1
0
mirror of https://github.com/gsi-upm/soil synced 2024-11-21 18:52:28 +00:00

Refactored time

Treating time and conditions as the same entity was getting confusing, and it
added a lot of unnecessary abstraction in a critical part (the scheduler).

The scheduling queue now has the time as a floating number (faster), the agent
id (for ties) and the condition, as well as the agent. The first three
elements (time, id, condition) can be considered as the "key" for the event.

To allow for agent execution to be "randomized" within every step, a new
parameter has been added to the scheduler, which makes it add a random number to
the key in order to change the ordering.

`EventedAgent.received` now checks the messages before returning control to the
user by default.
This commit is contained in:
J. Fernando Sánchez 2022-10-20 09:14:50 +02:00
parent cbbaf73538
commit a1262edd2a
16 changed files with 324 additions and 223 deletions

View File

@ -127,7 +127,8 @@ class Driver(Evented, FSM):
)
self.check_passengers()
self.check_messages() # This will call on_receive behind the scenes, and the agent's status will be updated
# This will call on_receive behind the scenes, and the agent's status will be updated
self.check_messages()
yield Delta(30) # Wait at least 30 seconds before checking again
try:
@ -204,6 +205,8 @@ class Passenger(Evented, FSM):
while not self.journey:
self.info(f"Passenger at: { self.pos }. Checking for responses.")
try:
# This will call check_messages behind the scenes, and the agent's status will be updated
# If you want to avoid that, you can call it with: check=False
yield self.received(expiration=expiration)
except events.TimedOut:
self.info(f"Passenger at: { self.pos }. Asking for journey.")
@ -211,7 +214,6 @@ class Passenger(Evented, FSM):
journey, ttl=timeout, sender=self, agent_class=Driver
)
expiration = self.now + timeout
self.check_messages()
return self.driving_home
@state
@ -220,7 +222,11 @@ class Passenger(Evented, FSM):
self.pos[0] != self.journey.destination[0]
or self.pos[1] != self.journey.destination[1]
):
yield self.received(timeout=60)
try:
yield self.received(timeout=60)
except events.TimedOut:
pass
self.info("Got home safe!")
self.die()

View File

@ -133,7 +133,7 @@ class RandomAccident(BaseAgent):
math.log10(max(1, rabbits_alive))
)
self.debug("Killing some rabbits with prob={}!".format(prob_death))
for i in self.iter_agents(agent_class=Rabbit):
for i in self.get_agents(agent_class=Rabbit):
if i.state_id == i.dead.id:
continue
if self.prob(prob_death):

View File

@ -20,12 +20,6 @@ from typing import Dict, List
from .. import serialization, utils, time, config
def as_node(agent):
if isinstance(agent, BaseAgent):
return agent.id
return agent
IGNORED_FIELDS = ("model", "logger")
@ -97,10 +91,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
"""
def __init__(self, unique_id, model, name=None, interval=None, **kwargs):
# Check for REQUIRED arguments
# Initialize agent parameters
if isinstance(unique_id, MesaAgent):
raise Exception()
assert isinstance(unique_id, int)
super().__init__(unique_id=unique_id, model=model)
@ -207,7 +197,8 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
def step(self):
if not self.alive:
raise time.DeadAgent(self.unique_id)
return super().step() or time.Delta(self.interval)
super().step()
return time.Delta(self.interval)
def log(self, message, *args, level=logging.INFO, **kwargs):
if not self.logger.isEnabledFor(level):
@ -270,6 +261,7 @@ def prob(prob, random):
return r < prob
def calculate_distribution(network_agents=None, agent_class=None):
"""
Calculate the threshold values (thresholds for a uniform distribution)
@ -414,7 +406,7 @@ def filter_agents(
if ids:
f = (agents[aid] for aid in ids if aid in agents)
else:
f = (a for a in agents.values())
f = agents.values()
if state_id is not None and not isinstance(state_id, (tuple, list)):
state_id = tuple([state_id])
@ -638,6 +630,10 @@ from .SentimentCorrelationModel import *
from .SISaModel import *
from .CounterModel import *
class Agent(NetworkAgent, EventedAgent):
'''Default agent class, has both network and event capabilities'''
try:
import scipy
from .Geo import Geo

View File

@ -1,57 +1,74 @@
from . import BaseAgent
from ..events import Message, Tell, Ask, Reply, TimedOut
from ..time import Cond
from ..events import Message, Tell, Ask, TimedOut
from ..time import BaseCond
from functools import partial
from collections import deque
class Evented(BaseAgent):
class ReceivedOrTimeout(BaseCond):
def __init__(self, agent, expiration=None, timeout=None, check=True, ignore=False, **kwargs):
if expiration is None:
if timeout is not None:
expiration = agent.now + timeout
self.expiration = expiration
self.ignore = ignore
self.check = check
super().__init__(**kwargs)
def expired(self, time):
return self.expiration and self.expiration < time
def ready(self, agent, time):
return len(agent._inbox) or self.expired(time)
def return_value(self, agent):
if not self.ignore and self.expired(agent.now):
raise TimedOut('No messages received')
if self.check:
agent.check_messages()
return None
def schedule_next(self, time, delta, first=False):
if self._delta is not None:
delta = self._delta
return (time + delta, self)
def __repr__(self):
return f'ReceivedOrTimeout(expires={self.expiration})'
class EventedAgent(BaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._inbox = deque()
self._received = 0
self._processed = 0
def on_receive(self, *args, **kwargs):
pass
def received(self, expiration=None, timeout=None):
current = self._received
if expiration is None:
expiration = float('inf') if timeout is None else self.now + timeout
def received(self, *args, **kwargs):
return ReceivedOrTimeout(self, *args, **kwargs)
if expiration < self.now:
raise ValueError("Invalid expiration time")
def ready(agent):
return agent._received > current or agent.now >= expiration
def value(agent):
if agent.now > expiration:
raise TimedOut("No message received")
c = Cond(func=ready, return_func=value)
c._checked = True
return c
def tell(self, msg, sender):
self._received += 1
self._inbox.append(Tell(payload=msg, sender=sender))
def tell(self, msg, sender=None):
self._inbox.append(Tell(timestamp=self.now, payload=msg, sender=sender))
def ask(self, msg, timeout=None):
self._received += 1
ask = Ask(payload=msg)
ask = Ask(timestamp=self.now, payload=msg)
self._inbox.append(ask)
expiration = float('inf') if timeout is None else self.now + timeout
return ask.replied(expiration=expiration)
def check_messages(self):
changed = False
while self._inbox:
msg = self._inbox.popleft()
self._processed += 1
if msg.expired(self.now):
continue
changed = True
reply = self.on_receive(msg.payload, sender=msg.sender)
if isinstance(msg, Ask):
msg.reply = reply
return changed
Evented = EventedAgent

View File

@ -54,7 +54,7 @@ class NetworkAgent(BaseAgent):
return G
def remove_node(self):
print(f"Removing node for {self.unique_id}: {self.node_id}")
self.debug(f"Removing node for {self.unique_id}: {self.node_id}")
self.G.remove_node(self.node_id)
self.node_id = None
@ -80,3 +80,5 @@ class NetworkAgent(BaseAgent):
if remove:
self.remove_node()
return super().die()
NetAgent = NetworkAgent

View File

@ -310,17 +310,20 @@ class NetworkEnvironment(BaseEnvironment):
self.add_agent(node_id=node_id, agent_class=a_class, **agent_params)
Environment = NetworkEnvironment
class EventedEnvironment(Environment):
def broadcast(self, msg, sender, expiration=None, ttl=None, **kwargs):
class EventedEnvironment(BaseEnvironment):
def broadcast(self, msg, sender=None, expiration=None, ttl=None, **kwargs):
for agent in self.agents(**kwargs):
if agent == sender:
continue
self.logger.info(f'Telling {repr(agent)}: {msg} ttl={ttl}')
try:
inbox = agent._inbox
except AttributeError:
self.logger.info(f'Agent {agent.unique_id} cannot receive events because it does not have an inbox')
continue
# Allow for AttributeError exceptions in this part of the code
inbox.append(events.Tell(payload=msg, sender=sender, expiration=expiration if ttl is None else self.now+ttl))
class Environment(NetworkEnvironment, EventedEnvironment):
'''Default environment class, has both network and event capabilities'''

View File

@ -1,4 +1,4 @@
from .time import Cond
from .time import BaseCond
from dataclasses import dataclass, field
from typing import Any
from uuid import uuid4
@ -11,6 +11,7 @@ class Message:
payload: Any
sender: Any = None
expiration: float = None
timestamp: float = None
id: int = field(default_factory=uuid4)
def expired(self, when):
@ -20,19 +21,28 @@ class Reply(Message):
source: Message
class ReplyCond(BaseCond):
def __init__(self, ask, *args, **kwargs):
self._ask = ask
super().__init__(*args, **kwargs)
def ready(self, agent, time):
return self._ask.reply is not None or self._ask.expired(time)
def return_value(self, agent):
if self._ask.expired(agent.now):
raise TimedOut()
return self._ask.reply
def __repr__(self):
return f"ReplyCond({self._ask.id})"
class Ask(Message):
reply: Message = None
def replied(self, expiration=None):
def ready(agent):
return self.reply is not None or agent.now > expiration
def value(agent):
if agent.now > expiration:
raise TimedOut(f'No answer received for {self}')
return self.reply
return Cond(func=ready, return_func=value)
return ReplyCond(self)
class Tell(Message):

View File

@ -59,7 +59,6 @@ def find_unassigned(G, shuffle=False, random=random):
If node_id is None, a node without an agent_id will be found.
"""
# TODO: test
candidates = list(G.nodes(data=True))
if shuffle:
random.shuffle(candidates)

View File

@ -221,8 +221,6 @@ def deserialize(type_, value=None, globs=None, **kwargs):
def deserialize_all(names, *args, known_modules=KNOWN_MODULES, **kwargs):
"""Return the list of deserialized objects"""
# TODO: remove
print("SERIALIZATION", kwargs)
objects = []
for name in names:
mod = deserialize(name, known_modules=known_modules)

View File

@ -66,7 +66,7 @@ class Simulation:
if ignored:
d.setdefault("extra", {}).update(ignored)
if ignored:
print(f'Warning: Ignoring these parameters (added to "extra"): { ignored }')
logger.warning(f'Ignoring these parameters (added to "extra"): { ignored }')
d.update(kwargs)
return cls(**d)

View File

@ -1,10 +1,11 @@
from mesa.time import BaseScheduler
from queue import Empty
from heapq import heappush, heappop, heapify
from heapq import heappush, heappop
import math
from inspect import getsource
from numbers import Number
from textwrap import dedent
from .utils import logger
from mesa import Agent as MesaAgent
@ -23,65 +24,11 @@ class When:
return time
self._time = time
def next(self, time):
def abs(self, time):
return self._time
def abs(self, time):
return self
def __repr__(self):
return str(f"When({self._time})")
def __lt__(self, other):
if isinstance(other, Number):
return self._time < other
return self._time < other.next(self._time)
def __gt__(self, other):
if isinstance(other, Number):
return self._time > other
return self._time > other.next(self._time)
def ready(self, agent):
return self._time <= agent.model.schedule.time
def return_value(self, agent):
return None
class Cond(When):
def __init__(self, func, delta=1, return_func=lambda agent: None):
self._func = func
self._delta = delta
self._checked = False
self._return_func = return_func
def next(self, time):
if self._checked:
return time + self._delta
return time
def abs(self, time):
return self
def ready(self, agent):
self._checked = True
return self._func(agent)
def return_value(self, agent):
return self._return_func(agent)
def __eq__(self, other):
return False
def __lt__(self, other):
return True
def __gt__(self, other):
return False
def __repr__(self):
return str(f'Cond("{getsource(self._func)}")')
def schedule_next(self, time, delta, first=False):
return (self._time, None)
NEVER = When(INFINITY)
@ -91,48 +38,94 @@ class Delta(When):
def __init__(self, delta):
self._delta = delta
def abs(self, time):
return self._time + self._delta
def __eq__(self, other):
if isinstance(other, Delta):
return self._delta == other._delta
return False
def abs(self, time):
return When(self._delta + time)
def next(self, time):
return time + self._delta
def schedule_next(self, time, delta, first=False):
return (time + self._delta, None)
def __repr__(self):
return str(f"Delta({self._delta})")
class BaseCond:
def __init__(self, msg=None, delta=None, eager=False):
self._msg = msg
self._delta = delta
self.eager = eager
def schedule_next(self, time, delta, first=False):
if first and self.eager:
return (time, self)
if self._delta:
delta = self._delta
return (time + delta, self)
def return_value(self, agent):
return None
def __repr__(self):
return self._msg or self.__class__.__name__
class Cond(BaseCond):
def __init__(self, func, *args, **kwargs):
self._func = func
super().__init__(*args, **kwargs)
def ready(self, agent, time):
return self._func(agent)
def __repr__(self):
if self._msg:
return self._msg
return str(f'Cond("{dedent(getsource(self._func)).strip()}")')
class TimedActivation(BaseScheduler):
"""A scheduler which activates each agent when the agent requests.
In each activation, each agent will update its 'next_time'.
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args, shuffle=True, **kwargs):
super().__init__(*args, **kwargs)
self._next = {}
self._queue = []
self.next_time = 0
self._shuffle = shuffle
self.step_interval = 1
self.logger = logger.getChild(f"time_{ self.model }")
def add(self, agent: MesaAgent, when=None):
if when is None:
when = When(self.time)
elif not isinstance(when, When):
when = When(when)
if agent.unique_id in self._agents:
del self._agents[agent.unique_id]
if agent.unique_id in self._next:
self._queue.remove((self._next[agent.unique_id], agent))
heapify(self._queue)
self._next[agent.unique_id] = when
heappush(self._queue, (when, agent))
when = self.time
elif isinstance(when, When):
when = when.abs()
self._schedule(agent, None, when)
super().add(agent)
def _schedule(self, agent, condition=None, when=None):
if condition:
if not when:
when, condition = condition.schedule_next(when or self.time,
self.step_interval)
else:
if when is None:
when = self.time + self.step_interval
condition = None
if self._shuffle:
key = (when, self.model.random.random(), condition)
else:
key = (when, agent.unique_id, condition)
self._next[agent.unique_id] = key
heappush(self._queue, (key, agent))
def step(self) -> None:
"""
Executes agents in order, one at a time. After each step,
@ -143,73 +136,59 @@ class TimedActivation(BaseScheduler):
if not self.model.running:
return
when = NEVER
to_process = []
skipped = []
next_time = INFINITY
ix = 0
self.logger.debug(f"Queue length: {len(self._queue)}")
while self._queue:
(when, agent) = self._queue[0]
((when, _id, cond), agent) = self._queue[0]
if when > self.time:
break
heappop(self._queue)
if when.ready(agent):
if cond:
if not cond.ready(agent, self.time):
self._schedule(agent, cond)
continue
try:
agent._last_return = when.return_value(agent)
agent._last_return = cond.return_value(agent)
except Exception as ex:
agent._last_except = ex
else:
agent._last_return = None
agent._last_except = None
self._next.pop(agent.unique_id, None)
to_process.append(agent)
continue
next_time = min(next_time, when.next(self.time))
self._next[agent.unique_id] = next_time
skipped.append((when, agent))
if self._queue:
next_time = min(next_time, self._queue[0][0].next(self.time))
self._queue = [*skipped, *self._queue]
for agent in to_process:
self.logger.debug(f"Stepping agent {agent}")
self._next.pop(agent.unique_id, None)
try:
returned = ((agent.step() or Delta(1))).abs(self.time)
returned = agent.step()
except DeadAgent:
if agent.unique_id in self._next:
del self._next[agent.unique_id]
agent.alive = False
continue
# Check status for MESA agents
if not getattr(agent, "alive", True):
continue
value = returned.next(self.time)
agent._last_return = value
if value < self.time:
raise Exception(
f"Cannot schedule an agent for a time in the past ({when} < {self.time})"
)
if value < INFINITY:
next_time = min(value, next_time)
self._next[agent.unique_id] = returned
heappush(self._queue, (returned, agent))
if returned:
next_check = returned.schedule_next(self.time, self.step_interval, first=True)
self._schedule(agent, when=next_check[0], condition=next_check[1])
else:
assert not self._next[agent.unique_id]
next_check = (self.time + self.step_interval, None)
self._schedule(agent)
self.steps += 1
self.logger.debug(f"Updating time step: {self.time} -> {next_time}")
self.time = next_time
if not self._queue or next_time == INFINITY:
if not self._queue:
self.time = INFINITY
self.model.running = False
return self.time
next_time = self._queue[0][0][0]
if next_time < self.time:
raise Exception(
f"An agent has been scheduled for a time in the past, there is probably an error ({when} < {self.time})"
)
self.logger.debug(f"Updating time step: {self.time} -> {next_time}")
self.time = next_time

View File

@ -12,12 +12,11 @@ class Dead(agents.FSM):
return self.die()
class TestMain(TestCase):
class TestAgents(TestCase):
def test_die_returns_infinity(self):
'''The last step of a dead agent should return time.INFINITY'''
d = Dead(unique_id=0, model=environment.Environment())
ret = d.step().abs(0)
print(ret, "next")
ret = d.step()
assert ret == stime.NEVER
def test_die_raises_exception(self):
@ -66,4 +65,95 @@ class TestMain(TestCase):
a.step()
assert a.run == 1
a.step()
assert a.run == 2
def test_broadcast(self):
'''
An agent should be able to broadcast messages to every other agent, AND each receiver should be able
to process it
'''
class BCast(agents.Evented):
pings_received = 0
def step(self):
print(self.model.broadcast)
try:
self.model.broadcast('PING')
except Exception as ex:
print(ex)
while True:
self.check_messages()
yield
def on_receive(self, msg, sender=None):
self.pings_received += 1
e = environment.EventedEnvironment()
for i in range(10):
e.add_agent(agent_class=BCast)
e.step()
pings_received = lambda: [a.pings_received for a in e.agents]
assert sorted(pings_received()) == list(range(1, 11))
e.step()
assert all(x==10 for x in pings_received())
def test_ask_messages(self):
'''
An agent should be able to ask another agent, and wait for a response.
'''
# #Results depend on ordering (agents are shuffled), so force the first agent
pings = []
pongs = []
responses = []
class Ping(agents.EventedAgent):
def step(self):
target_id = (self.unique_id + 1) % self.count_agents()
target = self.model.agents[target_id]
print('starting')
while True:
print('Pings: ', pings, responses or not pings, self.model.schedule._queue)
if pongs or not pings:
pings.append(self.now)
response = yield target.ask('PING')
responses.append(response)
else:
print('NOT sending ping')
print('Checking msgs')
# Do not advance until we have received a message.
# warning: it will wait at least until the next time in the simulation
yield self.received(check=True)
print('done')
def on_receive(self, msg, sender=None):
if msg == 'PING':
pongs.append(self.now)
return 'PONG'
e = environment.EventedEnvironment()
for i in range(2):
e.add_agent(agent_class=Ping)
assert e.now == 0
# There should be a delay of one step between agent 0 and 1
# On the first step:
# Agent 0 sends a PING, but blocks before a PONG
# Agent 1 sends a PONG, and blocks after its PING
# After that step, every agent can both receive (there are pending messages) and then send.
e.step()
assert e.now == 1
assert pings == [0]
assert pongs == []
e.step()
assert e.now == 2
assert pings == [0, 1]
assert pongs == [1]
e.step()
assert e.now == 3
assert pings == [0, 1, 2]
assert pongs == [1, 2]

View File

@ -44,6 +44,8 @@ def add_example_tests():
for cfg, path in serialization.load_files(
join(EXAMPLES, "**", "*.yml"),
):
if 'soil_output' in path:
continue
p = make_example_test(path=path, cfg=config.Config.from_raw(cfg))
fname = os.path.basename(path)
p.__name__ = "test_example_file_%s" % fname

View File

@ -172,25 +172,21 @@ class TestMain(TestCase):
assert len(configs) > 0
def test_until(self):
config = {
"name": "until_sim",
"model_params": {
"network_params": {},
"agents": {
"fixed": [
{
"agent_class": agents.BaseAgent,
}
]
},
},
"max_time": 2,
"num_trials": 50,
}
s = simulation.from_config(config)
n_runs = 0
class CheckRun(agents.BaseAgent):
def step(self):
nonlocal n_runs
n_runs += 1
return super().step()
n_trials = 50
max_time = 2
s = simulation.Simulation(model_params={'agents': [{'agent_class': CheckRun}]},
num_trials=n_trials, max_time=max_time)
runs = list(s.run_simulation(dry_run=True))
over = list(x.now for x in runs if x.now > 2)
assert len(runs) == config["num_trials"]
assert len(runs) == n_trials
assert len(over) == 0
def test_fsm(self):

View File

@ -72,7 +72,7 @@ class TestNetwork(TestCase):
assert len(env.agents) == 2
assert env.agents[1].count_agents(state_id="normal") == 2
assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1
assert env.agents[0].neighbors == 1
assert env.agents[0].count_neighbors() == 1
def test_custom_agent_neighbors(self):
"""Allow for search of neighbors with a certain state_id"""
@ -90,7 +90,7 @@ class TestNetwork(TestCase):
env = s.run_simulation(dry_run=True)[0]
assert env.agents[1].count_agents(state_id="normal") == 2
assert env.agents[1].count_agents(state_id="normal", limit_neighbors=True) == 1
assert env.agents[0].neighbors == 1
assert env.agents[0].count_neighbors() == 1
def test_subgraph(self):
"""An agent should be able to subgraph the global topology"""

View File

@ -30,8 +30,9 @@ class TestMain(TestCase):
times_started = []
times_awakened = []
times_asleep = []
times = []
done = 0
done = []
class CondAgent(agents.BaseAgent):
@ -39,36 +40,38 @@ class TestMain(TestCase):
nonlocal done
times_started.append(self.now)
while True:
yield time.Cond(lambda agent: agent.model.schedule.time >= 10)
times_asleep.append(self.now)
yield time.Cond(lambda agent: agent.now >= 10,
delta=2)
times_awakened.append(self.now)
if self.now >= 10:
break
done += 1
done.append(self.now)
env = environment.Environment(agents=[{'agent_class': CondAgent}])
while env.schedule.time < 11:
env.step()
times.append(env.now)
env.step()
assert env.schedule.time == 11
assert times_started == [0]
assert times_awakened == [10]
assert done == 1
assert done == [10]
# The first time will produce the Cond.
# Since there are no other agents, time will not advance, but the number
# of steps will.
assert env.schedule.steps == 12
assert len(times) == 12
assert env.schedule.steps == 6
assert len(times) == 6
while env.schedule.time < 12:
env.step()
while env.schedule.time < 13:
times.append(env.now)
env.step()
assert env.schedule.time == 12
assert times == [0, 2, 4, 6, 8, 10, 11]
assert env.schedule.time == 13
assert times_started == [0, 11]
assert times_awakened == [10, 11]
assert done == 2
assert times_awakened == [10]
assert done == [10]
# Once more to yield the cond, another one to continue
assert env.schedule.steps == 14
assert len(times) == 14
assert env.schedule.steps == 7
assert len(times) == 7