mirror of
https://github.com/gsi-upm/soil
synced 2025-09-15 04:32:21 +00:00
Compare commits
5 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
5d89827ccf | ||
|
fc48ed7e09 | ||
|
73c90887e8 | ||
|
497c8a55db | ||
|
7d1c800490 |
8
docker-compose.yml
Normal file
8
docker-compose.yml
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
version: '3'
|
||||||
|
services:
|
||||||
|
dev:
|
||||||
|
build: .
|
||||||
|
volumes:
|
||||||
|
- .:/usr/src/app
|
||||||
|
tty: true
|
||||||
|
entrypoint: /bin/bash
|
334
examples/NewsSpread.ipynb
Normal file
334
examples/NewsSpread.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -68,7 +68,7 @@ network_agents:
|
|||||||
- agent_type: HerdViewer
|
- agent_type: HerdViewer
|
||||||
state:
|
state:
|
||||||
has_tv: true
|
has_tv: true
|
||||||
id: infected
|
id: neutral
|
||||||
weight: 1
|
weight: 1
|
||||||
- agent_type: HerdViewer
|
- agent_type: HerdViewer
|
||||||
state:
|
state:
|
||||||
@@ -95,7 +95,7 @@ network_agents:
|
|||||||
- agent_type: HerdViewer
|
- agent_type: HerdViewer
|
||||||
state:
|
state:
|
||||||
has_tv: true
|
has_tv: true
|
||||||
id: infected
|
id: neutral
|
||||||
weight: 1
|
weight: 1
|
||||||
- agent_type: WiseViewer
|
- agent_type: WiseViewer
|
||||||
state:
|
state:
|
||||||
@@ -121,7 +121,7 @@ network_agents:
|
|||||||
- agent_type: WiseViewer
|
- agent_type: WiseViewer
|
||||||
state:
|
state:
|
||||||
has_tv: true
|
has_tv: true
|
||||||
id: infected
|
id: neutral
|
||||||
weight: 1
|
weight: 1
|
||||||
- agent_type: WiseViewer
|
- agent_type: WiseViewer
|
||||||
state:
|
state:
|
||||||
|
@@ -1,5 +1,4 @@
|
|||||||
from soil.agents import BaseAgent,FSM, state, default_state
|
from soil.agents import FSM, state, default_state, prob
|
||||||
import random
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -10,70 +9,73 @@ class DumbViewer(FSM):
|
|||||||
'''
|
'''
|
||||||
defaults = {
|
defaults = {
|
||||||
'prob_neighbor_spread': 0.5,
|
'prob_neighbor_spread': 0.5,
|
||||||
'prob_neighbor_cure': 0.25,
|
'prob_tv_spread': 0.1,
|
||||||
}
|
}
|
||||||
|
|
||||||
@default_state
|
@default_state
|
||||||
@state
|
@state
|
||||||
def neutral(self):
|
def neutral(self):
|
||||||
r = random.random()
|
if self['has_tv']:
|
||||||
if self['has_tv'] and r < self.env['prob_tv_spread']:
|
if prob(self.env['prob_tv_spread']):
|
||||||
self.infect()
|
self.set_state(self.infected)
|
||||||
return
|
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def infected(self):
|
def infected(self):
|
||||||
for neighbor in self.get_neighboring_agents(state_id=self.neutral.id):
|
for neighbor in self.get_neighboring_agents(state_id=self.neutral.id):
|
||||||
prob_infect = self.env['prob_neighbor_spread']
|
if prob(self.env['prob_neighbor_spread']):
|
||||||
r = random.random()
|
|
||||||
if r < prob_infect:
|
|
||||||
self.set_state(self.infected.id)
|
|
||||||
neighbor.infect()
|
neighbor.infect()
|
||||||
return
|
|
||||||
|
|
||||||
def infect(self):
|
def infect(self):
|
||||||
self.set_state(self.infected)
|
self.set_state(self.infected)
|
||||||
|
|
||||||
|
|
||||||
class HerdViewer(DumbViewer):
|
class HerdViewer(DumbViewer):
|
||||||
'''
|
'''
|
||||||
A viewer whose probability of infection depends on the state of its neighbors.
|
A viewer whose probability of infection depends on the state of its neighbors.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
level = logging.DEBUG
|
level = logging.DEBUG
|
||||||
|
|
||||||
def infect(self):
|
def infect(self):
|
||||||
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
||||||
total = self.count_neighboring_agents()
|
total = self.count_neighboring_agents()
|
||||||
prob_infect = self.env['prob_neighbor_spread'] * infected/total
|
prob_infect = self.env['prob_neighbor_spread'] * infected/total
|
||||||
self.debug('prob_infect', prob_infect)
|
self.debug('prob_infect', prob_infect)
|
||||||
r = random.random()
|
if prob(prob_infect):
|
||||||
if r < prob_infect:
|
|
||||||
self.set_state(self.infected.id)
|
self.set_state(self.infected.id)
|
||||||
|
|
||||||
|
|
||||||
class WiseViewer(HerdViewer):
|
class WiseViewer(HerdViewer):
|
||||||
'''
|
'''
|
||||||
A viewer that can change its mind.
|
A viewer that can change its mind.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
defaults = {
|
||||||
|
'prob_neighbor_spread': 0.5,
|
||||||
|
'prob_neighbor_cure': 0.25,
|
||||||
|
'prob_tv_spread': 0.1,
|
||||||
|
}
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def cured(self):
|
def cured(self):
|
||||||
prob_cure = self.env['prob_neighbor_cure']
|
prob_cure = self.env['prob_neighbor_cure']
|
||||||
for neighbor in self.get_neighboring_agents(state_id=self.infected.id):
|
for neighbor in self.get_neighboring_agents(state_id=self.infected.id):
|
||||||
r = random.random()
|
if prob(prob_cure):
|
||||||
if r < prob_cure:
|
|
||||||
try:
|
try:
|
||||||
neighbor.cure()
|
neighbor.cure()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
self.debug('Viewer {} cannot be cured'.format(neighbor.id))
|
self.debug('Viewer {} cannot be cured'.format(neighbor.id))
|
||||||
return
|
|
||||||
|
|
||||||
def cure(self):
|
def cure(self):
|
||||||
self.set_state(self.cured.id)
|
self.set_state(self.cured.id)
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def infected(self):
|
def infected(self):
|
||||||
prob_cure = self.env['prob_neighbor_cure']
|
cured = max(self.count_neighboring_agents(self.cured.id),
|
||||||
r = random.random()
|
1.0)
|
||||||
if r < prob_cure:
|
infected = max(self.count_neighboring_agents(self.infected.id),
|
||||||
self.cure()
|
1.0)
|
||||||
return
|
prob_cure = self.env['prob_neighbor_cure'] * (cured/infected)
|
||||||
return super().infected()
|
if prob(prob_cure):
|
||||||
|
return self.cure()
|
||||||
|
return self.set_state(super().infected)
|
||||||
|
File diff suppressed because one or more lines are too long
@@ -1,6 +1,6 @@
|
|||||||
nxsim
|
nxsim
|
||||||
simpy
|
simpy
|
||||||
networkx
|
networkx>=2.0
|
||||||
numpy
|
numpy
|
||||||
matplotlib
|
matplotlib
|
||||||
pyyaml
|
pyyaml
|
||||||
|
29
setup.py
29
setup.py
@@ -1,20 +1,21 @@
|
|||||||
import pip
|
import os
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
# parse_requirements() returns generator of pip.req.InstallRequirement objects
|
|
||||||
from pip.req import parse_requirements
|
|
||||||
from soil import __version__
|
|
||||||
|
|
||||||
try:
|
|
||||||
install_reqs = parse_requirements(
|
|
||||||
"requirements.txt", session=pip.download.PipSession())
|
|
||||||
test_reqs = parse_requirements(
|
|
||||||
"test-requirements.txt", session=pip.download.PipSession())
|
|
||||||
except AttributeError:
|
|
||||||
install_reqs = parse_requirements("requirements.txt")
|
|
||||||
test_reqs = parse_requirements("test-requirements.txt")
|
|
||||||
|
|
||||||
install_reqs = [str(ir.req) for ir in install_reqs]
|
with open(os.path.join('soil', 'VERSION')) as f:
|
||||||
test_reqs = [str(ir.req) for ir in test_reqs]
|
__version__ = f.readlines()[0].strip()
|
||||||
|
assert __version__
|
||||||
|
|
||||||
|
|
||||||
|
def parse_requirements(filename):
|
||||||
|
""" load requirements from a pip requirements file """
|
||||||
|
with open(filename, 'r') as f:
|
||||||
|
lineiter = list(line.strip() for line in f)
|
||||||
|
return [line for line in lineiter if line and not line.startswith("#")]
|
||||||
|
|
||||||
|
|
||||||
|
install_reqs = parse_requirements("requirements.txt")
|
||||||
|
test_reqs = parse_requirements("test-requirements.txt")
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
|
1
soil/VERSION
Normal file
1
soil/VERSION
Normal file
@@ -0,0 +1 @@
|
|||||||
|
0.11.1
|
@@ -4,7 +4,7 @@ import os
|
|||||||
import pdb
|
import pdb
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
__version__ = "0.10.1"
|
from .version import __version__
|
||||||
|
|
||||||
try:
|
try:
|
||||||
basestring
|
basestring
|
||||||
@@ -35,8 +35,14 @@ def main():
|
|||||||
help='Do not store the results of the simulation.')
|
help='Do not store the results of the simulation.')
|
||||||
parser.add_argument('--pdb', action='store_true',
|
parser.add_argument('--pdb', action='store_true',
|
||||||
help='Use a pdb console in case of exception.')
|
help='Use a pdb console in case of exception.')
|
||||||
parser.add_argument('--output', '-o', type=str,
|
parser.add_argument('--graph', '-g', action='store_true',
|
||||||
|
help='Dump GEXF graph. Defaults to false.')
|
||||||
|
parser.add_argument('--csv', action='store_true',
|
||||||
|
help='Dump history in CSV format. Defaults to false.')
|
||||||
|
parser.add_argument('--output', '-o', type=str, default="soil_output",
|
||||||
help='folder to write results to. It defaults to the current directory.')
|
help='folder to write results to. It defaults to the current directory.')
|
||||||
|
parser.add_argument('--synchronous', action='store_true',
|
||||||
|
help='Run trials serially and synchronously instead of in parallel. Defaults to false.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -47,7 +53,17 @@ def main():
|
|||||||
logging.info('Loading config file: {}'.format(args.file, args.output))
|
logging.info('Loading config file: {}'.format(args.file, args.output))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
simulation.run_from_config(args.file, dump=(not args.dry_run), results_dir=args.output)
|
dump = []
|
||||||
|
if not args.dry_run:
|
||||||
|
if args.csv:
|
||||||
|
dump.append('csv')
|
||||||
|
if args.graph:
|
||||||
|
dump.append('gexf')
|
||||||
|
simulation.run_from_config(args.file,
|
||||||
|
dry_run=args.dry_run,
|
||||||
|
dump=dump,
|
||||||
|
parallel=(not args.synchronous and not args.pdb),
|
||||||
|
results_dir=args.output)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
if args.pdb:
|
if args.pdb:
|
||||||
pdb.post_mortem()
|
pdb.post_mortem()
|
||||||
|
@@ -11,9 +11,9 @@ class CounterModel(BaseAgent):
|
|||||||
# Outside effects
|
# Outside effects
|
||||||
total = len(list(self.get_all_agents()))
|
total = len(list(self.get_all_agents()))
|
||||||
neighbors = len(list(self.get_neighboring_agents()))
|
neighbors = len(list(self.get_neighboring_agents()))
|
||||||
self.state['times'] = self.state.get('times', 0) + 1
|
self['times'] = self.get('times', 0) + 1
|
||||||
self.state['neighbors'] = neighbors
|
self['neighbors'] = neighbors
|
||||||
self.state['total'] = total
|
self['total'] = total
|
||||||
|
|
||||||
|
|
||||||
class AggregatedCounter(BaseAgent):
|
class AggregatedCounter(BaseAgent):
|
||||||
@@ -26,7 +26,7 @@ class AggregatedCounter(BaseAgent):
|
|||||||
# Outside effects
|
# Outside effects
|
||||||
total = len(list(self.get_all_agents()))
|
total = len(list(self.get_all_agents()))
|
||||||
neighbors = len(list(self.get_neighboring_agents()))
|
neighbors = len(list(self.get_neighboring_agents()))
|
||||||
self.state['times'] = self.state.get('times', 0) + 1
|
self['times'] = self.get('times', 0) + 1
|
||||||
self.state['neighbors'] = self.state.get('neighbors', 0) + neighbors
|
self['neighbors'] = self.get('neighbors', 0) + neighbors
|
||||||
self.state['total'] = total = self.state.get('total', 0) + total
|
self['total'] = total = self.get('total', 0) + total
|
||||||
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
||||||
|
@@ -15,4 +15,4 @@ class DrawingAgent(BaseAgent):
|
|||||||
# Outside effects
|
# Outside effects
|
||||||
f = plt.figure()
|
f = plt.figure()
|
||||||
nx.draw(self.env.G, node_size=10, width=0.2, pos=nx.spring_layout(self.env.G, scale=100), ax=f.add_subplot(111))
|
nx.draw(self.env.G, node_size=10, width=0.2, pos=nx.spring_layout(self.env.G, scale=100), ax=f.add_subplot(111))
|
||||||
f.savefig(os.path.join(self.env.sim().dir_path, "graph-"+str(self.env.now)+".png"))
|
f.savefig(os.path.join(self.env.get_path(), "graph-"+str(self.env.now)+".png"))
|
||||||
|
@@ -12,9 +12,9 @@ from copy import deepcopy
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
|
from .. import utils, history
|
||||||
|
|
||||||
agent_types = {}
|
agent_types = {}
|
||||||
|
|
||||||
@@ -32,33 +32,67 @@ class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
|||||||
|
|
||||||
defaults = {}
|
defaults = {}
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, environment=None, agent_id=None, state=None,
|
||||||
|
name='network_process', interval=None, **state_params):
|
||||||
|
# Check for REQUIRED arguments
|
||||||
|
assert environment is not None, TypeError('__init__ missing 1 required keyword argument: \'environment\'. '
|
||||||
|
'Cannot be NoneType.')
|
||||||
|
# Initialize agent parameters
|
||||||
|
self.id = agent_id
|
||||||
|
self.name = name
|
||||||
|
self.state_params = state_params
|
||||||
|
|
||||||
|
# Global parameters
|
||||||
|
self.global_topology = environment.G
|
||||||
|
self.environment_params = environment.environment_params
|
||||||
|
|
||||||
|
# Register agent to environment
|
||||||
|
self.env = environment
|
||||||
|
|
||||||
self._neighbors = None
|
self._neighbors = None
|
||||||
self.alive = True
|
self.alive = True
|
||||||
state = deepcopy(self.defaults)
|
real_state = deepcopy(self.defaults)
|
||||||
state.update(kwargs.pop('state', {}))
|
real_state.update(state or {})
|
||||||
kwargs['state'] = state
|
self._state = real_state
|
||||||
super().__init__(**kwargs)
|
self.interval = interval
|
||||||
|
|
||||||
if not hasattr(self, 'level'):
|
if not hasattr(self, 'level'):
|
||||||
self.level = logging.DEBUG
|
self.level = logging.DEBUG
|
||||||
self.logger = logging.getLogger('Agent-{}'.format(self.id))
|
self.logger = logging.getLogger('{}-Agent-{}'.format(self.env.name,
|
||||||
|
self.id))
|
||||||
self.logger.setLevel(self.level)
|
self.logger.setLevel(self.level)
|
||||||
|
|
||||||
|
# initialize every time an instance of the agent is created
|
||||||
|
self.action = self.env.process(self.run())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@state.setter
|
||||||
|
def state(self, value):
|
||||||
|
for k, v in value.items():
|
||||||
|
self[k] = v
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if isinstance(key, tuple):
|
if isinstance(key, tuple):
|
||||||
k, t_step = key
|
key, t_step = key
|
||||||
return self.env[self.id, t_step, k]
|
k = history.Key(key=key, t_step=t_step, agent_id=self.id)
|
||||||
|
return self.env[k]
|
||||||
return self.state.get(key, None)
|
return self.state.get(key, None)
|
||||||
|
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
del self.state[key]
|
self.state[key] = None
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
return key in self.state
|
return key in self.state
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
self.state[key] = value
|
self.state[key] = value
|
||||||
|
k = history.Key(t_step=self.now,
|
||||||
|
agent_id=self.id,
|
||||||
|
key=key)
|
||||||
|
self.env[k] = value
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
return self[key] if key in self else default
|
return self[key] if key in self else default
|
||||||
@@ -72,7 +106,12 @@ class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
interval = self.env.interval
|
if self.interval is not None:
|
||||||
|
interval = self.interval
|
||||||
|
elif 'interval' in self:
|
||||||
|
interval = self['interval']
|
||||||
|
else:
|
||||||
|
interval = self.env.interval
|
||||||
while self.alive:
|
while self.alive:
|
||||||
res = self.step()
|
res = self.step()
|
||||||
yield res or self.env.timeout(interval)
|
yield res or self.env.timeout(interval)
|
||||||
@@ -95,7 +134,7 @@ class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
|||||||
agents = self.global_topology.nodes()
|
agents = self.global_topology.nodes()
|
||||||
count = 0
|
count = 0
|
||||||
for agent in agents:
|
for agent in agents:
|
||||||
if state_id and state_id != self.global_topology.node[agent]['agent'].state['id']:
|
if state_id and state_id != self.global_topology.node[agent]['agent']['id']:
|
||||||
continue
|
continue
|
||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
@@ -140,20 +179,24 @@ class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
|||||||
|
|
||||||
|
|
||||||
def state(func):
|
def state(func):
|
||||||
|
'''
|
||||||
|
A state function should return either a state id, or a tuple (state_id, when)
|
||||||
|
The default value for state_id is the current state id.
|
||||||
|
The default value for when is the interval defined in the nevironment.
|
||||||
|
'''
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def func_wrapper(self):
|
def func_wrapper(self):
|
||||||
when = None
|
|
||||||
next_state = func(self)
|
next_state = func(self)
|
||||||
|
when = None
|
||||||
|
if next_state is None:
|
||||||
|
return when
|
||||||
try:
|
try:
|
||||||
next_state, when = next_state
|
next_state, when = next_state
|
||||||
except TypeError:
|
except (ValueError, TypeError):
|
||||||
pass
|
pass
|
||||||
if next_state:
|
if next_state:
|
||||||
try:
|
self.set_state(next_state)
|
||||||
self.state['id'] = next_state.id
|
|
||||||
except AttributeError:
|
|
||||||
raise ValueError('State id %s is not valid.' % next_state)
|
|
||||||
return when
|
return when
|
||||||
|
|
||||||
func_wrapper.id = func.__name__
|
func_wrapper.id = func.__name__
|
||||||
@@ -193,11 +236,13 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(FSM, self).__init__(*args, **kwargs)
|
super(FSM, self).__init__(*args, **kwargs)
|
||||||
if 'id' not in self.state:
|
if 'id' not in self.state:
|
||||||
self.state['id'] = self.default_state.id
|
if not self.default_state:
|
||||||
|
raise ValueError('No default state specified for {}'.format(self.id))
|
||||||
|
self['id'] = self.default_state.id
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
if 'id' in self.state:
|
if 'id' in self.state:
|
||||||
next_state = self.state['id']
|
next_state = self['id']
|
||||||
elif self.default_state:
|
elif self.default_state:
|
||||||
next_state = self.default_state.id
|
next_state = self.default_state.id
|
||||||
else:
|
else:
|
||||||
@@ -211,7 +256,117 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
|||||||
state = state.id
|
state = state.id
|
||||||
if state not in self.states:
|
if state not in self.states:
|
||||||
raise ValueError('{} is not a valid state'.format(state))
|
raise ValueError('{} is not a valid state'.format(state))
|
||||||
self.state['id'] = state
|
self['id'] = state
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def prob(prob=1):
|
||||||
|
'''
|
||||||
|
A true/False uniform distribution with a given probability.
|
||||||
|
To be used like this:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
if prob(0.3):
|
||||||
|
do_something()
|
||||||
|
|
||||||
|
'''
|
||||||
|
r = random.random()
|
||||||
|
return r < prob
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_distribution(network_agents=None,
|
||||||
|
agent_type=None):
|
||||||
|
'''
|
||||||
|
Calculate the threshold values (thresholds for a uniform distribution)
|
||||||
|
of an agent distribution given the weights of each agent type.
|
||||||
|
|
||||||
|
The input has this form: ::
|
||||||
|
|
||||||
|
[
|
||||||
|
{'agent_type': 'agent_type_1',
|
||||||
|
'weight': 0.2,
|
||||||
|
'state': {
|
||||||
|
'id': 0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{'agent_type': 'agent_type_2',
|
||||||
|
'weight': 0.8,
|
||||||
|
'state': {
|
||||||
|
'id': 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
In this example, 20% of the nodes will be marked as type
|
||||||
|
'agent_type_1'.
|
||||||
|
'''
|
||||||
|
if network_agents:
|
||||||
|
network_agents = deepcopy(network_agents)
|
||||||
|
elif agent_type:
|
||||||
|
network_agents = [{'agent_type': agent_type}]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Calculate the thresholds
|
||||||
|
total = sum(x.get('weight', 1) for x in network_agents)
|
||||||
|
acc = 0
|
||||||
|
for v in network_agents:
|
||||||
|
upper = acc + (v.get('weight', 1)/total)
|
||||||
|
v['threshold'] = [acc, upper]
|
||||||
|
acc = upper
|
||||||
|
return network_agents
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_distribution(network_agents):
|
||||||
|
d = _convert_agent_types(network_agents,
|
||||||
|
to_string=True)
|
||||||
|
'''
|
||||||
|
When serializing an agent distribution, remove the thresholds, in order
|
||||||
|
to avoid cluttering the YAML definition file.
|
||||||
|
'''
|
||||||
|
for v in d:
|
||||||
|
if 'threshold' in v:
|
||||||
|
del v['threshold']
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_states(states, topology):
|
||||||
|
'''Validate states to avoid ignoring states during initialization'''
|
||||||
|
states = states or []
|
||||||
|
if isinstance(states, dict):
|
||||||
|
for x in states:
|
||||||
|
assert x in topology.node
|
||||||
|
else:
|
||||||
|
assert len(states) <= len(topology)
|
||||||
|
return states
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_agent_types(ind, to_string=False):
|
||||||
|
'''Convenience method to allow specifying agents by class or class name.'''
|
||||||
|
d = deepcopy(ind)
|
||||||
|
for v in d:
|
||||||
|
agent_type = v['agent_type']
|
||||||
|
if to_string and not isinstance(agent_type, str):
|
||||||
|
v['agent_type'] = str(agent_type.__name__)
|
||||||
|
elif not to_string and isinstance(agent_type, str):
|
||||||
|
v['agent_type'] = agent_types[agent_type]
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def _agent_from_distribution(distribution, value=-1):
|
||||||
|
"""Used in the initialization of agents given an agent distribution."""
|
||||||
|
if value < 0:
|
||||||
|
value = random.random()
|
||||||
|
for d in distribution:
|
||||||
|
threshold = d['threshold']
|
||||||
|
if value >= threshold[0] and value < threshold[1]:
|
||||||
|
state = {}
|
||||||
|
if 'state' in d:
|
||||||
|
state = deepcopy(d['state'])
|
||||||
|
return d['agent_type'], state
|
||||||
|
|
||||||
|
raise Exception('Distribution for value {} not found in: {}'.format(value, distribution))
|
||||||
|
|
||||||
|
|
||||||
from .BassModel import *
|
from .BassModel import *
|
||||||
|
@@ -4,7 +4,7 @@ import glob
|
|||||||
import yaml
|
import yaml
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
|
||||||
from . import utils
|
from . import utils, history
|
||||||
|
|
||||||
|
|
||||||
def read_data(*args, group=False, **kwargs):
|
def read_data(*args, group=False, **kwargs):
|
||||||
@@ -15,8 +15,9 @@ def read_data(*args, group=False, **kwargs):
|
|||||||
return list(iterable)
|
return list(iterable)
|
||||||
|
|
||||||
|
|
||||||
def _read_data(pattern, keys=None, convert_types=False,
|
def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
|
||||||
process=None, from_csv=False, **kwargs):
|
if not process_args:
|
||||||
|
process_args = {}
|
||||||
for folder in glob.glob(pattern):
|
for folder in glob.glob(pattern):
|
||||||
config_file = glob.glob(join(folder, '*.yml'))[0]
|
config_file = glob.glob(join(folder, '*.yml'))[0]
|
||||||
config = yaml.load(open(config_file))
|
config = yaml.load(open(config_file))
|
||||||
@@ -24,19 +25,20 @@ def _read_data(pattern, keys=None, convert_types=False,
|
|||||||
if from_csv:
|
if from_csv:
|
||||||
for trial_data in sorted(glob.glob(join(folder,
|
for trial_data in sorted(glob.glob(join(folder,
|
||||||
'*.environment.csv'))):
|
'*.environment.csv'))):
|
||||||
df = read_csv(trial_data, convert_types=convert_types)
|
df = read_csv(trial_data, **kwargs)
|
||||||
if process:
|
|
||||||
df = process(df, **kwargs)
|
|
||||||
yield config_file, df, config
|
yield config_file, df, config
|
||||||
else:
|
else:
|
||||||
for trial_data in sorted(glob.glob(join(folder, '*.db.sqlite'))):
|
for trial_data in sorted(glob.glob(join(folder, '*.db.sqlite'))):
|
||||||
df = read_sql(trial_data, convert_types=convert_types,
|
df = read_sql(trial_data, **kwargs)
|
||||||
keys=keys)
|
|
||||||
if process:
|
|
||||||
df = process(df, **kwargs)
|
|
||||||
yield config_file, df, config
|
yield config_file, df, config
|
||||||
|
|
||||||
|
|
||||||
|
def read_sql(db, *args, **kwargs):
|
||||||
|
h = history.History(db, backup=False)
|
||||||
|
df = h.read_sql(*args, **kwargs)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
def read_csv(filename, keys=None, convert_types=False, **kwargs):
|
def read_csv(filename, keys=None, convert_types=False, **kwargs):
|
||||||
'''
|
'''
|
||||||
Read a CSV in canonical form: ::
|
Read a CSV in canonical form: ::
|
||||||
@@ -49,18 +51,7 @@ def read_csv(filename, keys=None, convert_types=False, **kwargs):
|
|||||||
df = convert_types_slow(df)
|
df = convert_types_slow(df)
|
||||||
if keys:
|
if keys:
|
||||||
df = df[df['key'].isin(keys)]
|
df = df[df['key'].isin(keys)]
|
||||||
return df
|
df = process_one(df)
|
||||||
|
|
||||||
|
|
||||||
def read_sql(filename, keys=None, convert_types=False, limit=-1):
|
|
||||||
condition = ''
|
|
||||||
if keys:
|
|
||||||
k = map(lambda x: "\'{}\'".format(x), keys)
|
|
||||||
condition = 'where key in ({})'.format(','.join(k))
|
|
||||||
query = 'select * from history {} limit {}'.format(condition, limit)
|
|
||||||
df = pd.read_sql_query(query, 'sqlite:///{}'.format(filename))
|
|
||||||
if convert_types:
|
|
||||||
df = convert_types_slow(df)
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
@@ -108,8 +99,9 @@ def get_types(df):
|
|||||||
return {k:v[0] for k,v in dtypes.iteritems()}
|
return {k:v[0] for k,v in dtypes.iteritems()}
|
||||||
|
|
||||||
|
|
||||||
def process_one(df, *keys, columns=['key'], values='value',
|
def process_one(df, *keys, columns=['key', 'agent_id'], values='value',
|
||||||
index=['t_step', 'agent_id'], aggfunc='first', **kwargs):
|
fill=True, index=['t_step',],
|
||||||
|
aggfunc='first', **kwargs):
|
||||||
'''
|
'''
|
||||||
Process a dataframe in canonical form ``(t_step, agent_id, key, value, value_type)`` into
|
Process a dataframe in canonical form ``(t_step, agent_id, key, value, value_type)`` into
|
||||||
a dataframe with a column per key
|
a dataframe with a column per key
|
||||||
@@ -119,35 +111,29 @@ def process_one(df, *keys, columns=['key'], values='value',
|
|||||||
if keys:
|
if keys:
|
||||||
df = df[df['key'].isin(keys)]
|
df = df[df['key'].isin(keys)]
|
||||||
|
|
||||||
dtypes = get_types(df)
|
|
||||||
|
|
||||||
df = df.pivot_table(values=values, index=index, columns=columns,
|
df = df.pivot_table(values=values, index=index, columns=columns,
|
||||||
aggfunc=aggfunc, **kwargs)
|
aggfunc=aggfunc, **kwargs)
|
||||||
df = df.fillna(0).astype(dtypes)
|
if fill:
|
||||||
|
df = fillna(df)
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
def get_count_processed(df, *keys):
|
|
||||||
if keys:
|
|
||||||
df = df[list(keys)]
|
|
||||||
# p = df.groupby(level=0).apply(pd.Series.value_counts)
|
|
||||||
p = df.unstack().apply(pd.Series.value_counts, axis=1)
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
def get_count(df, *keys):
|
def get_count(df, *keys):
|
||||||
if keys:
|
if keys:
|
||||||
df = df[df['key'].isin(keys)]
|
df = df[list(keys)]
|
||||||
p = df.groupby(by=['t_step', 'key', 'value']).size().unstack(level=[1,2]).fillna(0)
|
counts = pd.DataFrame()
|
||||||
return p
|
for key in df.columns.levels[0]:
|
||||||
|
g = df[key].apply(pd.Series.value_counts, axis=1).fillna(0)
|
||||||
|
for value, series in g.iteritems():
|
||||||
|
counts[key, value] = series
|
||||||
|
counts.columns = pd.MultiIndex.from_tuples(counts.columns)
|
||||||
|
return counts
|
||||||
|
|
||||||
|
|
||||||
def get_value(df, *keys, aggfunc='sum'):
|
def get_value(df, *keys, aggfunc='sum'):
|
||||||
if keys:
|
if keys:
|
||||||
df = df[df['key'].isin(keys)]
|
df = df[list(keys)]
|
||||||
p = process_one(df, *keys)
|
return df.groupby(axis=1, level=0).agg(aggfunc, axis=1)
|
||||||
p = p.groupby(level='t_step').agg(aggfunc)
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
def plot_all(*args, **kwargs):
|
def plot_all(*args, **kwargs):
|
||||||
@@ -175,4 +161,6 @@ def group_trials(trials, aggfunc=['mean', 'min', 'max', 'std']):
|
|||||||
return pd.concat(trials).groupby(level=0).agg(aggfunc).reorder_levels([2, 0,1] ,axis=1)
|
return pd.concat(trials).groupby(level=0).agg(aggfunc).reorder_levels([2, 0,1] ,axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
def fillna(df):
|
||||||
|
new_df = df.ffill(axis=0)
|
||||||
|
return new_df
|
||||||
|
@@ -1,19 +1,30 @@
|
|||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import time
|
import time
|
||||||
import weakref
|
|
||||||
import csv
|
import csv
|
||||||
import random
|
import random
|
||||||
import simpy
|
import simpy
|
||||||
|
import tempfile
|
||||||
|
import pandas as pd
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from networkx.readwrite import json_graph
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import nxsim
|
import nxsim
|
||||||
|
|
||||||
from . import utils
|
from . import utils, agents, analysis, history
|
||||||
|
|
||||||
|
|
||||||
class SoilEnvironment(nxsim.NetworkEnvironment):
|
class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||||
|
"""
|
||||||
|
The environment is key in a simulation. It contains the network topology,
|
||||||
|
a reference to network and environment agents, as well as the environment
|
||||||
|
params, which are used as shared state between agents.
|
||||||
|
|
||||||
|
The environment parameters and the state of every agent can be accessed
|
||||||
|
both by using the environment as a dictionary or with the environment's
|
||||||
|
:meth:`soil.environment.SoilEnvironment.get` method.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, name=None,
|
def __init__(self, name=None,
|
||||||
network_agents=None,
|
network_agents=None,
|
||||||
@@ -22,42 +33,31 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
|||||||
default_state=None,
|
default_state=None,
|
||||||
interval=1,
|
interval=1,
|
||||||
seed=None,
|
seed=None,
|
||||||
dump=False,
|
dry_run=False,
|
||||||
simulation=None,
|
dir_path=None,
|
||||||
|
topology=None,
|
||||||
*args, **kwargs):
|
*args, **kwargs):
|
||||||
self.name = name or 'UnnamedEnvironment'
|
self.name = name or 'UnnamedEnvironment'
|
||||||
if isinstance(states, list):
|
if isinstance(states, list):
|
||||||
states = dict(enumerate(states))
|
states = dict(enumerate(states))
|
||||||
self.states = deepcopy(states) if states else {}
|
self.states = deepcopy(states) if states else {}
|
||||||
self.default_state = deepcopy(default_state) or {}
|
self.default_state = deepcopy(default_state) or {}
|
||||||
self.sim = weakref.ref(simulation)
|
if not topology:
|
||||||
if 'topology' not in kwargs and simulation:
|
topology = nx.Graph()
|
||||||
kwargs['topology'] = self.sim().topology.copy()
|
super().__init__(*args, topology=topology, **kwargs)
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._env_agents = {}
|
self._env_agents = {}
|
||||||
|
self.dry_run = dry_run
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.dump = dump
|
self.dir_path = dir_path or tempfile.mkdtemp('soil-env')
|
||||||
|
self.get_path()
|
||||||
|
self._history = history.History(name=self.name if not dry_run else None,
|
||||||
|
dir_path=self.dir_path)
|
||||||
# Add environment agents first, so their events get
|
# Add environment agents first, so their events get
|
||||||
# executed before network agents
|
# executed before network agents
|
||||||
self['SEED'] = seed or time.time()
|
|
||||||
random.seed(self['SEED'])
|
|
||||||
self.process(self.save_state())
|
|
||||||
self.environment_agents = environment_agents or []
|
self.environment_agents = environment_agents or []
|
||||||
self.network_agents = network_agents or []
|
self.network_agents = network_agents or []
|
||||||
if self.dump:
|
self['SEED'] = seed or time.time()
|
||||||
self._db_path = os.path.join(self.get_path(), '{}.db.sqlite'.format(self.name))
|
random.seed(self['SEED'])
|
||||||
else:
|
|
||||||
self._db_path = ":memory:"
|
|
||||||
self.create_db(self._db_path)
|
|
||||||
|
|
||||||
def create_db(self, db_path=None):
|
|
||||||
db_path = db_path or self._db_path
|
|
||||||
if os.path.exists(db_path):
|
|
||||||
newname = db_path.replace('db.sqlite', 'backup{}.sqlite'.format(time.time()))
|
|
||||||
os.rename(db_path, newname)
|
|
||||||
self._db = sqlite3.connect(db_path)
|
|
||||||
with self._db:
|
|
||||||
self._db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text, value_type text)''')
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def agents(self):
|
def agents(self):
|
||||||
@@ -90,11 +90,11 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
|||||||
|
|
||||||
@network_agents.setter
|
@network_agents.setter
|
||||||
def network_agents(self, network_agents):
|
def network_agents(self, network_agents):
|
||||||
|
if not network_agents:
|
||||||
|
return
|
||||||
for ix in self.G.nodes():
|
for ix in self.G.nodes():
|
||||||
i = ix
|
agent, state = agents._agent_from_distribution(network_agents)
|
||||||
node = self.G.node[i]
|
self.set_agent(ix, agent_type=agent, state=state)
|
||||||
agent, state = utils.agent_from_distribution(network_agents)
|
|
||||||
self.set_agent(i, agent_type=agent, state=state)
|
|
||||||
|
|
||||||
def set_agent(self, agent_id, agent_type, state=None):
|
def set_agent(self, agent_id, agent_type, state=None):
|
||||||
node = self.G.nodes[agent_id]
|
node = self.G.nodes[agent_id]
|
||||||
@@ -121,16 +121,21 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
|||||||
return self.G.add_edge(agent1, agent2)
|
return self.G.add_edge(agent1, agent2)
|
||||||
|
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
|
self._save_state()
|
||||||
super().run(*args, **kwargs)
|
super().run(*args, **kwargs)
|
||||||
|
self._history.flush_cache()
|
||||||
|
|
||||||
def _save_state(self, now=None):
|
def _save_state(self, now=None):
|
||||||
# for agent in self.agents:
|
# for agent in self.agents:
|
||||||
# agent.save_state()
|
# agent.save_state()
|
||||||
utils.logger.debug('Saving state @{}'.format(self.now))
|
utils.logger.debug('Saving state @{}'.format(self.now))
|
||||||
with self._db:
|
self._history.save_records(self.state_to_tuples(now=now))
|
||||||
self._db.executemany("insert into history(agent_id, t_step, key, value, value_type) values (?, ?, ?, ?, ?)", self.state_to_tuples(now=now))
|
|
||||||
|
|
||||||
def save_state(self):
|
def save_state(self):
|
||||||
|
'''
|
||||||
|
:DEPRECATED:
|
||||||
|
Periodically save the state of the environment and the agents.
|
||||||
|
'''
|
||||||
self._save_state()
|
self._save_state()
|
||||||
while self.peek() != simpy.core.Infinity:
|
while self.peek() != simpy.core.Infinity:
|
||||||
delay = max(self.peek() - self.now, self.interval)
|
delay = max(self.peek() - self.now, self.interval)
|
||||||
@@ -145,64 +150,44 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
|||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if isinstance(key, tuple):
|
if isinstance(key, tuple):
|
||||||
values = [("agent_id", key[0]),
|
self._history.flush_cache()
|
||||||
("t_step", key[1]),
|
return self._history[key]
|
||||||
("key", key[2]),
|
|
||||||
("value", None),
|
|
||||||
("value_type", None)]
|
|
||||||
fields = list(k for k, v in values if v is None)
|
|
||||||
conditions = " and ".join("{}='{}'".format(k, v) for k, v in values if v is not None)
|
|
||||||
|
|
||||||
query = """SELECT {fields} from history""".format(fields=",".join(fields))
|
|
||||||
if conditions:
|
|
||||||
query = """{query} where {conditions}""".format(query=query,
|
|
||||||
conditions=conditions)
|
|
||||||
with self._db:
|
|
||||||
rows = self._db.execute(query).fetchall()
|
|
||||||
|
|
||||||
utils.logger.debug(rows)
|
|
||||||
results = self.rows_to_dict(rows)
|
|
||||||
return results
|
|
||||||
|
|
||||||
return self.environment_params[key]
|
return self.environment_params[key]
|
||||||
|
|
||||||
def rows_to_dict(self, rows):
|
|
||||||
if len(rows) < 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
level = len(rows[0])-2
|
|
||||||
|
|
||||||
if level == 0:
|
|
||||||
if len(rows) != 1:
|
|
||||||
raise ValueError('Cannot convert {} to dictionaries'.format(rows))
|
|
||||||
value, value_type = rows[0]
|
|
||||||
return utils.convert(value, value_type)
|
|
||||||
|
|
||||||
results = {}
|
|
||||||
for row in rows:
|
|
||||||
item = results
|
|
||||||
for i in range(level-1):
|
|
||||||
key = row[i]
|
|
||||||
if key not in item:
|
|
||||||
item[key] = {}
|
|
||||||
item = item[key]
|
|
||||||
key, value, value_type = row[level-1:]
|
|
||||||
item[key] = utils.convert(value, value_type)
|
|
||||||
return results
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
|
if isinstance(key, tuple):
|
||||||
|
k = history.Key(*key)
|
||||||
|
self._history.save_record(*k,
|
||||||
|
value=value)
|
||||||
|
return
|
||||||
self.environment_params[key] = value
|
self.environment_params[key] = value
|
||||||
|
self._history.save_record(agent_id='env',
|
||||||
|
t_step=self.now,
|
||||||
|
key=key,
|
||||||
|
value=value)
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
return key in self.environment_params
|
return key in self.environment_params
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
|
'''
|
||||||
|
Get the value of an environment attribute in a
|
||||||
|
given point in the simulation (history).
|
||||||
|
If key is an attribute name, this method returns
|
||||||
|
the current value.
|
||||||
|
To get values at other times, use a
|
||||||
|
:meth: `soil.history.Key` tuple.
|
||||||
|
'''
|
||||||
return self[key] if key in self else default
|
return self[key] if key in self else default
|
||||||
|
|
||||||
def get_path(self, dir_path=None):
|
def get_path(self, dir_path=None):
|
||||||
dir_path = dir_path or self.sim().dir_path
|
dir_path = dir_path or self.dir_path
|
||||||
if not os.path.exists(dir_path):
|
if not os.path.exists(dir_path):
|
||||||
os.makedirs(dir_path)
|
try:
|
||||||
|
os.makedirs(dir_path)
|
||||||
|
except FileExistsError:
|
||||||
|
pass
|
||||||
return dir_path
|
return dir_path
|
||||||
|
|
||||||
def get_agent(self, agent_id):
|
def get_agent(self, agent_id):
|
||||||
@@ -225,23 +210,45 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
|||||||
G = self.history_to_graph()
|
G = self.history_to_graph()
|
||||||
graph_path = os.path.join(self.get_path(dir_path),
|
graph_path = os.path.join(self.get_path(dir_path),
|
||||||
self.name+".gexf")
|
self.name+".gexf")
|
||||||
|
# Workaround for geometric models
|
||||||
|
# See soil/soil#4
|
||||||
|
for node in G.nodes():
|
||||||
|
if 'pos' in G.node[node]:
|
||||||
|
G.node[node]['viz'] = {"position": {"x": G.node[node]['pos'][0], "y": G.node[node]['pos'][1], "z": 0.0}}
|
||||||
|
del (G.node[node]['pos'])
|
||||||
|
|
||||||
nx.write_gexf(G, graph_path, version="1.2draft")
|
nx.write_gexf(G, graph_path, version="1.2draft")
|
||||||
|
|
||||||
|
def dump(self, dir_path=None, formats=None):
|
||||||
|
if not formats:
|
||||||
|
return
|
||||||
|
functions = {
|
||||||
|
'csv': self.dump_csv,
|
||||||
|
'gexf': self.dump_gexf
|
||||||
|
}
|
||||||
|
for f in formats:
|
||||||
|
if f in functions:
|
||||||
|
functions[f](dir_path)
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown format: {}'.format(f))
|
||||||
|
|
||||||
def state_to_tuples(self, now=None):
|
def state_to_tuples(self, now=None):
|
||||||
if now is None:
|
if now is None:
|
||||||
now = self.now
|
now = self.now
|
||||||
for k, v in self.environment_params.items():
|
for k, v in self.environment_params.items():
|
||||||
v, v_t = utils.repr(v)
|
yield history.Record(agent_id='env',
|
||||||
yield 'env', now, k, v, v_t
|
t_step=now,
|
||||||
|
key=k,
|
||||||
|
value=v)
|
||||||
for agent in self.agents:
|
for agent in self.agents:
|
||||||
for k, v in agent.state.items():
|
for k, v in agent.state.items():
|
||||||
v, v_t = utils.repr(v)
|
yield history.Record(agent_id=agent.id,
|
||||||
yield agent.id, now, k, v, v_t
|
t_step=now,
|
||||||
|
key=k,
|
||||||
|
value=v)
|
||||||
|
|
||||||
def history_to_tuples(self):
|
def history_to_tuples(self):
|
||||||
with self._db:
|
return self._history.to_tuples()
|
||||||
res = self._db.execute("select agent_id, t_step, key, value, value_type from history ").fetchall()
|
|
||||||
yield from res
|
|
||||||
|
|
||||||
def history_to_graph(self):
|
def history_to_graph(self):
|
||||||
G = nx.Graph(self.G)
|
G = nx.Graph(self.G)
|
||||||
@@ -289,3 +296,19 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
|||||||
G.add_node(agent.id, **attributes)
|
G.add_node(agent.id, **attributes)
|
||||||
|
|
||||||
return G
|
return G
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state['G'] = json_graph.node_link_data(self.G)
|
||||||
|
state['network_agents'] = agents._serialize_distribution(self.network_agents)
|
||||||
|
state['environment_agents'] = agents._convert_agent_types(self.environment_agents,
|
||||||
|
to_string=True)
|
||||||
|
del state['_queue']
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__ = state
|
||||||
|
self.G = json_graph.node_link_graph(state['G'])
|
||||||
|
self.network_agents = self.calculate_distribution(self._convert_agent_types(self.network_agents))
|
||||||
|
self.environment_agents = self._convert_agent_types(self.environment_agents)
|
||||||
|
return state
|
||||||
|
231
soil/history.py
Normal file
231
soil/history.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
import time
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import sqlite3
|
||||||
|
import copy
|
||||||
|
from collections import UserDict, Iterable, namedtuple
|
||||||
|
|
||||||
|
from . import utils
|
||||||
|
|
||||||
|
|
||||||
|
class History:
|
||||||
|
"""
|
||||||
|
Store and retrieve values from a sqlite database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path=None, name=None, dir_path=None, backup=True):
|
||||||
|
if db_path is None and name:
|
||||||
|
db_path = os.path.join(dir_path or os.getcwd(),
|
||||||
|
'{}.db.sqlite'.format(name))
|
||||||
|
if db_path is None:
|
||||||
|
db_path = ":memory:"
|
||||||
|
else:
|
||||||
|
if backup and os.path.exists(db_path):
|
||||||
|
newname = db_path + '.backup{}.sqlite'.format(time.time())
|
||||||
|
os.rename(db_path, newname)
|
||||||
|
self._db_path = db_path
|
||||||
|
if isinstance(db_path, str):
|
||||||
|
self._db = sqlite3.connect(db_path)
|
||||||
|
else:
|
||||||
|
self._db = db_path
|
||||||
|
|
||||||
|
with self._db:
|
||||||
|
self._db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text text)''')
|
||||||
|
self._db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''')
|
||||||
|
self._db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''')
|
||||||
|
self._dtypes = {}
|
||||||
|
self._tups = []
|
||||||
|
|
||||||
|
def conversors(self, key):
|
||||||
|
"""Get the serializer and deserializer for a given key."""
|
||||||
|
if key not in self._dtypes:
|
||||||
|
self.read_types()
|
||||||
|
return self._dtypes[key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtypes(self):
|
||||||
|
return {k:v[0] for k, v in self._dtypes.items()}
|
||||||
|
|
||||||
|
def save_tuples(self, tuples):
|
||||||
|
self.save_records(Record(*tup) for tup in tuples)
|
||||||
|
|
||||||
|
def save_records(self, records):
|
||||||
|
with self._db:
|
||||||
|
for rec in records:
|
||||||
|
if not isinstance(rec, Record):
|
||||||
|
rec = Record(*rec)
|
||||||
|
if rec.key not in self._dtypes:
|
||||||
|
name = utils.name(rec.value)
|
||||||
|
serializer = utils.serializer(name)
|
||||||
|
deserializer = utils.deserializer(name)
|
||||||
|
self._dtypes[rec.key] = (name, serializer, deserializer)
|
||||||
|
self._db.execute("replace into value_types (key, value_type) values (?, ?)", (rec.key, name))
|
||||||
|
self._db.execute("replace into history(agent_id, t_step, key, value) values (?, ?, ?, ?)", (rec.agent_id, rec.t_step, rec.key, rec.value))
|
||||||
|
|
||||||
|
def save_record(self, *args, **kwargs):
|
||||||
|
self._tups.append(Record(*args, **kwargs))
|
||||||
|
if len(self._tups) > 100:
|
||||||
|
self.flush_cache()
|
||||||
|
|
||||||
|
def flush_cache(self):
|
||||||
|
'''
|
||||||
|
Use a cache to save state changes to avoid opening a session for every change.
|
||||||
|
The cache will be flushed at the end of the simulation, and when history is accessed.
|
||||||
|
'''
|
||||||
|
self.save_records(self._tups)
|
||||||
|
self._tups = list()
|
||||||
|
|
||||||
|
def to_tuples(self):
|
||||||
|
self.flush_cache()
|
||||||
|
with self._db:
|
||||||
|
res = self._db.execute("select agent_id, t_step, key, value from history ").fetchall()
|
||||||
|
for r in res:
|
||||||
|
agent_id, t_step, key, value = r
|
||||||
|
_, _ , des = self.conversors(key)
|
||||||
|
yield agent_id, t_step, key, des(value)
|
||||||
|
|
||||||
|
def read_types(self):
|
||||||
|
with self._db:
|
||||||
|
res = self._db.execute("select key, value_type from value_types ").fetchall()
|
||||||
|
for k, v in res:
|
||||||
|
serializer = utils.serializer(v)
|
||||||
|
deserializer = utils.deserializer(v)
|
||||||
|
self._dtypes[k] = (v, serializer, deserializer)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
key = Key(*key)
|
||||||
|
agent_ids = [key.agent_id] if key.agent_id is not None else []
|
||||||
|
t_steps = [key.t_step] if key.t_step is not None else []
|
||||||
|
keys = [key.key] if key.key is not None else []
|
||||||
|
|
||||||
|
df = self.read_sql(agent_ids=agent_ids,
|
||||||
|
t_steps=t_steps,
|
||||||
|
keys=keys)
|
||||||
|
r = Records(df, filter=key, dtypes=self._dtypes)
|
||||||
|
return r.value()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def read_sql(self, keys=None, agent_ids=None, t_steps=None, convert_types=False, limit=-1):
|
||||||
|
|
||||||
|
self.read_types()
|
||||||
|
|
||||||
|
def escape_and_join(v):
|
||||||
|
if v is None:
|
||||||
|
return
|
||||||
|
return ",".join(map(lambda x: "\'{}\'".format(x), v))
|
||||||
|
|
||||||
|
filters = [("key in ({})".format(escape_and_join(keys)), keys),
|
||||||
|
("agent_id in ({})".format(escape_and_join(agent_ids)), agent_ids)
|
||||||
|
]
|
||||||
|
filters = list(k[0] for k in filters if k[1])
|
||||||
|
|
||||||
|
last_df = None
|
||||||
|
if t_steps:
|
||||||
|
# Look for the last value before the minimum step in the query
|
||||||
|
min_step = min(t_steps)
|
||||||
|
last_filters = ['t_step < {}'.format(min_step),]
|
||||||
|
last_filters = last_filters + filters
|
||||||
|
condition = ' and '.join(last_filters)
|
||||||
|
|
||||||
|
last_query = '''
|
||||||
|
select h1.*
|
||||||
|
from history h1
|
||||||
|
inner join (
|
||||||
|
select agent_id, key, max(t_step) as t_step
|
||||||
|
from history
|
||||||
|
where {condition}
|
||||||
|
group by agent_id, key
|
||||||
|
) h2
|
||||||
|
on h1.agent_id = h2.agent_id and
|
||||||
|
h1.key = h2.key and
|
||||||
|
h1.t_step = h2.t_step
|
||||||
|
'''.format(condition=condition)
|
||||||
|
last_df = pd.read_sql_query(last_query, self._db)
|
||||||
|
|
||||||
|
filters.append("t_step >= '{}' and t_step <= '{}'".format(min_step, max(t_steps)))
|
||||||
|
|
||||||
|
condition = ''
|
||||||
|
if filters:
|
||||||
|
condition = 'where {} '.format(' and '.join(filters))
|
||||||
|
query = 'select * from history {} limit {}'.format(condition, limit)
|
||||||
|
df = pd.read_sql_query(query, self._db)
|
||||||
|
if last_df is not None:
|
||||||
|
df = pd.concat([df, last_df])
|
||||||
|
|
||||||
|
df_p = df.pivot_table(values='value', index=['t_step'],
|
||||||
|
columns=['key', 'agent_id'],
|
||||||
|
aggfunc='first')
|
||||||
|
|
||||||
|
for k, v in self._dtypes.items():
|
||||||
|
if k in df_p:
|
||||||
|
dtype, _, deserial = v
|
||||||
|
df_p[k] = df_p[k].fillna(method='ffill').fillna(deserial()).astype(dtype)
|
||||||
|
if t_steps:
|
||||||
|
df_p = df_p.reindex(t_steps, method='ffill')
|
||||||
|
return df_p.ffill()
|
||||||
|
|
||||||
|
|
||||||
|
class Records():
|
||||||
|
|
||||||
|
def __init__(self, df, filter=None, dtypes=None):
|
||||||
|
if not filter:
|
||||||
|
filter = Key(agent_id=None,
|
||||||
|
t_step=None,
|
||||||
|
key=None)
|
||||||
|
self._df = df
|
||||||
|
self._filter = filter
|
||||||
|
self.dtypes = dtypes or {}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def mask(self, tup):
|
||||||
|
res = ()
|
||||||
|
for i, k in zip(tup[:-1], self._filter):
|
||||||
|
if k is None:
|
||||||
|
res = res + (i,)
|
||||||
|
res = res + (tup[-1],)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def filter(self, newKey):
|
||||||
|
f = list(self._filter)
|
||||||
|
for ix, i in enumerate(f):
|
||||||
|
if i is None:
|
||||||
|
f[ix] = newKey
|
||||||
|
self._filter = Key(*f)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def resolved(self):
|
||||||
|
return sum(1 for i in self._filter if i is not None) == 3
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for column, series in self._df.iteritems():
|
||||||
|
key, agent_id = column
|
||||||
|
for t_step, value in series.iteritems():
|
||||||
|
r = Record(t_step=t_step,
|
||||||
|
agent_id=agent_id,
|
||||||
|
key=key,
|
||||||
|
value=value)
|
||||||
|
yield self.mask(r)
|
||||||
|
|
||||||
|
def value(self):
|
||||||
|
if self.resolved:
|
||||||
|
f = self._filter
|
||||||
|
try:
|
||||||
|
i = self._df[f.key][str(f.agent_id)]
|
||||||
|
ix = i.index.get_loc(f.t_step, method='ffill')
|
||||||
|
return i.iloc[ix]
|
||||||
|
except KeyError:
|
||||||
|
return self.dtypes[f.key][2]()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __getitem__(self, k):
|
||||||
|
n = copy.copy(self)
|
||||||
|
n.filter(k)
|
||||||
|
return n.value()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._df)
|
||||||
|
|
||||||
|
|
||||||
|
Key = namedtuple('Key', ['agent_id', 't_step', 'key'])
|
||||||
|
Record = namedtuple('Record', 'agent_id t_step key value')
|
@@ -5,14 +5,14 @@ import sys
|
|||||||
import yaml
|
import yaml
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from networkx.readwrite import json_graph
|
from networkx.readwrite import json_graph
|
||||||
|
from multiprocessing import Pool
|
||||||
from copy import deepcopy
|
from functools import partial
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
from nxsim import NetworkSimulation
|
from nxsim import NetworkSimulation
|
||||||
|
|
||||||
from . import agents, utils, environment, basestring
|
from . import utils, environment, basestring, agents
|
||||||
from .utils import logger
|
from .utils import logger
|
||||||
|
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ class SoilSimulation(NetworkSimulation):
|
|||||||
"""
|
"""
|
||||||
Subclass of nsim.NetworkSimulation with three main differences:
|
Subclass of nsim.NetworkSimulation with three main differences:
|
||||||
1) agent type can be specified by name or by class.
|
1) agent type can be specified by name or by class.
|
||||||
2) instead of just one type, an network_agents can be used.
|
2) instead of just one type, a network agents distribution can be used.
|
||||||
The distribution specifies the weight (or probability) of each
|
The distribution specifies the weight (or probability) of each
|
||||||
agent type in the topology. This is an example distribution: ::
|
agent type in the topology. This is an example distribution: ::
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ class SoilSimulation(NetworkSimulation):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, name=None, topology=None, network_params=None,
|
def __init__(self, name=None, topology=None, network_params=None,
|
||||||
network_agents=None, agent_type=None, states=None,
|
network_agents=None, agent_type=None, states=None,
|
||||||
default_state=None, interval=1, dump=False,
|
default_state=None, interval=1, dump=None, dry_run=False,
|
||||||
dir_path=None, num_trials=1, max_time=100,
|
dir_path=None, num_trials=1, max_time=100,
|
||||||
agent_module=None, load_module=None, seed=None,
|
agent_module=None, load_module=None, seed=None,
|
||||||
environment_agents=None, environment_params=None):
|
environment_agents=None, environment_params=None):
|
||||||
@@ -57,7 +57,6 @@ class SoilSimulation(NetworkSimulation):
|
|||||||
elif isinstance(topology, basestring) or isinstance(topology, dict):
|
elif isinstance(topology, basestring) or isinstance(topology, dict):
|
||||||
topology = json_graph.node_link_graph(topology)
|
topology = json_graph.node_link_graph(topology)
|
||||||
|
|
||||||
|
|
||||||
self.load_module = load_module
|
self.load_module = load_module
|
||||||
self.topology = nx.Graph(topology)
|
self.topology = nx.Graph(topology)
|
||||||
self.network_params = network_params
|
self.network_params = network_params
|
||||||
@@ -69,94 +68,70 @@ class SoilSimulation(NetworkSimulation):
|
|||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.seed = str(seed) or str(time.time())
|
self.seed = str(seed) or str(time.time())
|
||||||
self.dump = dump
|
self.dump = dump
|
||||||
|
self.dry_run = dry_run
|
||||||
self.environment_params = environment_params or {}
|
self.environment_params = environment_params or {}
|
||||||
|
|
||||||
if load_module:
|
if load_module:
|
||||||
path = sys.path + [self.dir_path]
|
path = sys.path + [self.dir_path, os.getcwd()]
|
||||||
f, fp, desc = imp.find_module(load_module, path)
|
f, fp, desc = imp.find_module(load_module, path)
|
||||||
imp.load_module('soil.agents.custom', f, fp, desc)
|
imp.load_module('soil.agents.custom', f, fp, desc)
|
||||||
|
|
||||||
environment_agents = environment_agents or []
|
environment_agents = environment_agents or []
|
||||||
self.environment_agents = self._convert_agent_types(environment_agents)
|
self.environment_agents = agents._convert_agent_types(environment_agents)
|
||||||
|
|
||||||
distro = self.calculate_distribution(network_agents,
|
distro = agents.calculate_distribution(network_agents,
|
||||||
agent_type)
|
agent_type)
|
||||||
self.network_agents = self._convert_agent_types(distro)
|
self.network_agents = agents._convert_agent_types(distro)
|
||||||
|
|
||||||
self.states = self.validate_states(states,
|
self.states = agents._validate_states(states,
|
||||||
self.topology)
|
self.topology)
|
||||||
|
|
||||||
def calculate_distribution(self,
|
def run_simulation(self, *args, **kwargs):
|
||||||
network_agents=None,
|
return self.run(*args, **kwargs)
|
||||||
agent_type=None):
|
|
||||||
if network_agents:
|
|
||||||
network_agents = deepcopy(network_agents)
|
|
||||||
elif agent_type:
|
|
||||||
network_agents = [{'agent_type': agent_type}]
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Calculate the thresholds
|
def run(self, *args, **kwargs):
|
||||||
total = sum(x.get('weight', 1) for x in network_agents)
|
return list(self.run_simulation_gen(*args, **kwargs))
|
||||||
acc = 0
|
|
||||||
for v in network_agents:
|
|
||||||
upper = acc + (v.get('weight', 1)/total)
|
|
||||||
v['threshold'] = [acc, upper]
|
|
||||||
acc = upper
|
|
||||||
return network_agents
|
|
||||||
|
|
||||||
def serialize_distribution(self):
|
def run_simulation_gen(self, *args, parallel=False, dry_run=False,
|
||||||
d = self._convert_agent_types(self.network_agents,
|
**kwargs):
|
||||||
to_string=True)
|
p = Pool()
|
||||||
for v in d:
|
with utils.timer('simulation {}'.format(self.name)):
|
||||||
if 'threshold' in v:
|
if parallel:
|
||||||
del v['threshold']
|
func = partial(self.run_trial, dry_run=dry_run or self.dry_run,
|
||||||
return d
|
return_env=not parallel, **kwargs)
|
||||||
|
for i in p.imap_unordered(func, range(self.num_trials)):
|
||||||
def _convert_agent_types(self, ind, to_string=False):
|
yield i
|
||||||
d = deepcopy(ind)
|
else:
|
||||||
for v in d:
|
for i in range(self.num_trials):
|
||||||
agent_type = v['agent_type']
|
yield self.run_trial(i, dry_run=dry_run or self.dry_run, **kwargs)
|
||||||
if to_string and not isinstance(agent_type, str):
|
if not (dry_run or self.dry_run):
|
||||||
v['agent_type'] = str(agent_type.__name__)
|
|
||||||
elif not to_string and isinstance(agent_type, str):
|
|
||||||
v['agent_type'] = agents.agent_types[agent_type]
|
|
||||||
return d
|
|
||||||
|
|
||||||
def validate_states(self, states, topology):
|
|
||||||
states = states or []
|
|
||||||
# Validate states to avoid ignoring states during
|
|
||||||
# initialization
|
|
||||||
if isinstance(states, dict):
|
|
||||||
for x in states:
|
|
||||||
assert x in self.topology.node
|
|
||||||
else:
|
|
||||||
assert len(states) <= len(self.topology)
|
|
||||||
return states
|
|
||||||
|
|
||||||
def run_simulation(self):
|
|
||||||
return self.run()
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
return list(self.run_simulation_gen())
|
|
||||||
|
|
||||||
def run_simulation_gen(self, *args, **kwargs):
|
|
||||||
with utils.timer('simulation'):
|
|
||||||
for i in range(self.num_trials):
|
|
||||||
res = self.run_trial(i)
|
|
||||||
if self.dump:
|
|
||||||
res.dump_gexf(self.dir_path)
|
|
||||||
res.dump_csv(self.dir_path)
|
|
||||||
yield res
|
|
||||||
|
|
||||||
if self.dump:
|
|
||||||
logger.info('Dumping results to {}'.format(self.dir_path))
|
logger.info('Dumping results to {}'.format(self.dir_path))
|
||||||
self.dump_pickle(self.dir_path)
|
self.dump_pickle(self.dir_path)
|
||||||
self.dump_yaml(self.dir_path)
|
self.dump_yaml(self.dir_path)
|
||||||
else:
|
else:
|
||||||
logger.info('NOT dumping results')
|
logger.info('NOT dumping results')
|
||||||
|
|
||||||
def run_trial(self, trial_id=0, dump=False, dir_path=None):
|
def get_env(self, trial_id=0, **kwargs):
|
||||||
|
opts = self.environment_params.copy()
|
||||||
|
env_name = '{}_trial_{}'.format(self.name, trial_id)
|
||||||
|
opts.update({
|
||||||
|
'name': env_name,
|
||||||
|
'topology': self.topology.copy(),
|
||||||
|
'seed': self.seed+env_name,
|
||||||
|
'initial_time': 0,
|
||||||
|
'dry_run': self.dry_run,
|
||||||
|
'interval': self.interval,
|
||||||
|
'network_agents': self.network_agents,
|
||||||
|
'states': self.states,
|
||||||
|
'default_state': self.default_state,
|
||||||
|
'environment_agents': self.environment_agents,
|
||||||
|
'dir_path': self.dir_path,
|
||||||
|
})
|
||||||
|
opts.update(kwargs)
|
||||||
|
env = environment.SoilEnvironment(**opts)
|
||||||
|
return env
|
||||||
|
|
||||||
|
def run_trial(self, trial_id=0, until=None, return_env=True, **opts):
|
||||||
"""Run a single trial of the simulation
|
"""Run a single trial of the simulation
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -164,25 +139,16 @@ class SoilSimulation(NetworkSimulation):
|
|||||||
trial_id : int
|
trial_id : int
|
||||||
"""
|
"""
|
||||||
# Set-up trial environment and graph
|
# Set-up trial environment and graph
|
||||||
logger.info('Trial: {}'.format(trial_id))
|
until = until or self.max_time
|
||||||
env_name = '{}_trial_{}'.format(self.name, trial_id)
|
env = self.get_env(trial_id=trial_id, **opts)
|
||||||
env = environment.SoilEnvironment(name=env_name,
|
|
||||||
topology=self.topology.copy(),
|
|
||||||
seed=self.seed+env_name,
|
|
||||||
initial_time=0,
|
|
||||||
dump=self.dump,
|
|
||||||
interval=self.interval,
|
|
||||||
network_agents=self.network_agents,
|
|
||||||
states=self.states,
|
|
||||||
default_state=self.default_state,
|
|
||||||
environment_agents=self.environment_agents,
|
|
||||||
simulation=self,
|
|
||||||
**self.environment_params)
|
|
||||||
# Set up agents on nodes
|
# Set up agents on nodes
|
||||||
logger.info('\tRunning')
|
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
||||||
with utils.timer('trial'):
|
env.run(until)
|
||||||
env.run(until=self.max_time)
|
if self.dump and not self.dry_run:
|
||||||
return env
|
with utils.timer('Dumping simulation {} trial {}'.format(self.name, trial_id)):
|
||||||
|
env.dump(formats=self.dump)
|
||||||
|
if return_env:
|
||||||
|
return env
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return self.__getstate__()
|
return self.__getstate__()
|
||||||
@@ -213,20 +179,20 @@ class SoilSimulation(NetworkSimulation):
|
|||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
state = self.__dict__.copy()
|
state = self.__dict__.copy()
|
||||||
state['topology'] = json_graph.node_link_data(self.topology)
|
state['topology'] = json_graph.node_link_data(self.topology)
|
||||||
state['network_agents'] = self.serialize_distribution()
|
state['network_agents'] = agents._serialize_distribution(self.network_agents)
|
||||||
state['environment_agents'] = self._convert_agent_types(self.environment_agents,
|
state['environment_agents'] = agents._convert_agent_types(self.environment_agents,
|
||||||
to_string=True)
|
to_string=True)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
self.__dict__ = state
|
self.__dict__ = state
|
||||||
self.topology = json_graph.node_link_graph(state['topology'])
|
self.topology = json_graph.node_link_graph(state['topology'])
|
||||||
self.network_agents = self._convert_agent_types(self.network_agents)
|
self.network_agents = agents.calculate_distribution(agents._convert_agent_types(self.network_agents))
|
||||||
self.environment_agents = self._convert_agent_types(self.environment_agents)
|
self.environment_agents = agents._convert_agent_types(self.environment_agents)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
def from_config(config, G=None):
|
def from_config(config):
|
||||||
config = list(utils.load_config(config))
|
config = list(utils.load_config(config))
|
||||||
if len(config) > 1:
|
if len(config) > 1:
|
||||||
raise AttributeError('Provide only one configuration')
|
raise AttributeError('Provide only one configuration')
|
||||||
@@ -235,21 +201,19 @@ def from_config(config, G=None):
|
|||||||
return sim
|
return sim
|
||||||
|
|
||||||
|
|
||||||
def run_from_config(*configs, dump=True, results_dir=None, timestamp=False):
|
def run_from_config(*configs, results_dir='soil_output', dry_run=False, dump=None, timestamp=False, **kwargs):
|
||||||
if not results_dir:
|
|
||||||
results_dir = 'soil_output'
|
|
||||||
for config_def in configs:
|
for config_def in configs:
|
||||||
for config, cpath in utils.load_config(config_def):
|
# logger.info("Found {} config(s)".format(len(ls)))
|
||||||
|
for config, _ in utils.load_config(config_def):
|
||||||
name = config.get('name', 'unnamed')
|
name = config.get('name', 'unnamed')
|
||||||
logger.info("Using config(s): {name}".format(name=name))
|
logger.info("Using config(s): {name}".format(name=name))
|
||||||
|
|
||||||
sim = SoilSimulation(**config)
|
|
||||||
if timestamp:
|
if timestamp:
|
||||||
sim_folder = '{}_{}'.format(sim.name,
|
sim_folder = '{}_{}'.format(name,
|
||||||
time.strftime("%Y-%m-%d_%H:%M:%S"))
|
time.strftime("%Y-%m-%d_%H:%M:%S"))
|
||||||
else:
|
else:
|
||||||
sim_folder = sim.name
|
sim_folder = name
|
||||||
sim.dir_path = os.path.join(results_dir, sim_folder)
|
dir_path = os.path.join(results_dir, sim_folder)
|
||||||
sim.dump = dump
|
sim = SoilSimulation(dir_path=dir_path, dump=dump, **config)
|
||||||
logger.info('Dumping results to {} : {}'.format(sim.dir_path, dump))
|
logger.info('Dumping results to {} : {}'.format(sim.dir_path, sim.dump))
|
||||||
results = sim.run_simulation()
|
sim.run_simulation(**kwargs)
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
import logging
|
import logging
|
||||||
|
import importlib
|
||||||
from time import time
|
from time import time
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from random import random
|
from random import random
|
||||||
@@ -11,7 +12,7 @@ import networkx as nx
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger('soil')
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
@@ -62,6 +63,7 @@ def load_config(config):
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def timer(name='task', pre="", function=logger.info, to_object=None):
|
def timer(name='task', pre="", function=logger.info, to_object=None):
|
||||||
start = time()
|
start = time()
|
||||||
|
function('{}Starting {} at {}.'.format(pre, name, start))
|
||||||
yield start
|
yield start
|
||||||
end = time()
|
end = time()
|
||||||
function('{}Finished {} in {} seconds'.format(pre, name, str(end-start)))
|
function('{}Finished {} in {} seconds'.format(pre, name, str(end-start)))
|
||||||
@@ -70,29 +72,23 @@ def timer(name='task', pre="", function=logger.info, to_object=None):
|
|||||||
to_object.end = end
|
to_object.end = end
|
||||||
|
|
||||||
|
|
||||||
def agent_from_distribution(distribution, value=-1):
|
|
||||||
"""Find the agent """
|
|
||||||
if value < 0:
|
|
||||||
value = random()
|
|
||||||
for d in distribution:
|
|
||||||
threshold = d['threshold']
|
|
||||||
if value >= threshold[0] and value < threshold[1]:
|
|
||||||
state = {}
|
|
||||||
if 'state' in d:
|
|
||||||
state = deepcopy(d['state'])
|
|
||||||
return d['agent_type'], state
|
|
||||||
|
|
||||||
raise Exception('Distribution for value {} not found in: {}'.format(value, distribution))
|
|
||||||
|
|
||||||
|
|
||||||
def repr(v):
|
def repr(v):
|
||||||
if isinstance(v, bool):
|
func = serializer(v)
|
||||||
v = "true" if v else ""
|
tname = name(v)
|
||||||
return v, bool.__name__
|
return func(v), tname
|
||||||
return v, type(v).__name__
|
|
||||||
|
|
||||||
def convert(value, type_):
|
|
||||||
import importlib
|
def name(v):
|
||||||
|
return type(v).__name__
|
||||||
|
|
||||||
|
|
||||||
|
def serializer(type_):
|
||||||
|
if type_ == 'bool':
|
||||||
|
return lambda x: "true" if x else ""
|
||||||
|
return lambda x: x
|
||||||
|
|
||||||
|
|
||||||
|
def deserializer(type_):
|
||||||
try:
|
try:
|
||||||
# Check if it's a builtin type
|
# Check if it's a builtin type
|
||||||
module = importlib.import_module('builtins')
|
module = importlib.import_module('builtins')
|
||||||
@@ -102,4 +98,8 @@ def convert(value, type_):
|
|||||||
module, type_ = type_.rsplit(".", 1)
|
module, type_ = type_.rsplit(".", 1)
|
||||||
module = importlib.import_module(module)
|
module = importlib.import_module(module)
|
||||||
cls = getattr(module, type_)
|
cls = getattr(module, type_)
|
||||||
return cls(value)
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
def convert(value, type_):
|
||||||
|
return deserializer(type_)(value)
|
||||||
|
20
soil/version.py
Normal file
20
soil/version.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ROOT = os.path.dirname(__file__)
|
||||||
|
DEFAULT_FILE = os.path.join(ROOT, 'VERSION')
|
||||||
|
|
||||||
|
|
||||||
|
def read_version(versionfile=DEFAULT_FILE):
|
||||||
|
try:
|
||||||
|
with open(versionfile) as f:
|
||||||
|
return f.read().strip()
|
||||||
|
except IOError: # pragma: no cover
|
||||||
|
logger.error(('Running an unknown version of {}.'
|
||||||
|
'Be careful!.').format(__name__))
|
||||||
|
return '0.0'
|
||||||
|
|
||||||
|
|
||||||
|
__version__ = read_version()
|
16
tests/test.csv
Normal file
16
tests/test.csv
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
agent_id,t_step,key,value,value_type
|
||||||
|
a0,0,hello,w,str
|
||||||
|
a0,1,hello,o,str
|
||||||
|
a0,2,hello,r,str
|
||||||
|
a0,3,hello,l,str
|
||||||
|
a0,4,hello,d,str
|
||||||
|
a0,5,hello,!,str
|
||||||
|
env,1,started,,bool
|
||||||
|
env,2,started,True,bool
|
||||||
|
env,7,started,,bool
|
||||||
|
a0,0,hello,w,str
|
||||||
|
a0,1,hello,o,str
|
||||||
|
a0,2,hello,r,str
|
||||||
|
a0,3,hello,l,str
|
||||||
|
a0,4,hello,d,str
|
||||||
|
a0,5,hello,!,str
|
|
90
tests/test_analysis.py
Normal file
90
tests/test_analysis.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import yaml
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from os.path import join
|
||||||
|
from soil import simulation, analysis, agents
|
||||||
|
|
||||||
|
|
||||||
|
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
class Ping(agents.FSM):
|
||||||
|
|
||||||
|
defaults = {
|
||||||
|
'count': 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
@agents.default_state
|
||||||
|
@agents.state
|
||||||
|
def even(self):
|
||||||
|
self['count'] += 1
|
||||||
|
return self.odd
|
||||||
|
|
||||||
|
@agents.state
|
||||||
|
def odd(self):
|
||||||
|
self['count'] += 1
|
||||||
|
return self.even
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalysis(TestCase):
|
||||||
|
|
||||||
|
# Code to generate a simple sqlite history
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
The initial states should be applied to the agent and the
|
||||||
|
agent should be able to update its state."""
|
||||||
|
config = {
|
||||||
|
'name': 'analysis',
|
||||||
|
'dry_run': True,
|
||||||
|
'seed': 'seed',
|
||||||
|
'network_params': {
|
||||||
|
'generator': 'complete_graph',
|
||||||
|
'n': 2
|
||||||
|
},
|
||||||
|
'agent_type': Ping,
|
||||||
|
'states': [{'interval': 1}, {'interval': 2}],
|
||||||
|
'max_time': 30,
|
||||||
|
'num_trials': 1,
|
||||||
|
'environment_params': {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s = simulation.from_config(config)
|
||||||
|
self.env = s.run_simulation()[0]
|
||||||
|
|
||||||
|
def test_saved(self):
|
||||||
|
env = self.env
|
||||||
|
assert env.get_agent(0)['count', 0] == 1
|
||||||
|
assert env.get_agent(0)['count', 29] == 30
|
||||||
|
assert env.get_agent(1)['count', 0] == 1
|
||||||
|
assert env.get_agent(1)['count', 29] == 15
|
||||||
|
assert env['env', 29, None]['SEED'] == env['env', 29, 'SEED']
|
||||||
|
|
||||||
|
def test_count(self):
|
||||||
|
env = self.env
|
||||||
|
df = analysis.read_sql(env._history._db)
|
||||||
|
res = analysis.get_count(df, 'SEED', 'id')
|
||||||
|
assert res['SEED']['seedanalysis_trial_0'].iloc[0] == 1
|
||||||
|
assert res['SEED']['seedanalysis_trial_0'].iloc[-1] == 1
|
||||||
|
assert res['id']['odd'].iloc[0] == 2
|
||||||
|
assert res['id']['even'].iloc[0] == 0
|
||||||
|
assert res['id']['odd'].iloc[-1] == 1
|
||||||
|
assert res['id']['even'].iloc[-1] == 1
|
||||||
|
|
||||||
|
def test_value(self):
|
||||||
|
env = self.env
|
||||||
|
df = analysis.read_sql(env._history._db)
|
||||||
|
res_sum = analysis.get_value(df, 'count')
|
||||||
|
|
||||||
|
assert res_sum['count'].iloc[0] == 2
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
res_mean = analysis.get_value(df, 'count', aggfunc=np.mean)
|
||||||
|
assert res_mean['count'].iloc[0] == 1
|
||||||
|
|
||||||
|
res_total = analysis.get_value(df)
|
||||||
|
|
||||||
|
res_total['SEED'].iloc[0] == 'seedanalysis_trial_0'
|
133
tests/test_history.py
Normal file
133
tests/test_history.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
from soil import history
|
||||||
|
|
||||||
|
|
||||||
|
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
DBROOT = os.path.join(ROOT, 'testdb')
|
||||||
|
|
||||||
|
|
||||||
|
class TestHistory(TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
if not os.path.exists(DBROOT):
|
||||||
|
os.makedirs(DBROOT)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
if os.path.exists(DBROOT):
|
||||||
|
shutil.rmtree(DBROOT)
|
||||||
|
|
||||||
|
def test_history(self):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
tuples = (
|
||||||
|
('a_0', 0, 'id', 'h'),
|
||||||
|
('a_0', 1, 'id', 'e'),
|
||||||
|
('a_0', 2, 'id', 'l'),
|
||||||
|
('a_0', 3, 'id', 'l'),
|
||||||
|
('a_0', 4, 'id', 'o'),
|
||||||
|
('a_1', 0, 'id', 'v'),
|
||||||
|
('a_1', 1, 'id', 'a'),
|
||||||
|
('a_1', 2, 'id', 'l'),
|
||||||
|
('a_1', 3, 'id', 'u'),
|
||||||
|
('a_1', 4, 'id', 'e'),
|
||||||
|
('env', 1, 'prob', 1),
|
||||||
|
('env', 3, 'prob', 2),
|
||||||
|
('env', 5, 'prob', 3),
|
||||||
|
('a_2', 7, 'finished', True),
|
||||||
|
)
|
||||||
|
h = history.History()
|
||||||
|
h.save_tuples(tuples)
|
||||||
|
# assert h['env', 0, 'prob'] == 0
|
||||||
|
for i in range(1, 7):
|
||||||
|
assert h['env', i, 'prob'] == ((i-1)//2)+1
|
||||||
|
|
||||||
|
|
||||||
|
for i, k in zip(range(5), 'hello'):
|
||||||
|
assert h['a_0', i, 'id'] == k
|
||||||
|
for record, value in zip(h['a_0', None, 'id'], 'hello'):
|
||||||
|
t_step, val = record
|
||||||
|
assert val == value
|
||||||
|
|
||||||
|
for i, k in zip(range(5), 'value'):
|
||||||
|
assert h['a_1', i, 'id'] == k
|
||||||
|
for i in range(5, 8):
|
||||||
|
assert h['a_1', i, 'id'] == 'e'
|
||||||
|
for i in range(7):
|
||||||
|
assert h['a_2', i, 'finished'] == False
|
||||||
|
assert h['a_2', 7, 'finished']
|
||||||
|
|
||||||
|
def test_history_gen(self):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
tuples = (
|
||||||
|
('a_1', 0, 'id', 'v'),
|
||||||
|
('a_1', 1, 'id', 'a'),
|
||||||
|
('a_1', 2, 'id', 'l'),
|
||||||
|
('a_1', 3, 'id', 'u'),
|
||||||
|
('a_1', 4, 'id', 'e'),
|
||||||
|
('env', 1, 'prob', 1),
|
||||||
|
('env', 2, 'prob', 2),
|
||||||
|
('env', 3, 'prob', 3),
|
||||||
|
('a_2', 7, 'finished', True),
|
||||||
|
)
|
||||||
|
h = history.History()
|
||||||
|
h.save_tuples(tuples)
|
||||||
|
for t_step, key, value in h['env', None, None]:
|
||||||
|
assert t_step == value
|
||||||
|
assert key == 'prob'
|
||||||
|
|
||||||
|
records = list(h[None, 7, None])
|
||||||
|
assert len(records) == 3
|
||||||
|
for i in records:
|
||||||
|
agent_id, key, value = i
|
||||||
|
if agent_id == 'a_1':
|
||||||
|
assert key == 'id'
|
||||||
|
assert value == 'e'
|
||||||
|
elif agent_id == 'a_2':
|
||||||
|
assert key == 'finished'
|
||||||
|
assert value
|
||||||
|
else:
|
||||||
|
assert key == 'prob'
|
||||||
|
assert value == 3
|
||||||
|
|
||||||
|
records = h['a_1', 7, None]
|
||||||
|
assert records['id'] == 'e'
|
||||||
|
|
||||||
|
def test_history_file(self):
|
||||||
|
"""
|
||||||
|
History should be saved to a file
|
||||||
|
"""
|
||||||
|
tuples = (
|
||||||
|
('a_1', 0, 'id', 'v'),
|
||||||
|
('a_1', 1, 'id', 'a'),
|
||||||
|
('a_1', 2, 'id', 'l'),
|
||||||
|
('a_1', 3, 'id', 'u'),
|
||||||
|
('a_1', 4, 'id', 'e'),
|
||||||
|
('env', 1, 'prob', 1),
|
||||||
|
('env', 2, 'prob', 2),
|
||||||
|
('env', 3, 'prob', 3),
|
||||||
|
('a_2', 7, 'finished', True),
|
||||||
|
)
|
||||||
|
db_path = os.path.join(DBROOT, 'test')
|
||||||
|
h = history.History(db_path=db_path)
|
||||||
|
h.save_tuples(tuples)
|
||||||
|
assert os.path.exists(db_path)
|
||||||
|
|
||||||
|
# Recover the data
|
||||||
|
recovered = history.History(db_path=db_path, backup=False)
|
||||||
|
assert recovered['a_1', 0, 'id'] == 'v'
|
||||||
|
assert recovered['a_1', 4, 'id'] == 'e'
|
||||||
|
|
||||||
|
# Using the same name should create a backup copy
|
||||||
|
newhistory = history.History(db_path=db_path, backup=True)
|
||||||
|
backuppaths = glob(db_path + '.backup*.sqlite')
|
||||||
|
assert len(backuppaths) == 1
|
||||||
|
backuppath = backuppaths[0]
|
||||||
|
assert newhistory._db_path == h._db_path
|
||||||
|
assert os.path.exists(backuppath)
|
||||||
|
assert not len(newhistory[None, None, None])
|
@@ -2,6 +2,7 @@ from unittest import TestCase
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
|
import networkx as nx
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from os.path import join
|
from os.path import join
|
||||||
@@ -21,6 +22,7 @@ class TestMain(TestCase):
|
|||||||
Raise an exception otherwise.
|
Raise an exception otherwise.
|
||||||
"""
|
"""
|
||||||
config = {
|
config = {
|
||||||
|
'dry_run': True,
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'path': join(ROOT, 'test.gexf')
|
'path': join(ROOT, 'test.gexf')
|
||||||
}
|
}
|
||||||
@@ -30,6 +32,7 @@ class TestMain(TestCase):
|
|||||||
assert len(G) == 2
|
assert len(G) == 2
|
||||||
with self.assertRaises(AttributeError):
|
with self.assertRaises(AttributeError):
|
||||||
config = {
|
config = {
|
||||||
|
'dry_run': True,
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'path': join(ROOT, 'unknown.extension')
|
'path': join(ROOT, 'unknown.extension')
|
||||||
}
|
}
|
||||||
@@ -43,6 +46,7 @@ class TestMain(TestCase):
|
|||||||
should be used to generate a network
|
should be used to generate a network
|
||||||
"""
|
"""
|
||||||
config = {
|
config = {
|
||||||
|
'dry_run': True,
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'generator': 'barabasi_albert_graph'
|
'generator': 'barabasi_albert_graph'
|
||||||
}
|
}
|
||||||
@@ -57,6 +61,7 @@ class TestMain(TestCase):
|
|||||||
def test_empty_simulation(self):
|
def test_empty_simulation(self):
|
||||||
"""A simulation with a base behaviour should do nothing"""
|
"""A simulation with a base behaviour should do nothing"""
|
||||||
config = {
|
config = {
|
||||||
|
'dry_run': True,
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'path': join(ROOT, 'test.gexf')
|
'path': join(ROOT, 'test.gexf')
|
||||||
},
|
},
|
||||||
@@ -65,35 +70,39 @@ class TestMain(TestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
s.run_simulation()
|
s.run_simulation(dry_run=True)
|
||||||
|
|
||||||
def test_counter_agent(self):
|
def test_counter_agent(self):
|
||||||
"""
|
"""
|
||||||
The initial states should be applied to the agent and the
|
The initial states should be applied to the agent and the
|
||||||
agent should be able to update its state."""
|
agent should be able to update its state."""
|
||||||
config = {
|
config = {
|
||||||
|
'name': 'CounterAgent',
|
||||||
|
'dry_run': True,
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'path': join(ROOT, 'test.gexf')
|
'path': join(ROOT, 'test.gexf')
|
||||||
},
|
},
|
||||||
'agent_type': 'CounterModel',
|
'agent_type': 'CounterModel',
|
||||||
'states': [{'neighbors': 10}, {'total': 12}],
|
'states': [{'times': 10}, {'times': 20}],
|
||||||
'max_time': 2,
|
'max_time': 2,
|
||||||
'num_trials': 1,
|
'num_trials': 1,
|
||||||
'environment_params': {
|
'environment_params': {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
env = s.run_simulation()[0]
|
env = s.run_simulation(dry_run=True)[0]
|
||||||
assert env.get_agent(0)['neighbors', 0] == 10
|
assert env.get_agent(0)['times', 0] == 11
|
||||||
assert env.get_agent(0)['neighbors', 1] == 1
|
assert env.get_agent(0)['times', 1] == 12
|
||||||
assert env.get_agent(1)['total', 0] == 12
|
assert env.get_agent(1)['times', 0] == 21
|
||||||
assert env.get_agent(1)['neighbors', 1] == 1
|
assert env.get_agent(1)['times', 1] == 22
|
||||||
|
|
||||||
def test_counter_agent_history(self):
|
def test_counter_agent_history(self):
|
||||||
"""
|
"""
|
||||||
The evolution of the state should be recorded in the logging agent
|
The evolution of the state should be recorded in the logging agent
|
||||||
"""
|
"""
|
||||||
config = {
|
config = {
|
||||||
|
'name': 'CounterAgent',
|
||||||
|
'dry_run': True,
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'path': join(ROOT, 'test.gexf')
|
'path': join(ROOT, 'test.gexf')
|
||||||
},
|
},
|
||||||
@@ -108,14 +117,13 @@ class TestMain(TestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
env = s.run_simulation()[0]
|
env = s.run_simulation(dry_run=True)[0]
|
||||||
for agent in env.network_agents:
|
for agent in env.network_agents:
|
||||||
last = 0
|
last = 0
|
||||||
assert len(agent[None, None]) == 11
|
assert len(agent[None, None]) == 10
|
||||||
for step, total in agent['total', None].items():
|
for step, total in sorted(agent['total', None]):
|
||||||
if step > 0:
|
assert total == last + 2
|
||||||
assert total == last + 2
|
last = total
|
||||||
last = total
|
|
||||||
|
|
||||||
def test_custom_agent(self):
|
def test_custom_agent(self):
|
||||||
"""Allow for search of neighbors with a certain state_id"""
|
"""Allow for search of neighbors with a certain state_id"""
|
||||||
@@ -124,6 +132,7 @@ class TestMain(TestCase):
|
|||||||
self.state['neighbors'] = self.count_agents(state_id=0,
|
self.state['neighbors'] = self.count_agents(state_id=0,
|
||||||
limit_neighbors=True)
|
limit_neighbors=True)
|
||||||
config = {
|
config = {
|
||||||
|
'dry_run': True,
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'path': join(ROOT, 'test.gexf')
|
'path': join(ROOT, 'test.gexf')
|
||||||
},
|
},
|
||||||
@@ -138,7 +147,7 @@ class TestMain(TestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
env = s.run_simulation()[0]
|
env = s.run_simulation(dry_run=True)[0]
|
||||||
assert env.get_agent(0).state['neighbors'] == 1
|
assert env.get_agent(0).state['neighbors'] == 1
|
||||||
|
|
||||||
def test_torvalds_example(self):
|
def test_torvalds_example(self):
|
||||||
@@ -147,6 +156,7 @@ class TestMain(TestCase):
|
|||||||
config['network_params']['path'] = join(EXAMPLES,
|
config['network_params']['path'] = join(EXAMPLES,
|
||||||
config['network_params']['path'])
|
config['network_params']['path'])
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
|
s.dry_run = True
|
||||||
env = s.run_simulation()[0]
|
env = s.run_simulation()[0]
|
||||||
for a in env.network_agents:
|
for a in env.network_agents:
|
||||||
skill_level = a.state['skill_level']
|
skill_level = a.state['skill_level']
|
||||||
@@ -171,6 +181,7 @@ class TestMain(TestCase):
|
|||||||
with utils.timer('loading'):
|
with utils.timer('loading'):
|
||||||
config = utils.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
config = utils.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
|
s.dry_run = True
|
||||||
with utils.timer('serializing'):
|
with utils.timer('serializing'):
|
||||||
serial = s.to_yaml()
|
serial = s.to_yaml()
|
||||||
with utils.timer('recovering'):
|
with utils.timer('recovering'):
|
||||||
@@ -178,6 +189,7 @@ class TestMain(TestCase):
|
|||||||
with utils.timer('deleting'):
|
with utils.timer('deleting'):
|
||||||
del recovered['topology']
|
del recovered['topology']
|
||||||
del recovered['load_module']
|
del recovered['load_module']
|
||||||
|
del recovered['dry_run']
|
||||||
assert config == recovered
|
assert config == recovered
|
||||||
|
|
||||||
def test_configuration_changes(self):
|
def test_configuration_changes(self):
|
||||||
@@ -187,10 +199,12 @@ class TestMain(TestCase):
|
|||||||
"""
|
"""
|
||||||
config = utils.load_file('examples/complete.yml')[0]
|
config = utils.load_file('examples/complete.yml')[0]
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
|
s.dry_run = True
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
s.run_simulation()
|
s.run_simulation(dry_run=True)
|
||||||
nconfig = s.to_dict()
|
nconfig = s.to_dict()
|
||||||
del nconfig['topology']
|
del nconfig['topology']
|
||||||
|
del nconfig['dry_run']
|
||||||
del nconfig['load_module']
|
del nconfig['load_module']
|
||||||
assert config == nconfig
|
assert config == nconfig
|
||||||
|
|
||||||
@@ -201,23 +215,27 @@ class TestMain(TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def test_row_conversion(self):
|
def test_row_conversion(self):
|
||||||
sim = simulation.SoilSimulation()
|
env = environment.SoilEnvironment(dry_run=True)
|
||||||
env = environment.SoilEnvironment(simulation=sim)
|
|
||||||
env['test'] = 'test_value'
|
env['test'] = 'test_value'
|
||||||
env._save_state(now=0)
|
|
||||||
|
|
||||||
res = list(env.history_to_tuples())
|
res = list(env.history_to_tuples())
|
||||||
assert len(res) == len(env.environment_params)
|
assert len(res) == len(env.environment_params)
|
||||||
assert ('env', 0, 'test', 'test_value', 'str') in res
|
|
||||||
|
|
||||||
|
env._now = 1
|
||||||
env['test'] = 'second_value'
|
env['test'] = 'second_value'
|
||||||
env._save_state(now=1)
|
|
||||||
res = list(env.history_to_tuples())
|
res = list(env.history_to_tuples())
|
||||||
|
|
||||||
assert env['env', 0, 'test' ] == 'test_value'
|
assert env['env', 0, 'test' ] == 'test_value'
|
||||||
assert env['env', 1, 'test' ] == 'second_value'
|
assert env['env', 1, 'test' ] == 'second_value'
|
||||||
|
|
||||||
|
def test_save_geometric(self):
|
||||||
|
"""
|
||||||
|
There is a bug in networkx that prevents it from creating a GEXF file
|
||||||
|
from geometric models. We should work around it.
|
||||||
|
"""
|
||||||
|
G = nx.random_geometric_graph(20,0.1)
|
||||||
|
env = environment.SoilEnvironment(topology=G, dry_run=True)
|
||||||
|
env.dump_gexf('/tmp/dump-gexf')
|
||||||
|
|
||||||
|
|
||||||
def make_example_test(path, config):
|
def make_example_test(path, config):
|
||||||
@@ -225,8 +243,10 @@ def make_example_test(path, config):
|
|||||||
root = os.getcwd()
|
root = os.getcwd()
|
||||||
os.chdir(os.path.dirname(path))
|
os.chdir(os.path.dirname(path))
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
envs = s.run_simulation()
|
envs = s.run_simulation(dry_run=True)
|
||||||
|
assert envs
|
||||||
for env in envs:
|
for env in envs:
|
||||||
|
assert env
|
||||||
try:
|
try:
|
||||||
n = config['network_params']['n']
|
n = config['network_params']['n']
|
||||||
assert len(env.get_agents()) == n
|
assert len(env.get_agents()) == n
|
||||||
|
Reference in New Issue
Block a user