Partial MESA compatibility and several fixes

Documentation for the new APIs is still a work in progress :)
pull/8/head
J. Fernando Sánchez 3 years ago
parent af9a392a93
commit 6c4f44b4cb

1
.gitignore vendored

@ -8,3 +8,4 @@ soil_output
docs/_build* docs/_build*
build/* build/*
dist/* dist/*
prof

@ -5,23 +5,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased] ## [Unreleased]
### Added ### Added
* [WIP] Integration with MESA * Integration with MESA
* `not_agent_ids` paramter to get sql in history * `not_agent_ids` paramter to get sql in history
### Changed ### Changed
* `soil.Environment` now also inherits from `mesa.Model` * `soil.Environment` now also inherits from `mesa.Model`
* `soil.Agent` now also inherits from `mesa.Agent` * `soil.Agent` now also inherits from `mesa.Agent`
* `soil.time` to replace `simpy` events, delays, duration, etc. * `soil.time` to replace `simpy` events, delays, duration, etc.
* `agent.id` is not `agent.unique_id` to be compatible with `mesa`. A property `BaseAgent.id` has been added for compatibility.
* `agent.environment` is now `agent.model`, for the same reason as above. The parameter name in `BaseAgent.__init__` has also been renamed.
### Removed ### Removed
* `simpy` dependency and compatibility. Each agent used to be a simpy generator, but that made debugging and error handling more complex. That has been replaced by a scheduler within the `soil.Environment` class, similar to how `mesa` does it. * `simpy` dependency and compatibility. Each agent used to be a simpy generator, but that made debugging and error handling more complex. That has been replaced by a scheduler within the `soil.Environment` class, similar to how `mesa` does it.
* `soil.history` is now a separate package named `tsih`. The keys namedtuple uses `dict_id` instead of `agent_id`.
### TODO: ### Added
* agent_id -> unique_id? * An option to choose whether a database should be used for history
* mesa has Agent.model and soil has Agent.env
* Environments.agents and mesa.Agent.agents are not the same. env is a property, and it only takes into account network and environment agents. Might rename environment_agents to other_agents or sth like that
* soil.History should mimic a mesa.datacollector :/
* soil.Simulation *could* mimic a mesa.batchrunner
* DONE include scheduler in environment
* DONE environment inherits from `mesa.Model`
## [0.15.2] ## [0.15.2]

@ -5,6 +5,9 @@ Learn how to run your own simulations with our [documentation](http://soilsim.re
Follow our [tutorial](examples/tutorial/soil_tutorial.ipynb) to develop your own agent models. Follow our [tutorial](examples/tutorial/soil_tutorial.ipynb) to develop your own agent models.
## Citation
If you use Soil in your research, don't forget to cite this paper: If you use Soil in your research, don't forget to cite this paper:
```bibtex ```bibtex
@ -28,7 +31,24 @@ If you use Soil in your research, don't forget to cite this paper:
``` ```
@Copyright GSI - Universidad Politécnica de Madrid 2017 ## Mesa compatibility
[![SOIL](logo_gsi.png)](https://www.gsi.upm.es) Soil is in the process of becoming fully compatible with MESA.
As of this writing,
This is a non-exhaustive list of tasks to achieve compatibility:
* Environments.agents and mesa.Agent.agents are not the same. env is a property, and it only takes into account network and environment agents. Might rename environment_agents to other_agents or sth like that
- [ ] Integrate `soil.Simulation` with mesa's runners:
- [ ] `soil.Simulation` could mimic/become a `mesa.batchrunner`
- [ ] Integrate `soil.Environment` with `mesa.Model`:
- [x] `Soil.Environment` inherits from `mesa.Model`
- [x] `Soil.Environment` includes a Mesa-like Scheduler (see the `soil.time` module.
- [ ] Integrate `soil.Agent` with `mesa.Agent`:
- [x] Rename agent.id to unique_id?
- [x] mesa agents can be used in soil simulations (see `examples/mesa`)
- [ ] Document the new APIs and usage
@Copyright GSI - Universidad Politécnica de Madrid 2017-2021
[![SOIL](logo_gsi.png)](https://www.gsi.upm.es)

@ -13,7 +13,7 @@ network_agents:
- agent_type: CounterModel - agent_type: CounterModel
weight: 1 weight: 1
state: state:
id: 0 state_id: 0
- agent_type: AggregatedCounter - agent_type: AggregatedCounter
weight: 0.2 weight: 0.2
environment_agents: [] environment_agents: []

@ -13,4 +13,4 @@ network_agents:
- agent_type: CounterModel - agent_type: CounterModel
weight: 1 weight: 1
state: state:
id: 0 state_id: 0

@ -23,7 +23,6 @@ def network_portrayal(env):
} }
for (agent_id) in env.G.nodes for (agent_id) in env.G.nodes
] ]
# import pdb;pdb.set_trace()
portrayal["edges"] = [ portrayal["edges"] = [
{"id": edge_id, "source": source, "target": target, "color": "#000000"} {"id": edge_id, "source": source, "target": target, "color": "#000000"}
@ -65,7 +64,7 @@ model_params = {
"N": UserSettableParameter( "N": UserSettableParameter(
"slider", "slider",
"N", "N",
1, 5,
1, 1,
10, 10,
1, 1,

@ -34,9 +34,7 @@ class MoneyAgent(MesaAgent):
self.pos, self.pos,
moore=True, moore=True,
include_center=False) include_center=False)
print(self.pos, possible_steps)
new_position = self.random.choice(possible_steps) new_position = self.random.choice(possible_steps)
print(self.pos, new_position)
self.model.grid.move_agent(self, new_position) self.model.grid.move_agent(self, new_position)
def give_money(self): def give_money(self):
@ -74,21 +72,13 @@ class SocialMoneyAgent(NetworkAgent, MoneyAgent):
class MoneyEnv(Environment): class MoneyEnv(Environment):
"""A model with some number of agents.""" """A model with some number of agents."""
def __init__(self, N, width, height, *args, network_params, **kwargs): def __init__(self, N, width, height, *args, network_params, **kwargs):
self.initialized = True
# import pdb;pdb.set_trace()
network_params['n'] = N network_params['n'] = N
super().__init__(*args, network_params=network_params, **kwargs) super().__init__(*args, network_params=network_params, **kwargs)
self.grid = MultiGrid(width, height, False) self.grid = MultiGrid(width, height, False)
# self.schedule = RandomActivation(self)
self.running = True
# Create agents # Create agents
for agent in self.agents: for agent in self.agents:
self.schedule.add(agent)
# a = MoneyAgent(i, self)
# self.schedule.add(a)
# Add the agent to a random grid cell
x = self.random.randrange(self.grid.width) x = self.random.randrange(self.grid.width)
y = self.random.randrange(self.grid.height) y = self.random.randrange(self.grid.height)
self.grid.place_agent(agent, (x, y)) self.grid.place_agent(agent, (x, y))
@ -97,10 +87,6 @@ class MoneyEnv(Environment):
model_reporters={"Gini": compute_gini}, model_reporters={"Gini": compute_gini},
agent_reporters={"Wealth": "wealth"}) agent_reporters={"Wealth": "wealth"})
def step(self):
super().step()
self.datacollector.collect(self)
self.schedule.step()
def graph_generator(n=5): def graph_generator(n=5):
G = nx.Graph() G = nx.Graph()

@ -68,12 +68,12 @@ network_agents:
- agent_type: HerdViewer - agent_type: HerdViewer
state: state:
has_tv: true has_tv: true
id: neutral state_id: neutral
weight: 1 weight: 1
- agent_type: HerdViewer - agent_type: HerdViewer
state: state:
has_tv: true has_tv: true
id: neutral state_id: neutral
weight: 1 weight: 1
network_params: network_params:
generator: barabasi_albert_graph generator: barabasi_albert_graph
@ -95,7 +95,7 @@ network_agents:
- agent_type: HerdViewer - agent_type: HerdViewer
state: state:
has_tv: true has_tv: true
id: neutral state_id: neutral
weight: 1 weight: 1
- agent_type: WiseViewer - agent_type: WiseViewer
state: state:
@ -121,7 +121,7 @@ network_agents:
- agent_type: WiseViewer - agent_type: WiseViewer
state: state:
has_tv: true has_tv: true
id: neutral state_id: neutral
weight: 1 weight: 1
- agent_type: WiseViewer - agent_type: WiseViewer
state: state:

@ -34,8 +34,6 @@ class HerdViewer(DumbViewer):
A viewer whose probability of infection depends on the state of its neighbors. A viewer whose probability of infection depends on the state of its neighbors.
''' '''
level = logging.DEBUG
def infect(self): def infect(self):
infected = self.count_neighboring_agents(state_id=self.infected.id) infected = self.count_neighboring_agents(state_id=self.infected.id)
total = self.count_neighboring_agents() total = self.count_neighboring_agents()

@ -1,7 +1,7 @@
--- ---
load_module: rabbit_agents load_module: rabbit_agents
name: rabbits_example name: rabbits_example
max_time: 200 max_time: 150
interval: 1 interval: 1
seed: MySeed seed: MySeed
agent_type: RabbitModel agent_type: RabbitModel

@ -16,7 +16,7 @@ template:
- agent_type: CounterModel - agent_type: CounterModel
weight: "{{ x1 }}" weight: "{{ x1 }}"
state: state:
id: 0 state_id: 0
- agent_type: AggregatedCounter - agent_type: AggregatedCounter
weight: "{{ 1 - x1 }}" weight: "{{ 1 - x1 }}"
environment_params: environment_params:

@ -6,3 +6,4 @@ pandas>=0.23
SALib>=1.3 SALib>=1.3
Jinja2 Jinja2
Mesa>=0.8 Mesa>=0.8
tsih>=0.1.5

@ -1 +1 @@
0.15.2 0.20.0

@ -15,7 +15,6 @@ from .agents import *
from . import agents from . import agents
from .simulation import * from .simulation import *
from .environment import Environment from .environment import Environment
from .history import History
from . import serialization from . import serialization
from . import analysis from . import analysis
from .utils import logger from .utils import logger

@ -6,7 +6,9 @@ from itertools import islice
import json import json
import networkx as nx import networkx as nx
from .. import serialization, history, utils, time from .. import serialization, utils, time
from tsih import Key
from mesa import Agent from mesa import Agent
@ -16,6 +18,7 @@ def as_node(agent):
return agent.id return agent.id
return agent return agent
IGNORED_FIELDS = ('model', 'logger')
class BaseAgent(Agent): class BaseAgent(Agent):
""" """
@ -27,23 +30,20 @@ class BaseAgent(Agent):
def __init__(self, def __init__(self,
unique_id, unique_id,
model, model,
state=None,
name=None, name=None,
interval=None): interval=None):
# Check for REQUIRED arguments # Check for REQUIRED arguments
# Initialize agent parameters # Initialize agent parameters
if isinstance(unique_id, Agent): if isinstance(unique_id, Agent):
raise Exception() raise Exception()
self._saved = set()
super().__init__(unique_id=unique_id, model=model) super().__init__(unique_id=unique_id, model=model)
self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id) self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id)
self._neighbors = None self._neighbors = None
self.alive = True self.alive = True
real_state = deepcopy(self.defaults)
real_state.update(state or {})
self.state = real_state
self.interval = interval or self.get('interval', getattr(self.model, 'interval', 1)) self.interval = interval or self.get('interval', 1)
self.logger = logging.getLogger(self.model.name).getChild(self.name) self.logger = logging.getLogger(self.model.name).getChild(self.name)
if hasattr(self, 'level'): if hasattr(self, 'level'):
@ -75,7 +75,6 @@ class BaseAgent(Agent):
@state.setter @state.setter
def state(self, value): def state(self, value):
self._state = {}
for k, v in value.items(): for k, v in value.items():
self[k] = v self[k] = v
@ -87,28 +86,36 @@ class BaseAgent(Agent):
def environment_params(self, value): def environment_params(self, value):
self.model.environment_params = value self.model.environment_params = value
def __setattr__(self, key, value):
if not key.startswith('_') and key not in IGNORED_FIELDS:
try:
k = Key(t_step=self.now,
dict_id=self.unique_id,
key=key)
self._saved.add(key)
self.model[k] = value
except AttributeError:
pass
super().__setattr__(key, value)
def __getitem__(self, key): def __getitem__(self, key):
if isinstance(key, tuple): if isinstance(key, tuple):
key, t_step = key key, t_step = key
k = history.Key(key=key, t_step=t_step, agent_id=self.unique_id) k = Key(key=key, t_step=t_step, dict_id=self.unique_id)
return self.model[k] return self.model[k]
return self._state.get(key, None) return getattr(self, key)
def __delitem__(self, key): def __delitem__(self, key):
self._state[key] = None return delattr(self, key)
def __contains__(self, key): def __contains__(self, key):
return key in self._state return hasattr(self, key)
def __setitem__(self, key, value): def __setitem__(self, key, value):
self._state[key] = value setattr(self, key, value)
k = history.Key(t_step=self.now,
agent_id=self.id,
key=key)
self.model[k] = value
def items(self): def items(self):
return self._state.items() return ((k, getattr(self, k)) for k in self._saved)
def get(self, key, default=None): def get(self, key, default=None):
return self[key] if key in self else default return self[key] if key in self else default
@ -150,25 +157,6 @@ class BaseAgent(Agent):
def info(self, *args, **kwargs): def info(self, *args, **kwargs):
return self.log(*args, level=logging.INFO, **kwargs) return self.log(*args, level=logging.INFO, **kwargs)
# def __getstate__(self):
# '''
# Serializing an agent will lose all its running information (you cannot
# serialize an iterator), but it keeps the state and link to the environment,
# so it can be used for inspection and dumping to a file
# '''
# state = {}
# state['id'] = self.id
# state['environment'] = self.model
# state['_state'] = self._state
# return state
# def __setstate__(self, state):
# '''
# Get back a serialized agent and try to re-compose it
# '''
# self.state_id = state['id']
# self._state = state['_state']
# self.model = state['environment']
class NetworkAgent(BaseAgent): class NetworkAgent(BaseAgent):
@ -303,15 +291,15 @@ class MetaFSM(type):
class FSM(NetworkAgent, metaclass=MetaFSM): class FSM(NetworkAgent, metaclass=MetaFSM):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(FSM, self).__init__(*args, **kwargs) super(FSM, self).__init__(*args, **kwargs)
if 'id' not in self.state: if not hasattr(self, 'state_id'):
if not self.default_state: if not self.default_state:
raise ValueError('No default state specified for {}'.format(self.unique_id)) raise ValueError('No default state specified for {}'.format(self.unique_id))
self['id'] = self.default_state.id self.state_id = self.default_state.id
self.set_state(self.state['id']) self.set_state(self.state_id)
def step(self): def step(self):
self.debug(f'Agent {self.unique_id} @ state {self["id"]}') self.debug(f'Agent {self.unique_id} @ state {self.state_id}')
interval = super().step() interval = super().step()
if 'id' not in self.state: if 'id' not in self.state:
# if 'id' in self.state: # if 'id' in self.state:
@ -320,14 +308,14 @@ class FSM(NetworkAgent, metaclass=MetaFSM):
self.set_state(self.default_state.id) self.set_state(self.default_state.id)
else: else:
raise Exception('{} has no valid state id or default state'.format(self)) raise Exception('{} has no valid state id or default state'.format(self))
return self.states[self.state['id']](self) or interval return self.states[self.state_id](self) or interval
def set_state(self, state): def set_state(self, state):
if hasattr(state, 'id'): if hasattr(state, 'id'):
state = state.id state = state.id
if state not in self.states: if state not in self.states:
raise ValueError('{} is not a valid state'.format(state)) raise ValueError('{} is not a valid state'.format(state))
self.state['id'] = state self.state_id = state
return state return state
@ -541,7 +529,7 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False,
f = filter(lambda x: x not in ignore, f) f = filter(lambda x: x not in ignore, f)
if state_id is not None: if state_id is not None:
f = filter(lambda agent: agent.state.get('id', None) in state_id, f) f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
if agent_type is not None: if agent_type is not None:
f = filter(lambda agent: isinstance(agent, agent_type), f) f = filter(lambda agent: isinstance(agent, agent_type), f)

@ -4,7 +4,8 @@ import glob
import yaml import yaml
from os.path import join from os.path import join
from . import serialization, history from . import serialization
from tsih import History
def read_data(*args, group=False, **kwargs): def read_data(*args, group=False, **kwargs):
@ -34,7 +35,7 @@ def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
def read_sql(db, *args, **kwargs): def read_sql(db, *args, **kwargs):
h = history.History(db_path=db, backup=False, readonly=True) h = History(db_path=db, backup=False, readonly=True)
df = h.read_sql(*args, **kwargs) df = h.read_sql(*args, **kwargs)
return df return df

@ -12,9 +12,11 @@ from networkx.readwrite import json_graph
import networkx as nx import networkx as nx
from tsih import History, Record, Key, NoHistory
from mesa import Model from mesa import Model
from . import serialization, agents, analysis, history, utils, time from . import serialization, agents, analysis, utils, time
# These properties will be copied when pickling/unpickling the environment # These properties will be copied when pickling/unpickling the environment
_CONFIG_PROPS = [ 'name', _CONFIG_PROPS = [ 'name',
@ -46,6 +48,7 @@ class Environment(Model):
schedule=None, schedule=None,
initial_time=0, initial_time=0,
environment_params=None, environment_params=None,
history=True,
dir_path=None, dir_path=None,
**kwargs): **kwargs):
@ -78,8 +81,12 @@ class Environment(Model):
self._env_agents = {} self._env_agents = {}
self.interval = interval self.interval = interval
self._history = history.History(name=self.name, if history:
backup=True) history = History
else:
history = NoHistory
self._history = history(name=self.name,
backup=True)
self['SEED'] = seed self['SEED'] = seed
if network_agents: if network_agents:
@ -162,8 +169,15 @@ class Environment(Model):
if agent_type: if agent_type:
state = defstate state = defstate
a = agent_type(model=self, a = agent_type(model=self,
unique_id=agent_id, unique_id=agent_id)
state=state)
for (k, v) in getattr(a, 'defaults', {}).items():
if not hasattr(a, k) or getattr(a, k) is None:
setattr(a, k, v)
for (k, v) in state.items():
setattr(a, k, v)
node['agent'] = a node['agent'] = a
self.schedule.add(a) self.schedule.add(a)
return a return a
@ -183,6 +197,11 @@ class Environment(Model):
start = start or self.now start = start or self.now
return self.G.add_edge(agent1, agent2, **attrs) return self.G.add_edge(agent1, agent2, **attrs)
def step(self):
super().step()
self.datacollector.collect(self)
self.schedule.step()
def run(self, until, *args, **kwargs): def run(self, until, *args, **kwargs):
self._save_state() self._save_state()
@ -204,12 +223,12 @@ class Environment(Model):
def __setitem__(self, key, value): def __setitem__(self, key, value):
if isinstance(key, tuple): if isinstance(key, tuple):
k = history.Key(*key) k = Key(*key)
self._history.save_record(*k, self._history.save_record(*k,
value=value) value=value)
return return
self.environment_params[key] = value self.environment_params[key] = value
self._history.save_record(agent_id='env', self._history.save_record(dict_id='env',
t_step=self.now, t_step=self.now,
key=key, key=key,
value=value) value=value)
@ -274,16 +293,16 @@ class Environment(Model):
if now is None: if now is None:
now = self.now now = self.now
for k, v in self.environment_params.items(): for k, v in self.environment_params.items():
yield history.Record(agent_id='env', yield Record(dict_id='env',
t_step=now, t_step=now,
key=k, key=k,
value=v) value=v)
for agent in self.agents: for agent in self.agents:
for k, v in agent.state.items(): for k, v in agent.state.items():
yield history.Record(agent_id=agent.id, yield Record(dict_id=agent.id,
t_step=now, t_step=now,
key=k, key=k,
value=v) value=v)
def history_to_tuples(self): def history_to_tuples(self):
return self._history.to_tuples() return self._history.to_tuples()

@ -1,388 +0,0 @@
import time
import os
import pandas as pd
import sqlite3
import copy
import logging
import tempfile
logger = logging.getLogger(__name__)
from collections import UserDict, namedtuple
from . import serialization
from .utils import open_or_reuse, unflatten_dict
class History:
"""
Store and retrieve values from a sqlite database.
"""
def __init__(self, name=None, db_path=None, backup=False, readonly=False):
if readonly and (not os.path.exists(db_path)):
raise Exception('The DB file does not exist. Cannot open in read-only mode')
self._db = None
self._temp = db_path is None
self._stats_columns = None
self.readonly = readonly
if self._temp:
if not name:
name = time.time()
# The file will be deleted as soon as it's closed
# Normally, that will be on destruction
db_path = tempfile.NamedTemporaryFile(suffix='{}.sqlite'.format(name)).name
if backup and os.path.exists(db_path):
newname = db_path + '.backup{}.sqlite'.format(time.time())
os.rename(db_path, newname)
self.db_path = db_path
self.db = db_path
self._dtypes = {}
self._tups = []
if self.readonly:
return
with self.db:
logger.debug('Creating database {}'.format(self.db_path))
self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step real, key text, value text)''')
self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''')
self.db.execute('''CREATE TABLE IF NOT EXISTS stats (trial_id text)''')
self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''')
@property
def db(self):
try:
self._db.cursor()
except (sqlite3.ProgrammingError, AttributeError):
self.db = None # Reset the database
return self._db
@db.setter
def db(self, db_path=None):
self._close()
db_path = db_path or self.db_path
if isinstance(db_path, str):
logger.debug('Connecting to database {}'.format(db_path))
self._db = sqlite3.connect(db_path)
self._db.row_factory = sqlite3.Row
else:
self._db = db_path
def _close(self):
if self._db is None:
return
self.flush_cache()
self._db.close()
self._db = None
def save_stats(self, stat):
if self.readonly:
print('DB in readonly mode')
return
if not stat:
return
with self.db:
if not self._stats_columns:
self._stats_columns = list(c['name'] for c in self.db.execute('PRAGMA table_info(stats)'))
for column, value in stat.items():
if column in self._stats_columns:
continue
dtype = 'text'
if not isinstance(value, str):
try:
float(value)
dtype = 'real'
int(value)
dtype = 'int'
except (ValueError, OverflowError):
pass
self.db.execute('ALTER TABLE stats ADD "{}" "{}"'.format(column, dtype))
self._stats_columns.append(column)
columns = ", ".join(map(lambda x: '"{}"'.format(x), stat.keys()))
values = ", ".join(['"{0}"'.format(col) for col in stat.values()])
query = "INSERT INTO stats ({columns}) VALUES ({values})".format(
columns=columns,
values=values
)
self.db.execute(query)
def get_stats(self, unflatten=True):
rows = self.db.execute("select * from stats").fetchall()
res = []
for row in rows:
d = {}
for k in row.keys():
if row[k] is None:
continue
d[k] = row[k]
if unflatten:
d = unflatten_dict(d)
res.append(d)
return res
@property
def dtypes(self):
self._read_types()
return {k:v[0] for k, v in self._dtypes.items()}
def save_tuples(self, tuples):
'''
Save a series of tuples, converting them to records if necessary
'''
self.save_records(Record(*tup) for tup in tuples)
def save_records(self, records):
'''
Save a collection of records
'''
for record in records:
if not isinstance(record, Record):
record = Record(*record)
self.save_record(*record)
def save_record(self, agent_id, t_step, key, value):
'''
Save a collection of records to the database.
Database writes are cached.
'''
if self.readonly:
raise Exception('DB in readonly mode')
if key not in self._dtypes:
self._read_types()
if key not in self._dtypes:
name = serialization.name(value)
serializer = serialization.serializer(name)
deserializer = serialization.deserializer(name)
self._dtypes[key] = (name, serializer, deserializer)
with self.db:
self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name))
value = self._dtypes[key][1](value)
self._tups.append(Record(agent_id=agent_id,
t_step=t_step,
key=key,
value=value))
if len(self._tups) > 100:
self.flush_cache()
def flush_cache(self):
'''
Use a cache to save state changes to avoid opening a session for every change.
The cache will be flushed at the end of the simulation, and when history is accessed.
'''
if self.readonly:
raise Exception('DB in readonly mode')
logger.debug('Flushing cache {}'.format(self.db_path))
with self.db:
self.db.executemany("replace into history(agent_id, t_step, key, value) values (?, ?, ?, ?)", self._tups)
# (rec.agent_id, rec.t_step, rec.key, rec.value))
self._tups.clear()
def to_tuples(self):
self.flush_cache()
with self.db:
res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall()
for r in res:
agent_id, t_step, key, value = r
if key not in self._dtypes:
self._read_types()
if key not in self._dtypes:
raise ValueError("Unknown datatype for {} and {}".format(key, value))
value = self._dtypes[key][2](value)
yield agent_id, t_step, key, value
def _read_types(self):
with self.db:
res = self.db.execute("select key, value_type from value_types ").fetchall()
for k, v in res:
serializer = serialization.serializer(v)
deserializer = serialization.deserializer(v)
self._dtypes[k] = (v, serializer, deserializer)
def __getitem__(self, key):
# raise NotImplementedError()
self.flush_cache()
key = Key(*key)
agent_ids = [key.agent_id] if key.agent_id is not None else []
t_steps = [key.t_step] if key.t_step is not None else []
keys = [key.key] if key.key is not None else []
df = self.read_sql(agent_ids=agent_ids,
t_steps=t_steps,
keys=keys)
r = Records(df, filter=key, dtypes=self._dtypes)
if r.resolved:
return r.value()
return r
def read_sql(self, keys=None, agent_ids=None, not_agent_ids=None, t_steps=None, convert_types=False, limit=-1):
self._read_types()
def escape_and_join(v):
if v is None:
return
return ",".join(map(lambda x: "\'{}\'".format(x), v))
filters = [("key in ({})".format(escape_and_join(keys)), keys),
("agent_id in ({})".format(escape_and_join(agent_ids)), agent_ids),
("agent_id not in ({})".format(escape_and_join(not_agent_ids)), not_agent_ids)
]
filters = list(k[0] for k in filters if k[1])
last_df = None
if t_steps:
# Convert negative indices into positive
if any(x<0 for x in t_steps):
max_t = int(self.db.execute("select max(t_step) from history").fetchone()[0])
t_steps = [t if t>0 else max_t+1+t for t in t_steps]
# We will be doing ffill interpolation, so we need to look for
# the last value before the minimum step in the query
min_step = min(t_steps)
last_filters = ['t_step < {}'.format(min_step),]
last_filters = last_filters + filters
condition = ' and '.join(last_filters)
last_query = '''
select h1.*
from history h1
inner join (
select agent_id, key, max(t_step) as t_step
from history
where {condition}
group by agent_id, key
) h2
on h1.agent_id = h2.agent_id and
h1.key = h2.key and
h1.t_step = h2.t_step
'''.format(condition=condition)
last_df = pd.read_sql_query(last_query, self.db)
filters.append("t_step >= '{}' and t_step <= '{}'".format(min_step, max(t_steps)))
condition = ''
if filters:
condition = 'where {} '.format(' and '.join(filters))
query = 'select * from history {} limit {}'.format(condition, limit)
df = pd.read_sql_query(query, self.db)
if last_df is not None:
df = pd.concat([df, last_df])
df_p = df.pivot_table(values='value', index=['t_step'],
columns=['key', 'agent_id'],
aggfunc='first')
for k, v in self._dtypes.items():
if k in df_p:
dtype, _, deserial = v
try:
df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
except (TypeError, ValueError):
# Avoid forward-filling unknown/incompatible types
continue
if t_steps:
df_p = df_p.reindex(t_steps, method='ffill')
return df_p.ffill()
def __getstate__(self):
state = dict(**self.__dict__)
del state['_db']
del state['_dtypes']
return state
def __setstate__(self, state):
self.__dict__ = state
self._dtypes = {}
self._db = None
def dump(self, f):
self._close()
for line in open_or_reuse(self.db_path, 'rb'):
f.write(line)
class Records():
def __init__(self, df, filter=None, dtypes=None):
if not filter:
filter = Key(agent_id=None,
t_step=None,
key=None)
self._df = df
self._filter = filter
self.dtypes = dtypes or {}
super().__init__()
def mask(self, tup):
res = ()
for i, k in zip(tup[:-1], self._filter):
if k is None:
res = res + (i,)
res = res + (tup[-1],)
return res
def filter(self, newKey):
f = list(self._filter)
for ix, i in enumerate(f):
if i is None:
f[ix] = newKey
self._filter = Key(*f)
@property
def resolved(self):
return sum(1 for i in self._filter if i is not None) == 3
def __iter__(self):
for column, series in self._df.iteritems():
key, agent_id = column
for t_step, value in series.iteritems():
r = Record(t_step=t_step,
agent_id=agent_id,
key=key,
value=value)
yield self.mask(r)
def value(self):
if self.resolved:
f = self._filter
try:
i = self._df[f.key][str(f.agent_id)]
ix = i.index.get_loc(f.t_step, method='ffill')
return i.iloc[ix]
except KeyError as ex:
return self.dtypes[f.key][2]()
return list(self)
def df(self):
return self._df
def __getitem__(self, k):
n = copy.copy(self)
n.filter(k)
if n.resolved:
return n.value()
return n
def __len__(self):
return len(self._df)
def __str__(self):
if self.resolved:
return str(self.value())
return '<Records for [{}]>'.format(self._filter)
Key = namedtuple('Key', ['agent_id', 't_step', 'key'])
Record = namedtuple('Record', 'agent_id t_step key value')
Stat = namedtuple('Stat', 'trial_id')

@ -9,6 +9,7 @@ import networkx as nx
from networkx.readwrite import json_graph from networkx.readwrite import json_graph
from multiprocessing import Pool from multiprocessing import Pool
from functools import partial from functools import partial
from tsih import History
import pickle import pickle
@ -17,7 +18,6 @@ from .environment import Environment
from .utils import logger from .utils import logger
from .exporters import default from .exporters import default
from .stats import defaultStats from .stats import defaultStats
from .history import History
#TODO: change documentation for simulation #TODO: change documentation for simulation
@ -159,7 +159,7 @@ class Simulation:
**kwargs) **kwargs)
def run_gen(self, *args, parallel=False, dry_run=False, def run_gen(self, *args, parallel=False, dry_run=False,
exporters=[default, ], stats=[defaultStats], outdir=None, exporter_params={}, exporters=[default, ], stats=[], outdir=None, exporter_params={},
stats_params={}, log_level=None, stats_params={}, log_level=None,
**kwargs): **kwargs):
'''Run the simulation and yield the resulting environments.''' '''Run the simulation and yield the resulting environments.'''

@ -97,7 +97,7 @@ class defaultStats(Stats):
return { return {
'network ': { 'network ': {
'n_nodes': env.G.number_of_nodes(), 'n_nodes': env.G.number_of_nodes(),
'n_edges': env.G.number_of_nodes(), 'n_edges': env.G.number_of_edges(),
}, },
'agents': { 'agents': {
'model_count': dict(c), 'model_count': dict(c),

@ -26,6 +26,8 @@ def timer(name='task', pre="", function=logger.info, to_object=None):
to_object.end = end to_object.end = end
def safe_open(path, mode='r', backup=True, **kwargs): def safe_open(path, mode='r', backup=True, **kwargs):
outdir = os.path.dirname(path) outdir = os.path.dirname(path)
if outdir and not os.path.exists(outdir): if outdir and not os.path.exists(outdir):

@ -67,13 +67,13 @@ class TestAnalysis(TestCase):
def test_count(self): def test_count(self):
env = self.env env = self.env
df = analysis.read_sql(env._history.db_path) df = analysis.read_sql(env._history.db_path)
res = analysis.get_count(df, 'SEED', 'id') res = analysis.get_count(df, 'SEED', 'state_id')
assert res['SEED'][self.env['SEED']].iloc[0] == 1 assert res['SEED'][self.env['SEED']].iloc[0] == 1
assert res['SEED'][self.env['SEED']].iloc[-1] == 1 assert res['SEED'][self.env['SEED']].iloc[-1] == 1
assert res['id']['odd'].iloc[0] == 2 assert res['state_id']['odd'].iloc[0] == 2
assert res['id']['even'].iloc[0] == 0 assert res['state_id']['even'].iloc[0] == 0
assert res['id']['odd'].iloc[-1] == 1 assert res['state_id']['odd'].iloc[-1] == 1
assert res['id']['even'].iloc[-1] == 1 assert res['state_id']['even'].iloc[-1] == 1
def test_value(self): def test_value(self):
env = self.env env = self.env

@ -9,7 +9,7 @@ from functools import partial
from os.path import join from os.path import join
from soil import (simulation, Environment, agents, serialization, from soil import (simulation, Environment, agents, serialization,
history, utils) utils)
from soil.time import Delta from soil.time import Delta
@ -21,8 +21,8 @@ class CustomAgent(agents.FSM):
@agents.default_state @agents.default_state
@agents.state @agents.state
def normal(self): def normal(self):
self.state['neighbors'] = self.count_agents(state_id='normal', self.neighbors = self.count_agents(state_id='normal',
limit_neighbors=True) limit_neighbors=True)
@agents.state @agents.state
def unreachable(self): def unreachable(self):
return return
@ -116,7 +116,7 @@ class TestMain(TestCase):
'network_agents': [{ 'network_agents': [{
'agent_type': 'AggregatedCounter', 'agent_type': 'AggregatedCounter',
'weight': 1, 'weight': 1,
'state': {'id': 0} 'state': {'state_id': 0}
}], }],
'max_time': 10, 'max_time': 10,
@ -149,10 +149,9 @@ class TestMain(TestCase):
} }
s = simulation.from_config(config) s = simulation.from_config(config)
env = s.run_simulation(dry_run=True)[0] env = s.run_simulation(dry_run=True)[0]
assert env.get_agent(0).state['neighbors'] == 1
assert env.get_agent(0).state['neighbors'] == 1
assert env.get_agent(1).count_agents(state_id='normal') == 2 assert env.get_agent(1).count_agents(state_id='normal') == 2
assert env.get_agent(1).count_agents(state_id='normal', limit_neighbors=True) == 1 assert env.get_agent(1).count_agents(state_id='normal', limit_neighbors=True) == 1
assert env.get_agent(0).neighbors == 1
def test_torvalds_example(self): def test_torvalds_example(self):
"""A complete example from a documentation should work.""" """A complete example from a documentation should work."""
@ -317,12 +316,6 @@ class TestMain(TestCase):
assert recovered['key', 0] == 'test' assert recovered['key', 0] == 'test'
assert recovered['key'] == 'test' assert recovered['key'] == 'test'
def test_history(self):
'''Test storing in and retrieving from history (sqlite)'''
h = history.History()
h.save_record(agent_id=0, t_step=0, key="test", value="hello")
assert h[0, 0, "test"] == "hello"
def test_subgraph(self): def test_subgraph(self):
'''An agent should be able to subgraph the global topology''' '''An agent should be able to subgraph the global topology'''
G = nx.Graph() G = nx.Graph()
@ -350,12 +343,13 @@ class TestMain(TestCase):
'network_params': {}, 'network_params': {},
'agent_type': 'CounterModel', 'agent_type': 'CounterModel',
'max_time': 2, 'max_time': 2,
'num_trials': 100, 'num_trials': 50,
'environment_params': {} 'environment_params': {}
} }
s = simulation.from_config(config) s = simulation.from_config(config)
runs = list(s.run_simulation(dry_run=True)) runs = list(s.run_simulation(dry_run=True))
over = list(x.now for x in runs if x.now>2) over = list(x.now for x in runs if x.now>2)
assert len(runs) == config['num_trials']
assert len(over) == 0 assert len(over) == 0
@ -372,11 +366,11 @@ class TestMain(TestCase):
return self.ping return self.ping
a = ToggleAgent(unique_id=1, model=Environment()) a = ToggleAgent(unique_id=1, model=Environment())
assert a.state["id"] == a.ping.id assert a.state_id == a.ping.id
a.step() a.step()
assert a.state["id"] == a.pong.id assert a.state_id == a.pong.id
a.step() a.step()
assert a.state["id"] == a.ping.id assert a.state_id == a.ping.id
def test_fsm_when(self): def test_fsm_when(self):
'''Basic state change''' '''Basic state change'''

Loading…
Cancel
Save