Allow Xmlstream.ca_certs to be an iterable
Signed-off-by: Maxime “pep” Buquet <pep@bouah.net>
This commit is contained in:
parent
834ea8ed74
commit
d733c54518
@ -15,6 +15,7 @@ from typing import (
|
|||||||
Coroutine,
|
Coroutine,
|
||||||
Callable,
|
Callable,
|
||||||
Iterator,
|
Iterator,
|
||||||
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
@ -33,7 +34,6 @@ import socket as Socket
|
|||||||
import ssl
|
import ssl
|
||||||
import weakref
|
import weakref
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
@ -47,6 +47,7 @@ from asyncio import (
|
|||||||
iscoroutinefunction,
|
iscoroutinefunction,
|
||||||
wait,
|
wait,
|
||||||
)
|
)
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from slixmpp.types import FilterString
|
from slixmpp.types import FilterString
|
||||||
from slixmpp.xmlstream.tostring import tostring
|
from slixmpp.xmlstream.tostring import tostring
|
||||||
@ -75,6 +76,15 @@ class NotConnectedError(Exception):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidCABundle(Exception):
|
||||||
|
"""
|
||||||
|
Exception raised when the CA Bundle file hasn't been found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: Optional[Path]):
|
||||||
|
self.path = path
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar('_T', str, ElementBase, StanzaBase)
|
_T = TypeVar('_T', str, ElementBase, StanzaBase)
|
||||||
|
|
||||||
|
|
||||||
@ -162,7 +172,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
#:
|
#:
|
||||||
#: On Mac OS X, certificates in the system keyring will
|
#: On Mac OS X, certificates in the system keyring will
|
||||||
#: be consulted, even if they are not in the provided file.
|
#: be consulted, even if they are not in the provided file.
|
||||||
ca_certs: Optional[Path]
|
ca_certs: Optional[Union[Path, Iterable[Path]]]
|
||||||
|
|
||||||
#: Path to a file containing a client certificate to use for
|
#: Path to a file containing a client certificate to use for
|
||||||
#: authenticating via SASL EXTERNAL. If set, there must also
|
#: authenticating via SASL EXTERNAL. If set, there must also
|
||||||
@ -760,8 +770,20 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
log.debug('Loaded cert file %s and key file %s',
|
log.debug('Loaded cert file %s and key file %s',
|
||||||
self.certfile, self.keyfile)
|
self.certfile, self.keyfile)
|
||||||
if self.ca_certs is not None:
|
if self.ca_certs is not None:
|
||||||
|
ca_cert: Optional[Path] = None
|
||||||
|
if isinstance(self.ca_certs, Path):
|
||||||
|
if self.ca_certs.is_file():
|
||||||
|
ca_cert = self.ca_certs
|
||||||
|
else:
|
||||||
|
for bundle in self.ca_certs:
|
||||||
|
if bundle.is_file():
|
||||||
|
ca_cert = bundle
|
||||||
|
break
|
||||||
|
if ca_cert is None:
|
||||||
|
raise InvalidCABundle(ca_cert)
|
||||||
|
|
||||||
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
|
self.ssl_context.verify_mode = ssl.CERT_REQUIRED
|
||||||
self.ssl_context.load_verify_locations(cafile=self.ca_certs)
|
self.ssl_context.load_verify_locations(cafile=ca_cert)
|
||||||
|
|
||||||
return self.ssl_context
|
return self.ssl_context
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user