XEP-0115: API changes

- ``get_verstring``, ``get_caps`` are now coroutines.
- ``assign_verstring``, ``cache_caps`` now return a Future.

side-effect: fix supports() and has_identity() broken since forever
This commit is contained in:
mathieui 2021-02-14 11:56:20 +01:00
parent f0aec1614f
commit 4960cffcb4
3 changed files with 99 additions and 29 deletions

View File

@ -8,6 +8,52 @@ XEP-0115: Entity Capabilities
:members: :members:
:exclude-members: session_bind, plugin_init, plugin_end :exclude-members: session_bind, plugin_init, plugin_end
Internal API methods
--------------------
This internal API extends the Disco internal API, and also manages an
in-memory cache of verstring→disco info, and fulljid→verstring.
.. glossary::
cache_caps
- **jid**: unused
- **node**: unused
- **ifrom**: unused
- **args**: a ``dict`` containing the verstring and
:class:`~.DiscoInfo` payload (
``{'verstring': Optional[str], 'info': Optional[DiscoInfo]}``)
Cache a verification string with its payload.
get_caps
- **jid**: JID to retrieve the verstring for (unused with the default
handler)
- **node**: unused
- **ifrom**: unused
- **args**: a ``dict`` containing the verstring
``{'verstring': str}``
- **returns**: The :class:`~.DiscoInfo` payload for that verstring.
Get a disco payload from a verstring.
assign_verstring
- **jid**: :class:`~.JID` (full) to assign the verstring to
- **node**: unused
- **ifrom**: unused
- **args**: a ``dict`` containing the verstring
``{'verstring': str}``
Cache JID→verstring information.
get_verstring
- **jid**: :class:`~.JID` to use for fetching the verstring
- **node**: unused
- **ifrom**: unused
- **args**: unused
- **returns**: ``str``, the verstring
Retrieve a verstring for a JID.
Stanza elements Stanza elements
--------------- ---------------

View File

@ -7,6 +7,8 @@ import logging
import hashlib import hashlib
import base64 import base64
from asyncio import Future
from slixmpp import __version__ from slixmpp import __version__
from slixmpp.stanza import StreamFeatures, Presence, Iq from slixmpp.stanza import StreamFeatures, Presence, Iq
from slixmpp.xmlstream import register_stanza_plugin, JID from slixmpp.xmlstream import register_stanza_plugin, JID
@ -104,14 +106,14 @@ class XEP_0115(BasePlugin):
def session_bind(self, jid): def session_bind(self, jid):
self.xmpp['xep_0030'].add_feature(stanza.Capabilities.namespace) self.xmpp['xep_0030'].add_feature(stanza.Capabilities.namespace)
def _filter_add_caps(self, stanza): async def _filter_add_caps(self, stanza):
if not isinstance(stanza, Presence) or not self.broadcast: if not isinstance(stanza, Presence) or not self.broadcast:
return stanza return stanza
if stanza['type'] not in ('available', 'chat', 'away', 'dnd', 'xa'): if stanza['type'] not in ('available', 'chat', 'away', 'dnd', 'xa'):
return stanza return stanza
ver = self.get_verstring(stanza['from']) ver = await self.get_verstring(stanza['from'])
if ver: if ver:
stanza['caps']['node'] = self.caps_node stanza['caps']['node'] = self.caps_node
stanza['caps']['hash'] = self.hash stanza['caps']['hash'] = self.hash
@ -145,13 +147,13 @@ class XEP_0115(BasePlugin):
ver = pres['caps']['ver'] ver = pres['caps']['ver']
existing_verstring = self.get_verstring(pres['from'].full) existing_verstring = await self.get_verstring(pres['from'].full)
if str(existing_verstring) == str(ver): if str(existing_verstring) == str(ver):
return return
existing_caps = self.get_caps(verstring=ver) existing_caps = await self.get_caps(verstring=ver)
if existing_caps is not None: if existing_caps is not None:
self.assign_verstring(pres['from'], ver) await self.assign_verstring(pres['from'], ver)
return return
ifrom = pres['to'] if self.xmpp.is_component else None ifrom = pres['to'] if self.xmpp.is_component else None
@ -174,13 +176,13 @@ class XEP_0115(BasePlugin):
if isinstance(caps, Iq): if isinstance(caps, Iq):
caps = caps['disco_info'] caps = caps['disco_info']
if self._validate_caps(caps, pres['caps']['hash'], if await self._validate_caps(caps, pres['caps']['hash'],
pres['caps']['ver']): pres['caps']['ver']):
self.assign_verstring(pres['from'], pres['caps']['ver']) await self.assign_verstring(pres['from'], pres['caps']['ver'])
except XMPPError: except XMPPError:
log.debug("Could not retrieve disco#info results for caps for %s", node) log.debug("Could not retrieve disco#info results for caps for %s", node)
def _validate_caps(self, caps, hash, check_verstring): async def _validate_caps(self, caps, hash, check_verstring):
# Check Identities # Check Identities
full_ids = caps.get_identities(dedupe=False) full_ids = caps.get_identities(dedupe=False)
deduped_ids = caps.get_identities() deduped_ids = caps.get_identities()
@ -232,7 +234,7 @@ class XEP_0115(BasePlugin):
verstring, check_verstring)) verstring, check_verstring))
return False return False
self.cache_caps(verstring, caps) await self.cache_caps(verstring, caps)
return True return True
def generate_verstring(self, info, hash): def generate_verstring(self, info, hash):
@ -290,12 +292,13 @@ class XEP_0115(BasePlugin):
if isinstance(info, Iq): if isinstance(info, Iq):
info = info['disco_info'] info = info['disco_info']
ver = self.generate_verstring(info, self.hash) ver = self.generate_verstring(info, self.hash)
self.xmpp['xep_0030'].set_info( await self.xmpp['xep_0030'].set_info(
jid=jid, jid=jid,
node='%s#%s' % (self.caps_node, ver), node='%s#%s' % (self.caps_node, ver),
info=info) info=info
self.cache_caps(ver, info) )
self.assign_verstring(jid, ver) await self.cache_caps(ver, info)
await self.assign_verstring(jid, ver)
if self.xmpp.sessionstarted and self.broadcast: if self.xmpp.sessionstarted and self.broadcast:
if self.xmpp.is_component or preserve: if self.xmpp.is_component or preserve:
@ -306,32 +309,53 @@ class XEP_0115(BasePlugin):
except XMPPError: except XMPPError:
return return
def get_verstring(self, jid=None): def get_verstring(self, jid=None) -> Future:
"""Get the stored verstring for a JID.
.. versionchanged:: 1.8.0
This function now returns a Future.
"""
if jid in ('', None): if jid in ('', None):
jid = self.xmpp.boundjid.full jid = self.xmpp.boundjid.full
if isinstance(jid, JID): if isinstance(jid, JID):
jid = jid.full jid = jid.full
return self.api['get_verstring'](jid) return self.api['get_verstring'](jid)
def assign_verstring(self, jid=None, verstring=None): def assign_verstring(self, jid=None, verstring=None) -> Future:
"""Assign a vertification string to a jid.
.. versionchanged:: 1.8.0
This function now returns a Future.
"""
if jid in (None, ''): if jid in (None, ''):
jid = self.xmpp.boundjid.full jid = self.xmpp.boundjid.full
if isinstance(jid, JID): if isinstance(jid, JID):
jid = jid.full jid = jid.full
return self.api['assign_verstring'](jid, args={ return self.api['assign_verstring'](jid, args={
'verstring': verstring}) 'verstring': verstring
})
def cache_caps(self, verstring=None, info=None): def cache_caps(self, verstring=None, info=None) -> Future:
"""Add caps to the cache.
.. versionchanged:: 1.8.0
This function now returns a Future.
"""
data = {'verstring': verstring, 'info': info} data = {'verstring': verstring, 'info': info}
return self.api['cache_caps'](args=data) return self.api['cache_caps'](args=data)
def get_caps(self, jid=None, verstring=None): async def get_caps(self, jid=None, verstring=None):
"""Get caps for a JID.
.. versionchanged:: 1.8.0
This function is now a coroutine.
"""
if verstring is None: if verstring is None:
if jid is not None: if jid is not None:
verstring = self.get_verstring(jid) verstring = await self.get_verstring(jid)
else: else:
return None return None
if isinstance(jid, JID): if isinstance(jid, JID):
jid = jid.full jid = jid.full
data = {'verstring': verstring} data = {'verstring': verstring}
return self.api['get_caps'](jid, args=data) return await self.api['get_caps'](jid, args=data)

View File

@ -32,7 +32,7 @@ class StaticCaps(object):
self.static = static self.static = static
self.jid_vers = {} self.jid_vers = {}
def supports(self, jid, node, ifrom, data): async def supports(self, jid, node, ifrom, data):
""" """
Check if a JID supports a given feature. Check if a JID supports a given feature.
@ -65,7 +65,7 @@ class StaticCaps(object):
return True return True
try: try:
info = self.disco.get_info(jid=jid, node=node, info = await self.disco.get_info(jid=jid, node=node,
ifrom=ifrom, **data) ifrom=ifrom, **data)
info = self.disco._wrap(ifrom, jid, info, True) info = self.disco._wrap(ifrom, jid, info, True)
return feature in info['disco_info']['features'] return feature in info['disco_info']['features']
@ -74,7 +74,7 @@ class StaticCaps(object):
except IqTimeout: except IqTimeout:
return None return None
def has_identity(self, jid, node, ifrom, data): async def has_identity(self, jid, node, ifrom, data):
""" """
Check if a JID has a given identity. Check if a JID has a given identity.
@ -110,7 +110,7 @@ class StaticCaps(object):
return True return True
try: try:
info = self.disco.get_info(jid=jid, node=node, info = await self.disco.get_info(jid=jid, node=node,
ifrom=ifrom, **data) ifrom=ifrom, **data)
info = self.disco._wrap(ifrom, jid, info, True) info = self.disco._wrap(ifrom, jid, info, True)
return identity in map(trunc, info['disco_info']['identities']) return identity in map(trunc, info['disco_info']['identities'])