Merge branch 'async-filters' into 'master'
Add async filters on the send process See merge request poezio/slixmpp!24
This commit is contained in:
commit
115c234527
@ -352,6 +352,7 @@ class SlixTest(unittest.TestCase):
|
|||||||
header = self.xmpp.stream_header
|
header = self.xmpp.stream_header
|
||||||
|
|
||||||
self.xmpp.data_received(header)
|
self.xmpp.data_received(header)
|
||||||
|
self.wait_for_send_queue()
|
||||||
|
|
||||||
if skip:
|
if skip:
|
||||||
self.xmpp.socket.next_sent()
|
self.xmpp.socket.next_sent()
|
||||||
@ -599,6 +600,7 @@ class SlixTest(unittest.TestCase):
|
|||||||
'id', 'stanzapath', 'xpath', and 'mask'.
|
'id', 'stanzapath', 'xpath', and 'mask'.
|
||||||
Defaults to the value of self.match_method.
|
Defaults to the value of self.match_method.
|
||||||
"""
|
"""
|
||||||
|
self.wait_for_send_queue()
|
||||||
sent = self.xmpp.socket.next_sent(timeout)
|
sent = self.xmpp.socket.next_sent(timeout)
|
||||||
if data is None and sent is None:
|
if data is None and sent is None:
|
||||||
return
|
return
|
||||||
@ -615,6 +617,14 @@ class SlixTest(unittest.TestCase):
|
|||||||
defaults=defaults,
|
defaults=defaults,
|
||||||
use_values=use_values)
|
use_values=use_values)
|
||||||
|
|
||||||
|
def wait_for_send_queue(self):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
future = asyncio.ensure_future(self.xmpp.run_filters(), loop=loop)
|
||||||
|
queue = self.xmpp.waiting_queue
|
||||||
|
print(queue)
|
||||||
|
loop.run_until_complete(queue.join())
|
||||||
|
future.cancel()
|
||||||
|
|
||||||
def stream_close(self):
|
def stream_close(self):
|
||||||
"""
|
"""
|
||||||
Disconnect the dummy XMPP client.
|
Disconnect the dummy XMPP client.
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
:license: MIT, see LICENSE for more details
|
:license: MIT, see LICENSE for more details
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Set, Callable
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
@ -21,6 +21,8 @@ import ssl
|
|||||||
import weakref
|
import weakref
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from asyncio import iscoroutinefunction, wait
|
||||||
|
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
from slixmpp.xmlstream.asyncio import asyncio
|
from slixmpp.xmlstream.asyncio import asyncio
|
||||||
@ -32,6 +34,10 @@ from slixmpp.xmlstream.resolver import resolve, default_resolver
|
|||||||
RESPONSE_TIMEOUT = 30
|
RESPONSE_TIMEOUT = 30
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
class ContinueQueue(Exception):
|
||||||
|
"""
|
||||||
|
Exception raised in the send queue to "continue" from within an inner loop
|
||||||
|
"""
|
||||||
|
|
||||||
class NotConnectedError(Exception):
|
class NotConnectedError(Exception):
|
||||||
"""
|
"""
|
||||||
@ -83,6 +89,8 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
self.force_starttls = None
|
self.force_starttls = None
|
||||||
self.disable_starttls = None
|
self.disable_starttls = None
|
||||||
|
|
||||||
|
self.waiting_queue = asyncio.Queue()
|
||||||
|
|
||||||
# A dict of {name: handle}
|
# A dict of {name: handle}
|
||||||
self.scheduled_events = {}
|
self.scheduled_events = {}
|
||||||
|
|
||||||
@ -263,6 +271,10 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
localhost
|
localhost
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
asyncio.ensure_future(
|
||||||
|
self.run_filters(),
|
||||||
|
loop=self.loop,
|
||||||
|
)
|
||||||
self.disconnect_reason = None
|
self.disconnect_reason = None
|
||||||
self.cancel_connection_attempt()
|
self.cancel_connection_attempt()
|
||||||
if host and port:
|
if host and port:
|
||||||
@ -789,7 +801,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
|
|
||||||
# If the callback is a coroutine, schedule it instead of
|
# If the callback is a coroutine, schedule it instead of
|
||||||
# running it directly
|
# running it directly
|
||||||
if asyncio.iscoroutinefunction(handler_callback):
|
if iscoroutinefunction(handler_callback):
|
||||||
async def handler_callback_routine(cb):
|
async def handler_callback_routine(cb):
|
||||||
try:
|
try:
|
||||||
await cb(data)
|
await cb(data)
|
||||||
@ -888,11 +900,93 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
"""
|
"""
|
||||||
return xml
|
return xml
|
||||||
|
|
||||||
|
async def _continue_slow_send(
|
||||||
|
self,
|
||||||
|
task: asyncio.Task,
|
||||||
|
already_used: Set[Callable[[ElementBase], Optional[StanzaBase]]]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Used when an item in the send queue has taken too long to process.
|
||||||
|
|
||||||
|
This is away from the send queue and can take as much time as needed.
|
||||||
|
:param asyncio.Task task: the Task wrapping the coroutine
|
||||||
|
:param set already_used: Filters already used on this outgoing stanza
|
||||||
|
"""
|
||||||
|
data = await task
|
||||||
|
for filter in self.__filters['out']:
|
||||||
|
if filter in already_used:
|
||||||
|
continue
|
||||||
|
if iscoroutinefunction(filter):
|
||||||
|
data = await task
|
||||||
|
else:
|
||||||
|
data = filter(data)
|
||||||
|
if data is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(data, ElementBase):
|
||||||
|
for filter in self.__filters['out_sync']:
|
||||||
|
data = filter(data)
|
||||||
|
if data is None:
|
||||||
|
return
|
||||||
|
str_data = tostring(data.xml, xmlns=self.default_ns,
|
||||||
|
stream=self, top_level=True)
|
||||||
|
self.send_raw(str_data)
|
||||||
|
else:
|
||||||
|
self.send_raw(data)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_filters(self):
|
||||||
|
"""
|
||||||
|
Background loop that processes stanzas to send.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
(data, use_filters) = await self.waiting_queue.get()
|
||||||
|
try:
|
||||||
|
if isinstance(data, ElementBase):
|
||||||
|
if use_filters:
|
||||||
|
already_run_filters = set()
|
||||||
|
for filter in self.__filters['out']:
|
||||||
|
already_run_filters.add(filter)
|
||||||
|
if iscoroutinefunction(filter):
|
||||||
|
task = asyncio.create_task(filter(data))
|
||||||
|
completed, pending = await wait(
|
||||||
|
{task},
|
||||||
|
timeout=1,
|
||||||
|
)
|
||||||
|
if pending:
|
||||||
|
asyncio.ensure_future(
|
||||||
|
self._continue_slow_send(
|
||||||
|
task,
|
||||||
|
already_run_filters
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise Exception("Slow coro, rescheduling")
|
||||||
|
data = task.result()
|
||||||
|
else:
|
||||||
|
data = filter(data)
|
||||||
|
if data is None:
|
||||||
|
raise ContinueQueue('Empty stanza')
|
||||||
|
|
||||||
|
if isinstance(data, ElementBase):
|
||||||
|
if use_filters:
|
||||||
|
for filter in self.__filters['out_sync']:
|
||||||
|
data = filter(data)
|
||||||
|
if data is None:
|
||||||
|
raise ContinueQueue('Empty stanza')
|
||||||
|
str_data = tostring(data.xml, xmlns=self.default_ns,
|
||||||
|
stream=self, top_level=True)
|
||||||
|
self.send_raw(str_data)
|
||||||
|
else:
|
||||||
|
self.send_raw(data)
|
||||||
|
except ContinueQueue as exc:
|
||||||
|
log.debug('Stanza in send queue not sent: %s', exc)
|
||||||
|
except Exception:
|
||||||
|
log.error('Exception raised in send queue:', exc_info=True)
|
||||||
|
self.waiting_queue.task_done()
|
||||||
|
|
||||||
def send(self, data, use_filters=True):
|
def send(self, data, use_filters=True):
|
||||||
"""A wrapper for :meth:`send_raw()` for sending stanza objects.
|
"""A wrapper for :meth:`send_raw()` for sending stanza objects.
|
||||||
|
|
||||||
May optionally block until an expected response is received.
|
|
||||||
|
|
||||||
:param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
|
:param data: The :class:`~slixmpp.xmlstream.stanzabase.ElementBase`
|
||||||
stanza to send on the stream.
|
stanza to send on the stream.
|
||||||
:param bool use_filters: Indicates if outgoing filters should be
|
:param bool use_filters: Indicates if outgoing filters should be
|
||||||
@ -900,24 +994,7 @@ class XMLStream(asyncio.BaseProtocol):
|
|||||||
filters is useful when resending stanzas.
|
filters is useful when resending stanzas.
|
||||||
Defaults to ``True``.
|
Defaults to ``True``.
|
||||||
"""
|
"""
|
||||||
if isinstance(data, ElementBase):
|
self.waiting_queue.put_nowait((data, use_filters))
|
||||||
if use_filters:
|
|
||||||
for filter in self.__filters['out']:
|
|
||||||
data = filter(data)
|
|
||||||
if data is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(data, ElementBase):
|
|
||||||
if use_filters:
|
|
||||||
for filter in self.__filters['out_sync']:
|
|
||||||
data = filter(data)
|
|
||||||
if data is None:
|
|
||||||
return
|
|
||||||
str_data = tostring(data.xml, xmlns=self.default_ns,
|
|
||||||
stream=self, top_level=True)
|
|
||||||
self.send_raw(str_data)
|
|
||||||
else:
|
|
||||||
self.send_raw(data)
|
|
||||||
|
|
||||||
def send_xml(self, data):
|
def send_xml(self, data):
|
||||||
"""Send an XML object on the stream
|
"""Send an XML object on the stream
|
||||||
|
@ -4,6 +4,7 @@ import sys
|
|||||||
import datetime
|
import datetime
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
|
import unittest
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from slixmpp.test import *
|
from slixmpp.test import *
|
||||||
@ -11,6 +12,7 @@ from slixmpp.xmlstream import ElementBase
|
|||||||
from slixmpp.plugins.xep_0323.device import Device
|
from slixmpp.plugins.xep_0323.device import Device
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skip('')
|
||||||
class TestStreamSensorData(SlixTest):
|
class TestStreamSensorData(SlixTest):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user