Merge branch 'hacks' of github.com:tomstrummer/SleekXMPP into hacks
This commit is contained in:
commit
ba9633f8f7
@ -20,12 +20,12 @@ class Callback(base.BaseHandler):
|
|||||||
def prerun(self, payload): # prerun actually calls run?!? WTF! Then it gets run AGAIN!
|
def prerun(self, payload): # prerun actually calls run?!? WTF! Then it gets run AGAIN!
|
||||||
base.BaseHandler.prerun(self, payload)
|
base.BaseHandler.prerun(self, payload)
|
||||||
if self._instream:
|
if self._instream:
|
||||||
logging.debug('callback "%s" prerun', self.name)
|
# logging.debug('callback "%s" prerun', self.name)
|
||||||
self.run(payload, True)
|
self.run(payload, True)
|
||||||
|
|
||||||
def run(self, payload, instream=False):
|
def run(self, payload, instream=False):
|
||||||
if not self._instream or instream:
|
if not self._instream or instream:
|
||||||
logging.debug('callback "%s" run', self.name)
|
# logging.debug('callback "%s" run', self.name)
|
||||||
base.BaseHandler.run(self, payload)
|
base.BaseHandler.run(self, payload)
|
||||||
#if self._thread:
|
#if self._thread:
|
||||||
# x = threading.Thread(name="Callback_%s" % self.name, target=self._pointer, args=(payload,))
|
# x = threading.Thread(name="Callback_%s" % self.name, target=self._pointer, args=(payload,))
|
||||||
|
@ -5,27 +5,31 @@
|
|||||||
|
|
||||||
See the file license.txt for copying permission.
|
See the file license.txt for copying permission.
|
||||||
"""
|
"""
|
||||||
from __future__ import with_statement
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class StateMachine(object):
|
class StateMachine(object):
|
||||||
|
|
||||||
def __init__(self, states=[]):
|
def __init__(self, states=[]):
|
||||||
self.lock = threading.Condition(threading.RLock())
|
self.lock = threading.Lock()
|
||||||
|
self.notifier = threading.Event()
|
||||||
self.__states= []
|
self.__states= []
|
||||||
self.addStates(states)
|
self.addStates(states)
|
||||||
self.__default_state = self.__states[0]
|
self.__default_state = self.__states[0]
|
||||||
self.__current_state = self.__default_state
|
self.__current_state = self.__default_state
|
||||||
|
|
||||||
def addStates(self, states):
|
def addStates(self, states):
|
||||||
with self.lock:
|
self.lock.acquire()
|
||||||
|
try:
|
||||||
for state in states:
|
for state in states:
|
||||||
if state in self.__states:
|
if state in self.__states:
|
||||||
raise IndexError("The state '%s' is already in the StateMachine." % state)
|
raise IndexError("The state '%s' is already in the StateMachine." % state)
|
||||||
self.__states.append( state )
|
self.__states.append( state )
|
||||||
|
finally: self.lock.release()
|
||||||
|
|
||||||
|
|
||||||
def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ):
|
def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ):
|
||||||
@ -78,30 +82,33 @@ class StateMachine(object):
|
|||||||
if not to_state in self.__states:
|
if not to_state in self.__states:
|
||||||
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
|
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
|
||||||
|
|
||||||
with self.lock:
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
while not self.__current_state in from_states:
|
while not self.__current_state in from_states or not self.lock.acquire(False):
|
||||||
# detect timeout:
|
# detect timeout:
|
||||||
if time.time() >= start + wait: return False
|
if time.time() >= start + wait: return False
|
||||||
self.lock.wait(wait)
|
self.notifier.wait(wait)
|
||||||
|
|
||||||
|
try: # lock is acquired; all other threads will return false or wait until notify/timeout
|
||||||
|
self.notifier.clear()
|
||||||
if self.__current_state in from_states: # should always be True due to lock
|
if self.__current_state in from_states: # should always be True due to lock
|
||||||
|
|
||||||
return_val = True
|
|
||||||
# Note that func might throw an exception, but that's OK, it aborts the transition
|
# Note that func might throw an exception, but that's OK, it aborts the transition
|
||||||
if func is not None: return_val = func(*args,**kwargs)
|
return_val = func(*args,**kwargs) if func is not None else True
|
||||||
|
|
||||||
# some 'false' value returned from func,
|
# some 'false' value returned from func,
|
||||||
# indicating that transition should not occur:
|
# indicating that transition should not occur:
|
||||||
if not return_val: return return_val
|
if not return_val: return return_val
|
||||||
|
|
||||||
logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
|
log.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
|
||||||
self.__current_state = to_state
|
self._set_state( to_state )
|
||||||
self.lock.notify_all()
|
|
||||||
return return_val # some 'true' value returned by func or True if func was None
|
return return_val # some 'true' value returned by func or True if func was None
|
||||||
else:
|
else:
|
||||||
logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
|
log.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
|
||||||
return False
|
return False
|
||||||
|
finally:
|
||||||
|
self.notifier.set()
|
||||||
|
self.lock.release()
|
||||||
|
|
||||||
|
|
||||||
def transition_ctx(self, from_state, to_state, wait=0.0):
|
def transition_ctx(self, from_state, to_state, wait=0.0):
|
||||||
@ -148,7 +155,15 @@ class StateMachine(object):
|
|||||||
|
|
||||||
def ensure_any(self, states, wait=0.0):
|
def ensure_any(self, states, wait=0.0):
|
||||||
'''
|
'''
|
||||||
Ensure we are currently in one of the given `states`
|
Ensure we are currently in one of the given `states` or wait until
|
||||||
|
we enter one of those states.
|
||||||
|
|
||||||
|
Note that due to the nature of the function, you cannot guarantee that
|
||||||
|
the entirety of some operation completes while you remain in a given
|
||||||
|
state. That would require acquiring and holding a lock, which
|
||||||
|
would mean no other threads could do the same. (You'd essentially
|
||||||
|
be serializing all of the threads that are 'ensuring' their tasks
|
||||||
|
occurred in some state.
|
||||||
'''
|
'''
|
||||||
if not (isinstance(states,tuple) or isinstance(states,list)):
|
if not (isinstance(states,tuple) or isinstance(states,list)):
|
||||||
raise ValueError('states arg should be a tuple or list')
|
raise ValueError('states arg should be a tuple or list')
|
||||||
@ -157,13 +172,17 @@ class StateMachine(object):
|
|||||||
if not state in self.__states:
|
if not state in self.__states:
|
||||||
raise ValueError( "StateMachine does not contain state '%s'" % state )
|
raise ValueError( "StateMachine does not contain state '%s'" % state )
|
||||||
|
|
||||||
with self.lock:
|
# Locking never really gained us anything here, since the lock was released
|
||||||
|
# before the function returned anyways. The only thing it _did_ do was
|
||||||
|
# increase the probability that this function would block for longer than
|
||||||
|
# intended if a `transition` function or context was running while holding
|
||||||
|
# the lock.
|
||||||
start = time.time()
|
start = time.time()
|
||||||
while not self.__current_state in states:
|
while not self.__current_state in states:
|
||||||
# detect timeout:
|
# detect timeout:
|
||||||
if time.time() >= start + wait: return False
|
if time.time() >= start + wait: return False
|
||||||
self.lock.wait(wait)
|
self.notifier.wait(wait)
|
||||||
return self.__current_state in states # should always be True due to lock
|
return True
|
||||||
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@ -202,33 +221,36 @@ class _StateCtx:
|
|||||||
self.from_state = from_state
|
self.from_state = from_state
|
||||||
self.to_state = to_state
|
self.to_state = to_state
|
||||||
self.wait = wait
|
self.wait = wait
|
||||||
self._timeout = False
|
self._locked = False
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.state_machine.lock.acquire()
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
while not self.state_machine[ self.from_state ]:
|
while not self.state_machine[ self.from_state ] or not self.state_machine.lock.acquire(False):
|
||||||
# detect timeout:
|
# detect timeout:
|
||||||
if time.time() >= start + self.wait:
|
if time.time() >= start + self.wait:
|
||||||
logging.debug('StateMachine timeout while waiting for state: %s', self.from_state )
|
log.debug('StateMachine timeout while waiting for state: %s', self.from_state )
|
||||||
self._timeout = True # to indicate we should not transition
|
|
||||||
return False
|
return False
|
||||||
self.state_machine.lock.wait(self.wait)
|
self.state_machine.notifier.wait(self.wait)
|
||||||
|
|
||||||
logging.debug('StateMachine entered context in state: %s',
|
self._locked = True # lock has been acquired at this point
|
||||||
|
self.state_machine.notifier.clear()
|
||||||
|
log.debug('StateMachine entered context in state: %s',
|
||||||
self.state_machine.current_state() )
|
self.state_machine.current_state() )
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
if exc_val is not None:
|
if exc_val is not None:
|
||||||
logging.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s",
|
log.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s",
|
||||||
self.state_machine.current_state(), exc_type.__name__, exc_val )
|
self.state_machine.current_state(), exc_type.__name__, exc_val )
|
||||||
elif not self._timeout:
|
|
||||||
logging.debug(' ==== TRANSITION %s -> %s',
|
if self._locked:
|
||||||
|
if exc_val is None:
|
||||||
|
log.debug(' ==== TRANSITION %s -> %s',
|
||||||
self.state_machine.current_state(), self.to_state)
|
self.state_machine.current_state(), self.to_state)
|
||||||
self.state_machine._set_state( self.to_state )
|
self.state_machine._set_state( self.to_state )
|
||||||
|
|
||||||
self.state_machine.lock.notify_all()
|
self.state_machine.notifier.set()
|
||||||
self.state_machine.lock.release()
|
self.state_machine.lock.release()
|
||||||
|
|
||||||
return False # re-raise any exception
|
return False # re-raise any exception
|
||||||
|
|
||||||
|
@ -58,8 +58,7 @@ class XMLStream(object):
|
|||||||
global ssl_support
|
global ssl_support
|
||||||
self.ssl_support = ssl_support
|
self.ssl_support = ssl_support
|
||||||
self.escape_quotes = escape_quotes
|
self.escape_quotes = escape_quotes
|
||||||
self.state = statemachine.StateMachine(('disconnected','connecting',
|
self.state = statemachine.StateMachine(('disconnected','connected'))
|
||||||
'connected'))
|
|
||||||
self.should_reconnect = True
|
self.should_reconnect = True
|
||||||
|
|
||||||
self.setSocket(socket)
|
self.setSocket(socket)
|
||||||
@ -92,9 +91,11 @@ class XMLStream(object):
|
|||||||
def setSocket(self, socket):
|
def setSocket(self, socket):
|
||||||
"Set the socket"
|
"Set the socket"
|
||||||
self.socket = socket
|
self.socket = socket
|
||||||
if socket is not None and self.state.transition('disconnected','connecting'):
|
if socket is not None:
|
||||||
self.filesocket = socket.makefile('rb', 0) # ElementTree.iterparse requires a file. 0 buffer files have to be binary
|
with self.state.transition_ctx('disconnected','connected') as locked:
|
||||||
self.state.transition('connecting','connected')
|
if not locked: raise Exception('Already connected')
|
||||||
|
# ElementTree.iterparse requires a file. 0 buffer files have to be binary
|
||||||
|
self.filesocket = socket.makefile('rb', 0)
|
||||||
|
|
||||||
def setFileSocket(self, filesocket):
|
def setFileSocket(self, filesocket):
|
||||||
self.filesocket = filesocket
|
self.filesocket = filesocket
|
||||||
@ -235,6 +236,9 @@ class XMLStream(object):
|
|||||||
logging.debug("System interrupt detected")
|
logging.debug("System interrupt detected")
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
self.eventqueue.put(('quit', None, None))
|
self.eventqueue.put(('quit', None, None))
|
||||||
|
except cElementTree.XMLParserError:
|
||||||
|
logging.warn('XML RCV parsing error!', exc_info=1)
|
||||||
|
# don't restart the stream on an XML parse error.
|
||||||
except:
|
except:
|
||||||
logging.exception('Unexpected error in RCV thread')
|
logging.exception('Unexpected error in RCV thread')
|
||||||
if self.should_reconnect:
|
if self.should_reconnect:
|
||||||
@ -352,11 +356,11 @@ class XMLStream(object):
|
|||||||
# TODO inefficient linear search; performance might be improved by hashtable lookup
|
# TODO inefficient linear search; performance might be improved by hashtable lookup
|
||||||
for handler in self.__handlers:
|
for handler in self.__handlers:
|
||||||
if handler.match(stanza):
|
if handler.match(stanza):
|
||||||
logging.debug('matched stanza to handler %s', handler.name)
|
# logging.debug('matched stanza to handler %s', handler.name)
|
||||||
handler.prerun(stanza)
|
handler.prerun(stanza)
|
||||||
self.eventqueue.put(('stanza', handler, stanza))
|
self.eventqueue.put(('stanza', handler, stanza))
|
||||||
if handler.checkDelete():
|
if handler.checkDelete():
|
||||||
logging.debug('deleting callback %s', handler.name)
|
# logging.debug('deleting callback %s', handler.name)
|
||||||
self.__handlers.pop(self.__handlers.index(handler))
|
self.__handlers.pop(self.__handlers.index(handler))
|
||||||
unhandled = False
|
unhandled = False
|
||||||
if unhandled:
|
if unhandled:
|
||||||
|
@ -256,6 +256,73 @@ class testStateMachine(unittest.TestCase):
|
|||||||
self.assertTrue( s['three'] )
|
self.assertTrue( s['three'] )
|
||||||
|
|
||||||
|
|
||||||
|
def testTransitionsDontUnintentionallyBlock(self):
|
||||||
|
'''
|
||||||
|
There was a bug where a long-running transition (e.g. one with a 'func'
|
||||||
|
arg or a `transition_ctx` call would cause any `transition` or `ensure`
|
||||||
|
call to block since the lock is acquired before checking the current
|
||||||
|
state. Attempts to acquire the mutex need to be non-blocking so when a
|
||||||
|
timeout is _not_ given, the caller can return immediately. At the same
|
||||||
|
time, threads that _do_ want to wait need the ability to be notified
|
||||||
|
(to avoid waiting beyond when the lock is released) so we've moved to a
|
||||||
|
combination of a plain-ol `threading.Lock` to act as mutex, and a
|
||||||
|
`threading.Event` to perform notification for threads who choose to wait.
|
||||||
|
'''
|
||||||
|
|
||||||
|
s = sm.StateMachine(('one','two','three'))
|
||||||
|
|
||||||
|
with s.transition_ctx('two','three') as result:
|
||||||
|
self.failIf( result )
|
||||||
|
self.assertTrue( s['one'] )
|
||||||
|
self.failIf( s.current_state in ('two','three') )
|
||||||
|
|
||||||
|
self.assertTrue( s['one'] )
|
||||||
|
|
||||||
|
statuses = {'t1':"not started",
|
||||||
|
't2':'not started'}
|
||||||
|
|
||||||
|
def t1():
|
||||||
|
print 'thread 1 started'
|
||||||
|
# no wait, so this should 'return False' immediately.
|
||||||
|
self.failIf( s.transition('two','three') )
|
||||||
|
statuses['t1'] = 'complete'
|
||||||
|
print 'thread 1 transitioned'
|
||||||
|
|
||||||
|
def t2():
|
||||||
|
print 'thread 2 started'
|
||||||
|
self.failIf( s['two'] )
|
||||||
|
self.failIf( s['three'] )
|
||||||
|
# we want this thread to acquire the lock, but for
|
||||||
|
# the second thread not to wait on the first.
|
||||||
|
with s.transition_ctx('one','two', 10) as locked:
|
||||||
|
statuses['t2'] = 'started'
|
||||||
|
print 'thread 2 has entered context'
|
||||||
|
self.assertTrue( locked )
|
||||||
|
# give thread1 a chance to complete while this
|
||||||
|
# thread still owns the lock
|
||||||
|
time.sleep(5)
|
||||||
|
self.assertTrue( s['two'] )
|
||||||
|
statuses['t2'] = 'complete'
|
||||||
|
|
||||||
|
t1 = threading.Thread(target=t1)
|
||||||
|
t2 = threading.Thread(target=t2)
|
||||||
|
|
||||||
|
t2.start() # this should acquire the lock
|
||||||
|
time.sleep(.2)
|
||||||
|
self.assertEqual( 'started', statuses['t2'] )
|
||||||
|
t1.start() # but it shouldn't prevent thread 1 from completing
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
self.assertEqual( 'complete', statuses['t1'] )
|
||||||
|
|
||||||
|
t1.join()
|
||||||
|
t2.join()
|
||||||
|
|
||||||
|
self.assertEqual( 'complete', statuses['t2'] )
|
||||||
|
|
||||||
|
self.assertTrue( s['two'] )
|
||||||
|
|
||||||
|
|
||||||
suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)
|
suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)
|
||||||
|
|
||||||
if __name__ == '__main__': unittest.main()
|
if __name__ == '__main__': unittest.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user