From a1262edd2a54431e6e8a09711e8b274c42c7b05c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=2E=20Fernando=20S=C3=A1nchez?= Date: Thu, 20 Oct 2022 09:14:50 +0200 Subject: [PATCH] Refactored time Treating time and conditions as the same entity was getting confusing, and it added a lot of unnecessary abstraction in a critical part (the scheduler). The scheduling queue now has the time as a floating number (faster), the agent id (for ties) and the condition, as well as the agent. The first three elements (time, id, condition) can be considered as the "key" for the event. To allow for agent execution to be "randomized" within every step, a new parameter has been added to the scheduler, which makes it add a random number to the key in order to change the ordering. `EventedAgent.received` now checks the messages before returning control to the user by default. --- examples/events_and_messages/cars.py | 12 +- examples/rabbits/basic/rabbit_agents.py | 2 +- soil/agents/__init__.py | 20 +-- soil/agents/evented.py | 73 +++++---- soil/agents/network_agents.py | 4 +- soil/environment.py | 13 +- soil/events.py | 30 ++-- soil/network.py | 1 - soil/serialization.py | 2 - soil/simulation.py | 2 +- soil/time.py | 209 +++++++++++------------- tests/test_agents.py | 98 ++++++++++- tests/test_examples.py | 2 + tests/test_main.py | 30 ++-- tests/test_network.py | 4 +- tests/test_time.py | 35 ++-- 16 files changed, 319 insertions(+), 218 deletions(-) diff --git a/examples/events_and_messages/cars.py b/examples/events_and_messages/cars.py index 04d7894..c612f70 100644 --- a/examples/events_and_messages/cars.py +++ b/examples/events_and_messages/cars.py @@ -127,7 +127,8 @@ class Driver(Evented, FSM): ) self.check_passengers() - self.check_messages() # This will call on_receive behind the scenes, and the agent's status will be updated + # This will call on_receive behind the scenes, and the agent's status will be updated + self.check_messages() yield Delta(30) # Wait at least 30 seconds before checking again try: @@ -204,6 +205,8 @@ class Passenger(Evented, FSM): while not self.journey: self.info(f"Passenger at: { self.pos }. Checking for responses.") try: + # This will call check_messages behind the scenes, and the agent's status will be updated + # If you want to avoid that, you can call it with: check=False yield self.received(expiration=expiration) except events.TimedOut: self.info(f"Passenger at: { self.pos }. Asking for journey.") @@ -211,7 +214,6 @@ class Passenger(Evented, FSM): journey, ttl=timeout, sender=self, agent_class=Driver ) expiration = self.now + timeout - self.check_messages() return self.driving_home @state @@ -220,7 +222,11 @@ class Passenger(Evented, FSM): self.pos[0] != self.journey.destination[0] or self.pos[1] != self.journey.destination[1] ): - yield self.received(timeout=60) + try: + yield self.received(timeout=60) + except events.TimedOut: + pass + self.info("Got home safe!") self.die() diff --git a/examples/rabbits/basic/rabbit_agents.py b/examples/rabbits/basic/rabbit_agents.py index b28d2e9..4c0981b 100644 --- a/examples/rabbits/basic/rabbit_agents.py +++ b/examples/rabbits/basic/rabbit_agents.py @@ -133,7 +133,7 @@ class RandomAccident(BaseAgent): math.log10(max(1, rabbits_alive)) ) self.debug("Killing some rabbits with prob={}!".format(prob_death)) - for i in self.iter_agents(agent_class=Rabbit): + for i in self.get_agents(agent_class=Rabbit): if i.state_id == i.dead.id: continue if self.prob(prob_death): diff --git a/soil/agents/__init__.py b/soil/agents/__init__.py index c13999f..a9c1fc3 100644 --- a/soil/agents/__init__.py +++ b/soil/agents/__init__.py @@ -20,12 +20,6 @@ from typing import Dict, List from .. import serialization, utils, time, config -def as_node(agent): - if isinstance(agent, BaseAgent): - return agent.id - return agent - - IGNORED_FIELDS = ("model", "logger") @@ -97,10 +91,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): """ def __init__(self, unique_id, model, name=None, interval=None, **kwargs): - # Check for REQUIRED arguments - # Initialize agent parameters - if isinstance(unique_id, MesaAgent): - raise Exception() assert isinstance(unique_id, int) super().__init__(unique_id=unique_id, model=model) @@ -207,7 +197,8 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): def step(self): if not self.alive: raise time.DeadAgent(self.unique_id) - return super().step() or time.Delta(self.interval) + super().step() + return time.Delta(self.interval) def log(self, message, *args, level=logging.INFO, **kwargs): if not self.logger.isEnabledFor(level): @@ -270,6 +261,7 @@ def prob(prob, random): return r < prob + def calculate_distribution(network_agents=None, agent_class=None): """ Calculate the threshold values (thresholds for a uniform distribution) @@ -414,7 +406,7 @@ def filter_agents( if ids: f = (agents[aid] for aid in ids if aid in agents) else: - f = (a for a in agents.values()) + f = agents.values() if state_id is not None and not isinstance(state_id, (tuple, list)): state_id = tuple([state_id]) @@ -638,6 +630,10 @@ from .SentimentCorrelationModel import * from .SISaModel import * from .CounterModel import * + +class Agent(NetworkAgent, EventedAgent): + '''Default agent class, has both network and event capabilities''' + try: import scipy from .Geo import Geo diff --git a/soil/agents/evented.py b/soil/agents/evented.py index 451b570..340c29a 100644 --- a/soil/agents/evented.py +++ b/soil/agents/evented.py @@ -1,57 +1,74 @@ from . import BaseAgent -from ..events import Message, Tell, Ask, Reply, TimedOut -from ..time import Cond +from ..events import Message, Tell, Ask, TimedOut +from ..time import BaseCond from functools import partial from collections import deque -class Evented(BaseAgent): +class ReceivedOrTimeout(BaseCond): + def __init__(self, agent, expiration=None, timeout=None, check=True, ignore=False, **kwargs): + if expiration is None: + if timeout is not None: + expiration = agent.now + timeout + self.expiration = expiration + self.ignore = ignore + self.check = check + super().__init__(**kwargs) + + def expired(self, time): + return self.expiration and self.expiration < time + + def ready(self, agent, time): + return len(agent._inbox) or self.expired(time) + + def return_value(self, agent): + if not self.ignore and self.expired(agent.now): + raise TimedOut('No messages received') + if self.check: + agent.check_messages() + return None + + def schedule_next(self, time, delta, first=False): + if self._delta is not None: + delta = self._delta + return (time + delta, self) + + def __repr__(self): + return f'ReceivedOrTimeout(expires={self.expiration})' + + +class EventedAgent(BaseAgent): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._inbox = deque() - self._received = 0 self._processed = 0 - def on_receive(self, *args, **kwargs): pass - def received(self, expiration=None, timeout=None): - current = self._received - if expiration is None: - expiration = float('inf') if timeout is None else self.now + timeout - - if expiration < self.now: - raise ValueError("Invalid expiration time") - - def ready(agent): - return agent._received > current or agent.now >= expiration + def received(self, *args, **kwargs): + return ReceivedOrTimeout(self, *args, **kwargs) - def value(agent): - if agent.now > expiration: - raise TimedOut("No message received") - - c = Cond(func=ready, return_func=value) - c._checked = True - return c - - def tell(self, msg, sender): - self._received += 1 - self._inbox.append(Tell(payload=msg, sender=sender)) + def tell(self, msg, sender=None): + self._inbox.append(Tell(timestamp=self.now, payload=msg, sender=sender)) def ask(self, msg, timeout=None): - self._received += 1 - ask = Ask(payload=msg) + ask = Ask(timestamp=self.now, payload=msg) self._inbox.append(ask) expiration = float('inf') if timeout is None else self.now + timeout return ask.replied(expiration=expiration) def check_messages(self): + changed = False while self._inbox: msg = self._inbox.popleft() self._processed += 1 if msg.expired(self.now): continue + changed = True reply = self.on_receive(msg.payload, sender=msg.sender) if isinstance(msg, Ask): msg.reply = reply + return changed + +Evented = EventedAgent diff --git a/soil/agents/network_agents.py b/soil/agents/network_agents.py index d9950cf..cd57943 100644 --- a/soil/agents/network_agents.py +++ b/soil/agents/network_agents.py @@ -54,7 +54,7 @@ class NetworkAgent(BaseAgent): return G def remove_node(self): - print(f"Removing node for {self.unique_id}: {self.node_id}") + self.debug(f"Removing node for {self.unique_id}: {self.node_id}") self.G.remove_node(self.node_id) self.node_id = None @@ -80,3 +80,5 @@ class NetworkAgent(BaseAgent): if remove: self.remove_node() return super().die() + +NetAgent = NetworkAgent diff --git a/soil/environment.py b/soil/environment.py index 3a48d30..4be1625 100644 --- a/soil/environment.py +++ b/soil/environment.py @@ -310,17 +310,20 @@ class NetworkEnvironment(BaseEnvironment): self.add_agent(node_id=node_id, agent_class=a_class, **agent_params) -Environment = NetworkEnvironment - - -class EventedEnvironment(Environment): - def broadcast(self, msg, sender, expiration=None, ttl=None, **kwargs): +class EventedEnvironment(BaseEnvironment): + def broadcast(self, msg, sender=None, expiration=None, ttl=None, **kwargs): for agent in self.agents(**kwargs): + if agent == sender: + continue self.logger.info(f'Telling {repr(agent)}: {msg} ttl={ttl}') try: inbox = agent._inbox except AttributeError: self.logger.info(f'Agent {agent.unique_id} cannot receive events because it does not have an inbox') continue + # Allow for AttributeError exceptions in this part of the code inbox.append(events.Tell(payload=msg, sender=sender, expiration=expiration if ttl is None else self.now+ttl)) + +class Environment(NetworkEnvironment, EventedEnvironment): + '''Default environment class, has both network and event capabilities''' diff --git a/soil/events.py b/soil/events.py index 3bc50eb..25f471a 100644 --- a/soil/events.py +++ b/soil/events.py @@ -1,4 +1,4 @@ -from .time import Cond +from .time import BaseCond from dataclasses import dataclass, field from typing import Any from uuid import uuid4 @@ -11,6 +11,7 @@ class Message: payload: Any sender: Any = None expiration: float = None + timestamp: float = None id: int = field(default_factory=uuid4) def expired(self, when): @@ -20,19 +21,28 @@ class Reply(Message): source: Message +class ReplyCond(BaseCond): + def __init__(self, ask, *args, **kwargs): + self._ask = ask + super().__init__(*args, **kwargs) + + def ready(self, agent, time): + return self._ask.reply is not None or self._ask.expired(time) + + def return_value(self, agent): + if self._ask.expired(agent.now): + raise TimedOut() + return self._ask.reply + + def __repr__(self): + return f"ReplyCond({self._ask.id})" + + class Ask(Message): reply: Message = None def replied(self, expiration=None): - def ready(agent): - return self.reply is not None or agent.now > expiration - - def value(agent): - if agent.now > expiration: - raise TimedOut(f'No answer received for {self}') - return self.reply - - return Cond(func=ready, return_func=value) + return ReplyCond(self) class Tell(Message): diff --git a/soil/network.py b/soil/network.py index be7d96f..a717021 100644 --- a/soil/network.py +++ b/soil/network.py @@ -59,7 +59,6 @@ def find_unassigned(G, shuffle=False, random=random): If node_id is None, a node without an agent_id will be found. """ - # TODO: test candidates = list(G.nodes(data=True)) if shuffle: random.shuffle(candidates) diff --git a/soil/serialization.py b/soil/serialization.py index f0a98df..cd34a02 100644 --- a/soil/serialization.py +++ b/soil/serialization.py @@ -221,8 +221,6 @@ def deserialize(type_, value=None, globs=None, **kwargs): def deserialize_all(names, *args, known_modules=KNOWN_MODULES, **kwargs): """Return the list of deserialized objects""" - # TODO: remove - print("SERIALIZATION", kwargs) objects = [] for name in names: mod = deserialize(name, known_modules=known_modules) diff --git a/soil/simulation.py b/soil/simulation.py index 5746e67..75947de 100644 --- a/soil/simulation.py +++ b/soil/simulation.py @@ -66,7 +66,7 @@ class Simulation: if ignored: d.setdefault("extra", {}).update(ignored) if ignored: - print(f'Warning: Ignoring these parameters (added to "extra"): { ignored }') + logger.warning(f'Ignoring these parameters (added to "extra"): { ignored }') d.update(kwargs) return cls(**d) diff --git a/soil/time.py b/soil/time.py index ec1dc02..7e11201 100644 --- a/soil/time.py +++ b/soil/time.py @@ -1,10 +1,11 @@ from mesa.time import BaseScheduler from queue import Empty -from heapq import heappush, heappop, heapify +from heapq import heappush, heappop import math from inspect import getsource from numbers import Number +from textwrap import dedent from .utils import logger from mesa import Agent as MesaAgent @@ -23,87 +24,67 @@ class When: return time self._time = time - def next(self, time): - return self._time - def abs(self, time): - return self - - def __repr__(self): - return str(f"When({self._time})") - - def __lt__(self, other): - if isinstance(other, Number): - return self._time < other - return self._time < other.next(self._time) + return self._time - def __gt__(self, other): - if isinstance(other, Number): - return self._time > other - return self._time > other.next(self._time) + def schedule_next(self, time, delta, first=False): + return (self._time, None) - def ready(self, agent): - return self._time <= agent.model.schedule.time - def return_value(self, agent): - return None +NEVER = When(INFINITY) -class Cond(When): - def __init__(self, func, delta=1, return_func=lambda agent: None): - self._func = func +class Delta(When): + def __init__(self, delta): self._delta = delta - self._checked = False - self._return_func = return_func - - def next(self, time): - if self._checked: - return time + self._delta - return time def abs(self, time): - return self - - def ready(self, agent): - self._checked = True - return self._func(agent) - - def return_value(self, agent): - return self._return_func(agent) + return self._time + self._delta def __eq__(self, other): + if isinstance(other, Delta): + return self._delta == other._delta return False - def __lt__(self, other): - return True - - def __gt__(self, other): - return False + def schedule_next(self, time, delta, first=False): + return (time + self._delta, None) def __repr__(self): - return str(f'Cond("{getsource(self._func)}")') + return str(f"Delta({self._delta})") -NEVER = When(INFINITY) +class BaseCond: + def __init__(self, msg=None, delta=None, eager=False): + self._msg = msg + self._delta = delta + self.eager = eager + def schedule_next(self, time, delta, first=False): + if first and self.eager: + return (time, self) + if self._delta: + delta = self._delta + return (time + delta, self) -class Delta(When): - def __init__(self, delta): - self._delta = delta + def return_value(self, agent): + return None - def __eq__(self, other): - if isinstance(other, Delta): - return self._delta == other._delta - return False + def __repr__(self): + return self._msg or self.__class__.__name__ - def abs(self, time): - return When(self._delta + time) - def next(self, time): - return time + self._delta +class Cond(BaseCond): + def __init__(self, func, *args, **kwargs): + self._func = func + super().__init__(*args, **kwargs) + + def ready(self, agent, time): + return self._func(agent) def __repr__(self): - return str(f"Delta({self._delta})") + if self._msg: + return self._msg + return str(f'Cond("{dedent(getsource(self._func)).strip()}")') class TimedActivation(BaseScheduler): @@ -111,28 +92,40 @@ class TimedActivation(BaseScheduler): In each activation, each agent will update its 'next_time'. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, shuffle=True, **kwargs): super().__init__(*args, **kwargs) self._next = {} self._queue = [] - self.next_time = 0 + self._shuffle = shuffle + self.step_interval = 1 self.logger = logger.getChild(f"time_{ self.model }") def add(self, agent: MesaAgent, when=None): if when is None: - when = When(self.time) - elif not isinstance(when, When): - when = When(when) - if agent.unique_id in self._agents: - del self._agents[agent.unique_id] - if agent.unique_id in self._next: - self._queue.remove((self._next[agent.unique_id], agent)) - heapify(self._queue) - - self._next[agent.unique_id] = when - heappush(self._queue, (when, agent)) + when = self.time + elif isinstance(when, When): + when = when.abs() + + self._schedule(agent, None, when) super().add(agent) + def _schedule(self, agent, condition=None, when=None): + if condition: + if not when: + when, condition = condition.schedule_next(when or self.time, + self.step_interval) + else: + if when is None: + when = self.time + self.step_interval + condition = None + if self._shuffle: + key = (when, self.model.random.random(), condition) + else: + key = (when, agent.unique_id, condition) + self._next[agent.unique_id] = key + heappush(self._queue, (key, agent)) + + def step(self) -> None: """ Executes agents in order, one at a time. After each step, @@ -143,73 +136,59 @@ class TimedActivation(BaseScheduler): if not self.model.running: return - when = NEVER - - to_process = [] - skipped = [] - next_time = INFINITY - - ix = 0 - self.logger.debug(f"Queue length: {len(self._queue)}") while self._queue: - (when, agent) = self._queue[0] + ((when, _id, cond), agent) = self._queue[0] if when > self.time: break + heappop(self._queue) - if when.ready(agent): + if cond: + if not cond.ready(agent, self.time): + self._schedule(agent, cond) + continue try: - agent._last_return = when.return_value(agent) + agent._last_return = cond.return_value(agent) except Exception as ex: agent._last_except = ex + else: + agent._last_return = None + agent._last_except = None - self._next.pop(agent.unique_id, None) - to_process.append(agent) - continue - - next_time = min(next_time, when.next(self.time)) - self._next[agent.unique_id] = next_time - skipped.append((when, agent)) - - if self._queue: - next_time = min(next_time, self._queue[0][0].next(self.time)) - - self._queue = [*skipped, *self._queue] - - for agent in to_process: self.logger.debug(f"Stepping agent {agent}") + self._next.pop(agent.unique_id, None) try: - returned = ((agent.step() or Delta(1))).abs(self.time) + returned = agent.step() except DeadAgent: - if agent.unique_id in self._next: - del self._next[agent.unique_id] agent.alive = False continue + # Check status for MESA agents if not getattr(agent, "alive", True): continue - value = returned.next(self.time) - agent._last_return = value - - if value < self.time: - raise Exception( - f"Cannot schedule an agent for a time in the past ({when} < {self.time})" - ) - if value < INFINITY: - next_time = min(value, next_time) - - self._next[agent.unique_id] = returned - heappush(self._queue, (returned, agent)) + if returned: + next_check = returned.schedule_next(self.time, self.step_interval, first=True) + self._schedule(agent, when=next_check[0], condition=next_check[1]) else: - assert not self._next[agent.unique_id] + next_check = (self.time + self.step_interval, None) + + self._schedule(agent) self.steps += 1 - self.logger.debug(f"Updating time step: {self.time} -> {next_time}") - self.time = next_time - if not self._queue or next_time == INFINITY: + if not self._queue: + self.time = INFINITY self.model.running = False return self.time + + next_time = self._queue[0][0][0] + if next_time < self.time: + raise Exception( + f"An agent has been scheduled for a time in the past, there is probably an error ({when} < {self.time})" + ) + self.logger.debug(f"Updating time step: {self.time} -> {next_time}") + + self.time = next_time diff --git a/tests/test_agents.py b/tests/test_agents.py index c6a603e..a32d91a 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -12,12 +12,11 @@ class Dead(agents.FSM): return self.die() -class TestMain(TestCase): +class TestAgents(TestCase): def test_die_returns_infinity(self): '''The last step of a dead agent should return time.INFINITY''' d = Dead(unique_id=0, model=environment.Environment()) - ret = d.step().abs(0) - print(ret, "next") + ret = d.step() assert ret == stime.NEVER def test_die_raises_exception(self): @@ -66,4 +65,95 @@ class TestMain(TestCase): a.step() assert a.run == 1 a.step() - assert a.run == 2 + + + def test_broadcast(self): + ''' + An agent should be able to broadcast messages to every other agent, AND each receiver should be able + to process it + ''' + class BCast(agents.Evented): + pings_received = 0 + + def step(self): + print(self.model.broadcast) + try: + self.model.broadcast('PING') + except Exception as ex: + print(ex) + while True: + self.check_messages() + yield + + def on_receive(self, msg, sender=None): + self.pings_received += 1 + + e = environment.EventedEnvironment() + + for i in range(10): + e.add_agent(agent_class=BCast) + e.step() + pings_received = lambda: [a.pings_received for a in e.agents] + assert sorted(pings_received()) == list(range(1, 11)) + e.step() + assert all(x==10 for x in pings_received()) + + def test_ask_messages(self): + ''' + An agent should be able to ask another agent, and wait for a response. + ''' + + # #Results depend on ordering (agents are shuffled), so force the first agent + pings = [] + pongs = [] + responses = [] + + class Ping(agents.EventedAgent): + def step(self): + target_id = (self.unique_id + 1) % self.count_agents() + target = self.model.agents[target_id] + print('starting') + while True: + print('Pings: ', pings, responses or not pings, self.model.schedule._queue) + if pongs or not pings: + pings.append(self.now) + response = yield target.ask('PING') + responses.append(response) + else: + print('NOT sending ping') + print('Checking msgs') + # Do not advance until we have received a message. + # warning: it will wait at least until the next time in the simulation + yield self.received(check=True) + print('done') + + def on_receive(self, msg, sender=None): + if msg == 'PING': + pongs.append(self.now) + return 'PONG' + + e = environment.EventedEnvironment() + for i in range(2): + e.add_agent(agent_class=Ping) + assert e.now == 0 + + # There should be a delay of one step between agent 0 and 1 + # On the first step: + # Agent 0 sends a PING, but blocks before a PONG + # Agent 1 sends a PONG, and blocks after its PING + # After that step, every agent can both receive (there are pending messages) and then send. + + e.step() + assert e.now == 1 + assert pings == [0] + assert pongs == [] + + e.step() + assert e.now == 2 + assert pings == [0, 1] + assert pongs == [1] + + e.step() + assert e.now == 3 + assert pings == [0, 1, 2] + assert pongs == [1, 2] diff --git a/tests/test_examples.py b/tests/test_examples.py index a0a2bd5..b2d2750 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -44,6 +44,8 @@ def add_example_tests(): for cfg, path in serialization.load_files( join(EXAMPLES, "**", "*.yml"), ): + if 'soil_output' in path: + continue p = make_example_test(path=path, cfg=config.Config.from_raw(cfg)) fname = os.path.basename(path) p.__name__ = "test_example_file_%s" % fname diff --git a/tests/test_main.py b/tests/test_main.py index 8f4f97c..d100b97 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -172,25 +172,21 @@ class TestMain(TestCase): assert len(configs) > 0 def test_until(self): - config = { - "name": "until_sim", - "model_params": { - "network_params": {}, - "agents": { - "fixed": [ - { - "agent_class": agents.BaseAgent, - } - ] - }, - }, - "max_time": 2, - "num_trials": 50, - } - s = simulation.from_config(config) + n_runs = 0 + + class CheckRun(agents.BaseAgent): + def step(self): + nonlocal n_runs + n_runs += 1 + return super().step() + + n_trials = 50 + max_time = 2 + s = simulation.Simulation(model_params={'agents': [{'agent_class': CheckRun}]}, + num_trials=n_trials, max_time=max_time) runs = list(s.run_simulation(dry_run=True)) over = list(x.now for x in runs if x.now > 2) - assert len(runs) == config["num_trials"] + assert len(runs) == n_trials assert len(over) == 0 def test_fsm(self): diff --git a/tests/test_network.py b/tests/test_network.py index a860b14..89ff4a0 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -72,7 +72,7 @@ class TestNetwork(TestCase): assert len(env.agents) == 2 assert env.agents[1].count_agents(state_id="normal") == 2 assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1 - assert env.agents[0].neighbors == 1 + assert env.agents[0].count_neighbors() == 1 def test_custom_agent_neighbors(self): """Allow for search of neighbors with a certain state_id""" @@ -90,7 +90,7 @@ class TestNetwork(TestCase): env = s.run_simulation(dry_run=True)[0] assert env.agents[1].count_agents(state_id="normal") == 2 assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1 - assert env.agents[0].neighbors == 1 + assert env.agents[0].count_neighbors() == 1 def test_subgraph(self): """An agent should be able to subgraph the global topology""" diff --git a/tests/test_time.py b/tests/test_time.py index db16609..458b734 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -30,8 +30,9 @@ class TestMain(TestCase): times_started = [] times_awakened = [] + times_asleep = [] times = [] - done = 0 + done = [] class CondAgent(agents.BaseAgent): @@ -39,36 +40,38 @@ class TestMain(TestCase): nonlocal done times_started.append(self.now) while True: - yield time.Cond(lambda agent: agent.model.schedule.time >= 10) + times_asleep.append(self.now) + yield time.Cond(lambda agent: agent.now >= 10, + delta=2) times_awakened.append(self.now) if self.now >= 10: break - done += 1 + done.append(self.now) env = environment.Environment(agents=[{'agent_class': CondAgent}]) while env.schedule.time < 11: - env.step() times.append(env.now) + env.step() + assert env.schedule.time == 11 assert times_started == [0] assert times_awakened == [10] - assert done == 1 + assert done == [10] # The first time will produce the Cond. - # Since there are no other agents, time will not advance, but the number - # of steps will. - assert env.schedule.steps == 12 - assert len(times) == 12 + assert env.schedule.steps == 6 + assert len(times) == 6 - while env.schedule.time < 12: - env.step() + while env.schedule.time < 13: times.append(env.now) + env.step() - assert env.schedule.time == 12 + assert times == [0, 2, 4, 6, 8, 10, 11] + assert env.schedule.time == 13 assert times_started == [0, 11] - assert times_awakened == [10, 11] - assert done == 2 + assert times_awakened == [10] + assert done == [10] # Once more to yield the cond, another one to continue - assert env.schedule.steps == 14 - assert len(times) == 14 + assert env.schedule.steps == 7 + assert len(times) == 7