Allow the use of a custom loop instead of asyncio.get_event_loop()
This commit is contained in:
parent
f1e6d6b0a9
commit
a2852eb249
@ -332,7 +332,7 @@ class XEP_0325(BasePlugin):
|
|||||||
self.sessions[session]["nodeDone"][node] = False
|
self.sessions[session]["nodeDone"][node] = False
|
||||||
|
|
||||||
for node in self.sessions[session]["node_list"]:
|
for node in self.sessions[session]["node_list"]:
|
||||||
timer = asyncio.get_event_loop().call_later(self.nodes[node]['commTimeout'], partial(self._event_comm_timeout, args=(session, node)))
|
timer = self.xmpp.loop.call_later(self.nodes[node]['commTimeout'], partial(self._event_comm_timeout, args=(session, node)))
|
||||||
self.sessions[session]["commTimers"][node] = timer
|
self.sessions[session]["commTimers"][node] = timer
|
||||||
self.nodes[node]['device'].set_control_fields(process_fields, session=session, callback=self._device_set_command_callback)
|
self.nodes[node]['device'].set_control_fields(process_fields, session=session, callback=self._device_set_command_callback)
|
||||||
|
|
||||||
|
@ -32,14 +32,14 @@ except ImportError as e:
|
|||||||
"Not all features will be available")
|
"Not all features will be available")
|
||||||
|
|
||||||
|
|
||||||
def default_resolver():
|
def default_resolver(loop):
|
||||||
"""Return a basic DNS resolver object.
|
"""Return a basic DNS resolver object.
|
||||||
|
|
||||||
:returns: A :class:`aiodns.DNSResolver` object if aiodns
|
:returns: A :class:`aiodns.DNSResolver` object if aiodns
|
||||||
is available. Otherwise, ``None``.
|
is available. Otherwise, ``None``.
|
||||||
"""
|
"""
|
||||||
if AIODNS_AVAILABLE:
|
if AIODNS_AVAILABLE:
|
||||||
return aiodns.DNSResolver(loop=asyncio.get_event_loop(),
|
return aiodns.DNSResolver(loop=loop,
|
||||||
tries=1,
|
tries=1,
|
||||||
timeout=1.0)
|
timeout=1.0)
|
||||||
return None
|
return None
|
||||||
@ -47,7 +47,7 @@ def default_resolver():
|
|||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def resolve(host, port=None, service=None, proto='tcp',
|
def resolve(host, port=None, service=None, proto='tcp',
|
||||||
resolver=None, use_ipv6=True, use_aiodns=True):
|
resolver=None, use_ipv6=True, use_aiodns=True, loop=None):
|
||||||
"""Peform DNS resolution for a given hostname.
|
"""Peform DNS resolution for a given hostname.
|
||||||
|
|
||||||
Resolution may perform SRV record lookups if a service and protocol
|
Resolution may perform SRV record lookups if a service and protocol
|
||||||
@ -97,7 +97,7 @@ def resolve(host, port=None, service=None, proto='tcp',
|
|||||||
log.debug("DNS: Use of IPv6 has been disabled.")
|
log.debug("DNS: Use of IPv6 has been disabled.")
|
||||||
|
|
||||||
if resolver is None and AIODNS_AVAILABLE and use_aiodns:
|
if resolver is None and AIODNS_AVAILABLE and use_aiodns:
|
||||||
resolver = aiodns.DNSResolver(loop=asyncio.get_event_loop())
|
resolver = aiodns.DNSResolver(loop=loop)
|
||||||
|
|
||||||
# An IPv6 literal is allowed to be enclosed in square brackets, but
|
# An IPv6 literal is allowed to be enclosed in square brackets, but
|
||||||
# the brackets must be stripped in order to process the literal;
|
# the brackets must be stripped in order to process the literal;
|
||||||
@ -142,19 +142,19 @@ def resolve(host, port=None, service=None, proto='tcp',
|
|||||||
|
|
||||||
if use_ipv6:
|
if use_ipv6:
|
||||||
aaaa = yield from get_AAAA(host, resolver=resolver,
|
aaaa = yield from get_AAAA(host, resolver=resolver,
|
||||||
use_aiodns=use_aiodns)
|
use_aiodns=use_aiodns, loop=loop)
|
||||||
for address in aaaa:
|
for address in aaaa:
|
||||||
results.append((host, address, port))
|
results.append((host, address, port))
|
||||||
|
|
||||||
a = yield from get_A(host, resolver=resolver,
|
a = yield from get_A(host, resolver=resolver,
|
||||||
use_aiodns=use_aiodns)
|
use_aiodns=use_aiodns, loop=loop)
|
||||||
for address in a:
|
for address in a:
|
||||||
results.append((host, address, port))
|
results.append((host, address, port))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def get_A(host, resolver=None, use_aiodns=True):
|
def get_A(host, resolver=None, use_aiodns=True, loop=None):
|
||||||
"""Lookup DNS A records for a given host.
|
"""Lookup DNS A records for a given host.
|
||||||
|
|
||||||
If ``resolver`` is not provided, or is ``None``, then resolution will
|
If ``resolver`` is not provided, or is ``None``, then resolution will
|
||||||
@ -177,7 +177,6 @@ def get_A(host, resolver=None, use_aiodns=True):
|
|||||||
# If not using aiodns, attempt lookup using the OS level
|
# If not using aiodns, attempt lookup using the OS level
|
||||||
# getaddrinfo() method.
|
# getaddrinfo() method.
|
||||||
if resolver is None or not use_aiodns:
|
if resolver is None or not use_aiodns:
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
try:
|
try:
|
||||||
recs = yield from loop.getaddrinfo(host, None,
|
recs = yield from loop.getaddrinfo(host, None,
|
||||||
family=socket.AF_INET,
|
family=socket.AF_INET,
|
||||||
@ -198,7 +197,7 @@ def get_A(host, resolver=None, use_aiodns=True):
|
|||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def get_AAAA(host, resolver=None, use_aiodns=True):
|
def get_AAAA(host, resolver=None, use_aiodns=True, loop=None):
|
||||||
"""Lookup DNS AAAA records for a given host.
|
"""Lookup DNS AAAA records for a given host.
|
||||||
|
|
||||||
If ``resolver`` is not provided, or is ``None``, then resolution will
|
If ``resolver`` is not provided, or is ``None``, then resolution will
|
||||||
@ -224,7 +223,6 @@ def get_AAAA(host, resolver=None, use_aiodns=True):
|
|||||||
if not socket.has_ipv6:
|
if not socket.has_ipv6:
|
||||||
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
|
log.debug("DNS: Unable to query %s for AAAA records: IPv6 is not supported", host)
|
||||||
return []
|
return []
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
try:
|
try:
|
||||||
recs = yield from loop.getaddrinfo(host, None,
|
recs = yield from loop.getaddrinfo(host, None,
|
||||||
family=socket.AF_INET6,
|
family=socket.AF_INET6,
|
||||||
|
@ -116,6 +116,9 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
|
|
||||||
self._der_cert = None
|
self._der_cert = None
|
||||||
|
|
||||||
|
# The asyncio event loop
|
||||||
|
self._loop = None
|
||||||
|
|
||||||
#: The default port to return when querying DNS records.
|
#: The default port to return when querying DNS records.
|
||||||
self.default_port = int(port)
|
self.default_port = int(port)
|
||||||
|
|
||||||
@ -213,6 +216,16 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.add_event_handler('disconnected', self._remove_schedules)
|
self.add_event_handler('disconnected', self._remove_schedules)
|
||||||
self.add_event_handler('session_start', self._start_keepalive)
|
self.add_event_handler('session_start', self._start_keepalive)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop(self):
|
||||||
|
if self._loop is None:
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
return self._loop
|
||||||
|
|
||||||
|
@loop.setter
|
||||||
|
def loop(self, value):
|
||||||
|
self._loop = value
|
||||||
|
|
||||||
def new_id(self):
|
def new_id(self):
|
||||||
"""Generate and return a new stream ID in hexadecimal form.
|
"""Generate and return a new stream ID in hexadecimal form.
|
||||||
|
|
||||||
@ -270,7 +283,6 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def _connect_routine(self):
|
def _connect_routine(self):
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
self.event_when_connected = "connected"
|
self.event_when_connected = "connected"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -290,7 +302,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.dns_answers = None
|
self.dns_answers = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield from loop.create_connection(lambda: self,
|
yield from self.loop.create_connection(lambda: self,
|
||||||
self.address[0],
|
self.address[0],
|
||||||
self.address[1],
|
self.address[1],
|
||||||
ssl=self.use_ssl)
|
ssl=self.use_ssl)
|
||||||
@ -309,17 +321,16 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
function will run forever. If timeout is a number, this function
|
function will run forever. If timeout is a number, this function
|
||||||
will return after the given time in seconds.
|
will return after the given time in seconds.
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if timeout is None:
|
if timeout is None:
|
||||||
if forever:
|
if forever:
|
||||||
loop.run_forever()
|
self.loop.run_forever()
|
||||||
else:
|
else:
|
||||||
loop.run_until_complete(self.disconnected)
|
self.loop.run_until_complete(self.disconnected)
|
||||||
else:
|
else:
|
||||||
tasks = [asyncio.sleep(timeout)]
|
tasks = [asyncio.sleep(timeout)]
|
||||||
if not forever:
|
if not forever:
|
||||||
tasks.append(self.disconnected)
|
tasks.append(self.disconnected)
|
||||||
loop.run_until_complete(asyncio.wait(tasks))
|
self.loop.run_until_complete(asyncio.wait(tasks))
|
||||||
|
|
||||||
def init_parser(self):
|
def init_parser(self):
|
||||||
"""init the XML parser. The parser must always be reset for each new
|
"""init the XML parser. The parser must always be reset for each new
|
||||||
@ -367,8 +378,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
elif self.xml_depth == 1:
|
elif self.xml_depth == 1:
|
||||||
# A stanza is an XML element that is a direct child of
|
# A stanza is an XML element that is a direct child of
|
||||||
# the root element, hence the check of depth == 1
|
# the root element, hence the check of depth == 1
|
||||||
asyncio.get_event_loop().\
|
self.loop.idle_call(functools.partial(self.__spawn_event, xml))
|
||||||
idle_call(functools.partial(self.__spawn_event, xml))
|
|
||||||
if self.xml_root is not None:
|
if self.xml_root is not None:
|
||||||
# Keep the root element empty of children to
|
# Keep the root element empty of children to
|
||||||
# save on memory use.
|
# save on memory use.
|
||||||
@ -461,7 +471,6 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
If the handshake is successful, the XML stream will need
|
If the handshake is successful, the XML stream will need
|
||||||
to be restarted.
|
to be restarted.
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
self.event_when_connected = "tls_success"
|
self.event_when_connected = "tls_success"
|
||||||
|
|
||||||
if self.ciphers is not None:
|
if self.ciphers is not None:
|
||||||
@ -478,7 +487,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
|
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
|
||||||
self.ssl_context.load_verify_locations(cafile=self.ca_certs)
|
self.ssl_context.load_verify_locations(cafile=self.ca_certs)
|
||||||
|
|
||||||
ssl_connect_routine = loop.create_connection(lambda: self, ssl=self.ssl_context,
|
ssl_connect_routine = self.loop.create_connection(lambda: self, ssl=self.ssl_context,
|
||||||
sock=self.socket,
|
sock=self.socket,
|
||||||
server_hostname=self.address[0])
|
server_hostname=self.address[0])
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
@ -621,14 +630,15 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
if port is None:
|
if port is None:
|
||||||
port = self.default_port
|
port = self.default_port
|
||||||
|
|
||||||
resolver = default_resolver()
|
resolver = default_resolver(loop=self.loop)
|
||||||
self.configure_dns(resolver, domain=domain, port=port)
|
self.configure_dns(resolver, domain=domain, port=port)
|
||||||
|
|
||||||
result = yield from resolve(domain, port,
|
result = yield from resolve(domain, port,
|
||||||
service=self.dns_service,
|
service=self.dns_service,
|
||||||
resolver=resolver,
|
resolver=resolver,
|
||||||
use_ipv6=self.use_ipv6,
|
use_ipv6=self.use_ipv6,
|
||||||
use_aiodns=self.use_aiodns)
|
use_aiodns=self.use_aiodns,
|
||||||
|
loop=self.loop)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
@ -746,13 +756,12 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
"""
|
"""
|
||||||
if seconds is None:
|
if seconds is None:
|
||||||
seconds = RESPONSE_TIMEOUT
|
seconds = RESPONSE_TIMEOUT
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
cb = functools.partial(callback, *args, **kwargs)
|
cb = functools.partial(callback, *args, **kwargs)
|
||||||
if repeat:
|
if repeat:
|
||||||
handle = loop.call_later(seconds, self._execute_and_reschedule,
|
handle = self.loop.call_later(seconds, self._execute_and_reschedule,
|
||||||
name, cb, seconds)
|
name, cb, seconds)
|
||||||
else:
|
else:
|
||||||
handle = loop.call_later(seconds, self._execute_and_unschedule,
|
handle = self.loop.call_later(seconds, self._execute_and_unschedule,
|
||||||
name, cb)
|
name, cb)
|
||||||
|
|
||||||
# Save that handle, so we can just cancel this scheduled event by
|
# Save that handle, so we can just cancel this scheduled event by
|
||||||
@ -778,8 +787,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
be called after the given number of seconds.
|
be called after the given number of seconds.
|
||||||
"""
|
"""
|
||||||
self._safe_cb_run(name, cb)
|
self._safe_cb_run(name, cb)
|
||||||
loop = asyncio.get_event_loop()
|
handle = self.loop.call_later(seconds, self._execute_and_reschedule,
|
||||||
handle = loop.call_later(seconds, self._execute_and_reschedule,
|
|
||||||
name, cb, seconds)
|
name, cb, seconds)
|
||||||
self.scheduled_events[name] = handle
|
self.scheduled_events[name] = handle
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user