From af9a392a9392df168538f39252cdccd1b2c8e0bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=2E=20Fernando=20S=C3=A1nchez?= Date: Fri, 15 Oct 2021 13:36:39 +0200 Subject: [PATCH] WIP: mesa compat All tests pass but some features are still missing/unclear: - Mesa agents do not have a `state`, so their "metrics" don't get stored. I will probably refactor this to remove some magic in this regard. This should get rid of the `_state` dictionary and the setitem/getitem magic. - Simulation is still different from a runner. So far only Agent and Environment/Model have been updated. --- examples/rabbits/rabbit_agents.py | 8 ++-- examples/rabbits/rabbits.yml | 2 +- soil/agents/Geo.py | 3 +- soil/agents/__init__.py | 73 ++++++++++++++++--------------- soil/environment.py | 11 ++--- soil/time.py | 5 ++- tests/test_main.py | 39 +++++++++++++++++ 7 files changed, 91 insertions(+), 50 deletions(-) diff --git a/examples/rabbits/rabbit_agents.py b/examples/rabbits/rabbit_agents.py index 499acae..1c978b1 100644 --- a/examples/rabbits/rabbit_agents.py +++ b/examples/rabbits/rabbit_agents.py @@ -1,7 +1,6 @@ from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent from enum import Enum from random import random, choice -from itertools import islice import logging import math @@ -22,7 +21,7 @@ class RabbitModel(FSM): 'offspring': 0, } - sexual_maturity = 4*30 + sexual_maturity = 3 #4*30 life_expectancy = 365 * 3 gestation = 33 pregnancy = -1 @@ -31,9 +30,11 @@ class RabbitModel(FSM): @default_state @state def newborn(self): + self.debug(f'I am a newborn at age {self["age"]}') self['age'] += 1 if self['age'] >= self.sexual_maturity: + self.debug('I am fertile!') return self.fertile @state @@ -46,8 +47,7 @@ class RabbitModel(FSM): return # Males try to mate - females = self.get_agents(state_id=self.fertile.id, gender=Genders.female.value, limit_neighbors=False) - for f in islice(females, self.max_females): + for f in self.get_agents(state_id=self.fertile.id, gender=Genders.female.value, limit_neighbors=False, limit=self.max_females): r = random() if r < self['mating_prob']: self.impregnate(f) diff --git a/examples/rabbits/rabbits.yml b/examples/rabbits/rabbits.yml index 25275f3..204aa7a 100644 --- a/examples/rabbits/rabbits.yml +++ b/examples/rabbits/rabbits.yml @@ -1,7 +1,7 @@ --- load_module: rabbit_agents name: rabbits_example -max_time: 500 +max_time: 200 interval: 1 seed: MySeed agent_type: RabbitModel diff --git a/soil/agents/Geo.py b/soil/agents/Geo.py index 8e872d9..bf505bf 100644 --- a/soil/agents/Geo.py +++ b/soil/agents/Geo.py @@ -1,5 +1,6 @@ from scipy.spatial import cKDTree as KDTree -from . import NetworkAgent +import networkx as nx +from . import NetworkAgent, as_node class Geo(NetworkAgent): '''In this type of network, nodes have a "pos" attribute.''' diff --git a/soil/agents/__init__.py b/soil/agents/__init__.py index 5af3cf9..bc8b685 100644 --- a/soil/agents/__init__.py +++ b/soil/agents/__init__.py @@ -1,12 +1,11 @@ import logging from collections import OrderedDict, defaultdict from copy import deepcopy -from functools import partial +from functools import partial, wraps +from itertools import islice import json import networkx as nx -from functools import wraps - from .. import serialization, history, utils, time from mesa import Agent @@ -91,7 +90,7 @@ class BaseAgent(Agent): def __getitem__(self, key): if isinstance(key, tuple): key, t_step = key - k = history.Key(key=key, t_step=t_step, agent_id=self.id) + k = history.Key(key=key, t_step=t_step, agent_id=self.unique_id) return self.model[k] return self._state.get(key, None) @@ -151,25 +150,25 @@ class BaseAgent(Agent): def info(self, *args, **kwargs): return self.log(*args, level=logging.INFO, **kwargs) - def __getstate__(self): - ''' - Serializing an agent will lose all its running information (you cannot - serialize an iterator), but it keeps the state and link to the environment, - so it can be used for inspection and dumping to a file - ''' - state = {} - state['id'] = self.id - state['environment'] = self.model - state['_state'] = self._state - return state + # def __getstate__(self): + # ''' + # Serializing an agent will lose all its running information (you cannot + # serialize an iterator), but it keeps the state and link to the environment, + # so it can be used for inspection and dumping to a file + # ''' + # state = {} + # state['id'] = self.id + # state['environment'] = self.model + # state['_state'] = self._state + # return state - def __setstate__(self, state): - ''' - Get back a serialized agent and try to re-compose it - ''' - self.state_id = state['id'] - self._state = state['_state'] - self.model = state['environment'] + # def __setstate__(self, state): + # ''' + # Get back a serialized agent and try to re-compose it + # ''' + # self.state_id = state['id'] + # self._state = state['_state'] + # self.model = state['environment'] class NetworkAgent(BaseAgent): @@ -190,7 +189,13 @@ class NetworkAgent(BaseAgent): def get_neighboring_agents(self, state_id=None, **kwargs): return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs) - def get_agents(self, agents=None, limit_neighbors=False, **kwargs): + def get_agents(self, *args, limit=None, **kwargs): + it = self.iter_agents(*args, **kwargs) + if limit is not None: + it = islice(it, limit) + return list(it) + + def iter_agents(self, agents=None, limit_neighbors=False, **kwargs): if limit_neighbors: agents = self.topology.neighbors(self.unique_id) @@ -199,7 +204,7 @@ class NetworkAgent(BaseAgent): def subgraph(self, center=True, **kwargs): include = [self] if center else [] - return self.topology.subgraph(n.unique_id for n in self.get_agents(**kwargs)+include) + return self.topology.subgraph(n.unique_id for n in list(self.get_agents(**kwargs))+include) def remove_node(self, unique_id): self.topology.remove_node(unique_id) @@ -208,10 +213,10 @@ class NetworkAgent(BaseAgent): # return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs) if self.unique_id not in self.topology.nodes(data=False): raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id)) - if other not in self.topology.nodes(data=False): + if other.unique_id not in self.topology.nodes(data=False): raise ValueError('{} not in list of existing agents in the network'.format(other)) - self.topology.add_edge(self.unique_id, other, edge_attr_dict=edge_attr_dict, *edge_attrs) + self.topology.add_edge(self.unique_id, other.unique_id, edge_attr_dict=edge_attr_dict, *edge_attrs) def ego_search(self, steps=1, center=False, node=None, **kwargs): @@ -308,21 +313,21 @@ class FSM(NetworkAgent, metaclass=MetaFSM): def step(self): self.debug(f'Agent {self.unique_id} @ state {self["id"]}') interval = super().step() - if 'id' not in self: - if 'id' in self.state: - self.set_state(self['state_id']) - elif self.default_state: + if 'id' not in self.state: + # if 'id' in self.state: + # self.set_state(self.state['id']) + if self.default_state: self.set_state(self.default_state.id) else: raise Exception('{} has no valid state id or default state'.format(self)) - return self.states[self['id']](self) or interval + return self.states[self.state['id']](self) or interval def set_state(self, state): if hasattr(state, 'id'): state = state.id if state not in self.states: raise ValueError('{} is not a valid state'.format(state)) - self['state_id'] = state + self.state['id'] = state return state @@ -530,8 +535,6 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False, except TypeError: agent_type = tuple([agent_type]) - checks = [] - f = agents if ignore: @@ -547,7 +550,7 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False, if iterator: return f - return list(f) + return f from .BassModel import * diff --git a/soil/environment.py b/soil/environment.py index af2dc03..50c0745 100644 --- a/soil/environment.py +++ b/soil/environment.py @@ -81,9 +81,6 @@ class Environment(Model): self._history = history.History(name=self.name, backup=True) self['SEED'] = seed - # Add environment agents first, so their events get - # executed before network agents - if network_agents: distro = agents.calculate_distribution(network_agents) @@ -97,7 +94,6 @@ class Environment(Model): environment_agents = agents._convert_agent_types(distro) self.environment_agents = environment_agents - @property def now(self): if self.schedule: @@ -169,6 +165,7 @@ class Environment(Model): unique_id=agent_id, state=state) node['agent'] = a + self.schedule.add(a) return a def add_node(self, agent_type, state=None): @@ -188,8 +185,6 @@ class Environment(Model): def run(self, until, *args, **kwargs): self._save_state() - for agent in self.agents: - self.schedule.add(agent) while self.schedule.next_time <= until and not math.isinf(self.schedule.next_time): self.schedule.step(until=until) @@ -238,8 +233,8 @@ class Environment(Model): def get_agents(self, nodes=None): if nodes is None: - return list(self.agents) - return [self.G.nodes[i]['agent'] for i in nodes] + return self.agents + return (self.G.nodes[i]['agent'] for i in nodes) def dump_csv(self, f): with utils.open_or_reuse(f, 'w') as f: diff --git a/soil/time.py b/soil/time.py index 27e39c9..52ed2eb 100644 --- a/soil/time.py +++ b/soil/time.py @@ -18,6 +18,9 @@ class Delta: def __init__(self, delta): self._delta = delta + def __eq__(self, other): + return self._delta == other._delta + def abs(self, time): return time + self._delta @@ -35,7 +38,7 @@ class TimedActivation(BaseScheduler): def add(self, agent: Agent): if agent.unique_id not in self._agents: heappush(self._queue, (self.time, agent.unique_id)) - super().add(agent) + super().add(agent) def step(self, until: float =float('inf')) -> None: """ diff --git a/tests/test_main.py b/tests/test_main.py index f3cb4fd..db28e19 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -10,6 +10,7 @@ from functools import partial from os.path import join from soil import (simulation, Environment, agents, serialization, history, utils) +from soil.time import Delta ROOT = os.path.abspath(os.path.dirname(__file__)) @@ -356,3 +357,41 @@ class TestMain(TestCase): runs = list(s.run_simulation(dry_run=True)) over = list(x.now for x in runs if x.now>2) assert len(over) == 0 + + + def test_fsm(self): + '''Basic state change''' + class ToggleAgent(agents.FSM): + @agents.default_state + @agents.state + def ping(self): + return self.pong + + @agents.state + def pong(self): + return self.ping + + a = ToggleAgent(unique_id=1, model=Environment()) + assert a.state["id"] == a.ping.id + a.step() + assert a.state["id"] == a.pong.id + a.step() + assert a.state["id"] == a.ping.id + + def test_fsm_when(self): + '''Basic state change''' + class ToggleAgent(agents.FSM): + @agents.default_state + @agents.state + def ping(self): + return self.pong, 2 + + @agents.state + def pong(self): + return self.ping + + a = ToggleAgent(unique_id=1, model=Environment()) + when = a.step() + assert when == 2 + when = a.step() + assert when == Delta(a.interval)