Add rescheduling for received

wip-1.0
J. Fernando Sánchez 1 year ago
parent ee0c4517cb
commit 189836408f

@ -11,6 +11,7 @@ def run_sim(model, **kwargs):
dump=False, dump=False,
num_processes=1, num_processes=1,
parameters={'num_agents': NUM_AGENTS}, parameters={'num_agents': NUM_AGENTS},
seed="",
max_steps=MAX_STEPS, max_steps=MAX_STEPS,
iterations=NUM_ITERS) iterations=NUM_ITERS)
opts.update(kwargs) opts.update(kwargs)

@ -8,7 +8,6 @@ class NoopAgent(Agent):
self.num_calls = 0 self.num_calls = 0
def step(self): def step(self):
# import pdb;pdb.set_trace()
self.num_calls += 1 self.num_calls += 1

@ -10,7 +10,6 @@ class NoopAgent(Agent):
self.num_calls = 0 self.num_calls = 0
def step(self): def step(self):
# import pdb;pdb.set_trace()
self.num_calls += 1 self.num_calls += 1

@ -0,0 +1,21 @@
from soil import Agent, Environment, Simulation, state
class NoopAgent(Agent):
num_calls = 0
@state(default=True)
def unique(self):
self.num_calls += 1
class NoopEnvironment(Environment):
num_agents = 100
def init(self):
self.add_agents(NoopAgent, k=self.num_agents)
self.add_agent_reporter("num_calls")
from _config import *
run_sim(model=NoopEnvironment)

@ -1,7 +1,7 @@
from soil import BaseAgent, Environment, Simulation from soil import Agent, Environment, Simulation
class NoopAgent(BaseAgent): class NoopAgent(Agent):
num_calls = 0 num_calls = 0
def step(self): def step(self):
@ -15,7 +15,6 @@ class NoopEnvironment(Environment):
self.add_agent_reporter("num_calls") self.add_agent_reporter("num_calls")
if __name__ == "__main__": from _config import *
from _config import *
run_sim(model=NoopEnvironment) run_sim(model=NoopEnvironment)

@ -1,5 +1,5 @@
from soil import Agent, Environment, Simulation from soil import Agent, Environment, Simulation
from soilent import Scheduler from soil.time import SoilentActivation
class NoopAgent(Agent): class NoopAgent(Agent):
@ -14,7 +14,7 @@ class NoopAgent(Agent):
class NoopEnvironment(Environment): class NoopEnvironment(Environment):
num_agents = 100 num_agents = 100
schedule_class = Scheduler schedule_class = SoilentActivation
def init(self): def init(self):
self.add_agents(NoopAgent, k=self.num_agents) self.add_agents(NoopAgent, k=self.num_agents)
@ -26,4 +26,4 @@ if __name__ == "__main__":
res = run_sim(model=NoopEnvironment) res = run_sim(model=NoopEnvironment)
for r in res: for r in res:
assert isinstance(r.schedule, Scheduler) assert isinstance(r.schedule, SoilentActivation)

@ -1,5 +1,5 @@
from soil import Agent, Environment from soil import Agent, Environment
from soilent import PQueueScheduler from soil.time import SoilentPQueueActivation
class NoopAgent(Agent): class NoopAgent(Agent):
@ -12,7 +12,7 @@ class NoopAgent(Agent):
class NoopEnvironment(Environment): class NoopEnvironment(Environment):
num_agents = 100 num_agents = 100
schedule_class = PQueueScheduler schedule_class = SoilentPQueueActivation
def init(self): def init(self):
self.add_agents(NoopAgent, k=self.num_agents) self.add_agents(NoopAgent, k=self.num_agents)
@ -24,4 +24,4 @@ if __name__ == "__main__":
res = run_sim(model=NoopEnvironment) res = run_sim(model=NoopEnvironment)
for r in res: for r in res:
assert isinstance(r.schedule, PQueueScheduler) assert isinstance(r.schedule, SoilentPQueueActivation)

@ -1,5 +1,5 @@
from soil import Agent, Environment, Simulation from soil import Agent, Environment, Simulation
from soilent import Scheduler from soil.time import SoilentActivation
class NoopAgent(Agent): class NoopAgent(Agent):
@ -13,7 +13,7 @@ class NoopAgent(Agent):
class NoopEnvironment(Environment): class NoopEnvironment(Environment):
num_agents = 100 num_agents = 100
schedule_class = Scheduler schedule_class = SoilentActivation
def init(self): def init(self):
self.add_agents(NoopAgent, k=self.num_agents) self.add_agents(NoopAgent, k=self.num_agents)
@ -25,4 +25,4 @@ if __name__ == "__main__":
res = run_sim(model=NoopEnvironment) res = run_sim(model=NoopEnvironment)
for r in res: for r in res:
assert isinstance(r.schedule, Scheduler) assert isinstance(r.schedule, SoilentActivation)

@ -1,5 +1,5 @@
from soil import Agent, Environment from soil import Agent, Environment
from soilent import PQueueScheduler from soil.time import SoilentPQueueActivation
class NoopAgent(Agent): class NoopAgent(Agent):
@ -13,7 +13,7 @@ class NoopAgent(Agent):
class NoopEnvironment(Environment): class NoopEnvironment(Environment):
num_agents = 100 num_agents = 100
schedule_class = PQueueScheduler schedule_class = SoilentPQueueActivation
def init(self): def init(self):
self.add_agents(NoopAgent, k=self.num_agents) self.add_agents(NoopAgent, k=self.num_agents)
@ -25,4 +25,4 @@ if __name__ == "__main__":
res = run_sim(model=NoopEnvironment) res = run_sim(model=NoopEnvironment)
for r in res: for r in res:
assert isinstance(r.schedule, PQueueScheduler) assert isinstance(r.schedule, SoilentPQueueActivation)

@ -0,0 +1,30 @@
from soil import Agent, Environment, Simulation, state
from soil.time import SoilentActivation
class NoopAgent(Agent):
num_calls = 0
@state(default=True)
async def unique(self):
while True:
self.num_calls += 1
# yield self.delay(1)
await self.delay()
class NoopEnvironment(Environment):
num_agents = 100
schedule_class = SoilentActivation
def init(self):
self.add_agents(NoopAgent, k=self.num_agents)
self.add_agent_reporter("num_calls")
if __name__ == "__main__":
from _config import *
res = run_sim(model=NoopEnvironment)
for r in res:
assert isinstance(r.schedule, SoilentActivation)

@ -1,5 +1,5 @@
from soil import BaseAgent, Environment, Simulation from soil import BaseAgent, Environment, Simulation
from soilent import Scheduler from soil.time import SoilentActivation
class NoopAgent(BaseAgent): class NoopAgent(BaseAgent):
@ -10,7 +10,7 @@ class NoopAgent(BaseAgent):
class NoopEnvironment(Environment): class NoopEnvironment(Environment):
num_agents = 100 num_agents = 100
schedule_class = Scheduler schedule_class = SoilentActivation
def init(self): def init(self):
self.add_agents(NoopAgent, k=self.num_agents) self.add_agents(NoopAgent, k=self.num_agents)
@ -21,4 +21,4 @@ if __name__ == "__main__":
from _config import * from _config import *
res = run_sim(model=NoopEnvironment) res = run_sim(model=NoopEnvironment)
for r in res: for r in res:
assert isinstance(r.schedule, Scheduler) assert isinstance(r.schedule, SoilentActivation)

@ -1,5 +1,5 @@
from soil import BaseAgent, Environment, Simulation from soil import BaseAgent, Environment, Simulation
from soilent import PQueueScheduler from soil.time import SoilentPQueueActivation
class NoopAgent(BaseAgent): class NoopAgent(BaseAgent):
@ -10,7 +10,7 @@ class NoopAgent(BaseAgent):
class NoopEnvironment(Environment): class NoopEnvironment(Environment):
num_agents = 100 num_agents = 100
schedule_class = PQueueScheduler schedule_class = SoilentPQueueActivation
def init(self): def init(self):
self.add_agents(NoopAgent, k=self.num_agents) self.add_agents(NoopAgent, k=self.num_agents)
@ -21,4 +21,4 @@ if __name__ == "__main__":
from _config import * from _config import *
res = run_sim(model=NoopEnvironment) res = run_sim(model=NoopEnvironment)
for r in res: for r in res:
assert isinstance(r.schedule, PQueueScheduler) assert isinstance(r.schedule, SoilentPqueueActivation)

@ -1,8 +1,9 @@
import os import os
from soil import simulation
NUM_AGENTS = int(os.environ.get('NUM_AGENTS', 100)) NUM_AGENTS = int(os.environ.get('NUM_AGENTS', 100))
NUM_ITERS = int(os.environ.get('NUM_ITERS', 10)) NUM_ITERS = int(os.environ.get('NUM_ITERS', 10))
MAX_STEPS = int(os.environ.get('MAX_STEPS', 1000)) MAX_STEPS = int(os.environ.get('MAX_STEPS', 500))
def run_sim(model, **kwargs): def run_sim(model, **kwargs):
@ -22,11 +23,16 @@ def run_sim(model, **kwargs):
iterations=NUM_ITERS) iterations=NUM_ITERS)
opts.update(kwargs) opts.update(kwargs)
its = Simulation(**opts).run() its = Simulation(**opts).run()
assert len(its) == NUM_ITERS
assert all(it.schedule.steps == MAX_STEPS for it in its) if not simulation._AVOID_RUNNING:
ratios = list(it.resistant_susceptible_ratio() for it in its) ratios = list(it.resistant_susceptible_ratio for it in its)
print("Max - Avg - Min ratio:", max(ratios), sum(ratios)/len(ratios), min(ratios)) print("Max - Avg - Min ratio:", max(ratios), sum(ratios)/len(ratios), min(ratios))
assert all(sum([it.number_susceptible, infected = list(it.number_infected for it in its)
it.number_infected, print("Max - Avg - Min infected:", max(infected), sum(infected)/len(infected), min(infected))
it.number_resistant]) == NUM_AGENTS for it in its)
return its assert all((it.schedule.steps == MAX_STEPS or it.number_infected == 0) for it in its)
assert all(sum([it.number_susceptible,
it.number_infected,
it.number_resistant]) == NUM_AGENTS for it in its)
return its

@ -100,6 +100,7 @@ class VirusOnNetwork(mesa.Model):
def number_infected(self): def number_infected(self):
return number_infected(self) return number_infected(self)
@property
def resistant_susceptible_ratio(self): def resistant_susceptible_ratio(self):
try: try:
return number_state(self, State.RESISTANT) / number_state( return number_state(self, State.RESISTANT) / number_state(
@ -176,5 +177,4 @@ class VirusAgent(mesa.Agent):
from _config import run_sim from _config import run_sim
run_sim(model=VirusOnNetwork) run_sim(model=VirusOnNetwork)

@ -30,8 +30,12 @@ class VirusOnNetwork(Environment):
for a in self.agents(node_id=infected_nodes): for a in self.agents(node_id=infected_nodes):
a.set_state(VirusAgent.infected) a.set_state(VirusAgent.infected)
assert self.number_infected == self.initial_outbreak_size assert self.number_infected == self.initial_outbreak_size
def step(self):
super().step()
@report @report
@property
def resistant_susceptible_ratio(self): def resistant_susceptible_ratio(self):
try: try:
return self.number_resistant / self.number_susceptible return self.number_resistant / self.number_susceptible
@ -59,34 +63,29 @@ class VirusAgent(Agent):
virus_check_frequency = None # Inherit from model virus_check_frequency = None # Inherit from model
recovery_chance = None # Inherit from model recovery_chance = None # Inherit from model
gain_resistance_chance = None # Inherit from model gain_resistance_chance = None # Inherit from model
just_been_infected = False
@state(default=True) @state(default=True)
def susceptible(self): async def susceptible(self):
if self.just_been_infected: await self.received()
self.just_been_infected = False return self.infected
return self.infected
@state @state
def infected(self): def infected(self):
susceptible_neighbors = self.get_neighbors(state_id=self.susceptible.id) susceptible_neighbors = self.get_neighbors(state_id=self.susceptible.id)
for a in susceptible_neighbors: for a in susceptible_neighbors:
if self.prob(self.virus_spread_chance): if self.prob(self.virus_spread_chance):
a.just_been_infected = True a.tell(True, sender=self)
if self.prob(self.virus_check_frequency): if self.prob(self.virus_check_frequency):
if self.prob(self.recovery_chance): if self.prob(self.recovery_chance):
if self.prob(self.gain_resistance_chance): if self.prob(self.gain_resistance_chance):
return self.resistant return self.resistant
else: else:
return self.susceptible return self.susceptible
else:
return self.infected
@state @state
def resistant(self): def resistant(self):
return self.at(INFINITY) return self.at(INFINITY)
if __name__ == "__main__": from _config import run_sim
from _config import run_sim run_sim(model=VirusOnNetwork)
run_sim(model=VirusOnNetwork)

@ -38,6 +38,7 @@ class VirusOnNetwork(Environment):
assert self.number_infected == self.initial_outbreak_size assert self.number_infected == self.initial_outbreak_size
@report @report
@property
def resistant_susceptible_ratio(self): def resistant_susceptible_ratio(self):
try: try:
return self.number_resistant / self.number_susceptible return self.number_resistant / self.number_susceptible
@ -99,6 +100,5 @@ class VirusAgent(Agent):
if __name__ == "__main__": from _config import run_sim
from _config import run_sim run_sim(model=VirusOnNetwork)
run_sim(model=VirusOnNetwork)

File diff suppressed because one or more lines are too long

@ -167,7 +167,7 @@ class RandomAccident(BaseAgent):
if self.prob(prob_death): if self.prob(prob_death):
self.debug("I killed a rabbit: {}".format(i.unique_id)) self.debug("I killed a rabbit: {}".format(i.unique_id))
num_alive -= 1 num_alive -= 1
i.die() self.model.remove_agent(i)
self.debug("Rabbits alive: {}".format(num_alive)) self.debug("Rabbits alive: {}".format(num_alive))

@ -142,13 +142,15 @@ class RandomAccident(BaseAgent):
prob_death = min(1, self.prob_death * num_alive/10) prob_death = min(1, self.prob_death * num_alive/10)
self.debug("Killing some rabbits with prob={}!".format(prob_death)) self.debug("Killing some rabbits with prob={}!".format(prob_death))
for i in self.get_agents(agent_class=Rabbit): for i in alive:
if i.state_id == i.dead.id: if i.state_id == i.dead.id:
continue continue
if self.prob(prob_death): if self.prob(prob_death):
self.debug("I killed a rabbit: {}".format(i.unique_id)) self.debug("I killed a rabbit: {}".format(i.unique_id))
num_alive -= 1 num_alive -= 1
i.die() self.model.remove_agent(i)
i.alive = False
i.killed = True
self.debug("Rabbits alive: {}".format(num_alive)) self.debug("Rabbits alive: {}".format(num_alive))

@ -259,7 +259,6 @@ def main(
except Exception as ex: except Exception as ex:
if args.pdb: if args.pdb:
from .debugging import post_mortem from .debugging import post_mortem
print(traceback.format_exc()) print(traceback.format_exc())
post_mortem() post_mortem()
else: else:

@ -30,8 +30,11 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
Any attribute that is not preceded by an underscore (`_`) will also be added to its state. Any attribute that is not preceded by an underscore (`_`) will also be added to its state.
""" """
def __init__(self, unique_id, model, name=None, init=True, **kwargs): def __init__(self, unique_id=None, model=None, name=None, init=True, **kwargs):
assert isinstance(unique_id, int) # Ideally, model should be the first argument, but Mesa's Agent class has unique_id first
assert not (model is None), "Must provide a model"
if unique_id is None:
unique_id = model.next_id()
super().__init__(unique_id=unique_id, model=model) super().__init__(unique_id=unique_id, model=model)
self.name = ( self.name = (
@ -191,25 +194,25 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}({self.unique_id})" return f"{self.__class__.__name__}({self.unique_id})"
def at(self, at): def at(self, at):
return time.Delay(float(at) - self.now) return time.Delay(float(at) - self.now)
def delay(self, delay=1): def delay(self, delay=1):
return time.Delay(delay) return time.Delay(delay)
class Noop(BaseAgent):
def step(self):
return
from .network_agents import * from .network_agents import *
from .fsm import * from .fsm import *
from .evented import * from .evented import *
from .view import * from .view import *
class Noop(EventedAgent, BaseAgent):
def step(self):
return
class Agent(FSM, EventedAgent, NetworkAgent): class Agent(FSM, EventedAgent, NetworkAgent):
"""Default agent class, has network, FSM and event capabilities""" """Default agent class, has network, FSM and event capabilities"""

@ -16,7 +16,7 @@ class EventedAgent(BaseAgent):
self.model.register(self) self.model.register(self)
def received(self, **kwargs): def received(self, **kwargs):
return self.model.received(self, **kwargs) return self.model.received(agent=self, **kwargs)
def tell(self, msg, **kwargs): def tell(self, msg, **kwargs):
return self.model.tell(msg, recipient=self, **kwargs) return self.model.tell(msg, recipient=self, **kwargs)

@ -6,39 +6,38 @@ import inspect
class State: class State:
__slots__ = ("awaitable", "f", "generator", "name", "default") __slots__ = ("awaitable", "f", "attribute", "generator", "name", "default")
def __init__(self, f, name, default, generator, awaitable): def __init__(self, f, name, default, generator, awaitable):
self.f = f self.f = f
self.name = name self.name = name
self.attribute = "_{}".format(name)
self.generator = generator self.generator = generator
self.awaitable = awaitable self.awaitable = awaitable
self.default = default self.default = default
@coroutine
def step(self, obj):
if self.generator or self.awaitable:
f = self.f
next_state = yield from f(obj)
return next_state
else:
return self.f(obj)
@property @property
def id(self): def id(self):
return self.name return self.name
def __call__(self, *args, **kwargs): def __get__(self, obj, owner=None):
raise Exception("States should not be called directly") if obj is None:
return self
class UnboundState(State): try:
return getattr(obj, self.attribute)
except AttributeError:
b = self.bind(obj)
setattr(obj, self.attribute, b)
return b
def bind(self, obj): def bind(self, obj):
bs = BoundState(self.f, self.name, self.default, self.generator, self.awaitable, obj=obj) bs = BoundState(self.f, self.name, self.default, self.generator, self.awaitable, obj=obj)
setattr(obj, self.name, bs) setattr(obj, self.name, bs)
return bs return bs
def __call__(self, *args, **kwargs):
raise Exception("States should not be called directly")
class BoundState(State): class BoundState(State):
__slots__ = ("obj", ) __slots__ = ("obj", )
@ -46,10 +45,21 @@ class BoundState(State):
def __init__(self, *args, obj): def __init__(self, *args, obj):
super().__init__(*args) super().__init__(*args)
self.obj = obj self.obj = obj
@coroutine
def __call__(self):
if self.generator or self.awaitable:
f = self.f
next_state = yield from f(self.obj)
return next_state
else:
return self.f(self.obj)
def delay(self, delta=0): def delay(self, delta=0):
return self, self.obj.delay(delta) return self, self.obj.delay(delta)
def at(self, when): def at(self, when):
return self, self.obj.at(when) return self, self.obj.at(when)
@ -63,7 +73,7 @@ def state(name=None, default=False):
name = name or func.__name__ name = name or func.__name__
generator = inspect.isgeneratorfunction(func) generator = inspect.isgeneratorfunction(func)
awaitable = inspect.iscoroutinefunction(func) or inspect.isasyncgen(func) awaitable = inspect.iscoroutinefunction(func) or inspect.isasyncgen(func)
return UnboundState(func, name, default, generator, awaitable) return State(func, name, default, generator, awaitable)
if callable(name): if callable(name):
return decorator(name) return decorator(name)
@ -113,15 +123,24 @@ class MetaFSM(MetaAgent):
class FSM(BaseAgent, metaclass=MetaFSM): class FSM(BaseAgent, metaclass=MetaFSM):
def __init__(self, init=True, state_id=None, **kwargs): def __init__(self, init=True, state_id=None, **kwargs):
super().__init__(**kwargs, init=False) super().__init__(**kwargs, init=False)
bound_states = {}
for (k, v) in list(self._states.items()):
if isinstance(v, State):
v = v.bind(self)
bound_states[k] = v
setattr(self, k, v)
self._states = bound_states
if state_id is not None: if state_id is not None:
self._set_state(state_id) self._set_state(state_id)
else:
self._set_state(self._state)
# If more than "dead" state is defined, but no default state # If more than "dead" state is defined, but no default state
if len(self._states) > 1 and not self._state: if len(self._states) > 1 and not self._state:
raise ValueError( raise ValueError(
f"No default state specified for {type(self)}({self.unique_id})" f"No default state specified for {type(self)}({self.unique_id})"
) )
for (k, v) in self._states.items():
setattr(self, k, v.bind(self))
if init: if init:
self.init() self.init()
@ -139,6 +158,7 @@ class FSM(BaseAgent, metaclass=MetaFSM):
raise ValueError("Cannot change state after init") raise ValueError("Cannot change state after init")
self._set_state(value) self._set_state(value)
@coroutine
def step(self): def step(self):
if self._state is None: if self._state is None:
if len(self._states) == 1: if len(self._states) == 1:
@ -146,8 +166,7 @@ class FSM(BaseAgent, metaclass=MetaFSM):
else: else:
raise Exception("Invalid state (None) for agent {}".format(self)) raise Exception("Invalid state (None) for agent {}".format(self))
self._check_alive() next_state = yield from self._state()
next_state = yield from self._state.step(self)
try: try:
next_state, when = next_state next_state, when = next_state
@ -167,7 +186,9 @@ class FSM(BaseAgent, metaclass=MetaFSM):
if state not in self._states: if state not in self._states:
raise ValueError("{} is not a valid state".format(state)) raise ValueError("{} is not a valid state".format(state))
state = self._states[state] state = self._states[state]
if not isinstance(state, State): if isinstance(state, State):
state = state.bind(self)
elif not isinstance(state, BoundState):
raise ValueError("{} is not a valid state".format(state)) raise ValueError("{} is not a valid state".format(state))
self._state = state self._state = state
@ -177,4 +198,4 @@ class FSM(BaseAgent, metaclass=MetaFSM):
@state @state
def dead(self): def dead(self):
return time.INFINITY return time.INFINITY

@ -2,44 +2,14 @@ from abc import ABCMeta
from copy import copy from copy import copy
from functools import wraps from functools import wraps
from .. import time from .. import time
from ..decorators import syncify, while_alive
import types import types
import inspect import inspect
def decorate_generator_step(func, name):
@wraps(func)
def decorated(self):
if not self.alive:
return time.INFINITY
if self._coroutine is None: class MetaAnnotations(ABCMeta):
self._coroutine = func(self) """This metaclass sets default values for agents based on class attributes"""
try:
if self._last_except:
val = self._coroutine.throw(self._last_except)
else:
val = self._coroutine.send(self._last_return)
except StopIteration as ex:
self._coroutine = None
val = ex.value
finally:
self._last_return = None
self._last_except = None
return float(val) if val is not None else val
return decorated
def decorate_normal_step(func, name):
@wraps(func)
def decorated(self):
# if not self.alive:
# return time.INFINITY
val = func(self)
return float(val) if val is not None else val
return decorated
class MetaAgent(ABCMeta):
def __new__(mcls, name, bases, namespace): def __new__(mcls, name, bases, namespace):
defaults = {} defaults = {}
@ -53,22 +23,7 @@ class MetaAgent(ABCMeta):
} }
for attr, func in namespace.items(): for attr, func in namespace.items():
if attr == "step": if (
if inspect.isgeneratorfunction(func) or inspect.iscoroutinefunction(func):
func = decorate_generator_step(func, attr)
new_nmspc.update({
"_last_return": None,
"_last_except": None,
"_coroutine": None,
})
elif inspect.isasyncgenfunction(func):
raise ValueError("Illegal step function: {}. It probably mixes both async/await and yield".format(func))
elif inspect.isfunction(func):
func = decorate_normal_step(func, attr)
else:
raise ValueError("Illegal step function: {}".format(func))
new_nmspc[attr] = func
elif (
isinstance(func, types.FunctionType) isinstance(func, types.FunctionType)
or isinstance(func, property) or isinstance(func, property)
or isinstance(func, classmethod) or isinstance(func, classmethod)
@ -82,6 +37,28 @@ class MetaAgent(ABCMeta):
else: else:
defaults[attr] = copy(func) defaults[attr] = copy(func)
return super().__new__(mcls, name, bases, new_nmspc)
class AutoAgent(ABCMeta):
def __new__(mcls, name, bases, namespace):
if "step" in namespace:
func = namespace["step"]
namespace["_orig_step"] = func
if inspect.isfunction(func):
if inspect.isgeneratorfunction(func) or inspect.iscoroutinefunction(func):
func = syncify(func, method=True)
namespace["step"] = while_alive(func)
elif inspect.isasyncgenfunction(func):
raise ValueError("Illegal step function: {}. It probably mixes both async/await and yield".format(func))
else:
raise ValueError("Illegal step function: {}".format(func))
# Add attributes for their use in the decorated functions # Add attributes for their use in the decorated functions
return super().__new__(mcls, name, bases, new_nmspc) return super().__new__(mcls, name, bases, namespace)
class MetaAgent(AutoAgent, MetaAnnotations):
"""This metaclass sets default values for agents based on class attributes"""
pass

@ -1,5 +1,6 @@
from collections.abc import Mapping, Set from collections.abc import Mapping, Set
from itertools import islice from itertools import islice
from mesa import Agent
class AgentView(Mapping, Set): class AgentView(Mapping, Set):
@ -55,6 +56,8 @@ class AgentView(Mapping, Set):
return list(self.filter(*args, **kwargs)) return list(self.filter(*args, **kwargs))
def __contains__(self, agent_id): def __contains__(self, agent_id):
if isinstance(agent_id, Agent):
agent_id = agent_id.unique_id
return agent_id in self._agents return agent_id in self._agents
def __str__(self): def __str__(self):

@ -19,7 +19,8 @@ def plot(env, agent_df=None, model_df=None, steps=False, ignore=["agent_count",
try: try:
agent_df = env.agent_df() agent_df = env.agent_df()
except UserWarning: except UserWarning:
print("No agent dataframe provided and no agent reporters found. Skipping agent plot.", file=sys.stderr) print("No agent dataframe provided and no agent reporters found. "
"Skipping agent plot.", file=sys.stderr)
return return
if not agent_df.empty: if not agent_df.empty:
agent_df.unstack().apply(lambda x: x.value_counts(), agent_df.unstack().apply(lambda x: x.value_counts(),
@ -48,9 +49,5 @@ def read_sql(fpath=None, name=None, include_agents=False):
agents = pd.read_sql_table("agents", con=conn, index_col=["params_id", "iteration_id", "step", "agent_id"]) agents = pd.read_sql_table("agents", con=conn, index_col=["params_id", "iteration_id", "step", "agent_id"])
config = pd.read_sql_table("configuration", con=conn, index_col="simulation_id") config = pd.read_sql_table("configuration", con=conn, index_col="simulation_id")
parameters = pd.read_sql_table("parameters", con=conn, index_col=["simulation_id", "params_id", "iteration_id"]) parameters = pd.read_sql_table("parameters", con=conn, index_col=["simulation_id", "params_id", "iteration_id"])
# try:
# parameters = parameters.pivot(columns="key", values="value")
# except Exception as e:
# print(f"warning: coult not pivot parameters: {e}")
return Results(config, parameters, env, agents) return Results(config, parameters, env, agents)

@ -1,6 +1,42 @@
from functools import wraps
from .time import INFINITY
def report(f: property): def report(f: property):
if isinstance(f, property): if isinstance(f, property):
setattr(f.fget, "add_to_report", True) setattr(f.fget, "add_to_report", True)
else: else:
setattr(f, "add_to_report", True) setattr(f, "add_to_report", True)
return f return f
def syncify(func, method=True):
_coroutine = None
@wraps(func)
def wrapped(*args, **kwargs):
if not method:
nonlocal _coroutine
else:
_coroutine = getattr(args[0], "_coroutine", None)
_coroutine = _coroutine or func(*args, **kwargs)
try:
val = _coroutine.send(None)
except StopIteration as ex:
_coroutine = None
val = ex.value
finally:
if method:
args[0]._coroutine = _coroutine
return val
return wrapped
def while_alive(func):
@wraps(func)
def wrapped(self, *args, **kwargs):
if self.alive:
return func(self, *args, **kwargs)
return INFINITY
return wrapped

@ -11,6 +11,7 @@ import networkx as nx
from mesa import Model from mesa import Model
from time import time as current_time
from . import agents as agentmod, datacollection, utils, time, network, events from . import agents as agentmod, datacollection, utils, time, network, events
@ -43,6 +44,7 @@ class BaseEnvironment(Model):
tables: Optional[Any] = None, tables: Optional[Any] = None,
**kwargs: Any) -> Any: **kwargs: Any) -> Any:
"""Create a new model with a default seed value""" """Create a new model with a default seed value"""
seed = seed or str(current_time())
self = super().__new__(cls, *args, seed=seed, **kwargs) self = super().__new__(cls, *args, seed=seed, **kwargs)
self.dir_path = dir_path or os.getcwd() self.dir_path = dir_path or os.getcwd()
collector_class = collector_class or cls.collector_class collector_class = collector_class or cls.collector_class
@ -136,7 +138,7 @@ class BaseEnvironment(Model):
@property @property
def now(self): def now(self):
if self.schedule: if self.schedule is not None:
return self.schedule.time return self.schedule.time
raise Exception( raise Exception(
"The environment has not been scheduled, so it has no sense of time" "The environment has not been scheduled, so it has no sense of time"
@ -160,6 +162,10 @@ class BaseEnvironment(Model):
self.schedule.add(a) self.schedule.add(a)
return a return a
def remove_agent(self, agent):
agent.alive = False
self.schedule.remove(agent)
def add_agents(self, agent_classes: List[type], k, weights: Optional[List[float]] = None, **kwargs): def add_agents(self, agent_classes: List[type], k, weights: Optional[List[float]] = None, **kwargs):
if isinstance(agent_classes, type): if isinstance(agent_classes, type):
agent_classes = [agent_classes] agent_classes = [agent_classes]
@ -188,12 +194,15 @@ class BaseEnvironment(Model):
super().step() super().step()
self.schedule.step() self.schedule.step()
self.datacollector.collect(self) self.datacollector.collect(self)
if self.now == time.INFINITY:
self.running = False
if self.logger.isEnabledFor(logging.DEBUG): if self.logger.isEnabledFor(logging.DEBUG):
msg = "Model data:\n" msg = "Model data:\n"
max_width = max(len(k) for k in self.datacollector.model_vars.keys()) max_width = max(len(k) for k in self.datacollector.model_vars.keys())
for (k, v) in self.datacollector.model_vars.items(): for (k, v) in self.datacollector.model_vars.items():
msg += f"\t{k:<{max_width}}: {v[-1]:>6}\n" # msg += f"\t{k:<{max_width}}"
msg += f"\t{k:<{max_width}}: {v[-1]}\n"
self.logger.debug(f"--- Steps: {self.schedule.steps:^5} - Time: {self.now:^5} --- " + msg) self.logger.debug(f"--- Steps: {self.schedule.steps:^5} - Time: {self.now:^5} --- " + msg)
def add_model_reporter(self, name, func=None): def add_model_reporter(self, name, func=None):
@ -297,6 +306,11 @@ class NetworkEnvironment(BaseEnvironment):
self.G.nodes[node_id]["agent"] = a self.G.nodes[node_id]["agent"] = a
return a return a
def remove_agent(self, agent, remove_node=True):
super().remove_agent(agent)
if remove_node and hasattr(agent, "remove_node"):
agent.remove_node()
def add_agents(self, *args, k=None, **kwargs): def add_agents(self, *args, k=None, **kwargs):
if not k and not self.G: if not k and not self.G:
raise ValueError("Cannot add agents to an empty network") raise ValueError("Cannot add agents to an empty network")
@ -344,6 +358,7 @@ class NetworkEnvironment(BaseEnvironment):
) )
if node_id is None: if node_id is None:
node_id = f"Node_for_agent_{unique_id}" node_id = f"Node_for_agent_{unique_id}"
assert node_id not in self.G.nodes
if node_id not in self.G.nodes: if node_id not in self.G.nodes:
self.G.add_node(node_id) self.G.add_node(node_id)
@ -417,7 +432,10 @@ class EventedEnvironment(BaseEnvironment):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._inbox = dict() self._inbox = dict()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._can_reschedule = hasattr(self.schedule, "add_callback") and hasattr(self.schedule, "remove_callback")
self._can_reschedule = True
self._callbacks = {}
def register(self, agent): def register(self, agent):
self._inbox[agent.unique_id] = [] self._inbox[agent.unique_id] = []
@ -429,24 +447,47 @@ class EventedEnvironment(BaseEnvironment):
"Make sure your agent is of type EventedAgent and it is registered with the environment.") "Make sure your agent is of type EventedAgent and it is registered with the environment.")
@coroutine @coroutine
def received(self, agent, expiration=None, timeout=60, delay=1): def _polling_callback(self, agent, expiration, delay):
if not expiration: # this wakes the agent up at every step. It is better to wait until timeout (or inf)
expiration = self.now + timeout # and if a message is received before that, reschedule the agent
# (That is implemented in the `received` method)
inbox = self.inbox_for(agent) inbox = self.inbox_for(agent)
if inbox:
return self.process_messages(inbox)
while self.now < expiration: while self.now < expiration:
# TODO: this wakes the agent up at every step. It would be better to wait until timeout (or inf)
# and if a message is received before that, reschedule the agent when
if inbox: if inbox:
return self.process_messages(inbox) return self.process_messages(inbox)
yield time.Delay(delay) yield time.Delay(delay)
raise events.TimedOut("No message received") raise events.TimedOut("No message received")
def tell(self, msg, sender, recipient, expiration=None, timeout=None, **kwargs): @coroutine
def received(self, agent, expiration=None, timeout=None, delay=1):
if not expiration:
if timeout:
expiration = self.now + timeout
else:
expiration = float("inf")
inbox = self.inbox_for(agent)
if inbox:
return self.process_messages(inbox)
if self._can_reschedule:
checked = False
def cb():
nonlocal checked
if checked:
return time.INFINITY
checked = True
self.schedule.add_callback(self.now, agent.step)
self.schedule.add_callback(expiration, cb)
self._callbacks[agent.unique_id] = cb
yield time.INFINITY
res = yield from self._polling_callback(agent, expiration, delay)
return res
def tell(self, msg, recipient, sender=None, expiration=None, timeout=None, **kwargs):
if expiration is None: if expiration is None:
expiration = float("inf") if timeout is None else self.now + timeout expiration = float("inf") if timeout is None else self.now + timeout
self.inbox_for(recipient).append( self._add_to_inbox(recipient.unique_id,
events.Tell(timestamp=self.now, events.Tell(timestamp=self.now,
payload=msg, payload=msg,
sender=sender, sender=sender,
@ -463,18 +504,23 @@ class EventedEnvironment(BaseEnvironment):
if agent_class and not isinstance(self.agents(unique_id=agent_id), agent_class): if agent_class and not isinstance(self.agents(unique_id=agent_id), agent_class):
continue continue
self.logger.debug(f"Telling {agent_id}: {msg} ttl={ttl}") self.logger.debug(f"Telling {agent_id}: {msg} ttl={ttl}")
inbox.append( self._add_to_inbox(agent_id,
events.Tell( events.Tell(
payload=msg, payload=msg,
sender=sender, sender=sender,
expiration=expiration, expiration=expiration,
) )
) )
def _add_to_inbox(self, inbox_id, msg):
self._inbox[inbox_id].append(msg)
if inbox_id in self._callbacks:
cb = self._callbacks.pop(inbox_id)
cb()
@coroutine @coroutine
def ask(self, msg, recipient, sender=None, expiration=None, timeout=None, delay=1): def ask(self, msg, recipient, sender=None, expiration=None, timeout=None, delay=1):
ask = events.Ask(timestamp=self.now, payload=msg, sender=sender) ask = events.Ask(timestamp=self.now, payload=msg, sender=sender)
self.inbox_for(recipient).append(ask) self._add_to_inbox(recipient.unique_id, ask)
expiration = float("inf") if timeout is None else self.now + timeout expiration = float("inf") if timeout is None else self.now + timeout
while self.now < expiration: while self.now < expiration:
if ask.reply: if ask.reply:
@ -493,4 +539,4 @@ class EventedEnvironment(BaseEnvironment):
class Environment(EventedEnvironment, NetworkEnvironment): class Environment(EventedEnvironment, NetworkEnvironment):
pass pass

@ -75,6 +75,13 @@ class Exporter:
def iteration_end(self, env, params, params_id): def iteration_end(self, env, params, params_id):
"""Method to call when a iteration ends""" """Method to call when a iteration ends"""
pass pass
def env_id(self, env):
try:
return env.id
except AttributeError:
return f"{env.__class__.__name__}_{current_time()}"
def output(self, f, mode="w", **kwargs): def output(self, f, mode="w", **kwargs):
if not self.dump: if not self.dump:
@ -90,7 +97,7 @@ class Exporter:
def get_dfs(self, env, params_id, **kwargs): def get_dfs(self, env, params_id, **kwargs):
yield from get_dc_dfs(env.datacollector, yield from get_dc_dfs(env.datacollector,
params_id, params_id,
iteration_id=env.id, iteration_id=self.env_id(env),
**kwargs) **kwargs)
@ -157,11 +164,11 @@ class SQLite(Exporter):
return return
with timer( with timer(
"Dumping simulation {} iteration {}".format(self.simulation.name, env.id) "Dumping simulation {} iteration {}".format(self.simulation.name, self.env_id(env))
): ):
d = {"simulation_id": self.simulation.id, d = {"simulation_id": self.simulation.id,
"params_id": params_id, "params_id": params_id,
"iteration_id": env.id, "iteration_id": self.env_id(env),
} }
for (k,v) in params.items(): for (k,v) in params.items():
d[k] = serialize(v)[0] d[k] = serialize(v)[0]
@ -173,7 +180,7 @@ class SQLite(Exporter):
pd.DataFrame([{ pd.DataFrame([{
"simulation_id": self.simulation.id, "simulation_id": self.simulation.id,
"params_id": params_id, "params_id": params_id,
"iteration_id": env.id, "iteration_id": self.env_id(env),
}]).reset_index().to_sql("iterations", }]).reset_index().to_sql("iterations",
con=self.engine, con=self.engine,
if_exists="append", if_exists="append",
@ -191,11 +198,11 @@ class csv(Exporter):
def iteration_end(self, env, params, params_id, *args, **kwargs): def iteration_end(self, env, params, params_id, *args, **kwargs):
with timer( with timer(
"[CSV] Dumping simulation {} iteration {} @ dir {}".format( "[CSV] Dumping simulation {} iteration {} @ dir {}".format(
self.simulation.name, env.id, self.outdir self.simulation.name, self.env_id(env), self.outdir
) )
): ):
for (df_name, df) in self.get_dfs(env, params_id=params_id): for (df_name, df) in self.get_dfs(env, params_id=params_id):
with self.output("{}.{}.csv".format(env.id, df_name), mode="a") as f: with self.output("{}.{}.csv".format(self.env_id(env), df_name), mode="a") as f:
df.to_csv(f) df.to_csv(f)
@ -206,9 +213,9 @@ class gexf(Exporter):
return return
with timer( with timer(
"[GEXF] Dumping simulation {} iteration {}".format(self.simulation.name, env.id) "[GEXF] Dumping simulation {} iteration {}".format(self.simulation.name, self.env_id(env))
): ):
with self.output("{}.gexf".format(env.id), mode="wb") as f: with self.output("{}.gexf".format(self.env_id(env)), mode="wb") as f:
nx.write_gexf(env.G, f) nx.write_gexf(env.G, f)
@ -242,7 +249,7 @@ class graphdrawing(Exporter):
pos=nx.spring_layout(env.G, scale=100), pos=nx.spring_layout(env.G, scale=100),
ax=f.add_subplot(111), ax=f.add_subplot(111),
) )
with open("graph-{}.png".format(env.id)) as f: with open("graph-{}.png".format(self.env_id(env))) as f:
f.savefig(f) f.savefig(f)

@ -44,8 +44,8 @@ def do_not_run():
def _iter_queued(): def _iter_queued():
while _QUEUED: while _QUEUED:
(cls, params) = _QUEUED.pop(0) slf = _QUEUED.pop(0)
yield replace(cls, parameters=params) yield slf
# TODO: change documentation for simulation # TODO: change documentation for simulation
@ -130,11 +130,11 @@ class Simulation:
def run(self, **kwargs): def run(self, **kwargs):
"""Run the simulation and return the list of resulting environments""" """Run the simulation and return the list of resulting environments"""
if kwargs: if kwargs:
return replace(self, **kwargs).run() res = replace(self, **kwargs)
return res.run()
param_combinations = self._collect_params(**kwargs)
if _AVOID_RUNNING: if _AVOID_RUNNING:
_QUEUED.extend((self, param) for param in param_combinations) _QUEUED.append(self)
return [] return []
self.logger.debug("Using exporters: %s", self.exporters or []) self.logger.debug("Using exporters: %s", self.exporters or [])
@ -154,6 +154,8 @@ class Simulation:
for exporter in exporters: for exporter in exporters:
exporter.sim_start() exporter.sim_start()
param_combinations = self._collect_params(**kwargs)
for params in tqdm(param_combinations, desc=self.name, unit="configuration"): for params in tqdm(param_combinations, desc=self.name, unit="configuration"):
for (k, v) in params.items(): for (k, v) in params.items():
tqdm.write(f"{k} = {v}") tqdm.write(f"{k} = {v}")
@ -204,6 +206,7 @@ class Simulation:
for env in tqdm(utils.run_parallel( for env in tqdm(utils.run_parallel(
func=func, func=func,
iterable=range(self.iterations), iterable=range(self.iterations),
num_processes=self.num_processes,
**params, **params,
), total=self.iterations, leave=False): ), total=self.iterations, leave=False):
if env is None and self.dry_run: if env is None and self.dry_run:
@ -338,12 +341,13 @@ def iter_from_py(pyfile, module_name='imported_file', **kwargs):
sims.append(sim) sims.append(sim)
for sim in _iter_queued(): for sim in _iter_queued():
sims.append(sim) sims.append(sim)
# Try to find environments to run, because we did not import a script that ran simulations
if not sims: if not sims:
for (_name, env) in inspect.getmembers(module, for (_name, env) in inspect.getmembers(module,
lambda x: inspect.isclass(x) and lambda x: inspect.isclass(x) and
issubclass(x, environment.Environment) and issubclass(x, environment.Environment) and
(getattr(x, "__module__", None) != environment.__name__)): (getattr(x, "__module__", None) != environment.__name__)):
sims.append(Simulation(model=env, **kwargs)) sims.append(Simulation(model=env, **kwargs))
del sys.modules[module_name] del sys.modules[module_name]
assert not _AVOID_RUNNING assert not _AVOID_RUNNING
if not sims: if not sims:

@ -24,7 +24,10 @@ class Delay:
def __float__(self): def __float__(self):
return self.delta return self.delta
def __eq__(self, other):
return float(self) == float(other)
def __await__(self): def __await__(self):
return (yield self.delta) return (yield self.delta)
@ -87,6 +90,9 @@ class PQueueSchedule:
del self._queue[i] del self._queue[i]
break break
def __len__(self):
return len(self._queue)
def step(self) -> None: def step(self) -> None:
""" """
Executes events in order, one at a time. After each step, Executes events in order, one at a time. After each step,
@ -107,7 +113,8 @@ class PQueueSchedule:
next_time = when next_time = when
break break
when = event.func() or 1 when = event.func()
when = float(when) if when is not None else 1.0
if when == INFINITY: if when == INFINITY:
heappop(self._queue) heappop(self._queue)
@ -153,12 +160,18 @@ class Schedule:
return lst return lst
def insert(self, when, func, replace=False): def insert(self, when, func, replace=False):
if when == INFINITY:
return
lst = self._find_loc(when) lst = self._find_loc(when)
lst.append(func) lst.append(func)
def add_bulk(self, funcs, when=None): def add_bulk(self, funcs, when=None):
lst = self._find_loc(when) lst = self._find_loc(when)
n = len(funcs)
#TODO: remove for performance
before = len(self)
lst.extend(funcs) lst.extend(funcs)
assert len(self) == before + n
def remove(self, func): def remove(self, func):
for bucket in self._queue: for bucket in self._queue:
@ -167,6 +180,9 @@ class Schedule:
bucket.remove(ix) bucket.remove(ix)
return return
def __len__(self):
return sum(len(bucket[1]) for bucket in self._queue)
def step(self) -> None: def step(self) -> None:
""" """
Executes events in order, one at a time. After each step, Executes events in order, one at a time. After each step,
@ -188,11 +204,14 @@ class Schedule:
self.random.shuffle(bucket) self.random.shuffle(bucket)
next_batch = defaultdict(list) next_batch = defaultdict(list)
for func in bucket: for func in bucket:
when = func() or 1 when = func()
when = float(when) if when is not None else 1
if when == INFINITY:
continue
if when != INFINITY: when += now
when += now next_batch[when].append(func)
next_batch[when].append(func)
for (when, bucket) in next_batch.items(): for (when, bucket) in next_batch.items():
self.add_bulk(bucket, when) self.add_bulk(bucket, when)
@ -229,6 +248,12 @@ class InnerActivation(BaseScheduler):
self.agents_by_type[agent_class][agent.unique_id] = agent self.agents_by_type[agent_class][agent.unique_id] = agent
super().add(agent) super().add(agent)
def add_callback(self, when, cb):
self.inner.insert(when, cb)
def remove_callback(self, when, cb):
self.inner.remove(cb)
def remove(self, agent): def remove(self, agent):
del self._agents[agent.unique_id] del self._agents[agent.unique_id]
del self.agents_by_type[type(agent)][agent.unique_id] del self.agents_by_type[type(agent)][agent.unique_id]
@ -241,6 +266,9 @@ class InnerActivation(BaseScheduler):
""" """
self.inner.step() self.inner.step()
def __len__(self):
return len(self.inner)
class BucketTimedActivation(InnerActivation): class BucketTimedActivation(InnerActivation):
inner_class = Schedule inner_class = Schedule
@ -250,16 +278,19 @@ class PQueueActivation(InnerActivation):
inner_class = PQueueSchedule inner_class = PQueueSchedule
# Set the bucket implementation as default #Set the bucket implementation as default
TimedActivation = BucketTimedActivation
try: try:
from soilent.soilent import BucketScheduler from soilent.soilent import BucketScheduler, PQueueScheduler
class SoilBucketActivation(InnerActivation): class SoilentActivation(InnerActivation):
inner_class = BucketScheduler inner_class = BucketScheduler
class SoilentPQueueActivation(InnerActivation):
inner_class = PQueueScheduler
TimedActivation = SoilBucketActivation # TimedActivation = SoilentBucketActivation
except ImportError: except ImportError:
TimedActivation = BucketTimedActivation
pass pass

@ -93,15 +93,12 @@ def flatten_dict(d):
def _flatten_dict(d, prefix=""): def _flatten_dict(d, prefix=""):
if not isinstance(d, dict): if not isinstance(d, dict):
# print('END:', prefix, d)
yield prefix, d yield prefix, d
return return
if prefix: if prefix:
prefix = prefix + "." prefix = prefix + "."
for k, v in d.items(): for k, v in d.items():
# print(k, v)
res = list(_flatten_dict(v, prefix="{}{}".format(prefix, k))) res = list(_flatten_dict(v, prefix="{}{}".format(prefix, k)))
# print('RES:', res)
yield from res yield from res
@ -142,6 +139,7 @@ def run_and_return_exceptions(func, *args, **kwargs):
def run_parallel(func, iterable, num_processes=1, **kwargs): def run_parallel(func, iterable, num_processes=1, **kwargs):
if num_processes > 1 and not os.environ.get("SOIL_DEBUG", None): if num_processes > 1 and not os.environ.get("SOIL_DEBUG", None):
logger.info("Running simulations in {} processes".format(num_processes))
if num_processes < 1: if num_processes < 1:
num_processes = cpu_count() - num_processes num_processes = cpu_count() - num_processes
p = Pool(processes=num_processes) p = Pool(processes=num_processes)

@ -1,7 +1,7 @@
from unittest import TestCase from unittest import TestCase
import pytest import pytest
from soil import agents, environment from soil import agents, events, environment
from soil import time as stime from soil import time as stime
@ -25,7 +25,7 @@ class TestAgents(TestCase):
assert d.alive assert d.alive
d.step() d.step()
assert not d.alive assert not d.alive
when = d.step() when = float(d.step())
assert not d.alive assert not d.alive
assert when == stime.INFINITY assert when == stime.INFINITY
@ -63,6 +63,7 @@ class TestAgents(TestCase):
def other(self): def other(self):
self.times_run += 1 self.times_run += 1
assert MyAgent.other.id == "other"
e = environment.Environment() e = environment.Environment()
a = e.add_agent(MyAgent) a = e.add_agent(MyAgent)
e.step() e.step()
@ -73,6 +74,53 @@ class TestAgents(TestCase):
a.step() a.step()
assert a.times_run == 2 assert a.times_run == 2
def test_state_decorator_multiple(self):
class MyAgent(agents.FSM):
times_run = 0
@agents.state(default=True)
def one(self):
return self.two
@agents.state
def two(self):
return self.one
e = environment.Environment()
first = e.add_agent(MyAgent, state_id=MyAgent.one)
second = e.add_agent(MyAgent, state_id=MyAgent.two)
assert first.state_id == MyAgent.one.id
assert second.state_id == MyAgent.two.id
e.step()
assert first.state_id == MyAgent.two.id
assert second.state_id == MyAgent.one.id
def test_state_decorator_multiple_async(self):
class MyAgent(agents.FSM):
times_run = 0
@agents.state(default=True)
def one(self):
yield self.delay(1)
return self.two
@agents.state
def two(self):
yield self.delay(1)
return self.one
e = environment.Environment()
first = e.add_agent(MyAgent, state_id=MyAgent.one)
second = e.add_agent(MyAgent, state_id=MyAgent.two)
for i in range(2):
assert first.state_id == MyAgent.one.id
assert second.state_id == MyAgent.two.id
e.step()
for i in range(2):
assert first.state_id == MyAgent.two.id
assert second.state_id == MyAgent.one.id
e.step()
def test_broadcast(self): def test_broadcast(self):
""" """
An agent should be able to broadcast messages to every other agent, AND each receiver should be able An agent should be able to broadcast messages to every other agent, AND each receiver should be able
@ -372,22 +420,105 @@ class TestAgents(TestCase):
assert a.now == 17 assert a.now == 17
assert a.my_state == 5 assert a.my_state == 5
def test_send_nonevent(self): def test_receive(self):
'''
An agent should be able to receive a message after waiting
'''
model = environment.Environment()
class TestAgent(agents.Agent):
sent = False
woken = 0
def step(self):
self.woken += 1
return super().step()
@agents.state(default=True)
async def one(self):
try:
self.sent = await self.received(timeout=15)
return self.two.at(20)
except events.TimedOut:
pass
@agents.state
def two(self):
return self.die()
a = model.add_agent(TestAgent)
class Sender(agents.Agent):
async def step(self):
await self.delay(10)
a.tell(1)
return stime.INFINITY
b = model.add_agent(Sender)
# Start and wait
model.step()
assert model.now == 10
assert a.woken == 1
assert not a.sent
# Sending the message
model.step()
assert model.now == 10
assert a.woken == 1
assert not a.sent
# The receiver callback
model.step()
assert model.now == 15
assert a.woken == 2
assert a.sent[0].payload == 1
# The timeout
model.step()
assert model.now == 20
assert a.woken == 2
# The last state of the agent
model.step()
assert a.woken == 3
assert model.now == float('inf')
def test_receive_timeout(self):
''' '''
Sending a non-event should raise an error. A timeout should be raised if no messages are received after an expiration time
''' '''
model = environment.Environment() model = environment.Environment()
a = model.add_agent(agents.Noop) timedout = False
class TestAgent(agents.Agent): class TestAgent(agents.Agent):
@agents.state(default=True) @agents.state(default=True)
def one(self): def one(self):
try: try:
a.tell(b, 1) yield from self.received(timeout=10)
raise AssertionError('Should have raised an error.') raise AssertionError('Should have raised an error.')
except AttributeError: except events.TimedOut:
self.model.tell(1, sender=self, recipient=a) nonlocal timedout
timedout = True
model.add_agent(TestAgent) a = model.add_agent(TestAgent)
with pytest.raises(ValueError): model.step()
model.step() assert model.now == 10
model.step()
# Wake up the callback
assert model.now == 10
assert not timedout
# The actual timeout
model.step()
assert model.now == 11
assert timedout
def test_attributes(self):
"""Attributes should be individual per agent"""
class MyAgent(agents.Agent):
my_attribute = 0
model = environment.Environment()
a = MyAgent(model=model)
assert a.my_attribute == 0
b = MyAgent(model=model, my_attribute=1)
assert b.my_attribute == 1
assert a.my_attribute == 0

@ -6,7 +6,7 @@ import networkx as nx
from functools import partial from functools import partial
from os.path import join from os.path import join
from soil import simulation, Environment, agents, network, serialization, utils, config, from_file from soil import simulation, Environment, agents, serialization, from_file, time
from mesa import Agent as MesaAgent from mesa import Agent as MesaAgent
ROOT = os.path.abspath(os.path.dirname(__file__)) ROOT = os.path.abspath(os.path.dirname(__file__))
@ -194,7 +194,7 @@ class TestMain(TestCase):
return self.ping return self.ping
a = ToggleAgent(unique_id=1, model=Environment()) a = ToggleAgent(unique_id=1, model=Environment())
when = a.step() when = float(a.step())
assert when == 2 assert when == 2
when = a.step() when = a.step()
assert when == None assert when == None
@ -252,4 +252,34 @@ class TestMain(TestCase):
assert df["base"][(0,0)] == "base" assert df["base"][(0,0)] == "base"
assert df["base"][(0,1)] == "base" assert df["base"][(0,1)] == "base"
assert df["subclass"][(0,0)] is None assert df["subclass"][(0,0)] is None
assert df["subclass"][(0,1)] == "subclass" assert df["subclass"][(0,1)] == "subclass"
def test_remove_agent(self):
"""An agent that is scheduled should be removed from the schedule"""
model = Environment()
model.add_agent(agents.Noop)
model.step()
model.remove_agent(model.agents[0])
assert not model.agents
when = model.step()
assert when == None
assert not model.running
def test_remove_agent(self):
"""An agent that is scheduled should be removed from the schedule"""
allagents = []
class Removed(agents.BaseAgent):
def step(self):
nonlocal allagents
assert self.alive
assert self in self.model.agents
for agent in allagents:
self.model.remove_agent(agent)
model = Environment()
a1 = model.add_agent(Removed)
a2 = model.add_agent(Removed)
allagents = [a1, a2]
model.step()
assert not model.agents
Loading…
Cancel
Save