From cd62c23cb955e6695d30fead1556845bd2454ae9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=2E=20Fernando=20S=C3=A1nchez?= Date: Thu, 13 Oct 2022 22:43:16 +0200 Subject: [PATCH] WIP: all tests pass --- CHANGELOG.md | 2 + docs/soil-vs.rst | 12 + examples/complete.yml | 81 ++- examples/complete_opt2.yml | 63 ++ .../custom_generator/custom_generator.yml | 2 +- examples/custom_generator/mymodule.py | 5 +- examples/mesa/mesa.yml | 22 +- examples/newsspread/NewsSpread.yml | 10 +- examples/newsspread/newsspread.py | 16 +- examples/pubcrawl/pubcrawl.py | 18 +- examples/pubcrawl/pubcrawl.yml | 2 +- examples/rabbits/README.md | 4 + examples/rabbits/basic/rabbit_agents.py | 130 ++++ examples/rabbits/basic/rabbits.yml | 41 ++ examples/rabbits/improved/rabbit_agents.py | 130 ++++ examples/rabbits/improved/rabbits.yml | 41 ++ examples/rabbits/rabbit_agents.py | 133 ---- examples/rabbits/rabbits.yml | 20 - examples/template.yml | 24 +- examples/terrorism/TerroristNetworkModel.py | 26 +- examples/terrorism/TerroristNetworkModel.yml | 50 +- examples/torvalds.yml | 23 +- requirements.txt | 5 +- soil/__init__.py | 79 ++- soil/agents/CounterModel.py | 20 +- soil/agents/__init__.py | 641 ++++++++++-------- soil/analysis.py | 206 ------ soil/config.py | 167 +++-- soil/debugging.py | 151 +++++ soil/environment.py | 202 +++--- soil/exporters.py | 58 +- soil/network.py | 19 +- soil/serialization.py | 31 +- soil/simulation.py | 144 ++-- soil/time.py | 30 +- soil/utils.py | 39 +- tests/complete_converted.yml | 65 +- tests/old_complete.yml | 17 +- tests/test_agents.py | 6 +- tests/test_analysis.py | 91 --- tests/test_config.py | 21 +- tests/test_examples.py | 27 +- tests/test_exporters.py | 37 +- tests/test_history.py | 128 ---- tests/test_main.py | 45 +- tests/test_network.py | 70 +- 46 files changed, 1720 insertions(+), 1434 deletions(-) create mode 100644 docs/soil-vs.rst create mode 100644 examples/complete_opt2.yml create mode 100644 examples/rabbits/README.md create mode 100644 examples/rabbits/basic/rabbit_agents.py create mode 100644 examples/rabbits/basic/rabbits.yml create mode 100644 examples/rabbits/improved/rabbit_agents.py create mode 100644 examples/rabbits/improved/rabbits.yml delete mode 100644 examples/rabbits/rabbit_agents.py delete mode 100644 examples/rabbits/rabbits.yml delete mode 100644 soil/analysis.py create mode 100644 soil/debugging.py delete mode 100644 tests/test_analysis.py delete mode 100644 tests/test_history.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 92c457e..a0a8a2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [0.3 UNRELEASED] +### Added +* Simple debugging capabilities, with a custom `pdb.Debugger` subclass that exposes commands to list agents and their status and set breakpoints on states (for FSM agents) ### Changed * Configuration schema is very different now. Check `soil.config` for more information. We are also using Pydantic for (de)serialization. * There may be more than one topology/network in the simulation diff --git a/docs/soil-vs.rst b/docs/soil-vs.rst new file mode 100644 index 0000000..53b6891 --- /dev/null +++ b/docs/soil-vs.rst @@ -0,0 +1,12 @@ +### MESA + +Starting with version 0.3, Soil has been redesigned to complement Mesa, while remaining compatible with it. +That means that every component in Soil (i.e., Models, Environments, etc.) can be mixed with existing mesa components. +In fact, there are examples that show how that integration may be used, in the `examples/mesa` folder in the repository. + +Here are some reasons to use Soil instead of plain mesa: + +- Less boilerplate for common scenarios (by some definitions of common) +- Functions to automatically populate a topology with an agent distribution (i.e., different ratios of agent class and state) +- The `soil.Simulation` class allows you to run multiple instances of the same experiment (i.e., multiple trials with the same parameters but a different randomness seed) +- Reporting functions that aggregate multiple diff --git a/examples/complete.yml b/examples/complete.yml index d33cbaf..2677c22 100644 --- a/examples/complete.yml +++ b/examples/complete.yml @@ -1,46 +1,54 @@ --- version: '2' -general: - id: simple - group: tests - dir_path: "/tmp/" - num_trials: 3 - max_time: 100 - interval: 1 - seed: "CompleteSeed!" -topologies: - default: - params: - generator: complete_graph - n: 10 - another_graph: - params: - generator: complete_graph - n: 2 -environment: - environment_class: Environment - params: - am_i_complete: true -agents: -# Agents are split several groups, each with its own definition - default: # This is a special group. Its values will be used as default values for the rest of the groups +name: simple +group: tests +dir_path: "/tmp/" +num_trials: 3 +max_steps: 100 +interval: 1 +seed: "CompleteSeed!" +model_class: Environment +model_params: + am_i_complete: true + topologies: + default: + params: + generator: complete_graph + n: 10 + another_graph: + params: + generator: complete_graph + n: 2 + environment: + agents: agent_class: CounterModel topology: default state: times: 1 - environment: - # In this group we are not specifying any topology - topology: False + # In this group we are not specifying any topology fixed: - name: 'Environment Agent 1' - agent_class: CounterModel + agent_class: BaseAgent + group: environment + topology: null + hidden: true state: times: 10 - general_counters: - topology: default + - agent_class: CounterModel + id: 0 + group: other_counters + topology: another_graph + state: + times: 1 + total: 0 + - agent_class: CounterModel + topology: another_graph + group: other_counters + id: 1 distribution: - agent_class: CounterModel weight: 1 + group: general_counters state: times: 3 - agent_class: AggregatedCounter @@ -51,16 +59,3 @@ agents: n: 2 state: times: 5 - - other_counters: - topology: another_graph - fixed: - - agent_class: CounterModel - id: 0 - state: - times: 1 - total: 0 - - agent_class: CounterModel - id: 1 - # If not specified, it will use the state set in the default - # state: diff --git a/examples/complete_opt2.yml b/examples/complete_opt2.yml new file mode 100644 index 0000000..b4acc26 --- /dev/null +++ b/examples/complete_opt2.yml @@ -0,0 +1,63 @@ +--- +version: '2' +id: simple +group: tests +dir_path: "/tmp/" +num_trials: 3 +max_steps: 100 +interval: 1 +seed: "CompleteSeed!" +model_class: "soil.Environment" +model_params: + topologies: + default: + params: + generator: complete_graph + n: 10 + another_graph: + params: + generator: complete_graph + n: 2 + agents: + # The values here will be used as default values for any agent + agent_class: CounterModel + topology: default + state: + times: 1 + # This specifies a distribution of agents, each with a `weight` or an explicit number of agents + distribution: + - agent_class: CounterModel + weight: 1 + # This is inherited from the default settings + #topology: default + state: + times: 3 + - agent_class: AggregatedCounter + topology: default + weight: 0.2 + fixed: + - name: 'Environment Agent 1' + # All the other agents will assigned to the 'default' group + group: environment + # Do not count this agent towards total limits + hidden: true + agent_class: soil.BaseAgent + topology: null + state: + times: 10 + - agent_class: CounterModel + topology: another_graph + id: 0 + state: + times: 1 + total: 0 + - agent_class: CounterModel + topology: another_graph + id: 1 + override: + # 2 agents that match this filter will be updated to match the state {times: 5} + - filter: + agent_class: AggregatedCounter + n: 2 + state: + times: 5 diff --git a/examples/custom_generator/custom_generator.yml b/examples/custom_generator/custom_generator.yml index 12c130d..81f0314 100644 --- a/examples/custom_generator/custom_generator.yml +++ b/examples/custom_generator/custom_generator.yml @@ -2,7 +2,7 @@ name: custom-generator description: Using a custom generator for the network num_trials: 3 -max_time: 100 +max_steps: 100 interval: 1 network_params: generator: mymodule.mygenerator diff --git a/examples/custom_generator/mymodule.py b/examples/custom_generator/mymodule.py index ef3bacc..85226e0 100644 --- a/examples/custom_generator/mymodule.py +++ b/examples/custom_generator/mymodule.py @@ -1,4 +1,5 @@ from networkx import Graph +import random import networkx as nx def mygenerator(n=5, n_edges=5): @@ -13,9 +14,9 @@ def mygenerator(n=5, n_edges=5): for i in range(n_edges): nodes = list(G.nodes) - n_in = self.random.choice(nodes) + n_in = random.choice(nodes) nodes.remove(n_in) # Avoid loops - n_out = self.random.choice(nodes) + n_out = random.choice(nodes) G.add_edge(n_in, n_out) return G diff --git a/examples/mesa/mesa.yml b/examples/mesa/mesa.yml index a1572f2..6bdae6f 100644 --- a/examples/mesa/mesa.yml +++ b/examples/mesa/mesa.yml @@ -3,17 +3,21 @@ name: mesa_sim group: tests dir_path: "/tmp" num_trials: 3 -max_time: 100 +max_steps: 100 interval: 1 seed: '1' -network_params: - generator: social_wealth.graph_generator - n: 5 -network_agents: - - agent_class: social_wealth.SocialMoneyAgent - weight: 1 -environment_class: social_wealth.MoneyEnv -environment_params: +model_class: social_wealth.MoneyEnv +model_params: + topologies: + default: + params: + generator: social_wealth.graph_generator + n: 5 + agents: + distribution: + - agent_class: social_wealth.SocialMoneyAgent + topology: default + weight: 1 mesa_agent_class: social_wealth.MoneyAgent N: 10 width: 50 diff --git a/examples/newsspread/NewsSpread.yml b/examples/newsspread/NewsSpread.yml index 10ae525..d80a5d5 100644 --- a/examples/newsspread/NewsSpread.yml +++ b/examples/newsspread/NewsSpread.yml @@ -5,7 +5,7 @@ environment_params: prob_neighbor_spread: 0.0 prob_tv_spread: 0.01 interval: 1 -max_time: 300 +max_steps: 300 name: Sim_all_dumb network_agents: - agent_class: newsspread.DumbViewer @@ -28,7 +28,7 @@ environment_params: prob_neighbor_spread: 0.0 prob_tv_spread: 0.01 interval: 1 -max_time: 300 +max_steps: 300 name: Sim_half_herd network_agents: - agent_class: newsspread.DumbViewer @@ -59,7 +59,7 @@ environment_params: prob_neighbor_spread: 0.0 prob_tv_spread: 0.01 interval: 1 -max_time: 300 +max_steps: 300 name: Sim_all_herd network_agents: - agent_class: newsspread.HerdViewer @@ -85,7 +85,7 @@ environment_params: prob_tv_spread: 0.01 prob_neighbor_cure: 0.1 interval: 1 -max_time: 300 +max_steps: 300 name: Sim_wise_herd network_agents: - agent_class: newsspread.HerdViewer @@ -110,7 +110,7 @@ environment_params: prob_tv_spread: 0.01 prob_neighbor_cure: 0.1 interval: 1 -max_time: 300 +max_steps: 300 name: Sim_all_wise network_agents: - agent_class: newsspread.WiseViewer diff --git a/examples/newsspread/newsspread.py b/examples/newsspread/newsspread.py index c3b5e6b..14d666f 100644 --- a/examples/newsspread/newsspread.py +++ b/examples/newsspread/newsspread.py @@ -16,13 +16,13 @@ class DumbViewer(FSM, NetworkAgent): @state def neutral(self): if self['has_tv']: - if prob(self.env['prob_tv_spread']): + if self.prob(self.model['prob_tv_spread']): return self.infected @state def infected(self): for neighbor in self.get_neighboring_agents(state_id=self.neutral.id): - if prob(self.env['prob_neighbor_spread']): + if self.prob(self.model['prob_neighbor_spread']): neighbor.infect() def infect(self): @@ -44,9 +44,9 @@ class HerdViewer(DumbViewer): '''Notice again that this is NOT a state. See DumbViewer.infect for reference''' infected = self.count_neighboring_agents(state_id=self.infected.id) total = self.count_neighboring_agents() - prob_infect = self.env['prob_neighbor_spread'] * infected/total + prob_infect = self.model['prob_neighbor_spread'] * infected/total self.debug('prob_infect', prob_infect) - if prob(prob_infect): + if self.prob(prob_infect): self.set_state(self.infected) @@ -63,9 +63,9 @@ class WiseViewer(HerdViewer): @state def cured(self): - prob_cure = self.env['prob_neighbor_cure'] + prob_cure = self.model['prob_neighbor_cure'] for neighbor in self.get_neighboring_agents(state_id=self.infected.id): - if prob(prob_cure): + if self.prob(prob_cure): try: neighbor.cure() except AttributeError: @@ -80,7 +80,7 @@ class WiseViewer(HerdViewer): 1.0) infected = max(self.count_neighboring_agents(self.infected.id), 1.0) - prob_cure = self.env['prob_neighbor_cure'] * (cured/infected) - if prob(prob_cure): + prob_cure = self.model['prob_neighbor_cure'] * (cured/infected) + if self.prob(prob_cure): return self.cured return self.set_state(super().infected) diff --git a/examples/pubcrawl/pubcrawl.py b/examples/pubcrawl/pubcrawl.py index 7fc5b5f..e100893 100644 --- a/examples/pubcrawl/pubcrawl.py +++ b/examples/pubcrawl/pubcrawl.py @@ -60,12 +60,10 @@ class Patron(FSM, NetworkAgent): ''' level = logging.DEBUG - defaults = { - 'pub': None, - 'drunk': False, - 'pints': 0, - 'max_pints': 3, - } + pub = None + drunk = False + pints = 0 + max_pints = 3 @default_state @state @@ -89,9 +87,9 @@ class Patron(FSM, NetworkAgent): return self.sober_in_pub self.debug('I am looking for a pub') group = list(self.get_neighboring_agents()) - for pub in self.env.available_pubs(): + for pub in self.model.available_pubs(): self.debug('We\'re trying to get into {}: total: {}'.format(pub, len(group))) - if self.env.enter(pub, self, *group): + if self.model.enter(pub, self, *group): self.info('We\'re all {} getting in {}!'.format(len(group), pub)) return self.sober_in_pub @@ -128,7 +126,7 @@ class Patron(FSM, NetworkAgent): success depend on both agents' openness. ''' if force or self['openness'] > self.random.random(): - self.env.add_edge(self, other_agent) + self.model.add_edge(self, other_agent) self.info('Made some friend {}'.format(other_agent)) return True return False @@ -150,7 +148,7 @@ class Patron(FSM, NetworkAgent): return befriended -class Police(FSM, NetworkAgent): +class Police(FSM): '''Simple agent to take drunk people out of pubs.''' level = logging.INFO diff --git a/examples/pubcrawl/pubcrawl.yml b/examples/pubcrawl/pubcrawl.yml index 1f83f95..220b705 100644 --- a/examples/pubcrawl/pubcrawl.yml +++ b/examples/pubcrawl/pubcrawl.yml @@ -1,7 +1,7 @@ --- name: pubcrawl num_trials: 3 -max_time: 10 +max_steps: 10 dump: false network_params: # Generate 100 empty nodes. They will be assigned a network agent diff --git a/examples/rabbits/README.md b/examples/rabbits/README.md new file mode 100644 index 0000000..42b6011 --- /dev/null +++ b/examples/rabbits/README.md @@ -0,0 +1,4 @@ +There are two similar implementations of this simulation. + +- `basic`. Using simple primites +- `improved`. Using more advanced features such as the `time` module to avoid unnecessary computations (i.e., skip steps), and generator functions. diff --git a/examples/rabbits/basic/rabbit_agents.py b/examples/rabbits/basic/rabbit_agents.py new file mode 100644 index 0000000..2d5cf40 --- /dev/null +++ b/examples/rabbits/basic/rabbit_agents.py @@ -0,0 +1,130 @@ +from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent +from soil.time import Delta +from enum import Enum +from collections import Counter +import logging +import math + + +class RabbitModel(FSM, NetworkAgent): + + sexual_maturity = 30 + life_expectancy = 300 + + @default_state + @state + def newborn(self): + self.info('I am a newborn.') + self.age = 0 + self.offspring = 0 + return self.youngling + + @state + def youngling(self): + self.age += 1 + if self.age >= self.sexual_maturity: + self.info(f'I am fertile! My age is {self.age}') + return self.fertile + + @state + def fertile(self): + raise Exception("Each subclass should define its fertile state") + + @state + def dead(self): + self.die() + + +class Male(RabbitModel): + max_females = 5 + mating_prob = 0.001 + + @state + def fertile(self): + self.age += 1 + + if self.age > self.life_expectancy: + return self.dead + + # Males try to mate + for f in self.model.agents(agent_class=Female, + state_id=Female.fertile.id, + limit=self.max_females): + self.debug('FOUND A FEMALE: ', repr(f), self.mating_prob) + if self.prob(self['mating_prob']): + f.impregnate(self) + break # Take a break + + +class Female(RabbitModel): + gestation = 100 + + @state + def fertile(self): + # Just wait for a Male + self.age += 1 + if self.age > self.life_expectancy: + return self.dead + + def impregnate(self, male): + self.info(f'{repr(male)} impregnating female {repr(self)}') + self.mate = male + self.pregnancy = -1 + self.set_state(self.pregnant, when=self.now) + self.number_of_babies = int(8+4*self.random.random()) + self.debug('I am pregnant') + + @state + def pregnant(self): + self.age += 1 + self.pregnancy += 1 + + if self.prob(self.age / self.life_expectancy): + return self.die() + + if self.pregnancy >= self.gestation: + self.info('Having {} babies'.format(self.number_of_babies)) + for i in range(self.number_of_babies): + state = {} + agent_class = self.random.choice([Male, Female]) + child = self.model.add_node(agent_class=agent_class, + topology=self.topology, + **state) + child.add_edge(self) + try: + child.add_edge(self.mate) + self.model.agents[self.mate].offspring += 1 + except ValueError: + self.debug('The father has passed away') + + self.offspring += 1 + self.mate = None + return self.fertile + + @state + def dead(self): + super().dead() + if 'pregnancy' in self and self['pregnancy'] > -1: + self.info('A mother has died carrying a baby!!') + + +class RandomAccident(BaseAgent): + + level = logging.INFO + + def step(self): + rabbits_alive = self.model.topology.number_of_nodes() + + if not rabbits_alive: + return self.die() + + prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive))) + self.debug('Killing some rabbits with prob={}!'.format(prob_death)) + for i in self.iter_agents(agent_class=RabbitModel): + if i.state.id == i.dead.id: + continue + if self.prob(prob_death): + self.info('I killed a rabbit: {}'.format(i.id)) + rabbits_alive -= 1 + i.set_state(i.dead) + self.debug('Rabbits alive: {}'.format(rabbits_alive)) diff --git a/examples/rabbits/basic/rabbits.yml b/examples/rabbits/basic/rabbits.yml new file mode 100644 index 0000000..facaefe --- /dev/null +++ b/examples/rabbits/basic/rabbits.yml @@ -0,0 +1,41 @@ +--- +version: '2' +name: rabbits_basic +num_trials: 1 +seed: MySeed +description: null +group: null +interval: 1.0 +max_time: 100 +model_class: soil.environment.Environment +model_params: + agents: + topology: default + agent_class: rabbit_agents.RabbitModel + distribution: + - agent_class: rabbit_agents.Male + topology: default + weight: 1 + - agent_class: rabbit_agents.Female + topology: default + weight: 1 + fixed: + - agent_class: rabbit_agents.RandomAccident + topology: null + hidden: true + state: + group: environment + state: + group: network + mating_prob: 0.1 + prob_death: 0.001 + topologies: + default: + topology: + directed: true + links: [] + nodes: + - id: 1 + - id: 0 +extra: + visualization_params: {} diff --git a/examples/rabbits/improved/rabbit_agents.py b/examples/rabbits/improved/rabbit_agents.py new file mode 100644 index 0000000..d97b7e7 --- /dev/null +++ b/examples/rabbits/improved/rabbit_agents.py @@ -0,0 +1,130 @@ +from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent +from soil.time import Delta, When, NEVER +from enum import Enum +import logging +import math + + +class RabbitModel(FSM, NetworkAgent): + + mating_prob = 0.005 + offspring = 0 + birth = None + + sexual_maturity = 3 + life_expectancy = 30 + + @default_state + @state + def newborn(self): + self.birth = self.now + self.info(f'I am a newborn.') + self.model['rabbits_alive'] = self.model.get('rabbits_alive', 0) + 1 + + # Here we can skip the `youngling` state by using a coroutine/generator. + while self.age < self.sexual_maturity: + interval = self.sexual_maturity - self.age + yield Delta(interval) + + self.info(f'I am fertile! My age is {self.age}') + return self.fertile + + @property + def age(self): + return self.now - self.birth + + @state + def fertile(self): + raise Exception("Each subclass should define its fertile state") + + def step(self): + super().step() + if self.prob(self.age / self.life_expectancy): + return self.die() + + +class Male(RabbitModel): + + max_females = 5 + + @state + def fertile(self): + # Males try to mate + for f in self.model.agents(agent_class=Female, + state_id=Female.fertile.id, + limit=self.max_females): + self.debug('Found a female:', repr(f)) + if self.prob(self['mating_prob']): + f.impregnate(self) + break # Take a break, don't try to impregnate the rest + + +class Female(RabbitModel): + due_date = None + age_of_pregnancy = None + gestation = 10 + mate = None + + @state + def fertile(self): + return self.fertile, NEVER + + @state + def pregnant(self): + self.info('I am pregnant') + if self.age > self.life_expectancy: + return self.dead + + self.due_date = self.now + self.gestation + + number_of_babies = int(8+4*self.random.random()) + + while self.now < self.due_date: + yield When(self.due_date) + + self.info('Having {} babies'.format(number_of_babies)) + for i in range(number_of_babies): + agent_class = self.random.choice([Male, Female]) + child = self.model.add_node(agent_class=agent_class, + topology=self.topology) + self.model.add_edge(self, child) + self.model.add_edge(self.mate, child) + self.offspring += 1 + self.model.agents[self.mate].offspring += 1 + self.mate = None + self.due_date = None + return self.fertile + + @state + def dead(self): + super().dead() + if self.due_date is not None: + self.info('A mother has died carrying a baby!!') + + def impregnate(self, male): + self.info(f'{repr(male)} impregnating female {repr(self)}') + self.mate = male + self.set_state(self.pregnant, when=self.now) + + +class RandomAccident(BaseAgent): + + level = logging.INFO + + def step(self): + rabbits_total = self.model.topology.number_of_nodes() + if 'rabbits_alive' not in self.model: + self.model['rabbits_alive'] = 0 + rabbits_alive = self.model.get('rabbits_alive', rabbits_total) + prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive))) + self.debug('Killing some rabbits with prob={}!'.format(prob_death)) + for i in self.model.network_agents: + if i.state.id == i.dead.id: + continue + if self.prob(prob_death): + self.info('I killed a rabbit: {}'.format(i.id)) + rabbits_alive = self.model['rabbits_alive'] = rabbits_alive -1 + i.set_state(i.dead) + self.debug('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total)) + if self.model.count_agents(state_id=RabbitModel.dead.id) == self.model.topology.number_of_nodes(): + self.die() diff --git a/examples/rabbits/improved/rabbits.yml b/examples/rabbits/improved/rabbits.yml new file mode 100644 index 0000000..ce5dd68 --- /dev/null +++ b/examples/rabbits/improved/rabbits.yml @@ -0,0 +1,41 @@ +--- +version: '2' +name: rabbits_improved +num_trials: 1 +seed: MySeed +description: null +group: null +interval: 1.0 +max_time: 100 +model_class: soil.environment.Environment +model_params: + agents: + topology: default + agent_class: rabbit_agents.RabbitModel + distribution: + - agent_class: rabbit_agents.Male + topology: default + weight: 1 + - agent_class: rabbit_agents.Female + topology: default + weight: 1 + fixed: + - agent_class: rabbit_agents.RandomAccident + topology: null + hidden: true + state: + group: environment + state: + group: network + mating_prob: 0.1 + prob_death: 0.001 + topologies: + default: + topology: + directed: true + links: [] + nodes: + - id: 1 + - id: 0 +extra: + visualization_params: {} diff --git a/examples/rabbits/rabbit_agents.py b/examples/rabbits/rabbit_agents.py deleted file mode 100644 index df371b2..0000000 --- a/examples/rabbits/rabbit_agents.py +++ /dev/null @@ -1,133 +0,0 @@ -from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent -from enum import Enum -import logging -import math - - -class Genders(Enum): - male = 'male' - female = 'female' - - -class RabbitModel(FSM, NetworkAgent): - - defaults = { - 'age': 0, - 'gender': Genders.male.value, - 'mating_prob': 0.001, - 'offspring': 0, - } - - sexual_maturity = 3 #4*30 - life_expectancy = 365 * 3 - gestation = 33 - pregnancy = -1 - max_females = 5 - - @default_state - @state - def newborn(self): - self.debug(f'I am a newborn at age {self["age"]}') - self['age'] += 1 - - if self['age'] >= self.sexual_maturity: - self.debug('I am fertile!') - return self.fertile - @state - def fertile(self): - raise Exception("Each subclass should define its fertile state") - - @state - def dead(self): - self.info('Agent {} is dying'.format(self.id)) - self.die() - - -class Male(RabbitModel): - - @state - def fertile(self): - self['age'] += 1 - if self['age'] > self.life_expectancy: - return self.dead - - if self['gender'] == Genders.female.value: - return - - # Males try to mate - for f in self.get_agents(state_id=Female.fertile.id, - agent_class=Female, - limit_neighbors=False, - limit=self.max_females): - r = self.random.random() - if r < self['mating_prob']: - self.impregnate(f) - break # Take a break - def impregnate(self, whom): - whom['pregnancy'] = 0 - whom['mate'] = self.id - whom.set_state(whom.pregnant) - self.debug('{} impregnating: {}. {}'.format(self.id, whom.id, whom.state)) - -class Female(RabbitModel): - @state - def fertile(self): - # Just wait for a Male - pass - - @state - def pregnant(self): - self['age'] += 1 - if self['age'] > self.life_expectancy: - return self.dead - - self['pregnancy'] += 1 - self.debug('Pregnancy: {}'.format(self['pregnancy'])) - if self['pregnancy'] >= self.gestation: - number_of_babies = int(8+4*self.random.random()) - self.info('Having {} babies'.format(number_of_babies)) - for i in range(number_of_babies): - state = {} - state['gender'] = self.random.choice(list(Genders)).value - child = self.env.add_node(self.__class__, state) - self.env.add_edge(self.id, child.id) - self.env.add_edge(self['mate'], child.id) - # self.add_edge() - self.debug('A BABY IS COMING TO LIFE') - self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.topology.number_of_nodes())+1 - self.debug('Rabbits alive: {}'.format(self.env['rabbits_alive'])) - self['offspring'] += 1 - self.env.get_agent(self['mate'])['offspring'] += 1 - del self['mate'] - self['pregnancy'] = -1 - return self.fertile - - @state - def dead(self): - super().dead() - if 'pregnancy' in self and self['pregnancy'] > -1: - self.info('A mother has died carrying a baby!!') - - -class RandomAccident(BaseAgent): - - level = logging.DEBUG - - def step(self): - rabbits_total = self.env.topology.number_of_nodes() - if 'rabbits_alive' not in self.env: - self.env['rabbits_alive'] = 0 - rabbits_alive = self.env.get('rabbits_alive', rabbits_total) - prob_death = self.env.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive))) - self.debug('Killing some rabbits with prob={}!'.format(prob_death)) - for i in self.env.network_agents: - if i.state['id'] == i.dead.id: - continue - if self.prob(prob_death): - self.debug('I killed a rabbit: {}'.format(i.id)) - rabbits_alive = self.env['rabbits_alive'] = rabbits_alive -1 - self.log('Rabbits alive: {}'.format(self.env['rabbits_alive'])) - i.set_state(i.dead) - self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total)) - if self.env.count_agents(state_id=RabbitModel.dead.id) == self.env.topology.number_of_nodes(): - self.die() diff --git a/examples/rabbits/rabbits.yml b/examples/rabbits/rabbits.yml deleted file mode 100644 index 1b1d148..0000000 --- a/examples/rabbits/rabbits.yml +++ /dev/null @@ -1,20 +0,0 @@ ---- -name: rabbits_example -max_time: 100 -interval: 1 -seed: MySeed -agent_class: rabbit_agents.RabbitModel -environment_agents: - - agent_class: rabbit_agents.RandomAccident -environment_params: - prob_death: 0.001 -default_state: - mating_prob: 0.1 -topology: - nodes: - - id: 1 - agent_class: rabbit_agents.Male - - id: 0 - agent_class: rabbit_agents.Female - directed: true - links: [] diff --git a/examples/template.yml b/examples/template.yml index a307eff..2b1f2b7 100644 --- a/examples/template.yml +++ b/examples/template.yml @@ -6,20 +6,20 @@ template: group: simple num_trials: 1 interval: 1 - max_time: 2 + max_steps: 2 seed: "CompleteSeed!" dump: false - network_params: - generator: complete_graph - n: 10 - network_agents: - - agent_class: CounterModel - weight: "{{ x1 }}" - state: - state_id: 0 - - agent_class: AggregatedCounter - weight: "{{ 1 - x1 }}" - environment_params: + model_params: + network_params: + generator: complete_graph + n: 10 + network_agents: + - agent_class: CounterModel + weight: "{{ x1 }}" + state: + state_id: 0 + - agent_class: AggregatedCounter + weight: "{{ 1 - x1 }}" name: "{{ x3 }}" skip_test: true vars: diff --git a/examples/terrorism/TerroristNetworkModel.py b/examples/terrorism/TerroristNetworkModel.py index 2fa6de4..bf5045f 100644 --- a/examples/terrorism/TerroristNetworkModel.py +++ b/examples/terrorism/TerroristNetworkModel.py @@ -81,6 +81,26 @@ class TerroristSpreadModel(FSM, Geo): return return self.leader + def ego_search(self, steps=1, center=False, node=None, **kwargs): + '''Get a list of nodes in the ego network of *node* of radius *steps*''' + node = as_node(node if node is not None else self) + G = self.subgraph(**kwargs) + return nx.ego_graph(G, node, center=center, radius=steps).nodes() + + def degree(self, node, force=False): + node = as_node(node) + if force or (not hasattr(self.model, '_degree')) or getattr(self.model, '_last_step', 0) < self.now: + self.model._degree = nx.degree_centrality(self.G) + self.model._last_step = self.now + return self.model._degree[node] + + def betweenness(self, node, force=False): + node = as_node(node) + if force or (not hasattr(self.model, '_betweenness')) or getattr(self.model, '_last_step', 0) < self.now: + self.model._betweenness = nx.betweenness_centrality(self.G) + self.model._last_step = self.now + return self.model._betweenness[node] + class TrainingAreaModel(FSM, Geo): """ @@ -194,14 +214,14 @@ class TerroristNetworkModel(TerroristSpreadModel): break def get_distance(self, target): - source_x, source_y = nx.get_node_attributes(self.topology, 'pos')[self.id] - target_x, target_y = nx.get_node_attributes(self.topology, 'pos')[target] + source_x, source_y = nx.get_node_attributes(self.G, 'pos')[self.id] + target_x, target_y = nx.get_node_attributes(self.G, 'pos')[target] dx = abs( source_x - target_x ) dy = abs( source_y - target_y ) return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 ) def shortest_path_length(self, target): try: - return nx.shortest_path_length(self.topology, self.id, target) + return nx.shortest_path_length(self.G, self.id, target) except nx.NetworkXNoPath: return float('inf') diff --git a/examples/terrorism/TerroristNetworkModel.yml b/examples/terrorism/TerroristNetworkModel.yml index b5a3d09..f709766 100644 --- a/examples/terrorism/TerroristNetworkModel.yml +++ b/examples/terrorism/TerroristNetworkModel.yml @@ -1,31 +1,31 @@ name: TerroristNetworkModel_sim -max_time: 150 +max_steps: 150 num_trials: 1 -network_params: - generator: random_geometric_graph - radius: 0.2 - # generator: geographical_threshold_graph - # theta: 20 - n: 100 -network_agents: - - agent_class: TerroristNetworkModel.TerroristNetworkModel - weight: 0.8 - state: - id: civilian # Civilians - - agent_class: TerroristNetworkModel.TerroristNetworkModel - weight: 0.1 - state: - id: leader # Leaders - - agent_class: TerroristNetworkModel.TrainingAreaModel - weight: 0.05 - state: - id: terrorist # Terrorism - - agent_class: TerroristNetworkModel.HavenModel - weight: 0.05 - state: - id: civilian # Civilian +model_params: + network_params: + generator: random_geometric_graph + radius: 0.2 + # generator: geographical_threshold_graph + # theta: 20 + n: 100 + network_agents: + - agent_class: TerroristNetworkModel.TerroristNetworkModel + weight: 0.8 + state: + id: civilian # Civilians + - agent_class: TerroristNetworkModel.TerroristNetworkModel + weight: 0.1 + state: + id: leader # Leaders + - agent_class: TerroristNetworkModel.TrainingAreaModel + weight: 0.05 + state: + id: terrorist # Terrorism + - agent_class: TerroristNetworkModel.HavenModel + weight: 0.05 + state: + id: civilian # Civilian -environment_params: # TerroristSpreadModel information_spread_intensity: 0.7 terrorist_additional_influence: 0.035 diff --git a/examples/torvalds.yml b/examples/torvalds.yml index 421e2ac..3073d8c 100644 --- a/examples/torvalds.yml +++ b/examples/torvalds.yml @@ -1,14 +1,15 @@ --- name: torvalds_example -max_time: 10 +max_steps: 10 interval: 2 -agent_class: CounterModel -default_state: - skill_level: 'beginner' -network_params: - path: 'torvalds.edgelist' -states: - Torvalds: - skill_level: 'God' - balkian: - skill_level: 'developer' +model_params: + agent_class: CounterModel + default_state: + skill_level: 'beginner' + network_params: + path: 'torvalds.edgelist' + states: + Torvalds: + skill_level: 'God' + balkian: + skill_level: 'developer' diff --git a/requirements.txt b/requirements.txt index 31f12d5..8383887 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,9 @@ networkx>=2.5 numpy matplotlib pyyaml>=5.1 -pandas>=0.23 +pandas>=1 SALib>=1.3 Jinja2 -Mesa>=0.8.9 +Mesa>=1 pydantic>=1.9 +sqlalchemy>=1.4 diff --git a/soil/__init__.py b/soil/__init__.py index 44b548f..9219e04 100644 --- a/soil/__init__.py +++ b/soil/__init__.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import importlib import sys import os -import pdb import logging +import traceback from .version import __version__ @@ -16,11 +18,10 @@ from . import agents from .simulation import * from .environment import Environment from . import serialization -from . import analysis from .utils import logger from .time import * -def main(): +def main(cfg='simulation.yml', **kwargs): import argparse from . import simulation @@ -29,7 +30,7 @@ def main(): parser = argparse.ArgumentParser(description='Run a SOIL simulation') parser.add_argument('file', type=str, nargs="?", - default='simulation.yml', + default=cfg, help='Configuration file for the simulation (e.g., YAML or JSON)') parser.add_argument('--version', action='store_true', help='Show version info and exit') @@ -39,6 +40,8 @@ def main(): help='Do not store the results of the simulation to disk, show in terminal instead.') parser.add_argument('--pdb', action='store_true', help='Use a pdb console in case of exception.') + parser.add_argument('--debug', action='store_true', + help='Run a customized version of a pdb console to debug a simulation.') parser.add_argument('--graph', '-g', action='store_true', help='Dump each trial\'s network topology as a GEXF graph. Defaults to false.') parser.add_argument('--csv', action='store_true', @@ -51,9 +54,22 @@ def main(): help='Run trials serially and synchronously instead of in parallel. Defaults to false.') parser.add_argument('-e', '--exporter', action='append', help='Export environment and/or simulations using this exporter') + parser.add_argument('--only-convert', '--convert', action='store_true', + help='Do not run the simulation, only convert the configuration file(s) and output them.') + + + parser.add_argument("--set", + metavar="KEY=VALUE", + action='append', + help="Set a number of parameters that will be passed to the simulation." + "(do not put spaces before or after the = sign). " + "If a value contains spaces, you should define " + "it with double quotes: " + 'foo="this is a sentence". Note that ' + "values are always treated as strings.") args = parser.parse_args() - logging.basicConfig(level=getattr(logging, (args.level or 'INFO').upper())) + logger.setLevel(getattr(logging, (args.level or 'INFO').upper())) if args.version: return @@ -65,9 +81,10 @@ def main(): logger.info('Loading config file: {}'.format(args.file)) - if args.pdb: + if args.pdb or args.debug: args.synchronous = True - + if args.debug: + os.environ['SOIL_DEBUG'] = 'true' try: exporters = list(args.exporter or ['default', ]) @@ -82,18 +99,48 @@ def main(): if not os.path.exists(args.file): logger.error('Please, input a valid file') return - simulation.run_from_config(args.file, - dry_run=args.dry_run, - exporters=exporters, - parallel=(not args.synchronous), - outdir=args.output, - exporter_params=exp_params) - except Exception: + for sim in simulation.iter_from_config(args.file): + if args.set: + for s in args.set: + k, v = s.split('=', 1)[:2] + v = eval(v) + tail, *head = k.rsplit('.', 1)[::-1] + target = sim + if head: + for part in head[0].split('.'): + try: + target = getattr(target, part) + except AttributeError: + target = target[part] + try: + setattr(target, tail, v) + except AttributeError: + target[tail] = v + + if args.only_convert: + print(sim.to_yaml()) + continue + + sim.run_simulation(dry_run=args.dry_run, + exporters=exporters, + parallel=(not args.synchronous), + outdir=args.output, + exporter_params=exp_params, + **kwargs) + + except Exception as ex: if args.pdb: - pdb.post_mortem() + from .debugging import post_mortem + print(traceback.format_exc()) + post_mortem() else: raise - +def easy(cfg, debug=False): + sim = simulation.from_config(cfg) + if debug or os.environ.get('SOIL_DEBUG'): + from .debugging import setup + setup(sys._getframe().f_back) + return sim if __name__ == '__main__': main() diff --git a/soil/agents/CounterModel.py b/soil/agents/CounterModel.py index d2edc1b..97c7356 100644 --- a/soil/agents/CounterModel.py +++ b/soil/agents/CounterModel.py @@ -7,15 +7,13 @@ class CounterModel(NetworkAgent): in each step and adds it to its state. """ - defaults = { - 'times': 0, - 'neighbors': 0, - 'total': 0 - } + times = 0 + neighbors = 0 + total = 0 def step(self): # Outside effects - total = len(list(self.env.agents)) + total = len(list(self.model.schedule._agents)) neighbors = len(list(self.get_neighboring_agents())) self['times'] = self.get('times', 0) + 1 self['neighbors'] = neighbors @@ -28,17 +26,15 @@ class AggregatedCounter(NetworkAgent): in each step and adds it to its state. """ - defaults = { - 'times': 0, - 'neighbors': 0, - 'total': 0 - } + times = 0 + neighbors = 0 + total = 0 def step(self): # Outside effects self['times'] += 1 neighbors = len(list(self.get_neighboring_agents())) self['neighbors'] += neighbors - total = len(list(self.env.agents)) + total = len(list(self.model.schedule.agents)) self['total'] += total self.debug('Running for step: {}. Total: {}'.format(self.now, total)) diff --git a/soil/agents/__init__.py b/soil/agents/__init__.py index 213c345..c7763f2 100644 --- a/soil/agents/__init__.py +++ b/soil/agents/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from collections import OrderedDict, defaultdict from collections.abc import MutableMapping, Mapping, Set @@ -5,9 +7,13 @@ from abc import ABCMeta from copy import deepcopy, copy from functools import partial, wraps from itertools import islice, chain -import json +import inspect +import types +import textwrap import networkx as nx +from typing import Any + from mesa import Agent as MesaAgent from typing import Dict, List @@ -27,7 +33,31 @@ class DeadAgent(Exception): pass -class BaseAgent(MesaAgent, MutableMapping): +class MetaAgent(ABCMeta): + def __new__(mcls, name, bases, namespace): + defaults = {} + + # Re-use defaults from inherited classes + for i in bases: + if isinstance(i, MetaAgent): + defaults.update(i._defaults) + + new_nmspc = { + '_defaults': defaults, + } + + for attr, func in namespace.items(): + if isinstance(func, types.FunctionType) or isinstance(func, property) or attr[0] == '_': + new_nmspc[attr] = func + elif attr == 'defaults': + defaults.update(func) + else: + defaults[attr] = copy(func) + + return super().__new__(mcls=mcls, name=name, bases=bases, namespace=new_nmspc) + + +class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): """ A special type of Mesa Agent that: @@ -39,15 +69,12 @@ class BaseAgent(MesaAgent, MutableMapping): Any attribute that is not preceded by an underscore (`_`) will also be added to its state. """ - defaults = {} - def __init__(self, unique_id, model, name=None, interval=None, - **kwargs - ): + **kwargs): # Check for REQUIRED arguments # Initialize agent parameters if isinstance(unique_id, MesaAgent): @@ -58,15 +85,16 @@ class BaseAgent(MesaAgent, MutableMapping): self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id) - self._neighbors = None self.alive = True self.interval = interval or self.get('interval', 1) - self.logger = logging.getLogger(self.model.id).getChild(self.name) + logger = utils.logger.getChild(getattr(self.model, 'id', self.model)).getChild(self.name) + self.logger = logging.LoggerAdapter(logger, {'agent_name': self.name}) if hasattr(self, 'level'): self.logger.setLevel(self.level) - for (k, v) in self.defaults.items(): + + for (k, v) in self._defaults.items(): if not hasattr(self, k) or getattr(self, k) is None: setattr(self, k, deepcopy(v)) @@ -74,10 +102,6 @@ class BaseAgent(MesaAgent, MutableMapping): setattr(self, k, v) - for (k, v) in getattr(self, 'defaults', {}).items(): - if not hasattr(self, k) or getattr(self, k) is None: - setattr(self, k, v) - def __hash__(self): return hash(self.unique_id) @@ -89,14 +113,6 @@ class BaseAgent(MesaAgent, MutableMapping): def id(self): return self.unique_id - @property - def env(self): - return self.model - - @env.setter - def env(self, model): - self.model = model - @property def state(self): ''' @@ -108,19 +124,16 @@ class BaseAgent(MesaAgent, MutableMapping): @state.setter def state(self, value): + if not value: + return for k, v in value.items(): self[k] = v - @property - def environment_params(self): - return self.model.environment_params - - @environment_params.setter - def environment_params(self, value): - self.model.environment_params = value - def __getitem__(self, key): - return getattr(self, key) + try: + return getattr(self, key) + except AttributeError: + raise KeyError(f'key {key} not found in agent') def __delitem__(self, key): return delattr(self, key) @@ -138,11 +151,15 @@ class BaseAgent(MesaAgent, MutableMapping): return self.items() def keys(self): - return (k for k in self.__dict__ if k[0] != '_') - - def items(self): - return ((k, v) for (k, v) in self.__dict__.items() if k[0] != '_') + return (k for k in self.__dict__ if k[0] != '_' and k not in IGNORED_FIELDS) + def items(self, keys=None, skip=None): + keys = keys if keys is not None else self.keys() + it = ((k, self.get(k, None)) for k in keys) + if skip: + return filter(lambda x: x[0] not in skip, it) + return it + def get(self, key, default=None): return self[key] if key in self else default @@ -154,11 +171,9 @@ class BaseAgent(MesaAgent, MutableMapping): # No environment return None - def die(self, remove=False): - self.info(f'agent {self.unique_id} is dying') + def die(self): + self.info(f'agent dying') self.alive = False - if remove: - self.remove_node(self.id) return time.NEVER def step(self): @@ -170,7 +185,7 @@ class BaseAgent(MesaAgent, MutableMapping): if not self.logger.isEnabledFor(level): return message = message + " ".join(str(i) for i in args) - message = " @{:>3}: {}".format(self.now, message) + message = "[@{:>4}]\t{:>10}: {}".format(self.now, repr(self), message) for k, v in kwargs: message += " {k}={v} ".format(k, v) extra = {} @@ -179,33 +194,48 @@ class BaseAgent(MesaAgent, MutableMapping): extra['agent_name'] = self.name return self.logger.log(level, message, extra=extra) - - def debug(self, *args, **kwargs): return self.log(*args, level=logging.DEBUG, **kwargs) def info(self, *args, **kwargs): return self.log(*args, level=logging.INFO, **kwargs) -# Alias -# Agent = BaseAgent + def count_agents(self, **kwargs): + return len(list(self.get_agents(**kwargs))) + + def get_agents(self, *args, **kwargs): + it = self.iter_agents(*args, **kwargs) + return list(it) + + def iter_agents(self, *args, **kwargs): + yield from filter_agents(self.model.schedule._agents, *args, **kwargs) + + def __str__(self): + return self.to_str() + + def to_str(self, keys=None, skip=None, pretty=False): + content = dict(self.items(keys=keys)) + if pretty and content: + d = content + content = '\n' + for k, v in d.items(): + content += f'- {k}: {v}\n' + content = textwrap.indent(content, ' ') + return f"{repr(self)}{content}" + + def __repr__(self): + return f"{self.__class__.__name__}({self.unique_id})" + class NetworkAgent(BaseAgent): - @property - def topology(self): - return self.env.topology_for(self.unique_id) + def __init__(self, *args, topology, node_id, **kwargs): + super().__init__(*args, **kwargs) - @property - def node_id(self): - return self.env.node_id_for(self.unique_id) - - @property - def G(self): - return self.model.topologies[self._topology] - - def count_agents(self, **kwargs): - return len(list(self.get_agents(**kwargs))) + self.topology = topology + self.node_id = node_id + self.G = self.model.topologies[topology] + assert self.G def count_neighboring_agents(self, state_id=None, **kwargs): return len(self.get_neighboring_agents(state_id=state_id, **kwargs)) @@ -213,57 +243,47 @@ class NetworkAgent(BaseAgent): def get_neighboring_agents(self, state_id=None, **kwargs): return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs) - def get_agents(self, *args, limit=None, **kwargs): - it = self.iter_agents(*args, **kwargs) - if limit is not None: - it = islice(it, limit) - return list(it) - def iter_agents(self, unique_id=None, limit_neighbors=False, **kwargs): + unique_ids = None + if isinstance(unique_id, list): + unique_ids = set(unique_id) + elif unique_id is not None: + unique_ids = set([unique_id,]) + if limit_neighbors: - unique_id = [self.topology.nodes[node]['agent_id'] for node in self.topology.neighbors(self.node_id)] - if not unique_id: + neighbor_ids = set() + for node_id in self.G.neighbors(self.node_id): + if self.G.nodes[node_id].get('agent_id') is not None: + neighbor_ids.add(node_id) + if unique_ids: + unique_ids = unique_ids & neighbor_ids + else: + unique_ids = neighbor_ids + if not unique_ids: return - - yield from self.model.agents(unique_id=unique_id, **kwargs) - + unique_ids = list(unique_ids) + yield from super().iter_agents(unique_id=unique_ids, **kwargs) def subgraph(self, center=True, **kwargs): include = [self] if center else [] - G = self.topology.subgraph(n.node_id for n in list(self.get_agents(**kwargs)+include)) + G = self.G.subgraph(n.node_id for n in list(self.get_agents(**kwargs)+include)) return G - def remove_node(self, unique_id): - self.topology.remove_node(unique_id) + def remove_node(self): + self.G.remove_node(self.node_id) def add_edge(self, other, edge_attr_dict=None, *edge_attrs): - # return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs) - if self.unique_id not in self.topology.nodes(data=False): + if self.node_id not in self.G.nodes(data=False): raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id)) - if other.unique_id not in self.topology.nodes(data=False): + if other.node_id not in self.G.nodes(data=False): raise ValueError('{} not in list of existing agents in the network'.format(other)) - self.topology.add_edge(self.unique_id, other.unique_id, edge_attr_dict=edge_attr_dict, *edge_attrs) + self.G.add_edge(self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs) - def ego_search(self, steps=1, center=False, node=None, **kwargs): - '''Get a list of nodes in the ego network of *node* of radius *steps*''' - node = as_node(node if node is not None else self) - G = self.subgraph(**kwargs) - return nx.ego_graph(G, node, center=center, radius=steps).nodes() - - def degree(self, node, force=False): - node = as_node(node) - if force or (not hasattr(self.model, '_degree')) or getattr(self.model, '_last_step', 0) < self.now: - self.model._degree = nx.degree_centrality(self.topology) - self.model._last_step = self.now - return self.model._degree[node] - - def betweenness(self, node, force=False): - node = as_node(node) - if force or (not hasattr(self.model, '_betweenness')) or getattr(self.model, '_last_step', 0) < self.now: - self.model._betweenness = nx.betweenness_centrality(self.topology) - self.model._last_step = self.now - return self.model._betweenness[node] + def die(self, remove=True): + if remove: + self.remove_node() + return super().die() def state(name=None): @@ -273,24 +293,29 @@ def state(name=None): The default value for state_id is the current state id. The default value for when is the interval defined in the environment. ''' + if inspect.isgeneratorfunction(func): + orig_func = func - @wraps(func) - def func_wrapper(self): - next_state = func(self) - when = None - if next_state is None: - return when - try: - next_state, when = next_state - except (ValueError, TypeError): - pass - if next_state: - self.set_state(next_state) - return when + @wraps(func) + def func(self): + while True: + if not self._coroutine: + self._coroutine = orig_func(self) + try: + n = next(self._coroutine) + if n: + return None, n + return + except StopIteration as ex: + self._coroutine = None + next_state = ex.value + if next_state is not None: + self.set_state(next_state) + return next_state - func_wrapper.id = name or func.__name__ - func_wrapper.is_default = False - return func_wrapper + func.id = name or func.__name__ + func.is_default = False + return func if callable(name): return decorator(name) @@ -303,60 +328,84 @@ def default_state(func): return func -class MetaFSM(ABCMeta): - def __init__(cls, name, bases, nmspc): - super(MetaFSM, cls).__init__(name, bases, nmspc) +class MetaFSM(MetaAgent): + def __new__(mcls, name, bases, namespace): states = {} # Re-use states from inherited classes default_state = None for i in bases: if isinstance(i, MetaFSM): - for state_id, state in i.states.items(): + for state_id, state in i._states.items(): if state.is_default: default_state = state states[state_id] = state # Add new states - for name, func in nmspc.items(): + for attr, func in namespace.items(): if hasattr(func, 'id'): if func.is_default: default_state = func states[func.id] = func - cls.default_state = default_state - cls.states = states + + namespace.update({ + '_default_state': default_state, + '_states': states, + }) + + return super(MetaFSM, mcls).__new__(mcls=mcls, name=name, bases=bases, namespace=namespace) class FSM(BaseAgent, metaclass=MetaFSM): def __init__(self, *args, **kwargs): super(FSM, self).__init__(*args, **kwargs) if not hasattr(self, 'state_id'): - if not self.default_state: + if not self._default_state: raise ValueError('No default state specified for {}'.format(self.unique_id)) - self.state_id = self.default_state.id + self.state_id = self._default_state.id + self._coroutine = None self.set_state(self.state_id) def step(self): self.debug(f'Agent {self.unique_id} @ state {self.state_id}') - interval = super().step() - if 'id' not in self.state: - if self.default_state: - self.set_state(self.default_state.id) - else: - raise Exception('{} has no valid state id or default state'.format(self)) - interval = self.states[self.state_id](self) or interval - if not self.alive: - return time.NEVER - return interval + default_interval = super().step() - def set_state(self, state): + next_state = self._states[self.state_id](self) + + when = None + try: + next_state, *when = next_state + if not when: + when = None + elif len(when) == 1: + when = when[0] + else: + raise ValueError('Too many values returned. Only state (and time) allowed') + except TypeError: + pass + + if next_state is not None: + self.set_state(next_state) + + return when or default_interval + + def set_state(self, state, when=None): if hasattr(state, 'id'): state = state.id - if state not in self.states: + if state not in self._states: raise ValueError('{} is not a valid state'.format(state)) self.state_id = state + if when is not None: + self.model.schedule.add(self, when=when) return state + def die(self): + return self.dead, super().die() + + @state + def dead(self): + return self.die() + def prob(prob, random): ''' @@ -476,81 +525,81 @@ def _convert_agent_classs(ind, to_string=False, **kwargs): return deserialize_definition(ind, **kwargs) -def _agent_from_definition(definition, random, value=-1, unique_id=None): - """Used in the initialization of agents given an agent distribution.""" - if value < 0: - value = random.random() - for d in sorted(definition, key=lambda x: x.get('threshold')): - threshold = d.get('threshold', (-1, -1)) - # Check if the definition matches by id (first) or by threshold - if (unique_id is not None and unique_id in d.get('ids', [])) or \ - (value >= threshold[0] and value < threshold[1]): - state = {} - if 'state' in d: - state = deepcopy(d['state']) - return d['agent_class'], state +# def _agent_from_definition(definition, random, value=-1, unique_id=None): +# """Used in the initialization of agents given an agent distribution.""" +# if value < 0: +# value = random.random() +# for d in sorted(definition, key=lambda x: x.get('threshold')): +# threshold = d.get('threshold', (-1, -1)) +# # Check if the definition matches by id (first) or by threshold +# if (unique_id is not None and unique_id in d.get('ids', [])) or \ +# (value >= threshold[0] and value < threshold[1]): +# state = {} +# if 'state' in d: +# state = deepcopy(d['state']) +# return d['agent_class'], state - raise Exception('Definition for value {} not found in: {}'.format(value, definition)) +# raise Exception('Definition for value {} not found in: {}'.format(value, definition)) -def _definition_to_dict(definition, random, size=None, default_state=None): - state = default_state or {} - agents = {} - remaining = {} - if size: - for ix in range(size): - remaining[ix] = copy(state) - else: - remaining = defaultdict(lambda x: copy(state)) +# def _definition_to_dict(definition, random, size=None, default_state=None): +# state = default_state or {} +# agents = {} +# remaining = {} +# if size: +# for ix in range(size): +# remaining[ix] = copy(state) +# else: +# remaining = defaultdict(lambda x: copy(state)) - distro = sorted([item for item in definition if 'weight' in item]) +# distro = sorted([item for item in definition if 'weight' in item]) - id = 0 +# id = 0 - def init_agent(item, id=ix): - while id in agents: - id += 1 +# def init_agent(item, id=ix): +# while id in agents: +# id += 1 - agent = remaining[id] - agent['state'].update(copy(item.get('state', {}))) - agents[agent.unique_id] = agent - del remaining[id] - return agent +# agent = remaining[id] +# agent['state'].update(copy(item.get('state', {}))) +# agents[agent.unique_id] = agent +# del remaining[id] +# return agent - for item in definition: - if 'ids' in item: - ids = item['ids'] - del item['ids'] - for id in ids: - agent = init_agent(item, id) +# for item in definition: +# if 'ids' in item: +# ids = item['ids'] +# del item['ids'] +# for id in ids: +# agent = init_agent(item, id) - for item in definition: - if 'number' in item: - times = item['number'] - del item['number'] - for times in range(times): - if size: - ix = random.choice(remaining.keys()) - agent = init_agent(item, id) - else: - agent = init_agent(item) - if not size: - return agents +# for item in definition: +# if 'number' in item: +# times = item['number'] +# del item['number'] +# for times in range(times): +# if size: +# ix = random.choice(remaining.keys()) +# agent = init_agent(item, id) +# else: +# agent = init_agent(item) +# if not size: +# return agents - if len(remaining) < 0: - raise Exception('Invalid definition. Too many agents to add') +# if len(remaining) < 0: +# raise Exception('Invalid definition. Too many agents to add') - total_weight = float(sum(s['weight'] for s in distro)) - unit = size / total_weight +# total_weight = float(sum(s['weight'] for s in distro)) +# unit = size / total_weight - for item in distro: - times = unit * item['weight'] - del item['weight'] - for times in range(times): - ix = random.choice(remaining.keys()) - agent = init_agent(item, id) - return agents +# for item in distro: +# times = unit * item['weight'] +# del item['weight'] +# for times in range(times): +# ix = random.choice(remaining.keys()) +# agent = init_agent(item, id) +# return agents class AgentView(Mapping, Set): @@ -571,59 +620,43 @@ class AgentView(Mapping, Set): # Mapping methods def __len__(self): - return sum(len(x) for x in self._agents.values()) + return len(self._agents) def __iter__(self): - yield from iter(chain.from_iterable(g.values() for g in self._agents.values())) + yield from self._agents.values() def __getitem__(self, agent_id): if isinstance(agent_id, slice): raise ValueError(f"Slicing is not supported") - for group in self._agents.values(): - if agent_id in group: - return group[agent_id] + if agent_id in self._agents: + return self._agents[agent_id] raise ValueError(f"Agent {agent_id} not found") def filter(self, *args, **kwargs): - yield from filter_groups(self._agents, *args, **kwargs) + yield from filter_agents(self._agents, *args, **kwargs) def one(self, *args, **kwargs): - return next(filter_groups(self._agents, *args, **kwargs)) + return next(filter_agents(self._agents, *args, **kwargs)) def __call__(self, *args, **kwargs): return list(self.filter(*args, **kwargs)) def __contains__(self, agent_id): - return any(agent_id in g for g in self._agents) + return agent_id in self._agents def __str__(self): - return str(list(a.unique_id for a in self)) + return str(list(unique_id for unique_id in self.keys())) def __repr__(self): return f"{self.__class__.__name__}({self})" -def filter_groups(groups, *, group=None, **kwargs): - assert isinstance(groups, dict) - - if group is not None and not isinstance(group, list): - group = [group] - - if group: - groups = list(groups[g] for g in group if g in groups) - else: - groups = list(groups.values()) - - agents = chain.from_iterable(filter_group(g, **kwargs) for g in groups) - - yield from agents - - -def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=None, ignore=None, state=None, **kwargs): +def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=None, ignore=None, state=None, + limit=None, **kwargs): ''' Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id). ''' - assert isinstance(group, dict) + assert isinstance(agents, dict) ids = [] @@ -636,6 +669,11 @@ def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=Non if id_args: ids += id_args + if ids: + f = (agents[aid] for aid in ids if aid in agents) + else: + f = (a for a in agents.values()) + if state_id is not None and not isinstance(state_id, (tuple, list)): state_id = tuple([state_id]) @@ -646,12 +684,6 @@ def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=Non except TypeError: agent_class = tuple([agent_class]) - if ids: - agents = (group[aid] for aid in ids if aid in group) - else: - agents = (a for a in group.values()) - - f = agents if ignore: f = filter(lambda x: x not in ignore, f) @@ -667,83 +699,125 @@ def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=Non for k, v in state.items(): f = filter(lambda agent: agent.state.get(k, None) == v, f) + if limit is not None: + f = islice(f, limit) + yield from f -def from_config(cfg: Dict[str, config.AgentConfig], env, random): +def from_config(cfg: config.AgentConfig, random, topologies: Dict[str, nx.Graph] = None) -> List[Dict[str, Any]]: ''' - Agents are specified in groups. - Each group can be specified in two ways, either through a fixed list in which each item has - has the agent type, number of agents to create, and the other parameters, or through what we call - an `agent distribution`, which is similar but instead of number of agents, it specifies the weight - of each agent type. + This function turns an agentconfig into a list of individual "agent specifications", which are just a dictionary + with the parameters that the environment will use to construct each agent. + + This function does NOT return a list of agents, mostly because some attributes to the agent are not known at the + time of calling this function, such as `unique_id`. ''' - default = cfg.get('default', None) - return {k: _group_from_config(c, default=default, env=env, random=random) for (k, c) in cfg.items() if k is not 'default'} + default = cfg or config.AgentConfig() + if not isinstance(cfg, config.AgentConfig): + cfg = config.AgentConfig(**cfg) + return _agents_from_config(cfg, topologies=topologies, random=random) -def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfig, env, random): +def _agents_from_config(cfg: config.AgentConfig, + topologies: Dict[str, nx.Graph], + random) -> List[Dict[str, Any]]: if cfg and not isinstance(cfg, config.AgentConfig): cfg = config.AgentConfig(**cfg) - if default and not isinstance(default, config.SingleAgentConfig): - default = config.SingleAgentConfig(**default) - agents = {} + agents = [] + + assigned = defaultdict(int) + if cfg.fixed is not None: - agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env) - if cfg.distribution: - n = cfg.n or len(env.topologies[cfg.topology or default.topology]) - target = n - len(agents) - agents.update(_from_distro(cfg.distribution, target, - topology=cfg.topology or default.topology, - default=default, - env=env, random=random)) - assert len(agents) == n - if cfg.override: - for attrs in cfg.override: - if attrs.filter: - filtered = list(filter_group(agents, **attrs.filter)) - else: - filtered = list(agents) + agents, counts = _from_fixed(cfg.fixed, topology=cfg.topology, default=cfg) + assigned.update(counts) - if attrs.n > len(filtered): - raise ValueError(f'Not enough agents to sample. Got {len(filtered)}, expected >= {attrs.n}') - for agent in random.sample(filtered, attrs.n): - agent.state.update(attrs.state) + n = cfg.n + + if cfg.distribution: + topo_size = {top: len(topologies[top]) for top in topologies} + + grouped = defaultdict(list) + total = [] + + for d in cfg.distribution: + if d.strategy == config.Strategy.topology: + topology = d.topology if ('topology' in d.__fields_set__) else cfg.topology + if not topology: + raise ValueError('The "topology" strategy only works if the topology parameter is specified') + if topology not in topo_size: + raise ValueError(f'Unknown topology selected: { topology }. Make sure the topology has been defined') + + grouped[topology].append(d) + + if d.strategy == config.Strategy.total: + if not cfg.n: + raise ValueError('Cannot use the "total" strategy without providing the total number of agents') + total.append(d) + + + for (topo, distro) in grouped.items(): + if not topologies or topo not in topo_size: + raise ValueError( + 'You need to specify a target number of agents for the distribution \ + or a configuration with a topology, along with a dictionary with \ + all the available topologies') + n = len(topologies[topo]) + target = topo_size[topo] - assigned[topo] + new_agents = _from_distro(cfg.distribution, target, + topology=topo, + default=cfg, + random=random) + assigned[topo] += len(new_agents) + agents += new_agents + + if total: + remaining = n - sum(assigned.values()) + agents += _from_distro(total, remaining, + topology='', # DO NOT assign to any topology + default=cfg, + random=random) + + + if sum(assigned.values()) != sum(topo_size.values()): + utils.logger.warn(f'The total number of agents does not match the total number of nodes in ' + 'every topology. This may be due to a definition error: assigned: ' + f'{ assigned } total sizes: { topo_size }') return agents -def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: config.SingleAgentConfig, env): - agents = {} +def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: config.SingleAgentConfig) -> List[Dict[str, Any]]: + agents = [] + + counts = {} for fixed in lst: - agent_id = fixed.agent_id - if agent_id is None: - agent_id = env.next_id() + agent = {} + if default: + agent = default.state.copy() + agent.update(fixed.state) + cls = serialization.deserialize(fixed.agent_class or (default and default.agent_class)) + agent['agent_class'] = cls + topo = fixed.topology if ('topology' in fixed.__fields_set__) else topology or default.topology - cls = serialization.deserialize(fixed.agent_class or default.agent_class) - state = fixed.state.copy() - state.update(default.state) - agent = cls(unique_id=agent_id, - model=env, - **state) - topology = fixed.topology if (fixed.topology is not None) else (topology or default.topology) - if topology: - env.agent_to_node(agent_id, topology, fixed.node_id) - agents[agent.unique_id] = agent + if topo: + agent['topology'] = topo + if not fixed.hidden: + counts[topo] = counts.get(topo, 0) + 1 + agents.append(agent) - return agents + return agents, counts def _from_distro(distro: List[config.AgentDistro], n: int, topology: str, default: config.SingleAgentConfig, - env, - random): + random) -> List[Dict[str, Any]]: - agents = {} + agents = [] if n is None: if any(lambda dist: dist.n is None, distro): @@ -775,19 +849,16 @@ def _from_distro(distro: List[config.AgentDistro], for idx in indices: d = distro[idx] + agent = d.state.copy() cls = classes[idx] - agent_id = env.next_id() - state = d.state.copy() + agent['agent_class'] = cls if default: - state.update(default.state) - agent = cls(unique_id=agent_id, model=env, **state) - topology = d.topology if (d.topology is not None) else topology or default.topology + agent.update(default.state) + # agent = cls(unique_id=agent_id, model=env, **state) + topology = d.topology if ('topology' in d.__fields_set__) else topology or default.topology if topology: - env.agent_to_node(agent.unique_id, topology) - assert agent.name is not None - assert agent.name != 'None' - assert agent.name - agents[agent.unique_id] = agent + agent['topology'] = topology + agents.append(agent) return agents diff --git a/soil/analysis.py b/soil/analysis.py deleted file mode 100644 index 65d8468..0000000 --- a/soil/analysis.py +++ /dev/null @@ -1,206 +0,0 @@ -import pandas as pd - -import glob -import yaml -from os.path import join - -from . import serialization -from tsih import History - - -def read_data(*args, group=False, **kwargs): - iterable = _read_data(*args, **kwargs) - if group: - return group_trials(iterable) - else: - return list(iterable) - - -def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs): - if not process_args: - process_args = {} - for folder in glob.glob(pattern): - config_file = glob.glob(join(folder, '*.yml'))[0] - config = yaml.load(open(config_file), Loader=yaml.SafeLoader) - df = None - if from_csv: - for trial_data in sorted(glob.glob(join(folder, - '*.environment.csv'))): - df = read_csv(trial_data, **kwargs) - yield config_file, df, config - else: - for trial_data in sorted(glob.glob(join(folder, '*.sqlite'))): - df = read_sql(trial_data, **kwargs) - yield config_file, df, config - - -def read_sql(db, *args, **kwargs): - h = History(db_path=db, backup=False, readonly=True) - df = h.read_sql(*args, **kwargs) - return df - - -def read_csv(filename, keys=None, convert_types=False, **kwargs): - ''' - Read a CSV in canonical form: :: - - - - ''' - df = pd.read_csv(filename) - if convert_types: - df = convert_types_slow(df) - if keys: - df = df[df['key'].isin(keys)] - df = process_one(df) - return df - - -def convert_row(row): - row['value'] = serialization.deserialize(row['value_type'], row['value']) - return row - - -def convert_types_slow(df): - ''' - Go over every column in a dataframe and convert it to the type determined by the `get_types` - function. - - This is a slow operation. - ''' - dtypes = get_types(df) - for k, v in dtypes.items(): - t = df[df['key']==k] - t['value'] = t['value'].astype(v) - df = df.apply(convert_row, axis=1) - return df - - -def split_processed(df): - env = df.loc[:, df.columns.get_level_values(1).isin(['env', 'stats'])] - agents = df.loc[:, ~df.columns.get_level_values(1).isin(['env', 'stats'])] - return env, agents - - -def split_df(df): - ''' - Split a dataframe in two dataframes: one with the history of agents, - and one with the environment history - ''' - envmask = (df['agent_id'] == 'env') - n_env = envmask.sum() - if n_env == len(df): - return df, None - elif n_env == 0: - return None, df - agents, env = [x for _, x in df.groupby(envmask)] - return env, agents - - -def process(df, **kwargs): - ''' - Process a dataframe in canonical form ``(t_step, agent_id, key, value, value_type)`` into - two dataframes with a column per key: one with the history of the agents, and one for the - history of the environment. - ''' - env, agents = split_df(df) - return process_one(env, **kwargs), process_one(agents, **kwargs) - - -def get_types(df): - ''' - Get the value type for every key stored in a raw history dataframe. - ''' - dtypes = df.groupby(by=['key'])['value_type'].unique() - return {k:v[0] for k,v in dtypes.iteritems()} - - -def process_one(df, *keys, columns=['key', 'agent_id'], values='value', - fill=True, index=['t_step',], - aggfunc='first', **kwargs): - ''' - Process a dataframe in canonical form ``(t_step, agent_id, key, value, value_type)`` into - a dataframe with a column per key - ''' - if df is None: - return df - if keys: - df = df[df['key'].isin(keys)] - - df = df.pivot_table(values=values, index=index, columns=columns, - aggfunc=aggfunc, **kwargs) - if fill: - df = fillna(df) - return df - - -def get_count(df, *keys): - ''' - For every t_step and key, get the value count. - - The result is a dataframe with `t_step` as index, an a multiindex column based on `key` and the values found for each `key`. - ''' - if keys: - df = df[list(keys)] - df.columns = df.columns.remove_unused_levels() - counts = pd.DataFrame() - for key in df.columns.levels[0]: - g = df[[key]].apply(pd.Series.value_counts, axis=1).fillna(0) - for value, series in g.iteritems(): - counts[key, value] = series - counts.columns = pd.MultiIndex.from_tuples(counts.columns) - return counts - - -def get_majority(df, *keys): - ''' - For every t_step and key, get the value of the majority of agents - - The result is a dataframe with `t_step` as index, and columns based on `key`. - ''' - df = get_count(df, *keys) - return df.stack(level=0).idxmax(axis=1).unstack() - - -def get_value(df, *keys, aggfunc='sum'): - ''' - For every t_step and key, get the value of *numeric columns*, aggregated using a specific function. - ''' - if keys: - df = df[list(keys)] - df.columns = df.columns.remove_unused_levels() - df = df.select_dtypes('number') - return df.groupby(level='key', axis=1).agg(aggfunc) - - -def plot_all(*args, plot_args={}, **kwargs): - ''' - Read all the trial data and plot the result of applying a function on them. - ''' - dfs = do_all(*args, **kwargs) - ps = [] - for line in dfs: - f, df, config = line - if len(df) < 1: - continue - df.plot(title=config['name'], **plot_args) - ps.append(df) - return ps - -def do_all(pattern, func, *keys, include_env=False, **kwargs): - for config_file, df, config in read_data(pattern, keys=keys): - if len(df) < 1: - continue - p = func(df, *keys, **kwargs) - yield config_file, p, config - - -def group_trials(trials, aggfunc=['mean', 'min', 'max', 'std']): - trials = list(trials) - trials = list(map(lambda x: x[1] if isinstance(x, tuple) else x, trials)) - return pd.concat(trials).groupby(level=0).agg(aggfunc).reorder_levels([2, 0,1] ,axis=1) - - -def fillna(df): - new_df = df.ffill(axis=0) - return new_df diff --git a/soil/config.py b/soil/config.py index cf4cee2..20934db 100644 --- a/soil/config.py +++ b/soil/config.py @@ -1,12 +1,18 @@ from __future__ import annotations + +from enum import Enum from pydantic import BaseModel, ValidationError, validator, root_validator import yaml import os import sys + from typing import Any, Callable, Dict, List, Optional, Union, Type from pydantic import BaseModel, Extra + +from . import environment, utils + import networkx as nx @@ -36,7 +42,6 @@ class NetParams(BaseModel, extra=Extra.allow): class NetConfig(BaseModel): - group: str = 'network' params: Optional[NetParams] topology: Optional[Union[Topology, nx.Graph]] path: Optional[str] @@ -56,9 +61,6 @@ class NetConfig(BaseModel): class EnvConfig(BaseModel): - environment_class: Union[Type, str] = 'soil.Environment' - params: Dict[str, Any] = {} - schedule: Union[Type, str] = 'soil.time.TimedActivation' @staticmethod def default(): @@ -67,19 +69,19 @@ class EnvConfig(BaseModel): class SingleAgentConfig(BaseModel): agent_class: Optional[Union[Type, str]] = None - agent_id: Optional[int] = None + unique_id: Optional[int] = None topology: Optional[str] = None node_id: Optional[Union[int, str]] = None - name: Optional[str] = None state: Optional[Dict[str, Any]] = {} + class FixedAgentConfig(SingleAgentConfig): n: Optional[int] = 1 + hidden: Optional[bool] = False # Do not count this agent towards total agent count @root_validator def validate_all(cls, values): if values.get('agent_id', None) is not None and values.get('n', 1) > 1: - print(values) raise ValueError(f"An agent_id can only be provided when there is only one agent ({values.get('n')} given)") return values @@ -88,13 +90,19 @@ class OverrideAgentConfig(FixedAgentConfig): filter: Optional[Dict[str, Any]] = None +class Strategy(Enum): + topology = 'topology' + total = 'total' + + class AgentDistro(SingleAgentConfig): weight: Optional[float] = 1 + strategy: Strategy = Strategy.topology class AgentConfig(SingleAgentConfig): n: Optional[int] = None - topology: Optional[str] = None + topology: Optional[str] distribution: Optional[List[AgentDistro]] = None fixed: Optional[List[FixedAgentConfig]] = None override: Optional[List[OverrideAgentConfig]] = None @@ -110,19 +118,32 @@ class AgentConfig(SingleAgentConfig): return values -class Config(BaseModel, extra=Extra.forbid): +class Config(BaseModel, extra=Extra.allow): version: Optional[str] = '1' - id: str = 'Unnamed Simulation' + name: str = 'Unnamed Simulation' + description: Optional[str] = None group: str = None dir_path: Optional[str] = None num_trials: int = 1 max_time: float = 100 + max_steps: int = -1 interval: float = 1 seed: str = "" + dry_run: bool = False - model_class: Union[Type, str] - model_parameters: Optiona[Dict[str, Any]] = {} + model_class: Union[Type, str] = environment.Environment + model_params: Optional[Dict[str, Any]] = {} + + visualization_params: Optional[Dict[str, Any]] = {} + + @classmethod + def from_raw(cls, cfg): + if isinstance(cfg, Config): + return cfg + if cfg.get('version', '1') == '1' and any(k in cfg for k in ['agents', 'agent_class', 'topology', 'environment_class']): + return convert_old(cfg) + return Config(**cfg) def convert_old(old, strict=True): @@ -132,87 +153,84 @@ def convert_old(old, strict=True): This is still a work in progress and might not work in many cases. ''' - #TODO: implement actual conversion - print('The old configuration format is no longer supported. \ - Update your config files or run Soil==0.20') - raise NotImplementedError() + utils.logger.warning('The old configuration format is deprecated. The converted file MAY NOT yield the right results') - - new = {} - - general = {} - for k in ['id', - 'group', - 'dir_path', - 'num_trials', - 'max_time', - 'interval', - 'seed']: - if k in old: - general[k] = old[k] - - if 'name' in old: - general['id'] = old['name'] + new = old.copy() network = {} + if 'topology' in old: + del new['topology'] + network['topology'] = old['topology'] if 'network_params' in old and old['network_params']: + del new['network_params'] for (k, v) in old['network_params'].items(): if k == 'path': network['path'] = v else: network.setdefault('params', {})[k] = v - if 'topology' in old: - network['topology'] = old['topology'] + topologies = {} + if network: + topologies['default'] = network - agents = { - 'network': {}, - 'default': {}, - } - - if 'agent_class' in old: - agents['default']['agent_class'] = old['agent_class'] - - if 'default_state' in old: - agents['default']['state'] = old['default_state'] + agents = {'fixed': [], 'distribution': []} def updated_agent(agent): + '''Convert an agent definition''' newagent = dict(agent) - newagent['agent_class'] = newagent['agent_class'] - del newagent['agent_class'] return newagent - for agent in old.get('environment_agents', []): - agents['environment'] = {'distribution': [], 'fixed': []} - if 'agent_id' in agent: - agent['name'] = agent['agent_id'] - del agent['agent_id'] - agents['environment']['fixed'].append(updated_agent(agent)) - by_weight = [] fixed = [] override = [] - if 'network_agents' in old: - agents['network']['topology'] = 'default' + if 'environment_agents' in new: - for agent in old['network_agents']: + for agent in new['environment_agents']: + agent.setdefault('state', {})['group'] = 'environment' + if 'agent_id' in agent: + agent['state']['name'] = agent['agent_id'] + del agent['agent_id'] + agent['hidden'] = True + agent['topology'] = None + fixed.append(updated_agent(agent)) + del new['environment_agents'] + + + if 'agent_class' in old: + del new['agent_class'] + agents['agent_class'] = old['agent_class'] + + if 'default_state' in old: + del new['default_state'] + agents['state'] = old['default_state'] + + if 'network_agents' in old: + agents['topology'] = 'default' + + agents.setdefault('state', {})['group'] = 'network' + + for agent in new['network_agents']: agent = updated_agent(agent) if 'agent_id' in agent: + agent['state']['name'] = agent['agent_id'] + del agent['agent_id'] fixed.append(agent) else: by_weight.append(agent) + del new['network_agents'] if 'agent_class' in old and (not fixed and not by_weight): - agents['network']['topology'] = 'default' - by_weight = [{'agent_class': old['agent_class']}] + agents['topology'] = 'default' + by_weight = [{'agent_class': old['agent_class'], 'weight': 1}] # TODO: translate states properly if 'states' in old: + del new['states'] states = old['states'] if isinstance(states, dict): states = states.items() @@ -220,22 +238,29 @@ def convert_old(old, strict=True): states = enumerate(states) for (k, v) in states: override.append({'filter': {'node_id': k}, - 'state': v - }) + 'state': v}) - agents['network']['override'] = override - agents['network']['fixed'] = fixed - agents['network']['distribution'] = by_weight + agents['override'] = override + agents['fixed'] = fixed + agents['distribution'] = by_weight + + + model_params = {} + if 'environment_params' in new: + del new['environment_params'] + model_params = dict(old['environment_params']) - environment = {'params': {}} if 'environment_class' in old: - environment['environment_class'] = old['environment_class'] + del new['environment_class'] + new['model_class'] = old['environment_class'] - for (k, v) in old.get('environment_params', {}).items(): - environment['params'][k] = v + if 'dump' in old: + del new['dump'] + new['dry_run'] = not old['dump'] + + model_params['topologies'] = topologies + model_params['agents'] = agents return Config(version='2', - general=general, - topologies={'default': network}, - environment=environment, - agents=agents) + model_params=model_params, + **new) diff --git a/soil/debugging.py b/soil/debugging.py new file mode 100644 index 0000000..98c25e1 --- /dev/null +++ b/soil/debugging.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import pdb +import sys +import os + +from textwrap import indent +from functools import wraps + +from .agents import FSM, MetaFSM + + +def wrapcmd(func): + @wraps(func) + def wrapper(self, arg: str, temporary=False): + sys.settrace(self.trace_dispatch) + + known = globals() + known.update(self.curframe.f_globals) + known.update(self.curframe.f_locals) + known['agent'] = known.get('self', None) + known['model'] = known.get('self', {}).get('model') + known['attrs'] = arg.strip().split() + + exec(func.__code__, known, known) + + return wrapper + + +class Debug(pdb.Pdb): + def __init__(self, *args, skip_soil=False, **kwargs): + skip = kwargs.get('skip', []) + if skip_soil: + skip.append('soil.*') + skip.append('mesa.*') + super(Debug, self).__init__(*args, skip=skip, **kwargs) + self.prompt = "[soil-pdb] " + + @staticmethod + def _soil_agents(model, attrs=None, pretty=True, **kwargs): + for agent in model.agents(**kwargs): + d = agent + print(' - ' + indent(agent.to_str(keys=attrs, pretty=pretty), ' ')) + + @wrapcmd + def do_soil_agents(): + return Debug._soil_agents(model, attrs=attrs or None) + + do_sa = do_soil_agents + + @wrapcmd + def do_soil_list(): + return Debug._soil_agents(model, attrs=['state_id'], pretty=False) + + do_sl = do_soil_list + + @wrapcmd + def do_soil_self(): + if not agent: + print('No agent available') + return + + keys = None + if attrs: + keys = [] + for k in attrs: + for key in agent.keys(): + if key.startswith(k): + keys.append(key) + + print(agent.to_str(pretty=True, keys=keys)) + + do_ss = do_soil_self + + def do_break_state(self, arg: str, temporary=False): + ''' + Break before a specified state is stepped into. + ''' + + klass = None + state = arg.strip() + if not state: + self.error("Specify at least a state name") + return + + comma = arg.find(':') + if comma > 0: + state = arg[comma+1:].lstrip() + klass = arg[:comma].rstrip() + klass = eval(klass, + self.curframe.f_globals, + self.curframe_locals) + + if klass: + klasses = [klass] + else: + klasses = [k for k in self.curframe.f_globals.values() if isinstance(k, type) and issubclass(k, FSM)] + print(klasses) + if not klasses: + self.error('No agent classes found') + + for klass in klasses: + try: + func = getattr(klass, state) + except AttributeError: + continue + if hasattr(func, '__func__'): + func = func.__func__ + + code = func.__code__ + #use co_name to identify the bkpt (function names + #could be aliased, but co_name is invariant) + funcname = code.co_name + lineno = code.co_firstlineno + filename = code.co_filename + + # Check for reasonable breakpoint + line = self.checkline(filename, lineno) + if not line: + raise ValueError('no line found') + # now set the break point + cond = None + existing = self.get_breaks(filename, line) + if existing: + self.message("Breakpoint already exists at %s:%d" % + (filename, line)) + continue + err = self.set_break(filename, line, temporary, cond, funcname) + if err: + self.error(err) + else: + bp = self.get_breaks(filename, line)[-1] + self.message("Breakpoint %d at %s:%d" % + (bp.number, bp.file, bp.line)) + do_bs = do_break_state + + +def setup(frame=None): + debugger = Debug() + frame = frame or sys._getframe().f_back + debugger.set_trace(frame) + +def debug_env(): + if os.environ.get('SOIL_DEBUG'): + return setup(frame=sys._getframe().f_back) + +def post_mortem(traceback=None): + p = Debug() + t = sys.exc_info()[2] + p.reset() + p.interaction(None, t) diff --git a/soil/environment.py b/soil/environment.py index 2f59553..303a00f 100644 --- a/soil/environment.py +++ b/soil/environment.py @@ -1,4 +1,5 @@ from __future__ import annotations + import os import sqlite3 import math @@ -17,9 +18,7 @@ import networkx as nx from mesa import Model from mesa.datacollection import DataCollector -from . import serialization, analysis, utils, time, network - -from .agents import AgentView, BaseAgent, NetworkAgent, from_config as agents_from_config +from . import agents as agentmod, config, serialization, utils, time, network Record = namedtuple('Record', 'dict_id t_step key value') @@ -39,12 +38,12 @@ class BaseEnvironment(Model): """ def __init__(self, - env_id='unnamed_env', + id='unnamed_env', seed='default', schedule=None, dir_path=None, interval=1, - agent_class=BaseAgent, + agent_class=None, agents: [tuple[type, Dict[str, Any]]] = {}, agent_reporters: Optional[Any] = None, model_reporters: Optional[Any] = None, @@ -54,7 +53,7 @@ class BaseEnvironment(Model): super().__init__(seed=seed) self.current_id = -1 - self.id = env_id + self.id = id self.dir_path = dir_path or os.getcwd() @@ -62,7 +61,7 @@ class BaseEnvironment(Model): schedule = time.TimedActivation(self) self.schedule = schedule - self.agent_class = agent_class + self.agent_class = agent_class or agentmod.BaseAgent self.init_agents(agents) @@ -78,25 +77,51 @@ class BaseEnvironment(Model): tables=tables, ) - def __read_agent_tuple(self, tup): - cls = self.agent_class - args = tup - if isinstance(tup, tuple): - cls = tup[0] - args = tup[1] - return serialization.deserialize(cls)(unique_id=self.next_id(), - model=self, **args) + def _read_single_agent(self, agent): + agent = dict(**agent) + cls = agent.pop('agent_class', None) or self.agent_class + unique_id = agent.pop('unique_id', None) + if unique_id is None: + unique_id = self.next_id() + + return serialization.deserialize(cls)(unique_id=unique_id, + model=self, **agent) + + def init_agents(self, agents: Union[config.AgentConfig, [Dict[str, Any]]] = {}): + if not agents: + return + + lst = agents + override = [] + if not isinstance(lst, list): + if not isinstance(agents, config.AgentConfig): + lst = config.AgentConfig(**agents) + if lst.override: + override = lst.override + lst = agentmod.from_config(lst, + topologies=getattr(self, 'topologies', None), + random=self.random) + + #TODO: check override is working again. It cannot (easily) be part of agents.from_config anymore, + # because it needs attribute such as unique_id, which are only present after init + new_agents = [self._read_single_agent(agent) for agent in lst] + + + for a in new_agents: + self.schedule.add(a) + + for rule in override: + for agent in agentmod.filter_agents(self.schedule._agents, **rule.filter): + for attr, value in rule.state.items(): + setattr(agent, attr, value) - def init_agents(self, agents: [tuple[type, Dict[str, Any]]] = {}): - agents = [self.__read_agent_tuple(tup) for tup in agents] - self._agents = {'default': {agent.id: agent for agent in agents}} @property def agents(self): - return AgentView(self._agents) + return agentmod.AgentView(self.schedule._agents) def find_one(self, *args, **kwargs): - return AgentView(self._agents).one(*args, **kwargs) + return agentmod.AgentView(self.schedule._agents).one(*args, **kwargs) def count_agents(self, *args, **kwargs): return sum(1 for i in self.agents(*args, **kwargs)) @@ -108,38 +133,12 @@ class BaseEnvironment(Model): raise Exception('The environment has not been scheduled, so it has no sense of time') - # def init_agent(self, agent_id, agent_definitions, state=None): - # state = state or {} - - # agent_class = None - # if 'agent_class' in self.states.get(agent_id, {}): - # agent_class = self.states[agent_id]['agent_class'] - # elif 'agent_class' in self.default_state: - # agent_class = self.default_state['agent_class'] - - # if agent_class: - # agent_class = agents.deserialize_type(agent_class) - # elif agent_definitions: - # agent_class, state = agents._agent_from_definition(agent_definitions, unique_id=agent_id) - # else: - # serialization.logger.debug('Skipping agent {}'.format(agent_id)) - # return - # return self.add_agent(agent_id, agent_class, state) - - - def add_agent(self, agent_id, agent_class, state=None, graph='default'): - defstate = deepcopy(self.default_state) or {} - defstate.update(self.states.get(agent_id, {})) - if state: - defstate.update(state) + def add_agent(self, agent_id, agent_class, **kwargs): a = None if agent_class: - state = defstate a = agent_class(model=self, - unique_id=agent_id) - - for (k, v) in state.items(): - setattr(a, k, v) + unique_id=agent_id, + **kwargs) self.schedule.add(a) return a @@ -153,7 +152,7 @@ class BaseEnvironment(Model): message += " {k}={v} ".format(k, v) extra = {} extra['now'] = self.now - extra['unique_id'] = self.id + extra['id'] = self.id return self.logger.log(level, message, extra=extra) def step(self): @@ -161,6 +160,7 @@ class BaseEnvironment(Model): Advance one step in the simulation, and update the data collection and scheduler appropriately ''' super().step() + self.logger.info(f'--- Step {self.now:^5} ---') self.schedule.step() self.datacollector.collect(self) @@ -207,34 +207,41 @@ class BaseEnvironment(Model): yield from self._agent_to_tuples(agent, now) -class AgentConfigEnvironment(BaseEnvironment): +class NetworkEnvironment(BaseEnvironment): - def __init__(self, *args, - agents: Dict[str, config.AgentConfig] = {}, - **kwargs): - return super().__init__(*args, agents=agents, **kwargs) - - def init_agents(self, agents: Union[Dict[str, config.AgentConfig], [tuple[type, Dict[str, Any]]]] = {}): - if not isinstance(agents, dict): - return BaseEnvironment.init_agents(self, agents) - - self._agents = agents_from_config(agents, - env=self, - random=self.random) - for d in self._agents.values(): - for a in d.values(): - self.schedule.add(a) - - -class NetworkConfigEnvironment(BaseEnvironment): - - def __init__(self, *args, topologies: Dict[str, config.NetConfig] = {}, **kwargs): - super().__init__(*args, **kwargs) - self.topologies = {} + def __init__(self, *args, topology: nx.Graph = None, topologies: Dict[str, config.NetConfig] = {}, **kwargs): + agents = kwargs.pop('agents', None) + super().__init__(*args, agents=None, **kwargs) self._node_ids = {} + assert not hasattr(self, 'topologies') + if topology is not None: + if topologies: + raise ValueError('Please, provide either a single topology or a dictionary of them') + topologies = {'default': topology} + + self.topologies = {} for (name, cfg) in topologies.items(): self.set_topology(cfg=cfg, graph=name) + self.init_agents(agents) + + + def _read_single_agent(self, agent, unique_id=None): + agent = dict(agent) + + if agent.get('topology', None) is not None: + topology = agent.get('topology') + if unique_id is None: + unique_id = self.next_id() + if topology: + node_id = self.agent_to_node(unique_id, graph_name=topology, node_id=agent.get('node_id')) + agent['node_id'] = node_id + agent['topology'] = topology + agent['unique_id'] = unique_id + + return super()._read_single_agent(agent) + + @property def topology(self): return self.topologies['default'] @@ -246,51 +253,50 @@ class NetworkConfigEnvironment(BaseEnvironment): self.topologies[graph] = topology - def topology_for(self, agent_id): - return self.topologies[self._node_ids[agent_id][0]] + def topology_for(self, unique_id): + return self.topologies[self._node_ids[unique_id][0]] @property def network_agents(self): - yield from self.agents(agent_class=NetworkAgent) + yield from self.agents(agent_class=agentmod.NetworkAgent) - def agent_to_node(self, agent_id, graph_name='default', node_id=None, shuffle=False): - node_id = network.agent_to_node(G=self.topologies[graph_name], agent_id=agent_id, - node_id=node_id, shuffle=shuffle, + def agent_to_node(self, unique_id, graph_name='default', + node_id=None, shuffle=False): + node_id = network.agent_to_node(G=self.topologies[graph_name], + agent_id=unique_id, + node_id=node_id, + shuffle=shuffle, random=self.random) - self._node_ids[agent_id] = (graph_name, node_id) + self._node_ids[unique_id] = (graph_name, node_id) + return node_id + def add_node(self, agent_class, topology, **kwargs): + unique_id = self.next_id() + self.topologies[topology].add_node(unique_id) + node_id = self.agent_to_node(unique_id=unique_id, node_id=unique_id, graph_name=topology) - def add_node(self, agent_class, state=None, graph='default'): - agent_id = int(len(self.topologies[graph].nodes())) - self.topologies[graph].add_node(agent_id) - a = self.add_agent(agent_id, agent_class, state, graph=graph) + a = self.add_agent(unique_id=unique_id, agent_class=agent_class, node_id=node_id, topology=topology, **kwargs) a['visible'] = True return a def add_edge(self, agent1, agent2, start=None, graph='default', **attrs): - if hasattr(agent1, 'id'): - agent1 = agent1.id - if hasattr(agent2, 'id'): - agent2 = agent2.id - start = start or self.now - return self.topologies[graph].add_edge(agent1, agent2, **attrs) + agent1 = agent1.node_id + agent2 = agent2.node_id + return self.topologies[graph].add_edge(agent1, agent2, start=start) - def add_agent(self, *args, state=None, graph='default', **kwargs): - node = self.topologies[graph].nodes[agent_id] + def add_agent(self, unique_id, state=None, graph='default', **kwargs): + node = self.topologies[graph].nodes[unique_id] node_state = node.get('state', {}) if node_state: node_state.update(state or {}) state = node_state - a = super().add_agent(*args, state=state, **kwargs) + a = super().add_agent(unique_id, state=state, **kwargs) node['agent'] = a return a def node_id_for(self, agent_id): return self._node_ids[agent_id][1] -class Environment(AgentConfigEnvironment, NetworkConfigEnvironment): - def __init__(self, *args, **kwargs): - agents = kwargs.pop('agents', {}) - NetworkConfigEnvironment.__init__(self, *args, **kwargs) - AgentConfigEnvironment.__init__(self, *args, agents=agents, **kwargs) + +Environment = NetworkEnvironment diff --git a/soil/exporters.py b/soil/exporters.py index 20a0f92..055afd4 100644 --- a/soil/exporters.py +++ b/soil/exporters.py @@ -12,7 +12,7 @@ from .serialization import deserialize from .utils import open_or_reuse, logger, timer -from . import utils +from . import utils, network class DryRunner(BytesIO): @@ -85,38 +85,28 @@ class Exporter: class default(Exporter): '''Default exporter. Writes sqlite results, as well as the simulation YAML''' - # def sim_start(self): - # if not self.dry_run: - # logger.info('Dumping results to %s', self.outdir) - # self.simulation.dump_yaml(outdir=self.outdir) - # else: - # logger.info('NOT dumping results') + def sim_start(self): + if not self.dry_run: + logger.info('Dumping results to %s', self.outdir) + with self.output(self.simulation.name + '.dumped.yml') as f: + f.write(self.simulation.to_yaml()) + else: + logger.info('NOT dumping results') - # def trial_start(self, env, stats): - # if not self.dry_run: - # with timer('Dumping simulation {} trial {}'.format(self.simulation.name, - # env.name)): - # engine = create_engine('sqlite:///{}.sqlite'.format(env.name), echo=False) + def trial_end(self, env): + if not self.dry_run: + with timer('Dumping simulation {} trial {}'.format(self.simulation.name, + env.id)): + engine = create_engine('sqlite:///{}.sqlite'.format(env.id), echo=False) - # dc = env.datacollector - # tables = {'env': dc.get_model_vars_dataframe(), - # 'agents': dc.get_agent_vars_dataframe(), - # 'agents': dc.get_agent_vars_dataframe()} - # for table in dc.tables: - # tables[table] = dc.get_table_dataframe(table) - # for (t, df) in tables.items(): - # df.to_sql(t, con=engine) - - # def sim_end(self, stats): - # with timer('Dumping simulation {}\'s stats'.format(self.simulation.name)): - # engine = create_engine('sqlite:///{}.sqlite'.format(self.simulation.name), echo=False) - # with self.output('{}.sqlite'.format(self.simulation.name), mode='wb') as f: - # self.simulation.dump_sqlite(f) + dc = env.datacollector + for (t, df) in get_dc_dfs(dc): + df.to_sql(t, con=engine, if_exists='append') def get_dc_dfs(dc): dfs = {'env': dc.get_model_vars_dataframe(), - 'agents': dc.get_agent_vars_dataframe } + 'agents': dc.get_agent_vars_dataframe() } for table_name in dc.tables: dfs[table_name] = dc.get_table_dataframe(table_name) yield from dfs.items() @@ -130,10 +120,11 @@ class csv(Exporter): env.id, self.outdir)): for (df_name, df) in get_dc_dfs(env.datacollector): - with self.output('{}.stats.{}.csv'.format(env.id, df_name)) as f: + with self.output('{}.{}.csv'.format(env.id, df_name)) as f: df.to_csv(f) +#TODO: reimplement GEXF exporting without history class gexf(Exporter): def trial_end(self, env): if self.dry_run: @@ -143,18 +134,9 @@ class gexf(Exporter): with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name, env.id)): with self.output('{}.gexf'.format(env.id), mode='wb') as f: + network.dump_gexf(env.history_to_graph(), f) self.dump_gexf(env, f) - def dump_gexf(self, env, f): - G = env.history_to_graph() - # Workaround for geometric models - # See soil/soil#4 - for node in G.nodes(): - if 'pos' in G.nodes[node]: - G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}} - del (G.nodes[node]['pos']) - - nx.write_gexf(G, f, version="1.2draft") class dummy(Exporter): diff --git a/soil/network.py b/soil/network.py index 25b55ab..0836f35 100644 --- a/soil/network.py +++ b/soil/network.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Dict import os import sys @@ -37,8 +39,10 @@ def from_config(cfg: config.NetConfig, dir_path: str = None): known_modules=['networkx.generators',]) return method(**net_args) - if isinstance(cfg.topology, basestring) or isinstance(cfg.topology, dict): - return nx.json_graph.node_link_graph(cfg.topology) + if isinstance(cfg.topology, config.Topology): + cfg = cfg.topology.dict() + if isinstance(cfg, str) or isinstance(cfg, dict): + return nx.json_graph.node_link_graph(cfg) return nx.Graph() @@ -57,9 +61,18 @@ def agent_to_node(G, agent_id, node_id=None, shuffle=False, random=random): for next_id, data in candidates: if data.get('agent_id', None) is None: node_id = next_id - data['agent_id'] = agent_id break if node_id is None: raise ValueError(f"Not enough nodes in topology to assign one to agent {agent_id}") + G.nodes[node_id]['agent_id'] = agent_id return node_id + + +def dump_gexf(G, f): + for node in G.nodes(): + if 'pos' in G.nodes[node]: + G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}} + del (G.nodes[node]['pos']) + + nx.write_gexf(G, f, version="1.2draft") diff --git a/soil/serialization.py b/soil/serialization.py index 328efdd..9c2af63 100644 --- a/soil/serialization.py +++ b/soil/serialization.py @@ -7,6 +7,8 @@ import importlib from glob import glob from itertools import product, chain +from .config import Config + import yaml import networkx as nx @@ -120,22 +122,25 @@ def params_for_template(config): def load_files(*patterns, **kwargs): for pattern in patterns: for i in glob(pattern, **kwargs): - for config in load_file(i): + for cfg in load_file(i): path = os.path.abspath(i) - yield config, path + yield Config.from_raw(cfg), path -def load_config(config): - if isinstance(config, dict): - yield config, os.getcwd() +def load_config(cfg): + if isinstance(cfg, Config): + yield cfg, os.getcwd() + elif isinstance(cfg, dict): + yield Config.from_raw(cfg), os.getcwd() else: - yield from load_files(config) + yield from load_files(cfg) builtins = importlib.import_module('builtins') KNOWN_MODULES = ['soil', ] + def name(value, known_modules=KNOWN_MODULES): '''Return a name that can be imported, to serialize/deserialize an object''' if value is None: @@ -172,8 +177,22 @@ def serialize(v, known_modules=KNOWN_MODULES): return func(v), tname +def serialize_dict(d, known_modules=KNOWN_MODULES): + d = dict(d) + for (k, v) in d.items(): + if isinstance(v, dict): + d[k] = serialize_dict(v, known_modules=known_modules) + elif isinstance(v, list): + for ix in range(len(v)): + v[ix] = serialize_dict(v[ix], known_modules=known_modules) + elif isinstance(v, type): + d[k] = serialize(v, known_modules=known_modules)[1] + return d + + IS_CLASS = re.compile(r"") + def deserializer(type_, known_modules=KNOWN_MODULES): if type(type_) != str: # Already deserialized return type_ diff --git a/soil/simulation.py b/soil/simulation.py index 4d50c30..1ed5dbc 100644 --- a/soil/simulation.py +++ b/soil/simulation.py @@ -4,15 +4,17 @@ import importlib import sys import yaml import traceback +import inspect import logging import networkx as nx +from textwrap import dedent + from dataclasses import dataclass, field, asdict -from typing import Union +from typing import Any, Dict, Union, Optional from networkx.readwrite import json_graph -from multiprocessing import Pool from functools import partial import pickle @@ -21,7 +23,6 @@ from .environment import Environment from .utils import logger, run_and_return_exceptions from .exporters import default from .time import INFINITY - from .config import Config, convert_old @@ -36,7 +37,9 @@ class Simulation: kwargs: parameters to use to initialize a new configuration, if one has not been provided. """ + version: str = '2' name: str = 'Unnamed simulation' + description: Optional[str] = '' group: str = None model_class: Union[str, type] = 'soil.Environment' model_params: dict = field(default_factory=dict) @@ -44,30 +47,37 @@ class Simulation: dir_path: str = field(default_factory=lambda: os.getcwd()) max_time: float = float('inf') max_steps: int = -1 + interval: int = 1 num_trials: int = 3 dry_run: bool = False + extra: Dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, env): + + ignored = {k: v for k, v in env.items() + if k not in inspect.signature(cls).parameters} + + kwargs = {k:v for k, v in env.items() if k not in ignored} + if ignored: + kwargs.setdefault('extra', {}).update(ignored) + if ignored: + print(f'Warning: Ignoring these parameters (added to "extra"): { ignored }') + + return cls(**kwargs) def run_simulation(self, *args, **kwargs): return self.run(*args, **kwargs) def run(self, *args, **kwargs): '''Run the simulation and return the list of resulting environments''' + logger.info(dedent(''' + Simulation: + --- + ''') + + self.to_yaml()) return list(self.run_gen(*args, **kwargs)) - def _run_sync_or_async(self, parallel=False, **kwargs): - if parallel and not os.environ.get('SENPY_DEBUG', None): - p = Pool() - func = partial(run_and_return_exceptions, self.run_trial, **kwargs) - for i in p.imap_unordered(func, self.num_trials): - if isinstance(i, Exception): - logger.error('Trial failed:\n\t%s', i.message) - continue - yield i - else: - for i in range(self.num_trials): - yield self.run_trial(trial_id=i, - **kwargs) - def run_gen(self, parallel=False, dry_run=False, exporters=[default, ], outdir=None, exporter_params={}, log_level=None, @@ -88,9 +98,11 @@ class Simulation: for exporter in exporters: exporter.sim_start() - for env in self._run_sync_or_async(parallel=parallel, - log_level=log_level, - **kwargs): + for env in utils.run_parallel(func=self.run_trial, + iterable=range(int(self.num_trials)), + parallel=parallel, + log_level=log_level, + **kwargs): for exporter in exporters: exporter.trial_start(env) @@ -103,14 +115,6 @@ class Simulation: for exporter in exporters: exporter.sim_end() - def run_model(self, until=None, *args, **kwargs): - until = until or float('inf') - - while self.schedule.next_time < until: - self.step() - utils.logger.debug(f'Simulation step {self.schedule.time}/{until}. Next: {self.schedule.next_time}') - self.schedule.time = until - def get_env(self, trial_id=0, **kwargs): '''Create an environment for a trial of the simulation''' def deserialize_reporters(reporters): @@ -132,56 +136,76 @@ class Simulation: model_reporters=model_reporters, **model_params) - def run_trial(self, trial_id=None, until=None, log_level=logging.INFO, **opts): + def run_trial(self, trial_id=None, until=None, log_file=False, log_level=logging.INFO, **opts): """ Run a single trial of the simulation """ - model = self.get_env(trial_id, **opts) - return self.run_model(model, trial_id=trial_id, until=until, log_level=log_level) - - def run_model(self, model, trial_id=None, until=None, log_level=logging.INFO, **opts): - trial_id = trial_id if trial_id is not None else current_time() if log_level: logger.setLevel(log_level) + model = self.get_env(trial_id, **opts) + trial_id = trial_id if trial_id is not None else current_time() + with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)): + return self.run_model(model=model, trial_id=trial_id, until=until, log_level=log_level) + + def run_model(self, model, until=None, **opts): # Set-up trial environment and graph - until = until or self.max_time + until = float(until or self.max_time or 'inf') # Set up agents on nodes - is_done = lambda: False - if self.max_time and hasattr(self.schedule, 'time'): - is_done = lambda x: is_done() or self.schedule.time >= self.max_time - if self.max_steps and hasattr(self.schedule, 'time'): - is_done = lambda: is_done() or self.schedule.steps >= self.max_steps + def is_done(): + return False - with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)): - while not is_done(): - utils.logger.debug(f'Simulation time {model.schedule.time}/{until}. Next: {getattr(model.schedule, "next_time", model.schedule.time + self.interval)}') - model.step() + if until and hasattr(model.schedule, 'time'): + prev = is_done + + def is_done(): + return prev() or model.schedule.time >= until + + if self.max_steps and self.max_steps > 0 and hasattr(model.schedule, 'steps'): + prev_steps = is_done + + def is_done(): + return prev_steps() or model.schedule.steps >= self.max_steps + + newline = '\n' + logger.info(dedent(f''' +Model stats: + Agents (total: { model.schedule.get_agent_count() }): + - { (newline + ' - ').join(str(a) for a in model.schedule.agents) }''' +f''' + + Topologies (size): + - { dict( (k, len(v)) for (k, v) in model.topologies.items()) } +''' if getattr(model, "topologies", None) else '' +)) + + while not is_done(): + utils.logger.debug(f'Simulation time {model.schedule.time}/{until}. Next: {getattr(model.schedule, "next_time", model.schedule.time + self.interval)}') + model.step() return model def to_dict(self): d = asdict(self) - d['model_class'] = serialization.serialize(d['model_class'])[0] - d['model_params'] = serialization.serialize(d['model_params'])[0] + if not isinstance(d['model_class'], str): + d['model_class'] = serialization.name(d['model_class']) + d['model_params'] = serialization.serialize_dict(d['model_params']) d['dir_path'] = str(d['dir_path']) - + d['version'] = '2' return d def to_yaml(self): - return yaml.dump(self.asdict()) + return yaml.dump(self.to_dict()) -def iter_from_config(config): - configs = list(serialization.load_config(config)) - for config, path in configs: - d = dict(config) - if 'dir_path' not in d: - d['dir_path'] = os.path.dirname(path) - if d.get('version', '2') == '1' or 'agents' in d or 'network_agents' in d or 'environment_agents' in d: - d = convert_old(d) - d.pop('version', None) - yield Simulation(**d) +def iter_from_config(*cfgs): + for config in cfgs: + configs = list(serialization.load_config(config)) + for config, path in configs: + d = dict(config) + if 'dir_path' not in d: + d['dir_path'] = os.path.dirname(path) + yield Simulation.from_dict(d) def from_config(conf_or_path): @@ -192,6 +216,6 @@ def from_config(conf_or_path): def run_from_config(*configs, **kwargs): - for sim in iter_from_config(configs): - logger.info(f"Using config(s): {sim.id}") + for sim in iter_from_config(*configs): + logger.info(f"Using config(s): {sim.name}") sim.run_simulation(**kwargs) diff --git a/soil/time.py b/soil/time.py index b2faf46..b95c51e 100644 --- a/soil/time.py +++ b/soil/time.py @@ -1,6 +1,6 @@ from mesa.time import BaseScheduler from queue import Empty -from heapq import heappush, heappop +from heapq import heappush, heappop, heapify import math from .utils import logger from mesa import Agent as MesaAgent @@ -17,6 +17,7 @@ class When: def abs(self, time): return self._time + NEVER = When(INFINITY) @@ -38,14 +39,22 @@ class TimedActivation(BaseScheduler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._next = {} self._queue = [] self.next_time = 0 self.logger = logger.getChild(f'time_{ self.model }') - def add(self, agent: MesaAgent): - if agent.unique_id not in self._agents: - heappush(self._queue, (self.time, agent.unique_id)) - super().add(agent) + def add(self, agent: MesaAgent, when=None): + if when is None: + when = self.time + if agent.unique_id in self._agents: + self._queue.remove((self._next[agent.unique_id], agent.unique_id)) + del self._agents[agent.unique_id] + heapify(self._queue) + + heappush(self._queue, (when, agent.unique_id)) + self._next[agent.unique_id] = when + super().add(agent) def step(self) -> None: """ @@ -64,11 +73,18 @@ class TimedActivation(BaseScheduler): (when, agent_id) = heappop(self._queue) self.logger.debug(f'Stepping agent {agent_id}') - returned = self._agents[agent_id].step() + agent = self._agents[agent_id] + returned = agent.step() + + if not agent.alive: + self.remove(agent) + continue + when = (returned or Delta(1)).abs(self.time) if when < self.time: raise Exception("Cannot schedule an agent for a time in the past ({} < {})".format(when, self.time)) + self._next[agent_id] = when heappush(self._queue, (when, agent_id)) self.steps += 1 @@ -77,7 +93,7 @@ class TimedActivation(BaseScheduler): self.time = INFINITY self.next_time = INFINITY self.model.running = False - return + return self.time self.next_time = self._queue[0][0] self.logger.debug(f'Next step: {self.next_time}') diff --git a/soil/utils.py b/soil/utils.py index cd82588..faa34d1 100644 --- a/soil/utils.py +++ b/soil/utils.py @@ -3,13 +3,27 @@ from time import time as current_time, strftime, gmtime, localtime import os import traceback +from functools import partial from shutil import copyfile +from multiprocessing import Pool from contextlib import contextmanager logger = logging.getLogger('soil') -# logging.basicConfig() -# logger.setLevel(logging.INFO) +logger.setLevel(logging.INFO) + +timeformat = "%H:%M:%S" + +if os.environ.get('SOIL_VERBOSE', ''): + logformat = "[%(levelname)-5.5s][%(asctime)s][%(name)s]: %(message)s" +else: + logformat = "[%(levelname)-5.5s][%(asctime)s] %(message)s" + +logFormatter = logging.Formatter(logformat, timeformat) + +consoleHandler = logging.StreamHandler() +consoleHandler.setFormatter(logFormatter) +logger.addHandler(consoleHandler) @contextmanager @@ -27,8 +41,6 @@ def timer(name='task', pre="", function=logger.info, to_object=None): to_object.end = end - - def safe_open(path, mode='r', backup=True, **kwargs): outdir = os.path.dirname(path) if outdir and not os.path.exists(outdir): @@ -41,7 +53,7 @@ def safe_open(path, mode='r', backup=True, **kwargs): if not os.path.exists(backup_dir): os.makedirs(backup_dir) newpath = os.path.join(backup_dir, '{}@{}'.format(os.path.basename(path), - stamp)) + stamp)) copyfile(path, newpath) return open(path, mode=mode, **kwargs) @@ -92,7 +104,7 @@ def unflatten_dict(d): return out -def run_and_return_exceptions(self, func, *args, **kwargs): +def run_and_return_exceptions(func, *args, **kwargs): ''' A wrapper for run_trial that catches exceptions and returns them. It is meant for async simulations. @@ -104,3 +116,18 @@ def run_and_return_exceptions(self, func, *args, **kwargs): ex = ex.__cause__ ex.message = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)[:]) return ex + + +def run_parallel(func, iterable, parallel=False, **kwargs): + if parallel and not os.environ.get('SOIL_DEBUG', None): + p = Pool() + wrapped_func = partial(run_and_return_exceptions, + func, **kwargs) + for i in p.imap_unordered(wrapped_func, iterable): + if isinstance(i, Exception): + logger.error('Trial failed:\n\t%s', i.message) + continue + yield i + else: + for i in iterable: + yield func(i, **kwargs) diff --git a/tests/complete_converted.yml b/tests/complete_converted.yml index 36a0a96..d1c3358 100644 --- a/tests/complete_converted.yml +++ b/tests/complete_converted.yml @@ -1,49 +1,50 @@ --- version: '2' -general: - id: simple - group: tests - dir_path: "/tmp/" - num_trials: 3 - max_time: 100 - interval: 1 - seed: "CompleteSeed!" -topologies: - default: - params: - generator: complete_graph - n: 10 -agents: - default: +name: simple +group: tests +dir_path: "/tmp/" +num_trials: 3 +max_time: 100 +interval: 1 +seed: "CompleteSeed!" +model_class: Environment +model_params: + topologies: + default: + params: + generator: complete_graph + n: 4 + agents: agent_class: CounterModel state: + group: network times: 1 - network: topology: 'default' distribution: - agent_class: CounterModel - weight: 0.4 + weight: 0.25 state: state_id: 0 + times: 1 - agent_class: AggregatedCounter - weight: 0.6 - override: - - filter: - node_id: 0 + weight: 0.5 state: - name: 'The first node' + times: 2 + override: - filter: node_id: 1 state: - name: 'The second node' - - environment: - fixed: - - name: 'Environment Agent 1' - agent_class: CounterModel + name: 'Node 1' + - filter: + node_id: 2 state: + name: 'Node 2' + fixed: + - agent_class: BaseAgent + hidden: true + topology: null + state: + name: 'Environment Agent 1' times: 10 -environment: - environment_class: Environment - params: - am_i_complete: true + group: environment + am_i_complete: true diff --git a/tests/old_complete.yml b/tests/old_complete.yml index 517abc4..9e4315b 100644 --- a/tests/old_complete.yml +++ b/tests/old_complete.yml @@ -8,17 +8,20 @@ interval: 1 seed: "CompleteSeed!" network_params: generator: complete_graph - n: 10 + n: 4 network_agents: - agent_class: CounterModel - weight: 0.4 + weight: 0.25 state: state_id: 0 + times: 1 - agent_class: AggregatedCounter - weight: 0.6 + weight: 0.5 + state: + times: 2 environment_agents: - agent_id: 'Environment Agent 1' - agent_class: CounterModel + agent_class: BaseAgent state: times: 10 environment_class: Environment @@ -28,5 +31,7 @@ agent_class: CounterModel default_state: times: 1 states: - - name: 'The first node' - - name: 'The second node' + 1: + name: 'Node 1' + 2: + name: 'Node 2' diff --git a/tests/test_agents.py b/tests/test_agents.py index e95c11c..cb33f1f 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -8,7 +8,7 @@ class Dead(agents.FSM): @agents.default_state @agents.state def only(self): - self.die() + return self.die() class TestMain(TestCase): def test_die_raises_exception(self): @@ -19,4 +19,6 @@ class TestMain(TestCase): def test_die_returns_infinity(self): d = Dead(unique_id=0, model=environment.Environment()) - assert d.step().abs(0) == stime.INFINITY + ret = d.step().abs(0) + print(ret, 'next') + assert ret == stime.INFINITY diff --git a/tests/test_analysis.py b/tests/test_analysis.py deleted file mode 100644 index 204b4dd..0000000 --- a/tests/test_analysis.py +++ /dev/null @@ -1,91 +0,0 @@ -from unittest import TestCase - -import os -import pandas as pd -import yaml -from functools import partial - -from os.path import join -from soil import simulation, analysis, agents - - -ROOT = os.path.abspath(os.path.dirname(__file__)) - - -class Ping(agents.FSM): - - defaults = { - 'count': 0, - } - - @agents.default_state - @agents.state - def even(self): - self.debug(f'Even {self["count"]}') - self['count'] += 1 - return self.odd - - @agents.state - def odd(self): - self.debug(f'Odd {self["count"]}') - self['count'] += 1 - return self.even - - -class TestAnalysis(TestCase): - - # Code to generate a simple sqlite history - def setUp(self): - """ - The initial states should be applied to the agent and the - agent should be able to update its state.""" - config = { - 'name': 'analysis', - 'seed': 'seed', - 'network_params': { - 'generator': 'complete_graph', - 'n': 2 - }, - 'agent_class': Ping, - 'states': [{'interval': 1}, {'interval': 2}], - 'max_time': 30, - 'num_trials': 1, - 'history': True, - 'environment_params': { - } - } - s = simulation.from_config(config) - self.env = s.run_simulation(dry_run=True)[0] - - def test_saved(self): - env = self.env - assert env.get_agent(0)['count', 0] == 1 - assert env.get_agent(0)['count', 29] == 30 - assert env.get_agent(1)['count', 0] == 1 - assert env.get_agent(1)['count', 29] == 15 - assert env['env', 29, None]['SEED'] == env['env', 29, 'SEED'] - - def test_count(self): - env = self.env - df = analysis.read_sql(env._history.db_path) - res = analysis.get_count(df, 'SEED', 'state_id') - assert res['SEED'][self.env['SEED']].iloc[0] == 1 - assert res['SEED'][self.env['SEED']].iloc[-1] == 1 - assert res['state_id']['odd'].iloc[0] == 2 - assert res['state_id']['even'].iloc[0] == 0 - assert res['state_id']['odd'].iloc[-1] == 1 - assert res['state_id']['even'].iloc[-1] == 1 - - def test_value(self): - env = self.env - df = analysis.read_sql(env._history.db_path) - res_sum = analysis.get_value(df, 'count') - - assert res_sum['count'].iloc[0] == 2 - - import numpy as np - res_mean = analysis.get_value(df, 'count', aggfunc=np.mean) - assert res_mean['count'].iloc[15] == (16+8)/2 - - res_total = analysis.get_majority(df) - res_total['SEED'].iloc[0] == self.env['SEED'] diff --git a/tests/test_config.py b/tests/test_config.py index fd9fc70..3597844 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -29,7 +29,7 @@ class TestConfig(TestCase): expected = serialization.load_file(join(ROOT, "complete_converted.yml"))[0] old = serialization.load_file(join(ROOT, "old_complete.yml"))[0] converted_defaults = config.convert_old(old, strict=False) - converted = converted_defaults.dict(skip_defaults=True) + converted = converted_defaults.dict(exclude_unset=True) isequal(converted, expected) @@ -40,10 +40,10 @@ class TestConfig(TestCase): """ config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0] s = simulation.from_config(config) - init_config = copy.copy(s.config) + init_config = copy.copy(s.to_dict()) s.run_simulation(dry_run=True) - nconfig = s.config + nconfig = s.to_dict() # del nconfig['to isequal(init_config, nconfig) @@ -61,7 +61,7 @@ class TestConfig(TestCase): Simple configuration that tests that the graph is loaded, and that network agents are initialized properly. """ - config = { + cfg = { 'name': 'CounterAgent', 'network_params': { 'path': join(ROOT, 'test.gexf') @@ -74,12 +74,14 @@ class TestConfig(TestCase): 'environment_params': { } } - s = simulation.from_old_config(config) + conf = config.convert_old(cfg) + s = simulation.from_config(conf) + env = s.get_env() assert len(env.topologies['default'].nodes) == 2 assert len(env.topologies['default'].edges) == 1 assert len(env.agents) == 2 - assert env.agents[0].topology == env.topologies['default'] + assert env.agents[0].G == env.topologies['default'] def test_agents_from_config(self): '''We test that the known complete configuration produces @@ -87,12 +89,9 @@ class TestConfig(TestCase): cfg = serialization.load_file(join(ROOT, "complete_converted.yml"))[0] s = simulation.from_config(cfg) env = s.get_env() - assert len(env.topologies['default'].nodes) == 10 - assert len(env.agents(group='network')) == 10 + assert len(env.topologies['default'].nodes) == 4 + assert len(env.agents(group='network')) == 4 assert len(env.agents(group='environment')) == 1 - - assert sum(1 for a in env.agents(group='network', agent_class=agents.CounterModel)) == 4 - assert sum(1 for a in env.agents(group='network', agent_class=agents.AggregatedCounter)) == 6 def test_yaml(self): """ diff --git a/tests/test_examples.py b/tests/test_examples.py index 1cc4cca..af77c33 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -2,7 +2,7 @@ from unittest import TestCase import os from os.path import join -from soil import serialization, simulation +from soil import serialization, simulation, config ROOT = os.path.abspath(os.path.dirname(__file__)) EXAMPLES = join(ROOT, '..', 'examples') @@ -14,36 +14,37 @@ class TestExamples(TestCase): pass -def make_example_test(path, config): +def make_example_test(path, cfg): def wrapped(self): root = os.getcwd() - for s in simulation.all_from_config(path): - iterations = s.config.general.max_time * s.config.general.num_trials - if iterations > 1000: - s.config.general.max_time = 100 - s.config.general.num_trials = 1 - if config.get('skip_test', False) and not FORCE_TESTS: + for s in simulation.iter_from_config(cfg): + iterations = s.max_steps * s.num_trials + if iterations < 0 or iterations > 1000: + s.max_steps = 100 + s.num_trials = 1 + assert isinstance(cfg, config.Config) + if getattr(cfg, 'skip_test', False) and not FORCE_TESTS: self.skipTest('Example ignored.') envs = s.run_simulation(dry_run=True) assert envs for env in envs: assert env try: - n = config['network_params']['n'] + n = cfg.model_params['network_params']['n'] assert len(list(env.network_agents)) == n - assert env.now > 0 # It has run - assert env.now <= config['max_time'] # But not further than allowed except KeyError: pass + assert env.schedule.steps > 0 # It has run + assert env.schedule.steps <= s.max_steps # But not further than allowed return wrapped def add_example_tests(): - for config, path in serialization.load_files( + for cfg, path in serialization.load_files( join(EXAMPLES, '*', '*.yml'), join(EXAMPLES, '*.yml'), ): - p = make_example_test(path=path, config=config) + p = make_example_test(path=path, cfg=config.Config.from_raw(cfg)) fname = os.path.basename(path) p.__name__ = 'test_example_file_%s' % fname p.__doc__ = '%s should be a valid configuration' % fname diff --git a/tests/test_exporters.py b/tests/test_exporters.py index debc14a..cbd88bd 100644 --- a/tests/test_exporters.py +++ b/tests/test_exporters.py @@ -6,6 +6,8 @@ import shutil from unittest import TestCase from soil import exporters from soil import simulation +from soil import agents + class Dummy(exporters.Exporter): started = False @@ -33,28 +35,36 @@ class Dummy(exporters.Exporter): class Exporters(TestCase): def test_basic(self): + # We need to add at least one agent to make sure the scheduler + # ticks every step + num_trials = 5 + max_time = 2 config = { 'name': 'exporter_sim', - 'network_params': {}, - 'agent_class': 'CounterModel', - 'max_time': 2, - 'num_trials': 5, - 'environment_params': {} + 'model_params': { + 'agents': [{ + 'agent_class': agents.BaseAgent + }] + }, + 'max_time': max_time, + 'num_trials': num_trials, } s = simulation.from_config(config) + for env in s.run_simulation(exporters=[Dummy], dry_run=True): - assert env.now <= 2 + assert len(env.agents) == 1 + assert env.now == max_time assert Dummy.started assert Dummy.ended assert Dummy.called_start == 1 assert Dummy.called_end == 1 - assert Dummy.called_trial == 5 - assert Dummy.trials == 5 - assert Dummy.total_time == 2*5 + assert Dummy.called_trial == num_trials + assert Dummy.trials == num_trials + assert Dummy.total_time == max_time * num_trials def test_writing(self): - '''Try to write CSV, GEXF, sqlite and YAML (without dry_run)''' + '''Try to write CSV, sqlite and YAML (without dry_run)''' n_trials = 5 config = { 'name': 'exporter_sim', @@ -74,7 +84,6 @@ class Exporters(TestCase): envs = s.run_simulation(exporters=[ exporters.default, exporters.csv, - exporters.gexf, ], dry_run=False, outdir=tmpdir, @@ -88,11 +97,7 @@ class Exporters(TestCase): try: for e in envs: - with open(os.path.join(simdir, '{}.gexf'.format(e.name))) as f: - result = f.read() - assert result - - with open(os.path.join(simdir, '{}.csv'.format(e.name))) as f: + with open(os.path.join(simdir, '{}.env.csv'.format(e.id))) as f: result = f.read() assert result finally: diff --git a/tests/test_history.py b/tests/test_history.py deleted file mode 100644 index 773cfd6..0000000 --- a/tests/test_history.py +++ /dev/null @@ -1,128 +0,0 @@ -from unittest import TestCase - -import os -import io -import yaml -import copy -import pickle -import networkx as nx -from functools import partial - -from os.path import join -from soil import (simulation, Environment, agents, serialization, - utils) -from soil.time import Delta -from tsih import NoHistory, History - - -ROOT = os.path.abspath(os.path.dirname(__file__)) -EXAMPLES = join(ROOT, '..', 'examples') - - -class CustomAgent(agents.FSM): - @agents.default_state - @agents.state - def normal(self): - self.neighbors = self.count_agents(state_id='normal', - limit_neighbors=True) - @agents.state - def unreachable(self): - return - -class TestHistory(TestCase): - - def test_counter_agent_history(self): - """ - The evolution of the state should be recorded in the logging agent - """ - config = { - 'name': 'CounterAgent', - 'network_params': { - 'path': join(ROOT, 'test.gexf') - }, - 'network_agents': [{ - 'agent_class': 'AggregatedCounter', - 'weight': 1, - 'state': {'state_id': 0} - - }], - 'max_time': 10, - 'environment_params': { - } - } - s = simulation.from_config(config) - env = s.run_simulation(dry_run=True)[0] - for agent in env.network_agents: - last = 0 - assert len(agent[None, None]) == 11 - for step, total in sorted(agent['total', None]): - assert total == last + 2 - last = total - - def test_row_conversion(self): - env = Environment(history=True) - env['test'] = 'test_value' - - res = list(env.history_to_tuples()) - assert len(res) == len(env.environment_params) - - env.schedule.time = 1 - env['test'] = 'second_value' - res = list(env.history_to_tuples()) - - assert env['env', 0, 'test' ] == 'test_value' - assert env['env', 1, 'test' ] == 'second_value' - - def test_nohistory(self): - ''' - Make sure that no history(/sqlite) is used by default - ''' - env = Environment(topology=nx.Graph(), network_agents=[]) - assert isinstance(env._history, NoHistory) - - def test_save_graph_history(self): - ''' - The history_to_graph method should return a valid networkx graph. - - The state of the agent should be encoded as intervals in the nx graph. - ''' - G = nx.cycle_graph(5) - distribution = agents.calculate_distribution(None, agents.BaseAgent) - env = Environment(topology=G, network_agents=distribution, history=True) - env[0, 0, 'testvalue'] = 'start' - env[0, 10, 'testvalue'] = 'finish' - nG = env.history_to_graph() - values = nG.nodes[0]['attr_testvalue'] - assert ('start', 0, 10) in values - assert ('finish', 10, None) in values - - def test_save_graph_nohistory(self): - ''' - The history_to_graph method should return a valid networkx graph. - - When NoHistory is used, only the last known value is known - ''' - G = nx.cycle_graph(5) - distribution = agents.calculate_distribution(None, agents.BaseAgent) - env = Environment(topology=G, network_agents=distribution, history=False) - env.get_agent(0)['testvalue'] = 'start' - env.schedule.time = 10 - env.get_agent(0)['testvalue'] = 'finish' - nG = env.history_to_graph() - values = nG.nodes[0]['attr_testvalue'] - assert ('start', 0, None) not in values - assert ('finish', 10, None) in values - - def test_pickle_agent_environment(self): - env = Environment(name='Test', history=True) - a = agents.BaseAgent(model=env, unique_id=25) - - a['key'] = 'test' - - pickled = pickle.dumps(a) - recovered = pickle.loads(pickled) - - assert recovered.env.name == 'Test' - assert list(recovered.env._history.to_tuples()) - assert recovered['key', 0] == 'test' - assert recovered['key'] == 'test' diff --git a/tests/test_main.py b/tests/test_main.py index a114b6c..6ac26e4 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -24,6 +24,7 @@ class CustomAgent(agents.FSM, agents.NetworkAgent): def unreachable(self): return + class TestMain(TestCase): def test_empty_simulation(self): @@ -79,20 +80,16 @@ class TestMain(TestCase): } }, 'agents': { - 'default': { - 'agent_class': 'CounterModel', - }, - 'counters': { - 'topology': 'default', - 'fixed': [{'state': {'times': 10}}, {'state': {'times': 20}}], - } + 'agent_class': 'CounterModel', + 'topology': 'default', + 'fixed': [{'state': {'times': 10}}, {'state': {'times': 20}}], } } } s = simulation.from_config(config) env = s.get_env() assert isinstance(env.agents[0], agents.CounterModel) - assert env.agents[0].topology == env.topologies['default'] + assert env.agents[0].G == env.topologies['default'] assert env.agents[0]['times'] == 10 assert env.agents[0]['times'] == 10 env.step() @@ -105,8 +102,8 @@ class TestMain(TestCase): config = { 'max_time': 10, 'model_params': { - 'agents': [(CustomAgent, {'weight': 1}), - (CustomAgent, {'weight': 3}), + 'agents': [{'agent_class': CustomAgent, 'weight': 1, 'topology': 'default'}, + {'agent_class': CustomAgent, 'weight': 3, 'topology': 'default'}, ], 'topologies': { 'default': { @@ -128,7 +125,7 @@ class TestMain(TestCase): """A complete example from a documentation should work.""" config = serialization.load_file(join(EXAMPLES, 'torvalds.yml'))[0] config['model_params']['network_params']['path'] = join(EXAMPLES, - config['network_params']['path']) + config['model_params']['network_params']['path']) s = simulation.from_config(config) env = s.run_simulation(dry_run=True)[0] for a in env.network_agents: @@ -208,24 +205,6 @@ class TestMain(TestCase): assert converted[1]['agent_class'] == 'test_main.CustomAgent' pickle.dumps(converted) - def test_subgraph(self): - '''An agent should be able to subgraph the global topology''' - G = nx.Graph() - G.add_node(3) - G.add_edge(1, 2) - distro = agents.calculate_distribution(agent_class=agents.NetworkAgent) - distro[0]['topology'] = 'default' - aconfig = config.AgentConfig(distribution=distro, topology='default') - env = Environment(name='Test', topologies={'default': G}, agents={'network': aconfig}) - lst = list(env.network_agents) - - a2 = env.find_one(node_id=2) - a3 = env.find_one(node_id=3) - assert len(a2.subgraph(limit_neighbors=True)) == 2 - assert len(a3.subgraph(limit_neighbors=True)) == 1 - assert len(a3.subgraph(limit_neighbors=True, center=False)) == 0 - assert len(a3.subgraph(agent_class=agents.NetworkAgent)) == 3 - def test_templates(self): '''Loading a template should result in several configs''' configs = serialization.load_file(join(EXAMPLES, 'template.yml')) @@ -236,14 +215,18 @@ class TestMain(TestCase): 'name': 'until_sim', 'model_params': { 'network_params': {}, - 'agent_class': 'CounterModel', + 'agents': { + 'fixed': [{ + 'agent_class': agents.BaseAgent, + }] + }, }, 'max_time': 2, 'num_trials': 50, } s = simulation.from_config(config) runs = list(s.run_simulation(dry_run=True)) - over = list(x.now for x in runs if x.now>2) + over = list(x.now for x in runs if x.now > 2) assert len(runs) == config['num_trials'] assert len(over) == 0 diff --git a/tests/test_network.py b/tests/test_network.py index b111a94..d984320 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -6,7 +6,8 @@ import networkx as nx from os.path import join -from soil import network, environment +from soil import config, network, environment, agents, simulation +from test_main import CustomAgent ROOT = os.path.abspath(os.path.dirname(__file__)) EXAMPLES = join(ROOT, '..', 'examples') @@ -60,22 +61,53 @@ class TestNetwork(TestCase): G = nx.random_geometric_graph(20, 0.1) env = environment.NetworkEnvironment(topology=G) f = io.BytesIO() - env.dump_gexf(f) + assert env.topologies['default'] + network.dump_gexf(env.topologies['default'], f) + + def test_networkenvironment_creation(self): + """Networkenvironment should accept netconfig as parameters""" + model_params = { + 'topologies': { + 'default': { + 'path': join(ROOT, 'test.gexf') + } + }, + 'agents': { + 'topology': 'default', + 'distribution': [{ + 'agent_class': CustomAgent, + }] + } + } + env = environment.Environment(**model_params) + assert env.topologies + env.step() + assert len(env.topologies['default']) == 2 + assert len(env.agents) == 2 + assert env.agents[1].count_agents(state_id='normal') == 2 + assert env.agents[1].count_agents(state_id='normal', limit_neighbors=True) == 1 + assert env.agents[0].neighbors == 1 def test_custom_agent_neighbors(self): """Allow for search of neighbors with a certain state_id""" config = { - 'network_params': { - 'path': join(ROOT, 'test.gexf') + 'model_params': { + 'topologies': { + 'default': { + 'path': join(ROOT, 'test.gexf') + } + }, + 'agents': { + 'topology': 'default', + 'distribution': [ + { + 'weight': 1, + 'agent_class': CustomAgent + } + ] + } }, - 'network_agents': [{ - 'agent_class': CustomAgent, - 'weight': 1 - - }], 'max_time': 10, - 'environment_params': { - } } s = simulation.from_config(config) env = s.run_simulation(dry_run=True)[0] @@ -83,3 +115,19 @@ class TestNetwork(TestCase): assert env.agents[1].count_agents(state_id='normal', limit_neighbors=True) == 1 assert env.agents[0].neighbors == 1 + def test_subgraph(self): + '''An agent should be able to subgraph the global topology''' + G = nx.Graph() + G.add_node(3) + G.add_edge(1, 2) + distro = agents.calculate_distribution(agent_class=agents.NetworkAgent) + aconfig = config.AgentConfig(distribution=distro, topology='default') + env = environment.Environment(name='Test', topologies={'default': G}, agents=aconfig) + lst = list(env.network_agents) + + a2 = env.find_one(node_id=2) + a3 = env.find_one(node_id=3) + assert len(a2.subgraph(limit_neighbors=True)) == 2 + assert len(a3.subgraph(limit_neighbors=True)) == 1 + assert len(a3.subgraph(limit_neighbors=True, center=False)) == 0 + assert len(a3.subgraph(agent_class=agents.NetworkAgent)) == 3