mirror of
https://github.com/gsi-upm/soil
synced 2025-09-13 19:52:20 +00:00
Compare commits
6 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
5d89827ccf | ||
|
fc48ed7e09 | ||
|
73c90887e8 | ||
|
497c8a55db | ||
|
7d1c800490 | ||
|
a4b32afa2f |
8
docker-compose.yml
Normal file
8
docker-compose.yml
Normal file
@@ -0,0 +1,8 @@
|
||||
version: '3'
|
||||
services:
|
||||
dev:
|
||||
build: .
|
||||
volumes:
|
||||
- .:/usr/src/app
|
||||
tty: true
|
||||
entrypoint: /bin/bash
|
334
examples/NewsSpread.ipynb
Normal file
334
examples/NewsSpread.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -68,7 +68,7 @@ network_agents:
|
||||
- agent_type: HerdViewer
|
||||
state:
|
||||
has_tv: true
|
||||
id: infected
|
||||
id: neutral
|
||||
weight: 1
|
||||
- agent_type: HerdViewer
|
||||
state:
|
||||
@@ -95,7 +95,7 @@ network_agents:
|
||||
- agent_type: HerdViewer
|
||||
state:
|
||||
has_tv: true
|
||||
id: infected
|
||||
id: neutral
|
||||
weight: 1
|
||||
- agent_type: WiseViewer
|
||||
state:
|
||||
@@ -121,7 +121,7 @@ network_agents:
|
||||
- agent_type: WiseViewer
|
||||
state:
|
||||
has_tv: true
|
||||
id: infected
|
||||
id: neutral
|
||||
weight: 1
|
||||
- agent_type: WiseViewer
|
||||
state:
|
||||
|
@@ -1,5 +1,4 @@
|
||||
from soil.agents import BaseAgent,FSM, state, default_state
|
||||
import random
|
||||
from soil.agents import FSM, state, default_state, prob
|
||||
import logging
|
||||
|
||||
|
||||
@@ -10,70 +9,73 @@ class DumbViewer(FSM):
|
||||
'''
|
||||
defaults = {
|
||||
'prob_neighbor_spread': 0.5,
|
||||
'prob_neighbor_cure': 0.25,
|
||||
'prob_tv_spread': 0.1,
|
||||
}
|
||||
|
||||
@default_state
|
||||
@state
|
||||
def neutral(self):
|
||||
r = random.random()
|
||||
if self['has_tv'] and r < self.env['prob_tv_spread']:
|
||||
self.infect()
|
||||
return
|
||||
if self['has_tv']:
|
||||
if prob(self.env['prob_tv_spread']):
|
||||
self.set_state(self.infected)
|
||||
|
||||
@state
|
||||
def infected(self):
|
||||
for neighbor in self.get_neighboring_agents(state_id=self.neutral.id):
|
||||
prob_infect = self.env['prob_neighbor_spread']
|
||||
r = random.random()
|
||||
if r < prob_infect:
|
||||
self.set_state(self.infected.id)
|
||||
if prob(self.env['prob_neighbor_spread']):
|
||||
neighbor.infect()
|
||||
return
|
||||
|
||||
def infect(self):
|
||||
self.set_state(self.infected)
|
||||
|
||||
|
||||
class HerdViewer(DumbViewer):
|
||||
'''
|
||||
A viewer whose probability of infection depends on the state of its neighbors.
|
||||
'''
|
||||
|
||||
level = logging.DEBUG
|
||||
|
||||
|
||||
def infect(self):
|
||||
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
||||
total = self.count_neighboring_agents()
|
||||
prob_infect = self.env['prob_neighbor_spread'] * infected/total
|
||||
self.debug('prob_infect', prob_infect)
|
||||
r = random.random()
|
||||
if r < prob_infect:
|
||||
if prob(prob_infect):
|
||||
self.set_state(self.infected.id)
|
||||
|
||||
|
||||
class WiseViewer(HerdViewer):
|
||||
'''
|
||||
A viewer that can change its mind.
|
||||
'''
|
||||
|
||||
defaults = {
|
||||
'prob_neighbor_spread': 0.5,
|
||||
'prob_neighbor_cure': 0.25,
|
||||
'prob_tv_spread': 0.1,
|
||||
}
|
||||
|
||||
@state
|
||||
def cured(self):
|
||||
prob_cure = self.env['prob_neighbor_cure']
|
||||
for neighbor in self.get_neighboring_agents(state_id=self.infected.id):
|
||||
r = random.random()
|
||||
if r < prob_cure:
|
||||
if prob(prob_cure):
|
||||
try:
|
||||
neighbor.cure()
|
||||
except AttributeError:
|
||||
self.debug('Viewer {} cannot be cured'.format(neighbor.id))
|
||||
return
|
||||
|
||||
def cure(self):
|
||||
self.set_state(self.cured.id)
|
||||
|
||||
@state
|
||||
def infected(self):
|
||||
prob_cure = self.env['prob_neighbor_cure']
|
||||
r = random.random()
|
||||
if r < prob_cure:
|
||||
self.cure()
|
||||
return
|
||||
return super().infected()
|
||||
cured = max(self.count_neighboring_agents(self.cured.id),
|
||||
1.0)
|
||||
infected = max(self.count_neighboring_agents(self.infected.id),
|
||||
1.0)
|
||||
prob_cure = self.env['prob_neighbor_cure'] * (cured/infected)
|
||||
if prob(prob_cure):
|
||||
return self.cure()
|
||||
return self.set_state(super().infected)
|
||||
|
File diff suppressed because one or more lines are too long
@@ -1,6 +1,7 @@
|
||||
nxsim
|
||||
simpy
|
||||
networkx
|
||||
networkx>=2.0
|
||||
numpy
|
||||
matplotlib
|
||||
pyyaml
|
||||
pandas
|
||||
|
29
setup.py
29
setup.py
@@ -1,20 +1,21 @@
|
||||
import pip
|
||||
import os
|
||||
from setuptools import setup
|
||||
# parse_requirements() returns generator of pip.req.InstallRequirement objects
|
||||
from pip.req import parse_requirements
|
||||
from soil import __version__
|
||||
|
||||
try:
|
||||
install_reqs = parse_requirements(
|
||||
"requirements.txt", session=pip.download.PipSession())
|
||||
test_reqs = parse_requirements(
|
||||
"test-requirements.txt", session=pip.download.PipSession())
|
||||
except AttributeError:
|
||||
install_reqs = parse_requirements("requirements.txt")
|
||||
test_reqs = parse_requirements("test-requirements.txt")
|
||||
|
||||
install_reqs = [str(ir.req) for ir in install_reqs]
|
||||
test_reqs = [str(ir.req) for ir in test_reqs]
|
||||
with open(os.path.join('soil', 'VERSION')) as f:
|
||||
__version__ = f.readlines()[0].strip()
|
||||
assert __version__
|
||||
|
||||
|
||||
def parse_requirements(filename):
|
||||
""" load requirements from a pip requirements file """
|
||||
with open(filename, 'r') as f:
|
||||
lineiter = list(line.strip() for line in f)
|
||||
return [line for line in lineiter if line and not line.startswith("#")]
|
||||
|
||||
|
||||
install_reqs = parse_requirements("requirements.txt")
|
||||
test_reqs = parse_requirements("test-requirements.txt")
|
||||
|
||||
|
||||
setup(
|
||||
|
@@ -1,61 +0,0 @@
|
||||
---
|
||||
name: ControlModelM2_sim
|
||||
max_time: 50
|
||||
num_trials: 1
|
||||
network_params:
|
||||
generator: barabasi_albert_graph
|
||||
n: 100
|
||||
m: 2
|
||||
network_agents:
|
||||
- agent_type: ControlModelM2
|
||||
weight: 0.1
|
||||
state:
|
||||
id: 1
|
||||
- agent_type: ControlModelM2
|
||||
weight: 0.9
|
||||
state:
|
||||
id: 0
|
||||
environment_params:
|
||||
prob_neutral_making_denier: 0.035
|
||||
prob_infect: 0.075
|
||||
prob_cured_healing_infected: 0.035
|
||||
prob_cured_vaccinate_neutral: 0.035
|
||||
prob_vaccinated_healing_infected: 0.035
|
||||
prob_vaccinated_vaccinate_neutral: 0.035
|
||||
prob_generate_anti_rumor: 0.035
|
||||
standard_variance: 0.055
|
||||
---
|
||||
name: SISA_sm
|
||||
max_time: 50
|
||||
num_trials: 2
|
||||
network_params:
|
||||
generator: erdos_renyi_graph
|
||||
n: 1000
|
||||
p: 0.05
|
||||
#other_agents:
|
||||
# - agent_type: DrawingAgent
|
||||
network_agents:
|
||||
- agent_type: SISaModel
|
||||
weight: 1
|
||||
state:
|
||||
id: content
|
||||
- agent_type: SISaModel
|
||||
weight: 1
|
||||
state:
|
||||
id: neutral
|
||||
- agent_type: SISaModel
|
||||
weight: 1
|
||||
state:
|
||||
id: discontent
|
||||
environment_params:
|
||||
neutral_discontent_spon_prob: 0.04
|
||||
neutral_discontent_infected_prob: 0.04
|
||||
neutral_content_spon_prob: 0.18
|
||||
neutral_content_infected_prob: 0.02
|
||||
discontent_neutral: 0.13
|
||||
discontent_content: 0.07
|
||||
variance_d_c: 0.02
|
||||
content_discontent: 0.009
|
||||
variance_c_d: 0.003
|
||||
content_neutral: 0.088
|
||||
standard_variance: 0.055
|
1
soil/VERSION
Normal file
1
soil/VERSION
Normal file
@@ -0,0 +1 @@
|
||||
0.11.1
|
@@ -4,7 +4,7 @@ import os
|
||||
import pdb
|
||||
import logging
|
||||
|
||||
__version__ = "0.10"
|
||||
from .version import __version__
|
||||
|
||||
try:
|
||||
basestring
|
||||
@@ -35,8 +35,14 @@ def main():
|
||||
help='Do not store the results of the simulation.')
|
||||
parser.add_argument('--pdb', action='store_true',
|
||||
help='Use a pdb console in case of exception.')
|
||||
parser.add_argument('--output', '-o', type=str,
|
||||
parser.add_argument('--graph', '-g', action='store_true',
|
||||
help='Dump GEXF graph. Defaults to false.')
|
||||
parser.add_argument('--csv', action='store_true',
|
||||
help='Dump history in CSV format. Defaults to false.')
|
||||
parser.add_argument('--output', '-o', type=str, default="soil_output",
|
||||
help='folder to write results to. It defaults to the current directory.')
|
||||
parser.add_argument('--synchronous', action='store_true',
|
||||
help='Run trials serially and synchronously instead of in parallel. Defaults to false.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -47,7 +53,17 @@ def main():
|
||||
logging.info('Loading config file: {}'.format(args.file, args.output))
|
||||
|
||||
try:
|
||||
simulation.run_from_config(args.file, dump=(not args.dry_run), results_dir=args.output)
|
||||
dump = []
|
||||
if not args.dry_run:
|
||||
if args.csv:
|
||||
dump.append('csv')
|
||||
if args.graph:
|
||||
dump.append('gexf')
|
||||
simulation.run_from_config(args.file,
|
||||
dry_run=args.dry_run,
|
||||
dump=dump,
|
||||
parallel=(not args.synchronous and not args.pdb),
|
||||
results_dir=args.output)
|
||||
except Exception as ex:
|
||||
if args.pdb:
|
||||
pdb.post_mortem()
|
||||
|
@@ -1,36 +1,4 @@
|
||||
import importlib
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import logging
|
||||
from . import simulation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run a SOIL simulation')
|
||||
parser.add_argument('file', type=str,
|
||||
nargs="?",
|
||||
default='simulation.yml',
|
||||
help='python module containing the simulation configuration.')
|
||||
parser.add_argument('--module', '-m', type=str,
|
||||
help='file containing the code of any custom agents.')
|
||||
parser.add_argument('--dry-run', '--dry', action='store_true',
|
||||
help='Do not store the results of the simulation.')
|
||||
parser.add_argument('--output', '-o', type=str,
|
||||
help='folder to write results to. It defaults to the current directory.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.module:
|
||||
sys.path.append(os.getcwd())
|
||||
importlib.import_module(args.module)
|
||||
|
||||
print('Loading config file: {}'.format(args.file, args.output))
|
||||
simulation.run_from_config(args.file, dump=not args.dry_run, results_dir=args.output)
|
||||
|
||||
from . import main
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@@ -11,9 +11,9 @@ class CounterModel(BaseAgent):
|
||||
# Outside effects
|
||||
total = len(list(self.get_all_agents()))
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self.state['times'] = self.state.get('times', 0) + 1
|
||||
self.state['neighbors'] = neighbors
|
||||
self.state['total'] = total
|
||||
self['times'] = self.get('times', 0) + 1
|
||||
self['neighbors'] = neighbors
|
||||
self['total'] = total
|
||||
|
||||
|
||||
class AggregatedCounter(BaseAgent):
|
||||
@@ -26,7 +26,7 @@ class AggregatedCounter(BaseAgent):
|
||||
# Outside effects
|
||||
total = len(list(self.get_all_agents()))
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self.state['times'] = self.state.get('times', 0) + 1
|
||||
self.state['neighbors'] = self.state.get('neighbors', 0) + neighbors
|
||||
self.state['total'] = total = self.state.get('total', 0) + total
|
||||
self['times'] = self.get('times', 0) + 1
|
||||
self['neighbors'] = self.get('neighbors', 0) + neighbors
|
||||
self['total'] = total = self.get('total', 0) + total
|
||||
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
||||
|
@@ -15,4 +15,4 @@ class DrawingAgent(BaseAgent):
|
||||
# Outside effects
|
||||
f = plt.figure()
|
||||
nx.draw(self.env.G, node_size=10, width=0.2, pos=nx.spring_layout(self.env.G, scale=100), ax=f.add_subplot(111))
|
||||
f.savefig(os.path.join(self.env.sim().dir_path, "graph-"+str(self.env.now)+".png"))
|
||||
f.savefig(os.path.join(self.env.get_path(), "graph-"+str(self.env.now)+".png"))
|
||||
|
@@ -12,9 +12,9 @@ from copy import deepcopy
|
||||
from functools import partial
|
||||
import json
|
||||
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from .. import utils, history
|
||||
|
||||
agent_types = {}
|
||||
|
||||
@@ -32,33 +32,67 @@ class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
||||
|
||||
defaults = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, environment=None, agent_id=None, state=None,
|
||||
name='network_process', interval=None, **state_params):
|
||||
# Check for REQUIRED arguments
|
||||
assert environment is not None, TypeError('__init__ missing 1 required keyword argument: \'environment\'. '
|
||||
'Cannot be NoneType.')
|
||||
# Initialize agent parameters
|
||||
self.id = agent_id
|
||||
self.name = name
|
||||
self.state_params = state_params
|
||||
|
||||
# Global parameters
|
||||
self.global_topology = environment.G
|
||||
self.environment_params = environment.environment_params
|
||||
|
||||
# Register agent to environment
|
||||
self.env = environment
|
||||
|
||||
self._neighbors = None
|
||||
self.alive = True
|
||||
state = deepcopy(self.defaults)
|
||||
state.update(kwargs.pop('state', {}))
|
||||
kwargs['state'] = state
|
||||
super().__init__(**kwargs)
|
||||
real_state = deepcopy(self.defaults)
|
||||
real_state.update(state or {})
|
||||
self._state = real_state
|
||||
self.interval = interval
|
||||
|
||||
if not hasattr(self, 'level'):
|
||||
self.level = logging.DEBUG
|
||||
self.logger = logging.getLogger('Agent-{}'.format(self.id))
|
||||
self.logger = logging.getLogger('{}-Agent-{}'.format(self.env.name,
|
||||
self.id))
|
||||
self.logger.setLevel(self.level)
|
||||
|
||||
# initialize every time an instance of the agent is created
|
||||
self.action = self.env.process(self.run())
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self._state
|
||||
|
||||
@state.setter
|
||||
def state(self, value):
|
||||
for k, v in value.items():
|
||||
self[k] = v
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, tuple):
|
||||
k, t_step = key
|
||||
return self.env[self.id, t_step, k]
|
||||
key, t_step = key
|
||||
k = history.Key(key=key, t_step=t_step, agent_id=self.id)
|
||||
return self.env[k]
|
||||
return self.state.get(key, None)
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.state[key]
|
||||
self.state[key] = None
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.state
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.state[key] = value
|
||||
k = history.Key(t_step=self.now,
|
||||
agent_id=self.id,
|
||||
key=key)
|
||||
self.env[k] = value
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self[key] if key in self else default
|
||||
@@ -72,7 +106,12 @@ class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
||||
return None
|
||||
|
||||
def run(self):
|
||||
interval = self.env.interval
|
||||
if self.interval is not None:
|
||||
interval = self.interval
|
||||
elif 'interval' in self:
|
||||
interval = self['interval']
|
||||
else:
|
||||
interval = self.env.interval
|
||||
while self.alive:
|
||||
res = self.step()
|
||||
yield res or self.env.timeout(interval)
|
||||
@@ -95,7 +134,7 @@ class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
||||
agents = self.global_topology.nodes()
|
||||
count = 0
|
||||
for agent in agents:
|
||||
if state_id and state_id != self.global_topology.node[agent]['agent'].state['id']:
|
||||
if state_id and state_id != self.global_topology.node[agent]['agent']['id']:
|
||||
continue
|
||||
count += 1
|
||||
return count
|
||||
@@ -140,20 +179,24 @@ class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
||||
|
||||
|
||||
def state(func):
|
||||
'''
|
||||
A state function should return either a state id, or a tuple (state_id, when)
|
||||
The default value for state_id is the current state id.
|
||||
The default value for when is the interval defined in the nevironment.
|
||||
'''
|
||||
|
||||
@wraps(func)
|
||||
def func_wrapper(self):
|
||||
when = None
|
||||
next_state = func(self)
|
||||
when = None
|
||||
if next_state is None:
|
||||
return when
|
||||
try:
|
||||
next_state, when = next_state
|
||||
except TypeError:
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if next_state:
|
||||
try:
|
||||
self.state['id'] = next_state.id
|
||||
except AttributeError:
|
||||
raise ValueError('State id %s is not valid.' % next_state)
|
||||
self.set_state(next_state)
|
||||
return when
|
||||
|
||||
func_wrapper.id = func.__name__
|
||||
@@ -193,11 +236,13 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FSM, self).__init__(*args, **kwargs)
|
||||
if 'id' not in self.state:
|
||||
self.state['id'] = self.default_state.id
|
||||
if not self.default_state:
|
||||
raise ValueError('No default state specified for {}'.format(self.id))
|
||||
self['id'] = self.default_state.id
|
||||
|
||||
def step(self):
|
||||
if 'id' in self.state:
|
||||
next_state = self.state['id']
|
||||
next_state = self['id']
|
||||
elif self.default_state:
|
||||
next_state = self.default_state.id
|
||||
else:
|
||||
@@ -211,7 +256,117 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
state = state.id
|
||||
if state not in self.states:
|
||||
raise ValueError('{} is not a valid state'.format(state))
|
||||
self.state['id'] = state
|
||||
self['id'] = state
|
||||
return state
|
||||
|
||||
|
||||
def prob(prob=1):
|
||||
'''
|
||||
A true/False uniform distribution with a given probability.
|
||||
To be used like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
if prob(0.3):
|
||||
do_something()
|
||||
|
||||
'''
|
||||
r = random.random()
|
||||
return r < prob
|
||||
|
||||
|
||||
def calculate_distribution(network_agents=None,
|
||||
agent_type=None):
|
||||
'''
|
||||
Calculate the threshold values (thresholds for a uniform distribution)
|
||||
of an agent distribution given the weights of each agent type.
|
||||
|
||||
The input has this form: ::
|
||||
|
||||
[
|
||||
{'agent_type': 'agent_type_1',
|
||||
'weight': 0.2,
|
||||
'state': {
|
||||
'id': 0
|
||||
}
|
||||
},
|
||||
{'agent_type': 'agent_type_2',
|
||||
'weight': 0.8,
|
||||
'state': {
|
||||
'id': 1
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
In this example, 20% of the nodes will be marked as type
|
||||
'agent_type_1'.
|
||||
'''
|
||||
if network_agents:
|
||||
network_agents = deepcopy(network_agents)
|
||||
elif agent_type:
|
||||
network_agents = [{'agent_type': agent_type}]
|
||||
else:
|
||||
return []
|
||||
|
||||
# Calculate the thresholds
|
||||
total = sum(x.get('weight', 1) for x in network_agents)
|
||||
acc = 0
|
||||
for v in network_agents:
|
||||
upper = acc + (v.get('weight', 1)/total)
|
||||
v['threshold'] = [acc, upper]
|
||||
acc = upper
|
||||
return network_agents
|
||||
|
||||
|
||||
def _serialize_distribution(network_agents):
|
||||
d = _convert_agent_types(network_agents,
|
||||
to_string=True)
|
||||
'''
|
||||
When serializing an agent distribution, remove the thresholds, in order
|
||||
to avoid cluttering the YAML definition file.
|
||||
'''
|
||||
for v in d:
|
||||
if 'threshold' in v:
|
||||
del v['threshold']
|
||||
return d
|
||||
|
||||
|
||||
def _validate_states(states, topology):
|
||||
'''Validate states to avoid ignoring states during initialization'''
|
||||
states = states or []
|
||||
if isinstance(states, dict):
|
||||
for x in states:
|
||||
assert x in topology.node
|
||||
else:
|
||||
assert len(states) <= len(topology)
|
||||
return states
|
||||
|
||||
|
||||
def _convert_agent_types(ind, to_string=False):
|
||||
'''Convenience method to allow specifying agents by class or class name.'''
|
||||
d = deepcopy(ind)
|
||||
for v in d:
|
||||
agent_type = v['agent_type']
|
||||
if to_string and not isinstance(agent_type, str):
|
||||
v['agent_type'] = str(agent_type.__name__)
|
||||
elif not to_string and isinstance(agent_type, str):
|
||||
v['agent_type'] = agent_types[agent_type]
|
||||
return d
|
||||
|
||||
|
||||
def _agent_from_distribution(distribution, value=-1):
|
||||
"""Used in the initialization of agents given an agent distribution."""
|
||||
if value < 0:
|
||||
value = random.random()
|
||||
for d in distribution:
|
||||
threshold = d['threshold']
|
||||
if value >= threshold[0] and value < threshold[1]:
|
||||
state = {}
|
||||
if 'state' in d:
|
||||
state = deepcopy(d['state'])
|
||||
return d['agent_type'], state
|
||||
|
||||
raise Exception('Distribution for value {} not found in: {}'.format(value, distribution))
|
||||
|
||||
|
||||
from .BassModel import *
|
||||
|
@@ -4,7 +4,7 @@ import glob
|
||||
import yaml
|
||||
from os.path import join
|
||||
|
||||
from . import utils
|
||||
from . import utils, history
|
||||
|
||||
|
||||
def read_data(*args, group=False, **kwargs):
|
||||
@@ -15,8 +15,9 @@ def read_data(*args, group=False, **kwargs):
|
||||
return list(iterable)
|
||||
|
||||
|
||||
def _read_data(pattern, keys=None, convert_types=False,
|
||||
process=None, from_csv=False, **kwargs):
|
||||
def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
|
||||
if not process_args:
|
||||
process_args = {}
|
||||
for folder in glob.glob(pattern):
|
||||
config_file = glob.glob(join(folder, '*.yml'))[0]
|
||||
config = yaml.load(open(config_file))
|
||||
@@ -24,19 +25,20 @@ def _read_data(pattern, keys=None, convert_types=False,
|
||||
if from_csv:
|
||||
for trial_data in sorted(glob.glob(join(folder,
|
||||
'*.environment.csv'))):
|
||||
df = read_csv(trial_data, convert_types=convert_types)
|
||||
if process:
|
||||
df = process(df, **kwargs)
|
||||
df = read_csv(trial_data, **kwargs)
|
||||
yield config_file, df, config
|
||||
else:
|
||||
for trial_data in sorted(glob.glob(join(folder, '*.db.sqlite'))):
|
||||
df = read_sql(trial_data, convert_types=convert_types,
|
||||
keys=keys)
|
||||
if process:
|
||||
df = process(df, **kwargs)
|
||||
df = read_sql(trial_data, **kwargs)
|
||||
yield config_file, df, config
|
||||
|
||||
|
||||
def read_sql(db, *args, **kwargs):
|
||||
h = history.History(db, backup=False)
|
||||
df = h.read_sql(*args, **kwargs)
|
||||
return df
|
||||
|
||||
|
||||
def read_csv(filename, keys=None, convert_types=False, **kwargs):
|
||||
'''
|
||||
Read a CSV in canonical form: ::
|
||||
@@ -49,18 +51,7 @@ def read_csv(filename, keys=None, convert_types=False, **kwargs):
|
||||
df = convert_types_slow(df)
|
||||
if keys:
|
||||
df = df[df['key'].isin(keys)]
|
||||
return df
|
||||
|
||||
|
||||
def read_sql(filename, keys=None, convert_types=False, limit=-1):
|
||||
condition = ''
|
||||
if keys:
|
||||
k = map(lambda x: "\'{}\'".format(x), keys)
|
||||
condition = 'where key in ({})'.format(','.join(k))
|
||||
query = 'select * from history {} limit {}'.format(condition, limit)
|
||||
df = pd.read_sql_query(query, 'sqlite:///{}'.format(filename))
|
||||
if convert_types:
|
||||
df = convert_types_slow(df)
|
||||
df = process_one(df)
|
||||
return df
|
||||
|
||||
|
||||
@@ -108,8 +99,9 @@ def get_types(df):
|
||||
return {k:v[0] for k,v in dtypes.iteritems()}
|
||||
|
||||
|
||||
def process_one(df, *keys, columns=['key'], values='value',
|
||||
index=['t_step', 'agent_id'], aggfunc='first', **kwargs):
|
||||
def process_one(df, *keys, columns=['key', 'agent_id'], values='value',
|
||||
fill=True, index=['t_step',],
|
||||
aggfunc='first', **kwargs):
|
||||
'''
|
||||
Process a dataframe in canonical form ``(t_step, agent_id, key, value, value_type)`` into
|
||||
a dataframe with a column per key
|
||||
@@ -119,35 +111,29 @@ def process_one(df, *keys, columns=['key'], values='value',
|
||||
if keys:
|
||||
df = df[df['key'].isin(keys)]
|
||||
|
||||
dtypes = get_types(df)
|
||||
|
||||
df = df.pivot_table(values=values, index=index, columns=columns,
|
||||
aggfunc=aggfunc, **kwargs)
|
||||
df = df.fillna(0).astype(dtypes)
|
||||
if fill:
|
||||
df = fillna(df)
|
||||
return df
|
||||
|
||||
|
||||
def get_count_processed(df, *keys):
|
||||
if keys:
|
||||
df = df[list(keys)]
|
||||
# p = df.groupby(level=0).apply(pd.Series.value_counts)
|
||||
p = df.unstack().apply(pd.Series.value_counts, axis=1)
|
||||
return p
|
||||
|
||||
|
||||
def get_count(df, *keys):
|
||||
if keys:
|
||||
df = df[df['key'].isin(keys)]
|
||||
p = df.groupby(by=['t_step', 'key', 'value']).size().unstack(level=[1,2]).fillna(0)
|
||||
return p
|
||||
df = df[list(keys)]
|
||||
counts = pd.DataFrame()
|
||||
for key in df.columns.levels[0]:
|
||||
g = df[key].apply(pd.Series.value_counts, axis=1).fillna(0)
|
||||
for value, series in g.iteritems():
|
||||
counts[key, value] = series
|
||||
counts.columns = pd.MultiIndex.from_tuples(counts.columns)
|
||||
return counts
|
||||
|
||||
|
||||
def get_value(df, *keys, aggfunc='sum'):
|
||||
if keys:
|
||||
df = df[df['key'].isin(keys)]
|
||||
p = process_one(df, *keys)
|
||||
p = p.groupby(level='t_step').agg(aggfunc)
|
||||
return p
|
||||
df = df[list(keys)]
|
||||
return df.groupby(axis=1, level=0).agg(aggfunc, axis=1)
|
||||
|
||||
|
||||
def plot_all(*args, **kwargs):
|
||||
@@ -175,4 +161,6 @@ def group_trials(trials, aggfunc=['mean', 'min', 'max', 'std']):
|
||||
return pd.concat(trials).groupby(level=0).agg(aggfunc).reorder_levels([2, 0,1] ,axis=1)
|
||||
|
||||
|
||||
|
||||
def fillna(df):
|
||||
new_df = df.ffill(axis=0)
|
||||
return new_df
|
||||
|
@@ -1,19 +1,30 @@
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
import weakref
|
||||
import csv
|
||||
import random
|
||||
import simpy
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
from copy import deepcopy
|
||||
from networkx.readwrite import json_graph
|
||||
|
||||
import networkx as nx
|
||||
import nxsim
|
||||
|
||||
from . import utils
|
||||
from . import utils, agents, analysis, history
|
||||
|
||||
|
||||
class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
"""
|
||||
The environment is key in a simulation. It contains the network topology,
|
||||
a reference to network and environment agents, as well as the environment
|
||||
params, which are used as shared state between agents.
|
||||
|
||||
The environment parameters and the state of every agent can be accessed
|
||||
both by using the environment as a dictionary or with the environment's
|
||||
:meth:`soil.environment.SoilEnvironment.get` method.
|
||||
"""
|
||||
|
||||
def __init__(self, name=None,
|
||||
network_agents=None,
|
||||
@@ -22,42 +33,31 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
default_state=None,
|
||||
interval=1,
|
||||
seed=None,
|
||||
dump=False,
|
||||
simulation=None,
|
||||
dry_run=False,
|
||||
dir_path=None,
|
||||
topology=None,
|
||||
*args, **kwargs):
|
||||
self.name = name or 'UnnamedEnvironment'
|
||||
if isinstance(states, list):
|
||||
states = dict(enumerate(states))
|
||||
self.states = deepcopy(states) if states else {}
|
||||
self.default_state = deepcopy(default_state) or {}
|
||||
self.sim = weakref.ref(simulation)
|
||||
if 'topology' not in kwargs and simulation:
|
||||
kwargs['topology'] = self.sim().topology.copy()
|
||||
super().__init__(*args, **kwargs)
|
||||
if not topology:
|
||||
topology = nx.Graph()
|
||||
super().__init__(*args, topology=topology, **kwargs)
|
||||
self._env_agents = {}
|
||||
self.dry_run = dry_run
|
||||
self.interval = interval
|
||||
self.dump = dump
|
||||
self.dir_path = dir_path or tempfile.mkdtemp('soil-env')
|
||||
self.get_path()
|
||||
self._history = history.History(name=self.name if not dry_run else None,
|
||||
dir_path=self.dir_path)
|
||||
# Add environment agents first, so their events get
|
||||
# executed before network agents
|
||||
self['SEED'] = seed or time.time()
|
||||
random.seed(self['SEED'])
|
||||
self.process(self.save_state())
|
||||
self.environment_agents = environment_agents or []
|
||||
self.network_agents = network_agents or []
|
||||
if self.dump:
|
||||
self._db_path = os.path.join(self.get_path(), '{}.db.sqlite'.format(self.name))
|
||||
else:
|
||||
self._db_path = ":memory:"
|
||||
self.create_db(self._db_path)
|
||||
|
||||
def create_db(self, db_path=None):
|
||||
db_path = db_path or self._db_path
|
||||
if os.path.exists(db_path):
|
||||
newname = db_path.replace('db.sqlite', 'backup{}.sqlite'.format(time.time()))
|
||||
os.rename(db_path, newname)
|
||||
self._db = sqlite3.connect(db_path)
|
||||
with self._db:
|
||||
self._db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text, value_type text)''')
|
||||
self['SEED'] = seed or time.time()
|
||||
random.seed(self['SEED'])
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
@@ -90,11 +90,11 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
|
||||
@network_agents.setter
|
||||
def network_agents(self, network_agents):
|
||||
if not network_agents:
|
||||
return
|
||||
for ix in self.G.nodes():
|
||||
i = ix
|
||||
node = self.G.node[i]
|
||||
agent, state = utils.agent_from_distribution(network_agents)
|
||||
self.set_agent(i, agent_type=agent, state=state)
|
||||
agent, state = agents._agent_from_distribution(network_agents)
|
||||
self.set_agent(ix, agent_type=agent, state=state)
|
||||
|
||||
def set_agent(self, agent_id, agent_type, state=None):
|
||||
node = self.G.nodes[agent_id]
|
||||
@@ -121,16 +121,21 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
return self.G.add_edge(agent1, agent2)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
self._save_state()
|
||||
super().run(*args, **kwargs)
|
||||
self._history.flush_cache()
|
||||
|
||||
def _save_state(self, now=None):
|
||||
# for agent in self.agents:
|
||||
# agent.save_state()
|
||||
utils.logger.debug('Saving state @{}'.format(self.now))
|
||||
with self._db:
|
||||
self._db.executemany("insert into history(agent_id, t_step, key, value, value_type) values (?, ?, ?, ?, ?)", self.state_to_tuples(now=now))
|
||||
self._history.save_records(self.state_to_tuples(now=now))
|
||||
|
||||
def save_state(self):
|
||||
'''
|
||||
:DEPRECATED:
|
||||
Periodically save the state of the environment and the agents.
|
||||
'''
|
||||
self._save_state()
|
||||
while self.peek() != simpy.core.Infinity:
|
||||
delay = max(self.peek() - self.now, self.interval)
|
||||
@@ -145,66 +150,44 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, tuple):
|
||||
values = {"agent_id": key[0],
|
||||
"t_step": key[1],
|
||||
"key": key[2],
|
||||
"value": None,
|
||||
"value_type": None
|
||||
}
|
||||
|
||||
fields = list(k for k, v in values.items() if v is None)
|
||||
conditions = " and ".join("{}='{}'".format(k, v) for k, v in values.items() if v is not None)
|
||||
|
||||
query = """SELECT {fields} from history""".format(fields=",".join(fields))
|
||||
if conditions:
|
||||
query = """{query} where {conditions}""".format(query=query,
|
||||
conditions=conditions)
|
||||
with self._db:
|
||||
rows = self._db.execute(query).fetchall()
|
||||
|
||||
utils.logger.debug(rows)
|
||||
results = self.rows_to_dict(rows)
|
||||
return results
|
||||
self._history.flush_cache()
|
||||
return self._history[key]
|
||||
|
||||
return self.environment_params[key]
|
||||
|
||||
def rows_to_dict(self, rows):
|
||||
if len(rows) < 1:
|
||||
return None
|
||||
|
||||
level = len(rows[0])-2
|
||||
|
||||
if level == 0:
|
||||
if len(rows) != 1:
|
||||
raise ValueError('Cannot convert {} to dictionaries'.format(rows))
|
||||
value, value_type = rows[0]
|
||||
return utils.convert(value, value_type)
|
||||
|
||||
results = {}
|
||||
for row in rows:
|
||||
item = results
|
||||
for i in range(level-1):
|
||||
key = row[i]
|
||||
if key not in item:
|
||||
item[key] = {}
|
||||
item = item[key]
|
||||
key, value, value_type = row[level-1:]
|
||||
item[key] = utils.convert(value, value_type)
|
||||
return results
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if isinstance(key, tuple):
|
||||
k = history.Key(*key)
|
||||
self._history.save_record(*k,
|
||||
value=value)
|
||||
return
|
||||
self.environment_params[key] = value
|
||||
self._history.save_record(agent_id='env',
|
||||
t_step=self.now,
|
||||
key=key,
|
||||
value=value)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.environment_params
|
||||
|
||||
def get(self, key, default=None):
|
||||
'''
|
||||
Get the value of an environment attribute in a
|
||||
given point in the simulation (history).
|
||||
If key is an attribute name, this method returns
|
||||
the current value.
|
||||
To get values at other times, use a
|
||||
:meth: `soil.history.Key` tuple.
|
||||
'''
|
||||
return self[key] if key in self else default
|
||||
|
||||
def get_path(self, dir_path=None):
|
||||
dir_path = dir_path or self.sim().dir_path
|
||||
dir_path = dir_path or self.dir_path
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
try:
|
||||
os.makedirs(dir_path)
|
||||
except FileExistsError:
|
||||
pass
|
||||
return dir_path
|
||||
|
||||
def get_agent(self, agent_id):
|
||||
@@ -227,23 +210,45 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
G = self.history_to_graph()
|
||||
graph_path = os.path.join(self.get_path(dir_path),
|
||||
self.name+".gexf")
|
||||
# Workaround for geometric models
|
||||
# See soil/soil#4
|
||||
for node in G.nodes():
|
||||
if 'pos' in G.node[node]:
|
||||
G.node[node]['viz'] = {"position": {"x": G.node[node]['pos'][0], "y": G.node[node]['pos'][1], "z": 0.0}}
|
||||
del (G.node[node]['pos'])
|
||||
|
||||
nx.write_gexf(G, graph_path, version="1.2draft")
|
||||
|
||||
def dump(self, dir_path=None, formats=None):
|
||||
if not formats:
|
||||
return
|
||||
functions = {
|
||||
'csv': self.dump_csv,
|
||||
'gexf': self.dump_gexf
|
||||
}
|
||||
for f in formats:
|
||||
if f in functions:
|
||||
functions[f](dir_path)
|
||||
else:
|
||||
raise ValueError('Unknown format: {}'.format(f))
|
||||
|
||||
def state_to_tuples(self, now=None):
|
||||
if now is None:
|
||||
now = self.now
|
||||
for k, v in self.environment_params.items():
|
||||
v, v_t = utils.repr(v)
|
||||
yield 'env', now, k, v, v_t
|
||||
yield history.Record(agent_id='env',
|
||||
t_step=now,
|
||||
key=k,
|
||||
value=v)
|
||||
for agent in self.agents:
|
||||
for k, v in agent.state.items():
|
||||
v, v_t = utils.repr(v)
|
||||
yield agent.id, now, k, v, v_t
|
||||
yield history.Record(agent_id=agent.id,
|
||||
t_step=now,
|
||||
key=k,
|
||||
value=v)
|
||||
|
||||
def history_to_tuples(self):
|
||||
with self._db:
|
||||
res = self._db.execute("select agent_id, t_step, key, value, value_type from history ").fetchall()
|
||||
yield from res
|
||||
return self._history.to_tuples()
|
||||
|
||||
def history_to_graph(self):
|
||||
G = nx.Graph(self.G)
|
||||
@@ -291,3 +296,19 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
G.add_node(agent.id, **attributes)
|
||||
|
||||
return G
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state['G'] = json_graph.node_link_data(self.G)
|
||||
state['network_agents'] = agents._serialize_distribution(self.network_agents)
|
||||
state['environment_agents'] = agents._convert_agent_types(self.environment_agents,
|
||||
to_string=True)
|
||||
del state['_queue']
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
self.G = json_graph.node_link_graph(state['G'])
|
||||
self.network_agents = self.calculate_distribution(self._convert_agent_types(self.network_agents))
|
||||
self.environment_agents = self._convert_agent_types(self.environment_agents)
|
||||
return state
|
||||
|
231
soil/history.py
Normal file
231
soil/history.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import time
|
||||
import os
|
||||
import pandas as pd
|
||||
import sqlite3
|
||||
import copy
|
||||
from collections import UserDict, Iterable, namedtuple
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
class History:
|
||||
"""
|
||||
Store and retrieve values from a sqlite database.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path=None, name=None, dir_path=None, backup=True):
|
||||
if db_path is None and name:
|
||||
db_path = os.path.join(dir_path or os.getcwd(),
|
||||
'{}.db.sqlite'.format(name))
|
||||
if db_path is None:
|
||||
db_path = ":memory:"
|
||||
else:
|
||||
if backup and os.path.exists(db_path):
|
||||
newname = db_path + '.backup{}.sqlite'.format(time.time())
|
||||
os.rename(db_path, newname)
|
||||
self._db_path = db_path
|
||||
if isinstance(db_path, str):
|
||||
self._db = sqlite3.connect(db_path)
|
||||
else:
|
||||
self._db = db_path
|
||||
|
||||
with self._db:
|
||||
self._db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text text)''')
|
||||
self._db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''')
|
||||
self._db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''')
|
||||
self._dtypes = {}
|
||||
self._tups = []
|
||||
|
||||
def conversors(self, key):
|
||||
"""Get the serializer and deserializer for a given key."""
|
||||
if key not in self._dtypes:
|
||||
self.read_types()
|
||||
return self._dtypes[key]
|
||||
|
||||
@property
|
||||
def dtypes(self):
|
||||
return {k:v[0] for k, v in self._dtypes.items()}
|
||||
|
||||
def save_tuples(self, tuples):
|
||||
self.save_records(Record(*tup) for tup in tuples)
|
||||
|
||||
def save_records(self, records):
|
||||
with self._db:
|
||||
for rec in records:
|
||||
if not isinstance(rec, Record):
|
||||
rec = Record(*rec)
|
||||
if rec.key not in self._dtypes:
|
||||
name = utils.name(rec.value)
|
||||
serializer = utils.serializer(name)
|
||||
deserializer = utils.deserializer(name)
|
||||
self._dtypes[rec.key] = (name, serializer, deserializer)
|
||||
self._db.execute("replace into value_types (key, value_type) values (?, ?)", (rec.key, name))
|
||||
self._db.execute("replace into history(agent_id, t_step, key, value) values (?, ?, ?, ?)", (rec.agent_id, rec.t_step, rec.key, rec.value))
|
||||
|
||||
def save_record(self, *args, **kwargs):
|
||||
self._tups.append(Record(*args, **kwargs))
|
||||
if len(self._tups) > 100:
|
||||
self.flush_cache()
|
||||
|
||||
def flush_cache(self):
|
||||
'''
|
||||
Use a cache to save state changes to avoid opening a session for every change.
|
||||
The cache will be flushed at the end of the simulation, and when history is accessed.
|
||||
'''
|
||||
self.save_records(self._tups)
|
||||
self._tups = list()
|
||||
|
||||
def to_tuples(self):
|
||||
self.flush_cache()
|
||||
with self._db:
|
||||
res = self._db.execute("select agent_id, t_step, key, value from history ").fetchall()
|
||||
for r in res:
|
||||
agent_id, t_step, key, value = r
|
||||
_, _ , des = self.conversors(key)
|
||||
yield agent_id, t_step, key, des(value)
|
||||
|
||||
def read_types(self):
|
||||
with self._db:
|
||||
res = self._db.execute("select key, value_type from value_types ").fetchall()
|
||||
for k, v in res:
|
||||
serializer = utils.serializer(v)
|
||||
deserializer = utils.deserializer(v)
|
||||
self._dtypes[k] = (v, serializer, deserializer)
|
||||
|
||||
def __getitem__(self, key):
|
||||
key = Key(*key)
|
||||
agent_ids = [key.agent_id] if key.agent_id is not None else []
|
||||
t_steps = [key.t_step] if key.t_step is not None else []
|
||||
keys = [key.key] if key.key is not None else []
|
||||
|
||||
df = self.read_sql(agent_ids=agent_ids,
|
||||
t_steps=t_steps,
|
||||
keys=keys)
|
||||
r = Records(df, filter=key, dtypes=self._dtypes)
|
||||
return r.value()
|
||||
|
||||
|
||||
|
||||
def read_sql(self, keys=None, agent_ids=None, t_steps=None, convert_types=False, limit=-1):
|
||||
|
||||
self.read_types()
|
||||
|
||||
def escape_and_join(v):
|
||||
if v is None:
|
||||
return
|
||||
return ",".join(map(lambda x: "\'{}\'".format(x), v))
|
||||
|
||||
filters = [("key in ({})".format(escape_and_join(keys)), keys),
|
||||
("agent_id in ({})".format(escape_and_join(agent_ids)), agent_ids)
|
||||
]
|
||||
filters = list(k[0] for k in filters if k[1])
|
||||
|
||||
last_df = None
|
||||
if t_steps:
|
||||
# Look for the last value before the minimum step in the query
|
||||
min_step = min(t_steps)
|
||||
last_filters = ['t_step < {}'.format(min_step),]
|
||||
last_filters = last_filters + filters
|
||||
condition = ' and '.join(last_filters)
|
||||
|
||||
last_query = '''
|
||||
select h1.*
|
||||
from history h1
|
||||
inner join (
|
||||
select agent_id, key, max(t_step) as t_step
|
||||
from history
|
||||
where {condition}
|
||||
group by agent_id, key
|
||||
) h2
|
||||
on h1.agent_id = h2.agent_id and
|
||||
h1.key = h2.key and
|
||||
h1.t_step = h2.t_step
|
||||
'''.format(condition=condition)
|
||||
last_df = pd.read_sql_query(last_query, self._db)
|
||||
|
||||
filters.append("t_step >= '{}' and t_step <= '{}'".format(min_step, max(t_steps)))
|
||||
|
||||
condition = ''
|
||||
if filters:
|
||||
condition = 'where {} '.format(' and '.join(filters))
|
||||
query = 'select * from history {} limit {}'.format(condition, limit)
|
||||
df = pd.read_sql_query(query, self._db)
|
||||
if last_df is not None:
|
||||
df = pd.concat([df, last_df])
|
||||
|
||||
df_p = df.pivot_table(values='value', index=['t_step'],
|
||||
columns=['key', 'agent_id'],
|
||||
aggfunc='first')
|
||||
|
||||
for k, v in self._dtypes.items():
|
||||
if k in df_p:
|
||||
dtype, _, deserial = v
|
||||
df_p[k] = df_p[k].fillna(method='ffill').fillna(deserial()).astype(dtype)
|
||||
if t_steps:
|
||||
df_p = df_p.reindex(t_steps, method='ffill')
|
||||
return df_p.ffill()
|
||||
|
||||
|
||||
class Records():
|
||||
|
||||
def __init__(self, df, filter=None, dtypes=None):
|
||||
if not filter:
|
||||
filter = Key(agent_id=None,
|
||||
t_step=None,
|
||||
key=None)
|
||||
self._df = df
|
||||
self._filter = filter
|
||||
self.dtypes = dtypes or {}
|
||||
super().__init__()
|
||||
|
||||
def mask(self, tup):
|
||||
res = ()
|
||||
for i, k in zip(tup[:-1], self._filter):
|
||||
if k is None:
|
||||
res = res + (i,)
|
||||
res = res + (tup[-1],)
|
||||
return res
|
||||
|
||||
def filter(self, newKey):
|
||||
f = list(self._filter)
|
||||
for ix, i in enumerate(f):
|
||||
if i is None:
|
||||
f[ix] = newKey
|
||||
self._filter = Key(*f)
|
||||
|
||||
@property
|
||||
def resolved(self):
|
||||
return sum(1 for i in self._filter if i is not None) == 3
|
||||
|
||||
def __iter__(self):
|
||||
for column, series in self._df.iteritems():
|
||||
key, agent_id = column
|
||||
for t_step, value in series.iteritems():
|
||||
r = Record(t_step=t_step,
|
||||
agent_id=agent_id,
|
||||
key=key,
|
||||
value=value)
|
||||
yield self.mask(r)
|
||||
|
||||
def value(self):
|
||||
if self.resolved:
|
||||
f = self._filter
|
||||
try:
|
||||
i = self._df[f.key][str(f.agent_id)]
|
||||
ix = i.index.get_loc(f.t_step, method='ffill')
|
||||
return i.iloc[ix]
|
||||
except KeyError:
|
||||
return self.dtypes[f.key][2]()
|
||||
return self
|
||||
|
||||
def __getitem__(self, k):
|
||||
n = copy.copy(self)
|
||||
n.filter(k)
|
||||
return n.value()
|
||||
|
||||
def __len__(self):
|
||||
return len(self._df)
|
||||
|
||||
|
||||
Key = namedtuple('Key', ['agent_id', 't_step', 'key'])
|
||||
Record = namedtuple('Record', 'agent_id t_step key value')
|
@@ -5,14 +5,14 @@ import sys
|
||||
import yaml
|
||||
import networkx as nx
|
||||
from networkx.readwrite import json_graph
|
||||
|
||||
from copy import deepcopy
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
|
||||
import pickle
|
||||
|
||||
from nxsim import NetworkSimulation
|
||||
|
||||
from . import agents, utils, environment, basestring
|
||||
from . import utils, environment, basestring, agents
|
||||
from .utils import logger
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class SoilSimulation(NetworkSimulation):
|
||||
"""
|
||||
Subclass of nsim.NetworkSimulation with three main differences:
|
||||
1) agent type can be specified by name or by class.
|
||||
2) instead of just one type, an network_agents can be used.
|
||||
2) instead of just one type, a network agents distribution can be used.
|
||||
The distribution specifies the weight (or probability) of each
|
||||
agent type in the topology. This is an example distribution: ::
|
||||
|
||||
@@ -46,7 +46,7 @@ class SoilSimulation(NetworkSimulation):
|
||||
"""
|
||||
def __init__(self, name=None, topology=None, network_params=None,
|
||||
network_agents=None, agent_type=None, states=None,
|
||||
default_state=None, interval=1, dump=False,
|
||||
default_state=None, interval=1, dump=None, dry_run=False,
|
||||
dir_path=None, num_trials=1, max_time=100,
|
||||
agent_module=None, load_module=None, seed=None,
|
||||
environment_agents=None, environment_params=None):
|
||||
@@ -57,7 +57,6 @@ class SoilSimulation(NetworkSimulation):
|
||||
elif isinstance(topology, basestring) or isinstance(topology, dict):
|
||||
topology = json_graph.node_link_graph(topology)
|
||||
|
||||
|
||||
self.load_module = load_module
|
||||
self.topology = nx.Graph(topology)
|
||||
self.network_params = network_params
|
||||
@@ -69,94 +68,70 @@ class SoilSimulation(NetworkSimulation):
|
||||
self.interval = interval
|
||||
self.seed = str(seed) or str(time.time())
|
||||
self.dump = dump
|
||||
self.dry_run = dry_run
|
||||
self.environment_params = environment_params or {}
|
||||
|
||||
if load_module:
|
||||
path = sys.path + [self.dir_path]
|
||||
path = sys.path + [self.dir_path, os.getcwd()]
|
||||
f, fp, desc = imp.find_module(load_module, path)
|
||||
imp.load_module('soil.agents.custom', f, fp, desc)
|
||||
|
||||
environment_agents = environment_agents or []
|
||||
self.environment_agents = self._convert_agent_types(environment_agents)
|
||||
self.environment_agents = agents._convert_agent_types(environment_agents)
|
||||
|
||||
distro = self.calculate_distribution(network_agents,
|
||||
agent_type)
|
||||
self.network_agents = self._convert_agent_types(distro)
|
||||
distro = agents.calculate_distribution(network_agents,
|
||||
agent_type)
|
||||
self.network_agents = agents._convert_agent_types(distro)
|
||||
|
||||
self.states = self.validate_states(states,
|
||||
self.topology)
|
||||
self.states = agents._validate_states(states,
|
||||
self.topology)
|
||||
|
||||
def calculate_distribution(self,
|
||||
network_agents=None,
|
||||
agent_type=None):
|
||||
if network_agents:
|
||||
network_agents = deepcopy(network_agents)
|
||||
elif agent_type:
|
||||
network_agents = [{'agent_type': agent_type}]
|
||||
else:
|
||||
return []
|
||||
def run_simulation(self, *args, **kwargs):
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
# Calculate the thresholds
|
||||
total = sum(x.get('weight', 1) for x in network_agents)
|
||||
acc = 0
|
||||
for v in network_agents:
|
||||
upper = acc + (v.get('weight', 1)/total)
|
||||
v['threshold'] = [acc, upper]
|
||||
acc = upper
|
||||
return network_agents
|
||||
def run(self, *args, **kwargs):
|
||||
return list(self.run_simulation_gen(*args, **kwargs))
|
||||
|
||||
def serialize_distribution(self):
|
||||
d = self._convert_agent_types(self.network_agents,
|
||||
to_string=True)
|
||||
for v in d:
|
||||
if 'threshold' in v:
|
||||
del v['threshold']
|
||||
return d
|
||||
|
||||
def _convert_agent_types(self, ind, to_string=False):
|
||||
d = deepcopy(ind)
|
||||
for v in d:
|
||||
agent_type = v['agent_type']
|
||||
if to_string and not isinstance(agent_type, str):
|
||||
v['agent_type'] = str(agent_type.__name__)
|
||||
elif not to_string and isinstance(agent_type, str):
|
||||
v['agent_type'] = agents.agent_types[agent_type]
|
||||
return d
|
||||
|
||||
def validate_states(self, states, topology):
|
||||
states = states or []
|
||||
# Validate states to avoid ignoring states during
|
||||
# initialization
|
||||
if isinstance(states, dict):
|
||||
for x in states:
|
||||
assert x in self.topology.node
|
||||
else:
|
||||
assert len(states) <= len(self.topology)
|
||||
return states
|
||||
|
||||
def run_simulation(self):
|
||||
return self.run()
|
||||
|
||||
def run(self):
|
||||
return list(self.run_simulation_gen())
|
||||
|
||||
def run_simulation_gen(self, *args, **kwargs):
|
||||
with utils.timer('simulation'):
|
||||
for i in range(self.num_trials):
|
||||
res = self.run_trial(i)
|
||||
if self.dump:
|
||||
res.dump_gexf(self.dir_path)
|
||||
res.dump_csv(self.dir_path)
|
||||
yield res
|
||||
|
||||
if self.dump:
|
||||
def run_simulation_gen(self, *args, parallel=False, dry_run=False,
|
||||
**kwargs):
|
||||
p = Pool()
|
||||
with utils.timer('simulation {}'.format(self.name)):
|
||||
if parallel:
|
||||
func = partial(self.run_trial, dry_run=dry_run or self.dry_run,
|
||||
return_env=not parallel, **kwargs)
|
||||
for i in p.imap_unordered(func, range(self.num_trials)):
|
||||
yield i
|
||||
else:
|
||||
for i in range(self.num_trials):
|
||||
yield self.run_trial(i, dry_run=dry_run or self.dry_run, **kwargs)
|
||||
if not (dry_run or self.dry_run):
|
||||
logger.info('Dumping results to {}'.format(self.dir_path))
|
||||
self.dump_pickle(self.dir_path)
|
||||
self.dump_yaml(self.dir_path)
|
||||
else:
|
||||
logger.info('NOT dumping results')
|
||||
|
||||
def run_trial(self, trial_id=0, dump=False, dir_path=None):
|
||||
def get_env(self, trial_id=0, **kwargs):
|
||||
opts = self.environment_params.copy()
|
||||
env_name = '{}_trial_{}'.format(self.name, trial_id)
|
||||
opts.update({
|
||||
'name': env_name,
|
||||
'topology': self.topology.copy(),
|
||||
'seed': self.seed+env_name,
|
||||
'initial_time': 0,
|
||||
'dry_run': self.dry_run,
|
||||
'interval': self.interval,
|
||||
'network_agents': self.network_agents,
|
||||
'states': self.states,
|
||||
'default_state': self.default_state,
|
||||
'environment_agents': self.environment_agents,
|
||||
'dir_path': self.dir_path,
|
||||
})
|
||||
opts.update(kwargs)
|
||||
env = environment.SoilEnvironment(**opts)
|
||||
return env
|
||||
|
||||
def run_trial(self, trial_id=0, until=None, return_env=True, **opts):
|
||||
"""Run a single trial of the simulation
|
||||
|
||||
Parameters
|
||||
@@ -164,25 +139,16 @@ class SoilSimulation(NetworkSimulation):
|
||||
trial_id : int
|
||||
"""
|
||||
# Set-up trial environment and graph
|
||||
logger.info('Trial: {}'.format(trial_id))
|
||||
env_name = '{}_trial_{}'.format(self.name, trial_id)
|
||||
env = environment.SoilEnvironment(name=env_name,
|
||||
topology=self.topology.copy(),
|
||||
seed=self.seed+env_name,
|
||||
initial_time=0,
|
||||
dump=self.dump,
|
||||
interval=self.interval,
|
||||
network_agents=self.network_agents,
|
||||
states=self.states,
|
||||
default_state=self.default_state,
|
||||
environment_agents=self.environment_agents,
|
||||
simulation=self,
|
||||
**self.environment_params)
|
||||
until = until or self.max_time
|
||||
env = self.get_env(trial_id=trial_id, **opts)
|
||||
# Set up agents on nodes
|
||||
logger.info('\tRunning')
|
||||
with utils.timer('trial'):
|
||||
env.run(until=self.max_time)
|
||||
return env
|
||||
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
||||
env.run(until)
|
||||
if self.dump and not self.dry_run:
|
||||
with utils.timer('Dumping simulation {} trial {}'.format(self.name, trial_id)):
|
||||
env.dump(formats=self.dump)
|
||||
if return_env:
|
||||
return env
|
||||
|
||||
def to_dict(self):
|
||||
return self.__getstate__()
|
||||
@@ -213,20 +179,20 @@ class SoilSimulation(NetworkSimulation):
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state['topology'] = json_graph.node_link_data(self.topology)
|
||||
state['network_agents'] = self.serialize_distribution()
|
||||
state['environment_agents'] = self._convert_agent_types(self.environment_agents,
|
||||
to_string=True)
|
||||
state['network_agents'] = agents._serialize_distribution(self.network_agents)
|
||||
state['environment_agents'] = agents._convert_agent_types(self.environment_agents,
|
||||
to_string=True)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
self.topology = json_graph.node_link_graph(state['topology'])
|
||||
self.network_agents = self._convert_agent_types(self.network_agents)
|
||||
self.environment_agents = self._convert_agent_types(self.environment_agents)
|
||||
self.network_agents = agents.calculate_distribution(agents._convert_agent_types(self.network_agents))
|
||||
self.environment_agents = agents._convert_agent_types(self.environment_agents)
|
||||
return state
|
||||
|
||||
|
||||
def from_config(config, G=None):
|
||||
def from_config(config):
|
||||
config = list(utils.load_config(config))
|
||||
if len(config) > 1:
|
||||
raise AttributeError('Provide only one configuration')
|
||||
@@ -235,21 +201,19 @@ def from_config(config, G=None):
|
||||
return sim
|
||||
|
||||
|
||||
def run_from_config(*configs, dump=True, results_dir=None, timestamp=False):
|
||||
if not results_dir:
|
||||
results_dir = 'soil_output'
|
||||
def run_from_config(*configs, results_dir='soil_output', dry_run=False, dump=None, timestamp=False, **kwargs):
|
||||
for config_def in configs:
|
||||
for config, cpath in utils.load_config(config_def):
|
||||
# logger.info("Found {} config(s)".format(len(ls)))
|
||||
for config, _ in utils.load_config(config_def):
|
||||
name = config.get('name', 'unnamed')
|
||||
logger.info("Using config(s): {name}".format(name=name))
|
||||
|
||||
sim = SoilSimulation(**config)
|
||||
if timestamp:
|
||||
sim_folder = '{}_{}'.format(sim.name,
|
||||
sim_folder = '{}_{}'.format(name,
|
||||
time.strftime("%Y-%m-%d_%H:%M:%S"))
|
||||
else:
|
||||
sim_folder = sim.name
|
||||
sim.dir_path = os.path.join(results_dir, sim_folder)
|
||||
sim.dump = dump
|
||||
logger.info('Dumping results to {} : {}'.format(sim.dir_path, dump))
|
||||
results = sim.run_simulation()
|
||||
sim_folder = name
|
||||
dir_path = os.path.join(results_dir, sim_folder)
|
||||
sim = SoilSimulation(dir_path=dir_path, dump=dump, **config)
|
||||
logger.info('Dumping results to {} : {}'.format(sim.dir_path, sim.dump))
|
||||
sim.run_simulation(**kwargs)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import yaml
|
||||
import logging
|
||||
import importlib
|
||||
from time import time
|
||||
from glob import glob
|
||||
from random import random
|
||||
@@ -11,7 +12,7 @@ import networkx as nx
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger('soil')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
@@ -62,6 +63,7 @@ def load_config(config):
|
||||
@contextmanager
|
||||
def timer(name='task', pre="", function=logger.info, to_object=None):
|
||||
start = time()
|
||||
function('{}Starting {} at {}.'.format(pre, name, start))
|
||||
yield start
|
||||
end = time()
|
||||
function('{}Finished {} in {} seconds'.format(pre, name, str(end-start)))
|
||||
@@ -70,29 +72,23 @@ def timer(name='task', pre="", function=logger.info, to_object=None):
|
||||
to_object.end = end
|
||||
|
||||
|
||||
def agent_from_distribution(distribution, value=-1):
|
||||
"""Find the agent """
|
||||
if value < 0:
|
||||
value = random()
|
||||
for d in distribution:
|
||||
threshold = d['threshold']
|
||||
if value >= threshold[0] and value < threshold[1]:
|
||||
state = {}
|
||||
if 'state' in d:
|
||||
state = deepcopy(d['state'])
|
||||
return d['agent_type'], state
|
||||
|
||||
raise Exception('Distribution for value {} not found in: {}'.format(value, distribution))
|
||||
|
||||
|
||||
def repr(v):
|
||||
if isinstance(v, bool):
|
||||
v = "true" if v else ""
|
||||
return v, bool.__name__
|
||||
return v, type(v).__name__
|
||||
func = serializer(v)
|
||||
tname = name(v)
|
||||
return func(v), tname
|
||||
|
||||
def convert(value, type_):
|
||||
import importlib
|
||||
|
||||
def name(v):
|
||||
return type(v).__name__
|
||||
|
||||
|
||||
def serializer(type_):
|
||||
if type_ == 'bool':
|
||||
return lambda x: "true" if x else ""
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def deserializer(type_):
|
||||
try:
|
||||
# Check if it's a builtin type
|
||||
module = importlib.import_module('builtins')
|
||||
@@ -102,4 +98,8 @@ def convert(value, type_):
|
||||
module, type_ = type_.rsplit(".", 1)
|
||||
module = importlib.import_module(module)
|
||||
cls = getattr(module, type_)
|
||||
return cls(value)
|
||||
return cls
|
||||
|
||||
|
||||
def convert(value, type_):
|
||||
return deserializer(type_)(value)
|
||||
|
20
soil/version.py
Normal file
20
soil/version.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ROOT = os.path.dirname(__file__)
|
||||
DEFAULT_FILE = os.path.join(ROOT, 'VERSION')
|
||||
|
||||
|
||||
def read_version(versionfile=DEFAULT_FILE):
|
||||
try:
|
||||
with open(versionfile) as f:
|
||||
return f.read().strip()
|
||||
except IOError: # pragma: no cover
|
||||
logger.error(('Running an unknown version of {}.'
|
||||
'Be careful!.').format(__name__))
|
||||
return '0.0'
|
||||
|
||||
|
||||
__version__ = read_version()
|
16
tests/test.csv
Normal file
16
tests/test.csv
Normal file
@@ -0,0 +1,16 @@
|
||||
agent_id,t_step,key,value,value_type
|
||||
a0,0,hello,w,str
|
||||
a0,1,hello,o,str
|
||||
a0,2,hello,r,str
|
||||
a0,3,hello,l,str
|
||||
a0,4,hello,d,str
|
||||
a0,5,hello,!,str
|
||||
env,1,started,,bool
|
||||
env,2,started,True,bool
|
||||
env,7,started,,bool
|
||||
a0,0,hello,w,str
|
||||
a0,1,hello,o,str
|
||||
a0,2,hello,r,str
|
||||
a0,3,hello,l,str
|
||||
a0,4,hello,d,str
|
||||
a0,5,hello,!,str
|
|
90
tests/test_analysis.py
Normal file
90
tests/test_analysis.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from unittest import TestCase
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from functools import partial
|
||||
|
||||
from os.path import join
|
||||
from soil import simulation, analysis, agents
|
||||
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
||||
class Ping(agents.FSM):
|
||||
|
||||
defaults = {
|
||||
'count': 0,
|
||||
}
|
||||
|
||||
@agents.default_state
|
||||
@agents.state
|
||||
def even(self):
|
||||
self['count'] += 1
|
||||
return self.odd
|
||||
|
||||
@agents.state
|
||||
def odd(self):
|
||||
self['count'] += 1
|
||||
return self.even
|
||||
|
||||
|
||||
class TestAnalysis(TestCase):
|
||||
|
||||
# Code to generate a simple sqlite history
|
||||
def setUp(self):
|
||||
"""
|
||||
The initial states should be applied to the agent and the
|
||||
agent should be able to update its state."""
|
||||
config = {
|
||||
'name': 'analysis',
|
||||
'dry_run': True,
|
||||
'seed': 'seed',
|
||||
'network_params': {
|
||||
'generator': 'complete_graph',
|
||||
'n': 2
|
||||
},
|
||||
'agent_type': Ping,
|
||||
'states': [{'interval': 1}, {'interval': 2}],
|
||||
'max_time': 30,
|
||||
'num_trials': 1,
|
||||
'environment_params': {
|
||||
}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
self.env = s.run_simulation()[0]
|
||||
|
||||
def test_saved(self):
|
||||
env = self.env
|
||||
assert env.get_agent(0)['count', 0] == 1
|
||||
assert env.get_agent(0)['count', 29] == 30
|
||||
assert env.get_agent(1)['count', 0] == 1
|
||||
assert env.get_agent(1)['count', 29] == 15
|
||||
assert env['env', 29, None]['SEED'] == env['env', 29, 'SEED']
|
||||
|
||||
def test_count(self):
|
||||
env = self.env
|
||||
df = analysis.read_sql(env._history._db)
|
||||
res = analysis.get_count(df, 'SEED', 'id')
|
||||
assert res['SEED']['seedanalysis_trial_0'].iloc[0] == 1
|
||||
assert res['SEED']['seedanalysis_trial_0'].iloc[-1] == 1
|
||||
assert res['id']['odd'].iloc[0] == 2
|
||||
assert res['id']['even'].iloc[0] == 0
|
||||
assert res['id']['odd'].iloc[-1] == 1
|
||||
assert res['id']['even'].iloc[-1] == 1
|
||||
|
||||
def test_value(self):
|
||||
env = self.env
|
||||
df = analysis.read_sql(env._history._db)
|
||||
res_sum = analysis.get_value(df, 'count')
|
||||
|
||||
assert res_sum['count'].iloc[0] == 2
|
||||
|
||||
import numpy as np
|
||||
res_mean = analysis.get_value(df, 'count', aggfunc=np.mean)
|
||||
assert res_mean['count'].iloc[0] == 1
|
||||
|
||||
res_total = analysis.get_value(df)
|
||||
|
||||
res_total['SEED'].iloc[0] == 'seedanalysis_trial_0'
|
133
tests/test_history.py
Normal file
133
tests/test_history.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from unittest import TestCase
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from glob import glob
|
||||
|
||||
from soil import history
|
||||
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
DBROOT = os.path.join(ROOT, 'testdb')
|
||||
|
||||
|
||||
class TestHistory(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not os.path.exists(DBROOT):
|
||||
os.makedirs(DBROOT)
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(DBROOT):
|
||||
shutil.rmtree(DBROOT)
|
||||
|
||||
def test_history(self):
|
||||
"""
|
||||
"""
|
||||
tuples = (
|
||||
('a_0', 0, 'id', 'h'),
|
||||
('a_0', 1, 'id', 'e'),
|
||||
('a_0', 2, 'id', 'l'),
|
||||
('a_0', 3, 'id', 'l'),
|
||||
('a_0', 4, 'id', 'o'),
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 3, 'prob', 2),
|
||||
('env', 5, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
h = history.History()
|
||||
h.save_tuples(tuples)
|
||||
# assert h['env', 0, 'prob'] == 0
|
||||
for i in range(1, 7):
|
||||
assert h['env', i, 'prob'] == ((i-1)//2)+1
|
||||
|
||||
|
||||
for i, k in zip(range(5), 'hello'):
|
||||
assert h['a_0', i, 'id'] == k
|
||||
for record, value in zip(h['a_0', None, 'id'], 'hello'):
|
||||
t_step, val = record
|
||||
assert val == value
|
||||
|
||||
for i, k in zip(range(5), 'value'):
|
||||
assert h['a_1', i, 'id'] == k
|
||||
for i in range(5, 8):
|
||||
assert h['a_1', i, 'id'] == 'e'
|
||||
for i in range(7):
|
||||
assert h['a_2', i, 'finished'] == False
|
||||
assert h['a_2', 7, 'finished']
|
||||
|
||||
def test_history_gen(self):
|
||||
"""
|
||||
"""
|
||||
tuples = (
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 2, 'prob', 2),
|
||||
('env', 3, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
h = history.History()
|
||||
h.save_tuples(tuples)
|
||||
for t_step, key, value in h['env', None, None]:
|
||||
assert t_step == value
|
||||
assert key == 'prob'
|
||||
|
||||
records = list(h[None, 7, None])
|
||||
assert len(records) == 3
|
||||
for i in records:
|
||||
agent_id, key, value = i
|
||||
if agent_id == 'a_1':
|
||||
assert key == 'id'
|
||||
assert value == 'e'
|
||||
elif agent_id == 'a_2':
|
||||
assert key == 'finished'
|
||||
assert value
|
||||
else:
|
||||
assert key == 'prob'
|
||||
assert value == 3
|
||||
|
||||
records = h['a_1', 7, None]
|
||||
assert records['id'] == 'e'
|
||||
|
||||
def test_history_file(self):
|
||||
"""
|
||||
History should be saved to a file
|
||||
"""
|
||||
tuples = (
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 2, 'prob', 2),
|
||||
('env', 3, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
db_path = os.path.join(DBROOT, 'test')
|
||||
h = history.History(db_path=db_path)
|
||||
h.save_tuples(tuples)
|
||||
assert os.path.exists(db_path)
|
||||
|
||||
# Recover the data
|
||||
recovered = history.History(db_path=db_path, backup=False)
|
||||
assert recovered['a_1', 0, 'id'] == 'v'
|
||||
assert recovered['a_1', 4, 'id'] == 'e'
|
||||
|
||||
# Using the same name should create a backup copy
|
||||
newhistory = history.History(db_path=db_path, backup=True)
|
||||
backuppaths = glob(db_path + '.backup*.sqlite')
|
||||
assert len(backuppaths) == 1
|
||||
backuppath = backuppaths[0]
|
||||
assert newhistory._db_path == h._db_path
|
||||
assert os.path.exists(backuppath)
|
||||
assert not len(newhistory[None, None, None])
|
@@ -2,6 +2,7 @@ from unittest import TestCase
|
||||
|
||||
import os
|
||||
import yaml
|
||||
import networkx as nx
|
||||
from functools import partial
|
||||
|
||||
from os.path import join
|
||||
@@ -21,6 +22,7 @@ class TestMain(TestCase):
|
||||
Raise an exception otherwise.
|
||||
"""
|
||||
config = {
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
}
|
||||
@@ -30,6 +32,7 @@ class TestMain(TestCase):
|
||||
assert len(G) == 2
|
||||
with self.assertRaises(AttributeError):
|
||||
config = {
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'unknown.extension')
|
||||
}
|
||||
@@ -43,6 +46,7 @@ class TestMain(TestCase):
|
||||
should be used to generate a network
|
||||
"""
|
||||
config = {
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'generator': 'barabasi_albert_graph'
|
||||
}
|
||||
@@ -57,6 +61,7 @@ class TestMain(TestCase):
|
||||
def test_empty_simulation(self):
|
||||
"""A simulation with a base behaviour should do nothing"""
|
||||
config = {
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
},
|
||||
@@ -65,35 +70,39 @@ class TestMain(TestCase):
|
||||
}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
s.run_simulation()
|
||||
s.run_simulation(dry_run=True)
|
||||
|
||||
def test_counter_agent(self):
|
||||
"""
|
||||
The initial states should be applied to the agent and the
|
||||
agent should be able to update its state."""
|
||||
config = {
|
||||
'name': 'CounterAgent',
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
},
|
||||
'agent_type': 'CounterModel',
|
||||
'states': [{'neighbors': 10}, {'total': 12}],
|
||||
'states': [{'times': 10}, {'times': 20}],
|
||||
'max_time': 2,
|
||||
'num_trials': 1,
|
||||
'environment_params': {
|
||||
}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
env = s.run_simulation()[0]
|
||||
assert env.get_agent(0)['neighbors', 0] == 10
|
||||
assert env.get_agent(0)['neighbors', 1] == 1
|
||||
assert env.get_agent(1)['total', 0] == 12
|
||||
assert env.get_agent(1)['neighbors', 1] == 1
|
||||
env = s.run_simulation(dry_run=True)[0]
|
||||
assert env.get_agent(0)['times', 0] == 11
|
||||
assert env.get_agent(0)['times', 1] == 12
|
||||
assert env.get_agent(1)['times', 0] == 21
|
||||
assert env.get_agent(1)['times', 1] == 22
|
||||
|
||||
def test_counter_agent_history(self):
|
||||
"""
|
||||
The evolution of the state should be recorded in the logging agent
|
||||
"""
|
||||
config = {
|
||||
'name': 'CounterAgent',
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
},
|
||||
@@ -108,14 +117,13 @@ class TestMain(TestCase):
|
||||
}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
env = s.run_simulation()[0]
|
||||
env = s.run_simulation(dry_run=True)[0]
|
||||
for agent in env.network_agents:
|
||||
last = 0
|
||||
assert len(agent[None, None]) == 11
|
||||
for step, total in agent['total', None].items():
|
||||
if step > 0:
|
||||
assert total == last + 2
|
||||
last = total
|
||||
assert len(agent[None, None]) == 10
|
||||
for step, total in sorted(agent['total', None]):
|
||||
assert total == last + 2
|
||||
last = total
|
||||
|
||||
def test_custom_agent(self):
|
||||
"""Allow for search of neighbors with a certain state_id"""
|
||||
@@ -124,6 +132,7 @@ class TestMain(TestCase):
|
||||
self.state['neighbors'] = self.count_agents(state_id=0,
|
||||
limit_neighbors=True)
|
||||
config = {
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
},
|
||||
@@ -138,7 +147,7 @@ class TestMain(TestCase):
|
||||
}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
env = s.run_simulation()[0]
|
||||
env = s.run_simulation(dry_run=True)[0]
|
||||
assert env.get_agent(0).state['neighbors'] == 1
|
||||
|
||||
def test_torvalds_example(self):
|
||||
@@ -147,6 +156,7 @@ class TestMain(TestCase):
|
||||
config['network_params']['path'] = join(EXAMPLES,
|
||||
config['network_params']['path'])
|
||||
s = simulation.from_config(config)
|
||||
s.dry_run = True
|
||||
env = s.run_simulation()[0]
|
||||
for a in env.network_agents:
|
||||
skill_level = a.state['skill_level']
|
||||
@@ -171,6 +181,7 @@ class TestMain(TestCase):
|
||||
with utils.timer('loading'):
|
||||
config = utils.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
||||
s = simulation.from_config(config)
|
||||
s.dry_run = True
|
||||
with utils.timer('serializing'):
|
||||
serial = s.to_yaml()
|
||||
with utils.timer('recovering'):
|
||||
@@ -178,6 +189,7 @@ class TestMain(TestCase):
|
||||
with utils.timer('deleting'):
|
||||
del recovered['topology']
|
||||
del recovered['load_module']
|
||||
del recovered['dry_run']
|
||||
assert config == recovered
|
||||
|
||||
def test_configuration_changes(self):
|
||||
@@ -187,10 +199,12 @@ class TestMain(TestCase):
|
||||
"""
|
||||
config = utils.load_file('examples/complete.yml')[0]
|
||||
s = simulation.from_config(config)
|
||||
s.dry_run = True
|
||||
for i in range(5):
|
||||
s.run_simulation()
|
||||
s.run_simulation(dry_run=True)
|
||||
nconfig = s.to_dict()
|
||||
del nconfig['topology']
|
||||
del nconfig['dry_run']
|
||||
del nconfig['load_module']
|
||||
assert config == nconfig
|
||||
|
||||
@@ -201,23 +215,27 @@ class TestMain(TestCase):
|
||||
pass
|
||||
|
||||
def test_row_conversion(self):
|
||||
sim = simulation.SoilSimulation()
|
||||
env = environment.SoilEnvironment(simulation=sim)
|
||||
env = environment.SoilEnvironment(dry_run=True)
|
||||
env['test'] = 'test_value'
|
||||
env._save_state(now=0)
|
||||
|
||||
res = list(env.history_to_tuples())
|
||||
assert len(res) == len(env.environment_params)
|
||||
assert ('env', 0, 'test', 'test_value', 'str') in res
|
||||
|
||||
env._now = 1
|
||||
env['test'] = 'second_value'
|
||||
env._save_state(now=1)
|
||||
res = list(env.history_to_tuples())
|
||||
|
||||
assert env['env', 0, 'test' ] == 'test_value'
|
||||
assert env['env', 1, 'test' ] == 'second_value'
|
||||
|
||||
|
||||
def test_save_geometric(self):
|
||||
"""
|
||||
There is a bug in networkx that prevents it from creating a GEXF file
|
||||
from geometric models. We should work around it.
|
||||
"""
|
||||
G = nx.random_geometric_graph(20,0.1)
|
||||
env = environment.SoilEnvironment(topology=G, dry_run=True)
|
||||
env.dump_gexf('/tmp/dump-gexf')
|
||||
|
||||
|
||||
def make_example_test(path, config):
|
||||
@@ -225,8 +243,10 @@ def make_example_test(path, config):
|
||||
root = os.getcwd()
|
||||
os.chdir(os.path.dirname(path))
|
||||
s = simulation.from_config(config)
|
||||
envs = s.run_simulation()
|
||||
envs = s.run_simulation(dry_run=True)
|
||||
assert envs
|
||||
for env in envs:
|
||||
assert env
|
||||
try:
|
||||
n = config['network_params']['n']
|
||||
assert len(env.get_agents()) == n
|
||||
|
Reference in New Issue
Block a user