mirror of
https://github.com/gsi-upm/soil
synced 2025-09-13 19:52:20 +00:00
Compare commits
6 Commits
0.14.6
...
e860bdb922
Author | SHA1 | Date | |
---|---|---|---|
|
e860bdb922 | ||
|
d6b684c1c1 | ||
|
05f7f49233 | ||
|
3b2c6a3db5 | ||
|
6118f917ee | ||
|
6adc8d36ba |
@@ -1,4 +1,5 @@
|
||||
**/soil_output
|
||||
.*
|
||||
**/__pycache__
|
||||
__pycache__
|
||||
*.pyc
|
||||
|
41
CHANGELOG.md
41
CHANGELOG.md
@@ -3,6 +3,47 @@ All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.15.2]
|
||||
### Fixed
|
||||
* Pass the right known_modules and parameters to stats discovery in simulation
|
||||
* The configuration file must exist when launching through the CLI. If it doesn't, an error will be logged
|
||||
* Minor changes in the documentation of the CLI arguments
|
||||
### Changed
|
||||
* Stats are now exported by default
|
||||
## [0.15.1]
|
||||
### Added
|
||||
* read-only `History`
|
||||
### Fixed
|
||||
* Serialization problem with the `Environment` on parallel mode.
|
||||
* Analysis functions now work as they should in the tutorial
|
||||
## [0.15.0]
|
||||
### Added
|
||||
* Control logging level in CLI and simulation
|
||||
* `Stats` to calculate trial and simulation-wide statistics
|
||||
* Simulation statistics are stored in a separate table in history (see `History.get_stats` and `History.save_stats`, as well as `soil.stats`)
|
||||
* Aliased `NetworkAgent.G` to `NetworkAgent.topology`.
|
||||
### Changed
|
||||
* Templates in config files can be given as dictionaries in addition to strings
|
||||
* Samplers are used more explicitly
|
||||
* Removed nxsim dependency. We had already made a lot of changes, and nxsim has not been updated in 5 years.
|
||||
* Exporter methods renamed to `trial` and `end`. Added `start`.
|
||||
* `Distribution` exporter now a stats class
|
||||
* `global_topology` renamed to `topology`
|
||||
* Moved topology-related methods to `NetworkAgent`
|
||||
### Fixed
|
||||
* Temporary files used for history in dry_run mode are not longer left open
|
||||
|
||||
## [0.14.9]
|
||||
### Changed
|
||||
* Seed random before environment initialization
|
||||
## [0.14.8]
|
||||
### Fixed
|
||||
* Invalid directory names in Windows gsi-upm/soil#5
|
||||
## [0.14.7]
|
||||
### Changed
|
||||
* Minor change to traceback handling in async simulations
|
||||
### Fixed
|
||||
* Incomplete example in the docs (example.yml) caused an exception
|
||||
## [0.14.6]
|
||||
### Fixed
|
||||
* Bug with newer versions of networkx (0.24) where the Graph.node attribute has been removed. We have updated our calls, but the code in nxsim is not under our control, so we have pinned the networkx version until that issue is solved.
|
||||
|
@@ -30,5 +30,5 @@ If you use Soil in your research, don't forget to cite this paper:
|
||||
|
||||
@Copyright GSI - Universidad Politécnica de Madrid 2017
|
||||
|
||||
[](https://www.gsi.dit.upm.es)
|
||||
[](https://www.gsi.upm.es)
|
||||
|
||||
|
@@ -31,7 +31,7 @@
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = []
|
||||
extensions = ['IPython.sphinxext.ipython_console_highlighting']
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
@@ -69,7 +69,7 @@ language = None
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This patterns also effect to html_static_path and html_extra_path
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '**.ipynb_checkpoints']
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = 'sphinx'
|
||||
|
@@ -8,32 +8,8 @@ The advantage of a configuration file is that it is a clean declarative descript
|
||||
Simulation configuration files can be formatted in ``json`` or ``yaml`` and they define all the parameters of a simulation.
|
||||
Here's an example (``example.yml``).
|
||||
|
||||
.. code:: yaml
|
||||
|
||||
---
|
||||
name: MyExampleSimulation
|
||||
max_time: 50
|
||||
num_trials: 3
|
||||
interval: 2
|
||||
network_params:
|
||||
generator: barabasi_albert_graph
|
||||
n: 100
|
||||
m: 2
|
||||
network_agents:
|
||||
- agent_type: SISaModel
|
||||
weight: 1
|
||||
state:
|
||||
id: content
|
||||
- agent_type: SISaModel
|
||||
weight: 1
|
||||
state:
|
||||
id: discontent
|
||||
- agent_type: SISaModel
|
||||
weight: 8
|
||||
state:
|
||||
id: neutral
|
||||
environment_params:
|
||||
prob_infect: 0.075
|
||||
.. literalinclude:: example.yml
|
||||
:language: yaml
|
||||
|
||||
|
||||
This example configuration will run three trials (``num_trials``) of a simulation containing a randomly generated network (``network_params``).
|
||||
@@ -242,3 +218,24 @@ These agents are programmed in much the same way as network agents, the only dif
|
||||
|
||||
You may use environment agents to model events that a normal agent cannot control, such as natural disasters or chance.
|
||||
They are also useful to add behavior that has little to do with the network and the interactions within that network.
|
||||
|
||||
Templating
|
||||
==========
|
||||
|
||||
Sometimes, it is useful to parameterize a simulation and run it over a range of values in order to compare each run and measure the effect of those parameters in the simulation.
|
||||
For instance, you may want to run a simulation with different agent distributions.
|
||||
|
||||
This can be done in Soil using **templates**.
|
||||
A template is a configuration where some of the values are specified with a variable.
|
||||
e.g., ``weight: "{{ var1 }}"`` instead of ``weight: 1``.
|
||||
There are two types of variables, depending on how their values are decided:
|
||||
|
||||
* Fixed. A list of values is provided, and a new simulation is run for each possible value. If more than a variable is given, a new simulation will be run per combination of values.
|
||||
* Bounded/Sampled. The bounds of the variable are provided, along with a sampler method, which will be used to compute all the configuration combinations.
|
||||
|
||||
When fixed and bounded variables are mixed, Soil generates a new configuration per combination of fixed values and bounded values.
|
||||
|
||||
Here is an example with a single fixed variable and two bounded variable:
|
||||
|
||||
.. literalinclude:: ../examples/template.yml
|
||||
:language: yaml
|
||||
|
35
docs/example.yml
Normal file
35
docs/example.yml
Normal file
@@ -0,0 +1,35 @@
|
||||
---
|
||||
name: MyExampleSimulation
|
||||
max_time: 50
|
||||
num_trials: 3
|
||||
interval: 2
|
||||
network_params:
|
||||
generator: barabasi_albert_graph
|
||||
n: 100
|
||||
m: 2
|
||||
network_agents:
|
||||
- agent_type: SISaModel
|
||||
weight: 1
|
||||
state:
|
||||
id: content
|
||||
- agent_type: SISaModel
|
||||
weight: 1
|
||||
state:
|
||||
id: discontent
|
||||
- agent_type: SISaModel
|
||||
weight: 8
|
||||
state:
|
||||
id: neutral
|
||||
environment_params:
|
||||
prob_infect: 0.075
|
||||
neutral_discontent_spon_prob: 0.1
|
||||
neutral_discontent_infected_prob: 0.3
|
||||
neutral_content_spon_prob: 0.3
|
||||
neutral_content_infected_prob: 0.4
|
||||
discontent_neutral: 0.5
|
||||
discontent_content: 0.5
|
||||
variance_d_c: 0.2
|
||||
content_discontent: 0.2
|
||||
variance_c_d: 0.2
|
||||
content_neutral: 0.2
|
||||
standard_variance: 1
|
@@ -14,11 +14,11 @@ Now test that it worked by running the command line tool
|
||||
|
||||
soil --help
|
||||
|
||||
Or using soil programmatically:
|
||||
Or, if you're using using soil programmatically:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import soil
|
||||
print(soil.__version__)
|
||||
|
||||
The latest version can be installed through `GitLab <https://lab.cluster.gsi.dit.upm.es/soil/soil.git>`_.
|
||||
The latest version can be installed through `GitLab <https://lab.gsi.upm.es/soil/soil.git>`_ or `GitHub <https://github.com/gsi-upm/soil>`_.
|
||||
|
1
docs/requirements.txt
Normal file
1
docs/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
ipython==7.23
|
@@ -500,7 +500,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.5"
|
||||
"version": "3.8.5"
|
||||
},
|
||||
"toc": {
|
||||
"colors": {
|
||||
|
@@ -80800,7 +80800,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.5"
|
||||
"version": "3.8.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from soil.agents import FSM, state, default_state, BaseAgent
|
||||
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
|
||||
from enum import Enum
|
||||
from random import random, choice
|
||||
from itertools import islice
|
||||
@@ -80,7 +80,7 @@ class RabbitModel(FSM):
|
||||
self.env.add_edge(self['mate'], child.id)
|
||||
# self.add_edge()
|
||||
self.debug('A BABY IS COMING TO LIFE')
|
||||
self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.global_topology.number_of_nodes())+1
|
||||
self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.topology.number_of_nodes())+1
|
||||
self.debug('Rabbits alive: {}'.format(self.env['rabbits_alive']))
|
||||
self['offspring'] += 1
|
||||
self.env.get_agent(self['mate'])['offspring'] += 1
|
||||
@@ -97,12 +97,14 @@ class RabbitModel(FSM):
|
||||
return
|
||||
|
||||
|
||||
class RandomAccident(BaseAgent):
|
||||
class RandomAccident(NetworkAgent):
|
||||
|
||||
level = logging.DEBUG
|
||||
|
||||
def step(self):
|
||||
rabbits_total = self.global_topology.number_of_nodes()
|
||||
rabbits_total = self.topology.number_of_nodes()
|
||||
if 'rabbits_alive' not in self.env:
|
||||
self.env['rabbits_alive'] = 0
|
||||
rabbits_alive = self.env.get('rabbits_alive', rabbits_total)
|
||||
prob_death = self.env.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
|
||||
self.debug('Killing some rabbits with prob={}!'.format(prob_death))
|
||||
@@ -116,5 +118,5 @@ class RandomAccident(BaseAgent):
|
||||
self.log('Rabbits alive: {}'.format(self.env['rabbits_alive']))
|
||||
i.set_state(i.dead)
|
||||
self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total))
|
||||
if self.count_agents(state_id=RabbitModel.dead.id) == self.global_topology.number_of_nodes():
|
||||
if self.count_agents(state_id=RabbitModel.dead.id) == self.topology.number_of_nodes():
|
||||
self.die()
|
||||
|
@@ -1,13 +1,8 @@
|
||||
---
|
||||
vars:
|
||||
bounds:
|
||||
x1: [0, 1]
|
||||
x2: [1, 2]
|
||||
fixed:
|
||||
x3: ["a", "b", "c"]
|
||||
sampler: "SALib.sample.morris.sample"
|
||||
samples: 10
|
||||
template: |
|
||||
sampler:
|
||||
method: "SALib.sample.morris.sample"
|
||||
N: 10
|
||||
template:
|
||||
group: simple
|
||||
num_trials: 1
|
||||
interval: 1
|
||||
@@ -19,11 +14,17 @@ template: |
|
||||
n: 10
|
||||
network_agents:
|
||||
- agent_type: CounterModel
|
||||
weight: {{ x1 }}
|
||||
weight: "{{ x1 }}"
|
||||
state:
|
||||
id: 0
|
||||
- agent_type: AggregatedCounter
|
||||
weight: {{ 1 - x1 }}
|
||||
weight: "{{ 1 - x1 }}"
|
||||
environment_params:
|
||||
name: {{ x3 }}
|
||||
name: "{{ x3 }}"
|
||||
skip_test: true
|
||||
vars:
|
||||
bounds:
|
||||
x1: [0, 1]
|
||||
x2: [1, 2]
|
||||
fixed:
|
||||
x3: ["a", "b", "c"]
|
||||
|
@@ -195,14 +195,14 @@ class TerroristNetworkModel(TerroristSpreadModel):
|
||||
break
|
||||
|
||||
def get_distance(self, target):
|
||||
source_x, source_y = nx.get_node_attributes(self.global_topology, 'pos')[self.id]
|
||||
target_x, target_y = nx.get_node_attributes(self.global_topology, 'pos')[target]
|
||||
source_x, source_y = nx.get_node_attributes(self.topology, 'pos')[self.id]
|
||||
target_x, target_y = nx.get_node_attributes(self.topology, 'pos')[target]
|
||||
dx = abs( source_x - target_x )
|
||||
dy = abs( source_y - target_y )
|
||||
return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 )
|
||||
|
||||
def shortest_path_length(self, target):
|
||||
try:
|
||||
return nx.shortest_path_length(self.global_topology, self.id, target)
|
||||
return nx.shortest_path_length(self.topology, self.id, target)
|
||||
except nx.NetworkXNoPath:
|
||||
return float('inf')
|
||||
|
File diff suppressed because one or more lines are too long
@@ -1,10 +1,9 @@
|
||||
nxsim>=0.1.2
|
||||
simpy
|
||||
networkx>=2.0,<2.4
|
||||
simpy>=4.0
|
||||
networkx>=2.5
|
||||
numpy
|
||||
matplotlib
|
||||
pyyaml>=5.1
|
||||
pandas>=0.23
|
||||
scipy==1.2.1 # scipy 1.3.0rc1 is not compatible with salib
|
||||
scipy>=1.3
|
||||
SALib>=1.3
|
||||
Jinja2
|
||||
|
@@ -1 +1 @@
|
||||
0.14.6
|
||||
0.15.2
|
@@ -17,19 +17,21 @@ from .environment import Environment
|
||||
from .history import History
|
||||
from . import serialization
|
||||
from . import analysis
|
||||
from .utils import logger
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
from . import simulation
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info('Running SOIL version: {}'.format(__version__))
|
||||
logger.info('Running SOIL version: {}'.format(__version__))
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run a SOIL simulation')
|
||||
parser.add_argument('file', type=str,
|
||||
nargs="?",
|
||||
default='simulation.yml',
|
||||
help='python module containing the simulation configuration.')
|
||||
help='Configuration file for the simulation (e.g., YAML or JSON)')
|
||||
parser.add_argument('--version', action='store_true',
|
||||
help='Show version info and exit')
|
||||
parser.add_argument('--module', '-m', type=str,
|
||||
help='file containing the code of any custom agents.')
|
||||
parser.add_argument('--dry-run', '--dry', action='store_true',
|
||||
@@ -40,6 +42,8 @@ def main():
|
||||
help='Dump GEXF graph. Defaults to false.')
|
||||
parser.add_argument('--csv', action='store_true',
|
||||
help='Dump history in CSV format. Defaults to false.')
|
||||
parser.add_argument('--level', type=str,
|
||||
help='Logging level')
|
||||
parser.add_argument('--output', '-o', type=str, default="soil_output",
|
||||
help='folder to write results to. It defaults to the current directory.')
|
||||
parser.add_argument('--synchronous', action='store_true',
|
||||
@@ -48,13 +52,17 @@ def main():
|
||||
help='Export environment and/or simulations using this exporter')
|
||||
|
||||
args = parser.parse_args()
|
||||
logging.basicConfig(level=getattr(logging, (args.level or 'INFO').upper()))
|
||||
|
||||
if args.version:
|
||||
return
|
||||
|
||||
if os.getcwd() not in sys.path:
|
||||
sys.path.append(os.getcwd())
|
||||
if args.module:
|
||||
importlib.import_module(args.module)
|
||||
|
||||
logging.info('Loading config file: {}'.format(args.file))
|
||||
logger.info('Loading config file: {}'.format(args.file))
|
||||
|
||||
try:
|
||||
exporters = list(args.exporter or ['default', ])
|
||||
@@ -65,6 +73,10 @@ def main():
|
||||
exp_params = {}
|
||||
if args.dry_run:
|
||||
exp_params['copy_to'] = sys.stdout
|
||||
|
||||
if not os.path.exists(args.file):
|
||||
logger.error('Please, input a valid file')
|
||||
return
|
||||
simulation.run_from_config(args.file,
|
||||
dry_run=args.dry_run,
|
||||
exporters=exporters,
|
||||
|
@@ -9,7 +9,7 @@ class BassModel(BaseAgent):
|
||||
imitation_prob
|
||||
"""
|
||||
|
||||
def __init__(self, environment, agent_id, state):
|
||||
def __init__(self, environment, agent_id, state, **kwargs):
|
||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||
env_params = environment.environment_params
|
||||
self.state['sentimentCorrelation'] = 0
|
||||
@@ -19,7 +19,7 @@ class BassModel(BaseAgent):
|
||||
|
||||
def behaviour(self):
|
||||
# Outside effects
|
||||
if random.random() < self.state_params['innovation_prob']:
|
||||
if random.random() < self['innovation_prob']:
|
||||
if self.state['id'] == 0:
|
||||
self.state['id'] = 1
|
||||
self.state['sentimentCorrelation'] = 1
|
||||
@@ -32,7 +32,7 @@ class BassModel(BaseAgent):
|
||||
if self.state['id'] == 0:
|
||||
aware_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
num_neighbors_aware = len(aware_neighbors)
|
||||
if random.random() < (self.state_params['imitation_prob']*num_neighbors_aware):
|
||||
if random.random() < (self['imitation_prob']*num_neighbors_aware):
|
||||
self.state['id'] = 1
|
||||
self.state['sentimentCorrelation'] = 1
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from . import BaseAgent
|
||||
from . import NetworkAgent
|
||||
|
||||
|
||||
class CounterModel(BaseAgent):
|
||||
class CounterModel(NetworkAgent):
|
||||
"""
|
||||
Dummy behaviour. It counts the number of nodes in the network and neighbors
|
||||
in each step and adds it to its state.
|
||||
@@ -9,14 +9,14 @@ class CounterModel(BaseAgent):
|
||||
|
||||
def step(self):
|
||||
# Outside effects
|
||||
total = len(list(self.get_all_agents()))
|
||||
total = len(list(self.get_agents()))
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['times'] = self.get('times', 0) + 1
|
||||
self['neighbors'] = neighbors
|
||||
self['total'] = total
|
||||
|
||||
|
||||
class AggregatedCounter(BaseAgent):
|
||||
class AggregatedCounter(NetworkAgent):
|
||||
"""
|
||||
Dummy behaviour. It counts the number of nodes in the network and neighbors
|
||||
in each step and adds it to its state.
|
||||
@@ -33,6 +33,6 @@ class AggregatedCounter(BaseAgent):
|
||||
self['times'] += 1
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['neighbors'] += neighbors
|
||||
total = len(list(self.get_all_agents()))
|
||||
total = len(list(self.get_agents()))
|
||||
self['total'] += total
|
||||
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
||||
|
@@ -3,19 +3,19 @@
|
||||
# for x in range(0, settings.network_params["number_of_nodes"]):
|
||||
# sentimentCorrelationNodeArray.append({'id': x})
|
||||
# Initialize agent states. Let's assume everyone is normal.
|
||||
|
||||
|
||||
import nxsim
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from scipy.spatial import cKDTree as KDTree
|
||||
import json
|
||||
import simpy
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from .. import serialization, history
|
||||
from .. import serialization, history, utils
|
||||
|
||||
|
||||
def as_node(agent):
|
||||
@@ -24,7 +24,7 @@ def as_node(agent):
|
||||
return agent
|
||||
|
||||
|
||||
class BaseAgent(nxsim.BaseAgent):
|
||||
class BaseAgent:
|
||||
"""
|
||||
A special simpy BaseAgent that keeps track of its state history.
|
||||
"""
|
||||
@@ -32,14 +32,13 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
defaults = {}
|
||||
|
||||
def __init__(self, environment, agent_id, state=None,
|
||||
name=None, interval=None, **state_params):
|
||||
name=None, interval=None):
|
||||
# Check for REQUIRED arguments
|
||||
assert environment is not None, TypeError('__init__ missing 1 required keyword argument: \'environment\'. '
|
||||
'Cannot be NoneType.')
|
||||
# Initialize agent parameters
|
||||
self.id = agent_id
|
||||
self.name = name or '{}[{}]'.format(type(self).__name__, self.id)
|
||||
self.state_params = state_params
|
||||
|
||||
# Register agent to environment
|
||||
self.env = environment
|
||||
@@ -51,10 +50,10 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
self.state = real_state
|
||||
self.interval = interval
|
||||
|
||||
if not hasattr(self, 'level'):
|
||||
self.level = logging.DEBUG
|
||||
self.logger = logging.getLogger(self.env.name)
|
||||
self.logger.setLevel(self.level)
|
||||
self.logger = logging.getLogger(self.env.name).getChild(self.name)
|
||||
|
||||
if hasattr(self, 'level'):
|
||||
self.logger.setLevel(self.level)
|
||||
|
||||
# initialize every time an instance of the agent is created
|
||||
self.action = self.env.process(self.run())
|
||||
@@ -75,14 +74,10 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
for k, v in value.items():
|
||||
self[k] = v
|
||||
|
||||
@property
|
||||
def global_topology(self):
|
||||
return self.env.G
|
||||
|
||||
@property
|
||||
def environment_params(self):
|
||||
return self.env.environment_params
|
||||
|
||||
|
||||
@environment_params.setter
|
||||
def environment_params(self, value):
|
||||
self.env.environment_params = value
|
||||
@@ -135,36 +130,10 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
def die(self, remove=False):
|
||||
self.alive = False
|
||||
if remove:
|
||||
super().die()
|
||||
self.remove_node(self.id)
|
||||
|
||||
def step(self):
|
||||
pass
|
||||
|
||||
def count_agents(self, **kwargs):
|
||||
return len(list(self.get_agents(**kwargs)))
|
||||
|
||||
def count_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return len(super().get_neighboring_agents(state_id=state_id, **kwargs))
|
||||
|
||||
def get_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
|
||||
|
||||
def get_agents(self, agents=None, limit_neighbors=False, **kwargs):
|
||||
if limit_neighbors:
|
||||
agents = super().get_agents(limit_neighbors=limit_neighbors)
|
||||
else:
|
||||
agents = self.env.get_agents(agents)
|
||||
return select(agents, **kwargs)
|
||||
|
||||
def log(self, message, *args, level=logging.INFO, **kwargs):
|
||||
message = message + " ".join(str(i) for i in args)
|
||||
message = "\t{:10}@{:>5}:\t{}".format(self.name, self.now, message)
|
||||
for k, v in kwargs:
|
||||
message += " {k}={v} ".format(k, v)
|
||||
extra = {}
|
||||
extra['now'] = self.now
|
||||
extra['id'] = self.id
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
return
|
||||
|
||||
def debug(self, *args, **kwargs):
|
||||
return self.log(*args, level=logging.DEBUG, **kwargs)
|
||||
@@ -192,24 +161,59 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
self._state = state['_state']
|
||||
self.env = state['environment']
|
||||
|
||||
def add_edge(self, node1, node2, **attrs):
|
||||
node1 = as_node(node1)
|
||||
node2 = as_node(node2)
|
||||
class NetworkAgent(BaseAgent):
|
||||
|
||||
for n in [node1, node2]:
|
||||
if n not in self.global_topology.nodes(data=False):
|
||||
raise ValueError('"{}" not in the graph'.format(n))
|
||||
return self.global_topology.add_edge(node1, node2, **attrs)
|
||||
@property
|
||||
def topology(self):
|
||||
return self.env.G
|
||||
|
||||
@property
|
||||
def G(self):
|
||||
return self.env.G
|
||||
|
||||
def count_agents(self, **kwargs):
|
||||
return len(list(self.get_agents(**kwargs)))
|
||||
|
||||
def count_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return len(self.get_neighboring_agents(state_id=state_id, **kwargs))
|
||||
|
||||
def get_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
|
||||
|
||||
def get_agents(self, agents=None, limit_neighbors=False, **kwargs):
|
||||
if limit_neighbors:
|
||||
agents = self.topology.neighbors(self.id)
|
||||
|
||||
agents = self.env.get_agents(agents)
|
||||
return select(agents, **kwargs)
|
||||
|
||||
def log(self, message, *args, level=logging.INFO, **kwargs):
|
||||
message = message + " ".join(str(i) for i in args)
|
||||
message = " @{:>3}: {}".format(self.now, message)
|
||||
for k, v in kwargs:
|
||||
message += " {k}={v} ".format(k, v)
|
||||
extra = {}
|
||||
extra['now'] = self.now
|
||||
extra['agent_id'] = self.id
|
||||
extra['agent_name'] = self.name
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
|
||||
def subgraph(self, center=True, **kwargs):
|
||||
include = [self] if center else []
|
||||
return self.global_topology.subgraph(n.id for n in self.get_agents(**kwargs)+include)
|
||||
return self.topology.subgraph(n.id for n in self.get_agents(**kwargs)+include)
|
||||
|
||||
def remove_node(self, agent_id):
|
||||
self.topology.remove_node(agent_id)
|
||||
|
||||
class NetworkAgent(BaseAgent):
|
||||
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
||||
# return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
|
||||
if self.id not in self.topology.nodes(data=False):
|
||||
raise ValueError('{} not in list of existing agents in the network'.format(self.id))
|
||||
if other not in self.topology.nodes(data=False):
|
||||
raise ValueError('{} not in list of existing agents in the network'.format(other))
|
||||
|
||||
self.topology.add_edge(self.id, other, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
||||
|
||||
def add_edge(self, other, **kwargs):
|
||||
return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
|
||||
|
||||
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
||||
'''Get a list of nodes in the ego network of *node* of radius *steps*'''
|
||||
@@ -220,14 +224,14 @@ class NetworkAgent(BaseAgent):
|
||||
def degree(self, node, force=False):
|
||||
node = as_node(node)
|
||||
if force or (not hasattr(self.env, '_degree')) or getattr(self.env, '_last_step', 0) < self.now:
|
||||
self.env._degree = nx.degree_centrality(self.global_topology)
|
||||
self.env._degree = nx.degree_centrality(self.topology)
|
||||
self.env._last_step = self.now
|
||||
return self.env._degree[node]
|
||||
|
||||
def betweenness(self, node, force=False):
|
||||
node = as_node(node)
|
||||
if force or (not hasattr(self.env, '_betweenness')) or getattr(self.env, '_last_step', 0) < self.now:
|
||||
self.env._betweenness = nx.betweenness_centrality(self.global_topology)
|
||||
self.env._betweenness = nx.betweenness_centrality(self.topology)
|
||||
self.env._last_step = self.now
|
||||
return self.env._betweenness[node]
|
||||
|
||||
@@ -292,16 +296,22 @@ class MetaFSM(type):
|
||||
cls.states = states
|
||||
|
||||
|
||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
class FSM(NetworkAgent, metaclass=MetaFSM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FSM, self).__init__(*args, **kwargs)
|
||||
if 'id' not in self.state:
|
||||
if not self.default_state:
|
||||
raise ValueError('No default state specified for {}'.format(self.id))
|
||||
self['id'] = self.default_state.id
|
||||
self._next_change = simpy.core.Infinity
|
||||
self._next_state = self.state
|
||||
|
||||
def step(self):
|
||||
if 'id' in self.state:
|
||||
if self._next_change < self.now:
|
||||
next_state = self._next_state
|
||||
self._next_change = simpy.core.Infinity
|
||||
self['id'] = next_state
|
||||
elif 'id' in self.state:
|
||||
next_state = self['id']
|
||||
elif self.default_state:
|
||||
next_state = self.default_state.id
|
||||
@@ -311,6 +321,10 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
raise Exception('{} is not a valid id for {}'.format(next_state, self))
|
||||
return self.states[next_state](self)
|
||||
|
||||
def next_state(self, state):
|
||||
self._next_change = self.now
|
||||
self._next_state = state
|
||||
|
||||
def set_state(self, state):
|
||||
if hasattr(state, 'id'):
|
||||
state = state.id
|
||||
@@ -371,14 +385,18 @@ def calculate_distribution(network_agents=None,
|
||||
else:
|
||||
raise ValueError('Specify a distribution or a default agent type')
|
||||
|
||||
# Fix missing weights and incompatible types
|
||||
for x in network_agents:
|
||||
x['weight'] = float(x.get('weight', 1))
|
||||
|
||||
# Calculate the thresholds
|
||||
total = sum(x.get('weight', 1) for x in network_agents)
|
||||
total = sum(x['weight'] for x in network_agents)
|
||||
acc = 0
|
||||
for v in network_agents:
|
||||
if 'ids' in v:
|
||||
v['threshold'] = STATIC_THRESHOLD
|
||||
continue
|
||||
upper = acc + (v.get('weight', 1)/total)
|
||||
upper = acc + (v['weight']/total)
|
||||
v['threshold'] = [acc, upper]
|
||||
acc = upper
|
||||
return network_agents
|
||||
@@ -425,7 +443,7 @@ def _validate_states(states, topology):
|
||||
states = states or []
|
||||
if isinstance(states, dict):
|
||||
for x in states:
|
||||
assert x in topology.node
|
||||
assert x in topology.nodes
|
||||
else:
|
||||
assert len(states) <= len(topology)
|
||||
return states
|
||||
|
@@ -28,13 +28,13 @@ def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
|
||||
df = read_csv(trial_data, **kwargs)
|
||||
yield config_file, df, config
|
||||
else:
|
||||
for trial_data in sorted(glob.glob(join(folder, '*.db.sqlite'))):
|
||||
for trial_data in sorted(glob.glob(join(folder, '*.sqlite'))):
|
||||
df = read_sql(trial_data, **kwargs)
|
||||
yield config_file, df, config
|
||||
|
||||
|
||||
def read_sql(db, *args, **kwargs):
|
||||
h = history.History(db_path=db, backup=False)
|
||||
h = history.History(db_path=db, backup=False, readonly=True)
|
||||
df = h.read_sql(*args, **kwargs)
|
||||
return df
|
||||
|
||||
@@ -69,6 +69,13 @@ def convert_types_slow(df):
|
||||
df = df.apply(convert_row, axis=1)
|
||||
return df
|
||||
|
||||
|
||||
def split_processed(df):
|
||||
env = df.loc[:, df.columns.get_level_values(1).isin(['env', 'stats'])]
|
||||
agents = df.loc[:, ~df.columns.get_level_values(1).isin(['env', 'stats'])]
|
||||
return env, agents
|
||||
|
||||
|
||||
def split_df(df):
|
||||
'''
|
||||
Split a dataframe in two dataframes: one with the history of agents,
|
||||
@@ -136,7 +143,7 @@ def get_value(df, *keys, aggfunc='sum'):
|
||||
return df.groupby(axis=1, level=0).agg(aggfunc)
|
||||
|
||||
|
||||
def plot_all(*args, **kwargs):
|
||||
def plot_all(*args, plot_args={}, **kwargs):
|
||||
'''
|
||||
Read all the trial data and plot the result of applying a function on them.
|
||||
'''
|
||||
@@ -144,14 +151,17 @@ def plot_all(*args, **kwargs):
|
||||
ps = []
|
||||
for line in dfs:
|
||||
f, df, config = line
|
||||
df.plot(title=config['name'])
|
||||
if len(df) < 1:
|
||||
continue
|
||||
df.plot(title=config['name'], **plot_args)
|
||||
ps.append(df)
|
||||
return ps
|
||||
|
||||
def do_all(pattern, func, *keys, include_env=False, **kwargs):
|
||||
for config_file, df, config in read_data(pattern, keys=keys):
|
||||
if len(df) < 1:
|
||||
continue
|
||||
p = func(df, *keys, **kwargs)
|
||||
p.plot(title=config['name'])
|
||||
yield config_file, p, config
|
||||
|
||||
|
||||
|
@@ -8,11 +8,10 @@ import yaml
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
from copy import deepcopy
|
||||
from collections import Counter
|
||||
from networkx.readwrite import json_graph
|
||||
|
||||
import networkx as nx
|
||||
import nxsim
|
||||
import simpy
|
||||
|
||||
from . import serialization, agents, analysis, history, utils
|
||||
|
||||
@@ -23,7 +22,7 @@ _CONFIG_PROPS = [ 'name',
|
||||
'interval',
|
||||
]
|
||||
|
||||
class Environment(nxsim.NetworkEnvironment):
|
||||
class Environment(simpy.Environment):
|
||||
"""
|
||||
The environment is key in a simulation. It contains the network topology,
|
||||
a reference to network and environment agents, as well as the environment
|
||||
@@ -42,25 +41,33 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
interval=1,
|
||||
seed=None,
|
||||
topology=None,
|
||||
*args, **kwargs):
|
||||
initial_time=0,
|
||||
**environment_params):
|
||||
|
||||
|
||||
self.name = name or 'UnnamedEnvironment'
|
||||
seed = seed or time.time()
|
||||
random.seed(seed)
|
||||
if isinstance(states, list):
|
||||
states = dict(enumerate(states))
|
||||
self.states = deepcopy(states) if states else {}
|
||||
self.default_state = deepcopy(default_state) or {}
|
||||
if not topology:
|
||||
topology = nx.Graph()
|
||||
super().__init__(*args, topology=topology, **kwargs)
|
||||
self.G = nx.Graph(topology)
|
||||
|
||||
super().__init__(initial_time=initial_time)
|
||||
self.environment_params = environment_params
|
||||
|
||||
self._env_agents = {}
|
||||
self.interval = interval
|
||||
self._history = history.History(name=self.name,
|
||||
backup=True)
|
||||
self['SEED'] = seed
|
||||
# Add environment agents first, so their events get
|
||||
# executed before network agents
|
||||
self.environment_agents = environment_agents or []
|
||||
self.network_agents = network_agents or []
|
||||
self['SEED'] = seed or time.time()
|
||||
random.seed(self['SEED'])
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
@@ -150,12 +157,10 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
start = start or self.now
|
||||
return self.G.add_edge(agent1, agent2, **attrs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
def run(self, until, *args, **kwargs):
|
||||
self._save_state()
|
||||
self.log_stats()
|
||||
super().run(*args, **kwargs)
|
||||
super().run(until, *args, **kwargs)
|
||||
self._history.flush_cache()
|
||||
self.log_stats()
|
||||
|
||||
def _save_state(self, now=None):
|
||||
serialization.logger.debug('Saving state @{}'.format(self.now))
|
||||
@@ -317,25 +322,6 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
|
||||
return G
|
||||
|
||||
def stats(self):
|
||||
stats = {}
|
||||
stats['network'] = {}
|
||||
stats['network']['n_nodes'] = self.G.number_of_nodes()
|
||||
stats['network']['n_edges'] = self.G.number_of_edges()
|
||||
c = Counter()
|
||||
c.update(a.__class__.__name__ for a in self.network_agents)
|
||||
stats['agents'] = {}
|
||||
stats['agents']['model_count'] = dict(c)
|
||||
c2 = Counter()
|
||||
c2.update(a['id'] for a in self.network_agents)
|
||||
stats['agents']['state_count'] = dict(c2)
|
||||
stats['params'] = self.environment_params
|
||||
return stats
|
||||
|
||||
def log_stats(self):
|
||||
stats = self.stats()
|
||||
serialization.logger.info('Environment stats: \n{}'.format(yaml.dump(stats, default_flow_style=False)))
|
||||
|
||||
def __getstate__(self):
|
||||
state = {}
|
||||
for prop in _CONFIG_PROPS:
|
||||
@@ -343,6 +329,7 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
state['G'] = json_graph.node_link_data(self.G)
|
||||
state['environment_agents'] = self._env_agents
|
||||
state['history'] = self._history
|
||||
state['_now'] = self._now
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
@@ -351,6 +338,8 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
self._env_agents = state['environment_agents']
|
||||
self.G = json_graph.node_link_graph(state['G'])
|
||||
self._history = state['history']
|
||||
self._now = state['_now']
|
||||
self._queue = []
|
||||
|
||||
|
||||
SoilEnvironment = Environment
|
||||
|
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
import csv as csvlib
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
|
||||
from .serialization import deserialize
|
||||
from .utils import open_or_reuse, logger, timer
|
||||
@@ -13,15 +14,6 @@ from .utils import open_or_reuse, logger, timer
|
||||
from . import utils
|
||||
|
||||
|
||||
def for_sim(simulation, names, *args, **kwargs):
|
||||
'''Return the set of exporters for a simulation, given the exporter names'''
|
||||
exporters = []
|
||||
for name in names:
|
||||
mod = deserialize(name, known_modules=['soil.exporters'])
|
||||
exporters.append(mod(simulation, *args, **kwargs))
|
||||
return exporters
|
||||
|
||||
|
||||
class DryRunner(BytesIO):
|
||||
def __init__(self, fname, *args, copy_to=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -37,8 +29,12 @@ class DryRunner(BytesIO):
|
||||
super().write(bytes(txt, 'utf-8'))
|
||||
|
||||
def close(self):
|
||||
logger.info('**Not** written to {} (dry run mode):\n\n{}\n\n'.format(self.__fname,
|
||||
self.getvalue().decode()))
|
||||
content = '(binary data not shown)'
|
||||
try:
|
||||
content = self.getvalue().decode()
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
logger.info('**Not** written to {} (dry run mode):\n\n{}\n\n'.format(self.__fname, content))
|
||||
super().close()
|
||||
|
||||
|
||||
@@ -49,7 +45,7 @@ class Exporter:
|
||||
'''
|
||||
|
||||
def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None):
|
||||
self.sim = simulation
|
||||
self.simulation = simulation
|
||||
outdir = outdir or os.path.join(os.getcwd(), 'soil_output')
|
||||
self.outdir = os.path.join(outdir,
|
||||
simulation.group or '',
|
||||
@@ -59,12 +55,15 @@ class Exporter:
|
||||
|
||||
def start(self):
|
||||
'''Method to call when the simulation starts'''
|
||||
pass
|
||||
|
||||
def end(self):
|
||||
def end(self, stats):
|
||||
'''Method to call when the simulation ends'''
|
||||
pass
|
||||
|
||||
def trial_end(self, env):
|
||||
def trial(self, env, stats):
|
||||
'''Method to call when a trial ends'''
|
||||
pass
|
||||
|
||||
def output(self, f, mode='w', **kwargs):
|
||||
if self.dry_run:
|
||||
@@ -84,35 +83,47 @@ class default(Exporter):
|
||||
def start(self):
|
||||
if not self.dry_run:
|
||||
logger.info('Dumping results to %s', self.outdir)
|
||||
self.sim.dump_yaml(outdir=self.outdir)
|
||||
self.simulation.dump_yaml(outdir=self.outdir)
|
||||
else:
|
||||
logger.info('NOT dumping results')
|
||||
|
||||
def trial_end(self, env):
|
||||
def trial(self, env, stats):
|
||||
if not self.dry_run:
|
||||
with timer('Dumping simulation {} trial {}'.format(self.sim.name,
|
||||
with timer('Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||
env.name)):
|
||||
with self.output('{}.sqlite'.format(env.name), mode='wb') as f:
|
||||
env.dump_sqlite(f)
|
||||
|
||||
def end(self, stats):
|
||||
with timer('Dumping simulation {}\'s stats'.format(self.simulation.name)):
|
||||
with self.output('{}.sqlite'.format(self.simulation.name), mode='wb') as f:
|
||||
self.simulation.dump_sqlite(f)
|
||||
|
||||
|
||||
|
||||
class csv(Exporter):
|
||||
'''Export the state of each environment (and its agents) in a separate CSV file'''
|
||||
def trial_end(self, env):
|
||||
with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.sim.name,
|
||||
def trial(self, env, stats):
|
||||
with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name,
|
||||
env.name,
|
||||
self.outdir)):
|
||||
with self.output('{}.csv'.format(env.name)) as f:
|
||||
env.dump_csv(f)
|
||||
|
||||
with self.output('{}.stats.csv'.format(env.name)) as f:
|
||||
statwriter = csvlib.writer(f, delimiter='\t', quotechar='"', quoting=csvlib.QUOTE_ALL)
|
||||
|
||||
for stat in stats:
|
||||
statwriter.writerow(stat)
|
||||
|
||||
|
||||
class gexf(Exporter):
|
||||
def trial_end(self, env):
|
||||
def trial(self, env, stats):
|
||||
if self.dry_run:
|
||||
logger.info('Not dumping GEXF in dry_run mode')
|
||||
return
|
||||
|
||||
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.sim.name,
|
||||
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||
env.name)):
|
||||
with self.output('{}.gexf'.format(env.name), mode='wb') as f:
|
||||
env.dump_gexf(f)
|
||||
@@ -124,56 +135,24 @@ class dummy(Exporter):
|
||||
with self.output('dummy', 'w') as f:
|
||||
f.write('simulation started @ {}\n'.format(time.time()))
|
||||
|
||||
def trial_end(self, env):
|
||||
def trial(self, env, stats):
|
||||
with self.output('dummy', 'w') as f:
|
||||
for i in env.history_to_tuples():
|
||||
f.write(','.join(map(str, i)))
|
||||
f.write('\n')
|
||||
|
||||
def end(self):
|
||||
def sim(self, stats):
|
||||
with self.output('dummy', 'a') as f:
|
||||
f.write('simulation ended @ {}\n'.format(time.time()))
|
||||
|
||||
|
||||
class distribution(Exporter):
|
||||
'''
|
||||
Write the distribution of agent states at the end of each trial,
|
||||
the mean value, and its deviation.
|
||||
'''
|
||||
|
||||
def start(self):
|
||||
self.means = []
|
||||
self.counts = []
|
||||
|
||||
def trial_end(self, env):
|
||||
df = env[None, None, None].df()
|
||||
ix = df.index[-1]
|
||||
attrs = df.columns.levels[0]
|
||||
vc = {}
|
||||
stats = {}
|
||||
for a in attrs:
|
||||
t = df.loc[(ix, a)]
|
||||
try:
|
||||
self.means.append(('mean', a, t.mean()))
|
||||
except TypeError:
|
||||
for name, count in t.value_counts().iteritems():
|
||||
self.counts.append(('count', a, name, count))
|
||||
|
||||
def end(self):
|
||||
dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value'])
|
||||
dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count'])
|
||||
dfm = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
||||
dfc = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
||||
with self.output('counts.csv') as f:
|
||||
dfc.to_csv(f)
|
||||
with self.output('metrics.csv') as f:
|
||||
dfm.to_csv(f)
|
||||
|
||||
class graphdrawing(Exporter):
|
||||
|
||||
def trial_end(self, env):
|
||||
def trial(self, env, stats):
|
||||
# Outside effects
|
||||
f = plt.figure()
|
||||
nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111))
|
||||
with open('graph-{}.png'.format(env.name)) as f:
|
||||
f.savefig(f)
|
||||
|
||||
|
142
soil/history.py
142
soil/history.py
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
from collections import UserDict, namedtuple
|
||||
|
||||
from . import serialization
|
||||
from .utils import open_or_reuse
|
||||
from .utils import open_or_reuse, unflatten_dict
|
||||
|
||||
|
||||
class History:
|
||||
@@ -19,29 +19,43 @@ class History:
|
||||
Store and retrieve values from a sqlite database.
|
||||
"""
|
||||
|
||||
def __init__(self, name=None, db_path=None, backup=False):
|
||||
self._db = None
|
||||
def __init__(self, name=None, db_path=None, backup=False, readonly=False):
|
||||
if readonly and (not os.path.exists(db_path)):
|
||||
raise Exception('The DB file does not exist. Cannot open in read-only mode')
|
||||
|
||||
if db_path is None:
|
||||
self._db = None
|
||||
self._temp = db_path is None
|
||||
self._stats_columns = None
|
||||
self.readonly = readonly
|
||||
|
||||
if self._temp:
|
||||
if not name:
|
||||
name = time.time()
|
||||
_, db_path = tempfile.mkstemp(suffix='{}.sqlite'.format(name))
|
||||
# The file will be deleted as soon as it's closed
|
||||
# Normally, that will be on destruction
|
||||
db_path = tempfile.NamedTemporaryFile(suffix='{}.sqlite'.format(name)).name
|
||||
|
||||
|
||||
if backup and os.path.exists(db_path):
|
||||
newname = db_path + '.backup{}.sqlite'.format(time.time())
|
||||
os.rename(db_path, newname)
|
||||
newname = db_path + '.backup{}.sqlite'.format(time.time())
|
||||
os.rename(db_path, newname)
|
||||
|
||||
self.db_path = db_path
|
||||
|
||||
self.db = db_path
|
||||
self._dtypes = {}
|
||||
self._tups = []
|
||||
|
||||
|
||||
if self.readonly:
|
||||
return
|
||||
|
||||
with self.db:
|
||||
logger.debug('Creating database {}'.format(self.db_path))
|
||||
self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text text)''')
|
||||
self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text)''')
|
||||
self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''')
|
||||
self.db.execute('''CREATE TABLE IF NOT EXISTS stats (trial_id text)''')
|
||||
self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''')
|
||||
self._dtypes = {}
|
||||
self._tups = []
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
@@ -58,6 +72,7 @@ class History:
|
||||
if isinstance(db_path, str):
|
||||
logger.debug('Connecting to database {}'.format(db_path))
|
||||
self._db = sqlite3.connect(db_path)
|
||||
self._db.row_factory = sqlite3.Row
|
||||
else:
|
||||
self._db = db_path
|
||||
|
||||
@@ -68,9 +83,56 @@ class History:
|
||||
self._db.close()
|
||||
self._db = None
|
||||
|
||||
def save_stats(self, stat):
|
||||
if self.readonly:
|
||||
print('DB in readonly mode')
|
||||
return
|
||||
if not stat:
|
||||
return
|
||||
with self.db:
|
||||
if not self._stats_columns:
|
||||
self._stats_columns = list(c['name'] for c in self.db.execute('PRAGMA table_info(stats)'))
|
||||
|
||||
for column, value in stat.items():
|
||||
if column in self._stats_columns:
|
||||
continue
|
||||
dtype = 'text'
|
||||
if not isinstance(value, str):
|
||||
try:
|
||||
float(value)
|
||||
dtype = 'real'
|
||||
int(value)
|
||||
dtype = 'int'
|
||||
except ValueError:
|
||||
pass
|
||||
self.db.execute('ALTER TABLE stats ADD "{}" "{}"'.format(column, dtype))
|
||||
self._stats_columns.append(column)
|
||||
|
||||
columns = ", ".join(map(lambda x: '"{}"'.format(x), stat.keys()))
|
||||
values = ", ".join(['"{0}"'.format(col) for col in stat.values()])
|
||||
query = "INSERT INTO stats ({columns}) VALUES ({values})".format(
|
||||
columns=columns,
|
||||
values=values
|
||||
)
|
||||
self.db.execute(query)
|
||||
|
||||
def get_stats(self, unflatten=True):
|
||||
rows = self.db.execute("select * from stats").fetchall()
|
||||
res = []
|
||||
for row in rows:
|
||||
d = {}
|
||||
for k in row.keys():
|
||||
if row[k] is None:
|
||||
continue
|
||||
d[k] = row[k]
|
||||
if unflatten:
|
||||
d = unflatten_dict(d)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
@property
|
||||
def dtypes(self):
|
||||
self.read_types()
|
||||
self._read_types()
|
||||
return {k:v[0] for k, v in self._dtypes.items()}
|
||||
|
||||
def save_tuples(self, tuples):
|
||||
@@ -93,18 +155,10 @@ class History:
|
||||
Save a collection of records to the database.
|
||||
Database writes are cached.
|
||||
'''
|
||||
value = self.convert(key, value)
|
||||
self._tups.append(Record(agent_id=agent_id,
|
||||
t_step=t_step,
|
||||
key=key,
|
||||
value=value))
|
||||
if len(self._tups) > 100:
|
||||
self.flush_cache()
|
||||
|
||||
def convert(self, key, value):
|
||||
"""Get the serialized value for a given key."""
|
||||
if self.readonly:
|
||||
raise Exception('DB in readonly mode')
|
||||
if key not in self._dtypes:
|
||||
self.read_types()
|
||||
self._read_types()
|
||||
if key not in self._dtypes:
|
||||
name = serialization.name(value)
|
||||
serializer = serialization.serializer(name)
|
||||
@@ -112,21 +166,21 @@ class History:
|
||||
self._dtypes[key] = (name, serializer, deserializer)
|
||||
with self.db:
|
||||
self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name))
|
||||
return self._dtypes[key][1](value)
|
||||
|
||||
def recover(self, key, value):
|
||||
"""Get the deserialized value for a given key, and the serialized version."""
|
||||
if key not in self._dtypes:
|
||||
self.read_types()
|
||||
if key not in self._dtypes:
|
||||
raise ValueError("Unknown datatype for {} and {}".format(key, value))
|
||||
return self._dtypes[key][2](value)
|
||||
value = self._dtypes[key][1](value)
|
||||
self._tups.append(Record(agent_id=agent_id,
|
||||
t_step=t_step,
|
||||
key=key,
|
||||
value=value))
|
||||
if len(self._tups) > 100:
|
||||
self.flush_cache()
|
||||
|
||||
def flush_cache(self):
|
||||
'''
|
||||
Use a cache to save state changes to avoid opening a session for every change.
|
||||
The cache will be flushed at the end of the simulation, and when history is accessed.
|
||||
'''
|
||||
if self.readonly:
|
||||
raise Exception('DB in readonly mode')
|
||||
logger.debug('Flushing cache {}'.format(self.db_path))
|
||||
with self.db:
|
||||
for rec in self._tups:
|
||||
@@ -139,10 +193,14 @@ class History:
|
||||
res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall()
|
||||
for r in res:
|
||||
agent_id, t_step, key, value = r
|
||||
value = self.recover(key, value)
|
||||
if key not in self._dtypes:
|
||||
self._read_types()
|
||||
if key not in self._dtypes:
|
||||
raise ValueError("Unknown datatype for {} and {}".format(key, value))
|
||||
value = self._dtypes[key][2](value)
|
||||
yield agent_id, t_step, key, value
|
||||
|
||||
def read_types(self):
|
||||
def _read_types(self):
|
||||
with self.db:
|
||||
res = self.db.execute("select key, value_type from value_types ").fetchall()
|
||||
for k, v in res:
|
||||
@@ -167,7 +225,7 @@ class History:
|
||||
|
||||
def read_sql(self, keys=None, agent_ids=None, t_steps=None, convert_types=False, limit=-1):
|
||||
|
||||
self.read_types()
|
||||
self._read_types()
|
||||
|
||||
def escape_and_join(v):
|
||||
if v is None:
|
||||
@@ -181,7 +239,13 @@ class History:
|
||||
|
||||
last_df = None
|
||||
if t_steps:
|
||||
# Look for the last value before the minimum step in the query
|
||||
# Convert negative indices into positive
|
||||
if any(x<0 for x in t_steps):
|
||||
max_t = int(self.db.execute("select max(t_step) from history").fetchone()[0])
|
||||
t_steps = [t if t>0 else max_t+1+t for t in t_steps]
|
||||
|
||||
# We will be doing ffill interpolation, so we need to look for
|
||||
# the last value before the minimum step in the query
|
||||
min_step = min(t_steps)
|
||||
last_filters = ['t_step < {}'.format(min_step),]
|
||||
last_filters = last_filters + filters
|
||||
@@ -219,7 +283,11 @@ class History:
|
||||
for k, v in self._dtypes.items():
|
||||
if k in df_p:
|
||||
dtype, _, deserial = v
|
||||
df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
|
||||
try:
|
||||
df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
|
||||
except (TypeError, ValueError):
|
||||
# Avoid forward-filling unknown/incompatible types
|
||||
continue
|
||||
if t_steps:
|
||||
df_p = df_p.reindex(t_steps, method='ffill')
|
||||
return df_p.ffill()
|
||||
@@ -313,3 +381,5 @@ class Records():
|
||||
|
||||
Key = namedtuple('Key', ['agent_id', 't_step', 'key'])
|
||||
Record = namedtuple('Record', 'agent_id t_step key value')
|
||||
|
||||
Stat = namedtuple('Stat', 'trial_id')
|
||||
|
@@ -17,10 +17,10 @@ logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def load_network(network_params, dir_path=None):
|
||||
if network_params is None:
|
||||
return nx.Graph()
|
||||
path = network_params.get('path', None)
|
||||
if path:
|
||||
G = nx.Graph()
|
||||
|
||||
if 'path' in network_params:
|
||||
path = network_params['path']
|
||||
if dir_path and not os.path.isabs(path):
|
||||
path = os.path.join(dir_path, path)
|
||||
extension = os.path.splitext(path)[1][1:]
|
||||
@@ -32,21 +32,22 @@ def load_network(network_params, dir_path=None):
|
||||
method = getattr(nx.readwrite, 'read_' + extension)
|
||||
except AttributeError:
|
||||
raise AttributeError('Unknown format')
|
||||
return method(path, **kwargs)
|
||||
G = method(path, **kwargs)
|
||||
|
||||
net_args = network_params.copy()
|
||||
if 'generator' not in net_args:
|
||||
return nx.Graph()
|
||||
elif 'generator' in network_params:
|
||||
net_args = network_params.copy()
|
||||
net_gen = net_args.pop('generator')
|
||||
|
||||
net_gen = net_args.pop('generator')
|
||||
if dir_path not in sys.path:
|
||||
sys.path.append(dir_path)
|
||||
|
||||
if dir_path not in sys.path:
|
||||
sys.path.append(dir_path)
|
||||
method = deserializer(net_gen,
|
||||
known_modules=['networkx.generators',])
|
||||
G = method(**net_args)
|
||||
|
||||
return G
|
||||
|
||||
method = deserializer(net_gen,
|
||||
known_modules=['networkx.generators',])
|
||||
|
||||
return method(**net_args)
|
||||
|
||||
|
||||
def load_file(infile):
|
||||
@@ -66,11 +67,32 @@ def expand_template(config):
|
||||
raise ValueError(('You must provide a definition of variables'
|
||||
' for the template.'))
|
||||
|
||||
template = Template(config['template'])
|
||||
template = config['template']
|
||||
|
||||
sampler_name = config.get('sampler', 'SALib.sample.morris.sample')
|
||||
n_samples = int(config.get('samples', 100))
|
||||
sampler = deserializer(sampler_name)
|
||||
if not isinstance(template, str):
|
||||
template = yaml.dump(template)
|
||||
|
||||
template = Template(template)
|
||||
|
||||
params = params_for_template(config)
|
||||
|
||||
blank_str = template.render({k: 0 for k in params[0].keys()})
|
||||
blank = list(load_string(blank_str))
|
||||
if len(blank) > 1:
|
||||
raise ValueError('Templates must not return more than one configuration')
|
||||
if 'name' in blank[0]:
|
||||
raise ValueError('Templates cannot be named, use group instead')
|
||||
|
||||
for ps in params:
|
||||
string = template.render(ps)
|
||||
for c in load_string(string):
|
||||
yield c
|
||||
|
||||
|
||||
def params_for_template(config):
|
||||
sampler_config = config.get('sampler', {'N': 100})
|
||||
sampler = sampler_config.pop('method', 'SALib.sample.morris.sample')
|
||||
sampler = deserializer(sampler)
|
||||
bounds = config['vars']['bounds']
|
||||
|
||||
problem = {
|
||||
@@ -78,7 +100,7 @@ def expand_template(config):
|
||||
'names': list(bounds.keys()),
|
||||
'bounds': list(v for v in bounds.values())
|
||||
}
|
||||
samples = sampler(problem, n_samples)
|
||||
samples = sampler(problem, **sampler_config)
|
||||
|
||||
lists = config['vars'].get('lists', {})
|
||||
names = list(lists.keys())
|
||||
@@ -88,20 +110,7 @@ def expand_template(config):
|
||||
allnames = names + problem['names']
|
||||
allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)]
|
||||
params = list(map(lambda x: dict(zip(allnames, x)), allvalues))
|
||||
|
||||
|
||||
blank_str = template.render({k: 0 for k in allnames})
|
||||
blank = list(load_string(blank_str))
|
||||
if len(blank) > 1:
|
||||
raise ValueError('Templates must not return more than one configuration')
|
||||
if 'name' in blank[0]:
|
||||
raise ValueError('Templates cannot be named, use group instead')
|
||||
|
||||
confs = []
|
||||
for ps in params:
|
||||
string = template.render(ps)
|
||||
for c in load_string(string):
|
||||
yield c
|
||||
return params
|
||||
|
||||
|
||||
def load_files(*patterns, **kwargs):
|
||||
@@ -116,7 +125,7 @@ def load_files(*patterns, **kwargs):
|
||||
|
||||
def load_config(config):
|
||||
if isinstance(config, dict):
|
||||
yield config, None
|
||||
yield config, os.getcwd()
|
||||
else:
|
||||
yield from load_files(config)
|
||||
|
||||
@@ -199,3 +208,13 @@ def deserialize(type_, value=None, **kwargs):
|
||||
if value is None:
|
||||
return des
|
||||
return des(value)
|
||||
|
||||
|
||||
def deserialize_all(names, *args, known_modules=['soil'], **kwargs):
|
||||
'''Return the set of exporters for a simulation, given the exporter names'''
|
||||
exporters = []
|
||||
for name in names:
|
||||
mod = deserialize(name, known_modules=known_modules)
|
||||
exporters.append(mod(*args, **kwargs))
|
||||
return exporters
|
||||
|
||||
|
@@ -4,6 +4,7 @@ import importlib
|
||||
import sys
|
||||
import yaml
|
||||
import traceback
|
||||
import logging
|
||||
import networkx as nx
|
||||
from networkx.readwrite import json_graph
|
||||
from multiprocessing import Pool
|
||||
@@ -11,17 +12,19 @@ from functools import partial
|
||||
|
||||
import pickle
|
||||
|
||||
from nxsim import NetworkSimulation
|
||||
|
||||
from . import serialization, utils, basestring, agents
|
||||
from .environment import Environment
|
||||
from .utils import logger
|
||||
from .exporters import for_sim as exporters_for_sim
|
||||
from .exporters import default
|
||||
from .stats import defaultStats
|
||||
from .history import History
|
||||
|
||||
|
||||
class Simulation(NetworkSimulation):
|
||||
#TODO: change documentation for simulation
|
||||
|
||||
class Simulation:
|
||||
"""
|
||||
Subclass of nsim.NetworkSimulation with three main differences:
|
||||
Similar to nsim.NetworkSimulation with three main differences:
|
||||
1) agent type can be specified by name or by class.
|
||||
2) instead of just one type, a network agents distribution can be used.
|
||||
The distribution specifies the weight (or probability) of each
|
||||
@@ -91,11 +94,12 @@ class Simulation(NetworkSimulation):
|
||||
environment_params=None, environment_class=None,
|
||||
**kwargs):
|
||||
|
||||
self.seed = str(seed) or str(time.time())
|
||||
self.load_module = load_module
|
||||
self.network_params = network_params
|
||||
self.name = name or 'Unnamed_' + time.strftime("%Y-%m-%d_%H:%M:%S")
|
||||
self.group = group or None
|
||||
self.name = name or 'Unnamed'
|
||||
self.seed = str(seed or name)
|
||||
self._id = '{}_{}'.format(self.name, time.strftime("%Y-%m-%d_%H.%M.%S"))
|
||||
self.group = group or ''
|
||||
self.num_trials = num_trials
|
||||
self.max_time = max_time
|
||||
self.default_state = default_state or {}
|
||||
@@ -128,12 +132,15 @@ class Simulation(NetworkSimulation):
|
||||
self.states = agents._validate_states(states,
|
||||
self.topology)
|
||||
|
||||
self._history = History(name=self.name,
|
||||
backup=False)
|
||||
|
||||
def run_simulation(self, *args, **kwargs):
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
'''Run the simulation and return the list of resulting environments'''
|
||||
return list(self._run_simulation_gen(*args, **kwargs))
|
||||
return list(self.run_gen(*args, **kwargs))
|
||||
|
||||
def _run_sync_or_async(self, parallel=False, *args, **kwargs):
|
||||
if parallel:
|
||||
@@ -148,44 +155,82 @@ class Simulation(NetworkSimulation):
|
||||
yield i
|
||||
else:
|
||||
for i in range(self.num_trials):
|
||||
yield self.run_trial(i,
|
||||
*args,
|
||||
yield self.run_trial(*args,
|
||||
**kwargs)
|
||||
|
||||
def _run_simulation_gen(self, *args, parallel=False, dry_run=False,
|
||||
exporters=['default', ], outdir=None, exporter_params={}, **kwargs):
|
||||
def run_gen(self, *args, parallel=False, dry_run=False,
|
||||
exporters=[default, ], stats=[defaultStats], outdir=None, exporter_params={},
|
||||
stats_params={}, log_level=None,
|
||||
**kwargs):
|
||||
'''Run the simulation and yield the resulting environments.'''
|
||||
if log_level:
|
||||
logger.setLevel(log_level)
|
||||
logger.info('Using exporters: %s', exporters or [])
|
||||
logger.info('Output directory: %s', outdir)
|
||||
exporters = exporters_for_sim(self,
|
||||
exporters,
|
||||
dry_run=dry_run,
|
||||
outdir=outdir,
|
||||
**exporter_params)
|
||||
exporters = serialization.deserialize_all(exporters,
|
||||
simulation=self,
|
||||
known_modules=['soil.exporters',],
|
||||
dry_run=dry_run,
|
||||
outdir=outdir,
|
||||
**exporter_params)
|
||||
stats = serialization.deserialize_all(simulation=self,
|
||||
names=stats,
|
||||
known_modules=['soil.stats',],
|
||||
**stats_params)
|
||||
|
||||
with utils.timer('simulation {}'.format(self.name)):
|
||||
for stat in stats:
|
||||
stat.start()
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.start()
|
||||
|
||||
for env in self._run_sync_or_async(*args, parallel=parallel,
|
||||
for env in self._run_sync_or_async(*args,
|
||||
parallel=parallel,
|
||||
log_level=log_level,
|
||||
**kwargs):
|
||||
|
||||
collected = list(stat.trial(env) for stat in stats)
|
||||
|
||||
saved = self.save_stats(collected, t_step=env.now, trial_id=env.name)
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.trial_end(env)
|
||||
exporter.trial(env, saved)
|
||||
|
||||
yield env
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.end()
|
||||
|
||||
def get_env(self, trial_id = 0, **kwargs):
|
||||
collected = list(stat.end() for stat in stats)
|
||||
saved = self.save_stats(collected)
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.end(saved)
|
||||
|
||||
|
||||
def save_stats(self, collection, **kwargs):
|
||||
stats = dict(kwargs)
|
||||
for stat in collection:
|
||||
stats.update(stat)
|
||||
self._history.save_stats(utils.flatten_dict(stats))
|
||||
return stats
|
||||
|
||||
def get_stats(self, **kwargs):
|
||||
return self._history.get_stats(**kwargs)
|
||||
|
||||
def log_stats(self, stats):
|
||||
logger.info('Stats: \n{}'.format(yaml.dump(stats, default_flow_style=False)))
|
||||
|
||||
|
||||
def get_env(self, trial_id=0, **kwargs):
|
||||
'''Create an environment for a trial of the simulation'''
|
||||
opts = self.environment_params.copy()
|
||||
env_name = '{}_trial_{}'.format(self.name, trial_id)
|
||||
opts.update({
|
||||
'name': env_name,
|
||||
'name': trial_id,
|
||||
'topology': self.topology.copy(),
|
||||
'seed': self.seed+env_name,
|
||||
'seed': '{}_trial_{}'.format(self.seed, trial_id),
|
||||
'initial_time': 0,
|
||||
'interval': self.interval,
|
||||
'network_agents': self.network_agents,
|
||||
'initial_time': 0,
|
||||
'states': self.states,
|
||||
'default_state': self.default_state,
|
||||
'environment_agents': self.environment_agents,
|
||||
@@ -194,20 +239,22 @@ class Simulation(NetworkSimulation):
|
||||
env = self.environment_class(**opts)
|
||||
return env
|
||||
|
||||
def run_trial(self, trial_id=0, until=None, **opts):
|
||||
"""Run a single trial of the simulation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trial_id : int
|
||||
def run_trial(self, until=None, log_level=logging.INFO, **opts):
|
||||
"""
|
||||
Run a single trial of the simulation
|
||||
|
||||
"""
|
||||
trial_id = '{}_trial_{}'.format(self.name, time.time()).replace('.', '-')
|
||||
if log_level:
|
||||
logger.setLevel(log_level)
|
||||
# Set-up trial environment and graph
|
||||
until = until or self.max_time
|
||||
env = self.get_env(trial_id = trial_id, **opts)
|
||||
env = self.get_env(trial_id=trial_id, **opts)
|
||||
# Set up agents on nodes
|
||||
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
||||
env.run(until)
|
||||
return env
|
||||
|
||||
def run_trial_exceptions(self, *args, **kwargs):
|
||||
'''
|
||||
A wrapper for run_trial that catches exceptions and returns them.
|
||||
@@ -216,9 +263,10 @@ class Simulation(NetworkSimulation):
|
||||
try:
|
||||
return self.run_trial(*args, **kwargs)
|
||||
except Exception as ex:
|
||||
c = ex.__cause__
|
||||
c.message = ''.join(traceback.format_exception(type(c), c, c.__traceback__)[:])
|
||||
return c
|
||||
if ex.__cause__ is not None:
|
||||
ex = ex.__cause__
|
||||
ex.message = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)[:])
|
||||
return ex
|
||||
|
||||
def to_dict(self):
|
||||
return self.__getstate__()
|
||||
@@ -247,6 +295,9 @@ class Simulation(NetworkSimulation):
|
||||
with utils.open_or_reuse(f, 'wb') as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
def dump_sqlite(self, f):
|
||||
return self._history.dump(f)
|
||||
|
||||
def __getstate__(self):
|
||||
state={}
|
||||
for k, v in self.__dict__.items():
|
||||
|
106
soil/stats.py
Normal file
106
soil/stats.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import pandas as pd
|
||||
|
||||
from collections import Counter
|
||||
|
||||
class Stats:
|
||||
'''
|
||||
Interface for all stats. It is not necessary, but it is useful
|
||||
if you don't plan to implement all the methods.
|
||||
'''
|
||||
|
||||
def __init__(self, simulation):
|
||||
self.simulation = simulation
|
||||
|
||||
def start(self):
|
||||
'''Method to call when the simulation starts'''
|
||||
pass
|
||||
|
||||
def end(self):
|
||||
'''Method to call when the simulation ends'''
|
||||
return {}
|
||||
|
||||
def trial(self, env):
|
||||
'''Method to call when a trial ends'''
|
||||
return {}
|
||||
|
||||
|
||||
class distribution(Stats):
|
||||
'''
|
||||
Calculate the distribution of agent states at the end of each trial,
|
||||
the mean value, and its deviation.
|
||||
'''
|
||||
|
||||
def start(self):
|
||||
self.means = []
|
||||
self.counts = []
|
||||
|
||||
def trial(self, env):
|
||||
df = env[None, None, None].df()
|
||||
df = df.drop('SEED', axis=1)
|
||||
ix = df.index[-1]
|
||||
attrs = df.columns.get_level_values(0)
|
||||
vc = {}
|
||||
stats = {
|
||||
'mean': {},
|
||||
'count': {},
|
||||
}
|
||||
for a in attrs:
|
||||
t = df.loc[(ix, a)]
|
||||
try:
|
||||
stats['mean'][a] = t.mean()
|
||||
self.means.append(('mean', a, t.mean()))
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
for name, count in t.value_counts().iteritems():
|
||||
if a not in stats['count']:
|
||||
stats['count'][a] = {}
|
||||
stats['count'][a][name] = count
|
||||
self.counts.append(('count', a, name, count))
|
||||
|
||||
return stats
|
||||
|
||||
def end(self):
|
||||
dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value'])
|
||||
dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count'])
|
||||
|
||||
count = {}
|
||||
mean = {}
|
||||
|
||||
if self.means:
|
||||
res = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
||||
mean = res['value'].to_dict()
|
||||
if self.counts:
|
||||
res = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
||||
for k,v in res['count'].to_dict().items():
|
||||
if k not in count:
|
||||
count[k] = {}
|
||||
for tup, times in v.items():
|
||||
subkey, subcount = tup
|
||||
if subkey not in count[k]:
|
||||
count[k][subkey] = {}
|
||||
count[k][subkey][subcount] = times
|
||||
|
||||
|
||||
return {'count': count, 'mean': mean}
|
||||
|
||||
|
||||
class defaultStats(Stats):
|
||||
|
||||
def trial(self, env):
|
||||
c = Counter()
|
||||
c.update(a.__class__.__name__ for a in env.network_agents)
|
||||
|
||||
c2 = Counter()
|
||||
c2.update(a['id'] for a in env.network_agents)
|
||||
|
||||
return {
|
||||
'network ': {
|
||||
'n_nodes': env.G.number_of_nodes(),
|
||||
'n_edges': env.G.number_of_nodes(),
|
||||
},
|
||||
'agents': {
|
||||
'model_count': dict(c),
|
||||
'state_count': dict(c2),
|
||||
}
|
||||
}
|
@@ -7,6 +7,7 @@ from shutil import copyfile
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = logging.getLogger('soil')
|
||||
logging.basicConfig()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
@@ -31,14 +32,13 @@ def safe_open(path, mode='r', backup=True, **kwargs):
|
||||
os.makedirs(outdir)
|
||||
if backup and 'w' in mode and os.path.exists(path):
|
||||
creation = os.path.getctime(path)
|
||||
stamp = time.strftime('%Y-%m-%d_%H:%M', time.localtime(creation))
|
||||
stamp = time.strftime('%Y-%m-%d_%H.%M.%S', time.localtime(creation))
|
||||
|
||||
backup_dir = os.path.join(outdir, stamp)
|
||||
backup_dir = os.path.join(outdir, 'backup')
|
||||
if not os.path.exists(backup_dir):
|
||||
os.makedirs(backup_dir)
|
||||
newpath = os.path.join(backup_dir, os.path.basename(path))
|
||||
if os.path.exists(newpath):
|
||||
newpath = '{}@{}'.format(newpath, time.time())
|
||||
newpath = os.path.join(backup_dir, '{}@{}'.format(os.path.basename(path),
|
||||
stamp))
|
||||
copyfile(path, newpath)
|
||||
return open(path, mode=mode, **kwargs)
|
||||
|
||||
@@ -48,3 +48,40 @@ def open_or_reuse(f, *args, **kwargs):
|
||||
return safe_open(f, *args, **kwargs)
|
||||
except (AttributeError, TypeError):
|
||||
return f
|
||||
|
||||
def flatten_dict(d):
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
return dict(_flatten_dict(d))
|
||||
|
||||
def _flatten_dict(d, prefix=''):
|
||||
if not isinstance(d, dict):
|
||||
# print('END:', prefix, d)
|
||||
yield prefix, d
|
||||
return
|
||||
if prefix:
|
||||
prefix = prefix + '.'
|
||||
for k, v in d.items():
|
||||
# print(k, v)
|
||||
res = list(_flatten_dict(v, prefix='{}{}'.format(prefix, k)))
|
||||
# print('RES:', res)
|
||||
yield from res
|
||||
|
||||
|
||||
def unflatten_dict(d):
|
||||
out = {}
|
||||
for k, v in d.items():
|
||||
target = out
|
||||
if not isinstance(k, str):
|
||||
target[k] = v
|
||||
continue
|
||||
tokens = k.split('.')
|
||||
if len(tokens) < 2:
|
||||
target[k] = v
|
||||
continue
|
||||
for token in tokens[:-1]:
|
||||
if token not in target:
|
||||
target[token] = {}
|
||||
target = target[token]
|
||||
target[tokens[-1]] = v
|
||||
return out
|
||||
|
@@ -66,8 +66,8 @@ class TestAnalysis(TestCase):
|
||||
env = self.env
|
||||
df = analysis.read_sql(env._history.db_path)
|
||||
res = analysis.get_count(df, 'SEED', 'id')
|
||||
assert res['SEED']['seedanalysis_trial_0'].iloc[0] == 1
|
||||
assert res['SEED']['seedanalysis_trial_0'].iloc[-1] == 1
|
||||
assert res['SEED'][self.env['SEED']].iloc[0] == 1
|
||||
assert res['SEED'][self.env['SEED']].iloc[-1] == 1
|
||||
assert res['id']['odd'].iloc[0] == 2
|
||||
assert res['id']['even'].iloc[0] == 0
|
||||
assert res['id']['odd'].iloc[-1] == 1
|
||||
@@ -75,7 +75,7 @@ class TestAnalysis(TestCase):
|
||||
|
||||
def test_value(self):
|
||||
env = self.env
|
||||
df = analysis.read_sql(env._history._db)
|
||||
df = analysis.read_sql(env._history.db_path)
|
||||
res_sum = analysis.get_value(df, 'count')
|
||||
|
||||
assert res_sum['count'].iloc[0] == 2
|
||||
@@ -86,4 +86,4 @@ class TestAnalysis(TestCase):
|
||||
|
||||
res_total = analysis.get_value(df)
|
||||
|
||||
res_total['SEED'].iloc[0] == 'seedanalysis_trial_0'
|
||||
res_total['SEED'].iloc[0] == self.env['SEED']
|
||||
|
@@ -31,7 +31,7 @@ def make_example_test(path, config):
|
||||
try:
|
||||
n = config['network_params']['n']
|
||||
assert len(list(env.network_agents)) == n
|
||||
assert env.now > 2 # It has run
|
||||
assert env.now > 0 # It has run
|
||||
assert env.now <= config['max_time'] # But not further than allowed
|
||||
except KeyError:
|
||||
pass
|
||||
|
@@ -6,26 +6,32 @@ from time import time
|
||||
|
||||
from unittest import TestCase
|
||||
from soil import exporters
|
||||
from soil.utils import safe_open
|
||||
from soil import simulation
|
||||
|
||||
from soil.stats import distribution
|
||||
|
||||
class Dummy(exporters.Exporter):
|
||||
started = False
|
||||
trials = 0
|
||||
ended = False
|
||||
total_time = 0
|
||||
called_start = 0
|
||||
called_trial = 0
|
||||
called_end = 0
|
||||
|
||||
def start(self):
|
||||
self.__class__.called_start += 1
|
||||
self.__class__.started = True
|
||||
|
||||
def trial_end(self, env):
|
||||
def trial(self, env, stats):
|
||||
assert env
|
||||
self.__class__.trials += 1
|
||||
self.__class__.total_time += env.now
|
||||
self.__class__.called_trial += 1
|
||||
|
||||
def end(self):
|
||||
def end(self, stats):
|
||||
self.__class__.ended = True
|
||||
self.__class__.called_end += 1
|
||||
|
||||
|
||||
class Exporters(TestCase):
|
||||
@@ -39,32 +45,17 @@ class Exporters(TestCase):
|
||||
'environment_params': {}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
s.run_simulation(exporters=[Dummy], dry_run=True)
|
||||
for env in s.run_simulation(exporters=[Dummy], dry_run=True):
|
||||
assert env.now <= 2
|
||||
|
||||
assert Dummy.started
|
||||
assert Dummy.ended
|
||||
assert Dummy.called_start == 1
|
||||
assert Dummy.called_end == 1
|
||||
assert Dummy.called_trial == 5
|
||||
assert Dummy.trials == 5
|
||||
assert Dummy.total_time == 2*5
|
||||
|
||||
def test_distribution(self):
|
||||
'''The distribution exporter should write the number of agents in each state'''
|
||||
config = {
|
||||
'name': 'exporter_sim',
|
||||
'network_params': {
|
||||
'generator': 'complete_graph',
|
||||
'n': 4
|
||||
},
|
||||
'agent_type': 'CounterModel',
|
||||
'max_time': 2,
|
||||
'num_trials': 5,
|
||||
'environment_params': {}
|
||||
}
|
||||
output = io.StringIO()
|
||||
s = simulation.from_config(config)
|
||||
s.run_simulation(exporters=[exporters.distribution], dry_run=True, exporter_params={'copy_to': output})
|
||||
result = output.getvalue()
|
||||
assert 'count' in result
|
||||
assert 'SEED,Noneexporter_sim_trial_3,1,,1,1,1,1' in result
|
||||
|
||||
def test_writing(self):
|
||||
'''Try to write CSV, GEXF, sqlite and YAML (without dry_run)'''
|
||||
n_trials = 5
|
||||
@@ -83,11 +74,11 @@ class Exporters(TestCase):
|
||||
s = simulation.from_config(config)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
envs = s.run_simulation(exporters=[
|
||||
exporters.default,
|
||||
exporters.csv,
|
||||
exporters.gexf,
|
||||
exporters.distribution,
|
||||
],
|
||||
exporters.default,
|
||||
exporters.csv,
|
||||
exporters.gexf,
|
||||
],
|
||||
stats=[distribution,],
|
||||
outdir=tmpdir,
|
||||
exporter_params={'copy_to': output})
|
||||
result = output.getvalue()
|
||||
|
@@ -5,6 +5,7 @@ import shutil
|
||||
from glob import glob
|
||||
|
||||
from soil import history
|
||||
from soil import utils
|
||||
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
@@ -154,3 +155,49 @@ class TestHistory(TestCase):
|
||||
assert recovered
|
||||
for i in recovered:
|
||||
assert i in tuples
|
||||
|
||||
def test_stats(self):
|
||||
"""
|
||||
The data recovered should be equal to the one recorded.
|
||||
"""
|
||||
tuples = (
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 2, 'prob', 2),
|
||||
('env', 3, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
stat_tuples = [
|
||||
{'num_infected': 5, 'runtime': 0.2},
|
||||
{'num_infected': 5, 'runtime': 0.2},
|
||||
{'new': '40'},
|
||||
]
|
||||
h = history.History()
|
||||
h.save_tuples(tuples)
|
||||
for stat in stat_tuples:
|
||||
h.save_stats(stat)
|
||||
recovered = h.get_stats()
|
||||
assert recovered
|
||||
assert recovered[0]['num_infected'] == 5
|
||||
assert recovered[1]['runtime'] == 0.2
|
||||
assert recovered[2]['new'] == '40'
|
||||
|
||||
def test_unflatten(self):
|
||||
ex = {'count.neighbors.3': 4,
|
||||
'count.times.2': 4,
|
||||
'count.total.4': 4,
|
||||
'mean.neighbors': 3,
|
||||
'mean.times': 2,
|
||||
'mean.total': 4,
|
||||
't_step': 2,
|
||||
'trial_id': 'exporter_sim_trial_1605817956-4475424'}
|
||||
res = utils.unflatten_dict(ex)
|
||||
|
||||
assert 'count' in res
|
||||
assert 'mean' in res
|
||||
assert 't_step' in res
|
||||
assert 'trial_id' in res
|
||||
|
@@ -343,4 +343,16 @@ class TestMain(TestCase):
|
||||
configs = serialization.load_file(join(EXAMPLES, 'template.yml'))
|
||||
assert len(configs) > 0
|
||||
|
||||
|
||||
def test_until(self):
|
||||
config = {
|
||||
'name': 'exporter_sim',
|
||||
'network_params': {},
|
||||
'agent_type': 'CounterModel',
|
||||
'max_time': 2,
|
||||
'num_trials': 100,
|
||||
'environment_params': {}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
runs = list(s.run_simulation(dry_run=True))
|
||||
over = list(x.now for x in runs if x.now>2)
|
||||
assert len(over) == 0
|
||||
|
34
tests/test_stats.py
Normal file
34
tests/test_stats.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from soil import simulation, stats
|
||||
from soil.utils import unflatten_dict
|
||||
|
||||
class Stats(TestCase):
|
||||
|
||||
def test_distribution(self):
|
||||
'''The distribution exporter should write the number of agents in each state'''
|
||||
config = {
|
||||
'name': 'exporter_sim',
|
||||
'network_params': {
|
||||
'generator': 'complete_graph',
|
||||
'n': 4
|
||||
},
|
||||
'agent_type': 'CounterModel',
|
||||
'max_time': 2,
|
||||
'num_trials': 5,
|
||||
'environment_params': {}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
for env in s.run_simulation(stats=[stats.distribution]):
|
||||
pass
|
||||
# stats_res = unflatten_dict(dict(env._history['stats', -1, None]))
|
||||
allstats = s.get_stats()
|
||||
for stat in allstats:
|
||||
assert 'count' in stat
|
||||
assert 'mean' in stat
|
||||
if 'trial_id' in stat:
|
||||
assert stat['mean']['neighbors'] == 3
|
||||
assert stat['count']['total']['4'] == 4
|
||||
else:
|
||||
assert stat['count']['count']['neighbors']['3'] == 20
|
||||
assert stat['mean']['min']['neighbors'] == stat['mean']['max']['neighbors']
|
Reference in New Issue
Block a user