mirror of
				https://github.com/gsi-upm/soil
				synced 2025-11-04 09:28:16 +00:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			0.14.9
			...
			05f7f49233
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					05f7f49233 | 
							
								
								
									
										23
									
								
								CHANGELOG.md
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								CHANGELOG.md
									
									
									
									
									
								
							@@ -3,6 +3,29 @@ All notable changes to this project will be documented in this file.
 | 
			
		||||
 | 
			
		||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 | 
			
		||||
 | 
			
		||||
## [0.15.1]
 | 
			
		||||
### Added
 | 
			
		||||
* read-only `History`
 | 
			
		||||
### Fixed
 | 
			
		||||
* Serialization problem with the `Environment` on parallel mode.
 | 
			
		||||
* Analysis functions now work as they should in the tutorial
 | 
			
		||||
## [0.15.0]
 | 
			
		||||
### Added
 | 
			
		||||
* Control logging level in CLI and simulation
 | 
			
		||||
* `Stats` to calculate trial and simulation-wide statistics
 | 
			
		||||
* Simulation statistics are stored in a separate table in history (see `History.get_stats` and `History.save_stats`, as well as `soil.stats`)
 | 
			
		||||
* Aliased `NetworkAgent.G` to `NetworkAgent.topology`.
 | 
			
		||||
### Changed
 | 
			
		||||
* Templates in config files can be given as dictionaries in addition to strings
 | 
			
		||||
* Samplers are used more explicitly
 | 
			
		||||
* Removed nxsim dependency. We had already made a lot of changes, and nxsim has not been updated in 5 years.
 | 
			
		||||
* Exporter methods renamed to `trial` and `end`. Added `start`.
 | 
			
		||||
* `Distribution` exporter now a stats class
 | 
			
		||||
* `global_topology` renamed to `topology`
 | 
			
		||||
* Moved topology-related methods to `NetworkAgent`
 | 
			
		||||
### Fixed
 | 
			
		||||
* Temporary files used for history in dry_run mode are not longer left open 
 | 
			
		||||
 | 
			
		||||
## [0.14.9]
 | 
			
		||||
### Changed
 | 
			
		||||
* Seed random before environment initialization
 | 
			
		||||
 
 | 
			
		||||
@@ -31,7 +31,7 @@
 | 
			
		||||
# Add any Sphinx extension module names here, as strings. They can be
 | 
			
		||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
 | 
			
		||||
# ones.
 | 
			
		||||
extensions = []
 | 
			
		||||
extensions = ['IPython.sphinxext.ipython_console_highlighting']
 | 
			
		||||
 | 
			
		||||
# Add any paths that contain templates here, relative to this directory.
 | 
			
		||||
templates_path = ['_templates']
 | 
			
		||||
@@ -69,7 +69,7 @@ language = None
 | 
			
		||||
# List of patterns, relative to source directory, that match files and
 | 
			
		||||
# directories to ignore when looking for source files.
 | 
			
		||||
# This patterns also effect to html_static_path and html_extra_path
 | 
			
		||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
 | 
			
		||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '**.ipynb_checkpoints']
 | 
			
		||||
 | 
			
		||||
# The name of the Pygments (syntax highlighting) style to use.
 | 
			
		||||
pygments_style = 'sphinx'
 | 
			
		||||
 
 | 
			
		||||
@@ -218,3 +218,24 @@ These agents are programmed in much the same way as network agents, the only dif
 | 
			
		||||
 | 
			
		||||
You may use environment agents to model events that a normal agent cannot control, such as natural disasters or chance.
 | 
			
		||||
They are also useful to add behavior that has little to do with the network and the interactions within that network.
 | 
			
		||||
 | 
			
		||||
Templating
 | 
			
		||||
==========
 | 
			
		||||
 | 
			
		||||
Sometimes, it is useful to parameterize a simulation and run it over a range of values in order to compare each run and measure the effect of those parameters in the simulation.
 | 
			
		||||
For instance, you may want to run a simulation with different agent distributions.
 | 
			
		||||
 | 
			
		||||
This can be done in Soil using **templates**.
 | 
			
		||||
A template is a configuration where some of the values are specified with a variable.
 | 
			
		||||
e.g.,  ``weight: "{{ var1 }}"`` instead of ``weight: 1``.
 | 
			
		||||
There are two types of variables, depending on how their values are decided:
 | 
			
		||||
 | 
			
		||||
* Fixed. A list of values is provided, and a new simulation is run for each possible value. If more than a variable is given, a new simulation will be run per combination of values.
 | 
			
		||||
* Bounded/Sampled. The bounds of the variable are provided, along with a sampler method, which will be used to compute all the configuration combinations.
 | 
			
		||||
 | 
			
		||||
When fixed and bounded variables are mixed, Soil generates a new configuration per combination of fixed values and bounded values.
 | 
			
		||||
 | 
			
		||||
Here is an example with a single fixed variable and two bounded variable:
 | 
			
		||||
 | 
			
		||||
.. literalinclude:: ../examples/template.yml
 | 
			
		||||
   :language: yaml
 | 
			
		||||
 
 | 
			
		||||
@@ -500,7 +500,7 @@
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.6.5"
 | 
			
		||||
   "version": "3.8.5"
 | 
			
		||||
  },
 | 
			
		||||
  "toc": {
 | 
			
		||||
   "colors": {
 | 
			
		||||
 
 | 
			
		||||
@@ -80800,7 +80800,7 @@
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.6.5"
 | 
			
		||||
   "version": "3.8.6"
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
from soil.agents import FSM, state, default_state, BaseAgent
 | 
			
		||||
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from random import random, choice
 | 
			
		||||
from itertools import islice
 | 
			
		||||
@@ -80,7 +80,7 @@ class RabbitModel(FSM):
 | 
			
		||||
                self.env.add_edge(self['mate'], child.id)
 | 
			
		||||
                # self.add_edge()
 | 
			
		||||
                self.debug('A BABY IS COMING TO LIFE')
 | 
			
		||||
                self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.global_topology.number_of_nodes())+1
 | 
			
		||||
                self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.topology.number_of_nodes())+1
 | 
			
		||||
                self.debug('Rabbits alive: {}'.format(self.env['rabbits_alive']))
 | 
			
		||||
                self['offspring'] += 1
 | 
			
		||||
                self.env.get_agent(self['mate'])['offspring'] += 1
 | 
			
		||||
@@ -97,12 +97,14 @@ class RabbitModel(FSM):
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RandomAccident(BaseAgent):
 | 
			
		||||
class RandomAccident(NetworkAgent):
 | 
			
		||||
 | 
			
		||||
    level = logging.DEBUG
 | 
			
		||||
 | 
			
		||||
    def step(self):
 | 
			
		||||
        rabbits_total = self.global_topology.number_of_nodes()
 | 
			
		||||
        rabbits_total = self.topology.number_of_nodes()
 | 
			
		||||
        if 'rabbits_alive' not in self.env:
 | 
			
		||||
            self.env['rabbits_alive'] = 0
 | 
			
		||||
        rabbits_alive = self.env.get('rabbits_alive', rabbits_total)
 | 
			
		||||
        prob_death = self.env.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
 | 
			
		||||
        self.debug('Killing some rabbits with prob={}!'.format(prob_death))
 | 
			
		||||
@@ -116,5 +118,5 @@ class RandomAccident(BaseAgent):
 | 
			
		||||
                self.log('Rabbits alive: {}'.format(self.env['rabbits_alive']))
 | 
			
		||||
                i.set_state(i.dead)
 | 
			
		||||
        self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total))
 | 
			
		||||
        if self.count_agents(state_id=RabbitModel.dead.id) == self.global_topology.number_of_nodes():
 | 
			
		||||
        if self.count_agents(state_id=RabbitModel.dead.id) == self.topology.number_of_nodes():
 | 
			
		||||
            self.die()
 | 
			
		||||
 
 | 
			
		||||
@@ -1,13 +1,8 @@
 | 
			
		||||
---
 | 
			
		||||
vars:
 | 
			
		||||
  bounds:
 | 
			
		||||
    x1: [0, 1]
 | 
			
		||||
    x2: [1, 2]
 | 
			
		||||
  fixed:
 | 
			
		||||
    x3: ["a", "b", "c"]
 | 
			
		||||
sampler: "SALib.sample.morris.sample"
 | 
			
		||||
samples: 10
 | 
			
		||||
template: |
 | 
			
		||||
sampler:
 | 
			
		||||
  method: "SALib.sample.morris.sample"
 | 
			
		||||
  N: 10
 | 
			
		||||
template:
 | 
			
		||||
  group: simple
 | 
			
		||||
  num_trials: 1
 | 
			
		||||
  interval: 1
 | 
			
		||||
@@ -19,11 +14,17 @@ template: |
 | 
			
		||||
    n: 10
 | 
			
		||||
  network_agents:
 | 
			
		||||
    - agent_type: CounterModel
 | 
			
		||||
      weight: {{ x1 }}
 | 
			
		||||
      weight: "{{ x1 }}"
 | 
			
		||||
      state:
 | 
			
		||||
        id: 0
 | 
			
		||||
    - agent_type: AggregatedCounter
 | 
			
		||||
      weight: {{ 1 - x1 }}
 | 
			
		||||
      weight: "{{ 1 - x1 }}"
 | 
			
		||||
  environment_params:
 | 
			
		||||
    name: {{ x3 }}
 | 
			
		||||
    name: "{{ x3 }}"
 | 
			
		||||
  skip_test: true
 | 
			
		||||
vars:
 | 
			
		||||
  bounds:
 | 
			
		||||
    x1: [0, 1]
 | 
			
		||||
    x2: [1, 2]
 | 
			
		||||
  fixed:
 | 
			
		||||
    x3: ["a", "b", "c"]
 | 
			
		||||
 
 | 
			
		||||
@@ -195,14 +195,14 @@ class TerroristNetworkModel(TerroristSpreadModel):
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
    def get_distance(self, target):
 | 
			
		||||
        source_x, source_y = nx.get_node_attributes(self.global_topology, 'pos')[self.id]
 | 
			
		||||
        target_x, target_y = nx.get_node_attributes(self.global_topology, 'pos')[target]
 | 
			
		||||
        source_x, source_y = nx.get_node_attributes(self.topology, 'pos')[self.id]
 | 
			
		||||
        target_x, target_y = nx.get_node_attributes(self.topology, 'pos')[target]
 | 
			
		||||
        dx = abs( source_x - target_x )
 | 
			
		||||
        dy = abs( source_y - target_y )
 | 
			
		||||
        return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 )
 | 
			
		||||
 | 
			
		||||
    def shortest_path_length(self, target):
 | 
			
		||||
        try:
 | 
			
		||||
            return nx.shortest_path_length(self.global_topology, self.id, target)
 | 
			
		||||
            return nx.shortest_path_length(self.topology, self.id, target)
 | 
			
		||||
        except nx.NetworkXNoPath:
 | 
			
		||||
            return float('inf')
 | 
			
		||||
 
 | 
			
		||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							@@ -1,6 +1,5 @@
 | 
			
		||||
nxsim>=0.1.2
 | 
			
		||||
simpy
 | 
			
		||||
networkx>=2.0,<2.4
 | 
			
		||||
simpy>=4.0
 | 
			
		||||
networkx>=2.5
 | 
			
		||||
numpy
 | 
			
		||||
matplotlib
 | 
			
		||||
pyyaml>=5.1
 | 
			
		||||
 
 | 
			
		||||
@@ -1 +1 @@
 | 
			
		||||
0.14.9
 | 
			
		||||
0.15.1
 | 
			
		||||
@@ -17,12 +17,12 @@ from .environment import Environment
 | 
			
		||||
from .history import History
 | 
			
		||||
from . import serialization
 | 
			
		||||
from . import analysis
 | 
			
		||||
from .utils import logger
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    import argparse
 | 
			
		||||
    from . import simulation
 | 
			
		||||
 | 
			
		||||
    logging.basicConfig(level=logging.INFO)
 | 
			
		||||
    logging.info('Running SOIL version: {}'.format(__version__))
 | 
			
		||||
 | 
			
		||||
    parser = argparse.ArgumentParser(description='Run a SOIL simulation')
 | 
			
		||||
@@ -40,6 +40,8 @@ def main():
 | 
			
		||||
                        help='Dump GEXF graph. Defaults to false.')
 | 
			
		||||
    parser.add_argument('--csv', action='store_true',
 | 
			
		||||
                        help='Dump history in CSV format. Defaults to false.')
 | 
			
		||||
    parser.add_argument('--level', type=str,
 | 
			
		||||
                        help='Logging level')
 | 
			
		||||
    parser.add_argument('--output', '-o', type=str, default="soil_output",
 | 
			
		||||
                        help='folder to write results to. It defaults to the current directory.')
 | 
			
		||||
    parser.add_argument('--synchronous', action='store_true',
 | 
			
		||||
@@ -48,6 +50,7 @@ def main():
 | 
			
		||||
                        help='Export environment and/or simulations using this exporter')
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    logging.basicConfig(level=getattr(logging, (args.level or 'INFO').upper()))
 | 
			
		||||
 | 
			
		||||
    if os.getcwd() not in sys.path:
 | 
			
		||||
        sys.path.append(os.getcwd())
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,7 @@ class BassModel(BaseAgent):
 | 
			
		||||
        imitation_prob
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, environment, agent_id, state):
 | 
			
		||||
    def __init__(self, environment, agent_id, state, **kwargs):
 | 
			
		||||
        super().__init__(environment=environment, agent_id=agent_id, state=state)
 | 
			
		||||
        env_params = environment.environment_params
 | 
			
		||||
        self.state['sentimentCorrelation'] = 0
 | 
			
		||||
@@ -19,7 +19,7 @@ class BassModel(BaseAgent):
 | 
			
		||||
 | 
			
		||||
    def behaviour(self):
 | 
			
		||||
        # Outside effects
 | 
			
		||||
        if random.random() < self.state_params['innovation_prob']:
 | 
			
		||||
        if random.random() < self['innovation_prob']:
 | 
			
		||||
            if self.state['id'] == 0:
 | 
			
		||||
                self.state['id'] = 1
 | 
			
		||||
                self.state['sentimentCorrelation'] = 1
 | 
			
		||||
@@ -32,7 +32,7 @@ class BassModel(BaseAgent):
 | 
			
		||||
        if self.state['id'] == 0:
 | 
			
		||||
            aware_neighbors = self.get_neighboring_agents(state_id=1)
 | 
			
		||||
            num_neighbors_aware = len(aware_neighbors)
 | 
			
		||||
            if random.random() < (self.state_params['imitation_prob']*num_neighbors_aware):
 | 
			
		||||
            if random.random() < (self['imitation_prob']*num_neighbors_aware):
 | 
			
		||||
                self.state['id'] = 1
 | 
			
		||||
                self.state['sentimentCorrelation'] = 1
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,7 @@
 | 
			
		||||
from . import BaseAgent
 | 
			
		||||
from . import NetworkAgent
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CounterModel(BaseAgent):
 | 
			
		||||
class CounterModel(NetworkAgent):
 | 
			
		||||
    """
 | 
			
		||||
    Dummy behaviour. It counts the number of nodes in the network and neighbors
 | 
			
		||||
    in each step and adds it to its state.
 | 
			
		||||
@@ -9,14 +9,14 @@ class CounterModel(BaseAgent):
 | 
			
		||||
 | 
			
		||||
    def step(self):
 | 
			
		||||
        # Outside effects
 | 
			
		||||
        total = len(list(self.get_all_agents()))
 | 
			
		||||
        total = len(list(self.get_agents()))
 | 
			
		||||
        neighbors = len(list(self.get_neighboring_agents()))
 | 
			
		||||
        self['times'] = self.get('times', 0) + 1
 | 
			
		||||
        self['neighbors'] = neighbors
 | 
			
		||||
        self['total'] = total
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AggregatedCounter(BaseAgent):
 | 
			
		||||
class AggregatedCounter(NetworkAgent):
 | 
			
		||||
    """
 | 
			
		||||
    Dummy behaviour. It counts the number of nodes in the network and neighbors
 | 
			
		||||
    in each step and adds it to its state.
 | 
			
		||||
@@ -33,6 +33,6 @@ class AggregatedCounter(BaseAgent):
 | 
			
		||||
        self['times'] += 1
 | 
			
		||||
        neighbors = len(list(self.get_neighboring_agents()))
 | 
			
		||||
        self['neighbors'] += neighbors
 | 
			
		||||
        total = len(list(self.get_all_agents()))
 | 
			
		||||
        total = len(list(self.get_agents()))
 | 
			
		||||
        self['total'] += total
 | 
			
		||||
        self.debug('Running for step: {}. Total: {}'.format(self.now, total))
 | 
			
		||||
 
 | 
			
		||||
@@ -3,19 +3,19 @@
 | 
			
		||||
# for x in range(0, settings.network_params["number_of_nodes"]):
 | 
			
		||||
#     sentimentCorrelationNodeArray.append({'id': x})
 | 
			
		||||
# Initialize agent states. Let's assume everyone is normal.
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
import nxsim
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
from collections import OrderedDict
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from functools import partial
 | 
			
		||||
from scipy.spatial import cKDTree as KDTree
 | 
			
		||||
import json
 | 
			
		||||
import simpy
 | 
			
		||||
 | 
			
		||||
from functools import wraps
 | 
			
		||||
 | 
			
		||||
from .. import serialization, history
 | 
			
		||||
from .. import serialization, history, utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def as_node(agent):
 | 
			
		||||
@@ -24,7 +24,7 @@ def as_node(agent):
 | 
			
		||||
    return agent
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseAgent(nxsim.BaseAgent):
 | 
			
		||||
class BaseAgent:
 | 
			
		||||
    """
 | 
			
		||||
    A special simpy BaseAgent that keeps track of its state history.
 | 
			
		||||
    """
 | 
			
		||||
@@ -32,14 +32,13 @@ class BaseAgent(nxsim.BaseAgent):
 | 
			
		||||
    defaults = {}
 | 
			
		||||
 | 
			
		||||
    def __init__(self, environment, agent_id, state=None,
 | 
			
		||||
                 name=None, interval=None, **state_params):
 | 
			
		||||
                 name=None, interval=None):
 | 
			
		||||
        # Check for REQUIRED arguments
 | 
			
		||||
        assert environment is not None, TypeError('__init__ missing 1 required keyword argument: \'environment\'. '
 | 
			
		||||
                                                  'Cannot be NoneType.')
 | 
			
		||||
        # Initialize agent parameters
 | 
			
		||||
        self.id = agent_id
 | 
			
		||||
        self.name = name or '{}[{}]'.format(type(self).__name__, self.id)
 | 
			
		||||
        self.state_params = state_params
 | 
			
		||||
 | 
			
		||||
        # Register agent to environment
 | 
			
		||||
        self.env = environment
 | 
			
		||||
@@ -51,10 +50,10 @@ class BaseAgent(nxsim.BaseAgent):
 | 
			
		||||
        self.state = real_state
 | 
			
		||||
        self.interval = interval
 | 
			
		||||
 | 
			
		||||
        if not hasattr(self, 'level'):
 | 
			
		||||
            self.level = logging.DEBUG
 | 
			
		||||
        self.logger = logging.getLogger(self.env.name)
 | 
			
		||||
        self.logger.setLevel(self.level)
 | 
			
		||||
        self.logger = logging.getLogger(self.env.name).getChild(self.name)
 | 
			
		||||
 | 
			
		||||
        if hasattr(self, 'level'):
 | 
			
		||||
            self.logger.setLevel(self.level)
 | 
			
		||||
 | 
			
		||||
        # initialize every time an instance of the agent is created
 | 
			
		||||
        self.action = self.env.process(self.run())
 | 
			
		||||
@@ -75,14 +74,10 @@ class BaseAgent(nxsim.BaseAgent):
 | 
			
		||||
        for k, v in value.items():
 | 
			
		||||
            self[k] = v
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def global_topology(self):
 | 
			
		||||
        return self.env.G
 | 
			
		||||
    
 | 
			
		||||
    @property
 | 
			
		||||
    def environment_params(self):
 | 
			
		||||
        return self.env.environment_params
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    @environment_params.setter
 | 
			
		||||
    def environment_params(self, value):
 | 
			
		||||
        self.env.environment_params = value
 | 
			
		||||
@@ -135,36 +130,10 @@ class BaseAgent(nxsim.BaseAgent):
 | 
			
		||||
    def die(self, remove=False):
 | 
			
		||||
        self.alive = False
 | 
			
		||||
        if remove:
 | 
			
		||||
            super().die()
 | 
			
		||||
            self.remove_node(self.id)
 | 
			
		||||
 | 
			
		||||
    def step(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def count_agents(self, **kwargs):
 | 
			
		||||
        return len(list(self.get_agents(**kwargs)))
 | 
			
		||||
 | 
			
		||||
    def count_neighboring_agents(self, state_id=None, **kwargs):
 | 
			
		||||
        return len(super().get_neighboring_agents(state_id=state_id, **kwargs))
 | 
			
		||||
 | 
			
		||||
    def get_neighboring_agents(self, state_id=None, **kwargs):
 | 
			
		||||
        return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def get_agents(self, agents=None, limit_neighbors=False, **kwargs):
 | 
			
		||||
        if limit_neighbors:
 | 
			
		||||
            agents = super().get_agents(limit_neighbors=limit_neighbors)
 | 
			
		||||
        else:
 | 
			
		||||
            agents = self.env.get_agents(agents)
 | 
			
		||||
        return select(agents, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def log(self, message, *args, level=logging.INFO, **kwargs):
 | 
			
		||||
        message = message + " ".join(str(i) for i in args)
 | 
			
		||||
        message = "\t{:10}@{:>5}:\t{}".format(self.name, self.now, message)
 | 
			
		||||
        for k, v in kwargs:
 | 
			
		||||
            message += " {k}={v} ".format(k, v)
 | 
			
		||||
        extra = {}
 | 
			
		||||
        extra['now'] = self.now
 | 
			
		||||
        extra['id'] = self.id
 | 
			
		||||
        return self.logger.log(level, message, extra=extra)
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    def debug(self, *args, **kwargs):
 | 
			
		||||
        return self.log(*args, level=logging.DEBUG, **kwargs)
 | 
			
		||||
@@ -192,24 +161,59 @@ class BaseAgent(nxsim.BaseAgent):
 | 
			
		||||
        self._state = state['_state']
 | 
			
		||||
        self.env = state['environment']
 | 
			
		||||
 | 
			
		||||
    def add_edge(self, node1, node2, **attrs):
 | 
			
		||||
        node1 = as_node(node1)
 | 
			
		||||
        node2 = as_node(node2)
 | 
			
		||||
class NetworkAgent(BaseAgent):
 | 
			
		||||
 | 
			
		||||
        for n in [node1, node2]:
 | 
			
		||||
            if n not in self.global_topology.nodes(data=False):
 | 
			
		||||
                raise ValueError('"{}" not in the graph'.format(n))
 | 
			
		||||
        return self.global_topology.add_edge(node1, node2, **attrs)
 | 
			
		||||
    @property
 | 
			
		||||
    def topology(self):
 | 
			
		||||
        return self.env.G
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def G(self):
 | 
			
		||||
        return self.env.G
 | 
			
		||||
 | 
			
		||||
    def count_agents(self, **kwargs):
 | 
			
		||||
        return len(list(self.get_agents(**kwargs)))
 | 
			
		||||
 | 
			
		||||
    def count_neighboring_agents(self, state_id=None, **kwargs):
 | 
			
		||||
        return len(self.get_neighboring_agents(state_id=state_id, **kwargs))
 | 
			
		||||
 | 
			
		||||
    def get_neighboring_agents(self, state_id=None, **kwargs):
 | 
			
		||||
        return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def get_agents(self, agents=None, limit_neighbors=False, **kwargs):
 | 
			
		||||
        if limit_neighbors:
 | 
			
		||||
            agents = self.topology.neighbors(self.id)
 | 
			
		||||
 | 
			
		||||
        agents = self.env.get_agents(agents)
 | 
			
		||||
        return select(agents, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def log(self, message, *args, level=logging.INFO, **kwargs):
 | 
			
		||||
        message = message + " ".join(str(i) for i in args)
 | 
			
		||||
        message = " @{:>3}: {}".format(self.now, message)
 | 
			
		||||
        for k, v in kwargs:
 | 
			
		||||
            message += " {k}={v} ".format(k, v)
 | 
			
		||||
        extra = {}
 | 
			
		||||
        extra['now'] = self.now
 | 
			
		||||
        extra['agent_id'] = self.id
 | 
			
		||||
        extra['agent_name'] = self.name
 | 
			
		||||
        return self.logger.log(level, message, extra=extra)
 | 
			
		||||
 | 
			
		||||
    def subgraph(self, center=True, **kwargs):
 | 
			
		||||
        include = [self] if center else []
 | 
			
		||||
        return self.global_topology.subgraph(n.id for n in self.get_agents(**kwargs)+include)
 | 
			
		||||
        return self.topology.subgraph(n.id for n in self.get_agents(**kwargs)+include)
 | 
			
		||||
 | 
			
		||||
    def remove_node(self, agent_id):
 | 
			
		||||
        self.topology.remove_node(agent_id)
 | 
			
		||||
 | 
			
		||||
class NetworkAgent(BaseAgent):
 | 
			
		||||
    def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
 | 
			
		||||
        # return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
 | 
			
		||||
        if self.id not in self.topology.nodes(data=False):
 | 
			
		||||
            raise ValueError('{} not in list of existing agents in the network'.format(self.id))
 | 
			
		||||
        if other not in self.topology.nodes(data=False):
 | 
			
		||||
            raise ValueError('{} not in list of existing agents in the network'.format(other))
 | 
			
		||||
 | 
			
		||||
        self.topology.add_edge(self.id, other, edge_attr_dict=edge_attr_dict, *edge_attrs)
 | 
			
		||||
 | 
			
		||||
    def add_edge(self, other, **kwargs):
 | 
			
		||||
        return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def ego_search(self, steps=1, center=False, node=None, **kwargs):
 | 
			
		||||
        '''Get a list of nodes in the ego network of *node* of radius *steps*'''
 | 
			
		||||
@@ -220,14 +224,14 @@ class NetworkAgent(BaseAgent):
 | 
			
		||||
    def degree(self, node, force=False):
 | 
			
		||||
        node = as_node(node)
 | 
			
		||||
        if force or (not hasattr(self.env, '_degree')) or getattr(self.env, '_last_step', 0) < self.now:
 | 
			
		||||
            self.env._degree = nx.degree_centrality(self.global_topology)
 | 
			
		||||
            self.env._degree = nx.degree_centrality(self.topology)
 | 
			
		||||
            self.env._last_step = self.now
 | 
			
		||||
        return self.env._degree[node]
 | 
			
		||||
 | 
			
		||||
    def betweenness(self, node, force=False):
 | 
			
		||||
        node = as_node(node)
 | 
			
		||||
        if force or (not hasattr(self.env, '_betweenness')) or getattr(self.env, '_last_step', 0) < self.now:
 | 
			
		||||
            self.env._betweenness = nx.betweenness_centrality(self.global_topology)
 | 
			
		||||
            self.env._betweenness = nx.betweenness_centrality(self.topology)
 | 
			
		||||
            self.env._last_step = self.now
 | 
			
		||||
        return self.env._betweenness[node]
 | 
			
		||||
 | 
			
		||||
@@ -292,16 +296,22 @@ class MetaFSM(type):
 | 
			
		||||
        cls.states = states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FSM(BaseAgent, metaclass=MetaFSM):
 | 
			
		||||
class FSM(NetworkAgent, metaclass=MetaFSM):
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super(FSM, self).__init__(*args, **kwargs)
 | 
			
		||||
        if 'id' not in self.state:
 | 
			
		||||
            if not self.default_state:
 | 
			
		||||
                raise ValueError('No default state specified for {}'.format(self.id))
 | 
			
		||||
            self['id'] = self.default_state.id
 | 
			
		||||
        self._next_change = simpy.core.Infinity
 | 
			
		||||
        self._next_state = self.state
 | 
			
		||||
 | 
			
		||||
    def step(self):
 | 
			
		||||
        if 'id' in self.state:
 | 
			
		||||
        if self._next_change < self.now:
 | 
			
		||||
            next_state = self._next_state
 | 
			
		||||
            self._next_change = simpy.core.Infinity
 | 
			
		||||
            self['id'] = next_state
 | 
			
		||||
        elif 'id' in self.state:
 | 
			
		||||
            next_state = self['id']
 | 
			
		||||
        elif self.default_state:
 | 
			
		||||
            next_state = self.default_state.id
 | 
			
		||||
@@ -311,6 +321,10 @@ class FSM(BaseAgent, metaclass=MetaFSM):
 | 
			
		||||
            raise Exception('{} is not a valid id for {}'.format(next_state, self))
 | 
			
		||||
        return self.states[next_state](self)
 | 
			
		||||
 | 
			
		||||
    def next_state(self, state):
 | 
			
		||||
        self._next_change = self.now
 | 
			
		||||
        self._next_state = state
 | 
			
		||||
 | 
			
		||||
    def set_state(self, state):
 | 
			
		||||
        if hasattr(state, 'id'):
 | 
			
		||||
            state = state.id
 | 
			
		||||
@@ -371,14 +385,18 @@ def calculate_distribution(network_agents=None,
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError('Specify a distribution or a default agent type')
 | 
			
		||||
 | 
			
		||||
    # Fix missing weights and incompatible types
 | 
			
		||||
    for x in network_agents:
 | 
			
		||||
        x['weight'] = float(x.get('weight', 1))
 | 
			
		||||
 | 
			
		||||
    # Calculate the thresholds
 | 
			
		||||
    total = sum(x.get('weight', 1) for x in network_agents)
 | 
			
		||||
    total = sum(x['weight'] for x in network_agents)
 | 
			
		||||
    acc = 0
 | 
			
		||||
    for v in network_agents:
 | 
			
		||||
        if 'ids' in v:
 | 
			
		||||
            v['threshold'] = STATIC_THRESHOLD
 | 
			
		||||
            continue
 | 
			
		||||
        upper = acc + (v.get('weight', 1)/total)
 | 
			
		||||
        upper = acc + (v['weight']/total)
 | 
			
		||||
        v['threshold'] = [acc, upper]
 | 
			
		||||
        acc = upper
 | 
			
		||||
    return network_agents
 | 
			
		||||
@@ -425,7 +443,7 @@ def _validate_states(states, topology):
 | 
			
		||||
    states = states or []
 | 
			
		||||
    if isinstance(states, dict):
 | 
			
		||||
        for x in states:
 | 
			
		||||
            assert x in topology.node
 | 
			
		||||
            assert x in topology.nodes
 | 
			
		||||
    else:
 | 
			
		||||
        assert len(states) <= len(topology)
 | 
			
		||||
    return states
 | 
			
		||||
 
 | 
			
		||||
@@ -28,13 +28,13 @@ def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
 | 
			
		||||
                df = read_csv(trial_data, **kwargs)
 | 
			
		||||
                yield config_file, df, config
 | 
			
		||||
        else:
 | 
			
		||||
            for trial_data in sorted(glob.glob(join(folder, '*.db.sqlite'))):
 | 
			
		||||
            for trial_data in sorted(glob.glob(join(folder, '*.sqlite'))):
 | 
			
		||||
                df = read_sql(trial_data, **kwargs)
 | 
			
		||||
                yield config_file, df, config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_sql(db, *args, **kwargs):
 | 
			
		||||
    h = history.History(db_path=db, backup=False)
 | 
			
		||||
    h = history.History(db_path=db, backup=False, readonly=True)
 | 
			
		||||
    df = h.read_sql(*args, **kwargs)
 | 
			
		||||
    return df
 | 
			
		||||
 | 
			
		||||
@@ -69,6 +69,13 @@ def convert_types_slow(df):
 | 
			
		||||
    df = df.apply(convert_row, axis=1)
 | 
			
		||||
    return df
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_processed(df):
 | 
			
		||||
    env = df.loc[:, df.columns.get_level_values(1).isin(['env', 'stats'])]
 | 
			
		||||
    agents = df.loc[:, ~df.columns.get_level_values(1).isin(['env', 'stats'])]
 | 
			
		||||
    return env, agents
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_df(df):
 | 
			
		||||
    '''
 | 
			
		||||
    Split a dataframe in two dataframes: one with the history of agents,
 | 
			
		||||
@@ -136,7 +143,7 @@ def get_value(df, *keys, aggfunc='sum'):
 | 
			
		||||
    return df.groupby(axis=1, level=0).agg(aggfunc)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def plot_all(*args, **kwargs):
 | 
			
		||||
def plot_all(*args, plot_args={}, **kwargs):
 | 
			
		||||
    '''
 | 
			
		||||
    Read all the trial data and plot the result of applying a function on them.
 | 
			
		||||
    '''
 | 
			
		||||
@@ -144,14 +151,17 @@ def plot_all(*args, **kwargs):
 | 
			
		||||
    ps = []
 | 
			
		||||
    for line in dfs:
 | 
			
		||||
        f, df, config = line
 | 
			
		||||
        df.plot(title=config['name'])
 | 
			
		||||
        if len(df) < 1:
 | 
			
		||||
            continue
 | 
			
		||||
        df.plot(title=config['name'], **plot_args)
 | 
			
		||||
        ps.append(df)
 | 
			
		||||
    return ps
 | 
			
		||||
 | 
			
		||||
def do_all(pattern, func, *keys, include_env=False, **kwargs):
 | 
			
		||||
    for config_file, df, config in read_data(pattern, keys=keys):
 | 
			
		||||
        if len(df) < 1:
 | 
			
		||||
            continue
 | 
			
		||||
        p = func(df, *keys, **kwargs)
 | 
			
		||||
        p.plot(title=config['name'])
 | 
			
		||||
        yield config_file, p, config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -8,11 +8,10 @@ import yaml
 | 
			
		||||
import tempfile
 | 
			
		||||
import pandas as pd
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from collections import Counter
 | 
			
		||||
from networkx.readwrite import json_graph
 | 
			
		||||
 | 
			
		||||
import networkx as nx
 | 
			
		||||
import nxsim
 | 
			
		||||
import simpy
 | 
			
		||||
 | 
			
		||||
from . import serialization, agents, analysis, history, utils
 | 
			
		||||
 | 
			
		||||
@@ -23,7 +22,7 @@ _CONFIG_PROPS = [ 'name',
 | 
			
		||||
                 'interval',
 | 
			
		||||
                 ]
 | 
			
		||||
 | 
			
		||||
class Environment(nxsim.NetworkEnvironment):
 | 
			
		||||
class Environment(simpy.Environment):
 | 
			
		||||
    """
 | 
			
		||||
    The environment is key in a simulation. It contains the network topology,
 | 
			
		||||
    a reference to network and environment agents, as well as the environment
 | 
			
		||||
@@ -42,7 +41,10 @@ class Environment(nxsim.NetworkEnvironment):
 | 
			
		||||
                 interval=1,
 | 
			
		||||
                 seed=None,
 | 
			
		||||
                 topology=None,
 | 
			
		||||
                 *args, **kwargs):
 | 
			
		||||
                 initial_time=0,
 | 
			
		||||
                 **environment_params):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        self.name = name or 'UnnamedEnvironment'
 | 
			
		||||
        seed = seed or time.time()
 | 
			
		||||
        random.seed(seed)
 | 
			
		||||
@@ -52,7 +54,11 @@ class Environment(nxsim.NetworkEnvironment):
 | 
			
		||||
        self.default_state = deepcopy(default_state) or {}
 | 
			
		||||
        if not topology:
 | 
			
		||||
            topology = nx.Graph()
 | 
			
		||||
        super().__init__(*args, topology=topology, **kwargs)
 | 
			
		||||
        self.G = nx.Graph(topology) 
 | 
			
		||||
 | 
			
		||||
        super().__init__(initial_time=initial_time)
 | 
			
		||||
        self.environment_params = environment_params
 | 
			
		||||
 | 
			
		||||
        self._env_agents = {}
 | 
			
		||||
        self.interval = interval
 | 
			
		||||
        self._history = history.History(name=self.name,
 | 
			
		||||
@@ -151,12 +157,10 @@ class Environment(nxsim.NetworkEnvironment):
 | 
			
		||||
        start = start or self.now
 | 
			
		||||
        return self.G.add_edge(agent1, agent2, **attrs)
 | 
			
		||||
 | 
			
		||||
    def run(self, *args, **kwargs):
 | 
			
		||||
    def run(self, until, *args, **kwargs):
 | 
			
		||||
        self._save_state()
 | 
			
		||||
        self.log_stats()
 | 
			
		||||
        super().run(*args, **kwargs)
 | 
			
		||||
        super().run(until, *args, **kwargs)
 | 
			
		||||
        self._history.flush_cache()
 | 
			
		||||
        self.log_stats()
 | 
			
		||||
 | 
			
		||||
    def _save_state(self, now=None):
 | 
			
		||||
        serialization.logger.debug('Saving state @{}'.format(self.now))
 | 
			
		||||
@@ -318,25 +322,6 @@ class Environment(nxsim.NetworkEnvironment):
 | 
			
		||||
 | 
			
		||||
        return G
 | 
			
		||||
 | 
			
		||||
    def stats(self):
 | 
			
		||||
        stats = {}
 | 
			
		||||
        stats['network'] = {}
 | 
			
		||||
        stats['network']['n_nodes'] = self.G.number_of_nodes()
 | 
			
		||||
        stats['network']['n_edges'] = self.G.number_of_edges()
 | 
			
		||||
        c = Counter()
 | 
			
		||||
        c.update(a.__class__.__name__ for a in self.network_agents)
 | 
			
		||||
        stats['agents'] = {}
 | 
			
		||||
        stats['agents']['model_count'] = dict(c)
 | 
			
		||||
        c2 = Counter()
 | 
			
		||||
        c2.update(a['id'] for a in self.network_agents)
 | 
			
		||||
        stats['agents']['state_count'] = dict(c2)
 | 
			
		||||
        stats['params'] = self.environment_params
 | 
			
		||||
        return stats
 | 
			
		||||
 | 
			
		||||
    def log_stats(self):
 | 
			
		||||
        stats = self.stats()
 | 
			
		||||
        serialization.logger.info('Environment stats: \n{}'.format(yaml.dump(stats, default_flow_style=False)))
 | 
			
		||||
    
 | 
			
		||||
    def __getstate__(self):
 | 
			
		||||
        state = {}
 | 
			
		||||
        for prop in _CONFIG_PROPS:
 | 
			
		||||
@@ -344,6 +329,7 @@ class Environment(nxsim.NetworkEnvironment):
 | 
			
		||||
        state['G'] = json_graph.node_link_data(self.G)
 | 
			
		||||
        state['environment_agents'] = self._env_agents
 | 
			
		||||
        state['history'] = self._history
 | 
			
		||||
        state['_now'] = self._now
 | 
			
		||||
        return state
 | 
			
		||||
 | 
			
		||||
    def __setstate__(self, state):
 | 
			
		||||
@@ -352,6 +338,8 @@ class Environment(nxsim.NetworkEnvironment):
 | 
			
		||||
        self._env_agents = state['environment_agents']
 | 
			
		||||
        self.G = json_graph.node_link_graph(state['G'])
 | 
			
		||||
        self._history = state['history']
 | 
			
		||||
        self._now = state['_now']
 | 
			
		||||
        self._queue = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SoilEnvironment = Environment
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,11 @@
 | 
			
		||||
import os
 | 
			
		||||
import csv as csvlib
 | 
			
		||||
import time
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
import networkx as nx
 | 
			
		||||
import pandas as pd
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from .serialization import deserialize
 | 
			
		||||
from .utils import open_or_reuse, logger, timer
 | 
			
		||||
@@ -49,7 +50,7 @@ class Exporter:
 | 
			
		||||
    '''
 | 
			
		||||
 | 
			
		||||
    def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None):
 | 
			
		||||
        self.sim = simulation
 | 
			
		||||
        self.simulation = simulation
 | 
			
		||||
        outdir = outdir or os.path.join(os.getcwd(), 'soil_output')
 | 
			
		||||
        self.outdir = os.path.join(outdir,
 | 
			
		||||
                                   simulation.group or '',
 | 
			
		||||
@@ -59,12 +60,15 @@ class Exporter:
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        '''Method to call when the simulation starts'''
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def end(self):
 | 
			
		||||
    def end(self, stats):
 | 
			
		||||
        '''Method to call when the simulation ends'''
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def trial_end(self, env):
 | 
			
		||||
    def trial(self, env, stats):
 | 
			
		||||
        '''Method to call when a trial ends'''
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def output(self, f, mode='w', **kwargs):
 | 
			
		||||
        if self.dry_run:
 | 
			
		||||
@@ -84,13 +88,13 @@ class default(Exporter):
 | 
			
		||||
    def start(self):
 | 
			
		||||
        if not self.dry_run:
 | 
			
		||||
            logger.info('Dumping results to %s', self.outdir)
 | 
			
		||||
            self.sim.dump_yaml(outdir=self.outdir)
 | 
			
		||||
            self.simulation.dump_yaml(outdir=self.outdir)
 | 
			
		||||
        else:
 | 
			
		||||
            logger.info('NOT dumping results')
 | 
			
		||||
 | 
			
		||||
    def trial_end(self, env):
 | 
			
		||||
    def trial(self, env, stats):
 | 
			
		||||
        if not self.dry_run:
 | 
			
		||||
            with timer('Dumping simulation {} trial {}'.format(self.sim.name,
 | 
			
		||||
            with timer('Dumping simulation {} trial {}'.format(self.simulation.name,
 | 
			
		||||
                                                               env.name)):
 | 
			
		||||
                with self.output('{}.sqlite'.format(env.name), mode='wb') as f:
 | 
			
		||||
                    env.dump_sqlite(f)
 | 
			
		||||
@@ -98,21 +102,27 @@ class default(Exporter):
 | 
			
		||||
 | 
			
		||||
class csv(Exporter):
 | 
			
		||||
    '''Export the state of each environment (and its agents) in a separate CSV file'''
 | 
			
		||||
    def trial_end(self, env):
 | 
			
		||||
        with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.sim.name,
 | 
			
		||||
    def trial(self, env, stats):
 | 
			
		||||
        with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name,
 | 
			
		||||
                                                                          env.name,
 | 
			
		||||
                                                                          self.outdir)):
 | 
			
		||||
            with self.output('{}.csv'.format(env.name)) as f:
 | 
			
		||||
                env.dump_csv(f)
 | 
			
		||||
 | 
			
		||||
            with self.output('{}.stats.csv'.format(env.name)) as f:
 | 
			
		||||
                statwriter = csvlib.writer(f, delimiter='\t', quotechar='"', quoting=csvlib.QUOTE_ALL)
 | 
			
		||||
 | 
			
		||||
                for stat in stats:
 | 
			
		||||
                    statwriter.writerow(stat)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class gexf(Exporter):
 | 
			
		||||
    def trial_end(self, env):
 | 
			
		||||
    def trial(self, env, stats):
 | 
			
		||||
        if self.dry_run:
 | 
			
		||||
            logger.info('Not dumping GEXF in dry_run mode')
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        with timer('[GEXF] Dumping simulation {} trial {}'.format(self.sim.name,
 | 
			
		||||
        with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name,
 | 
			
		||||
                                                                  env.name)):
 | 
			
		||||
            with self.output('{}.gexf'.format(env.name), mode='wb') as f:
 | 
			
		||||
                env.dump_gexf(f)
 | 
			
		||||
@@ -124,56 +134,24 @@ class dummy(Exporter):
 | 
			
		||||
        with self.output('dummy', 'w') as f:
 | 
			
		||||
            f.write('simulation started @ {}\n'.format(time.time()))
 | 
			
		||||
 | 
			
		||||
    def trial_end(self, env):
 | 
			
		||||
    def trial(self, env, stats):
 | 
			
		||||
        with self.output('dummy', 'w') as f:
 | 
			
		||||
            for i in env.history_to_tuples():
 | 
			
		||||
                f.write(','.join(map(str, i)))
 | 
			
		||||
                f.write('\n')
 | 
			
		||||
 | 
			
		||||
    def end(self):
 | 
			
		||||
    def sim(self, stats):
 | 
			
		||||
        with self.output('dummy', 'a') as f:
 | 
			
		||||
            f.write('simulation ended @ {}\n'.format(time.time()))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class distribution(Exporter):
 | 
			
		||||
    '''
 | 
			
		||||
    Write the distribution of agent states at the end of each trial,
 | 
			
		||||
    the mean value, and its deviation.
 | 
			
		||||
    '''
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        self.means = []
 | 
			
		||||
        self.counts = []
 | 
			
		||||
 | 
			
		||||
    def trial_end(self, env):
 | 
			
		||||
        df = env[None, None, None].df()
 | 
			
		||||
        ix = df.index[-1]
 | 
			
		||||
        attrs = df.columns.levels[0]
 | 
			
		||||
        vc = {}
 | 
			
		||||
        stats = {}
 | 
			
		||||
        for a in attrs:
 | 
			
		||||
            t = df.loc[(ix, a)]
 | 
			
		||||
            try:
 | 
			
		||||
                self.means.append(('mean', a, t.mean()))
 | 
			
		||||
            except TypeError:
 | 
			
		||||
                for name, count in t.value_counts().iteritems():
 | 
			
		||||
                    self.counts.append(('count', a, name, count))
 | 
			
		||||
 | 
			
		||||
    def end(self):
 | 
			
		||||
        dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value'])
 | 
			
		||||
        dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count'])
 | 
			
		||||
        dfm = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
 | 
			
		||||
        dfc = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
 | 
			
		||||
        with self.output('counts.csv') as f:
 | 
			
		||||
            dfc.to_csv(f)
 | 
			
		||||
        with self.output('metrics.csv') as f:
 | 
			
		||||
            dfm.to_csv(f)
 | 
			
		||||
 | 
			
		||||
class graphdrawing(Exporter):
 | 
			
		||||
 | 
			
		||||
    def trial_end(self, env):
 | 
			
		||||
    def trial(self, env, stats):
 | 
			
		||||
        # Outside effects
 | 
			
		||||
        f = plt.figure()
 | 
			
		||||
        nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111))
 | 
			
		||||
        with open('graph-{}.png'.format(env.name)) as f:
 | 
			
		||||
            f.savefig(f)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										142
									
								
								soil/history.py
									
									
									
									
									
								
							
							
						
						
									
										142
									
								
								soil/history.py
									
									
									
									
									
								
							@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
 | 
			
		||||
from collections import UserDict, namedtuple
 | 
			
		||||
 | 
			
		||||
from . import serialization
 | 
			
		||||
from .utils import open_or_reuse
 | 
			
		||||
from .utils import open_or_reuse, unflatten_dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class History:
 | 
			
		||||
@@ -19,29 +19,43 @@ class History:
 | 
			
		||||
    Store and retrieve values from a sqlite database.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name=None, db_path=None, backup=False):
 | 
			
		||||
        self._db = None
 | 
			
		||||
    def __init__(self, name=None, db_path=None, backup=False, readonly=False):
 | 
			
		||||
        if readonly and (not os.path.exists(db_path)):
 | 
			
		||||
            raise Exception('The DB file does not exist. Cannot open in read-only mode')
 | 
			
		||||
 | 
			
		||||
        if db_path is None:
 | 
			
		||||
        self._db = None
 | 
			
		||||
        self._temp = db_path is None
 | 
			
		||||
        self._stats_columns = None
 | 
			
		||||
        self.readonly = readonly
 | 
			
		||||
 | 
			
		||||
        if self._temp:
 | 
			
		||||
            if not name:
 | 
			
		||||
                name = time.time()
 | 
			
		||||
            _, db_path = tempfile.mkstemp(suffix='{}.sqlite'.format(name))
 | 
			
		||||
            # The file will be deleted as soon as it's closed
 | 
			
		||||
            # Normally, that will be on destruction
 | 
			
		||||
            db_path = tempfile.NamedTemporaryFile(suffix='{}.sqlite'.format(name)).name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if backup and os.path.exists(db_path):
 | 
			
		||||
            newname = db_path + '.backup{}.sqlite'.format(time.time())
 | 
			
		||||
            os.rename(db_path, newname)
 | 
			
		||||
                newname = db_path + '.backup{}.sqlite'.format(time.time())
 | 
			
		||||
                os.rename(db_path, newname)
 | 
			
		||||
 | 
			
		||||
        self.db_path = db_path
 | 
			
		||||
 | 
			
		||||
        self.db = db_path
 | 
			
		||||
        self._dtypes = {}
 | 
			
		||||
        self._tups = []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        if self.readonly:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        with self.db:
 | 
			
		||||
            logger.debug('Creating database {}'.format(self.db_path))
 | 
			
		||||
            self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text text)''')
 | 
			
		||||
            self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text)''')
 | 
			
		||||
            self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''')
 | 
			
		||||
            self.db.execute('''CREATE TABLE IF NOT EXISTS stats (trial_id text)''')
 | 
			
		||||
            self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''')
 | 
			
		||||
        self._dtypes = {}
 | 
			
		||||
        self._tups = []
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def db(self):
 | 
			
		||||
@@ -58,6 +72,7 @@ class History:
 | 
			
		||||
        if isinstance(db_path, str):
 | 
			
		||||
            logger.debug('Connecting to database {}'.format(db_path))
 | 
			
		||||
            self._db = sqlite3.connect(db_path)
 | 
			
		||||
            self._db.row_factory = sqlite3.Row
 | 
			
		||||
        else:
 | 
			
		||||
            self._db = db_path
 | 
			
		||||
 | 
			
		||||
@@ -68,9 +83,56 @@ class History:
 | 
			
		||||
        self._db.close()
 | 
			
		||||
        self._db = None
 | 
			
		||||
 | 
			
		||||
    def save_stats(self, stat):
 | 
			
		||||
        if self.readonly:
 | 
			
		||||
            print('DB in readonly mode')
 | 
			
		||||
            return
 | 
			
		||||
        if not stat:
 | 
			
		||||
            return
 | 
			
		||||
        with self.db:
 | 
			
		||||
            if not self._stats_columns:
 | 
			
		||||
                self._stats_columns = list(c['name'] for c in self.db.execute('PRAGMA table_info(stats)'))
 | 
			
		||||
 | 
			
		||||
            for column, value in stat.items():
 | 
			
		||||
                if column in self._stats_columns:
 | 
			
		||||
                    continue
 | 
			
		||||
                dtype = 'text'
 | 
			
		||||
                if not isinstance(value, str):
 | 
			
		||||
                    try:
 | 
			
		||||
                        float(value)
 | 
			
		||||
                        dtype = 'real'
 | 
			
		||||
                        int(value)
 | 
			
		||||
                        dtype = 'int'
 | 
			
		||||
                    except ValueError:
 | 
			
		||||
                        pass
 | 
			
		||||
                self.db.execute('ALTER TABLE stats ADD "{}" "{}"'.format(column, dtype))
 | 
			
		||||
                self._stats_columns.append(column)
 | 
			
		||||
 | 
			
		||||
            columns = ", ".join(map(lambda x: '"{}"'.format(x), stat.keys()))
 | 
			
		||||
            values = ", ".join(['"{0}"'.format(col) for col in stat.values()])
 | 
			
		||||
            query = "INSERT INTO stats ({columns}) VALUES ({values})".format(
 | 
			
		||||
                columns=columns,
 | 
			
		||||
                values=values
 | 
			
		||||
            )
 | 
			
		||||
            self.db.execute(query)
 | 
			
		||||
 | 
			
		||||
    def get_stats(self, unflatten=True):
 | 
			
		||||
        rows = self.db.execute("select * from stats").fetchall()
 | 
			
		||||
        res = []
 | 
			
		||||
        for row in rows:
 | 
			
		||||
            d = {}
 | 
			
		||||
            for k in row.keys():
 | 
			
		||||
                if row[k] is None:
 | 
			
		||||
                    continue
 | 
			
		||||
                d[k] = row[k]
 | 
			
		||||
            if unflatten:
 | 
			
		||||
                d = unflatten_dict(d)
 | 
			
		||||
            res.append(d)
 | 
			
		||||
        return res
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def dtypes(self):
 | 
			
		||||
        self.read_types()
 | 
			
		||||
        self._read_types()
 | 
			
		||||
        return {k:v[0] for k, v in self._dtypes.items()}
 | 
			
		||||
 | 
			
		||||
    def save_tuples(self, tuples):
 | 
			
		||||
@@ -93,18 +155,10 @@ class History:
 | 
			
		||||
        Save a collection of records to the database.
 | 
			
		||||
        Database writes are cached.
 | 
			
		||||
        '''
 | 
			
		||||
        value = self.convert(key, value)
 | 
			
		||||
        self._tups.append(Record(agent_id=agent_id,
 | 
			
		||||
                                 t_step=t_step,
 | 
			
		||||
                                 key=key,
 | 
			
		||||
                                 value=value))
 | 
			
		||||
        if len(self._tups) > 100:
 | 
			
		||||
            self.flush_cache()
 | 
			
		||||
 | 
			
		||||
    def convert(self, key, value):
 | 
			
		||||
        """Get the serialized value for a given key."""
 | 
			
		||||
        if self.readonly:
 | 
			
		||||
            raise Exception('DB in readonly mode')
 | 
			
		||||
        if key not in self._dtypes:
 | 
			
		||||
            self.read_types()
 | 
			
		||||
            self._read_types()
 | 
			
		||||
            if key not in self._dtypes:
 | 
			
		||||
                name = serialization.name(value)
 | 
			
		||||
                serializer = serialization.serializer(name)
 | 
			
		||||
@@ -112,21 +166,21 @@ class History:
 | 
			
		||||
                self._dtypes[key] = (name, serializer, deserializer)
 | 
			
		||||
                with self.db:
 | 
			
		||||
                    self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name))
 | 
			
		||||
        return self._dtypes[key][1](value)
 | 
			
		||||
 | 
			
		||||
    def recover(self, key, value):
 | 
			
		||||
        """Get the deserialized value for a given key, and the serialized version."""
 | 
			
		||||
        if key not in self._dtypes:
 | 
			
		||||
            self.read_types()
 | 
			
		||||
        if key not in self._dtypes:
 | 
			
		||||
            raise ValueError("Unknown datatype for {} and {}".format(key, value))
 | 
			
		||||
        return self._dtypes[key][2](value)
 | 
			
		||||
        value = self._dtypes[key][1](value)
 | 
			
		||||
        self._tups.append(Record(agent_id=agent_id,
 | 
			
		||||
                                 t_step=t_step,
 | 
			
		||||
                                 key=key,
 | 
			
		||||
                                 value=value))
 | 
			
		||||
        if len(self._tups) > 100:
 | 
			
		||||
            self.flush_cache()
 | 
			
		||||
 | 
			
		||||
    def flush_cache(self):
 | 
			
		||||
        '''
 | 
			
		||||
        Use a cache to save state changes to avoid opening a session for every change.
 | 
			
		||||
        The cache will be flushed at the end of the simulation, and when history is accessed.
 | 
			
		||||
        '''
 | 
			
		||||
        if self.readonly:
 | 
			
		||||
            raise Exception('DB in readonly mode')
 | 
			
		||||
        logger.debug('Flushing cache {}'.format(self.db_path))
 | 
			
		||||
        with self.db:
 | 
			
		||||
            for rec in self._tups:
 | 
			
		||||
@@ -139,10 +193,14 @@ class History:
 | 
			
		||||
            res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall()
 | 
			
		||||
        for r in res:
 | 
			
		||||
            agent_id, t_step, key, value = r
 | 
			
		||||
            value = self.recover(key, value)
 | 
			
		||||
            if key not in self._dtypes:
 | 
			
		||||
                self._read_types()
 | 
			
		||||
            if key not in self._dtypes:
 | 
			
		||||
                raise ValueError("Unknown datatype for {} and {}".format(key, value))
 | 
			
		||||
            value = self._dtypes[key][2](value)
 | 
			
		||||
            yield agent_id, t_step, key, value
 | 
			
		||||
 | 
			
		||||
    def read_types(self):
 | 
			
		||||
    def _read_types(self):
 | 
			
		||||
        with self.db:
 | 
			
		||||
            res = self.db.execute("select key, value_type from value_types ").fetchall()
 | 
			
		||||
        for k, v in res:
 | 
			
		||||
@@ -167,7 +225,7 @@ class History:
 | 
			
		||||
 | 
			
		||||
    def read_sql(self, keys=None, agent_ids=None, t_steps=None, convert_types=False, limit=-1):
 | 
			
		||||
 | 
			
		||||
        self.read_types()
 | 
			
		||||
        self._read_types()
 | 
			
		||||
 | 
			
		||||
        def escape_and_join(v):
 | 
			
		||||
            if v is None:
 | 
			
		||||
@@ -181,7 +239,13 @@ class History:
 | 
			
		||||
 | 
			
		||||
        last_df = None
 | 
			
		||||
        if t_steps:
 | 
			
		||||
            # Look for the last value before the minimum step in the query
 | 
			
		||||
            # Convert negative indices into positive
 | 
			
		||||
            if any(x<0 for x in t_steps):
 | 
			
		||||
                max_t = int(self.db.execute("select max(t_step) from history").fetchone()[0])
 | 
			
		||||
                t_steps = [t if t>0 else max_t+1+t for t in t_steps]
 | 
			
		||||
 | 
			
		||||
            # We will be doing ffill interpolation, so we need to look for
 | 
			
		||||
            # the last value before the minimum step in the query
 | 
			
		||||
            min_step = min(t_steps)
 | 
			
		||||
            last_filters = ['t_step < {}'.format(min_step),]
 | 
			
		||||
            last_filters = last_filters + filters
 | 
			
		||||
@@ -219,7 +283,11 @@ class History:
 | 
			
		||||
        for k, v in self._dtypes.items():
 | 
			
		||||
            if k in df_p:
 | 
			
		||||
                dtype, _, deserial = v
 | 
			
		||||
                df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
 | 
			
		||||
                try:
 | 
			
		||||
                    df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
 | 
			
		||||
                except (TypeError, ValueError):
 | 
			
		||||
                    # Avoid forward-filling unknown/incompatible types
 | 
			
		||||
                    continue
 | 
			
		||||
        if t_steps:
 | 
			
		||||
            df_p = df_p.reindex(t_steps, method='ffill')
 | 
			
		||||
        return df_p.ffill()
 | 
			
		||||
@@ -313,3 +381,5 @@ class Records():
 | 
			
		||||
 | 
			
		||||
Key = namedtuple('Key', ['agent_id', 't_step', 'key'])
 | 
			
		||||
Record = namedtuple('Record', 'agent_id t_step key value')
 | 
			
		||||
 | 
			
		||||
Stat = namedtuple('Stat', 'trial_id')
 | 
			
		||||
 
 | 
			
		||||
@@ -17,10 +17,10 @@ logger.setLevel(logging.INFO)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_network(network_params, dir_path=None):
 | 
			
		||||
    if network_params is None:
 | 
			
		||||
        return nx.Graph()
 | 
			
		||||
    path = network_params.get('path', None)
 | 
			
		||||
    if path:
 | 
			
		||||
    G = nx.Graph()
 | 
			
		||||
 | 
			
		||||
    if 'path' in network_params:
 | 
			
		||||
        path = network_params['path']
 | 
			
		||||
        if dir_path and not os.path.isabs(path):
 | 
			
		||||
            path = os.path.join(dir_path, path)
 | 
			
		||||
        extension = os.path.splitext(path)[1][1:]
 | 
			
		||||
@@ -32,21 +32,22 @@ def load_network(network_params, dir_path=None):
 | 
			
		||||
            method = getattr(nx.readwrite, 'read_' + extension)
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            raise AttributeError('Unknown format')
 | 
			
		||||
        return method(path, **kwargs)
 | 
			
		||||
        G = method(path, **kwargs)
 | 
			
		||||
 | 
			
		||||
    net_args = network_params.copy()
 | 
			
		||||
    if 'generator' not in net_args:
 | 
			
		||||
        return nx.Graph()
 | 
			
		||||
    elif 'generator' in network_params:
 | 
			
		||||
        net_args = network_params.copy()
 | 
			
		||||
        net_gen = net_args.pop('generator')
 | 
			
		||||
 | 
			
		||||
    net_gen = net_args.pop('generator')
 | 
			
		||||
        if dir_path not in sys.path:
 | 
			
		||||
            sys.path.append(dir_path)
 | 
			
		||||
 | 
			
		||||
    if dir_path not in sys.path:
 | 
			
		||||
        sys.path.append(dir_path)
 | 
			
		||||
        method = deserializer(net_gen,
 | 
			
		||||
                              known_modules=['networkx.generators',])
 | 
			
		||||
        G = method(**net_args)
 | 
			
		||||
 | 
			
		||||
    return G
 | 
			
		||||
 | 
			
		||||
    method = deserializer(net_gen,
 | 
			
		||||
                          known_modules=['networkx.generators',])
 | 
			
		||||
 | 
			
		||||
    return method(**net_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_file(infile):
 | 
			
		||||
@@ -66,11 +67,32 @@ def expand_template(config):
 | 
			
		||||
        raise ValueError(('You must provide a definition of variables'
 | 
			
		||||
                          ' for the template.'))
 | 
			
		||||
 | 
			
		||||
    template = Template(config['template'])
 | 
			
		||||
    template = config['template']
 | 
			
		||||
 | 
			
		||||
    sampler_name = config.get('sampler', 'SALib.sample.morris.sample')
 | 
			
		||||
    n_samples = int(config.get('samples', 100))
 | 
			
		||||
    sampler = deserializer(sampler_name)
 | 
			
		||||
    if not isinstance(template, str):
 | 
			
		||||
        template = yaml.dump(template)
 | 
			
		||||
 | 
			
		||||
    template = Template(template)
 | 
			
		||||
 | 
			
		||||
    params = params_for_template(config)
 | 
			
		||||
 | 
			
		||||
    blank_str = template.render({k: 0 for k in params[0].keys()})
 | 
			
		||||
    blank = list(load_string(blank_str))
 | 
			
		||||
    if len(blank) > 1:
 | 
			
		||||
        raise ValueError('Templates must not return more than one configuration')
 | 
			
		||||
    if 'name' in blank[0]:
 | 
			
		||||
        raise ValueError('Templates cannot be named, use group instead')
 | 
			
		||||
 | 
			
		||||
    for ps in params:
 | 
			
		||||
        string = template.render(ps)
 | 
			
		||||
        for c in load_string(string):
 | 
			
		||||
            yield c
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def params_for_template(config):
 | 
			
		||||
    sampler_config = config.get('sampler', {'N': 100})
 | 
			
		||||
    sampler = sampler_config.pop('method', 'SALib.sample.morris.sample')
 | 
			
		||||
    sampler = deserializer(sampler)
 | 
			
		||||
    bounds = config['vars']['bounds']
 | 
			
		||||
 | 
			
		||||
    problem = {
 | 
			
		||||
@@ -78,7 +100,7 @@ def expand_template(config):
 | 
			
		||||
        'names': list(bounds.keys()),
 | 
			
		||||
        'bounds': list(v for v in bounds.values())
 | 
			
		||||
    }
 | 
			
		||||
    samples = sampler(problem, n_samples)
 | 
			
		||||
    samples = sampler(problem, **sampler_config)
 | 
			
		||||
 | 
			
		||||
    lists = config['vars'].get('lists', {})
 | 
			
		||||
    names = list(lists.keys())
 | 
			
		||||
@@ -88,20 +110,7 @@ def expand_template(config):
 | 
			
		||||
    allnames = names + problem['names']
 | 
			
		||||
    allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)]
 | 
			
		||||
    params = list(map(lambda x: dict(zip(allnames, x)), allvalues))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    blank_str = template.render({k: 0 for k in allnames})
 | 
			
		||||
    blank = list(load_string(blank_str))
 | 
			
		||||
    if len(blank) > 1:
 | 
			
		||||
        raise ValueError('Templates must not return more than one configuration')
 | 
			
		||||
    if 'name' in blank[0]:
 | 
			
		||||
        raise ValueError('Templates cannot be named, use group instead')
 | 
			
		||||
 | 
			
		||||
    confs = []
 | 
			
		||||
    for ps in params:
 | 
			
		||||
        string = template.render(ps)
 | 
			
		||||
        for c in load_string(string):
 | 
			
		||||
            yield c
 | 
			
		||||
    return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_files(*patterns, **kwargs):
 | 
			
		||||
@@ -116,7 +125,7 @@ def load_files(*patterns, **kwargs):
 | 
			
		||||
 | 
			
		||||
def load_config(config):
 | 
			
		||||
    if isinstance(config, dict):
 | 
			
		||||
        yield config, None
 | 
			
		||||
        yield config, os.getcwd()
 | 
			
		||||
    else:
 | 
			
		||||
        yield from load_files(config)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import importlib
 | 
			
		||||
import sys
 | 
			
		||||
import yaml
 | 
			
		||||
import traceback
 | 
			
		||||
import logging
 | 
			
		||||
import networkx as nx
 | 
			
		||||
from networkx.readwrite import json_graph
 | 
			
		||||
from multiprocessing import Pool
 | 
			
		||||
@@ -11,17 +12,19 @@ from functools import partial
 | 
			
		||||
 | 
			
		||||
import pickle
 | 
			
		||||
 | 
			
		||||
from nxsim import NetworkSimulation
 | 
			
		||||
 | 
			
		||||
from . import serialization, utils, basestring, agents
 | 
			
		||||
from .environment import Environment
 | 
			
		||||
from .utils import logger
 | 
			
		||||
from .exporters import for_sim as exporters_for_sim
 | 
			
		||||
from .exporters import default, for_sim as exporters_for_sim
 | 
			
		||||
from .stats import defaultStats
 | 
			
		||||
from .history import History
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Simulation(NetworkSimulation):
 | 
			
		||||
#TODO: change documentation for simulation
 | 
			
		||||
 | 
			
		||||
class Simulation:
 | 
			
		||||
    """
 | 
			
		||||
    Subclass of nsim.NetworkSimulation with three main differences:
 | 
			
		||||
    Similar to nsim.NetworkSimulation with three main differences:
 | 
			
		||||
        1) agent type can be specified by name or by class.
 | 
			
		||||
        2) instead of just one type, a network agents distribution can be used.
 | 
			
		||||
           The distribution specifies the weight (or probability) of each
 | 
			
		||||
@@ -91,11 +94,12 @@ class Simulation(NetworkSimulation):
 | 
			
		||||
                 environment_params=None, environment_class=None,
 | 
			
		||||
                 **kwargs):
 | 
			
		||||
 | 
			
		||||
        self.seed = str(seed) or str(time.time())
 | 
			
		||||
        self.load_module = load_module
 | 
			
		||||
        self.network_params = network_params
 | 
			
		||||
        self.name = name or 'Unnamed_' + time.strftime("%Y-%m-%d_%H.%M.%S")
 | 
			
		||||
        self.group = group or None
 | 
			
		||||
        self.name = name or 'Unnamed'
 | 
			
		||||
        self.seed = str(seed or name)
 | 
			
		||||
        self._id = '{}_{}'.format(self.name, time.strftime("%Y-%m-%d_%H.%M.%S"))
 | 
			
		||||
        self.group = group or ''
 | 
			
		||||
        self.num_trials = num_trials
 | 
			
		||||
        self.max_time = max_time
 | 
			
		||||
        self.default_state = default_state or {}
 | 
			
		||||
@@ -128,12 +132,15 @@ class Simulation(NetworkSimulation):
 | 
			
		||||
        self.states = agents._validate_states(states,
 | 
			
		||||
                                              self.topology)
 | 
			
		||||
 | 
			
		||||
        self._history = History(name=self.name,
 | 
			
		||||
                               backup=False)
 | 
			
		||||
 | 
			
		||||
    def run_simulation(self, *args, **kwargs):
 | 
			
		||||
        return self.run(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def run(self, *args, **kwargs):
 | 
			
		||||
        '''Run the simulation and return the list of resulting environments'''
 | 
			
		||||
        return list(self._run_simulation_gen(*args, **kwargs))
 | 
			
		||||
        return list(self.run_gen(*args, **kwargs))
 | 
			
		||||
 | 
			
		||||
    def _run_sync_or_async(self, parallel=False, *args, **kwargs):
 | 
			
		||||
        if parallel:
 | 
			
		||||
@@ -148,12 +155,16 @@ class Simulation(NetworkSimulation):
 | 
			
		||||
                yield i
 | 
			
		||||
        else:
 | 
			
		||||
            for i in range(self.num_trials):
 | 
			
		||||
                yield self.run_trial(i,
 | 
			
		||||
                                     *args,
 | 
			
		||||
                yield self.run_trial(*args,
 | 
			
		||||
                                     **kwargs)
 | 
			
		||||
 | 
			
		||||
    def _run_simulation_gen(self, *args, parallel=False, dry_run=False,
 | 
			
		||||
                            exporters=['default', ], outdir=None, exporter_params={}, **kwargs):
 | 
			
		||||
    def run_gen(self, *args, parallel=False, dry_run=False,
 | 
			
		||||
                exporters=[default, ], stats=[defaultStats], outdir=None, exporter_params={},
 | 
			
		||||
                stats_params={}, log_level=None,
 | 
			
		||||
                **kwargs):
 | 
			
		||||
        '''Run the simulation and yield the resulting environments.'''
 | 
			
		||||
        if log_level:
 | 
			
		||||
            logger.setLevel(log_level)
 | 
			
		||||
        logger.info('Using exporters: %s', exporters or [])
 | 
			
		||||
        logger.info('Output directory: %s', outdir)
 | 
			
		||||
        exporters = exporters_for_sim(self,
 | 
			
		||||
@@ -161,31 +172,63 @@ class Simulation(NetworkSimulation):
 | 
			
		||||
                                      dry_run=dry_run,
 | 
			
		||||
                                      outdir=outdir,
 | 
			
		||||
                                      **exporter_params)
 | 
			
		||||
        stats = exporters_for_sim(self,
 | 
			
		||||
                                  stats,
 | 
			
		||||
                                  **stats_params)
 | 
			
		||||
 | 
			
		||||
        with utils.timer('simulation {}'.format(self.name)):
 | 
			
		||||
            for stat in stats:
 | 
			
		||||
                stat.start()
 | 
			
		||||
 | 
			
		||||
            for exporter in exporters:
 | 
			
		||||
                exporter.start()
 | 
			
		||||
 | 
			
		||||
            for env in self._run_sync_or_async(*args, parallel=parallel,
 | 
			
		||||
            for env in self._run_sync_or_async(*args,
 | 
			
		||||
                                               parallel=parallel,
 | 
			
		||||
                                               log_level=log_level,
 | 
			
		||||
                                               **kwargs):
 | 
			
		||||
 | 
			
		||||
                collected = list(stat.trial(env) for stat in stats)
 | 
			
		||||
 | 
			
		||||
                saved = self.save_stats(collected, t_step=env.now, trial_id=env.name)
 | 
			
		||||
 | 
			
		||||
                for exporter in exporters:
 | 
			
		||||
                    exporter.trial_end(env)
 | 
			
		||||
                    exporter.trial(env, saved)
 | 
			
		||||
 | 
			
		||||
                yield env
 | 
			
		||||
 | 
			
		||||
            for exporter in exporters:
 | 
			
		||||
                exporter.end()
 | 
			
		||||
 | 
			
		||||
    def get_env(self, trial_id = 0, **kwargs):
 | 
			
		||||
            collected = list(stat.end() for stat in stats)
 | 
			
		||||
            saved = self.save_stats(collected)
 | 
			
		||||
 | 
			
		||||
            for exporter in exporters:
 | 
			
		||||
                exporter.end(saved)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def save_stats(self, collection, **kwargs):
 | 
			
		||||
        stats = dict(kwargs)
 | 
			
		||||
        for stat in collection:
 | 
			
		||||
            stats.update(stat)
 | 
			
		||||
        self._history.save_stats(utils.flatten_dict(stats))
 | 
			
		||||
        return stats
 | 
			
		||||
 | 
			
		||||
    def get_stats(self, **kwargs):
 | 
			
		||||
        return self._history.get_stats(**kwargs)
 | 
			
		||||
 | 
			
		||||
    def log_stats(self, stats):
 | 
			
		||||
        logger.info('Stats: \n{}'.format(yaml.dump(stats, default_flow_style=False)))
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    def get_env(self, trial_id=0, **kwargs):
 | 
			
		||||
        '''Create an environment for a trial of the simulation'''
 | 
			
		||||
        opts = self.environment_params.copy()
 | 
			
		||||
        env_name = '{}_trial_{}'.format(self.name, trial_id)
 | 
			
		||||
        opts.update({
 | 
			
		||||
            'name': env_name,
 | 
			
		||||
            'name': trial_id,
 | 
			
		||||
            'topology': self.topology.copy(),
 | 
			
		||||
            'seed': self.seed+env_name,
 | 
			
		||||
            'seed': '{}_trial_{}'.format(self.seed, trial_id),
 | 
			
		||||
            'initial_time': 0,
 | 
			
		||||
            'interval': self.interval,
 | 
			
		||||
            'network_agents': self.network_agents,
 | 
			
		||||
            'initial_time': 0,
 | 
			
		||||
            'states': self.states,
 | 
			
		||||
            'default_state': self.default_state,
 | 
			
		||||
            'environment_agents': self.environment_agents,
 | 
			
		||||
@@ -194,20 +237,22 @@ class Simulation(NetworkSimulation):
 | 
			
		||||
        env = self.environment_class(**opts)
 | 
			
		||||
        return env
 | 
			
		||||
 | 
			
		||||
    def run_trial(self, trial_id=0, until=None, **opts):
 | 
			
		||||
        """Run a single trial of the simulation
 | 
			
		||||
 | 
			
		||||
        Parameters
 | 
			
		||||
        ----------
 | 
			
		||||
        trial_id : int
 | 
			
		||||
    def run_trial(self, until=None, log_level=logging.INFO, **opts):
 | 
			
		||||
        """
 | 
			
		||||
        Run a single trial of the simulation
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        trial_id = '{}_trial_{}'.format(self.name, time.time()).replace('.', '-')
 | 
			
		||||
        if log_level:
 | 
			
		||||
            logger.setLevel(log_level)
 | 
			
		||||
        # Set-up trial environment and graph
 | 
			
		||||
        until = until or self.max_time
 | 
			
		||||
        env = self.get_env(trial_id = trial_id, **opts)
 | 
			
		||||
        env = self.get_env(trial_id=trial_id, **opts)
 | 
			
		||||
        # Set up agents on nodes
 | 
			
		||||
        with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
 | 
			
		||||
            env.run(until)
 | 
			
		||||
        return env
 | 
			
		||||
 | 
			
		||||
    def run_trial_exceptions(self, *args, **kwargs):
 | 
			
		||||
        '''
 | 
			
		||||
        A wrapper for run_trial that catches exceptions and returns them.
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										106
									
								
								soil/stats.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								soil/stats.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,106 @@
 | 
			
		||||
import pandas as pd
 | 
			
		||||
 | 
			
		||||
from collections import Counter
 | 
			
		||||
 | 
			
		||||
class Stats:
 | 
			
		||||
    '''
 | 
			
		||||
    Interface for all stats. It is not necessary, but it is useful
 | 
			
		||||
    if you don't plan to implement all the methods.
 | 
			
		||||
    '''
 | 
			
		||||
 | 
			
		||||
    def __init__(self, simulation):
 | 
			
		||||
        self.simulation = simulation
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        '''Method to call when the simulation starts'''
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def end(self):
 | 
			
		||||
        '''Method to call when the simulation ends'''
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
    def trial(self, env):
 | 
			
		||||
        '''Method to call when a trial ends'''
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class distribution(Stats):
 | 
			
		||||
    '''
 | 
			
		||||
    Calculate the distribution of agent states at the end of each trial,
 | 
			
		||||
    the mean value, and its deviation.
 | 
			
		||||
    '''
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        self.means = []
 | 
			
		||||
        self.counts = []
 | 
			
		||||
 | 
			
		||||
    def trial(self, env):
 | 
			
		||||
        df = env[None, None, None].df()
 | 
			
		||||
        df = df.drop('SEED', axis=1)
 | 
			
		||||
        ix = df.index[-1]
 | 
			
		||||
        attrs = df.columns.get_level_values(0)
 | 
			
		||||
        vc = {}
 | 
			
		||||
        stats = {
 | 
			
		||||
            'mean': {},
 | 
			
		||||
            'count': {},
 | 
			
		||||
        }
 | 
			
		||||
        for a in attrs:
 | 
			
		||||
            t = df.loc[(ix, a)]
 | 
			
		||||
            try:
 | 
			
		||||
                stats['mean'][a] = t.mean()
 | 
			
		||||
                self.means.append(('mean', a, t.mean()))
 | 
			
		||||
            except TypeError:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
            for name, count in t.value_counts().iteritems():
 | 
			
		||||
                if a not in stats['count']:
 | 
			
		||||
                    stats['count'][a] = {}
 | 
			
		||||
                stats['count'][a][name] = count
 | 
			
		||||
                self.counts.append(('count', a, name, count))
 | 
			
		||||
 | 
			
		||||
        return stats
 | 
			
		||||
 | 
			
		||||
    def end(self):
 | 
			
		||||
        dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value'])
 | 
			
		||||
        dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count'])
 | 
			
		||||
 | 
			
		||||
        count = {}
 | 
			
		||||
        mean = {}
 | 
			
		||||
 | 
			
		||||
        if self.means:
 | 
			
		||||
            res = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
 | 
			
		||||
            mean = res['value'].to_dict()
 | 
			
		||||
        if self.counts:
 | 
			
		||||
            res = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
 | 
			
		||||
            for k,v in res['count'].to_dict().items():
 | 
			
		||||
                if k not in count:
 | 
			
		||||
                    count[k] = {}
 | 
			
		||||
                for tup, times in v.items():
 | 
			
		||||
                    subkey, subcount = tup
 | 
			
		||||
                    if subkey not in count[k]:
 | 
			
		||||
                        count[k][subkey] = {}
 | 
			
		||||
                    count[k][subkey][subcount] = times
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        return {'count': count, 'mean': mean}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class defaultStats(Stats):
 | 
			
		||||
 | 
			
		||||
    def trial(self, env):
 | 
			
		||||
        c = Counter()
 | 
			
		||||
        c.update(a.__class__.__name__ for a in env.network_agents)
 | 
			
		||||
 | 
			
		||||
        c2 = Counter()
 | 
			
		||||
        c2.update(a['id'] for a in env.network_agents)
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            'network ': {
 | 
			
		||||
                'n_nodes': env.G.number_of_nodes(),
 | 
			
		||||
                'n_edges': env.G.number_of_nodes(),
 | 
			
		||||
            },
 | 
			
		||||
            'agents': {
 | 
			
		||||
                'model_count': dict(c),
 | 
			
		||||
                'state_count': dict(c2),
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
@@ -7,6 +7,7 @@ from shutil import copyfile
 | 
			
		||||
from contextlib import contextmanager
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger('soil')
 | 
			
		||||
logging.basicConfig()
 | 
			
		||||
logger.setLevel(logging.INFO)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -31,14 +32,13 @@ def safe_open(path, mode='r', backup=True, **kwargs):
 | 
			
		||||
        os.makedirs(outdir)
 | 
			
		||||
    if backup and 'w' in mode and os.path.exists(path):
 | 
			
		||||
        creation = os.path.getctime(path)
 | 
			
		||||
        stamp = time.strftime('%Y-%m-%d_%H.%M', time.localtime(creation))
 | 
			
		||||
        stamp = time.strftime('%Y-%m-%d_%H.%M.%S', time.localtime(creation))
 | 
			
		||||
 | 
			
		||||
        backup_dir = os.path.join(outdir, stamp)
 | 
			
		||||
        backup_dir = os.path.join(outdir, 'backup')
 | 
			
		||||
        if not os.path.exists(backup_dir):
 | 
			
		||||
            os.makedirs(backup_dir)
 | 
			
		||||
        newpath = os.path.join(backup_dir, os.path.basename(path))
 | 
			
		||||
        if os.path.exists(newpath):
 | 
			
		||||
            newpath = '{}@{}'.format(newpath, time.time())
 | 
			
		||||
        newpath = os.path.join(backup_dir, '{}@{}'.format(os.path.basename(path),
 | 
			
		||||
                                                               stamp))
 | 
			
		||||
        copyfile(path, newpath)
 | 
			
		||||
    return open(path, mode=mode, **kwargs)
 | 
			
		||||
 | 
			
		||||
@@ -48,3 +48,40 @@ def open_or_reuse(f, *args, **kwargs):
 | 
			
		||||
        return safe_open(f, *args, **kwargs)
 | 
			
		||||
    except (AttributeError, TypeError):
 | 
			
		||||
        return f
 | 
			
		||||
 | 
			
		||||
def flatten_dict(d):
 | 
			
		||||
    if not isinstance(d, dict):
 | 
			
		||||
        return d
 | 
			
		||||
    return dict(_flatten_dict(d))
 | 
			
		||||
 | 
			
		||||
def _flatten_dict(d, prefix=''):
 | 
			
		||||
    if not isinstance(d, dict):
 | 
			
		||||
        # print('END:', prefix, d)
 | 
			
		||||
        yield prefix, d
 | 
			
		||||
        return
 | 
			
		||||
    if prefix:
 | 
			
		||||
        prefix = prefix + '.'
 | 
			
		||||
    for k, v in d.items():
 | 
			
		||||
        # print(k, v)
 | 
			
		||||
        res = list(_flatten_dict(v, prefix='{}{}'.format(prefix, k)))
 | 
			
		||||
        # print('RES:', res)
 | 
			
		||||
        yield from res
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def unflatten_dict(d):
 | 
			
		||||
    out = {}
 | 
			
		||||
    for k, v in d.items():
 | 
			
		||||
        target = out
 | 
			
		||||
        if not isinstance(k, str):
 | 
			
		||||
            target[k] = v
 | 
			
		||||
            continue
 | 
			
		||||
        tokens = k.split('.')
 | 
			
		||||
        if len(tokens) < 2:
 | 
			
		||||
            target[k] = v
 | 
			
		||||
            continue
 | 
			
		||||
        for token in tokens[:-1]:
 | 
			
		||||
            if token not in target:
 | 
			
		||||
                target[token] = {}
 | 
			
		||||
            target = target[token]
 | 
			
		||||
        target[tokens[-1]] = v
 | 
			
		||||
    return out
 | 
			
		||||
 
 | 
			
		||||
@@ -66,8 +66,8 @@ class TestAnalysis(TestCase):
 | 
			
		||||
        env = self.env
 | 
			
		||||
        df = analysis.read_sql(env._history.db_path)
 | 
			
		||||
        res = analysis.get_count(df, 'SEED', 'id')
 | 
			
		||||
        assert res['SEED']['seedanalysis_trial_0'].iloc[0] == 1
 | 
			
		||||
        assert res['SEED']['seedanalysis_trial_0'].iloc[-1] == 1
 | 
			
		||||
        assert res['SEED'][self.env['SEED']].iloc[0] == 1
 | 
			
		||||
        assert res['SEED'][self.env['SEED']].iloc[-1] == 1
 | 
			
		||||
        assert res['id']['odd'].iloc[0] == 2
 | 
			
		||||
        assert res['id']['even'].iloc[0] == 0
 | 
			
		||||
        assert res['id']['odd'].iloc[-1] == 1
 | 
			
		||||
@@ -75,7 +75,7 @@ class TestAnalysis(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_value(self):
 | 
			
		||||
        env = self.env
 | 
			
		||||
        df = analysis.read_sql(env._history._db)
 | 
			
		||||
        df = analysis.read_sql(env._history.db_path)
 | 
			
		||||
        res_sum = analysis.get_value(df, 'count')
 | 
			
		||||
 | 
			
		||||
        assert res_sum['count'].iloc[0] == 2
 | 
			
		||||
@@ -86,4 +86,4 @@ class TestAnalysis(TestCase):
 | 
			
		||||
 | 
			
		||||
        res_total = analysis.get_value(df)
 | 
			
		||||
 | 
			
		||||
        res_total['SEED'].iloc[0] == 'seedanalysis_trial_0'
 | 
			
		||||
        res_total['SEED'].iloc[0] == self.env['SEED']
 | 
			
		||||
 
 | 
			
		||||
@@ -31,7 +31,7 @@ def make_example_test(path, config):
 | 
			
		||||
                try:
 | 
			
		||||
                    n = config['network_params']['n']
 | 
			
		||||
                    assert len(list(env.network_agents)) == n
 | 
			
		||||
                    assert env.now > 2  # It has run
 | 
			
		||||
                    assert env.now > 0  # It has run
 | 
			
		||||
                    assert env.now <= config['max_time']  # But not further than allowed
 | 
			
		||||
                except KeyError:
 | 
			
		||||
                    pass
 | 
			
		||||
 
 | 
			
		||||
@@ -6,26 +6,32 @@ from time import time
 | 
			
		||||
 | 
			
		||||
from unittest import TestCase
 | 
			
		||||
from soil import exporters
 | 
			
		||||
from soil.utils import safe_open
 | 
			
		||||
from soil import simulation
 | 
			
		||||
 | 
			
		||||
from soil.stats import distribution
 | 
			
		||||
 | 
			
		||||
class Dummy(exporters.Exporter):
 | 
			
		||||
    started = False
 | 
			
		||||
    trials = 0
 | 
			
		||||
    ended = False
 | 
			
		||||
    total_time = 0
 | 
			
		||||
    called_start = 0
 | 
			
		||||
    called_trial = 0
 | 
			
		||||
    called_end = 0
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        self.__class__.called_start += 1
 | 
			
		||||
        self.__class__.started = True
 | 
			
		||||
 | 
			
		||||
    def trial_end(self, env):
 | 
			
		||||
    def trial(self, env, stats):
 | 
			
		||||
        assert env
 | 
			
		||||
        self.__class__.trials += 1
 | 
			
		||||
        self.__class__.total_time += env.now
 | 
			
		||||
        self.__class__.called_trial += 1
 | 
			
		||||
 | 
			
		||||
    def end(self):
 | 
			
		||||
    def end(self, stats):
 | 
			
		||||
        self.__class__.ended = True
 | 
			
		||||
        self.__class__.called_end += 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Exporters(TestCase):
 | 
			
		||||
@@ -39,32 +45,17 @@ class Exporters(TestCase):
 | 
			
		||||
            'environment_params': {}
 | 
			
		||||
        }
 | 
			
		||||
        s = simulation.from_config(config)
 | 
			
		||||
        s.run_simulation(exporters=[Dummy], dry_run=True)
 | 
			
		||||
        for env in s.run_simulation(exporters=[Dummy], dry_run=True):
 | 
			
		||||
            assert env.now <= 2
 | 
			
		||||
 | 
			
		||||
        assert Dummy.started
 | 
			
		||||
        assert Dummy.ended
 | 
			
		||||
        assert Dummy.called_start == 1
 | 
			
		||||
        assert Dummy.called_end == 1
 | 
			
		||||
        assert Dummy.called_trial == 5
 | 
			
		||||
        assert Dummy.trials == 5
 | 
			
		||||
        assert Dummy.total_time == 2*5
 | 
			
		||||
 | 
			
		||||
    def test_distribution(self):
 | 
			
		||||
        '''The distribution exporter should write the number of agents in each state'''
 | 
			
		||||
        config = {
 | 
			
		||||
            'name': 'exporter_sim',
 | 
			
		||||
            'network_params': {
 | 
			
		||||
                'generator': 'complete_graph',
 | 
			
		||||
                'n': 4
 | 
			
		||||
            },
 | 
			
		||||
            'agent_type': 'CounterModel',
 | 
			
		||||
            'max_time': 2,
 | 
			
		||||
            'num_trials': 5,
 | 
			
		||||
            'environment_params': {}
 | 
			
		||||
        }
 | 
			
		||||
        output = io.StringIO()
 | 
			
		||||
        s = simulation.from_config(config)
 | 
			
		||||
        s.run_simulation(exporters=[exporters.distribution], dry_run=True, exporter_params={'copy_to': output})
 | 
			
		||||
        result = output.getvalue()
 | 
			
		||||
        assert 'count' in result
 | 
			
		||||
        assert 'SEED,Noneexporter_sim_trial_3,1,,1,1,1,1' in result
 | 
			
		||||
 | 
			
		||||
    def test_writing(self):
 | 
			
		||||
        '''Try to write CSV, GEXF, sqlite and YAML (without dry_run)'''
 | 
			
		||||
        n_trials = 5
 | 
			
		||||
@@ -86,8 +77,8 @@ class Exporters(TestCase):
 | 
			
		||||
            exporters.default,
 | 
			
		||||
            exporters.csv,
 | 
			
		||||
            exporters.gexf,
 | 
			
		||||
            exporters.distribution,
 | 
			
		||||
        ],
 | 
			
		||||
                                stats=[distribution,],
 | 
			
		||||
                                outdir=tmpdir,
 | 
			
		||||
                                exporter_params={'copy_to': output})
 | 
			
		||||
        result = output.getvalue()
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import shutil
 | 
			
		||||
from glob import glob
 | 
			
		||||
 | 
			
		||||
from soil import history
 | 
			
		||||
from soil import utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ROOT = os.path.abspath(os.path.dirname(__file__))
 | 
			
		||||
@@ -154,3 +155,49 @@ class TestHistory(TestCase):
 | 
			
		||||
        assert recovered
 | 
			
		||||
        for i in recovered:
 | 
			
		||||
            assert i in tuples
 | 
			
		||||
 | 
			
		||||
    def test_stats(self):
 | 
			
		||||
        """
 | 
			
		||||
        The data recovered should be equal to the one recorded.
 | 
			
		||||
        """
 | 
			
		||||
        tuples = (
 | 
			
		||||
            ('a_1', 0, 'id', 'v'),
 | 
			
		||||
            ('a_1', 1, 'id', 'a'),
 | 
			
		||||
            ('a_1', 2, 'id', 'l'),
 | 
			
		||||
            ('a_1', 3, 'id', 'u'),
 | 
			
		||||
            ('a_1', 4, 'id', 'e'),
 | 
			
		||||
            ('env', 1, 'prob', 1),
 | 
			
		||||
            ('env', 2, 'prob', 2),
 | 
			
		||||
            ('env', 3, 'prob', 3),
 | 
			
		||||
            ('a_2', 7, 'finished', True),
 | 
			
		||||
        )
 | 
			
		||||
        stat_tuples = [
 | 
			
		||||
            {'num_infected': 5, 'runtime': 0.2},
 | 
			
		||||
            {'num_infected': 5, 'runtime': 0.2},
 | 
			
		||||
            {'new': '40'},
 | 
			
		||||
        ]
 | 
			
		||||
        h = history.History()
 | 
			
		||||
        h.save_tuples(tuples)
 | 
			
		||||
        for stat in stat_tuples:
 | 
			
		||||
            h.save_stats(stat)
 | 
			
		||||
        recovered = h.get_stats()
 | 
			
		||||
        assert recovered
 | 
			
		||||
        assert recovered[0]['num_infected'] == 5
 | 
			
		||||
        assert recovered[1]['runtime'] == 0.2
 | 
			
		||||
        assert recovered[2]['new'] == '40'
 | 
			
		||||
 | 
			
		||||
    def test_unflatten(self):
 | 
			
		||||
        ex = {'count.neighbors.3': 4,
 | 
			
		||||
              'count.times.2': 4,
 | 
			
		||||
              'count.total.4': 4,
 | 
			
		||||
              'mean.neighbors': 3,
 | 
			
		||||
              'mean.times': 2,
 | 
			
		||||
              'mean.total': 4,
 | 
			
		||||
              't_step': 2,
 | 
			
		||||
              'trial_id': 'exporter_sim_trial_1605817956-4475424'}
 | 
			
		||||
        res = utils.unflatten_dict(ex)
 | 
			
		||||
 | 
			
		||||
        assert 'count' in res
 | 
			
		||||
        assert 'mean' in res
 | 
			
		||||
        assert 't_step' in res
 | 
			
		||||
        assert 'trial_id' in res
 | 
			
		||||
 
 | 
			
		||||
@@ -343,4 +343,16 @@ class TestMain(TestCase):
 | 
			
		||||
        configs = serialization.load_file(join(EXAMPLES, 'template.yml'))
 | 
			
		||||
        assert len(configs) > 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def test_until(self):
 | 
			
		||||
        config = {
 | 
			
		||||
            'name': 'exporter_sim',
 | 
			
		||||
            'network_params': {},
 | 
			
		||||
            'agent_type': 'CounterModel',
 | 
			
		||||
            'max_time': 2,
 | 
			
		||||
            'num_trials': 100,
 | 
			
		||||
            'environment_params': {}
 | 
			
		||||
        }
 | 
			
		||||
        s = simulation.from_config(config)
 | 
			
		||||
        runs = list(s.run_simulation(dry_run=True))
 | 
			
		||||
        over = list(x.now for x in runs if x.now>2)
 | 
			
		||||
        assert len(over) == 0
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										34
									
								
								tests/test_stats.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								tests/test_stats.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,34 @@
 | 
			
		||||
from unittest import TestCase
 | 
			
		||||
 | 
			
		||||
from soil import simulation, stats
 | 
			
		||||
from soil.utils import unflatten_dict
 | 
			
		||||
 | 
			
		||||
class Stats(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_distribution(self):
 | 
			
		||||
        '''The distribution exporter should write the number of agents in each state'''
 | 
			
		||||
        config = {
 | 
			
		||||
            'name': 'exporter_sim',
 | 
			
		||||
            'network_params': {
 | 
			
		||||
                'generator': 'complete_graph',
 | 
			
		||||
                'n': 4
 | 
			
		||||
            },
 | 
			
		||||
            'agent_type': 'CounterModel',
 | 
			
		||||
            'max_time': 2,
 | 
			
		||||
            'num_trials': 5,
 | 
			
		||||
            'environment_params': {}
 | 
			
		||||
        }
 | 
			
		||||
        s = simulation.from_config(config)
 | 
			
		||||
        for env in s.run_simulation(stats=[stats.distribution]):
 | 
			
		||||
            pass
 | 
			
		||||
            # stats_res = unflatten_dict(dict(env._history['stats', -1, None]))
 | 
			
		||||
        allstats = s.get_stats()
 | 
			
		||||
        for stat in allstats:
 | 
			
		||||
            assert 'count' in stat
 | 
			
		||||
            assert 'mean' in stat
 | 
			
		||||
            if 'trial_id' in stat:
 | 
			
		||||
                assert stat['mean']['neighbors'] == 3
 | 
			
		||||
                assert stat['count']['total']['4'] == 4
 | 
			
		||||
            else:
 | 
			
		||||
                assert stat['count']['count']['neighbors']['3'] == 20
 | 
			
		||||
                assert stat['mean']['min']['neighbors'] == stat['mean']['max']['neighbors']
 | 
			
		||||
		Reference in New Issue
	
	Block a user