mirror of
https://github.com/gsi-upm/soil
synced 2025-11-04 17:38:16 +00:00
Compare commits
18 Commits
0.13.4-fix
...
05f7f49233
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05f7f49233 | ||
|
|
3b2c6a3db5 | ||
|
|
6118f917ee | ||
|
|
6adc8d36ba | ||
|
|
c8b8149a17 | ||
|
|
6690b6ee5f | ||
|
|
97835b3d10 | ||
|
|
b0add8552e | ||
|
|
1cf85ea450 | ||
|
|
c32e167fb8 | ||
|
|
5f68b5321d | ||
|
|
2a2843bd19 | ||
|
|
d1006bd55c | ||
|
|
9bc036d185 | ||
|
|
a3ea434f23 | ||
|
|
65f6aa72f3 | ||
|
|
09e14c6e84 | ||
|
|
8593ac999d |
@@ -1,2 +1,5 @@
|
||||
**/soil_output
|
||||
.*
|
||||
**/__pycache__
|
||||
__pycache__
|
||||
*.pyc
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
stages:
|
||||
- build
|
||||
- test
|
||||
- build
|
||||
|
||||
build:
|
||||
stage: build
|
||||
@@ -18,6 +18,8 @@ build:
|
||||
|
||||
|
||||
test:
|
||||
except:
|
||||
- tags # Avoid running tests for tags, because they are already run for the branch
|
||||
tags:
|
||||
- docker
|
||||
image: python:3.7
|
||||
|
||||
104
CHANGELOG.md
Normal file
104
CHANGELOG.md
Normal file
@@ -0,0 +1,104 @@
|
||||
# Changelog
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.15.1]
|
||||
### Added
|
||||
* read-only `History`
|
||||
### Fixed
|
||||
* Serialization problem with the `Environment` on parallel mode.
|
||||
* Analysis functions now work as they should in the tutorial
|
||||
## [0.15.0]
|
||||
### Added
|
||||
* Control logging level in CLI and simulation
|
||||
* `Stats` to calculate trial and simulation-wide statistics
|
||||
* Simulation statistics are stored in a separate table in history (see `History.get_stats` and `History.save_stats`, as well as `soil.stats`)
|
||||
* Aliased `NetworkAgent.G` to `NetworkAgent.topology`.
|
||||
### Changed
|
||||
* Templates in config files can be given as dictionaries in addition to strings
|
||||
* Samplers are used more explicitly
|
||||
* Removed nxsim dependency. We had already made a lot of changes, and nxsim has not been updated in 5 years.
|
||||
* Exporter methods renamed to `trial` and `end`. Added `start`.
|
||||
* `Distribution` exporter now a stats class
|
||||
* `global_topology` renamed to `topology`
|
||||
* Moved topology-related methods to `NetworkAgent`
|
||||
### Fixed
|
||||
* Temporary files used for history in dry_run mode are not longer left open
|
||||
|
||||
## [0.14.9]
|
||||
### Changed
|
||||
* Seed random before environment initialization
|
||||
## [0.14.8]
|
||||
### Fixed
|
||||
* Invalid directory names in Windows gsi-upm/soil#5
|
||||
## [0.14.7]
|
||||
### Changed
|
||||
* Minor change to traceback handling in async simulations
|
||||
### Fixed
|
||||
* Incomplete example in the docs (example.yml) caused an exception
|
||||
## [0.14.6]
|
||||
### Fixed
|
||||
* Bug with newer versions of networkx (0.24) where the Graph.node attribute has been removed. We have updated our calls, but the code in nxsim is not under our control, so we have pinned the networkx version until that issue is solved.
|
||||
### Changed
|
||||
* Explicit yaml.SafeLoader to avoid deprecation warnings when using yaml.load. It should not break any existing setups, but we could move to the FullLoader in the future if needed.
|
||||
|
||||
## [0.14.4]
|
||||
### Fixed
|
||||
* Bug in `agent.get_agents()` when `state_id` is passed as a string. The tests have been modified accordingly.
|
||||
## [0.14.3]
|
||||
### Fixed
|
||||
* Incompatibility with py3.3-3.6 due to ModuleNotFoundError and TypeError in DryRunner
|
||||
## [0.14.2]
|
||||
### Fixed
|
||||
* Output path for exporters is now soil_output
|
||||
### Changed
|
||||
* CSV output to stdout in dry_run mode
|
||||
## [0.14.1]
|
||||
### Changed
|
||||
* Exporter names in lower case
|
||||
* Add default exporter in runs
|
||||
## [0.14.0]
|
||||
### Added
|
||||
* Loading configuration from template definitions in the yaml, in preparation for SALib support.
|
||||
The definition of the variables and their possible values (i.e., a problem in SALib terms), as well as a sampler function, can be provided.
|
||||
Soil uses this definition and the template to generate a set of configurations.
|
||||
* Simulation group names, to link related simulations. For now, they are only used to group all simulations in the same group under the same folder.
|
||||
* Exporters unify exporting/dumping results and other files to disk. If `dry_run` is set to `True`, exporters will write to stdout instead of a file (useful for testing/debugging).
|
||||
* Distribution exporter, to write statistics about values and value_counts in every simulation. The results are dumped to two CSV files.
|
||||
|
||||
### Changed
|
||||
* `dir_path` is now the directory for resources (modules, files)
|
||||
* Environments and simulations do not export or write anything by default. That task is delegated to Exporters
|
||||
|
||||
### Removed
|
||||
* The output dir for environments and simulations (see Exporters)
|
||||
* DrawingAgent, because it wrote to disk and was not being used. We provide a partial alternative in the form of the GraphDrawing exporter. A complete alternative will be provided once the network at each state can be accessed by exporters.
|
||||
|
||||
## Fixed
|
||||
* Modules with custom agents/environments failed to load when they were run from outside the directory of the definition file. Modules are now loaded from the directory of the simulation file in addition to the working directory
|
||||
* Memory databases (in history) can now be shared between threads.
|
||||
* Testing all examples, not just subdirectories
|
||||
|
||||
## [0.13.8]
|
||||
### Changed
|
||||
* Moved TerroristNetworkModel to examples
|
||||
### Added
|
||||
* `get_agents` and `count_agents` methods now accept lists as inputs. They can be used to retrieve agents from node ids
|
||||
* `subgraph` in BaseAgent
|
||||
* `agents.select` method, to filter out agents
|
||||
* `skip_test` property in yaml definitions, to force skipping some examples
|
||||
* `agents.Geo`, with a search function based on postition
|
||||
* `BaseAgent.ego_search` to get nodes from the ego network of a node
|
||||
* `BaseAgent.degree` and `BaseAgent.betweenness`
|
||||
### Fixed
|
||||
|
||||
## [0.13.7]
|
||||
### Changed
|
||||
* History now defaults to not backing up! This makes it more intuitive to load the history for examination, at the expense of rewriting something. That should not happen because History is only created in the Environment, and that has `backup=True`.
|
||||
### Added
|
||||
* Agent names are assigned based on their agent types
|
||||
* Agent logging uses the agent name.
|
||||
* FSM agents can now return a timeout in addition to a new state. e.g. `return self.idle, self.env.timeout(2)` will execute the *different_state* in 2 *units of time* (`t_step=now+2`).
|
||||
* Example of using timeouts in FSM (custom_timeouts)
|
||||
* `network_agents` entries may include an `ids` entry. If set, it should be a list of node ids that should be assigned that agent type. This complements the previous behavior of setting agent type with `weights`.
|
||||
@@ -1,4 +1,7 @@
|
||||
include requirements.txt
|
||||
include test-requirements.txt
|
||||
include README.rst
|
||||
graft soil
|
||||
graft soil
|
||||
global-exclude __pycache__
|
||||
global-exclude soil_output
|
||||
global-exclude *.py[co]
|
||||
|
||||
7
Makefile
7
Makefile
@@ -1,4 +1,7 @@
|
||||
test:
|
||||
quick-test:
|
||||
docker-compose exec dev python -m pytest -s -v
|
||||
|
||||
.PHONY: test
|
||||
test:
|
||||
docker run -t -v $$PWD:/usr/src/app -w /usr/src/app python:3.7 python setup.py test
|
||||
|
||||
.PHONY: test
|
||||
|
||||
@@ -30,5 +30,5 @@ If you use Soil in your research, don't forget to cite this paper:
|
||||
|
||||
@Copyright GSI - Universidad Politécnica de Madrid 2017
|
||||
|
||||
[](https://www.gsi.dit.upm.es)
|
||||
[](https://www.gsi.upm.es)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = []
|
||||
extensions = ['IPython.sphinxext.ipython_console_highlighting']
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
@@ -69,7 +69,7 @@ language = None
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This patterns also effect to html_static_path and html_extra_path
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '**.ipynb_checkpoints']
|
||||
|
||||
# The name of the Pygments (syntax highlighting) style to use.
|
||||
pygments_style = 'sphinx'
|
||||
|
||||
@@ -8,32 +8,8 @@ The advantage of a configuration file is that it is a clean declarative descript
|
||||
Simulation configuration files can be formatted in ``json`` or ``yaml`` and they define all the parameters of a simulation.
|
||||
Here's an example (``example.yml``).
|
||||
|
||||
.. code:: yaml
|
||||
|
||||
---
|
||||
name: MyExampleSimulation
|
||||
max_time: 50
|
||||
num_trials: 3
|
||||
interval: 2
|
||||
network_params:
|
||||
generator: barabasi_albert_graph
|
||||
n: 100
|
||||
m: 2
|
||||
network_agents:
|
||||
- agent_type: SISaModel
|
||||
weight: 1
|
||||
state:
|
||||
id: content
|
||||
- agent_type: SISaModel
|
||||
weight: 1
|
||||
state:
|
||||
id: discontent
|
||||
- agent_type: SISaModel
|
||||
weight: 8
|
||||
state:
|
||||
id: neutral
|
||||
environment_params:
|
||||
prob_infect: 0.075
|
||||
.. literalinclude:: example.yml
|
||||
:language: yaml
|
||||
|
||||
|
||||
This example configuration will run three trials (``num_trials``) of a simulation containing a randomly generated network (``network_params``).
|
||||
@@ -242,3 +218,24 @@ These agents are programmed in much the same way as network agents, the only dif
|
||||
|
||||
You may use environment agents to model events that a normal agent cannot control, such as natural disasters or chance.
|
||||
They are also useful to add behavior that has little to do with the network and the interactions within that network.
|
||||
|
||||
Templating
|
||||
==========
|
||||
|
||||
Sometimes, it is useful to parameterize a simulation and run it over a range of values in order to compare each run and measure the effect of those parameters in the simulation.
|
||||
For instance, you may want to run a simulation with different agent distributions.
|
||||
|
||||
This can be done in Soil using **templates**.
|
||||
A template is a configuration where some of the values are specified with a variable.
|
||||
e.g., ``weight: "{{ var1 }}"`` instead of ``weight: 1``.
|
||||
There are two types of variables, depending on how their values are decided:
|
||||
|
||||
* Fixed. A list of values is provided, and a new simulation is run for each possible value. If more than a variable is given, a new simulation will be run per combination of values.
|
||||
* Bounded/Sampled. The bounds of the variable are provided, along with a sampler method, which will be used to compute all the configuration combinations.
|
||||
|
||||
When fixed and bounded variables are mixed, Soil generates a new configuration per combination of fixed values and bounded values.
|
||||
|
||||
Here is an example with a single fixed variable and two bounded variable:
|
||||
|
||||
.. literalinclude:: ../examples/template.yml
|
||||
:language: yaml
|
||||
|
||||
35
docs/example.yml
Normal file
35
docs/example.yml
Normal file
@@ -0,0 +1,35 @@
|
||||
---
|
||||
name: MyExampleSimulation
|
||||
max_time: 50
|
||||
num_trials: 3
|
||||
interval: 2
|
||||
network_params:
|
||||
generator: barabasi_albert_graph
|
||||
n: 100
|
||||
m: 2
|
||||
network_agents:
|
||||
- agent_type: SISaModel
|
||||
weight: 1
|
||||
state:
|
||||
id: content
|
||||
- agent_type: SISaModel
|
||||
weight: 1
|
||||
state:
|
||||
id: discontent
|
||||
- agent_type: SISaModel
|
||||
weight: 8
|
||||
state:
|
||||
id: neutral
|
||||
environment_params:
|
||||
prob_infect: 0.075
|
||||
neutral_discontent_spon_prob: 0.1
|
||||
neutral_discontent_infected_prob: 0.3
|
||||
neutral_content_spon_prob: 0.3
|
||||
neutral_content_infected_prob: 0.4
|
||||
discontent_neutral: 0.5
|
||||
discontent_content: 0.5
|
||||
variance_d_c: 0.2
|
||||
content_discontent: 0.2
|
||||
variance_c_d: 0.2
|
||||
content_neutral: 0.2
|
||||
standard_variance: 1
|
||||
@@ -14,11 +14,11 @@ Now test that it worked by running the command line tool
|
||||
|
||||
soil --help
|
||||
|
||||
Or using soil programmatically:
|
||||
Or, if you're using using soil programmatically:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import soil
|
||||
print(soil.__version__)
|
||||
|
||||
The latest version can be installed through `GitLab <https://lab.cluster.gsi.dit.upm.es/soil/soil.git>`_.
|
||||
The latest version can be installed through `GitLab <https://lab.gsi.upm.es/soil/soil.git>`_ or `GitHub <https://github.com/gsi-upm/soil>`_.
|
||||
|
||||
@@ -26,7 +26,7 @@ But before that, let's import the soil module and networkx.
|
||||
%autoreload 2
|
||||
|
||||
%pylab inline
|
||||
# To display plots in the notebooed_
|
||||
# To display plots in the notebook_
|
||||
|
||||
|
||||
.. parsed-literal::
|
||||
@@ -323,7 +323,7 @@ Let's run our simulation:
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
soil.simulation.run_from_config(config, dump=False)
|
||||
soil.simulation.run_from_config(config)
|
||||
|
||||
|
||||
.. parsed-literal::
|
||||
@@ -2531,7 +2531,7 @@ Dealing with bigger data
|
||||
|
||||
.. parsed-literal::
|
||||
|
||||
267M ../rabbits/soil_output/rabbits_example/
|
||||
267M ../rabbits/soil_output/rabbits_example/
|
||||
|
||||
|
||||
If we tried to load the entire history, we would probably run out of
|
||||
|
||||
@@ -500,7 +500,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.5"
|
||||
"version": "3.8.5"
|
||||
},
|
||||
"toc": {
|
||||
"colors": {
|
||||
|
||||
@@ -80800,7 +80800,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.5"
|
||||
"version": "3.8.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
---
|
||||
name: simple
|
||||
group: tests
|
||||
dir_path: "/tmp/"
|
||||
num_trials: 3
|
||||
dry_run: True
|
||||
max_time: 100
|
||||
interval: 1
|
||||
seed: "CompleteSeed!"
|
||||
dump: false
|
||||
network_params:
|
||||
generator: complete_graph
|
||||
n: 10
|
||||
|
||||
16
examples/custom_generator/custom_generator.yml
Normal file
16
examples/custom_generator/custom_generator.yml
Normal file
@@ -0,0 +1,16 @@
|
||||
---
|
||||
name: custom-generator
|
||||
description: Using a custom generator for the network
|
||||
num_trials: 3
|
||||
max_time: 100
|
||||
interval: 1
|
||||
network_params:
|
||||
generator: mymodule.mygenerator
|
||||
# These are custom parameters
|
||||
n: 10
|
||||
n_edges: 5
|
||||
network_agents:
|
||||
- agent_type: CounterModel
|
||||
weight: 1
|
||||
state:
|
||||
id: 0
|
||||
27
examples/custom_generator/mymodule.py
Normal file
27
examples/custom_generator/mymodule.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from networkx import Graph
|
||||
import networkx as nx
|
||||
from random import choice
|
||||
|
||||
def mygenerator(n=5, n_edges=5):
|
||||
'''
|
||||
Just a simple generator that creates a network with n nodes and
|
||||
n_edges edges. Edges are assigned randomly, only avoiding self loops.
|
||||
'''
|
||||
G = nx.Graph()
|
||||
|
||||
for i in range(n):
|
||||
G.add_node(i)
|
||||
|
||||
for i in range(n_edges):
|
||||
nodes = list(G.nodes)
|
||||
n_in = choice(nodes)
|
||||
nodes.remove(n_in) # Avoid loops
|
||||
n_out = choice(nodes)
|
||||
G.add_edge(n_in, n_out)
|
||||
return G
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
35
examples/custom_timeouts/custom_timeouts.py
Normal file
35
examples/custom_timeouts/custom_timeouts.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from soil.agents import FSM, state, default_state
|
||||
|
||||
|
||||
class Fibonacci(FSM):
|
||||
'''Agent that only executes in t_steps that are Fibonacci numbers'''
|
||||
|
||||
defaults = {
|
||||
'prev': 1
|
||||
}
|
||||
|
||||
@default_state
|
||||
@state
|
||||
def counting(self):
|
||||
self.log('Stopping at {}'.format(self.now))
|
||||
prev, self['prev'] = self['prev'], max([self.now, self['prev']])
|
||||
return None, self.env.timeout(prev)
|
||||
|
||||
class Odds(FSM):
|
||||
'''Agent that only executes in odd t_steps'''
|
||||
@default_state
|
||||
@state
|
||||
def odds(self):
|
||||
self.log('Stopping at {}'.format(self.now))
|
||||
return None, self.env.timeout(1+self.now%2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
from soil import Simulation
|
||||
s = Simulation(network_agents=[{'ids': [0], 'agent_type': Fibonacci},
|
||||
{'ids': [1], 'agent_type': Odds}],
|
||||
network_params={"generator": "complete_graph", "n": 2},
|
||||
max_time=100,
|
||||
)
|
||||
s.run(dry_run=True)
|
||||
@@ -6,7 +6,7 @@ environment_params:
|
||||
prob_neighbor_spread: 0.0
|
||||
prob_tv_spread: 0.01
|
||||
interval: 1
|
||||
max_time: 30
|
||||
max_time: 300
|
||||
name: Sim_all_dumb
|
||||
network_agents:
|
||||
- agent_type: DumbViewer
|
||||
@@ -30,7 +30,7 @@ environment_params:
|
||||
prob_neighbor_spread: 0.0
|
||||
prob_tv_spread: 0.01
|
||||
interval: 1
|
||||
max_time: 30
|
||||
max_time: 300
|
||||
name: Sim_half_herd
|
||||
network_agents:
|
||||
- agent_type: DumbViewer
|
||||
@@ -62,7 +62,7 @@ environment_params:
|
||||
prob_neighbor_spread: 0.0
|
||||
prob_tv_spread: 0.01
|
||||
interval: 1
|
||||
max_time: 30
|
||||
max_time: 300
|
||||
name: Sim_all_herd
|
||||
network_agents:
|
||||
- agent_type: HerdViewer
|
||||
@@ -89,7 +89,7 @@ environment_params:
|
||||
prob_tv_spread: 0.01
|
||||
prob_neighbor_cure: 0.1
|
||||
interval: 1
|
||||
max_time: 30
|
||||
max_time: 300
|
||||
name: Sim_wise_herd
|
||||
network_agents:
|
||||
- agent_type: HerdViewer
|
||||
@@ -115,7 +115,7 @@ environment_params:
|
||||
prob_tv_spread: 0.01
|
||||
prob_neighbor_cure: 0.1
|
||||
interval: 1
|
||||
max_time: 30
|
||||
max_time: 300
|
||||
name: Sim_all_wise
|
||||
network_agents:
|
||||
- agent_type: WiseViewer
|
||||
|
||||
1
examples/programmatic/.gitignore
vendored
Normal file
1
examples/programmatic/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
Programmatic*
|
||||
38
examples/programmatic/programmatic.py
Normal file
38
examples/programmatic/programmatic.py
Normal file
@@ -0,0 +1,38 @@
|
||||
'''
|
||||
Example of a fully programmatic simulation, without definition files.
|
||||
'''
|
||||
from soil import Simulation, agents
|
||||
from networkx import Graph
|
||||
import logging
|
||||
|
||||
|
||||
def mygenerator():
|
||||
# Add only a node
|
||||
G = Graph()
|
||||
G.add_node(1)
|
||||
return G
|
||||
|
||||
|
||||
class MyAgent(agents.FSM):
|
||||
|
||||
@agents.default_state
|
||||
@agents.state
|
||||
def neutral(self):
|
||||
self.info('I am running')
|
||||
|
||||
|
||||
s = Simulation(name='Programmatic',
|
||||
network_params={'generator': mygenerator},
|
||||
num_trials=1,
|
||||
max_time=100,
|
||||
agent_type=MyAgent,
|
||||
dry_run=True)
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
envs = s.run()
|
||||
|
||||
s.dump_yaml()
|
||||
|
||||
for env in envs:
|
||||
env.dump_csv()
|
||||
@@ -59,7 +59,7 @@ class Patron(FSM):
|
||||
2) Look for a bar where the agent and other agents in the same group can get in.
|
||||
3) While in the bar, patrons only drink, until they get drunk and taken home.
|
||||
'''
|
||||
level = logging.INFO
|
||||
level = logging.DEBUG
|
||||
|
||||
defaults = {
|
||||
'pub': None,
|
||||
@@ -113,7 +113,8 @@ class Patron(FSM):
|
||||
@state
|
||||
def at_home(self):
|
||||
'''The end'''
|
||||
self.debug('Life sucks. I\'m home!')
|
||||
others = self.get_agents(state_id=Patron.at_home.id, limit_neighbors=True)
|
||||
self.debug('I\'m home. Just like {} of my friends'.format(len(others)))
|
||||
|
||||
def drink(self):
|
||||
self['pints'] += 1
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from soil.agents import FSM, state, default_state, BaseAgent
|
||||
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
|
||||
from enum import Enum
|
||||
from random import random, choice
|
||||
from itertools import islice
|
||||
@@ -80,7 +80,7 @@ class RabbitModel(FSM):
|
||||
self.env.add_edge(self['mate'], child.id)
|
||||
# self.add_edge()
|
||||
self.debug('A BABY IS COMING TO LIFE')
|
||||
self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.global_topology.number_of_nodes())+1
|
||||
self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.topology.number_of_nodes())+1
|
||||
self.debug('Rabbits alive: {}'.format(self.env['rabbits_alive']))
|
||||
self['offspring'] += 1
|
||||
self.env.get_agent(self['mate'])['offspring'] += 1
|
||||
@@ -97,12 +97,14 @@ class RabbitModel(FSM):
|
||||
return
|
||||
|
||||
|
||||
class RandomAccident(BaseAgent):
|
||||
class RandomAccident(NetworkAgent):
|
||||
|
||||
level = logging.DEBUG
|
||||
|
||||
def step(self):
|
||||
rabbits_total = self.global_topology.number_of_nodes()
|
||||
rabbits_total = self.topology.number_of_nodes()
|
||||
if 'rabbits_alive' not in self.env:
|
||||
self.env['rabbits_alive'] = 0
|
||||
rabbits_alive = self.env.get('rabbits_alive', rabbits_total)
|
||||
prob_death = self.env.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
|
||||
self.debug('Killing some rabbits with prob={}!'.format(prob_death))
|
||||
@@ -116,5 +118,5 @@ class RandomAccident(BaseAgent):
|
||||
self.log('Rabbits alive: {}'.format(self.env['rabbits_alive']))
|
||||
i.set_state(i.dead)
|
||||
self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total))
|
||||
if self.count_agents(state_id=RabbitModel.dead.id) == self.global_topology.number_of_nodes():
|
||||
if self.count_agents(state_id=RabbitModel.dead.id) == self.topology.number_of_nodes():
|
||||
self.die()
|
||||
|
||||
30
examples/template.yml
Normal file
30
examples/template.yml
Normal file
@@ -0,0 +1,30 @@
|
||||
---
|
||||
sampler:
|
||||
method: "SALib.sample.morris.sample"
|
||||
N: 10
|
||||
template:
|
||||
group: simple
|
||||
num_trials: 1
|
||||
interval: 1
|
||||
max_time: 2
|
||||
seed: "CompleteSeed!"
|
||||
dump: false
|
||||
network_params:
|
||||
generator: complete_graph
|
||||
n: 10
|
||||
network_agents:
|
||||
- agent_type: CounterModel
|
||||
weight: "{{ x1 }}"
|
||||
state:
|
||||
id: 0
|
||||
- agent_type: AggregatedCounter
|
||||
weight: "{{ 1 - x1 }}"
|
||||
environment_params:
|
||||
name: "{{ x3 }}"
|
||||
skip_test: true
|
||||
vars:
|
||||
bounds:
|
||||
x1: [0, 1]
|
||||
x2: [1, 2]
|
||||
fixed:
|
||||
x3: ["a", "b", "c"]
|
||||
208
examples/terrorism/TerroristNetworkModel.py
Normal file
208
examples/terrorism/TerroristNetworkModel.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import random
|
||||
import networkx as nx
|
||||
from soil.agents import Geo, NetworkAgent, FSM, state, default_state
|
||||
from soil import Environment
|
||||
|
||||
|
||||
class TerroristSpreadModel(FSM, Geo):
|
||||
"""
|
||||
Settings:
|
||||
information_spread_intensity
|
||||
|
||||
terrorist_additional_influence
|
||||
|
||||
min_vulnerability (optional else zero)
|
||||
|
||||
max_vulnerability
|
||||
|
||||
prob_interaction
|
||||
"""
|
||||
|
||||
def __init__(self, environment=None, agent_id=0, state=()):
|
||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||
|
||||
self.information_spread_intensity = environment.environment_params['information_spread_intensity']
|
||||
self.terrorist_additional_influence = environment.environment_params['terrorist_additional_influence']
|
||||
self.prob_interaction = environment.environment_params['prob_interaction']
|
||||
|
||||
if self['id'] == self.civilian.id: # Civilian
|
||||
self.mean_belief = random.uniform(0.00, 0.5)
|
||||
elif self['id'] == self.terrorist.id: # Terrorist
|
||||
self.mean_belief = random.uniform(0.8, 1.00)
|
||||
elif self['id'] == self.leader.id: # Leader
|
||||
self.mean_belief = 1.00
|
||||
else:
|
||||
raise Exception('Invalid state id: {}'.format(self['id']))
|
||||
|
||||
if 'min_vulnerability' in environment.environment_params:
|
||||
self.vulnerability = random.uniform( environment.environment_params['min_vulnerability'], environment.environment_params['max_vulnerability'] )
|
||||
else :
|
||||
self.vulnerability = random.uniform( 0, environment.environment_params['max_vulnerability'] )
|
||||
|
||||
|
||||
@state
|
||||
def civilian(self):
|
||||
neighbours = list(self.get_neighboring_agents(agent_type=TerroristSpreadModel))
|
||||
if len(neighbours) > 0:
|
||||
# Only interact with some of the neighbors
|
||||
interactions = list(n for n in neighbours if random.random() <= self.prob_interaction)
|
||||
influence = sum( self.degree(i) for i in interactions )
|
||||
mean_belief = sum( i.mean_belief * self.degree(i) / influence for i in interactions )
|
||||
mean_belief = mean_belief * self.information_spread_intensity + self.mean_belief * ( 1 - self.information_spread_intensity )
|
||||
self.mean_belief = mean_belief * self.vulnerability + self.mean_belief * ( 1 - self.vulnerability )
|
||||
|
||||
if self.mean_belief >= 0.8:
|
||||
return self.terrorist
|
||||
|
||||
@state
|
||||
def leader(self):
|
||||
self.mean_belief = self.mean_belief ** ( 1 - self.terrorist_additional_influence )
|
||||
for neighbour in self.get_neighboring_agents(state_id=[self.terrorist.id, self.leader.id]):
|
||||
if self.betweenness(neighbour) > self.betweenness(self):
|
||||
return self.terrorist
|
||||
|
||||
@state
|
||||
def terrorist(self):
|
||||
neighbours = self.get_agents(state_id=[self.terrorist.id, self.leader.id],
|
||||
agent_type=TerroristSpreadModel,
|
||||
limit_neighbors=True)
|
||||
if len(neighbours) > 0:
|
||||
influence = sum( self.degree(n) for n in neighbours )
|
||||
mean_belief = sum( n.mean_belief * self.degree(n) / influence for n in neighbours )
|
||||
mean_belief = mean_belief * self.vulnerability + self.mean_belief * ( 1 - self.vulnerability )
|
||||
self.mean_belief = self.mean_belief ** ( 1 - self.terrorist_additional_influence )
|
||||
|
||||
# Check if there are any leaders in the group
|
||||
leaders = list(filter(lambda x: x.state.id == self.leader.id, neighbours))
|
||||
if not leaders:
|
||||
# Check if this is the potential leader
|
||||
# Stop once it's found. Otherwise, set self as leader
|
||||
for neighbour in neighbours:
|
||||
if self.betweenness(self) < self.betweenness(neighbour):
|
||||
return
|
||||
return self.leader
|
||||
|
||||
|
||||
class TrainingAreaModel(FSM, Geo):
|
||||
"""
|
||||
Settings:
|
||||
training_influence
|
||||
|
||||
min_vulnerability
|
||||
|
||||
Requires TerroristSpreadModel.
|
||||
"""
|
||||
|
||||
def __init__(self, environment=None, agent_id=0, state=()):
|
||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||
self.training_influence = environment.environment_params['training_influence']
|
||||
if 'min_vulnerability' in environment.environment_params:
|
||||
self.min_vulnerability = environment.environment_params['min_vulnerability']
|
||||
else: self.min_vulnerability = 0
|
||||
|
||||
@default_state
|
||||
@state
|
||||
def terrorist(self):
|
||||
for neighbour in self.get_neighboring_agents(agent_type=TerroristSpreadModel):
|
||||
if neighbour.vulnerability > self.min_vulnerability:
|
||||
neighbour.vulnerability = neighbour.vulnerability ** ( 1 - self.training_influence )
|
||||
|
||||
|
||||
class HavenModel(FSM, Geo):
|
||||
"""
|
||||
Settings:
|
||||
haven_influence
|
||||
|
||||
min_vulnerability
|
||||
|
||||
max_vulnerability
|
||||
|
||||
Requires TerroristSpreadModel.
|
||||
"""
|
||||
|
||||
def __init__(self, environment=None, agent_id=0, state=()):
|
||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||
self.haven_influence = environment.environment_params['haven_influence']
|
||||
if 'min_vulnerability' in environment.environment_params:
|
||||
self.min_vulnerability = environment.environment_params['min_vulnerability']
|
||||
else: self.min_vulnerability = 0
|
||||
self.max_vulnerability = environment.environment_params['max_vulnerability']
|
||||
|
||||
def get_occupants(self, **kwargs):
|
||||
return self.get_neighboring_agents(agent_type=TerroristSpreadModel, **kwargs)
|
||||
|
||||
@state
|
||||
def civilian(self):
|
||||
civilians = self.get_occupants(state_id=self.civilian.id)
|
||||
if not civilians:
|
||||
return self.terrorist
|
||||
|
||||
for neighbour in self.get_occupants():
|
||||
if neighbour.vulnerability > self.min_vulnerability:
|
||||
neighbour.vulnerability = neighbour.vulnerability * ( 1 - self.haven_influence )
|
||||
return self.civilian
|
||||
|
||||
@state
|
||||
def terrorist(self):
|
||||
for neighbour in self.get_occupants():
|
||||
if neighbour.vulnerability < self.max_vulnerability:
|
||||
neighbour.vulnerability = neighbour.vulnerability ** ( 1 - self.haven_influence )
|
||||
return self.terrorist
|
||||
|
||||
|
||||
class TerroristNetworkModel(TerroristSpreadModel):
|
||||
"""
|
||||
Settings:
|
||||
sphere_influence
|
||||
|
||||
vision_range
|
||||
|
||||
weight_social_distance
|
||||
|
||||
weight_link_distance
|
||||
"""
|
||||
|
||||
def __init__(self, environment=None, agent_id=0, state=()):
|
||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||
|
||||
self.vision_range = environment.environment_params['vision_range']
|
||||
self.sphere_influence = environment.environment_params['sphere_influence']
|
||||
self.weight_social_distance = environment.environment_params['weight_social_distance']
|
||||
self.weight_link_distance = environment.environment_params['weight_link_distance']
|
||||
|
||||
@state
|
||||
def terrorist(self):
|
||||
self.update_relationships()
|
||||
return super().terrorist()
|
||||
|
||||
@state
|
||||
def leader(self):
|
||||
self.update_relationships()
|
||||
return super().leader()
|
||||
|
||||
def update_relationships(self):
|
||||
if self.count_neighboring_agents(state_id=self.civilian.id) == 0:
|
||||
close_ups = set(self.geo_search(radius=self.vision_range, agent_type=TerroristNetworkModel))
|
||||
step_neighbours = set(self.ego_search(self.sphere_influence, agent_type=TerroristNetworkModel, center=False))
|
||||
neighbours = set(agent.id for agent in self.get_neighboring_agents(agent_type=TerroristNetworkModel))
|
||||
search = (close_ups | step_neighbours) - neighbours
|
||||
for agent in self.get_agents(search):
|
||||
social_distance = 1 / self.shortest_path_length(agent.id)
|
||||
spatial_proximity = ( 1 - self.get_distance(agent.id) )
|
||||
prob_new_interaction = self.weight_social_distance * social_distance + self.weight_link_distance * spatial_proximity
|
||||
if agent['id'] == agent.civilian.id and random.random() < prob_new_interaction:
|
||||
self.add_edge(agent)
|
||||
break
|
||||
|
||||
def get_distance(self, target):
|
||||
source_x, source_y = nx.get_node_attributes(self.topology, 'pos')[self.id]
|
||||
target_x, target_y = nx.get_node_attributes(self.topology, 'pos')[target]
|
||||
dx = abs( source_x - target_x )
|
||||
dy = abs( source_y - target_y )
|
||||
return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 )
|
||||
|
||||
def shortest_path_length(self, target):
|
||||
try:
|
||||
return nx.shortest_path_length(self.topology, self.id, target)
|
||||
except nx.NetworkXNoPath:
|
||||
return float('inf')
|
||||
@@ -60,3 +60,4 @@ visualization_params:
|
||||
background_image: 'map_4800x2860.jpg'
|
||||
background_opacity: '0.9'
|
||||
background_filter_color: 'blue'
|
||||
skip_test: true # This simulation takes too long for automated tests.
|
||||
File diff suppressed because one or more lines are too long
@@ -1,7 +1,9 @@
|
||||
nxsim
|
||||
simpy
|
||||
networkx>=2.0
|
||||
simpy>=4.0
|
||||
networkx>=2.5
|
||||
numpy
|
||||
matplotlib
|
||||
pyyaml
|
||||
pandas
|
||||
pyyaml>=5.1
|
||||
pandas>=0.23
|
||||
scipy>=1.3
|
||||
SALib>=1.3
|
||||
Jinja2
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.13.4
|
||||
0.15.1
|
||||
@@ -14,14 +14,15 @@ except NameError:
|
||||
from . import agents
|
||||
from .simulation import *
|
||||
from .environment import Environment
|
||||
from . import utils
|
||||
from .history import History
|
||||
from . import serialization
|
||||
from . import analysis
|
||||
from .utils import logger
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
from . import simulation
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info('Running SOIL version: {}'.format(__version__))
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run a SOIL simulation')
|
||||
@@ -39,12 +40,17 @@ def main():
|
||||
help='Dump GEXF graph. Defaults to false.')
|
||||
parser.add_argument('--csv', action='store_true',
|
||||
help='Dump history in CSV format. Defaults to false.')
|
||||
parser.add_argument('--level', type=str,
|
||||
help='Logging level')
|
||||
parser.add_argument('--output', '-o', type=str, default="soil_output",
|
||||
help='folder to write results to. It defaults to the current directory.')
|
||||
parser.add_argument('--synchronous', action='store_true',
|
||||
help='Run trials serially and synchronously instead of in parallel. Defaults to false.')
|
||||
parser.add_argument('-e', '--exporter', action='append',
|
||||
help='Export environment and/or simulations using this exporter')
|
||||
|
||||
args = parser.parse_args()
|
||||
logging.basicConfig(level=getattr(logging, (args.level or 'INFO').upper()))
|
||||
|
||||
if os.getcwd() not in sys.path:
|
||||
sys.path.append(os.getcwd())
|
||||
@@ -54,17 +60,20 @@ def main():
|
||||
logging.info('Loading config file: {}'.format(args.file))
|
||||
|
||||
try:
|
||||
dump = []
|
||||
if not args.dry_run:
|
||||
if args.csv:
|
||||
dump.append('csv')
|
||||
if args.graph:
|
||||
dump.append('gexf')
|
||||
exporters = list(args.exporter or ['default', ])
|
||||
if args.csv:
|
||||
exporters.append('csv')
|
||||
if args.graph:
|
||||
exporters.append('gexf')
|
||||
exp_params = {}
|
||||
if args.dry_run:
|
||||
exp_params['copy_to'] = sys.stdout
|
||||
simulation.run_from_config(args.file,
|
||||
dry_run=args.dry_run,
|
||||
dump=dump,
|
||||
exporters=exporters,
|
||||
parallel=(not args.synchronous),
|
||||
results_dir=args.output)
|
||||
outdir=args.output,
|
||||
exporter_params=exp_params)
|
||||
except Exception:
|
||||
if args.pdb:
|
||||
pdb.post_mortem()
|
||||
|
||||
@@ -9,7 +9,7 @@ class BassModel(BaseAgent):
|
||||
imitation_prob
|
||||
"""
|
||||
|
||||
def __init__(self, environment, agent_id, state):
|
||||
def __init__(self, environment, agent_id, state, **kwargs):
|
||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||
env_params = environment.environment_params
|
||||
self.state['sentimentCorrelation'] = 0
|
||||
@@ -19,7 +19,7 @@ class BassModel(BaseAgent):
|
||||
|
||||
def behaviour(self):
|
||||
# Outside effects
|
||||
if random.random() < self.state_params['innovation_prob']:
|
||||
if random.random() < self['innovation_prob']:
|
||||
if self.state['id'] == 0:
|
||||
self.state['id'] = 1
|
||||
self.state['sentimentCorrelation'] = 1
|
||||
@@ -32,7 +32,7 @@ class BassModel(BaseAgent):
|
||||
if self.state['id'] == 0:
|
||||
aware_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
num_neighbors_aware = len(aware_neighbors)
|
||||
if random.random() < (self.state_params['imitation_prob']*num_neighbors_aware):
|
||||
if random.random() < (self['imitation_prob']*num_neighbors_aware):
|
||||
self.state['id'] = 1
|
||||
self.state['sentimentCorrelation'] = 1
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from . import BaseAgent
|
||||
from . import NetworkAgent
|
||||
|
||||
|
||||
class CounterModel(BaseAgent):
|
||||
class CounterModel(NetworkAgent):
|
||||
"""
|
||||
Dummy behaviour. It counts the number of nodes in the network and neighbors
|
||||
in each step and adds it to its state.
|
||||
@@ -9,24 +9,30 @@ class CounterModel(BaseAgent):
|
||||
|
||||
def step(self):
|
||||
# Outside effects
|
||||
total = len(list(self.get_all_agents()))
|
||||
total = len(list(self.get_agents()))
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['times'] = self.get('times', 0) + 1
|
||||
self['neighbors'] = neighbors
|
||||
self['total'] = total
|
||||
|
||||
|
||||
class AggregatedCounter(BaseAgent):
|
||||
class AggregatedCounter(NetworkAgent):
|
||||
"""
|
||||
Dummy behaviour. It counts the number of nodes in the network and neighbors
|
||||
in each step and adds it to its state.
|
||||
"""
|
||||
|
||||
defaults = {
|
||||
'times': 0,
|
||||
'neighbors': 0,
|
||||
'total': 0
|
||||
}
|
||||
|
||||
def step(self):
|
||||
# Outside effects
|
||||
total = len(list(self.get_all_agents()))
|
||||
self['times'] += 1
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['times'] = self.get('times', 0) + 1
|
||||
self['neighbors'] = self.get('neighbors', 0) + neighbors
|
||||
self['total'] = total = self.get('total', 0) + total
|
||||
self['neighbors'] += neighbors
|
||||
total = len(list(self.get_agents()))
|
||||
self['total'] += total
|
||||
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from . import BaseAgent
|
||||
|
||||
import os.path
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
|
||||
|
||||
class DrawingAgent(BaseAgent):
|
||||
"""
|
||||
Agent that draws the state of the network.
|
||||
"""
|
||||
|
||||
def step(self):
|
||||
# Outside effects
|
||||
f = plt.figure()
|
||||
nx.draw(self.env.G, node_size=10, width=0.2, pos=nx.spring_layout(self.env.G, scale=100), ax=f.add_subplot(111))
|
||||
f.savefig(os.path.join(self.env.get_path(), "graph-"+str(self.env.now)+".png"))
|
||||
@@ -3,21 +3,28 @@
|
||||
# for x in range(0, settings.network_params["number_of_nodes"]):
|
||||
# sentimentCorrelationNodeArray.append({'id': x})
|
||||
# Initialize agent states. Let's assume everyone is normal.
|
||||
|
||||
|
||||
import nxsim
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from scipy.spatial import cKDTree as KDTree
|
||||
import json
|
||||
import simpy
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from .. import utils, history
|
||||
from .. import serialization, history, utils
|
||||
|
||||
|
||||
class BaseAgent(nxsim.BaseAgent):
|
||||
def as_node(agent):
|
||||
if isinstance(agent, BaseAgent):
|
||||
return agent.id
|
||||
return agent
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
"""
|
||||
A special simpy BaseAgent that keeps track of its state history.
|
||||
"""
|
||||
@@ -25,14 +32,13 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
defaults = {}
|
||||
|
||||
def __init__(self, environment, agent_id, state=None,
|
||||
name='network_process', interval=None, **state_params):
|
||||
name=None, interval=None):
|
||||
# Check for REQUIRED arguments
|
||||
assert environment is not None, TypeError('__init__ missing 1 required keyword argument: \'environment\'. '
|
||||
'Cannot be NoneType.')
|
||||
# Initialize agent parameters
|
||||
self.id = agent_id
|
||||
self.name = name
|
||||
self.state_params = state_params
|
||||
self.name = name or '{}[{}]'.format(type(self).__name__, self.id)
|
||||
|
||||
# Register agent to environment
|
||||
self.env = environment
|
||||
@@ -44,11 +50,10 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
self.state = real_state
|
||||
self.interval = interval
|
||||
|
||||
if not hasattr(self, 'level'):
|
||||
self.level = logging.DEBUG
|
||||
self.logger = logging.getLogger('{}-Agent-{}'.format(self.env.name,
|
||||
self.id))
|
||||
self.logger.setLevel(self.level)
|
||||
self.logger = logging.getLogger(self.env.name).getChild(self.name)
|
||||
|
||||
if hasattr(self, 'level'):
|
||||
self.logger.setLevel(self.level)
|
||||
|
||||
# initialize every time an instance of the agent is created
|
||||
self.action = self.env.process(self.run())
|
||||
@@ -69,14 +74,10 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
for k, v in value.items():
|
||||
self[k] = v
|
||||
|
||||
@property
|
||||
def global_topology(self):
|
||||
return self.env.G
|
||||
|
||||
@property
|
||||
def environment_params(self):
|
||||
return self.env.environment_params
|
||||
|
||||
|
||||
@environment_params.setter
|
||||
def environment_params(self, value):
|
||||
self.env.environment_params = value
|
||||
@@ -129,65 +130,17 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
def die(self, remove=False):
|
||||
self.alive = False
|
||||
if remove:
|
||||
super().die()
|
||||
self.remove_node(self.id)
|
||||
|
||||
def step(self):
|
||||
pass
|
||||
|
||||
def count_agents(self, state_id=None, limit_neighbors=False):
|
||||
if limit_neighbors:
|
||||
agents = self.global_topology.neighbors(self.id)
|
||||
else:
|
||||
agents = self.global_topology.nodes()
|
||||
count = 0
|
||||
for agent in agents:
|
||||
if state_id and state_id != self.global_topology.node[agent]['agent']['id']:
|
||||
continue
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def count_neighboring_agents(self, state_id=None):
|
||||
return len(super().get_agents(state_id, limit_neighbors=True))
|
||||
|
||||
def get_agents(self, state_id=None, agent_type=None, limit_neighbors=False, iterator=False, **kwargs):
|
||||
agents = self.env.agents
|
||||
if limit_neighbors:
|
||||
agents = super().get_agents(state_id, limit_neighbors)
|
||||
|
||||
def matches_all(agent):
|
||||
if state_id is not None:
|
||||
if agent.state.get('id', None) != state_id:
|
||||
return False
|
||||
if agent_type is not None:
|
||||
if type(agent) != agent_type:
|
||||
return False
|
||||
state = agent.state
|
||||
for k, v in kwargs.items():
|
||||
if state.get(k, None) != v:
|
||||
return False
|
||||
return True
|
||||
|
||||
f = filter(matches_all, agents)
|
||||
if iterator:
|
||||
return f
|
||||
return list(f)
|
||||
|
||||
def log(self, message, *args, level=logging.INFO, **kwargs):
|
||||
message = message + " ".join(str(i) for i in args)
|
||||
message = "\t@{:>5}:\t{}".format(self.now, message)
|
||||
for k, v in kwargs:
|
||||
message += " {k}={v} ".format(k, v)
|
||||
extra = {}
|
||||
extra['now'] = self.now
|
||||
extra['id'] = self.id
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
return
|
||||
|
||||
def debug(self, *args, **kwargs):
|
||||
return self.log(*args, level=logging.DEBUG, **kwargs)
|
||||
|
||||
def info(self, *args, **kwargs):
|
||||
return self.log(*args, level=logging.INFO, **kwargs)
|
||||
|
||||
|
||||
def __getstate__(self):
|
||||
'''
|
||||
Serializing an agent will lose all its running information (you cannot
|
||||
@@ -208,31 +161,111 @@ class BaseAgent(nxsim.BaseAgent):
|
||||
self._state = state['_state']
|
||||
self.env = state['environment']
|
||||
|
||||
class NetworkAgent(BaseAgent):
|
||||
|
||||
def state(func):
|
||||
'''
|
||||
A state function should return either a state id, or a tuple (state_id, when)
|
||||
The default value for state_id is the current state id.
|
||||
The default value for when is the interval defined in the nevironment.
|
||||
'''
|
||||
@property
|
||||
def topology(self):
|
||||
return self.env.G
|
||||
|
||||
@wraps(func)
|
||||
def func_wrapper(self):
|
||||
next_state = func(self)
|
||||
when = None
|
||||
if next_state is None:
|
||||
@property
|
||||
def G(self):
|
||||
return self.env.G
|
||||
|
||||
def count_agents(self, **kwargs):
|
||||
return len(list(self.get_agents(**kwargs)))
|
||||
|
||||
def count_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return len(self.get_neighboring_agents(state_id=state_id, **kwargs))
|
||||
|
||||
def get_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
|
||||
|
||||
def get_agents(self, agents=None, limit_neighbors=False, **kwargs):
|
||||
if limit_neighbors:
|
||||
agents = self.topology.neighbors(self.id)
|
||||
|
||||
agents = self.env.get_agents(agents)
|
||||
return select(agents, **kwargs)
|
||||
|
||||
def log(self, message, *args, level=logging.INFO, **kwargs):
|
||||
message = message + " ".join(str(i) for i in args)
|
||||
message = " @{:>3}: {}".format(self.now, message)
|
||||
for k, v in kwargs:
|
||||
message += " {k}={v} ".format(k, v)
|
||||
extra = {}
|
||||
extra['now'] = self.now
|
||||
extra['agent_id'] = self.id
|
||||
extra['agent_name'] = self.name
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
|
||||
def subgraph(self, center=True, **kwargs):
|
||||
include = [self] if center else []
|
||||
return self.topology.subgraph(n.id for n in self.get_agents(**kwargs)+include)
|
||||
|
||||
def remove_node(self, agent_id):
|
||||
self.topology.remove_node(agent_id)
|
||||
|
||||
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
||||
# return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
|
||||
if self.id not in self.topology.nodes(data=False):
|
||||
raise ValueError('{} not in list of existing agents in the network'.format(self.id))
|
||||
if other not in self.topology.nodes(data=False):
|
||||
raise ValueError('{} not in list of existing agents in the network'.format(other))
|
||||
|
||||
self.topology.add_edge(self.id, other, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
||||
|
||||
|
||||
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
||||
'''Get a list of nodes in the ego network of *node* of radius *steps*'''
|
||||
node = as_node(node if node is not None else self)
|
||||
G = self.subgraph(**kwargs)
|
||||
return nx.ego_graph(G, node, center=center, radius=steps).nodes()
|
||||
|
||||
def degree(self, node, force=False):
|
||||
node = as_node(node)
|
||||
if force or (not hasattr(self.env, '_degree')) or getattr(self.env, '_last_step', 0) < self.now:
|
||||
self.env._degree = nx.degree_centrality(self.topology)
|
||||
self.env._last_step = self.now
|
||||
return self.env._degree[node]
|
||||
|
||||
def betweenness(self, node, force=False):
|
||||
node = as_node(node)
|
||||
if force or (not hasattr(self.env, '_betweenness')) or getattr(self.env, '_last_step', 0) < self.now:
|
||||
self.env._betweenness = nx.betweenness_centrality(self.topology)
|
||||
self.env._last_step = self.now
|
||||
return self.env._betweenness[node]
|
||||
|
||||
|
||||
def state(name=None):
|
||||
def decorator(func, name=None):
|
||||
'''
|
||||
A state function should return either a state id, or a tuple (state_id, when)
|
||||
The default value for state_id is the current state id.
|
||||
The default value for when is the interval defined in the environment.
|
||||
'''
|
||||
|
||||
@wraps(func)
|
||||
def func_wrapper(self):
|
||||
next_state = func(self)
|
||||
when = None
|
||||
if next_state is None:
|
||||
return when
|
||||
try:
|
||||
next_state, when = next_state
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if next_state:
|
||||
self.set_state(next_state)
|
||||
return when
|
||||
try:
|
||||
next_state, when = next_state
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if next_state:
|
||||
self.set_state(next_state)
|
||||
return when
|
||||
|
||||
func_wrapper.id = func.__name__
|
||||
func_wrapper.is_default = False
|
||||
return func_wrapper
|
||||
func_wrapper.id = name or func.__name__
|
||||
func_wrapper.is_default = False
|
||||
return func_wrapper
|
||||
|
||||
if callable(name):
|
||||
return decorator(name)
|
||||
else:
|
||||
return partial(decorator, name=name)
|
||||
|
||||
|
||||
def default_state(func):
|
||||
@@ -263,16 +296,22 @@ class MetaFSM(type):
|
||||
cls.states = states
|
||||
|
||||
|
||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
class FSM(NetworkAgent, metaclass=MetaFSM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FSM, self).__init__(*args, **kwargs)
|
||||
if 'id' not in self.state:
|
||||
if not self.default_state:
|
||||
raise ValueError('No default state specified for {}'.format(self.id))
|
||||
self['id'] = self.default_state.id
|
||||
self._next_change = simpy.core.Infinity
|
||||
self._next_state = self.state
|
||||
|
||||
def step(self):
|
||||
if 'id' in self.state:
|
||||
if self._next_change < self.now:
|
||||
next_state = self._next_state
|
||||
self._next_change = simpy.core.Infinity
|
||||
self['id'] = next_state
|
||||
elif 'id' in self.state:
|
||||
next_state = self['id']
|
||||
elif self.default_state:
|
||||
next_state = self.default_state.id
|
||||
@@ -280,7 +319,11 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
raise Exception('{} has no valid state id or default state'.format(self))
|
||||
if next_state not in self.states:
|
||||
raise Exception('{} is not a valid id for {}'.format(next_state, self))
|
||||
self.states[next_state](self)
|
||||
return self.states[next_state](self)
|
||||
|
||||
def next_state(self, state):
|
||||
self._next_change = self.now
|
||||
self._next_state = state
|
||||
|
||||
def set_state(self, state):
|
||||
if hasattr(state, 'id'):
|
||||
@@ -306,6 +349,9 @@ def prob(prob=1):
|
||||
return r < prob
|
||||
|
||||
|
||||
STATIC_THRESHOLD = (-1, -1)
|
||||
|
||||
|
||||
def calculate_distribution(network_agents=None,
|
||||
agent_type=None):
|
||||
'''
|
||||
@@ -337,13 +383,20 @@ def calculate_distribution(network_agents=None,
|
||||
elif agent_type:
|
||||
network_agents = [{'agent_type': agent_type}]
|
||||
else:
|
||||
return []
|
||||
raise ValueError('Specify a distribution or a default agent type')
|
||||
|
||||
# Fix missing weights and incompatible types
|
||||
for x in network_agents:
|
||||
x['weight'] = float(x.get('weight', 1))
|
||||
|
||||
# Calculate the thresholds
|
||||
total = sum(x.get('weight', 1) for x in network_agents)
|
||||
total = sum(x['weight'] for x in network_agents)
|
||||
acc = 0
|
||||
for v in network_agents:
|
||||
upper = acc + (v.get('weight', 1)/total)
|
||||
if 'ids' in v:
|
||||
v['threshold'] = STATIC_THRESHOLD
|
||||
continue
|
||||
upper = acc + (v['weight']/total)
|
||||
v['threshold'] = [acc, upper]
|
||||
acc = upper
|
||||
return network_agents
|
||||
@@ -353,7 +406,7 @@ def serialize_type(agent_type, known_modules=[], **kwargs):
|
||||
if isinstance(agent_type, str):
|
||||
return agent_type
|
||||
known_modules += ['soil.agents']
|
||||
return utils.serialize(agent_type, known_modules=known_modules, **kwargs)[1] # Get the name of the class
|
||||
return serialization.serialize(agent_type, known_modules=known_modules, **kwargs)[1] # Get the name of the class
|
||||
|
||||
|
||||
def serialize_distribution(network_agents, known_modules=[]):
|
||||
@@ -374,7 +427,7 @@ def deserialize_type(agent_type, known_modules=[]):
|
||||
if not isinstance(agent_type, str):
|
||||
return agent_type
|
||||
known = known_modules + ['soil.agents', 'soil.agents.custom' ]
|
||||
agent_type = utils.deserializer(agent_type, known_modules=known)
|
||||
agent_type = serialization.deserializer(agent_type, known_modules=known)
|
||||
return agent_type
|
||||
|
||||
|
||||
@@ -390,7 +443,7 @@ def _validate_states(states, topology):
|
||||
states = states or []
|
||||
if isinstance(states, dict):
|
||||
for x in states:
|
||||
assert x in topology.node
|
||||
assert x in topology.nodes
|
||||
else:
|
||||
assert len(states) <= len(topology)
|
||||
return states
|
||||
@@ -403,21 +456,73 @@ def _convert_agent_types(ind, to_string=False, **kwargs):
|
||||
return deserialize_distribution(ind, **kwargs)
|
||||
|
||||
|
||||
def _agent_from_distribution(distribution, value=-1):
|
||||
def _agent_from_distribution(distribution, value=-1, agent_id=None):
|
||||
"""Used in the initialization of agents given an agent distribution."""
|
||||
if value < 0:
|
||||
value = random.random()
|
||||
for d in distribution:
|
||||
for d in sorted(distribution, key=lambda x: x['threshold']):
|
||||
threshold = d['threshold']
|
||||
if value >= threshold[0] and value < threshold[1]:
|
||||
state = {}
|
||||
if 'state' in d:
|
||||
state = deepcopy(d['state'])
|
||||
return d['agent_type'], state
|
||||
# Check if the definition matches by id (first) or by threshold
|
||||
if not ((agent_id is not None and threshold == STATIC_THRESHOLD and agent_id in d['ids']) or \
|
||||
(value >= threshold[0] and value < threshold[1])):
|
||||
continue
|
||||
state = {}
|
||||
if 'state' in d:
|
||||
state = deepcopy(d['state'])
|
||||
return d['agent_type'], state
|
||||
|
||||
raise Exception('Distribution for value {} not found in: {}'.format(value, distribution))
|
||||
|
||||
|
||||
class Geo(NetworkAgent):
|
||||
'''In this type of network, nodes have a "pos" attribute.'''
|
||||
|
||||
def geo_search(self, radius, node=None, center=False, **kwargs):
|
||||
'''Get a list of nodes whose coordinates are closer than *radius* to *node*.'''
|
||||
node = as_node(node if node is not None else self)
|
||||
|
||||
G = self.subgraph(**kwargs)
|
||||
|
||||
pos = nx.get_node_attributes(G, 'pos')
|
||||
if not pos:
|
||||
return []
|
||||
nodes, coords = list(zip(*pos.items()))
|
||||
kdtree = KDTree(coords) # Cannot provide generator.
|
||||
indices = kdtree.query_ball_point(pos[node], radius)
|
||||
return [nodes[i] for i in indices if center or (nodes[i] != node)]
|
||||
|
||||
|
||||
def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs):
|
||||
|
||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||
state_id = tuple([state_id])
|
||||
if agent_type is not None:
|
||||
try:
|
||||
agent_type = tuple(agent_type)
|
||||
except TypeError:
|
||||
agent_type = tuple([agent_type])
|
||||
|
||||
def matches_all(agent):
|
||||
if state_id is not None:
|
||||
if agent.state.get('id', None) not in state_id:
|
||||
return False
|
||||
if agent_type is not None:
|
||||
if not isinstance(agent, agent_type):
|
||||
return False
|
||||
state = agent.state
|
||||
for k, v in kwargs.items():
|
||||
if state.get(k, None) != v:
|
||||
return False
|
||||
return True
|
||||
|
||||
f = filter(matches_all, agents)
|
||||
if ignore:
|
||||
f = filter(lambda x: x not in ignore, f)
|
||||
if iterator:
|
||||
return f
|
||||
return list(f)
|
||||
|
||||
|
||||
from .BassModel import *
|
||||
from .BigMarketModel import *
|
||||
from .IndependentCascadeModel import *
|
||||
@@ -425,4 +530,3 @@ from .ModelM2 import *
|
||||
from .SentimentCorrelationModel import *
|
||||
from .SISaModel import *
|
||||
from .CounterModel import *
|
||||
from .DrawingAgent import *
|
||||
|
||||
@@ -4,7 +4,7 @@ import glob
|
||||
import yaml
|
||||
from os.path import join
|
||||
|
||||
from . import utils, history
|
||||
from . import serialization, history
|
||||
|
||||
|
||||
def read_data(*args, group=False, **kwargs):
|
||||
@@ -20,7 +20,7 @@ def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
|
||||
process_args = {}
|
||||
for folder in glob.glob(pattern):
|
||||
config_file = glob.glob(join(folder, '*.yml'))[0]
|
||||
config = yaml.load(open(config_file))
|
||||
config = yaml.load(open(config_file), Loader=yaml.SafeLoader)
|
||||
df = None
|
||||
if from_csv:
|
||||
for trial_data in sorted(glob.glob(join(folder,
|
||||
@@ -28,13 +28,13 @@ def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
|
||||
df = read_csv(trial_data, **kwargs)
|
||||
yield config_file, df, config
|
||||
else:
|
||||
for trial_data in sorted(glob.glob(join(folder, '*.db.sqlite'))):
|
||||
for trial_data in sorted(glob.glob(join(folder, '*.sqlite'))):
|
||||
df = read_sql(trial_data, **kwargs)
|
||||
yield config_file, df, config
|
||||
|
||||
|
||||
def read_sql(db, *args, **kwargs):
|
||||
h = history.History(db, backup=False)
|
||||
h = history.History(db_path=db, backup=False, readonly=True)
|
||||
df = h.read_sql(*args, **kwargs)
|
||||
return df
|
||||
|
||||
@@ -56,7 +56,7 @@ def read_csv(filename, keys=None, convert_types=False, **kwargs):
|
||||
|
||||
|
||||
def convert_row(row):
|
||||
row['value'] = utils.deserialize(row['value_type'], row['value'])
|
||||
row['value'] = serialization.deserialize(row['value_type'], row['value'])
|
||||
return row
|
||||
|
||||
|
||||
@@ -69,6 +69,13 @@ def convert_types_slow(df):
|
||||
df = df.apply(convert_row, axis=1)
|
||||
return df
|
||||
|
||||
|
||||
def split_processed(df):
|
||||
env = df.loc[:, df.columns.get_level_values(1).isin(['env', 'stats'])]
|
||||
agents = df.loc[:, ~df.columns.get_level_values(1).isin(['env', 'stats'])]
|
||||
return env, agents
|
||||
|
||||
|
||||
def split_df(df):
|
||||
'''
|
||||
Split a dataframe in two dataframes: one with the history of agents,
|
||||
@@ -133,10 +140,10 @@ def get_count(df, *keys):
|
||||
def get_value(df, *keys, aggfunc='sum'):
|
||||
if keys:
|
||||
df = df[list(keys)]
|
||||
return df.groupby(axis=1, level=0).agg(aggfunc, axis=1)
|
||||
return df.groupby(axis=1, level=0).agg(aggfunc)
|
||||
|
||||
|
||||
def plot_all(*args, **kwargs):
|
||||
def plot_all(*args, plot_args={}, **kwargs):
|
||||
'''
|
||||
Read all the trial data and plot the result of applying a function on them.
|
||||
'''
|
||||
@@ -144,14 +151,17 @@ def plot_all(*args, **kwargs):
|
||||
ps = []
|
||||
for line in dfs:
|
||||
f, df, config = line
|
||||
df.plot(title=config['name'])
|
||||
if len(df) < 1:
|
||||
continue
|
||||
df.plot(title=config['name'], **plot_args)
|
||||
ps.append(df)
|
||||
return ps
|
||||
|
||||
def do_all(pattern, func, *keys, include_env=False, **kwargs):
|
||||
for config_file, df, config in read_data(pattern, keys=keys):
|
||||
if len(df) < 1:
|
||||
continue
|
||||
p = func(df, *keys, **kwargs)
|
||||
p.plot(title=config['name'])
|
||||
yield config_file, p, config
|
||||
|
||||
|
||||
|
||||
@@ -4,26 +4,25 @@ import time
|
||||
import csv
|
||||
import random
|
||||
import simpy
|
||||
import yaml
|
||||
import tempfile
|
||||
import pandas as pd
|
||||
from copy import deepcopy
|
||||
from networkx.readwrite import json_graph
|
||||
|
||||
import networkx as nx
|
||||
import nxsim
|
||||
import simpy
|
||||
|
||||
from . import utils, agents, analysis, history
|
||||
from . import serialization, agents, analysis, history, utils
|
||||
|
||||
# These properties will be copied when pickling/unpickling the environment
|
||||
_CONFIG_PROPS = [ 'name',
|
||||
'states',
|
||||
'default_state',
|
||||
'interval',
|
||||
'dry_run',
|
||||
'dir_path',
|
||||
]
|
||||
|
||||
class Environment(nxsim.NetworkEnvironment):
|
||||
class Environment(simpy.Environment):
|
||||
"""
|
||||
The environment is key in a simulation. It contains the network topology,
|
||||
a reference to network and environment agents, as well as the environment
|
||||
@@ -41,32 +40,34 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
default_state=None,
|
||||
interval=1,
|
||||
seed=None,
|
||||
dry_run=False,
|
||||
dir_path=None,
|
||||
topology=None,
|
||||
*args, **kwargs):
|
||||
initial_time=0,
|
||||
**environment_params):
|
||||
|
||||
|
||||
self.name = name or 'UnnamedEnvironment'
|
||||
seed = seed or time.time()
|
||||
random.seed(seed)
|
||||
if isinstance(states, list):
|
||||
states = dict(enumerate(states))
|
||||
self.states = deepcopy(states) if states else {}
|
||||
self.default_state = deepcopy(default_state) or {}
|
||||
if not topology:
|
||||
topology = nx.Graph()
|
||||
super().__init__(*args, topology=topology, **kwargs)
|
||||
self.G = nx.Graph(topology)
|
||||
|
||||
super().__init__(initial_time=initial_time)
|
||||
self.environment_params = environment_params
|
||||
|
||||
self._env_agents = {}
|
||||
self.dry_run = dry_run
|
||||
self.interval = interval
|
||||
self.dir_path = dir_path or tempfile.mkdtemp('soil-env')
|
||||
if not dry_run:
|
||||
self.get_path()
|
||||
self._history = history.History(name=self.name if not dry_run else None,
|
||||
dir_path=self.dir_path)
|
||||
self._history = history.History(name=self.name,
|
||||
backup=True)
|
||||
self['SEED'] = seed
|
||||
# Add environment agents first, so their events get
|
||||
# executed before network agents
|
||||
self.environment_agents = environment_agents or []
|
||||
self.network_agents = network_agents or []
|
||||
self['SEED'] = seed or time.time()
|
||||
random.seed(self['SEED'])
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
@@ -93,14 +94,13 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
@property
|
||||
def network_agents(self):
|
||||
for i in self.G.nodes():
|
||||
node = self.G.node[i]
|
||||
node = self.G.nodes[i]
|
||||
if 'agent' in node:
|
||||
yield node['agent']
|
||||
|
||||
@network_agents.setter
|
||||
def network_agents(self, network_agents):
|
||||
if not network_agents:
|
||||
return
|
||||
self._network_agents = network_agents
|
||||
for ix in self.G.nodes():
|
||||
self.init_agent(ix, agent_distribution=network_agents)
|
||||
|
||||
@@ -111,7 +111,7 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
|
||||
agent_type = None
|
||||
if 'agent_type' in self.states.get(agent_id, {}):
|
||||
agent_type = self.states[agent_id]
|
||||
agent_type = self.states[agent_id]['agent_type']
|
||||
elif 'agent_type' in node:
|
||||
agent_type = node['agent_type']
|
||||
elif 'agent_type' in self.default_state:
|
||||
@@ -119,8 +119,11 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
|
||||
if agent_type:
|
||||
agent_type = agents.deserialize_type(agent_type)
|
||||
elif agent_distribution:
|
||||
agent_type, state = agents._agent_from_distribution(agent_distribution, agent_id=agent_id)
|
||||
else:
|
||||
agent_type, state = agents._agent_from_distribution(agent_distribution)
|
||||
serialization.logger.debug('Skipping node {}'.format(agent_id))
|
||||
return
|
||||
return self.set_agent(agent_id, agent_type, state)
|
||||
|
||||
def set_agent(self, agent_id, agent_type, state=None):
|
||||
@@ -130,10 +133,12 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
defstate.update(node.get('state', {}))
|
||||
if state:
|
||||
defstate.update(state)
|
||||
state = defstate
|
||||
a = agent_type(environment=self,
|
||||
agent_id=agent_id,
|
||||
state=state)
|
||||
a = None
|
||||
if agent_type:
|
||||
state = defstate
|
||||
a = agent_type(environment=self,
|
||||
agent_id=agent_id,
|
||||
state=state)
|
||||
node['agent'] = a
|
||||
return a
|
||||
|
||||
@@ -144,22 +149,21 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
a['visible'] = True
|
||||
return a
|
||||
|
||||
def add_edge(self, agent1, agent2, attrs=None):
|
||||
def add_edge(self, agent1, agent2, start=None, **attrs):
|
||||
if hasattr(agent1, 'id'):
|
||||
agent1 = agent1.id
|
||||
if hasattr(agent2, 'id'):
|
||||
agent2 = agent2.id
|
||||
return self.G.add_edge(agent1, agent2)
|
||||
start = start or self.now
|
||||
return self.G.add_edge(agent1, agent2, **attrs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
def run(self, until, *args, **kwargs):
|
||||
self._save_state()
|
||||
super().run(*args, **kwargs)
|
||||
super().run(until, *args, **kwargs)
|
||||
self._history.flush_cache()
|
||||
|
||||
def _save_state(self, now=None):
|
||||
# for agent in self.agents:
|
||||
# agent.save_state()
|
||||
utils.logger.debug('Saving state @{}'.format(self.now))
|
||||
serialization.logger.debug('Saving state @{}'.format(self.now))
|
||||
self._history.save_records(self.state_to_tuples(now=now))
|
||||
|
||||
def save_state(self):
|
||||
@@ -170,7 +174,7 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
self._save_state()
|
||||
while self.peek() != simpy.core.Infinity:
|
||||
delay = max(self.peek() - self.now, self.interval)
|
||||
utils.logger.debug('Step: {}'.format(self.now))
|
||||
serialization.logger.debug('Step: {}'.format(self.now))
|
||||
ev = self.event()
|
||||
ev._ok = True
|
||||
# Schedule the event with minimum priority so
|
||||
@@ -212,45 +216,33 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
'''
|
||||
return self[key] if key in self else default
|
||||
|
||||
def get_path(self, dir_path=None):
|
||||
dir_path = dir_path or self.dir_path
|
||||
if not os.path.exists(dir_path):
|
||||
try:
|
||||
os.makedirs(dir_path)
|
||||
except FileExistsError:
|
||||
pass
|
||||
return dir_path
|
||||
|
||||
def get_agent(self, agent_id):
|
||||
return self.G.node[agent_id]['agent']
|
||||
return self.G.nodes[agent_id]['agent']
|
||||
|
||||
def get_agents(self):
|
||||
return list(self.agents)
|
||||
def get_agents(self, nodes=None):
|
||||
if nodes is None:
|
||||
return list(self.agents)
|
||||
return [self.G.nodes[i]['agent'] for i in nodes]
|
||||
|
||||
def dump_csv(self, dir_path=None):
|
||||
csv_name = os.path.join(self.get_path(dir_path),
|
||||
'{}.environment.csv'.format(self.name))
|
||||
|
||||
with open(csv_name, 'w') as f:
|
||||
def dump_csv(self, f):
|
||||
with utils.open_or_reuse(f, 'w') as f:
|
||||
cr = csv.writer(f)
|
||||
cr.writerow(('agent_id', 't_step', 'key', 'value'))
|
||||
for i in self.history_to_tuples():
|
||||
cr.writerow(i)
|
||||
|
||||
def dump_gexf(self, dir_path=None):
|
||||
def dump_gexf(self, f):
|
||||
G = self.history_to_graph()
|
||||
graph_path = os.path.join(self.get_path(dir_path),
|
||||
self.name+".gexf")
|
||||
# Workaround for geometric models
|
||||
# See soil/soil#4
|
||||
for node in G.nodes():
|
||||
if 'pos' in G.node[node]:
|
||||
G.node[node]['viz'] = {"position": {"x": G.node[node]['pos'][0], "y": G.node[node]['pos'][1], "z": 0.0}}
|
||||
del (G.node[node]['pos'])
|
||||
if 'pos' in G.nodes[node]:
|
||||
G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}}
|
||||
del (G.nodes[node]['pos'])
|
||||
|
||||
nx.write_gexf(G, graph_path, version="1.2draft")
|
||||
nx.write_gexf(G, f, version="1.2draft")
|
||||
|
||||
def dump(self, dir_path=None, formats=None):
|
||||
def dump(self, *args, formats=None, **kwargs):
|
||||
if not formats:
|
||||
return
|
||||
functions = {
|
||||
@@ -259,10 +251,13 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
}
|
||||
for f in formats:
|
||||
if f in functions:
|
||||
functions[f](dir_path)
|
||||
functions[f](*args, **kwargs)
|
||||
else:
|
||||
raise ValueError('Unknown format: {}'.format(f))
|
||||
|
||||
def dump_sqlite(self, f):
|
||||
return self._history.dump(f)
|
||||
|
||||
def state_to_tuples(self, now=None):
|
||||
if now is None:
|
||||
now = self.now
|
||||
@@ -326,7 +321,7 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
G.add_node(agent.id, **attributes)
|
||||
|
||||
return G
|
||||
|
||||
|
||||
def __getstate__(self):
|
||||
state = {}
|
||||
for prop in _CONFIG_PROPS:
|
||||
@@ -334,6 +329,7 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
state['G'] = json_graph.node_link_data(self.G)
|
||||
state['environment_agents'] = self._env_agents
|
||||
state['history'] = self._history
|
||||
state['_now'] = self._now
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
@@ -342,6 +338,8 @@ class Environment(nxsim.NetworkEnvironment):
|
||||
self._env_agents = state['environment_agents']
|
||||
self.G = json_graph.node_link_graph(state['G'])
|
||||
self._history = state['history']
|
||||
self._now = state['_now']
|
||||
self._queue = []
|
||||
|
||||
|
||||
SoilEnvironment = Environment
|
||||
|
||||
157
soil/exporters.py
Normal file
157
soil/exporters.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import os
|
||||
import csv as csvlib
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
|
||||
|
||||
from .serialization import deserialize
|
||||
from .utils import open_or_reuse, logger, timer
|
||||
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
def for_sim(simulation, names, *args, **kwargs):
|
||||
'''Return the set of exporters for a simulation, given the exporter names'''
|
||||
exporters = []
|
||||
for name in names:
|
||||
mod = deserialize(name, known_modules=['soil.exporters'])
|
||||
exporters.append(mod(simulation, *args, **kwargs))
|
||||
return exporters
|
||||
|
||||
|
||||
class DryRunner(BytesIO):
|
||||
def __init__(self, fname, *args, copy_to=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__fname = fname
|
||||
self.__copy_to = copy_to
|
||||
|
||||
def write(self, txt):
|
||||
if self.__copy_to:
|
||||
self.__copy_to.write('{}:::{}'.format(self.__fname, txt))
|
||||
try:
|
||||
super().write(txt)
|
||||
except TypeError:
|
||||
super().write(bytes(txt, 'utf-8'))
|
||||
|
||||
def close(self):
|
||||
logger.info('**Not** written to {} (dry run mode):\n\n{}\n\n'.format(self.__fname,
|
||||
self.getvalue().decode()))
|
||||
super().close()
|
||||
|
||||
|
||||
class Exporter:
|
||||
'''
|
||||
Interface for all exporters. It is not necessary, but it is useful
|
||||
if you don't plan to implement all the methods.
|
||||
'''
|
||||
|
||||
def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None):
|
||||
self.simulation = simulation
|
||||
outdir = outdir or os.path.join(os.getcwd(), 'soil_output')
|
||||
self.outdir = os.path.join(outdir,
|
||||
simulation.group or '',
|
||||
simulation.name)
|
||||
self.dry_run = dry_run
|
||||
self.copy_to = copy_to
|
||||
|
||||
def start(self):
|
||||
'''Method to call when the simulation starts'''
|
||||
pass
|
||||
|
||||
def end(self, stats):
|
||||
'''Method to call when the simulation ends'''
|
||||
pass
|
||||
|
||||
def trial(self, env, stats):
|
||||
'''Method to call when a trial ends'''
|
||||
pass
|
||||
|
||||
def output(self, f, mode='w', **kwargs):
|
||||
if self.dry_run:
|
||||
f = DryRunner(f, copy_to=self.copy_to)
|
||||
else:
|
||||
try:
|
||||
if not os.path.isabs(f):
|
||||
f = os.path.join(self.outdir, f)
|
||||
except TypeError:
|
||||
pass
|
||||
return open_or_reuse(f, mode=mode, **kwargs)
|
||||
|
||||
|
||||
class default(Exporter):
|
||||
'''Default exporter. Writes sqlite results, as well as the simulation YAML'''
|
||||
|
||||
def start(self):
|
||||
if not self.dry_run:
|
||||
logger.info('Dumping results to %s', self.outdir)
|
||||
self.simulation.dump_yaml(outdir=self.outdir)
|
||||
else:
|
||||
logger.info('NOT dumping results')
|
||||
|
||||
def trial(self, env, stats):
|
||||
if not self.dry_run:
|
||||
with timer('Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||
env.name)):
|
||||
with self.output('{}.sqlite'.format(env.name), mode='wb') as f:
|
||||
env.dump_sqlite(f)
|
||||
|
||||
|
||||
class csv(Exporter):
|
||||
'''Export the state of each environment (and its agents) in a separate CSV file'''
|
||||
def trial(self, env, stats):
|
||||
with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name,
|
||||
env.name,
|
||||
self.outdir)):
|
||||
with self.output('{}.csv'.format(env.name)) as f:
|
||||
env.dump_csv(f)
|
||||
|
||||
with self.output('{}.stats.csv'.format(env.name)) as f:
|
||||
statwriter = csvlib.writer(f, delimiter='\t', quotechar='"', quoting=csvlib.QUOTE_ALL)
|
||||
|
||||
for stat in stats:
|
||||
statwriter.writerow(stat)
|
||||
|
||||
|
||||
class gexf(Exporter):
|
||||
def trial(self, env, stats):
|
||||
if self.dry_run:
|
||||
logger.info('Not dumping GEXF in dry_run mode')
|
||||
return
|
||||
|
||||
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||
env.name)):
|
||||
with self.output('{}.gexf'.format(env.name), mode='wb') as f:
|
||||
env.dump_gexf(f)
|
||||
|
||||
|
||||
class dummy(Exporter):
|
||||
|
||||
def start(self):
|
||||
with self.output('dummy', 'w') as f:
|
||||
f.write('simulation started @ {}\n'.format(time.time()))
|
||||
|
||||
def trial(self, env, stats):
|
||||
with self.output('dummy', 'w') as f:
|
||||
for i in env.history_to_tuples():
|
||||
f.write(','.join(map(str, i)))
|
||||
f.write('\n')
|
||||
|
||||
def sim(self, stats):
|
||||
with self.output('dummy', 'a') as f:
|
||||
f.write('simulation ended @ {}\n'.format(time.time()))
|
||||
|
||||
|
||||
|
||||
class graphdrawing(Exporter):
|
||||
|
||||
def trial(self, env, stats):
|
||||
# Outside effects
|
||||
f = plt.figure()
|
||||
nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111))
|
||||
with open('graph-{}.png'.format(env.name)) as f:
|
||||
f.savefig(f)
|
||||
|
||||
196
soil/history.py
196
soil/history.py
@@ -3,9 +3,15 @@ import os
|
||||
import pandas as pd
|
||||
import sqlite3
|
||||
import copy
|
||||
import logging
|
||||
import tempfile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from collections import UserDict, namedtuple
|
||||
|
||||
from . import utils
|
||||
from . import serialization
|
||||
from .utils import open_or_reuse, unflatten_dict
|
||||
|
||||
|
||||
class History:
|
||||
@@ -13,27 +19,44 @@ class History:
|
||||
Store and retrieve values from a sqlite database.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path=None, name=None, dir_path=None, backup=True):
|
||||
if db_path is None and name:
|
||||
db_path = os.path.join(dir_path or os.getcwd(),
|
||||
'{}.db.sqlite'.format(name))
|
||||
if db_path:
|
||||
if backup and os.path.exists(db_path):
|
||||
def __init__(self, name=None, db_path=None, backup=False, readonly=False):
|
||||
if readonly and (not os.path.exists(db_path)):
|
||||
raise Exception('The DB file does not exist. Cannot open in read-only mode')
|
||||
|
||||
self._db = None
|
||||
self._temp = db_path is None
|
||||
self._stats_columns = None
|
||||
self.readonly = readonly
|
||||
|
||||
if self._temp:
|
||||
if not name:
|
||||
name = time.time()
|
||||
# The file will be deleted as soon as it's closed
|
||||
# Normally, that will be on destruction
|
||||
db_path = tempfile.NamedTemporaryFile(suffix='{}.sqlite'.format(name)).name
|
||||
|
||||
|
||||
if backup and os.path.exists(db_path):
|
||||
newname = db_path + '.backup{}.sqlite'.format(time.time())
|
||||
os.rename(db_path, newname)
|
||||
else:
|
||||
db_path = ":memory:"
|
||||
|
||||
self.db_path = db_path
|
||||
|
||||
self.db = db_path
|
||||
|
||||
with self.db:
|
||||
self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text text)''')
|
||||
self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''')
|
||||
self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''')
|
||||
self._dtypes = {}
|
||||
self._tups = []
|
||||
|
||||
|
||||
if self.readonly:
|
||||
return
|
||||
|
||||
with self.db:
|
||||
logger.debug('Creating database {}'.format(self.db_path))
|
||||
self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text)''')
|
||||
self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''')
|
||||
self.db.execute('''CREATE TABLE IF NOT EXISTS stats (trial_id text)''')
|
||||
self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''')
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
try:
|
||||
@@ -44,15 +67,72 @@ class History:
|
||||
|
||||
@db.setter
|
||||
def db(self, db_path=None):
|
||||
self._close()
|
||||
db_path = db_path or self.db_path
|
||||
if isinstance(db_path, str):
|
||||
logger.debug('Connecting to database {}'.format(db_path))
|
||||
self._db = sqlite3.connect(db_path)
|
||||
self._db.row_factory = sqlite3.Row
|
||||
else:
|
||||
self._db = db_path
|
||||
|
||||
def _close(self):
|
||||
if self._db is None:
|
||||
return
|
||||
self.flush_cache()
|
||||
self._db.close()
|
||||
self._db = None
|
||||
|
||||
def save_stats(self, stat):
|
||||
if self.readonly:
|
||||
print('DB in readonly mode')
|
||||
return
|
||||
if not stat:
|
||||
return
|
||||
with self.db:
|
||||
if not self._stats_columns:
|
||||
self._stats_columns = list(c['name'] for c in self.db.execute('PRAGMA table_info(stats)'))
|
||||
|
||||
for column, value in stat.items():
|
||||
if column in self._stats_columns:
|
||||
continue
|
||||
dtype = 'text'
|
||||
if not isinstance(value, str):
|
||||
try:
|
||||
float(value)
|
||||
dtype = 'real'
|
||||
int(value)
|
||||
dtype = 'int'
|
||||
except ValueError:
|
||||
pass
|
||||
self.db.execute('ALTER TABLE stats ADD "{}" "{}"'.format(column, dtype))
|
||||
self._stats_columns.append(column)
|
||||
|
||||
columns = ", ".join(map(lambda x: '"{}"'.format(x), stat.keys()))
|
||||
values = ", ".join(['"{0}"'.format(col) for col in stat.values()])
|
||||
query = "INSERT INTO stats ({columns}) VALUES ({values})".format(
|
||||
columns=columns,
|
||||
values=values
|
||||
)
|
||||
self.db.execute(query)
|
||||
|
||||
def get_stats(self, unflatten=True):
|
||||
rows = self.db.execute("select * from stats").fetchall()
|
||||
res = []
|
||||
for row in rows:
|
||||
d = {}
|
||||
for k in row.keys():
|
||||
if row[k] is None:
|
||||
continue
|
||||
d[k] = row[k]
|
||||
if unflatten:
|
||||
d = unflatten_dict(d)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
@property
|
||||
def dtypes(self):
|
||||
self.read_types()
|
||||
self._read_types()
|
||||
return {k:v[0] for k, v in self._dtypes.items()}
|
||||
|
||||
def save_tuples(self, tuples):
|
||||
@@ -75,7 +155,18 @@ class History:
|
||||
Save a collection of records to the database.
|
||||
Database writes are cached.
|
||||
'''
|
||||
value = self.convert(key, value)
|
||||
if self.readonly:
|
||||
raise Exception('DB in readonly mode')
|
||||
if key not in self._dtypes:
|
||||
self._read_types()
|
||||
if key not in self._dtypes:
|
||||
name = serialization.name(value)
|
||||
serializer = serialization.serializer(name)
|
||||
deserializer = serialization.deserializer(name)
|
||||
self._dtypes[key] = (name, serializer, deserializer)
|
||||
with self.db:
|
||||
self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name))
|
||||
value = self._dtypes[key][1](value)
|
||||
self._tups.append(Record(agent_id=agent_id,
|
||||
t_step=t_step,
|
||||
key=key,
|
||||
@@ -83,33 +174,14 @@ class History:
|
||||
if len(self._tups) > 100:
|
||||
self.flush_cache()
|
||||
|
||||
def convert(self, key, value):
|
||||
"""Get the serialized value for a given key."""
|
||||
if key not in self._dtypes:
|
||||
self.read_types()
|
||||
if key not in self._dtypes:
|
||||
name = utils.name(value)
|
||||
serializer = utils.serializer(name)
|
||||
deserializer = utils.deserializer(name)
|
||||
self._dtypes[key] = (name, serializer, deserializer)
|
||||
with self.db:
|
||||
self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name))
|
||||
return self._dtypes[key][1](value)
|
||||
|
||||
def recover(self, key, value):
|
||||
"""Get the deserialized value for a given key, and the serialized version."""
|
||||
if key not in self._dtypes:
|
||||
self.read_types()
|
||||
if key not in self._dtypes:
|
||||
raise ValueError("Unknown datatype for {} and {}".format(key, value))
|
||||
return self._dtypes[key][2](value)
|
||||
|
||||
|
||||
def flush_cache(self):
|
||||
'''
|
||||
Use a cache to save state changes to avoid opening a session for every change.
|
||||
The cache will be flushed at the end of the simulation, and when history is accessed.
|
||||
'''
|
||||
if self.readonly:
|
||||
raise Exception('DB in readonly mode')
|
||||
logger.debug('Flushing cache {}'.format(self.db_path))
|
||||
with self.db:
|
||||
for rec in self._tups:
|
||||
self.db.execute("replace into history(agent_id, t_step, key, value) values (?, ?, ?, ?)", (rec.agent_id, rec.t_step, rec.key, rec.value))
|
||||
@@ -121,15 +193,19 @@ class History:
|
||||
res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall()
|
||||
for r in res:
|
||||
agent_id, t_step, key, value = r
|
||||
value = self.recover(key, value)
|
||||
if key not in self._dtypes:
|
||||
self._read_types()
|
||||
if key not in self._dtypes:
|
||||
raise ValueError("Unknown datatype for {} and {}".format(key, value))
|
||||
value = self._dtypes[key][2](value)
|
||||
yield agent_id, t_step, key, value
|
||||
|
||||
def read_types(self):
|
||||
def _read_types(self):
|
||||
with self.db:
|
||||
res = self.db.execute("select key, value_type from value_types ").fetchall()
|
||||
for k, v in res:
|
||||
serializer = utils.serializer(v)
|
||||
deserializer = utils.deserializer(v)
|
||||
serializer = serialization.serializer(v)
|
||||
deserializer = serialization.deserializer(v)
|
||||
self._dtypes[k] = (v, serializer, deserializer)
|
||||
|
||||
def __getitem__(self, key):
|
||||
@@ -147,11 +223,9 @@ class History:
|
||||
return r.value()
|
||||
return r
|
||||
|
||||
|
||||
|
||||
def read_sql(self, keys=None, agent_ids=None, t_steps=None, convert_types=False, limit=-1):
|
||||
|
||||
self.read_types()
|
||||
self._read_types()
|
||||
|
||||
def escape_and_join(v):
|
||||
if v is None:
|
||||
@@ -165,7 +239,13 @@ class History:
|
||||
|
||||
last_df = None
|
||||
if t_steps:
|
||||
# Look for the last value before the minimum step in the query
|
||||
# Convert negative indices into positive
|
||||
if any(x<0 for x in t_steps):
|
||||
max_t = int(self.db.execute("select max(t_step) from history").fetchone()[0])
|
||||
t_steps = [t if t>0 else max_t+1+t for t in t_steps]
|
||||
|
||||
# We will be doing ffill interpolation, so we need to look for
|
||||
# the last value before the minimum step in the query
|
||||
min_step = min(t_steps)
|
||||
last_filters = ['t_step < {}'.format(min_step),]
|
||||
last_filters = last_filters + filters
|
||||
@@ -203,20 +283,30 @@ class History:
|
||||
for k, v in self._dtypes.items():
|
||||
if k in df_p:
|
||||
dtype, _, deserial = v
|
||||
df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
|
||||
try:
|
||||
df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
|
||||
except (TypeError, ValueError):
|
||||
# Avoid forward-filling unknown/incompatible types
|
||||
continue
|
||||
if t_steps:
|
||||
df_p = df_p.reindex(t_steps, method='ffill')
|
||||
return df_p.ffill()
|
||||
|
||||
|
||||
def __getstate__(self):
|
||||
state = dict(**self.__dict__)
|
||||
del state['_db']
|
||||
del state['_dtypes']
|
||||
return state
|
||||
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
self._dtypes = {}
|
||||
self._db = None
|
||||
|
||||
def dump(self, f):
|
||||
self._close()
|
||||
for line in open_or_reuse(self.db_path, 'rb'):
|
||||
f.write(line)
|
||||
|
||||
|
||||
class Records():
|
||||
@@ -267,10 +357,13 @@ class Records():
|
||||
i = self._df[f.key][str(f.agent_id)]
|
||||
ix = i.index.get_loc(f.t_step, method='ffill')
|
||||
return i.iloc[ix]
|
||||
except KeyError:
|
||||
except KeyError as ex:
|
||||
return self.dtypes[f.key][2]()
|
||||
return list(self)
|
||||
|
||||
def df(self):
|
||||
return self._df
|
||||
|
||||
def __getitem__(self, k):
|
||||
n = copy.copy(self)
|
||||
n.filter(k)
|
||||
@@ -286,6 +379,7 @@ class Records():
|
||||
return str(self.value())
|
||||
return '<Records for [{}]>'.format(self._filter)
|
||||
|
||||
|
||||
Key = namedtuple('Key', ['agent_id', 't_step', 'key'])
|
||||
Record = namedtuple('Record', 'agent_id t_step key value')
|
||||
|
||||
Stat = namedtuple('Stat', 'trial_id')
|
||||
|
||||
210
soil/serialization.py
Normal file
210
soil/serialization.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import os
|
||||
import logging
|
||||
import ast
|
||||
import sys
|
||||
import importlib
|
||||
from glob import glob
|
||||
from itertools import product, chain
|
||||
|
||||
import yaml
|
||||
import networkx as nx
|
||||
|
||||
from jinja2 import Template
|
||||
|
||||
|
||||
logger = logging.getLogger('soil')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def load_network(network_params, dir_path=None):
|
||||
G = nx.Graph()
|
||||
|
||||
if 'path' in network_params:
|
||||
path = network_params['path']
|
||||
if dir_path and not os.path.isabs(path):
|
||||
path = os.path.join(dir_path, path)
|
||||
extension = os.path.splitext(path)[1][1:]
|
||||
kwargs = {}
|
||||
if extension == 'gexf':
|
||||
kwargs['version'] = '1.2draft'
|
||||
kwargs['node_type'] = int
|
||||
try:
|
||||
method = getattr(nx.readwrite, 'read_' + extension)
|
||||
except AttributeError:
|
||||
raise AttributeError('Unknown format')
|
||||
G = method(path, **kwargs)
|
||||
|
||||
elif 'generator' in network_params:
|
||||
net_args = network_params.copy()
|
||||
net_gen = net_args.pop('generator')
|
||||
|
||||
if dir_path not in sys.path:
|
||||
sys.path.append(dir_path)
|
||||
|
||||
method = deserializer(net_gen,
|
||||
known_modules=['networkx.generators',])
|
||||
G = method(**net_args)
|
||||
|
||||
return G
|
||||
|
||||
|
||||
|
||||
|
||||
def load_file(infile):
|
||||
with open(infile, 'r') as f:
|
||||
return list(chain.from_iterable(map(expand_template, load_string(f))))
|
||||
|
||||
|
||||
def load_string(string):
|
||||
yield from yaml.load_all(string, Loader=yaml.FullLoader)
|
||||
|
||||
|
||||
def expand_template(config):
|
||||
if 'template' not in config:
|
||||
yield config
|
||||
return
|
||||
if 'vars' not in config:
|
||||
raise ValueError(('You must provide a definition of variables'
|
||||
' for the template.'))
|
||||
|
||||
template = config['template']
|
||||
|
||||
if not isinstance(template, str):
|
||||
template = yaml.dump(template)
|
||||
|
||||
template = Template(template)
|
||||
|
||||
params = params_for_template(config)
|
||||
|
||||
blank_str = template.render({k: 0 for k in params[0].keys()})
|
||||
blank = list(load_string(blank_str))
|
||||
if len(blank) > 1:
|
||||
raise ValueError('Templates must not return more than one configuration')
|
||||
if 'name' in blank[0]:
|
||||
raise ValueError('Templates cannot be named, use group instead')
|
||||
|
||||
for ps in params:
|
||||
string = template.render(ps)
|
||||
for c in load_string(string):
|
||||
yield c
|
||||
|
||||
|
||||
def params_for_template(config):
|
||||
sampler_config = config.get('sampler', {'N': 100})
|
||||
sampler = sampler_config.pop('method', 'SALib.sample.morris.sample')
|
||||
sampler = deserializer(sampler)
|
||||
bounds = config['vars']['bounds']
|
||||
|
||||
problem = {
|
||||
'num_vars': len(bounds),
|
||||
'names': list(bounds.keys()),
|
||||
'bounds': list(v for v in bounds.values())
|
||||
}
|
||||
samples = sampler(problem, **sampler_config)
|
||||
|
||||
lists = config['vars'].get('lists', {})
|
||||
names = list(lists.keys())
|
||||
values = list(lists.values())
|
||||
combs = list(product(*values))
|
||||
|
||||
allnames = names + problem['names']
|
||||
allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)]
|
||||
params = list(map(lambda x: dict(zip(allnames, x)), allvalues))
|
||||
return params
|
||||
|
||||
|
||||
def load_files(*patterns, **kwargs):
|
||||
for pattern in patterns:
|
||||
for i in glob(pattern, **kwargs):
|
||||
for config in load_file(i):
|
||||
path = os.path.abspath(i)
|
||||
if 'dir_path' not in config:
|
||||
config['dir_path'] = os.path.dirname(path)
|
||||
yield config, path
|
||||
|
||||
|
||||
def load_config(config):
|
||||
if isinstance(config, dict):
|
||||
yield config, os.getcwd()
|
||||
else:
|
||||
yield from load_files(config)
|
||||
|
||||
|
||||
builtins = importlib.import_module('builtins')
|
||||
|
||||
def name(value, known_modules=[]):
|
||||
'''Return a name that can be imported, to serialize/deserialize an object'''
|
||||
if value is None:
|
||||
return 'None'
|
||||
if not isinstance(value, type): # Get the class name first
|
||||
value = type(value)
|
||||
tname = value.__name__
|
||||
if hasattr(builtins, tname):
|
||||
return tname
|
||||
modname = value.__module__
|
||||
if modname == '__main__':
|
||||
return tname
|
||||
if known_modules and modname in known_modules:
|
||||
return tname
|
||||
for kmod in known_modules:
|
||||
if not kmod:
|
||||
continue
|
||||
module = importlib.import_module(kmod)
|
||||
if hasattr(module, tname):
|
||||
return tname
|
||||
return '{}.{}'.format(modname, tname)
|
||||
|
||||
|
||||
def serializer(type_):
|
||||
if type_ != 'str' and hasattr(builtins, type_):
|
||||
return repr
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def serialize(v, known_modules=[]):
|
||||
'''Get a text representation of an object.'''
|
||||
tname = name(v, known_modules=known_modules)
|
||||
func = serializer(tname)
|
||||
return func(v), tname
|
||||
|
||||
def deserializer(type_, known_modules=[]):
|
||||
if type(type_) != str: # Already deserialized
|
||||
return type_
|
||||
if type_ == 'str':
|
||||
return lambda x='': x
|
||||
if type_ == 'None':
|
||||
return lambda x=None: None
|
||||
if hasattr(builtins, type_): # Check if it's a builtin type
|
||||
cls = getattr(builtins, type_)
|
||||
return lambda x=None: ast.literal_eval(x) if x is not None else cls()
|
||||
# Otherwise, see if we can find the module and the class
|
||||
modules = known_modules or []
|
||||
options = []
|
||||
|
||||
for mod in modules:
|
||||
if mod:
|
||||
options.append((mod, type_))
|
||||
|
||||
if '.' in type_: # Fully qualified module
|
||||
module, type_ = type_.rsplit(".", 1)
|
||||
options.append ((module, type_))
|
||||
|
||||
errors = []
|
||||
for modname, tname in options:
|
||||
try:
|
||||
module = importlib.import_module(modname)
|
||||
cls = getattr(module, tname)
|
||||
return getattr(cls, 'deserialize', cls)
|
||||
except (ImportError, AttributeError) as ex:
|
||||
errors.append((modname, tname, ex))
|
||||
raise Exception('Could not find type {}. Tried: {}'.format(type_, errors))
|
||||
|
||||
|
||||
def deserialize(type_, value=None, **kwargs):
|
||||
'''Get an object from a text representation'''
|
||||
if not isinstance(type_, str):
|
||||
return type_
|
||||
des = deserializer(type_, **kwargs)
|
||||
if value is None:
|
||||
return des
|
||||
return des(value)
|
||||
@@ -4,6 +4,7 @@ import importlib
|
||||
import sys
|
||||
import yaml
|
||||
import traceback
|
||||
import logging
|
||||
import networkx as nx
|
||||
from networkx.readwrite import json_graph
|
||||
from multiprocessing import Pool
|
||||
@@ -11,16 +12,19 @@ from functools import partial
|
||||
|
||||
import pickle
|
||||
|
||||
from nxsim import NetworkSimulation
|
||||
|
||||
from . import utils, basestring, agents
|
||||
from . import serialization, utils, basestring, agents
|
||||
from .environment import Environment
|
||||
from .utils import logger
|
||||
from .exporters import default, for_sim as exporters_for_sim
|
||||
from .stats import defaultStats
|
||||
from .history import History
|
||||
|
||||
|
||||
class Simulation(NetworkSimulation):
|
||||
#TODO: change documentation for simulation
|
||||
|
||||
class Simulation:
|
||||
"""
|
||||
Subclass of nsim.NetworkSimulation with three main differences:
|
||||
Similar to nsim.NetworkSimulation with three main differences:
|
||||
1) agent type can be specified by name or by class.
|
||||
2) instead of just one type, a network agents distribution can be used.
|
||||
The distribution specifies the weight (or probability) of each
|
||||
@@ -50,6 +54,8 @@ class Simulation(NetworkSimulation):
|
||||
---------
|
||||
name : str, optional
|
||||
name of the Simulation
|
||||
group : str, optional
|
||||
a group name can be used to link simulations
|
||||
topology : networkx.Graph instance, optional
|
||||
network_params : dict
|
||||
parameters used to create a topology with networkx, if no topology is given
|
||||
@@ -60,8 +66,8 @@ class Simulation(NetworkSimulation):
|
||||
states : list, optional
|
||||
List of initial states corresponding to the nodes in the topology. Basic form is a list of integers
|
||||
whose value indicates the state
|
||||
dir_path : str, optional
|
||||
Directory path where to save pickled objects
|
||||
dir_path: str, optional
|
||||
Directory path to load simulation assets (files, modules...)
|
||||
seed : str, optional
|
||||
Seed to use for the random generator
|
||||
num_trials : int, optional
|
||||
@@ -80,37 +86,38 @@ class Simulation(NetworkSimulation):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name=None, topology=None, network_params=None,
|
||||
def __init__(self, name=None, group=None, topology=None, network_params=None,
|
||||
network_agents=None, agent_type=None, states=None,
|
||||
default_state=None, interval=1, dump=None, dry_run=False,
|
||||
dir_path=None, num_trials=1, max_time=100,
|
||||
load_module=None, seed=None,
|
||||
environment_agents=None, environment_params=None,
|
||||
environment_class=None, **kwargs):
|
||||
|
||||
if topology is None:
|
||||
topology = utils.load_network(network_params,
|
||||
dir_path=dir_path)
|
||||
elif isinstance(topology, basestring) or isinstance(topology, dict):
|
||||
topology = json_graph.node_link_graph(topology)
|
||||
default_state=None, interval=1, num_trials=1,
|
||||
max_time=100, load_module=None, seed=None,
|
||||
dir_path=None, environment_agents=None,
|
||||
environment_params=None, environment_class=None,
|
||||
**kwargs):
|
||||
|
||||
self.load_module = load_module
|
||||
self.topology = nx.Graph(topology)
|
||||
self.network_params = network_params
|
||||
self.name = name or 'UnnamedSimulation'
|
||||
self.name = name or 'Unnamed'
|
||||
self.seed = str(seed or name)
|
||||
self._id = '{}_{}'.format(self.name, time.strftime("%Y-%m-%d_%H.%M.%S"))
|
||||
self.group = group or ''
|
||||
self.num_trials = num_trials
|
||||
self.max_time = max_time
|
||||
self.default_state = default_state or {}
|
||||
self.dir_path = dir_path or os.getcwd()
|
||||
self.interval = interval
|
||||
self.seed = str(seed) or str(time.time())
|
||||
self.dump = dump
|
||||
self.dry_run = dry_run
|
||||
|
||||
sys.path += [self.dir_path, os.getcwd()]
|
||||
sys.path += list(x for x in [os.getcwd(), self.dir_path] if x not in sys.path)
|
||||
|
||||
if topology is None:
|
||||
topology = serialization.load_network(network_params,
|
||||
dir_path=self.dir_path)
|
||||
elif isinstance(topology, basestring) or isinstance(topology, dict):
|
||||
topology = json_graph.node_link_graph(topology)
|
||||
self.topology = nx.Graph(topology)
|
||||
|
||||
|
||||
self.environment_params = environment_params or {}
|
||||
self.environment_class = utils.deserialize(environment_class,
|
||||
self.environment_class = serialization.deserialize(environment_class,
|
||||
known_modules=['soil.environment', ]) or Environment
|
||||
|
||||
environment_agents = environment_agents or []
|
||||
@@ -125,73 +132,127 @@ class Simulation(NetworkSimulation):
|
||||
self.states = agents._validate_states(states,
|
||||
self.topology)
|
||||
|
||||
self._history = History(name=self.name,
|
||||
backup=False)
|
||||
|
||||
def run_simulation(self, *args, **kwargs):
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return list(self.run_simulation_gen(*args, **kwargs))
|
||||
'''Run the simulation and return the list of resulting environments'''
|
||||
return list(self.run_gen(*args, **kwargs))
|
||||
|
||||
def _run_sync_or_async(self, parallel=False, *args, **kwargs):
|
||||
if parallel:
|
||||
p = Pool()
|
||||
func = partial(self.run_trial_exceptions,
|
||||
*args,
|
||||
**kwargs)
|
||||
for i in p.imap_unordered(func, range(self.num_trials)):
|
||||
if isinstance(i, Exception):
|
||||
logger.error('Trial failed:\n\t%s', i.message)
|
||||
continue
|
||||
yield i
|
||||
else:
|
||||
for i in range(self.num_trials):
|
||||
yield self.run_trial(*args,
|
||||
**kwargs)
|
||||
|
||||
def run_gen(self, *args, parallel=False, dry_run=False,
|
||||
exporters=[default, ], stats=[defaultStats], outdir=None, exporter_params={},
|
||||
stats_params={}, log_level=None,
|
||||
**kwargs):
|
||||
'''Run the simulation and yield the resulting environments.'''
|
||||
if log_level:
|
||||
logger.setLevel(log_level)
|
||||
logger.info('Using exporters: %s', exporters or [])
|
||||
logger.info('Output directory: %s', outdir)
|
||||
exporters = exporters_for_sim(self,
|
||||
exporters,
|
||||
dry_run=dry_run,
|
||||
outdir=outdir,
|
||||
**exporter_params)
|
||||
stats = exporters_for_sim(self,
|
||||
stats,
|
||||
**stats_params)
|
||||
|
||||
def run_simulation_gen(self, *args, parallel=False, dry_run=False,
|
||||
**kwargs):
|
||||
p = Pool()
|
||||
with utils.timer('simulation {}'.format(self.name)):
|
||||
if parallel:
|
||||
func = partial(self.run_trial_exceptions, dry_run=dry_run or self.dry_run,
|
||||
return_env=True,
|
||||
**kwargs)
|
||||
for i in p.imap_unordered(func, range(self.num_trials)):
|
||||
if isinstance(i, Exception):
|
||||
logger.error('Trial failed:\n\t{}'.format(i.message))
|
||||
continue
|
||||
yield i
|
||||
else:
|
||||
for i in range(self.num_trials):
|
||||
yield self.run_trial(i, dry_run = dry_run or self.dry_run, **kwargs)
|
||||
if not (dry_run or self.dry_run):
|
||||
logger.info('Dumping results to {}'.format(self.dir_path))
|
||||
self.dump_pickle(self.dir_path)
|
||||
self.dump_yaml(self.dir_path)
|
||||
else:
|
||||
logger.info('NOT dumping results')
|
||||
for stat in stats:
|
||||
stat.start()
|
||||
|
||||
def get_env(self, trial_id = 0, **kwargs):
|
||||
opts=self.environment_params.copy()
|
||||
env_name='{}_trial_{}'.format(self.name, trial_id)
|
||||
for exporter in exporters:
|
||||
exporter.start()
|
||||
for env in self._run_sync_or_async(*args,
|
||||
parallel=parallel,
|
||||
log_level=log_level,
|
||||
**kwargs):
|
||||
|
||||
collected = list(stat.trial(env) for stat in stats)
|
||||
|
||||
saved = self.save_stats(collected, t_step=env.now, trial_id=env.name)
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.trial(env, saved)
|
||||
|
||||
yield env
|
||||
|
||||
|
||||
collected = list(stat.end() for stat in stats)
|
||||
saved = self.save_stats(collected)
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.end(saved)
|
||||
|
||||
|
||||
def save_stats(self, collection, **kwargs):
|
||||
stats = dict(kwargs)
|
||||
for stat in collection:
|
||||
stats.update(stat)
|
||||
self._history.save_stats(utils.flatten_dict(stats))
|
||||
return stats
|
||||
|
||||
def get_stats(self, **kwargs):
|
||||
return self._history.get_stats(**kwargs)
|
||||
|
||||
def log_stats(self, stats):
|
||||
logger.info('Stats: \n{}'.format(yaml.dump(stats, default_flow_style=False)))
|
||||
|
||||
|
||||
def get_env(self, trial_id=0, **kwargs):
|
||||
'''Create an environment for a trial of the simulation'''
|
||||
opts = self.environment_params.copy()
|
||||
opts.update({
|
||||
'name': env_name,
|
||||
'name': trial_id,
|
||||
'topology': self.topology.copy(),
|
||||
'seed': self.seed+env_name,
|
||||
'seed': '{}_trial_{}'.format(self.seed, trial_id),
|
||||
'initial_time': 0,
|
||||
'dry_run': self.dry_run,
|
||||
'interval': self.interval,
|
||||
'network_agents': self.network_agents,
|
||||
'initial_time': 0,
|
||||
'states': self.states,
|
||||
'default_state': self.default_state,
|
||||
'environment_agents': self.environment_agents,
|
||||
'dir_path': self.dir_path,
|
||||
})
|
||||
opts.update(kwargs)
|
||||
env=self.environment_class(**opts)
|
||||
env = self.environment_class(**opts)
|
||||
return env
|
||||
|
||||
def run_trial(self, trial_id = 0, until = None, return_env = True, **opts):
|
||||
"""Run a single trial of the simulation
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trial_id : int
|
||||
def run_trial(self, until=None, log_level=logging.INFO, **opts):
|
||||
"""
|
||||
Run a single trial of the simulation
|
||||
|
||||
"""
|
||||
trial_id = '{}_trial_{}'.format(self.name, time.time()).replace('.', '-')
|
||||
if log_level:
|
||||
logger.setLevel(log_level)
|
||||
# Set-up trial environment and graph
|
||||
until=until or self.max_time
|
||||
env=self.get_env(trial_id = trial_id, **opts)
|
||||
until = until or self.max_time
|
||||
env = self.get_env(trial_id=trial_id, **opts)
|
||||
# Set up agents on nodes
|
||||
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
||||
env.run(until)
|
||||
if self.dump and not self.dry_run:
|
||||
with utils.timer('Dumping simulation {} trial {}'.format(self.name, trial_id)):
|
||||
env.dump(formats = self.dump)
|
||||
if return_env:
|
||||
return env
|
||||
return env
|
||||
|
||||
def run_trial_exceptions(self, *args, **kwargs):
|
||||
'''
|
||||
A wrapper for run_trial that catches exceptions and returns them.
|
||||
@@ -200,9 +261,10 @@ class Simulation(NetworkSimulation):
|
||||
try:
|
||||
return self.run_trial(*args, **kwargs)
|
||||
except Exception as ex:
|
||||
c = ex.__cause__
|
||||
c.message = ''.join(traceback.format_exception(type(c), c, c.__traceback__)[:])
|
||||
return c
|
||||
if ex.__cause__ is not None:
|
||||
ex = ex.__cause__
|
||||
ex.message = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)[:])
|
||||
return ex
|
||||
|
||||
def to_dict(self):
|
||||
return self.__getstate__()
|
||||
@@ -210,38 +272,39 @@ class Simulation(NetworkSimulation):
|
||||
def to_yaml(self):
|
||||
return yaml.dump(self.to_dict())
|
||||
|
||||
def dump_yaml(self, dir_path = None, file_name = None):
|
||||
dir_path=dir_path or self.dir_path
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
if not file_name:
|
||||
file_name=os.path.join(dir_path,
|
||||
'{}.dumped.yml'.format(self.name))
|
||||
with open(file_name, 'w') as f:
|
||||
|
||||
def dump_yaml(self, f=None, outdir=None):
|
||||
if not f and not outdir:
|
||||
raise ValueError('specify a file or an output directory')
|
||||
|
||||
if not f:
|
||||
f = os.path.join(outdir, '{}.dumped.yml'.format(self.name))
|
||||
|
||||
with utils.open_or_reuse(f, 'w') as f:
|
||||
f.write(self.to_yaml())
|
||||
|
||||
def dump_pickle(self, dir_path = None, pickle_name = None):
|
||||
dir_path=dir_path or self.dir_path
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
if not pickle_name:
|
||||
pickle_name=os.path.join(dir_path,
|
||||
'{}.simulation.pickle'.format(self.name))
|
||||
with open(pickle_name, 'wb') as f:
|
||||
def dump_pickle(self, f=None, outdir=None):
|
||||
if not outdir and not f:
|
||||
raise ValueError('specify a file or an output directory')
|
||||
|
||||
if not f:
|
||||
f = os.path.join(outdir,
|
||||
'{}.simulation.pickle'.format(self.name))
|
||||
with utils.open_or_reuse(f, 'wb') as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
def __getstate__(self):
|
||||
state={}
|
||||
for k, v in self.__dict__.items():
|
||||
if k[0] != '_':
|
||||
state[k]=v
|
||||
state['topology']=json_graph.node_link_data(self.topology)
|
||||
state['network_agents']=agents.serialize_distribution(self.network_agents,
|
||||
known_modules = [])
|
||||
state['environment_agents']=agents.serialize_distribution(self.environment_agents,
|
||||
known_modules = [])
|
||||
state['environment_class']=utils.serialize(self.environment_class,
|
||||
known_modules=['soil.environment'])[1] # func, name
|
||||
state[k] = v
|
||||
state['topology'] = json_graph.node_link_data(self.topology)
|
||||
state['network_agents'] = agents.serialize_distribution(self.network_agents,
|
||||
known_modules = [])
|
||||
state['environment_agents'] = agents.serialize_distribution(self.environment_agents,
|
||||
known_modules = [])
|
||||
state['environment_class'] = serialization.serialize(self.environment_class,
|
||||
known_modules=['soil.environment'])[1] # func, name
|
||||
if state['load_module'] is None:
|
||||
del state['load_module']
|
||||
return state
|
||||
@@ -255,13 +318,20 @@ class Simulation(NetworkSimulation):
|
||||
self.network_agents = agents.calculate_distribution(agents._convert_agent_types(self.network_agents))
|
||||
self.environment_agents = agents._convert_agent_types(self.environment_agents,
|
||||
known_modules=[self.load_module])
|
||||
self.environment_class = utils.deserialize(self.environment_class,
|
||||
self.environment_class = serialization.deserialize(self.environment_class,
|
||||
known_modules=[self.load_module, 'soil.environment', ]) # func, name
|
||||
return state
|
||||
|
||||
|
||||
def from_config(config):
|
||||
config = list(utils.load_config(config))
|
||||
def all_from_config(config):
|
||||
configs = list(serialization.load_config(config))
|
||||
for config, _ in configs:
|
||||
sim = Simulation(**config)
|
||||
yield sim
|
||||
|
||||
|
||||
def from_config(conf_or_path):
|
||||
config = list(serialization.load_config(conf_or_path))
|
||||
if len(config) > 1:
|
||||
raise AttributeError('Provide only one configuration')
|
||||
config = config[0][0]
|
||||
@@ -269,21 +339,14 @@ def from_config(config):
|
||||
return sim
|
||||
|
||||
|
||||
def run_from_config(*configs, results_dir='soil_output', dump=None, timestamp=False, **kwargs):
|
||||
def run_from_config(*configs, **kwargs):
|
||||
for config_def in configs:
|
||||
# logger.info("Found {} config(s)".format(len(ls)))
|
||||
for config, _ in utils.load_config(config_def):
|
||||
for config, path in serialization.load_config(config_def):
|
||||
name = config.get('name', 'unnamed')
|
||||
logger.info("Using config(s): {name}".format(name=name))
|
||||
|
||||
if timestamp:
|
||||
sim_folder = '{}_{}'.format(name,
|
||||
time.strftime("%Y-%m-%d_%H:%M:%S"))
|
||||
else:
|
||||
sim_folder = name
|
||||
dir_path = os.path.join(results_dir, sim_folder)
|
||||
if dump is not None:
|
||||
config['dump'] = dump
|
||||
sim = Simulation(dir_path=dir_path, **config)
|
||||
logger.info('Dumping results to {} : {}'.format(sim.dir_path, sim.dump))
|
||||
dir_path = config.pop('dir_path', os.path.dirname(path))
|
||||
sim = Simulation(dir_path=dir_path,
|
||||
**config)
|
||||
sim.run_simulation(**kwargs)
|
||||
|
||||
106
soil/stats.py
Normal file
106
soil/stats.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import pandas as pd
|
||||
|
||||
from collections import Counter
|
||||
|
||||
class Stats:
|
||||
'''
|
||||
Interface for all stats. It is not necessary, but it is useful
|
||||
if you don't plan to implement all the methods.
|
||||
'''
|
||||
|
||||
def __init__(self, simulation):
|
||||
self.simulation = simulation
|
||||
|
||||
def start(self):
|
||||
'''Method to call when the simulation starts'''
|
||||
pass
|
||||
|
||||
def end(self):
|
||||
'''Method to call when the simulation ends'''
|
||||
return {}
|
||||
|
||||
def trial(self, env):
|
||||
'''Method to call when a trial ends'''
|
||||
return {}
|
||||
|
||||
|
||||
class distribution(Stats):
|
||||
'''
|
||||
Calculate the distribution of agent states at the end of each trial,
|
||||
the mean value, and its deviation.
|
||||
'''
|
||||
|
||||
def start(self):
|
||||
self.means = []
|
||||
self.counts = []
|
||||
|
||||
def trial(self, env):
|
||||
df = env[None, None, None].df()
|
||||
df = df.drop('SEED', axis=1)
|
||||
ix = df.index[-1]
|
||||
attrs = df.columns.get_level_values(0)
|
||||
vc = {}
|
||||
stats = {
|
||||
'mean': {},
|
||||
'count': {},
|
||||
}
|
||||
for a in attrs:
|
||||
t = df.loc[(ix, a)]
|
||||
try:
|
||||
stats['mean'][a] = t.mean()
|
||||
self.means.append(('mean', a, t.mean()))
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
for name, count in t.value_counts().iteritems():
|
||||
if a not in stats['count']:
|
||||
stats['count'][a] = {}
|
||||
stats['count'][a][name] = count
|
||||
self.counts.append(('count', a, name, count))
|
||||
|
||||
return stats
|
||||
|
||||
def end(self):
|
||||
dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value'])
|
||||
dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count'])
|
||||
|
||||
count = {}
|
||||
mean = {}
|
||||
|
||||
if self.means:
|
||||
res = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
||||
mean = res['value'].to_dict()
|
||||
if self.counts:
|
||||
res = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
||||
for k,v in res['count'].to_dict().items():
|
||||
if k not in count:
|
||||
count[k] = {}
|
||||
for tup, times in v.items():
|
||||
subkey, subcount = tup
|
||||
if subkey not in count[k]:
|
||||
count[k][subkey] = {}
|
||||
count[k][subkey][subcount] = times
|
||||
|
||||
|
||||
return {'count': count, 'mean': mean}
|
||||
|
||||
|
||||
class defaultStats(Stats):
|
||||
|
||||
def trial(self, env):
|
||||
c = Counter()
|
||||
c.update(a.__class__.__name__ for a in env.network_agents)
|
||||
|
||||
c2 = Counter()
|
||||
c2.update(a['id'] for a in env.network_agents)
|
||||
|
||||
return {
|
||||
'network ': {
|
||||
'n_nodes': env.G.number_of_nodes(),
|
||||
'n_edges': env.G.number_of_nodes(),
|
||||
},
|
||||
'agents': {
|
||||
'model_count': dict(c),
|
||||
'state_count': dict(c2),
|
||||
}
|
||||
}
|
||||
187
soil/utils.py
187
soil/utils.py
@@ -1,66 +1,16 @@
|
||||
import os
|
||||
import ast
|
||||
import yaml
|
||||
import logging
|
||||
import importlib
|
||||
import time
|
||||
from glob import glob
|
||||
from random import random
|
||||
from copy import deepcopy
|
||||
import os
|
||||
|
||||
import networkx as nx
|
||||
from shutil import copyfile
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
logger = logging.getLogger('soil')
|
||||
logging.basicConfig()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def load_network(network_params, dir_path=None):
|
||||
if network_params is None:
|
||||
return nx.Graph()
|
||||
path = network_params.get('path', None)
|
||||
if path:
|
||||
if dir_path and not os.path.isabs(path):
|
||||
path = os.path.join(dir_path, path)
|
||||
extension = os.path.splitext(path)[1][1:]
|
||||
kwargs = {}
|
||||
if extension == 'gexf':
|
||||
kwargs['version'] = '1.2draft'
|
||||
kwargs['node_type'] = int
|
||||
try:
|
||||
method = getattr(nx.readwrite, 'read_' + extension)
|
||||
except AttributeError:
|
||||
raise AttributeError('Unknown format')
|
||||
return method(path, **kwargs)
|
||||
|
||||
net_args = network_params.copy()
|
||||
net_type = net_args.pop('generator')
|
||||
|
||||
method = getattr(nx.generators, net_type)
|
||||
return method(**net_args)
|
||||
|
||||
|
||||
def load_file(infile):
|
||||
with open(infile, 'r') as f:
|
||||
return list(yaml.load_all(f))
|
||||
|
||||
|
||||
def load_files(*patterns):
|
||||
for pattern in patterns:
|
||||
for i in glob(pattern):
|
||||
for config in load_file(i):
|
||||
yield config, os.path.abspath(i)
|
||||
|
||||
|
||||
def load_config(config):
|
||||
if isinstance(config, dict):
|
||||
yield config, None
|
||||
else:
|
||||
yield from load_files(config)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timer(name='task', pre="", function=logger.info, to_object=None):
|
||||
start = time.time()
|
||||
@@ -76,79 +26,62 @@ def timer(name='task', pre="", function=logger.info, to_object=None):
|
||||
to_object.end = end
|
||||
|
||||
|
||||
builtins = importlib.import_module('builtins')
|
||||
def safe_open(path, mode='r', backup=True, **kwargs):
|
||||
outdir = os.path.dirname(path)
|
||||
if outdir and not os.path.exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
if backup and 'w' in mode and os.path.exists(path):
|
||||
creation = os.path.getctime(path)
|
||||
stamp = time.strftime('%Y-%m-%d_%H.%M.%S', time.localtime(creation))
|
||||
|
||||
def name(value, known_modules=[]):
|
||||
'''Return a name that can be imported, to serialize/deserialize an object'''
|
||||
if value is None:
|
||||
return 'None'
|
||||
if not isinstance(value, type): # Get the class name first
|
||||
value = type(value)
|
||||
tname = value.__name__
|
||||
if hasattr(builtins, tname):
|
||||
return tname
|
||||
modname = value.__module__
|
||||
if modname == '__main__':
|
||||
return tname
|
||||
if known_modules and modname in known_modules:
|
||||
return tname
|
||||
for kmod in known_modules:
|
||||
if not kmod:
|
||||
backup_dir = os.path.join(outdir, 'backup')
|
||||
if not os.path.exists(backup_dir):
|
||||
os.makedirs(backup_dir)
|
||||
newpath = os.path.join(backup_dir, '{}@{}'.format(os.path.basename(path),
|
||||
stamp))
|
||||
copyfile(path, newpath)
|
||||
return open(path, mode=mode, **kwargs)
|
||||
|
||||
|
||||
def open_or_reuse(f, *args, **kwargs):
|
||||
try:
|
||||
return safe_open(f, *args, **kwargs)
|
||||
except (AttributeError, TypeError):
|
||||
return f
|
||||
|
||||
def flatten_dict(d):
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
return dict(_flatten_dict(d))
|
||||
|
||||
def _flatten_dict(d, prefix=''):
|
||||
if not isinstance(d, dict):
|
||||
# print('END:', prefix, d)
|
||||
yield prefix, d
|
||||
return
|
||||
if prefix:
|
||||
prefix = prefix + '.'
|
||||
for k, v in d.items():
|
||||
# print(k, v)
|
||||
res = list(_flatten_dict(v, prefix='{}{}'.format(prefix, k)))
|
||||
# print('RES:', res)
|
||||
yield from res
|
||||
|
||||
|
||||
def unflatten_dict(d):
|
||||
out = {}
|
||||
for k, v in d.items():
|
||||
target = out
|
||||
if not isinstance(k, str):
|
||||
target[k] = v
|
||||
continue
|
||||
module = importlib.import_module(kmod)
|
||||
if hasattr(module, tname):
|
||||
return tname
|
||||
return '{}.{}'.format(modname, tname)
|
||||
|
||||
|
||||
def serializer(type_):
|
||||
if type_ != 'str' and hasattr(builtins, type_):
|
||||
return repr
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def serialize(v, known_modules=[]):
|
||||
'''Get a text representation of an object.'''
|
||||
tname = name(v, known_modules=known_modules)
|
||||
func = serializer(tname)
|
||||
return func(v), tname
|
||||
|
||||
def deserializer(type_, known_modules=[]):
|
||||
if type_ == 'str':
|
||||
return lambda x='': x
|
||||
if type_ == 'None':
|
||||
return lambda x=None: None
|
||||
if hasattr(builtins, type_): # Check if it's a builtin type
|
||||
cls = getattr(builtins, type_)
|
||||
return lambda x=None: ast.literal_eval(x) if x is not None else cls()
|
||||
# Otherwise, see if we can find the module and the class
|
||||
modules = known_modules or []
|
||||
options = []
|
||||
|
||||
for mod in modules:
|
||||
if mod:
|
||||
options.append((mod, type_))
|
||||
|
||||
if '.' in type_: # Fully qualified module
|
||||
module, type_ = type_.rsplit(".", 1)
|
||||
options.append ((module, type_))
|
||||
|
||||
errors = []
|
||||
for modname, tname in options:
|
||||
try:
|
||||
module = importlib.import_module(modname)
|
||||
cls = getattr(module, tname)
|
||||
return getattr(cls, 'deserialize', cls)
|
||||
except (ImportError, AttributeError) as ex:
|
||||
errors.append((modname, tname, ex))
|
||||
raise Exception('Could not find type {}. Tried: {}'.format(type_, errors))
|
||||
|
||||
|
||||
def deserialize(type_, value=None, **kwargs):
|
||||
'''Get an object from a text representation'''
|
||||
if not isinstance(type_, str):
|
||||
return type_
|
||||
des = deserializer(type_, **kwargs)
|
||||
if value is None:
|
||||
return des
|
||||
return des(value)
|
||||
tokens = k.split('.')
|
||||
if len(tokens) < 2:
|
||||
target[k] = v
|
||||
continue
|
||||
for token in tokens[:-1]:
|
||||
if token not in target:
|
||||
target[token] = {}
|
||||
target = target[token]
|
||||
target[tokens[-1]] = v
|
||||
return out
|
||||
|
||||
@@ -1,255 +0,0 @@
|
||||
import random
|
||||
import networkx as nx
|
||||
from soil.agents import BaseAgent, FSM, state, default_state
|
||||
from scipy.spatial import cKDTree as KDTree
|
||||
|
||||
global betweenness_centrality_global
|
||||
global degree_centrality_global
|
||||
|
||||
betweenness_centrality_global = None
|
||||
degree_centrality_global = None
|
||||
|
||||
class TerroristSpreadModel(FSM):
|
||||
"""
|
||||
Settings:
|
||||
information_spread_intensity
|
||||
|
||||
terrorist_additional_influence
|
||||
|
||||
min_vulnerability (optional else zero)
|
||||
|
||||
max_vulnerability
|
||||
|
||||
prob_interaction
|
||||
"""
|
||||
|
||||
def __init__(self, environment=None, agent_id=0, state=()):
|
||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||
|
||||
global betweenness_centrality_global
|
||||
global degree_centrality_global
|
||||
|
||||
if betweenness_centrality_global == None:
|
||||
betweenness_centrality_global = nx.betweenness_centrality(self.global_topology)
|
||||
if degree_centrality_global == None:
|
||||
degree_centrality_global = nx.degree_centrality(self.global_topology)
|
||||
|
||||
self.information_spread_intensity = environment.environment_params['information_spread_intensity']
|
||||
self.terrorist_additional_influence = environment.environment_params['terrorist_additional_influence']
|
||||
self.prob_interaction = environment.environment_params['prob_interaction']
|
||||
|
||||
if self['id'] == self.civilian.id: # Civilian
|
||||
self.initial_belief = random.uniform(0.00, 0.5)
|
||||
elif self['id'] == self.terrorist.id: # Terrorist
|
||||
self.initial_belief = random.uniform(0.8, 1.00)
|
||||
elif self['id'] == self.leader.id: # Leader
|
||||
self.initial_belief = 1.00
|
||||
else:
|
||||
raise Exception('Invalid state id: {}'.format(self['id']))
|
||||
|
||||
if 'min_vulnerability' in environment.environment_params:
|
||||
self.vulnerability = random.uniform( environment.environment_params['min_vulnerability'], environment.environment_params['max_vulnerability'] )
|
||||
else :
|
||||
self.vulnerability = random.uniform( 0, environment.environment_params['max_vulnerability'] )
|
||||
|
||||
self.mean_belief = self.initial_belief
|
||||
self.betweenness_centrality = betweenness_centrality_global[self.id]
|
||||
self.degree_centrality = degree_centrality_global[self.id]
|
||||
|
||||
# self.state['radicalism'] = self.mean_belief
|
||||
|
||||
def count_neighboring_agents(self, state_id=None):
|
||||
if isinstance(state_id, list):
|
||||
return len(self.get_neighboring_agents(state_id))
|
||||
else:
|
||||
return len(super().get_agents(state_id, limit_neighbors=True))
|
||||
|
||||
def get_neighboring_agents(self, state_id=None):
|
||||
if isinstance(state_id, list):
|
||||
_list = []
|
||||
for i in state_id:
|
||||
_list += super().get_agents(i, limit_neighbors=True)
|
||||
return [ neighbour for neighbour in _list if isinstance(neighbour, TerroristSpreadModel) ]
|
||||
else:
|
||||
_list = super().get_agents(state_id, limit_neighbors=True)
|
||||
return [ neighbour for neighbour in _list if isinstance(neighbour, TerroristSpreadModel) ]
|
||||
|
||||
@state
|
||||
def civilian(self):
|
||||
if self.count_neighboring_agents() > 0:
|
||||
neighbours = []
|
||||
for neighbour in self.get_neighboring_agents():
|
||||
if random.random() < self.prob_interaction:
|
||||
neighbours.append(neighbour)
|
||||
influence = sum( neighbour.degree_centrality for neighbour in neighbours )
|
||||
mean_belief = sum( neighbour.mean_belief * neighbour.degree_centrality / influence for neighbour in neighbours )
|
||||
self.initial_belief = self.mean_belief
|
||||
mean_belief = mean_belief * self.information_spread_intensity + self.initial_belief * ( 1 - self.information_spread_intensity )
|
||||
self.mean_belief = mean_belief * self.vulnerability + self.initial_belief * ( 1 - self.vulnerability )
|
||||
|
||||
if self.mean_belief >= 0.8:
|
||||
return self.terrorist
|
||||
|
||||
@state
|
||||
def leader(self):
|
||||
self.mean_belief = self.mean_belief ** ( 1 - self.terrorist_additional_influence )
|
||||
if self.count_neighboring_agents(state_id=[self.terrorist.id, self.leader.id]) > 0:
|
||||
for neighbour in self.get_neighboring_agents(state_id=[self.terrorist.id, self.leader.id]):
|
||||
if neighbour.betweenness_centrality > self.betweenness_centrality:
|
||||
return self.terrorist
|
||||
|
||||
@state
|
||||
def terrorist(self):
|
||||
if self.count_neighboring_agents(state_id=[self.terrorist.id, self.leader.id]) > 0:
|
||||
neighbours = self.get_neighboring_agents(state_id=[self.terrorist.id, self.leader.id])
|
||||
influence = sum( neighbour.degree_centrality for neighbour in neighbours )
|
||||
mean_belief = sum( neighbour.mean_belief * neighbour.degree_centrality / influence for neighbour in neighbours )
|
||||
self.initial_belief = self.mean_belief
|
||||
self.mean_belief = mean_belief * self.vulnerability + self.initial_belief * ( 1 - self.vulnerability )
|
||||
self.mean_belief = self.mean_belief ** ( 1 - self.terrorist_additional_influence )
|
||||
|
||||
if self.count_neighboring_agents(state_id=self.leader.id) == 0 and self.count_neighboring_agents(state_id=self.terrorist.id) > 0:
|
||||
max_betweenness_centrality = self
|
||||
for neighbour in self.get_neighboring_agents(state_id=self.terrorist.id):
|
||||
if neighbour.betweenness_centrality > max_betweenness_centrality.betweenness_centrality:
|
||||
max_betweenness_centrality = neighbour
|
||||
if max_betweenness_centrality == self:
|
||||
return self.leader
|
||||
|
||||
def add_edge(self, G, source, target):
|
||||
G.add_edge(source.id, target.id, start=self.env._now)
|
||||
|
||||
def link_search(self, G, node, radius):
|
||||
pos = nx.get_node_attributes(G, 'pos')
|
||||
nodes, coords = list(zip(*pos.items()))
|
||||
kdtree = KDTree(coords) # Cannot provide generator.
|
||||
edge_indexes = kdtree.query_pairs(radius, 2)
|
||||
_list = [ edge[int(not edge.index(node))] for edge in edge_indexes if node in edge ]
|
||||
return [ G.nodes()[index]['agent'] for index in _list ]
|
||||
|
||||
def social_search(self, G, node, steps):
|
||||
nodes = list(nx.ego_graph(G, node, radius=steps).nodes())
|
||||
nodes.remove(node)
|
||||
return [ G.nodes()[index]['agent'] for index in nodes ]
|
||||
|
||||
|
||||
class TrainingAreaModel(FSM):
|
||||
"""
|
||||
Settings:
|
||||
training_influence
|
||||
|
||||
min_vulnerability
|
||||
|
||||
Requires TerroristSpreadModel.
|
||||
"""
|
||||
|
||||
def __init__(self, environment=None, agent_id=0, state=()):
|
||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||
self.training_influence = environment.environment_params['training_influence']
|
||||
if 'min_vulnerability' in environment.environment_params:
|
||||
self.min_vulnerability = environment.environment_params['min_vulnerability']
|
||||
else: self.min_vulnerability = 0
|
||||
|
||||
@default_state
|
||||
@state
|
||||
def terrorist(self):
|
||||
for neighbour in self.get_neighboring_agents():
|
||||
if isinstance(neighbour, TerroristSpreadModel) and neighbour.vulnerability > self.min_vulnerability:
|
||||
neighbour.vulnerability = neighbour.vulnerability ** ( 1 - self.training_influence )
|
||||
|
||||
|
||||
class HavenModel(FSM):
|
||||
"""
|
||||
Settings:
|
||||
haven_influence
|
||||
|
||||
min_vulnerability
|
||||
|
||||
max_vulnerability
|
||||
|
||||
Requires TerroristSpreadModel.
|
||||
"""
|
||||
|
||||
def __init__(self, environment=None, agent_id=0, state=()):
|
||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||
self.haven_influence = environment.environment_params['haven_influence']
|
||||
if 'min_vulnerability' in environment.environment_params:
|
||||
self.min_vulnerability = environment.environment_params['min_vulnerability']
|
||||
else: self.min_vulnerability = 0
|
||||
self.max_vulnerability = environment.environment_params['max_vulnerability']
|
||||
|
||||
@state
|
||||
def civilian(self):
|
||||
for neighbour_agent in self.get_neighboring_agents():
|
||||
if isinstance(neighbour_agent, TerroristSpreadModel) and neighbour_agent['id'] == neighbour_agent.civilian.id:
|
||||
for neighbour in self.get_neighboring_agents():
|
||||
if isinstance(neighbour, TerroristSpreadModel) and neighbour.vulnerability > self.min_vulnerability:
|
||||
neighbour.vulnerability = neighbour.vulnerability * ( 1 - self.haven_influence )
|
||||
return self.civilian
|
||||
return self.terrorist
|
||||
|
||||
@state
|
||||
def terrorist(self):
|
||||
for neighbour in self.get_neighboring_agents():
|
||||
if isinstance(neighbour, TerroristSpreadModel) and neighbour.vulnerability < self.max_vulnerability:
|
||||
neighbour.vulnerability = neighbour.vulnerability ** ( 1 - self.haven_influence )
|
||||
return self.terrorist
|
||||
|
||||
|
||||
class TerroristNetworkModel(TerroristSpreadModel):
|
||||
"""
|
||||
Settings:
|
||||
sphere_influence
|
||||
|
||||
vision_range
|
||||
|
||||
weight_social_distance
|
||||
|
||||
weight_link_distance
|
||||
"""
|
||||
|
||||
def __init__(self, environment=None, agent_id=0, state=()):
|
||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||
|
||||
self.vision_range = environment.environment_params['vision_range']
|
||||
self.sphere_influence = environment.environment_params['sphere_influence']
|
||||
self.weight_social_distance = environment.environment_params['weight_social_distance']
|
||||
self.weight_link_distance = environment.environment_params['weight_link_distance']
|
||||
|
||||
@state
|
||||
def terrorist(self):
|
||||
self.update_relationships()
|
||||
return super().terrorist()
|
||||
|
||||
@state
|
||||
def leader(self):
|
||||
self.update_relationships()
|
||||
return super().leader()
|
||||
|
||||
def update_relationships(self):
|
||||
if self.count_neighboring_agents(state_id=self.civilian.id) == 0:
|
||||
close_ups = self.link_search(self.global_topology, self.id, self.vision_range)
|
||||
step_neighbours = self.social_search(self.global_topology, self.id, self.sphere_influence)
|
||||
search = list(set(close_ups).union(step_neighbours))
|
||||
neighbours = self.get_neighboring_agents()
|
||||
search = [item for item in search if not item in neighbours and isinstance(item, TerroristNetworkModel)]
|
||||
for agent in search:
|
||||
social_distance = 1 / self.shortest_path_length(self.global_topology, self.id, agent.id)
|
||||
spatial_proximity = ( 1 - self.get_distance(self.global_topology, self.id, agent.id) )
|
||||
prob_new_interaction = self.weight_social_distance * social_distance + self.weight_link_distance * spatial_proximity
|
||||
if agent['id'] == agent.civilian.id and random.random() < prob_new_interaction:
|
||||
self.add_edge(self.global_topology, self, agent)
|
||||
break
|
||||
|
||||
def get_distance(self, G, source, target):
|
||||
source_x, source_y = nx.get_node_attributes(G, 'pos')[source]
|
||||
target_x, target_y = nx.get_node_attributes(G, 'pos')[target]
|
||||
dx = abs( source_x - target_x )
|
||||
dy = abs( source_y - target_y )
|
||||
return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 )
|
||||
|
||||
def shortest_path_length(self, G, source, target):
|
||||
try:
|
||||
return nx.shortest_path_length(G, source, target)
|
||||
except nx.NetworkXNoPath:
|
||||
return float('inf')
|
||||
@@ -118,9 +118,9 @@ class SocketHandler(tornado.websocket.WebSocketHandler):
|
||||
elif msg['type'] == 'download_gexf':
|
||||
G = self.trials[ int(msg['data']) ].history_to_graph()
|
||||
for node in G.nodes():
|
||||
if 'pos' in G.node[node]:
|
||||
G.node[node]['viz'] = {"position": {"x": G.node[node]['pos'][0], "y": G.node[node]['pos'][1], "z": 0.0}}
|
||||
del (G.node[node]['pos'])
|
||||
if 'pos' in G.nodes[node]:
|
||||
G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}}
|
||||
del (G.nodes[node]['pos'])
|
||||
writer = nx.readwrite.gexf.GEXFWriter(version='1.2draft')
|
||||
writer.add_graph(G)
|
||||
self.write_message({'type': 'download_gexf',
|
||||
@@ -130,9 +130,9 @@ class SocketHandler(tornado.websocket.WebSocketHandler):
|
||||
elif msg['type'] == 'download_json':
|
||||
G = self.trials[ int(msg['data']) ].history_to_graph()
|
||||
for node in G.nodes():
|
||||
if 'pos' in G.node[node]:
|
||||
G.node[node]['viz'] = {"position": {"x": G.node[node]['pos'][0], "y": G.node[node]['pos'][1], "z": 0.0}}
|
||||
del (G.node[node]['pos'])
|
||||
if 'pos' in G.nodes[node]:
|
||||
G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}}
|
||||
del (G.nodes[node]['pos'])
|
||||
self.write_message({'type': 'download_json',
|
||||
'filename': self.config['name'] + '_trial_' + str(msg['data']),
|
||||
'data': nx.node_link_data(G) })
|
||||
@@ -180,7 +180,7 @@ class SocketHandler(tornado.websocket.WebSocketHandler):
|
||||
with self.logging(self.simulation_name):
|
||||
try:
|
||||
config = dict(**self.config)
|
||||
config['dir_path'] = os.path.join(self.application.dir_path, config['name'])
|
||||
config['outdir'] = os.path.join(self.application.outdir, config['name'])
|
||||
config['dump'] = self.application.dump
|
||||
self.trials = yield self.nonblocking(config)
|
||||
|
||||
@@ -232,12 +232,12 @@ class ModularServer(tornado.web.Application):
|
||||
settings = {'debug': True,
|
||||
'template_path': ROOT + '/templates'}
|
||||
|
||||
def __init__(self, dump=False, dir_path='output', name='SOIL', verbose=True, *args, **kwargs):
|
||||
def __init__(self, dump=False, outdir='output', name='SOIL', verbose=True, *args, **kwargs):
|
||||
|
||||
self.verbose = verbose
|
||||
self.name = name
|
||||
self.dump = dump
|
||||
self.dir_path = dir_path
|
||||
self.outdir = outdir
|
||||
|
||||
# Initializing the application itself:
|
||||
super().__init__(self.handlers, **self.settings)
|
||||
@@ -271,4 +271,4 @@ def main():
|
||||
parser.add_argument('--verbose', '-v', help='verbose mode', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
run(name=args.name, port=(args.port[0] if isinstance(args.port, list) else args.port), verbose=args.verbose)
|
||||
run(name=args.name, port=(args.port[0] if isinstance(args.port, list) else args.port), verbose=args.verbose)
|
||||
|
||||
@@ -39,7 +39,6 @@ class TestAnalysis(TestCase):
|
||||
agent should be able to update its state."""
|
||||
config = {
|
||||
'name': 'analysis',
|
||||
'dry_run': True,
|
||||
'seed': 'seed',
|
||||
'network_params': {
|
||||
'generator': 'complete_graph',
|
||||
@@ -53,7 +52,7 @@ class TestAnalysis(TestCase):
|
||||
}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
self.env = s.run_simulation()[0]
|
||||
self.env = s.run_simulation(dry_run=True)[0]
|
||||
|
||||
def test_saved(self):
|
||||
env = self.env
|
||||
@@ -65,10 +64,10 @@ class TestAnalysis(TestCase):
|
||||
|
||||
def test_count(self):
|
||||
env = self.env
|
||||
df = analysis.read_sql(env._history._db)
|
||||
df = analysis.read_sql(env._history.db_path)
|
||||
res = analysis.get_count(df, 'SEED', 'id')
|
||||
assert res['SEED']['seedanalysis_trial_0'].iloc[0] == 1
|
||||
assert res['SEED']['seedanalysis_trial_0'].iloc[-1] == 1
|
||||
assert res['SEED'][self.env['SEED']].iloc[0] == 1
|
||||
assert res['SEED'][self.env['SEED']].iloc[-1] == 1
|
||||
assert res['id']['odd'].iloc[0] == 2
|
||||
assert res['id']['even'].iloc[0] == 0
|
||||
assert res['id']['odd'].iloc[-1] == 1
|
||||
@@ -76,7 +75,7 @@ class TestAnalysis(TestCase):
|
||||
|
||||
def test_value(self):
|
||||
env = self.env
|
||||
df = analysis.read_sql(env._history._db)
|
||||
df = analysis.read_sql(env._history.db_path)
|
||||
res_sum = analysis.get_value(df, 'count')
|
||||
|
||||
assert res_sum['count'].iloc[0] == 2
|
||||
@@ -87,4 +86,4 @@ class TestAnalysis(TestCase):
|
||||
|
||||
res_total = analysis.get_value(df)
|
||||
|
||||
res_total['SEED'].iloc[0] == 'seedanalysis_trial_0'
|
||||
res_total['SEED'].iloc[0] == self.env['SEED']
|
||||
|
||||
@@ -2,11 +2,13 @@ from unittest import TestCase
|
||||
import os
|
||||
from os.path import join
|
||||
|
||||
from soil import utils, simulation
|
||||
from soil import serialization, simulation
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
EXAMPLES = join(ROOT, '..', 'examples')
|
||||
|
||||
FORCE_TESTS = os.environ.get('FORCE_TESTS', '')
|
||||
|
||||
|
||||
class TestExamples(TestCase):
|
||||
pass
|
||||
@@ -15,28 +17,32 @@ class TestExamples(TestCase):
|
||||
def make_example_test(path, config):
|
||||
def wrapped(self):
|
||||
root = os.getcwd()
|
||||
os.chdir(os.path.dirname(path))
|
||||
s = simulation.from_config(config)
|
||||
iterations = s.max_time * s.num_trials
|
||||
if iterations > 1000:
|
||||
self.skipTest('This example would probably take too long')
|
||||
envs = s.run_simulation(dry_run=True)
|
||||
assert envs
|
||||
for env in envs:
|
||||
assert env
|
||||
try:
|
||||
n = config['network_params']['n']
|
||||
assert len(list(env.network_agents)) == n
|
||||
assert env.now > 2 # It has run
|
||||
assert env.now <= config['max_time'] # But not further than allowed
|
||||
except KeyError:
|
||||
pass
|
||||
os.chdir(root)
|
||||
for s in simulation.all_from_config(path):
|
||||
iterations = s.max_time * s.num_trials
|
||||
if iterations > 1000:
|
||||
s.max_time = 100
|
||||
s.num_trials = 1
|
||||
if config.get('skip_test', False) and not FORCE_TESTS:
|
||||
self.skipTest('Example ignored.')
|
||||
envs = s.run_simulation(dry_run=True)
|
||||
assert envs
|
||||
for env in envs:
|
||||
assert env
|
||||
try:
|
||||
n = config['network_params']['n']
|
||||
assert len(list(env.network_agents)) == n
|
||||
assert env.now > 0 # It has run
|
||||
assert env.now <= config['max_time'] # But not further than allowed
|
||||
except KeyError:
|
||||
pass
|
||||
return wrapped
|
||||
|
||||
|
||||
def add_example_tests():
|
||||
for config, path in utils.load_config(join(EXAMPLES, '**', '*.yml')):
|
||||
for config, path in serialization.load_files(
|
||||
join(EXAMPLES, '*', '*.yml'),
|
||||
join(EXAMPLES, '*.yml'),
|
||||
):
|
||||
p = make_example_test(path=path, config=config)
|
||||
fname = os.path.basename(path)
|
||||
p.__name__ = 'test_example_file_%s' % fname
|
||||
|
||||
101
tests/test_exporters.py
Normal file
101
tests/test_exporters.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import os
|
||||
import io
|
||||
import tempfile
|
||||
import shutil
|
||||
from time import time
|
||||
|
||||
from unittest import TestCase
|
||||
from soil import exporters
|
||||
from soil import simulation
|
||||
|
||||
from soil.stats import distribution
|
||||
|
||||
class Dummy(exporters.Exporter):
|
||||
started = False
|
||||
trials = 0
|
||||
ended = False
|
||||
total_time = 0
|
||||
called_start = 0
|
||||
called_trial = 0
|
||||
called_end = 0
|
||||
|
||||
def start(self):
|
||||
self.__class__.called_start += 1
|
||||
self.__class__.started = True
|
||||
|
||||
def trial(self, env, stats):
|
||||
assert env
|
||||
self.__class__.trials += 1
|
||||
self.__class__.total_time += env.now
|
||||
self.__class__.called_trial += 1
|
||||
|
||||
def end(self, stats):
|
||||
self.__class__.ended = True
|
||||
self.__class__.called_end += 1
|
||||
|
||||
|
||||
class Exporters(TestCase):
|
||||
def test_basic(self):
|
||||
config = {
|
||||
'name': 'exporter_sim',
|
||||
'network_params': {},
|
||||
'agent_type': 'CounterModel',
|
||||
'max_time': 2,
|
||||
'num_trials': 5,
|
||||
'environment_params': {}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
for env in s.run_simulation(exporters=[Dummy], dry_run=True):
|
||||
assert env.now <= 2
|
||||
|
||||
assert Dummy.started
|
||||
assert Dummy.ended
|
||||
assert Dummy.called_start == 1
|
||||
assert Dummy.called_end == 1
|
||||
assert Dummy.called_trial == 5
|
||||
assert Dummy.trials == 5
|
||||
assert Dummy.total_time == 2*5
|
||||
|
||||
def test_writing(self):
|
||||
'''Try to write CSV, GEXF, sqlite and YAML (without dry_run)'''
|
||||
n_trials = 5
|
||||
config = {
|
||||
'name': 'exporter_sim',
|
||||
'network_params': {
|
||||
'generator': 'complete_graph',
|
||||
'n': 4
|
||||
},
|
||||
'agent_type': 'CounterModel',
|
||||
'max_time': 2,
|
||||
'num_trials': n_trials,
|
||||
'environment_params': {}
|
||||
}
|
||||
output = io.StringIO()
|
||||
s = simulation.from_config(config)
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
envs = s.run_simulation(exporters=[
|
||||
exporters.default,
|
||||
exporters.csv,
|
||||
exporters.gexf,
|
||||
],
|
||||
stats=[distribution,],
|
||||
outdir=tmpdir,
|
||||
exporter_params={'copy_to': output})
|
||||
result = output.getvalue()
|
||||
|
||||
simdir = os.path.join(tmpdir, s.group or '', s.name)
|
||||
with open(os.path.join(simdir, '{}.dumped.yml'.format(s.name))) as f:
|
||||
result = f.read()
|
||||
assert result
|
||||
|
||||
try:
|
||||
for e in envs:
|
||||
with open(os.path.join(simdir, '{}.gexf'.format(e.name))) as f:
|
||||
result = f.read()
|
||||
assert result
|
||||
|
||||
with open(os.path.join(simdir, '{}.csv'.format(e.name))) as f:
|
||||
result = f.read()
|
||||
assert result
|
||||
finally:
|
||||
shutil.rmtree(tmpdir)
|
||||
@@ -5,6 +5,7 @@ import shutil
|
||||
from glob import glob
|
||||
|
||||
from soil import history
|
||||
from soil import utils
|
||||
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
@@ -120,18 +121,18 @@ class TestHistory(TestCase):
|
||||
assert os.path.exists(db_path)
|
||||
|
||||
# Recover the data
|
||||
recovered = history.History(db_path=db_path, backup=False)
|
||||
recovered = history.History(db_path=db_path)
|
||||
assert recovered['a_1', 0, 'id'] == 'v'
|
||||
assert recovered['a_1', 4, 'id'] == 'e'
|
||||
|
||||
# Using the same name should create a backup copy
|
||||
# Using backup=True should create a backup copy, and initialize an empty history
|
||||
newhistory = history.History(db_path=db_path, backup=True)
|
||||
backuppaths = glob(db_path + '.backup*.sqlite')
|
||||
assert len(backuppaths) == 1
|
||||
backuppath = backuppaths[0]
|
||||
assert newhistory.db_path == h.db_path
|
||||
assert os.path.exists(backuppath)
|
||||
assert not len(newhistory[None, None, None])
|
||||
assert len(newhistory[None, None, None]) == 0
|
||||
|
||||
def test_history_tuples(self):
|
||||
"""
|
||||
@@ -154,3 +155,49 @@ class TestHistory(TestCase):
|
||||
assert recovered
|
||||
for i in recovered:
|
||||
assert i in tuples
|
||||
|
||||
def test_stats(self):
|
||||
"""
|
||||
The data recovered should be equal to the one recorded.
|
||||
"""
|
||||
tuples = (
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 2, 'prob', 2),
|
||||
('env', 3, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
stat_tuples = [
|
||||
{'num_infected': 5, 'runtime': 0.2},
|
||||
{'num_infected': 5, 'runtime': 0.2},
|
||||
{'new': '40'},
|
||||
]
|
||||
h = history.History()
|
||||
h.save_tuples(tuples)
|
||||
for stat in stat_tuples:
|
||||
h.save_stats(stat)
|
||||
recovered = h.get_stats()
|
||||
assert recovered
|
||||
assert recovered[0]['num_infected'] == 5
|
||||
assert recovered[1]['runtime'] == 0.2
|
||||
assert recovered[2]['new'] == '40'
|
||||
|
||||
def test_unflatten(self):
|
||||
ex = {'count.neighbors.3': 4,
|
||||
'count.times.2': 4,
|
||||
'count.total.4': 4,
|
||||
'mean.neighbors': 3,
|
||||
'mean.times': 2,
|
||||
'mean.total': 4,
|
||||
't_step': 2,
|
||||
'trial_id': 'exporter_sim_trial_1605817956-4475424'}
|
||||
res = utils.unflatten_dict(ex)
|
||||
|
||||
assert 'count' in res
|
||||
assert 'mean' in res
|
||||
assert 't_step' in res
|
||||
assert 'trial_id' in res
|
||||
|
||||
@@ -1,23 +1,30 @@
|
||||
from unittest import TestCase
|
||||
|
||||
import os
|
||||
import io
|
||||
import yaml
|
||||
import pickle
|
||||
import networkx as nx
|
||||
from functools import partial
|
||||
|
||||
from os.path import join
|
||||
from soil import simulation, Environment, agents, utils, history
|
||||
from soil import (simulation, Environment, agents, serialization,
|
||||
history, utils)
|
||||
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
EXAMPLES = join(ROOT, '..', 'examples')
|
||||
|
||||
|
||||
class CustomAgent(agents.BaseAgent):
|
||||
def step(self):
|
||||
self.state['neighbors'] = self.count_agents(state_id=0,
|
||||
class CustomAgent(agents.FSM):
|
||||
@agents.default_state
|
||||
@agents.state
|
||||
def normal(self):
|
||||
self.state['neighbors'] = self.count_agents(state_id='normal',
|
||||
limit_neighbors=True)
|
||||
@agents.state
|
||||
def unreachable(self):
|
||||
return
|
||||
|
||||
class TestMain(TestCase):
|
||||
|
||||
@@ -27,22 +34,20 @@ class TestMain(TestCase):
|
||||
Raise an exception otherwise.
|
||||
"""
|
||||
config = {
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
}
|
||||
}
|
||||
G = utils.load_network(config['network_params'])
|
||||
G = serialization.load_network(config['network_params'])
|
||||
assert G
|
||||
assert len(G) == 2
|
||||
with self.assertRaises(AttributeError):
|
||||
config = {
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'unknown.extension')
|
||||
}
|
||||
}
|
||||
G = utils.load_network(config['network_params'])
|
||||
G = serialization.load_network(config['network_params'])
|
||||
print(G)
|
||||
|
||||
def test_generate_barabasi(self):
|
||||
@@ -51,22 +56,20 @@ class TestMain(TestCase):
|
||||
should be used to generate a network
|
||||
"""
|
||||
config = {
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'generator': 'barabasi_albert_graph'
|
||||
}
|
||||
}
|
||||
with self.assertRaises(TypeError):
|
||||
G = utils.load_network(config['network_params'])
|
||||
G = serialization.load_network(config['network_params'])
|
||||
config['network_params']['n'] = 100
|
||||
config['network_params']['m'] = 10
|
||||
G = utils.load_network(config['network_params'])
|
||||
G = serialization.load_network(config['network_params'])
|
||||
assert len(G) == 100
|
||||
|
||||
def test_empty_simulation(self):
|
||||
"""A simulation with a base behaviour should do nothing"""
|
||||
config = {
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
},
|
||||
@@ -83,7 +86,6 @@ class TestMain(TestCase):
|
||||
agent should be able to update its state."""
|
||||
config = {
|
||||
'name': 'CounterAgent',
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
},
|
||||
@@ -107,7 +109,6 @@ class TestMain(TestCase):
|
||||
"""
|
||||
config = {
|
||||
'name': 'CounterAgent',
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
},
|
||||
@@ -133,14 +134,12 @@ class TestMain(TestCase):
|
||||
def test_custom_agent(self):
|
||||
"""Allow for search of neighbors with a certain state_id"""
|
||||
config = {
|
||||
'dry_run': True,
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
},
|
||||
'network_agents': [{
|
||||
'agent_type': CustomAgent,
|
||||
'weight': 1,
|
||||
'state': {'id': 0}
|
||||
'weight': 1
|
||||
|
||||
}],
|
||||
'max_time': 10,
|
||||
@@ -150,15 +149,17 @@ class TestMain(TestCase):
|
||||
s = simulation.from_config(config)
|
||||
env = s.run_simulation(dry_run=True)[0]
|
||||
assert env.get_agent(0).state['neighbors'] == 1
|
||||
assert env.get_agent(0).state['neighbors'] == 1
|
||||
assert env.get_agent(1).count_agents(state_id='normal') == 2
|
||||
assert env.get_agent(1).count_agents(state_id='normal', limit_neighbors=True) == 1
|
||||
|
||||
def test_torvalds_example(self):
|
||||
"""A complete example from a documentation should work."""
|
||||
config = utils.load_file(join(EXAMPLES, 'torvalds.yml'))[0]
|
||||
config = serialization.load_file(join(EXAMPLES, 'torvalds.yml'))[0]
|
||||
config['network_params']['path'] = join(EXAMPLES,
|
||||
config['network_params']['path'])
|
||||
s = simulation.from_config(config)
|
||||
s.dry_run = True
|
||||
env = s.run_simulation()[0]
|
||||
env = s.run_simulation(dry_run=True)[0]
|
||||
for a in env.network_agents:
|
||||
skill_level = a.state['skill_level']
|
||||
if a.id == 'Torvalds':
|
||||
@@ -180,13 +181,12 @@ class TestMain(TestCase):
|
||||
should be equivalent to the configuration file used
|
||||
"""
|
||||
with utils.timer('loading'):
|
||||
config = utils.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
||||
config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
||||
s = simulation.from_config(config)
|
||||
s.dry_run = True
|
||||
with utils.timer('serializing'):
|
||||
serial = s.to_yaml()
|
||||
with utils.timer('recovering'):
|
||||
recovered = yaml.load(serial)
|
||||
recovered = yaml.load(serial, Loader=yaml.SafeLoader)
|
||||
with utils.timer('deleting'):
|
||||
del recovered['topology']
|
||||
assert config == recovered
|
||||
@@ -196,9 +196,8 @@ class TestMain(TestCase):
|
||||
The configuration should not change after running
|
||||
the simulation.
|
||||
"""
|
||||
config = utils.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
||||
config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
||||
s = simulation.from_config(config)
|
||||
s.dry_run = True
|
||||
for i in range(5):
|
||||
s.run_simulation(dry_run=True)
|
||||
nconfig = s.to_dict()
|
||||
@@ -206,7 +205,7 @@ class TestMain(TestCase):
|
||||
assert config == nconfig
|
||||
|
||||
def test_row_conversion(self):
|
||||
env = Environment(dry_run=True)
|
||||
env = Environment()
|
||||
env['test'] = 'test_value'
|
||||
|
||||
res = list(env.history_to_tuples())
|
||||
@@ -225,8 +224,9 @@ class TestMain(TestCase):
|
||||
from geometric models. We should work around it.
|
||||
"""
|
||||
G = nx.random_geometric_graph(20, 0.1)
|
||||
env = Environment(topology=G, dry_run=True)
|
||||
env.dump_gexf('/tmp/dump-gexf')
|
||||
env = Environment(topology=G)
|
||||
f = io.BytesIO()
|
||||
env.dump_gexf(f)
|
||||
|
||||
def test_save_graph(self):
|
||||
'''
|
||||
@@ -236,20 +236,20 @@ class TestMain(TestCase):
|
||||
'''
|
||||
G = nx.cycle_graph(5)
|
||||
distribution = agents.calculate_distribution(None, agents.BaseAgent)
|
||||
env = Environment(topology=G, network_agents=distribution, dry_run=True)
|
||||
env = Environment(topology=G, network_agents=distribution)
|
||||
env[0, 0, 'testvalue'] = 'start'
|
||||
env[0, 10, 'testvalue'] = 'finish'
|
||||
nG = env.history_to_graph()
|
||||
values = nG.node[0]['attr_testvalue']
|
||||
values = nG.nodes[0]['attr_testvalue']
|
||||
assert ('start', 0, 10) in values
|
||||
assert ('finish', 10, None) in values
|
||||
|
||||
def test_serialize_class(self):
|
||||
ser, name = utils.serialize(agents.BaseAgent)
|
||||
ser, name = serialization.serialize(agents.BaseAgent)
|
||||
assert name == 'soil.agents.BaseAgent'
|
||||
assert ser == agents.BaseAgent
|
||||
|
||||
ser, name = utils.serialize(CustomAgent)
|
||||
ser, name = serialization.serialize(CustomAgent)
|
||||
assert name == 'test_main.CustomAgent'
|
||||
assert ser == CustomAgent
|
||||
pickle.dumps(ser)
|
||||
@@ -257,9 +257,9 @@ class TestMain(TestCase):
|
||||
def test_serialize_builtin_types(self):
|
||||
|
||||
for i in [1, None, True, False, {}, [], list(), dict()]:
|
||||
ser, name = utils.serialize(i)
|
||||
ser, name = serialization.serialize(i)
|
||||
assert type(ser) == str
|
||||
des = utils.deserialize(name, ser)
|
||||
des = serialization.deserialize(name, ser)
|
||||
assert i == des
|
||||
|
||||
def test_serialize_agent_type(self):
|
||||
@@ -312,11 +312,47 @@ class TestMain(TestCase):
|
||||
recovered = pickle.loads(pickled)
|
||||
|
||||
assert recovered.env.name == 'Test'
|
||||
assert recovered['key'] == 'test'
|
||||
assert list(recovered.env._history.to_tuples())
|
||||
assert recovered['key', 0] == 'test'
|
||||
assert recovered['key'] == 'test'
|
||||
|
||||
def test_history(self):
|
||||
'''Test storing in and retrieving from history (sqlite)'''
|
||||
h = history.History()
|
||||
h.save_record(agent_id=0, t_step=0, key="test", value="hello")
|
||||
assert h[0, 0, "test"] == "hello"
|
||||
|
||||
def test_subgraph(self):
|
||||
'''An agent should be able to subgraph the global topology'''
|
||||
G = nx.Graph()
|
||||
G.add_node(3)
|
||||
G.add_edge(1, 2)
|
||||
distro = agents.calculate_distribution(agent_type=agents.NetworkAgent)
|
||||
env = Environment(name='Test', topology=G, network_agents=distro)
|
||||
lst = list(env.network_agents)
|
||||
|
||||
a2 = env.get_agent(2)
|
||||
a3 = env.get_agent(3)
|
||||
assert len(a2.subgraph(limit_neighbors=True)) == 2
|
||||
assert len(a3.subgraph(limit_neighbors=True)) == 1
|
||||
assert len(a3.subgraph(limit_neighbors=True, center=False)) == 0
|
||||
assert len(a3.subgraph(agent_type=agents.NetworkAgent)) == 3
|
||||
|
||||
def test_templates(self):
|
||||
'''Loading a template should result in several configs'''
|
||||
configs = serialization.load_file(join(EXAMPLES, 'template.yml'))
|
||||
assert len(configs) > 0
|
||||
|
||||
def test_until(self):
|
||||
config = {
|
||||
'name': 'exporter_sim',
|
||||
'network_params': {},
|
||||
'agent_type': 'CounterModel',
|
||||
'max_time': 2,
|
||||
'num_trials': 100,
|
||||
'environment_params': {}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
runs = list(s.run_simulation(dry_run=True))
|
||||
over = list(x.now for x in runs if x.now>2)
|
||||
assert len(over) == 0
|
||||
|
||||
34
tests/test_stats.py
Normal file
34
tests/test_stats.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from soil import simulation, stats
|
||||
from soil.utils import unflatten_dict
|
||||
|
||||
class Stats(TestCase):
|
||||
|
||||
def test_distribution(self):
|
||||
'''The distribution exporter should write the number of agents in each state'''
|
||||
config = {
|
||||
'name': 'exporter_sim',
|
||||
'network_params': {
|
||||
'generator': 'complete_graph',
|
||||
'n': 4
|
||||
},
|
||||
'agent_type': 'CounterModel',
|
||||
'max_time': 2,
|
||||
'num_trials': 5,
|
||||
'environment_params': {}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
for env in s.run_simulation(stats=[stats.distribution]):
|
||||
pass
|
||||
# stats_res = unflatten_dict(dict(env._history['stats', -1, None]))
|
||||
allstats = s.get_stats()
|
||||
for stat in allstats:
|
||||
assert 'count' in stat
|
||||
assert 'mean' in stat
|
||||
if 'trial_id' in stat:
|
||||
assert stat['mean']['neighbors'] == 3
|
||||
assert stat['count']['total']['4'] == 4
|
||||
else:
|
||||
assert stat['count']['count']['neighbors']['3'] == 20
|
||||
assert stat['mean']['min']['neighbors'] == stat['mean']['max']['neighbors']
|
||||
Reference in New Issue
Block a user