mirror of
https://github.com/gsi-upm/soil
synced 2025-08-23 19:52:19 +00:00
Refactoring v0.15.1
See CHANGELOG.md for a full list of changes * Removed nxsim * Refactored `agents.NetworkAgent` and `agents.BaseAgent` * Refactored exporters * Added stats to history
This commit is contained in:
@@ -1 +1 @@
|
||||
0.14.9
|
||||
0.15.1
|
@@ -17,12 +17,12 @@ from .environment import Environment
|
||||
from .history import History
|
||||
from . import serialization
|
||||
from . import analysis
|
||||
from .utils import logger
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
from . import simulation
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info('Running SOIL version: {}'.format(__version__))
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run a SOIL simulation')
|
||||
@@ -40,6 +40,8 @@ def main():
|
||||
help='Dump GEXF graph. Defaults to false.')
|
||||
parser.add_argument('--csv', action='store_true',
|
||||
help='Dump history in CSV format. Defaults to false.')
|
||||
parser.add_argument('--level', type=str,
|
||||
help='Logging level')
|
||||
parser.add_argument('--output', '-o', type=str, default="soil_output",
|
||||
help='folder to write results to. It defaults to the current directory.')
|
||||
parser.add_argument('--synchronous', action='store_true',
|
||||
@@ -48,6 +50,7 @@ def main():
|
||||
help='Export environment and/or simulations using this exporter')
|
||||
|
||||
args = parser.parse_args()
|
||||
logging.basicConfig(level=getattr(logging, (args.level or 'INFO').upper()))
|
||||
|
||||
if os.getcwd() not in sys.path:
|
||||
sys.path.append(os.getcwd())
|
||||
|
@@ -9,7 +9,7 @@ class BassModel(BaseAgent):
|
||||
imitation_prob
|
||||
"""
|
||||
|
||||
def __init__(self, environment, agent_id, state):
|
||||
def __init__(self, environment, agent_id, state, **kwargs):
|
||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||
env_params = environment.environment_params
|
||||
self.state['sentimentCorrelation'] = 0
|
||||
@@ -19,7 +19,7 @@ class BassModel(BaseAgent):
|
||||
|
||||
def behaviour(self):
|
||||
# Outside effects
|
||||
if random.random() < self.state_params['innovation_prob']:
|
||||
if random.random() < self['innovation_prob']:
|
||||
if self.state['id'] == 0:
|
||||
self.state['id'] = 1
|
||||
self.state['sentimentCorrelation'] = 1
|
||||
@@ -32,7 +32,7 @@ class BassModel(BaseAgent):
|
||||
if self.state['id'] == 0:
|
||||
aware_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
num_neighbors_aware = len(aware_neighbors)
|
||||
if random.random() < (self.state_params['imitation_prob']*num_neighbors_aware):
|
||||
if random.random() < (self['imitation_prob']*num_neighbors_aware):
|
||||
self.state['id'] = 1
|
||||
self.state['sentimentCorrelation'] = 1
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from . import BaseAgent
|
||||
from . import NetworkAgent
|
||||
|
||||
|
||||
class CounterModel(BaseAgent):
|
||||
class CounterModel(NetworkAgent):
|
||||
"""
|
||||
Dummy behaviour. It counts the number of nodes in the network and neighbors
|
||||
in each step and adds it to its state.
|
||||
@@ -9,14 +9,14 @@ class CounterModel(BaseAgent):
|
||||
|
||||
def step(self):
|
||||
# Outside effects
|
||||
total = len(list(self.get_all_agents()))
|
||||
total = len(list(self.get_agents()))
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['times'] = self.get('times', 0) + 1
|
||||
self['neighbors'] = neighbors
|
||||
self['total'] = total
|
||||
|
||||
|
||||
class AggregatedCounter(BaseAgent):
|
||||
class AggregatedCounter(NetworkAgent):
|
||||
"""
|
||||
Dummy behaviour. It counts the number of nodes in the network and neighbors
|
||||
in each step and adds it to its state.
|
||||
@@ -33,6 +33,6 @@ class AggregatedCounter(BaseAgent):
|
||||
self['times'] += 1
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['neighbors'] += neighbors
|
||||
total = len(list(self.get_all_agents()))
|
||||
total = len(list(self.get_agents()))
|
||||
self['total'] += total
|
||||
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
||||
|
@@ -3,19 +3,19 @@
|
||||
# for x in range(0, settings.network_params["number_of_nodes"]):
|
||||
# sentimentCorrelationNodeArray.append({'id': x})
|
||||
# Initialize agent states. Let's assume everyone is normal.
|
||||
|
||||
|
||||
import nxsim
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from scipy.spatial import cKDTree as KDTree
|
||||
import json
|
||||
import simpy
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from .. import serialization, history
|
||||
from .. import serialization, history, utils
|
||||
|
||||
|
||||
def as_node(agent):
|
||||
@@ -24,7 +24,7 @@ def as_node(agent):
|
||||
return agent
|
||||
|
||||
|
||||
class BaseAgent(nxsim.BaseAgent):
|
||||
class BaseAgent:
|
||||
"""
|
||||
A special simpy BaseAgent that keeps track of its state history.
|
||||
"""
|
||||
@@ -32,14 +32,13 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
defaults = {}
|
||||
|
||||
def __init__(self, environment, agent_id, state=None,
|
||||
name=None, interval=None, **state_params):
|
||||
name=None, interval=None):
|
||||
# Check for REQUIRED arguments
|
||||
assert environment is not None, TypeError('__init__ missing 1 required keyword argument: \'environment\'. '
|
||||
'Cannot be NoneType.')
|
||||
# Initialize agent parameters
|
||||
self.id = agent_id
|
||||
self.name = name or '{}[{}]'.format(type(self).__name__, self.id)
|
||||
self.state_params = state_params
|
||||
|
||||
# Register agent to environment
|
||||
self.env = environment
|
||||
@@ -51,10 +50,10 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
self.state = real_state
|
||||
self.interval = interval
|
||||
|
||||
if not hasattr(self, 'level'):
|
||||
self.level = logging.DEBUG
|
||||
self.logger = logging.getLogger(self.env.name)
|
||||
self.logger.setLevel(self.level)
|
||||
self.logger = logging.getLogger(self.env.name).getChild(self.name)
|
||||
|
||||
if hasattr(self, 'level'):
|
||||
self.logger.setLevel(self.level)
|
||||
|
||||
# initialize every time an instance of the agent is created
|
||||
self.action = self.env.process(self.run())
|
||||
@@ -75,14 +74,10 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
for k, v in value.items():
|
||||
self[k] = v
|
||||
|
||||
@property
|
||||
def global_topology(self):
|
||||
return self.env.G
|
||||
|
||||
@property
|
||||
def environment_params(self):
|
||||
return self.env.environment_params
|
||||
|
||||
|
||||
@environment_params.setter
|
||||
def environment_params(self, value):
|
||||
self.env.environment_params = value
|
||||
@@ -135,36 +130,10 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
def die(self, remove=False):
|
||||
self.alive = False
|
||||
if remove:
|
||||
super().die()
|
||||
self.remove_node(self.id)
|
||||
|
||||
def step(self):
|
||||
pass
|
||||
|
||||
def count_agents(self, **kwargs):
|
||||
return len(list(self.get_agents(**kwargs)))
|
||||
|
||||
def count_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return len(super().get_neighboring_agents(state_id=state_id, **kwargs))
|
||||
|
||||
def get_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
|
||||
|
||||
def get_agents(self, agents=None, limit_neighbors=False, **kwargs):
|
||||
if limit_neighbors:
|
||||
agents = super().get_agents(limit_neighbors=limit_neighbors)
|
||||
else:
|
||||
agents = self.env.get_agents(agents)
|
||||
return select(agents, **kwargs)
|
||||
|
||||
def log(self, message, *args, level=logging.INFO, **kwargs):
|
||||
message = message + " ".join(str(i) for i in args)
|
||||
message = "\t{:10}@{:>5}:\t{}".format(self.name, self.now, message)
|
||||
for k, v in kwargs:
|
||||
message += " {k}={v} ".format(k, v)
|
||||
extra = {}
|
||||
extra['now'] = self.now
|
||||
extra['id'] = self.id
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
return
|
||||
|
||||
def debug(self, *args, **kwargs):
|
||||
return self.log(*args, level=logging.DEBUG, **kwargs)
|
||||
@@ -192,24 +161,59 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
self._state = state['_state']
|
||||
self.env = state['environment']
|
||||
|
||||
def add_edge(self, node1, node2, **attrs):
|
||||
node1 = as_node(node1)
|
||||
node2 = as_node(node2)
|
||||
class NetworkAgent(BaseAgent):
|
||||
|
||||
for n in [node1, node2]:
|
||||
if n not in self.global_topology.nodes(data=False):
|
||||
raise ValueError('"{}" not in the graph'.format(n))
|
||||
return self.global_topology.add_edge(node1, node2, **attrs)
|
||||
@property
|
||||
def topology(self):
|
||||
return self.env.G
|
||||
|
||||
@property
|
||||
def G(self):
|
||||
return self.env.G
|
||||
|
||||
def count_agents(self, **kwargs):
|
||||
return len(list(self.get_agents(**kwargs)))
|
||||
|
||||
def count_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return len(self.get_neighboring_agents(state_id=state_id, **kwargs))
|
||||
|
||||
def get_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
|
||||
|
||||
def get_agents(self, agents=None, limit_neighbors=False, **kwargs):
|
||||
if limit_neighbors:
|
||||
agents = self.topology.neighbors(self.id)
|
||||
|
||||
agents = self.env.get_agents(agents)
|
||||
return select(agents, **kwargs)
|
||||
|
||||
def log(self, message, *args, level=logging.INFO, **kwargs):
|
||||
message = message + " ".join(str(i) for i in args)
|
||||
message = " @{:>3}: {}".format(self.now, message)
|
||||
for k, v in kwargs:
|
||||
message += " {k}={v} ".format(k, v)
|
||||
extra = {}
|
||||
extra['now'] = self.now
|
||||
extra['agent_id'] = self.id
|
||||
extra['agent_name'] = self.name
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
|
||||
def subgraph(self, center=True, **kwargs):
|
||||
include = [self] if center else []
|
||||
return self.global_topology.subgraph(n.id for n in self.get_agents(**kwargs)+include)
|
||||
return self.topology.subgraph(n.id for n in self.get_agents(**kwargs)+include)
|
||||
|
||||
def remove_node(self, agent_id):
|
||||
self.topology.remove_node(agent_id)
|
||||
|
||||
class NetworkAgent(BaseAgent):
|
||||
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
||||
# return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
|
||||
if self.id not in self.topology.nodes(data=False):
|
||||
raise ValueError('{} not in list of existing agents in the network'.format(self.id))
|
||||
if other not in self.topology.nodes(data=False):
|
||||
raise ValueError('{} not in list of existing agents in the network'.format(other))
|
||||
|
||||
self.topology.add_edge(self.id, other, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
||||
|
||||
def add_edge(self, other, **kwargs):
|
||||
return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
|
||||
|
||||
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
||||
'''Get a list of nodes in the ego network of *node* of radius *steps*'''
|
||||
@@ -220,14 +224,14 @@ class NetworkAgent(BaseAgent):
|
||||
def degree(self, node, force=False):
|
||||
node = as_node(node)
|
||||
if force or (not hasattr(self.env, '_degree')) or getattr(self.env, '_last_step', 0) < self.now:
|
||||
self.env._degree = nx.degree_centrality(self.global_topology)
|
||||
self.env._degree = nx.degree_centrality(self.topology)
|
||||
self.env._last_step = self.now
|
||||
return self.env._degree[node]
|
||||
|
||||
def betweenness(self, node, force=False):
|
||||
node = as_node(node)
|
||||
if force or (not hasattr(self.env, '_betweenness')) or getattr(self.env, '_last_step', 0) < self.now:
|
||||
self.env._betweenness = nx.betweenness_centrality(self.global_topology)
|
||||
self.env._betweenness = nx.betweenness_centrality(self.topology)
|
||||
self.env._last_step = self.now
|
||||
return self.env._betweenness[node]
|
||||
|
||||
@@ -292,16 +296,22 @@ class MetaFSM(type):
|
||||
cls.states = states
|
||||
|
||||
|
||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
class FSM(NetworkAgent, metaclass=MetaFSM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FSM, self).__init__(*args, **kwargs)
|
||||
if 'id' not in self.state:
|
||||
if not self.default_state:
|
||||
raise ValueError('No default state specified for {}'.format(self.id))
|
||||
self['id'] = self.default_state.id
|
||||
self._next_change = simpy.core.Infinity
|
||||
self._next_state = self.state
|
||||
|
||||
def step(self):
|
||||
if 'id' in self.state:
|
||||
if self._next_change < self.now:
|
||||
next_state = self._next_state
|
||||
self._next_change = simpy.core.Infinity
|
||||
self['id'] = next_state
|
||||
elif 'id' in self.state:
|
||||
next_state = self['id']
|
||||
elif self.default_state:
|
||||
next_state = self.default_state.id
|
||||
@@ -311,6 +321,10 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
raise Exception('{} is not a valid id for {}'.format(next_state, self))
|
||||
return self.states[next_state](self)
|
||||
|
||||
def next_state(self, state):
|
||||
self._next_change = self.now
|
||||
self._next_state = state
|
||||
|
||||
def set_state(self, state):
|
||||
if hasattr(state, 'id'):
|
||||
state = state.id
|
||||
@@ -371,14 +385,18 @@ def calculate_distribution(network_agents=None,
|
||||
else:
|
||||
raise ValueError('Specify a distribution or a default agent type')
|
||||
|
||||
# Fix missing weights and incompatible types
|
||||
for x in network_agents:
|
||||
x['weight'] = float(x.get('weight', 1))
|
||||
|
||||
# Calculate the thresholds
|
||||
total = sum(x.get('weight', 1) for x in network_agents)
|
||||
total = sum(x['weight'] for x in network_agents)
|
||||
acc = 0
|
||||
for v in network_agents:
|
||||
if 'ids' in v:
|
||||
v['threshold'] = STATIC_THRESHOLD
|
||||
continue
|
||||
upper = acc + (v.get('weight', 1)/total)
|
||||
upper = acc + (v['weight']/total)
|
||||
v['threshold'] = [acc, upper]
|
||||
acc = upper
|
||||
return network_agents
|
||||
@@ -425,7 +443,7 @@ def _validate_states(states, topology):
|
||||
states = states or []
|
||||
if isinstance(states, dict):
|
||||
for x in states:
|
||||
assert x in topology.node
|
||||
assert x in topology.nodes
|
||||
else:
|
||||
assert len(states) <= len(topology)
|
||||
return states
|
||||
|
@@ -28,13 +28,13 @@ def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
|
||||
df = read_csv(trial_data, **kwargs)
|
||||
yield config_file, df, config
|
||||
else:
|
||||
for trial_data in sorted(glob.glob(join(folder, '*.db.sqlite'))):
|
||||
for trial_data in sorted(glob.glob(join(folder, '*.sqlite'))):
|
||||
df = read_sql(trial_data, **kwargs)
|
||||
yield config_file, df, config
|
||||
|
||||
|
||||
def read_sql(db, *args, **kwargs):
|
||||
h = history.History(db_path=db, backup=False)
|
||||
h = history.History(db_path=db, backup=False, readonly=True)
|
||||
df = h.read_sql(*args, **kwargs)
|
||||
return df
|
||||
|
||||
@@ -69,6 +69,13 @@ def convert_types_slow(df):
|
||||
df = df.apply(convert_row, axis=1)
|
||||
return df
|
||||
|
||||
|
||||
def split_processed(df):
|
||||
env = df.loc[:, df.columns.get_level_values(1).isin(['env', 'stats'])]
|
||||
agents = df.loc[:, ~df.columns.get_level_values(1).isin(['env', 'stats'])]
|
||||
return env, agents
|
||||
|
||||
|
||||
def split_df(df):
|
||||
'''
|
||||
Split a dataframe in two dataframes: one with the history of agents,
|
||||
@@ -136,7 +143,7 @@ def get_value(df, *keys, aggfunc='sum'):
|
||||
return df.groupby(axis=1, level=0).agg(aggfunc)
|
||||
|
||||
|
||||
def plot_all(*args, **kwargs):
|
||||
def plot_all(*args, plot_args={}, **kwargs):
|
||||
'''
|
||||
Read all the trial data and plot the result of applying a function on them.
|
||||
'''
|
||||
@@ -144,14 +151,17 @@ def plot_all(*args, **kwargs):
|
||||
ps = []
|
||||
for line in dfs:
|
||||
f, df, config = line
|
||||
df.plot(title=config['name'])
|
||||
if len(df) < 1:
|
||||
continue
|
||||
df.plot(title=config['name'], **plot_args)
|
||||
ps.append(df)
|
||||
return ps
|
||||
|
||||
def do_all(pattern, func, *keys, include_env=False, **kwargs):
|
||||
for config_file, df, config in read_data(pattern, keys=keys):
|
||||
if len(df) < 1:
|
||||
continue
|
||||
p = func(df, *keys, **kwargs)
|
||||
p.plot(title=config['name'])
|
||||
yield config_file, p, config
|
||||
|
||||
|
||||
|
@@ -8,11 +8,10 @@ import yaml
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
from copy import deepcopy
|
||||
from collections import Counter
|
||||
from networkx.readwrite import json_graph
|
||||
|
||||
import networkx as nx
|
||||
import nxsim
|
||||
import simpy
|
||||
|
||||
from . import serialization, agents, analysis, history, utils
|
||||
|
||||
@@ -23,7 +22,7 @@ _CONFIG_PROPS = [ 'name',
|
||||
'interval',
|
||||
]
|
||||
|
||||
class Environment(nxsim.NetworkEnvironment):
|
||||
class Environment(simpy.Environment):
|
||||
"""
|
||||
The environment is key in a simulation. It contains the network topology,
|
||||
a reference to network and environment agents, as well as the environment
|
||||
@@ -42,7 +41,10 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
interval=1,
|
||||
seed=None,
|
||||
topology=None,
|
||||
*args, **kwargs):
|
||||
initial_time=0,
|
||||
**environment_params):
|
||||
|
||||
|
||||
self.name = name or 'UnnamedEnvironment'
|
||||
seed = seed or time.time()
|
||||
random.seed(seed)
|
||||
@@ -52,7 +54,11 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
self.default_state = deepcopy(default_state) or {}
|
||||
if not topology:
|
||||
topology = nx.Graph()
|
||||
super().__init__(*args, topology=topology, **kwargs)
|
||||
self.G = nx.Graph(topology)
|
||||
|
||||
super().__init__(initial_time=initial_time)
|
||||
self.environment_params = environment_params
|
||||
|
||||
self._env_agents = {}
|
||||
self.interval = interval
|
||||
self._history = history.History(name=self.name,
|
||||
@@ -151,12 +157,10 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
start = start or self.now
|
||||
return self.G.add_edge(agent1, agent2, **attrs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
def run(self, until, *args, **kwargs):
|
||||
self._save_state()
|
||||
self.log_stats()
|
||||
super().run(*args, **kwargs)
|
||||
super().run(until, *args, **kwargs)
|
||||
self._history.flush_cache()
|
||||
self.log_stats()
|
||||
|
||||
def _save_state(self, now=None):
|
||||
serialization.logger.debug('Saving state @{}'.format(self.now))
|
||||
@@ -318,25 +322,6 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
|
||||
return G
|
||||
|
||||
def stats(self):
|
||||
stats = {}
|
||||
stats['network'] = {}
|
||||
stats['network']['n_nodes'] = self.G.number_of_nodes()
|
||||
stats['network']['n_edges'] = self.G.number_of_edges()
|
||||
c = Counter()
|
||||
c.update(a.__class__.__name__ for a in self.network_agents)
|
||||
stats['agents'] = {}
|
||||
stats['agents']['model_count'] = dict(c)
|
||||
c2 = Counter()
|
||||
c2.update(a['id'] for a in self.network_agents)
|
||||
stats['agents']['state_count'] = dict(c2)
|
||||
stats['params'] = self.environment_params
|
||||
return stats
|
||||
|
||||
def log_stats(self):
|
||||
stats = self.stats()
|
||||
serialization.logger.info('Environment stats: \n{}'.format(yaml.dump(stats, default_flow_style=False)))
|
||||
|
||||
def __getstate__(self):
|
||||
state = {}
|
||||
for prop in _CONFIG_PROPS:
|
||||
@@ -344,6 +329,7 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
state['G'] = json_graph.node_link_data(self.G)
|
||||
state['environment_agents'] = self._env_agents
|
||||
state['history'] = self._history
|
||||
state['_now'] = self._now
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
@@ -352,6 +338,8 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
self._env_agents = state['environment_agents']
|
||||
self.G = json_graph.node_link_graph(state['G'])
|
||||
self._history = state['history']
|
||||
self._now = state['_now']
|
||||
self._queue = []
|
||||
|
||||
|
||||
SoilEnvironment = Environment
|
||||
|
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
import csv as csvlib
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
|
||||
from .serialization import deserialize
|
||||
from .utils import open_or_reuse, logger, timer
|
||||
@@ -49,7 +50,7 @@ class Exporter:
|
||||
'''
|
||||
|
||||
def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None):
|
||||
self.sim = simulation
|
||||
self.simulation = simulation
|
||||
outdir = outdir or os.path.join(os.getcwd(), 'soil_output')
|
||||
self.outdir = os.path.join(outdir,
|
||||
simulation.group or '',
|
||||
@@ -59,12 +60,15 @@ class Exporter:
|
||||
|
||||
def start(self):
|
||||
'''Method to call when the simulation starts'''
|
||||
pass
|
||||
|
||||
def end(self):
|
||||
def end(self, stats):
|
||||
'''Method to call when the simulation ends'''
|
||||
pass
|
||||
|
||||
def trial_end(self, env):
|
||||
def trial(self, env, stats):
|
||||
'''Method to call when a trial ends'''
|
||||
pass
|
||||
|
||||
def output(self, f, mode='w', **kwargs):
|
||||
if self.dry_run:
|
||||
@@ -84,13 +88,13 @@ class default(Exporter):
|
||||
def start(self):
|
||||
if not self.dry_run:
|
||||
logger.info('Dumping results to %s', self.outdir)
|
||||
self.sim.dump_yaml(outdir=self.outdir)
|
||||
self.simulation.dump_yaml(outdir=self.outdir)
|
||||
else:
|
||||
logger.info('NOT dumping results')
|
||||
|
||||
def trial_end(self, env):
|
||||
def trial(self, env, stats):
|
||||
if not self.dry_run:
|
||||
with timer('Dumping simulation {} trial {}'.format(self.sim.name,
|
||||
with timer('Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||
env.name)):
|
||||
with self.output('{}.sqlite'.format(env.name), mode='wb') as f:
|
||||
env.dump_sqlite(f)
|
||||
@@ -98,21 +102,27 @@ class default(Exporter):
|
||||
|
||||
class csv(Exporter):
|
||||
'''Export the state of each environment (and its agents) in a separate CSV file'''
|
||||
def trial_end(self, env):
|
||||
with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.sim.name,
|
||||
def trial(self, env, stats):
|
||||
with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name,
|
||||
env.name,
|
||||
self.outdir)):
|
||||
with self.output('{}.csv'.format(env.name)) as f:
|
||||
env.dump_csv(f)
|
||||
|
||||
with self.output('{}.stats.csv'.format(env.name)) as f:
|
||||
statwriter = csvlib.writer(f, delimiter='\t', quotechar='"', quoting=csvlib.QUOTE_ALL)
|
||||
|
||||
for stat in stats:
|
||||
statwriter.writerow(stat)
|
||||
|
||||
|
||||
class gexf(Exporter):
|
||||
def trial_end(self, env):
|
||||
def trial(self, env, stats):
|
||||
if self.dry_run:
|
||||
logger.info('Not dumping GEXF in dry_run mode')
|
||||
return
|
||||
|
||||
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.sim.name,
|
||||
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||
env.name)):
|
||||
with self.output('{}.gexf'.format(env.name), mode='wb') as f:
|
||||
env.dump_gexf(f)
|
||||
@@ -124,56 +134,24 @@ class dummy(Exporter):
|
||||
with self.output('dummy', 'w') as f:
|
||||
f.write('simulation started @ {}\n'.format(time.time()))
|
||||
|
||||
def trial_end(self, env):
|
||||
def trial(self, env, stats):
|
||||
with self.output('dummy', 'w') as f:
|
||||
for i in env.history_to_tuples():
|
||||
f.write(','.join(map(str, i)))
|
||||
f.write('\n')
|
||||
|
||||
def end(self):
|
||||
def sim(self, stats):
|
||||
with self.output('dummy', 'a') as f:
|
||||
f.write('simulation ended @ {}\n'.format(time.time()))
|
||||
|
||||
|
||||
class distribution(Exporter):
|
||||
'''
|
||||
Write the distribution of agent states at the end of each trial,
|
||||
the mean value, and its deviation.
|
||||
'''
|
||||
|
||||
def start(self):
|
||||
self.means = []
|
||||
self.counts = []
|
||||
|
||||
def trial_end(self, env):
|
||||
df = env[None, None, None].df()
|
||||
ix = df.index[-1]
|
||||
attrs = df.columns.levels[0]
|
||||
vc = {}
|
||||
stats = {}
|
||||
for a in attrs:
|
||||
t = df.loc[(ix, a)]
|
||||
try:
|
||||
self.means.append(('mean', a, t.mean()))
|
||||
except TypeError:
|
||||
for name, count in t.value_counts().iteritems():
|
||||
self.counts.append(('count', a, name, count))
|
||||
|
||||
def end(self):
|
||||
dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value'])
|
||||
dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count'])
|
||||
dfm = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
||||
dfc = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
||||
with self.output('counts.csv') as f:
|
||||
dfc.to_csv(f)
|
||||
with self.output('metrics.csv') as f:
|
||||
dfm.to_csv(f)
|
||||
|
||||
class graphdrawing(Exporter):
|
||||
|
||||
def trial_end(self, env):
|
||||
def trial(self, env, stats):
|
||||
# Outside effects
|
||||
f = plt.figure()
|
||||
nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111))
|
||||
with open('graph-{}.png'.format(env.name)) as f:
|
||||
f.savefig(f)
|
||||
|
||||
|
142
soil/history.py
142
soil/history.py
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
from collections import UserDict, namedtuple
|
||||
|
||||
from . import serialization
|
||||
from .utils import open_or_reuse
|
||||
from .utils import open_or_reuse, unflatten_dict
|
||||
|
||||
|
||||
class History:
|
||||
@@ -19,29 +19,43 @@ class History:
|
||||
Store and retrieve values from a sqlite database.
|
||||
"""
|
||||
|
||||
def __init__(self, name=None, db_path=None, backup=False):
|
||||
self._db = None
|
||||
def __init__(self, name=None, db_path=None, backup=False, readonly=False):
|
||||
if readonly and (not os.path.exists(db_path)):
|
||||
raise Exception('The DB file does not exist. Cannot open in read-only mode')
|
||||
|
||||
if db_path is None:
|
||||
self._db = None
|
||||
self._temp = db_path is None
|
||||
self._stats_columns = None
|
||||
self.readonly = readonly
|
||||
|
||||
if self._temp:
|
||||
if not name:
|
||||
name = time.time()
|
||||
_, db_path = tempfile.mkstemp(suffix='{}.sqlite'.format(name))
|
||||
# The file will be deleted as soon as it's closed
|
||||
# Normally, that will be on destruction
|
||||
db_path = tempfile.NamedTemporaryFile(suffix='{}.sqlite'.format(name)).name
|
||||
|
||||
|
||||
if backup and os.path.exists(db_path):
|
||||
newname = db_path + '.backup{}.sqlite'.format(time.time())
|
||||
os.rename(db_path, newname)
|
||||
newname = db_path + '.backup{}.sqlite'.format(time.time())
|
||||
os.rename(db_path, newname)
|
||||
|
||||
self.db_path = db_path
|
||||
|
||||
self.db = db_path
|
||||
self._dtypes = {}
|
||||
self._tups = []
|
||||
|
||||
|
||||
if self.readonly:
|
||||
return
|
||||
|
||||
with self.db:
|
||||
logger.debug('Creating database {}'.format(self.db_path))
|
||||
self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text text)''')
|
||||
self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text)''')
|
||||
self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''')
|
||||
self.db.execute('''CREATE TABLE IF NOT EXISTS stats (trial_id text)''')
|
||||
self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''')
|
||||
self._dtypes = {}
|
||||
self._tups = []
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
@@ -58,6 +72,7 @@ class History:
|
||||
if isinstance(db_path, str):
|
||||
logger.debug('Connecting to database {}'.format(db_path))
|
||||
self._db = sqlite3.connect(db_path)
|
||||
self._db.row_factory = sqlite3.Row
|
||||
else:
|
||||
self._db = db_path
|
||||
|
||||
@@ -68,9 +83,56 @@ class History:
|
||||
self._db.close()
|
||||
self._db = None
|
||||
|
||||
def save_stats(self, stat):
|
||||
if self.readonly:
|
||||
print('DB in readonly mode')
|
||||
return
|
||||
if not stat:
|
||||
return
|
||||
with self.db:
|
||||
if not self._stats_columns:
|
||||
self._stats_columns = list(c['name'] for c in self.db.execute('PRAGMA table_info(stats)'))
|
||||
|
||||
for column, value in stat.items():
|
||||
if column in self._stats_columns:
|
||||
continue
|
||||
dtype = 'text'
|
||||
if not isinstance(value, str):
|
||||
try:
|
||||
float(value)
|
||||
dtype = 'real'
|
||||
int(value)
|
||||
dtype = 'int'
|
||||
except ValueError:
|
||||
pass
|
||||
self.db.execute('ALTER TABLE stats ADD "{}" "{}"'.format(column, dtype))
|
||||
self._stats_columns.append(column)
|
||||
|
||||
columns = ", ".join(map(lambda x: '"{}"'.format(x), stat.keys()))
|
||||
values = ", ".join(['"{0}"'.format(col) for col in stat.values()])
|
||||
query = "INSERT INTO stats ({columns}) VALUES ({values})".format(
|
||||
columns=columns,
|
||||
values=values
|
||||
)
|
||||
self.db.execute(query)
|
||||
|
||||
def get_stats(self, unflatten=True):
|
||||
rows = self.db.execute("select * from stats").fetchall()
|
||||
res = []
|
||||
for row in rows:
|
||||
d = {}
|
||||
for k in row.keys():
|
||||
if row[k] is None:
|
||||
continue
|
||||
d[k] = row[k]
|
||||
if unflatten:
|
||||
d = unflatten_dict(d)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
@property
|
||||
def dtypes(self):
|
||||
self.read_types()
|
||||
self._read_types()
|
||||
return {k:v[0] for k, v in self._dtypes.items()}
|
||||
|
||||
def save_tuples(self, tuples):
|
||||
@@ -93,18 +155,10 @@ class History:
|
||||
Save a collection of records to the database.
|
||||
Database writes are cached.
|
||||
'''
|
||||
value = self.convert(key, value)
|
||||
self._tups.append(Record(agent_id=agent_id,
|
||||
t_step=t_step,
|
||||
key=key,
|
||||
value=value))
|
||||
if len(self._tups) > 100:
|
||||
self.flush_cache()
|
||||
|
||||
def convert(self, key, value):
|
||||
"""Get the serialized value for a given key."""
|
||||
if self.readonly:
|
||||
raise Exception('DB in readonly mode')
|
||||
if key not in self._dtypes:
|
||||
self.read_types()
|
||||
self._read_types()
|
||||
if key not in self._dtypes:
|
||||
name = serialization.name(value)
|
||||
serializer = serialization.serializer(name)
|
||||
@@ -112,21 +166,21 @@ class History:
|
||||
self._dtypes[key] = (name, serializer, deserializer)
|
||||
with self.db:
|
||||
self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name))
|
||||
return self._dtypes[key][1](value)
|
||||
|
||||
def recover(self, key, value):
|
||||
"""Get the deserialized value for a given key, and the serialized version."""
|
||||
if key not in self._dtypes:
|
||||
self.read_types()
|
||||
if key not in self._dtypes:
|
||||
raise ValueError("Unknown datatype for {} and {}".format(key, value))
|
||||
return self._dtypes[key][2](value)
|
||||
value = self._dtypes[key][1](value)
|
||||
self._tups.append(Record(agent_id=agent_id,
|
||||
t_step=t_step,
|
||||
key=key,
|
||||
value=value))
|
||||
if len(self._tups) > 100:
|
||||
self.flush_cache()
|
||||
|
||||
def flush_cache(self):
|
||||
'''
|
||||
Use a cache to save state changes to avoid opening a session for every change.
|
||||
The cache will be flushed at the end of the simulation, and when history is accessed.
|
||||
'''
|
||||
if self.readonly:
|
||||
raise Exception('DB in readonly mode')
|
||||
logger.debug('Flushing cache {}'.format(self.db_path))
|
||||
with self.db:
|
||||
for rec in self._tups:
|
||||
@@ -139,10 +193,14 @@ class History:
|
||||
res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall()
|
||||
for r in res:
|
||||
agent_id, t_step, key, value = r
|
||||
value = self.recover(key, value)
|
||||
if key not in self._dtypes:
|
||||
self._read_types()
|
||||
if key not in self._dtypes:
|
||||
raise ValueError("Unknown datatype for {} and {}".format(key, value))
|
||||
value = self._dtypes[key][2](value)
|
||||
yield agent_id, t_step, key, value
|
||||
|
||||
def read_types(self):
|
||||
def _read_types(self):
|
||||
with self.db:
|
||||
res = self.db.execute("select key, value_type from value_types ").fetchall()
|
||||
for k, v in res:
|
||||
@@ -167,7 +225,7 @@ class History:
|
||||
|
||||
def read_sql(self, keys=None, agent_ids=None, t_steps=None, convert_types=False, limit=-1):
|
||||
|
||||
self.read_types()
|
||||
self._read_types()
|
||||
|
||||
def escape_and_join(v):
|
||||
if v is None:
|
||||
@@ -181,7 +239,13 @@ class History:
|
||||
|
||||
last_df = None
|
||||
if t_steps:
|
||||
# Look for the last value before the minimum step in the query
|
||||
# Convert negative indices into positive
|
||||
if any(x<0 for x in t_steps):
|
||||
max_t = int(self.db.execute("select max(t_step) from history").fetchone()[0])
|
||||
t_steps = [t if t>0 else max_t+1+t for t in t_steps]
|
||||
|
||||
# We will be doing ffill interpolation, so we need to look for
|
||||
# the last value before the minimum step in the query
|
||||
min_step = min(t_steps)
|
||||
last_filters = ['t_step < {}'.format(min_step),]
|
||||
last_filters = last_filters + filters
|
||||
@@ -219,7 +283,11 @@ class History:
|
||||
for k, v in self._dtypes.items():
|
||||
if k in df_p:
|
||||
dtype, _, deserial = v
|
||||
df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
|
||||
try:
|
||||
df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
|
||||
except (TypeError, ValueError):
|
||||
# Avoid forward-filling unknown/incompatible types
|
||||
continue
|
||||
if t_steps:
|
||||
df_p = df_p.reindex(t_steps, method='ffill')
|
||||
return df_p.ffill()
|
||||
@@ -313,3 +381,5 @@ class Records():
|
||||
|
||||
Key = namedtuple('Key', ['agent_id', 't_step', 'key'])
|
||||
Record = namedtuple('Record', 'agent_id t_step key value')
|
||||
|
||||
Stat = namedtuple('Stat', 'trial_id')
|
||||
|
@@ -17,10 +17,10 @@ logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def load_network(network_params, dir_path=None):
|
||||
if network_params is None:
|
||||
return nx.Graph()
|
||||
path = network_params.get('path', None)
|
||||
if path:
|
||||
G = nx.Graph()
|
||||
|
||||
if 'path' in network_params:
|
||||
path = network_params['path']
|
||||
if dir_path and not os.path.isabs(path):
|
||||
path = os.path.join(dir_path, path)
|
||||
extension = os.path.splitext(path)[1][1:]
|
||||
@@ -32,21 +32,22 @@ def load_network(network_params, dir_path=None):
|
||||
method = getattr(nx.readwrite, 'read_' + extension)
|
||||
except AttributeError:
|
||||
raise AttributeError('Unknown format')
|
||||
return method(path, **kwargs)
|
||||
G = method(path, **kwargs)
|
||||
|
||||
net_args = network_params.copy()
|
||||
if 'generator' not in net_args:
|
||||
return nx.Graph()
|
||||
elif 'generator' in network_params:
|
||||
net_args = network_params.copy()
|
||||
net_gen = net_args.pop('generator')
|
||||
|
||||
net_gen = net_args.pop('generator')
|
||||
if dir_path not in sys.path:
|
||||
sys.path.append(dir_path)
|
||||
|
||||
if dir_path not in sys.path:
|
||||
sys.path.append(dir_path)
|
||||
method = deserializer(net_gen,
|
||||
known_modules=['networkx.generators',])
|
||||
G = method(**net_args)
|
||||
|
||||
return G
|
||||
|
||||
method = deserializer(net_gen,
|
||||
known_modules=['networkx.generators',])
|
||||
|
||||
return method(**net_args)
|
||||
|
||||
|
||||
def load_file(infile):
|
||||
@@ -66,11 +67,32 @@ def expand_template(config):
|
||||
raise ValueError(('You must provide a definition of variables'
|
||||
' for the template.'))
|
||||
|
||||
template = Template(config['template'])
|
||||
template = config['template']
|
||||
|
||||
sampler_name = config.get('sampler', 'SALib.sample.morris.sample')
|
||||
n_samples = int(config.get('samples', 100))
|
||||
sampler = deserializer(sampler_name)
|
||||
if not isinstance(template, str):
|
||||
template = yaml.dump(template)
|
||||
|
||||
template = Template(template)
|
||||
|
||||
params = params_for_template(config)
|
||||
|
||||
blank_str = template.render({k: 0 for k in params[0].keys()})
|
||||
blank = list(load_string(blank_str))
|
||||
if len(blank) > 1:
|
||||
raise ValueError('Templates must not return more than one configuration')
|
||||
if 'name' in blank[0]:
|
||||
raise ValueError('Templates cannot be named, use group instead')
|
||||
|
||||
for ps in params:
|
||||
string = template.render(ps)
|
||||
for c in load_string(string):
|
||||
yield c
|
||||
|
||||
|
||||
def params_for_template(config):
|
||||
sampler_config = config.get('sampler', {'N': 100})
|
||||
sampler = sampler_config.pop('method', 'SALib.sample.morris.sample')
|
||||
sampler = deserializer(sampler)
|
||||
bounds = config['vars']['bounds']
|
||||
|
||||
problem = {
|
||||
@@ -78,7 +100,7 @@ def expand_template(config):
|
||||
'names': list(bounds.keys()),
|
||||
'bounds': list(v for v in bounds.values())
|
||||
}
|
||||
samples = sampler(problem, n_samples)
|
||||
samples = sampler(problem, **sampler_config)
|
||||
|
||||
lists = config['vars'].get('lists', {})
|
||||
names = list(lists.keys())
|
||||
@@ -88,20 +110,7 @@ def expand_template(config):
|
||||
allnames = names + problem['names']
|
||||
allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)]
|
||||
params = list(map(lambda x: dict(zip(allnames, x)), allvalues))
|
||||
|
||||
|
||||
blank_str = template.render({k: 0 for k in allnames})
|
||||
blank = list(load_string(blank_str))
|
||||
if len(blank) > 1:
|
||||
raise ValueError('Templates must not return more than one configuration')
|
||||
if 'name' in blank[0]:
|
||||
raise ValueError('Templates cannot be named, use group instead')
|
||||
|
||||
confs = []
|
||||
for ps in params:
|
||||
string = template.render(ps)
|
||||
for c in load_string(string):
|
||||
yield c
|
||||
return params
|
||||
|
||||
|
||||
def load_files(*patterns, **kwargs):
|
||||
@@ -116,7 +125,7 @@ def load_files(*patterns, **kwargs):
|
||||
|
||||
def load_config(config):
|
||||
if isinstance(config, dict):
|
||||
yield config, None
|
||||
yield config, os.getcwd()
|
||||
else:
|
||||
yield from load_files(config)
|
||||
|
||||
|
@@ -4,6 +4,7 @@ import importlib
|
||||
import sys
|
||||
import yaml
|
||||
import traceback
|
||||
import logging
|
||||
import networkx as nx
|
||||
from networkx.readwrite import json_graph
|
||||
from multiprocessing import Pool
|
||||
@@ -11,17 +12,19 @@ from functools import partial
|
||||
|
||||
import pickle
|
||||
|
||||
from nxsim import NetworkSimulation
|
||||
|
||||
from . import serialization, utils, basestring, agents
|
||||
from .environment import Environment
|
||||
from .utils import logger
|
||||
from .exporters import for_sim as exporters_for_sim
|
||||
from .exporters import default, for_sim as exporters_for_sim
|
||||
from .stats import defaultStats
|
||||
from .history import History
|
||||
|
||||
|
||||
class Simulation(NetworkSimulation):
|
||||
#TODO: change documentation for simulation
|
||||
|
||||
class Simulation:
|
||||
"""
|
||||
Subclass of nsim.NetworkSimulation with three main differences:
|
||||
Similar to nsim.NetworkSimulation with three main differences:
|
||||
1) agent type can be specified by name or by class.
|
||||
2) instead of just one type, a network agents distribution can be used.
|
||||
The distribution specifies the weight (or probability) of each
|
||||
@@ -91,11 +94,12 @@ class Simulation(NetworkSimulation):
|
||||
environment_params=None, environment_class=None,
|
||||
**kwargs):
|
||||
|
||||
self.seed = str(seed) or str(time.time())
|
||||
self.load_module = load_module
|
||||
self.network_params = network_params
|
||||
self.name = name or 'Unnamed_' + time.strftime("%Y-%m-%d_%H.%M.%S")
|
||||
self.group = group or None
|
||||
self.name = name or 'Unnamed'
|
||||
self.seed = str(seed or name)
|
||||
self._id = '{}_{}'.format(self.name, time.strftime("%Y-%m-%d_%H.%M.%S"))
|
||||
self.group = group or ''
|
||||
self.num_trials = num_trials
|
||||
self.max_time = max_time
|
||||
self.default_state = default_state or {}
|
||||
@@ -128,12 +132,15 @@ class Simulation(NetworkSimulation):
|
||||
self.states = agents._validate_states(states,
|
||||
self.topology)
|
||||
|
||||
self._history = History(name=self.name,
|
||||
backup=False)
|
||||
|
||||
def run_simulation(self, *args, **kwargs):
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
'''Run the simulation and return the list of resulting environments'''
|
||||
return list(self._run_simulation_gen(*args, **kwargs))
|
||||
return list(self.run_gen(*args, **kwargs))
|
||||
|
||||
def _run_sync_or_async(self, parallel=False, *args, **kwargs):
|
||||
if parallel:
|
||||
@@ -148,12 +155,16 @@ class Simulation(NetworkSimulation):
|
||||
yield i
|
||||
else:
|
||||
for i in range(self.num_trials):
|
||||
yield self.run_trial(i,
|
||||
*args,
|
||||
yield self.run_trial(*args,
|
||||
**kwargs)
|
||||
|
||||
def _run_simulation_gen(self, *args, parallel=False, dry_run=False,
|
||||
exporters=['default', ], outdir=None, exporter_params={}, **kwargs):
|
||||
def run_gen(self, *args, parallel=False, dry_run=False,
|
||||
exporters=[default, ], stats=[defaultStats], outdir=None, exporter_params={},
|
||||
stats_params={}, log_level=None,
|
||||
**kwargs):
|
||||
'''Run the simulation and yield the resulting environments.'''
|
||||
if log_level:
|
||||
logger.setLevel(log_level)
|
||||
logger.info('Using exporters: %s', exporters or [])
|
||||
logger.info('Output directory: %s', outdir)
|
||||
exporters = exporters_for_sim(self,
|
||||
@@ -161,31 +172,63 @@ class Simulation(NetworkSimulation):
|
||||
dry_run=dry_run,
|
||||
outdir=outdir,
|
||||
**exporter_params)
|
||||
stats = exporters_for_sim(self,
|
||||
stats,
|
||||
**stats_params)
|
||||
|
||||
with utils.timer('simulation {}'.format(self.name)):
|
||||
for stat in stats:
|
||||
stat.start()
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.start()
|
||||
|
||||
for env in self._run_sync_or_async(*args, parallel=parallel,
|
||||
for env in self._run_sync_or_async(*args,
|
||||
parallel=parallel,
|
||||
log_level=log_level,
|
||||
**kwargs):
|
||||
|
||||
collected = list(stat.trial(env) for stat in stats)
|
||||
|
||||
saved = self.save_stats(collected, t_step=env.now, trial_id=env.name)
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.trial_end(env)
|
||||
exporter.trial(env, saved)
|
||||
|
||||
yield env
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.end()
|
||||
|
||||
def get_env(self, trial_id = 0, **kwargs):
|
||||
collected = list(stat.end() for stat in stats)
|
||||
saved = self.save_stats(collected)
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.end(saved)
|
||||
|
||||
|
||||
def save_stats(self, collection, **kwargs):
|
||||
stats = dict(kwargs)
|
||||
for stat in collection:
|
||||
stats.update(stat)
|
||||
self._history.save_stats(utils.flatten_dict(stats))
|
||||
return stats
|
||||
|
||||
def get_stats(self, **kwargs):
|
||||
return self._history.get_stats(**kwargs)
|
||||
|
||||
def log_stats(self, stats):
|
||||
logger.info('Stats: \n{}'.format(yaml.dump(stats, default_flow_style=False)))
|
||||
|
||||
|
||||
def get_env(self, trial_id=0, **kwargs):
|
||||
'''Create an environment for a trial of the simulation'''
|
||||
opts = self.environment_params.copy()
|
||||
env_name = '{}_trial_{}'.format(self.name, trial_id)
|
||||
opts.update({
|
||||
'name': env_name,
|
||||
'name': trial_id,
|
||||
'topology': self.topology.copy(),
|
||||
'seed': self.seed+env_name,
|
||||
'seed': '{}_trial_{}'.format(self.seed, trial_id),
|
||||
'initial_time': 0,
|
||||
'interval': self.interval,
|
||||
'network_agents': self.network_agents,
|
||||
'initial_time': 0,
|
||||
'states': self.states,
|
||||
'default_state': self.default_state,
|
||||
'environment_agents': self.environment_agents,
|
||||
@@ -194,20 +237,22 @@ class Simulation(NetworkSimulation):
|
||||
env = self.environment_class(**opts)
|
||||
return env
|
||||
|
||||
def run_trial(self, trial_id=0, until=None, **opts):
|
||||
"""Run a single trial of the simulation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trial_id : int
|
||||
def run_trial(self, until=None, log_level=logging.INFO, **opts):
|
||||
"""
|
||||
Run a single trial of the simulation
|
||||
|
||||
"""
|
||||
trial_id = '{}_trial_{}'.format(self.name, time.time()).replace('.', '-')
|
||||
if log_level:
|
||||
logger.setLevel(log_level)
|
||||
# Set-up trial environment and graph
|
||||
until = until or self.max_time
|
||||
env = self.get_env(trial_id = trial_id, **opts)
|
||||
env = self.get_env(trial_id=trial_id, **opts)
|
||||
# Set up agents on nodes
|
||||
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
||||
env.run(until)
|
||||
return env
|
||||
|
||||
def run_trial_exceptions(self, *args, **kwargs):
|
||||
'''
|
||||
A wrapper for run_trial that catches exceptions and returns them.
|
||||
|
106
soil/stats.py
Normal file
106
soil/stats.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import pandas as pd
|
||||
|
||||
from collections import Counter
|
||||
|
||||
class Stats:
|
||||
'''
|
||||
Interface for all stats. It is not necessary, but it is useful
|
||||
if you don't plan to implement all the methods.
|
||||
'''
|
||||
|
||||
def __init__(self, simulation):
|
||||
self.simulation = simulation
|
||||
|
||||
def start(self):
|
||||
'''Method to call when the simulation starts'''
|
||||
pass
|
||||
|
||||
def end(self):
|
||||
'''Method to call when the simulation ends'''
|
||||
return {}
|
||||
|
||||
def trial(self, env):
|
||||
'''Method to call when a trial ends'''
|
||||
return {}
|
||||
|
||||
|
||||
class distribution(Stats):
|
||||
'''
|
||||
Calculate the distribution of agent states at the end of each trial,
|
||||
the mean value, and its deviation.
|
||||
'''
|
||||
|
||||
def start(self):
|
||||
self.means = []
|
||||
self.counts = []
|
||||
|
||||
def trial(self, env):
|
||||
df = env[None, None, None].df()
|
||||
df = df.drop('SEED', axis=1)
|
||||
ix = df.index[-1]
|
||||
attrs = df.columns.get_level_values(0)
|
||||
vc = {}
|
||||
stats = {
|
||||
'mean': {},
|
||||
'count': {},
|
||||
}
|
||||
for a in attrs:
|
||||
t = df.loc[(ix, a)]
|
||||
try:
|
||||
stats['mean'][a] = t.mean()
|
||||
self.means.append(('mean', a, t.mean()))
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
for name, count in t.value_counts().iteritems():
|
||||
if a not in stats['count']:
|
||||
stats['count'][a] = {}
|
||||
stats['count'][a][name] = count
|
||||
self.counts.append(('count', a, name, count))
|
||||
|
||||
return stats
|
||||
|
||||
def end(self):
|
||||
dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value'])
|
||||
dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count'])
|
||||
|
||||
count = {}
|
||||
mean = {}
|
||||
|
||||
if self.means:
|
||||
res = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
||||
mean = res['value'].to_dict()
|
||||
if self.counts:
|
||||
res = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
||||
for k,v in res['count'].to_dict().items():
|
||||
if k not in count:
|
||||
count[k] = {}
|
||||
for tup, times in v.items():
|
||||
subkey, subcount = tup
|
||||
if subkey not in count[k]:
|
||||
count[k][subkey] = {}
|
||||
count[k][subkey][subcount] = times
|
||||
|
||||
|
||||
return {'count': count, 'mean': mean}
|
||||
|
||||
|
||||
class defaultStats(Stats):
|
||||
|
||||
def trial(self, env):
|
||||
c = Counter()
|
||||
c.update(a.__class__.__name__ for a in env.network_agents)
|
||||
|
||||
c2 = Counter()
|
||||
c2.update(a['id'] for a in env.network_agents)
|
||||
|
||||
return {
|
||||
'network ': {
|
||||
'n_nodes': env.G.number_of_nodes(),
|
||||
'n_edges': env.G.number_of_nodes(),
|
||||
},
|
||||
'agents': {
|
||||
'model_count': dict(c),
|
||||
'state_count': dict(c2),
|
||||
}
|
||||
}
|
@@ -7,6 +7,7 @@ from shutil import copyfile
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = logging.getLogger('soil')
|
||||
logging.basicConfig()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
@@ -31,14 +32,13 @@ def safe_open(path, mode='r', backup=True, **kwargs):
|
||||
os.makedirs(outdir)
|
||||
if backup and 'w' in mode and os.path.exists(path):
|
||||
creation = os.path.getctime(path)
|
||||
stamp = time.strftime('%Y-%m-%d_%H.%M', time.localtime(creation))
|
||||
stamp = time.strftime('%Y-%m-%d_%H.%M.%S', time.localtime(creation))
|
||||
|
||||
backup_dir = os.path.join(outdir, stamp)
|
||||
backup_dir = os.path.join(outdir, 'backup')
|
||||
if not os.path.exists(backup_dir):
|
||||
os.makedirs(backup_dir)
|
||||
newpath = os.path.join(backup_dir, os.path.basename(path))
|
||||
if os.path.exists(newpath):
|
||||
newpath = '{}@{}'.format(newpath, time.time())
|
||||
newpath = os.path.join(backup_dir, '{}@{}'.format(os.path.basename(path),
|
||||
stamp))
|
||||
copyfile(path, newpath)
|
||||
return open(path, mode=mode, **kwargs)
|
||||
|
||||
@@ -48,3 +48,40 @@ def open_or_reuse(f, *args, **kwargs):
|
||||
return safe_open(f, *args, **kwargs)
|
||||
except (AttributeError, TypeError):
|
||||
return f
|
||||
|
||||
def flatten_dict(d):
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
return dict(_flatten_dict(d))
|
||||
|
||||
def _flatten_dict(d, prefix=''):
|
||||
if not isinstance(d, dict):
|
||||
# print('END:', prefix, d)
|
||||
yield prefix, d
|
||||
return
|
||||
if prefix:
|
||||
prefix = prefix + '.'
|
||||
for k, v in d.items():
|
||||
# print(k, v)
|
||||
res = list(_flatten_dict(v, prefix='{}{}'.format(prefix, k)))
|
||||
# print('RES:', res)
|
||||
yield from res
|
||||
|
||||
|
||||
def unflatten_dict(d):
|
||||
out = {}
|
||||
for k, v in d.items():
|
||||
target = out
|
||||
if not isinstance(k, str):
|
||||
target[k] = v
|
||||
continue
|
||||
tokens = k.split('.')
|
||||
if len(tokens) < 2:
|
||||
target[k] = v
|
||||
continue
|
||||
for token in tokens[:-1]:
|
||||
if token not in target:
|
||||
target[token] = {}
|
||||
target = target[token]
|
||||
target[tokens[-1]] = v
|
||||
return out
|
||||
|
Reference in New Issue
Block a user