xmlstream: add more types

This commit is contained in:
mathieui 2021-07-03 11:07:01 +02:00
parent c07476e7de
commit db48c8f4da

View File

@ -9,17 +9,24 @@
# :license: MIT, see LICENSE for more details # :license: MIT, see LICENSE for more details
from typing import ( from typing import (
Any, Any,
Dict,
Awaitable,
Generator,
Coroutine, Coroutine,
Callable, Callable,
Iterable,
Iterator, Iterator,
List, List,
Optional, Optional,
Set, Set,
Union, Union,
Tuple, Tuple,
TypeVar,
NoReturn,
Type,
cast,
) )
import asyncio
import functools import functools
import logging import logging
import socket as Socket import socket as Socket
@ -27,30 +34,66 @@ import ssl
import weakref import weakref
import uuid import uuid
import asyncio
from asyncio import iscoroutinefunction, wait, Future
from contextlib import contextmanager from contextlib import contextmanager
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from asyncio import (
AbstractEventLoop,
BaseTransport,
Future,
Task,
TimerHandle,
Transport,
iscoroutinefunction,
wait,
)
from slixmpp.xmlstream import tostring from slixmpp.types import FilterString
from slixmpp.xmlstream.tostring import tostring
from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase
from slixmpp.xmlstream.resolver import resolve, default_resolver from slixmpp.xmlstream.resolver import resolve, default_resolver
from slixmpp.xmlstream.handler.base import BaseHandler
T = TypeVar('T')
#: The time in seconds to wait before timing out waiting for response stanzas. #: The time in seconds to wait before timing out waiting for response stanzas.
RESPONSE_TIMEOUT = 30 RESPONSE_TIMEOUT = 30
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class ContinueQueue(Exception): class ContinueQueue(Exception):
""" """
Exception raised in the send queue to "continue" from within an inner loop Exception raised in the send queue to "continue" from within an inner loop
""" """
class NotConnectedError(Exception): class NotConnectedError(Exception):
""" """
Raised when we try to send something over the wire but we are not Raised when we try to send something over the wire but we are not
connected. connected.
""" """
_T = TypeVar('_T', str, ElementBase, StanzaBase)
SyncFilter = Callable[[StanzaBase], Optional[StanzaBase]]
AsyncFilter = Callable[[StanzaBase], Awaitable[Optional[StanzaBase]]]
Filter = Union[
SyncFilter,
AsyncFilter,
]
_FiltersDict = Dict[str, List[Filter]]
Handler = Callable[[Any], Union[
Any,
Coroutine[Any, Any, Any]
]]
class XMLStream(asyncio.BaseProtocol): class XMLStream(asyncio.BaseProtocol):
""" """
An XML stream connection manager and event dispatcher. An XML stream connection manager and event dispatcher.
@ -78,16 +121,156 @@ class XMLStream(asyncio.BaseProtocol):
:param int port: The port to use for the connection. Defaults to 0. :param int port: The port to use for the connection. Defaults to 0.
""" """
def __init__(self, host='', port=0): transport: Optional[Transport]
# The asyncio.Transport object provided by the connection_made()
# callback when we are connected
self.transport = None
# The socket that is used internally by the transport object # The socket that is used internally by the transport object
self.socket = None socket: Optional[ssl.SSLSocket]
# The backoff of the connect routine (increases exponentially # The backoff of the connect routine (increases exponentially
# after each failure) # after each failure)
_connect_loop_wait: float
parser: Optional[ET.XMLPullParser]
xml_depth: int
xml_root: Optional[ET.Element]
force_starttls: Optional[bool]
disable_starttls: Optional[bool]
waiting_queue: asyncio.Queue[Tuple[Union[StanzaBase, str], bool]]
# A dict of {name: handle}
scheduled_events: Dict[str, TimerHandle]
ssl_context: ssl.SSLContext
# The event to trigger when the create_connection() succeeds. It can
# be "connected" or "tls_success" depending on the step we are at.
event_when_connected: str
#: The list of accepted ciphers, in OpenSSL Format.
#: It might be useful to override it for improved security
#: over the python defaults.
ciphers: Optional[str]
#: Path to a file containing certificates for verifying the
#: server SSL certificate. A non-``None`` value will trigger
#: certificate checking.
#:
#: .. note::
#:
#: On Mac OS X, certificates in the system keyring will
#: be consulted, even if they are not in the provided file.
ca_certs: Optional[str]
#: Path to a file containing a client certificate to use for
#: authenticating via SASL EXTERNAL. If set, there must also
#: be a corresponding `:attr:keyfile` value.
certfile: Optional[str]
#: Path to a file containing the private key for the selected
#: client certificate to use for authenticating via SASL EXTERNAL.
keyfile: Optional[str]
# The asyncio event loop
_loop: Optional[AbstractEventLoop]
#: The default port to return when querying DNS records.
default_port: int
#: The domain to try when querying DNS records.
default_domain: str
#: The expected name of the server, for validation.
_expected_server_name: str
_service_name: str
#: The desired, or actual, address of the connected server.
address: Tuple[str, int]
#: Enable connecting to the server directly over SSL, in
#: particular when the service provides two ports: one for
#: non-SSL traffic and another for SSL traffic.
use_ssl: bool
#: If set to ``True``, attempt to use IPv6.
use_ipv6: bool
#: If set to ``True``, allow using the ``dnspython`` DNS library
#: if available. If set to ``False``, the builtin DNS resolver
#: will be used, even if ``dnspython`` is installed.
use_aiodns: bool
#: Use CDATA for escaping instead of XML entities. Defaults
#: to ``False``.
use_cdata: bool
#: The default namespace of the stream content, not of the
#: stream wrapper it
default_ns: str
default_lang: Optional[str]
peer_default_lang: Optional[str]
#: The namespace of the enveloping stream element.
stream_ns: str
#: The default opening tag for the stream element.
stream_header: str
#: The default closing tag for the stream element.
stream_footer: str
#: If ``True``, periodically send a whitespace character over the
#: wire to keep the connection alive. Mainly useful for connections
#: traversing NAT.
whitespace_keepalive: bool
#: The default interval between keepalive signals when
#: :attr:`whitespace_keepalive` is enabled.
whitespace_keepalive_interval: int
#: Flag for controlling if the session can be considered ended
#: if the connection is terminated.
end_session_on_disconnect: bool
#: A mapping of XML namespaces to well-known prefixes.
namespace_map: dict
__root_stanza: List[Type[StanzaBase]]
__handlers: List[BaseHandler]
__event_handlers: Dict[str, List[Tuple[Handler, bool]]]
__filters: _FiltersDict
# Current connection attempt (Future)
_current_connection_attempt: Optional[Future[None]]
#: A list of DNS results that have not yet been tried.
_dns_answers: Optional[Iterator[Tuple[str, str, int]]]
#: The service name to check with DNS SRV records. For
#: example, setting this to ``'xmpp-client'`` would query the
#: ``_xmpp-client._tcp`` service.
dns_service: Optional[str]
#: The reason why we are disconnecting from the server
disconnect_reason: Optional[str]
#: An asyncio Future being done when the stream is disconnected.
disconnected: Future[bool]
# If the session has been started or not
_session_started: bool
# If we want to bypass the send() check (e.g. unit tests)
_always_send_everything: bool
_run_out_filters: Optional[Future]
__slow_tasks: List[Task]
__queued_stanzas: List[Tuple[Union[StanzaBase, str], bool]]
def __init__(self, host: str = '', port: int = 0):
self.transport = None
self.socket = None
self._connect_loop_wait = 0 self._connect_loop_wait = 0
self.parser = None self.parser = None
@ -106,126 +289,60 @@ class XMLStream(asyncio.BaseProtocol):
self.ssl_context.check_hostname = False self.ssl_context.check_hostname = False
self.ssl_context.verify_mode = ssl.CERT_NONE self.ssl_context.verify_mode = ssl.CERT_NONE
# The event to trigger when the create_connection() succeeds. It can
# be "connected" or "tls_success" depending on the step we are at.
self.event_when_connected = "connected" self.event_when_connected = "connected"
#: The list of accepted ciphers, in OpenSSL Format.
#: It might be useful to override it for improved security
#: over the python defaults.
self.ciphers = None self.ciphers = None
#: Path to a file containing certificates for verifying the
#: server SSL certificate. A non-``None`` value will trigger
#: certificate checking.
#:
#: .. note::
#:
#: On Mac OS X, certificates in the system keyring will
#: be consulted, even if they are not in the provided file.
self.ca_certs = None self.ca_certs = None
#: Path to a file containing a client certificate to use for
#: authenticating via SASL EXTERNAL. If set, there must also
#: be a corresponding `:attr:keyfile` value.
self.certfile = None
#: Path to a file containing the private key for the selected
#: client certificate to use for authenticating via SASL EXTERNAL.
self.keyfile = None self.keyfile = None
self._der_cert = None
# The asyncio event loop
self._loop = None self._loop = None
#: The default port to return when querying DNS records.
self.default_port = int(port) self.default_port = int(port)
#: The domain to try when querying DNS records.
self.default_domain = '' self.default_domain = ''
#: The expected name of the server, for validation.
self._expected_server_name = '' self._expected_server_name = ''
self._service_name = '' self._service_name = ''
#: The desired, or actual, address of the connected server.
self.address = (host, int(port)) self.address = (host, int(port))
#: Enable connecting to the server directly over SSL, in
#: particular when the service provides two ports: one for
#: non-SSL traffic and another for SSL traffic.
self.use_ssl = False self.use_ssl = False
#: If set to ``True``, attempt to use IPv6.
self.use_ipv6 = True self.use_ipv6 = True
#: If set to ``True``, allow using the ``dnspython`` DNS library
#: if available. If set to ``False``, the builtin DNS resolver
#: will be used, even if ``dnspython`` is installed.
self.use_aiodns = True self.use_aiodns = True
#: Use CDATA for escaping instead of XML entities. Defaults
#: to ``False``.
self.use_cdata = False self.use_cdata = False
#: The default namespace of the stream content, not of the
#: stream wrapper itself.
self.default_ns = '' self.default_ns = ''
self.default_lang = None self.default_lang = None
self.peer_default_lang = None self.peer_default_lang = None
#: The namespace of the enveloping stream element.
self.stream_ns = '' self.stream_ns = ''
#: The default opening tag for the stream element.
self.stream_header = "<stream>" self.stream_header = "<stream>"
#: The default closing tag for the stream element.
self.stream_footer = "</stream>" self.stream_footer = "</stream>"
#: If ``True``, periodically send a whitespace character over the
#: wire to keep the connection alive. Mainly useful for connections
#: traversing NAT.
self.whitespace_keepalive = True self.whitespace_keepalive = True
#: The default interval between keepalive signals when
#: :attr:`whitespace_keepalive` is enabled.
self.whitespace_keepalive_interval = 300 self.whitespace_keepalive_interval = 300
#: Flag for controlling if the session can be considered ended
#: if the connection is terminated.
self.end_session_on_disconnect = True self.end_session_on_disconnect = True
#: A mapping of XML namespaces to well-known prefixes.
self.namespace_map = {StanzaBase.xml_ns: 'xml'} self.namespace_map = {StanzaBase.xml_ns: 'xml'}
self.__root_stanza = [] self.__root_stanza = []
self.__handlers = [] self.__handlers = []
self.__event_handlers = {} self.__event_handlers = {}
self.__filters = {'in': [], 'out': [], 'out_sync': []} self.__filters = {
'in': [], 'out': [], 'out_sync': []
}
# Current connection attempt (Future)
self._current_connection_attempt = None self._current_connection_attempt = None
#: A list of DNS results that have not yet been tried. self._dns_answers = None
self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None
#: The service name to check with DNS SRV records. For
#: example, setting this to ``'xmpp-client'`` would query the
#: ``_xmpp-client._tcp`` service.
self.dns_service = None self.dns_service = None
#: The reason why we are disconnecting from the server
self.disconnect_reason = None self.disconnect_reason = None
self.disconnected = Future()
#: An asyncio Future being done when the stream is disconnected.
self.disconnected: Future = Future()
# If the session has been started or not
self._session_started = False self._session_started = False
# If we want to bypass the send() check (e.g. unit tests)
self._always_send_everything = False self._always_send_everything = False
self.add_event_handler('disconnected', self._remove_schedules) self.add_event_handler('disconnected', self._remove_schedules)
@ -234,21 +351,21 @@ class XMLStream(asyncio.BaseProtocol):
self.add_event_handler('session_start', self._set_session_start) self.add_event_handler('session_start', self._set_session_start)
self.add_event_handler('session_resumed', self._set_session_start) self.add_event_handler('session_resumed', self._set_session_start)
self._run_out_filters: Optional[Future] = None self._run_out_filters = None
self.__slow_tasks: List[Future] = [] self.__slow_tasks = []
self.__queued_stanzas: List[Tuple[StanzaBase, bool]] = [] self.__queued_stanzas = []
@property @property
def loop(self): def loop(self) -> AbstractEventLoop:
if self._loop is None: if self._loop is None:
self._loop = asyncio.get_event_loop() self._loop = asyncio.get_event_loop()
return self._loop return self._loop
@loop.setter @loop.setter
def loop(self, value): def loop(self, value: AbstractEventLoop) -> None:
self._loop = value self._loop = value
def new_id(self): def new_id(self) -> str:
"""Generate and return a new stream ID in hexadecimal form. """Generate and return a new stream ID in hexadecimal form.
Many stanzas, handlers, or matchers may require unique Many stanzas, handlers, or matchers may require unique
@ -257,7 +374,7 @@ class XMLStream(asyncio.BaseProtocol):
""" """
return uuid.uuid4().hex return uuid.uuid4().hex
def _set_session_start(self, event): def _set_session_start(self, event: Any) -> None:
""" """
On session start, queue all pending stanzas to be sent. On session start, queue all pending stanzas to be sent.
""" """
@ -266,17 +383,17 @@ class XMLStream(asyncio.BaseProtocol):
self.waiting_queue.put_nowait(stanza) self.waiting_queue.put_nowait(stanza)
self.__queued_stanzas = [] self.__queued_stanzas = []
def _set_disconnected(self, event): def _set_disconnected(self, event: Any) -> None:
self._session_started = False self._session_started = False
def _set_disconnected_future(self): def _set_disconnected_future(self) -> None:
"""Set the self.disconnected future on disconnect""" """Set the self.disconnected future on disconnect"""
if not self.disconnected.done(): if not self.disconnected.done():
self.disconnected.set_result(True) self.disconnected.set_result(True)
self.disconnected = asyncio.Future() self.disconnected = asyncio.Future()
def connect(self, host='', port=0, use_ssl=False, def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False,
force_starttls=True, disable_starttls=False): force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None:
"""Create a new socket and connect to the server. """Create a new socket and connect to the server.
:param host: The name of the desired server for the connection. :param host: The name of the desired server for the connection.
@ -327,7 +444,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop, loop=self.loop,
) )
async def _connect_routine(self): async def _connect_routine(self) -> None:
self.event_when_connected = "connected" self.event_when_connected = "connected"
if self._connect_loop_wait > 0: if self._connect_loop_wait > 0:
@ -345,6 +462,7 @@ class XMLStream(asyncio.BaseProtocol):
# and try (host, port) as a last resort # and try (host, port) as a last resort
self._dns_answers = None self._dns_answers = None
ssl_context: Optional[ssl.SSLContext]
if self.use_ssl: if self.use_ssl:
ssl_context = self.get_ssl_context() ssl_context = self.get_ssl_context()
else: else:
@ -373,7 +491,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop, loop=self.loop,
) )
def process(self, *, forever=True, timeout=None): def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None:
"""Process all the available XMPP events (receiving or sending data on the """Process all the available XMPP events (receiving or sending data on the
socket(s), calling various registered callbacks, calling expired socket(s), calling various registered callbacks, calling expired
timers, handling signal events, etc). If timeout is None, this timers, handling signal events, etc). If timeout is None, this
@ -386,12 +504,12 @@ class XMLStream(asyncio.BaseProtocol):
else: else:
self.loop.run_until_complete(self.disconnected) self.loop.run_until_complete(self.disconnected)
else: else:
tasks = [asyncio.sleep(timeout, loop=self.loop)] tasks: List[Future[bool]] = [asyncio.sleep(timeout, loop=self.loop)]
if not forever: if not forever:
tasks.append(self.disconnected) tasks.append(self.disconnected)
self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop)) self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))
def init_parser(self): def init_parser(self) -> None:
"""init the XML parser. The parser must always be reset for each new """init the XML parser. The parser must always be reset for each new
connexion connexion
""" """
@ -399,11 +517,13 @@ class XMLStream(asyncio.BaseProtocol):
self.xml_root = None self.xml_root = None
self.parser = ET.XMLPullParser(("start", "end")) self.parser = ET.XMLPullParser(("start", "end"))
def connection_made(self, transport): def connection_made(self, transport: BaseTransport) -> None:
"""Called when the TCP connection has been established with the server """Called when the TCP connection has been established with the server
""" """
self.event(self.event_when_connected) self.event(self.event_when_connected)
self.transport = transport self.transport = cast(Transport, transport)
if self.transport is None:
raise ValueError("Transport cannot be none")
self.socket = self.transport.get_extra_info( self.socket = self.transport.get_extra_info(
"ssl_object", "ssl_object",
default=self.transport.get_extra_info("socket") default=self.transport.get_extra_info("socket")
@ -413,7 +533,7 @@ class XMLStream(asyncio.BaseProtocol):
self.send_raw(self.stream_header) self.send_raw(self.stream_header)
self._dns_answers = None self._dns_answers = None
def data_received(self, data): def data_received(self, data: bytes) -> None:
"""Called when incoming data is received on the socket. """Called when incoming data is received on the socket.
We feed that data to the parser and the see if this produced any XML We feed that data to the parser and the see if this produced any XML
@ -467,18 +587,18 @@ class XMLStream(asyncio.BaseProtocol):
self.send(error) self.send(error)
self.disconnect() self.disconnect()
def is_connecting(self): def is_connecting(self) -> bool:
return self._current_connection_attempt is not None return self._current_connection_attempt is not None
def is_connected(self): def is_connected(self) -> bool:
return self.transport is not None return self.transport is not None
def eof_received(self): def eof_received(self) -> None:
"""When the TCP connection is properly closed by the remote end """When the TCP connection is properly closed by the remote end
""" """
self.event("eof_received") self.event("eof_received")
def connection_lost(self, exception): def connection_lost(self, exception: Optional[BaseException]) -> None:
"""On any kind of disconnection, initiated by us or not. This signals the """On any kind of disconnection, initiated by us or not. This signals the
closure of the TCP connection closure of the TCP connection
""" """
@ -493,9 +613,9 @@ class XMLStream(asyncio.BaseProtocol):
self._reset_sendq() self._reset_sendq()
self.event('session_end') self.event('session_end')
self._set_disconnected_future() self._set_disconnected_future()
self.event("disconnected", self.disconnect_reason or exception and exception.strerror) self.event("disconnected", self.disconnect_reason or exception)
def cancel_connection_attempt(self): def cancel_connection_attempt(self) -> None:
""" """
Immediately cancel the current create_connection() Future. Immediately cancel the current create_connection() Future.
This is useful when a client using slixmpp tries to connect This is useful when a client using slixmpp tries to connect
@ -506,7 +626,7 @@ class XMLStream(asyncio.BaseProtocol):
self._current_connection_attempt.cancel() self._current_connection_attempt.cancel()
self._current_connection_attempt = None self._current_connection_attempt = None
def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future: def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future[None]:
"""Close the XML stream and wait for an acknowldgement from the server for """Close the XML stream and wait for an acknowldgement from the server for
at most `wait` seconds. After the given number of seconds has at most `wait` seconds. After the given number of seconds has
passed without a response from the server, or when the server passed without a response from the server, or when the server
@ -526,7 +646,7 @@ class XMLStream(asyncio.BaseProtocol):
# `disconnect(wait=True)` for ages. This doesn't mean anything to the # `disconnect(wait=True)` for ages. This doesn't mean anything to the
# schedule call below. It would fortunately be converted to `1` later # schedule call below. It would fortunately be converted to `1` later
# down the call chain. Praise the implicit casts lord. # down the call chain. Praise the implicit casts lord.
if wait == True: if wait is True:
wait = 2.0 wait = 2.0
if self.transport: if self.transport:
@ -545,11 +665,11 @@ class XMLStream(asyncio.BaseProtocol):
else: else:
self._set_disconnected_future() self._set_disconnected_future()
self.event("disconnected", reason) self.event("disconnected", reason)
future = Future() future: Future[None] = Future()
future.set_result(None) future.set_result(None)
return future return future
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float): async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float) -> None:
"""Wait until the send queue is empty before disconnecting""" """Wait until the send queue is empty before disconnecting"""
try: try:
await asyncio.wait_for( await asyncio.wait_for(
@ -561,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol):
self.disconnect_reason = reason self.disconnect_reason = reason
await self._end_stream_wait(wait) await self._end_stream_wait(wait)
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None): async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None) -> None:
""" """
Run abort() if we do not received the disconnected event Run abort() if we do not received the disconnected event
after a waiting time. after a waiting time.
@ -578,7 +698,7 @@ class XMLStream(asyncio.BaseProtocol):
# that means the disconnect has already been handled # that means the disconnect has already been handled
pass pass
def abort(self): def abort(self) -> None:
""" """
Forcibly close the connection Forcibly close the connection
""" """
@ -588,26 +708,26 @@ class XMLStream(asyncio.BaseProtocol):
self.transport.abort() self.transport.abort()
self.event("killed") self.event("killed")
def reconnect(self, wait=2.0, reason="Reconnecting"): def reconnect(self, wait: Union[int, float] = 2.0, reason: str = "Reconnecting") -> None:
"""Calls disconnect(), and once we are disconnected (after the timeout, or """Calls disconnect(), and once we are disconnected (after the timeout, or
when the server acknowledgement is received), call connect() when the server acknowledgement is received), call connect()
""" """
log.debug("reconnecting...") log.debug("reconnecting...")
async def handler(event): async def handler(event: Any) -> None:
# We yield here to allow synchronous handlers to work first # We yield here to allow synchronous handlers to work first
await asyncio.sleep(0, loop=self.loop) await asyncio.sleep(0, loop=self.loop)
self.connect() self.connect()
self.add_event_handler('disconnected', handler, disposable=True) self.add_event_handler('disconnected', handler, disposable=True)
self.disconnect(wait, reason) self.disconnect(wait, reason)
def configure_socket(self): def configure_socket(self) -> None:
"""Set timeout and other options for self.socket. """Set timeout and other options for self.socket.
Meant to be overridden. Meant to be overridden.
""" """
pass pass
def configure_dns(self, resolver, domain=None, port=None): def configure_dns(self, resolver: Any, domain: Optional[str] = None, port: Optional[int] = None) -> None:
""" """
Configure and set options for a :class:`~dns.resolver.Resolver` Configure and set options for a :class:`~dns.resolver.Resolver`
instance, and other DNS related tasks. For example, you instance, and other DNS related tasks. For example, you
@ -624,7 +744,7 @@ class XMLStream(asyncio.BaseProtocol):
""" """
pass pass
def get_ssl_context(self): def get_ssl_context(self) -> ssl.SSLContext:
""" """
Get SSL context. Get SSL context.
""" """
@ -644,12 +764,14 @@ class XMLStream(asyncio.BaseProtocol):
return self.ssl_context return self.ssl_context
async def start_tls(self): async def start_tls(self) -> bool:
"""Perform handshakes for TLS. """Perform handshakes for TLS.
If the handshake is successful, the XML stream will need If the handshake is successful, the XML stream will need
to be restarted. to be restarted.
""" """
if self.transport is None:
raise ValueError("Transport should not be None")
self.event_when_connected = "tls_success" self.event_when_connected = "tls_success"
ssl_context = self.get_ssl_context() ssl_context = self.get_ssl_context()
try: try:
@ -685,7 +807,7 @@ class XMLStream(asyncio.BaseProtocol):
self.connection_made(transp) self.connection_made(transp)
return True return True
def _start_keepalive(self, event): def _start_keepalive(self, event: Any) -> None:
"""Begin sending whitespace periodically to keep the connection alive. """Begin sending whitespace periodically to keep the connection alive.
May be disabled by setting:: May be disabled by setting::
@ -702,11 +824,11 @@ class XMLStream(asyncio.BaseProtocol):
args=(' ',), args=(' ',),
repeat=True) repeat=True)
def _remove_schedules(self, event): def _remove_schedules(self, event: Any) -> None:
"""Remove some schedules that become pointless when disconnected""" """Remove some schedules that become pointless when disconnected"""
self.cancel_schedule('Whitespace Keepalive') self.cancel_schedule('Whitespace Keepalive')
def start_stream_handler(self, xml): def start_stream_handler(self, xml: ET.Element) -> None:
"""Perform any initialization actions, such as handshakes, """Perform any initialization actions, such as handshakes,
once the stream header has been sent. once the stream header has been sent.
@ -714,7 +836,7 @@ class XMLStream(asyncio.BaseProtocol):
""" """
pass pass
def register_stanza(self, stanza_class): def register_stanza(self, stanza_class: Type[StanzaBase]) -> None:
"""Add a stanza object class as a known root stanza. """Add a stanza object class as a known root stanza.
A root stanza is one that appears as a direct child of the stream's A root stanza is one that appears as a direct child of the stream's
@ -732,7 +854,7 @@ class XMLStream(asyncio.BaseProtocol):
""" """
self.__root_stanza.append(stanza_class) self.__root_stanza.append(stanza_class)
def remove_stanza(self, stanza_class): def remove_stanza(self, stanza_class: Type[StanzaBase]) -> None:
"""Remove a stanza from being a known root stanza. """Remove a stanza from being a known root stanza.
A root stanza is one that appears as a direct child of the stream's A root stanza is one that appears as a direct child of the stream's
@ -744,7 +866,7 @@ class XMLStream(asyncio.BaseProtocol):
""" """
self.__root_stanza.remove(stanza_class) self.__root_stanza.remove(stanza_class)
def add_filter(self, mode, handler, order=None): def add_filter(self, mode: FilterString, handler: Callable[[StanzaBase], Optional[StanzaBase]], order: Optional[int] = None) -> None:
"""Add a filter for incoming or outgoing stanzas. """Add a filter for incoming or outgoing stanzas.
These filters are applied before incoming stanzas are These filters are applied before incoming stanzas are
@ -766,11 +888,11 @@ class XMLStream(asyncio.BaseProtocol):
else: else:
self.__filters[mode].append(handler) self.__filters[mode].append(handler)
def del_filter(self, mode, handler): def del_filter(self, mode: str, handler: Callable[[StanzaBase], Optional[StanzaBase]]) -> None:
"""Remove an incoming or outgoing filter.""" """Remove an incoming or outgoing filter."""
self.__filters[mode].remove(handler) self.__filters[mode].remove(handler)
def register_handler(self, handler, before=None, after=None): def register_handler(self, handler: BaseHandler, before: Optional[BaseHandler] = None, after: Optional[BaseHandler] = None) -> None:
"""Add a stream event handler that will be executed when a matching """Add a stream event handler that will be executed when a matching
stanza is received. stanza is received.
@ -782,7 +904,7 @@ class XMLStream(asyncio.BaseProtocol):
self.__handlers.append(handler) self.__handlers.append(handler)
handler.stream = weakref.ref(self) handler.stream = weakref.ref(self)
def remove_handler(self, name): def remove_handler(self, name: str) -> bool:
"""Remove any stream event handlers with the given name. """Remove any stream event handlers with the given name.
:param name: The name of the handler. :param name: The name of the handler.
@ -831,9 +953,9 @@ class XMLStream(asyncio.BaseProtocol):
try: try:
return next(self._dns_answers) return next(self._dns_answers)
except StopIteration: except StopIteration:
return return None
def add_event_handler(self, name, pointer, disposable=False): def add_event_handler(self, name: str, pointer: Callable[..., Any], disposable: bool = False) -> None:
"""Add a custom event handler that will be executed whenever """Add a custom event handler that will be executed whenever
its event is manually triggered. its event is manually triggered.
@ -847,7 +969,7 @@ class XMLStream(asyncio.BaseProtocol):
self.__event_handlers[name] = [] self.__event_handlers[name] = []
self.__event_handlers[name].append((pointer, disposable)) self.__event_handlers[name].append((pointer, disposable))
def del_event_handler(self, name, pointer): def del_event_handler(self, name: str, pointer: Callable[..., Any]) -> None:
"""Remove a function as a handler for an event. """Remove a function as a handler for an event.
:param name: The name of the event. :param name: The name of the event.
@ -858,21 +980,21 @@ class XMLStream(asyncio.BaseProtocol):
# Need to keep handlers that do not use # Need to keep handlers that do not use
# the given function pointer # the given function pointer
def filter_pointers(handler): def filter_pointers(handler: Tuple[Callable[..., Any], bool]) -> bool:
return handler[0] != pointer return handler[0] != pointer
self.__event_handlers[name] = list(filter( self.__event_handlers[name] = list(filter(
filter_pointers, filter_pointers,
self.__event_handlers[name])) self.__event_handlers[name]))
def event_handled(self, name): def event_handled(self, name: str) -> int:
"""Returns the number of registered handlers for an event. """Returns the number of registered handlers for an event.
:param name: The name of the event to check. :param name: The name of the event to check.
""" """
return len(self.__event_handlers.get(name, [])) return len(self.__event_handlers.get(name, []))
async def event_async(self, name: str, data: Any = {}): async def event_async(self, name: str, data: Any = {}) -> None:
"""Manually trigger a custom event, but await coroutines immediately. """Manually trigger a custom event, but await coroutines immediately.
This event generator should only be called in situations when This event generator should only be called in situations when
@ -908,7 +1030,7 @@ class XMLStream(asyncio.BaseProtocol):
except Exception as e: except Exception as e:
self.exception(e) self.exception(e)
def event(self, name: str, data: Any = {}): def event(self, name: str, data: Any = {}) -> None:
"""Manually trigger a custom event. """Manually trigger a custom event.
Coroutine handlers are wrapped into a future and sent into the Coroutine handlers are wrapped into a future and sent into the
event loop for their execution, and not awaited. event loop for their execution, and not awaited.
@ -928,7 +1050,7 @@ class XMLStream(asyncio.BaseProtocol):
# If the callback is a coroutine, schedule it instead of # If the callback is a coroutine, schedule it instead of
# running it directly # running it directly
if iscoroutinefunction(handler_callback): if iscoroutinefunction(handler_callback):
async def handler_callback_routine(cb): async def handler_callback_routine(cb: Callable[[ElementBase], Any]) -> None:
try: try:
await cb(data) await cb(data)
except Exception as e: except Exception as e:
@ -957,8 +1079,9 @@ class XMLStream(asyncio.BaseProtocol):
except ValueError: except ValueError:
pass pass
def schedule(self, name, seconds, callback, args=tuple(), def schedule(self, name: str, seconds: int, callback: Callable[..., None],
kwargs={}, repeat=False): args: Tuple[Any, ...] = tuple(),
kwargs: Dict[Any, Any] = {}, repeat: bool = False) -> None:
"""Schedule a callback function to execute after a given delay. """Schedule a callback function to execute after a given delay.
:param name: A unique name for the scheduled callback. :param name: A unique name for the scheduled callback.
@ -986,21 +1109,21 @@ class XMLStream(asyncio.BaseProtocol):
# canceling scheduled_events[name] # canceling scheduled_events[name]
self.scheduled_events[name] = handle self.scheduled_events[name] = handle
def cancel_schedule(self, name): def cancel_schedule(self, name: str) -> None:
try: try:
handle = self.scheduled_events.pop(name) handle = self.scheduled_events.pop(name)
handle.cancel() handle.cancel()
except KeyError: except KeyError:
log.debug("Tried to cancel unscheduled event: %s" % (name,)) log.debug("Tried to cancel unscheduled event: %s" % (name,))
def _safe_cb_run(self, name, cb): def _safe_cb_run(self, name: str, cb: Callable[[], None]) -> None:
log.debug('Scheduled event: %s', name) log.debug('Scheduled event: %s', name)
try: try:
cb() cb()
except Exception as e: except Exception as e:
self.exception(e) self.exception(e)
def _execute_and_reschedule(self, name, cb, seconds): def _execute_and_reschedule(self, name: str, cb: Callable[[], None], seconds: int) -> None:
"""Simple method that calls the given callback, and then schedule itself to """Simple method that calls the given callback, and then schedule itself to
be called after the given number of seconds. be called after the given number of seconds.
""" """
@ -1009,7 +1132,7 @@ class XMLStream(asyncio.BaseProtocol):
name, cb, seconds) name, cb, seconds)
self.scheduled_events[name] = handle self.scheduled_events[name] = handle
def _execute_and_unschedule(self, name, cb): def _execute_and_unschedule(self, name: str, cb: Callable[[], None]) -> None:
""" """
Execute the callback and remove the handler for it. Execute the callback and remove the handler for it.
""" """
@ -1018,7 +1141,7 @@ class XMLStream(asyncio.BaseProtocol):
if name in self.scheduled_events: if name in self.scheduled_events:
del self.scheduled_events[name] del self.scheduled_events[name]
def incoming_filter(self, xml): def incoming_filter(self, xml: ET.Element) -> ET.Element:
"""Filter incoming XML objects before they are processed. """Filter incoming XML objects before they are processed.
Possible uses include remapping namespaces, or correcting elements Possible uses include remapping namespaces, or correcting elements
@ -1028,7 +1151,7 @@ class XMLStream(asyncio.BaseProtocol):
""" """
return xml return xml
def _reset_sendq(self): def _reset_sendq(self) -> None:
"""Clear sending tasks on session end""" """Clear sending tasks on session end"""
# Cancel all pending slow send tasks # Cancel all pending slow send tasks
log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks)) log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
@ -1042,8 +1165,8 @@ class XMLStream(asyncio.BaseProtocol):
async def _continue_slow_send( async def _continue_slow_send(
self, self,
task: asyncio.Task, task: asyncio.Task[Optional[StanzaBase]],
already_used: Set[Callable[[ElementBase], Optional[StanzaBase]]] already_used: Set[Filter]
) -> None: ) -> None:
""" """
Used when an item in the send queue has taken too long to process. Used when an item in the send queue has taken too long to process.
@ -1062,12 +1185,14 @@ class XMLStream(asyncio.BaseProtocol):
if iscoroutinefunction(filter): if iscoroutinefunction(filter):
data = await filter(data) data = await filter(data)
else: else:
filter = cast(SyncFilter, filter)
data = filter(data) data = filter(data)
if data is None: if data is None:
return return
if isinstance(data, ElementBase): if isinstance(data, StanzaBase):
for filter in self.__filters['out_sync']: for filter in self.__filters['out_sync']:
filter = cast(SyncFilter, filter)
data = filter(data) data = filter(data)
if data is None: if data is None:
return return
@ -1077,19 +1202,21 @@ class XMLStream(asyncio.BaseProtocol):
else: else:
self.send_raw(data) self.send_raw(data)
async def run_filters(self): async def run_filters(self) -> NoReturn:
""" """
Background loop that processes stanzas to send. Background loop that processes stanzas to send.
""" """
while True: while True:
data: Optional[Union[StanzaBase, str]]
(data, use_filters) = await self.waiting_queue.get() (data, use_filters) = await self.waiting_queue.get()
try: try:
if isinstance(data, ElementBase): if isinstance(data, StanzaBase):
if use_filters: if use_filters:
already_run_filters = set() already_run_filters = set()
for filter in self.__filters['out']: for filter in self.__filters['out']:
already_run_filters.add(filter) already_run_filters.add(filter)
if iscoroutinefunction(filter): if iscoroutinefunction(filter):
filter = cast(AsyncFilter, filter)
task = asyncio.create_task(filter(data)) task = asyncio.create_task(filter(data))
completed, pending = await wait( completed, pending = await wait(
{task}, {task},
@ -1108,19 +1235,24 @@ class XMLStream(asyncio.BaseProtocol):
"Slow coroutine, rescheduling filters" "Slow coroutine, rescheduling filters"
) )
data = task.result() data = task.result()
else: elif isinstance(data, StanzaBase):
filter = cast(SyncFilter, filter)
data = filter(data) data = filter(data)
if data is None: if data is None:
raise ContinueQueue('Empty stanza') raise ContinueQueue('Empty stanza')
if isinstance(data, ElementBase): if isinstance(data, StanzaBase):
if use_filters: if use_filters:
for filter in self.__filters['out_sync']: for filter in self.__filters['out_sync']:
filter = cast(SyncFilter, filter)
data = filter(data) data = filter(data)
if data is None: if data is None:
raise ContinueQueue('Empty stanza') raise ContinueQueue('Empty stanza')
if isinstance(data, StanzaBase):
str_data = tostring(data.xml, xmlns=self.default_ns, str_data = tostring(data.xml, xmlns=self.default_ns,
stream=self, top_level=True) stream=self, top_level=True)
else:
str_data = data
self.send_raw(str_data) self.send_raw(str_data)
else: else:
self.send_raw(data) self.send_raw(data)
@ -1130,10 +1262,10 @@ class XMLStream(asyncio.BaseProtocol):
log.error('Exception raised in send queue:', exc_info=True) log.error('Exception raised in send queue:', exc_info=True)
self.waiting_queue.task_done() self.waiting_queue.task_done()
def send(self, data, use_filters=True): def send(self, data: Union[StanzaBase, str], use_filters: bool = True) -> None:
"""A wrapper for :meth:`send_raw()` for sending stanza objects. """A wrapper for :meth:`send_raw()` for sending stanza objects.
:param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase` :param data: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
stanza to send on the stream. stanza to send on the stream.
:param bool use_filters: Indicates if outgoing filters should be :param bool use_filters: Indicates if outgoing filters should be
applied to the given stanza data. Disabling applied to the given stanza data. Disabling
@ -1156,15 +1288,15 @@ class XMLStream(asyncio.BaseProtocol):
return return
self.waiting_queue.put_nowait((data, use_filters)) self.waiting_queue.put_nowait((data, use_filters))
def send_xml(self, data): def send_xml(self, data: ET.Element) -> None:
"""Send an XML object on the stream """Send an XML object on the stream
:param data: The :class:`~xml.etree.ElementTree.Element` XML object :param data: The :class:`~xml.etree.ElementTree.Element` XML object
to send on the stream. to send on the stream.
""" """
return self.send(tostring(data)) self.send(tostring(data))
def send_raw(self, data): def send_raw(self, data: Union[str, bytes]) -> None:
"""Send raw data across the stream. """Send raw data across the stream.
:param string data: Any bytes or utf-8 string value. :param string data: Any bytes or utf-8 string value.
@ -1176,7 +1308,8 @@ class XMLStream(asyncio.BaseProtocol):
data = data.encode('utf-8') data = data.encode('utf-8')
self.transport.write(data) self.transport.write(data)
def _build_stanza(self, xml, default_ns=None): def _build_stanza(self, xml: ET.Element,
default_ns: Optional[str] = None) -> StanzaBase:
"""Create a stanza object from a given XML object. """Create a stanza object from a given XML object.
If a specialized stanza type is not found for the XML, then If a specialized stanza type is not found for the XML, then
@ -1201,7 +1334,7 @@ class XMLStream(asyncio.BaseProtocol):
stanza['lang'] = self.peer_default_lang stanza['lang'] = self.peer_default_lang
return stanza return stanza
def _spawn_event(self, xml): def _spawn_event(self, xml: ET.Element) -> None:
""" """
Analyze incoming XML stanzas and convert them into stanza Analyze incoming XML stanzas and convert them into stanza
objects if applicable and queue stream events to be processed objects if applicable and queue stream events to be processed
@ -1215,9 +1348,10 @@ class XMLStream(asyncio.BaseProtocol):
# Convert the raw XML object into a stanza object. If no registered # Convert the raw XML object into a stanza object. If no registered
# stanza type applies, a generic StanzaBase stanza will be used. # stanza type applies, a generic StanzaBase stanza will be used.
stanza = self._build_stanza(xml) stanza: Optional[StanzaBase] = self._build_stanza(xml)
for filter in self.__filters['in']: for filter in self.__filters['in']:
if stanza is not None: if stanza is not None:
filter = cast(SyncFilter, filter)
stanza = filter(stanza) stanza = filter(stanza)
if stanza is None: if stanza is None:
return return
@ -1244,7 +1378,7 @@ class XMLStream(asyncio.BaseProtocol):
if not handled: if not handled:
stanza.unhandled() stanza.unhandled()
def exception(self, exception): def exception(self, exception: Exception) -> None:
"""Process an unknown exception. """Process an unknown exception.
Meant to be overridden. Meant to be overridden.
@ -1253,7 +1387,7 @@ class XMLStream(asyncio.BaseProtocol):
""" """
pass pass
async def wait_until(self, event: str, timeout=30) -> Any: async def wait_until(self, event: str, timeout: Union[int, float] = 30) -> Any:
"""Utility method to wake on the next firing of an event. """Utility method to wake on the next firing of an event.
(Registers a disposable handler on it) (Registers a disposable handler on it)
@ -1261,9 +1395,9 @@ class XMLStream(asyncio.BaseProtocol):
:param int timeout: Timeout :param int timeout: Timeout
:raises: :class:`asyncio.TimeoutError` when the timeout is reached :raises: :class:`asyncio.TimeoutError` when the timeout is reached
""" """
fut = asyncio.Future() fut: Future[Any] = asyncio.Future()
def result_handler(event_data): def result_handler(event_data: Any) -> None:
if not fut.done(): if not fut.done():
fut.set_result(event_data) fut.set_result(event_data)
else: else:
@ -1280,19 +1414,19 @@ class XMLStream(asyncio.BaseProtocol):
return await asyncio.wait_for(fut, timeout) return await asyncio.wait_for(fut, timeout)
@contextmanager @contextmanager
def event_handler(self, event: str, handler: Callable): def event_handler(self, event: str, handler: Callable[..., Any]) -> Generator[None, None, None]:
""" """
Context manager that adds then removes an event handler. Context manager that adds then removes an event handler.
""" """
self.add_event_handler(event, handler) self.add_event_handler(event, handler)
try: try:
yield yield
except Exception as exc: except Exception:
raise raise
finally: finally:
self.del_event_handler(event, handler) self.del_event_handler(event, handler)
def wrap(self, coroutine: Coroutine[Any, Any, Any]) -> Future: def wrap(self, coroutine: Coroutine[None, None, T]) -> Future[T]:
"""Make a Future out of a coroutine with the current loop. """Make a Future out of a coroutine with the current loop.
:param coroutine: The coroutine to wrap. :param coroutine: The coroutine to wrap.