xmlstream: make dns_answers private

This commit is contained in:
mathieui 2021-02-04 18:42:01 +01:00
parent d3063a0368
commit ccbba89cbd

View File

@ -16,10 +16,12 @@ from typing import (
Any, Any,
Callable, Callable,
Iterable, Iterable,
Iterator,
List, List,
Optional, Optional,
Set, Set,
Union, Union,
Tuple,
) )
import functools import functools
@ -212,7 +214,7 @@ class XMLStream(asyncio.BaseProtocol):
self._current_connection_attempt = None self._current_connection_attempt = None
#: A list of DNS results that have not yet been tried. #: A list of DNS results that have not yet been tried.
self.dns_answers = None self._dns_answers: Optional[Iterator[Tuple[str, str, int]]] = None
#: The service name to check with DNS SRV records. For #: The service name to check with DNS SRV records. For
#: example, setting this to ``'xmpp-client'`` would query the #: example, setting this to ``'xmpp-client'`` would query the
@ -315,7 +317,7 @@ class XMLStream(asyncio.BaseProtocol):
self.event('reconnect_delay', self._connect_loop_wait) self.event('reconnect_delay', self._connect_loop_wait)
await asyncio.sleep(self._connect_loop_wait, loop=self.loop) await asyncio.sleep(self._connect_loop_wait, loop=self.loop)
record = await self.pick_dns_answer(self.default_domain) record = await self._pick_dns_answer(self.default_domain)
if record is not None: if record is not None:
host, address, dns_port = record host, address, dns_port = record
port = dns_port if dns_port else self.address[1] port = dns_port if dns_port else self.address[1]
@ -324,7 +326,7 @@ class XMLStream(asyncio.BaseProtocol):
else: else:
# No DNS records left, stop iterating # No DNS records left, stop iterating
# and try (host, port) as a last resort # and try (host, port) as a last resort
self.dns_answers = None self._dns_answers = None
if self.use_ssl: if self.use_ssl:
ssl_context = self.get_ssl_context() ssl_context = self.get_ssl_context()
@ -392,7 +394,7 @@ class XMLStream(asyncio.BaseProtocol):
self._current_connection_attempt = None self._current_connection_attempt = None
self.init_parser() self.init_parser()
self.send_raw(self.stream_header) self.send_raw(self.stream_header)
self.dns_answers = None self._dns_answers = None
def data_received(self, data): def data_received(self, data):
"""Called when incoming data is received on the socket. """Called when incoming data is received on the socket.
@ -777,7 +779,7 @@ class XMLStream(asyncio.BaseProtocol):
idx += 1 idx += 1
return False return False
async def get_dns_records(self, domain, port=None): async def get_dns_records(self, domain: str, port: Optional[int] = None) -> List[Tuple[str, str, int]]:
"""Get the DNS records for a domain. """Get the DNS records for a domain.
:param domain: The domain in question. :param domain: The domain in question.
@ -797,7 +799,7 @@ class XMLStream(asyncio.BaseProtocol):
loop=self.loop) loop=self.loop)
return result return result
async def pick_dns_answer(self, domain, port=None): async def _pick_dns_answer(self, domain: str, port: Optional[int] = None) -> Optional[Tuple[str, str, int]]:
"""Pick a server and port from DNS answers. """Pick a server and port from DNS answers.
Gets DNS answers if none available. Gets DNS answers if none available.
@ -806,12 +808,12 @@ class XMLStream(asyncio.BaseProtocol):
:param domain: The domain in question. :param domain: The domain in question.
:param port: If the results don't include a port, use this one. :param port: If the results don't include a port, use this one.
""" """
if self.dns_answers is None: if self._dns_answers is None:
dns_records = await self.get_dns_records(domain, port) dns_records = await self.get_dns_records(domain, port)
self.dns_answers = iter(dns_records) self._dns_answers = iter(dns_records)
try: try:
return next(self.dns_answers) return next(self._dns_answers)
except StopIteration: except StopIteration:
return return