Add JID escaping support.

This commit is contained in:
Lance Stout 2012-07-22 23:41:46 -07:00
parent e4e18a416f
commit b5c9c98a8b

View File

@ -29,6 +29,30 @@ ILLEGAL_CHARS = '\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r' + \
JID_PATTERN = "^(?:([^\"&'/:<>@]{1,1023})@)?([^/@]{1,1023})(?:/(.{1,1023}))?$" JID_PATTERN = "^(?:([^\"&'/:<>@]{1,1023})@)?([^/@]{1,1023})(?:/(.{1,1023}))?$"
JID_ESCAPE_SEQUENCES = set(['\\20', '\\22', '\\26', '\\27', '\\2f',
'\\3a', '\\3c', '\\3e', '\\40', '\\5c'])
JID_ESCAPE_TRANSFORMATIONS = {' ': '\\20',
'"': '\\22',
'&': '\\26',
"'": '\\27',
'/': '\\2f',
':': '\\3a',
'<': '\\3c',
'>': '\\3e',
'@': '\\40'}
JID_UNESCAPE_TRANSFORMATIONS = {'\\20': ' ',
'\\22': '"',
'\\26': '&',
'\\27': "'",
'\\2f': '/',
'\\3a': ':',
'\\3c': '<',
'\\3e': '>',
'\\40': '@',
'\\5c': '\\'}
nodeprep = stringprep_profiles.create( nodeprep = stringprep_profiles.create(
nfkc=True, nfkc=True,
@ -70,21 +94,33 @@ resourceprep = stringprep_profiles.create(
unassigned=[stringprep.in_table_a1]) unassigned=[stringprep.in_table_a1])
class InvalidJID(ValueError): def _parse_jid(data):
pass
def parse_jid(data):
""" """
Parse string data into the node, domain, and resource Parse string data into the node, domain, and resource
components of a JID. components of a JID.
""" """
match = re.match(JID_PATTERN, data) match = re.match(JID_PATTERN, data)
if not match: if not match:
raise InvalidJID raise InvalidJID('JID could not be parsed')
(node, domain, resource) = match.groups() (node, domain, resource) = match.groups()
_validate_node(node)
_validate_domain(domain)
_validate_resource(resource)
return node, domain, resource
def _validate_node(node):
try:
if node is not None:
node = nodeprep(node)
except stringprep_profiles.StringPrepError:
raise InvalidJID('Invalid local part')
def _validate_domain(domain):
ip_addr = False ip_addr = False
try: try:
@ -107,27 +143,122 @@ def parse_jid(data):
label = encodings.idna.nameprep(label) label = encodings.idna.nameprep(label)
encodings.idna.ToASCII(label) encodings.idna.ToASCII(label)
except UnicodeError: except UnicodeError:
raise InvalidJID raise InvalidJID('Could not encode domain as ASCII')
for char in label: for char in label:
if char in ILLEGAL_CHARS: if char in ILLEGAL_CHARS:
raise InvalidJID raise InvalidJID('Domain contains illegar characters')
if '-' in (label[0], label[-1]): if '-' in (label[0], label[-1]):
raise InvalidJID raise InvalidJID('Domain started or ended with -')
domain_parts.append(label) domain_parts.append(label)
domain = '.'.join(domain_parts) domain = '.'.join(domain_parts)
if not domain:
raise InvalidJID('Missing domain')
def _validate_resource(resource):
try: try:
if node is not None:
node = nodeprep(node)
if resource is not None: if resource is not None:
resource = resourceprep(resource) resource = resourceprep(resource)
except stringprep_profiles.StringPrepError: except stringprep_profiles.StringPrepError:
raise InvalidJID raise InvalidJID('Invalid resource')
return node, domain, resource
def _escape_node(node):
result = []
for i, char in enumerate(node):
if char == '\\':
if ''.join((data[i:i+3])) in JID_ESCAPE_SEQUENCES:
result.append('\\5c')
continue
result.append(char)
for i, char in enumerate(result):
result[i] = JID_ESCAPE_TRANSFORMATIONS.get(char, char)
escaped = ''.join(result)
if escaped.startswith('\\20') or escaped.endswith('\\20'):
raise InvalidJID('Escaped local part starts or ends with "\\20"')
_validate_node(escaped)
return escaped
def _unescape_node(node):
unescaped = []
seq = ''
for i, char in enumerate(node):
if char == '\\':
seq = node[i:i+3]
if seq not in JID_ESCAPE_SEQUENCES:
seq = ''
if seq:
if len(seq) == 3:
unescaped.append(JID_UNESCAPE_TRANSFORMATIONS.get(seq, char))
# Pop character off the escape sequence, and ignore it
seq = seq[1:]
else:
unescaped.append(char)
unescaped = ''.join(unescaped)
return unescaped
def _format_jid(local=None, domain=None, resource=None):
result = []
if local:
result.append(local)
result.append('@')
if domain:
result.append(domain)
if resource:
result.append('/')
result.append(resource)
return ''.join(result)
class InvalidJID(ValueError):
pass
class UnescapedJID(object):
def __init__(self, local, domain, resource):
self._jid = (local, domain, resource)
def __getattr__(self, name):
"""
:param name: one of: user, server, domain, resource,
full, or bare.
"""
if name == 'resource':
return self._jid[2] or ''
elif name in ('user', 'username', 'local', 'node'):
return self._jid[0] or ''
elif name in ('server', 'domain', 'host'):
return self._jid[1] or ''
elif name in ('full', 'jid'):
return _format_jid(*self._jid)
elif name == 'bare':
return _format_jid(self._jid[0], self._jid[1])
elif name == '_jid':
return getattr(super(JID, self), '_jid')
else:
return None
def __str__(self):
"""Use the full JID as the string value."""
return _format_jid(*self._jid)
def __repr__(self):
return self.__str__()
class JID(object): class JID(object):
@ -157,21 +288,37 @@ class JID(object):
:param string jid: A string of the form ``'[user@]domain[/resource]'``. :param string jid: A string of the form ``'[user@]domain[/resource]'``.
""" """
def __init__(self, jid=None, local=None, domain=None, resource=None): def __init__(self, jid=None, **kwargs):
"""Initialize a new JID""" """Initialize a new JID"""
self._jid = (None, None, None) self._jid = (None, None, None)
if jid is None or jid == '': if jid is None or jid == '':
jid = (None, None, None) jid = (None, None, None)
elif not isinstance(jid, JID): elif not isinstance(jid, JID):
jid = parse_jid(jid) jid = _parse_jid(jid)
else: else:
jid = jid._jid jid = jid._jid
orig_local, orig_domain, orig_resource = jid local, domain, resource = jid
self._jid = (local or orig_local or None, validated = True
domain or orig_domain or None,
resource or orig_resource or None) local = kwargs.get('local', local)
domain = kwargs.get('domain', domain)
resource = kwargs.get('resource', resource)
if 'local' in kwargs:
local = _escape_node(local)
if 'domain' in kwargs:
_validate_domain(domain)
if 'resource' in kwargs:
_validate_resource(resource)
self._jid = (local, domain, resource)
def unescape(self):
return UnescapedJID(_unescape_node(self._jid[0]),
self._jid[1],
self._jid[2])
def regenerate(self): def regenerate(self):
"""Deprecated""" """Deprecated"""
@ -185,8 +332,7 @@ class JID(object):
self._jid = JID(data)._jid self._jid = JID(data)._jid
def __getattr__(self, name): def __getattr__(self, name):
"""handle getting the jid values, using cache if available. """
:param name: one of: user, server, domain, resource, :param name: one of: user, server, domain, resource,
full, or bare. full, or bare.
""" """
@ -197,16 +343,16 @@ class JID(object):
elif name in ('server', 'domain', 'host'): elif name in ('server', 'domain', 'host'):
return self._jid[1] or '' return self._jid[1] or ''
elif name in ('full', 'jid'): elif name in ('full', 'jid'):
return str(self) return _format_jid(*self._jid)
elif name == 'bare': elif name == 'bare':
return str(JID(local=self._jid[0], return _format_jid(self._jid[0], self._jid[1])
domain=self._jid[1])) elif name == '_jid':
return getattr(super(JID, self), '_jid')
else: else:
object.__getattr__(self, name) return None
def __setattr__(self, name, value): def __setattr__(self, name, value):
"""handle getting the jid values, using cache if available. """
:param name: one of: ``user``, ``username``, ``local``, :param name: one of: ``user``, ``username``, ``local``,
``node``, ``server``, ``domain``, ``host``, ``node``, ``server``, ``domain``, ``host``,
``resource``, ``full``, ``jid``, or ``bare``. ``resource``, ``full``, ``jid``, or ``bare``.
@ -223,21 +369,12 @@ class JID(object):
elif name == 'bare': elif name == 'bare':
parsed = JID(value)._jid parsed = JID(value)._jid
self._jid = (parsed[0], parsed[1], self._jid[2]) self._jid = (parsed[0], parsed[1], self._jid[2])
else: elif name == '_jid':
object.__setattr__(self, name, value) super(JID, self).__setattr__('_jid', value)
def __str__(self): def __str__(self):
"""Use the full JID as the string value.""" """Use the full JID as the string value."""
result = [] return _format_jid(*self._jid)
if self._jid[0]:
result.append(self._jid[0])
result.append('@')
if self._jid[1]:
result.append(self._jid[1])
if self._jid[2]:
result.append('/')
result.append(self._jid[2])
return ''.join(result)
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
@ -246,6 +383,9 @@ class JID(object):
""" """
Two JIDs are considered equal if they have the same full JID value. Two JIDs are considered equal if they have the same full JID value.
""" """
if isinstance(other, UnescapedJID):
return False
other = JID(other) other = JID(other)
return self._jid == other._jid return self._jid == other._jid