fix for statemachine where operations would unintentionally block if the lock was acquired in a long-running transition

This commit is contained in:
Tom Nichols
2010-07-01 15:10:22 -04:00
parent 8bdfa77024
commit 0a23f84ec3
2 changed files with 122 additions and 35 deletions

View File

@@ -5,7 +5,6 @@
See the file license.txt for copying permission.
"""
from __future__ import with_statement
import threading
import time
import logging
@@ -14,18 +13,21 @@ import logging
class StateMachine(object):
def __init__(self, states=[]):
self.lock = threading.Condition(threading.RLock())
self.lock = threading.Lock()
self.notifier = threading.Event()
self.__states= []
self.addStates(states)
self.__default_state = self.__states[0]
self.__current_state = self.__default_state
def addStates(self, states):
with self.lock:
self.lock.acquire()
try:
for state in states:
if state in self.__states:
raise IndexError("The state '%s' is already in the StateMachine." % state)
self.__states.append( state )
finally: self.lock.release()
def transition(self, from_state, to_state, wait=0.0, func=None, args=[], kwargs={} ):
@@ -78,30 +80,33 @@ class StateMachine(object):
if not to_state in self.__states:
raise ValueError( "StateMachine does not contain to_state %s." % to_state )
with self.lock:
start = time.time()
while not self.__current_state in from_states:
# detect timeout:
if time.time() >= start + wait: return False
self.lock.wait(wait)
start = time.time()
while not self.__current_state in from_states or not self.lock.acquire(False):
# detect timeout:
if time.time() >= start + wait: return False
self.notifier.wait(wait)
try: # lock is acquired; all other threads will return false or wait until notify/timeout
self.notifier.clear()
if self.__current_state in from_states: # should always be True due to lock
return_val = True
# Note that func might throw an exception, but that's OK, it aborts the transition
if func is not None: return_val = func(*args,**kwargs)
return_val = func(*args,**kwargs) if func is not None else True
# some 'false' value returned from func,
# indicating that transition should not occur:
if not return_val: return return_val
logging.debug(' ==== TRANSITION %s -> %s', self.__current_state, to_state)
self.__current_state = to_state
self.lock.notify_all()
self._set_state( to_state )
return return_val # some 'true' value returned by func or True if func was None
else:
logging.error( "StateMachine bug!! The lock should ensure this doesn't happen!" )
return False
finally:
self.notifier.set()
self.lock.release()
def transition_ctx(self, from_state, to_state, wait=0.0):
@@ -148,7 +153,15 @@ class StateMachine(object):
def ensure_any(self, states, wait=0.0):
'''
Ensure we are currently in one of the given `states`
Ensure we are currently in one of the given `states` or wait until
we enter one of those states.
Note that due to the nature of the function, you cannot guarantee that
the entirety of some operation completes while you remain in a given
state. That would require acquiring and holding a lock, which
would mean no other threads could do the same. (You'd essentially
be serializing all of the threads that are 'ensuring' their tasks
occurred in some state.
'''
if not (isinstance(states,tuple) or isinstance(states,list)):
raise ValueError('states arg should be a tuple or list')
@@ -157,13 +170,17 @@ class StateMachine(object):
if not state in self.__states:
raise ValueError( "StateMachine does not contain state '%s'" % state )
with self.lock:
start = time.time()
while not self.__current_state in states:
# detect timeout:
if time.time() >= start + wait: return False
self.lock.wait(wait)
return self.__current_state in states # should always be True due to lock
# Locking never really gained us anything here, since the lock was released
# before the function returned anyways. The only thing it _did_ do was
# increase the probability that this function would block for longer than
# intended if a `transition` function or context was running while holding
# the lock.
start = time.time()
while not self.__current_state in states:
# detect timeout:
if time.time() >= start + wait: return False
self.notifier.wait(wait)
return True
def reset(self):
@@ -202,19 +219,19 @@ class _StateCtx:
self.from_state = from_state
self.to_state = to_state
self.wait = wait
self._timeout = False
self._locked = False
def __enter__(self):
self.state_machine.lock.acquire()
start = time.time()
while not self.state_machine[ self.from_state ]:
while not self.state_machine[ self.from_state ] or not self.state_machine.lock.acquire(False):
# detect timeout:
if time.time() >= start + self.wait:
logging.debug('StateMachine timeout while waiting for state: %s', self.from_state )
self._timeout = True # to indicate we should not transition
return False
self.state_machine.lock.wait(self.wait)
self.state_machine.notifier.wait(self.wait)
self._locked = True # lock has been acquired at this point
self.state_machine.notifier.clear()
logging.debug('StateMachine entered context in state: %s',
self.state_machine.current_state() )
return True
@@ -222,13 +239,16 @@ class _StateCtx:
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_val is not None:
logging.exception( "StateMachine exception in context, remaining in state: %s\n%s:%s",
self.state_machine.current_state(), exc_type.__name__, exc_val )
elif not self._timeout:
logging.debug(' ==== TRANSITION %s -> %s',
self.state_machine.current_state(), self.to_state)
self.state_machine._set_state( self.to_state )
self.state_machine.current_state(), exc_type.__name__, exc_val )
if self._locked:
if exc_val is None:
logging.debug(' ==== TRANSITION %s -> %s',
self.state_machine.current_state(), self.to_state)
self.state_machine._set_state( self.to_state )
self.state_machine.notifier.set()
self.state_machine.lock.release()
self.state_machine.lock.notify_all()
self.state_machine.lock.release()
return False # re-raise any exception