xmlstream: add more types
This commit is contained in:
parent
c07476e7de
commit
db48c8f4da
@ -9,17 +9,24 @@
|
|||||||
# :license: MIT, see LICENSE for more details
|
# :license: MIT, see LICENSE for more details
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
Dict,
|
||||||
|
Awaitable,
|
||||||
|
Generator,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
Callable,
|
Callable,
|
||||||
Iterable,
|
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
Union,
|
Union,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
NoReturn,
|
||||||
|
Type,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import socket as Socket
|
import socket as Socket
|
||||||
@ -27,30 +34,66 @@ import ssl
|
|||||||
import weakref
|
import weakref
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from asyncio import iscoroutinefunction, wait, Future
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
from asyncio import (
|
||||||
|
AbstractEventLoop,
|
||||||
|
BaseTransport,
|
||||||
|
Future,
|
||||||
|
Task,
|
||||||
|
TimerHandle,
|
||||||
|
Transport,
|
||||||
|
iscoroutinefunction,
|
||||||
|
wait,
|
||||||
|
)
|
||||||
|
|
||||||
from slixmpp.xmlstream import tostring
|
from slixmpp.types import FilterString
|
||||||
|
from slixmpp.xmlstream.tostring import tostring
|
||||||
from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase
|
from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase
|
||||||
from slixmpp.xmlstream.resolver import resolve, default_resolver
|
from slixmpp.xmlstream.resolver import resolve, default_resolver
|
||||||
|
from slixmpp.xmlstream.handler.base import BaseHandler
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
#: The time in seconds to wait before timing out waiting for response stanzas.
|
#: The time in seconds to wait before timing out waiting for response stanzas.
|
||||||
RESPONSE_TIMEOUT = 30
|
RESPONSE_TIMEOUT = 30
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ContinueQueue(Exception):
|
class ContinueQueue(Exception):
|
||||||
"""
|
"""
|
||||||
Exception raised in the send queue to "continue" from within an inner loop
|
Exception raised in the send queue to "continue" from within an inner loop
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class NotConnectedError(Exception):
|
class NotConnectedError(Exception):
|
||||||
"""
|
"""
|
||||||
Raised when we try to send something over the wire but we are not
|
Raised when we try to send something over the wire but we are not
|
||||||
connected.
|
connected.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
_T = TypeVar('_T', str, ElementBase, StanzaBase)
|
||||||
|
|
||||||
|
|
||||||
|
SyncFilter = Callable[[StanzaBase], Optional[StanzaBase]]
|
||||||
|
AsyncFilter = Callable[[StanzaBase], Awaitable[Optional[StanzaBase]]]
|
||||||
|
|
||||||
|
|
||||||
|
Filter = Union[
|
||||||
|
SyncFilter,
|
||||||
|
AsyncFilter,
|
||||||
|
]
|
||||||
|
|
||||||
|
_FiltersDict = Dict[str, List[Filter]]
|
||||||
|
|
||||||
|
Handler = Callable[[Any], Union[
|
||||||
|
Any,
|
||||||
|
Coroutine[Any, Any, Any]
|
||||||
|
]]
|
||||||
|
|
||||||
|
|
||||||
class XMLStream(asyncio.BaseProtocol):
|
class XMLStream(asyncio.BaseProtocol):
|
||||||
"""
|
"""
|
||||||
An XML stream connection manager and event dispatcher.
|
An XML stream connection manager and event dispatcher.
|
||||||
@ -78,16 +121,156 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
:param int port: The port to use for the connection. Defaults to 0.
|
:param int port: The port to use for the connection. Defaults to 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, host='', port=0):
|
transport: Optional[Transport]
|
||||||
# The asyncio.Transport object provided by the connection_made()
|
|
||||||
# callback when we are connected
|
|
||||||
self.transport = None
|
|
||||||
|
|
||||||
# The socket that is used internally by the transport object
|
# The socket that is used internally by the transport object
|
||||||
self.socket = None
|
socket: Optional[ssl.SSLSocket]
|
||||||
|
|
||||||
# The backoff of the connect routine (increases exponentially
|
# The backoff of the connect routine (increases exponentially
|
||||||
# after each failure)
|
# after each failure)
|
||||||
|
_connect_loop_wait: float
|
||||||
|
|
||||||
|
parser: Optional[ET.XMLPullParser]
|
||||||
|
xml_depth: int
|
||||||
|
xml_root: Optional[ET.Element]
|
||||||
|
|
||||||
|
force_starttls: Optional[bool]
|
||||||
|
disable_starttls: Optional[bool]
|
||||||
|
|
||||||
|
waiting_queue: asyncio.Queue[Tuple[Union[StanzaBase, str], bool]]
|
||||||
|
|
||||||
|
# A dict of {name: handle}
|
||||||
|
scheduled_events: Dict[str, TimerHandle]
|
||||||
|
|
||||||
|
ssl_context: ssl.SSLContext
|
||||||
|
|
||||||
|
# The event to trigger when the create_connection() succeeds. It can
|
||||||
|
# be "connected" or "tls_success" depending on the step we are at.
|
||||||
|
event_when_connected: str
|
||||||
|
|
||||||
|
#: The list of accepted ciphers, in OpenSSL Format.
|
||||||
|
#: It might be useful to override it for improved security
|
||||||
|
#: over the python defaults.
|
||||||
|
ciphers: Optional[str]
|
||||||
|
|
||||||
|
#: Path to a file containing certificates for verifying the
|
||||||
|
#: server SSL certificate. A non-``None`` value will trigger
|
||||||
|
#: certificate checking.
|
||||||
|
#:
|
||||||
|
#: .. note::
|
||||||
|
#:
|
||||||
|
#: On Mac OS X, certificates in the system keyring will
|
||||||
|
#: be consulted, even if they are not in the provided file.
|
||||||
|
ca_certs: Optional[str]
|
||||||
|
|
||||||
|
#: Path to a file containing a client certificate to use for
|
||||||
|
#: authenticating via SASL EXTERNAL. If set, there must also
|
||||||
|
#: be a corresponding `:attr:keyfile` value.
|
||||||
|
certfile: Optional[str]
|
||||||
|
|
||||||
|
#: Path to a file containing the private key for the selected
|
||||||
|
#: client certificate to use for authenticating via SASL EXTERNAL.
|
||||||
|
keyfile: Optional[str]
|
||||||
|
|
||||||
|
# The asyncio event loop
|
||||||
|
_loop: Optional[AbstractEventLoop]
|
||||||
|
|
||||||
|
#: The default port to return when querying DNS records.
|
||||||
|
default_port: int
|
||||||
|
|
||||||
|
#: The domain to try when querying DNS records.
|
||||||
|
default_domain: str
|
||||||
|
|
||||||
|
#: The expected name of the server, for validation.
|
||||||
|
_expected_server_name: str
|
||||||
|
_service_name: str
|
||||||
|
|
||||||
|
#: The desired, or actual, address of the connected server.
|
||||||
|
address: Tuple[str, int]
|
||||||
|
|
||||||
|
#: Enable connecting to the server directly over SSL, in
|
||||||
|
#: particular when the service provides two ports: one for
|
||||||
|
#: non-SSL traffic and another for SSL traffic.
|
||||||
|
use_ssl: bool
|
||||||
|
|
||||||
|
#: If set to ``True``, attempt to use IPv6.
|
||||||
|
use_ipv6: bool
|
||||||
|
|
||||||
|
#: If set to ``True``, allow using the ``dnspython`` DNS library
|
||||||
|
#: if available. If set to ``False``, the builtin DNS resolver
|
||||||
|
#: will be used, even if ``dnspython`` is installed.
|
||||||
|
use_aiodns: bool
|
||||||
|
|
||||||
|
#: Use CDATA for escaping instead of XML entities. Defaults
|
||||||
|
#: to ``False``.
|
||||||
|
use_cdata: bool
|
||||||
|
|
||||||
|
#: The default namespace of the stream content, not of the
|
||||||
|
#: stream wrapper it
|
||||||
|
default_ns: str
|
||||||
|
|
||||||
|
default_lang: Optional[str]
|
||||||
|
peer_default_lang: Optional[str]
|
||||||
|
|
||||||
|
#: The namespace of the enveloping stream element.
|
||||||
|
stream_ns: str
|
||||||
|
|
||||||
|
#: The default opening tag for the stream element.
|
||||||
|
stream_header: str
|
||||||
|
|
||||||
|
#: The default closing tag for the stream element.
|
||||||
|
stream_footer: str
|
||||||
|
|
||||||
|
#: If ``True``, periodically send a whitespace character over the
|
||||||
|
#: wire to keep the connection alive. Mainly useful for connections
|
||||||
|
#: traversing NAT.
|
||||||
|
whitespace_keepalive: bool
|
||||||
|
|
||||||
|
#: The default interval between keepalive signals when
|
||||||
|
#: :attr:`whitespace_keepalive` is enabled.
|
||||||
|
whitespace_keepalive_interval: int
|
||||||
|
|
||||||
|
#: Flag for controlling if the session can be considered ended
|
||||||
|
#: if the connection is terminated.
|
||||||
|
end_session_on_disconnect: bool
|
||||||
|
|
||||||
|
#: A mapping of XML namespaces to well-known prefixes.
|
||||||
|
namespace_map: dict
|
||||||
|
|
||||||
|
__root_stanza: List[Type[StanzaBase]]
|
||||||
|
__handlers: List[BaseHandler]
|
||||||
|
__event_handlers: Dict[str, List[Tuple[Handler, bool]]]
|
||||||
|
__filters: _FiltersDict
|
||||||
|
|
||||||
|
# Current connection attempt (Future)
|
||||||
|
_current_connection_attempt: Optional[Future[None]]
|
||||||
|
|
||||||
|
#: A list of DNS results that have not yet been tried.
|
||||||
|
_dns_answers: Optional[Iterator[Tuple[str, str, int]]]
|
||||||
|
|
||||||
|
#: The service name to check with DNS SRV records. For
|
||||||
|
#: example, setting this to ``'xmpp-client'`` would query the
|
||||||
|
#: ``_xmpp-client._tcp`` service.
|
||||||
|
dns_service: Optional[str]
|
||||||
|
|
||||||
|
#: The reason why we are disconnecting from the server
|
||||||
|
disconnect_reason: Optional[str]
|
||||||
|
|
||||||
|
#: An asyncio Future being done when the stream is disconnected.
|
||||||
|
disconnected: Future[bool]
|
||||||
|
|
||||||
|
# If the session has been started or not
|
||||||
|
_session_started: bool
|
||||||
|
# If we want to bypass the send() check (e.g. unit tests)
|
||||||
|
_always_send_everything: bool
|
||||||
|
|
||||||
|
_run_out_filters: Optional[Future]
|
||||||
|
__slow_tasks: List[Task]
|
||||||
|
__queued_stanzas: List[Tuple[Union[StanzaBase, str], bool]]
|
||||||
|
|
||||||
|
def __init__(self, host: str = '', port: int = 0):
|
||||||
|
self.transport = None
|
||||||
|
self.socket = None
|
||||||
self._connect_loop_wait = 0
|
self._connect_loop_wait = 0
|
||||||
|
|
||||||
self.parser = None
|
self.parser = None
|
||||||
@ -106,126 +289,60 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.ssl_context.check_hostname = False
|
self.ssl_context.check_hostname = False
|
||||||
self.ssl_context.verify_mode = ssl.CERT_NONE
|
self.ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
# The event to trigger when the create_connection() succeeds. It can
|
|
||||||
# be "connected" or "tls_success" depending on the step we are at.
|
|
||||||
self.event_when_connected = "connected"
|
self.event_when_connected = "connected"
|
||||||
|
|
||||||
#: The list of accepted ciphers, in OpenSSL Format.
|
|
||||||
#: It might be useful to override it for improved security
|
|
||||||
#: over the python defaults.
|
|
||||||
self.ciphers = None
|
self.ciphers = None
|
||||||
|
|
||||||
#: Path to a file containing certificates for verifying the
|
|
||||||
#: server SSL certificate. A non-``None`` value will trigger
|
|
||||||
#: certificate checking.
|
|
||||||
#:
|
|
||||||
#: .. note::
|
|
||||||
#:
|
|
||||||
#: On Mac OS X, certificates in the system keyring will
|
|
||||||
#: be consulted, even if they are not in the provided file.
|
|
||||||
self.ca_certs = None
|
self.ca_certs = None
|
||||||
|
|
||||||
#: Path to a file containing a client certificate to use for
|
|
||||||
#: authenticating via SASL EXTERNAL. If set, there must also
|
|
||||||
#: be a corresponding `:attr:keyfile` value.
|
|
||||||
self.certfile = None
|
|
||||||
|
|
||||||
#: Path to a file containing the private key for the selected
|
|
||||||
#: client certificate to use for authenticating via SASL EXTERNAL.
|
|
||||||
self.keyfile = None
|
self.keyfile = None
|
||||||
|
|
||||||
self._der_cert = None
|
|
||||||
|
|
||||||
# The asyncio event loop
|
|
||||||
self._loop = None
|
self._loop = None
|
||||||
|
|
||||||
#: The default port to return when querying DNS records.
|
|
||||||
self.default_port = int(port)
|
self.default_port = int(port)
|
||||||
|
|
||||||
#: The domain to try when querying DNS records.
|
|
||||||
self.default_domain = ''
|
self.default_domain = ''
|
||||||
|
|
||||||
#: The expected name of the server, for validation.
|
|
||||||
self._expected_server_name = ''
|
self._expected_server_name = ''
|
||||||
self._service_name = ''
|
self._service_name = ''
|
||||||
|
|
||||||
#: The desired, or actual, address of the connected server.
|
|
||||||
self.address = (host, int(port))
|
self.address = (host, int(port))
|
||||||
|
|
||||||
#: Enable connecting to the server directly over SSL, in
|
|
||||||
#: particular when the service provides two ports: one for
|
|
||||||
#: non-SSL traffic and another for SSL traffic.
|
|
||||||
self.use_ssl = False
|
self.use_ssl = False
|
||||||
|
|
||||||
#: If set to ``True``, attempt to use IPv6.
|
|
||||||
self.use_ipv6 = True
|
self.use_ipv6 = True
|
||||||
|
|
||||||
#: If set to ``True``, allow using the ``dnspython`` DNS library
|
|
||||||
#: if available. If set to ``False``, the builtin DNS resolver
|
|
||||||
#: will be used, even if ``dnspython`` is installed.
|
|
||||||
self.use_aiodns = True
|
self.use_aiodns = True
|
||||||
|
|
||||||
#: Use CDATA for escaping instead of XML entities. Defaults
|
|
||||||
#: to ``False``.
|
|
||||||
self.use_cdata = False
|
self.use_cdata = False
|
||||||
|
|
||||||
#: The default namespace of the stream content, not of the
|
|
||||||
#: stream wrapper itself.
|
|
||||||
self.default_ns = ''
|
self.default_ns = ''
|
||||||
|
|
||||||
self.default_lang = None
|
self.default_lang = None
|
||||||
self.peer_default_lang = None
|
self.peer_default_lang = None
|
||||||
|
|
||||||
#: The namespace of the enveloping stream element.
|
|
||||||
self.stream_ns = ''
|
self.stream_ns = ''
|
||||||
|
|
||||||
#: The default opening tag for the stream element.
|
|
||||||
self.stream_header = "<stream>"
|
self.stream_header = "<stream>"
|
||||||
|
|
||||||
#: The default closing tag for the stream element.
|
|
||||||
self.stream_footer = "</stream>"
|
self.stream_footer = "</stream>"
|
||||||
|
|
||||||
#: If ``True``, periodically send a whitespace character over the
|
|
||||||
#: wire to keep the connection alive. Mainly useful for connections
|
|
||||||
#: traversing NAT.
|
|
||||||
self.whitespace_keepalive = True
|
self.whitespace_keepalive = True
|
||||||
|
|
||||||
#: The default interval between keepalive signals when
|
|
||||||
#: :attr:`whitespace_keepalive` is enabled.
|
|
||||||
self.whitespace_keepalive_interval = 300
|
self.whitespace_keepalive_interval = 300
|
||||||
|
|
||||||
#: Flag for controlling if the session can be considered ended
|
|
||||||
#: if the connection is terminated.
|
|
||||||
self.end_session_on_disconnect = True
|
self.end_session_on_disconnect = True
|
||||||
|
|
||||||
#: A mapping of XML namespaces to well-known prefixes.
|
|
||||||
self.namespace_map = {StanzaBase.xml_ns: 'xml'}
|
self.namespace_map = {StanzaBase.xml_ns: 'xml'}
|
||||||
|
|
||||||
self.__root_stanza = []
|
self.__root_stanza = []
|
||||||
self.__handlers = []
|
self.__handlers = []
|
||||||
self.__event_handlers = {}
|
self.__event_handlers = {}
|
||||||
self.__filters = {'in': [], 'out': [], 'out_sync': []}
|
self.__filters = {
|
||||||
|
'in': [], 'out': [], 'out_sync': []
|
||||||
|
}
|
||||||
|
|
||||||
# Current connection attempt (Future)
|
|
||||||
self._current_connection_attempt = None
|
self._current_connection_attempt = None
|
||||||
|
|
||||||
#: A list of DNS results that have not yet been tried.
|
self._dns_answers = None
|
||||||
self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None
|
|
||||||
|
|
||||||
#: The service name to check with DNS SRV records. For
|
|
||||||
#: example, setting this to ``'xmpp-client'`` would query the
|
|
||||||
#: ``_xmpp-client._tcp`` service.
|
|
||||||
self.dns_service = None
|
self.dns_service = None
|
||||||
|
|
||||||
#: The reason why we are disconnecting from the server
|
|
||||||
self.disconnect_reason = None
|
self.disconnect_reason = None
|
||||||
|
self.disconnected = Future()
|
||||||
#: An asyncio Future being done when the stream is disconnected.
|
|
||||||
self.disconnected: Future = Future()
|
|
||||||
|
|
||||||
# If the session has been started or not
|
|
||||||
self._session_started = False
|
self._session_started = False
|
||||||
# If we want to bypass the send() check (e.g. unit tests)
|
|
||||||
self._always_send_everything = False
|
self._always_send_everything = False
|
||||||
|
|
||||||
self.add_event_handler('disconnected', self._remove_schedules)
|
self.add_event_handler('disconnected', self._remove_schedules)
|
||||||
@ -234,21 +351,21 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.add_event_handler('session_start', self._set_session_start)
|
self.add_event_handler('session_start', self._set_session_start)
|
||||||
self.add_event_handler('session_resumed', self._set_session_start)
|
self.add_event_handler('session_resumed', self._set_session_start)
|
||||||
|
|
||||||
self._run_out_filters: Optional[Future] = None
|
self._run_out_filters = None
|
||||||
self.__slow_tasks: List[Future] = []
|
self.__slow_tasks = []
|
||||||
self.__queued_stanzas: List[Tuple[StanzaBase, bool]] = []
|
self.__queued_stanzas = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loop(self):
|
def loop(self) -> AbstractEventLoop:
|
||||||
if self._loop is None:
|
if self._loop is None:
|
||||||
self._loop = asyncio.get_event_loop()
|
self._loop = asyncio.get_event_loop()
|
||||||
return self._loop
|
return self._loop
|
||||||
|
|
||||||
@loop.setter
|
@loop.setter
|
||||||
def loop(self, value):
|
def loop(self, value: AbstractEventLoop) -> None:
|
||||||
self._loop = value
|
self._loop = value
|
||||||
|
|
||||||
def new_id(self):
|
def new_id(self) -> str:
|
||||||
"""Generate and return a new stream ID in hexadecimal form.
|
"""Generate and return a new stream ID in hexadecimal form.
|
||||||
|
|
||||||
Many stanzas, handlers, or matchers may require unique
|
Many stanzas, handlers, or matchers may require unique
|
||||||
@ -257,7 +374,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
"""
|
"""
|
||||||
return uuid.uuid4().hex
|
return uuid.uuid4().hex
|
||||||
|
|
||||||
def _set_session_start(self, event):
|
def _set_session_start(self, event: Any) -> None:
|
||||||
"""
|
"""
|
||||||
On session start, queue all pending stanzas to be sent.
|
On session start, queue all pending stanzas to be sent.
|
||||||
"""
|
"""
|
||||||
@ -266,17 +383,17 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.waiting_queue.put_nowait(stanza)
|
self.waiting_queue.put_nowait(stanza)
|
||||||
self.__queued_stanzas = []
|
self.__queued_stanzas = []
|
||||||
|
|
||||||
def _set_disconnected(self, event):
|
def _set_disconnected(self, event: Any) -> None:
|
||||||
self._session_started = False
|
self._session_started = False
|
||||||
|
|
||||||
def _set_disconnected_future(self):
|
def _set_disconnected_future(self) -> None:
|
||||||
"""Set the self.disconnected future on disconnect"""
|
"""Set the self.disconnected future on disconnect"""
|
||||||
if not self.disconnected.done():
|
if not self.disconnected.done():
|
||||||
self.disconnected.set_result(True)
|
self.disconnected.set_result(True)
|
||||||
self.disconnected = asyncio.Future()
|
self.disconnected = asyncio.Future()
|
||||||
|
|
||||||
def connect(self, host='', port=0, use_ssl=False,
|
def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False,
|
||||||
force_starttls=True, disable_starttls=False):
|
force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None:
|
||||||
"""Create a new socket and connect to the server.
|
"""Create a new socket and connect to the server.
|
||||||
|
|
||||||
:param host: The name of the desired server for the connection.
|
:param host: The name of the desired server for the connection.
|
||||||
@ -327,7 +444,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
loop=self.loop,
|
loop=self.loop,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _connect_routine(self):
|
async def _connect_routine(self) -> None:
|
||||||
self.event_when_connected = "connected"
|
self.event_when_connected = "connected"
|
||||||
|
|
||||||
if self._connect_loop_wait > 0:
|
if self._connect_loop_wait > 0:
|
||||||
@ -345,6 +462,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
# and try (host, port) as a last resort
|
# and try (host, port) as a last resort
|
||||||
self._dns_answers = None
|
self._dns_answers = None
|
||||||
|
|
||||||
|
ssl_context: Optional[ssl.SSLContext]
|
||||||
if self.use_ssl:
|
if self.use_ssl:
|
||||||
ssl_context = self.get_ssl_context()
|
ssl_context = self.get_ssl_context()
|
||||||
else:
|
else:
|
||||||
@ -373,7 +491,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
loop=self.loop,
|
loop=self.loop,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, *, forever=True, timeout=None):
|
def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None:
|
||||||
"""Process all the available XMPP events (receiving or sending data on the
|
"""Process all the available XMPP events (receiving or sending data on the
|
||||||
socket(s), calling various registered callbacks, calling expired
|
socket(s), calling various registered callbacks, calling expired
|
||||||
timers, handling signal events, etc). If timeout is None, this
|
timers, handling signal events, etc). If timeout is None, this
|
||||||
@ -386,12 +504,12 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
else:
|
else:
|
||||||
self.loop.run_until_complete(self.disconnected)
|
self.loop.run_until_complete(self.disconnected)
|
||||||
else:
|
else:
|
||||||
tasks = [asyncio.sleep(timeout, loop=self.loop)]
|
tasks: List[Future[bool]] = [asyncio.sleep(timeout, loop=self.loop)]
|
||||||
if not forever:
|
if not forever:
|
||||||
tasks.append(self.disconnected)
|
tasks.append(self.disconnected)
|
||||||
self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))
|
self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))
|
||||||
|
|
||||||
def init_parser(self):
|
def init_parser(self) -> None:
|
||||||
"""init the XML parser. The parser must always be reset for each new
|
"""init the XML parser. The parser must always be reset for each new
|
||||||
connexion
|
connexion
|
||||||
"""
|
"""
|
||||||
@ -399,11 +517,13 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.xml_root = None
|
self.xml_root = None
|
||||||
self.parser = ET.XMLPullParser(("start", "end"))
|
self.parser = ET.XMLPullParser(("start", "end"))
|
||||||
|
|
||||||
def connection_made(self, transport):
|
def connection_made(self, transport: BaseTransport) -> None:
|
||||||
"""Called when the TCP connection has been established with the server
|
"""Called when the TCP connection has been established with the server
|
||||||
"""
|
"""
|
||||||
self.event(self.event_when_connected)
|
self.event(self.event_when_connected)
|
||||||
self.transport = transport
|
self.transport = cast(Transport, transport)
|
||||||
|
if self.transport is None:
|
||||||
|
raise ValueError("Transport cannot be none")
|
||||||
self.socket = self.transport.get_extra_info(
|
self.socket = self.transport.get_extra_info(
|
||||||
"ssl_object",
|
"ssl_object",
|
||||||
default=self.transport.get_extra_info("socket")
|
default=self.transport.get_extra_info("socket")
|
||||||
@ -413,7 +533,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.send_raw(self.stream_header)
|
self.send_raw(self.stream_header)
|
||||||
self._dns_answers = None
|
self._dns_answers = None
|
||||||
|
|
||||||
def data_received(self, data):
|
def data_received(self, data: bytes) -> None:
|
||||||
"""Called when incoming data is received on the socket.
|
"""Called when incoming data is received on the socket.
|
||||||
|
|
||||||
We feed that data to the parser and the see if this produced any XML
|
We feed that data to the parser and the see if this produced any XML
|
||||||
@ -467,18 +587,18 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.send(error)
|
self.send(error)
|
||||||
self.disconnect()
|
self.disconnect()
|
||||||
|
|
||||||
def is_connecting(self):
|
def is_connecting(self) -> bool:
|
||||||
return self._current_connection_attempt is not None
|
return self._current_connection_attempt is not None
|
||||||
|
|
||||||
def is_connected(self):
|
def is_connected(self) -> bool:
|
||||||
return self.transport is not None
|
return self.transport is not None
|
||||||
|
|
||||||
def eof_received(self):
|
def eof_received(self) -> None:
|
||||||
"""When the TCP connection is properly closed by the remote end
|
"""When the TCP connection is properly closed by the remote end
|
||||||
"""
|
"""
|
||||||
self.event("eof_received")
|
self.event("eof_received")
|
||||||
|
|
||||||
def connection_lost(self, exception):
|
def connection_lost(self, exception: Optional[BaseException]) -> None:
|
||||||
"""On any kind of disconnection, initiated by us or not. This signals the
|
"""On any kind of disconnection, initiated by us or not. This signals the
|
||||||
closure of the TCP connection
|
closure of the TCP connection
|
||||||
"""
|
"""
|
||||||
@ -493,9 +613,9 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self._reset_sendq()
|
self._reset_sendq()
|
||||||
self.event('session_end')
|
self.event('session_end')
|
||||||
self._set_disconnected_future()
|
self._set_disconnected_future()
|
||||||
self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
|
self.event("disconnected", self.disconnect_reason or exception)
|
||||||
|
|
||||||
def cancel_connection_attempt(self):
|
def cancel_connection_attempt(self) -> None:
|
||||||
"""
|
"""
|
||||||
Immediately cancel the current create_connection() Future.
|
Immediately cancel the current create_connection() Future.
|
||||||
This is useful when a client using slixmpp tries to connect
|
This is useful when a client using slixmpp tries to connect
|
||||||
@ -506,7 +626,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self._current_connection_attempt.cancel()
|
self._current_connection_attempt.cancel()
|
||||||
self._current_connection_attempt = None
|
self._current_connection_attempt = None
|
||||||
|
|
||||||
def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future:
|
def disconnect(self, wait: Union[float, int] = 2.0, reason: Optional[str] = None, ignore_send_queue: bool = False) -> Future[None]:
|
||||||
"""Close the XML stream and wait for an acknowldgement from the server for
|
"""Close the XML stream and wait for an acknowldgement from the server for
|
||||||
at most `wait` seconds. After the given number of seconds has
|
at most `wait` seconds. After the given number of seconds has
|
||||||
passed without a response from the server, or when the server
|
passed without a response from the server, or when the server
|
||||||
@ -526,7 +646,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
# `disconnect(wait=True)` for ages. This doesn't mean anything to the
|
# `disconnect(wait=True)` for ages. This doesn't mean anything to the
|
||||||
# schedule call below. It would fortunately be converted to `1` later
|
# schedule call below. It would fortunately be converted to `1` later
|
||||||
# down the call chain. Praise the implicit casts lord.
|
# down the call chain. Praise the implicit casts lord.
|
||||||
if wait == True:
|
if wait is True:
|
||||||
wait = 2.0
|
wait = 2.0
|
||||||
|
|
||||||
if self.transport:
|
if self.transport:
|
||||||
@ -545,11 +665,11 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
else:
|
else:
|
||||||
self._set_disconnected_future()
|
self._set_disconnected_future()
|
||||||
self.event("disconnected", reason)
|
self.event("disconnected", reason)
|
||||||
future = Future()
|
future: Future[None] = Future()
|
||||||
future.set_result(None)
|
future.set_result(None)
|
||||||
return future
|
return future
|
||||||
|
|
||||||
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
|
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float) -> None:
|
||||||
"""Wait until the send queue is empty before disconnecting"""
|
"""Wait until the send queue is empty before disconnecting"""
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
@ -561,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.disconnect_reason = reason
|
self.disconnect_reason = reason
|
||||||
await self._end_stream_wait(wait)
|
await self._end_stream_wait(wait)
|
||||||
|
|
||||||
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None):
|
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Run abort() if we do not received the disconnected event
|
Run abort() if we do not received the disconnected event
|
||||||
after a waiting time.
|
after a waiting time.
|
||||||
@ -578,7 +698,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
# that means the disconnect has already been handled
|
# that means the disconnect has already been handled
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def abort(self):
|
def abort(self) -> None:
|
||||||
"""
|
"""
|
||||||
Forcibly close the connection
|
Forcibly close the connection
|
||||||
"""
|
"""
|
||||||
@ -588,26 +708,26 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.transport.abort()
|
self.transport.abort()
|
||||||
self.event("killed")
|
self.event("killed")
|
||||||
|
|
||||||
def reconnect(self, wait=2.0, reason="Reconnecting"):
|
def reconnect(self, wait: Union[int, float] = 2.0, reason: str = "Reconnecting") -> None:
|
||||||
"""Calls disconnect(), and once we are disconnected (after the timeout, or
|
"""Calls disconnect(), and once we are disconnected (after the timeout, or
|
||||||
when the server acknowledgement is received), call connect()
|
when the server acknowledgement is received), call connect()
|
||||||
"""
|
"""
|
||||||
log.debug("reconnecting...")
|
log.debug("reconnecting...")
|
||||||
async def handler(event):
|
async def handler(event: Any) -> None:
|
||||||
# We yield here to allow synchronous handlers to work first
|
# We yield here to allow synchronous handlers to work first
|
||||||
await asyncio.sleep(0, loop=self.loop)
|
await asyncio.sleep(0, loop=self.loop)
|
||||||
self.connect()
|
self.connect()
|
||||||
self.add_event_handler('disconnected', handler, disposable=True)
|
self.add_event_handler('disconnected', handler, disposable=True)
|
||||||
self.disconnect(wait, reason)
|
self.disconnect(wait, reason)
|
||||||
|
|
||||||
def configure_socket(self):
|
def configure_socket(self) -> None:
|
||||||
"""Set timeout and other options for self.socket.
|
"""Set timeout and other options for self.socket.
|
||||||
|
|
||||||
Meant to be overridden.
|
Meant to be overridden.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def configure_dns(self, resolver, domain=None, port=None):
|
def configure_dns(self, resolver: Any, domain: Optional[str] = None, port: Optional[int] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Configure and set options for a :class:`~dns.resolver.Resolver`
|
Configure and set options for a :class:`~dns.resolver.Resolver`
|
||||||
instance, and other DNS related tasks. For example, you
|
instance, and other DNS related tasks. For example, you
|
||||||
@ -624,7 +744,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_ssl_context(self):
|
def get_ssl_context(self) -> ssl.SSLContext:
|
||||||
"""
|
"""
|
||||||
Get SSL context.
|
Get SSL context.
|
||||||
"""
|
"""
|
||||||
@ -644,12 +764,14 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
|
|
||||||
return self.ssl_context
|
return self.ssl_context
|
||||||
|
|
||||||
async def start_tls(self):
|
async def start_tls(self) -> bool:
|
||||||
"""Perform handshakes for TLS.
|
"""Perform handshakes for TLS.
|
||||||
|
|
||||||
If the handshake is successful, the XML stream will need
|
If the handshake is successful, the XML stream will need
|
||||||
to be restarted.
|
to be restarted.
|
||||||
"""
|
"""
|
||||||
|
if self.transport is None:
|
||||||
|
raise ValueError("Transport should not be None")
|
||||||
self.event_when_connected = "tls_success"
|
self.event_when_connected = "tls_success"
|
||||||
ssl_context = self.get_ssl_context()
|
ssl_context = self.get_ssl_context()
|
||||||
try:
|
try:
|
||||||
@ -685,7 +807,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.connection_made(transp)
|
self.connection_made(transp)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _start_keepalive(self, event):
|
def _start_keepalive(self, event: Any) -> None:
|
||||||
"""Begin sending whitespace periodically to keep the connection alive.
|
"""Begin sending whitespace periodically to keep the connection alive.
|
||||||
|
|
||||||
May be disabled by setting::
|
May be disabled by setting::
|
||||||
@ -702,11 +824,11 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
args=(' ',),
|
args=(' ',),
|
||||||
repeat=True)
|
repeat=True)
|
||||||
|
|
||||||
def _remove_schedules(self, event):
|
def _remove_schedules(self, event: Any) -> None:
|
||||||
"""Remove some schedules that become pointless when disconnected"""
|
"""Remove some schedules that become pointless when disconnected"""
|
||||||
self.cancel_schedule('Whitespace Keepalive')
|
self.cancel_schedule('Whitespace Keepalive')
|
||||||
|
|
||||||
def start_stream_handler(self, xml):
|
def start_stream_handler(self, xml: ET.Element) -> None:
|
||||||
"""Perform any initialization actions, such as handshakes,
|
"""Perform any initialization actions, such as handshakes,
|
||||||
once the stream header has been sent.
|
once the stream header has been sent.
|
||||||
|
|
||||||
@ -714,7 +836,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def register_stanza(self, stanza_class):
|
def register_stanza(self, stanza_class: Type[StanzaBase]) -> None:
|
||||||
"""Add a stanza object class as a known root stanza.
|
"""Add a stanza object class as a known root stanza.
|
||||||
|
|
||||||
A root stanza is one that appears as a direct child of the stream's
|
A root stanza is one that appears as a direct child of the stream's
|
||||||
@ -732,7 +854,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
"""
|
"""
|
||||||
self.__root_stanza.append(stanza_class)
|
self.__root_stanza.append(stanza_class)
|
||||||
|
|
||||||
def remove_stanza(self, stanza_class):
|
def remove_stanza(self, stanza_class: Type[StanzaBase]) -> None:
|
||||||
"""Remove a stanza from being a known root stanza.
|
"""Remove a stanza from being a known root stanza.
|
||||||
|
|
||||||
A root stanza is one that appears as a direct child of the stream's
|
A root stanza is one that appears as a direct child of the stream's
|
||||||
@ -744,7 +866,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
"""
|
"""
|
||||||
self.__root_stanza.remove(stanza_class)
|
self.__root_stanza.remove(stanza_class)
|
||||||
|
|
||||||
def add_filter(self, mode, handler, order=None):
|
def add_filter(self, mode: FilterString, handler: Callable[[StanzaBase], Optional[StanzaBase]], order: Optional[int] = None) -> None:
|
||||||
"""Add a filter for incoming or outgoing stanzas.
|
"""Add a filter for incoming or outgoing stanzas.
|
||||||
|
|
||||||
These filters are applied before incoming stanzas are
|
These filters are applied before incoming stanzas are
|
||||||
@ -766,11 +888,11 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
else:
|
else:
|
||||||
self.__filters[mode].append(handler)
|
self.__filters[mode].append(handler)
|
||||||
|
|
||||||
def del_filter(self, mode, handler):
|
def del_filter(self, mode: str, handler: Callable[[StanzaBase], Optional[StanzaBase]]) -> None:
|
||||||
"""Remove an incoming or outgoing filter."""
|
"""Remove an incoming or outgoing filter."""
|
||||||
self.__filters[mode].remove(handler)
|
self.__filters[mode].remove(handler)
|
||||||
|
|
||||||
def register_handler(self, handler, before=None, after=None):
|
def register_handler(self, handler: BaseHandler, before: Optional[BaseHandler] = None, after: Optional[BaseHandler] = None) -> None:
|
||||||
"""Add a stream event handler that will be executed when a matching
|
"""Add a stream event handler that will be executed when a matching
|
||||||
stanza is received.
|
stanza is received.
|
||||||
|
|
||||||
@ -782,7 +904,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.__handlers.append(handler)
|
self.__handlers.append(handler)
|
||||||
handler.stream = weakref.ref(self)
|
handler.stream = weakref.ref(self)
|
||||||
|
|
||||||
def remove_handler(self, name):
|
def remove_handler(self, name: str) -> bool:
|
||||||
"""Remove any stream event handlers with the given name.
|
"""Remove any stream event handlers with the given name.
|
||||||
|
|
||||||
:param name: The name of the handler.
|
:param name: The name of the handler.
|
||||||
@ -831,9 +953,9 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
try:
|
try:
|
||||||
return next(self._dns_answers)
|
return next(self._dns_answers)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
return
|
return None
|
||||||
|
|
||||||
def add_event_handler(self, name, pointer, disposable=False):
|
def add_event_handler(self, name: str, pointer: Callable[..., Any], disposable: bool = False) -> None:
|
||||||
"""Add a custom event handler that will be executed whenever
|
"""Add a custom event handler that will be executed whenever
|
||||||
its event is manually triggered.
|
its event is manually triggered.
|
||||||
|
|
||||||
@ -847,7 +969,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.__event_handlers[name] = []
|
self.__event_handlers[name] = []
|
||||||
self.__event_handlers[name].append((pointer, disposable))
|
self.__event_handlers[name].append((pointer, disposable))
|
||||||
|
|
||||||
def del_event_handler(self, name, pointer):
|
def del_event_handler(self, name: str, pointer: Callable[..., Any]) -> None:
|
||||||
"""Remove a function as a handler for an event.
|
"""Remove a function as a handler for an event.
|
||||||
|
|
||||||
:param name: The name of the event.
|
:param name: The name of the event.
|
||||||
@ -858,21 +980,21 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
|
|
||||||
# Need to keep handlers that do not use
|
# Need to keep handlers that do not use
|
||||||
# the given function pointer
|
# the given function pointer
|
||||||
def filter_pointers(handler):
|
def filter_pointers(handler: Tuple[Callable[..., Any], bool]) -> bool:
|
||||||
return handler[0] != pointer
|
return handler[0] != pointer
|
||||||
|
|
||||||
self.__event_handlers[name] = list(filter(
|
self.__event_handlers[name] = list(filter(
|
||||||
filter_pointers,
|
filter_pointers,
|
||||||
self.__event_handlers[name]))
|
self.__event_handlers[name]))
|
||||||
|
|
||||||
def event_handled(self, name):
|
def event_handled(self, name: str) -> int:
|
||||||
"""Returns the number of registered handlers for an event.
|
"""Returns the number of registered handlers for an event.
|
||||||
|
|
||||||
:param name: The name of the event to check.
|
:param name: The name of the event to check.
|
||||||
"""
|
"""
|
||||||
return len(self.__event_handlers.get(name, []))
|
return len(self.__event_handlers.get(name, []))
|
||||||
|
|
||||||
async def event_async(self, name: str, data: Any = {}):
|
async def event_async(self, name: str, data: Any = {}) -> None:
|
||||||
"""Manually trigger a custom event, but await coroutines immediately.
|
"""Manually trigger a custom event, but await coroutines immediately.
|
||||||
|
|
||||||
This event generator should only be called in situations when
|
This event generator should only be called in situations when
|
||||||
@ -908,7 +1030,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.exception(e)
|
self.exception(e)
|
||||||
|
|
||||||
def event(self, name: str, data: Any = {}):
|
def event(self, name: str, data: Any = {}) -> None:
|
||||||
"""Manually trigger a custom event.
|
"""Manually trigger a custom event.
|
||||||
Coroutine handlers are wrapped into a future and sent into the
|
Coroutine handlers are wrapped into a future and sent into the
|
||||||
event loop for their execution, and not awaited.
|
event loop for their execution, and not awaited.
|
||||||
@ -928,7 +1050,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
# If the callback is a coroutine, schedule it instead of
|
# If the callback is a coroutine, schedule it instead of
|
||||||
# running it directly
|
# running it directly
|
||||||
if iscoroutinefunction(handler_callback):
|
if iscoroutinefunction(handler_callback):
|
||||||
async def handler_callback_routine(cb):
|
async def handler_callback_routine(cb: Callable[[ElementBase], Any]) -> None:
|
||||||
try:
|
try:
|
||||||
await cb(data)
|
await cb(data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -957,8 +1079,9 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def schedule(self, name, seconds, callback, args=tuple(),
|
def schedule(self, name: str, seconds: int, callback: Callable[..., None],
|
||||||
kwargs={}, repeat=False):
|
args: Tuple[Any, ...] = tuple(),
|
||||||
|
kwargs: Dict[Any, Any] = {}, repeat: bool = False) -> None:
|
||||||
"""Schedule a callback function to execute after a given delay.
|
"""Schedule a callback function to execute after a given delay.
|
||||||
|
|
||||||
:param name: A unique name for the scheduled callback.
|
:param name: A unique name for the scheduled callback.
|
||||||
@ -986,21 +1109,21 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
# canceling scheduled_events[name]
|
# canceling scheduled_events[name]
|
||||||
self.scheduled_events[name] = handle
|
self.scheduled_events[name] = handle
|
||||||
|
|
||||||
def cancel_schedule(self, name):
|
def cancel_schedule(self, name: str) -> None:
|
||||||
try:
|
try:
|
||||||
handle = self.scheduled_events.pop(name)
|
handle = self.scheduled_events.pop(name)
|
||||||
handle.cancel()
|
handle.cancel()
|
||||||
except KeyError:
|
except KeyError:
|
||||||
log.debug("Tried to cancel unscheduled event: %s" % (name,))
|
log.debug("Tried to cancel unscheduled event: %s" % (name,))
|
||||||
|
|
||||||
def _safe_cb_run(self, name, cb):
|
def _safe_cb_run(self, name: str, cb: Callable[[], None]) -> None:
|
||||||
log.debug('Scheduled event: %s', name)
|
log.debug('Scheduled event: %s', name)
|
||||||
try:
|
try:
|
||||||
cb()
|
cb()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.exception(e)
|
self.exception(e)
|
||||||
|
|
||||||
def _execute_and_reschedule(self, name, cb, seconds):
|
def _execute_and_reschedule(self, name: str, cb: Callable[[], None], seconds: int) -> None:
|
||||||
"""Simple method that calls the given callback, and then schedule itself to
|
"""Simple method that calls the given callback, and then schedule itself to
|
||||||
be called after the given number of seconds.
|
be called after the given number of seconds.
|
||||||
"""
|
"""
|
||||||
@ -1009,7 +1132,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
name, cb, seconds)
|
name, cb, seconds)
|
||||||
self.scheduled_events[name] = handle
|
self.scheduled_events[name] = handle
|
||||||
|
|
||||||
def _execute_and_unschedule(self, name, cb):
|
def _execute_and_unschedule(self, name: str, cb: Callable[[], None]) -> None:
|
||||||
"""
|
"""
|
||||||
Execute the callback and remove the handler for it.
|
Execute the callback and remove the handler for it.
|
||||||
"""
|
"""
|
||||||
@ -1018,7 +1141,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
if name in self.scheduled_events:
|
if name in self.scheduled_events:
|
||||||
del self.scheduled_events[name]
|
del self.scheduled_events[name]
|
||||||
|
|
||||||
def incoming_filter(self, xml):
|
def incoming_filter(self, xml: ET.Element) -> ET.Element:
|
||||||
"""Filter incoming XML objects before they are processed.
|
"""Filter incoming XML objects before they are processed.
|
||||||
|
|
||||||
Possible uses include remapping namespaces, or correcting elements
|
Possible uses include remapping namespaces, or correcting elements
|
||||||
@ -1028,7 +1151,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
"""
|
"""
|
||||||
return xml
|
return xml
|
||||||
|
|
||||||
def _reset_sendq(self):
|
def _reset_sendq(self) -> None:
|
||||||
"""Clear sending tasks on session end"""
|
"""Clear sending tasks on session end"""
|
||||||
# Cancel all pending slow send tasks
|
# Cancel all pending slow send tasks
|
||||||
log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
|
log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
|
||||||
@ -1042,8 +1165,8 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
|
|
||||||
async def _continue_slow_send(
|
async def _continue_slow_send(
|
||||||
self,
|
self,
|
||||||
task: asyncio.Task,
|
task: asyncio.Task[Optional[StanzaBase]],
|
||||||
already_used: Set[Callable[[ElementBase], Optional[StanzaBase]]]
|
already_used: Set[Filter]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Used when an item in the send queue has taken too long to process.
|
Used when an item in the send queue has taken too long to process.
|
||||||
@ -1062,12 +1185,14 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
if iscoroutinefunction(filter):
|
if iscoroutinefunction(filter):
|
||||||
data = await filter(data)
|
data = await filter(data)
|
||||||
else:
|
else:
|
||||||
|
filter = cast(SyncFilter, filter)
|
||||||
data = filter(data)
|
data = filter(data)
|
||||||
if data is None:
|
if data is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(data, ElementBase):
|
if isinstance(data, StanzaBase):
|
||||||
for filter in self.__filters['out_sync']:
|
for filter in self.__filters['out_sync']:
|
||||||
|
filter = cast(SyncFilter, filter)
|
||||||
data = filter(data)
|
data = filter(data)
|
||||||
if data is None:
|
if data is None:
|
||||||
return
|
return
|
||||||
@ -1077,19 +1202,21 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
else:
|
else:
|
||||||
self.send_raw(data)
|
self.send_raw(data)
|
||||||
|
|
||||||
async def run_filters(self):
|
async def run_filters(self) -> NoReturn:
|
||||||
"""
|
"""
|
||||||
Background loop that processes stanzas to send.
|
Background loop that processes stanzas to send.
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
|
data: Optional[Union[StanzaBase, str]]
|
||||||
(data, use_filters) = await self.waiting_queue.get()
|
(data, use_filters) = await self.waiting_queue.get()
|
||||||
try:
|
try:
|
||||||
if isinstance(data, ElementBase):
|
if isinstance(data, StanzaBase):
|
||||||
if use_filters:
|
if use_filters:
|
||||||
already_run_filters = set()
|
already_run_filters = set()
|
||||||
for filter in self.__filters['out']:
|
for filter in self.__filters['out']:
|
||||||
already_run_filters.add(filter)
|
already_run_filters.add(filter)
|
||||||
if iscoroutinefunction(filter):
|
if iscoroutinefunction(filter):
|
||||||
|
filter = cast(AsyncFilter, filter)
|
||||||
task = asyncio.create_task(filter(data))
|
task = asyncio.create_task(filter(data))
|
||||||
completed, pending = await wait(
|
completed, pending = await wait(
|
||||||
{task},
|
{task},
|
||||||
@ -1108,19 +1235,24 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
"Slow coroutine, rescheduling filters"
|
"Slow coroutine, rescheduling filters"
|
||||||
)
|
)
|
||||||
data = task.result()
|
data = task.result()
|
||||||
else:
|
elif isinstance(data, StanzaBase):
|
||||||
|
filter = cast(SyncFilter, filter)
|
||||||
data = filter(data)
|
data = filter(data)
|
||||||
if data is None:
|
if data is None:
|
||||||
raise ContinueQueue('Empty stanza')
|
raise ContinueQueue('Empty stanza')
|
||||||
|
|
||||||
if isinstance(data, ElementBase):
|
if isinstance(data, StanzaBase):
|
||||||
if use_filters:
|
if use_filters:
|
||||||
for filter in self.__filters['out_sync']:
|
for filter in self.__filters['out_sync']:
|
||||||
|
filter = cast(SyncFilter, filter)
|
||||||
data = filter(data)
|
data = filter(data)
|
||||||
if data is None:
|
if data is None:
|
||||||
raise ContinueQueue('Empty stanza')
|
raise ContinueQueue('Empty stanza')
|
||||||
|
if isinstance(data, StanzaBase):
|
||||||
str_data = tostring(data.xml, xmlns=self.default_ns,
|
str_data = tostring(data.xml, xmlns=self.default_ns,
|
||||||
stream=self, top_level=True)
|
stream=self, top_level=True)
|
||||||
|
else:
|
||||||
|
str_data = data
|
||||||
self.send_raw(str_data)
|
self.send_raw(str_data)
|
||||||
else:
|
else:
|
||||||
self.send_raw(data)
|
self.send_raw(data)
|
||||||
@ -1130,10 +1262,10 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
log.error('Exception raised in send queue:', exc_info=True)
|
log.error('Exception raised in send queue:', exc_info=True)
|
||||||
self.waiting_queue.task_done()
|
self.waiting_queue.task_done()
|
||||||
|
|
||||||
def send(self, data, use_filters=True):
|
def send(self, data: Union[StanzaBase, str], use_filters: bool = True) -> None:
|
||||||
"""A wrapper for :meth:`send_raw()` for sending stanza objects.
|
"""A wrapper for :meth:`send_raw()` for sending stanza objects.
|
||||||
|
|
||||||
:param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
|
:param data: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
|
||||||
stanza to send on the stream.
|
stanza to send on the stream.
|
||||||
:param bool use_filters: Indicates if outgoing filters should be
|
:param bool use_filters: Indicates if outgoing filters should be
|
||||||
applied to the given stanza data. Disabling
|
applied to the given stanza data. Disabling
|
||||||
@ -1156,15 +1288,15 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
return
|
return
|
||||||
self.waiting_queue.put_nowait((data, use_filters))
|
self.waiting_queue.put_nowait((data, use_filters))
|
||||||
|
|
||||||
def send_xml(self, data):
|
def send_xml(self, data: ET.Element) -> None:
|
||||||
"""Send an XML object on the stream
|
"""Send an XML object on the stream
|
||||||
|
|
||||||
:param data: The :class:`~xml.etree.ElementTree.Element` XML object
|
:param data: The :class:`~xml.etree.ElementTree.Element` XML object
|
||||||
to send on the stream.
|
to send on the stream.
|
||||||
"""
|
"""
|
||||||
return self.send(tostring(data))
|
self.send(tostring(data))
|
||||||
|
|
||||||
def send_raw(self, data):
|
def send_raw(self, data: Union[str, bytes]) -> None:
|
||||||
"""Send raw data across the stream.
|
"""Send raw data across the stream.
|
||||||
|
|
||||||
:param string data: Any bytes or utf-8 string value.
|
:param string data: Any bytes or utf-8 string value.
|
||||||
@ -1176,7 +1308,8 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
data = data.encode('utf-8')
|
data = data.encode('utf-8')
|
||||||
self.transport.write(data)
|
self.transport.write(data)
|
||||||
|
|
||||||
def _build_stanza(self, xml, default_ns=None):
|
def _build_stanza(self, xml: ET.Element,
|
||||||
|
default_ns: Optional[str] = None) -> StanzaBase:
|
||||||
"""Create a stanza object from a given XML object.
|
"""Create a stanza object from a given XML object.
|
||||||
|
|
||||||
If a specialized stanza type is not found for the XML, then
|
If a specialized stanza type is not found for the XML, then
|
||||||
@ -1201,7 +1334,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
stanza['lang'] = self.peer_default_lang
|
stanza['lang'] = self.peer_default_lang
|
||||||
return stanza
|
return stanza
|
||||||
|
|
||||||
def _spawn_event(self, xml):
|
def _spawn_event(self, xml: ET.Element) -> None:
|
||||||
"""
|
"""
|
||||||
Analyze incoming XML stanzas and convert them into stanza
|
Analyze incoming XML stanzas and convert them into stanza
|
||||||
objects if applicable and queue stream events to be processed
|
objects if applicable and queue stream events to be processed
|
||||||
@ -1215,9 +1348,10 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
|
|
||||||
# Convert the raw XML object into a stanza object. If no registered
|
# Convert the raw XML object into a stanza object. If no registered
|
||||||
# stanza type applies, a generic StanzaBase stanza will be used.
|
# stanza type applies, a generic StanzaBase stanza will be used.
|
||||||
stanza = self._build_stanza(xml)
|
stanza: Optional[StanzaBase] = self._build_stanza(xml)
|
||||||
for filter in self.__filters['in']:
|
for filter in self.__filters['in']:
|
||||||
if stanza is not None:
|
if stanza is not None:
|
||||||
|
filter = cast(SyncFilter, filter)
|
||||||
stanza = filter(stanza)
|
stanza = filter(stanza)
|
||||||
if stanza is None:
|
if stanza is None:
|
||||||
return
|
return
|
||||||
@ -1244,7 +1378,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
if not handled:
|
if not handled:
|
||||||
stanza.unhandled()
|
stanza.unhandled()
|
||||||
|
|
||||||
def exception(self, exception):
|
def exception(self, exception: Exception) -> None:
|
||||||
"""Process an unknown exception.
|
"""Process an unknown exception.
|
||||||
|
|
||||||
Meant to be overridden.
|
Meant to be overridden.
|
||||||
@ -1253,7 +1387,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def wait_until(self, event: str, timeout=30) -> Any:
|
async def wait_until(self, event: str, timeout: Union[int, float] = 30) -> Any:
|
||||||
"""Utility method to wake on the next firing of an event.
|
"""Utility method to wake on the next firing of an event.
|
||||||
(Registers a disposable handler on it)
|
(Registers a disposable handler on it)
|
||||||
|
|
||||||
@ -1261,9 +1395,9 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
:param int timeout: Timeout
|
:param int timeout: Timeout
|
||||||
:raises: :class:`asyncio.TimeoutError` when the timeout is reached
|
:raises: :class:`asyncio.TimeoutError` when the timeout is reached
|
||||||
"""
|
"""
|
||||||
fut = asyncio.Future()
|
fut: Future[Any] = asyncio.Future()
|
||||||
|
|
||||||
def result_handler(event_data):
|
def result_handler(event_data: Any) -> None:
|
||||||
if not fut.done():
|
if not fut.done():
|
||||||
fut.set_result(event_data)
|
fut.set_result(event_data)
|
||||||
else:
|
else:
|
||||||
@ -1280,19 +1414,19 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
return await asyncio.wait_for(fut, timeout)
|
return await asyncio.wait_for(fut, timeout)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def event_handler(self, event: str, handler: Callable):
|
def event_handler(self, event: str, handler: Callable[..., Any]) -> Generator[None, None, None]:
|
||||||
"""
|
"""
|
||||||
Context manager that adds then removes an event handler.
|
Context manager that adds then removes an event handler.
|
||||||
"""
|
"""
|
||||||
self.add_event_handler(event, handler)
|
self.add_event_handler(event, handler)
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
except Exception as exc:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self.del_event_handler(event, handler)
|
self.del_event_handler(event, handler)
|
||||||
|
|
||||||
def wrap(self, coroutine: Coroutine[Any, Any, Any]) -> Future:
|
def wrap(self, coroutine: Coroutine[None, None, T]) -> Future[T]:
|
||||||
"""Make a Future out of a coroutine with the current loop.
|
"""Make a Future out of a coroutine with the current loop.
|
||||||
|
|
||||||
:param coroutine: The coroutine to wrap.
|
:param coroutine: The coroutine to wrap.
|
||||||
|
Loading…
Reference in New Issue
Block a user