mirror of
https://github.com/gsi-upm/soil
synced 2024-12-22 16:28:11 +00:00
WIP
This commit is contained in:
parent
f49be3af68
commit
3041156f19
File diff suppressed because one or more lines are too long
@ -43,7 +43,7 @@ class Journey:
|
||||
driver: Optional[Driver] = None
|
||||
|
||||
|
||||
class City(EventedEnvironment):
|
||||
class City(Environment):
|
||||
"""
|
||||
An environment with a grid where drivers and passengers will be placed.
|
||||
|
||||
@ -85,11 +85,12 @@ class Driver(Evented, FSM):
|
||||
journey = None
|
||||
earnings = 0
|
||||
|
||||
def on_receive(self, msg, sender):
|
||||
"""This is not a state. It will run (and block) every time process_messages is invoked"""
|
||||
if self.journey is None and isinstance(msg, Journey) and msg.driver is None:
|
||||
msg.driver = self
|
||||
self.journey = msg
|
||||
# TODO: remove
|
||||
# def on_receive(self, msg, sender):
|
||||
# """This is not a state. It will run (and block) every time process_messages is invoked"""
|
||||
# if self.journey is None and isinstance(msg, Journey) and msg.driver is None:
|
||||
# msg.driver = self
|
||||
# self.journey = msg
|
||||
|
||||
def check_passengers(self):
|
||||
"""If there are no more passengers, stop forever"""
|
||||
@ -104,7 +105,7 @@ class Driver(Evented, FSM):
|
||||
if not self.check_passengers():
|
||||
return self.die("No passengers left")
|
||||
self.journey = None
|
||||
while self.journey is None: # No potential journeys detected (see on_receive)
|
||||
while self.journey is None: # No potential journeys detected
|
||||
if target is None or not self.move_towards(target):
|
||||
target = self.random.choice(
|
||||
self.model.grid.get_neighborhood(self.pos, moore=False)
|
||||
@ -113,7 +114,7 @@ class Driver(Evented, FSM):
|
||||
if not self.check_passengers():
|
||||
return self.die("No passengers left")
|
||||
# This will call on_receive behind the scenes, and the agent's status will be updated
|
||||
self.process_messages()
|
||||
|
||||
await self.delay(30) # Wait at least 30 seconds before checking again
|
||||
|
||||
try:
|
||||
@ -167,12 +168,13 @@ class Driver(Evented, FSM):
|
||||
class Passenger(Evented, FSM):
|
||||
pos = None
|
||||
|
||||
def on_receive(self, msg, sender):
|
||||
"""This is not a state. It will be run synchronously every time `process_messages` is run"""
|
||||
# TODO: Remove
|
||||
# def on_receive(self, msg, sender):
|
||||
# """This is not a state. It will be run synchronously every time `process_messages` is run"""
|
||||
|
||||
if isinstance(msg, Journey):
|
||||
self.journey = msg
|
||||
return msg
|
||||
# if isinstance(msg, Journey):
|
||||
# self.journey = msg
|
||||
# return msg
|
||||
|
||||
@default_state
|
||||
@state
|
||||
@ -192,17 +194,34 @@ class Passenger(Evented, FSM):
|
||||
timeout = 60
|
||||
expiration = self.now + timeout
|
||||
self.info(f"Asking for journey at: { self.pos }")
|
||||
self.model.broadcast(journey, ttl=timeout, sender=self, agent_class=Driver)
|
||||
self.broadcast(journey, ttl=timeout, agent_class=Driver)
|
||||
while not self.journey:
|
||||
self.debug(f"Waiting for responses at: { self.pos }")
|
||||
try:
|
||||
# This will call process_messages behind the scenes, and the agent's status will be updated
|
||||
# If you want to avoid that, you can call it with: check=False
|
||||
await self.received(expiration=expiration, delay=10)
|
||||
offers = await self.received(expiration=expiration, delay=10)
|
||||
accepted = None
|
||||
for event in offers:
|
||||
offer = event.payload
|
||||
if isinstance(offer, Journey):
|
||||
self.journey = offer
|
||||
assert isinstance(event.sender, Driver)
|
||||
try:
|
||||
answer = await event.sender.ask(True, sender=self, timeout=60, delay=5)
|
||||
if answer:
|
||||
accepted = offer
|
||||
self.journey = offer
|
||||
break
|
||||
except events.TimedOut:
|
||||
pass
|
||||
if accepted:
|
||||
for event in offers:
|
||||
if event.payload != accepted:
|
||||
event.sender.tell(False, timeout=60, delay=5)
|
||||
|
||||
except events.TimedOut:
|
||||
self.info(f"Still no response. Waiting at: { self.pos }")
|
||||
self.model.broadcast(
|
||||
journey, ttl=timeout, sender=self, agent_class=Driver
|
||||
self.broadcast(
|
||||
journey, ttl=timeout, agent_class=Driver
|
||||
)
|
||||
expiration = self.now + timeout
|
||||
self.info(f"Got a response! Waiting for driver")
|
||||
|
@ -1,4 +1,4 @@
|
||||
from soil.agents import FSM, NetworkAgent, state, default_state, prob
|
||||
from soil.agents import FSM, NetworkAgent, state, default_state
|
||||
from soil.parameters import *
|
||||
import logging
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
There are two similar implementations of this simulation.
|
||||
|
||||
- `basic`. Using simple primites
|
||||
- `improved`. Using more advanced features such as the `time` module to avoid unnecessary computations (i.e., skip steps), and generator functions.
|
||||
- `improved`. Using more advanced features such as the delays to avoid unnecessary computations (i.e., skip steps).
|
||||
|
||||
The examples can be run directly in the terminal, and they accept command like arguments.
|
||||
For example, to enable the CSV exporter and the Summary exporter, while setting `max_time` to `100` and `seed` to `CustomSeed`:
|
||||
|
@ -1,22 +1,36 @@
|
||||
from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment, Simulation
|
||||
from enum import Enum
|
||||
from collections import Counter
|
||||
import logging
|
||||
from soil import Evented, FSM, state, default_state, BaseAgent, NetworkAgent, Environment, parameters, report, TimedOut
|
||||
import math
|
||||
|
||||
from rabbits_basic_sim import RabbitEnv
|
||||
from soilent import Scheduler
|
||||
|
||||
|
||||
class RabbitsImprovedEnv(RabbitEnv):
|
||||
class RabbitsImprovedEnv(Environment):
|
||||
prob_death: parameters.probability = 1e-3
|
||||
schedule_class = Scheduler
|
||||
|
||||
def init(self):
|
||||
"""Initialize the environment with the new versions of the agents"""
|
||||
a1 = self.add_node(Male)
|
||||
a2 = self.add_node(Female)
|
||||
a1.add_edge(a2)
|
||||
self.add_agent(RandomAccident)
|
||||
|
||||
@report
|
||||
@property
|
||||
def num_rabbits(self):
|
||||
return self.count_agents(agent_class=Rabbit)
|
||||
|
||||
class Rabbit(FSM, NetworkAgent):
|
||||
@report
|
||||
@property
|
||||
def num_males(self):
|
||||
return self.count_agents(agent_class=Male)
|
||||
|
||||
@report
|
||||
@property
|
||||
def num_females(self):
|
||||
return self.count_agents(agent_class=Female)
|
||||
|
||||
|
||||
class Rabbit(Evented, FSM, NetworkAgent):
|
||||
|
||||
sexual_maturity = 30
|
||||
life_expectancy = 300
|
||||
@ -31,42 +45,40 @@ class Rabbit(FSM, NetworkAgent):
|
||||
@default_state
|
||||
@state
|
||||
def newborn(self):
|
||||
self.info("I am a newborn.")
|
||||
self.debug("I am a newborn.")
|
||||
self.birth = self.now
|
||||
self.offspring = 0
|
||||
return self.youngling.delay(self.sexual_maturity - self.age)
|
||||
return self.youngling
|
||||
|
||||
@state
|
||||
def youngling(self):
|
||||
if self.age >= self.sexual_maturity:
|
||||
self.info(f"I am fertile! My age is {self.age}")
|
||||
async def youngling(self):
|
||||
self.debug("I am a youngling.")
|
||||
await self.delay(self.sexual_maturity - self.age)
|
||||
assert self.age >= self.sexual_maturity
|
||||
self.debug(f"I am fertile! My age is {self.age}")
|
||||
return self.fertile
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
raise Exception("Each subclass should define its fertile state")
|
||||
|
||||
@state
|
||||
def dead(self):
|
||||
self.die()
|
||||
|
||||
|
||||
class Male(Rabbit):
|
||||
max_females = 5
|
||||
mating_prob = 0.001
|
||||
mating_prob = 0.005
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
if self.age > self.life_expectancy:
|
||||
return self.dead
|
||||
return self.die()
|
||||
|
||||
# Males try to mate
|
||||
for f in self.model.agents(
|
||||
agent_class=Female, state_id=Female.fertile.id, limit=self.max_females
|
||||
):
|
||||
self.debug("FOUND A FEMALE: ", repr(f), self.mating_prob)
|
||||
self.debug(f"FOUND A FEMALE: {repr(f)}. Mating with prob {self.mating_prob}")
|
||||
if self.prob(self["mating_prob"]):
|
||||
f.impregnate(self)
|
||||
f.tell(self.unique_id, sender=self, timeout=1)
|
||||
break # Do not try to impregnate other females
|
||||
|
||||
|
||||
@ -75,78 +87,91 @@ class Female(Rabbit):
|
||||
conception = None
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
async def fertile(self):
|
||||
# Just wait for a Male
|
||||
if self.age > self.life_expectancy:
|
||||
return self.dead
|
||||
if self.conception is not None:
|
||||
return self.pregnant
|
||||
|
||||
@property
|
||||
def pregnancy(self):
|
||||
if self.conception is None:
|
||||
return None
|
||||
return self.now - self.conception
|
||||
|
||||
def impregnate(self, male):
|
||||
self.info(f"impregnated by {repr(male)}")
|
||||
try:
|
||||
timeout = self.life_expectancy - self.age
|
||||
while timeout > 0:
|
||||
mates = await self.received(timeout=timeout)
|
||||
# assert self.age <= self.life_expectancy
|
||||
for mate in mates:
|
||||
try:
|
||||
male = self.model.agents[mate.payload]
|
||||
except ValueError:
|
||||
continue
|
||||
self.debug(f"impregnated by {repr(male)}")
|
||||
self.mate = male
|
||||
self.conception = self.now
|
||||
self.number_of_babies = int(8 + 4 * self.random.random())
|
||||
|
||||
@state
|
||||
def pregnant(self):
|
||||
self.debug("I am pregnant")
|
||||
|
||||
if self.age > self.life_expectancy:
|
||||
self.info("Dying before giving birth")
|
||||
self.conception = self.now
|
||||
return self.pregnant
|
||||
except TimedOut:
|
||||
pass
|
||||
return self.die()
|
||||
|
||||
if self.pregnancy >= self.gestation:
|
||||
self.info("Having {} babies".format(self.number_of_babies))
|
||||
@state
|
||||
async def pregnant(self):
|
||||
self.debug("I am pregnant")
|
||||
# assert self.mate is not None
|
||||
|
||||
when = min(self.gestation, self.life_expectancy - self.age)
|
||||
if when < 0:
|
||||
return self.die()
|
||||
await self.delay(when)
|
||||
|
||||
if self.age > self.life_expectancy:
|
||||
self.debug("Dying before giving birth")
|
||||
return self.die()
|
||||
|
||||
# assert self.now - self.conception >= self.gestation
|
||||
if not self.alive:
|
||||
return self.die()
|
||||
|
||||
self.debug("Having {} babies".format(self.number_of_babies))
|
||||
for i in range(self.number_of_babies):
|
||||
state = {}
|
||||
agent_class = self.random.choice([Male, Female])
|
||||
child = self.model.add_node(agent_class=agent_class, **state)
|
||||
child.add_edge(self)
|
||||
if self.mate:
|
||||
try:
|
||||
child.add_edge(self.mate)
|
||||
self.mate.offspring += 1
|
||||
else:
|
||||
self.model.agents[self.mate].offspring += 1
|
||||
except ValueError:
|
||||
self.debug("The father has passed away")
|
||||
|
||||
self.offspring += 1
|
||||
self.mate = None
|
||||
self.conception = None
|
||||
return self.fertile
|
||||
|
||||
def die(self):
|
||||
if self.pregnancy is not None:
|
||||
self.info("A mother has died carrying a baby!!")
|
||||
if self.conception is not None:
|
||||
self.debug("A mother has died carrying a baby!!")
|
||||
return super().die()
|
||||
|
||||
|
||||
class RandomAccident(BaseAgent):
|
||||
# Default value, but the value from the environment takes precedence
|
||||
prob_death = 1e-3
|
||||
|
||||
def step(self):
|
||||
rabbits_alive = self.model.G.number_of_nodes()
|
||||
|
||||
if not rabbits_alive:
|
||||
return self.die()
|
||||
alive = self.get_agents(agent_class=Rabbit, alive=True)
|
||||
|
||||
prob_death = self.model.get("prob_death", 1e-100) * math.floor(
|
||||
math.log10(max(1, rabbits_alive))
|
||||
)
|
||||
if not alive:
|
||||
return self.die("No more rabbits to kill")
|
||||
|
||||
num_alive = len(alive)
|
||||
prob_death = min(1, self.prob_death * num_alive/10)
|
||||
self.debug("Killing some rabbits with prob={}!".format(prob_death))
|
||||
for i in self.iter_agents(agent_class=Rabbit):
|
||||
|
||||
for i in alive:
|
||||
if i.state_id == i.dead.id:
|
||||
continue
|
||||
if self.prob(prob_death):
|
||||
self.info("I killed a rabbit: {}".format(i.id))
|
||||
rabbits_alive -= 1
|
||||
self.debug("I killed a rabbit: {}".format(i.unique_id))
|
||||
num_alive -= 1
|
||||
i.die()
|
||||
self.debug("Rabbits alive: {}".format(rabbits_alive))
|
||||
self.debug("Rabbits alive: {}".format(num_alive))
|
||||
|
||||
|
||||
sim = Simulation(model=RabbitsImprovedEnv, max_time=100, seed="MySeed", iterations=1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
sim.run()
|
||||
RabbitsImprovedEnv.run(max_time=1000, seed="MySeed", iterations=1)
|
@ -1,11 +1,9 @@
|
||||
from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment, Simulation, report, parameters as params
|
||||
from collections import Counter
|
||||
import logging
|
||||
from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment, report, parameters as params
|
||||
import math
|
||||
|
||||
|
||||
class RabbitEnv(Environment):
|
||||
prob_death: params.probability = 1e-100
|
||||
prob_death: params.probability = 1e-3
|
||||
|
||||
def init(self):
|
||||
a1 = self.add_node(Male)
|
||||
@ -37,7 +35,7 @@ class Rabbit(NetworkAgent, FSM):
|
||||
@default_state
|
||||
@state
|
||||
def newborn(self):
|
||||
self.info("I am a newborn.")
|
||||
self.debug("I am a newborn.")
|
||||
self.age = 0
|
||||
self.offspring = 0
|
||||
return self.youngling
|
||||
@ -46,7 +44,7 @@ class Rabbit(NetworkAgent, FSM):
|
||||
def youngling(self):
|
||||
self.age += 1
|
||||
if self.age >= self.sexual_maturity:
|
||||
self.info(f"I am fertile! My age is {self.age}")
|
||||
self.debug(f"I am fertile! My age is {self.age}")
|
||||
return self.fertile
|
||||
|
||||
@state
|
||||
@ -60,7 +58,7 @@ class Rabbit(NetworkAgent, FSM):
|
||||
|
||||
class Male(Rabbit):
|
||||
max_females = 5
|
||||
mating_prob = 0.001
|
||||
mating_prob = 0.005
|
||||
|
||||
@state
|
||||
def fertile(self):
|
||||
@ -70,9 +68,8 @@ class Male(Rabbit):
|
||||
return self.dead
|
||||
|
||||
# Males try to mate
|
||||
for f in self.model.agents(
|
||||
agent_class=Female, state_id=Female.fertile.id, limit=self.max_females
|
||||
):
|
||||
for f in self.model.agents.filter(
|
||||
agent_class=Female, state_id=Female.fertile.id).limit(self.max_females):
|
||||
self.debug("FOUND A FEMALE: ", repr(f), self.mating_prob)
|
||||
if self.prob(self["mating_prob"]):
|
||||
f.impregnate(self)
|
||||
@ -93,14 +90,14 @@ class Female(Rabbit):
|
||||
return self.pregnant
|
||||
|
||||
def impregnate(self, male):
|
||||
self.info(f"impregnated by {repr(male)}")
|
||||
self.debug(f"impregnated by {repr(male)}")
|
||||
self.mate = male
|
||||
self.pregnancy = 0
|
||||
self.number_of_babies = int(8 + 4 * self.random.random())
|
||||
|
||||
@state
|
||||
def pregnant(self):
|
||||
self.info("I am pregnant")
|
||||
self.debug("I am pregnant")
|
||||
self.age += 1
|
||||
|
||||
if self.age >= self.life_expectancy:
|
||||
@ -110,7 +107,7 @@ class Female(Rabbit):
|
||||
self.pregnancy += 1
|
||||
return
|
||||
|
||||
self.info("Having {} babies".format(self.number_of_babies))
|
||||
self.debug("Having {} babies".format(self.number_of_babies))
|
||||
for i in range(self.number_of_babies):
|
||||
state = {}
|
||||
agent_class = self.random.choice([Male, Female])
|
||||
@ -129,33 +126,30 @@ class Female(Rabbit):
|
||||
|
||||
def die(self):
|
||||
if "pregnancy" in self and self["pregnancy"] > -1:
|
||||
self.info("A mother has died carrying a baby!!")
|
||||
self.debug("A mother has died carrying a baby!!")
|
||||
return super().die()
|
||||
|
||||
|
||||
class RandomAccident(BaseAgent):
|
||||
prob_death = None
|
||||
def step(self):
|
||||
rabbits_alive = self.model.G.number_of_nodes()
|
||||
alive = self.get_agents(agent_class=Rabbit, alive=True)
|
||||
|
||||
if not rabbits_alive:
|
||||
return self.die()
|
||||
if not alive:
|
||||
return self.die("No more rabbits to kill")
|
||||
|
||||
prob_death = self.model.prob_death * math.floor(
|
||||
math.log10(max(1, rabbits_alive))
|
||||
)
|
||||
num_alive = len(alive)
|
||||
prob_death = min(1, self.prob_death * num_alive/10)
|
||||
self.debug("Killing some rabbits with prob={}!".format(prob_death))
|
||||
|
||||
for i in self.get_agents(agent_class=Rabbit):
|
||||
if i.state_id == i.dead.id:
|
||||
continue
|
||||
if self.prob(prob_death):
|
||||
self.info("I killed a rabbit: {}".format(i.id))
|
||||
rabbits_alive -= 1
|
||||
self.debug("I killed a rabbit: {}".format(i.unique_id))
|
||||
num_alive -= 1
|
||||
i.die()
|
||||
self.debug("Rabbits alive: {}".format(rabbits_alive))
|
||||
self.debug("Rabbits alive: {}".format(num_alive))
|
||||
|
||||
|
||||
|
||||
sim = Simulation(model=RabbitEnv, max_time=100, seed="MySeed", iterations=1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
sim.run()
|
||||
RabbitEnv.run(max_time=1000, seed="MySeed", iterations=1)
|
@ -1,5 +1,5 @@
|
||||
import networkx as nx
|
||||
from soil.agents import NetworkAgent, FSM, custom, state, default_state
|
||||
from soil.agents import FSM, state, default_state
|
||||
from soil.agents.geo import Geo
|
||||
from soil import Environment, Simulation
|
||||
from soil.parameters import *
|
||||
|
@ -1 +1 @@
|
||||
1.0.0rc3
|
||||
1.0.0rc6
|
||||
|
@ -19,7 +19,7 @@ from pathlib import Path
|
||||
from .agents import *
|
||||
from . import agents
|
||||
from .simulation import *
|
||||
from .environment import Environment, EventedEnvironment
|
||||
from .environment import Environment
|
||||
from .datacollection import SoilCollector
|
||||
from . import serialization
|
||||
from .utils import logger
|
||||
@ -117,13 +117,13 @@ def main(
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_time",
|
||||
default="-1",
|
||||
default="",
|
||||
help="Set maximum time for the simulation to run. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default="-1",
|
||||
default="",
|
||||
help="Set maximum number of steps for the simulation to run.",
|
||||
)
|
||||
|
||||
@ -249,9 +249,12 @@ def main(
|
||||
if args.only_convert:
|
||||
print(sim.to_yaml())
|
||||
continue
|
||||
max_time = float(args.max_time) if args.max_time != "-1" else None
|
||||
max_steps = float(args.max_steps) if args.max_steps != "-1" else None
|
||||
res.append(sim.run(max_time=max_time, max_steps=max_steps))
|
||||
d = {}
|
||||
if args.max_time:
|
||||
d["max_time"] = float(args.max_time)
|
||||
if args.max_steps:
|
||||
d["max_steps"] = int(args.max_steps)
|
||||
res.append(sim.run(**d))
|
||||
|
||||
except Exception as ex:
|
||||
if args.pdb:
|
||||
|
@ -1,107 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections.abc import MutableMapping, Mapping, Set
|
||||
from abc import ABCMeta
|
||||
from copy import deepcopy, copy
|
||||
from functools import partial, wraps
|
||||
from itertools import islice, chain
|
||||
from collections.abc import MutableMapping
|
||||
from copy import deepcopy
|
||||
import inspect
|
||||
import types
|
||||
import textwrap
|
||||
import networkx as nx
|
||||
import warnings
|
||||
import sys
|
||||
|
||||
from typing import Any
|
||||
from mesa import Agent as MesaAgent
|
||||
|
||||
from mesa import Agent as MesaAgent, Model
|
||||
from typing import Dict, List
|
||||
from .. import utils, time
|
||||
|
||||
from .. import serialization, network, utils, time, config
|
||||
from .meta import MetaAgent
|
||||
|
||||
|
||||
IGNORED_FIELDS = ("model", "logger")
|
||||
|
||||
|
||||
def decorate_generator_step(func, name):
|
||||
@wraps(func)
|
||||
def decorated(self):
|
||||
while True:
|
||||
if self._coroutine is None:
|
||||
self._coroutine = func(self)
|
||||
try:
|
||||
if self._last_except:
|
||||
val = self._coroutine.throw(self._last_except)
|
||||
else:
|
||||
val = self._coroutine.send(self._last_return)
|
||||
except StopIteration as ex:
|
||||
self._coroutine = None
|
||||
val = ex.value
|
||||
finally:
|
||||
self._last_return = None
|
||||
self._last_except = None
|
||||
return float(val) if val is not None else val
|
||||
return decorated
|
||||
|
||||
|
||||
def decorate_normal_func(func, name):
|
||||
@wraps(func)
|
||||
def decorated(self):
|
||||
val = func(self)
|
||||
return float(val) if val is not None else val
|
||||
return decorated
|
||||
|
||||
|
||||
class MetaAgent(ABCMeta):
|
||||
def __new__(mcls, name, bases, namespace):
|
||||
defaults = {}
|
||||
|
||||
# Re-use defaults from inherited classes
|
||||
for i in bases:
|
||||
if isinstance(i, MetaAgent):
|
||||
defaults.update(i._defaults)
|
||||
|
||||
new_nmspc = {
|
||||
"_defaults": defaults,
|
||||
}
|
||||
|
||||
for attr, func in namespace.items():
|
||||
if attr == "step":
|
||||
if inspect.isgeneratorfunction(func) or inspect.iscoroutinefunction(func):
|
||||
func = decorate_generator_step(func, attr)
|
||||
new_nmspc.update({
|
||||
"_last_return": None,
|
||||
"_last_except": None,
|
||||
"_coroutine": None,
|
||||
})
|
||||
elif inspect.isasyncgenfunction(func):
|
||||
raise ValueError("Illegal step function: {}. It probably mixes both async/await and yield".format(func))
|
||||
elif inspect.isfunction(func):
|
||||
func = decorate_normal_func(func, attr)
|
||||
else:
|
||||
raise ValueError("Illegal step function: {}".format(func))
|
||||
new_nmspc[attr] = func
|
||||
elif (
|
||||
isinstance(func, types.FunctionType)
|
||||
or isinstance(func, property)
|
||||
or isinstance(func, classmethod)
|
||||
or attr[0] == "_"
|
||||
):
|
||||
new_nmspc[attr] = func
|
||||
elif attr == "defaults":
|
||||
defaults.update(func)
|
||||
elif inspect.isfunction(func):
|
||||
new_nmspc[attr] = func
|
||||
else:
|
||||
defaults[attr] = copy(func)
|
||||
|
||||
|
||||
# Add attributes for their use in the decorated functions
|
||||
return super().__new__(mcls, name, bases, new_nmspc)
|
||||
|
||||
|
||||
class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
"""
|
||||
A special type of Mesa Agent that:
|
||||
@ -154,11 +70,11 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
return hash(self.unique_id)
|
||||
|
||||
def prob(self, probability):
|
||||
return prob(probability, self.model.random)
|
||||
return utils.prob(probability, self.model.random)
|
||||
|
||||
@classmethod
|
||||
def w(cls, **kwargs):
|
||||
return custom(cls, **kwargs)
|
||||
return utils.custom(cls, **kwargs)
|
||||
|
||||
# TODO: refactor to clean up mesa compatibility
|
||||
@property
|
||||
@ -168,21 +84,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
print(msg, file=sys.stderr)
|
||||
return self.unique_id
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, model, attrs, warn_extra=True):
|
||||
ignored = {}
|
||||
args = {}
|
||||
for k, v in attrs.items():
|
||||
if k in inspect.signature(cls).parameters:
|
||||
args[k] = v
|
||||
else:
|
||||
ignored[k] = v
|
||||
if ignored and warn_extra:
|
||||
utils.logger.info(
|
||||
f"Ignoring the following arguments for agent class { agent_class.__name__ }: { ignored }"
|
||||
)
|
||||
return cls(model=model, **args)
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
return getattr(self, key)
|
||||
@ -302,399 +203,23 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
return time.Delay(delay)
|
||||
|
||||
|
||||
def prob(prob, random):
|
||||
"""
|
||||
A true/False uniform distribution with a given probability.
|
||||
To be used like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
if prob(0.3):
|
||||
do_something()
|
||||
|
||||
"""
|
||||
r = random.random()
|
||||
return r < prob
|
||||
|
||||
|
||||
def calculate_distribution(network_agents=None, agent_class=None):
|
||||
"""
|
||||
Calculate the threshold values (thresholds for a uniform distribution)
|
||||
of an agent distribution given the weights of each agent type.
|
||||
|
||||
The input has this form: ::
|
||||
|
||||
[
|
||||
{'agent_class': 'agent_class_1',
|
||||
'weight': 0.2,
|
||||
'state': {
|
||||
'id': 0
|
||||
}
|
||||
},
|
||||
{'agent_class': 'agent_class_2',
|
||||
'weight': 0.8,
|
||||
'state': {
|
||||
'id': 1
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
In this example, 20% of the nodes will be marked as type
|
||||
'agent_class_1'.
|
||||
"""
|
||||
if network_agents:
|
||||
network_agents = [
|
||||
deepcopy(agent) for agent in network_agents if not hasattr(agent, "id")
|
||||
]
|
||||
elif agent_class:
|
||||
network_agents = [{"agent_class": agent_class}]
|
||||
else:
|
||||
raise ValueError("Specify a distribution or a default agent type")
|
||||
|
||||
# Fix missing weights and incompatible types
|
||||
for x in network_agents:
|
||||
x["weight"] = float(x.get("weight", 1))
|
||||
|
||||
# Calculate the thresholds
|
||||
total = sum(x["weight"] for x in network_agents)
|
||||
acc = 0
|
||||
for v in network_agents:
|
||||
if "ids" in v:
|
||||
continue
|
||||
upper = acc + (v["weight"] / total)
|
||||
v["threshold"] = [acc, upper]
|
||||
acc = upper
|
||||
return network_agents
|
||||
|
||||
|
||||
def _serialize_type(agent_class, known_modules=[], **kwargs):
|
||||
if isinstance(agent_class, str):
|
||||
return agent_class
|
||||
known_modules += ["soil.agents"]
|
||||
return serialization.serialize(agent_class, known_modules=known_modules, **kwargs)[
|
||||
1
|
||||
] # Get the name of the class
|
||||
|
||||
|
||||
def _deserialize_type(agent_class, known_modules=[]):
|
||||
if not isinstance(agent_class, str):
|
||||
return agent_class
|
||||
known = known_modules + ["soil.agents", "soil.agents.custom"]
|
||||
agent_class = serialization.deserializer(agent_class, known_modules=known)
|
||||
return agent_class
|
||||
|
||||
|
||||
class AgentView(Mapping, Set):
|
||||
"""A lazy-loaded list of agents."""
|
||||
|
||||
__slots__ = ("_agents",)
|
||||
|
||||
def __init__(self, agents):
|
||||
self._agents = agents
|
||||
|
||||
def __getstate__(self):
|
||||
return {"_agents": self._agents}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._agents = state["_agents"]
|
||||
|
||||
# Mapping methods
|
||||
def __len__(self):
|
||||
return len(self._agents)
|
||||
|
||||
def __iter__(self):
|
||||
yield from self._agents.values()
|
||||
|
||||
def __getitem__(self, agent_id):
|
||||
if isinstance(agent_id, slice):
|
||||
raise ValueError(f"Slicing is not supported")
|
||||
if agent_id in self._agents:
|
||||
return self._agents[agent_id]
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
def filter(self, *args, **kwargs):
|
||||
yield from filter_agents(self._agents, *args, **kwargs)
|
||||
|
||||
def one(self, *args, **kwargs):
|
||||
return next(filter_agents(self._agents, *args, **kwargs))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return list(self.filter(*args, **kwargs))
|
||||
|
||||
def __contains__(self, agent_id):
|
||||
return agent_id in self._agents
|
||||
|
||||
def __str__(self):
|
||||
return str(list(unique_id for unique_id in self.keys()))
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self})"
|
||||
|
||||
|
||||
def filter_agents(
|
||||
agents: dict,
|
||||
*id_args,
|
||||
unique_id=None,
|
||||
state_id=None,
|
||||
agent_class=None,
|
||||
ignore=None,
|
||||
state=None,
|
||||
limit=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
||||
"""
|
||||
assert isinstance(agents, dict)
|
||||
|
||||
ids = []
|
||||
|
||||
if unique_id is not None:
|
||||
if isinstance(unique_id, list):
|
||||
ids += unique_id
|
||||
else:
|
||||
ids.append(unique_id)
|
||||
|
||||
if id_args:
|
||||
ids += id_args
|
||||
|
||||
if ids:
|
||||
f = (agents[aid] for aid in ids if aid in agents)
|
||||
else:
|
||||
f = agents.values()
|
||||
|
||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||
state_id = tuple([state_id])
|
||||
|
||||
if agent_class is not None:
|
||||
agent_class = _deserialize_type(agent_class)
|
||||
try:
|
||||
agent_class = tuple(agent_class)
|
||||
except TypeError:
|
||||
agent_class = tuple([agent_class])
|
||||
|
||||
if ignore:
|
||||
f = filter(lambda x: x not in ignore, f)
|
||||
|
||||
if state_id is not None:
|
||||
f = filter(lambda agent: agent.get("state_id", None) in state_id, f)
|
||||
|
||||
if agent_class is not None:
|
||||
f = filter(lambda agent: isinstance(agent, agent_class), f)
|
||||
|
||||
state = state or dict()
|
||||
state.update(kwargs)
|
||||
|
||||
for k, vs in state.items():
|
||||
if not isinstance(vs, list):
|
||||
vs = [vs]
|
||||
f = filter(lambda agent: any(getattr(agent, k, None) == v for v in vs), f)
|
||||
|
||||
if limit is not None:
|
||||
f = islice(f, limit)
|
||||
|
||||
yield from f
|
||||
|
||||
|
||||
def from_config(
|
||||
cfg: config.AgentConfig, random, topology: nx.Graph = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
This function turns an agentconfig into a list of individual "agent specifications", which are just a dictionary
|
||||
with the parameters that the environment will use to construct each agent.
|
||||
|
||||
This function does NOT return a list of agents, mostly because some attributes to the agent are not known at the
|
||||
time of calling this function, such as `unique_id`.
|
||||
"""
|
||||
default = cfg or config.AgentConfig()
|
||||
if not isinstance(cfg, config.AgentConfig):
|
||||
cfg = config.AgentConfig(**cfg)
|
||||
|
||||
agents = []
|
||||
|
||||
assigned_total = 0
|
||||
assigned_network = 0
|
||||
|
||||
if cfg.fixed is not None:
|
||||
agents, assigned_total, assigned_network = _from_fixed(
|
||||
cfg.fixed, topology=cfg.topology, default=cfg
|
||||
)
|
||||
|
||||
n = cfg.n
|
||||
|
||||
if cfg.distribution:
|
||||
topo_size = len(topology) if topology else 0
|
||||
|
||||
networked = []
|
||||
total = []
|
||||
|
||||
for d in cfg.distribution:
|
||||
if d.strategy == config.Strategy.topology:
|
||||
topo = d.topology if ("topology" in d.__fields_set__) else cfg.topology
|
||||
if not topo:
|
||||
raise ValueError(
|
||||
'The "topology" strategy only works if the topology parameter is set to True'
|
||||
)
|
||||
if not topo_size:
|
||||
raise ValueError(
|
||||
f"Topology does not have enough free nodes to assign one to the agent"
|
||||
)
|
||||
|
||||
networked.append(d)
|
||||
|
||||
if d.strategy == config.Strategy.total:
|
||||
if not cfg.n:
|
||||
raise ValueError(
|
||||
'Cannot use the "total" strategy without providing the total number of agents'
|
||||
)
|
||||
total.append(d)
|
||||
|
||||
if networked:
|
||||
new_agents = _from_distro(
|
||||
networked,
|
||||
n=topo_size - assigned_network,
|
||||
topology=topo,
|
||||
default=cfg,
|
||||
random=random,
|
||||
)
|
||||
assigned_total += len(new_agents)
|
||||
assigned_network += len(new_agents)
|
||||
agents += new_agents
|
||||
|
||||
if total:
|
||||
remaining = n - assigned_total
|
||||
agents += _from_distro(total, n=remaining, default=cfg, random=random)
|
||||
|
||||
if assigned_network < topo_size:
|
||||
utils.logger.warn(
|
||||
f"The total number of agents does not match the total number of nodes in "
|
||||
"every topology. This may be due to a definition error: assigned: "
|
||||
f"{ assigned } total size: { topo_size }"
|
||||
)
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
def _from_fixed(
|
||||
lst: List[config.FixedAgentConfig],
|
||||
topology: bool,
|
||||
default: config.SingleAgentConfig,
|
||||
) -> List[Dict[str, Any]]:
|
||||
agents = []
|
||||
|
||||
counts_total = 0
|
||||
counts_network = 0
|
||||
|
||||
for fixed in lst:
|
||||
agent = {}
|
||||
if default:
|
||||
agent = default.state.copy()
|
||||
agent.update(fixed.state)
|
||||
cls = serialization.deserialize(
|
||||
fixed.agent_class or (default and default.agent_class)
|
||||
)
|
||||
agent["agent_class"] = cls
|
||||
topo = (
|
||||
fixed.topology
|
||||
if ("topology" in fixed.__fields_set__)
|
||||
else topology or default.topology
|
||||
)
|
||||
|
||||
if topo:
|
||||
agent["topology"] = True
|
||||
counts_network += 1
|
||||
if not fixed.hidden:
|
||||
counts_total += 1
|
||||
agents.append(agent)
|
||||
|
||||
return agents, counts_total, counts_network
|
||||
|
||||
|
||||
def _from_distro(
|
||||
distro: List[config.AgentDistro],
|
||||
n: int,
|
||||
default: config.SingleAgentConfig,
|
||||
random,
|
||||
topology: str = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
||||
agents = []
|
||||
|
||||
if n is None:
|
||||
if any(lambda dist: dist.n is None, distro):
|
||||
raise ValueError(
|
||||
"You must provide a total number of agents, or the number of each type"
|
||||
)
|
||||
n = sum(dist.n for dist in distro)
|
||||
|
||||
weights = list(dist.weight if dist.weight is not None else 1 for dist in distro)
|
||||
minw = min(weights)
|
||||
norm = list(weight / minw for weight in weights)
|
||||
total = sum(norm)
|
||||
chunk = n // total
|
||||
|
||||
# random.choices would be enough to get a weighted distribution. But it can vary a lot for smaller k
|
||||
# So instead we calculate our own distribution to make sure the actual ratios are close to what we would expect
|
||||
|
||||
# Calculate how many times each has to appear
|
||||
indices = list(
|
||||
chain.from_iterable([idx] * int(n * chunk) for (idx, n) in enumerate(norm))
|
||||
)
|
||||
|
||||
# Complete with random agents following the original weight distribution
|
||||
if len(indices) < n:
|
||||
indices += random.choices(
|
||||
list(range(len(distro))),
|
||||
weights=[d.weight for d in distro],
|
||||
k=n - len(indices),
|
||||
)
|
||||
|
||||
# Deserialize classes for efficiency
|
||||
classes = list(
|
||||
serialization.deserialize(i.agent_class or default.agent_class) for i in distro
|
||||
)
|
||||
|
||||
# Add them in random order
|
||||
random.shuffle(indices)
|
||||
|
||||
for idx in indices:
|
||||
d = distro[idx]
|
||||
agent = d.state.copy()
|
||||
cls = classes[idx]
|
||||
agent["agent_class"] = cls
|
||||
if default:
|
||||
agent.update(default.state)
|
||||
topology = (
|
||||
d.topology
|
||||
if ("topology" in d.__fields_set__)
|
||||
else topology or default.topology
|
||||
)
|
||||
if topology:
|
||||
agent["topology"] = topology
|
||||
agents.append(agent)
|
||||
|
||||
return agents
|
||||
class Noop(BaseAgent):
|
||||
def step(self):
|
||||
return
|
||||
|
||||
|
||||
from .network_agents import *
|
||||
from .fsm import *
|
||||
from .evented import *
|
||||
from typing import Optional
|
||||
from .view import *
|
||||
|
||||
|
||||
class Agent(NetworkAgent, FSM, EventedAgent):
|
||||
"""Default agent class, has both network and event capabilities"""
|
||||
|
||||
|
||||
from ..environment import NetworkEnvironment
|
||||
class Agent(FSM, EventedAgent, NetworkAgent):
|
||||
"""Default agent class, has network, FSM and event capabilities"""
|
||||
|
||||
|
||||
# Additional types of agents
|
||||
from .BassModel import *
|
||||
from .IndependentCascadeModel import *
|
||||
from .SISaModel import *
|
||||
from .CounterModel import *
|
||||
|
||||
|
||||
def custom(cls, **kwargs):
|
||||
"""Create a new class from a template class and keyword arguments"""
|
||||
return type(cls.__name__, (cls,), kwargs)
|
||||
|
@ -1,58 +1,34 @@
|
||||
from . import BaseAgent
|
||||
from ..events import Message, Tell, Ask, TimedOut
|
||||
from .. import environment, events
|
||||
from functools import partial
|
||||
from collections import deque
|
||||
from types import coroutine
|
||||
|
||||
# from soilent import Scheduler
|
||||
|
||||
|
||||
class EventedAgent(BaseAgent):
|
||||
# scheduler_class = Scheduler
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._inbox = deque()
|
||||
self._processed = 0
|
||||
assert isinstance(self.model, environment.EventedEnvironment), "EventedAgent requires an EventedEnvironment"
|
||||
self.model.register(self)
|
||||
|
||||
def on_receive(self, *args, **kwargs):
|
||||
pass
|
||||
def received(self, **kwargs):
|
||||
return self.model.received(self, **kwargs)
|
||||
|
||||
@coroutine
|
||||
def received(self, expiration=None, timeout=60, delay=1, process=True):
|
||||
if not expiration:
|
||||
expiration = self.now + timeout
|
||||
while self.now < expiration:
|
||||
if self._inbox:
|
||||
msgs = self._inbox
|
||||
if process:
|
||||
self.process_messages()
|
||||
return msgs
|
||||
yield self.delay(delay)
|
||||
raise TimedOut("No message received")
|
||||
def tell(self, msg, **kwargs):
|
||||
return self.model.tell(msg, recipient=self, **kwargs)
|
||||
|
||||
def tell(self, msg, sender=None):
|
||||
self._inbox.append(Tell(timestamp=self.now, payload=msg, sender=sender))
|
||||
def broadcast(self, msg, **kwargs):
|
||||
return self.model.broadcast(msg, sender=self, **kwargs)
|
||||
|
||||
@coroutine
|
||||
def ask(self, msg, expiration=None, timeout=None, delay=1):
|
||||
ask = Ask(timestamp=self.now, payload=msg, sender=self)
|
||||
self._inbox.append(ask)
|
||||
expiration = float("inf") if timeout is None else self.now + timeout
|
||||
while self.now < expiration:
|
||||
if ask.reply:
|
||||
return ask.reply
|
||||
yield self.delay(delay)
|
||||
raise TimedOut("No reply received")
|
||||
def ask(self, msg, **kwargs):
|
||||
return self.model.ask(msg, recipient=self, **kwargs)
|
||||
|
||||
def process_messages(self):
|
||||
valid = list()
|
||||
for msg in self._inbox:
|
||||
self._processed += 1
|
||||
if msg.expired(self.now):
|
||||
continue
|
||||
valid.append(msg)
|
||||
reply = self.on_receive(msg.payload, sender=msg.sender)
|
||||
if isinstance(msg, Ask):
|
||||
msg.reply = reply
|
||||
self._inbox.clear()
|
||||
return valid
|
||||
return self.model.process_messages(self.model.inbox_for(self))
|
||||
|
||||
|
||||
Evented = EventedAgent
|
||||
|
@ -140,11 +140,19 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
self._set_state(value)
|
||||
|
||||
def step(self):
|
||||
if self._state is None:
|
||||
if len(self._states) == 1:
|
||||
raise Exception("Agent class has no valid states: {}. Make sure to define states or define a custom step function".format(self.__class__.__name__))
|
||||
else:
|
||||
raise Exception("Invalid state (None) for agent {}".format(self))
|
||||
|
||||
self._check_alive()
|
||||
next_state = yield from self._state.step(self)
|
||||
|
||||
try:
|
||||
next_state, when = next_state
|
||||
self._set_state(next_state)
|
||||
return when
|
||||
except (TypeError, ValueError) as ex:
|
||||
try:
|
||||
self._set_state(next_state)
|
||||
@ -152,9 +160,6 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
except ValueError:
|
||||
return next_state
|
||||
|
||||
self._set_state(next_state)
|
||||
return when
|
||||
|
||||
def _set_state(self, state):
|
||||
if state is None:
|
||||
return
|
||||
|
87
soil/agents/meta.py
Normal file
87
soil/agents/meta.py
Normal file
@ -0,0 +1,87 @@
|
||||
from abc import ABCMeta
|
||||
from copy import copy
|
||||
from functools import wraps
|
||||
from .. import time
|
||||
|
||||
import types
|
||||
import inspect
|
||||
|
||||
def decorate_generator_step(func, name):
|
||||
@wraps(func)
|
||||
def decorated(self):
|
||||
if not self.alive:
|
||||
return time.INFINITY
|
||||
|
||||
if self._coroutine is None:
|
||||
self._coroutine = func(self)
|
||||
try:
|
||||
if self._last_except:
|
||||
val = self._coroutine.throw(self._last_except)
|
||||
else:
|
||||
val = self._coroutine.send(self._last_return)
|
||||
except StopIteration as ex:
|
||||
self._coroutine = None
|
||||
val = ex.value
|
||||
finally:
|
||||
self._last_return = None
|
||||
self._last_except = None
|
||||
return float(val) if val is not None else val
|
||||
return decorated
|
||||
|
||||
|
||||
def decorate_normal_step(func, name):
|
||||
@wraps(func)
|
||||
def decorated(self):
|
||||
# if not self.alive:
|
||||
# return time.INFINITY
|
||||
val = func(self)
|
||||
return float(val) if val is not None else val
|
||||
return decorated
|
||||
|
||||
|
||||
class MetaAgent(ABCMeta):
|
||||
def __new__(mcls, name, bases, namespace):
|
||||
defaults = {}
|
||||
|
||||
# Re-use defaults from inherited classes
|
||||
for i in bases:
|
||||
if isinstance(i, MetaAgent):
|
||||
defaults.update(i._defaults)
|
||||
|
||||
new_nmspc = {
|
||||
"_defaults": defaults,
|
||||
}
|
||||
|
||||
for attr, func in namespace.items():
|
||||
if attr == "step":
|
||||
if inspect.isgeneratorfunction(func) or inspect.iscoroutinefunction(func):
|
||||
func = decorate_generator_step(func, attr)
|
||||
new_nmspc.update({
|
||||
"_last_return": None,
|
||||
"_last_except": None,
|
||||
"_coroutine": None,
|
||||
})
|
||||
elif inspect.isasyncgenfunction(func):
|
||||
raise ValueError("Illegal step function: {}. It probably mixes both async/await and yield".format(func))
|
||||
elif inspect.isfunction(func):
|
||||
func = decorate_normal_step(func, attr)
|
||||
else:
|
||||
raise ValueError("Illegal step function: {}".format(func))
|
||||
new_nmspc[attr] = func
|
||||
elif (
|
||||
isinstance(func, types.FunctionType)
|
||||
or isinstance(func, property)
|
||||
or isinstance(func, classmethod)
|
||||
or attr[0] == "_"
|
||||
):
|
||||
new_nmspc[attr] = func
|
||||
elif attr == "defaults":
|
||||
defaults.update(func)
|
||||
elif inspect.isfunction(func):
|
||||
new_nmspc[attr] = func
|
||||
else:
|
||||
defaults[attr] = copy(func)
|
||||
|
||||
|
||||
# Add attributes for their use in the decorated functions
|
||||
return super().__new__(mcls, name, bases, new_nmspc)
|
@ -1,9 +1,11 @@
|
||||
from . import BaseAgent
|
||||
from .. import environment
|
||||
|
||||
|
||||
class NetworkAgent(BaseAgent):
|
||||
def __init__(self, *args, topology=None, init=True, node_id=None, **kwargs):
|
||||
super().__init__(*args, init=False, **kwargs)
|
||||
assert isinstance(self.model, environment.NetworkEnvironment), "NetworkAgent requires a NetworkEnvironment"
|
||||
|
||||
self.G = topology or self.model.G
|
||||
assert self.G is not None, "Network agents should have a network"
|
||||
@ -16,7 +18,7 @@ class NetworkAgent(BaseAgent):
|
||||
else:
|
||||
node_id = len(self.G)
|
||||
self.info(f"All nodes ({len(self.G)}) have an agent assigned, adding a new node to the graph for agent {self.unique_id}")
|
||||
self.G.add_node(node_id)
|
||||
self.G.add_node(node_id, find_unassigned=True)
|
||||
assert node_id is not None
|
||||
self.G.nodes[node_id]["agent"] = self
|
||||
self.node_id = node_id
|
||||
|
136
soil/agents/view.py
Normal file
136
soil/agents/view.py
Normal file
@ -0,0 +1,136 @@
|
||||
from collections.abc import Mapping, Set
|
||||
from itertools import islice
|
||||
|
||||
|
||||
class AgentView(Mapping, Set):
|
||||
"""A lazy-loaded list of agents."""
|
||||
|
||||
__slots__ = ("_agents", "agents_by_type")
|
||||
|
||||
def __init__(self, agents, agents_by_type):
|
||||
self._agents = agents
|
||||
self.agents_by_type = agents_by_type
|
||||
|
||||
def __getstate__(self):
|
||||
return {"_agents": self._agents}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._agents = state["_agents"]
|
||||
|
||||
# Mapping methods
|
||||
def __len__(self):
|
||||
return len(self._agents)
|
||||
|
||||
def __iter__(self):
|
||||
yield from self._agents.values()
|
||||
|
||||
def __getitem__(self, agent_id):
|
||||
if isinstance(agent_id, slice):
|
||||
raise ValueError(f"Slicing is not supported")
|
||||
if agent_id in self._agents:
|
||||
return self._agents[agent_id]
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
def filter(self, agent_class=None, include_subclasses=True, **kwargs):
|
||||
if agent_class and self.agents_by_type:
|
||||
if not include_subclasses:
|
||||
return filter_agents(self.agents_by_type[agent_class],
|
||||
**kwargs)
|
||||
else:
|
||||
d = {}
|
||||
for k, v in self.agents_by_type.items():
|
||||
if (k == agent_class) or issubclass(k, agent_class):
|
||||
d.update(v)
|
||||
return filter_agents(d, **kwargs)
|
||||
return filter_agents(self._agents, agent_class=agent_class, **kwargs)
|
||||
|
||||
|
||||
def one(self, *args, **kwargs):
|
||||
try:
|
||||
return next(self.filter(*args, **kwargs))
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return list(self.filter(*args, **kwargs))
|
||||
|
||||
def __contains__(self, agent_id):
|
||||
return agent_id in self._agents
|
||||
|
||||
def __str__(self):
|
||||
return str(list(unique_id for unique_id in self.keys()))
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self})"
|
||||
|
||||
|
||||
def filter_agents(
|
||||
agents: dict,
|
||||
*id_args,
|
||||
unique_id=None,
|
||||
state_id=None,
|
||||
agent_class=None,
|
||||
ignore=None,
|
||||
state=None,
|
||||
limit=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
||||
"""
|
||||
assert isinstance(agents, dict)
|
||||
|
||||
ids = []
|
||||
|
||||
if unique_id is not None:
|
||||
if isinstance(unique_id, list):
|
||||
ids += unique_id
|
||||
else:
|
||||
ids.append(unique_id)
|
||||
|
||||
if id_args:
|
||||
ids += id_args
|
||||
|
||||
if ids:
|
||||
f = list(agents[aid] for aid in ids if aid in agents)
|
||||
else:
|
||||
f = agents.values()
|
||||
|
||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||
state_id = tuple([state_id])
|
||||
|
||||
if ignore:
|
||||
f = filter(lambda x: x not in ignore, f)
|
||||
|
||||
if state_id is not None:
|
||||
f = filter(lambda agent: agent.get("state_id", None) in state_id, f)
|
||||
|
||||
if agent_class is not None:
|
||||
f = filter(lambda agent: isinstance(agent, agent_class), f)
|
||||
|
||||
state = state or dict()
|
||||
state.update(kwargs)
|
||||
|
||||
for k, vs in state.items():
|
||||
if not isinstance(vs, list):
|
||||
vs = [vs]
|
||||
f = filter(lambda agent: any(getattr(agent, k, None) == v for v in vs), f)
|
||||
|
||||
if limit is not None:
|
||||
f = islice(f, limit)
|
||||
|
||||
return AgentResult(f)
|
||||
|
||||
class AgentResult:
|
||||
def __init__(self, iterator):
|
||||
self.iterator = iterator
|
||||
|
||||
def limit(self, limit):
|
||||
self.iterator = islice(self.iterator, limit)
|
||||
return self
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.iterator)
|
||||
|
||||
def __next__(self):
|
||||
return next(self.iterator)
|
@ -43,15 +43,14 @@ def read_sql(fpath=None, name=None, include_agents=False):
|
||||
with engine.connect() as conn:
|
||||
env = pd.read_sql_table("env", con=conn,
|
||||
index_col="step").reset_index().set_index([
|
||||
"simulation_id", "params_id",
|
||||
"iteration_id", "step"
|
||||
"params_id", "iteration_id", "step"
|
||||
])
|
||||
agents = pd.read_sql_table("agents", con=conn, index_col=["simulation_id", "params_id", "iteration_id", "step", "agent_id"])
|
||||
agents = pd.read_sql_table("agents", con=conn, index_col=["params_id", "iteration_id", "step", "agent_id"])
|
||||
config = pd.read_sql_table("configuration", con=conn, index_col="simulation_id")
|
||||
parameters = pd.read_sql_table("parameters", con=conn, index_col=["iteration_id", "params_id", "simulation_id"])
|
||||
try:
|
||||
parameters = parameters.pivot(columns="key", values="value")
|
||||
except Exception as e:
|
||||
print(f"warning: coult not pivot parameters: {e}")
|
||||
parameters = pd.read_sql_table("parameters", con=conn, index_col=["simulation_id", "params_id", "iteration_id"])
|
||||
# try:
|
||||
# parameters = parameters.pivot(columns="key", values="value")
|
||||
# except Exception as e:
|
||||
# print(f"warning: coult not pivot parameters: {e}")
|
||||
|
||||
return Results(config, parameters, env, agents)
|
||||
|
@ -1,20 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import math
|
||||
import sys
|
||||
import logging
|
||||
import inspect
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Union, List, Type
|
||||
from collections import namedtuple
|
||||
from time import time as current_time
|
||||
from copy import deepcopy
|
||||
|
||||
from types import coroutine
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from mesa import Model, Agent
|
||||
|
||||
from mesa import Model
|
||||
|
||||
from . import agents as agentmod, datacollection, utils, time, network, events
|
||||
|
||||
@ -114,10 +110,10 @@ class BaseEnvironment(Model):
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
return agentmod.AgentView(self.schedule._agents)
|
||||
return agentmod.AgentView(self.schedule._agents, getattr(self.schedule, "agents_by_type", None))
|
||||
|
||||
def agent(self, *args, **kwargs):
|
||||
return agentmod.AgentView(self.schedule._agents).one(*args, **kwargs)
|
||||
return agentmod.AgentView(self.schedule._agents, self.schedule.agents_by_type).one(*args, **kwargs)
|
||||
|
||||
def count_agents(self, *args, **kwargs):
|
||||
return sum(1 for i in self.agents(*args, **kwargs))
|
||||
@ -205,11 +201,18 @@ class BaseEnvironment(Model):
|
||||
func = name
|
||||
self.datacollector._new_model_reporter(name, func)
|
||||
|
||||
def add_agent_reporter(self, name, reporter=None, agent_type=None):
|
||||
if not agent_type and not reporter:
|
||||
def add_agent_reporter(self, name, reporter=None, agent_class=None, *, agent_type=None):
|
||||
if agent_type:
|
||||
print("agent_type is deprecated, use agent_class instead", file=sys.stderr)
|
||||
agent_class = agent_type or agent_class
|
||||
if not reporter and not agent_class:
|
||||
reporter = name
|
||||
elif agent_type:
|
||||
reporter = lambda a: reporter(a) if isinstance(a, agent_type) else None
|
||||
if agent_class:
|
||||
if reporter:
|
||||
_orig = reporter
|
||||
else:
|
||||
_orig = lambda a: getattr(a, name)
|
||||
reporter = lambda a: (_orig(a) if isinstance(a, agent_class) else None)
|
||||
self.datacollector._new_agent_reporter(name, reporter)
|
||||
|
||||
@classmethod
|
||||
@ -331,15 +334,16 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
if getattr(agent, "alive", True):
|
||||
yield agent
|
||||
|
||||
def add_node(self, agent_class, unique_id=None, node_id=None, **kwargs):
|
||||
def add_node(self, agent_class, unique_id=None, node_id=None, find_unassigned=False, **kwargs):
|
||||
if unique_id is None:
|
||||
unique_id = self.next_id()
|
||||
if node_id is None:
|
||||
if find_unassigned:
|
||||
node_id = network.find_unassigned(
|
||||
G=self.G, shuffle=True, random=self.random
|
||||
)
|
||||
if node_id is None:
|
||||
node_id = f"node_for_{unique_id}"
|
||||
node_id = f"Node_for_agent_{unique_id}"
|
||||
|
||||
if node_id not in self.G.nodes:
|
||||
self.G.add_node(node_id)
|
||||
@ -409,27 +413,84 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
|
||||
|
||||
class EventedEnvironment(BaseEnvironment):
|
||||
def broadcast(self, msg, sender=None, expiration=None, ttl=None, **kwargs):
|
||||
for agent in self.agents(**kwargs):
|
||||
if agent == sender:
|
||||
continue
|
||||
self.logger.debug(f"Telling {repr(agent)}: {msg} ttl={ttl}")
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._inbox = dict()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def register(self, agent):
|
||||
self._inbox[agent.unique_id] = []
|
||||
|
||||
def inbox_for(self, agent):
|
||||
try:
|
||||
inbox = agent._inbox
|
||||
except AttributeError:
|
||||
self.logger.info(
|
||||
f"Agent {agent.unique_id} cannot receive events because it does not have an inbox"
|
||||
)
|
||||
return self._inbox[agent.unique_id]
|
||||
except KeyError:
|
||||
raise ValueError(f"Trying to access inbox for unregistered agent: {agent} (class: {type(agent)}). "
|
||||
"Make sure your agent is of type EventedAgent and it is registered with the environment.")
|
||||
|
||||
@coroutine
|
||||
def received(self, agent, expiration=None, timeout=60, delay=1):
|
||||
if not expiration:
|
||||
expiration = self.now + timeout
|
||||
inbox = self.inbox_for(agent)
|
||||
if inbox:
|
||||
return self.process_messages(inbox)
|
||||
while self.now < expiration:
|
||||
# TODO: this wakes the agent up at every step. It would be better to wait until timeout (or inf)
|
||||
# and if a message is received before that, reschedule the agent when
|
||||
if inbox:
|
||||
return self.process_messages(inbox)
|
||||
yield time.Delay(delay)
|
||||
raise events.TimedOut("No message received")
|
||||
|
||||
def tell(self, msg, sender, recipient, expiration=None, timeout=None, **kwargs):
|
||||
if expiration is None:
|
||||
expiration = float("inf") if timeout is None else self.now + timeout
|
||||
self.inbox_for(recipient).append(
|
||||
events.Tell(timestamp=self.now,
|
||||
payload=msg,
|
||||
sender=sender,
|
||||
expiration=expiration,
|
||||
**kwargs))
|
||||
|
||||
def broadcast(self, msg, sender, ttl=None, expiration=None, agent_class=None):
|
||||
expiration = expiration if ttl is None else self.now + ttl
|
||||
# This only works for Soil environments. Mesa agents do not have an `agents` method
|
||||
sender_id = sender.unique_id
|
||||
for (agent_id, inbox) in self._inbox.items():
|
||||
if agent_id == sender_id:
|
||||
continue
|
||||
# Allow for AttributeError exceptions in this part of the code
|
||||
if agent_class and not isinstance(self.agents(unique_id=agent_id), agent_class):
|
||||
continue
|
||||
self.logger.debug(f"Telling {agent_id}: {msg} ttl={ttl}")
|
||||
inbox.append(
|
||||
events.Tell(
|
||||
payload=msg,
|
||||
sender=sender,
|
||||
expiration=expiration if ttl is None else self.now + ttl,
|
||||
expiration=expiration,
|
||||
)
|
||||
)
|
||||
|
||||
@coroutine
|
||||
def ask(self, msg, recipient, sender=None, expiration=None, timeout=None, delay=1):
|
||||
ask = events.Ask(timestamp=self.now, payload=msg, sender=sender)
|
||||
self.inbox_for(recipient).append(ask)
|
||||
expiration = float("inf") if timeout is None else self.now + timeout
|
||||
while self.now < expiration:
|
||||
if ask.reply:
|
||||
return ask.reply
|
||||
yield time.Delay(delay)
|
||||
raise events.TimedOut("No reply received")
|
||||
|
||||
class Environment(NetworkEnvironment, EventedEnvironment):
|
||||
"""Default environment class, has both network and event capabilities"""
|
||||
def process_messages(self, inbox):
|
||||
valid = list()
|
||||
for msg in inbox:
|
||||
if msg.expired(self.now):
|
||||
continue
|
||||
valid.append(msg)
|
||||
inbox.clear()
|
||||
return valid
|
||||
|
||||
|
||||
class Environment(EventedEnvironment, NetworkEnvironment):
|
||||
pass
|
@ -18,7 +18,6 @@ class Message:
|
||||
def expired(self, when):
|
||||
return self.expiration is not None and self.expiration < when
|
||||
|
||||
|
||||
class Reply(Message):
|
||||
source: Message
|
||||
|
||||
@ -28,7 +27,9 @@ class Ask(Message):
|
||||
|
||||
|
||||
class Tell(Message):
|
||||
pass
|
||||
def __post_init__(self):
|
||||
assert self.sender is not None, "Tell requires a sender"
|
||||
|
||||
|
||||
|
||||
class TimedOut(Exception):
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
from time import time as current_time
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from textwrap import dedent, indent
|
||||
|
||||
@ -86,18 +87,19 @@ class Exporter:
|
||||
pass
|
||||
return open_or_reuse(f, mode=mode, backup=self.simulation.backup, **kwargs)
|
||||
|
||||
def get_dfs(self, env, **kwargs):
|
||||
def get_dfs(self, env, params_id, **kwargs):
|
||||
yield from get_dc_dfs(env.datacollector,
|
||||
simulation_id=self.simulation.id,
|
||||
params_id,
|
||||
iteration_id=env.id,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def get_dc_dfs(dc, **kwargs):
|
||||
def get_dc_dfs(dc, params_id, **kwargs):
|
||||
dfs = {}
|
||||
dfe = dc.get_model_vars_dataframe()
|
||||
dfe.index.rename("step", inplace=True)
|
||||
dfs["env"] = dfe
|
||||
kwargs["params_id"] = params_id
|
||||
try:
|
||||
dfa = dc.get_agent_vars_dataframe()
|
||||
dfa.index.rename(["step", "agent_id"], inplace=True)
|
||||
@ -108,8 +110,12 @@ def get_dc_dfs(dc, **kwargs):
|
||||
dfs[table_name] = dc.get_table_dataframe(table_name)
|
||||
for (name, df) in dfs.items():
|
||||
for (k, v) in kwargs.items():
|
||||
if v:
|
||||
df[k] = v
|
||||
df.set_index(["simulation_id", "iteration_id"], append=True, inplace=True)
|
||||
else:
|
||||
df[k] = pd.Series(dtype="object")
|
||||
df.reset_index(inplace=True)
|
||||
df.set_index(["params_id", "iteration_id"], inplace=True)
|
||||
|
||||
yield from dfs.items()
|
||||
|
||||
@ -134,12 +140,16 @@ class SQLite(Exporter):
|
||||
if os.path.exists(self.dbpath):
|
||||
os.remove(self.dbpath)
|
||||
|
||||
outdir = os.path.dirname(self.dbpath)
|
||||
if outdir and not os.path.exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
|
||||
self.engine = create_engine(f"sqlite:///{self.dbpath}", echo=False)
|
||||
|
||||
sim_dict = {k: serialize(v)[0] for (k,v) in self.simulation.to_dict().items()}
|
||||
sim_dict["simulation_id"] = self.simulation.id
|
||||
df = pd.DataFrame([sim_dict])
|
||||
df.to_sql("configuration", con=self.engine, if_exists="append")
|
||||
df.reset_index().to_sql("configuration", con=self.engine, if_exists="append", index=False)
|
||||
|
||||
def iteration_end(self, env, params, params_id, *args, **kwargs):
|
||||
if not self.dump:
|
||||
@ -149,15 +159,28 @@ class SQLite(Exporter):
|
||||
with timer(
|
||||
"Dumping simulation {} iteration {}".format(self.simulation.name, env.id)
|
||||
):
|
||||
|
||||
pd.DataFrame([{"simulation_id": self.simulation.id,
|
||||
d = {"simulation_id": self.simulation.id,
|
||||
"params_id": params_id,
|
||||
"iteration_id": env.id,
|
||||
"key": k,
|
||||
"value": serialize(v)[0]} for (k,v) in params.items()]).to_sql("parameters", con=self.engine, if_exists="append")
|
||||
}
|
||||
for (k,v) in params.items():
|
||||
d[k] = serialize(v)[0]
|
||||
|
||||
pd.DataFrame([d]).reset_index().to_sql("parameters",
|
||||
con=self.engine,
|
||||
if_exists="append",
|
||||
index=False)
|
||||
pd.DataFrame([{
|
||||
"simulation_id": self.simulation.id,
|
||||
"params_id": params_id,
|
||||
"iteration_id": env.id,
|
||||
}]).reset_index().to_sql("iterations",
|
||||
con=self.engine,
|
||||
if_exists="append",
|
||||
index=False)
|
||||
|
||||
for (t, df) in self.get_dfs(env, params_id=params_id):
|
||||
df.to_sql(t, con=self.engine, if_exists="append")
|
||||
df.reset_index().to_sql(t, con=self.engine, if_exists="append", index=False)
|
||||
|
||||
class csv(Exporter):
|
||||
"""Export the state of each environment (and its agents) a CSV file for the simulation"""
|
||||
@ -226,9 +249,9 @@ class graphdrawing(Exporter):
|
||||
class summary(Exporter):
|
||||
"""Print a summary of each iteration to sys.stdout"""
|
||||
|
||||
def iteration_end(self, env, *args, **kwargs):
|
||||
def iteration_end(self, env, params_id, *args, **kwargs):
|
||||
msg = ""
|
||||
for (t, df) in self.get_dfs(env):
|
||||
for (t, df) in self.get_dfs(env, params_id):
|
||||
if not len(df):
|
||||
continue
|
||||
tabs = "\t" * 2
|
||||
@ -262,21 +285,5 @@ class YAML(Exporter):
|
||||
logger.info(f"Dumping simulation configuration to {self.outdir}")
|
||||
f.write(self.simulation.to_yaml())
|
||||
|
||||
class default(Exporter):
|
||||
"""Default exporter. Writes sqlite results, as well as the simulation YAML"""
|
||||
|
||||
def __init__(self, *args, exporter_cls=[], **kwargs):
|
||||
exporter_cls = exporter_cls or [YAML, SQLite]
|
||||
self.inner = [cls(*args, **kwargs) for cls in exporter_cls]
|
||||
|
||||
def sim_start(self, *args, **kwargs):
|
||||
for exporter in self.inner:
|
||||
exporter.sim_start(*args, **kwargs)
|
||||
|
||||
def sim_end(self, *args, **kwargs):
|
||||
for exporter in self.inner:
|
||||
exporter.sim_end(*args, **kwargs)
|
||||
|
||||
def iteration_end(self, *args, **kwargs):
|
||||
for exporter in self.inner:
|
||||
exporter.iteration_end(*args, **kwargs)
|
||||
default = SQLite
|
86
soil/time.py
86
soil/time.py
@ -1,7 +1,7 @@
|
||||
from mesa.time import BaseScheduler
|
||||
from queue import Empty
|
||||
from heapq import heappush, heappop, heapreplace
|
||||
from collections import deque
|
||||
from collections import deque, defaultdict
|
||||
import math
|
||||
import logging
|
||||
|
||||
@ -27,6 +27,12 @@ class Delay:
|
||||
def __await__(self):
|
||||
return (yield self.delta)
|
||||
|
||||
class When:
|
||||
def __init__(self, when):
|
||||
raise Exception("The use of When is deprecated. Use the `Agent.at` and `Agent.delay` methods instead")
|
||||
class Delta:
|
||||
def __init__(self, delta):
|
||||
raise Exception("The use of Delay is deprecated. Use the `Agent.at` and `Agent.delay` methods instead")
|
||||
|
||||
class DeadAgent(Exception):
|
||||
pass
|
||||
@ -46,27 +52,39 @@ class PQueueActivation(BaseScheduler):
|
||||
self._shuffle = shuffle
|
||||
self.logger = getattr(self.model, "logger", logger).getChild(f"time_{ self.model }")
|
||||
self.next_time = self.time
|
||||
self.agents_by_type = defaultdict(dict)
|
||||
|
||||
def add(self, agent: MesaAgent, when=None):
|
||||
if when is None:
|
||||
when = self.time
|
||||
else:
|
||||
when = float(when)
|
||||
|
||||
self._schedule(agent, None, when)
|
||||
agent_class = type(agent)
|
||||
self.agents_by_type[agent_class][agent.unique_id] = agent
|
||||
super().add(agent)
|
||||
self.add_callback(agent.step, when)
|
||||
|
||||
def _schedule(self, agent, when=None, replace=False):
|
||||
def add_callback(self, callback, when=None, replace=False):
|
||||
if when is None:
|
||||
when = self.time
|
||||
else:
|
||||
when = float(when)
|
||||
if self._shuffle:
|
||||
key = (when, self.model.random.random())
|
||||
else:
|
||||
key = (when, agent.unique_id)
|
||||
key = when
|
||||
if replace:
|
||||
heapreplace(self._queue, (key, agent))
|
||||
heapreplace(self._queue, (key, callback))
|
||||
else:
|
||||
heappush(self._queue, (key, agent))
|
||||
heappush(self._queue, (key, callback))
|
||||
|
||||
def remove(self, agent):
|
||||
del self._agents[agent.unique_id]
|
||||
del self._agents[type(agent)][agent.unique_id]
|
||||
for i, (key, callback) in enumerate(self._queue):
|
||||
if callback == agent.step:
|
||||
del self._queue[i]
|
||||
break
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
@ -87,18 +105,14 @@ class PQueueActivation(BaseScheduler):
|
||||
next_time = when
|
||||
break
|
||||
|
||||
try:
|
||||
when = agent.step() or 1
|
||||
when += now
|
||||
except DeadAgent:
|
||||
heappop(self._queue)
|
||||
continue
|
||||
|
||||
if when == INFINITY:
|
||||
heappop(self._queue)
|
||||
continue
|
||||
when += now
|
||||
|
||||
self._schedule(agent, when, replace=True)
|
||||
self.add_callback(agent, when, replace=True)
|
||||
|
||||
self.steps += 1
|
||||
|
||||
@ -117,26 +131,42 @@ class TimedActivation(BaseScheduler):
|
||||
self._shuffle = shuffle
|
||||
self.logger = getattr(self.model, "logger", logger).getChild(f"time_{ self.model }")
|
||||
self.next_time = self.time
|
||||
self.agents_by_type = defaultdict(dict)
|
||||
|
||||
def add(self, agent: MesaAgent, when=None):
|
||||
self.add_callback(agent.step, when)
|
||||
agent_class = type(agent)
|
||||
self.agents_by_type[agent_class][agent.unique_id] = agent
|
||||
super().add(agent)
|
||||
|
||||
def _find_loc(self, when=None):
|
||||
if when is None:
|
||||
when = self.time
|
||||
else:
|
||||
when = float(when)
|
||||
self._schedule(agent, None, when)
|
||||
super().add(agent)
|
||||
|
||||
def _schedule(self, agent, when=None, replace=False):
|
||||
when = when or self.time
|
||||
pos = len(self._queue)
|
||||
for (ix, l) in enumerate(self._queue):
|
||||
if l[0] == when:
|
||||
l[1].append(agent)
|
||||
return
|
||||
return l[1]
|
||||
if l[0] > when:
|
||||
pos = ix
|
||||
break
|
||||
self._queue.insert(pos, (when, [agent]))
|
||||
lst = []
|
||||
self._queue.insert(pos, (when, lst))
|
||||
return lst
|
||||
|
||||
def add_callback(self, func, when=None, replace=False):
|
||||
lst = self._find_loc(when)
|
||||
lst.append(func)
|
||||
|
||||
def add_bulk(self, funcs, when=None):
|
||||
lst = self._find_loc(when)
|
||||
lst.extend(funcs)
|
||||
|
||||
def remove(self, agent):
|
||||
del self._agents[agent.unique_id]
|
||||
del self.agents_by_type[type(agent)][agent.unique_id]
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
@ -157,20 +187,22 @@ class TimedActivation(BaseScheduler):
|
||||
bucket = self._queue.popleft()[1]
|
||||
if self._shuffle:
|
||||
self.model.random.shuffle(bucket)
|
||||
for agent in bucket:
|
||||
try:
|
||||
when = agent.step() or 1
|
||||
when += now
|
||||
except DeadAgent:
|
||||
continue
|
||||
next_batch = defaultdict(list)
|
||||
for func in bucket:
|
||||
when = func() or 1
|
||||
|
||||
if when != INFINITY:
|
||||
self._schedule(agent, when, replace=True)
|
||||
when += now
|
||||
next_batch[when].append(func)
|
||||
|
||||
for (when, bucket) in next_batch.items():
|
||||
self.add_bulk(bucket, when)
|
||||
|
||||
self.steps += 1
|
||||
if self._queue:
|
||||
self.time = self._queue[0][0]
|
||||
else:
|
||||
self.model.running = False
|
||||
self.time = INFINITY
|
||||
|
||||
|
||||
|
@ -158,3 +158,23 @@ def run_parallel(func, iterable, num_processes=1, **kwargs):
|
||||
|
||||
def int_seed(seed: str):
|
||||
return int.from_bytes(seed.encode(), "little")
|
||||
|
||||
|
||||
def prob(prob, random):
|
||||
"""
|
||||
A true/False uniform distribution with a given probability.
|
||||
To be used like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
if prob(0.3):
|
||||
do_something()
|
||||
|
||||
"""
|
||||
r = random.random()
|
||||
return r < prob
|
||||
|
||||
|
||||
def custom(cls, **kwargs):
|
||||
"""Create a new class from a template class and keyword arguments"""
|
||||
return type(cls.__name__, (cls,), kwargs)
|
@ -20,13 +20,14 @@ class TestAgents(TestCase):
|
||||
assert ret == stime.INFINITY
|
||||
|
||||
def test_die_raises_exception(self):
|
||||
"""A dead agent should raise an exception if it is stepped after death"""
|
||||
"""A dead agent should continue returning INFINITY after death"""
|
||||
d = Dead(unique_id=0, model=environment.Environment())
|
||||
assert d.alive
|
||||
d.step()
|
||||
assert not d.alive
|
||||
with pytest.raises(stime.DeadAgent):
|
||||
d.step()
|
||||
when = d.step()
|
||||
assert not d.alive
|
||||
assert when == stime.INFINITY
|
||||
|
||||
def test_agent_generator(self):
|
||||
"""
|
||||
@ -79,30 +80,28 @@ class TestAgents(TestCase):
|
||||
"""
|
||||
|
||||
class BCast(agents.Evented):
|
||||
pings_received = 0
|
||||
pings_received = []
|
||||
|
||||
def step(self):
|
||||
print(self.model.broadcast)
|
||||
try:
|
||||
self.model.broadcast("PING")
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
async def step(self):
|
||||
self.broadcast("PING")
|
||||
print("PING sent")
|
||||
while True:
|
||||
self.process_messages()
|
||||
yield
|
||||
msgs = await self.received()
|
||||
self.pings_received += msgs
|
||||
|
||||
def on_receive(self, msg, sender=None):
|
||||
self.pings_received += 1
|
||||
e = environment.Environment()
|
||||
|
||||
e = environment.EventedEnvironment()
|
||||
|
||||
for i in range(10):
|
||||
num_agents = 10
|
||||
for i in range(num_agents):
|
||||
e.add_agent(agent_class=BCast)
|
||||
e.step()
|
||||
pings_received = lambda: [a.pings_received for a in e.agents]
|
||||
assert sorted(pings_received()) == list(range(1, 11))
|
||||
# Agents are executed in order, so the first agent should have not received any messages
|
||||
pings_received = lambda: [len(a.pings_received) for a in e.agents]
|
||||
assert sorted(pings_received()) == list(range(0, num_agents))
|
||||
e.step()
|
||||
assert all(x == 10 for x in pings_received())
|
||||
# After the second step, every agent should have received a broadcast from every other agent
|
||||
received = pings_received()
|
||||
assert all(x == (num_agents - 1) for x in received)
|
||||
|
||||
def test_ask_messages(self):
|
||||
"""
|
||||
@ -140,17 +139,16 @@ class TestAgents(TestCase):
|
||||
print("NOT sending ping")
|
||||
print("Checking msgs")
|
||||
# Do not block if we have already received a PING
|
||||
if not self.process_messages():
|
||||
yield from self.received()
|
||||
print("done")
|
||||
|
||||
def on_receive(self, msg, sender=None):
|
||||
if msg == "PING":
|
||||
msgs = yield from self.received()
|
||||
for ping in msgs:
|
||||
if ping.payload == "PING":
|
||||
ping.reply = "PONG"
|
||||
pongs.append(self.now)
|
||||
return "PONG"
|
||||
else:
|
||||
raise Exception("This should never happen")
|
||||
|
||||
e = environment.EventedEnvironment(schedule_class=stime.OrderedTimedActivation)
|
||||
|
||||
e = environment.Environment(schedule_class=stime.OrderedTimedActivation)
|
||||
for i in range(2):
|
||||
e.add_agent(agent_class=Ping)
|
||||
assert e.now == 0
|
||||
@ -373,3 +371,23 @@ class TestAgents(TestCase):
|
||||
model.step()
|
||||
assert a.now == 17
|
||||
assert a.my_state == 5
|
||||
|
||||
def test_send_nonevent(self):
|
||||
'''
|
||||
Sending a non-event should raise an error.
|
||||
'''
|
||||
model = environment.Environment()
|
||||
a = model.add_agent(agents.Noop)
|
||||
class TestAgent(agents.Agent):
|
||||
@agents.state(default=True)
|
||||
def one(self):
|
||||
try:
|
||||
a.tell(b, 1)
|
||||
raise AssertionError('Should have raised an error.')
|
||||
except AttributeError:
|
||||
self.model.tell(1, sender=self, recipient=a)
|
||||
|
||||
model.add_agent(TestAgent)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
model.step()
|
@ -81,7 +81,8 @@ class Exporters(TestCase):
|
||||
model=ConstantEnv,
|
||||
name="exporter_sim",
|
||||
exporters=[
|
||||
exporters.default,
|
||||
exporters.YAML,
|
||||
exporters.SQLite,
|
||||
exporters.csv,
|
||||
],
|
||||
exporter_params={"copy_to": output},
|
||||
|
@ -135,9 +135,9 @@ class TestMain(TestCase):
|
||||
|
||||
def test_serialize_agent_class(self):
|
||||
"""A class from soil.agents should be serialized without the module part"""
|
||||
ser = agents._serialize_type(CustomAgent)
|
||||
ser = serialization.serialize(CustomAgent, known_modules=["soil.agents"])[1]
|
||||
assert ser == "test_main.CustomAgent"
|
||||
ser = agents._serialize_type(agents.BaseAgent)
|
||||
ser = serialization.serialize(agents.BaseAgent, known_modules=["soil.agents"])[1]
|
||||
assert ser == "BaseAgent"
|
||||
pickle.dumps(ser)
|
||||
|
||||
@ -227,3 +227,29 @@ class TestMain(TestCase):
|
||||
for i in a:
|
||||
for j in b:
|
||||
assert {"a": i, "b": j} in configs
|
||||
|
||||
def test_agent_reporters(self):
|
||||
"""An environment should be able to set its own reporters"""
|
||||
class Noop2(agents.Noop):
|
||||
pass
|
||||
|
||||
e = Environment()
|
||||
e.add_agent(agents.Noop)
|
||||
e.add_agent(Noop2)
|
||||
e.add_agent_reporter("now")
|
||||
e.add_agent_reporter("base", lambda a: "base", agent_class=agents.Noop)
|
||||
e.add_agent_reporter("subclass", lambda a:"subclass", agent_class=Noop2)
|
||||
e.step()
|
||||
|
||||
# Step 0 is not present because we added the reporters
|
||||
# after initialization.
|
||||
df = e.agent_df()
|
||||
assert "now" in df.columns
|
||||
assert "base" in df.columns
|
||||
assert "subclass" in df.columns
|
||||
assert df["now"][(0,0)] == 1
|
||||
assert df["now"][(0,1)] == 1
|
||||
assert df["base"][(0,0)] == "base"
|
||||
assert df["base"][(0,1)] == "base"
|
||||
assert df["subclass"][(0,0)] is None
|
||||
assert df["subclass"][(0,1)] == "subclass"
|
Loading…
Reference in New Issue
Block a user