mirror of
https://github.com/gsi-upm/soil
synced 2024-11-21 18:52:28 +00:00
WIP: mesa compat
All tests pass but some features are still missing/unclear: - Mesa agents do not have a `state`, so their "metrics" don't get stored. I will probably refactor this to remove some magic in this regard. This should get rid of the `_state` dictionary and the setitem/getitem magic. - Simulation is still different from a runner. So far only Agent and Environment/Model have been updated.
This commit is contained in:
parent
5d7e57675a
commit
af9a392a93
@ -1,7 +1,6 @@
|
||||
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
|
||||
from enum import Enum
|
||||
from random import random, choice
|
||||
from itertools import islice
|
||||
import logging
|
||||
import math
|
||||
|
||||
@ -22,7 +21,7 @@ class RabbitModel(FSM):
|
||||
'offspring': 0,
|
||||
}
|
||||
|
||||
sexual_maturity = 4*30
|
||||
sexual_maturity = 3 #4*30
|
||||
life_expectancy = 365 * 3
|
||||
gestation = 33
|
||||
pregnancy = -1
|
||||
@ -31,9 +30,11 @@ class RabbitModel(FSM):
|
||||
@default_state
|
||||
@state
|
||||
def newborn(self):
|
||||
self.debug(f'I am a newborn at age {self["age"]}')
|
||||
self['age'] += 1
|
||||
|
||||
if self['age'] >= self.sexual_maturity:
|
||||
self.debug('I am fertile!')
|
||||
return self.fertile
|
||||
|
||||
@state
|
||||
@ -46,8 +47,7 @@ class RabbitModel(FSM):
|
||||
return
|
||||
|
||||
# Males try to mate
|
||||
females = self.get_agents(state_id=self.fertile.id, gender=Genders.female.value, limit_neighbors=False)
|
||||
for f in islice(females, self.max_females):
|
||||
for f in self.get_agents(state_id=self.fertile.id, gender=Genders.female.value, limit_neighbors=False, limit=self.max_females):
|
||||
r = random()
|
||||
if r < self['mating_prob']:
|
||||
self.impregnate(f)
|
||||
|
@ -1,7 +1,7 @@
|
||||
---
|
||||
load_module: rabbit_agents
|
||||
name: rabbits_example
|
||||
max_time: 500
|
||||
max_time: 200
|
||||
interval: 1
|
||||
seed: MySeed
|
||||
agent_type: RabbitModel
|
||||
|
@ -1,5 +1,6 @@
|
||||
from scipy.spatial import cKDTree as KDTree
|
||||
from . import NetworkAgent
|
||||
import networkx as nx
|
||||
from . import NetworkAgent, as_node
|
||||
|
||||
class Geo(NetworkAgent):
|
||||
'''In this type of network, nodes have a "pos" attribute.'''
|
||||
|
@ -1,12 +1,11 @@
|
||||
import logging
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
from itertools import islice
|
||||
import json
|
||||
import networkx as nx
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from .. import serialization, history, utils, time
|
||||
|
||||
from mesa import Agent
|
||||
@ -91,7 +90,7 @@ class BaseAgent(Agent):
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, tuple):
|
||||
key, t_step = key
|
||||
k = history.Key(key=key, t_step=t_step, agent_id=self.id)
|
||||
k = history.Key(key=key, t_step=t_step, agent_id=self.unique_id)
|
||||
return self.model[k]
|
||||
return self._state.get(key, None)
|
||||
|
||||
@ -151,25 +150,25 @@ class BaseAgent(Agent):
|
||||
def info(self, *args, **kwargs):
|
||||
return self.log(*args, level=logging.INFO, **kwargs)
|
||||
|
||||
def __getstate__(self):
|
||||
'''
|
||||
Serializing an agent will lose all its running information (you cannot
|
||||
serialize an iterator), but it keeps the state and link to the environment,
|
||||
so it can be used for inspection and dumping to a file
|
||||
'''
|
||||
state = {}
|
||||
state['id'] = self.id
|
||||
state['environment'] = self.model
|
||||
state['_state'] = self._state
|
||||
return state
|
||||
# def __getstate__(self):
|
||||
# '''
|
||||
# Serializing an agent will lose all its running information (you cannot
|
||||
# serialize an iterator), but it keeps the state and link to the environment,
|
||||
# so it can be used for inspection and dumping to a file
|
||||
# '''
|
||||
# state = {}
|
||||
# state['id'] = self.id
|
||||
# state['environment'] = self.model
|
||||
# state['_state'] = self._state
|
||||
# return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
'''
|
||||
Get back a serialized agent and try to re-compose it
|
||||
'''
|
||||
self.state_id = state['id']
|
||||
self._state = state['_state']
|
||||
self.model = state['environment']
|
||||
# def __setstate__(self, state):
|
||||
# '''
|
||||
# Get back a serialized agent and try to re-compose it
|
||||
# '''
|
||||
# self.state_id = state['id']
|
||||
# self._state = state['_state']
|
||||
# self.model = state['environment']
|
||||
|
||||
class NetworkAgent(BaseAgent):
|
||||
|
||||
@ -190,7 +189,13 @@ class NetworkAgent(BaseAgent):
|
||||
def get_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
|
||||
|
||||
def get_agents(self, agents=None, limit_neighbors=False, **kwargs):
|
||||
def get_agents(self, *args, limit=None, **kwargs):
|
||||
it = self.iter_agents(*args, **kwargs)
|
||||
if limit is not None:
|
||||
it = islice(it, limit)
|
||||
return list(it)
|
||||
|
||||
def iter_agents(self, agents=None, limit_neighbors=False, **kwargs):
|
||||
if limit_neighbors:
|
||||
agents = self.topology.neighbors(self.unique_id)
|
||||
|
||||
@ -199,7 +204,7 @@ class NetworkAgent(BaseAgent):
|
||||
|
||||
def subgraph(self, center=True, **kwargs):
|
||||
include = [self] if center else []
|
||||
return self.topology.subgraph(n.unique_id for n in self.get_agents(**kwargs)+include)
|
||||
return self.topology.subgraph(n.unique_id for n in list(self.get_agents(**kwargs))+include)
|
||||
|
||||
def remove_node(self, unique_id):
|
||||
self.topology.remove_node(unique_id)
|
||||
@ -208,10 +213,10 @@ class NetworkAgent(BaseAgent):
|
||||
# return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
|
||||
if self.unique_id not in self.topology.nodes(data=False):
|
||||
raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id))
|
||||
if other not in self.topology.nodes(data=False):
|
||||
if other.unique_id not in self.topology.nodes(data=False):
|
||||
raise ValueError('{} not in list of existing agents in the network'.format(other))
|
||||
|
||||
self.topology.add_edge(self.unique_id, other, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
||||
self.topology.add_edge(self.unique_id, other.unique_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
||||
|
||||
|
||||
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
||||
@ -308,21 +313,21 @@ class FSM(NetworkAgent, metaclass=MetaFSM):
|
||||
def step(self):
|
||||
self.debug(f'Agent {self.unique_id} @ state {self["id"]}')
|
||||
interval = super().step()
|
||||
if 'id' not in self:
|
||||
if 'id' in self.state:
|
||||
self.set_state(self['state_id'])
|
||||
elif self.default_state:
|
||||
if 'id' not in self.state:
|
||||
# if 'id' in self.state:
|
||||
# self.set_state(self.state['id'])
|
||||
if self.default_state:
|
||||
self.set_state(self.default_state.id)
|
||||
else:
|
||||
raise Exception('{} has no valid state id or default state'.format(self))
|
||||
return self.states[self['id']](self) or interval
|
||||
return self.states[self.state['id']](self) or interval
|
||||
|
||||
def set_state(self, state):
|
||||
if hasattr(state, 'id'):
|
||||
state = state.id
|
||||
if state not in self.states:
|
||||
raise ValueError('{} is not a valid state'.format(state))
|
||||
self['state_id'] = state
|
||||
self.state['id'] = state
|
||||
return state
|
||||
|
||||
|
||||
@ -530,8 +535,6 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False,
|
||||
except TypeError:
|
||||
agent_type = tuple([agent_type])
|
||||
|
||||
checks = []
|
||||
|
||||
f = agents
|
||||
|
||||
if ignore:
|
||||
@ -547,7 +550,7 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False,
|
||||
|
||||
if iterator:
|
||||
return f
|
||||
return list(f)
|
||||
return f
|
||||
|
||||
|
||||
from .BassModel import *
|
||||
|
@ -81,9 +81,6 @@ class Environment(Model):
|
||||
self._history = history.History(name=self.name,
|
||||
backup=True)
|
||||
self['SEED'] = seed
|
||||
# Add environment agents first, so their events get
|
||||
# executed before network agents
|
||||
|
||||
|
||||
if network_agents:
|
||||
distro = agents.calculate_distribution(network_agents)
|
||||
@ -97,7 +94,6 @@ class Environment(Model):
|
||||
environment_agents = agents._convert_agent_types(distro)
|
||||
self.environment_agents = environment_agents
|
||||
|
||||
|
||||
@property
|
||||
def now(self):
|
||||
if self.schedule:
|
||||
@ -169,6 +165,7 @@ class Environment(Model):
|
||||
unique_id=agent_id,
|
||||
state=state)
|
||||
node['agent'] = a
|
||||
self.schedule.add(a)
|
||||
return a
|
||||
|
||||
def add_node(self, agent_type, state=None):
|
||||
@ -188,8 +185,6 @@ class Environment(Model):
|
||||
|
||||
def run(self, until, *args, **kwargs):
|
||||
self._save_state()
|
||||
for agent in self.agents:
|
||||
self.schedule.add(agent)
|
||||
|
||||
while self.schedule.next_time <= until and not math.isinf(self.schedule.next_time):
|
||||
self.schedule.step(until=until)
|
||||
@ -238,8 +233,8 @@ class Environment(Model):
|
||||
|
||||
def get_agents(self, nodes=None):
|
||||
if nodes is None:
|
||||
return list(self.agents)
|
||||
return [self.G.nodes[i]['agent'] for i in nodes]
|
||||
return self.agents
|
||||
return (self.G.nodes[i]['agent'] for i in nodes)
|
||||
|
||||
def dump_csv(self, f):
|
||||
with utils.open_or_reuse(f, 'w') as f:
|
||||
|
@ -18,6 +18,9 @@ class Delta:
|
||||
def __init__(self, delta):
|
||||
self._delta = delta
|
||||
|
||||
def __eq__(self, other):
|
||||
return self._delta == other._delta
|
||||
|
||||
def abs(self, time):
|
||||
return time + self._delta
|
||||
|
||||
|
@ -10,6 +10,7 @@ from functools import partial
|
||||
from os.path import join
|
||||
from soil import (simulation, Environment, agents, serialization,
|
||||
history, utils)
|
||||
from soil.time import Delta
|
||||
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
@ -356,3 +357,41 @@ class TestMain(TestCase):
|
||||
runs = list(s.run_simulation(dry_run=True))
|
||||
over = list(x.now for x in runs if x.now>2)
|
||||
assert len(over) == 0
|
||||
|
||||
|
||||
def test_fsm(self):
|
||||
'''Basic state change'''
|
||||
class ToggleAgent(agents.FSM):
|
||||
@agents.default_state
|
||||
@agents.state
|
||||
def ping(self):
|
||||
return self.pong
|
||||
|
||||
@agents.state
|
||||
def pong(self):
|
||||
return self.ping
|
||||
|
||||
a = ToggleAgent(unique_id=1, model=Environment())
|
||||
assert a.state["id"] == a.ping.id
|
||||
a.step()
|
||||
assert a.state["id"] == a.pong.id
|
||||
a.step()
|
||||
assert a.state["id"] == a.ping.id
|
||||
|
||||
def test_fsm_when(self):
|
||||
'''Basic state change'''
|
||||
class ToggleAgent(agents.FSM):
|
||||
@agents.default_state
|
||||
@agents.state
|
||||
def ping(self):
|
||||
return self.pong, 2
|
||||
|
||||
@agents.state
|
||||
def pong(self):
|
||||
return self.ping
|
||||
|
||||
a = ToggleAgent(unique_id=1, model=Environment())
|
||||
when = a.step()
|
||||
assert when == 2
|
||||
when = a.step()
|
||||
assert when == Delta(a.interval)
|
||||
|
Loading…
Reference in New Issue
Block a user