DNS is now properly checked and different answers are tried for each reconnect until exhausted
This commit is contained in:
@@ -36,6 +36,13 @@ from sleekxmpp.xmlstream.matcher import MatchXMLMask
|
||||
if sys.version_info < (3, 0):
|
||||
from sleekxmpp.xmlstream.filesocket import FileSocket, Socket26
|
||||
|
||||
try:
|
||||
import dns.resolver
|
||||
except ImportError:
|
||||
DNSPYTHON = False
|
||||
else:
|
||||
DNSPYTHON = True
|
||||
|
||||
|
||||
# The time in seconds to wait before timing out waiting for response stanzas.
|
||||
RESPONSE_TIMEOUT = 10
|
||||
@@ -51,7 +58,6 @@ SSL_SUPPORT = True
|
||||
# Maximum time to delay between connection attempts is one hour.
|
||||
RECONNECT_MAX_DELAY = 600
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -92,6 +98,7 @@ class XMLStream(object):
|
||||
events to be processed.
|
||||
filesocket -- A filesocket created from the main connection socket.
|
||||
Required for ElementTree.iterparse.
|
||||
default_port -- Default port to connect to.
|
||||
namespace_map -- Optional mapping of namespaces to namespace prefixes.
|
||||
scheduler -- A scheduler object for triggering events
|
||||
after a given period of time.
|
||||
@@ -121,6 +128,7 @@ class XMLStream(object):
|
||||
reconnect_max_delay -- Maximum time to delay between connection
|
||||
attempts. Defaults to RECONNECT_MAX_DELAY,
|
||||
which is one hour.
|
||||
dns_answers -- List of dns answers not yet used to connect.
|
||||
|
||||
Methods:
|
||||
add_event_handler -- Add a handler for a custom event.
|
||||
@@ -177,6 +185,8 @@ class XMLStream(object):
|
||||
self.state = StateMachine(('disconnected', 'connected'))
|
||||
self.state._set_state('disconnected')
|
||||
|
||||
self.default_port = int(port)
|
||||
self.default_domain = ''
|
||||
self.address = (host, int(port))
|
||||
self.filesocket = None
|
||||
self.set_socket(socket)
|
||||
@@ -219,6 +229,7 @@ class XMLStream(object):
|
||||
|
||||
self.auto_reconnect = True
|
||||
self.is_client = False
|
||||
self.dns_answers = []
|
||||
|
||||
def use_signals(self, signals=None):
|
||||
"""
|
||||
@@ -303,6 +314,10 @@ class XMLStream(object):
|
||||
"""
|
||||
if host and port:
|
||||
self.address = (host, int(port))
|
||||
try:
|
||||
Socket.inet_aton(self.address[0])
|
||||
except Socket.error:
|
||||
self.default_domain = self.address[0]
|
||||
|
||||
self.is_client = True
|
||||
# Respect previous SSL and TLS usage directives.
|
||||
@@ -322,6 +337,8 @@ class XMLStream(object):
|
||||
|
||||
def _connect(self):
|
||||
self.stop.clear()
|
||||
if self.default_domain:
|
||||
self.address = self.pick_dns_answer(self.default_domain, self.address[1])
|
||||
self.socket = self.socket_class(Socket.AF_INET, Socket.SOCK_STREAM)
|
||||
self.socket.settimeout(None)
|
||||
|
||||
@@ -639,6 +656,51 @@ class XMLStream(object):
|
||||
idx += 1
|
||||
return False
|
||||
|
||||
def get_dns_records(self, domain, port=None):
|
||||
if port is None:
|
||||
port = self.default_port
|
||||
if DNSPYTHON:
|
||||
try:
|
||||
answers = dns.resolver.query(domain, dns.rdatatype.A)
|
||||
except dns.resolver.NXDOMAIN, dns.resolver.NoAnswer:
|
||||
log.warning("No A records for %s" % domain)
|
||||
except dns.exception.Timeout:
|
||||
log.warning("DNS resolution timed out for A record of %s" % domain)
|
||||
answers = [((answer.address, port), 0, 0) for answer in answers]
|
||||
return answers
|
||||
else:
|
||||
log.warning("dnspython is not installed -- relying on OS A record resolution")
|
||||
return [((domain, port), 0, 0)]
|
||||
|
||||
def pick_dns_answer(self, domain, port=None):
|
||||
if not self.dns_answers:
|
||||
self.dns_answers = self.get_dns_records(domain, port)
|
||||
addresses = {}
|
||||
intmax = 0
|
||||
topprio = 65535
|
||||
for answer in self.dns_answers:
|
||||
topprio = min(topprio, answer[1])
|
||||
for answer in self.dns_answers:
|
||||
if answer[1] == topprio:
|
||||
intmax += answer[2]
|
||||
addresses[intmax] = answer[0]
|
||||
|
||||
#python3 returns a generator for dictionary keys
|
||||
items = [x for x in addresses.keys()]
|
||||
items.sort()
|
||||
|
||||
picked = random.randint(0, intmax)
|
||||
for item in items:
|
||||
if picked <= item:
|
||||
address = addresses[item]
|
||||
break
|
||||
for idx, answer in enumerate(self.dns_answers):
|
||||
if self.dns_answers[0] == address:
|
||||
break
|
||||
self.dns_answers.pop(idx)
|
||||
log.debug("Trying to connect to %s:%s" % address)
|
||||
return address
|
||||
|
||||
def add_event_handler(self, name, pointer,
|
||||
threaded=False, disposable=False):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user