diff --git a/slixmpp/plugins/xep_0115/caps.py b/slixmpp/plugins/xep_0115/caps.py index a606d662..4587e616 100644 --- a/slixmpp/plugins/xep_0115/caps.py +++ b/slixmpp/plugins/xep_0115/caps.py @@ -7,7 +7,8 @@ import logging import hashlib import base64 -from asyncio import Future +from asyncio import Future, Lock +from collections import defaultdict from typing import Optional from slixmpp import __version__ @@ -94,6 +95,9 @@ class XEP_0115(BasePlugin): disco.assign_verstring = self.assign_verstring disco.get_verstring = self.get_verstring + # prevent concurrent fetches for the same hash + self._locks = defaultdict(Lock) + def plugin_end(self): self.xmpp['xep_0030'].del_feature(feature=stanza.Capabilities.namespace) self.xmpp.del_filter('out', self._filter_add_caps) @@ -137,7 +141,7 @@ class XEP_0115(BasePlugin): self.xmpp.event('entity_caps', p) - async def _process_caps(self, pres): + async def _process_caps(self, pres: Presence): if not pres['caps']['hash']: log.debug("Received unsupported legacy caps: %s, %s, %s", pres['caps']['node'], @@ -147,7 +151,11 @@ class XEP_0115(BasePlugin): return ver = pres['caps']['ver'] + async with self._locks[ver]: + await self._process_caps_wrapped(pres, ver) + self._locks.pop(ver, None) + async def _process_caps_wrapped(self, pres: Presence, ver: str): existing_verstring = await self.get_verstring(pres['from'].full) if str(existing_verstring) == str(ver): return diff --git a/tests/test_stream_xep_0115.py b/tests/test_stream_xep_0115.py new file mode 100644 index 00000000..c145234a --- /dev/null +++ b/tests/test_stream_xep_0115.py @@ -0,0 +1,76 @@ +import logging +import unittest +from slixmpp.test import SlixTest + + +class TestCaps(SlixTest): + def setUp(self): + self.stream_start(plugins=["xep_0115"]) + + def testConcurrentSameHash(self): + """ + Check that we only resolve a given ver string to a disco info once, + even if we receive several presences with that same ver string + consecutively. + """ + self.recv( # language=XML + """ + + + + """ + ) + self.recv( # language=XML + """ + + + + """ + ) + self.send( # language=XML + """ + + + + """ + ) + self.send(None) + self.recv( # language=XML + """ + + + + + + + """ + ) + self.send(None) + self.assertTrue( + self.xmpp["xep_0030"].supports( + "romeo@montague.lit/orchard", "http://jabber.org/protocol/caps" + ) + ) + self.assertTrue( + self.xmpp["xep_0030"].supports( + "i-dont-know-much-shakespeare@montague.lit/orchard", + "http://jabber.org/protocol/caps", + ) + ) + + +logging.basicConfig(level=logging.DEBUG) +suite = unittest.TestLoader().loadTestsFromTestCase(TestCaps)