Merge branch 'reconnect-logic-doomed' into 'master'

fix reconnect logic

See merge request poezio/slixmpp!104
This commit is contained in:
Link Mauve 2021-01-29 16:11:29 +01:00
commit dbcd0c6050
2 changed files with 114 additions and 51 deletions

View File

@ -174,6 +174,9 @@ class XEP_0198(BasePlugin):
def send_ack(self): def send_ack(self):
"""Send the current ack count to the server.""" """Send the current ack count to the server."""
if not self.xmpp.transport:
log.debug('Disconnected: not sending ack')
return
ack = stanza.Ack(self.xmpp) ack = stanza.Ack(self.xmpp)
ack['h'] = self.handled ack['h'] = self.handled
self.xmpp.send_raw(str(ack)) self.xmpp.send_raw(str(ack))
@ -198,20 +201,7 @@ class XEP_0198(BasePlugin):
# We've already negotiated stream management, # We've already negotiated stream management,
# so no need to do it again. # so no need to do it again.
return False return False
if not self.sm_id: if self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
if 'bind' in self.xmpp.features:
enable = stanza.Enable(self.xmpp)
enable['resume'] = self.allow_resume
enable.send()
log.debug("enabling SM")
waiter = Waiter('enabled_or_failed',
MatchMany([
MatchXPath(stanza.Enabled.tag_name()),
MatchXPath(stanza.Failed.tag_name())]))
self.xmpp.register_handler(waiter)
result = await waiter.wait()
elif self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
resume = stanza.Resume(self.xmpp) resume = stanza.Resume(self.xmpp)
resume['h'] = self.handled resume['h'] = self.handled
resume['previd'] = self.sm_id resume['previd'] = self.sm_id
@ -229,6 +219,19 @@ class XEP_0198(BasePlugin):
result = await waiter.wait() result = await waiter.wait()
if result is not None and result.name == 'resumed': if result is not None and result.name == 'resumed':
return True return True
self.xmpp.event("session_end")
if 'bind' in self.xmpp.features:
enable = stanza.Enable(self.xmpp)
enable['resume'] = self.allow_resume
enable.send()
log.debug("enabling SM")
waiter = Waiter('enabled_or_failed',
MatchMany([
MatchXPath(stanza.Enabled.tag_name()),
MatchXPath(stanza.Failed.tag_name())]))
self.xmpp.register_handler(waiter)
result = await waiter.wait()
return False return False
def _handle_enabled(self, stanza): def _handle_enabled(self, stanza):

View File

@ -12,7 +12,15 @@
:license: MIT, see LICENSE for more details :license: MIT, see LICENSE for more details
""" """
from typing import Optional, Set, Callable, Any from typing import (
Any,
Callable,
Iterable,
List,
Optional,
Set,
Union,
)
import functools import functools
import logging import logging
@ -21,7 +29,7 @@ import ssl
import weakref import weakref
import uuid import uuid
from asyncio import iscoroutinefunction, wait from asyncio import iscoroutinefunction, wait, Future
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
@ -224,12 +232,13 @@ class XMLStream(asyncio.BaseProtocol):
self.disconnect_reason = None self.disconnect_reason = None
#: An asyncio Future being done when the stream is disconnected. #: An asyncio Future being done when the stream is disconnected.
self.disconnected = asyncio.Future() self.disconnected: Future = Future()
self.add_event_handler('disconnected', self._remove_schedules) self.add_event_handler('disconnected', self._remove_schedules)
self.add_event_handler('session_start', self._start_keepalive) self.add_event_handler('session_start', self._start_keepalive)
self._run_filters = None self._run_out_filters: Optional[Future] = None
self.__slow_tasks: List[Future] = []
@property @property
def loop(self): def loop(self):
@ -250,6 +259,12 @@ class XMLStream(asyncio.BaseProtocol):
""" """
return uuid.uuid4().hex return uuid.uuid4().hex
def _set_disconnected_future(self):
"""Set the self.disconnected future on disconnect"""
if not self.disconnected.done():
self.disconnected.set_result(True)
self.disconnected = asyncio.Future()
def connect(self, host='', port=0, use_ssl=False, def connect(self, host='', port=0, use_ssl=False,
force_starttls=True, disable_starttls=False): force_starttls=True, disable_starttls=False):
"""Create a new socket and connect to the server. """Create a new socket and connect to the server.
@ -272,8 +287,8 @@ class XMLStream(asyncio.BaseProtocol):
localhost localhost
""" """
if self._run_filters is None: if self._run_out_filters is None or self._run_out_filters.done():
self._run_filters = asyncio.ensure_future( self._run_out_filters = asyncio.ensure_future(
self.run_filters(), self.run_filters(),
loop=self.loop, loop=self.loop,
) )
@ -418,10 +433,10 @@ class XMLStream(asyncio.BaseProtocol):
if self.xml_depth == 0: if self.xml_depth == 0:
# The stream's root element has closed, # The stream's root element has closed,
# terminating the stream. # terminating the stream.
self.end_session_on_disconnect = True
log.debug("End of stream received") log.debug("End of stream received")
self.disconnect_reason = "End of stream" self.disconnect_reason = "End of stream"
self.abort() self.abort()
return
elif self.xml_depth == 1: elif self.xml_depth == 1:
# A stanza is an XML element that is a direct child of # A stanza is an XML element that is a direct child of
# the root element, hence the check of depth == 1 # the root element, hence the check of depth == 1
@ -463,11 +478,11 @@ class XMLStream(asyncio.BaseProtocol):
self.parser = None self.parser = None
self.transport = None self.transport = None
self.socket = None self.socket = None
if self._run_filters:
self._run_filters.cancel()
# Fire the events after cleanup # Fire the events after cleanup
if self.end_session_on_disconnect: if self.end_session_on_disconnect:
self._reset_sendq()
self.event('session_end') self.event('session_end')
self._set_disconnected_future()
self.event("disconnected", self.disconnect_reason or exception and exception.strerror) self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
def cancel_connection_attempt(self): def cancel_connection_attempt(self):
@ -480,10 +495,8 @@ class XMLStream(asyncio.BaseProtocol):
if self._current_connection_attempt: if self._current_connection_attempt:
self._current_connection_attempt.cancel() self._current_connection_attempt.cancel()
self._current_connection_attempt = None self._current_connection_attempt = None
if self._run_filters:
self._run_filters.cancel()
def disconnect(self, wait: float = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> None: def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future:
"""Close the XML stream and wait for an acknowldgement from the server for """Close the XML stream and wait for an acknowldgement from the server for
at most `wait` seconds. After the given number of seconds has at most `wait` seconds. After the given number of seconds has
passed without a response from the server, or when the server passed without a response from the server, or when the server
@ -491,10 +504,13 @@ class XMLStream(asyncio.BaseProtocol):
called. If wait is 0.0, this will call abort() directly without closing called. If wait is 0.0, this will call abort() directly without closing
the stream. the stream.
Does nothing if we are not connected. Does nothing but trigger the disconnected event if we are not connected.
:param wait: Time to wait for a response from the server. :param wait: Time to wait for a response from the server.
:param reason: An optional reason for the disconnect.
:param ignore_send_queue: Boolean to toggle if we want to ignore
the in-flight stanzas and disconnect immediately.
:return: A future that ends when all code involved in the disconnect has ended
""" """
# Compat: docs/getting_started/sendlogout.rst has been promoting # Compat: docs/getting_started/sendlogout.rst has been promoting
# `disconnect(wait=True)` for ages. This doesn't mean anything to the # `disconnect(wait=True)` for ages. This doesn't mean anything to the
@ -504,50 +520,75 @@ class XMLStream(asyncio.BaseProtocol):
wait = 2.0 wait = 2.0
if self.transport: if self.transport:
self.disconnect_reason = reason
if self.waiting_queue.empty() or ignore_send_queue: if self.waiting_queue.empty() or ignore_send_queue:
self.disconnect_reason = reason
self.cancel_connection_attempt() self.cancel_connection_attempt()
if wait > 0.0: return asyncio.ensure_future(
self.send_raw(self.stream_footer) self._end_stream_wait(wait, reason=reason),
self.schedule('Disconnect wait', wait, loop=self.loop,
self.abort, repeat=False) )
else: else:
asyncio.ensure_future( return asyncio.ensure_future(
self._consume_send_queue_before_disconnecting(reason, wait), self._consume_send_queue_before_disconnecting(reason, wait),
loop=self.loop, loop=self.loop,
) )
else: else:
self._set_disconnected_future()
self.event("disconnected", reason) self.event("disconnected", reason)
future = Future()
future.set_result(None)
return future
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float): async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
"""Wait until the send queue is empty before disconnecting""" """Wait until the send queue is empty before disconnecting"""
await self.waiting_queue.join() try:
await asyncio.wait_for(
self.waiting_queue.join(),
wait,
loop=self.loop
)
except asyncio.TimeoutError:
wait = 0 # we already consumed the timeout
self.disconnect_reason = reason self.disconnect_reason = reason
self.cancel_connection_attempt() await self._end_stream_wait(wait)
if wait > 0.0:
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None):
"""
Run abort() if we do not received the disconnected event
after a waiting time.
:param wait: The waiting time (defaults to 2)
"""
try:
self.send_raw(self.stream_footer) self.send_raw(self.stream_footer)
self.schedule('Disconnect wait', wait, await self.wait_until('disconnected', wait)
self.abort, repeat=False) except asyncio.TimeoutError:
self.abort()
except NotConnectedError:
# We are not connected when sending the end of stream
# that means the disconnect has already been handled
pass
def abort(self): def abort(self):
""" """
Forcibly close the connection Forcibly close the connection
""" """
self.cancel_connection_attempt()
if self.transport: if self.transport:
self.cancel_connection_attempt()
self.transport.close() self.transport.close()
self.transport.abort() self.transport.abort()
self.event("killed") self.event("killed")
self.disconnected.set_result(True)
self.disconnected = asyncio.Future()
self.event("disconnected", self.disconnect_reason)
def reconnect(self, wait=2.0, reason="Reconnecting"): def reconnect(self, wait=2.0, reason="Reconnecting"):
"""Calls disconnect(), and once we are disconnected (after the timeout, or """Calls disconnect(), and once we are disconnected (after the timeout, or
when the server acknowledgement is received), call connect() when the server acknowledgement is received), call connect()
""" """
log.debug("reconnecting...") log.debug("reconnecting...")
self.add_event_handler('disconnected', lambda event: self.connect(), disposable=True) async def handler(event):
# We yield here to allow synchronous handlers to work first
await asyncio.sleep(0, loop=self.loop)
self.connect()
self.add_event_handler('disconnected', handler, disposable=True)
self.disconnect(wait, reason) self.disconnect(wait, reason)
def configure_socket(self): def configure_socket(self):
@ -655,7 +696,6 @@ class XMLStream(asyncio.BaseProtocol):
def _remove_schedules(self, event): def _remove_schedules(self, event):
"""Remove some schedules that become pointless when disconnected""" """Remove some schedules that become pointless when disconnected"""
self.cancel_schedule('Whitespace Keepalive') self.cancel_schedule('Whitespace Keepalive')
self.cancel_schedule('Disconnect wait')
def start_stream_handler(self, xml): def start_stream_handler(self, xml):
"""Perform any initialization actions, such as handshakes, """Perform any initialization actions, such as handshakes,
@ -833,7 +873,7 @@ class XMLStream(asyncio.BaseProtocol):
""" """
log.debug("Event triggered: %s", name) log.debug("Event triggered: %s", name)
handlers = self.__event_handlers.get(name, []) handlers = self.__event_handlers.get(name, [])[:]
for handler in handlers: for handler in handlers:
handler_callback, disposable = handler handler_callback, disposable = handler
old_exception = getattr(data, 'exception', None) old_exception = getattr(data, 'exception', None)
@ -941,6 +981,18 @@ class XMLStream(asyncio.BaseProtocol):
""" """
return xml return xml
def _reset_sendq(self):
"""Clear sending tasks on session end"""
# Cancel all pending slow send tasks
log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
for slow_task in self.__slow_tasks:
slow_task.cancel()
self.__slow_tasks.clear()
# Purge pending stanzas
while not self.waiting_queue.empty():
discarded = self.waiting_queue.get_nowait()
log.debug('Discarded stanza: %s', discarded)
async def _continue_slow_send( async def _continue_slow_send(
self, self,
task: asyncio.Task, task: asyncio.Task,
@ -954,6 +1006,7 @@ class XMLStream(asyncio.BaseProtocol):
:param set already_used: Filters already used on this outgoing stanza :param set already_used: Filters already used on this outgoing stanza
""" """
data = await task data = await task
self.__slow_tasks.remove(task)
for filter in self.__filters['out']: for filter in self.__filters['out']:
if filter in already_used: if filter in already_used:
continue continue
@ -975,7 +1028,6 @@ class XMLStream(asyncio.BaseProtocol):
else: else:
self.send_raw(data) self.send_raw(data)
async def run_filters(self): async def run_filters(self):
""" """
Background loop that processes stanzas to send. Background loop that processes stanzas to send.
@ -995,11 +1047,13 @@ class XMLStream(asyncio.BaseProtocol):
timeout=1, timeout=1,
) )
if pending: if pending:
self.slow_tasks.append(task)
asyncio.ensure_future( asyncio.ensure_future(
self._continue_slow_send( self._continue_slow_send(
task, task,
already_run_filters already_run_filters
) ),
loop=self.loop,
) )
raise Exception("Slow coro, rescheduling") raise Exception("Slow coro, rescheduling")
data = task.result() data = task.result()
@ -1142,9 +1196,15 @@ class XMLStream(asyncio.BaseProtocol):
:param int timeout: Timeout :param int timeout: Timeout
""" """
fut = asyncio.Future() fut = asyncio.Future()
def result_handler(event_data):
if not fut.done():
fut.set_result(event_data)
else:
log.debug("Future registered on event '%s' was alredy done", event)
self.add_event_handler( self.add_event_handler(
event, event,
fut.set_result, result_handler,
disposable=True, disposable=True,
) )
return await asyncio.wait_for(fut, timeout) return await asyncio.wait_for(fut, timeout, loop=self.loop)