mirror of
https://github.com/gsi-upm/soil
synced 2025-09-14 04:02:21 +00:00
Compare commits
5 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
092591fb97 | ||
|
263b4e0e33 | ||
|
189836408f | ||
|
ee0c4517cb | ||
|
3041156f19 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -10,3 +10,4 @@ build/*
|
||||
dist/*
|
||||
prof
|
||||
backup
|
||||
*.egg-info
|
||||
|
@@ -11,6 +11,7 @@ def run_sim(model, **kwargs):
|
||||
dump=False,
|
||||
num_processes=1,
|
||||
parameters={'num_agents': NUM_AGENTS},
|
||||
seed="",
|
||||
max_steps=MAX_STEPS,
|
||||
iterations=NUM_ITERS)
|
||||
opts.update(kwargs)
|
||||
|
@@ -8,7 +8,6 @@ class NoopAgent(Agent):
|
||||
self.num_calls = 0
|
||||
|
||||
def step(self):
|
||||
# import pdb;pdb.set_trace()
|
||||
self.num_calls += 1
|
||||
|
||||
|
||||
|
@@ -10,7 +10,6 @@ class NoopAgent(Agent):
|
||||
self.num_calls = 0
|
||||
|
||||
def step(self):
|
||||
# import pdb;pdb.set_trace()
|
||||
self.num_calls += 1
|
||||
|
||||
|
||||
|
21
benchmarks/noop/soil_state.py
Normal file
21
benchmarks/noop/soil_state.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from soil import Agent, Environment, Simulation, state
|
||||
|
||||
|
||||
class NoopAgent(Agent):
|
||||
num_calls = 0
|
||||
|
||||
@state(default=True)
|
||||
def unique(self):
|
||||
self.num_calls += 1
|
||||
|
||||
class NoopEnvironment(Environment):
|
||||
num_agents = 100
|
||||
|
||||
def init(self):
|
||||
self.add_agents(NoopAgent, k=self.num_agents)
|
||||
self.add_agent_reporter("num_calls")
|
||||
|
||||
|
||||
from _config import *
|
||||
|
||||
run_sim(model=NoopEnvironment)
|
@@ -1,7 +1,7 @@
|
||||
from soil import BaseAgent, Environment, Simulation
|
||||
from soil import Agent, Environment, Simulation
|
||||
|
||||
|
||||
class NoopAgent(BaseAgent):
|
||||
class NoopAgent(Agent):
|
||||
num_calls = 0
|
||||
|
||||
def step(self):
|
||||
@@ -15,7 +15,6 @@ class NoopEnvironment(Environment):
|
||||
self.add_agent_reporter("num_calls")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from _config import *
|
||||
from _config import *
|
||||
|
||||
run_sim(model=NoopEnvironment)
|
||||
run_sim(model=NoopEnvironment)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from soil import Agent, Environment, Simulation
|
||||
from soilent import Scheduler
|
||||
from soil.time import SoilentActivation
|
||||
|
||||
|
||||
class NoopAgent(Agent):
|
||||
@@ -14,7 +14,7 @@ class NoopAgent(Agent):
|
||||
|
||||
class NoopEnvironment(Environment):
|
||||
num_agents = 100
|
||||
schedule_class = Scheduler
|
||||
schedule_class = SoilentActivation
|
||||
|
||||
def init(self):
|
||||
self.add_agents(NoopAgent, k=self.num_agents)
|
||||
@@ -26,4 +26,4 @@ if __name__ == "__main__":
|
||||
|
||||
res = run_sim(model=NoopEnvironment)
|
||||
for r in res:
|
||||
assert isinstance(r.schedule, Scheduler)
|
||||
assert isinstance(r.schedule, SoilentActivation)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from soil import Agent, Environment
|
||||
from soilent import PQueueScheduler
|
||||
from soil.time import SoilentPQueueActivation
|
||||
|
||||
|
||||
class NoopAgent(Agent):
|
||||
@@ -12,7 +12,7 @@ class NoopAgent(Agent):
|
||||
|
||||
class NoopEnvironment(Environment):
|
||||
num_agents = 100
|
||||
schedule_class = PQueueScheduler
|
||||
schedule_class = SoilentPQueueActivation
|
||||
|
||||
def init(self):
|
||||
self.add_agents(NoopAgent, k=self.num_agents)
|
||||
@@ -24,4 +24,4 @@ if __name__ == "__main__":
|
||||
|
||||
res = run_sim(model=NoopEnvironment)
|
||||
for r in res:
|
||||
assert isinstance(r.schedule, PQueueScheduler)
|
||||
assert isinstance(r.schedule, SoilentPQueueActivation)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from soil import Agent, Environment, Simulation
|
||||
from soilent import Scheduler
|
||||
from soil.time import SoilentActivation
|
||||
|
||||
|
||||
class NoopAgent(Agent):
|
||||
@@ -13,7 +13,7 @@ class NoopAgent(Agent):
|
||||
|
||||
class NoopEnvironment(Environment):
|
||||
num_agents = 100
|
||||
schedule_class = Scheduler
|
||||
schedule_class = SoilentActivation
|
||||
|
||||
def init(self):
|
||||
self.add_agents(NoopAgent, k=self.num_agents)
|
||||
@@ -25,4 +25,4 @@ if __name__ == "__main__":
|
||||
|
||||
res = run_sim(model=NoopEnvironment)
|
||||
for r in res:
|
||||
assert isinstance(r.schedule, Scheduler)
|
||||
assert isinstance(r.schedule, SoilentActivation)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from soil import Agent, Environment
|
||||
from soilent import PQueueScheduler
|
||||
from soil.time import SoilentPQueueActivation
|
||||
|
||||
|
||||
class NoopAgent(Agent):
|
||||
@@ -13,7 +13,7 @@ class NoopAgent(Agent):
|
||||
|
||||
class NoopEnvironment(Environment):
|
||||
num_agents = 100
|
||||
schedule_class = PQueueScheduler
|
||||
schedule_class = SoilentPQueueActivation
|
||||
|
||||
def init(self):
|
||||
self.add_agents(NoopAgent, k=self.num_agents)
|
||||
@@ -25,4 +25,4 @@ if __name__ == "__main__":
|
||||
|
||||
res = run_sim(model=NoopEnvironment)
|
||||
for r in res:
|
||||
assert isinstance(r.schedule, PQueueScheduler)
|
||||
assert isinstance(r.schedule, SoilentPQueueActivation)
|
||||
|
30
benchmarks/noop/soilent_state.py
Normal file
30
benchmarks/noop/soilent_state.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from soil import Agent, Environment, Simulation, state
|
||||
from soil.time import SoilentActivation
|
||||
|
||||
|
||||
class NoopAgent(Agent):
|
||||
num_calls = 0
|
||||
|
||||
@state(default=True)
|
||||
async def unique(self):
|
||||
while True:
|
||||
self.num_calls += 1
|
||||
# yield self.delay(1)
|
||||
await self.delay()
|
||||
|
||||
|
||||
class NoopEnvironment(Environment):
|
||||
num_agents = 100
|
||||
schedule_class = SoilentActivation
|
||||
|
||||
def init(self):
|
||||
self.add_agents(NoopAgent, k=self.num_agents)
|
||||
self.add_agent_reporter("num_calls")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from _config import *
|
||||
|
||||
res = run_sim(model=NoopEnvironment)
|
||||
for r in res:
|
||||
assert isinstance(r.schedule, SoilentActivation)
|
@@ -1,5 +1,5 @@
|
||||
from soil import BaseAgent, Environment, Simulation
|
||||
from soilent import Scheduler
|
||||
from soil.time import SoilentActivation
|
||||
|
||||
|
||||
class NoopAgent(BaseAgent):
|
||||
@@ -10,7 +10,7 @@ class NoopAgent(BaseAgent):
|
||||
|
||||
class NoopEnvironment(Environment):
|
||||
num_agents = 100
|
||||
schedule_class = Scheduler
|
||||
schedule_class = SoilentActivation
|
||||
|
||||
def init(self):
|
||||
self.add_agents(NoopAgent, k=self.num_agents)
|
||||
@@ -21,4 +21,4 @@ if __name__ == "__main__":
|
||||
from _config import *
|
||||
res = run_sim(model=NoopEnvironment)
|
||||
for r in res:
|
||||
assert isinstance(r.schedule, Scheduler)
|
||||
assert isinstance(r.schedule, SoilentActivation)
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from soil import BaseAgent, Environment, Simulation
|
||||
from soilent import PQueueScheduler
|
||||
from soil.time import SoilentPQueueActivation
|
||||
|
||||
|
||||
class NoopAgent(BaseAgent):
|
||||
@@ -10,7 +10,7 @@ class NoopAgent(BaseAgent):
|
||||
|
||||
class NoopEnvironment(Environment):
|
||||
num_agents = 100
|
||||
schedule_class = PQueueScheduler
|
||||
schedule_class = SoilentPQueueActivation
|
||||
|
||||
def init(self):
|
||||
self.add_agents(NoopAgent, k=self.num_agents)
|
||||
@@ -21,4 +21,4 @@ if __name__ == "__main__":
|
||||
from _config import *
|
||||
res = run_sim(model=NoopEnvironment)
|
||||
for r in res:
|
||||
assert isinstance(r.schedule, PQueueScheduler)
|
||||
assert isinstance(r.schedule, SoilentPqueueActivation)
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import os
|
||||
from soil import simulation
|
||||
|
||||
NUM_AGENTS = int(os.environ.get('NUM_AGENTS', 100))
|
||||
NUM_ITERS = int(os.environ.get('NUM_ITERS', 10))
|
||||
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000))
|
||||
MAX_STEPS = int(os.environ.get('MAX_STEPS', 500))
|
||||
|
||||
|
||||
def run_sim(model, **kwargs):
|
||||
@@ -22,11 +23,16 @@ def run_sim(model, **kwargs):
|
||||
iterations=NUM_ITERS)
|
||||
opts.update(kwargs)
|
||||
its = Simulation(**opts).run()
|
||||
assert len(its) == NUM_ITERS
|
||||
|
||||
assert all(it.schedule.steps == MAX_STEPS for it in its)
|
||||
ratios = list(it.resistant_susceptible_ratio() for it in its)
|
||||
print("Max - Avg - Min ratio:", max(ratios), sum(ratios)/len(ratios), min(ratios))
|
||||
assert all(sum([it.number_susceptible,
|
||||
it.number_infected,
|
||||
it.number_resistant]) == NUM_AGENTS for it in its)
|
||||
if not simulation._AVOID_RUNNING:
|
||||
ratios = list(it.resistant_susceptible_ratio for it in its)
|
||||
print("Max - Avg - Min ratio:", max(ratios), sum(ratios)/len(ratios), min(ratios))
|
||||
infected = list(it.number_infected for it in its)
|
||||
print("Max - Avg - Min infected:", max(infected), sum(infected)/len(infected), min(infected))
|
||||
|
||||
assert all((it.schedule.steps == MAX_STEPS or it.number_infected == 0) for it in its)
|
||||
assert all(sum([it.number_susceptible,
|
||||
it.number_infected,
|
||||
it.number_resistant]) == NUM_AGENTS for it in its)
|
||||
return its
|
@@ -100,6 +100,7 @@ class VirusOnNetwork(mesa.Model):
|
||||
def number_infected(self):
|
||||
return number_infected(self)
|
||||
|
||||
@property
|
||||
def resistant_susceptible_ratio(self):
|
||||
try:
|
||||
return number_state(self, State.RESISTANT) / number_state(
|
||||
@@ -176,5 +177,4 @@ class VirusAgent(mesa.Agent):
|
||||
|
||||
|
||||
from _config import run_sim
|
||||
|
||||
run_sim(model=VirusOnNetwork)
|
@@ -31,7 +31,11 @@ class VirusOnNetwork(Environment):
|
||||
a.set_state(VirusAgent.infected)
|
||||
assert self.number_infected == self.initial_outbreak_size
|
||||
|
||||
def step(self):
|
||||
super().step()
|
||||
|
||||
@report
|
||||
@property
|
||||
def resistant_susceptible_ratio(self):
|
||||
try:
|
||||
return self.number_resistant / self.number_susceptible
|
||||
@@ -59,34 +63,29 @@ class VirusAgent(Agent):
|
||||
virus_check_frequency = None # Inherit from model
|
||||
recovery_chance = None # Inherit from model
|
||||
gain_resistance_chance = None # Inherit from model
|
||||
just_been_infected = False
|
||||
|
||||
@state(default=True)
|
||||
def susceptible(self):
|
||||
if self.just_been_infected:
|
||||
self.just_been_infected = False
|
||||
return self.infected
|
||||
async def susceptible(self):
|
||||
await self.received()
|
||||
return self.infected
|
||||
|
||||
@state
|
||||
def infected(self):
|
||||
susceptible_neighbors = self.get_neighbors(state_id=self.susceptible.id)
|
||||
for a in susceptible_neighbors:
|
||||
if self.prob(self.virus_spread_chance):
|
||||
a.just_been_infected = True
|
||||
a.tell(True, sender=self)
|
||||
if self.prob(self.virus_check_frequency):
|
||||
if self.prob(self.recovery_chance):
|
||||
if self.prob(self.gain_resistance_chance):
|
||||
return self.resistant
|
||||
else:
|
||||
return self.susceptible
|
||||
else:
|
||||
return self.infected
|
||||
|
||||
@state
|
||||
def resistant(self):
|
||||
return self.at(INFINITY)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from _config import run_sim
|
||||
run_sim(model=VirusOnNetwork)
|
||||
from _config import run_sim
|
||||
run_sim(model=VirusOnNetwork)
|
@@ -38,6 +38,7 @@ class VirusOnNetwork(Environment):
|
||||
assert self.number_infected == self.initial_outbreak_size
|
||||
|
||||
@report
|
||||
@property
|
||||
def resistant_susceptible_ratio(self):
|
||||
try:
|
||||
return self.number_resistant / self.number_susceptible
|
||||
@@ -99,6 +100,5 @@ class VirusAgent(Agent):
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from _config import run_sim
|
||||
run_sim(model=VirusOnNetwork)
|
||||
from _config import run_sim
|
||||
run_sim(model=VirusOnNetwork)
|
File diff suppressed because one or more lines are too long
@@ -43,7 +43,7 @@ class Journey:
|
||||
driver: Optional[Driver] = None
|
||||
|
||||
|
||||
class City(EventedEnvironment):
|
||||
class City(Environment):
|
||||
"""
|
||||
An environment with a grid where drivers and passengers will be placed.
|
||||
|
||||
@@ -85,11 +85,12 @@ class Driver(Evented, FSM):
|
||||
journey = None
|
||||
earnings = 0
|
||||
|
||||
def on_receive(self, msg, sender):
|
||||
"""This is not a state. It will run (and block) every time process_messages is invoked"""
|
||||
if self.journey is None and isinstance(msg, Journey) and msg.driver is None:
|
||||
msg.driver = self
|
||||
self.journey = msg
|
||||
# TODO: remove
|
||||
# def on_receive(self, msg, sender):
|
||||
# """This is not a state. It will run (and block) every time process_messages is invoked"""
|
||||
# if self.journey is None and isinstance(msg, Journey) and msg.driver is None:
|
||||
# msg.driver = self
|
||||
# self.journey = msg
|
||||
|
||||
def check_passengers(self):
|
||||
"""If there are no more passengers, stop forever"""
|
||||
@@ -104,7 +105,7 @@ class Driver(Evented, FSM):
|
||||
if not self.check_passengers():
|
||||
return self.die("No passengers left")
|
||||
self.journey = None
|
||||
while self.journey is None: # No potential journeys detected (see on_receive)
|
||||
while self.journey is None: # No potential journeys detected
|
||||
if target is None or not self.move_towards(target):
|
||||
target = self.random.choice(
|
||||
self.model.grid.get_neighborhood(self.pos, moore=False)
|
||||
@@ -113,7 +114,7 @@ class Driver(Evented, FSM):
|
||||
if not self.check_passengers():
|
||||
return self.die("No passengers left")
|
||||
# This will call on_receive behind the scenes, and the agent's status will be updated
|
||||
self.process_messages()
|
||||
|
||||
await self.delay(30) # Wait at least 30 seconds before checking again
|
||||
|
||||
try:
|
||||
@@ -167,12 +168,13 @@ class Driver(Evented, FSM):
|
||||
class Passenger(Evented, FSM):
|
||||
pos = None
|
||||
|
||||
def on_receive(self, msg, sender):
|
||||
"""This is not a state. It will be run synchronously every time `process_messages` is run"""
|
||||
# TODO: Remove
|
||||
# def on_receive(self, msg, sender):
|
||||
# """This is not a state. It will be run synchronously every time `process_messages` is run"""
|
||||
|
||||
if isinstance(msg, Journey):
|
||||
self.journey = msg
|
||||
return msg
|
||||
# if isinstance(msg, Journey):
|
||||
# self.journey = msg
|
||||
# return msg
|
||||
|
||||
@default_state
|
||||
@state
|
||||
@@ -192,17 +194,34 @@ class Passenger(Evented, FSM):
|
||||
timeout = 60
|
||||
expiration = self.now + timeout
|
||||
self.info(f"Asking for journey at: { self.pos }")
|
||||
self.model.broadcast(journey, ttl=timeout, sender=self, agent_class=Driver)
|
||||
self.broadcast(journey, ttl=timeout, agent_class=Driver)
|
||||
while not self.journey:
|
||||
self.debug(f"Waiting for responses at: { self.pos }")
|
||||
try:
|
||||
# This will call process_messages behind the scenes, and the agent's status will be updated
|
||||
# If you want to avoid that, you can call it with: check=False
|
||||
await self.received(expiration=expiration, delay=10)
|
||||
offers = await self.received(expiration=expiration, delay=10)
|
||||
accepted = None
|
||||
for event in offers:
|
||||
offer = event.payload
|
||||
if isinstance(offer, Journey):
|
||||
self.journey = offer
|
||||
assert isinstance(event.sender, Driver)
|
||||
try:
|
||||
answer = await event.sender.ask(True, sender=self, timeout=60, delay=5)
|
||||
if answer:
|
||||
accepted = offer
|
||||
self.journey = offer
|
||||
break
|
||||
except events.TimedOut:
|
||||
pass
|
||||
if accepted:
|
||||
for event in offers:
|
||||
if event.payload != accepted:
|
||||
event.sender.tell(False, timeout=60, delay=5)
|
||||
|
||||
except events.TimedOut:
|
||||
self.info(f"Still no response. Waiting at: { self.pos }")
|
||||
self.model.broadcast(
|
||||
journey, ttl=timeout, sender=self, agent_class=Driver
|
||||
self.broadcast(
|
||||
journey, ttl=timeout, agent_class=Driver
|
||||
)
|
||||
expiration = self.now + timeout
|
||||
self.info(f"Got a response! Waiting for driver")
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from soil.agents import FSM, NetworkAgent, state, default_state, prob
|
||||
from soil.agents import FSM, NetworkAgent, state, default_state
|
||||
from soil.parameters import *
|
||||
import logging
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
There are two similar implementations of this simulation.
|
||||
|
||||
- `basic`. Using simple primites
|
||||
- `improved`. Using more advanced features such as the `time` module to avoid unnecessary computations (i.e., skip steps), and generator functions.
|
||||
- `improved`. Using more advanced features such as the delays to avoid unnecessary computations (i.e., skip steps).
|
||||
|
||||
The examples can be run directly in the terminal, and they accept command like arguments.
|
||||
For example, to enable the CSV exporter and the Summary exporter, while setting `max_time` to `100` and `seed` to `CustomSeed`:
|
||||
|
@@ -1,22 +1,33 @@
|
||||
from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment, Simulation
|
||||
from enum import Enum
|
||||
from collections import Counter
|
||||
import logging
|
||||
from soil import Evented, FSM, state, default_state, BaseAgent, NetworkAgent, Environment, parameters, report, TimedOut
|
||||
import math
|
||||
|
||||
from rabbits_basic_sim import RabbitEnv
|
||||
|
||||
class RabbitsImprovedEnv(Environment):
|
||||
prob_death: parameters.probability = 1e-3
|
||||
|
||||
class RabbitsImprovedEnv(RabbitEnv):
|
||||
def init(self):
|
||||
"""Initialize the environment with the new versions of the agents"""
|
||||
a1 = self.add_node(Male)
|
||||
a2 = self.add_node(Female)
|
||||
a1.add_edge(a2)
|
||||
self.add_agent(RandomAccident)
|
||||
|
||||
@report
|
||||
@property
|
||||
def num_rabbits(self):
|
||||
return self.count_agents(agent_class=Rabbit)
|
||||
|
||||
class Rabbit(FSM, NetworkAgent):
|
||||
@report
|
||||
@property
|
||||
def num_males(self):
|
||||
return self.count_agents(agent_class=Male)
|
||||
|
||||
@report
|
||||
@property
|
||||
def num_females(self):
|
||||
return self.count_agents(agent_class=Female)
|
||||
|
||||
|
||||
class Rabbit(Evented, FSM, NetworkAgent):
|
||||
|
||||
sexual_maturity = 30
|
||||
life_expectancy = 300
|
||||
@@ -31,42 +42,40 @@ class Rabbit(FSM, NetworkAgent):
|
||||
@default_state
|
||||
@state
|
||||
def newborn(self):
|
||||
self.info("I am a newborn.")
|
||||
self.debug("I am a newborn.")
|
||||
self.birth = self.now
|
||||
self.offspring = 0
|
||||
return self.youngling.delay(self.sexual_maturity - self.age)
|
||||
return self.youngling
|
||||
|
||||
@state
|
||||
def youngling(self):
|
||||
if self.age >= self.sexual_maturity:
|
||||
self.info(f"I am fertile! My age is {self.age}")
|
||||
return self.fertile
|
||||
async def youngling(self):
|
||||
self.debug("I am a youngling.")
|
||||
await self.delay(self.sexual_maturity - self.age)
|
||||
assert self.age >= self.sexual_maturity
|
||||
self.debug(f"I am fertile! My age is {self.age}")
|
||||
return self.fertile
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
raise Exception("Each subclass should define its fertile state")
|
||||
|
||||
@state
|
||||
def dead(self):
|
||||
self.die()
|
||||
|
||||
|
||||
class Male(Rabbit):
|
||||
max_females = 5
|
||||
mating_prob = 0.001
|
||||
mating_prob = 0.005
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
if self.age > self.life_expectancy:
|
||||
return self.dead
|
||||
return self.die()
|
||||
|
||||
# Males try to mate
|
||||
for f in self.model.agents(
|
||||
agent_class=Female, state_id=Female.fertile.id, limit=self.max_females
|
||||
):
|
||||
self.debug("FOUND A FEMALE: ", repr(f), self.mating_prob)
|
||||
self.debug(f"FOUND A FEMALE: {repr(f)}. Mating with prob {self.mating_prob}")
|
||||
if self.prob(self["mating_prob"]):
|
||||
f.impregnate(self)
|
||||
f.tell(self.unique_id, sender=self, timeout=1)
|
||||
break # Do not try to impregnate other females
|
||||
|
||||
|
||||
@@ -75,78 +84,91 @@ class Female(Rabbit):
|
||||
conception = None
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
async def fertile(self):
|
||||
# Just wait for a Male
|
||||
if self.age > self.life_expectancy:
|
||||
return self.dead
|
||||
if self.conception is not None:
|
||||
return self.pregnant
|
||||
|
||||
@property
|
||||
def pregnancy(self):
|
||||
if self.conception is None:
|
||||
return None
|
||||
return self.now - self.conception
|
||||
|
||||
def impregnate(self, male):
|
||||
self.info(f"impregnated by {repr(male)}")
|
||||
self.mate = male
|
||||
self.conception = self.now
|
||||
self.number_of_babies = int(8 + 4 * self.random.random())
|
||||
try:
|
||||
timeout = self.life_expectancy - self.age
|
||||
while timeout > 0:
|
||||
mates = await self.received(timeout=timeout)
|
||||
# assert self.age <= self.life_expectancy
|
||||
for mate in mates:
|
||||
try:
|
||||
male = self.model.agents[mate.payload]
|
||||
except ValueError:
|
||||
continue
|
||||
self.debug(f"impregnated by {repr(male)}")
|
||||
self.mate = male
|
||||
self.number_of_babies = int(8 + 4 * self.random.random())
|
||||
self.conception = self.now
|
||||
return self.pregnant
|
||||
except TimedOut:
|
||||
pass
|
||||
return self.die()
|
||||
|
||||
@state
|
||||
def pregnant(self):
|
||||
async def pregnant(self):
|
||||
self.debug("I am pregnant")
|
||||
# assert self.mate is not None
|
||||
|
||||
when = min(self.gestation, self.life_expectancy - self.age)
|
||||
if when < 0:
|
||||
return self.die()
|
||||
await self.delay(when)
|
||||
|
||||
if self.age > self.life_expectancy:
|
||||
self.info("Dying before giving birth")
|
||||
self.debug("Dying before giving birth")
|
||||
return self.die()
|
||||
|
||||
if self.pregnancy >= self.gestation:
|
||||
self.info("Having {} babies".format(self.number_of_babies))
|
||||
for i in range(self.number_of_babies):
|
||||
state = {}
|
||||
agent_class = self.random.choice([Male, Female])
|
||||
child = self.model.add_node(agent_class=agent_class, **state)
|
||||
child.add_edge(self)
|
||||
if self.mate:
|
||||
child.add_edge(self.mate)
|
||||
self.mate.offspring += 1
|
||||
else:
|
||||
self.debug("The father has passed away")
|
||||
# assert self.now - self.conception >= self.gestation
|
||||
if not self.alive:
|
||||
return self.die()
|
||||
|
||||
self.offspring += 1
|
||||
self.mate = None
|
||||
return self.fertile
|
||||
self.debug("Having {} babies".format(self.number_of_babies))
|
||||
for i in range(self.number_of_babies):
|
||||
state = {}
|
||||
agent_class = self.random.choice([Male, Female])
|
||||
child = self.model.add_node(agent_class=agent_class, **state)
|
||||
child.add_edge(self)
|
||||
try:
|
||||
child.add_edge(self.mate)
|
||||
self.model.agents[self.mate].offspring += 1
|
||||
except ValueError:
|
||||
self.debug("The father has passed away")
|
||||
|
||||
self.offspring += 1
|
||||
self.mate = None
|
||||
self.conception = None
|
||||
return self.fertile
|
||||
|
||||
def die(self):
|
||||
if self.pregnancy is not None:
|
||||
self.info("A mother has died carrying a baby!!")
|
||||
if self.conception is not None:
|
||||
self.debug("A mother has died carrying a baby!!")
|
||||
return super().die()
|
||||
|
||||
|
||||
class RandomAccident(BaseAgent):
|
||||
# Default value, but the value from the environment takes precedence
|
||||
prob_death = 1e-3
|
||||
|
||||
def step(self):
|
||||
rabbits_alive = self.model.G.number_of_nodes()
|
||||
|
||||
if not rabbits_alive:
|
||||
return self.die()
|
||||
alive = self.get_agents(agent_class=Rabbit, alive=True)
|
||||
|
||||
prob_death = self.model.get("prob_death", 1e-100) * math.floor(
|
||||
math.log10(max(1, rabbits_alive))
|
||||
)
|
||||
if not alive:
|
||||
return self.die("No more rabbits to kill")
|
||||
|
||||
num_alive = len(alive)
|
||||
prob_death = min(1, self.prob_death * num_alive/10)
|
||||
self.debug("Killing some rabbits with prob={}!".format(prob_death))
|
||||
for i in self.iter_agents(agent_class=Rabbit):
|
||||
|
||||
for i in alive:
|
||||
if i.state_id == i.dead.id:
|
||||
continue
|
||||
if self.prob(prob_death):
|
||||
self.info("I killed a rabbit: {}".format(i.id))
|
||||
rabbits_alive -= 1
|
||||
i.die()
|
||||
self.debug("Rabbits alive: {}".format(rabbits_alive))
|
||||
self.debug("I killed a rabbit: {}".format(i.unique_id))
|
||||
num_alive -= 1
|
||||
self.model.remove_agent(i)
|
||||
self.debug("Rabbits alive: {}".format(num_alive))
|
||||
|
||||
|
||||
sim = Simulation(model=RabbitsImprovedEnv, max_time=100, seed="MySeed", iterations=1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
sim.run()
|
||||
RabbitsImprovedEnv.run(max_time=1000, seed="MySeed", iterations=1)
|
||||
|
@@ -1,11 +1,9 @@
|
||||
from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment, Simulation, report, parameters as params
|
||||
from collections import Counter
|
||||
import logging
|
||||
from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment, report, parameters as params
|
||||
import math
|
||||
|
||||
|
||||
class RabbitEnv(Environment):
|
||||
prob_death: params.probability = 1e-100
|
||||
prob_death: params.probability = 1e-3
|
||||
|
||||
def init(self):
|
||||
a1 = self.add_node(Male)
|
||||
@@ -37,7 +35,7 @@ class Rabbit(NetworkAgent, FSM):
|
||||
@default_state
|
||||
@state
|
||||
def newborn(self):
|
||||
self.info("I am a newborn.")
|
||||
self.debug("I am a newborn.")
|
||||
self.age = 0
|
||||
self.offspring = 0
|
||||
return self.youngling
|
||||
@@ -46,7 +44,7 @@ class Rabbit(NetworkAgent, FSM):
|
||||
def youngling(self):
|
||||
self.age += 1
|
||||
if self.age >= self.sexual_maturity:
|
||||
self.info(f"I am fertile! My age is {self.age}")
|
||||
self.debug(f"I am fertile! My age is {self.age}")
|
||||
return self.fertile
|
||||
|
||||
@state
|
||||
@@ -60,7 +58,7 @@ class Rabbit(NetworkAgent, FSM):
|
||||
|
||||
class Male(Rabbit):
|
||||
max_females = 5
|
||||
mating_prob = 0.001
|
||||
mating_prob = 0.005
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
@@ -70,9 +68,8 @@ class Male(Rabbit):
|
||||
return self.dead
|
||||
|
||||
# Males try to mate
|
||||
for f in self.model.agents(
|
||||
agent_class=Female, state_id=Female.fertile.id, limit=self.max_females
|
||||
):
|
||||
for f in self.model.agents.filter(
|
||||
agent_class=Female, state_id=Female.fertile.id).limit(self.max_females):
|
||||
self.debug("FOUND A FEMALE: ", repr(f), self.mating_prob)
|
||||
if self.prob(self["mating_prob"]):
|
||||
f.impregnate(self)
|
||||
@@ -93,14 +90,14 @@ class Female(Rabbit):
|
||||
return self.pregnant
|
||||
|
||||
def impregnate(self, male):
|
||||
self.info(f"impregnated by {repr(male)}")
|
||||
self.debug(f"impregnated by {repr(male)}")
|
||||
self.mate = male
|
||||
self.pregnancy = 0
|
||||
self.number_of_babies = int(8 + 4 * self.random.random())
|
||||
|
||||
@state
|
||||
def pregnant(self):
|
||||
self.info("I am pregnant")
|
||||
self.debug("I am pregnant")
|
||||
self.age += 1
|
||||
|
||||
if self.age >= self.life_expectancy:
|
||||
@@ -110,7 +107,7 @@ class Female(Rabbit):
|
||||
self.pregnancy += 1
|
||||
return
|
||||
|
||||
self.info("Having {} babies".format(self.number_of_babies))
|
||||
self.debug("Having {} babies".format(self.number_of_babies))
|
||||
for i in range(self.number_of_babies):
|
||||
state = {}
|
||||
agent_class = self.random.choice([Male, Female])
|
||||
@@ -129,33 +126,32 @@ class Female(Rabbit):
|
||||
|
||||
def die(self):
|
||||
if "pregnancy" in self and self["pregnancy"] > -1:
|
||||
self.info("A mother has died carrying a baby!!")
|
||||
self.debug("A mother has died carrying a baby!!")
|
||||
return super().die()
|
||||
|
||||
|
||||
class RandomAccident(BaseAgent):
|
||||
prob_death = None
|
||||
def step(self):
|
||||
rabbits_alive = self.model.G.number_of_nodes()
|
||||
alive = self.get_agents(agent_class=Rabbit, alive=True)
|
||||
|
||||
if not rabbits_alive:
|
||||
return self.die()
|
||||
if not alive:
|
||||
return self.die("No more rabbits to kill")
|
||||
|
||||
prob_death = self.model.prob_death * math.floor(
|
||||
math.log10(max(1, rabbits_alive))
|
||||
)
|
||||
num_alive = len(alive)
|
||||
prob_death = min(1, self.prob_death * num_alive/10)
|
||||
self.debug("Killing some rabbits with prob={}!".format(prob_death))
|
||||
for i in self.get_agents(agent_class=Rabbit):
|
||||
|
||||
for i in alive:
|
||||
if i.state_id == i.dead.id:
|
||||
continue
|
||||
if self.prob(prob_death):
|
||||
self.info("I killed a rabbit: {}".format(i.id))
|
||||
rabbits_alive -= 1
|
||||
i.die()
|
||||
self.debug("Rabbits alive: {}".format(rabbits_alive))
|
||||
self.debug("I killed a rabbit: {}".format(i.unique_id))
|
||||
num_alive -= 1
|
||||
self.model.remove_agent(i)
|
||||
i.alive = False
|
||||
i.killed = True
|
||||
self.debug("Rabbits alive: {}".format(num_alive))
|
||||
|
||||
|
||||
|
||||
sim = Simulation(model=RabbitEnv, max_time=100, seed="MySeed", iterations=1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
sim.run()
|
||||
RabbitEnv.run(max_time=1000, seed="MySeed", iterations=1)
|
@@ -1,5 +1,5 @@
|
||||
import networkx as nx
|
||||
from soil.agents import NetworkAgent, FSM, custom, state, default_state
|
||||
from soil.agents import FSM, state, default_state
|
||||
from soil.agents.geo import Geo
|
||||
from soil import Environment, Simulation
|
||||
from soil.parameters import *
|
||||
|
@@ -6,7 +6,6 @@ pandas>=1
|
||||
SALib>=1.3
|
||||
Jinja2
|
||||
Mesa>=1.2
|
||||
pydantic>=1.9
|
||||
sqlalchemy>=1.4
|
||||
typing-extensions>=4.4
|
||||
annotated-types>=0.4
|
||||
|
@@ -1 +1 @@
|
||||
1.0.0rc3
|
||||
1.0.0rc10
|
@@ -19,7 +19,7 @@ from pathlib import Path
|
||||
from .agents import *
|
||||
from . import agents
|
||||
from .simulation import *
|
||||
from .environment import Environment, EventedEnvironment
|
||||
from .environment import Environment
|
||||
from .datacollection import SoilCollector
|
||||
from . import serialization
|
||||
from .utils import logger
|
||||
@@ -117,13 +117,13 @@ def main(
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_time",
|
||||
default="-1",
|
||||
default="",
|
||||
help="Set maximum time for the simulation to run. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default="-1",
|
||||
default="",
|
||||
help="Set maximum number of steps for the simulation to run.",
|
||||
)
|
||||
|
||||
@@ -249,14 +249,16 @@ def main(
|
||||
if args.only_convert:
|
||||
print(sim.to_yaml())
|
||||
continue
|
||||
max_time = float(args.max_time) if args.max_time != "-1" else None
|
||||
max_steps = float(args.max_steps) if args.max_steps != "-1" else None
|
||||
res.append(sim.run(max_time=max_time, max_steps=max_steps))
|
||||
d = {}
|
||||
if args.max_time:
|
||||
d["max_time"] = float(args.max_time)
|
||||
if args.max_steps:
|
||||
d["max_steps"] = int(args.max_steps)
|
||||
res.append(sim.run(**d))
|
||||
|
||||
except Exception as ex:
|
||||
if args.pdb:
|
||||
from .debugging import post_mortem
|
||||
|
||||
print(traceback.format_exc())
|
||||
post_mortem()
|
||||
else:
|
||||
|
@@ -1,107 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections.abc import MutableMapping, Mapping, Set
|
||||
from abc import ABCMeta
|
||||
from copy import deepcopy, copy
|
||||
from functools import partial, wraps
|
||||
from itertools import islice, chain
|
||||
from collections.abc import MutableMapping
|
||||
from copy import deepcopy
|
||||
import inspect
|
||||
import types
|
||||
import textwrap
|
||||
import networkx as nx
|
||||
import warnings
|
||||
import sys
|
||||
|
||||
from typing import Any
|
||||
from mesa import Agent as MesaAgent
|
||||
|
||||
from mesa import Agent as MesaAgent, Model
|
||||
from typing import Dict, List
|
||||
from .. import utils, time
|
||||
|
||||
from .. import serialization, network, utils, time, config
|
||||
from .meta import MetaAgent
|
||||
|
||||
|
||||
IGNORED_FIELDS = ("model", "logger")
|
||||
|
||||
|
||||
def decorate_generator_step(func, name):
|
||||
@wraps(func)
|
||||
def decorated(self):
|
||||
while True:
|
||||
if self._coroutine is None:
|
||||
self._coroutine = func(self)
|
||||
try:
|
||||
if self._last_except:
|
||||
val = self._coroutine.throw(self._last_except)
|
||||
else:
|
||||
val = self._coroutine.send(self._last_return)
|
||||
except StopIteration as ex:
|
||||
self._coroutine = None
|
||||
val = ex.value
|
||||
finally:
|
||||
self._last_return = None
|
||||
self._last_except = None
|
||||
return float(val) if val is not None else val
|
||||
return decorated
|
||||
|
||||
|
||||
def decorate_normal_func(func, name):
|
||||
@wraps(func)
|
||||
def decorated(self):
|
||||
val = func(self)
|
||||
return float(val) if val is not None else val
|
||||
return decorated
|
||||
|
||||
|
||||
class MetaAgent(ABCMeta):
|
||||
def __new__(mcls, name, bases, namespace):
|
||||
defaults = {}
|
||||
|
||||
# Re-use defaults from inherited classes
|
||||
for i in bases:
|
||||
if isinstance(i, MetaAgent):
|
||||
defaults.update(i._defaults)
|
||||
|
||||
new_nmspc = {
|
||||
"_defaults": defaults,
|
||||
}
|
||||
|
||||
for attr, func in namespace.items():
|
||||
if attr == "step":
|
||||
if inspect.isgeneratorfunction(func) or inspect.iscoroutinefunction(func):
|
||||
func = decorate_generator_step(func, attr)
|
||||
new_nmspc.update({
|
||||
"_last_return": None,
|
||||
"_last_except": None,
|
||||
"_coroutine": None,
|
||||
})
|
||||
elif inspect.isasyncgenfunction(func):
|
||||
raise ValueError("Illegal step function: {}. It probably mixes both async/await and yield".format(func))
|
||||
elif inspect.isfunction(func):
|
||||
func = decorate_normal_func(func, attr)
|
||||
else:
|
||||
raise ValueError("Illegal step function: {}".format(func))
|
||||
new_nmspc[attr] = func
|
||||
elif (
|
||||
isinstance(func, types.FunctionType)
|
||||
or isinstance(func, property)
|
||||
or isinstance(func, classmethod)
|
||||
or attr[0] == "_"
|
||||
):
|
||||
new_nmspc[attr] = func
|
||||
elif attr == "defaults":
|
||||
defaults.update(func)
|
||||
elif inspect.isfunction(func):
|
||||
new_nmspc[attr] = func
|
||||
else:
|
||||
defaults[attr] = copy(func)
|
||||
|
||||
|
||||
# Add attributes for their use in the decorated functions
|
||||
return super().__new__(mcls, name, bases, new_nmspc)
|
||||
|
||||
|
||||
class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
"""
|
||||
A special type of Mesa Agent that:
|
||||
@@ -114,8 +30,11 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
Any attribute that is not preceded by an underscore (`_`) will also be added to its state.
|
||||
"""
|
||||
|
||||
def __init__(self, unique_id, model, name=None, init=True, **kwargs):
|
||||
assert isinstance(unique_id, int)
|
||||
def __init__(self, unique_id=None, model=None, name=None, init=True, **kwargs):
|
||||
# Ideally, model should be the first argument, but Mesa's Agent class has unique_id first
|
||||
assert not (model is None), "Must provide a model"
|
||||
if unique_id is None:
|
||||
unique_id = model.next_id()
|
||||
super().__init__(unique_id=unique_id, model=model)
|
||||
|
||||
self.name = (
|
||||
@@ -154,11 +73,11 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
return hash(self.unique_id)
|
||||
|
||||
def prob(self, probability):
|
||||
return prob(probability, self.model.random)
|
||||
return utils.prob(probability, self.model.random)
|
||||
|
||||
@classmethod
|
||||
def w(cls, **kwargs):
|
||||
return custom(cls, **kwargs)
|
||||
return utils.custom(cls, **kwargs)
|
||||
|
||||
# TODO: refactor to clean up mesa compatibility
|
||||
@property
|
||||
@@ -168,20 +87,12 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
print(msg, file=sys.stderr)
|
||||
return self.unique_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, model, attrs, warn_extra=True):
|
||||
ignored = {}
|
||||
args = {}
|
||||
for k, v in attrs.items():
|
||||
if k in inspect.signature(cls).parameters:
|
||||
args[k] = v
|
||||
else:
|
||||
ignored[k] = v
|
||||
if ignored and warn_extra:
|
||||
utils.logger.info(
|
||||
f"Ignoring the following arguments for agent class { agent_class.__name__ }: { ignored }"
|
||||
)
|
||||
return cls(model=model, **args)
|
||||
@property
|
||||
def env(self):
|
||||
msg = "This attribute is deprecated. Use `model` instead"
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
print(msg, file=sys.stderr)
|
||||
return self.model
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
@@ -237,10 +148,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
else:
|
||||
self.debug(f"agent dying")
|
||||
self.alive = False
|
||||
try:
|
||||
self.model.schedule.remove(self)
|
||||
except KeyError:
|
||||
pass
|
||||
return time.Delay(time.INFINITY)
|
||||
|
||||
def step(self):
|
||||
@@ -302,399 +209,23 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
return time.Delay(delay)
|
||||
|
||||
|
||||
def prob(prob, random):
|
||||
"""
|
||||
A true/False uniform distribution with a given probability.
|
||||
To be used like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
if prob(0.3):
|
||||
do_something()
|
||||
|
||||
"""
|
||||
r = random.random()
|
||||
return r < prob
|
||||
|
||||
|
||||
def calculate_distribution(network_agents=None, agent_class=None):
|
||||
"""
|
||||
Calculate the threshold values (thresholds for a uniform distribution)
|
||||
of an agent distribution given the weights of each agent type.
|
||||
|
||||
The input has this form: ::
|
||||
|
||||
[
|
||||
{'agent_class': 'agent_class_1',
|
||||
'weight': 0.2,
|
||||
'state': {
|
||||
'id': 0
|
||||
}
|
||||
},
|
||||
{'agent_class': 'agent_class_2',
|
||||
'weight': 0.8,
|
||||
'state': {
|
||||
'id': 1
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
In this example, 20% of the nodes will be marked as type
|
||||
'agent_class_1'.
|
||||
"""
|
||||
if network_agents:
|
||||
network_agents = [
|
||||
deepcopy(agent) for agent in network_agents if not hasattr(agent, "id")
|
||||
]
|
||||
elif agent_class:
|
||||
network_agents = [{"agent_class": agent_class}]
|
||||
else:
|
||||
raise ValueError("Specify a distribution or a default agent type")
|
||||
|
||||
# Fix missing weights and incompatible types
|
||||
for x in network_agents:
|
||||
x["weight"] = float(x.get("weight", 1))
|
||||
|
||||
# Calculate the thresholds
|
||||
total = sum(x["weight"] for x in network_agents)
|
||||
acc = 0
|
||||
for v in network_agents:
|
||||
if "ids" in v:
|
||||
continue
|
||||
upper = acc + (v["weight"] / total)
|
||||
v["threshold"] = [acc, upper]
|
||||
acc = upper
|
||||
return network_agents
|
||||
|
||||
|
||||
def _serialize_type(agent_class, known_modules=[], **kwargs):
|
||||
if isinstance(agent_class, str):
|
||||
return agent_class
|
||||
known_modules += ["soil.agents"]
|
||||
return serialization.serialize(agent_class, known_modules=known_modules, **kwargs)[
|
||||
1
|
||||
] # Get the name of the class
|
||||
|
||||
|
||||
def _deserialize_type(agent_class, known_modules=[]):
|
||||
if not isinstance(agent_class, str):
|
||||
return agent_class
|
||||
known = known_modules + ["soil.agents", "soil.agents.custom"]
|
||||
agent_class = serialization.deserializer(agent_class, known_modules=known)
|
||||
return agent_class
|
||||
|
||||
|
||||
class AgentView(Mapping, Set):
|
||||
"""A lazy-loaded list of agents."""
|
||||
|
||||
__slots__ = ("_agents",)
|
||||
|
||||
def __init__(self, agents):
|
||||
self._agents = agents
|
||||
|
||||
def __getstate__(self):
|
||||
return {"_agents": self._agents}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._agents = state["_agents"]
|
||||
|
||||
# Mapping methods
|
||||
def __len__(self):
|
||||
return len(self._agents)
|
||||
|
||||
def __iter__(self):
|
||||
yield from self._agents.values()
|
||||
|
||||
def __getitem__(self, agent_id):
|
||||
if isinstance(agent_id, slice):
|
||||
raise ValueError(f"Slicing is not supported")
|
||||
if agent_id in self._agents:
|
||||
return self._agents[agent_id]
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
def filter(self, *args, **kwargs):
|
||||
yield from filter_agents(self._agents, *args, **kwargs)
|
||||
|
||||
def one(self, *args, **kwargs):
|
||||
return next(filter_agents(self._agents, *args, **kwargs))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return list(self.filter(*args, **kwargs))
|
||||
|
||||
def __contains__(self, agent_id):
|
||||
return agent_id in self._agents
|
||||
|
||||
def __str__(self):
|
||||
return str(list(unique_id for unique_id in self.keys()))
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self})"
|
||||
|
||||
|
||||
def filter_agents(
|
||||
agents: dict,
|
||||
*id_args,
|
||||
unique_id=None,
|
||||
state_id=None,
|
||||
agent_class=None,
|
||||
ignore=None,
|
||||
state=None,
|
||||
limit=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
||||
"""
|
||||
assert isinstance(agents, dict)
|
||||
|
||||
ids = []
|
||||
|
||||
if unique_id is not None:
|
||||
if isinstance(unique_id, list):
|
||||
ids += unique_id
|
||||
else:
|
||||
ids.append(unique_id)
|
||||
|
||||
if id_args:
|
||||
ids += id_args
|
||||
|
||||
if ids:
|
||||
f = (agents[aid] for aid in ids if aid in agents)
|
||||
else:
|
||||
f = agents.values()
|
||||
|
||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||
state_id = tuple([state_id])
|
||||
|
||||
if agent_class is not None:
|
||||
agent_class = _deserialize_type(agent_class)
|
||||
try:
|
||||
agent_class = tuple(agent_class)
|
||||
except TypeError:
|
||||
agent_class = tuple([agent_class])
|
||||
|
||||
if ignore:
|
||||
f = filter(lambda x: x not in ignore, f)
|
||||
|
||||
if state_id is not None:
|
||||
f = filter(lambda agent: agent.get("state_id", None) in state_id, f)
|
||||
|
||||
if agent_class is not None:
|
||||
f = filter(lambda agent: isinstance(agent, agent_class), f)
|
||||
|
||||
state = state or dict()
|
||||
state.update(kwargs)
|
||||
|
||||
for k, vs in state.items():
|
||||
if not isinstance(vs, list):
|
||||
vs = [vs]
|
||||
f = filter(lambda agent: any(getattr(agent, k, None) == v for v in vs), f)
|
||||
|
||||
if limit is not None:
|
||||
f = islice(f, limit)
|
||||
|
||||
yield from f
|
||||
|
||||
|
||||
def from_config(
|
||||
cfg: config.AgentConfig, random, topology: nx.Graph = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
This function turns an agentconfig into a list of individual "agent specifications", which are just a dictionary
|
||||
with the parameters that the environment will use to construct each agent.
|
||||
|
||||
This function does NOT return a list of agents, mostly because some attributes to the agent are not known at the
|
||||
time of calling this function, such as `unique_id`.
|
||||
"""
|
||||
default = cfg or config.AgentConfig()
|
||||
if not isinstance(cfg, config.AgentConfig):
|
||||
cfg = config.AgentConfig(**cfg)
|
||||
|
||||
agents = []
|
||||
|
||||
assigned_total = 0
|
||||
assigned_network = 0
|
||||
|
||||
if cfg.fixed is not None:
|
||||
agents, assigned_total, assigned_network = _from_fixed(
|
||||
cfg.fixed, topology=cfg.topology, default=cfg
|
||||
)
|
||||
|
||||
n = cfg.n
|
||||
|
||||
if cfg.distribution:
|
||||
topo_size = len(topology) if topology else 0
|
||||
|
||||
networked = []
|
||||
total = []
|
||||
|
||||
for d in cfg.distribution:
|
||||
if d.strategy == config.Strategy.topology:
|
||||
topo = d.topology if ("topology" in d.__fields_set__) else cfg.topology
|
||||
if not topo:
|
||||
raise ValueError(
|
||||
'The "topology" strategy only works if the topology parameter is set to True'
|
||||
)
|
||||
if not topo_size:
|
||||
raise ValueError(
|
||||
f"Topology does not have enough free nodes to assign one to the agent"
|
||||
)
|
||||
|
||||
networked.append(d)
|
||||
|
||||
if d.strategy == config.Strategy.total:
|
||||
if not cfg.n:
|
||||
raise ValueError(
|
||||
'Cannot use the "total" strategy without providing the total number of agents'
|
||||
)
|
||||
total.append(d)
|
||||
|
||||
if networked:
|
||||
new_agents = _from_distro(
|
||||
networked,
|
||||
n=topo_size - assigned_network,
|
||||
topology=topo,
|
||||
default=cfg,
|
||||
random=random,
|
||||
)
|
||||
assigned_total += len(new_agents)
|
||||
assigned_network += len(new_agents)
|
||||
agents += new_agents
|
||||
|
||||
if total:
|
||||
remaining = n - assigned_total
|
||||
agents += _from_distro(total, n=remaining, default=cfg, random=random)
|
||||
|
||||
if assigned_network < topo_size:
|
||||
utils.logger.warn(
|
||||
f"The total number of agents does not match the total number of nodes in "
|
||||
"every topology. This may be due to a definition error: assigned: "
|
||||
f"{ assigned } total size: { topo_size }"
|
||||
)
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
def _from_fixed(
|
||||
lst: List[config.FixedAgentConfig],
|
||||
topology: bool,
|
||||
default: config.SingleAgentConfig,
|
||||
) -> List[Dict[str, Any]]:
|
||||
agents = []
|
||||
|
||||
counts_total = 0
|
||||
counts_network = 0
|
||||
|
||||
for fixed in lst:
|
||||
agent = {}
|
||||
if default:
|
||||
agent = default.state.copy()
|
||||
agent.update(fixed.state)
|
||||
cls = serialization.deserialize(
|
||||
fixed.agent_class or (default and default.agent_class)
|
||||
)
|
||||
agent["agent_class"] = cls
|
||||
topo = (
|
||||
fixed.topology
|
||||
if ("topology" in fixed.__fields_set__)
|
||||
else topology or default.topology
|
||||
)
|
||||
|
||||
if topo:
|
||||
agent["topology"] = True
|
||||
counts_network += 1
|
||||
if not fixed.hidden:
|
||||
counts_total += 1
|
||||
agents.append(agent)
|
||||
|
||||
return agents, counts_total, counts_network
|
||||
|
||||
|
||||
def _from_distro(
|
||||
distro: List[config.AgentDistro],
|
||||
n: int,
|
||||
default: config.SingleAgentConfig,
|
||||
random,
|
||||
topology: str = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
||||
agents = []
|
||||
|
||||
if n is None:
|
||||
if any(lambda dist: dist.n is None, distro):
|
||||
raise ValueError(
|
||||
"You must provide a total number of agents, or the number of each type"
|
||||
)
|
||||
n = sum(dist.n for dist in distro)
|
||||
|
||||
weights = list(dist.weight if dist.weight is not None else 1 for dist in distro)
|
||||
minw = min(weights)
|
||||
norm = list(weight / minw for weight in weights)
|
||||
total = sum(norm)
|
||||
chunk = n // total
|
||||
|
||||
# random.choices would be enough to get a weighted distribution. But it can vary a lot for smaller k
|
||||
# So instead we calculate our own distribution to make sure the actual ratios are close to what we would expect
|
||||
|
||||
# Calculate how many times each has to appear
|
||||
indices = list(
|
||||
chain.from_iterable([idx] * int(n * chunk) for (idx, n) in enumerate(norm))
|
||||
)
|
||||
|
||||
# Complete with random agents following the original weight distribution
|
||||
if len(indices) < n:
|
||||
indices += random.choices(
|
||||
list(range(len(distro))),
|
||||
weights=[d.weight for d in distro],
|
||||
k=n - len(indices),
|
||||
)
|
||||
|
||||
# Deserialize classes for efficiency
|
||||
classes = list(
|
||||
serialization.deserialize(i.agent_class or default.agent_class) for i in distro
|
||||
)
|
||||
|
||||
# Add them in random order
|
||||
random.shuffle(indices)
|
||||
|
||||
for idx in indices:
|
||||
d = distro[idx]
|
||||
agent = d.state.copy()
|
||||
cls = classes[idx]
|
||||
agent["agent_class"] = cls
|
||||
if default:
|
||||
agent.update(default.state)
|
||||
topology = (
|
||||
d.topology
|
||||
if ("topology" in d.__fields_set__)
|
||||
else topology or default.topology
|
||||
)
|
||||
if topology:
|
||||
agent["topology"] = topology
|
||||
agents.append(agent)
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
from .network_agents import *
|
||||
from .fsm import *
|
||||
from .evented import *
|
||||
from typing import Optional
|
||||
from .view import *
|
||||
|
||||
|
||||
class Agent(NetworkAgent, FSM, EventedAgent):
|
||||
"""Default agent class, has both network and event capabilities"""
|
||||
class Noop(EventedAgent, BaseAgent):
|
||||
def step(self):
|
||||
return
|
||||
|
||||
|
||||
from ..environment import NetworkEnvironment
|
||||
class Agent(FSM, EventedAgent, NetworkAgent):
|
||||
"""Default agent class, has network, FSM and event capabilities"""
|
||||
|
||||
|
||||
# Additional types of agents
|
||||
from .BassModel import *
|
||||
from .IndependentCascadeModel import *
|
||||
from .SISaModel import *
|
||||
from .CounterModel import *
|
||||
|
||||
|
||||
def custom(cls, **kwargs):
|
||||
"""Create a new class from a template class and keyword arguments"""
|
||||
return type(cls.__name__, (cls,), kwargs)
|
||||
|
@@ -1,58 +1,34 @@
|
||||
from . import BaseAgent
|
||||
from ..events import Message, Tell, Ask, TimedOut
|
||||
from .. import environment, events
|
||||
from functools import partial
|
||||
from collections import deque
|
||||
from types import coroutine
|
||||
|
||||
# from soilent import Scheduler
|
||||
|
||||
|
||||
class EventedAgent(BaseAgent):
|
||||
# scheduler_class = Scheduler
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._inbox = deque()
|
||||
self._processed = 0
|
||||
assert isinstance(self.model, environment.EventedEnvironment), "EventedAgent requires an EventedEnvironment"
|
||||
self.model.register(self)
|
||||
|
||||
def on_receive(self, *args, **kwargs):
|
||||
pass
|
||||
def received(self, **kwargs):
|
||||
return self.model.received(agent=self, **kwargs)
|
||||
|
||||
@coroutine
|
||||
def received(self, expiration=None, timeout=60, delay=1, process=True):
|
||||
if not expiration:
|
||||
expiration = self.now + timeout
|
||||
while self.now < expiration:
|
||||
if self._inbox:
|
||||
msgs = self._inbox
|
||||
if process:
|
||||
self.process_messages()
|
||||
return msgs
|
||||
yield self.delay(delay)
|
||||
raise TimedOut("No message received")
|
||||
def tell(self, msg, **kwargs):
|
||||
return self.model.tell(msg, recipient=self, **kwargs)
|
||||
|
||||
def tell(self, msg, sender=None):
|
||||
self._inbox.append(Tell(timestamp=self.now, payload=msg, sender=sender))
|
||||
def broadcast(self, msg, **kwargs):
|
||||
return self.model.broadcast(msg, sender=self, **kwargs)
|
||||
|
||||
@coroutine
|
||||
def ask(self, msg, expiration=None, timeout=None, delay=1):
|
||||
ask = Ask(timestamp=self.now, payload=msg, sender=self)
|
||||
self._inbox.append(ask)
|
||||
expiration = float("inf") if timeout is None else self.now + timeout
|
||||
while self.now < expiration:
|
||||
if ask.reply:
|
||||
return ask.reply
|
||||
yield self.delay(delay)
|
||||
raise TimedOut("No reply received")
|
||||
def ask(self, msg, **kwargs):
|
||||
return self.model.ask(msg, recipient=self, **kwargs)
|
||||
|
||||
def process_messages(self):
|
||||
valid = list()
|
||||
for msg in self._inbox:
|
||||
self._processed += 1
|
||||
if msg.expired(self.now):
|
||||
continue
|
||||
valid.append(msg)
|
||||
reply = self.on_receive(msg.payload, sender=msg.sender)
|
||||
if isinstance(msg, Ask):
|
||||
msg.reply = reply
|
||||
self._inbox.clear()
|
||||
return valid
|
||||
return self.model.process_messages(self.model.inbox_for(self))
|
||||
|
||||
|
||||
Evented = EventedAgent
|
||||
|
@@ -6,39 +6,38 @@ import inspect
|
||||
|
||||
|
||||
class State:
|
||||
__slots__ = ("awaitable", "f", "generator", "name", "default")
|
||||
__slots__ = ("awaitable", "f", "attribute", "generator", "name", "default")
|
||||
|
||||
def __init__(self, f, name, default, generator, awaitable):
|
||||
self.f = f
|
||||
self.name = name
|
||||
self.attribute = "_{}".format(name)
|
||||
self.generator = generator
|
||||
self.awaitable = awaitable
|
||||
self.default = default
|
||||
|
||||
@coroutine
|
||||
def step(self, obj):
|
||||
if self.generator or self.awaitable:
|
||||
f = self.f
|
||||
next_state = yield from f(obj)
|
||||
return next_state
|
||||
|
||||
else:
|
||||
return self.f(obj)
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.name
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise Exception("States should not be called directly")
|
||||
|
||||
class UnboundState(State):
|
||||
def __get__(self, obj, owner=None):
|
||||
if obj is None:
|
||||
return self
|
||||
try:
|
||||
return getattr(obj, self.attribute)
|
||||
except AttributeError:
|
||||
b = self.bind(obj)
|
||||
setattr(obj, self.attribute, b)
|
||||
return b
|
||||
|
||||
def bind(self, obj):
|
||||
bs = BoundState(self.f, self.name, self.default, self.generator, self.awaitable, obj=obj)
|
||||
setattr(obj, self.name, bs)
|
||||
return bs
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise Exception("States should not be called directly")
|
||||
|
||||
|
||||
class BoundState(State):
|
||||
__slots__ = ("obj", )
|
||||
@@ -47,6 +46,17 @@ class BoundState(State):
|
||||
super().__init__(*args)
|
||||
self.obj = obj
|
||||
|
||||
@coroutine
|
||||
def __call__(self):
|
||||
if self.generator or self.awaitable:
|
||||
f = self.f
|
||||
next_state = yield from f(self.obj)
|
||||
return next_state
|
||||
|
||||
else:
|
||||
return self.f(self.obj)
|
||||
|
||||
|
||||
def delay(self, delta=0):
|
||||
return self, self.obj.delay(delta)
|
||||
|
||||
@@ -63,7 +73,7 @@ def state(name=None, default=False):
|
||||
name = name or func.__name__
|
||||
generator = inspect.isgeneratorfunction(func)
|
||||
awaitable = inspect.iscoroutinefunction(func) or inspect.isasyncgen(func)
|
||||
return UnboundState(func, name, default, generator, awaitable)
|
||||
return State(func, name, default, generator, awaitable)
|
||||
|
||||
if callable(name):
|
||||
return decorator(name)
|
||||
@@ -113,15 +123,24 @@ class MetaFSM(MetaAgent):
|
||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
def __init__(self, init=True, state_id=None, **kwargs):
|
||||
super().__init__(**kwargs, init=False)
|
||||
bound_states = {}
|
||||
for (k, v) in list(self._states.items()):
|
||||
if isinstance(v, State):
|
||||
v = v.bind(self)
|
||||
bound_states[k] = v
|
||||
setattr(self, k, v)
|
||||
|
||||
self._states = bound_states
|
||||
|
||||
if state_id is not None:
|
||||
self._set_state(state_id)
|
||||
else:
|
||||
self._set_state(self._state)
|
||||
# If more than "dead" state is defined, but no default state
|
||||
if len(self._states) > 1 and not self._state:
|
||||
raise ValueError(
|
||||
f"No default state specified for {type(self)}({self.unique_id})"
|
||||
)
|
||||
for (k, v) in self._states.items():
|
||||
setattr(self, k, v.bind(self))
|
||||
|
||||
if init:
|
||||
self.init()
|
||||
@@ -139,12 +158,20 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
raise ValueError("Cannot change state after init")
|
||||
self._set_state(value)
|
||||
|
||||
@coroutine
|
||||
def step(self):
|
||||
self._check_alive()
|
||||
next_state = yield from self._state.step(self)
|
||||
if self._state is None:
|
||||
if len(self._states) == 1:
|
||||
raise Exception("Agent class has no valid states: {}. Make sure to define states or define a custom step function".format(self.__class__.__name__))
|
||||
else:
|
||||
raise Exception("Invalid state (None) for agent {}".format(self))
|
||||
|
||||
next_state = yield from self._state()
|
||||
|
||||
try:
|
||||
next_state, when = next_state
|
||||
self._set_state(next_state)
|
||||
return when
|
||||
except (TypeError, ValueError) as ex:
|
||||
try:
|
||||
self._set_state(next_state)
|
||||
@@ -152,9 +179,6 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
except ValueError:
|
||||
return next_state
|
||||
|
||||
self._set_state(next_state)
|
||||
return when
|
||||
|
||||
def _set_state(self, state):
|
||||
if state is None:
|
||||
return
|
||||
@@ -162,7 +186,9 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
if state not in self._states:
|
||||
raise ValueError("{} is not a valid state".format(state))
|
||||
state = self._states[state]
|
||||
if not isinstance(state, State):
|
||||
if isinstance(state, State):
|
||||
state = state.bind(self)
|
||||
elif not isinstance(state, BoundState):
|
||||
raise ValueError("{} is not a valid state".format(state))
|
||||
self._state = state
|
||||
|
||||
|
64
soil/agents/meta.py
Normal file
64
soil/agents/meta.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from abc import ABCMeta
|
||||
from copy import copy
|
||||
from functools import wraps
|
||||
from .. import time
|
||||
from ..decorators import syncify, while_alive
|
||||
|
||||
import types
|
||||
import inspect
|
||||
|
||||
|
||||
class MetaAnnotations(ABCMeta):
|
||||
"""This metaclass sets default values for agents based on class attributes"""
|
||||
def __new__(mcls, name, bases, namespace):
|
||||
defaults = {}
|
||||
|
||||
# Re-use defaults from inherited classes
|
||||
for i in bases:
|
||||
if isinstance(i, MetaAgent):
|
||||
defaults.update(i._defaults)
|
||||
|
||||
new_nmspc = {
|
||||
"_defaults": defaults,
|
||||
}
|
||||
|
||||
for attr, func in namespace.items():
|
||||
if (
|
||||
isinstance(func, types.FunctionType)
|
||||
or isinstance(func, property)
|
||||
or isinstance(func, classmethod)
|
||||
or attr[0] == "_"
|
||||
):
|
||||
new_nmspc[attr] = func
|
||||
elif attr == "defaults":
|
||||
defaults.update(func)
|
||||
elif inspect.isfunction(func):
|
||||
new_nmspc[attr] = func
|
||||
else:
|
||||
defaults[attr] = copy(func)
|
||||
|
||||
return super().__new__(mcls, name, bases, new_nmspc)
|
||||
|
||||
|
||||
class AutoAgent(ABCMeta):
|
||||
def __new__(mcls, name, bases, namespace):
|
||||
if "step" in namespace:
|
||||
func = namespace["step"]
|
||||
namespace["_orig_step"] = func
|
||||
if inspect.isfunction(func):
|
||||
if inspect.isgeneratorfunction(func) or inspect.iscoroutinefunction(func):
|
||||
func = syncify(func, method=True)
|
||||
namespace["step"] = while_alive(func)
|
||||
elif inspect.isasyncgenfunction(func):
|
||||
raise ValueError("Illegal step function: {}. It probably mixes both async/await and yield".format(func))
|
||||
else:
|
||||
raise ValueError("Illegal step function: {}".format(func))
|
||||
|
||||
# Add attributes for their use in the decorated functions
|
||||
return super().__new__(mcls, name, bases, namespace)
|
||||
|
||||
|
||||
class MetaAgent(AutoAgent, MetaAnnotations):
|
||||
"""This metaclass sets default values for agents based on class attributes"""
|
||||
pass
|
||||
|
@@ -1,9 +1,11 @@
|
||||
from . import BaseAgent
|
||||
from .. import environment
|
||||
|
||||
|
||||
class NetworkAgent(BaseAgent):
|
||||
def __init__(self, *args, topology=None, init=True, node_id=None, **kwargs):
|
||||
super().__init__(*args, init=False, **kwargs)
|
||||
assert isinstance(self.model, environment.NetworkEnvironment), "NetworkAgent requires a NetworkEnvironment"
|
||||
|
||||
self.G = topology or self.model.G
|
||||
assert self.G is not None, "Network agents should have a network"
|
||||
@@ -16,7 +18,7 @@ class NetworkAgent(BaseAgent):
|
||||
else:
|
||||
node_id = len(self.G)
|
||||
self.info(f"All nodes ({len(self.G)}) have an agent assigned, adding a new node to the graph for agent {self.unique_id}")
|
||||
self.G.add_node(node_id)
|
||||
self.G.add_node(node_id, find_unassigned=True)
|
||||
assert node_id is not None
|
||||
self.G.nodes[node_id]["agent"] = self
|
||||
self.node_id = node_id
|
||||
|
139
soil/agents/view.py
Normal file
139
soil/agents/view.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from collections.abc import Mapping, Set
|
||||
from itertools import islice
|
||||
from mesa import Agent
|
||||
|
||||
|
||||
class AgentView(Mapping, Set):
|
||||
"""A lazy-loaded list of agents."""
|
||||
|
||||
__slots__ = ("_agents", "agents_by_type")
|
||||
|
||||
def __init__(self, agents, agents_by_type):
|
||||
self._agents = agents
|
||||
self.agents_by_type = agents_by_type
|
||||
|
||||
def __getstate__(self):
|
||||
return {"_agents": self._agents}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._agents = state["_agents"]
|
||||
|
||||
# Mapping methods
|
||||
def __len__(self):
|
||||
return len(self._agents)
|
||||
|
||||
def __iter__(self):
|
||||
yield from self._agents.values()
|
||||
|
||||
def __getitem__(self, agent_id):
|
||||
if isinstance(agent_id, slice):
|
||||
raise ValueError(f"Slicing is not supported")
|
||||
if agent_id in self._agents:
|
||||
return self._agents[agent_id]
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
def filter(self, agent_class=None, include_subclasses=True, **kwargs):
|
||||
if agent_class and self.agents_by_type:
|
||||
if not include_subclasses:
|
||||
return filter_agents(self.agents_by_type[agent_class],
|
||||
**kwargs)
|
||||
else:
|
||||
d = {}
|
||||
for k, v in self.agents_by_type.items():
|
||||
if (k == agent_class) or issubclass(k, agent_class):
|
||||
d.update(v)
|
||||
return filter_agents(d, **kwargs)
|
||||
return filter_agents(self._agents, agent_class=agent_class, **kwargs)
|
||||
|
||||
|
||||
def one(self, *args, **kwargs):
|
||||
try:
|
||||
return next(self.filter(*args, **kwargs))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return list(self.filter(*args, **kwargs))
|
||||
|
||||
def __contains__(self, agent_id):
|
||||
if isinstance(agent_id, Agent):
|
||||
agent_id = agent_id.unique_id
|
||||
return agent_id in self._agents
|
||||
|
||||
def __str__(self):
|
||||
return str(list(unique_id for unique_id in self.keys()))
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self})"
|
||||
|
||||
|
||||
def filter_agents(
|
||||
agents: dict,
|
||||
*id_args,
|
||||
unique_id=None,
|
||||
state_id=None,
|
||||
agent_class=None,
|
||||
ignore=None,
|
||||
state=None,
|
||||
limit=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
||||
"""
|
||||
assert isinstance(agents, dict)
|
||||
|
||||
ids = []
|
||||
|
||||
if unique_id is not None:
|
||||
if isinstance(unique_id, list):
|
||||
ids += unique_id
|
||||
else:
|
||||
ids.append(unique_id)
|
||||
|
||||
if id_args:
|
||||
ids += id_args
|
||||
|
||||
if ids:
|
||||
f = list(agents[aid] for aid in ids if aid in agents)
|
||||
else:
|
||||
f = agents.values()
|
||||
|
||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||
state_id = tuple([state_id])
|
||||
|
||||
if ignore:
|
||||
f = filter(lambda x: x not in ignore, f)
|
||||
|
||||
if state_id is not None:
|
||||
f = filter(lambda agent: agent.get("state_id", None) in state_id, f)
|
||||
|
||||
if agent_class is not None:
|
||||
f = filter(lambda agent: isinstance(agent, agent_class), f)
|
||||
|
||||
state = state or dict()
|
||||
state.update(kwargs)
|
||||
|
||||
for k, vs in state.items():
|
||||
if not isinstance(vs, list):
|
||||
vs = [vs]
|
||||
f = filter(lambda agent: any(getattr(agent, k, None) == v for v in vs), f)
|
||||
|
||||
if limit is not None:
|
||||
f = islice(f, limit)
|
||||
|
||||
return AgentResult(f)
|
||||
|
||||
class AgentResult:
|
||||
def __init__(self, iterator):
|
||||
self.iterator = iterator
|
||||
|
||||
def limit(self, limit):
|
||||
self.iterator = islice(self.iterator, limit)
|
||||
return self
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.iterator)
|
||||
|
||||
def __next__(self):
|
||||
return next(self.iterator)
|
@@ -19,7 +19,8 @@ def plot(env, agent_df=None, model_df=None, steps=False, ignore=["agent_count",
|
||||
try:
|
||||
agent_df = env.agent_df()
|
||||
except UserWarning:
|
||||
print("No agent dataframe provided and no agent reporters found. Skipping agent plot.", file=sys.stderr)
|
||||
print("No agent dataframe provided and no agent reporters found. "
|
||||
"Skipping agent plot.", file=sys.stderr)
|
||||
return
|
||||
if not agent_df.empty:
|
||||
agent_df.unstack().apply(lambda x: x.value_counts(),
|
||||
@@ -43,15 +44,10 @@ def read_sql(fpath=None, name=None, include_agents=False):
|
||||
with engine.connect() as conn:
|
||||
env = pd.read_sql_table("env", con=conn,
|
||||
index_col="step").reset_index().set_index([
|
||||
"simulation_id", "params_id",
|
||||
"iteration_id", "step"
|
||||
"params_id", "iteration_id", "step"
|
||||
])
|
||||
agents = pd.read_sql_table("agents", con=conn, index_col=["simulation_id", "params_id", "iteration_id", "step", "agent_id"])
|
||||
agents = pd.read_sql_table("agents", con=conn, index_col=["params_id", "iteration_id", "step", "agent_id"])
|
||||
config = pd.read_sql_table("configuration", con=conn, index_col="simulation_id")
|
||||
parameters = pd.read_sql_table("parameters", con=conn, index_col=["iteration_id", "params_id", "simulation_id"])
|
||||
try:
|
||||
parameters = parameters.pivot(columns="key", values="value")
|
||||
except Exception as e:
|
||||
print(f"warning: coult not pivot parameters: {e}")
|
||||
parameters = pd.read_sql_table("parameters", con=conn, index_col=["simulation_id", "params_id", "iteration_id"])
|
||||
|
||||
return Results(config, parameters, env, agents)
|
||||
|
@@ -1,6 +1,42 @@
|
||||
from functools import wraps
|
||||
from .time import INFINITY
|
||||
|
||||
def report(f: property):
|
||||
if isinstance(f, property):
|
||||
setattr(f.fget, "add_to_report", True)
|
||||
else:
|
||||
setattr(f, "add_to_report", True)
|
||||
return f
|
||||
|
||||
|
||||
def syncify(func, method=True):
|
||||
_coroutine = None
|
||||
|
||||
@wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
if not method:
|
||||
nonlocal _coroutine
|
||||
else:
|
||||
_coroutine = getattr(args[0], "_coroutine", None)
|
||||
_coroutine = _coroutine or func(*args, **kwargs)
|
||||
try:
|
||||
val = _coroutine.send(None)
|
||||
except StopIteration as ex:
|
||||
_coroutine = None
|
||||
val = ex.value
|
||||
finally:
|
||||
if method:
|
||||
args[0]._coroutine = _coroutine
|
||||
return val
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def while_alive(func):
|
||||
@wraps(func)
|
||||
def wrapped(self, *args, **kwargs):
|
||||
if self.alive:
|
||||
return func(self, *args, **kwargs)
|
||||
return INFINITY
|
||||
|
||||
return wrapped
|
@@ -1,20 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import math
|
||||
import sys
|
||||
import logging
|
||||
import inspect
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Union, List, Type
|
||||
from collections import namedtuple
|
||||
from time import time as current_time
|
||||
from copy import deepcopy
|
||||
|
||||
from types import coroutine
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from mesa import Model, Agent
|
||||
|
||||
from mesa import Model
|
||||
from time import time as current_time
|
||||
|
||||
from . import agents as agentmod, datacollection, utils, time, network, events
|
||||
|
||||
@@ -36,6 +34,8 @@ class BaseEnvironment(Model):
|
||||
|
||||
collector_class = datacollection.SoilCollector
|
||||
schedule_class = time.TimedActivation
|
||||
start_time = 0
|
||||
time_format = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
def __new__(cls,
|
||||
*args: Any,
|
||||
@@ -47,6 +47,7 @@ class BaseEnvironment(Model):
|
||||
tables: Optional[Any] = None,
|
||||
**kwargs: Any) -> Any:
|
||||
"""Create a new model with a default seed value"""
|
||||
seed = seed or str(current_time())
|
||||
self = super().__new__(cls, *args, seed=seed, **kwargs)
|
||||
self.dir_path = dir_path or os.getcwd()
|
||||
collector_class = collector_class or cls.collector_class
|
||||
@@ -77,6 +78,8 @@ class BaseEnvironment(Model):
|
||||
collector_class: type = datacollection.SoilCollector,
|
||||
agent_reporters: Optional[Any] = None,
|
||||
model_reporters: Optional[Any] = None,
|
||||
start_time=None,
|
||||
time_format=None,
|
||||
tables: Optional[Any] = None,
|
||||
init: bool = True,
|
||||
**env_params,
|
||||
@@ -93,12 +96,21 @@ class BaseEnvironment(Model):
|
||||
self.logger = logger
|
||||
else:
|
||||
self.logger = utils.logger.getChild(self.id)
|
||||
if start_time is not None:
|
||||
self.start_time = start_time
|
||||
if time_format is not None:
|
||||
self.time_format = time_format
|
||||
|
||||
if isinstance(self.start_time, str):
|
||||
self.start_time = datetime.strptime(self.start_time, self.time_format)
|
||||
if isinstance(self.start_time, datetime):
|
||||
self.start_time = self.start_time.timestamp()
|
||||
|
||||
self.schedule = schedule
|
||||
if schedule is None:
|
||||
if schedule_class is None:
|
||||
schedule_class = self.schedule_class
|
||||
self.schedule = schedule_class(self)
|
||||
self.schedule = schedule_class(self, time=self.start_time)
|
||||
|
||||
for (k, v) in env_params.items():
|
||||
self[k] = v
|
||||
@@ -114,22 +126,23 @@ class BaseEnvironment(Model):
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
return agentmod.AgentView(self.schedule._agents)
|
||||
return agentmod.AgentView(self.schedule._agents, getattr(self.schedule, "agents_by_type", None))
|
||||
|
||||
def agent(self, *args, **kwargs):
|
||||
return agentmod.AgentView(self.schedule._agents).one(*args, **kwargs)
|
||||
return agentmod.AgentView(self.schedule._agents, self.schedule.agents_by_type).one(*args, **kwargs)
|
||||
|
||||
def count_agents(self, *args, **kwargs):
|
||||
return sum(1 for i in self.agents(*args, **kwargs))
|
||||
|
||||
def agent_df(self, steps=False):
|
||||
df = self.datacollector.get_agent_vars_dataframe()
|
||||
df.index.rename(["step", "agent_id"], inplace=True)
|
||||
if steps:
|
||||
df.index.rename(["step", "agent_id"], inplace=True)
|
||||
return df
|
||||
df = df.reset_index()
|
||||
model_df = self.datacollector.get_model_vars_dataframe()
|
||||
df.index = df.index.set_levels(model_df.time, level=0).rename(["time", "agent_id"])
|
||||
return df
|
||||
df['time'] = df.apply(lambda row: model_df.loc[row.step].time, axis=1)
|
||||
return df.groupby(["time", "agent_id"]).last()
|
||||
|
||||
def model_df(self, steps=False):
|
||||
df = self.datacollector.get_model_vars_dataframe()
|
||||
@@ -140,7 +153,7 @@ class BaseEnvironment(Model):
|
||||
|
||||
@property
|
||||
def now(self):
|
||||
if self.schedule:
|
||||
if self.schedule is not None:
|
||||
return self.schedule.time
|
||||
raise Exception(
|
||||
"The environment has not been scheduled, so it has no sense of time"
|
||||
@@ -164,6 +177,10 @@ class BaseEnvironment(Model):
|
||||
self.schedule.add(a)
|
||||
return a
|
||||
|
||||
def remove_agent(self, agent):
|
||||
agent.alive = False
|
||||
self.schedule.remove(agent)
|
||||
|
||||
def add_agents(self, agent_classes: List[type], k, weights: Optional[List[float]] = None, **kwargs):
|
||||
if isinstance(agent_classes, type):
|
||||
agent_classes = [agent_classes]
|
||||
@@ -192,12 +209,15 @@ class BaseEnvironment(Model):
|
||||
super().step()
|
||||
self.schedule.step()
|
||||
self.datacollector.collect(self)
|
||||
if self.now == time.INFINITY:
|
||||
self.running = False
|
||||
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
msg = "Model data:\n"
|
||||
max_width = max(len(k) for k in self.datacollector.model_vars.keys())
|
||||
for (k, v) in self.datacollector.model_vars.items():
|
||||
msg += f"\t{k:<{max_width}}: {v[-1]:>6}\n"
|
||||
# msg += f"\t{k:<{max_width}}"
|
||||
msg += f"\t{k:<{max_width}}: {v[-1]}\n"
|
||||
self.logger.debug(f"--- Steps: {self.schedule.steps:^5} - Time: {self.now:^5} --- " + msg)
|
||||
|
||||
def add_model_reporter(self, name, func=None):
|
||||
@@ -205,11 +225,18 @@ class BaseEnvironment(Model):
|
||||
func = name
|
||||
self.datacollector._new_model_reporter(name, func)
|
||||
|
||||
def add_agent_reporter(self, name, reporter=None, agent_type=None):
|
||||
if not agent_type and not reporter:
|
||||
def add_agent_reporter(self, name, reporter=None, agent_class=None, *, agent_type=None):
|
||||
if agent_type:
|
||||
print("agent_type is deprecated, use agent_class instead", file=sys.stderr)
|
||||
agent_class = agent_type or agent_class
|
||||
if not reporter and not agent_class:
|
||||
reporter = name
|
||||
elif agent_type:
|
||||
reporter = lambda a: reporter(a) if isinstance(a, agent_type) else None
|
||||
if agent_class:
|
||||
if reporter:
|
||||
_orig = reporter
|
||||
else:
|
||||
_orig = lambda a: getattr(a, name)
|
||||
reporter = lambda a: (_orig(a) if isinstance(a, agent_class) else None)
|
||||
self.datacollector._new_agent_reporter(name, reporter)
|
||||
|
||||
@classmethod
|
||||
@@ -294,6 +321,11 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
self.G.nodes[node_id]["agent"] = a
|
||||
return a
|
||||
|
||||
def remove_agent(self, agent, remove_node=True):
|
||||
super().remove_agent(agent)
|
||||
if remove_node and hasattr(agent, "remove_node"):
|
||||
agent.remove_node()
|
||||
|
||||
def add_agents(self, *args, k=None, **kwargs):
|
||||
if not k and not self.G:
|
||||
raise ValueError("Cannot add agents to an empty network")
|
||||
@@ -331,15 +363,17 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
if getattr(agent, "alive", True):
|
||||
yield agent
|
||||
|
||||
def add_node(self, agent_class, unique_id=None, node_id=None, **kwargs):
|
||||
def add_node(self, agent_class, unique_id=None, node_id=None, find_unassigned=False, **kwargs):
|
||||
if unique_id is None:
|
||||
unique_id = self.next_id()
|
||||
if node_id is None:
|
||||
node_id = network.find_unassigned(
|
||||
G=self.G, shuffle=True, random=self.random
|
||||
)
|
||||
if find_unassigned:
|
||||
node_id = network.find_unassigned(
|
||||
G=self.G, shuffle=True, random=self.random
|
||||
)
|
||||
if node_id is None:
|
||||
node_id = f"node_for_{unique_id}"
|
||||
node_id = f"Node_for_agent_{unique_id}"
|
||||
assert node_id not in self.G.nodes
|
||||
|
||||
if node_id not in self.G.nodes:
|
||||
self.G.add_node(node_id)
|
||||
@@ -409,27 +443,115 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
|
||||
|
||||
class EventedEnvironment(BaseEnvironment):
|
||||
def broadcast(self, msg, sender=None, expiration=None, ttl=None, **kwargs):
|
||||
for agent in self.agents(**kwargs):
|
||||
if agent == sender:
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._inbox = dict()
|
||||
super().__init__(*args, **kwargs)
|
||||
self._can_reschedule = hasattr(self.schedule, "add_callback") and hasattr(self.schedule, "remove_callback")
|
||||
self._can_reschedule = True
|
||||
self._callbacks = {}
|
||||
|
||||
def register(self, agent):
|
||||
self._inbox[agent.unique_id] = []
|
||||
|
||||
def inbox_for(self, agent):
|
||||
try:
|
||||
return self._inbox[agent.unique_id]
|
||||
except KeyError:
|
||||
raise ValueError(f"Trying to access inbox for unregistered agent: {agent} (class: {type(agent)}). "
|
||||
"Make sure your agent is of type EventedAgent and it is registered with the environment.")
|
||||
|
||||
@coroutine
|
||||
def _polling_callback(self, agent, expiration, delay):
|
||||
# this wakes the agent up at every step. It is better to wait until timeout (or inf)
|
||||
# and if a message is received before that, reschedule the agent
|
||||
# (That is implemented in the `received` method)
|
||||
inbox = self.inbox_for(agent)
|
||||
while self.now < expiration:
|
||||
if inbox:
|
||||
return self.process_messages(inbox)
|
||||
yield time.Delay(delay)
|
||||
raise events.TimedOut("No message received")
|
||||
|
||||
@coroutine
|
||||
def received(self, agent, expiration=None, timeout=None, delay=1):
|
||||
if not expiration:
|
||||
if timeout:
|
||||
expiration = self.now + timeout
|
||||
else:
|
||||
expiration = float("inf")
|
||||
inbox = self.inbox_for(agent)
|
||||
if inbox:
|
||||
return self.process_messages(inbox)
|
||||
|
||||
if self._can_reschedule:
|
||||
checked = False
|
||||
def cb():
|
||||
nonlocal checked
|
||||
if checked:
|
||||
return time.INFINITY
|
||||
checked = True
|
||||
self.schedule.add_callback(self.now, agent.step)
|
||||
self.schedule.add_callback(expiration, cb)
|
||||
self._callbacks[agent.unique_id] = cb
|
||||
yield time.INFINITY
|
||||
res = yield from self._polling_callback(agent, expiration, delay)
|
||||
return res
|
||||
|
||||
|
||||
def tell(self, msg, recipient, sender=None, expiration=None, timeout=None, **kwargs):
|
||||
if expiration is None:
|
||||
expiration = float("inf") if timeout is None else self.now + timeout
|
||||
self._add_to_inbox(recipient.unique_id,
|
||||
events.Tell(timestamp=self.now,
|
||||
payload=msg,
|
||||
sender=sender,
|
||||
expiration=expiration,
|
||||
**kwargs))
|
||||
|
||||
def broadcast(self, msg, sender, ttl=None, expiration=None, agent_class=None):
|
||||
expiration = expiration if ttl is None else self.now + ttl
|
||||
# This only works for Soil environments. Mesa agents do not have an `agents` method
|
||||
sender_id = sender.unique_id
|
||||
for (agent_id, inbox) in self._inbox.items():
|
||||
if agent_id == sender_id:
|
||||
continue
|
||||
self.logger.debug(f"Telling {repr(agent)}: {msg} ttl={ttl}")
|
||||
try:
|
||||
inbox = agent._inbox
|
||||
except AttributeError:
|
||||
self.logger.info(
|
||||
f"Agent {agent.unique_id} cannot receive events because it does not have an inbox"
|
||||
)
|
||||
if agent_class and not isinstance(self.agents(unique_id=agent_id), agent_class):
|
||||
continue
|
||||
# Allow for AttributeError exceptions in this part of the code
|
||||
inbox.append(
|
||||
self.logger.debug(f"Telling {agent_id}: {msg} ttl={ttl}")
|
||||
self._add_to_inbox(agent_id,
|
||||
events.Tell(
|
||||
payload=msg,
|
||||
sender=sender,
|
||||
expiration=expiration if ttl is None else self.now + ttl,
|
||||
expiration=expiration,
|
||||
)
|
||||
)
|
||||
def _add_to_inbox(self, inbox_id, msg):
|
||||
self._inbox[inbox_id].append(msg)
|
||||
if inbox_id in self._callbacks:
|
||||
cb = self._callbacks.pop(inbox_id)
|
||||
cb()
|
||||
|
||||
@coroutine
|
||||
def ask(self, msg, recipient, sender=None, expiration=None, timeout=None, delay=1):
|
||||
ask = events.Ask(timestamp=self.now, payload=msg, sender=sender)
|
||||
self._add_to_inbox(recipient.unique_id, ask)
|
||||
expiration = float("inf") if timeout is None else self.now + timeout
|
||||
while self.now < expiration:
|
||||
if ask.reply:
|
||||
return ask.reply
|
||||
yield time.Delay(delay)
|
||||
raise events.TimedOut("No reply received")
|
||||
|
||||
def process_messages(self, inbox):
|
||||
valid = list()
|
||||
for msg in inbox:
|
||||
if msg.expired(self.now):
|
||||
continue
|
||||
valid.append(msg)
|
||||
inbox.clear()
|
||||
return valid
|
||||
|
||||
|
||||
class Environment(NetworkEnvironment, EventedEnvironment):
|
||||
"""Default environment class, has both network and event capabilities"""
|
||||
class Environment(EventedEnvironment, NetworkEnvironment):
|
||||
pass
|
||||
|
@@ -18,7 +18,6 @@ class Message:
|
||||
def expired(self, when):
|
||||
return self.expiration is not None and self.expiration < when
|
||||
|
||||
|
||||
class Reply(Message):
|
||||
source: Message
|
||||
|
||||
@@ -28,7 +27,9 @@ class Ask(Message):
|
||||
|
||||
|
||||
class Tell(Message):
|
||||
pass
|
||||
def __post_init__(self):
|
||||
assert self.sender is not None, "Tell requires a sender"
|
||||
|
||||
|
||||
|
||||
class TimedOut(Exception):
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
from time import time as current_time
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from textwrap import dedent, indent
|
||||
|
||||
@@ -75,6 +76,13 @@ class Exporter:
|
||||
"""Method to call when a iteration ends"""
|
||||
pass
|
||||
|
||||
def env_id(self, env):
|
||||
try:
|
||||
return env.id
|
||||
except AttributeError:
|
||||
return f"{env.__class__.__name__}_{current_time()}"
|
||||
|
||||
|
||||
def output(self, f, mode="w", **kwargs):
|
||||
if not self.dump:
|
||||
f = DryRunner(f, copy_to=self.copy_to)
|
||||
@@ -86,18 +94,19 @@ class Exporter:
|
||||
pass
|
||||
return open_or_reuse(f, mode=mode, backup=self.simulation.backup, **kwargs)
|
||||
|
||||
def get_dfs(self, env, **kwargs):
|
||||
def get_dfs(self, env, params_id, **kwargs):
|
||||
yield from get_dc_dfs(env.datacollector,
|
||||
simulation_id=self.simulation.id,
|
||||
iteration_id=env.id,
|
||||
params_id,
|
||||
iteration_id=self.env_id(env),
|
||||
**kwargs)
|
||||
|
||||
|
||||
def get_dc_dfs(dc, **kwargs):
|
||||
def get_dc_dfs(dc, params_id, **kwargs):
|
||||
dfs = {}
|
||||
dfe = dc.get_model_vars_dataframe()
|
||||
dfe.index.rename("step", inplace=True)
|
||||
dfs["env"] = dfe
|
||||
kwargs["params_id"] = params_id
|
||||
try:
|
||||
dfa = dc.get_agent_vars_dataframe()
|
||||
dfa.index.rename(["step", "agent_id"], inplace=True)
|
||||
@@ -108,8 +117,12 @@ def get_dc_dfs(dc, **kwargs):
|
||||
dfs[table_name] = dc.get_table_dataframe(table_name)
|
||||
for (name, df) in dfs.items():
|
||||
for (k, v) in kwargs.items():
|
||||
df[k] = v
|
||||
df.set_index(["simulation_id", "iteration_id"], append=True, inplace=True)
|
||||
if v:
|
||||
df[k] = v
|
||||
else:
|
||||
df[k] = pd.Series(dtype="object")
|
||||
df.reset_index(inplace=True)
|
||||
df.set_index(["params_id", "iteration_id"], inplace=True)
|
||||
|
||||
yield from dfs.items()
|
||||
|
||||
@@ -134,12 +147,16 @@ class SQLite(Exporter):
|
||||
if os.path.exists(self.dbpath):
|
||||
os.remove(self.dbpath)
|
||||
|
||||
outdir = os.path.dirname(self.dbpath)
|
||||
if outdir and not os.path.exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
|
||||
self.engine = create_engine(f"sqlite:///{self.dbpath}", echo=False)
|
||||
|
||||
sim_dict = {k: serialize(v)[0] for (k,v) in self.simulation.to_dict().items()}
|
||||
sim_dict["simulation_id"] = self.simulation.id
|
||||
df = pd.DataFrame([sim_dict])
|
||||
df.to_sql("configuration", con=self.engine, if_exists="append")
|
||||
df.reset_index().to_sql("configuration", con=self.engine, if_exists="append", index=False)
|
||||
|
||||
def iteration_end(self, env, params, params_id, *args, **kwargs):
|
||||
if not self.dump:
|
||||
@@ -147,17 +164,30 @@ class SQLite(Exporter):
|
||||
return
|
||||
|
||||
with timer(
|
||||
"Dumping simulation {} iteration {}".format(self.simulation.name, env.id)
|
||||
"Dumping simulation {} iteration {}".format(self.simulation.name, self.env_id(env))
|
||||
):
|
||||
|
||||
pd.DataFrame([{"simulation_id": self.simulation.id,
|
||||
d = {"simulation_id": self.simulation.id,
|
||||
"params_id": params_id,
|
||||
"iteration_id": env.id,
|
||||
"key": k,
|
||||
"value": serialize(v)[0]} for (k,v) in params.items()]).to_sql("parameters", con=self.engine, if_exists="append")
|
||||
"iteration_id": self.env_id(env),
|
||||
}
|
||||
for (k,v) in params.items():
|
||||
d[k] = serialize(v)[0]
|
||||
|
||||
pd.DataFrame([d]).reset_index().to_sql("parameters",
|
||||
con=self.engine,
|
||||
if_exists="append",
|
||||
index=False)
|
||||
pd.DataFrame([{
|
||||
"simulation_id": self.simulation.id,
|
||||
"params_id": params_id,
|
||||
"iteration_id": self.env_id(env),
|
||||
}]).reset_index().to_sql("iterations",
|
||||
con=self.engine,
|
||||
if_exists="append",
|
||||
index=False)
|
||||
|
||||
for (t, df) in self.get_dfs(env, params_id=params_id):
|
||||
df.to_sql(t, con=self.engine, if_exists="append")
|
||||
df.reset_index().to_sql(t, con=self.engine, if_exists="append", index=False)
|
||||
|
||||
class csv(Exporter):
|
||||
"""Export the state of each environment (and its agents) a CSV file for the simulation"""
|
||||
@@ -168,11 +198,11 @@ class csv(Exporter):
|
||||
def iteration_end(self, env, params, params_id, *args, **kwargs):
|
||||
with timer(
|
||||
"[CSV] Dumping simulation {} iteration {} @ dir {}".format(
|
||||
self.simulation.name, env.id, self.outdir
|
||||
self.simulation.name, self.env_id(env), self.outdir
|
||||
)
|
||||
):
|
||||
for (df_name, df) in self.get_dfs(env, params_id=params_id):
|
||||
with self.output("{}.{}.csv".format(env.id, df_name), mode="a") as f:
|
||||
with self.output("{}.{}.csv".format(self.env_id(env), df_name), mode="a") as f:
|
||||
df.to_csv(f)
|
||||
|
||||
|
||||
@@ -183,9 +213,9 @@ class gexf(Exporter):
|
||||
return
|
||||
|
||||
with timer(
|
||||
"[GEXF] Dumping simulation {} iteration {}".format(self.simulation.name, env.id)
|
||||
"[GEXF] Dumping simulation {} iteration {}".format(self.simulation.name, self.env_id(env))
|
||||
):
|
||||
with self.output("{}.gexf".format(env.id), mode="wb") as f:
|
||||
with self.output("{}.gexf".format(self.env_id(env)), mode="wb") as f:
|
||||
nx.write_gexf(env.G, f)
|
||||
|
||||
|
||||
@@ -219,16 +249,16 @@ class graphdrawing(Exporter):
|
||||
pos=nx.spring_layout(env.G, scale=100),
|
||||
ax=f.add_subplot(111),
|
||||
)
|
||||
with open("graph-{}.png".format(env.id)) as f:
|
||||
with open("graph-{}.png".format(self.env_id(env))) as f:
|
||||
f.savefig(f)
|
||||
|
||||
|
||||
class summary(Exporter):
|
||||
"""Print a summary of each iteration to sys.stdout"""
|
||||
|
||||
def iteration_end(self, env, *args, **kwargs):
|
||||
def iteration_end(self, env, params_id, *args, **kwargs):
|
||||
msg = ""
|
||||
for (t, df) in self.get_dfs(env):
|
||||
for (t, df) in self.get_dfs(env, params_id):
|
||||
if not len(df):
|
||||
continue
|
||||
tabs = "\t" * 2
|
||||
@@ -262,21 +292,5 @@ class YAML(Exporter):
|
||||
logger.info(f"Dumping simulation configuration to {self.outdir}")
|
||||
f.write(self.simulation.to_yaml())
|
||||
|
||||
class default(Exporter):
|
||||
"""Default exporter. Writes sqlite results, as well as the simulation YAML"""
|
||||
|
||||
def __init__(self, *args, exporter_cls=[], **kwargs):
|
||||
exporter_cls = exporter_cls or [YAML, SQLite]
|
||||
self.inner = [cls(*args, **kwargs) for cls in exporter_cls]
|
||||
|
||||
def sim_start(self, *args, **kwargs):
|
||||
for exporter in self.inner:
|
||||
exporter.sim_start(*args, **kwargs)
|
||||
|
||||
def sim_end(self, *args, **kwargs):
|
||||
for exporter in self.inner:
|
||||
exporter.sim_end(*args, **kwargs)
|
||||
|
||||
def iteration_end(self, *args, **kwargs):
|
||||
for exporter in self.inner:
|
||||
exporter.iteration_end(*args, **kwargs)
|
||||
default = SQLite
|
@@ -44,8 +44,8 @@ def do_not_run():
|
||||
|
||||
def _iter_queued():
|
||||
while _QUEUED:
|
||||
(cls, params) = _QUEUED.pop(0)
|
||||
yield replace(cls, parameters=params)
|
||||
slf = _QUEUED.pop(0)
|
||||
yield slf
|
||||
|
||||
|
||||
# TODO: change documentation for simulation
|
||||
@@ -130,11 +130,11 @@ class Simulation:
|
||||
def run(self, **kwargs):
|
||||
"""Run the simulation and return the list of resulting environments"""
|
||||
if kwargs:
|
||||
return replace(self, **kwargs).run()
|
||||
res = replace(self, **kwargs)
|
||||
return res.run()
|
||||
|
||||
param_combinations = self._collect_params(**kwargs)
|
||||
if _AVOID_RUNNING:
|
||||
_QUEUED.extend((self, param) for param in param_combinations)
|
||||
_QUEUED.append(self)
|
||||
return []
|
||||
|
||||
self.logger.debug("Using exporters: %s", self.exporters or [])
|
||||
@@ -154,6 +154,8 @@ class Simulation:
|
||||
for exporter in exporters:
|
||||
exporter.sim_start()
|
||||
|
||||
param_combinations = self._collect_params(**kwargs)
|
||||
|
||||
for params in tqdm(param_combinations, desc=self.name, unit="configuration"):
|
||||
for (k, v) in params.items():
|
||||
tqdm.write(f"{k} = {v}")
|
||||
@@ -204,6 +206,7 @@ class Simulation:
|
||||
for env in tqdm(utils.run_parallel(
|
||||
func=func,
|
||||
iterable=range(self.iterations),
|
||||
num_processes=self.num_processes,
|
||||
**params,
|
||||
), total=self.iterations, leave=False):
|
||||
if env is None and self.dry_run:
|
||||
@@ -338,12 +341,13 @@ def iter_from_py(pyfile, module_name='imported_file', **kwargs):
|
||||
sims.append(sim)
|
||||
for sim in _iter_queued():
|
||||
sims.append(sim)
|
||||
# Try to find environments to run, because we did not import a script that ran simulations
|
||||
if not sims:
|
||||
for (_name, env) in inspect.getmembers(module,
|
||||
lambda x: inspect.isclass(x) and
|
||||
issubclass(x, environment.Environment) and
|
||||
(getattr(x, "__module__", None) != environment.__name__)):
|
||||
sims.append(Simulation(model=env, **kwargs))
|
||||
lambda x: inspect.isclass(x) and
|
||||
issubclass(x, environment.Environment) and
|
||||
(getattr(x, "__module__", None) != environment.__name__)):
|
||||
sims.append(Simulation(model=env, **kwargs))
|
||||
del sys.modules[module_name]
|
||||
assert not _AVOID_RUNNING
|
||||
if not sims:
|
||||
|
230
soil/time.py
230
soil/time.py
@@ -1,13 +1,14 @@
|
||||
from mesa.time import BaseScheduler
|
||||
from queue import Empty
|
||||
from heapq import heappush, heappop, heapreplace
|
||||
from collections import deque
|
||||
from collections import deque, defaultdict
|
||||
import math
|
||||
import logging
|
||||
|
||||
from inspect import getsource
|
||||
from numbers import Number
|
||||
from textwrap import dedent
|
||||
import random as random_std
|
||||
|
||||
from .utils import logger
|
||||
from mesa import Agent as MesaAgent
|
||||
@@ -24,54 +25,78 @@ class Delay:
|
||||
def __float__(self):
|
||||
return self.delta
|
||||
|
||||
def __eq__(self, other):
|
||||
return float(self) == float(other)
|
||||
|
||||
def __await__(self):
|
||||
return (yield self.delta)
|
||||
|
||||
class When:
|
||||
def __init__(self, when):
|
||||
raise Exception("The use of When is deprecated. Use the `Agent.at` and `Agent.delay` methods instead")
|
||||
|
||||
class Delta:
|
||||
def __init__(self, delta):
|
||||
raise Exception("The use of Delay is deprecated. Use the `Agent.at` and `Agent.delay` methods instead")
|
||||
|
||||
class DeadAgent(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PQueueActivation(BaseScheduler):
|
||||
class Event(object):
|
||||
def __init__(self, when: float, func, order=1):
|
||||
self.when = when
|
||||
self.func = func
|
||||
self.order = order
|
||||
|
||||
def __repr__(self):
|
||||
return f'Event @ {self.when} - Func: {self.func}'
|
||||
|
||||
def __lt__(self, other):
|
||||
return (self.when < other.when) or (self.when == other.when and self.order < other.order)
|
||||
|
||||
|
||||
class PQueueSchedule:
|
||||
"""
|
||||
A scheduler which activates each agent with a delay returned by the agent's step method.
|
||||
A scheduler which activates each function with a delay returned by the function at each step.
|
||||
If no delay is returned, a default of 1 is used.
|
||||
|
||||
In each activation, each agent will update its 'next_time'.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, shuffle=True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, shuffle=True, seed=None, time=0, **kwargs):
|
||||
self._queue = []
|
||||
self._shuffle = shuffle
|
||||
self.logger = getattr(self.model, "logger", logger).getChild(f"time_{ self.model }")
|
||||
self.time = time
|
||||
self.steps = 0
|
||||
self.random = random_std.Random(seed)
|
||||
self.next_time = self.time
|
||||
|
||||
def add(self, agent: MesaAgent, when=None):
|
||||
def insert(self, when, callback, replace=False):
|
||||
if when is None:
|
||||
when = self.time
|
||||
else:
|
||||
when = float(when)
|
||||
|
||||
self._schedule(agent, None, when)
|
||||
super().add(agent)
|
||||
|
||||
def _schedule(self, agent, when=None, replace=False):
|
||||
if when is None:
|
||||
when = self.time
|
||||
order = 1
|
||||
if self._shuffle:
|
||||
key = (when, self.model.random.random())
|
||||
else:
|
||||
key = (when, agent.unique_id)
|
||||
order = self.random.random()
|
||||
event = Event(when, callback, order=order)
|
||||
if replace:
|
||||
heapreplace(self._queue, (key, agent))
|
||||
heapreplace(self._queue, event)
|
||||
else:
|
||||
heappush(self._queue, (key, agent))
|
||||
heappush(self._queue, event)
|
||||
|
||||
def remove(self, callback):
|
||||
for i, event in enumerate(self._queue):
|
||||
if callback == event.func:
|
||||
del self._queue[i]
|
||||
break
|
||||
|
||||
def __len__(self):
|
||||
return len(self._queue)
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
Executes agents in order, one at a time. After each step,
|
||||
an agent will signal when it wants to be scheduled next.
|
||||
Executes events in order, one at a time. After each step,
|
||||
an event will signal when it wants to be scheduled next.
|
||||
"""
|
||||
|
||||
if self.time == INFINITY:
|
||||
@@ -82,66 +107,86 @@ class PQueueActivation(BaseScheduler):
|
||||
now = self.time
|
||||
|
||||
while self._queue:
|
||||
((when, _id), agent) = self._queue[0]
|
||||
event = self._queue[0]
|
||||
when = event.when
|
||||
if when > now:
|
||||
next_time = when
|
||||
break
|
||||
|
||||
try:
|
||||
when = agent.step() or 1
|
||||
when += now
|
||||
except DeadAgent:
|
||||
heappop(self._queue)
|
||||
continue
|
||||
when = event.func()
|
||||
when = float(when) if when is not None else 1.0
|
||||
|
||||
if when == INFINITY:
|
||||
heappop(self._queue)
|
||||
continue
|
||||
|
||||
self._schedule(agent, when, replace=True)
|
||||
when += now
|
||||
|
||||
self.insert(when, event.func, replace=True)
|
||||
|
||||
self.steps += 1
|
||||
|
||||
self.time = next_time
|
||||
|
||||
if next_time == INFINITY:
|
||||
self.model.running = False
|
||||
self.time = INFINITY
|
||||
return
|
||||
|
||||
|
||||
class TimedActivation(BaseScheduler):
|
||||
def __init__(self, *args, shuffle=True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
class Schedule:
|
||||
def __init__(self, shuffle=True, seed=None, time=0, **kwargs):
|
||||
self._queue = deque()
|
||||
self._shuffle = shuffle
|
||||
self.logger = getattr(self.model, "logger", logger).getChild(f"time_{ self.model }")
|
||||
self.time = time
|
||||
self.steps = 0
|
||||
self.random = random_std.Random(seed)
|
||||
self.next_time = self.time
|
||||
|
||||
def add(self, agent: MesaAgent, when=None):
|
||||
def _find_loc(self, when=None):
|
||||
if when is None:
|
||||
when = self.time
|
||||
else:
|
||||
when = float(when)
|
||||
self._schedule(agent, None, when)
|
||||
super().add(agent)
|
||||
|
||||
def _schedule(self, agent, when=None, replace=False):
|
||||
when = when or self.time
|
||||
pos = len(self._queue)
|
||||
for (ix, l) in enumerate(self._queue):
|
||||
if l[0] == when:
|
||||
l[1].append(agent)
|
||||
return
|
||||
return l[1]
|
||||
if l[0] > when:
|
||||
pos = ix
|
||||
break
|
||||
self._queue.insert(pos, (when, [agent]))
|
||||
lst = []
|
||||
self._queue.insert(pos, (when, lst))
|
||||
return lst
|
||||
|
||||
def insert(self, when, func, replace=False):
|
||||
if when == INFINITY:
|
||||
return
|
||||
lst = self._find_loc(when)
|
||||
lst.append(func)
|
||||
|
||||
def add_bulk(self, funcs, when=None):
|
||||
lst = self._find_loc(when)
|
||||
n = len(funcs)
|
||||
#TODO: remove for performance
|
||||
before = len(self)
|
||||
lst.extend(funcs)
|
||||
assert len(self) == before + n
|
||||
|
||||
def remove(self, func):
|
||||
for bucket in self._queue:
|
||||
for (ix, e) in enumerate(bucket):
|
||||
if e == func:
|
||||
bucket.remove(ix)
|
||||
return
|
||||
|
||||
def __len__(self):
|
||||
return sum(len(bucket[1]) for bucket in self._queue)
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
Executes agents in order, one at a time. After each step,
|
||||
an agent will signal when it wants to be scheduled next.
|
||||
Executes events in order, one at a time. After each step,
|
||||
an event will signal when it wants to be scheduled next.
|
||||
"""
|
||||
if not self._queue:
|
||||
return
|
||||
@@ -156,16 +201,20 @@ class TimedActivation(BaseScheduler):
|
||||
|
||||
bucket = self._queue.popleft()[1]
|
||||
if self._shuffle:
|
||||
self.model.random.shuffle(bucket)
|
||||
for agent in bucket:
|
||||
try:
|
||||
when = agent.step() or 1
|
||||
when += now
|
||||
except DeadAgent:
|
||||
self.random.shuffle(bucket)
|
||||
next_batch = defaultdict(list)
|
||||
for func in bucket:
|
||||
when = func()
|
||||
when = float(when) if when is not None else 1
|
||||
|
||||
if when == INFINITY:
|
||||
continue
|
||||
|
||||
if when != INFINITY:
|
||||
self._schedule(agent, when, replace=True)
|
||||
when += now
|
||||
next_batch[when].append(func)
|
||||
|
||||
for (when, bucket) in next_batch.items():
|
||||
self.add_bulk(bucket, when)
|
||||
|
||||
self.steps += 1
|
||||
if self._queue:
|
||||
@@ -174,6 +223,77 @@ class TimedActivation(BaseScheduler):
|
||||
self.time = INFINITY
|
||||
|
||||
|
||||
class InnerActivation(BaseScheduler):
|
||||
inner_class = Schedule
|
||||
|
||||
def __init__(self, model, shuffle=True, time=0, **kwargs):
|
||||
self.model = model
|
||||
self.logger = getattr(self.model, "logger", logger).getChild(f"time_{ self.model }")
|
||||
self._agents = {}
|
||||
self.agents_by_type = defaultdict(dict)
|
||||
self.inner = self.inner_class(shuffle=shuffle, seed=self.model._seed, time=time)
|
||||
|
||||
@property
|
||||
def steps(self):
|
||||
return self.inner.steps
|
||||
|
||||
@property
|
||||
def time(self):
|
||||
return self.inner.time
|
||||
|
||||
def add(self, agent: MesaAgent, when=None):
|
||||
when = when or self.inner.time
|
||||
self.inner.insert(when, agent.step)
|
||||
agent_class = type(agent)
|
||||
self.agents_by_type[agent_class][agent.unique_id] = agent
|
||||
super().add(agent)
|
||||
|
||||
def add_callback(self, when, cb):
|
||||
self.inner.insert(when, cb)
|
||||
|
||||
def remove_callback(self, when, cb):
|
||||
self.inner.remove(cb)
|
||||
|
||||
def remove(self, agent):
|
||||
del self._agents[agent.unique_id]
|
||||
del self.agents_by_type[type(agent)][agent.unique_id]
|
||||
self.inner.remove(agent.step)
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
Executes agents in order, one at a time. After each step,
|
||||
an agent will signal when it wants to be scheduled next.
|
||||
"""
|
||||
self.inner.step()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.inner)
|
||||
|
||||
|
||||
class BucketTimedActivation(InnerActivation):
|
||||
inner_class = Schedule
|
||||
|
||||
|
||||
class PQueueActivation(InnerActivation):
|
||||
inner_class = PQueueSchedule
|
||||
|
||||
|
||||
#Set the bucket implementation as default
|
||||
TimedActivation = BucketTimedActivation
|
||||
|
||||
try:
|
||||
from soilent.soilent import BucketScheduler, PQueueScheduler
|
||||
|
||||
class SoilentActivation(InnerActivation):
|
||||
inner_class = BucketScheduler
|
||||
class SoilentPQueueActivation(InnerActivation):
|
||||
inner_class = PQueueScheduler
|
||||
|
||||
# TimedActivation = SoilentBucketActivation
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class ShuffledTimedActivation(TimedActivation):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, shuffle=True, **kwargs)
|
||||
@@ -182,5 +302,3 @@ class ShuffledTimedActivation(TimedActivation):
|
||||
class OrderedTimedActivation(TimedActivation):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, shuffle=False, **kwargs)
|
||||
|
||||
|
||||
|
@@ -93,15 +93,12 @@ def flatten_dict(d):
|
||||
|
||||
def _flatten_dict(d, prefix=""):
|
||||
if not isinstance(d, dict):
|
||||
# print('END:', prefix, d)
|
||||
yield prefix, d
|
||||
return
|
||||
if prefix:
|
||||
prefix = prefix + "."
|
||||
for k, v in d.items():
|
||||
# print(k, v)
|
||||
res = list(_flatten_dict(v, prefix="{}{}".format(prefix, k)))
|
||||
# print('RES:', res)
|
||||
yield from res
|
||||
|
||||
|
||||
@@ -142,6 +139,7 @@ def run_and_return_exceptions(func, *args, **kwargs):
|
||||
|
||||
def run_parallel(func, iterable, num_processes=1, **kwargs):
|
||||
if num_processes > 1 and not os.environ.get("SOIL_DEBUG", None):
|
||||
logger.info("Running simulations in {} processes".format(num_processes))
|
||||
if num_processes < 1:
|
||||
num_processes = cpu_count() - num_processes
|
||||
p = Pool(processes=num_processes)
|
||||
@@ -158,3 +156,23 @@ def run_parallel(func, iterable, num_processes=1, **kwargs):
|
||||
|
||||
def int_seed(seed: str):
|
||||
return int.from_bytes(seed.encode(), "little")
|
||||
|
||||
|
||||
def prob(prob, random):
|
||||
"""
|
||||
A true/False uniform distribution with a given probability.
|
||||
To be used like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
if prob(0.3):
|
||||
do_something()
|
||||
|
||||
"""
|
||||
r = random.random()
|
||||
return r < prob
|
||||
|
||||
|
||||
def custom(cls, **kwargs):
|
||||
"""Create a new class from a template class and keyword arguments"""
|
||||
return type(cls.__name__, (cls,), kwargs)
|
@@ -1,7 +1,7 @@
|
||||
from unittest import TestCase
|
||||
import pytest
|
||||
|
||||
from soil import agents, environment
|
||||
from soil import agents, events, environment
|
||||
from soil import time as stime
|
||||
|
||||
|
||||
@@ -20,13 +20,14 @@ class TestAgents(TestCase):
|
||||
assert ret == stime.INFINITY
|
||||
|
||||
def test_die_raises_exception(self):
|
||||
"""A dead agent should raise an exception if it is stepped after death"""
|
||||
"""A dead agent should continue returning INFINITY after death"""
|
||||
d = Dead(unique_id=0, model=environment.Environment())
|
||||
assert d.alive
|
||||
d.step()
|
||||
assert not d.alive
|
||||
with pytest.raises(stime.DeadAgent):
|
||||
d.step()
|
||||
when = float(d.step())
|
||||
assert not d.alive
|
||||
assert when == stime.INFINITY
|
||||
|
||||
def test_agent_generator(self):
|
||||
"""
|
||||
@@ -62,6 +63,7 @@ class TestAgents(TestCase):
|
||||
def other(self):
|
||||
self.times_run += 1
|
||||
|
||||
assert MyAgent.other.id == "other"
|
||||
e = environment.Environment()
|
||||
a = e.add_agent(MyAgent)
|
||||
e.step()
|
||||
@@ -72,6 +74,53 @@ class TestAgents(TestCase):
|
||||
a.step()
|
||||
assert a.times_run == 2
|
||||
|
||||
def test_state_decorator_multiple(self):
|
||||
class MyAgent(agents.FSM):
|
||||
times_run = 0
|
||||
|
||||
@agents.state(default=True)
|
||||
def one(self):
|
||||
return self.two
|
||||
|
||||
@agents.state
|
||||
def two(self):
|
||||
return self.one
|
||||
|
||||
e = environment.Environment()
|
||||
first = e.add_agent(MyAgent, state_id=MyAgent.one)
|
||||
second = e.add_agent(MyAgent, state_id=MyAgent.two)
|
||||
assert first.state_id == MyAgent.one.id
|
||||
assert second.state_id == MyAgent.two.id
|
||||
e.step()
|
||||
assert first.state_id == MyAgent.two.id
|
||||
assert second.state_id == MyAgent.one.id
|
||||
|
||||
def test_state_decorator_multiple_async(self):
|
||||
class MyAgent(agents.FSM):
|
||||
times_run = 0
|
||||
|
||||
@agents.state(default=True)
|
||||
def one(self):
|
||||
yield self.delay(1)
|
||||
return self.two
|
||||
|
||||
@agents.state
|
||||
def two(self):
|
||||
yield self.delay(1)
|
||||
return self.one
|
||||
|
||||
e = environment.Environment()
|
||||
first = e.add_agent(MyAgent, state_id=MyAgent.one)
|
||||
second = e.add_agent(MyAgent, state_id=MyAgent.two)
|
||||
for i in range(2):
|
||||
assert first.state_id == MyAgent.one.id
|
||||
assert second.state_id == MyAgent.two.id
|
||||
e.step()
|
||||
for i in range(2):
|
||||
assert first.state_id == MyAgent.two.id
|
||||
assert second.state_id == MyAgent.one.id
|
||||
e.step()
|
||||
|
||||
def test_broadcast(self):
|
||||
"""
|
||||
An agent should be able to broadcast messages to every other agent, AND each receiver should be able
|
||||
@@ -79,30 +128,28 @@ class TestAgents(TestCase):
|
||||
"""
|
||||
|
||||
class BCast(agents.Evented):
|
||||
pings_received = 0
|
||||
pings_received = []
|
||||
|
||||
def step(self):
|
||||
print(self.model.broadcast)
|
||||
try:
|
||||
self.model.broadcast("PING")
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
async def step(self):
|
||||
self.broadcast("PING")
|
||||
print("PING sent")
|
||||
while True:
|
||||
self.process_messages()
|
||||
yield
|
||||
msgs = await self.received()
|
||||
self.pings_received += msgs
|
||||
|
||||
def on_receive(self, msg, sender=None):
|
||||
self.pings_received += 1
|
||||
e = environment.Environment()
|
||||
|
||||
e = environment.EventedEnvironment()
|
||||
|
||||
for i in range(10):
|
||||
num_agents = 10
|
||||
for i in range(num_agents):
|
||||
e.add_agent(agent_class=BCast)
|
||||
e.step()
|
||||
pings_received = lambda: [a.pings_received for a in e.agents]
|
||||
assert sorted(pings_received()) == list(range(1, 11))
|
||||
# Agents are executed in order, so the first agent should have not received any messages
|
||||
pings_received = lambda: [len(a.pings_received) for a in e.agents]
|
||||
assert sorted(pings_received()) == list(range(0, num_agents))
|
||||
e.step()
|
||||
assert all(x == 10 for x in pings_received())
|
||||
# After the second step, every agent should have received a broadcast from every other agent
|
||||
received = pings_received()
|
||||
assert all(x == (num_agents - 1) for x in received)
|
||||
|
||||
def test_ask_messages(self):
|
||||
"""
|
||||
@@ -140,17 +187,16 @@ class TestAgents(TestCase):
|
||||
print("NOT sending ping")
|
||||
print("Checking msgs")
|
||||
# Do not block if we have already received a PING
|
||||
if not self.process_messages():
|
||||
yield from self.received()
|
||||
print("done")
|
||||
msgs = yield from self.received()
|
||||
for ping in msgs:
|
||||
if ping.payload == "PING":
|
||||
ping.reply = "PONG"
|
||||
pongs.append(self.now)
|
||||
else:
|
||||
raise Exception("This should never happen")
|
||||
|
||||
def on_receive(self, msg, sender=None):
|
||||
if msg == "PING":
|
||||
pongs.append(self.now)
|
||||
return "PONG"
|
||||
raise Exception("This should never happen")
|
||||
|
||||
e = environment.EventedEnvironment(schedule_class=stime.OrderedTimedActivation)
|
||||
e = environment.Environment(schedule_class=stime.OrderedTimedActivation)
|
||||
for i in range(2):
|
||||
e.add_agent(agent_class=Ping)
|
||||
assert e.now == 0
|
||||
@@ -373,3 +419,106 @@ class TestAgents(TestCase):
|
||||
model.step()
|
||||
assert a.now == 17
|
||||
assert a.my_state == 5
|
||||
|
||||
def test_receive(self):
|
||||
'''
|
||||
An agent should be able to receive a message after waiting
|
||||
'''
|
||||
model = environment.Environment()
|
||||
class TestAgent(agents.Agent):
|
||||
sent = False
|
||||
woken = 0
|
||||
def step(self):
|
||||
self.woken += 1
|
||||
return super().step()
|
||||
|
||||
@agents.state(default=True)
|
||||
async def one(self):
|
||||
try:
|
||||
self.sent = await self.received(timeout=15)
|
||||
return self.two.at(20)
|
||||
except events.TimedOut:
|
||||
pass
|
||||
@agents.state
|
||||
def two(self):
|
||||
return self.die()
|
||||
|
||||
a = model.add_agent(TestAgent)
|
||||
|
||||
class Sender(agents.Agent):
|
||||
async def step(self):
|
||||
await self.delay(10)
|
||||
a.tell(1)
|
||||
return stime.INFINITY
|
||||
|
||||
b = model.add_agent(Sender)
|
||||
|
||||
# Start and wait
|
||||
model.step()
|
||||
assert model.now == 10
|
||||
assert a.woken == 1
|
||||
assert not a.sent
|
||||
|
||||
# Sending the message
|
||||
model.step()
|
||||
assert model.now == 10
|
||||
assert a.woken == 1
|
||||
assert not a.sent
|
||||
|
||||
# The receiver callback
|
||||
model.step()
|
||||
assert model.now == 15
|
||||
assert a.woken == 2
|
||||
assert a.sent[0].payload == 1
|
||||
|
||||
# The timeout
|
||||
model.step()
|
||||
assert model.now == 20
|
||||
assert a.woken == 2
|
||||
|
||||
# The last state of the agent
|
||||
model.step()
|
||||
assert a.woken == 3
|
||||
assert model.now == float('inf')
|
||||
|
||||
def test_receive_timeout(self):
|
||||
'''
|
||||
A timeout should be raised if no messages are received after an expiration time
|
||||
'''
|
||||
model = environment.Environment()
|
||||
timedout = False
|
||||
class TestAgent(agents.Agent):
|
||||
@agents.state(default=True)
|
||||
def one(self):
|
||||
try:
|
||||
yield from self.received(timeout=10)
|
||||
raise AssertionError('Should have raised an error.')
|
||||
except events.TimedOut:
|
||||
nonlocal timedout
|
||||
timedout = True
|
||||
|
||||
a = model.add_agent(TestAgent)
|
||||
|
||||
model.step()
|
||||
assert model.now == 10
|
||||
model.step()
|
||||
# Wake up the callback
|
||||
assert model.now == 10
|
||||
assert not timedout
|
||||
# The actual timeout
|
||||
model.step()
|
||||
assert model.now == 11
|
||||
assert timedout
|
||||
|
||||
def test_attributes(self):
|
||||
"""Attributes should be individual per agent"""
|
||||
|
||||
class MyAgent(agents.Agent):
|
||||
my_attribute = 0
|
||||
|
||||
model = environment.Environment()
|
||||
a = MyAgent(model=model)
|
||||
assert a.my_attribute == 0
|
||||
b = MyAgent(model=model, my_attribute=1)
|
||||
assert b.my_attribute == 1
|
||||
assert a.my_attribute == 0
|
||||
|
@@ -81,7 +81,8 @@ class Exporters(TestCase):
|
||||
model=ConstantEnv,
|
||||
name="exporter_sim",
|
||||
exporters=[
|
||||
exporters.default,
|
||||
exporters.YAML,
|
||||
exporters.SQLite,
|
||||
exporters.csv,
|
||||
],
|
||||
exporter_params={"copy_to": output},
|
||||
|
@@ -6,7 +6,7 @@ import networkx as nx
|
||||
from functools import partial
|
||||
|
||||
from os.path import join
|
||||
from soil import simulation, Environment, agents, network, serialization, utils, config, from_file
|
||||
from soil import simulation, Environment, agents, serialization, from_file, time
|
||||
from mesa import Agent as MesaAgent
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
@@ -135,9 +135,9 @@ class TestMain(TestCase):
|
||||
|
||||
def test_serialize_agent_class(self):
|
||||
"""A class from soil.agents should be serialized without the module part"""
|
||||
ser = agents._serialize_type(CustomAgent)
|
||||
ser = serialization.serialize(CustomAgent, known_modules=["soil.agents"])[1]
|
||||
assert ser == "test_main.CustomAgent"
|
||||
ser = agents._serialize_type(agents.BaseAgent)
|
||||
ser = serialization.serialize(agents.BaseAgent, known_modules=["soil.agents"])[1]
|
||||
assert ser == "BaseAgent"
|
||||
pickle.dumps(ser)
|
||||
|
||||
@@ -194,7 +194,7 @@ class TestMain(TestCase):
|
||||
return self.ping
|
||||
|
||||
a = ToggleAgent(unique_id=1, model=Environment())
|
||||
when = a.step()
|
||||
when = float(a.step())
|
||||
assert when == 2
|
||||
when = a.step()
|
||||
assert when == None
|
||||
@@ -227,3 +227,86 @@ class TestMain(TestCase):
|
||||
for i in a:
|
||||
for j in b:
|
||||
assert {"a": i, "b": j} in configs
|
||||
|
||||
def test_agent_reporters(self):
|
||||
"""An environment should be able to set its own reporters"""
|
||||
class Noop2(agents.Noop):
|
||||
pass
|
||||
|
||||
e = Environment()
|
||||
e.add_agent(agents.Noop)
|
||||
e.add_agent(Noop2)
|
||||
e.add_agent_reporter("now")
|
||||
e.add_agent_reporter("base", lambda a: "base", agent_class=agents.Noop)
|
||||
e.add_agent_reporter("subclass", lambda a:"subclass", agent_class=Noop2)
|
||||
e.step()
|
||||
|
||||
# Step 0 is not present because we added the reporters
|
||||
# after initialization.
|
||||
df = e.agent_df()
|
||||
assert "now" in df.columns
|
||||
assert "base" in df.columns
|
||||
assert "subclass" in df.columns
|
||||
assert df["now"][(1,0)] == 1
|
||||
assert df["now"][(1,1)] == 1
|
||||
assert df["base"][(1,0)] == "base"
|
||||
assert df["base"][(1,1)] == "base"
|
||||
assert df["subclass"][(1,0)] is None
|
||||
assert df["subclass"][(1,1)] == "subclass"
|
||||
|
||||
def test_remove_agent(self):
|
||||
"""An agent that is scheduled should be removed from the schedule"""
|
||||
model = Environment()
|
||||
model.add_agent(agents.Noop)
|
||||
model.step()
|
||||
model.remove_agent(model.agents[0])
|
||||
assert not model.agents
|
||||
when = model.step()
|
||||
assert when == None
|
||||
assert not model.running
|
||||
|
||||
def test_remove_agent(self):
|
||||
"""An agent that is scheduled should be removed from the schedule"""
|
||||
|
||||
allagents = []
|
||||
class Removed(agents.BaseAgent):
|
||||
def step(self):
|
||||
nonlocal allagents
|
||||
assert self.alive
|
||||
assert self in self.model.agents
|
||||
for agent in allagents:
|
||||
self.model.remove_agent(agent)
|
||||
|
||||
model = Environment()
|
||||
a1 = model.add_agent(Removed)
|
||||
a2 = model.add_agent(Removed)
|
||||
allagents = [a1, a2]
|
||||
model.step()
|
||||
assert not model.agents
|
||||
|
||||
def test_agent_df(self):
|
||||
'''The agent dataframe should have the right columns'''
|
||||
|
||||
class PeterPan(agents.BaseAgent):
|
||||
steps = 0
|
||||
|
||||
def step(self):
|
||||
self.steps += 1
|
||||
return self.delay(0)
|
||||
|
||||
class AgentDF(Environment):
|
||||
def init(self):
|
||||
self.add_agent(PeterPan)
|
||||
self.add_agent_reporter("steps")
|
||||
|
||||
e = AgentDF()
|
||||
df = e.agent_df()
|
||||
assert df["steps"][(0,0)] == 0
|
||||
e.step()
|
||||
df = e.agent_df()
|
||||
assert len(df) == 1
|
||||
assert df["steps"][(0,0)] == 1
|
||||
e.step()
|
||||
df = e.agent_df()
|
||||
assert len(df) == 1
|
||||
assert df["steps"][(0,0)] == 2
|
Reference in New Issue
Block a user