Fix connect parameters used for follow-up calls

XMLStream.connect() is supposed to persist the parameters
it gets called with to allow follow-up calls to call
XMLStream.connect() without any parameters to result in a connection
with the same properties as the original one. That's for example used by
XMLStream.reconnect() when establishing a new connection.

Unfortunately that was broken for some of the parameters and resulted
different TLS related settings on reconnections. This commit fixes that.
This commit is contained in:
Daniel Roschka 2023-01-08 09:23:02 +01:00 committed by mathieui
parent 76a11d4899
commit 7de5cbcf33
No known key found for this signature in database
GPG Key ID: C59F84CEEFD616E3
3 changed files with 17 additions and 15 deletions

View File

@ -138,8 +138,8 @@ class ClientXMPP(BaseXMPP):
self.credentials['password'] = value self.credentials['password'] = value
def connect(self, address: Optional[Tuple[str, int]] = None, # type: ignore def connect(self, address: Optional[Tuple[str, int]] = None, # type: ignore
use_ssl: bool = False, force_starttls: bool = True, use_ssl: Optional[bool] = None, force_starttls: Optional[bool] = None,
disable_starttls: bool = False) -> None: disable_starttls: Optional[bool] = None) -> None:
"""Connect to the XMPP server. """Connect to the XMPP server.
When no address is given, a SRV lookup for the server will When no address is given, a SRV lookup for the server will
@ -166,8 +166,8 @@ class ClientXMPP(BaseXMPP):
host, port = (self.boundjid.host, 5222) host, port = (self.boundjid.host, 5222)
self.dns_service = 'xmpp-client' self.dns_service = 'xmpp-client'
return XMLStream.connect(self, host, port, use_ssl=use_ssl, XMLStream.connect(self, host, port, use_ssl=use_ssl,
force_starttls=force_starttls, disable_starttls=disable_starttls) force_starttls=force_starttls, disable_starttls=disable_starttls)
def register_feature(self, name: str, handler: Callable, restart: bool = False, order: int = 5000) -> None: def register_feature(self, name: str, handler: Callable, restart: bool = False, order: int = 5000) -> None:
"""Register a stream feature handler. """Register a stream feature handler.

View File

@ -9,6 +9,8 @@
import logging import logging
import hashlib import hashlib
from typing import Optional
from slixmpp import Message, Iq, Presence from slixmpp import Message, Iq, Presence
from slixmpp.basexmpp import BaseXMPP from slixmpp.basexmpp import BaseXMPP
from slixmpp.stanza import Handshake from slixmpp.stanza import Handshake
@ -93,7 +95,7 @@ class ComponentXMPP(BaseXMPP):
for st in Message, Iq, Presence: for st in Message, Iq, Presence:
register_stanza_plugin(st, Error) register_stanza_plugin(st, Error)
def connect(self, host=None, port=None, use_ssl=False): def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = None) -> None:
"""Connect to the server. """Connect to the server.
@ -104,16 +106,15 @@ class ComponentXMPP(BaseXMPP):
:param use_ssl: Flag indicating if SSL should be used by connecting :param use_ssl: Flag indicating if SSL should be used by connecting
directly to a port using SSL. directly to a port using SSL.
""" """
if host is None: if host is not None:
host = self.server_host self.server_host = host
if port is None: if port:
port = self.server_port self.server_port = port
self.server_name = self.boundjid.host self.server_name = self.boundjid.host
log.debug("Connecting to %s:%s", host, port) log.debug("Connecting to %s:%s", host, port)
return XMLStream.connect(self, host=host, port=port, XMLStream.connect(self, host=self.server_host, port=self.server_port, use_ssl=use_ssl)
use_ssl=use_ssl)
def incoming_filter(self, xml): def incoming_filter(self, xml):
""" """

View File

@ -290,8 +290,8 @@ class XMLStream(asyncio.BaseProtocol):
self.xml_depth = 0 self.xml_depth = 0
self.xml_root = None self.xml_root = None
self.force_starttls = None self.force_starttls = True
self.disable_starttls = None self.disable_starttls = False
self.waiting_queue = asyncio.Queue() self.waiting_queue = asyncio.Queue()
@ -405,8 +405,9 @@ class XMLStream(asyncio.BaseProtocol):
self.disconnected.set_result(True) self.disconnected.set_result(True)
self.disconnected = asyncio.Future() self.disconnected = asyncio.Future()
def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False, def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = None,
force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None: force_starttls: Optional[bool] = None,
disable_starttls: Optional[bool] = None) -> None:
"""Create a new socket and connect to the server. """Create a new socket and connect to the server.
:param host: The name of the desired server for the connection. :param host: The name of the desired server for the connection.