Fix the ordering of stream features

since iq.send is non-blocking, some features handlers could end up
being executed before others were set, leading to issues. Adding yield
from where it’s necessary fixes that.
This commit is contained in:
mathieui 2016-05-28 14:46:39 +02:00
parent bd6ec10939
commit 4905407092
3 changed files with 17 additions and 7 deletions

View File

@ -12,6 +12,7 @@
:license: MIT, see LICENSE for more details :license: MIT, see LICENSE for more details
""" """
import asyncio
import logging import logging
from slixmpp.stanza import StreamFeatures from slixmpp.stanza import StreamFeatures
@ -19,7 +20,7 @@ from slixmpp.basexmpp import BaseXMPP
from slixmpp.exceptions import XMPPError from slixmpp.exceptions import XMPPError
from slixmpp.xmlstream import XMLStream from slixmpp.xmlstream import XMLStream
from slixmpp.xmlstream.matcher import StanzaPath, MatchXPath from slixmpp.xmlstream.matcher import StanzaPath, MatchXPath
from slixmpp.xmlstream.handler import Callback from slixmpp.xmlstream.handler import Callback, CoroutineCallback
# Flag indicating if DNS SRV records are available for use. # Flag indicating if DNS SRV records are available for use.
try: try:
@ -104,7 +105,7 @@ class ClientXMPP(BaseXMPP):
self.register_stanza(StreamFeatures) self.register_stanza(StreamFeatures)
self.register_handler( self.register_handler(
Callback('Stream Features', CoroutineCallback('Stream Features',
MatchXPath('{%s}features' % self.stream_ns), MatchXPath('{%s}features' % self.stream_ns),
self._handle_stream_features)) self._handle_stream_features))
self.register_handler( self.register_handler(
@ -249,6 +250,7 @@ class ClientXMPP(BaseXMPP):
self.bindfail = False self.bindfail = False
self.features = set() self.features = set()
@asyncio.coroutine
def _handle_stream_features(self, features): def _handle_stream_features(self, features):
"""Process the received stream features. """Process the received stream features.
@ -257,7 +259,11 @@ class ClientXMPP(BaseXMPP):
for order, name in self._stream_feature_order: for order, name in self._stream_feature_order:
if name in features['features']: if name in features['features']:
handler, restart = self._stream_feature_handlers[name] handler, restart = self._stream_feature_handlers[name]
if handler(features) and restart: if asyncio.iscoroutinefunction(handler):
result = yield from handler(features)
else:
result = handler(features)
if result and restart:
# Don't continue if the feature requires # Don't continue if the feature requires
# restarting the XML stream. # restarting the XML stream.
return True return True

View File

@ -6,6 +6,7 @@
See the file LICENSE for copying permission. See the file LICENSE for copying permission.
""" """
import asyncio
import logging import logging
from slixmpp.jid import JID from slixmpp.jid import JID
@ -34,6 +35,7 @@ class FeatureBind(BasePlugin):
register_stanza_plugin(Iq, stanza.Bind) register_stanza_plugin(Iq, stanza.Bind)
register_stanza_plugin(StreamFeatures, stanza.Bind) register_stanza_plugin(StreamFeatures, stanza.Bind)
@asyncio.coroutine
def _handle_bind_resource(self, features): def _handle_bind_resource(self, features):
""" """
Handle requesting a specific resource. Handle requesting a specific resource.
@ -49,7 +51,7 @@ class FeatureBind(BasePlugin):
if self.xmpp.requested_jid.resource: if self.xmpp.requested_jid.resource:
iq['bind']['resource'] = self.xmpp.requested_jid.resource iq['bind']['resource'] = self.xmpp.requested_jid.resource
iq.send(callback=self._on_bind_response) yield from iq.send(callback=self._on_bind_response)
def _on_bind_response(self, response): def _on_bind_response(self, response):
self.xmpp.boundjid = JID(response['bind']['jid']) self.xmpp.boundjid = JID(response['bind']['jid'])

View File

@ -6,6 +6,7 @@
See the file LICENSE for copying permission. See the file LICENSE for copying permission.
""" """
import asyncio
import logging import logging
from slixmpp.stanza import Iq, StreamFeatures from slixmpp.stanza import Iq, StreamFeatures
@ -34,6 +35,7 @@ class FeatureSession(BasePlugin):
register_stanza_plugin(Iq, stanza.Session) register_stanza_plugin(Iq, stanza.Session)
register_stanza_plugin(StreamFeatures, stanza.Session) register_stanza_plugin(StreamFeatures, stanza.Session)
@asyncio.coroutine
def _handle_start_session(self, features): def _handle_start_session(self, features):
""" """
Handle the start of the session. Handle the start of the session.
@ -44,7 +46,7 @@ class FeatureSession(BasePlugin):
iq = self.xmpp.Iq() iq = self.xmpp.Iq()
iq['type'] = 'set' iq['type'] = 'set'
iq.enable('session') iq.enable('session')
iq.send(callback=self._on_start_session_response) yield from iq.send(callback=self._on_start_session_response)
def _on_start_session_response(self, response): def _on_start_session_response(self, response):
self.xmpp.features.add('session') self.xmpp.features.add('session')