xep-0115: perf: avoid simultaneous disco info queries for the same verstring

This commit is contained in:
nicoco 2023-08-20 01:07:57 +02:00 committed by mathieui
parent 7de5cbcf33
commit ca90d3908e
2 changed files with 86 additions and 2 deletions

View File

@ -7,7 +7,8 @@ import logging
import hashlib import hashlib
import base64 import base64
from asyncio import Future from asyncio import Future, Lock
from collections import defaultdict
from typing import Optional from typing import Optional
from slixmpp import __version__ from slixmpp import __version__
@ -94,6 +95,9 @@ class XEP_0115(BasePlugin):
disco.assign_verstring = self.assign_verstring disco.assign_verstring = self.assign_verstring
disco.get_verstring = self.get_verstring disco.get_verstring = self.get_verstring
# prevent concurrent fetches for the same hash
self._locks = defaultdict(Lock)
def plugin_end(self): def plugin_end(self):
self.xmpp['xep_0030'].del_feature(feature=stanza.Capabilities.namespace) self.xmpp['xep_0030'].del_feature(feature=stanza.Capabilities.namespace)
self.xmpp.del_filter('out', self._filter_add_caps) self.xmpp.del_filter('out', self._filter_add_caps)
@ -137,7 +141,7 @@ class XEP_0115(BasePlugin):
self.xmpp.event('entity_caps', p) self.xmpp.event('entity_caps', p)
async def _process_caps(self, pres): async def _process_caps(self, pres: Presence):
if not pres['caps']['hash']: if not pres['caps']['hash']:
log.debug("Received unsupported legacy caps: %s, %s, %s", log.debug("Received unsupported legacy caps: %s, %s, %s",
pres['caps']['node'], pres['caps']['node'],
@ -147,7 +151,11 @@ class XEP_0115(BasePlugin):
return return
ver = pres['caps']['ver'] ver = pres['caps']['ver']
async with self._locks[ver]:
await self._process_caps_wrapped(pres, ver)
self._locks.pop(ver, None)
async def _process_caps_wrapped(self, pres: Presence, ver: str):
existing_verstring = await self.get_verstring(pres['from'].full) existing_verstring = await self.get_verstring(pres['from'].full)
if str(existing_verstring) == str(ver): if str(existing_verstring) == str(ver):
return return

View File

@ -0,0 +1,76 @@
import logging
import unittest
from slixmpp.test import SlixTest
class TestCaps(SlixTest):
def setUp(self):
self.stream_start(plugins=["xep_0115"])
def testConcurrentSameHash(self):
"""
Check that we only resolve a given ver string to a disco info once,
even if we receive several presences with that same ver string
consecutively.
"""
self.recv( # language=XML
"""
<presence from='romeo@montague.lit/orchard'>
<c xmlns='http://jabber.org/protocol/caps'
hash='sha-1'
node='a-node'
ver='h0TdMvqNR8FHUfFG1HauOLYZDqE='/>
</presence>
"""
)
self.recv( # language=XML
"""
<presence from='i-dont-know-much-shakespeare@montague.lit/orchard'>
<c xmlns='http://jabber.org/protocol/caps'
hash='sha-1'
node='a-node'
ver='h0TdMvqNR8FHUfFG1HauOLYZDqE='/>
</presence>
"""
)
self.send( # language=XML
"""
<iq xmlns="jabber:client"
id="1"
to="romeo@montague.lit/orchard"
type="get">
<query xmlns="http://jabber.org/protocol/disco#info"
node="a-node#h0TdMvqNR8FHUfFG1HauOLYZDqE="/>
</iq>
"""
)
self.send(None)
self.recv( # language=XML
"""
<iq from='romeo@montague.lit/orchard'
id='1'
type='result'>
<query xmlns='http://jabber.org/protocol/disco#info'
node='a-nodes#h0TdMvqNR8FHUfFG1HauOLYZDqE='>
<identity category='client' name='a client' type='pc'/>
<feature var='http://jabber.org/protocol/caps'/>
</query>
</iq>
"""
)
self.send(None)
self.assertTrue(
self.xmpp["xep_0030"].supports(
"romeo@montague.lit/orchard", "http://jabber.org/protocol/caps"
)
)
self.assertTrue(
self.xmpp["xep_0030"].supports(
"i-dont-know-much-shakespeare@montague.lit/orchard",
"http://jabber.org/protocol/caps",
)
)
logging.basicConfig(level=logging.DEBUG)
suite = unittest.TestLoader().loadTestsFromTestCase(TestCaps)