mirror of
https://github.com/gsi-upm/soil
synced 2024-11-22 03:02:28 +00:00
WIP: all tests pass
This commit is contained in:
parent
f811ee18c5
commit
cd62c23cb9
@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file.
|
|||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
## [0.3 UNRELEASED]
|
## [0.3 UNRELEASED]
|
||||||
|
### Added
|
||||||
|
* Simple debugging capabilities, with a custom `pdb.Debugger` subclass that exposes commands to list agents and their status and set breakpoints on states (for FSM agents)
|
||||||
### Changed
|
### Changed
|
||||||
* Configuration schema is very different now. Check `soil.config` for more information. We are also using Pydantic for (de)serialization.
|
* Configuration schema is very different now. Check `soil.config` for more information. We are also using Pydantic for (de)serialization.
|
||||||
* There may be more than one topology/network in the simulation
|
* There may be more than one topology/network in the simulation
|
||||||
|
12
docs/soil-vs.rst
Normal file
12
docs/soil-vs.rst
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
### MESA
|
||||||
|
|
||||||
|
Starting with version 0.3, Soil has been redesigned to complement Mesa, while remaining compatible with it.
|
||||||
|
That means that every component in Soil (i.e., Models, Environments, etc.) can be mixed with existing mesa components.
|
||||||
|
In fact, there are examples that show how that integration may be used, in the `examples/mesa` folder in the repository.
|
||||||
|
|
||||||
|
Here are some reasons to use Soil instead of plain mesa:
|
||||||
|
|
||||||
|
- Less boilerplate for common scenarios (by some definitions of common)
|
||||||
|
- Functions to automatically populate a topology with an agent distribution (i.e., different ratios of agent class and state)
|
||||||
|
- The `soil.Simulation` class allows you to run multiple instances of the same experiment (i.e., multiple trials with the same parameters but a different randomness seed)
|
||||||
|
- Reporting functions that aggregate multiple
|
@ -1,14 +1,16 @@
|
|||||||
---
|
---
|
||||||
version: '2'
|
version: '2'
|
||||||
general:
|
name: simple
|
||||||
id: simple
|
group: tests
|
||||||
group: tests
|
dir_path: "/tmp/"
|
||||||
dir_path: "/tmp/"
|
num_trials: 3
|
||||||
num_trials: 3
|
max_steps: 100
|
||||||
max_time: 100
|
interval: 1
|
||||||
interval: 1
|
seed: "CompleteSeed!"
|
||||||
seed: "CompleteSeed!"
|
model_class: Environment
|
||||||
topologies:
|
model_params:
|
||||||
|
am_i_complete: true
|
||||||
|
topologies:
|
||||||
default:
|
default:
|
||||||
params:
|
params:
|
||||||
generator: complete_graph
|
generator: complete_graph
|
||||||
@ -17,30 +19,36 @@ topologies:
|
|||||||
params:
|
params:
|
||||||
generator: complete_graph
|
generator: complete_graph
|
||||||
n: 2
|
n: 2
|
||||||
environment:
|
environment:
|
||||||
environment_class: Environment
|
agents:
|
||||||
params:
|
|
||||||
am_i_complete: true
|
|
||||||
agents:
|
|
||||||
# Agents are split several groups, each with its own definition
|
|
||||||
default: # This is a special group. Its values will be used as default values for the rest of the groups
|
|
||||||
agent_class: CounterModel
|
agent_class: CounterModel
|
||||||
topology: default
|
topology: default
|
||||||
state:
|
state:
|
||||||
times: 1
|
times: 1
|
||||||
environment:
|
|
||||||
# In this group we are not specifying any topology
|
# In this group we are not specifying any topology
|
||||||
topology: False
|
|
||||||
fixed:
|
fixed:
|
||||||
- name: 'Environment Agent 1'
|
- name: 'Environment Agent 1'
|
||||||
agent_class: CounterModel
|
agent_class: BaseAgent
|
||||||
|
group: environment
|
||||||
|
topology: null
|
||||||
|
hidden: true
|
||||||
state:
|
state:
|
||||||
times: 10
|
times: 10
|
||||||
general_counters:
|
- agent_class: CounterModel
|
||||||
topology: default
|
id: 0
|
||||||
|
group: other_counters
|
||||||
|
topology: another_graph
|
||||||
|
state:
|
||||||
|
times: 1
|
||||||
|
total: 0
|
||||||
|
- agent_class: CounterModel
|
||||||
|
topology: another_graph
|
||||||
|
group: other_counters
|
||||||
|
id: 1
|
||||||
distribution:
|
distribution:
|
||||||
- agent_class: CounterModel
|
- agent_class: CounterModel
|
||||||
weight: 1
|
weight: 1
|
||||||
|
group: general_counters
|
||||||
state:
|
state:
|
||||||
times: 3
|
times: 3
|
||||||
- agent_class: AggregatedCounter
|
- agent_class: AggregatedCounter
|
||||||
@ -51,16 +59,3 @@ agents:
|
|||||||
n: 2
|
n: 2
|
||||||
state:
|
state:
|
||||||
times: 5
|
times: 5
|
||||||
|
|
||||||
other_counters:
|
|
||||||
topology: another_graph
|
|
||||||
fixed:
|
|
||||||
- agent_class: CounterModel
|
|
||||||
id: 0
|
|
||||||
state:
|
|
||||||
times: 1
|
|
||||||
total: 0
|
|
||||||
- agent_class: CounterModel
|
|
||||||
id: 1
|
|
||||||
# If not specified, it will use the state set in the default
|
|
||||||
# state:
|
|
||||||
|
63
examples/complete_opt2.yml
Normal file
63
examples/complete_opt2.yml
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
---
|
||||||
|
version: '2'
|
||||||
|
id: simple
|
||||||
|
group: tests
|
||||||
|
dir_path: "/tmp/"
|
||||||
|
num_trials: 3
|
||||||
|
max_steps: 100
|
||||||
|
interval: 1
|
||||||
|
seed: "CompleteSeed!"
|
||||||
|
model_class: "soil.Environment"
|
||||||
|
model_params:
|
||||||
|
topologies:
|
||||||
|
default:
|
||||||
|
params:
|
||||||
|
generator: complete_graph
|
||||||
|
n: 10
|
||||||
|
another_graph:
|
||||||
|
params:
|
||||||
|
generator: complete_graph
|
||||||
|
n: 2
|
||||||
|
agents:
|
||||||
|
# The values here will be used as default values for any agent
|
||||||
|
agent_class: CounterModel
|
||||||
|
topology: default
|
||||||
|
state:
|
||||||
|
times: 1
|
||||||
|
# This specifies a distribution of agents, each with a `weight` or an explicit number of agents
|
||||||
|
distribution:
|
||||||
|
- agent_class: CounterModel
|
||||||
|
weight: 1
|
||||||
|
# This is inherited from the default settings
|
||||||
|
#topology: default
|
||||||
|
state:
|
||||||
|
times: 3
|
||||||
|
- agent_class: AggregatedCounter
|
||||||
|
topology: default
|
||||||
|
weight: 0.2
|
||||||
|
fixed:
|
||||||
|
- name: 'Environment Agent 1'
|
||||||
|
# All the other agents will assigned to the 'default' group
|
||||||
|
group: environment
|
||||||
|
# Do not count this agent towards total limits
|
||||||
|
hidden: true
|
||||||
|
agent_class: soil.BaseAgent
|
||||||
|
topology: null
|
||||||
|
state:
|
||||||
|
times: 10
|
||||||
|
- agent_class: CounterModel
|
||||||
|
topology: another_graph
|
||||||
|
id: 0
|
||||||
|
state:
|
||||||
|
times: 1
|
||||||
|
total: 0
|
||||||
|
- agent_class: CounterModel
|
||||||
|
topology: another_graph
|
||||||
|
id: 1
|
||||||
|
override:
|
||||||
|
# 2 agents that match this filter will be updated to match the state {times: 5}
|
||||||
|
- filter:
|
||||||
|
agent_class: AggregatedCounter
|
||||||
|
n: 2
|
||||||
|
state:
|
||||||
|
times: 5
|
@ -2,7 +2,7 @@
|
|||||||
name: custom-generator
|
name: custom-generator
|
||||||
description: Using a custom generator for the network
|
description: Using a custom generator for the network
|
||||||
num_trials: 3
|
num_trials: 3
|
||||||
max_time: 100
|
max_steps: 100
|
||||||
interval: 1
|
interval: 1
|
||||||
network_params:
|
network_params:
|
||||||
generator: mymodule.mygenerator
|
generator: mymodule.mygenerator
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from networkx import Graph
|
from networkx import Graph
|
||||||
|
import random
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
def mygenerator(n=5, n_edges=5):
|
def mygenerator(n=5, n_edges=5):
|
||||||
@ -13,9 +14,9 @@ def mygenerator(n=5, n_edges=5):
|
|||||||
|
|
||||||
for i in range(n_edges):
|
for i in range(n_edges):
|
||||||
nodes = list(G.nodes)
|
nodes = list(G.nodes)
|
||||||
n_in = self.random.choice(nodes)
|
n_in = random.choice(nodes)
|
||||||
nodes.remove(n_in) # Avoid loops
|
nodes.remove(n_in) # Avoid loops
|
||||||
n_out = self.random.choice(nodes)
|
n_out = random.choice(nodes)
|
||||||
G.add_edge(n_in, n_out)
|
G.add_edge(n_in, n_out)
|
||||||
return G
|
return G
|
||||||
|
|
||||||
|
@ -3,17 +3,21 @@ name: mesa_sim
|
|||||||
group: tests
|
group: tests
|
||||||
dir_path: "/tmp"
|
dir_path: "/tmp"
|
||||||
num_trials: 3
|
num_trials: 3
|
||||||
max_time: 100
|
max_steps: 100
|
||||||
interval: 1
|
interval: 1
|
||||||
seed: '1'
|
seed: '1'
|
||||||
network_params:
|
model_class: social_wealth.MoneyEnv
|
||||||
|
model_params:
|
||||||
|
topologies:
|
||||||
|
default:
|
||||||
|
params:
|
||||||
generator: social_wealth.graph_generator
|
generator: social_wealth.graph_generator
|
||||||
n: 5
|
n: 5
|
||||||
network_agents:
|
agents:
|
||||||
|
distribution:
|
||||||
- agent_class: social_wealth.SocialMoneyAgent
|
- agent_class: social_wealth.SocialMoneyAgent
|
||||||
|
topology: default
|
||||||
weight: 1
|
weight: 1
|
||||||
environment_class: social_wealth.MoneyEnv
|
|
||||||
environment_params:
|
|
||||||
mesa_agent_class: social_wealth.MoneyAgent
|
mesa_agent_class: social_wealth.MoneyAgent
|
||||||
N: 10
|
N: 10
|
||||||
width: 50
|
width: 50
|
||||||
|
@ -5,7 +5,7 @@ environment_params:
|
|||||||
prob_neighbor_spread: 0.0
|
prob_neighbor_spread: 0.0
|
||||||
prob_tv_spread: 0.01
|
prob_tv_spread: 0.01
|
||||||
interval: 1
|
interval: 1
|
||||||
max_time: 300
|
max_steps: 300
|
||||||
name: Sim_all_dumb
|
name: Sim_all_dumb
|
||||||
network_agents:
|
network_agents:
|
||||||
- agent_class: newsspread.DumbViewer
|
- agent_class: newsspread.DumbViewer
|
||||||
@ -28,7 +28,7 @@ environment_params:
|
|||||||
prob_neighbor_spread: 0.0
|
prob_neighbor_spread: 0.0
|
||||||
prob_tv_spread: 0.01
|
prob_tv_spread: 0.01
|
||||||
interval: 1
|
interval: 1
|
||||||
max_time: 300
|
max_steps: 300
|
||||||
name: Sim_half_herd
|
name: Sim_half_herd
|
||||||
network_agents:
|
network_agents:
|
||||||
- agent_class: newsspread.DumbViewer
|
- agent_class: newsspread.DumbViewer
|
||||||
@ -59,7 +59,7 @@ environment_params:
|
|||||||
prob_neighbor_spread: 0.0
|
prob_neighbor_spread: 0.0
|
||||||
prob_tv_spread: 0.01
|
prob_tv_spread: 0.01
|
||||||
interval: 1
|
interval: 1
|
||||||
max_time: 300
|
max_steps: 300
|
||||||
name: Sim_all_herd
|
name: Sim_all_herd
|
||||||
network_agents:
|
network_agents:
|
||||||
- agent_class: newsspread.HerdViewer
|
- agent_class: newsspread.HerdViewer
|
||||||
@ -85,7 +85,7 @@ environment_params:
|
|||||||
prob_tv_spread: 0.01
|
prob_tv_spread: 0.01
|
||||||
prob_neighbor_cure: 0.1
|
prob_neighbor_cure: 0.1
|
||||||
interval: 1
|
interval: 1
|
||||||
max_time: 300
|
max_steps: 300
|
||||||
name: Sim_wise_herd
|
name: Sim_wise_herd
|
||||||
network_agents:
|
network_agents:
|
||||||
- agent_class: newsspread.HerdViewer
|
- agent_class: newsspread.HerdViewer
|
||||||
@ -110,7 +110,7 @@ environment_params:
|
|||||||
prob_tv_spread: 0.01
|
prob_tv_spread: 0.01
|
||||||
prob_neighbor_cure: 0.1
|
prob_neighbor_cure: 0.1
|
||||||
interval: 1
|
interval: 1
|
||||||
max_time: 300
|
max_steps: 300
|
||||||
name: Sim_all_wise
|
name: Sim_all_wise
|
||||||
network_agents:
|
network_agents:
|
||||||
- agent_class: newsspread.WiseViewer
|
- agent_class: newsspread.WiseViewer
|
||||||
|
@ -16,13 +16,13 @@ class DumbViewer(FSM, NetworkAgent):
|
|||||||
@state
|
@state
|
||||||
def neutral(self):
|
def neutral(self):
|
||||||
if self['has_tv']:
|
if self['has_tv']:
|
||||||
if prob(self.env['prob_tv_spread']):
|
if self.prob(self.model['prob_tv_spread']):
|
||||||
return self.infected
|
return self.infected
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def infected(self):
|
def infected(self):
|
||||||
for neighbor in self.get_neighboring_agents(state_id=self.neutral.id):
|
for neighbor in self.get_neighboring_agents(state_id=self.neutral.id):
|
||||||
if prob(self.env['prob_neighbor_spread']):
|
if self.prob(self.model['prob_neighbor_spread']):
|
||||||
neighbor.infect()
|
neighbor.infect()
|
||||||
|
|
||||||
def infect(self):
|
def infect(self):
|
||||||
@ -44,9 +44,9 @@ class HerdViewer(DumbViewer):
|
|||||||
'''Notice again that this is NOT a state. See DumbViewer.infect for reference'''
|
'''Notice again that this is NOT a state. See DumbViewer.infect for reference'''
|
||||||
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
||||||
total = self.count_neighboring_agents()
|
total = self.count_neighboring_agents()
|
||||||
prob_infect = self.env['prob_neighbor_spread'] * infected/total
|
prob_infect = self.model['prob_neighbor_spread'] * infected/total
|
||||||
self.debug('prob_infect', prob_infect)
|
self.debug('prob_infect', prob_infect)
|
||||||
if prob(prob_infect):
|
if self.prob(prob_infect):
|
||||||
self.set_state(self.infected)
|
self.set_state(self.infected)
|
||||||
|
|
||||||
|
|
||||||
@ -63,9 +63,9 @@ class WiseViewer(HerdViewer):
|
|||||||
|
|
||||||
@state
|
@state
|
||||||
def cured(self):
|
def cured(self):
|
||||||
prob_cure = self.env['prob_neighbor_cure']
|
prob_cure = self.model['prob_neighbor_cure']
|
||||||
for neighbor in self.get_neighboring_agents(state_id=self.infected.id):
|
for neighbor in self.get_neighboring_agents(state_id=self.infected.id):
|
||||||
if prob(prob_cure):
|
if self.prob(prob_cure):
|
||||||
try:
|
try:
|
||||||
neighbor.cure()
|
neighbor.cure()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -80,7 +80,7 @@ class WiseViewer(HerdViewer):
|
|||||||
1.0)
|
1.0)
|
||||||
infected = max(self.count_neighboring_agents(self.infected.id),
|
infected = max(self.count_neighboring_agents(self.infected.id),
|
||||||
1.0)
|
1.0)
|
||||||
prob_cure = self.env['prob_neighbor_cure'] * (cured/infected)
|
prob_cure = self.model['prob_neighbor_cure'] * (cured/infected)
|
||||||
if prob(prob_cure):
|
if self.prob(prob_cure):
|
||||||
return self.cured
|
return self.cured
|
||||||
return self.set_state(super().infected)
|
return self.set_state(super().infected)
|
||||||
|
@ -60,12 +60,10 @@ class Patron(FSM, NetworkAgent):
|
|||||||
'''
|
'''
|
||||||
level = logging.DEBUG
|
level = logging.DEBUG
|
||||||
|
|
||||||
defaults = {
|
pub = None
|
||||||
'pub': None,
|
drunk = False
|
||||||
'drunk': False,
|
pints = 0
|
||||||
'pints': 0,
|
max_pints = 3
|
||||||
'max_pints': 3,
|
|
||||||
}
|
|
||||||
|
|
||||||
@default_state
|
@default_state
|
||||||
@state
|
@state
|
||||||
@ -89,9 +87,9 @@ class Patron(FSM, NetworkAgent):
|
|||||||
return self.sober_in_pub
|
return self.sober_in_pub
|
||||||
self.debug('I am looking for a pub')
|
self.debug('I am looking for a pub')
|
||||||
group = list(self.get_neighboring_agents())
|
group = list(self.get_neighboring_agents())
|
||||||
for pub in self.env.available_pubs():
|
for pub in self.model.available_pubs():
|
||||||
self.debug('We\'re trying to get into {}: total: {}'.format(pub, len(group)))
|
self.debug('We\'re trying to get into {}: total: {}'.format(pub, len(group)))
|
||||||
if self.env.enter(pub, self, *group):
|
if self.model.enter(pub, self, *group):
|
||||||
self.info('We\'re all {} getting in {}!'.format(len(group), pub))
|
self.info('We\'re all {} getting in {}!'.format(len(group), pub))
|
||||||
return self.sober_in_pub
|
return self.sober_in_pub
|
||||||
|
|
||||||
@ -128,7 +126,7 @@ class Patron(FSM, NetworkAgent):
|
|||||||
success depend on both agents' openness.
|
success depend on both agents' openness.
|
||||||
'''
|
'''
|
||||||
if force or self['openness'] > self.random.random():
|
if force or self['openness'] > self.random.random():
|
||||||
self.env.add_edge(self, other_agent)
|
self.model.add_edge(self, other_agent)
|
||||||
self.info('Made some friend {}'.format(other_agent))
|
self.info('Made some friend {}'.format(other_agent))
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
@ -150,7 +148,7 @@ class Patron(FSM, NetworkAgent):
|
|||||||
return befriended
|
return befriended
|
||||||
|
|
||||||
|
|
||||||
class Police(FSM, NetworkAgent):
|
class Police(FSM):
|
||||||
'''Simple agent to take drunk people out of pubs.'''
|
'''Simple agent to take drunk people out of pubs.'''
|
||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
---
|
---
|
||||||
name: pubcrawl
|
name: pubcrawl
|
||||||
num_trials: 3
|
num_trials: 3
|
||||||
max_time: 10
|
max_steps: 10
|
||||||
dump: false
|
dump: false
|
||||||
network_params:
|
network_params:
|
||||||
# Generate 100 empty nodes. They will be assigned a network agent
|
# Generate 100 empty nodes. They will be assigned a network agent
|
||||||
|
4
examples/rabbits/README.md
Normal file
4
examples/rabbits/README.md
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
There are two similar implementations of this simulation.
|
||||||
|
|
||||||
|
- `basic`. Using simple primites
|
||||||
|
- `improved`. Using more advanced features such as the `time` module to avoid unnecessary computations (i.e., skip steps), and generator functions.
|
130
examples/rabbits/basic/rabbit_agents.py
Normal file
130
examples/rabbits/basic/rabbit_agents.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
|
||||||
|
from soil.time import Delta
|
||||||
|
from enum import Enum
|
||||||
|
from collections import Counter
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class RabbitModel(FSM, NetworkAgent):
|
||||||
|
|
||||||
|
sexual_maturity = 30
|
||||||
|
life_expectancy = 300
|
||||||
|
|
||||||
|
@default_state
|
||||||
|
@state
|
||||||
|
def newborn(self):
|
||||||
|
self.info('I am a newborn.')
|
||||||
|
self.age = 0
|
||||||
|
self.offspring = 0
|
||||||
|
return self.youngling
|
||||||
|
|
||||||
|
@state
|
||||||
|
def youngling(self):
|
||||||
|
self.age += 1
|
||||||
|
if self.age >= self.sexual_maturity:
|
||||||
|
self.info(f'I am fertile! My age is {self.age}')
|
||||||
|
return self.fertile
|
||||||
|
|
||||||
|
@state
|
||||||
|
def fertile(self):
|
||||||
|
raise Exception("Each subclass should define its fertile state")
|
||||||
|
|
||||||
|
@state
|
||||||
|
def dead(self):
|
||||||
|
self.die()
|
||||||
|
|
||||||
|
|
||||||
|
class Male(RabbitModel):
|
||||||
|
max_females = 5
|
||||||
|
mating_prob = 0.001
|
||||||
|
|
||||||
|
@state
|
||||||
|
def fertile(self):
|
||||||
|
self.age += 1
|
||||||
|
|
||||||
|
if self.age > self.life_expectancy:
|
||||||
|
return self.dead
|
||||||
|
|
||||||
|
# Males try to mate
|
||||||
|
for f in self.model.agents(agent_class=Female,
|
||||||
|
state_id=Female.fertile.id,
|
||||||
|
limit=self.max_females):
|
||||||
|
self.debug('FOUND A FEMALE: ', repr(f), self.mating_prob)
|
||||||
|
if self.prob(self['mating_prob']):
|
||||||
|
f.impregnate(self)
|
||||||
|
break # Take a break
|
||||||
|
|
||||||
|
|
||||||
|
class Female(RabbitModel):
|
||||||
|
gestation = 100
|
||||||
|
|
||||||
|
@state
|
||||||
|
def fertile(self):
|
||||||
|
# Just wait for a Male
|
||||||
|
self.age += 1
|
||||||
|
if self.age > self.life_expectancy:
|
||||||
|
return self.dead
|
||||||
|
|
||||||
|
def impregnate(self, male):
|
||||||
|
self.info(f'{repr(male)} impregnating female {repr(self)}')
|
||||||
|
self.mate = male
|
||||||
|
self.pregnancy = -1
|
||||||
|
self.set_state(self.pregnant, when=self.now)
|
||||||
|
self.number_of_babies = int(8+4*self.random.random())
|
||||||
|
self.debug('I am pregnant')
|
||||||
|
|
||||||
|
@state
|
||||||
|
def pregnant(self):
|
||||||
|
self.age += 1
|
||||||
|
self.pregnancy += 1
|
||||||
|
|
||||||
|
if self.prob(self.age / self.life_expectancy):
|
||||||
|
return self.die()
|
||||||
|
|
||||||
|
if self.pregnancy >= self.gestation:
|
||||||
|
self.info('Having {} babies'.format(self.number_of_babies))
|
||||||
|
for i in range(self.number_of_babies):
|
||||||
|
state = {}
|
||||||
|
agent_class = self.random.choice([Male, Female])
|
||||||
|
child = self.model.add_node(agent_class=agent_class,
|
||||||
|
topology=self.topology,
|
||||||
|
**state)
|
||||||
|
child.add_edge(self)
|
||||||
|
try:
|
||||||
|
child.add_edge(self.mate)
|
||||||
|
self.model.agents[self.mate].offspring += 1
|
||||||
|
except ValueError:
|
||||||
|
self.debug('The father has passed away')
|
||||||
|
|
||||||
|
self.offspring += 1
|
||||||
|
self.mate = None
|
||||||
|
return self.fertile
|
||||||
|
|
||||||
|
@state
|
||||||
|
def dead(self):
|
||||||
|
super().dead()
|
||||||
|
if 'pregnancy' in self and self['pregnancy'] > -1:
|
||||||
|
self.info('A mother has died carrying a baby!!')
|
||||||
|
|
||||||
|
|
||||||
|
class RandomAccident(BaseAgent):
|
||||||
|
|
||||||
|
level = logging.INFO
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
rabbits_alive = self.model.topology.number_of_nodes()
|
||||||
|
|
||||||
|
if not rabbits_alive:
|
||||||
|
return self.die()
|
||||||
|
|
||||||
|
prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
|
||||||
|
self.debug('Killing some rabbits with prob={}!'.format(prob_death))
|
||||||
|
for i in self.iter_agents(agent_class=RabbitModel):
|
||||||
|
if i.state.id == i.dead.id:
|
||||||
|
continue
|
||||||
|
if self.prob(prob_death):
|
||||||
|
self.info('I killed a rabbit: {}'.format(i.id))
|
||||||
|
rabbits_alive -= 1
|
||||||
|
i.set_state(i.dead)
|
||||||
|
self.debug('Rabbits alive: {}'.format(rabbits_alive))
|
41
examples/rabbits/basic/rabbits.yml
Normal file
41
examples/rabbits/basic/rabbits.yml
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
---
|
||||||
|
version: '2'
|
||||||
|
name: rabbits_basic
|
||||||
|
num_trials: 1
|
||||||
|
seed: MySeed
|
||||||
|
description: null
|
||||||
|
group: null
|
||||||
|
interval: 1.0
|
||||||
|
max_time: 100
|
||||||
|
model_class: soil.environment.Environment
|
||||||
|
model_params:
|
||||||
|
agents:
|
||||||
|
topology: default
|
||||||
|
agent_class: rabbit_agents.RabbitModel
|
||||||
|
distribution:
|
||||||
|
- agent_class: rabbit_agents.Male
|
||||||
|
topology: default
|
||||||
|
weight: 1
|
||||||
|
- agent_class: rabbit_agents.Female
|
||||||
|
topology: default
|
||||||
|
weight: 1
|
||||||
|
fixed:
|
||||||
|
- agent_class: rabbit_agents.RandomAccident
|
||||||
|
topology: null
|
||||||
|
hidden: true
|
||||||
|
state:
|
||||||
|
group: environment
|
||||||
|
state:
|
||||||
|
group: network
|
||||||
|
mating_prob: 0.1
|
||||||
|
prob_death: 0.001
|
||||||
|
topologies:
|
||||||
|
default:
|
||||||
|
topology:
|
||||||
|
directed: true
|
||||||
|
links: []
|
||||||
|
nodes:
|
||||||
|
- id: 1
|
||||||
|
- id: 0
|
||||||
|
extra:
|
||||||
|
visualization_params: {}
|
130
examples/rabbits/improved/rabbit_agents.py
Normal file
130
examples/rabbits/improved/rabbit_agents.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
|
||||||
|
from soil.time import Delta, When, NEVER
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class RabbitModel(FSM, NetworkAgent):
|
||||||
|
|
||||||
|
mating_prob = 0.005
|
||||||
|
offspring = 0
|
||||||
|
birth = None
|
||||||
|
|
||||||
|
sexual_maturity = 3
|
||||||
|
life_expectancy = 30
|
||||||
|
|
||||||
|
@default_state
|
||||||
|
@state
|
||||||
|
def newborn(self):
|
||||||
|
self.birth = self.now
|
||||||
|
self.info(f'I am a newborn.')
|
||||||
|
self.model['rabbits_alive'] = self.model.get('rabbits_alive', 0) + 1
|
||||||
|
|
||||||
|
# Here we can skip the `youngling` state by using a coroutine/generator.
|
||||||
|
while self.age < self.sexual_maturity:
|
||||||
|
interval = self.sexual_maturity - self.age
|
||||||
|
yield Delta(interval)
|
||||||
|
|
||||||
|
self.info(f'I am fertile! My age is {self.age}')
|
||||||
|
return self.fertile
|
||||||
|
|
||||||
|
@property
|
||||||
|
def age(self):
|
||||||
|
return self.now - self.birth
|
||||||
|
|
||||||
|
@state
|
||||||
|
def fertile(self):
|
||||||
|
raise Exception("Each subclass should define its fertile state")
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
super().step()
|
||||||
|
if self.prob(self.age / self.life_expectancy):
|
||||||
|
return self.die()
|
||||||
|
|
||||||
|
|
||||||
|
class Male(RabbitModel):
|
||||||
|
|
||||||
|
max_females = 5
|
||||||
|
|
||||||
|
@state
|
||||||
|
def fertile(self):
|
||||||
|
# Males try to mate
|
||||||
|
for f in self.model.agents(agent_class=Female,
|
||||||
|
state_id=Female.fertile.id,
|
||||||
|
limit=self.max_females):
|
||||||
|
self.debug('Found a female:', repr(f))
|
||||||
|
if self.prob(self['mating_prob']):
|
||||||
|
f.impregnate(self)
|
||||||
|
break # Take a break, don't try to impregnate the rest
|
||||||
|
|
||||||
|
|
||||||
|
class Female(RabbitModel):
|
||||||
|
due_date = None
|
||||||
|
age_of_pregnancy = None
|
||||||
|
gestation = 10
|
||||||
|
mate = None
|
||||||
|
|
||||||
|
@state
|
||||||
|
def fertile(self):
|
||||||
|
return self.fertile, NEVER
|
||||||
|
|
||||||
|
@state
|
||||||
|
def pregnant(self):
|
||||||
|
self.info('I am pregnant')
|
||||||
|
if self.age > self.life_expectancy:
|
||||||
|
return self.dead
|
||||||
|
|
||||||
|
self.due_date = self.now + self.gestation
|
||||||
|
|
||||||
|
number_of_babies = int(8+4*self.random.random())
|
||||||
|
|
||||||
|
while self.now < self.due_date:
|
||||||
|
yield When(self.due_date)
|
||||||
|
|
||||||
|
self.info('Having {} babies'.format(number_of_babies))
|
||||||
|
for i in range(number_of_babies):
|
||||||
|
agent_class = self.random.choice([Male, Female])
|
||||||
|
child = self.model.add_node(agent_class=agent_class,
|
||||||
|
topology=self.topology)
|
||||||
|
self.model.add_edge(self, child)
|
||||||
|
self.model.add_edge(self.mate, child)
|
||||||
|
self.offspring += 1
|
||||||
|
self.model.agents[self.mate].offspring += 1
|
||||||
|
self.mate = None
|
||||||
|
self.due_date = None
|
||||||
|
return self.fertile
|
||||||
|
|
||||||
|
@state
|
||||||
|
def dead(self):
|
||||||
|
super().dead()
|
||||||
|
if self.due_date is not None:
|
||||||
|
self.info('A mother has died carrying a baby!!')
|
||||||
|
|
||||||
|
def impregnate(self, male):
|
||||||
|
self.info(f'{repr(male)} impregnating female {repr(self)}')
|
||||||
|
self.mate = male
|
||||||
|
self.set_state(self.pregnant, when=self.now)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomAccident(BaseAgent):
|
||||||
|
|
||||||
|
level = logging.INFO
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
rabbits_total = self.model.topology.number_of_nodes()
|
||||||
|
if 'rabbits_alive' not in self.model:
|
||||||
|
self.model['rabbits_alive'] = 0
|
||||||
|
rabbits_alive = self.model.get('rabbits_alive', rabbits_total)
|
||||||
|
prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
|
||||||
|
self.debug('Killing some rabbits with prob={}!'.format(prob_death))
|
||||||
|
for i in self.model.network_agents:
|
||||||
|
if i.state.id == i.dead.id:
|
||||||
|
continue
|
||||||
|
if self.prob(prob_death):
|
||||||
|
self.info('I killed a rabbit: {}'.format(i.id))
|
||||||
|
rabbits_alive = self.model['rabbits_alive'] = rabbits_alive -1
|
||||||
|
i.set_state(i.dead)
|
||||||
|
self.debug('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total))
|
||||||
|
if self.model.count_agents(state_id=RabbitModel.dead.id) == self.model.topology.number_of_nodes():
|
||||||
|
self.die()
|
41
examples/rabbits/improved/rabbits.yml
Normal file
41
examples/rabbits/improved/rabbits.yml
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
---
|
||||||
|
version: '2'
|
||||||
|
name: rabbits_improved
|
||||||
|
num_trials: 1
|
||||||
|
seed: MySeed
|
||||||
|
description: null
|
||||||
|
group: null
|
||||||
|
interval: 1.0
|
||||||
|
max_time: 100
|
||||||
|
model_class: soil.environment.Environment
|
||||||
|
model_params:
|
||||||
|
agents:
|
||||||
|
topology: default
|
||||||
|
agent_class: rabbit_agents.RabbitModel
|
||||||
|
distribution:
|
||||||
|
- agent_class: rabbit_agents.Male
|
||||||
|
topology: default
|
||||||
|
weight: 1
|
||||||
|
- agent_class: rabbit_agents.Female
|
||||||
|
topology: default
|
||||||
|
weight: 1
|
||||||
|
fixed:
|
||||||
|
- agent_class: rabbit_agents.RandomAccident
|
||||||
|
topology: null
|
||||||
|
hidden: true
|
||||||
|
state:
|
||||||
|
group: environment
|
||||||
|
state:
|
||||||
|
group: network
|
||||||
|
mating_prob: 0.1
|
||||||
|
prob_death: 0.001
|
||||||
|
topologies:
|
||||||
|
default:
|
||||||
|
topology:
|
||||||
|
directed: true
|
||||||
|
links: []
|
||||||
|
nodes:
|
||||||
|
- id: 1
|
||||||
|
- id: 0
|
||||||
|
extra:
|
||||||
|
visualization_params: {}
|
@ -1,133 +0,0 @@
|
|||||||
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
|
|
||||||
from enum import Enum
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
class Genders(Enum):
|
|
||||||
male = 'male'
|
|
||||||
female = 'female'
|
|
||||||
|
|
||||||
|
|
||||||
class RabbitModel(FSM, NetworkAgent):
|
|
||||||
|
|
||||||
defaults = {
|
|
||||||
'age': 0,
|
|
||||||
'gender': Genders.male.value,
|
|
||||||
'mating_prob': 0.001,
|
|
||||||
'offspring': 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
sexual_maturity = 3 #4*30
|
|
||||||
life_expectancy = 365 * 3
|
|
||||||
gestation = 33
|
|
||||||
pregnancy = -1
|
|
||||||
max_females = 5
|
|
||||||
|
|
||||||
@default_state
|
|
||||||
@state
|
|
||||||
def newborn(self):
|
|
||||||
self.debug(f'I am a newborn at age {self["age"]}')
|
|
||||||
self['age'] += 1
|
|
||||||
|
|
||||||
if self['age'] >= self.sexual_maturity:
|
|
||||||
self.debug('I am fertile!')
|
|
||||||
return self.fertile
|
|
||||||
@state
|
|
||||||
def fertile(self):
|
|
||||||
raise Exception("Each subclass should define its fertile state")
|
|
||||||
|
|
||||||
@state
|
|
||||||
def dead(self):
|
|
||||||
self.info('Agent {} is dying'.format(self.id))
|
|
||||||
self.die()
|
|
||||||
|
|
||||||
|
|
||||||
class Male(RabbitModel):
|
|
||||||
|
|
||||||
@state
|
|
||||||
def fertile(self):
|
|
||||||
self['age'] += 1
|
|
||||||
if self['age'] > self.life_expectancy:
|
|
||||||
return self.dead
|
|
||||||
|
|
||||||
if self['gender'] == Genders.female.value:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Males try to mate
|
|
||||||
for f in self.get_agents(state_id=Female.fertile.id,
|
|
||||||
agent_class=Female,
|
|
||||||
limit_neighbors=False,
|
|
||||||
limit=self.max_females):
|
|
||||||
r = self.random.random()
|
|
||||||
if r < self['mating_prob']:
|
|
||||||
self.impregnate(f)
|
|
||||||
break # Take a break
|
|
||||||
def impregnate(self, whom):
|
|
||||||
whom['pregnancy'] = 0
|
|
||||||
whom['mate'] = self.id
|
|
||||||
whom.set_state(whom.pregnant)
|
|
||||||
self.debug('{} impregnating: {}. {}'.format(self.id, whom.id, whom.state))
|
|
||||||
|
|
||||||
class Female(RabbitModel):
|
|
||||||
@state
|
|
||||||
def fertile(self):
|
|
||||||
# Just wait for a Male
|
|
||||||
pass
|
|
||||||
|
|
||||||
@state
|
|
||||||
def pregnant(self):
|
|
||||||
self['age'] += 1
|
|
||||||
if self['age'] > self.life_expectancy:
|
|
||||||
return self.dead
|
|
||||||
|
|
||||||
self['pregnancy'] += 1
|
|
||||||
self.debug('Pregnancy: {}'.format(self['pregnancy']))
|
|
||||||
if self['pregnancy'] >= self.gestation:
|
|
||||||
number_of_babies = int(8+4*self.random.random())
|
|
||||||
self.info('Having {} babies'.format(number_of_babies))
|
|
||||||
for i in range(number_of_babies):
|
|
||||||
state = {}
|
|
||||||
state['gender'] = self.random.choice(list(Genders)).value
|
|
||||||
child = self.env.add_node(self.__class__, state)
|
|
||||||
self.env.add_edge(self.id, child.id)
|
|
||||||
self.env.add_edge(self['mate'], child.id)
|
|
||||||
# self.add_edge()
|
|
||||||
self.debug('A BABY IS COMING TO LIFE')
|
|
||||||
self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.topology.number_of_nodes())+1
|
|
||||||
self.debug('Rabbits alive: {}'.format(self.env['rabbits_alive']))
|
|
||||||
self['offspring'] += 1
|
|
||||||
self.env.get_agent(self['mate'])['offspring'] += 1
|
|
||||||
del self['mate']
|
|
||||||
self['pregnancy'] = -1
|
|
||||||
return self.fertile
|
|
||||||
|
|
||||||
@state
|
|
||||||
def dead(self):
|
|
||||||
super().dead()
|
|
||||||
if 'pregnancy' in self and self['pregnancy'] > -1:
|
|
||||||
self.info('A mother has died carrying a baby!!')
|
|
||||||
|
|
||||||
|
|
||||||
class RandomAccident(BaseAgent):
|
|
||||||
|
|
||||||
level = logging.DEBUG
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
rabbits_total = self.env.topology.number_of_nodes()
|
|
||||||
if 'rabbits_alive' not in self.env:
|
|
||||||
self.env['rabbits_alive'] = 0
|
|
||||||
rabbits_alive = self.env.get('rabbits_alive', rabbits_total)
|
|
||||||
prob_death = self.env.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
|
|
||||||
self.debug('Killing some rabbits with prob={}!'.format(prob_death))
|
|
||||||
for i in self.env.network_agents:
|
|
||||||
if i.state['id'] == i.dead.id:
|
|
||||||
continue
|
|
||||||
if self.prob(prob_death):
|
|
||||||
self.debug('I killed a rabbit: {}'.format(i.id))
|
|
||||||
rabbits_alive = self.env['rabbits_alive'] = rabbits_alive -1
|
|
||||||
self.log('Rabbits alive: {}'.format(self.env['rabbits_alive']))
|
|
||||||
i.set_state(i.dead)
|
|
||||||
self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total))
|
|
||||||
if self.env.count_agents(state_id=RabbitModel.dead.id) == self.env.topology.number_of_nodes():
|
|
||||||
self.die()
|
|
@ -1,20 +0,0 @@
|
|||||||
---
|
|
||||||
name: rabbits_example
|
|
||||||
max_time: 100
|
|
||||||
interval: 1
|
|
||||||
seed: MySeed
|
|
||||||
agent_class: rabbit_agents.RabbitModel
|
|
||||||
environment_agents:
|
|
||||||
- agent_class: rabbit_agents.RandomAccident
|
|
||||||
environment_params:
|
|
||||||
prob_death: 0.001
|
|
||||||
default_state:
|
|
||||||
mating_prob: 0.1
|
|
||||||
topology:
|
|
||||||
nodes:
|
|
||||||
- id: 1
|
|
||||||
agent_class: rabbit_agents.Male
|
|
||||||
- id: 0
|
|
||||||
agent_class: rabbit_agents.Female
|
|
||||||
directed: true
|
|
||||||
links: []
|
|
@ -6,9 +6,10 @@ template:
|
|||||||
group: simple
|
group: simple
|
||||||
num_trials: 1
|
num_trials: 1
|
||||||
interval: 1
|
interval: 1
|
||||||
max_time: 2
|
max_steps: 2
|
||||||
seed: "CompleteSeed!"
|
seed: "CompleteSeed!"
|
||||||
dump: false
|
dump: false
|
||||||
|
model_params:
|
||||||
network_params:
|
network_params:
|
||||||
generator: complete_graph
|
generator: complete_graph
|
||||||
n: 10
|
n: 10
|
||||||
@ -19,7 +20,6 @@ template:
|
|||||||
state_id: 0
|
state_id: 0
|
||||||
- agent_class: AggregatedCounter
|
- agent_class: AggregatedCounter
|
||||||
weight: "{{ 1 - x1 }}"
|
weight: "{{ 1 - x1 }}"
|
||||||
environment_params:
|
|
||||||
name: "{{ x3 }}"
|
name: "{{ x3 }}"
|
||||||
skip_test: true
|
skip_test: true
|
||||||
vars:
|
vars:
|
||||||
|
@ -81,6 +81,26 @@ class TerroristSpreadModel(FSM, Geo):
|
|||||||
return
|
return
|
||||||
return self.leader
|
return self.leader
|
||||||
|
|
||||||
|
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
||||||
|
'''Get a list of nodes in the ego network of *node* of radius *steps*'''
|
||||||
|
node = as_node(node if node is not None else self)
|
||||||
|
G = self.subgraph(**kwargs)
|
||||||
|
return nx.ego_graph(G, node, center=center, radius=steps).nodes()
|
||||||
|
|
||||||
|
def degree(self, node, force=False):
|
||||||
|
node = as_node(node)
|
||||||
|
if force or (not hasattr(self.model, '_degree')) or getattr(self.model, '_last_step', 0) < self.now:
|
||||||
|
self.model._degree = nx.degree_centrality(self.G)
|
||||||
|
self.model._last_step = self.now
|
||||||
|
return self.model._degree[node]
|
||||||
|
|
||||||
|
def betweenness(self, node, force=False):
|
||||||
|
node = as_node(node)
|
||||||
|
if force or (not hasattr(self.model, '_betweenness')) or getattr(self.model, '_last_step', 0) < self.now:
|
||||||
|
self.model._betweenness = nx.betweenness_centrality(self.G)
|
||||||
|
self.model._last_step = self.now
|
||||||
|
return self.model._betweenness[node]
|
||||||
|
|
||||||
|
|
||||||
class TrainingAreaModel(FSM, Geo):
|
class TrainingAreaModel(FSM, Geo):
|
||||||
"""
|
"""
|
||||||
@ -194,14 +214,14 @@ class TerroristNetworkModel(TerroristSpreadModel):
|
|||||||
break
|
break
|
||||||
|
|
||||||
def get_distance(self, target):
|
def get_distance(self, target):
|
||||||
source_x, source_y = nx.get_node_attributes(self.topology, 'pos')[self.id]
|
source_x, source_y = nx.get_node_attributes(self.G, 'pos')[self.id]
|
||||||
target_x, target_y = nx.get_node_attributes(self.topology, 'pos')[target]
|
target_x, target_y = nx.get_node_attributes(self.G, 'pos')[target]
|
||||||
dx = abs( source_x - target_x )
|
dx = abs( source_x - target_x )
|
||||||
dy = abs( source_y - target_y )
|
dy = abs( source_y - target_y )
|
||||||
return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 )
|
return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 )
|
||||||
|
|
||||||
def shortest_path_length(self, target):
|
def shortest_path_length(self, target):
|
||||||
try:
|
try:
|
||||||
return nx.shortest_path_length(self.topology, self.id, target)
|
return nx.shortest_path_length(self.G, self.id, target)
|
||||||
except nx.NetworkXNoPath:
|
except nx.NetworkXNoPath:
|
||||||
return float('inf')
|
return float('inf')
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
name: TerroristNetworkModel_sim
|
name: TerroristNetworkModel_sim
|
||||||
max_time: 150
|
max_steps: 150
|
||||||
num_trials: 1
|
num_trials: 1
|
||||||
network_params:
|
model_params:
|
||||||
|
network_params:
|
||||||
generator: random_geometric_graph
|
generator: random_geometric_graph
|
||||||
radius: 0.2
|
radius: 0.2
|
||||||
# generator: geographical_threshold_graph
|
# generator: geographical_threshold_graph
|
||||||
# theta: 20
|
# theta: 20
|
||||||
n: 100
|
n: 100
|
||||||
network_agents:
|
network_agents:
|
||||||
- agent_class: TerroristNetworkModel.TerroristNetworkModel
|
- agent_class: TerroristNetworkModel.TerroristNetworkModel
|
||||||
weight: 0.8
|
weight: 0.8
|
||||||
state:
|
state:
|
||||||
@ -25,7 +26,6 @@ network_agents:
|
|||||||
state:
|
state:
|
||||||
id: civilian # Civilian
|
id: civilian # Civilian
|
||||||
|
|
||||||
environment_params:
|
|
||||||
# TerroristSpreadModel
|
# TerroristSpreadModel
|
||||||
information_spread_intensity: 0.7
|
information_spread_intensity: 0.7
|
||||||
terrorist_additional_influence: 0.035
|
terrorist_additional_influence: 0.035
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
---
|
---
|
||||||
name: torvalds_example
|
name: torvalds_example
|
||||||
max_time: 10
|
max_steps: 10
|
||||||
interval: 2
|
interval: 2
|
||||||
agent_class: CounterModel
|
model_params:
|
||||||
default_state:
|
agent_class: CounterModel
|
||||||
|
default_state:
|
||||||
skill_level: 'beginner'
|
skill_level: 'beginner'
|
||||||
network_params:
|
network_params:
|
||||||
path: 'torvalds.edgelist'
|
path: 'torvalds.edgelist'
|
||||||
states:
|
states:
|
||||||
Torvalds:
|
Torvalds:
|
||||||
skill_level: 'God'
|
skill_level: 'God'
|
||||||
balkian:
|
balkian:
|
||||||
|
@ -2,8 +2,9 @@ networkx>=2.5
|
|||||||
numpy
|
numpy
|
||||||
matplotlib
|
matplotlib
|
||||||
pyyaml>=5.1
|
pyyaml>=5.1
|
||||||
pandas>=0.23
|
pandas>=1
|
||||||
SALib>=1.3
|
SALib>=1.3
|
||||||
Jinja2
|
Jinja2
|
||||||
Mesa>=0.8.9
|
Mesa>=1
|
||||||
pydantic>=1.9
|
pydantic>=1.9
|
||||||
|
sqlalchemy>=1.4
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import pdb
|
|
||||||
import logging
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
from .version import __version__
|
from .version import __version__
|
||||||
|
|
||||||
@ -16,11 +18,10 @@ from . import agents
|
|||||||
from .simulation import *
|
from .simulation import *
|
||||||
from .environment import Environment
|
from .environment import Environment
|
||||||
from . import serialization
|
from . import serialization
|
||||||
from . import analysis
|
|
||||||
from .utils import logger
|
from .utils import logger
|
||||||
from .time import *
|
from .time import *
|
||||||
|
|
||||||
def main():
|
def main(cfg='simulation.yml', **kwargs):
|
||||||
import argparse
|
import argparse
|
||||||
from . import simulation
|
from . import simulation
|
||||||
|
|
||||||
@ -29,7 +30,7 @@ def main():
|
|||||||
parser = argparse.ArgumentParser(description='Run a SOIL simulation')
|
parser = argparse.ArgumentParser(description='Run a SOIL simulation')
|
||||||
parser.add_argument('file', type=str,
|
parser.add_argument('file', type=str,
|
||||||
nargs="?",
|
nargs="?",
|
||||||
default='simulation.yml',
|
default=cfg,
|
||||||
help='Configuration file for the simulation (e.g., YAML or JSON)')
|
help='Configuration file for the simulation (e.g., YAML or JSON)')
|
||||||
parser.add_argument('--version', action='store_true',
|
parser.add_argument('--version', action='store_true',
|
||||||
help='Show version info and exit')
|
help='Show version info and exit')
|
||||||
@ -39,6 +40,8 @@ def main():
|
|||||||
help='Do not store the results of the simulation to disk, show in terminal instead.')
|
help='Do not store the results of the simulation to disk, show in terminal instead.')
|
||||||
parser.add_argument('--pdb', action='store_true',
|
parser.add_argument('--pdb', action='store_true',
|
||||||
help='Use a pdb console in case of exception.')
|
help='Use a pdb console in case of exception.')
|
||||||
|
parser.add_argument('--debug', action='store_true',
|
||||||
|
help='Run a customized version of a pdb console to debug a simulation.')
|
||||||
parser.add_argument('--graph', '-g', action='store_true',
|
parser.add_argument('--graph', '-g', action='store_true',
|
||||||
help='Dump each trial\'s network topology as a GEXF graph. Defaults to false.')
|
help='Dump each trial\'s network topology as a GEXF graph. Defaults to false.')
|
||||||
parser.add_argument('--csv', action='store_true',
|
parser.add_argument('--csv', action='store_true',
|
||||||
@ -51,9 +54,22 @@ def main():
|
|||||||
help='Run trials serially and synchronously instead of in parallel. Defaults to false.')
|
help='Run trials serially and synchronously instead of in parallel. Defaults to false.')
|
||||||
parser.add_argument('-e', '--exporter', action='append',
|
parser.add_argument('-e', '--exporter', action='append',
|
||||||
help='Export environment and/or simulations using this exporter')
|
help='Export environment and/or simulations using this exporter')
|
||||||
|
parser.add_argument('--only-convert', '--convert', action='store_true',
|
||||||
|
help='Do not run the simulation, only convert the configuration file(s) and output them.')
|
||||||
|
|
||||||
|
|
||||||
|
parser.add_argument("--set",
|
||||||
|
metavar="KEY=VALUE",
|
||||||
|
action='append',
|
||||||
|
help="Set a number of parameters that will be passed to the simulation."
|
||||||
|
"(do not put spaces before or after the = sign). "
|
||||||
|
"If a value contains spaces, you should define "
|
||||||
|
"it with double quotes: "
|
||||||
|
'foo="this is a sentence". Note that '
|
||||||
|
"values are always treated as strings.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logging.basicConfig(level=getattr(logging, (args.level or 'INFO').upper()))
|
logger.setLevel(getattr(logging, (args.level or 'INFO').upper()))
|
||||||
|
|
||||||
if args.version:
|
if args.version:
|
||||||
return
|
return
|
||||||
@ -65,9 +81,10 @@ def main():
|
|||||||
|
|
||||||
logger.info('Loading config file: {}'.format(args.file))
|
logger.info('Loading config file: {}'.format(args.file))
|
||||||
|
|
||||||
if args.pdb:
|
if args.pdb or args.debug:
|
||||||
args.synchronous = True
|
args.synchronous = True
|
||||||
|
if args.debug:
|
||||||
|
os.environ['SOIL_DEBUG'] = 'true'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
exporters = list(args.exporter or ['default', ])
|
exporters = list(args.exporter or ['default', ])
|
||||||
@ -82,18 +99,48 @@ def main():
|
|||||||
if not os.path.exists(args.file):
|
if not os.path.exists(args.file):
|
||||||
logger.error('Please, input a valid file')
|
logger.error('Please, input a valid file')
|
||||||
return
|
return
|
||||||
simulation.run_from_config(args.file,
|
for sim in simulation.iter_from_config(args.file):
|
||||||
dry_run=args.dry_run,
|
if args.set:
|
||||||
|
for s in args.set:
|
||||||
|
k, v = s.split('=', 1)[:2]
|
||||||
|
v = eval(v)
|
||||||
|
tail, *head = k.rsplit('.', 1)[::-1]
|
||||||
|
target = sim
|
||||||
|
if head:
|
||||||
|
for part in head[0].split('.'):
|
||||||
|
try:
|
||||||
|
target = getattr(target, part)
|
||||||
|
except AttributeError:
|
||||||
|
target = target[part]
|
||||||
|
try:
|
||||||
|
setattr(target, tail, v)
|
||||||
|
except AttributeError:
|
||||||
|
target[tail] = v
|
||||||
|
|
||||||
|
if args.only_convert:
|
||||||
|
print(sim.to_yaml())
|
||||||
|
continue
|
||||||
|
|
||||||
|
sim.run_simulation(dry_run=args.dry_run,
|
||||||
exporters=exporters,
|
exporters=exporters,
|
||||||
parallel=(not args.synchronous),
|
parallel=(not args.synchronous),
|
||||||
outdir=args.output,
|
outdir=args.output,
|
||||||
exporter_params=exp_params)
|
exporter_params=exp_params,
|
||||||
except Exception:
|
**kwargs)
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
if args.pdb:
|
if args.pdb:
|
||||||
pdb.post_mortem()
|
from .debugging import post_mortem
|
||||||
|
print(traceback.format_exc())
|
||||||
|
post_mortem()
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def easy(cfg, debug=False):
|
||||||
|
sim = simulation.from_config(cfg)
|
||||||
|
if debug or os.environ.get('SOIL_DEBUG'):
|
||||||
|
from .debugging import setup
|
||||||
|
setup(sys._getframe().f_back)
|
||||||
|
return sim
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
@ -7,15 +7,13 @@ class CounterModel(NetworkAgent):
|
|||||||
in each step and adds it to its state.
|
in each step and adds it to its state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
defaults = {
|
times = 0
|
||||||
'times': 0,
|
neighbors = 0
|
||||||
'neighbors': 0,
|
total = 0
|
||||||
'total': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
# Outside effects
|
# Outside effects
|
||||||
total = len(list(self.env.agents))
|
total = len(list(self.model.schedule._agents))
|
||||||
neighbors = len(list(self.get_neighboring_agents()))
|
neighbors = len(list(self.get_neighboring_agents()))
|
||||||
self['times'] = self.get('times', 0) + 1
|
self['times'] = self.get('times', 0) + 1
|
||||||
self['neighbors'] = neighbors
|
self['neighbors'] = neighbors
|
||||||
@ -28,17 +26,15 @@ class AggregatedCounter(NetworkAgent):
|
|||||||
in each step and adds it to its state.
|
in each step and adds it to its state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
defaults = {
|
times = 0
|
||||||
'times': 0,
|
neighbors = 0
|
||||||
'neighbors': 0,
|
total = 0
|
||||||
'total': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
# Outside effects
|
# Outside effects
|
||||||
self['times'] += 1
|
self['times'] += 1
|
||||||
neighbors = len(list(self.get_neighboring_agents()))
|
neighbors = len(list(self.get_neighboring_agents()))
|
||||||
self['neighbors'] += neighbors
|
self['neighbors'] += neighbors
|
||||||
total = len(list(self.env.agents))
|
total = len(list(self.model.schedule.agents))
|
||||||
self['total'] += total
|
self['total'] += total
|
||||||
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from collections.abc import MutableMapping, Mapping, Set
|
from collections.abc import MutableMapping, Mapping, Set
|
||||||
@ -5,9 +7,13 @@ from abc import ABCMeta
|
|||||||
from copy import deepcopy, copy
|
from copy import deepcopy, copy
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from itertools import islice, chain
|
from itertools import islice, chain
|
||||||
import json
|
import inspect
|
||||||
|
import types
|
||||||
|
import textwrap
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from mesa import Agent as MesaAgent
|
from mesa import Agent as MesaAgent
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
@ -27,7 +33,31 @@ class DeadAgent(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(MesaAgent, MutableMapping):
|
class MetaAgent(ABCMeta):
|
||||||
|
def __new__(mcls, name, bases, namespace):
|
||||||
|
defaults = {}
|
||||||
|
|
||||||
|
# Re-use defaults from inherited classes
|
||||||
|
for i in bases:
|
||||||
|
if isinstance(i, MetaAgent):
|
||||||
|
defaults.update(i._defaults)
|
||||||
|
|
||||||
|
new_nmspc = {
|
||||||
|
'_defaults': defaults,
|
||||||
|
}
|
||||||
|
|
||||||
|
for attr, func in namespace.items():
|
||||||
|
if isinstance(func, types.FunctionType) or isinstance(func, property) or attr[0] == '_':
|
||||||
|
new_nmspc[attr] = func
|
||||||
|
elif attr == 'defaults':
|
||||||
|
defaults.update(func)
|
||||||
|
else:
|
||||||
|
defaults[attr] = copy(func)
|
||||||
|
|
||||||
|
return super().__new__(mcls=mcls, name=name, bases=bases, namespace=new_nmspc)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||||
"""
|
"""
|
||||||
A special type of Mesa Agent that:
|
A special type of Mesa Agent that:
|
||||||
|
|
||||||
@ -39,15 +69,12 @@ class BaseAgent(MesaAgent, MutableMapping):
|
|||||||
Any attribute that is not preceded by an underscore (`_`) will also be added to its state.
|
Any attribute that is not preceded by an underscore (`_`) will also be added to its state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
defaults = {}
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
unique_id,
|
unique_id,
|
||||||
model,
|
model,
|
||||||
name=None,
|
name=None,
|
||||||
interval=None,
|
interval=None,
|
||||||
**kwargs
|
**kwargs):
|
||||||
):
|
|
||||||
# Check for REQUIRED arguments
|
# Check for REQUIRED arguments
|
||||||
# Initialize agent parameters
|
# Initialize agent parameters
|
||||||
if isinstance(unique_id, MesaAgent):
|
if isinstance(unique_id, MesaAgent):
|
||||||
@ -58,15 +85,16 @@ class BaseAgent(MesaAgent, MutableMapping):
|
|||||||
self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id)
|
self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id)
|
||||||
|
|
||||||
|
|
||||||
self._neighbors = None
|
|
||||||
self.alive = True
|
self.alive = True
|
||||||
|
|
||||||
self.interval = interval or self.get('interval', 1)
|
self.interval = interval or self.get('interval', 1)
|
||||||
self.logger = logging.getLogger(self.model.id).getChild(self.name)
|
logger = utils.logger.getChild(getattr(self.model, 'id', self.model)).getChild(self.name)
|
||||||
|
self.logger = logging.LoggerAdapter(logger, {'agent_name': self.name})
|
||||||
|
|
||||||
if hasattr(self, 'level'):
|
if hasattr(self, 'level'):
|
||||||
self.logger.setLevel(self.level)
|
self.logger.setLevel(self.level)
|
||||||
for (k, v) in self.defaults.items():
|
|
||||||
|
for (k, v) in self._defaults.items():
|
||||||
if not hasattr(self, k) or getattr(self, k) is None:
|
if not hasattr(self, k) or getattr(self, k) is None:
|
||||||
setattr(self, k, deepcopy(v))
|
setattr(self, k, deepcopy(v))
|
||||||
|
|
||||||
@ -74,10 +102,6 @@ class BaseAgent(MesaAgent, MutableMapping):
|
|||||||
|
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
|
|
||||||
for (k, v) in getattr(self, 'defaults', {}).items():
|
|
||||||
if not hasattr(self, k) or getattr(self, k) is None:
|
|
||||||
setattr(self, k, v)
|
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.unique_id)
|
return hash(self.unique_id)
|
||||||
|
|
||||||
@ -89,14 +113,6 @@ class BaseAgent(MesaAgent, MutableMapping):
|
|||||||
def id(self):
|
def id(self):
|
||||||
return self.unique_id
|
return self.unique_id
|
||||||
|
|
||||||
@property
|
|
||||||
def env(self):
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
@env.setter
|
|
||||||
def env(self, model):
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self):
|
def state(self):
|
||||||
'''
|
'''
|
||||||
@ -108,19 +124,16 @@ class BaseAgent(MesaAgent, MutableMapping):
|
|||||||
|
|
||||||
@state.setter
|
@state.setter
|
||||||
def state(self, value):
|
def state(self, value):
|
||||||
|
if not value:
|
||||||
|
return
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
self[k] = v
|
self[k] = v
|
||||||
|
|
||||||
@property
|
|
||||||
def environment_params(self):
|
|
||||||
return self.model.environment_params
|
|
||||||
|
|
||||||
@environment_params.setter
|
|
||||||
def environment_params(self, value):
|
|
||||||
self.model.environment_params = value
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
|
try:
|
||||||
return getattr(self, key)
|
return getattr(self, key)
|
||||||
|
except AttributeError:
|
||||||
|
raise KeyError(f'key {key} not found in agent')
|
||||||
|
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
return delattr(self, key)
|
return delattr(self, key)
|
||||||
@ -138,10 +151,14 @@ class BaseAgent(MesaAgent, MutableMapping):
|
|||||||
return self.items()
|
return self.items()
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return (k for k in self.__dict__ if k[0] != '_')
|
return (k for k in self.__dict__ if k[0] != '_' and k not in IGNORED_FIELDS)
|
||||||
|
|
||||||
def items(self):
|
def items(self, keys=None, skip=None):
|
||||||
return ((k, v) for (k, v) in self.__dict__.items() if k[0] != '_')
|
keys = keys if keys is not None else self.keys()
|
||||||
|
it = ((k, self.get(k, None)) for k in keys)
|
||||||
|
if skip:
|
||||||
|
return filter(lambda x: x[0] not in skip, it)
|
||||||
|
return it
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
return self[key] if key in self else default
|
return self[key] if key in self else default
|
||||||
@ -154,11 +171,9 @@ class BaseAgent(MesaAgent, MutableMapping):
|
|||||||
# No environment
|
# No environment
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def die(self, remove=False):
|
def die(self):
|
||||||
self.info(f'agent {self.unique_id} is dying')
|
self.info(f'agent dying')
|
||||||
self.alive = False
|
self.alive = False
|
||||||
if remove:
|
|
||||||
self.remove_node(self.id)
|
|
||||||
return time.NEVER
|
return time.NEVER
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
@ -170,7 +185,7 @@ class BaseAgent(MesaAgent, MutableMapping):
|
|||||||
if not self.logger.isEnabledFor(level):
|
if not self.logger.isEnabledFor(level):
|
||||||
return
|
return
|
||||||
message = message + " ".join(str(i) for i in args)
|
message = message + " ".join(str(i) for i in args)
|
||||||
message = " @{:>3}: {}".format(self.now, message)
|
message = "[@{:>4}]\t{:>10}: {}".format(self.now, repr(self), message)
|
||||||
for k, v in kwargs:
|
for k, v in kwargs:
|
||||||
message += " {k}={v} ".format(k, v)
|
message += " {k}={v} ".format(k, v)
|
||||||
extra = {}
|
extra = {}
|
||||||
@ -179,33 +194,48 @@ class BaseAgent(MesaAgent, MutableMapping):
|
|||||||
extra['agent_name'] = self.name
|
extra['agent_name'] = self.name
|
||||||
return self.logger.log(level, message, extra=extra)
|
return self.logger.log(level, message, extra=extra)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def debug(self, *args, **kwargs):
|
def debug(self, *args, **kwargs):
|
||||||
return self.log(*args, level=logging.DEBUG, **kwargs)
|
return self.log(*args, level=logging.DEBUG, **kwargs)
|
||||||
|
|
||||||
def info(self, *args, **kwargs):
|
def info(self, *args, **kwargs):
|
||||||
return self.log(*args, level=logging.INFO, **kwargs)
|
return self.log(*args, level=logging.INFO, **kwargs)
|
||||||
|
|
||||||
# Alias
|
def count_agents(self, **kwargs):
|
||||||
# Agent = BaseAgent
|
return len(list(self.get_agents(**kwargs)))
|
||||||
|
|
||||||
|
def get_agents(self, *args, **kwargs):
|
||||||
|
it = self.iter_agents(*args, **kwargs)
|
||||||
|
return list(it)
|
||||||
|
|
||||||
|
def iter_agents(self, *args, **kwargs):
|
||||||
|
yield from filter_agents(self.model.schedule._agents, *args, **kwargs)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.to_str()
|
||||||
|
|
||||||
|
def to_str(self, keys=None, skip=None, pretty=False):
|
||||||
|
content = dict(self.items(keys=keys))
|
||||||
|
if pretty and content:
|
||||||
|
d = content
|
||||||
|
content = '\n'
|
||||||
|
for k, v in d.items():
|
||||||
|
content += f'- {k}: {v}\n'
|
||||||
|
content = textwrap.indent(content, ' ')
|
||||||
|
return f"{repr(self)}{content}"
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}({self.unique_id})"
|
||||||
|
|
||||||
|
|
||||||
class NetworkAgent(BaseAgent):
|
class NetworkAgent(BaseAgent):
|
||||||
|
|
||||||
@property
|
def __init__(self, *args, topology, node_id, **kwargs):
|
||||||
def topology(self):
|
super().__init__(*args, **kwargs)
|
||||||
return self.env.topology_for(self.unique_id)
|
|
||||||
|
|
||||||
@property
|
self.topology = topology
|
||||||
def node_id(self):
|
self.node_id = node_id
|
||||||
return self.env.node_id_for(self.unique_id)
|
self.G = self.model.topologies[topology]
|
||||||
|
assert self.G
|
||||||
@property
|
|
||||||
def G(self):
|
|
||||||
return self.model.topologies[self._topology]
|
|
||||||
|
|
||||||
def count_agents(self, **kwargs):
|
|
||||||
return len(list(self.get_agents(**kwargs)))
|
|
||||||
|
|
||||||
def count_neighboring_agents(self, state_id=None, **kwargs):
|
def count_neighboring_agents(self, state_id=None, **kwargs):
|
||||||
return len(self.get_neighboring_agents(state_id=state_id, **kwargs))
|
return len(self.get_neighboring_agents(state_id=state_id, **kwargs))
|
||||||
@ -213,57 +243,47 @@ class NetworkAgent(BaseAgent):
|
|||||||
def get_neighboring_agents(self, state_id=None, **kwargs):
|
def get_neighboring_agents(self, state_id=None, **kwargs):
|
||||||
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
|
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
|
||||||
|
|
||||||
def get_agents(self, *args, limit=None, **kwargs):
|
|
||||||
it = self.iter_agents(*args, **kwargs)
|
|
||||||
if limit is not None:
|
|
||||||
it = islice(it, limit)
|
|
||||||
return list(it)
|
|
||||||
|
|
||||||
def iter_agents(self, unique_id=None, limit_neighbors=False, **kwargs):
|
def iter_agents(self, unique_id=None, limit_neighbors=False, **kwargs):
|
||||||
|
unique_ids = None
|
||||||
|
if isinstance(unique_id, list):
|
||||||
|
unique_ids = set(unique_id)
|
||||||
|
elif unique_id is not None:
|
||||||
|
unique_ids = set([unique_id,])
|
||||||
|
|
||||||
if limit_neighbors:
|
if limit_neighbors:
|
||||||
unique_id = [self.topology.nodes[node]['agent_id'] for node in self.topology.neighbors(self.node_id)]
|
neighbor_ids = set()
|
||||||
if not unique_id:
|
for node_id in self.G.neighbors(self.node_id):
|
||||||
|
if self.G.nodes[node_id].get('agent_id') is not None:
|
||||||
|
neighbor_ids.add(node_id)
|
||||||
|
if unique_ids:
|
||||||
|
unique_ids = unique_ids & neighbor_ids
|
||||||
|
else:
|
||||||
|
unique_ids = neighbor_ids
|
||||||
|
if not unique_ids:
|
||||||
return
|
return
|
||||||
|
unique_ids = list(unique_ids)
|
||||||
yield from self.model.agents(unique_id=unique_id, **kwargs)
|
yield from super().iter_agents(unique_id=unique_ids, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def subgraph(self, center=True, **kwargs):
|
def subgraph(self, center=True, **kwargs):
|
||||||
include = [self] if center else []
|
include = [self] if center else []
|
||||||
G = self.topology.subgraph(n.node_id for n in list(self.get_agents(**kwargs)+include))
|
G = self.G.subgraph(n.node_id for n in list(self.get_agents(**kwargs)+include))
|
||||||
return G
|
return G
|
||||||
|
|
||||||
def remove_node(self, unique_id):
|
def remove_node(self):
|
||||||
self.topology.remove_node(unique_id)
|
self.G.remove_node(self.node_id)
|
||||||
|
|
||||||
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
||||||
# return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
|
if self.node_id not in self.G.nodes(data=False):
|
||||||
if self.unique_id not in self.topology.nodes(data=False):
|
|
||||||
raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id))
|
raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id))
|
||||||
if other.unique_id not in self.topology.nodes(data=False):
|
if other.node_id not in self.G.nodes(data=False):
|
||||||
raise ValueError('{} not in list of existing agents in the network'.format(other))
|
raise ValueError('{} not in list of existing agents in the network'.format(other))
|
||||||
|
|
||||||
self.topology.add_edge(self.unique_id, other.unique_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
self.G.add_edge(self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
||||||
|
|
||||||
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
def die(self, remove=True):
|
||||||
'''Get a list of nodes in the ego network of *node* of radius *steps*'''
|
if remove:
|
||||||
node = as_node(node if node is not None else self)
|
self.remove_node()
|
||||||
G = self.subgraph(**kwargs)
|
return super().die()
|
||||||
return nx.ego_graph(G, node, center=center, radius=steps).nodes()
|
|
||||||
|
|
||||||
def degree(self, node, force=False):
|
|
||||||
node = as_node(node)
|
|
||||||
if force or (not hasattr(self.model, '_degree')) or getattr(self.model, '_last_step', 0) < self.now:
|
|
||||||
self.model._degree = nx.degree_centrality(self.topology)
|
|
||||||
self.model._last_step = self.now
|
|
||||||
return self.model._degree[node]
|
|
||||||
|
|
||||||
def betweenness(self, node, force=False):
|
|
||||||
node = as_node(node)
|
|
||||||
if force or (not hasattr(self.model, '_betweenness')) or getattr(self.model, '_last_step', 0) < self.now:
|
|
||||||
self.model._betweenness = nx.betweenness_centrality(self.topology)
|
|
||||||
self.model._last_step = self.now
|
|
||||||
return self.model._betweenness[node]
|
|
||||||
|
|
||||||
|
|
||||||
def state(name=None):
|
def state(name=None):
|
||||||
@ -273,24 +293,29 @@ def state(name=None):
|
|||||||
The default value for state_id is the current state id.
|
The default value for state_id is the current state id.
|
||||||
The default value for when is the interval defined in the environment.
|
The default value for when is the interval defined in the environment.
|
||||||
'''
|
'''
|
||||||
|
if inspect.isgeneratorfunction(func):
|
||||||
|
orig_func = func
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def func_wrapper(self):
|
def func(self):
|
||||||
next_state = func(self)
|
while True:
|
||||||
when = None
|
if not self._coroutine:
|
||||||
if next_state is None:
|
self._coroutine = orig_func(self)
|
||||||
return when
|
|
||||||
try:
|
try:
|
||||||
next_state, when = next_state
|
n = next(self._coroutine)
|
||||||
except (ValueError, TypeError):
|
if n:
|
||||||
pass
|
return None, n
|
||||||
if next_state:
|
return
|
||||||
|
except StopIteration as ex:
|
||||||
|
self._coroutine = None
|
||||||
|
next_state = ex.value
|
||||||
|
if next_state is not None:
|
||||||
self.set_state(next_state)
|
self.set_state(next_state)
|
||||||
return when
|
return next_state
|
||||||
|
|
||||||
func_wrapper.id = name or func.__name__
|
func.id = name or func.__name__
|
||||||
func_wrapper.is_default = False
|
func.is_default = False
|
||||||
return func_wrapper
|
return func
|
||||||
|
|
||||||
if callable(name):
|
if callable(name):
|
||||||
return decorator(name)
|
return decorator(name)
|
||||||
@ -303,60 +328,84 @@ def default_state(func):
|
|||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
class MetaFSM(ABCMeta):
|
class MetaFSM(MetaAgent):
|
||||||
def __init__(cls, name, bases, nmspc):
|
def __new__(mcls, name, bases, namespace):
|
||||||
super(MetaFSM, cls).__init__(name, bases, nmspc)
|
|
||||||
states = {}
|
states = {}
|
||||||
# Re-use states from inherited classes
|
# Re-use states from inherited classes
|
||||||
default_state = None
|
default_state = None
|
||||||
for i in bases:
|
for i in bases:
|
||||||
if isinstance(i, MetaFSM):
|
if isinstance(i, MetaFSM):
|
||||||
for state_id, state in i.states.items():
|
for state_id, state in i._states.items():
|
||||||
if state.is_default:
|
if state.is_default:
|
||||||
default_state = state
|
default_state = state
|
||||||
states[state_id] = state
|
states[state_id] = state
|
||||||
|
|
||||||
# Add new states
|
# Add new states
|
||||||
for name, func in nmspc.items():
|
for attr, func in namespace.items():
|
||||||
if hasattr(func, 'id'):
|
if hasattr(func, 'id'):
|
||||||
if func.is_default:
|
if func.is_default:
|
||||||
default_state = func
|
default_state = func
|
||||||
states[func.id] = func
|
states[func.id] = func
|
||||||
cls.default_state = default_state
|
|
||||||
cls.states = states
|
namespace.update({
|
||||||
|
'_default_state': default_state,
|
||||||
|
'_states': states,
|
||||||
|
})
|
||||||
|
|
||||||
|
return super(MetaFSM, mcls).__new__(mcls=mcls, name=name, bases=bases, namespace=namespace)
|
||||||
|
|
||||||
|
|
||||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(FSM, self).__init__(*args, **kwargs)
|
super(FSM, self).__init__(*args, **kwargs)
|
||||||
if not hasattr(self, 'state_id'):
|
if not hasattr(self, 'state_id'):
|
||||||
if not self.default_state:
|
if not self._default_state:
|
||||||
raise ValueError('No default state specified for {}'.format(self.unique_id))
|
raise ValueError('No default state specified for {}'.format(self.unique_id))
|
||||||
self.state_id = self.default_state.id
|
self.state_id = self._default_state.id
|
||||||
|
|
||||||
|
self._coroutine = None
|
||||||
self.set_state(self.state_id)
|
self.set_state(self.state_id)
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
self.debug(f'Agent {self.unique_id} @ state {self.state_id}')
|
self.debug(f'Agent {self.unique_id} @ state {self.state_id}')
|
||||||
interval = super().step()
|
default_interval = super().step()
|
||||||
if 'id' not in self.state:
|
|
||||||
if self.default_state:
|
|
||||||
self.set_state(self.default_state.id)
|
|
||||||
else:
|
|
||||||
raise Exception('{} has no valid state id or default state'.format(self))
|
|
||||||
interval = self.states[self.state_id](self) or interval
|
|
||||||
if not self.alive:
|
|
||||||
return time.NEVER
|
|
||||||
return interval
|
|
||||||
|
|
||||||
def set_state(self, state):
|
next_state = self._states[self.state_id](self)
|
||||||
|
|
||||||
|
when = None
|
||||||
|
try:
|
||||||
|
next_state, *when = next_state
|
||||||
|
if not when:
|
||||||
|
when = None
|
||||||
|
elif len(when) == 1:
|
||||||
|
when = when[0]
|
||||||
|
else:
|
||||||
|
raise ValueError('Too many values returned. Only state (and time) allowed')
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if next_state is not None:
|
||||||
|
self.set_state(next_state)
|
||||||
|
|
||||||
|
return when or default_interval
|
||||||
|
|
||||||
|
def set_state(self, state, when=None):
|
||||||
if hasattr(state, 'id'):
|
if hasattr(state, 'id'):
|
||||||
state = state.id
|
state = state.id
|
||||||
if state not in self.states:
|
if state not in self._states:
|
||||||
raise ValueError('{} is not a valid state'.format(state))
|
raise ValueError('{} is not a valid state'.format(state))
|
||||||
self.state_id = state
|
self.state_id = state
|
||||||
|
if when is not None:
|
||||||
|
self.model.schedule.add(self, when=when)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
def die(self):
|
||||||
|
return self.dead, super().die()
|
||||||
|
|
||||||
|
@state
|
||||||
|
def dead(self):
|
||||||
|
return self.die()
|
||||||
|
|
||||||
|
|
||||||
def prob(prob, random):
|
def prob(prob, random):
|
||||||
'''
|
'''
|
||||||
@ -476,81 +525,81 @@ def _convert_agent_classs(ind, to_string=False, **kwargs):
|
|||||||
return deserialize_definition(ind, **kwargs)
|
return deserialize_definition(ind, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _agent_from_definition(definition, random, value=-1, unique_id=None):
|
# def _agent_from_definition(definition, random, value=-1, unique_id=None):
|
||||||
"""Used in the initialization of agents given an agent distribution."""
|
# """Used in the initialization of agents given an agent distribution."""
|
||||||
if value < 0:
|
# if value < 0:
|
||||||
value = random.random()
|
# value = random.random()
|
||||||
for d in sorted(definition, key=lambda x: x.get('threshold')):
|
# for d in sorted(definition, key=lambda x: x.get('threshold')):
|
||||||
threshold = d.get('threshold', (-1, -1))
|
# threshold = d.get('threshold', (-1, -1))
|
||||||
# Check if the definition matches by id (first) or by threshold
|
# # Check if the definition matches by id (first) or by threshold
|
||||||
if (unique_id is not None and unique_id in d.get('ids', [])) or \
|
# if (unique_id is not None and unique_id in d.get('ids', [])) or \
|
||||||
(value >= threshold[0] and value < threshold[1]):
|
# (value >= threshold[0] and value < threshold[1]):
|
||||||
state = {}
|
# state = {}
|
||||||
if 'state' in d:
|
# if 'state' in d:
|
||||||
state = deepcopy(d['state'])
|
# state = deepcopy(d['state'])
|
||||||
return d['agent_class'], state
|
# return d['agent_class'], state
|
||||||
|
|
||||||
raise Exception('Definition for value {} not found in: {}'.format(value, definition))
|
# raise Exception('Definition for value {} not found in: {}'.format(value, definition))
|
||||||
|
|
||||||
|
|
||||||
def _definition_to_dict(definition, random, size=None, default_state=None):
|
# def _definition_to_dict(definition, random, size=None, default_state=None):
|
||||||
state = default_state or {}
|
# state = default_state or {}
|
||||||
agents = {}
|
# agents = {}
|
||||||
remaining = {}
|
# remaining = {}
|
||||||
if size:
|
# if size:
|
||||||
for ix in range(size):
|
# for ix in range(size):
|
||||||
remaining[ix] = copy(state)
|
# remaining[ix] = copy(state)
|
||||||
else:
|
# else:
|
||||||
remaining = defaultdict(lambda x: copy(state))
|
# remaining = defaultdict(lambda x: copy(state))
|
||||||
|
|
||||||
distro = sorted([item for item in definition if 'weight' in item])
|
# distro = sorted([item for item in definition if 'weight' in item])
|
||||||
|
|
||||||
id = 0
|
# id = 0
|
||||||
|
|
||||||
def init_agent(item, id=ix):
|
# def init_agent(item, id=ix):
|
||||||
while id in agents:
|
# while id in agents:
|
||||||
id += 1
|
# id += 1
|
||||||
|
|
||||||
agent = remaining[id]
|
# agent = remaining[id]
|
||||||
agent['state'].update(copy(item.get('state', {})))
|
# agent['state'].update(copy(item.get('state', {})))
|
||||||
agents[agent.unique_id] = agent
|
# agents[agent.unique_id] = agent
|
||||||
del remaining[id]
|
# del remaining[id]
|
||||||
return agent
|
# return agent
|
||||||
|
|
||||||
for item in definition:
|
# for item in definition:
|
||||||
if 'ids' in item:
|
# if 'ids' in item:
|
||||||
ids = item['ids']
|
# ids = item['ids']
|
||||||
del item['ids']
|
# del item['ids']
|
||||||
for id in ids:
|
# for id in ids:
|
||||||
agent = init_agent(item, id)
|
# agent = init_agent(item, id)
|
||||||
|
|
||||||
for item in definition:
|
# for item in definition:
|
||||||
if 'number' in item:
|
# if 'number' in item:
|
||||||
times = item['number']
|
# times = item['number']
|
||||||
del item['number']
|
# del item['number']
|
||||||
for times in range(times):
|
# for times in range(times):
|
||||||
if size:
|
# if size:
|
||||||
ix = random.choice(remaining.keys())
|
# ix = random.choice(remaining.keys())
|
||||||
agent = init_agent(item, id)
|
# agent = init_agent(item, id)
|
||||||
else:
|
# else:
|
||||||
agent = init_agent(item)
|
# agent = init_agent(item)
|
||||||
if not size:
|
# if not size:
|
||||||
return agents
|
# return agents
|
||||||
|
|
||||||
if len(remaining) < 0:
|
# if len(remaining) < 0:
|
||||||
raise Exception('Invalid definition. Too many agents to add')
|
# raise Exception('Invalid definition. Too many agents to add')
|
||||||
|
|
||||||
|
|
||||||
total_weight = float(sum(s['weight'] for s in distro))
|
# total_weight = float(sum(s['weight'] for s in distro))
|
||||||
unit = size / total_weight
|
# unit = size / total_weight
|
||||||
|
|
||||||
for item in distro:
|
# for item in distro:
|
||||||
times = unit * item['weight']
|
# times = unit * item['weight']
|
||||||
del item['weight']
|
# del item['weight']
|
||||||
for times in range(times):
|
# for times in range(times):
|
||||||
ix = random.choice(remaining.keys())
|
# ix = random.choice(remaining.keys())
|
||||||
agent = init_agent(item, id)
|
# agent = init_agent(item, id)
|
||||||
return agents
|
# return agents
|
||||||
|
|
||||||
|
|
||||||
class AgentView(Mapping, Set):
|
class AgentView(Mapping, Set):
|
||||||
@ -571,59 +620,43 @@ class AgentView(Mapping, Set):
|
|||||||
|
|
||||||
# Mapping methods
|
# Mapping methods
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return sum(len(x) for x in self._agents.values())
|
return len(self._agents)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
yield from iter(chain.from_iterable(g.values() for g in self._agents.values()))
|
yield from self._agents.values()
|
||||||
|
|
||||||
def __getitem__(self, agent_id):
|
def __getitem__(self, agent_id):
|
||||||
if isinstance(agent_id, slice):
|
if isinstance(agent_id, slice):
|
||||||
raise ValueError(f"Slicing is not supported")
|
raise ValueError(f"Slicing is not supported")
|
||||||
for group in self._agents.values():
|
if agent_id in self._agents:
|
||||||
if agent_id in group:
|
return self._agents[agent_id]
|
||||||
return group[agent_id]
|
|
||||||
raise ValueError(f"Agent {agent_id} not found")
|
raise ValueError(f"Agent {agent_id} not found")
|
||||||
|
|
||||||
def filter(self, *args, **kwargs):
|
def filter(self, *args, **kwargs):
|
||||||
yield from filter_groups(self._agents, *args, **kwargs)
|
yield from filter_agents(self._agents, *args, **kwargs)
|
||||||
|
|
||||||
def one(self, *args, **kwargs):
|
def one(self, *args, **kwargs):
|
||||||
return next(filter_groups(self._agents, *args, **kwargs))
|
return next(filter_agents(self._agents, *args, **kwargs))
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return list(self.filter(*args, **kwargs))
|
return list(self.filter(*args, **kwargs))
|
||||||
|
|
||||||
def __contains__(self, agent_id):
|
def __contains__(self, agent_id):
|
||||||
return any(agent_id in g for g in self._agents)
|
return agent_id in self._agents
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(list(a.unique_id for a in self))
|
return str(list(unique_id for unique_id in self.keys()))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.__class__.__name__}({self})"
|
return f"{self.__class__.__name__}({self})"
|
||||||
|
|
||||||
|
|
||||||
def filter_groups(groups, *, group=None, **kwargs):
|
def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=None, ignore=None, state=None,
|
||||||
assert isinstance(groups, dict)
|
limit=None, **kwargs):
|
||||||
|
|
||||||
if group is not None and not isinstance(group, list):
|
|
||||||
group = [group]
|
|
||||||
|
|
||||||
if group:
|
|
||||||
groups = list(groups[g] for g in group if g in groups)
|
|
||||||
else:
|
|
||||||
groups = list(groups.values())
|
|
||||||
|
|
||||||
agents = chain.from_iterable(filter_group(g, **kwargs) for g in groups)
|
|
||||||
|
|
||||||
yield from agents
|
|
||||||
|
|
||||||
|
|
||||||
def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=None, ignore=None, state=None, **kwargs):
|
|
||||||
'''
|
'''
|
||||||
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
||||||
'''
|
'''
|
||||||
assert isinstance(group, dict)
|
assert isinstance(agents, dict)
|
||||||
|
|
||||||
ids = []
|
ids = []
|
||||||
|
|
||||||
@ -636,6 +669,11 @@ def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=Non
|
|||||||
if id_args:
|
if id_args:
|
||||||
ids += id_args
|
ids += id_args
|
||||||
|
|
||||||
|
if ids:
|
||||||
|
f = (agents[aid] for aid in ids if aid in agents)
|
||||||
|
else:
|
||||||
|
f = (a for a in agents.values())
|
||||||
|
|
||||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||||
state_id = tuple([state_id])
|
state_id = tuple([state_id])
|
||||||
|
|
||||||
@ -646,12 +684,6 @@ def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=Non
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
agent_class = tuple([agent_class])
|
agent_class = tuple([agent_class])
|
||||||
|
|
||||||
if ids:
|
|
||||||
agents = (group[aid] for aid in ids if aid in group)
|
|
||||||
else:
|
|
||||||
agents = (a for a in group.values())
|
|
||||||
|
|
||||||
f = agents
|
|
||||||
if ignore:
|
if ignore:
|
||||||
f = filter(lambda x: x not in ignore, f)
|
f = filter(lambda x: x not in ignore, f)
|
||||||
|
|
||||||
@ -667,83 +699,125 @@ def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=Non
|
|||||||
for k, v in state.items():
|
for k, v in state.items():
|
||||||
f = filter(lambda agent: agent.state.get(k, None) == v, f)
|
f = filter(lambda agent: agent.state.get(k, None) == v, f)
|
||||||
|
|
||||||
|
if limit is not None:
|
||||||
|
f = islice(f, limit)
|
||||||
|
|
||||||
yield from f
|
yield from f
|
||||||
|
|
||||||
|
|
||||||
def from_config(cfg: Dict[str, config.AgentConfig], env, random):
|
def from_config(cfg: config.AgentConfig, random, topologies: Dict[str, nx.Graph] = None) -> List[Dict[str, Any]]:
|
||||||
'''
|
'''
|
||||||
Agents are specified in groups.
|
This function turns an agentconfig into a list of individual "agent specifications", which are just a dictionary
|
||||||
Each group can be specified in two ways, either through a fixed list in which each item has
|
with the parameters that the environment will use to construct each agent.
|
||||||
has the agent type, number of agents to create, and the other parameters, or through what we call
|
|
||||||
an `agent distribution`, which is similar but instead of number of agents, it specifies the weight
|
This function does NOT return a list of agents, mostly because some attributes to the agent are not known at the
|
||||||
of each agent type.
|
time of calling this function, such as `unique_id`.
|
||||||
'''
|
'''
|
||||||
default = cfg.get('default', None)
|
default = cfg or config.AgentConfig()
|
||||||
return {k: _group_from_config(c, default=default, env=env, random=random) for (k, c) in cfg.items() if k is not 'default'}
|
if not isinstance(cfg, config.AgentConfig):
|
||||||
|
cfg = config.AgentConfig(**cfg)
|
||||||
|
return _agents_from_config(cfg, topologies=topologies, random=random)
|
||||||
|
|
||||||
|
|
||||||
def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfig, env, random):
|
def _agents_from_config(cfg: config.AgentConfig,
|
||||||
|
topologies: Dict[str, nx.Graph],
|
||||||
|
random) -> List[Dict[str, Any]]:
|
||||||
if cfg and not isinstance(cfg, config.AgentConfig):
|
if cfg and not isinstance(cfg, config.AgentConfig):
|
||||||
cfg = config.AgentConfig(**cfg)
|
cfg = config.AgentConfig(**cfg)
|
||||||
if default and not isinstance(default, config.SingleAgentConfig):
|
|
||||||
default = config.SingleAgentConfig(**default)
|
|
||||||
|
|
||||||
agents = {}
|
agents = []
|
||||||
|
|
||||||
|
assigned = defaultdict(int)
|
||||||
|
|
||||||
if cfg.fixed is not None:
|
if cfg.fixed is not None:
|
||||||
agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env)
|
agents, counts = _from_fixed(cfg.fixed, topology=cfg.topology, default=cfg)
|
||||||
if cfg.distribution:
|
assigned.update(counts)
|
||||||
n = cfg.n or len(env.topologies[cfg.topology or default.topology])
|
|
||||||
target = n - len(agents)
|
|
||||||
agents.update(_from_distro(cfg.distribution, target,
|
|
||||||
topology=cfg.topology or default.topology,
|
|
||||||
default=default,
|
|
||||||
env=env, random=random))
|
|
||||||
assert len(agents) == n
|
|
||||||
if cfg.override:
|
|
||||||
for attrs in cfg.override:
|
|
||||||
if attrs.filter:
|
|
||||||
filtered = list(filter_group(agents, **attrs.filter))
|
|
||||||
else:
|
|
||||||
filtered = list(agents)
|
|
||||||
|
|
||||||
if attrs.n > len(filtered):
|
n = cfg.n
|
||||||
raise ValueError(f'Not enough agents to sample. Got {len(filtered)}, expected >= {attrs.n}')
|
|
||||||
for agent in random.sample(filtered, attrs.n):
|
if cfg.distribution:
|
||||||
agent.state.update(attrs.state)
|
topo_size = {top: len(topologies[top]) for top in topologies}
|
||||||
|
|
||||||
|
grouped = defaultdict(list)
|
||||||
|
total = []
|
||||||
|
|
||||||
|
for d in cfg.distribution:
|
||||||
|
if d.strategy == config.Strategy.topology:
|
||||||
|
topology = d.topology if ('topology' in d.__fields_set__) else cfg.topology
|
||||||
|
if not topology:
|
||||||
|
raise ValueError('The "topology" strategy only works if the topology parameter is specified')
|
||||||
|
if topology not in topo_size:
|
||||||
|
raise ValueError(f'Unknown topology selected: { topology }. Make sure the topology has been defined')
|
||||||
|
|
||||||
|
grouped[topology].append(d)
|
||||||
|
|
||||||
|
if d.strategy == config.Strategy.total:
|
||||||
|
if not cfg.n:
|
||||||
|
raise ValueError('Cannot use the "total" strategy without providing the total number of agents')
|
||||||
|
total.append(d)
|
||||||
|
|
||||||
|
|
||||||
|
for (topo, distro) in grouped.items():
|
||||||
|
if not topologies or topo not in topo_size:
|
||||||
|
raise ValueError(
|
||||||
|
'You need to specify a target number of agents for the distribution \
|
||||||
|
or a configuration with a topology, along with a dictionary with \
|
||||||
|
all the available topologies')
|
||||||
|
n = len(topologies[topo])
|
||||||
|
target = topo_size[topo] - assigned[topo]
|
||||||
|
new_agents = _from_distro(cfg.distribution, target,
|
||||||
|
topology=topo,
|
||||||
|
default=cfg,
|
||||||
|
random=random)
|
||||||
|
assigned[topo] += len(new_agents)
|
||||||
|
agents += new_agents
|
||||||
|
|
||||||
|
if total:
|
||||||
|
remaining = n - sum(assigned.values())
|
||||||
|
agents += _from_distro(total, remaining,
|
||||||
|
topology='', # DO NOT assign to any topology
|
||||||
|
default=cfg,
|
||||||
|
random=random)
|
||||||
|
|
||||||
|
|
||||||
|
if sum(assigned.values()) != sum(topo_size.values()):
|
||||||
|
utils.logger.warn(f'The total number of agents does not match the total number of nodes in '
|
||||||
|
'every topology. This may be due to a definition error: assigned: '
|
||||||
|
f'{ assigned } total sizes: { topo_size }')
|
||||||
|
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
|
|
||||||
def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: config.SingleAgentConfig, env):
|
def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: config.SingleAgentConfig) -> List[Dict[str, Any]]:
|
||||||
agents = {}
|
agents = []
|
||||||
|
|
||||||
|
counts = {}
|
||||||
|
|
||||||
for fixed in lst:
|
for fixed in lst:
|
||||||
agent_id = fixed.agent_id
|
agent = {}
|
||||||
if agent_id is None:
|
if default:
|
||||||
agent_id = env.next_id()
|
agent = default.state.copy()
|
||||||
|
agent.update(fixed.state)
|
||||||
|
cls = serialization.deserialize(fixed.agent_class or (default and default.agent_class))
|
||||||
|
agent['agent_class'] = cls
|
||||||
|
topo = fixed.topology if ('topology' in fixed.__fields_set__) else topology or default.topology
|
||||||
|
|
||||||
cls = serialization.deserialize(fixed.agent_class or default.agent_class)
|
if topo:
|
||||||
state = fixed.state.copy()
|
agent['topology'] = topo
|
||||||
state.update(default.state)
|
if not fixed.hidden:
|
||||||
agent = cls(unique_id=agent_id,
|
counts[topo] = counts.get(topo, 0) + 1
|
||||||
model=env,
|
agents.append(agent)
|
||||||
**state)
|
|
||||||
topology = fixed.topology if (fixed.topology is not None) else (topology or default.topology)
|
|
||||||
if topology:
|
|
||||||
env.agent_to_node(agent_id, topology, fixed.node_id)
|
|
||||||
agents[agent.unique_id] = agent
|
|
||||||
|
|
||||||
return agents
|
return agents, counts
|
||||||
|
|
||||||
|
|
||||||
def _from_distro(distro: List[config.AgentDistro],
|
def _from_distro(distro: List[config.AgentDistro],
|
||||||
n: int,
|
n: int,
|
||||||
topology: str,
|
topology: str,
|
||||||
default: config.SingleAgentConfig,
|
default: config.SingleAgentConfig,
|
||||||
env,
|
random) -> List[Dict[str, Any]]:
|
||||||
random):
|
|
||||||
|
|
||||||
agents = {}
|
agents = []
|
||||||
|
|
||||||
if n is None:
|
if n is None:
|
||||||
if any(lambda dist: dist.n is None, distro):
|
if any(lambda dist: dist.n is None, distro):
|
||||||
@ -775,19 +849,16 @@ def _from_distro(distro: List[config.AgentDistro],
|
|||||||
|
|
||||||
for idx in indices:
|
for idx in indices:
|
||||||
d = distro[idx]
|
d = distro[idx]
|
||||||
|
agent = d.state.copy()
|
||||||
cls = classes[idx]
|
cls = classes[idx]
|
||||||
agent_id = env.next_id()
|
agent['agent_class'] = cls
|
||||||
state = d.state.copy()
|
|
||||||
if default:
|
if default:
|
||||||
state.update(default.state)
|
agent.update(default.state)
|
||||||
agent = cls(unique_id=agent_id, model=env, **state)
|
# agent = cls(unique_id=agent_id, model=env, **state)
|
||||||
topology = d.topology if (d.topology is not None) else topology or default.topology
|
topology = d.topology if ('topology' in d.__fields_set__) else topology or default.topology
|
||||||
if topology:
|
if topology:
|
||||||
env.agent_to_node(agent.unique_id, topology)
|
agent['topology'] = topology
|
||||||
assert agent.name is not None
|
agents.append(agent)
|
||||||
assert agent.name != 'None'
|
|
||||||
assert agent.name
|
|
||||||
agents[agent.unique_id] = agent
|
|
||||||
|
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
|
206
soil/analysis.py
206
soil/analysis.py
@ -1,206 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
|
|
||||||
import glob
|
|
||||||
import yaml
|
|
||||||
from os.path import join
|
|
||||||
|
|
||||||
from . import serialization
|
|
||||||
from tsih import History
|
|
||||||
|
|
||||||
|
|
||||||
def read_data(*args, group=False, **kwargs):
|
|
||||||
iterable = _read_data(*args, **kwargs)
|
|
||||||
if group:
|
|
||||||
return group_trials(iterable)
|
|
||||||
else:
|
|
||||||
return list(iterable)
|
|
||||||
|
|
||||||
|
|
||||||
def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
|
|
||||||
if not process_args:
|
|
||||||
process_args = {}
|
|
||||||
for folder in glob.glob(pattern):
|
|
||||||
config_file = glob.glob(join(folder, '*.yml'))[0]
|
|
||||||
config = yaml.load(open(config_file), Loader=yaml.SafeLoader)
|
|
||||||
df = None
|
|
||||||
if from_csv:
|
|
||||||
for trial_data in sorted(glob.glob(join(folder,
|
|
||||||
'*.environment.csv'))):
|
|
||||||
df = read_csv(trial_data, **kwargs)
|
|
||||||
yield config_file, df, config
|
|
||||||
else:
|
|
||||||
for trial_data in sorted(glob.glob(join(folder, '*.sqlite'))):
|
|
||||||
df = read_sql(trial_data, **kwargs)
|
|
||||||
yield config_file, df, config
|
|
||||||
|
|
||||||
|
|
||||||
def read_sql(db, *args, **kwargs):
|
|
||||||
h = History(db_path=db, backup=False, readonly=True)
|
|
||||||
df = h.read_sql(*args, **kwargs)
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def read_csv(filename, keys=None, convert_types=False, **kwargs):
|
|
||||||
'''
|
|
||||||
Read a CSV in canonical form: ::
|
|
||||||
|
|
||||||
<agent_id, t_step, key, value, value_type>
|
|
||||||
|
|
||||||
'''
|
|
||||||
df = pd.read_csv(filename)
|
|
||||||
if convert_types:
|
|
||||||
df = convert_types_slow(df)
|
|
||||||
if keys:
|
|
||||||
df = df[df['key'].isin(keys)]
|
|
||||||
df = process_one(df)
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def convert_row(row):
|
|
||||||
row['value'] = serialization.deserialize(row['value_type'], row['value'])
|
|
||||||
return row
|
|
||||||
|
|
||||||
|
|
||||||
def convert_types_slow(df):
|
|
||||||
'''
|
|
||||||
Go over every column in a dataframe and convert it to the type determined by the `get_types`
|
|
||||||
function.
|
|
||||||
|
|
||||||
This is a slow operation.
|
|
||||||
'''
|
|
||||||
dtypes = get_types(df)
|
|
||||||
for k, v in dtypes.items():
|
|
||||||
t = df[df['key']==k]
|
|
||||||
t['value'] = t['value'].astype(v)
|
|
||||||
df = df.apply(convert_row, axis=1)
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def split_processed(df):
|
|
||||||
env = df.loc[:, df.columns.get_level_values(1).isin(['env', 'stats'])]
|
|
||||||
agents = df.loc[:, ~df.columns.get_level_values(1).isin(['env', 'stats'])]
|
|
||||||
return env, agents
|
|
||||||
|
|
||||||
|
|
||||||
def split_df(df):
|
|
||||||
'''
|
|
||||||
Split a dataframe in two dataframes: one with the history of agents,
|
|
||||||
and one with the environment history
|
|
||||||
'''
|
|
||||||
envmask = (df['agent_id'] == 'env')
|
|
||||||
n_env = envmask.sum()
|
|
||||||
if n_env == len(df):
|
|
||||||
return df, None
|
|
||||||
elif n_env == 0:
|
|
||||||
return None, df
|
|
||||||
agents, env = [x for _, x in df.groupby(envmask)]
|
|
||||||
return env, agents
|
|
||||||
|
|
||||||
|
|
||||||
def process(df, **kwargs):
|
|
||||||
'''
|
|
||||||
Process a dataframe in canonical form ``(t_step, agent_id, key, value, value_type)`` into
|
|
||||||
two dataframes with a column per key: one with the history of the agents, and one for the
|
|
||||||
history of the environment.
|
|
||||||
'''
|
|
||||||
env, agents = split_df(df)
|
|
||||||
return process_one(env, **kwargs), process_one(agents, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_types(df):
|
|
||||||
'''
|
|
||||||
Get the value type for every key stored in a raw history dataframe.
|
|
||||||
'''
|
|
||||||
dtypes = df.groupby(by=['key'])['value_type'].unique()
|
|
||||||
return {k:v[0] for k,v in dtypes.iteritems()}
|
|
||||||
|
|
||||||
|
|
||||||
def process_one(df, *keys, columns=['key', 'agent_id'], values='value',
|
|
||||||
fill=True, index=['t_step',],
|
|
||||||
aggfunc='first', **kwargs):
|
|
||||||
'''
|
|
||||||
Process a dataframe in canonical form ``(t_step, agent_id, key, value, value_type)`` into
|
|
||||||
a dataframe with a column per key
|
|
||||||
'''
|
|
||||||
if df is None:
|
|
||||||
return df
|
|
||||||
if keys:
|
|
||||||
df = df[df['key'].isin(keys)]
|
|
||||||
|
|
||||||
df = df.pivot_table(values=values, index=index, columns=columns,
|
|
||||||
aggfunc=aggfunc, **kwargs)
|
|
||||||
if fill:
|
|
||||||
df = fillna(df)
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def get_count(df, *keys):
|
|
||||||
'''
|
|
||||||
For every t_step and key, get the value count.
|
|
||||||
|
|
||||||
The result is a dataframe with `t_step` as index, an a multiindex column based on `key` and the values found for each `key`.
|
|
||||||
'''
|
|
||||||
if keys:
|
|
||||||
df = df[list(keys)]
|
|
||||||
df.columns = df.columns.remove_unused_levels()
|
|
||||||
counts = pd.DataFrame()
|
|
||||||
for key in df.columns.levels[0]:
|
|
||||||
g = df[[key]].apply(pd.Series.value_counts, axis=1).fillna(0)
|
|
||||||
for value, series in g.iteritems():
|
|
||||||
counts[key, value] = series
|
|
||||||
counts.columns = pd.MultiIndex.from_tuples(counts.columns)
|
|
||||||
return counts
|
|
||||||
|
|
||||||
|
|
||||||
def get_majority(df, *keys):
|
|
||||||
'''
|
|
||||||
For every t_step and key, get the value of the majority of agents
|
|
||||||
|
|
||||||
The result is a dataframe with `t_step` as index, and columns based on `key`.
|
|
||||||
'''
|
|
||||||
df = get_count(df, *keys)
|
|
||||||
return df.stack(level=0).idxmax(axis=1).unstack()
|
|
||||||
|
|
||||||
|
|
||||||
def get_value(df, *keys, aggfunc='sum'):
|
|
||||||
'''
|
|
||||||
For every t_step and key, get the value of *numeric columns*, aggregated using a specific function.
|
|
||||||
'''
|
|
||||||
if keys:
|
|
||||||
df = df[list(keys)]
|
|
||||||
df.columns = df.columns.remove_unused_levels()
|
|
||||||
df = df.select_dtypes('number')
|
|
||||||
return df.groupby(level='key', axis=1).agg(aggfunc)
|
|
||||||
|
|
||||||
|
|
||||||
def plot_all(*args, plot_args={}, **kwargs):
|
|
||||||
'''
|
|
||||||
Read all the trial data and plot the result of applying a function on them.
|
|
||||||
'''
|
|
||||||
dfs = do_all(*args, **kwargs)
|
|
||||||
ps = []
|
|
||||||
for line in dfs:
|
|
||||||
f, df, config = line
|
|
||||||
if len(df) < 1:
|
|
||||||
continue
|
|
||||||
df.plot(title=config['name'], **plot_args)
|
|
||||||
ps.append(df)
|
|
||||||
return ps
|
|
||||||
|
|
||||||
def do_all(pattern, func, *keys, include_env=False, **kwargs):
|
|
||||||
for config_file, df, config in read_data(pattern, keys=keys):
|
|
||||||
if len(df) < 1:
|
|
||||||
continue
|
|
||||||
p = func(df, *keys, **kwargs)
|
|
||||||
yield config_file, p, config
|
|
||||||
|
|
||||||
|
|
||||||
def group_trials(trials, aggfunc=['mean', 'min', 'max', 'std']):
|
|
||||||
trials = list(trials)
|
|
||||||
trials = list(map(lambda x: x[1] if isinstance(x, tuple) else x, trials))
|
|
||||||
return pd.concat(trials).groupby(level=0).agg(aggfunc).reorder_levels([2, 0,1] ,axis=1)
|
|
||||||
|
|
||||||
|
|
||||||
def fillna(df):
|
|
||||||
new_df = df.ffill(axis=0)
|
|
||||||
return new_df
|
|
167
soil/config.py
167
soil/config.py
@ -1,12 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from pydantic import BaseModel, ValidationError, validator, root_validator
|
from pydantic import BaseModel, ValidationError, validator, root_validator
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union, Type
|
from typing import Any, Callable, Dict, List, Optional, Union, Type
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra
|
||||||
|
|
||||||
|
from . import environment, utils
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
|
|
||||||
@ -36,7 +42,6 @@ class NetParams(BaseModel, extra=Extra.allow):
|
|||||||
|
|
||||||
|
|
||||||
class NetConfig(BaseModel):
|
class NetConfig(BaseModel):
|
||||||
group: str = 'network'
|
|
||||||
params: Optional[NetParams]
|
params: Optional[NetParams]
|
||||||
topology: Optional[Union[Topology, nx.Graph]]
|
topology: Optional[Union[Topology, nx.Graph]]
|
||||||
path: Optional[str]
|
path: Optional[str]
|
||||||
@ -56,9 +61,6 @@ class NetConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class EnvConfig(BaseModel):
|
class EnvConfig(BaseModel):
|
||||||
environment_class: Union[Type, str] = 'soil.Environment'
|
|
||||||
params: Dict[str, Any] = {}
|
|
||||||
schedule: Union[Type, str] = 'soil.time.TimedActivation'
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default():
|
def default():
|
||||||
@ -67,19 +69,19 @@ class EnvConfig(BaseModel):
|
|||||||
|
|
||||||
class SingleAgentConfig(BaseModel):
|
class SingleAgentConfig(BaseModel):
|
||||||
agent_class: Optional[Union[Type, str]] = None
|
agent_class: Optional[Union[Type, str]] = None
|
||||||
agent_id: Optional[int] = None
|
unique_id: Optional[int] = None
|
||||||
topology: Optional[str] = None
|
topology: Optional[str] = None
|
||||||
node_id: Optional[Union[int, str]] = None
|
node_id: Optional[Union[int, str]] = None
|
||||||
name: Optional[str] = None
|
|
||||||
state: Optional[Dict[str, Any]] = {}
|
state: Optional[Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
class FixedAgentConfig(SingleAgentConfig):
|
class FixedAgentConfig(SingleAgentConfig):
|
||||||
n: Optional[int] = 1
|
n: Optional[int] = 1
|
||||||
|
hidden: Optional[bool] = False # Do not count this agent towards total agent count
|
||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def validate_all(cls, values):
|
def validate_all(cls, values):
|
||||||
if values.get('agent_id', None) is not None and values.get('n', 1) > 1:
|
if values.get('agent_id', None) is not None and values.get('n', 1) > 1:
|
||||||
print(values)
|
|
||||||
raise ValueError(f"An agent_id can only be provided when there is only one agent ({values.get('n')} given)")
|
raise ValueError(f"An agent_id can only be provided when there is only one agent ({values.get('n')} given)")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@ -88,13 +90,19 @@ class OverrideAgentConfig(FixedAgentConfig):
|
|||||||
filter: Optional[Dict[str, Any]] = None
|
filter: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Strategy(Enum):
|
||||||
|
topology = 'topology'
|
||||||
|
total = 'total'
|
||||||
|
|
||||||
|
|
||||||
class AgentDistro(SingleAgentConfig):
|
class AgentDistro(SingleAgentConfig):
|
||||||
weight: Optional[float] = 1
|
weight: Optional[float] = 1
|
||||||
|
strategy: Strategy = Strategy.topology
|
||||||
|
|
||||||
|
|
||||||
class AgentConfig(SingleAgentConfig):
|
class AgentConfig(SingleAgentConfig):
|
||||||
n: Optional[int] = None
|
n: Optional[int] = None
|
||||||
topology: Optional[str] = None
|
topology: Optional[str]
|
||||||
distribution: Optional[List[AgentDistro]] = None
|
distribution: Optional[List[AgentDistro]] = None
|
||||||
fixed: Optional[List[FixedAgentConfig]] = None
|
fixed: Optional[List[FixedAgentConfig]] = None
|
||||||
override: Optional[List[OverrideAgentConfig]] = None
|
override: Optional[List[OverrideAgentConfig]] = None
|
||||||
@ -110,19 +118,32 @@ class AgentConfig(SingleAgentConfig):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel, extra=Extra.forbid):
|
class Config(BaseModel, extra=Extra.allow):
|
||||||
version: Optional[str] = '1'
|
version: Optional[str] = '1'
|
||||||
|
|
||||||
id: str = 'Unnamed Simulation'
|
name: str = 'Unnamed Simulation'
|
||||||
|
description: Optional[str] = None
|
||||||
group: str = None
|
group: str = None
|
||||||
dir_path: Optional[str] = None
|
dir_path: Optional[str] = None
|
||||||
num_trials: int = 1
|
num_trials: int = 1
|
||||||
max_time: float = 100
|
max_time: float = 100
|
||||||
|
max_steps: int = -1
|
||||||
interval: float = 1
|
interval: float = 1
|
||||||
seed: str = ""
|
seed: str = ""
|
||||||
|
dry_run: bool = False
|
||||||
|
|
||||||
model_class: Union[Type, str]
|
model_class: Union[Type, str] = environment.Environment
|
||||||
model_parameters: Optiona[Dict[str, Any]] = {}
|
model_params: Optional[Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
visualization_params: Optional[Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_raw(cls, cfg):
|
||||||
|
if isinstance(cfg, Config):
|
||||||
|
return cfg
|
||||||
|
if cfg.get('version', '1') == '1' and any(k in cfg for k in ['agents', 'agent_class', 'topology', 'environment_class']):
|
||||||
|
return convert_old(cfg)
|
||||||
|
return Config(**cfg)
|
||||||
|
|
||||||
|
|
||||||
def convert_old(old, strict=True):
|
def convert_old(old, strict=True):
|
||||||
@ -132,87 +153,84 @@ def convert_old(old, strict=True):
|
|||||||
This is still a work in progress and might not work in many cases.
|
This is still a work in progress and might not work in many cases.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
#TODO: implement actual conversion
|
utils.logger.warning('The old configuration format is deprecated. The converted file MAY NOT yield the right results')
|
||||||
print('The old configuration format is no longer supported. \
|
|
||||||
Update your config files or run Soil==0.20')
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
new = old.copy()
|
||||||
new = {}
|
|
||||||
|
|
||||||
general = {}
|
|
||||||
for k in ['id',
|
|
||||||
'group',
|
|
||||||
'dir_path',
|
|
||||||
'num_trials',
|
|
||||||
'max_time',
|
|
||||||
'interval',
|
|
||||||
'seed']:
|
|
||||||
if k in old:
|
|
||||||
general[k] = old[k]
|
|
||||||
|
|
||||||
if 'name' in old:
|
|
||||||
general['id'] = old['name']
|
|
||||||
|
|
||||||
network = {}
|
network = {}
|
||||||
|
|
||||||
|
if 'topology' in old:
|
||||||
|
del new['topology']
|
||||||
|
network['topology'] = old['topology']
|
||||||
|
|
||||||
if 'network_params' in old and old['network_params']:
|
if 'network_params' in old and old['network_params']:
|
||||||
|
del new['network_params']
|
||||||
for (k, v) in old['network_params'].items():
|
for (k, v) in old['network_params'].items():
|
||||||
if k == 'path':
|
if k == 'path':
|
||||||
network['path'] = v
|
network['path'] = v
|
||||||
else:
|
else:
|
||||||
network.setdefault('params', {})[k] = v
|
network.setdefault('params', {})[k] = v
|
||||||
|
|
||||||
if 'topology' in old:
|
topologies = {}
|
||||||
network['topology'] = old['topology']
|
if network:
|
||||||
|
topologies['default'] = network
|
||||||
|
|
||||||
agents = {
|
|
||||||
'network': {},
|
|
||||||
'default': {},
|
|
||||||
}
|
|
||||||
|
|
||||||
if 'agent_class' in old:
|
|
||||||
agents['default']['agent_class'] = old['agent_class']
|
|
||||||
|
|
||||||
if 'default_state' in old:
|
|
||||||
agents['default']['state'] = old['default_state']
|
|
||||||
|
|
||||||
|
agents = {'fixed': [], 'distribution': []}
|
||||||
|
|
||||||
def updated_agent(agent):
|
def updated_agent(agent):
|
||||||
|
'''Convert an agent definition'''
|
||||||
newagent = dict(agent)
|
newagent = dict(agent)
|
||||||
newagent['agent_class'] = newagent['agent_class']
|
|
||||||
del newagent['agent_class']
|
|
||||||
return newagent
|
return newagent
|
||||||
|
|
||||||
for agent in old.get('environment_agents', []):
|
|
||||||
agents['environment'] = {'distribution': [], 'fixed': []}
|
|
||||||
if 'agent_id' in agent:
|
|
||||||
agent['name'] = agent['agent_id']
|
|
||||||
del agent['agent_id']
|
|
||||||
agents['environment']['fixed'].append(updated_agent(agent))
|
|
||||||
|
|
||||||
by_weight = []
|
by_weight = []
|
||||||
fixed = []
|
fixed = []
|
||||||
override = []
|
override = []
|
||||||
|
|
||||||
if 'network_agents' in old:
|
if 'environment_agents' in new:
|
||||||
agents['network']['topology'] = 'default'
|
|
||||||
|
|
||||||
for agent in old['network_agents']:
|
for agent in new['environment_agents']:
|
||||||
|
agent.setdefault('state', {})['group'] = 'environment'
|
||||||
|
if 'agent_id' in agent:
|
||||||
|
agent['state']['name'] = agent['agent_id']
|
||||||
|
del agent['agent_id']
|
||||||
|
agent['hidden'] = True
|
||||||
|
agent['topology'] = None
|
||||||
|
fixed.append(updated_agent(agent))
|
||||||
|
del new['environment_agents']
|
||||||
|
|
||||||
|
|
||||||
|
if 'agent_class' in old:
|
||||||
|
del new['agent_class']
|
||||||
|
agents['agent_class'] = old['agent_class']
|
||||||
|
|
||||||
|
if 'default_state' in old:
|
||||||
|
del new['default_state']
|
||||||
|
agents['state'] = old['default_state']
|
||||||
|
|
||||||
|
if 'network_agents' in old:
|
||||||
|
agents['topology'] = 'default'
|
||||||
|
|
||||||
|
agents.setdefault('state', {})['group'] = 'network'
|
||||||
|
|
||||||
|
for agent in new['network_agents']:
|
||||||
agent = updated_agent(agent)
|
agent = updated_agent(agent)
|
||||||
if 'agent_id' in agent:
|
if 'agent_id' in agent:
|
||||||
|
agent['state']['name'] = agent['agent_id']
|
||||||
|
del agent['agent_id']
|
||||||
fixed.append(agent)
|
fixed.append(agent)
|
||||||
else:
|
else:
|
||||||
by_weight.append(agent)
|
by_weight.append(agent)
|
||||||
|
del new['network_agents']
|
||||||
|
|
||||||
if 'agent_class' in old and (not fixed and not by_weight):
|
if 'agent_class' in old and (not fixed and not by_weight):
|
||||||
agents['network']['topology'] = 'default'
|
agents['topology'] = 'default'
|
||||||
by_weight = [{'agent_class': old['agent_class']}]
|
by_weight = [{'agent_class': old['agent_class'], 'weight': 1}]
|
||||||
|
|
||||||
|
|
||||||
# TODO: translate states properly
|
# TODO: translate states properly
|
||||||
if 'states' in old:
|
if 'states' in old:
|
||||||
|
del new['states']
|
||||||
states = old['states']
|
states = old['states']
|
||||||
if isinstance(states, dict):
|
if isinstance(states, dict):
|
||||||
states = states.items()
|
states = states.items()
|
||||||
@ -220,22 +238,29 @@ def convert_old(old, strict=True):
|
|||||||
states = enumerate(states)
|
states = enumerate(states)
|
||||||
for (k, v) in states:
|
for (k, v) in states:
|
||||||
override.append({'filter': {'node_id': k},
|
override.append({'filter': {'node_id': k},
|
||||||
'state': v
|
'state': v})
|
||||||
})
|
|
||||||
|
|
||||||
agents['network']['override'] = override
|
agents['override'] = override
|
||||||
agents['network']['fixed'] = fixed
|
agents['fixed'] = fixed
|
||||||
agents['network']['distribution'] = by_weight
|
agents['distribution'] = by_weight
|
||||||
|
|
||||||
|
|
||||||
|
model_params = {}
|
||||||
|
if 'environment_params' in new:
|
||||||
|
del new['environment_params']
|
||||||
|
model_params = dict(old['environment_params'])
|
||||||
|
|
||||||
environment = {'params': {}}
|
|
||||||
if 'environment_class' in old:
|
if 'environment_class' in old:
|
||||||
environment['environment_class'] = old['environment_class']
|
del new['environment_class']
|
||||||
|
new['model_class'] = old['environment_class']
|
||||||
|
|
||||||
for (k, v) in old.get('environment_params', {}).items():
|
if 'dump' in old:
|
||||||
environment['params'][k] = v
|
del new['dump']
|
||||||
|
new['dry_run'] = not old['dump']
|
||||||
|
|
||||||
|
model_params['topologies'] = topologies
|
||||||
|
model_params['agents'] = agents
|
||||||
|
|
||||||
return Config(version='2',
|
return Config(version='2',
|
||||||
general=general,
|
model_params=model_params,
|
||||||
topologies={'default': network},
|
**new)
|
||||||
environment=environment,
|
|
||||||
agents=agents)
|
|
||||||
|
151
soil/debugging.py
Normal file
151
soil/debugging.py
Normal file
@ -0,0 +1,151 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pdb
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
from textwrap import indent
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from .agents import FSM, MetaFSM
|
||||||
|
|
||||||
|
|
||||||
|
def wrapcmd(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(self, arg: str, temporary=False):
|
||||||
|
sys.settrace(self.trace_dispatch)
|
||||||
|
|
||||||
|
known = globals()
|
||||||
|
known.update(self.curframe.f_globals)
|
||||||
|
known.update(self.curframe.f_locals)
|
||||||
|
known['agent'] = known.get('self', None)
|
||||||
|
known['model'] = known.get('self', {}).get('model')
|
||||||
|
known['attrs'] = arg.strip().split()
|
||||||
|
|
||||||
|
exec(func.__code__, known, known)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class Debug(pdb.Pdb):
|
||||||
|
def __init__(self, *args, skip_soil=False, **kwargs):
|
||||||
|
skip = kwargs.get('skip', [])
|
||||||
|
if skip_soil:
|
||||||
|
skip.append('soil.*')
|
||||||
|
skip.append('mesa.*')
|
||||||
|
super(Debug, self).__init__(*args, skip=skip, **kwargs)
|
||||||
|
self.prompt = "[soil-pdb] "
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _soil_agents(model, attrs=None, pretty=True, **kwargs):
|
||||||
|
for agent in model.agents(**kwargs):
|
||||||
|
d = agent
|
||||||
|
print(' - ' + indent(agent.to_str(keys=attrs, pretty=pretty), ' '))
|
||||||
|
|
||||||
|
@wrapcmd
|
||||||
|
def do_soil_agents():
|
||||||
|
return Debug._soil_agents(model, attrs=attrs or None)
|
||||||
|
|
||||||
|
do_sa = do_soil_agents
|
||||||
|
|
||||||
|
@wrapcmd
|
||||||
|
def do_soil_list():
|
||||||
|
return Debug._soil_agents(model, attrs=['state_id'], pretty=False)
|
||||||
|
|
||||||
|
do_sl = do_soil_list
|
||||||
|
|
||||||
|
@wrapcmd
|
||||||
|
def do_soil_self():
|
||||||
|
if not agent:
|
||||||
|
print('No agent available')
|
||||||
|
return
|
||||||
|
|
||||||
|
keys = None
|
||||||
|
if attrs:
|
||||||
|
keys = []
|
||||||
|
for k in attrs:
|
||||||
|
for key in agent.keys():
|
||||||
|
if key.startswith(k):
|
||||||
|
keys.append(key)
|
||||||
|
|
||||||
|
print(agent.to_str(pretty=True, keys=keys))
|
||||||
|
|
||||||
|
do_ss = do_soil_self
|
||||||
|
|
||||||
|
def do_break_state(self, arg: str, temporary=False):
|
||||||
|
'''
|
||||||
|
Break before a specified state is stepped into.
|
||||||
|
'''
|
||||||
|
|
||||||
|
klass = None
|
||||||
|
state = arg.strip()
|
||||||
|
if not state:
|
||||||
|
self.error("Specify at least a state name")
|
||||||
|
return
|
||||||
|
|
||||||
|
comma = arg.find(':')
|
||||||
|
if comma > 0:
|
||||||
|
state = arg[comma+1:].lstrip()
|
||||||
|
klass = arg[:comma].rstrip()
|
||||||
|
klass = eval(klass,
|
||||||
|
self.curframe.f_globals,
|
||||||
|
self.curframe_locals)
|
||||||
|
|
||||||
|
if klass:
|
||||||
|
klasses = [klass]
|
||||||
|
else:
|
||||||
|
klasses = [k for k in self.curframe.f_globals.values() if isinstance(k, type) and issubclass(k, FSM)]
|
||||||
|
print(klasses)
|
||||||
|
if not klasses:
|
||||||
|
self.error('No agent classes found')
|
||||||
|
|
||||||
|
for klass in klasses:
|
||||||
|
try:
|
||||||
|
func = getattr(klass, state)
|
||||||
|
except AttributeError:
|
||||||
|
continue
|
||||||
|
if hasattr(func, '__func__'):
|
||||||
|
func = func.__func__
|
||||||
|
|
||||||
|
code = func.__code__
|
||||||
|
#use co_name to identify the bkpt (function names
|
||||||
|
#could be aliased, but co_name is invariant)
|
||||||
|
funcname = code.co_name
|
||||||
|
lineno = code.co_firstlineno
|
||||||
|
filename = code.co_filename
|
||||||
|
|
||||||
|
# Check for reasonable breakpoint
|
||||||
|
line = self.checkline(filename, lineno)
|
||||||
|
if not line:
|
||||||
|
raise ValueError('no line found')
|
||||||
|
# now set the break point
|
||||||
|
cond = None
|
||||||
|
existing = self.get_breaks(filename, line)
|
||||||
|
if existing:
|
||||||
|
self.message("Breakpoint already exists at %s:%d" %
|
||||||
|
(filename, line))
|
||||||
|
continue
|
||||||
|
err = self.set_break(filename, line, temporary, cond, funcname)
|
||||||
|
if err:
|
||||||
|
self.error(err)
|
||||||
|
else:
|
||||||
|
bp = self.get_breaks(filename, line)[-1]
|
||||||
|
self.message("Breakpoint %d at %s:%d" %
|
||||||
|
(bp.number, bp.file, bp.line))
|
||||||
|
do_bs = do_break_state
|
||||||
|
|
||||||
|
|
||||||
|
def setup(frame=None):
|
||||||
|
debugger = Debug()
|
||||||
|
frame = frame or sys._getframe().f_back
|
||||||
|
debugger.set_trace(frame)
|
||||||
|
|
||||||
|
def debug_env():
|
||||||
|
if os.environ.get('SOIL_DEBUG'):
|
||||||
|
return setup(frame=sys._getframe().f_back)
|
||||||
|
|
||||||
|
def post_mortem(traceback=None):
|
||||||
|
p = Debug()
|
||||||
|
t = sys.exc_info()[2]
|
||||||
|
p.reset()
|
||||||
|
p.interaction(None, t)
|
@ -1,4 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import math
|
import math
|
||||||
@ -17,9 +18,7 @@ import networkx as nx
|
|||||||
from mesa import Model
|
from mesa import Model
|
||||||
from mesa.datacollection import DataCollector
|
from mesa.datacollection import DataCollector
|
||||||
|
|
||||||
from . import serialization, analysis, utils, time, network
|
from . import agents as agentmod, config, serialization, utils, time, network
|
||||||
|
|
||||||
from .agents import AgentView, BaseAgent, NetworkAgent, from_config as agents_from_config
|
|
||||||
|
|
||||||
|
|
||||||
Record = namedtuple('Record', 'dict_id t_step key value')
|
Record = namedtuple('Record', 'dict_id t_step key value')
|
||||||
@ -39,12 +38,12 @@ class BaseEnvironment(Model):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
env_id='unnamed_env',
|
id='unnamed_env',
|
||||||
seed='default',
|
seed='default',
|
||||||
schedule=None,
|
schedule=None,
|
||||||
dir_path=None,
|
dir_path=None,
|
||||||
interval=1,
|
interval=1,
|
||||||
agent_class=BaseAgent,
|
agent_class=None,
|
||||||
agents: [tuple[type, Dict[str, Any]]] = {},
|
agents: [tuple[type, Dict[str, Any]]] = {},
|
||||||
agent_reporters: Optional[Any] = None,
|
agent_reporters: Optional[Any] = None,
|
||||||
model_reporters: Optional[Any] = None,
|
model_reporters: Optional[Any] = None,
|
||||||
@ -54,7 +53,7 @@ class BaseEnvironment(Model):
|
|||||||
super().__init__(seed=seed)
|
super().__init__(seed=seed)
|
||||||
self.current_id = -1
|
self.current_id = -1
|
||||||
|
|
||||||
self.id = env_id
|
self.id = id
|
||||||
|
|
||||||
self.dir_path = dir_path or os.getcwd()
|
self.dir_path = dir_path or os.getcwd()
|
||||||
|
|
||||||
@ -62,7 +61,7 @@ class BaseEnvironment(Model):
|
|||||||
schedule = time.TimedActivation(self)
|
schedule = time.TimedActivation(self)
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
|
|
||||||
self.agent_class = agent_class
|
self.agent_class = agent_class or agentmod.BaseAgent
|
||||||
|
|
||||||
self.init_agents(agents)
|
self.init_agents(agents)
|
||||||
|
|
||||||
@ -78,25 +77,51 @@ class BaseEnvironment(Model):
|
|||||||
tables=tables,
|
tables=tables,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __read_agent_tuple(self, tup):
|
def _read_single_agent(self, agent):
|
||||||
cls = self.agent_class
|
agent = dict(**agent)
|
||||||
args = tup
|
cls = agent.pop('agent_class', None) or self.agent_class
|
||||||
if isinstance(tup, tuple):
|
unique_id = agent.pop('unique_id', None)
|
||||||
cls = tup[0]
|
if unique_id is None:
|
||||||
args = tup[1]
|
unique_id = self.next_id()
|
||||||
return serialization.deserialize(cls)(unique_id=self.next_id(),
|
|
||||||
model=self, **args)
|
return serialization.deserialize(cls)(unique_id=unique_id,
|
||||||
|
model=self, **agent)
|
||||||
|
|
||||||
|
def init_agents(self, agents: Union[config.AgentConfig, [Dict[str, Any]]] = {}):
|
||||||
|
if not agents:
|
||||||
|
return
|
||||||
|
|
||||||
|
lst = agents
|
||||||
|
override = []
|
||||||
|
if not isinstance(lst, list):
|
||||||
|
if not isinstance(agents, config.AgentConfig):
|
||||||
|
lst = config.AgentConfig(**agents)
|
||||||
|
if lst.override:
|
||||||
|
override = lst.override
|
||||||
|
lst = agentmod.from_config(lst,
|
||||||
|
topologies=getattr(self, 'topologies', None),
|
||||||
|
random=self.random)
|
||||||
|
|
||||||
|
#TODO: check override is working again. It cannot (easily) be part of agents.from_config anymore,
|
||||||
|
# because it needs attribute such as unique_id, which are only present after init
|
||||||
|
new_agents = [self._read_single_agent(agent) for agent in lst]
|
||||||
|
|
||||||
|
|
||||||
|
for a in new_agents:
|
||||||
|
self.schedule.add(a)
|
||||||
|
|
||||||
|
for rule in override:
|
||||||
|
for agent in agentmod.filter_agents(self.schedule._agents, **rule.filter):
|
||||||
|
for attr, value in rule.state.items():
|
||||||
|
setattr(agent, attr, value)
|
||||||
|
|
||||||
def init_agents(self, agents: [tuple[type, Dict[str, Any]]] = {}):
|
|
||||||
agents = [self.__read_agent_tuple(tup) for tup in agents]
|
|
||||||
self._agents = {'default': {agent.id: agent for agent in agents}}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def agents(self):
|
def agents(self):
|
||||||
return AgentView(self._agents)
|
return agentmod.AgentView(self.schedule._agents)
|
||||||
|
|
||||||
def find_one(self, *args, **kwargs):
|
def find_one(self, *args, **kwargs):
|
||||||
return AgentView(self._agents).one(*args, **kwargs)
|
return agentmod.AgentView(self.schedule._agents).one(*args, **kwargs)
|
||||||
|
|
||||||
def count_agents(self, *args, **kwargs):
|
def count_agents(self, *args, **kwargs):
|
||||||
return sum(1 for i in self.agents(*args, **kwargs))
|
return sum(1 for i in self.agents(*args, **kwargs))
|
||||||
@ -108,38 +133,12 @@ class BaseEnvironment(Model):
|
|||||||
raise Exception('The environment has not been scheduled, so it has no sense of time')
|
raise Exception('The environment has not been scheduled, so it has no sense of time')
|
||||||
|
|
||||||
|
|
||||||
# def init_agent(self, agent_id, agent_definitions, state=None):
|
def add_agent(self, agent_id, agent_class, **kwargs):
|
||||||
# state = state or {}
|
|
||||||
|
|
||||||
# agent_class = None
|
|
||||||
# if 'agent_class' in self.states.get(agent_id, {}):
|
|
||||||
# agent_class = self.states[agent_id]['agent_class']
|
|
||||||
# elif 'agent_class' in self.default_state:
|
|
||||||
# agent_class = self.default_state['agent_class']
|
|
||||||
|
|
||||||
# if agent_class:
|
|
||||||
# agent_class = agents.deserialize_type(agent_class)
|
|
||||||
# elif agent_definitions:
|
|
||||||
# agent_class, state = agents._agent_from_definition(agent_definitions, unique_id=agent_id)
|
|
||||||
# else:
|
|
||||||
# serialization.logger.debug('Skipping agent {}'.format(agent_id))
|
|
||||||
# return
|
|
||||||
# return self.add_agent(agent_id, agent_class, state)
|
|
||||||
|
|
||||||
|
|
||||||
def add_agent(self, agent_id, agent_class, state=None, graph='default'):
|
|
||||||
defstate = deepcopy(self.default_state) or {}
|
|
||||||
defstate.update(self.states.get(agent_id, {}))
|
|
||||||
if state:
|
|
||||||
defstate.update(state)
|
|
||||||
a = None
|
a = None
|
||||||
if agent_class:
|
if agent_class:
|
||||||
state = defstate
|
|
||||||
a = agent_class(model=self,
|
a = agent_class(model=self,
|
||||||
unique_id=agent_id)
|
unique_id=agent_id,
|
||||||
|
**kwargs)
|
||||||
for (k, v) in state.items():
|
|
||||||
setattr(a, k, v)
|
|
||||||
|
|
||||||
self.schedule.add(a)
|
self.schedule.add(a)
|
||||||
return a
|
return a
|
||||||
@ -153,7 +152,7 @@ class BaseEnvironment(Model):
|
|||||||
message += " {k}={v} ".format(k, v)
|
message += " {k}={v} ".format(k, v)
|
||||||
extra = {}
|
extra = {}
|
||||||
extra['now'] = self.now
|
extra['now'] = self.now
|
||||||
extra['unique_id'] = self.id
|
extra['id'] = self.id
|
||||||
return self.logger.log(level, message, extra=extra)
|
return self.logger.log(level, message, extra=extra)
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
@ -161,6 +160,7 @@ class BaseEnvironment(Model):
|
|||||||
Advance one step in the simulation, and update the data collection and scheduler appropriately
|
Advance one step in the simulation, and update the data collection and scheduler appropriately
|
||||||
'''
|
'''
|
||||||
super().step()
|
super().step()
|
||||||
|
self.logger.info(f'--- Step {self.now:^5} ---')
|
||||||
self.schedule.step()
|
self.schedule.step()
|
||||||
self.datacollector.collect(self)
|
self.datacollector.collect(self)
|
||||||
|
|
||||||
@ -207,34 +207,41 @@ class BaseEnvironment(Model):
|
|||||||
yield from self._agent_to_tuples(agent, now)
|
yield from self._agent_to_tuples(agent, now)
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigEnvironment(BaseEnvironment):
|
class NetworkEnvironment(BaseEnvironment):
|
||||||
|
|
||||||
def __init__(self, *args,
|
def __init__(self, *args, topology: nx.Graph = None, topologies: Dict[str, config.NetConfig] = {}, **kwargs):
|
||||||
agents: Dict[str, config.AgentConfig] = {},
|
agents = kwargs.pop('agents', None)
|
||||||
**kwargs):
|
super().__init__(*args, agents=None, **kwargs)
|
||||||
return super().__init__(*args, agents=agents, **kwargs)
|
|
||||||
|
|
||||||
def init_agents(self, agents: Union[Dict[str, config.AgentConfig], [tuple[type, Dict[str, Any]]]] = {}):
|
|
||||||
if not isinstance(agents, dict):
|
|
||||||
return BaseEnvironment.init_agents(self, agents)
|
|
||||||
|
|
||||||
self._agents = agents_from_config(agents,
|
|
||||||
env=self,
|
|
||||||
random=self.random)
|
|
||||||
for d in self._agents.values():
|
|
||||||
for a in d.values():
|
|
||||||
self.schedule.add(a)
|
|
||||||
|
|
||||||
|
|
||||||
class NetworkConfigEnvironment(BaseEnvironment):
|
|
||||||
|
|
||||||
def __init__(self, *args, topologies: Dict[str, config.NetConfig] = {}, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.topologies = {}
|
|
||||||
self._node_ids = {}
|
self._node_ids = {}
|
||||||
|
assert not hasattr(self, 'topologies')
|
||||||
|
if topology is not None:
|
||||||
|
if topologies:
|
||||||
|
raise ValueError('Please, provide either a single topology or a dictionary of them')
|
||||||
|
topologies = {'default': topology}
|
||||||
|
|
||||||
|
self.topologies = {}
|
||||||
for (name, cfg) in topologies.items():
|
for (name, cfg) in topologies.items():
|
||||||
self.set_topology(cfg=cfg, graph=name)
|
self.set_topology(cfg=cfg, graph=name)
|
||||||
|
|
||||||
|
self.init_agents(agents)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_single_agent(self, agent, unique_id=None):
|
||||||
|
agent = dict(agent)
|
||||||
|
|
||||||
|
if agent.get('topology', None) is not None:
|
||||||
|
topology = agent.get('topology')
|
||||||
|
if unique_id is None:
|
||||||
|
unique_id = self.next_id()
|
||||||
|
if topology:
|
||||||
|
node_id = self.agent_to_node(unique_id, graph_name=topology, node_id=agent.get('node_id'))
|
||||||
|
agent['node_id'] = node_id
|
||||||
|
agent['topology'] = topology
|
||||||
|
agent['unique_id'] = unique_id
|
||||||
|
|
||||||
|
return super()._read_single_agent(agent)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def topology(self):
|
def topology(self):
|
||||||
return self.topologies['default']
|
return self.topologies['default']
|
||||||
@ -246,51 +253,50 @@ class NetworkConfigEnvironment(BaseEnvironment):
|
|||||||
|
|
||||||
self.topologies[graph] = topology
|
self.topologies[graph] = topology
|
||||||
|
|
||||||
def topology_for(self, agent_id):
|
def topology_for(self, unique_id):
|
||||||
return self.topologies[self._node_ids[agent_id][0]]
|
return self.topologies[self._node_ids[unique_id][0]]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def network_agents(self):
|
def network_agents(self):
|
||||||
yield from self.agents(agent_class=NetworkAgent)
|
yield from self.agents(agent_class=agentmod.NetworkAgent)
|
||||||
|
|
||||||
def agent_to_node(self, agent_id, graph_name='default', node_id=None, shuffle=False):
|
def agent_to_node(self, unique_id, graph_name='default',
|
||||||
node_id = network.agent_to_node(G=self.topologies[graph_name], agent_id=agent_id,
|
node_id=None, shuffle=False):
|
||||||
node_id=node_id, shuffle=shuffle,
|
node_id = network.agent_to_node(G=self.topologies[graph_name],
|
||||||
|
agent_id=unique_id,
|
||||||
|
node_id=node_id,
|
||||||
|
shuffle=shuffle,
|
||||||
random=self.random)
|
random=self.random)
|
||||||
|
|
||||||
self._node_ids[agent_id] = (graph_name, node_id)
|
self._node_ids[unique_id] = (graph_name, node_id)
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
def add_node(self, agent_class, topology, **kwargs):
|
||||||
|
unique_id = self.next_id()
|
||||||
|
self.topologies[topology].add_node(unique_id)
|
||||||
|
node_id = self.agent_to_node(unique_id=unique_id, node_id=unique_id, graph_name=topology)
|
||||||
|
|
||||||
def add_node(self, agent_class, state=None, graph='default'):
|
a = self.add_agent(unique_id=unique_id, agent_class=agent_class, node_id=node_id, topology=topology, **kwargs)
|
||||||
agent_id = int(len(self.topologies[graph].nodes()))
|
|
||||||
self.topologies[graph].add_node(agent_id)
|
|
||||||
a = self.add_agent(agent_id, agent_class, state, graph=graph)
|
|
||||||
a['visible'] = True
|
a['visible'] = True
|
||||||
return a
|
return a
|
||||||
|
|
||||||
def add_edge(self, agent1, agent2, start=None, graph='default', **attrs):
|
def add_edge(self, agent1, agent2, start=None, graph='default', **attrs):
|
||||||
if hasattr(agent1, 'id'):
|
agent1 = agent1.node_id
|
||||||
agent1 = agent1.id
|
agent2 = agent2.node_id
|
||||||
if hasattr(agent2, 'id'):
|
return self.topologies[graph].add_edge(agent1, agent2, start=start)
|
||||||
agent2 = agent2.id
|
|
||||||
start = start or self.now
|
|
||||||
return self.topologies[graph].add_edge(agent1, agent2, **attrs)
|
|
||||||
|
|
||||||
def add_agent(self, *args, state=None, graph='default', **kwargs):
|
def add_agent(self, unique_id, state=None, graph='default', **kwargs):
|
||||||
node = self.topologies[graph].nodes[agent_id]
|
node = self.topologies[graph].nodes[unique_id]
|
||||||
node_state = node.get('state', {})
|
node_state = node.get('state', {})
|
||||||
if node_state:
|
if node_state:
|
||||||
node_state.update(state or {})
|
node_state.update(state or {})
|
||||||
state = node_state
|
state = node_state
|
||||||
a = super().add_agent(*args, state=state, **kwargs)
|
a = super().add_agent(unique_id, state=state, **kwargs)
|
||||||
node['agent'] = a
|
node['agent'] = a
|
||||||
return a
|
return a
|
||||||
|
|
||||||
def node_id_for(self, agent_id):
|
def node_id_for(self, agent_id):
|
||||||
return self._node_ids[agent_id][1]
|
return self._node_ids[agent_id][1]
|
||||||
|
|
||||||
class Environment(AgentConfigEnvironment, NetworkConfigEnvironment):
|
|
||||||
def __init__(self, *args, **kwargs):
|
Environment = NetworkEnvironment
|
||||||
agents = kwargs.pop('agents', {})
|
|
||||||
NetworkConfigEnvironment.__init__(self, *args, **kwargs)
|
|
||||||
AgentConfigEnvironment.__init__(self, *args, agents=agents, **kwargs)
|
|
||||||
|
@ -12,7 +12,7 @@ from .serialization import deserialize
|
|||||||
from .utils import open_or_reuse, logger, timer
|
from .utils import open_or_reuse, logger, timer
|
||||||
|
|
||||||
|
|
||||||
from . import utils
|
from . import utils, network
|
||||||
|
|
||||||
|
|
||||||
class DryRunner(BytesIO):
|
class DryRunner(BytesIO):
|
||||||
@ -85,38 +85,28 @@ class Exporter:
|
|||||||
class default(Exporter):
|
class default(Exporter):
|
||||||
'''Default exporter. Writes sqlite results, as well as the simulation YAML'''
|
'''Default exporter. Writes sqlite results, as well as the simulation YAML'''
|
||||||
|
|
||||||
# def sim_start(self):
|
def sim_start(self):
|
||||||
# if not self.dry_run:
|
if not self.dry_run:
|
||||||
# logger.info('Dumping results to %s', self.outdir)
|
logger.info('Dumping results to %s', self.outdir)
|
||||||
# self.simulation.dump_yaml(outdir=self.outdir)
|
with self.output(self.simulation.name + '.dumped.yml') as f:
|
||||||
# else:
|
f.write(self.simulation.to_yaml())
|
||||||
# logger.info('NOT dumping results')
|
else:
|
||||||
|
logger.info('NOT dumping results')
|
||||||
|
|
||||||
# def trial_start(self, env, stats):
|
def trial_end(self, env):
|
||||||
# if not self.dry_run:
|
if not self.dry_run:
|
||||||
# with timer('Dumping simulation {} trial {}'.format(self.simulation.name,
|
with timer('Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||||
# env.name)):
|
env.id)):
|
||||||
# engine = create_engine('sqlite:///{}.sqlite'.format(env.name), echo=False)
|
engine = create_engine('sqlite:///{}.sqlite'.format(env.id), echo=False)
|
||||||
|
|
||||||
# dc = env.datacollector
|
dc = env.datacollector
|
||||||
# tables = {'env': dc.get_model_vars_dataframe(),
|
for (t, df) in get_dc_dfs(dc):
|
||||||
# 'agents': dc.get_agent_vars_dataframe(),
|
df.to_sql(t, con=engine, if_exists='append')
|
||||||
# 'agents': dc.get_agent_vars_dataframe()}
|
|
||||||
# for table in dc.tables:
|
|
||||||
# tables[table] = dc.get_table_dataframe(table)
|
|
||||||
# for (t, df) in tables.items():
|
|
||||||
# df.to_sql(t, con=engine)
|
|
||||||
|
|
||||||
# def sim_end(self, stats):
|
|
||||||
# with timer('Dumping simulation {}\'s stats'.format(self.simulation.name)):
|
|
||||||
# engine = create_engine('sqlite:///{}.sqlite'.format(self.simulation.name), echo=False)
|
|
||||||
# with self.output('{}.sqlite'.format(self.simulation.name), mode='wb') as f:
|
|
||||||
# self.simulation.dump_sqlite(f)
|
|
||||||
|
|
||||||
|
|
||||||
def get_dc_dfs(dc):
|
def get_dc_dfs(dc):
|
||||||
dfs = {'env': dc.get_model_vars_dataframe(),
|
dfs = {'env': dc.get_model_vars_dataframe(),
|
||||||
'agents': dc.get_agent_vars_dataframe }
|
'agents': dc.get_agent_vars_dataframe() }
|
||||||
for table_name in dc.tables:
|
for table_name in dc.tables:
|
||||||
dfs[table_name] = dc.get_table_dataframe(table_name)
|
dfs[table_name] = dc.get_table_dataframe(table_name)
|
||||||
yield from dfs.items()
|
yield from dfs.items()
|
||||||
@ -130,10 +120,11 @@ class csv(Exporter):
|
|||||||
env.id,
|
env.id,
|
||||||
self.outdir)):
|
self.outdir)):
|
||||||
for (df_name, df) in get_dc_dfs(env.datacollector):
|
for (df_name, df) in get_dc_dfs(env.datacollector):
|
||||||
with self.output('{}.stats.{}.csv'.format(env.id, df_name)) as f:
|
with self.output('{}.{}.csv'.format(env.id, df_name)) as f:
|
||||||
df.to_csv(f)
|
df.to_csv(f)
|
||||||
|
|
||||||
|
|
||||||
|
#TODO: reimplement GEXF exporting without history
|
||||||
class gexf(Exporter):
|
class gexf(Exporter):
|
||||||
def trial_end(self, env):
|
def trial_end(self, env):
|
||||||
if self.dry_run:
|
if self.dry_run:
|
||||||
@ -143,18 +134,9 @@ class gexf(Exporter):
|
|||||||
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name,
|
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||||
env.id)):
|
env.id)):
|
||||||
with self.output('{}.gexf'.format(env.id), mode='wb') as f:
|
with self.output('{}.gexf'.format(env.id), mode='wb') as f:
|
||||||
|
network.dump_gexf(env.history_to_graph(), f)
|
||||||
self.dump_gexf(env, f)
|
self.dump_gexf(env, f)
|
||||||
|
|
||||||
def dump_gexf(self, env, f):
|
|
||||||
G = env.history_to_graph()
|
|
||||||
# Workaround for geometric models
|
|
||||||
# See soil/soil#4
|
|
||||||
for node in G.nodes():
|
|
||||||
if 'pos' in G.nodes[node]:
|
|
||||||
G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}}
|
|
||||||
del (G.nodes[node]['pos'])
|
|
||||||
|
|
||||||
nx.write_gexf(G, f, version="1.2draft")
|
|
||||||
|
|
||||||
class dummy(Exporter):
|
class dummy(Exporter):
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -37,8 +39,10 @@ def from_config(cfg: config.NetConfig, dir_path: str = None):
|
|||||||
known_modules=['networkx.generators',])
|
known_modules=['networkx.generators',])
|
||||||
return method(**net_args)
|
return method(**net_args)
|
||||||
|
|
||||||
if isinstance(cfg.topology, basestring) or isinstance(cfg.topology, dict):
|
if isinstance(cfg.topology, config.Topology):
|
||||||
return nx.json_graph.node_link_graph(cfg.topology)
|
cfg = cfg.topology.dict()
|
||||||
|
if isinstance(cfg, str) or isinstance(cfg, dict):
|
||||||
|
return nx.json_graph.node_link_graph(cfg)
|
||||||
|
|
||||||
return nx.Graph()
|
return nx.Graph()
|
||||||
|
|
||||||
@ -57,9 +61,18 @@ def agent_to_node(G, agent_id, node_id=None, shuffle=False, random=random):
|
|||||||
for next_id, data in candidates:
|
for next_id, data in candidates:
|
||||||
if data.get('agent_id', None) is None:
|
if data.get('agent_id', None) is None:
|
||||||
node_id = next_id
|
node_id = next_id
|
||||||
data['agent_id'] = agent_id
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
raise ValueError(f"Not enough nodes in topology to assign one to agent {agent_id}")
|
raise ValueError(f"Not enough nodes in topology to assign one to agent {agent_id}")
|
||||||
|
G.nodes[node_id]['agent_id'] = agent_id
|
||||||
return node_id
|
return node_id
|
||||||
|
|
||||||
|
|
||||||
|
def dump_gexf(G, f):
|
||||||
|
for node in G.nodes():
|
||||||
|
if 'pos' in G.nodes[node]:
|
||||||
|
G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}}
|
||||||
|
del (G.nodes[node]['pos'])
|
||||||
|
|
||||||
|
nx.write_gexf(G, f, version="1.2draft")
|
||||||
|
@ -7,6 +7,8 @@ import importlib
|
|||||||
from glob import glob
|
from glob import glob
|
||||||
from itertools import product, chain
|
from itertools import product, chain
|
||||||
|
|
||||||
|
from .config import Config
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
@ -120,22 +122,25 @@ def params_for_template(config):
|
|||||||
def load_files(*patterns, **kwargs):
|
def load_files(*patterns, **kwargs):
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
for i in glob(pattern, **kwargs):
|
for i in glob(pattern, **kwargs):
|
||||||
for config in load_file(i):
|
for cfg in load_file(i):
|
||||||
path = os.path.abspath(i)
|
path = os.path.abspath(i)
|
||||||
yield config, path
|
yield Config.from_raw(cfg), path
|
||||||
|
|
||||||
|
|
||||||
def load_config(config):
|
def load_config(cfg):
|
||||||
if isinstance(config, dict):
|
if isinstance(cfg, Config):
|
||||||
yield config, os.getcwd()
|
yield cfg, os.getcwd()
|
||||||
|
elif isinstance(cfg, dict):
|
||||||
|
yield Config.from_raw(cfg), os.getcwd()
|
||||||
else:
|
else:
|
||||||
yield from load_files(config)
|
yield from load_files(cfg)
|
||||||
|
|
||||||
|
|
||||||
builtins = importlib.import_module('builtins')
|
builtins = importlib.import_module('builtins')
|
||||||
|
|
||||||
KNOWN_MODULES = ['soil', ]
|
KNOWN_MODULES = ['soil', ]
|
||||||
|
|
||||||
|
|
||||||
def name(value, known_modules=KNOWN_MODULES):
|
def name(value, known_modules=KNOWN_MODULES):
|
||||||
'''Return a name that can be imported, to serialize/deserialize an object'''
|
'''Return a name that can be imported, to serialize/deserialize an object'''
|
||||||
if value is None:
|
if value is None:
|
||||||
@ -172,8 +177,22 @@ def serialize(v, known_modules=KNOWN_MODULES):
|
|||||||
return func(v), tname
|
return func(v), tname
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_dict(d, known_modules=KNOWN_MODULES):
|
||||||
|
d = dict(d)
|
||||||
|
for (k, v) in d.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
d[k] = serialize_dict(v, known_modules=known_modules)
|
||||||
|
elif isinstance(v, list):
|
||||||
|
for ix in range(len(v)):
|
||||||
|
v[ix] = serialize_dict(v[ix], known_modules=known_modules)
|
||||||
|
elif isinstance(v, type):
|
||||||
|
d[k] = serialize(v, known_modules=known_modules)[1]
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
IS_CLASS = re.compile(r"<class '(.*)'>")
|
IS_CLASS = re.compile(r"<class '(.*)'>")
|
||||||
|
|
||||||
|
|
||||||
def deserializer(type_, known_modules=KNOWN_MODULES):
|
def deserializer(type_, known_modules=KNOWN_MODULES):
|
||||||
if type(type_) != str: # Already deserialized
|
if type(type_) != str: # Already deserialized
|
||||||
return type_
|
return type_
|
||||||
|
@ -4,15 +4,17 @@ import importlib
|
|||||||
import sys
|
import sys
|
||||||
import yaml
|
import yaml
|
||||||
import traceback
|
import traceback
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
|
from textwrap import dedent
|
||||||
|
|
||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
from typing import Union
|
from typing import Any, Dict, Union, Optional
|
||||||
|
|
||||||
|
|
||||||
from networkx.readwrite import json_graph
|
from networkx.readwrite import json_graph
|
||||||
from multiprocessing import Pool
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
@ -21,7 +23,6 @@ from .environment import Environment
|
|||||||
from .utils import logger, run_and_return_exceptions
|
from .utils import logger, run_and_return_exceptions
|
||||||
from .exporters import default
|
from .exporters import default
|
||||||
from .time import INFINITY
|
from .time import INFINITY
|
||||||
|
|
||||||
from .config import Config, convert_old
|
from .config import Config, convert_old
|
||||||
|
|
||||||
|
|
||||||
@ -36,7 +37,9 @@ class Simulation:
|
|||||||
|
|
||||||
kwargs: parameters to use to initialize a new configuration, if one has not been provided.
|
kwargs: parameters to use to initialize a new configuration, if one has not been provided.
|
||||||
"""
|
"""
|
||||||
|
version: str = '2'
|
||||||
name: str = 'Unnamed simulation'
|
name: str = 'Unnamed simulation'
|
||||||
|
description: Optional[str] = ''
|
||||||
group: str = None
|
group: str = None
|
||||||
model_class: Union[str, type] = 'soil.Environment'
|
model_class: Union[str, type] = 'soil.Environment'
|
||||||
model_params: dict = field(default_factory=dict)
|
model_params: dict = field(default_factory=dict)
|
||||||
@ -44,30 +47,37 @@ class Simulation:
|
|||||||
dir_path: str = field(default_factory=lambda: os.getcwd())
|
dir_path: str = field(default_factory=lambda: os.getcwd())
|
||||||
max_time: float = float('inf')
|
max_time: float = float('inf')
|
||||||
max_steps: int = -1
|
max_steps: int = -1
|
||||||
|
interval: int = 1
|
||||||
num_trials: int = 3
|
num_trials: int = 3
|
||||||
dry_run: bool = False
|
dry_run: bool = False
|
||||||
|
extra: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, env):
|
||||||
|
|
||||||
|
ignored = {k: v for k, v in env.items()
|
||||||
|
if k not in inspect.signature(cls).parameters}
|
||||||
|
|
||||||
|
kwargs = {k:v for k, v in env.items() if k not in ignored}
|
||||||
|
if ignored:
|
||||||
|
kwargs.setdefault('extra', {}).update(ignored)
|
||||||
|
if ignored:
|
||||||
|
print(f'Warning: Ignoring these parameters (added to "extra"): { ignored }')
|
||||||
|
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
def run_simulation(self, *args, **kwargs):
|
def run_simulation(self, *args, **kwargs):
|
||||||
return self.run(*args, **kwargs)
|
return self.run(*args, **kwargs)
|
||||||
|
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
'''Run the simulation and return the list of resulting environments'''
|
'''Run the simulation and return the list of resulting environments'''
|
||||||
|
logger.info(dedent('''
|
||||||
|
Simulation:
|
||||||
|
---
|
||||||
|
''') +
|
||||||
|
self.to_yaml())
|
||||||
return list(self.run_gen(*args, **kwargs))
|
return list(self.run_gen(*args, **kwargs))
|
||||||
|
|
||||||
def _run_sync_or_async(self, parallel=False, **kwargs):
|
|
||||||
if parallel and not os.environ.get('SENPY_DEBUG', None):
|
|
||||||
p = Pool()
|
|
||||||
func = partial(run_and_return_exceptions, self.run_trial, **kwargs)
|
|
||||||
for i in p.imap_unordered(func, self.num_trials):
|
|
||||||
if isinstance(i, Exception):
|
|
||||||
logger.error('Trial failed:\n\t%s', i.message)
|
|
||||||
continue
|
|
||||||
yield i
|
|
||||||
else:
|
|
||||||
for i in range(self.num_trials):
|
|
||||||
yield self.run_trial(trial_id=i,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
def run_gen(self, parallel=False, dry_run=False,
|
def run_gen(self, parallel=False, dry_run=False,
|
||||||
exporters=[default, ], outdir=None, exporter_params={},
|
exporters=[default, ], outdir=None, exporter_params={},
|
||||||
log_level=None,
|
log_level=None,
|
||||||
@ -88,7 +98,9 @@ class Simulation:
|
|||||||
for exporter in exporters:
|
for exporter in exporters:
|
||||||
exporter.sim_start()
|
exporter.sim_start()
|
||||||
|
|
||||||
for env in self._run_sync_or_async(parallel=parallel,
|
for env in utils.run_parallel(func=self.run_trial,
|
||||||
|
iterable=range(int(self.num_trials)),
|
||||||
|
parallel=parallel,
|
||||||
log_level=log_level,
|
log_level=log_level,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
@ -103,14 +115,6 @@ class Simulation:
|
|||||||
for exporter in exporters:
|
for exporter in exporters:
|
||||||
exporter.sim_end()
|
exporter.sim_end()
|
||||||
|
|
||||||
def run_model(self, until=None, *args, **kwargs):
|
|
||||||
until = until or float('inf')
|
|
||||||
|
|
||||||
while self.schedule.next_time < until:
|
|
||||||
self.step()
|
|
||||||
utils.logger.debug(f'Simulation step {self.schedule.time}/{until}. Next: {self.schedule.next_time}')
|
|
||||||
self.schedule.time = until
|
|
||||||
|
|
||||||
def get_env(self, trial_id=0, **kwargs):
|
def get_env(self, trial_id=0, **kwargs):
|
||||||
'''Create an environment for a trial of the simulation'''
|
'''Create an environment for a trial of the simulation'''
|
||||||
def deserialize_reporters(reporters):
|
def deserialize_reporters(reporters):
|
||||||
@ -132,29 +136,50 @@ class Simulation:
|
|||||||
model_reporters=model_reporters,
|
model_reporters=model_reporters,
|
||||||
**model_params)
|
**model_params)
|
||||||
|
|
||||||
def run_trial(self, trial_id=None, until=None, log_level=logging.INFO, **opts):
|
def run_trial(self, trial_id=None, until=None, log_file=False, log_level=logging.INFO, **opts):
|
||||||
"""
|
"""
|
||||||
Run a single trial of the simulation
|
Run a single trial of the simulation
|
||||||
|
|
||||||
"""
|
"""
|
||||||
model = self.get_env(trial_id, **opts)
|
|
||||||
return self.run_model(model, trial_id=trial_id, until=until, log_level=log_level)
|
|
||||||
|
|
||||||
def run_model(self, model, trial_id=None, until=None, log_level=logging.INFO, **opts):
|
|
||||||
trial_id = trial_id if trial_id is not None else current_time()
|
|
||||||
if log_level:
|
if log_level:
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
|
model = self.get_env(trial_id, **opts)
|
||||||
|
trial_id = trial_id if trial_id is not None else current_time()
|
||||||
|
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
||||||
|
return self.run_model(model=model, trial_id=trial_id, until=until, log_level=log_level)
|
||||||
|
|
||||||
|
def run_model(self, model, until=None, **opts):
|
||||||
# Set-up trial environment and graph
|
# Set-up trial environment and graph
|
||||||
until = until or self.max_time
|
until = float(until or self.max_time or 'inf')
|
||||||
|
|
||||||
# Set up agents on nodes
|
# Set up agents on nodes
|
||||||
is_done = lambda: False
|
def is_done():
|
||||||
if self.max_time and hasattr(self.schedule, 'time'):
|
return False
|
||||||
is_done = lambda x: is_done() or self.schedule.time >= self.max_time
|
|
||||||
if self.max_steps and hasattr(self.schedule, 'time'):
|
if until and hasattr(model.schedule, 'time'):
|
||||||
is_done = lambda: is_done() or self.schedule.steps >= self.max_steps
|
prev = is_done
|
||||||
|
|
||||||
|
def is_done():
|
||||||
|
return prev() or model.schedule.time >= until
|
||||||
|
|
||||||
|
if self.max_steps and self.max_steps > 0 and hasattr(model.schedule, 'steps'):
|
||||||
|
prev_steps = is_done
|
||||||
|
|
||||||
|
def is_done():
|
||||||
|
return prev_steps() or model.schedule.steps >= self.max_steps
|
||||||
|
|
||||||
|
newline = '\n'
|
||||||
|
logger.info(dedent(f'''
|
||||||
|
Model stats:
|
||||||
|
Agents (total: { model.schedule.get_agent_count() }):
|
||||||
|
- { (newline + ' - ').join(str(a) for a in model.schedule.agents) }'''
|
||||||
|
f'''
|
||||||
|
|
||||||
|
Topologies (size):
|
||||||
|
- { dict( (k, len(v)) for (k, v) in model.topologies.items()) }
|
||||||
|
''' if getattr(model, "topologies", None) else ''
|
||||||
|
))
|
||||||
|
|
||||||
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
|
||||||
while not is_done():
|
while not is_done():
|
||||||
utils.logger.debug(f'Simulation time {model.schedule.time}/{until}. Next: {getattr(model.schedule, "next_time", model.schedule.time + self.interval)}')
|
utils.logger.debug(f'Simulation time {model.schedule.time}/{until}. Next: {getattr(model.schedule, "next_time", model.schedule.time + self.interval)}')
|
||||||
model.step()
|
model.step()
|
||||||
@ -162,26 +187,25 @@ class Simulation:
|
|||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
d = asdict(self)
|
d = asdict(self)
|
||||||
d['model_class'] = serialization.serialize(d['model_class'])[0]
|
if not isinstance(d['model_class'], str):
|
||||||
d['model_params'] = serialization.serialize(d['model_params'])[0]
|
d['model_class'] = serialization.name(d['model_class'])
|
||||||
|
d['model_params'] = serialization.serialize_dict(d['model_params'])
|
||||||
d['dir_path'] = str(d['dir_path'])
|
d['dir_path'] = str(d['dir_path'])
|
||||||
|
d['version'] = '2'
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def to_yaml(self):
|
def to_yaml(self):
|
||||||
return yaml.dump(self.asdict())
|
return yaml.dump(self.to_dict())
|
||||||
|
|
||||||
|
|
||||||
def iter_from_config(config):
|
def iter_from_config(*cfgs):
|
||||||
|
for config in cfgs:
|
||||||
configs = list(serialization.load_config(config))
|
configs = list(serialization.load_config(config))
|
||||||
for config, path in configs:
|
for config, path in configs:
|
||||||
d = dict(config)
|
d = dict(config)
|
||||||
if 'dir_path' not in d:
|
if 'dir_path' not in d:
|
||||||
d['dir_path'] = os.path.dirname(path)
|
d['dir_path'] = os.path.dirname(path)
|
||||||
if d.get('version', '2') == '1' or 'agents' in d or 'network_agents' in d or 'environment_agents' in d:
|
yield Simulation.from_dict(d)
|
||||||
d = convert_old(d)
|
|
||||||
d.pop('version', None)
|
|
||||||
yield Simulation(**d)
|
|
||||||
|
|
||||||
|
|
||||||
def from_config(conf_or_path):
|
def from_config(conf_or_path):
|
||||||
@ -192,6 +216,6 @@ def from_config(conf_or_path):
|
|||||||
|
|
||||||
|
|
||||||
def run_from_config(*configs, **kwargs):
|
def run_from_config(*configs, **kwargs):
|
||||||
for sim in iter_from_config(configs):
|
for sim in iter_from_config(*configs):
|
||||||
logger.info(f"Using config(s): {sim.id}")
|
logger.info(f"Using config(s): {sim.name}")
|
||||||
sim.run_simulation(**kwargs)
|
sim.run_simulation(**kwargs)
|
||||||
|
28
soil/time.py
28
soil/time.py
@ -1,6 +1,6 @@
|
|||||||
from mesa.time import BaseScheduler
|
from mesa.time import BaseScheduler
|
||||||
from queue import Empty
|
from queue import Empty
|
||||||
from heapq import heappush, heappop
|
from heapq import heappush, heappop, heapify
|
||||||
import math
|
import math
|
||||||
from .utils import logger
|
from .utils import logger
|
||||||
from mesa import Agent as MesaAgent
|
from mesa import Agent as MesaAgent
|
||||||
@ -17,6 +17,7 @@ class When:
|
|||||||
def abs(self, time):
|
def abs(self, time):
|
||||||
return self._time
|
return self._time
|
||||||
|
|
||||||
|
|
||||||
NEVER = When(INFINITY)
|
NEVER = When(INFINITY)
|
||||||
|
|
||||||
|
|
||||||
@ -38,13 +39,21 @@ class TimedActivation(BaseScheduler):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self._next = {}
|
||||||
self._queue = []
|
self._queue = []
|
||||||
self.next_time = 0
|
self.next_time = 0
|
||||||
self.logger = logger.getChild(f'time_{ self.model }')
|
self.logger = logger.getChild(f'time_{ self.model }')
|
||||||
|
|
||||||
def add(self, agent: MesaAgent):
|
def add(self, agent: MesaAgent, when=None):
|
||||||
if agent.unique_id not in self._agents:
|
if when is None:
|
||||||
heappush(self._queue, (self.time, agent.unique_id))
|
when = self.time
|
||||||
|
if agent.unique_id in self._agents:
|
||||||
|
self._queue.remove((self._next[agent.unique_id], agent.unique_id))
|
||||||
|
del self._agents[agent.unique_id]
|
||||||
|
heapify(self._queue)
|
||||||
|
|
||||||
|
heappush(self._queue, (when, agent.unique_id))
|
||||||
|
self._next[agent.unique_id] = when
|
||||||
super().add(agent)
|
super().add(agent)
|
||||||
|
|
||||||
def step(self) -> None:
|
def step(self) -> None:
|
||||||
@ -64,11 +73,18 @@ class TimedActivation(BaseScheduler):
|
|||||||
(when, agent_id) = heappop(self._queue)
|
(when, agent_id) = heappop(self._queue)
|
||||||
self.logger.debug(f'Stepping agent {agent_id}')
|
self.logger.debug(f'Stepping agent {agent_id}')
|
||||||
|
|
||||||
returned = self._agents[agent_id].step()
|
agent = self._agents[agent_id]
|
||||||
|
returned = agent.step()
|
||||||
|
|
||||||
|
if not agent.alive:
|
||||||
|
self.remove(agent)
|
||||||
|
continue
|
||||||
|
|
||||||
when = (returned or Delta(1)).abs(self.time)
|
when = (returned or Delta(1)).abs(self.time)
|
||||||
if when < self.time:
|
if when < self.time:
|
||||||
raise Exception("Cannot schedule an agent for a time in the past ({} < {})".format(when, self.time))
|
raise Exception("Cannot schedule an agent for a time in the past ({} < {})".format(when, self.time))
|
||||||
|
|
||||||
|
self._next[agent_id] = when
|
||||||
heappush(self._queue, (when, agent_id))
|
heappush(self._queue, (when, agent_id))
|
||||||
|
|
||||||
self.steps += 1
|
self.steps += 1
|
||||||
@ -77,7 +93,7 @@ class TimedActivation(BaseScheduler):
|
|||||||
self.time = INFINITY
|
self.time = INFINITY
|
||||||
self.next_time = INFINITY
|
self.next_time = INFINITY
|
||||||
self.model.running = False
|
self.model.running = False
|
||||||
return
|
return self.time
|
||||||
|
|
||||||
self.next_time = self._queue[0][0]
|
self.next_time = self._queue[0][0]
|
||||||
self.logger.debug(f'Next step: {self.next_time}')
|
self.logger.debug(f'Next step: {self.next_time}')
|
||||||
|
@ -3,13 +3,27 @@ from time import time as current_time, strftime, gmtime, localtime
|
|||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
logger = logging.getLogger('soil')
|
logger = logging.getLogger('soil')
|
||||||
# logging.basicConfig()
|
logger.setLevel(logging.INFO)
|
||||||
# logger.setLevel(logging.INFO)
|
|
||||||
|
timeformat = "%H:%M:%S"
|
||||||
|
|
||||||
|
if os.environ.get('SOIL_VERBOSE', ''):
|
||||||
|
logformat = "[%(levelname)-5.5s][%(asctime)s][%(name)s]: %(message)s"
|
||||||
|
else:
|
||||||
|
logformat = "[%(levelname)-5.5s][%(asctime)s] %(message)s"
|
||||||
|
|
||||||
|
logFormatter = logging.Formatter(logformat, timeformat)
|
||||||
|
|
||||||
|
consoleHandler = logging.StreamHandler()
|
||||||
|
consoleHandler.setFormatter(logFormatter)
|
||||||
|
logger.addHandler(consoleHandler)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -27,8 +41,6 @@ def timer(name='task', pre="", function=logger.info, to_object=None):
|
|||||||
to_object.end = end
|
to_object.end = end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def safe_open(path, mode='r', backup=True, **kwargs):
|
def safe_open(path, mode='r', backup=True, **kwargs):
|
||||||
outdir = os.path.dirname(path)
|
outdir = os.path.dirname(path)
|
||||||
if outdir and not os.path.exists(outdir):
|
if outdir and not os.path.exists(outdir):
|
||||||
@ -92,7 +104,7 @@ def unflatten_dict(d):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def run_and_return_exceptions(self, func, *args, **kwargs):
|
def run_and_return_exceptions(func, *args, **kwargs):
|
||||||
'''
|
'''
|
||||||
A wrapper for run_trial that catches exceptions and returns them.
|
A wrapper for run_trial that catches exceptions and returns them.
|
||||||
It is meant for async simulations.
|
It is meant for async simulations.
|
||||||
@ -104,3 +116,18 @@ def run_and_return_exceptions(self, func, *args, **kwargs):
|
|||||||
ex = ex.__cause__
|
ex = ex.__cause__
|
||||||
ex.message = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)[:])
|
ex.message = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)[:])
|
||||||
return ex
|
return ex
|
||||||
|
|
||||||
|
|
||||||
|
def run_parallel(func, iterable, parallel=False, **kwargs):
|
||||||
|
if parallel and not os.environ.get('SOIL_DEBUG', None):
|
||||||
|
p = Pool()
|
||||||
|
wrapped_func = partial(run_and_return_exceptions,
|
||||||
|
func, **kwargs)
|
||||||
|
for i in p.imap_unordered(wrapped_func, iterable):
|
||||||
|
if isinstance(i, Exception):
|
||||||
|
logger.error('Trial failed:\n\t%s', i.message)
|
||||||
|
continue
|
||||||
|
yield i
|
||||||
|
else:
|
||||||
|
for i in iterable:
|
||||||
|
yield func(i, **kwargs)
|
||||||
|
@ -1,49 +1,50 @@
|
|||||||
---
|
---
|
||||||
version: '2'
|
version: '2'
|
||||||
general:
|
name: simple
|
||||||
id: simple
|
group: tests
|
||||||
group: tests
|
dir_path: "/tmp/"
|
||||||
dir_path: "/tmp/"
|
num_trials: 3
|
||||||
num_trials: 3
|
max_time: 100
|
||||||
max_time: 100
|
interval: 1
|
||||||
interval: 1
|
seed: "CompleteSeed!"
|
||||||
seed: "CompleteSeed!"
|
model_class: Environment
|
||||||
topologies:
|
model_params:
|
||||||
|
topologies:
|
||||||
default:
|
default:
|
||||||
params:
|
params:
|
||||||
generator: complete_graph
|
generator: complete_graph
|
||||||
n: 10
|
n: 4
|
||||||
agents:
|
agents:
|
||||||
default:
|
|
||||||
agent_class: CounterModel
|
agent_class: CounterModel
|
||||||
state:
|
state:
|
||||||
|
group: network
|
||||||
times: 1
|
times: 1
|
||||||
network:
|
|
||||||
topology: 'default'
|
topology: 'default'
|
||||||
distribution:
|
distribution:
|
||||||
- agent_class: CounterModel
|
- agent_class: CounterModel
|
||||||
weight: 0.4
|
weight: 0.25
|
||||||
state:
|
state:
|
||||||
state_id: 0
|
state_id: 0
|
||||||
|
times: 1
|
||||||
- agent_class: AggregatedCounter
|
- agent_class: AggregatedCounter
|
||||||
weight: 0.6
|
weight: 0.5
|
||||||
override:
|
|
||||||
- filter:
|
|
||||||
node_id: 0
|
|
||||||
state:
|
state:
|
||||||
name: 'The first node'
|
times: 2
|
||||||
|
override:
|
||||||
- filter:
|
- filter:
|
||||||
node_id: 1
|
node_id: 1
|
||||||
state:
|
state:
|
||||||
name: 'The second node'
|
name: 'Node 1'
|
||||||
|
- filter:
|
||||||
environment:
|
node_id: 2
|
||||||
fixed:
|
|
||||||
- name: 'Environment Agent 1'
|
|
||||||
agent_class: CounterModel
|
|
||||||
state:
|
state:
|
||||||
|
name: 'Node 2'
|
||||||
|
fixed:
|
||||||
|
- agent_class: BaseAgent
|
||||||
|
hidden: true
|
||||||
|
topology: null
|
||||||
|
state:
|
||||||
|
name: 'Environment Agent 1'
|
||||||
times: 10
|
times: 10
|
||||||
environment:
|
group: environment
|
||||||
environment_class: Environment
|
|
||||||
params:
|
|
||||||
am_i_complete: true
|
am_i_complete: true
|
||||||
|
@ -8,17 +8,20 @@ interval: 1
|
|||||||
seed: "CompleteSeed!"
|
seed: "CompleteSeed!"
|
||||||
network_params:
|
network_params:
|
||||||
generator: complete_graph
|
generator: complete_graph
|
||||||
n: 10
|
n: 4
|
||||||
network_agents:
|
network_agents:
|
||||||
- agent_class: CounterModel
|
- agent_class: CounterModel
|
||||||
weight: 0.4
|
weight: 0.25
|
||||||
state:
|
state:
|
||||||
state_id: 0
|
state_id: 0
|
||||||
|
times: 1
|
||||||
- agent_class: AggregatedCounter
|
- agent_class: AggregatedCounter
|
||||||
weight: 0.6
|
weight: 0.5
|
||||||
|
state:
|
||||||
|
times: 2
|
||||||
environment_agents:
|
environment_agents:
|
||||||
- agent_id: 'Environment Agent 1'
|
- agent_id: 'Environment Agent 1'
|
||||||
agent_class: CounterModel
|
agent_class: BaseAgent
|
||||||
state:
|
state:
|
||||||
times: 10
|
times: 10
|
||||||
environment_class: Environment
|
environment_class: Environment
|
||||||
@ -28,5 +31,7 @@ agent_class: CounterModel
|
|||||||
default_state:
|
default_state:
|
||||||
times: 1
|
times: 1
|
||||||
states:
|
states:
|
||||||
- name: 'The first node'
|
1:
|
||||||
- name: 'The second node'
|
name: 'Node 1'
|
||||||
|
2:
|
||||||
|
name: 'Node 2'
|
||||||
|
@ -8,7 +8,7 @@ class Dead(agents.FSM):
|
|||||||
@agents.default_state
|
@agents.default_state
|
||||||
@agents.state
|
@agents.state
|
||||||
def only(self):
|
def only(self):
|
||||||
self.die()
|
return self.die()
|
||||||
|
|
||||||
class TestMain(TestCase):
|
class TestMain(TestCase):
|
||||||
def test_die_raises_exception(self):
|
def test_die_raises_exception(self):
|
||||||
@ -19,4 +19,6 @@ class TestMain(TestCase):
|
|||||||
|
|
||||||
def test_die_returns_infinity(self):
|
def test_die_returns_infinity(self):
|
||||||
d = Dead(unique_id=0, model=environment.Environment())
|
d = Dead(unique_id=0, model=environment.Environment())
|
||||||
assert d.step().abs(0) == stime.INFINITY
|
ret = d.step().abs(0)
|
||||||
|
print(ret, 'next')
|
||||||
|
assert ret == stime.INFINITY
|
||||||
|
@ -1,91 +0,0 @@
|
|||||||
from unittest import TestCase
|
|
||||||
|
|
||||||
import os
|
|
||||||
import pandas as pd
|
|
||||||
import yaml
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from os.path import join
|
|
||||||
from soil import simulation, analysis, agents
|
|
||||||
|
|
||||||
|
|
||||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
|
||||||
|
|
||||||
|
|
||||||
class Ping(agents.FSM):
|
|
||||||
|
|
||||||
defaults = {
|
|
||||||
'count': 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
@agents.default_state
|
|
||||||
@agents.state
|
|
||||||
def even(self):
|
|
||||||
self.debug(f'Even {self["count"]}')
|
|
||||||
self['count'] += 1
|
|
||||||
return self.odd
|
|
||||||
|
|
||||||
@agents.state
|
|
||||||
def odd(self):
|
|
||||||
self.debug(f'Odd {self["count"]}')
|
|
||||||
self['count'] += 1
|
|
||||||
return self.even
|
|
||||||
|
|
||||||
|
|
||||||
class TestAnalysis(TestCase):
|
|
||||||
|
|
||||||
# Code to generate a simple sqlite history
|
|
||||||
def setUp(self):
|
|
||||||
"""
|
|
||||||
The initial states should be applied to the agent and the
|
|
||||||
agent should be able to update its state."""
|
|
||||||
config = {
|
|
||||||
'name': 'analysis',
|
|
||||||
'seed': 'seed',
|
|
||||||
'network_params': {
|
|
||||||
'generator': 'complete_graph',
|
|
||||||
'n': 2
|
|
||||||
},
|
|
||||||
'agent_class': Ping,
|
|
||||||
'states': [{'interval': 1}, {'interval': 2}],
|
|
||||||
'max_time': 30,
|
|
||||||
'num_trials': 1,
|
|
||||||
'history': True,
|
|
||||||
'environment_params': {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s = simulation.from_config(config)
|
|
||||||
self.env = s.run_simulation(dry_run=True)[0]
|
|
||||||
|
|
||||||
def test_saved(self):
|
|
||||||
env = self.env
|
|
||||||
assert env.get_agent(0)['count', 0] == 1
|
|
||||||
assert env.get_agent(0)['count', 29] == 30
|
|
||||||
assert env.get_agent(1)['count', 0] == 1
|
|
||||||
assert env.get_agent(1)['count', 29] == 15
|
|
||||||
assert env['env', 29, None]['SEED'] == env['env', 29, 'SEED']
|
|
||||||
|
|
||||||
def test_count(self):
|
|
||||||
env = self.env
|
|
||||||
df = analysis.read_sql(env._history.db_path)
|
|
||||||
res = analysis.get_count(df, 'SEED', 'state_id')
|
|
||||||
assert res['SEED'][self.env['SEED']].iloc[0] == 1
|
|
||||||
assert res['SEED'][self.env['SEED']].iloc[-1] == 1
|
|
||||||
assert res['state_id']['odd'].iloc[0] == 2
|
|
||||||
assert res['state_id']['even'].iloc[0] == 0
|
|
||||||
assert res['state_id']['odd'].iloc[-1] == 1
|
|
||||||
assert res['state_id']['even'].iloc[-1] == 1
|
|
||||||
|
|
||||||
def test_value(self):
|
|
||||||
env = self.env
|
|
||||||
df = analysis.read_sql(env._history.db_path)
|
|
||||||
res_sum = analysis.get_value(df, 'count')
|
|
||||||
|
|
||||||
assert res_sum['count'].iloc[0] == 2
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
res_mean = analysis.get_value(df, 'count', aggfunc=np.mean)
|
|
||||||
assert res_mean['count'].iloc[15] == (16+8)/2
|
|
||||||
|
|
||||||
res_total = analysis.get_majority(df)
|
|
||||||
res_total['SEED'].iloc[0] == self.env['SEED']
|
|
@ -29,7 +29,7 @@ class TestConfig(TestCase):
|
|||||||
expected = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
|
expected = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
|
||||||
old = serialization.load_file(join(ROOT, "old_complete.yml"))[0]
|
old = serialization.load_file(join(ROOT, "old_complete.yml"))[0]
|
||||||
converted_defaults = config.convert_old(old, strict=False)
|
converted_defaults = config.convert_old(old, strict=False)
|
||||||
converted = converted_defaults.dict(skip_defaults=True)
|
converted = converted_defaults.dict(exclude_unset=True)
|
||||||
|
|
||||||
isequal(converted, expected)
|
isequal(converted, expected)
|
||||||
|
|
||||||
@ -40,10 +40,10 @@ class TestConfig(TestCase):
|
|||||||
"""
|
"""
|
||||||
config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
init_config = copy.copy(s.config)
|
init_config = copy.copy(s.to_dict())
|
||||||
|
|
||||||
s.run_simulation(dry_run=True)
|
s.run_simulation(dry_run=True)
|
||||||
nconfig = s.config
|
nconfig = s.to_dict()
|
||||||
# del nconfig['to
|
# del nconfig['to
|
||||||
isequal(init_config, nconfig)
|
isequal(init_config, nconfig)
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ class TestConfig(TestCase):
|
|||||||
Simple configuration that tests that the graph is loaded, and that
|
Simple configuration that tests that the graph is loaded, and that
|
||||||
network agents are initialized properly.
|
network agents are initialized properly.
|
||||||
"""
|
"""
|
||||||
config = {
|
cfg = {
|
||||||
'name': 'CounterAgent',
|
'name': 'CounterAgent',
|
||||||
'network_params': {
|
'network_params': {
|
||||||
'path': join(ROOT, 'test.gexf')
|
'path': join(ROOT, 'test.gexf')
|
||||||
@ -74,12 +74,14 @@ class TestConfig(TestCase):
|
|||||||
'environment_params': {
|
'environment_params': {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s = simulation.from_old_config(config)
|
conf = config.convert_old(cfg)
|
||||||
|
s = simulation.from_config(conf)
|
||||||
|
|
||||||
env = s.get_env()
|
env = s.get_env()
|
||||||
assert len(env.topologies['default'].nodes) == 2
|
assert len(env.topologies['default'].nodes) == 2
|
||||||
assert len(env.topologies['default'].edges) == 1
|
assert len(env.topologies['default'].edges) == 1
|
||||||
assert len(env.agents) == 2
|
assert len(env.agents) == 2
|
||||||
assert env.agents[0].topology == env.topologies['default']
|
assert env.agents[0].G == env.topologies['default']
|
||||||
|
|
||||||
def test_agents_from_config(self):
|
def test_agents_from_config(self):
|
||||||
'''We test that the known complete configuration produces
|
'''We test that the known complete configuration produces
|
||||||
@ -87,13 +89,10 @@ class TestConfig(TestCase):
|
|||||||
cfg = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
|
cfg = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
|
||||||
s = simulation.from_config(cfg)
|
s = simulation.from_config(cfg)
|
||||||
env = s.get_env()
|
env = s.get_env()
|
||||||
assert len(env.topologies['default'].nodes) == 10
|
assert len(env.topologies['default'].nodes) == 4
|
||||||
assert len(env.agents(group='network')) == 10
|
assert len(env.agents(group='network')) == 4
|
||||||
assert len(env.agents(group='environment')) == 1
|
assert len(env.agents(group='environment')) == 1
|
||||||
|
|
||||||
assert sum(1 for a in env.agents(group='network', agent_class=agents.CounterModel)) == 4
|
|
||||||
assert sum(1 for a in env.agents(group='network', agent_class=agents.AggregatedCounter)) == 6
|
|
||||||
|
|
||||||
def test_yaml(self):
|
def test_yaml(self):
|
||||||
"""
|
"""
|
||||||
The YAML version of a newly created configuration should be equivalent
|
The YAML version of a newly created configuration should be equivalent
|
||||||
|
@ -2,7 +2,7 @@ from unittest import TestCase
|
|||||||
import os
|
import os
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
|
||||||
from soil import serialization, simulation
|
from soil import serialization, simulation, config
|
||||||
|
|
||||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||||
EXAMPLES = join(ROOT, '..', 'examples')
|
EXAMPLES = join(ROOT, '..', 'examples')
|
||||||
@ -14,36 +14,37 @@ class TestExamples(TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def make_example_test(path, config):
|
def make_example_test(path, cfg):
|
||||||
def wrapped(self):
|
def wrapped(self):
|
||||||
root = os.getcwd()
|
root = os.getcwd()
|
||||||
for s in simulation.all_from_config(path):
|
for s in simulation.iter_from_config(cfg):
|
||||||
iterations = s.config.general.max_time * s.config.general.num_trials
|
iterations = s.max_steps * s.num_trials
|
||||||
if iterations > 1000:
|
if iterations < 0 or iterations > 1000:
|
||||||
s.config.general.max_time = 100
|
s.max_steps = 100
|
||||||
s.config.general.num_trials = 1
|
s.num_trials = 1
|
||||||
if config.get('skip_test', False) and not FORCE_TESTS:
|
assert isinstance(cfg, config.Config)
|
||||||
|
if getattr(cfg, 'skip_test', False) and not FORCE_TESTS:
|
||||||
self.skipTest('Example ignored.')
|
self.skipTest('Example ignored.')
|
||||||
envs = s.run_simulation(dry_run=True)
|
envs = s.run_simulation(dry_run=True)
|
||||||
assert envs
|
assert envs
|
||||||
for env in envs:
|
for env in envs:
|
||||||
assert env
|
assert env
|
||||||
try:
|
try:
|
||||||
n = config['network_params']['n']
|
n = cfg.model_params['network_params']['n']
|
||||||
assert len(list(env.network_agents)) == n
|
assert len(list(env.network_agents)) == n
|
||||||
assert env.now > 0 # It has run
|
|
||||||
assert env.now <= config['max_time'] # But not further than allowed
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
assert env.schedule.steps > 0 # It has run
|
||||||
|
assert env.schedule.steps <= s.max_steps # But not further than allowed
|
||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
def add_example_tests():
|
def add_example_tests():
|
||||||
for config, path in serialization.load_files(
|
for cfg, path in serialization.load_files(
|
||||||
join(EXAMPLES, '*', '*.yml'),
|
join(EXAMPLES, '*', '*.yml'),
|
||||||
join(EXAMPLES, '*.yml'),
|
join(EXAMPLES, '*.yml'),
|
||||||
):
|
):
|
||||||
p = make_example_test(path=path, config=config)
|
p = make_example_test(path=path, cfg=config.Config.from_raw(cfg))
|
||||||
fname = os.path.basename(path)
|
fname = os.path.basename(path)
|
||||||
p.__name__ = 'test_example_file_%s' % fname
|
p.__name__ = 'test_example_file_%s' % fname
|
||||||
p.__doc__ = '%s should be a valid configuration' % fname
|
p.__doc__ = '%s should be a valid configuration' % fname
|
||||||
|
@ -6,6 +6,8 @@ import shutil
|
|||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
from soil import exporters
|
from soil import exporters
|
||||||
from soil import simulation
|
from soil import simulation
|
||||||
|
from soil import agents
|
||||||
|
|
||||||
|
|
||||||
class Dummy(exporters.Exporter):
|
class Dummy(exporters.Exporter):
|
||||||
started = False
|
started = False
|
||||||
@ -33,28 +35,36 @@ class Dummy(exporters.Exporter):
|
|||||||
|
|
||||||
class Exporters(TestCase):
|
class Exporters(TestCase):
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
|
# We need to add at least one agent to make sure the scheduler
|
||||||
|
# ticks every step
|
||||||
|
num_trials = 5
|
||||||
|
max_time = 2
|
||||||
config = {
|
config = {
|
||||||
'name': 'exporter_sim',
|
'name': 'exporter_sim',
|
||||||
'network_params': {},
|
'model_params': {
|
||||||
'agent_class': 'CounterModel',
|
'agents': [{
|
||||||
'max_time': 2,
|
'agent_class': agents.BaseAgent
|
||||||
'num_trials': 5,
|
}]
|
||||||
'environment_params': {}
|
},
|
||||||
|
'max_time': max_time,
|
||||||
|
'num_trials': num_trials,
|
||||||
}
|
}
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
|
|
||||||
for env in s.run_simulation(exporters=[Dummy], dry_run=True):
|
for env in s.run_simulation(exporters=[Dummy], dry_run=True):
|
||||||
assert env.now <= 2
|
assert len(env.agents) == 1
|
||||||
|
assert env.now == max_time
|
||||||
|
|
||||||
assert Dummy.started
|
assert Dummy.started
|
||||||
assert Dummy.ended
|
assert Dummy.ended
|
||||||
assert Dummy.called_start == 1
|
assert Dummy.called_start == 1
|
||||||
assert Dummy.called_end == 1
|
assert Dummy.called_end == 1
|
||||||
assert Dummy.called_trial == 5
|
assert Dummy.called_trial == num_trials
|
||||||
assert Dummy.trials == 5
|
assert Dummy.trials == num_trials
|
||||||
assert Dummy.total_time == 2*5
|
assert Dummy.total_time == max_time * num_trials
|
||||||
|
|
||||||
def test_writing(self):
|
def test_writing(self):
|
||||||
'''Try to write CSV, GEXF, sqlite and YAML (without dry_run)'''
|
'''Try to write CSV, sqlite and YAML (without dry_run)'''
|
||||||
n_trials = 5
|
n_trials = 5
|
||||||
config = {
|
config = {
|
||||||
'name': 'exporter_sim',
|
'name': 'exporter_sim',
|
||||||
@ -74,7 +84,6 @@ class Exporters(TestCase):
|
|||||||
envs = s.run_simulation(exporters=[
|
envs = s.run_simulation(exporters=[
|
||||||
exporters.default,
|
exporters.default,
|
||||||
exporters.csv,
|
exporters.csv,
|
||||||
exporters.gexf,
|
|
||||||
],
|
],
|
||||||
dry_run=False,
|
dry_run=False,
|
||||||
outdir=tmpdir,
|
outdir=tmpdir,
|
||||||
@ -88,11 +97,7 @@ class Exporters(TestCase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
for e in envs:
|
for e in envs:
|
||||||
with open(os.path.join(simdir, '{}.gexf'.format(e.name))) as f:
|
with open(os.path.join(simdir, '{}.env.csv'.format(e.id))) as f:
|
||||||
result = f.read()
|
|
||||||
assert result
|
|
||||||
|
|
||||||
with open(os.path.join(simdir, '{}.csv'.format(e.name))) as f:
|
|
||||||
result = f.read()
|
result = f.read()
|
||||||
assert result
|
assert result
|
||||||
finally:
|
finally:
|
||||||
|
@ -1,128 +0,0 @@
|
|||||||
from unittest import TestCase
|
|
||||||
|
|
||||||
import os
|
|
||||||
import io
|
|
||||||
import yaml
|
|
||||||
import copy
|
|
||||||
import pickle
|
|
||||||
import networkx as nx
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from os.path import join
|
|
||||||
from soil import (simulation, Environment, agents, serialization,
|
|
||||||
utils)
|
|
||||||
from soil.time import Delta
|
|
||||||
from tsih import NoHistory, History
|
|
||||||
|
|
||||||
|
|
||||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
|
||||||
EXAMPLES = join(ROOT, '..', 'examples')
|
|
||||||
|
|
||||||
|
|
||||||
class CustomAgent(agents.FSM):
|
|
||||||
@agents.default_state
|
|
||||||
@agents.state
|
|
||||||
def normal(self):
|
|
||||||
self.neighbors = self.count_agents(state_id='normal',
|
|
||||||
limit_neighbors=True)
|
|
||||||
@agents.state
|
|
||||||
def unreachable(self):
|
|
||||||
return
|
|
||||||
|
|
||||||
class TestHistory(TestCase):
|
|
||||||
|
|
||||||
def test_counter_agent_history(self):
|
|
||||||
"""
|
|
||||||
The evolution of the state should be recorded in the logging agent
|
|
||||||
"""
|
|
||||||
config = {
|
|
||||||
'name': 'CounterAgent',
|
|
||||||
'network_params': {
|
|
||||||
'path': join(ROOT, 'test.gexf')
|
|
||||||
},
|
|
||||||
'network_agents': [{
|
|
||||||
'agent_class': 'AggregatedCounter',
|
|
||||||
'weight': 1,
|
|
||||||
'state': {'state_id': 0}
|
|
||||||
|
|
||||||
}],
|
|
||||||
'max_time': 10,
|
|
||||||
'environment_params': {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s = simulation.from_config(config)
|
|
||||||
env = s.run_simulation(dry_run=True)[0]
|
|
||||||
for agent in env.network_agents:
|
|
||||||
last = 0
|
|
||||||
assert len(agent[None, None]) == 11
|
|
||||||
for step, total in sorted(agent['total', None]):
|
|
||||||
assert total == last + 2
|
|
||||||
last = total
|
|
||||||
|
|
||||||
def test_row_conversion(self):
|
|
||||||
env = Environment(history=True)
|
|
||||||
env['test'] = 'test_value'
|
|
||||||
|
|
||||||
res = list(env.history_to_tuples())
|
|
||||||
assert len(res) == len(env.environment_params)
|
|
||||||
|
|
||||||
env.schedule.time = 1
|
|
||||||
env['test'] = 'second_value'
|
|
||||||
res = list(env.history_to_tuples())
|
|
||||||
|
|
||||||
assert env['env', 0, 'test' ] == 'test_value'
|
|
||||||
assert env['env', 1, 'test' ] == 'second_value'
|
|
||||||
|
|
||||||
def test_nohistory(self):
|
|
||||||
'''
|
|
||||||
Make sure that no history(/sqlite) is used by default
|
|
||||||
'''
|
|
||||||
env = Environment(topology=nx.Graph(), network_agents=[])
|
|
||||||
assert isinstance(env._history, NoHistory)
|
|
||||||
|
|
||||||
def test_save_graph_history(self):
|
|
||||||
'''
|
|
||||||
The history_to_graph method should return a valid networkx graph.
|
|
||||||
|
|
||||||
The state of the agent should be encoded as intervals in the nx graph.
|
|
||||||
'''
|
|
||||||
G = nx.cycle_graph(5)
|
|
||||||
distribution = agents.calculate_distribution(None, agents.BaseAgent)
|
|
||||||
env = Environment(topology=G, network_agents=distribution, history=True)
|
|
||||||
env[0, 0, 'testvalue'] = 'start'
|
|
||||||
env[0, 10, 'testvalue'] = 'finish'
|
|
||||||
nG = env.history_to_graph()
|
|
||||||
values = nG.nodes[0]['attr_testvalue']
|
|
||||||
assert ('start', 0, 10) in values
|
|
||||||
assert ('finish', 10, None) in values
|
|
||||||
|
|
||||||
def test_save_graph_nohistory(self):
|
|
||||||
'''
|
|
||||||
The history_to_graph method should return a valid networkx graph.
|
|
||||||
|
|
||||||
When NoHistory is used, only the last known value is known
|
|
||||||
'''
|
|
||||||
G = nx.cycle_graph(5)
|
|
||||||
distribution = agents.calculate_distribution(None, agents.BaseAgent)
|
|
||||||
env = Environment(topology=G, network_agents=distribution, history=False)
|
|
||||||
env.get_agent(0)['testvalue'] = 'start'
|
|
||||||
env.schedule.time = 10
|
|
||||||
env.get_agent(0)['testvalue'] = 'finish'
|
|
||||||
nG = env.history_to_graph()
|
|
||||||
values = nG.nodes[0]['attr_testvalue']
|
|
||||||
assert ('start', 0, None) not in values
|
|
||||||
assert ('finish', 10, None) in values
|
|
||||||
|
|
||||||
def test_pickle_agent_environment(self):
|
|
||||||
env = Environment(name='Test', history=True)
|
|
||||||
a = agents.BaseAgent(model=env, unique_id=25)
|
|
||||||
|
|
||||||
a['key'] = 'test'
|
|
||||||
|
|
||||||
pickled = pickle.dumps(a)
|
|
||||||
recovered = pickle.loads(pickled)
|
|
||||||
|
|
||||||
assert recovered.env.name == 'Test'
|
|
||||||
assert list(recovered.env._history.to_tuples())
|
|
||||||
assert recovered['key', 0] == 'test'
|
|
||||||
assert recovered['key'] == 'test'
|
|
@ -24,6 +24,7 @@ class CustomAgent(agents.FSM, agents.NetworkAgent):
|
|||||||
def unreachable(self):
|
def unreachable(self):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
class TestMain(TestCase):
|
class TestMain(TestCase):
|
||||||
|
|
||||||
def test_empty_simulation(self):
|
def test_empty_simulation(self):
|
||||||
@ -79,20 +80,16 @@ class TestMain(TestCase):
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
'agents': {
|
'agents': {
|
||||||
'default': {
|
|
||||||
'agent_class': 'CounterModel',
|
'agent_class': 'CounterModel',
|
||||||
},
|
|
||||||
'counters': {
|
|
||||||
'topology': 'default',
|
'topology': 'default',
|
||||||
'fixed': [{'state': {'times': 10}}, {'state': {'times': 20}}],
|
'fixed': [{'state': {'times': 10}}, {'state': {'times': 20}}],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
env = s.get_env()
|
env = s.get_env()
|
||||||
assert isinstance(env.agents[0], agents.CounterModel)
|
assert isinstance(env.agents[0], agents.CounterModel)
|
||||||
assert env.agents[0].topology == env.topologies['default']
|
assert env.agents[0].G == env.topologies['default']
|
||||||
assert env.agents[0]['times'] == 10
|
assert env.agents[0]['times'] == 10
|
||||||
assert env.agents[0]['times'] == 10
|
assert env.agents[0]['times'] == 10
|
||||||
env.step()
|
env.step()
|
||||||
@ -105,8 +102,8 @@ class TestMain(TestCase):
|
|||||||
config = {
|
config = {
|
||||||
'max_time': 10,
|
'max_time': 10,
|
||||||
'model_params': {
|
'model_params': {
|
||||||
'agents': [(CustomAgent, {'weight': 1}),
|
'agents': [{'agent_class': CustomAgent, 'weight': 1, 'topology': 'default'},
|
||||||
(CustomAgent, {'weight': 3}),
|
{'agent_class': CustomAgent, 'weight': 3, 'topology': 'default'},
|
||||||
],
|
],
|
||||||
'topologies': {
|
'topologies': {
|
||||||
'default': {
|
'default': {
|
||||||
@ -128,7 +125,7 @@ class TestMain(TestCase):
|
|||||||
"""A complete example from a documentation should work."""
|
"""A complete example from a documentation should work."""
|
||||||
config = serialization.load_file(join(EXAMPLES, 'torvalds.yml'))[0]
|
config = serialization.load_file(join(EXAMPLES, 'torvalds.yml'))[0]
|
||||||
config['model_params']['network_params']['path'] = join(EXAMPLES,
|
config['model_params']['network_params']['path'] = join(EXAMPLES,
|
||||||
config['network_params']['path'])
|
config['model_params']['network_params']['path'])
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
env = s.run_simulation(dry_run=True)[0]
|
env = s.run_simulation(dry_run=True)[0]
|
||||||
for a in env.network_agents:
|
for a in env.network_agents:
|
||||||
@ -208,24 +205,6 @@ class TestMain(TestCase):
|
|||||||
assert converted[1]['agent_class'] == 'test_main.CustomAgent'
|
assert converted[1]['agent_class'] == 'test_main.CustomAgent'
|
||||||
pickle.dumps(converted)
|
pickle.dumps(converted)
|
||||||
|
|
||||||
def test_subgraph(self):
|
|
||||||
'''An agent should be able to subgraph the global topology'''
|
|
||||||
G = nx.Graph()
|
|
||||||
G.add_node(3)
|
|
||||||
G.add_edge(1, 2)
|
|
||||||
distro = agents.calculate_distribution(agent_class=agents.NetworkAgent)
|
|
||||||
distro[0]['topology'] = 'default'
|
|
||||||
aconfig = config.AgentConfig(distribution=distro, topology='default')
|
|
||||||
env = Environment(name='Test', topologies={'default': G}, agents={'network': aconfig})
|
|
||||||
lst = list(env.network_agents)
|
|
||||||
|
|
||||||
a2 = env.find_one(node_id=2)
|
|
||||||
a3 = env.find_one(node_id=3)
|
|
||||||
assert len(a2.subgraph(limit_neighbors=True)) == 2
|
|
||||||
assert len(a3.subgraph(limit_neighbors=True)) == 1
|
|
||||||
assert len(a3.subgraph(limit_neighbors=True, center=False)) == 0
|
|
||||||
assert len(a3.subgraph(agent_class=agents.NetworkAgent)) == 3
|
|
||||||
|
|
||||||
def test_templates(self):
|
def test_templates(self):
|
||||||
'''Loading a template should result in several configs'''
|
'''Loading a template should result in several configs'''
|
||||||
configs = serialization.load_file(join(EXAMPLES, 'template.yml'))
|
configs = serialization.load_file(join(EXAMPLES, 'template.yml'))
|
||||||
@ -236,14 +215,18 @@ class TestMain(TestCase):
|
|||||||
'name': 'until_sim',
|
'name': 'until_sim',
|
||||||
'model_params': {
|
'model_params': {
|
||||||
'network_params': {},
|
'network_params': {},
|
||||||
'agent_class': 'CounterModel',
|
'agents': {
|
||||||
|
'fixed': [{
|
||||||
|
'agent_class': agents.BaseAgent,
|
||||||
|
}]
|
||||||
|
},
|
||||||
},
|
},
|
||||||
'max_time': 2,
|
'max_time': 2,
|
||||||
'num_trials': 50,
|
'num_trials': 50,
|
||||||
}
|
}
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
runs = list(s.run_simulation(dry_run=True))
|
runs = list(s.run_simulation(dry_run=True))
|
||||||
over = list(x.now for x in runs if x.now>2)
|
over = list(x.now for x in runs if x.now > 2)
|
||||||
assert len(runs) == config['num_trials']
|
assert len(runs) == config['num_trials']
|
||||||
assert len(over) == 0
|
assert len(over) == 0
|
||||||
|
|
||||||
|
@ -6,7 +6,8 @@ import networkx as nx
|
|||||||
|
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
|
||||||
from soil import network, environment
|
from soil import config, network, environment, agents, simulation
|
||||||
|
from test_main import CustomAgent
|
||||||
|
|
||||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||||
EXAMPLES = join(ROOT, '..', 'examples')
|
EXAMPLES = join(ROOT, '..', 'examples')
|
||||||
@ -60,22 +61,53 @@ class TestNetwork(TestCase):
|
|||||||
G = nx.random_geometric_graph(20, 0.1)
|
G = nx.random_geometric_graph(20, 0.1)
|
||||||
env = environment.NetworkEnvironment(topology=G)
|
env = environment.NetworkEnvironment(topology=G)
|
||||||
f = io.BytesIO()
|
f = io.BytesIO()
|
||||||
env.dump_gexf(f)
|
assert env.topologies['default']
|
||||||
|
network.dump_gexf(env.topologies['default'], f)
|
||||||
|
|
||||||
|
def test_networkenvironment_creation(self):
|
||||||
|
"""Networkenvironment should accept netconfig as parameters"""
|
||||||
|
model_params = {
|
||||||
|
'topologies': {
|
||||||
|
'default': {
|
||||||
|
'path': join(ROOT, 'test.gexf')
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'agents': {
|
||||||
|
'topology': 'default',
|
||||||
|
'distribution': [{
|
||||||
|
'agent_class': CustomAgent,
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
env = environment.Environment(**model_params)
|
||||||
|
assert env.topologies
|
||||||
|
env.step()
|
||||||
|
assert len(env.topologies['default']) == 2
|
||||||
|
assert len(env.agents) == 2
|
||||||
|
assert env.agents[1].count_agents(state_id='normal') == 2
|
||||||
|
assert env.agents[1].count_agents(state_id='normal', limit_neighbors=True) == 1
|
||||||
|
assert env.agents[0].neighbors == 1
|
||||||
|
|
||||||
def test_custom_agent_neighbors(self):
|
def test_custom_agent_neighbors(self):
|
||||||
"""Allow for search of neighbors with a certain state_id"""
|
"""Allow for search of neighbors with a certain state_id"""
|
||||||
config = {
|
config = {
|
||||||
'network_params': {
|
'model_params': {
|
||||||
|
'topologies': {
|
||||||
|
'default': {
|
||||||
'path': join(ROOT, 'test.gexf')
|
'path': join(ROOT, 'test.gexf')
|
||||||
},
|
|
||||||
'network_agents': [{
|
|
||||||
'agent_class': CustomAgent,
|
|
||||||
'weight': 1
|
|
||||||
|
|
||||||
}],
|
|
||||||
'max_time': 10,
|
|
||||||
'environment_params': {
|
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
'agents': {
|
||||||
|
'topology': 'default',
|
||||||
|
'distribution': [
|
||||||
|
{
|
||||||
|
'weight': 1,
|
||||||
|
'agent_class': CustomAgent
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'max_time': 10,
|
||||||
}
|
}
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
env = s.run_simulation(dry_run=True)[0]
|
env = s.run_simulation(dry_run=True)[0]
|
||||||
@ -83,3 +115,19 @@ class TestNetwork(TestCase):
|
|||||||
assert env.agents[1].count_agents(state_id='normal', limit_neighbors=True) == 1
|
assert env.agents[1].count_agents(state_id='normal', limit_neighbors=True) == 1
|
||||||
assert env.agents[0].neighbors == 1
|
assert env.agents[0].neighbors == 1
|
||||||
|
|
||||||
|
def test_subgraph(self):
|
||||||
|
'''An agent should be able to subgraph the global topology'''
|
||||||
|
G = nx.Graph()
|
||||||
|
G.add_node(3)
|
||||||
|
G.add_edge(1, 2)
|
||||||
|
distro = agents.calculate_distribution(agent_class=agents.NetworkAgent)
|
||||||
|
aconfig = config.AgentConfig(distribution=distro, topology='default')
|
||||||
|
env = environment.Environment(name='Test', topologies={'default': G}, agents=aconfig)
|
||||||
|
lst = list(env.network_agents)
|
||||||
|
|
||||||
|
a2 = env.find_one(node_id=2)
|
||||||
|
a3 = env.find_one(node_id=3)
|
||||||
|
assert len(a2.subgraph(limit_neighbors=True)) == 2
|
||||||
|
assert len(a3.subgraph(limit_neighbors=True)) == 1
|
||||||
|
assert len(a3.subgraph(limit_neighbors=True, center=False)) == 0
|
||||||
|
assert len(a3.subgraph(agent_class=agents.NetworkAgent)) == 3
|
||||||
|
Loading…
Reference in New Issue
Block a user