mirror of
https://github.com/gsi-upm/soil
synced 2025-10-24 04:08:21 +00:00
Compare commits
8 Commits
0.13.7
...
b0add8552e
Author | SHA1 | Date | |
---|---|---|---|
|
b0add8552e | ||
|
1cf85ea450 | ||
|
c32e167fb8 | ||
|
5f68b5321d | ||
|
2a2843bd19 | ||
|
d1006bd55c | ||
|
9bc036d185 | ||
|
a3ea434f23 |
34
CHANGELOG.md
34
CHANGELOG.md
@@ -3,9 +3,39 @@ All notable changes to this project will be documented in this file.
|
|||||||
|
|
||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
## [Unreleased]
|
## [0.14.0]
|
||||||
### Changed
|
|
||||||
### Added
|
### Added
|
||||||
|
* Loading configuration from template definitions in the yaml, in preparation for SALib support.
|
||||||
|
The definition of the variables and their possible values (i.e., a problem in SALib terms), as well as a sampler function, can be provided.
|
||||||
|
Soil uses this definition and the template to generate a set of configurations.
|
||||||
|
* Simulation group names, to link related simulations. For now, they are only used to group all simulations in the same group under the same folder.
|
||||||
|
* Exporters unify exporting/dumping results and other files to disk. If `dry_run` is set to `True`, exporters will write to stdout instead of a file (useful for testing/debugging).
|
||||||
|
* Distribution exporter, to write statistics about values and value_counts in every simulation. The results are dumped to two CSV files.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
* `dir_path` is now the directory for resources (modules, files)
|
||||||
|
* Environments and simulations do not export or write anything by default. That task is delegated to Exporters
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
* The output dir for environments and simulations (see Exporters)
|
||||||
|
* DrawingAgent, because it wrote to disk and was not being used. We provide a partial alternative in the form of the GraphDrawing exporter. A complete alternative will be provided once the network at each state can be accessed by exporters.
|
||||||
|
|
||||||
|
## Fixed
|
||||||
|
* Modules with custom agents/environments failed to load when they were run from outside the directory of the definition file. Modules are now loaded from the directory of the simulation file in addition to the working directory
|
||||||
|
* Memory databases (in history) can now be shared between threads.
|
||||||
|
* Testing all examples, not just subdirectories
|
||||||
|
|
||||||
|
## [0.13.8]
|
||||||
|
### Changed
|
||||||
|
* Moved TerroristNetworkModel to examples
|
||||||
|
### Added
|
||||||
|
* `get_agents` and `count_agents` methods now accept lists as inputs. They can be used to retrieve agents from node ids
|
||||||
|
* `subgraph` in BaseAgent
|
||||||
|
* `agents.select` method, to filter out agents
|
||||||
|
* `skip_test` property in yaml definitions, to force skipping some examples
|
||||||
|
* `agents.Geo`, with a search function based on postition
|
||||||
|
* `BaseAgent.ego_search` to get nodes from the ego network of a node
|
||||||
|
* `BaseAgent.degree` and `BaseAgent.betweenness`
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
## [0.13.7]
|
## [0.13.7]
|
||||||
|
@@ -2,3 +2,6 @@ include requirements.txt
|
|||||||
include test-requirements.txt
|
include test-requirements.txt
|
||||||
include README.rst
|
include README.rst
|
||||||
graft soil
|
graft soil
|
||||||
|
global-exclude __pycache__
|
||||||
|
global-exclude soil_output
|
||||||
|
global-exclude *.py[co]
|
||||||
|
5
Makefile
5
Makefile
@@ -1,4 +1,7 @@
|
|||||||
test:
|
quick-test:
|
||||||
docker-compose exec dev python -m pytest -s -v
|
docker-compose exec dev python -m pytest -s -v
|
||||||
|
|
||||||
|
test:
|
||||||
|
docker run -t -v $$PWD:/usr/src/app -w /usr/src/app python:3.7 python setup.py test
|
||||||
|
|
||||||
.PHONY: test
|
.PHONY: test
|
@@ -1,12 +1,11 @@
|
|||||||
---
|
---
|
||||||
name: simple
|
name: simple
|
||||||
|
group: tests
|
||||||
dir_path: "/tmp/"
|
dir_path: "/tmp/"
|
||||||
num_trials: 3
|
num_trials: 3
|
||||||
dry_run: True
|
|
||||||
max_time: 100
|
max_time: 100
|
||||||
interval: 1
|
interval: 1
|
||||||
seed: "CompleteSeed!"
|
seed: "CompleteSeed!"
|
||||||
dump: false
|
|
||||||
network_params:
|
network_params:
|
||||||
generator: complete_graph
|
generator: complete_graph
|
||||||
n: 10
|
n: 10
|
||||||
|
@@ -2,7 +2,6 @@
|
|||||||
name: custom-generator
|
name: custom-generator
|
||||||
description: Using a custom generator for the network
|
description: Using a custom generator for the network
|
||||||
num_trials: 3
|
num_trials: 3
|
||||||
dry_run: True
|
|
||||||
max_time: 100
|
max_time: 100
|
||||||
interval: 1
|
interval: 1
|
||||||
network_params:
|
network_params:
|
||||||
|
@@ -29,8 +29,7 @@ if __name__ == '__main__':
|
|||||||
from soil import Simulation
|
from soil import Simulation
|
||||||
s = Simulation(network_agents=[{'ids': [0], 'agent_type': Fibonacci},
|
s = Simulation(network_agents=[{'ids': [0], 'agent_type': Fibonacci},
|
||||||
{'ids': [1], 'agent_type': Odds}],
|
{'ids': [1], 'agent_type': Odds}],
|
||||||
dry_run=True,
|
|
||||||
network_params={"generator": "complete_graph", "n": 2},
|
network_params={"generator": "complete_graph", "n": 2},
|
||||||
max_time=100,
|
max_time=100,
|
||||||
)
|
)
|
||||||
s.run()
|
s.run(dry_run=True)
|
||||||
|
@@ -6,7 +6,7 @@ environment_params:
|
|||||||
prob_neighbor_spread: 0.0
|
prob_neighbor_spread: 0.0
|
||||||
prob_tv_spread: 0.01
|
prob_tv_spread: 0.01
|
||||||
interval: 1
|
interval: 1
|
||||||
max_time: 30
|
max_time: 300
|
||||||
name: Sim_all_dumb
|
name: Sim_all_dumb
|
||||||
network_agents:
|
network_agents:
|
||||||
- agent_type: DumbViewer
|
- agent_type: DumbViewer
|
||||||
@@ -30,7 +30,7 @@ environment_params:
|
|||||||
prob_neighbor_spread: 0.0
|
prob_neighbor_spread: 0.0
|
||||||
prob_tv_spread: 0.01
|
prob_tv_spread: 0.01
|
||||||
interval: 1
|
interval: 1
|
||||||
max_time: 30
|
max_time: 300
|
||||||
name: Sim_half_herd
|
name: Sim_half_herd
|
||||||
network_agents:
|
network_agents:
|
||||||
- agent_type: DumbViewer
|
- agent_type: DumbViewer
|
||||||
@@ -62,7 +62,7 @@ environment_params:
|
|||||||
prob_neighbor_spread: 0.0
|
prob_neighbor_spread: 0.0
|
||||||
prob_tv_spread: 0.01
|
prob_tv_spread: 0.01
|
||||||
interval: 1
|
interval: 1
|
||||||
max_time: 30
|
max_time: 300
|
||||||
name: Sim_all_herd
|
name: Sim_all_herd
|
||||||
network_agents:
|
network_agents:
|
||||||
- agent_type: HerdViewer
|
- agent_type: HerdViewer
|
||||||
@@ -89,7 +89,7 @@ environment_params:
|
|||||||
prob_tv_spread: 0.01
|
prob_tv_spread: 0.01
|
||||||
prob_neighbor_cure: 0.1
|
prob_neighbor_cure: 0.1
|
||||||
interval: 1
|
interval: 1
|
||||||
max_time: 30
|
max_time: 300
|
||||||
name: Sim_wise_herd
|
name: Sim_wise_herd
|
||||||
network_agents:
|
network_agents:
|
||||||
- agent_type: HerdViewer
|
- agent_type: HerdViewer
|
||||||
@@ -115,7 +115,7 @@ environment_params:
|
|||||||
prob_tv_spread: 0.01
|
prob_tv_spread: 0.01
|
||||||
prob_neighbor_cure: 0.1
|
prob_neighbor_cure: 0.1
|
||||||
interval: 1
|
interval: 1
|
||||||
max_time: 30
|
max_time: 300
|
||||||
name: Sim_all_wise
|
name: Sim_all_wise
|
||||||
network_agents:
|
network_agents:
|
||||||
- agent_type: WiseViewer
|
- agent_type: WiseViewer
|
||||||
|
29
examples/template.yml
Normal file
29
examples/template.yml
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
---
|
||||||
|
vars:
|
||||||
|
bounds:
|
||||||
|
x1: [0, 1]
|
||||||
|
x2: [1, 2]
|
||||||
|
fixed:
|
||||||
|
x3: ["a", "b", "c"]
|
||||||
|
sampler: "SALib.sample.morris.sample"
|
||||||
|
samples: 10
|
||||||
|
template: |
|
||||||
|
group: simple
|
||||||
|
num_trials: 1
|
||||||
|
interval: 1
|
||||||
|
max_time: 2
|
||||||
|
seed: "CompleteSeed!"
|
||||||
|
dump: false
|
||||||
|
network_params:
|
||||||
|
generator: complete_graph
|
||||||
|
n: 10
|
||||||
|
network_agents:
|
||||||
|
- agent_type: CounterModel
|
||||||
|
weight: {{ x1 }}
|
||||||
|
state:
|
||||||
|
id: 0
|
||||||
|
- agent_type: AggregatedCounter
|
||||||
|
weight: {{ 1 - x1 }}
|
||||||
|
environment_params:
|
||||||
|
name: {{ x3 }}
|
||||||
|
skip_test: true
|
208
examples/terrorism/TerroristNetworkModel.py
Normal file
208
examples/terrorism/TerroristNetworkModel.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
import random
|
||||||
|
import networkx as nx
|
||||||
|
from soil.agents import Geo, NetworkAgent, FSM, state, default_state
|
||||||
|
from soil import Environment
|
||||||
|
|
||||||
|
|
||||||
|
class TerroristSpreadModel(FSM, Geo):
|
||||||
|
"""
|
||||||
|
Settings:
|
||||||
|
information_spread_intensity
|
||||||
|
|
||||||
|
terrorist_additional_influence
|
||||||
|
|
||||||
|
min_vulnerability (optional else zero)
|
||||||
|
|
||||||
|
max_vulnerability
|
||||||
|
|
||||||
|
prob_interaction
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, environment=None, agent_id=0, state=()):
|
||||||
|
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||||
|
|
||||||
|
self.information_spread_intensity = environment.environment_params['information_spread_intensity']
|
||||||
|
self.terrorist_additional_influence = environment.environment_params['terrorist_additional_influence']
|
||||||
|
self.prob_interaction = environment.environment_params['prob_interaction']
|
||||||
|
|
||||||
|
if self['id'] == self.civilian.id: # Civilian
|
||||||
|
self.mean_belief = random.uniform(0.00, 0.5)
|
||||||
|
elif self['id'] == self.terrorist.id: # Terrorist
|
||||||
|
self.mean_belief = random.uniform(0.8, 1.00)
|
||||||
|
elif self['id'] == self.leader.id: # Leader
|
||||||
|
self.mean_belief = 1.00
|
||||||
|
else:
|
||||||
|
raise Exception('Invalid state id: {}'.format(self['id']))
|
||||||
|
|
||||||
|
if 'min_vulnerability' in environment.environment_params:
|
||||||
|
self.vulnerability = random.uniform( environment.environment_params['min_vulnerability'], environment.environment_params['max_vulnerability'] )
|
||||||
|
else :
|
||||||
|
self.vulnerability = random.uniform( 0, environment.environment_params['max_vulnerability'] )
|
||||||
|
|
||||||
|
|
||||||
|
@state
|
||||||
|
def civilian(self):
|
||||||
|
neighbours = list(self.get_neighboring_agents(agent_type=TerroristSpreadModel))
|
||||||
|
if len(neighbours) > 0:
|
||||||
|
# Only interact with some of the neighbors
|
||||||
|
interactions = list(n for n in neighbours if random.random() <= self.prob_interaction)
|
||||||
|
influence = sum( self.degree(i) for i in interactions )
|
||||||
|
mean_belief = sum( i.mean_belief * self.degree(i) / influence for i in interactions )
|
||||||
|
mean_belief = mean_belief * self.information_spread_intensity + self.mean_belief * ( 1 - self.information_spread_intensity )
|
||||||
|
self.mean_belief = mean_belief * self.vulnerability + self.mean_belief * ( 1 - self.vulnerability )
|
||||||
|
|
||||||
|
if self.mean_belief >= 0.8:
|
||||||
|
return self.terrorist
|
||||||
|
|
||||||
|
@state
|
||||||
|
def leader(self):
|
||||||
|
self.mean_belief = self.mean_belief ** ( 1 - self.terrorist_additional_influence )
|
||||||
|
for neighbour in self.get_neighboring_agents(state_id=[self.terrorist.id, self.leader.id]):
|
||||||
|
if self.betweenness(neighbour) > self.betweenness(self):
|
||||||
|
return self.terrorist
|
||||||
|
|
||||||
|
@state
|
||||||
|
def terrorist(self):
|
||||||
|
neighbours = self.get_agents(state_id=[self.terrorist.id, self.leader.id],
|
||||||
|
agent_type=TerroristSpreadModel,
|
||||||
|
limit_neighbors=True)
|
||||||
|
if len(neighbours) > 0:
|
||||||
|
influence = sum( self.degree(n) for n in neighbours )
|
||||||
|
mean_belief = sum( n.mean_belief * self.degree(n) / influence for n in neighbours )
|
||||||
|
mean_belief = mean_belief * self.vulnerability + self.mean_belief * ( 1 - self.vulnerability )
|
||||||
|
self.mean_belief = self.mean_belief ** ( 1 - self.terrorist_additional_influence )
|
||||||
|
|
||||||
|
# Check if there are any leaders in the group
|
||||||
|
leaders = list(filter(lambda x: x.state.id == self.leader.id, neighbours))
|
||||||
|
if not leaders:
|
||||||
|
# Check if this is the potential leader
|
||||||
|
# Stop once it's found. Otherwise, set self as leader
|
||||||
|
for neighbour in neighbours:
|
||||||
|
if self.betweenness(self) < self.betweenness(neighbour):
|
||||||
|
return
|
||||||
|
return self.leader
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingAreaModel(FSM, Geo):
|
||||||
|
"""
|
||||||
|
Settings:
|
||||||
|
training_influence
|
||||||
|
|
||||||
|
min_vulnerability
|
||||||
|
|
||||||
|
Requires TerroristSpreadModel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, environment=None, agent_id=0, state=()):
|
||||||
|
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||||
|
self.training_influence = environment.environment_params['training_influence']
|
||||||
|
if 'min_vulnerability' in environment.environment_params:
|
||||||
|
self.min_vulnerability = environment.environment_params['min_vulnerability']
|
||||||
|
else: self.min_vulnerability = 0
|
||||||
|
|
||||||
|
@default_state
|
||||||
|
@state
|
||||||
|
def terrorist(self):
|
||||||
|
for neighbour in self.get_neighboring_agents(agent_type=TerroristSpreadModel):
|
||||||
|
if neighbour.vulnerability > self.min_vulnerability:
|
||||||
|
neighbour.vulnerability = neighbour.vulnerability ** ( 1 - self.training_influence )
|
||||||
|
|
||||||
|
|
||||||
|
class HavenModel(FSM, Geo):
|
||||||
|
"""
|
||||||
|
Settings:
|
||||||
|
haven_influence
|
||||||
|
|
||||||
|
min_vulnerability
|
||||||
|
|
||||||
|
max_vulnerability
|
||||||
|
|
||||||
|
Requires TerroristSpreadModel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, environment=None, agent_id=0, state=()):
|
||||||
|
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||||
|
self.haven_influence = environment.environment_params['haven_influence']
|
||||||
|
if 'min_vulnerability' in environment.environment_params:
|
||||||
|
self.min_vulnerability = environment.environment_params['min_vulnerability']
|
||||||
|
else: self.min_vulnerability = 0
|
||||||
|
self.max_vulnerability = environment.environment_params['max_vulnerability']
|
||||||
|
|
||||||
|
def get_occupants(self, **kwargs):
|
||||||
|
return self.get_neighboring_agents(agent_type=TerroristSpreadModel, **kwargs)
|
||||||
|
|
||||||
|
@state
|
||||||
|
def civilian(self):
|
||||||
|
civilians = self.get_occupants(state_id=self.civilian.id)
|
||||||
|
if not civilians:
|
||||||
|
return self.terrorist
|
||||||
|
|
||||||
|
for neighbour in self.get_occupants():
|
||||||
|
if neighbour.vulnerability > self.min_vulnerability:
|
||||||
|
neighbour.vulnerability = neighbour.vulnerability * ( 1 - self.haven_influence )
|
||||||
|
return self.civilian
|
||||||
|
|
||||||
|
@state
|
||||||
|
def terrorist(self):
|
||||||
|
for neighbour in self.get_occupants():
|
||||||
|
if neighbour.vulnerability < self.max_vulnerability:
|
||||||
|
neighbour.vulnerability = neighbour.vulnerability ** ( 1 - self.haven_influence )
|
||||||
|
return self.terrorist
|
||||||
|
|
||||||
|
|
||||||
|
class TerroristNetworkModel(TerroristSpreadModel):
|
||||||
|
"""
|
||||||
|
Settings:
|
||||||
|
sphere_influence
|
||||||
|
|
||||||
|
vision_range
|
||||||
|
|
||||||
|
weight_social_distance
|
||||||
|
|
||||||
|
weight_link_distance
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, environment=None, agent_id=0, state=()):
|
||||||
|
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||||
|
|
||||||
|
self.vision_range = environment.environment_params['vision_range']
|
||||||
|
self.sphere_influence = environment.environment_params['sphere_influence']
|
||||||
|
self.weight_social_distance = environment.environment_params['weight_social_distance']
|
||||||
|
self.weight_link_distance = environment.environment_params['weight_link_distance']
|
||||||
|
|
||||||
|
@state
|
||||||
|
def terrorist(self):
|
||||||
|
self.update_relationships()
|
||||||
|
return super().terrorist()
|
||||||
|
|
||||||
|
@state
|
||||||
|
def leader(self):
|
||||||
|
self.update_relationships()
|
||||||
|
return super().leader()
|
||||||
|
|
||||||
|
def update_relationships(self):
|
||||||
|
if self.count_neighboring_agents(state_id=self.civilian.id) == 0:
|
||||||
|
close_ups = set(self.geo_search(radius=self.vision_range, agent_type=TerroristNetworkModel))
|
||||||
|
step_neighbours = set(self.ego_search(self.sphere_influence, agent_type=TerroristNetworkModel, center=False))
|
||||||
|
neighbours = set(agent.id for agent in self.get_neighboring_agents(agent_type=TerroristNetworkModel))
|
||||||
|
search = (close_ups | step_neighbours) - neighbours
|
||||||
|
for agent in self.get_agents(search):
|
||||||
|
social_distance = 1 / self.shortest_path_length(agent.id)
|
||||||
|
spatial_proximity = ( 1 - self.get_distance(agent.id) )
|
||||||
|
prob_new_interaction = self.weight_social_distance * social_distance + self.weight_link_distance * spatial_proximity
|
||||||
|
if agent['id'] == agent.civilian.id and random.random() < prob_new_interaction:
|
||||||
|
self.add_edge(agent)
|
||||||
|
break
|
||||||
|
|
||||||
|
def get_distance(self, target):
|
||||||
|
source_x, source_y = nx.get_node_attributes(self.global_topology, 'pos')[self.id]
|
||||||
|
target_x, target_y = nx.get_node_attributes(self.global_topology, 'pos')[target]
|
||||||
|
dx = abs( source_x - target_x )
|
||||||
|
dy = abs( source_y - target_y )
|
||||||
|
return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 )
|
||||||
|
|
||||||
|
def shortest_path_length(self, target):
|
||||||
|
try:
|
||||||
|
return nx.shortest_path_length(self.global_topology, self.id, target)
|
||||||
|
except nx.NetworkXNoPath:
|
||||||
|
return float('inf')
|
@@ -60,3 +60,4 @@ visualization_params:
|
|||||||
background_image: 'map_4800x2860.jpg'
|
background_image: 'map_4800x2860.jpg'
|
||||||
background_opacity: '0.9'
|
background_opacity: '0.9'
|
||||||
background_filter_color: 'blue'
|
background_filter_color: 'blue'
|
||||||
|
skip_test: true # This simulation takes too long for automated tests.
|
@@ -1,7 +1,10 @@
|
|||||||
nxsim
|
nxsim>=0.1.2
|
||||||
simpy
|
simpy
|
||||||
networkx>=2.0
|
networkx>=2.0
|
||||||
numpy
|
numpy
|
||||||
matplotlib
|
matplotlib
|
||||||
pyyaml
|
pyyaml>=5.1
|
||||||
pandas
|
pandas>=0.23
|
||||||
|
scipy==1.2.1 # scipy 1.3.0rc1 is not compatible with salib
|
||||||
|
SALib>=1.3
|
||||||
|
Jinja2
|
||||||
|
@@ -1 +1 @@
|
|||||||
0.13.7
|
0.14.0
|
||||||
|
@@ -15,7 +15,7 @@ from . import agents
|
|||||||
from .simulation import *
|
from .simulation import *
|
||||||
from .environment import Environment
|
from .environment import Environment
|
||||||
from .history import History
|
from .history import History
|
||||||
from . import utils
|
from . import serialization
|
||||||
from . import analysis
|
from . import analysis
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -44,6 +44,8 @@ def main():
|
|||||||
help='folder to write results to. It defaults to the current directory.')
|
help='folder to write results to. It defaults to the current directory.')
|
||||||
parser.add_argument('--synchronous', action='store_true',
|
parser.add_argument('--synchronous', action='store_true',
|
||||||
help='Run trials serially and synchronously instead of in parallel. Defaults to false.')
|
help='Run trials serially and synchronously instead of in parallel. Defaults to false.')
|
||||||
|
parser.add_argument('-e', '--exporter', action='append',
|
||||||
|
help='Export environment and/or simulations using this exporter')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -55,17 +57,20 @@ def main():
|
|||||||
logging.info('Loading config file: {}'.format(args.file))
|
logging.info('Loading config file: {}'.format(args.file))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dump = []
|
exporters = list(args.exporter or [])
|
||||||
if not args.dry_run:
|
|
||||||
if args.csv:
|
if args.csv:
|
||||||
dump.append('csv')
|
exporters.append('CSV')
|
||||||
if args.graph:
|
if args.graph:
|
||||||
dump.append('gexf')
|
exporters.append('Gexf')
|
||||||
|
exp_params = {}
|
||||||
|
if args.dry_run:
|
||||||
|
exp_params['copy_to'] = sys.stdout
|
||||||
simulation.run_from_config(args.file,
|
simulation.run_from_config(args.file,
|
||||||
dry_run=args.dry_run,
|
dry_run=args.dry_run,
|
||||||
dump=dump,
|
exporters=exporters,
|
||||||
parallel=(not args.synchronous),
|
parallel=(not args.synchronous),
|
||||||
results_dir=args.output)
|
outdir=args.output,
|
||||||
|
exporter_params=exp_params)
|
||||||
except Exception:
|
except Exception:
|
||||||
if args.pdb:
|
if args.pdb:
|
||||||
pdb.post_mortem()
|
pdb.post_mortem()
|
||||||
|
@@ -22,11 +22,17 @@ class AggregatedCounter(BaseAgent):
|
|||||||
in each step and adds it to its state.
|
in each step and adds it to its state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
defaults = {
|
||||||
|
'times': 0,
|
||||||
|
'neighbors': 0,
|
||||||
|
'total': 0
|
||||||
|
}
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
# Outside effects
|
# Outside effects
|
||||||
total = len(list(self.get_all_agents()))
|
self['times'] += 1
|
||||||
neighbors = len(list(self.get_neighboring_agents()))
|
neighbors = len(list(self.get_neighboring_agents()))
|
||||||
self['times'] = self.get('times', 0) + 1
|
self['neighbors'] += neighbors
|
||||||
self['neighbors'] = self.get('neighbors', 0) + neighbors
|
total = len(list(self.get_all_agents()))
|
||||||
self['total'] = total = self.get('total', 0) + total
|
self['total'] += total
|
||||||
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
||||||
|
@@ -1,18 +0,0 @@
|
|||||||
from . import BaseAgent
|
|
||||||
|
|
||||||
import os.path
|
|
||||||
import matplotlib
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import networkx as nx
|
|
||||||
|
|
||||||
|
|
||||||
class DrawingAgent(BaseAgent):
|
|
||||||
"""
|
|
||||||
Agent that draws the state of the network.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
# Outside effects
|
|
||||||
f = plt.figure()
|
|
||||||
nx.draw(self.env.G, node_size=10, width=0.2, pos=nx.spring_layout(self.env.G, scale=100), ax=f.add_subplot(111))
|
|
||||||
f.savefig(os.path.join(self.env.get_path(), "graph-"+str(self.env.now)+".png"))
|
|
@@ -10,11 +10,18 @@ import logging
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from scipy.spatial import cKDTree as KDTree
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
from .. import utils, history
|
from .. import serialization, history
|
||||||
|
|
||||||
|
|
||||||
|
def as_node(agent):
|
||||||
|
if isinstance(agent, BaseAgent):
|
||||||
|
return agent.id
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(nxsim.BaseAgent):
|
class BaseAgent(nxsim.BaseAgent):
|
||||||
@@ -46,8 +53,7 @@ class BaseAgent(nxsim.BaseAgent):
|
|||||||
|
|
||||||
if not hasattr(self, 'level'):
|
if not hasattr(self, 'level'):
|
||||||
self.level = logging.DEBUG
|
self.level = logging.DEBUG
|
||||||
self.logger = logging.getLogger('{}.{}'.format(self.env.name,
|
self.logger = logging.getLogger(self.env.name)
|
||||||
self.id))
|
|
||||||
self.logger.setLevel(self.level)
|
self.logger.setLevel(self.level)
|
||||||
|
|
||||||
# initialize every time an instance of the agent is created
|
# initialize every time an instance of the agent is created
|
||||||
@@ -134,43 +140,21 @@ class BaseAgent(nxsim.BaseAgent):
|
|||||||
def step(self):
|
def step(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def count_agents(self, state_id=None, limit_neighbors=False):
|
def count_agents(self, **kwargs):
|
||||||
|
return len(list(self.get_agents(**kwargs)))
|
||||||
|
|
||||||
|
def count_neighboring_agents(self, state_id=None, **kwargs):
|
||||||
|
return len(super().get_neighboring_agents(state_id=state_id, **kwargs))
|
||||||
|
|
||||||
|
def get_neighboring_agents(self, state_id=None, **kwargs):
|
||||||
|
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
|
||||||
|
|
||||||
|
def get_agents(self, agents=None, limit_neighbors=False, **kwargs):
|
||||||
if limit_neighbors:
|
if limit_neighbors:
|
||||||
agents = self.global_topology.neighbors(self.id)
|
agents = super().get_agents(limit_neighbors=limit_neighbors)
|
||||||
else:
|
else:
|
||||||
agents = self.global_topology.nodes()
|
agents = self.env.get_agents(agents)
|
||||||
count = 0
|
return select(agents, **kwargs)
|
||||||
for agent in agents:
|
|
||||||
if state_id and state_id != self.global_topology.node[agent]['agent']['id']:
|
|
||||||
continue
|
|
||||||
count += 1
|
|
||||||
return count
|
|
||||||
|
|
||||||
def count_neighboring_agents(self, state_id=None):
|
|
||||||
return len(super().get_agents(state_id, limit_neighbors=True))
|
|
||||||
|
|
||||||
def get_agents(self, state_id=None, agent_type=None, limit_neighbors=False, iterator=False, **kwargs):
|
|
||||||
agents = self.env.agents
|
|
||||||
if limit_neighbors:
|
|
||||||
agents = super().get_agents(state_id, limit_neighbors)
|
|
||||||
|
|
||||||
def matches_all(agent):
|
|
||||||
if state_id is not None:
|
|
||||||
if agent.state.get('id', None) != state_id:
|
|
||||||
return False
|
|
||||||
if agent_type is not None:
|
|
||||||
if type(agent) != agent_type:
|
|
||||||
return False
|
|
||||||
state = agent.state
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
if state.get(k, None) != v:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
f = filter(matches_all, agents)
|
|
||||||
if iterator:
|
|
||||||
return f
|
|
||||||
return list(f)
|
|
||||||
|
|
||||||
def log(self, message, *args, level=logging.INFO, **kwargs):
|
def log(self, message, *args, level=logging.INFO, **kwargs):
|
||||||
message = message + " ".join(str(i) for i in args)
|
message = message + " ".join(str(i) for i in args)
|
||||||
@@ -208,12 +192,52 @@ class BaseAgent(nxsim.BaseAgent):
|
|||||||
self._state = state['_state']
|
self._state = state['_state']
|
||||||
self.env = state['environment']
|
self.env = state['environment']
|
||||||
|
|
||||||
|
def add_edge(self, node1, node2, **attrs):
|
||||||
|
node1 = as_node(node1)
|
||||||
|
node2 = as_node(node2)
|
||||||
|
|
||||||
def state(func):
|
for n in [node1, node2]:
|
||||||
|
if n not in self.global_topology.nodes(data=False):
|
||||||
|
raise ValueError('"{}" not in the graph'.format(n))
|
||||||
|
return self.global_topology.add_edge(node1, node2, **attrs)
|
||||||
|
|
||||||
|
def subgraph(self, center=True, **kwargs):
|
||||||
|
include = [self] if center else []
|
||||||
|
return self.global_topology.subgraph(n.id for n in self.get_agents(**kwargs)+include)
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkAgent(BaseAgent):
|
||||||
|
|
||||||
|
def add_edge(self, other, **kwargs):
|
||||||
|
return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
|
||||||
|
|
||||||
|
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
||||||
|
'''Get a list of nodes in the ego network of *node* of radius *steps*'''
|
||||||
|
node = as_node(node if node is not None else self)
|
||||||
|
G = self.subgraph(**kwargs)
|
||||||
|
return nx.ego_graph(G, node, center=center, radius=steps).nodes()
|
||||||
|
|
||||||
|
def degree(self, node, force=False):
|
||||||
|
node = as_node(node)
|
||||||
|
if force or (not hasattr(self.env, '_degree')) or getattr(self.env, '_last_step', 0) < self.now:
|
||||||
|
self.env._degree = nx.degree_centrality(self.global_topology)
|
||||||
|
self.env._last_step = self.now
|
||||||
|
return self.env._degree[node]
|
||||||
|
|
||||||
|
def betweenness(self, node, force=False):
|
||||||
|
node = as_node(node)
|
||||||
|
if force or (not hasattr(self.env, '_betweenness')) or getattr(self.env, '_last_step', 0) < self.now:
|
||||||
|
self.env._betweenness = nx.betweenness_centrality(self.global_topology)
|
||||||
|
self.env._last_step = self.now
|
||||||
|
return self.env._betweenness[node]
|
||||||
|
|
||||||
|
|
||||||
|
def state(name=None):
|
||||||
|
def decorator(func, name=None):
|
||||||
'''
|
'''
|
||||||
A state function should return either a state id, or a tuple (state_id, when)
|
A state function should return either a state id, or a tuple (state_id, when)
|
||||||
The default value for state_id is the current state id.
|
The default value for state_id is the current state id.
|
||||||
The default value for when is the interval defined in the nevironment.
|
The default value for when is the interval defined in the environment.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
@@ -230,10 +254,15 @@ def state(func):
|
|||||||
self.set_state(next_state)
|
self.set_state(next_state)
|
||||||
return when
|
return when
|
||||||
|
|
||||||
func_wrapper.id = func.__name__
|
func_wrapper.id = name or func.__name__
|
||||||
func_wrapper.is_default = False
|
func_wrapper.is_default = False
|
||||||
return func_wrapper
|
return func_wrapper
|
||||||
|
|
||||||
|
if callable(name):
|
||||||
|
return decorator(name)
|
||||||
|
else:
|
||||||
|
return partial(decorator, name=name)
|
||||||
|
|
||||||
|
|
||||||
def default_state(func):
|
def default_state(func):
|
||||||
func.is_default = True
|
func.is_default = True
|
||||||
@@ -340,7 +369,7 @@ def calculate_distribution(network_agents=None,
|
|||||||
elif agent_type:
|
elif agent_type:
|
||||||
network_agents = [{'agent_type': agent_type}]
|
network_agents = [{'agent_type': agent_type}]
|
||||||
else:
|
else:
|
||||||
return []
|
raise ValueError('Specify a distribution or a default agent type')
|
||||||
|
|
||||||
# Calculate the thresholds
|
# Calculate the thresholds
|
||||||
total = sum(x.get('weight', 1) for x in network_agents)
|
total = sum(x.get('weight', 1) for x in network_agents)
|
||||||
@@ -359,7 +388,7 @@ def serialize_type(agent_type, known_modules=[], **kwargs):
|
|||||||
if isinstance(agent_type, str):
|
if isinstance(agent_type, str):
|
||||||
return agent_type
|
return agent_type
|
||||||
known_modules += ['soil.agents']
|
known_modules += ['soil.agents']
|
||||||
return utils.serialize(agent_type, known_modules=known_modules, **kwargs)[1] # Get the name of the class
|
return serialization.serialize(agent_type, known_modules=known_modules, **kwargs)[1] # Get the name of the class
|
||||||
|
|
||||||
|
|
||||||
def serialize_distribution(network_agents, known_modules=[]):
|
def serialize_distribution(network_agents, known_modules=[]):
|
||||||
@@ -380,7 +409,7 @@ def deserialize_type(agent_type, known_modules=[]):
|
|||||||
if not isinstance(agent_type, str):
|
if not isinstance(agent_type, str):
|
||||||
return agent_type
|
return agent_type
|
||||||
known = known_modules + ['soil.agents', 'soil.agents.custom' ]
|
known = known_modules + ['soil.agents', 'soil.agents.custom' ]
|
||||||
agent_type = utils.deserializer(agent_type, known_modules=known)
|
agent_type = serialization.deserializer(agent_type, known_modules=known)
|
||||||
return agent_type
|
return agent_type
|
||||||
|
|
||||||
|
|
||||||
@@ -427,6 +456,58 @@ def _agent_from_distribution(distribution, value=-1, agent_id=None):
|
|||||||
raise Exception('Distribution for value {} not found in: {}'.format(value, distribution))
|
raise Exception('Distribution for value {} not found in: {}'.format(value, distribution))
|
||||||
|
|
||||||
|
|
||||||
|
class Geo(NetworkAgent):
|
||||||
|
'''In this type of network, nodes have a "pos" attribute.'''
|
||||||
|
|
||||||
|
def geo_search(self, radius, node=None, center=False, **kwargs):
|
||||||
|
'''Get a list of nodes whose coordinates are closer than *radius* to *node*.'''
|
||||||
|
node = as_node(node if node is not None else self)
|
||||||
|
|
||||||
|
G = self.subgraph(**kwargs)
|
||||||
|
|
||||||
|
pos = nx.get_node_attributes(G, 'pos')
|
||||||
|
if not pos:
|
||||||
|
return []
|
||||||
|
nodes, coords = list(zip(*pos.items()))
|
||||||
|
kdtree = KDTree(coords) # Cannot provide generator.
|
||||||
|
indices = kdtree.query_ball_point(pos[node], radius)
|
||||||
|
return [nodes[i] for i in indices if center or (nodes[i] != node)]
|
||||||
|
|
||||||
|
|
||||||
|
def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs):
|
||||||
|
|
||||||
|
if state_id is not None:
|
||||||
|
try:
|
||||||
|
state_id = tuple(state_id)
|
||||||
|
except TypeError:
|
||||||
|
state_id = tuple([state_id])
|
||||||
|
if agent_type is not None:
|
||||||
|
try:
|
||||||
|
agent_type = tuple(agent_type)
|
||||||
|
except TypeError:
|
||||||
|
agent_type = tuple([agent_type])
|
||||||
|
|
||||||
|
def matches_all(agent):
|
||||||
|
if state_id is not None:
|
||||||
|
if agent.state.get('id', None) not in state_id:
|
||||||
|
return False
|
||||||
|
if agent_type is not None:
|
||||||
|
if not isinstance(agent, agent_type):
|
||||||
|
return False
|
||||||
|
state = agent.state
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if state.get(k, None) != v:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
f = filter(matches_all, agents)
|
||||||
|
if ignore:
|
||||||
|
f = filter(lambda x: x not in ignore, f)
|
||||||
|
if iterator:
|
||||||
|
return f
|
||||||
|
return list(f)
|
||||||
|
|
||||||
|
|
||||||
from .BassModel import *
|
from .BassModel import *
|
||||||
from .BigMarketModel import *
|
from .BigMarketModel import *
|
||||||
from .IndependentCascadeModel import *
|
from .IndependentCascadeModel import *
|
||||||
@@ -434,4 +515,3 @@ from .ModelM2 import *
|
|||||||
from .SentimentCorrelationModel import *
|
from .SentimentCorrelationModel import *
|
||||||
from .SISaModel import *
|
from .SISaModel import *
|
||||||
from .CounterModel import *
|
from .CounterModel import *
|
||||||
from .DrawingAgent import *
|
|
||||||
|
@@ -4,7 +4,7 @@ import glob
|
|||||||
import yaml
|
import yaml
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
|
||||||
from . import utils, history
|
from . import serialization, history
|
||||||
|
|
||||||
|
|
||||||
def read_data(*args, group=False, **kwargs):
|
def read_data(*args, group=False, **kwargs):
|
||||||
@@ -34,7 +34,7 @@ def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def read_sql(db, *args, **kwargs):
|
def read_sql(db, *args, **kwargs):
|
||||||
h = history.History(db, backup=False)
|
h = history.History(db_path=db, backup=False)
|
||||||
df = h.read_sql(*args, **kwargs)
|
df = h.read_sql(*args, **kwargs)
|
||||||
return df
|
return df
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ def read_csv(filename, keys=None, convert_types=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def convert_row(row):
|
def convert_row(row):
|
||||||
row['value'] = utils.deserialize(row['value_type'], row['value'])
|
row['value'] = serialization.deserialize(row['value_type'], row['value'])
|
||||||
return row
|
return row
|
||||||
|
|
||||||
|
|
||||||
|
@@ -14,15 +14,13 @@ from networkx.readwrite import json_graph
|
|||||||
import networkx as nx
|
import networkx as nx
|
||||||
import nxsim
|
import nxsim
|
||||||
|
|
||||||
from . import utils, agents, analysis, history
|
from . import serialization, agents, analysis, history, utils
|
||||||
|
|
||||||
# These properties will be copied when pickling/unpickling the environment
|
# These properties will be copied when pickling/unpickling the environment
|
||||||
_CONFIG_PROPS = [ 'name',
|
_CONFIG_PROPS = [ 'name',
|
||||||
'states',
|
'states',
|
||||||
'default_state',
|
'default_state',
|
||||||
'interval',
|
'interval',
|
||||||
'dry_run',
|
|
||||||
'dir_path',
|
|
||||||
]
|
]
|
||||||
|
|
||||||
class Environment(nxsim.NetworkEnvironment):
|
class Environment(nxsim.NetworkEnvironment):
|
||||||
@@ -43,8 +41,6 @@ class Environment(nxsim.NetworkEnvironment):
|
|||||||
default_state=None,
|
default_state=None,
|
||||||
interval=1,
|
interval=1,
|
||||||
seed=None,
|
seed=None,
|
||||||
dry_run=False,
|
|
||||||
dir_path=None,
|
|
||||||
topology=None,
|
topology=None,
|
||||||
*args, **kwargs):
|
*args, **kwargs):
|
||||||
self.name = name or 'UnnamedEnvironment'
|
self.name = name or 'UnnamedEnvironment'
|
||||||
@@ -56,13 +52,8 @@ class Environment(nxsim.NetworkEnvironment):
|
|||||||
topology = nx.Graph()
|
topology = nx.Graph()
|
||||||
super().__init__(*args, topology=topology, **kwargs)
|
super().__init__(*args, topology=topology, **kwargs)
|
||||||
self._env_agents = {}
|
self._env_agents = {}
|
||||||
self.dry_run = dry_run
|
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.dir_path = dir_path or tempfile.mkdtemp('soil-env')
|
self._history = history.History(name=self.name,
|
||||||
if not dry_run:
|
|
||||||
self.get_path()
|
|
||||||
self._history = history.History(name=self.name if not dry_run else None,
|
|
||||||
dir_path=self.dir_path,
|
|
||||||
backup=True)
|
backup=True)
|
||||||
# Add environment agents first, so their events get
|
# Add environment agents first, so their events get
|
||||||
# executed before network agents
|
# executed before network agents
|
||||||
@@ -102,8 +93,7 @@ class Environment(nxsim.NetworkEnvironment):
|
|||||||
|
|
||||||
@network_agents.setter
|
@network_agents.setter
|
||||||
def network_agents(self, network_agents):
|
def network_agents(self, network_agents):
|
||||||
if not network_agents:
|
self._network_agents = network_agents
|
||||||
return
|
|
||||||
for ix in self.G.nodes():
|
for ix in self.G.nodes():
|
||||||
self.init_agent(ix, agent_distribution=network_agents)
|
self.init_agent(ix, agent_distribution=network_agents)
|
||||||
|
|
||||||
@@ -124,6 +114,9 @@ class Environment(nxsim.NetworkEnvironment):
|
|||||||
agent_type = agents.deserialize_type(agent_type)
|
agent_type = agents.deserialize_type(agent_type)
|
||||||
elif agent_distribution:
|
elif agent_distribution:
|
||||||
agent_type, state = agents._agent_from_distribution(agent_distribution, agent_id=agent_id)
|
agent_type, state = agents._agent_from_distribution(agent_distribution, agent_id=agent_id)
|
||||||
|
else:
|
||||||
|
serialization.logger.debug('Skipping node {}'.format(agent_id))
|
||||||
|
return
|
||||||
return self.set_agent(agent_id, agent_type, state)
|
return self.set_agent(agent_id, agent_type, state)
|
||||||
|
|
||||||
def set_agent(self, agent_id, agent_type, state=None):
|
def set_agent(self, agent_id, agent_type, state=None):
|
||||||
@@ -149,12 +142,13 @@ class Environment(nxsim.NetworkEnvironment):
|
|||||||
a['visible'] = True
|
a['visible'] = True
|
||||||
return a
|
return a
|
||||||
|
|
||||||
def add_edge(self, agent1, agent2, attrs=None):
|
def add_edge(self, agent1, agent2, start=None, **attrs):
|
||||||
if hasattr(agent1, 'id'):
|
if hasattr(agent1, 'id'):
|
||||||
agent1 = agent1.id
|
agent1 = agent1.id
|
||||||
if hasattr(agent2, 'id'):
|
if hasattr(agent2, 'id'):
|
||||||
agent2 = agent2.id
|
agent2 = agent2.id
|
||||||
return self.G.add_edge(agent1, agent2)
|
start = start or self.now
|
||||||
|
return self.G.add_edge(agent1, agent2, **attrs)
|
||||||
|
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
self._save_state()
|
self._save_state()
|
||||||
@@ -164,9 +158,7 @@ class Environment(nxsim.NetworkEnvironment):
|
|||||||
self.log_stats()
|
self.log_stats()
|
||||||
|
|
||||||
def _save_state(self, now=None):
|
def _save_state(self, now=None):
|
||||||
# for agent in self.agents:
|
serialization.logger.debug('Saving state @{}'.format(self.now))
|
||||||
# agent.save_state()
|
|
||||||
utils.logger.debug('Saving state @{}'.format(self.now))
|
|
||||||
self._history.save_records(self.state_to_tuples(now=now))
|
self._history.save_records(self.state_to_tuples(now=now))
|
||||||
|
|
||||||
def save_state(self):
|
def save_state(self):
|
||||||
@@ -177,7 +169,7 @@ class Environment(nxsim.NetworkEnvironment):
|
|||||||
self._save_state()
|
self._save_state()
|
||||||
while self.peek() != simpy.core.Infinity:
|
while self.peek() != simpy.core.Infinity:
|
||||||
delay = max(self.peek() - self.now, self.interval)
|
delay = max(self.peek() - self.now, self.interval)
|
||||||
utils.logger.debug('Step: {}'.format(self.now))
|
serialization.logger.debug('Step: {}'.format(self.now))
|
||||||
ev = self.event()
|
ev = self.event()
|
||||||
ev._ok = True
|
ev._ok = True
|
||||||
# Schedule the event with minimum priority so
|
# Schedule the event with minimum priority so
|
||||||
@@ -219,35 +211,23 @@ class Environment(nxsim.NetworkEnvironment):
|
|||||||
'''
|
'''
|
||||||
return self[key] if key in self else default
|
return self[key] if key in self else default
|
||||||
|
|
||||||
def get_path(self, dir_path=None):
|
|
||||||
dir_path = dir_path or self.dir_path
|
|
||||||
if not os.path.exists(dir_path):
|
|
||||||
try:
|
|
||||||
os.makedirs(dir_path)
|
|
||||||
except FileExistsError:
|
|
||||||
pass
|
|
||||||
return dir_path
|
|
||||||
|
|
||||||
def get_agent(self, agent_id):
|
def get_agent(self, agent_id):
|
||||||
return self.G.node[agent_id]['agent']
|
return self.G.node[agent_id]['agent']
|
||||||
|
|
||||||
def get_agents(self):
|
def get_agents(self, nodes=None):
|
||||||
|
if nodes is None:
|
||||||
return list(self.agents)
|
return list(self.agents)
|
||||||
|
return [self.G.node[i]['agent'] for i in nodes]
|
||||||
|
|
||||||
def dump_csv(self, dir_path=None):
|
def dump_csv(self, f):
|
||||||
csv_name = os.path.join(self.get_path(dir_path),
|
with utils.open_or_reuse(f, 'w') as f:
|
||||||
'{}.environment.csv'.format(self.name))
|
|
||||||
|
|
||||||
with open(csv_name, 'w') as f:
|
|
||||||
cr = csv.writer(f)
|
cr = csv.writer(f)
|
||||||
cr.writerow(('agent_id', 't_step', 'key', 'value'))
|
cr.writerow(('agent_id', 't_step', 'key', 'value'))
|
||||||
for i in self.history_to_tuples():
|
for i in self.history_to_tuples():
|
||||||
cr.writerow(i)
|
cr.writerow(i)
|
||||||
|
|
||||||
def dump_gexf(self, dir_path=None):
|
def dump_gexf(self, f):
|
||||||
G = self.history_to_graph()
|
G = self.history_to_graph()
|
||||||
graph_path = os.path.join(self.get_path(dir_path),
|
|
||||||
self.name+".gexf")
|
|
||||||
# Workaround for geometric models
|
# Workaround for geometric models
|
||||||
# See soil/soil#4
|
# See soil/soil#4
|
||||||
for node in G.nodes():
|
for node in G.nodes():
|
||||||
@@ -255,9 +235,9 @@ class Environment(nxsim.NetworkEnvironment):
|
|||||||
G.node[node]['viz'] = {"position": {"x": G.node[node]['pos'][0], "y": G.node[node]['pos'][1], "z": 0.0}}
|
G.node[node]['viz'] = {"position": {"x": G.node[node]['pos'][0], "y": G.node[node]['pos'][1], "z": 0.0}}
|
||||||
del (G.node[node]['pos'])
|
del (G.node[node]['pos'])
|
||||||
|
|
||||||
nx.write_gexf(G, graph_path, version="1.2draft")
|
nx.write_gexf(G, f, version="1.2draft")
|
||||||
|
|
||||||
def dump(self, dir_path=None, formats=None):
|
def dump(self, *args, formats=None, **kwargs):
|
||||||
if not formats:
|
if not formats:
|
||||||
return
|
return
|
||||||
functions = {
|
functions = {
|
||||||
@@ -266,10 +246,13 @@ class Environment(nxsim.NetworkEnvironment):
|
|||||||
}
|
}
|
||||||
for f in formats:
|
for f in formats:
|
||||||
if f in functions:
|
if f in functions:
|
||||||
functions[f](dir_path)
|
functions[f](*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unknown format: {}'.format(f))
|
raise ValueError('Unknown format: {}'.format(f))
|
||||||
|
|
||||||
|
def dump_sqlite(self, f):
|
||||||
|
return self._history.dump(f)
|
||||||
|
|
||||||
def state_to_tuples(self, now=None):
|
def state_to_tuples(self, now=None):
|
||||||
if now is None:
|
if now is None:
|
||||||
now = self.now
|
now = self.now
|
||||||
@@ -351,7 +334,7 @@ class Environment(nxsim.NetworkEnvironment):
|
|||||||
|
|
||||||
def log_stats(self):
|
def log_stats(self):
|
||||||
stats = self.stats()
|
stats = self.stats()
|
||||||
utils.logger.info('Environment stats: \n{}'.format(yaml.dump(stats, default_flow_style=False)))
|
serialization.logger.info('Environment stats: \n{}'.format(yaml.dump(stats, default_flow_style=False)))
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
state = {}
|
state = {}
|
||||||
|
175
soil/exporters.py
Normal file
175
soil/exporters.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import networkx as nx
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from .serialization import deserialize
|
||||||
|
from .utils import open_or_reuse, logger, timer
|
||||||
|
|
||||||
|
|
||||||
|
from . import utils
|
||||||
|
|
||||||
|
|
||||||
|
def for_sim(simulation, names, *args, **kwargs):
|
||||||
|
'''Return the set of exporters for a simulation, given the exporter names'''
|
||||||
|
exporters = []
|
||||||
|
for name in names:
|
||||||
|
mod = deserialize(name, known_modules=['soil.exporters'])
|
||||||
|
exporters.append(mod(simulation, *args, **kwargs))
|
||||||
|
return exporters
|
||||||
|
|
||||||
|
|
||||||
|
class DryRunner(BytesIO):
|
||||||
|
def __init__(self, fname, *args, copy_to=None, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.__fname = fname
|
||||||
|
self.__copy_to = copy_to
|
||||||
|
|
||||||
|
def write(self, txt):
|
||||||
|
if self.__copy_to:
|
||||||
|
self.__copy_to.write('{}:::{}'.format(self.__fname, txt))
|
||||||
|
try:
|
||||||
|
super().write(txt)
|
||||||
|
except TypeError:
|
||||||
|
super().write(bytes(txt, 'utf-8'))
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
logger.info('**Not** written to {} (dry run mode):\n\n{}\n\n'.format(self.__fname,
|
||||||
|
self.getvalue().decode()))
|
||||||
|
super().close()
|
||||||
|
|
||||||
|
|
||||||
|
class Exporter:
|
||||||
|
'''
|
||||||
|
Interface for all exporters. It is not necessary, but it is useful
|
||||||
|
if you don't plan to implement all the methods.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None):
|
||||||
|
self.sim = simulation
|
||||||
|
outdir = outdir or os.getcwd()
|
||||||
|
self.outdir = os.path.join(outdir,
|
||||||
|
simulation.group or '',
|
||||||
|
simulation.name)
|
||||||
|
self.dry_run = dry_run
|
||||||
|
self.copy_to = copy_to
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
'''Method to call when the simulation starts'''
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
'''Method to call when the simulation ends'''
|
||||||
|
|
||||||
|
def trial_end(self, env):
|
||||||
|
'''Method to call when a trial ends'''
|
||||||
|
|
||||||
|
def output(self, f, mode='w', **kwargs):
|
||||||
|
if self.dry_run:
|
||||||
|
f = DryRunner(f, copy_to=self.copy_to)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
if not os.path.isabs(f):
|
||||||
|
f = os.path.join(self.outdir, f)
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
return open_or_reuse(f, mode=mode, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Default(Exporter):
|
||||||
|
'''Default exporter. Writes CSV and sqlite results, as well as the simulation YAML'''
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
if not self.dry_run:
|
||||||
|
logger.info('Dumping results to %s', self.outdir)
|
||||||
|
self.sim.dump_yaml(outdir=self.outdir)
|
||||||
|
else:
|
||||||
|
logger.info('NOT dumping results')
|
||||||
|
|
||||||
|
def trial_end(self, env):
|
||||||
|
if not self.dry_run:
|
||||||
|
with timer('Dumping simulation {} trial {}'.format(self.sim.name,
|
||||||
|
env.name)):
|
||||||
|
with self.output('{}.sqlite'.format(env.name), mode='wb') as f:
|
||||||
|
env.dump_sqlite(f)
|
||||||
|
|
||||||
|
|
||||||
|
class CSV(Exporter):
|
||||||
|
def trial_end(self, env):
|
||||||
|
if not self.dry_run:
|
||||||
|
with timer('[CSV] Dumping simulation {} trial {}'.format(self.sim.name,
|
||||||
|
env.name)):
|
||||||
|
with self.output('{}.csv'.format(env.name)) as f:
|
||||||
|
env.dump_csv(f)
|
||||||
|
|
||||||
|
|
||||||
|
class Gexf(Exporter):
|
||||||
|
def trial_end(self, env):
|
||||||
|
if not self.dry_run:
|
||||||
|
with timer('[CSV] Dumping simulation {} trial {}'.format(self.sim.name,
|
||||||
|
env.name)):
|
||||||
|
with self.output('{}.gexf'.format(env.name), mode='wb') as f:
|
||||||
|
env.dump_gexf(f)
|
||||||
|
|
||||||
|
|
||||||
|
class Dummy(Exporter):
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
with self.output('dummy', 'w') as f:
|
||||||
|
f.write('simulation started @ {}\n'.format(time.time()))
|
||||||
|
|
||||||
|
def trial_end(self, env):
|
||||||
|
with self.output('dummy', 'w') as f:
|
||||||
|
for i in env.history_to_tuples():
|
||||||
|
f.write(','.join(map(str, i)))
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
with self.output('dummy', 'a') as f:
|
||||||
|
f.write('simulation ended @ {}\n'.format(time.time()))
|
||||||
|
|
||||||
|
|
||||||
|
class Distribution(Exporter):
|
||||||
|
'''
|
||||||
|
Write the distribution of agent states at the end of each trial,
|
||||||
|
the mean value, and its deviation.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.means = []
|
||||||
|
self.counts = []
|
||||||
|
|
||||||
|
def trial_end(self, env):
|
||||||
|
df = env[None, None, None].df()
|
||||||
|
ix = df.index[-1]
|
||||||
|
attrs = df.columns.levels[0]
|
||||||
|
vc = {}
|
||||||
|
stats = {}
|
||||||
|
for a in attrs:
|
||||||
|
t = df.loc[(ix, a)]
|
||||||
|
try:
|
||||||
|
self.means.append(('mean', a, t.mean()))
|
||||||
|
except TypeError:
|
||||||
|
for name, count in t.value_counts().iteritems():
|
||||||
|
self.counts.append(('count', a, name, count))
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value'])
|
||||||
|
dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count'])
|
||||||
|
dfm = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
||||||
|
dfc = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
||||||
|
with self.output('counts.csv') as f:
|
||||||
|
dfc.to_csv(f)
|
||||||
|
with self.output('metrics.csv') as f:
|
||||||
|
dfm.to_csv(f)
|
||||||
|
|
||||||
|
class GraphDrawing(Exporter):
|
||||||
|
|
||||||
|
def trial_end(self, env):
|
||||||
|
# Outside effects
|
||||||
|
f = plt.figure()
|
||||||
|
nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111))
|
||||||
|
with open('graph-{}.png'.format(env.name)) as f:
|
||||||
|
f.savefig(f)
|
@@ -4,12 +4,13 @@ import pandas as pd
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import tempfile
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from collections import UserDict, namedtuple
|
from collections import UserDict, namedtuple
|
||||||
|
|
||||||
from . import utils
|
from . import serialization
|
||||||
|
|
||||||
|
|
||||||
class History:
|
class History:
|
||||||
@@ -17,16 +18,18 @@ class History:
|
|||||||
Store and retrieve values from a sqlite database.
|
Store and retrieve values from a sqlite database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, db_path=None, name=None, dir_path=None, backup=False):
|
def __init__(self, name=None, db_path=None, backup=False):
|
||||||
if db_path is None and name:
|
self._db = None
|
||||||
db_path = os.path.join(dir_path or os.getcwd(),
|
|
||||||
'{}.db.sqlite'.format(name))
|
if db_path is None:
|
||||||
if db_path:
|
if not name:
|
||||||
|
name = time.time()
|
||||||
|
_, db_path = tempfile.mkstemp(suffix='{}.sqlite'.format(name))
|
||||||
|
|
||||||
if backup and os.path.exists(db_path):
|
if backup and os.path.exists(db_path):
|
||||||
newname = db_path + '.backup{}.sqlite'.format(time.time())
|
newname = db_path + '.backup{}.sqlite'.format(time.time())
|
||||||
os.rename(db_path, newname)
|
os.rename(db_path, newname)
|
||||||
else:
|
|
||||||
db_path = ":memory:"
|
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
|
|
||||||
self.db = db_path
|
self.db = db_path
|
||||||
@@ -49,6 +52,7 @@ class History:
|
|||||||
|
|
||||||
@db.setter
|
@db.setter
|
||||||
def db(self, db_path=None):
|
def db(self, db_path=None):
|
||||||
|
self._close()
|
||||||
db_path = db_path or self.db_path
|
db_path = db_path or self.db_path
|
||||||
if isinstance(db_path, str):
|
if isinstance(db_path, str):
|
||||||
logger.debug('Connecting to database {}'.format(db_path))
|
logger.debug('Connecting to database {}'.format(db_path))
|
||||||
@@ -56,6 +60,13 @@ class History:
|
|||||||
else:
|
else:
|
||||||
self._db = db_path
|
self._db = db_path
|
||||||
|
|
||||||
|
def _close(self):
|
||||||
|
if self._db is None:
|
||||||
|
return
|
||||||
|
self.flush_cache()
|
||||||
|
self._db.close()
|
||||||
|
self._db = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtypes(self):
|
def dtypes(self):
|
||||||
self.read_types()
|
self.read_types()
|
||||||
@@ -94,9 +105,9 @@ class History:
|
|||||||
if key not in self._dtypes:
|
if key not in self._dtypes:
|
||||||
self.read_types()
|
self.read_types()
|
||||||
if key not in self._dtypes:
|
if key not in self._dtypes:
|
||||||
name = utils.name(value)
|
name = serialization.name(value)
|
||||||
serializer = utils.serializer(name)
|
serializer = serialization.serializer(name)
|
||||||
deserializer = utils.deserializer(name)
|
deserializer = serialization.deserializer(name)
|
||||||
self._dtypes[key] = (name, serializer, deserializer)
|
self._dtypes[key] = (name, serializer, deserializer)
|
||||||
with self.db:
|
with self.db:
|
||||||
self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name))
|
self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name))
|
||||||
@@ -110,7 +121,6 @@ class History:
|
|||||||
raise ValueError("Unknown datatype for {} and {}".format(key, value))
|
raise ValueError("Unknown datatype for {} and {}".format(key, value))
|
||||||
return self._dtypes[key][2](value)
|
return self._dtypes[key][2](value)
|
||||||
|
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
'''
|
'''
|
||||||
Use a cache to save state changes to avoid opening a session for every change.
|
Use a cache to save state changes to avoid opening a session for every change.
|
||||||
@@ -135,8 +145,8 @@ class History:
|
|||||||
with self.db:
|
with self.db:
|
||||||
res = self.db.execute("select key, value_type from value_types ").fetchall()
|
res = self.db.execute("select key, value_type from value_types ").fetchall()
|
||||||
for k, v in res:
|
for k, v in res:
|
||||||
serializer = utils.serializer(v)
|
serializer = serialization.serializer(v)
|
||||||
deserializer = utils.deserializer(v)
|
deserializer = serialization.deserializer(v)
|
||||||
self._dtypes[k] = (v, serializer, deserializer)
|
self._dtypes[k] = (v, serializer, deserializer)
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
@@ -154,8 +164,6 @@ class History:
|
|||||||
return r.value()
|
return r.value()
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def read_sql(self, keys=None, agent_ids=None, t_steps=None, convert_types=False, limit=-1):
|
def read_sql(self, keys=None, agent_ids=None, t_steps=None, convert_types=False, limit=-1):
|
||||||
|
|
||||||
self.read_types()
|
self.read_types()
|
||||||
@@ -224,6 +232,12 @@ class History:
|
|||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
self.__dict__ = state
|
self.__dict__ = state
|
||||||
self._dtypes = {}
|
self._dtypes = {}
|
||||||
|
self._db = None
|
||||||
|
|
||||||
|
def dump(self, f):
|
||||||
|
self._close()
|
||||||
|
for line in open(self.db_path, 'rb'):
|
||||||
|
f.write(line)
|
||||||
|
|
||||||
|
|
||||||
class Records():
|
class Records():
|
||||||
@@ -274,10 +288,13 @@ class Records():
|
|||||||
i = self._df[f.key][str(f.agent_id)]
|
i = self._df[f.key][str(f.agent_id)]
|
||||||
ix = i.index.get_loc(f.t_step, method='ffill')
|
ix = i.index.get_loc(f.t_step, method='ffill')
|
||||||
return i.iloc[ix]
|
return i.iloc[ix]
|
||||||
except KeyError:
|
except KeyError as ex:
|
||||||
return self.dtypes[f.key][2]()
|
return self.dtypes[f.key][2]()
|
||||||
return list(self)
|
return list(self)
|
||||||
|
|
||||||
|
def df(self):
|
||||||
|
return self._df
|
||||||
|
|
||||||
def __getitem__(self, k):
|
def __getitem__(self, k):
|
||||||
n = copy.copy(self)
|
n = copy.copy(self)
|
||||||
n.filter(k)
|
n.filter(k)
|
||||||
@@ -293,6 +310,5 @@ class Records():
|
|||||||
return str(self.value())
|
return str(self.value())
|
||||||
return '<Records for [{}]>'.format(self._filter)
|
return '<Records for [{}]>'.format(self._filter)
|
||||||
|
|
||||||
|
|
||||||
Key = namedtuple('Key', ['agent_id', 't_step', 'key'])
|
Key = namedtuple('Key', ['agent_id', 't_step', 'key'])
|
||||||
Record = namedtuple('Record', 'agent_id t_step key value')
|
Record = namedtuple('Record', 'agent_id t_step key value')
|
||||||
|
201
soil/serialization.py
Normal file
201
soil/serialization.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import ast
|
||||||
|
import sys
|
||||||
|
import importlib
|
||||||
|
from glob import glob
|
||||||
|
from itertools import product, chain
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
import networkx as nx
|
||||||
|
|
||||||
|
from jinja2 import Template
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger('soil')
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def load_network(network_params, dir_path=None):
|
||||||
|
if network_params is None:
|
||||||
|
return nx.Graph()
|
||||||
|
path = network_params.get('path', None)
|
||||||
|
if path:
|
||||||
|
if dir_path and not os.path.isabs(path):
|
||||||
|
path = os.path.join(dir_path, path)
|
||||||
|
extension = os.path.splitext(path)[1][1:]
|
||||||
|
kwargs = {}
|
||||||
|
if extension == 'gexf':
|
||||||
|
kwargs['version'] = '1.2draft'
|
||||||
|
kwargs['node_type'] = int
|
||||||
|
try:
|
||||||
|
method = getattr(nx.readwrite, 'read_' + extension)
|
||||||
|
except AttributeError:
|
||||||
|
raise AttributeError('Unknown format')
|
||||||
|
return method(path, **kwargs)
|
||||||
|
|
||||||
|
net_args = network_params.copy()
|
||||||
|
if 'generator' not in net_args:
|
||||||
|
return nx.Graph()
|
||||||
|
|
||||||
|
net_gen = net_args.pop('generator')
|
||||||
|
|
||||||
|
if dir_path not in sys.path:
|
||||||
|
sys.path.append(dir_path)
|
||||||
|
|
||||||
|
method = deserializer(net_gen,
|
||||||
|
known_modules=['networkx.generators',])
|
||||||
|
|
||||||
|
return method(**net_args)
|
||||||
|
|
||||||
|
|
||||||
|
def load_file(infile):
|
||||||
|
with open(infile, 'r') as f:
|
||||||
|
return list(chain.from_iterable(map(expand_template, load_string(f))))
|
||||||
|
|
||||||
|
|
||||||
|
def load_string(string):
|
||||||
|
yield from yaml.load_all(string, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_template(config):
|
||||||
|
if 'template' not in config:
|
||||||
|
yield config
|
||||||
|
return
|
||||||
|
if 'vars' not in config:
|
||||||
|
raise ValueError(('You must provide a definition of variables'
|
||||||
|
' for the template.'))
|
||||||
|
|
||||||
|
template = Template(config['template'])
|
||||||
|
|
||||||
|
sampler_name = config.get('sampler', 'SALib.sample.morris.sample')
|
||||||
|
n_samples = int(config.get('samples', 100))
|
||||||
|
sampler = deserializer(sampler_name)
|
||||||
|
bounds = config['vars']['bounds']
|
||||||
|
|
||||||
|
problem = {
|
||||||
|
'num_vars': len(bounds),
|
||||||
|
'names': list(bounds.keys()),
|
||||||
|
'bounds': list(v for v in bounds.values())
|
||||||
|
}
|
||||||
|
samples = sampler(problem, n_samples)
|
||||||
|
|
||||||
|
lists = config['vars'].get('lists', {})
|
||||||
|
names = list(lists.keys())
|
||||||
|
values = list(lists.values())
|
||||||
|
combs = list(product(*values))
|
||||||
|
|
||||||
|
allnames = names + problem['names']
|
||||||
|
allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)]
|
||||||
|
params = list(map(lambda x: dict(zip(allnames, x)), allvalues))
|
||||||
|
|
||||||
|
|
||||||
|
blank_str = template.render({k: 0 for k in allnames})
|
||||||
|
blank = list(load_string(blank_str))
|
||||||
|
if len(blank) > 1:
|
||||||
|
raise ValueError('Templates must not return more than one configuration')
|
||||||
|
if 'name' in blank[0]:
|
||||||
|
raise ValueError('Templates cannot be named, use group instead')
|
||||||
|
|
||||||
|
confs = []
|
||||||
|
for ps in params:
|
||||||
|
string = template.render(ps)
|
||||||
|
for c in load_string(string):
|
||||||
|
yield c
|
||||||
|
|
||||||
|
|
||||||
|
def load_files(*patterns, **kwargs):
|
||||||
|
for pattern in patterns:
|
||||||
|
for i in glob(pattern, **kwargs):
|
||||||
|
for config in load_file(i):
|
||||||
|
path = os.path.abspath(i)
|
||||||
|
if 'dir_path' not in config:
|
||||||
|
config['dir_path'] = os.path.dirname(path)
|
||||||
|
yield config, path
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config):
|
||||||
|
if isinstance(config, dict):
|
||||||
|
yield config, None
|
||||||
|
else:
|
||||||
|
yield from load_files(config)
|
||||||
|
|
||||||
|
|
||||||
|
builtins = importlib.import_module('builtins')
|
||||||
|
|
||||||
|
def name(value, known_modules=[]):
|
||||||
|
'''Return a name that can be imported, to serialize/deserialize an object'''
|
||||||
|
if value is None:
|
||||||
|
return 'None'
|
||||||
|
if not isinstance(value, type): # Get the class name first
|
||||||
|
value = type(value)
|
||||||
|
tname = value.__name__
|
||||||
|
if hasattr(builtins, tname):
|
||||||
|
return tname
|
||||||
|
modname = value.__module__
|
||||||
|
if modname == '__main__':
|
||||||
|
return tname
|
||||||
|
if known_modules and modname in known_modules:
|
||||||
|
return tname
|
||||||
|
for kmod in known_modules:
|
||||||
|
if not kmod:
|
||||||
|
continue
|
||||||
|
module = importlib.import_module(kmod)
|
||||||
|
if hasattr(module, tname):
|
||||||
|
return tname
|
||||||
|
return '{}.{}'.format(modname, tname)
|
||||||
|
|
||||||
|
|
||||||
|
def serializer(type_):
|
||||||
|
if type_ != 'str' and hasattr(builtins, type_):
|
||||||
|
return repr
|
||||||
|
return lambda x: x
|
||||||
|
|
||||||
|
|
||||||
|
def serialize(v, known_modules=[]):
|
||||||
|
'''Get a text representation of an object.'''
|
||||||
|
tname = name(v, known_modules=known_modules)
|
||||||
|
func = serializer(tname)
|
||||||
|
return func(v), tname
|
||||||
|
|
||||||
|
def deserializer(type_, known_modules=[]):
|
||||||
|
if type(type_) != str: # Already deserialized
|
||||||
|
return type_
|
||||||
|
if type_ == 'str':
|
||||||
|
return lambda x='': x
|
||||||
|
if type_ == 'None':
|
||||||
|
return lambda x=None: None
|
||||||
|
if hasattr(builtins, type_): # Check if it's a builtin type
|
||||||
|
cls = getattr(builtins, type_)
|
||||||
|
return lambda x=None: ast.literal_eval(x) if x is not None else cls()
|
||||||
|
# Otherwise, see if we can find the module and the class
|
||||||
|
modules = known_modules or []
|
||||||
|
options = []
|
||||||
|
|
||||||
|
for mod in modules:
|
||||||
|
if mod:
|
||||||
|
options.append((mod, type_))
|
||||||
|
|
||||||
|
if '.' in type_: # Fully qualified module
|
||||||
|
module, type_ = type_.rsplit(".", 1)
|
||||||
|
options.append ((module, type_))
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
for modname, tname in options:
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(modname)
|
||||||
|
cls = getattr(module, tname)
|
||||||
|
return getattr(cls, 'deserialize', cls)
|
||||||
|
except (ModuleNotFoundError, AttributeError) as ex:
|
||||||
|
errors.append((modname, tname, ex))
|
||||||
|
raise Exception('Could not find type {}. Tried: {}'.format(type_, errors))
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize(type_, value=None, **kwargs):
|
||||||
|
'''Get an object from a text representation'''
|
||||||
|
if not isinstance(type_, str):
|
||||||
|
return type_
|
||||||
|
des = deserializer(type_, **kwargs)
|
||||||
|
if value is None:
|
||||||
|
return des
|
||||||
|
return des(value)
|
@@ -13,9 +13,10 @@ import pickle
|
|||||||
|
|
||||||
from nxsim import NetworkSimulation
|
from nxsim import NetworkSimulation
|
||||||
|
|
||||||
from . import utils, basestring, agents
|
from . import serialization, utils, basestring, agents
|
||||||
from .environment import Environment
|
from .environment import Environment
|
||||||
from .utils import logger
|
from .utils import logger
|
||||||
|
from .exporters import for_sim as exporters_for_sim
|
||||||
|
|
||||||
|
|
||||||
class Simulation(NetworkSimulation):
|
class Simulation(NetworkSimulation):
|
||||||
@@ -50,6 +51,8 @@ class Simulation(NetworkSimulation):
|
|||||||
---------
|
---------
|
||||||
name : str, optional
|
name : str, optional
|
||||||
name of the Simulation
|
name of the Simulation
|
||||||
|
group : str, optional
|
||||||
|
a group name can be used to link simulations
|
||||||
topology : networkx.Graph instance, optional
|
topology : networkx.Graph instance, optional
|
||||||
network_params : dict
|
network_params : dict
|
||||||
parameters used to create a topology with networkx, if no topology is given
|
parameters used to create a topology with networkx, if no topology is given
|
||||||
@@ -61,7 +64,7 @@ class Simulation(NetworkSimulation):
|
|||||||
List of initial states corresponding to the nodes in the topology. Basic form is a list of integers
|
List of initial states corresponding to the nodes in the topology. Basic form is a list of integers
|
||||||
whose value indicates the state
|
whose value indicates the state
|
||||||
dir_path: str, optional
|
dir_path: str, optional
|
||||||
Directory path where to save pickled objects
|
Directory path to load simulation assets (files, modules...)
|
||||||
seed : str, optional
|
seed : str, optional
|
||||||
Seed to use for the random generator
|
Seed to use for the random generator
|
||||||
num_trials : int, optional
|
num_trials : int, optional
|
||||||
@@ -80,30 +83,29 @@ class Simulation(NetworkSimulation):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name=None, topology=None, network_params=None,
|
def __init__(self, name=None, group=None, topology=None, network_params=None,
|
||||||
network_agents=None, agent_type=None, states=None,
|
network_agents=None, agent_type=None, states=None,
|
||||||
default_state=None, interval=1, dump=None, dry_run=False,
|
default_state=None, interval=1, num_trials=1,
|
||||||
dir_path=None, num_trials=1, max_time=100,
|
max_time=100, load_module=None, seed=None,
|
||||||
load_module=None, seed=None,
|
dir_path=None, environment_agents=None,
|
||||||
environment_agents=None, environment_params=None,
|
environment_params=None, environment_class=None,
|
||||||
environment_class=None, **kwargs):
|
**kwargs):
|
||||||
|
|
||||||
self.seed = str(seed) or str(time.time())
|
self.seed = str(seed) or str(time.time())
|
||||||
self.load_module = load_module
|
self.load_module = load_module
|
||||||
self.network_params = network_params
|
self.network_params = network_params
|
||||||
self.name = name or 'UnnamedSimulation'
|
self.name = name or 'Unnamed_' + time.strftime("%Y-%m-%d_%H:%M:%S")
|
||||||
|
self.group = group or None
|
||||||
self.num_trials = num_trials
|
self.num_trials = num_trials
|
||||||
self.max_time = max_time
|
self.max_time = max_time
|
||||||
self.default_state = default_state or {}
|
self.default_state = default_state or {}
|
||||||
self.dir_path = dir_path or os.getcwd()
|
self.dir_path = dir_path or os.getcwd()
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.dump = dump
|
|
||||||
self.dry_run = dry_run
|
|
||||||
|
|
||||||
sys.path += [self.dir_path, os.getcwd()]
|
sys.path += list(x for x in [os.getcwd(), self.dir_path] if x not in sys.path)
|
||||||
|
|
||||||
if topology is None:
|
if topology is None:
|
||||||
topology = utils.load_network(network_params,
|
topology = serialization.load_network(network_params,
|
||||||
dir_path=self.dir_path)
|
dir_path=self.dir_path)
|
||||||
elif isinstance(topology, basestring) or isinstance(topology, dict):
|
elif isinstance(topology, basestring) or isinstance(topology, dict):
|
||||||
topology = json_graph.node_link_graph(topology)
|
topology = json_graph.node_link_graph(topology)
|
||||||
@@ -111,7 +113,7 @@ class Simulation(NetworkSimulation):
|
|||||||
|
|
||||||
|
|
||||||
self.environment_params = environment_params or {}
|
self.environment_params = environment_params or {}
|
||||||
self.environment_class = utils.deserialize(environment_class,
|
self.environment_class = serialization.deserialize(environment_class,
|
||||||
known_modules=['soil.environment', ]) or Environment
|
known_modules=['soil.environment', ]) or Environment
|
||||||
|
|
||||||
environment_agents = environment_agents or []
|
environment_agents = environment_agents or []
|
||||||
@@ -130,32 +132,51 @@ class Simulation(NetworkSimulation):
|
|||||||
return self.run(*args, **kwargs)
|
return self.run(*args, **kwargs)
|
||||||
|
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
return list(self.run_simulation_gen(*args, **kwargs))
|
'''Run the simulation and return the list of resulting environments'''
|
||||||
|
return list(self._run_simulation_gen(*args, **kwargs))
|
||||||
|
|
||||||
def run_simulation_gen(self, *args, parallel=False, dry_run=False,
|
def _run_sync_or_async(self, parallel=False, *args, **kwargs):
|
||||||
**kwargs):
|
|
||||||
p = Pool()
|
|
||||||
with utils.timer('simulation {}'.format(self.name)):
|
|
||||||
if parallel:
|
if parallel:
|
||||||
func = partial(self.run_trial_exceptions, dry_run=dry_run or self.dry_run,
|
p = Pool()
|
||||||
return_env=True,
|
func = partial(self.run_trial_exceptions,
|
||||||
|
*args,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
for i in p.imap_unordered(func, range(self.num_trials)):
|
for i in p.imap_unordered(func, range(self.num_trials)):
|
||||||
if isinstance(i, Exception):
|
if isinstance(i, Exception):
|
||||||
logger.error('Trial failed:\n\t{}'.format(i.message))
|
logger.error('Trial failed:\n\t%s', i.message)
|
||||||
continue
|
continue
|
||||||
yield i
|
yield i
|
||||||
else:
|
else:
|
||||||
for i in range(self.num_trials):
|
for i in range(self.num_trials):
|
||||||
yield self.run_trial(i, dry_run = dry_run or self.dry_run, **kwargs)
|
yield self.run_trial(i,
|
||||||
if not (dry_run or self.dry_run):
|
*args,
|
||||||
logger.info('Dumping results to {}'.format(self.dir_path))
|
**kwargs)
|
||||||
self.dump_pickle(self.dir_path)
|
|
||||||
self.dump_yaml(self.dir_path)
|
def _run_simulation_gen(self, *args, parallel=False, dry_run=False,
|
||||||
else:
|
exporters=None, outdir=None, exporter_params={}, **kwargs):
|
||||||
logger.info('NOT dumping results')
|
logger.info('Using exporters: %s', exporters or [])
|
||||||
|
logger.info('Output directory: %s', outdir)
|
||||||
|
exporters = exporters_for_sim(self,
|
||||||
|
exporters or [],
|
||||||
|
dry_run=dry_run,
|
||||||
|
outdir=outdir,
|
||||||
|
**exporter_params)
|
||||||
|
|
||||||
|
with utils.timer('simulation {}'.format(self.name)):
|
||||||
|
for exporter in exporters:
|
||||||
|
exporter.start()
|
||||||
|
|
||||||
|
for env in self._run_sync_or_async(*args, parallel=parallel,
|
||||||
|
**kwargs):
|
||||||
|
for exporter in exporters:
|
||||||
|
exporter.trial_end(env)
|
||||||
|
yield env
|
||||||
|
|
||||||
|
for exporter in exporters:
|
||||||
|
exporter.end()
|
||||||
|
|
||||||
def get_env(self, trial_id = 0, **kwargs):
|
def get_env(self, trial_id = 0, **kwargs):
|
||||||
|
'''Create an environment for a trial of the simulation'''
|
||||||
opts = self.environment_params.copy()
|
opts = self.environment_params.copy()
|
||||||
env_name = '{}_trial_{}'.format(self.name, trial_id)
|
env_name = '{}_trial_{}'.format(self.name, trial_id)
|
||||||
opts.update({
|
opts.update({
|
||||||
@@ -163,19 +184,17 @@ class Simulation(NetworkSimulation):
|
|||||||
'topology': self.topology.copy(),
|
'topology': self.topology.copy(),
|
||||||
'seed': self.seed+env_name,
|
'seed': self.seed+env_name,
|
||||||
'initial_time': 0,
|
'initial_time': 0,
|
||||||
'dry_run': self.dry_run,
|
|
||||||
'interval': self.interval,
|
'interval': self.interval,
|
||||||
'network_agents': self.network_agents,
|
'network_agents': self.network_agents,
|
||||||
'states': self.states,
|
'states': self.states,
|
||||||
'default_state': self.default_state,
|
'default_state': self.default_state,
|
||||||
'environment_agents': self.environment_agents,
|
'environment_agents': self.environment_agents,
|
||||||
'dir_path': self.dir_path,
|
|
||||||
})
|
})
|
||||||
opts.update(kwargs)
|
opts.update(kwargs)
|
||||||
env = self.environment_class(**opts)
|
env = self.environment_class(**opts)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def run_trial(self, trial_id = 0, until = None, return_env = True, **opts):
|
def run_trial(self, trial_id=0, until=None, **opts):
|
||||||
"""Run a single trial of the simulation
|
"""Run a single trial of the simulation
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -188,10 +207,6 @@ class Simulation(NetworkSimulation):
|
|||||||
# Set up agents on nodes
|
# Set up agents on nodes
|
||||||
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
||||||
env.run(until)
|
env.run(until)
|
||||||
if self.dump and not self.dry_run:
|
|
||||||
with utils.timer('Dumping simulation {} trial {}'.format(self.name, trial_id)):
|
|
||||||
env.dump(formats = self.dump)
|
|
||||||
if return_env:
|
|
||||||
return env
|
return env
|
||||||
def run_trial_exceptions(self, *args, **kwargs):
|
def run_trial_exceptions(self, *args, **kwargs):
|
||||||
'''
|
'''
|
||||||
@@ -211,24 +226,25 @@ class Simulation(NetworkSimulation):
|
|||||||
def to_yaml(self):
|
def to_yaml(self):
|
||||||
return yaml.dump(self.to_dict())
|
return yaml.dump(self.to_dict())
|
||||||
|
|
||||||
def dump_yaml(self, dir_path = None, file_name = None):
|
|
||||||
dir_path=dir_path or self.dir_path
|
def dump_yaml(self, f=None, outdir=None):
|
||||||
if not os.path.exists(dir_path):
|
if not f and not outdir:
|
||||||
os.makedirs(dir_path)
|
raise ValueError('specify a file or an output directory')
|
||||||
if not file_name:
|
|
||||||
file_name=os.path.join(dir_path,
|
if not f:
|
||||||
'{}.dumped.yml'.format(self.name))
|
f = os.path.join(outdir, '{}.dumped.yml'.format(self.name))
|
||||||
with open(file_name, 'w') as f:
|
|
||||||
|
with utils.open_or_reuse(f, 'w') as f:
|
||||||
f.write(self.to_yaml())
|
f.write(self.to_yaml())
|
||||||
|
|
||||||
def dump_pickle(self, dir_path = None, pickle_name = None):
|
def dump_pickle(self, f=None, outdir=None):
|
||||||
dir_path=dir_path or self.dir_path
|
if not outdir and not f:
|
||||||
if not os.path.exists(dir_path):
|
raise ValueError('specify a file or an output directory')
|
||||||
os.makedirs(dir_path)
|
|
||||||
if not pickle_name:
|
if not f:
|
||||||
pickle_name=os.path.join(dir_path,
|
f = os.path.join(outdir,
|
||||||
'{}.simulation.pickle'.format(self.name))
|
'{}.simulation.pickle'.format(self.name))
|
||||||
with open(pickle_name, 'wb') as f:
|
with utils.open_or_reuse(f, 'wb') as f:
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
@@ -241,7 +257,7 @@ class Simulation(NetworkSimulation):
|
|||||||
known_modules = [])
|
known_modules = [])
|
||||||
state['environment_agents'] = agents.serialize_distribution(self.environment_agents,
|
state['environment_agents'] = agents.serialize_distribution(self.environment_agents,
|
||||||
known_modules = [])
|
known_modules = [])
|
||||||
state['environment_class']=utils.serialize(self.environment_class,
|
state['environment_class'] = serialization.serialize(self.environment_class,
|
||||||
known_modules=['soil.environment'])[1] # func, name
|
known_modules=['soil.environment'])[1] # func, name
|
||||||
if state['load_module'] is None:
|
if state['load_module'] is None:
|
||||||
del state['load_module']
|
del state['load_module']
|
||||||
@@ -256,13 +272,20 @@ class Simulation(NetworkSimulation):
|
|||||||
self.network_agents = agents.calculate_distribution(agents._convert_agent_types(self.network_agents))
|
self.network_agents = agents.calculate_distribution(agents._convert_agent_types(self.network_agents))
|
||||||
self.environment_agents = agents._convert_agent_types(self.environment_agents,
|
self.environment_agents = agents._convert_agent_types(self.environment_agents,
|
||||||
known_modules=[self.load_module])
|
known_modules=[self.load_module])
|
||||||
self.environment_class = utils.deserialize(self.environment_class,
|
self.environment_class = serialization.deserialize(self.environment_class,
|
||||||
known_modules=[self.load_module, 'soil.environment', ]) # func, name
|
known_modules=[self.load_module, 'soil.environment', ]) # func, name
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
def from_config(config):
|
def all_from_config(config):
|
||||||
config = list(utils.load_config(config))
|
configs = list(serialization.load_config(config))
|
||||||
|
for config, _ in configs:
|
||||||
|
sim = Simulation(**config)
|
||||||
|
yield sim
|
||||||
|
|
||||||
|
|
||||||
|
def from_config(conf_or_path):
|
||||||
|
config = list(serialization.load_config(conf_or_path))
|
||||||
if len(config) > 1:
|
if len(config) > 1:
|
||||||
raise AttributeError('Provide only one configuration')
|
raise AttributeError('Provide only one configuration')
|
||||||
config = config[0][0]
|
config = config[0][0]
|
||||||
@@ -270,21 +293,14 @@ def from_config(config):
|
|||||||
return sim
|
return sim
|
||||||
|
|
||||||
|
|
||||||
def run_from_config(*configs, results_dir='soil_output', dump=None, timestamp=False, **kwargs):
|
def run_from_config(*configs, **kwargs):
|
||||||
for config_def in configs:
|
for config_def in configs:
|
||||||
# logger.info("Found {} config(s)".format(len(ls)))
|
# logger.info("Found {} config(s)".format(len(ls)))
|
||||||
for config, _ in utils.load_config(config_def):
|
for config, path in serialization.load_config(config_def):
|
||||||
name = config.get('name', 'unnamed')
|
name = config.get('name', 'unnamed')
|
||||||
logger.info("Using config(s): {name}".format(name=name))
|
logger.info("Using config(s): {name}".format(name=name))
|
||||||
|
|
||||||
if timestamp:
|
dir_path = config.pop('dir_path', os.path.dirname(path))
|
||||||
sim_folder = '{}_{}'.format(name,
|
sim = Simulation(dir_path=dir_path,
|
||||||
time.strftime("%Y-%m-%d_%H:%M:%S"))
|
**config)
|
||||||
else:
|
|
||||||
sim_folder = name
|
|
||||||
dir_path = os.path.join(results_dir, sim_folder)
|
|
||||||
if dump is not None:
|
|
||||||
config['dump'] = dump
|
|
||||||
sim = Simulation(dir_path=dir_path, **config)
|
|
||||||
logger.info('Dumping results to {} : {}'.format(sim.dir_path, sim.dump))
|
|
||||||
sim.run_simulation(**kwargs)
|
sim.run_simulation(**kwargs)
|
||||||
|
145
soil/utils.py
145
soil/utils.py
@@ -1,72 +1,13 @@
|
|||||||
import os
|
|
||||||
import ast
|
|
||||||
import sys
|
|
||||||
import yaml
|
|
||||||
import logging
|
import logging
|
||||||
import importlib
|
|
||||||
import time
|
import time
|
||||||
from glob import glob
|
import os
|
||||||
from random import random
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import networkx as nx
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger('soil')
|
logger = logging.getLogger('soil')
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def load_network(network_params, dir_path=None):
|
|
||||||
if network_params is None:
|
|
||||||
return nx.Graph()
|
|
||||||
path = network_params.get('path', None)
|
|
||||||
if path:
|
|
||||||
if dir_path and not os.path.isabs(path):
|
|
||||||
path = os.path.join(dir_path, path)
|
|
||||||
extension = os.path.splitext(path)[1][1:]
|
|
||||||
kwargs = {}
|
|
||||||
if extension == 'gexf':
|
|
||||||
kwargs['version'] = '1.2draft'
|
|
||||||
kwargs['node_type'] = int
|
|
||||||
try:
|
|
||||||
method = getattr(nx.readwrite, 'read_' + extension)
|
|
||||||
except AttributeError:
|
|
||||||
raise AttributeError('Unknown format')
|
|
||||||
return method(path, **kwargs)
|
|
||||||
|
|
||||||
net_args = network_params.copy()
|
|
||||||
net_gen = net_args.pop('generator')
|
|
||||||
|
|
||||||
if dir_path not in sys.path:
|
|
||||||
sys.path.append(dir_path)
|
|
||||||
|
|
||||||
method = deserializer(net_gen,
|
|
||||||
known_modules=['networkx.generators',])
|
|
||||||
|
|
||||||
return method(**net_args)
|
|
||||||
|
|
||||||
|
|
||||||
def load_file(infile):
|
|
||||||
with open(infile, 'r') as f:
|
|
||||||
return list(yaml.load_all(f))
|
|
||||||
|
|
||||||
|
|
||||||
def load_files(*patterns):
|
|
||||||
for pattern in patterns:
|
|
||||||
for i in glob(pattern):
|
|
||||||
for config in load_file(i):
|
|
||||||
yield config, os.path.abspath(i)
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(config):
|
|
||||||
if isinstance(config, dict):
|
|
||||||
yield config, None
|
|
||||||
else:
|
|
||||||
yield from load_files(config)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def timer(name='task', pre="", function=logger.info, to_object=None):
|
def timer(name='task', pre="", function=logger.info, to_object=None):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@@ -82,81 +23,15 @@ def timer(name='task', pre="", function=logger.info, to_object=None):
|
|||||||
to_object.end = end
|
to_object.end = end
|
||||||
|
|
||||||
|
|
||||||
builtins = importlib.import_module('builtins')
|
def safe_open(path, *args, **kwargs):
|
||||||
|
outdir = os.path.dirname(path)
|
||||||
def name(value, known_modules=[]):
|
if outdir and not os.path.exists(outdir):
|
||||||
'''Return a name that can be imported, to serialize/deserialize an object'''
|
os.makedirs(outdir)
|
||||||
if value is None:
|
return open(path, *args, **kwargs)
|
||||||
return 'None'
|
|
||||||
if not isinstance(value, type): # Get the class name first
|
|
||||||
value = type(value)
|
|
||||||
tname = value.__name__
|
|
||||||
if hasattr(builtins, tname):
|
|
||||||
return tname
|
|
||||||
modname = value.__module__
|
|
||||||
if modname == '__main__':
|
|
||||||
return tname
|
|
||||||
if known_modules and modname in known_modules:
|
|
||||||
return tname
|
|
||||||
for kmod in known_modules:
|
|
||||||
if not kmod:
|
|
||||||
continue
|
|
||||||
module = importlib.import_module(kmod)
|
|
||||||
if hasattr(module, tname):
|
|
||||||
return tname
|
|
||||||
return '{}.{}'.format(modname, tname)
|
|
||||||
|
|
||||||
|
|
||||||
def serializer(type_):
|
def open_or_reuse(f, *args, **kwargs):
|
||||||
if type_ != 'str' and hasattr(builtins, type_):
|
|
||||||
return repr
|
|
||||||
return lambda x: x
|
|
||||||
|
|
||||||
|
|
||||||
def serialize(v, known_modules=[]):
|
|
||||||
'''Get a text representation of an object.'''
|
|
||||||
tname = name(v, known_modules=known_modules)
|
|
||||||
func = serializer(tname)
|
|
||||||
return func(v), tname
|
|
||||||
|
|
||||||
def deserializer(type_, known_modules=[]):
|
|
||||||
if type(type_) != str: # Already deserialized
|
|
||||||
return type_
|
|
||||||
if type_ == 'str':
|
|
||||||
return lambda x='': x
|
|
||||||
if type_ == 'None':
|
|
||||||
return lambda x=None: None
|
|
||||||
if hasattr(builtins, type_): # Check if it's a builtin type
|
|
||||||
cls = getattr(builtins, type_)
|
|
||||||
return lambda x=None: ast.literal_eval(x) if x is not None else cls()
|
|
||||||
# Otherwise, see if we can find the module and the class
|
|
||||||
modules = known_modules or []
|
|
||||||
options = []
|
|
||||||
|
|
||||||
for mod in modules:
|
|
||||||
if mod:
|
|
||||||
options.append((mod, type_))
|
|
||||||
|
|
||||||
if '.' in type_: # Fully qualified module
|
|
||||||
module, type_ = type_.rsplit(".", 1)
|
|
||||||
options.append ((module, type_))
|
|
||||||
|
|
||||||
errors = []
|
|
||||||
for modname, tname in options:
|
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(modname)
|
return safe_open(f, *args, **kwargs)
|
||||||
cls = getattr(module, tname)
|
except TypeError:
|
||||||
return getattr(cls, 'deserialize', cls)
|
return f
|
||||||
except (ImportError, AttributeError) as ex:
|
|
||||||
errors.append((modname, tname, ex))
|
|
||||||
raise Exception('Could not find type {}. Tried: {}'.format(type_, errors))
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize(type_, value=None, **kwargs):
|
|
||||||
'''Get an object from a text representation'''
|
|
||||||
if not isinstance(type_, str):
|
|
||||||
return type_
|
|
||||||
des = deserializer(type_, **kwargs)
|
|
||||||
if value is None:
|
|
||||||
return des
|
|
||||||
return des(value)
|
|
||||||
|
@@ -1,255 +0,0 @@
|
|||||||
import random
|
|
||||||
import networkx as nx
|
|
||||||
from soil.agents import BaseAgent, FSM, state, default_state
|
|
||||||
from scipy.spatial import cKDTree as KDTree
|
|
||||||
|
|
||||||
global betweenness_centrality_global
|
|
||||||
global degree_centrality_global
|
|
||||||
|
|
||||||
betweenness_centrality_global = None
|
|
||||||
degree_centrality_global = None
|
|
||||||
|
|
||||||
class TerroristSpreadModel(FSM):
|
|
||||||
"""
|
|
||||||
Settings:
|
|
||||||
information_spread_intensity
|
|
||||||
|
|
||||||
terrorist_additional_influence
|
|
||||||
|
|
||||||
min_vulnerability (optional else zero)
|
|
||||||
|
|
||||||
max_vulnerability
|
|
||||||
|
|
||||||
prob_interaction
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, environment=None, agent_id=0, state=()):
|
|
||||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
|
||||||
|
|
||||||
global betweenness_centrality_global
|
|
||||||
global degree_centrality_global
|
|
||||||
|
|
||||||
if betweenness_centrality_global == None:
|
|
||||||
betweenness_centrality_global = nx.betweenness_centrality(self.global_topology)
|
|
||||||
if degree_centrality_global == None:
|
|
||||||
degree_centrality_global = nx.degree_centrality(self.global_topology)
|
|
||||||
|
|
||||||
self.information_spread_intensity = environment.environment_params['information_spread_intensity']
|
|
||||||
self.terrorist_additional_influence = environment.environment_params['terrorist_additional_influence']
|
|
||||||
self.prob_interaction = environment.environment_params['prob_interaction']
|
|
||||||
|
|
||||||
if self['id'] == self.civilian.id: # Civilian
|
|
||||||
self.initial_belief = random.uniform(0.00, 0.5)
|
|
||||||
elif self['id'] == self.terrorist.id: # Terrorist
|
|
||||||
self.initial_belief = random.uniform(0.8, 1.00)
|
|
||||||
elif self['id'] == self.leader.id: # Leader
|
|
||||||
self.initial_belief = 1.00
|
|
||||||
else:
|
|
||||||
raise Exception('Invalid state id: {}'.format(self['id']))
|
|
||||||
|
|
||||||
if 'min_vulnerability' in environment.environment_params:
|
|
||||||
self.vulnerability = random.uniform( environment.environment_params['min_vulnerability'], environment.environment_params['max_vulnerability'] )
|
|
||||||
else :
|
|
||||||
self.vulnerability = random.uniform( 0, environment.environment_params['max_vulnerability'] )
|
|
||||||
|
|
||||||
self.mean_belief = self.initial_belief
|
|
||||||
self.betweenness_centrality = betweenness_centrality_global[self.id]
|
|
||||||
self.degree_centrality = degree_centrality_global[self.id]
|
|
||||||
|
|
||||||
# self.state['radicalism'] = self.mean_belief
|
|
||||||
|
|
||||||
def count_neighboring_agents(self, state_id=None):
|
|
||||||
if isinstance(state_id, list):
|
|
||||||
return len(self.get_neighboring_agents(state_id))
|
|
||||||
else:
|
|
||||||
return len(super().get_agents(state_id, limit_neighbors=True))
|
|
||||||
|
|
||||||
def get_neighboring_agents(self, state_id=None):
|
|
||||||
if isinstance(state_id, list):
|
|
||||||
_list = []
|
|
||||||
for i in state_id:
|
|
||||||
_list += super().get_agents(i, limit_neighbors=True)
|
|
||||||
return [ neighbour for neighbour in _list if isinstance(neighbour, TerroristSpreadModel) ]
|
|
||||||
else:
|
|
||||||
_list = super().get_agents(state_id, limit_neighbors=True)
|
|
||||||
return [ neighbour for neighbour in _list if isinstance(neighbour, TerroristSpreadModel) ]
|
|
||||||
|
|
||||||
@state
|
|
||||||
def civilian(self):
|
|
||||||
if self.count_neighboring_agents() > 0:
|
|
||||||
neighbours = []
|
|
||||||
for neighbour in self.get_neighboring_agents():
|
|
||||||
if random.random() < self.prob_interaction:
|
|
||||||
neighbours.append(neighbour)
|
|
||||||
influence = sum( neighbour.degree_centrality for neighbour in neighbours )
|
|
||||||
mean_belief = sum( neighbour.mean_belief * neighbour.degree_centrality / influence for neighbour in neighbours )
|
|
||||||
self.initial_belief = self.mean_belief
|
|
||||||
mean_belief = mean_belief * self.information_spread_intensity + self.initial_belief * ( 1 - self.information_spread_intensity )
|
|
||||||
self.mean_belief = mean_belief * self.vulnerability + self.initial_belief * ( 1 - self.vulnerability )
|
|
||||||
|
|
||||||
if self.mean_belief >= 0.8:
|
|
||||||
return self.terrorist
|
|
||||||
|
|
||||||
@state
|
|
||||||
def leader(self):
|
|
||||||
self.mean_belief = self.mean_belief ** ( 1 - self.terrorist_additional_influence )
|
|
||||||
if self.count_neighboring_agents(state_id=[self.terrorist.id, self.leader.id]) > 0:
|
|
||||||
for neighbour in self.get_neighboring_agents(state_id=[self.terrorist.id, self.leader.id]):
|
|
||||||
if neighbour.betweenness_centrality > self.betweenness_centrality:
|
|
||||||
return self.terrorist
|
|
||||||
|
|
||||||
@state
|
|
||||||
def terrorist(self):
|
|
||||||
if self.count_neighboring_agents(state_id=[self.terrorist.id, self.leader.id]) > 0:
|
|
||||||
neighbours = self.get_neighboring_agents(state_id=[self.terrorist.id, self.leader.id])
|
|
||||||
influence = sum( neighbour.degree_centrality for neighbour in neighbours )
|
|
||||||
mean_belief = sum( neighbour.mean_belief * neighbour.degree_centrality / influence for neighbour in neighbours )
|
|
||||||
self.initial_belief = self.mean_belief
|
|
||||||
self.mean_belief = mean_belief * self.vulnerability + self.initial_belief * ( 1 - self.vulnerability )
|
|
||||||
self.mean_belief = self.mean_belief ** ( 1 - self.terrorist_additional_influence )
|
|
||||||
|
|
||||||
if self.count_neighboring_agents(state_id=self.leader.id) == 0 and self.count_neighboring_agents(state_id=self.terrorist.id) > 0:
|
|
||||||
max_betweenness_centrality = self
|
|
||||||
for neighbour in self.get_neighboring_agents(state_id=self.terrorist.id):
|
|
||||||
if neighbour.betweenness_centrality > max_betweenness_centrality.betweenness_centrality:
|
|
||||||
max_betweenness_centrality = neighbour
|
|
||||||
if max_betweenness_centrality == self:
|
|
||||||
return self.leader
|
|
||||||
|
|
||||||
def add_edge(self, G, source, target):
|
|
||||||
G.add_edge(source.id, target.id, start=self.env._now)
|
|
||||||
|
|
||||||
def link_search(self, G, node, radius):
|
|
||||||
pos = nx.get_node_attributes(G, 'pos')
|
|
||||||
nodes, coords = list(zip(*pos.items()))
|
|
||||||
kdtree = KDTree(coords) # Cannot provide generator.
|
|
||||||
edge_indexes = kdtree.query_pairs(radius, 2)
|
|
||||||
_list = [ edge[int(not edge.index(node))] for edge in edge_indexes if node in edge ]
|
|
||||||
return [ G.nodes()[index]['agent'] for index in _list ]
|
|
||||||
|
|
||||||
def social_search(self, G, node, steps):
|
|
||||||
nodes = list(nx.ego_graph(G, node, radius=steps).nodes())
|
|
||||||
nodes.remove(node)
|
|
||||||
return [ G.nodes()[index]['agent'] for index in nodes ]
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingAreaModel(FSM):
|
|
||||||
"""
|
|
||||||
Settings:
|
|
||||||
training_influence
|
|
||||||
|
|
||||||
min_vulnerability
|
|
||||||
|
|
||||||
Requires TerroristSpreadModel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, environment=None, agent_id=0, state=()):
|
|
||||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
|
||||||
self.training_influence = environment.environment_params['training_influence']
|
|
||||||
if 'min_vulnerability' in environment.environment_params:
|
|
||||||
self.min_vulnerability = environment.environment_params['min_vulnerability']
|
|
||||||
else: self.min_vulnerability = 0
|
|
||||||
|
|
||||||
@default_state
|
|
||||||
@state
|
|
||||||
def terrorist(self):
|
|
||||||
for neighbour in self.get_neighboring_agents():
|
|
||||||
if isinstance(neighbour, TerroristSpreadModel) and neighbour.vulnerability > self.min_vulnerability:
|
|
||||||
neighbour.vulnerability = neighbour.vulnerability ** ( 1 - self.training_influence )
|
|
||||||
|
|
||||||
|
|
||||||
class HavenModel(FSM):
|
|
||||||
"""
|
|
||||||
Settings:
|
|
||||||
haven_influence
|
|
||||||
|
|
||||||
min_vulnerability
|
|
||||||
|
|
||||||
max_vulnerability
|
|
||||||
|
|
||||||
Requires TerroristSpreadModel.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, environment=None, agent_id=0, state=()):
|
|
||||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
|
||||||
self.haven_influence = environment.environment_params['haven_influence']
|
|
||||||
if 'min_vulnerability' in environment.environment_params:
|
|
||||||
self.min_vulnerability = environment.environment_params['min_vulnerability']
|
|
||||||
else: self.min_vulnerability = 0
|
|
||||||
self.max_vulnerability = environment.environment_params['max_vulnerability']
|
|
||||||
|
|
||||||
@state
|
|
||||||
def civilian(self):
|
|
||||||
for neighbour_agent in self.get_neighboring_agents():
|
|
||||||
if isinstance(neighbour_agent, TerroristSpreadModel) and neighbour_agent['id'] == neighbour_agent.civilian.id:
|
|
||||||
for neighbour in self.get_neighboring_agents():
|
|
||||||
if isinstance(neighbour, TerroristSpreadModel) and neighbour.vulnerability > self.min_vulnerability:
|
|
||||||
neighbour.vulnerability = neighbour.vulnerability * ( 1 - self.haven_influence )
|
|
||||||
return self.civilian
|
|
||||||
return self.terrorist
|
|
||||||
|
|
||||||
@state
|
|
||||||
def terrorist(self):
|
|
||||||
for neighbour in self.get_neighboring_agents():
|
|
||||||
if isinstance(neighbour, TerroristSpreadModel) and neighbour.vulnerability < self.max_vulnerability:
|
|
||||||
neighbour.vulnerability = neighbour.vulnerability ** ( 1 - self.haven_influence )
|
|
||||||
return self.terrorist
|
|
||||||
|
|
||||||
|
|
||||||
class TerroristNetworkModel(TerroristSpreadModel):
|
|
||||||
"""
|
|
||||||
Settings:
|
|
||||||
sphere_influence
|
|
||||||
|
|
||||||
vision_range
|
|
||||||
|
|
||||||
weight_social_distance
|
|
||||||
|
|
||||||
weight_link_distance
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, environment=None, agent_id=0, state=()):
|
|
||||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
|
||||||
|
|
||||||
self.vision_range = environment.environment_params['vision_range']
|
|
||||||
self.sphere_influence = environment.environment_params['sphere_influence']
|
|
||||||
self.weight_social_distance = environment.environment_params['weight_social_distance']
|
|
||||||
self.weight_link_distance = environment.environment_params['weight_link_distance']
|
|
||||||
|
|
||||||
@state
|
|
||||||
def terrorist(self):
|
|
||||||
self.update_relationships()
|
|
||||||
return super().terrorist()
|
|
||||||
|
|
||||||
@state
|
|
||||||
def leader(self):
|
|
||||||
self.update_relationships()
|
|
||||||
return super().leader()
|
|
||||||
|
|
||||||
def update_relationships(self):
|
|
||||||
if self.count_neighboring_agents(state_id=self.civilian.id) == 0:
|
|
||||||
close_ups = self.link_search(self.global_topology, self.id, self.vision_range)
|
|
||||||
step_neighbours = self.social_search(self.global_topology, self.id, self.sphere_influence)
|
|
||||||
search = list(set(close_ups).union(step_neighbours))
|
|
||||||
neighbours = self.get_neighboring_agents()
|
|
||||||
search = [item for item in search if not item in neighbours and isinstance(item, TerroristNetworkModel)]
|
|
||||||
for agent in search:
|
|
||||||
social_distance = 1 / self.shortest_path_length(self.global_topology, self.id, agent.id)
|
|
||||||
spatial_proximity = ( 1 - self.get_distance(self.global_topology, self.id, agent.id) )
|
|
||||||
prob_new_interaction = self.weight_social_distance * social_distance + self.weight_link_distance * spatial_proximity
|
|
||||||
if agent['id'] == agent.civilian.id and random.random() < prob_new_interaction:
|
|
||||||
self.add_edge(self.global_topology, self, agent)
|
|
||||||
break
|
|
||||||
|
|
||||||
def get_distance(self, G, source, target):
|
|
||||||
source_x, source_y = nx.get_node_attributes(G, 'pos')[source]
|
|
||||||
target_x, target_y = nx.get_node_attributes(G, 'pos')[target]
|
|
||||||
dx = abs( source_x - target_x )
|
|
||||||
dy = abs( source_y - target_y )
|
|
||||||
return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 )
|
|
||||||
|
|
||||||
def shortest_path_length(self, G, source, target):
|
|
||||||
try:
|
|
||||||
return nx.shortest_path_length(G, source, target)
|
|
||||||
except nx.NetworkXNoPath:
|
|
||||||
return float('inf')
|
|
@@ -180,7 +180,7 @@ class SocketHandler(tornado.websocket.WebSocketHandler):
|
|||||||
with self.logging(self.simulation_name):
|
with self.logging(self.simulation_name):
|
||||||
try:
|
try:
|
||||||
config = dict(**self.config)
|
config = dict(**self.config)
|
||||||
config['dir_path'] = os.path.join(self.application.dir_path, config['name'])
|
config['outdir'] = os.path.join(self.application.outdir, config['name'])
|
||||||
config['dump'] = self.application.dump
|
config['dump'] = self.application.dump
|
||||||
self.trials = yield self.nonblocking(config)
|
self.trials = yield self.nonblocking(config)
|
||||||
|
|
||||||
@@ -232,12 +232,12 @@ class ModularServer(tornado.web.Application):
|
|||||||
settings = {'debug': True,
|
settings = {'debug': True,
|
||||||
'template_path': ROOT + '/templates'}
|
'template_path': ROOT + '/templates'}
|
||||||
|
|
||||||
def __init__(self, dump=False, dir_path='output', name='SOIL', verbose=True, *args, **kwargs):
|
def __init__(self, dump=False, outdir='output', name='SOIL', verbose=True, *args, **kwargs):
|
||||||
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.name = name
|
self.name = name
|
||||||
self.dump = dump
|
self.dump = dump
|
||||||
self.dir_path = dir_path
|
self.outdir = outdir
|
||||||
|
|
||||||
# Initializing the application itself:
|
# Initializing the application itself:
|
||||||
super().__init__(self.handlers, **self.settings)
|
super().__init__(self.handlers, **self.settings)
|
||||||
|
@@ -39,7 +39,6 @@ class TestAnalysis(TestCase):
|
|||||||
agent should be able to update its state."""
|
agent should be able to update its state."""
|
||||||
config = {
|
config = {
|
||||||
'name': 'analysis',
|
'name': 'analysis',
|
||||||
'dry_run': True,
|
|
||||||
'seed': 'seed',
|
'seed': 'seed',
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'generator': 'complete_graph',
|
'generator': 'complete_graph',
|
||||||
@@ -53,7 +52,7 @@ class TestAnalysis(TestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
self.env = s.run_simulation()[0]
|
self.env = s.run_simulation(dry_run=True)[0]
|
||||||
|
|
||||||
def test_saved(self):
|
def test_saved(self):
|
||||||
env = self.env
|
env = self.env
|
||||||
@@ -65,7 +64,7 @@ class TestAnalysis(TestCase):
|
|||||||
|
|
||||||
def test_count(self):
|
def test_count(self):
|
||||||
env = self.env
|
env = self.env
|
||||||
df = analysis.read_sql(env._history._db)
|
df = analysis.read_sql(env._history.db_path)
|
||||||
res = analysis.get_count(df, 'SEED', 'id')
|
res = analysis.get_count(df, 'SEED', 'id')
|
||||||
assert res['SEED']['seedanalysis_trial_0'].iloc[0] == 1
|
assert res['SEED']['seedanalysis_trial_0'].iloc[0] == 1
|
||||||
assert res['SEED']['seedanalysis_trial_0'].iloc[-1] == 1
|
assert res['SEED']['seedanalysis_trial_0'].iloc[-1] == 1
|
||||||
|
@@ -2,11 +2,13 @@ from unittest import TestCase
|
|||||||
import os
|
import os
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
|
||||||
from soil import utils, simulation
|
from soil import serialization, simulation
|
||||||
|
|
||||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||||
EXAMPLES = join(ROOT, '..', 'examples')
|
EXAMPLES = join(ROOT, '..', 'examples')
|
||||||
|
|
||||||
|
FORCE_TESTS = os.environ.get('FORCE_TESTS', '')
|
||||||
|
|
||||||
|
|
||||||
class TestExamples(TestCase):
|
class TestExamples(TestCase):
|
||||||
pass
|
pass
|
||||||
@@ -15,11 +17,13 @@ class TestExamples(TestCase):
|
|||||||
def make_example_test(path, config):
|
def make_example_test(path, config):
|
||||||
def wrapped(self):
|
def wrapped(self):
|
||||||
root = os.getcwd()
|
root = os.getcwd()
|
||||||
os.chdir(os.path.dirname(path))
|
for s in simulation.all_from_config(path):
|
||||||
s = simulation.from_config(config)
|
|
||||||
iterations = s.max_time * s.num_trials
|
iterations = s.max_time * s.num_trials
|
||||||
if iterations > 1000:
|
if iterations > 1000:
|
||||||
self.skipTest('This example would probably take too long')
|
s.max_time = 100
|
||||||
|
s.num_trials = 1
|
||||||
|
if config.get('skip_test', False) and not FORCE_TESTS:
|
||||||
|
self.skipTest('Example ignored.')
|
||||||
envs = s.run_simulation(dry_run=True)
|
envs = s.run_simulation(dry_run=True)
|
||||||
assert envs
|
assert envs
|
||||||
for env in envs:
|
for env in envs:
|
||||||
@@ -31,12 +35,14 @@ def make_example_test(path, config):
|
|||||||
assert env.now <= config['max_time'] # But not further than allowed
|
assert env.now <= config['max_time'] # But not further than allowed
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
os.chdir(root)
|
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
def add_example_tests():
|
def add_example_tests():
|
||||||
for config, path in utils.load_config(join(EXAMPLES, '**', '*.yml')):
|
for config, path in serialization.load_files(
|
||||||
|
join(EXAMPLES, '*', '*.yml'),
|
||||||
|
join(EXAMPLES, '*.yml'),
|
||||||
|
):
|
||||||
p = make_example_test(path=path, config=config)
|
p = make_example_test(path=path, config=config)
|
||||||
fname = os.path.basename(path)
|
fname = os.path.basename(path)
|
||||||
p.__name__ = 'test_example_file_%s' % fname
|
p.__name__ = 'test_example_file_%s' % fname
|
||||||
|
110
tests/test_exporters.py
Normal file
110
tests/test_exporters.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import os
|
||||||
|
import io
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
from unittest import TestCase
|
||||||
|
from soil import exporters
|
||||||
|
from soil.utils import safe_open
|
||||||
|
from soil import simulation
|
||||||
|
|
||||||
|
|
||||||
|
class Dummy(exporters.Exporter):
|
||||||
|
started = False
|
||||||
|
trials = 0
|
||||||
|
ended = False
|
||||||
|
total_time = 0
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.__class__.started = True
|
||||||
|
|
||||||
|
def trial_end(self, env):
|
||||||
|
assert env
|
||||||
|
self.__class__.trials += 1
|
||||||
|
self.__class__.total_time += env.now
|
||||||
|
|
||||||
|
def end(self):
|
||||||
|
self.__class__.ended = True
|
||||||
|
|
||||||
|
|
||||||
|
class Exporters(TestCase):
|
||||||
|
def test_basic(self):
|
||||||
|
config = {
|
||||||
|
'name': 'exporter_sim',
|
||||||
|
'network_params': {},
|
||||||
|
'agent_type': 'CounterModel',
|
||||||
|
'max_time': 2,
|
||||||
|
'num_trials': 5,
|
||||||
|
'environment_params': {}
|
||||||
|
}
|
||||||
|
s = simulation.from_config(config)
|
||||||
|
s.run_simulation(exporters=[Dummy], dry_run=True)
|
||||||
|
assert Dummy.started
|
||||||
|
assert Dummy.ended
|
||||||
|
assert Dummy.trials == 5
|
||||||
|
assert Dummy.total_time == 2*5
|
||||||
|
|
||||||
|
def test_distribution(self):
|
||||||
|
'''The distribution exporter should write the number of agents in each state'''
|
||||||
|
config = {
|
||||||
|
'name': 'exporter_sim',
|
||||||
|
'network_params': {
|
||||||
|
'generator': 'complete_graph',
|
||||||
|
'n': 4
|
||||||
|
},
|
||||||
|
'agent_type': 'CounterModel',
|
||||||
|
'max_time': 2,
|
||||||
|
'num_trials': 5,
|
||||||
|
'environment_params': {}
|
||||||
|
}
|
||||||
|
output = io.StringIO()
|
||||||
|
s = simulation.from_config(config)
|
||||||
|
s.run_simulation(exporters=[exporters.Distribution], dry_run=True, exporter_params={'copy_to': output})
|
||||||
|
result = output.getvalue()
|
||||||
|
assert 'count' in result
|
||||||
|
assert 'SEED,Noneexporter_sim_trial_3,1,,1,1,1,1' in result
|
||||||
|
|
||||||
|
def test_writing(self):
|
||||||
|
'''Try to write CSV, GEXF, sqlite and YAML (without dry_run)'''
|
||||||
|
n_trials = 5
|
||||||
|
config = {
|
||||||
|
'name': 'exporter_sim',
|
||||||
|
'network_params': {
|
||||||
|
'generator': 'complete_graph',
|
||||||
|
'n': 4
|
||||||
|
},
|
||||||
|
'agent_type': 'CounterModel',
|
||||||
|
'max_time': 2,
|
||||||
|
'num_trials': n_trials,
|
||||||
|
'environment_params': {}
|
||||||
|
}
|
||||||
|
output = io.StringIO()
|
||||||
|
s = simulation.from_config(config)
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
envs = s.run_simulation(exporters=[
|
||||||
|
exporters.Default,
|
||||||
|
exporters.CSV,
|
||||||
|
exporters.Gexf,
|
||||||
|
exporters.Distribution,
|
||||||
|
],
|
||||||
|
outdir=tmpdir,
|
||||||
|
exporter_params={'copy_to': output})
|
||||||
|
result = output.getvalue()
|
||||||
|
|
||||||
|
simdir = os.path.join(tmpdir, s.group or '', s.name)
|
||||||
|
with open(os.path.join(simdir, '{}.dumped.yml'.format(s.name))) as f:
|
||||||
|
result = f.read()
|
||||||
|
assert result
|
||||||
|
|
||||||
|
try:
|
||||||
|
for e in envs:
|
||||||
|
with open(os.path.join(simdir, '{}.gexf'.format(e.name))) as f:
|
||||||
|
result = f.read()
|
||||||
|
assert result
|
||||||
|
|
||||||
|
with open(os.path.join(simdir, '{}.csv'.format(e.name))) as f:
|
||||||
|
result = f.read()
|
||||||
|
assert result
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(tmpdir)
|
@@ -1,13 +1,15 @@
|
|||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import io
|
||||||
import yaml
|
import yaml
|
||||||
import pickle
|
import pickle
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from os.path import join
|
from os.path import join
|
||||||
from soil import simulation, Environment, agents, utils, history
|
from soil import (simulation, Environment, agents, serialization,
|
||||||
|
history, utils)
|
||||||
|
|
||||||
|
|
||||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||||
@@ -27,22 +29,20 @@ class TestMain(TestCase):
|
|||||||
Raise an exception otherwise.
|
Raise an exception otherwise.
|
||||||
"""
|
"""
|
||||||
config = {
|
config = {
|
||||||
'dry_run': True,
|
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'path': join(ROOT, 'test.gexf')
|
'path': join(ROOT, 'test.gexf')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
G = utils.load_network(config['network_params'])
|
G = serialization.load_network(config['network_params'])
|
||||||
assert G
|
assert G
|
||||||
assert len(G) == 2
|
assert len(G) == 2
|
||||||
with self.assertRaises(AttributeError):
|
with self.assertRaises(AttributeError):
|
||||||
config = {
|
config = {
|
||||||
'dry_run': True,
|
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'path': join(ROOT, 'unknown.extension')
|
'path': join(ROOT, 'unknown.extension')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
G = utils.load_network(config['network_params'])
|
G = serialization.load_network(config['network_params'])
|
||||||
print(G)
|
print(G)
|
||||||
|
|
||||||
def test_generate_barabasi(self):
|
def test_generate_barabasi(self):
|
||||||
@@ -51,22 +51,20 @@ class TestMain(TestCase):
|
|||||||
should be used to generate a network
|
should be used to generate a network
|
||||||
"""
|
"""
|
||||||
config = {
|
config = {
|
||||||
'dry_run': True,
|
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'generator': 'barabasi_albert_graph'
|
'generator': 'barabasi_albert_graph'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
G = utils.load_network(config['network_params'])
|
G = serialization.load_network(config['network_params'])
|
||||||
config['network_params']['n'] = 100
|
config['network_params']['n'] = 100
|
||||||
config['network_params']['m'] = 10
|
config['network_params']['m'] = 10
|
||||||
G = utils.load_network(config['network_params'])
|
G = serialization.load_network(config['network_params'])
|
||||||
assert len(G) == 100
|
assert len(G) == 100
|
||||||
|
|
||||||
def test_empty_simulation(self):
|
def test_empty_simulation(self):
|
||||||
"""A simulation with a base behaviour should do nothing"""
|
"""A simulation with a base behaviour should do nothing"""
|
||||||
config = {
|
config = {
|
||||||
'dry_run': True,
|
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'path': join(ROOT, 'test.gexf')
|
'path': join(ROOT, 'test.gexf')
|
||||||
},
|
},
|
||||||
@@ -83,7 +81,6 @@ class TestMain(TestCase):
|
|||||||
agent should be able to update its state."""
|
agent should be able to update its state."""
|
||||||
config = {
|
config = {
|
||||||
'name': 'CounterAgent',
|
'name': 'CounterAgent',
|
||||||
'dry_run': True,
|
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'path': join(ROOT, 'test.gexf')
|
'path': join(ROOT, 'test.gexf')
|
||||||
},
|
},
|
||||||
@@ -107,7 +104,6 @@ class TestMain(TestCase):
|
|||||||
"""
|
"""
|
||||||
config = {
|
config = {
|
||||||
'name': 'CounterAgent',
|
'name': 'CounterAgent',
|
||||||
'dry_run': True,
|
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'path': join(ROOT, 'test.gexf')
|
'path': join(ROOT, 'test.gexf')
|
||||||
},
|
},
|
||||||
@@ -133,7 +129,6 @@ class TestMain(TestCase):
|
|||||||
def test_custom_agent(self):
|
def test_custom_agent(self):
|
||||||
"""Allow for search of neighbors with a certain state_id"""
|
"""Allow for search of neighbors with a certain state_id"""
|
||||||
config = {
|
config = {
|
||||||
'dry_run': True,
|
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'path': join(ROOT, 'test.gexf')
|
'path': join(ROOT, 'test.gexf')
|
||||||
},
|
},
|
||||||
@@ -153,12 +148,11 @@ class TestMain(TestCase):
|
|||||||
|
|
||||||
def test_torvalds_example(self):
|
def test_torvalds_example(self):
|
||||||
"""A complete example from a documentation should work."""
|
"""A complete example from a documentation should work."""
|
||||||
config = utils.load_file(join(EXAMPLES, 'torvalds.yml'))[0]
|
config = serialization.load_file(join(EXAMPLES, 'torvalds.yml'))[0]
|
||||||
config['network_params']['path'] = join(EXAMPLES,
|
config['network_params']['path'] = join(EXAMPLES,
|
||||||
config['network_params']['path'])
|
config['network_params']['path'])
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
s.dry_run = True
|
env = s.run_simulation(dry_run=True)[0]
|
||||||
env = s.run_simulation()[0]
|
|
||||||
for a in env.network_agents:
|
for a in env.network_agents:
|
||||||
skill_level = a.state['skill_level']
|
skill_level = a.state['skill_level']
|
||||||
if a.id == 'Torvalds':
|
if a.id == 'Torvalds':
|
||||||
@@ -180,9 +174,8 @@ class TestMain(TestCase):
|
|||||||
should be equivalent to the configuration file used
|
should be equivalent to the configuration file used
|
||||||
"""
|
"""
|
||||||
with utils.timer('loading'):
|
with utils.timer('loading'):
|
||||||
config = utils.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
s.dry_run = True
|
|
||||||
with utils.timer('serializing'):
|
with utils.timer('serializing'):
|
||||||
serial = s.to_yaml()
|
serial = s.to_yaml()
|
||||||
with utils.timer('recovering'):
|
with utils.timer('recovering'):
|
||||||
@@ -196,9 +189,8 @@ class TestMain(TestCase):
|
|||||||
The configuration should not change after running
|
The configuration should not change after running
|
||||||
the simulation.
|
the simulation.
|
||||||
"""
|
"""
|
||||||
config = utils.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
s.dry_run = True
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
s.run_simulation(dry_run=True)
|
s.run_simulation(dry_run=True)
|
||||||
nconfig = s.to_dict()
|
nconfig = s.to_dict()
|
||||||
@@ -206,7 +198,7 @@ class TestMain(TestCase):
|
|||||||
assert config == nconfig
|
assert config == nconfig
|
||||||
|
|
||||||
def test_row_conversion(self):
|
def test_row_conversion(self):
|
||||||
env = Environment(dry_run=True)
|
env = Environment()
|
||||||
env['test'] = 'test_value'
|
env['test'] = 'test_value'
|
||||||
|
|
||||||
res = list(env.history_to_tuples())
|
res = list(env.history_to_tuples())
|
||||||
@@ -225,8 +217,9 @@ class TestMain(TestCase):
|
|||||||
from geometric models. We should work around it.
|
from geometric models. We should work around it.
|
||||||
"""
|
"""
|
||||||
G = nx.random_geometric_graph(20, 0.1)
|
G = nx.random_geometric_graph(20, 0.1)
|
||||||
env = Environment(topology=G, dry_run=True)
|
env = Environment(topology=G)
|
||||||
env.dump_gexf('/tmp/dump-gexf')
|
f = io.BytesIO()
|
||||||
|
env.dump_gexf(f)
|
||||||
|
|
||||||
def test_save_graph(self):
|
def test_save_graph(self):
|
||||||
'''
|
'''
|
||||||
@@ -236,7 +229,7 @@ class TestMain(TestCase):
|
|||||||
'''
|
'''
|
||||||
G = nx.cycle_graph(5)
|
G = nx.cycle_graph(5)
|
||||||
distribution = agents.calculate_distribution(None, agents.BaseAgent)
|
distribution = agents.calculate_distribution(None, agents.BaseAgent)
|
||||||
env = Environment(topology=G, network_agents=distribution, dry_run=True)
|
env = Environment(topology=G, network_agents=distribution)
|
||||||
env[0, 0, 'testvalue'] = 'start'
|
env[0, 0, 'testvalue'] = 'start'
|
||||||
env[0, 10, 'testvalue'] = 'finish'
|
env[0, 10, 'testvalue'] = 'finish'
|
||||||
nG = env.history_to_graph()
|
nG = env.history_to_graph()
|
||||||
@@ -245,11 +238,11 @@ class TestMain(TestCase):
|
|||||||
assert ('finish', 10, None) in values
|
assert ('finish', 10, None) in values
|
||||||
|
|
||||||
def test_serialize_class(self):
|
def test_serialize_class(self):
|
||||||
ser, name = utils.serialize(agents.BaseAgent)
|
ser, name = serialization.serialize(agents.BaseAgent)
|
||||||
assert name == 'soil.agents.BaseAgent'
|
assert name == 'soil.agents.BaseAgent'
|
||||||
assert ser == agents.BaseAgent
|
assert ser == agents.BaseAgent
|
||||||
|
|
||||||
ser, name = utils.serialize(CustomAgent)
|
ser, name = serialization.serialize(CustomAgent)
|
||||||
assert name == 'test_main.CustomAgent'
|
assert name == 'test_main.CustomAgent'
|
||||||
assert ser == CustomAgent
|
assert ser == CustomAgent
|
||||||
pickle.dumps(ser)
|
pickle.dumps(ser)
|
||||||
@@ -257,9 +250,9 @@ class TestMain(TestCase):
|
|||||||
def test_serialize_builtin_types(self):
|
def test_serialize_builtin_types(self):
|
||||||
|
|
||||||
for i in [1, None, True, False, {}, [], list(), dict()]:
|
for i in [1, None, True, False, {}, [], list(), dict()]:
|
||||||
ser, name = utils.serialize(i)
|
ser, name = serialization.serialize(i)
|
||||||
assert type(ser) == str
|
assert type(ser) == str
|
||||||
des = utils.deserialize(name, ser)
|
des = serialization.deserialize(name, ser)
|
||||||
assert i == des
|
assert i == des
|
||||||
|
|
||||||
def test_serialize_agent_type(self):
|
def test_serialize_agent_type(self):
|
||||||
@@ -312,11 +305,35 @@ class TestMain(TestCase):
|
|||||||
recovered = pickle.loads(pickled)
|
recovered = pickle.loads(pickled)
|
||||||
|
|
||||||
assert recovered.env.name == 'Test'
|
assert recovered.env.name == 'Test'
|
||||||
assert recovered['key'] == 'test'
|
assert list(recovered.env._history.to_tuples())
|
||||||
assert recovered['key', 0] == 'test'
|
assert recovered['key', 0] == 'test'
|
||||||
|
assert recovered['key'] == 'test'
|
||||||
|
|
||||||
def test_history(self):
|
def test_history(self):
|
||||||
'''Test storing in and retrieving from history (sqlite)'''
|
'''Test storing in and retrieving from history (sqlite)'''
|
||||||
h = history.History()
|
h = history.History()
|
||||||
h.save_record(agent_id=0, t_step=0, key="test", value="hello")
|
h.save_record(agent_id=0, t_step=0, key="test", value="hello")
|
||||||
assert h[0, 0, "test"] == "hello"
|
assert h[0, 0, "test"] == "hello"
|
||||||
|
|
||||||
|
def test_subgraph(self):
|
||||||
|
'''An agent should be able to subgraph the global topology'''
|
||||||
|
G = nx.Graph()
|
||||||
|
G.add_node(3)
|
||||||
|
G.add_edge(1, 2)
|
||||||
|
distro = agents.calculate_distribution(agent_type=agents.NetworkAgent)
|
||||||
|
env = Environment(name='Test', topology=G, network_agents=distro)
|
||||||
|
lst = list(env.network_agents)
|
||||||
|
|
||||||
|
a2 = env.get_agent(2)
|
||||||
|
a3 = env.get_agent(3)
|
||||||
|
assert len(a2.subgraph(limit_neighbors=True)) == 2
|
||||||
|
assert len(a3.subgraph(limit_neighbors=True)) == 1
|
||||||
|
assert len(a3.subgraph(limit_neighbors=True, center=False)) == 0
|
||||||
|
assert len(a3.subgraph(agent_type=agents.NetworkAgent)) == 3
|
||||||
|
|
||||||
|
def test_templates(self):
|
||||||
|
'''Loading a template should result in several configs'''
|
||||||
|
configs = serialization.load_file(join(EXAMPLES, 'template.yml'))
|
||||||
|
assert len(configs) > 0
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user