1
0
mirror of https://github.com/gsi-upm/soil synced 2024-12-22 08:18:13 +00:00

Refactoring v0.15.1

See CHANGELOG.md for a full list of changes

* Removed nxsim
* Refactored `agents.NetworkAgent` and `agents.BaseAgent`
* Refactored exporters
* Added stats to history
This commit is contained in:
J. Fernando Sánchez 2020-10-19 13:14:48 +02:00
parent 3b2c6a3db5
commit 05f7f49233
29 changed files with 4847 additions and 572 deletions

View File

@ -3,6 +3,29 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.15.1]
### Added
* read-only `History`
### Fixed
* Serialization problem with the `Environment` on parallel mode.
* Analysis functions now work as they should in the tutorial
## [0.15.0]
### Added
* Control logging level in CLI and simulation
* `Stats` to calculate trial and simulation-wide statistics
* Simulation statistics are stored in a separate table in history (see `History.get_stats` and `History.save_stats`, as well as `soil.stats`)
* Aliased `NetworkAgent.G` to `NetworkAgent.topology`.
### Changed
* Templates in config files can be given as dictionaries in addition to strings
* Samplers are used more explicitly
* Removed nxsim dependency. We had already made a lot of changes, and nxsim has not been updated in 5 years.
* Exporter methods renamed to `trial` and `end`. Added `start`.
* `Distribution` exporter now a stats class
* `global_topology` renamed to `topology`
* Moved topology-related methods to `NetworkAgent`
### Fixed
* Temporary files used for history in dry_run mode are not longer left open
## [0.14.9]
### Changed
* Seed random before environment initialization

View File

@ -31,7 +31,7 @@
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = []
extensions = ['IPython.sphinxext.ipython_console_highlighting']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
@ -69,7 +69,7 @@ language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '**.ipynb_checkpoints']
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'

View File

@ -218,3 +218,24 @@ These agents are programmed in much the same way as network agents, the only dif
You may use environment agents to model events that a normal agent cannot control, such as natural disasters or chance.
They are also useful to add behavior that has little to do with the network and the interactions within that network.
Templating
==========
Sometimes, it is useful to parameterize a simulation and run it over a range of values in order to compare each run and measure the effect of those parameters in the simulation.
For instance, you may want to run a simulation with different agent distributions.
This can be done in Soil using **templates**.
A template is a configuration where some of the values are specified with a variable.
e.g., ``weight: "{{ var1 }}"`` instead of ``weight: 1``.
There are two types of variables, depending on how their values are decided:
* Fixed. A list of values is provided, and a new simulation is run for each possible value. If more than a variable is given, a new simulation will be run per combination of values.
* Bounded/Sampled. The bounds of the variable are provided, along with a sampler method, which will be used to compute all the configuration combinations.
When fixed and bounded variables are mixed, Soil generates a new configuration per combination of fixed values and bounded values.
Here is an example with a single fixed variable and two bounded variable:
.. literalinclude:: ../examples/template.yml
:language: yaml

View File

@ -500,7 +500,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.8.5"
},
"toc": {
"colors": {

View File

@ -80800,7 +80800,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.8.6"
}
},
"nbformat": 4,

View File

@ -1,4 +1,4 @@
from soil.agents import FSM, state, default_state, BaseAgent
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
from enum import Enum
from random import random, choice
from itertools import islice
@ -80,7 +80,7 @@ class RabbitModel(FSM):
self.env.add_edge(self['mate'], child.id)
# self.add_edge()
self.debug('A BABY IS COMING TO LIFE')
self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.global_topology.number_of_nodes())+1
self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.topology.number_of_nodes())+1
self.debug('Rabbits alive: {}'.format(self.env['rabbits_alive']))
self['offspring'] += 1
self.env.get_agent(self['mate'])['offspring'] += 1
@ -97,12 +97,14 @@ class RabbitModel(FSM):
return
class RandomAccident(BaseAgent):
class RandomAccident(NetworkAgent):
level = logging.DEBUG
def step(self):
rabbits_total = self.global_topology.number_of_nodes()
rabbits_total = self.topology.number_of_nodes()
if 'rabbits_alive' not in self.env:
self.env['rabbits_alive'] = 0
rabbits_alive = self.env.get('rabbits_alive', rabbits_total)
prob_death = self.env.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
self.debug('Killing some rabbits with prob={}!'.format(prob_death))
@ -116,5 +118,5 @@ class RandomAccident(BaseAgent):
self.log('Rabbits alive: {}'.format(self.env['rabbits_alive']))
i.set_state(i.dead)
self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total))
if self.count_agents(state_id=RabbitModel.dead.id) == self.global_topology.number_of_nodes():
if self.count_agents(state_id=RabbitModel.dead.id) == self.topology.number_of_nodes():
self.die()

View File

@ -1,13 +1,8 @@
---
vars:
bounds:
x1: [0, 1]
x2: [1, 2]
fixed:
x3: ["a", "b", "c"]
sampler: "SALib.sample.morris.sample"
samples: 10
template: |
sampler:
method: "SALib.sample.morris.sample"
N: 10
template:
group: simple
num_trials: 1
interval: 1
@ -19,11 +14,17 @@ template: |
n: 10
network_agents:
- agent_type: CounterModel
weight: {{ x1 }}
weight: "{{ x1 }}"
state:
id: 0
- agent_type: AggregatedCounter
weight: {{ 1 - x1 }}
weight: "{{ 1 - x1 }}"
environment_params:
name: {{ x3 }}
name: "{{ x3 }}"
skip_test: true
vars:
bounds:
x1: [0, 1]
x2: [1, 2]
fixed:
x3: ["a", "b", "c"]

View File

@ -195,14 +195,14 @@ class TerroristNetworkModel(TerroristSpreadModel):
break
def get_distance(self, target):
source_x, source_y = nx.get_node_attributes(self.global_topology, 'pos')[self.id]
target_x, target_y = nx.get_node_attributes(self.global_topology, 'pos')[target]
source_x, source_y = nx.get_node_attributes(self.topology, 'pos')[self.id]
target_x, target_y = nx.get_node_attributes(self.topology, 'pos')[target]
dx = abs( source_x - target_x )
dy = abs( source_y - target_y )
return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 )
def shortest_path_length(self, target):
try:
return nx.shortest_path_length(self.global_topology, self.id, target)
return nx.shortest_path_length(self.topology, self.id, target)
except nx.NetworkXNoPath:
return float('inf')

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,5 @@
nxsim>=0.1.2
simpy
networkx>=2.0,<2.4
simpy>=4.0
networkx>=2.5
numpy
matplotlib
pyyaml>=5.1

View File

@ -1 +1 @@
0.14.9
0.15.1

View File

@ -17,12 +17,12 @@ from .environment import Environment
from .history import History
from . import serialization
from . import analysis
from .utils import logger
def main():
import argparse
from . import simulation
logging.basicConfig(level=logging.INFO)
logging.info('Running SOIL version: {}'.format(__version__))
parser = argparse.ArgumentParser(description='Run a SOIL simulation')
@ -40,6 +40,8 @@ def main():
help='Dump GEXF graph. Defaults to false.')
parser.add_argument('--csv', action='store_true',
help='Dump history in CSV format. Defaults to false.')
parser.add_argument('--level', type=str,
help='Logging level')
parser.add_argument('--output', '-o', type=str, default="soil_output",
help='folder to write results to. It defaults to the current directory.')
parser.add_argument('--synchronous', action='store_true',
@ -48,6 +50,7 @@ def main():
help='Export environment and/or simulations using this exporter')
args = parser.parse_args()
logging.basicConfig(level=getattr(logging, (args.level or 'INFO').upper()))
if os.getcwd() not in sys.path:
sys.path.append(os.getcwd())

View File

@ -9,7 +9,7 @@ class BassModel(BaseAgent):
imitation_prob
"""
def __init__(self, environment, agent_id, state):
def __init__(self, environment, agent_id, state, **kwargs):
super().__init__(environment=environment, agent_id=agent_id, state=state)
env_params = environment.environment_params
self.state['sentimentCorrelation'] = 0
@ -19,7 +19,7 @@ class BassModel(BaseAgent):
def behaviour(self):
# Outside effects
if random.random() < self.state_params['innovation_prob']:
if random.random() < self['innovation_prob']:
if self.state['id'] == 0:
self.state['id'] = 1
self.state['sentimentCorrelation'] = 1
@ -32,7 +32,7 @@ class BassModel(BaseAgent):
if self.state['id'] == 0:
aware_neighbors = self.get_neighboring_agents(state_id=1)
num_neighbors_aware = len(aware_neighbors)
if random.random() < (self.state_params['imitation_prob']*num_neighbors_aware):
if random.random() < (self['imitation_prob']*num_neighbors_aware):
self.state['id'] = 1
self.state['sentimentCorrelation'] = 1

View File

@ -1,7 +1,7 @@
from . import BaseAgent
from . import NetworkAgent
class CounterModel(BaseAgent):
class CounterModel(NetworkAgent):
"""
Dummy behaviour. It counts the number of nodes in the network and neighbors
in each step and adds it to its state.
@ -9,14 +9,14 @@ class CounterModel(BaseAgent):
def step(self):
# Outside effects
total = len(list(self.get_all_agents()))
total = len(list(self.get_agents()))
neighbors = len(list(self.get_neighboring_agents()))
self['times'] = self.get('times', 0) + 1
self['neighbors'] = neighbors
self['total'] = total
class AggregatedCounter(BaseAgent):
class AggregatedCounter(NetworkAgent):
"""
Dummy behaviour. It counts the number of nodes in the network and neighbors
in each step and adds it to its state.
@ -33,6 +33,6 @@ class AggregatedCounter(BaseAgent):
self['times'] += 1
neighbors = len(list(self.get_neighboring_agents()))
self['neighbors'] += neighbors
total = len(list(self.get_all_agents()))
total = len(list(self.get_agents()))
self['total'] += total
self.debug('Running for step: {}. Total: {}'.format(self.now, total))

View File

@ -3,19 +3,19 @@
# for x in range(0, settings.network_params["number_of_nodes"]):
# sentimentCorrelationNodeArray.append({'id': x})
# Initialize agent states. Let's assume everyone is normal.
import nxsim
import logging
from collections import OrderedDict
from copy import deepcopy
from functools import partial
from scipy.spatial import cKDTree as KDTree
import json
import simpy
from functools import wraps
from .. import serialization, history
from .. import serialization, history, utils
def as_node(agent):
@ -24,7 +24,7 @@ def as_node(agent):
return agent
class BaseAgent(nxsim.BaseAgent):
class BaseAgent:
"""
A special simpy BaseAgent that keeps track of its state history.
"""
@ -32,14 +32,13 @@ class BaseAgent(nxsim.BaseAgent):
defaults = {}
def __init__(self, environment, agent_id, state=None,
name=None, interval=None, **state_params):
name=None, interval=None):
# Check for REQUIRED arguments
assert environment is not None, TypeError('__init__ missing 1 required keyword argument: \'environment\'. '
'Cannot be NoneType.')
# Initialize agent parameters
self.id = agent_id
self.name = name or '{}[{}]'.format(type(self).__name__, self.id)
self.state_params = state_params
# Register agent to environment
self.env = environment
@ -51,10 +50,10 @@ class BaseAgent(nxsim.BaseAgent):
self.state = real_state
self.interval = interval
if not hasattr(self, 'level'):
self.level = logging.DEBUG
self.logger = logging.getLogger(self.env.name)
self.logger.setLevel(self.level)
self.logger = logging.getLogger(self.env.name).getChild(self.name)
if hasattr(self, 'level'):
self.logger.setLevel(self.level)
# initialize every time an instance of the agent is created
self.action = self.env.process(self.run())
@ -75,14 +74,10 @@ class BaseAgent(nxsim.BaseAgent):
for k, v in value.items():
self[k] = v
@property
def global_topology(self):
return self.env.G
@property
def environment_params(self):
return self.env.environment_params
@environment_params.setter
def environment_params(self, value):
self.env.environment_params = value
@ -135,36 +130,10 @@ class BaseAgent(nxsim.BaseAgent):
def die(self, remove=False):
self.alive = False
if remove:
super().die()
self.remove_node(self.id)
def step(self):
pass
def count_agents(self, **kwargs):
return len(list(self.get_agents(**kwargs)))
def count_neighboring_agents(self, state_id=None, **kwargs):
return len(super().get_neighboring_agents(state_id=state_id, **kwargs))
def get_neighboring_agents(self, state_id=None, **kwargs):
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
def get_agents(self, agents=None, limit_neighbors=False, **kwargs):
if limit_neighbors:
agents = super().get_agents(limit_neighbors=limit_neighbors)
else:
agents = self.env.get_agents(agents)
return select(agents, **kwargs)
def log(self, message, *args, level=logging.INFO, **kwargs):
message = message + " ".join(str(i) for i in args)
message = "\t{:10}@{:>5}:\t{}".format(self.name, self.now, message)
for k, v in kwargs:
message += " {k}={v} ".format(k, v)
extra = {}
extra['now'] = self.now
extra['id'] = self.id
return self.logger.log(level, message, extra=extra)
return
def debug(self, *args, **kwargs):
return self.log(*args, level=logging.DEBUG, **kwargs)
@ -192,24 +161,59 @@ class BaseAgent(nxsim.BaseAgent):
self._state = state['_state']
self.env = state['environment']
def add_edge(self, node1, node2, **attrs):
node1 = as_node(node1)
node2 = as_node(node2)
class NetworkAgent(BaseAgent):
for n in [node1, node2]:
if n not in self.global_topology.nodes(data=False):
raise ValueError('"{}" not in the graph'.format(n))
return self.global_topology.add_edge(node1, node2, **attrs)
@property
def topology(self):
return self.env.G
@property
def G(self):
return self.env.G
def count_agents(self, **kwargs):
return len(list(self.get_agents(**kwargs)))
def count_neighboring_agents(self, state_id=None, **kwargs):
return len(self.get_neighboring_agents(state_id=state_id, **kwargs))
def get_neighboring_agents(self, state_id=None, **kwargs):
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
def get_agents(self, agents=None, limit_neighbors=False, **kwargs):
if limit_neighbors:
agents = self.topology.neighbors(self.id)
agents = self.env.get_agents(agents)
return select(agents, **kwargs)
def log(self, message, *args, level=logging.INFO, **kwargs):
message = message + " ".join(str(i) for i in args)
message = " @{:>3}: {}".format(self.now, message)
for k, v in kwargs:
message += " {k}={v} ".format(k, v)
extra = {}
extra['now'] = self.now
extra['agent_id'] = self.id
extra['agent_name'] = self.name
return self.logger.log(level, message, extra=extra)
def subgraph(self, center=True, **kwargs):
include = [self] if center else []
return self.global_topology.subgraph(n.id for n in self.get_agents(**kwargs)+include)
return self.topology.subgraph(n.id for n in self.get_agents(**kwargs)+include)
def remove_node(self, agent_id):
self.topology.remove_node(agent_id)
class NetworkAgent(BaseAgent):
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
# return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
if self.id not in self.topology.nodes(data=False):
raise ValueError('{} not in list of existing agents in the network'.format(self.id))
if other not in self.topology.nodes(data=False):
raise ValueError('{} not in list of existing agents in the network'.format(other))
self.topology.add_edge(self.id, other, edge_attr_dict=edge_attr_dict, *edge_attrs)
def add_edge(self, other, **kwargs):
return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
def ego_search(self, steps=1, center=False, node=None, **kwargs):
'''Get a list of nodes in the ego network of *node* of radius *steps*'''
@ -220,14 +224,14 @@ class NetworkAgent(BaseAgent):
def degree(self, node, force=False):
node = as_node(node)
if force or (not hasattr(self.env, '_degree')) or getattr(self.env, '_last_step', 0) < self.now:
self.env._degree = nx.degree_centrality(self.global_topology)
self.env._degree = nx.degree_centrality(self.topology)
self.env._last_step = self.now
return self.env._degree[node]
def betweenness(self, node, force=False):
node = as_node(node)
if force or (not hasattr(self.env, '_betweenness')) or getattr(self.env, '_last_step', 0) < self.now:
self.env._betweenness = nx.betweenness_centrality(self.global_topology)
self.env._betweenness = nx.betweenness_centrality(self.topology)
self.env._last_step = self.now
return self.env._betweenness[node]
@ -292,16 +296,22 @@ class MetaFSM(type):
cls.states = states
class FSM(BaseAgent, metaclass=MetaFSM):
class FSM(NetworkAgent, metaclass=MetaFSM):
def __init__(self, *args, **kwargs):
super(FSM, self).__init__(*args, **kwargs)
if 'id' not in self.state:
if not self.default_state:
raise ValueError('No default state specified for {}'.format(self.id))
self['id'] = self.default_state.id
self._next_change = simpy.core.Infinity
self._next_state = self.state
def step(self):
if 'id' in self.state:
if self._next_change < self.now:
next_state = self._next_state
self._next_change = simpy.core.Infinity
self['id'] = next_state
elif 'id' in self.state:
next_state = self['id']
elif self.default_state:
next_state = self.default_state.id
@ -311,6 +321,10 @@ class FSM(BaseAgent, metaclass=MetaFSM):
raise Exception('{} is not a valid id for {}'.format(next_state, self))
return self.states[next_state](self)
def next_state(self, state):
self._next_change = self.now
self._next_state = state
def set_state(self, state):
if hasattr(state, 'id'):
state = state.id
@ -371,14 +385,18 @@ def calculate_distribution(network_agents=None,
else:
raise ValueError('Specify a distribution or a default agent type')
# Fix missing weights and incompatible types
for x in network_agents:
x['weight'] = float(x.get('weight', 1))
# Calculate the thresholds
total = sum(x.get('weight', 1) for x in network_agents)
total = sum(x['weight'] for x in network_agents)
acc = 0
for v in network_agents:
if 'ids' in v:
v['threshold'] = STATIC_THRESHOLD
continue
upper = acc + (v.get('weight', 1)/total)
upper = acc + (v['weight']/total)
v['threshold'] = [acc, upper]
acc = upper
return network_agents
@ -425,7 +443,7 @@ def _validate_states(states, topology):
states = states or []
if isinstance(states, dict):
for x in states:
assert x in topology.node
assert x in topology.nodes
else:
assert len(states) <= len(topology)
return states

View File

@ -28,13 +28,13 @@ def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
df = read_csv(trial_data, **kwargs)
yield config_file, df, config
else:
for trial_data in sorted(glob.glob(join(folder, '*.db.sqlite'))):
for trial_data in sorted(glob.glob(join(folder, '*.sqlite'))):
df = read_sql(trial_data, **kwargs)
yield config_file, df, config
def read_sql(db, *args, **kwargs):
h = history.History(db_path=db, backup=False)
h = history.History(db_path=db, backup=False, readonly=True)
df = h.read_sql(*args, **kwargs)
return df
@ -69,6 +69,13 @@ def convert_types_slow(df):
df = df.apply(convert_row, axis=1)
return df
def split_processed(df):
env = df.loc[:, df.columns.get_level_values(1).isin(['env', 'stats'])]
agents = df.loc[:, ~df.columns.get_level_values(1).isin(['env', 'stats'])]
return env, agents
def split_df(df):
'''
Split a dataframe in two dataframes: one with the history of agents,
@ -136,7 +143,7 @@ def get_value(df, *keys, aggfunc='sum'):
return df.groupby(axis=1, level=0).agg(aggfunc)
def plot_all(*args, **kwargs):
def plot_all(*args, plot_args={}, **kwargs):
'''
Read all the trial data and plot the result of applying a function on them.
'''
@ -144,14 +151,17 @@ def plot_all(*args, **kwargs):
ps = []
for line in dfs:
f, df, config = line
df.plot(title=config['name'])
if len(df) < 1:
continue
df.plot(title=config['name'], **plot_args)
ps.append(df)
return ps
def do_all(pattern, func, *keys, include_env=False, **kwargs):
for config_file, df, config in read_data(pattern, keys=keys):
if len(df) < 1:
continue
p = func(df, *keys, **kwargs)
p.plot(title=config['name'])
yield config_file, p, config

View File

@ -8,11 +8,10 @@ import yaml
import tempfile
import pandas as pd
from copy import deepcopy
from collections import Counter
from networkx.readwrite import json_graph
import networkx as nx
import nxsim
import simpy
from . import serialization, agents, analysis, history, utils
@ -23,7 +22,7 @@ _CONFIG_PROPS = [ 'name',
'interval',
]
class Environment(nxsim.NetworkEnvironment):
class Environment(simpy.Environment):
"""
The environment is key in a simulation. It contains the network topology,
a reference to network and environment agents, as well as the environment
@ -42,7 +41,10 @@ class Environment(nxsim.NetworkEnvironment):
interval=1,
seed=None,
topology=None,
*args, **kwargs):
initial_time=0,
**environment_params):
self.name = name or 'UnnamedEnvironment'
seed = seed or time.time()
random.seed(seed)
@ -52,7 +54,11 @@ class Environment(nxsim.NetworkEnvironment):
self.default_state = deepcopy(default_state) or {}
if not topology:
topology = nx.Graph()
super().__init__(*args, topology=topology, **kwargs)
self.G = nx.Graph(topology)
super().__init__(initial_time=initial_time)
self.environment_params = environment_params
self._env_agents = {}
self.interval = interval
self._history = history.History(name=self.name,
@ -151,12 +157,10 @@ class Environment(nxsim.NetworkEnvironment):
start = start or self.now
return self.G.add_edge(agent1, agent2, **attrs)
def run(self, *args, **kwargs):
def run(self, until, *args, **kwargs):
self._save_state()
self.log_stats()
super().run(*args, **kwargs)
super().run(until, *args, **kwargs)
self._history.flush_cache()
self.log_stats()
def _save_state(self, now=None):
serialization.logger.debug('Saving state @{}'.format(self.now))
@ -318,25 +322,6 @@ class Environment(nxsim.NetworkEnvironment):
return G
def stats(self):
stats = {}
stats['network'] = {}
stats['network']['n_nodes'] = self.G.number_of_nodes()
stats['network']['n_edges'] = self.G.number_of_edges()
c = Counter()
c.update(a.__class__.__name__ for a in self.network_agents)
stats['agents'] = {}
stats['agents']['model_count'] = dict(c)
c2 = Counter()
c2.update(a['id'] for a in self.network_agents)
stats['agents']['state_count'] = dict(c2)
stats['params'] = self.environment_params
return stats
def log_stats(self):
stats = self.stats()
serialization.logger.info('Environment stats: \n{}'.format(yaml.dump(stats, default_flow_style=False)))
def __getstate__(self):
state = {}
for prop in _CONFIG_PROPS:
@ -344,6 +329,7 @@ class Environment(nxsim.NetworkEnvironment):
state['G'] = json_graph.node_link_data(self.G)
state['environment_agents'] = self._env_agents
state['history'] = self._history
state['_now'] = self._now
return state
def __setstate__(self, state):
@ -352,6 +338,8 @@ class Environment(nxsim.NetworkEnvironment):
self._env_agents = state['environment_agents']
self.G = json_graph.node_link_graph(state['G'])
self._history = state['history']
self._now = state['_now']
self._queue = []
SoilEnvironment = Environment

View File

@ -1,10 +1,11 @@
import os
import csv as csvlib
import time
from io import BytesIO
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
from .serialization import deserialize
from .utils import open_or_reuse, logger, timer
@ -49,7 +50,7 @@ class Exporter:
'''
def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None):
self.sim = simulation
self.simulation = simulation
outdir = outdir or os.path.join(os.getcwd(), 'soil_output')
self.outdir = os.path.join(outdir,
simulation.group or '',
@ -59,12 +60,15 @@ class Exporter:
def start(self):
'''Method to call when the simulation starts'''
pass
def end(self):
def end(self, stats):
'''Method to call when the simulation ends'''
pass
def trial_end(self, env):
def trial(self, env, stats):
'''Method to call when a trial ends'''
pass
def output(self, f, mode='w', **kwargs):
if self.dry_run:
@ -84,13 +88,13 @@ class default(Exporter):
def start(self):
if not self.dry_run:
logger.info('Dumping results to %s', self.outdir)
self.sim.dump_yaml(outdir=self.outdir)
self.simulation.dump_yaml(outdir=self.outdir)
else:
logger.info('NOT dumping results')
def trial_end(self, env):
def trial(self, env, stats):
if not self.dry_run:
with timer('Dumping simulation {} trial {}'.format(self.sim.name,
with timer('Dumping simulation {} trial {}'.format(self.simulation.name,
env.name)):
with self.output('{}.sqlite'.format(env.name), mode='wb') as f:
env.dump_sqlite(f)
@ -98,21 +102,27 @@ class default(Exporter):
class csv(Exporter):
'''Export the state of each environment (and its agents) in a separate CSV file'''
def trial_end(self, env):
with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.sim.name,
def trial(self, env, stats):
with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name,
env.name,
self.outdir)):
with self.output('{}.csv'.format(env.name)) as f:
env.dump_csv(f)
with self.output('{}.stats.csv'.format(env.name)) as f:
statwriter = csvlib.writer(f, delimiter='\t', quotechar='"', quoting=csvlib.QUOTE_ALL)
for stat in stats:
statwriter.writerow(stat)
class gexf(Exporter):
def trial_end(self, env):
def trial(self, env, stats):
if self.dry_run:
logger.info('Not dumping GEXF in dry_run mode')
return
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.sim.name,
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name,
env.name)):
with self.output('{}.gexf'.format(env.name), mode='wb') as f:
env.dump_gexf(f)
@ -124,56 +134,24 @@ class dummy(Exporter):
with self.output('dummy', 'w') as f:
f.write('simulation started @ {}\n'.format(time.time()))
def trial_end(self, env):
def trial(self, env, stats):
with self.output('dummy', 'w') as f:
for i in env.history_to_tuples():
f.write(','.join(map(str, i)))
f.write('\n')
def end(self):
def sim(self, stats):
with self.output('dummy', 'a') as f:
f.write('simulation ended @ {}\n'.format(time.time()))
class distribution(Exporter):
'''
Write the distribution of agent states at the end of each trial,
the mean value, and its deviation.
'''
def start(self):
self.means = []
self.counts = []
def trial_end(self, env):
df = env[None, None, None].df()
ix = df.index[-1]
attrs = df.columns.levels[0]
vc = {}
stats = {}
for a in attrs:
t = df.loc[(ix, a)]
try:
self.means.append(('mean', a, t.mean()))
except TypeError:
for name, count in t.value_counts().iteritems():
self.counts.append(('count', a, name, count))
def end(self):
dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value'])
dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count'])
dfm = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
dfc = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
with self.output('counts.csv') as f:
dfc.to_csv(f)
with self.output('metrics.csv') as f:
dfm.to_csv(f)
class graphdrawing(Exporter):
def trial_end(self, env):
def trial(self, env, stats):
# Outside effects
f = plt.figure()
nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111))
with open('graph-{}.png'.format(env.name)) as f:
f.savefig(f)

View File

@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
from collections import UserDict, namedtuple
from . import serialization
from .utils import open_or_reuse
from .utils import open_or_reuse, unflatten_dict
class History:
@ -19,29 +19,43 @@ class History:
Store and retrieve values from a sqlite database.
"""
def __init__(self, name=None, db_path=None, backup=False):
self._db = None
def __init__(self, name=None, db_path=None, backup=False, readonly=False):
if readonly and (not os.path.exists(db_path)):
raise Exception('The DB file does not exist. Cannot open in read-only mode')
if db_path is None:
self._db = None
self._temp = db_path is None
self._stats_columns = None
self.readonly = readonly
if self._temp:
if not name:
name = time.time()
_, db_path = tempfile.mkstemp(suffix='{}.sqlite'.format(name))
# The file will be deleted as soon as it's closed
# Normally, that will be on destruction
db_path = tempfile.NamedTemporaryFile(suffix='{}.sqlite'.format(name)).name
if backup and os.path.exists(db_path):
newname = db_path + '.backup{}.sqlite'.format(time.time())
os.rename(db_path, newname)
newname = db_path + '.backup{}.sqlite'.format(time.time())
os.rename(db_path, newname)
self.db_path = db_path
self.db = db_path
self._dtypes = {}
self._tups = []
if self.readonly:
return
with self.db:
logger.debug('Creating database {}'.format(self.db_path))
self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text text)''')
self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text)''')
self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''')
self.db.execute('''CREATE TABLE IF NOT EXISTS stats (trial_id text)''')
self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''')
self._dtypes = {}
self._tups = []
@property
def db(self):
@ -58,6 +72,7 @@ class History:
if isinstance(db_path, str):
logger.debug('Connecting to database {}'.format(db_path))
self._db = sqlite3.connect(db_path)
self._db.row_factory = sqlite3.Row
else:
self._db = db_path
@ -68,9 +83,56 @@ class History:
self._db.close()
self._db = None
def save_stats(self, stat):
if self.readonly:
print('DB in readonly mode')
return
if not stat:
return
with self.db:
if not self._stats_columns:
self._stats_columns = list(c['name'] for c in self.db.execute('PRAGMA table_info(stats)'))
for column, value in stat.items():
if column in self._stats_columns:
continue
dtype = 'text'
if not isinstance(value, str):
try:
float(value)
dtype = 'real'
int(value)
dtype = 'int'
except ValueError:
pass
self.db.execute('ALTER TABLE stats ADD "{}" "{}"'.format(column, dtype))
self._stats_columns.append(column)
columns = ", ".join(map(lambda x: '"{}"'.format(x), stat.keys()))
values = ", ".join(['"{0}"'.format(col) for col in stat.values()])
query = "INSERT INTO stats ({columns}) VALUES ({values})".format(
columns=columns,
values=values
)
self.db.execute(query)
def get_stats(self, unflatten=True):
rows = self.db.execute("select * from stats").fetchall()
res = []
for row in rows:
d = {}
for k in row.keys():
if row[k] is None:
continue
d[k] = row[k]
if unflatten:
d = unflatten_dict(d)
res.append(d)
return res
@property
def dtypes(self):
self.read_types()
self._read_types()
return {k:v[0] for k, v in self._dtypes.items()}
def save_tuples(self, tuples):
@ -93,18 +155,10 @@ class History:
Save a collection of records to the database.
Database writes are cached.
'''
value = self.convert(key, value)
self._tups.append(Record(agent_id=agent_id,
t_step=t_step,
key=key,
value=value))
if len(self._tups) > 100:
self.flush_cache()
def convert(self, key, value):
"""Get the serialized value for a given key."""
if self.readonly:
raise Exception('DB in readonly mode')
if key not in self._dtypes:
self.read_types()
self._read_types()
if key not in self._dtypes:
name = serialization.name(value)
serializer = serialization.serializer(name)
@ -112,21 +166,21 @@ class History:
self._dtypes[key] = (name, serializer, deserializer)
with self.db:
self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name))
return self._dtypes[key][1](value)
def recover(self, key, value):
"""Get the deserialized value for a given key, and the serialized version."""
if key not in self._dtypes:
self.read_types()
if key not in self._dtypes:
raise ValueError("Unknown datatype for {} and {}".format(key, value))
return self._dtypes[key][2](value)
value = self._dtypes[key][1](value)
self._tups.append(Record(agent_id=agent_id,
t_step=t_step,
key=key,
value=value))
if len(self._tups) > 100:
self.flush_cache()
def flush_cache(self):
'''
Use a cache to save state changes to avoid opening a session for every change.
The cache will be flushed at the end of the simulation, and when history is accessed.
'''
if self.readonly:
raise Exception('DB in readonly mode')
logger.debug('Flushing cache {}'.format(self.db_path))
with self.db:
for rec in self._tups:
@ -139,10 +193,14 @@ class History:
res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall()
for r in res:
agent_id, t_step, key, value = r
value = self.recover(key, value)
if key not in self._dtypes:
self._read_types()
if key not in self._dtypes:
raise ValueError("Unknown datatype for {} and {}".format(key, value))
value = self._dtypes[key][2](value)
yield agent_id, t_step, key, value
def read_types(self):
def _read_types(self):
with self.db:
res = self.db.execute("select key, value_type from value_types ").fetchall()
for k, v in res:
@ -167,7 +225,7 @@ class History:
def read_sql(self, keys=None, agent_ids=None, t_steps=None, convert_types=False, limit=-1):
self.read_types()
self._read_types()
def escape_and_join(v):
if v is None:
@ -181,7 +239,13 @@ class History:
last_df = None
if t_steps:
# Look for the last value before the minimum step in the query
# Convert negative indices into positive
if any(x<0 for x in t_steps):
max_t = int(self.db.execute("select max(t_step) from history").fetchone()[0])
t_steps = [t if t>0 else max_t+1+t for t in t_steps]
# We will be doing ffill interpolation, so we need to look for
# the last value before the minimum step in the query
min_step = min(t_steps)
last_filters = ['t_step < {}'.format(min_step),]
last_filters = last_filters + filters
@ -219,7 +283,11 @@ class History:
for k, v in self._dtypes.items():
if k in df_p:
dtype, _, deserial = v
df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
try:
df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
except (TypeError, ValueError):
# Avoid forward-filling unknown/incompatible types
continue
if t_steps:
df_p = df_p.reindex(t_steps, method='ffill')
return df_p.ffill()
@ -313,3 +381,5 @@ class Records():
Key = namedtuple('Key', ['agent_id', 't_step', 'key'])
Record = namedtuple('Record', 'agent_id t_step key value')
Stat = namedtuple('Stat', 'trial_id')

View File

@ -17,10 +17,10 @@ logger.setLevel(logging.INFO)
def load_network(network_params, dir_path=None):
if network_params is None:
return nx.Graph()
path = network_params.get('path', None)
if path:
G = nx.Graph()
if 'path' in network_params:
path = network_params['path']
if dir_path and not os.path.isabs(path):
path = os.path.join(dir_path, path)
extension = os.path.splitext(path)[1][1:]
@ -32,21 +32,22 @@ def load_network(network_params, dir_path=None):
method = getattr(nx.readwrite, 'read_' + extension)
except AttributeError:
raise AttributeError('Unknown format')
return method(path, **kwargs)
G = method(path, **kwargs)
net_args = network_params.copy()
if 'generator' not in net_args:
return nx.Graph()
elif 'generator' in network_params:
net_args = network_params.copy()
net_gen = net_args.pop('generator')
net_gen = net_args.pop('generator')
if dir_path not in sys.path:
sys.path.append(dir_path)
if dir_path not in sys.path:
sys.path.append(dir_path)
method = deserializer(net_gen,
known_modules=['networkx.generators',])
G = method(**net_args)
return G
method = deserializer(net_gen,
known_modules=['networkx.generators',])
return method(**net_args)
def load_file(infile):
@ -66,11 +67,32 @@ def expand_template(config):
raise ValueError(('You must provide a definition of variables'
' for the template.'))
template = Template(config['template'])
template = config['template']
sampler_name = config.get('sampler', 'SALib.sample.morris.sample')
n_samples = int(config.get('samples', 100))
sampler = deserializer(sampler_name)
if not isinstance(template, str):
template = yaml.dump(template)
template = Template(template)
params = params_for_template(config)
blank_str = template.render({k: 0 for k in params[0].keys()})
blank = list(load_string(blank_str))
if len(blank) > 1:
raise ValueError('Templates must not return more than one configuration')
if 'name' in blank[0]:
raise ValueError('Templates cannot be named, use group instead')
for ps in params:
string = template.render(ps)
for c in load_string(string):
yield c
def params_for_template(config):
sampler_config = config.get('sampler', {'N': 100})
sampler = sampler_config.pop('method', 'SALib.sample.morris.sample')
sampler = deserializer(sampler)
bounds = config['vars']['bounds']
problem = {
@ -78,7 +100,7 @@ def expand_template(config):
'names': list(bounds.keys()),
'bounds': list(v for v in bounds.values())
}
samples = sampler(problem, n_samples)
samples = sampler(problem, **sampler_config)
lists = config['vars'].get('lists', {})
names = list(lists.keys())
@ -88,20 +110,7 @@ def expand_template(config):
allnames = names + problem['names']
allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)]
params = list(map(lambda x: dict(zip(allnames, x)), allvalues))
blank_str = template.render({k: 0 for k in allnames})
blank = list(load_string(blank_str))
if len(blank) > 1:
raise ValueError('Templates must not return more than one configuration')
if 'name' in blank[0]:
raise ValueError('Templates cannot be named, use group instead')
confs = []
for ps in params:
string = template.render(ps)
for c in load_string(string):
yield c
return params
def load_files(*patterns, **kwargs):
@ -116,7 +125,7 @@ def load_files(*patterns, **kwargs):
def load_config(config):
if isinstance(config, dict):
yield config, None
yield config, os.getcwd()
else:
yield from load_files(config)

View File

@ -4,6 +4,7 @@ import importlib
import sys
import yaml
import traceback
import logging
import networkx as nx
from networkx.readwrite import json_graph
from multiprocessing import Pool
@ -11,17 +12,19 @@ from functools import partial
import pickle
from nxsim import NetworkSimulation
from . import serialization, utils, basestring, agents
from .environment import Environment
from .utils import logger
from .exporters import for_sim as exporters_for_sim
from .exporters import default, for_sim as exporters_for_sim
from .stats import defaultStats
from .history import History
class Simulation(NetworkSimulation):
#TODO: change documentation for simulation
class Simulation:
"""
Subclass of nsim.NetworkSimulation with three main differences:
Similar to nsim.NetworkSimulation with three main differences:
1) agent type can be specified by name or by class.
2) instead of just one type, a network agents distribution can be used.
The distribution specifies the weight (or probability) of each
@ -91,11 +94,12 @@ class Simulation(NetworkSimulation):
environment_params=None, environment_class=None,
**kwargs):
self.seed = str(seed) or str(time.time())
self.load_module = load_module
self.network_params = network_params
self.name = name or 'Unnamed_' + time.strftime("%Y-%m-%d_%H.%M.%S")
self.group = group or None
self.name = name or 'Unnamed'
self.seed = str(seed or name)
self._id = '{}_{}'.format(self.name, time.strftime("%Y-%m-%d_%H.%M.%S"))
self.group = group or ''
self.num_trials = num_trials
self.max_time = max_time
self.default_state = default_state or {}
@ -128,12 +132,15 @@ class Simulation(NetworkSimulation):
self.states = agents._validate_states(states,
self.topology)
self._history = History(name=self.name,
backup=False)
def run_simulation(self, *args, **kwargs):
return self.run(*args, **kwargs)
def run(self, *args, **kwargs):
'''Run the simulation and return the list of resulting environments'''
return list(self._run_simulation_gen(*args, **kwargs))
return list(self.run_gen(*args, **kwargs))
def _run_sync_or_async(self, parallel=False, *args, **kwargs):
if parallel:
@ -148,12 +155,16 @@ class Simulation(NetworkSimulation):
yield i
else:
for i in range(self.num_trials):
yield self.run_trial(i,
*args,
yield self.run_trial(*args,
**kwargs)
def _run_simulation_gen(self, *args, parallel=False, dry_run=False,
exporters=['default', ], outdir=None, exporter_params={}, **kwargs):
def run_gen(self, *args, parallel=False, dry_run=False,
exporters=[default, ], stats=[defaultStats], outdir=None, exporter_params={},
stats_params={}, log_level=None,
**kwargs):
'''Run the simulation and yield the resulting environments.'''
if log_level:
logger.setLevel(log_level)
logger.info('Using exporters: %s', exporters or [])
logger.info('Output directory: %s', outdir)
exporters = exporters_for_sim(self,
@ -161,31 +172,63 @@ class Simulation(NetworkSimulation):
dry_run=dry_run,
outdir=outdir,
**exporter_params)
stats = exporters_for_sim(self,
stats,
**stats_params)
with utils.timer('simulation {}'.format(self.name)):
for stat in stats:
stat.start()
for exporter in exporters:
exporter.start()
for env in self._run_sync_or_async(*args, parallel=parallel,
for env in self._run_sync_or_async(*args,
parallel=parallel,
log_level=log_level,
**kwargs):
collected = list(stat.trial(env) for stat in stats)
saved = self.save_stats(collected, t_step=env.now, trial_id=env.name)
for exporter in exporters:
exporter.trial_end(env)
exporter.trial(env, saved)
yield env
for exporter in exporters:
exporter.end()
def get_env(self, trial_id = 0, **kwargs):
collected = list(stat.end() for stat in stats)
saved = self.save_stats(collected)
for exporter in exporters:
exporter.end(saved)
def save_stats(self, collection, **kwargs):
stats = dict(kwargs)
for stat in collection:
stats.update(stat)
self._history.save_stats(utils.flatten_dict(stats))
return stats
def get_stats(self, **kwargs):
return self._history.get_stats(**kwargs)
def log_stats(self, stats):
logger.info('Stats: \n{}'.format(yaml.dump(stats, default_flow_style=False)))
def get_env(self, trial_id=0, **kwargs):
'''Create an environment for a trial of the simulation'''
opts = self.environment_params.copy()
env_name = '{}_trial_{}'.format(self.name, trial_id)
opts.update({
'name': env_name,
'name': trial_id,
'topology': self.topology.copy(),
'seed': self.seed+env_name,
'seed': '{}_trial_{}'.format(self.seed, trial_id),
'initial_time': 0,
'interval': self.interval,
'network_agents': self.network_agents,
'initial_time': 0,
'states': self.states,
'default_state': self.default_state,
'environment_agents': self.environment_agents,
@ -194,20 +237,22 @@ class Simulation(NetworkSimulation):
env = self.environment_class(**opts)
return env
def run_trial(self, trial_id=0, until=None, **opts):
"""Run a single trial of the simulation
Parameters
----------
trial_id : int
def run_trial(self, until=None, log_level=logging.INFO, **opts):
"""
Run a single trial of the simulation
"""
trial_id = '{}_trial_{}'.format(self.name, time.time()).replace('.', '-')
if log_level:
logger.setLevel(log_level)
# Set-up trial environment and graph
until = until or self.max_time
env = self.get_env(trial_id = trial_id, **opts)
env = self.get_env(trial_id=trial_id, **opts)
# Set up agents on nodes
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
env.run(until)
return env
def run_trial_exceptions(self, *args, **kwargs):
'''
A wrapper for run_trial that catches exceptions and returns them.

106
soil/stats.py Normal file
View File

@ -0,0 +1,106 @@
import pandas as pd
from collections import Counter
class Stats:
'''
Interface for all stats. It is not necessary, but it is useful
if you don't plan to implement all the methods.
'''
def __init__(self, simulation):
self.simulation = simulation
def start(self):
'''Method to call when the simulation starts'''
pass
def end(self):
'''Method to call when the simulation ends'''
return {}
def trial(self, env):
'''Method to call when a trial ends'''
return {}
class distribution(Stats):
'''
Calculate the distribution of agent states at the end of each trial,
the mean value, and its deviation.
'''
def start(self):
self.means = []
self.counts = []
def trial(self, env):
df = env[None, None, None].df()
df = df.drop('SEED', axis=1)
ix = df.index[-1]
attrs = df.columns.get_level_values(0)
vc = {}
stats = {
'mean': {},
'count': {},
}
for a in attrs:
t = df.loc[(ix, a)]
try:
stats['mean'][a] = t.mean()
self.means.append(('mean', a, t.mean()))
except TypeError:
pass
for name, count in t.value_counts().iteritems():
if a not in stats['count']:
stats['count'][a] = {}
stats['count'][a][name] = count
self.counts.append(('count', a, name, count))
return stats
def end(self):
dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value'])
dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count'])
count = {}
mean = {}
if self.means:
res = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
mean = res['value'].to_dict()
if self.counts:
res = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
for k,v in res['count'].to_dict().items():
if k not in count:
count[k] = {}
for tup, times in v.items():
subkey, subcount = tup
if subkey not in count[k]:
count[k][subkey] = {}
count[k][subkey][subcount] = times
return {'count': count, 'mean': mean}
class defaultStats(Stats):
def trial(self, env):
c = Counter()
c.update(a.__class__.__name__ for a in env.network_agents)
c2 = Counter()
c2.update(a['id'] for a in env.network_agents)
return {
'network ': {
'n_nodes': env.G.number_of_nodes(),
'n_edges': env.G.number_of_nodes(),
},
'agents': {
'model_count': dict(c),
'state_count': dict(c2),
}
}

View File

@ -7,6 +7,7 @@ from shutil import copyfile
from contextlib import contextmanager
logger = logging.getLogger('soil')
logging.basicConfig()
logger.setLevel(logging.INFO)
@ -31,14 +32,13 @@ def safe_open(path, mode='r', backup=True, **kwargs):
os.makedirs(outdir)
if backup and 'w' in mode and os.path.exists(path):
creation = os.path.getctime(path)
stamp = time.strftime('%Y-%m-%d_%H.%M', time.localtime(creation))
stamp = time.strftime('%Y-%m-%d_%H.%M.%S', time.localtime(creation))
backup_dir = os.path.join(outdir, stamp)
backup_dir = os.path.join(outdir, 'backup')
if not os.path.exists(backup_dir):
os.makedirs(backup_dir)
newpath = os.path.join(backup_dir, os.path.basename(path))
if os.path.exists(newpath):
newpath = '{}@{}'.format(newpath, time.time())
newpath = os.path.join(backup_dir, '{}@{}'.format(os.path.basename(path),
stamp))
copyfile(path, newpath)
return open(path, mode=mode, **kwargs)
@ -48,3 +48,40 @@ def open_or_reuse(f, *args, **kwargs):
return safe_open(f, *args, **kwargs)
except (AttributeError, TypeError):
return f
def flatten_dict(d):
if not isinstance(d, dict):
return d
return dict(_flatten_dict(d))
def _flatten_dict(d, prefix=''):
if not isinstance(d, dict):
# print('END:', prefix, d)
yield prefix, d
return
if prefix:
prefix = prefix + '.'
for k, v in d.items():
# print(k, v)
res = list(_flatten_dict(v, prefix='{}{}'.format(prefix, k)))
# print('RES:', res)
yield from res
def unflatten_dict(d):
out = {}
for k, v in d.items():
target = out
if not isinstance(k, str):
target[k] = v
continue
tokens = k.split('.')
if len(tokens) < 2:
target[k] = v
continue
for token in tokens[:-1]:
if token not in target:
target[token] = {}
target = target[token]
target[tokens[-1]] = v
return out

View File

@ -66,8 +66,8 @@ class TestAnalysis(TestCase):
env = self.env
df = analysis.read_sql(env._history.db_path)
res = analysis.get_count(df, 'SEED', 'id')
assert res['SEED']['seedanalysis_trial_0'].iloc[0] == 1
assert res['SEED']['seedanalysis_trial_0'].iloc[-1] == 1
assert res['SEED'][self.env['SEED']].iloc[0] == 1
assert res['SEED'][self.env['SEED']].iloc[-1] == 1
assert res['id']['odd'].iloc[0] == 2
assert res['id']['even'].iloc[0] == 0
assert res['id']['odd'].iloc[-1] == 1
@ -75,7 +75,7 @@ class TestAnalysis(TestCase):
def test_value(self):
env = self.env
df = analysis.read_sql(env._history._db)
df = analysis.read_sql(env._history.db_path)
res_sum = analysis.get_value(df, 'count')
assert res_sum['count'].iloc[0] == 2
@ -86,4 +86,4 @@ class TestAnalysis(TestCase):
res_total = analysis.get_value(df)
res_total['SEED'].iloc[0] == 'seedanalysis_trial_0'
res_total['SEED'].iloc[0] == self.env['SEED']

View File

@ -31,7 +31,7 @@ def make_example_test(path, config):
try:
n = config['network_params']['n']
assert len(list(env.network_agents)) == n
assert env.now > 2 # It has run
assert env.now > 0 # It has run
assert env.now <= config['max_time'] # But not further than allowed
except KeyError:
pass

View File

@ -6,26 +6,32 @@ from time import time
from unittest import TestCase
from soil import exporters
from soil.utils import safe_open
from soil import simulation
from soil.stats import distribution
class Dummy(exporters.Exporter):
started = False
trials = 0
ended = False
total_time = 0
called_start = 0
called_trial = 0
called_end = 0
def start(self):
self.__class__.called_start += 1
self.__class__.started = True
def trial_end(self, env):
def trial(self, env, stats):
assert env
self.__class__.trials += 1
self.__class__.total_time += env.now
self.__class__.called_trial += 1
def end(self):
def end(self, stats):
self.__class__.ended = True
self.__class__.called_end += 1
class Exporters(TestCase):
@ -39,32 +45,17 @@ class Exporters(TestCase):
'environment_params': {}
}
s = simulation.from_config(config)
s.run_simulation(exporters=[Dummy], dry_run=True)
for env in s.run_simulation(exporters=[Dummy], dry_run=True):
assert env.now <= 2
assert Dummy.started
assert Dummy.ended
assert Dummy.called_start == 1
assert Dummy.called_end == 1
assert Dummy.called_trial == 5
assert Dummy.trials == 5
assert Dummy.total_time == 2*5
def test_distribution(self):
'''The distribution exporter should write the number of agents in each state'''
config = {
'name': 'exporter_sim',
'network_params': {
'generator': 'complete_graph',
'n': 4
},
'agent_type': 'CounterModel',
'max_time': 2,
'num_trials': 5,
'environment_params': {}
}
output = io.StringIO()
s = simulation.from_config(config)
s.run_simulation(exporters=[exporters.distribution], dry_run=True, exporter_params={'copy_to': output})
result = output.getvalue()
assert 'count' in result
assert 'SEED,Noneexporter_sim_trial_3,1,,1,1,1,1' in result
def test_writing(self):
'''Try to write CSV, GEXF, sqlite and YAML (without dry_run)'''
n_trials = 5
@ -86,8 +77,8 @@ class Exporters(TestCase):
exporters.default,
exporters.csv,
exporters.gexf,
exporters.distribution,
],
stats=[distribution,],
outdir=tmpdir,
exporter_params={'copy_to': output})
result = output.getvalue()

View File

@ -5,6 +5,7 @@ import shutil
from glob import glob
from soil import history
from soil import utils
ROOT = os.path.abspath(os.path.dirname(__file__))
@ -154,3 +155,49 @@ class TestHistory(TestCase):
assert recovered
for i in recovered:
assert i in tuples
def test_stats(self):
"""
The data recovered should be equal to the one recorded.
"""
tuples = (
('a_1', 0, 'id', 'v'),
('a_1', 1, 'id', 'a'),
('a_1', 2, 'id', 'l'),
('a_1', 3, 'id', 'u'),
('a_1', 4, 'id', 'e'),
('env', 1, 'prob', 1),
('env', 2, 'prob', 2),
('env', 3, 'prob', 3),
('a_2', 7, 'finished', True),
)
stat_tuples = [
{'num_infected': 5, 'runtime': 0.2},
{'num_infected': 5, 'runtime': 0.2},
{'new': '40'},
]
h = history.History()
h.save_tuples(tuples)
for stat in stat_tuples:
h.save_stats(stat)
recovered = h.get_stats()
assert recovered
assert recovered[0]['num_infected'] == 5
assert recovered[1]['runtime'] == 0.2
assert recovered[2]['new'] == '40'
def test_unflatten(self):
ex = {'count.neighbors.3': 4,
'count.times.2': 4,
'count.total.4': 4,
'mean.neighbors': 3,
'mean.times': 2,
'mean.total': 4,
't_step': 2,
'trial_id': 'exporter_sim_trial_1605817956-4475424'}
res = utils.unflatten_dict(ex)
assert 'count' in res
assert 'mean' in res
assert 't_step' in res
assert 'trial_id' in res

View File

@ -343,4 +343,16 @@ class TestMain(TestCase):
configs = serialization.load_file(join(EXAMPLES, 'template.yml'))
assert len(configs) > 0
def test_until(self):
config = {
'name': 'exporter_sim',
'network_params': {},
'agent_type': 'CounterModel',
'max_time': 2,
'num_trials': 100,
'environment_params': {}
}
s = simulation.from_config(config)
runs = list(s.run_simulation(dry_run=True))
over = list(x.now for x in runs if x.now>2)
assert len(over) == 0

34
tests/test_stats.py Normal file
View File

@ -0,0 +1,34 @@
from unittest import TestCase
from soil import simulation, stats
from soil.utils import unflatten_dict
class Stats(TestCase):
def test_distribution(self):
'''The distribution exporter should write the number of agents in each state'''
config = {
'name': 'exporter_sim',
'network_params': {
'generator': 'complete_graph',
'n': 4
},
'agent_type': 'CounterModel',
'max_time': 2,
'num_trials': 5,
'environment_params': {}
}
s = simulation.from_config(config)
for env in s.run_simulation(stats=[stats.distribution]):
pass
# stats_res = unflatten_dict(dict(env._history['stats', -1, None]))
allstats = s.get_stats()
for stat in allstats:
assert 'count' in stat
assert 'mean' in stat
if 'trial_id' in stat:
assert stat['mean']['neighbors'] == 3
assert stat['count']['total']['4'] == 4
else:
assert stat['count']['count']['neighbors']['3'] == 20
assert stat['mean']['min']['neighbors'] == stat['mean']['max']['neighbors']