xmlstream: purge send queue and pending tasks on session end

and keep track of slow tasks
This commit is contained in:
mathieui 2021-01-23 16:16:56 +01:00
parent 8700f8d162
commit 9fbd40578c

View File

@ -12,7 +12,15 @@
:license: MIT, see LICENSE for more details :license: MIT, see LICENSE for more details
""" """
from typing import Optional, Set, Callable, Any from typing import (
Any,
Callable,
Iterable,
List,
Optional,
Set,
Union,
)
import functools import functools
import logging import logging
@ -21,7 +29,7 @@ import ssl
import weakref import weakref
import uuid import uuid
from asyncio import iscoroutinefunction, wait from asyncio import iscoroutinefunction, wait, Future
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
@ -230,6 +238,7 @@ class XMLStream(asyncio.BaseProtocol):
self.add_event_handler('session_start', self._start_keepalive) self.add_event_handler('session_start', self._start_keepalive)
self._run_filters = None self._run_filters = None
self.__slow_tasks: List[Future] = []
@property @property
def loop(self): def loop(self):
@ -465,6 +474,7 @@ class XMLStream(asyncio.BaseProtocol):
self.socket = None self.socket = None
# Fire the events after cleanup # Fire the events after cleanup
if self.end_session_on_disconnect: if self.end_session_on_disconnect:
self._reset_sendq()
self.event('session_end') self.event('session_end')
self.event("disconnected", self.disconnect_reason or exception and exception.strerror) self.event("disconnected", self.disconnect_reason or exception and exception.strerror)
@ -937,6 +947,18 @@ class XMLStream(asyncio.BaseProtocol):
""" """
return xml return xml
def _reset_sendq(self):
"""Clear sending tasks on session end"""
# Cancel all pending slow send tasks
log.debug('Cancelling %d slow send tasks', len(self.__slow_tasks))
for slow_task in self.__slow_tasks:
slow_task.cancel()
self.__slow_tasks.clear()
# Purge pending stanzas
while not self.waiting_queue.empty():
discarded = self.waiting_queue.get_nowait()
log.debug('Discarded stanza: %s', discarded)
async def _continue_slow_send( async def _continue_slow_send(
self, self,
task: asyncio.Task, task: asyncio.Task,
@ -950,6 +972,7 @@ class XMLStream(asyncio.BaseProtocol):
:param set already_used: Filters already used on this outgoing stanza :param set already_used: Filters already used on this outgoing stanza
""" """
data = await task data = await task
self.__slow_tasks.remove(task)
for filter in self.__filters['out']: for filter in self.__filters['out']:
if filter in already_used: if filter in already_used:
continue continue
@ -990,6 +1013,7 @@ class XMLStream(asyncio.BaseProtocol):
timeout=1, timeout=1,
) )
if pending: if pending:
self.slow_tasks.append(task)
asyncio.ensure_future( asyncio.ensure_future(
self._continue_slow_send( self._continue_slow_send(
task, task,