Merge branch 'reconnect-logic-doomed' into 'master'
fix reconnect logic See merge request poezio/slixmpp!104
This commit is contained in:
commit
dbcd0c6050
@ -174,6 +174,9 @@ class XEP_0198(BasePlugin):
|
|||||||
|
|
||||||
def send_ack(self):
|
def send_ack(self):
|
||||||
"""Send the current ack count to the server."""
|
"""Send the current ack count to the server."""
|
||||||
|
if not self.xmpp.transport:
|
||||||
|
log.debug('Disconnected: not sending ack')
|
||||||
|
return
|
||||||
ack = stanza.Ack(self.xmpp)
|
ack = stanza.Ack(self.xmpp)
|
||||||
ack['h'] = self.handled
|
ack['h'] = self.handled
|
||||||
self.xmpp.send_raw(str(ack))
|
self.xmpp.send_raw(str(ack))
|
||||||
@ -198,20 +201,7 @@ class XEP_0198(BasePlugin):
|
|||||||
# We've already negotiated stream management,
|
# We've already negotiated stream management,
|
||||||
# so no need to do it again.
|
# so no need to do it again.
|
||||||
return False
|
return False
|
||||||
if not self.sm_id:
|
if self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
|
||||||
if 'bind' in self.xmpp.features:
|
|
||||||
enable = stanza.Enable(self.xmpp)
|
|
||||||
enable['resume'] = self.allow_resume
|
|
||||||
enable.send()
|
|
||||||
log.debug("enabling SM")
|
|
||||||
|
|
||||||
waiter = Waiter('enabled_or_failed',
|
|
||||||
MatchMany([
|
|
||||||
MatchXPath(stanza.Enabled.tag_name()),
|
|
||||||
MatchXPath(stanza.Failed.tag_name())]))
|
|
||||||
self.xmpp.register_handler(waiter)
|
|
||||||
result = await waiter.wait()
|
|
||||||
elif self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
|
|
||||||
resume = stanza.Resume(self.xmpp)
|
resume = stanza.Resume(self.xmpp)
|
||||||
resume['h'] = self.handled
|
resume['h'] = self.handled
|
||||||
resume['previd'] = self.sm_id
|
resume['previd'] = self.sm_id
|
||||||
@ -229,6 +219,19 @@ class XEP_0198(BasePlugin):
|
|||||||
result = await waiter.wait()
|
result = await waiter.wait()
|
||||||
if result is not None and result.name == 'resumed':
|
if result is not None and result.name == 'resumed':
|
||||||
return True
|
return True
|
||||||
|
self.xmpp.event("session_end")
|
||||||
|
if 'bind' in self.xmpp.features:
|
||||||
|
enable = stanza.Enable(self.xmpp)
|
||||||
|
enable['resume'] = self.allow_resume
|
||||||
|
enable.send()
|
||||||
|
log.debug("enabling SM")
|
||||||
|
|
||||||
|
waiter = Waiter('enabled_or_failed',
|
||||||
|
MatchMany([
|
||||||
|
MatchXPath(stanza.Enabled.tag_name()),
|
||||||
|
MatchXPath(stanza.Failed.tag_name())]))
|
||||||
|
self.xmpp.register_handler(waiter)
|
||||||
|
result = await waiter.wait()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _handle_enabled(self, stanza):
|
def _handle_enabled(self, stanza):
|
||||||
|
@ -12,7 +12,15 @@
|
|||||||
:license: MIT, see LICENSE for more details
|
:license: MIT, see LICENSE for more details
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Set, Callable, Any
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
@ -21,7 +29,7 @@ import ssl
|
|||||||
import weakref
|
import weakref
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from asyncio import iscoroutinefunction, wait
|
from asyncio import iscoroutinefunction, wait, Future
|
||||||
|
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
@ -224,12 +232,13 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.disconnect_reason = None
|
self.disconnect_reason = None
|
||||||
|
|
||||||
#: An asyncio Future being done when the stream is disconnected.
|
#: An asyncio Future being done when the stream is disconnected.
|
||||||
self.disconnected = asyncio.Future()
|
self.disconnected: Future = Future()
|
||||||
|
|
||||||
self.add_event_handler('disconnected', self._remove_schedules)
|
self.add_event_handler('disconnected', self._remove_schedules)
|
||||||
self.add_event_handler('session_start', self._start_keepalive)
|
self.add_event_handler('session_start', self._start_keepalive)
|
||||||
|
|
||||||
self._run_filters = None
|
self._run_out_filters: Optional[Future] = None
|
||||||
|
self.__slow_tasks: List[Future] = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loop(self):
|
def loop(self):
|
||||||
@ -250,6 +259,12 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
"""
|
"""
|
||||||
return uuid.uuid4().hex
|
return uuid.uuid4().hex
|
||||||
|
|
||||||
|
def _set_disconnected_future(self):
|
||||||
|
"""Set the self.disconnected future on disconnect"""
|
||||||
|
if not self.disconnected.done():
|
||||||
|
self.disconnected.set_result(True)
|
||||||
|
self.disconnected = asyncio.Future()
|
||||||
|
|
||||||
def connect(self, host='', port=0, use_ssl=False,
|
def connect(self, host='', port=0, use_ssl=False,
|
||||||
force_starttls=True, disable_starttls=False):
|
force_starttls=True, disable_starttls=False):
|
||||||
"""Create a new socket and connect to the server.
|
"""Create a new socket and connect to the server.
|
||||||
@ -272,8 +287,8 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
localhost
|
localhost
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self._run_filters is None:
|
if self._run_out_filters is None or self._run_out_filters.done():
|
||||||
self._run_filters = asyncio.ensure_future(
|
self._run_out_filters = asyncio.ensure_future(
|
||||||
self.run_filters(),
|
self.run_filters(),
|
||||||
loop=self.loop,
|
loop=self.loop,
|
||||||
)
|
)
|
||||||
@ -418,10 +433,10 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
if self.xml_depth == 0:
|
if self.xml_depth == 0:
|
||||||
# The stream's root element has closed,
|
# The stream's root element has closed,
|
||||||
# terminating the stream.
|
# terminating the stream.
|
||||||
self.end_session_on_disconnect = True
|
|
||||||
log.debug("End of stream received")
|
log.debug("End of stream received")
|
||||||
self.disconnect_reason = "End of stream"
|
self.disconnect_reason = "End of stream"
|
||||||
self.abort()
|
self.abort()
|
||||||
|
return
|
||||||
elif self.xml_depth == 1:
|
elif self.xml_depth == 1:
|
||||||
# A stanza is an XML element that is a direct child of
|
# A stanza is an XML element that is a direct child of
|
||||||
# the root element, hence the check of depth == 1
|
# the root element, hence the check of depth == 1
|
||||||
@ -463,11 +478,11 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.parser = None
|
self.parser = None
|
||||||
self.transport = None
|
self.transport = None
|
||||||
self.socket = None
|
self.socket = None
|
||||||
if self._run_filters:
|
|
||||||
self._run_filters.cancel()
|
|
||||||
# Fire the events after cleanup
|
# Fire the events after cleanup
|
||||||
if self.end_session_on_disconnect:
|
if self.end_session_on_disconnect:
|
||||||
|
self._reset_sendq()
|
||||||
self.event('session_end')
|
self.event('session_end')
|
||||||
|
self._set_disconnected_future()
|
||||||
self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
|
self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
|
||||||
|
|
||||||
def cancel_connection_attempt(self):
|
def cancel_connection_attempt(self):
|
||||||
@ -480,10 +495,8 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
if self._current_connection_attempt:
|
if self._current_connection_attempt:
|
||||||
self._current_connection_attempt.cancel()
|
self._current_connection_attempt.cancel()
|
||||||
self._current_connection_attempt = None
|
self._current_connection_attempt = None
|
||||||
if self._run_filters:
|
|
||||||
self._run_filters.cancel()
|
|
||||||
|
|
||||||
def disconnect(self, wait: float = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> None:
|
def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future:
|
||||||
"""Close the XML stream and wait for an acknowldgement from the server for
|
"""Close the XML stream and wait for an acknowldgement from the server for
|
||||||
at most `wait` seconds. After the given number of seconds has
|
at most `wait` seconds. After the given number of seconds has
|
||||||
passed without a response from the server, or when the server
|
passed without a response from the server, or when the server
|
||||||
@ -491,10 +504,13 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
called. If wait is 0.0, this will call abort() directly without closing
|
called. If wait is 0.0, this will call abort() directly without closing
|
||||||
the stream.
|
the stream.
|
||||||
|
|
||||||
Does nothing if we are not connected.
|
Does nothing but trigger the disconnected event if we are not connected.
|
||||||
|
|
||||||
:param wait: Time to wait for a response from the server.
|
:param wait: Time to wait for a response from the server.
|
||||||
|
:param reason: An optional reason for the disconnect.
|
||||||
|
:param ignore_send_queue: Boolean to toggle if we want to ignore
|
||||||
|
the in-flight stanzas and disconnect immediately.
|
||||||
|
:return: A future that ends when all code involved in the disconnect has ended
|
||||||
"""
|
"""
|
||||||
# Compat: docs/getting_started/sendlogout.rst has been promoting
|
# Compat: docs/getting_started/sendlogout.rst has been promoting
|
||||||
# `disconnect(wait=True)` for ages. This doesn't mean anything to the
|
# `disconnect(wait=True)` for ages. This doesn't mean anything to the
|
||||||
@ -504,50 +520,75 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
wait = 2.0
|
wait = 2.0
|
||||||
|
|
||||||
if self.transport:
|
if self.transport:
|
||||||
|
self.disconnect_reason = reason
|
||||||
if self.waiting_queue.empty() or ignore_send_queue:
|
if self.waiting_queue.empty() or ignore_send_queue:
|
||||||
self.disconnect_reason = reason
|
|
||||||
self.cancel_connection_attempt()
|
self.cancel_connection_attempt()
|
||||||
if wait > 0.0:
|
return asyncio.ensure_future(
|
||||||
self.send_raw(self.stream_footer)
|
self._end_stream_wait(wait, reason=reason),
|
||||||
self.schedule('Disconnect wait', wait,
|
loop=self.loop,
|
||||||
self.abort, repeat=False)
|
)
|
||||||
else:
|
else:
|
||||||
asyncio.ensure_future(
|
return asyncio.ensure_future(
|
||||||
self._consume_send_queue_before_disconnecting(reason, wait),
|
self._consume_send_queue_before_disconnecting(reason, wait),
|
||||||
loop=self.loop,
|
loop=self.loop,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
self._set_disconnected_future()
|
||||||
self.event("disconnected", reason)
|
self.event("disconnected", reason)
|
||||||
|
future = Future()
|
||||||
|
future.set_result(None)
|
||||||
|
return future
|
||||||
|
|
||||||
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
|
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
|
||||||
"""Wait until the send queue is empty before disconnecting"""
|
"""Wait until the send queue is empty before disconnecting"""
|
||||||
await self.waiting_queue.join()
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self.waiting_queue.join(),
|
||||||
|
wait,
|
||||||
|
loop=self.loop
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
wait = 0 # we already consumed the timeout
|
||||||
self.disconnect_reason = reason
|
self.disconnect_reason = reason
|
||||||
self.cancel_connection_attempt()
|
await self._end_stream_wait(wait)
|
||||||
if wait > 0.0:
|
|
||||||
|
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Run abort() if we do not received the disconnected event
|
||||||
|
after a waiting time.
|
||||||
|
|
||||||
|
:param wait: The waiting time (defaults to 2)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
self.send_raw(self.stream_footer)
|
self.send_raw(self.stream_footer)
|
||||||
self.schedule('Disconnect wait', wait,
|
await self.wait_until('disconnected', wait)
|
||||||
self.abort, repeat=False)
|
except asyncio.TimeoutError:
|
||||||
|
self.abort()
|
||||||
|
except NotConnectedError:
|
||||||
|
# We are not connected when sending the end of stream
|
||||||
|
# that means the disconnect has already been handled
|
||||||
|
pass
|
||||||
|
|
||||||
def abort(self):
|
def abort(self):
|
||||||
"""
|
"""
|
||||||
Forcibly close the connection
|
Forcibly close the connection
|
||||||
"""
|
"""
|
||||||
self.cancel_connection_attempt()
|
|
||||||
if self.transport:
|
if self.transport:
|
||||||
|
self.cancel_connection_attempt()
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
self.transport.abort()
|
self.transport.abort()
|
||||||
self.event("killed")
|
self.event("killed")
|
||||||
self.disconnected.set_result(True)
|
|
||||||
self.disconnected = asyncio.Future()
|
|
||||||
self.event("disconnected", self.disconnect_reason)
|
|
||||||
|
|
||||||
def reconnect(self, wait=2.0, reason="Reconnecting"):
|
def reconnect(self, wait=2.0, reason="Reconnecting"):
|
||||||
"""Calls disconnect(), and once we are disconnected (after the timeout, or
|
"""Calls disconnect(), and once we are disconnected (after the timeout, or
|
||||||
when the server acknowledgement is received), call connect()
|
when the server acknowledgement is received), call connect()
|
||||||
"""
|
"""
|
||||||
log.debug("reconnecting...")
|
log.debug("reconnecting...")
|
||||||
self.add_event_handler('disconnected', lambda event: self.connect(), disposable=True)
|
async def handler(event):
|
||||||
|
# We yield here to allow synchronous handlers to work first
|
||||||
|
await asyncio.sleep(0, loop=self.loop)
|
||||||
|
self.connect()
|
||||||
|
self.add_event_handler('disconnected', handler, disposable=True)
|
||||||
self.disconnect(wait, reason)
|
self.disconnect(wait, reason)
|
||||||
|
|
||||||
def configure_socket(self):
|
def configure_socket(self):
|
||||||
@ -655,7 +696,6 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
def _remove_schedules(self, event):
|
def _remove_schedules(self, event):
|
||||||
"""Remove some schedules that become pointless when disconnected"""
|
"""Remove some schedules that become pointless when disconnected"""
|
||||||
self.cancel_schedule('Whitespace Keepalive')
|
self.cancel_schedule('Whitespace Keepalive')
|
||||||
self.cancel_schedule('Disconnect wait')
|
|
||||||
|
|
||||||
def start_stream_handler(self, xml):
|
def start_stream_handler(self, xml):
|
||||||
"""Perform any initialization actions, such as handshakes,
|
"""Perform any initialization actions, such as handshakes,
|
||||||
@ -833,7 +873,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
"""
|
"""
|
||||||
log.debug("Event triggered: %s", name)
|
log.debug("Event triggered: %s", name)
|
||||||
|
|
||||||
handlers = self.__event_handlers.get(name, [])
|
handlers = self.__event_handlers.get(name, [])[:]
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
handler_callback, disposable = handler
|
handler_callback, disposable = handler
|
||||||
old_exception = getattr(data, 'exception', None)
|
old_exception = getattr(data, 'exception', None)
|
||||||
@ -941,6 +981,18 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
"""
|
"""
|
||||||
return xml
|
return xml
|
||||||
|
|
||||||
|
def _reset_sendq(self):
|
||||||
|
"""Clear sending tasks on session end"""
|
||||||
|
# Cancel all pending slow send tasks
|
||||||
|
log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
|
||||||
|
for slow_task in self.__slow_tasks:
|
||||||
|
slow_task.cancel()
|
||||||
|
self.__slow_tasks.clear()
|
||||||
|
# Purge pending stanzas
|
||||||
|
while not self.waiting_queue.empty():
|
||||||
|
discarded = self.waiting_queue.get_nowait()
|
||||||
|
log.debug('Discarded stanza: %s', discarded)
|
||||||
|
|
||||||
async def _continue_slow_send(
|
async def _continue_slow_send(
|
||||||
self,
|
self,
|
||||||
task: asyncio.Task,
|
task: asyncio.Task,
|
||||||
@ -954,6 +1006,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
:param set already_used: Filters already used on this outgoing stanza
|
:param set already_used: Filters already used on this outgoing stanza
|
||||||
"""
|
"""
|
||||||
data = await task
|
data = await task
|
||||||
|
self.__slow_tasks.remove(task)
|
||||||
for filter in self.__filters['out']:
|
for filter in self.__filters['out']:
|
||||||
if filter in already_used:
|
if filter in already_used:
|
||||||
continue
|
continue
|
||||||
@ -975,7 +1028,6 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
else:
|
else:
|
||||||
self.send_raw(data)
|
self.send_raw(data)
|
||||||
|
|
||||||
|
|
||||||
async def run_filters(self):
|
async def run_filters(self):
|
||||||
"""
|
"""
|
||||||
Background loop that processes stanzas to send.
|
Background loop that processes stanzas to send.
|
||||||
@ -995,11 +1047,13 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
timeout=1,
|
timeout=1,
|
||||||
)
|
)
|
||||||
if pending:
|
if pending:
|
||||||
|
self.slow_tasks.append(task)
|
||||||
asyncio.ensure_future(
|
asyncio.ensure_future(
|
||||||
self._continue_slow_send(
|
self._continue_slow_send(
|
||||||
task,
|
task,
|
||||||
already_run_filters
|
already_run_filters
|
||||||
)
|
),
|
||||||
|
loop=self.loop,
|
||||||
)
|
)
|
||||||
raise Exception("Slow coro, rescheduling")
|
raise Exception("Slow coro, rescheduling")
|
||||||
data = task.result()
|
data = task.result()
|
||||||
@ -1142,9 +1196,15 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
:param int timeout: Timeout
|
:param int timeout: Timeout
|
||||||
"""
|
"""
|
||||||
fut = asyncio.Future()
|
fut = asyncio.Future()
|
||||||
|
def result_handler(event_data):
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_result(event_data)
|
||||||
|
else:
|
||||||
|
log.debug("Future registered on event '%s' was alredy done", event)
|
||||||
|
|
||||||
self.add_event_handler(
|
self.add_event_handler(
|
||||||
event,
|
event,
|
||||||
fut.set_result,
|
result_handler,
|
||||||
disposable=True,
|
disposable=True,
|
||||||
)
|
)
|
||||||
return await asyncio.wait_for(fut, timeout)
|
return await asyncio.wait_for(fut, timeout, loop=self.loop)
|
||||||
|
Loading…
Reference in New Issue
Block a user