mirror of
https://github.com/gsi-upm/soil
synced 2024-11-22 03:02:28 +00:00
WIP: mesa compat
All tests pass but some features are still missing/unclear: - Mesa agents do not have a `state`, so their "metrics" don't get stored. I will probably refactor this to remove some magic in this regard. This should get rid of the `_state` dictionary and the setitem/getitem magic. - Simulation is still different from a runner. So far only Agent and Environment/Model have been updated.
This commit is contained in:
parent
5d7e57675a
commit
af9a392a93
@ -1,7 +1,6 @@
|
|||||||
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
|
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from random import random, choice
|
from random import random, choice
|
||||||
from itertools import islice
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@ -22,7 +21,7 @@ class RabbitModel(FSM):
|
|||||||
'offspring': 0,
|
'offspring': 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
sexual_maturity = 4*30
|
sexual_maturity = 3 #4*30
|
||||||
life_expectancy = 365 * 3
|
life_expectancy = 365 * 3
|
||||||
gestation = 33
|
gestation = 33
|
||||||
pregnancy = -1
|
pregnancy = -1
|
||||||
@ -31,9 +30,11 @@ class RabbitModel(FSM):
|
|||||||
@default_state
|
@default_state
|
||||||
@state
|
@state
|
||||||
def newborn(self):
|
def newborn(self):
|
||||||
|
self.debug(f'I am a newborn at age {self["age"]}')
|
||||||
self['age'] += 1
|
self['age'] += 1
|
||||||
|
|
||||||
if self['age'] >= self.sexual_maturity:
|
if self['age'] >= self.sexual_maturity:
|
||||||
|
self.debug('I am fertile!')
|
||||||
return self.fertile
|
return self.fertile
|
||||||
|
|
||||||
@state
|
@state
|
||||||
@ -46,8 +47,7 @@ class RabbitModel(FSM):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Males try to mate
|
# Males try to mate
|
||||||
females = self.get_agents(state_id=self.fertile.id, gender=Genders.female.value, limit_neighbors=False)
|
for f in self.get_agents(state_id=self.fertile.id, gender=Genders.female.value, limit_neighbors=False, limit=self.max_females):
|
||||||
for f in islice(females, self.max_females):
|
|
||||||
r = random()
|
r = random()
|
||||||
if r < self['mating_prob']:
|
if r < self['mating_prob']:
|
||||||
self.impregnate(f)
|
self.impregnate(f)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
---
|
---
|
||||||
load_module: rabbit_agents
|
load_module: rabbit_agents
|
||||||
name: rabbits_example
|
name: rabbits_example
|
||||||
max_time: 500
|
max_time: 200
|
||||||
interval: 1
|
interval: 1
|
||||||
seed: MySeed
|
seed: MySeed
|
||||||
agent_type: RabbitModel
|
agent_type: RabbitModel
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from scipy.spatial import cKDTree as KDTree
|
from scipy.spatial import cKDTree as KDTree
|
||||||
from . import NetworkAgent
|
import networkx as nx
|
||||||
|
from . import NetworkAgent, as_node
|
||||||
|
|
||||||
class Geo(NetworkAgent):
|
class Geo(NetworkAgent):
|
||||||
'''In this type of network, nodes have a "pos" attribute.'''
|
'''In this type of network, nodes have a "pos" attribute.'''
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial, wraps
|
||||||
|
from itertools import islice
|
||||||
import json
|
import json
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
from .. import serialization, history, utils, time
|
from .. import serialization, history, utils, time
|
||||||
|
|
||||||
from mesa import Agent
|
from mesa import Agent
|
||||||
@ -91,7 +90,7 @@ class BaseAgent(Agent):
|
|||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if isinstance(key, tuple):
|
if isinstance(key, tuple):
|
||||||
key, t_step = key
|
key, t_step = key
|
||||||
k = history.Key(key=key, t_step=t_step, agent_id=self.id)
|
k = history.Key(key=key, t_step=t_step, agent_id=self.unique_id)
|
||||||
return self.model[k]
|
return self.model[k]
|
||||||
return self._state.get(key, None)
|
return self._state.get(key, None)
|
||||||
|
|
||||||
@ -151,25 +150,25 @@ class BaseAgent(Agent):
|
|||||||
def info(self, *args, **kwargs):
|
def info(self, *args, **kwargs):
|
||||||
return self.log(*args, level=logging.INFO, **kwargs)
|
return self.log(*args, level=logging.INFO, **kwargs)
|
||||||
|
|
||||||
def __getstate__(self):
|
# def __getstate__(self):
|
||||||
'''
|
# '''
|
||||||
Serializing an agent will lose all its running information (you cannot
|
# Serializing an agent will lose all its running information (you cannot
|
||||||
serialize an iterator), but it keeps the state and link to the environment,
|
# serialize an iterator), but it keeps the state and link to the environment,
|
||||||
so it can be used for inspection and dumping to a file
|
# so it can be used for inspection and dumping to a file
|
||||||
'''
|
# '''
|
||||||
state = {}
|
# state = {}
|
||||||
state['id'] = self.id
|
# state['id'] = self.id
|
||||||
state['environment'] = self.model
|
# state['environment'] = self.model
|
||||||
state['_state'] = self._state
|
# state['_state'] = self._state
|
||||||
return state
|
# return state
|
||||||
|
|
||||||
def __setstate__(self, state):
|
# def __setstate__(self, state):
|
||||||
'''
|
# '''
|
||||||
Get back a serialized agent and try to re-compose it
|
# Get back a serialized agent and try to re-compose it
|
||||||
'''
|
# '''
|
||||||
self.state_id = state['id']
|
# self.state_id = state['id']
|
||||||
self._state = state['_state']
|
# self._state = state['_state']
|
||||||
self.model = state['environment']
|
# self.model = state['environment']
|
||||||
|
|
||||||
class NetworkAgent(BaseAgent):
|
class NetworkAgent(BaseAgent):
|
||||||
|
|
||||||
@ -190,7 +189,13 @@ class NetworkAgent(BaseAgent):
|
|||||||
def get_neighboring_agents(self, state_id=None, **kwargs):
|
def get_neighboring_agents(self, state_id=None, **kwargs):
|
||||||
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
|
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
|
||||||
|
|
||||||
def get_agents(self, agents=None, limit_neighbors=False, **kwargs):
|
def get_agents(self, *args, limit=None, **kwargs):
|
||||||
|
it = self.iter_agents(*args, **kwargs)
|
||||||
|
if limit is not None:
|
||||||
|
it = islice(it, limit)
|
||||||
|
return list(it)
|
||||||
|
|
||||||
|
def iter_agents(self, agents=None, limit_neighbors=False, **kwargs):
|
||||||
if limit_neighbors:
|
if limit_neighbors:
|
||||||
agents = self.topology.neighbors(self.unique_id)
|
agents = self.topology.neighbors(self.unique_id)
|
||||||
|
|
||||||
@ -199,7 +204,7 @@ class NetworkAgent(BaseAgent):
|
|||||||
|
|
||||||
def subgraph(self, center=True, **kwargs):
|
def subgraph(self, center=True, **kwargs):
|
||||||
include = [self] if center else []
|
include = [self] if center else []
|
||||||
return self.topology.subgraph(n.unique_id for n in self.get_agents(**kwargs)+include)
|
return self.topology.subgraph(n.unique_id for n in list(self.get_agents(**kwargs))+include)
|
||||||
|
|
||||||
def remove_node(self, unique_id):
|
def remove_node(self, unique_id):
|
||||||
self.topology.remove_node(unique_id)
|
self.topology.remove_node(unique_id)
|
||||||
@ -208,10 +213,10 @@ class NetworkAgent(BaseAgent):
|
|||||||
# return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
|
# return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
|
||||||
if self.unique_id not in self.topology.nodes(data=False):
|
if self.unique_id not in self.topology.nodes(data=False):
|
||||||
raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id))
|
raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id))
|
||||||
if other not in self.topology.nodes(data=False):
|
if other.unique_id not in self.topology.nodes(data=False):
|
||||||
raise ValueError('{} not in list of existing agents in the network'.format(other))
|
raise ValueError('{} not in list of existing agents in the network'.format(other))
|
||||||
|
|
||||||
self.topology.add_edge(self.unique_id, other, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
self.topology.add_edge(self.unique_id, other.unique_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
||||||
|
|
||||||
|
|
||||||
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
||||||
@ -308,21 +313,21 @@ class FSM(NetworkAgent, metaclass=MetaFSM):
|
|||||||
def step(self):
|
def step(self):
|
||||||
self.debug(f'Agent {self.unique_id} @ state {self["id"]}')
|
self.debug(f'Agent {self.unique_id} @ state {self["id"]}')
|
||||||
interval = super().step()
|
interval = super().step()
|
||||||
if 'id' not in self:
|
if 'id' not in self.state:
|
||||||
if 'id' in self.state:
|
# if 'id' in self.state:
|
||||||
self.set_state(self['state_id'])
|
# self.set_state(self.state['id'])
|
||||||
elif self.default_state:
|
if self.default_state:
|
||||||
self.set_state(self.default_state.id)
|
self.set_state(self.default_state.id)
|
||||||
else:
|
else:
|
||||||
raise Exception('{} has no valid state id or default state'.format(self))
|
raise Exception('{} has no valid state id or default state'.format(self))
|
||||||
return self.states[self['id']](self) or interval
|
return self.states[self.state['id']](self) or interval
|
||||||
|
|
||||||
def set_state(self, state):
|
def set_state(self, state):
|
||||||
if hasattr(state, 'id'):
|
if hasattr(state, 'id'):
|
||||||
state = state.id
|
state = state.id
|
||||||
if state not in self.states:
|
if state not in self.states:
|
||||||
raise ValueError('{} is not a valid state'.format(state))
|
raise ValueError('{} is not a valid state'.format(state))
|
||||||
self['state_id'] = state
|
self.state['id'] = state
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
@ -530,8 +535,6 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False,
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
agent_type = tuple([agent_type])
|
agent_type = tuple([agent_type])
|
||||||
|
|
||||||
checks = []
|
|
||||||
|
|
||||||
f = agents
|
f = agents
|
||||||
|
|
||||||
if ignore:
|
if ignore:
|
||||||
@ -547,7 +550,7 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False,
|
|||||||
|
|
||||||
if iterator:
|
if iterator:
|
||||||
return f
|
return f
|
||||||
return list(f)
|
return f
|
||||||
|
|
||||||
|
|
||||||
from .BassModel import *
|
from .BassModel import *
|
||||||
|
@ -81,9 +81,6 @@ class Environment(Model):
|
|||||||
self._history = history.History(name=self.name,
|
self._history = history.History(name=self.name,
|
||||||
backup=True)
|
backup=True)
|
||||||
self['SEED'] = seed
|
self['SEED'] = seed
|
||||||
# Add environment agents first, so their events get
|
|
||||||
# executed before network agents
|
|
||||||
|
|
||||||
|
|
||||||
if network_agents:
|
if network_agents:
|
||||||
distro = agents.calculate_distribution(network_agents)
|
distro = agents.calculate_distribution(network_agents)
|
||||||
@ -97,7 +94,6 @@ class Environment(Model):
|
|||||||
environment_agents = agents._convert_agent_types(distro)
|
environment_agents = agents._convert_agent_types(distro)
|
||||||
self.environment_agents = environment_agents
|
self.environment_agents = environment_agents
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def now(self):
|
def now(self):
|
||||||
if self.schedule:
|
if self.schedule:
|
||||||
@ -169,6 +165,7 @@ class Environment(Model):
|
|||||||
unique_id=agent_id,
|
unique_id=agent_id,
|
||||||
state=state)
|
state=state)
|
||||||
node['agent'] = a
|
node['agent'] = a
|
||||||
|
self.schedule.add(a)
|
||||||
return a
|
return a
|
||||||
|
|
||||||
def add_node(self, agent_type, state=None):
|
def add_node(self, agent_type, state=None):
|
||||||
@ -188,8 +185,6 @@ class Environment(Model):
|
|||||||
|
|
||||||
def run(self, until, *args, **kwargs):
|
def run(self, until, *args, **kwargs):
|
||||||
self._save_state()
|
self._save_state()
|
||||||
for agent in self.agents:
|
|
||||||
self.schedule.add(agent)
|
|
||||||
|
|
||||||
while self.schedule.next_time <= until and not math.isinf(self.schedule.next_time):
|
while self.schedule.next_time <= until and not math.isinf(self.schedule.next_time):
|
||||||
self.schedule.step(until=until)
|
self.schedule.step(until=until)
|
||||||
@ -238,8 +233,8 @@ class Environment(Model):
|
|||||||
|
|
||||||
def get_agents(self, nodes=None):
|
def get_agents(self, nodes=None):
|
||||||
if nodes is None:
|
if nodes is None:
|
||||||
return list(self.agents)
|
return self.agents
|
||||||
return [self.G.nodes[i]['agent'] for i in nodes]
|
return (self.G.nodes[i]['agent'] for i in nodes)
|
||||||
|
|
||||||
def dump_csv(self, f):
|
def dump_csv(self, f):
|
||||||
with utils.open_or_reuse(f, 'w') as f:
|
with utils.open_or_reuse(f, 'w') as f:
|
||||||
|
@ -18,6 +18,9 @@ class Delta:
|
|||||||
def __init__(self, delta):
|
def __init__(self, delta):
|
||||||
self._delta = delta
|
self._delta = delta
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self._delta == other._delta
|
||||||
|
|
||||||
def abs(self, time):
|
def abs(self, time):
|
||||||
return time + self._delta
|
return time + self._delta
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from functools import partial
|
|||||||
from os.path import join
|
from os.path import join
|
||||||
from soil import (simulation, Environment, agents, serialization,
|
from soil import (simulation, Environment, agents, serialization,
|
||||||
history, utils)
|
history, utils)
|
||||||
|
from soil.time import Delta
|
||||||
|
|
||||||
|
|
||||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||||
@ -356,3 +357,41 @@ class TestMain(TestCase):
|
|||||||
runs = list(s.run_simulation(dry_run=True))
|
runs = list(s.run_simulation(dry_run=True))
|
||||||
over = list(x.now for x in runs if x.now>2)
|
over = list(x.now for x in runs if x.now>2)
|
||||||
assert len(over) == 0
|
assert len(over) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_fsm(self):
|
||||||
|
'''Basic state change'''
|
||||||
|
class ToggleAgent(agents.FSM):
|
||||||
|
@agents.default_state
|
||||||
|
@agents.state
|
||||||
|
def ping(self):
|
||||||
|
return self.pong
|
||||||
|
|
||||||
|
@agents.state
|
||||||
|
def pong(self):
|
||||||
|
return self.ping
|
||||||
|
|
||||||
|
a = ToggleAgent(unique_id=1, model=Environment())
|
||||||
|
assert a.state["id"] == a.ping.id
|
||||||
|
a.step()
|
||||||
|
assert a.state["id"] == a.pong.id
|
||||||
|
a.step()
|
||||||
|
assert a.state["id"] == a.ping.id
|
||||||
|
|
||||||
|
def test_fsm_when(self):
|
||||||
|
'''Basic state change'''
|
||||||
|
class ToggleAgent(agents.FSM):
|
||||||
|
@agents.default_state
|
||||||
|
@agents.state
|
||||||
|
def ping(self):
|
||||||
|
return self.pong, 2
|
||||||
|
|
||||||
|
@agents.state
|
||||||
|
def pong(self):
|
||||||
|
return self.ping
|
||||||
|
|
||||||
|
a = ToggleAgent(unique_id=1, model=Environment())
|
||||||
|
when = a.step()
|
||||||
|
assert when == 2
|
||||||
|
when = a.step()
|
||||||
|
assert when == Delta(a.interval)
|
||||||
|
Loading…
Reference in New Issue
Block a user