Merge branch 'master' into develop

This commit is contained in:
Lance Stout 2012-07-24 20:01:18 -07:00
commit c42f1ad4c7
24 changed files with 936 additions and 212 deletions

View File

@ -49,6 +49,7 @@ packages = [ 'sleekxmpp',
'sleekxmpp/stanza', 'sleekxmpp/stanza',
'sleekxmpp/test', 'sleekxmpp/test',
'sleekxmpp/roster', 'sleekxmpp/roster',
'sleekxmpp/util',
'sleekxmpp/xmlstream', 'sleekxmpp/xmlstream',
'sleekxmpp/xmlstream/matcher', 'sleekxmpp/xmlstream/matcher',
'sleekxmpp/xmlstream/handler', 'sleekxmpp/xmlstream/handler',

View File

@ -10,6 +10,7 @@ from sleekxmpp.basexmpp import BaseXMPP
from sleekxmpp.clientxmpp import ClientXMPP from sleekxmpp.clientxmpp import ClientXMPP
from sleekxmpp.componentxmpp import ComponentXMPP from sleekxmpp.componentxmpp import ComponentXMPP
from sleekxmpp.stanza import Message, Presence, Iq from sleekxmpp.stanza import Message, Presence, Iq
from sleekxmpp.jid import JID, InvalidJID
from sleekxmpp.xmlstream.handler import * from sleekxmpp.xmlstream.handler import *
from sleekxmpp.xmlstream import XMLStream, RestartStream from sleekxmpp.xmlstream import XMLStream, RestartStream
from sleekxmpp.xmlstream.matcher import * from sleekxmpp.xmlstream.matcher import *

View File

@ -179,8 +179,7 @@ class ClientXMPP(BaseXMPP):
self._stream_feature_order.remove((order, name)) self._stream_feature_order.remove((order, name))
self._stream_feature_order.sort() self._stream_feature_order.sort()
def update_roster(self, jid, name=None, subscription=None, groups=[], def update_roster(self, jid, **kwargs):
block=True, timeout=None, callback=None):
"""Add or change a roster item. """Add or change a roster item.
:param jid: The JID of the entry to modify. :param jid: The JID of the entry to modify.
@ -201,6 +200,16 @@ class ClientXMPP(BaseXMPP):
Will be executed when the roster is received. Will be executed when the roster is received.
Implies ``block=False``. Implies ``block=False``.
""" """
current = self.client_roster[jid]
name = kwargs.get('name', current['name'])
subscription = kwargs.get('subscription', current['subscription'])
groups = kwargs.get('groups', current['groups'])
block = kwargs.get('block', True)
timeout = kwargs.get('timeout', None)
callback = kwargs.get('callback', None)
return self.client_roster.update(jid, name, subscription, groups, return self.client_roster.update(jid, name, subscription, groups,
block, timeout, callback) block, timeout, callback)

541
sleekxmpp/jid.py Normal file
View File

@ -0,0 +1,541 @@
# -*- coding: utf-8 -*-
"""
sleekxmpp.jid
~~~~~~~~~~~~~~~~~~~~~~~
This module allows for working with Jabber IDs (JIDs).
Part of SleekXMPP: The Sleek XMPP Library
:copyright: (c) 2011 Nathanael C. Fritz
:license: MIT, see LICENSE for more details
"""
from __future__ import unicode_literals
import re
import socket
import stringprep
import encodings.idna
from sleekxmpp.util import stringprep_profiles
#: These characters are not allowed to appear in a JID.
ILLEGAL_CHARS = '\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r' + \
'\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19' + \
'\x1a\x1b\x1c\x1d\x1e\x1f' + \
' !"#$%&\'()*+,./:;<=>?@[\\]^_`{|}~\x7f'
#: The basic regex pattern that a JID must match in order to determine
#: the local, domain, and resource parts. This regex does NOT do any
#: validation, which requires application of nodeprep, resourceprep, etc.
JID_PATTERN = "^(?:([^\"&'/:<>@]{1,1023})@)?([^/@]{1,1023})(?:/(.{1,1023}))?$"
#: The set of escape sequences for the characters not allowed by nodeprep.
JID_ESCAPE_SEQUENCES = set(['\\20', '\\22', '\\26', '\\27', '\\2f',
'\\3a', '\\3c', '\\3e', '\\40', '\\5c'])
#: A mapping of unallowed characters to their escape sequences. An escape
#: sequence for '\' is also included since it must also be escaped in
#: certain situations.
JID_ESCAPE_TRANSFORMATIONS = {' ': '\\20',
'"': '\\22',
'&': '\\26',
"'": '\\27',
'/': '\\2f',
':': '\\3a',
'<': '\\3c',
'>': '\\3e',
'@': '\\40',
'\\': '\\5c'}
#: The reverse mapping of escape sequences to their original forms.
JID_UNESCAPE_TRANSFORMATIONS = {'\\20': ' ',
'\\22': '"',
'\\26': '&',
'\\27': "'",
'\\2f': '/',
'\\3a': ':',
'\\3c': '<',
'\\3e': '>',
'\\40': '@',
'\\5c': '\\'}
# pylint: disable=c0103
#: The nodeprep profile of stringprep used to validate the local,
#: or username, portion of a JID.
nodeprep = stringprep_profiles.create(
nfkc=True,
bidi=True,
mappings=[
stringprep_profiles.b1_mapping,
stringprep_profiles.c12_mapping],
prohibited=[
stringprep.in_table_c11,
stringprep.in_table_c12,
stringprep.in_table_c21,
stringprep.in_table_c22,
stringprep.in_table_c3,
stringprep.in_table_c4,
stringprep.in_table_c5,
stringprep.in_table_c6,
stringprep.in_table_c7,
stringprep.in_table_c8,
stringprep.in_table_c9,
lambda c: c in ' \'"&/:<>@'],
unassigned=[stringprep.in_table_a1])
# pylint: disable=c0103
#: The resourceprep profile of stringprep, which is used to validate
#: the resource portion of a JID.
resourceprep = stringprep_profiles.create(
nfkc=True,
bidi=True,
mappings=[stringprep_profiles.b1_mapping],
prohibited=[
stringprep.in_table_c12,
stringprep.in_table_c21,
stringprep.in_table_c22,
stringprep.in_table_c3,
stringprep.in_table_c4,
stringprep.in_table_c5,
stringprep.in_table_c6,
stringprep.in_table_c7,
stringprep.in_table_c8,
stringprep.in_table_c9],
unassigned=[stringprep.in_table_a1])
def _parse_jid(data):
"""
Parse string data into the node, domain, and resource
components of a JID, if possible.
:param string data: A string that is potentially a JID.
:raises InvalidJID:
:returns: tuple of the validated local, domain, and resource strings
"""
match = re.match(JID_PATTERN, data)
if not match:
raise InvalidJID('JID could not be parsed')
(node, domain, resource) = match.groups()
node = _validate_node(node)
domain = _validate_domain(domain)
resource = _validate_resource(resource)
return node, domain, resource
def _validate_node(node):
"""Validate the local, or username, portion of a JID.
:raises InvalidJID:
:returns: The local portion of a JID, as validated by nodeprep.
"""
try:
if node is not None:
node = nodeprep(node)
if not node:
raise InvalidJID('Localpart must not be 0 bytes')
if len(node) > 1023:
raise InvalidJID('Localpart must be less than 1024 bytes')
return node
except stringprep_profiles.StringPrepError:
raise InvalidJID('Invalid local part')
def _validate_domain(domain):
"""Validate the domain portion of a JID.
IP literal addresses are left as-is, if valid. Domain names
are stripped of any trailing label separators (`.`), and are
checked with the nameprep profile of stringprep. If the given
domain is actually a punyencoded version of a domain name, it
is converted back into its original Unicode form. Domains must
also not start or end with a dash (`-`).
:raises InvalidJID:
:returns: The validated domain name
"""
ip_addr = False
# First, check if this is an IPv4 address
try:
socket.inet_aton(domain)
ip_addr = True
except socket.error:
pass
# Check if this is an IPv6 address
if not ip_addr and hasattr(socket, 'inet_pton'):
try:
socket.inet_pton(socket.AF_INET6, domain.strip('[]'))
domain = '[%s]' % domain.strip('[]')
ip_addr = True
except socket.error:
pass
if not ip_addr:
# This is a domain name, which must be checked further
if domain and domain[-1] == '.':
domain = domain[:-1]
domain_parts = []
for label in domain.split('.'):
try:
label = encodings.idna.nameprep(label)
encodings.idna.ToASCII(label)
pass_nameprep = True
except UnicodeError:
pass_nameprep = False
if not pass_nameprep:
raise InvalidJID('Could not encode domain as ASCII')
if label.startswith('xn--'):
label = encodings.idna.ToUnicode(label)
for char in label:
if char in ILLEGAL_CHARS:
raise InvalidJID('Domain contains illegar characters')
if '-' in (label[0], label[-1]):
raise InvalidJID('Domain started or ended with -')
domain_parts.append(label)
domain = '.'.join(domain_parts)
if not domain:
raise InvalidJID('Domain must not be 0 bytes')
if len(domain) > 1023:
raise InvalidJID('Domain must be less than 1024 bytes')
return domain
def _validate_resource(resource):
"""Validate the resource portion of a JID.
:raises InvalidJID:
:returns: The local portion of a JID, as validated by resourceprep.
"""
try:
if resource is not None:
resource = resourceprep(resource)
if not resource:
raise InvalidJID('Resource must not be 0 bytes')
if len(resource) > 1023:
raise InvalidJID('Resource must be less than 1024 bytes')
return resource
except stringprep_profiles.StringPrepError:
raise InvalidJID('Invalid resource')
def _escape_node(node):
"""Escape the local portion of a JID."""
result = []
for i, char in enumerate(node):
if char == '\\':
if ''.join((node[i:i+3])) in JID_ESCAPE_SEQUENCES:
result.append('\\5c')
continue
result.append(char)
for i, char in enumerate(result):
if char != '\\':
result[i] = JID_ESCAPE_TRANSFORMATIONS.get(char, char)
escaped = ''.join(result)
if escaped.startswith('\\20') or escaped.endswith('\\20'):
raise InvalidJID('Escaped local part starts or ends with "\\20"')
_validate_node(escaped)
return escaped
def _unescape_node(node):
"""Unescape a local portion of a JID.
.. note::
The unescaped local portion is meant ONLY for presentation,
and should not be used for other purposes.
"""
unescaped = []
seq = ''
for i, char in enumerate(node):
if char == '\\':
seq = node[i:i+3]
if seq not in JID_ESCAPE_SEQUENCES:
seq = ''
if seq:
if len(seq) == 3:
unescaped.append(JID_UNESCAPE_TRANSFORMATIONS.get(seq, char))
# Pop character off the escape sequence, and ignore it
seq = seq[1:]
else:
unescaped.append(char)
unescaped = ''.join(unescaped)
return unescaped
def _format_jid(local=None, domain=None, resource=None):
"""Format the given JID components into a full or bare JID.
:param string local: Optional. The local portion of the JID.
:param string domain: Required. The domain name portion of the JID.
:param strin resource: Optional. The resource portion of the JID.
:return: A full or bare JID string.
"""
result = []
if local:
result.append(local)
result.append('@')
if domain:
result.append(domain)
if resource:
result.append('/')
result.append(resource)
return ''.join(result)
class InvalidJID(ValueError):
"""
Raised when attempting to create a JID that does not pass validation.
It can also be raised if modifying an existing JID in such a way as
to make it invalid, such trying to remove the domain from an existing
full JID while the local and resource portions still exist.
"""
# pylint: disable=R0903
class UnescapedJID(object):
"""
.. versionadded:: 1.1.10
"""
def __init__(self, local, domain, resource):
self._jid = (local, domain, resource)
# pylint: disable=R0911
def __getattr__(self, name):
"""Retrieve the given JID component.
:param name: one of: user, server, domain, resource,
full, or bare.
"""
if name == 'resource':
return self._jid[2] or ''
elif name in ('user', 'username', 'local', 'node'):
return self._jid[0] or ''
elif name in ('server', 'domain', 'host'):
return self._jid[1] or ''
elif name in ('full', 'jid'):
return _format_jid(*self._jid)
elif name == 'bare':
return _format_jid(self._jid[0], self._jid[1])
elif name == '_jid':
return getattr(super(JID, self), '_jid')
else:
return None
def __str__(self):
"""Use the full JID as the string value."""
return _format_jid(*self._jid)
def __repr__(self):
"""Use the full JID as the representation."""
return self.__str__()
class JID(object):
"""
A representation of a Jabber ID, or JID.
Each JID may have three components: a user, a domain, and an optional
resource. For example: user@domain/resource
When a resource is not used, the JID is called a bare JID.
The JID is a full JID otherwise.
**JID Properties:**
:jid: Alias for ``full``.
:full: The string value of the full JID.
:bare: The string value of the bare JID.
:user: The username portion of the JID.
:username: Alias for ``user``.
:local: Alias for ``user``.
:node: Alias for ``user``.
:domain: The domain name portion of the JID.
:server: Alias for ``domain``.
:host: Alias for ``domain``.
:resource: The resource portion of the JID.
:param string jid:
A string of the form ``'[user@]domain[/resource]'``.
:param string local:
Optional. Specify the local, or username, portion
of the JID. If provided, it will override the local
value provided by the `jid` parameter. The given
local value will also be escaped if necessary.
:param string domain:
Optional. Specify the domain of the JID. If
provided, it will override the domain given by
the `jid` parameter.
:param string resource:
Optional. Specify the resource value of the JID.
If provided, it will override the domain given
by the `jid` parameter.
:raises InvalidJID:
"""
# pylint: disable=W0212
def __init__(self, jid=None, **kwargs):
self._jid = (None, None, None)
if jid is None or jid == '':
jid = (None, None, None)
elif not isinstance(jid, JID):
jid = _parse_jid(jid)
else:
jid = jid._jid
local, domain, resource = jid
local = kwargs.get('local', local)
domain = kwargs.get('domain', domain)
resource = kwargs.get('resource', resource)
if 'local' in kwargs:
local = _escape_node(local)
if 'domain' in kwargs:
domain = _validate_domain(domain)
if 'resource' in kwargs:
resource = _validate_resource(resource)
self._jid = (local, domain, resource)
def unescape(self):
"""Return an unescaped JID object.
Using an unescaped JID is preferred for displaying JIDs
to humans, and they should NOT be used for any other
purposes than for presentation.
:return: :class:`UnescapedJID`
.. versionadded:: 1.1.10
"""
return UnescapedJID(_unescape_node(self._jid[0]),
self._jid[1],
self._jid[2])
def regenerate(self):
"""No-op
.. deprecated:: 1.1.10
"""
pass
def reset(self, data):
"""Start fresh from a new JID string.
:param string data: A string of the form ``'[user@]domain[/resource]'``.
.. deprecated:: 1.1.10
"""
self._jid = JID(data)._jid
# pylint: disable=R0911
def __getattr__(self, name):
"""Retrieve the given JID component.
:param name: one of: user, server, domain, resource,
full, or bare.
"""
if name == 'resource':
return self._jid[2] or ''
elif name in ('user', 'username', 'local', 'node'):
return self._jid[0] or ''
elif name in ('server', 'domain', 'host'):
return self._jid[1] or ''
elif name in ('full', 'jid'):
return _format_jid(*self._jid)
elif name == 'bare':
return _format_jid(self._jid[0], self._jid[1])
elif name == '_jid':
return getattr(super(JID, self), '_jid')
else:
return None
# pylint: disable=W0212
def __setattr__(self, name, value):
"""Update the given JID component.
:param name: one of: ``user``, ``username``, ``local``,
``node``, ``server``, ``domain``, ``host``,
``resource``, ``full``, ``jid``, or ``bare``.
:param value: The new string value of the JID component.
"""
if name == 'resource':
self._jid = JID(self, resource=value)._jid
elif name in ('user', 'username', 'local', 'node'):
self._jid = JID(self, local=value)._jid
elif name in ('server', 'domain', 'host'):
self._jid = JID(self, domain=value)._jid
elif name in ('full', 'jid'):
self._jid = JID(value)._jid
elif name == 'bare':
parsed = JID(value)._jid
self._jid = (parsed[0], parsed[1], self._jid[2])
elif name == '_jid':
super(JID, self).__setattr__('_jid', value)
def __str__(self):
"""Use the full JID as the string value."""
return _format_jid(*self._jid)
def __repr__(self):
"""Use the full JID as the representation."""
return self.__str__()
# pylint: disable=W0212
def __eq__(self, other):
"""Two JIDs are equal if they have the same full JID value."""
if isinstance(other, UnescapedJID):
return False
other = JID(other)
return self._jid == other._jid
# pylint: disable=W0212
def __ne__(self, other):
"""Two JIDs are considered unequal if they are not equal."""
return not self == other
def __hash__(self):
"""Hash a JID based on the string version of its full JID."""
return hash(self.__str__())
def __copy__(self):
"""Generate a duplicate JID."""
return JID(self)

View File

@ -37,6 +37,7 @@ __all__ = [
'xep_0085', # Chat State Notifications 'xep_0085', # Chat State Notifications
'xep_0086', # Legacy Error Codes 'xep_0086', # Legacy Error Codes
'xep_0092', # Software Version 'xep_0092', # Software Version
'xep_0106', # JID Escaping
'xep_0107', # User Mood 'xep_0107', # User Mood
'xep_0108', # User Activity 'xep_0108', # User Activity
'xep_0115', # Entity Capabilities 'xep_0115', # Entity Capabilities

View File

@ -1,11 +1,8 @@
import socket import socket
import threading import threading
import logging import logging
try:
import queue
except ImportError:
import Queue as queue
from sleekxmpp.util import Queue
from sleekxmpp.exceptions import XMPPError from sleekxmpp.exceptions import XMPPError
@ -33,7 +30,7 @@ class IBBytestream(object):
self.stream_in_closed = threading.Event() self.stream_in_closed = threading.Event()
self.stream_out_closed = threading.Event() self.stream_out_closed = threading.Event()
self.recv_queue = queue.Queue() self.recv_queue = Queue()
self.send_window = threading.BoundedSemaphore(value=self.window_size) self.send_window = threading.BoundedSemaphore(value=self.window_size)
self.window_ids = set() self.window_ids = set()

View File

@ -41,6 +41,9 @@ class XEP_0084(BasePlugin):
def session_bind(self, jid): def session_bind(self, jid):
self.xmpp['xep_0163'].register_pep('avatar_metadata', MetaData) self.xmpp['xep_0163'].register_pep('avatar_metadata', MetaData)
def generate_id(self, data):
return hashlib.sha1(data).hexdigest()
def retrieve_avatar(self, jid, id, url=None, ifrom=None, block=True, def retrieve_avatar(self, jid, id, url=None, ifrom=None, block=True,
callback=None, timeout=None): callback=None, timeout=None):
return self.xmpp['xep_0060'].get_item(jid, Data.namespace, id, return self.xmpp['xep_0060'].get_item(jid, Data.namespace, id,
@ -54,8 +57,7 @@ class XEP_0084(BasePlugin):
payload = Data() payload = Data()
payload['value'] = data payload['value'] = data
return self.xmpp['xep_0163'].publish(payload, return self.xmpp['xep_0163'].publish(payload,
node=Data.namespace, id=self.generate_id(data),
id=hashlib.sha1(data).hexdigest(),
ifrom=ifrom, ifrom=ifrom,
block=block, block=block,
callback=callback, callback=callback,
@ -72,12 +74,12 @@ class XEP_0084(BasePlugin):
height=info.get('height', ''), height=info.get('height', ''),
width=info.get('width', ''), width=info.get('width', ''),
url=info.get('url', '')) url=info.get('url', ''))
if pointers is not None:
for pointer in pointers: for pointer in pointers:
metadata.add_pointer(pointer) metadata.add_pointer(pointer)
return self.xmpp['xep_0163'].publish(payload, return self.xmpp['xep_0163'].publish(metadata,
node=Data.namespace,
id=hashlib.sha1(data).hexdigest(),
ifrom=ifrom, ifrom=ifrom,
block=block, block=block,
callback=callback, callback=callback,

View File

@ -43,7 +43,7 @@ class MetaData(ElementBase):
info = Info() info = Info()
info.values = {'id': id, info.values = {'id': id,
'type': itype, 'type': itype,
'bytes': ibytes, 'bytes': '%s' % ibytes,
'height': height, 'height': height,
'width': width, 'width': width,
'url': url} 'url': url}

View File

@ -0,0 +1,26 @@
"""
SleekXMPP: The Sleek XMPP Library
Copyright (C) 2012 Nathanael C. Fritz, Lance J.T. Stout
This file is part of SleekXMPP.
See the file LICENSE for copying permission.
"""
from sleekxmpp.plugins import BasePlugin, register_plugin
class XEP_0106(BasePlugin):
name = 'xep_0106'
description = 'XEP-0106: JID Escaping'
dependencies = set(['xep_0030'])
def session_bind(self, jid):
self.xmpp['xep_0030'].add_feature(feature='jid\\20escaping')
def plugin_end(self):
self.xmpp['xep_0030'].del_feature(feature='jid\\20escaping')
register_plugin(XEP_0106)

View File

@ -75,6 +75,9 @@ class XEP_0153(BasePlugin):
return stanza return stanza
def _reset_hash(self, jid=None): def _reset_hash(self, jid=None):
if jid is None:
jid = self.xmpp.boundjid
own_jid = (jid.bare == self.xmpp.boundjid.bare) own_jid = (jid.bare == self.xmpp.boundjid.bare)
if self.xmpp.is_component: if self.xmpp.is_component:
own_jid = (jid.domain == self.xmpp.boundjid.domain) own_jid = (jid.domain == self.xmpp.boundjid.domain)

View File

@ -52,7 +52,7 @@ class Error(ElementBase):
name = 'error' name = 'error'
plugin_attrib = 'error' plugin_attrib = 'error'
interfaces = set(('code', 'condition', 'text', 'type', interfaces = set(('code', 'condition', 'text', 'type',
'gone', 'redirect')) 'gone', 'redirect', 'by'))
sub_interfaces = set(('text',)) sub_interfaces = set(('text',))
plugin_attrib_map = {} plugin_attrib_map = {}
plugin_tag_map = {} plugin_tag_map = {}

View File

@ -8,10 +8,8 @@
import socket import socket
import threading import threading
try:
import queue from sleekxmpp.util import Queue
except ImportError:
import Queue as queue
class TestLiveSocket(object): class TestLiveSocket(object):
@ -39,8 +37,8 @@ class TestLiveSocket(object):
""" """
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.recv_buffer = [] self.recv_buffer = []
self.recv_queue = queue.Queue() self.recv_queue = Queue()
self.send_queue = queue.Queue() self.send_queue = Queue()
self.send_queue_lock = threading.Lock() self.send_queue_lock = threading.Lock()
self.recv_queue_lock = threading.Lock() self.recv_queue_lock = threading.Lock()
self.is_live = True self.is_live = True

View File

@ -7,10 +7,8 @@
""" """
import socket import socket
try:
import queue from sleekxmpp.util import Queue
except ImportError:
import Queue as queue
class TestSocket(object): class TestSocket(object):
@ -36,8 +34,8 @@ class TestSocket(object):
Same as arguments for socket.socket Same as arguments for socket.socket
""" """
self.socket = socket.socket(*args, **kwargs) self.socket = socket.socket(*args, **kwargs)
self.recv_queue = queue.Queue() self.recv_queue = Queue()
self.send_queue = queue.Queue() self.send_queue = Queue()
self.is_live = False self.is_live = False
self.disconnected = False self.disconnected = False

View File

@ -8,13 +8,10 @@
import unittest import unittest
from xml.parsers.expat import ExpatError from xml.parsers.expat import ExpatError
try:
import Queue as queue
except:
import queue
import sleekxmpp import sleekxmpp
from sleekxmpp import ClientXMPP, ComponentXMPP from sleekxmpp import ClientXMPP, ComponentXMPP
from sleekxmpp.util import Queue
from sleekxmpp.stanza import Message, Iq, Presence from sleekxmpp.stanza import Message, Iq, Presence
from sleekxmpp.test import TestSocket, TestLiveSocket from sleekxmpp.test import TestSocket, TestLiveSocket
from sleekxmpp.exceptions import XMPPError, IqTimeout, IqError from sleekxmpp.exceptions import XMPPError, IqTimeout, IqError
@ -338,7 +335,7 @@ class SleekTest(unittest.TestCase):
# We will use this to wait for the session_start event # We will use this to wait for the session_start event
# for live connections. # for live connections.
skip_queue = queue.Queue() skip_queue = Queue()
if socket == 'mock': if socket == 'mock':
self.xmpp.set_socket(TestSocket()) self.xmpp.set_socket(TestSocket())

View File

@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
"""
sleekxmpp.util
~~~~~~~~~~~~~~
Part of SleekXMPP: The Sleek XMPP Library
:copyright: (c) 2012 Nathanael C. Fritz, Lance J.T. Stout
:license: MIT, see LICENSE for more details
"""
# =====================================================================
# Standardize import of Queue class:
try:
import queue
except ImportError:
import Queue as queue
Queue = queue.Queue
QueueEmpty = queue.Empty

View File

@ -0,0 +1,119 @@
from __future__ import unicode_literals
import sys
import stringprep
import unicodedata
class StringPrepError(UnicodeError):
pass
def to_unicode(data):
if sys.version_info < (3, 0):
return unicode(data)
else:
return str(data)
def b1_mapping(char):
return '' if stringprep.in_table_c12(char) else None
def c12_mapping(char):
return ' ' if stringprep.in_table_c12(char) else None
def map_input(data, tables=None):
"""
Each character in the input stream MUST be checked against
a mapping table.
"""
result = []
for char in data:
replacement = None
for mapping in tables:
replacement = mapping(char)
if replacement is not None:
break
if replacement is None:
replacement = char
result.append(replacement)
return ''.join(result)
def normalize(data, nfkc=True):
"""
A profile can specify one of two options for Unicode normalization:
- no normalization
- Unicode normalization with form KC
"""
if nfkc:
data = unicodedata.normalize('NFKC', data)
return data
def prohibit_output(data, tables=None):
"""
Before the text can be emitted, it MUST be checked for prohibited
code points.
"""
for char in data:
for check in tables:
if check(char):
raise StringPrepError("Prohibited code point: %s" % char)
def check_bidi(data):
"""
1) The characters in section 5.8 MUST be prohibited.
2) If a string contains any RandALCat character, the string MUST NOT
contain any LCat character.
3) If a string contains any RandALCat character, a RandALCat
character MUST be the first character of the string, and a
RandALCat character MUST be the last character of the string.
"""
if not data:
return data
has_lcat = False
has_randal = False
for c in data:
if stringprep.in_table_c8(c):
raise StringPrepError("BIDI violation: seciton 6 (1)")
if stringprep.in_table_d1(c):
has_randal = True
elif stringprep.in_table_d2(c):
has_lcat = True
if has_randal and has_lcat:
raise StringPrepError("BIDI violation: section 6 (2)")
first_randal = stringprep.in_table_d1(data[0])
last_randal = stringprep.in_table_d1(data[-1])
if has_randal and not (first_randal and last_randal):
raise StringPrepError("BIDI violation: section 6 (3)")
def create(nfkc=True, bidi=True, mappings=None,
prohibited=None, unassigned=None):
def profile(data, query=False):
try:
data = to_unicode(data)
except UnicodeError:
raise StringPrepError
data = map_input(data, mappings)
data = normalize(data, nfkc)
prohibit_output(data, prohibited)
if bidi:
check_bidi(data)
if query and unassigned:
check_unassigned(data, unassigned)
return data
return profile

View File

@ -6,7 +6,7 @@
See the file LICENSE for copying permission. See the file LICENSE for copying permission.
""" """
from sleekxmpp.xmlstream.jid import JID from sleekxmpp.jid import JID
from sleekxmpp.xmlstream.scheduler import Scheduler from sleekxmpp.xmlstream.scheduler import Scheduler
from sleekxmpp.xmlstream.stanzabase import StanzaBase, ElementBase, ET from sleekxmpp.xmlstream.stanzabase import StanzaBase, ElementBase, ET
from sleekxmpp.xmlstream.stanzabase import register_stanza_plugin from sleekxmpp.xmlstream.stanzabase import register_stanza_plugin

View File

@ -10,11 +10,8 @@
""" """
import logging import logging
try:
import queue
except ImportError:
import Queue as queue
from sleekxmpp.util import Queue, QueueEmpty
from sleekxmpp.xmlstream.handler.base import BaseHandler from sleekxmpp.xmlstream.handler.base import BaseHandler
@ -37,7 +34,7 @@ class Waiter(BaseHandler):
def __init__(self, name, matcher, stream=None): def __init__(self, name, matcher, stream=None):
BaseHandler.__init__(self, name, matcher, stream=stream) BaseHandler.__init__(self, name, matcher, stream=stream)
self._payload = queue.Queue() self._payload = Queue()
def prerun(self, payload): def prerun(self, payload):
"""Store the matched stanza when received during processing. """Store the matched stanza when received during processing.
@ -74,7 +71,7 @@ class Waiter(BaseHandler):
try: try:
stanza = self._payload.get(True, 1) stanza = self._payload.get(True, 1)
break break
except queue.Empty: except QueueEmpty:
elapsed_time += 1 elapsed_time += 1
if elapsed_time >= timeout: if elapsed_time >= timeout:
log.warning("Timed out waiting for %s", self.name) log.warning("Timed out waiting for %s", self.name)

View File

@ -1,148 +1,5 @@
# -*- coding: utf-8 -*- import logging
"""
sleekxmpp.xmlstream.jid
~~~~~~~~~~~~~~~~~~~~~~~
This module allows for working with Jabber IDs (JIDs) by logging.warning('Deprecated: sleekxmpp.xmlstream.jid is moving to sleekxmpp.jid')
providing accessors for the various components of a JID.
Part of SleekXMPP: The Sleek XMPP Library from sleekxmpp.jid import JID
:copyright: (c) 2011 Nathanael C. Fritz
:license: MIT, see LICENSE for more details
"""
from __future__ import unicode_literals
class JID(object):
"""
A representation of a Jabber ID, or JID.
Each JID may have three components: a user, a domain, and an optional
resource. For example: user@domain/resource
When a resource is not used, the JID is called a bare JID.
The JID is a full JID otherwise.
**JID Properties:**
:jid: Alias for ``full``.
:full: The value of the full JID.
:bare: The value of the bare JID.
:user: The username portion of the JID.
:domain: The domain name portion of the JID.
:server: Alias for ``domain``.
:resource: The resource portion of the JID.
:param string jid: A string of the form ``'[user@]domain[/resource]'``.
"""
def __init__(self, jid):
"""Initialize a new JID"""
self.reset(jid)
def reset(self, jid):
"""Start fresh from a new JID string.
:param string jid: A string of the form ``'[user@]domain[/resource]'``.
"""
if isinstance(jid, JID):
jid = jid.full
self._full = self._jid = jid
self._domain = None
self._resource = None
self._user = None
self._bare = None
def __getattr__(self, name):
"""Handle getting the JID values, using cache if available.
:param name: One of: user, server, domain, resource,
full, or bare.
"""
if name == 'resource':
if self._resource is None and '/' in self._jid:
self._resource = self._jid.split('/', 1)[-1]
return self._resource or ""
elif name == 'user':
if self._user is None:
if '@' in self._jid:
self._user = self._jid.split('@', 1)[0]
else:
self._user = self._user
return self._user or ""
elif name in ('server', 'domain', 'host'):
if self._domain is None:
self._domain = self._jid.split('@', 1)[-1].split('/', 1)[0]
return self._domain or ""
elif name in ('full', 'jid'):
return self._jid or ""
elif name == 'bare':
if self._bare is None:
self._bare = self._jid.split('/', 1)[0]
return self._bare or ""
def __setattr__(self, name, value):
"""Edit a JID by updating it's individual values, resetting the
generated JID in the end.
Arguments:
name -- The name of the JID part. One of: user, domain,
server, resource, full, jid, or bare.
value -- The new value for the JID part.
"""
if name in ('resource', 'user', 'domain'):
object.__setattr__(self, "_%s" % name, value)
self.regenerate()
elif name in ('server', 'domain', 'host'):
self.domain = value
elif name in ('full', 'jid'):
self.reset(value)
self.regenerate()
elif name == 'bare':
if '@' in value:
u, d = value.split('@', 1)
object.__setattr__(self, "_user", u)
object.__setattr__(self, "_domain", d)
else:
object.__setattr__(self, "_user", '')
object.__setattr__(self, "_domain", value)
self.regenerate()
else:
object.__setattr__(self, name, value)
def regenerate(self):
"""Generate a new JID based on current values, useful after editing."""
jid = ""
if self.user:
jid = "%s@" % self.user
jid += self.domain
if self.resource:
jid += "/%s" % self.resource
self.reset(jid)
def __str__(self):
"""Use the full JID as the string value."""
return self.full
def __repr__(self):
return self.full
def __eq__(self, other):
"""
Two JIDs are considered equal if they have the same full JID value.
"""
other = JID(other)
return self.full == other.full
def __ne__(self, other):
"""Two JIDs are considered unequal if they are not equal."""
return not self == other
def __hash__(self):
"""Hash a JID based on the string version of its full JID."""
return hash(self.full)
def __copy__(self):
return JID(self.jid)

View File

@ -15,10 +15,8 @@
import time import time
import threading import threading
import logging import logging
try:
import queue from sleekxmpp.util import Queue, QueueEmpty
except ImportError:
import Queue as queue
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -102,7 +100,7 @@ class Scheduler(object):
def __init__(self, parentstop=None): def __init__(self, parentstop=None):
#: A queue for storing tasks #: A queue for storing tasks
self.addq = queue.Queue() self.addq = Queue()
#: A list of tasks in order of execution time. #: A list of tasks in order of execution time.
self.schedule = [] self.schedule = []
@ -157,7 +155,7 @@ class Scheduler(object):
elapsed < wait: elapsed < wait:
newtask = self.addq.get(True, 0.1) newtask = self.addq.get(True, 0.1)
elapsed += 0.1 elapsed += 0.1
except queue.Empty: except QueueEmpty:
cleanup = [] cleanup = []
self.schedule_lock.acquire() self.schedule_lock.acquire()
for task in self.schedule: for task in self.schedule:

View File

@ -63,9 +63,11 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None,
default_ns = '' default_ns = ''
stream_ns = '' stream_ns = ''
use_cdata = False
if stream: if stream:
default_ns = stream.default_ns default_ns = stream.default_ns
stream_ns = stream.stream_ns stream_ns = stream.stream_ns
use_cdata = stream.use_cdata
# Output the tag name and derived namespace of the element. # Output the tag name and derived namespace of the element.
namespace = '' namespace = ''
@ -81,7 +83,7 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None,
# Output escaped attribute values. # Output escaped attribute values.
for attrib, value in xml.attrib.items(): for attrib, value in xml.attrib.items():
value = xml_escape(value) value = escape(value, use_cdata)
if '}' not in attrib: if '}' not in attrib:
output.append(' %s="%s"' % (attrib, value)) output.append(' %s="%s"' % (attrib, value))
else: else:
@ -105,24 +107,24 @@ def tostring(xml=None, xmlns='', stanza_ns='', stream=None,
# If there are additional child elements to serialize. # If there are additional child elements to serialize.
output.append(">") output.append(">")
if xml.text: if xml.text:
output.append(xml_escape(xml.text)) output.append(escape(xml.text, use_cdata))
if len(xml): if len(xml):
for child in xml: for child in xml:
output.append(tostring(child, tag_xmlns, stanza_ns, stream)) output.append(tostring(child, tag_xmlns, stanza_ns, stream))
output.append("</%s>" % tag_name) output.append("</%s>" % tag_name)
elif xml.text: elif xml.text:
# If we only have text content. # If we only have text content.
output.append(">%s</%s>" % (xml_escape(xml.text), tag_name)) output.append(">%s</%s>" % (escape(xml.text, use_cdata), tag_name))
else: else:
# Empty element. # Empty element.
output.append(" />") output.append(" />")
if xml.tail: if xml.tail:
# If there is additional text after the element. # If there is additional text after the element.
output.append(xml_escape(xml.tail)) output.append(escape(xml.tail, use_cdata))
return ''.join(output) return ''.join(output)
def xml_escape(text): def escape(text, use_cdata=False):
"""Convert special characters in XML to escape sequences. """Convert special characters in XML to escape sequences.
:param string text: The XML text to convert. :param string text: The XML text to convert.
@ -132,12 +134,24 @@ def xml_escape(text):
if type(text) != types.UnicodeType: if type(text) != types.UnicodeType:
text = unicode(text, 'utf-8', 'ignore') text = unicode(text, 'utf-8', 'ignore')
text = list(text)
escapes = {'&': '&amp;', escapes = {'&': '&amp;',
'<': '&lt;', '<': '&lt;',
'>': '&gt;', '>': '&gt;',
"'": '&apos;', "'": '&apos;',
'"': '&quot;'} '"': '&quot;'}
if not use_cdata:
text = list(text)
for i, c in enumerate(text): for i, c in enumerate(text):
text[i] = escapes.get(c, c) text[i] = escapes.get(c, c)
return ''.join(text) return ''.join(text)
else:
escape_needed = False
for c in text:
if c in escapes:
escape_needed = True
break
if escape_needed:
escaped = map(lambda x : "<![CDATA[%s]]>" % x, text.split("]]>"))
return "<![CDATA[]]]><![CDATA[]>]]>".join(escaped)
return text

View File

@ -26,14 +26,11 @@ import time
import random import random
import weakref import weakref
import uuid import uuid
try:
import queue
except ImportError:
import Queue as queue
from xml.parsers.expat import ExpatError from xml.parsers.expat import ExpatError
import sleekxmpp import sleekxmpp
from sleekxmpp.util import Queue, QueueEmpty
from sleekxmpp.thirdparty.statemachine import StateMachine from sleekxmpp.thirdparty.statemachine import StateMachine
from sleekxmpp.xmlstream import Scheduler, tostring, cert from sleekxmpp.xmlstream import Scheduler, tostring, cert
from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase from sleekxmpp.xmlstream.stanzabase import StanzaBase, ET, ElementBase
@ -215,6 +212,10 @@ class XMLStream(object):
#: If set to ``True``, attempt to use IPv6. #: If set to ``True``, attempt to use IPv6.
self.use_ipv6 = True self.use_ipv6 = True
#: Use CDATA for escaping instead of XML entities. Defaults
#: to ``False``.
self.use_cdata = False
#: An optional dictionary of proxy settings. It may provide: #: An optional dictionary of proxy settings. It may provide:
#: :host: The host offering proxy services. #: :host: The host offering proxy services.
#: :port: The port for the proxy service. #: :port: The port for the proxy service.
@ -270,10 +271,10 @@ class XMLStream(object):
self.end_session_on_disconnect = True self.end_session_on_disconnect = True
#: A queue of stream, custom, and scheduled events to be processed. #: A queue of stream, custom, and scheduled events to be processed.
self.event_queue = queue.Queue() self.event_queue = Queue()
#: A queue of string data to be sent over the stream. #: A queue of string data to be sent over the stream.
self.send_queue = queue.Queue() self.send_queue = Queue()
self.send_queue_lock = threading.Lock() self.send_queue_lock = threading.Lock()
self.send_lock = threading.RLock() self.send_lock = threading.RLock()
@ -1586,7 +1587,7 @@ class XMLStream(object):
try: try:
wait = self.wait_timeout wait = self.wait_timeout
event = self.event_queue.get(True, timeout=wait) event = self.event_queue.get(True, timeout=wait)
except queue.Empty: except QueueEmpty:
event = None event = None
if event is None: if event is None:
continue continue
@ -1655,7 +1656,7 @@ class XMLStream(object):
else: else:
try: try:
data = self.send_queue.get(True, 1) data = self.send_queue.get(True, 1)
except queue.Empty: except QueueEmpty:
continue continue
log.debug("SEND: %s", data) log.debug("SEND: %s", data)
enc_data = data.encode('utf-8') enc_data = data.encode('utf-8')

View File

@ -1,5 +1,5 @@
from sleekxmpp.test import * from sleekxmpp.test import *
from sleekxmpp.xmlstream.jid import JID from sleekxmpp import JID, InvalidJID
class TestJIDClass(SleekTest): class TestJIDClass(SleekTest):
@ -137,5 +137,146 @@ class TestJIDClass(SleekTest):
self.assertFalse(jid1 == jid2, "Same JIDs are not considered equal") self.assertFalse(jid1 == jid2, "Same JIDs are not considered equal")
self.assertTrue(jid1 != jid2, "Same JIDs are considered not equal") self.assertTrue(jid1 != jid2, "Same JIDs are considered not equal")
def testZeroLengthDomain(self):
self.assertRaises(InvalidJID, JID, domain='')
self.assertRaises(InvalidJID, JID, 'user@/resource')
def testZeroLengthLocalPart(self):
self.assertRaises(InvalidJID, JID, local='', domain='test.com')
self.assertRaises(InvalidJID, JID, '@/test.com')
def testZeroLengthResource(self):
self.assertRaises(InvalidJID, JID, domain='test.com', resource='')
self.assertRaises(InvalidJID, JID, 'test.com/')
def test1023LengthDomain(self):
domain = ('a.' * 509) + 'a.com'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
def test1023LengthLocalPart(self):
local = 'a' * 1023
jid1 = JID(local=local, domain='test.com')
jid2 = JID('%s@test.com' % local)
def test1023LengthResource(self):
resource = 'r' * 1023
jid1 = JID(domain='test.com', resource=resource)
jid2 = JID('test.com/%s' % resource)
def test1024LengthDomain(self):
domain = ('a.' * 509) + 'aa.com'
self.assertRaises(InvalidJID, JID, domain=domain)
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
def test1024LengthLocalPart(self):
local = 'a' * 1024
self.assertRaises(InvalidJID, JID, local=local, domain='test.com')
self.assertRaises(InvalidJID, JID, '%s@/test.com' % local)
def test1024LengthResource(self):
resource = 'r' * 1024
self.assertRaises(InvalidJID, JID, domain='test.com', resource=resource)
self.assertRaises(InvalidJID, JID, 'test.com/%s' % resource)
def testTooLongDomainLabel(self):
domain = ('a' * 64) + '.com'
self.assertRaises(InvalidJID, JID, domain=domain)
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
def testDomainEmptyLabel(self):
domain = 'aaa..bbb.com'
self.assertRaises(InvalidJID, JID, domain=domain)
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
def testDomainIPv4(self):
domain = '127.0.0.1'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
def testDomainIPv6(self):
domain = '[::1]'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
def testDomainInvalidIPv6NoBrackets(self):
domain = '::1'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
self.assertEqual(jid1.domain, '[::1]')
self.assertEqual(jid2.domain, '[::1]')
def testDomainInvalidIPv6MissingBracket(self):
domain = '[::1'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
self.assertEqual(jid1.domain, '[::1]')
self.assertEqual(jid2.domain, '[::1]')
def testDomainWithPort(self):
domain = 'example.com:5555'
self.assertRaises(InvalidJID, JID, domain=domain)
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
def testDomainWithTrailingDot(self):
domain = 'example.com.'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
self.assertEqual(jid1.domain, 'example.com')
self.assertEqual(jid2.domain, 'example.com')
def testDomainWithDashes(self):
domain = 'example.com-'
self.assertRaises(InvalidJID, JID, domain=domain)
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
domain = '-example.com'
self.assertRaises(InvalidJID, JID, domain=domain)
self.assertRaises(InvalidJID, JID, 'user@%s/resource' % domain)
def testACEDomain(self):
domain = 'xn--bcher-kva.ch'
jid1 = JID(domain=domain)
jid2 = JID('user@%s/resource' % domain)
self.assertEqual(jid1.domain.encode('utf-8'), b'b\xc3\xbccher.ch')
self.assertEqual(jid2.domain.encode('utf-8'), b'b\xc3\xbccher.ch')
def testJIDEscapeExistingSequences(self):
jid = JID(local='blah\\foo\\20bar', domain='example.com')
self.assertEqual(jid.local, 'blah\\foo\\5c20bar')
def testJIDEscape(self):
jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")',
domain='example.com')
self.assertEqual(jid.local, r'here\27s_a_wild_\26_\2fcr%zy\2f_address_for\3a\3cwv\3e(\22IMPS\22)')
def testJIDUnescape(self):
jid = JID(local='here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")',
domain='example.com')
ujid = jid.unescape()
self.assertEqual(ujid.local, 'here\'s_a_wild_&_/cr%zy/_address_for:<wv>("IMPS")')
jid = JID(local='blah\\foo\\20bar', domain='example.com')
ujid = jid.unescape()
self.assertEqual(ujid.local, 'blah\\foo\\20bar')
def testStartOrEndWithEscapedSpaces(self):
local = ' foo'
self.assertRaises(InvalidJID, JID, local=local, domain='example.com')
self.assertRaises(InvalidJID, JID, '%s@example.com' % local)
local = 'bar '
self.assertRaises(InvalidJID, JID, local=local, domain='example.com')
self.assertRaises(InvalidJID, JID, '%s@example.com' % local)
# Need more input for these cases. A JID starting with \20 *is* valid
# according to RFC 6122, but is not according to XEP-0106.
#self.assertRaises(InvalidJID, JID, '%s@example.com' % '\\20foo2')
#self.assertRaises(InvalidJID, JID, '%s@example.com' % 'bar2\\20')
suite = unittest.TestLoader().loadTestsFromTestCase(TestJIDClass) suite = unittest.TestLoader().loadTestsFromTestCase(TestJIDClass)

View File

@ -1,7 +1,7 @@
from sleekxmpp.test import * from sleekxmpp.test import *
from sleekxmpp.stanza import Message from sleekxmpp.stanza import Message
from sleekxmpp.xmlstream.stanzabase import ET, ElementBase from sleekxmpp.xmlstream.stanzabase import ET, ElementBase
from sleekxmpp.xmlstream.tostring import tostring, xml_escape from sleekxmpp.xmlstream.tostring import tostring, escape
class TestToString(SleekTest): class TestToString(SleekTest):
@ -30,7 +30,7 @@ class TestToString(SleekTest):
def testXMLEscape(self): def testXMLEscape(self):
"""Test escaping XML special characters.""" """Test escaping XML special characters."""
original = """<foo bar="baz">'Hi & welcome!'</foo>""" original = """<foo bar="baz">'Hi & welcome!'</foo>"""
escaped = xml_escape(original) escaped = escape(original)
desired = """&lt;foo bar=&quot;baz&quot;&gt;&apos;Hi""" desired = """&lt;foo bar=&quot;baz&quot;&gt;&apos;Hi"""
desired += """ &amp; welcome!&apos;&lt;/foo&gt;""" desired += """ &amp; welcome!&apos;&lt;/foo&gt;"""