mirror of
https://github.com/gsi-upm/soil
synced 2024-11-24 20:02:28 +00:00
WIP: all tests pass
This commit is contained in:
parent
f811ee18c5
commit
cd62c23cb9
@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.3 UNRELEASED]
|
||||
### Added
|
||||
* Simple debugging capabilities, with a custom `pdb.Debugger` subclass that exposes commands to list agents and their status and set breakpoints on states (for FSM agents)
|
||||
### Changed
|
||||
* Configuration schema is very different now. Check `soil.config` for more information. We are also using Pydantic for (de)serialization.
|
||||
* There may be more than one topology/network in the simulation
|
||||
|
12
docs/soil-vs.rst
Normal file
12
docs/soil-vs.rst
Normal file
@ -0,0 +1,12 @@
|
||||
### MESA
|
||||
|
||||
Starting with version 0.3, Soil has been redesigned to complement Mesa, while remaining compatible with it.
|
||||
That means that every component in Soil (i.e., Models, Environments, etc.) can be mixed with existing mesa components.
|
||||
In fact, there are examples that show how that integration may be used, in the `examples/mesa` folder in the repository.
|
||||
|
||||
Here are some reasons to use Soil instead of plain mesa:
|
||||
|
||||
- Less boilerplate for common scenarios (by some definitions of common)
|
||||
- Functions to automatically populate a topology with an agent distribution (i.e., different ratios of agent class and state)
|
||||
- The `soil.Simulation` class allows you to run multiple instances of the same experiment (i.e., multiple trials with the same parameters but a different randomness seed)
|
||||
- Reporting functions that aggregate multiple
|
@ -1,14 +1,16 @@
|
||||
---
|
||||
version: '2'
|
||||
general:
|
||||
id: simple
|
||||
group: tests
|
||||
dir_path: "/tmp/"
|
||||
num_trials: 3
|
||||
max_time: 100
|
||||
interval: 1
|
||||
seed: "CompleteSeed!"
|
||||
topologies:
|
||||
name: simple
|
||||
group: tests
|
||||
dir_path: "/tmp/"
|
||||
num_trials: 3
|
||||
max_steps: 100
|
||||
interval: 1
|
||||
seed: "CompleteSeed!"
|
||||
model_class: Environment
|
||||
model_params:
|
||||
am_i_complete: true
|
||||
topologies:
|
||||
default:
|
||||
params:
|
||||
generator: complete_graph
|
||||
@ -17,30 +19,36 @@ topologies:
|
||||
params:
|
||||
generator: complete_graph
|
||||
n: 2
|
||||
environment:
|
||||
environment_class: Environment
|
||||
params:
|
||||
am_i_complete: true
|
||||
agents:
|
||||
# Agents are split several groups, each with its own definition
|
||||
default: # This is a special group. Its values will be used as default values for the rest of the groups
|
||||
environment:
|
||||
agents:
|
||||
agent_class: CounterModel
|
||||
topology: default
|
||||
state:
|
||||
times: 1
|
||||
environment:
|
||||
# In this group we are not specifying any topology
|
||||
topology: False
|
||||
fixed:
|
||||
- name: 'Environment Agent 1'
|
||||
agent_class: CounterModel
|
||||
agent_class: BaseAgent
|
||||
group: environment
|
||||
topology: null
|
||||
hidden: true
|
||||
state:
|
||||
times: 10
|
||||
general_counters:
|
||||
topology: default
|
||||
- agent_class: CounterModel
|
||||
id: 0
|
||||
group: other_counters
|
||||
topology: another_graph
|
||||
state:
|
||||
times: 1
|
||||
total: 0
|
||||
- agent_class: CounterModel
|
||||
topology: another_graph
|
||||
group: other_counters
|
||||
id: 1
|
||||
distribution:
|
||||
- agent_class: CounterModel
|
||||
weight: 1
|
||||
group: general_counters
|
||||
state:
|
||||
times: 3
|
||||
- agent_class: AggregatedCounter
|
||||
@ -51,16 +59,3 @@ agents:
|
||||
n: 2
|
||||
state:
|
||||
times: 5
|
||||
|
||||
other_counters:
|
||||
topology: another_graph
|
||||
fixed:
|
||||
- agent_class: CounterModel
|
||||
id: 0
|
||||
state:
|
||||
times: 1
|
||||
total: 0
|
||||
- agent_class: CounterModel
|
||||
id: 1
|
||||
# If not specified, it will use the state set in the default
|
||||
# state:
|
||||
|
63
examples/complete_opt2.yml
Normal file
63
examples/complete_opt2.yml
Normal file
@ -0,0 +1,63 @@
|
||||
---
|
||||
version: '2'
|
||||
id: simple
|
||||
group: tests
|
||||
dir_path: "/tmp/"
|
||||
num_trials: 3
|
||||
max_steps: 100
|
||||
interval: 1
|
||||
seed: "CompleteSeed!"
|
||||
model_class: "soil.Environment"
|
||||
model_params:
|
||||
topologies:
|
||||
default:
|
||||
params:
|
||||
generator: complete_graph
|
||||
n: 10
|
||||
another_graph:
|
||||
params:
|
||||
generator: complete_graph
|
||||
n: 2
|
||||
agents:
|
||||
# The values here will be used as default values for any agent
|
||||
agent_class: CounterModel
|
||||
topology: default
|
||||
state:
|
||||
times: 1
|
||||
# This specifies a distribution of agents, each with a `weight` or an explicit number of agents
|
||||
distribution:
|
||||
- agent_class: CounterModel
|
||||
weight: 1
|
||||
# This is inherited from the default settings
|
||||
#topology: default
|
||||
state:
|
||||
times: 3
|
||||
- agent_class: AggregatedCounter
|
||||
topology: default
|
||||
weight: 0.2
|
||||
fixed:
|
||||
- name: 'Environment Agent 1'
|
||||
# All the other agents will assigned to the 'default' group
|
||||
group: environment
|
||||
# Do not count this agent towards total limits
|
||||
hidden: true
|
||||
agent_class: soil.BaseAgent
|
||||
topology: null
|
||||
state:
|
||||
times: 10
|
||||
- agent_class: CounterModel
|
||||
topology: another_graph
|
||||
id: 0
|
||||
state:
|
||||
times: 1
|
||||
total: 0
|
||||
- agent_class: CounterModel
|
||||
topology: another_graph
|
||||
id: 1
|
||||
override:
|
||||
# 2 agents that match this filter will be updated to match the state {times: 5}
|
||||
- filter:
|
||||
agent_class: AggregatedCounter
|
||||
n: 2
|
||||
state:
|
||||
times: 5
|
@ -2,7 +2,7 @@
|
||||
name: custom-generator
|
||||
description: Using a custom generator for the network
|
||||
num_trials: 3
|
||||
max_time: 100
|
||||
max_steps: 100
|
||||
interval: 1
|
||||
network_params:
|
||||
generator: mymodule.mygenerator
|
||||
|
@ -1,4 +1,5 @@
|
||||
from networkx import Graph
|
||||
import random
|
||||
import networkx as nx
|
||||
|
||||
def mygenerator(n=5, n_edges=5):
|
||||
@ -13,9 +14,9 @@ def mygenerator(n=5, n_edges=5):
|
||||
|
||||
for i in range(n_edges):
|
||||
nodes = list(G.nodes)
|
||||
n_in = self.random.choice(nodes)
|
||||
n_in = random.choice(nodes)
|
||||
nodes.remove(n_in) # Avoid loops
|
||||
n_out = self.random.choice(nodes)
|
||||
n_out = random.choice(nodes)
|
||||
G.add_edge(n_in, n_out)
|
||||
return G
|
||||
|
||||
|
@ -3,17 +3,21 @@ name: mesa_sim
|
||||
group: tests
|
||||
dir_path: "/tmp"
|
||||
num_trials: 3
|
||||
max_time: 100
|
||||
max_steps: 100
|
||||
interval: 1
|
||||
seed: '1'
|
||||
network_params:
|
||||
model_class: social_wealth.MoneyEnv
|
||||
model_params:
|
||||
topologies:
|
||||
default:
|
||||
params:
|
||||
generator: social_wealth.graph_generator
|
||||
n: 5
|
||||
network_agents:
|
||||
agents:
|
||||
distribution:
|
||||
- agent_class: social_wealth.SocialMoneyAgent
|
||||
topology: default
|
||||
weight: 1
|
||||
environment_class: social_wealth.MoneyEnv
|
||||
environment_params:
|
||||
mesa_agent_class: social_wealth.MoneyAgent
|
||||
N: 10
|
||||
width: 50
|
||||
|
@ -5,7 +5,7 @@ environment_params:
|
||||
prob_neighbor_spread: 0.0
|
||||
prob_tv_spread: 0.01
|
||||
interval: 1
|
||||
max_time: 300
|
||||
max_steps: 300
|
||||
name: Sim_all_dumb
|
||||
network_agents:
|
||||
- agent_class: newsspread.DumbViewer
|
||||
@ -28,7 +28,7 @@ environment_params:
|
||||
prob_neighbor_spread: 0.0
|
||||
prob_tv_spread: 0.01
|
||||
interval: 1
|
||||
max_time: 300
|
||||
max_steps: 300
|
||||
name: Sim_half_herd
|
||||
network_agents:
|
||||
- agent_class: newsspread.DumbViewer
|
||||
@ -59,7 +59,7 @@ environment_params:
|
||||
prob_neighbor_spread: 0.0
|
||||
prob_tv_spread: 0.01
|
||||
interval: 1
|
||||
max_time: 300
|
||||
max_steps: 300
|
||||
name: Sim_all_herd
|
||||
network_agents:
|
||||
- agent_class: newsspread.HerdViewer
|
||||
@ -85,7 +85,7 @@ environment_params:
|
||||
prob_tv_spread: 0.01
|
||||
prob_neighbor_cure: 0.1
|
||||
interval: 1
|
||||
max_time: 300
|
||||
max_steps: 300
|
||||
name: Sim_wise_herd
|
||||
network_agents:
|
||||
- agent_class: newsspread.HerdViewer
|
||||
@ -110,7 +110,7 @@ environment_params:
|
||||
prob_tv_spread: 0.01
|
||||
prob_neighbor_cure: 0.1
|
||||
interval: 1
|
||||
max_time: 300
|
||||
max_steps: 300
|
||||
name: Sim_all_wise
|
||||
network_agents:
|
||||
- agent_class: newsspread.WiseViewer
|
||||
|
@ -16,13 +16,13 @@ class DumbViewer(FSM, NetworkAgent):
|
||||
@state
|
||||
def neutral(self):
|
||||
if self['has_tv']:
|
||||
if prob(self.env['prob_tv_spread']):
|
||||
if self.prob(self.model['prob_tv_spread']):
|
||||
return self.infected
|
||||
|
||||
@state
|
||||
def infected(self):
|
||||
for neighbor in self.get_neighboring_agents(state_id=self.neutral.id):
|
||||
if prob(self.env['prob_neighbor_spread']):
|
||||
if self.prob(self.model['prob_neighbor_spread']):
|
||||
neighbor.infect()
|
||||
|
||||
def infect(self):
|
||||
@ -44,9 +44,9 @@ class HerdViewer(DumbViewer):
|
||||
'''Notice again that this is NOT a state. See DumbViewer.infect for reference'''
|
||||
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
||||
total = self.count_neighboring_agents()
|
||||
prob_infect = self.env['prob_neighbor_spread'] * infected/total
|
||||
prob_infect = self.model['prob_neighbor_spread'] * infected/total
|
||||
self.debug('prob_infect', prob_infect)
|
||||
if prob(prob_infect):
|
||||
if self.prob(prob_infect):
|
||||
self.set_state(self.infected)
|
||||
|
||||
|
||||
@ -63,9 +63,9 @@ class WiseViewer(HerdViewer):
|
||||
|
||||
@state
|
||||
def cured(self):
|
||||
prob_cure = self.env['prob_neighbor_cure']
|
||||
prob_cure = self.model['prob_neighbor_cure']
|
||||
for neighbor in self.get_neighboring_agents(state_id=self.infected.id):
|
||||
if prob(prob_cure):
|
||||
if self.prob(prob_cure):
|
||||
try:
|
||||
neighbor.cure()
|
||||
except AttributeError:
|
||||
@ -80,7 +80,7 @@ class WiseViewer(HerdViewer):
|
||||
1.0)
|
||||
infected = max(self.count_neighboring_agents(self.infected.id),
|
||||
1.0)
|
||||
prob_cure = self.env['prob_neighbor_cure'] * (cured/infected)
|
||||
if prob(prob_cure):
|
||||
prob_cure = self.model['prob_neighbor_cure'] * (cured/infected)
|
||||
if self.prob(prob_cure):
|
||||
return self.cured
|
||||
return self.set_state(super().infected)
|
||||
|
@ -60,12 +60,10 @@ class Patron(FSM, NetworkAgent):
|
||||
'''
|
||||
level = logging.DEBUG
|
||||
|
||||
defaults = {
|
||||
'pub': None,
|
||||
'drunk': False,
|
||||
'pints': 0,
|
||||
'max_pints': 3,
|
||||
}
|
||||
pub = None
|
||||
drunk = False
|
||||
pints = 0
|
||||
max_pints = 3
|
||||
|
||||
@default_state
|
||||
@state
|
||||
@ -89,9 +87,9 @@ class Patron(FSM, NetworkAgent):
|
||||
return self.sober_in_pub
|
||||
self.debug('I am looking for a pub')
|
||||
group = list(self.get_neighboring_agents())
|
||||
for pub in self.env.available_pubs():
|
||||
for pub in self.model.available_pubs():
|
||||
self.debug('We\'re trying to get into {}: total: {}'.format(pub, len(group)))
|
||||
if self.env.enter(pub, self, *group):
|
||||
if self.model.enter(pub, self, *group):
|
||||
self.info('We\'re all {} getting in {}!'.format(len(group), pub))
|
||||
return self.sober_in_pub
|
||||
|
||||
@ -128,7 +126,7 @@ class Patron(FSM, NetworkAgent):
|
||||
success depend on both agents' openness.
|
||||
'''
|
||||
if force or self['openness'] > self.random.random():
|
||||
self.env.add_edge(self, other_agent)
|
||||
self.model.add_edge(self, other_agent)
|
||||
self.info('Made some friend {}'.format(other_agent))
|
||||
return True
|
||||
return False
|
||||
@ -150,7 +148,7 @@ class Patron(FSM, NetworkAgent):
|
||||
return befriended
|
||||
|
||||
|
||||
class Police(FSM, NetworkAgent):
|
||||
class Police(FSM):
|
||||
'''Simple agent to take drunk people out of pubs.'''
|
||||
level = logging.INFO
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
---
|
||||
name: pubcrawl
|
||||
num_trials: 3
|
||||
max_time: 10
|
||||
max_steps: 10
|
||||
dump: false
|
||||
network_params:
|
||||
# Generate 100 empty nodes. They will be assigned a network agent
|
||||
|
4
examples/rabbits/README.md
Normal file
4
examples/rabbits/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
There are two similar implementations of this simulation.
|
||||
|
||||
- `basic`. Using simple primites
|
||||
- `improved`. Using more advanced features such as the `time` module to avoid unnecessary computations (i.e., skip steps), and generator functions.
|
130
examples/rabbits/basic/rabbit_agents.py
Normal file
130
examples/rabbits/basic/rabbit_agents.py
Normal file
@ -0,0 +1,130 @@
|
||||
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
|
||||
from soil.time import Delta
|
||||
from enum import Enum
|
||||
from collections import Counter
|
||||
import logging
|
||||
import math
|
||||
|
||||
|
||||
class RabbitModel(FSM, NetworkAgent):
|
||||
|
||||
sexual_maturity = 30
|
||||
life_expectancy = 300
|
||||
|
||||
@default_state
|
||||
@state
|
||||
def newborn(self):
|
||||
self.info('I am a newborn.')
|
||||
self.age = 0
|
||||
self.offspring = 0
|
||||
return self.youngling
|
||||
|
||||
@state
|
||||
def youngling(self):
|
||||
self.age += 1
|
||||
if self.age >= self.sexual_maturity:
|
||||
self.info(f'I am fertile! My age is {self.age}')
|
||||
return self.fertile
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
raise Exception("Each subclass should define its fertile state")
|
||||
|
||||
@state
|
||||
def dead(self):
|
||||
self.die()
|
||||
|
||||
|
||||
class Male(RabbitModel):
|
||||
max_females = 5
|
||||
mating_prob = 0.001
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
self.age += 1
|
||||
|
||||
if self.age > self.life_expectancy:
|
||||
return self.dead
|
||||
|
||||
# Males try to mate
|
||||
for f in self.model.agents(agent_class=Female,
|
||||
state_id=Female.fertile.id,
|
||||
limit=self.max_females):
|
||||
self.debug('FOUND A FEMALE: ', repr(f), self.mating_prob)
|
||||
if self.prob(self['mating_prob']):
|
||||
f.impregnate(self)
|
||||
break # Take a break
|
||||
|
||||
|
||||
class Female(RabbitModel):
|
||||
gestation = 100
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
# Just wait for a Male
|
||||
self.age += 1
|
||||
if self.age > self.life_expectancy:
|
||||
return self.dead
|
||||
|
||||
def impregnate(self, male):
|
||||
self.info(f'{repr(male)} impregnating female {repr(self)}')
|
||||
self.mate = male
|
||||
self.pregnancy = -1
|
||||
self.set_state(self.pregnant, when=self.now)
|
||||
self.number_of_babies = int(8+4*self.random.random())
|
||||
self.debug('I am pregnant')
|
||||
|
||||
@state
|
||||
def pregnant(self):
|
||||
self.age += 1
|
||||
self.pregnancy += 1
|
||||
|
||||
if self.prob(self.age / self.life_expectancy):
|
||||
return self.die()
|
||||
|
||||
if self.pregnancy >= self.gestation:
|
||||
self.info('Having {} babies'.format(self.number_of_babies))
|
||||
for i in range(self.number_of_babies):
|
||||
state = {}
|
||||
agent_class = self.random.choice([Male, Female])
|
||||
child = self.model.add_node(agent_class=agent_class,
|
||||
topology=self.topology,
|
||||
**state)
|
||||
child.add_edge(self)
|
||||
try:
|
||||
child.add_edge(self.mate)
|
||||
self.model.agents[self.mate].offspring += 1
|
||||
except ValueError:
|
||||
self.debug('The father has passed away')
|
||||
|
||||
self.offspring += 1
|
||||
self.mate = None
|
||||
return self.fertile
|
||||
|
||||
@state
|
||||
def dead(self):
|
||||
super().dead()
|
||||
if 'pregnancy' in self and self['pregnancy'] > -1:
|
||||
self.info('A mother has died carrying a baby!!')
|
||||
|
||||
|
||||
class RandomAccident(BaseAgent):
|
||||
|
||||
level = logging.INFO
|
||||
|
||||
def step(self):
|
||||
rabbits_alive = self.model.topology.number_of_nodes()
|
||||
|
||||
if not rabbits_alive:
|
||||
return self.die()
|
||||
|
||||
prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
|
||||
self.debug('Killing some rabbits with prob={}!'.format(prob_death))
|
||||
for i in self.iter_agents(agent_class=RabbitModel):
|
||||
if i.state.id == i.dead.id:
|
||||
continue
|
||||
if self.prob(prob_death):
|
||||
self.info('I killed a rabbit: {}'.format(i.id))
|
||||
rabbits_alive -= 1
|
||||
i.set_state(i.dead)
|
||||
self.debug('Rabbits alive: {}'.format(rabbits_alive))
|
41
examples/rabbits/basic/rabbits.yml
Normal file
41
examples/rabbits/basic/rabbits.yml
Normal file
@ -0,0 +1,41 @@
|
||||
---
|
||||
version: '2'
|
||||
name: rabbits_basic
|
||||
num_trials: 1
|
||||
seed: MySeed
|
||||
description: null
|
||||
group: null
|
||||
interval: 1.0
|
||||
max_time: 100
|
||||
model_class: soil.environment.Environment
|
||||
model_params:
|
||||
agents:
|
||||
topology: default
|
||||
agent_class: rabbit_agents.RabbitModel
|
||||
distribution:
|
||||
- agent_class: rabbit_agents.Male
|
||||
topology: default
|
||||
weight: 1
|
||||
- agent_class: rabbit_agents.Female
|
||||
topology: default
|
||||
weight: 1
|
||||
fixed:
|
||||
- agent_class: rabbit_agents.RandomAccident
|
||||
topology: null
|
||||
hidden: true
|
||||
state:
|
||||
group: environment
|
||||
state:
|
||||
group: network
|
||||
mating_prob: 0.1
|
||||
prob_death: 0.001
|
||||
topologies:
|
||||
default:
|
||||
topology:
|
||||
directed: true
|
||||
links: []
|
||||
nodes:
|
||||
- id: 1
|
||||
- id: 0
|
||||
extra:
|
||||
visualization_params: {}
|
130
examples/rabbits/improved/rabbit_agents.py
Normal file
130
examples/rabbits/improved/rabbit_agents.py
Normal file
@ -0,0 +1,130 @@
|
||||
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
|
||||
from soil.time import Delta, When, NEVER
|
||||
from enum import Enum
|
||||
import logging
|
||||
import math
|
||||
|
||||
|
||||
class RabbitModel(FSM, NetworkAgent):
|
||||
|
||||
mating_prob = 0.005
|
||||
offspring = 0
|
||||
birth = None
|
||||
|
||||
sexual_maturity = 3
|
||||
life_expectancy = 30
|
||||
|
||||
@default_state
|
||||
@state
|
||||
def newborn(self):
|
||||
self.birth = self.now
|
||||
self.info(f'I am a newborn.')
|
||||
self.model['rabbits_alive'] = self.model.get('rabbits_alive', 0) + 1
|
||||
|
||||
# Here we can skip the `youngling` state by using a coroutine/generator.
|
||||
while self.age < self.sexual_maturity:
|
||||
interval = self.sexual_maturity - self.age
|
||||
yield Delta(interval)
|
||||
|
||||
self.info(f'I am fertile! My age is {self.age}')
|
||||
return self.fertile
|
||||
|
||||
@property
|
||||
def age(self):
|
||||
return self.now - self.birth
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
raise Exception("Each subclass should define its fertile state")
|
||||
|
||||
def step(self):
|
||||
super().step()
|
||||
if self.prob(self.age / self.life_expectancy):
|
||||
return self.die()
|
||||
|
||||
|
||||
class Male(RabbitModel):
|
||||
|
||||
max_females = 5
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
# Males try to mate
|
||||
for f in self.model.agents(agent_class=Female,
|
||||
state_id=Female.fertile.id,
|
||||
limit=self.max_females):
|
||||
self.debug('Found a female:', repr(f))
|
||||
if self.prob(self['mating_prob']):
|
||||
f.impregnate(self)
|
||||
break # Take a break, don't try to impregnate the rest
|
||||
|
||||
|
||||
class Female(RabbitModel):
|
||||
due_date = None
|
||||
age_of_pregnancy = None
|
||||
gestation = 10
|
||||
mate = None
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
return self.fertile, NEVER
|
||||
|
||||
@state
|
||||
def pregnant(self):
|
||||
self.info('I am pregnant')
|
||||
if self.age > self.life_expectancy:
|
||||
return self.dead
|
||||
|
||||
self.due_date = self.now + self.gestation
|
||||
|
||||
number_of_babies = int(8+4*self.random.random())
|
||||
|
||||
while self.now < self.due_date:
|
||||
yield When(self.due_date)
|
||||
|
||||
self.info('Having {} babies'.format(number_of_babies))
|
||||
for i in range(number_of_babies):
|
||||
agent_class = self.random.choice([Male, Female])
|
||||
child = self.model.add_node(agent_class=agent_class,
|
||||
topology=self.topology)
|
||||
self.model.add_edge(self, child)
|
||||
self.model.add_edge(self.mate, child)
|
||||
self.offspring += 1
|
||||
self.model.agents[self.mate].offspring += 1
|
||||
self.mate = None
|
||||
self.due_date = None
|
||||
return self.fertile
|
||||
|
||||
@state
|
||||
def dead(self):
|
||||
super().dead()
|
||||
if self.due_date is not None:
|
||||
self.info('A mother has died carrying a baby!!')
|
||||
|
||||
def impregnate(self, male):
|
||||
self.info(f'{repr(male)} impregnating female {repr(self)}')
|
||||
self.mate = male
|
||||
self.set_state(self.pregnant, when=self.now)
|
||||
|
||||
|
||||
class RandomAccident(BaseAgent):
|
||||
|
||||
level = logging.INFO
|
||||
|
||||
def step(self):
|
||||
rabbits_total = self.model.topology.number_of_nodes()
|
||||
if 'rabbits_alive' not in self.model:
|
||||
self.model['rabbits_alive'] = 0
|
||||
rabbits_alive = self.model.get('rabbits_alive', rabbits_total)
|
||||
prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
|
||||
self.debug('Killing some rabbits with prob={}!'.format(prob_death))
|
||||
for i in self.model.network_agents:
|
||||
if i.state.id == i.dead.id:
|
||||
continue
|
||||
if self.prob(prob_death):
|
||||
self.info('I killed a rabbit: {}'.format(i.id))
|
||||
rabbits_alive = self.model['rabbits_alive'] = rabbits_alive -1
|
||||
i.set_state(i.dead)
|
||||
self.debug('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total))
|
||||
if self.model.count_agents(state_id=RabbitModel.dead.id) == self.model.topology.number_of_nodes():
|
||||
self.die()
|
41
examples/rabbits/improved/rabbits.yml
Normal file
41
examples/rabbits/improved/rabbits.yml
Normal file
@ -0,0 +1,41 @@
|
||||
---
|
||||
version: '2'
|
||||
name: rabbits_improved
|
||||
num_trials: 1
|
||||
seed: MySeed
|
||||
description: null
|
||||
group: null
|
||||
interval: 1.0
|
||||
max_time: 100
|
||||
model_class: soil.environment.Environment
|
||||
model_params:
|
||||
agents:
|
||||
topology: default
|
||||
agent_class: rabbit_agents.RabbitModel
|
||||
distribution:
|
||||
- agent_class: rabbit_agents.Male
|
||||
topology: default
|
||||
weight: 1
|
||||
- agent_class: rabbit_agents.Female
|
||||
topology: default
|
||||
weight: 1
|
||||
fixed:
|
||||
- agent_class: rabbit_agents.RandomAccident
|
||||
topology: null
|
||||
hidden: true
|
||||
state:
|
||||
group: environment
|
||||
state:
|
||||
group: network
|
||||
mating_prob: 0.1
|
||||
prob_death: 0.001
|
||||
topologies:
|
||||
default:
|
||||
topology:
|
||||
directed: true
|
||||
links: []
|
||||
nodes:
|
||||
- id: 1
|
||||
- id: 0
|
||||
extra:
|
||||
visualization_params: {}
|
@ -1,133 +0,0 @@
|
||||
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
|
||||
from enum import Enum
|
||||
import logging
|
||||
import math
|
||||
|
||||
|
||||
class Genders(Enum):
|
||||
male = 'male'
|
||||
female = 'female'
|
||||
|
||||
|
||||
class RabbitModel(FSM, NetworkAgent):
|
||||
|
||||
defaults = {
|
||||
'age': 0,
|
||||
'gender': Genders.male.value,
|
||||
'mating_prob': 0.001,
|
||||
'offspring': 0,
|
||||
}
|
||||
|
||||
sexual_maturity = 3 #4*30
|
||||
life_expectancy = 365 * 3
|
||||
gestation = 33
|
||||
pregnancy = -1
|
||||
max_females = 5
|
||||
|
||||
@default_state
|
||||
@state
|
||||
def newborn(self):
|
||||
self.debug(f'I am a newborn at age {self["age"]}')
|
||||
self['age'] += 1
|
||||
|
||||
if self['age'] >= self.sexual_maturity:
|
||||
self.debug('I am fertile!')
|
||||
return self.fertile
|
||||
@state
|
||||
def fertile(self):
|
||||
raise Exception("Each subclass should define its fertile state")
|
||||
|
||||
@state
|
||||
def dead(self):
|
||||
self.info('Agent {} is dying'.format(self.id))
|
||||
self.die()
|
||||
|
||||
|
||||
class Male(RabbitModel):
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
self['age'] += 1
|
||||
if self['age'] > self.life_expectancy:
|
||||
return self.dead
|
||||
|
||||
if self['gender'] == Genders.female.value:
|
||||
return
|
||||
|
||||
# Males try to mate
|
||||
for f in self.get_agents(state_id=Female.fertile.id,
|
||||
agent_class=Female,
|
||||
limit_neighbors=False,
|
||||
limit=self.max_females):
|
||||
r = self.random.random()
|
||||
if r < self['mating_prob']:
|
||||
self.impregnate(f)
|
||||
break # Take a break
|
||||
def impregnate(self, whom):
|
||||
whom['pregnancy'] = 0
|
||||
whom['mate'] = self.id
|
||||
whom.set_state(whom.pregnant)
|
||||
self.debug('{} impregnating: {}. {}'.format(self.id, whom.id, whom.state))
|
||||
|
||||
class Female(RabbitModel):
|
||||
@state
|
||||
def fertile(self):
|
||||
# Just wait for a Male
|
||||
pass
|
||||
|
||||
@state
|
||||
def pregnant(self):
|
||||
self['age'] += 1
|
||||
if self['age'] > self.life_expectancy:
|
||||
return self.dead
|
||||
|
||||
self['pregnancy'] += 1
|
||||
self.debug('Pregnancy: {}'.format(self['pregnancy']))
|
||||
if self['pregnancy'] >= self.gestation:
|
||||
number_of_babies = int(8+4*self.random.random())
|
||||
self.info('Having {} babies'.format(number_of_babies))
|
||||
for i in range(number_of_babies):
|
||||
state = {}
|
||||
state['gender'] = self.random.choice(list(Genders)).value
|
||||
child = self.env.add_node(self.__class__, state)
|
||||
self.env.add_edge(self.id, child.id)
|
||||
self.env.add_edge(self['mate'], child.id)
|
||||
# self.add_edge()
|
||||
self.debug('A BABY IS COMING TO LIFE')
|
||||
self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.topology.number_of_nodes())+1
|
||||
self.debug('Rabbits alive: {}'.format(self.env['rabbits_alive']))
|
||||
self['offspring'] += 1
|
||||
self.env.get_agent(self['mate'])['offspring'] += 1
|
||||
del self['mate']
|
||||
self['pregnancy'] = -1
|
||||
return self.fertile
|
||||
|
||||
@state
|
||||
def dead(self):
|
||||
super().dead()
|
||||
if 'pregnancy' in self and self['pregnancy'] > -1:
|
||||
self.info('A mother has died carrying a baby!!')
|
||||
|
||||
|
||||
class RandomAccident(BaseAgent):
|
||||
|
||||
level = logging.DEBUG
|
||||
|
||||
def step(self):
|
||||
rabbits_total = self.env.topology.number_of_nodes()
|
||||
if 'rabbits_alive' not in self.env:
|
||||
self.env['rabbits_alive'] = 0
|
||||
rabbits_alive = self.env.get('rabbits_alive', rabbits_total)
|
||||
prob_death = self.env.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
|
||||
self.debug('Killing some rabbits with prob={}!'.format(prob_death))
|
||||
for i in self.env.network_agents:
|
||||
if i.state['id'] == i.dead.id:
|
||||
continue
|
||||
if self.prob(prob_death):
|
||||
self.debug('I killed a rabbit: {}'.format(i.id))
|
||||
rabbits_alive = self.env['rabbits_alive'] = rabbits_alive -1
|
||||
self.log('Rabbits alive: {}'.format(self.env['rabbits_alive']))
|
||||
i.set_state(i.dead)
|
||||
self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total))
|
||||
if self.env.count_agents(state_id=RabbitModel.dead.id) == self.env.topology.number_of_nodes():
|
||||
self.die()
|
@ -1,20 +0,0 @@
|
||||
---
|
||||
name: rabbits_example
|
||||
max_time: 100
|
||||
interval: 1
|
||||
seed: MySeed
|
||||
agent_class: rabbit_agents.RabbitModel
|
||||
environment_agents:
|
||||
- agent_class: rabbit_agents.RandomAccident
|
||||
environment_params:
|
||||
prob_death: 0.001
|
||||
default_state:
|
||||
mating_prob: 0.1
|
||||
topology:
|
||||
nodes:
|
||||
- id: 1
|
||||
agent_class: rabbit_agents.Male
|
||||
- id: 0
|
||||
agent_class: rabbit_agents.Female
|
||||
directed: true
|
||||
links: []
|
@ -6,9 +6,10 @@ template:
|
||||
group: simple
|
||||
num_trials: 1
|
||||
interval: 1
|
||||
max_time: 2
|
||||
max_steps: 2
|
||||
seed: "CompleteSeed!"
|
||||
dump: false
|
||||
model_params:
|
||||
network_params:
|
||||
generator: complete_graph
|
||||
n: 10
|
||||
@ -19,7 +20,6 @@ template:
|
||||
state_id: 0
|
||||
- agent_class: AggregatedCounter
|
||||
weight: "{{ 1 - x1 }}"
|
||||
environment_params:
|
||||
name: "{{ x3 }}"
|
||||
skip_test: true
|
||||
vars:
|
||||
|
@ -81,6 +81,26 @@ class TerroristSpreadModel(FSM, Geo):
|
||||
return
|
||||
return self.leader
|
||||
|
||||
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
||||
'''Get a list of nodes in the ego network of *node* of radius *steps*'''
|
||||
node = as_node(node if node is not None else self)
|
||||
G = self.subgraph(**kwargs)
|
||||
return nx.ego_graph(G, node, center=center, radius=steps).nodes()
|
||||
|
||||
def degree(self, node, force=False):
|
||||
node = as_node(node)
|
||||
if force or (not hasattr(self.model, '_degree')) or getattr(self.model, '_last_step', 0) < self.now:
|
||||
self.model._degree = nx.degree_centrality(self.G)
|
||||
self.model._last_step = self.now
|
||||
return self.model._degree[node]
|
||||
|
||||
def betweenness(self, node, force=False):
|
||||
node = as_node(node)
|
||||
if force or (not hasattr(self.model, '_betweenness')) or getattr(self.model, '_last_step', 0) < self.now:
|
||||
self.model._betweenness = nx.betweenness_centrality(self.G)
|
||||
self.model._last_step = self.now
|
||||
return self.model._betweenness[node]
|
||||
|
||||
|
||||
class TrainingAreaModel(FSM, Geo):
|
||||
"""
|
||||
@ -194,14 +214,14 @@ class TerroristNetworkModel(TerroristSpreadModel):
|
||||
break
|
||||
|
||||
def get_distance(self, target):
|
||||
source_x, source_y = nx.get_node_attributes(self.topology, 'pos')[self.id]
|
||||
target_x, target_y = nx.get_node_attributes(self.topology, 'pos')[target]
|
||||
source_x, source_y = nx.get_node_attributes(self.G, 'pos')[self.id]
|
||||
target_x, target_y = nx.get_node_attributes(self.G, 'pos')[target]
|
||||
dx = abs( source_x - target_x )
|
||||
dy = abs( source_y - target_y )
|
||||
return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 )
|
||||
|
||||
def shortest_path_length(self, target):
|
||||
try:
|
||||
return nx.shortest_path_length(self.topology, self.id, target)
|
||||
return nx.shortest_path_length(self.G, self.id, target)
|
||||
except nx.NetworkXNoPath:
|
||||
return float('inf')
|
||||
|
@ -1,13 +1,14 @@
|
||||
name: TerroristNetworkModel_sim
|
||||
max_time: 150
|
||||
max_steps: 150
|
||||
num_trials: 1
|
||||
network_params:
|
||||
model_params:
|
||||
network_params:
|
||||
generator: random_geometric_graph
|
||||
radius: 0.2
|
||||
# generator: geographical_threshold_graph
|
||||
# theta: 20
|
||||
n: 100
|
||||
network_agents:
|
||||
network_agents:
|
||||
- agent_class: TerroristNetworkModel.TerroristNetworkModel
|
||||
weight: 0.8
|
||||
state:
|
||||
@ -25,7 +26,6 @@ network_agents:
|
||||
state:
|
||||
id: civilian # Civilian
|
||||
|
||||
environment_params:
|
||||
# TerroristSpreadModel
|
||||
information_spread_intensity: 0.7
|
||||
terrorist_additional_influence: 0.035
|
||||
|
@ -1,13 +1,14 @@
|
||||
---
|
||||
name: torvalds_example
|
||||
max_time: 10
|
||||
max_steps: 10
|
||||
interval: 2
|
||||
agent_class: CounterModel
|
||||
default_state:
|
||||
model_params:
|
||||
agent_class: CounterModel
|
||||
default_state:
|
||||
skill_level: 'beginner'
|
||||
network_params:
|
||||
network_params:
|
||||
path: 'torvalds.edgelist'
|
||||
states:
|
||||
states:
|
||||
Torvalds:
|
||||
skill_level: 'God'
|
||||
balkian:
|
||||
|
@ -2,8 +2,9 @@ networkx>=2.5
|
||||
numpy
|
||||
matplotlib
|
||||
pyyaml>=5.1
|
||||
pandas>=0.23
|
||||
pandas>=1
|
||||
SALib>=1.3
|
||||
Jinja2
|
||||
Mesa>=0.8.9
|
||||
Mesa>=1
|
||||
pydantic>=1.9
|
||||
sqlalchemy>=1.4
|
||||
|
@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
import os
|
||||
import pdb
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from .version import __version__
|
||||
|
||||
@ -16,11 +18,10 @@ from . import agents
|
||||
from .simulation import *
|
||||
from .environment import Environment
|
||||
from . import serialization
|
||||
from . import analysis
|
||||
from .utils import logger
|
||||
from .time import *
|
||||
|
||||
def main():
|
||||
def main(cfg='simulation.yml', **kwargs):
|
||||
import argparse
|
||||
from . import simulation
|
||||
|
||||
@ -29,7 +30,7 @@ def main():
|
||||
parser = argparse.ArgumentParser(description='Run a SOIL simulation')
|
||||
parser.add_argument('file', type=str,
|
||||
nargs="?",
|
||||
default='simulation.yml',
|
||||
default=cfg,
|
||||
help='Configuration file for the simulation (e.g., YAML or JSON)')
|
||||
parser.add_argument('--version', action='store_true',
|
||||
help='Show version info and exit')
|
||||
@ -39,6 +40,8 @@ def main():
|
||||
help='Do not store the results of the simulation to disk, show in terminal instead.')
|
||||
parser.add_argument('--pdb', action='store_true',
|
||||
help='Use a pdb console in case of exception.')
|
||||
parser.add_argument('--debug', action='store_true',
|
||||
help='Run a customized version of a pdb console to debug a simulation.')
|
||||
parser.add_argument('--graph', '-g', action='store_true',
|
||||
help='Dump each trial\'s network topology as a GEXF graph. Defaults to false.')
|
||||
parser.add_argument('--csv', action='store_true',
|
||||
@ -51,9 +54,22 @@ def main():
|
||||
help='Run trials serially and synchronously instead of in parallel. Defaults to false.')
|
||||
parser.add_argument('-e', '--exporter', action='append',
|
||||
help='Export environment and/or simulations using this exporter')
|
||||
parser.add_argument('--only-convert', '--convert', action='store_true',
|
||||
help='Do not run the simulation, only convert the configuration file(s) and output them.')
|
||||
|
||||
|
||||
parser.add_argument("--set",
|
||||
metavar="KEY=VALUE",
|
||||
action='append',
|
||||
help="Set a number of parameters that will be passed to the simulation."
|
||||
"(do not put spaces before or after the = sign). "
|
||||
"If a value contains spaces, you should define "
|
||||
"it with double quotes: "
|
||||
'foo="this is a sentence". Note that '
|
||||
"values are always treated as strings.")
|
||||
|
||||
args = parser.parse_args()
|
||||
logging.basicConfig(level=getattr(logging, (args.level or 'INFO').upper()))
|
||||
logger.setLevel(getattr(logging, (args.level or 'INFO').upper()))
|
||||
|
||||
if args.version:
|
||||
return
|
||||
@ -65,9 +81,10 @@ def main():
|
||||
|
||||
logger.info('Loading config file: {}'.format(args.file))
|
||||
|
||||
if args.pdb:
|
||||
if args.pdb or args.debug:
|
||||
args.synchronous = True
|
||||
|
||||
if args.debug:
|
||||
os.environ['SOIL_DEBUG'] = 'true'
|
||||
|
||||
try:
|
||||
exporters = list(args.exporter or ['default', ])
|
||||
@ -82,18 +99,48 @@ def main():
|
||||
if not os.path.exists(args.file):
|
||||
logger.error('Please, input a valid file')
|
||||
return
|
||||
simulation.run_from_config(args.file,
|
||||
dry_run=args.dry_run,
|
||||
for sim in simulation.iter_from_config(args.file):
|
||||
if args.set:
|
||||
for s in args.set:
|
||||
k, v = s.split('=', 1)[:2]
|
||||
v = eval(v)
|
||||
tail, *head = k.rsplit('.', 1)[::-1]
|
||||
target = sim
|
||||
if head:
|
||||
for part in head[0].split('.'):
|
||||
try:
|
||||
target = getattr(target, part)
|
||||
except AttributeError:
|
||||
target = target[part]
|
||||
try:
|
||||
setattr(target, tail, v)
|
||||
except AttributeError:
|
||||
target[tail] = v
|
||||
|
||||
if args.only_convert:
|
||||
print(sim.to_yaml())
|
||||
continue
|
||||
|
||||
sim.run_simulation(dry_run=args.dry_run,
|
||||
exporters=exporters,
|
||||
parallel=(not args.synchronous),
|
||||
outdir=args.output,
|
||||
exporter_params=exp_params)
|
||||
except Exception:
|
||||
exporter_params=exp_params,
|
||||
**kwargs)
|
||||
|
||||
except Exception as ex:
|
||||
if args.pdb:
|
||||
pdb.post_mortem()
|
||||
from .debugging import post_mortem
|
||||
print(traceback.format_exc())
|
||||
post_mortem()
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def easy(cfg, debug=False):
|
||||
sim = simulation.from_config(cfg)
|
||||
if debug or os.environ.get('SOIL_DEBUG'):
|
||||
from .debugging import setup
|
||||
setup(sys._getframe().f_back)
|
||||
return sim
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -7,15 +7,13 @@ class CounterModel(NetworkAgent):
|
||||
in each step and adds it to its state.
|
||||
"""
|
||||
|
||||
defaults = {
|
||||
'times': 0,
|
||||
'neighbors': 0,
|
||||
'total': 0
|
||||
}
|
||||
times = 0
|
||||
neighbors = 0
|
||||
total = 0
|
||||
|
||||
def step(self):
|
||||
# Outside effects
|
||||
total = len(list(self.env.agents))
|
||||
total = len(list(self.model.schedule._agents))
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['times'] = self.get('times', 0) + 1
|
||||
self['neighbors'] = neighbors
|
||||
@ -28,17 +26,15 @@ class AggregatedCounter(NetworkAgent):
|
||||
in each step and adds it to its state.
|
||||
"""
|
||||
|
||||
defaults = {
|
||||
'times': 0,
|
||||
'neighbors': 0,
|
||||
'total': 0
|
||||
}
|
||||
times = 0
|
||||
neighbors = 0
|
||||
total = 0
|
||||
|
||||
def step(self):
|
||||
# Outside effects
|
||||
self['times'] += 1
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['neighbors'] += neighbors
|
||||
total = len(list(self.env.agents))
|
||||
total = len(list(self.model.schedule.agents))
|
||||
self['total'] += total
|
||||
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections.abc import MutableMapping, Mapping, Set
|
||||
@ -5,9 +7,13 @@ from abc import ABCMeta
|
||||
from copy import deepcopy, copy
|
||||
from functools import partial, wraps
|
||||
from itertools import islice, chain
|
||||
import json
|
||||
import inspect
|
||||
import types
|
||||
import textwrap
|
||||
import networkx as nx
|
||||
|
||||
from typing import Any
|
||||
|
||||
from mesa import Agent as MesaAgent
|
||||
from typing import Dict, List
|
||||
|
||||
@ -27,7 +33,31 @@ class DeadAgent(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseAgent(MesaAgent, MutableMapping):
|
||||
class MetaAgent(ABCMeta):
|
||||
def __new__(mcls, name, bases, namespace):
|
||||
defaults = {}
|
||||
|
||||
# Re-use defaults from inherited classes
|
||||
for i in bases:
|
||||
if isinstance(i, MetaAgent):
|
||||
defaults.update(i._defaults)
|
||||
|
||||
new_nmspc = {
|
||||
'_defaults': defaults,
|
||||
}
|
||||
|
||||
for attr, func in namespace.items():
|
||||
if isinstance(func, types.FunctionType) or isinstance(func, property) or attr[0] == '_':
|
||||
new_nmspc[attr] = func
|
||||
elif attr == 'defaults':
|
||||
defaults.update(func)
|
||||
else:
|
||||
defaults[attr] = copy(func)
|
||||
|
||||
return super().__new__(mcls=mcls, name=name, bases=bases, namespace=new_nmspc)
|
||||
|
||||
|
||||
class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
"""
|
||||
A special type of Mesa Agent that:
|
||||
|
||||
@ -39,15 +69,12 @@ class BaseAgent(MesaAgent, MutableMapping):
|
||||
Any attribute that is not preceded by an underscore (`_`) will also be added to its state.
|
||||
"""
|
||||
|
||||
defaults = {}
|
||||
|
||||
def __init__(self,
|
||||
unique_id,
|
||||
model,
|
||||
name=None,
|
||||
interval=None,
|
||||
**kwargs
|
||||
):
|
||||
**kwargs):
|
||||
# Check for REQUIRED arguments
|
||||
# Initialize agent parameters
|
||||
if isinstance(unique_id, MesaAgent):
|
||||
@ -58,15 +85,16 @@ class BaseAgent(MesaAgent, MutableMapping):
|
||||
self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id)
|
||||
|
||||
|
||||
self._neighbors = None
|
||||
self.alive = True
|
||||
|
||||
self.interval = interval or self.get('interval', 1)
|
||||
self.logger = logging.getLogger(self.model.id).getChild(self.name)
|
||||
logger = utils.logger.getChild(getattr(self.model, 'id', self.model)).getChild(self.name)
|
||||
self.logger = logging.LoggerAdapter(logger, {'agent_name': self.name})
|
||||
|
||||
if hasattr(self, 'level'):
|
||||
self.logger.setLevel(self.level)
|
||||
for (k, v) in self.defaults.items():
|
||||
|
||||
for (k, v) in self._defaults.items():
|
||||
if not hasattr(self, k) or getattr(self, k) is None:
|
||||
setattr(self, k, deepcopy(v))
|
||||
|
||||
@ -74,10 +102,6 @@ class BaseAgent(MesaAgent, MutableMapping):
|
||||
|
||||
setattr(self, k, v)
|
||||
|
||||
for (k, v) in getattr(self, 'defaults', {}).items():
|
||||
if not hasattr(self, k) or getattr(self, k) is None:
|
||||
setattr(self, k, v)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.unique_id)
|
||||
|
||||
@ -89,14 +113,6 @@ class BaseAgent(MesaAgent, MutableMapping):
|
||||
def id(self):
|
||||
return self.unique_id
|
||||
|
||||
@property
|
||||
def env(self):
|
||||
return self.model
|
||||
|
||||
@env.setter
|
||||
def env(self, model):
|
||||
self.model = model
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
'''
|
||||
@ -108,19 +124,16 @@ class BaseAgent(MesaAgent, MutableMapping):
|
||||
|
||||
@state.setter
|
||||
def state(self, value):
|
||||
if not value:
|
||||
return
|
||||
for k, v in value.items():
|
||||
self[k] = v
|
||||
|
||||
@property
|
||||
def environment_params(self):
|
||||
return self.model.environment_params
|
||||
|
||||
@environment_params.setter
|
||||
def environment_params(self, value):
|
||||
self.model.environment_params = value
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError:
|
||||
raise KeyError(f'key {key} not found in agent')
|
||||
|
||||
def __delitem__(self, key):
|
||||
return delattr(self, key)
|
||||
@ -138,10 +151,14 @@ class BaseAgent(MesaAgent, MutableMapping):
|
||||
return self.items()
|
||||
|
||||
def keys(self):
|
||||
return (k for k in self.__dict__ if k[0] != '_')
|
||||
return (k for k in self.__dict__ if k[0] != '_' and k not in IGNORED_FIELDS)
|
||||
|
||||
def items(self):
|
||||
return ((k, v) for (k, v) in self.__dict__.items() if k[0] != '_')
|
||||
def items(self, keys=None, skip=None):
|
||||
keys = keys if keys is not None else self.keys()
|
||||
it = ((k, self.get(k, None)) for k in keys)
|
||||
if skip:
|
||||
return filter(lambda x: x[0] not in skip, it)
|
||||
return it
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self[key] if key in self else default
|
||||
@ -154,11 +171,9 @@ class BaseAgent(MesaAgent, MutableMapping):
|
||||
# No environment
|
||||
return None
|
||||
|
||||
def die(self, remove=False):
|
||||
self.info(f'agent {self.unique_id} is dying')
|
||||
def die(self):
|
||||
self.info(f'agent dying')
|
||||
self.alive = False
|
||||
if remove:
|
||||
self.remove_node(self.id)
|
||||
return time.NEVER
|
||||
|
||||
def step(self):
|
||||
@ -170,7 +185,7 @@ class BaseAgent(MesaAgent, MutableMapping):
|
||||
if not self.logger.isEnabledFor(level):
|
||||
return
|
||||
message = message + " ".join(str(i) for i in args)
|
||||
message = " @{:>3}: {}".format(self.now, message)
|
||||
message = "[@{:>4}]\t{:>10}: {}".format(self.now, repr(self), message)
|
||||
for k, v in kwargs:
|
||||
message += " {k}={v} ".format(k, v)
|
||||
extra = {}
|
||||
@ -179,33 +194,48 @@ class BaseAgent(MesaAgent, MutableMapping):
|
||||
extra['agent_name'] = self.name
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
|
||||
|
||||
|
||||
def debug(self, *args, **kwargs):
|
||||
return self.log(*args, level=logging.DEBUG, **kwargs)
|
||||
|
||||
def info(self, *args, **kwargs):
|
||||
return self.log(*args, level=logging.INFO, **kwargs)
|
||||
|
||||
# Alias
|
||||
# Agent = BaseAgent
|
||||
def count_agents(self, **kwargs):
|
||||
return len(list(self.get_agents(**kwargs)))
|
||||
|
||||
def get_agents(self, *args, **kwargs):
|
||||
it = self.iter_agents(*args, **kwargs)
|
||||
return list(it)
|
||||
|
||||
def iter_agents(self, *args, **kwargs):
|
||||
yield from filter_agents(self.model.schedule._agents, *args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return self.to_str()
|
||||
|
||||
def to_str(self, keys=None, skip=None, pretty=False):
|
||||
content = dict(self.items(keys=keys))
|
||||
if pretty and content:
|
||||
d = content
|
||||
content = '\n'
|
||||
for k, v in d.items():
|
||||
content += f'- {k}: {v}\n'
|
||||
content = textwrap.indent(content, ' ')
|
||||
return f"{repr(self)}{content}"
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self.unique_id})"
|
||||
|
||||
|
||||
class NetworkAgent(BaseAgent):
|
||||
|
||||
@property
|
||||
def topology(self):
|
||||
return self.env.topology_for(self.unique_id)
|
||||
def __init__(self, *args, topology, node_id, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def node_id(self):
|
||||
return self.env.node_id_for(self.unique_id)
|
||||
|
||||
@property
|
||||
def G(self):
|
||||
return self.model.topologies[self._topology]
|
||||
|
||||
def count_agents(self, **kwargs):
|
||||
return len(list(self.get_agents(**kwargs)))
|
||||
self.topology = topology
|
||||
self.node_id = node_id
|
||||
self.G = self.model.topologies[topology]
|
||||
assert self.G
|
||||
|
||||
def count_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return len(self.get_neighboring_agents(state_id=state_id, **kwargs))
|
||||
@ -213,57 +243,47 @@ class NetworkAgent(BaseAgent):
|
||||
def get_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs)
|
||||
|
||||
def get_agents(self, *args, limit=None, **kwargs):
|
||||
it = self.iter_agents(*args, **kwargs)
|
||||
if limit is not None:
|
||||
it = islice(it, limit)
|
||||
return list(it)
|
||||
|
||||
def iter_agents(self, unique_id=None, limit_neighbors=False, **kwargs):
|
||||
unique_ids = None
|
||||
if isinstance(unique_id, list):
|
||||
unique_ids = set(unique_id)
|
||||
elif unique_id is not None:
|
||||
unique_ids = set([unique_id,])
|
||||
|
||||
if limit_neighbors:
|
||||
unique_id = [self.topology.nodes[node]['agent_id'] for node in self.topology.neighbors(self.node_id)]
|
||||
if not unique_id:
|
||||
neighbor_ids = set()
|
||||
for node_id in self.G.neighbors(self.node_id):
|
||||
if self.G.nodes[node_id].get('agent_id') is not None:
|
||||
neighbor_ids.add(node_id)
|
||||
if unique_ids:
|
||||
unique_ids = unique_ids & neighbor_ids
|
||||
else:
|
||||
unique_ids = neighbor_ids
|
||||
if not unique_ids:
|
||||
return
|
||||
|
||||
yield from self.model.agents(unique_id=unique_id, **kwargs)
|
||||
|
||||
unique_ids = list(unique_ids)
|
||||
yield from super().iter_agents(unique_id=unique_ids, **kwargs)
|
||||
|
||||
def subgraph(self, center=True, **kwargs):
|
||||
include = [self] if center else []
|
||||
G = self.topology.subgraph(n.node_id for n in list(self.get_agents(**kwargs)+include))
|
||||
G = self.G.subgraph(n.node_id for n in list(self.get_agents(**kwargs)+include))
|
||||
return G
|
||||
|
||||
def remove_node(self, unique_id):
|
||||
self.topology.remove_node(unique_id)
|
||||
def remove_node(self):
|
||||
self.G.remove_node(self.node_id)
|
||||
|
||||
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
||||
# return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs)
|
||||
if self.unique_id not in self.topology.nodes(data=False):
|
||||
if self.node_id not in self.G.nodes(data=False):
|
||||
raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id))
|
||||
if other.unique_id not in self.topology.nodes(data=False):
|
||||
if other.node_id not in self.G.nodes(data=False):
|
||||
raise ValueError('{} not in list of existing agents in the network'.format(other))
|
||||
|
||||
self.topology.add_edge(self.unique_id, other.unique_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
||||
self.G.add_edge(self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
||||
|
||||
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
||||
'''Get a list of nodes in the ego network of *node* of radius *steps*'''
|
||||
node = as_node(node if node is not None else self)
|
||||
G = self.subgraph(**kwargs)
|
||||
return nx.ego_graph(G, node, center=center, radius=steps).nodes()
|
||||
|
||||
def degree(self, node, force=False):
|
||||
node = as_node(node)
|
||||
if force or (not hasattr(self.model, '_degree')) or getattr(self.model, '_last_step', 0) < self.now:
|
||||
self.model._degree = nx.degree_centrality(self.topology)
|
||||
self.model._last_step = self.now
|
||||
return self.model._degree[node]
|
||||
|
||||
def betweenness(self, node, force=False):
|
||||
node = as_node(node)
|
||||
if force or (not hasattr(self.model, '_betweenness')) or getattr(self.model, '_last_step', 0) < self.now:
|
||||
self.model._betweenness = nx.betweenness_centrality(self.topology)
|
||||
self.model._last_step = self.now
|
||||
return self.model._betweenness[node]
|
||||
def die(self, remove=True):
|
||||
if remove:
|
||||
self.remove_node()
|
||||
return super().die()
|
||||
|
||||
|
||||
def state(name=None):
|
||||
@ -273,24 +293,29 @@ def state(name=None):
|
||||
The default value for state_id is the current state id.
|
||||
The default value for when is the interval defined in the environment.
|
||||
'''
|
||||
if inspect.isgeneratorfunction(func):
|
||||
orig_func = func
|
||||
|
||||
@wraps(func)
|
||||
def func_wrapper(self):
|
||||
next_state = func(self)
|
||||
when = None
|
||||
if next_state is None:
|
||||
return when
|
||||
def func(self):
|
||||
while True:
|
||||
if not self._coroutine:
|
||||
self._coroutine = orig_func(self)
|
||||
try:
|
||||
next_state, when = next_state
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
if next_state:
|
||||
n = next(self._coroutine)
|
||||
if n:
|
||||
return None, n
|
||||
return
|
||||
except StopIteration as ex:
|
||||
self._coroutine = None
|
||||
next_state = ex.value
|
||||
if next_state is not None:
|
||||
self.set_state(next_state)
|
||||
return when
|
||||
return next_state
|
||||
|
||||
func_wrapper.id = name or func.__name__
|
||||
func_wrapper.is_default = False
|
||||
return func_wrapper
|
||||
func.id = name or func.__name__
|
||||
func.is_default = False
|
||||
return func
|
||||
|
||||
if callable(name):
|
||||
return decorator(name)
|
||||
@ -303,60 +328,84 @@ def default_state(func):
|
||||
return func
|
||||
|
||||
|
||||
class MetaFSM(ABCMeta):
|
||||
def __init__(cls, name, bases, nmspc):
|
||||
super(MetaFSM, cls).__init__(name, bases, nmspc)
|
||||
class MetaFSM(MetaAgent):
|
||||
def __new__(mcls, name, bases, namespace):
|
||||
states = {}
|
||||
# Re-use states from inherited classes
|
||||
default_state = None
|
||||
for i in bases:
|
||||
if isinstance(i, MetaFSM):
|
||||
for state_id, state in i.states.items():
|
||||
for state_id, state in i._states.items():
|
||||
if state.is_default:
|
||||
default_state = state
|
||||
states[state_id] = state
|
||||
|
||||
# Add new states
|
||||
for name, func in nmspc.items():
|
||||
for attr, func in namespace.items():
|
||||
if hasattr(func, 'id'):
|
||||
if func.is_default:
|
||||
default_state = func
|
||||
states[func.id] = func
|
||||
cls.default_state = default_state
|
||||
cls.states = states
|
||||
|
||||
namespace.update({
|
||||
'_default_state': default_state,
|
||||
'_states': states,
|
||||
})
|
||||
|
||||
return super(MetaFSM, mcls).__new__(mcls=mcls, name=name, bases=bases, namespace=namespace)
|
||||
|
||||
|
||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FSM, self).__init__(*args, **kwargs)
|
||||
if not hasattr(self, 'state_id'):
|
||||
if not self.default_state:
|
||||
if not self._default_state:
|
||||
raise ValueError('No default state specified for {}'.format(self.unique_id))
|
||||
self.state_id = self.default_state.id
|
||||
self.state_id = self._default_state.id
|
||||
|
||||
self._coroutine = None
|
||||
self.set_state(self.state_id)
|
||||
|
||||
def step(self):
|
||||
self.debug(f'Agent {self.unique_id} @ state {self.state_id}')
|
||||
interval = super().step()
|
||||
if 'id' not in self.state:
|
||||
if self.default_state:
|
||||
self.set_state(self.default_state.id)
|
||||
else:
|
||||
raise Exception('{} has no valid state id or default state'.format(self))
|
||||
interval = self.states[self.state_id](self) or interval
|
||||
if not self.alive:
|
||||
return time.NEVER
|
||||
return interval
|
||||
default_interval = super().step()
|
||||
|
||||
def set_state(self, state):
|
||||
next_state = self._states[self.state_id](self)
|
||||
|
||||
when = None
|
||||
try:
|
||||
next_state, *when = next_state
|
||||
if not when:
|
||||
when = None
|
||||
elif len(when) == 1:
|
||||
when = when[0]
|
||||
else:
|
||||
raise ValueError('Too many values returned. Only state (and time) allowed')
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
if next_state is not None:
|
||||
self.set_state(next_state)
|
||||
|
||||
return when or default_interval
|
||||
|
||||
def set_state(self, state, when=None):
|
||||
if hasattr(state, 'id'):
|
||||
state = state.id
|
||||
if state not in self.states:
|
||||
if state not in self._states:
|
||||
raise ValueError('{} is not a valid state'.format(state))
|
||||
self.state_id = state
|
||||
if when is not None:
|
||||
self.model.schedule.add(self, when=when)
|
||||
return state
|
||||
|
||||
def die(self):
|
||||
return self.dead, super().die()
|
||||
|
||||
@state
|
||||
def dead(self):
|
||||
return self.die()
|
||||
|
||||
|
||||
def prob(prob, random):
|
||||
'''
|
||||
@ -476,81 +525,81 @@ def _convert_agent_classs(ind, to_string=False, **kwargs):
|
||||
return deserialize_definition(ind, **kwargs)
|
||||
|
||||
|
||||
def _agent_from_definition(definition, random, value=-1, unique_id=None):
|
||||
"""Used in the initialization of agents given an agent distribution."""
|
||||
if value < 0:
|
||||
value = random.random()
|
||||
for d in sorted(definition, key=lambda x: x.get('threshold')):
|
||||
threshold = d.get('threshold', (-1, -1))
|
||||
# Check if the definition matches by id (first) or by threshold
|
||||
if (unique_id is not None and unique_id in d.get('ids', [])) or \
|
||||
(value >= threshold[0] and value < threshold[1]):
|
||||
state = {}
|
||||
if 'state' in d:
|
||||
state = deepcopy(d['state'])
|
||||
return d['agent_class'], state
|
||||
# def _agent_from_definition(definition, random, value=-1, unique_id=None):
|
||||
# """Used in the initialization of agents given an agent distribution."""
|
||||
# if value < 0:
|
||||
# value = random.random()
|
||||
# for d in sorted(definition, key=lambda x: x.get('threshold')):
|
||||
# threshold = d.get('threshold', (-1, -1))
|
||||
# # Check if the definition matches by id (first) or by threshold
|
||||
# if (unique_id is not None and unique_id in d.get('ids', [])) or \
|
||||
# (value >= threshold[0] and value < threshold[1]):
|
||||
# state = {}
|
||||
# if 'state' in d:
|
||||
# state = deepcopy(d['state'])
|
||||
# return d['agent_class'], state
|
||||
|
||||
raise Exception('Definition for value {} not found in: {}'.format(value, definition))
|
||||
# raise Exception('Definition for value {} not found in: {}'.format(value, definition))
|
||||
|
||||
|
||||
def _definition_to_dict(definition, random, size=None, default_state=None):
|
||||
state = default_state or {}
|
||||
agents = {}
|
||||
remaining = {}
|
||||
if size:
|
||||
for ix in range(size):
|
||||
remaining[ix] = copy(state)
|
||||
else:
|
||||
remaining = defaultdict(lambda x: copy(state))
|
||||
# def _definition_to_dict(definition, random, size=None, default_state=None):
|
||||
# state = default_state or {}
|
||||
# agents = {}
|
||||
# remaining = {}
|
||||
# if size:
|
||||
# for ix in range(size):
|
||||
# remaining[ix] = copy(state)
|
||||
# else:
|
||||
# remaining = defaultdict(lambda x: copy(state))
|
||||
|
||||
distro = sorted([item for item in definition if 'weight' in item])
|
||||
# distro = sorted([item for item in definition if 'weight' in item])
|
||||
|
||||
id = 0
|
||||
# id = 0
|
||||
|
||||
def init_agent(item, id=ix):
|
||||
while id in agents:
|
||||
id += 1
|
||||
# def init_agent(item, id=ix):
|
||||
# while id in agents:
|
||||
# id += 1
|
||||
|
||||
agent = remaining[id]
|
||||
agent['state'].update(copy(item.get('state', {})))
|
||||
agents[agent.unique_id] = agent
|
||||
del remaining[id]
|
||||
return agent
|
||||
# agent = remaining[id]
|
||||
# agent['state'].update(copy(item.get('state', {})))
|
||||
# agents[agent.unique_id] = agent
|
||||
# del remaining[id]
|
||||
# return agent
|
||||
|
||||
for item in definition:
|
||||
if 'ids' in item:
|
||||
ids = item['ids']
|
||||
del item['ids']
|
||||
for id in ids:
|
||||
agent = init_agent(item, id)
|
||||
# for item in definition:
|
||||
# if 'ids' in item:
|
||||
# ids = item['ids']
|
||||
# del item['ids']
|
||||
# for id in ids:
|
||||
# agent = init_agent(item, id)
|
||||
|
||||
for item in definition:
|
||||
if 'number' in item:
|
||||
times = item['number']
|
||||
del item['number']
|
||||
for times in range(times):
|
||||
if size:
|
||||
ix = random.choice(remaining.keys())
|
||||
agent = init_agent(item, id)
|
||||
else:
|
||||
agent = init_agent(item)
|
||||
if not size:
|
||||
return agents
|
||||
# for item in definition:
|
||||
# if 'number' in item:
|
||||
# times = item['number']
|
||||
# del item['number']
|
||||
# for times in range(times):
|
||||
# if size:
|
||||
# ix = random.choice(remaining.keys())
|
||||
# agent = init_agent(item, id)
|
||||
# else:
|
||||
# agent = init_agent(item)
|
||||
# if not size:
|
||||
# return agents
|
||||
|
||||
if len(remaining) < 0:
|
||||
raise Exception('Invalid definition. Too many agents to add')
|
||||
# if len(remaining) < 0:
|
||||
# raise Exception('Invalid definition. Too many agents to add')
|
||||
|
||||
|
||||
total_weight = float(sum(s['weight'] for s in distro))
|
||||
unit = size / total_weight
|
||||
# total_weight = float(sum(s['weight'] for s in distro))
|
||||
# unit = size / total_weight
|
||||
|
||||
for item in distro:
|
||||
times = unit * item['weight']
|
||||
del item['weight']
|
||||
for times in range(times):
|
||||
ix = random.choice(remaining.keys())
|
||||
agent = init_agent(item, id)
|
||||
return agents
|
||||
# for item in distro:
|
||||
# times = unit * item['weight']
|
||||
# del item['weight']
|
||||
# for times in range(times):
|
||||
# ix = random.choice(remaining.keys())
|
||||
# agent = init_agent(item, id)
|
||||
# return agents
|
||||
|
||||
|
||||
class AgentView(Mapping, Set):
|
||||
@ -571,59 +620,43 @@ class AgentView(Mapping, Set):
|
||||
|
||||
# Mapping methods
|
||||
def __len__(self):
|
||||
return sum(len(x) for x in self._agents.values())
|
||||
return len(self._agents)
|
||||
|
||||
def __iter__(self):
|
||||
yield from iter(chain.from_iterable(g.values() for g in self._agents.values()))
|
||||
yield from self._agents.values()
|
||||
|
||||
def __getitem__(self, agent_id):
|
||||
if isinstance(agent_id, slice):
|
||||
raise ValueError(f"Slicing is not supported")
|
||||
for group in self._agents.values():
|
||||
if agent_id in group:
|
||||
return group[agent_id]
|
||||
if agent_id in self._agents:
|
||||
return self._agents[agent_id]
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
def filter(self, *args, **kwargs):
|
||||
yield from filter_groups(self._agents, *args, **kwargs)
|
||||
yield from filter_agents(self._agents, *args, **kwargs)
|
||||
|
||||
def one(self, *args, **kwargs):
|
||||
return next(filter_groups(self._agents, *args, **kwargs))
|
||||
return next(filter_agents(self._agents, *args, **kwargs))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return list(self.filter(*args, **kwargs))
|
||||
|
||||
def __contains__(self, agent_id):
|
||||
return any(agent_id in g for g in self._agents)
|
||||
return agent_id in self._agents
|
||||
|
||||
def __str__(self):
|
||||
return str(list(a.unique_id for a in self))
|
||||
return str(list(unique_id for unique_id in self.keys()))
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self})"
|
||||
|
||||
|
||||
def filter_groups(groups, *, group=None, **kwargs):
|
||||
assert isinstance(groups, dict)
|
||||
|
||||
if group is not None and not isinstance(group, list):
|
||||
group = [group]
|
||||
|
||||
if group:
|
||||
groups = list(groups[g] for g in group if g in groups)
|
||||
else:
|
||||
groups = list(groups.values())
|
||||
|
||||
agents = chain.from_iterable(filter_group(g, **kwargs) for g in groups)
|
||||
|
||||
yield from agents
|
||||
|
||||
|
||||
def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=None, ignore=None, state=None, **kwargs):
|
||||
def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=None, ignore=None, state=None,
|
||||
limit=None, **kwargs):
|
||||
'''
|
||||
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
||||
'''
|
||||
assert isinstance(group, dict)
|
||||
assert isinstance(agents, dict)
|
||||
|
||||
ids = []
|
||||
|
||||
@ -636,6 +669,11 @@ def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=Non
|
||||
if id_args:
|
||||
ids += id_args
|
||||
|
||||
if ids:
|
||||
f = (agents[aid] for aid in ids if aid in agents)
|
||||
else:
|
||||
f = (a for a in agents.values())
|
||||
|
||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||
state_id = tuple([state_id])
|
||||
|
||||
@ -646,12 +684,6 @@ def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=Non
|
||||
except TypeError:
|
||||
agent_class = tuple([agent_class])
|
||||
|
||||
if ids:
|
||||
agents = (group[aid] for aid in ids if aid in group)
|
||||
else:
|
||||
agents = (a for a in group.values())
|
||||
|
||||
f = agents
|
||||
if ignore:
|
||||
f = filter(lambda x: x not in ignore, f)
|
||||
|
||||
@ -667,83 +699,125 @@ def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=Non
|
||||
for k, v in state.items():
|
||||
f = filter(lambda agent: agent.state.get(k, None) == v, f)
|
||||
|
||||
if limit is not None:
|
||||
f = islice(f, limit)
|
||||
|
||||
yield from f
|
||||
|
||||
|
||||
def from_config(cfg: Dict[str, config.AgentConfig], env, random):
|
||||
def from_config(cfg: config.AgentConfig, random, topologies: Dict[str, nx.Graph] = None) -> List[Dict[str, Any]]:
|
||||
'''
|
||||
Agents are specified in groups.
|
||||
Each group can be specified in two ways, either through a fixed list in which each item has
|
||||
has the agent type, number of agents to create, and the other parameters, or through what we call
|
||||
an `agent distribution`, which is similar but instead of number of agents, it specifies the weight
|
||||
of each agent type.
|
||||
This function turns an agentconfig into a list of individual "agent specifications", which are just a dictionary
|
||||
with the parameters that the environment will use to construct each agent.
|
||||
|
||||
This function does NOT return a list of agents, mostly because some attributes to the agent are not known at the
|
||||
time of calling this function, such as `unique_id`.
|
||||
'''
|
||||
default = cfg.get('default', None)
|
||||
return {k: _group_from_config(c, default=default, env=env, random=random) for (k, c) in cfg.items() if k is not 'default'}
|
||||
default = cfg or config.AgentConfig()
|
||||
if not isinstance(cfg, config.AgentConfig):
|
||||
cfg = config.AgentConfig(**cfg)
|
||||
return _agents_from_config(cfg, topologies=topologies, random=random)
|
||||
|
||||
|
||||
def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfig, env, random):
|
||||
def _agents_from_config(cfg: config.AgentConfig,
|
||||
topologies: Dict[str, nx.Graph],
|
||||
random) -> List[Dict[str, Any]]:
|
||||
if cfg and not isinstance(cfg, config.AgentConfig):
|
||||
cfg = config.AgentConfig(**cfg)
|
||||
if default and not isinstance(default, config.SingleAgentConfig):
|
||||
default = config.SingleAgentConfig(**default)
|
||||
|
||||
agents = {}
|
||||
agents = []
|
||||
|
||||
assigned = defaultdict(int)
|
||||
|
||||
if cfg.fixed is not None:
|
||||
agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env)
|
||||
if cfg.distribution:
|
||||
n = cfg.n or len(env.topologies[cfg.topology or default.topology])
|
||||
target = n - len(agents)
|
||||
agents.update(_from_distro(cfg.distribution, target,
|
||||
topology=cfg.topology or default.topology,
|
||||
default=default,
|
||||
env=env, random=random))
|
||||
assert len(agents) == n
|
||||
if cfg.override:
|
||||
for attrs in cfg.override:
|
||||
if attrs.filter:
|
||||
filtered = list(filter_group(agents, **attrs.filter))
|
||||
else:
|
||||
filtered = list(agents)
|
||||
agents, counts = _from_fixed(cfg.fixed, topology=cfg.topology, default=cfg)
|
||||
assigned.update(counts)
|
||||
|
||||
if attrs.n > len(filtered):
|
||||
raise ValueError(f'Not enough agents to sample. Got {len(filtered)}, expected >= {attrs.n}')
|
||||
for agent in random.sample(filtered, attrs.n):
|
||||
agent.state.update(attrs.state)
|
||||
n = cfg.n
|
||||
|
||||
if cfg.distribution:
|
||||
topo_size = {top: len(topologies[top]) for top in topologies}
|
||||
|
||||
grouped = defaultdict(list)
|
||||
total = []
|
||||
|
||||
for d in cfg.distribution:
|
||||
if d.strategy == config.Strategy.topology:
|
||||
topology = d.topology if ('topology' in d.__fields_set__) else cfg.topology
|
||||
if not topology:
|
||||
raise ValueError('The "topology" strategy only works if the topology parameter is specified')
|
||||
if topology not in topo_size:
|
||||
raise ValueError(f'Unknown topology selected: { topology }. Make sure the topology has been defined')
|
||||
|
||||
grouped[topology].append(d)
|
||||
|
||||
if d.strategy == config.Strategy.total:
|
||||
if not cfg.n:
|
||||
raise ValueError('Cannot use the "total" strategy without providing the total number of agents')
|
||||
total.append(d)
|
||||
|
||||
|
||||
for (topo, distro) in grouped.items():
|
||||
if not topologies or topo not in topo_size:
|
||||
raise ValueError(
|
||||
'You need to specify a target number of agents for the distribution \
|
||||
or a configuration with a topology, along with a dictionary with \
|
||||
all the available topologies')
|
||||
n = len(topologies[topo])
|
||||
target = topo_size[topo] - assigned[topo]
|
||||
new_agents = _from_distro(cfg.distribution, target,
|
||||
topology=topo,
|
||||
default=cfg,
|
||||
random=random)
|
||||
assigned[topo] += len(new_agents)
|
||||
agents += new_agents
|
||||
|
||||
if total:
|
||||
remaining = n - sum(assigned.values())
|
||||
agents += _from_distro(total, remaining,
|
||||
topology='', # DO NOT assign to any topology
|
||||
default=cfg,
|
||||
random=random)
|
||||
|
||||
|
||||
if sum(assigned.values()) != sum(topo_size.values()):
|
||||
utils.logger.warn(f'The total number of agents does not match the total number of nodes in '
|
||||
'every topology. This may be due to a definition error: assigned: '
|
||||
f'{ assigned } total sizes: { topo_size }')
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: config.SingleAgentConfig, env):
|
||||
agents = {}
|
||||
def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: config.SingleAgentConfig) -> List[Dict[str, Any]]:
|
||||
agents = []
|
||||
|
||||
counts = {}
|
||||
|
||||
for fixed in lst:
|
||||
agent_id = fixed.agent_id
|
||||
if agent_id is None:
|
||||
agent_id = env.next_id()
|
||||
agent = {}
|
||||
if default:
|
||||
agent = default.state.copy()
|
||||
agent.update(fixed.state)
|
||||
cls = serialization.deserialize(fixed.agent_class or (default and default.agent_class))
|
||||
agent['agent_class'] = cls
|
||||
topo = fixed.topology if ('topology' in fixed.__fields_set__) else topology or default.topology
|
||||
|
||||
cls = serialization.deserialize(fixed.agent_class or default.agent_class)
|
||||
state = fixed.state.copy()
|
||||
state.update(default.state)
|
||||
agent = cls(unique_id=agent_id,
|
||||
model=env,
|
||||
**state)
|
||||
topology = fixed.topology if (fixed.topology is not None) else (topology or default.topology)
|
||||
if topology:
|
||||
env.agent_to_node(agent_id, topology, fixed.node_id)
|
||||
agents[agent.unique_id] = agent
|
||||
if topo:
|
||||
agent['topology'] = topo
|
||||
if not fixed.hidden:
|
||||
counts[topo] = counts.get(topo, 0) + 1
|
||||
agents.append(agent)
|
||||
|
||||
return agents
|
||||
return agents, counts
|
||||
|
||||
|
||||
def _from_distro(distro: List[config.AgentDistro],
|
||||
n: int,
|
||||
topology: str,
|
||||
default: config.SingleAgentConfig,
|
||||
env,
|
||||
random):
|
||||
random) -> List[Dict[str, Any]]:
|
||||
|
||||
agents = {}
|
||||
agents = []
|
||||
|
||||
if n is None:
|
||||
if any(lambda dist: dist.n is None, distro):
|
||||
@ -775,19 +849,16 @@ def _from_distro(distro: List[config.AgentDistro],
|
||||
|
||||
for idx in indices:
|
||||
d = distro[idx]
|
||||
agent = d.state.copy()
|
||||
cls = classes[idx]
|
||||
agent_id = env.next_id()
|
||||
state = d.state.copy()
|
||||
agent['agent_class'] = cls
|
||||
if default:
|
||||
state.update(default.state)
|
||||
agent = cls(unique_id=agent_id, model=env, **state)
|
||||
topology = d.topology if (d.topology is not None) else topology or default.topology
|
||||
agent.update(default.state)
|
||||
# agent = cls(unique_id=agent_id, model=env, **state)
|
||||
topology = d.topology if ('topology' in d.__fields_set__) else topology or default.topology
|
||||
if topology:
|
||||
env.agent_to_node(agent.unique_id, topology)
|
||||
assert agent.name is not None
|
||||
assert agent.name != 'None'
|
||||
assert agent.name
|
||||
agents[agent.unique_id] = agent
|
||||
agent['topology'] = topology
|
||||
agents.append(agent)
|
||||
|
||||
return agents
|
||||
|
||||
|
206
soil/analysis.py
206
soil/analysis.py
@ -1,206 +0,0 @@
|
||||
import pandas as pd
|
||||
|
||||
import glob
|
||||
import yaml
|
||||
from os.path import join
|
||||
|
||||
from . import serialization
|
||||
from tsih import History
|
||||
|
||||
|
||||
def read_data(*args, group=False, **kwargs):
|
||||
iterable = _read_data(*args, **kwargs)
|
||||
if group:
|
||||
return group_trials(iterable)
|
||||
else:
|
||||
return list(iterable)
|
||||
|
||||
|
||||
def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
|
||||
if not process_args:
|
||||
process_args = {}
|
||||
for folder in glob.glob(pattern):
|
||||
config_file = glob.glob(join(folder, '*.yml'))[0]
|
||||
config = yaml.load(open(config_file), Loader=yaml.SafeLoader)
|
||||
df = None
|
||||
if from_csv:
|
||||
for trial_data in sorted(glob.glob(join(folder,
|
||||
'*.environment.csv'))):
|
||||
df = read_csv(trial_data, **kwargs)
|
||||
yield config_file, df, config
|
||||
else:
|
||||
for trial_data in sorted(glob.glob(join(folder, '*.sqlite'))):
|
||||
df = read_sql(trial_data, **kwargs)
|
||||
yield config_file, df, config
|
||||
|
||||
|
||||
def read_sql(db, *args, **kwargs):
|
||||
h = History(db_path=db, backup=False, readonly=True)
|
||||
df = h.read_sql(*args, **kwargs)
|
||||
return df
|
||||
|
||||
|
||||
def read_csv(filename, keys=None, convert_types=False, **kwargs):
|
||||
'''
|
||||
Read a CSV in canonical form: ::
|
||||
|
||||
<agent_id, t_step, key, value, value_type>
|
||||
|
||||
'''
|
||||
df = pd.read_csv(filename)
|
||||
if convert_types:
|
||||
df = convert_types_slow(df)
|
||||
if keys:
|
||||
df = df[df['key'].isin(keys)]
|
||||
df = process_one(df)
|
||||
return df
|
||||
|
||||
|
||||
def convert_row(row):
|
||||
row['value'] = serialization.deserialize(row['value_type'], row['value'])
|
||||
return row
|
||||
|
||||
|
||||
def convert_types_slow(df):
|
||||
'''
|
||||
Go over every column in a dataframe and convert it to the type determined by the `get_types`
|
||||
function.
|
||||
|
||||
This is a slow operation.
|
||||
'''
|
||||
dtypes = get_types(df)
|
||||
for k, v in dtypes.items():
|
||||
t = df[df['key']==k]
|
||||
t['value'] = t['value'].astype(v)
|
||||
df = df.apply(convert_row, axis=1)
|
||||
return df
|
||||
|
||||
|
||||
def split_processed(df):
|
||||
env = df.loc[:, df.columns.get_level_values(1).isin(['env', 'stats'])]
|
||||
agents = df.loc[:, ~df.columns.get_level_values(1).isin(['env', 'stats'])]
|
||||
return env, agents
|
||||
|
||||
|
||||
def split_df(df):
|
||||
'''
|
||||
Split a dataframe in two dataframes: one with the history of agents,
|
||||
and one with the environment history
|
||||
'''
|
||||
envmask = (df['agent_id'] == 'env')
|
||||
n_env = envmask.sum()
|
||||
if n_env == len(df):
|
||||
return df, None
|
||||
elif n_env == 0:
|
||||
return None, df
|
||||
agents, env = [x for _, x in df.groupby(envmask)]
|
||||
return env, agents
|
||||
|
||||
|
||||
def process(df, **kwargs):
|
||||
'''
|
||||
Process a dataframe in canonical form ``(t_step, agent_id, key, value, value_type)`` into
|
||||
two dataframes with a column per key: one with the history of the agents, and one for the
|
||||
history of the environment.
|
||||
'''
|
||||
env, agents = split_df(df)
|
||||
return process_one(env, **kwargs), process_one(agents, **kwargs)
|
||||
|
||||
|
||||
def get_types(df):
|
||||
'''
|
||||
Get the value type for every key stored in a raw history dataframe.
|
||||
'''
|
||||
dtypes = df.groupby(by=['key'])['value_type'].unique()
|
||||
return {k:v[0] for k,v in dtypes.iteritems()}
|
||||
|
||||
|
||||
def process_one(df, *keys, columns=['key', 'agent_id'], values='value',
|
||||
fill=True, index=['t_step',],
|
||||
aggfunc='first', **kwargs):
|
||||
'''
|
||||
Process a dataframe in canonical form ``(t_step, agent_id, key, value, value_type)`` into
|
||||
a dataframe with a column per key
|
||||
'''
|
||||
if df is None:
|
||||
return df
|
||||
if keys:
|
||||
df = df[df['key'].isin(keys)]
|
||||
|
||||
df = df.pivot_table(values=values, index=index, columns=columns,
|
||||
aggfunc=aggfunc, **kwargs)
|
||||
if fill:
|
||||
df = fillna(df)
|
||||
return df
|
||||
|
||||
|
||||
def get_count(df, *keys):
|
||||
'''
|
||||
For every t_step and key, get the value count.
|
||||
|
||||
The result is a dataframe with `t_step` as index, an a multiindex column based on `key` and the values found for each `key`.
|
||||
'''
|
||||
if keys:
|
||||
df = df[list(keys)]
|
||||
df.columns = df.columns.remove_unused_levels()
|
||||
counts = pd.DataFrame()
|
||||
for key in df.columns.levels[0]:
|
||||
g = df[[key]].apply(pd.Series.value_counts, axis=1).fillna(0)
|
||||
for value, series in g.iteritems():
|
||||
counts[key, value] = series
|
||||
counts.columns = pd.MultiIndex.from_tuples(counts.columns)
|
||||
return counts
|
||||
|
||||
|
||||
def get_majority(df, *keys):
|
||||
'''
|
||||
For every t_step and key, get the value of the majority of agents
|
||||
|
||||
The result is a dataframe with `t_step` as index, and columns based on `key`.
|
||||
'''
|
||||
df = get_count(df, *keys)
|
||||
return df.stack(level=0).idxmax(axis=1).unstack()
|
||||
|
||||
|
||||
def get_value(df, *keys, aggfunc='sum'):
|
||||
'''
|
||||
For every t_step and key, get the value of *numeric columns*, aggregated using a specific function.
|
||||
'''
|
||||
if keys:
|
||||
df = df[list(keys)]
|
||||
df.columns = df.columns.remove_unused_levels()
|
||||
df = df.select_dtypes('number')
|
||||
return df.groupby(level='key', axis=1).agg(aggfunc)
|
||||
|
||||
|
||||
def plot_all(*args, plot_args={}, **kwargs):
|
||||
'''
|
||||
Read all the trial data and plot the result of applying a function on them.
|
||||
'''
|
||||
dfs = do_all(*args, **kwargs)
|
||||
ps = []
|
||||
for line in dfs:
|
||||
f, df, config = line
|
||||
if len(df) < 1:
|
||||
continue
|
||||
df.plot(title=config['name'], **plot_args)
|
||||
ps.append(df)
|
||||
return ps
|
||||
|
||||
def do_all(pattern, func, *keys, include_env=False, **kwargs):
|
||||
for config_file, df, config in read_data(pattern, keys=keys):
|
||||
if len(df) < 1:
|
||||
continue
|
||||
p = func(df, *keys, **kwargs)
|
||||
yield config_file, p, config
|
||||
|
||||
|
||||
def group_trials(trials, aggfunc=['mean', 'min', 'max', 'std']):
|
||||
trials = list(trials)
|
||||
trials = list(map(lambda x: x[1] if isinstance(x, tuple) else x, trials))
|
||||
return pd.concat(trials).groupby(level=0).agg(aggfunc).reorder_levels([2, 0,1] ,axis=1)
|
||||
|
||||
|
||||
def fillna(df):
|
||||
new_df = df.ffill(axis=0)
|
||||
return new_df
|
167
soil/config.py
167
soil/config.py
@ -1,12 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, ValidationError, validator, root_validator
|
||||
|
||||
import yaml
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Type
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from . import environment, utils
|
||||
|
||||
import networkx as nx
|
||||
|
||||
|
||||
@ -36,7 +42,6 @@ class NetParams(BaseModel, extra=Extra.allow):
|
||||
|
||||
|
||||
class NetConfig(BaseModel):
|
||||
group: str = 'network'
|
||||
params: Optional[NetParams]
|
||||
topology: Optional[Union[Topology, nx.Graph]]
|
||||
path: Optional[str]
|
||||
@ -56,9 +61,6 @@ class NetConfig(BaseModel):
|
||||
|
||||
|
||||
class EnvConfig(BaseModel):
|
||||
environment_class: Union[Type, str] = 'soil.Environment'
|
||||
params: Dict[str, Any] = {}
|
||||
schedule: Union[Type, str] = 'soil.time.TimedActivation'
|
||||
|
||||
@staticmethod
|
||||
def default():
|
||||
@ -67,19 +69,19 @@ class EnvConfig(BaseModel):
|
||||
|
||||
class SingleAgentConfig(BaseModel):
|
||||
agent_class: Optional[Union[Type, str]] = None
|
||||
agent_id: Optional[int] = None
|
||||
unique_id: Optional[int] = None
|
||||
topology: Optional[str] = None
|
||||
node_id: Optional[Union[int, str]] = None
|
||||
name: Optional[str] = None
|
||||
state: Optional[Dict[str, Any]] = {}
|
||||
|
||||
|
||||
class FixedAgentConfig(SingleAgentConfig):
|
||||
n: Optional[int] = 1
|
||||
hidden: Optional[bool] = False # Do not count this agent towards total agent count
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, values):
|
||||
if values.get('agent_id', None) is not None and values.get('n', 1) > 1:
|
||||
print(values)
|
||||
raise ValueError(f"An agent_id can only be provided when there is only one agent ({values.get('n')} given)")
|
||||
return values
|
||||
|
||||
@ -88,13 +90,19 @@ class OverrideAgentConfig(FixedAgentConfig):
|
||||
filter: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class Strategy(Enum):
|
||||
topology = 'topology'
|
||||
total = 'total'
|
||||
|
||||
|
||||
class AgentDistro(SingleAgentConfig):
|
||||
weight: Optional[float] = 1
|
||||
strategy: Strategy = Strategy.topology
|
||||
|
||||
|
||||
class AgentConfig(SingleAgentConfig):
|
||||
n: Optional[int] = None
|
||||
topology: Optional[str] = None
|
||||
topology: Optional[str]
|
||||
distribution: Optional[List[AgentDistro]] = None
|
||||
fixed: Optional[List[FixedAgentConfig]] = None
|
||||
override: Optional[List[OverrideAgentConfig]] = None
|
||||
@ -110,19 +118,32 @@ class AgentConfig(SingleAgentConfig):
|
||||
return values
|
||||
|
||||
|
||||
class Config(BaseModel, extra=Extra.forbid):
|
||||
class Config(BaseModel, extra=Extra.allow):
|
||||
version: Optional[str] = '1'
|
||||
|
||||
id: str = 'Unnamed Simulation'
|
||||
name: str = 'Unnamed Simulation'
|
||||
description: Optional[str] = None
|
||||
group: str = None
|
||||
dir_path: Optional[str] = None
|
||||
num_trials: int = 1
|
||||
max_time: float = 100
|
||||
max_steps: int = -1
|
||||
interval: float = 1
|
||||
seed: str = ""
|
||||
dry_run: bool = False
|
||||
|
||||
model_class: Union[Type, str]
|
||||
model_parameters: Optiona[Dict[str, Any]] = {}
|
||||
model_class: Union[Type, str] = environment.Environment
|
||||
model_params: Optional[Dict[str, Any]] = {}
|
||||
|
||||
visualization_params: Optional[Dict[str, Any]] = {}
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, cfg):
|
||||
if isinstance(cfg, Config):
|
||||
return cfg
|
||||
if cfg.get('version', '1') == '1' and any(k in cfg for k in ['agents', 'agent_class', 'topology', 'environment_class']):
|
||||
return convert_old(cfg)
|
||||
return Config(**cfg)
|
||||
|
||||
|
||||
def convert_old(old, strict=True):
|
||||
@ -132,87 +153,84 @@ def convert_old(old, strict=True):
|
||||
This is still a work in progress and might not work in many cases.
|
||||
'''
|
||||
|
||||
#TODO: implement actual conversion
|
||||
print('The old configuration format is no longer supported. \
|
||||
Update your config files or run Soil==0.20')
|
||||
raise NotImplementedError()
|
||||
utils.logger.warning('The old configuration format is deprecated. The converted file MAY NOT yield the right results')
|
||||
|
||||
|
||||
new = {}
|
||||
|
||||
general = {}
|
||||
for k in ['id',
|
||||
'group',
|
||||
'dir_path',
|
||||
'num_trials',
|
||||
'max_time',
|
||||
'interval',
|
||||
'seed']:
|
||||
if k in old:
|
||||
general[k] = old[k]
|
||||
|
||||
if 'name' in old:
|
||||
general['id'] = old['name']
|
||||
new = old.copy()
|
||||
|
||||
network = {}
|
||||
|
||||
if 'topology' in old:
|
||||
del new['topology']
|
||||
network['topology'] = old['topology']
|
||||
|
||||
if 'network_params' in old and old['network_params']:
|
||||
del new['network_params']
|
||||
for (k, v) in old['network_params'].items():
|
||||
if k == 'path':
|
||||
network['path'] = v
|
||||
else:
|
||||
network.setdefault('params', {})[k] = v
|
||||
|
||||
if 'topology' in old:
|
||||
network['topology'] = old['topology']
|
||||
topologies = {}
|
||||
if network:
|
||||
topologies['default'] = network
|
||||
|
||||
agents = {
|
||||
'network': {},
|
||||
'default': {},
|
||||
}
|
||||
|
||||
if 'agent_class' in old:
|
||||
agents['default']['agent_class'] = old['agent_class']
|
||||
|
||||
if 'default_state' in old:
|
||||
agents['default']['state'] = old['default_state']
|
||||
|
||||
agents = {'fixed': [], 'distribution': []}
|
||||
|
||||
def updated_agent(agent):
|
||||
'''Convert an agent definition'''
|
||||
newagent = dict(agent)
|
||||
newagent['agent_class'] = newagent['agent_class']
|
||||
del newagent['agent_class']
|
||||
return newagent
|
||||
|
||||
for agent in old.get('environment_agents', []):
|
||||
agents['environment'] = {'distribution': [], 'fixed': []}
|
||||
if 'agent_id' in agent:
|
||||
agent['name'] = agent['agent_id']
|
||||
del agent['agent_id']
|
||||
agents['environment']['fixed'].append(updated_agent(agent))
|
||||
|
||||
by_weight = []
|
||||
fixed = []
|
||||
override = []
|
||||
|
||||
if 'network_agents' in old:
|
||||
agents['network']['topology'] = 'default'
|
||||
if 'environment_agents' in new:
|
||||
|
||||
for agent in old['network_agents']:
|
||||
for agent in new['environment_agents']:
|
||||
agent.setdefault('state', {})['group'] = 'environment'
|
||||
if 'agent_id' in agent:
|
||||
agent['state']['name'] = agent['agent_id']
|
||||
del agent['agent_id']
|
||||
agent['hidden'] = True
|
||||
agent['topology'] = None
|
||||
fixed.append(updated_agent(agent))
|
||||
del new['environment_agents']
|
||||
|
||||
|
||||
if 'agent_class' in old:
|
||||
del new['agent_class']
|
||||
agents['agent_class'] = old['agent_class']
|
||||
|
||||
if 'default_state' in old:
|
||||
del new['default_state']
|
||||
agents['state'] = old['default_state']
|
||||
|
||||
if 'network_agents' in old:
|
||||
agents['topology'] = 'default'
|
||||
|
||||
agents.setdefault('state', {})['group'] = 'network'
|
||||
|
||||
for agent in new['network_agents']:
|
||||
agent = updated_agent(agent)
|
||||
if 'agent_id' in agent:
|
||||
agent['state']['name'] = agent['agent_id']
|
||||
del agent['agent_id']
|
||||
fixed.append(agent)
|
||||
else:
|
||||
by_weight.append(agent)
|
||||
del new['network_agents']
|
||||
|
||||
if 'agent_class' in old and (not fixed and not by_weight):
|
||||
agents['network']['topology'] = 'default'
|
||||
by_weight = [{'agent_class': old['agent_class']}]
|
||||
agents['topology'] = 'default'
|
||||
by_weight = [{'agent_class': old['agent_class'], 'weight': 1}]
|
||||
|
||||
|
||||
# TODO: translate states properly
|
||||
if 'states' in old:
|
||||
del new['states']
|
||||
states = old['states']
|
||||
if isinstance(states, dict):
|
||||
states = states.items()
|
||||
@ -220,22 +238,29 @@ def convert_old(old, strict=True):
|
||||
states = enumerate(states)
|
||||
for (k, v) in states:
|
||||
override.append({'filter': {'node_id': k},
|
||||
'state': v
|
||||
})
|
||||
'state': v})
|
||||
|
||||
agents['network']['override'] = override
|
||||
agents['network']['fixed'] = fixed
|
||||
agents['network']['distribution'] = by_weight
|
||||
agents['override'] = override
|
||||
agents['fixed'] = fixed
|
||||
agents['distribution'] = by_weight
|
||||
|
||||
|
||||
model_params = {}
|
||||
if 'environment_params' in new:
|
||||
del new['environment_params']
|
||||
model_params = dict(old['environment_params'])
|
||||
|
||||
environment = {'params': {}}
|
||||
if 'environment_class' in old:
|
||||
environment['environment_class'] = old['environment_class']
|
||||
del new['environment_class']
|
||||
new['model_class'] = old['environment_class']
|
||||
|
||||
for (k, v) in old.get('environment_params', {}).items():
|
||||
environment['params'][k] = v
|
||||
if 'dump' in old:
|
||||
del new['dump']
|
||||
new['dry_run'] = not old['dump']
|
||||
|
||||
model_params['topologies'] = topologies
|
||||
model_params['agents'] = agents
|
||||
|
||||
return Config(version='2',
|
||||
general=general,
|
||||
topologies={'default': network},
|
||||
environment=environment,
|
||||
agents=agents)
|
||||
model_params=model_params,
|
||||
**new)
|
||||
|
151
soil/debugging.py
Normal file
151
soil/debugging.py
Normal file
@ -0,0 +1,151 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pdb
|
||||
import sys
|
||||
import os
|
||||
|
||||
from textwrap import indent
|
||||
from functools import wraps
|
||||
|
||||
from .agents import FSM, MetaFSM
|
||||
|
||||
|
||||
def wrapcmd(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, arg: str, temporary=False):
|
||||
sys.settrace(self.trace_dispatch)
|
||||
|
||||
known = globals()
|
||||
known.update(self.curframe.f_globals)
|
||||
known.update(self.curframe.f_locals)
|
||||
known['agent'] = known.get('self', None)
|
||||
known['model'] = known.get('self', {}).get('model')
|
||||
known['attrs'] = arg.strip().split()
|
||||
|
||||
exec(func.__code__, known, known)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Debug(pdb.Pdb):
|
||||
def __init__(self, *args, skip_soil=False, **kwargs):
|
||||
skip = kwargs.get('skip', [])
|
||||
if skip_soil:
|
||||
skip.append('soil.*')
|
||||
skip.append('mesa.*')
|
||||
super(Debug, self).__init__(*args, skip=skip, **kwargs)
|
||||
self.prompt = "[soil-pdb] "
|
||||
|
||||
@staticmethod
|
||||
def _soil_agents(model, attrs=None, pretty=True, **kwargs):
|
||||
for agent in model.agents(**kwargs):
|
||||
d = agent
|
||||
print(' - ' + indent(agent.to_str(keys=attrs, pretty=pretty), ' '))
|
||||
|
||||
@wrapcmd
|
||||
def do_soil_agents():
|
||||
return Debug._soil_agents(model, attrs=attrs or None)
|
||||
|
||||
do_sa = do_soil_agents
|
||||
|
||||
@wrapcmd
|
||||
def do_soil_list():
|
||||
return Debug._soil_agents(model, attrs=['state_id'], pretty=False)
|
||||
|
||||
do_sl = do_soil_list
|
||||
|
||||
@wrapcmd
|
||||
def do_soil_self():
|
||||
if not agent:
|
||||
print('No agent available')
|
||||
return
|
||||
|
||||
keys = None
|
||||
if attrs:
|
||||
keys = []
|
||||
for k in attrs:
|
||||
for key in agent.keys():
|
||||
if key.startswith(k):
|
||||
keys.append(key)
|
||||
|
||||
print(agent.to_str(pretty=True, keys=keys))
|
||||
|
||||
do_ss = do_soil_self
|
||||
|
||||
def do_break_state(self, arg: str, temporary=False):
|
||||
'''
|
||||
Break before a specified state is stepped into.
|
||||
'''
|
||||
|
||||
klass = None
|
||||
state = arg.strip()
|
||||
if not state:
|
||||
self.error("Specify at least a state name")
|
||||
return
|
||||
|
||||
comma = arg.find(':')
|
||||
if comma > 0:
|
||||
state = arg[comma+1:].lstrip()
|
||||
klass = arg[:comma].rstrip()
|
||||
klass = eval(klass,
|
||||
self.curframe.f_globals,
|
||||
self.curframe_locals)
|
||||
|
||||
if klass:
|
||||
klasses = [klass]
|
||||
else:
|
||||
klasses = [k for k in self.curframe.f_globals.values() if isinstance(k, type) and issubclass(k, FSM)]
|
||||
print(klasses)
|
||||
if not klasses:
|
||||
self.error('No agent classes found')
|
||||
|
||||
for klass in klasses:
|
||||
try:
|
||||
func = getattr(klass, state)
|
||||
except AttributeError:
|
||||
continue
|
||||
if hasattr(func, '__func__'):
|
||||
func = func.__func__
|
||||
|
||||
code = func.__code__
|
||||
#use co_name to identify the bkpt (function names
|
||||
#could be aliased, but co_name is invariant)
|
||||
funcname = code.co_name
|
||||
lineno = code.co_firstlineno
|
||||
filename = code.co_filename
|
||||
|
||||
# Check for reasonable breakpoint
|
||||
line = self.checkline(filename, lineno)
|
||||
if not line:
|
||||
raise ValueError('no line found')
|
||||
# now set the break point
|
||||
cond = None
|
||||
existing = self.get_breaks(filename, line)
|
||||
if existing:
|
||||
self.message("Breakpoint already exists at %s:%d" %
|
||||
(filename, line))
|
||||
continue
|
||||
err = self.set_break(filename, line, temporary, cond, funcname)
|
||||
if err:
|
||||
self.error(err)
|
||||
else:
|
||||
bp = self.get_breaks(filename, line)[-1]
|
||||
self.message("Breakpoint %d at %s:%d" %
|
||||
(bp.number, bp.file, bp.line))
|
||||
do_bs = do_break_state
|
||||
|
||||
|
||||
def setup(frame=None):
|
||||
debugger = Debug()
|
||||
frame = frame or sys._getframe().f_back
|
||||
debugger.set_trace(frame)
|
||||
|
||||
def debug_env():
|
||||
if os.environ.get('SOIL_DEBUG'):
|
||||
return setup(frame=sys._getframe().f_back)
|
||||
|
||||
def post_mortem(traceback=None):
|
||||
p = Debug()
|
||||
t = sys.exc_info()[2]
|
||||
p.reset()
|
||||
p.interaction(None, t)
|
@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import math
|
||||
@ -17,9 +18,7 @@ import networkx as nx
|
||||
from mesa import Model
|
||||
from mesa.datacollection import DataCollector
|
||||
|
||||
from . import serialization, analysis, utils, time, network
|
||||
|
||||
from .agents import AgentView, BaseAgent, NetworkAgent, from_config as agents_from_config
|
||||
from . import agents as agentmod, config, serialization, utils, time, network
|
||||
|
||||
|
||||
Record = namedtuple('Record', 'dict_id t_step key value')
|
||||
@ -39,12 +38,12 @@ class BaseEnvironment(Model):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
env_id='unnamed_env',
|
||||
id='unnamed_env',
|
||||
seed='default',
|
||||
schedule=None,
|
||||
dir_path=None,
|
||||
interval=1,
|
||||
agent_class=BaseAgent,
|
||||
agent_class=None,
|
||||
agents: [tuple[type, Dict[str, Any]]] = {},
|
||||
agent_reporters: Optional[Any] = None,
|
||||
model_reporters: Optional[Any] = None,
|
||||
@ -54,7 +53,7 @@ class BaseEnvironment(Model):
|
||||
super().__init__(seed=seed)
|
||||
self.current_id = -1
|
||||
|
||||
self.id = env_id
|
||||
self.id = id
|
||||
|
||||
self.dir_path = dir_path or os.getcwd()
|
||||
|
||||
@ -62,7 +61,7 @@ class BaseEnvironment(Model):
|
||||
schedule = time.TimedActivation(self)
|
||||
self.schedule = schedule
|
||||
|
||||
self.agent_class = agent_class
|
||||
self.agent_class = agent_class or agentmod.BaseAgent
|
||||
|
||||
self.init_agents(agents)
|
||||
|
||||
@ -78,25 +77,51 @@ class BaseEnvironment(Model):
|
||||
tables=tables,
|
||||
)
|
||||
|
||||
def __read_agent_tuple(self, tup):
|
||||
cls = self.agent_class
|
||||
args = tup
|
||||
if isinstance(tup, tuple):
|
||||
cls = tup[0]
|
||||
args = tup[1]
|
||||
return serialization.deserialize(cls)(unique_id=self.next_id(),
|
||||
model=self, **args)
|
||||
def _read_single_agent(self, agent):
|
||||
agent = dict(**agent)
|
||||
cls = agent.pop('agent_class', None) or self.agent_class
|
||||
unique_id = agent.pop('unique_id', None)
|
||||
if unique_id is None:
|
||||
unique_id = self.next_id()
|
||||
|
||||
return serialization.deserialize(cls)(unique_id=unique_id,
|
||||
model=self, **agent)
|
||||
|
||||
def init_agents(self, agents: Union[config.AgentConfig, [Dict[str, Any]]] = {}):
|
||||
if not agents:
|
||||
return
|
||||
|
||||
lst = agents
|
||||
override = []
|
||||
if not isinstance(lst, list):
|
||||
if not isinstance(agents, config.AgentConfig):
|
||||
lst = config.AgentConfig(**agents)
|
||||
if lst.override:
|
||||
override = lst.override
|
||||
lst = agentmod.from_config(lst,
|
||||
topologies=getattr(self, 'topologies', None),
|
||||
random=self.random)
|
||||
|
||||
#TODO: check override is working again. It cannot (easily) be part of agents.from_config anymore,
|
||||
# because it needs attribute such as unique_id, which are only present after init
|
||||
new_agents = [self._read_single_agent(agent) for agent in lst]
|
||||
|
||||
|
||||
for a in new_agents:
|
||||
self.schedule.add(a)
|
||||
|
||||
for rule in override:
|
||||
for agent in agentmod.filter_agents(self.schedule._agents, **rule.filter):
|
||||
for attr, value in rule.state.items():
|
||||
setattr(agent, attr, value)
|
||||
|
||||
def init_agents(self, agents: [tuple[type, Dict[str, Any]]] = {}):
|
||||
agents = [self.__read_agent_tuple(tup) for tup in agents]
|
||||
self._agents = {'default': {agent.id: agent for agent in agents}}
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
return AgentView(self._agents)
|
||||
return agentmod.AgentView(self.schedule._agents)
|
||||
|
||||
def find_one(self, *args, **kwargs):
|
||||
return AgentView(self._agents).one(*args, **kwargs)
|
||||
return agentmod.AgentView(self.schedule._agents).one(*args, **kwargs)
|
||||
|
||||
def count_agents(self, *args, **kwargs):
|
||||
return sum(1 for i in self.agents(*args, **kwargs))
|
||||
@ -108,38 +133,12 @@ class BaseEnvironment(Model):
|
||||
raise Exception('The environment has not been scheduled, so it has no sense of time')
|
||||
|
||||
|
||||
# def init_agent(self, agent_id, agent_definitions, state=None):
|
||||
# state = state or {}
|
||||
|
||||
# agent_class = None
|
||||
# if 'agent_class' in self.states.get(agent_id, {}):
|
||||
# agent_class = self.states[agent_id]['agent_class']
|
||||
# elif 'agent_class' in self.default_state:
|
||||
# agent_class = self.default_state['agent_class']
|
||||
|
||||
# if agent_class:
|
||||
# agent_class = agents.deserialize_type(agent_class)
|
||||
# elif agent_definitions:
|
||||
# agent_class, state = agents._agent_from_definition(agent_definitions, unique_id=agent_id)
|
||||
# else:
|
||||
# serialization.logger.debug('Skipping agent {}'.format(agent_id))
|
||||
# return
|
||||
# return self.add_agent(agent_id, agent_class, state)
|
||||
|
||||
|
||||
def add_agent(self, agent_id, agent_class, state=None, graph='default'):
|
||||
defstate = deepcopy(self.default_state) or {}
|
||||
defstate.update(self.states.get(agent_id, {}))
|
||||
if state:
|
||||
defstate.update(state)
|
||||
def add_agent(self, agent_id, agent_class, **kwargs):
|
||||
a = None
|
||||
if agent_class:
|
||||
state = defstate
|
||||
a = agent_class(model=self,
|
||||
unique_id=agent_id)
|
||||
|
||||
for (k, v) in state.items():
|
||||
setattr(a, k, v)
|
||||
unique_id=agent_id,
|
||||
**kwargs)
|
||||
|
||||
self.schedule.add(a)
|
||||
return a
|
||||
@ -153,7 +152,7 @@ class BaseEnvironment(Model):
|
||||
message += " {k}={v} ".format(k, v)
|
||||
extra = {}
|
||||
extra['now'] = self.now
|
||||
extra['unique_id'] = self.id
|
||||
extra['id'] = self.id
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
|
||||
def step(self):
|
||||
@ -161,6 +160,7 @@ class BaseEnvironment(Model):
|
||||
Advance one step in the simulation, and update the data collection and scheduler appropriately
|
||||
'''
|
||||
super().step()
|
||||
self.logger.info(f'--- Step {self.now:^5} ---')
|
||||
self.schedule.step()
|
||||
self.datacollector.collect(self)
|
||||
|
||||
@ -207,34 +207,41 @@ class BaseEnvironment(Model):
|
||||
yield from self._agent_to_tuples(agent, now)
|
||||
|
||||
|
||||
class AgentConfigEnvironment(BaseEnvironment):
|
||||
class NetworkEnvironment(BaseEnvironment):
|
||||
|
||||
def __init__(self, *args,
|
||||
agents: Dict[str, config.AgentConfig] = {},
|
||||
**kwargs):
|
||||
return super().__init__(*args, agents=agents, **kwargs)
|
||||
|
||||
def init_agents(self, agents: Union[Dict[str, config.AgentConfig], [tuple[type, Dict[str, Any]]]] = {}):
|
||||
if not isinstance(agents, dict):
|
||||
return BaseEnvironment.init_agents(self, agents)
|
||||
|
||||
self._agents = agents_from_config(agents,
|
||||
env=self,
|
||||
random=self.random)
|
||||
for d in self._agents.values():
|
||||
for a in d.values():
|
||||
self.schedule.add(a)
|
||||
|
||||
|
||||
class NetworkConfigEnvironment(BaseEnvironment):
|
||||
|
||||
def __init__(self, *args, topologies: Dict[str, config.NetConfig] = {}, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.topologies = {}
|
||||
def __init__(self, *args, topology: nx.Graph = None, topologies: Dict[str, config.NetConfig] = {}, **kwargs):
|
||||
agents = kwargs.pop('agents', None)
|
||||
super().__init__(*args, agents=None, **kwargs)
|
||||
self._node_ids = {}
|
||||
assert not hasattr(self, 'topologies')
|
||||
if topology is not None:
|
||||
if topologies:
|
||||
raise ValueError('Please, provide either a single topology or a dictionary of them')
|
||||
topologies = {'default': topology}
|
||||
|
||||
self.topologies = {}
|
||||
for (name, cfg) in topologies.items():
|
||||
self.set_topology(cfg=cfg, graph=name)
|
||||
|
||||
self.init_agents(agents)
|
||||
|
||||
|
||||
def _read_single_agent(self, agent, unique_id=None):
|
||||
agent = dict(agent)
|
||||
|
||||
if agent.get('topology', None) is not None:
|
||||
topology = agent.get('topology')
|
||||
if unique_id is None:
|
||||
unique_id = self.next_id()
|
||||
if topology:
|
||||
node_id = self.agent_to_node(unique_id, graph_name=topology, node_id=agent.get('node_id'))
|
||||
agent['node_id'] = node_id
|
||||
agent['topology'] = topology
|
||||
agent['unique_id'] = unique_id
|
||||
|
||||
return super()._read_single_agent(agent)
|
||||
|
||||
|
||||
@property
|
||||
def topology(self):
|
||||
return self.topologies['default']
|
||||
@ -246,51 +253,50 @@ class NetworkConfigEnvironment(BaseEnvironment):
|
||||
|
||||
self.topologies[graph] = topology
|
||||
|
||||
def topology_for(self, agent_id):
|
||||
return self.topologies[self._node_ids[agent_id][0]]
|
||||
def topology_for(self, unique_id):
|
||||
return self.topologies[self._node_ids[unique_id][0]]
|
||||
|
||||
@property
|
||||
def network_agents(self):
|
||||
yield from self.agents(agent_class=NetworkAgent)
|
||||
yield from self.agents(agent_class=agentmod.NetworkAgent)
|
||||
|
||||
def agent_to_node(self, agent_id, graph_name='default', node_id=None, shuffle=False):
|
||||
node_id = network.agent_to_node(G=self.topologies[graph_name], agent_id=agent_id,
|
||||
node_id=node_id, shuffle=shuffle,
|
||||
def agent_to_node(self, unique_id, graph_name='default',
|
||||
node_id=None, shuffle=False):
|
||||
node_id = network.agent_to_node(G=self.topologies[graph_name],
|
||||
agent_id=unique_id,
|
||||
node_id=node_id,
|
||||
shuffle=shuffle,
|
||||
random=self.random)
|
||||
|
||||
self._node_ids[agent_id] = (graph_name, node_id)
|
||||
self._node_ids[unique_id] = (graph_name, node_id)
|
||||
return node_id
|
||||
|
||||
def add_node(self, agent_class, topology, **kwargs):
|
||||
unique_id = self.next_id()
|
||||
self.topologies[topology].add_node(unique_id)
|
||||
node_id = self.agent_to_node(unique_id=unique_id, node_id=unique_id, graph_name=topology)
|
||||
|
||||
def add_node(self, agent_class, state=None, graph='default'):
|
||||
agent_id = int(len(self.topologies[graph].nodes()))
|
||||
self.topologies[graph].add_node(agent_id)
|
||||
a = self.add_agent(agent_id, agent_class, state, graph=graph)
|
||||
a = self.add_agent(unique_id=unique_id, agent_class=agent_class, node_id=node_id, topology=topology, **kwargs)
|
||||
a['visible'] = True
|
||||
return a
|
||||
|
||||
def add_edge(self, agent1, agent2, start=None, graph='default', **attrs):
|
||||
if hasattr(agent1, 'id'):
|
||||
agent1 = agent1.id
|
||||
if hasattr(agent2, 'id'):
|
||||
agent2 = agent2.id
|
||||
start = start or self.now
|
||||
return self.topologies[graph].add_edge(agent1, agent2, **attrs)
|
||||
agent1 = agent1.node_id
|
||||
agent2 = agent2.node_id
|
||||
return self.topologies[graph].add_edge(agent1, agent2, start=start)
|
||||
|
||||
def add_agent(self, *args, state=None, graph='default', **kwargs):
|
||||
node = self.topologies[graph].nodes[agent_id]
|
||||
def add_agent(self, unique_id, state=None, graph='default', **kwargs):
|
||||
node = self.topologies[graph].nodes[unique_id]
|
||||
node_state = node.get('state', {})
|
||||
if node_state:
|
||||
node_state.update(state or {})
|
||||
state = node_state
|
||||
a = super().add_agent(*args, state=state, **kwargs)
|
||||
a = super().add_agent(unique_id, state=state, **kwargs)
|
||||
node['agent'] = a
|
||||
return a
|
||||
|
||||
def node_id_for(self, agent_id):
|
||||
return self._node_ids[agent_id][1]
|
||||
|
||||
class Environment(AgentConfigEnvironment, NetworkConfigEnvironment):
|
||||
def __init__(self, *args, **kwargs):
|
||||
agents = kwargs.pop('agents', {})
|
||||
NetworkConfigEnvironment.__init__(self, *args, **kwargs)
|
||||
AgentConfigEnvironment.__init__(self, *args, agents=agents, **kwargs)
|
||||
|
||||
Environment = NetworkEnvironment
|
||||
|
@ -12,7 +12,7 @@ from .serialization import deserialize
|
||||
from .utils import open_or_reuse, logger, timer
|
||||
|
||||
|
||||
from . import utils
|
||||
from . import utils, network
|
||||
|
||||
|
||||
class DryRunner(BytesIO):
|
||||
@ -85,38 +85,28 @@ class Exporter:
|
||||
class default(Exporter):
|
||||
'''Default exporter. Writes sqlite results, as well as the simulation YAML'''
|
||||
|
||||
# def sim_start(self):
|
||||
# if not self.dry_run:
|
||||
# logger.info('Dumping results to %s', self.outdir)
|
||||
# self.simulation.dump_yaml(outdir=self.outdir)
|
||||
# else:
|
||||
# logger.info('NOT dumping results')
|
||||
def sim_start(self):
|
||||
if not self.dry_run:
|
||||
logger.info('Dumping results to %s', self.outdir)
|
||||
with self.output(self.simulation.name + '.dumped.yml') as f:
|
||||
f.write(self.simulation.to_yaml())
|
||||
else:
|
||||
logger.info('NOT dumping results')
|
||||
|
||||
# def trial_start(self, env, stats):
|
||||
# if not self.dry_run:
|
||||
# with timer('Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||
# env.name)):
|
||||
# engine = create_engine('sqlite:///{}.sqlite'.format(env.name), echo=False)
|
||||
def trial_end(self, env):
|
||||
if not self.dry_run:
|
||||
with timer('Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||
env.id)):
|
||||
engine = create_engine('sqlite:///{}.sqlite'.format(env.id), echo=False)
|
||||
|
||||
# dc = env.datacollector
|
||||
# tables = {'env': dc.get_model_vars_dataframe(),
|
||||
# 'agents': dc.get_agent_vars_dataframe(),
|
||||
# 'agents': dc.get_agent_vars_dataframe()}
|
||||
# for table in dc.tables:
|
||||
# tables[table] = dc.get_table_dataframe(table)
|
||||
# for (t, df) in tables.items():
|
||||
# df.to_sql(t, con=engine)
|
||||
|
||||
# def sim_end(self, stats):
|
||||
# with timer('Dumping simulation {}\'s stats'.format(self.simulation.name)):
|
||||
# engine = create_engine('sqlite:///{}.sqlite'.format(self.simulation.name), echo=False)
|
||||
# with self.output('{}.sqlite'.format(self.simulation.name), mode='wb') as f:
|
||||
# self.simulation.dump_sqlite(f)
|
||||
dc = env.datacollector
|
||||
for (t, df) in get_dc_dfs(dc):
|
||||
df.to_sql(t, con=engine, if_exists='append')
|
||||
|
||||
|
||||
def get_dc_dfs(dc):
|
||||
dfs = {'env': dc.get_model_vars_dataframe(),
|
||||
'agents': dc.get_agent_vars_dataframe }
|
||||
'agents': dc.get_agent_vars_dataframe() }
|
||||
for table_name in dc.tables:
|
||||
dfs[table_name] = dc.get_table_dataframe(table_name)
|
||||
yield from dfs.items()
|
||||
@ -130,10 +120,11 @@ class csv(Exporter):
|
||||
env.id,
|
||||
self.outdir)):
|
||||
for (df_name, df) in get_dc_dfs(env.datacollector):
|
||||
with self.output('{}.stats.{}.csv'.format(env.id, df_name)) as f:
|
||||
with self.output('{}.{}.csv'.format(env.id, df_name)) as f:
|
||||
df.to_csv(f)
|
||||
|
||||
|
||||
#TODO: reimplement GEXF exporting without history
|
||||
class gexf(Exporter):
|
||||
def trial_end(self, env):
|
||||
if self.dry_run:
|
||||
@ -143,18 +134,9 @@ class gexf(Exporter):
|
||||
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||
env.id)):
|
||||
with self.output('{}.gexf'.format(env.id), mode='wb') as f:
|
||||
network.dump_gexf(env.history_to_graph(), f)
|
||||
self.dump_gexf(env, f)
|
||||
|
||||
def dump_gexf(self, env, f):
|
||||
G = env.history_to_graph()
|
||||
# Workaround for geometric models
|
||||
# See soil/soil#4
|
||||
for node in G.nodes():
|
||||
if 'pos' in G.nodes[node]:
|
||||
G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}}
|
||||
del (G.nodes[node]['pos'])
|
||||
|
||||
nx.write_gexf(G, f, version="1.2draft")
|
||||
|
||||
class dummy(Exporter):
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
import os
|
||||
import sys
|
||||
@ -37,8 +39,10 @@ def from_config(cfg: config.NetConfig, dir_path: str = None):
|
||||
known_modules=['networkx.generators',])
|
||||
return method(**net_args)
|
||||
|
||||
if isinstance(cfg.topology, basestring) or isinstance(cfg.topology, dict):
|
||||
return nx.json_graph.node_link_graph(cfg.topology)
|
||||
if isinstance(cfg.topology, config.Topology):
|
||||
cfg = cfg.topology.dict()
|
||||
if isinstance(cfg, str) or isinstance(cfg, dict):
|
||||
return nx.json_graph.node_link_graph(cfg)
|
||||
|
||||
return nx.Graph()
|
||||
|
||||
@ -57,9 +61,18 @@ def agent_to_node(G, agent_id, node_id=None, shuffle=False, random=random):
|
||||
for next_id, data in candidates:
|
||||
if data.get('agent_id', None) is None:
|
||||
node_id = next_id
|
||||
data['agent_id'] = agent_id
|
||||
break
|
||||
|
||||
if node_id is None:
|
||||
raise ValueError(f"Not enough nodes in topology to assign one to agent {agent_id}")
|
||||
G.nodes[node_id]['agent_id'] = agent_id
|
||||
return node_id
|
||||
|
||||
|
||||
def dump_gexf(G, f):
|
||||
for node in G.nodes():
|
||||
if 'pos' in G.nodes[node]:
|
||||
G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}}
|
||||
del (G.nodes[node]['pos'])
|
||||
|
||||
nx.write_gexf(G, f, version="1.2draft")
|
||||
|
@ -7,6 +7,8 @@ import importlib
|
||||
from glob import glob
|
||||
from itertools import product, chain
|
||||
|
||||
from .config import Config
|
||||
|
||||
import yaml
|
||||
import networkx as nx
|
||||
|
||||
@ -120,22 +122,25 @@ def params_for_template(config):
|
||||
def load_files(*patterns, **kwargs):
|
||||
for pattern in patterns:
|
||||
for i in glob(pattern, **kwargs):
|
||||
for config in load_file(i):
|
||||
for cfg in load_file(i):
|
||||
path = os.path.abspath(i)
|
||||
yield config, path
|
||||
yield Config.from_raw(cfg), path
|
||||
|
||||
|
||||
def load_config(config):
|
||||
if isinstance(config, dict):
|
||||
yield config, os.getcwd()
|
||||
def load_config(cfg):
|
||||
if isinstance(cfg, Config):
|
||||
yield cfg, os.getcwd()
|
||||
elif isinstance(cfg, dict):
|
||||
yield Config.from_raw(cfg), os.getcwd()
|
||||
else:
|
||||
yield from load_files(config)
|
||||
yield from load_files(cfg)
|
||||
|
||||
|
||||
builtins = importlib.import_module('builtins')
|
||||
|
||||
KNOWN_MODULES = ['soil', ]
|
||||
|
||||
|
||||
def name(value, known_modules=KNOWN_MODULES):
|
||||
'''Return a name that can be imported, to serialize/deserialize an object'''
|
||||
if value is None:
|
||||
@ -172,8 +177,22 @@ def serialize(v, known_modules=KNOWN_MODULES):
|
||||
return func(v), tname
|
||||
|
||||
|
||||
def serialize_dict(d, known_modules=KNOWN_MODULES):
|
||||
d = dict(d)
|
||||
for (k, v) in d.items():
|
||||
if isinstance(v, dict):
|
||||
d[k] = serialize_dict(v, known_modules=known_modules)
|
||||
elif isinstance(v, list):
|
||||
for ix in range(len(v)):
|
||||
v[ix] = serialize_dict(v[ix], known_modules=known_modules)
|
||||
elif isinstance(v, type):
|
||||
d[k] = serialize(v, known_modules=known_modules)[1]
|
||||
return d
|
||||
|
||||
|
||||
IS_CLASS = re.compile(r"<class '(.*)'>")
|
||||
|
||||
|
||||
def deserializer(type_, known_modules=KNOWN_MODULES):
|
||||
if type(type_) != str: # Already deserialized
|
||||
return type_
|
||||
|
@ -4,15 +4,17 @@ import importlib
|
||||
import sys
|
||||
import yaml
|
||||
import traceback
|
||||
import inspect
|
||||
import logging
|
||||
import networkx as nx
|
||||
|
||||
from textwrap import dedent
|
||||
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import Union
|
||||
from typing import Any, Dict, Union, Optional
|
||||
|
||||
|
||||
from networkx.readwrite import json_graph
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
import pickle
|
||||
|
||||
@ -21,7 +23,6 @@ from .environment import Environment
|
||||
from .utils import logger, run_and_return_exceptions
|
||||
from .exporters import default
|
||||
from .time import INFINITY
|
||||
|
||||
from .config import Config, convert_old
|
||||
|
||||
|
||||
@ -36,7 +37,9 @@ class Simulation:
|
||||
|
||||
kwargs: parameters to use to initialize a new configuration, if one has not been provided.
|
||||
"""
|
||||
version: str = '2'
|
||||
name: str = 'Unnamed simulation'
|
||||
description: Optional[str] = ''
|
||||
group: str = None
|
||||
model_class: Union[str, type] = 'soil.Environment'
|
||||
model_params: dict = field(default_factory=dict)
|
||||
@ -44,30 +47,37 @@ class Simulation:
|
||||
dir_path: str = field(default_factory=lambda: os.getcwd())
|
||||
max_time: float = float('inf')
|
||||
max_steps: int = -1
|
||||
interval: int = 1
|
||||
num_trials: int = 3
|
||||
dry_run: bool = False
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, env):
|
||||
|
||||
ignored = {k: v for k, v in env.items()
|
||||
if k not in inspect.signature(cls).parameters}
|
||||
|
||||
kwargs = {k:v for k, v in env.items() if k not in ignored}
|
||||
if ignored:
|
||||
kwargs.setdefault('extra', {}).update(ignored)
|
||||
if ignored:
|
||||
print(f'Warning: Ignoring these parameters (added to "extra"): { ignored }')
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
def run_simulation(self, *args, **kwargs):
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
'''Run the simulation and return the list of resulting environments'''
|
||||
logger.info(dedent('''
|
||||
Simulation:
|
||||
---
|
||||
''') +
|
||||
self.to_yaml())
|
||||
return list(self.run_gen(*args, **kwargs))
|
||||
|
||||
def _run_sync_or_async(self, parallel=False, **kwargs):
|
||||
if parallel and not os.environ.get('SENPY_DEBUG', None):
|
||||
p = Pool()
|
||||
func = partial(run_and_return_exceptions, self.run_trial, **kwargs)
|
||||
for i in p.imap_unordered(func, self.num_trials):
|
||||
if isinstance(i, Exception):
|
||||
logger.error('Trial failed:\n\t%s', i.message)
|
||||
continue
|
||||
yield i
|
||||
else:
|
||||
for i in range(self.num_trials):
|
||||
yield self.run_trial(trial_id=i,
|
||||
**kwargs)
|
||||
|
||||
def run_gen(self, parallel=False, dry_run=False,
|
||||
exporters=[default, ], outdir=None, exporter_params={},
|
||||
log_level=None,
|
||||
@ -88,7 +98,9 @@ class Simulation:
|
||||
for exporter in exporters:
|
||||
exporter.sim_start()
|
||||
|
||||
for env in self._run_sync_or_async(parallel=parallel,
|
||||
for env in utils.run_parallel(func=self.run_trial,
|
||||
iterable=range(int(self.num_trials)),
|
||||
parallel=parallel,
|
||||
log_level=log_level,
|
||||
**kwargs):
|
||||
|
||||
@ -103,14 +115,6 @@ class Simulation:
|
||||
for exporter in exporters:
|
||||
exporter.sim_end()
|
||||
|
||||
def run_model(self, until=None, *args, **kwargs):
|
||||
until = until or float('inf')
|
||||
|
||||
while self.schedule.next_time < until:
|
||||
self.step()
|
||||
utils.logger.debug(f'Simulation step {self.schedule.time}/{until}. Next: {self.schedule.next_time}')
|
||||
self.schedule.time = until
|
||||
|
||||
def get_env(self, trial_id=0, **kwargs):
|
||||
'''Create an environment for a trial of the simulation'''
|
||||
def deserialize_reporters(reporters):
|
||||
@ -132,29 +136,50 @@ class Simulation:
|
||||
model_reporters=model_reporters,
|
||||
**model_params)
|
||||
|
||||
def run_trial(self, trial_id=None, until=None, log_level=logging.INFO, **opts):
|
||||
def run_trial(self, trial_id=None, until=None, log_file=False, log_level=logging.INFO, **opts):
|
||||
"""
|
||||
Run a single trial of the simulation
|
||||
|
||||
"""
|
||||
model = self.get_env(trial_id, **opts)
|
||||
return self.run_model(model, trial_id=trial_id, until=until, log_level=log_level)
|
||||
|
||||
def run_model(self, model, trial_id=None, until=None, log_level=logging.INFO, **opts):
|
||||
trial_id = trial_id if trial_id is not None else current_time()
|
||||
if log_level:
|
||||
logger.setLevel(log_level)
|
||||
model = self.get_env(trial_id, **opts)
|
||||
trial_id = trial_id if trial_id is not None else current_time()
|
||||
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
||||
return self.run_model(model=model, trial_id=trial_id, until=until, log_level=log_level)
|
||||
|
||||
def run_model(self, model, until=None, **opts):
|
||||
# Set-up trial environment and graph
|
||||
until = until or self.max_time
|
||||
until = float(until or self.max_time or 'inf')
|
||||
|
||||
# Set up agents on nodes
|
||||
is_done = lambda: False
|
||||
if self.max_time and hasattr(self.schedule, 'time'):
|
||||
is_done = lambda x: is_done() or self.schedule.time >= self.max_time
|
||||
if self.max_steps and hasattr(self.schedule, 'time'):
|
||||
is_done = lambda: is_done() or self.schedule.steps >= self.max_steps
|
||||
def is_done():
|
||||
return False
|
||||
|
||||
if until and hasattr(model.schedule, 'time'):
|
||||
prev = is_done
|
||||
|
||||
def is_done():
|
||||
return prev() or model.schedule.time >= until
|
||||
|
||||
if self.max_steps and self.max_steps > 0 and hasattr(model.schedule, 'steps'):
|
||||
prev_steps = is_done
|
||||
|
||||
def is_done():
|
||||
return prev_steps() or model.schedule.steps >= self.max_steps
|
||||
|
||||
newline = '\n'
|
||||
logger.info(dedent(f'''
|
||||
Model stats:
|
||||
Agents (total: { model.schedule.get_agent_count() }):
|
||||
- { (newline + ' - ').join(str(a) for a in model.schedule.agents) }'''
|
||||
f'''
|
||||
|
||||
Topologies (size):
|
||||
- { dict( (k, len(v)) for (k, v) in model.topologies.items()) }
|
||||
''' if getattr(model, "topologies", None) else ''
|
||||
))
|
||||
|
||||
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
||||
while not is_done():
|
||||
utils.logger.debug(f'Simulation time {model.schedule.time}/{until}. Next: {getattr(model.schedule, "next_time", model.schedule.time + self.interval)}')
|
||||
model.step()
|
||||
@ -162,26 +187,25 @@ class Simulation:
|
||||
|
||||
def to_dict(self):
|
||||
d = asdict(self)
|
||||
d['model_class'] = serialization.serialize(d['model_class'])[0]
|
||||
d['model_params'] = serialization.serialize(d['model_params'])[0]
|
||||
if not isinstance(d['model_class'], str):
|
||||
d['model_class'] = serialization.name(d['model_class'])
|
||||
d['model_params'] = serialization.serialize_dict(d['model_params'])
|
||||
d['dir_path'] = str(d['dir_path'])
|
||||
|
||||
d['version'] = '2'
|
||||
return d
|
||||
|
||||
def to_yaml(self):
|
||||
return yaml.dump(self.asdict())
|
||||
return yaml.dump(self.to_dict())
|
||||
|
||||
|
||||
def iter_from_config(config):
|
||||
def iter_from_config(*cfgs):
|
||||
for config in cfgs:
|
||||
configs = list(serialization.load_config(config))
|
||||
for config, path in configs:
|
||||
d = dict(config)
|
||||
if 'dir_path' not in d:
|
||||
d['dir_path'] = os.path.dirname(path)
|
||||
if d.get('version', '2') == '1' or 'agents' in d or 'network_agents' in d or 'environment_agents' in d:
|
||||
d = convert_old(d)
|
||||
d.pop('version', None)
|
||||
yield Simulation(**d)
|
||||
yield Simulation.from_dict(d)
|
||||
|
||||
|
||||
def from_config(conf_or_path):
|
||||
@ -192,6 +216,6 @@ def from_config(conf_or_path):
|
||||
|
||||
|
||||
def run_from_config(*configs, **kwargs):
|
||||
for sim in iter_from_config(configs):
|
||||
logger.info(f"Using config(s): {sim.id}")
|
||||
for sim in iter_from_config(*configs):
|
||||
logger.info(f"Using config(s): {sim.name}")
|
||||
sim.run_simulation(**kwargs)
|
||||
|
28
soil/time.py
28
soil/time.py
@ -1,6 +1,6 @@
|
||||
from mesa.time import BaseScheduler
|
||||
from queue import Empty
|
||||
from heapq import heappush, heappop
|
||||
from heapq import heappush, heappop, heapify
|
||||
import math
|
||||
from .utils import logger
|
||||
from mesa import Agent as MesaAgent
|
||||
@ -17,6 +17,7 @@ class When:
|
||||
def abs(self, time):
|
||||
return self._time
|
||||
|
||||
|
||||
NEVER = When(INFINITY)
|
||||
|
||||
|
||||
@ -38,13 +39,21 @@ class TimedActivation(BaseScheduler):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._next = {}
|
||||
self._queue = []
|
||||
self.next_time = 0
|
||||
self.logger = logger.getChild(f'time_{ self.model }')
|
||||
|
||||
def add(self, agent: MesaAgent):
|
||||
if agent.unique_id not in self._agents:
|
||||
heappush(self._queue, (self.time, agent.unique_id))
|
||||
def add(self, agent: MesaAgent, when=None):
|
||||
if when is None:
|
||||
when = self.time
|
||||
if agent.unique_id in self._agents:
|
||||
self._queue.remove((self._next[agent.unique_id], agent.unique_id))
|
||||
del self._agents[agent.unique_id]
|
||||
heapify(self._queue)
|
||||
|
||||
heappush(self._queue, (when, agent.unique_id))
|
||||
self._next[agent.unique_id] = when
|
||||
super().add(agent)
|
||||
|
||||
def step(self) -> None:
|
||||
@ -64,11 +73,18 @@ class TimedActivation(BaseScheduler):
|
||||
(when, agent_id) = heappop(self._queue)
|
||||
self.logger.debug(f'Stepping agent {agent_id}')
|
||||
|
||||
returned = self._agents[agent_id].step()
|
||||
agent = self._agents[agent_id]
|
||||
returned = agent.step()
|
||||
|
||||
if not agent.alive:
|
||||
self.remove(agent)
|
||||
continue
|
||||
|
||||
when = (returned or Delta(1)).abs(self.time)
|
||||
if when < self.time:
|
||||
raise Exception("Cannot schedule an agent for a time in the past ({} < {})".format(when, self.time))
|
||||
|
||||
self._next[agent_id] = when
|
||||
heappush(self._queue, (when, agent_id))
|
||||
|
||||
self.steps += 1
|
||||
@ -77,7 +93,7 @@ class TimedActivation(BaseScheduler):
|
||||
self.time = INFINITY
|
||||
self.next_time = INFINITY
|
||||
self.model.running = False
|
||||
return
|
||||
return self.time
|
||||
|
||||
self.next_time = self._queue[0][0]
|
||||
self.logger.debug(f'Next step: {self.next_time}')
|
||||
|
@ -3,13 +3,27 @@ from time import time as current_time, strftime, gmtime, localtime
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from functools import partial
|
||||
from shutil import copyfile
|
||||
from multiprocessing import Pool
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = logging.getLogger('soil')
|
||||
# logging.basicConfig()
|
||||
# logger.setLevel(logging.INFO)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
timeformat = "%H:%M:%S"
|
||||
|
||||
if os.environ.get('SOIL_VERBOSE', ''):
|
||||
logformat = "[%(levelname)-5.5s][%(asctime)s][%(name)s]: %(message)s"
|
||||
else:
|
||||
logformat = "[%(levelname)-5.5s][%(asctime)s] %(message)s"
|
||||
|
||||
logFormatter = logging.Formatter(logformat, timeformat)
|
||||
|
||||
consoleHandler = logging.StreamHandler()
|
||||
consoleHandler.setFormatter(logFormatter)
|
||||
logger.addHandler(consoleHandler)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -27,8 +41,6 @@ def timer(name='task', pre="", function=logger.info, to_object=None):
|
||||
to_object.end = end
|
||||
|
||||
|
||||
|
||||
|
||||
def safe_open(path, mode='r', backup=True, **kwargs):
|
||||
outdir = os.path.dirname(path)
|
||||
if outdir and not os.path.exists(outdir):
|
||||
@ -92,7 +104,7 @@ def unflatten_dict(d):
|
||||
return out
|
||||
|
||||
|
||||
def run_and_return_exceptions(self, func, *args, **kwargs):
|
||||
def run_and_return_exceptions(func, *args, **kwargs):
|
||||
'''
|
||||
A wrapper for run_trial that catches exceptions and returns them.
|
||||
It is meant for async simulations.
|
||||
@ -104,3 +116,18 @@ def run_and_return_exceptions(self, func, *args, **kwargs):
|
||||
ex = ex.__cause__
|
||||
ex.message = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)[:])
|
||||
return ex
|
||||
|
||||
|
||||
def run_parallel(func, iterable, parallel=False, **kwargs):
|
||||
if parallel and not os.environ.get('SOIL_DEBUG', None):
|
||||
p = Pool()
|
||||
wrapped_func = partial(run_and_return_exceptions,
|
||||
func, **kwargs)
|
||||
for i in p.imap_unordered(wrapped_func, iterable):
|
||||
if isinstance(i, Exception):
|
||||
logger.error('Trial failed:\n\t%s', i.message)
|
||||
continue
|
||||
yield i
|
||||
else:
|
||||
for i in iterable:
|
||||
yield func(i, **kwargs)
|
||||
|
@ -1,49 +1,50 @@
|
||||
---
|
||||
version: '2'
|
||||
general:
|
||||
id: simple
|
||||
group: tests
|
||||
dir_path: "/tmp/"
|
||||
num_trials: 3
|
||||
max_time: 100
|
||||
interval: 1
|
||||
seed: "CompleteSeed!"
|
||||
topologies:
|
||||
name: simple
|
||||
group: tests
|
||||
dir_path: "/tmp/"
|
||||
num_trials: 3
|
||||
max_time: 100
|
||||
interval: 1
|
||||
seed: "CompleteSeed!"
|
||||
model_class: Environment
|
||||
model_params:
|
||||
topologies:
|
||||
default:
|
||||
params:
|
||||
generator: complete_graph
|
||||
n: 10
|
||||
agents:
|
||||
default:
|
||||
n: 4
|
||||
agents:
|
||||
agent_class: CounterModel
|
||||
state:
|
||||
group: network
|
||||
times: 1
|
||||
network:
|
||||
topology: 'default'
|
||||
distribution:
|
||||
- agent_class: CounterModel
|
||||
weight: 0.4
|
||||
weight: 0.25
|
||||
state:
|
||||
state_id: 0
|
||||
times: 1
|
||||
- agent_class: AggregatedCounter
|
||||
weight: 0.6
|
||||
override:
|
||||
- filter:
|
||||
node_id: 0
|
||||
weight: 0.5
|
||||
state:
|
||||
name: 'The first node'
|
||||
times: 2
|
||||
override:
|
||||
- filter:
|
||||
node_id: 1
|
||||
state:
|
||||
name: 'The second node'
|
||||
|
||||
environment:
|
||||
fixed:
|
||||
- name: 'Environment Agent 1'
|
||||
agent_class: CounterModel
|
||||
name: 'Node 1'
|
||||
- filter:
|
||||
node_id: 2
|
||||
state:
|
||||
name: 'Node 2'
|
||||
fixed:
|
||||
- agent_class: BaseAgent
|
||||
hidden: true
|
||||
topology: null
|
||||
state:
|
||||
name: 'Environment Agent 1'
|
||||
times: 10
|
||||
environment:
|
||||
environment_class: Environment
|
||||
params:
|
||||
group: environment
|
||||
am_i_complete: true
|
||||
|
@ -8,17 +8,20 @@ interval: 1
|
||||
seed: "CompleteSeed!"
|
||||
network_params:
|
||||
generator: complete_graph
|
||||
n: 10
|
||||
n: 4
|
||||
network_agents:
|
||||
- agent_class: CounterModel
|
||||
weight: 0.4
|
||||
weight: 0.25
|
||||
state:
|
||||
state_id: 0
|
||||
times: 1
|
||||
- agent_class: AggregatedCounter
|
||||
weight: 0.6
|
||||
weight: 0.5
|
||||
state:
|
||||
times: 2
|
||||
environment_agents:
|
||||
- agent_id: 'Environment Agent 1'
|
||||
agent_class: CounterModel
|
||||
agent_class: BaseAgent
|
||||
state:
|
||||
times: 10
|
||||
environment_class: Environment
|
||||
@ -28,5 +31,7 @@ agent_class: CounterModel
|
||||
default_state:
|
||||
times: 1
|
||||
states:
|
||||
- name: 'The first node'
|
||||
- name: 'The second node'
|
||||
1:
|
||||
name: 'Node 1'
|
||||
2:
|
||||
name: 'Node 2'
|
||||
|
@ -8,7 +8,7 @@ class Dead(agents.FSM):
|
||||
@agents.default_state
|
||||
@agents.state
|
||||
def only(self):
|
||||
self.die()
|
||||
return self.die()
|
||||
|
||||
class TestMain(TestCase):
|
||||
def test_die_raises_exception(self):
|
||||
@ -19,4 +19,6 @@ class TestMain(TestCase):
|
||||
|
||||
def test_die_returns_infinity(self):
|
||||
d = Dead(unique_id=0, model=environment.Environment())
|
||||
assert d.step().abs(0) == stime.INFINITY
|
||||
ret = d.step().abs(0)
|
||||
print(ret, 'next')
|
||||
assert ret == stime.INFINITY
|
||||
|
@ -1,91 +0,0 @@
|
||||
from unittest import TestCase
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from functools import partial
|
||||
|
||||
from os.path import join
|
||||
from soil import simulation, analysis, agents
|
||||
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
||||
class Ping(agents.FSM):
|
||||
|
||||
defaults = {
|
||||
'count': 0,
|
||||
}
|
||||
|
||||
@agents.default_state
|
||||
@agents.state
|
||||
def even(self):
|
||||
self.debug(f'Even {self["count"]}')
|
||||
self['count'] += 1
|
||||
return self.odd
|
||||
|
||||
@agents.state
|
||||
def odd(self):
|
||||
self.debug(f'Odd {self["count"]}')
|
||||
self['count'] += 1
|
||||
return self.even
|
||||
|
||||
|
||||
class TestAnalysis(TestCase):
|
||||
|
||||
# Code to generate a simple sqlite history
|
||||
def setUp(self):
|
||||
"""
|
||||
The initial states should be applied to the agent and the
|
||||
agent should be able to update its state."""
|
||||
config = {
|
||||
'name': 'analysis',
|
||||
'seed': 'seed',
|
||||
'network_params': {
|
||||
'generator': 'complete_graph',
|
||||
'n': 2
|
||||
},
|
||||
'agent_class': Ping,
|
||||
'states': [{'interval': 1}, {'interval': 2}],
|
||||
'max_time': 30,
|
||||
'num_trials': 1,
|
||||
'history': True,
|
||||
'environment_params': {
|
||||
}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
self.env = s.run_simulation(dry_run=True)[0]
|
||||
|
||||
def test_saved(self):
|
||||
env = self.env
|
||||
assert env.get_agent(0)['count', 0] == 1
|
||||
assert env.get_agent(0)['count', 29] == 30
|
||||
assert env.get_agent(1)['count', 0] == 1
|
||||
assert env.get_agent(1)['count', 29] == 15
|
||||
assert env['env', 29, None]['SEED'] == env['env', 29, 'SEED']
|
||||
|
||||
def test_count(self):
|
||||
env = self.env
|
||||
df = analysis.read_sql(env._history.db_path)
|
||||
res = analysis.get_count(df, 'SEED', 'state_id')
|
||||
assert res['SEED'][self.env['SEED']].iloc[0] == 1
|
||||
assert res['SEED'][self.env['SEED']].iloc[-1] == 1
|
||||
assert res['state_id']['odd'].iloc[0] == 2
|
||||
assert res['state_id']['even'].iloc[0] == 0
|
||||
assert res['state_id']['odd'].iloc[-1] == 1
|
||||
assert res['state_id']['even'].iloc[-1] == 1
|
||||
|
||||
def test_value(self):
|
||||
env = self.env
|
||||
df = analysis.read_sql(env._history.db_path)
|
||||
res_sum = analysis.get_value(df, 'count')
|
||||
|
||||
assert res_sum['count'].iloc[0] == 2
|
||||
|
||||
import numpy as np
|
||||
res_mean = analysis.get_value(df, 'count', aggfunc=np.mean)
|
||||
assert res_mean['count'].iloc[15] == (16+8)/2
|
||||
|
||||
res_total = analysis.get_majority(df)
|
||||
res_total['SEED'].iloc[0] == self.env['SEED']
|
@ -29,7 +29,7 @@ class TestConfig(TestCase):
|
||||
expected = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
|
||||
old = serialization.load_file(join(ROOT, "old_complete.yml"))[0]
|
||||
converted_defaults = config.convert_old(old, strict=False)
|
||||
converted = converted_defaults.dict(skip_defaults=True)
|
||||
converted = converted_defaults.dict(exclude_unset=True)
|
||||
|
||||
isequal(converted, expected)
|
||||
|
||||
@ -40,10 +40,10 @@ class TestConfig(TestCase):
|
||||
"""
|
||||
config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
||||
s = simulation.from_config(config)
|
||||
init_config = copy.copy(s.config)
|
||||
init_config = copy.copy(s.to_dict())
|
||||
|
||||
s.run_simulation(dry_run=True)
|
||||
nconfig = s.config
|
||||
nconfig = s.to_dict()
|
||||
# del nconfig['to
|
||||
isequal(init_config, nconfig)
|
||||
|
||||
@ -61,7 +61,7 @@ class TestConfig(TestCase):
|
||||
Simple configuration that tests that the graph is loaded, and that
|
||||
network agents are initialized properly.
|
||||
"""
|
||||
config = {
|
||||
cfg = {
|
||||
'name': 'CounterAgent',
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
@ -74,12 +74,14 @@ class TestConfig(TestCase):
|
||||
'environment_params': {
|
||||
}
|
||||
}
|
||||
s = simulation.from_old_config(config)
|
||||
conf = config.convert_old(cfg)
|
||||
s = simulation.from_config(conf)
|
||||
|
||||
env = s.get_env()
|
||||
assert len(env.topologies['default'].nodes) == 2
|
||||
assert len(env.topologies['default'].edges) == 1
|
||||
assert len(env.agents) == 2
|
||||
assert env.agents[0].topology == env.topologies['default']
|
||||
assert env.agents[0].G == env.topologies['default']
|
||||
|
||||
def test_agents_from_config(self):
|
||||
'''We test that the known complete configuration produces
|
||||
@ -87,13 +89,10 @@ class TestConfig(TestCase):
|
||||
cfg = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
|
||||
s = simulation.from_config(cfg)
|
||||
env = s.get_env()
|
||||
assert len(env.topologies['default'].nodes) == 10
|
||||
assert len(env.agents(group='network')) == 10
|
||||
assert len(env.topologies['default'].nodes) == 4
|
||||
assert len(env.agents(group='network')) == 4
|
||||
assert len(env.agents(group='environment')) == 1
|
||||
|
||||
assert sum(1 for a in env.agents(group='network', agent_class=agents.CounterModel)) == 4
|
||||
assert sum(1 for a in env.agents(group='network', agent_class=agents.AggregatedCounter)) == 6
|
||||
|
||||
def test_yaml(self):
|
||||
"""
|
||||
The YAML version of a newly created configuration should be equivalent
|
||||
|
@ -2,7 +2,7 @@ from unittest import TestCase
|
||||
import os
|
||||
from os.path import join
|
||||
|
||||
from soil import serialization, simulation
|
||||
from soil import serialization, simulation, config
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
EXAMPLES = join(ROOT, '..', 'examples')
|
||||
@ -14,36 +14,37 @@ class TestExamples(TestCase):
|
||||
pass
|
||||
|
||||
|
||||
def make_example_test(path, config):
|
||||
def make_example_test(path, cfg):
|
||||
def wrapped(self):
|
||||
root = os.getcwd()
|
||||
for s in simulation.all_from_config(path):
|
||||
iterations = s.config.general.max_time * s.config.general.num_trials
|
||||
if iterations > 1000:
|
||||
s.config.general.max_time = 100
|
||||
s.config.general.num_trials = 1
|
||||
if config.get('skip_test', False) and not FORCE_TESTS:
|
||||
for s in simulation.iter_from_config(cfg):
|
||||
iterations = s.max_steps * s.num_trials
|
||||
if iterations < 0 or iterations > 1000:
|
||||
s.max_steps = 100
|
||||
s.num_trials = 1
|
||||
assert isinstance(cfg, config.Config)
|
||||
if getattr(cfg, 'skip_test', False) and not FORCE_TESTS:
|
||||
self.skipTest('Example ignored.')
|
||||
envs = s.run_simulation(dry_run=True)
|
||||
assert envs
|
||||
for env in envs:
|
||||
assert env
|
||||
try:
|
||||
n = config['network_params']['n']
|
||||
n = cfg.model_params['network_params']['n']
|
||||
assert len(list(env.network_agents)) == n
|
||||
assert env.now > 0 # It has run
|
||||
assert env.now <= config['max_time'] # But not further than allowed
|
||||
except KeyError:
|
||||
pass
|
||||
assert env.schedule.steps > 0 # It has run
|
||||
assert env.schedule.steps <= s.max_steps # But not further than allowed
|
||||
return wrapped
|
||||
|
||||
|
||||
def add_example_tests():
|
||||
for config, path in serialization.load_files(
|
||||
for cfg, path in serialization.load_files(
|
||||
join(EXAMPLES, '*', '*.yml'),
|
||||
join(EXAMPLES, '*.yml'),
|
||||
):
|
||||
p = make_example_test(path=path, config=config)
|
||||
p = make_example_test(path=path, cfg=config.Config.from_raw(cfg))
|
||||
fname = os.path.basename(path)
|
||||
p.__name__ = 'test_example_file_%s' % fname
|
||||
p.__doc__ = '%s should be a valid configuration' % fname
|
||||
|
@ -6,6 +6,8 @@ import shutil
|
||||
from unittest import TestCase
|
||||
from soil import exporters
|
||||
from soil import simulation
|
||||
from soil import agents
|
||||
|
||||
|
||||
class Dummy(exporters.Exporter):
|
||||
started = False
|
||||
@ -33,28 +35,36 @@ class Dummy(exporters.Exporter):
|
||||
|
||||
class Exporters(TestCase):
|
||||
def test_basic(self):
|
||||
# We need to add at least one agent to make sure the scheduler
|
||||
# ticks every step
|
||||
num_trials = 5
|
||||
max_time = 2
|
||||
config = {
|
||||
'name': 'exporter_sim',
|
||||
'network_params': {},
|
||||
'agent_class': 'CounterModel',
|
||||
'max_time': 2,
|
||||
'num_trials': 5,
|
||||
'environment_params': {}
|
||||
'model_params': {
|
||||
'agents': [{
|
||||
'agent_class': agents.BaseAgent
|
||||
}]
|
||||
},
|
||||
'max_time': max_time,
|
||||
'num_trials': num_trials,
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
|
||||
for env in s.run_simulation(exporters=[Dummy], dry_run=True):
|
||||
assert env.now <= 2
|
||||
assert len(env.agents) == 1
|
||||
assert env.now == max_time
|
||||
|
||||
assert Dummy.started
|
||||
assert Dummy.ended
|
||||
assert Dummy.called_start == 1
|
||||
assert Dummy.called_end == 1
|
||||
assert Dummy.called_trial == 5
|
||||
assert Dummy.trials == 5
|
||||
assert Dummy.total_time == 2*5
|
||||
assert Dummy.called_trial == num_trials
|
||||
assert Dummy.trials == num_trials
|
||||
assert Dummy.total_time == max_time * num_trials
|
||||
|
||||
def test_writing(self):
|
||||
'''Try to write CSV, GEXF, sqlite and YAML (without dry_run)'''
|
||||
'''Try to write CSV, sqlite and YAML (without dry_run)'''
|
||||
n_trials = 5
|
||||
config = {
|
||||
'name': 'exporter_sim',
|
||||
@ -74,7 +84,6 @@ class Exporters(TestCase):
|
||||
envs = s.run_simulation(exporters=[
|
||||
exporters.default,
|
||||
exporters.csv,
|
||||
exporters.gexf,
|
||||
],
|
||||
dry_run=False,
|
||||
outdir=tmpdir,
|
||||
@ -88,11 +97,7 @@ class Exporters(TestCase):
|
||||
|
||||
try:
|
||||
for e in envs:
|
||||
with open(os.path.join(simdir, '{}.gexf'.format(e.name))) as f:
|
||||
result = f.read()
|
||||
assert result
|
||||
|
||||
with open(os.path.join(simdir, '{}.csv'.format(e.name))) as f:
|
||||
with open(os.path.join(simdir, '{}.env.csv'.format(e.id))) as f:
|
||||
result = f.read()
|
||||
assert result
|
||||
finally:
|
||||
|
@ -1,128 +0,0 @@
|
||||
from unittest import TestCase
|
||||
|
||||
import os
|
||||
import io
|
||||
import yaml
|
||||
import copy
|
||||
import pickle
|
||||
import networkx as nx
|
||||
from functools import partial
|
||||
|
||||
from os.path import join
|
||||
from soil import (simulation, Environment, agents, serialization,
|
||||
utils)
|
||||
from soil.time import Delta
|
||||
from tsih import NoHistory, History
|
||||
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
EXAMPLES = join(ROOT, '..', 'examples')
|
||||
|
||||
|
||||
class CustomAgent(agents.FSM):
|
||||
@agents.default_state
|
||||
@agents.state
|
||||
def normal(self):
|
||||
self.neighbors = self.count_agents(state_id='normal',
|
||||
limit_neighbors=True)
|
||||
@agents.state
|
||||
def unreachable(self):
|
||||
return
|
||||
|
||||
class TestHistory(TestCase):
|
||||
|
||||
def test_counter_agent_history(self):
|
||||
"""
|
||||
The evolution of the state should be recorded in the logging agent
|
||||
"""
|
||||
config = {
|
||||
'name': 'CounterAgent',
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
},
|
||||
'network_agents': [{
|
||||
'agent_class': 'AggregatedCounter',
|
||||
'weight': 1,
|
||||
'state': {'state_id': 0}
|
||||
|
||||
}],
|
||||
'max_time': 10,
|
||||
'environment_params': {
|
||||
}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
env = s.run_simulation(dry_run=True)[0]
|
||||
for agent in env.network_agents:
|
||||
last = 0
|
||||
assert len(agent[None, None]) == 11
|
||||
for step, total in sorted(agent['total', None]):
|
||||
assert total == last + 2
|
||||
last = total
|
||||
|
||||
def test_row_conversion(self):
|
||||
env = Environment(history=True)
|
||||
env['test'] = 'test_value'
|
||||
|
||||
res = list(env.history_to_tuples())
|
||||
assert len(res) == len(env.environment_params)
|
||||
|
||||
env.schedule.time = 1
|
||||
env['test'] = 'second_value'
|
||||
res = list(env.history_to_tuples())
|
||||
|
||||
assert env['env', 0, 'test' ] == 'test_value'
|
||||
assert env['env', 1, 'test' ] == 'second_value'
|
||||
|
||||
def test_nohistory(self):
|
||||
'''
|
||||
Make sure that no history(/sqlite) is used by default
|
||||
'''
|
||||
env = Environment(topology=nx.Graph(), network_agents=[])
|
||||
assert isinstance(env._history, NoHistory)
|
||||
|
||||
def test_save_graph_history(self):
|
||||
'''
|
||||
The history_to_graph method should return a valid networkx graph.
|
||||
|
||||
The state of the agent should be encoded as intervals in the nx graph.
|
||||
'''
|
||||
G = nx.cycle_graph(5)
|
||||
distribution = agents.calculate_distribution(None, agents.BaseAgent)
|
||||
env = Environment(topology=G, network_agents=distribution, history=True)
|
||||
env[0, 0, 'testvalue'] = 'start'
|
||||
env[0, 10, 'testvalue'] = 'finish'
|
||||
nG = env.history_to_graph()
|
||||
values = nG.nodes[0]['attr_testvalue']
|
||||
assert ('start', 0, 10) in values
|
||||
assert ('finish', 10, None) in values
|
||||
|
||||
def test_save_graph_nohistory(self):
|
||||
'''
|
||||
The history_to_graph method should return a valid networkx graph.
|
||||
|
||||
When NoHistory is used, only the last known value is known
|
||||
'''
|
||||
G = nx.cycle_graph(5)
|
||||
distribution = agents.calculate_distribution(None, agents.BaseAgent)
|
||||
env = Environment(topology=G, network_agents=distribution, history=False)
|
||||
env.get_agent(0)['testvalue'] = 'start'
|
||||
env.schedule.time = 10
|
||||
env.get_agent(0)['testvalue'] = 'finish'
|
||||
nG = env.history_to_graph()
|
||||
values = nG.nodes[0]['attr_testvalue']
|
||||
assert ('start', 0, None) not in values
|
||||
assert ('finish', 10, None) in values
|
||||
|
||||
def test_pickle_agent_environment(self):
|
||||
env = Environment(name='Test', history=True)
|
||||
a = agents.BaseAgent(model=env, unique_id=25)
|
||||
|
||||
a['key'] = 'test'
|
||||
|
||||
pickled = pickle.dumps(a)
|
||||
recovered = pickle.loads(pickled)
|
||||
|
||||
assert recovered.env.name == 'Test'
|
||||
assert list(recovered.env._history.to_tuples())
|
||||
assert recovered['key', 0] == 'test'
|
||||
assert recovered['key'] == 'test'
|
@ -24,6 +24,7 @@ class CustomAgent(agents.FSM, agents.NetworkAgent):
|
||||
def unreachable(self):
|
||||
return
|
||||
|
||||
|
||||
class TestMain(TestCase):
|
||||
|
||||
def test_empty_simulation(self):
|
||||
@ -79,20 +80,16 @@ class TestMain(TestCase):
|
||||
}
|
||||
},
|
||||
'agents': {
|
||||
'default': {
|
||||
'agent_class': 'CounterModel',
|
||||
},
|
||||
'counters': {
|
||||
'topology': 'default',
|
||||
'fixed': [{'state': {'times': 10}}, {'state': {'times': 20}}],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
env = s.get_env()
|
||||
assert isinstance(env.agents[0], agents.CounterModel)
|
||||
assert env.agents[0].topology == env.topologies['default']
|
||||
assert env.agents[0].G == env.topologies['default']
|
||||
assert env.agents[0]['times'] == 10
|
||||
assert env.agents[0]['times'] == 10
|
||||
env.step()
|
||||
@ -105,8 +102,8 @@ class TestMain(TestCase):
|
||||
config = {
|
||||
'max_time': 10,
|
||||
'model_params': {
|
||||
'agents': [(CustomAgent, {'weight': 1}),
|
||||
(CustomAgent, {'weight': 3}),
|
||||
'agents': [{'agent_class': CustomAgent, 'weight': 1, 'topology': 'default'},
|
||||
{'agent_class': CustomAgent, 'weight': 3, 'topology': 'default'},
|
||||
],
|
||||
'topologies': {
|
||||
'default': {
|
||||
@ -128,7 +125,7 @@ class TestMain(TestCase):
|
||||
"""A complete example from a documentation should work."""
|
||||
config = serialization.load_file(join(EXAMPLES, 'torvalds.yml'))[0]
|
||||
config['model_params']['network_params']['path'] = join(EXAMPLES,
|
||||
config['network_params']['path'])
|
||||
config['model_params']['network_params']['path'])
|
||||
s = simulation.from_config(config)
|
||||
env = s.run_simulation(dry_run=True)[0]
|
||||
for a in env.network_agents:
|
||||
@ -208,24 +205,6 @@ class TestMain(TestCase):
|
||||
assert converted[1]['agent_class'] == 'test_main.CustomAgent'
|
||||
pickle.dumps(converted)
|
||||
|
||||
def test_subgraph(self):
|
||||
'''An agent should be able to subgraph the global topology'''
|
||||
G = nx.Graph()
|
||||
G.add_node(3)
|
||||
G.add_edge(1, 2)
|
||||
distro = agents.calculate_distribution(agent_class=agents.NetworkAgent)
|
||||
distro[0]['topology'] = 'default'
|
||||
aconfig = config.AgentConfig(distribution=distro, topology='default')
|
||||
env = Environment(name='Test', topologies={'default': G}, agents={'network': aconfig})
|
||||
lst = list(env.network_agents)
|
||||
|
||||
a2 = env.find_one(node_id=2)
|
||||
a3 = env.find_one(node_id=3)
|
||||
assert len(a2.subgraph(limit_neighbors=True)) == 2
|
||||
assert len(a3.subgraph(limit_neighbors=True)) == 1
|
||||
assert len(a3.subgraph(limit_neighbors=True, center=False)) == 0
|
||||
assert len(a3.subgraph(agent_class=agents.NetworkAgent)) == 3
|
||||
|
||||
def test_templates(self):
|
||||
'''Loading a template should result in several configs'''
|
||||
configs = serialization.load_file(join(EXAMPLES, 'template.yml'))
|
||||
@ -236,14 +215,18 @@ class TestMain(TestCase):
|
||||
'name': 'until_sim',
|
||||
'model_params': {
|
||||
'network_params': {},
|
||||
'agent_class': 'CounterModel',
|
||||
'agents': {
|
||||
'fixed': [{
|
||||
'agent_class': agents.BaseAgent,
|
||||
}]
|
||||
},
|
||||
},
|
||||
'max_time': 2,
|
||||
'num_trials': 50,
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
runs = list(s.run_simulation(dry_run=True))
|
||||
over = list(x.now for x in runs if x.now>2)
|
||||
over = list(x.now for x in runs if x.now > 2)
|
||||
assert len(runs) == config['num_trials']
|
||||
assert len(over) == 0
|
||||
|
||||
|
@ -6,7 +6,8 @@ import networkx as nx
|
||||
|
||||
from os.path import join
|
||||
|
||||
from soil import network, environment
|
||||
from soil import config, network, environment, agents, simulation
|
||||
from test_main import CustomAgent
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
EXAMPLES = join(ROOT, '..', 'examples')
|
||||
@ -60,22 +61,53 @@ class TestNetwork(TestCase):
|
||||
G = nx.random_geometric_graph(20, 0.1)
|
||||
env = environment.NetworkEnvironment(topology=G)
|
||||
f = io.BytesIO()
|
||||
env.dump_gexf(f)
|
||||
assert env.topologies['default']
|
||||
network.dump_gexf(env.topologies['default'], f)
|
||||
|
||||
def test_networkenvironment_creation(self):
|
||||
"""Networkenvironment should accept netconfig as parameters"""
|
||||
model_params = {
|
||||
'topologies': {
|
||||
'default': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
}
|
||||
},
|
||||
'agents': {
|
||||
'topology': 'default',
|
||||
'distribution': [{
|
||||
'agent_class': CustomAgent,
|
||||
}]
|
||||
}
|
||||
}
|
||||
env = environment.Environment(**model_params)
|
||||
assert env.topologies
|
||||
env.step()
|
||||
assert len(env.topologies['default']) == 2
|
||||
assert len(env.agents) == 2
|
||||
assert env.agents[1].count_agents(state_id='normal') == 2
|
||||
assert env.agents[1].count_agents(state_id='normal', limit_neighbors=True) == 1
|
||||
assert env.agents[0].neighbors == 1
|
||||
|
||||
def test_custom_agent_neighbors(self):
|
||||
"""Allow for search of neighbors with a certain state_id"""
|
||||
config = {
|
||||
'network_params': {
|
||||
'model_params': {
|
||||
'topologies': {
|
||||
'default': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
},
|
||||
'network_agents': [{
|
||||
'agent_class': CustomAgent,
|
||||
'weight': 1
|
||||
|
||||
}],
|
||||
'max_time': 10,
|
||||
'environment_params': {
|
||||
}
|
||||
},
|
||||
'agents': {
|
||||
'topology': 'default',
|
||||
'distribution': [
|
||||
{
|
||||
'weight': 1,
|
||||
'agent_class': CustomAgent
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
'max_time': 10,
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
env = s.run_simulation(dry_run=True)[0]
|
||||
@ -83,3 +115,19 @@ class TestNetwork(TestCase):
|
||||
assert env.agents[1].count_agents(state_id='normal', limit_neighbors=True) == 1
|
||||
assert env.agents[0].neighbors == 1
|
||||
|
||||
def test_subgraph(self):
|
||||
'''An agent should be able to subgraph the global topology'''
|
||||
G = nx.Graph()
|
||||
G.add_node(3)
|
||||
G.add_edge(1, 2)
|
||||
distro = agents.calculate_distribution(agent_class=agents.NetworkAgent)
|
||||
aconfig = config.AgentConfig(distribution=distro, topology='default')
|
||||
env = environment.Environment(name='Test', topologies={'default': G}, agents=aconfig)
|
||||
lst = list(env.network_agents)
|
||||
|
||||
a2 = env.find_one(node_id=2)
|
||||
a3 = env.find_one(node_id=3)
|
||||
assert len(a2.subgraph(limit_neighbors=True)) == 2
|
||||
assert len(a3.subgraph(limit_neighbors=True)) == 1
|
||||
assert len(a3.subgraph(limit_neighbors=True, center=False)) == 0
|
||||
assert len(a3.subgraph(agent_class=agents.NetworkAgent)) == 3
|
||||
|
Loading…
Reference in New Issue
Block a user