fix for statemachine where operations would unintentionally block if the lock was acquired in a long-running transition

This commit is contained in:
Tom Nichols 2010-07-01 15:10:22 -04:00
parent 8bdfa77024
commit 0a23f84ec3
2 changed files with 122 additions and 35 deletions

View File

@ -5,7 +5,6 @@
See the file license.txt for copying permission. See the file license.txt for copying permission.
""" """
from __future__ import with_statement
import threading import threading
import time import time
import logging import logging
@ -14,18 +13,21 @@ import logging
class StateMachine(object): class StateMachine(object):
def __init__(self, states=[]): def __init__(self, states=[]):
self.lock = threading.Condition(threading.RLock()) self.lock = threading.Lock()
self.notifier = threading.Event()
self.__states= [] self.__states= []
self.addStates(states) self.addStates(states)
self.__default_state = self.__states[0] self.__default_state = self.__states[0]
self.__current_state = self.__default_state self.__current_state = self.__default_state
def addStates(self, states): def addStates(self, states):
with self.lock: self.lock.acquire()
try:
for state in states: for state in states:
if state in self.__states: if state in self.__states:
raise IndexError("The state '%s' is already in the StateMachine." % state) raise IndexError("The state '%s' is already in the StateMachine." % state)
self.__states.append( state ) self.__states.append( state )
finally: self.lock.release()
def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ): def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ):
@ -78,30 +80,33 @@ class StateMachine(object):
if not to_state in self.__states: if not to_state in self.__states:
raise ValueError( "StateMachine does not contain to_state %s." % to_state ) raise ValueError( "StateMachine does not contain to_state %s." % to_state )
with self.lock:
start = time.time() start = time.time()
while not self.__current_state in from_states: while not self.__current_state in from_states or not self.lock.acquire(False):
# detect timeout: # detect timeout:
if time.time() >= start + wait: return False if time.time() >= start + wait: return False
self.lock.wait(wait) self.notifier.wait(wait)
try: # lock is acquired; all other threads will return false or wait until notify/timeout
self.notifier.clear()
if self.__current_state in from_states: # should always be True due to lock if self.__current_state in from_states: # should always be True due to lock
return_val = True
# Note that func might throw an exception, but that's OK, it aborts the transition # Note that func might throw an exception, but that's OK, it aborts the transition
if func is not None: return_val = func(*args,**kwargs) return_val = func(*args,**kwargs) if func is not None else True
# some 'false' value returned from func, # some 'false' value returned from func,
# indicating that transition should not occur: # indicating that transition should not occur:
if not return_val: return return_val if not return_val: return return_val
logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state) logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
self.__current_state = to_state self._set_state( to_state )
self.lock.notify_all()
return return_val # some 'true' value returned by func or True if func was None return return_val # some 'true' value returned by func or True if func was None
else: else:
logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" ) logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
return False return False
finally:
self.notifier.set()
self.lock.release()
def transition_ctx(self, from_state, to_state, wait=0.0): def transition_ctx(self, from_state, to_state, wait=0.0):
@ -148,7 +153,15 @@ class StateMachine(object):
def ensure_any(self, states, wait=0.0): def ensure_any(self, states, wait=0.0):
''' '''
Ensure we are currently in one of the given `states` Ensure we are currently in one of the given `states` or wait until
we enter one of those states.
Note that due to the nature of the function, you cannot guarantee that
the entirety of some operation completes while you remain in a given
state. That would require acquiring and holding a lock, which
would mean no other threads could do the same. (You'd essentially
be serializing all of the threads that are 'ensuring' their tasks
occurred in some state.
''' '''
if not (isinstance(states,tuple) or isinstance(states,list)): if not (isinstance(states,tuple) or isinstance(states,list)):
raise ValueError('states arg should be a tuple or list') raise ValueError('states arg should be a tuple or list')
@ -157,13 +170,17 @@ class StateMachine(object):
if not state in self.__states: if not state in self.__states:
raise ValueError( "StateMachine does not contain state '%s'" % state ) raise ValueError( "StateMachine does not contain state '%s'" % state )
with self.lock: # Locking never really gained us anything here, since the lock was released
start = time.time() # before the function returned anyways. The only thing it _did_ do was
while not self.__current_state in states: # increase the probability that this function would block for longer than
# detect timeout: # intended if a `transition` function or context was running while holding
if time.time() >= start + wait: return False # the lock.
self.lock.wait(wait) start = time.time()
return self.__current_state in states # should always be True due to lock while not self.__current_state in states:
# detect timeout:
if time.time() >= start + wait: return False
self.notifier.wait(wait)
return True
def reset(self): def reset(self):
@ -202,19 +219,19 @@ class _StateCtx:
self.from_state = from_state self.from_state = from_state
self.to_state = to_state self.to_state = to_state
self.wait = wait self.wait = wait
self._timeout = False self._locked = False
def __enter__(self): def __enter__(self):
self.state_machine.lock.acquire()
start = time.time() start = time.time()
while not self.state_machine[ self.from_state ]: while not self.state_machine[ self.from_state ] or not self.state_machine.lock.acquire(False):
# detect timeout: # detect timeout:
if time.time() >= start + self.wait: if time.time() >= start + self.wait:
logging.debug('StateMachine timeout while waiting for state: %s', self.from_state ) logging.debug('StateMachine timeout while waiting for state: %s', self.from_state )
self._timeout = True # to indicate we should not transition
return False return False
self.state_machine.lock.wait(self.wait) self.state_machine.notifier.wait(self.wait)
self._locked = True # lock has been acquired at this point
self.state_machine.notifier.clear()
logging.debug('StateMachine entered context in state: %s', logging.debug('StateMachine entered context in state: %s',
self.state_machine.current_state() ) self.state_machine.current_state() )
return True return True
@ -222,13 +239,16 @@ class _StateCtx:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val is not None: if exc_val is not None:
logging.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s", logging.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s",
self.state_machine.current_state(), exc_type.__name__, exc_val ) self.state_machine.current_state(), exc_type.__name__, exc_val )
elif not self._timeout:
logging.debug(' ==== TRANSITION %s -> %s', if self._locked:
self.state_machine.current_state(), self.to_state) if exc_val is None:
self.state_machine._set_state( self.to_state ) logging.debug(' ==== TRANSITION %s -> %s',
self.state_machine.current_state(), self.to_state)
self.state_machine._set_state( self.to_state )
self.state_machine.notifier.set()
self.state_machine.lock.release()
self.state_machine.lock.notify_all()
self.state_machine.lock.release()
return False # re-raise any exception return False # re-raise any exception

View File

@ -256,6 +256,73 @@ class testStateMachine(unittest.TestCase):
self.assertTrue( s['three'] ) self.assertTrue( s['three'] )
def testTransitionsDontUnintentionallyBlock(self):
'''
There was a bug where a long-running transition (e.g. one with a 'func'
arg or a `transition_ctx` call would cause any `transition` or `ensure`
call to block since the lock is acquired before checking the current
state. Attempts to acquire the mutex need to be non-blocking so when a
timeout is _not_ given, the caller can return immediately. At the same
time, threads that _do_ want to wait need the ability to be notified
(to avoid waiting beyond when the lock is released) so we've moved to a
combination of a plain-ol `threading.Lock` to act as mutex, and a
`threading.Event` to perform notification for threads who choose to wait.
'''
s = sm.StateMachine(('one','two','three'))
with s.transition_ctx('two','three') as result:
self.failIf( result )
self.assertTrue( s['one'] )
self.failIf( s.current_state in ('two','three') )
self.assertTrue( s['one'] )
statuses = {'t1':"not started",
't2':'not started'}
def t1():
print 'thread 1 started'
# no wait, so this should 'return False' immediately.
self.failIf( s.transition('two','three') )
statuses['t1'] = 'complete'
print 'thread 1 transitioned'
def t2():
print 'thread 2 started'
self.failIf( s['two'] )
self.failIf( s['three'] )
# we want this thread to acquire the lock, but for
# the second thread not to wait on the first.
with s.transition_ctx('one','two', 10) as locked:
statuses['t2'] = 'started'
print 'thread 2 has entered context'
self.assertTrue( locked )
# give thread1 a chance to complete while this
# thread still owns the lock
time.sleep(5)
self.assertTrue( s['two'] )
statuses['t2'] = 'complete'
t1 = threading.Thread(target=t1)
t2 = threading.Thread(target=t2)
t2.start() # this should acquire the lock
time.sleep(.2)
self.assertEqual( 'started', statuses['t2'] )
t1.start() # but it shouldn't prevent thread 1 from completing
time.sleep(1)
self.assertEqual( 'complete', statuses['t1'] )
t1.join()
t2.join()
self.assertEqual( 'complete', statuses['t2'] )
self.assertTrue( s['two'] )
suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine) suite = unittest.TestLoader().loadTestsFromTestCase(testStateMachine)
if __name__ == '__main__': unittest.main() if __name__ == '__main__': unittest.main()