Merge branch 'more-typing' into 'master'

Add more typing

See merge request poezio/slixmpp!166
This commit is contained in:
mathieui 2021-07-15 10:01:03 +02:00
commit 22fa8bc4d9
63 changed files with 1160 additions and 783 deletions

View File

@ -1,7 +1,17 @@
stages:
- lint
- test
- trigger
mypy:
stage: lint
tags:
- docker
image: python:3
script:
- pip3 install mypy
- mypy slixmpp
test:
stage: test
tags:

15
mypy.ini Normal file
View File

@ -0,0 +1,15 @@
[mypy]
check_untyped_defs = False
ignore_missing_imports = True
[mypy-slixmpp.types]
ignore_errors = True
[mypy-slixmpp.thirdparty.*]
ignore_errors = True
[mypy-slixmpp.plugins.*]
ignore_errors = True
[mypy-slixmpp.plugins.base]
ignore_errors = False

View File

@ -83,6 +83,7 @@ setup(
url='https://lab.louiz.org/poezio/slixmpp',
license='MIT',
platforms=['any'],
package_data={'slixmpp': ['py.typed']},
packages=packages,
ext_modules=ext_modules,
install_requires=['aiodns>=1.0', 'pyasn1', 'pyasn1_modules', 'typing_extensions; python_version < "3.8.0"'],

View File

@ -19,7 +19,6 @@ from slixmpp.xmlstream.stanzabase import ET, ElementBase, register_stanza_plugin
from slixmpp.xmlstream.handler import *
from slixmpp.xmlstream import XMLStream
from slixmpp.xmlstream.matcher import *
from slixmpp.xmlstream.asyncio import asyncio, future_wrapper
from slixmpp.basexmpp import BaseXMPP
from slixmpp.clientxmpp import ClientXMPP
from slixmpp.componentxmpp import ComponentXMPP

View File

@ -21,7 +21,7 @@ class APIWrapper(object):
if name not in self.api.settings:
self.api.settings[name] = {}
def __getattr__(self, attr):
def __getattr__(self, attr: str):
"""Curry API management commands with the API name."""
if attr == 'name':
return self.name
@ -33,13 +33,13 @@ class APIWrapper(object):
return register(handler, self.name, op, jid, node, default)
return partial
elif attr == 'register_default':
def partial(handler, op, jid=None, node=None):
def partial1(handler, op, jid=None, node=None):
return getattr(self.api, attr)(handler, self.name, op)
return partial
return partial1
elif attr in ('run', 'restore_default', 'unregister'):
def partial(*args, **kwargs):
def partial2(*args, **kwargs):
return getattr(self.api, attr)(self.name, *args, **kwargs)
return partial
return partial2
return None
def __getitem__(self, attr):
@ -82,7 +82,7 @@ class APIRegistry(object):
"""Return a wrapper object that targets a specific API."""
return APIWrapper(self, ctype)
def purge(self, ctype: str):
def purge(self, ctype: str) -> None:
"""Remove all information for a given API."""
del self.settings[ctype]
del self._handler_defaults[ctype]
@ -131,22 +131,23 @@ class APIRegistry(object):
jid = JID(jid)
elif jid == JID(''):
jid = self.xmpp.boundjid
assert jid is not None
if node is None:
node = ''
if self.xmpp.is_component:
if self.settings[ctype].get('component_bare', False):
jid = jid.bare
jid_str = jid.bare
else:
jid = jid.full
jid_str = jid.full
else:
if self.settings[ctype].get('client_bare', False):
jid = jid.bare
jid_str = jid.bare
else:
jid = jid.full
jid_str = jid.full
jid = JID(jid)
jid = JID(jid_str)
handler = self._handlers[ctype][op]['node'].get((jid, node), None)
if handler is None:
@ -167,8 +168,11 @@ class APIRegistry(object):
# To preserve backward compatibility, drop the ifrom
# parameter for existing handlers that don't understand it.
return handler(jid, node, args)
future = Future()
future.set_result(None)
return future
def register(self, handler: APIHandler, ctype: str, op: str,
def register(self, handler: Optional[APIHandler], ctype: str, op: str,
jid: Optional[JID] = None, node: Optional[str] = None,
default: bool = False):
"""Register an API callback, with JID+node specificity.

View File

@ -45,10 +45,11 @@ log = logging.getLogger(__name__)
from slixmpp.types import (
PresenceShows,
PresenceTypes,
MessageTypes,
IqTypes,
JidStr,
OptJidStr,
)
if TYPE_CHECKING:
@ -263,9 +264,9 @@ class BaseXMPP(XMLStream):
if not pconfig:
pconfig = self.plugin_config.get(plugin, {})
if not self.plugin.registered(plugin):
if not self.plugin.registered(plugin): # type: ignore
load_plugin(plugin, module)
self.plugin.enable(plugin, pconfig)
self.plugin.enable(plugin, pconfig) # type: ignore
def register_plugins(self):
"""Register and initialize all built-in plugins.
@ -298,25 +299,25 @@ class BaseXMPP(XMLStream):
"""Return a plugin given its name, if it has been registered."""
return self.plugin.get(key, default)
def Message(self, *args, **kwargs) -> Message:
def Message(self, *args, **kwargs) -> stanza.Message:
"""Create a Message stanza associated with this stream."""
msg = Message(self, *args, **kwargs)
msg['lang'] = self.default_lang
return msg
def Iq(self, *args, **kwargs) -> Iq:
def Iq(self, *args, **kwargs) -> stanza.Iq:
"""Create an Iq stanza associated with this stream."""
return Iq(self, *args, **kwargs)
def Presence(self, *args, **kwargs) -> Presence:
def Presence(self, *args, **kwargs) -> stanza.Presence:
"""Create a Presence stanza associated with this stream."""
pres = Presence(self, *args, **kwargs)
pres['lang'] = self.default_lang
return pres
def make_iq(self, id: str = "0", ifrom: Optional[JID] = None,
ito: Optional[JID] = None, itype: Optional[IqTypes] = None,
iquery: Optional[str] = None) -> Iq:
def make_iq(self, id: str = "0", ifrom: OptJidStr = None,
ito: OptJidStr = None, itype: Optional[IqTypes] = None,
iquery: Optional[str] = None) -> stanza.Iq:
"""Create a new :class:`~.Iq` stanza with a given Id and from JID.
:param id: An ideally unique ID value for this stanza thread.
@ -339,8 +340,8 @@ class BaseXMPP(XMLStream):
return iq
def make_iq_get(self, queryxmlns: Optional[str] =None,
ito: Optional[JID] = None, ifrom: Optional[JID] = None,
iq: Optional[Iq] = None) -> Iq:
ito: OptJidStr = None, ifrom: OptJidStr = None,
iq: Optional[stanza.Iq] = None) -> stanza.Iq:
"""Create an :class:`~.Iq` stanza of type ``'get'``.
Optionally, a query element may be added.
@ -364,8 +365,8 @@ class BaseXMPP(XMLStream):
return iq
def make_iq_result(self, id: Optional[str] = None,
ito: Optional[JID] = None, ifrom: Optional[JID] = None,
iq: Optional[Iq] = None) -> Iq:
ito: OptJidStr = None, ifrom: OptJidStr = None,
iq: Optional[stanza.Iq] = None) -> stanza.Iq:
"""
Create an :class:`~.Iq` stanza of type
``'result'`` with the given ID value.
@ -391,8 +392,8 @@ class BaseXMPP(XMLStream):
return iq
def make_iq_set(self, sub: Optional[Union[ElementBase, ET.Element]] = None,
ito: Optional[JID] = None, ifrom: Optional[JID] = None,
iq: Optional[Iq] = None) -> Iq:
ito: OptJidStr = None, ifrom: OptJidStr = None,
iq: Optional[stanza.Iq] = None) -> stanza.Iq:
"""
Create an :class:`~.Iq` stanza of type ``'set'``.
@ -414,7 +415,7 @@ class BaseXMPP(XMLStream):
if not iq:
iq = self.Iq()
iq['type'] = 'set'
if sub != None:
if sub is not None:
iq.append(sub)
if ito:
iq['to'] = ito
@ -453,9 +454,9 @@ class BaseXMPP(XMLStream):
iq['from'] = ifrom
return iq
def make_iq_query(self, iq: Optional[Iq] = None, xmlns: str = '',
ito: Optional[JID] = None,
ifrom: Optional[JID] = None) -> Iq:
def make_iq_query(self, iq: Optional[stanza.Iq] = None, xmlns: str = '',
ito: OptJidStr = None,
ifrom: OptJidStr = None) -> stanza.Iq:
"""
Create or modify an :class:`~.Iq` stanza
to use the given query namespace.
@ -477,7 +478,7 @@ class BaseXMPP(XMLStream):
iq['from'] = ifrom
return iq
def make_query_roster(self, iq: Optional[Iq] = None) -> ET.Element:
def make_query_roster(self, iq: Optional[stanza.Iq] = None) -> ET.Element:
"""Create a roster query element.
:param iq: Optionally use an existing stanza instead
@ -487,11 +488,11 @@ class BaseXMPP(XMLStream):
iq['query'] = 'jabber:iq:roster'
return ET.Element("{jabber:iq:roster}query")
def make_message(self, mto: JID, mbody: Optional[str] = None,
def make_message(self, mto: JidStr, mbody: Optional[str] = None,
msubject: Optional[str] = None,
mtype: Optional[MessageTypes] = None,
mhtml: Optional[str] = None, mfrom: Optional[JID] = None,
mnick: Optional[str] = None) -> Message:
mhtml: Optional[str] = None, mfrom: OptJidStr = None,
mnick: Optional[str] = None) -> stanza.Message:
"""
Create and initialize a new
:class:`~.Message` stanza.
@ -516,13 +517,13 @@ class BaseXMPP(XMLStream):
message['html']['body'] = mhtml
return message
def make_presence(self, pshow: Optional[PresenceShows] = None,
def make_presence(self, pshow: Optional[str] = None,
pstatus: Optional[str] = None,
ppriority: Optional[int] = None,
pto: Optional[JID] = None,
pto: OptJidStr = None,
ptype: Optional[PresenceTypes] = None,
pfrom: Optional[JID] = None,
pnick: Optional[str] = None) -> Presence:
pfrom: OptJidStr = None,
pnick: Optional[str] = None) -> stanza.Presence:
"""
Create and initialize a new
:class:`~.Presence` stanza.
@ -548,7 +549,7 @@ class BaseXMPP(XMLStream):
def send_message(self, mto: JID, mbody: Optional[str] = None,
msubject: Optional[str] = None,
mtype: Optional[MessageTypes] = None,
mhtml: Optional[str] = None, mfrom: Optional[JID] = None,
mhtml: Optional[str] = None, mfrom: OptJidStr = None,
mnick: Optional[str] = None):
"""
Create, initialize, and send a new
@ -568,12 +569,12 @@ class BaseXMPP(XMLStream):
self.make_message(mto, mbody, msubject, mtype,
mhtml, mfrom, mnick).send()
def send_presence(self, pshow: Optional[PresenceShows] = None,
def send_presence(self, pshow: Optional[str] = None,
pstatus: Optional[str] = None,
ppriority: Optional[int] = None,
pto: Optional[JID] = None,
pto: OptJidStr = None,
ptype: Optional[PresenceTypes] = None,
pfrom: Optional[JID] = None,
pfrom: OptJidStr = None,
pnick: Optional[str] = None):
"""
Create, initialize, and send a new
@ -590,8 +591,9 @@ class BaseXMPP(XMLStream):
self.make_presence(pshow, pstatus, ppriority, pto,
ptype, pfrom, pnick).send()
def send_presence_subscription(self, pto, pfrom=None,
ptype='subscribe', pnick=None):
def send_presence_subscription(self, pto: JidStr, pfrom: OptJidStr = None,
ptype: PresenceTypes='subscribe', pnick:
Optional[str] = None):
"""
Create, initialize, and send a new
:class:`~.Presence` stanza of
@ -608,62 +610,62 @@ class BaseXMPP(XMLStream):
pnick=pnick).send()
@property
def jid(self):
def jid(self) -> str:
"""Attribute accessor for bare jid"""
log.warning("jid property deprecated. Use boundjid.bare")
return self.boundjid.bare
@jid.setter
def jid(self, value):
def jid(self, value: str):
log.warning("jid property deprecated. Use boundjid.bare")
self.boundjid.bare = value
@property
def fulljid(self):
def fulljid(self) -> str:
"""Attribute accessor for full jid"""
log.warning("fulljid property deprecated. Use boundjid.full")
return self.boundjid.full
@fulljid.setter
def fulljid(self, value):
def fulljid(self, value: str):
log.warning("fulljid property deprecated. Use boundjid.full")
self.boundjid.full = value
@property
def resource(self):
def resource(self) -> str:
"""Attribute accessor for jid resource"""
log.warning("resource property deprecated. Use boundjid.resource")
return self.boundjid.resource
@resource.setter
def resource(self, value):
def resource(self, value: str):
log.warning("fulljid property deprecated. Use boundjid.resource")
self.boundjid.resource = value
@property
def username(self):
def username(self) -> str:
"""Attribute accessor for jid usernode"""
log.warning("username property deprecated. Use boundjid.user")
return self.boundjid.user
@username.setter
def username(self, value):
def username(self, value: str):
log.warning("username property deprecated. Use boundjid.user")
self.boundjid.user = value
@property
def server(self):
def server(self) -> str:
"""Attribute accessor for jid host"""
log.warning("server property deprecated. Use boundjid.host")
return self.boundjid.server
@server.setter
def server(self, value):
def server(self, value: str):
log.warning("server property deprecated. Use boundjid.host")
self.boundjid.server = value
@property
def auto_authorize(self):
def auto_authorize(self) -> Optional[bool]:
"""Auto accept or deny subscription requests.
If ``True``, auto accept subscription requests.
@ -673,11 +675,11 @@ class BaseXMPP(XMLStream):
return self.roster.auto_authorize
@auto_authorize.setter
def auto_authorize(self, value):
def auto_authorize(self, value: Optional[bool]):
self.roster.auto_authorize = value
@property
def auto_subscribe(self):
def auto_subscribe(self) -> bool:
"""Auto send requests for mutual subscriptions.
If ``True``, auto send mutual subscription requests.
@ -685,21 +687,21 @@ class BaseXMPP(XMLStream):
return self.roster.auto_subscribe
@auto_subscribe.setter
def auto_subscribe(self, value):
def auto_subscribe(self, value: bool):
self.roster.auto_subscribe = value
def set_jid(self, jid):
def set_jid(self, jid: JidStr):
"""Rip a JID apart and claim it as our own."""
log.debug("setting jid to %s", jid)
self.boundjid = JID(jid)
def getjidresource(self, fulljid):
def getjidresource(self, fulljid: str):
if '/' in fulljid:
return fulljid.split('/', 1)[-1]
else:
return ''
def getjidbare(self, fulljid):
def getjidbare(self, fulljid: str):
return fulljid.split('/', 1)[0]
def _handle_session_start(self, event):

View File

@ -8,23 +8,18 @@
# :license: MIT, see LICENSE for more details
import asyncio
import logging
from typing import Optional, Any, Callable, Tuple, Dict, Set, List
from slixmpp.jid import JID
from slixmpp.stanza import StreamFeatures
from slixmpp.stanza import StreamFeatures, Iq
from slixmpp.basexmpp import BaseXMPP
from slixmpp.exceptions import XMPPError
from slixmpp.types import JidStr
from slixmpp.xmlstream import XMLStream
from slixmpp.xmlstream.stanzabase import StanzaBase
from slixmpp.xmlstream.matcher import StanzaPath, MatchXPath
from slixmpp.xmlstream.handler import Callback, CoroutineCallback
# Flag indicating if DNS SRV records are available for use.
try:
import dns.resolver
except ImportError:
DNSPYTHON = False
else:
DNSPYTHON = True
log = logging.getLogger(__name__)
@ -53,7 +48,7 @@ class ClientXMPP(BaseXMPP):
:param escape_quotes: **Deprecated.**
"""
def __init__(self, jid, password, plugin_config=None,
def __init__(self, jid: JidStr, password: str, plugin_config=None,
plugin_whitelist=None, escape_quotes=True, sasl_mech=None,
lang='en', **kwargs):
if not plugin_whitelist:
@ -69,7 +64,7 @@ class ClientXMPP(BaseXMPP):
self.default_port = 5222
self.default_lang = lang
self.credentials = {}
self.credentials: Dict[str, str] = {}
self.password = password
@ -81,9 +76,9 @@ class ClientXMPP(BaseXMPP):
"version='1.0'")
self.stream_footer = "</stream:stream>"
self.features = set()
self._stream_feature_handlers = {}
self._stream_feature_order = []
self.features: Set[str] = set()
self._stream_feature_handlers: Dict[str, Tuple[Callable, bool]] = {}
self._stream_feature_order: List[Tuple[int, str]] = []
self.dns_service = 'xmpp-client'
@ -100,10 +95,14 @@ class ClientXMPP(BaseXMPP):
self.register_stanza(StreamFeatures)
self.register_handler(
CoroutineCallback('Stream Features',
MatchXPath('{%s}features' % self.stream_ns),
self._handle_stream_features))
def roster_push_filter(iq):
CoroutineCallback(
'Stream Features',
MatchXPath('{%s}features' % self.stream_ns),
self._handle_stream_features, # type: ignore
)
)
def roster_push_filter(iq: StanzaBase) -> None:
from_ = iq['from']
if from_ and from_ != JID('') and from_ != self.boundjid.bare:
reply = iq.reply()
@ -131,15 +130,16 @@ class ClientXMPP(BaseXMPP):
self['feature_mechanisms'].use_mech = sasl_mech
@property
def password(self):
def password(self) -> str:
return self.credentials.get('password', '')
@password.setter
def password(self, value):
def password(self, value: str) -> None:
self.credentials['password'] = value
def connect(self, address=tuple(), use_ssl=False,
force_starttls=True, disable_starttls=False):
def connect(self, address: Optional[Tuple[str, int]] = None, # type: ignore
use_ssl: bool = False, force_starttls: bool = True,
disable_starttls: bool = False) -> None:
"""Connect to the XMPP server.
When no address is given, a SRV lookup for the server will
@ -161,14 +161,15 @@ class ClientXMPP(BaseXMPP):
# XMPP client port and allow SRV lookup.
if address:
self.dns_service = None
host, port = address
else:
address = (self.boundjid.host, 5222)
host, port = (self.boundjid.host, 5222)
self.dns_service = 'xmpp-client'
return XMLStream.connect(self, address[0], address[1], use_ssl=use_ssl,
return XMLStream.connect(self, host, port, use_ssl=use_ssl,
force_starttls=force_starttls, disable_starttls=disable_starttls)
def register_feature(self, name, handler, restart=False, order=5000):
def register_feature(self, name: str, handler: Callable, restart: bool = False, order: int = 5000) -> None:
"""Register a stream feature handler.
:param name: The name of the stream feature.
@ -183,13 +184,13 @@ class ClientXMPP(BaseXMPP):
self._stream_feature_order.append((order, name))
self._stream_feature_order.sort()
def unregister_feature(self, name, order):
def unregister_feature(self, name: str, order: int) -> None:
if name in self._stream_feature_handlers:
del self._stream_feature_handlers[name]
self._stream_feature_order.remove((order, name))
self._stream_feature_order.sort()
def update_roster(self, jid, **kwargs):
def update_roster(self, jid: JID, **kwargs) -> None:
"""Add or change a roster item.
:param jid: The JID of the entry to modify.
@ -251,7 +252,7 @@ class ClientXMPP(BaseXMPP):
return iq.send(callback, timeout, timeout_callback)
def _reset_connection_state(self, event=None):
def _reset_connection_state(self, event: Optional[Any] = None) -> None:
#TODO: Use stream state here
self.authenticated = False
self.sessionstarted = False
@ -259,7 +260,7 @@ class ClientXMPP(BaseXMPP):
self.bindfail = False
self.features = set()
async def _handle_stream_features(self, features):
async def _handle_stream_features(self, features: StreamFeatures) -> Optional[bool]:
"""Process the received stream features.
:param features: The features stanza.
@ -277,8 +278,9 @@ class ClientXMPP(BaseXMPP):
return True
log.debug('Finished processing stream features.')
self.event('stream_negotiated')
return None
def _handle_roster(self, iq):
def _handle_roster(self, iq: Iq) -> None:
"""Update the roster after receiving a roster stanza.
:param iq: The roster stanza.
@ -310,7 +312,7 @@ class ClientXMPP(BaseXMPP):
resp.enable('roster')
resp.send()
def _handle_session_bind(self, jid):
def _handle_session_bind(self, jid: JID) -> None:
"""Set the client roster to the JID set by the server.
:param :class:`slixmpp.xmlstream.jid.JID` jid: The bound JID as

View File

@ -1,4 +1,3 @@
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2011 Nathanael C. Fritz
# This file is part of Slixmpp.
@ -11,6 +10,7 @@ from slixmpp.stanza import Iq, StreamFeatures
from slixmpp.features.feature_bind import stanza
from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.plugins import BasePlugin
from typing import ClassVar, Set
log = logging.getLogger(__name__)
@ -20,7 +20,7 @@ class FeatureBind(BasePlugin):
name = 'feature_bind'
description = 'RFC 6120: Stream Feature: Resource Binding'
dependencies = set()
dependencies: ClassVar[Set[str]] = set()
stanza = stanza
def plugin_init(self):

View File

@ -1,4 +1,3 @@
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2011 Nathanael C. Fritz
# This file is part of Slixmpp.
@ -15,6 +14,8 @@ from slixmpp.xmlstream.matcher import MatchXPath
from slixmpp.xmlstream.handler import Callback
from slixmpp.features.feature_mechanisms import stanza
from typing import ClassVar, Set
log = logging.getLogger(__name__)
@ -23,7 +24,7 @@ class FeatureMechanisms(BasePlugin):
name = 'feature_mechanisms'
description = 'RFC 6120: Stream Feature: SASL'
dependencies = set()
dependencies: ClassVar[Set[str]] = set()
stanza = stanza
default_config = {
'use_mech': None,

View File

@ -1,9 +1,9 @@
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2011 Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
from slixmpp.xmlstream import StanzaBase
from typing import ClassVar, Set
class Abort(StanzaBase):
@ -13,7 +13,7 @@ class Abort(StanzaBase):
name = 'abort'
namespace = 'urn:ietf:params:xml:ns:xmpp-sasl'
interfaces = set()
interfaces: ClassVar[Set[str]] = set()
plugin_attrib = name
def setup(self, xml):

View File

@ -1,4 +1,3 @@
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2012 Nathanael C. Fritz
# This file is part of Slixmpp.
@ -9,6 +8,7 @@ from slixmpp.stanza import StreamFeatures
from slixmpp.features.feature_preapproval import stanza
from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.plugins.base import BasePlugin
from typing import ClassVar, Set
log = logging.getLogger(__name__)
@ -18,7 +18,7 @@ class FeaturePreApproval(BasePlugin):
name = 'feature_preapproval'
description = 'RFC 6121: Stream Feature: Subscription Pre-Approval'
dependences = set()
dependencies: ClassVar[Set[str]] = set()
stanza = stanza
def plugin_init(self):

View File

@ -1,14 +1,14 @@
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2012 Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
from slixmpp.xmlstream import ElementBase
from typing import ClassVar, Set
class PreApproval(ElementBase):
name = 'sub'
namespace = 'urn:xmpp:features:pre-approval'
interfaces = set()
interfaces: ClassVar[Set[str]] = set()
plugin_attrib = 'preapproval'

View File

@ -1,4 +1,3 @@
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2012 Nathanael C. Fritz
# This file is part of Slixmpp.
@ -9,6 +8,7 @@ from slixmpp.stanza import StreamFeatures
from slixmpp.features.feature_rosterver import stanza
from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.plugins.base import BasePlugin
from typing import ClassVar, Set
log = logging.getLogger(__name__)
@ -18,7 +18,7 @@ class FeatureRosterVer(BasePlugin):
name = 'feature_rosterver'
description = 'RFC 6121: Stream Feature: Roster Versioning'
dependences = set()
dependences: ClassVar[Set[str]] = set()
stanza = stanza
def plugin_init(self):

View File

@ -1,14 +1,14 @@
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2012 Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
from slixmpp.xmlstream import ElementBase
from typing import Set, ClassVar
class RosterVer(ElementBase):
name = 'ver'
namespace = 'urn:xmpp:features:rosterver'
interfaces = set()
interfaces: ClassVar[Set[str]] = set()
plugin_attrib = 'rosterver'

View File

@ -11,6 +11,7 @@ from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.plugins import BasePlugin
from slixmpp.features.feature_session import stanza
from typing import ClassVar, Set
log = logging.getLogger(__name__)
@ -20,7 +21,7 @@ class FeatureSession(BasePlugin):
name = 'feature_session'
description = 'RFC 3920: Stream Feature: Start Session'
dependencies = set()
dependencies: ClassVar[Set[str]] = set()
stanza = stanza
def plugin_init(self):

View File

@ -4,39 +4,47 @@
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
from slixmpp.xmlstream import StanzaBase, ElementBase
from typing import Set, ClassVar
class STARTTLS(ElementBase):
"""
class STARTTLS(StanzaBase):
"""
.. code-block:: xml
<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>
"""
name = 'starttls'
namespace = 'urn:ietf:params:xml:ns:xmpp-tls'
interfaces = {'required'}
plugin_attrib = name
def get_required(self):
"""
"""
return True
class Proceed(StanzaBase):
"""
"""
.. code-block:: xml
<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>
"""
name = 'proceed'
namespace = 'urn:ietf:params:xml:ns:xmpp-tls'
interfaces = set()
interfaces: ClassVar[Set[str]] = set()
class Failure(StanzaBase):
"""
"""
.. code-block:: xml
<failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>
"""
name = 'failure'
namespace = 'urn:ietf:params:xml:ns:xmpp-tls'
interfaces = set()
interfaces: ClassVar[Set[str]] = set()

View File

@ -12,6 +12,8 @@ from slixmpp.xmlstream.matcher import MatchXPath
from slixmpp.xmlstream.handler import CoroutineCallback
from slixmpp.features.feature_starttls import stanza
from typing import ClassVar, Set
log = logging.getLogger(__name__)
@ -20,7 +22,7 @@ class FeatureSTARTTLS(BasePlugin):
name = 'feature_starttls'
description = 'RFC 6120: Stream Feature: STARTTLS'
dependencies = set()
dependencies: ClassVar[Set[str]] = set()
stanza = stanza
def plugin_init(self):
@ -52,7 +54,7 @@ class FeatureSTARTTLS(BasePlugin):
elif self.xmpp.disable_starttls:
return False
else:
self.xmpp.send(features['starttls'])
self.xmpp.send(stanza.STARTTLS())
return True
async def _handle_starttls_proceed(self, proceed):

View File

@ -350,36 +350,10 @@ class JID:
if self._resource
else self._bare)
@property
def node(self) -> str:
return self._node
@property
def domain(self) -> str:
return self._domain
@property
def resource(self) -> str:
return self._resource
@property
def bare(self) -> str:
return self._bare
@property
def full(self) -> str:
return self._full
@node.setter
def node(self, value: str):
self._node = _validate_node(value)
self._update_bare_full()
@domain.setter
def domain(self, value: str):
self._domain = _validate_domain(value)
self._update_bare_full()
@bare.setter
def bare(self, value: str):
node, domain, resource = _parse_jid(value)
@ -388,11 +362,38 @@ class JID:
self._domain = domain
self._update_bare_full()
@property
def node(self) -> str:
return self._node
@node.setter
def node(self, value: str):
self._node = _validate_node(value)
self._update_bare_full()
@property
def domain(self) -> str:
return self._domain
@domain.setter
def domain(self, value: str):
self._domain = _validate_domain(value)
self._update_bare_full()
@property
def resource(self) -> str:
return self._resource
@resource.setter
def resource(self, value: str):
self._resource = _validate_resource(value)
self._update_bare_full()
@property
def full(self) -> str:
return self._full
@full.setter
def full(self, value: str):
self._node, self._domain, self._resource = _parse_jid(value)

View File

@ -12,6 +12,8 @@ import copy
import logging
import threading
from typing import Any, Dict, Set, ClassVar
log = logging.getLogger(__name__)
@ -250,17 +252,17 @@ class BasePlugin(object):
#: A short name for the plugin based on the implemented specification.
#: For example, a plugin for XEP-0030 would use `'xep_0030'`.
name = ''
name: str = ''
#: A longer name for the plugin, describing its purpose. For example,
#: a plugin for XEP-0030 would use `'Service Discovery'` as its
#: description value.
description = ''
description: str = ''
#: Some plugins may depend on others in order to function properly.
#: Any plugin names included in :attr:`~BasePlugin.dependencies` will
#: be initialized as needed if this plugin is enabled.
dependencies = set()
dependencies: ClassVar[Set[str]] = set()
#: The basic, standard configuration for the plugin, which may
#: be overridden when initializing the plugin. The configuration
@ -268,7 +270,7 @@ class BasePlugin(object):
#: the plugin. For example, including the configuration field 'foo'
#: would mean accessing `plugin.foo` returns the current value of
#: `plugin.config['foo']`.
default_config = {}
default_config: ClassVar[Dict[str, Any]] = {}
def __init__(self, xmpp, config=None):
self.xmpp = xmpp

View File

@ -11,11 +11,11 @@ from typing import (
Optional
)
from slixmpp.plugins import BasePlugin, register_plugin
from slixmpp import future_wrapper, JID
from slixmpp.plugins import BasePlugin
from slixmpp import JID
from slixmpp.stanza import Iq
from slixmpp.exceptions import XMPPError
from slixmpp.xmlstream import JID, register_stanza_plugin
from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.xmlstream.handler import CoroutineCallback
from slixmpp.xmlstream.matcher import StanzaPath
from slixmpp.plugins.xep_0012 import stanza, LastActivity

View File

@ -4,7 +4,6 @@
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
import logging
from asyncio import Future
from typing import Optional
from slixmpp import JID
@ -15,7 +14,6 @@ from slixmpp.xmlstream.handler import CoroutineCallback
from slixmpp.xmlstream.matcher import StanzaPath
from slixmpp.plugins import BasePlugin
from slixmpp.plugins.xep_0054 import VCardTemp, stanza
from slixmpp import future_wrapper
log = logging.getLogger(__name__)

View File

@ -1,4 +1,3 @@
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2015 Emmanuel Gil Peyrot
# This file is part of Slixmpp.
@ -7,11 +6,10 @@ import asyncio
import logging
from uuid import uuid4
from slixmpp.plugins import BasePlugin, register_plugin
from slixmpp import future_wrapper, Iq, Message
from slixmpp.exceptions import XMPPError, IqError, IqTimeout
from slixmpp.plugins import BasePlugin
from slixmpp import Iq, Message
from slixmpp.jid import JID
from slixmpp.xmlstream import JID, register_stanza_plugin
from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.xmlstream.handler import Callback
from slixmpp.xmlstream.matcher import StanzaPath
from slixmpp.plugins.xep_0070 import stanza, Confirm
@ -52,7 +50,6 @@ class XEP_0070(BasePlugin):
def session_bind(self, jid):
self.xmpp['xep_0030'].add_feature('http://jabber.org/protocol/http-auth')
@future_wrapper
def ask_confirm(self, jid, id, url, method, *, ifrom=None, message=None):
jid = JID(jid)
if jid.resource:
@ -70,7 +67,9 @@ class XEP_0070(BasePlugin):
if message is not None:
stanza['body'] = message.format(id=id, url=url, method=method)
stanza.send()
return stanza
fut = asyncio.Future()
fut.set_result(stanza)
return fut
else:
return stanza.send()

View File

@ -17,7 +17,6 @@ from slixmpp.exceptions import XMPPError, IqTimeout, IqError
from slixmpp.xmlstream import register_stanza_plugin, ElementBase
from slixmpp.plugins.base import BasePlugin
from slixmpp.plugins.xep_0153 import stanza, VCardTempUpdate
from slixmpp import future_wrapper
log = logging.getLogger(__name__)

View File

@ -3,10 +3,11 @@
# Copyright (C) 2011 Nathanael C. Fritz, Lance J.T. Stout
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
import asyncio
import logging
from typing import Optional, Callable
from slixmpp import asyncio, JID
from slixmpp import JID
from slixmpp.xmlstream import register_stanza_plugin, ElementBase
from slixmpp.plugins.base import BasePlugin, register_plugin
from slixmpp.plugins.xep_0004.stanza import Form

View File

@ -3,6 +3,7 @@
# Copyright (C) 2010 Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
import asyncio
import time
import logging
@ -11,7 +12,6 @@ from typing import Optional, Callable, List
from slixmpp.jid import JID
from slixmpp.stanza import Iq
from slixmpp import asyncio
from slixmpp.exceptions import IqError, IqTimeout
from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.xmlstream.matcher import StanzaPath

View File

@ -9,14 +9,13 @@ import hashlib
from asyncio import Future
from typing import Optional
from slixmpp import future_wrapper, JID
from slixmpp import JID
from slixmpp.stanza import Iq, Message, Presence
from slixmpp.exceptions import XMPPError
from slixmpp.xmlstream.handler import CoroutineCallback
from slixmpp.xmlstream.matcher import StanzaPath
from slixmpp.xmlstream import register_stanza_plugin
from slixmpp.plugins.base import BasePlugin
from slixmpp.plugins.xep_0231 import stanza, BitsOfBinary
from slixmpp.plugins.xep_0231 import BitsOfBinary
log = logging.getLogger(__name__)

View File

@ -5,10 +5,10 @@
# Copyright (C) 2013 Sustainable Innovation, Joachim.lindborg@sust.se, bjorn.westrom@consoden.se
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
import asyncio
import logging
import time
from slixmpp import asyncio
from functools import partial
from slixmpp.xmlstream import JID
from slixmpp.xmlstream.handler import Callback

View File

@ -20,6 +20,7 @@ from slixmpp.plugins.xep_0030 import XEP_0030
from slixmpp.plugins.xep_0033 import XEP_0033
from slixmpp.plugins.xep_0045 import XEP_0045
from slixmpp.plugins.xep_0047 import XEP_0047
from slixmpp.plugins.xep_0048 import XEP_0048
from slixmpp.plugins.xep_0049 import XEP_0049
from slixmpp.plugins.xep_0050 import XEP_0050
from slixmpp.plugins.xep_0054 import XEP_0054
@ -112,6 +113,7 @@ class PluginsDict(TypedDict):
xep_0033: XEP_0033
xep_0045: XEP_0045
xep_0047: XEP_0047
xep_0048: XEP_0048
xep_0049: XEP_0049
xep_0050: XEP_0050
xep_0054: XEP_0054

0
slixmpp/py.typed Normal file
View File

View File

@ -1,8 +1,9 @@
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2010 Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
from __future__ import annotations
from typing import Optional, Dict, Type, ClassVar
from slixmpp.xmlstream import ElementBase, ET
@ -49,10 +50,10 @@ class Error(ElementBase):
name = 'error'
plugin_attrib = 'error'
interfaces = {'code', 'condition', 'text', 'type',
'gone', 'redirect', 'by'}
'gone', 'redirect', 'by'}
sub_interfaces = {'text'}
plugin_attrib_map = {}
plugin_tag_map = {}
plugin_attrib_map: ClassVar[Dict[str, Type[ElementBase]]] = {}
plugin_tag_map: ClassVar[Dict[str, Type[ElementBase]]] = {}
conditions = {'bad-request', 'conflict', 'feature-not-implemented',
'forbidden', 'gone', 'internal-server-error',
'item-not-found', 'jid-malformed', 'not-acceptable',
@ -62,10 +63,10 @@ class Error(ElementBase):
'remote-server-timeout', 'resource-constraint',
'service-unavailable', 'subscription-required',
'undefined-condition', 'unexpected-request'}
condition_ns = 'urn:ietf:params:xml:ns:xmpp-stanzas'
condition_ns: str = 'urn:ietf:params:xml:ns:xmpp-stanzas'
types = {'cancel', 'continue', 'modify', 'auth', 'wait'}
def setup(self, xml=None):
def setup(self, xml: Optional[ET.Element] = None):
"""
Populate the stanza object using an optional XML object.
@ -82,9 +83,11 @@ class Error(ElementBase):
self['type'] = 'cancel'
self['condition'] = 'feature-not-implemented'
if self.parent is not None:
self.parent()['type'] = 'error'
parent = self.parent()
if parent:
parent['type'] = 'error'
def get_condition(self):
def get_condition(self) -> str:
"""Return the condition element's name."""
for child in self.xml:
if "{%s}" % self.condition_ns in child.tag:
@ -93,7 +96,7 @@ class Error(ElementBase):
return cond
return ''
def set_condition(self, value):
def set_condition(self, value: str) -> Error:
"""
Set the tag name of the condition element.
@ -105,7 +108,7 @@ class Error(ElementBase):
self.xml.append(ET.Element("{%s}%s" % (self.condition_ns, value)))
return self
def del_condition(self):
def del_condition(self) -> Error:
"""Remove the condition element."""
for child in self.xml:
if "{%s}" % self.condition_ns in child.tag:
@ -139,14 +142,14 @@ class Error(ElementBase):
def get_redirect(self):
return self._get_sub_text('{%s}redirect' % self.condition_ns, '')
def set_gone(self, value):
def set_gone(self, value: str):
if value:
del self['condition']
return self._set_sub_text('{%s}gone' % self.condition_ns, value)
elif self['condition'] == 'gone':
del self['condition']
def set_redirect(self, value):
def set_redirect(self, value: str):
if value:
del self['condition']
ns = self.condition_ns

View File

@ -4,6 +4,7 @@
# See the file LICENSE for copying permission.
from slixmpp.xmlstream import StanzaBase
from typing import Optional
class Handshake(StanzaBase):
@ -18,7 +19,7 @@ class Handshake(StanzaBase):
def set_value(self, value: str):
self.xml.text = value
def get_value(self) -> str:
def get_value(self) -> Optional[str]:
return self.xml.text
def del_value(self):

View File

@ -3,10 +3,10 @@
# Copyright (C) 2010 Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
import asyncio
from slixmpp.stanza.rootstanza import RootStanza
from slixmpp.xmlstream import StanzaBase, ET
from slixmpp.xmlstream.handler import Waiter, Callback, CoroutineCallback
from slixmpp.xmlstream.asyncio import asyncio
from slixmpp.xmlstream.handler import Callback, CoroutineCallback
from slixmpp.xmlstream.matcher import MatchIDSender, MatcherId
from slixmpp.exceptions import IqTimeout, IqError

View File

@ -61,8 +61,10 @@ class Message(RootStanza):
"""
StanzaBase.__init__(self, *args, **kwargs)
if not recv and self['id'] == '':
if self.stream is not None and self.stream.use_message_ids:
self['id'] = self.stream.new_id()
if self.stream:
use_ids = getattr(self.stream, 'use_message_ids', None)
if use_ids:
self['id'] = self.stream.new_id()
else:
del self['origin_id']
@ -93,8 +95,10 @@ class Message(RootStanza):
self.xml.attrib['id'] = value
if self.stream and not self.stream.use_origin_id:
return None
if self.stream:
use_orig_ids = getattr(self.stream, 'use_origin_id', None)
if not use_orig_ids:
return None
sub = self.xml.find(ORIGIN_NAME)
if sub is not None:

View File

@ -1,4 +1,3 @@
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2010 Nathanael C. Fritz
# This file is part of Slixmpp.
@ -61,7 +60,7 @@ class Presence(RootStanza):
'subscribed', 'unsubscribe', 'unsubscribed'}
showtypes = {'dnd', 'chat', 'xa', 'away'}
def __init__(self, *args, recv=False, **kwargs):
def __init__(self, *args, recv: bool = False, **kwargs):
"""
Initialize a new <presence /> stanza with an optional 'id' value.
@ -69,10 +68,12 @@ class Presence(RootStanza):
"""
StanzaBase.__init__(self, *args, **kwargs)
if not recv and self['id'] == '':
if self.stream is not None and self.stream.use_presence_ids:
self['id'] = self.stream.new_id()
if self.stream:
use_ids = getattr(self.stream, 'use_presence_ids', None)
if use_ids:
self['id'] = self.stream.new_id()
def set_show(self, show):
def set_show(self, show: str):
"""
Set the value of the <show> element.
@ -84,7 +85,7 @@ class Presence(RootStanza):
self._set_sub_text('show', text=show)
return self
def get_type(self):
def get_type(self) -> str:
"""
Return the value of the <presence> stanza's type attribute, or
the value of the <show> element if valid.
@ -96,7 +97,7 @@ class Presence(RootStanza):
out = 'available'
return out
def set_type(self, value):
def set_type(self, value: str):
"""
Set the type attribute's value, and the <show> element
if applicable.
@ -119,7 +120,7 @@ class Presence(RootStanza):
self._del_attr('type')
self._del_sub('show')
def set_priority(self, value):
def set_priority(self, value: int):
"""
Set the entity's priority value. Some server use priority to
determine message routing behavior.

View File

@ -4,7 +4,8 @@
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
from slixmpp.stanza.error import Error
from slixmpp.xmlstream import StanzaBase
from slixmpp.xmlstream import StanzaBase, ET
from typing import Optional, Dict, Union
class StreamError(Error, StanzaBase):
@ -62,19 +63,20 @@ class StreamError(Error, StanzaBase):
'system-shutdown', 'undefined-condition', 'unsupported-encoding',
'unsupported-feature', 'unsupported-stanza-type',
'unsupported-version'}
condition_ns = 'urn:ietf:params:xml:ns:xmpp-streams'
condition_ns: str = 'urn:ietf:params:xml:ns:xmpp-streams'
def get_see_other_host(self):
def get_see_other_host(self) -> Union[str, Dict[str, str]]:
ns = self.condition_ns
return self._get_sub_text('{%s}see-other-host' % ns, '')
def set_see_other_host(self, value):
def set_see_other_host(self, value: str) -> Optional[ET.Element]:
if value:
del self['condition']
ns = self.condition_ns
return self._set_sub_text('{%s}see-other-host' % ns, value)
elif self['condition'] == 'see-other-host':
del self['condition']
return None
def del_see_other_host(self):
def del_see_other_host(self) -> None:
self._del_sub('{%s}see-other-host' % self.condition_ns)

View File

@ -3,7 +3,8 @@
# Copyright (C) 2010 Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
from slixmpp.xmlstream import StanzaBase
from slixmpp.xmlstream import StanzaBase, ElementBase
from typing import ClassVar, Dict, Type
class StreamFeatures(StanzaBase):
@ -15,8 +16,8 @@ class StreamFeatures(StanzaBase):
namespace = 'http://etherx.jabber.org/streams'
interfaces = {'features', 'required', 'optional'}
sub_interfaces = interfaces
plugin_tag_map = {}
plugin_attrib_map = {}
plugin_attrib_map: ClassVar[Dict[str, Type[ElementBase]]] = {}
plugin_tag_map: ClassVar[Dict[str, Type[ElementBase]]] = {}
def setup(self, xml):
StanzaBase.setup(self, xml)

View File

@ -11,7 +11,7 @@ except ImportError:
# Python < 3.8
# just to make sure the imports do not break, but
# not usable.
from unittest import TestCase as IsolatedAsyncioTestCase
from unittest import TestCase as IsolatedAsyncioTestCase # type: ignore
from typing import (
Dict,
List,

View File

@ -17,9 +17,7 @@ from slixmpp.xmlstream.matcher import StanzaPath, MatcherId, MatchIDSender
from slixmpp.xmlstream.matcher import MatchXMLMask, MatchXPath
import asyncio
cls = asyncio.get_event_loop().__class__
cls.idle_call = lambda self, callback: callback()
class SlixTest(unittest.TestCase):

View File

@ -16,11 +16,13 @@ try:
from typing import (
Literal,
TypedDict,
Protocol,
)
except ImportError:
from typing_extensions import (
Literal,
TypedDict,
Protocol,
)
from slixmpp.jid import JID
@ -78,3 +80,11 @@ JidStr = Union[str, JID]
OptJidStr = Optional[Union[str, JID]]
MAMDefault = Literal['always', 'never', 'roster']
FilterString = Literal['in', 'out', 'out_sync']
__all__ = [
'Protocol', 'TypedDict', 'Literal', 'OptJid', 'JidStr', 'MAMDefault',
'PresenceTypes', 'PresenceShows', 'MessageTypes', 'IqTypes', 'MucRole',
'MucAffiliation', 'FilterString',
]

View File

@ -1,4 +1,3 @@
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2018 Emmanuel Gil Peyrot
# This file is part of Slixmpp.
@ -6,8 +5,11 @@
import os
import logging
from typing import Callable, Optional, Any
log = logging.getLogger(__name__)
class Cache:
def retrieve(self, key):
raise NotImplementedError
@ -16,7 +18,8 @@ class Cache:
raise NotImplementedError
def remove(self, key):
raise NotImplemented
raise NotImplementedError
class PerJidCache:
def retrieve_by_jid(self, jid, key):
@ -28,6 +31,7 @@ class PerJidCache:
def remove_by_jid(self, jid, key):
raise NotImplementedError
class MemoryCache(Cache):
def __init__(self):
self.cache = {}
@ -44,6 +48,7 @@ class MemoryCache(Cache):
del self.cache[key]
return True
class MemoryPerJidCache(PerJidCache):
def __init__(self):
self.cache = {}
@ -65,14 +70,15 @@ class MemoryPerJidCache(PerJidCache):
del cache[key]
return True
class FileSystemStorage:
def __init__(self, encode, decode, binary):
def __init__(self, encode: Optional[Callable[[Any], str]], decode: Optional[Callable[[str], Any]], binary: bool):
self.encode = encode if encode is not None else lambda x: x
self.decode = decode if decode is not None else lambda x: x
self.read = 'rb' if binary else 'r'
self.write = 'wb' if binary else 'w'
def _retrieve(self, directory, key):
def _retrieve(self, directory: str, key: str):
filename = os.path.join(directory, key.replace('/', '_'))
try:
with open(filename, self.read) as cache_file:
@ -86,7 +92,7 @@ class FileSystemStorage:
log.debug('Removing %s entry', key)
self._remove(directory, key)
def _store(self, directory, key, value):
def _store(self, directory: str, key: str, value):
filename = os.path.join(directory, key.replace('/', '_'))
try:
os.makedirs(directory, exist_ok=True)
@ -99,7 +105,7 @@ class FileSystemStorage:
except Exception:
log.debug('Failed to encode %s to cache:', key, exc_info=True)
def _remove(self, directory, key):
def _remove(self, directory: str, key: str):
filename = os.path.join(directory, key.replace('/', '_'))
try:
os.remove(filename)
@ -108,8 +114,9 @@ class FileSystemStorage:
return False
return True
class FileSystemCache(Cache, FileSystemStorage):
def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False):
def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False):
FileSystemStorage.__init__(self, encode, decode, binary)
self.base_dir = os.path.join(directory, cache_type)
@ -122,8 +129,9 @@ class FileSystemCache(Cache, FileSystemStorage):
def remove(self, key):
return self._remove(self.base_dir, key)
class FileSystemPerJidCache(PerJidCache, FileSystemStorage):
def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False):
def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False):
FileSystemStorage.__init__(self, encode, decode, binary)
self.base_dir = os.path.join(directory, cache_type)

View File

@ -2,15 +2,19 @@ import builtins
import sys
import hashlib
from typing import Optional, Union, Callable, List
def unicode(text):
bytes_ = builtins.bytes # alias the stdlib type but ew
def unicode(text: Union[bytes_, str]) -> str:
if not isinstance(text, str):
return text.decode('utf-8')
else:
return text
def bytes(text):
def bytes(text: Optional[Union[str, bytes_]]) -> bytes_:
"""
Convert Unicode text to UTF-8 encoded bytes.
@ -34,7 +38,7 @@ def bytes(text):
return builtins.bytes(text, encoding='utf-8')
def quote(text):
def quote(text: Union[str, bytes_]) -> bytes_:
"""
Enclose in quotes and escape internal slashes and double quotes.
@ -44,7 +48,7 @@ def quote(text):
return b'"' + text.replace(b'\\', b'\\\\').replace(b'"', b'\\"') + b'"'
def num_to_bytes(num):
def num_to_bytes(num: int) -> bytes_:
"""
Convert an integer into a four byte sequence.
@ -58,21 +62,21 @@ def num_to_bytes(num):
return bval
def bytes_to_num(bval):
def bytes_to_num(bval: bytes_) -> int:
"""
Convert a four byte sequence to an integer.
:param bytes bval: A four byte sequence to turn into an integer.
"""
num = 0
num += ord(bval[0] << 24)
num += ord(bval[1] << 16)
num += ord(bval[2] << 8)
num += ord(bval[3])
num += (bval[0] << 24)
num += (bval[1] << 16)
num += (bval[2] << 8)
num += (bval[3])
return num
def XOR(x, y):
def XOR(x: bytes_, y: bytes_) -> bytes_:
"""
Return the results of an XOR operation on two equal length byte strings.
@ -85,7 +89,7 @@ def XOR(x, y):
return builtins.bytes([a ^ b for a, b in zip(x, y)])
def hash(name):
def hash(name: str) -> Optional[Callable]:
"""
Return a hash function implementing the given algorithm.
@ -102,7 +106,7 @@ def hash(name):
return None
def hashes():
def hashes() -> List[str]:
"""
Return a list of available hashing algorithms.
@ -115,28 +119,3 @@ def hashes():
t += ['MD2']
hashes = ['SHA-' + h[3:] for h in dir(hashlib) if h.startswith('sha')]
return t + hashes
def setdefaultencoding(encoding):
"""
Set the current default string encoding used by the Unicode implementation.
Actually calls sys.setdefaultencoding under the hood - see the docs for that
for more details. This method exists only as a way to call find/call it
even after it has been 'deleted' when the site module is executed.
:param string encoding: An encoding name, compatible with sys.setdefaultencoding
"""
func = getattr(sys, 'setdefaultencoding', None)
if func is None:
import gc
import types
for obj in gc.get_objects():
if (isinstance(obj, types.BuiltinFunctionType)
and obj.__name__ == 'setdefaultencoding'):
func = obj
break
if func is None:
raise RuntimeError("Could not find setdefaultencoding")
sys.setdefaultencoding = func
return func(encoding)

View File

@ -1,4 +1,3 @@
# slixmpp.util.sasl.client
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# This module was originally based on Dave Cridland's Suelta library.
@ -6,9 +5,11 @@
# :copryight: (c) 2004-2013 David Alan Cridland
# :copyright: (c) 2013 Nathanael C. Fritz, Lance J.T. Stout
# :license: MIT, see LICENSE for more details
from __future__ import annotations
import logging
import stringprep
from typing import Iterable, Set, Callable, Dict, Any, Optional, Type
from slixmpp.util import hashes, bytes, stringprep_profiles
@ -16,11 +17,11 @@ log = logging.getLogger(__name__)
#: Global registry mapping mechanism names to implementation classes.
MECHANISMS = {}
MECHANISMS: Dict[str, Type[Mech]] = {}
#: Global registry mapping mechanism names to security scores.
MECH_SEC_SCORES = {}
MECH_SEC_SCORES: Dict[str, int] = {}
#: The SASLprep profile of stringprep used to validate simple username
@ -45,9 +46,10 @@ saslprep = stringprep_profiles.create(
unassigned=[stringprep.in_table_a1])
def sasl_mech(score):
def sasl_mech(score: int):
sec_score = score
def register(mech):
def register(mech: Type[Mech]):
n = 0
mech.score = sec_score
if mech.use_hashes:
@ -99,9 +101,9 @@ class Mech(object):
score = -1
use_hashes = False
channel_binding = False
required_credentials = set()
optional_credentials = set()
security = set()
required_credentials: Set[str] = set()
optional_credentials: Set[str] = set()
security: Set[str] = set()
def __init__(self, name, credentials, security_settings):
self.credentials = credentials
@ -118,7 +120,14 @@ class Mech(object):
return b''
def choose(mech_list, credentials, security_settings, limit=None, min_mech=None):
CredentialsCallback = Callable[[Iterable[str], Iterable[str]], Dict[str, Any]]
SecurityCallback = Callable[[Iterable[str]], Dict[str, Any]]
def choose(mech_list: Iterable[Type[Mech]], credentials: CredentialsCallback,
security_settings: SecurityCallback,
limit: Optional[Iterable[Type[Mech]]] = None,
min_mech: Optional[str] = None) -> Mech:
available_mechs = set(MECHANISMS.keys())
if limit is None:
limit = set(mech_list)
@ -130,7 +139,10 @@ def choose(mech_list, credentials, security_settings, limit=None, min_mech=None)
mech_list = mech_list.intersection(limit)
available_mechs = available_mechs.intersection(mech_list)
best_score = MECH_SEC_SCORES.get(min_mech, -1)
if min_mech is None:
best_score = -1
else:
best_score = MECH_SEC_SCORES.get(min_mech, -1)
best_mech = None
for name in available_mechs:
if name in MECH_SEC_SCORES:

View File

@ -11,6 +11,9 @@ import hmac
import random
from base64 import b64encode, b64decode
from typing import List, Dict, Optional
bytes_ = bytes
from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes
from slixmpp.util.sasl.client import sasl_mech, Mech, \
@ -63,7 +66,7 @@ class PLAIN(Mech):
if not self.security_settings['encrypted_plain']:
raise SASLCancelled('PLAIN with encryption')
def process(self, challenge=b''):
def process(self, challenge: bytes_ = b'') -> bytes_:
authzid = self.credentials['authzid']
authcid = self.credentials['username']
password = self.credentials['password']
@ -148,7 +151,7 @@ class CRAM(Mech):
required_credentials = {'username', 'password'}
security = {'encrypted', 'unencrypted_cram'}
def setup(self, name):
def setup(self, name: str):
self.hash_name = name[5:]
self.hash = hash(self.hash_name)
if self.hash is None:
@ -157,14 +160,14 @@ class CRAM(Mech):
if not self.security_settings['unencrypted_cram']:
raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name)
def process(self, challenge=b''):
def process(self, challenge: bytes_ = b'') -> Optional[bytes_]:
if not challenge:
return None
username = self.credentials['username']
password = self.credentials['password']
mac = hmac.HMAC(key=password, digestmod=self.hash)
mac = hmac.HMAC(key=password, digestmod=self.hash) # type: ignore
mac.update(challenge)
return username + b' ' + bytes(mac.hexdigest())
@ -201,43 +204,42 @@ class SCRAM(Mech):
def HMAC(self, key, msg):
return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest()
def Hi(self, text, salt, iterations):
text = bytes(text)
ui1 = self.HMAC(text, salt + b'\0\0\0\01')
def Hi(self, text: str, salt: bytes_, iterations: int):
text_enc = bytes(text)
ui1 = self.HMAC(text_enc, salt + b'\0\0\0\01')
ui = ui1
for i in range(iterations - 1):
ui1 = self.HMAC(text, ui1)
ui1 = self.HMAC(text_enc, ui1)
ui = XOR(ui, ui1)
return ui
def H(self, text):
def H(self, text: str) -> bytes_:
return self.hash(text).digest()
def saslname(self, value):
value = value.decode("utf-8")
escaped = []
def saslname(self, value_b: bytes_) -> bytes_:
value = value_b.decode("utf-8")
escaped: List[str] = []
for char in value:
if char == ',':
escaped += b'=2C'
escaped.append('=2C')
elif char == '=':
escaped += b'=3D'
escaped.append('=3D')
else:
escaped += char
escaped.append(char)
return "".join(escaped).encode("utf-8")
def parse(self, challenge):
def parse(self, challenge: bytes_) -> Dict[bytes_, bytes_]:
items = {}
for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]:
items[key] = value
return items
def process(self, challenge=b''):
def process(self, challenge: bytes_ = b''):
steps = [self.process_1, self.process_2, self.process_3]
return steps[self.step](challenge)
def process_1(self, challenge):
def process_1(self, challenge: bytes_) -> bytes_:
self.step = 1
data = {}
self.cnonce = bytes(('%s' % random.random())[2:])
@ -263,7 +265,7 @@ class SCRAM(Mech):
return self.client_first_message
def process_2(self, challenge):
def process_2(self, challenge: bytes_) -> bytes_:
self.step = 2
data = self.parse(challenge)
@ -304,7 +306,7 @@ class SCRAM(Mech):
return client_final_message
def process_3(self, challenge):
def process_3(self, challenge: bytes_) -> bytes_:
data = self.parse(challenge)
verifier = data.get(b'v', None)
error = data.get(b'e', 'Unknown error')
@ -345,17 +347,16 @@ class DIGEST(Mech):
self.cnonce = b''
self.nonce_count = 1
def parse(self, challenge=b''):
data = {}
def parse(self, challenge: bytes_ = b''):
data: Dict[str, bytes_] = {}
var_name = b''
var_value = b''
# States: var, new_var, end, quote, escaped_quote
state = 'var'
for char in challenge:
char = bytes([char])
for char_int in challenge:
char = bytes_([char_int])
if state == 'var':
if char.isspace():
@ -401,14 +402,14 @@ class DIGEST(Mech):
state = 'var'
return data
def MAC(self, key, seq, msg):
def MAC(self, key: bytes_, seq: int, msg: bytes_) -> bytes_:
mac = hmac.HMAC(key=key, digestmod=self.hash)
seqnum = num_to_bytes(seq)
mac.update(seqnum)
mac.update(msg)
return mac.digest()[:10] + b'\x00\x01' + seqnum
def A1(self):
def A1(self) -> bytes_:
username = self.credentials['username']
password = self.credentials['password']
authzid = self.credentials['authzid']
@ -423,13 +424,13 @@ class DIGEST(Mech):
return bytes(a1)
def A2(self, prefix=b''):
def A2(self, prefix: bytes_ = b'') -> bytes_:
a2 = prefix + b':' + self.digest_uri()
if self.qop in (b'auth-int', b'auth-conf'):
a2 += b':00000000000000000000000000000000'
return bytes(a2)
def response(self, prefix=b''):
def response(self, prefix: bytes_ = b'') -> bytes_:
nc = bytes('%08x' % self.nonce_count)
a1 = bytes(self.hash(self.A1()).hexdigest().lower())
@ -439,7 +440,7 @@ class DIGEST(Mech):
return bytes(self.hash(a1 + b':' + s).hexdigest().lower())
def digest_uri(self):
def digest_uri(self) -> bytes_:
serv_type = self.credentials['service']
serv_name = self.credentials['service-name']
host = self.credentials['host']
@ -449,7 +450,7 @@ class DIGEST(Mech):
uri += b'/' + serv_name
return uri
def respond(self):
def respond(self) -> bytes_:
data = {
'username': quote(self.credentials['username']),
'authzid': quote(self.credentials['authzid']),
@ -469,7 +470,7 @@ class DIGEST(Mech):
resp += b',' + bytes(key) + b'=' + bytes(value)
return resp[1:]
def process(self, challenge=b''):
def process(self, challenge: bytes_ = b'') -> Optional[bytes_]:
if not challenge:
if self.cnonce and self.nonce and self.nonce_count and self.qop:
self.nonce_count += 1
@ -480,6 +481,7 @@ class DIGEST(Mech):
if 'rspauth' in data:
if data['rspauth'] != self.response():
raise SASLMutualAuthFailed()
return None
else:
self.nonce_count = 1
self.cnonce = bytes('%s' % random.random())[2:]

View File

@ -1,22 +0,0 @@
"""
asyncio-related utilities
"""
import asyncio
from functools import wraps
def future_wrapper(func):
"""
Make sure the result of a function call is an asyncio.Future()
object.
"""
@wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
if isinstance(result, asyncio.Future):
return result
future = asyncio.Future()
future.set_result(result)
return future
return wrapper

View File

@ -1,5 +1,6 @@
import logging
from datetime import datetime, timedelta
from typing import Dict, Set, Tuple, Optional
# Make a call to strptime before starting threads to
# prevent thread safety issues.
@ -32,13 +33,13 @@ class CertificateError(Exception):
pass
def decode_str(data):
def decode_str(data: bytes) -> str:
encoding = 'utf-16-be' if isinstance(data, BMPString) else 'utf-8'
return bytes(data).decode(encoding)
def extract_names(raw_cert):
results = {'CN': set(),
def extract_names(raw_cert: bytes) -> Dict[str, Set[str]]:
results: Dict[str, Set[str]] = {'CN': set(),
'DNS': set(),
'SRV': set(),
'URI': set(),
@ -96,7 +97,7 @@ def extract_names(raw_cert):
return results
def extract_dates(raw_cert):
def extract_dates(raw_cert: bytes) -> Tuple[Optional[datetime], Optional[datetime]]:
if not HAVE_PYASN1:
log.warning("Could not find pyasn1 and pyasn1_modules. " + \
"SSL certificate expiration COULD NOT BE VERIFIED.")
@ -125,24 +126,29 @@ def extract_dates(raw_cert):
return not_before, not_after
def get_ttl(raw_cert):
def get_ttl(raw_cert: bytes) -> Optional[timedelta]:
not_before, not_after = extract_dates(raw_cert)
if not_after is None:
if not_after is None or not_before is None:
return None
return not_after - datetime.utcnow()
def verify(expected, raw_cert):
def verify(expected: str, raw_cert: bytes) -> Optional[bool]:
if not HAVE_PYASN1:
log.warning("Could not find pyasn1 and pyasn1_modules. " + \
"SSL certificate COULD NOT BE VERIFIED.")
return
return None
not_before, not_after = extract_dates(raw_cert)
cert_names = extract_names(raw_cert)
now = datetime.utcnow()
if not not_before or not not_after:
raise CertificateError(
"Error while checking the dates of the certificate"
)
if not_before > now:
raise CertificateError(
'Certificate has not entered its valid date range.')

View File

@ -4,10 +4,19 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
from __future__ import annotations
import weakref
from weakref import ReferenceType
from typing import Optional, TYPE_CHECKING, Union
from slixmpp.xmlstream.matcher.base import MatcherBase
from xml.etree.ElementTree import Element
if TYPE_CHECKING:
from slixmpp.xmlstream import XMLStream, StanzaBase
class BaseHandler(object):
class BaseHandler:
"""
Base class for stream handlers. Stream handlers are matched with
@ -26,8 +35,13 @@ class BaseHandler(object):
:param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream`
instance that the handle will respond to.
"""
name: str
stream: Optional[ReferenceType[XMLStream]]
_destroy: bool
_matcher: MatcherBase
_payload: Optional[StanzaBase]
def __init__(self, name, matcher, stream=None):
def __init__(self, name: str, matcher: MatcherBase, stream: Optional[XMLStream] = None):
#: The name of the handler
self.name = name
@ -41,33 +55,33 @@ class BaseHandler(object):
self._payload = None
self._matcher = matcher
def match(self, xml):
def match(self, xml: StanzaBase) -> bool:
"""Compare a stanza or XML object with the handler's matcher.
:param xml: An XML or
:class:`~slixmpp.xmlstream.stanzabase.ElementBase` object
:class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object
"""
return self._matcher.match(xml)
def prerun(self, payload):
def prerun(self, payload: StanzaBase) -> None:
"""Prepare the handler for execution while the XML
stream is being processed.
:param payload: A :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
:param payload: A :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
object.
"""
self._payload = payload
def run(self, payload):
def run(self, payload: StanzaBase) -> None:
"""Execute the handler after XML stream processing and during the
main event loop.
:param payload: A :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
:param payload: A :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
object.
"""
self._payload = payload
def check_delete(self):
def check_delete(self) -> bool:
"""Check if the handler should be removed from the list
of stream handlers.
"""

View File

@ -1,10 +1,17 @@
# slixmpp.xmlstream.handler.callback
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
from __future__ import annotations
from typing import Optional, Callable, Any, TYPE_CHECKING
from slixmpp.xmlstream.handler.base import BaseHandler
from slixmpp.xmlstream.matcher.base import MatcherBase
if TYPE_CHECKING:
from slixmpp.xmlstream.stanzabase import StanzaBase
from slixmpp.xmlstream.xmlstream import XMLStream
class Callback(BaseHandler):
@ -28,8 +35,6 @@ class Callback(BaseHandler):
:param matcher: A :class:`~slixmpp.xmlstream.matcher.base.MatcherBase`
derived object for matching stanza objects.
:param pointer: The function to execute during callback.
:param bool thread: **DEPRECATED.** Remains only for
backwards compatibility.
:param bool once: Indicates if the handler should be used only
once. Defaults to False.
:param bool instream: Indicates if the callback should be executed
@ -38,31 +43,36 @@ class Callback(BaseHandler):
:param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream`
instance this handler should monitor.
"""
_once: bool
_instream: bool
def __init__(self, name, matcher, pointer, thread=False,
once=False, instream=False, stream=None):
def __init__(self, name: str, matcher: MatcherBase,
pointer: Callable[[StanzaBase], Any],
once: bool = False, instream: bool = False,
stream: Optional[XMLStream] = None):
BaseHandler.__init__(self, name, matcher, stream)
self._pointer: Callable[[StanzaBase], Any] = pointer
self._pointer = pointer
self._once = once
self._instream = instream
def prerun(self, payload):
def prerun(self, payload: StanzaBase) -> None:
"""Execute the callback during stream processing, if
the callback was created with ``instream=True``.
:param payload: The matched
:class:`~slixmpp.xmlstream.stanzabase.ElementBase` object.
:class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object.
"""
if self._once:
self._destroy = True
if self._instream:
self.run(payload, True)
def run(self, payload, instream=False):
def run(self, payload: StanzaBase, instream: bool = False) -> None:
"""Execute the callback function with the matched stanza payload.
:param payload: The matched
:class:`~slixmpp.xmlstream.stanzabase.ElementBase` object.
:class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object.
:param bool instream: Force the handler to execute during stream
processing. This should only be used by
:meth:`prerun()`. Defaults to ``False``.

View File

@ -4,11 +4,17 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2012 Nathanael C. Fritz, Lance J.T. Stout
# :license: MIT, see LICENSE for more details
from __future__ import annotations
import logging
from queue import Queue, Empty
from typing import List, Optional, TYPE_CHECKING
from slixmpp.xmlstream.stanzabase import StanzaBase
from slixmpp.xmlstream.handler.base import BaseHandler
from slixmpp.xmlstream.matcher.base import MatcherBase
if TYPE_CHECKING:
from slixmpp.xmlstream.xmlstream import XMLStream
log = logging.getLogger(__name__)
@ -27,35 +33,35 @@ class Collector(BaseHandler):
:param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream`
instance this handler should monitor.
"""
_stanzas: List[StanzaBase]
def __init__(self, name, matcher, stream=None):
def __init__(self, name: str, matcher: MatcherBase, stream: Optional[XMLStream] = None):
BaseHandler.__init__(self, name, matcher, stream=stream)
self._payload = Queue()
self._stanzas = []
def prerun(self, payload):
def prerun(self, payload: StanzaBase) -> None:
"""Store the matched stanza when received during processing.
:param payload: The matched
:class:`~slixmpp.xmlstream.stanzabase.ElementBase` object.
:class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object.
"""
self._payload.put(payload)
self._stanzas.append(payload)
def run(self, payload):
def run(self, payload: StanzaBase) -> None:
"""Do not process this handler during the main event loop."""
pass
def stop(self):
def stop(self) -> List[StanzaBase]:
"""
Stop collection of matching stanzas, and return the ones that
have been stored so far.
"""
stream_ref = self.stream
if stream_ref is None:
raise ValueError('stop() called without a stream!')
stream = stream_ref()
if stream is None:
raise ValueError('stop() called without a stream!')
self._destroy = True
results = []
try:
while True:
results.append(self._payload.get(False))
except Empty:
pass
self.stream().remove_handler(self.name)
return results
stream.remove_handler(self.name)
return self._stanzas

View File

@ -4,8 +4,19 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
from __future__ import annotations
from asyncio import iscoroutinefunction, ensure_future
from typing import Optional, Callable, Awaitable, TYPE_CHECKING
from slixmpp.xmlstream.stanzabase import StanzaBase
from slixmpp.xmlstream.handler.base import BaseHandler
from slixmpp.xmlstream.asyncio import asyncio
from slixmpp.xmlstream.matcher.base import MatcherBase
CoroutineFunction = Callable[[StanzaBase], Awaitable[None]]
if TYPE_CHECKING:
from slixmpp.xmlstream.xmlstream import XMLStream
class CoroutineCallback(BaseHandler):
@ -34,45 +45,49 @@ class CoroutineCallback(BaseHandler):
instance this handler should monitor.
"""
def __init__(self, name, matcher, pointer, once=False,
instream=False, stream=None):
_once: bool
_instream: bool
def __init__(self, name: str, matcher: MatcherBase,
pointer: CoroutineFunction, once: bool = False,
instream: bool = False, stream: Optional[XMLStream] = None):
BaseHandler.__init__(self, name, matcher, stream)
if not asyncio.iscoroutinefunction(pointer):
if not iscoroutinefunction(pointer):
raise ValueError("Given function is not a coroutine")
async def pointer_wrapper(stanza, *args, **kwargs):
async def pointer_wrapper(stanza: StanzaBase) -> None:
try:
await pointer(stanza, *args, **kwargs)
await pointer(stanza)
except Exception as e:
stanza.exception(e)
self._pointer = pointer_wrapper
self._pointer: CoroutineFunction = pointer_wrapper
self._once = once
self._instream = instream
def prerun(self, payload):
def prerun(self, payload: StanzaBase) -> None:
"""Execute the callback during stream processing, if
the callback was created with ``instream=True``.
:param payload: The matched
:class:`~slixmpp.xmlstream.stanzabase.ElementBase` object.
:class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object.
"""
if self._once:
self._destroy = True
if self._instream:
self.run(payload, True)
def run(self, payload, instream=False):
def run(self, payload: StanzaBase, instream: bool = False) -> None:
"""Execute the callback function with the matched stanza payload.
:param payload: The matched
:class:`~slixmpp.xmlstream.stanzabase.ElementBase` object.
:class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object.
:param bool instream: Force the handler to execute during stream
processing. This should only be used by
:meth:`prerun()`. Defaults to ``False``.
"""
if not self._instream or instream:
asyncio.ensure_future(self._pointer(payload))
ensure_future(self._pointer(payload))
if self._once:
self._destroy = True
del self._pointer

View File

@ -4,13 +4,20 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
from __future__ import annotations
import logging
import asyncio
from asyncio import Queue, wait_for, TimeoutError
from asyncio import Event, wait_for, TimeoutError
from typing import Optional, TYPE_CHECKING, Union
from xml.etree.ElementTree import Element
import slixmpp
from slixmpp.xmlstream.stanzabase import StanzaBase
from slixmpp.xmlstream.handler.base import BaseHandler
from slixmpp.xmlstream.matcher.base import MatcherBase
if TYPE_CHECKING:
from slixmpp.xmlstream.xmlstream import XMLStream
log = logging.getLogger(__name__)
@ -28,24 +35,27 @@ class Waiter(BaseHandler):
:param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream`
instance this handler should monitor.
"""
_event: Event
def __init__(self, name, matcher, stream=None):
def __init__(self, name: str, matcher: MatcherBase, stream: Optional[XMLStream] = None):
BaseHandler.__init__(self, name, matcher, stream=stream)
self._payload = Queue()
self._event = Event()
def prerun(self, payload):
def prerun(self, payload: StanzaBase) -> None:
"""Store the matched stanza when received during processing.
:param payload: The matched
:class:`~slixmpp.xmlstream.stanzabase.ElementBase` object.
:class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object.
"""
self._payload.put_nowait(payload)
if not self._event.is_set():
self._event.set()
self._payload = payload
def run(self, payload):
def run(self, payload: StanzaBase) -> None:
"""Do not process this handler during the main event loop."""
pass
async def wait(self, timeout=None):
async def wait(self, timeout: Optional[int] = None) -> Optional[StanzaBase]:
"""Block an event handler while waiting for a stanza to arrive.
Be aware that this will impact performance if called from a
@ -59,17 +69,24 @@ class Waiter(BaseHandler):
:class:`~slixmpp.xmlstream.xmlstream.XMLStream.response_timeout`
value.
"""
stream_ref = self.stream
if stream_ref is None:
raise ValueError('wait() called without a stream')
stream = stream_ref()
if stream is None:
raise ValueError('wait() called without a stream')
if timeout is None:
timeout = slixmpp.xmlstream.RESPONSE_TIMEOUT
stanza = None
try:
stanza = await self._payload.get()
await wait_for(
self._event.wait(), timeout, loop=stream.loop
)
except TimeoutError:
log.warning("Timed out waiting for %s", self.name)
self.stream().remove_handler(self.name)
return stanza
stream.remove_handler(self.name)
return self._payload
def check_delete(self):
def check_delete(self) -> bool:
"""Always remove waiters after use."""
return True

View File

@ -4,6 +4,7 @@
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
from slixmpp.xmlstream.handler import Callback
from slixmpp.xmlstream.stanzabase import StanzaBase
class XMLCallback(Callback):
@ -17,7 +18,7 @@ class XMLCallback(Callback):
run -- Overrides Callback.run
"""
def run(self, payload, instream=False):
def run(self, payload: StanzaBase, instream: bool = False) -> None:
"""
Execute the callback function with the matched stanza's
XML contents, instead of the stanza itself.
@ -30,4 +31,4 @@ class XMLCallback(Callback):
stream processing. Used only by prerun.
Defaults to False.
"""
Callback.run(self, payload.xml, instream)
Callback.run(self, payload.xml, instream) # type: ignore

View File

@ -3,6 +3,7 @@
# Copyright (C) 2010 Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
from slixmpp.xmlstream.stanzabase import StanzaBase
from slixmpp.xmlstream.handler import Waiter
@ -17,7 +18,7 @@ class XMLWaiter(Waiter):
prerun -- Overrides Waiter.prerun
"""
def prerun(self, payload):
def prerun(self, payload: StanzaBase) -> None:
"""
Store the XML contents of the stanza to return to the
waiting event handler.
@ -27,4 +28,4 @@ class XMLWaiter(Waiter):
Arguments:
payload -- The matched stanza object.
"""
Waiter.prerun(self, payload.xml)
Waiter.prerun(self, payload.xml) # type: ignore

View File

@ -1,10 +1,13 @@
# slixmpp.xmlstream.matcher.base
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
from typing import Any
from slixmpp.xmlstream.stanzabase import StanzaBase
class MatcherBase(object):
"""
@ -15,10 +18,10 @@ class MatcherBase(object):
:param criteria: Object to compare some aspect of a stanza against.
"""
def __init__(self, criteria):
def __init__(self, criteria: Any):
self._criteria = criteria
def match(self, xml):
def match(self, xml: StanzaBase) -> bool:
"""Check if a stanza matches the stored criteria.
Meant to be overridden.

View File

@ -5,6 +5,7 @@
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
from slixmpp.xmlstream.matcher.base import MatcherBase
from slixmpp.xmlstream.stanzabase import StanzaBase
class MatcherId(MatcherBase):
@ -13,12 +14,13 @@ class MatcherId(MatcherBase):
The ID matcher selects stanzas that have the same stanza 'id'
interface value as the desired ID.
"""
_criteria: str
def match(self, xml):
def match(self, xml: StanzaBase) -> bool:
"""Compare the given stanza's ``'id'`` attribute to the stored
``id`` value.
:param xml: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
:param xml: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
stanza to compare against.
"""
return xml['id'] == self._criteria
return bool(xml['id'] == self._criteria)

View File

@ -4,7 +4,19 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
from slixmpp.xmlstream.matcher.base import MatcherBase
from slixmpp.xmlstream.stanzabase import StanzaBase
from slixmpp.jid import JID
from slixmpp.types import TypedDict
from typing import Dict
class CriteriaType(TypedDict):
self: JID
peer: JID
id: str
class MatchIDSender(MatcherBase):
@ -14,25 +26,26 @@ class MatchIDSender(MatcherBase):
interface value as the desired ID, and that the 'from' value is one
of a set of approved entities that can respond to a request.
"""
_criteria: CriteriaType
def match(self, xml):
def match(self, xml: StanzaBase) -> bool:
"""Compare the given stanza's ``'id'`` attribute to the stored
``id`` value, and verify the sender's JID.
:param xml: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
:param xml: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
stanza to compare against.
"""
selfjid = self._criteria['self']
peerjid = self._criteria['peer']
allowed = {}
allowed: Dict[str, bool] = {}
allowed[''] = True
allowed[selfjid.bare] = True
allowed[selfjid.host] = True
allowed[selfjid.domain] = True
allowed[peerjid.full] = True
allowed[peerjid.bare] = True
allowed[peerjid.host] = True
allowed[peerjid.domain] = True
_from = xml['from']

View File

@ -3,7 +3,9 @@
# Copyright (C) 2010 Nathanael C. Fritz
# This file is part of Slixmpp.
# See the file LICENSE for copying permission.
from typing import Iterable
from slixmpp.xmlstream.matcher.base import MatcherBase
from slixmpp.xmlstream.stanzabase import StanzaBase
class MatchMany(MatcherBase):
@ -18,8 +20,9 @@ class MatchMany(MatcherBase):
Methods:
match -- Overrides MatcherBase.match.
"""
_criteria: Iterable[MatcherBase]
def match(self, xml):
def match(self, xml: StanzaBase) -> bool:
"""
Match a stanza against multiple criteria. The match is successful
if one of the criteria matches.

View File

@ -4,8 +4,9 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
from typing import cast, List
from slixmpp.xmlstream.matcher.base import MatcherBase
from slixmpp.xmlstream.stanzabase import fix_ns
from slixmpp.xmlstream.stanzabase import fix_ns, StanzaBase
class StanzaPath(MatcherBase):
@ -17,22 +18,28 @@ class StanzaPath(MatcherBase):
:param criteria: Object to compare some aspect of a stanza against.
"""
_criteria: List[str]
_raw_criteria: str
def __init__(self, criteria):
self._criteria = fix_ns(criteria, split=True,
propagate_ns=False,
default_ns='jabber:client')
def __init__(self, criteria: str):
self._criteria = cast(
List[str],
fix_ns(
criteria, split=True, propagate_ns=False,
default_ns='jabber:client'
)
)
self._raw_criteria = criteria
def match(self, stanza):
def match(self, stanza: StanzaBase) -> bool:
"""
Compare a stanza against a "stanza path". A stanza path is similar to
an XPath expression, but uses the stanza's interfaces and plugins
instead of the underlying XML. See the documentation for the stanza
:meth:`~slixmpp.xmlstream.stanzabase.ElementBase.match()` method
:meth:`~slixmpp.xmlstream.stanzabase.StanzaBase.match()` method
for more information.
:param stanza: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
:param stanza: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
stanza to compare against.
"""
return stanza.match(self._criteria) or stanza.match(self._raw_criteria)

View File

@ -1,4 +1,3 @@
# Slixmpp: The Slick XMPP Library
# Copyright (C) 2010 Nathanael C. Fritz
# This file is part of Slixmpp.
@ -6,8 +5,9 @@
import logging
from xml.parsers.expat import ExpatError
from xml.etree.ElementTree import Element
from slixmpp.xmlstream.stanzabase import ET
from slixmpp.xmlstream.stanzabase import ET, StanzaBase
from slixmpp.xmlstream.matcher.base import MatcherBase
@ -33,32 +33,33 @@ class MatchXMLMask(MatcherBase):
:param criteria: Either an :class:`~xml.etree.ElementTree.Element` XML
object or XML string to use as a mask.
"""
_criteria: Element
def __init__(self, criteria, default_ns='jabber:client'):
def __init__(self, criteria: str, default_ns: str = 'jabber:client'):
MatcherBase.__init__(self, criteria)
if isinstance(criteria, str):
self._criteria = ET.fromstring(self._criteria)
self._criteria = ET.fromstring(criteria)
self.default_ns = default_ns
def setDefaultNS(self, ns):
def setDefaultNS(self, ns: str) -> None:
"""Set the default namespace to use during comparisons.
:param ns: The new namespace to use as the default.
"""
self.default_ns = ns
def match(self, xml):
def match(self, xml: StanzaBase) -> bool:
"""Compare a stanza object or XML object against the stored XML mask.
Overrides MatcherBase.match.
:param xml: The stanza object or XML object to compare against.
"""
if hasattr(xml, 'xml'):
xml = xml.xml
return self._mask_cmp(xml, self._criteria, True)
real_xml = xml.xml
return self._mask_cmp(real_xml, self._criteria, True)
def _mask_cmp(self, source, mask, use_ns=False, default_ns='__no_ns__'):
def _mask_cmp(self, source: Element, mask: Element, use_ns: bool = False,
default_ns: str = '__no_ns__') -> bool:
"""Compare an XML object against an XML mask.
:param source: The :class:`~xml.etree.ElementTree.Element` XML object
@ -75,13 +76,6 @@ class MatchXMLMask(MatcherBase):
# If the element was not found. May happen during recursive calls.
return False
# Convert the mask to an XML object if it is a string.
if not hasattr(mask, 'attrib'):
try:
mask = ET.fromstring(mask)
except ExpatError:
log.warning("Expat error: %s\nIn parsing: %s", '', mask)
mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag)
if source.tag not in [mask.tag, mask_ns_tag]:
return False

View File

@ -4,7 +4,8 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
from slixmpp.xmlstream.stanzabase import ET, fix_ns
from typing import cast
from slixmpp.xmlstream.stanzabase import ET, fix_ns, StanzaBase
from slixmpp.xmlstream.matcher.base import MatcherBase
@ -17,23 +18,23 @@ class MatchXPath(MatcherBase):
If the value of :data:`IGNORE_NS` is set to ``True``, then XPath
expressions will be matched without using namespaces.
"""
_criteria: str
def __init__(self, criteria):
self._criteria = fix_ns(criteria)
def __init__(self, criteria: str):
self._criteria = cast(str, fix_ns(criteria))
def match(self, xml):
def match(self, xml: StanzaBase) -> bool:
"""
Compare a stanza's XML contents to an XPath expression.
If the value of :data:`IGNORE_NS` is set to ``True``, then XPath
expressions will be matched without using namespaces.
:param xml: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
:param xml: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
stanza to compare against.
"""
if hasattr(xml, 'xml'):
xml = xml.xml
real_xml = xml.xml
x = ET.Element('x')
x.append(xml)
x.append(real_xml)
return x.find(self._criteria) is not None

View File

@ -1,18 +1,32 @@
# slixmpp.xmlstream.dns
# ~~~~~~~~~~~~~~~~~~~~~~~
# :copyright: (c) 2012 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
from slixmpp.xmlstream.asyncio import asyncio
import socket
import sys
import logging
import random
from asyncio import Future, AbstractEventLoop
from typing import Optional, Tuple, Dict, List, Iterable, cast
from slixmpp.types import Protocol
log = logging.getLogger(__name__)
class AnswerProtocol(Protocol):
host: str
priority: int
weight: int
port: int
class ResolverProtocol(Protocol):
def query(self, query: str, querytype: str) -> Future:
...
#: Global flag indicating the availability of the ``aiodns`` package.
#: Installing ``aiodns`` can be done via:
#:
@ -23,12 +37,12 @@ AIODNS_AVAILABLE = False
try:
import aiodns
AIODNS_AVAILABLE = True
except ImportError as e:
log.debug("Could not find aiodns package. " + \
except ImportError:
log.debug("Could not find aiodns package. "
"Not all features will be available")
def default_resolver(loop):
def default_resolver(loop: AbstractEventLoop) -> Optional[ResolverProtocol]:
"""Return a basic DNS resolver object.
:returns: A :class:`aiodns.DNSResolver` object if aiodns
@ -41,8 +55,11 @@ def default_resolver(loop):
return None
async def resolve(host, port=None, service=None, proto='tcp',
resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
async def resolve(host: str, port: int, *, loop: AbstractEventLoop,
service: Optional[str] = None, proto: str = 'tcp',
resolver: Optional[ResolverProtocol] = None,
use_ipv6: bool = True,
use_aiodns: bool = True) -> List[Tuple[str, str, int]]:
"""Peform DNS resolution for a given hostname.
Resolution may perform SRV record lookups if a service and protocol
@ -91,8 +108,8 @@ async def resolve(host, port=None, service=None, proto='tcp',
if not use_ipv6:
log.debug("DNS: Use of IPv6 has been disabled.")
if resolver is None and AIODNS_AVAILABLE and use_aiodns:
resolver = aiodns.DNSResolver(loop=loop)
if resolver is None and use_aiodns:
resolver = default_resolver(loop=loop)
# An IPv6 literal is allowed to be enclosed in square brackets, but
# the brackets must be stripped in order to process the literal;
@ -101,7 +118,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
try:
# If `host` is an IPv4 literal, we can return it immediately.
ipv4 = socket.inet_aton(host)
socket.inet_aton(host)
return [(host, host, port)]
except socket.error:
pass
@ -111,7 +128,7 @@ async def resolve(host, port=None, service=None, proto='tcp',
# Likewise, If `host` is an IPv6 literal, we can return
# it immediately.
if hasattr(socket, 'inet_pton'):
ipv6 = socket.inet_pton(socket.AF_INET6, host)
socket.inet_pton(socket.AF_INET6, host)
return [(host, host, port)]
except (socket.error, ValueError):
pass
@ -148,7 +165,10 @@ async def resolve(host, port=None, service=None, proto='tcp',
return results
async def get_A(host, resolver=None, use_aiodns=True, loop=None):
async def get_A(host: str, *, loop: AbstractEventLoop,
resolver: Optional[ResolverProtocol] = None,
use_aiodns: bool = True) -> List[str]:
"""Lookup DNS A records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@ -172,10 +192,10 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
# getaddrinfo() method.
if resolver is None or not use_aiodns:
try:
recs = await loop.getaddrinfo(host, None,
inet_recs = await loop.getaddrinfo(host, None,
family=socket.AF_INET,
type=socket.SOCK_STREAM)
return [rec[4][0] for rec in recs]
return [rec[4][0] for rec in inet_recs]
except socket.gaierror:
log.debug("DNS: Error retrieving A address info for %s." % host)
return []
@ -183,14 +203,16 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None):
# Using aiodns:
future = resolver.query(host, 'A')
try:
recs = await future
recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s A records: %s', host, e)
recs = []
return [rec.host for rec in recs]
async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
async def get_AAAA(host: str, *, loop: AbstractEventLoop,
resolver: Optional[ResolverProtocol] = None,
use_aiodns: bool = True) -> List[str]:
"""Lookup DNS AAAA records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will
@ -217,10 +239,10 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
return []
try:
recs = await loop.getaddrinfo(host, None,
inet_recs = await loop.getaddrinfo(host, None,
family=socket.AF_INET6,
type=socket.SOCK_STREAM)
return [rec[4][0] for rec in recs]
return [rec[4][0] for rec in inet_recs]
except (OSError, socket.gaierror):
log.debug("DNS: Error retrieving AAAA address " + \
"info for %s." % host)
@ -229,13 +251,17 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
# Using aiodns:
future = resolver.query(host, 'AAAA')
try:
recs = await future
recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e)
recs = []
return [rec.host for rec in recs]
async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True):
async def get_SRV(host: str, port: int, service: str,
proto: str = 'tcp',
resolver: Optional[ResolverProtocol] = None,
use_aiodns: bool = True) -> List[Tuple[str, int]]:
"""Perform SRV record resolution for a given host.
.. note::
@ -269,12 +295,12 @@ async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=Tr
try:
future = resolver.query('_%s._%s.%s' % (service, proto, host),
'SRV')
recs = await future
recs = cast(Iterable[AnswerProtocol], await future)
except Exception as e:
log.debug('DNS: Exception while querying for %s SRV records: %s', host, e)
return []
answers = {}
answers: Dict[int, List[AnswerProtocol]] = {}
for rec in recs:
if rec.priority not in answers:
answers[rec.priority] = []

View File

@ -1,4 +1,3 @@
# slixmpp.xmlstream.stanzabase
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# module implements a wrapper layer for XML objects
@ -11,13 +10,34 @@ from __future__ import annotations
import copy
import logging
import weakref
from typing import Optional
from typing import (
cast,
Any,
Callable,
ClassVar,
Coroutine,
Dict,
List,
Iterable,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from weakref import ReferenceType
from xml.etree import ElementTree as ET
from slixmpp.types import JidStr
from slixmpp.xmlstream import JID
from slixmpp.xmlstream.tostring import tostring
if TYPE_CHECKING:
from slixmpp.xmlstream import XMLStream
log = logging.getLogger(__name__)
@ -28,7 +48,8 @@ XML_TYPE = type(ET.Element('xml'))
XML_NS = 'http://www.w3.org/XML/1998/namespace'
def register_stanza_plugin(stanza, plugin, iterable=False, overrides=False):
def register_stanza_plugin(stanza: Type[ElementBase], plugin: Type[ElementBase],
iterable: bool = False, overrides: bool = False) -> None:
"""
Associate a stanza object as a plugin for another stanza.
@ -85,15 +106,15 @@ def register_stanza_plugin(stanza, plugin, iterable=False, overrides=False):
stanza.plugin_overrides[interface] = plugin.plugin_attrib
def multifactory(stanza, plugin_attrib):
def multifactory(stanza: Type[ElementBase], plugin_attrib: str) -> Type[ElementBase]:
"""
Returns a ElementBase class for handling reoccuring child stanzas
"""
def plugin_filter(self):
def plugin_filter(self: Multi) -> Callable[..., bool]:
return lambda x: isinstance(x, self._multistanza)
def plugin_lang_filter(self, lang):
def plugin_lang_filter(self: Multi, lang: Optional[str]) -> Callable[..., bool]:
return lambda x: isinstance(x, self._multistanza) and \
x['lang'] == lang
@ -101,31 +122,41 @@ def multifactory(stanza, plugin_attrib):
"""
Template class for multifactory
"""
def setup(self, xml=None):
self.xml = ET.Element('')
_multistanza: Type[ElementBase]
def get_multi(self, lang=None):
parent = self.parent()
def setup(self, xml: Optional[ET.Element] = None) -> bool:
self.xml = ET.Element('')
return False
def get_multi(self: Multi, lang: Optional[str] = None) -> List[ElementBase]:
parent = fail_without_parent(self)
if not lang or lang == '*':
res = filter(plugin_filter(self), parent)
else:
res = filter(plugin_filter(self, lang), parent)
res = filter(plugin_lang_filter(self, lang), parent)
return list(res)
def set_multi(self, val, lang=None):
parent = self.parent()
def set_multi(self: Multi, val: Iterable[ElementBase], lang: Optional[str] = None) -> None:
parent = fail_without_parent(self)
del_multi = getattr(self, 'del_%s' % plugin_attrib)
del_multi(lang)
for sub in val:
parent.append(sub)
def del_multi(self, lang=None):
parent = self.parent()
def fail_without_parent(self: Multi) -> ElementBase:
parent = None
if self.parent:
parent = self.parent()
if not parent:
raise ValueError('No stanza parent for multifactory')
return parent
def del_multi(self: Multi, lang: Optional[str] = None) -> None:
parent = fail_without_parent(self)
if not lang or lang == '*':
res = filter(plugin_filter(self), parent)
res = list(filter(plugin_filter(self), parent))
else:
res = filter(plugin_filter(self, lang), parent)
res = list(res)
res = list(filter(plugin_lang_filter(self, lang), parent))
if not res:
del parent.plugins[(plugin_attrib, None)]
parent.loaded_plugins.remove(plugin_attrib)
@ -149,7 +180,8 @@ def multifactory(stanza, plugin_attrib):
return Multi
def fix_ns(xpath, split=False, propagate_ns=True, default_ns=''):
def fix_ns(xpath: str, split: bool = False, propagate_ns: bool = True,
default_ns: str = '') -> Union[str, List[str]]:
"""Apply the stanza's namespace to elements in an XPath expression.
:param string xpath: The XPath expression to fix with namespaces.
@ -275,12 +307,12 @@ class ElementBase(object):
#: The XML tag name of the element, not including any namespace
#: prefixes. For example, an :class:`ElementBase` object for
#: ``<message />`` would use ``name = 'message'``.
name = 'stanza'
name: ClassVar[str] = 'stanza'
#: The XML namespace for the element. Given ``<foo xmlns="bar" />``,
#: then ``namespace = "bar"`` should be used. The default namespace
#: is ``jabber:client`` since this is being used in an XMPP library.
namespace = 'jabber:client'
namespace: str = 'jabber:client'
#: For :class:`ElementBase` subclasses which are intended to be used
#: as plugins, the ``plugin_attrib`` value defines the plugin name.
@ -290,7 +322,7 @@ class ElementBase(object):
#: register_stanza_plugin(Message, FooPlugin)
#: msg = Message()
#: msg['foo']['an_interface_from_the_foo_plugin']
plugin_attrib = 'plugin'
plugin_attrib: ClassVar[str] = 'plugin'
#: For :class:`ElementBase` subclasses that are intended to be an
#: iterable group of items, the ``plugin_multi_attrib`` value defines
@ -300,29 +332,29 @@ class ElementBase(object):
#: # Given stanza class Foo, with plugin_multi_attrib = 'foos'
#: parent['foos']
#: filter(isinstance(item, Foo), parent['substanzas'])
plugin_multi_attrib = ''
plugin_multi_attrib: ClassVar[str] = ''
#: The set of keys that the stanza provides for accessing and
#: manipulating the underlying XML object. This set may be augmented
#: with the :attr:`plugin_attrib` value of any registered
#: stanza plugins.
interfaces = {'type', 'to', 'from', 'id', 'payload'}
interfaces: ClassVar[Set[str]] = {'type', 'to', 'from', 'id', 'payload'}
#: A subset of :attr:`interfaces` which maps interfaces to direct
#: subelements of the underlying XML object. Using this set, the text
#: of these subelements may be set, retrieved, or removed without
#: needing to define custom methods.
sub_interfaces = set()
sub_interfaces: ClassVar[Set[str]] = set()
#: A subset of :attr:`interfaces` which maps the presence of
#: subelements to boolean values. Using this set allows for quickly
#: checking for the existence of empty subelements like ``<required />``.
#:
#: .. versionadded:: 1.1
bool_interfaces = set()
bool_interfaces: ClassVar[Set[str]] = set()
#: .. versionadded:: 1.1.2
lang_interfaces = set()
lang_interfaces: ClassVar[Set[str]] = set()
#: In some cases you may wish to override the behaviour of one of the
#: parent stanza's interfaces. The ``overrides`` list specifies the
@ -336,7 +368,7 @@ class ElementBase(object):
#: be affected.
#:
#: .. versionadded:: 1.0-Beta5
overrides = []
overrides: ClassVar[List[str]] = []
#: If you need to add a new interface to an existing stanza, you
#: can create a plugin and set ``is_extension = True``. Be sure
@ -346,7 +378,7 @@ class ElementBase(object):
#: parent stanza will be passed to the plugin directly.
#:
#: .. versionadded:: 1.0-Beta5
is_extension = False
is_extension: ClassVar[bool] = False
#: A map of interface operations to the overriding functions.
#: For example, after overriding the ``set`` operation for
@ -355,15 +387,15 @@ class ElementBase(object):
#: {'set_body': <some function>}
#:
#: .. versionadded: 1.0-Beta5
plugin_overrides = {}
plugin_overrides: ClassVar[Dict[str, str]] = {}
#: A mapping of the :attr:`plugin_attrib` values of registered
#: plugins to their respective classes.
plugin_attrib_map = {}
plugin_attrib_map: ClassVar[Dict[str, Type[ElementBase]]] = {}
#: A mapping of root element tag names (in ``'{namespace}elementname'``
#: format) to the plugin classes responsible for them.
plugin_tag_map = {}
plugin_tag_map: ClassVar[Dict[str, Type[ElementBase]]] = {}
#: The set of stanza classes that can be iterated over using
#: the 'substanzas' interface. Classes are added to this set
@ -372,17 +404,26 @@ class ElementBase(object):
#: register_stanza_plugin(DiscoInfo, DiscoItem, iterable=True)
#:
#: .. versionadded:: 1.0-Beta5
plugin_iterables = set()
plugin_iterables: ClassVar[Set[Type[ElementBase]]] = set()
#: The default XML namespace: ``http://www.w3.org/XML/1998/namespace``.
xml_ns = XML_NS
xml_ns: ClassVar[str] = XML_NS
def __init__(self, xml=None, parent=None):
plugins: Dict[Tuple[str, Optional[str]], ElementBase]
#: The underlying XML object for the stanza. It is a standard
#: :class:`xml.etree.ElementTree` object.
xml: ET.Element
_index: int
loaded_plugins: Set[str]
iterables: List[ElementBase]
tag: str
parent: Optional[ReferenceType[ElementBase]]
def __init__(self, xml: Optional[ET.Element] = None, parent: Union[Optional[ElementBase], ReferenceType[ElementBase]] = None):
self._index = 0
#: The underlying XML object for the stanza. It is a standard
#: :class:`xml.etree.ElementTree` object.
self.xml = xml
if xml is not None:
self.xml = xml
#: An ordered dictionary of plugin stanzas, mapped by their
#: :attr:`plugin_attrib` value.
@ -419,7 +460,7 @@ class ElementBase(object):
existing_xml=child,
reuse=False)
def setup(self, xml=None):
def setup(self, xml: Optional[ET.Element] = None) -> bool:
"""Initialize the stanza's XML contents.
Will return ``True`` if XML was generated according to the stanza's
@ -429,29 +470,31 @@ class ElementBase(object):
:param xml: An existing XML object to use for the stanza's content
instead of generating new XML.
"""
if self.xml is None:
if hasattr(self, 'xml'):
return False
if not hasattr(self, 'xml') and xml is not None:
self.xml = xml
last_xml = self.xml
if self.xml is None:
# Generate XML from the stanza definition
for ename in self.name.split('/'):
new = ET.Element("{%s}%s" % (self.namespace, ename))
if self.xml is None:
self.xml = new
else:
last_xml.append(new)
last_xml = new
if self.parent is not None:
self.parent().xml.append(self.xml)
# We had to generate XML
return True
else:
# We did not generate XML
return False
def enable(self, attrib, lang=None):
# Generate XML from the stanza definition
last_xml = ET.Element('')
for ename in self.name.split('/'):
new = ET.Element("{%s}%s" % (self.namespace, ename))
if not hasattr(self, 'xml'):
self.xml = new
else:
last_xml.append(new)
last_xml = new
if self.parent is not None:
parent = self.parent()
if parent:
parent.xml.append(self.xml)
# We had to generate XML
return True
def enable(self, attrib: str, lang: Optional[str] = None) -> ElementBase:
"""Enable and initialize a stanza plugin.
Alias for :meth:`init_plugin`.
@ -487,7 +530,10 @@ class ElementBase(object):
else:
return None if check else self.init_plugin(name, lang)
def init_plugin(self, attrib, lang=None, existing_xml=None, element=None, reuse=True):
def init_plugin(self, attrib: str, lang: Optional[str] = None,
existing_xml: Optional[ET.Element] = None,
reuse: bool = True,
element: Optional[ElementBase] = None) -> ElementBase:
"""Enable and initialize a stanza plugin.
:param string attrib: The :attr:`plugin_attrib` value of the
@ -525,7 +571,7 @@ class ElementBase(object):
return plugin
def _get_stanza_values(self):
def _get_stanza_values(self) -> Dict[str, Any]:
"""Return A JSON/dictionary version of the XML content
exposed through the stanza's interfaces::
@ -567,7 +613,7 @@ class ElementBase(object):
values['substanzas'] = iterables
return values
def _set_stanza_values(self, values):
def _set_stanza_values(self, values: Dict[str, Any]) -> ElementBase:
"""Set multiple stanza interface values using a dictionary.
Stanza plugin values may be set using nested dictionaries.
@ -623,7 +669,7 @@ class ElementBase(object):
plugin.values = value
return self
def __getitem__(self, full_attrib):
def __getitem__(self, full_attrib: str) -> Any:
"""Return the value of a stanza interface using dict-like syntax.
Example::
@ -688,7 +734,7 @@ class ElementBase(object):
else:
return ''
def __setitem__(self, attrib, value):
def __setitem__(self, attrib: str, value: Any) -> Any:
"""Set the value of a stanza interface using dictionary-like syntax.
Example::
@ -773,7 +819,7 @@ class ElementBase(object):
plugin[full_attrib] = value
return self
def __delitem__(self, attrib):
def __delitem__(self, attrib: str) -> Any:
"""Delete the value of a stanza interface using dict-like syntax.
Example::
@ -851,7 +897,7 @@ class ElementBase(object):
pass
return self
def _set_attr(self, name, value):
def _set_attr(self, name: str, value: Optional[JidStr]) -> None:
"""Set the value of a top level attribute of the XML object.
If the new value is None or an empty string, then the attribute will
@ -868,7 +914,7 @@ class ElementBase(object):
value = str(value)
self.xml.attrib[name] = value
def _del_attr(self, name):
def _del_attr(self, name: str) -> None:
"""Remove a top level attribute of the XML object.
:param name: The name of the attribute.
@ -876,7 +922,7 @@ class ElementBase(object):
if name in self.xml.attrib:
del self.xml.attrib[name]
def _get_attr(self, name, default=''):
def _get_attr(self, name: str, default: str = '') -> str:
"""Return the value of a top level attribute of the XML object.
In case the attribute has not been set, a default value can be
@ -889,7 +935,8 @@ class ElementBase(object):
"""
return self.xml.attrib.get(name, default)
def _get_sub_text(self, name, default='', lang=None):
def _get_sub_text(self, name: str, default: str = '',
lang: Optional[str] = None) -> Union[str, Dict[str, str]]:
"""Return the text contents of a sub element.
In case the element does not exist, or it has no textual content,
@ -900,7 +947,7 @@ class ElementBase(object):
:param default: Optional default to return if the element does
not exists. An empty string is returned otherwise.
"""
name = self._fix_ns(name)
name = cast(str, self._fix_ns(name))
if lang == '*':
return self._get_all_sub_text(name, default, None)
@ -924,8 +971,9 @@ class ElementBase(object):
return result
return default
def _get_all_sub_text(self, name, default='', lang=None):
name = self._fix_ns(name)
def _get_all_sub_text(self, name: str, default: str = '',
lang: Optional[str] = None) -> Dict[str, str]:
name = cast(str, self._fix_ns(name))
default_lang = self.get_lang()
results = {}
@ -935,10 +983,16 @@ class ElementBase(object):
stanza_lang = stanza.attrib.get('{%s}lang' % XML_NS,
default_lang)
if not lang or lang == '*' or stanza_lang == lang:
results[stanza_lang] = stanza.text
if stanza.text is None:
text = default
else:
text = stanza.text
results[stanza_lang] = text
return results
def _set_sub_text(self, name, text=None, keep=False, lang=None):
def _set_sub_text(self, name: str, text: Optional[str] = None,
keep: bool = False,
lang: Optional[str] = None) -> Optional[ET.Element]:
"""Set the text contents of a sub element.
In case the element does not exist, a element will be created,
@ -959,15 +1013,16 @@ class ElementBase(object):
lang = default_lang
if not text and not keep:
return self._del_sub(name, lang=lang)
self._del_sub(name, lang=lang)
return None
path = self._fix_ns(name, split=True)
path = cast(List[str], self._fix_ns(name, split=True))
name = path[-1]
parent = self.xml
parent: Optional[ET.Element] = self.xml
# The first goal is to find the parent of the subelement, or, if
# we can't find that, the closest grandparent element.
missing_path = []
missing_path: List[str] = []
search_order = path[:-1]
while search_order:
parent = self.xml.find('/'.join(search_order))
@ -1008,15 +1063,17 @@ class ElementBase(object):
parent.append(element)
return element
def _set_all_sub_text(self, name, values, keep=False, lang=None):
self._del_sub(name, lang)
def _set_all_sub_text(self, name: str, values: Dict[str, str],
keep: bool = False,
lang: Optional[str] = None) -> None:
self._del_sub(name, lang=lang)
for value_lang, value in values.items():
if not lang or lang == '*' or value_lang == lang:
self._set_sub_text(name, text=value,
keep=keep,
lang=value_lang)
def _del_sub(self, name, all=False, lang=None):
def _del_sub(self, name: str, all: bool = False, lang: Optional[str] = None) -> None:
"""Remove sub elements that match the given name or XPath.
If the element is in a path, then any parent elements that become
@ -1034,11 +1091,11 @@ class ElementBase(object):
if not lang:
lang = default_lang
parent = self.xml
parent: Optional[ET.Element] = self.xml
for level, _ in enumerate(path):
# Generate the paths to the target elements and their parent.
element_path = "/".join(path[:len(path) - level])
parent_path = "/".join(path[:len(path) - level - 1])
parent_path: Optional[str] = "/".join(path[:len(path) - level - 1])
elements = self.xml.findall(element_path)
if parent_path == '':
@ -1061,7 +1118,7 @@ class ElementBase(object):
# after deleting the first level of elements.
return
def match(self, xpath):
def match(self, xpath: Union[str, List[str]]) -> bool:
"""Compare a stanza object with an XPath-like expression.
If the XPath matches the contents of the stanza object, the match
@ -1127,7 +1184,7 @@ class ElementBase(object):
# Everything matched.
return True
def get(self, key, default=None):
def get(self, key: str, default: Optional[Any] = None) -> Any:
"""Return the value of a stanza interface.
If the found value is None or an empty string, return the supplied
@ -1144,7 +1201,7 @@ class ElementBase(object):
return default
return value
def keys(self):
def keys(self) -> List[str]:
"""Return the names of all stanza interfaces provided by the
stanza object.
@ -1158,7 +1215,7 @@ class ElementBase(object):
out.append('substanzas')
return out
def append(self, item):
def append(self, item: Union[ET.Element, ElementBase]) -> ElementBase:
"""Append either an XML object or a substanza to this stanza object.
If a substanza object is appended, it will be added to the list
@ -1189,7 +1246,7 @@ class ElementBase(object):
return self
def appendxml(self, xml):
def appendxml(self, xml: ET.Element) -> ElementBase:
"""Append an XML object to the stanza's XML.
The added XML will not be included in the list of
@ -1200,7 +1257,7 @@ class ElementBase(object):
self.xml.append(xml)
return self
def pop(self, index=0):
def pop(self, index: int = 0) -> ElementBase:
"""Remove and return the last substanza in the list of
iterable substanzas.
@ -1212,11 +1269,11 @@ class ElementBase(object):
self.xml.remove(substanza.xml)
return substanza
def next(self):
def next(self) -> ElementBase:
"""Return the next iterable substanza."""
return self.__next__()
def clear(self):
def clear(self) -> ElementBase:
"""Remove all XML element contents and plugins.
Any attribute values will be preserved.
@ -1229,7 +1286,7 @@ class ElementBase(object):
return self
@classmethod
def tag_name(cls):
def tag_name(cls) -> str:
"""Return the namespaced name of the stanza's root element.
The format for the tag name is::
@ -1241,29 +1298,32 @@ class ElementBase(object):
"""
return "{%s}%s" % (cls.namespace, cls.name)
def get_lang(self, lang=None):
def get_lang(self, lang: Optional[str] = None) -> str:
result = self.xml.attrib.get('{%s}lang' % XML_NS, '')
if not result and self.parent and self.parent():
return self.parent()['lang']
if not result and self.parent:
parent = self.parent()
if parent:
return cast(str, parent['lang'])
return result
def set_lang(self, lang):
def set_lang(self, lang: Optional[str]) -> None:
self.del_lang()
attr = '{%s}lang' % XML_NS
if lang:
self.xml.attrib[attr] = lang
def del_lang(self):
def del_lang(self) -> None:
attr = '{%s}lang' % XML_NS
if attr in self.xml.attrib:
del self.xml.attrib[attr]
def _fix_ns(self, xpath, split=False, propagate_ns=True):
def _fix_ns(self, xpath: str, split: bool = False,
propagate_ns: bool = True) -> Union[str, List[str]]:
return fix_ns(xpath, split=split,
propagate_ns=propagate_ns,
default_ns=self.namespace)
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
"""Compare the stanza object with another to test for equality.
Stanzas are equal if their interfaces return the same values,
@ -1290,7 +1350,7 @@ class ElementBase(object):
# must be equal.
return True
def __ne__(self, other):
def __ne__(self, other: Any) -> bool:
"""Compare the stanza object with another to test for inequality.
Stanzas are not equal if their interfaces return different values,
@ -1300,16 +1360,16 @@ class ElementBase(object):
"""
return not self.__eq__(other)
def __bool__(self):
def __bool__(self) -> bool:
"""Stanza objects should be treated as True in boolean contexts.
"""
return True
def __len__(self):
def __len__(self) -> int:
"""Return the number of iterable substanzas in this stanza."""
return len(self.iterables)
def __iter__(self):
def __iter__(self) -> ElementBase:
"""Return an iterator object for the stanza's substanzas.
The iterator is the stanza object itself. Attempting to use two
@ -1318,7 +1378,7 @@ class ElementBase(object):
self._index = 0
return self
def __next__(self):
def __next__(self) -> ElementBase:
"""Return the next iterable substanza."""
self._index += 1
if self._index > len(self.iterables):
@ -1326,13 +1386,16 @@ class ElementBase(object):
raise StopIteration
return self.iterables[self._index - 1]
def __copy__(self):
def __copy__(self) -> ElementBase:
"""Return a copy of the stanza object that does not share the same
underlying XML object.
"""
return self.__class__(xml=copy.deepcopy(self.xml), parent=self.parent)
return self.__class__(
xml=copy.deepcopy(self.xml),
parent=self.parent,
)
def __str__(self, top_level_ns=True):
def __str__(self, top_level_ns: bool = True) -> str:
"""Return a string serialization of the underlying XML object.
.. seealso:: :ref:`tostring`
@ -1343,12 +1406,33 @@ class ElementBase(object):
return tostring(self.xml, xmlns='',
top_level=True)
def __repr__(self):
def __repr__(self) -> str:
"""Use the stanza's serialized XML as its representation."""
return self.__str__()
# Compatibility.
_get_plugin = get_plugin
get_stanza_values = _get_stanza_values
set_stanza_values = _set_stanza_values
#: A JSON/dictionary version of the XML content exposed through
#: the stanza interfaces::
#:
#: >>> msg = Message()
#: >>> msg.values
#: {'body': '', 'from': , 'mucnick': '', 'mucroom': '',
#: 'to': , 'type': 'normal', 'id': '', 'subject': ''}
#:
#: Likewise, assigning to the :attr:`values` will change the XML
#: content::
#:
#: >>> msg = Message()
#: >>> msg.values = {'body': 'Hi!', 'to': 'user@example.com'}
#: >>> msg
#: '<message to="user@example.com"><body>Hi!</body></message>'
#:
#: Child stanzas are exposed as nested dictionaries.
values = property(_get_stanza_values, _set_stanza_values) # type: ignore
class StanzaBase(ElementBase):
@ -1386,9 +1470,14 @@ class StanzaBase(ElementBase):
#: The default XMPP client namespace
namespace = 'jabber:client'
types: ClassVar[Set[str]] = set()
def __init__(self, stream=None, xml=None, stype=None,
sto=None, sfrom=None, sid=None, parent=None, recv=False):
def __init__(self, stream: Optional[XMLStream] = None,
xml: Optional[ET.Element] = None,
stype: Optional[str] = None,
sto: Optional[JidStr] = None, sfrom: Optional[JidStr] = None,
sid: Optional[str] = None,
parent: Optional[ElementBase] = None, recv: bool = False):
self.stream = stream
if stream is not None:
self.namespace = stream.default_ns
@ -1403,7 +1492,7 @@ class StanzaBase(ElementBase):
self['id'] = sid
self.tag = "{%s}%s" % (self.namespace, self.name)
def set_type(self, value):
def set_type(self, value: str) -> StanzaBase:
"""Set the stanza's ``'type'`` attribute.
Only type values contained in :attr:`types` are accepted.
@ -1414,11 +1503,11 @@ class StanzaBase(ElementBase):
self.xml.attrib['type'] = value
return self
def get_to(self):
def get_to(self) -> JID:
"""Return the value of the stanza's ``'to'`` attribute."""
return JID(self._get_attr('to'))
def set_to(self, value):
def set_to(self, value: JidStr) -> None:
"""Set the ``'to'`` attribute of the stanza.
:param value: A string or :class:`slixmpp.xmlstream.JID` object
@ -1426,11 +1515,11 @@ class StanzaBase(ElementBase):
"""
return self._set_attr('to', str(value))
def get_from(self):
def get_from(self) -> JID:
"""Return the value of the stanza's ``'from'`` attribute."""
return JID(self._get_attr('from'))
def set_from(self, value):
def set_from(self, value: JidStr) -> None:
"""Set the 'from' attribute of the stanza.
:param from: A string or JID object representing the sender's JID.
@ -1438,11 +1527,11 @@ class StanzaBase(ElementBase):
"""
return self._set_attr('from', str(value))
def get_payload(self):
def get_payload(self) -> List[ET.Element]:
"""Return a list of XML objects contained in the stanza."""
return list(self.xml)
def set_payload(self, value):
def set_payload(self, value: Union[List[ElementBase], ElementBase]) -> StanzaBase:
"""Add XML content to the stanza.
:param value: Either an XML or a stanza object, or a list
@ -1454,12 +1543,12 @@ class StanzaBase(ElementBase):
self.append(val)
return self
def del_payload(self):
def del_payload(self) -> StanzaBase:
"""Remove the XML contents of the stanza."""
self.clear()
return self
def reply(self, clear=True):
def reply(self, clear: bool = True) -> StanzaBase:
"""Prepare the stanza for sending a reply.
Swaps the ``'from'`` and ``'to'`` attributes.
@ -1475,7 +1564,7 @@ class StanzaBase(ElementBase):
new_stanza = copy.copy(self)
# if it's a component, use from
if self.stream and hasattr(self.stream, "is_component") and \
self.stream.is_component:
getattr(self.stream, 'is_component'):
new_stanza['from'], new_stanza['to'] = self['to'], self['from']
else:
new_stanza['to'] = self['from']
@ -1484,19 +1573,19 @@ class StanzaBase(ElementBase):
new_stanza.clear()
return new_stanza
def error(self):
def error(self) -> StanzaBase:
"""Set the stanza's type to ``'error'``."""
self['type'] = 'error'
return self
def unhandled(self):
def unhandled(self) -> None:
"""Called if no handlers have been registered to process this stanza.
Meant to be overridden.
"""
pass
def exception(self, e):
def exception(self, e: Exception) -> None:
"""Handle exceptions raised during stanza processing.
Meant to be overridden.
@ -1504,18 +1593,21 @@ class StanzaBase(ElementBase):
log.exception('Error handling {%s}%s stanza', self.namespace,
self.name)
def send(self):
def send(self) -> None:
"""Queue the stanza to be sent on the XML stream."""
self.stream.send(self)
if self.stream is not None:
self.stream.send(self)
else:
log.error("Tried to send stanza without a stream: %s", self)
def __copy__(self):
def __copy__(self) -> StanzaBase:
"""Return a copy of the stanza object that does not share the
same underlying XML object, but does share the same XML stream.
"""
return self.__class__(xml=copy.deepcopy(self.xml),
stream=self.stream)
def __str__(self, top_level_ns=False):
def __str__(self, top_level_ns: bool = False) -> str:
"""Serialize the stanza's XML to a string.
:param bool top_level_ns: Display the top-most namespace.
@ -1525,27 +1617,3 @@ class StanzaBase(ElementBase):
return tostring(self.xml, xmlns=xmlns,
stream=self.stream,
top_level=(self.stream is None))
#: A JSON/dictionary version of the XML content exposed through
#: the stanza interfaces::
#:
#: >>> msg = Message()
#: >>> msg.values
#: {'body': '', 'from': , 'mucnick': '', 'mucroom': '',
#: 'to': , 'type': 'normal', 'id': '', 'subject': ''}
#:
#: Likewise, assigning to the :attr:`values` will change the XML
#: content::
#:
#: >>> msg = Message()
#: >>> msg.values = {'body': 'Hi!', 'to': 'user@example.com'}
#: >>> msg
#: '<message to="user@example.com"><body>Hi!</body></message>'
#:
#: Child stanzas are exposed as nested dictionaries.
ElementBase.values = property(ElementBase._get_stanza_values,
ElementBase._set_stanza_values)
ElementBase.get_stanza_values = ElementBase._get_stanza_values
ElementBase.set_stanza_values = ElementBase._set_stanza_values

View File

@ -1,4 +1,3 @@
# slixmpp.xmlstream.tostring
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# This module converts XML objects into Unicode strings and
@ -7,11 +6,20 @@
# Part of Slixmpp: The Slick XMPP Library
# :copyright: (c) 2011 Nathanael C. Fritz
# :license: MIT, see LICENSE for more details
from __future__ import annotations
from typing import Optional, Set, TYPE_CHECKING
from xml.etree.ElementTree import Element
if TYPE_CHECKING:
from slixmpp.xmlstream import XMLStream
XML_NS = 'http://www.w3.org/XML/1998/namespace'
def tostring(xml=None, xmlns='', stream=None, outbuffer='',
top_level=False, open_only=False, namespaces=None):
def tostring(xml: Optional[Element] = None, xmlns: str = '',
stream: Optional[XMLStream] = None, outbuffer: str = '',
top_level: bool = False, open_only: bool = False,
namespaces: Optional[Set[str]] = None) -> str:
"""Serialize an XML object to a Unicode string.
If an outer xmlns is provided using ``xmlns``, then the current element's
@ -35,6 +43,8 @@ def tostring(xml=None, xmlns='', stream=None, outbuffer='',
:rtype: Unicode string
"""
if xml is None:
return ''
# Add previous results to the start of the output.
output = [outbuffer]
@ -123,11 +133,12 @@ def tostring(xml=None, xmlns='', stream=None, outbuffer='',
# Remove namespaces introduced in this context. This is necessary
# because the namespaces object continues to be shared with other
# contexts.
namespaces.remove(ns)
if namespaces is not None:
namespaces.remove(ns)
return ''.join(output)
def escape(text, use_cdata=False):
def escape(text: str, use_cdata: bool = False) -> str:
"""Convert special characters in XML to escape sequences.
:param string text: The XML text to convert.

View File

@ -9,17 +9,24 @@
# :license: MIT, see LICENSE for more details
from typing import (
Any,
Dict,
Awaitable,
Generator,
Coroutine,
Callable,
Iterable,
Iterator,
List,
Optional,
Set,
Union,
Tuple,
TypeVar,
NoReturn,
Type,
cast,
)
import asyncio
import functools
import logging
import socket as Socket
@ -27,30 +34,66 @@ import ssl
import weakref
import uuid
from asyncio import iscoroutinefunction, wait, Future
from contextlib import contextmanager
import xml.etree.ElementTree as ET
from asyncio import (
AbstractEventLoop,
BaseTransport,
Future,
Task,
TimerHandle,
Transport,
iscoroutinefunction,
wait,
)
from slixmpp.xmlstream.asyncio import asyncio
from slixmpp.xmlstream import tostring
from slixmpp.types import FilterString
from slixmpp.xmlstream.tostring import tostring
from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase
from slixmpp.xmlstream.resolver import resolve, default_resolver
from slixmpp.xmlstream.handler.base import BaseHandler
T = TypeVar('T')
#: The time in seconds to wait before timing out waiting for response stanzas.
RESPONSE_TIMEOUT = 30
log = logging.getLogger(__name__)
class ContinueQueue(Exception):
"""
Exception raised in the send queue to "continue" from within an inner loop
"""
class NotConnectedError(Exception):
"""
Raised when we try to send something over the wire but we are not
connected.
"""
_T = TypeVar('_T', str, ElementBase, StanzaBase)
SyncFilter = Callable[[StanzaBase], Optional[StanzaBase]]
AsyncFilter = Callable[[StanzaBase], Awaitable[Optional[StanzaBase]]]
Filter = Union[
SyncFilter,
AsyncFilter,
]
_FiltersDict = Dict[str, List[Filter]]
Handler = Callable[[Any], Union[
Any,
Coroutine[Any, Any, Any]
]]
class XMLStream(asyncio.BaseProtocol):
"""
An XML stream connection manager and event dispatcher.
@ -78,16 +121,156 @@ class XMLStream(asyncio.BaseProtocol):
:param int port: The port to use for the connection. Defaults to 0.
"""
def __init__(self, host='', port=0):
# The asyncio.Transport object provided by the connection_made()
# callback when we are connected
transport: Optional[Transport]
# The socket that is used internally by the transport object
socket: Optional[ssl.SSLSocket]
# The backoff of the connect routine (increases exponentially
# after each failure)
_connect_loop_wait: float
parser: Optional[ET.XMLPullParser]
xml_depth: int
xml_root: Optional[ET.Element]
force_starttls: Optional[bool]
disable_starttls: Optional[bool]
waiting_queue: asyncio.Queue
# A dict of {name: handle}
scheduled_events: Dict[str, TimerHandle]
ssl_context: ssl.SSLContext
# The event to trigger when the create_connection() succeeds. It can
# be "connected" or "tls_success" depending on the step we are at.
event_when_connected: str
#: The list of accepted ciphers, in OpenSSL Format.
#: It might be useful to override it for improved security
#: over the python defaults.
ciphers: Optional[str]
#: Path to a file containing certificates for verifying the
#: server SSL certificate. A non-``None`` value will trigger
#: certificate checking.
#:
#: .. note::
#:
#: On Mac OS X, certificates in the system keyring will
#: be consulted, even if they are not in the provided file.
ca_certs: Optional[str]
#: Path to a file containing a client certificate to use for
#: authenticating via SASL EXTERNAL. If set, there must also
#: be a corresponding `:attr:keyfile` value.
certfile: Optional[str]
#: Path to a file containing the private key for the selected
#: client certificate to use for authenticating via SASL EXTERNAL.
keyfile: Optional[str]
# The asyncio event loop
_loop: Optional[AbstractEventLoop]
#: The default port to return when querying DNS records.
default_port: int
#: The domain to try when querying DNS records.
default_domain: str
#: The expected name of the server, for validation.
_expected_server_name: str
_service_name: str
#: The desired, or actual, address of the connected server.
address: Tuple[str, int]
#: Enable connecting to the server directly over SSL, in
#: particular when the service provides two ports: one for
#: non-SSL traffic and another for SSL traffic.
use_ssl: bool
#: If set to ``True``, attempt to use IPv6.
use_ipv6: bool
#: If set to ``True``, allow using the ``dnspython`` DNS library
#: if available. If set to ``False``, the builtin DNS resolver
#: will be used, even if ``dnspython`` is installed.
use_aiodns: bool
#: Use CDATA for escaping instead of XML entities. Defaults
#: to ``False``.
use_cdata: bool
#: The default namespace of the stream content, not of the
#: stream wrapper it
default_ns: str
default_lang: Optional[str]
peer_default_lang: Optional[str]
#: The namespace of the enveloping stream element.
stream_ns: str
#: The default opening tag for the stream element.
stream_header: str
#: The default closing tag for the stream element.
stream_footer: str
#: If ``True``, periodically send a whitespace character over the
#: wire to keep the connection alive. Mainly useful for connections
#: traversing NAT.
whitespace_keepalive: bool
#: The default interval between keepalive signals when
#: :attr:`whitespace_keepalive` is enabled.
whitespace_keepalive_interval: int
#: Flag for controlling if the session can be considered ended
#: if the connection is terminated.
end_session_on_disconnect: bool
#: A mapping of XML namespaces to well-known prefixes.
namespace_map: dict
__root_stanza: List[Type[StanzaBase]]
__handlers: List[BaseHandler]
__event_handlers: Dict[str, List[Tuple[Handler, bool]]]
__filters: _FiltersDict
# Current connection attempt (Future)
_current_connection_attempt: Optional[Future]
#: A list of DNS results that have not yet been tried.
_dns_answers: Optional[Iterator[Tuple[str, str, int]]]
#: The service name to check with DNS SRV records. For
#: example, setting this to ``'xmpp-client'`` would query the
#: ``_xmpp-client._tcp`` service.
dns_service: Optional[str]
#: The reason why we are disconnecting from the server
disconnect_reason: Optional[str]
#: An asyncio Future being done when the stream is disconnected.
disconnected: Future
# If the session has been started or not
_session_started: bool
# If we want to bypass the send() check (e.g. unit tests)
_always_send_everything: bool
_run_out_filters: Optional[Future]
__slow_tasks: List[Task]
__queued_stanzas: List[Tuple[Union[StanzaBase, str], bool]]
def __init__(self, host: str = '', port: int = 0):
self.transport = None
# The socket that is used internally by the transport object
self.socket = None
# The backoff of the connect routine (increases exponentially
# after each failure)
self._connect_loop_wait = 0
self.parser = None
@ -106,126 +289,60 @@ class XMLStream(asyncio.BaseProtocol):
self.ssl_context.check_hostname = False
self.ssl_context.verify_mode = ssl.CERT_NONE
# The event to trigger when the create_connection() succeeds. It can
# be "connected" or "tls_success" depending on the step we are at.
self.event_when_connected = "connected"
#: The list of accepted ciphers, in OpenSSL Format.
#: It might be useful to override it for improved security
#: over the python defaults.
self.ciphers = None
#: Path to a file containing certificates for verifying the
#: server SSL certificate. A non-``None`` value will trigger
#: certificate checking.
#:
#: .. note::
#:
#: On Mac OS X, certificates in the system keyring will
#: be consulted, even if they are not in the provided file.
self.ca_certs = None
#: Path to a file containing a client certificate to use for
#: authenticating via SASL EXTERNAL. If set, there must also
#: be a corresponding `:attr:keyfile` value.
self.certfile = None
#: Path to a file containing the private key for the selected
#: client certificate to use for authenticating via SASL EXTERNAL.
self.keyfile = None
self._der_cert = None
# The asyncio event loop
self._loop = None
#: The default port to return when querying DNS records.
self.default_port = int(port)
#: The domain to try when querying DNS records.
self.default_domain = ''
#: The expected name of the server, for validation.
self._expected_server_name = ''
self._service_name = ''
#: The desired, or actual, address of the connected server.
self.address = (host, int(port))
#: Enable connecting to the server directly over SSL, in
#: particular when the service provides two ports: one for
#: non-SSL traffic and another for SSL traffic.
self.use_ssl = False
#: If set to ``True``, attempt to use IPv6.
self.use_ipv6 = True
#: If set to ``True``, allow using the ``dnspython`` DNS library
#: if available. If set to ``False``, the builtin DNS resolver
#: will be used, even if ``dnspython`` is installed.
self.use_aiodns = True
#: Use CDATA for escaping instead of XML entities. Defaults
#: to ``False``.
self.use_cdata = False
#: The default namespace of the stream content, not of the
#: stream wrapper itself.
self.default_ns = ''
self.default_lang = None
self.peer_default_lang = None
#: The namespace of the enveloping stream element.
self.stream_ns = ''
#: The default opening tag for the stream element.
self.stream_header = "<stream>"
#: The default closing tag for the stream element.
self.stream_footer = "</stream>"
#: If ``True``, periodically send a whitespace character over the
#: wire to keep the connection alive. Mainly useful for connections
#: traversing NAT.
self.whitespace_keepalive = True
#: The default interval between keepalive signals when
#: :attr:`whitespace_keepalive` is enabled.
self.whitespace_keepalive_interval = 300
#: Flag for controlling if the session can be considered ended
#: if the connection is terminated.
self.end_session_on_disconnect = True
#: A mapping of XML namespaces to well-known prefixes.
self.namespace_map = {StanzaBase.xml_ns: 'xml'}
self.__root_stanza = []
self.__handlers = []
self.__event_handlers = {}
self.__filters = {'in': [], 'out': [], 'out_sync': []}
self.__filters = {
'in': [], 'out': [], 'out_sync': []
}
# Current connection attempt (Future)
self._current_connection_attempt = None
#: A list of DNS results that have not yet been tried.
self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None
#: The service name to check with DNS SRV records. For
#: example, setting this to ``'xmpp-client'`` would query the
#: ``_xmpp-client._tcp`` service.
self._dns_answers = None
self.dns_service = None
#: The reason why we are disconnecting from the server
self.disconnect_reason = None
#: An asyncio Future being done when the stream is disconnected.
self.disconnected: Future = Future()
# If the session has been started or not
self.disconnected = Future()
self._session_started = False
# If we want to bypass the send() check (e.g. unit tests)
self._always_send_everything = False
self.add_event_handler('disconnected', self._remove_schedules)
@ -234,21 +351,21 @@ class XMLStream(asyncio.BaseProtocol):
self.add_event_handler('session_start', self._set_session_start)
self.add_event_handler('session_resumed', self._set_session_start)
self._run_out_filters: Optional[Future] = None
self.__slow_tasks: List[Future] = []
self.__queued_stanzas: List[Tuple[StanzaBase, bool]] = []
self._run_out_filters = None
self.__slow_tasks = []
self.__queued_stanzas = []
@property
def loop(self):
def loop(self) -> AbstractEventLoop:
if self._loop is None:
self._loop = asyncio.get_event_loop()
return self._loop
@loop.setter
def loop(self, value):
def loop(self, value: AbstractEventLoop) -> None:
self._loop = value
def new_id(self):
def new_id(self) -> str:
"""Generate and return a new stream ID in hexadecimal form.
Many stanzas, handlers, or matchers may require unique
@ -257,7 +374,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
return uuid.uuid4().hex
def _set_session_start(self, event):
def _set_session_start(self, event: Any) -> None:
"""
On session start, queue all pending stanzas to be sent.
"""
@ -266,17 +383,17 @@ class XMLStream(asyncio.BaseProtocol):
self.waiting_queue.put_nowait(stanza)
self.__queued_stanzas = []
def _set_disconnected(self, event):
def _set_disconnected(self, event: Any) -> None:
self._session_started = False
def _set_disconnected_future(self):
def _set_disconnected_future(self) -> None:
"""Set the self.disconnected future on disconnect"""
if not self.disconnected.done():
self.disconnected.set_result(True)
self.disconnected = asyncio.Future()
def connect(self, host='', port=0, use_ssl=False,
force_starttls=True, disable_starttls=False):
def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False,
force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None:
"""Create a new socket and connect to the server.
:param host: The name of the desired server for the connection.
@ -327,7 +444,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop,
)
async def _connect_routine(self):
async def _connect_routine(self) -> None:
self.event_when_connected = "connected"
if self._connect_loop_wait > 0:
@ -345,6 +462,7 @@ class XMLStream(asyncio.BaseProtocol):
# and try (host, port) as a last resort
self._dns_answers = None
ssl_context: Optional[ssl.SSLContext]
if self.use_ssl:
ssl_context = self.get_ssl_context()
else:
@ -373,7 +491,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop,
)
def process(self, *, forever=True, timeout=None):
def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None:
"""Process all the available XMPP events (receiving or sending data on the
socket(s), calling various registered callbacks, calling expired
timers, handling signal events, etc). If timeout is None, this
@ -386,12 +504,12 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.loop.run_until_complete(self.disconnected)
else:
tasks = [asyncio.sleep(timeout, loop=self.loop)]
tasks: List[Future] = [asyncio.sleep(timeout, loop=self.loop)]
if not forever:
tasks.append(self.disconnected)
self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop))
def init_parser(self):
def init_parser(self) -> None:
"""init the XML parser. The parser must always be reset for each new
connexion
"""
@ -399,11 +517,13 @@ class XMLStream(asyncio.BaseProtocol):
self.xml_root = None
self.parser = ET.XMLPullParser(("start", "end"))
def connection_made(self, transport):
def connection_made(self, transport: BaseTransport) -> None:
"""Called when the TCP connection has been established with the server
"""
self.event(self.event_when_connected)
self.transport = transport
self.transport = cast(Transport, transport)
if self.transport is None:
raise ValueError("Transport cannot be none")
self.socket = self.transport.get_extra_info(
"ssl_object",
default=self.transport.get_extra_info("socket")
@ -413,7 +533,7 @@ class XMLStream(asyncio.BaseProtocol):
self.send_raw(self.stream_header)
self._dns_answers = None
def data_received(self, data):
def data_received(self, data: bytes) -> None:
"""Called when incoming data is received on the socket.
We feed that data to the parser and the see if this produced any XML
@ -467,18 +587,18 @@ class XMLStream(asyncio.BaseProtocol):
self.send(error)
self.disconnect()
def is_connecting(self):
def is_connecting(self) -> bool:
return self._current_connection_attempt is not None
def is_connected(self):
def is_connected(self) -> bool:
return self.transport is not None
def eof_received(self):
def eof_received(self) -> None:
"""When the TCP connection is properly closed by the remote end
"""
self.event("eof_received")
def connection_lost(self, exception):
def connection_lost(self, exception: Optional[BaseException]) -> None:
"""On any kind of disconnection, initiated by us or not. This signals the
closure of the TCP connection
"""
@ -493,9 +613,9 @@ class XMLStream(asyncio.BaseProtocol):
self._reset_sendq()
self.event('session_end')
self._set_disconnected_future()
self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
self.event("disconnected", self.disconnect_reason or exception)
def cancel_connection_attempt(self):
def cancel_connection_attempt(self) -> None:
"""
Immediately cancel the current create_connection() Future.
This is useful when a client using slixmpp tries to connect
@ -526,7 +646,7 @@ class XMLStream(asyncio.BaseProtocol):
# `disconnect(wait=True)` for ages. This doesn't mean anything to the
# schedule call below. It would fortunately be converted to `1` later
# down the call chain. Praise the implicit casts lord.
if wait == True:
if wait is True:
wait = 2.0
if self.transport:
@ -545,11 +665,11 @@ class XMLStream(asyncio.BaseProtocol):
else:
self._set_disconnected_future()
self.event("disconnected", reason)
future = Future()
future: Future = Future()
future.set_result(None)
return future
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float):
async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float) -> None:
"""Wait until the send queue is empty before disconnecting"""
try:
await asyncio.wait_for(
@ -561,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol):
self.disconnect_reason = reason
await self._end_stream_wait(wait)
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None):
async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None) -> None:
"""
Run abort() if we do not received the disconnected event
after a waiting time.
@ -578,7 +698,7 @@ class XMLStream(asyncio.BaseProtocol):
# that means the disconnect has already been handled
pass
def abort(self):
def abort(self) -> None:
"""
Forcibly close the connection
"""
@ -588,26 +708,26 @@ class XMLStream(asyncio.BaseProtocol):
self.transport.abort()
self.event("killed")
def reconnect(self, wait=2.0, reason="Reconnecting"):
def reconnect(self, wait: Union[int, float] = 2.0, reason: str = "Reconnecting") -> None:
"""Calls disconnect(), and once we are disconnected (after the timeout, or
when the server acknowledgement is received), call connect()
"""
log.debug("reconnecting...")
async def handler(event):
async def handler(event: Any) -> None:
# We yield here to allow synchronous handlers to work first
await asyncio.sleep(0, loop=self.loop)
self.connect()
self.add_event_handler('disconnected', handler, disposable=True)
self.disconnect(wait, reason)
def configure_socket(self):
def configure_socket(self) -> None:
"""Set timeout and other options for self.socket.
Meant to be overridden.
"""
pass
def configure_dns(self, resolver, domain=None, port=None):
def configure_dns(self, resolver: Any, domain: Optional[str] = None, port: Optional[int] = None) -> None:
"""
Configure and set options for a :class:`~dns.resolver.Resolver`
instance, and other DNS related tasks. For example, you
@ -624,7 +744,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
pass
def get_ssl_context(self):
def get_ssl_context(self) -> ssl.SSLContext:
"""
Get SSL context.
"""
@ -644,12 +764,14 @@ class XMLStream(asyncio.BaseProtocol):
return self.ssl_context
async def start_tls(self):
async def start_tls(self) -> bool:
"""Perform handshakes for TLS.
If the handshake is successful, the XML stream will need
to be restarted.
"""
if self.transport is None:
raise ValueError("Transport should not be None")
self.event_when_connected = "tls_success"
ssl_context = self.get_ssl_context()
try:
@ -685,7 +807,7 @@ class XMLStream(asyncio.BaseProtocol):
self.connection_made(transp)
return True
def _start_keepalive(self, event):
def _start_keepalive(self, event: Any) -> None:
"""Begin sending whitespace periodically to keep the connection alive.
May be disabled by setting::
@ -702,11 +824,11 @@ class XMLStream(asyncio.BaseProtocol):
args=(' ',),
repeat=True)
def _remove_schedules(self, event):
def _remove_schedules(self, event: Any) -> None:
"""Remove some schedules that become pointless when disconnected"""
self.cancel_schedule('Whitespace Keepalive')
def start_stream_handler(self, xml):
def start_stream_handler(self, xml: ET.Element) -> None:
"""Perform any initialization actions, such as handshakes,
once the stream header has been sent.
@ -714,7 +836,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
pass
def register_stanza(self, stanza_class):
def register_stanza(self, stanza_class: Type[StanzaBase]) -> None:
"""Add a stanza object class as a known root stanza.
A root stanza is one that appears as a direct child of the stream's
@ -732,7 +854,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
self.__root_stanza.append(stanza_class)
def remove_stanza(self, stanza_class):
def remove_stanza(self, stanza_class: Type[StanzaBase]) -> None:
"""Remove a stanza from being a known root stanza.
A root stanza is one that appears as a direct child of the stream's
@ -744,7 +866,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
self.__root_stanza.remove(stanza_class)
def add_filter(self, mode, handler, order=None):
def add_filter(self, mode: FilterString, handler: Callable[[StanzaBase], Optional[StanzaBase]], order: Optional[int] = None) -> None:
"""Add a filter for incoming or outgoing stanzas.
These filters are applied before incoming stanzas are
@ -766,11 +888,11 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.__filters[mode].append(handler)
def del_filter(self, mode, handler):
def del_filter(self, mode: str, handler: Callable[[StanzaBase], Optional[StanzaBase]]) -> None:
"""Remove an incoming or outgoing filter."""
self.__filters[mode].remove(handler)
def register_handler(self, handler, before=None, after=None):
def register_handler(self, handler: BaseHandler, before: Optional[BaseHandler] = None, after: Optional[BaseHandler] = None) -> None:
"""Add a stream event handler that will be executed when a matching
stanza is received.
@ -782,7 +904,7 @@ class XMLStream(asyncio.BaseProtocol):
self.__handlers.append(handler)
handler.stream = weakref.ref(self)
def remove_handler(self, name):
def remove_handler(self, name: str) -> bool:
"""Remove any stream event handlers with the given name.
:param name: The name of the handler.
@ -831,9 +953,9 @@ class XMLStream(asyncio.BaseProtocol):
try:
return next(self._dns_answers)
except StopIteration:
return
return None
def add_event_handler(self, name, pointer, disposable=False):
def add_event_handler(self, name: str, pointer: Callable[..., Any], disposable: bool = False) -> None:
"""Add a custom event handler that will be executed whenever
its event is manually triggered.
@ -847,7 +969,7 @@ class XMLStream(asyncio.BaseProtocol):
self.__event_handlers[name] = []
self.__event_handlers[name].append((pointer, disposable))
def del_event_handler(self, name, pointer):
def del_event_handler(self, name: str, pointer: Callable[..., Any]) -> None:
"""Remove a function as a handler for an event.
:param name: The name of the event.
@ -858,21 +980,21 @@ class XMLStream(asyncio.BaseProtocol):
# Need to keep handlers that do not use
# the given function pointer
def filter_pointers(handler):
def filter_pointers(handler: Tuple[Callable[..., Any], bool]) -> bool:
return handler[0] != pointer
self.__event_handlers[name] = list(filter(
filter_pointers,
self.__event_handlers[name]))
def event_handled(self, name):
def event_handled(self, name: str) -> int:
"""Returns the number of registered handlers for an event.
:param name: The name of the event to check.
"""
return len(self.__event_handlers.get(name, []))
async def event_async(self, name: str, data: Any = {}):
async def event_async(self, name: str, data: Any = {}) -> None:
"""Manually trigger a custom event, but await coroutines immediately.
This event generator should only be called in situations when
@ -908,7 +1030,7 @@ class XMLStream(asyncio.BaseProtocol):
except Exception as e:
self.exception(e)
def event(self, name: str, data: Any = {}):
def event(self, name: str, data: Any = {}) -> None:
"""Manually trigger a custom event.
Coroutine handlers are wrapped into a future and sent into the
event loop for their execution, and not awaited.
@ -928,7 +1050,7 @@ class XMLStream(asyncio.BaseProtocol):
# If the callback is a coroutine, schedule it instead of
# running it directly
if iscoroutinefunction(handler_callback):
async def handler_callback_routine(cb):
async def handler_callback_routine(cb: Callable[[ElementBase], Any]) -> None:
try:
await cb(data)
except Exception as e:
@ -957,8 +1079,9 @@ class XMLStream(asyncio.BaseProtocol):
except ValueError:
pass
def schedule(self, name, seconds, callback, args=tuple(),
kwargs={}, repeat=False):
def schedule(self, name: str, seconds: int, callback: Callable[..., None],
args: Tuple[Any, ...] = tuple(),
kwargs: Dict[Any, Any] = {}, repeat: bool = False) -> None:
"""Schedule a callback function to execute after a given delay.
:param name: A unique name for the scheduled callback.
@ -986,21 +1109,21 @@ class XMLStream(asyncio.BaseProtocol):
# canceling scheduled_events[name]
self.scheduled_events[name] = handle
def cancel_schedule(self, name):
def cancel_schedule(self, name: str) -> None:
try:
handle = self.scheduled_events.pop(name)
handle.cancel()
except KeyError:
log.debug("Tried to cancel unscheduled event: %s" % (name,))
def _safe_cb_run(self, name, cb):
def _safe_cb_run(self, name: str, cb: Callable[[], None]) -> None:
log.debug('Scheduled event: %s', name)
try:
cb()
except Exception as e:
self.exception(e)
def _execute_and_reschedule(self, name, cb, seconds):
def _execute_and_reschedule(self, name: str, cb: Callable[[], None], seconds: int) -> None:
"""Simple method that calls the given callback, and then schedule itself to
be called after the given number of seconds.
"""
@ -1009,7 +1132,7 @@ class XMLStream(asyncio.BaseProtocol):
name, cb, seconds)
self.scheduled_events[name] = handle
def _execute_and_unschedule(self, name, cb):
def _execute_and_unschedule(self, name: str, cb: Callable[[], None]) -> None:
"""
Execute the callback and remove the handler for it.
"""
@ -1018,7 +1141,7 @@ class XMLStream(asyncio.BaseProtocol):
if name in self.scheduled_events:
del self.scheduled_events[name]
def incoming_filter(self, xml):
def incoming_filter(self, xml: ET.Element) -> ET.Element:
"""Filter incoming XML objects before they are processed.
Possible uses include remapping namespaces, or correcting elements
@ -1028,7 +1151,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
return xml
def _reset_sendq(self):
def _reset_sendq(self) -> None:
"""Clear sending tasks on session end"""
# Cancel all pending slow send tasks
log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
@ -1043,7 +1166,7 @@ class XMLStream(asyncio.BaseProtocol):
async def _continue_slow_send(
self,
task: asyncio.Task,
already_used: Set[Callable[[ElementBase], Optional[StanzaBase]]]
already_used: Set[Filter]
) -> None:
"""
Used when an item in the send queue has taken too long to process.
@ -1060,14 +1183,16 @@ class XMLStream(asyncio.BaseProtocol):
if filter in already_used:
continue
if iscoroutinefunction(filter):
data = await filter(data)
data = await filter(data) # type: ignore
else:
filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
return
if isinstance(data, ElementBase):
if isinstance(data, StanzaBase):
for filter in self.__filters['out_sync']:
filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
return
@ -1077,19 +1202,21 @@ class XMLStream(asyncio.BaseProtocol):
else:
self.send_raw(data)
async def run_filters(self):
async def run_filters(self) -> NoReturn:
"""
Background loop that processes stanzas to send.
"""
while True:
data: Optional[Union[StanzaBase, str]]
(data, use_filters) = await self.waiting_queue.get()
try:
if isinstance(data, ElementBase):
if isinstance(data, StanzaBase):
if use_filters:
already_run_filters = set()
for filter in self.__filters['out']:
already_run_filters.add(filter)
if iscoroutinefunction(filter):
filter = cast(AsyncFilter, filter)
task = asyncio.create_task(filter(data))
completed, pending = await wait(
{task},
@ -1108,21 +1235,26 @@ class XMLStream(asyncio.BaseProtocol):
"Slow coroutine, rescheduling filters"
)
data = task.result()
else:
elif isinstance(data, StanzaBase):
filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
raise ContinueQueue('Empty stanza')
if isinstance(data, ElementBase):
if isinstance(data, StanzaBase):
if use_filters:
for filter in self.__filters['out_sync']:
filter = cast(SyncFilter, filter)
data = filter(data)
if data is None:
raise ContinueQueue('Empty stanza')
str_data = tostring(data.xml, xmlns=self.default_ns,
stream=self, top_level=True)
if isinstance(data, StanzaBase):
str_data = tostring(data.xml, xmlns=self.default_ns,
stream=self, top_level=True)
else:
str_data = data
self.send_raw(str_data)
else:
elif isinstance(data, (str, bytes)):
self.send_raw(data)
except ContinueQueue as exc:
log.debug('Stanza in send queue not sent: %s', exc)
@ -1130,10 +1262,10 @@ class XMLStream(asyncio.BaseProtocol):
log.error('Exception raised in send queue:', exc_info=True)
self.waiting_queue.task_done()
def send(self, data, use_filters=True):
def send(self, data: Union[StanzaBase, str], use_filters: bool = True) -> None:
"""A wrapper for :meth:`send_raw()` for sending stanza objects.
:param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
:param data: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase`
stanza to send on the stream.
:param bool use_filters: Indicates if outgoing filters should be
applied to the given stanza data. Disabling
@ -1156,15 +1288,15 @@ class XMLStream(asyncio.BaseProtocol):
return
self.waiting_queue.put_nowait((data, use_filters))
def send_xml(self, data):
def send_xml(self, data: ET.Element) -> None:
"""Send an XML object on the stream
:param data: The :class:`~xml.etree.ElementTree.Element` XML object
to send on the stream.
"""
return self.send(tostring(data))
self.send(tostring(data))
def send_raw(self, data):
def send_raw(self, data: Union[str, bytes]) -> None:
"""Send raw data across the stream.
:param string data: Any bytes or utf-8 string value.
@ -1176,7 +1308,8 @@ class XMLStream(asyncio.BaseProtocol):
data = data.encode('utf-8')
self.transport.write(data)
def _build_stanza(self, xml, default_ns=None):
def _build_stanza(self, xml: ET.Element,
default_ns: Optional[str] = None) -> StanzaBase:
"""Create a stanza object from a given XML object.
If a specialized stanza type is not found for the XML, then
@ -1201,7 +1334,7 @@ class XMLStream(asyncio.BaseProtocol):
stanza['lang'] = self.peer_default_lang
return stanza
def _spawn_event(self, xml):
def _spawn_event(self, xml: ET.Element) -> None:
"""
Analyze incoming XML stanzas and convert them into stanza
objects if applicable and queue stream events to be processed
@ -1215,9 +1348,10 @@ class XMLStream(asyncio.BaseProtocol):
# Convert the raw XML object into a stanza object. If no registered
# stanza type applies, a generic StanzaBase stanza will be used.
stanza = self._build_stanza(xml)
stanza: Optional[StanzaBase] = self._build_stanza(xml)
for filter in self.__filters['in']:
if stanza is not None:
filter = cast(SyncFilter, filter)
stanza = filter(stanza)
if stanza is None:
return
@ -1244,7 +1378,7 @@ class XMLStream(asyncio.BaseProtocol):
if not handled:
stanza.unhandled()
def exception(self, exception):
def exception(self, exception: Exception) -> None:
"""Process an unknown exception.
Meant to be overridden.
@ -1253,7 +1387,7 @@ class XMLStream(asyncio.BaseProtocol):
"""
pass
async def wait_until(self, event: str, timeout=30) -> Any:
async def wait_until(self, event: str, timeout: Union[int, float] = 30) -> Any:
"""Utility method to wake on the next firing of an event.
(Registers a disposable handler on it)
@ -1261,9 +1395,9 @@ class XMLStream(asyncio.BaseProtocol):
:param int timeout: Timeout
:raises: :class:`asyncio.TimeoutError` when the timeout is reached
"""
fut = asyncio.Future()
fut: Future = asyncio.Future()
def result_handler(event_data):
def result_handler(event_data: Any) -> None:
if not fut.done():
fut.set_result(event_data)
else:
@ -1280,19 +1414,19 @@ class XMLStream(asyncio.BaseProtocol):
return await asyncio.wait_for(fut, timeout)
@contextmanager
def event_handler(self, event: str, handler: Callable):
def event_handler(self, event: str, handler: Callable[..., Any]) -> Generator[None, None, None]:
"""
Context manager that adds then removes an event handler.
"""
self.add_event_handler(event, handler)
try:
yield
except Exception as exc:
except Exception:
raise
finally:
self.del_event_handler(event, handler)
def wrap(self, coroutine: Coroutine[Any, Any, Any]) -> Future:
def wrap(self, coroutine: Coroutine[None, None, T]) -> Future:
"""Make a Future out of a coroutine with the current loop.
:param coroutine: The coroutine to wrap.