You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
soil/soil/agents/fsm.py

149 lines
4.4 KiB
Python

from . import MetaAgent, BaseAgent
from ..time import Delta
from functools import partial, wraps
import inspect
def state(name=None, default=False):
def decorator(func, name=None):
"""
A state function should return either a state id, or a tuple (state_id, when)
The default value for state_id is the current state id.
The default value for when is the interval defined in the environment.
"""
if inspect.isgeneratorfunction(func):
orig_func = func
@wraps(func)
def func(self):
while True:
if not self._coroutine:
self._coroutine = orig_func(self)
try:
if self._last_except:
n = self._coroutine.throw(self._last_except)
else:
n = self._coroutine.send(self._last_return)
if n:
return None, n
return n
except StopIteration as ex:
self._coroutine = None
next_state = ex.value
if next_state is not None:
self._set_state(next_state)
return next_state
finally:
self._last_return = None
self._last_except = None
func.id = name or func.__name__
func.is_default = default
return func
if callable(name):
return decorator(name)
else:
return partial(decorator, name=name)
def default_state(func):
func.is_default = True
return func
class MetaFSM(MetaAgent):
def __new__(mcls, name, bases, namespace):
states = {}
# Re-use states from inherited classes
default_state = None
for i in bases:
if isinstance(i, MetaFSM):
for state_id, state in i._states.items():
if state.is_default:
default_state = state
states[state_id] = state
# Add new states
for attr, func in namespace.items():
if hasattr(func, "id"):
if func.is_default:
default_state = func
states[func.id] = func
namespace.update(
{
"_default_state": default_state,
"_states": states,
}
)
return super(MetaFSM, mcls).__new__(
mcls=mcls, name=name, bases=bases, namespace=namespace
)
class FSM(BaseAgent, metaclass=MetaFSM):
def __init__(self, init=True, **kwargs):
super().__init__(**kwargs, init=False)
if not hasattr(self, "state_id"):
if not self._default_state:
raise ValueError(
"No default state specified for {}".format(self.unique_id)
)
self.state_id = self._default_state.id
self._coroutine = None
self.default_interval = Delta(self.model.interval)
self._set_state(self.state_id)
if init:
self.init()
@classmethod
def states(cls):
return list(cls._states.keys())
def step(self):
self.debug(f"Agent {self.unique_id} @ state {self.state_id}")
self._check_alive()
next_state = self._states[self.state_id](self)
when = None
try:
next_state, *when = next_state
if not when:
when = None
elif len(when) == 1:
when = when[0]
else:
raise ValueError(
"Too many values returned. Only state (and time) allowed"
)
except TypeError:
pass
if next_state is not None:
self._set_state(next_state)
return when or self.default_interval
def _set_state(self, state, when=None):
if hasattr(state, "id"):
state = state.id
if state not in self._states:
raise ValueError("{} is not a valid state".format(state))
self.state_id = state
if when is not None:
self.model.schedule.add(self, when=when)
return state
def die(self, *args, **kwargs):
return self.dead, super().die(*args, **kwargs)
@state
def dead(self):
return self.die()