Fix JID validation bugs, add lots of tests.
This commit is contained in:
parent
78aa5c3dfa
commit
352ee2f2fd
@ -10,7 +10,7 @@ from sleekxmpp.basexmpp import BaseXMPP
|
|||||||
from sleekxmpp.clientxmpp import ClientXMPP
|
from sleekxmpp.clientxmpp import ClientXMPP
|
||||||
from sleekxmpp.componentxmpp import ComponentXMPP
|
from sleekxmpp.componentxmpp import ComponentXMPP
|
||||||
from sleekxmpp.stanza import Message, Presence, Iq
|
from sleekxmpp.stanza import Message, Presence, Iq
|
||||||
from sleekxmpp.jid import JID
|
from sleekxmpp.jid import JID, InvalidJID
|
||||||
from sleekxmpp.xmlstream.handler import *
|
from sleekxmpp.xmlstream.handler import *
|
||||||
from sleekxmpp.xmlstream import XMLStream, RestartStream
|
from sleekxmpp.xmlstream import XMLStream, RestartStream
|
||||||
from sleekxmpp.xmlstream.matcher import *
|
from sleekxmpp.xmlstream.matcher import *
|
||||||
|
@ -140,13 +140,12 @@ def _validate_node(node):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if node is not None:
|
if node is not None:
|
||||||
if not node:
|
|
||||||
raise InvalidJID('Localpart must not be 0 bytes')
|
|
||||||
|
|
||||||
node = nodeprep(node)
|
node = nodeprep(node)
|
||||||
|
|
||||||
if not node:
|
if not node:
|
||||||
raise InvalidJID('Localpart must not be 0 bytes')
|
raise InvalidJID('Localpart must not be 0 bytes')
|
||||||
|
if len(node) > 1023:
|
||||||
|
raise InvalidJID('Localpart must be less than 1024 bytes')
|
||||||
return node
|
return node
|
||||||
except stringprep_profiles.StringPrepError:
|
except stringprep_profiles.StringPrepError:
|
||||||
raise InvalidJID('Invalid local part')
|
raise InvalidJID('Invalid local part')
|
||||||
@ -179,6 +178,7 @@ def _validate_domain(domain):
|
|||||||
if not ip_addr and hasattr(socket, 'inet_pton'):
|
if not ip_addr and hasattr(socket, 'inet_pton'):
|
||||||
try:
|
try:
|
||||||
socket.inet_pton(socket.AF_INET6, domain.strip('[]'))
|
socket.inet_pton(socket.AF_INET6, domain.strip('[]'))
|
||||||
|
domain = '[%s]' % domain.strip('[]')
|
||||||
ip_addr = True
|
ip_addr = True
|
||||||
except socket.error:
|
except socket.error:
|
||||||
pass
|
pass
|
||||||
@ -186,12 +186,19 @@ def _validate_domain(domain):
|
|||||||
if not ip_addr:
|
if not ip_addr:
|
||||||
# This is a domain name, which must be checked further
|
# This is a domain name, which must be checked further
|
||||||
|
|
||||||
|
if domain and domain[-1] == '.':
|
||||||
|
domain = domain[:-1]
|
||||||
|
|
||||||
domain_parts = []
|
domain_parts = []
|
||||||
for label in domain.split('.'):
|
for label in domain.split('.'):
|
||||||
try:
|
try:
|
||||||
label = encodings.idna.nameprep(label)
|
label = encodings.idna.nameprep(label)
|
||||||
encodings.idna.ToASCII(label)
|
encodings.idna.ToASCII(label)
|
||||||
|
pass_nameprep = True
|
||||||
except UnicodeError:
|
except UnicodeError:
|
||||||
|
pass_nameprep = False
|
||||||
|
|
||||||
|
if not pass_nameprep:
|
||||||
raise InvalidJID('Could not encode domain as ASCII')
|
raise InvalidJID('Could not encode domain as ASCII')
|
||||||
|
|
||||||
if label.startswith('xn--'):
|
if label.startswith('xn--'):
|
||||||
@ -209,6 +216,8 @@ def _validate_domain(domain):
|
|||||||
|
|
||||||
if not domain:
|
if not domain:
|
||||||
raise InvalidJID('Domain must not be 0 bytes')
|
raise InvalidJID('Domain must not be 0 bytes')
|
||||||
|
if len(domain) > 1023:
|
||||||
|
raise InvalidJID('Domain must be less than 1024 bytes')
|
||||||
|
|
||||||
return domain
|
return domain
|
||||||
|
|
||||||
@ -222,13 +231,12 @@ def _validate_resource(resource):
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if resource is not None:
|
if resource is not None:
|
||||||
if not resource:
|
|
||||||
raise InvalidJID('Resource must not be 0 bytes')
|
|
||||||
|
|
||||||
resource = resourceprep(resource)
|
resource = resourceprep(resource)
|
||||||
|
|
||||||
if not resource:
|
if not resource:
|
||||||
raise InvalidJID('Resource must not be 0 bytes')
|
raise InvalidJID('Resource must not be 0 bytes')
|
||||||
|
if len(resource) > 1023:
|
||||||
|
raise InvalidJID('Resource must be less than 1024 bytes')
|
||||||
return resource
|
return resource
|
||||||
except stringprep_profiles.StringPrepError:
|
except stringprep_profiles.StringPrepError:
|
||||||
raise InvalidJID('Invalid resource')
|
raise InvalidJID('Invalid resource')
|
||||||
|
@ -77,6 +77,9 @@ def check_bidi(data):
|
|||||||
character MUST be the first character of the string, and a
|
character MUST be the first character of the string, and a
|
||||||
RandALCat character MUST be the last character of the string.
|
RandALCat character MUST be the last character of the string.
|
||||||
"""
|
"""
|
||||||
|
if not data:
|
||||||
|
return data
|
||||||
|
|
||||||
has_lcat = False
|
has_lcat = False
|
||||||
has_randal = False
|
has_randal = False
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from sleekxmpp.test import *
|
from sleekxmpp.test import *
|
||||||
from sleekxmpp import JID
|
from sleekxmpp import JID, InvalidJID
|
||||||
|
|
||||||
|
|
||||||
class TestJIDClass(SleekTest):
|
class TestJIDClass(SleekTest):
|
||||||
@ -137,5 +137,146 @@ class TestJIDClass(SleekTest):
|
|||||||
self.assertFalse(jid1 == jid2, "Same JIDs are not considered equal")
|
self.assertFalse(jid1 == jid2, "Same JIDs are not considered equal")
|
||||||
self.assertTrue(jid1 != jid2, "Same JIDs are considered not equal")
|
self.assertTrue(jid1 != jid2, "Same JIDs are considered not equal")
|
||||||
|
|
||||||
|
def testZeroLengthDomain(self):
|
||||||
|
self.assertRaises(InvalidJID, JID, domain='')
|
||||||
|
self.assertRaises(InvalidJID, JID, 'user@/resource')
|
||||||
|
|
||||||
|
def testZeroLengthLocalPart(self):
|
||||||
|
self.assertRaises(InvalidJID, JID, local='', domain='test.com')
|
||||||
|
self.assertRaises(InvalidJID, JID, '@/test.com')
|
||||||
|
|
||||||
|
def testZeroLengthResource(self):
|
||||||
|
self.assertRaises(InvalidJID, JID, domain='test.com', resource='')
|
||||||
|
self.assertRaises(InvalidJID, JID, 'test.com/')
|
||||||
|
|
||||||
|
def test1023LengthDomain(self):
|
||||||
|
domain = ('a.' * 509) + 'a.com'
|
||||||
|
jid1 = JID(domain=domain)
|
||||||
|
jid2 = JID('user@%s/resource' % domain)
|
||||||
|
|
||||||
|
def test1023LengthLocalPart(self):
|
||||||
|
local = 'a' * 1023
|
||||||
|
jid1 = JID(local=local, domain='test.com')
|
||||||
|
jid2 = JID('%s@test.com' % local)
|
||||||
|
|
||||||
|
def test1023LengthResource(self):
|
||||||
|
resource = 'r' * 1023
|
||||||
|
jid1 = JID(domain='test.com', resource=resource)
|
||||||
|
jid2 = JID('test.com/%s' % resource)
|
||||||
|
|
||||||
|
def test1024LengthDomain(self):
|
||||||
|
domain = ('a.' * 509) + 'aa.com'
|
||||||
|
self.assertRaises(InvalidJID, JID, domain=domain)
|
||||||
|
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
|
||||||
|
|
||||||
|
def test1024LengthLocalPart(self):
|
||||||
|
local = 'a' * 1024
|
||||||
|
self.assertRaises(InvalidJID, JID, local=local, domain='test.com')
|
||||||
|
self.assertRaises(InvalidJID, JID, '%s@/test.com' % local)
|
||||||
|
|
||||||
|
def test1024LengthResource(self):
|
||||||
|
resource = 'r' * 1024
|
||||||
|
self.assertRaises(InvalidJID, JID, domain='test.com', resource=resource)
|
||||||
|
self.assertRaises(InvalidJID, JID, 'test.com/%s' % resource)
|
||||||
|
|
||||||
|
def testTooLongDomainLabel(self):
|
||||||
|
domain = ('a' * 64) + '.com'
|
||||||
|
self.assertRaises(InvalidJID, JID, domain=domain)
|
||||||
|
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
|
||||||
|
|
||||||
|
def testDomainEmptyLabel(self):
|
||||||
|
domain = 'aaa..bbb.com'
|
||||||
|
self.assertRaises(InvalidJID, JID, domain=domain)
|
||||||
|
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
|
||||||
|
|
||||||
|
def testDomainIPv4(self):
|
||||||
|
domain = '127.0.0.1'
|
||||||
|
jid1 = JID(domain=domain)
|
||||||
|
jid2 = JID('user@%s/resource' % domain)
|
||||||
|
|
||||||
|
def testDomainIPv6(self):
|
||||||
|
domain = '[::1]'
|
||||||
|
jid1 = JID(domain=domain)
|
||||||
|
jid2 = JID('user@%s/resource' % domain)
|
||||||
|
|
||||||
|
def testDomainInvalidIPv6NoBrackets(self):
|
||||||
|
domain = '::1'
|
||||||
|
jid1 = JID(domain=domain)
|
||||||
|
jid2 = JID('user@%s/resource' % domain)
|
||||||
|
|
||||||
|
self.assertEqual(jid1.domain, '[::1]')
|
||||||
|
self.assertEqual(jid2.domain, '[::1]')
|
||||||
|
|
||||||
|
def testDomainInvalidIPv6MissingBracket(self):
|
||||||
|
domain = '[::1'
|
||||||
|
jid1 = JID(domain=domain)
|
||||||
|
jid2 = JID('user@%s/resource' % domain)
|
||||||
|
|
||||||
|
self.assertEqual(jid1.domain, '[::1]')
|
||||||
|
self.assertEqual(jid2.domain, '[::1]')
|
||||||
|
|
||||||
|
def testDomainWithPort(self):
|
||||||
|
domain = 'example.com:5555'
|
||||||
|
self.assertRaises(InvalidJID, JID, domain=domain)
|
||||||
|
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
|
||||||
|
|
||||||
|
def testDomainWithTrailingDot(self):
|
||||||
|
domain = 'example.com.'
|
||||||
|
jid1 = JID(domain=domain)
|
||||||
|
jid2 = JID('user@%s/resource' % domain)
|
||||||
|
|
||||||
|
self.assertEqual(jid1.domain, 'example.com')
|
||||||
|
self.assertEqual(jid2.domain, 'example.com')
|
||||||
|
|
||||||
|
def testDomainWithDashes(self):
|
||||||
|
domain = 'example.com-'
|
||||||
|
self.assertRaises(InvalidJID, JID, domain=domain)
|
||||||
|
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
|
||||||
|
|
||||||
|
domain = '-example.com'
|
||||||
|
self.assertRaises(InvalidJID, JID, domain=domain)
|
||||||
|
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
|
||||||
|
|
||||||
|
def testACEDomain(self):
|
||||||
|
domain = 'xn--bcher-kva.ch'
|
||||||
|
jid1 = JID(domain=domain)
|
||||||
|
jid2 = JID('user@%s/resource' % domain)
|
||||||
|
|
||||||
|
self.assertEqual(jid1.domain.encode('utf-8'), b'b\xc3\xbccher.ch')
|
||||||
|
self.assertEqual(jid2.domain.encode('utf-8'), b'b\xc3\xbccher.ch')
|
||||||
|
|
||||||
|
def testJIDEscapeExistingSequences(self):
|
||||||
|
jid = JID(local='blah\\foo\\20bar', domain='example.com')
|
||||||
|
self.assertEqual(jid.local, 'blah\\foo\\5c20bar')
|
||||||
|
|
||||||
|
def testJIDEscape(self):
|
||||||
|
jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")',
|
||||||
|
domain='example.com')
|
||||||
|
self.assertEqual(jid.local, r'here\27s_a_wild_\26_\2fcr%zy\2f_address_for\3a\3cwv\3e(\22IMPS\22)')
|
||||||
|
|
||||||
|
def testJIDUnescape(self):
|
||||||
|
jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")',
|
||||||
|
domain='example.com')
|
||||||
|
ujid = jid.unescape()
|
||||||
|
self.assertEqual(ujid.local, 'here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")')
|
||||||
|
|
||||||
|
jid = JID(local='blah\\foo\\20bar', domain='example.com')
|
||||||
|
ujid = jid.unescape()
|
||||||
|
self.assertEqual(ujid.local, 'blah\\foo\\20bar')
|
||||||
|
|
||||||
|
def testStartOrEndWithEscapedSpaces(self):
|
||||||
|
local = ' foo'
|
||||||
|
self.assertRaises(InvalidJID, JID, local=local, domain='example.com')
|
||||||
|
self.assertRaises(InvalidJID, JID, '%s@example.com' % local)
|
||||||
|
|
||||||
|
local = 'bar '
|
||||||
|
self.assertRaises(InvalidJID, JID, local=local, domain='example.com')
|
||||||
|
self.assertRaises(InvalidJID, JID, '%s@example.com' % local)
|
||||||
|
|
||||||
|
# Need more input for these cases. A JID starting with \20 *is* valid
|
||||||
|
# according to RFC 6122, but is not according to XEP-0106.
|
||||||
|
#self.assertRaises(InvalidJID, JID, '%s@example.com' % '\\20foo2')
|
||||||
|
#self.assertRaises(InvalidJID, JID, '%s@example.com' % 'bar2\\20')
|
||||||
|
|
||||||
|
|
||||||
suite = unittest.TestLoader().loadTestsFromTestCase(TestJIDClass)
|
suite = unittest.TestLoader().loadTestsFromTestCase(TestJIDClass)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user