Fix JID validation bugs, add lots of tests.

This commit is contained in:
Lance Stout 2012-07-23 21:45:24 -07:00
parent 78aa5c3dfa
commit 352ee2f2fd
4 changed files with 160 additions and 8 deletions

View File

@ -10,7 +10,7 @@ from sleekxmpp.basexmpp import BaseXMPP
from sleekxmpp.clientxmpp import ClientXMPP from sleekxmpp.clientxmpp import ClientXMPP
from sleekxmpp.componentxmpp import ComponentXMPP from sleekxmpp.componentxmpp import ComponentXMPP
from sleekxmpp.stanza import Message, Presence, Iq from sleekxmpp.stanza import Message, Presence, Iq
from sleekxmpp.jid import JID from sleekxmpp.jid import JID, InvalidJID
from sleekxmpp.xmlstream.handler import * from sleekxmpp.xmlstream.handler import *
from sleekxmpp.xmlstream import XMLStream, RestartStream from sleekxmpp.xmlstream import XMLStream, RestartStream
from sleekxmpp.xmlstream.matcher import * from sleekxmpp.xmlstream.matcher import *

View File

@ -140,13 +140,12 @@ def _validate_node(node):
""" """
try: try:
if node is not None: if node is not None:
if not node:
raise InvalidJID('Localpart must not be 0 bytes')
node = nodeprep(node) node = nodeprep(node)
if not node: if not node:
raise InvalidJID('Localpart must not be 0 bytes') raise InvalidJID('Localpart must not be 0 bytes')
if len(node) > 1023:
raise InvalidJID('Localpart must be less than 1024 bytes')
return node return node
except stringprep_profiles.StringPrepError: except stringprep_profiles.StringPrepError:
raise InvalidJID('Invalid local part') raise InvalidJID('Invalid local part')
@ -179,6 +178,7 @@ def _validate_domain(domain):
if not ip_addr and hasattr(socket, 'inet_pton'): if not ip_addr and hasattr(socket, 'inet_pton'):
try: try:
socket.inet_pton(socket.AF_INET6, domain.strip('[]')) socket.inet_pton(socket.AF_INET6, domain.strip('[]'))
domain = '[%s]' % domain.strip('[]')
ip_addr = True ip_addr = True
except socket.error: except socket.error:
pass pass
@ -186,12 +186,19 @@ def _validate_domain(domain):
if not ip_addr: if not ip_addr:
# This is a domain name, which must be checked further # This is a domain name, which must be checked further
if domain and domain[-1] == '.':
domain = domain[:-1]
domain_parts = [] domain_parts = []
for label in domain.split('.'): for label in domain.split('.'):
try: try:
label = encodings.idna.nameprep(label) label = encodings.idna.nameprep(label)
encodings.idna.ToASCII(label) encodings.idna.ToASCII(label)
pass_nameprep = True
except UnicodeError: except UnicodeError:
pass_nameprep = False
if not pass_nameprep:
raise InvalidJID('Could not encode domain as ASCII') raise InvalidJID('Could not encode domain as ASCII')
if label.startswith('xn--'): if label.startswith('xn--'):
@ -209,6 +216,8 @@ def _validate_domain(domain):
if not domain: if not domain:
raise InvalidJID('Domain must not be 0 bytes') raise InvalidJID('Domain must not be 0 bytes')
if len(domain) > 1023:
raise InvalidJID('Domain must be less than 1024 bytes')
return domain return domain
@ -222,13 +231,12 @@ def _validate_resource(resource):
""" """
try: try:
if resource is not None: if resource is not None:
if not resource:
raise InvalidJID('Resource must not be 0 bytes')
resource = resourceprep(resource) resource = resourceprep(resource)
if not resource: if not resource:
raise InvalidJID('Resource must not be 0 bytes') raise InvalidJID('Resource must not be 0 bytes')
if len(resource) > 1023:
raise InvalidJID('Resource must be less than 1024 bytes')
return resource return resource
except stringprep_profiles.StringPrepError: except stringprep_profiles.StringPrepError:
raise InvalidJID('Invalid resource') raise InvalidJID('Invalid resource')

View File

@ -77,6 +77,9 @@ def check_bidi(data):
character MUST be the first character of the string, and a character MUST be the first character of the string, and a
RandALCat character MUST be the last character of the string. RandALCat character MUST be the last character of the string.
""" """
if not data:
return data
has_lcat = False has_lcat = False
has_randal = False has_randal = False

View File

@ -1,5 +1,5 @@
from sleekxmpp.test import * from sleekxmpp.test import *
from sleekxmpp import JID from sleekxmpp import JID, InvalidJID
class TestJIDClass(SleekTest): class TestJIDClass(SleekTest):
@ -137,5 +137,146 @@ class TestJIDClass(SleekTest):
self.assertFalse(jid1 == jid2, "Same JIDs are not considered equal") self.assertFalse(jid1 == jid2, "Same JIDs are not considered equal")
self.assertTrue(jid1 != jid2, "Same JIDs are considered not equal") self.assertTrue(jid1 != jid2, "Same JIDs are considered not equal")
def testZeroLengthDomain(self):
self.assertRaises(InvalidJID, JID, domain='')
self.assertRaises(InvalidJID, JID, 'user@/resource')
def testZeroLengthLocalPart(self):
self.assertRaises(InvalidJID, JID, local='', domain='test.com')
self.assertRaises(InvalidJID, JID, '@/test.com')
def testZeroLengthResource(self):
self.assertRaises(InvalidJID, JID, domain='test.com', resource='')
self.assertRaises(InvalidJID, JID, 'test.com/')
def test1023LengthDomain(self):
domain = ('a.' * 509) + 'a.com'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
def test1023LengthLocalPart(self):
local = 'a' * 1023
jid1 = JID(local=local, domain='test.com')
jid2 = JID('%s@test.com' % local)
def test1023LengthResource(self):
resource = 'r' * 1023
jid1 = JID(domain='test.com', resource=resource)
jid2 = JID('test.com/%s' % resource)
def test1024LengthDomain(self):
domain = ('a.' * 509) + 'aa.com'
self.assertRaises(InvalidJID, JID, domain=domain)
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
def test1024LengthLocalPart(self):
local = 'a' * 1024
self.assertRaises(InvalidJID, JID, local=local, domain='test.com')
self.assertRaises(InvalidJID, JID, '%s@/test.com' % local)
def test1024LengthResource(self):
resource = 'r' * 1024
self.assertRaises(InvalidJID, JID, domain='test.com', resource=resource)
self.assertRaises(InvalidJID, JID, 'test.com/%s' % resource)
def testTooLongDomainLabel(self):
domain = ('a' * 64) + '.com'
self.assertRaises(InvalidJID, JID, domain=domain)
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
def testDomainEmptyLabel(self):
domain = 'aaa..bbb.com'
self.assertRaises(InvalidJID, JID, domain=domain)
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
def testDomainIPv4(self):
domain = '127.0.0.1'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
def testDomainIPv6(self):
domain = '[::1]'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
def testDomainInvalidIPv6NoBrackets(self):
domain = '::1'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
self.assertEqual(jid1.domain, '[::1]')
self.assertEqual(jid2.domain, '[::1]')
def testDomainInvalidIPv6MissingBracket(self):
domain = '[::1'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
self.assertEqual(jid1.domain, '[::1]')
self.assertEqual(jid2.domain, '[::1]')
def testDomainWithPort(self):
domain = 'example.com:5555'
self.assertRaises(InvalidJID, JID, domain=domain)
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
def testDomainWithTrailingDot(self):
domain = 'example.com.'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
self.assertEqual(jid1.domain, 'example.com')
self.assertEqual(jid2.domain, 'example.com')
def testDomainWithDashes(self):
domain = 'example.com-'
self.assertRaises(InvalidJID, JID, domain=domain)
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
domain = '-example.com'
self.assertRaises(InvalidJID, JID, domain=domain)
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
def testACEDomain(self):
domain = 'xn--bcher-kva.ch'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
self.assertEqual(jid1.domain.encode('utf-8'), b'b\xc3\xbccher.ch')
self.assertEqual(jid2.domain.encode('utf-8'), b'b\xc3\xbccher.ch')
def testJIDEscapeExistingSequences(self):
jid = JID(local='blah\\foo\\20bar', domain='example.com')
self.assertEqual(jid.local, 'blah\\foo\\5c20bar')
def testJIDEscape(self):
jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")',
domain='example.com')
self.assertEqual(jid.local, r'here\27s_a_wild_\26_\2fcr%zy\2f_address_for\3a\3cwv\3e(\22IMPS\22)')
def testJIDUnescape(self):
jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")',
domain='example.com')
ujid = jid.unescape()
self.assertEqual(ujid.local, 'here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")')
jid = JID(local='blah\\foo\\20bar', domain='example.com')
ujid = jid.unescape()
self.assertEqual(ujid.local, 'blah\\foo\\20bar')
def testStartOrEndWithEscapedSpaces(self):
local = ' foo'
self.assertRaises(InvalidJID, JID, local=local, domain='example.com')
self.assertRaises(InvalidJID, JID, '%s@example.com' % local)
local = 'bar '
self.assertRaises(InvalidJID, JID, local=local, domain='example.com')
self.assertRaises(InvalidJID, JID, '%s@example.com' % local)
# Need more input for these cases. A JID starting with \20 *is* valid
# according to RFC 6122, but is not according to XEP-0106.
#self.assertRaises(InvalidJID, JID, '%s@example.com' % '\\20foo2')
#self.assertRaises(InvalidJID, JID, '%s@example.com' % 'bar2\\20')
suite = unittest.TestLoader().loadTestsFromTestCase(TestJIDClass) suite = unittest.TestLoader().loadTestsFromTestCase(TestJIDClass)