Compare commits

...

6 Commits

Author SHA1 Message Date
J. Fernando Sánchez d3cee18635 Add seed to cars example 2 years ago
J. Fernando Sánchez 9a7b62e88e Release 0.30.0rc3 2 years ago
J. Fernando Sánchez c09e480d37 black formatting 2 years ago
J. Fernando Sánchez b2d48cb4df Add test cases for 'ASK' 2 years ago
J. Fernando Sánchez a1262edd2a Refactored time
Treating time and conditions as the same entity was getting confusing, and it
added a lot of unnecessary abstraction in a critical part (the scheduler).

The scheduling queue now has the time as a floating number (faster), the agent
id (for ties) and the condition, as well as the agent. The first three
elements (time, id, condition) can be considered as the "key" for the event.

To allow for agent execution to be "randomized" within every step, a new
parameter has been added to the scheduler, which makes it add a random number to
the key in order to change the ordering.

`EventedAgent.received` now checks the messages before returning control to the
user by default.
2 years ago
J. Fernando Sánchez cbbaf73538 Fix bug EventedEnvironment 2 years ago

@ -127,7 +127,8 @@ class Driver(Evented, FSM):
) )
self.check_passengers() self.check_passengers()
self.check_messages() # This will call on_receive behind the scenes, and the agent's status will be updated # This will call on_receive behind the scenes, and the agent's status will be updated
self.check_messages()
yield Delta(30) # Wait at least 30 seconds before checking again yield Delta(30) # Wait at least 30 seconds before checking again
try: try:
@ -204,6 +205,8 @@ class Passenger(Evented, FSM):
while not self.journey: while not self.journey:
self.info(f"Passenger at: { self.pos }. Checking for responses.") self.info(f"Passenger at: { self.pos }. Checking for responses.")
try: try:
# This will call check_messages behind the scenes, and the agent's status will be updated
# If you want to avoid that, you can call it with: check=False
yield self.received(expiration=expiration) yield self.received(expiration=expiration)
except events.TimedOut: except events.TimedOut:
self.info(f"Passenger at: { self.pos }. Asking for journey.") self.info(f"Passenger at: { self.pos }. Asking for journey.")
@ -211,7 +214,6 @@ class Passenger(Evented, FSM):
journey, ttl=timeout, sender=self, agent_class=Driver journey, ttl=timeout, sender=self, agent_class=Driver
) )
expiration = self.now + timeout expiration = self.now + timeout
self.check_messages()
return self.driving_home return self.driving_home
@state @state
@ -220,13 +222,20 @@ class Passenger(Evented, FSM):
self.pos[0] != self.journey.destination[0] self.pos[0] != self.journey.destination[0]
or self.pos[1] != self.journey.destination[1] or self.pos[1] != self.journey.destination[1]
): ):
yield self.received(timeout=60) try:
yield self.received(timeout=60)
except events.TimedOut:
pass
self.info("Got home safe!") self.info("Got home safe!")
self.die() self.die()
simulation = Simulation( simulation = Simulation(
name="RideHailing", model_class=City, model_params={"n_passengers": 2} name="RideHailing",
model_class=City,
model_params={"n_passengers": 2},
seed="carsSeed",
) )
if __name__ == "__main__": if __name__ == "__main__":

@ -133,7 +133,7 @@ class RandomAccident(BaseAgent):
math.log10(max(1, rabbits_alive)) math.log10(max(1, rabbits_alive))
) )
self.debug("Killing some rabbits with prob={}!".format(prob_death)) self.debug("Killing some rabbits with prob={}!".format(prob_death))
for i in self.iter_agents(agent_class=Rabbit): for i in self.get_agents(agent_class=Rabbit):
if i.state_id == i.dead.id: if i.state_id == i.dead.id:
continue continue
if self.prob(prob_death): if self.prob(prob_death):

@ -258,9 +258,7 @@ class TerroristNetworkModel(TerroristSpreadModel):
) )
neighbours = set( neighbours = set(
agent.id agent.id
for agent in self.get_neighbors( for agent in self.get_neighbors(agent_class=TerroristNetworkModel)
agent_class=TerroristNetworkModel
)
) )
search = (close_ups | step_neighbours) - neighbours search = (close_ups | step_neighbours) - neighbours
for agent in self.get_agents(search): for agent in self.get_agents(search):

@ -1 +1 @@
0.30.0rc2 0.30.0rc3

@ -47,7 +47,7 @@ def main(
"file", "file",
type=str, type=str,
nargs="?", nargs="?",
default=cfg if sim is None else '', default=cfg if sim is None else "",
help="Configuration file for the simulation (e.g., YAML or JSON)", help="Configuration file for the simulation (e.g., YAML or JSON)",
) )
parser.add_argument( parser.add_argument(
@ -169,22 +169,26 @@ def main(
sim.exporters = exporters sim.exporters = exporters
sim.parallel = parallel sim.parallel = parallel
sim.outdir = output sim.outdir = output
sims = [sim, ] sims = [
sim,
]
else: else:
logger.info("Loading config file: {}".format(args.file)) logger.info("Loading config file: {}".format(args.file))
if not os.path.exists(args.file): if not os.path.exists(args.file):
logger.error("Please, input a valid file") logger.error("Please, input a valid file")
return return
sims = list(simulation.iter_from_config( sims = list(
args.file, simulation.iter_from_config(
dry_run=args.dry_run, args.file,
exporters=exporters, dry_run=args.dry_run,
parallel=parallel, exporters=exporters,
outdir=output, parallel=parallel,
exporter_params=exp_params, outdir=output,
**kwargs, exporter_params=exp_params,
)) **kwargs,
)
)
for sim in sims: for sim in sims:

@ -20,12 +20,6 @@ from typing import Dict, List
from .. import serialization, utils, time, config from .. import serialization, utils, time, config
def as_node(agent):
if isinstance(agent, BaseAgent):
return agent.id
return agent
IGNORED_FIELDS = ("model", "logger") IGNORED_FIELDS = ("model", "logger")
@ -97,10 +91,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
""" """
def __init__(self, unique_id, model, name=None, interval=None, **kwargs): def __init__(self, unique_id, model, name=None, interval=None, **kwargs):
# Check for REQUIRED arguments
# Initialize agent parameters
if isinstance(unique_id, MesaAgent):
raise Exception()
assert isinstance(unique_id, int) assert isinstance(unique_id, int)
super().__init__(unique_id=unique_id, model=model) super().__init__(unique_id=unique_id, model=model)
@ -207,7 +197,8 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
def step(self): def step(self):
if not self.alive: if not self.alive:
raise time.DeadAgent(self.unique_id) raise time.DeadAgent(self.unique_id)
return super().step() or time.Delta(self.interval) super().step()
return time.Delta(self.interval)
def log(self, message, *args, level=logging.INFO, **kwargs): def log(self, message, *args, level=logging.INFO, **kwargs):
if not self.logger.isEnabledFor(level): if not self.logger.isEnabledFor(level):
@ -414,7 +405,7 @@ def filter_agents(
if ids: if ids:
f = (agents[aid] for aid in ids if aid in agents) f = (agents[aid] for aid in ids if aid in agents)
else: else:
f = (a for a in agents.values()) f = agents.values()
if state_id is not None and not isinstance(state_id, (tuple, list)): if state_id is not None and not isinstance(state_id, (tuple, list)):
state_id = tuple([state_id]) state_id = tuple([state_id])
@ -638,6 +629,11 @@ from .SentimentCorrelationModel import *
from .SISaModel import * from .SISaModel import *
from .CounterModel import * from .CounterModel import *
class Agent(NetworkAgent, EventedAgent):
"""Default agent class, has both network and event capabilities"""
try: try:
import scipy import scipy
from .Geo import Geo from .Geo import Geo

@ -1,57 +1,77 @@
from . import BaseAgent from . import BaseAgent
from ..events import Message, Tell, Ask, Reply, TimedOut from ..events import Message, Tell, Ask, TimedOut
from ..time import Cond from ..time import BaseCond
from functools import partial from functools import partial
from collections import deque from collections import deque
class Evented(BaseAgent): class ReceivedOrTimeout(BaseCond):
def __init__(
self, agent, expiration=None, timeout=None, check=True, ignore=False, **kwargs
):
if expiration is None:
if timeout is not None:
expiration = agent.now + timeout
self.expiration = expiration
self.ignore = ignore
self.check = check
super().__init__(**kwargs)
def expired(self, time):
return self.expiration and self.expiration < time
def ready(self, agent, time):
return len(agent._inbox) or self.expired(time)
def return_value(self, agent):
if not self.ignore and self.expired(agent.now):
raise TimedOut("No messages received")
if self.check:
agent.check_messages()
return None
def schedule_next(self, time, delta, first=False):
if self._delta is not None:
delta = self._delta
return (time + delta, self)
def __repr__(self):
return f"ReceivedOrTimeout(expires={self.expiration})"
class EventedAgent(BaseAgent):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._inbox = deque() self._inbox = deque()
self._received = 0
self._processed = 0 self._processed = 0
def on_receive(self, *args, **kwargs): def on_receive(self, *args, **kwargs):
pass pass
def received(self, expiration=None, timeout=None): def received(self, *args, **kwargs):
current = self._received return ReceivedOrTimeout(self, *args, **kwargs)
if expiration is None:
expiration = float('inf') if timeout is None else self.now + timeout
if expiration < self.now:
raise ValueError("Invalid expiration time")
def ready(agent):
return agent._received > current or agent.now >= expiration
def value(agent): def tell(self, msg, sender=None):
if agent.now > expiration: self._inbox.append(Tell(timestamp=self.now, payload=msg, sender=sender))
raise TimedOut("No message received")
c = Cond(func=ready, return_func=value) def ask(self, msg, timeout=None, **kwargs):
c._checked = True ask = Ask(timestamp=self.now, payload=msg, sender=self)
return c
def tell(self, msg, sender):
self._received += 1
self._inbox.append(Tell(payload=msg, sender=sender))
def ask(self, msg, timeout=None):
self._received += 1
ask = Ask(payload=msg)
self._inbox.append(ask) self._inbox.append(ask)
expiration = float('inf') if timeout is None else self.now + timeout expiration = float("inf") if timeout is None else self.now + timeout
return ask.replied(expiration=expiration) return ask.replied(expiration=expiration, **kwargs)
def check_messages(self): def check_messages(self):
changed = False
while self._inbox: while self._inbox:
msg = self._inbox.popleft() msg = self._inbox.popleft()
self._processed += 1 self._processed += 1
if msg.expired(self.now): if msg.expired(self.now):
continue continue
changed = True
reply = self.on_receive(msg.payload, sender=msg.sender) reply = self.on_receive(msg.payload, sender=msg.sender)
if isinstance(msg, Ask): if isinstance(msg, Ask):
msg.reply = reply msg.reply = reply
return changed
Evented = EventedAgent

@ -38,8 +38,6 @@ def state(name=None):
self._last_return = None self._last_return = None
self._last_except = None self._last_except = None
func.id = name or func.__name__ func.id = name or func.__name__
func.is_default = False func.is_default = False
return func return func

@ -54,7 +54,7 @@ class NetworkAgent(BaseAgent):
return G return G
def remove_node(self): def remove_node(self):
print(f"Removing node for {self.unique_id}: {self.node_id}") self.debug(f"Removing node for {self.unique_id}: {self.node_id}")
self.G.remove_node(self.node_id) self.G.remove_node(self.node_id)
self.node_id = None self.node_id = None
@ -80,3 +80,6 @@ class NetworkAgent(BaseAgent):
if remove: if remove:
self.remove_node() self.remove_node()
return super().die() return super().die()
NetAgent = NetworkAgent

@ -38,7 +38,7 @@ class BaseEnvironment(Model):
self, self,
id="unnamed_env", id="unnamed_env",
seed="default", seed="default",
schedule=None, schedule_class=time.TimedActivation,
dir_path=None, dir_path=None,
interval=1, interval=1,
agent_class=None, agent_class=None,
@ -58,9 +58,11 @@ class BaseEnvironment(Model):
self.dir_path = dir_path or os.getcwd() self.dir_path = dir_path or os.getcwd()
if schedule is None: if schedule_class is None:
schedule = time.TimedActivation(self) schedule_class = time.TimedActivation
self.schedule = schedule else:
schedule_class = serialization.deserialize(schedule_class)
self.schedule = schedule_class(self)
self.agent_class = agent_class or agentmod.BaseAgent self.agent_class = agent_class or agentmod.BaseAgent
@ -310,15 +312,28 @@ class NetworkEnvironment(BaseEnvironment):
self.add_agent(node_id=node_id, agent_class=a_class, **agent_params) self.add_agent(node_id=node_id, agent_class=a_class, **agent_params)
Environment = NetworkEnvironment class EventedEnvironment(BaseEnvironment):
def broadcast(self, msg, sender=None, expiration=None, ttl=None, **kwargs):
class EventedEnvironment(Environment):
def broadcast(self, msg, sender, expiration=None, ttl=None, **kwargs):
for agent in self.agents(**kwargs): for agent in self.agents(**kwargs):
self.logger.info(f'Telling {repr(agent)}: {msg} ttl={ttl}') if agent == sender:
continue
self.logger.info(f"Telling {repr(agent)}: {msg} ttl={ttl}")
try: try:
agent._inbox.append(events.Tell(payload=msg, sender=sender, expiration=expiration if ttl is None else self.now+ttl)) inbox = agent._inbox
except AttributeError: except AttributeError:
self.info(f'Agent {agent.unique_id} cannot receive events') self.logger.info(
f"Agent {agent.unique_id} cannot receive events because it does not have an inbox"
)
continue
# Allow for AttributeError exceptions in this part of the code
inbox.append(
events.Tell(
payload=msg,
sender=sender,
expiration=expiration if ttl is None else self.now + ttl,
)
)
class Environment(NetworkEnvironment, EventedEnvironment):
"""Default environment class, has both network and event capabilities"""

@ -1,38 +1,51 @@
from .time import Cond from .time import BaseCond
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from uuid import uuid4 from uuid import uuid4
class Event: class Event:
pass pass
@dataclass @dataclass
class Message: class Message:
payload: Any payload: Any
sender: Any = None sender: Any = None
expiration: float = None expiration: float = None
timestamp: float = None
id: int = field(default_factory=uuid4) id: int = field(default_factory=uuid4)
def expired(self, when): def expired(self, when):
return self.expiration is not None and self.expiration < when return self.expiration is not None and self.expiration < when
class Reply(Message): class Reply(Message):
source: Message source: Message
class ReplyCond(BaseCond):
def __init__(self, ask, *args, **kwargs):
self._ask = ask
super().__init__(*args, **kwargs)
def ready(self, agent, time):
return self._ask.reply is not None or self._ask.expired(time)
def return_value(self, agent):
if self._ask.expired(agent.now):
raise TimedOut()
return self._ask.reply
def __repr__(self):
return f"ReplyCond({self._ask.id})"
class Ask(Message): class Ask(Message):
reply: Message = None reply: Message = None
def replied(self, expiration=None): def replied(self, expiration=None):
def ready(agent): return ReplyCond(self)
return self.reply is not None or agent.now > expiration
def value(agent):
if agent.now > expiration:
raise TimedOut(f'No answer received for {self}')
return self.reply
return Cond(func=ready, return_func=value)
class Tell(Message): class Tell(Message):

@ -59,7 +59,6 @@ def find_unassigned(G, shuffle=False, random=random):
If node_id is None, a node without an agent_id will be found. If node_id is None, a node without an agent_id will be found.
""" """
# TODO: test
candidates = list(G.nodes(data=True)) candidates = list(G.nodes(data=True))
if shuffle: if shuffle:
random.shuffle(candidates) random.shuffle(candidates)

@ -221,8 +221,6 @@ def deserialize(type_, value=None, globs=None, **kwargs):
def deserialize_all(names, *args, known_modules=KNOWN_MODULES, **kwargs): def deserialize_all(names, *args, known_modules=KNOWN_MODULES, **kwargs):
"""Return the list of deserialized objects""" """Return the list of deserialized objects"""
# TODO: remove
print("SERIALIZATION", kwargs)
objects = [] objects = []
for name in names: for name in names:
mod = deserialize(name, known_modules=known_modules) mod = deserialize(name, known_modules=known_modules)

@ -66,7 +66,7 @@ class Simulation:
if ignored: if ignored:
d.setdefault("extra", {}).update(ignored) d.setdefault("extra", {}).update(ignored)
if ignored: if ignored:
print(f'Warning: Ignoring these parameters (added to "extra"): { ignored }') logger.warning(f'Ignoring these parameters (added to "extra"): { ignored }')
d.update(kwargs) d.update(kwargs)
return cls(**d) return cls(**d)

@ -1,10 +1,11 @@
from mesa.time import BaseScheduler from mesa.time import BaseScheduler
from queue import Empty from queue import Empty
from heapq import heappush, heappop, heapify from heapq import heappush, heappop
import math import math
from inspect import getsource from inspect import getsource
from numbers import Number from numbers import Number
from textwrap import dedent
from .utils import logger from .utils import logger
from mesa import Agent as MesaAgent from mesa import Agent as MesaAgent
@ -23,87 +24,67 @@ class When:
return time return time
self._time = time self._time = time
def next(self, time):
return self._time
def abs(self, time): def abs(self, time):
return self return self._time
def __repr__(self):
return str(f"When({self._time})")
def __lt__(self, other):
if isinstance(other, Number):
return self._time < other
return self._time < other.next(self._time)
def __gt__(self, other): def schedule_next(self, time, delta, first=False):
if isinstance(other, Number): return (self._time, None)
return self._time > other
return self._time > other.next(self._time)
def ready(self, agent):
return self._time <= agent.model.schedule.time
def return_value(self, agent): NEVER = When(INFINITY)
return None
class Cond(When): class Delta(When):
def __init__(self, func, delta=1, return_func=lambda agent: None): def __init__(self, delta):
self._func = func
self._delta = delta self._delta = delta
self._checked = False
self._return_func = return_func
def next(self, time):
if self._checked:
return time + self._delta
return time
def abs(self, time): def abs(self, time):
return self return self._time + self._delta
def ready(self, agent):
self._checked = True
return self._func(agent)
def return_value(self, agent):
return self._return_func(agent)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, Delta):
return self._delta == other._delta
return False return False
def __lt__(self, other): def schedule_next(self, time, delta, first=False):
return True return (time + self._delta, None)
def __gt__(self, other):
return False
def __repr__(self): def __repr__(self):
return str(f'Cond("{getsource(self._func)}")') return str(f"Delta({self._delta})")
NEVER = When(INFINITY) class BaseCond:
def __init__(self, msg=None, delta=None, eager=False):
self._msg = msg
self._delta = delta
self.eager = eager
def schedule_next(self, time, delta, first=False):
if first and self.eager:
return (time, self)
if self._delta:
delta = self._delta
return (time + delta, self)
class Delta(When): def return_value(self, agent):
def __init__(self, delta): return None
self._delta = delta
def __eq__(self, other): def __repr__(self):
if isinstance(other, Delta): return self._msg or self.__class__.__name__
return self._delta == other._delta
return False
def abs(self, time):
return When(self._delta + time)
def next(self, time): class Cond(BaseCond):
return time + self._delta def __init__(self, func, *args, **kwargs):
self._func = func
super().__init__(*args, **kwargs)
def ready(self, agent, time):
return self._func(agent)
def __repr__(self): def __repr__(self):
return str(f"Delta({self._delta})") if self._msg:
return self._msg
return str(f'Cond("{dedent(getsource(self._func)).strip()}")')
class TimedActivation(BaseScheduler): class TimedActivation(BaseScheduler):
@ -111,28 +92,40 @@ class TimedActivation(BaseScheduler):
In each activation, each agent will update its 'next_time'. In each activation, each agent will update its 'next_time'.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, shuffle=True, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._next = {} self._next = {}
self._queue = [] self._queue = []
self.next_time = 0 self._shuffle = shuffle
self.step_interval = 1
self.logger = logger.getChild(f"time_{ self.model }") self.logger = logger.getChild(f"time_{ self.model }")
def add(self, agent: MesaAgent, when=None): def add(self, agent: MesaAgent, when=None):
if when is None: if when is None:
when = When(self.time) when = self.time
elif not isinstance(when, When): elif isinstance(when, When):
when = When(when) when = when.abs()
if agent.unique_id in self._agents:
del self._agents[agent.unique_id] self._schedule(agent, None, when)
if agent.unique_id in self._next:
self._queue.remove((self._next[agent.unique_id], agent))
heapify(self._queue)
self._next[agent.unique_id] = when
heappush(self._queue, (when, agent))
super().add(agent) super().add(agent)
def _schedule(self, agent, condition=None, when=None):
if condition:
if not when:
when, condition = condition.schedule_next(
when or self.time, self.step_interval
)
else:
if when is None:
when = self.time + self.step_interval
condition = None
if self._shuffle:
key = (when, self.model.random.random(), condition)
else:
key = (when, agent.unique_id, condition)
self._next[agent.unique_id] = key
heappush(self._queue, (key, agent))
def step(self) -> None: def step(self) -> None:
""" """
Executes agents in order, one at a time. After each step, Executes agents in order, one at a time. After each step,
@ -143,73 +136,71 @@ class TimedActivation(BaseScheduler):
if not self.model.running: if not self.model.running:
return return
when = NEVER
to_process = []
skipped = []
next_time = INFINITY
ix = 0
self.logger.debug(f"Queue length: {len(self._queue)}") self.logger.debug(f"Queue length: {len(self._queue)}")
while self._queue: while self._queue:
(when, agent) = self._queue[0] ((when, _id, cond), agent) = self._queue[0]
if when > self.time: if when > self.time:
break break
heappop(self._queue) heappop(self._queue)
if when.ready(agent): if cond:
if not cond.ready(agent, self.time):
self._schedule(agent, cond)
continue
try: try:
agent._last_return = when.return_value(agent) agent._last_return = cond.return_value(agent)
except Exception as ex: except Exception as ex:
agent._last_except = ex agent._last_except = ex
else:
agent._last_return = None
agent._last_except = None
self._next.pop(agent.unique_id, None)
to_process.append(agent)
continue
next_time = min(next_time, when.next(self.time))
self._next[agent.unique_id] = next_time
skipped.append((when, agent))
if self._queue:
next_time = min(next_time, self._queue[0][0].next(self.time))
self._queue = [*skipped, *self._queue]
for agent in to_process:
self.logger.debug(f"Stepping agent {agent}") self.logger.debug(f"Stepping agent {agent}")
self._next.pop(agent.unique_id, None)
try: try:
returned = ((agent.step() or Delta(1))).abs(self.time) returned = agent.step()
except DeadAgent: except DeadAgent:
if agent.unique_id in self._next:
del self._next[agent.unique_id]
agent.alive = False agent.alive = False
continue continue
# Check status for MESA agents
if not getattr(agent, "alive", True): if not getattr(agent, "alive", True):
continue continue
value = returned.next(self.time) if returned:
agent._last_return = value next_check = returned.schedule_next(
self.time, self.step_interval, first=True
if value < self.time:
raise Exception(
f"Cannot schedule an agent for a time in the past ({when} < {self.time})"
) )
if value < INFINITY: self._schedule(agent, when=next_check[0], condition=next_check[1])
next_time = min(value, next_time)
self._next[agent.unique_id] = returned
heappush(self._queue, (returned, agent))
else: else:
assert not self._next[agent.unique_id] next_check = (self.time + self.step_interval, None)
self._schedule(agent)
self.steps += 1 self.steps += 1
self.logger.debug(f"Updating time step: {self.time} -> {next_time}")
self.time = next_time
if not self._queue or next_time == INFINITY: if not self._queue:
self.time = INFINITY
self.model.running = False self.model.running = False
return self.time return self.time
next_time = self._queue[0][0][0]
if next_time < self.time:
raise Exception(
f"An agent has been scheduled for a time in the past, there is probably an error ({when} < {self.time})"
)
self.logger.debug(f"Updating time step: {self.time} -> {next_time}")
self.time = next_time
class ShuffledTimedActivation(TimedActivation):
def __init__(self, *args, **kwargs):
super().__init__(*args, shuffle=True, **kwargs)
class OrderedTimedActivation(TimedActivation):
def __init__(self, *args, **kwargs):
super().__init__(*args, shuffle=False, **kwargs)

@ -12,34 +12,34 @@ class Dead(agents.FSM):
return self.die() return self.die()
class TestMain(TestCase): class TestAgents(TestCase):
def test_die_returns_infinity(self): def test_die_returns_infinity(self):
'''The last step of a dead agent should return time.INFINITY''' """The last step of a dead agent should return time.INFINITY"""
d = Dead(unique_id=0, model=environment.Environment()) d = Dead(unique_id=0, model=environment.Environment())
ret = d.step().abs(0) ret = d.step()
print(ret, "next")
assert ret == stime.NEVER assert ret == stime.NEVER
def test_die_raises_exception(self): def test_die_raises_exception(self):
'''A dead agent should raise an exception if it is stepped after death''' """A dead agent should raise an exception if it is stepped after death"""
d = Dead(unique_id=0, model=environment.Environment()) d = Dead(unique_id=0, model=environment.Environment())
d.step() d.step()
with pytest.raises(stime.DeadAgent): with pytest.raises(stime.DeadAgent):
d.step() d.step()
def test_agent_generator(self): def test_agent_generator(self):
''' """
The step function of an agent could be a generator. In that case, the state of the The step function of an agent could be a generator. In that case, the state of the
agent will be resumed after every call to step. agent will be resumed after every call to step.
''' """
a = 0 a = 0
class Gen(agents.BaseAgent): class Gen(agents.BaseAgent):
def step(self): def step(self):
nonlocal a nonlocal a
for i in range(5): for i in range(5):
yield yield
a += 1 a += 1
e = environment.Environment() e = environment.Environment()
g = Gen(model=e, unique_id=e.next_id()) g = Gen(model=e, unique_id=e.next_id())
e.schedule.add(g) e.schedule.add(g)
@ -51,8 +51,9 @@ class TestMain(TestCase):
def test_state_decorator(self): def test_state_decorator(self):
class MyAgent(agents.FSM): class MyAgent(agents.FSM):
run = 0 run = 0
@agents.default_state @agents.default_state
@agents.state('original') @agents.state("original")
def root(self): def root(self):
self.run += 1 self.run += 1
return self.other return self.other
@ -66,4 +67,97 @@ class TestMain(TestCase):
a.step() a.step()
assert a.run == 1 assert a.run == 1
a.step() a.step()
assert a.run == 2
def test_broadcast(self):
"""
An agent should be able to broadcast messages to every other agent, AND each receiver should be able
to process it
"""
class BCast(agents.Evented):
pings_received = 0
def step(self):
print(self.model.broadcast)
try:
self.model.broadcast("PING")
except Exception as ex:
print(ex)
while True:
self.check_messages()
yield
def on_receive(self, msg, sender=None):
self.pings_received += 1
e = environment.EventedEnvironment()
for i in range(10):
e.add_agent(agent_class=BCast)
e.step()
pings_received = lambda: [a.pings_received for a in e.agents]
assert sorted(pings_received()) == list(range(1, 11))
e.step()
assert all(x == 10 for x in pings_received())
def test_ask_messages(self):
"""
An agent should be able to ask another agent, and wait for a response.
"""
# There are two agents, they try to send pings
# This is arguably a very contrived example. In practice, the or
# There should be a delay of one step between agent 0 and 1
# On the first step:
# Agent 0 sends a PING, but blocks before a PONG
# Agent 1 detects the PING, responds with a PONG, and blocks after its own PING
# After that step, every agent can both receive (there are pending messages) and send.
# In each step, for each agent, one message is sent, and another one is received
# (although not necessarily in that order).
# Results depend on ordering (agents are normally shuffled)
# so we force the timedactivation not to be shuffled
pings = []
pongs = []
responses = []
class Ping(agents.EventedAgent):
def step(self):
target_id = (self.unique_id + 1) % self.count_agents()
target = self.model.agents[target_id]
print("starting")
while True:
if pongs or not pings: # First agent, or anyone after that
pings.append(self.now)
response = yield target.ask("PING")
responses.append(response)
else:
print("NOT sending ping")
print("Checking msgs")
# Do not block if we have already received a PING
if not self.check_messages():
yield self.received()
print("done")
def on_receive(self, msg, sender=None):
if msg == "PING":
pongs.append(self.now)
return "PONG"
raise Exception("This should never happen")
e = environment.EventedEnvironment(schedule_class=stime.OrderedTimedActivation)
for i in range(2):
e.add_agent(agent_class=Ping)
assert e.now == 0
for i in range(5):
e.step()
time = i + 1
assert e.now == time
assert len(pings) == 2 * time
assert len(pongs) == (2 * time) - 1
# Every step between 0 and t appears twice
assert sum(pings) == sum(range(time)) * 2
# It is the same as pings, without the leading 0
assert sum(pongs) == sum(range(time)) * 2

@ -44,6 +44,8 @@ def add_example_tests():
for cfg, path in serialization.load_files( for cfg, path in serialization.load_files(
join(EXAMPLES, "**", "*.yml"), join(EXAMPLES, "**", "*.yml"),
): ):
if "soil_output" in path:
continue
p = make_example_test(path=path, cfg=config.Config.from_raw(cfg)) p = make_example_test(path=path, cfg=config.Config.from_raw(cfg))
fname = os.path.basename(path) fname = os.path.basename(path)
p.__name__ = "test_example_file_%s" % fname p.__name__ = "test_example_file_%s" % fname

@ -172,25 +172,24 @@ class TestMain(TestCase):
assert len(configs) > 0 assert len(configs) > 0
def test_until(self): def test_until(self):
config = { n_runs = 0
"name": "until_sim",
"model_params": { class CheckRun(agents.BaseAgent):
"network_params": {}, def step(self):
"agents": { nonlocal n_runs
"fixed": [ n_runs += 1
{ return super().step()
"agent_class": agents.BaseAgent,
} n_trials = 50
] max_time = 2
}, s = simulation.Simulation(
}, model_params={"agents": [{"agent_class": CheckRun}]},
"max_time": 2, num_trials=n_trials,
"num_trials": 50, max_time=max_time,
} )
s = simulation.from_config(config)
runs = list(s.run_simulation(dry_run=True)) runs = list(s.run_simulation(dry_run=True))
over = list(x.now for x in runs if x.now > 2) over = list(x.now for x in runs if x.now > 2)
assert len(runs) == config["num_trials"] assert len(runs) == n_trials
assert len(over) == 0 assert len(over) == 0
def test_fsm(self): def test_fsm(self):

@ -72,7 +72,7 @@ class TestNetwork(TestCase):
assert len(env.agents) == 2 assert len(env.agents) == 2
assert env.agents[1].count_agents(state_id="normal") == 2 assert env.agents[1].count_agents(state_id="normal") == 2
assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1 assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1
assert env.agents[0].neighbors == 1 assert env.agents[0].count_neighbors() == 1
def test_custom_agent_neighbors(self): def test_custom_agent_neighbors(self):
"""Allow for search of neighbors with a certain state_id""" """Allow for search of neighbors with a certain state_id"""
@ -90,7 +90,7 @@ class TestNetwork(TestCase):
env = s.run_simulation(dry_run=True)[0] env = s.run_simulation(dry_run=True)[0]
assert env.agents[1].count_agents(state_id="normal") == 2 assert env.agents[1].count_agents(state_id="normal") == 2
assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1 assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1
assert env.agents[0].neighbors == 1 assert env.agents[0].count_neighbors() == 1
def test_subgraph(self): def test_subgraph(self):
"""An agent should be able to subgraph the global topology""" """An agent should be able to subgraph the global topology"""

@ -2,11 +2,12 @@ from unittest import TestCase
from soil import time, agents, environment from soil import time, agents, environment
class TestMain(TestCase): class TestMain(TestCase):
def test_cond(self): def test_cond(self):
''' """
A condition should match a When if the concition is True A condition should match a When if the concition is True
''' """
t = time.Cond(lambda t: True) t = time.Cond(lambda t: True)
f = time.Cond(lambda t: False) f = time.Cond(lambda t: False)
@ -16,59 +17,58 @@ class TestMain(TestCase):
assert w is not f assert w is not f
def test_cond(self): def test_cond(self):
''' """
Comparing a Cond to a Delta should always return False Comparing a Cond to a Delta should always return False
''' """
c = time.Cond(lambda t: False) c = time.Cond(lambda t: False)
d = time.Delta(1) d = time.Delta(1)
assert c is not d assert c is not d
def test_cond_env(self): def test_cond_env(self):
''' """ """
'''
times_started = [] times_started = []
times_awakened = [] times_awakened = []
times_asleep = []
times = [] times = []
done = 0 done = []
class CondAgent(agents.BaseAgent): class CondAgent(agents.BaseAgent):
def step(self): def step(self):
nonlocal done nonlocal done
times_started.append(self.now) times_started.append(self.now)
while True: while True:
yield time.Cond(lambda agent: agent.model.schedule.time >= 10) times_asleep.append(self.now)
yield time.Cond(lambda agent: agent.now >= 10, delta=2)
times_awakened.append(self.now) times_awakened.append(self.now)
if self.now >= 10: if self.now >= 10:
break break
done += 1 done.append(self.now)
env = environment.Environment(agents=[{'agent_class': CondAgent}])
env = environment.Environment(agents=[{"agent_class": CondAgent}])
while env.schedule.time < 11: while env.schedule.time < 11:
env.step()
times.append(env.now) times.append(env.now)
env.step()
assert env.schedule.time == 11 assert env.schedule.time == 11
assert times_started == [0] assert times_started == [0]
assert times_awakened == [10] assert times_awakened == [10]
assert done == 1 assert done == [10]
# The first time will produce the Cond. # The first time will produce the Cond.
# Since there are no other agents, time will not advance, but the number assert env.schedule.steps == 6
# of steps will. assert len(times) == 6
assert env.schedule.steps == 12
assert len(times) == 12
while env.schedule.time < 12: while env.schedule.time < 13:
env.step()
times.append(env.now) times.append(env.now)
env.step()
assert env.schedule.time == 12 assert times == [0, 2, 4, 6, 8, 10, 11]
assert env.schedule.time == 13
assert times_started == [0, 11] assert times_started == [0, 11]
assert times_awakened == [10, 11] assert times_awakened == [10]
assert done == 2 assert done == [10]
# Once more to yield the cond, another one to continue # Once more to yield the cond, another one to continue
assert env.schedule.steps == 14 assert env.schedule.steps == 7
assert len(times) == 14 assert len(times) == 7

Loading…
Cancel
Save