xmlstream: add more types
This commit is contained in:
parent
c07476e7de
commit
db48c8f4da
@ -9,17 +9,24 @@
|
||||
# :license: MIT, see LICENSE for more details
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Awaitable,
|
||||
Generator,
|
||||
Coroutine,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Union,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
NoReturn,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import socket as Socket
|
||||
@ -27,30 +34,66 @@ import ssl
|
||||
import weakref
|
||||
import uuid
|
||||
|
||||
import asyncio
|
||||
from asyncio import iscoroutinefunction, wait, Future
|
||||
from contextlib import contextmanager
|
||||
import xml.etree.ElementTree as ET
|
||||
from asyncio import (
|
||||
AbstractEventLoop,
|
||||
BaseTransport,
|
||||
Future,
|
||||
Task,
|
||||
TimerHandle,
|
||||
Transport,
|
||||
iscoroutinefunction,
|
||||
wait,
|
||||
)
|
||||
|
||||
from slixmpp.xmlstream import tostring
|
||||
from slixmpp.types import FilterString
|
||||
from slixmpp.xmlstream.tostring import tostring
|
||||
from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase
|
||||
from slixmpp.xmlstream.resolver import resolve, default_resolver
|
||||
from slixmpp.xmlstream.handler.base import BaseHandler
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
#: The time in seconds to wait before timing out waiting for response stanzas.
|
||||
RESPONSE_TIMEOUT = 30
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContinueQueue(Exception):
|
||||
"""
|
||||
Exception raised in the send queue to "continue" from within an inner loop
|
||||
"""
|
||||
|
||||
|
||||
class NotConnectedError(Exception):
|
||||
"""
|
||||
Raised when we try to send something over the wire but we are not
|
||||
connected.
|
||||
"""
|
||||
|
||||
|
||||
_T = TypeVar('_T', str, ElementBase, StanzaBase)
|
||||
|
||||
|
||||
SyncFilter = Callable[[StanzaBase], Optional[StanzaBase]]
|
||||
AsyncFilter = Callable[[StanzaBase], Awaitable[Optional[StanzaBase]]]
|
||||
|
||||
|
||||
Filter = Union[
|
||||
SyncFilter,
|
||||
AsyncFilter,
|
||||
]
|
||||
|
||||
_FiltersDict = Dict[str, List[Filter]]
|
||||
|
||||
Handler = Callable[[Any], Union[
|
||||
Any,
|
||||
Coroutine[Any, Any, Any]
|
||||
]]
|
||||
|
||||
|
||||
class XMLStream(asyncio.BaseProtocol):
|
||||
"""
|
||||
An XML stream connection manager and event dispatcher.
|
||||
@ -78,16 +121,156 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
:param int port: The port to use for the connection. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(self, host='', port=0):
|
||||
# The asyncio.Transport object provided by the connection_made()
|
||||
# callback when we are connected
|
||||
self.transport = None
|
||||
transport: Optional[Transport]
|
||||
|
||||
# The socket that is used internally by the transport object
|
||||
self.socket = None
|
||||
socket: Optional[ssl.SSLSocket]
|
||||
|
||||
# The backoff of the connect routine (increases exponentially
|
||||
# after each failure)
|
||||
_connect_loop_wait: float
|
||||
|
||||
parser: Optional[ET.XMLPullParser]
|
||||
xml_depth: int
|
||||
xml_root: Optional[ET.Element]
|
||||
|
||||
force_starttls: Optional[bool]
|
||||
disable_starttls: Optional[bool]
|
||||
|
||||
waiting_queue: asyncio.Queue[Tuple[Union[StanzaBase, str], bool]]
|
||||
|
||||
# A dict of {name: handle}
|
||||
scheduled_events: Dict[str, TimerHandle]
|
||||
|
||||
ssl_context: ssl.SSLContext
|
||||
|
||||
# The event to trigger when the create_connection() succeeds. It can
|
||||
# be "connected" or "tls_success" depending on the step we are at.
|
||||
event_when_connected: str
|
||||
|
||||
#: The list of accepted ciphers, in OpenSSL Format.
|
||||
#: It might be useful to override it for improved security
|
||||
#: over the python defaults.
|
||||
ciphers: Optional[str]
|
||||
|
||||
#: Path to a file containing certificates for verifying the
|
||||
#: server SSL certificate. A non-``None`` value will trigger
|
||||
#: certificate checking.
|
||||
#:
|
||||
#: .. note::
|
||||
#:
|
||||
#: On Mac OS X, certificates in the system keyring will
|
||||
#: be consulted, even if they are not in the provided file.
|
||||
ca_certs: Optional[str]
|
||||
|
||||
#: Path to a file containing a client certificate to use for
|
||||
#: authenticating via SASL EXTERNAL. If set, there must also
|
||||
#: be a corresponding `:attr:keyfile` value.
|
||||
certfile: Optional[str]
|
||||
|
||||
#: Path to a file containing the private key for the selected
|
||||
#: client certificate to use for authenticating via SASL EXTERNAL.
|
||||
keyfile: Optional[str]
|
||||
|
||||
# The asyncio event loop
|
||||
_loop: Optional[AbstractEventLoop]
|
||||
|
||||
#: The default port to return when querying DNS records.
|
||||
default_port: int
|
||||
|
||||
#: The domain to try when querying DNS records.
|
||||
default_domain: str
|
||||
|
||||
#: The expected name of the server, for validation.
|
||||
_expected_server_name: str
|
||||
_service_name: str
|
||||
|
||||
#: The desired, or actual, address of the connected server.
|
||||
address: Tuple[str, int]
|
||||
|
||||
#: Enable connecting to the server directly over SSL, in
|
||||
#: particular when the service provides two ports: one for
|
||||
#: non-SSL traffic and another for SSL traffic.
|
||||
use_ssl: bool
|
||||
|
||||
#: If set to ``True``, attempt to use IPv6.
|
||||
use_ipv6: bool
|
||||
|
||||
#: If set to ``True``, allow using the ``dnspython`` DNS library
|
||||
#: if available. If set to ``False``, the builtin DNS resolver
|
||||
#: will be used, even if ``dnspython`` is installed.
|
||||
use_aiodns: bool
|
||||
|
||||
#: Use CDATA for escaping instead of XML entities. Defaults
|
||||
#: to ``False``.
|
||||
use_cdata: bool
|
||||
|
||||
#: The default namespace of the stream content, not of the
|
||||
#: stream wrapper it
|
||||
default_ns: str
|
||||
|
||||
default_lang: Optional[str]
|
||||
peer_default_lang: Optional[str]
|
||||
|
||||
#: The namespace of the enveloping stream element.
|
||||
stream_ns: str
|
||||
|
||||
#: The default opening tag for the stream element.
|
||||
stream_header: str
|
||||
|
||||
#: The default closing tag for the stream element.
|
||||
stream_footer: str
|
||||
|
||||
#: If ``True``, periodically send a whitespace character over the
|
||||
#: wire to keep the connection alive. Mainly useful for connections
|
||||
#: traversing NAT.
|
||||
whitespace_keepalive: bool
|
||||
|
||||
#: The default interval between keepalive signals when
|
||||
#: :attr:`whitespace_keepalive` is enabled.
|
||||
whitespace_keepalive_interval: int
|
||||
|
||||
#: Flag for controlling if the session can be considered ended
|
||||
#: if the connection is terminated.
|
||||
end_session_on_disconnect: bool
|
||||
|
||||
#: A mapping of XML namespaces to well-known prefixes.
|
||||
namespace_map: dict
|
||||
|
||||
__root_stanza: List[Type[StanzaBase]]
|
||||
__handlers: List[BaseHandler]
|
||||
__event_handlers: Dict[str, List[Tuple[Handler, bool]]]
|
||||
__filters: _FiltersDict
|
||||
|
||||
# Current connection attempt (Future)
|
||||
_current_connection_attempt: Optional[Future[None]]
|
||||
|
||||
#: A list of DNS results that have not yet been tried.
|
||||
_dns_answers: Optional[Iterator[Tuple[str, str, int]]]
|
||||
|
||||
#: The service name to check with DNS SRV records. For
|
||||
#: example, setting this to ``'xmpp-client'`` would query the
|
||||
#: ``_xmpp-client._tcp`` service.
|
||||
dns_service: Optional[str]
|
||||
|
||||
#: The reason why we are disconnecting from the server
|
||||
disconnect_reason: Optional[str]
|
||||
|
||||
#: An asyncio Future being done when the stream is disconnected.
|
||||
disconnected: Future[bool]
|
||||
|
||||
# If the session has been started or not
|
||||
_session_started: bool
|
||||
# If we want to bypass the send() check (e.g. unit tests)
|
||||
_always_send_everything: bool
|
||||
|
||||
_run_out_filters: Optional[Future]
|
||||
__slow_tasks: List[Task]
|
||||
__queued_stanzas: List[Tuple[Union[StanzaBase, str], bool]]
|
||||
|
||||
def __init__(self, host: str = '', port: int = 0):
|
||||
self.transport = None
|
||||
self.socket = None
|
||||
self._connect_loop_wait = 0
|
||||
|
||||
self.parser = None
|
||||
@ -106,126 +289,60 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
self.ssl_context.check_hostname = False
|
||||
self.ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# The event to trigger when the create_connection() succeeds. It can
|
||||
# be "connected" or "tls_success" depending on the step we are at.
|
||||
self.event_when_connected = "connected"
|
||||
|
||||
#: The list of accepted ciphers, in OpenSSL Format.
|
||||
#: It might be useful to override it for improved security
|
||||
#: over the python defaults.
|
||||
self.ciphers = None
|
||||
|
||||
#: Path to a file containing certificates for verifying the
|
||||
#: server SSL certificate. A non-``None`` value will trigger
|
||||
#: certificate checking.
|
||||
#:
|
||||
#: .. note::
|
||||
#:
|
||||
#: On Mac OS X, certificates in the system keyring will
|
||||
#: be consulted, even if they are not in the provided file.
|
||||
self.ca_certs = None
|
||||
|
||||
#: Path to a file containing a client certificate to use for
|
||||
#: authenticating via SASL EXTERNAL. If set, there must also
|
||||
#: be a corresponding `:attr:keyfile` value.
|
||||
self.certfile = None
|
||||
|
||||
#: Path to a file containing the private key for the selected
|
||||
#: client certificate to use for authenticating via SASL EXTERNAL.
|
||||
self.keyfile = None
|
||||
|
||||
self._der_cert = None
|
||||
|
||||
# The asyncio event loop
|
||||
self._loop = None
|
||||
|
||||
#: The default port to return when querying DNS records.
|
||||
self.default_port = int(port)
|
||||
|
||||
#: The domain to try when querying DNS records.
|
||||
self.default_domain = ''
|
||||
|
||||
#: The expected name of the server, for validation.
|
||||
self._expected_server_name = ''
|
||||
self._service_name = ''
|
||||
|
||||
#: The desired, or actual, address of the connected server.
|
||||
self.address = (host, int(port))
|
||||
|
||||
#: Enable connecting to the server directly over SSL, in
|
||||
#: particular when the service provides two ports: one for
|
||||
#: non-SSL traffic and another for SSL traffic.
|
||||
self.use_ssl = False
|
||||
|
||||
#: If set to ``True``, attempt to use IPv6.
|
||||
self.use_ipv6 = True
|
||||
|
||||
#: If set to ``True``, allow using the ``dnspython`` DNS library
|
||||
#: if available. If set to ``False``, the builtin DNS resolver
|
||||
#: will be used, even if ``dnspython`` is installed.
|
||||
self.use_aiodns = True
|
||||
|
||||
#: Use CDATA for escaping instead of XML entities. Defaults
|
||||
#: to ``False``.
|
||||
self.use_cdata = False
|
||||
|
||||
#: The default namespace of the stream content, not of the
|
||||
#: stream wrapper itself.
|
||||
self.default_ns = ''
|
||||
|
||||
self.default_lang = None
|
||||
self.peer_default_lang = None
|
||||
|
||||
#: The namespace of the enveloping stream element.
|
||||
self.stream_ns = ''
|
||||
|
||||
#: The default opening tag for the stream element.
|
||||
self.stream_header = "<stream>"
|
||||
|
||||
#: The default closing tag for the stream element.
|
||||
self.stream_footer = "</stream>"
|
||||
|
||||
#: If ``True``, periodically send a whitespace character over the
|
||||
#: wire to keep the connection alive. Mainly useful for connections
|
||||
#: traversing NAT.
|
||||
self.whitespace_keepalive = True
|
||||
|
||||
#: The default interval between keepalive signals when
|
||||
#: :attr:`whitespace_keepalive` is enabled.
|
||||
self.whitespace_keepalive_interval = 300
|
||||
|
||||
#: Flag for controlling if the session can be considered ended
|
||||
#: if the connection is terminated.
|
||||
self.end_session_on_disconnect = True
|
||||
|
||||
#: A mapping of XML namespaces to well-known prefixes.
|
||||
self.namespace_map = {StanzaBase.xml_ns: 'xml'}
|
||||
|
||||
self.__root_stanza = []
|
||||
self.__handlers = []
|
||||
self.__event_handlers = {}
|
||||
self.__filters = {'in': [], 'out': [], 'out_sync': []}
|
||||
self.__filters = {
|
||||
'in': [], 'out': [], 'out_sync': []
|
||||
}
|
||||
|
||||
# Current connection attempt (Future)
|
||||
self._current_connection_attempt = None
|
||||
|
||||
#: A list of DNS results that have not yet been tried.
|
||||
self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None
|
||||
|
||||
#: The service name to check with DNS SRV records. For
|
||||
#: example, setting this to ``'xmpp-client'`` would query the
|
||||
#: ``_xmpp-client._tcp`` service.
|
||||
self._dns_answers = None
|
||||
self.dns_service = None
|
||||
|
||||
#: The reason why we are disconnecting from the server
|
||||
self.disconnect_reason = None
|
||||
|
||||
#: An asyncio Future being done when the stream is disconnected.
|
||||
self.disconnected: Future = Future()
|
||||
|
||||
# If the session has been started or not
|
||||
self.disconnected = Future()
|
||||
self._session_started = False
|
||||
# If we want to bypass the send() check (e.g. unit tests)
|
||||
self._always_send_everything = False
|
||||
|
||||
self.add_event_handler('disconnected', self._remove_schedules)
|
||||
@ -234,21 +351,21 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
self.add_event_handler('session_start', self._set_session_start)
|
||||
self.add_event_handler('session_resumed', self._set_session_start)
|
||||
|
||||
self._run_out_filters: Optional[Future] = None
|
||||
self.__slow_tasks: List[Future] = []
|
||||
self.__queued_stanzas: List[Tuple[StanzaBase, bool]] = []
|
||||
self._run_out_filters = None
|
||||
self.__slow_tasks = []
|
||||
self.__queued_stanzas = []
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
def loop(self) -> AbstractEventLoop:
|
||||
if self._loop is None:
|
||||
self._loop = asyncio.get_event_loop()
|
||||
return self._loop
|
||||
|
||||
@loop.setter
|
||||
def loop(self, value):
|
||||
def loop(self, value: AbstractEventLoop) -> None:
|
||||
self._loop = value
|
||||
|
||||
def new_id(self):
|
||||
def new_id(self) -> str:
|
||||
"""Generate and return a new stream ID in hexadecimal form.
|
||||
|
||||
Many stanzas, handlers, or matchers may require unique
|
||||
@ -257,7 +374,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
"""
|
||||
return uuid.uuid4().hex
|
||||
|
||||
def _set_session_start(self, event):
|
||||
def _set_session_start(self, event: Any) -> None:
|
||||
"""
|
||||
On session start, queue all pending stanzas to be sent.
|
||||
"""
|
||||
@ -266,17 +383,17 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
self.waiting_queue.put_nowait(stanza)
|
||||
self.__queued_stanzas = []
|
||||
|
||||
def _set_disconnected(self, event):
|
||||
def _set_disconnected(self, event: Any) -> None:
|
||||
self._session_started = False
|
||||
|
||||
def _set_disconnected_future(self):
|
||||
def _set_disconnected_future(self) -> None:
|
||||
"""Set the self.disconnected future on disconnect"""
|
||||
if not self.disconnected.done():
|
||||
self.disconnected.set_result(True)
|
||||
self.disconnected = asyncio.Future()
|
||||
|
||||
def connect(self, host='', port=0, use_ssl=False,
|
||||
force_starttls=True, disable_starttls=False):
|
||||
def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False,
|
||||
force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None:
|
||||
"""Create a new socket and connect to the server.
|
||||
|
||||
:param host: The name of the desired server for the connection.
|
||||
@ -327,7 +444,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
loop=self.loop,
|
||||
)
|
||||
|
||||
async def _connect_routine(self):
|
||||
async def _connect_routine(self) -> None:
|
||||
self.event_when_connected = "connected"
|
||||
|
||||
if self._connect_loop_wait > 0:
|
||||
@ -345,6 +462,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
# and try (host, port) as a last resort
|
||||
self._dns_answers = None
|
||||
|
||||
ssl_context: Optional[ssl.SSLContext]
|
||||
if self.use_ssl:
|
||||
ssl_context = self.get_ssl_context()
|
||||
else:
|
||||
@ -373,7 +491,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
loop=self.loop,
|
||||
)
|
||||
|
||||
def process(self, *, forever=True, timeout=None):
|
||||
def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None:
|
||||
"""Process all the available XMPP events (receiving or sending data on the
|
||||
socket(s), calling various registered callbacks, calling expired
|
||||
timers, handling signal events, etc). If timeout is None, this
|
||||
@ -386,12 +504,12 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
else:
|
||||
self.loop.run_until_complete(self.disconnected)
|
||||
else:
|
||||
tasks = [asyncio.sleep(timeout, loop=self.loop)]
|
||||
tasks: List[Future[bool]] = [asyncio.sleep(timeout, loop=self.loop)]
|
||||
if not forever:
|
||||
tasks.append(self.disconnected)
|
||||
self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))
|
||||
|
||||
def init_parser(self):
|
||||
def init_parser(self) -> None:
|
||||
"""init the XML parser. The parser must always be reset for each new
|
||||
connexion
|
||||
"""
|
||||
@ -399,11 +517,13 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
self.xml_root = None
|
||||
self.parser = ET.XMLPullParser(("start", "end"))
|
||||
|
||||
def connection_made(self, transport):
|
||||
def connection_made(self, transport: BaseTransport) -> None:
|
||||
"""Called when the TCP connection has been established with the server
|
||||
"""
|
||||
self.event(self.event_when_connected)
|
||||
self.transport = transport
|
||||
self.transport = cast(Transport, transport)
|
||||
if self.transport is None:
|
||||
raise ValueError("Transport cannot be none")
|
||||
self.socket = self.transport.get_extra_info(
|
||||
"ssl_object",
|
||||
default=self.transport.get_extra_info("socket")
|
||||
@ -413,7 +533,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
self.send_raw(self.stream_header)
|
||||
self._dns_answers = None
|
||||
|
||||
def data_received(self, data):
|
||||
def data_received(self, data: bytes) -> None:
|
||||
"""Called when incoming data is received on the socket.
|
||||
|
||||
We feed that data to the parser and the see if this produced any XML
|
||||
@ -467,18 +587,18 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
self.send(error)
|
||||
self.disconnect()
|
||||
|
||||
def is_connecting(self):
|
||||
def is_connecting(self) -> bool:
|
||||
return self._current_connection_attempt is not None
|
||||
|
||||
def is_connected(self):
|
||||
def is_connected(self) -> bool:
|
||||
return self.transport is not None
|
||||
|
||||
def eof_received(self):
|
||||
def eof_received(self) -> None:
|
||||
"""When the TCP connection is properly closed by the remote end
|
||||
"""
|
||||
self.event("eof_received")
|
||||
|
||||
def connection_lost(self, exception):
|
||||
def connection_lost(self, exception: Optional[BaseException]) -> None:
|
||||
"""On any kind of disconnection, initiated by us or not. This signals the
|
||||
closure of the TCP connection
|
||||
"""
|
||||
@ -493,9 +613,9 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
self._reset_sendq()
|
||||
self.event('session_end')
|
||||
self._set_disconnected_future()
|
||||
self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
|
||||
self.event("disconnected", self.disconnect_reason or exception)
|
||||
|
||||
def cancel_connection_attempt(self):
|
||||
def cancel_connection_attempt(self) -> None:
|
||||
"""
|
||||
Immediately cancel the current create_connection() Future.
|
||||
This is useful when a client using slixmpp tries to connect
|
||||
@ -506,7 +626,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
self._current_connection_attempt.cancel()
|
||||
self._current_connection_attempt = None
|
||||
|
||||
def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future:
|
||||
def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future[None]:
|
||||
"""Close the XML stream and wait for an acknowldgement from the server for
|
||||
at most `wait` seconds. After the given number of seconds has
|
||||
passed without a response from the server, or when the server
|
||||
@ -526,7 +646,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
# `disconnect(wait=True)` for ages. This doesn't mean anything to the
|
||||
# schedule call below. It would fortunately be converted to `1` later
|
||||
# down the call chain. Praise the implicit casts lord.
|
||||
if wait == True:
|
||||
if wait is True:
|
||||
wait = 2.0
|
||||
|
||||
if self.transport:
|
||||
@ -545,11 +665,11 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
else:
|
||||
self._set_disconnected_future()
|
||||
self.event("disconnected", reason)
|
||||
future = Future()
|
||||
future: Future[None] = Future()
|
||||
future.set_result(None)
|
||||
return future
|
||||
|
||||
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
|
||||
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float) -> None:
|
||||
"""Wait until the send queue is empty before disconnecting"""
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
@ -561,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
self.disconnect_reason = reason
|
||||
await self._end_stream_wait(wait)
|
||||
|
||||
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None):
|
||||
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None) -> None:
|
||||
"""
|
||||
Run abort() if we do not received the disconnected event
|
||||
after a waiting time.
|
||||
@ -578,7 +698,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
# that means the disconnect has already been handled
|
||||
pass
|
||||
|
||||
def abort(self):
|
||||
def abort(self) -> None:
|
||||
"""
|
||||
Forcibly close the connection
|
||||
"""
|
||||
@ -588,26 +708,26 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
self.transport.abort()
|
||||
self.event("killed")
|
||||
|
||||
def reconnect(self, wait=2.0, reason="Reconnecting"):
|
||||
def reconnect(self, wait: Union[int, float] = 2.0, reason: str = "Reconnecting") -> None:
|
||||
"""Calls disconnect(), and once we are disconnected (after the timeout, or
|
||||
when the server acknowledgement is received), call connect()
|
||||
"""
|
||||
log.debug("reconnecting...")
|
||||
async def handler(event):
|
||||
async def handler(event: Any) -> None:
|
||||
# We yield here to allow synchronous handlers to work first
|
||||
await asyncio.sleep(0, loop=self.loop)
|
||||
self.connect()
|
||||
self.add_event_handler('disconnected', handler, disposable=True)
|
||||
self.disconnect(wait, reason)
|
||||
|
||||
def configure_socket(self):
|
||||
def configure_socket(self) -> None:
|
||||
"""Set timeout and other options for self.socket.
|
||||
|
||||
Meant to be overridden.
|
||||
"""
|
||||
pass
|
||||
|
||||
def configure_dns(self, resolver, domain=None, port=None):
|
||||
def configure_dns(self, resolver: Any, domain: Optional[str] = None, port: Optional[int] = None) -> None:
|
||||
"""
|
||||
Configure and set options for a :class:`~dns.resolver.Resolver`
|
||||
instance, and other DNS related tasks. For example, you
|
||||
@ -624,7 +744,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_ssl_context(self):
|
||||
def get_ssl_context(self) -> ssl.SSLContext:
|
||||
"""
|
||||
Get SSL context.
|
||||
"""
|
||||
@ -644,12 +764,14 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
|
||||
return self.ssl_context
|
||||
|
||||
async def start_tls(self):
|
||||
async def start_tls(self) -> bool:
|
||||
"""Perform handshakes for TLS.
|
||||
|
||||
If the handshake is successful, the XML stream will need
|
||||
to be restarted.
|
||||
"""
|
||||
if self.transport is None:
|
||||
raise ValueError("Transport should not be None")
|
||||
self.event_when_connected = "tls_success"
|
||||
ssl_context = self.get_ssl_context()
|
||||
try:
|
||||
@ -685,7 +807,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
self.connection_made(transp)
|
||||
return True
|
||||
|
||||
def _start_keepalive(self, event):
|
||||
def _start_keepalive(self, event: Any) -> None:
|
||||
"""Begin sending whitespace periodically to keep the connection alive.
|
||||
|
||||
May be disabled by setting::
|
||||
@ -702,11 +824,11 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
args=(' ',),
|
||||
repeat=True)
|
||||
|
||||
def _remove_schedules(self, event):
|
||||
def _remove_schedules(self, event: Any) -> None:
|
||||
"""Remove some schedules that become pointless when disconnected"""
|
||||
self.cancel_schedule('Whitespace Keepalive')
|
||||
|
||||
def start_stream_handler(self, xml):
|
||||
def start_stream_handler(self, xml: ET.Element) -> None:
|
||||
"""Perform any initialization actions, such as handshakes,
|
||||
once the stream header has been sent.
|
||||
|
||||
@ -714,7 +836,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
"""
|
||||
pass
|
||||
|
||||
def register_stanza(self, stanza_class):
|
||||
def register_stanza(self, stanza_class: Type[StanzaBase]) -> None:
|
||||
"""Add a stanza object class as a known root stanza.
|
||||
|
||||
A root stanza is one that appears as a direct child of the stream's
|
||||
@ -732,7 +854,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
"""
|
||||
self.__root_stanza.append(stanza_class)
|
||||
|
||||
def remove_stanza(self, stanza_class):
|
||||
def remove_stanza(self, stanza_class: Type[StanzaBase]) -> None:
|
||||
"""Remove a stanza from being a known root stanza.
|
||||
|
||||
A root stanza is one that appears as a direct child of the stream's
|
||||
@ -744,7 +866,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
"""
|
||||
self.__root_stanza.remove(stanza_class)
|
||||
|
||||
def add_filter(self, mode, handler, order=None):
|
||||
def add_filter(self, mode: FilterString, handler: Callable[[StanzaBase], Optional[StanzaBase]], order: Optional[int] = None) -> None:
|
||||
"""Add a filter for incoming or outgoing stanzas.
|
||||
|
||||
These filters are applied before incoming stanzas are
|
||||
@ -766,11 +888,11 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
else:
|
||||
self.__filters[mode].append(handler)
|
||||
|
||||
def del_filter(self, mode, handler):
|
||||
def del_filter(self, mode: str, handler: Callable[[StanzaBase], Optional[StanzaBase]]) -> None:
|
||||
"""Remove an incoming or outgoing filter."""
|
||||
self.__filters[mode].remove(handler)
|
||||
|
||||
def register_handler(self, handler, before=None, after=None):
|
||||
def register_handler(self, handler: BaseHandler, before: Optional[BaseHandler] = None, after: Optional[BaseHandler] = None) -> None:
|
||||
"""Add a stream event handler that will be executed when a matching
|
||||
stanza is received.
|
||||
|
||||
@ -782,7 +904,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
self.__handlers.append(handler)
|
||||
handler.stream = weakref.ref(self)
|
||||
|
||||
def remove_handler(self, name):
|
||||
def remove_handler(self, name: str) -> bool:
|
||||
"""Remove any stream event handlers with the given name.
|
||||
|
||||
:param name: The name of the handler.
|
||||
@ -831,9 +953,9 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
try:
|
||||
return next(self._dns_answers)
|
||||
except StopIteration:
|
||||
return
|
||||
return None
|
||||
|
||||
def add_event_handler(self, name, pointer, disposable=False):
|
||||
def add_event_handler(self, name: str, pointer: Callable[..., Any], disposable: bool = False) -> None:
|
||||
"""Add a custom event handler that will be executed whenever
|
||||
its event is manually triggered.
|
||||
|
||||
@ -847,7 +969,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
self.__event_handlers[name] = []
|
||||
self.__event_handlers[name].append((pointer, disposable))
|
||||
|
||||
def del_event_handler(self, name, pointer):
|
||||
def del_event_handler(self, name: str, pointer: Callable[..., Any]) -> None:
|
||||
"""Remove a function as a handler for an event.
|
||||
|
||||
:param name: The name of the event.
|
||||
@ -858,21 +980,21 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
|
||||
# Need to keep handlers that do not use
|
||||
# the given function pointer
|
||||
def filter_pointers(handler):
|
||||
def filter_pointers(handler: Tuple[Callable[..., Any], bool]) -> bool:
|
||||
return handler[0] != pointer
|
||||
|
||||
self.__event_handlers[name] = list(filter(
|
||||
filter_pointers,
|
||||
self.__event_handlers[name]))
|
||||
|
||||
def event_handled(self, name):
|
||||
def event_handled(self, name: str) -> int:
|
||||
"""Returns the number of registered handlers for an event.
|
||||
|
||||
:param name: The name of the event to check.
|
||||
"""
|
||||
return len(self.__event_handlers.get(name, []))
|
||||
|
||||
async def event_async(self, name: str, data: Any = {}):
|
||||
async def event_async(self, name: str, data: Any = {}) -> None:
|
||||
"""Manually trigger a custom event, but await coroutines immediately.
|
||||
|
||||
This event generator should only be called in situations when
|
||||
@ -908,7 +1030,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
except Exception as e:
|
||||
self.exception(e)
|
||||
|
||||
def event(self, name: str, data: Any = {}):
|
||||
def event(self, name: str, data: Any = {}) -> None:
|
||||
"""Manually trigger a custom event.
|
||||
Coroutine handlers are wrapped into a future and sent into the
|
||||
event loop for their execution, and not awaited.
|
||||
@ -928,7 +1050,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
# If the callback is a coroutine, schedule it instead of
|
||||
# running it directly
|
||||
if iscoroutinefunction(handler_callback):
|
||||
async def handler_callback_routine(cb):
|
||||
async def handler_callback_routine(cb: Callable[[ElementBase], Any]) -> None:
|
||||
try:
|
||||
await cb(data)
|
||||
except Exception as e:
|
||||
@ -957,8 +1079,9 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def schedule(self, name, seconds, callback, args=tuple(),
|
||||
kwargs={}, repeat=False):
|
||||
def schedule(self, name: str, seconds: int, callback: Callable[..., None],
|
||||
args: Tuple[Any, ...] = tuple(),
|
||||
kwargs: Dict[Any, Any] = {}, repeat: bool = False) -> None:
|
||||
"""Schedule a callback function to execute after a given delay.
|
||||
|
||||
:param name: A unique name for the scheduled callback.
|
||||
@ -986,21 +1109,21 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
# canceling scheduled_events[name]
|
||||
self.scheduled_events[name] = handle
|
||||
|
||||
def cancel_schedule(self, name):
|
||||
def cancel_schedule(self, name: str) -> None:
|
||||
try:
|
||||
handle = self.scheduled_events.pop(name)
|
||||
handle.cancel()
|
||||
except KeyError:
|
||||
log.debug("Tried to cancel unscheduled event: %s" % (name,))
|
||||
|
||||
def _safe_cb_run(self, name, cb):
|
||||
def _safe_cb_run(self, name: str, cb: Callable[[], None]) -> None:
|
||||
log.debug('Scheduled event: %s', name)
|
||||
try:
|
||||
cb()
|
||||
except Exception as e:
|
||||
self.exception(e)
|
||||
|
||||
def _execute_and_reschedule(self, name, cb, seconds):
|
||||
def _execute_and_reschedule(self, name: str, cb: Callable[[], None], seconds: int) -> None:
|
||||
"""Simple method that calls the given callback, and then schedule itself to
|
||||
be called after the given number of seconds.
|
||||
"""
|
||||
@ -1009,7 +1132,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
name, cb, seconds)
|
||||
self.scheduled_events[name] = handle
|
||||
|
||||
def _execute_and_unschedule(self, name, cb):
|
||||
def _execute_and_unschedule(self, name: str, cb: Callable[[], None]) -> None:
|
||||
"""
|
||||
Execute the callback and remove the handler for it.
|
||||
"""
|
||||
@ -1018,7 +1141,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
if name in self.scheduled_events:
|
||||
del self.scheduled_events[name]
|
||||
|
||||
def incoming_filter(self, xml):
|
||||
def incoming_filter(self, xml: ET.Element) -> ET.Element:
|
||||
"""Filter incoming XML objects before they are processed.
|
||||
|
||||
Possible uses include remapping namespaces, or correcting elements
|
||||
@ -1028,7 +1151,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
"""
|
||||
return xml
|
||||
|
||||
def _reset_sendq(self):
|
||||
def _reset_sendq(self) -> None:
|
||||
"""Clear sending tasks on session end"""
|
||||
# Cancel all pending slow send tasks
|
||||
log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
|
||||
@ -1042,8 +1165,8 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
|
||||
async def _continue_slow_send(
|
||||
self,
|
||||
task: asyncio.Task,
|
||||
already_used: Set[Callable[[ElementBase], Optional[StanzaBase]]]
|
||||
task: asyncio.Task[Optional[StanzaBase]],
|
||||
already_used: Set[Filter]
|
||||
) -> None:
|
||||
"""
|
||||
Used when an item in the send queue has taken too long to process.
|
||||
@ -1062,12 +1185,14 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
if iscoroutinefunction(filter):
|
||||
data = await filter(data)
|
||||
else:
|
||||
filter = cast(SyncFilter, filter)
|
||||
data = filter(data)
|
||||
if data is None:
|
||||
return
|
||||
|
||||
if isinstance(data, ElementBase):
|
||||
if isinstance(data, StanzaBase):
|
||||
for filter in self.__filters['out_sync']:
|
||||
filter = cast(SyncFilter, filter)
|
||||
data = filter(data)
|
||||
if data is None:
|
||||
return
|
||||
@ -1077,19 +1202,21 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
else:
|
||||
self.send_raw(data)
|
||||
|
||||
async def run_filters(self):
|
||||
async def run_filters(self) -> NoReturn:
|
||||
"""
|
||||
Background loop that processes stanzas to send.
|
||||
"""
|
||||
while True:
|
||||
data: Optional[Union[StanzaBase, str]]
|
||||
(data, use_filters) = await self.waiting_queue.get()
|
||||
try:
|
||||
if isinstance(data, ElementBase):
|
||||
if isinstance(data, StanzaBase):
|
||||
if use_filters:
|
||||
already_run_filters = set()
|
||||
for filter in self.__filters['out']:
|
||||
already_run_filters.add(filter)
|
||||
if iscoroutinefunction(filter):
|
||||
filter = cast(AsyncFilter, filter)
|
||||
task = asyncio.create_task(filter(data))
|
||||
completed, pending = await wait(
|
||||
{task},
|
||||
@ -1108,19 +1235,24 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
"Slow coroutine, rescheduling filters"
|
||||
)
|
||||
data = task.result()
|
||||
else:
|
||||
elif isinstance(data, StanzaBase):
|
||||
filter = cast(SyncFilter, filter)
|
||||
data = filter(data)
|
||||
if data is None:
|
||||
raise ContinueQueue('Empty stanza')
|
||||
|
||||
if isinstance(data, ElementBase):
|
||||
if isinstance(data, StanzaBase):
|
||||
if use_filters:
|
||||
for filter in self.__filters['out_sync']:
|
||||
filter = cast(SyncFilter, filter)
|
||||
data = filter(data)
|
||||
if data is None:
|
||||
raise ContinueQueue('Empty stanza')
|
||||
if isinstance(data, StanzaBase):
|
||||
str_data = tostring(data.xml, xmlns=self.default_ns,
|
||||
stream=self, top_level=True)
|
||||
else:
|
||||
str_data = data
|
||||
self.send_raw(str_data)
|
||||
else:
|
||||
self.send_raw(data)
|
||||
@ -1130,10 +1262,10 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
log.error('Exception raised in send queue:', exc_info=True)
|
||||
self.waiting_queue.task_done()
|
||||
|
||||
def send(self, data, use_filters=True):
|
||||
def send(self, data: Union[StanzaBase, str], use_filters: bool = True) -> None:
|
||||
"""A wrapper for :meth:`send_raw()` for sending stanza objects.
|
||||
|
||||
:param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
|
||||
:param data: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
|
||||
stanza to send on the stream.
|
||||
:param bool use_filters: Indicates if outgoing filters should be
|
||||
applied to the given stanza data. Disabling
|
||||
@ -1156,15 +1288,15 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
return
|
||||
self.waiting_queue.put_nowait((data, use_filters))
|
||||
|
||||
def send_xml(self, data):
|
||||
def send_xml(self, data: ET.Element) -> None:
|
||||
"""Send an XML object on the stream
|
||||
|
||||
:param data: The :class:`~xml.etree.ElementTree.Element` XML object
|
||||
to send on the stream.
|
||||
"""
|
||||
return self.send(tostring(data))
|
||||
self.send(tostring(data))
|
||||
|
||||
def send_raw(self, data):
|
||||
def send_raw(self, data: Union[str, bytes]) -> None:
|
||||
"""Send raw data across the stream.
|
||||
|
||||
:param string data: Any bytes or utf-8 string value.
|
||||
@ -1176,7 +1308,8 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
data = data.encode('utf-8')
|
||||
self.transport.write(data)
|
||||
|
||||
def _build_stanza(self, xml, default_ns=None):
|
||||
def _build_stanza(self, xml: ET.Element,
|
||||
default_ns: Optional[str] = None) -> StanzaBase:
|
||||
"""Create a stanza object from a given XML object.
|
||||
|
||||
If a specialized stanza type is not found for the XML, then
|
||||
@ -1201,7 +1334,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
stanza['lang'] = self.peer_default_lang
|
||||
return stanza
|
||||
|
||||
def _spawn_event(self, xml):
|
||||
def _spawn_event(self, xml: ET.Element) -> None:
|
||||
"""
|
||||
Analyze incoming XML stanzas and convert them into stanza
|
||||
objects if applicable and queue stream events to be processed
|
||||
@ -1215,9 +1348,10 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
|
||||
# Convert the raw XML object into a stanza object. If no registered
|
||||
# stanza type applies, a generic StanzaBase stanza will be used.
|
||||
stanza = self._build_stanza(xml)
|
||||
stanza: Optional[StanzaBase] = self._build_stanza(xml)
|
||||
for filter in self.__filters['in']:
|
||||
if stanza is not None:
|
||||
filter = cast(SyncFilter, filter)
|
||||
stanza = filter(stanza)
|
||||
if stanza is None:
|
||||
return
|
||||
@ -1244,7 +1378,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
if not handled:
|
||||
stanza.unhandled()
|
||||
|
||||
def exception(self, exception):
|
||||
def exception(self, exception: Exception) -> None:
|
||||
"""Process an unknown exception.
|
||||
|
||||
Meant to be overridden.
|
||||
@ -1253,7 +1387,7 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def wait_until(self, event: str, timeout=30) -> Any:
|
||||
async def wait_until(self, event: str, timeout: Union[int, float] = 30) -> Any:
|
||||
"""Utility method to wake on the next firing of an event.
|
||||
(Registers a disposable handler on it)
|
||||
|
||||
@ -1261,9 +1395,9 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
:param int timeout: Timeout
|
||||
:raises: :class:`asyncio.TimeoutError` when the timeout is reached
|
||||
"""
|
||||
fut = asyncio.Future()
|
||||
fut: Future[Any] = asyncio.Future()
|
||||
|
||||
def result_handler(event_data):
|
||||
def result_handler(event_data: Any) -> None:
|
||||
if not fut.done():
|
||||
fut.set_result(event_data)
|
||||
else:
|
||||
@ -1280,19 +1414,19 @@ class XMLStream(asyncio.BaseProtocol):
|
||||
return await asyncio.wait_for(fut, timeout)
|
||||
|
||||
@contextmanager
|
||||
def event_handler(self, event: str, handler: Callable):
|
||||
def event_handler(self, event: str, handler: Callable[..., Any]) -> Generator[None, None, None]:
|
||||
"""
|
||||
Context manager that adds then removes an event handler.
|
||||
"""
|
||||
self.add_event_handler(event, handler)
|
||||
try:
|
||||
yield
|
||||
except Exception as exc:
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
self.del_event_handler(event, handler)
|
||||
|
||||
def wrap(self, coroutine: Coroutine[Any, Any, Any]) -> Future:
|
||||
def wrap(self, coroutine: Coroutine[None, None, T]) -> Future[T]:
|
||||
"""Make a Future out of a coroutine with the current loop.
|
||||
|
||||
:param coroutine: The coroutine to wrap.
|
||||
|
Loading…
Reference in New Issue
Block a user