mirror of
https://github.com/gsi-upm/soil
synced 2024-12-22 00:08:12 +00:00
Refactored time
Treating time and conditions as the same entity was getting confusing, and it added a lot of unnecessary abstraction in a critical part (the scheduler). The scheduling queue now has the time as a floating number (faster), the agent id (for ties) and the condition, as well as the agent. The first three elements (time, id, condition) can be considered as the "key" for the event. To allow for agent execution to be "randomized" within every step, a new parameter has been added to the scheduler, which makes it add a random number to the key in order to change the ordering. `EventedAgent.received` now checks the messages before returning control to the user by default.
This commit is contained in:
parent
cbbaf73538
commit
a1262edd2a
@ -127,7 +127,8 @@ class Driver(Evented, FSM):
|
||||
)
|
||||
|
||||
self.check_passengers()
|
||||
self.check_messages() # This will call on_receive behind the scenes, and the agent's status will be updated
|
||||
# This will call on_receive behind the scenes, and the agent's status will be updated
|
||||
self.check_messages()
|
||||
yield Delta(30) # Wait at least 30 seconds before checking again
|
||||
|
||||
try:
|
||||
@ -204,6 +205,8 @@ class Passenger(Evented, FSM):
|
||||
while not self.journey:
|
||||
self.info(f"Passenger at: { self.pos }. Checking for responses.")
|
||||
try:
|
||||
# This will call check_messages behind the scenes, and the agent's status will be updated
|
||||
# If you want to avoid that, you can call it with: check=False
|
||||
yield self.received(expiration=expiration)
|
||||
except events.TimedOut:
|
||||
self.info(f"Passenger at: { self.pos }. Asking for journey.")
|
||||
@ -211,7 +214,6 @@ class Passenger(Evented, FSM):
|
||||
journey, ttl=timeout, sender=self, agent_class=Driver
|
||||
)
|
||||
expiration = self.now + timeout
|
||||
self.check_messages()
|
||||
return self.driving_home
|
||||
|
||||
@state
|
||||
@ -220,7 +222,11 @@ class Passenger(Evented, FSM):
|
||||
self.pos[0] != self.journey.destination[0]
|
||||
or self.pos[1] != self.journey.destination[1]
|
||||
):
|
||||
yield self.received(timeout=60)
|
||||
try:
|
||||
yield self.received(timeout=60)
|
||||
except events.TimedOut:
|
||||
pass
|
||||
|
||||
self.info("Got home safe!")
|
||||
self.die()
|
||||
|
||||
|
@ -133,7 +133,7 @@ class RandomAccident(BaseAgent):
|
||||
math.log10(max(1, rabbits_alive))
|
||||
)
|
||||
self.debug("Killing some rabbits with prob={}!".format(prob_death))
|
||||
for i in self.iter_agents(agent_class=Rabbit):
|
||||
for i in self.get_agents(agent_class=Rabbit):
|
||||
if i.state_id == i.dead.id:
|
||||
continue
|
||||
if self.prob(prob_death):
|
||||
|
@ -20,12 +20,6 @@ from typing import Dict, List
|
||||
from .. import serialization, utils, time, config
|
||||
|
||||
|
||||
def as_node(agent):
|
||||
if isinstance(agent, BaseAgent):
|
||||
return agent.id
|
||||
return agent
|
||||
|
||||
|
||||
IGNORED_FIELDS = ("model", "logger")
|
||||
|
||||
|
||||
@ -97,10 +91,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
"""
|
||||
|
||||
def __init__(self, unique_id, model, name=None, interval=None, **kwargs):
|
||||
# Check for REQUIRED arguments
|
||||
# Initialize agent parameters
|
||||
if isinstance(unique_id, MesaAgent):
|
||||
raise Exception()
|
||||
assert isinstance(unique_id, int)
|
||||
super().__init__(unique_id=unique_id, model=model)
|
||||
|
||||
@ -207,7 +197,8 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
def step(self):
|
||||
if not self.alive:
|
||||
raise time.DeadAgent(self.unique_id)
|
||||
return super().step() or time.Delta(self.interval)
|
||||
super().step()
|
||||
return time.Delta(self.interval)
|
||||
|
||||
def log(self, message, *args, level=logging.INFO, **kwargs):
|
||||
if not self.logger.isEnabledFor(level):
|
||||
@ -270,6 +261,7 @@ def prob(prob, random):
|
||||
return r < prob
|
||||
|
||||
|
||||
|
||||
def calculate_distribution(network_agents=None, agent_class=None):
|
||||
"""
|
||||
Calculate the threshold values (thresholds for a uniform distribution)
|
||||
@ -414,7 +406,7 @@ def filter_agents(
|
||||
if ids:
|
||||
f = (agents[aid] for aid in ids if aid in agents)
|
||||
else:
|
||||
f = (a for a in agents.values())
|
||||
f = agents.values()
|
||||
|
||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||
state_id = tuple([state_id])
|
||||
@ -638,6 +630,10 @@ from .SentimentCorrelationModel import *
|
||||
from .SISaModel import *
|
||||
from .CounterModel import *
|
||||
|
||||
|
||||
class Agent(NetworkAgent, EventedAgent):
|
||||
'''Default agent class, has both network and event capabilities'''
|
||||
|
||||
try:
|
||||
import scipy
|
||||
from .Geo import Geo
|
||||
|
@ -1,57 +1,74 @@
|
||||
from . import BaseAgent
|
||||
from ..events import Message, Tell, Ask, Reply, TimedOut
|
||||
from ..time import Cond
|
||||
from ..events import Message, Tell, Ask, TimedOut
|
||||
from ..time import BaseCond
|
||||
from functools import partial
|
||||
from collections import deque
|
||||
|
||||
|
||||
class Evented(BaseAgent):
|
||||
class ReceivedOrTimeout(BaseCond):
|
||||
def __init__(self, agent, expiration=None, timeout=None, check=True, ignore=False, **kwargs):
|
||||
if expiration is None:
|
||||
if timeout is not None:
|
||||
expiration = agent.now + timeout
|
||||
self.expiration = expiration
|
||||
self.ignore = ignore
|
||||
self.check = check
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def expired(self, time):
|
||||
return self.expiration and self.expiration < time
|
||||
|
||||
def ready(self, agent, time):
|
||||
return len(agent._inbox) or self.expired(time)
|
||||
|
||||
def return_value(self, agent):
|
||||
if not self.ignore and self.expired(agent.now):
|
||||
raise TimedOut('No messages received')
|
||||
if self.check:
|
||||
agent.check_messages()
|
||||
return None
|
||||
|
||||
def schedule_next(self, time, delta, first=False):
|
||||
if self._delta is not None:
|
||||
delta = self._delta
|
||||
return (time + delta, self)
|
||||
|
||||
def __repr__(self):
|
||||
return f'ReceivedOrTimeout(expires={self.expiration})'
|
||||
|
||||
|
||||
class EventedAgent(BaseAgent):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._inbox = deque()
|
||||
self._received = 0
|
||||
self._processed = 0
|
||||
|
||||
|
||||
def on_receive(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def received(self, expiration=None, timeout=None):
|
||||
current = self._received
|
||||
if expiration is None:
|
||||
expiration = float('inf') if timeout is None else self.now + timeout
|
||||
def received(self, *args, **kwargs):
|
||||
return ReceivedOrTimeout(self, *args, **kwargs)
|
||||
|
||||
if expiration < self.now:
|
||||
raise ValueError("Invalid expiration time")
|
||||
|
||||
def ready(agent):
|
||||
return agent._received > current or agent.now >= expiration
|
||||
|
||||
def value(agent):
|
||||
if agent.now > expiration:
|
||||
raise TimedOut("No message received")
|
||||
|
||||
c = Cond(func=ready, return_func=value)
|
||||
c._checked = True
|
||||
return c
|
||||
|
||||
def tell(self, msg, sender):
|
||||
self._received += 1
|
||||
self._inbox.append(Tell(payload=msg, sender=sender))
|
||||
def tell(self, msg, sender=None):
|
||||
self._inbox.append(Tell(timestamp=self.now, payload=msg, sender=sender))
|
||||
|
||||
def ask(self, msg, timeout=None):
|
||||
self._received += 1
|
||||
ask = Ask(payload=msg)
|
||||
ask = Ask(timestamp=self.now, payload=msg)
|
||||
self._inbox.append(ask)
|
||||
expiration = float('inf') if timeout is None else self.now + timeout
|
||||
return ask.replied(expiration=expiration)
|
||||
|
||||
def check_messages(self):
|
||||
changed = False
|
||||
while self._inbox:
|
||||
msg = self._inbox.popleft()
|
||||
self._processed += 1
|
||||
if msg.expired(self.now):
|
||||
continue
|
||||
changed = True
|
||||
reply = self.on_receive(msg.payload, sender=msg.sender)
|
||||
if isinstance(msg, Ask):
|
||||
msg.reply = reply
|
||||
return changed
|
||||
|
||||
Evented = EventedAgent
|
||||
|
@ -54,7 +54,7 @@ class NetworkAgent(BaseAgent):
|
||||
return G
|
||||
|
||||
def remove_node(self):
|
||||
print(f"Removing node for {self.unique_id}: {self.node_id}")
|
||||
self.debug(f"Removing node for {self.unique_id}: {self.node_id}")
|
||||
self.G.remove_node(self.node_id)
|
||||
self.node_id = None
|
||||
|
||||
@ -80,3 +80,5 @@ class NetworkAgent(BaseAgent):
|
||||
if remove:
|
||||
self.remove_node()
|
||||
return super().die()
|
||||
|
||||
NetAgent = NetworkAgent
|
||||
|
@ -310,17 +310,20 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
self.add_agent(node_id=node_id, agent_class=a_class, **agent_params)
|
||||
|
||||
|
||||
Environment = NetworkEnvironment
|
||||
|
||||
|
||||
class EventedEnvironment(Environment):
|
||||
def broadcast(self, msg, sender, expiration=None, ttl=None, **kwargs):
|
||||
class EventedEnvironment(BaseEnvironment):
|
||||
def broadcast(self, msg, sender=None, expiration=None, ttl=None, **kwargs):
|
||||
for agent in self.agents(**kwargs):
|
||||
if agent == sender:
|
||||
continue
|
||||
self.logger.info(f'Telling {repr(agent)}: {msg} ttl={ttl}')
|
||||
try:
|
||||
inbox = agent._inbox
|
||||
except AttributeError:
|
||||
self.logger.info(f'Agent {agent.unique_id} cannot receive events because it does not have an inbox')
|
||||
continue
|
||||
# Allow for AttributeError exceptions in this part of the code
|
||||
inbox.append(events.Tell(payload=msg, sender=sender, expiration=expiration if ttl is None else self.now+ttl))
|
||||
|
||||
|
||||
class Environment(NetworkEnvironment, EventedEnvironment):
|
||||
'''Default environment class, has both network and event capabilities'''
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .time import Cond
|
||||
from .time import BaseCond
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
@ -11,6 +11,7 @@ class Message:
|
||||
payload: Any
|
||||
sender: Any = None
|
||||
expiration: float = None
|
||||
timestamp: float = None
|
||||
id: int = field(default_factory=uuid4)
|
||||
|
||||
def expired(self, when):
|
||||
@ -20,19 +21,28 @@ class Reply(Message):
|
||||
source: Message
|
||||
|
||||
|
||||
class ReplyCond(BaseCond):
|
||||
def __init__(self, ask, *args, **kwargs):
|
||||
self._ask = ask
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def ready(self, agent, time):
|
||||
return self._ask.reply is not None or self._ask.expired(time)
|
||||
|
||||
def return_value(self, agent):
|
||||
if self._ask.expired(agent.now):
|
||||
raise TimedOut()
|
||||
return self._ask.reply
|
||||
|
||||
def __repr__(self):
|
||||
return f"ReplyCond({self._ask.id})"
|
||||
|
||||
|
||||
class Ask(Message):
|
||||
reply: Message = None
|
||||
|
||||
def replied(self, expiration=None):
|
||||
def ready(agent):
|
||||
return self.reply is not None or agent.now > expiration
|
||||
|
||||
def value(agent):
|
||||
if agent.now > expiration:
|
||||
raise TimedOut(f'No answer received for {self}')
|
||||
return self.reply
|
||||
|
||||
return Cond(func=ready, return_func=value)
|
||||
return ReplyCond(self)
|
||||
|
||||
|
||||
class Tell(Message):
|
||||
|
@ -59,7 +59,6 @@ def find_unassigned(G, shuffle=False, random=random):
|
||||
|
||||
If node_id is None, a node without an agent_id will be found.
|
||||
"""
|
||||
# TODO: test
|
||||
candidates = list(G.nodes(data=True))
|
||||
if shuffle:
|
||||
random.shuffle(candidates)
|
||||
|
@ -221,8 +221,6 @@ def deserialize(type_, value=None, globs=None, **kwargs):
|
||||
|
||||
def deserialize_all(names, *args, known_modules=KNOWN_MODULES, **kwargs):
|
||||
"""Return the list of deserialized objects"""
|
||||
# TODO: remove
|
||||
print("SERIALIZATION", kwargs)
|
||||
objects = []
|
||||
for name in names:
|
||||
mod = deserialize(name, known_modules=known_modules)
|
||||
|
@ -66,7 +66,7 @@ class Simulation:
|
||||
if ignored:
|
||||
d.setdefault("extra", {}).update(ignored)
|
||||
if ignored:
|
||||
print(f'Warning: Ignoring these parameters (added to "extra"): { ignored }')
|
||||
logger.warning(f'Ignoring these parameters (added to "extra"): { ignored }')
|
||||
d.update(kwargs)
|
||||
|
||||
return cls(**d)
|
||||
|
219
soil/time.py
219
soil/time.py
@ -1,10 +1,11 @@
|
||||
from mesa.time import BaseScheduler
|
||||
from queue import Empty
|
||||
from heapq import heappush, heappop, heapify
|
||||
from heapq import heappush, heappop
|
||||
import math
|
||||
|
||||
from inspect import getsource
|
||||
from numbers import Number
|
||||
from textwrap import dedent
|
||||
|
||||
from .utils import logger
|
||||
from mesa import Agent as MesaAgent
|
||||
@ -23,65 +24,11 @@ class When:
|
||||
return time
|
||||
self._time = time
|
||||
|
||||
def next(self, time):
|
||||
def abs(self, time):
|
||||
return self._time
|
||||
|
||||
def abs(self, time):
|
||||
return self
|
||||
|
||||
def __repr__(self):
|
||||
return str(f"When({self._time})")
|
||||
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, Number):
|
||||
return self._time < other
|
||||
return self._time < other.next(self._time)
|
||||
|
||||
def __gt__(self, other):
|
||||
if isinstance(other, Number):
|
||||
return self._time > other
|
||||
return self._time > other.next(self._time)
|
||||
|
||||
def ready(self, agent):
|
||||
return self._time <= agent.model.schedule.time
|
||||
|
||||
def return_value(self, agent):
|
||||
return None
|
||||
|
||||
|
||||
class Cond(When):
|
||||
def __init__(self, func, delta=1, return_func=lambda agent: None):
|
||||
self._func = func
|
||||
self._delta = delta
|
||||
self._checked = False
|
||||
self._return_func = return_func
|
||||
|
||||
def next(self, time):
|
||||
if self._checked:
|
||||
return time + self._delta
|
||||
return time
|
||||
|
||||
def abs(self, time):
|
||||
return self
|
||||
|
||||
def ready(self, agent):
|
||||
self._checked = True
|
||||
return self._func(agent)
|
||||
|
||||
def return_value(self, agent):
|
||||
return self._return_func(agent)
|
||||
|
||||
def __eq__(self, other):
|
||||
return False
|
||||
|
||||
def __lt__(self, other):
|
||||
return True
|
||||
|
||||
def __gt__(self, other):
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return str(f'Cond("{getsource(self._func)}")')
|
||||
def schedule_next(self, time, delta, first=False):
|
||||
return (self._time, None)
|
||||
|
||||
|
||||
NEVER = When(INFINITY)
|
||||
@ -91,48 +38,94 @@ class Delta(When):
|
||||
def __init__(self, delta):
|
||||
self._delta = delta
|
||||
|
||||
def abs(self, time):
|
||||
return self._time + self._delta
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Delta):
|
||||
return self._delta == other._delta
|
||||
return False
|
||||
|
||||
def abs(self, time):
|
||||
return When(self._delta + time)
|
||||
|
||||
def next(self, time):
|
||||
return time + self._delta
|
||||
def schedule_next(self, time, delta, first=False):
|
||||
return (time + self._delta, None)
|
||||
|
||||
def __repr__(self):
|
||||
return str(f"Delta({self._delta})")
|
||||
|
||||
|
||||
class BaseCond:
|
||||
def __init__(self, msg=None, delta=None, eager=False):
|
||||
self._msg = msg
|
||||
self._delta = delta
|
||||
self.eager = eager
|
||||
|
||||
def schedule_next(self, time, delta, first=False):
|
||||
if first and self.eager:
|
||||
return (time, self)
|
||||
if self._delta:
|
||||
delta = self._delta
|
||||
return (time + delta, self)
|
||||
|
||||
def return_value(self, agent):
|
||||
return None
|
||||
|
||||
def __repr__(self):
|
||||
return self._msg or self.__class__.__name__
|
||||
|
||||
|
||||
class Cond(BaseCond):
|
||||
def __init__(self, func, *args, **kwargs):
|
||||
self._func = func
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def ready(self, agent, time):
|
||||
return self._func(agent)
|
||||
|
||||
def __repr__(self):
|
||||
if self._msg:
|
||||
return self._msg
|
||||
return str(f'Cond("{dedent(getsource(self._func)).strip()}")')
|
||||
|
||||
|
||||
class TimedActivation(BaseScheduler):
|
||||
"""A scheduler which activates each agent when the agent requests.
|
||||
In each activation, each agent will update its 'next_time'.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args, shuffle=True, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._next = {}
|
||||
self._queue = []
|
||||
self.next_time = 0
|
||||
self._shuffle = shuffle
|
||||
self.step_interval = 1
|
||||
self.logger = logger.getChild(f"time_{ self.model }")
|
||||
|
||||
def add(self, agent: MesaAgent, when=None):
|
||||
if when is None:
|
||||
when = When(self.time)
|
||||
elif not isinstance(when, When):
|
||||
when = When(when)
|
||||
if agent.unique_id in self._agents:
|
||||
del self._agents[agent.unique_id]
|
||||
if agent.unique_id in self._next:
|
||||
self._queue.remove((self._next[agent.unique_id], agent))
|
||||
heapify(self._queue)
|
||||
|
||||
self._next[agent.unique_id] = when
|
||||
heappush(self._queue, (when, agent))
|
||||
when = self.time
|
||||
elif isinstance(when, When):
|
||||
when = when.abs()
|
||||
|
||||
self._schedule(agent, None, when)
|
||||
super().add(agent)
|
||||
|
||||
def _schedule(self, agent, condition=None, when=None):
|
||||
if condition:
|
||||
if not when:
|
||||
when, condition = condition.schedule_next(when or self.time,
|
||||
self.step_interval)
|
||||
else:
|
||||
if when is None:
|
||||
when = self.time + self.step_interval
|
||||
condition = None
|
||||
if self._shuffle:
|
||||
key = (when, self.model.random.random(), condition)
|
||||
else:
|
||||
key = (when, agent.unique_id, condition)
|
||||
self._next[agent.unique_id] = key
|
||||
heappush(self._queue, (key, agent))
|
||||
|
||||
|
||||
def step(self) -> None:
|
||||
"""
|
||||
Executes agents in order, one at a time. After each step,
|
||||
@ -143,73 +136,59 @@ class TimedActivation(BaseScheduler):
|
||||
if not self.model.running:
|
||||
return
|
||||
|
||||
when = NEVER
|
||||
|
||||
to_process = []
|
||||
skipped = []
|
||||
next_time = INFINITY
|
||||
|
||||
ix = 0
|
||||
|
||||
self.logger.debug(f"Queue length: {len(self._queue)}")
|
||||
|
||||
while self._queue:
|
||||
(when, agent) = self._queue[0]
|
||||
((when, _id, cond), agent) = self._queue[0]
|
||||
if when > self.time:
|
||||
break
|
||||
|
||||
heappop(self._queue)
|
||||
if when.ready(agent):
|
||||
if cond:
|
||||
if not cond.ready(agent, self.time):
|
||||
self._schedule(agent, cond)
|
||||
continue
|
||||
try:
|
||||
agent._last_return = when.return_value(agent)
|
||||
agent._last_return = cond.return_value(agent)
|
||||
except Exception as ex:
|
||||
agent._last_except = ex
|
||||
else:
|
||||
agent._last_return = None
|
||||
agent._last_except = None
|
||||
|
||||
self._next.pop(agent.unique_id, None)
|
||||
to_process.append(agent)
|
||||
continue
|
||||
|
||||
next_time = min(next_time, when.next(self.time))
|
||||
self._next[agent.unique_id] = next_time
|
||||
skipped.append((when, agent))
|
||||
|
||||
if self._queue:
|
||||
next_time = min(next_time, self._queue[0][0].next(self.time))
|
||||
|
||||
self._queue = [*skipped, *self._queue]
|
||||
|
||||
for agent in to_process:
|
||||
self.logger.debug(f"Stepping agent {agent}")
|
||||
self._next.pop(agent.unique_id, None)
|
||||
|
||||
try:
|
||||
returned = ((agent.step() or Delta(1))).abs(self.time)
|
||||
returned = agent.step()
|
||||
except DeadAgent:
|
||||
if agent.unique_id in self._next:
|
||||
del self._next[agent.unique_id]
|
||||
agent.alive = False
|
||||
continue
|
||||
|
||||
# Check status for MESA agents
|
||||
if not getattr(agent, "alive", True):
|
||||
continue
|
||||
|
||||
value = returned.next(self.time)
|
||||
agent._last_return = value
|
||||
|
||||
if value < self.time:
|
||||
raise Exception(
|
||||
f"Cannot schedule an agent for a time in the past ({when} < {self.time})"
|
||||
)
|
||||
if value < INFINITY:
|
||||
next_time = min(value, next_time)
|
||||
|
||||
self._next[agent.unique_id] = returned
|
||||
heappush(self._queue, (returned, agent))
|
||||
if returned:
|
||||
next_check = returned.schedule_next(self.time, self.step_interval, first=True)
|
||||
self._schedule(agent, when=next_check[0], condition=next_check[1])
|
||||
else:
|
||||
assert not self._next[agent.unique_id]
|
||||
next_check = (self.time + self.step_interval, None)
|
||||
|
||||
self._schedule(agent)
|
||||
|
||||
self.steps += 1
|
||||
self.logger.debug(f"Updating time step: {self.time} -> {next_time}")
|
||||
self.time = next_time
|
||||
|
||||
if not self._queue or next_time == INFINITY:
|
||||
if not self._queue:
|
||||
self.time = INFINITY
|
||||
self.model.running = False
|
||||
return self.time
|
||||
|
||||
next_time = self._queue[0][0][0]
|
||||
if next_time < self.time:
|
||||
raise Exception(
|
||||
f"An agent has been scheduled for a time in the past, there is probably an error ({when} < {self.time})"
|
||||
)
|
||||
self.logger.debug(f"Updating time step: {self.time} -> {next_time}")
|
||||
|
||||
self.time = next_time
|
||||
|
@ -12,12 +12,11 @@ class Dead(agents.FSM):
|
||||
return self.die()
|
||||
|
||||
|
||||
class TestMain(TestCase):
|
||||
class TestAgents(TestCase):
|
||||
def test_die_returns_infinity(self):
|
||||
'''The last step of a dead agent should return time.INFINITY'''
|
||||
d = Dead(unique_id=0, model=environment.Environment())
|
||||
ret = d.step().abs(0)
|
||||
print(ret, "next")
|
||||
ret = d.step()
|
||||
assert ret == stime.NEVER
|
||||
|
||||
def test_die_raises_exception(self):
|
||||
@ -66,4 +65,95 @@ class TestMain(TestCase):
|
||||
a.step()
|
||||
assert a.run == 1
|
||||
a.step()
|
||||
assert a.run == 2
|
||||
|
||||
|
||||
def test_broadcast(self):
|
||||
'''
|
||||
An agent should be able to broadcast messages to every other agent, AND each receiver should be able
|
||||
to process it
|
||||
'''
|
||||
class BCast(agents.Evented):
|
||||
pings_received = 0
|
||||
|
||||
def step(self):
|
||||
print(self.model.broadcast)
|
||||
try:
|
||||
self.model.broadcast('PING')
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
while True:
|
||||
self.check_messages()
|
||||
yield
|
||||
|
||||
def on_receive(self, msg, sender=None):
|
||||
self.pings_received += 1
|
||||
|
||||
e = environment.EventedEnvironment()
|
||||
|
||||
for i in range(10):
|
||||
e.add_agent(agent_class=BCast)
|
||||
e.step()
|
||||
pings_received = lambda: [a.pings_received for a in e.agents]
|
||||
assert sorted(pings_received()) == list(range(1, 11))
|
||||
e.step()
|
||||
assert all(x==10 for x in pings_received())
|
||||
|
||||
def test_ask_messages(self):
|
||||
'''
|
||||
An agent should be able to ask another agent, and wait for a response.
|
||||
'''
|
||||
|
||||
# #Results depend on ordering (agents are shuffled), so force the first agent
|
||||
pings = []
|
||||
pongs = []
|
||||
responses = []
|
||||
|
||||
class Ping(agents.EventedAgent):
|
||||
def step(self):
|
||||
target_id = (self.unique_id + 1) % self.count_agents()
|
||||
target = self.model.agents[target_id]
|
||||
print('starting')
|
||||
while True:
|
||||
print('Pings: ', pings, responses or not pings, self.model.schedule._queue)
|
||||
if pongs or not pings:
|
||||
pings.append(self.now)
|
||||
response = yield target.ask('PING')
|
||||
responses.append(response)
|
||||
else:
|
||||
print('NOT sending ping')
|
||||
print('Checking msgs')
|
||||
# Do not advance until we have received a message.
|
||||
# warning: it will wait at least until the next time in the simulation
|
||||
yield self.received(check=True)
|
||||
print('done')
|
||||
|
||||
def on_receive(self, msg, sender=None):
|
||||
if msg == 'PING':
|
||||
pongs.append(self.now)
|
||||
return 'PONG'
|
||||
|
||||
e = environment.EventedEnvironment()
|
||||
for i in range(2):
|
||||
e.add_agent(agent_class=Ping)
|
||||
assert e.now == 0
|
||||
|
||||
# There should be a delay of one step between agent 0 and 1
|
||||
# On the first step:
|
||||
# Agent 0 sends a PING, but blocks before a PONG
|
||||
# Agent 1 sends a PONG, and blocks after its PING
|
||||
# After that step, every agent can both receive (there are pending messages) and then send.
|
||||
|
||||
e.step()
|
||||
assert e.now == 1
|
||||
assert pings == [0]
|
||||
assert pongs == []
|
||||
|
||||
e.step()
|
||||
assert e.now == 2
|
||||
assert pings == [0, 1]
|
||||
assert pongs == [1]
|
||||
|
||||
e.step()
|
||||
assert e.now == 3
|
||||
assert pings == [0, 1, 2]
|
||||
assert pongs == [1, 2]
|
||||
|
@ -44,6 +44,8 @@ def add_example_tests():
|
||||
for cfg, path in serialization.load_files(
|
||||
join(EXAMPLES, "**", "*.yml"),
|
||||
):
|
||||
if 'soil_output' in path:
|
||||
continue
|
||||
p = make_example_test(path=path, cfg=config.Config.from_raw(cfg))
|
||||
fname = os.path.basename(path)
|
||||
p.__name__ = "test_example_file_%s" % fname
|
||||
|
@ -172,25 +172,21 @@ class TestMain(TestCase):
|
||||
assert len(configs) > 0
|
||||
|
||||
def test_until(self):
|
||||
config = {
|
||||
"name": "until_sim",
|
||||
"model_params": {
|
||||
"network_params": {},
|
||||
"agents": {
|
||||
"fixed": [
|
||||
{
|
||||
"agent_class": agents.BaseAgent,
|
||||
}
|
||||
]
|
||||
},
|
||||
},
|
||||
"max_time": 2,
|
||||
"num_trials": 50,
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
n_runs = 0
|
||||
|
||||
class CheckRun(agents.BaseAgent):
|
||||
def step(self):
|
||||
nonlocal n_runs
|
||||
n_runs += 1
|
||||
return super().step()
|
||||
|
||||
n_trials = 50
|
||||
max_time = 2
|
||||
s = simulation.Simulation(model_params={'agents': [{'agent_class': CheckRun}]},
|
||||
num_trials=n_trials, max_time=max_time)
|
||||
runs = list(s.run_simulation(dry_run=True))
|
||||
over = list(x.now for x in runs if x.now > 2)
|
||||
assert len(runs) == config["num_trials"]
|
||||
assert len(runs) == n_trials
|
||||
assert len(over) == 0
|
||||
|
||||
def test_fsm(self):
|
||||
|
@ -72,7 +72,7 @@ class TestNetwork(TestCase):
|
||||
assert len(env.agents) == 2
|
||||
assert env.agents[1].count_agents(state_id="normal") == 2
|
||||
assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1
|
||||
assert env.agents[0].neighbors == 1
|
||||
assert env.agents[0].count_neighbors() == 1
|
||||
|
||||
def test_custom_agent_neighbors(self):
|
||||
"""Allow for search of neighbors with a certain state_id"""
|
||||
@ -90,7 +90,7 @@ class TestNetwork(TestCase):
|
||||
env = s.run_simulation(dry_run=True)[0]
|
||||
assert env.agents[1].count_agents(state_id="normal") == 2
|
||||
assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1
|
||||
assert env.agents[0].neighbors == 1
|
||||
assert env.agents[0].count_neighbors() == 1
|
||||
|
||||
def test_subgraph(self):
|
||||
"""An agent should be able to subgraph the global topology"""
|
||||
|
@ -30,8 +30,9 @@ class TestMain(TestCase):
|
||||
|
||||
times_started = []
|
||||
times_awakened = []
|
||||
times_asleep = []
|
||||
times = []
|
||||
done = 0
|
||||
done = []
|
||||
|
||||
class CondAgent(agents.BaseAgent):
|
||||
|
||||
@ -39,36 +40,38 @@ class TestMain(TestCase):
|
||||
nonlocal done
|
||||
times_started.append(self.now)
|
||||
while True:
|
||||
yield time.Cond(lambda agent: agent.model.schedule.time >= 10)
|
||||
times_asleep.append(self.now)
|
||||
yield time.Cond(lambda agent: agent.now >= 10,
|
||||
delta=2)
|
||||
times_awakened.append(self.now)
|
||||
if self.now >= 10:
|
||||
break
|
||||
done += 1
|
||||
done.append(self.now)
|
||||
|
||||
env = environment.Environment(agents=[{'agent_class': CondAgent}])
|
||||
|
||||
|
||||
while env.schedule.time < 11:
|
||||
env.step()
|
||||
times.append(env.now)
|
||||
env.step()
|
||||
|
||||
assert env.schedule.time == 11
|
||||
assert times_started == [0]
|
||||
assert times_awakened == [10]
|
||||
assert done == 1
|
||||
assert done == [10]
|
||||
# The first time will produce the Cond.
|
||||
# Since there are no other agents, time will not advance, but the number
|
||||
# of steps will.
|
||||
assert env.schedule.steps == 12
|
||||
assert len(times) == 12
|
||||
assert env.schedule.steps == 6
|
||||
assert len(times) == 6
|
||||
|
||||
while env.schedule.time < 12:
|
||||
env.step()
|
||||
while env.schedule.time < 13:
|
||||
times.append(env.now)
|
||||
env.step()
|
||||
|
||||
assert env.schedule.time == 12
|
||||
assert times == [0, 2, 4, 6, 8, 10, 11]
|
||||
assert env.schedule.time == 13
|
||||
assert times_started == [0, 11]
|
||||
assert times_awakened == [10, 11]
|
||||
assert done == 2
|
||||
assert times_awakened == [10]
|
||||
assert done == [10]
|
||||
# Once more to yield the cond, another one to continue
|
||||
assert env.schedule.steps == 14
|
||||
assert len(times) == 14
|
||||
assert env.schedule.steps == 7
|
||||
assert len(times) == 7
|
||||
|
Loading…
Reference in New Issue
Block a user