mirror of https://github.com/gsi-upm/soil
Refactor
* Removed references to `set_state` * Split some functionality from `agents` into separate files (`fsm` and `network_agents`) * Rename `neighboring_agents` to `neighbors` * Delete some spurious functionsmesa
parent
880a9f2a1c
commit
3776c4e5c5
@ -0,0 +1,133 @@
|
|||||||
|
from . import MetaAgent, BaseAgent
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
def state(name=None):
|
||||||
|
def decorator(func, name=None):
|
||||||
|
"""
|
||||||
|
A state function should return either a state id, or a tuple (state_id, when)
|
||||||
|
The default value for state_id is the current state id.
|
||||||
|
The default value for when is the interval defined in the environment.
|
||||||
|
"""
|
||||||
|
if inspect.isgeneratorfunction(func):
|
||||||
|
orig_func = func
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def func(self):
|
||||||
|
while True:
|
||||||
|
if not self._coroutine:
|
||||||
|
self._coroutine = orig_func(self)
|
||||||
|
try:
|
||||||
|
n = next(self._coroutine)
|
||||||
|
if n:
|
||||||
|
return None, n
|
||||||
|
return
|
||||||
|
except StopIteration as ex:
|
||||||
|
self._coroutine = None
|
||||||
|
next_state = ex.value
|
||||||
|
if next_state is not None:
|
||||||
|
self._set_state(next_state)
|
||||||
|
return next_state
|
||||||
|
|
||||||
|
func.id = name or func.__name__
|
||||||
|
func.is_default = False
|
||||||
|
return func
|
||||||
|
|
||||||
|
if callable(name):
|
||||||
|
return decorator(name)
|
||||||
|
else:
|
||||||
|
return partial(decorator, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def default_state(func):
|
||||||
|
func.is_default = True
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
class MetaFSM(MetaAgent):
|
||||||
|
def __new__(mcls, name, bases, namespace):
|
||||||
|
states = {}
|
||||||
|
# Re-use states from inherited classes
|
||||||
|
default_state = None
|
||||||
|
for i in bases:
|
||||||
|
if isinstance(i, MetaFSM):
|
||||||
|
for state_id, state in i._states.items():
|
||||||
|
if state.is_default:
|
||||||
|
default_state = state
|
||||||
|
states[state_id] = state
|
||||||
|
|
||||||
|
# Add new states
|
||||||
|
for attr, func in namespace.items():
|
||||||
|
if hasattr(func, "id"):
|
||||||
|
if func.is_default:
|
||||||
|
default_state = func
|
||||||
|
states[func.id] = func
|
||||||
|
|
||||||
|
namespace.update(
|
||||||
|
{
|
||||||
|
"_default_state": default_state,
|
||||||
|
"_states": states,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return super(MetaFSM, mcls).__new__(
|
||||||
|
mcls=mcls, name=name, bases=bases, namespace=namespace
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(FSM, self).__init__(**kwargs)
|
||||||
|
if not hasattr(self, "state_id"):
|
||||||
|
if not self._default_state:
|
||||||
|
raise ValueError(
|
||||||
|
"No default state specified for {}".format(self.unique_id)
|
||||||
|
)
|
||||||
|
self.state_id = self._default_state.id
|
||||||
|
|
||||||
|
self._coroutine = None
|
||||||
|
self._set_state(self.state_id)
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.debug(f"Agent {self.unique_id} @ state {self.state_id}")
|
||||||
|
default_interval = super().step()
|
||||||
|
|
||||||
|
next_state = self._states[self.state_id](self)
|
||||||
|
|
||||||
|
when = None
|
||||||
|
try:
|
||||||
|
next_state, *when = next_state
|
||||||
|
if not when:
|
||||||
|
when = None
|
||||||
|
elif len(when) == 1:
|
||||||
|
when = when[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Too many values returned. Only state (and time) allowed"
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if next_state is not None:
|
||||||
|
self._set_state(next_state)
|
||||||
|
|
||||||
|
return when or default_interval
|
||||||
|
|
||||||
|
def _set_state(self, state, when=None):
|
||||||
|
if hasattr(state, "id"):
|
||||||
|
state = state.id
|
||||||
|
if state not in self._states:
|
||||||
|
raise ValueError("{} is not a valid state".format(state))
|
||||||
|
self.state_id = state
|
||||||
|
if when is not None:
|
||||||
|
self.model.schedule.add(self, when=when)
|
||||||
|
return state
|
||||||
|
|
||||||
|
def die(self):
|
||||||
|
return self.dead, super().die()
|
||||||
|
|
||||||
|
@state
|
||||||
|
def dead(self):
|
||||||
|
return self.die()
|
@ -0,0 +1,82 @@
|
|||||||
|
from . import BaseAgent
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkAgent(BaseAgent):
|
||||||
|
def __init__(self, *args, topology, node_id, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
assert topology is not None
|
||||||
|
assert node_id is not None
|
||||||
|
self.G = topology
|
||||||
|
assert self.G
|
||||||
|
self.node_id = node_id
|
||||||
|
|
||||||
|
def count_neighbors(self, state_id=None, **kwargs):
|
||||||
|
return len(self.get_neighbors(state_id=state_id, **kwargs))
|
||||||
|
|
||||||
|
def get_neighbors(self, **kwargs):
|
||||||
|
return list(self.iter_agents(limit_neighbors=True, **kwargs))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node(self):
|
||||||
|
return self.G.nodes[self.node_id]
|
||||||
|
|
||||||
|
def iter_agents(self, unique_id=None, *, limit_neighbors=False, **kwargs):
|
||||||
|
unique_ids = None
|
||||||
|
if isinstance(unique_id, list):
|
||||||
|
unique_ids = set(unique_id)
|
||||||
|
elif unique_id is not None:
|
||||||
|
unique_ids = set(
|
||||||
|
[
|
||||||
|
unique_id,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if limit_neighbors:
|
||||||
|
neighbor_ids = set()
|
||||||
|
for node_id in self.G.neighbors(self.node_id):
|
||||||
|
if self.G.nodes[node_id].get("agent") is not None:
|
||||||
|
neighbor_ids.add(node_id)
|
||||||
|
if unique_ids:
|
||||||
|
unique_ids = unique_ids & neighbor_ids
|
||||||
|
else:
|
||||||
|
unique_ids = neighbor_ids
|
||||||
|
if not unique_ids:
|
||||||
|
return
|
||||||
|
unique_ids = list(unique_ids)
|
||||||
|
yield from super().iter_agents(unique_id=unique_ids, **kwargs)
|
||||||
|
|
||||||
|
def subgraph(self, center=True, **kwargs):
|
||||||
|
include = [self] if center else []
|
||||||
|
G = self.G.subgraph(
|
||||||
|
n.node_id for n in list(self.get_agents(**kwargs) + include)
|
||||||
|
)
|
||||||
|
return G
|
||||||
|
|
||||||
|
def remove_node(self):
|
||||||
|
print(f"Removing node for {self.unique_id}: {self.node_id}")
|
||||||
|
self.G.remove_node(self.node_id)
|
||||||
|
self.node_id = None
|
||||||
|
|
||||||
|
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
||||||
|
if self.node_id not in self.G.nodes(data=False):
|
||||||
|
raise ValueError(
|
||||||
|
"{} not in list of existing agents in the network".format(
|
||||||
|
self.unique_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if other.node_id not in self.G.nodes(data=False):
|
||||||
|
raise ValueError(
|
||||||
|
"{} not in list of existing agents in the network".format(other)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.G.add_edge(
|
||||||
|
self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs
|
||||||
|
)
|
||||||
|
|
||||||
|
def die(self, remove=True):
|
||||||
|
if not self.alive:
|
||||||
|
return None
|
||||||
|
if remove:
|
||||||
|
self.remove_node()
|
||||||
|
return super().die()
|
Loading…
Reference in New Issue