1
0
mirror of https://github.com/gsi-upm/soil synced 2024-11-22 03:02:28 +00:00

Refactored time

Treating time and conditions as the same entity was getting confusing, and it
added a lot of unnecessary abstraction in a critical part (the scheduler).

The scheduling queue now has the time as a floating number (faster), the agent
id (for ties) and the condition, as well as the agent. The first three
elements (time, id, condition) can be considered as the "key" for the event.

To allow for agent execution to be "randomized" within every step, a new
parameter has been added to the scheduler, which makes it add a random number to
the key in order to change the ordering.

`EventedAgent.received` now checks the messages before returning control to the
user by default.
This commit is contained in:
J. Fernando Sánchez 2022-10-20 09:14:50 +02:00
parent cbbaf73538
commit a1262edd2a
16 changed files with 324 additions and 223 deletions

View File

@ -127,7 +127,8 @@ class Driver(Evented, FSM):
) )
self.check_passengers() self.check_passengers()
self.check_messages() # This will call on_receive behind the scenes, and the agent's status will be updated # This will call on_receive behind the scenes, and the agent's status will be updated
self.check_messages()
yield Delta(30) # Wait at least 30 seconds before checking again yield Delta(30) # Wait at least 30 seconds before checking again
try: try:
@ -204,6 +205,8 @@ class Passenger(Evented, FSM):
while not self.journey: while not self.journey:
self.info(f"Passenger at: { self.pos }. Checking for responses.") self.info(f"Passenger at: { self.pos }. Checking for responses.")
try: try:
# This will call check_messages behind the scenes, and the agent's status will be updated
# If you want to avoid that, you can call it with: check=False
yield self.received(expiration=expiration) yield self.received(expiration=expiration)
except events.TimedOut: except events.TimedOut:
self.info(f"Passenger at: { self.pos }. Asking for journey.") self.info(f"Passenger at: { self.pos }. Asking for journey.")
@ -211,7 +214,6 @@ class Passenger(Evented, FSM):
journey, ttl=timeout, sender=self, agent_class=Driver journey, ttl=timeout, sender=self, agent_class=Driver
) )
expiration = self.now + timeout expiration = self.now + timeout
self.check_messages()
return self.driving_home return self.driving_home
@state @state
@ -220,7 +222,11 @@ class Passenger(Evented, FSM):
self.pos[0] != self.journey.destination[0] self.pos[0] != self.journey.destination[0]
or self.pos[1] != self.journey.destination[1] or self.pos[1] != self.journey.destination[1]
): ):
try:
yield self.received(timeout=60) yield self.received(timeout=60)
except events.TimedOut:
pass
self.info("Got home safe!") self.info("Got home safe!")
self.die() self.die()

View File

@ -133,7 +133,7 @@ class RandomAccident(BaseAgent):
math.log10(max(1, rabbits_alive)) math.log10(max(1, rabbits_alive))
) )
self.debug("Killing some rabbits with prob={}!".format(prob_death)) self.debug("Killing some rabbits with prob={}!".format(prob_death))
for i in self.iter_agents(agent_class=Rabbit): for i in self.get_agents(agent_class=Rabbit):
if i.state_id == i.dead.id: if i.state_id == i.dead.id:
continue continue
if self.prob(prob_death): if self.prob(prob_death):

View File

@ -20,12 +20,6 @@ from typing import Dict, List
from .. import serialization, utils, time, config from .. import serialization, utils, time, config
def as_node(agent):
if isinstance(agent, BaseAgent):
return agent.id
return agent
IGNORED_FIELDS = ("model", "logger") IGNORED_FIELDS = ("model", "logger")
@ -97,10 +91,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
""" """
def __init__(self, unique_id, model, name=None, interval=None, **kwargs): def __init__(self, unique_id, model, name=None, interval=None, **kwargs):
# Check for REQUIRED arguments
# Initialize agent parameters
if isinstance(unique_id, MesaAgent):
raise Exception()
assert isinstance(unique_id, int) assert isinstance(unique_id, int)
super().__init__(unique_id=unique_id, model=model) super().__init__(unique_id=unique_id, model=model)
@ -207,7 +197,8 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
def step(self): def step(self):
if not self.alive: if not self.alive:
raise time.DeadAgent(self.unique_id) raise time.DeadAgent(self.unique_id)
return super().step() or time.Delta(self.interval) super().step()
return time.Delta(self.interval)
def log(self, message, *args, level=logging.INFO, **kwargs): def log(self, message, *args, level=logging.INFO, **kwargs):
if not self.logger.isEnabledFor(level): if not self.logger.isEnabledFor(level):
@ -270,6 +261,7 @@ def prob(prob, random):
return r < prob return r < prob
def calculate_distribution(network_agents=None, agent_class=None): def calculate_distribution(network_agents=None, agent_class=None):
""" """
Calculate the threshold values (thresholds for a uniform distribution) Calculate the threshold values (thresholds for a uniform distribution)
@ -414,7 +406,7 @@ def filter_agents(
if ids: if ids:
f = (agents[aid] for aid in ids if aid in agents) f = (agents[aid] for aid in ids if aid in agents)
else: else:
f = (a for a in agents.values()) f = agents.values()
if state_id is not None and not isinstance(state_id, (tuple, list)): if state_id is not None and not isinstance(state_id, (tuple, list)):
state_id = tuple([state_id]) state_id = tuple([state_id])
@ -638,6 +630,10 @@ from .SentimentCorrelationModel import *
from .SISaModel import * from .SISaModel import *
from .CounterModel import * from .CounterModel import *
class Agent(NetworkAgent, EventedAgent):
'''Default agent class, has both network and event capabilities'''
try: try:
import scipy import scipy
from .Geo import Geo from .Geo import Geo

View File

@ -1,57 +1,74 @@
from . import BaseAgent from . import BaseAgent
from ..events import Message, Tell, Ask, Reply, TimedOut from ..events import Message, Tell, Ask, TimedOut
from ..time import Cond from ..time import BaseCond
from functools import partial from functools import partial
from collections import deque from collections import deque
class Evented(BaseAgent): class ReceivedOrTimeout(BaseCond):
def __init__(self, agent, expiration=None, timeout=None, check=True, ignore=False, **kwargs):
if expiration is None:
if timeout is not None:
expiration = agent.now + timeout
self.expiration = expiration
self.ignore = ignore
self.check = check
super().__init__(**kwargs)
def expired(self, time):
return self.expiration and self.expiration < time
def ready(self, agent, time):
return len(agent._inbox) or self.expired(time)
def return_value(self, agent):
if not self.ignore and self.expired(agent.now):
raise TimedOut('No messages received')
if self.check:
agent.check_messages()
return None
def schedule_next(self, time, delta, first=False):
if self._delta is not None:
delta = self._delta
return (time + delta, self)
def __repr__(self):
return f'ReceivedOrTimeout(expires={self.expiration})'
class EventedAgent(BaseAgent):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._inbox = deque() self._inbox = deque()
self._received = 0
self._processed = 0 self._processed = 0
def on_receive(self, *args, **kwargs): def on_receive(self, *args, **kwargs):
pass pass
def received(self, expiration=None, timeout=None): def received(self, *args, **kwargs):
current = self._received return ReceivedOrTimeout(self, *args, **kwargs)
if expiration is None:
expiration = float('inf') if timeout is None else self.now + timeout
if expiration < self.now: def tell(self, msg, sender=None):
raise ValueError("Invalid expiration time") self._inbox.append(Tell(timestamp=self.now, payload=msg, sender=sender))
def ready(agent):
return agent._received > current or agent.now >= expiration
def value(agent):
if agent.now > expiration:
raise TimedOut("No message received")
c = Cond(func=ready, return_func=value)
c._checked = True
return c
def tell(self, msg, sender):
self._received += 1
self._inbox.append(Tell(payload=msg, sender=sender))
def ask(self, msg, timeout=None): def ask(self, msg, timeout=None):
self._received += 1 ask = Ask(timestamp=self.now, payload=msg)
ask = Ask(payload=msg)
self._inbox.append(ask) self._inbox.append(ask)
expiration = float('inf') if timeout is None else self.now + timeout expiration = float('inf') if timeout is None else self.now + timeout
return ask.replied(expiration=expiration) return ask.replied(expiration=expiration)
def check_messages(self): def check_messages(self):
changed = False
while self._inbox: while self._inbox:
msg = self._inbox.popleft() msg = self._inbox.popleft()
self._processed += 1 self._processed += 1
if msg.expired(self.now): if msg.expired(self.now):
continue continue
changed = True
reply = self.on_receive(msg.payload, sender=msg.sender) reply = self.on_receive(msg.payload, sender=msg.sender)
if isinstance(msg, Ask): if isinstance(msg, Ask):
msg.reply = reply msg.reply = reply
return changed
Evented = EventedAgent

View File

@ -54,7 +54,7 @@ class NetworkAgent(BaseAgent):
return G return G
def remove_node(self): def remove_node(self):
print(f"Removing node for {self.unique_id}: {self.node_id}") self.debug(f"Removing node for {self.unique_id}: {self.node_id}")
self.G.remove_node(self.node_id) self.G.remove_node(self.node_id)
self.node_id = None self.node_id = None
@ -80,3 +80,5 @@ class NetworkAgent(BaseAgent):
if remove: if remove:
self.remove_node() self.remove_node()
return super().die() return super().die()
NetAgent = NetworkAgent

View File

@ -310,17 +310,20 @@ class NetworkEnvironment(BaseEnvironment):
self.add_agent(node_id=node_id, agent_class=a_class, **agent_params) self.add_agent(node_id=node_id, agent_class=a_class, **agent_params)
Environment = NetworkEnvironment class EventedEnvironment(BaseEnvironment):
def broadcast(self, msg, sender=None, expiration=None, ttl=None, **kwargs):
class EventedEnvironment(Environment):
def broadcast(self, msg, sender, expiration=None, ttl=None, **kwargs):
for agent in self.agents(**kwargs): for agent in self.agents(**kwargs):
if agent == sender:
continue
self.logger.info(f'Telling {repr(agent)}: {msg} ttl={ttl}') self.logger.info(f'Telling {repr(agent)}: {msg} ttl={ttl}')
try: try:
inbox = agent._inbox inbox = agent._inbox
except AttributeError: except AttributeError:
self.logger.info(f'Agent {agent.unique_id} cannot receive events because it does not have an inbox') self.logger.info(f'Agent {agent.unique_id} cannot receive events because it does not have an inbox')
continue continue
# Allow for AttributeError exceptions in this part of the code
inbox.append(events.Tell(payload=msg, sender=sender, expiration=expiration if ttl is None else self.now+ttl)) inbox.append(events.Tell(payload=msg, sender=sender, expiration=expiration if ttl is None else self.now+ttl))
class Environment(NetworkEnvironment, EventedEnvironment):
'''Default environment class, has both network and event capabilities'''

View File

@ -1,4 +1,4 @@
from .time import Cond from .time import BaseCond
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from uuid import uuid4 from uuid import uuid4
@ -11,6 +11,7 @@ class Message:
payload: Any payload: Any
sender: Any = None sender: Any = None
expiration: float = None expiration: float = None
timestamp: float = None
id: int = field(default_factory=uuid4) id: int = field(default_factory=uuid4)
def expired(self, when): def expired(self, when):
@ -20,19 +21,28 @@ class Reply(Message):
source: Message source: Message
class ReplyCond(BaseCond):
def __init__(self, ask, *args, **kwargs):
self._ask = ask
super().__init__(*args, **kwargs)
def ready(self, agent, time):
return self._ask.reply is not None or self._ask.expired(time)
def return_value(self, agent):
if self._ask.expired(agent.now):
raise TimedOut()
return self._ask.reply
def __repr__(self):
return f"ReplyCond({self._ask.id})"
class Ask(Message): class Ask(Message):
reply: Message = None reply: Message = None
def replied(self, expiration=None): def replied(self, expiration=None):
def ready(agent): return ReplyCond(self)
return self.reply is not None or agent.now > expiration
def value(agent):
if agent.now > expiration:
raise TimedOut(f'No answer received for {self}')
return self.reply
return Cond(func=ready, return_func=value)
class Tell(Message): class Tell(Message):

View File

@ -59,7 +59,6 @@ def find_unassigned(G, shuffle=False, random=random):
If node_id is None, a node without an agent_id will be found. If node_id is None, a node without an agent_id will be found.
""" """
# TODO: test
candidates = list(G.nodes(data=True)) candidates = list(G.nodes(data=True))
if shuffle: if shuffle:
random.shuffle(candidates) random.shuffle(candidates)

View File

@ -221,8 +221,6 @@ def deserialize(type_, value=None, globs=None, **kwargs):
def deserialize_all(names, *args, known_modules=KNOWN_MODULES, **kwargs): def deserialize_all(names, *args, known_modules=KNOWN_MODULES, **kwargs):
"""Return the list of deserialized objects""" """Return the list of deserialized objects"""
# TODO: remove
print("SERIALIZATION", kwargs)
objects = [] objects = []
for name in names: for name in names:
mod = deserialize(name, known_modules=known_modules) mod = deserialize(name, known_modules=known_modules)

View File

@ -66,7 +66,7 @@ class Simulation:
if ignored: if ignored:
d.setdefault("extra", {}).update(ignored) d.setdefault("extra", {}).update(ignored)
if ignored: if ignored:
print(f'Warning: Ignoring these parameters (added to "extra"): { ignored }') logger.warning(f'Ignoring these parameters (added to "extra"): { ignored }')
d.update(kwargs) d.update(kwargs)
return cls(**d) return cls(**d)

View File

@ -1,10 +1,11 @@
from mesa.time import BaseScheduler from mesa.time import BaseScheduler
from queue import Empty from queue import Empty
from heapq import heappush, heappop, heapify from heapq import heappush, heappop
import math import math
from inspect import getsource from inspect import getsource
from numbers import Number from numbers import Number
from textwrap import dedent
from .utils import logger from .utils import logger
from mesa import Agent as MesaAgent from mesa import Agent as MesaAgent
@ -23,65 +24,11 @@ class When:
return time return time
self._time = time self._time = time
def next(self, time): def abs(self, time):
return self._time return self._time
def abs(self, time): def schedule_next(self, time, delta, first=False):
return self return (self._time, None)
def __repr__(self):
return str(f"When({self._time})")
def __lt__(self, other):
if isinstance(other, Number):
return self._time < other
return self._time < other.next(self._time)
def __gt__(self, other):
if isinstance(other, Number):
return self._time > other
return self._time > other.next(self._time)
def ready(self, agent):
return self._time <= agent.model.schedule.time
def return_value(self, agent):
return None
class Cond(When):
def __init__(self, func, delta=1, return_func=lambda agent: None):
self._func = func
self._delta = delta
self._checked = False
self._return_func = return_func
def next(self, time):
if self._checked:
return time + self._delta
return time
def abs(self, time):
return self
def ready(self, agent):
self._checked = True
return self._func(agent)
def return_value(self, agent):
return self._return_func(agent)
def __eq__(self, other):
return False
def __lt__(self, other):
return True
def __gt__(self, other):
return False
def __repr__(self):
return str(f'Cond("{getsource(self._func)}")')
NEVER = When(INFINITY) NEVER = When(INFINITY)
@ -91,48 +38,94 @@ class Delta(When):
def __init__(self, delta): def __init__(self, delta):
self._delta = delta self._delta = delta
def abs(self, time):
return self._time + self._delta
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, Delta): if isinstance(other, Delta):
return self._delta == other._delta return self._delta == other._delta
return False return False
def abs(self, time): def schedule_next(self, time, delta, first=False):
return When(self._delta + time) return (time + self._delta, None)
def next(self, time):
return time + self._delta
def __repr__(self): def __repr__(self):
return str(f"Delta({self._delta})") return str(f"Delta({self._delta})")
class BaseCond:
def __init__(self, msg=None, delta=None, eager=False):
self._msg = msg
self._delta = delta
self.eager = eager
def schedule_next(self, time, delta, first=False):
if first and self.eager:
return (time, self)
if self._delta:
delta = self._delta
return (time + delta, self)
def return_value(self, agent):
return None
def __repr__(self):
return self._msg or self.__class__.__name__
class Cond(BaseCond):
def __init__(self, func, *args, **kwargs):
self._func = func
super().__init__(*args, **kwargs)
def ready(self, agent, time):
return self._func(agent)
def __repr__(self):
if self._msg:
return self._msg
return str(f'Cond("{dedent(getsource(self._func)).strip()}")')
class TimedActivation(BaseScheduler): class TimedActivation(BaseScheduler):
"""A scheduler which activates each agent when the agent requests. """A scheduler which activates each agent when the agent requests.
In each activation, each agent will update its 'next_time'. In each activation, each agent will update its 'next_time'.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, shuffle=True, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._next = {} self._next = {}
self._queue = [] self._queue = []
self.next_time = 0 self._shuffle = shuffle
self.step_interval = 1
self.logger = logger.getChild(f"time_{ self.model }") self.logger = logger.getChild(f"time_{ self.model }")
def add(self, agent: MesaAgent, when=None): def add(self, agent: MesaAgent, when=None):
if when is None: if when is None:
when = When(self.time) when = self.time
elif not isinstance(when, When): elif isinstance(when, When):
when = When(when) when = when.abs()
if agent.unique_id in self._agents:
del self._agents[agent.unique_id]
if agent.unique_id in self._next:
self._queue.remove((self._next[agent.unique_id], agent))
heapify(self._queue)
self._next[agent.unique_id] = when self._schedule(agent, None, when)
heappush(self._queue, (when, agent))
super().add(agent) super().add(agent)
def _schedule(self, agent, condition=None, when=None):
if condition:
if not when:
when, condition = condition.schedule_next(when or self.time,
self.step_interval)
else:
if when is None:
when = self.time + self.step_interval
condition = None
if self._shuffle:
key = (when, self.model.random.random(), condition)
else:
key = (when, agent.unique_id, condition)
self._next[agent.unique_id] = key
heappush(self._queue, (key, agent))
def step(self) -> None: def step(self) -> None:
""" """
Executes agents in order, one at a time. After each step, Executes agents in order, one at a time. After each step,
@ -143,73 +136,59 @@ class TimedActivation(BaseScheduler):
if not self.model.running: if not self.model.running:
return return
when = NEVER
to_process = []
skipped = []
next_time = INFINITY
ix = 0
self.logger.debug(f"Queue length: {len(self._queue)}") self.logger.debug(f"Queue length: {len(self._queue)}")
while self._queue: while self._queue:
(when, agent) = self._queue[0] ((when, _id, cond), agent) = self._queue[0]
if when > self.time: if when > self.time:
break break
heappop(self._queue) heappop(self._queue)
if when.ready(agent): if cond:
if not cond.ready(agent, self.time):
self._schedule(agent, cond)
continue
try: try:
agent._last_return = when.return_value(agent) agent._last_return = cond.return_value(agent)
except Exception as ex: except Exception as ex:
agent._last_except = ex agent._last_except = ex
else:
agent._last_return = None
agent._last_except = None
self._next.pop(agent.unique_id, None)
to_process.append(agent)
continue
next_time = min(next_time, when.next(self.time))
self._next[agent.unique_id] = next_time
skipped.append((when, agent))
if self._queue:
next_time = min(next_time, self._queue[0][0].next(self.time))
self._queue = [*skipped, *self._queue]
for agent in to_process:
self.logger.debug(f"Stepping agent {agent}") self.logger.debug(f"Stepping agent {agent}")
self._next.pop(agent.unique_id, None)
try: try:
returned = ((agent.step() or Delta(1))).abs(self.time) returned = agent.step()
except DeadAgent: except DeadAgent:
if agent.unique_id in self._next:
del self._next[agent.unique_id]
agent.alive = False agent.alive = False
continue continue
# Check status for MESA agents
if not getattr(agent, "alive", True): if not getattr(agent, "alive", True):
continue continue
value = returned.next(self.time) if returned:
agent._last_return = value next_check = returned.schedule_next(self.time, self.step_interval, first=True)
self._schedule(agent, when=next_check[0], condition=next_check[1])
if value < self.time:
raise Exception(
f"Cannot schedule an agent for a time in the past ({when} < {self.time})"
)
if value < INFINITY:
next_time = min(value, next_time)
self._next[agent.unique_id] = returned
heappush(self._queue, (returned, agent))
else: else:
assert not self._next[agent.unique_id] next_check = (self.time + self.step_interval, None)
self._schedule(agent)
self.steps += 1 self.steps += 1
self.logger.debug(f"Updating time step: {self.time} -> {next_time}")
self.time = next_time
if not self._queue or next_time == INFINITY: if not self._queue:
self.time = INFINITY
self.model.running = False self.model.running = False
return self.time return self.time
next_time = self._queue[0][0][0]
if next_time < self.time:
raise Exception(
f"An agent has been scheduled for a time in the past, there is probably an error ({when} < {self.time})"
)
self.logger.debug(f"Updating time step: {self.time} -> {next_time}")
self.time = next_time

View File

@ -12,12 +12,11 @@ class Dead(agents.FSM):
return self.die() return self.die()
class TestMain(TestCase): class TestAgents(TestCase):
def test_die_returns_infinity(self): def test_die_returns_infinity(self):
'''The last step of a dead agent should return time.INFINITY''' '''The last step of a dead agent should return time.INFINITY'''
d = Dead(unique_id=0, model=environment.Environment()) d = Dead(unique_id=0, model=environment.Environment())
ret = d.step().abs(0) ret = d.step()
print(ret, "next")
assert ret == stime.NEVER assert ret == stime.NEVER
def test_die_raises_exception(self): def test_die_raises_exception(self):
@ -66,4 +65,95 @@ class TestMain(TestCase):
a.step() a.step()
assert a.run == 1 assert a.run == 1
a.step() a.step()
assert a.run == 2
def test_broadcast(self):
'''
An agent should be able to broadcast messages to every other agent, AND each receiver should be able
to process it
'''
class BCast(agents.Evented):
pings_received = 0
def step(self):
print(self.model.broadcast)
try:
self.model.broadcast('PING')
except Exception as ex:
print(ex)
while True:
self.check_messages()
yield
def on_receive(self, msg, sender=None):
self.pings_received += 1
e = environment.EventedEnvironment()
for i in range(10):
e.add_agent(agent_class=BCast)
e.step()
pings_received = lambda: [a.pings_received for a in e.agents]
assert sorted(pings_received()) == list(range(1, 11))
e.step()
assert all(x==10 for x in pings_received())
def test_ask_messages(self):
'''
An agent should be able to ask another agent, and wait for a response.
'''
# #Results depend on ordering (agents are shuffled), so force the first agent
pings = []
pongs = []
responses = []
class Ping(agents.EventedAgent):
def step(self):
target_id = (self.unique_id + 1) % self.count_agents()
target = self.model.agents[target_id]
print('starting')
while True:
print('Pings: ', pings, responses or not pings, self.model.schedule._queue)
if pongs or not pings:
pings.append(self.now)
response = yield target.ask('PING')
responses.append(response)
else:
print('NOT sending ping')
print('Checking msgs')
# Do not advance until we have received a message.
# warning: it will wait at least until the next time in the simulation
yield self.received(check=True)
print('done')
def on_receive(self, msg, sender=None):
if msg == 'PING':
pongs.append(self.now)
return 'PONG'
e = environment.EventedEnvironment()
for i in range(2):
e.add_agent(agent_class=Ping)
assert e.now == 0
# There should be a delay of one step between agent 0 and 1
# On the first step:
# Agent 0 sends a PING, but blocks before a PONG
# Agent 1 sends a PONG, and blocks after its PING
# After that step, every agent can both receive (there are pending messages) and then send.
e.step()
assert e.now == 1
assert pings == [0]
assert pongs == []
e.step()
assert e.now == 2
assert pings == [0, 1]
assert pongs == [1]
e.step()
assert e.now == 3
assert pings == [0, 1, 2]
assert pongs == [1, 2]

View File

@ -44,6 +44,8 @@ def add_example_tests():
for cfg, path in serialization.load_files( for cfg, path in serialization.load_files(
join(EXAMPLES, "**", "*.yml"), join(EXAMPLES, "**", "*.yml"),
): ):
if 'soil_output' in path:
continue
p = make_example_test(path=path, cfg=config.Config.from_raw(cfg)) p = make_example_test(path=path, cfg=config.Config.from_raw(cfg))
fname = os.path.basename(path) fname = os.path.basename(path)
p.__name__ = "test_example_file_%s" % fname p.__name__ = "test_example_file_%s" % fname

View File

@ -172,25 +172,21 @@ class TestMain(TestCase):
assert len(configs) > 0 assert len(configs) > 0
def test_until(self): def test_until(self):
config = { n_runs = 0
"name": "until_sim",
"model_params": { class CheckRun(agents.BaseAgent):
"network_params": {}, def step(self):
"agents": { nonlocal n_runs
"fixed": [ n_runs += 1
{ return super().step()
"agent_class": agents.BaseAgent,
} n_trials = 50
] max_time = 2
}, s = simulation.Simulation(model_params={'agents': [{'agent_class': CheckRun}]},
}, num_trials=n_trials, max_time=max_time)
"max_time": 2,
"num_trials": 50,
}
s = simulation.from_config(config)
runs = list(s.run_simulation(dry_run=True)) runs = list(s.run_simulation(dry_run=True))
over = list(x.now for x in runs if x.now > 2) over = list(x.now for x in runs if x.now > 2)
assert len(runs) == config["num_trials"] assert len(runs) == n_trials
assert len(over) == 0 assert len(over) == 0
def test_fsm(self): def test_fsm(self):

View File

@ -72,7 +72,7 @@ class TestNetwork(TestCase):
assert len(env.agents) == 2 assert len(env.agents) == 2
assert env.agents[1].count_agents(state_id="normal") == 2 assert env.agents[1].count_agents(state_id="normal") == 2
assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1 assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1
assert env.agents[0].neighbors == 1 assert env.agents[0].count_neighbors() == 1
def test_custom_agent_neighbors(self): def test_custom_agent_neighbors(self):
"""Allow for search of neighbors with a certain state_id""" """Allow for search of neighbors with a certain state_id"""
@ -90,7 +90,7 @@ class TestNetwork(TestCase):
env = s.run_simulation(dry_run=True)[0] env = s.run_simulation(dry_run=True)[0]
assert env.agents[1].count_agents(state_id="normal") == 2 assert env.agents[1].count_agents(state_id="normal") == 2
assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1 assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1
assert env.agents[0].neighbors == 1 assert env.agents[0].count_neighbors() == 1
def test_subgraph(self): def test_subgraph(self):
"""An agent should be able to subgraph the global topology""" """An agent should be able to subgraph the global topology"""

View File

@ -30,8 +30,9 @@ class TestMain(TestCase):
times_started = [] times_started = []
times_awakened = [] times_awakened = []
times_asleep = []
times = [] times = []
done = 0 done = []
class CondAgent(agents.BaseAgent): class CondAgent(agents.BaseAgent):
@ -39,36 +40,38 @@ class TestMain(TestCase):
nonlocal done nonlocal done
times_started.append(self.now) times_started.append(self.now)
while True: while True:
yield time.Cond(lambda agent: agent.model.schedule.time >= 10) times_asleep.append(self.now)
yield time.Cond(lambda agent: agent.now >= 10,
delta=2)
times_awakened.append(self.now) times_awakened.append(self.now)
if self.now >= 10: if self.now >= 10:
break break
done += 1 done.append(self.now)
env = environment.Environment(agents=[{'agent_class': CondAgent}]) env = environment.Environment(agents=[{'agent_class': CondAgent}])
while env.schedule.time < 11: while env.schedule.time < 11:
env.step()
times.append(env.now) times.append(env.now)
env.step()
assert env.schedule.time == 11 assert env.schedule.time == 11
assert times_started == [0] assert times_started == [0]
assert times_awakened == [10] assert times_awakened == [10]
assert done == 1 assert done == [10]
# The first time will produce the Cond. # The first time will produce the Cond.
# Since there are no other agents, time will not advance, but the number assert env.schedule.steps == 6
# of steps will. assert len(times) == 6
assert env.schedule.steps == 12
assert len(times) == 12
while env.schedule.time < 12: while env.schedule.time < 13:
env.step()
times.append(env.now) times.append(env.now)
env.step()
assert env.schedule.time == 12 assert times == [0, 2, 4, 6, 8, 10, 11]
assert env.schedule.time == 13
assert times_started == [0, 11] assert times_started == [0, 11]
assert times_awakened == [10, 11] assert times_awakened == [10]
assert done == 2 assert done == [10]
# Once more to yield the cond, another one to continue # Once more to yield the cond, another one to continue
assert env.schedule.steps == 14 assert env.schedule.steps == 7
assert len(times) == 14 assert len(times) == 7