Allow the use of a custom loop instead of asyncio.get_event_loop()

This commit is contained in:
mathieui 2015-05-12 00:02:32 +02:00
parent f1e6d6b0a9
commit a2852eb249
No known key found for this signature in database
GPG Key ID: C59F84CEEFD616E3
3 changed files with 42 additions and 36 deletions

View File

@ -332,7 +332,7 @@ class XEP_0325(BasePlugin):
self.sessions[session]["nodeDone"][node] = False self.sessions[session]["nodeDone"][node] = False
for node in self.sessions[session]["node_list"]: for node in self.sessions[session]["node_list"]:
timer = asyncio.get_event_loop().call_later(self.nodes[node]['commTimeout'], partial(self._event_comm_timeout, args=(session, node))) timer = self.xmpp.loop.call_later(self.nodes[node]['commTimeout'], partial(self._event_comm_timeout, args=(session, node)))
self.sessions[session]["commTimers"][node] = timer self.sessions[session]["commTimers"][node] = timer
self.nodes[node]['device'].set_control_fields(process_fields, session=session, callback=self._device_set_command_callback) self.nodes[node]['device'].set_control_fields(process_fields, session=session, callback=self._device_set_command_callback)

View File

@ -32,14 +32,14 @@ except ImportError as e:
"Not all features will be available") "Not all features will be available")
def default_resolver(): def default_resolver(loop):
"""Return a basic DNS resolver object. """Return a basic DNS resolver object.
:returns: A :class:`aiodns.DNSResolver` object if aiodns :returns: A :class:`aiodns.DNSResolver` object if aiodns
is available. Otherwise, ``None``. is available. Otherwise, ``None``.
""" """
if AIODNS_AVAILABLE: if AIODNS_AVAILABLE:
return aiodns.DNSResolver(loop=asyncio.get_event_loop(), return aiodns.DNSResolver(loop=loop,
tries=1, tries=1,
timeout=1.0) timeout=1.0)
return None return None
@ -47,7 +47,7 @@ def default_resolver():
@asyncio.coroutine @asyncio.coroutine
def resolve(host, port=None, service=None, proto='tcp', def resolve(host, port=None, service=None, proto='tcp',
resolver=None, use_ipv6=True, use_aiodns=True): resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
"""Peform DNS resolution for a given hostname. """Peform DNS resolution for a given hostname.
Resolution may perform SRV record lookups if a service and protocol Resolution may perform SRV record lookups if a service and protocol
@ -97,7 +97,7 @@ def resolve(host, port=None, service=None, proto='tcp',
log.debug("DNS: Use of IPv6 has been disabled.") log.debug("DNS: Use of IPv6 has been disabled.")
if resolver is None and AIODNS_AVAILABLE and use_aiodns: if resolver is None and AIODNS_AVAILABLE and use_aiodns:
resolver = aiodns.DNSResolver(loop=asyncio.get_event_loop()) resolver = aiodns.DNSResolver(loop=loop)
# An IPv6 literal is allowed to be enclosed in square brackets, but # An IPv6 literal is allowed to be enclosed in square brackets, but
# the brackets must be stripped in order to process the literal; # the brackets must be stripped in order to process the literal;
@ -142,19 +142,19 @@ def resolve(host, port=None, service=None, proto='tcp',
if use_ipv6: if use_ipv6:
aaaa = yield from get_AAAA(host, resolver=resolver, aaaa = yield from get_AAAA(host, resolver=resolver,
use_aiodns=use_aiodns) use_aiodns=use_aiodns, loop=loop)
for address in aaaa: for address in aaaa:
results.append((host, address, port)) results.append((host, address, port))
a = yield from get_A(host, resolver=resolver, a = yield from get_A(host, resolver=resolver,
use_aiodns=use_aiodns) use_aiodns=use_aiodns, loop=loop)
for address in a: for address in a:
results.append((host, address, port)) results.append((host, address, port))
return results return results
@asyncio.coroutine @asyncio.coroutine
def get_A(host, resolver=None, use_aiodns=True): def get_A(host, resolver=None, use_aiodns=True, loop=None):
"""Lookup DNS A records for a given host. """Lookup DNS A records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will If ``resolver`` is not provided, or is ``None``, then resolution will
@ -177,7 +177,6 @@ def get_A(host, resolver=None, use_aiodns=True):
# If not using aiodns, attempt lookup using the OS level # If not using aiodns, attempt lookup using the OS level
# getaddrinfo() method. # getaddrinfo() method.
if resolver is None or not use_aiodns: if resolver is None or not use_aiodns:
loop = asyncio.get_event_loop()
try: try:
recs = yield from loop.getaddrinfo(host, None, recs = yield from loop.getaddrinfo(host, None,
family=socket.AF_INET, family=socket.AF_INET,
@ -198,7 +197,7 @@ def get_A(host, resolver=None, use_aiodns=True):
@asyncio.coroutine @asyncio.coroutine
def get_AAAA(host, resolver=None, use_aiodns=True): def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
"""Lookup DNS AAAA records for a given host. """Lookup DNS AAAA records for a given host.
If ``resolver`` is not provided, or is ``None``, then resolution will If ``resolver`` is not provided, or is ``None``, then resolution will
@ -224,7 +223,6 @@ def get_AAAA(host, resolver=None, use_aiodns=True):
if not socket.has_ipv6: if not socket.has_ipv6:
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host) log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
return [] return []
loop = asyncio.get_event_loop()
try: try:
recs = yield from loop.getaddrinfo(host, None, recs = yield from loop.getaddrinfo(host, None,
family=socket.AF_INET6, family=socket.AF_INET6,

View File

@ -116,6 +116,9 @@ class XMLStream(asyncio.BaseProtocol):
self._der_cert = None self._der_cert = None
# The asyncio event loop
self._loop = None
#: The default port to return when querying DNS records. #: The default port to return when querying DNS records.
self.default_port = int(port) self.default_port = int(port)
@ -213,6 +216,16 @@ class XMLStream(asyncio.BaseProtocol):
self.add_event_handler('disconnected', self._remove_schedules) self.add_event_handler('disconnected', self._remove_schedules)
self.add_event_handler('session_start', self._start_keepalive) self.add_event_handler('session_start', self._start_keepalive)
@property
def loop(self):
if self._loop is None:
self._loop = asyncio.get_event_loop()
return self._loop
@loop.setter
def loop(self, value):
self._loop = value
def new_id(self): def new_id(self):
"""Generate and return a new stream ID in hexadecimal form. """Generate and return a new stream ID in hexadecimal form.
@ -270,7 +283,6 @@ class XMLStream(asyncio.BaseProtocol):
@asyncio.coroutine @asyncio.coroutine
def _connect_routine(self): def _connect_routine(self):
loop = asyncio.get_event_loop()
self.event_when_connected = "connected" self.event_when_connected = "connected"
try: try:
@ -290,7 +302,7 @@ class XMLStream(asyncio.BaseProtocol):
self.dns_answers = None self.dns_answers = None
try: try:
yield from loop.create_connection(lambda: self, yield from self.loop.create_connection(lambda: self,
self.address[0], self.address[0],
self.address[1], self.address[1],
ssl=self.use_ssl) ssl=self.use_ssl)
@ -309,17 +321,16 @@ class XMLStream(asyncio.BaseProtocol):
function will run forever. If timeout is a number, this function function will run forever. If timeout is a number, this function
will return after the given time in seconds. will return after the given time in seconds.
""" """
loop = asyncio.get_event_loop()
if timeout is None: if timeout is None:
if forever: if forever:
loop.run_forever() self.loop.run_forever()
else: else:
loop.run_until_complete(self.disconnected) self.loop.run_until_complete(self.disconnected)
else: else:
tasks = [asyncio.sleep(timeout)] tasks = [asyncio.sleep(timeout)]
if not forever: if not forever:
tasks.append(self.disconnected) tasks.append(self.disconnected)
loop.run_until_complete(asyncio.wait(tasks)) self.loop.run_until_complete(asyncio.wait(tasks))
def init_parser(self): def init_parser(self):
"""init the XML parser. The parser must always be reset for each new """init the XML parser. The parser must always be reset for each new
@ -367,8 +378,7 @@ class XMLStream(asyncio.BaseProtocol):
elif self.xml_depth == 1: elif self.xml_depth == 1:
# A stanza is an XML element that is a direct child of # A stanza is an XML element that is a direct child of
# the root element, hence the check of depth == 1 # the root element, hence the check of depth == 1
asyncio.get_event_loop().\ self.loop.idle_call(functools.partial(self.__spawn_event, xml))
idle_call(functools.partial(self.__spawn_event, xml))
if self.xml_root is not None: if self.xml_root is not None:
# Keep the root element empty of children to # Keep the root element empty of children to
# save on memory use. # save on memory use.
@ -461,7 +471,6 @@ class XMLStream(asyncio.BaseProtocol):
If the handshake is successful, the XML stream will need If the handshake is successful, the XML stream will need
to be restarted. to be restarted.
""" """
loop = asyncio.get_event_loop()
self.event_when_connected = "tls_success" self.event_when_connected = "tls_success"
if self.ciphers is not None: if self.ciphers is not None:
@ -478,7 +487,7 @@ class XMLStream(asyncio.BaseProtocol):
self.ssl_context.verify_mode = ssl.CERT_REQUIRED self.ssl_context.verify_mode = ssl.CERT_REQUIRED
self.ssl_context.load_verify_locations(cafile=self.ca_certs) self.ssl_context.load_verify_locations(cafile=self.ca_certs)
ssl_connect_routine = loop.create_connection(lambda: self, ssl=self.ssl_context, ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=self.ssl_context,
sock=self.socket, sock=self.socket,
server_hostname=self.address[0]) server_hostname=self.address[0])
@asyncio.coroutine @asyncio.coroutine
@ -621,14 +630,15 @@ class XMLStream(asyncio.BaseProtocol):
if port is None: if port is None:
port = self.default_port port = self.default_port
resolver = default_resolver() resolver = default_resolver(loop=self.loop)
self.configure_dns(resolver, domain=domain, port=port) self.configure_dns(resolver, domain=domain, port=port)
result = yield from resolve(domain, port, result = yield from resolve(domain, port,
service=self.dns_service, service=self.dns_service,
resolver=resolver, resolver=resolver,
use_ipv6=self.use_ipv6, use_ipv6=self.use_ipv6,
use_aiodns=self.use_aiodns) use_aiodns=self.use_aiodns,
loop=self.loop)
return result return result
@asyncio.coroutine @asyncio.coroutine
@ -746,13 +756,12 @@ class XMLStream(asyncio.BaseProtocol):
""" """
if seconds is None: if seconds is None:
seconds = RESPONSE_TIMEOUT seconds = RESPONSE_TIMEOUT
loop = asyncio.get_event_loop()
cb = functools.partial(callback, *args, **kwargs) cb = functools.partial(callback, *args, **kwargs)
if repeat: if repeat:
handle = loop.call_later(seconds, self._execute_and_reschedule, handle = self.loop.call_later(seconds, self._execute_and_reschedule,
name, cb, seconds) name, cb, seconds)
else: else:
handle = loop.call_later(seconds, self._execute_and_unschedule, handle = self.loop.call_later(seconds, self._execute_and_unschedule,
name, cb) name, cb)
# Save that handle, so we can just cancel this scheduled event by # Save that handle, so we can just cancel this scheduled event by
@ -778,8 +787,7 @@ class XMLStream(asyncio.BaseProtocol):
be called after the given number of seconds. be called after the given number of seconds.
""" """
self._safe_cb_run(name, cb) self._safe_cb_run(name, cb)
loop = asyncio.get_event_loop() handle = self.loop.call_later(seconds, self._execute_and_reschedule,
handle = loop.call_later(seconds, self._execute_and_reschedule,
name, cb, seconds) name, cb, seconds)
self.scheduled_events[name] = handle self.scheduled_events[name] = handle