slixmpp.util: type things
Fix a bug in the SASL implementation as well. (some special chars would make things crash instead of being escaped)
This commit is contained in:
parent
b1411d8ed7
commit
ef06429941
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
# Slixmpp: The Slick XMPP Library
|
# Slixmpp: The Slick XMPP Library
|
||||||
# Copyright (C) 2018 Emmanuel Gil Peyrot
|
# Copyright (C) 2018 Emmanuel Gil Peyrot
|
||||||
# This file is part of Slixmpp.
|
# This file is part of Slixmpp.
|
||||||
@ -6,8 +5,11 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from typing import Callable, Optional, Any
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Cache:
|
class Cache:
|
||||||
def retrieve(self, key):
|
def retrieve(self, key):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -16,7 +18,8 @@ class Cache:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def remove(self, key):
|
def remove(self, key):
|
||||||
raise NotImplemented
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class PerJidCache:
|
class PerJidCache:
|
||||||
def retrieve_by_jid(self, jid, key):
|
def retrieve_by_jid(self, jid, key):
|
||||||
@ -28,6 +31,7 @@ class PerJidCache:
|
|||||||
def remove_by_jid(self, jid, key):
|
def remove_by_jid(self, jid, key):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class MemoryCache(Cache):
|
class MemoryCache(Cache):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
@ -44,6 +48,7 @@ class MemoryCache(Cache):
|
|||||||
del self.cache[key]
|
del self.cache[key]
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class MemoryPerJidCache(PerJidCache):
|
class MemoryPerJidCache(PerJidCache):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
@ -65,14 +70,15 @@ class MemoryPerJidCache(PerJidCache):
|
|||||||
del cache[key]
|
del cache[key]
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class FileSystemStorage:
|
class FileSystemStorage:
|
||||||
def __init__(self, encode, decode, binary):
|
def __init__(self, encode: Optional[Callable[[Any], str]], decode: Optional[Callable[[str], Any]], binary: bool):
|
||||||
self.encode = encode if encode is not None else lambda x: x
|
self.encode = encode if encode is not None else lambda x: x
|
||||||
self.decode = decode if decode is not None else lambda x: x
|
self.decode = decode if decode is not None else lambda x: x
|
||||||
self.read = 'rb' if binary else 'r'
|
self.read = 'rb' if binary else 'r'
|
||||||
self.write = 'wb' if binary else 'w'
|
self.write = 'wb' if binary else 'w'
|
||||||
|
|
||||||
def _retrieve(self, directory, key):
|
def _retrieve(self, directory: str, key: str):
|
||||||
filename = os.path.join(directory, key.replace('/', '_'))
|
filename = os.path.join(directory, key.replace('/', '_'))
|
||||||
try:
|
try:
|
||||||
with open(filename, self.read) as cache_file:
|
with open(filename, self.read) as cache_file:
|
||||||
@ -86,7 +92,7 @@ class FileSystemStorage:
|
|||||||
log.debug('Removing %s entry', key)
|
log.debug('Removing %s entry', key)
|
||||||
self._remove(directory, key)
|
self._remove(directory, key)
|
||||||
|
|
||||||
def _store(self, directory, key, value):
|
def _store(self, directory: str, key: str, value):
|
||||||
filename = os.path.join(directory, key.replace('/', '_'))
|
filename = os.path.join(directory, key.replace('/', '_'))
|
||||||
try:
|
try:
|
||||||
os.makedirs(directory, exist_ok=True)
|
os.makedirs(directory, exist_ok=True)
|
||||||
@ -99,7 +105,7 @@ class FileSystemStorage:
|
|||||||
except Exception:
|
except Exception:
|
||||||
log.debug('Failed to encode %s to cache:', key, exc_info=True)
|
log.debug('Failed to encode %s to cache:', key, exc_info=True)
|
||||||
|
|
||||||
def _remove(self, directory, key):
|
def _remove(self, directory: str, key: str):
|
||||||
filename = os.path.join(directory, key.replace('/', '_'))
|
filename = os.path.join(directory, key.replace('/', '_'))
|
||||||
try:
|
try:
|
||||||
os.remove(filename)
|
os.remove(filename)
|
||||||
@ -108,8 +114,9 @@ class FileSystemStorage:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class FileSystemCache(Cache, FileSystemStorage):
|
class FileSystemCache(Cache, FileSystemStorage):
|
||||||
def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False):
|
def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False):
|
||||||
FileSystemStorage.__init__(self, encode, decode, binary)
|
FileSystemStorage.__init__(self, encode, decode, binary)
|
||||||
self.base_dir = os.path.join(directory, cache_type)
|
self.base_dir = os.path.join(directory, cache_type)
|
||||||
|
|
||||||
@ -122,8 +129,9 @@ class FileSystemCache(Cache, FileSystemStorage):
|
|||||||
def remove(self, key):
|
def remove(self, key):
|
||||||
return self._remove(self.base_dir, key)
|
return self._remove(self.base_dir, key)
|
||||||
|
|
||||||
|
|
||||||
class FileSystemPerJidCache(PerJidCache, FileSystemStorage):
|
class FileSystemPerJidCache(PerJidCache, FileSystemStorage):
|
||||||
def __init__(self, directory, cache_type, *, encode=None, decode=None, binary=False):
|
def __init__(self, directory: str, cache_type: str, *, encode=None, decode=None, binary=False):
|
||||||
FileSystemStorage.__init__(self, encode, decode, binary)
|
FileSystemStorage.__init__(self, encode, decode, binary)
|
||||||
self.base_dir = os.path.join(directory, cache_type)
|
self.base_dir = os.path.join(directory, cache_type)
|
||||||
|
|
||||||
|
@ -2,15 +2,19 @@ import builtins
|
|||||||
import sys
|
import sys
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
from typing import Optional, Union, Callable, List
|
||||||
|
|
||||||
def unicode(text):
|
bytes_ = builtins.bytes # alias the stdlib type but ew
|
||||||
|
|
||||||
|
|
||||||
|
def unicode(text: Union[bytes_, str]) -> str:
|
||||||
if not isinstance(text, str):
|
if not isinstance(text, str):
|
||||||
return text.decode('utf-8')
|
return text.decode('utf-8')
|
||||||
else:
|
else:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def bytes(text):
|
def bytes(text: Optional[Union[str, bytes_]]) -> bytes_:
|
||||||
"""
|
"""
|
||||||
Convert Unicode text to UTF-8 encoded bytes.
|
Convert Unicode text to UTF-8 encoded bytes.
|
||||||
|
|
||||||
@ -34,7 +38,7 @@ def bytes(text):
|
|||||||
return builtins.bytes(text, encoding='utf-8')
|
return builtins.bytes(text, encoding='utf-8')
|
||||||
|
|
||||||
|
|
||||||
def quote(text):
|
def quote(text: Union[str, bytes_]) -> bytes_:
|
||||||
"""
|
"""
|
||||||
Enclose in quotes and escape internal slashes and double quotes.
|
Enclose in quotes and escape internal slashes and double quotes.
|
||||||
|
|
||||||
@ -44,7 +48,7 @@ def quote(text):
|
|||||||
return b'"' + text.replace(b'\\', b'\\\\').replace(b'"', b'\\"') + b'"'
|
return b'"' + text.replace(b'\\', b'\\\\').replace(b'"', b'\\"') + b'"'
|
||||||
|
|
||||||
|
|
||||||
def num_to_bytes(num):
|
def num_to_bytes(num: int) -> bytes_:
|
||||||
"""
|
"""
|
||||||
Convert an integer into a four byte sequence.
|
Convert an integer into a four byte sequence.
|
||||||
|
|
||||||
@ -58,21 +62,21 @@ def num_to_bytes(num):
|
|||||||
return bval
|
return bval
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_num(bval):
|
def bytes_to_num(bval: bytes_) -> int:
|
||||||
"""
|
"""
|
||||||
Convert a four byte sequence to an integer.
|
Convert a four byte sequence to an integer.
|
||||||
|
|
||||||
:param bytes bval: A four byte sequence to turn into an integer.
|
:param bytes bval: A four byte sequence to turn into an integer.
|
||||||
"""
|
"""
|
||||||
num = 0
|
num = 0
|
||||||
num += ord(bval[0] << 24)
|
num += (bval[0] << 24)
|
||||||
num += ord(bval[1] << 16)
|
num += (bval[1] << 16)
|
||||||
num += ord(bval[2] << 8)
|
num += (bval[2] << 8)
|
||||||
num += ord(bval[3])
|
num += (bval[3])
|
||||||
return num
|
return num
|
||||||
|
|
||||||
|
|
||||||
def XOR(x, y):
|
def XOR(x: bytes_, y: bytes_) -> bytes_:
|
||||||
"""
|
"""
|
||||||
Return the results of an XOR operation on two equal length byte strings.
|
Return the results of an XOR operation on two equal length byte strings.
|
||||||
|
|
||||||
@ -85,7 +89,7 @@ def XOR(x, y):
|
|||||||
return builtins.bytes([a ^ b for a, b in zip(x, y)])
|
return builtins.bytes([a ^ b for a, b in zip(x, y)])
|
||||||
|
|
||||||
|
|
||||||
def hash(name):
|
def hash(name: str) -> Optional[Callable]:
|
||||||
"""
|
"""
|
||||||
Return a hash function implementing the given algorithm.
|
Return a hash function implementing the given algorithm.
|
||||||
|
|
||||||
@ -102,7 +106,7 @@ def hash(name):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def hashes():
|
def hashes() -> List[str]:
|
||||||
"""
|
"""
|
||||||
Return a list of available hashing algorithms.
|
Return a list of available hashing algorithms.
|
||||||
|
|
||||||
@ -115,28 +119,3 @@ def hashes():
|
|||||||
t += ['MD2']
|
t += ['MD2']
|
||||||
hashes = ['SHA-' + h[3:] for h in dir(hashlib) if h.startswith('sha')]
|
hashes = ['SHA-' + h[3:] for h in dir(hashlib) if h.startswith('sha')]
|
||||||
return t + hashes
|
return t + hashes
|
||||||
|
|
||||||
|
|
||||||
def setdefaultencoding(encoding):
|
|
||||||
"""
|
|
||||||
Set the current default string encoding used by the Unicode implementation.
|
|
||||||
|
|
||||||
Actually calls sys.setdefaultencoding under the hood - see the docs for that
|
|
||||||
for more details. This method exists only as a way to call find/call it
|
|
||||||
even after it has been 'deleted' when the site module is executed.
|
|
||||||
|
|
||||||
:param string encoding: An encoding name, compatible with sys.setdefaultencoding
|
|
||||||
"""
|
|
||||||
func = getattr(sys, 'setdefaultencoding', None)
|
|
||||||
if func is None:
|
|
||||||
import gc
|
|
||||||
import types
|
|
||||||
for obj in gc.get_objects():
|
|
||||||
if (isinstance(obj, types.BuiltinFunctionType)
|
|
||||||
and obj.__name__ == 'setdefaultencoding'):
|
|
||||||
func = obj
|
|
||||||
break
|
|
||||||
if func is None:
|
|
||||||
raise RuntimeError("Could not find setdefaultencoding")
|
|
||||||
sys.setdefaultencoding = func
|
|
||||||
return func(encoding)
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
# slixmpp.util.sasl.client
|
# slixmpp.util.sasl.client
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# This module was originally based on Dave Cridland's Suelta library.
|
# This module was originally based on Dave Cridland's Suelta library.
|
||||||
@ -6,9 +5,11 @@
|
|||||||
# :copryight: (c) 2004-2013 David Alan Cridland
|
# :copryight: (c) 2004-2013 David Alan Cridland
|
||||||
# :copyright: (c) 2013 Nathanael C. Fritz, Lance J.T. Stout
|
# :copyright: (c) 2013 Nathanael C. Fritz, Lance J.T. Stout
|
||||||
# :license: MIT, see LICENSE for more details
|
# :license: MIT, see LICENSE for more details
|
||||||
|
from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
import stringprep
|
import stringprep
|
||||||
|
|
||||||
|
from typing import Iterable, Set, Callable, Dict, Any, Optional, Type
|
||||||
from slixmpp.util import hashes, bytes, stringprep_profiles
|
from slixmpp.util import hashes, bytes, stringprep_profiles
|
||||||
|
|
||||||
|
|
||||||
@ -16,11 +17,11 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
#: Global registry mapping mechanism names to implementation classes.
|
#: Global registry mapping mechanism names to implementation classes.
|
||||||
MECHANISMS = {}
|
MECHANISMS: Dict[str, Type[Mech]] = {}
|
||||||
|
|
||||||
|
|
||||||
#: Global registry mapping mechanism names to security scores.
|
#: Global registry mapping mechanism names to security scores.
|
||||||
MECH_SEC_SCORES = {}
|
MECH_SEC_SCORES: Dict[str, int] = {}
|
||||||
|
|
||||||
|
|
||||||
#: The SASLprep profile of stringprep used to validate simple username
|
#: The SASLprep profile of stringprep used to validate simple username
|
||||||
@ -45,9 +46,10 @@ saslprep = stringprep_profiles.create(
|
|||||||
unassigned=[stringprep.in_table_a1])
|
unassigned=[stringprep.in_table_a1])
|
||||||
|
|
||||||
|
|
||||||
def sasl_mech(score):
|
def sasl_mech(score: int):
|
||||||
sec_score = score
|
sec_score = score
|
||||||
def register(mech):
|
|
||||||
|
def register(mech: Type[Mech]):
|
||||||
n = 0
|
n = 0
|
||||||
mech.score = sec_score
|
mech.score = sec_score
|
||||||
if mech.use_hashes:
|
if mech.use_hashes:
|
||||||
@ -99,9 +101,9 @@ class Mech(object):
|
|||||||
score = -1
|
score = -1
|
||||||
use_hashes = False
|
use_hashes = False
|
||||||
channel_binding = False
|
channel_binding = False
|
||||||
required_credentials = set()
|
required_credentials: Set[str] = set()
|
||||||
optional_credentials = set()
|
optional_credentials: Set[str] = set()
|
||||||
security = set()
|
security: Set[str] = set()
|
||||||
|
|
||||||
def __init__(self, name, credentials, security_settings):
|
def __init__(self, name, credentials, security_settings):
|
||||||
self.credentials = credentials
|
self.credentials = credentials
|
||||||
@ -118,7 +120,14 @@ class Mech(object):
|
|||||||
return b''
|
return b''
|
||||||
|
|
||||||
|
|
||||||
def choose(mech_list, credentials, security_settings, limit=None, min_mech=None):
|
CredentialsCallback = Callable[[Iterable[str], Iterable[str]], Dict[str, Any]]
|
||||||
|
SecurityCallback = Callable[[Iterable[str]], Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
def choose(mech_list: Iterable[Type[Mech]], credentials: CredentialsCallback,
|
||||||
|
security_settings: SecurityCallback,
|
||||||
|
limit: Optional[Iterable[Type[Mech]]] = None,
|
||||||
|
min_mech: Optional[str] = None) -> Mech:
|
||||||
available_mechs = set(MECHANISMS.keys())
|
available_mechs = set(MECHANISMS.keys())
|
||||||
if limit is None:
|
if limit is None:
|
||||||
limit = set(mech_list)
|
limit = set(mech_list)
|
||||||
@ -130,7 +139,10 @@ def choose(mech_list, credentials, security_settings, limit=None, min_mech=None)
|
|||||||
mech_list = mech_list.intersection(limit)
|
mech_list = mech_list.intersection(limit)
|
||||||
available_mechs = available_mechs.intersection(mech_list)
|
available_mechs = available_mechs.intersection(mech_list)
|
||||||
|
|
||||||
best_score = MECH_SEC_SCORES.get(min_mech, -1)
|
if min_mech is None:
|
||||||
|
best_score = -1
|
||||||
|
else:
|
||||||
|
best_score = MECH_SEC_SCORES.get(min_mech, -1)
|
||||||
best_mech = None
|
best_mech = None
|
||||||
for name in available_mechs:
|
for name in available_mechs:
|
||||||
if name in MECH_SEC_SCORES:
|
if name in MECH_SEC_SCORES:
|
||||||
|
@ -11,6 +11,9 @@ import hmac
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
from base64 import b64encode, b64decode
|
from base64 import b64encode, b64decode
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
bytes_ = bytes
|
||||||
|
|
||||||
from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes
|
from slixmpp.util import bytes, hash, XOR, quote, num_to_bytes
|
||||||
from slixmpp.util.sasl.client import sasl_mech, Mech, \
|
from slixmpp.util.sasl.client import sasl_mech, Mech, \
|
||||||
@ -63,7 +66,7 @@ class PLAIN(Mech):
|
|||||||
if not self.security_settings['encrypted_plain']:
|
if not self.security_settings['encrypted_plain']:
|
||||||
raise SASLCancelled('PLAIN with encryption')
|
raise SASLCancelled('PLAIN with encryption')
|
||||||
|
|
||||||
def process(self, challenge=b''):
|
def process(self, challenge: bytes_ = b'') -> bytes_:
|
||||||
authzid = self.credentials['authzid']
|
authzid = self.credentials['authzid']
|
||||||
authcid = self.credentials['username']
|
authcid = self.credentials['username']
|
||||||
password = self.credentials['password']
|
password = self.credentials['password']
|
||||||
@ -148,7 +151,7 @@ class CRAM(Mech):
|
|||||||
required_credentials = {'username', 'password'}
|
required_credentials = {'username', 'password'}
|
||||||
security = {'encrypted', 'unencrypted_cram'}
|
security = {'encrypted', 'unencrypted_cram'}
|
||||||
|
|
||||||
def setup(self, name):
|
def setup(self, name: str):
|
||||||
self.hash_name = name[5:]
|
self.hash_name = name[5:]
|
||||||
self.hash = hash(self.hash_name)
|
self.hash = hash(self.hash_name)
|
||||||
if self.hash is None:
|
if self.hash is None:
|
||||||
@ -157,14 +160,14 @@ class CRAM(Mech):
|
|||||||
if not self.security_settings['unencrypted_cram']:
|
if not self.security_settings['unencrypted_cram']:
|
||||||
raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name)
|
raise SASLCancelled('Unecrypted CRAM-%s' % self.hash_name)
|
||||||
|
|
||||||
def process(self, challenge=b''):
|
def process(self, challenge: bytes_ = b'') -> Optional[bytes_]:
|
||||||
if not challenge:
|
if not challenge:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
username = self.credentials['username']
|
username = self.credentials['username']
|
||||||
password = self.credentials['password']
|
password = self.credentials['password']
|
||||||
|
|
||||||
mac = hmac.HMAC(key=password, digestmod=self.hash)
|
mac = hmac.HMAC(key=password, digestmod=self.hash) # type: ignore
|
||||||
mac.update(challenge)
|
mac.update(challenge)
|
||||||
|
|
||||||
return username + b' ' + bytes(mac.hexdigest())
|
return username + b' ' + bytes(mac.hexdigest())
|
||||||
@ -201,43 +204,42 @@ class SCRAM(Mech):
|
|||||||
def HMAC(self, key, msg):
|
def HMAC(self, key, msg):
|
||||||
return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest()
|
return hmac.HMAC(key=key, msg=msg, digestmod=self.hash).digest()
|
||||||
|
|
||||||
def Hi(self, text, salt, iterations):
|
def Hi(self, text: str, salt: bytes_, iterations: int):
|
||||||
text = bytes(text)
|
text_enc = bytes(text)
|
||||||
ui1 = self.HMAC(text, salt + b'\0\0\0\01')
|
ui1 = self.HMAC(text_enc, salt + b'\0\0\0\01')
|
||||||
ui = ui1
|
ui = ui1
|
||||||
for i in range(iterations - 1):
|
for i in range(iterations - 1):
|
||||||
ui1 = self.HMAC(text, ui1)
|
ui1 = self.HMAC(text_enc, ui1)
|
||||||
ui = XOR(ui, ui1)
|
ui = XOR(ui, ui1)
|
||||||
return ui
|
return ui
|
||||||
|
|
||||||
def H(self, text):
|
def H(self, text: str) -> bytes_:
|
||||||
return self.hash(text).digest()
|
return self.hash(text).digest()
|
||||||
|
|
||||||
def saslname(self, value):
|
def saslname(self, value_b: bytes_) -> bytes_:
|
||||||
value = value.decode("utf-8")
|
value = value_b.decode("utf-8")
|
||||||
escaped = []
|
escaped: List[str] = []
|
||||||
for char in value:
|
for char in value:
|
||||||
if char == ',':
|
if char == ',':
|
||||||
escaped += b'=2C'
|
escaped.append('=2C')
|
||||||
elif char == '=':
|
elif char == '=':
|
||||||
escaped += b'=3D'
|
escaped.append('=3D')
|
||||||
else:
|
else:
|
||||||
escaped += char
|
escaped.append(char)
|
||||||
return "".join(escaped).encode("utf-8")
|
return "".join(escaped).encode("utf-8")
|
||||||
|
|
||||||
def parse(self, challenge):
|
def parse(self, challenge: bytes_) -> Dict[bytes_, bytes_]:
|
||||||
items = {}
|
items = {}
|
||||||
for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]:
|
for key, value in [item.split(b'=', 1) for item in challenge.split(b',')]:
|
||||||
items[key] = value
|
items[key] = value
|
||||||
return items
|
return items
|
||||||
|
|
||||||
def process(self, challenge=b''):
|
def process(self, challenge: bytes_ = b''):
|
||||||
steps = [self.process_1, self.process_2, self.process_3]
|
steps = [self.process_1, self.process_2, self.process_3]
|
||||||
return steps[self.step](challenge)
|
return steps[self.step](challenge)
|
||||||
|
|
||||||
def process_1(self, challenge):
|
def process_1(self, challenge: bytes_) -> bytes_:
|
||||||
self.step = 1
|
self.step = 1
|
||||||
data = {}
|
|
||||||
|
|
||||||
self.cnonce = bytes(('%s' % random.random())[2:])
|
self.cnonce = bytes(('%s' % random.random())[2:])
|
||||||
|
|
||||||
@ -263,7 +265,7 @@ class SCRAM(Mech):
|
|||||||
|
|
||||||
return self.client_first_message
|
return self.client_first_message
|
||||||
|
|
||||||
def process_2(self, challenge):
|
def process_2(self, challenge: bytes_) -> bytes_:
|
||||||
self.step = 2
|
self.step = 2
|
||||||
|
|
||||||
data = self.parse(challenge)
|
data = self.parse(challenge)
|
||||||
@ -304,7 +306,7 @@ class SCRAM(Mech):
|
|||||||
|
|
||||||
return client_final_message
|
return client_final_message
|
||||||
|
|
||||||
def process_3(self, challenge):
|
def process_3(self, challenge: bytes_) -> bytes_:
|
||||||
data = self.parse(challenge)
|
data = self.parse(challenge)
|
||||||
verifier = data.get(b'v', None)
|
verifier = data.get(b'v', None)
|
||||||
error = data.get(b'e', 'Unknown error')
|
error = data.get(b'e', 'Unknown error')
|
||||||
@ -345,17 +347,16 @@ class DIGEST(Mech):
|
|||||||
self.cnonce = b''
|
self.cnonce = b''
|
||||||
self.nonce_count = 1
|
self.nonce_count = 1
|
||||||
|
|
||||||
def parse(self, challenge=b''):
|
def parse(self, challenge: bytes_ = b''):
|
||||||
data = {}
|
data: Dict[str, bytes_] = {}
|
||||||
var_name = b''
|
var_name = b''
|
||||||
var_value = b''
|
var_value = b''
|
||||||
|
|
||||||
# States: var, new_var, end, quote, escaped_quote
|
# States: var, new_var, end, quote, escaped_quote
|
||||||
state = 'var'
|
state = 'var'
|
||||||
|
|
||||||
|
for char_int in challenge:
|
||||||
for char in challenge:
|
char = bytes_([char_int])
|
||||||
char = bytes([char])
|
|
||||||
|
|
||||||
if state == 'var':
|
if state == 'var':
|
||||||
if char.isspace():
|
if char.isspace():
|
||||||
@ -401,14 +402,14 @@ class DIGEST(Mech):
|
|||||||
state = 'var'
|
state = 'var'
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def MAC(self, key, seq, msg):
|
def MAC(self, key: bytes_, seq: int, msg: bytes_) -> bytes_:
|
||||||
mac = hmac.HMAC(key=key, digestmod=self.hash)
|
mac = hmac.HMAC(key=key, digestmod=self.hash)
|
||||||
seqnum = num_to_bytes(seq)
|
seqnum = num_to_bytes(seq)
|
||||||
mac.update(seqnum)
|
mac.update(seqnum)
|
||||||
mac.update(msg)
|
mac.update(msg)
|
||||||
return mac.digest()[:10] + b'\x00\x01' + seqnum
|
return mac.digest()[:10] + b'\x00\x01' + seqnum
|
||||||
|
|
||||||
def A1(self):
|
def A1(self) -> bytes_:
|
||||||
username = self.credentials['username']
|
username = self.credentials['username']
|
||||||
password = self.credentials['password']
|
password = self.credentials['password']
|
||||||
authzid = self.credentials['authzid']
|
authzid = self.credentials['authzid']
|
||||||
@ -423,13 +424,13 @@ class DIGEST(Mech):
|
|||||||
|
|
||||||
return bytes(a1)
|
return bytes(a1)
|
||||||
|
|
||||||
def A2(self, prefix=b''):
|
def A2(self, prefix: bytes_ = b'') -> bytes_:
|
||||||
a2 = prefix + b':' + self.digest_uri()
|
a2 = prefix + b':' + self.digest_uri()
|
||||||
if self.qop in (b'auth-int', b'auth-conf'):
|
if self.qop in (b'auth-int', b'auth-conf'):
|
||||||
a2 += b':00000000000000000000000000000000'
|
a2 += b':00000000000000000000000000000000'
|
||||||
return bytes(a2)
|
return bytes(a2)
|
||||||
|
|
||||||
def response(self, prefix=b''):
|
def response(self, prefix: bytes_ = b'') -> bytes_:
|
||||||
nc = bytes('%08x' % self.nonce_count)
|
nc = bytes('%08x' % self.nonce_count)
|
||||||
|
|
||||||
a1 = bytes(self.hash(self.A1()).hexdigest().lower())
|
a1 = bytes(self.hash(self.A1()).hexdigest().lower())
|
||||||
@ -439,7 +440,7 @@ class DIGEST(Mech):
|
|||||||
|
|
||||||
return bytes(self.hash(a1 + b':' + s).hexdigest().lower())
|
return bytes(self.hash(a1 + b':' + s).hexdigest().lower())
|
||||||
|
|
||||||
def digest_uri(self):
|
def digest_uri(self) -> bytes_:
|
||||||
serv_type = self.credentials['service']
|
serv_type = self.credentials['service']
|
||||||
serv_name = self.credentials['service-name']
|
serv_name = self.credentials['service-name']
|
||||||
host = self.credentials['host']
|
host = self.credentials['host']
|
||||||
@ -449,7 +450,7 @@ class DIGEST(Mech):
|
|||||||
uri += b'/' + serv_name
|
uri += b'/' + serv_name
|
||||||
return uri
|
return uri
|
||||||
|
|
||||||
def respond(self):
|
def respond(self) -> bytes_:
|
||||||
data = {
|
data = {
|
||||||
'username': quote(self.credentials['username']),
|
'username': quote(self.credentials['username']),
|
||||||
'authzid': quote(self.credentials['authzid']),
|
'authzid': quote(self.credentials['authzid']),
|
||||||
@ -469,7 +470,7 @@ class DIGEST(Mech):
|
|||||||
resp += b',' + bytes(key) + b'=' + bytes(value)
|
resp += b',' + bytes(key) + b'=' + bytes(value)
|
||||||
return resp[1:]
|
return resp[1:]
|
||||||
|
|
||||||
def process(self, challenge=b''):
|
def process(self, challenge: bytes_ = b'') -> Optional[bytes_]:
|
||||||
if not challenge:
|
if not challenge:
|
||||||
if self.cnonce and self.nonce and self.nonce_count and self.qop:
|
if self.cnonce and self.nonce and self.nonce_count and self.qop:
|
||||||
self.nonce_count += 1
|
self.nonce_count += 1
|
||||||
@ -480,6 +481,7 @@ class DIGEST(Mech):
|
|||||||
if 'rspauth' in data:
|
if 'rspauth' in data:
|
||||||
if data['rspauth'] != self.response():
|
if data['rspauth'] != self.response():
|
||||||
raise SASLMutualAuthFailed()
|
raise SASLMutualAuthFailed()
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
self.nonce_count = 1
|
self.nonce_count = 1
|
||||||
self.cnonce = bytes('%s' % random.random())[2:]
|
self.cnonce = bytes('%s' % random.random())[2:]
|
||||||
|
Loading…
Reference in New Issue
Block a user