Allow Xmlstream.ca_certs to be an iterable

Signed-off-by: Maxime “pep” Buquet <pep@bouah.net>
This commit is contained in:
Maxime “pep” Buquet 2021-12-28 19:50:20 +01:00
parent 834ea8ed74
commit d733c54518
No known key found for this signature in database
GPG Key ID: DEDA74AEECA9D0F2

View File

@ -15,6 +15,7 @@ from typing import (
Coroutine, Coroutine,
Callable, Callable,
Iterator, Iterator,
Iterable,
List, List,
Optional, Optional,
Set, Set,
@ -33,7 +34,6 @@ import socket as Socket
import ssl import ssl
import weakref import weakref
import uuid import uuid
from pathlib import Path
from contextlib import contextmanager from contextlib import contextmanager
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
@ -47,6 +47,7 @@ from asyncio import (
iscoroutinefunction, iscoroutinefunction,
wait, wait,
) )
from pathlib import Path
from slixmpp.types import FilterString from slixmpp.types import FilterString
from slixmpp.xmlstream.tostring import tostring from slixmpp.xmlstream.tostring import tostring
@ -75,6 +76,15 @@ class NotConnectedError(Exception):
""" """
class InvalidCABundle(Exception):
"""
Exception raised when the CA Bundle file hasn't been found.
"""
def __init__(self, path: Optional[Path]):
self.path = path
_T = TypeVar('_T', str, ElementBase, StanzaBase) _T = TypeVar('_T', str, ElementBase, StanzaBase)
@ -162,7 +172,7 @@ class XMLStream(asyncio.BaseProtocol):
#: #:
#: On Mac OS X, certificates in the system keyring will #: On Mac OS X, certificates in the system keyring will
#: be consulted, even if they are not in the provided file. #: be consulted, even if they are not in the provided file.
ca_certs: Optional[Path] ca_certs: Optional[Union[Path, Iterable[Path]]]
#: Path to a file containing a client certificate to use for #: Path to a file containing a client certificate to use for
#: authenticating via SASL EXTERNAL. If set, there must also #: authenticating via SASL EXTERNAL. If set, there must also
@ -760,8 +770,20 @@ class XMLStream(asyncio.BaseProtocol):
log.debug('Loaded cert file %s and key file %s', log.debug('Loaded cert file %s and key file %s',
self.certfile, self.keyfile) self.certfile, self.keyfile)
if self.ca_certs is not None: if self.ca_certs is not None:
ca_cert: Optional[Path] = None
if isinstance(self.ca_certs, Path):
if self.ca_certs.is_file():
ca_cert = self.ca_certs
else:
for bundle in self.ca_certs:
if bundle.is_file():
ca_cert = bundle
break
if ca_cert is None:
raise InvalidCABundle(ca_cert)
self.ssl_context.verify_mode = ssl.CERT_REQUIRED self.ssl_context.verify_mode = ssl.CERT_REQUIRED
self.ssl_context.load_verify_locations(cafile=self.ca_certs) self.ssl_context.load_verify_locations(cafile=ca_cert)
return self.ssl_context return self.ssl_context