Update XEP-0198 for asyncio

This commit is contained in:
mathieui 2016-06-04 20:51:59 +02:00
parent ffced0ed9a
commit 8fc6814b6d

View File

@ -6,8 +6,8 @@
See the file LICENSE for copying permission. See the file LICENSE for copying permission.
""" """
import asyncio
import logging import logging
import threading
import collections import collections
from slixmpp.stanza import Message, Presence, Iq, StreamFeatures from slixmpp.stanza import Message, Presence, Iq, StreamFeatures
@ -70,15 +70,10 @@ class XEP_0198(BasePlugin):
return return
self.window_counter = self.window self.window_counter = self.window
self.window_counter_lock = threading.Lock()
self.enabled = threading.Event() self.enabled = False
self.unacked_queue = collections.deque() self.unacked_queue = collections.deque()
self.seq_lock = threading.Lock()
self.handled_lock = threading.Lock()
self.ack_lock = threading.Lock()
register_stanza_plugin(StreamFeatures, stanza.StreamManagement) register_stanza_plugin(StreamFeatures, stanza.StreamManagement)
self.xmpp.register_stanza(stanza.Enable) self.xmpp.register_stanza(stanza.Enable)
self.xmpp.register_stanza(stanza.Enabled) self.xmpp.register_stanza(stanza.Enabled)
@ -161,7 +156,7 @@ class XEP_0198(BasePlugin):
def session_end(self, event): def session_end(self, event):
"""Reset stream management state.""" """Reset stream management state."""
self.enabled.clear() self.enabled = False
self.unacked_queue.clear() self.unacked_queue.clear()
self.sm_id = None self.sm_id = None
self.handled = 0 self.handled = 0
@ -171,15 +166,15 @@ class XEP_0198(BasePlugin):
def send_ack(self): def send_ack(self):
"""Send the current ack count to the server.""" """Send the current ack count to the server."""
ack = stanza.Ack(self.xmpp) ack = stanza.Ack(self.xmpp)
with self.handled_lock:
ack['h'] = self.handled ack['h'] = self.handled
self.xmpp.send_raw(str(ack)) self.xmpp.send_raw(str(ack))
def request_ack(self, e=None): def request_ack(self, e=None):
"""Request an ack from the server.""" """Request an ack from the server."""
req = stanza.RequestAck(self.xmpp) req = stanza.RequestAck(self.xmpp)
self.xmpp.send_queue.put(str(req)) self.xmpp.send_raw(str(req))
@asyncio.coroutine
def _handle_sm_feature(self, features): def _handle_sm_feature(self, features):
""" """
Enable or resume stream management. Enable or resume stream management.
@ -196,13 +191,21 @@ class XEP_0198(BasePlugin):
return False return False
if not self.sm_id: if not self.sm_id:
if 'bind' in self.xmpp.features: if 'bind' in self.xmpp.features:
self.enabled.set()
enable = stanza.Enable(self.xmpp) enable = stanza.Enable(self.xmpp)
enable['resume'] = self.allow_resume enable['resume'] = self.allow_resume
enable.send() enable.send()
self.enabled = True
self.handled = 0 self.handled = 0
elif self.sm_id and self.allow_resume: self.unacked_queue.clear()
self.enabled.set()
waiter = Waiter('enabled_or_failed',
MatchMany([
MatchXPath(stanza.Enabled.tag_name()),
MatchXPath(stanza.Failed.tag_name())]))
self.xmpp.register_handler(waiter)
result = yield from waiter.wait()
elif self.sm_id and self.allow_resume and 'bind' not in self.xmpp.features:
self.enabled = True
resume = stanza.Resume(self.xmpp) resume = stanza.Resume(self.xmpp)
resume['h'] = self.handled resume['h'] = self.handled
resume['previd'] = self.sm_id resume['previd'] = self.sm_id
@ -216,7 +219,7 @@ class XEP_0198(BasePlugin):
MatchXPath(stanza.Resumed.tag_name()), MatchXPath(stanza.Resumed.tag_name()),
MatchXPath(stanza.Failed.tag_name())])) MatchXPath(stanza.Failed.tag_name())]))
self.xmpp.register_handler(waiter) self.xmpp.register_handler(waiter)
result = waiter.wait() result = yield from waiter.wait()
if result is not None and result.name == 'resumed': if result is not None and result.name == 'resumed':
return True return True
return False return False
@ -250,7 +253,7 @@ class XEP_0198(BasePlugin):
Raises an :term:`sm_failed` event. Raises an :term:`sm_failed` event.
""" """
self.enabled.clear() self.enabled = False
self.unacked_queue.clear() self.unacked_queue.clear()
self.xmpp.event('sm_failed', stanza) self.xmpp.event('sm_failed', stanza)
@ -262,7 +265,6 @@ class XEP_0198(BasePlugin):
if ack['h'] == self.last_ack: if ack['h'] == self.last_ack:
return return
with self.ack_lock:
num_acked = (ack['h'] - self.last_ack) % MAX_SEQ num_acked = (ack['h'] - self.last_ack) % MAX_SEQ
num_unacked = len(self.unacked_queue) num_unacked = len(self.unacked_queue)
log.debug("Ack: %s, Last Ack: %s, " + \ log.debug("Ack: %s, Last Ack: %s, " + \
@ -273,6 +275,10 @@ class XEP_0198(BasePlugin):
num_unacked, num_unacked,
num_acked, num_acked,
num_unacked - num_acked) num_unacked - num_acked)
if num_acked > len(self.unacked_queue) or num_acked < 0:
log.error('Inconsistent sequence numbers from the server,'
' ignoring and replacing ours with them.')
num_acked = len(self.unacked_queue)
for x in range(num_acked): for x in range(num_acked):
seq, stanza = self.unacked_queue.popleft() seq, stanza = self.unacked_queue.popleft()
self.xmpp.event('stanza_acked', stanza) self.xmpp.event('stanza_acked', stanza)
@ -284,28 +290,25 @@ class XEP_0198(BasePlugin):
def _handle_incoming(self, stanza): def _handle_incoming(self, stanza):
"""Increment the handled counter for each inbound stanza.""" """Increment the handled counter for each inbound stanza."""
if not self.enabled.is_set(): if not self.enabled:
return stanza return stanza
if isinstance(stanza, (Message, Presence, Iq)): if isinstance(stanza, (Message, Presence, Iq)):
with self.handled_lock:
# Sequence numbers are mod 2^32 # Sequence numbers are mod 2^32
self.handled = (self.handled + 1) % MAX_SEQ self.handled = (self.handled + 1) % MAX_SEQ
return stanza return stanza
def _handle_outgoing(self, stanza): def _handle_outgoing(self, stanza):
"""Store outgoing stanzas in a queue to be acked.""" """Store outgoing stanzas in a queue to be acked."""
if not self.enabled.is_set(): if not self.enabled:
return stanza return stanza
if isinstance(stanza, (Message, Presence, Iq)): if isinstance(stanza, (Message, Presence, Iq)):
seq = None seq = None
with self.seq_lock:
# Sequence numbers are mod 2^32 # Sequence numbers are mod 2^32
self.seq = (self.seq + 1) % MAX_SEQ self.seq = (self.seq + 1) % MAX_SEQ
seq = self.seq seq = self.seq
self.unacked_queue.append((seq, stanza)) self.unacked_queue.append((seq, stanza))
with self.window_counter_lock:
self.window_counter -= 1 self.window_counter -= 1
if self.window_counter == 0: if self.window_counter == 0:
self.window_counter = self.window self.window_counter = self.window