xmlstream: return a future on connect()

which can make sense for users of the lib to wait on.
This commit is contained in:
mathieui 2025-02-09 12:07:56 +01:00
parent 75ea0bf039
commit f94a4f2dbd
No known key found for this signature in database
GPG Key ID: C59F84CEEFD616E3

View File

@ -409,9 +409,10 @@ class XMLStream(asyncio.BaseProtocol):
self.disconnected.set_result(True) self.disconnected.set_result(True)
self.disconnected = asyncio.Future() self.disconnected = asyncio.Future()
def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = None, def connect(self, host: str = '', port: int = 0,
use_ssl: Optional[bool] = None,
force_starttls: Optional[bool] = None, force_starttls: Optional[bool] = None,
disable_starttls: Optional[bool] = None) -> None: disable_starttls: Optional[bool] = None) -> asyncio.Future:
"""Create a new socket and connect to the server. """Create a new socket and connect to the server.
:param host: The name of the desired server for the connection. :param host: The name of the desired server for the connection.
@ -430,6 +431,7 @@ class XMLStream(asyncio.BaseProtocol):
upgrade to TLS, even if the server provides upgrade to TLS, even if the server provides
it. Use this for example if youre on it. Use this for example if youre on
localhost localhost
:returns: A future on the current connection attempt
""" """
if self._run_out_filters is None or self._run_out_filters.done(): if self._run_out_filters is None or self._run_out_filters.done():
@ -461,8 +463,14 @@ class XMLStream(asyncio.BaseProtocol):
self._connect_routine(), self._connect_routine(),
loop=self.loop, loop=self.loop,
) )
return self._current_connection_attempt
async def _connect_routine(self) -> None: async def _connect_routine(self) -> Optional[asyncio.Future]:
"""
Returns None if the attempt was canceled or if the connection succeeded
(cancelling done manually by the library user, so that should be known)
or the next connection attempt future if a new try has been scheduled.
"""
self.event_when_connected = "connected" self.event_when_connected = "connected"
if self._connect_loop_wait > 0: if self._connect_loop_wait > 0:
@ -499,11 +507,11 @@ class XMLStream(asyncio.BaseProtocol):
except Socket.gaierror as e: except Socket.gaierror as e:
self.event('connection_failed', self.event('connection_failed',
'No DNS record available for %s' % self.default_domain) 'No DNS record available for %s' % self.default_domain)
self.reschedule_connection_attempt() return self.reschedule_connection_attempt()
except OSError as e: except OSError as e:
log.debug('Connection failed: %s', e) log.debug('Connection failed: %s', e)
self.event("connection_failed", e) self.event("connection_failed", e)
self.reschedule_connection_attempt() return self.reschedule_connection_attempt()
def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None: def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None:
"""Process all the available XMPP events (receiving or sending data on the """Process all the available XMPP events (receiving or sending data on the
@ -639,10 +647,12 @@ class XMLStream(asyncio.BaseProtocol):
self._set_disconnected_future() self._set_disconnected_future()
self.event("disconnected", self.disconnect_reason or exception) self.event("disconnected", self.disconnect_reason or exception)
def reschedule_connection_attempt(self) -> None: def reschedule_connection_attempt(self) -> Optional[asyncio.Future]:
""" """
Increase the exponential back-off and initate another background Increase the exponential back-off and initate another background
_connect_routine call to connect to the server. _connect_routine call to connect to the server.
:returns: A future on the next scheduled connection attempt.
""" """
# abort if there is no ongoing connection attempt # abort if there is no ongoing connection attempt
if self._current_connection_attempt is None: if self._current_connection_attempt is None:
@ -652,6 +662,7 @@ class XMLStream(asyncio.BaseProtocol):
self._connect_routine(), self._connect_routine(),
loop=self.loop, loop=self.loop,
) )
return self._current_connection_attempt
def cancel_connection_attempt(self) -> None: def cancel_connection_attempt(self) -> None:
""" """