Merge branch 'hacks' of github.com:tomstrummer/SleekXMPP into hacks

This commit is contained in:
Thom Nichols 2010-07-01 17:06:50 -04:00
commit ba9633f8f7
4 changed files with 142 additions and 49 deletions

View File

@ -20,12 +20,12 @@ class Callback(base.BaseHandler):
def prerun(self, payload): # prerun actually calls run?!? WTF! Then it gets run AGAIN!
base.BaseHandler.prerun(self, payload)
if self._instream:
logging.debug('callback "%s" prerun', self.name)
# logging.debug('callback "%s" prerun', self.name)
self.run(payload, True)
def run(self, payload, instream=False):
if not self._instream or instream:
logging.debug('callback "%s" run', self.name)
# logging.debug('callback "%s" run', self.name)
base.BaseHandler.run(self, payload)
#if self._thread:
# x = threading.Thread(name="Callback_%s" % self.name, target=self._pointer, args=(payload,))

View File

@ -5,27 +5,31 @@
See the file license.txt for copying permission.
"""
from __future__ import with_statement
import threading
import time
import logging
log = logging.getLogger(__name__)
class StateMachine(object):
def __init__(self, states=[]):
self.lock = threading.Condition(threading.RLock())
self.lock = threading.Lock()
self.notifier = threading.Event()
self.__states= []
self.addStates(states)
self.__default_state = self.__states[0]
self.__current_state = self.__default_state
def addStates(self, states):
with self.lock:
self.lock.acquire()
try:
for state in states:
if state in self.__states:
raise IndexError("The state '%s' is already in the StateMachine." % state)
self.__states.append( state )
finally: self.lock.release()
def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ):
@ -78,30 +82,33 @@ class StateMachine(object):
if not to_state in self.__states:
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
with self.lock:
start = time.time()
while not self.__current_state in from_states:
while not self.__current_state in from_states or not self.lock.acquire(False):
# detect timeout:
if time.time() >= start + wait: return False
self.lock.wait(wait)
self.notifier.wait(wait)
try: # lock is acquired; all other threads will return false or wait until notify/timeout
self.notifier.clear()
if self.__current_state in from_states: # should always be True due to lock
return_val = True
# Note that func might throw an exception, but that's OK, it aborts the transition
if func is not None: return_val = func(*args,**kwargs)
return_val = func(*args,**kwargs) if func is not None else True
# some 'false' value returned from func,
# indicating that transition should not occur:
if not return_val: return return_val
logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
self.__current_state = to_state
self.lock.notify_all()
log.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
self._set_state( to_state )
return return_val # some 'true' value returned by func or True if func was None
else:
logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
log.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
return False
finally:
self.notifier.set()
self.lock.release()
def transition_ctx(self, from_state, to_state, wait=0.0):
@ -148,7 +155,15 @@ class StateMachine(object):
def ensure_any(self, states, wait=0.0):
'''
Ensure we are currently in one of the given `states`
Ensure we are currently in one of the given `states` or wait until
we enter one of those states.
Note that due to the nature of the function, you cannot guarantee that
the entirety of some operation completes while you remain in a given
state. That would require acquiring and holding a lock, which
would mean no other threads could do the same. (You'd essentially
be serializing all of the threads that are 'ensuring' their tasks
occurred in some state.
'''
if not (isinstance(states,tuple) or isinstance(states,list)):
raise ValueError('states arg should be a tuple or list')
@ -157,13 +172,17 @@ class StateMachine(object):
if not state in self.__states:
raise ValueError( "StateMachine does not contain state '%s'" % state )
with self.lock:
# Locking never really gained us anything here, since the lock was released
# before the function returned anyways. The only thing it _did_ do was
# increase the probability that this function would block for longer than
# intended if a `transition` function or context was running while holding
# the lock.
start = time.time()
while not self.__current_state in states:
# detect timeout:
if time.time() >= start + wait: return False
self.lock.wait(wait)
return self.__current_state in states # should always be True due to lock
self.notifier.wait(wait)
return True
def reset(self):
@ -202,33 +221,36 @@ class _StateCtx:
self.from_state = from_state
self.to_state = to_state
self.wait = wait
self._timeout = False
self._locked = False
def __enter__(self):
self.state_machine.lock.acquire()
start = time.time()
while not self.state_machine[ self.from_state ]:
while not self.state_machine[ self.from_state ] or not self.state_machine.lock.acquire(False):
# detect timeout:
if time.time() >= start + self.wait:
logging.debug('StateMachine timeout while waiting for state: %s', self.from_state )
self._timeout = True # to indicate we should not transition
log.debug('StateMachine timeout while waiting for state: %s', self.from_state )
return False
self.state_machine.lock.wait(self.wait)
self.state_machine.notifier.wait(self.wait)
logging.debug('StateMachine entered context in state: %s',
self._locked = True # lock has been acquired at this point
self.state_machine.notifier.clear()
log.debug('StateMachine entered context in state: %s',
self.state_machine.current_state() )
return True
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val is not None:
logging.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s",
log.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s",
self.state_machine.current_state(), exc_type.__name__, exc_val )
elif not self._timeout:
logging.debug(' ==== TRANSITION %s -> %s',
if self._locked:
if exc_val is None:
log.debug(' ==== TRANSITION %s -> %s',
self.state_machine.current_state(), self.to_state)
self.state_machine._set_state( self.to_state )
self.state_machine.lock.notify_all()
self.state_machine.notifier.set()
self.state_machine.lock.release()
return False # re-raise any exception

View File

@ -58,8 +58,7 @@ class XMLStream(object):
global ssl_support
self.ssl_support = ssl_support
self.escape_quotes = escape_quotes
self.state = statemachine.StateMachine(('disconnected','connecting',
'connected'))
self.state = statemachine.StateMachine(('disconnected','connected'))
self.should_reconnect = True
self.setSocket(socket)
@ -92,9 +91,11 @@ class XMLStream(object):
def setSocket(self, socket):
"Set the socket"
self.socket = socket
if socket is not None and self.state.transition('disconnected','connecting'):
self.filesocket = socket.makefile('rb', 0) # ElementTree.iterparse requires a file. 0 buffer files have to be binary
self.state.transition('connecting','connected')
if socket is not None:
with self.state.transition_ctx('disconnected','connected') as locked:
if not locked: raise Exception('Already connected')
# ElementTree.iterparse requires a file. 0 buffer files have to be binary
self.filesocket = socket.makefile('rb', 0)
def setFileSocket(self, filesocket):
self.filesocket = filesocket
@ -235,6 +236,9 @@ class XMLStream(object):
logging.debug("System interrupt detected")
self.shutdown()
self.eventqueue.put(('quit', None, None))
except cElementTree.XMLParserError:
logging.warn('XML RCV parsing error!', exc_info=1)
# don't restart the stream on an XML parse error.
except:
logging.exception('Unexpected error in RCV thread')
if self.should_reconnect:
@ -352,11 +356,11 @@ class XMLStream(object):
# TODO inefficient linear search; performance might be improved by hashtable lookup
for handler in self.__handlers:
if handler.match(stanza):
logging.debug('matched stanza to handler %s', handler.name)
# logging.debug('matched stanza to handler %s', handler.name)
handler.prerun(stanza)
self.eventqueue.put(('stanza', handler, stanza))
if handler.checkDelete():
logging.debug('deleting callback %s', handler.name)
# logging.debug('deleting callback %s', handler.name)
self.__handlers.pop(self.__handlers.index(handler))
unhandled = False
if unhandled:

View File

@ -256,6 +256,73 @@ class testStateMachine(unittest.TestCase):
self.assertTrue( s['three'] )
def testTransitionsDontUnintentionallyBlock(self):
'''
There was a bug where a long-running transition (e.g. one with a 'func'
arg or a `transition_ctx` call would cause any `transition` or `ensure`
call to block since the lock is acquired before checking the current
state. Attempts to acquire the mutex need to be non-blocking so when a
timeout is _not_ given, the caller can return immediately. At the same
time, threads that _do_ want to wait need the ability to be notified
(to avoid waiting beyond when the lock is released) so we've moved to a
combination of a plain-ol `threading.Lock` to act as mutex, and a
`threading.Event` to perform notification for threads who choose to wait.
'''
s = sm.StateMachine(('one','two','three'))
with s.transition_ctx('two','three') as result:
self.failIf( result )
self.assertTrue( s['one'] )
self.failIf( s.current_state in ('two','three') )
self.assertTrue( s['one'] )
statuses = {'t1':"not started",
't2':'not started'}
def t1():
print 'thread 1 started'
# no wait, so this should 'return False' immediately.
self.failIf( s.transition('two','three') )
statuses['t1'] = 'complete'
print 'thread 1 transitioned'
def t2():
print 'thread 2 started'
self.failIf( s['two'] )
self.failIf( s['three'] )
# we want this thread to acquire the lock, but for
# the second thread not to wait on the first.
with s.transition_ctx('one','two', 10) as locked:
statuses['t2'] = 'started'
print 'thread 2 has entered context'
self.assertTrue( locked )
# give thread1 a chance to complete while this
# thread still owns the lock
time.sleep(5)
self.assertTrue( s['two'] )
statuses['t2'] = 'complete'
t1 = threading.Thread(target=t1)
t2 = threading.Thread(target=t2)
t2.start() # this should acquire the lock
time.sleep(.2)
self.assertEqual( 'started', statuses['t2'] )
t1.start() # but it shouldn't prevent thread 1 from completing
time.sleep(1)
self.assertEqual( 'complete', statuses['t1'] )
t1.join()
t2.join()
self.assertEqual( 'complete', statuses['t2'] )
self.assertTrue( s['two'] )
suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)
if __name__ == '__main__': unittest.main()