From f94a4f2dbd04fefc40d759c729132b6ac96eac8b Mon Sep 17 00:00:00 2001 From: mathieui Date: Sun, 9 Feb 2025 12:07:56 +0100 Subject: [PATCH] xmlstream: return a future on connect() which can make sense for users of the lib to wait on. --- slixmpp/xmlstream/xmlstream.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 97a356bb..60928999 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -409,9 +409,10 @@ class XMLStream(asyncio.BaseProtocol): self.disconnected.set_result(True) self.disconnected = asyncio.Future() - def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = None, + def connect(self, host: str = '', port: int = 0, + use_ssl: Optional[bool] = None, force_starttls: Optional[bool] = None, - disable_starttls: Optional[bool] = None) -> None: + disable_starttls: Optional[bool] = None) -> asyncio.Future: """Create a new socket and connect to the server. :param host: The name of the desired server for the connection. @@ -430,6 +431,7 @@ class XMLStream(asyncio.BaseProtocol): upgrade to TLS, even if the server provides it. Use this for example if you’re on localhost + :returns: A future on the current connection attempt """ if self._run_out_filters is None or self._run_out_filters.done(): @@ -461,8 +463,14 @@ class XMLStream(asyncio.BaseProtocol): self._connect_routine(), loop=self.loop, ) + return self._current_connection_attempt - async def _connect_routine(self) -> None: + async def _connect_routine(self) -> Optional[asyncio.Future]: + """ + Returns None if the attempt was canceled or if the connection succeeded + (cancelling done manually by the library user, so that should be known) + or the next connection attempt future if a new try has been scheduled. + """ self.event_when_connected = "connected" if self._connect_loop_wait > 0: @@ -499,11 +507,11 @@ class XMLStream(asyncio.BaseProtocol): except Socket.gaierror as e: self.event('connection_failed', 'No DNS record available for %s' % self.default_domain) - self.reschedule_connection_attempt() + return self.reschedule_connection_attempt() except OSError as e: log.debug('Connection failed: %s', e) self.event("connection_failed", e) - self.reschedule_connection_attempt() + return self.reschedule_connection_attempt() def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None: """Process all the available XMPP events (receiving or sending data on the @@ -639,10 +647,12 @@ class XMLStream(asyncio.BaseProtocol): self._set_disconnected_future() self.event("disconnected", self.disconnect_reason or exception) - def reschedule_connection_attempt(self) -> None: + def reschedule_connection_attempt(self) -> Optional[asyncio.Future]: """ Increase the exponential back-off and initate another background _connect_routine call to connect to the server. + + :returns: A future on the next scheduled connection attempt. """ # abort if there is no ongoing connection attempt if self._current_connection_attempt is None: @@ -652,6 +662,7 @@ class XMLStream(asyncio.BaseProtocol): self._connect_routine(), loop=self.loop, ) + return self._current_connection_attempt def cancel_connection_attempt(self) -> None: """