mirror of
https://github.com/gsi-upm/soil
synced 2024-11-22 03:02:28 +00:00
Parallelism and granular exporting options
* Graphs are not saved by default (not backwards compatible) * Modified newsspread examples * More granular options to save results (exporting to CSV and GEXF are now optional) * Updated tutorial to include exporting options * Removed references from environment to simulation * Added parallelism to simulations (can be turned off with a flag or argument).
This commit is contained in:
parent
a4b32afa2f
commit
7d1c800490
File diff suppressed because one or more lines are too long
@ -68,7 +68,7 @@ network_agents:
|
||||
- agent_type: HerdViewer
|
||||
state:
|
||||
has_tv: true
|
||||
id: infected
|
||||
id: neutral
|
||||
weight: 1
|
||||
- agent_type: HerdViewer
|
||||
state:
|
||||
@ -95,7 +95,7 @@ network_agents:
|
||||
- agent_type: HerdViewer
|
||||
state:
|
||||
has_tv: true
|
||||
id: infected
|
||||
id: neutral
|
||||
weight: 1
|
||||
- agent_type: WiseViewer
|
||||
state:
|
||||
@ -121,7 +121,7 @@ network_agents:
|
||||
- agent_type: WiseViewer
|
||||
state:
|
||||
has_tv: true
|
||||
id: infected
|
||||
id: neutral
|
||||
weight: 1
|
||||
- agent_type: WiseViewer
|
||||
state:
|
||||
|
@ -1,5 +1,4 @@
|
||||
from soil.agents import BaseAgent,FSM, state, default_state
|
||||
import random
|
||||
from soil.agents import FSM, state, default_state, prob
|
||||
import logging
|
||||
|
||||
|
||||
@ -10,70 +9,73 @@ class DumbViewer(FSM):
|
||||
'''
|
||||
defaults = {
|
||||
'prob_neighbor_spread': 0.5,
|
||||
'prob_neighbor_cure': 0.25,
|
||||
'prob_tv_spread': 0.1,
|
||||
}
|
||||
|
||||
@default_state
|
||||
@state
|
||||
def neutral(self):
|
||||
r = random.random()
|
||||
if self['has_tv'] and r < self.env['prob_tv_spread']:
|
||||
self.infect()
|
||||
return
|
||||
if self['has_tv']:
|
||||
if prob(self.env['prob_tv_spread']):
|
||||
self.set_state(self.infected)
|
||||
|
||||
@state
|
||||
def infected(self):
|
||||
for neighbor in self.get_neighboring_agents(state_id=self.neutral.id):
|
||||
prob_infect = self.env['prob_neighbor_spread']
|
||||
r = random.random()
|
||||
if r < prob_infect:
|
||||
self.set_state(self.infected.id)
|
||||
if prob(self.env['prob_neighbor_spread']):
|
||||
neighbor.infect()
|
||||
return
|
||||
|
||||
def infect(self):
|
||||
self.set_state(self.infected)
|
||||
|
||||
|
||||
class HerdViewer(DumbViewer):
|
||||
'''
|
||||
A viewer whose probability of infection depends on the state of its neighbors.
|
||||
'''
|
||||
|
||||
level = logging.DEBUG
|
||||
|
||||
|
||||
def infect(self):
|
||||
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
||||
total = self.count_neighboring_agents()
|
||||
prob_infect = self.env['prob_neighbor_spread'] * infected/total
|
||||
self.debug('prob_infect', prob_infect)
|
||||
r = random.random()
|
||||
if r < prob_infect:
|
||||
if prob(prob_infect):
|
||||
self.set_state(self.infected.id)
|
||||
|
||||
|
||||
class WiseViewer(HerdViewer):
|
||||
'''
|
||||
A viewer that can change its mind.
|
||||
'''
|
||||
|
||||
defaults = {
|
||||
'prob_neighbor_spread': 0.5,
|
||||
'prob_neighbor_cure': 0.25,
|
||||
'prob_tv_spread': 0.1,
|
||||
}
|
||||
|
||||
@state
|
||||
def cured(self):
|
||||
prob_cure = self.env['prob_neighbor_cure']
|
||||
for neighbor in self.get_neighboring_agents(state_id=self.infected.id):
|
||||
r = random.random()
|
||||
if r < prob_cure:
|
||||
if prob(prob_cure):
|
||||
try:
|
||||
neighbor.cure()
|
||||
except AttributeError:
|
||||
self.debug('Viewer {} cannot be cured'.format(neighbor.id))
|
||||
return
|
||||
|
||||
def cure(self):
|
||||
self.set_state(self.cured.id)
|
||||
|
||||
@state
|
||||
def infected(self):
|
||||
prob_cure = self.env['prob_neighbor_cure']
|
||||
r = random.random()
|
||||
if r < prob_cure:
|
||||
self.cure()
|
||||
return
|
||||
return super().infected()
|
||||
cured = max(self.count_neighboring_agents(self.cured.id),
|
||||
1.0)
|
||||
infected = max(self.count_neighboring_agents(self.infected.id),
|
||||
1.0)
|
||||
prob_cure = self.env['prob_neighbor_cure'] * (cured/infected)
|
||||
if prob(prob_cure):
|
||||
return self.cure()
|
||||
return self.set_state(super().infected)
|
||||
|
@ -35,8 +35,14 @@ def main():
|
||||
help='Do not store the results of the simulation.')
|
||||
parser.add_argument('--pdb', action='store_true',
|
||||
help='Use a pdb console in case of exception.')
|
||||
parser.add_argument('--output', '-o', type=str,
|
||||
parser.add_argument('--graph', '-g', action='store_true',
|
||||
help='Dump GEXF graph. Defaults to false.')
|
||||
parser.add_argument('--csv', action='store_true',
|
||||
help='Dump history in CSV format. Defaults to false.')
|
||||
parser.add_argument('--output', '-o', type=str, default="soil_output",
|
||||
help='folder to write results to. It defaults to the current directory.')
|
||||
parser.add_argument('--synchronous', action='store_true',
|
||||
help='Run trials serially and synchronously instead of in parallel. Defaults to false.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -47,7 +53,17 @@ def main():
|
||||
logging.info('Loading config file: {}'.format(args.file, args.output))
|
||||
|
||||
try:
|
||||
simulation.run_from_config(args.file, dump=(not args.dry_run), results_dir=args.output)
|
||||
dump = []
|
||||
if not args.dry_run:
|
||||
if args.csv:
|
||||
dump.append('csv')
|
||||
if args.graph:
|
||||
dump.append('gexf')
|
||||
simulation.run_from_config(args.file,
|
||||
dry_run=args.dry_run,
|
||||
dump=dump,
|
||||
parallel=(not args.synchronous),
|
||||
results_dir=args.output)
|
||||
except Exception as ex:
|
||||
if args.pdb:
|
||||
pdb.post_mortem()
|
||||
|
@ -15,4 +15,4 @@ class DrawingAgent(BaseAgent):
|
||||
# Outside effects
|
||||
f = plt.figure()
|
||||
nx.draw(self.env.G, node_size=10, width=0.2, pos=nx.spring_layout(self.env.G, scale=100), ax=f.add_subplot(111))
|
||||
f.savefig(os.path.join(self.env.sim().dir_path, "graph-"+str(self.env.now)+".png"))
|
||||
f.savefig(os.path.join(self.env.get_path(), "graph-"+str(self.env.now)+".png"))
|
||||
|
@ -12,9 +12,9 @@ from copy import deepcopy
|
||||
from functools import partial
|
||||
import json
|
||||
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from .. import utils
|
||||
|
||||
agent_types = {}
|
||||
|
||||
@ -41,7 +41,7 @@ class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
||||
super().__init__(**kwargs)
|
||||
if not hasattr(self, 'level'):
|
||||
self.level = logging.DEBUG
|
||||
self.logger = logging.getLogger('Agent-{}'.format(self.id))
|
||||
self.logger = logging.getLogger('{}-Agent-{}'.format(self.env.name, self.id))
|
||||
self.logger.setLevel(self.level)
|
||||
|
||||
|
||||
@ -140,20 +140,24 @@ class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
||||
|
||||
|
||||
def state(func):
|
||||
'''
|
||||
A state function should return either a state id, or a tuple (state_id, when)
|
||||
The default value for state_id is the current state id.
|
||||
The default value for when is the interval defined in the nevironment.
|
||||
'''
|
||||
|
||||
@wraps(func)
|
||||
def func_wrapper(self):
|
||||
when = None
|
||||
next_state = func(self)
|
||||
when = None
|
||||
if next_state is None:
|
||||
return when
|
||||
try:
|
||||
next_state, when = next_state
|
||||
except TypeError:
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if next_state:
|
||||
try:
|
||||
self.state['id'] = next_state.id
|
||||
except AttributeError:
|
||||
raise ValueError('State id %s is not valid.' % next_state)
|
||||
self.set_state(next_state)
|
||||
return when
|
||||
|
||||
func_wrapper.id = func.__name__
|
||||
@ -212,6 +216,116 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
if state not in self.states:
|
||||
raise ValueError('{} is not a valid state'.format(state))
|
||||
self.state['id'] = state
|
||||
return state
|
||||
|
||||
|
||||
def prob(prob=1):
|
||||
'''
|
||||
A true/False uniform distribution with a given probability.
|
||||
To be used like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
if prob(0.3):
|
||||
do_something()
|
||||
|
||||
'''
|
||||
r = random.random()
|
||||
return r < prob
|
||||
|
||||
|
||||
def calculate_distribution(network_agents=None,
|
||||
agent_type=None):
|
||||
'''
|
||||
Calculate the threshold values (thresholds for a uniform distribution)
|
||||
of an agent distribution given the weights of each agent type.
|
||||
|
||||
The input has this form: ::
|
||||
|
||||
[
|
||||
{'agent_type': 'agent_type_1',
|
||||
'weight': 0.2,
|
||||
'state': {
|
||||
'id': 0
|
||||
}
|
||||
},
|
||||
{'agent_type': 'agent_type_2',
|
||||
'weight': 0.8,
|
||||
'state': {
|
||||
'id': 1
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
In this example, 20% of the nodes will be marked as type
|
||||
'agent_type_1'.
|
||||
'''
|
||||
if network_agents:
|
||||
network_agents = deepcopy(network_agents)
|
||||
elif agent_type:
|
||||
network_agents = [{'agent_type': agent_type}]
|
||||
else:
|
||||
return []
|
||||
|
||||
# Calculate the thresholds
|
||||
total = sum(x.get('weight', 1) for x in network_agents)
|
||||
acc = 0
|
||||
for v in network_agents:
|
||||
upper = acc + (v.get('weight', 1)/total)
|
||||
v['threshold'] = [acc, upper]
|
||||
acc = upper
|
||||
return network_agents
|
||||
|
||||
|
||||
def _serialize_distribution(network_agents):
|
||||
d = _convert_agent_types(network_agents,
|
||||
to_string=True)
|
||||
'''
|
||||
When serializing an agent distribution, remove the thresholds, in order
|
||||
to avoid cluttering the YAML definition file.
|
||||
'''
|
||||
for v in d:
|
||||
if 'threshold' in v:
|
||||
del v['threshold']
|
||||
return d
|
||||
|
||||
|
||||
def _validate_states(states, topology):
|
||||
'''Validate states to avoid ignoring states during initialization'''
|
||||
states = states or []
|
||||
if isinstance(states, dict):
|
||||
for x in states:
|
||||
assert x in topology.node
|
||||
else:
|
||||
assert len(states) <= len(topology)
|
||||
return states
|
||||
|
||||
|
||||
def _convert_agent_types(ind, to_string=False):
|
||||
'''Convenience method to allow specifying agents by class or class name.'''
|
||||
d = deepcopy(ind)
|
||||
for v in d:
|
||||
agent_type = v['agent_type']
|
||||
if to_string and not isinstance(agent_type, str):
|
||||
v['agent_type'] = str(agent_type.__name__)
|
||||
elif not to_string and isinstance(agent_type, str):
|
||||
v['agent_type'] = agent_types[agent_type]
|
||||
return d
|
||||
|
||||
|
||||
def _agent_from_distribution(distribution, value=-1):
|
||||
"""Used in the initialization of agents given an agent distribution."""
|
||||
if value < 0:
|
||||
value = random.random()
|
||||
for d in distribution:
|
||||
threshold = d['threshold']
|
||||
if value >= threshold[0] and value < threshold[1]:
|
||||
state = {}
|
||||
if 'state' in d:
|
||||
state = deepcopy(d['state'])
|
||||
return d['agent_type'], state
|
||||
|
||||
raise Exception('Distribution for value {} not found in: {}'.format(value, distribution))
|
||||
|
||||
|
||||
from .BassModel import *
|
||||
|
@ -1,16 +1,16 @@
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
import weakref
|
||||
import csv
|
||||
import random
|
||||
import simpy
|
||||
from copy import deepcopy
|
||||
from networkx.readwrite import json_graph
|
||||
|
||||
import networkx as nx
|
||||
import nxsim
|
||||
|
||||
from . import utils
|
||||
from . import utils, agents
|
||||
|
||||
|
||||
class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
@ -22,21 +22,18 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
default_state=None,
|
||||
interval=1,
|
||||
seed=None,
|
||||
dump=False,
|
||||
simulation=None,
|
||||
dry_run=False,
|
||||
dir_path=None,
|
||||
*args, **kwargs):
|
||||
self.name = name or 'UnnamedEnvironment'
|
||||
if isinstance(states, list):
|
||||
states = dict(enumerate(states))
|
||||
self.states = deepcopy(states) if states else {}
|
||||
self.default_state = deepcopy(default_state) or {}
|
||||
self.sim = weakref.ref(simulation)
|
||||
if 'topology' not in kwargs and simulation:
|
||||
kwargs['topology'] = self.sim().topology.copy()
|
||||
super().__init__(*args, **kwargs)
|
||||
self._env_agents = {}
|
||||
self.dry_run = dry_run
|
||||
self.interval = interval
|
||||
self.dump = dump
|
||||
# Add environment agents first, so their events get
|
||||
# executed before network agents
|
||||
self['SEED'] = seed or time.time()
|
||||
@ -44,10 +41,11 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
self.process(self.save_state())
|
||||
self.environment_agents = environment_agents or []
|
||||
self.network_agents = network_agents or []
|
||||
if self.dump:
|
||||
self._db_path = os.path.join(self.get_path(), '{}.db.sqlite'.format(self.name))
|
||||
else:
|
||||
self.dir_path = dir_path
|
||||
if self.dry_run:
|
||||
self._db_path = ":memory:"
|
||||
else:
|
||||
self._db_path = os.path.join(self.get_path(), '{}.db.sqlite'.format(self.name))
|
||||
self.create_db(self._db_path)
|
||||
|
||||
def create_db(self, db_path=None):
|
||||
@ -93,7 +91,7 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
for ix in self.G.nodes():
|
||||
i = ix
|
||||
node = self.G.node[i]
|
||||
agent, state = utils.agent_from_distribution(network_agents)
|
||||
agent, state = agents._agent_from_distribution(network_agents)
|
||||
self.set_agent(i, agent_type=agent, state=state)
|
||||
|
||||
def set_agent(self, agent_id, agent_type, state=None):
|
||||
@ -200,7 +198,7 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
return self[key] if key in self else default
|
||||
|
||||
def get_path(self, dir_path=None):
|
||||
dir_path = dir_path or self.sim().dir_path
|
||||
dir_path = dir_path or self.dir_path
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
return dir_path
|
||||
@ -227,6 +225,19 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
self.name+".gexf")
|
||||
nx.write_gexf(G, graph_path, version="1.2draft")
|
||||
|
||||
def dump(self, dir_path=None, formats=None):
|
||||
if not formats:
|
||||
return
|
||||
functions = {
|
||||
'csv': self.dump_csv,
|
||||
'gexf': self.dump_gexf
|
||||
}
|
||||
for f in formats:
|
||||
if f in functions:
|
||||
functions[f](dir_path)
|
||||
else:
|
||||
raise ValueError('Unknown format: {}'.format(f))
|
||||
|
||||
def state_to_tuples(self, now=None):
|
||||
if now is None:
|
||||
now = self.now
|
||||
@ -289,3 +300,23 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
G.add_node(agent.id, **attributes)
|
||||
|
||||
return G
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state['G'] = json_graph.node_link_data(self.G)
|
||||
state['network_agents'] = agents.serialize_distribution(self.network_agents)
|
||||
state['environment_agents'] = agents._convert_agent_types(self.environment_agents,
|
||||
to_string=True)
|
||||
del state['_queue']
|
||||
import inspect
|
||||
for k, v in state.items():
|
||||
if inspect.isgeneratorfunction(v):
|
||||
print(k, v, type(v))
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
self.G = json_graph.node_link_graph(state['G'])
|
||||
self.network_agents = self.calculate_distribution(self._convert_agent_types(self.network_agents))
|
||||
self.environment_agents = self._convert_agent_types(self.environment_agents)
|
||||
return state
|
||||
|
@ -5,14 +5,14 @@ import sys
|
||||
import yaml
|
||||
import networkx as nx
|
||||
from networkx.readwrite import json_graph
|
||||
|
||||
from copy import deepcopy
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
|
||||
import pickle
|
||||
|
||||
from nxsim import NetworkSimulation
|
||||
|
||||
from . import agents, utils, environment, basestring
|
||||
from . import utils, environment, basestring, agents
|
||||
from .utils import logger
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ class SoilSimulation(NetworkSimulation):
|
||||
"""
|
||||
def __init__(self, name=None, topology=None, network_params=None,
|
||||
network_agents=None, agent_type=None, states=None,
|
||||
default_state=None, interval=1, dump=False,
|
||||
default_state=None, interval=1, dump=None, dry_run=False,
|
||||
dir_path=None, num_trials=1, max_time=100,
|
||||
agent_module=None, load_module=None, seed=None,
|
||||
environment_agents=None, environment_params=None):
|
||||
@ -57,7 +57,6 @@ class SoilSimulation(NetworkSimulation):
|
||||
elif isinstance(topology, basestring) or isinstance(topology, dict):
|
||||
topology = json_graph.node_link_graph(topology)
|
||||
|
||||
|
||||
self.load_module = load_module
|
||||
self.topology = nx.Graph(topology)
|
||||
self.network_params = network_params
|
||||
@ -69,94 +68,64 @@ class SoilSimulation(NetworkSimulation):
|
||||
self.interval = interval
|
||||
self.seed = str(seed) or str(time.time())
|
||||
self.dump = dump
|
||||
self.dry_run = dry_run
|
||||
self.environment_params = environment_params or {}
|
||||
|
||||
if load_module:
|
||||
path = sys.path + [self.dir_path]
|
||||
path = sys.path + [self.dir_path, os.getcwd()]
|
||||
f, fp, desc = imp.find_module(load_module, path)
|
||||
imp.load_module('soil.agents.custom', f, fp, desc)
|
||||
|
||||
environment_agents = environment_agents or []
|
||||
self.environment_agents = self._convert_agent_types(environment_agents)
|
||||
self.environment_agents = agents._convert_agent_types(environment_agents)
|
||||
|
||||
distro = self.calculate_distribution(network_agents,
|
||||
agent_type)
|
||||
self.network_agents = self._convert_agent_types(distro)
|
||||
distro = agents.calculate_distribution(network_agents,
|
||||
agent_type)
|
||||
self.network_agents = agents._convert_agent_types(distro)
|
||||
|
||||
self.states = self.validate_states(states,
|
||||
self.topology)
|
||||
self.states = agents._validate_states(states,
|
||||
self.topology)
|
||||
|
||||
def calculate_distribution(self,
|
||||
network_agents=None,
|
||||
agent_type=None):
|
||||
if network_agents:
|
||||
network_agents = deepcopy(network_agents)
|
||||
elif agent_type:
|
||||
network_agents = [{'agent_type': agent_type}]
|
||||
else:
|
||||
return []
|
||||
def run_simulation(self, *args, **kwargs):
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
# Calculate the thresholds
|
||||
total = sum(x.get('weight', 1) for x in network_agents)
|
||||
acc = 0
|
||||
for v in network_agents:
|
||||
upper = acc + (v.get('weight', 1)/total)
|
||||
v['threshold'] = [acc, upper]
|
||||
acc = upper
|
||||
return network_agents
|
||||
def run(self, *args, **kwargs):
|
||||
return list(self.run_simulation_gen(*args, **kwargs))
|
||||
|
||||
def serialize_distribution(self):
|
||||
d = self._convert_agent_types(self.network_agents,
|
||||
to_string=True)
|
||||
for v in d:
|
||||
if 'threshold' in v:
|
||||
del v['threshold']
|
||||
return d
|
||||
|
||||
def _convert_agent_types(self, ind, to_string=False):
|
||||
d = deepcopy(ind)
|
||||
for v in d:
|
||||
agent_type = v['agent_type']
|
||||
if to_string and not isinstance(agent_type, str):
|
||||
v['agent_type'] = str(agent_type.__name__)
|
||||
elif not to_string and isinstance(agent_type, str):
|
||||
v['agent_type'] = agents.agent_types[agent_type]
|
||||
return d
|
||||
|
||||
def validate_states(self, states, topology):
|
||||
states = states or []
|
||||
# Validate states to avoid ignoring states during
|
||||
# initialization
|
||||
if isinstance(states, dict):
|
||||
for x in states:
|
||||
assert x in self.topology.node
|
||||
else:
|
||||
assert len(states) <= len(self.topology)
|
||||
return states
|
||||
|
||||
def run_simulation(self):
|
||||
return self.run()
|
||||
|
||||
def run(self):
|
||||
return list(self.run_simulation_gen())
|
||||
|
||||
def run_simulation_gen(self, *args, **kwargs):
|
||||
def run_simulation_gen(self, *args, parallel=False, **kwargs):
|
||||
p = Pool()
|
||||
with utils.timer('simulation'):
|
||||
for i in range(self.num_trials):
|
||||
res = self.run_trial(i)
|
||||
if self.dump:
|
||||
res.dump_gexf(self.dir_path)
|
||||
res.dump_csv(self.dir_path)
|
||||
yield res
|
||||
|
||||
if self.dump:
|
||||
if parallel:
|
||||
func = partial(self.run_trial, return_env=not parallel)
|
||||
for i in p.imap_unordered(func, range(self.num_trials)):
|
||||
yield i
|
||||
else:
|
||||
for i in range(self.num_trials):
|
||||
yield self.run_trial(i)
|
||||
if not self.dry_run:
|
||||
logger.info('Dumping results to {}'.format(self.dir_path))
|
||||
self.dump_pickle(self.dir_path)
|
||||
self.dump_yaml(self.dir_path)
|
||||
else:
|
||||
logger.info('NOT dumping results')
|
||||
|
||||
def run_trial(self, trial_id=0, dump=False, dir_path=None):
|
||||
def get_env(self, trial_id=0, dump=False, dir_path=None):
|
||||
env_name = '{}_trial_{}'.format(self.name, trial_id)
|
||||
env = environment.SoilEnvironment(name=env_name,
|
||||
topology=self.topology.copy(),
|
||||
seed=self.seed+env_name,
|
||||
initial_time=0,
|
||||
dry_run=self.dry_run,
|
||||
interval=self.interval,
|
||||
network_agents=self.network_agents,
|
||||
states=self.states,
|
||||
default_state=self.default_state,
|
||||
environment_agents=self.environment_agents,
|
||||
dir_path=dir_path or self.dir_path,
|
||||
**self.environment_params)
|
||||
return env
|
||||
|
||||
def run_trial(self, trial_id=0, dump=False, dir_path=None, until=None, return_env=False):
|
||||
"""Run a single trial of the simulation
|
||||
|
||||
Parameters
|
||||
@ -164,25 +133,16 @@ class SoilSimulation(NetworkSimulation):
|
||||
trial_id : int
|
||||
"""
|
||||
# Set-up trial environment and graph
|
||||
logger.info('Trial: {}'.format(trial_id))
|
||||
env_name = '{}_trial_{}'.format(self.name, trial_id)
|
||||
env = environment.SoilEnvironment(name=env_name,
|
||||
topology=self.topology.copy(),
|
||||
seed=self.seed+env_name,
|
||||
initial_time=0,
|
||||
dump=self.dump,
|
||||
interval=self.interval,
|
||||
network_agents=self.network_agents,
|
||||
states=self.states,
|
||||
default_state=self.default_state,
|
||||
environment_agents=self.environment_agents,
|
||||
simulation=self,
|
||||
**self.environment_params)
|
||||
until = until or self.max_time
|
||||
env = self.get_env(trial_id=trial_id, dump=dump, dir_path=dir_path)
|
||||
# Set up agents on nodes
|
||||
logger.info('\tRunning')
|
||||
with utils.timer('trial'):
|
||||
env.run(until=self.max_time)
|
||||
return env
|
||||
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
||||
env.run(until)
|
||||
if self.dump and not self.dry_run:
|
||||
with utils.timer('Dumping simulation {} trial {}'.format(self.name, trial_id)):
|
||||
env.dump(dir_path, formats=self.dump)
|
||||
if return_env:
|
||||
return env
|
||||
|
||||
def to_dict(self):
|
||||
return self.__getstate__()
|
||||
@ -213,16 +173,16 @@ class SoilSimulation(NetworkSimulation):
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state['topology'] = json_graph.node_link_data(self.topology)
|
||||
state['network_agents'] = self.serialize_distribution()
|
||||
state['environment_agents'] = self._convert_agent_types(self.environment_agents,
|
||||
to_string=True)
|
||||
state['network_agents'] = agents._serialize_distribution(self.network_agents)
|
||||
state['environment_agents'] = agents._convert_agent_types(self.environment_agents,
|
||||
to_string=True)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
self.topology = json_graph.node_link_graph(state['topology'])
|
||||
self.network_agents = self._convert_agent_types(self.network_agents)
|
||||
self.environment_agents = self._convert_agent_types(self.environment_agents)
|
||||
self.network_agents = agents.calculate_distribution(agents._convert_agent_types(self.network_agents))
|
||||
self.environment_agents = agents._convert_agent_types(self.environment_agents)
|
||||
return state
|
||||
|
||||
|
||||
@ -235,21 +195,18 @@ def from_config(config, G=None):
|
||||
return sim
|
||||
|
||||
|
||||
def run_from_config(*configs, dump=True, results_dir=None, timestamp=False):
|
||||
if not results_dir:
|
||||
results_dir = 'soil_output'
|
||||
def run_from_config(*configs, results_dir='soil_output', dump=None, timestamp=False, **kwargs):
|
||||
for config_def in configs:
|
||||
for config, cpath in utils.load_config(config_def):
|
||||
name = config.get('name', 'unnamed')
|
||||
logger.info("Using config(s): {name}".format(name=name))
|
||||
|
||||
sim = SoilSimulation(**config)
|
||||
if timestamp:
|
||||
sim_folder = '{}_{}'.format(sim.name,
|
||||
sim_folder = '{}_{}'.format(name,
|
||||
time.strftime("%Y-%m-%d_%H:%M:%S"))
|
||||
else:
|
||||
sim_folder = sim.name
|
||||
sim.dir_path = os.path.join(results_dir, sim_folder)
|
||||
sim.dump = dump
|
||||
logger.info('Dumping results to {} : {}'.format(sim.dir_path, dump))
|
||||
results = sim.run_simulation()
|
||||
sim_folder = name
|
||||
dir_path = os.path.join(results_dir, sim_folder)
|
||||
sim = SoilSimulation(dir_path=dir_path, dump=dump, **config)
|
||||
logger.info('Dumping results to {} : {}'.format(sim.dir_path, sim.dump))
|
||||
results = sim.run_simulation(**kwargs)
|
||||
|
@ -11,7 +11,7 @@ import networkx as nx
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger('soil')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
@ -62,6 +62,7 @@ def load_config(config):
|
||||
@contextmanager
|
||||
def timer(name='task', pre="", function=logger.info, to_object=None):
|
||||
start = time()
|
||||
function('{}Starting {} at {}.'.format(pre, name, start))
|
||||
yield start
|
||||
end = time()
|
||||
function('{}Finished {} in {} seconds'.format(pre, name, str(end-start)))
|
||||
@ -70,21 +71,6 @@ def timer(name='task', pre="", function=logger.info, to_object=None):
|
||||
to_object.end = end
|
||||
|
||||
|
||||
def agent_from_distribution(distribution, value=-1):
|
||||
"""Find the agent """
|
||||
if value < 0:
|
||||
value = random()
|
||||
for d in distribution:
|
||||
threshold = d['threshold']
|
||||
if value >= threshold[0] and value < threshold[1]:
|
||||
state = {}
|
||||
if 'state' in d:
|
||||
state = deepcopy(d['state'])
|
||||
return d['agent_type'], state
|
||||
|
||||
raise Exception('Distribution for value {} not found in: {}'.format(value, distribution))
|
||||
|
||||
|
||||
def repr(v):
|
||||
if isinstance(v, bool):
|
||||
v = "true" if v else ""
|
||||
|
Loading…
Reference in New Issue
Block a user