1
0
mirror of https://github.com/gsi-upm/soil synced 2024-12-22 00:08:12 +00:00

Add events

This commit is contained in:
J. Fernando Sánchez 2022-10-18 13:11:01 +02:00
parent 3776c4e5c5
commit 159c9a9077
11 changed files with 342 additions and 44 deletions

View File

@ -0,0 +1,141 @@
from __future__ import annotations
from soil import *
from soil import events
from mesa.space import MultiGrid
from enum import Enum
@dataclass
class Journey:
origin: (int, int)
destination: (int, int)
tip: float
passenger: Passenger = None
driver: Driver = None
class City(EventedEnvironment):
def __init__(self, *args, n_cars=1, height=100, width=100, n_passengers=10, agents=None, **kwargs):
self.grid = MultiGrid(width=width, height=height, torus=False)
if agents is None:
agents = []
for i in range(n_cars):
agents.append({'agent_class': Driver})
for i in range(n_passengers):
agents.append({'agent_class': Passenger})
super().__init__(*args, agents=agents, **kwargs)
for agent in self.agents:
self.grid.place_agent(agent, (0, 0))
self.grid.move_to_empty(agent)
class Driver(Evented, FSM):
pos = None
journey = None
earnings = 0
def on_receive(self, msg, sender):
if self.journey is None and isinstance(msg, Journey) and msg.driver is None:
msg.driver = self
self.journey = msg
@default_state
@state
def wandering(self):
target = None
self.check_passengers()
self.journey = None
while self.journey is None:
if target is None or not self.move_towards(target):
target = self.random.choice(self.model.grid.get_neighborhood(self.pos, moore=False))
self.check_passengers()
self.check_messages() # This will call on_receive behind the scenes
yield Delta(30)
try:
self.journey = yield self.journey.passenger.ask(self.journey, timeout=60)
except events.TimedOut:
self.journey = None
return
return self.driving
def check_passengers(self):
c = self.count_agents(agent_class=Passenger)
self.info(f"Passengers left {c}")
if not c:
self.die()
@state
def driving(self):
#Approaching
while self.move_towards(self.journey.origin):
yield
while self.move_towards(self.journey.destination, with_passenger=True):
yield
self.check_passengers()
return self.wandering
def move_towards(self, target, with_passenger=False):
'''Move one cell at a time towards a target'''
self.info(f"Moving { self.pos } -> { target }")
if target[0] == self.pos[0] and target[1] == self.pos[1]:
return False
next_pos = [self.pos[0], self.pos[1]]
for idx in [0, 1]:
if self.pos[idx] < target[idx]:
next_pos[idx] += 1
break
if self.pos[idx] > target[idx]:
next_pos[idx] -= 1
break
self.model.grid.move_agent(self, tuple(next_pos))
if with_passenger:
self.journey.passenger.pos = self.pos # This could be communicated through messages
return True
class Passenger(Evented, FSM):
pos = None
@default_state
@state
def asking(self):
destination = (self.random.randint(0, self.model.grid.height), self.random.randint(0, self.model.grid.width))
self.journey = None
journey = Journey(origin=self.pos,
destination=destination,
tip=self.random.randint(10, 100),
passenger=self)
timeout = 60
expiration = self.now + timeout
self.model.broadcast(journey, ttl=timeout, sender=self, agent_class=Driver)
while not self.journey:
self.info(f"Passenger at: { self.pos }. Checking for responses.")
try:
yield self.received(expiration=expiration)
except events.TimedOut:
self.info(f"Passenger at: { self.pos }. Asking for journey.")
self.model.broadcast(journey, ttl=timeout, sender=self, agent_class=Driver)
expiration = self.now + timeout
self.check_messages()
return self.driving_home
def on_receive(self, msg, sender):
if isinstance(msg, Journey):
self.journey = msg
return msg
@state
def driving_home(self):
while self.pos[0] != self.journey.destination[0] or self.pos[1] != self.journey.destination[1]:
yield self.received(timeout=60)
self.info("Got home safe!")
self.die()
simulation = Simulation(model_class=City, model_params={'n_passengers': 2})
if __name__ == "__main__":
with easy(simulation) as s:
s.run()

View File

@ -17,7 +17,7 @@ except NameError:
from .agents import *
from . import agents
from .simulation import *
from .environment import Environment
from .environment import Environment, EventedEnvironment
from . import serialization
from .utils import logger
from .time import *
@ -34,6 +34,9 @@ def main(
pdb=False,
**kwargs,
):
if isinstance(cfg, Simulation):
sim = cfg
import argparse
from . import simulation
@ -44,7 +47,7 @@ def main(
"file",
type=str,
nargs="?",
default=cfg,
default=cfg if sim is None else '',
help="Configuration file for the simulation (e.g., YAML or JSON)",
)
parser.add_argument(
@ -150,7 +153,7 @@ def main(
if output is None:
output = args.output
logger.info("Loading config file: {}".format(args.file))
debug = debug or args.debug
@ -162,19 +165,27 @@ def main(
try:
exp_params = {}
if not os.path.exists(args.file):
logger.error("Please, input a valid file")
return
if sim:
logger.info("Loading simulation instance")
sims = [sim, ]
else:
logger.info("Loading config file: {}".format(args.file))
if not os.path.exists(args.file):
logger.error("Please, input a valid file")
return
sims = list(simulation.iter_from_config(
args.file,
dry_run=args.dry_run,
exporters=exporters,
parallel=parallel,
outdir=output,
exporter_params=exp_params,
**kwargs,
))
for sim in sims:
for sim in simulation.iter_from_config(
args.file,
dry_run=args.dry_run,
exporters=exporters,
parallel=parallel,
outdir=output,
exporter_params=exp_params,
**kwargs,
):
if args.set:
for s in args.set:
k, v = s.split("=", 1)[:2]
@ -219,7 +230,6 @@ def main(
@contextmanager
def easy(cfg, pdb=False, debug=False, **kwargs):
ex = None
try:
yield main(cfg, **kwargs)[0]
except Exception as e:
@ -228,10 +238,7 @@ def easy(cfg, pdb=False, debug=False, **kwargs):
print(traceback.format_exc())
post_mortem()
ex = e
finally:
if ex:
raise ex
raise
if __name__ == "__main__":

View File

@ -40,23 +40,31 @@ class MetaAgent(ABCMeta):
new_nmspc = {
"_defaults": defaults,
"_last_return": None,
"_last_except": None,
}
for attr, func in namespace.items():
if attr == "step" and inspect.isgeneratorfunction(func):
orig_func = func
new_nmspc["_MetaAgent__coroutine"] = None
new_nmspc["_coroutine"] = None
@wraps(func)
def func(self):
while True:
if not self.__coroutine:
self.__coroutine = orig_func(self)
if not self._coroutine:
self._coroutine = orig_func(self)
try:
return next(self.__coroutine)
if self._last_except:
return self._coroutine.throw(self._last_except)
else:
return self._coroutine.send(self._last_return)
except StopIteration as ex:
self.__coroutine = None
self._coroutine = None
return ex.value
finally:
self._last_return = None
self._last_except = None
func.id = name or func.__name__
func.is_default = False
@ -190,6 +198,10 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
def die(self):
self.info(f"agent dying")
self.alive = False
try:
self.model.schedule.remove(self)
except KeyError:
pass
return time.NEVER
def step(self):
@ -617,6 +629,7 @@ def _from_distro(
from .network_agents import *
from .fsm import *
from .evented import *
from .BassModel import *
from .BigMarketModel import *
from .IndependentCascadeModel import *

57
soil/agents/evented.py Normal file
View File

@ -0,0 +1,57 @@
from . import BaseAgent
from ..events import Message, Tell, Ask, Reply, TimedOut
from ..time import Cond
from functools import partial
from collections import deque
class Evented(BaseAgent):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._inbox = deque()
self._received = 0
self._processed = 0
def on_receive(self, *args, **kwargs):
pass
def received(self, expiration=None, timeout=None):
current = self._received
if expiration is None:
expiration = float('inf') if timeout is None else self.now + timeout
if expiration < self.now:
raise ValueError("Invalid expiration time")
def ready(agent):
return agent._received > current or agent.now >= expiration
def value(agent):
if agent.now > expiration:
raise TimedOut("No message received")
c = Cond(func=ready, return_func=value)
c._checked = True
return c
def tell(self, msg, sender):
self._received += 1
self._inbox.append(Tell(payload=msg, sender=sender))
def ask(self, msg, timeout=None):
self._received += 1
ask = Ask(payload=msg)
self._inbox.append(ask)
expiration = float('inf') if timeout is None else self.now + timeout
return ask.replied(expiration=expiration)
def check_messages(self):
while self._inbox:
msg = self._inbox.popleft()
self._processed += 1
if msg.expired(self.now):
continue
reply = self.on_receive(msg.payload, sender=msg.sender)
if isinstance(msg, Ask):
msg.reply = reply

View File

@ -1,6 +1,6 @@
from . import MetaAgent, BaseAgent
from functools import partial
from functools import partial, wraps
import inspect
@ -19,17 +19,26 @@ def state(name=None):
while True:
if not self._coroutine:
self._coroutine = orig_func(self)
try:
n = next(self._coroutine)
if self._last_except:
n = self._coroutine.throw(self._last_except)
else:
n = self._coroutine.send(self._last_return)
if n:
return None, n
return
return n
except StopIteration as ex:
self._coroutine = None
next_state = ex.value
if next_state is not None:
self._set_state(next_state)
return next_state
finally:
self._last_return = None
self._last_except = None
func.id = name or func.__name__
func.is_default = False

View File

@ -30,9 +30,9 @@ def wrapcmd(func):
class Debug(pdb.Pdb):
def __init__(self, *args, skip_soil=False, **kwargs):
skip = kwargs.get("skip", [])
skip.append("soil")
skip.append("contextlib")
if skip_soil:
skip.append("soil")
skip.append("contextlib")
skip.append("soil.*")
skip.append("mesa.*")
super(Debug, self).__init__(*args, skip=skip, **kwargs)

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import os
import sqlite3
import math
import random
import logging
import inspect
@ -19,7 +18,7 @@ import networkx as nx
from mesa import Model
from mesa.datacollection import DataCollector
from . import agents as agentmod, config, serialization, utils, time, network
from . import agents as agentmod, config, serialization, utils, time, network, events
class BaseEnvironment(Model):
@ -294,10 +293,6 @@ class NetworkEnvironment(BaseEnvironment):
def add_agent(self, *args, **kwargs):
a = super().add_agent(*args, **kwargs)
if "node_id" in a:
if a.node_id == 24:
import pdb
pdb.set_trace()
assert self.G.nodes[a.node_id]["agent"] == a
return a
@ -316,3 +311,14 @@ class NetworkEnvironment(BaseEnvironment):
Environment = NetworkEnvironment
class EventedEnvironment(Environment):
def broadcast(self, msg, sender, expiration=None, ttl=None, **kwargs):
for agent in self.agents(**kwargs):
self.logger.info(f'Telling {repr(agent)}: {msg} ttl={ttl}')
try:
agent._inbox.append(events.Tell(payload=msg, sender=sender, expiration=expiration if ttl is None else self.now+ttl))
except AttributeError:
self.info(f'Agent {agent.unique_id} cannot receive events')

43
soil/events.py Normal file
View File

@ -0,0 +1,43 @@
from .time import Cond
from dataclasses import dataclass, field
from typing import Any
from uuid import uuid4
class Event:
pass
@dataclass
class Message:
payload: Any
sender: Any = None
expiration: float = None
id: int = field(default_factory=uuid4)
def expired(self, when):
return self.expiration is not None and self.expiration < when
class Reply(Message):
source: Message
class Ask(Message):
reply: Message = None
def replied(self, expiration=None):
def ready(agent):
return self.reply is not None or agent.now > expiration
def value(agent):
if agent.now > expiration:
raise TimedOut(f'No answer received for {self}')
return self.reply
return Cond(func=ready, return_func=value)
class Tell(Message):
pass
class TimedOut(Exception):
pass

View File

@ -47,7 +47,7 @@ class Simulation:
max_time: float = float("inf")
max_steps: int = -1
interval: int = 1
num_trials: int = 3
num_trials: int = 1
parallel: Optional[bool] = None
exporters: Optional[List[str]] = field(default_factory=list)
outdir: Optional[str] = None

View File

@ -45,12 +45,16 @@ class When:
def ready(self, agent):
return self._time <= agent.model.schedule.time
def return_value(self, agent):
return None
class Cond(When):
def __init__(self, func, delta=1):
def __init__(self, func, delta=1, return_func=lambda agent: None):
self._func = func
self._delta = delta
self._checked = False
self._return_func = return_func
def next(self, time):
if self._checked:
@ -64,6 +68,9 @@ class Cond(When):
self._checked = True
return self._func(agent)
def return_value(self, agent):
return self._return_func(agent)
def __eq__(self, other):
return False
@ -144,14 +151,21 @@ class TimedActivation(BaseScheduler):
ix = 0
self.logger.debug(f"Queue length: {len(self._queue)}")
while self._queue:
(when, agent) = self._queue[0]
if when > self.time:
break
heappop(self._queue)
if when.ready(agent):
to_process.append(agent)
try:
agent._last_return = when.return_value(agent)
except Exception as ex:
agent._last_except = ex
self._next.pop(agent.unique_id, None)
to_process.append(agent)
continue
next_time = min(next_time, when.next(self.time))
@ -175,10 +189,10 @@ class TimedActivation(BaseScheduler):
continue
if not getattr(agent, "alive", True):
self.remove(agent)
continue
value = returned.next(self.time)
agent._last_return = value
if value < self.time:
raise Exception(

View File

@ -33,18 +33,20 @@ class TestMain(TestCase):
The step function of an agent could be a generator. In that case, the state of the
agent will be resumed after every call to step.
'''
a = 0
class Gen(agents.BaseAgent):
def step(self):
a = 0
nonlocal a
for i in range(5):
yield a
yield
a += 1
e = environment.Environment()
g = Gen(model=e, unique_id=e.next_id())
e.schedule.add(g)
for i in range(5):
t = g.step()
assert t == i
e.step()
assert a == i
def test_state_decorator(self):
class MyAgent(agents.FSM):
@ -53,6 +55,12 @@ class TestMain(TestCase):
@agents.state('original')
def root(self):
self.run += 1
return self.other
@agents.state
def other(self):
self.run += 1
e = environment.Environment()
a = MyAgent(model=e, unique_id=e.next_id())
a.step()