Add better DNS resolver wrapper.

This commit is contained in:
Lance Stout
2012-03-29 15:11:24 -07:00
parent aad2eb31fc
commit c1d36cad46
3 changed files with 324 additions and 128 deletions

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
sleekxmpp.xmlstream.xmlstream
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -39,19 +38,13 @@ from sleekxmpp.xmlstream import Scheduler, tostring
from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase
from sleekxmpp.xmlstream.handler import Waiter, XMLCallback
from sleekxmpp.xmlstream.matcher import MatchXMLMask
from sleekxmpp.xmlstream.resolver import resolve, default_resolver
# In Python 2.x, file socket objects are broken. A patched socket
# wrapper is provided for this case in filesocket.py.
if sys.version_info < (3, 0):
from sleekxmpp.xmlstream.filesocket import FileSocket, Socket26
try:
import dns.resolver
except ImportError:
DNSPYTHON = False
else:
DNSPYTHON = True
#: The time in seconds to wait before timing out waiting for response stanzas.
RESPONSE_TIMEOUT = 30
@@ -306,6 +299,11 @@ class XMLStream(object):
#: A list of DNS results that have not yet been tried.
self.dns_answers = []
#: The service name to check with DNS SRV records. For
#: example, setting this to ``'xmpp-client'`` would query the
#: ``_xmpp-client._tcp`` service.
self.dns_service = None
self.add_event_handler('connected', self._handle_connected)
self.add_event_handler('session_start', self._start_keepalive)
self.add_event_handler('disconnected', self._end_keepalive)
@@ -445,25 +443,10 @@ class XMLStream(object):
self.stop.set()
return False
try:
# Look for IPv6 addresses, in addition to IPv4
for res in Socket.getaddrinfo(self.address[0],
int(self.address[1]),
0,
Socket.SOCK_STREAM):
log.debug("Trying: %s", res[-1])
af, sock_type, proto, canonical, sock_addr = res
try:
self.socket = self.socket_class(af, sock_type, proto)
break
except Socket.error:
log.debug("Could not open IPv%s socket." % proto)
except Socket.gaierror:
log.warning("Socket could not be opened: no connectivity" + \
" or wrong IP versions.")
if reattempt:
self.reconnect_delay = delay
return False
af = Socket.AF_INET
if ':' in self.address[0]:
af = Socket.AF_INET6
self.socket = self.socket_class(af, Socket.SOCK_STREAM)
self.configure_socket()
@@ -511,7 +494,10 @@ class XMLStream(object):
except Socket.error as serr:
error_msg = "Could not connect to %s:%s. Socket Error #%s: %s"
self.event('socket_error', serr, direct=True)
log.error(error_msg, self.address[0], self.address[1],
domain = self.address[0]
if ':' in domain:
domain = '[%s]' % domain
log.error(error_msg, domain, self.address[1],
serr.errno, serr.strerror)
if reattempt:
self.reconnect_delay = delay
@@ -915,50 +901,11 @@ class XMLStream(object):
"""
if port is None:
port = self.default_port
if DNSPYTHON:
resolver = dns.resolver.get_default_resolver()
self.configure_dns(resolver, domain=domain, port=port)
resolver = default_resolver()
self.configure_dns(resolver, domain=domain, port=port)
v4_answers = []
v6_answers = []
answers = []
try:
log.debug("Querying A records for %s" % domain)
v4_answers = resolver.query(domain, dns.rdatatype.A)
except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
log.warning("No A records for %s", domain)
v4_answers = [((domain, port), 0, 0)]
except dns.exception.Timeout:
log.warning("DNS resolution timed out " + \
"for A record of %s", domain)
v4_answers = [((domain, port), 0, 0)]
else:
for ans in v4_answers:
log.debug("Found A record: %s", ans.address)
answers.append(((ans.address, port), 0, 0))
try:
log.debug("Querying AAAA records for %s" % domain)
v6_answers = resolver.query(domain, dns.rdatatype.AAAA)
except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
log.warning("No AAAA records for %s", domain)
v6_answers = [((domain, port), 0, 0)]
except dns.exception.Timeout:
log.warning("DNS resolution timed out " + \
"for AAAA record of %s", domain)
v6_answers = [((domain, port), 0, 0)]
else:
for ans in v6_answers:
log.debug("Found AAAA record: %s", ans.address)
answers.append(((ans.address, port), 0, 0))
return answers
else:
log.warning("dnspython is not installed -- " + \
"relying on OS A/AAAA record resolution")
self.configure_dns(None, domain=domain, port=port)
return [((domain, port), 0, 0)]
return resolve(domain, port, service=self.dns_service, resolver=resolver)
def pick_dns_answer(self, domain, port=None):
"""Pick a server and port from DNS answers.
@@ -971,33 +918,16 @@ class XMLStream(object):
"""
if not self.dns_answers:
self.dns_answers = self.get_dns_records(domain, port)
addresses = {}
intmax = 0
topprio = 65535
for answer in self.dns_answers:
topprio = min(topprio, answer[1])
for answer in self.dns_answers:
if answer[1] == topprio:
intmax += answer[2]
addresses[intmax] = answer[0]
#python3 returns a generator for dictionary keys
items = [x for x in addresses.keys()]
items.sort()
address = (domain, port)
picked = random.randint(0, intmax)
for item in items:
if picked <= item:
address = addresses[item]
break
for idx, answer in enumerate(self.dns_answers):
if self.dns_answers[0] == address:
self.dns_answers.pop(idx)
break
log.debug("Trying to connect to %s:%s", *address)
return address
try:
if sys.version_info < (3, 0):
return self.dns_answers.next()
else:
return next(self.dns_answers)
except StopIteration:
self.dns_answers = None
return (domain, port)
def add_event_handler(self, name, pointer,
threaded=False, disposable=False):
"""Add a custom event handler that will be executed whenever