mirror of https://github.com/gsi-upm/soil
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
149 lines
4.4 KiB
Python
149 lines
4.4 KiB
Python
from . import MetaAgent, BaseAgent
|
|
from ..time import Delta
|
|
|
|
from functools import partial, wraps
|
|
import inspect
|
|
|
|
|
|
def state(name=None, default=False):
|
|
def decorator(func, name=None):
|
|
"""
|
|
A state function should return either a state id, or a tuple (state_id, when)
|
|
The default value for state_id is the current state id.
|
|
The default value for when is the interval defined in the environment.
|
|
"""
|
|
if inspect.isgeneratorfunction(func):
|
|
orig_func = func
|
|
|
|
@wraps(func)
|
|
def func(self):
|
|
while True:
|
|
if not self._coroutine:
|
|
self._coroutine = orig_func(self)
|
|
|
|
try:
|
|
if self._last_except:
|
|
n = self._coroutine.throw(self._last_except)
|
|
else:
|
|
n = self._coroutine.send(self._last_return)
|
|
if n:
|
|
return None, n
|
|
return n
|
|
except StopIteration as ex:
|
|
self._coroutine = None
|
|
next_state = ex.value
|
|
if next_state is not None:
|
|
self._set_state(next_state)
|
|
return next_state
|
|
finally:
|
|
self._last_return = None
|
|
self._last_except = None
|
|
|
|
func.id = name or func.__name__
|
|
func.is_default = default
|
|
return func
|
|
|
|
if callable(name):
|
|
return decorator(name)
|
|
else:
|
|
return partial(decorator, name=name)
|
|
|
|
|
|
def default_state(func):
|
|
func.is_default = True
|
|
return func
|
|
|
|
|
|
class MetaFSM(MetaAgent):
|
|
def __new__(mcls, name, bases, namespace):
|
|
states = {}
|
|
# Re-use states from inherited classes
|
|
default_state = None
|
|
for i in bases:
|
|
if isinstance(i, MetaFSM):
|
|
for state_id, state in i._states.items():
|
|
if state.is_default:
|
|
default_state = state
|
|
states[state_id] = state
|
|
|
|
# Add new states
|
|
for attr, func in namespace.items():
|
|
if hasattr(func, "id"):
|
|
if func.is_default:
|
|
default_state = func
|
|
states[func.id] = func
|
|
|
|
namespace.update(
|
|
{
|
|
"_default_state": default_state,
|
|
"_states": states,
|
|
}
|
|
)
|
|
|
|
return super(MetaFSM, mcls).__new__(
|
|
mcls=mcls, name=name, bases=bases, namespace=namespace
|
|
)
|
|
|
|
|
|
class FSM(BaseAgent, metaclass=MetaFSM):
|
|
def __init__(self, init=True, **kwargs):
|
|
super().__init__(**kwargs, init=False)
|
|
if not hasattr(self, "state_id"):
|
|
if not self._default_state:
|
|
raise ValueError(
|
|
"No default state specified for {}".format(self.unique_id)
|
|
)
|
|
self.state_id = self._default_state.id
|
|
|
|
self._coroutine = None
|
|
self.default_interval = Delta(self.model.interval)
|
|
self._set_state(self.state_id)
|
|
if init:
|
|
self.init()
|
|
|
|
@classmethod
|
|
def states(cls):
|
|
return list(cls._states.keys())
|
|
|
|
def step(self):
|
|
self.debug(f"Agent {self.unique_id} @ state {self.state_id}")
|
|
|
|
self._check_alive()
|
|
next_state = self._states[self.state_id](self)
|
|
|
|
when = None
|
|
try:
|
|
next_state, *when = next_state
|
|
if not when:
|
|
when = None
|
|
elif len(when) == 1:
|
|
when = when[0]
|
|
else:
|
|
raise ValueError(
|
|
"Too many values returned. Only state (and time) allowed"
|
|
)
|
|
except TypeError:
|
|
pass
|
|
|
|
if next_state is not None:
|
|
self._set_state(next_state)
|
|
|
|
return when or self.default_interval
|
|
|
|
def _set_state(self, state, when=None):
|
|
if hasattr(state, "id"):
|
|
state = state.id
|
|
if state not in self._states:
|
|
raise ValueError("{} is not a valid state".format(state))
|
|
self.state_id = state
|
|
if when is not None:
|
|
self.model.schedule.add(self, when=when)
|
|
return state
|
|
|
|
def die(self, *args, **kwargs):
|
|
return self.dead, super().die(*args, **kwargs)
|
|
|
|
@state
|
|
def dead(self):
|
|
return self.die()
|