Use CallbackCoroutine with Iq callbacks too

This commit is contained in:
mathieui 2015-02-22 20:13:48 +01:00
parent 2b3b86e281
commit 06358d0665
No known key found for this signature in database
GPG Key ID: C59F84CEEFD616E3
2 changed files with 14 additions and 9 deletions

View File

@ -8,7 +8,7 @@
from slixmpp.stanza.rootstanza import RootStanza from slixmpp.stanza.rootstanza import RootStanza
from slixmpp.xmlstream import StanzaBase, ET from slixmpp.xmlstream import StanzaBase, ET
from slixmpp.xmlstream.handler import Waiter, Callback from slixmpp.xmlstream.handler import Waiter, Callback, CoroutineCallback
from slixmpp.xmlstream.asyncio import asyncio from slixmpp.xmlstream.asyncio import asyncio
from slixmpp.xmlstream.matcher import MatchIDSender, MatcherId from slixmpp.xmlstream.matcher import MatchIDSender, MatcherId
from slixmpp.exceptions import IqTimeout, IqError from slixmpp.exceptions import IqTimeout, IqError
@ -249,6 +249,10 @@ class Iq(RootStanza):
if callback is not None and self['type'] in ('get', 'set'): if callback is not None and self['type'] in ('get', 'set'):
handler_name = 'IqCallback_%s' % self['id'] handler_name = 'IqCallback_%s' % self['id']
if asyncio.iscoroutinefunction(callback):
constr = CoroutineCallback
else:
constr = Callback
if timeout_callback: if timeout_callback:
self.callback = callback self.callback = callback
self.timeout_callback = timeout_callback self.timeout_callback = timeout_callback
@ -256,15 +260,15 @@ class Iq(RootStanza):
timeout, timeout,
self._fire_timeout, self._fire_timeout,
repeat=False) repeat=False)
handler = Callback(handler_name, handler = constr(handler_name,
matcher, matcher,
self._handle_result, self._handle_result,
once=True) once=True)
else: else:
handler = Callback(handler_name, handler = constr(handler_name,
matcher, matcher,
callback, callback,
once=True) once=True)
self.stream.register_handler(handler) self.stream.register_handler(handler)
StanzaBase.send(self) StanzaBase.send(self)
return handler_name return handler_name

View File

@ -489,6 +489,7 @@ class XMLStream(asyncio.BaseProtocol):
ssl_connect_routine = loop.create_connection(lambda: self, ssl=self.ssl_context, ssl_connect_routine = loop.create_connection(lambda: self, ssl=self.ssl_context,
sock=self.socket, sock=self.socket,
server_hostname=self.address[0]) server_hostname=self.address[0])
@asyncio.coroutine
def ssl_coro(): def ssl_coro():
try: try:
transp, prot = yield from ssl_connect_routine transp, prot = yield from ssl_connect_routine