1
0
mirror of https://github.com/gsi-upm/soil synced 2024-11-24 11:52:29 +00:00

WIP: mesa compat

All tests pass but some features are still missing/unclear:

- Mesa agents do not have a `state`, so their "metrics" don't get stored. I will
probably refactor this to remove some magic in this regard. This should get rid
of the `_state` dictionary and the setitem/getitem magic.
- Simulation is still different from a runner. So far only Agent and
Environment/Model have been updated.
This commit is contained in:
J. Fernando Sánchez 2021-10-15 13:36:39 +02:00
parent 5d7e57675a
commit af9a392a93
7 changed files with 91 additions and 50 deletions

View File

@ -1,7 +1,6 @@
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
from enum import Enum from enum import Enum
from random import random, choice from random import random, choice
from itertools import islice
import logging import logging
import math import math
@ -22,7 +21,7 @@ class RabbitModel(FSM):
'offspring': 0, 'offspring': 0,
} }
sexual_maturity = 4*30 sexual_maturity = 3 #4*30
life_expectancy = 365 * 3 life_expectancy = 365 * 3
gestation = 33 gestation = 33
pregnancy = -1 pregnancy = -1
@ -31,9 +30,11 @@ class RabbitModel(FSM):
@default_state @default_state
@state @state
def newborn(self): def newborn(self):
self.debug(f'I am a newborn at age {self["age"]}')
self['age'] += 1 self['age'] += 1
if self['age'] >= self.sexual_maturity: if self['age'] >= self.sexual_maturity:
self.debug('I am fertile!')
return self.fertile return self.fertile
@state @state
@ -46,8 +47,7 @@ class RabbitModel(FSM):
return return
# Males try to mate # Males try to mate
females = self.get_agents(state_id=self.fertile.id, gender=Genders.female.value, limit_neighbors=False) for f in self.get_agents(state_id=self.fertile.id, gender=Genders.female.value, limit_neighbors=False, limit=self.max_females):
for f in islice(females, self.max_females):
r = random() r = random()
if r < self['mating_prob']: if r < self['mating_prob']:
self.impregnate(f) self.impregnate(f)

View File

@ -1,7 +1,7 @@
--- ---
load_module: rabbit_agents load_module: rabbit_agents
name: rabbits_example name: rabbits_example
max_time: 500 max_time: 200
interval: 1 interval: 1
seed: MySeed seed: MySeed
agent_type: RabbitModel agent_type: RabbitModel

View File

@ -1,5 +1,6 @@
from scipy.spatial import cKDTree as KDTree from scipy.spatial import cKDTree as KDTree
from . import NetworkAgent import networkx as nx
from . import NetworkAgent, as_node
class Geo(NetworkAgent): class Geo(NetworkAgent):
'''In this type of network, nodes have a "pos" attribute.''' '''In this type of network, nodes have a "pos" attribute.'''

View File

@ -1,12 +1,11 @@
import logging import logging
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial, wraps
from itertools import islice
import json import json
import networkx as nx import networkx as nx
from functools import wraps
from .. import serialization, history, utils, time from .. import serialization, history, utils, time
from mesa import Agent from mesa import Agent
@ -91,7 +90,7 @@ class BaseAgent(Agent):
def __getitem__(self, key): def __getitem__(self, key):
if isinstance(key, tuple): if isinstance(key, tuple):
key, t_step = key key, t_step = key
k = history.Key(key=key, t_step=t_step, agent_id=self.id) k = history.Key(key=key, t_step=t_step, agent_id=self.unique_id)
return self.model[k] return self.model[k]
return self._state.get(key, None) return self._state.get(key, None)
@ -151,25 +150,25 @@ class BaseAgent(Agent):
def info(self, *args, **kwargs): def info(self, *args, **kwargs):
return self.log(*args, level=logging.INFO, **kwargs) return self.log(*args, level=logging.INFO, **kwargs)
def __getstate__(self): # def __getstate__(self):
''' # '''
Serializing an agent will lose all its running information (you cannot # Serializing an agent will lose all its running information (you cannot
serialize an iterator), but it keeps the state and link to the environment, # serialize an iterator), but it keeps the state and link to the environment,
so it can be used for inspection and dumping to a file # so it can be used for inspection and dumping to a file
''' # '''
state = {} # state = {}
state['id'] = self.id # state['id'] = self.id
state['environment'] = self.model # state['environment'] = self.model
state['_state'] = self._state # state['_state'] = self._state
return state # return state
def __setstate__(self, state): # def __setstate__(self, state):
''' # '''
Get back a serialized agent and try to re-compose it # Get back a serialized agent and try to re-compose it
''' # '''
self.state_id = state['id'] # self.state_id = state['id']
self._state = state['_state'] # self._state = state['_state']
self.model = state['environment'] # self.model = state['environment']
class NetworkAgent(BaseAgent): class NetworkAgent(BaseAgent):
@ -190,7 +189,13 @@ class NetworkAgent(BaseAgent):
def get_neighboring_agents(self, state_id=None, **kwargs): def get_neighboring_agents(self, state_id=None, **kwargs):
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs) return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
def get_agents(self, agents=None, limit_neighbors=False, **kwargs): def get_agents(self, *args, limit=None, **kwargs):
it = self.iter_agents(*args, **kwargs)
if limit is not None:
it = islice(it, limit)
return list(it)
def iter_agents(self, agents=None, limit_neighbors=False, **kwargs):
if limit_neighbors: if limit_neighbors:
agents = self.topology.neighbors(self.unique_id) agents = self.topology.neighbors(self.unique_id)
@ -199,7 +204,7 @@ class NetworkAgent(BaseAgent):
def subgraph(self, center=True, **kwargs): def subgraph(self, center=True, **kwargs):
include = [self] if center else [] include = [self] if center else []
return self.topology.subgraph(n.unique_id for n in self.get_agents(**kwargs)+include) return self.topology.subgraph(n.unique_id for n in list(self.get_agents(**kwargs))+include)
def remove_node(self, unique_id): def remove_node(self, unique_id):
self.topology.remove_node(unique_id) self.topology.remove_node(unique_id)
@ -208,10 +213,10 @@ class NetworkAgent(BaseAgent):
# return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs) # return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
if self.unique_id not in self.topology.nodes(data=False): if self.unique_id not in self.topology.nodes(data=False):
raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id)) raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id))
if other not in self.topology.nodes(data=False): if other.unique_id not in self.topology.nodes(data=False):
raise ValueError('{} not in list of existing agents in the network'.format(other)) raise ValueError('{} not in list of existing agents in the network'.format(other))
self.topology.add_edge(self.unique_id, other, edge_attr_dict=edge_attr_dict, *edge_attrs) self.topology.add_edge(self.unique_id, other.unique_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
def ego_search(self, steps=1, center=False, node=None, **kwargs): def ego_search(self, steps=1, center=False, node=None, **kwargs):
@ -308,21 +313,21 @@ class FSM(NetworkAgent, metaclass=MetaFSM):
def step(self): def step(self):
self.debug(f'Agent {self.unique_id} @ state {self["id"]}') self.debug(f'Agent {self.unique_id} @ state {self["id"]}')
interval = super().step() interval = super().step()
if 'id' not in self: if 'id' not in self.state:
if 'id' in self.state: # if 'id' in self.state:
self.set_state(self['state_id']) # self.set_state(self.state['id'])
elif self.default_state: if self.default_state:
self.set_state(self.default_state.id) self.set_state(self.default_state.id)
else: else:
raise Exception('{} has no valid state id or default state'.format(self)) raise Exception('{} has no valid state id or default state'.format(self))
return self.states[self['id']](self) or interval return self.states[self.state['id']](self) or interval
def set_state(self, state): def set_state(self, state):
if hasattr(state, 'id'): if hasattr(state, 'id'):
state = state.id state = state.id
if state not in self.states: if state not in self.states:
raise ValueError('{} is not a valid state'.format(state)) raise ValueError('{} is not a valid state'.format(state))
self['state_id'] = state self.state['id'] = state
return state return state
@ -530,8 +535,6 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False,
except TypeError: except TypeError:
agent_type = tuple([agent_type]) agent_type = tuple([agent_type])
checks = []
f = agents f = agents
if ignore: if ignore:
@ -547,7 +550,7 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False,
if iterator: if iterator:
return f return f
return list(f) return f
from .BassModel import * from .BassModel import *

View File

@ -81,9 +81,6 @@ class Environment(Model):
self._history = history.History(name=self.name, self._history = history.History(name=self.name,
backup=True) backup=True)
self['SEED'] = seed self['SEED'] = seed
# Add environment agents first, so their events get
# executed before network agents
if network_agents: if network_agents:
distro = agents.calculate_distribution(network_agents) distro = agents.calculate_distribution(network_agents)
@ -97,7 +94,6 @@ class Environment(Model):
environment_agents = agents._convert_agent_types(distro) environment_agents = agents._convert_agent_types(distro)
self.environment_agents = environment_agents self.environment_agents = environment_agents
@property @property
def now(self): def now(self):
if self.schedule: if self.schedule:
@ -169,6 +165,7 @@ class Environment(Model):
unique_id=agent_id, unique_id=agent_id,
state=state) state=state)
node['agent'] = a node['agent'] = a
self.schedule.add(a)
return a return a
def add_node(self, agent_type, state=None): def add_node(self, agent_type, state=None):
@ -188,8 +185,6 @@ class Environment(Model):
def run(self, until, *args, **kwargs): def run(self, until, *args, **kwargs):
self._save_state() self._save_state()
for agent in self.agents:
self.schedule.add(agent)
while self.schedule.next_time <= until and not math.isinf(self.schedule.next_time): while self.schedule.next_time <= until and not math.isinf(self.schedule.next_time):
self.schedule.step(until=until) self.schedule.step(until=until)
@ -238,8 +233,8 @@ class Environment(Model):
def get_agents(self, nodes=None): def get_agents(self, nodes=None):
if nodes is None: if nodes is None:
return list(self.agents) return self.agents
return [self.G.nodes[i]['agent'] for i in nodes] return (self.G.nodes[i]['agent'] for i in nodes)
def dump_csv(self, f): def dump_csv(self, f):
with utils.open_or_reuse(f, 'w') as f: with utils.open_or_reuse(f, 'w') as f:

View File

@ -18,6 +18,9 @@ class Delta:
def __init__(self, delta): def __init__(self, delta):
self._delta = delta self._delta = delta
def __eq__(self, other):
return self._delta == other._delta
def abs(self, time): def abs(self, time):
return time + self._delta return time + self._delta
@ -35,7 +38,7 @@ class TimedActivation(BaseScheduler):
def add(self, agent: Agent): def add(self, agent: Agent):
if agent.unique_id not in self._agents: if agent.unique_id not in self._agents:
heappush(self._queue, (self.time, agent.unique_id)) heappush(self._queue, (self.time, agent.unique_id))
super().add(agent) super().add(agent)
def step(self, until: float =float('inf')) -> None: def step(self, until: float =float('inf')) -> None:
""" """

View File

@ -10,6 +10,7 @@ from functools import partial
from os.path import join from os.path import join
from soil import (simulation, Environment, agents, serialization, from soil import (simulation, Environment, agents, serialization,
history, utils) history, utils)
from soil.time import Delta
ROOT = os.path.abspath(os.path.dirname(__file__)) ROOT = os.path.abspath(os.path.dirname(__file__))
@ -356,3 +357,41 @@ class TestMain(TestCase):
runs = list(s.run_simulation(dry_run=True)) runs = list(s.run_simulation(dry_run=True))
over = list(x.now for x in runs if x.now>2) over = list(x.now for x in runs if x.now>2)
assert len(over) == 0 assert len(over) == 0
def test_fsm(self):
'''Basic state change'''
class ToggleAgent(agents.FSM):
@agents.default_state
@agents.state
def ping(self):
return self.pong
@agents.state
def pong(self):
return self.ping
a = ToggleAgent(unique_id=1, model=Environment())
assert a.state["id"] == a.ping.id
a.step()
assert a.state["id"] == a.pong.id
a.step()
assert a.state["id"] == a.ping.id
def test_fsm_when(self):
'''Basic state change'''
class ToggleAgent(agents.FSM):
@agents.default_state
@agents.state
def ping(self):
return self.pong, 2
@agents.state
def pong(self):
return self.ping
a = ToggleAgent(unique_id=1, model=Environment())
when = a.step()
assert when == 2
when = a.step()
assert when == Delta(a.interval)