Merge branch 'ca-cert-list' into 'master'

Allow Xmlstream.ca_certs to be an iterable

See merge request poezio/slixmpp!177
This commit is contained in:
Link Mauve 2022-01-03 11:04:36 +01:00
commit e56930e0a1

View File

@ -15,6 +15,7 @@ from typing import (
Coroutine, Coroutine,
Callable, Callable,
Iterator, Iterator,
Iterable,
List, List,
Optional, Optional,
Set, Set,
@ -33,7 +34,6 @@ import socket as Socket
import ssl import ssl
import weakref import weakref
import uuid import uuid
from pathlib import Path
from contextlib import contextmanager from contextlib import contextmanager
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
@ -47,6 +47,7 @@ from asyncio import (
iscoroutinefunction, iscoroutinefunction,
wait, wait,
) )
from pathlib import Path
from slixmpp.types import FilterString from slixmpp.types import FilterString
from slixmpp.xmlstream.tostring import tostring from slixmpp.xmlstream.tostring import tostring
@ -75,6 +76,15 @@ class NotConnectedError(Exception):
""" """
class InvalidCABundle(Exception):
"""
Exception raised when the CA Bundle file hasn't been found.
"""
def __init__(self, path: Optional[Path]):
self.path = path
_T = TypeVar('_T', str, ElementBase, StanzaBase) _T = TypeVar('_T', str, ElementBase, StanzaBase)
@ -162,7 +172,7 @@ class XMLStream(asyncio.BaseProtocol):
#: #:
#: On Mac OS X, certificates in the system keyring will #: On Mac OS X, certificates in the system keyring will
#: be consulted, even if they are not in the provided file. #: be consulted, even if they are not in the provided file.
ca_certs: Optional[Path] ca_certs: Optional[Union[Path, Iterable[Path]]]
#: Path to a file containing a client certificate to use for #: Path to a file containing a client certificate to use for
#: authenticating via SASL EXTERNAL. If set, there must also #: authenticating via SASL EXTERNAL. If set, there must also
@ -760,8 +770,20 @@ class XMLStream(asyncio.BaseProtocol):
log.debug('Loaded cert file %s and key file %s', log.debug('Loaded cert file %s and key file %s',
self.certfile, self.keyfile) self.certfile, self.keyfile)
if self.ca_certs is not None: if self.ca_certs is not None:
ca_cert: Optional[Path] = None
if isinstance(self.ca_certs, Path):
if self.ca_certs.is_file():
ca_cert = self.ca_certs
else:
for bundle in self.ca_certs:
if bundle.is_file():
ca_cert = bundle
break
if ca_cert is None:
raise InvalidCABundle(ca_cert)
self.ssl_context.verify_mode = ssl.CERT_REQUIRED self.ssl_context.verify_mode = ssl.CERT_REQUIRED
self.ssl_context.load_verify_locations(cafile=self.ca_certs) self.ssl_context.load_verify_locations(cafile=ca_cert)
return self.ssl_context return self.ssl_context