Use CallbackCoroutine with Iq callbacks too

This commit is contained in:
mathieui
2015-02-22 20:13:48 +01:00
parent 2b3b86e281
commit 06358d0665
2 changed files with 14 additions and 9 deletions

View File

@@ -8,7 +8,7 @@
from slixmpp.stanza.rootstanza import RootStanza from slixmpp.stanza.rootstanza import RootStanza
from slixmpp.xmlstream import StanzaBase, ET from slixmpp.xmlstream import StanzaBase, ET
from slixmpp.xmlstream.handler import Waiter, Callback from slixmpp.xmlstream.handler import Waiter, Callback, CoroutineCallback
from slixmpp.xmlstream.asyncio import asyncio from slixmpp.xmlstream.asyncio import asyncio
from slixmpp.xmlstream.matcher import MatchIDSender, MatcherId from slixmpp.xmlstream.matcher import MatchIDSender, MatcherId
from slixmpp.exceptions import IqTimeout, IqError from slixmpp.exceptions import IqTimeout, IqError
@@ -249,6 +249,10 @@ class Iq(RootStanza):
if callback is not None and self['type'] in ('get', 'set'): if callback is not None and self['type'] in ('get', 'set'):
handler_name = 'IqCallback_%s' % self['id'] handler_name = 'IqCallback_%s' % self['id']
if asyncio.iscoroutinefunction(callback):
constr = CoroutineCallback
else:
constr = Callback
if timeout_callback: if timeout_callback:
self.callback = callback self.callback = callback
self.timeout_callback = timeout_callback self.timeout_callback = timeout_callback
@@ -256,12 +260,12 @@ class Iq(RootStanza):
timeout, timeout,
self._fire_timeout, self._fire_timeout,
repeat=False) repeat=False)
handler = Callback(handler_name, handler = constr(handler_name,
matcher, matcher,
self._handle_result, self._handle_result,
once=True) once=True)
else: else:
handler = Callback(handler_name, handler = constr(handler_name,
matcher, matcher,
callback, callback,
once=True) once=True)

View File

@@ -489,6 +489,7 @@ class XMLStream(asyncio.BaseProtocol):
ssl_connect_routine = loop.create_connection(lambda: self, ssl=self.ssl_context, ssl_connect_routine = loop.create_connection(lambda: self, ssl=self.ssl_context,
sock=self.socket, sock=self.socket,
server_hostname=self.address[0]) server_hostname=self.address[0])
@asyncio.coroutine
def ssl_coro(): def ssl_coro():
try: try:
transp, prot = yield from ssl_connect_routine transp, prot = yield from ssl_connect_routine