diff --git a/slixmpp/clientxmpp.py b/slixmpp/clientxmpp.py index 8ef3493b..c88f3716 100644 --- a/slixmpp/clientxmpp.py +++ b/slixmpp/clientxmpp.py @@ -139,7 +139,7 @@ class ClientXMPP(BaseXMPP): def connect(self, address: Optional[Tuple[str, int]] = None, # type: ignore use_ssl: Optional[bool] = None, force_starttls: Optional[bool] = None, - disable_starttls: Optional[bool] = None) -> None: + disable_starttls: Optional[bool] = None) -> asyncio.Future: """Connect to the XMPP server. When no address is given, a SRV lookup for the server will @@ -166,8 +166,9 @@ class ClientXMPP(BaseXMPP): host, port = (self.boundjid.host, 5222) self.dns_service = 'xmpp-client' - XMLStream.connect(self, host, port, use_ssl=use_ssl, - force_starttls=force_starttls, disable_starttls=disable_starttls) + return XMLStream.connect(self, host, port, use_ssl=use_ssl, + force_starttls=force_starttls, + disable_starttls=disable_starttls) def register_feature(self, name: str, handler: Callable, restart: bool = False, order: int = 5000) -> None: """Register a stream feature handler. diff --git a/slixmpp/componentxmpp.py b/slixmpp/componentxmpp.py index f811987d..ab7f68db 100644 --- a/slixmpp/componentxmpp.py +++ b/slixmpp/componentxmpp.py @@ -9,6 +9,7 @@ import logging import hashlib +from asyncio import Future from typing import Optional from slixmpp import Message, Iq, Presence @@ -97,7 +98,7 @@ class ComponentXMPP(BaseXMPP): def connect(self, host: Optional[str] = None, port: int = 0, use_ssl: Optional[bool] = None, force_starttls: Optional[bool] = None, - disable_starttls: Optional[bool] = None) -> None: + disable_starttls: Optional[bool] = None) -> Future: """Connect to the server. @@ -118,7 +119,7 @@ class ComponentXMPP(BaseXMPP): self.server_name = self.boundjid.host log.debug("Connecting to %s:%s", host, port) - XMLStream.connect(self, host=self.server_host, port=self.server_port, use_ssl=use_ssl) + return XMLStream.connect(self, host=self.server_host, port=self.server_port, use_ssl=use_ssl) def incoming_filter(self, xml): """ diff --git a/slixmpp/xmlstream/xmlstream.py b/slixmpp/xmlstream/xmlstream.py index 90985858..8fae894a 100644 --- a/slixmpp/xmlstream/xmlstream.py +++ b/slixmpp/xmlstream/xmlstream.py @@ -511,7 +511,7 @@ class XMLStream(asyncio.BaseProtocol): ssl_context = None if self._current_connection_attempt is None: - return + return None try: server_hostname = self.default_domain if self.use_ssl else None await self.loop.create_connection(lambda: self, @@ -528,6 +528,7 @@ class XMLStream(asyncio.BaseProtocol): log.debug('Connection failed: %s', e) self.event("connection_failed", e) return self.reschedule_connection_attempt() + return None def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None: """Process all the available XMPP events (receiving or sending data on the @@ -672,7 +673,7 @@ class XMLStream(asyncio.BaseProtocol): """ # abort if there is no ongoing connection attempt if self._current_connection_attempt is None: - return + return None self._connect_loop_wait = min(300, self._connect_loop_wait * 2 + 1) self._current_connection_attempt = asyncio.ensure_future( self._connect_routine(),