Partial MESA compatibility and several fixes

Documentation for the new APIs is still a work in progress :)
pull/8/head
J. Fernando Sánchez 2 years ago
parent af9a392a93
commit 6c4f44b4cb

1
.gitignore vendored

@ -8,3 +8,4 @@ soil_output
docs/_build*
build/*
dist/*
prof

@ -5,23 +5,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased]
### Added
* [WIP] Integration with MESA
* Integration with MESA
* `not_agent_ids` paramter to get sql in history
### Changed
* `soil.Environment` now also inherits from `mesa.Model`
* `soil.Agent` now also inherits from `mesa.Agent`
* `soil.time` to replace `simpy` events, delays, duration, etc.
* `agent.id` is not `agent.unique_id` to be compatible with `mesa`. A property `BaseAgent.id` has been added for compatibility.
* `agent.environment` is now `agent.model`, for the same reason as above. The parameter name in `BaseAgent.__init__` has also been renamed.
### Removed
* `simpy` dependency and compatibility. Each agent used to be a simpy generator, but that made debugging and error handling more complex. That has been replaced by a scheduler within the `soil.Environment` class, similar to how `mesa` does it.
* `soil.history` is now a separate package named `tsih`. The keys namedtuple uses `dict_id` instead of `agent_id`.
### TODO:
* agent_id -> unique_id?
* mesa has Agent.model and soil has Agent.env
* Environments.agents and mesa.Agent.agents are not the same. env is a property, and it only takes into account network and environment agents. Might rename environment_agents to other_agents or sth like that
* soil.History should mimic a mesa.datacollector :/
* soil.Simulation *could* mimic a mesa.batchrunner
* DONE include scheduler in environment
* DONE environment inherits from `mesa.Model`
### Added
* An option to choose whether a database should be used for history
## [0.15.2]

@ -5,6 +5,9 @@ Learn how to run your own simulations with our [documentation](http://soilsim.re
Follow our [tutorial](examples/tutorial/soil_tutorial.ipynb) to develop your own agent models.
## Citation
If you use Soil in your research, don't forget to cite this paper:
```bibtex
@ -28,7 +31,24 @@ If you use Soil in your research, don't forget to cite this paper:
```
@Copyright GSI - Universidad Politécnica de Madrid 2017
## Mesa compatibility
[![SOIL](logo_gsi.png)](https://www.gsi.upm.es)
Soil is in the process of becoming fully compatible with MESA.
As of this writing,
This is a non-exhaustive list of tasks to achieve compatibility:
* Environments.agents and mesa.Agent.agents are not the same. env is a property, and it only takes into account network and environment agents. Might rename environment_agents to other_agents or sth like that
- [ ] Integrate `soil.Simulation` with mesa's runners:
- [ ] `soil.Simulation` could mimic/become a `mesa.batchrunner`
- [ ] Integrate `soil.Environment` with `mesa.Model`:
- [x] `Soil.Environment` inherits from `mesa.Model`
- [x] `Soil.Environment` includes a Mesa-like Scheduler (see the `soil.time` module.
- [ ] Integrate `soil.Agent` with `mesa.Agent`:
- [x] Rename agent.id to unique_id?
- [x] mesa agents can be used in soil simulations (see `examples/mesa`)
- [ ] Document the new APIs and usage
@Copyright GSI - Universidad Politécnica de Madrid 2017-2021
[![SOIL](logo_gsi.png)](https://www.gsi.upm.es)

@ -13,7 +13,7 @@ network_agents:
- agent_type: CounterModel
weight: 1
state:
id: 0
state_id: 0
- agent_type: AggregatedCounter
weight: 0.2
environment_agents: []

@ -13,4 +13,4 @@ network_agents:
- agent_type: CounterModel
weight: 1
state:
id: 0
state_id: 0

@ -23,7 +23,6 @@ def network_portrayal(env):
}
for (agent_id) in env.G.nodes
]
# import pdb;pdb.set_trace()
portrayal["edges"] = [
{"id": edge_id, "source": source, "target": target, "color": "#000000"}
@ -65,7 +64,7 @@ model_params = {
"N": UserSettableParameter(
"slider",
"N",
1,
5,
1,
10,
1,

@ -34,9 +34,7 @@ class MoneyAgent(MesaAgent):
self.pos,
moore=True,
include_center=False)
print(self.pos, possible_steps)
new_position = self.random.choice(possible_steps)
print(self.pos, new_position)
self.model.grid.move_agent(self, new_position)
def give_money(self):
@ -74,21 +72,13 @@ class SocialMoneyAgent(NetworkAgent, MoneyAgent):
class MoneyEnv(Environment):
"""A model with some number of agents."""
def __init__(self, N, width, height, *args, network_params, **kwargs):
self.initialized = True
# import pdb;pdb.set_trace()
network_params['n'] = N
super().__init__(*args, network_params=network_params, **kwargs)
self.grid = MultiGrid(width, height, False)
# self.schedule = RandomActivation(self)
self.running = True
# Create agents
for agent in self.agents:
self.schedule.add(agent)
# a = MoneyAgent(i, self)
# self.schedule.add(a)
# Add the agent to a random grid cell
x = self.random.randrange(self.grid.width)
y = self.random.randrange(self.grid.height)
self.grid.place_agent(agent, (x, y))
@ -97,10 +87,6 @@ class MoneyEnv(Environment):
model_reporters={"Gini": compute_gini},
agent_reporters={"Wealth": "wealth"})
def step(self):
super().step()
self.datacollector.collect(self)
self.schedule.step()
def graph_generator(n=5):
G = nx.Graph()

@ -68,12 +68,12 @@ network_agents:
- agent_type: HerdViewer
state:
has_tv: true
id: neutral
state_id: neutral
weight: 1
- agent_type: HerdViewer
state:
has_tv: true
id: neutral
state_id: neutral
weight: 1
network_params:
generator: barabasi_albert_graph
@ -95,7 +95,7 @@ network_agents:
- agent_type: HerdViewer
state:
has_tv: true
id: neutral
state_id: neutral
weight: 1
- agent_type: WiseViewer
state:
@ -121,7 +121,7 @@ network_agents:
- agent_type: WiseViewer
state:
has_tv: true
id: neutral
state_id: neutral
weight: 1
- agent_type: WiseViewer
state:

@ -34,8 +34,6 @@ class HerdViewer(DumbViewer):
A viewer whose probability of infection depends on the state of its neighbors.
'''
level = logging.DEBUG
def infect(self):
infected = self.count_neighboring_agents(state_id=self.infected.id)
total = self.count_neighboring_agents()

@ -1,7 +1,7 @@
---
load_module: rabbit_agents
name: rabbits_example
max_time: 200
max_time: 150
interval: 1
seed: MySeed
agent_type: RabbitModel

@ -16,7 +16,7 @@ template:
- agent_type: CounterModel
weight: "{{ x1 }}"
state:
id: 0
state_id: 0
- agent_type: AggregatedCounter
weight: "{{ 1 - x1 }}"
environment_params:

@ -6,3 +6,4 @@ pandas>=0.23
SALib>=1.3
Jinja2
Mesa>=0.8
tsih>=0.1.5

@ -1 +1 @@
0.15.2
0.20.0

@ -15,7 +15,6 @@ from .agents import *
from . import agents
from .simulation import *
from .environment import Environment
from .history import History
from . import serialization
from . import analysis
from .utils import logger

@ -6,7 +6,9 @@ from itertools import islice
import json
import networkx as nx
from .. import serialization, history, utils, time
from .. import serialization, utils, time
from tsih import Key
from mesa import Agent
@ -16,6 +18,7 @@ def as_node(agent):
return agent.id
return agent
IGNORED_FIELDS = ('model', 'logger')
class BaseAgent(Agent):
"""
@ -27,23 +30,20 @@ class BaseAgent(Agent):
def __init__(self,
unique_id,
model,
state=None,
name=None,
interval=None):
# Check for REQUIRED arguments
# Initialize agent parameters
if isinstance(unique_id, Agent):
raise Exception()
self._saved = set()
super().__init__(unique_id=unique_id, model=model)
self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id)
self._neighbors = None
self.alive = True
real_state = deepcopy(self.defaults)
real_state.update(state or {})
self.state = real_state
self.interval = interval or self.get('interval', getattr(self.model, 'interval', 1))
self.interval = interval or self.get('interval', 1)
self.logger = logging.getLogger(self.model.name).getChild(self.name)
if hasattr(self, 'level'):
@ -75,7 +75,6 @@ class BaseAgent(Agent):
@state.setter
def state(self, value):
self._state = {}
for k, v in value.items():
self[k] = v
@ -87,28 +86,36 @@ class BaseAgent(Agent):
def environment_params(self, value):
self.model.environment_params = value
def __setattr__(self, key, value):
if not key.startswith('_') and key not in IGNORED_FIELDS:
try:
k = Key(t_step=self.now,
dict_id=self.unique_id,
key=key)
self._saved.add(key)
self.model[k] = value
except AttributeError:
pass
super().__setattr__(key, value)
def __getitem__(self, key):
if isinstance(key, tuple):
key, t_step = key
k = history.Key(key=key, t_step=t_step, agent_id=self.unique_id)
k = Key(key=key, t_step=t_step, dict_id=self.unique_id)
return self.model[k]
return self._state.get(key, None)
return getattr(self, key)
def __delitem__(self, key):
self._state[key] = None
return delattr(self, key)
def __contains__(self, key):
return key in self._state
return hasattr(self, key)
def __setitem__(self, key, value):
self._state[key] = value
k = history.Key(t_step=self.now,
agent_id=self.id,
key=key)
self.model[k] = value
setattr(self, key, value)
def items(self):
return self._state.items()
return ((k, getattr(self, k)) for k in self._saved)
def get(self, key, default=None):
return self[key] if key in self else default
@ -150,25 +157,6 @@ class BaseAgent(Agent):
def info(self, *args, **kwargs):
return self.log(*args, level=logging.INFO, **kwargs)
# def __getstate__(self):
# '''
# Serializing an agent will lose all its running information (you cannot
# serialize an iterator), but it keeps the state and link to the environment,
# so it can be used for inspection and dumping to a file
# '''
# state = {}
# state['id'] = self.id
# state['environment'] = self.model
# state['_state'] = self._state
# return state
# def __setstate__(self, state):
# '''
# Get back a serialized agent and try to re-compose it
# '''
# self.state_id = state['id']
# self._state = state['_state']
# self.model = state['environment']
class NetworkAgent(BaseAgent):
@ -303,15 +291,15 @@ class MetaFSM(type):
class FSM(NetworkAgent, metaclass=MetaFSM):
def __init__(self, *args, **kwargs):
super(FSM, self).__init__(*args, **kwargs)
if 'id' not in self.state:
if not hasattr(self, 'state_id'):
if not self.default_state:
raise ValueError('No default state specified for {}'.format(self.unique_id))
self['id'] = self.default_state.id
self.state_id = self.default_state.id
self.set_state(self.state['id'])
self.set_state(self.state_id)
def step(self):
self.debug(f'Agent {self.unique_id} @ state {self["id"]}')
self.debug(f'Agent {self.unique_id} @ state {self.state_id}')
interval = super().step()
if 'id' not in self.state:
# if 'id' in self.state:
@ -320,14 +308,14 @@ class FSM(NetworkAgent, metaclass=MetaFSM):
self.set_state(self.default_state.id)
else:
raise Exception('{} has no valid state id or default state'.format(self))
return self.states[self.state['id']](self) or interval
return self.states[self.state_id](self) or interval
def set_state(self, state):
if hasattr(state, 'id'):
state = state.id
if state not in self.states:
raise ValueError('{} is not a valid state'.format(state))
self.state['id'] = state
self.state_id = state
return state
@ -541,7 +529,7 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False,
f = filter(lambda x: x not in ignore, f)
if state_id is not None:
f = filter(lambda agent: agent.state.get('id', None) in state_id, f)
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
if agent_type is not None:
f = filter(lambda agent: isinstance(agent, agent_type), f)

@ -4,7 +4,8 @@ import glob
import yaml
from os.path import join
from . import serialization, history
from . import serialization
from tsih import History
def read_data(*args, group=False, **kwargs):
@ -34,7 +35,7 @@ def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
def read_sql(db, *args, **kwargs):
h = history.History(db_path=db, backup=False, readonly=True)
h = History(db_path=db, backup=False, readonly=True)
df = h.read_sql(*args, **kwargs)
return df

@ -12,9 +12,11 @@ from networkx.readwrite import json_graph
import networkx as nx
from tsih import History, Record, Key, NoHistory
from mesa import Model
from . import serialization, agents, analysis, history, utils, time
from . import serialization, agents, analysis, utils, time
# These properties will be copied when pickling/unpickling the environment
_CONFIG_PROPS = [ 'name',
@ -46,6 +48,7 @@ class Environment(Model):
schedule=None,
initial_time=0,
environment_params=None,
history=True,
dir_path=None,
**kwargs):
@ -78,8 +81,12 @@ class Environment(Model):
self._env_agents = {}
self.interval = interval
self._history = history.History(name=self.name,
backup=True)
if history:
history = History
else:
history = NoHistory
self._history = history(name=self.name,
backup=True)
self['SEED'] = seed
if network_agents:
@ -162,8 +169,15 @@ class Environment(Model):
if agent_type:
state = defstate
a = agent_type(model=self,
unique_id=agent_id,
state=state)
unique_id=agent_id)
for (k, v) in getattr(a, 'defaults', {}).items():
if not hasattr(a, k) or getattr(a, k) is None:
setattr(a, k, v)
for (k, v) in state.items():
setattr(a, k, v)
node['agent'] = a
self.schedule.add(a)
return a
@ -183,6 +197,11 @@ class Environment(Model):
start = start or self.now
return self.G.add_edge(agent1, agent2, **attrs)
def step(self):
super().step()
self.datacollector.collect(self)
self.schedule.step()
def run(self, until, *args, **kwargs):
self._save_state()
@ -204,12 +223,12 @@ class Environment(Model):
def __setitem__(self, key, value):
if isinstance(key, tuple):
k = history.Key(*key)
k = Key(*key)
self._history.save_record(*k,
value=value)
return
self.environment_params[key] = value
self._history.save_record(agent_id='env',
self._history.save_record(dict_id='env',
t_step=self.now,
key=key,
value=value)
@ -274,16 +293,16 @@ class Environment(Model):
if now is None:
now = self.now
for k, v in self.environment_params.items():
yield history.Record(agent_id='env',
t_step=now,
key=k,
value=v)
yield Record(dict_id='env',
t_step=now,
key=k,
value=v)
for agent in self.agents:
for k, v in agent.state.items():
yield history.Record(agent_id=agent.id,
t_step=now,
key=k,
value=v)
yield Record(dict_id=agent.id,
t_step=now,
key=k,
value=v)
def history_to_tuples(self):
return self._history.to_tuples()

@ -1,388 +0,0 @@
import time
import os
import pandas as pd
import sqlite3
import copy
import logging
import tempfile
logger = logging.getLogger(__name__)
from collections import UserDict, namedtuple
from . import serialization
from .utils import open_or_reuse, unflatten_dict
class History:
"""
Store and retrieve values from a sqlite database.
"""
def __init__(self, name=None, db_path=None, backup=False, readonly=False):
if readonly and (not os.path.exists(db_path)):
raise Exception('The DB file does not exist. Cannot open in read-only mode')
self._db = None
self._temp = db_path is None
self._stats_columns = None
self.readonly = readonly
if self._temp:
if not name:
name = time.time()
# The file will be deleted as soon as it's closed
# Normally, that will be on destruction
db_path = tempfile.NamedTemporaryFile(suffix='{}.sqlite'.format(name)).name
if backup and os.path.exists(db_path):
newname = db_path + '.backup{}.sqlite'.format(time.time())
os.rename(db_path, newname)
self.db_path = db_path
self.db = db_path
self._dtypes = {}
self._tups = []
if self.readonly:
return
with self.db:
logger.debug('Creating database {}'.format(self.db_path))
self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step real, key text, value text)''')
self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''')
self.db.execute('''CREATE TABLE IF NOT EXISTS stats (trial_id text)''')
self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''')
@property
def db(self):
try:
self._db.cursor()
except (sqlite3.ProgrammingError, AttributeError):
self.db = None # Reset the database
return self._db
@db.setter
def db(self, db_path=None):
self._close()
db_path = db_path or self.db_path
if isinstance(db_path, str):
logger.debug('Connecting to database {}'.format(db_path))
self._db = sqlite3.connect(db_path)
self._db.row_factory = sqlite3.Row
else:
self._db = db_path
def _close(self):
if self._db is None:
return
self.flush_cache()
self._db.close()
self._db = None
def save_stats(self, stat):
if self.readonly:
print('DB in readonly mode')
return
if not stat:
return
with self.db:
if not self._stats_columns:
self._stats_columns = list(c['name'] for c in self.db.execute('PRAGMA table_info(stats)'))
for column, value in stat.items():
if column in self._stats_columns:
continue
dtype = 'text'
if not isinstance(value, str):
try:
float(value)
dtype = 'real'
int(value)
dtype = 'int'
except (ValueError, OverflowError):
pass
self.db.execute('ALTER TABLE stats ADD "{}" "{}"'.format(column, dtype))
self._stats_columns.append(column)
columns = ", ".join(map(lambda x: '"{}"'.format(x), stat.keys()))
values = ", ".join(['"{0}"'.format(col) for col in stat.values()])
query = "INSERT INTO stats ({columns}) VALUES ({values})".format(
columns=columns,
values=values
)
self.db.execute(query)
def get_stats(self, unflatten=True):
rows = self.db.execute("select * from stats").fetchall()
res = []
for row in rows:
d = {}
for k in row.keys():
if row[k] is None:
continue
d[k] = row[k]
if unflatten:
d = unflatten_dict(d)
res.append(d)
return res
@property
def dtypes(self):
self._read_types()
return {k:v[0] for k, v in self._dtypes.items()}
def save_tuples(self, tuples):
'''
Save a series of tuples, converting them to records if necessary
'''
self.save_records(Record(*tup) for tup in tuples)
def save_records(self, records):
'''
Save a collection of records
'''
for record in records:
if not isinstance(record, Record):
record = Record(*record)
self.save_record(*record)
def save_record(self, agent_id, t_step, key, value):
'''
Save a collection of records to the database.
Database writes are cached.
'''
if self.readonly:
raise Exception('DB in readonly mode')
if key not in self._dtypes:
self._read_types()
if key not in self._dtypes:
name = serialization.name(value)
serializer = serialization.serializer(name)
deserializer = serialization.deserializer(name)
self._dtypes[key] = (name, serializer, deserializer)
with self.db:
self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name))
value = self._dtypes[key][1](value)
self._tups.append(Record(agent_id=agent_id,
t_step=t_step,
key=key,
value=value))
if len(self._tups) > 100:
self.flush_cache()
def flush_cache(self):
'''
Use a cache to save state changes to avoid opening a session for every change.
The cache will be flushed at the end of the simulation, and when history is accessed.
'''
if self.readonly:
raise Exception('DB in readonly mode')
logger.debug('Flushing cache {}'.format(self.db_path))
with self.db:
self.db.executemany("replace into history(agent_id, t_step, key, value) values (?, ?, ?, ?)", self._tups)
# (rec.agent_id, rec.t_step, rec.key, rec.value))
self._tups.clear()
def to_tuples(self):
self.flush_cache()
with self.db:
res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall()
for r in res:
agent_id, t_step, key, value = r
if key not in self._dtypes:
self._read_types()
if key not in self._dtypes:
raise ValueError("Unknown datatype for {} and {}".format(key, value))
value = self._dtypes[key][2](value)
yield agent_id, t_step, key, value
def _read_types(self):
with self.db:
res = self.db.execute("select key, value_type from value_types ").fetchall()
for k, v in res:
serializer = serialization.serializer(v)
deserializer = serialization.deserializer(v)
self._dtypes[k] = (v, serializer, deserializer)
def __getitem__(self, key):
# raise NotImplementedError()
self.flush_cache()
key = Key(*key)
agent_ids = [key.agent_id] if key.agent_id is not None else []
t_steps = [key.t_step] if key.t_step is not None else []
keys = [key.key] if key.key is not None else []
df = self.read_sql(agent_ids=agent_ids,
t_steps=t_steps,
keys=keys)
r = Records(df, filter=key, dtypes=self._dtypes)
if r.resolved:
return r.value()
return r
def read_sql(self, keys=None, agent_ids=None, not_agent_ids=None, t_steps=None, convert_types=False, limit=-1):
self._read_types()
def escape_and_join(v):
if v is None:
return
return ",".join(map(lambda x: "\'{}\'".format(x), v))
filters = [("key in ({})".format(escape_and_join(keys)), keys),
("agent_id in ({})".format(escape_and_join(agent_ids)), agent_ids),
("agent_id not in ({})".format(escape_and_join(not_agent_ids)), not_agent_ids)
]
filters = list(k[0] for k in filters if k[1])
last_df = None
if t_steps:
# Convert negative indices into positive
if any(x<0 for x in t_steps):
max_t = int(self.db.execute("select max(t_step) from history").fetchone()[0])
t_steps = [t if t>0 else max_t+1+t for t in t_steps]
# We will be doing ffill interpolation, so we need to look for
# the last value before the minimum step in the query
min_step = min(t_steps)
last_filters = ['t_step < {}'.format(min_step),]
last_filters = last_filters + filters
condition = ' and '.join(last_filters)
last_query = '''
select h1.*
from history h1
inner join (
select agent_id, key, max(t_step) as t_step
from history
where {condition}
group by agent_id, key
) h2
on h1.agent_id = h2.agent_id and
h1.key = h2.key and
h1.t_step = h2.t_step
'''.format(condition=condition)
last_df = pd.read_sql_query(last_query, self.db)
filters.append("t_step >= '{}' and t_step <= '{}'".format(min_step, max(t_steps)))
condition = ''
if filters:
condition = 'where {} '.format(' and '.join(filters))
query = 'select * from history {} limit {}'.format(condition, limit)
df = pd.read_sql_query(query, self.db)
if last_df is not None:
df = pd.concat([df, last_df])
df_p = df.pivot_table(values='value', index=['t_step'],
columns=['key', 'agent_id'],
aggfunc='first')
for k, v in self._dtypes.items():
if k in df_p:
dtype, _, deserial = v
try:
df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
except (TypeError, ValueError):
# Avoid forward-filling unknown/incompatible types
continue
if t_steps:
df_p = df_p.reindex(t_steps, method='ffill')
return df_p.ffill()
def __getstate__(self):
state = dict(**self.__dict__)
del state['_db']
del state['_dtypes']
return state
def __setstate__(self, state):
self.__dict__ = state
self._dtypes = {}
self._db = None
def dump(self, f):
self._close()
for line in open_or_reuse(self.db_path, 'rb'):
f.write(line)
class Records():
def __init__(self, df, filter=None, dtypes=None):
if not filter:
filter = Key(agent_id=None,
t_step=None,
key=None)
self._df = df
self._filter = filter
self.dtypes = dtypes or {}
super().__init__()
def mask(self, tup):
res = ()
for i, k in zip(tup[:-1], self._filter):
if k is None:
res = res + (i,)
res = res + (tup[-1],)
return res
def filter(self, newKey):
f = list(self._filter)
for ix, i in enumerate(f):
if i is None:
f[ix] = newKey
self._filter = Key(*f)
@property
def resolved(self):
return sum(1 for i in self._filter if i is not None) == 3
def __iter__(self):
for column, series in self._df.iteritems():
key, agent_id = column
for t_step, value in series.iteritems():
r = Record(t_step=t_step,
agent_id=agent_id,
key=key,
value=value)
yield self.mask(r)
def value(self):
if self.resolved:
f = self._filter
try:
i = self._df[f.key][str(f.agent_id)]
ix = i.index.get_loc(f.t_step, method='ffill')
return i.iloc[ix]
except KeyError as ex:
return self.dtypes[f.key][2]()
return list(self)
def df(self):
return self._df
def __getitem__(self, k):
n = copy.copy(self)
n.filter(k)
if n.resolved:
return n.value()
return n
def __len__(self):
return len(self._df)
def __str__(self):
if self.resolved:
return str(self.value())
return '<Records for [{}]>'.format(self._filter)
Key = namedtuple('Key', ['agent_id', 't_step', 'key'])
Record = namedtuple('Record', 'agent_id t_step key value')
Stat = namedtuple('Stat', 'trial_id')

@ -9,6 +9,7 @@ import networkx as nx
from networkx.readwrite import json_graph
from multiprocessing import Pool
from functools import partial
from tsih import History
import pickle
@ -17,7 +18,6 @@ from .environment import Environment
from .utils import logger
from .exporters import default
from .stats import defaultStats
from .history import History
#TODO: change documentation for simulation
@ -159,7 +159,7 @@ class Simulation:
**kwargs)
def run_gen(self, *args, parallel=False, dry_run=False,
exporters=[default, ], stats=[defaultStats], outdir=None, exporter_params={},
exporters=[default, ], stats=[], outdir=None, exporter_params={},
stats_params={}, log_level=None,
**kwargs):
'''Run the simulation and yield the resulting environments.'''

@ -97,7 +97,7 @@ class defaultStats(Stats):
return {
'network ': {
'n_nodes': env.G.number_of_nodes(),
'n_edges': env.G.number_of_nodes(),
'n_edges': env.G.number_of_edges(),
},
'agents': {
'model_count': dict(c),

@ -26,6 +26,8 @@ def timer(name='task', pre="", function=logger.info, to_object=None):
to_object.end = end
def safe_open(path, mode='r', backup=True, **kwargs):
outdir = os.path.dirname(path)
if outdir and not os.path.exists(outdir):

@ -67,13 +67,13 @@ class TestAnalysis(TestCase):
def test_count(self):
env = self.env
df = analysis.read_sql(env._history.db_path)
res = analysis.get_count(df, 'SEED', 'id')
res = analysis.get_count(df, 'SEED', 'state_id')
assert res['SEED'][self.env['SEED']].iloc[0] == 1
assert res['SEED'][self.env['SEED']].iloc[-1] == 1
assert res['id']['odd'].iloc[0] == 2
assert res['id']['even'].iloc[0] == 0
assert res['id']['odd'].iloc[-1] == 1
assert res['id']['even'].iloc[-1] == 1
assert res['state_id']['odd'].iloc[0] == 2
assert res['state_id']['even'].iloc[0] == 0
assert res['state_id']['odd'].iloc[-1] == 1
assert res['state_id']['even'].iloc[-1] == 1
def test_value(self):
env = self.env

@ -9,7 +9,7 @@ from functools import partial
from os.path import join
from soil import (simulation, Environment, agents, serialization,
history, utils)
utils)
from soil.time import Delta
@ -21,8 +21,8 @@ class CustomAgent(agents.FSM):
@agents.default_state
@agents.state
def normal(self):
self.state['neighbors'] = self.count_agents(state_id='normal',
limit_neighbors=True)
self.neighbors = self.count_agents(state_id='normal',
limit_neighbors=True)
@agents.state
def unreachable(self):
return
@ -116,7 +116,7 @@ class TestMain(TestCase):
'network_agents': [{
'agent_type': 'AggregatedCounter',
'weight': 1,
'state': {'id': 0}
'state': {'state_id': 0}
}],
'max_time': 10,
@ -149,10 +149,9 @@ class TestMain(TestCase):
}
s = simulation.from_config(config)
env = s.run_simulation(dry_run=True)[0]
assert env.get_agent(0).state['neighbors'] == 1
assert env.get_agent(0).state['neighbors'] == 1
assert env.get_agent(1).count_agents(state_id='normal') == 2
assert env.get_agent(1).count_agents(state_id='normal', limit_neighbors=True) == 1
assert env.get_agent(0).neighbors == 1
def test_torvalds_example(self):
"""A complete example from a documentation should work."""
@ -317,12 +316,6 @@ class TestMain(TestCase):
assert recovered['key', 0] == 'test'
assert recovered['key'] == 'test'
def test_history(self):
'''Test storing in and retrieving from history (sqlite)'''
h = history.History()
h.save_record(agent_id=0, t_step=0, key="test", value="hello")
assert h[0, 0, "test"] == "hello"
def test_subgraph(self):
'''An agent should be able to subgraph the global topology'''
G = nx.Graph()
@ -350,12 +343,13 @@ class TestMain(TestCase):
'network_params': {},
'agent_type': 'CounterModel',
'max_time': 2,
'num_trials': 100,
'num_trials': 50,
'environment_params': {}
}
s = simulation.from_config(config)
runs = list(s.run_simulation(dry_run=True))
over = list(x.now for x in runs if x.now>2)
assert len(runs) == config['num_trials']
assert len(over) == 0
@ -372,11 +366,11 @@ class TestMain(TestCase):
return self.ping
a = ToggleAgent(unique_id=1, model=Environment())
assert a.state["id"] == a.ping.id
assert a.state_id == a.ping.id
a.step()
assert a.state["id"] == a.pong.id
assert a.state_id == a.pong.id
a.step()
assert a.state["id"] == a.ping.id
assert a.state_id == a.ping.id
def test_fsm_when(self):
'''Basic state change'''

Loading…
Cancel
Save