xmlstream: add more types

This commit is contained in:
mathieui 2021-07-03 11:07:01 +02:00
parent c07476e7de
commit db48c8f4da

View File

@ -9,17 +9,24 @@
# :license: MIT, see LICENSE for more details
from typing import (
Any,
Dict,
Awaitable,
Generator,
Coroutine,
Callable,
Iterable,
Iterator,
List,
Optional,
Set,
Union,
Tuple,
TypeVar,
NoReturn,
Type,
cast,
)
import asyncio
import functools
import logging
import socket as Socket
@ -27,30 +34,66 @@ import ssl
import weakref
import uuid
import asyncio
from asyncio import iscoroutinefunction, wait, Future
from contextlib import contextmanager
import xml.etree.ElementTree as ET
from asyncio import (
AbstractEventLoop,
BaseTransport,
Future,
Task,
TimerHandle,
Transport,
iscoroutinefunction,
wait,
)
from slixmpp.xmlstream import tostring
from slixmpp.types import FilterString
from slixmpp.xmlstream.tostring import tostring
from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase
from slixmpp.xmlstream.resolver import resolve, default_resolver
from slixmpp.xmlstream.handler.base import BaseHandler
T = TypeVar('T')
#: The time in seconds to wait before timing out waiting for response stanzas.
RESPONSE_TIMEOUT = 30
log = logging.getLogger(__name__)
class ContinueQueue(Exception):
"""
Exception raised in the send queue to "continue" from within an inner loop
"""
class NotConnectedError(Exception):
"""
Raised when we try to send something over the wire but we are not
connected.
"""
_T = TypeVar('_T', str, ElementBase, StanzaBase)
SyncFilter = Callable[[StanzaBase], Optional[StanzaBase]]
AsyncFilter = Callable[[StanzaBase], Awaitable[Optional[StanzaBase]]]
Filter = Union[
SyncFilter,
AsyncFilter,
]
_FiltersDict = Dict[str, List[Filter]]
Handler = Callable[[Any], Union[
Any,
Coroutine[Any, Any, Any]
]]
class XMLStream(asyncio.BaseProtocol):
"""
An XML stream connection manager and event dispatcher.
@ -78,16 +121,156 @@ class XMLStream(asyncio.BaseProtocol):
:param int port: The port to use for the connection. Defaults to 0.
"""
def __init__(self, host='', port=0):
# The asyncio.Transport object provided by the connection_made()
# callback when we are connected
self.transport = None
transport: Optional[Transport]
# The socket that is used internally by the transport object
self.socket = None
socket: Optional[ssl.SSLSocket]
# The backoff of the connect routine (increases exponentially
# after each failure)
_connect_loop_wait: float
parser: Optional[ET.XMLPullParser]
xml_depth: int
xml_root: Optional[ET.Element]
force_starttls: Optional[bool]
disable_starttls: Optional[bool]
waiting_queue: asyncio.Queue[Tuple[Union[StanzaBase, str], bool]]
# A dict of {name: handle}
scheduled_events: Dict[str, TimerHandle]
ssl_context: ssl.SSLContext
# The event to trigger when the create_connection() succeeds. It can
# be "connected" or "tls_success" depending on the step we are at.
event_when_connected: str
#: The list of accepted ciphers, in OpenSSL Format.
#: It might be useful to override it for improved security
#: over the python defaults.
ciphers: Optional[str]
#: Path to a file containing certificates for verifying the
#: server SSL certificate. A non-``None`` value will trigger
#: certificate checking.
#:
#: .. note::
#:
#: On Mac OS X, certificates in the system keyring will
#: be consulted, even if they are not in the provided file.
ca_certs: Optional[str]
#: Path to a file containing a client certificate to use for
#: authenticating via SASL EXTERNAL. If set, there must also
#: be a corresponding `:attr:keyfile` value.
certfile: Optional[str]
#: Path to a file containing the private key for the selected
#: client certificate to use for authenticating via SASL EXTERNAL.
keyfile: Optional[str]
# The asyncio event loop
_loop: Optional[AbstractEventLoop]
#: The default port to return when querying DNS records.
default_port: int
#: The domain to try when querying DNS records.
default_domain: str
#: The expected name of the server, for validation.
_expected_server_name: str
_service_name: str
#: The desired, or actual, address of the connected server.
address: Tuple[str, int]
#: Enable connecting to the server directly over SSL, in
#: particular when the service provides two ports: one for
#: non-SSL traffic and another for SSL traffic.
use_ssl: bool
#: If set to ``True``, attempt to use IPv6.
use_ipv6: bool
#: If set to ``True``, allow using the ``dnspython`` DNS library
#: if available. If set to ``False``, the builtin DNS resolver
#: will be used, even if ``dnspython`` is installed.
use_aiodns: bool
#: Use CDATA for escaping instead of XML entities. Defaults
#: to ``False``.
use_cdata: bool
#: The default namespace of the stream content, not of the
#: stream wrapper it
default_ns: str
default_lang: Optional[str]
peer_default_lang: Optional[str]
#: The namespace of the enveloping stream element.
stream_ns: str
#: The default opening tag for the stream element.
stream_header: str
#: The default closing tag for the stream element.
stream_footer: str
#: If ``True``, periodically send a whitespace character over the
#: wire to keep the connection alive. Mainly useful for connections
#: traversing NAT.
whitespace_keepalive: bool
#: The default interval between keepalive signals when
#: :attr:`whitespace_keepalive` is enabled.
whitespace_keepalive_interval: int
#: Flag for controlling if the session can be considered ended
#: if the connection is terminated.
end_session_on_disconnect: bool
#: A mapping of XML namespaces to well-known prefixes.
namespace_map: dict
__root_stanza: List[Type[StanzaBase]]
__handlers: List[BaseHandler]
__event_handlers: Dict[str, List[Tuple[Handler, bool]]]
__filters: _FiltersDict
# Current connection attempt (Future)
_current_connection_attempt: Optional[Future[None]]
#: A list of DNS results that have not yet been tried.
_dns_answers: Optional[Iterator[Tuple[str, str, int]]]
#: The service name to check with DNS SRV records. For
#: example, setting this to ``'xmpp-client'`` would query the
#: ``_xmpp-client._tcp`` service.
dns_service: Optional[str]
#: The reason why we are disconnecting from the server
disconnect_reason: Optional[str]
#: An asyncio Future being done when the stream is disconnected.
disconnected: Future[bool]
# If the session has been started or not
_session_started: bool
# If we want to bypass the send() check (e.g. unit tests)
_always_send_everything: bool
_run_out_filters: Optional[Future]
__slow_tasks: List[Task]
__queued_stanzas: List[Tuple[Union[StanzaBase, str], bool]]
def __init__(self, host: str = '', port: int = 0):
self.transport = None
self.socket = None
self._connect_loop_wait = 0
self.parser = None
@ -106,126 +289,60 @@ class XMLStream(asyncio.BaseProtocol):
self.ssl_context.check_hostname = False
self.ssl_context.verify_mode = ssl.CERT_NONE
# The event to trigger when the create_connection() succeeds. It can
# be "connected" or "tls_success" depending on the step we are at.
self.event_when_connected = "connected"
#: The list of accepted ciphers, in OpenSSL Format.
#: It might be useful to override it for improved security
#: over the python defaults.
self.ciphers = None
#: Path to a file containing certificates for verifying the
#: server SSL certificate. A non-``None`` value will trigger
#: certificate checking.
#:
#: .. note::
#:
#: On Mac OS X, certificates in the system keyring will
#: be consulted, even if they are not in the provided file.
self.ca_certs = None
#: Path to a file containing a client certificate to use for
#: authenticating via SASL EXTERNAL. If set, there must also
#: be a corresponding `:attr:keyfile` value.
self.certfile = None
#: Path to a file containing the private key for the selected
#: client certificate to use for authenticating via SASL EXTERNAL.
self.keyfile = None
self._der_cert = None
# The asyncio event loop
self._loop = None
#: The default port to return when querying DNS records.
self.default_port = int(port)
#: The domain to try when querying DNS records.
self.default_domain = ''
#: The expected name of the server, for validation.
self._expected_server_name = ''
self._service_name = ''
#: The desired, or actual, address of the connected server.
self.address = (host, int(port))
#: Enable connecting to the server directly over SSL, in
#: particular when the service provides two ports: one for
#: non-SSL traffic and another for SSL traffic.
self.use_ssl = False
#: If set to ``True``, attempt to use IPv6.
self.use_ipv6 = True
#: If set to ``True``, allow using the ``dnspython`` DNS library
#: if available. If set to ``False``, the builtin DNS resolver
#: will be used, even if ``dnspython`` is installed.
self.use_aiodns = True
#: Use CDATA for escaping instead of XML entities. Defaults
#: to ``False``.
self.use_cdata = False
#: The default namespace of the stream content, not of the
#: stream wrapper itself.
self.default_ns = ''
self.default_lang = None
self.peer_default_lang = None
#: The namespace of the enveloping stream element.
self.stream_ns = ''
#: The default opening tag for the stream element.
self.stream_header = "<stream>"
#: The default closing tag for the stream element.
self.stream_footer = "</stream>"
#: If ``True``, periodically send a whitespace character over the
#: wire to keep the connection alive. Mainly useful for connections
#: traversing NAT.
self.whitespace_keepalive = True
#: The default interval between keepalive signals when
#: :attr:`whitespace_keepalive` is enabled.
self.whitespace_keepalive_interval = 300
#: Flag for controlling if the session can be considered ended
#: if the connection is terminated.
self.end_session_on_disconnect = True
#: A mapping of XML namespaces to well-known prefixes.
self.namespace_map = {StanzaBase.xml_ns: 'xml'}
self.__root_stanza = []
self.__handlers = []
self.__event_handlers = {}
self.__filters = {'in': [], 'out': [], 'out_sync': []}
self.__filters = {
'in': [], 'out': [], 'out_sync': []
}
# Current connection attempt (Future)
self._current_connection_attempt = None
#: A list of DNS results that have not yet been tried.
self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None
#: The service name to check with DNS SRV records. For
#: example, setting this to ``'xmpp-client'`` would query the
#: ``_xmpp-client._tcp`` service.
self._dns_answers = None
self.dns_service = None
#: The reason why we are disconnecting from the server
self.disconnect_reason = None
#: An asyncio Future being done when the stream is disconnected.
self.disconnected: Future = Future()
# If the session has been started or not
self.disconnected = Future()
self._session_started = False
# If we want to bypass the send() check (e.g. unit tests)
self._always_send_everything = False
self.add_event_handler('disconnected', self._remove_schedules)
@ -234,21 +351,21 @@ class XMLStream(asyncio.BaseProtocol):
self.add_event_handler('session_start', self._set_session_start)
self.add_event_handler('session_resumed', self._set_session_start)
self._run_out_filters: Optional[Future] = None
self.__slow_tasks: List[Future] = []
self.__queued_stanzas: List[Tuple[StanzaBase, bool]] = []
self._run_out_filters = None
self.__slow_tasks = []
self.__queued_stanzas = []
@property
def loop(self):
def loop(self) -> AbstractEventLoop:
if self._loop is None:
self._loop = asyncio.get_event_loop()
return self._loop
@loop.setter
def loop(self, value):
def loop(self, value: AbstractEventLoop) -> None:
self._loop = value
def new_id(self):
def new_id(self) -> str:
"""Generate and return a new stream ID in hexadecimal form.
Many stanzas, handlers, or matchers may require unique
@ -257,7 +374,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
return uuid.uuid4().hex
def _set_session_start(self, event):
def _set_session_start(self, event: Any) -> None:
"""
On session start, queue all pending stanzas to be sent.
"""
@ -266,17 +383,17 @@ class XMLStream(asyncio.BaseProtocol):
self.waiting_queue.put_nowait(stanza)
self.__queued_stanzas = []
def _set_disconnected(self, event):
def _set_disconnected(self, event: Any) -> None:
self._session_started = False
def _set_disconnected_future(self):
def _set_disconnected_future(self) -> None:
"""Set the self.disconnected future on disconnect"""
if not self.disconnected.done():
self.disconnected.set_result(True)
self.disconnected = asyncio.Future()
def connect(self, host='', port=0, use_ssl=False,
force_starttls=True, disable_starttls=False):
def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False,
force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None:
"""Create a new socket and connect to the server.
:param host: The name of the desired server for the connection.
@ -327,7 +444,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop,
)
async def _connect_routine(self):
async def _connect_routine(self) -> None:
self.event_when_connected = "connected"
if self._connect_loop_wait > 0:
@ -345,6 +462,7 @@ class XMLStream(asyncio.BaseProtocol):
# and try (host, port) as a last resort
self._dns_answers = None
ssl_context: Optional[ssl.SSLContext]
if self.use_ssl:
ssl_context = self.get_ssl_context()
else:
@ -373,7 +491,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop,
)
def process(self, *, forever=True, timeout=None):
def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None:
"""Process all the available XMPP events (receiving or sending data on the
socket(s), calling various registered callbacks, calling expired
timers, handling signal events, etc). If timeout is None, this
@ -386,12 +504,12 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.loop.run_until_complete(self.disconnected)
else:
tasks = [asyncio.sleep(timeout, loop=self.loop)]
tasks: List[Future[bool]] = [asyncio.sleep(timeout, loop=self.loop)]
if not forever:
tasks.append(self.disconnected)
self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))
def init_parser(self):
def init_parser(self) -> None:
"""init the XML parser. The parser must always be reset for each new
connexion
"""
@ -399,11 +517,13 @@ class XMLStream(asyncio.BaseProtocol):
self.xml_root = None
self.parser = ET.XMLPullParser(("start", "end"))
def connection_made(self, transport):
def connection_made(self, transport: BaseTransport) -> None:
"""Called when the TCP connection has been established with the server
"""
self.event(self.event_when_connected)
self.transport = transport
self.transport = cast(Transport, transport)
if self.transport is None:
raise ValueError("Transport cannot be none")
self.socket = self.transport.get_extra_info(
"ssl_object",
default=self.transport.get_extra_info("socket")
@ -413,7 +533,7 @@ class XMLStream(asyncio.BaseProtocol):
self.send_raw(self.stream_header)
self._dns_answers = None
def data_received(self, data):
def data_received(self, data: bytes) -> None:
"""Called when incoming data is received on the socket.
We feed that data to the parser and the see if this produced any XML
@ -467,18 +587,18 @@ class XMLStream(asyncio.BaseProtocol):
self.send(error)
self.disconnect()
def is_connecting(self):
def is_connecting(self) -> bool:
return self._current_connection_attempt is not None
def is_connected(self):
def is_connected(self) -> bool:
return self.transport is not None
def eof_received(self):
def eof_received(self) -> None:
"""When the TCP connection is properly closed by the remote end
"""
self.event("eof_received")
def connection_lost(self, exception):
def connection_lost(self, exception: Optional[BaseException]) -> None:
"""On any kind of disconnection, initiated by us or not. This signals the
closure of the TCP connection
"""
@ -493,9 +613,9 @@ class XMLStream(asyncio.BaseProtocol):
self._reset_sendq()
self.event('session_end')
self._set_disconnected_future()
self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
self.event("disconnected", self.disconnect_reason or exception)
def cancel_connection_attempt(self):
def cancel_connection_attempt(self) -> None:
"""
Immediately cancel the current create_connection() Future.
This is useful when a client using slixmpp tries to connect
@ -506,7 +626,7 @@ class XMLStream(asyncio.BaseProtocol):
self._current_connection_attempt.cancel()
self._current_connection_attempt = None
def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future:
def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future[None]:
"""Close the XML stream and wait for an acknowldgement from the server for
at most `wait` seconds. After the given number of seconds has
passed without a response from the server, or when the server
@ -526,7 +646,7 @@ class XMLStream(asyncio.BaseProtocol):
# `disconnect(wait=True)` for ages. This doesn't mean anything to the
# schedule call below. It would fortunately be converted to `1` later
# down the call chain. Praise the implicit casts lord.
if wait == True:
if wait is True:
wait = 2.0
if self.transport:
@ -545,11 +665,11 @@ class XMLStream(asyncio.BaseProtocol):
else:
self._set_disconnected_future()
self.event("disconnected", reason)
future = Future()
future: Future[None] = Future()
future.set_result(None)
return future
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float) -> None:
"""Wait until the send queue is empty before disconnecting"""
try:
await asyncio.wait_for(
@ -561,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol):
self.disconnect_reason = reason
await self._end_stream_wait(wait)
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None):
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None) -> None:
"""
Run abort() if we do not received the disconnected event
after a waiting time.
@ -578,7 +698,7 @@ class XMLStream(asyncio.BaseProtocol):
# that means the disconnect has already been handled
pass
def abort(self):
def abort(self) -> None:
"""
Forcibly close the connection
"""
@ -588,26 +708,26 @@ class XMLStream(asyncio.BaseProtocol):
self.transport.abort()
self.event("killed")
def reconnect(self, wait=2.0, reason="Reconnecting"):
def reconnect(self, wait: Union[int, float] = 2.0, reason: str = "Reconnecting") -> None:
"""Calls disconnect(), and once we are disconnected (after the timeout, or
when the server acknowledgement is received), call connect()
"""
log.debug("reconnecting...")
async def handler(event):
async def handler(event: Any) -> None:
# We yield here to allow synchronous handlers to work first
await asyncio.sleep(0, loop=self.loop)
self.connect()
self.add_event_handler('disconnected', handler, disposable=True)
self.disconnect(wait, reason)
def configure_socket(self):
def configure_socket(self) -> None:
"""Set timeout and other options for self.socket.
Meant to be overridden.
"""
pass
def configure_dns(self, resolver, domain=None, port=None):
def configure_dns(self, resolver: Any, domain: Optional[str] = None, port: Optional[int] = None) -> None:
"""
Configure and set options for a :class:`~dns.resolver.Resolver`
instance, and other DNS related tasks. For example, you
@ -624,7 +744,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
pass
def get_ssl_context(self):
def get_ssl_context(self) -> ssl.SSLContext:
"""
Get SSL context.
"""
@ -644,12 +764,14 @@ class XMLStream(asyncio.BaseProtocol):
return self.ssl_context
async def start_tls(self):
async def start_tls(self) -> bool:
"""Perform handshakes for TLS.
If the handshake is successful, the XML stream will need
to be restarted.
"""
if self.transport is None:
raise ValueError("Transport should not be None")
self.event_when_connected = "tls_success"
ssl_context = self.get_ssl_context()
try:
@ -685,7 +807,7 @@ class XMLStream(asyncio.BaseProtocol):
self.connection_made(transp)
return True
def _start_keepalive(self, event):
def _start_keepalive(self, event: Any) -> None:
"""Begin sending whitespace periodically to keep the connection alive.
May be disabled by setting::
@ -702,11 +824,11 @@ class XMLStream(asyncio.BaseProtocol):
args=(' ',),
repeat=True)
def _remove_schedules(self, event):
def _remove_schedules(self, event: Any) -> None:
"""Remove some schedules that become pointless when disconnected"""
self.cancel_schedule('Whitespace Keepalive')
def start_stream_handler(self, xml):
def start_stream_handler(self, xml: ET.Element) -> None:
"""Perform any initialization actions, such as handshakes,
once the stream header has been sent.
@ -714,7 +836,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
pass
def register_stanza(self, stanza_class):
def register_stanza(self, stanza_class: Type[StanzaBase]) -> None:
"""Add a stanza object class as a known root stanza.
A root stanza is one that appears as a direct child of the stream's
@ -732,7 +854,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
self.__root_stanza.append(stanza_class)
def remove_stanza(self, stanza_class):
def remove_stanza(self, stanza_class: Type[StanzaBase]) -> None:
"""Remove a stanza from being a known root stanza.
A root stanza is one that appears as a direct child of the stream's
@ -744,7 +866,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
self.__root_stanza.remove(stanza_class)
def add_filter(self, mode, handler, order=None):
def add_filter(self, mode: FilterString, handler: Callable[[StanzaBase], Optional[StanzaBase]], order: Optional[int] = None) -> None:
"""Add a filter for incoming or outgoing stanzas.
These filters are applied before incoming stanzas are
@ -766,11 +888,11 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.__filters[mode].append(handler)
def del_filter(self, mode, handler):
def del_filter(self, mode: str, handler: Callable[[StanzaBase], Optional[StanzaBase]]) -> None:
"""Remove an incoming or outgoing filter."""
self.__filters[mode].remove(handler)
def register_handler(self, handler, before=None, after=None):
def register_handler(self, handler: BaseHandler, before: Optional[BaseHandler] = None, after: Optional[BaseHandler] = None) -> None:
"""Add a stream event handler that will be executed when a matching
stanza is received.
@ -782,7 +904,7 @@ class XMLStream(asyncio.BaseProtocol):
self.__handlers.append(handler)
handler.stream = weakref.ref(self)
def remove_handler(self, name):
def remove_handler(self, name: str) -> bool:
"""Remove any stream event handlers with the given name.
:param name: The name of the handler.
@ -831,9 +953,9 @@ class XMLStream(asyncio.BaseProtocol):
try:
return next(self._dns_answers)
except StopIteration:
return
return None
def add_event_handler(self, name, pointer, disposable=False):
def add_event_handler(self, name: str, pointer: Callable[..., Any], disposable: bool = False) -> None:
"""Add a custom event handler that will be executed whenever
its event is manually triggered.
@ -847,7 +969,7 @@ class XMLStream(asyncio.BaseProtocol):
self.__event_handlers[name] = []
self.__event_handlers[name].append((pointer, disposable))
def del_event_handler(self, name, pointer):
def del_event_handler(self, name: str, pointer: Callable[..., Any]) -> None:
"""Remove a function as a handler for an event.
:param name: The name of the event.
@ -858,21 +980,21 @@ class XMLStream(asyncio.BaseProtocol):
# Need to keep handlers that do not use
# the given function pointer
def filter_pointers(handler):
def filter_pointers(handler: Tuple[Callable[..., Any], bool]) -> bool:
return handler[0] != pointer
self.__event_handlers[name] = list(filter(
filter_pointers,
self.__event_handlers[name]))
def event_handled(self, name):
def event_handled(self, name: str) -> int:
"""Returns the number of registered handlers for an event.
:param name: The name of the event to check.
"""
return len(self.__event_handlers.get(name, []))
async def event_async(self, name: str, data: Any = {}):
async def event_async(self, name: str, data: Any = {}) -> None:
"""Manually trigger a custom event, but await coroutines immediately.
This event generator should only be called in situations when
@ -908,7 +1030,7 @@ class XMLStream(asyncio.BaseProtocol):
except Exception as e:
self.exception(e)
def event(self, name: str, data: Any = {}):
def event(self, name: str, data: Any = {}) -> None:
"""Manually trigger a custom event.
Coroutine handlers are wrapped into a future and sent into the
event loop for their execution, and not awaited.
@ -928,7 +1050,7 @@ class XMLStream(asyncio.BaseProtocol):
# If the callback is a coroutine, schedule it instead of
# running it directly
if iscoroutinefunction(handler_callback):
async def handler_callback_routine(cb):
async def handler_callback_routine(cb: Callable[[ElementBase], Any]) -> None:
try:
await cb(data)
except Exception as e:
@ -957,8 +1079,9 @@ class XMLStream(asyncio.BaseProtocol):
except ValueError:
pass
def schedule(self, name, seconds, callback, args=tuple(),
kwargs={}, repeat=False):
def schedule(self, name: str, seconds: int, callback: Callable[..., None],
args: Tuple[Any, ...] = tuple(),
kwargs: Dict[Any, Any] = {}, repeat: bool = False) -> None:
"""Schedule a callback function to execute after a given delay.
:param name: A unique name for the scheduled callback.
@ -986,21 +1109,21 @@ class XMLStream(asyncio.BaseProtocol):
# canceling scheduled_events[name]
self.scheduled_events[name] = handle
def cancel_schedule(self, name):
def cancel_schedule(self, name: str) -> None:
try:
handle = self.scheduled_events.pop(name)
handle.cancel()
except KeyError:
log.debug("Tried to cancel unscheduled event: %s" % (name,))
def _safe_cb_run(self, name, cb):
def _safe_cb_run(self, name: str, cb: Callable[[], None]) -> None:
log.debug('Scheduled event: %s', name)
try:
cb()
except Exception as e:
self.exception(e)
def _execute_and_reschedule(self, name, cb, seconds):
def _execute_and_reschedule(self, name: str, cb: Callable[[], None], seconds: int) -> None:
"""Simple method that calls the given callback, and then schedule itself to
be called after the given number of seconds.
"""
@ -1009,7 +1132,7 @@ class XMLStream(asyncio.BaseProtocol):
name, cb, seconds)
self.scheduled_events[name] = handle
def _execute_and_unschedule(self, name, cb):
def _execute_and_unschedule(self, name: str, cb: Callable[[], None]) -> None:
"""
Execute the callback and remove the handler for it.
"""
@ -1018,7 +1141,7 @@ class XMLStream(asyncio.BaseProtocol):
if name in self.scheduled_events:
del self.scheduled_events[name]
def incoming_filter(self, xml):
def incoming_filter(self, xml: ET.Element) -> ET.Element:
"""Filter incoming XML objects before they are processed.
Possible uses include remapping namespaces, or correcting elements
@ -1028,7 +1151,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
return xml
def _reset_sendq(self):
def _reset_sendq(self) -> None:
"""Clear sending tasks on session end"""
# Cancel all pending slow send tasks
log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
@ -1042,8 +1165,8 @@ class XMLStream(asyncio.BaseProtocol):
async def _continue_slow_send(
self,
task: asyncio.Task,
already_used: Set[Callable[[ElementBase], Optional[StanzaBase]]]
task: asyncio.Task[Optional[StanzaBase]],
already_used: Set[Filter]
) -> None:
"""
Used when an item in the send queue has taken too long to process.
@ -1062,12 +1185,14 @@ class XMLStream(asyncio.BaseProtocol):
if iscoroutinefunction(filter):
data = await filter(data)
else:
filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
return
if isinstance(data, ElementBase):
if isinstance(data, StanzaBase):
for filter in self.__filters['out_sync']:
filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
return
@ -1077,19 +1202,21 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.send_raw(data)
async def run_filters(self):
async def run_filters(self) -> NoReturn:
"""
Background loop that processes stanzas to send.
"""
while True:
data: Optional[Union[StanzaBase, str]]
(data, use_filters) = await self.waiting_queue.get()
try:
if isinstance(data, ElementBase):
if isinstance(data, StanzaBase):
if use_filters:
already_run_filters = set()
for filter in self.__filters['out']:
already_run_filters.add(filter)
if iscoroutinefunction(filter):
filter = cast(AsyncFilter, filter)
task = asyncio.create_task(filter(data))
completed, pending = await wait(
{task},
@ -1108,19 +1235,24 @@ class XMLStream(asyncio.BaseProtocol):
"Slow coroutine, rescheduling filters"
)
data = task.result()
else:
elif isinstance(data, StanzaBase):
filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
raise ContinueQueue('Empty stanza')
if isinstance(data, ElementBase):
if isinstance(data, StanzaBase):
if use_filters:
for filter in self.__filters['out_sync']:
filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
raise ContinueQueue('Empty stanza')
if isinstance(data, StanzaBase):
str_data = tostring(data.xml, xmlns=self.default_ns,
stream=self, top_level=True)
else:
str_data = data
self.send_raw(str_data)
else:
self.send_raw(data)
@ -1130,10 +1262,10 @@ class XMLStream(asyncio.BaseProtocol):
log.error('Exception raised in send queue:', exc_info=True)
self.waiting_queue.task_done()
def send(self, data, use_filters=True):
def send(self, data: Union[StanzaBase, str], use_filters: bool = True) -> None:
"""A wrapper for :meth:`send_raw()` for sending stanza objects.
:param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
:param data: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
stanza to send on the stream.
:param bool use_filters: Indicates if outgoing filters should be
applied to the given stanza data. Disabling
@ -1156,15 +1288,15 @@ class XMLStream(asyncio.BaseProtocol):
return
self.waiting_queue.put_nowait((data, use_filters))
def send_xml(self, data):
def send_xml(self, data: ET.Element) -> None:
"""Send an XML object on the stream
:param data: The :class:`~xml.etree.ElementTree.Element` XML object
to send on the stream.
"""
return self.send(tostring(data))
self.send(tostring(data))
def send_raw(self, data):
def send_raw(self, data: Union[str, bytes]) -> None:
"""Send raw data across the stream.
:param string data: Any bytes or utf-8 string value.
@ -1176,7 +1308,8 @@ class XMLStream(asyncio.BaseProtocol):
data = data.encode('utf-8')
self.transport.write(data)
def _build_stanza(self, xml, default_ns=None):
def _build_stanza(self, xml: ET.Element,
default_ns: Optional[str] = None) -> StanzaBase:
"""Create a stanza object from a given XML object.
If a specialized stanza type is not found for the XML, then
@ -1201,7 +1334,7 @@ class XMLStream(asyncio.BaseProtocol):
stanza['lang'] = self.peer_default_lang
return stanza
def _spawn_event(self, xml):
def _spawn_event(self, xml: ET.Element) -> None:
"""
Analyze incoming XML stanzas and convert them into stanza
objects if applicable and queue stream events to be processed
@ -1215,9 +1348,10 @@ class XMLStream(asyncio.BaseProtocol):
# Convert the raw XML object into a stanza object. If no registered
# stanza type applies, a generic StanzaBase stanza will be used.
stanza = self._build_stanza(xml)
stanza: Optional[StanzaBase] = self._build_stanza(xml)
for filter in self.__filters['in']:
if stanza is not None:
filter = cast(SyncFilter, filter)
stanza = filter(stanza)
if stanza is None:
return
@ -1244,7 +1378,7 @@ class XMLStream(asyncio.BaseProtocol):
if not handled:
stanza.unhandled()
def exception(self, exception):
def exception(self, exception: Exception) -> None:
"""Process an unknown exception.
Meant to be overridden.
@ -1253,7 +1387,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
pass
async def wait_until(self, event: str, timeout=30) -> Any:
async def wait_until(self, event: str, timeout: Union[int, float] = 30) -> Any:
"""Utility method to wake on the next firing of an event.
(Registers a disposable handler on it)
@ -1261,9 +1395,9 @@ class XMLStream(asyncio.BaseProtocol):
:param int timeout: Timeout
:raises: :class:`asyncio.TimeoutError` when the timeout is reached
"""
fut = asyncio.Future()
fut: Future[Any] = asyncio.Future()
def result_handler(event_data):
def result_handler(event_data: Any) -> None:
if not fut.done():
fut.set_result(event_data)
else:
@ -1280,19 +1414,19 @@ class XMLStream(asyncio.BaseProtocol):
return await asyncio.wait_for(fut, timeout)
@contextmanager
def event_handler(self, event: str, handler: Callable):
def event_handler(self, event: str, handler: Callable[..., Any]) -> Generator[None, None, None]:
"""
Context manager that adds then removes an event handler.
"""
self.add_event_handler(event, handler)
try:
yield
except Exception as exc:
except Exception:
raise
finally:
self.del_event_handler(event, handler)
def wrap(self, coroutine: Coroutine[Any, Any, Any]) -> Future:
def wrap(self, coroutine: Coroutine[None, None, T]) -> Future[T]:
"""Make a Future out of a coroutine with the current loop.
:param coroutine: The coroutine to wrap.