mirror of
				https://github.com/gsi-upm/soil
				synced 2025-10-23 03:38:24 +00:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			3b2c6a3db5
			...
			0.15.1
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 05f7f49233 | 
							
								
								
									
										23
									
								
								CHANGELOG.md
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								CHANGELOG.md
									
									
									
									
									
								
							| @@ -3,6 +3,29 @@ All notable changes to this project will be documented in this file. | |||||||
|  |  | ||||||
| The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | ||||||
|  |  | ||||||
|  | ## [0.15.1] | ||||||
|  | ### Added | ||||||
|  | * read-only `History` | ||||||
|  | ### Fixed | ||||||
|  | * Serialization problem with the `Environment` on parallel mode. | ||||||
|  | * Analysis functions now work as they should in the tutorial | ||||||
|  | ## [0.15.0] | ||||||
|  | ### Added | ||||||
|  | * Control logging level in CLI and simulation | ||||||
|  | * `Stats` to calculate trial and simulation-wide statistics | ||||||
|  | * Simulation statistics are stored in a separate table in history (see `History.get_stats` and `History.save_stats`, as well as `soil.stats`) | ||||||
|  | * Aliased `NetworkAgent.G` to `NetworkAgent.topology`. | ||||||
|  | ### Changed | ||||||
|  | * Templates in config files can be given as dictionaries in addition to strings | ||||||
|  | * Samplers are used more explicitly | ||||||
|  | * Removed nxsim dependency. We had already made a lot of changes, and nxsim has not been updated in 5 years. | ||||||
|  | * Exporter methods renamed to `trial` and `end`. Added `start`. | ||||||
|  | * `Distribution` exporter now a stats class | ||||||
|  | * `global_topology` renamed to `topology` | ||||||
|  | * Moved topology-related methods to `NetworkAgent` | ||||||
|  | ### Fixed | ||||||
|  | * Temporary files used for history in dry_run mode are not longer left open  | ||||||
|  |  | ||||||
| ## [0.14.9] | ## [0.14.9] | ||||||
| ### Changed | ### Changed | ||||||
| * Seed random before environment initialization | * Seed random before environment initialization | ||||||
|   | |||||||
| @@ -31,7 +31,7 @@ | |||||||
| # Add any Sphinx extension module names here, as strings. They can be | # Add any Sphinx extension module names here, as strings. They can be | ||||||
| # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom | ||||||
| # ones. | # ones. | ||||||
| extensions = [] | extensions = ['IPython.sphinxext.ipython_console_highlighting'] | ||||||
|  |  | ||||||
| # Add any paths that contain templates here, relative to this directory. | # Add any paths that contain templates here, relative to this directory. | ||||||
| templates_path = ['_templates'] | templates_path = ['_templates'] | ||||||
| @@ -69,7 +69,7 @@ language = None | |||||||
| # List of patterns, relative to source directory, that match files and | # List of patterns, relative to source directory, that match files and | ||||||
| # directories to ignore when looking for source files. | # directories to ignore when looking for source files. | ||||||
| # This patterns also effect to html_static_path and html_extra_path | # This patterns also effect to html_static_path and html_extra_path | ||||||
| exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '**.ipynb_checkpoints'] | ||||||
|  |  | ||||||
| # The name of the Pygments (syntax highlighting) style to use. | # The name of the Pygments (syntax highlighting) style to use. | ||||||
| pygments_style = 'sphinx' | pygments_style = 'sphinx' | ||||||
|   | |||||||
| @@ -218,3 +218,24 @@ These agents are programmed in much the same way as network agents, the only dif | |||||||
|  |  | ||||||
| You may use environment agents to model events that a normal agent cannot control, such as natural disasters or chance. | You may use environment agents to model events that a normal agent cannot control, such as natural disasters or chance. | ||||||
| They are also useful to add behavior that has little to do with the network and the interactions within that network. | They are also useful to add behavior that has little to do with the network and the interactions within that network. | ||||||
|  |  | ||||||
|  | Templating | ||||||
|  | ========== | ||||||
|  |  | ||||||
|  | Sometimes, it is useful to parameterize a simulation and run it over a range of values in order to compare each run and measure the effect of those parameters in the simulation. | ||||||
|  | For instance, you may want to run a simulation with different agent distributions. | ||||||
|  |  | ||||||
|  | This can be done in Soil using **templates**. | ||||||
|  | A template is a configuration where some of the values are specified with a variable. | ||||||
|  | e.g.,  ``weight: "{{ var1 }}"`` instead of ``weight: 1``. | ||||||
|  | There are two types of variables, depending on how their values are decided: | ||||||
|  |  | ||||||
|  | * Fixed. A list of values is provided, and a new simulation is run for each possible value. If more than a variable is given, a new simulation will be run per combination of values. | ||||||
|  | * Bounded/Sampled. The bounds of the variable are provided, along with a sampler method, which will be used to compute all the configuration combinations. | ||||||
|  |  | ||||||
|  | When fixed and bounded variables are mixed, Soil generates a new configuration per combination of fixed values and bounded values. | ||||||
|  |  | ||||||
|  | Here is an example with a single fixed variable and two bounded variable: | ||||||
|  |  | ||||||
|  | .. literalinclude:: ../examples/template.yml | ||||||
|  |    :language: yaml | ||||||
|   | |||||||
| @@ -500,7 +500,7 @@ | |||||||
|    "name": "python", |    "name": "python", | ||||||
|    "nbconvert_exporter": "python", |    "nbconvert_exporter": "python", | ||||||
|    "pygments_lexer": "ipython3", |    "pygments_lexer": "ipython3", | ||||||
|    "version": "3.6.5" |    "version": "3.8.5" | ||||||
|   }, |   }, | ||||||
|   "toc": { |   "toc": { | ||||||
|    "colors": { |    "colors": { | ||||||
|   | |||||||
| @@ -80800,7 +80800,7 @@ | |||||||
|    "name": "python", |    "name": "python", | ||||||
|    "nbconvert_exporter": "python", |    "nbconvert_exporter": "python", | ||||||
|    "pygments_lexer": "ipython3", |    "pygments_lexer": "ipython3", | ||||||
|    "version": "3.6.5" |    "version": "3.8.6" | ||||||
|   } |   } | ||||||
|  }, |  }, | ||||||
|  "nbformat": 4, |  "nbformat": 4, | ||||||
|   | |||||||
| @@ -1,4 +1,4 @@ | |||||||
| from soil.agents import FSM, state, default_state, BaseAgent | from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent | ||||||
| from enum import Enum | from enum import Enum | ||||||
| from random import random, choice | from random import random, choice | ||||||
| from itertools import islice | from itertools import islice | ||||||
| @@ -80,7 +80,7 @@ class RabbitModel(FSM): | |||||||
|                 self.env.add_edge(self['mate'], child.id) |                 self.env.add_edge(self['mate'], child.id) | ||||||
|                 # self.add_edge() |                 # self.add_edge() | ||||||
|                 self.debug('A BABY IS COMING TO LIFE') |                 self.debug('A BABY IS COMING TO LIFE') | ||||||
|                 self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.global_topology.number_of_nodes())+1 |                 self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.topology.number_of_nodes())+1 | ||||||
|                 self.debug('Rabbits alive: {}'.format(self.env['rabbits_alive'])) |                 self.debug('Rabbits alive: {}'.format(self.env['rabbits_alive'])) | ||||||
|                 self['offspring'] += 1 |                 self['offspring'] += 1 | ||||||
|                 self.env.get_agent(self['mate'])['offspring'] += 1 |                 self.env.get_agent(self['mate'])['offspring'] += 1 | ||||||
| @@ -97,12 +97,14 @@ class RabbitModel(FSM): | |||||||
|         return |         return | ||||||
|  |  | ||||||
|  |  | ||||||
| class RandomAccident(BaseAgent): | class RandomAccident(NetworkAgent): | ||||||
|  |  | ||||||
|     level = logging.DEBUG |     level = logging.DEBUG | ||||||
|  |  | ||||||
|     def step(self): |     def step(self): | ||||||
|         rabbits_total = self.global_topology.number_of_nodes() |         rabbits_total = self.topology.number_of_nodes() | ||||||
|  |         if 'rabbits_alive' not in self.env: | ||||||
|  |             self.env['rabbits_alive'] = 0 | ||||||
|         rabbits_alive = self.env.get('rabbits_alive', rabbits_total) |         rabbits_alive = self.env.get('rabbits_alive', rabbits_total) | ||||||
|         prob_death = self.env.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive))) |         prob_death = self.env.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive))) | ||||||
|         self.debug('Killing some rabbits with prob={}!'.format(prob_death)) |         self.debug('Killing some rabbits with prob={}!'.format(prob_death)) | ||||||
| @@ -116,5 +118,5 @@ class RandomAccident(BaseAgent): | |||||||
|                 self.log('Rabbits alive: {}'.format(self.env['rabbits_alive'])) |                 self.log('Rabbits alive: {}'.format(self.env['rabbits_alive'])) | ||||||
|                 i.set_state(i.dead) |                 i.set_state(i.dead) | ||||||
|         self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total)) |         self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total)) | ||||||
|         if self.count_agents(state_id=RabbitModel.dead.id) == self.global_topology.number_of_nodes(): |         if self.count_agents(state_id=RabbitModel.dead.id) == self.topology.number_of_nodes(): | ||||||
|             self.die() |             self.die() | ||||||
|   | |||||||
| @@ -1,13 +1,8 @@ | |||||||
| --- | --- | ||||||
| vars: | sampler: | ||||||
|   bounds: |   method: "SALib.sample.morris.sample" | ||||||
|     x1: [0, 1] |   N: 10 | ||||||
|     x2: [1, 2] | template: | ||||||
|   fixed: |  | ||||||
|     x3: ["a", "b", "c"] |  | ||||||
| sampler: "SALib.sample.morris.sample" |  | ||||||
| samples: 10 |  | ||||||
| template: | |  | ||||||
|   group: simple |   group: simple | ||||||
|   num_trials: 1 |   num_trials: 1 | ||||||
|   interval: 1 |   interval: 1 | ||||||
| @@ -19,11 +14,17 @@ template: | | |||||||
|     n: 10 |     n: 10 | ||||||
|   network_agents: |   network_agents: | ||||||
|     - agent_type: CounterModel |     - agent_type: CounterModel | ||||||
|       weight: {{ x1 }} |       weight: "{{ x1 }}" | ||||||
|       state: |       state: | ||||||
|         id: 0 |         id: 0 | ||||||
|     - agent_type: AggregatedCounter |     - agent_type: AggregatedCounter | ||||||
|       weight: {{ 1 - x1 }} |       weight: "{{ 1 - x1 }}" | ||||||
|   environment_params: |   environment_params: | ||||||
|     name: {{ x3 }} |     name: "{{ x3 }}" | ||||||
|   skip_test: true |   skip_test: true | ||||||
|  | vars: | ||||||
|  |   bounds: | ||||||
|  |     x1: [0, 1] | ||||||
|  |     x2: [1, 2] | ||||||
|  |   fixed: | ||||||
|  |     x3: ["a", "b", "c"] | ||||||
|   | |||||||
| @@ -195,14 +195,14 @@ class TerroristNetworkModel(TerroristSpreadModel): | |||||||
|                     break |                     break | ||||||
|  |  | ||||||
|     def get_distance(self, target): |     def get_distance(self, target): | ||||||
|         source_x, source_y = nx.get_node_attributes(self.global_topology, 'pos')[self.id] |         source_x, source_y = nx.get_node_attributes(self.topology, 'pos')[self.id] | ||||||
|         target_x, target_y = nx.get_node_attributes(self.global_topology, 'pos')[target] |         target_x, target_y = nx.get_node_attributes(self.topology, 'pos')[target] | ||||||
|         dx = abs( source_x - target_x ) |         dx = abs( source_x - target_x ) | ||||||
|         dy = abs( source_y - target_y ) |         dy = abs( source_y - target_y ) | ||||||
|         return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 ) |         return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 ) | ||||||
|  |  | ||||||
|     def shortest_path_length(self, target): |     def shortest_path_length(self, target): | ||||||
|         try: |         try: | ||||||
|             return nx.shortest_path_length(self.global_topology, self.id, target) |             return nx.shortest_path_length(self.topology, self.id, target) | ||||||
|         except nx.NetworkXNoPath: |         except nx.NetworkXNoPath: | ||||||
|             return float('inf') |             return float('inf') | ||||||
|   | |||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -1,6 +1,5 @@ | |||||||
| nxsim>=0.1.2 | simpy>=4.0 | ||||||
| simpy | networkx>=2.5 | ||||||
| networkx>=2.0,<2.4 |  | ||||||
| numpy | numpy | ||||||
| matplotlib | matplotlib | ||||||
| pyyaml>=5.1 | pyyaml>=5.1 | ||||||
|   | |||||||
| @@ -1 +1 @@ | |||||||
| 0.14.9 | 0.15.1 | ||||||
| @@ -17,12 +17,12 @@ from .environment import Environment | |||||||
| from .history import History | from .history import History | ||||||
| from . import serialization | from . import serialization | ||||||
| from . import analysis | from . import analysis | ||||||
|  | from .utils import logger | ||||||
|  |  | ||||||
| def main(): | def main(): | ||||||
|     import argparse |     import argparse | ||||||
|     from . import simulation |     from . import simulation | ||||||
|  |  | ||||||
|     logging.basicConfig(level=logging.INFO) |  | ||||||
|     logging.info('Running SOIL version: {}'.format(__version__)) |     logging.info('Running SOIL version: {}'.format(__version__)) | ||||||
|  |  | ||||||
|     parser = argparse.ArgumentParser(description='Run a SOIL simulation') |     parser = argparse.ArgumentParser(description='Run a SOIL simulation') | ||||||
| @@ -40,6 +40,8 @@ def main(): | |||||||
|                         help='Dump GEXF graph. Defaults to false.') |                         help='Dump GEXF graph. Defaults to false.') | ||||||
|     parser.add_argument('--csv', action='store_true', |     parser.add_argument('--csv', action='store_true', | ||||||
|                         help='Dump history in CSV format. Defaults to false.') |                         help='Dump history in CSV format. Defaults to false.') | ||||||
|  |     parser.add_argument('--level', type=str, | ||||||
|  |                         help='Logging level') | ||||||
|     parser.add_argument('--output', '-o', type=str, default="soil_output", |     parser.add_argument('--output', '-o', type=str, default="soil_output", | ||||||
|                         help='folder to write results to. It defaults to the current directory.') |                         help='folder to write results to. It defaults to the current directory.') | ||||||
|     parser.add_argument('--synchronous', action='store_true', |     parser.add_argument('--synchronous', action='store_true', | ||||||
| @@ -48,6 +50,7 @@ def main(): | |||||||
|                         help='Export environment and/or simulations using this exporter') |                         help='Export environment and/or simulations using this exporter') | ||||||
|  |  | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |     logging.basicConfig(level=getattr(logging, (args.level or 'INFO').upper())) | ||||||
|  |  | ||||||
|     if os.getcwd() not in sys.path: |     if os.getcwd() not in sys.path: | ||||||
|         sys.path.append(os.getcwd()) |         sys.path.append(os.getcwd()) | ||||||
|   | |||||||
| @@ -9,7 +9,7 @@ class BassModel(BaseAgent): | |||||||
|         imitation_prob |         imitation_prob | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, environment, agent_id, state): |     def __init__(self, environment, agent_id, state, **kwargs): | ||||||
|         super().__init__(environment=environment, agent_id=agent_id, state=state) |         super().__init__(environment=environment, agent_id=agent_id, state=state) | ||||||
|         env_params = environment.environment_params |         env_params = environment.environment_params | ||||||
|         self.state['sentimentCorrelation'] = 0 |         self.state['sentimentCorrelation'] = 0 | ||||||
| @@ -19,7 +19,7 @@ class BassModel(BaseAgent): | |||||||
|  |  | ||||||
|     def behaviour(self): |     def behaviour(self): | ||||||
|         # Outside effects |         # Outside effects | ||||||
|         if random.random() < self.state_params['innovation_prob']: |         if random.random() < self['innovation_prob']: | ||||||
|             if self.state['id'] == 0: |             if self.state['id'] == 0: | ||||||
|                 self.state['id'] = 1 |                 self.state['id'] = 1 | ||||||
|                 self.state['sentimentCorrelation'] = 1 |                 self.state['sentimentCorrelation'] = 1 | ||||||
| @@ -32,7 +32,7 @@ class BassModel(BaseAgent): | |||||||
|         if self.state['id'] == 0: |         if self.state['id'] == 0: | ||||||
|             aware_neighbors = self.get_neighboring_agents(state_id=1) |             aware_neighbors = self.get_neighboring_agents(state_id=1) | ||||||
|             num_neighbors_aware = len(aware_neighbors) |             num_neighbors_aware = len(aware_neighbors) | ||||||
|             if random.random() < (self.state_params['imitation_prob']*num_neighbors_aware): |             if random.random() < (self['imitation_prob']*num_neighbors_aware): | ||||||
|                 self.state['id'] = 1 |                 self.state['id'] = 1 | ||||||
|                 self.state['sentimentCorrelation'] = 1 |                 self.state['sentimentCorrelation'] = 1 | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| from . import BaseAgent | from . import NetworkAgent | ||||||
|  |  | ||||||
|  |  | ||||||
| class CounterModel(BaseAgent): | class CounterModel(NetworkAgent): | ||||||
|     """ |     """ | ||||||
|     Dummy behaviour. It counts the number of nodes in the network and neighbors |     Dummy behaviour. It counts the number of nodes in the network and neighbors | ||||||
|     in each step and adds it to its state. |     in each step and adds it to its state. | ||||||
| @@ -9,14 +9,14 @@ class CounterModel(BaseAgent): | |||||||
|  |  | ||||||
|     def step(self): |     def step(self): | ||||||
|         # Outside effects |         # Outside effects | ||||||
|         total = len(list(self.get_all_agents())) |         total = len(list(self.get_agents())) | ||||||
|         neighbors = len(list(self.get_neighboring_agents())) |         neighbors = len(list(self.get_neighboring_agents())) | ||||||
|         self['times'] = self.get('times', 0) + 1 |         self['times'] = self.get('times', 0) + 1 | ||||||
|         self['neighbors'] = neighbors |         self['neighbors'] = neighbors | ||||||
|         self['total'] = total |         self['total'] = total | ||||||
|  |  | ||||||
|  |  | ||||||
| class AggregatedCounter(BaseAgent): | class AggregatedCounter(NetworkAgent): | ||||||
|     """ |     """ | ||||||
|     Dummy behaviour. It counts the number of nodes in the network and neighbors |     Dummy behaviour. It counts the number of nodes in the network and neighbors | ||||||
|     in each step and adds it to its state. |     in each step and adds it to its state. | ||||||
| @@ -33,6 +33,6 @@ class AggregatedCounter(BaseAgent): | |||||||
|         self['times'] += 1 |         self['times'] += 1 | ||||||
|         neighbors = len(list(self.get_neighboring_agents())) |         neighbors = len(list(self.get_neighboring_agents())) | ||||||
|         self['neighbors'] += neighbors |         self['neighbors'] += neighbors | ||||||
|         total = len(list(self.get_all_agents())) |         total = len(list(self.get_agents())) | ||||||
|         self['total'] += total |         self['total'] += total | ||||||
|         self.debug('Running for step: {}. Total: {}'.format(self.now, total)) |         self.debug('Running for step: {}. Total: {}'.format(self.now, total)) | ||||||
|   | |||||||
| @@ -5,17 +5,17 @@ | |||||||
| # Initialize agent states. Let's assume everyone is normal. | # Initialize agent states. Let's assume everyone is normal. | ||||||
|  |  | ||||||
|  |  | ||||||
| import nxsim |  | ||||||
| import logging | import logging | ||||||
| from collections import OrderedDict | from collections import OrderedDict | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from functools import partial | from functools import partial | ||||||
| from scipy.spatial import cKDTree as KDTree | from scipy.spatial import cKDTree as KDTree | ||||||
| import json | import json | ||||||
|  | import simpy | ||||||
|  |  | ||||||
| from functools import wraps | from functools import wraps | ||||||
|  |  | ||||||
| from .. import serialization, history | from .. import serialization, history, utils | ||||||
|  |  | ||||||
|  |  | ||||||
| def as_node(agent): | def as_node(agent): | ||||||
| @@ -24,7 +24,7 @@ def as_node(agent): | |||||||
|     return agent |     return agent | ||||||
|  |  | ||||||
|  |  | ||||||
| class BaseAgent(nxsim.BaseAgent): | class BaseAgent: | ||||||
|     """ |     """ | ||||||
|     A special simpy BaseAgent that keeps track of its state history. |     A special simpy BaseAgent that keeps track of its state history. | ||||||
|     """ |     """ | ||||||
| @@ -32,14 +32,13 @@ class BaseAgent(nxsim.BaseAgent): | |||||||
|     defaults = {} |     defaults = {} | ||||||
|  |  | ||||||
|     def __init__(self, environment, agent_id, state=None, |     def __init__(self, environment, agent_id, state=None, | ||||||
|                  name=None, interval=None, **state_params): |                  name=None, interval=None): | ||||||
|         # Check for REQUIRED arguments |         # Check for REQUIRED arguments | ||||||
|         assert environment is not None, TypeError('__init__ missing 1 required keyword argument: \'environment\'. ' |         assert environment is not None, TypeError('__init__ missing 1 required keyword argument: \'environment\'. ' | ||||||
|                                                   'Cannot be NoneType.') |                                                   'Cannot be NoneType.') | ||||||
|         # Initialize agent parameters |         # Initialize agent parameters | ||||||
|         self.id = agent_id |         self.id = agent_id | ||||||
|         self.name = name or '{}[{}]'.format(type(self).__name__, self.id) |         self.name = name or '{}[{}]'.format(type(self).__name__, self.id) | ||||||
|         self.state_params = state_params |  | ||||||
|  |  | ||||||
|         # Register agent to environment |         # Register agent to environment | ||||||
|         self.env = environment |         self.env = environment | ||||||
| @@ -51,9 +50,9 @@ class BaseAgent(nxsim.BaseAgent): | |||||||
|         self.state = real_state |         self.state = real_state | ||||||
|         self.interval = interval |         self.interval = interval | ||||||
|  |  | ||||||
|         if not hasattr(self, 'level'): |         self.logger = logging.getLogger(self.env.name).getChild(self.name) | ||||||
|             self.level = logging.DEBUG |  | ||||||
|         self.logger = logging.getLogger(self.env.name) |         if hasattr(self, 'level'): | ||||||
|             self.logger.setLevel(self.level) |             self.logger.setLevel(self.level) | ||||||
|  |  | ||||||
|         # initialize every time an instance of the agent is created |         # initialize every time an instance of the agent is created | ||||||
| @@ -75,10 +74,6 @@ class BaseAgent(nxsim.BaseAgent): | |||||||
|         for k, v in value.items(): |         for k, v in value.items(): | ||||||
|             self[k] = v |             self[k] = v | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def global_topology(self): |  | ||||||
|         return self.env.G |  | ||||||
|      |  | ||||||
|     @property |     @property | ||||||
|     def environment_params(self): |     def environment_params(self): | ||||||
|         return self.env.environment_params |         return self.env.environment_params | ||||||
| @@ -135,36 +130,10 @@ class BaseAgent(nxsim.BaseAgent): | |||||||
|     def die(self, remove=False): |     def die(self, remove=False): | ||||||
|         self.alive = False |         self.alive = False | ||||||
|         if remove: |         if remove: | ||||||
|             super().die() |             self.remove_node(self.id) | ||||||
|  |  | ||||||
|     def step(self): |     def step(self): | ||||||
|         pass |         return | ||||||
|  |  | ||||||
|     def count_agents(self, **kwargs): |  | ||||||
|         return len(list(self.get_agents(**kwargs))) |  | ||||||
|  |  | ||||||
|     def count_neighboring_agents(self, state_id=None, **kwargs): |  | ||||||
|         return len(super().get_neighboring_agents(state_id=state_id, **kwargs)) |  | ||||||
|  |  | ||||||
|     def get_neighboring_agents(self, state_id=None, **kwargs): |  | ||||||
|         return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs) |  | ||||||
|  |  | ||||||
|     def get_agents(self, agents=None, limit_neighbors=False, **kwargs): |  | ||||||
|         if limit_neighbors: |  | ||||||
|             agents = super().get_agents(limit_neighbors=limit_neighbors) |  | ||||||
|         else: |  | ||||||
|             agents = self.env.get_agents(agents) |  | ||||||
|         return select(agents, **kwargs) |  | ||||||
|  |  | ||||||
|     def log(self, message, *args, level=logging.INFO, **kwargs): |  | ||||||
|         message = message + " ".join(str(i) for i in args) |  | ||||||
|         message = "\t{:10}@{:>5}:\t{}".format(self.name, self.now, message) |  | ||||||
|         for k, v in kwargs: |  | ||||||
|             message += " {k}={v} ".format(k, v) |  | ||||||
|         extra = {} |  | ||||||
|         extra['now'] = self.now |  | ||||||
|         extra['id'] = self.id |  | ||||||
|         return self.logger.log(level, message, extra=extra) |  | ||||||
|  |  | ||||||
|     def debug(self, *args, **kwargs): |     def debug(self, *args, **kwargs): | ||||||
|         return self.log(*args, level=logging.DEBUG, **kwargs) |         return self.log(*args, level=logging.DEBUG, **kwargs) | ||||||
| @@ -192,24 +161,59 @@ class BaseAgent(nxsim.BaseAgent): | |||||||
|         self._state = state['_state'] |         self._state = state['_state'] | ||||||
|         self.env = state['environment'] |         self.env = state['environment'] | ||||||
|  |  | ||||||
|     def add_edge(self, node1, node2, **attrs): | class NetworkAgent(BaseAgent): | ||||||
|         node1 = as_node(node1) |  | ||||||
|         node2 = as_node(node2) |  | ||||||
|  |  | ||||||
|         for n in [node1, node2]: |     @property | ||||||
|             if n not in self.global_topology.nodes(data=False): |     def topology(self): | ||||||
|                 raise ValueError('"{}" not in the graph'.format(n)) |         return self.env.G | ||||||
|         return self.global_topology.add_edge(node1, node2, **attrs) |  | ||||||
|  |     @property | ||||||
|  |     def G(self): | ||||||
|  |         return self.env.G | ||||||
|  |  | ||||||
|  |     def count_agents(self, **kwargs): | ||||||
|  |         return len(list(self.get_agents(**kwargs))) | ||||||
|  |  | ||||||
|  |     def count_neighboring_agents(self, state_id=None, **kwargs): | ||||||
|  |         return len(self.get_neighboring_agents(state_id=state_id, **kwargs)) | ||||||
|  |  | ||||||
|  |     def get_neighboring_agents(self, state_id=None, **kwargs): | ||||||
|  |         return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs) | ||||||
|  |  | ||||||
|  |     def get_agents(self, agents=None, limit_neighbors=False, **kwargs): | ||||||
|  |         if limit_neighbors: | ||||||
|  |             agents = self.topology.neighbors(self.id) | ||||||
|  |  | ||||||
|  |         agents = self.env.get_agents(agents) | ||||||
|  |         return select(agents, **kwargs) | ||||||
|  |  | ||||||
|  |     def log(self, message, *args, level=logging.INFO, **kwargs): | ||||||
|  |         message = message + " ".join(str(i) for i in args) | ||||||
|  |         message = " @{:>3}: {}".format(self.now, message) | ||||||
|  |         for k, v in kwargs: | ||||||
|  |             message += " {k}={v} ".format(k, v) | ||||||
|  |         extra = {} | ||||||
|  |         extra['now'] = self.now | ||||||
|  |         extra['agent_id'] = self.id | ||||||
|  |         extra['agent_name'] = self.name | ||||||
|  |         return self.logger.log(level, message, extra=extra) | ||||||
|  |  | ||||||
|     def subgraph(self, center=True, **kwargs): |     def subgraph(self, center=True, **kwargs): | ||||||
|         include = [self] if center else [] |         include = [self] if center else [] | ||||||
|         return self.global_topology.subgraph(n.id for n in self.get_agents(**kwargs)+include) |         return self.topology.subgraph(n.id for n in self.get_agents(**kwargs)+include) | ||||||
|  |  | ||||||
|  |     def remove_node(self, agent_id): | ||||||
|  |         self.topology.remove_node(agent_id) | ||||||
|  |  | ||||||
| class NetworkAgent(BaseAgent): |     def add_edge(self, other, edge_attr_dict=None, *edge_attrs): | ||||||
|  |         # return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs) | ||||||
|  |         if self.id not in self.topology.nodes(data=False): | ||||||
|  |             raise ValueError('{} not in list of existing agents in the network'.format(self.id)) | ||||||
|  |         if other not in self.topology.nodes(data=False): | ||||||
|  |             raise ValueError('{} not in list of existing agents in the network'.format(other)) | ||||||
|  |  | ||||||
|  |         self.topology.add_edge(self.id, other, edge_attr_dict=edge_attr_dict, *edge_attrs) | ||||||
|  |  | ||||||
|     def add_edge(self, other, **kwargs): |  | ||||||
|         return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs) |  | ||||||
|  |  | ||||||
|     def ego_search(self, steps=1, center=False, node=None, **kwargs): |     def ego_search(self, steps=1, center=False, node=None, **kwargs): | ||||||
|         '''Get a list of nodes in the ego network of *node* of radius *steps*''' |         '''Get a list of nodes in the ego network of *node* of radius *steps*''' | ||||||
| @@ -220,14 +224,14 @@ class NetworkAgent(BaseAgent): | |||||||
|     def degree(self, node, force=False): |     def degree(self, node, force=False): | ||||||
|         node = as_node(node) |         node = as_node(node) | ||||||
|         if force or (not hasattr(self.env, '_degree')) or getattr(self.env, '_last_step', 0) < self.now: |         if force or (not hasattr(self.env, '_degree')) or getattr(self.env, '_last_step', 0) < self.now: | ||||||
|             self.env._degree = nx.degree_centrality(self.global_topology) |             self.env._degree = nx.degree_centrality(self.topology) | ||||||
|             self.env._last_step = self.now |             self.env._last_step = self.now | ||||||
|         return self.env._degree[node] |         return self.env._degree[node] | ||||||
|  |  | ||||||
|     def betweenness(self, node, force=False): |     def betweenness(self, node, force=False): | ||||||
|         node = as_node(node) |         node = as_node(node) | ||||||
|         if force or (not hasattr(self.env, '_betweenness')) or getattr(self.env, '_last_step', 0) < self.now: |         if force or (not hasattr(self.env, '_betweenness')) or getattr(self.env, '_last_step', 0) < self.now: | ||||||
|             self.env._betweenness = nx.betweenness_centrality(self.global_topology) |             self.env._betweenness = nx.betweenness_centrality(self.topology) | ||||||
|             self.env._last_step = self.now |             self.env._last_step = self.now | ||||||
|         return self.env._betweenness[node] |         return self.env._betweenness[node] | ||||||
|  |  | ||||||
| @@ -292,16 +296,22 @@ class MetaFSM(type): | |||||||
|         cls.states = states |         cls.states = states | ||||||
|  |  | ||||||
|  |  | ||||||
| class FSM(BaseAgent, metaclass=MetaFSM): | class FSM(NetworkAgent, metaclass=MetaFSM): | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super(FSM, self).__init__(*args, **kwargs) |         super(FSM, self).__init__(*args, **kwargs) | ||||||
|         if 'id' not in self.state: |         if 'id' not in self.state: | ||||||
|             if not self.default_state: |             if not self.default_state: | ||||||
|                 raise ValueError('No default state specified for {}'.format(self.id)) |                 raise ValueError('No default state specified for {}'.format(self.id)) | ||||||
|             self['id'] = self.default_state.id |             self['id'] = self.default_state.id | ||||||
|  |         self._next_change = simpy.core.Infinity | ||||||
|  |         self._next_state = self.state | ||||||
|  |  | ||||||
|     def step(self): |     def step(self): | ||||||
|         if 'id' in self.state: |         if self._next_change < self.now: | ||||||
|  |             next_state = self._next_state | ||||||
|  |             self._next_change = simpy.core.Infinity | ||||||
|  |             self['id'] = next_state | ||||||
|  |         elif 'id' in self.state: | ||||||
|             next_state = self['id'] |             next_state = self['id'] | ||||||
|         elif self.default_state: |         elif self.default_state: | ||||||
|             next_state = self.default_state.id |             next_state = self.default_state.id | ||||||
| @@ -311,6 +321,10 @@ class FSM(BaseAgent, metaclass=MetaFSM): | |||||||
|             raise Exception('{} is not a valid id for {}'.format(next_state, self)) |             raise Exception('{} is not a valid id for {}'.format(next_state, self)) | ||||||
|         return self.states[next_state](self) |         return self.states[next_state](self) | ||||||
|  |  | ||||||
|  |     def next_state(self, state): | ||||||
|  |         self._next_change = self.now | ||||||
|  |         self._next_state = state | ||||||
|  |  | ||||||
|     def set_state(self, state): |     def set_state(self, state): | ||||||
|         if hasattr(state, 'id'): |         if hasattr(state, 'id'): | ||||||
|             state = state.id |             state = state.id | ||||||
| @@ -371,14 +385,18 @@ def calculate_distribution(network_agents=None, | |||||||
|     else: |     else: | ||||||
|         raise ValueError('Specify a distribution or a default agent type') |         raise ValueError('Specify a distribution or a default agent type') | ||||||
|  |  | ||||||
|  |     # Fix missing weights and incompatible types | ||||||
|  |     for x in network_agents: | ||||||
|  |         x['weight'] = float(x.get('weight', 1)) | ||||||
|  |  | ||||||
|     # Calculate the thresholds |     # Calculate the thresholds | ||||||
|     total = sum(x.get('weight', 1) for x in network_agents) |     total = sum(x['weight'] for x in network_agents) | ||||||
|     acc = 0 |     acc = 0 | ||||||
|     for v in network_agents: |     for v in network_agents: | ||||||
|         if 'ids' in v: |         if 'ids' in v: | ||||||
|             v['threshold'] = STATIC_THRESHOLD |             v['threshold'] = STATIC_THRESHOLD | ||||||
|             continue |             continue | ||||||
|         upper = acc + (v.get('weight', 1)/total) |         upper = acc + (v['weight']/total) | ||||||
|         v['threshold'] = [acc, upper] |         v['threshold'] = [acc, upper] | ||||||
|         acc = upper |         acc = upper | ||||||
|     return network_agents |     return network_agents | ||||||
| @@ -425,7 +443,7 @@ def _validate_states(states, topology): | |||||||
|     states = states or [] |     states = states or [] | ||||||
|     if isinstance(states, dict): |     if isinstance(states, dict): | ||||||
|         for x in states: |         for x in states: | ||||||
|             assert x in topology.node |             assert x in topology.nodes | ||||||
|     else: |     else: | ||||||
|         assert len(states) <= len(topology) |         assert len(states) <= len(topology) | ||||||
|     return states |     return states | ||||||
|   | |||||||
| @@ -28,13 +28,13 @@ def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs): | |||||||
|                 df = read_csv(trial_data, **kwargs) |                 df = read_csv(trial_data, **kwargs) | ||||||
|                 yield config_file, df, config |                 yield config_file, df, config | ||||||
|         else: |         else: | ||||||
|             for trial_data in sorted(glob.glob(join(folder, '*.db.sqlite'))): |             for trial_data in sorted(glob.glob(join(folder, '*.sqlite'))): | ||||||
|                 df = read_sql(trial_data, **kwargs) |                 df = read_sql(trial_data, **kwargs) | ||||||
|                 yield config_file, df, config |                 yield config_file, df, config | ||||||
|  |  | ||||||
|  |  | ||||||
| def read_sql(db, *args, **kwargs): | def read_sql(db, *args, **kwargs): | ||||||
|     h = history.History(db_path=db, backup=False) |     h = history.History(db_path=db, backup=False, readonly=True) | ||||||
|     df = h.read_sql(*args, **kwargs) |     df = h.read_sql(*args, **kwargs) | ||||||
|     return df |     return df | ||||||
|  |  | ||||||
| @@ -69,6 +69,13 @@ def convert_types_slow(df): | |||||||
|     df = df.apply(convert_row, axis=1) |     df = df.apply(convert_row, axis=1) | ||||||
|     return df |     return df | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def split_processed(df): | ||||||
|  |     env = df.loc[:, df.columns.get_level_values(1).isin(['env', 'stats'])] | ||||||
|  |     agents = df.loc[:, ~df.columns.get_level_values(1).isin(['env', 'stats'])] | ||||||
|  |     return env, agents | ||||||
|  |  | ||||||
|  |  | ||||||
| def split_df(df): | def split_df(df): | ||||||
|     ''' |     ''' | ||||||
|     Split a dataframe in two dataframes: one with the history of agents, |     Split a dataframe in two dataframes: one with the history of agents, | ||||||
| @@ -136,7 +143,7 @@ def get_value(df, *keys, aggfunc='sum'): | |||||||
|     return df.groupby(axis=1, level=0).agg(aggfunc) |     return df.groupby(axis=1, level=0).agg(aggfunc) | ||||||
|  |  | ||||||
|  |  | ||||||
| def plot_all(*args, **kwargs): | def plot_all(*args, plot_args={}, **kwargs): | ||||||
|     ''' |     ''' | ||||||
|     Read all the trial data and plot the result of applying a function on them. |     Read all the trial data and plot the result of applying a function on them. | ||||||
|     ''' |     ''' | ||||||
| @@ -144,14 +151,17 @@ def plot_all(*args, **kwargs): | |||||||
|     ps = [] |     ps = [] | ||||||
|     for line in dfs: |     for line in dfs: | ||||||
|         f, df, config = line |         f, df, config = line | ||||||
|         df.plot(title=config['name']) |         if len(df) < 1: | ||||||
|  |             continue | ||||||
|  |         df.plot(title=config['name'], **plot_args) | ||||||
|         ps.append(df) |         ps.append(df) | ||||||
|     return ps |     return ps | ||||||
|  |  | ||||||
| def do_all(pattern, func, *keys, include_env=False, **kwargs): | def do_all(pattern, func, *keys, include_env=False, **kwargs): | ||||||
|     for config_file, df, config in read_data(pattern, keys=keys): |     for config_file, df, config in read_data(pattern, keys=keys): | ||||||
|  |         if len(df) < 1: | ||||||
|  |             continue | ||||||
|         p = func(df, *keys, **kwargs) |         p = func(df, *keys, **kwargs) | ||||||
|         p.plot(title=config['name']) |  | ||||||
|         yield config_file, p, config |         yield config_file, p, config | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -8,11 +8,10 @@ import yaml | |||||||
| import tempfile | import tempfile | ||||||
| import pandas as pd | import pandas as pd | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from collections import Counter |  | ||||||
| from networkx.readwrite import json_graph | from networkx.readwrite import json_graph | ||||||
|  |  | ||||||
| import networkx as nx | import networkx as nx | ||||||
| import nxsim | import simpy | ||||||
|  |  | ||||||
| from . import serialization, agents, analysis, history, utils | from . import serialization, agents, analysis, history, utils | ||||||
|  |  | ||||||
| @@ -23,7 +22,7 @@ _CONFIG_PROPS = [ 'name', | |||||||
|                  'interval', |                  'interval', | ||||||
|                  ] |                  ] | ||||||
|  |  | ||||||
| class Environment(nxsim.NetworkEnvironment): | class Environment(simpy.Environment): | ||||||
|     """ |     """ | ||||||
|     The environment is key in a simulation. It contains the network topology, |     The environment is key in a simulation. It contains the network topology, | ||||||
|     a reference to network and environment agents, as well as the environment |     a reference to network and environment agents, as well as the environment | ||||||
| @@ -42,7 +41,10 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|                  interval=1, |                  interval=1, | ||||||
|                  seed=None, |                  seed=None, | ||||||
|                  topology=None, |                  topology=None, | ||||||
|                  *args, **kwargs): |                  initial_time=0, | ||||||
|  |                  **environment_params): | ||||||
|  |  | ||||||
|  |  | ||||||
|         self.name = name or 'UnnamedEnvironment' |         self.name = name or 'UnnamedEnvironment' | ||||||
|         seed = seed or time.time() |         seed = seed or time.time() | ||||||
|         random.seed(seed) |         random.seed(seed) | ||||||
| @@ -52,7 +54,11 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|         self.default_state = deepcopy(default_state) or {} |         self.default_state = deepcopy(default_state) or {} | ||||||
|         if not topology: |         if not topology: | ||||||
|             topology = nx.Graph() |             topology = nx.Graph() | ||||||
|         super().__init__(*args, topology=topology, **kwargs) |         self.G = nx.Graph(topology)  | ||||||
|  |  | ||||||
|  |         super().__init__(initial_time=initial_time) | ||||||
|  |         self.environment_params = environment_params | ||||||
|  |  | ||||||
|         self._env_agents = {} |         self._env_agents = {} | ||||||
|         self.interval = interval |         self.interval = interval | ||||||
|         self._history = history.History(name=self.name, |         self._history = history.History(name=self.name, | ||||||
| @@ -151,12 +157,10 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|         start = start or self.now |         start = start or self.now | ||||||
|         return self.G.add_edge(agent1, agent2, **attrs) |         return self.G.add_edge(agent1, agent2, **attrs) | ||||||
|  |  | ||||||
|     def run(self, *args, **kwargs): |     def run(self, until, *args, **kwargs): | ||||||
|         self._save_state() |         self._save_state() | ||||||
|         self.log_stats() |         super().run(until, *args, **kwargs) | ||||||
|         super().run(*args, **kwargs) |  | ||||||
|         self._history.flush_cache() |         self._history.flush_cache() | ||||||
|         self.log_stats() |  | ||||||
|  |  | ||||||
|     def _save_state(self, now=None): |     def _save_state(self, now=None): | ||||||
|         serialization.logger.debug('Saving state @{}'.format(self.now)) |         serialization.logger.debug('Saving state @{}'.format(self.now)) | ||||||
| @@ -318,25 +322,6 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|  |  | ||||||
|         return G |         return G | ||||||
|  |  | ||||||
|     def stats(self): |  | ||||||
|         stats = {} |  | ||||||
|         stats['network'] = {} |  | ||||||
|         stats['network']['n_nodes'] = self.G.number_of_nodes() |  | ||||||
|         stats['network']['n_edges'] = self.G.number_of_edges() |  | ||||||
|         c = Counter() |  | ||||||
|         c.update(a.__class__.__name__ for a in self.network_agents) |  | ||||||
|         stats['agents'] = {} |  | ||||||
|         stats['agents']['model_count'] = dict(c) |  | ||||||
|         c2 = Counter() |  | ||||||
|         c2.update(a['id'] for a in self.network_agents) |  | ||||||
|         stats['agents']['state_count'] = dict(c2) |  | ||||||
|         stats['params'] = self.environment_params |  | ||||||
|         return stats |  | ||||||
|  |  | ||||||
|     def log_stats(self): |  | ||||||
|         stats = self.stats() |  | ||||||
|         serialization.logger.info('Environment stats: \n{}'.format(yaml.dump(stats, default_flow_style=False))) |  | ||||||
|      |  | ||||||
|     def __getstate__(self): |     def __getstate__(self): | ||||||
|         state = {} |         state = {} | ||||||
|         for prop in _CONFIG_PROPS: |         for prop in _CONFIG_PROPS: | ||||||
| @@ -344,6 +329,7 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|         state['G'] = json_graph.node_link_data(self.G) |         state['G'] = json_graph.node_link_data(self.G) | ||||||
|         state['environment_agents'] = self._env_agents |         state['environment_agents'] = self._env_agents | ||||||
|         state['history'] = self._history |         state['history'] = self._history | ||||||
|  |         state['_now'] = self._now | ||||||
|         return state |         return state | ||||||
|  |  | ||||||
|     def __setstate__(self, state): |     def __setstate__(self, state): | ||||||
| @@ -352,6 +338,8 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|         self._env_agents = state['environment_agents'] |         self._env_agents = state['environment_agents'] | ||||||
|         self.G = json_graph.node_link_graph(state['G']) |         self.G = json_graph.node_link_graph(state['G']) | ||||||
|         self._history = state['history'] |         self._history = state['history'] | ||||||
|  |         self._now = state['_now'] | ||||||
|  |         self._queue = [] | ||||||
|  |  | ||||||
|  |  | ||||||
| SoilEnvironment = Environment | SoilEnvironment = Environment | ||||||
|   | |||||||
| @@ -1,10 +1,11 @@ | |||||||
| import os | import os | ||||||
|  | import csv as csvlib | ||||||
| import time | import time | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
|  |  | ||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
| import networkx as nx | import networkx as nx | ||||||
| import pandas as pd |  | ||||||
|  |  | ||||||
| from .serialization import deserialize | from .serialization import deserialize | ||||||
| from .utils import open_or_reuse, logger, timer | from .utils import open_or_reuse, logger, timer | ||||||
| @@ -49,7 +50,7 @@ class Exporter: | |||||||
|     ''' |     ''' | ||||||
|  |  | ||||||
|     def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None): |     def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None): | ||||||
|         self.sim = simulation |         self.simulation = simulation | ||||||
|         outdir = outdir or os.path.join(os.getcwd(), 'soil_output') |         outdir = outdir or os.path.join(os.getcwd(), 'soil_output') | ||||||
|         self.outdir = os.path.join(outdir, |         self.outdir = os.path.join(outdir, | ||||||
|                                    simulation.group or '', |                                    simulation.group or '', | ||||||
| @@ -59,12 +60,15 @@ class Exporter: | |||||||
|  |  | ||||||
|     def start(self): |     def start(self): | ||||||
|         '''Method to call when the simulation starts''' |         '''Method to call when the simulation starts''' | ||||||
|  |         pass | ||||||
|  |  | ||||||
|     def end(self): |     def end(self, stats): | ||||||
|         '''Method to call when the simulation ends''' |         '''Method to call when the simulation ends''' | ||||||
|  |         pass | ||||||
|  |  | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         '''Method to call when a trial ends''' |         '''Method to call when a trial ends''' | ||||||
|  |         pass | ||||||
|  |  | ||||||
|     def output(self, f, mode='w', **kwargs): |     def output(self, f, mode='w', **kwargs): | ||||||
|         if self.dry_run: |         if self.dry_run: | ||||||
| @@ -84,13 +88,13 @@ class default(Exporter): | |||||||
|     def start(self): |     def start(self): | ||||||
|         if not self.dry_run: |         if not self.dry_run: | ||||||
|             logger.info('Dumping results to %s', self.outdir) |             logger.info('Dumping results to %s', self.outdir) | ||||||
|             self.sim.dump_yaml(outdir=self.outdir) |             self.simulation.dump_yaml(outdir=self.outdir) | ||||||
|         else: |         else: | ||||||
|             logger.info('NOT dumping results') |             logger.info('NOT dumping results') | ||||||
|  |  | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         if not self.dry_run: |         if not self.dry_run: | ||||||
|             with timer('Dumping simulation {} trial {}'.format(self.sim.name, |             with timer('Dumping simulation {} trial {}'.format(self.simulation.name, | ||||||
|                                                                env.name)): |                                                                env.name)): | ||||||
|                 with self.output('{}.sqlite'.format(env.name), mode='wb') as f: |                 with self.output('{}.sqlite'.format(env.name), mode='wb') as f: | ||||||
|                     env.dump_sqlite(f) |                     env.dump_sqlite(f) | ||||||
| @@ -98,21 +102,27 @@ class default(Exporter): | |||||||
|  |  | ||||||
| class csv(Exporter): | class csv(Exporter): | ||||||
|     '''Export the state of each environment (and its agents) in a separate CSV file''' |     '''Export the state of each environment (and its agents) in a separate CSV file''' | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.sim.name, |         with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name, | ||||||
|                                                                           env.name, |                                                                           env.name, | ||||||
|                                                                           self.outdir)): |                                                                           self.outdir)): | ||||||
|             with self.output('{}.csv'.format(env.name)) as f: |             with self.output('{}.csv'.format(env.name)) as f: | ||||||
|                 env.dump_csv(f) |                 env.dump_csv(f) | ||||||
|  |  | ||||||
|  |             with self.output('{}.stats.csv'.format(env.name)) as f: | ||||||
|  |                 statwriter = csvlib.writer(f, delimiter='\t', quotechar='"', quoting=csvlib.QUOTE_ALL) | ||||||
|  |  | ||||||
|  |                 for stat in stats: | ||||||
|  |                     statwriter.writerow(stat) | ||||||
|  |  | ||||||
|  |  | ||||||
| class gexf(Exporter): | class gexf(Exporter): | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         if self.dry_run: |         if self.dry_run: | ||||||
|             logger.info('Not dumping GEXF in dry_run mode') |             logger.info('Not dumping GEXF in dry_run mode') | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         with timer('[GEXF] Dumping simulation {} trial {}'.format(self.sim.name, |         with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name, | ||||||
|                                                                   env.name)): |                                                                   env.name)): | ||||||
|             with self.output('{}.gexf'.format(env.name), mode='wb') as f: |             with self.output('{}.gexf'.format(env.name), mode='wb') as f: | ||||||
|                 env.dump_gexf(f) |                 env.dump_gexf(f) | ||||||
| @@ -124,56 +134,24 @@ class dummy(Exporter): | |||||||
|         with self.output('dummy', 'w') as f: |         with self.output('dummy', 'w') as f: | ||||||
|             f.write('simulation started @ {}\n'.format(time.time())) |             f.write('simulation started @ {}\n'.format(time.time())) | ||||||
|  |  | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         with self.output('dummy', 'w') as f: |         with self.output('dummy', 'w') as f: | ||||||
|             for i in env.history_to_tuples(): |             for i in env.history_to_tuples(): | ||||||
|                 f.write(','.join(map(str, i))) |                 f.write(','.join(map(str, i))) | ||||||
|                 f.write('\n') |                 f.write('\n') | ||||||
|  |  | ||||||
|     def end(self): |     def sim(self, stats): | ||||||
|         with self.output('dummy', 'a') as f: |         with self.output('dummy', 'a') as f: | ||||||
|             f.write('simulation ended @ {}\n'.format(time.time())) |             f.write('simulation ended @ {}\n'.format(time.time())) | ||||||
|  |  | ||||||
|  |  | ||||||
| class distribution(Exporter): |  | ||||||
|     ''' |  | ||||||
|     Write the distribution of agent states at the end of each trial, |  | ||||||
|     the mean value, and its deviation. |  | ||||||
|     ''' |  | ||||||
|  |  | ||||||
|     def start(self): |  | ||||||
|         self.means = [] |  | ||||||
|         self.counts = [] |  | ||||||
|  |  | ||||||
|     def trial_end(self, env): |  | ||||||
|         df = env[None, None, None].df() |  | ||||||
|         ix = df.index[-1] |  | ||||||
|         attrs = df.columns.levels[0] |  | ||||||
|         vc = {} |  | ||||||
|         stats = {} |  | ||||||
|         for a in attrs: |  | ||||||
|             t = df.loc[(ix, a)] |  | ||||||
|             try: |  | ||||||
|                 self.means.append(('mean', a, t.mean())) |  | ||||||
|             except TypeError: |  | ||||||
|                 for name, count in t.value_counts().iteritems(): |  | ||||||
|                     self.counts.append(('count', a, name, count)) |  | ||||||
|  |  | ||||||
|     def end(self): |  | ||||||
|         dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value']) |  | ||||||
|         dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count']) |  | ||||||
|         dfm = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min']) |  | ||||||
|         dfc = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min']) |  | ||||||
|         with self.output('counts.csv') as f: |  | ||||||
|             dfc.to_csv(f) |  | ||||||
|         with self.output('metrics.csv') as f: |  | ||||||
|             dfm.to_csv(f) |  | ||||||
|  |  | ||||||
| class graphdrawing(Exporter): | class graphdrawing(Exporter): | ||||||
|  |  | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         # Outside effects |         # Outside effects | ||||||
|         f = plt.figure() |         f = plt.figure() | ||||||
|         nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111)) |         nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111)) | ||||||
|         with open('graph-{}.png'.format(env.name)) as f: |         with open('graph-{}.png'.format(env.name)) as f: | ||||||
|             f.savefig(f) |             f.savefig(f) | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										136
									
								
								soil/history.py
									
									
									
									
									
								
							
							
						
						
									
										136
									
								
								soil/history.py
									
									
									
									
									
								
							| @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) | |||||||
| from collections import UserDict, namedtuple | from collections import UserDict, namedtuple | ||||||
|  |  | ||||||
| from . import serialization | from . import serialization | ||||||
| from .utils import open_or_reuse | from .utils import open_or_reuse, unflatten_dict | ||||||
|  |  | ||||||
|  |  | ||||||
| class History: | class History: | ||||||
| @@ -19,13 +19,22 @@ class History: | |||||||
|     Store and retrieve values from a sqlite database. |     Store and retrieve values from a sqlite database. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, name=None, db_path=None, backup=False): |     def __init__(self, name=None, db_path=None, backup=False, readonly=False): | ||||||
|         self._db = None |         if readonly and (not os.path.exists(db_path)): | ||||||
|  |             raise Exception('The DB file does not exist. Cannot open in read-only mode') | ||||||
|  |  | ||||||
|         if db_path is None: |         self._db = None | ||||||
|  |         self._temp = db_path is None | ||||||
|  |         self._stats_columns = None | ||||||
|  |         self.readonly = readonly | ||||||
|  |  | ||||||
|  |         if self._temp: | ||||||
|             if not name: |             if not name: | ||||||
|                 name = time.time() |                 name = time.time() | ||||||
|             _, db_path = tempfile.mkstemp(suffix='{}.sqlite'.format(name)) |             # The file will be deleted as soon as it's closed | ||||||
|  |             # Normally, that will be on destruction | ||||||
|  |             db_path = tempfile.NamedTemporaryFile(suffix='{}.sqlite'.format(name)).name | ||||||
|  |  | ||||||
|  |  | ||||||
|         if backup and os.path.exists(db_path): |         if backup and os.path.exists(db_path): | ||||||
|                 newname = db_path + '.backup{}.sqlite'.format(time.time()) |                 newname = db_path + '.backup{}.sqlite'.format(time.time()) | ||||||
| @@ -34,14 +43,19 @@ class History: | |||||||
|         self.db_path = db_path |         self.db_path = db_path | ||||||
|  |  | ||||||
|         self.db = db_path |         self.db = db_path | ||||||
|  |         self._dtypes = {} | ||||||
|  |         self._tups = [] | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         if self.readonly: | ||||||
|  |             return | ||||||
|  |  | ||||||
|         with self.db: |         with self.db: | ||||||
|             logger.debug('Creating database {}'.format(self.db_path)) |             logger.debug('Creating database {}'.format(self.db_path)) | ||||||
|             self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text text)''') |             self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text)''') | ||||||
|             self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''') |             self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''') | ||||||
|  |             self.db.execute('''CREATE TABLE IF NOT EXISTS stats (trial_id text)''') | ||||||
|             self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''') |             self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''') | ||||||
|         self._dtypes = {} |  | ||||||
|         self._tups = [] |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def db(self): |     def db(self): | ||||||
| @@ -58,6 +72,7 @@ class History: | |||||||
|         if isinstance(db_path, str): |         if isinstance(db_path, str): | ||||||
|             logger.debug('Connecting to database {}'.format(db_path)) |             logger.debug('Connecting to database {}'.format(db_path)) | ||||||
|             self._db = sqlite3.connect(db_path) |             self._db = sqlite3.connect(db_path) | ||||||
|  |             self._db.row_factory = sqlite3.Row | ||||||
|         else: |         else: | ||||||
|             self._db = db_path |             self._db = db_path | ||||||
|  |  | ||||||
| @@ -68,9 +83,56 @@ class History: | |||||||
|         self._db.close() |         self._db.close() | ||||||
|         self._db = None |         self._db = None | ||||||
|  |  | ||||||
|  |     def save_stats(self, stat): | ||||||
|  |         if self.readonly: | ||||||
|  |             print('DB in readonly mode') | ||||||
|  |             return | ||||||
|  |         if not stat: | ||||||
|  |             return | ||||||
|  |         with self.db: | ||||||
|  |             if not self._stats_columns: | ||||||
|  |                 self._stats_columns = list(c['name'] for c in self.db.execute('PRAGMA table_info(stats)')) | ||||||
|  |  | ||||||
|  |             for column, value in stat.items(): | ||||||
|  |                 if column in self._stats_columns: | ||||||
|  |                     continue | ||||||
|  |                 dtype = 'text' | ||||||
|  |                 if not isinstance(value, str): | ||||||
|  |                     try: | ||||||
|  |                         float(value) | ||||||
|  |                         dtype = 'real' | ||||||
|  |                         int(value) | ||||||
|  |                         dtype = 'int' | ||||||
|  |                     except ValueError: | ||||||
|  |                         pass | ||||||
|  |                 self.db.execute('ALTER TABLE stats ADD "{}" "{}"'.format(column, dtype)) | ||||||
|  |                 self._stats_columns.append(column) | ||||||
|  |  | ||||||
|  |             columns = ", ".join(map(lambda x: '"{}"'.format(x), stat.keys())) | ||||||
|  |             values = ", ".join(['"{0}"'.format(col) for col in stat.values()]) | ||||||
|  |             query = "INSERT INTO stats ({columns}) VALUES ({values})".format( | ||||||
|  |                 columns=columns, | ||||||
|  |                 values=values | ||||||
|  |             ) | ||||||
|  |             self.db.execute(query) | ||||||
|  |  | ||||||
|  |     def get_stats(self, unflatten=True): | ||||||
|  |         rows = self.db.execute("select * from stats").fetchall() | ||||||
|  |         res = [] | ||||||
|  |         for row in rows: | ||||||
|  |             d = {} | ||||||
|  |             for k in row.keys(): | ||||||
|  |                 if row[k] is None: | ||||||
|  |                     continue | ||||||
|  |                 d[k] = row[k] | ||||||
|  |             if unflatten: | ||||||
|  |                 d = unflatten_dict(d) | ||||||
|  |             res.append(d) | ||||||
|  |         return res | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def dtypes(self): |     def dtypes(self): | ||||||
|         self.read_types() |         self._read_types() | ||||||
|         return {k:v[0] for k, v in self._dtypes.items()} |         return {k:v[0] for k, v in self._dtypes.items()} | ||||||
|  |  | ||||||
|     def save_tuples(self, tuples): |     def save_tuples(self, tuples): | ||||||
| @@ -93,18 +155,10 @@ class History: | |||||||
|         Save a collection of records to the database. |         Save a collection of records to the database. | ||||||
|         Database writes are cached. |         Database writes are cached. | ||||||
|         ''' |         ''' | ||||||
|         value = self.convert(key, value) |         if self.readonly: | ||||||
|         self._tups.append(Record(agent_id=agent_id, |             raise Exception('DB in readonly mode') | ||||||
|                                  t_step=t_step, |  | ||||||
|                                  key=key, |  | ||||||
|                                  value=value)) |  | ||||||
|         if len(self._tups) > 100: |  | ||||||
|             self.flush_cache() |  | ||||||
|  |  | ||||||
|     def convert(self, key, value): |  | ||||||
|         """Get the serialized value for a given key.""" |  | ||||||
|         if key not in self._dtypes: |         if key not in self._dtypes: | ||||||
|             self.read_types() |             self._read_types() | ||||||
|             if key not in self._dtypes: |             if key not in self._dtypes: | ||||||
|                 name = serialization.name(value) |                 name = serialization.name(value) | ||||||
|                 serializer = serialization.serializer(name) |                 serializer = serialization.serializer(name) | ||||||
| @@ -112,21 +166,21 @@ class History: | |||||||
|                 self._dtypes[key] = (name, serializer, deserializer) |                 self._dtypes[key] = (name, serializer, deserializer) | ||||||
|                 with self.db: |                 with self.db: | ||||||
|                     self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name)) |                     self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name)) | ||||||
|         return self._dtypes[key][1](value) |         value = self._dtypes[key][1](value) | ||||||
|  |         self._tups.append(Record(agent_id=agent_id, | ||||||
|     def recover(self, key, value): |                                  t_step=t_step, | ||||||
|         """Get the deserialized value for a given key, and the serialized version.""" |                                  key=key, | ||||||
|         if key not in self._dtypes: |                                  value=value)) | ||||||
|             self.read_types() |         if len(self._tups) > 100: | ||||||
|         if key not in self._dtypes: |             self.flush_cache() | ||||||
|             raise ValueError("Unknown datatype for {} and {}".format(key, value)) |  | ||||||
|         return self._dtypes[key][2](value) |  | ||||||
|  |  | ||||||
|     def flush_cache(self): |     def flush_cache(self): | ||||||
|         ''' |         ''' | ||||||
|         Use a cache to save state changes to avoid opening a session for every change. |         Use a cache to save state changes to avoid opening a session for every change. | ||||||
|         The cache will be flushed at the end of the simulation, and when history is accessed. |         The cache will be flushed at the end of the simulation, and when history is accessed. | ||||||
|         ''' |         ''' | ||||||
|  |         if self.readonly: | ||||||
|  |             raise Exception('DB in readonly mode') | ||||||
|         logger.debug('Flushing cache {}'.format(self.db_path)) |         logger.debug('Flushing cache {}'.format(self.db_path)) | ||||||
|         with self.db: |         with self.db: | ||||||
|             for rec in self._tups: |             for rec in self._tups: | ||||||
| @@ -139,10 +193,14 @@ class History: | |||||||
|             res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall() |             res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall() | ||||||
|         for r in res: |         for r in res: | ||||||
|             agent_id, t_step, key, value = r |             agent_id, t_step, key, value = r | ||||||
|             value = self.recover(key, value) |             if key not in self._dtypes: | ||||||
|  |                 self._read_types() | ||||||
|  |             if key not in self._dtypes: | ||||||
|  |                 raise ValueError("Unknown datatype for {} and {}".format(key, value)) | ||||||
|  |             value = self._dtypes[key][2](value) | ||||||
|             yield agent_id, t_step, key, value |             yield agent_id, t_step, key, value | ||||||
|  |  | ||||||
|     def read_types(self): |     def _read_types(self): | ||||||
|         with self.db: |         with self.db: | ||||||
|             res = self.db.execute("select key, value_type from value_types ").fetchall() |             res = self.db.execute("select key, value_type from value_types ").fetchall() | ||||||
|         for k, v in res: |         for k, v in res: | ||||||
| @@ -167,7 +225,7 @@ class History: | |||||||
|  |  | ||||||
|     def read_sql(self, keys=None, agent_ids=None, t_steps=None, convert_types=False, limit=-1): |     def read_sql(self, keys=None, agent_ids=None, t_steps=None, convert_types=False, limit=-1): | ||||||
|  |  | ||||||
|         self.read_types() |         self._read_types() | ||||||
|  |  | ||||||
|         def escape_and_join(v): |         def escape_and_join(v): | ||||||
|             if v is None: |             if v is None: | ||||||
| @@ -181,7 +239,13 @@ class History: | |||||||
|  |  | ||||||
|         last_df = None |         last_df = None | ||||||
|         if t_steps: |         if t_steps: | ||||||
|             # Look for the last value before the minimum step in the query |             # Convert negative indices into positive | ||||||
|  |             if any(x<0 for x in t_steps): | ||||||
|  |                 max_t = int(self.db.execute("select max(t_step) from history").fetchone()[0]) | ||||||
|  |                 t_steps = [t if t>0 else max_t+1+t for t in t_steps] | ||||||
|  |  | ||||||
|  |             # We will be doing ffill interpolation, so we need to look for | ||||||
|  |             # the last value before the minimum step in the query | ||||||
|             min_step = min(t_steps) |             min_step = min(t_steps) | ||||||
|             last_filters = ['t_step < {}'.format(min_step),] |             last_filters = ['t_step < {}'.format(min_step),] | ||||||
|             last_filters = last_filters + filters |             last_filters = last_filters + filters | ||||||
| @@ -219,7 +283,11 @@ class History: | |||||||
|         for k, v in self._dtypes.items(): |         for k, v in self._dtypes.items(): | ||||||
|             if k in df_p: |             if k in df_p: | ||||||
|                 dtype, _, deserial = v |                 dtype, _, deserial = v | ||||||
|  |                 try: | ||||||
|                     df_p[k] = df_p[k].fillna(method='ffill').astype(dtype) |                     df_p[k] = df_p[k].fillna(method='ffill').astype(dtype) | ||||||
|  |                 except (TypeError, ValueError): | ||||||
|  |                     # Avoid forward-filling unknown/incompatible types | ||||||
|  |                     continue | ||||||
|         if t_steps: |         if t_steps: | ||||||
|             df_p = df_p.reindex(t_steps, method='ffill') |             df_p = df_p.reindex(t_steps, method='ffill') | ||||||
|         return df_p.ffill() |         return df_p.ffill() | ||||||
| @@ -313,3 +381,5 @@ class Records(): | |||||||
|  |  | ||||||
| Key = namedtuple('Key', ['agent_id', 't_step', 'key']) | Key = namedtuple('Key', ['agent_id', 't_step', 'key']) | ||||||
| Record = namedtuple('Record', 'agent_id t_step key value') | Record = namedtuple('Record', 'agent_id t_step key value') | ||||||
|  |  | ||||||
|  | Stat = namedtuple('Stat', 'trial_id') | ||||||
|   | |||||||
| @@ -17,10 +17,10 @@ logger.setLevel(logging.INFO) | |||||||
|  |  | ||||||
|  |  | ||||||
| def load_network(network_params, dir_path=None): | def load_network(network_params, dir_path=None): | ||||||
|     if network_params is None: |     G = nx.Graph() | ||||||
|         return nx.Graph() |  | ||||||
|     path = network_params.get('path', None) |     if 'path' in network_params: | ||||||
|     if path: |         path = network_params['path'] | ||||||
|         if dir_path and not os.path.isabs(path): |         if dir_path and not os.path.isabs(path): | ||||||
|             path = os.path.join(dir_path, path) |             path = os.path.join(dir_path, path) | ||||||
|         extension = os.path.splitext(path)[1][1:] |         extension = os.path.splitext(path)[1][1:] | ||||||
| @@ -32,12 +32,10 @@ def load_network(network_params, dir_path=None): | |||||||
|             method = getattr(nx.readwrite, 'read_' + extension) |             method = getattr(nx.readwrite, 'read_' + extension) | ||||||
|         except AttributeError: |         except AttributeError: | ||||||
|             raise AttributeError('Unknown format') |             raise AttributeError('Unknown format') | ||||||
|         return method(path, **kwargs) |         G = method(path, **kwargs) | ||||||
|  |  | ||||||
|  |     elif 'generator' in network_params: | ||||||
|         net_args = network_params.copy() |         net_args = network_params.copy() | ||||||
|     if 'generator' not in net_args: |  | ||||||
|         return nx.Graph() |  | ||||||
|  |  | ||||||
|         net_gen = net_args.pop('generator') |         net_gen = net_args.pop('generator') | ||||||
|  |  | ||||||
|         if dir_path not in sys.path: |         if dir_path not in sys.path: | ||||||
| @@ -45,8 +43,11 @@ def load_network(network_params, dir_path=None): | |||||||
|  |  | ||||||
|         method = deserializer(net_gen, |         method = deserializer(net_gen, | ||||||
|                               known_modules=['networkx.generators',]) |                               known_modules=['networkx.generators',]) | ||||||
|  |         G = method(**net_args) | ||||||
|  |  | ||||||
|  |     return G | ||||||
|  |  | ||||||
|  |  | ||||||
|     return method(**net_args) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_file(infile): | def load_file(infile): | ||||||
| @@ -66,11 +67,32 @@ def expand_template(config): | |||||||
|         raise ValueError(('You must provide a definition of variables' |         raise ValueError(('You must provide a definition of variables' | ||||||
|                           ' for the template.')) |                           ' for the template.')) | ||||||
|  |  | ||||||
|     template = Template(config['template']) |     template = config['template'] | ||||||
|  |  | ||||||
|     sampler_name = config.get('sampler', 'SALib.sample.morris.sample') |     if not isinstance(template, str): | ||||||
|     n_samples = int(config.get('samples', 100)) |         template = yaml.dump(template) | ||||||
|     sampler = deserializer(sampler_name) |  | ||||||
|  |     template = Template(template) | ||||||
|  |  | ||||||
|  |     params = params_for_template(config) | ||||||
|  |  | ||||||
|  |     blank_str = template.render({k: 0 for k in params[0].keys()}) | ||||||
|  |     blank = list(load_string(blank_str)) | ||||||
|  |     if len(blank) > 1: | ||||||
|  |         raise ValueError('Templates must not return more than one configuration') | ||||||
|  |     if 'name' in blank[0]: | ||||||
|  |         raise ValueError('Templates cannot be named, use group instead') | ||||||
|  |  | ||||||
|  |     for ps in params: | ||||||
|  |         string = template.render(ps) | ||||||
|  |         for c in load_string(string): | ||||||
|  |             yield c | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def params_for_template(config): | ||||||
|  |     sampler_config = config.get('sampler', {'N': 100}) | ||||||
|  |     sampler = sampler_config.pop('method', 'SALib.sample.morris.sample') | ||||||
|  |     sampler = deserializer(sampler) | ||||||
|     bounds = config['vars']['bounds'] |     bounds = config['vars']['bounds'] | ||||||
|  |  | ||||||
|     problem = { |     problem = { | ||||||
| @@ -78,7 +100,7 @@ def expand_template(config): | |||||||
|         'names': list(bounds.keys()), |         'names': list(bounds.keys()), | ||||||
|         'bounds': list(v for v in bounds.values()) |         'bounds': list(v for v in bounds.values()) | ||||||
|     } |     } | ||||||
|     samples = sampler(problem, n_samples) |     samples = sampler(problem, **sampler_config) | ||||||
|  |  | ||||||
|     lists = config['vars'].get('lists', {}) |     lists = config['vars'].get('lists', {}) | ||||||
|     names = list(lists.keys()) |     names = list(lists.keys()) | ||||||
| @@ -88,20 +110,7 @@ def expand_template(config): | |||||||
|     allnames = names + problem['names'] |     allnames = names + problem['names'] | ||||||
|     allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)] |     allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)] | ||||||
|     params = list(map(lambda x: dict(zip(allnames, x)), allvalues)) |     params = list(map(lambda x: dict(zip(allnames, x)), allvalues)) | ||||||
|  |     return params | ||||||
|  |  | ||||||
|     blank_str = template.render({k: 0 for k in allnames}) |  | ||||||
|     blank = list(load_string(blank_str)) |  | ||||||
|     if len(blank) > 1: |  | ||||||
|         raise ValueError('Templates must not return more than one configuration') |  | ||||||
|     if 'name' in blank[0]: |  | ||||||
|         raise ValueError('Templates cannot be named, use group instead') |  | ||||||
|  |  | ||||||
|     confs = [] |  | ||||||
|     for ps in params: |  | ||||||
|         string = template.render(ps) |  | ||||||
|         for c in load_string(string): |  | ||||||
|             yield c |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_files(*patterns, **kwargs): | def load_files(*patterns, **kwargs): | ||||||
| @@ -116,7 +125,7 @@ def load_files(*patterns, **kwargs): | |||||||
|  |  | ||||||
| def load_config(config): | def load_config(config): | ||||||
|     if isinstance(config, dict): |     if isinstance(config, dict): | ||||||
|         yield config, None |         yield config, os.getcwd() | ||||||
|     else: |     else: | ||||||
|         yield from load_files(config) |         yield from load_files(config) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ import importlib | |||||||
| import sys | import sys | ||||||
| import yaml | import yaml | ||||||
| import traceback | import traceback | ||||||
|  | import logging | ||||||
| import networkx as nx | import networkx as nx | ||||||
| from networkx.readwrite import json_graph | from networkx.readwrite import json_graph | ||||||
| from multiprocessing import Pool | from multiprocessing import Pool | ||||||
| @@ -11,17 +12,19 @@ from functools import partial | |||||||
|  |  | ||||||
| import pickle | import pickle | ||||||
|  |  | ||||||
| from nxsim import NetworkSimulation |  | ||||||
|  |  | ||||||
| from . import serialization, utils, basestring, agents | from . import serialization, utils, basestring, agents | ||||||
| from .environment import Environment | from .environment import Environment | ||||||
| from .utils import logger | from .utils import logger | ||||||
| from .exporters import for_sim as exporters_for_sim | from .exporters import default, for_sim as exporters_for_sim | ||||||
|  | from .stats import defaultStats | ||||||
|  | from .history import History | ||||||
|  |  | ||||||
|  |  | ||||||
| class Simulation(NetworkSimulation): | #TODO: change documentation for simulation | ||||||
|  |  | ||||||
|  | class Simulation: | ||||||
|     """ |     """ | ||||||
|     Subclass of nsim.NetworkSimulation with three main differences: |     Similar to nsim.NetworkSimulation with three main differences: | ||||||
|         1) agent type can be specified by name or by class. |         1) agent type can be specified by name or by class. | ||||||
|         2) instead of just one type, a network agents distribution can be used. |         2) instead of just one type, a network agents distribution can be used. | ||||||
|            The distribution specifies the weight (or probability) of each |            The distribution specifies the weight (or probability) of each | ||||||
| @@ -91,11 +94,12 @@ class Simulation(NetworkSimulation): | |||||||
|                  environment_params=None, environment_class=None, |                  environment_params=None, environment_class=None, | ||||||
|                  **kwargs): |                  **kwargs): | ||||||
|  |  | ||||||
|         self.seed = str(seed) or str(time.time()) |  | ||||||
|         self.load_module = load_module |         self.load_module = load_module | ||||||
|         self.network_params = network_params |         self.network_params = network_params | ||||||
|         self.name = name or 'Unnamed_' + time.strftime("%Y-%m-%d_%H.%M.%S") |         self.name = name or 'Unnamed' | ||||||
|         self.group = group or None |         self.seed = str(seed or name) | ||||||
|  |         self._id = '{}_{}'.format(self.name, time.strftime("%Y-%m-%d_%H.%M.%S")) | ||||||
|  |         self.group = group or '' | ||||||
|         self.num_trials = num_trials |         self.num_trials = num_trials | ||||||
|         self.max_time = max_time |         self.max_time = max_time | ||||||
|         self.default_state = default_state or {} |         self.default_state = default_state or {} | ||||||
| @@ -128,12 +132,15 @@ class Simulation(NetworkSimulation): | |||||||
|         self.states = agents._validate_states(states, |         self.states = agents._validate_states(states, | ||||||
|                                               self.topology) |                                               self.topology) | ||||||
|  |  | ||||||
|  |         self._history = History(name=self.name, | ||||||
|  |                                backup=False) | ||||||
|  |  | ||||||
|     def run_simulation(self, *args, **kwargs): |     def run_simulation(self, *args, **kwargs): | ||||||
|         return self.run(*args, **kwargs) |         return self.run(*args, **kwargs) | ||||||
|  |  | ||||||
|     def run(self, *args, **kwargs): |     def run(self, *args, **kwargs): | ||||||
|         '''Run the simulation and return the list of resulting environments''' |         '''Run the simulation and return the list of resulting environments''' | ||||||
|         return list(self._run_simulation_gen(*args, **kwargs)) |         return list(self.run_gen(*args, **kwargs)) | ||||||
|  |  | ||||||
|     def _run_sync_or_async(self, parallel=False, *args, **kwargs): |     def _run_sync_or_async(self, parallel=False, *args, **kwargs): | ||||||
|         if parallel: |         if parallel: | ||||||
| @@ -148,12 +155,16 @@ class Simulation(NetworkSimulation): | |||||||
|                 yield i |                 yield i | ||||||
|         else: |         else: | ||||||
|             for i in range(self.num_trials): |             for i in range(self.num_trials): | ||||||
|                 yield self.run_trial(i, |                 yield self.run_trial(*args, | ||||||
|                                      *args, |  | ||||||
|                                      **kwargs) |                                      **kwargs) | ||||||
|  |  | ||||||
|     def _run_simulation_gen(self, *args, parallel=False, dry_run=False, |     def run_gen(self, *args, parallel=False, dry_run=False, | ||||||
|                             exporters=['default', ], outdir=None, exporter_params={}, **kwargs): |                 exporters=[default, ], stats=[defaultStats], outdir=None, exporter_params={}, | ||||||
|  |                 stats_params={}, log_level=None, | ||||||
|  |                 **kwargs): | ||||||
|  |         '''Run the simulation and yield the resulting environments.''' | ||||||
|  |         if log_level: | ||||||
|  |             logger.setLevel(log_level) | ||||||
|         logger.info('Using exporters: %s', exporters or []) |         logger.info('Using exporters: %s', exporters or []) | ||||||
|         logger.info('Output directory: %s', outdir) |         logger.info('Output directory: %s', outdir) | ||||||
|         exporters = exporters_for_sim(self, |         exporters = exporters_for_sim(self, | ||||||
| @@ -161,31 +172,63 @@ class Simulation(NetworkSimulation): | |||||||
|                                       dry_run=dry_run, |                                       dry_run=dry_run, | ||||||
|                                       outdir=outdir, |                                       outdir=outdir, | ||||||
|                                       **exporter_params) |                                       **exporter_params) | ||||||
|  |         stats = exporters_for_sim(self, | ||||||
|  |                                   stats, | ||||||
|  |                                   **stats_params) | ||||||
|  |  | ||||||
|         with utils.timer('simulation {}'.format(self.name)): |         with utils.timer('simulation {}'.format(self.name)): | ||||||
|  |             for stat in stats: | ||||||
|  |                 stat.start() | ||||||
|  |  | ||||||
|             for exporter in exporters: |             for exporter in exporters: | ||||||
|                 exporter.start() |                 exporter.start() | ||||||
|  |             for env in self._run_sync_or_async(*args, | ||||||
|             for env in self._run_sync_or_async(*args, parallel=parallel, |                                                parallel=parallel, | ||||||
|  |                                                log_level=log_level, | ||||||
|                                                **kwargs): |                                                **kwargs): | ||||||
|  |  | ||||||
|  |                 collected = list(stat.trial(env) for stat in stats) | ||||||
|  |  | ||||||
|  |                 saved = self.save_stats(collected, t_step=env.now, trial_id=env.name) | ||||||
|  |  | ||||||
|                 for exporter in exporters: |                 for exporter in exporters: | ||||||
|                     exporter.trial_end(env) |                     exporter.trial(env, saved) | ||||||
|  |  | ||||||
|                 yield env |                 yield env | ||||||
|  |  | ||||||
|  |  | ||||||
|  |             collected = list(stat.end() for stat in stats) | ||||||
|  |             saved = self.save_stats(collected) | ||||||
|  |  | ||||||
|             for exporter in exporters: |             for exporter in exporters: | ||||||
|                 exporter.end() |                 exporter.end(saved) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     def save_stats(self, collection, **kwargs): | ||||||
|  |         stats = dict(kwargs) | ||||||
|  |         for stat in collection: | ||||||
|  |             stats.update(stat) | ||||||
|  |         self._history.save_stats(utils.flatten_dict(stats)) | ||||||
|  |         return stats | ||||||
|  |  | ||||||
|  |     def get_stats(self, **kwargs): | ||||||
|  |         return self._history.get_stats(**kwargs) | ||||||
|  |  | ||||||
|  |     def log_stats(self, stats): | ||||||
|  |         logger.info('Stats: \n{}'.format(yaml.dump(stats, default_flow_style=False))) | ||||||
|  |      | ||||||
|  |  | ||||||
|     def get_env(self, trial_id=0, **kwargs): |     def get_env(self, trial_id=0, **kwargs): | ||||||
|         '''Create an environment for a trial of the simulation''' |         '''Create an environment for a trial of the simulation''' | ||||||
|         opts = self.environment_params.copy() |         opts = self.environment_params.copy() | ||||||
|         env_name = '{}_trial_{}'.format(self.name, trial_id) |  | ||||||
|         opts.update({ |         opts.update({ | ||||||
|             'name': env_name, |             'name': trial_id, | ||||||
|             'topology': self.topology.copy(), |             'topology': self.topology.copy(), | ||||||
|             'seed': self.seed+env_name, |             'seed': '{}_trial_{}'.format(self.seed, trial_id), | ||||||
|             'initial_time': 0, |             'initial_time': 0, | ||||||
|             'interval': self.interval, |             'interval': self.interval, | ||||||
|             'network_agents': self.network_agents, |             'network_agents': self.network_agents, | ||||||
|  |             'initial_time': 0, | ||||||
|             'states': self.states, |             'states': self.states, | ||||||
|             'default_state': self.default_state, |             'default_state': self.default_state, | ||||||
|             'environment_agents': self.environment_agents, |             'environment_agents': self.environment_agents, | ||||||
| @@ -194,13 +237,14 @@ class Simulation(NetworkSimulation): | |||||||
|         env = self.environment_class(**opts) |         env = self.environment_class(**opts) | ||||||
|         return env |         return env | ||||||
|  |  | ||||||
|     def run_trial(self, trial_id=0, until=None, **opts): |     def run_trial(self, until=None, log_level=logging.INFO, **opts): | ||||||
|         """Run a single trial of the simulation |  | ||||||
|  |  | ||||||
|         Parameters |  | ||||||
|         ---------- |  | ||||||
|         trial_id : int |  | ||||||
|         """ |         """ | ||||||
|  |         Run a single trial of the simulation | ||||||
|  |  | ||||||
|  |         """ | ||||||
|  |         trial_id = '{}_trial_{}'.format(self.name, time.time()).replace('.', '-') | ||||||
|  |         if log_level: | ||||||
|  |             logger.setLevel(log_level) | ||||||
|         # Set-up trial environment and graph |         # Set-up trial environment and graph | ||||||
|         until = until or self.max_time |         until = until or self.max_time | ||||||
|         env = self.get_env(trial_id=trial_id, **opts) |         env = self.get_env(trial_id=trial_id, **opts) | ||||||
| @@ -208,6 +252,7 @@ class Simulation(NetworkSimulation): | |||||||
|         with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)): |         with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)): | ||||||
|             env.run(until) |             env.run(until) | ||||||
|         return env |         return env | ||||||
|  |  | ||||||
|     def run_trial_exceptions(self, *args, **kwargs): |     def run_trial_exceptions(self, *args, **kwargs): | ||||||
|         ''' |         ''' | ||||||
|         A wrapper for run_trial that catches exceptions and returns them. |         A wrapper for run_trial that catches exceptions and returns them. | ||||||
|   | |||||||
							
								
								
									
										106
									
								
								soil/stats.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								soil/stats.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,106 @@ | |||||||
|  | import pandas as pd | ||||||
|  |  | ||||||
|  | from collections import Counter | ||||||
|  |  | ||||||
|  | class Stats: | ||||||
|  |     ''' | ||||||
|  |     Interface for all stats. It is not necessary, but it is useful | ||||||
|  |     if you don't plan to implement all the methods. | ||||||
|  |     ''' | ||||||
|  |  | ||||||
|  |     def __init__(self, simulation): | ||||||
|  |         self.simulation = simulation | ||||||
|  |  | ||||||
|  |     def start(self): | ||||||
|  |         '''Method to call when the simulation starts''' | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     def end(self): | ||||||
|  |         '''Method to call when the simulation ends''' | ||||||
|  |         return {} | ||||||
|  |  | ||||||
|  |     def trial(self, env): | ||||||
|  |         '''Method to call when a trial ends''' | ||||||
|  |         return {} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class distribution(Stats): | ||||||
|  |     ''' | ||||||
|  |     Calculate the distribution of agent states at the end of each trial, | ||||||
|  |     the mean value, and its deviation. | ||||||
|  |     ''' | ||||||
|  |  | ||||||
|  |     def start(self): | ||||||
|  |         self.means = [] | ||||||
|  |         self.counts = [] | ||||||
|  |  | ||||||
|  |     def trial(self, env): | ||||||
|  |         df = env[None, None, None].df() | ||||||
|  |         df = df.drop('SEED', axis=1) | ||||||
|  |         ix = df.index[-1] | ||||||
|  |         attrs = df.columns.get_level_values(0) | ||||||
|  |         vc = {} | ||||||
|  |         stats = { | ||||||
|  |             'mean': {}, | ||||||
|  |             'count': {}, | ||||||
|  |         } | ||||||
|  |         for a in attrs: | ||||||
|  |             t = df.loc[(ix, a)] | ||||||
|  |             try: | ||||||
|  |                 stats['mean'][a] = t.mean() | ||||||
|  |                 self.means.append(('mean', a, t.mean())) | ||||||
|  |             except TypeError: | ||||||
|  |                 pass | ||||||
|  |  | ||||||
|  |             for name, count in t.value_counts().iteritems(): | ||||||
|  |                 if a not in stats['count']: | ||||||
|  |                     stats['count'][a] = {} | ||||||
|  |                 stats['count'][a][name] = count | ||||||
|  |                 self.counts.append(('count', a, name, count)) | ||||||
|  |  | ||||||
|  |         return stats | ||||||
|  |  | ||||||
|  |     def end(self): | ||||||
|  |         dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value']) | ||||||
|  |         dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count']) | ||||||
|  |  | ||||||
|  |         count = {} | ||||||
|  |         mean = {} | ||||||
|  |  | ||||||
|  |         if self.means: | ||||||
|  |             res = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min']) | ||||||
|  |             mean = res['value'].to_dict() | ||||||
|  |         if self.counts: | ||||||
|  |             res = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min']) | ||||||
|  |             for k,v in res['count'].to_dict().items(): | ||||||
|  |                 if k not in count: | ||||||
|  |                     count[k] = {} | ||||||
|  |                 for tup, times in v.items(): | ||||||
|  |                     subkey, subcount = tup | ||||||
|  |                     if subkey not in count[k]: | ||||||
|  |                         count[k][subkey] = {} | ||||||
|  |                     count[k][subkey][subcount] = times | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         return {'count': count, 'mean': mean} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class defaultStats(Stats): | ||||||
|  |  | ||||||
|  |     def trial(self, env): | ||||||
|  |         c = Counter() | ||||||
|  |         c.update(a.__class__.__name__ for a in env.network_agents) | ||||||
|  |  | ||||||
|  |         c2 = Counter() | ||||||
|  |         c2.update(a['id'] for a in env.network_agents) | ||||||
|  |  | ||||||
|  |         return { | ||||||
|  |             'network ': { | ||||||
|  |                 'n_nodes': env.G.number_of_nodes(), | ||||||
|  |                 'n_edges': env.G.number_of_nodes(), | ||||||
|  |             }, | ||||||
|  |             'agents': { | ||||||
|  |                 'model_count': dict(c), | ||||||
|  |                 'state_count': dict(c2), | ||||||
|  |             } | ||||||
|  |         } | ||||||
| @@ -7,6 +7,7 @@ from shutil import copyfile | |||||||
| from contextlib import contextmanager | from contextlib import contextmanager | ||||||
|  |  | ||||||
| logger = logging.getLogger('soil') | logger = logging.getLogger('soil') | ||||||
|  | logging.basicConfig() | ||||||
| logger.setLevel(logging.INFO) | logger.setLevel(logging.INFO) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -31,14 +32,13 @@ def safe_open(path, mode='r', backup=True, **kwargs): | |||||||
|         os.makedirs(outdir) |         os.makedirs(outdir) | ||||||
|     if backup and 'w' in mode and os.path.exists(path): |     if backup and 'w' in mode and os.path.exists(path): | ||||||
|         creation = os.path.getctime(path) |         creation = os.path.getctime(path) | ||||||
|         stamp = time.strftime('%Y-%m-%d_%H.%M', time.localtime(creation)) |         stamp = time.strftime('%Y-%m-%d_%H.%M.%S', time.localtime(creation)) | ||||||
|  |  | ||||||
|         backup_dir = os.path.join(outdir, stamp) |         backup_dir = os.path.join(outdir, 'backup') | ||||||
|         if not os.path.exists(backup_dir): |         if not os.path.exists(backup_dir): | ||||||
|             os.makedirs(backup_dir) |             os.makedirs(backup_dir) | ||||||
|         newpath = os.path.join(backup_dir, os.path.basename(path)) |         newpath = os.path.join(backup_dir, '{}@{}'.format(os.path.basename(path), | ||||||
|         if os.path.exists(newpath): |                                                                stamp)) | ||||||
|             newpath = '{}@{}'.format(newpath, time.time()) |  | ||||||
|         copyfile(path, newpath) |         copyfile(path, newpath) | ||||||
|     return open(path, mode=mode, **kwargs) |     return open(path, mode=mode, **kwargs) | ||||||
|  |  | ||||||
| @@ -48,3 +48,40 @@ def open_or_reuse(f, *args, **kwargs): | |||||||
|         return safe_open(f, *args, **kwargs) |         return safe_open(f, *args, **kwargs) | ||||||
|     except (AttributeError, TypeError): |     except (AttributeError, TypeError): | ||||||
|         return f |         return f | ||||||
|  |  | ||||||
|  | def flatten_dict(d): | ||||||
|  |     if not isinstance(d, dict): | ||||||
|  |         return d | ||||||
|  |     return dict(_flatten_dict(d)) | ||||||
|  |  | ||||||
|  | def _flatten_dict(d, prefix=''): | ||||||
|  |     if not isinstance(d, dict): | ||||||
|  |         # print('END:', prefix, d) | ||||||
|  |         yield prefix, d | ||||||
|  |         return | ||||||
|  |     if prefix: | ||||||
|  |         prefix = prefix + '.' | ||||||
|  |     for k, v in d.items(): | ||||||
|  |         # print(k, v) | ||||||
|  |         res = list(_flatten_dict(v, prefix='{}{}'.format(prefix, k))) | ||||||
|  |         # print('RES:', res) | ||||||
|  |         yield from res | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def unflatten_dict(d): | ||||||
|  |     out = {} | ||||||
|  |     for k, v in d.items(): | ||||||
|  |         target = out | ||||||
|  |         if not isinstance(k, str): | ||||||
|  |             target[k] = v | ||||||
|  |             continue | ||||||
|  |         tokens = k.split('.') | ||||||
|  |         if len(tokens) < 2: | ||||||
|  |             target[k] = v | ||||||
|  |             continue | ||||||
|  |         for token in tokens[:-1]: | ||||||
|  |             if token not in target: | ||||||
|  |                 target[token] = {} | ||||||
|  |             target = target[token] | ||||||
|  |         target[tokens[-1]] = v | ||||||
|  |     return out | ||||||
|   | |||||||
| @@ -66,8 +66,8 @@ class TestAnalysis(TestCase): | |||||||
|         env = self.env |         env = self.env | ||||||
|         df = analysis.read_sql(env._history.db_path) |         df = analysis.read_sql(env._history.db_path) | ||||||
|         res = analysis.get_count(df, 'SEED', 'id') |         res = analysis.get_count(df, 'SEED', 'id') | ||||||
|         assert res['SEED']['seedanalysis_trial_0'].iloc[0] == 1 |         assert res['SEED'][self.env['SEED']].iloc[0] == 1 | ||||||
|         assert res['SEED']['seedanalysis_trial_0'].iloc[-1] == 1 |         assert res['SEED'][self.env['SEED']].iloc[-1] == 1 | ||||||
|         assert res['id']['odd'].iloc[0] == 2 |         assert res['id']['odd'].iloc[0] == 2 | ||||||
|         assert res['id']['even'].iloc[0] == 0 |         assert res['id']['even'].iloc[0] == 0 | ||||||
|         assert res['id']['odd'].iloc[-1] == 1 |         assert res['id']['odd'].iloc[-1] == 1 | ||||||
| @@ -75,7 +75,7 @@ class TestAnalysis(TestCase): | |||||||
|  |  | ||||||
|     def test_value(self): |     def test_value(self): | ||||||
|         env = self.env |         env = self.env | ||||||
|         df = analysis.read_sql(env._history._db) |         df = analysis.read_sql(env._history.db_path) | ||||||
|         res_sum = analysis.get_value(df, 'count') |         res_sum = analysis.get_value(df, 'count') | ||||||
|  |  | ||||||
|         assert res_sum['count'].iloc[0] == 2 |         assert res_sum['count'].iloc[0] == 2 | ||||||
| @@ -86,4 +86,4 @@ class TestAnalysis(TestCase): | |||||||
|  |  | ||||||
|         res_total = analysis.get_value(df) |         res_total = analysis.get_value(df) | ||||||
|  |  | ||||||
|         res_total['SEED'].iloc[0] == 'seedanalysis_trial_0' |         res_total['SEED'].iloc[0] == self.env['SEED'] | ||||||
|   | |||||||
| @@ -31,7 +31,7 @@ def make_example_test(path, config): | |||||||
|                 try: |                 try: | ||||||
|                     n = config['network_params']['n'] |                     n = config['network_params']['n'] | ||||||
|                     assert len(list(env.network_agents)) == n |                     assert len(list(env.network_agents)) == n | ||||||
|                     assert env.now > 2  # It has run |                     assert env.now > 0  # It has run | ||||||
|                     assert env.now <= config['max_time']  # But not further than allowed |                     assert env.now <= config['max_time']  # But not further than allowed | ||||||
|                 except KeyError: |                 except KeyError: | ||||||
|                     pass |                     pass | ||||||
|   | |||||||
| @@ -6,26 +6,32 @@ from time import time | |||||||
|  |  | ||||||
| from unittest import TestCase | from unittest import TestCase | ||||||
| from soil import exporters | from soil import exporters | ||||||
| from soil.utils import safe_open |  | ||||||
| from soil import simulation | from soil import simulation | ||||||
|  |  | ||||||
|  | from soil.stats import distribution | ||||||
|  |  | ||||||
| class Dummy(exporters.Exporter): | class Dummy(exporters.Exporter): | ||||||
|     started = False |     started = False | ||||||
|     trials = 0 |     trials = 0 | ||||||
|     ended = False |     ended = False | ||||||
|     total_time = 0 |     total_time = 0 | ||||||
|  |     called_start = 0 | ||||||
|  |     called_trial = 0 | ||||||
|  |     called_end = 0 | ||||||
|  |  | ||||||
|     def start(self): |     def start(self): | ||||||
|  |         self.__class__.called_start += 1 | ||||||
|         self.__class__.started = True |         self.__class__.started = True | ||||||
|  |  | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         assert env |         assert env | ||||||
|         self.__class__.trials += 1 |         self.__class__.trials += 1 | ||||||
|         self.__class__.total_time += env.now |         self.__class__.total_time += env.now | ||||||
|  |         self.__class__.called_trial += 1 | ||||||
|  |  | ||||||
|     def end(self): |     def end(self, stats): | ||||||
|         self.__class__.ended = True |         self.__class__.ended = True | ||||||
|  |         self.__class__.called_end += 1 | ||||||
|  |  | ||||||
|  |  | ||||||
| class Exporters(TestCase): | class Exporters(TestCase): | ||||||
| @@ -39,32 +45,17 @@ class Exporters(TestCase): | |||||||
|             'environment_params': {} |             'environment_params': {} | ||||||
|         } |         } | ||||||
|         s = simulation.from_config(config) |         s = simulation.from_config(config) | ||||||
|         s.run_simulation(exporters=[Dummy], dry_run=True) |         for env in s.run_simulation(exporters=[Dummy], dry_run=True): | ||||||
|  |             assert env.now <= 2 | ||||||
|  |  | ||||||
|         assert Dummy.started |         assert Dummy.started | ||||||
|         assert Dummy.ended |         assert Dummy.ended | ||||||
|  |         assert Dummy.called_start == 1 | ||||||
|  |         assert Dummy.called_end == 1 | ||||||
|  |         assert Dummy.called_trial == 5 | ||||||
|         assert Dummy.trials == 5 |         assert Dummy.trials == 5 | ||||||
|         assert Dummy.total_time == 2*5 |         assert Dummy.total_time == 2*5 | ||||||
|  |  | ||||||
|     def test_distribution(self): |  | ||||||
|         '''The distribution exporter should write the number of agents in each state''' |  | ||||||
|         config = { |  | ||||||
|             'name': 'exporter_sim', |  | ||||||
|             'network_params': { |  | ||||||
|                 'generator': 'complete_graph', |  | ||||||
|                 'n': 4 |  | ||||||
|             }, |  | ||||||
|             'agent_type': 'CounterModel', |  | ||||||
|             'max_time': 2, |  | ||||||
|             'num_trials': 5, |  | ||||||
|             'environment_params': {} |  | ||||||
|         } |  | ||||||
|         output = io.StringIO() |  | ||||||
|         s = simulation.from_config(config) |  | ||||||
|         s.run_simulation(exporters=[exporters.distribution], dry_run=True, exporter_params={'copy_to': output}) |  | ||||||
|         result = output.getvalue() |  | ||||||
|         assert 'count' in result |  | ||||||
|         assert 'SEED,Noneexporter_sim_trial_3,1,,1,1,1,1' in result |  | ||||||
|  |  | ||||||
|     def test_writing(self): |     def test_writing(self): | ||||||
|         '''Try to write CSV, GEXF, sqlite and YAML (without dry_run)''' |         '''Try to write CSV, GEXF, sqlite and YAML (without dry_run)''' | ||||||
|         n_trials = 5 |         n_trials = 5 | ||||||
| @@ -86,8 +77,8 @@ class Exporters(TestCase): | |||||||
|             exporters.default, |             exporters.default, | ||||||
|             exporters.csv, |             exporters.csv, | ||||||
|             exporters.gexf, |             exporters.gexf, | ||||||
|             exporters.distribution, |  | ||||||
|         ], |         ], | ||||||
|  |                                 stats=[distribution,], | ||||||
|                                 outdir=tmpdir, |                                 outdir=tmpdir, | ||||||
|                                 exporter_params={'copy_to': output}) |                                 exporter_params={'copy_to': output}) | ||||||
|         result = output.getvalue() |         result = output.getvalue() | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ import shutil | |||||||
| from glob import glob | from glob import glob | ||||||
|  |  | ||||||
| from soil import history | from soil import history | ||||||
|  | from soil import utils | ||||||
|  |  | ||||||
|  |  | ||||||
| ROOT = os.path.abspath(os.path.dirname(__file__)) | ROOT = os.path.abspath(os.path.dirname(__file__)) | ||||||
| @@ -154,3 +155,49 @@ class TestHistory(TestCase): | |||||||
|         assert recovered |         assert recovered | ||||||
|         for i in recovered: |         for i in recovered: | ||||||
|             assert i in tuples |             assert i in tuples | ||||||
|  |  | ||||||
|  |     def test_stats(self): | ||||||
|  |         """ | ||||||
|  |         The data recovered should be equal to the one recorded. | ||||||
|  |         """ | ||||||
|  |         tuples = ( | ||||||
|  |             ('a_1', 0, 'id', 'v'), | ||||||
|  |             ('a_1', 1, 'id', 'a'), | ||||||
|  |             ('a_1', 2, 'id', 'l'), | ||||||
|  |             ('a_1', 3, 'id', 'u'), | ||||||
|  |             ('a_1', 4, 'id', 'e'), | ||||||
|  |             ('env', 1, 'prob', 1), | ||||||
|  |             ('env', 2, 'prob', 2), | ||||||
|  |             ('env', 3, 'prob', 3), | ||||||
|  |             ('a_2', 7, 'finished', True), | ||||||
|  |         ) | ||||||
|  |         stat_tuples = [ | ||||||
|  |             {'num_infected': 5, 'runtime': 0.2}, | ||||||
|  |             {'num_infected': 5, 'runtime': 0.2}, | ||||||
|  |             {'new': '40'}, | ||||||
|  |         ] | ||||||
|  |         h = history.History() | ||||||
|  |         h.save_tuples(tuples) | ||||||
|  |         for stat in stat_tuples: | ||||||
|  |             h.save_stats(stat) | ||||||
|  |         recovered = h.get_stats() | ||||||
|  |         assert recovered | ||||||
|  |         assert recovered[0]['num_infected'] == 5 | ||||||
|  |         assert recovered[1]['runtime'] == 0.2 | ||||||
|  |         assert recovered[2]['new'] == '40' | ||||||
|  |  | ||||||
|  |     def test_unflatten(self): | ||||||
|  |         ex = {'count.neighbors.3': 4, | ||||||
|  |               'count.times.2': 4, | ||||||
|  |               'count.total.4': 4, | ||||||
|  |               'mean.neighbors': 3, | ||||||
|  |               'mean.times': 2, | ||||||
|  |               'mean.total': 4, | ||||||
|  |               't_step': 2, | ||||||
|  |               'trial_id': 'exporter_sim_trial_1605817956-4475424'} | ||||||
|  |         res = utils.unflatten_dict(ex) | ||||||
|  |  | ||||||
|  |         assert 'count' in res | ||||||
|  |         assert 'mean' in res | ||||||
|  |         assert 't_step' in res | ||||||
|  |         assert 'trial_id' in res | ||||||
|   | |||||||
| @@ -343,4 +343,16 @@ class TestMain(TestCase): | |||||||
|         configs = serialization.load_file(join(EXAMPLES, 'template.yml')) |         configs = serialization.load_file(join(EXAMPLES, 'template.yml')) | ||||||
|         assert len(configs) > 0 |         assert len(configs) > 0 | ||||||
|  |  | ||||||
|  |     def test_until(self): | ||||||
|  |         config = { | ||||||
|  |             'name': 'exporter_sim', | ||||||
|  |             'network_params': {}, | ||||||
|  |             'agent_type': 'CounterModel', | ||||||
|  |             'max_time': 2, | ||||||
|  |             'num_trials': 100, | ||||||
|  |             'environment_params': {} | ||||||
|  |         } | ||||||
|  |         s = simulation.from_config(config) | ||||||
|  |         runs = list(s.run_simulation(dry_run=True)) | ||||||
|  |         over = list(x.now for x in runs if x.now>2) | ||||||
|  |         assert len(over) == 0 | ||||||
|   | |||||||
							
								
								
									
										34
									
								
								tests/test_stats.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								tests/test_stats.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,34 @@ | |||||||
|  | from unittest import TestCase | ||||||
|  |  | ||||||
|  | from soil import simulation, stats | ||||||
|  | from soil.utils import unflatten_dict | ||||||
|  |  | ||||||
|  | class Stats(TestCase): | ||||||
|  |  | ||||||
|  |     def test_distribution(self): | ||||||
|  |         '''The distribution exporter should write the number of agents in each state''' | ||||||
|  |         config = { | ||||||
|  |             'name': 'exporter_sim', | ||||||
|  |             'network_params': { | ||||||
|  |                 'generator': 'complete_graph', | ||||||
|  |                 'n': 4 | ||||||
|  |             }, | ||||||
|  |             'agent_type': 'CounterModel', | ||||||
|  |             'max_time': 2, | ||||||
|  |             'num_trials': 5, | ||||||
|  |             'environment_params': {} | ||||||
|  |         } | ||||||
|  |         s = simulation.from_config(config) | ||||||
|  |         for env in s.run_simulation(stats=[stats.distribution]): | ||||||
|  |             pass | ||||||
|  |             # stats_res = unflatten_dict(dict(env._history['stats', -1, None])) | ||||||
|  |         allstats = s.get_stats() | ||||||
|  |         for stat in allstats: | ||||||
|  |             assert 'count' in stat | ||||||
|  |             assert 'mean' in stat | ||||||
|  |             if 'trial_id' in stat: | ||||||
|  |                 assert stat['mean']['neighbors'] == 3 | ||||||
|  |                 assert stat['count']['total']['4'] == 4 | ||||||
|  |             else: | ||||||
|  |                 assert stat['count']['count']['neighbors']['3'] == 20 | ||||||
|  |                 assert stat['mean']['min']['neighbors'] == stat['mean']['max']['neighbors'] | ||||||
		Reference in New Issue
	
	Block a user