mirror of
https://github.com/gsi-upm/soil
synced 2024-11-22 03:02:28 +00:00
Parallelism and granular exporting options
* Graphs are not saved by default (not backwards compatible) * Modified newsspread examples * More granular options to save results (exporting to CSV and GEXF are now optional) * Updated tutorial to include exporting options * Removed references from environment to simulation * Added parallelism to simulations (can be turned off with a flag or argument).
This commit is contained in:
parent
a4b32afa2f
commit
7d1c800490
File diff suppressed because one or more lines are too long
@ -68,7 +68,7 @@ network_agents:
|
|||||||
- agent_type: HerdViewer
|
- agent_type: HerdViewer
|
||||||
state:
|
state:
|
||||||
has_tv: true
|
has_tv: true
|
||||||
id: infected
|
id: neutral
|
||||||
weight: 1
|
weight: 1
|
||||||
- agent_type: HerdViewer
|
- agent_type: HerdViewer
|
||||||
state:
|
state:
|
||||||
@ -95,7 +95,7 @@ network_agents:
|
|||||||
- agent_type: HerdViewer
|
- agent_type: HerdViewer
|
||||||
state:
|
state:
|
||||||
has_tv: true
|
has_tv: true
|
||||||
id: infected
|
id: neutral
|
||||||
weight: 1
|
weight: 1
|
||||||
- agent_type: WiseViewer
|
- agent_type: WiseViewer
|
||||||
state:
|
state:
|
||||||
@ -121,7 +121,7 @@ network_agents:
|
|||||||
- agent_type: WiseViewer
|
- agent_type: WiseViewer
|
||||||
state:
|
state:
|
||||||
has_tv: true
|
has_tv: true
|
||||||
id: infected
|
id: neutral
|
||||||
weight: 1
|
weight: 1
|
||||||
- agent_type: WiseViewer
|
- agent_type: WiseViewer
|
||||||
state:
|
state:
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from soil.agents import BaseAgent,FSM, state, default_state
|
from soil.agents import FSM, state, default_state, prob
|
||||||
import random
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
@ -10,70 +9,73 @@ class DumbViewer(FSM):
|
|||||||
'''
|
'''
|
||||||
defaults = {
|
defaults = {
|
||||||
'prob_neighbor_spread': 0.5,
|
'prob_neighbor_spread': 0.5,
|
||||||
'prob_neighbor_cure': 0.25,
|
'prob_tv_spread': 0.1,
|
||||||
}
|
}
|
||||||
|
|
||||||
@default_state
|
@default_state
|
||||||
@state
|
@state
|
||||||
def neutral(self):
|
def neutral(self):
|
||||||
r = random.random()
|
if self['has_tv']:
|
||||||
if self['has_tv'] and r < self.env['prob_tv_spread']:
|
if prob(self.env['prob_tv_spread']):
|
||||||
self.infect()
|
self.set_state(self.infected)
|
||||||
return
|
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def infected(self):
|
def infected(self):
|
||||||
for neighbor in self.get_neighboring_agents(state_id=self.neutral.id):
|
for neighbor in self.get_neighboring_agents(state_id=self.neutral.id):
|
||||||
prob_infect = self.env['prob_neighbor_spread']
|
if prob(self.env['prob_neighbor_spread']):
|
||||||
r = random.random()
|
|
||||||
if r < prob_infect:
|
|
||||||
self.set_state(self.infected.id)
|
|
||||||
neighbor.infect()
|
neighbor.infect()
|
||||||
return
|
|
||||||
|
|
||||||
def infect(self):
|
def infect(self):
|
||||||
self.set_state(self.infected)
|
self.set_state(self.infected)
|
||||||
|
|
||||||
|
|
||||||
class HerdViewer(DumbViewer):
|
class HerdViewer(DumbViewer):
|
||||||
'''
|
'''
|
||||||
A viewer whose probability of infection depends on the state of its neighbors.
|
A viewer whose probability of infection depends on the state of its neighbors.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
level = logging.DEBUG
|
level = logging.DEBUG
|
||||||
|
|
||||||
def infect(self):
|
def infect(self):
|
||||||
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
||||||
total = self.count_neighboring_agents()
|
total = self.count_neighboring_agents()
|
||||||
prob_infect = self.env['prob_neighbor_spread'] * infected/total
|
prob_infect = self.env['prob_neighbor_spread'] * infected/total
|
||||||
self.debug('prob_infect', prob_infect)
|
self.debug('prob_infect', prob_infect)
|
||||||
r = random.random()
|
if prob(prob_infect):
|
||||||
if r < prob_infect:
|
|
||||||
self.set_state(self.infected.id)
|
self.set_state(self.infected.id)
|
||||||
|
|
||||||
|
|
||||||
class WiseViewer(HerdViewer):
|
class WiseViewer(HerdViewer):
|
||||||
'''
|
'''
|
||||||
A viewer that can change its mind.
|
A viewer that can change its mind.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
defaults = {
|
||||||
|
'prob_neighbor_spread': 0.5,
|
||||||
|
'prob_neighbor_cure': 0.25,
|
||||||
|
'prob_tv_spread': 0.1,
|
||||||
|
}
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def cured(self):
|
def cured(self):
|
||||||
prob_cure = self.env['prob_neighbor_cure']
|
prob_cure = self.env['prob_neighbor_cure']
|
||||||
for neighbor in self.get_neighboring_agents(state_id=self.infected.id):
|
for neighbor in self.get_neighboring_agents(state_id=self.infected.id):
|
||||||
r = random.random()
|
if prob(prob_cure):
|
||||||
if r < prob_cure:
|
|
||||||
try:
|
try:
|
||||||
neighbor.cure()
|
neighbor.cure()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
self.debug('Viewer {} cannot be cured'.format(neighbor.id))
|
self.debug('Viewer {} cannot be cured'.format(neighbor.id))
|
||||||
return
|
|
||||||
|
|
||||||
def cure(self):
|
def cure(self):
|
||||||
self.set_state(self.cured.id)
|
self.set_state(self.cured.id)
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def infected(self):
|
def infected(self):
|
||||||
prob_cure = self.env['prob_neighbor_cure']
|
cured = max(self.count_neighboring_agents(self.cured.id),
|
||||||
r = random.random()
|
1.0)
|
||||||
if r < prob_cure:
|
infected = max(self.count_neighboring_agents(self.infected.id),
|
||||||
self.cure()
|
1.0)
|
||||||
return
|
prob_cure = self.env['prob_neighbor_cure'] * (cured/infected)
|
||||||
return super().infected()
|
if prob(prob_cure):
|
||||||
|
return self.cure()
|
||||||
|
return self.set_state(super().infected)
|
||||||
|
@ -35,8 +35,14 @@ def main():
|
|||||||
help='Do not store the results of the simulation.')
|
help='Do not store the results of the simulation.')
|
||||||
parser.add_argument('--pdb', action='store_true',
|
parser.add_argument('--pdb', action='store_true',
|
||||||
help='Use a pdb console in case of exception.')
|
help='Use a pdb console in case of exception.')
|
||||||
parser.add_argument('--output', '-o', type=str,
|
parser.add_argument('--graph', '-g', action='store_true',
|
||||||
|
help='Dump GEXF graph. Defaults to false.')
|
||||||
|
parser.add_argument('--csv', action='store_true',
|
||||||
|
help='Dump history in CSV format. Defaults to false.')
|
||||||
|
parser.add_argument('--output', '-o', type=str, default="soil_output",
|
||||||
help='folder to write results to. It defaults to the current directory.')
|
help='folder to write results to. It defaults to the current directory.')
|
||||||
|
parser.add_argument('--synchronous', action='store_true',
|
||||||
|
help='Run trials serially and synchronously instead of in parallel. Defaults to false.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -47,7 +53,17 @@ def main():
|
|||||||
logging.info('Loading config file: {}'.format(args.file, args.output))
|
logging.info('Loading config file: {}'.format(args.file, args.output))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
simulation.run_from_config(args.file, dump=(not args.dry_run), results_dir=args.output)
|
dump = []
|
||||||
|
if not args.dry_run:
|
||||||
|
if args.csv:
|
||||||
|
dump.append('csv')
|
||||||
|
if args.graph:
|
||||||
|
dump.append('gexf')
|
||||||
|
simulation.run_from_config(args.file,
|
||||||
|
dry_run=args.dry_run,
|
||||||
|
dump=dump,
|
||||||
|
parallel=(not args.synchronous),
|
||||||
|
results_dir=args.output)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
if args.pdb:
|
if args.pdb:
|
||||||
pdb.post_mortem()
|
pdb.post_mortem()
|
||||||
|
@ -15,4 +15,4 @@ class DrawingAgent(BaseAgent):
|
|||||||
# Outside effects
|
# Outside effects
|
||||||
f = plt.figure()
|
f = plt.figure()
|
||||||
nx.draw(self.env.G, node_size=10, width=0.2, pos=nx.spring_layout(self.env.G, scale=100), ax=f.add_subplot(111))
|
nx.draw(self.env.G, node_size=10, width=0.2, pos=nx.spring_layout(self.env.G, scale=100), ax=f.add_subplot(111))
|
||||||
f.savefig(os.path.join(self.env.sim().dir_path, "graph-"+str(self.env.now)+".png"))
|
f.savefig(os.path.join(self.env.get_path(), "graph-"+str(self.env.now)+".png"))
|
||||||
|
@ -12,9 +12,9 @@ from copy import deepcopy
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
|
from .. import utils
|
||||||
|
|
||||||
agent_types = {}
|
agent_types = {}
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if not hasattr(self, 'level'):
|
if not hasattr(self, 'level'):
|
||||||
self.level = logging.DEBUG
|
self.level = logging.DEBUG
|
||||||
self.logger = logging.getLogger('Agent-{}'.format(self.id))
|
self.logger = logging.getLogger('{}-Agent-{}'.format(self.env.name, self.id))
|
||||||
self.logger.setLevel(self.level)
|
self.logger.setLevel(self.level)
|
||||||
|
|
||||||
|
|
||||||
@ -140,20 +140,24 @@ class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
|||||||
|
|
||||||
|
|
||||||
def state(func):
|
def state(func):
|
||||||
|
'''
|
||||||
|
A state function should return either a state id, or a tuple (state_id, when)
|
||||||
|
The default value for state_id is the current state id.
|
||||||
|
The default value for when is the interval defined in the nevironment.
|
||||||
|
'''
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def func_wrapper(self):
|
def func_wrapper(self):
|
||||||
when = None
|
|
||||||
next_state = func(self)
|
next_state = func(self)
|
||||||
|
when = None
|
||||||
|
if next_state is None:
|
||||||
|
return when
|
||||||
try:
|
try:
|
||||||
next_state, when = next_state
|
next_state, when = next_state
|
||||||
except TypeError:
|
except (ValueError, TypeError):
|
||||||
pass
|
pass
|
||||||
if next_state:
|
if next_state:
|
||||||
try:
|
self.set_state(next_state)
|
||||||
self.state['id'] = next_state.id
|
|
||||||
except AttributeError:
|
|
||||||
raise ValueError('State id %s is not valid.' % next_state)
|
|
||||||
return when
|
return when
|
||||||
|
|
||||||
func_wrapper.id = func.__name__
|
func_wrapper.id = func.__name__
|
||||||
@ -212,6 +216,116 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
|||||||
if state not in self.states:
|
if state not in self.states:
|
||||||
raise ValueError('{} is not a valid state'.format(state))
|
raise ValueError('{} is not a valid state'.format(state))
|
||||||
self.state['id'] = state
|
self.state['id'] = state
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def prob(prob=1):
|
||||||
|
'''
|
||||||
|
A true/False uniform distribution with a given probability.
|
||||||
|
To be used like this:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
if prob(0.3):
|
||||||
|
do_something()
|
||||||
|
|
||||||
|
'''
|
||||||
|
r = random.random()
|
||||||
|
return r < prob
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_distribution(network_agents=None,
|
||||||
|
agent_type=None):
|
||||||
|
'''
|
||||||
|
Calculate the threshold values (thresholds for a uniform distribution)
|
||||||
|
of an agent distribution given the weights of each agent type.
|
||||||
|
|
||||||
|
The input has this form: ::
|
||||||
|
|
||||||
|
[
|
||||||
|
{'agent_type': 'agent_type_1',
|
||||||
|
'weight': 0.2,
|
||||||
|
'state': {
|
||||||
|
'id': 0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{'agent_type': 'agent_type_2',
|
||||||
|
'weight': 0.8,
|
||||||
|
'state': {
|
||||||
|
'id': 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
In this example, 20% of the nodes will be marked as type
|
||||||
|
'agent_type_1'.
|
||||||
|
'''
|
||||||
|
if network_agents:
|
||||||
|
network_agents = deepcopy(network_agents)
|
||||||
|
elif agent_type:
|
||||||
|
network_agents = [{'agent_type': agent_type}]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Calculate the thresholds
|
||||||
|
total = sum(x.get('weight', 1) for x in network_agents)
|
||||||
|
acc = 0
|
||||||
|
for v in network_agents:
|
||||||
|
upper = acc + (v.get('weight', 1)/total)
|
||||||
|
v['threshold'] = [acc, upper]
|
||||||
|
acc = upper
|
||||||
|
return network_agents
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_distribution(network_agents):
|
||||||
|
d = _convert_agent_types(network_agents,
|
||||||
|
to_string=True)
|
||||||
|
'''
|
||||||
|
When serializing an agent distribution, remove the thresholds, in order
|
||||||
|
to avoid cluttering the YAML definition file.
|
||||||
|
'''
|
||||||
|
for v in d:
|
||||||
|
if 'threshold' in v:
|
||||||
|
del v['threshold']
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_states(states, topology):
|
||||||
|
'''Validate states to avoid ignoring states during initialization'''
|
||||||
|
states = states or []
|
||||||
|
if isinstance(states, dict):
|
||||||
|
for x in states:
|
||||||
|
assert x in topology.node
|
||||||
|
else:
|
||||||
|
assert len(states) <= len(topology)
|
||||||
|
return states
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_agent_types(ind, to_string=False):
|
||||||
|
'''Convenience method to allow specifying agents by class or class name.'''
|
||||||
|
d = deepcopy(ind)
|
||||||
|
for v in d:
|
||||||
|
agent_type = v['agent_type']
|
||||||
|
if to_string and not isinstance(agent_type, str):
|
||||||
|
v['agent_type'] = str(agent_type.__name__)
|
||||||
|
elif not to_string and isinstance(agent_type, str):
|
||||||
|
v['agent_type'] = agent_types[agent_type]
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def _agent_from_distribution(distribution, value=-1):
|
||||||
|
"""Used in the initialization of agents given an agent distribution."""
|
||||||
|
if value < 0:
|
||||||
|
value = random.random()
|
||||||
|
for d in distribution:
|
||||||
|
threshold = d['threshold']
|
||||||
|
if value >= threshold[0] and value < threshold[1]:
|
||||||
|
state = {}
|
||||||
|
if 'state' in d:
|
||||||
|
state = deepcopy(d['state'])
|
||||||
|
return d['agent_type'], state
|
||||||
|
|
||||||
|
raise Exception('Distribution for value {} not found in: {}'.format(value, distribution))
|
||||||
|
|
||||||
|
|
||||||
from .BassModel import *
|
from .BassModel import *
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import time
|
import time
|
||||||
import weakref
|
|
||||||
import csv
|
import csv
|
||||||
import random
|
import random
|
||||||
import simpy
|
import simpy
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from networkx.readwrite import json_graph
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import nxsim
|
import nxsim
|
||||||
|
|
||||||
from . import utils
|
from . import utils, agents
|
||||||
|
|
||||||
|
|
||||||
class SoilEnvironment(nxsim.NetworkEnvironment):
|
class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||||
@ -22,21 +22,18 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
|||||||
default_state=None,
|
default_state=None,
|
||||||
interval=1,
|
interval=1,
|
||||||
seed=None,
|
seed=None,
|
||||||
dump=False,
|
dry_run=False,
|
||||||
simulation=None,
|
dir_path=None,
|
||||||
*args, **kwargs):
|
*args, **kwargs):
|
||||||
self.name = name or 'UnnamedEnvironment'
|
self.name = name or 'UnnamedEnvironment'
|
||||||
if isinstance(states, list):
|
if isinstance(states, list):
|
||||||
states = dict(enumerate(states))
|
states = dict(enumerate(states))
|
||||||
self.states = deepcopy(states) if states else {}
|
self.states = deepcopy(states) if states else {}
|
||||||
self.default_state = deepcopy(default_state) or {}
|
self.default_state = deepcopy(default_state) or {}
|
||||||
self.sim = weakref.ref(simulation)
|
|
||||||
if 'topology' not in kwargs and simulation:
|
|
||||||
kwargs['topology'] = self.sim().topology.copy()
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._env_agents = {}
|
self._env_agents = {}
|
||||||
|
self.dry_run = dry_run
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.dump = dump
|
|
||||||
# Add environment agents first, so their events get
|
# Add environment agents first, so their events get
|
||||||
# executed before network agents
|
# executed before network agents
|
||||||
self['SEED'] = seed or time.time()
|
self['SEED'] = seed or time.time()
|
||||||
@ -44,10 +41,11 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
|||||||
self.process(self.save_state())
|
self.process(self.save_state())
|
||||||
self.environment_agents = environment_agents or []
|
self.environment_agents = environment_agents or []
|
||||||
self.network_agents = network_agents or []
|
self.network_agents = network_agents or []
|
||||||
if self.dump:
|
self.dir_path = dir_path
|
||||||
self._db_path = os.path.join(self.get_path(), '{}.db.sqlite'.format(self.name))
|
if self.dry_run:
|
||||||
else:
|
|
||||||
self._db_path = ":memory:"
|
self._db_path = ":memory:"
|
||||||
|
else:
|
||||||
|
self._db_path = os.path.join(self.get_path(), '{}.db.sqlite'.format(self.name))
|
||||||
self.create_db(self._db_path)
|
self.create_db(self._db_path)
|
||||||
|
|
||||||
def create_db(self, db_path=None):
|
def create_db(self, db_path=None):
|
||||||
@ -93,7 +91,7 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
|||||||
for ix in self.G.nodes():
|
for ix in self.G.nodes():
|
||||||
i = ix
|
i = ix
|
||||||
node = self.G.node[i]
|
node = self.G.node[i]
|
||||||
agent, state = utils.agent_from_distribution(network_agents)
|
agent, state = agents._agent_from_distribution(network_agents)
|
||||||
self.set_agent(i, agent_type=agent, state=state)
|
self.set_agent(i, agent_type=agent, state=state)
|
||||||
|
|
||||||
def set_agent(self, agent_id, agent_type, state=None):
|
def set_agent(self, agent_id, agent_type, state=None):
|
||||||
@ -200,7 +198,7 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
|||||||
return self[key] if key in self else default
|
return self[key] if key in self else default
|
||||||
|
|
||||||
def get_path(self, dir_path=None):
|
def get_path(self, dir_path=None):
|
||||||
dir_path = dir_path or self.sim().dir_path
|
dir_path = dir_path or self.dir_path
|
||||||
if not os.path.exists(dir_path):
|
if not os.path.exists(dir_path):
|
||||||
os.makedirs(dir_path)
|
os.makedirs(dir_path)
|
||||||
return dir_path
|
return dir_path
|
||||||
@ -227,6 +225,19 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
|||||||
self.name+".gexf")
|
self.name+".gexf")
|
||||||
nx.write_gexf(G, graph_path, version="1.2draft")
|
nx.write_gexf(G, graph_path, version="1.2draft")
|
||||||
|
|
||||||
|
def dump(self, dir_path=None, formats=None):
|
||||||
|
if not formats:
|
||||||
|
return
|
||||||
|
functions = {
|
||||||
|
'csv': self.dump_csv,
|
||||||
|
'gexf': self.dump_gexf
|
||||||
|
}
|
||||||
|
for f in formats:
|
||||||
|
if f in functions:
|
||||||
|
functions[f](dir_path)
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown format: {}'.format(f))
|
||||||
|
|
||||||
def state_to_tuples(self, now=None):
|
def state_to_tuples(self, now=None):
|
||||||
if now is None:
|
if now is None:
|
||||||
now = self.now
|
now = self.now
|
||||||
@ -289,3 +300,23 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
|||||||
G.add_node(agent.id, **attributes)
|
G.add_node(agent.id, **attributes)
|
||||||
|
|
||||||
return G
|
return G
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state['G'] = json_graph.node_link_data(self.G)
|
||||||
|
state['network_agents'] = agents.serialize_distribution(self.network_agents)
|
||||||
|
state['environment_agents'] = agents._convert_agent_types(self.environment_agents,
|
||||||
|
to_string=True)
|
||||||
|
del state['_queue']
|
||||||
|
import inspect
|
||||||
|
for k, v in state.items():
|
||||||
|
if inspect.isgeneratorfunction(v):
|
||||||
|
print(k, v, type(v))
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__ = state
|
||||||
|
self.G = json_graph.node_link_graph(state['G'])
|
||||||
|
self.network_agents = self.calculate_distribution(self._convert_agent_types(self.network_agents))
|
||||||
|
self.environment_agents = self._convert_agent_types(self.environment_agents)
|
||||||
|
return state
|
||||||
|
@ -5,14 +5,14 @@ import sys
|
|||||||
import yaml
|
import yaml
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from networkx.readwrite import json_graph
|
from networkx.readwrite import json_graph
|
||||||
|
from multiprocessing import Pool
|
||||||
from copy import deepcopy
|
from functools import partial
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
from nxsim import NetworkSimulation
|
from nxsim import NetworkSimulation
|
||||||
|
|
||||||
from . import agents, utils, environment, basestring
|
from . import utils, environment, basestring, agents
|
||||||
from .utils import logger
|
from .utils import logger
|
||||||
|
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ class SoilSimulation(NetworkSimulation):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, name=None, topology=None, network_params=None,
|
def __init__(self, name=None, topology=None, network_params=None,
|
||||||
network_agents=None, agent_type=None, states=None,
|
network_agents=None, agent_type=None, states=None,
|
||||||
default_state=None, interval=1, dump=False,
|
default_state=None, interval=1, dump=None, dry_run=False,
|
||||||
dir_path=None, num_trials=1, max_time=100,
|
dir_path=None, num_trials=1, max_time=100,
|
||||||
agent_module=None, load_module=None, seed=None,
|
agent_module=None, load_module=None, seed=None,
|
||||||
environment_agents=None, environment_params=None):
|
environment_agents=None, environment_params=None):
|
||||||
@ -57,7 +57,6 @@ class SoilSimulation(NetworkSimulation):
|
|||||||
elif isinstance(topology, basestring) or isinstance(topology, dict):
|
elif isinstance(topology, basestring) or isinstance(topology, dict):
|
||||||
topology = json_graph.node_link_graph(topology)
|
topology = json_graph.node_link_graph(topology)
|
||||||
|
|
||||||
|
|
||||||
self.load_module = load_module
|
self.load_module = load_module
|
||||||
self.topology = nx.Graph(topology)
|
self.topology = nx.Graph(topology)
|
||||||
self.network_params = network_params
|
self.network_params = network_params
|
||||||
@ -69,94 +68,64 @@ class SoilSimulation(NetworkSimulation):
|
|||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.seed = str(seed) or str(time.time())
|
self.seed = str(seed) or str(time.time())
|
||||||
self.dump = dump
|
self.dump = dump
|
||||||
|
self.dry_run = dry_run
|
||||||
self.environment_params = environment_params or {}
|
self.environment_params = environment_params or {}
|
||||||
|
|
||||||
if load_module:
|
if load_module:
|
||||||
path = sys.path + [self.dir_path]
|
path = sys.path + [self.dir_path, os.getcwd()]
|
||||||
f, fp, desc = imp.find_module(load_module, path)
|
f, fp, desc = imp.find_module(load_module, path)
|
||||||
imp.load_module('soil.agents.custom', f, fp, desc)
|
imp.load_module('soil.agents.custom', f, fp, desc)
|
||||||
|
|
||||||
environment_agents = environment_agents or []
|
environment_agents = environment_agents or []
|
||||||
self.environment_agents = self._convert_agent_types(environment_agents)
|
self.environment_agents = agents._convert_agent_types(environment_agents)
|
||||||
|
|
||||||
distro = self.calculate_distribution(network_agents,
|
distro = agents.calculate_distribution(network_agents,
|
||||||
agent_type)
|
agent_type)
|
||||||
self.network_agents = self._convert_agent_types(distro)
|
self.network_agents = agents._convert_agent_types(distro)
|
||||||
|
|
||||||
self.states = self.validate_states(states,
|
self.states = agents._validate_states(states,
|
||||||
self.topology)
|
self.topology)
|
||||||
|
|
||||||
def calculate_distribution(self,
|
def run_simulation(self, *args, **kwargs):
|
||||||
network_agents=None,
|
return self.run(*args, **kwargs)
|
||||||
agent_type=None):
|
|
||||||
if network_agents:
|
|
||||||
network_agents = deepcopy(network_agents)
|
|
||||||
elif agent_type:
|
|
||||||
network_agents = [{'agent_type': agent_type}]
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Calculate the thresholds
|
def run(self, *args, **kwargs):
|
||||||
total = sum(x.get('weight', 1) for x in network_agents)
|
return list(self.run_simulation_gen(*args, **kwargs))
|
||||||
acc = 0
|
|
||||||
for v in network_agents:
|
|
||||||
upper = acc + (v.get('weight', 1)/total)
|
|
||||||
v['threshold'] = [acc, upper]
|
|
||||||
acc = upper
|
|
||||||
return network_agents
|
|
||||||
|
|
||||||
def serialize_distribution(self):
|
def run_simulation_gen(self, *args, parallel=False, **kwargs):
|
||||||
d = self._convert_agent_types(self.network_agents,
|
p = Pool()
|
||||||
to_string=True)
|
|
||||||
for v in d:
|
|
||||||
if 'threshold' in v:
|
|
||||||
del v['threshold']
|
|
||||||
return d
|
|
||||||
|
|
||||||
def _convert_agent_types(self, ind, to_string=False):
|
|
||||||
d = deepcopy(ind)
|
|
||||||
for v in d:
|
|
||||||
agent_type = v['agent_type']
|
|
||||||
if to_string and not isinstance(agent_type, str):
|
|
||||||
v['agent_type'] = str(agent_type.__name__)
|
|
||||||
elif not to_string and isinstance(agent_type, str):
|
|
||||||
v['agent_type'] = agents.agent_types[agent_type]
|
|
||||||
return d
|
|
||||||
|
|
||||||
def validate_states(self, states, topology):
|
|
||||||
states = states or []
|
|
||||||
# Validate states to avoid ignoring states during
|
|
||||||
# initialization
|
|
||||||
if isinstance(states, dict):
|
|
||||||
for x in states:
|
|
||||||
assert x in self.topology.node
|
|
||||||
else:
|
|
||||||
assert len(states) <= len(self.topology)
|
|
||||||
return states
|
|
||||||
|
|
||||||
def run_simulation(self):
|
|
||||||
return self.run()
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
return list(self.run_simulation_gen())
|
|
||||||
|
|
||||||
def run_simulation_gen(self, *args, **kwargs):
|
|
||||||
with utils.timer('simulation'):
|
with utils.timer('simulation'):
|
||||||
for i in range(self.num_trials):
|
if parallel:
|
||||||
res = self.run_trial(i)
|
func = partial(self.run_trial, return_env=not parallel)
|
||||||
if self.dump:
|
for i in p.imap_unordered(func, range(self.num_trials)):
|
||||||
res.dump_gexf(self.dir_path)
|
yield i
|
||||||
res.dump_csv(self.dir_path)
|
else:
|
||||||
yield res
|
for i in range(self.num_trials):
|
||||||
|
yield self.run_trial(i)
|
||||||
if self.dump:
|
if not self.dry_run:
|
||||||
logger.info('Dumping results to {}'.format(self.dir_path))
|
logger.info('Dumping results to {}'.format(self.dir_path))
|
||||||
self.dump_pickle(self.dir_path)
|
self.dump_pickle(self.dir_path)
|
||||||
self.dump_yaml(self.dir_path)
|
self.dump_yaml(self.dir_path)
|
||||||
else:
|
else:
|
||||||
logger.info('NOT dumping results')
|
logger.info('NOT dumping results')
|
||||||
|
|
||||||
def run_trial(self, trial_id=0, dump=False, dir_path=None):
|
def get_env(self, trial_id=0, dump=False, dir_path=None):
|
||||||
|
env_name = '{}_trial_{}'.format(self.name, trial_id)
|
||||||
|
env = environment.SoilEnvironment(name=env_name,
|
||||||
|
topology=self.topology.copy(),
|
||||||
|
seed=self.seed+env_name,
|
||||||
|
initial_time=0,
|
||||||
|
dry_run=self.dry_run,
|
||||||
|
interval=self.interval,
|
||||||
|
network_agents=self.network_agents,
|
||||||
|
states=self.states,
|
||||||
|
default_state=self.default_state,
|
||||||
|
environment_agents=self.environment_agents,
|
||||||
|
dir_path=dir_path or self.dir_path,
|
||||||
|
**self.environment_params)
|
||||||
|
return env
|
||||||
|
|
||||||
|
def run_trial(self, trial_id=0, dump=False, dir_path=None, until=None, return_env=False):
|
||||||
"""Run a single trial of the simulation
|
"""Run a single trial of the simulation
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -164,25 +133,16 @@ class SoilSimulation(NetworkSimulation):
|
|||||||
trial_id : int
|
trial_id : int
|
||||||
"""
|
"""
|
||||||
# Set-up trial environment and graph
|
# Set-up trial environment and graph
|
||||||
logger.info('Trial: {}'.format(trial_id))
|
until = until or self.max_time
|
||||||
env_name = '{}_trial_{}'.format(self.name, trial_id)
|
env = self.get_env(trial_id=trial_id, dump=dump, dir_path=dir_path)
|
||||||
env = environment.SoilEnvironment(name=env_name,
|
|
||||||
topology=self.topology.copy(),
|
|
||||||
seed=self.seed+env_name,
|
|
||||||
initial_time=0,
|
|
||||||
dump=self.dump,
|
|
||||||
interval=self.interval,
|
|
||||||
network_agents=self.network_agents,
|
|
||||||
states=self.states,
|
|
||||||
default_state=self.default_state,
|
|
||||||
environment_agents=self.environment_agents,
|
|
||||||
simulation=self,
|
|
||||||
**self.environment_params)
|
|
||||||
# Set up agents on nodes
|
# Set up agents on nodes
|
||||||
logger.info('\tRunning')
|
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
||||||
with utils.timer('trial'):
|
env.run(until)
|
||||||
env.run(until=self.max_time)
|
if self.dump and not self.dry_run:
|
||||||
return env
|
with utils.timer('Dumping simulation {} trial {}'.format(self.name, trial_id)):
|
||||||
|
env.dump(dir_path, formats=self.dump)
|
||||||
|
if return_env:
|
||||||
|
return env
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return self.__getstate__()
|
return self.__getstate__()
|
||||||
@ -213,16 +173,16 @@ class SoilSimulation(NetworkSimulation):
|
|||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
state = self.__dict__.copy()
|
state = self.__dict__.copy()
|
||||||
state['topology'] = json_graph.node_link_data(self.topology)
|
state['topology'] = json_graph.node_link_data(self.topology)
|
||||||
state['network_agents'] = self.serialize_distribution()
|
state['network_agents'] = agents._serialize_distribution(self.network_agents)
|
||||||
state['environment_agents'] = self._convert_agent_types(self.environment_agents,
|
state['environment_agents'] = agents._convert_agent_types(self.environment_agents,
|
||||||
to_string=True)
|
to_string=True)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
self.__dict__ = state
|
self.__dict__ = state
|
||||||
self.topology = json_graph.node_link_graph(state['topology'])
|
self.topology = json_graph.node_link_graph(state['topology'])
|
||||||
self.network_agents = self._convert_agent_types(self.network_agents)
|
self.network_agents = agents.calculate_distribution(agents._convert_agent_types(self.network_agents))
|
||||||
self.environment_agents = self._convert_agent_types(self.environment_agents)
|
self.environment_agents = agents._convert_agent_types(self.environment_agents)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
@ -235,21 +195,18 @@ def from_config(config, G=None):
|
|||||||
return sim
|
return sim
|
||||||
|
|
||||||
|
|
||||||
def run_from_config(*configs, dump=True, results_dir=None, timestamp=False):
|
def run_from_config(*configs, results_dir='soil_output', dump=None, timestamp=False, **kwargs):
|
||||||
if not results_dir:
|
|
||||||
results_dir = 'soil_output'
|
|
||||||
for config_def in configs:
|
for config_def in configs:
|
||||||
for config, cpath in utils.load_config(config_def):
|
for config, cpath in utils.load_config(config_def):
|
||||||
name = config.get('name', 'unnamed')
|
name = config.get('name', 'unnamed')
|
||||||
logger.info("Using config(s): {name}".format(name=name))
|
logger.info("Using config(s): {name}".format(name=name))
|
||||||
|
|
||||||
sim = SoilSimulation(**config)
|
|
||||||
if timestamp:
|
if timestamp:
|
||||||
sim_folder = '{}_{}'.format(sim.name,
|
sim_folder = '{}_{}'.format(name,
|
||||||
time.strftime("%Y-%m-%d_%H:%M:%S"))
|
time.strftime("%Y-%m-%d_%H:%M:%S"))
|
||||||
else:
|
else:
|
||||||
sim_folder = sim.name
|
sim_folder = name
|
||||||
sim.dir_path = os.path.join(results_dir, sim_folder)
|
dir_path = os.path.join(results_dir, sim_folder)
|
||||||
sim.dump = dump
|
sim = SoilSimulation(dir_path=dir_path, dump=dump, **config)
|
||||||
logger.info('Dumping results to {} : {}'.format(sim.dir_path, dump))
|
logger.info('Dumping results to {} : {}'.format(sim.dir_path, sim.dump))
|
||||||
results = sim.run_simulation()
|
results = sim.run_simulation(**kwargs)
|
||||||
|
@ -11,7 +11,7 @@ import networkx as nx
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger('soil')
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
@ -62,6 +62,7 @@ def load_config(config):
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def timer(name='task', pre="", function=logger.info, to_object=None):
|
def timer(name='task', pre="", function=logger.info, to_object=None):
|
||||||
start = time()
|
start = time()
|
||||||
|
function('{}Starting {} at {}.'.format(pre, name, start))
|
||||||
yield start
|
yield start
|
||||||
end = time()
|
end = time()
|
||||||
function('{}Finished {} in {} seconds'.format(pre, name, str(end-start)))
|
function('{}Finished {} in {} seconds'.format(pre, name, str(end-start)))
|
||||||
@ -70,21 +71,6 @@ def timer(name='task', pre="", function=logger.info, to_object=None):
|
|||||||
to_object.end = end
|
to_object.end = end
|
||||||
|
|
||||||
|
|
||||||
def agent_from_distribution(distribution, value=-1):
|
|
||||||
"""Find the agent """
|
|
||||||
if value < 0:
|
|
||||||
value = random()
|
|
||||||
for d in distribution:
|
|
||||||
threshold = d['threshold']
|
|
||||||
if value >= threshold[0] and value < threshold[1]:
|
|
||||||
state = {}
|
|
||||||
if 'state' in d:
|
|
||||||
state = deepcopy(d['state'])
|
|
||||||
return d['agent_type'], state
|
|
||||||
|
|
||||||
raise Exception('Distribution for value {} not found in: {}'.format(value, distribution))
|
|
||||||
|
|
||||||
|
|
||||||
def repr(v):
|
def repr(v):
|
||||||
if isinstance(v, bool):
|
if isinstance(v, bool):
|
||||||
v = "true" if v else ""
|
v = "true" if v else ""
|
||||||
|
Loading…
Reference in New Issue
Block a user