mirror of https://github.com/gsi-upm/soil
Refactor
* Removed references to `set_state` * Split some functionality from `agents` into separate files (`fsm` and `network_agents`) * Rename `neighboring_agents` to `neighbors` * Delete some spurious functionsmesa
parent
880a9f2a1c
commit
3776c4e5c5
@ -0,0 +1,133 @@
|
||||
from . import MetaAgent, BaseAgent
|
||||
|
||||
from functools import partial
|
||||
import inspect
|
||||
|
||||
|
||||
def state(name=None):
|
||||
def decorator(func, name=None):
|
||||
"""
|
||||
A state function should return either a state id, or a tuple (state_id, when)
|
||||
The default value for state_id is the current state id.
|
||||
The default value for when is the interval defined in the environment.
|
||||
"""
|
||||
if inspect.isgeneratorfunction(func):
|
||||
orig_func = func
|
||||
|
||||
@wraps(func)
|
||||
def func(self):
|
||||
while True:
|
||||
if not self._coroutine:
|
||||
self._coroutine = orig_func(self)
|
||||
try:
|
||||
n = next(self._coroutine)
|
||||
if n:
|
||||
return None, n
|
||||
return
|
||||
except StopIteration as ex:
|
||||
self._coroutine = None
|
||||
next_state = ex.value
|
||||
if next_state is not None:
|
||||
self._set_state(next_state)
|
||||
return next_state
|
||||
|
||||
func.id = name or func.__name__
|
||||
func.is_default = False
|
||||
return func
|
||||
|
||||
if callable(name):
|
||||
return decorator(name)
|
||||
else:
|
||||
return partial(decorator, name=name)
|
||||
|
||||
|
||||
def default_state(func):
|
||||
func.is_default = True
|
||||
return func
|
||||
|
||||
|
||||
class MetaFSM(MetaAgent):
|
||||
def __new__(mcls, name, bases, namespace):
|
||||
states = {}
|
||||
# Re-use states from inherited classes
|
||||
default_state = None
|
||||
for i in bases:
|
||||
if isinstance(i, MetaFSM):
|
||||
for state_id, state in i._states.items():
|
||||
if state.is_default:
|
||||
default_state = state
|
||||
states[state_id] = state
|
||||
|
||||
# Add new states
|
||||
for attr, func in namespace.items():
|
||||
if hasattr(func, "id"):
|
||||
if func.is_default:
|
||||
default_state = func
|
||||
states[func.id] = func
|
||||
|
||||
namespace.update(
|
||||
{
|
||||
"_default_state": default_state,
|
||||
"_states": states,
|
||||
}
|
||||
)
|
||||
|
||||
return super(MetaFSM, mcls).__new__(
|
||||
mcls=mcls, name=name, bases=bases, namespace=namespace
|
||||
)
|
||||
|
||||
|
||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
def __init__(self, **kwargs):
|
||||
super(FSM, self).__init__(**kwargs)
|
||||
if not hasattr(self, "state_id"):
|
||||
if not self._default_state:
|
||||
raise ValueError(
|
||||
"No default state specified for {}".format(self.unique_id)
|
||||
)
|
||||
self.state_id = self._default_state.id
|
||||
|
||||
self._coroutine = None
|
||||
self._set_state(self.state_id)
|
||||
|
||||
def step(self):
|
||||
self.debug(f"Agent {self.unique_id} @ state {self.state_id}")
|
||||
default_interval = super().step()
|
||||
|
||||
next_state = self._states[self.state_id](self)
|
||||
|
||||
when = None
|
||||
try:
|
||||
next_state, *when = next_state
|
||||
if not when:
|
||||
when = None
|
||||
elif len(when) == 1:
|
||||
when = when[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Too many values returned. Only state (and time) allowed"
|
||||
)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
if next_state is not None:
|
||||
self._set_state(next_state)
|
||||
|
||||
return when or default_interval
|
||||
|
||||
def _set_state(self, state, when=None):
|
||||
if hasattr(state, "id"):
|
||||
state = state.id
|
||||
if state not in self._states:
|
||||
raise ValueError("{} is not a valid state".format(state))
|
||||
self.state_id = state
|
||||
if when is not None:
|
||||
self.model.schedule.add(self, when=when)
|
||||
return state
|
||||
|
||||
def die(self):
|
||||
return self.dead, super().die()
|
||||
|
||||
@state
|
||||
def dead(self):
|
||||
return self.die()
|
@ -0,0 +1,82 @@
|
||||
from . import BaseAgent
|
||||
|
||||
|
||||
class NetworkAgent(BaseAgent):
|
||||
def __init__(self, *args, topology, node_id, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
assert topology is not None
|
||||
assert node_id is not None
|
||||
self.G = topology
|
||||
assert self.G
|
||||
self.node_id = node_id
|
||||
|
||||
def count_neighbors(self, state_id=None, **kwargs):
|
||||
return len(self.get_neighbors(state_id=state_id, **kwargs))
|
||||
|
||||
def get_neighbors(self, **kwargs):
|
||||
return list(self.iter_agents(limit_neighbors=True, **kwargs))
|
||||
|
||||
@property
|
||||
def node(self):
|
||||
return self.G.nodes[self.node_id]
|
||||
|
||||
def iter_agents(self, unique_id=None, *, limit_neighbors=False, **kwargs):
|
||||
unique_ids = None
|
||||
if isinstance(unique_id, list):
|
||||
unique_ids = set(unique_id)
|
||||
elif unique_id is not None:
|
||||
unique_ids = set(
|
||||
[
|
||||
unique_id,
|
||||
]
|
||||
)
|
||||
|
||||
if limit_neighbors:
|
||||
neighbor_ids = set()
|
||||
for node_id in self.G.neighbors(self.node_id):
|
||||
if self.G.nodes[node_id].get("agent") is not None:
|
||||
neighbor_ids.add(node_id)
|
||||
if unique_ids:
|
||||
unique_ids = unique_ids & neighbor_ids
|
||||
else:
|
||||
unique_ids = neighbor_ids
|
||||
if not unique_ids:
|
||||
return
|
||||
unique_ids = list(unique_ids)
|
||||
yield from super().iter_agents(unique_id=unique_ids, **kwargs)
|
||||
|
||||
def subgraph(self, center=True, **kwargs):
|
||||
include = [self] if center else []
|
||||
G = self.G.subgraph(
|
||||
n.node_id for n in list(self.get_agents(**kwargs) + include)
|
||||
)
|
||||
return G
|
||||
|
||||
def remove_node(self):
|
||||
print(f"Removing node for {self.unique_id}: {self.node_id}")
|
||||
self.G.remove_node(self.node_id)
|
||||
self.node_id = None
|
||||
|
||||
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
||||
if self.node_id not in self.G.nodes(data=False):
|
||||
raise ValueError(
|
||||
"{} not in list of existing agents in the network".format(
|
||||
self.unique_id
|
||||
)
|
||||
)
|
||||
if other.node_id not in self.G.nodes(data=False):
|
||||
raise ValueError(
|
||||
"{} not in list of existing agents in the network".format(other)
|
||||
)
|
||||
|
||||
self.G.add_edge(
|
||||
self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs
|
||||
)
|
||||
|
||||
def die(self, remove=True):
|
||||
if not self.alive:
|
||||
return None
|
||||
if remove:
|
||||
self.remove_node()
|
||||
return super().die()
|
Loading…
Reference in New Issue