Merge branch 'more-typing' into 'master'
Add more typing See merge request poezio/slixmpp!166
This commit is contained in:
		| @@ -1,7 +1,17 @@ | ||||
| stages: | ||||
|   - lint | ||||
|   - test | ||||
|   - trigger | ||||
|  | ||||
| mypy: | ||||
|   stage: lint | ||||
|   tags: | ||||
|     - docker | ||||
|   image: python:3 | ||||
|   script: | ||||
|     - pip3 install mypy | ||||
|     - mypy slixmpp | ||||
|  | ||||
| test: | ||||
|   stage: test | ||||
|   tags: | ||||
|   | ||||
							
								
								
									
										15
									
								
								mypy.ini
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								mypy.ini
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| [mypy] | ||||
| check_untyped_defs = False | ||||
| ignore_missing_imports = True | ||||
|  | ||||
| [mypy-slixmpp.types] | ||||
| ignore_errors = True | ||||
|  | ||||
| [mypy-slixmpp.thirdparty.*] | ||||
| ignore_errors = True | ||||
|  | ||||
| [mypy-slixmpp.plugins.*] | ||||
| ignore_errors = True | ||||
|  | ||||
| [mypy-slixmpp.plugins.base] | ||||
| ignore_errors = False | ||||
							
								
								
									
										1
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								setup.py
									
									
									
									
									
								
							| @@ -83,6 +83,7 @@ setup( | ||||
|     url='https://lab.louiz.org/poezio/slixmpp', | ||||
|     license='MIT', | ||||
|     platforms=['any'], | ||||
|     package_data={'slixmpp': ['py.typed']}, | ||||
|     packages=packages, | ||||
|     ext_modules=ext_modules, | ||||
|     install_requires=['aiodns>=1.0', 'pyasn1', 'pyasn1_modules', 'typing_extensions; python_version < "3.8.0"'], | ||||
|   | ||||
| @@ -19,7 +19,6 @@ from slixmpp.xmlstream.stanzabase import ET, ElementBase, register_stanza_plugin | ||||
| from slixmpp.xmlstream.handler import * | ||||
| from slixmpp.xmlstream import XMLStream | ||||
| from slixmpp.xmlstream.matcher import * | ||||
| from slixmpp.xmlstream.asyncio import asyncio, future_wrapper | ||||
| from slixmpp.basexmpp import BaseXMPP | ||||
| from slixmpp.clientxmpp import ClientXMPP | ||||
| from slixmpp.componentxmpp import ComponentXMPP | ||||
|   | ||||
| @@ -21,7 +21,7 @@ class APIWrapper(object): | ||||
|         if name not in self.api.settings: | ||||
|             self.api.settings[name] = {} | ||||
|  | ||||
|     def __getattr__(self, attr): | ||||
|     def __getattr__(self, attr: str): | ||||
|         """Curry API management commands with the API name.""" | ||||
|         if attr == 'name': | ||||
|             return self.name | ||||
| @@ -33,13 +33,13 @@ class APIWrapper(object): | ||||
|                 return register(handler, self.name, op, jid, node, default) | ||||
|             return partial | ||||
|         elif attr == 'register_default': | ||||
|             def partial(handler, op, jid=None, node=None): | ||||
|             def partial1(handler, op, jid=None, node=None): | ||||
|                 return getattr(self.api, attr)(handler, self.name, op) | ||||
|             return partial | ||||
|             return partial1 | ||||
|         elif attr in ('run', 'restore_default', 'unregister'): | ||||
|             def partial(*args, **kwargs): | ||||
|             def partial2(*args, **kwargs): | ||||
|                 return getattr(self.api, attr)(self.name, *args, **kwargs) | ||||
|             return partial | ||||
|             return partial2 | ||||
|         return None | ||||
|  | ||||
|     def __getitem__(self, attr): | ||||
| @@ -82,7 +82,7 @@ class APIRegistry(object): | ||||
|         """Return a wrapper object that targets a specific API.""" | ||||
|         return APIWrapper(self, ctype) | ||||
|  | ||||
|     def purge(self, ctype: str): | ||||
|     def purge(self, ctype: str) -> None: | ||||
|         """Remove all information for a given API.""" | ||||
|         del self.settings[ctype] | ||||
|         del self._handler_defaults[ctype] | ||||
| @@ -131,22 +131,23 @@ class APIRegistry(object): | ||||
|             jid = JID(jid) | ||||
|         elif jid == JID(''): | ||||
|             jid = self.xmpp.boundjid | ||||
|         assert jid is not None | ||||
|  | ||||
|         if node is None: | ||||
|             node = '' | ||||
|  | ||||
|         if self.xmpp.is_component: | ||||
|             if self.settings[ctype].get('component_bare', False): | ||||
|                 jid = jid.bare | ||||
|                 jid_str = jid.bare | ||||
|             else: | ||||
|                 jid = jid.full | ||||
|                 jid_str = jid.full | ||||
|         else: | ||||
|             if self.settings[ctype].get('client_bare', False): | ||||
|                 jid = jid.bare | ||||
|                 jid_str = jid.bare | ||||
|             else: | ||||
|                 jid = jid.full | ||||
|                 jid_str = jid.full | ||||
|  | ||||
|         jid = JID(jid) | ||||
|         jid = JID(jid_str) | ||||
|  | ||||
|         handler = self._handlers[ctype][op]['node'].get((jid, node), None) | ||||
|         if handler is None: | ||||
| @@ -167,8 +168,11 @@ class APIRegistry(object): | ||||
|                 # To preserve backward compatibility, drop the ifrom | ||||
|                 # parameter for existing handlers that don't understand it. | ||||
|                 return handler(jid, node, args) | ||||
|         future = Future() | ||||
|         future.set_result(None) | ||||
|         return future | ||||
|  | ||||
|     def register(self, handler: APIHandler, ctype: str, op: str, | ||||
|     def register(self, handler: Optional[APIHandler], ctype: str, op: str, | ||||
|                  jid: Optional[JID] = None, node: Optional[str] = None, | ||||
|                  default: bool = False): | ||||
|         """Register an API callback, with JID+node specificity. | ||||
|   | ||||
| @@ -45,10 +45,11 @@ log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| from slixmpp.types import ( | ||||
|     PresenceShows, | ||||
|     PresenceTypes, | ||||
|     MessageTypes, | ||||
|     IqTypes, | ||||
|     JidStr, | ||||
|     OptJidStr, | ||||
| ) | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
| @@ -263,9 +264,9 @@ class BaseXMPP(XMLStream): | ||||
|         if not pconfig: | ||||
|             pconfig = self.plugin_config.get(plugin, {}) | ||||
|  | ||||
|         if not self.plugin.registered(plugin): | ||||
|         if not self.plugin.registered(plugin):  # type: ignore | ||||
|             load_plugin(plugin, module) | ||||
|         self.plugin.enable(plugin, pconfig) | ||||
|         self.plugin.enable(plugin, pconfig)  # type: ignore | ||||
|  | ||||
|     def register_plugins(self): | ||||
|         """Register and initialize all built-in plugins. | ||||
| @@ -298,25 +299,25 @@ class BaseXMPP(XMLStream): | ||||
|         """Return a plugin given its name, if it has been registered.""" | ||||
|         return self.plugin.get(key, default) | ||||
|  | ||||
|     def Message(self, *args, **kwargs) -> Message: | ||||
|     def Message(self, *args, **kwargs) -> stanza.Message: | ||||
|         """Create a Message stanza associated with this stream.""" | ||||
|         msg = Message(self, *args, **kwargs) | ||||
|         msg['lang'] = self.default_lang | ||||
|         return msg | ||||
|  | ||||
|     def Iq(self, *args, **kwargs) -> Iq: | ||||
|     def Iq(self, *args, **kwargs) -> stanza.Iq: | ||||
|         """Create an Iq stanza associated with this stream.""" | ||||
|         return Iq(self, *args, **kwargs) | ||||
|  | ||||
|     def Presence(self, *args, **kwargs) -> Presence: | ||||
|     def Presence(self, *args, **kwargs) -> stanza.Presence: | ||||
|         """Create a Presence stanza associated with this stream.""" | ||||
|         pres = Presence(self, *args, **kwargs) | ||||
|         pres['lang'] = self.default_lang | ||||
|         return pres | ||||
|  | ||||
|     def make_iq(self, id: str = "0", ifrom: Optional[JID] = None, | ||||
|                 ito: Optional[JID] = None, itype: Optional[IqTypes] = None, | ||||
|                 iquery: Optional[str] = None) -> Iq: | ||||
|     def make_iq(self, id: str = "0", ifrom: OptJidStr = None, | ||||
|                 ito: OptJidStr = None, itype: Optional[IqTypes] = None, | ||||
|                 iquery: Optional[str] = None) -> stanza.Iq: | ||||
|         """Create a new :class:`~.Iq` stanza with a given Id and from JID. | ||||
|  | ||||
|         :param id: An ideally unique ID value for this stanza thread. | ||||
| @@ -339,8 +340,8 @@ class BaseXMPP(XMLStream): | ||||
|         return iq | ||||
|  | ||||
|     def make_iq_get(self, queryxmlns: Optional[str] =None, | ||||
|                     ito: Optional[JID] = None, ifrom: Optional[JID] = None, | ||||
|                     iq: Optional[Iq] = None) -> Iq: | ||||
|                     ito: OptJidStr = None, ifrom: OptJidStr = None, | ||||
|                     iq: Optional[stanza.Iq] = None) -> stanza.Iq: | ||||
|         """Create an :class:`~.Iq` stanza of type ``'get'``. | ||||
|  | ||||
|         Optionally, a query element may be added. | ||||
| @@ -364,8 +365,8 @@ class BaseXMPP(XMLStream): | ||||
|         return iq | ||||
|  | ||||
|     def make_iq_result(self, id: Optional[str] = None, | ||||
|                        ito: Optional[JID] = None, ifrom: Optional[JID] = None, | ||||
|                        iq: Optional[Iq] = None) -> Iq: | ||||
|                        ito: OptJidStr = None, ifrom: OptJidStr = None, | ||||
|                        iq: Optional[stanza.Iq] = None) -> stanza.Iq: | ||||
|         """ | ||||
|         Create an :class:`~.Iq` stanza of type | ||||
|         ``'result'`` with the given ID value. | ||||
| @@ -391,8 +392,8 @@ class BaseXMPP(XMLStream): | ||||
|         return iq | ||||
|  | ||||
|     def make_iq_set(self, sub: Optional[Union[ElementBase, ET.Element]] = None, | ||||
|                     ito: Optional[JID] = None, ifrom: Optional[JID] = None, | ||||
|                     iq: Optional[Iq] = None) -> Iq: | ||||
|                     ito: OptJidStr = None, ifrom: OptJidStr = None, | ||||
|                     iq: Optional[stanza.Iq] = None) -> stanza.Iq: | ||||
|         """ | ||||
|         Create an :class:`~.Iq` stanza of type ``'set'``. | ||||
|  | ||||
| @@ -414,7 +415,7 @@ class BaseXMPP(XMLStream): | ||||
|         if not iq: | ||||
|             iq = self.Iq() | ||||
|         iq['type'] = 'set' | ||||
|         if sub != None: | ||||
|         if sub is not None: | ||||
|             iq.append(sub) | ||||
|         if ito: | ||||
|             iq['to'] = ito | ||||
| @@ -453,9 +454,9 @@ class BaseXMPP(XMLStream): | ||||
|             iq['from'] = ifrom | ||||
|         return iq | ||||
|  | ||||
|     def make_iq_query(self, iq: Optional[Iq] = None, xmlns: str = '', | ||||
|                       ito: Optional[JID] = None, | ||||
|                       ifrom: Optional[JID] = None) -> Iq: | ||||
|     def make_iq_query(self, iq: Optional[stanza.Iq] = None, xmlns: str = '', | ||||
|                       ito: OptJidStr = None, | ||||
|                       ifrom: OptJidStr = None) -> stanza.Iq: | ||||
|         """ | ||||
|         Create or modify an :class:`~.Iq` stanza | ||||
|         to use the given query namespace. | ||||
| @@ -477,7 +478,7 @@ class BaseXMPP(XMLStream): | ||||
|             iq['from'] = ifrom | ||||
|         return iq | ||||
|  | ||||
|     def make_query_roster(self, iq: Optional[Iq] = None) -> ET.Element: | ||||
|     def make_query_roster(self, iq: Optional[stanza.Iq] = None) -> ET.Element: | ||||
|         """Create a roster query element. | ||||
|  | ||||
|         :param iq: Optionally use an existing stanza instead | ||||
| @@ -487,11 +488,11 @@ class BaseXMPP(XMLStream): | ||||
|             iq['query'] = 'jabber:iq:roster' | ||||
|         return ET.Element("{jabber:iq:roster}query") | ||||
|  | ||||
|     def make_message(self, mto: JID, mbody: Optional[str] = None, | ||||
|     def make_message(self, mto: JidStr, mbody: Optional[str] = None, | ||||
|                      msubject: Optional[str] = None, | ||||
|                      mtype: Optional[MessageTypes] = None, | ||||
|                      mhtml: Optional[str] = None, mfrom: Optional[JID] = None, | ||||
|                      mnick: Optional[str] = None) -> Message: | ||||
|                      mhtml: Optional[str] = None, mfrom: OptJidStr = None, | ||||
|                      mnick: Optional[str] = None) -> stanza.Message: | ||||
|         """ | ||||
|         Create and initialize a new | ||||
|         :class:`~.Message` stanza. | ||||
| @@ -516,13 +517,13 @@ class BaseXMPP(XMLStream): | ||||
|             message['html']['body'] = mhtml | ||||
|         return message | ||||
|  | ||||
|     def make_presence(self, pshow: Optional[PresenceShows] = None, | ||||
|     def make_presence(self, pshow: Optional[str] = None, | ||||
|                       pstatus: Optional[str] = None, | ||||
|                       ppriority: Optional[int] = None, | ||||
|                       pto: Optional[JID] = None, | ||||
|                       pto: OptJidStr = None, | ||||
|                       ptype: Optional[PresenceTypes] = None, | ||||
|                       pfrom: Optional[JID] = None, | ||||
|                       pnick: Optional[str] = None) -> Presence: | ||||
|                       pfrom: OptJidStr = None, | ||||
|                       pnick: Optional[str] = None) -> stanza.Presence: | ||||
|         """ | ||||
|         Create and initialize a new | ||||
|         :class:`~.Presence` stanza. | ||||
| @@ -548,7 +549,7 @@ class BaseXMPP(XMLStream): | ||||
|     def send_message(self, mto: JID, mbody: Optional[str] = None, | ||||
|                      msubject: Optional[str] = None, | ||||
|                      mtype: Optional[MessageTypes] = None, | ||||
|                      mhtml: Optional[str] = None, mfrom: Optional[JID] = None, | ||||
|                      mhtml: Optional[str] = None, mfrom: OptJidStr = None, | ||||
|                      mnick: Optional[str] = None): | ||||
|         """ | ||||
|         Create, initialize, and send a new | ||||
| @@ -568,12 +569,12 @@ class BaseXMPP(XMLStream): | ||||
|         self.make_message(mto, mbody, msubject, mtype, | ||||
|                           mhtml, mfrom, mnick).send() | ||||
|  | ||||
|     def send_presence(self, pshow: Optional[PresenceShows] = None, | ||||
|     def send_presence(self, pshow: Optional[str] = None, | ||||
|                       pstatus: Optional[str] = None, | ||||
|                       ppriority: Optional[int] = None, | ||||
|                       pto: Optional[JID] = None, | ||||
|                       pto: OptJidStr = None, | ||||
|                       ptype: Optional[PresenceTypes] = None, | ||||
|                       pfrom: Optional[JID] = None, | ||||
|                       pfrom: OptJidStr = None, | ||||
|                       pnick: Optional[str] = None): | ||||
|         """ | ||||
|         Create, initialize, and send a new | ||||
| @@ -590,8 +591,9 @@ class BaseXMPP(XMLStream): | ||||
|         self.make_presence(pshow, pstatus, ppriority, pto, | ||||
|                            ptype, pfrom, pnick).send() | ||||
|  | ||||
|     def send_presence_subscription(self, pto, pfrom=None, | ||||
|                                    ptype='subscribe', pnick=None): | ||||
|     def send_presence_subscription(self, pto: JidStr, pfrom: OptJidStr = None, | ||||
|                                    ptype: PresenceTypes='subscribe', pnick: | ||||
|                                    Optional[str] = None): | ||||
|         """ | ||||
|         Create, initialize, and send a new | ||||
|         :class:`~.Presence` stanza of | ||||
| @@ -608,62 +610,62 @@ class BaseXMPP(XMLStream): | ||||
|                            pnick=pnick).send() | ||||
|  | ||||
|     @property | ||||
|     def jid(self): | ||||
|     def jid(self) -> str: | ||||
|         """Attribute accessor for bare jid""" | ||||
|         log.warning("jid property deprecated. Use boundjid.bare") | ||||
|         return self.boundjid.bare | ||||
|  | ||||
|     @jid.setter | ||||
|     def jid(self, value): | ||||
|     def jid(self, value: str): | ||||
|         log.warning("jid property deprecated. Use boundjid.bare") | ||||
|         self.boundjid.bare = value | ||||
|  | ||||
|     @property | ||||
|     def fulljid(self): | ||||
|     def fulljid(self) -> str: | ||||
|         """Attribute accessor for full jid""" | ||||
|         log.warning("fulljid property deprecated. Use boundjid.full") | ||||
|         return self.boundjid.full | ||||
|  | ||||
|     @fulljid.setter | ||||
|     def fulljid(self, value): | ||||
|     def fulljid(self, value: str): | ||||
|         log.warning("fulljid property deprecated. Use boundjid.full") | ||||
|         self.boundjid.full = value | ||||
|  | ||||
|     @property | ||||
|     def resource(self): | ||||
|     def resource(self) -> str: | ||||
|         """Attribute accessor for jid resource""" | ||||
|         log.warning("resource property deprecated. Use boundjid.resource") | ||||
|         return self.boundjid.resource | ||||
|  | ||||
|     @resource.setter | ||||
|     def resource(self, value): | ||||
|     def resource(self, value: str): | ||||
|         log.warning("fulljid property deprecated. Use boundjid.resource") | ||||
|         self.boundjid.resource = value | ||||
|  | ||||
|     @property | ||||
|     def username(self): | ||||
|     def username(self) -> str: | ||||
|         """Attribute accessor for jid usernode""" | ||||
|         log.warning("username property deprecated. Use boundjid.user") | ||||
|         return self.boundjid.user | ||||
|  | ||||
|     @username.setter | ||||
|     def username(self, value): | ||||
|     def username(self, value: str): | ||||
|         log.warning("username property deprecated. Use boundjid.user") | ||||
|         self.boundjid.user = value | ||||
|  | ||||
|     @property | ||||
|     def server(self): | ||||
|     def server(self) -> str: | ||||
|         """Attribute accessor for jid host""" | ||||
|         log.warning("server property deprecated. Use boundjid.host") | ||||
|         return self.boundjid.server | ||||
|  | ||||
|     @server.setter | ||||
|     def server(self, value): | ||||
|     def server(self, value: str): | ||||
|         log.warning("server property deprecated. Use boundjid.host") | ||||
|         self.boundjid.server = value | ||||
|  | ||||
|     @property | ||||
|     def auto_authorize(self): | ||||
|     def auto_authorize(self) -> Optional[bool]: | ||||
|         """Auto accept or deny subscription requests. | ||||
|  | ||||
|         If ``True``, auto accept subscription requests. | ||||
| @@ -673,11 +675,11 @@ class BaseXMPP(XMLStream): | ||||
|         return self.roster.auto_authorize | ||||
|  | ||||
|     @auto_authorize.setter | ||||
|     def auto_authorize(self, value): | ||||
|     def auto_authorize(self, value: Optional[bool]): | ||||
|         self.roster.auto_authorize = value | ||||
|  | ||||
|     @property | ||||
|     def auto_subscribe(self): | ||||
|     def auto_subscribe(self) -> bool: | ||||
|         """Auto send requests for mutual subscriptions. | ||||
|  | ||||
|         If ``True``, auto send mutual subscription requests. | ||||
| @@ -685,21 +687,21 @@ class BaseXMPP(XMLStream): | ||||
|         return self.roster.auto_subscribe | ||||
|  | ||||
|     @auto_subscribe.setter | ||||
|     def auto_subscribe(self, value): | ||||
|     def auto_subscribe(self, value: bool): | ||||
|         self.roster.auto_subscribe = value | ||||
|  | ||||
|     def set_jid(self, jid): | ||||
|     def set_jid(self, jid: JidStr): | ||||
|         """Rip a JID apart and claim it as our own.""" | ||||
|         log.debug("setting jid to %s", jid) | ||||
|         self.boundjid = JID(jid) | ||||
|  | ||||
|     def getjidresource(self, fulljid): | ||||
|     def getjidresource(self, fulljid: str): | ||||
|         if '/' in fulljid: | ||||
|             return fulljid.split('/', 1)[-1] | ||||
|         else: | ||||
|             return '' | ||||
|  | ||||
|     def getjidbare(self, fulljid): | ||||
|     def getjidbare(self, fulljid: str): | ||||
|         return fulljid.split('/', 1)[0] | ||||
|  | ||||
|     def _handle_session_start(self, event): | ||||
|   | ||||
| @@ -8,23 +8,18 @@ | ||||
| # :license: MIT, see LICENSE for more details | ||||
| import asyncio | ||||
| import logging | ||||
| from typing import Optional, Any, Callable, Tuple, Dict, Set, List | ||||
|  | ||||
| from slixmpp.jid import JID | ||||
| from slixmpp.stanza import StreamFeatures | ||||
| from slixmpp.stanza import StreamFeatures, Iq | ||||
| from slixmpp.basexmpp import BaseXMPP | ||||
| from slixmpp.exceptions import XMPPError | ||||
| from slixmpp.types import JidStr | ||||
| from slixmpp.xmlstream import XMLStream | ||||
| from slixmpp.xmlstream.stanzabase import StanzaBase | ||||
| from slixmpp.xmlstream.matcher import StanzaPath, MatchXPath | ||||
| from slixmpp.xmlstream.handler import Callback, CoroutineCallback | ||||
|  | ||||
| # Flag indicating if DNS SRV records are available for use. | ||||
| try: | ||||
|     import dns.resolver | ||||
| except ImportError: | ||||
|     DNSPYTHON = False | ||||
| else: | ||||
|     DNSPYTHON = True | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
| @@ -53,7 +48,7 @@ class ClientXMPP(BaseXMPP): | ||||
|     :param escape_quotes: **Deprecated.** | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, jid, password, plugin_config=None, | ||||
|     def __init__(self, jid: JidStr, password: str, plugin_config=None, | ||||
|                  plugin_whitelist=None, escape_quotes=True, sasl_mech=None, | ||||
|                  lang='en', **kwargs): | ||||
|         if not plugin_whitelist: | ||||
| @@ -69,7 +64,7 @@ class ClientXMPP(BaseXMPP): | ||||
|         self.default_port = 5222 | ||||
|         self.default_lang = lang | ||||
|  | ||||
|         self.credentials = {} | ||||
|         self.credentials: Dict[str, str] = {} | ||||
|  | ||||
|         self.password = password | ||||
|  | ||||
| @@ -81,9 +76,9 @@ class ClientXMPP(BaseXMPP): | ||||
|                 "version='1.0'") | ||||
|         self.stream_footer = "</stream:stream>" | ||||
|  | ||||
|         self.features = set() | ||||
|         self._stream_feature_handlers = {} | ||||
|         self._stream_feature_order = [] | ||||
|         self.features: Set[str] = set() | ||||
|         self._stream_feature_handlers: Dict[str, Tuple[Callable, bool]] = {} | ||||
|         self._stream_feature_order: List[Tuple[int, str]] = [] | ||||
|  | ||||
|         self.dns_service = 'xmpp-client' | ||||
|  | ||||
| @@ -100,10 +95,14 @@ class ClientXMPP(BaseXMPP): | ||||
|         self.register_stanza(StreamFeatures) | ||||
|  | ||||
|         self.register_handler( | ||||
|                 CoroutineCallback('Stream Features', | ||||
|                      MatchXPath('{%s}features' % self.stream_ns), | ||||
|                      self._handle_stream_features)) | ||||
|         def roster_push_filter(iq): | ||||
|             CoroutineCallback( | ||||
|                 'Stream Features', | ||||
|                 MatchXPath('{%s}features' % self.stream_ns), | ||||
|                 self._handle_stream_features,  # type: ignore | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         def roster_push_filter(iq: StanzaBase) -> None: | ||||
|             from_ = iq['from'] | ||||
|             if from_ and from_ != JID('') and from_ != self.boundjid.bare: | ||||
|                 reply = iq.reply() | ||||
| @@ -131,15 +130,16 @@ class ClientXMPP(BaseXMPP): | ||||
|             self['feature_mechanisms'].use_mech = sasl_mech | ||||
|  | ||||
|     @property | ||||
|     def password(self): | ||||
|     def password(self) -> str: | ||||
|         return self.credentials.get('password', '') | ||||
|  | ||||
|     @password.setter | ||||
|     def password(self, value): | ||||
|     def password(self, value: str) -> None: | ||||
|         self.credentials['password'] = value | ||||
|  | ||||
|     def connect(self, address=tuple(), use_ssl=False, | ||||
|                 force_starttls=True, disable_starttls=False): | ||||
|     def connect(self, address: Optional[Tuple[str, int]] = None,  # type: ignore | ||||
|                 use_ssl: bool = False, force_starttls: bool = True, | ||||
|                 disable_starttls: bool = False) -> None: | ||||
|         """Connect to the XMPP server. | ||||
|  | ||||
|         When no address is given, a SRV lookup for the server will | ||||
| @@ -161,14 +161,15 @@ class ClientXMPP(BaseXMPP): | ||||
|         # XMPP client port and allow SRV lookup. | ||||
|         if address: | ||||
|             self.dns_service = None | ||||
|             host, port = address | ||||
|         else: | ||||
|             address = (self.boundjid.host, 5222) | ||||
|             host, port = (self.boundjid.host, 5222) | ||||
|             self.dns_service = 'xmpp-client' | ||||
|  | ||||
|         return XMLStream.connect(self, address[0], address[1], use_ssl=use_ssl, | ||||
|         return XMLStream.connect(self, host, port, use_ssl=use_ssl, | ||||
|                                  force_starttls=force_starttls, disable_starttls=disable_starttls) | ||||
|  | ||||
|     def register_feature(self, name, handler, restart=False, order=5000): | ||||
|     def register_feature(self, name: str, handler: Callable, restart: bool = False, order: int = 5000) -> None: | ||||
|         """Register a stream feature handler. | ||||
|  | ||||
|         :param name: The name of the stream feature. | ||||
| @@ -183,13 +184,13 @@ class ClientXMPP(BaseXMPP): | ||||
|         self._stream_feature_order.append((order, name)) | ||||
|         self._stream_feature_order.sort() | ||||
|  | ||||
|     def unregister_feature(self, name, order): | ||||
|     def unregister_feature(self, name: str, order: int) -> None: | ||||
|         if name in self._stream_feature_handlers: | ||||
|             del self._stream_feature_handlers[name] | ||||
|         self._stream_feature_order.remove((order, name)) | ||||
|         self._stream_feature_order.sort() | ||||
|  | ||||
|     def update_roster(self, jid, **kwargs): | ||||
|     def update_roster(self, jid: JID, **kwargs) -> None: | ||||
|         """Add or change a roster item. | ||||
|  | ||||
|         :param jid: The JID of the entry to modify. | ||||
| @@ -251,7 +252,7 @@ class ClientXMPP(BaseXMPP): | ||||
|  | ||||
|         return iq.send(callback, timeout, timeout_callback) | ||||
|  | ||||
|     def _reset_connection_state(self, event=None): | ||||
|     def _reset_connection_state(self, event: Optional[Any] = None) -> None: | ||||
|         #TODO: Use stream state here | ||||
|         self.authenticated = False | ||||
|         self.sessionstarted = False | ||||
| @@ -259,7 +260,7 @@ class ClientXMPP(BaseXMPP): | ||||
|         self.bindfail = False | ||||
|         self.features = set() | ||||
|  | ||||
|     async def _handle_stream_features(self, features): | ||||
|     async def _handle_stream_features(self, features: StreamFeatures) -> Optional[bool]: | ||||
|         """Process the received stream features. | ||||
|  | ||||
|         :param features: The features stanza. | ||||
| @@ -277,8 +278,9 @@ class ClientXMPP(BaseXMPP): | ||||
|                     return True | ||||
|         log.debug('Finished processing stream features.') | ||||
|         self.event('stream_negotiated') | ||||
|         return None | ||||
|  | ||||
|     def _handle_roster(self, iq): | ||||
|     def _handle_roster(self, iq: Iq) -> None: | ||||
|         """Update the roster after receiving a roster stanza. | ||||
|  | ||||
|         :param iq: The roster stanza. | ||||
| @@ -310,7 +312,7 @@ class ClientXMPP(BaseXMPP): | ||||
|             resp.enable('roster') | ||||
|             resp.send() | ||||
|  | ||||
|     def _handle_session_bind(self, jid): | ||||
|     def _handle_session_bind(self, jid: JID) -> None: | ||||
|         """Set the client roster to the JID set by the server. | ||||
|  | ||||
|         :param :class:`slixmpp.xmlstream.jid.JID` jid: The bound JID as | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
|  | ||||
| # Slixmpp: The Slick XMPP Library | ||||
| # Copyright (C) 2011  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| @@ -11,6 +10,7 @@ from slixmpp.stanza import Iq, StreamFeatures | ||||
| from slixmpp.features.feature_bind import stanza | ||||
| from slixmpp.xmlstream import register_stanza_plugin | ||||
| from slixmpp.plugins import BasePlugin | ||||
| from typing import ClassVar, Set | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
| @@ -20,7 +20,7 @@ class FeatureBind(BasePlugin): | ||||
|  | ||||
|     name = 'feature_bind' | ||||
|     description = 'RFC 6120: Stream Feature: Resource Binding' | ||||
|     dependencies = set() | ||||
|     dependencies: ClassVar[Set[str]] = set() | ||||
|     stanza = stanza | ||||
|  | ||||
|     def plugin_init(self): | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
|  | ||||
| # Slixmpp: The Slick XMPP Library | ||||
| # Copyright (C) 2011  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| @@ -15,6 +14,8 @@ from slixmpp.xmlstream.matcher import MatchXPath | ||||
| from slixmpp.xmlstream.handler import Callback | ||||
| from slixmpp.features.feature_mechanisms import stanza | ||||
|  | ||||
| from typing import ClassVar, Set | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
| @@ -23,7 +24,7 @@ class FeatureMechanisms(BasePlugin): | ||||
|  | ||||
|     name = 'feature_mechanisms' | ||||
|     description = 'RFC 6120: Stream Feature: SASL' | ||||
|     dependencies = set() | ||||
|     dependencies: ClassVar[Set[str]] = set() | ||||
|     stanza = stanza | ||||
|     default_config = { | ||||
|         'use_mech': None, | ||||
|   | ||||
| @@ -1,9 +1,9 @@ | ||||
|  | ||||
| # Slixmpp: The Slick XMPP Library | ||||
| # Copyright (C) 2011  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| from slixmpp.xmlstream import StanzaBase | ||||
| from typing import ClassVar, Set | ||||
|  | ||||
|  | ||||
| class Abort(StanzaBase): | ||||
| @@ -13,7 +13,7 @@ class Abort(StanzaBase): | ||||
|  | ||||
|     name = 'abort' | ||||
|     namespace = 'urn:ietf:params:xml:ns:xmpp-sasl' | ||||
|     interfaces = set() | ||||
|     interfaces: ClassVar[Set[str]] = set() | ||||
|     plugin_attrib = name | ||||
|  | ||||
|     def setup(self, xml): | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
|  | ||||
| # Slixmpp: The Slick XMPP Library | ||||
| # Copyright (C) 2012  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| @@ -9,6 +8,7 @@ from slixmpp.stanza import StreamFeatures | ||||
| from slixmpp.features.feature_preapproval import stanza | ||||
| from slixmpp.xmlstream import register_stanza_plugin | ||||
| from slixmpp.plugins.base import BasePlugin | ||||
| from typing import ClassVar, Set | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
| @@ -18,7 +18,7 @@ class FeaturePreApproval(BasePlugin): | ||||
|  | ||||
|     name = 'feature_preapproval' | ||||
|     description = 'RFC 6121: Stream Feature: Subscription Pre-Approval' | ||||
|     dependences = set() | ||||
|     dependencies: ClassVar[Set[str]] = set() | ||||
|     stanza = stanza | ||||
|  | ||||
|     def plugin_init(self): | ||||
|   | ||||
| @@ -1,14 +1,14 @@ | ||||
|  | ||||
| # Slixmpp: The Slick XMPP Library | ||||
| # Copyright (C) 2012  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| from slixmpp.xmlstream import ElementBase | ||||
| from typing import ClassVar, Set | ||||
|  | ||||
|  | ||||
| class PreApproval(ElementBase): | ||||
|  | ||||
|     name = 'sub' | ||||
|     namespace = 'urn:xmpp:features:pre-approval' | ||||
|     interfaces = set() | ||||
|     interfaces: ClassVar[Set[str]] = set() | ||||
|     plugin_attrib = 'preapproval' | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
|  | ||||
| # Slixmpp: The Slick XMPP Library | ||||
| # Copyright (C) 2012  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| @@ -9,6 +8,7 @@ from slixmpp.stanza import StreamFeatures | ||||
| from slixmpp.features.feature_rosterver import stanza | ||||
| from slixmpp.xmlstream import register_stanza_plugin | ||||
| from slixmpp.plugins.base import BasePlugin | ||||
| from typing import ClassVar, Set | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
| @@ -18,7 +18,7 @@ class FeatureRosterVer(BasePlugin): | ||||
|  | ||||
|     name = 'feature_rosterver' | ||||
|     description = 'RFC 6121: Stream Feature: Roster Versioning' | ||||
|     dependences = set() | ||||
|     dependences: ClassVar[Set[str]] = set() | ||||
|     stanza = stanza | ||||
|  | ||||
|     def plugin_init(self): | ||||
|   | ||||
| @@ -1,14 +1,14 @@ | ||||
|  | ||||
| # Slixmpp: The Slick XMPP Library | ||||
| # Copyright (C) 2012  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| from slixmpp.xmlstream import ElementBase | ||||
| from typing import Set, ClassVar | ||||
|  | ||||
|  | ||||
| class RosterVer(ElementBase): | ||||
|  | ||||
|     name = 'ver' | ||||
|     namespace = 'urn:xmpp:features:rosterver' | ||||
|     interfaces = set() | ||||
|     interfaces: ClassVar[Set[str]] = set() | ||||
|     plugin_attrib = 'rosterver' | ||||
|   | ||||
| @@ -11,6 +11,7 @@ from slixmpp.xmlstream import register_stanza_plugin | ||||
| from slixmpp.plugins import BasePlugin | ||||
|  | ||||
| from slixmpp.features.feature_session import stanza | ||||
| from typing import  ClassVar, Set | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
| @@ -20,7 +21,7 @@ class FeatureSession(BasePlugin): | ||||
|  | ||||
|     name = 'feature_session' | ||||
|     description = 'RFC 3920: Stream Feature: Start Session' | ||||
|     dependencies = set() | ||||
|     dependencies: ClassVar[Set[str]] = set() | ||||
|     stanza = stanza | ||||
|  | ||||
|     def plugin_init(self): | ||||
|   | ||||
| @@ -4,39 +4,47 @@ | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| from slixmpp.xmlstream import StanzaBase, ElementBase | ||||
| from typing import Set, ClassVar | ||||
|  | ||||
|  | ||||
| class STARTTLS(ElementBase): | ||||
|  | ||||
|     """ | ||||
| class STARTTLS(StanzaBase): | ||||
|     """ | ||||
|  | ||||
|     .. code-block:: xml | ||||
|  | ||||
|          <starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/> | ||||
|  | ||||
|     """ | ||||
|     name = 'starttls' | ||||
|     namespace = 'urn:ietf:params:xml:ns:xmpp-tls' | ||||
|     interfaces = {'required'} | ||||
|     plugin_attrib = name | ||||
|  | ||||
|     def get_required(self): | ||||
|         """ | ||||
|         """ | ||||
|         return True | ||||
|  | ||||
|  | ||||
| class Proceed(StanzaBase): | ||||
|  | ||||
|     """ | ||||
|     """ | ||||
|  | ||||
|     .. code-block:: xml | ||||
|  | ||||
|         <proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/> | ||||
|  | ||||
|     """ | ||||
|     name = 'proceed' | ||||
|     namespace = 'urn:ietf:params:xml:ns:xmpp-tls' | ||||
|     interfaces = set() | ||||
|     interfaces: ClassVar[Set[str]] = set() | ||||
|  | ||||
|  | ||||
| class Failure(StanzaBase): | ||||
|  | ||||
|     """ | ||||
|     """ | ||||
|  | ||||
|     .. code-block:: xml | ||||
|  | ||||
|         <failure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/> | ||||
|  | ||||
|     """ | ||||
|     name = 'failure' | ||||
|     namespace = 'urn:ietf:params:xml:ns:xmpp-tls' | ||||
|     interfaces = set() | ||||
|     interfaces: ClassVar[Set[str]] = set() | ||||
|   | ||||
| @@ -12,6 +12,8 @@ from slixmpp.xmlstream.matcher import MatchXPath | ||||
| from slixmpp.xmlstream.handler import CoroutineCallback | ||||
| from slixmpp.features.feature_starttls import stanza | ||||
|  | ||||
| from typing import ClassVar, Set | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
| @@ -20,7 +22,7 @@ class FeatureSTARTTLS(BasePlugin): | ||||
|  | ||||
|     name = 'feature_starttls' | ||||
|     description = 'RFC 6120: Stream Feature: STARTTLS' | ||||
|     dependencies = set() | ||||
|     dependencies: ClassVar[Set[str]] = set() | ||||
|     stanza = stanza | ||||
|  | ||||
|     def plugin_init(self): | ||||
| @@ -52,7 +54,7 @@ class FeatureSTARTTLS(BasePlugin): | ||||
|         elif self.xmpp.disable_starttls: | ||||
|             return False | ||||
|         else: | ||||
|             self.xmpp.send(features['starttls']) | ||||
|             self.xmpp.send(stanza.STARTTLS()) | ||||
|             return True | ||||
|  | ||||
|     async def _handle_starttls_proceed(self, proceed): | ||||
|   | ||||
| @@ -350,36 +350,10 @@ class JID: | ||||
|                       if self._resource | ||||
|                       else self._bare) | ||||
|  | ||||
|     @property | ||||
|     def node(self) -> str: | ||||
|         return self._node | ||||
|  | ||||
|     @property | ||||
|     def domain(self) -> str: | ||||
|         return self._domain | ||||
|  | ||||
|     @property | ||||
|     def resource(self) -> str: | ||||
|         return self._resource | ||||
|  | ||||
|     @property | ||||
|     def bare(self) -> str: | ||||
|         return self._bare | ||||
|  | ||||
|     @property | ||||
|     def full(self) -> str: | ||||
|         return self._full | ||||
|  | ||||
|     @node.setter | ||||
|     def node(self, value: str): | ||||
|         self._node = _validate_node(value) | ||||
|         self._update_bare_full() | ||||
|  | ||||
|     @domain.setter | ||||
|     def domain(self, value: str): | ||||
|         self._domain = _validate_domain(value) | ||||
|         self._update_bare_full() | ||||
|  | ||||
|     @bare.setter | ||||
|     def bare(self, value: str): | ||||
|         node, domain, resource = _parse_jid(value) | ||||
| @@ -388,11 +362,38 @@ class JID: | ||||
|         self._domain = domain | ||||
|         self._update_bare_full() | ||||
|  | ||||
|  | ||||
|     @property | ||||
|     def node(self) -> str: | ||||
|         return self._node | ||||
|  | ||||
|     @node.setter | ||||
|     def node(self, value: str): | ||||
|         self._node = _validate_node(value) | ||||
|         self._update_bare_full() | ||||
|  | ||||
|     @property | ||||
|     def domain(self) -> str: | ||||
|         return self._domain | ||||
|  | ||||
|     @domain.setter | ||||
|     def domain(self, value: str): | ||||
|         self._domain = _validate_domain(value) | ||||
|         self._update_bare_full() | ||||
|  | ||||
|     @property | ||||
|     def resource(self) -> str: | ||||
|         return self._resource | ||||
|  | ||||
|     @resource.setter | ||||
|     def resource(self, value: str): | ||||
|         self._resource = _validate_resource(value) | ||||
|         self._update_bare_full() | ||||
|  | ||||
|     @property | ||||
|     def full(self) -> str: | ||||
|         return self._full | ||||
|  | ||||
|     @full.setter | ||||
|     def full(self, value: str): | ||||
|         self._node, self._domain, self._resource = _parse_jid(value) | ||||
|   | ||||
| @@ -12,6 +12,8 @@ import copy | ||||
| import logging | ||||
| import threading | ||||
|  | ||||
| from typing import Any, Dict, Set, ClassVar | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
| @@ -250,17 +252,17 @@ class BasePlugin(object): | ||||
|  | ||||
|     #: A short name for the plugin based on the implemented specification. | ||||
|     #: For example, a plugin for XEP-0030 would use `'xep_0030'`. | ||||
|     name = '' | ||||
|     name: str = '' | ||||
|  | ||||
|     #: A longer name for the plugin, describing its purpose. For example, | ||||
|     #: a plugin for XEP-0030 would use `'Service Discovery'` as its | ||||
|     #: description value. | ||||
|     description = '' | ||||
|     description: str = '' | ||||
|  | ||||
|     #: Some plugins may depend on others in order to function properly. | ||||
|     #: Any plugin names included in :attr:`~BasePlugin.dependencies` will | ||||
|     #: be initialized as needed if this plugin is enabled. | ||||
|     dependencies = set() | ||||
|     dependencies: ClassVar[Set[str]] = set() | ||||
|  | ||||
|     #: The basic, standard configuration for the plugin, which may | ||||
|     #: be overridden when initializing the plugin. The configuration | ||||
| @@ -268,7 +270,7 @@ class BasePlugin(object): | ||||
|     #: the plugin. For example, including the configuration field 'foo' | ||||
|     #: would mean accessing `plugin.foo` returns the current value of | ||||
|     #: `plugin.config['foo']`. | ||||
|     default_config = {} | ||||
|     default_config: ClassVar[Dict[str, Any]] = {} | ||||
|  | ||||
|     def __init__(self, xmpp, config=None): | ||||
|         self.xmpp = xmpp | ||||
|   | ||||
| @@ -11,11 +11,11 @@ from typing import ( | ||||
|     Optional | ||||
| ) | ||||
|  | ||||
| from slixmpp.plugins import BasePlugin, register_plugin | ||||
| from slixmpp import future_wrapper, JID | ||||
| from slixmpp.plugins import BasePlugin | ||||
| from slixmpp import JID | ||||
| from slixmpp.stanza import Iq | ||||
| from slixmpp.exceptions import XMPPError | ||||
| from slixmpp.xmlstream import JID, register_stanza_plugin | ||||
| from slixmpp.xmlstream import register_stanza_plugin | ||||
| from slixmpp.xmlstream.handler import CoroutineCallback | ||||
| from slixmpp.xmlstream.matcher import StanzaPath | ||||
| from slixmpp.plugins.xep_0012 import stanza, LastActivity | ||||
|   | ||||
| @@ -4,7 +4,6 @@ | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| import logging | ||||
| from asyncio import Future | ||||
| from typing import Optional | ||||
|  | ||||
| from slixmpp import JID | ||||
| @@ -15,7 +14,6 @@ from slixmpp.xmlstream.handler import CoroutineCallback | ||||
| from slixmpp.xmlstream.matcher import StanzaPath | ||||
| from slixmpp.plugins import BasePlugin | ||||
| from slixmpp.plugins.xep_0054 import VCardTemp, stanza | ||||
| from slixmpp import future_wrapper | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
|  | ||||
| # Slixmpp: The Slick XMPP Library | ||||
| # Copyright (C) 2015 Emmanuel Gil Peyrot | ||||
| # This file is part of Slixmpp. | ||||
| @@ -7,11 +6,10 @@ import asyncio | ||||
| import logging | ||||
| from uuid import uuid4 | ||||
|  | ||||
| from slixmpp.plugins import BasePlugin, register_plugin | ||||
| from slixmpp import future_wrapper, Iq, Message | ||||
| from slixmpp.exceptions import XMPPError, IqError, IqTimeout | ||||
| from slixmpp.plugins import BasePlugin | ||||
| from slixmpp import Iq, Message | ||||
| from slixmpp.jid import JID | ||||
| from slixmpp.xmlstream import JID, register_stanza_plugin | ||||
| from slixmpp.xmlstream import register_stanza_plugin | ||||
| from slixmpp.xmlstream.handler import Callback | ||||
| from slixmpp.xmlstream.matcher import StanzaPath | ||||
| from slixmpp.plugins.xep_0070 import stanza, Confirm | ||||
| @@ -52,7 +50,6 @@ class XEP_0070(BasePlugin): | ||||
|     def session_bind(self, jid): | ||||
|         self.xmpp['xep_0030'].add_feature('http://jabber.org/protocol/http-auth') | ||||
|  | ||||
|     @future_wrapper | ||||
|     def ask_confirm(self, jid, id, url, method, *, ifrom=None, message=None): | ||||
|         jid = JID(jid) | ||||
|         if jid.resource: | ||||
| @@ -70,7 +67,9 @@ class XEP_0070(BasePlugin): | ||||
|             if message is not None: | ||||
|                 stanza['body'] = message.format(id=id, url=url, method=method) | ||||
|             stanza.send() | ||||
|             return stanza | ||||
|             fut = asyncio.Future() | ||||
|             fut.set_result(stanza) | ||||
|             return fut | ||||
|         else: | ||||
|             return stanza.send() | ||||
|  | ||||
|   | ||||
| @@ -17,7 +17,6 @@ from slixmpp.exceptions import XMPPError, IqTimeout, IqError | ||||
| from slixmpp.xmlstream import register_stanza_plugin, ElementBase | ||||
| from slixmpp.plugins.base import BasePlugin | ||||
| from slixmpp.plugins.xep_0153 import stanza, VCardTempUpdate | ||||
| from slixmpp import future_wrapper | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|   | ||||
| @@ -3,10 +3,11 @@ | ||||
| # Copyright (C) 2011 Nathanael C. Fritz, Lance J.T. Stout | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| import asyncio | ||||
| import logging | ||||
|  | ||||
| from typing import Optional, Callable | ||||
| from slixmpp import asyncio, JID | ||||
| from slixmpp import JID | ||||
| from slixmpp.xmlstream import register_stanza_plugin, ElementBase | ||||
| from slixmpp.plugins.base import BasePlugin, register_plugin | ||||
| from slixmpp.plugins.xep_0004.stanza import Form | ||||
|   | ||||
| @@ -3,6 +3,7 @@ | ||||
| # Copyright (C) 2010 Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| import asyncio | ||||
| import time | ||||
| import logging | ||||
|  | ||||
| @@ -11,7 +12,6 @@ from typing import Optional, Callable, List | ||||
|  | ||||
| from slixmpp.jid import JID | ||||
| from slixmpp.stanza import Iq | ||||
| from slixmpp import asyncio | ||||
| from slixmpp.exceptions import IqError, IqTimeout | ||||
| from slixmpp.xmlstream import register_stanza_plugin | ||||
| from slixmpp.xmlstream.matcher import StanzaPath | ||||
|   | ||||
| @@ -9,14 +9,13 @@ import hashlib | ||||
| from asyncio import Future | ||||
| from typing import Optional | ||||
|  | ||||
| from slixmpp import future_wrapper, JID | ||||
| from slixmpp import JID | ||||
| from slixmpp.stanza import Iq, Message, Presence | ||||
| from slixmpp.exceptions import XMPPError | ||||
| from slixmpp.xmlstream.handler import CoroutineCallback | ||||
| from slixmpp.xmlstream.matcher import StanzaPath | ||||
| from slixmpp.xmlstream import register_stanza_plugin | ||||
| from slixmpp.plugins.base import BasePlugin | ||||
| from slixmpp.plugins.xep_0231 import stanza, BitsOfBinary | ||||
| from slixmpp.plugins.xep_0231 import BitsOfBinary | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|   | ||||
| @@ -5,10 +5,10 @@ | ||||
| # Copyright (C) 2013 Sustainable Innovation, Joachim.lindborg@sust.se, bjorn.westrom@consoden.se | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| import asyncio | ||||
| import logging | ||||
| import time | ||||
|  | ||||
| from slixmpp import asyncio | ||||
| from functools import partial | ||||
| from slixmpp.xmlstream import JID | ||||
| from slixmpp.xmlstream.handler import Callback | ||||
|   | ||||
| @@ -20,6 +20,7 @@ from slixmpp.plugins.xep_0030 import XEP_0030 | ||||
| from slixmpp.plugins.xep_0033 import XEP_0033 | ||||
| from slixmpp.plugins.xep_0045 import XEP_0045 | ||||
| from slixmpp.plugins.xep_0047 import XEP_0047 | ||||
| from slixmpp.plugins.xep_0048 import XEP_0048 | ||||
| from slixmpp.plugins.xep_0049 import XEP_0049 | ||||
| from slixmpp.plugins.xep_0050 import XEP_0050 | ||||
| from slixmpp.plugins.xep_0054 import XEP_0054 | ||||
| @@ -112,6 +113,7 @@ class PluginsDict(TypedDict): | ||||
|     xep_0033: XEP_0033 | ||||
|     xep_0045: XEP_0045 | ||||
|     xep_0047: XEP_0047 | ||||
|     xep_0048: XEP_0048 | ||||
|     xep_0049: XEP_0049 | ||||
|     xep_0050: XEP_0050 | ||||
|     xep_0054: XEP_0054 | ||||
|   | ||||
							
								
								
									
										0
									
								
								slixmpp/py.typed
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								slixmpp/py.typed
									
									
									
									
									
										Normal file
									
								
							| @@ -1,8 +1,9 @@ | ||||
|  | ||||
| # Slixmpp: The Slick XMPP Library | ||||
| # Copyright (C) 2010  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| from __future__ import annotations | ||||
| from typing import Optional, Dict, Type, ClassVar | ||||
| from slixmpp.xmlstream import ElementBase, ET | ||||
|  | ||||
|  | ||||
| @@ -49,10 +50,10 @@ class Error(ElementBase): | ||||
|     name = 'error' | ||||
|     plugin_attrib = 'error' | ||||
|     interfaces = {'code', 'condition', 'text', 'type', | ||||
|                       'gone', 'redirect', 'by'} | ||||
|                   'gone', 'redirect', 'by'} | ||||
|     sub_interfaces = {'text'} | ||||
|     plugin_attrib_map = {} | ||||
|     plugin_tag_map = {} | ||||
|     plugin_attrib_map: ClassVar[Dict[str, Type[ElementBase]]] = {} | ||||
|     plugin_tag_map: ClassVar[Dict[str, Type[ElementBase]]] = {} | ||||
|     conditions = {'bad-request', 'conflict', 'feature-not-implemented', | ||||
|                   'forbidden', 'gone', 'internal-server-error', | ||||
|                   'item-not-found', 'jid-malformed', 'not-acceptable', | ||||
| @@ -62,10 +63,10 @@ class Error(ElementBase): | ||||
|                   'remote-server-timeout', 'resource-constraint', | ||||
|                   'service-unavailable', 'subscription-required', | ||||
|                   'undefined-condition', 'unexpected-request'} | ||||
|     condition_ns = 'urn:ietf:params:xml:ns:xmpp-stanzas' | ||||
|     condition_ns: str = 'urn:ietf:params:xml:ns:xmpp-stanzas' | ||||
|     types = {'cancel', 'continue', 'modify', 'auth', 'wait'} | ||||
|  | ||||
|     def setup(self, xml=None): | ||||
|     def setup(self, xml: Optional[ET.Element] = None): | ||||
|         """ | ||||
|         Populate the stanza object using an optional XML object. | ||||
|  | ||||
| @@ -82,9 +83,11 @@ class Error(ElementBase): | ||||
|             self['type'] = 'cancel' | ||||
|             self['condition'] = 'feature-not-implemented' | ||||
|         if self.parent is not None: | ||||
|             self.parent()['type'] = 'error' | ||||
|             parent = self.parent() | ||||
|             if parent: | ||||
|                 parent['type'] = 'error' | ||||
|  | ||||
|     def get_condition(self): | ||||
|     def get_condition(self) -> str: | ||||
|         """Return the condition element's name.""" | ||||
|         for child in self.xml: | ||||
|             if "{%s}" % self.condition_ns in child.tag: | ||||
| @@ -93,7 +96,7 @@ class Error(ElementBase): | ||||
|                     return cond | ||||
|         return '' | ||||
|  | ||||
|     def set_condition(self, value): | ||||
|     def set_condition(self, value: str) -> Error: | ||||
|         """ | ||||
|         Set the tag name of the condition element. | ||||
|  | ||||
| @@ -105,7 +108,7 @@ class Error(ElementBase): | ||||
|             self.xml.append(ET.Element("{%s}%s" % (self.condition_ns, value))) | ||||
|         return self | ||||
|  | ||||
|     def del_condition(self): | ||||
|     def del_condition(self) -> Error: | ||||
|         """Remove the condition element.""" | ||||
|         for child in self.xml: | ||||
|             if "{%s}" % self.condition_ns in child.tag: | ||||
| @@ -139,14 +142,14 @@ class Error(ElementBase): | ||||
|     def get_redirect(self): | ||||
|         return self._get_sub_text('{%s}redirect' % self.condition_ns, '') | ||||
|  | ||||
|     def set_gone(self, value): | ||||
|     def set_gone(self, value: str): | ||||
|         if value: | ||||
|             del self['condition'] | ||||
|             return self._set_sub_text('{%s}gone' % self.condition_ns, value) | ||||
|         elif self['condition'] == 'gone': | ||||
|             del self['condition'] | ||||
|  | ||||
|     def set_redirect(self, value): | ||||
|     def set_redirect(self, value: str): | ||||
|         if value: | ||||
|             del self['condition'] | ||||
|             ns = self.condition_ns | ||||
|   | ||||
| @@ -4,6 +4,7 @@ | ||||
| # See the file LICENSE for copying permission. | ||||
|  | ||||
| from slixmpp.xmlstream import StanzaBase | ||||
| from typing import Optional | ||||
|  | ||||
|  | ||||
| class Handshake(StanzaBase): | ||||
| @@ -18,7 +19,7 @@ class Handshake(StanzaBase): | ||||
|     def set_value(self, value: str): | ||||
|         self.xml.text = value | ||||
|  | ||||
|     def get_value(self) -> str: | ||||
|     def get_value(self) -> Optional[str]: | ||||
|         return self.xml.text | ||||
|  | ||||
|     def del_value(self): | ||||
|   | ||||
| @@ -3,10 +3,10 @@ | ||||
| # Copyright (C) 2010  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| import asyncio | ||||
| from slixmpp.stanza.rootstanza import RootStanza | ||||
| from slixmpp.xmlstream import StanzaBase, ET | ||||
| from slixmpp.xmlstream.handler import Waiter, Callback, CoroutineCallback | ||||
| from slixmpp.xmlstream.asyncio import asyncio | ||||
| from slixmpp.xmlstream.handler import Callback, CoroutineCallback | ||||
| from slixmpp.xmlstream.matcher import MatchIDSender, MatcherId | ||||
| from slixmpp.exceptions import IqTimeout, IqError | ||||
|  | ||||
|   | ||||
| @@ -61,8 +61,10 @@ class Message(RootStanza): | ||||
|         """ | ||||
|         StanzaBase.__init__(self, *args, **kwargs) | ||||
|         if not recv and self['id'] == '': | ||||
|             if self.stream is not None and self.stream.use_message_ids: | ||||
|                 self['id'] = self.stream.new_id() | ||||
|             if self.stream: | ||||
|                 use_ids = getattr(self.stream, 'use_message_ids', None) | ||||
|                 if use_ids: | ||||
|                     self['id'] = self.stream.new_id() | ||||
|             else: | ||||
|                 del self['origin_id'] | ||||
|  | ||||
| @@ -93,8 +95,10 @@ class Message(RootStanza): | ||||
|  | ||||
|         self.xml.attrib['id'] = value | ||||
|  | ||||
|         if self.stream and not self.stream.use_origin_id: | ||||
|             return None | ||||
|         if self.stream: | ||||
|             use_orig_ids = getattr(self.stream, 'use_origin_id', None) | ||||
|             if not use_orig_ids: | ||||
|                 return None | ||||
|  | ||||
|         sub = self.xml.find(ORIGIN_NAME) | ||||
|         if sub is not None: | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
|  | ||||
| # Slixmpp: The Slick XMPP Library | ||||
| # Copyright (C) 2010  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| @@ -61,7 +60,7 @@ class Presence(RootStanza): | ||||
|              'subscribed', 'unsubscribe', 'unsubscribed'} | ||||
|     showtypes = {'dnd', 'chat', 'xa', 'away'} | ||||
|  | ||||
|     def __init__(self, *args, recv=False, **kwargs): | ||||
|     def __init__(self, *args, recv: bool = False, **kwargs): | ||||
|         """ | ||||
|         Initialize a new <presence /> stanza with an optional 'id' value. | ||||
|  | ||||
| @@ -69,10 +68,12 @@ class Presence(RootStanza): | ||||
|         """ | ||||
|         StanzaBase.__init__(self, *args, **kwargs) | ||||
|         if not recv and self['id'] == '': | ||||
|             if self.stream is not None and self.stream.use_presence_ids: | ||||
|                 self['id'] = self.stream.new_id() | ||||
|             if self.stream: | ||||
|                 use_ids = getattr(self.stream, 'use_presence_ids', None) | ||||
|                 if use_ids: | ||||
|                     self['id'] = self.stream.new_id() | ||||
|  | ||||
|     def set_show(self, show): | ||||
|     def set_show(self, show: str): | ||||
|         """ | ||||
|         Set the value of the <show> element. | ||||
|  | ||||
| @@ -84,7 +85,7 @@ class Presence(RootStanza): | ||||
|             self._set_sub_text('show', text=show) | ||||
|         return self | ||||
|  | ||||
|     def get_type(self): | ||||
|     def get_type(self) -> str: | ||||
|         """ | ||||
|         Return the value of the <presence> stanza's type attribute, or | ||||
|         the value of the <show> element if valid. | ||||
| @@ -96,7 +97,7 @@ class Presence(RootStanza): | ||||
|             out = 'available' | ||||
|         return out | ||||
|  | ||||
|     def set_type(self, value): | ||||
|     def set_type(self, value: str): | ||||
|         """ | ||||
|         Set the type attribute's value, and the <show> element | ||||
|         if applicable. | ||||
| @@ -119,7 +120,7 @@ class Presence(RootStanza): | ||||
|         self._del_attr('type') | ||||
|         self._del_sub('show') | ||||
|  | ||||
|     def set_priority(self, value): | ||||
|     def set_priority(self, value: int): | ||||
|         """ | ||||
|         Set the entity's priority value. Some server use priority to | ||||
|         determine message routing behavior. | ||||
|   | ||||
| @@ -4,7 +4,8 @@ | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| from slixmpp.stanza.error import Error | ||||
| from slixmpp.xmlstream import StanzaBase | ||||
| from slixmpp.xmlstream import StanzaBase, ET | ||||
| from typing import Optional, Dict, Union | ||||
|  | ||||
|  | ||||
| class StreamError(Error, StanzaBase): | ||||
| @@ -62,19 +63,20 @@ class StreamError(Error, StanzaBase): | ||||
|         'system-shutdown', 'undefined-condition', 'unsupported-encoding', | ||||
|         'unsupported-feature', 'unsupported-stanza-type', | ||||
|         'unsupported-version'} | ||||
|     condition_ns = 'urn:ietf:params:xml:ns:xmpp-streams' | ||||
|     condition_ns: str = 'urn:ietf:params:xml:ns:xmpp-streams' | ||||
|  | ||||
|     def get_see_other_host(self): | ||||
|     def get_see_other_host(self) -> Union[str, Dict[str, str]]: | ||||
|         ns = self.condition_ns | ||||
|         return self._get_sub_text('{%s}see-other-host' % ns, '') | ||||
|  | ||||
|     def set_see_other_host(self, value): | ||||
|     def set_see_other_host(self, value: str) -> Optional[ET.Element]: | ||||
|         if value: | ||||
|             del self['condition'] | ||||
|             ns = self.condition_ns | ||||
|             return self._set_sub_text('{%s}see-other-host' % ns, value) | ||||
|         elif self['condition'] == 'see-other-host': | ||||
|             del self['condition'] | ||||
|         return None | ||||
|  | ||||
|     def del_see_other_host(self): | ||||
|     def del_see_other_host(self) -> None: | ||||
|         self._del_sub('{%s}see-other-host' % self.condition_ns) | ||||
|   | ||||
| @@ -3,7 +3,8 @@ | ||||
| # Copyright (C) 2010  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| from slixmpp.xmlstream import StanzaBase | ||||
| from slixmpp.xmlstream import StanzaBase, ElementBase | ||||
| from typing import ClassVar, Dict, Type | ||||
|  | ||||
|  | ||||
| class StreamFeatures(StanzaBase): | ||||
| @@ -15,8 +16,8 @@ class StreamFeatures(StanzaBase): | ||||
|     namespace = 'http://etherx.jabber.org/streams' | ||||
|     interfaces = {'features', 'required', 'optional'} | ||||
|     sub_interfaces = interfaces | ||||
|     plugin_tag_map = {} | ||||
|     plugin_attrib_map = {} | ||||
|     plugin_attrib_map: ClassVar[Dict[str, Type[ElementBase]]] = {} | ||||
|     plugin_tag_map: ClassVar[Dict[str, Type[ElementBase]]] = {} | ||||
|  | ||||
|     def setup(self, xml): | ||||
|         StanzaBase.setup(self, xml) | ||||
|   | ||||
| @@ -11,7 +11,7 @@ except ImportError: | ||||
|     # Python < 3.8 | ||||
|     # just to make sure the imports do not break, but | ||||
|     # not usable. | ||||
|     from unittest import TestCase as IsolatedAsyncioTestCase | ||||
|     from unittest import TestCase as IsolatedAsyncioTestCase  # type: ignore | ||||
| from typing import ( | ||||
|     Dict, | ||||
|     List, | ||||
|   | ||||
| @@ -17,9 +17,7 @@ from slixmpp.xmlstream.matcher import StanzaPath, MatcherId, MatchIDSender | ||||
| from slixmpp.xmlstream.matcher import MatchXMLMask, MatchXPath | ||||
|  | ||||
| import asyncio | ||||
| cls = asyncio.get_event_loop().__class__ | ||||
|  | ||||
| cls.idle_call = lambda self, callback: callback() | ||||
|  | ||||
| class SlixTest(unittest.TestCase): | ||||
|  | ||||
|   | ||||
| @@ -16,11 +16,13 @@ try: | ||||
|     from typing import ( | ||||
|         Literal, | ||||
|         TypedDict, | ||||
|         Protocol, | ||||
|     ) | ||||
| except ImportError: | ||||
|     from typing_extensions import ( | ||||
|         Literal, | ||||
|         TypedDict, | ||||
|         Protocol, | ||||
|     ) | ||||
|  | ||||
| from slixmpp.jid import JID | ||||
| @@ -78,3 +80,11 @@ JidStr = Union[str, JID] | ||||
| OptJidStr = Optional[Union[str, JID]] | ||||
|  | ||||
| MAMDefault = Literal['always', 'never', 'roster'] | ||||
|  | ||||
| FilterString = Literal['in', 'out', 'out_sync'] | ||||
|  | ||||
| __all__ = [ | ||||
|     'Protocol', 'TypedDict', 'Literal', 'OptJid', 'JidStr', 'MAMDefault', | ||||
|     'PresenceTypes', 'PresenceShows', 'MessageTypes', 'IqTypes', 'MucRole', | ||||
|     'MucAffiliation', 'FilterString', | ||||
| ] | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
|  | ||||
| # Slixmpp: The Slick XMPP Library | ||||
| # Copyright (C) 2018 Emmanuel Gil Peyrot | ||||
| # This file is part of Slixmpp. | ||||
| @@ -6,8 +5,11 @@ | ||||
| import os | ||||
| import logging | ||||
|  | ||||
| from typing import Callable, Optional, Any | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class Cache: | ||||
|     def retrieve(self, key): | ||||
|         raise NotImplementedError | ||||
| @@ -16,7 +18,8 @@ class Cache: | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def remove(self, key): | ||||
|         raise NotImplemented | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| class PerJidCache: | ||||
|     def retrieve_by_jid(self, jid, key): | ||||
| @@ -28,6 +31,7 @@ class PerJidCache: | ||||
|     def remove_by_jid(self, jid, key): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| class MemoryCache(Cache): | ||||
|     def __init__(self): | ||||
|         self.cache = {} | ||||
| @@ -44,6 +48,7 @@ class MemoryCache(Cache): | ||||
|             del self.cache[key] | ||||
|         return True | ||||
|  | ||||
|  | ||||
| class MemoryPerJidCache(PerJidCache): | ||||
|     def __init__(self): | ||||
|         self.cache = {} | ||||
| @@ -65,14 +70,15 @@ class MemoryPerJidCache(PerJidCache): | ||||
|             del cache[key] | ||||
|         return True | ||||
|  | ||||
|  | ||||
| class FileSystemStorage: | ||||
|     def __init__(self, encode, decode, binary): | ||||
|     def __init__(self, encode: Optional[Callable[[Any], str]], decode: Optional[Callable[[str], Any]], binary: bool): | ||||
|         self.encode = encode if encode is not None else lambda x: x | ||||
|         self.decode = decode if decode is not None else lambda x: x | ||||
|         self.read = 'rb' if binary else 'r' | ||||
|         self.write = 'wb' if binary else 'w' | ||||
|  | ||||
|     def _retrieve(self, directory, key): | ||||
|     def _retrieve(self, directory: str, key: str): | ||||
|         filename = os.path.join(directory, key.replace('/', '_')) | ||||
|         try: | ||||
|             with open(filename, self.read) as cache_file: | ||||
| @@ -86,7 +92,7 @@ class FileSystemStorage: | ||||
|             log.debug('Removing %s entry', key) | ||||
|             self._remove(directory, key) | ||||
|  | ||||
|     def _store(self, directory, key, value): | ||||
|     def _store(self, directory: str, key: str, value): | ||||
|         filename = os.path.join(directory, key.replace('/', '_')) | ||||
|         try: | ||||
|             os.makedirs(directory, exist_ok=True) | ||||
| @@ -99,7 +105,7 @@ class FileSystemStorage: | ||||
|         except Exception: | ||||
|             log.debug('Failed to encode %s to cache:', key, exc_info=True) | ||||
|  | ||||
|     def _remove(self, directory, key): | ||||
|     def _remove(self, directory: str, key: str): | ||||
|         filename = os.path.join(directory, key.replace('/', '_')) | ||||
|         try: | ||||
|             os.remove(filename) | ||||
| @@ -108,8 +114,9 @@ class FileSystemStorage: | ||||
|             return False | ||||
|         return True | ||||
|  | ||||
|  | ||||
| class FileSystemCache(Cache, FileSystemStorage): | ||||
|     def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False): | ||||
|     def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False): | ||||
|         FileSystemStorage.__init__(self, encode, decode, binary) | ||||
|         self.base_dir = os.path.join(directory, cache_type) | ||||
|  | ||||
| @@ -122,8 +129,9 @@ class FileSystemCache(Cache, FileSystemStorage): | ||||
|     def remove(self, key): | ||||
|         return self._remove(self.base_dir, key) | ||||
|  | ||||
|  | ||||
| class FileSystemPerJidCache(PerJidCache, FileSystemStorage): | ||||
|     def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False): | ||||
|     def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False): | ||||
|         FileSystemStorage.__init__(self, encode, decode, binary) | ||||
|         self.base_dir = os.path.join(directory, cache_type) | ||||
|  | ||||
|   | ||||
| @@ -2,15 +2,19 @@ import builtins | ||||
| import sys | ||||
| import hashlib | ||||
|  | ||||
| from typing import Optional, Union, Callable, List | ||||
|  | ||||
| def unicode(text): | ||||
| bytes_ = builtins.bytes  # alias the stdlib type but ew | ||||
|  | ||||
|  | ||||
| def unicode(text: Union[bytes_, str]) -> str: | ||||
|     if not isinstance(text, str): | ||||
|         return text.decode('utf-8') | ||||
|     else: | ||||
|         return text | ||||
|  | ||||
|  | ||||
| def bytes(text): | ||||
| def bytes(text: Optional[Union[str, bytes_]]) -> bytes_: | ||||
|     """ | ||||
|     Convert Unicode text to UTF-8 encoded bytes. | ||||
|  | ||||
| @@ -34,7 +38,7 @@ def bytes(text): | ||||
|         return builtins.bytes(text, encoding='utf-8') | ||||
|  | ||||
|  | ||||
| def quote(text): | ||||
| def quote(text: Union[str, bytes_]) -> bytes_: | ||||
|     """ | ||||
|     Enclose in quotes and escape internal slashes and double quotes. | ||||
|  | ||||
| @@ -44,7 +48,7 @@ def quote(text): | ||||
|     return b'"' + text.replace(b'\\', b'\\\\').replace(b'"', b'\\"') + b'"' | ||||
|  | ||||
|  | ||||
| def num_to_bytes(num): | ||||
| def num_to_bytes(num: int) -> bytes_: | ||||
|     """ | ||||
|     Convert an integer into a four byte sequence. | ||||
|  | ||||
| @@ -58,21 +62,21 @@ def num_to_bytes(num): | ||||
|     return bval | ||||
|  | ||||
|  | ||||
| def bytes_to_num(bval): | ||||
| def bytes_to_num(bval: bytes_) -> int: | ||||
|     """ | ||||
|     Convert a four byte sequence to an integer. | ||||
|  | ||||
|     :param bytes bval: A four byte sequence to turn into an integer. | ||||
|     """ | ||||
|     num = 0 | ||||
|     num += ord(bval[0] << 24) | ||||
|     num += ord(bval[1] << 16) | ||||
|     num += ord(bval[2] << 8) | ||||
|     num += ord(bval[3]) | ||||
|     num += (bval[0] << 24) | ||||
|     num += (bval[1] << 16) | ||||
|     num += (bval[2] << 8) | ||||
|     num += (bval[3]) | ||||
|     return num | ||||
|  | ||||
|  | ||||
| def XOR(x, y): | ||||
| def XOR(x: bytes_, y: bytes_) -> bytes_: | ||||
|     """ | ||||
|     Return the results of an XOR operation on two equal length byte strings. | ||||
|  | ||||
| @@ -85,7 +89,7 @@ def XOR(x, y): | ||||
|     return builtins.bytes([a ^ b for a, b in zip(x, y)]) | ||||
|  | ||||
|  | ||||
| def hash(name): | ||||
| def hash(name: str) -> Optional[Callable]: | ||||
|     """ | ||||
|     Return a hash function implementing the given algorithm. | ||||
|  | ||||
| @@ -102,7 +106,7 @@ def hash(name): | ||||
|     return None | ||||
|  | ||||
|  | ||||
| def hashes(): | ||||
| def hashes() -> List[str]: | ||||
|     """ | ||||
|     Return a list of available hashing algorithms. | ||||
|  | ||||
| @@ -115,28 +119,3 @@ def hashes(): | ||||
|         t += ['MD2'] | ||||
|     hashes = ['SHA-' + h[3:] for h in dir(hashlib) if h.startswith('sha')] | ||||
|     return t + hashes | ||||
|  | ||||
|  | ||||
| def setdefaultencoding(encoding): | ||||
|     """ | ||||
|     Set the current default string encoding used by the Unicode implementation. | ||||
|  | ||||
|     Actually calls sys.setdefaultencoding under the hood - see the docs for that | ||||
|     for more details.  This method exists only as a way to call find/call it | ||||
|     even after it has been 'deleted' when the site module is executed. | ||||
|  | ||||
|     :param string encoding: An encoding name, compatible with sys.setdefaultencoding | ||||
|     """ | ||||
|     func = getattr(sys, 'setdefaultencoding', None) | ||||
|     if func is None: | ||||
|         import gc | ||||
|         import types | ||||
|         for obj in gc.get_objects(): | ||||
|             if (isinstance(obj, types.BuiltinFunctionType) | ||||
|                     and obj.__name__ == 'setdefaultencoding'): | ||||
|                 func = obj | ||||
|                 break | ||||
|         if func is None: | ||||
|             raise RuntimeError("Could not find setdefaultencoding") | ||||
|         sys.setdefaultencoding = func | ||||
|     return func(encoding) | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
|  | ||||
| # slixmpp.util.sasl.client | ||||
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| # This module was originally based on Dave Cridland's Suelta library. | ||||
| @@ -6,9 +5,11 @@ | ||||
| # :copryight: (c) 2004-2013 David Alan Cridland | ||||
| # :copyright: (c) 2013 Nathanael C. Fritz, Lance J.T. Stout | ||||
| # :license: MIT, see LICENSE for more details | ||||
| from __future__ import annotations | ||||
| import logging | ||||
| import stringprep | ||||
|  | ||||
| from typing import Iterable, Set, Callable, Dict, Any, Optional, Type | ||||
| from slixmpp.util import hashes, bytes, stringprep_profiles | ||||
|  | ||||
|  | ||||
| @@ -16,11 +17,11 @@ log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| #: Global registry mapping mechanism names to implementation classes. | ||||
| MECHANISMS = {} | ||||
| MECHANISMS: Dict[str, Type[Mech]] = {} | ||||
|  | ||||
|  | ||||
| #: Global registry mapping mechanism names to security scores. | ||||
| MECH_SEC_SCORES = {} | ||||
| MECH_SEC_SCORES: Dict[str, int] = {} | ||||
|  | ||||
|  | ||||
| #: The SASLprep profile of stringprep used to validate simple username | ||||
| @@ -45,9 +46,10 @@ saslprep = stringprep_profiles.create( | ||||
|     unassigned=[stringprep.in_table_a1]) | ||||
|  | ||||
|  | ||||
| def sasl_mech(score): | ||||
| def sasl_mech(score: int): | ||||
|     sec_score = score | ||||
|     def register(mech): | ||||
|  | ||||
|     def register(mech: Type[Mech]): | ||||
|         n = 0 | ||||
|         mech.score = sec_score | ||||
|         if mech.use_hashes: | ||||
| @@ -99,9 +101,9 @@ class Mech(object): | ||||
|     score = -1 | ||||
|     use_hashes = False | ||||
|     channel_binding = False | ||||
|     required_credentials = set() | ||||
|     optional_credentials = set() | ||||
|     security = set() | ||||
|     required_credentials: Set[str] = set() | ||||
|     optional_credentials: Set[str] = set() | ||||
|     security: Set[str] = set() | ||||
|  | ||||
|     def __init__(self, name, credentials, security_settings): | ||||
|         self.credentials = credentials | ||||
| @@ -118,7 +120,14 @@ class Mech(object): | ||||
|         return b'' | ||||
|  | ||||
|  | ||||
| def choose(mech_list, credentials, security_settings, limit=None, min_mech=None): | ||||
| CredentialsCallback = Callable[[Iterable[str], Iterable[str]], Dict[str, Any]] | ||||
| SecurityCallback = Callable[[Iterable[str]], Dict[str, Any]] | ||||
|  | ||||
|  | ||||
| def choose(mech_list: Iterable[Type[Mech]], credentials: CredentialsCallback, | ||||
|            security_settings: SecurityCallback, | ||||
|            limit: Optional[Iterable[Type[Mech]]] = None, | ||||
|            min_mech: Optional[str] = None) -> Mech: | ||||
|     available_mechs = set(MECHANISMS.keys()) | ||||
|     if limit is None: | ||||
|         limit = set(mech_list) | ||||
| @@ -130,7 +139,10 @@ def choose(mech_list, credentials, security_settings, limit=None, min_mech=None) | ||||
|     mech_list = mech_list.intersection(limit) | ||||
|     available_mechs = available_mechs.intersection(mech_list) | ||||
|  | ||||
|     best_score = MECH_SEC_SCORES.get(min_mech, -1) | ||||
|     if min_mech is None: | ||||
|         best_score = -1 | ||||
|     else: | ||||
|         best_score = MECH_SEC_SCORES.get(min_mech, -1) | ||||
|     best_mech = None | ||||
|     for name in available_mechs: | ||||
|         if name in MECH_SEC_SCORES: | ||||
|   | ||||
| @@ -11,6 +11,9 @@ import hmac | ||||
| import random | ||||
|  | ||||
| from base64 import b64encode, b64decode | ||||
| from typing import List, Dict, Optional | ||||
|  | ||||
| bytes_ = bytes | ||||
|  | ||||
| from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes | ||||
| from slixmpp.util.sasl.client import sasl_mech, Mech, \ | ||||
| @@ -63,7 +66,7 @@ class PLAIN(Mech): | ||||
|             if not self.security_settings['encrypted_plain']: | ||||
|                 raise SASLCancelled('PLAIN with encryption') | ||||
|  | ||||
|     def process(self, challenge=b''): | ||||
|     def process(self, challenge: bytes_ = b'') -> bytes_: | ||||
|         authzid = self.credentials['authzid'] | ||||
|         authcid = self.credentials['username'] | ||||
|         password = self.credentials['password'] | ||||
| @@ -148,7 +151,7 @@ class CRAM(Mech): | ||||
|     required_credentials = {'username', 'password'} | ||||
|     security = {'encrypted', 'unencrypted_cram'} | ||||
|  | ||||
|     def setup(self, name): | ||||
|     def setup(self, name: str): | ||||
|         self.hash_name = name[5:] | ||||
|         self.hash = hash(self.hash_name) | ||||
|         if self.hash is None: | ||||
| @@ -157,14 +160,14 @@ class CRAM(Mech): | ||||
|             if not self.security_settings['unencrypted_cram']: | ||||
|                 raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name) | ||||
|  | ||||
|     def process(self, challenge=b''): | ||||
|     def process(self, challenge: bytes_ = b'') -> Optional[bytes_]: | ||||
|         if not challenge: | ||||
|             return None | ||||
|  | ||||
|         username = self.credentials['username'] | ||||
|         password = self.credentials['password'] | ||||
|  | ||||
|         mac = hmac.HMAC(key=password, digestmod=self.hash) | ||||
|         mac = hmac.HMAC(key=password, digestmod=self.hash)  # type: ignore | ||||
|         mac.update(challenge) | ||||
|  | ||||
|         return username + b' ' + bytes(mac.hexdigest()) | ||||
| @@ -201,43 +204,42 @@ class SCRAM(Mech): | ||||
|     def HMAC(self, key, msg): | ||||
|         return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest() | ||||
|  | ||||
|     def Hi(self, text, salt, iterations): | ||||
|         text = bytes(text) | ||||
|         ui1 = self.HMAC(text, salt + b'\0\0\0\01') | ||||
|     def Hi(self, text: str, salt: bytes_, iterations: int): | ||||
|         text_enc = bytes(text) | ||||
|         ui1 = self.HMAC(text_enc, salt + b'\0\0\0\01') | ||||
|         ui = ui1 | ||||
|         for i in range(iterations - 1): | ||||
|             ui1 = self.HMAC(text, ui1) | ||||
|             ui1 = self.HMAC(text_enc, ui1) | ||||
|             ui = XOR(ui, ui1) | ||||
|         return ui | ||||
|  | ||||
|     def H(self, text): | ||||
|     def H(self, text: str) -> bytes_: | ||||
|         return self.hash(text).digest() | ||||
|  | ||||
|     def saslname(self, value): | ||||
|         value = value.decode("utf-8") | ||||
|         escaped = [] | ||||
|     def saslname(self, value_b: bytes_) -> bytes_: | ||||
|         value = value_b.decode("utf-8") | ||||
|         escaped: List[str] = [] | ||||
|         for char in value: | ||||
|             if char == ',': | ||||
|                 escaped += b'=2C' | ||||
|                 escaped.append('=2C') | ||||
|             elif char == '=': | ||||
|                 escaped += b'=3D' | ||||
|                 escaped.append('=3D') | ||||
|             else: | ||||
|                 escaped += char | ||||
|                 escaped.append(char) | ||||
|         return "".join(escaped).encode("utf-8") | ||||
|  | ||||
|     def parse(self, challenge): | ||||
|     def parse(self, challenge: bytes_) -> Dict[bytes_, bytes_]: | ||||
|         items = {} | ||||
|         for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]: | ||||
|             items[key] = value | ||||
|         return items | ||||
|  | ||||
|     def process(self, challenge=b''): | ||||
|     def process(self, challenge: bytes_ = b''): | ||||
|         steps = [self.process_1, self.process_2, self.process_3] | ||||
|         return steps[self.step](challenge) | ||||
|  | ||||
|     def process_1(self, challenge): | ||||
|     def process_1(self, challenge: bytes_) -> bytes_: | ||||
|         self.step = 1 | ||||
|         data = {} | ||||
|  | ||||
|         self.cnonce = bytes(('%s' % random.random())[2:]) | ||||
|  | ||||
| @@ -263,7 +265,7 @@ class SCRAM(Mech): | ||||
|  | ||||
|         return self.client_first_message | ||||
|  | ||||
|     def process_2(self, challenge): | ||||
|     def process_2(self, challenge: bytes_) -> bytes_: | ||||
|         self.step = 2 | ||||
|  | ||||
|         data = self.parse(challenge) | ||||
| @@ -304,7 +306,7 @@ class SCRAM(Mech): | ||||
|  | ||||
|         return client_final_message | ||||
|  | ||||
|     def process_3(self, challenge): | ||||
|     def process_3(self, challenge: bytes_) -> bytes_: | ||||
|         data = self.parse(challenge) | ||||
|         verifier = data.get(b'v', None) | ||||
|         error = data.get(b'e', 'Unknown error') | ||||
| @@ -345,17 +347,16 @@ class DIGEST(Mech): | ||||
|         self.cnonce = b'' | ||||
|         self.nonce_count = 1 | ||||
|  | ||||
|     def parse(self, challenge=b''): | ||||
|         data = {} | ||||
|     def parse(self, challenge:  bytes_ = b''): | ||||
|         data: Dict[str, bytes_] = {} | ||||
|         var_name = b'' | ||||
|         var_value = b'' | ||||
|  | ||||
|         # States: var, new_var, end, quote, escaped_quote | ||||
|         state = 'var' | ||||
|  | ||||
|  | ||||
|         for char in challenge: | ||||
|             char = bytes([char]) | ||||
|         for char_int in challenge: | ||||
|             char = bytes_([char_int]) | ||||
|  | ||||
|             if state == 'var': | ||||
|                 if char.isspace(): | ||||
| @@ -401,14 +402,14 @@ class DIGEST(Mech): | ||||
|         state = 'var' | ||||
|         return data | ||||
|  | ||||
|     def MAC(self, key, seq, msg): | ||||
|     def MAC(self, key: bytes_, seq: int, msg: bytes_) -> bytes_: | ||||
|         mac = hmac.HMAC(key=key, digestmod=self.hash) | ||||
|         seqnum = num_to_bytes(seq) | ||||
|         mac.update(seqnum) | ||||
|         mac.update(msg) | ||||
|         return mac.digest()[:10] + b'\x00\x01' + seqnum | ||||
|  | ||||
|     def A1(self): | ||||
|     def A1(self) -> bytes_: | ||||
|         username = self.credentials['username'] | ||||
|         password = self.credentials['password'] | ||||
|         authzid = self.credentials['authzid'] | ||||
| @@ -423,13 +424,13 @@ class DIGEST(Mech): | ||||
|  | ||||
|         return bytes(a1) | ||||
|  | ||||
|     def A2(self, prefix=b''): | ||||
|     def A2(self, prefix: bytes_ = b'') -> bytes_: | ||||
|         a2 = prefix + b':' + self.digest_uri() | ||||
|         if self.qop in (b'auth-int', b'auth-conf'): | ||||
|             a2 += b':00000000000000000000000000000000' | ||||
|         return bytes(a2) | ||||
|  | ||||
|     def response(self, prefix=b''): | ||||
|     def response(self, prefix: bytes_ = b'') -> bytes_: | ||||
|         nc = bytes('%08x' % self.nonce_count) | ||||
|  | ||||
|         a1 = bytes(self.hash(self.A1()).hexdigest().lower()) | ||||
| @@ -439,7 +440,7 @@ class DIGEST(Mech): | ||||
|  | ||||
|         return bytes(self.hash(a1 + b':' + s).hexdigest().lower()) | ||||
|  | ||||
|     def digest_uri(self): | ||||
|     def digest_uri(self) -> bytes_: | ||||
|         serv_type = self.credentials['service'] | ||||
|         serv_name = self.credentials['service-name'] | ||||
|         host = self.credentials['host'] | ||||
| @@ -449,7 +450,7 @@ class DIGEST(Mech): | ||||
|             uri += b'/' + serv_name | ||||
|         return uri | ||||
|  | ||||
|     def respond(self): | ||||
|     def respond(self) -> bytes_: | ||||
|         data = { | ||||
|             'username': quote(self.credentials['username']), | ||||
|             'authzid': quote(self.credentials['authzid']), | ||||
| @@ -469,7 +470,7 @@ class DIGEST(Mech): | ||||
|                 resp += b',' + bytes(key) + b'=' + bytes(value) | ||||
|         return resp[1:] | ||||
|  | ||||
|     def process(self, challenge=b''): | ||||
|     def process(self, challenge: bytes_ = b'') -> Optional[bytes_]: | ||||
|         if not challenge: | ||||
|             if self.cnonce and self.nonce and self.nonce_count and self.qop: | ||||
|                 self.nonce_count += 1 | ||||
| @@ -480,6 +481,7 @@ class DIGEST(Mech): | ||||
|         if 'rspauth' in data: | ||||
|             if data['rspauth'] != self.response(): | ||||
|                 raise SASLMutualAuthFailed() | ||||
|             return None | ||||
|         else: | ||||
|             self.nonce_count = 1 | ||||
|             self.cnonce = bytes('%s' % random.random())[2:] | ||||
|   | ||||
| @@ -1,22 +0,0 @@ | ||||
| """ | ||||
| asyncio-related utilities | ||||
| """ | ||||
|  | ||||
| import asyncio | ||||
| from functools import wraps | ||||
|  | ||||
| def future_wrapper(func): | ||||
|     """ | ||||
|     Make sure the result of a function call is an asyncio.Future() | ||||
|     object. | ||||
|     """ | ||||
|     @wraps(func) | ||||
|     def wrapper(*args, **kwargs): | ||||
|         result = func(*args, **kwargs) | ||||
|         if isinstance(result, asyncio.Future): | ||||
|             return result | ||||
|         future = asyncio.Future() | ||||
|         future.set_result(result) | ||||
|         return future | ||||
|  | ||||
|     return wrapper | ||||
| @@ -1,5 +1,6 @@ | ||||
| import logging | ||||
| from datetime import datetime, timedelta | ||||
| from typing import Dict, Set, Tuple, Optional | ||||
|  | ||||
| # Make a call to strptime before starting threads to | ||||
| # prevent thread safety issues. | ||||
| @@ -32,13 +33,13 @@ class CertificateError(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| def decode_str(data): | ||||
| def decode_str(data: bytes) -> str: | ||||
|     encoding = 'utf-16-be' if isinstance(data, BMPString) else 'utf-8' | ||||
|     return bytes(data).decode(encoding) | ||||
|  | ||||
|  | ||||
| def extract_names(raw_cert): | ||||
|     results = {'CN': set(), | ||||
| def extract_names(raw_cert: bytes) -> Dict[str, Set[str]]: | ||||
|     results: Dict[str, Set[str]] = {'CN': set(), | ||||
|                'DNS': set(), | ||||
|                'SRV': set(), | ||||
|                'URI': set(), | ||||
| @@ -96,7 +97,7 @@ def extract_names(raw_cert): | ||||
|     return results | ||||
|  | ||||
|  | ||||
| def extract_dates(raw_cert): | ||||
| def extract_dates(raw_cert: bytes) -> Tuple[Optional[datetime], Optional[datetime]]: | ||||
|     if not HAVE_PYASN1: | ||||
|         log.warning("Could not find pyasn1 and pyasn1_modules. " + \ | ||||
|                     "SSL certificate expiration COULD NOT BE VERIFIED.") | ||||
| @@ -125,24 +126,29 @@ def extract_dates(raw_cert): | ||||
|     return not_before, not_after | ||||
|  | ||||
|  | ||||
| def get_ttl(raw_cert): | ||||
| def get_ttl(raw_cert: bytes) -> Optional[timedelta]: | ||||
|     not_before, not_after = extract_dates(raw_cert) | ||||
|     if not_after is None: | ||||
|     if not_after is None or not_before is None: | ||||
|         return None | ||||
|     return not_after - datetime.utcnow() | ||||
|  | ||||
|  | ||||
| def verify(expected, raw_cert): | ||||
| def verify(expected: str, raw_cert: bytes) -> Optional[bool]: | ||||
|     if not HAVE_PYASN1: | ||||
|         log.warning("Could not find pyasn1 and pyasn1_modules. " + \ | ||||
|                     "SSL certificate COULD NOT BE VERIFIED.") | ||||
|         return | ||||
|         return None | ||||
|  | ||||
|     not_before, not_after = extract_dates(raw_cert) | ||||
|     cert_names = extract_names(raw_cert) | ||||
|  | ||||
|     now = datetime.utcnow() | ||||
|  | ||||
|     if not not_before or not not_after: | ||||
|         raise CertificateError( | ||||
|             "Error while checking the dates of the certificate" | ||||
|         ) | ||||
|  | ||||
|     if not_before > now: | ||||
|         raise CertificateError( | ||||
|                 'Certificate has not entered its valid date range.') | ||||
|   | ||||
| @@ -4,10 +4,19 @@ | ||||
| # Part of Slixmpp: The Slick XMPP Library | ||||
| # :copyright: (c) 2011 Nathanael C. Fritz | ||||
| # :license: MIT, see LICENSE for more details | ||||
| from __future__ import annotations | ||||
|  | ||||
| import weakref | ||||
| from weakref import ReferenceType | ||||
| from typing import Optional, TYPE_CHECKING, Union | ||||
| from slixmpp.xmlstream.matcher.base import MatcherBase | ||||
| from xml.etree.ElementTree import Element | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from slixmpp.xmlstream import XMLStream, StanzaBase | ||||
|  | ||||
|  | ||||
| class BaseHandler(object): | ||||
| class BaseHandler: | ||||
|  | ||||
|     """ | ||||
|     Base class for stream handlers. Stream handlers are matched with | ||||
| @@ -26,8 +35,13 @@ class BaseHandler(object): | ||||
|     :param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream` | ||||
|                     instance that the handle will respond to. | ||||
|     """ | ||||
|     name: str | ||||
|     stream: Optional[ReferenceType[XMLStream]] | ||||
|     _destroy: bool | ||||
|     _matcher: MatcherBase | ||||
|     _payload: Optional[StanzaBase] | ||||
|  | ||||
|     def __init__(self, name, matcher, stream=None): | ||||
|     def __init__(self, name: str, matcher: MatcherBase, stream: Optional[XMLStream] = None): | ||||
|         #: The name of the handler | ||||
|         self.name = name | ||||
|  | ||||
| @@ -41,33 +55,33 @@ class BaseHandler(object): | ||||
|         self._payload = None | ||||
|         self._matcher = matcher | ||||
|  | ||||
|     def match(self, xml): | ||||
|     def match(self, xml: StanzaBase) -> bool: | ||||
|         """Compare a stanza or XML object with the handler's matcher. | ||||
|  | ||||
|         :param xml: An XML or | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object | ||||
|         """ | ||||
|         return self._matcher.match(xml) | ||||
|  | ||||
|     def prerun(self, payload): | ||||
|     def prerun(self, payload: StanzaBase) -> None: | ||||
|         """Prepare the handler for execution while the XML | ||||
|         stream is being processed. | ||||
|  | ||||
|         :param payload: A :class:`~slixmpp.xmlstream.stanzabase.ElementBase` | ||||
|         :param payload: A :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` | ||||
|                         object. | ||||
|         """ | ||||
|         self._payload = payload | ||||
|  | ||||
|     def run(self, payload): | ||||
|     def run(self, payload: StanzaBase) -> None: | ||||
|         """Execute the handler after XML stream processing and during the | ||||
|         main event loop. | ||||
|  | ||||
|         :param payload: A :class:`~slixmpp.xmlstream.stanzabase.ElementBase` | ||||
|         :param payload: A :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` | ||||
|                         object. | ||||
|         """ | ||||
|         self._payload = payload | ||||
|  | ||||
|     def check_delete(self): | ||||
|     def check_delete(self) -> bool: | ||||
|         """Check if the handler should be removed from the list | ||||
|         of stream handlers. | ||||
|         """ | ||||
|   | ||||
| @@ -1,10 +1,17 @@ | ||||
|  | ||||
| # slixmpp.xmlstream.handler.callback | ||||
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| # Part of Slixmpp: The Slick XMPP Library | ||||
| # :copyright: (c) 2011 Nathanael C. Fritz | ||||
| # :license: MIT, see LICENSE for more details | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import Optional, Callable, Any, TYPE_CHECKING | ||||
| from slixmpp.xmlstream.handler.base import BaseHandler | ||||
| from slixmpp.xmlstream.matcher.base import MatcherBase | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from slixmpp.xmlstream.stanzabase import StanzaBase | ||||
|     from slixmpp.xmlstream.xmlstream import XMLStream | ||||
|  | ||||
|  | ||||
| class Callback(BaseHandler): | ||||
| @@ -28,8 +35,6 @@ class Callback(BaseHandler): | ||||
|     :param matcher: A :class:`~slixmpp.xmlstream.matcher.base.MatcherBase` | ||||
|                     derived object for matching stanza objects. | ||||
|     :param pointer: The function to execute during callback. | ||||
|     :param bool thread: **DEPRECATED.** Remains only for | ||||
|                         backwards compatibility. | ||||
|     :param bool once: Indicates if the handler should be used only | ||||
|                       once. Defaults to False. | ||||
|     :param bool instream: Indicates if the callback should be executed | ||||
| @@ -38,31 +43,36 @@ class Callback(BaseHandler): | ||||
|     :param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream` | ||||
|                    instance this handler should monitor. | ||||
|     """ | ||||
|     _once: bool | ||||
|     _instream: bool | ||||
|  | ||||
|     def __init__(self, name, matcher, pointer, thread=False, | ||||
|                  once=False, instream=False, stream=None): | ||||
|     def __init__(self, name: str, matcher: MatcherBase, | ||||
|                  pointer: Callable[[StanzaBase], Any], | ||||
|                  once: bool = False, instream: bool = False, | ||||
|                  stream: Optional[XMLStream] = None): | ||||
|         BaseHandler.__init__(self, name, matcher, stream) | ||||
|         self._pointer: Callable[[StanzaBase], Any] = pointer | ||||
|         self._pointer = pointer | ||||
|         self._once = once | ||||
|         self._instream = instream | ||||
|  | ||||
|     def prerun(self, payload): | ||||
|     def prerun(self, payload: StanzaBase) -> None: | ||||
|         """Execute the callback during stream processing, if | ||||
|         the callback was created with ``instream=True``. | ||||
|  | ||||
|         :param payload: The matched | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. | ||||
|         """ | ||||
|         if self._once: | ||||
|             self._destroy = True | ||||
|         if self._instream: | ||||
|             self.run(payload, True) | ||||
|  | ||||
|     def run(self, payload, instream=False): | ||||
|     def run(self, payload: StanzaBase, instream: bool = False) -> None: | ||||
|         """Execute the callback function with the matched stanza payload. | ||||
|  | ||||
|         :param payload: The matched | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. | ||||
|         :param bool instream: Force the handler to execute during stream | ||||
|                               processing. This should only be used by | ||||
|                               :meth:`prerun()`. Defaults to ``False``. | ||||
|   | ||||
| @@ -4,11 +4,17 @@ | ||||
| # Part of Slixmpp: The Slick XMPP Library | ||||
| # :copyright: (c) 2012 Nathanael C. Fritz, Lance J.T. Stout | ||||
| # :license: MIT, see LICENSE for more details | ||||
| from __future__ import annotations | ||||
|  | ||||
| import logging | ||||
| from queue import Queue, Empty | ||||
| from typing import List, Optional, TYPE_CHECKING | ||||
|  | ||||
| from slixmpp.xmlstream.stanzabase import StanzaBase | ||||
| from slixmpp.xmlstream.handler.base import BaseHandler | ||||
| from slixmpp.xmlstream.matcher.base import MatcherBase | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from slixmpp.xmlstream.xmlstream import XMLStream | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
| @@ -27,35 +33,35 @@ class Collector(BaseHandler): | ||||
|     :param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream` | ||||
|                    instance this handler should monitor. | ||||
|     """ | ||||
|     _stanzas: List[StanzaBase] | ||||
|  | ||||
|     def __init__(self, name, matcher, stream=None): | ||||
|     def __init__(self, name: str, matcher: MatcherBase, stream: Optional[XMLStream] = None): | ||||
|         BaseHandler.__init__(self, name, matcher, stream=stream) | ||||
|         self._payload = Queue() | ||||
|         self._stanzas = [] | ||||
|  | ||||
|     def prerun(self, payload): | ||||
|     def prerun(self, payload: StanzaBase) -> None: | ||||
|         """Store the matched stanza when received during processing. | ||||
|  | ||||
|         :param payload: The matched | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. | ||||
|         """ | ||||
|         self._payload.put(payload) | ||||
|         self._stanzas.append(payload) | ||||
|  | ||||
|     def run(self, payload): | ||||
|     def run(self, payload: StanzaBase) -> None: | ||||
|         """Do not process this handler during the main event loop.""" | ||||
|         pass | ||||
|  | ||||
|     def stop(self): | ||||
|     def stop(self) -> List[StanzaBase]: | ||||
|         """ | ||||
|         Stop collection of matching stanzas, and return the ones that | ||||
|         have been stored so far. | ||||
|         """ | ||||
|         stream_ref = self.stream | ||||
|         if stream_ref is None: | ||||
|             raise ValueError('stop() called without a stream!') | ||||
|         stream = stream_ref() | ||||
|         if stream is None: | ||||
|             raise ValueError('stop() called without a stream!') | ||||
|         self._destroy = True | ||||
|         results = [] | ||||
|         try: | ||||
|             while True: | ||||
|                 results.append(self._payload.get(False)) | ||||
|         except Empty: | ||||
|             pass | ||||
|  | ||||
|         self.stream().remove_handler(self.name) | ||||
|         return results | ||||
|         stream.remove_handler(self.name) | ||||
|         return self._stanzas | ||||
|   | ||||
| @@ -4,8 +4,19 @@ | ||||
| # Part of Slixmpp: The Slick XMPP Library | ||||
| # :copyright: (c) 2011 Nathanael C. Fritz | ||||
| # :license: MIT, see LICENSE for more details | ||||
| from __future__ import annotations | ||||
|  | ||||
| from asyncio import iscoroutinefunction, ensure_future | ||||
| from typing import Optional, Callable, Awaitable, TYPE_CHECKING | ||||
|  | ||||
| from slixmpp.xmlstream.stanzabase import StanzaBase | ||||
| from slixmpp.xmlstream.handler.base import BaseHandler | ||||
| from slixmpp.xmlstream.asyncio import asyncio | ||||
| from slixmpp.xmlstream.matcher.base import MatcherBase | ||||
|  | ||||
| CoroutineFunction = Callable[[StanzaBase], Awaitable[None]] | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from slixmpp.xmlstream.xmlstream import XMLStream | ||||
|  | ||||
|  | ||||
| class CoroutineCallback(BaseHandler): | ||||
| @@ -34,45 +45,49 @@ class CoroutineCallback(BaseHandler): | ||||
|                    instance this handler should monitor. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, name, matcher, pointer, once=False, | ||||
|                  instream=False, stream=None): | ||||
|     _once: bool | ||||
|     _instream: bool | ||||
|  | ||||
|     def __init__(self, name: str, matcher: MatcherBase, | ||||
|                  pointer: CoroutineFunction, once: bool = False, | ||||
|                  instream: bool = False, stream: Optional[XMLStream] = None): | ||||
|         BaseHandler.__init__(self, name, matcher, stream) | ||||
|         if not asyncio.iscoroutinefunction(pointer): | ||||
|         if not iscoroutinefunction(pointer): | ||||
|             raise ValueError("Given function is not a coroutine") | ||||
|  | ||||
|         async def pointer_wrapper(stanza, *args, **kwargs): | ||||
|         async def pointer_wrapper(stanza: StanzaBase) -> None: | ||||
|             try: | ||||
|                 await pointer(stanza, *args, **kwargs) | ||||
|                 await pointer(stanza) | ||||
|             except Exception as e: | ||||
|                 stanza.exception(e) | ||||
|  | ||||
|         self._pointer = pointer_wrapper | ||||
|         self._pointer: CoroutineFunction = pointer_wrapper | ||||
|         self._once = once | ||||
|         self._instream = instream | ||||
|  | ||||
|     def prerun(self, payload): | ||||
|     def prerun(self, payload: StanzaBase) -> None: | ||||
|         """Execute the callback during stream processing, if | ||||
|         the callback was created with ``instream=True``. | ||||
|  | ||||
|         :param payload: The matched | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. | ||||
|         """ | ||||
|         if self._once: | ||||
|             self._destroy = True | ||||
|         if self._instream: | ||||
|             self.run(payload, True) | ||||
|  | ||||
|     def run(self, payload, instream=False): | ||||
|     def run(self, payload: StanzaBase, instream: bool = False) -> None: | ||||
|         """Execute the callback function with the matched stanza payload. | ||||
|  | ||||
|         :param payload: The matched | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. | ||||
|         :param bool instream: Force the handler to execute during stream | ||||
|                               processing. This should only be used by | ||||
|                               :meth:`prerun()`. Defaults to ``False``. | ||||
|         """ | ||||
|         if not self._instream or instream: | ||||
|             asyncio.ensure_future(self._pointer(payload)) | ||||
|             ensure_future(self._pointer(payload)) | ||||
|             if self._once: | ||||
|                 self._destroy = True | ||||
|                 del self._pointer | ||||
|   | ||||
| @@ -4,13 +4,20 @@ | ||||
| # Part of Slixmpp: The Slick XMPP Library | ||||
| # :copyright: (c) 2011 Nathanael C. Fritz | ||||
| # :license: MIT, see LICENSE for more details | ||||
| from __future__ import annotations | ||||
|  | ||||
| import logging | ||||
| import asyncio | ||||
| from  asyncio import Queue, wait_for, TimeoutError | ||||
| from asyncio import Event, wait_for, TimeoutError | ||||
| from typing import Optional, TYPE_CHECKING, Union | ||||
| from xml.etree.ElementTree import Element | ||||
|  | ||||
| import slixmpp | ||||
| from slixmpp.xmlstream.stanzabase import StanzaBase | ||||
| from slixmpp.xmlstream.handler.base import BaseHandler | ||||
| from slixmpp.xmlstream.matcher.base import MatcherBase | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from slixmpp.xmlstream.xmlstream import XMLStream | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
| @@ -28,24 +35,27 @@ class Waiter(BaseHandler): | ||||
|     :param stream: The :class:`~slixmpp.xmlstream.xmlstream.XMLStream` | ||||
|                    instance this handler should monitor. | ||||
|     """ | ||||
|     _event: Event | ||||
|  | ||||
|     def __init__(self, name, matcher, stream=None): | ||||
|     def __init__(self, name: str, matcher: MatcherBase, stream: Optional[XMLStream] = None): | ||||
|         BaseHandler.__init__(self, name, matcher, stream=stream) | ||||
|         self._payload = Queue() | ||||
|         self._event = Event() | ||||
|  | ||||
|     def prerun(self, payload): | ||||
|     def prerun(self, payload: StanzaBase) -> None: | ||||
|         """Store the matched stanza when received during processing. | ||||
|  | ||||
|         :param payload: The matched | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.ElementBase` object. | ||||
|             :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` object. | ||||
|         """ | ||||
|         self._payload.put_nowait(payload) | ||||
|         if not self._event.is_set(): | ||||
|             self._event.set() | ||||
|             self._payload = payload | ||||
|  | ||||
|     def run(self, payload): | ||||
|     def run(self, payload: StanzaBase) -> None: | ||||
|         """Do not process this handler during the main event loop.""" | ||||
|         pass | ||||
|  | ||||
|     async def wait(self, timeout=None): | ||||
|     async def wait(self, timeout: Optional[int] = None) -> Optional[StanzaBase]: | ||||
|         """Block an event handler while waiting for a stanza to arrive. | ||||
|  | ||||
|         Be aware that this will impact performance if called from a | ||||
| @@ -59,17 +69,24 @@ class Waiter(BaseHandler): | ||||
|             :class:`~slixmpp.xmlstream.xmlstream.XMLStream.response_timeout` | ||||
|             value. | ||||
|         """ | ||||
|         stream_ref = self.stream | ||||
|         if stream_ref is None: | ||||
|             raise ValueError('wait() called without a stream') | ||||
|         stream = stream_ref() | ||||
|         if stream is None: | ||||
|             raise ValueError('wait() called without a stream') | ||||
|         if timeout is None: | ||||
|             timeout = slixmpp.xmlstream.RESPONSE_TIMEOUT | ||||
|  | ||||
|         stanza = None | ||||
|         try: | ||||
|             stanza = await self._payload.get() | ||||
|             await wait_for( | ||||
|                 self._event.wait(), timeout, loop=stream.loop | ||||
|             ) | ||||
|         except TimeoutError: | ||||
|             log.warning("Timed out waiting for %s", self.name) | ||||
|         self.stream().remove_handler(self.name) | ||||
|         return stanza | ||||
|         stream.remove_handler(self.name) | ||||
|         return self._payload | ||||
|  | ||||
|     def check_delete(self): | ||||
|     def check_delete(self) -> bool: | ||||
|         """Always remove waiters after use.""" | ||||
|         return True | ||||
|   | ||||
| @@ -4,6 +4,7 @@ | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| from slixmpp.xmlstream.handler import Callback | ||||
| from slixmpp.xmlstream.stanzabase import StanzaBase | ||||
|  | ||||
|  | ||||
| class XMLCallback(Callback): | ||||
| @@ -17,7 +18,7 @@ class XMLCallback(Callback): | ||||
|         run -- Overrides Callback.run | ||||
|     """ | ||||
|  | ||||
|     def run(self, payload, instream=False): | ||||
|     def run(self, payload: StanzaBase, instream: bool = False) -> None: | ||||
|         """ | ||||
|         Execute the callback function with the matched stanza's | ||||
|         XML contents, instead of the stanza itself. | ||||
| @@ -30,4 +31,4 @@ class XMLCallback(Callback): | ||||
|                         stream processing. Used only by prerun. | ||||
|                         Defaults to False. | ||||
|         """ | ||||
|         Callback.run(self, payload.xml, instream) | ||||
|         Callback.run(self, payload.xml, instream)  # type: ignore | ||||
|   | ||||
| @@ -3,6 +3,7 @@ | ||||
| # Copyright (C) 2010  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| from slixmpp.xmlstream.stanzabase import StanzaBase | ||||
| from slixmpp.xmlstream.handler import Waiter | ||||
|  | ||||
|  | ||||
| @@ -17,7 +18,7 @@ class XMLWaiter(Waiter): | ||||
|         prerun -- Overrides Waiter.prerun | ||||
|     """ | ||||
|  | ||||
|     def prerun(self, payload): | ||||
|     def prerun(self, payload: StanzaBase) -> None: | ||||
|         """ | ||||
|         Store the XML contents of the stanza to return to the | ||||
|         waiting event handler. | ||||
| @@ -27,4 +28,4 @@ class XMLWaiter(Waiter): | ||||
|         Arguments: | ||||
|             payload -- The matched stanza object. | ||||
|         """ | ||||
|         Waiter.prerun(self, payload.xml) | ||||
|         Waiter.prerun(self, payload.xml)  # type: ignore | ||||
|   | ||||
| @@ -1,10 +1,13 @@ | ||||
|  | ||||
| # slixmpp.xmlstream.matcher.base | ||||
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| # Part of Slixmpp: The Slick XMPP Library | ||||
| # :copyright: (c) 2011 Nathanael C. Fritz | ||||
| # :license: MIT, see LICENSE for more details | ||||
|  | ||||
| from typing import Any | ||||
| from slixmpp.xmlstream.stanzabase import StanzaBase | ||||
|  | ||||
|  | ||||
| class MatcherBase(object): | ||||
|  | ||||
|     """ | ||||
| @@ -15,10 +18,10 @@ class MatcherBase(object): | ||||
|     :param criteria: Object to compare some aspect of a stanza against. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, criteria): | ||||
|     def __init__(self, criteria: Any): | ||||
|         self._criteria = criteria | ||||
|  | ||||
|     def match(self, xml): | ||||
|     def match(self, xml: StanzaBase) -> bool: | ||||
|         """Check if a stanza matches the stored criteria. | ||||
|  | ||||
|         Meant to be overridden. | ||||
|   | ||||
| @@ -5,6 +5,7 @@ | ||||
| # :copyright: (c) 2011 Nathanael C. Fritz | ||||
| # :license: MIT, see LICENSE for more details | ||||
| from slixmpp.xmlstream.matcher.base import MatcherBase | ||||
| from slixmpp.xmlstream.stanzabase import StanzaBase | ||||
|  | ||||
|  | ||||
| class MatcherId(MatcherBase): | ||||
| @@ -13,12 +14,13 @@ class MatcherId(MatcherBase): | ||||
|     The ID matcher selects stanzas that have the same stanza 'id' | ||||
|     interface value as the desired ID. | ||||
|     """ | ||||
|     _criteria: str | ||||
|  | ||||
|     def match(self, xml): | ||||
|     def match(self, xml: StanzaBase) -> bool: | ||||
|         """Compare the given stanza's ``'id'`` attribute to the stored | ||||
|         ``id`` value. | ||||
|  | ||||
|         :param xml: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase` | ||||
|         :param xml: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` | ||||
|                     stanza to compare against. | ||||
|         """ | ||||
|         return xml['id'] == self._criteria | ||||
|         return bool(xml['id'] == self._criteria) | ||||
|   | ||||
| @@ -4,7 +4,19 @@ | ||||
| # Part of Slixmpp: The Slick XMPP Library | ||||
| # :copyright: (c) 2011 Nathanael C. Fritz | ||||
| # :license: MIT, see LICENSE for more details | ||||
|  | ||||
| from slixmpp.xmlstream.matcher.base import MatcherBase | ||||
| from slixmpp.xmlstream.stanzabase import StanzaBase | ||||
| from slixmpp.jid import JID | ||||
| from slixmpp.types import TypedDict | ||||
|  | ||||
| from typing import Dict | ||||
|  | ||||
|  | ||||
| class CriteriaType(TypedDict): | ||||
|     self: JID | ||||
|     peer: JID | ||||
|     id: str | ||||
|  | ||||
|  | ||||
| class MatchIDSender(MatcherBase): | ||||
| @@ -14,25 +26,26 @@ class MatchIDSender(MatcherBase): | ||||
|     interface value as the desired ID, and that the 'from' value is one | ||||
|     of a set of approved entities that can respond to a request. | ||||
|     """ | ||||
|     _criteria: CriteriaType | ||||
|  | ||||
|     def match(self, xml): | ||||
|     def match(self, xml: StanzaBase) -> bool: | ||||
|         """Compare the given stanza's ``'id'`` attribute to the stored | ||||
|         ``id`` value, and verify the sender's JID. | ||||
|  | ||||
|         :param xml: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase` | ||||
|         :param xml: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` | ||||
|                     stanza to compare against. | ||||
|         """ | ||||
|  | ||||
|         selfjid = self._criteria['self'] | ||||
|         peerjid = self._criteria['peer'] | ||||
|  | ||||
|         allowed = {} | ||||
|         allowed: Dict[str, bool] = {} | ||||
|         allowed[''] = True | ||||
|         allowed[selfjid.bare] = True | ||||
|         allowed[selfjid.host] = True | ||||
|         allowed[selfjid.domain] = True | ||||
|         allowed[peerjid.full] = True | ||||
|         allowed[peerjid.bare] = True | ||||
|         allowed[peerjid.host] = True | ||||
|         allowed[peerjid.domain] = True | ||||
|  | ||||
|         _from = xml['from'] | ||||
|  | ||||
|   | ||||
| @@ -3,7 +3,9 @@ | ||||
| # Copyright (C) 2010  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| # See the file LICENSE for copying permission. | ||||
| from typing import Iterable | ||||
| from slixmpp.xmlstream.matcher.base import MatcherBase | ||||
| from slixmpp.xmlstream.stanzabase import StanzaBase | ||||
|  | ||||
|  | ||||
| class MatchMany(MatcherBase): | ||||
| @@ -18,8 +20,9 @@ class MatchMany(MatcherBase): | ||||
|     Methods: | ||||
|         match -- Overrides MatcherBase.match. | ||||
|     """ | ||||
|     _criteria: Iterable[MatcherBase] | ||||
|  | ||||
|     def match(self, xml): | ||||
|     def match(self, xml: StanzaBase) -> bool: | ||||
|         """ | ||||
|         Match a stanza against multiple criteria. The match is successful | ||||
|         if one of the criteria matches. | ||||
|   | ||||
| @@ -4,8 +4,9 @@ | ||||
| # Part of Slixmpp: The Slick XMPP Library | ||||
| # :copyright: (c) 2011 Nathanael C. Fritz | ||||
| # :license: MIT, see LICENSE for more details | ||||
| from typing import cast, List | ||||
| from slixmpp.xmlstream.matcher.base import MatcherBase | ||||
| from slixmpp.xmlstream.stanzabase import fix_ns | ||||
| from slixmpp.xmlstream.stanzabase import fix_ns, StanzaBase | ||||
|  | ||||
|  | ||||
| class StanzaPath(MatcherBase): | ||||
| @@ -17,22 +18,28 @@ class StanzaPath(MatcherBase): | ||||
|  | ||||
|     :param criteria: Object to compare some aspect of a stanza against. | ||||
|     """ | ||||
|     _criteria: List[str] | ||||
|     _raw_criteria: str | ||||
|  | ||||
|     def __init__(self, criteria): | ||||
|         self._criteria = fix_ns(criteria, split=True, | ||||
|                                           propagate_ns=False, | ||||
|                                           default_ns='jabber:client') | ||||
|     def __init__(self, criteria: str): | ||||
|         self._criteria = cast( | ||||
|             List[str], | ||||
|             fix_ns( | ||||
|                 criteria, split=True, propagate_ns=False, | ||||
|                 default_ns='jabber:client' | ||||
|             ) | ||||
|         ) | ||||
|         self._raw_criteria = criteria | ||||
|  | ||||
|     def match(self, stanza): | ||||
|     def match(self, stanza: StanzaBase) -> bool: | ||||
|         """ | ||||
|         Compare a stanza against a "stanza path". A stanza path is similar to | ||||
|         an XPath expression, but uses the stanza's interfaces and plugins | ||||
|         instead of the underlying XML. See the documentation for the stanza | ||||
|         :meth:`~slixmpp.xmlstream.stanzabase.ElementBase.match()` method | ||||
|         :meth:`~slixmpp.xmlstream.stanzabase.StanzaBase.match()` method | ||||
|         for more information. | ||||
|  | ||||
|         :param stanza: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase` | ||||
|         :param stanza: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` | ||||
|                        stanza to compare against. | ||||
|         """ | ||||
|         return stanza.match(self._criteria) or stanza.match(self._raw_criteria) | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
|  | ||||
| # Slixmpp: The Slick XMPP Library | ||||
| # Copyright (C) 2010  Nathanael C. Fritz | ||||
| # This file is part of Slixmpp. | ||||
| @@ -6,8 +5,9 @@ | ||||
| import logging | ||||
|  | ||||
| from xml.parsers.expat import ExpatError | ||||
| from xml.etree.ElementTree import Element | ||||
|  | ||||
| from slixmpp.xmlstream.stanzabase import ET | ||||
| from slixmpp.xmlstream.stanzabase import ET, StanzaBase | ||||
| from slixmpp.xmlstream.matcher.base import MatcherBase | ||||
|  | ||||
|  | ||||
| @@ -33,32 +33,33 @@ class MatchXMLMask(MatcherBase): | ||||
|     :param criteria: Either an :class:`~xml.etree.ElementTree.Element` XML | ||||
|                      object or XML string to use as a mask. | ||||
|     """ | ||||
|     _criteria: Element | ||||
|  | ||||
|     def __init__(self, criteria, default_ns='jabber:client'): | ||||
|     def __init__(self, criteria: str, default_ns: str = 'jabber:client'): | ||||
|         MatcherBase.__init__(self, criteria) | ||||
|         if isinstance(criteria, str): | ||||
|             self._criteria = ET.fromstring(self._criteria) | ||||
|             self._criteria = ET.fromstring(criteria) | ||||
|         self.default_ns = default_ns | ||||
|  | ||||
|     def setDefaultNS(self, ns): | ||||
|     def setDefaultNS(self, ns: str) -> None: | ||||
|         """Set the default namespace to use during comparisons. | ||||
|  | ||||
|         :param ns: The new namespace to use as the default. | ||||
|         """ | ||||
|         self.default_ns = ns | ||||
|  | ||||
|     def match(self, xml): | ||||
|     def match(self, xml: StanzaBase) -> bool: | ||||
|         """Compare a stanza object or XML object against the stored XML mask. | ||||
|  | ||||
|         Overrides MatcherBase.match. | ||||
|  | ||||
|         :param xml: The stanza object or XML object to compare against. | ||||
|         """ | ||||
|         if hasattr(xml, 'xml'): | ||||
|             xml = xml.xml | ||||
|         return self._mask_cmp(xml, self._criteria, True) | ||||
|         real_xml = xml.xml | ||||
|         return self._mask_cmp(real_xml, self._criteria, True) | ||||
|  | ||||
|     def _mask_cmp(self, source, mask, use_ns=False, default_ns='__no_ns__'): | ||||
|     def _mask_cmp(self, source: Element, mask: Element, use_ns: bool = False, | ||||
|                   default_ns: str = '__no_ns__') -> bool: | ||||
|         """Compare an XML object against an XML mask. | ||||
|  | ||||
|         :param source: The :class:`~xml.etree.ElementTree.Element` XML object | ||||
| @@ -75,13 +76,6 @@ class MatchXMLMask(MatcherBase): | ||||
|             # If the element was not found. May happen during recursive calls. | ||||
|             return False | ||||
|  | ||||
|         # Convert the mask to an XML object if it is a string. | ||||
|         if not hasattr(mask, 'attrib'): | ||||
|             try: | ||||
|                 mask = ET.fromstring(mask) | ||||
|             except ExpatError: | ||||
|                 log.warning("Expat error: %s\nIn parsing: %s", '', mask) | ||||
|  | ||||
|         mask_ns_tag = "{%s}%s" % (self.default_ns, mask.tag) | ||||
|         if source.tag not in [mask.tag, mask_ns_tag]: | ||||
|             return False | ||||
|   | ||||
| @@ -4,7 +4,8 @@ | ||||
| # Part of Slixmpp: The Slick XMPP Library | ||||
| # :copyright: (c) 2011 Nathanael C. Fritz | ||||
| # :license: MIT, see LICENSE for more details | ||||
| from slixmpp.xmlstream.stanzabase import ET, fix_ns | ||||
| from typing import cast | ||||
| from slixmpp.xmlstream.stanzabase import ET, fix_ns, StanzaBase | ||||
| from slixmpp.xmlstream.matcher.base import MatcherBase | ||||
|  | ||||
|  | ||||
| @@ -17,23 +18,23 @@ class MatchXPath(MatcherBase): | ||||
|     If the value of :data:`IGNORE_NS` is set to ``True``, then XPath | ||||
|     expressions will be matched without using namespaces. | ||||
|     """ | ||||
|     _criteria: str | ||||
|  | ||||
|     def __init__(self, criteria): | ||||
|         self._criteria = fix_ns(criteria) | ||||
|     def __init__(self, criteria: str): | ||||
|         self._criteria = cast(str, fix_ns(criteria)) | ||||
|  | ||||
|     def match(self, xml): | ||||
|     def match(self, xml: StanzaBase) -> bool: | ||||
|         """ | ||||
|         Compare a stanza's XML contents to an XPath expression. | ||||
|  | ||||
|         If the value of :data:`IGNORE_NS` is set to ``True``, then XPath | ||||
|         expressions will be matched without using namespaces. | ||||
|  | ||||
|         :param xml: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase` | ||||
|         :param xml: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` | ||||
|                     stanza to compare against. | ||||
|         """ | ||||
|         if hasattr(xml, 'xml'): | ||||
|             xml = xml.xml | ||||
|         real_xml = xml.xml | ||||
|         x = ET.Element('x') | ||||
|         x.append(xml) | ||||
|         x.append(real_xml) | ||||
|  | ||||
|         return x.find(self._criteria) is not None | ||||
|   | ||||
| @@ -1,18 +1,32 @@ | ||||
|  | ||||
| # slixmpp.xmlstream.dns | ||||
| # ~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| # :copyright: (c) 2012 Nathanael C. Fritz | ||||
| # :license: MIT, see LICENSE for more details | ||||
|  | ||||
| from slixmpp.xmlstream.asyncio import asyncio | ||||
| import socket | ||||
| import sys | ||||
| import logging | ||||
| import random | ||||
| from asyncio import Future, AbstractEventLoop | ||||
| from typing import Optional, Tuple, Dict, List, Iterable, cast | ||||
| from slixmpp.types import Protocol | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class AnswerProtocol(Protocol): | ||||
|     host: str | ||||
|     priority: int | ||||
|     weight: int | ||||
|     port: int | ||||
|  | ||||
|  | ||||
| class ResolverProtocol(Protocol): | ||||
|     def query(self, query: str, querytype: str) -> Future: | ||||
|         ... | ||||
|  | ||||
|  | ||||
| #: Global flag indicating the availability of the ``aiodns`` package. | ||||
| #: Installing ``aiodns`` can be done via: | ||||
| #: | ||||
| @@ -23,12 +37,12 @@ AIODNS_AVAILABLE = False | ||||
| try: | ||||
|     import aiodns | ||||
|     AIODNS_AVAILABLE = True | ||||
| except ImportError as e: | ||||
|     log.debug("Could not find aiodns package. " + \ | ||||
| except ImportError: | ||||
|     log.debug("Could not find aiodns package. " | ||||
|               "Not all features will be available") | ||||
|  | ||||
|  | ||||
| def default_resolver(loop): | ||||
| def default_resolver(loop: AbstractEventLoop) -> Optional[ResolverProtocol]: | ||||
|     """Return a basic DNS resolver object. | ||||
|  | ||||
|     :returns: A :class:`aiodns.DNSResolver` object if aiodns | ||||
| @@ -41,8 +55,11 @@ def default_resolver(loop): | ||||
|     return None | ||||
|  | ||||
|  | ||||
| async def resolve(host, port=None, service=None, proto='tcp', | ||||
|             resolver=None, use_ipv6=True, use_aiodns=True, loop=None): | ||||
| async def resolve(host: str, port: int, *, loop: AbstractEventLoop, | ||||
|                   service: Optional[str] = None, proto: str = 'tcp', | ||||
|                   resolver: Optional[ResolverProtocol] = None, | ||||
|                   use_ipv6: bool = True, | ||||
|                   use_aiodns: bool = True) -> List[Tuple[str, str, int]]: | ||||
|     """Peform DNS resolution for a given hostname. | ||||
|  | ||||
|     Resolution may perform SRV record lookups if a service and protocol | ||||
| @@ -91,8 +108,8 @@ async def resolve(host, port=None, service=None, proto='tcp', | ||||
|     if not use_ipv6: | ||||
|         log.debug("DNS: Use of IPv6 has been disabled.") | ||||
|  | ||||
|     if resolver is None and AIODNS_AVAILABLE and use_aiodns: | ||||
|         resolver = aiodns.DNSResolver(loop=loop) | ||||
|     if resolver is None and use_aiodns: | ||||
|         resolver = default_resolver(loop=loop) | ||||
|  | ||||
|     # An IPv6 literal is allowed to be enclosed in square brackets, but | ||||
|     # the brackets must be stripped in order to process the literal; | ||||
| @@ -101,7 +118,7 @@ async def resolve(host, port=None, service=None, proto='tcp', | ||||
|  | ||||
|     try: | ||||
|         # If `host` is an IPv4 literal, we can return it immediately. | ||||
|         ipv4 = socket.inet_aton(host) | ||||
|         socket.inet_aton(host) | ||||
|         return [(host, host, port)] | ||||
|     except socket.error: | ||||
|         pass | ||||
| @@ -111,7 +128,7 @@ async def resolve(host, port=None, service=None, proto='tcp', | ||||
|             # Likewise, If `host` is an IPv6 literal, we can return | ||||
|             # it immediately. | ||||
|             if hasattr(socket, 'inet_pton'): | ||||
|                 ipv6 = socket.inet_pton(socket.AF_INET6, host) | ||||
|                 socket.inet_pton(socket.AF_INET6, host) | ||||
|                 return [(host, host, port)] | ||||
|         except (socket.error, ValueError): | ||||
|             pass | ||||
| @@ -148,7 +165,10 @@ async def resolve(host, port=None, service=None, proto='tcp', | ||||
|  | ||||
|     return results | ||||
|  | ||||
| async def get_A(host, resolver=None, use_aiodns=True, loop=None): | ||||
|  | ||||
| async def get_A(host: str, *, loop: AbstractEventLoop, | ||||
|                 resolver: Optional[ResolverProtocol] = None, | ||||
|                 use_aiodns: bool = True) -> List[str]: | ||||
|     """Lookup DNS A records for a given host. | ||||
|  | ||||
|     If ``resolver`` is not provided, or is ``None``, then resolution will | ||||
| @@ -172,10 +192,10 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None): | ||||
|     # getaddrinfo() method. | ||||
|     if resolver is None or not use_aiodns: | ||||
|         try: | ||||
|             recs = await loop.getaddrinfo(host, None, | ||||
|             inet_recs = await loop.getaddrinfo(host, None, | ||||
|                                                family=socket.AF_INET, | ||||
|                                                type=socket.SOCK_STREAM) | ||||
|             return [rec[4][0] for rec in recs] | ||||
|             return [rec[4][0] for rec in inet_recs] | ||||
|         except socket.gaierror: | ||||
|             log.debug("DNS: Error retrieving A address info for %s." % host) | ||||
|             return [] | ||||
| @@ -183,14 +203,16 @@ async def get_A(host, resolver=None, use_aiodns=True, loop=None): | ||||
|     # Using aiodns: | ||||
|     future = resolver.query(host, 'A') | ||||
|     try: | ||||
|         recs = await future | ||||
|         recs = cast(Iterable[AnswerProtocol], await future) | ||||
|     except Exception as e: | ||||
|         log.debug('DNS: Exception while querying for %s A records: %s', host, e) | ||||
|         recs = [] | ||||
|     return [rec.host for rec in recs] | ||||
|  | ||||
|  | ||||
| async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): | ||||
| async def get_AAAA(host: str, *, loop: AbstractEventLoop, | ||||
|                    resolver: Optional[ResolverProtocol] = None, | ||||
|                    use_aiodns: bool = True) -> List[str]: | ||||
|     """Lookup DNS AAAA records for a given host. | ||||
|  | ||||
|     If ``resolver`` is not provided, or is ``None``, then resolution will | ||||
| @@ -217,10 +239,10 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): | ||||
|             log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host) | ||||
|             return [] | ||||
|         try: | ||||
|             recs = await loop.getaddrinfo(host, None, | ||||
|             inet_recs = await loop.getaddrinfo(host, None, | ||||
|                                                family=socket.AF_INET6, | ||||
|                                                type=socket.SOCK_STREAM) | ||||
|             return [rec[4][0] for rec in recs] | ||||
|             return [rec[4][0] for rec in inet_recs] | ||||
|         except (OSError, socket.gaierror): | ||||
|             log.debug("DNS: Error retrieving AAAA address " + \ | ||||
|                       "info for %s." % host) | ||||
| @@ -229,13 +251,17 @@ async def get_AAAA(host, resolver=None, use_aiodns=True, loop=None): | ||||
|     # Using aiodns: | ||||
|     future = resolver.query(host, 'AAAA') | ||||
|     try: | ||||
|         recs = await future | ||||
|         recs = cast(Iterable[AnswerProtocol], await future) | ||||
|     except Exception as e: | ||||
|         log.debug('DNS: Exception while querying for %s AAAA records: %s', host, e) | ||||
|         recs = [] | ||||
|     return [rec.host for rec in recs] | ||||
|  | ||||
| async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=True): | ||||
|  | ||||
| async def get_SRV(host: str, port: int, service: str, | ||||
|                   proto: str = 'tcp', | ||||
|                   resolver: Optional[ResolverProtocol] = None, | ||||
|                   use_aiodns: bool = True) -> List[Tuple[str, int]]: | ||||
|     """Perform SRV record resolution for a given host. | ||||
|  | ||||
|     .. note:: | ||||
| @@ -269,12 +295,12 @@ async def get_SRV(host, port, service, proto='tcp', resolver=None, use_aiodns=Tr | ||||
|     try: | ||||
|         future = resolver.query('_%s._%s.%s' % (service, proto, host), | ||||
|                                 'SRV') | ||||
|         recs = await future | ||||
|         recs = cast(Iterable[AnswerProtocol], await future) | ||||
|     except Exception as e: | ||||
|         log.debug('DNS: Exception while querying for %s SRV records: %s', host, e) | ||||
|         return [] | ||||
|  | ||||
|     answers = {} | ||||
|     answers: Dict[int, List[AnswerProtocol]] = {} | ||||
|     for rec in recs: | ||||
|         if rec.priority not in answers: | ||||
|             answers[rec.priority] = [] | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
|  | ||||
| # slixmpp.xmlstream.stanzabase | ||||
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| # module implements a wrapper layer for XML objects | ||||
| @@ -11,13 +10,34 @@ from __future__ import annotations | ||||
| import copy | ||||
| import logging | ||||
| import weakref | ||||
| from typing import Optional | ||||
| from typing import ( | ||||
|     cast, | ||||
|     Any, | ||||
|     Callable, | ||||
|     ClassVar, | ||||
|     Coroutine, | ||||
|     Dict, | ||||
|     List, | ||||
|     Iterable, | ||||
|     Optional, | ||||
|     Set, | ||||
|     Tuple, | ||||
|     Type, | ||||
|     TYPE_CHECKING, | ||||
|     Union, | ||||
| ) | ||||
| from weakref import ReferenceType | ||||
| from xml.etree import ElementTree as ET | ||||
|  | ||||
| from slixmpp.types import JidStr | ||||
| from slixmpp.xmlstream import JID | ||||
| from slixmpp.xmlstream.tostring import tostring | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from slixmpp.xmlstream import XMLStream | ||||
|  | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| @@ -28,7 +48,8 @@ XML_TYPE = type(ET.Element('xml')) | ||||
| XML_NS = 'http://www.w3.org/XML/1998/namespace' | ||||
|  | ||||
|  | ||||
| def register_stanza_plugin(stanza, plugin, iterable=False, overrides=False): | ||||
| def register_stanza_plugin(stanza: Type[ElementBase], plugin: Type[ElementBase], | ||||
|                            iterable: bool = False, overrides: bool = False) -> None: | ||||
|     """ | ||||
|     Associate a stanza object as a plugin for another stanza. | ||||
|  | ||||
| @@ -85,15 +106,15 @@ def register_stanza_plugin(stanza, plugin, iterable=False, overrides=False): | ||||
|             stanza.plugin_overrides[interface] = plugin.plugin_attrib | ||||
|  | ||||
|  | ||||
| def multifactory(stanza, plugin_attrib): | ||||
| def multifactory(stanza: Type[ElementBase], plugin_attrib: str) -> Type[ElementBase]: | ||||
|     """ | ||||
|     Returns a ElementBase class for handling reoccuring child stanzas | ||||
|     """ | ||||
|  | ||||
|     def plugin_filter(self): | ||||
|     def plugin_filter(self: Multi) -> Callable[..., bool]: | ||||
|         return lambda x: isinstance(x, self._multistanza) | ||||
|  | ||||
|     def plugin_lang_filter(self, lang): | ||||
|     def plugin_lang_filter(self: Multi, lang: Optional[str]) -> Callable[..., bool]: | ||||
|         return lambda x: isinstance(x, self._multistanza) and \ | ||||
|                          x['lang'] == lang | ||||
|  | ||||
| @@ -101,31 +122,41 @@ def multifactory(stanza, plugin_attrib): | ||||
|         """ | ||||
|         Template class for multifactory | ||||
|         """ | ||||
|         def setup(self, xml=None): | ||||
|             self.xml = ET.Element('') | ||||
|         _multistanza: Type[ElementBase] | ||||
|  | ||||
|     def get_multi(self, lang=None): | ||||
|         parent = self.parent() | ||||
|         def setup(self, xml: Optional[ET.Element] = None) -> bool: | ||||
|             self.xml = ET.Element('') | ||||
|             return False | ||||
|  | ||||
|     def get_multi(self: Multi, lang: Optional[str] = None) -> List[ElementBase]: | ||||
|         parent = fail_without_parent(self) | ||||
|         if not lang or lang == '*': | ||||
|             res = filter(plugin_filter(self), parent) | ||||
|         else: | ||||
|             res = filter(plugin_filter(self, lang), parent) | ||||
|             res = filter(plugin_lang_filter(self, lang), parent) | ||||
|         return list(res) | ||||
|  | ||||
|     def set_multi(self, val, lang=None): | ||||
|         parent = self.parent() | ||||
|     def set_multi(self: Multi, val: Iterable[ElementBase], lang: Optional[str] = None) -> None: | ||||
|         parent = fail_without_parent(self) | ||||
|         del_multi = getattr(self, 'del_%s' % plugin_attrib) | ||||
|         del_multi(lang) | ||||
|         for sub in val: | ||||
|             parent.append(sub) | ||||
|  | ||||
|     def del_multi(self, lang=None): | ||||
|         parent = self.parent() | ||||
|     def fail_without_parent(self: Multi) -> ElementBase: | ||||
|         parent = None | ||||
|         if self.parent: | ||||
|             parent = self.parent() | ||||
|         if not parent: | ||||
|             raise ValueError('No stanza parent for multifactory') | ||||
|         return parent | ||||
|  | ||||
|     def del_multi(self: Multi, lang: Optional[str] = None) -> None: | ||||
|         parent = fail_without_parent(self) | ||||
|         if not lang or lang == '*': | ||||
|             res = filter(plugin_filter(self), parent) | ||||
|             res = list(filter(plugin_filter(self), parent)) | ||||
|         else: | ||||
|             res = filter(plugin_filter(self, lang), parent) | ||||
|         res = list(res) | ||||
|             res = list(filter(plugin_lang_filter(self, lang), parent)) | ||||
|         if not res: | ||||
|             del parent.plugins[(plugin_attrib, None)] | ||||
|             parent.loaded_plugins.remove(plugin_attrib) | ||||
| @@ -149,7 +180,8 @@ def multifactory(stanza, plugin_attrib): | ||||
|     return Multi | ||||
|  | ||||
|  | ||||
| def fix_ns(xpath, split=False, propagate_ns=True, default_ns=''): | ||||
| def fix_ns(xpath: str, split: bool = False, propagate_ns: bool = True, | ||||
|            default_ns: str = '') -> Union[str, List[str]]: | ||||
|     """Apply the stanza's namespace to elements in an XPath expression. | ||||
|  | ||||
|     :param string xpath: The XPath expression to fix with namespaces. | ||||
| @@ -275,12 +307,12 @@ class ElementBase(object): | ||||
|     #: The XML tag name of the element, not including any namespace | ||||
|     #: prefixes. For example, an :class:`ElementBase` object for | ||||
|     #: ``<message />`` would use ``name = 'message'``. | ||||
|     name = 'stanza' | ||||
|     name: ClassVar[str] = 'stanza' | ||||
|  | ||||
|     #: The XML namespace for the element. Given ``<foo xmlns="bar" />``, | ||||
|     #: then ``namespace = "bar"`` should be used. The default namespace | ||||
|     #: is ``jabber:client`` since this is being used in an XMPP library. | ||||
|     namespace = 'jabber:client' | ||||
|     namespace: str = 'jabber:client' | ||||
|  | ||||
|     #: For :class:`ElementBase` subclasses which are intended to be used | ||||
|     #: as plugins, the ``plugin_attrib`` value defines the plugin name. | ||||
| @@ -290,7 +322,7 @@ class ElementBase(object): | ||||
|     #:     register_stanza_plugin(Message, FooPlugin) | ||||
|     #:     msg = Message() | ||||
|     #:     msg['foo']['an_interface_from_the_foo_plugin'] | ||||
|     plugin_attrib = 'plugin' | ||||
|     plugin_attrib: ClassVar[str] = 'plugin' | ||||
|  | ||||
|     #: For :class:`ElementBase` subclasses that are intended to be an | ||||
|     #: iterable group of items, the ``plugin_multi_attrib`` value defines | ||||
| @@ -300,29 +332,29 @@ class ElementBase(object): | ||||
|     #:     # Given stanza class Foo, with plugin_multi_attrib = 'foos' | ||||
|     #:     parent['foos'] | ||||
|     #:     filter(isinstance(item, Foo), parent['substanzas']) | ||||
|     plugin_multi_attrib = '' | ||||
|     plugin_multi_attrib: ClassVar[str] = '' | ||||
|  | ||||
|     #: The set of keys that the stanza provides for accessing and | ||||
|     #: manipulating the underlying XML object. This set may be augmented | ||||
|     #: with the :attr:`plugin_attrib` value of any registered | ||||
|     #: stanza plugins. | ||||
|     interfaces = {'type', 'to', 'from', 'id', 'payload'} | ||||
|     interfaces: ClassVar[Set[str]] = {'type', 'to', 'from', 'id', 'payload'} | ||||
|  | ||||
|     #: A subset of :attr:`interfaces` which maps interfaces to direct | ||||
|     #: subelements of the underlying XML object. Using this set, the text | ||||
|     #: of these subelements may be set, retrieved, or removed without | ||||
|     #: needing to define custom methods. | ||||
|     sub_interfaces = set() | ||||
|     sub_interfaces: ClassVar[Set[str]] = set() | ||||
|  | ||||
|     #: A subset of :attr:`interfaces` which maps the presence of | ||||
|     #: subelements to boolean values. Using this set allows for quickly | ||||
|     #: checking for the existence of empty subelements like ``<required />``. | ||||
|     #: | ||||
|     #: .. versionadded:: 1.1 | ||||
|     bool_interfaces = set() | ||||
|     bool_interfaces: ClassVar[Set[str]] = set() | ||||
|  | ||||
|     #: .. versionadded:: 1.1.2 | ||||
|     lang_interfaces = set() | ||||
|     lang_interfaces: ClassVar[Set[str]] = set() | ||||
|  | ||||
|     #: In some cases you may wish to override the behaviour of one of the | ||||
|     #: parent stanza's interfaces. The ``overrides`` list specifies the | ||||
| @@ -336,7 +368,7 @@ class ElementBase(object): | ||||
|     #: be affected. | ||||
|     #: | ||||
|     #: .. versionadded:: 1.0-Beta5 | ||||
|     overrides = [] | ||||
|     overrides: ClassVar[List[str]] = [] | ||||
|  | ||||
|     #: If you need to add a new interface to an existing stanza, you | ||||
|     #: can create a plugin and set ``is_extension = True``. Be sure | ||||
| @@ -346,7 +378,7 @@ class ElementBase(object): | ||||
|     #: parent stanza will be passed to the plugin directly. | ||||
|     #: | ||||
|     #: .. versionadded:: 1.0-Beta5 | ||||
|     is_extension = False | ||||
|     is_extension: ClassVar[bool] = False | ||||
|  | ||||
|     #: A map of interface operations to the overriding functions. | ||||
|     #: For example, after overriding the ``set`` operation for | ||||
| @@ -355,15 +387,15 @@ class ElementBase(object): | ||||
|     #:     {'set_body': <some function>} | ||||
|     #: | ||||
|     #: .. versionadded: 1.0-Beta5 | ||||
|     plugin_overrides = {} | ||||
|     plugin_overrides: ClassVar[Dict[str, str]] = {} | ||||
|  | ||||
|     #: A mapping of the :attr:`plugin_attrib` values of registered | ||||
|     #: plugins to their respective classes. | ||||
|     plugin_attrib_map = {} | ||||
|     plugin_attrib_map: ClassVar[Dict[str, Type[ElementBase]]] = {} | ||||
|  | ||||
|     #: A mapping of root element tag names (in ``'{namespace}elementname'`` | ||||
|     #: format) to the plugin classes responsible for them. | ||||
|     plugin_tag_map = {} | ||||
|     plugin_tag_map: ClassVar[Dict[str, Type[ElementBase]]] = {} | ||||
|  | ||||
|     #: The set of stanza classes that can be iterated over using | ||||
|     #: the 'substanzas' interface. Classes are added to this set | ||||
| @@ -372,17 +404,26 @@ class ElementBase(object): | ||||
|     #:     register_stanza_plugin(DiscoInfo, DiscoItem, iterable=True) | ||||
|     #: | ||||
|     #: .. versionadded:: 1.0-Beta5 | ||||
|     plugin_iterables = set() | ||||
|     plugin_iterables: ClassVar[Set[Type[ElementBase]]] = set() | ||||
|  | ||||
|     #: The default XML namespace: ``http://www.w3.org/XML/1998/namespace``. | ||||
|     xml_ns = XML_NS | ||||
|     xml_ns: ClassVar[str] = XML_NS | ||||
|  | ||||
|     def __init__(self, xml=None, parent=None): | ||||
|     plugins: Dict[Tuple[str, Optional[str]], ElementBase] | ||||
|     #: The underlying XML object for the stanza. It is a standard | ||||
|     #: :class:`xml.etree.ElementTree` object. | ||||
|     xml: ET.Element | ||||
|     _index: int | ||||
|     loaded_plugins: Set[str] | ||||
|     iterables: List[ElementBase] | ||||
|     tag: str | ||||
|     parent: Optional[ReferenceType[ElementBase]] | ||||
|  | ||||
|     def __init__(self, xml: Optional[ET.Element] = None, parent: Union[Optional[ElementBase], ReferenceType[ElementBase]] = None): | ||||
|         self._index = 0 | ||||
|  | ||||
|         #: The underlying XML object for the stanza. It is a standard | ||||
|         #: :class:`xml.etree.ElementTree` object. | ||||
|         self.xml = xml | ||||
|         if xml is not None: | ||||
|             self.xml = xml | ||||
|  | ||||
|         #: An ordered dictionary of plugin stanzas, mapped by their | ||||
|         #: :attr:`plugin_attrib` value. | ||||
| @@ -419,7 +460,7 @@ class ElementBase(object): | ||||
|                                  existing_xml=child, | ||||
|                                  reuse=False) | ||||
|  | ||||
|     def setup(self, xml=None): | ||||
|     def setup(self, xml: Optional[ET.Element] = None) -> bool: | ||||
|         """Initialize the stanza's XML contents. | ||||
|  | ||||
|         Will return ``True`` if XML was generated according to the stanza's | ||||
| @@ -429,29 +470,31 @@ class ElementBase(object): | ||||
|         :param xml: An existing XML object to use for the stanza's content | ||||
|                     instead of generating new XML. | ||||
|         """ | ||||
|         if self.xml is None: | ||||
|         if hasattr(self, 'xml'): | ||||
|             return False | ||||
|         if not hasattr(self, 'xml') and xml is not None: | ||||
|             self.xml = xml | ||||
|  | ||||
|         last_xml = self.xml | ||||
|         if self.xml is None: | ||||
|             # Generate XML from the stanza definition | ||||
|             for ename in self.name.split('/'): | ||||
|                 new = ET.Element("{%s}%s" % (self.namespace, ename)) | ||||
|                 if self.xml is None: | ||||
|                     self.xml = new | ||||
|                 else: | ||||
|                     last_xml.append(new) | ||||
|                 last_xml = new | ||||
|             if self.parent is not None: | ||||
|                 self.parent().xml.append(self.xml) | ||||
|  | ||||
|             # We had to generate XML | ||||
|             return True | ||||
|         else: | ||||
|             # We did not generate XML | ||||
|             return False | ||||
|  | ||||
|     def enable(self, attrib, lang=None): | ||||
|  | ||||
|         # Generate XML from the stanza definition | ||||
|         last_xml = ET.Element('') | ||||
|         for ename in self.name.split('/'): | ||||
|             new = ET.Element("{%s}%s" % (self.namespace, ename)) | ||||
|             if not hasattr(self, 'xml'): | ||||
|                 self.xml = new | ||||
|             else: | ||||
|                 last_xml.append(new) | ||||
|             last_xml = new | ||||
|         if self.parent is not None: | ||||
|             parent = self.parent() | ||||
|             if parent: | ||||
|                 parent.xml.append(self.xml) | ||||
|  | ||||
|         # We had to generate XML | ||||
|         return True | ||||
|  | ||||
|     def enable(self, attrib: str, lang: Optional[str] = None) -> ElementBase: | ||||
|         """Enable and initialize a stanza plugin. | ||||
|  | ||||
|         Alias for :meth:`init_plugin`. | ||||
| @@ -487,7 +530,10 @@ class ElementBase(object): | ||||
|             else: | ||||
|                 return None if check else self.init_plugin(name, lang) | ||||
|  | ||||
|     def init_plugin(self, attrib, lang=None, existing_xml=None, element=None, reuse=True): | ||||
|     def init_plugin(self, attrib: str, lang: Optional[str] = None, | ||||
|                     existing_xml: Optional[ET.Element] = None, | ||||
|                     reuse: bool = True, | ||||
|                     element: Optional[ElementBase] = None) -> ElementBase: | ||||
|         """Enable and initialize a stanza plugin. | ||||
|  | ||||
|         :param string attrib: The :attr:`plugin_attrib` value of the | ||||
| @@ -525,7 +571,7 @@ class ElementBase(object): | ||||
|  | ||||
|         return plugin | ||||
|  | ||||
|     def _get_stanza_values(self): | ||||
|     def _get_stanza_values(self) -> Dict[str, Any]: | ||||
|         """Return A JSON/dictionary version of the XML content | ||||
|         exposed through the stanza's interfaces:: | ||||
|  | ||||
| @@ -567,7 +613,7 @@ class ElementBase(object): | ||||
|             values['substanzas'] = iterables | ||||
|         return values | ||||
|  | ||||
|     def _set_stanza_values(self, values): | ||||
|     def _set_stanza_values(self, values: Dict[str, Any]) -> ElementBase: | ||||
|         """Set multiple stanza interface values using a dictionary. | ||||
|  | ||||
|         Stanza plugin values may be set using nested dictionaries. | ||||
| @@ -623,7 +669,7 @@ class ElementBase(object): | ||||
|                         plugin.values = value | ||||
|         return self | ||||
|  | ||||
|     def __getitem__(self, full_attrib): | ||||
|     def __getitem__(self, full_attrib: str) -> Any: | ||||
|         """Return the value of a stanza interface using dict-like syntax. | ||||
|  | ||||
|         Example:: | ||||
| @@ -688,7 +734,7 @@ class ElementBase(object): | ||||
|         else: | ||||
|             return '' | ||||
|  | ||||
|     def __setitem__(self, attrib, value): | ||||
|     def __setitem__(self, attrib: str, value: Any) -> Any: | ||||
|         """Set the value of a stanza interface using dictionary-like syntax. | ||||
|  | ||||
|         Example:: | ||||
| @@ -773,7 +819,7 @@ class ElementBase(object): | ||||
|                 plugin[full_attrib] = value | ||||
|         return self | ||||
|  | ||||
|     def __delitem__(self, attrib): | ||||
|     def __delitem__(self, attrib: str) -> Any: | ||||
|         """Delete the value of a stanza interface using dict-like syntax. | ||||
|  | ||||
|         Example:: | ||||
| @@ -851,7 +897,7 @@ class ElementBase(object): | ||||
|                 pass | ||||
|         return self | ||||
|  | ||||
|     def _set_attr(self, name, value): | ||||
|     def _set_attr(self, name: str, value: Optional[JidStr]) -> None: | ||||
|         """Set the value of a top level attribute of the XML object. | ||||
|  | ||||
|         If the new value is None or an empty string, then the attribute will | ||||
| @@ -868,7 +914,7 @@ class ElementBase(object): | ||||
|                 value = str(value) | ||||
|             self.xml.attrib[name] = value | ||||
|  | ||||
|     def _del_attr(self, name): | ||||
|     def _del_attr(self, name: str) -> None: | ||||
|         """Remove a top level attribute of the XML object. | ||||
|  | ||||
|         :param name: The name of the attribute. | ||||
| @@ -876,7 +922,7 @@ class ElementBase(object): | ||||
|         if name in self.xml.attrib: | ||||
|             del self.xml.attrib[name] | ||||
|  | ||||
|     def _get_attr(self, name, default=''): | ||||
|     def _get_attr(self, name: str, default: str = '') -> str: | ||||
|         """Return the value of a top level attribute of the XML object. | ||||
|  | ||||
|         In case the attribute has not been set, a default value can be | ||||
| @@ -889,7 +935,8 @@ class ElementBase(object): | ||||
|         """ | ||||
|         return self.xml.attrib.get(name, default) | ||||
|  | ||||
|     def _get_sub_text(self, name, default='', lang=None): | ||||
|     def _get_sub_text(self, name: str, default: str = '', | ||||
|                       lang: Optional[str] = None) -> Union[str, Dict[str, str]]: | ||||
|         """Return the text contents of a sub element. | ||||
|  | ||||
|         In case the element does not exist, or it has no textual content, | ||||
| @@ -900,7 +947,7 @@ class ElementBase(object): | ||||
|         :param default: Optional default to return if the element does | ||||
|                         not exists. An empty string is returned otherwise. | ||||
|         """ | ||||
|         name = self._fix_ns(name) | ||||
|         name = cast(str, self._fix_ns(name)) | ||||
|         if lang == '*': | ||||
|             return self._get_all_sub_text(name, default, None) | ||||
|  | ||||
| @@ -924,8 +971,9 @@ class ElementBase(object): | ||||
|             return result | ||||
|         return default | ||||
|  | ||||
|     def _get_all_sub_text(self, name, default='', lang=None): | ||||
|         name = self._fix_ns(name) | ||||
|     def _get_all_sub_text(self, name: str, default: str = '', | ||||
|                           lang: Optional[str] = None) -> Dict[str, str]: | ||||
|         name = cast(str, self._fix_ns(name)) | ||||
|  | ||||
|         default_lang = self.get_lang() | ||||
|         results = {} | ||||
| @@ -935,10 +983,16 @@ class ElementBase(object): | ||||
|                 stanza_lang = stanza.attrib.get('{%s}lang' % XML_NS, | ||||
|                                                 default_lang) | ||||
|                 if not lang or lang == '*' or stanza_lang == lang: | ||||
|                     results[stanza_lang] = stanza.text | ||||
|                     if stanza.text is None: | ||||
|                         text = default | ||||
|                     else: | ||||
|                         text = stanza.text | ||||
|                     results[stanza_lang] = text | ||||
|         return results | ||||
|  | ||||
|     def _set_sub_text(self, name, text=None, keep=False, lang=None): | ||||
|     def _set_sub_text(self, name: str, text: Optional[str] = None, | ||||
|                       keep: bool = False, | ||||
|                       lang: Optional[str] = None) -> Optional[ET.Element]: | ||||
|         """Set the text contents of a sub element. | ||||
|  | ||||
|         In case the element does not exist, a element will be created, | ||||
| @@ -959,15 +1013,16 @@ class ElementBase(object): | ||||
|             lang = default_lang | ||||
|  | ||||
|         if not text and not keep: | ||||
|             return self._del_sub(name, lang=lang) | ||||
|             self._del_sub(name, lang=lang) | ||||
|             return None | ||||
|  | ||||
|         path = self._fix_ns(name, split=True) | ||||
|         path = cast(List[str], self._fix_ns(name, split=True)) | ||||
|         name = path[-1] | ||||
|         parent = self.xml | ||||
|         parent: Optional[ET.Element] = self.xml | ||||
|  | ||||
|         # The first goal is to find the parent of the subelement, or, if | ||||
|         # we can't find that, the closest grandparent element. | ||||
|         missing_path = [] | ||||
|         missing_path: List[str] = [] | ||||
|         search_order = path[:-1] | ||||
|         while search_order: | ||||
|             parent = self.xml.find('/'.join(search_order)) | ||||
| @@ -1008,15 +1063,17 @@ class ElementBase(object): | ||||
|         parent.append(element) | ||||
|         return element | ||||
|  | ||||
|     def _set_all_sub_text(self, name, values, keep=False, lang=None): | ||||
|         self._del_sub(name, lang) | ||||
|     def _set_all_sub_text(self, name: str, values: Dict[str, str], | ||||
|                           keep: bool = False, | ||||
|                           lang: Optional[str] = None) -> None: | ||||
|         self._del_sub(name, lang=lang) | ||||
|         for value_lang, value in values.items(): | ||||
|             if not lang or lang == '*' or value_lang == lang: | ||||
|                 self._set_sub_text(name, text=value, | ||||
|                                          keep=keep, | ||||
|                                          lang=value_lang) | ||||
|  | ||||
|     def _del_sub(self, name, all=False, lang=None): | ||||
|     def _del_sub(self, name: str, all: bool = False, lang: Optional[str] = None) -> None: | ||||
|         """Remove sub elements that match the given name or XPath. | ||||
|  | ||||
|         If the element is in a path, then any parent elements that become | ||||
| @@ -1034,11 +1091,11 @@ class ElementBase(object): | ||||
|         if not lang: | ||||
|             lang = default_lang | ||||
|  | ||||
|         parent = self.xml | ||||
|         parent: Optional[ET.Element] = self.xml | ||||
|         for level, _ in enumerate(path): | ||||
|             # Generate the paths to the target elements and their parent. | ||||
|             element_path = "/".join(path[:len(path) - level]) | ||||
|             parent_path = "/".join(path[:len(path) - level - 1]) | ||||
|             parent_path: Optional[str] = "/".join(path[:len(path) - level - 1]) | ||||
|  | ||||
|             elements = self.xml.findall(element_path) | ||||
|             if parent_path == '': | ||||
| @@ -1061,7 +1118,7 @@ class ElementBase(object): | ||||
|                 # after deleting the first level of elements. | ||||
|                 return | ||||
|  | ||||
|     def match(self, xpath): | ||||
|     def match(self, xpath: Union[str, List[str]]) -> bool: | ||||
|         """Compare a stanza object with an XPath-like expression. | ||||
|  | ||||
|         If the XPath matches the contents of the stanza object, the match | ||||
| @@ -1127,7 +1184,7 @@ class ElementBase(object): | ||||
|         # Everything matched. | ||||
|         return True | ||||
|  | ||||
|     def get(self, key, default=None): | ||||
|     def get(self, key: str, default: Optional[Any] = None) -> Any: | ||||
|         """Return the value of a stanza interface. | ||||
|  | ||||
|         If the found value is None or an empty string, return the supplied | ||||
| @@ -1144,7 +1201,7 @@ class ElementBase(object): | ||||
|             return default | ||||
|         return value | ||||
|  | ||||
|     def keys(self): | ||||
|     def keys(self) -> List[str]: | ||||
|         """Return the names of all stanza interfaces provided by the | ||||
|         stanza object. | ||||
|  | ||||
| @@ -1158,7 +1215,7 @@ class ElementBase(object): | ||||
|             out.append('substanzas') | ||||
|         return out | ||||
|  | ||||
|     def append(self, item): | ||||
|     def append(self, item: Union[ET.Element, ElementBase]) -> ElementBase: | ||||
|         """Append either an XML object or a substanza to this stanza object. | ||||
|  | ||||
|         If a substanza object is appended, it will be added to the list | ||||
| @@ -1189,7 +1246,7 @@ class ElementBase(object): | ||||
|  | ||||
|         return self | ||||
|  | ||||
|     def appendxml(self, xml): | ||||
|     def appendxml(self, xml: ET.Element) -> ElementBase: | ||||
|         """Append an XML object to the stanza's XML. | ||||
|  | ||||
|         The added XML will not be included in the list of | ||||
| @@ -1200,7 +1257,7 @@ class ElementBase(object): | ||||
|         self.xml.append(xml) | ||||
|         return self | ||||
|  | ||||
|     def pop(self, index=0): | ||||
|     def pop(self, index: int = 0) -> ElementBase: | ||||
|         """Remove and return the last substanza in the list of | ||||
|         iterable substanzas. | ||||
|  | ||||
| @@ -1212,11 +1269,11 @@ class ElementBase(object): | ||||
|         self.xml.remove(substanza.xml) | ||||
|         return substanza | ||||
|  | ||||
|     def next(self): | ||||
|     def next(self) -> ElementBase: | ||||
|         """Return the next iterable substanza.""" | ||||
|         return self.__next__() | ||||
|  | ||||
|     def clear(self): | ||||
|     def clear(self) -> ElementBase: | ||||
|         """Remove all XML element contents and plugins. | ||||
|  | ||||
|         Any attribute values will be preserved. | ||||
| @@ -1229,7 +1286,7 @@ class ElementBase(object): | ||||
|         return self | ||||
|  | ||||
|     @classmethod | ||||
|     def tag_name(cls): | ||||
|     def tag_name(cls) -> str: | ||||
|         """Return the namespaced name of the stanza's root element. | ||||
|  | ||||
|         The format for the tag name is:: | ||||
| @@ -1241,29 +1298,32 @@ class ElementBase(object): | ||||
|         """ | ||||
|         return "{%s}%s" % (cls.namespace, cls.name) | ||||
|  | ||||
|     def get_lang(self, lang=None): | ||||
|     def get_lang(self, lang: Optional[str] = None) -> str: | ||||
|         result = self.xml.attrib.get('{%s}lang' % XML_NS, '') | ||||
|         if not result and self.parent and self.parent(): | ||||
|             return self.parent()['lang'] | ||||
|         if not result and self.parent: | ||||
|             parent = self.parent() | ||||
|             if parent: | ||||
|                 return cast(str, parent['lang']) | ||||
|         return result | ||||
|  | ||||
|     def set_lang(self, lang): | ||||
|     def set_lang(self, lang: Optional[str]) -> None: | ||||
|         self.del_lang() | ||||
|         attr = '{%s}lang' % XML_NS | ||||
|         if lang: | ||||
|             self.xml.attrib[attr] = lang | ||||
|  | ||||
|     def del_lang(self): | ||||
|     def del_lang(self) -> None: | ||||
|         attr = '{%s}lang' % XML_NS | ||||
|         if attr in self.xml.attrib: | ||||
|             del self.xml.attrib[attr] | ||||
|  | ||||
|     def _fix_ns(self, xpath, split=False, propagate_ns=True): | ||||
|     def _fix_ns(self, xpath: str, split: bool = False, | ||||
|                 propagate_ns: bool = True) -> Union[str, List[str]]: | ||||
|         return fix_ns(xpath, split=split, | ||||
|                              propagate_ns=propagate_ns, | ||||
|                              default_ns=self.namespace) | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|     def __eq__(self, other: Any) -> bool: | ||||
|         """Compare the stanza object with another to test for equality. | ||||
|  | ||||
|         Stanzas are equal if their interfaces return the same values, | ||||
| @@ -1290,7 +1350,7 @@ class ElementBase(object): | ||||
|         # must be equal. | ||||
|         return True | ||||
|  | ||||
|     def __ne__(self, other): | ||||
|     def __ne__(self, other: Any) -> bool: | ||||
|         """Compare the stanza object with another to test for inequality. | ||||
|  | ||||
|         Stanzas are not equal if their interfaces return different values, | ||||
| @@ -1300,16 +1360,16 @@ class ElementBase(object): | ||||
|         """ | ||||
|         return not self.__eq__(other) | ||||
|  | ||||
|     def __bool__(self): | ||||
|     def __bool__(self) -> bool: | ||||
|         """Stanza objects should be treated as True in boolean contexts. | ||||
|         """ | ||||
|         return True | ||||
|  | ||||
|     def __len__(self): | ||||
|     def __len__(self) -> int: | ||||
|         """Return the number of iterable substanzas in this stanza.""" | ||||
|         return len(self.iterables) | ||||
|  | ||||
|     def __iter__(self): | ||||
|     def __iter__(self) -> ElementBase: | ||||
|         """Return an iterator object for the stanza's substanzas. | ||||
|  | ||||
|         The iterator is the stanza object itself. Attempting to use two | ||||
| @@ -1318,7 +1378,7 @@ class ElementBase(object): | ||||
|         self._index = 0 | ||||
|         return self | ||||
|  | ||||
|     def __next__(self): | ||||
|     def __next__(self) -> ElementBase: | ||||
|         """Return the next iterable substanza.""" | ||||
|         self._index += 1 | ||||
|         if self._index > len(self.iterables): | ||||
| @@ -1326,13 +1386,16 @@ class ElementBase(object): | ||||
|             raise StopIteration | ||||
|         return self.iterables[self._index - 1] | ||||
|  | ||||
|     def __copy__(self): | ||||
|     def __copy__(self) -> ElementBase: | ||||
|         """Return a copy of the stanza object that does not share the same | ||||
|         underlying XML object. | ||||
|         """ | ||||
|         return self.__class__(xml=copy.deepcopy(self.xml), parent=self.parent) | ||||
|         return self.__class__( | ||||
|             xml=copy.deepcopy(self.xml), | ||||
|             parent=self.parent, | ||||
|         ) | ||||
|  | ||||
|     def __str__(self, top_level_ns=True): | ||||
|     def __str__(self, top_level_ns: bool = True) -> str: | ||||
|         """Return a string serialization of the underlying XML object. | ||||
|  | ||||
|         .. seealso:: :ref:`tostring` | ||||
| @@ -1343,12 +1406,33 @@ class ElementBase(object): | ||||
|         return tostring(self.xml, xmlns='', | ||||
|                         top_level=True) | ||||
|  | ||||
|     def __repr__(self): | ||||
|     def __repr__(self) -> str: | ||||
|         """Use the stanza's serialized XML as its representation.""" | ||||
|         return self.__str__() | ||||
|  | ||||
|     # Compatibility. | ||||
|     _get_plugin = get_plugin | ||||
|     get_stanza_values = _get_stanza_values | ||||
|     set_stanza_values = _set_stanza_values | ||||
|  | ||||
|     #: A JSON/dictionary version of the XML content exposed through | ||||
|     #: the stanza interfaces:: | ||||
|     #: | ||||
|     #:     >>> msg = Message() | ||||
|     #:     >>> msg.values | ||||
|     #:    {'body': '', 'from': , 'mucnick': '', 'mucroom': '', | ||||
|     #:     'to': , 'type': 'normal', 'id': '', 'subject': ''} | ||||
|     #: | ||||
|     #: Likewise, assigning to the :attr:`values` will change the XML | ||||
|     #: content:: | ||||
|     #: | ||||
|     #:     >>> msg = Message() | ||||
|     #:     >>> msg.values = {'body': 'Hi!', 'to': 'user@example.com'} | ||||
|     #:     >>> msg | ||||
|     #:     '<message to="user@example.com"><body>Hi!</body></message>' | ||||
|     #: | ||||
|     #: Child stanzas are exposed as nested dictionaries. | ||||
|     values = property(_get_stanza_values, _set_stanza_values)  # type: ignore | ||||
|  | ||||
|  | ||||
| class StanzaBase(ElementBase): | ||||
| @@ -1386,9 +1470,14 @@ class StanzaBase(ElementBase): | ||||
|  | ||||
|     #: The default XMPP client namespace | ||||
|     namespace = 'jabber:client' | ||||
|     types: ClassVar[Set[str]] = set() | ||||
|  | ||||
|     def __init__(self, stream=None, xml=None, stype=None, | ||||
|                  sto=None, sfrom=None, sid=None, parent=None, recv=False): | ||||
|     def __init__(self, stream: Optional[XMLStream] = None, | ||||
|                  xml: Optional[ET.Element] = None, | ||||
|                  stype: Optional[str] = None, | ||||
|                  sto: Optional[JidStr] = None, sfrom: Optional[JidStr] = None, | ||||
|                  sid: Optional[str] = None, | ||||
|                  parent: Optional[ElementBase] = None, recv: bool = False): | ||||
|         self.stream = stream | ||||
|         if stream is not None: | ||||
|             self.namespace = stream.default_ns | ||||
| @@ -1403,7 +1492,7 @@ class StanzaBase(ElementBase): | ||||
|             self['id'] = sid | ||||
|         self.tag = "{%s}%s" % (self.namespace, self.name) | ||||
|  | ||||
|     def set_type(self, value): | ||||
|     def set_type(self, value: str) -> StanzaBase: | ||||
|         """Set the stanza's ``'type'`` attribute. | ||||
|  | ||||
|         Only type values contained in :attr:`types` are accepted. | ||||
| @@ -1414,11 +1503,11 @@ class StanzaBase(ElementBase): | ||||
|             self.xml.attrib['type'] = value | ||||
|         return self | ||||
|  | ||||
|     def get_to(self): | ||||
|     def get_to(self) -> JID: | ||||
|         """Return the value of the stanza's ``'to'`` attribute.""" | ||||
|         return JID(self._get_attr('to')) | ||||
|  | ||||
|     def set_to(self, value): | ||||
|     def set_to(self, value: JidStr) -> None: | ||||
|         """Set the ``'to'`` attribute of the stanza. | ||||
|  | ||||
|         :param value: A string or :class:`slixmpp.xmlstream.JID` object | ||||
| @@ -1426,11 +1515,11 @@ class StanzaBase(ElementBase): | ||||
|         """ | ||||
|         return self._set_attr('to', str(value)) | ||||
|  | ||||
|     def get_from(self): | ||||
|     def get_from(self) -> JID: | ||||
|         """Return the value of the stanza's ``'from'`` attribute.""" | ||||
|         return JID(self._get_attr('from')) | ||||
|  | ||||
|     def set_from(self, value): | ||||
|     def set_from(self, value: JidStr) -> None: | ||||
|         """Set the 'from' attribute of the stanza. | ||||
|  | ||||
|         :param from: A string or JID object representing the sender's JID. | ||||
| @@ -1438,11 +1527,11 @@ class StanzaBase(ElementBase): | ||||
|         """ | ||||
|         return self._set_attr('from', str(value)) | ||||
|  | ||||
|     def get_payload(self): | ||||
|     def get_payload(self) -> List[ET.Element]: | ||||
|         """Return a list of XML objects contained in the stanza.""" | ||||
|         return list(self.xml) | ||||
|  | ||||
|     def set_payload(self, value): | ||||
|     def set_payload(self, value: Union[List[ElementBase], ElementBase]) -> StanzaBase: | ||||
|         """Add XML content to the stanza. | ||||
|  | ||||
|         :param value: Either an XML or a stanza object, or a list | ||||
| @@ -1454,12 +1543,12 @@ class StanzaBase(ElementBase): | ||||
|             self.append(val) | ||||
|         return self | ||||
|  | ||||
|     def del_payload(self): | ||||
|     def del_payload(self) -> StanzaBase: | ||||
|         """Remove the XML contents of the stanza.""" | ||||
|         self.clear() | ||||
|         return self | ||||
|  | ||||
|     def reply(self, clear=True): | ||||
|     def reply(self, clear: bool = True) -> StanzaBase: | ||||
|         """Prepare the stanza for sending a reply. | ||||
|  | ||||
|         Swaps the ``'from'`` and ``'to'`` attributes. | ||||
| @@ -1475,7 +1564,7 @@ class StanzaBase(ElementBase): | ||||
|         new_stanza = copy.copy(self) | ||||
|         # if it's a component, use from | ||||
|         if self.stream and hasattr(self.stream, "is_component") and \ | ||||
|             self.stream.is_component: | ||||
|                 getattr(self.stream, 'is_component'): | ||||
|             new_stanza['from'], new_stanza['to'] = self['to'], self['from'] | ||||
|         else: | ||||
|             new_stanza['to'] = self['from'] | ||||
| @@ -1484,19 +1573,19 @@ class StanzaBase(ElementBase): | ||||
|             new_stanza.clear() | ||||
|         return new_stanza | ||||
|  | ||||
|     def error(self): | ||||
|     def error(self) -> StanzaBase: | ||||
|         """Set the stanza's type to ``'error'``.""" | ||||
|         self['type'] = 'error' | ||||
|         return self | ||||
|  | ||||
|     def unhandled(self): | ||||
|     def unhandled(self) -> None: | ||||
|         """Called if no handlers have been registered to process this stanza. | ||||
|  | ||||
|         Meant to be overridden. | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def exception(self, e): | ||||
|     def exception(self, e: Exception) -> None: | ||||
|         """Handle exceptions raised during stanza processing. | ||||
|  | ||||
|         Meant to be overridden. | ||||
| @@ -1504,18 +1593,21 @@ class StanzaBase(ElementBase): | ||||
|         log.exception('Error handling {%s}%s stanza', self.namespace, | ||||
|                                                       self.name) | ||||
|  | ||||
|     def send(self): | ||||
|     def send(self) -> None: | ||||
|         """Queue the stanza to be sent on the XML stream.""" | ||||
|         self.stream.send(self) | ||||
|         if self.stream is not None: | ||||
|             self.stream.send(self) | ||||
|         else: | ||||
|             log.error("Tried to send stanza without a stream: %s", self) | ||||
|  | ||||
|     def __copy__(self): | ||||
|     def __copy__(self) -> StanzaBase: | ||||
|         """Return a copy of the stanza object that does not share the | ||||
|         same underlying XML object, but does share the same XML stream. | ||||
|         """ | ||||
|         return self.__class__(xml=copy.deepcopy(self.xml), | ||||
|                               stream=self.stream) | ||||
|  | ||||
|     def __str__(self, top_level_ns=False): | ||||
|     def __str__(self, top_level_ns: bool = False) -> str: | ||||
|         """Serialize the stanza's XML to a string. | ||||
|  | ||||
|         :param bool top_level_ns: Display the top-most namespace. | ||||
| @@ -1525,27 +1617,3 @@ class StanzaBase(ElementBase): | ||||
|         return tostring(self.xml, xmlns=xmlns, | ||||
|                         stream=self.stream, | ||||
|                         top_level=(self.stream is None)) | ||||
|  | ||||
|  | ||||
| #: A JSON/dictionary version of the XML content exposed through | ||||
| #: the stanza interfaces:: | ||||
| #: | ||||
| #:     >>> msg = Message() | ||||
| #:     >>> msg.values | ||||
| #:    {'body': '', 'from': , 'mucnick': '', 'mucroom': '', | ||||
| #:     'to': , 'type': 'normal', 'id': '', 'subject': ''} | ||||
| #: | ||||
| #: Likewise, assigning to the :attr:`values` will change the XML | ||||
| #: content:: | ||||
| #: | ||||
| #:     >>> msg = Message() | ||||
| #:     >>> msg.values = {'body': 'Hi!', 'to': 'user@example.com'} | ||||
| #:     >>> msg | ||||
| #:     '<message to="user@example.com"><body>Hi!</body></message>' | ||||
| #: | ||||
| #: Child stanzas are exposed as nested dictionaries. | ||||
| ElementBase.values = property(ElementBase._get_stanza_values, | ||||
|                               ElementBase._set_stanza_values) | ||||
|  | ||||
| ElementBase.get_stanza_values = ElementBase._get_stanza_values | ||||
| ElementBase.set_stanza_values = ElementBase._set_stanza_values | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
|  | ||||
| # slixmpp.xmlstream.tostring | ||||
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||
| # This module converts XML objects into Unicode strings and | ||||
| @@ -7,11 +6,20 @@ | ||||
| # Part of Slixmpp: The Slick XMPP Library | ||||
| # :copyright: (c) 2011 Nathanael C. Fritz | ||||
| # :license: MIT, see LICENSE for more details | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import Optional, Set, TYPE_CHECKING | ||||
| from xml.etree.ElementTree import Element | ||||
| if TYPE_CHECKING: | ||||
|     from slixmpp.xmlstream import XMLStream | ||||
|  | ||||
| XML_NS = 'http://www.w3.org/XML/1998/namespace' | ||||
|  | ||||
|  | ||||
| def tostring(xml=None, xmlns='', stream=None, outbuffer='', | ||||
|              top_level=False, open_only=False, namespaces=None): | ||||
| def tostring(xml: Optional[Element] = None, xmlns: str = '', | ||||
|              stream: Optional[XMLStream] = None, outbuffer: str = '', | ||||
|              top_level: bool = False, open_only: bool = False, | ||||
|              namespaces: Optional[Set[str]] = None) -> str: | ||||
|     """Serialize an XML object to a Unicode string. | ||||
|  | ||||
|     If an outer xmlns is provided using ``xmlns``, then the current element's | ||||
| @@ -35,6 +43,8 @@ def tostring(xml=None, xmlns='', stream=None, outbuffer='', | ||||
|  | ||||
|     :rtype: Unicode string | ||||
|     """ | ||||
|     if xml is None: | ||||
|         return '' | ||||
|     # Add previous results to the start of the output. | ||||
|     output = [outbuffer] | ||||
|  | ||||
| @@ -123,11 +133,12 @@ def tostring(xml=None, xmlns='', stream=None, outbuffer='', | ||||
|         # Remove namespaces introduced in this context. This is necessary | ||||
|         # because the namespaces object continues to be shared with other | ||||
|         # contexts. | ||||
|         namespaces.remove(ns) | ||||
|         if namespaces is not None: | ||||
|             namespaces.remove(ns) | ||||
|     return ''.join(output) | ||||
|  | ||||
|  | ||||
| def escape(text, use_cdata=False): | ||||
| def escape(text: str, use_cdata: bool = False) -> str: | ||||
|     """Convert special characters in XML to escape sequences. | ||||
|  | ||||
|     :param string text: The XML text to convert. | ||||
|   | ||||
| @@ -9,17 +9,24 @@ | ||||
| # :license: MIT, see LICENSE for more details | ||||
| from typing import ( | ||||
|     Any, | ||||
|     Dict, | ||||
|     Awaitable, | ||||
|     Generator, | ||||
|     Coroutine, | ||||
|     Callable, | ||||
|     Iterable, | ||||
|     Iterator, | ||||
|     List, | ||||
|     Optional, | ||||
|     Set, | ||||
|     Union, | ||||
|     Tuple, | ||||
|     TypeVar, | ||||
|     NoReturn, | ||||
|     Type, | ||||
|     cast, | ||||
| ) | ||||
|  | ||||
| import asyncio | ||||
| import functools | ||||
| import logging | ||||
| import socket as Socket | ||||
| @@ -27,30 +34,66 @@ import ssl | ||||
| import weakref | ||||
| import uuid | ||||
|  | ||||
| from asyncio import iscoroutinefunction, wait, Future | ||||
| from contextlib import contextmanager | ||||
| import xml.etree.ElementTree as ET | ||||
| from asyncio import ( | ||||
|     AbstractEventLoop, | ||||
|     BaseTransport, | ||||
|     Future, | ||||
|     Task, | ||||
|     TimerHandle, | ||||
|     Transport, | ||||
|     iscoroutinefunction, | ||||
|     wait, | ||||
| ) | ||||
|  | ||||
| from slixmpp.xmlstream.asyncio import asyncio | ||||
| from slixmpp.xmlstream import tostring | ||||
| from slixmpp.types import FilterString | ||||
| from slixmpp.xmlstream.tostring import tostring | ||||
| from slixmpp.xmlstream.stanzabase import StanzaBase, ElementBase | ||||
| from slixmpp.xmlstream.resolver import resolve, default_resolver | ||||
| from slixmpp.xmlstream.handler.base import BaseHandler | ||||
|  | ||||
| T = TypeVar('T') | ||||
|  | ||||
| #: The time in seconds to wait before timing out waiting for response stanzas. | ||||
| RESPONSE_TIMEOUT = 30 | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| class ContinueQueue(Exception): | ||||
|     """ | ||||
|     Exception raised in the send queue to "continue" from within an inner loop | ||||
|     """ | ||||
|  | ||||
|  | ||||
| class NotConnectedError(Exception): | ||||
|     """ | ||||
|     Raised when we try to send something over the wire but we are not | ||||
|     connected. | ||||
|     """ | ||||
|  | ||||
|  | ||||
| _T = TypeVar('_T', str, ElementBase, StanzaBase) | ||||
|  | ||||
|  | ||||
| SyncFilter = Callable[[StanzaBase], Optional[StanzaBase]] | ||||
| AsyncFilter = Callable[[StanzaBase], Awaitable[Optional[StanzaBase]]] | ||||
|  | ||||
|  | ||||
| Filter = Union[ | ||||
|     SyncFilter, | ||||
|     AsyncFilter, | ||||
| ] | ||||
|  | ||||
| _FiltersDict = Dict[str, List[Filter]] | ||||
|  | ||||
| Handler = Callable[[Any], Union[ | ||||
|     Any, | ||||
|     Coroutine[Any, Any, Any] | ||||
| ]] | ||||
|  | ||||
|  | ||||
| class XMLStream(asyncio.BaseProtocol): | ||||
|     """ | ||||
|     An XML stream connection manager and event dispatcher. | ||||
| @@ -78,16 +121,156 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|     :param int port: The port to use for the connection. Defaults to 0. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, host='', port=0): | ||||
|         # The asyncio.Transport object provided by the connection_made() | ||||
|         # callback when we are connected | ||||
|     transport: Optional[Transport] | ||||
|  | ||||
|     # The socket that is used internally by the transport object | ||||
|     socket: Optional[ssl.SSLSocket] | ||||
|  | ||||
|     # The backoff of the connect routine (increases exponentially | ||||
|     # after each failure) | ||||
|     _connect_loop_wait: float | ||||
|  | ||||
|     parser: Optional[ET.XMLPullParser] | ||||
|     xml_depth: int | ||||
|     xml_root: Optional[ET.Element] | ||||
|  | ||||
|     force_starttls: Optional[bool] | ||||
|     disable_starttls: Optional[bool] | ||||
|  | ||||
|     waiting_queue: asyncio.Queue | ||||
|  | ||||
|     # A dict of {name: handle} | ||||
|     scheduled_events: Dict[str, TimerHandle] | ||||
|  | ||||
|     ssl_context: ssl.SSLContext | ||||
|  | ||||
|     # The event to trigger when the create_connection() succeeds. It can | ||||
|     # be "connected" or "tls_success" depending on the step we are at. | ||||
|     event_when_connected: str | ||||
|  | ||||
|     #: The list of accepted ciphers, in OpenSSL Format. | ||||
|     #: It might be useful to override it for improved security | ||||
|     #: over the python defaults. | ||||
|     ciphers: Optional[str] | ||||
|  | ||||
|     #: Path to a file containing certificates for verifying the | ||||
|     #: server SSL certificate. A non-``None`` value will trigger | ||||
|     #: certificate checking. | ||||
|     #: | ||||
|     #: .. note:: | ||||
|     #: | ||||
|     #:     On Mac OS X, certificates in the system keyring will | ||||
|     #:     be consulted, even if they are not in the provided file. | ||||
|     ca_certs: Optional[str] | ||||
|  | ||||
|     #: Path to a file containing a client certificate to use for | ||||
|     #: authenticating via SASL EXTERNAL. If set, there must also | ||||
|     #: be a corresponding `:attr:keyfile` value. | ||||
|     certfile: Optional[str] | ||||
|  | ||||
|     #: Path to a file containing the private key for the selected | ||||
|     #: client certificate to use for authenticating via SASL EXTERNAL. | ||||
|     keyfile: Optional[str] | ||||
|  | ||||
|     # The asyncio event loop | ||||
|     _loop: Optional[AbstractEventLoop] | ||||
|  | ||||
|     #: The default port to return when querying DNS records. | ||||
|     default_port: int | ||||
|  | ||||
|     #: The domain to try when querying DNS records. | ||||
|     default_domain: str | ||||
|  | ||||
|     #: The expected name of the server, for validation. | ||||
|     _expected_server_name: str | ||||
|     _service_name: str | ||||
|  | ||||
|     #: The desired, or actual, address of the connected server. | ||||
|     address: Tuple[str, int] | ||||
|  | ||||
|     #: Enable connecting to the server directly over SSL, in | ||||
|     #: particular when the service provides two ports: one for | ||||
|     #: non-SSL traffic and another for SSL traffic. | ||||
|     use_ssl: bool | ||||
|  | ||||
|     #: If set to ``True``, attempt to use IPv6. | ||||
|     use_ipv6: bool | ||||
|  | ||||
|     #: If set to ``True``, allow using the ``dnspython`` DNS library | ||||
|     #: if available. If set to ``False``, the builtin DNS resolver | ||||
|     #: will be used, even if ``dnspython`` is installed. | ||||
|     use_aiodns: bool | ||||
|  | ||||
|     #: Use CDATA for escaping instead of XML entities. Defaults | ||||
|     #: to ``False``. | ||||
|     use_cdata: bool | ||||
|  | ||||
|     #: The default namespace of the stream content, not of the | ||||
|     #: stream wrapper it | ||||
|     default_ns: str | ||||
|  | ||||
|     default_lang: Optional[str] | ||||
|     peer_default_lang: Optional[str] | ||||
|  | ||||
|     #: The namespace of the enveloping stream element. | ||||
|     stream_ns: str | ||||
|  | ||||
|     #: The default opening tag for the stream element. | ||||
|     stream_header: str | ||||
|  | ||||
|     #: The default closing tag for the stream element. | ||||
|     stream_footer: str | ||||
|  | ||||
|     #: If ``True``, periodically send a whitespace character over the | ||||
|     #: wire to keep the connection alive. Mainly useful for connections | ||||
|     #: traversing NAT. | ||||
|     whitespace_keepalive: bool | ||||
|  | ||||
|     #: The default interval between keepalive signals when | ||||
|     #: :attr:`whitespace_keepalive` is enabled. | ||||
|     whitespace_keepalive_interval: int | ||||
|  | ||||
|     #: Flag for controlling if the session can be considered ended | ||||
|     #: if the connection is terminated. | ||||
|     end_session_on_disconnect: bool | ||||
|  | ||||
|     #: A mapping of XML namespaces to well-known prefixes. | ||||
|     namespace_map: dict | ||||
|  | ||||
|     __root_stanza: List[Type[StanzaBase]] | ||||
|     __handlers: List[BaseHandler] | ||||
|     __event_handlers: Dict[str, List[Tuple[Handler, bool]]] | ||||
|     __filters: _FiltersDict | ||||
|  | ||||
|     # Current connection attempt (Future) | ||||
|     _current_connection_attempt: Optional[Future] | ||||
|  | ||||
|     #: A list of DNS results that have not yet been tried. | ||||
|     _dns_answers: Optional[Iterator[Tuple[str, str, int]]] | ||||
|  | ||||
|     #: The service name to check with DNS SRV records. For | ||||
|     #: example, setting this to ``'xmpp-client'`` would query the | ||||
|     #: ``_xmpp-client._tcp`` service. | ||||
|     dns_service: Optional[str] | ||||
|  | ||||
|     #: The reason why we are disconnecting from the server | ||||
|     disconnect_reason: Optional[str] | ||||
|  | ||||
|     #: An asyncio Future being done when the stream is disconnected. | ||||
|     disconnected: Future | ||||
|  | ||||
|     # If the session has been started or not | ||||
|     _session_started: bool | ||||
|     # If we want to bypass the send() check (e.g. unit tests) | ||||
|     _always_send_everything: bool | ||||
|  | ||||
|     _run_out_filters: Optional[Future] | ||||
|     __slow_tasks: List[Task] | ||||
|     __queued_stanzas: List[Tuple[Union[StanzaBase, str], bool]] | ||||
|  | ||||
|     def __init__(self, host: str = '', port: int = 0): | ||||
|         self.transport = None | ||||
|  | ||||
|         # The socket that is used internally by the transport object | ||||
|         self.socket = None | ||||
|  | ||||
|         # The backoff of the connect routine (increases exponentially | ||||
|         # after each failure) | ||||
|         self._connect_loop_wait = 0 | ||||
|  | ||||
|         self.parser = None | ||||
| @@ -106,126 +289,60 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         self.ssl_context.check_hostname = False | ||||
|         self.ssl_context.verify_mode = ssl.CERT_NONE | ||||
|  | ||||
|         # The event to trigger when the create_connection() succeeds. It can | ||||
|         # be "connected" or "tls_success" depending on the step we are at. | ||||
|         self.event_when_connected = "connected" | ||||
|  | ||||
|         #: The list of accepted ciphers, in OpenSSL Format. | ||||
|         #: It might be useful to override it for improved security | ||||
|         #: over the python defaults. | ||||
|         self.ciphers = None | ||||
|  | ||||
|         #: Path to a file containing certificates for verifying the | ||||
|         #: server SSL certificate. A non-``None`` value will trigger | ||||
|         #: certificate checking. | ||||
|         #: | ||||
|         #: .. note:: | ||||
|         #: | ||||
|         #:     On Mac OS X, certificates in the system keyring will | ||||
|         #:     be consulted, even if they are not in the provided file. | ||||
|         self.ca_certs = None | ||||
|  | ||||
|         #: Path to a file containing a client certificate to use for | ||||
|         #: authenticating via SASL EXTERNAL. If set, there must also | ||||
|         #: be a corresponding `:attr:keyfile` value. | ||||
|         self.certfile = None | ||||
|  | ||||
|         #: Path to a file containing the private key for the selected | ||||
|         #: client certificate to use for authenticating via SASL EXTERNAL. | ||||
|         self.keyfile = None | ||||
|  | ||||
|         self._der_cert = None | ||||
|  | ||||
|         # The asyncio event loop | ||||
|         self._loop = None | ||||
|  | ||||
|         #: The default port to return when querying DNS records. | ||||
|         self.default_port = int(port) | ||||
|  | ||||
|         #: The domain to try when querying DNS records. | ||||
|         self.default_domain = '' | ||||
|  | ||||
|         #: The expected name of the server, for validation. | ||||
|         self._expected_server_name = '' | ||||
|         self._service_name = '' | ||||
|  | ||||
|         #: The desired, or actual, address of the connected server. | ||||
|         self.address = (host, int(port)) | ||||
|  | ||||
|         #: Enable connecting to the server directly over SSL, in | ||||
|         #: particular when the service provides two ports: one for | ||||
|         #: non-SSL traffic and another for SSL traffic. | ||||
|         self.use_ssl = False | ||||
|  | ||||
|         #: If set to ``True``, attempt to use IPv6. | ||||
|         self.use_ipv6 = True | ||||
|  | ||||
|         #: If set to ``True``, allow using the ``dnspython`` DNS library | ||||
|         #: if available. If set to ``False``, the builtin DNS resolver | ||||
|         #: will be used, even if ``dnspython`` is installed. | ||||
|         self.use_aiodns = True | ||||
|  | ||||
|         #: Use CDATA for escaping instead of XML entities. Defaults | ||||
|         #: to ``False``. | ||||
|         self.use_cdata = False | ||||
|  | ||||
|         #: The default namespace of the stream content, not of the | ||||
|         #: stream wrapper itself. | ||||
|         self.default_ns = '' | ||||
|  | ||||
|         self.default_lang = None | ||||
|         self.peer_default_lang = None | ||||
|  | ||||
|         #: The namespace of the enveloping stream element. | ||||
|         self.stream_ns = '' | ||||
|  | ||||
|         #: The default opening tag for the stream element. | ||||
|         self.stream_header = "<stream>" | ||||
|  | ||||
|         #: The default closing tag for the stream element. | ||||
|         self.stream_footer = "</stream>" | ||||
|  | ||||
|         #: If ``True``, periodically send a whitespace character over the | ||||
|         #: wire to keep the connection alive. Mainly useful for connections | ||||
|         #: traversing NAT. | ||||
|         self.whitespace_keepalive = True | ||||
|  | ||||
|         #: The default interval between keepalive signals when | ||||
|         #: :attr:`whitespace_keepalive` is enabled. | ||||
|         self.whitespace_keepalive_interval = 300 | ||||
|  | ||||
|         #: Flag for controlling if the session can be considered ended | ||||
|         #: if the connection is terminated. | ||||
|         self.end_session_on_disconnect = True | ||||
|  | ||||
|         #: A mapping of XML namespaces to well-known prefixes. | ||||
|         self.namespace_map = {StanzaBase.xml_ns: 'xml'} | ||||
|  | ||||
|         self.__root_stanza = [] | ||||
|         self.__handlers = [] | ||||
|         self.__event_handlers = {} | ||||
|         self.__filters = {'in': [], 'out': [], 'out_sync': []} | ||||
|         self.__filters = { | ||||
|             'in': [], 'out': [], 'out_sync': [] | ||||
|         } | ||||
|  | ||||
|         # Current connection attempt (Future) | ||||
|         self._current_connection_attempt = None | ||||
|  | ||||
|         #: A list of DNS results that have not yet been tried. | ||||
|         self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None | ||||
|  | ||||
|         #: The service name to check with DNS SRV records. For | ||||
|         #: example, setting this to ``'xmpp-client'`` would query the | ||||
|         #: ``_xmpp-client._tcp`` service. | ||||
|         self._dns_answers = None | ||||
|         self.dns_service = None | ||||
|  | ||||
|         #: The reason why we are disconnecting from the server | ||||
|         self.disconnect_reason = None | ||||
|  | ||||
|         #: An asyncio Future being done when the stream is disconnected. | ||||
|         self.disconnected: Future = Future() | ||||
|  | ||||
|         # If the session has been started or not | ||||
|         self.disconnected = Future() | ||||
|         self._session_started = False | ||||
|         # If we want to bypass the send() check (e.g. unit tests) | ||||
|         self._always_send_everything = False | ||||
|  | ||||
|         self.add_event_handler('disconnected', self._remove_schedules) | ||||
| @@ -234,21 +351,21 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         self.add_event_handler('session_start', self._set_session_start) | ||||
|         self.add_event_handler('session_resumed', self._set_session_start) | ||||
|  | ||||
|         self._run_out_filters: Optional[Future] = None | ||||
|         self.__slow_tasks: List[Future] = [] | ||||
|         self.__queued_stanzas: List[Tuple[StanzaBase, bool]] = [] | ||||
|         self._run_out_filters = None | ||||
|         self.__slow_tasks = [] | ||||
|         self.__queued_stanzas = [] | ||||
|  | ||||
|     @property | ||||
|     def loop(self): | ||||
|     def loop(self) -> AbstractEventLoop: | ||||
|         if self._loop is None: | ||||
|             self._loop = asyncio.get_event_loop() | ||||
|         return self._loop | ||||
|  | ||||
|     @loop.setter | ||||
|     def loop(self, value): | ||||
|     def loop(self, value: AbstractEventLoop) -> None: | ||||
|         self._loop = value | ||||
|  | ||||
|     def new_id(self): | ||||
|     def new_id(self) -> str: | ||||
|         """Generate and return a new stream ID in hexadecimal form. | ||||
|  | ||||
|         Many stanzas, handlers, or matchers may require unique | ||||
| @@ -257,7 +374,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         """ | ||||
|         return uuid.uuid4().hex | ||||
|  | ||||
|     def _set_session_start(self, event): | ||||
|     def _set_session_start(self, event: Any) -> None: | ||||
|         """ | ||||
|         On session start, queue all pending stanzas to be sent. | ||||
|         """ | ||||
| @@ -266,17 +383,17 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             self.waiting_queue.put_nowait(stanza) | ||||
|         self.__queued_stanzas = [] | ||||
|  | ||||
|     def _set_disconnected(self, event): | ||||
|     def _set_disconnected(self, event: Any) -> None: | ||||
|         self._session_started = False | ||||
|  | ||||
|     def _set_disconnected_future(self): | ||||
|     def _set_disconnected_future(self) -> None: | ||||
|         """Set the self.disconnected future on disconnect""" | ||||
|         if not self.disconnected.done(): | ||||
|             self.disconnected.set_result(True) | ||||
|         self.disconnected = asyncio.Future() | ||||
|  | ||||
|     def connect(self, host='', port=0, use_ssl=False, | ||||
|                 force_starttls=True, disable_starttls=False): | ||||
|     def connect(self, host: str = '', port: int = 0, use_ssl: Optional[bool] = False, | ||||
|                 force_starttls: Optional[bool] = True, disable_starttls: Optional[bool] = False) -> None: | ||||
|         """Create a new socket and connect to the server. | ||||
|  | ||||
|         :param host: The name of the desired server for the connection. | ||||
| @@ -327,7 +444,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             loop=self.loop, | ||||
|         ) | ||||
|  | ||||
|     async def _connect_routine(self): | ||||
|     async def _connect_routine(self) -> None: | ||||
|         self.event_when_connected = "connected" | ||||
|  | ||||
|         if self._connect_loop_wait > 0: | ||||
| @@ -345,6 +462,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             # and try (host, port) as a last resort | ||||
|             self._dns_answers = None | ||||
|  | ||||
|         ssl_context: Optional[ssl.SSLContext] | ||||
|         if self.use_ssl: | ||||
|             ssl_context = self.get_ssl_context() | ||||
|         else: | ||||
| @@ -373,7 +491,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|                 loop=self.loop, | ||||
|             ) | ||||
|  | ||||
|     def process(self, *, forever=True, timeout=None): | ||||
|     def process(self, *, forever: bool = True, timeout: Optional[int] = None) -> None: | ||||
|         """Process all the available XMPP events (receiving or sending data on the | ||||
|         socket(s), calling various registered callbacks, calling expired | ||||
|         timers, handling signal events, etc).  If timeout is None, this | ||||
| @@ -386,12 +504,12 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             else: | ||||
|                 self.loop.run_until_complete(self.disconnected) | ||||
|         else: | ||||
|             tasks = [asyncio.sleep(timeout, loop=self.loop)] | ||||
|             tasks: List[Future] = [asyncio.sleep(timeout, loop=self.loop)] | ||||
|             if not forever: | ||||
|                 tasks.append(self.disconnected) | ||||
|             self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop)) | ||||
|  | ||||
|     def init_parser(self): | ||||
|     def init_parser(self) -> None: | ||||
|         """init the XML parser. The parser must always be reset for each new | ||||
|         connexion | ||||
|         """ | ||||
| @@ -399,11 +517,13 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         self.xml_root = None | ||||
|         self.parser = ET.XMLPullParser(("start", "end")) | ||||
|  | ||||
|     def connection_made(self, transport): | ||||
|     def connection_made(self, transport: BaseTransport) -> None: | ||||
|         """Called when the TCP connection has been established with the server | ||||
|         """ | ||||
|         self.event(self.event_when_connected) | ||||
|         self.transport = transport | ||||
|         self.transport = cast(Transport, transport) | ||||
|         if self.transport is None: | ||||
|             raise ValueError("Transport cannot be none") | ||||
|         self.socket = self.transport.get_extra_info( | ||||
|             "ssl_object", | ||||
|             default=self.transport.get_extra_info("socket") | ||||
| @@ -413,7 +533,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         self.send_raw(self.stream_header) | ||||
|         self._dns_answers = None | ||||
|  | ||||
|     def data_received(self, data): | ||||
|     def data_received(self, data: bytes) -> None: | ||||
|         """Called when incoming data is received on the socket. | ||||
|  | ||||
|         We feed that data to the parser and the see if this produced any XML | ||||
| @@ -467,18 +587,18 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             self.send(error) | ||||
|             self.disconnect() | ||||
|  | ||||
|     def is_connecting(self): | ||||
|     def is_connecting(self) -> bool: | ||||
|         return self._current_connection_attempt is not None | ||||
|  | ||||
|     def is_connected(self): | ||||
|     def is_connected(self) -> bool: | ||||
|         return self.transport is not None | ||||
|  | ||||
|     def eof_received(self): | ||||
|     def eof_received(self) -> None: | ||||
|         """When the TCP connection is properly closed by the remote end | ||||
|         """ | ||||
|         self.event("eof_received") | ||||
|  | ||||
|     def connection_lost(self, exception): | ||||
|     def connection_lost(self, exception: Optional[BaseException]) -> None: | ||||
|         """On any kind of disconnection, initiated by us or not.  This signals the | ||||
|         closure of the TCP connection | ||||
|         """ | ||||
| @@ -493,9 +613,9 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             self._reset_sendq() | ||||
|             self.event('session_end') | ||||
|         self._set_disconnected_future() | ||||
|         self.event("disconnected", self.disconnect_reason or exception and exception.strerror) | ||||
|         self.event("disconnected", self.disconnect_reason or exception) | ||||
|  | ||||
|     def cancel_connection_attempt(self): | ||||
|     def cancel_connection_attempt(self) -> None: | ||||
|         """ | ||||
|         Immediately cancel the current create_connection() Future. | ||||
|         This is useful when a client using slixmpp tries to connect | ||||
| @@ -526,7 +646,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         # `disconnect(wait=True)` for ages. This doesn't mean anything to the | ||||
|         # schedule call below. It would fortunately be converted to `1` later | ||||
|         # down the call chain. Praise the implicit casts lord. | ||||
|         if wait == True: | ||||
|         if wait is True: | ||||
|             wait = 2.0 | ||||
|  | ||||
|         if self.transport: | ||||
| @@ -545,11 +665,11 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         else: | ||||
|             self._set_disconnected_future() | ||||
|             self.event("disconnected", reason) | ||||
|             future = Future() | ||||
|             future: Future = Future() | ||||
|             future.set_result(None) | ||||
|             return future | ||||
|  | ||||
|     async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float): | ||||
|     async def _consume_send_queue_before_disconnecting(self, reason: Optional[str], wait: float) -> None: | ||||
|         """Wait until the send queue is empty before disconnecting""" | ||||
|         try: | ||||
|             await asyncio.wait_for( | ||||
| @@ -561,7 +681,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         self.disconnect_reason = reason | ||||
|         await self._end_stream_wait(wait) | ||||
|  | ||||
|     async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None): | ||||
|     async def _end_stream_wait(self, wait: Union[int, float] = 2, reason: Optional[str] = None) -> None: | ||||
|         """ | ||||
|         Run abort() if we do not received the disconnected event | ||||
|         after a waiting time. | ||||
| @@ -578,7 +698,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             # that means the disconnect has already been handled | ||||
|             pass | ||||
|  | ||||
|     def abort(self): | ||||
|     def abort(self) -> None: | ||||
|         """ | ||||
|         Forcibly close the connection | ||||
|         """ | ||||
| @@ -588,26 +708,26 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             self.transport.abort() | ||||
|             self.event("killed") | ||||
|  | ||||
|     def reconnect(self, wait=2.0, reason="Reconnecting"): | ||||
|     def reconnect(self, wait: Union[int, float] = 2.0, reason: str = "Reconnecting") -> None: | ||||
|         """Calls disconnect(), and once we are disconnected (after the timeout, or | ||||
|         when the server acknowledgement is received), call connect() | ||||
|         """ | ||||
|         log.debug("reconnecting...") | ||||
|         async def handler(event): | ||||
|         async def handler(event: Any) -> None: | ||||
|             # We yield here to allow synchronous handlers to work first | ||||
|             await asyncio.sleep(0, loop=self.loop) | ||||
|             self.connect() | ||||
|         self.add_event_handler('disconnected', handler, disposable=True) | ||||
|         self.disconnect(wait, reason) | ||||
|  | ||||
|     def configure_socket(self): | ||||
|     def configure_socket(self) -> None: | ||||
|         """Set timeout and other options for self.socket. | ||||
|  | ||||
|         Meant to be overridden. | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def configure_dns(self, resolver, domain=None, port=None): | ||||
|     def configure_dns(self, resolver: Any, domain: Optional[str] = None, port: Optional[int] = None) -> None: | ||||
|         """ | ||||
|         Configure and set options for a :class:`~dns.resolver.Resolver` | ||||
|         instance, and other DNS related tasks. For example, you | ||||
| @@ -624,7 +744,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def get_ssl_context(self): | ||||
|     def get_ssl_context(self) -> ssl.SSLContext: | ||||
|         """ | ||||
|         Get SSL context. | ||||
|         """ | ||||
| @@ -644,12 +764,14 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|  | ||||
|         return self.ssl_context | ||||
|  | ||||
|     async def start_tls(self): | ||||
|     async def start_tls(self) -> bool: | ||||
|         """Perform handshakes for TLS. | ||||
|  | ||||
|         If the handshake is successful, the XML stream will need | ||||
|         to be restarted. | ||||
|         """ | ||||
|         if self.transport is None: | ||||
|             raise ValueError("Transport should not be None") | ||||
|         self.event_when_connected = "tls_success" | ||||
|         ssl_context = self.get_ssl_context() | ||||
|         try: | ||||
| @@ -685,7 +807,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             self.connection_made(transp) | ||||
|         return True | ||||
|  | ||||
|     def _start_keepalive(self, event): | ||||
|     def _start_keepalive(self, event: Any) -> None: | ||||
|         """Begin sending whitespace periodically to keep the connection alive. | ||||
|  | ||||
|         May be disabled by setting:: | ||||
| @@ -702,11 +824,11 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|                       args=(' ',), | ||||
|                       repeat=True) | ||||
|  | ||||
|     def _remove_schedules(self, event): | ||||
|     def _remove_schedules(self, event: Any) -> None: | ||||
|         """Remove some schedules that become pointless when disconnected""" | ||||
|         self.cancel_schedule('Whitespace Keepalive') | ||||
|  | ||||
|     def start_stream_handler(self, xml): | ||||
|     def start_stream_handler(self, xml: ET.Element) -> None: | ||||
|         """Perform any initialization actions, such as handshakes, | ||||
|         once the stream header has been sent. | ||||
|  | ||||
| @@ -714,7 +836,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     def register_stanza(self, stanza_class): | ||||
|     def register_stanza(self, stanza_class: Type[StanzaBase]) -> None: | ||||
|         """Add a stanza object class as a known root stanza. | ||||
|  | ||||
|         A root stanza is one that appears as a direct child of the stream's | ||||
| @@ -732,7 +854,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         """ | ||||
|         self.__root_stanza.append(stanza_class) | ||||
|  | ||||
|     def remove_stanza(self, stanza_class): | ||||
|     def remove_stanza(self, stanza_class: Type[StanzaBase]) -> None: | ||||
|         """Remove a stanza from being a known root stanza. | ||||
|  | ||||
|         A root stanza is one that appears as a direct child of the stream's | ||||
| @@ -744,7 +866,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         """ | ||||
|         self.__root_stanza.remove(stanza_class) | ||||
|  | ||||
|     def add_filter(self, mode, handler, order=None): | ||||
|     def add_filter(self, mode: FilterString, handler: Callable[[StanzaBase], Optional[StanzaBase]], order: Optional[int] = None) -> None: | ||||
|         """Add a filter for incoming or outgoing stanzas. | ||||
|  | ||||
|         These filters are applied before incoming stanzas are | ||||
| @@ -766,11 +888,11 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         else: | ||||
|             self.__filters[mode].append(handler) | ||||
|  | ||||
|     def del_filter(self, mode, handler): | ||||
|     def del_filter(self, mode: str, handler: Callable[[StanzaBase], Optional[StanzaBase]]) -> None: | ||||
|         """Remove an incoming or outgoing filter.""" | ||||
|         self.__filters[mode].remove(handler) | ||||
|  | ||||
|     def register_handler(self, handler, before=None, after=None): | ||||
|     def register_handler(self, handler: BaseHandler, before: Optional[BaseHandler] = None, after: Optional[BaseHandler] = None) -> None: | ||||
|         """Add a stream event handler that will be executed when a matching | ||||
|         stanza is received. | ||||
|  | ||||
| @@ -782,7 +904,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             self.__handlers.append(handler) | ||||
|             handler.stream = weakref.ref(self) | ||||
|  | ||||
|     def remove_handler(self, name): | ||||
|     def remove_handler(self, name: str) -> bool: | ||||
|         """Remove any stream event handlers with the given name. | ||||
|  | ||||
|         :param name: The name of the handler. | ||||
| @@ -831,9 +953,9 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         try: | ||||
|             return next(self._dns_answers) | ||||
|         except StopIteration: | ||||
|             return | ||||
|             return None | ||||
|  | ||||
|     def add_event_handler(self, name, pointer, disposable=False): | ||||
|     def add_event_handler(self, name: str, pointer: Callable[..., Any], disposable: bool = False) -> None: | ||||
|         """Add a custom event handler that will be executed whenever | ||||
|         its event is manually triggered. | ||||
|  | ||||
| @@ -847,7 +969,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             self.__event_handlers[name] = [] | ||||
|         self.__event_handlers[name].append((pointer, disposable)) | ||||
|  | ||||
|     def del_event_handler(self, name, pointer): | ||||
|     def del_event_handler(self, name: str, pointer: Callable[..., Any]) -> None: | ||||
|         """Remove a function as a handler for an event. | ||||
|  | ||||
|         :param name: The name of the event. | ||||
| @@ -858,21 +980,21 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|  | ||||
|         # Need to keep handlers that do not use | ||||
|         # the given function pointer | ||||
|         def filter_pointers(handler): | ||||
|         def filter_pointers(handler: Tuple[Callable[..., Any], bool]) -> bool: | ||||
|             return handler[0] != pointer | ||||
|  | ||||
|         self.__event_handlers[name] = list(filter( | ||||
|             filter_pointers, | ||||
|             self.__event_handlers[name])) | ||||
|  | ||||
|     def event_handled(self, name): | ||||
|     def event_handled(self, name: str) -> int: | ||||
|         """Returns the number of registered handlers for an event. | ||||
|  | ||||
|         :param name: The name of the event to check. | ||||
|         """ | ||||
|         return len(self.__event_handlers.get(name, [])) | ||||
|  | ||||
|     async def event_async(self, name: str, data: Any = {}): | ||||
|     async def event_async(self, name: str, data: Any = {}) -> None: | ||||
|         """Manually trigger a custom event, but await coroutines immediately. | ||||
|  | ||||
|         This event generator should only be called in situations when | ||||
| @@ -908,7 +1030,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|                 except Exception as e: | ||||
|                     self.exception(e) | ||||
|  | ||||
|     def event(self, name: str, data: Any = {}): | ||||
|     def event(self, name: str, data: Any = {}) -> None: | ||||
|         """Manually trigger a custom event. | ||||
|         Coroutine handlers are wrapped into a future and sent into the | ||||
|         event loop for their execution, and not awaited. | ||||
| @@ -928,7 +1050,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             # If the callback is a coroutine, schedule it instead of | ||||
|             # running it directly | ||||
|             if iscoroutinefunction(handler_callback): | ||||
|                 async def handler_callback_routine(cb): | ||||
|                 async def handler_callback_routine(cb: Callable[[ElementBase], Any]) -> None: | ||||
|                     try: | ||||
|                         await cb(data) | ||||
|                     except Exception as e: | ||||
| @@ -957,8 +1079,9 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|                 except ValueError: | ||||
|                     pass | ||||
|  | ||||
|     def schedule(self, name, seconds, callback, args=tuple(), | ||||
|                  kwargs={}, repeat=False): | ||||
|     def schedule(self, name: str, seconds: int, callback: Callable[..., None], | ||||
|             args: Tuple[Any, ...] = tuple(), | ||||
|             kwargs: Dict[Any, Any] = {}, repeat: bool = False) -> None: | ||||
|         """Schedule a callback function to execute after a given delay. | ||||
|  | ||||
|         :param name: A unique name for the scheduled callback. | ||||
| @@ -986,21 +1109,21 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         # canceling scheduled_events[name] | ||||
|         self.scheduled_events[name] = handle | ||||
|  | ||||
|     def cancel_schedule(self, name): | ||||
|     def cancel_schedule(self, name: str) -> None: | ||||
|         try: | ||||
|             handle = self.scheduled_events.pop(name) | ||||
|             handle.cancel() | ||||
|         except KeyError: | ||||
|             log.debug("Tried to cancel unscheduled event: %s" % (name,)) | ||||
|  | ||||
|     def _safe_cb_run(self, name, cb): | ||||
|     def _safe_cb_run(self, name: str, cb: Callable[[], None]) -> None: | ||||
|         log.debug('Scheduled event: %s', name) | ||||
|         try: | ||||
|             cb() | ||||
|         except Exception as e: | ||||
|             self.exception(e) | ||||
|  | ||||
|     def _execute_and_reschedule(self, name, cb, seconds): | ||||
|     def _execute_and_reschedule(self, name: str, cb: Callable[[], None], seconds: int) -> None: | ||||
|         """Simple method that calls the given callback, and then schedule itself to | ||||
|         be called after the given number of seconds. | ||||
|         """ | ||||
| @@ -1009,7 +1132,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|                                       name, cb, seconds) | ||||
|         self.scheduled_events[name] = handle | ||||
|  | ||||
|     def _execute_and_unschedule(self, name, cb): | ||||
|     def _execute_and_unschedule(self, name: str, cb: Callable[[], None]) -> None: | ||||
|         """ | ||||
|         Execute the callback and remove the handler for it. | ||||
|         """ | ||||
| @@ -1018,7 +1141,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         if name in self.scheduled_events: | ||||
|             del self.scheduled_events[name] | ||||
|  | ||||
|     def incoming_filter(self, xml): | ||||
|     def incoming_filter(self, xml: ET.Element) -> ET.Element: | ||||
|         """Filter incoming XML objects before they are processed. | ||||
|  | ||||
|         Possible uses include remapping namespaces, or correcting elements | ||||
| @@ -1028,7 +1151,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         """ | ||||
|         return xml | ||||
|  | ||||
|     def _reset_sendq(self): | ||||
|     def _reset_sendq(self) -> None: | ||||
|         """Clear sending tasks on session end""" | ||||
|         # Cancel all pending slow send tasks | ||||
|         log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks)) | ||||
| @@ -1043,7 +1166,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|     async def _continue_slow_send( | ||||
|             self, | ||||
|             task: asyncio.Task, | ||||
|             already_used: Set[Callable[[ElementBase], Optional[StanzaBase]]] | ||||
|             already_used: Set[Filter] | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Used when an item in the send queue has taken too long to process. | ||||
| @@ -1060,14 +1183,16 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             if filter in already_used: | ||||
|                 continue | ||||
|             if iscoroutinefunction(filter): | ||||
|                 data = await filter(data) | ||||
|                 data = await filter(data)  # type: ignore | ||||
|             else: | ||||
|                 filter = cast(SyncFilter, filter) | ||||
|                 data = filter(data) | ||||
|             if data is None: | ||||
|                 return | ||||
|  | ||||
|         if isinstance(data, ElementBase): | ||||
|         if isinstance(data, StanzaBase): | ||||
|             for filter in self.__filters['out_sync']: | ||||
|                 filter = cast(SyncFilter, filter) | ||||
|                 data = filter(data) | ||||
|                 if data is None: | ||||
|                     return | ||||
| @@ -1077,19 +1202,21 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         else: | ||||
|             self.send_raw(data) | ||||
|  | ||||
|     async def run_filters(self): | ||||
|     async def run_filters(self) -> NoReturn: | ||||
|         """ | ||||
|         Background loop that processes stanzas to send. | ||||
|         """ | ||||
|         while True: | ||||
|             data: Optional[Union[StanzaBase, str]] | ||||
|             (data, use_filters) = await self.waiting_queue.get() | ||||
|             try: | ||||
|                 if isinstance(data, ElementBase): | ||||
|                 if isinstance(data, StanzaBase): | ||||
|                     if use_filters: | ||||
|                         already_run_filters = set() | ||||
|                         for filter in self.__filters['out']: | ||||
|                             already_run_filters.add(filter) | ||||
|                             if iscoroutinefunction(filter): | ||||
|                                 filter = cast(AsyncFilter, filter) | ||||
|                                 task = asyncio.create_task(filter(data)) | ||||
|                                 completed, pending = await wait( | ||||
|                                     {task}, | ||||
| @@ -1108,21 +1235,26 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|                                         "Slow coroutine, rescheduling filters" | ||||
|                                     ) | ||||
|                                 data = task.result() | ||||
|                             else: | ||||
|                             elif isinstance(data, StanzaBase): | ||||
|                                 filter = cast(SyncFilter, filter) | ||||
|                                 data = filter(data) | ||||
|                             if data is None: | ||||
|                                 raise ContinueQueue('Empty stanza') | ||||
|  | ||||
|                 if isinstance(data, ElementBase): | ||||
|                 if isinstance(data, StanzaBase): | ||||
|                     if use_filters: | ||||
|                         for filter in self.__filters['out_sync']: | ||||
|                             filter = cast(SyncFilter, filter) | ||||
|                             data = filter(data) | ||||
|                             if data is None: | ||||
|                                 raise ContinueQueue('Empty stanza') | ||||
|                     str_data = tostring(data.xml, xmlns=self.default_ns, | ||||
|                                         stream=self, top_level=True) | ||||
|                     if isinstance(data, StanzaBase): | ||||
|                         str_data = tostring(data.xml, xmlns=self.default_ns, | ||||
|                                             stream=self, top_level=True) | ||||
|                     else: | ||||
|                         str_data = data | ||||
|                     self.send_raw(str_data) | ||||
|                 else: | ||||
|                 elif isinstance(data, (str, bytes)): | ||||
|                     self.send_raw(data) | ||||
|             except ContinueQueue as exc: | ||||
|                 log.debug('Stanza in send queue not sent: %s', exc) | ||||
| @@ -1130,10 +1262,10 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|                 log.error('Exception raised in send queue:', exc_info=True) | ||||
|             self.waiting_queue.task_done() | ||||
|  | ||||
|     def send(self, data, use_filters=True): | ||||
|     def send(self, data: Union[StanzaBase, str], use_filters: bool = True) -> None: | ||||
|         """A wrapper for :meth:`send_raw()` for sending stanza objects. | ||||
|  | ||||
|         :param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase` | ||||
|         :param data: The :class:`~slixmpp.xmlstream.stanzabase.StanzaBase` | ||||
|                      stanza to send on the stream. | ||||
|         :param bool use_filters: Indicates if outgoing filters should be | ||||
|                                  applied to the given stanza data. Disabling | ||||
| @@ -1156,15 +1288,15 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|                 return | ||||
|         self.waiting_queue.put_nowait((data, use_filters)) | ||||
|  | ||||
|     def send_xml(self, data): | ||||
|     def send_xml(self, data: ET.Element) -> None: | ||||
|         """Send an XML object on the stream | ||||
|  | ||||
|         :param data: The :class:`~xml.etree.ElementTree.Element` XML object | ||||
|                      to send on the stream. | ||||
|         """ | ||||
|         return self.send(tostring(data)) | ||||
|         self.send(tostring(data)) | ||||
|  | ||||
|     def send_raw(self, data): | ||||
|     def send_raw(self, data: Union[str, bytes]) -> None: | ||||
|         """Send raw data across the stream. | ||||
|  | ||||
|         :param string data: Any bytes or utf-8 string value. | ||||
| @@ -1176,7 +1308,8 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             data = data.encode('utf-8') | ||||
|         self.transport.write(data) | ||||
|  | ||||
|     def _build_stanza(self, xml, default_ns=None): | ||||
|     def _build_stanza(self, xml: ET.Element, | ||||
|                       default_ns: Optional[str] = None) -> StanzaBase: | ||||
|         """Create a stanza object from a given XML object. | ||||
|  | ||||
|         If a specialized stanza type is not found for the XML, then | ||||
| @@ -1201,7 +1334,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|             stanza['lang'] = self.peer_default_lang | ||||
|         return stanza | ||||
|  | ||||
|     def _spawn_event(self, xml): | ||||
|     def _spawn_event(self, xml: ET.Element) -> None: | ||||
|         """ | ||||
|         Analyze incoming XML stanzas and convert them into stanza | ||||
|         objects if applicable and queue stream events to be processed | ||||
| @@ -1215,9 +1348,10 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|  | ||||
|         # Convert the raw XML object into a stanza object. If no registered | ||||
|         # stanza type applies, a generic StanzaBase stanza will be used. | ||||
|         stanza = self._build_stanza(xml) | ||||
|         stanza: Optional[StanzaBase] = self._build_stanza(xml) | ||||
|         for filter in self.__filters['in']: | ||||
|             if stanza is not None: | ||||
|                 filter = cast(SyncFilter, filter) | ||||
|                 stanza = filter(stanza) | ||||
|         if stanza is None: | ||||
|             return | ||||
| @@ -1244,7 +1378,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         if not handled: | ||||
|             stanza.unhandled() | ||||
|  | ||||
|     def exception(self, exception): | ||||
|     def exception(self, exception: Exception) -> None: | ||||
|         """Process an unknown exception. | ||||
|  | ||||
|         Meant to be overridden. | ||||
| @@ -1253,7 +1387,7 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         """ | ||||
|         pass | ||||
|  | ||||
|     async def wait_until(self, event: str, timeout=30) -> Any: | ||||
|     async def wait_until(self, event: str, timeout: Union[int, float] = 30) -> Any: | ||||
|         """Utility method to wake on the next firing of an event. | ||||
|         (Registers a disposable handler on it) | ||||
|  | ||||
| @@ -1261,9 +1395,9 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         :param int timeout: Timeout | ||||
|         :raises: :class:`asyncio.TimeoutError` when the timeout is reached | ||||
|         """ | ||||
|         fut = asyncio.Future() | ||||
|         fut: Future = asyncio.Future() | ||||
|  | ||||
|         def result_handler(event_data): | ||||
|         def result_handler(event_data: Any) -> None: | ||||
|             if not fut.done(): | ||||
|                 fut.set_result(event_data) | ||||
|             else: | ||||
| @@ -1280,19 +1414,19 @@ class XMLStream(asyncio.BaseProtocol): | ||||
|         return await asyncio.wait_for(fut, timeout) | ||||
|  | ||||
|     @contextmanager | ||||
|     def event_handler(self, event: str, handler: Callable): | ||||
|     def event_handler(self, event: str, handler: Callable[..., Any]) -> Generator[None, None, None]: | ||||
|         """ | ||||
|         Context manager that adds then removes an event handler. | ||||
|         """ | ||||
|         self.add_event_handler(event, handler) | ||||
|         try: | ||||
|             yield | ||||
|         except Exception as exc: | ||||
|         except Exception: | ||||
|             raise | ||||
|         finally: | ||||
|             self.del_event_handler(event, handler) | ||||
|  | ||||
|     def wrap(self, coroutine: Coroutine[Any, Any, Any]) -> Future: | ||||
|     def wrap(self, coroutine: Coroutine[None, None, T]) -> Future: | ||||
|         """Make a Future out of a coroutine with the current loop. | ||||
|  | ||||
|         :param coroutine: The coroutine to wrap. | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 mathieui
					mathieui