Add better DNS resolver wrapper.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
sleekxmpp.xmlstream.xmlstream
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@@ -39,19 +38,13 @@ from sleekxmpp.xmlstream import Scheduler, tostring
|
||||
from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase
|
||||
from sleekxmpp.xmlstream.handler import Waiter, XMLCallback
|
||||
from sleekxmpp.xmlstream.matcher import MatchXMLMask
|
||||
from sleekxmpp.xmlstream.resolver import resolve, default_resolver
|
||||
|
||||
# In Python 2.x, file socket objects are broken. A patched socket
|
||||
# wrapper is provided for this case in filesocket.py.
|
||||
if sys.version_info < (3, 0):
|
||||
from sleekxmpp.xmlstream.filesocket import FileSocket, Socket26
|
||||
|
||||
try:
|
||||
import dns.resolver
|
||||
except ImportError:
|
||||
DNSPYTHON = False
|
||||
else:
|
||||
DNSPYTHON = True
|
||||
|
||||
|
||||
#: The time in seconds to wait before timing out waiting for response stanzas.
|
||||
RESPONSE_TIMEOUT = 30
|
||||
@@ -306,6 +299,11 @@ class XMLStream(object):
|
||||
#: A list of DNS results that have not yet been tried.
|
||||
self.dns_answers = []
|
||||
|
||||
#: The service name to check with DNS SRV records. For
|
||||
#: example, setting this to ``'xmpp-client'`` would query the
|
||||
#: ``_xmpp-client._tcp`` service.
|
||||
self.dns_service = None
|
||||
|
||||
self.add_event_handler('connected', self._handle_connected)
|
||||
self.add_event_handler('session_start', self._start_keepalive)
|
||||
self.add_event_handler('disconnected', self._end_keepalive)
|
||||
@@ -445,25 +443,10 @@ class XMLStream(object):
|
||||
self.stop.set()
|
||||
return False
|
||||
|
||||
try:
|
||||
# Look for IPv6 addresses, in addition to IPv4
|
||||
for res in Socket.getaddrinfo(self.address[0],
|
||||
int(self.address[1]),
|
||||
0,
|
||||
Socket.SOCK_STREAM):
|
||||
log.debug("Trying: %s", res[-1])
|
||||
af, sock_type, proto, canonical, sock_addr = res
|
||||
try:
|
||||
self.socket = self.socket_class(af, sock_type, proto)
|
||||
break
|
||||
except Socket.error:
|
||||
log.debug("Could not open IPv%s socket." % proto)
|
||||
except Socket.gaierror:
|
||||
log.warning("Socket could not be opened: no connectivity" + \
|
||||
" or wrong IP versions.")
|
||||
if reattempt:
|
||||
self.reconnect_delay = delay
|
||||
return False
|
||||
af = Socket.AF_INET
|
||||
if ':' in self.address[0]:
|
||||
af = Socket.AF_INET6
|
||||
self.socket = self.socket_class(af, Socket.SOCK_STREAM)
|
||||
|
||||
self.configure_socket()
|
||||
|
||||
@@ -511,7 +494,10 @@ class XMLStream(object):
|
||||
except Socket.error as serr:
|
||||
error_msg = "Could not connect to %s:%s. Socket Error #%s: %s"
|
||||
self.event('socket_error', serr, direct=True)
|
||||
log.error(error_msg, self.address[0], self.address[1],
|
||||
domain = self.address[0]
|
||||
if ':' in domain:
|
||||
domain = '[%s]' % domain
|
||||
log.error(error_msg, domain, self.address[1],
|
||||
serr.errno, serr.strerror)
|
||||
if reattempt:
|
||||
self.reconnect_delay = delay
|
||||
@@ -915,50 +901,11 @@ class XMLStream(object):
|
||||
"""
|
||||
if port is None:
|
||||
port = self.default_port
|
||||
if DNSPYTHON:
|
||||
resolver = dns.resolver.get_default_resolver()
|
||||
self.configure_dns(resolver, domain=domain, port=port)
|
||||
|
||||
resolver = default_resolver()
|
||||
self.configure_dns(resolver, domain=domain, port=port)
|
||||
|
||||
v4_answers = []
|
||||
v6_answers = []
|
||||
answers = []
|
||||
|
||||
try:
|
||||
log.debug("Querying A records for %s" % domain)
|
||||
v4_answers = resolver.query(domain, dns.rdatatype.A)
|
||||
except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
|
||||
log.warning("No A records for %s", domain)
|
||||
v4_answers = [((domain, port), 0, 0)]
|
||||
except dns.exception.Timeout:
|
||||
log.warning("DNS resolution timed out " + \
|
||||
"for A record of %s", domain)
|
||||
v4_answers = [((domain, port), 0, 0)]
|
||||
else:
|
||||
for ans in v4_answers:
|
||||
log.debug("Found A record: %s", ans.address)
|
||||
answers.append(((ans.address, port), 0, 0))
|
||||
|
||||
try:
|
||||
log.debug("Querying AAAA records for %s" % domain)
|
||||
v6_answers = resolver.query(domain, dns.rdatatype.AAAA)
|
||||
except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
|
||||
log.warning("No AAAA records for %s", domain)
|
||||
v6_answers = [((domain, port), 0, 0)]
|
||||
except dns.exception.Timeout:
|
||||
log.warning("DNS resolution timed out " + \
|
||||
"for AAAA record of %s", domain)
|
||||
v6_answers = [((domain, port), 0, 0)]
|
||||
else:
|
||||
for ans in v6_answers:
|
||||
log.debug("Found AAAA record: %s", ans.address)
|
||||
answers.append(((ans.address, port), 0, 0))
|
||||
|
||||
return answers
|
||||
else:
|
||||
log.warning("dnspython is not installed -- " + \
|
||||
"relying on OS A/AAAA record resolution")
|
||||
self.configure_dns(None, domain=domain, port=port)
|
||||
return [((domain, port), 0, 0)]
|
||||
return resolve(domain, port, service=self.dns_service, resolver=resolver)
|
||||
|
||||
def pick_dns_answer(self, domain, port=None):
|
||||
"""Pick a server and port from DNS answers.
|
||||
@@ -971,33 +918,16 @@ class XMLStream(object):
|
||||
"""
|
||||
if not self.dns_answers:
|
||||
self.dns_answers = self.get_dns_records(domain, port)
|
||||
addresses = {}
|
||||
intmax = 0
|
||||
topprio = 65535
|
||||
for answer in self.dns_answers:
|
||||
topprio = min(topprio, answer[1])
|
||||
for answer in self.dns_answers:
|
||||
if answer[1] == topprio:
|
||||
intmax += answer[2]
|
||||
addresses[intmax] = answer[0]
|
||||
|
||||
#python3 returns a generator for dictionary keys
|
||||
items = [x for x in addresses.keys()]
|
||||
items.sort()
|
||||
|
||||
address = (domain, port)
|
||||
picked = random.randint(0, intmax)
|
||||
for item in items:
|
||||
if picked <= item:
|
||||
address = addresses[item]
|
||||
break
|
||||
for idx, answer in enumerate(self.dns_answers):
|
||||
if self.dns_answers[0] == address:
|
||||
self.dns_answers.pop(idx)
|
||||
break
|
||||
log.debug("Trying to connect to %s:%s", *address)
|
||||
return address
|
||||
|
||||
try:
|
||||
if sys.version_info < (3, 0):
|
||||
return self.dns_answers.next()
|
||||
else:
|
||||
return next(self.dns_answers)
|
||||
except StopIteration:
|
||||
self.dns_answers = None
|
||||
return (domain, port)
|
||||
|
||||
def add_event_handler(self, name, pointer,
|
||||
threaded=False, disposable=False):
|
||||
"""Add a custom event handler that will be executed whenever
|
||||
|
||||
Reference in New Issue
Block a user