diff --git a/examples/events_and_messages/cars.py b/examples/events_and_messages/cars.py new file mode 100644 index 0000000..ecb5b17 --- /dev/null +++ b/examples/events_and_messages/cars.py @@ -0,0 +1,141 @@ +from __future__ import annotations +from soil import * +from soil import events +from mesa.space import MultiGrid +from enum import Enum + + +@dataclass +class Journey: + origin: (int, int) + destination: (int, int) + tip: float + + passenger: Passenger = None + driver: Driver = None + + +class City(EventedEnvironment): + def __init__(self, *args, n_cars=1, height=100, width=100, n_passengers=10, agents=None, **kwargs): + self.grid = MultiGrid(width=width, height=height, torus=False) + if agents is None: + agents = [] + for i in range(n_cars): + agents.append({'agent_class': Driver}) + for i in range(n_passengers): + agents.append({'agent_class': Passenger}) + super().__init__(*args, agents=agents, **kwargs) + for agent in self.agents: + self.grid.place_agent(agent, (0, 0)) + self.grid.move_to_empty(agent) + +class Driver(Evented, FSM): + pos = None + journey = None + earnings = 0 + + def on_receive(self, msg, sender): + if self.journey is None and isinstance(msg, Journey) and msg.driver is None: + msg.driver = self + self.journey = msg + + @default_state + @state + def wandering(self): + target = None + self.check_passengers() + self.journey = None + while self.journey is None: + if target is None or not self.move_towards(target): + target = self.random.choice(self.model.grid.get_neighborhood(self.pos, moore=False)) + self.check_passengers() + self.check_messages() # This will call on_receive behind the scenes + yield Delta(30) + try: + self.journey = yield self.journey.passenger.ask(self.journey, timeout=60) + except events.TimedOut: + self.journey = None + return + return self.driving + + def check_passengers(self): + c = self.count_agents(agent_class=Passenger) + self.info(f"Passengers left {c}") + if not c: + self.die() + + @state + def driving(self): + #Approaching + while self.move_towards(self.journey.origin): + yield + while self.move_towards(self.journey.destination, with_passenger=True): + yield + self.check_passengers() + return self.wandering + + def move_towards(self, target, with_passenger=False): + '''Move one cell at a time towards a target''' + self.info(f"Moving { self.pos } -> { target }") + if target[0] == self.pos[0] and target[1] == self.pos[1]: + return False + + next_pos = [self.pos[0], self.pos[1]] + for idx in [0, 1]: + if self.pos[idx] < target[idx]: + next_pos[idx] += 1 + break + if self.pos[idx] > target[idx]: + next_pos[idx] -= 1 + break + self.model.grid.move_agent(self, tuple(next_pos)) + if with_passenger: + self.journey.passenger.pos = self.pos # This could be communicated through messages + return True + + +class Passenger(Evented, FSM): + pos = None + + @default_state + @state + def asking(self): + destination = (self.random.randint(0, self.model.grid.height), self.random.randint(0, self.model.grid.width)) + self.journey = None + journey = Journey(origin=self.pos, + destination=destination, + tip=self.random.randint(10, 100), + passenger=self) + + timeout = 60 + expiration = self.now + timeout + self.model.broadcast(journey, ttl=timeout, sender=self, agent_class=Driver) + while not self.journey: + self.info(f"Passenger at: { self.pos }. Checking for responses.") + try: + yield self.received(expiration=expiration) + except events.TimedOut: + self.info(f"Passenger at: { self.pos }. Asking for journey.") + self.model.broadcast(journey, ttl=timeout, sender=self, agent_class=Driver) + expiration = self.now + timeout + self.check_messages() + return self.driving_home + + def on_receive(self, msg, sender): + if isinstance(msg, Journey): + self.journey = msg + return msg + + @state + def driving_home(self): + while self.pos[0] != self.journey.destination[0] or self.pos[1] != self.journey.destination[1]: + yield self.received(timeout=60) + self.info("Got home safe!") + self.die() + + +simulation = Simulation(model_class=City, model_params={'n_passengers': 2}) + +if __name__ == "__main__": + with easy(simulation) as s: + s.run() diff --git a/soil/__init__.py b/soil/__init__.py index 92bc79f..b6b62ee 100644 --- a/soil/__init__.py +++ b/soil/__init__.py @@ -17,7 +17,7 @@ except NameError: from .agents import * from . import agents from .simulation import * -from .environment import Environment +from .environment import Environment, EventedEnvironment from . import serialization from .utils import logger from .time import * @@ -34,6 +34,9 @@ def main( pdb=False, **kwargs, ): + + if isinstance(cfg, Simulation): + sim = cfg import argparse from . import simulation @@ -44,7 +47,7 @@ def main( "file", type=str, nargs="?", - default=cfg, + default=cfg if sim is None else '', help="Configuration file for the simulation (e.g., YAML or JSON)", ) parser.add_argument( @@ -150,7 +153,7 @@ def main( if output is None: output = args.output - logger.info("Loading config file: {}".format(args.file)) + debug = debug or args.debug @@ -162,19 +165,27 @@ def main( try: exp_params = {} - if not os.path.exists(args.file): - logger.error("Please, input a valid file") - return + if sim: + logger.info("Loading simulation instance") + sims = [sim, ] + else: + logger.info("Loading config file: {}".format(args.file)) + if not os.path.exists(args.file): + logger.error("Please, input a valid file") + return + + sims = list(simulation.iter_from_config( + args.file, + dry_run=args.dry_run, + exporters=exporters, + parallel=parallel, + outdir=output, + exporter_params=exp_params, + **kwargs, + )) + + for sim in sims: - for sim in simulation.iter_from_config( - args.file, - dry_run=args.dry_run, - exporters=exporters, - parallel=parallel, - outdir=output, - exporter_params=exp_params, - **kwargs, - ): if args.set: for s in args.set: k, v = s.split("=", 1)[:2] @@ -219,7 +230,6 @@ def main( @contextmanager def easy(cfg, pdb=False, debug=False, **kwargs): - ex = None try: yield main(cfg, **kwargs)[0] except Exception as e: @@ -228,10 +238,7 @@ def easy(cfg, pdb=False, debug=False, **kwargs): print(traceback.format_exc()) post_mortem() - ex = e - finally: - if ex: - raise ex + raise if __name__ == "__main__": diff --git a/soil/agents/__init__.py b/soil/agents/__init__.py index 9b5736b..c13999f 100644 --- a/soil/agents/__init__.py +++ b/soil/agents/__init__.py @@ -40,23 +40,31 @@ class MetaAgent(ABCMeta): new_nmspc = { "_defaults": defaults, + "_last_return": None, + "_last_except": None, } for attr, func in namespace.items(): if attr == "step" and inspect.isgeneratorfunction(func): orig_func = func - new_nmspc["_MetaAgent__coroutine"] = None + new_nmspc["_coroutine"] = None @wraps(func) def func(self): while True: - if not self.__coroutine: - self.__coroutine = orig_func(self) + if not self._coroutine: + self._coroutine = orig_func(self) try: - return next(self.__coroutine) + if self._last_except: + return self._coroutine.throw(self._last_except) + else: + return self._coroutine.send(self._last_return) except StopIteration as ex: - self.__coroutine = None + self._coroutine = None return ex.value + finally: + self._last_return = None + self._last_except = None func.id = name or func.__name__ func.is_default = False @@ -190,6 +198,10 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): def die(self): self.info(f"agent dying") self.alive = False + try: + self.model.schedule.remove(self) + except KeyError: + pass return time.NEVER def step(self): @@ -617,6 +629,7 @@ def _from_distro( from .network_agents import * from .fsm import * +from .evented import * from .BassModel import * from .BigMarketModel import * from .IndependentCascadeModel import * diff --git a/soil/agents/evented.py b/soil/agents/evented.py new file mode 100644 index 0000000..451b570 --- /dev/null +++ b/soil/agents/evented.py @@ -0,0 +1,57 @@ +from . import BaseAgent +from ..events import Message, Tell, Ask, Reply, TimedOut +from ..time import Cond +from functools import partial +from collections import deque + + +class Evented(BaseAgent): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._inbox = deque() + self._received = 0 + self._processed = 0 + + + def on_receive(self, *args, **kwargs): + pass + + def received(self, expiration=None, timeout=None): + current = self._received + if expiration is None: + expiration = float('inf') if timeout is None else self.now + timeout + + if expiration < self.now: + raise ValueError("Invalid expiration time") + + def ready(agent): + return agent._received > current or agent.now >= expiration + + def value(agent): + if agent.now > expiration: + raise TimedOut("No message received") + + c = Cond(func=ready, return_func=value) + c._checked = True + return c + + def tell(self, msg, sender): + self._received += 1 + self._inbox.append(Tell(payload=msg, sender=sender)) + + def ask(self, msg, timeout=None): + self._received += 1 + ask = Ask(payload=msg) + self._inbox.append(ask) + expiration = float('inf') if timeout is None else self.now + timeout + return ask.replied(expiration=expiration) + + def check_messages(self): + while self._inbox: + msg = self._inbox.popleft() + self._processed += 1 + if msg.expired(self.now): + continue + reply = self.on_receive(msg.payload, sender=msg.sender) + if isinstance(msg, Ask): + msg.reply = reply diff --git a/soil/agents/fsm.py b/soil/agents/fsm.py index 729313d..4b64364 100644 --- a/soil/agents/fsm.py +++ b/soil/agents/fsm.py @@ -1,6 +1,6 @@ from . import MetaAgent, BaseAgent -from functools import partial +from functools import partial, wraps import inspect @@ -19,17 +19,26 @@ def state(name=None): while True: if not self._coroutine: self._coroutine = orig_func(self) + try: - n = next(self._coroutine) + if self._last_except: + n = self._coroutine.throw(self._last_except) + else: + n = self._coroutine.send(self._last_return) if n: return None, n - return + return n except StopIteration as ex: self._coroutine = None next_state = ex.value if next_state is not None: self._set_state(next_state) return next_state + finally: + self._last_return = None + self._last_except = None + + func.id = name or func.__name__ func.is_default = False diff --git a/soil/debugging.py b/soil/debugging.py index f5a43e7..4344df0 100644 --- a/soil/debugging.py +++ b/soil/debugging.py @@ -30,9 +30,9 @@ def wrapcmd(func): class Debug(pdb.Pdb): def __init__(self, *args, skip_soil=False, **kwargs): skip = kwargs.get("skip", []) - skip.append("soil") - skip.append("contextlib") if skip_soil: + skip.append("soil") + skip.append("contextlib") skip.append("soil.*") skip.append("mesa.*") super(Debug, self).__init__(*args, skip=skip, **kwargs) diff --git a/soil/environment.py b/soil/environment.py index 8245ca0..926f26f 100644 --- a/soil/environment.py +++ b/soil/environment.py @@ -3,7 +3,6 @@ from __future__ import annotations import os import sqlite3 import math -import random import logging import inspect @@ -19,7 +18,7 @@ import networkx as nx from mesa import Model from mesa.datacollection import DataCollector -from . import agents as agentmod, config, serialization, utils, time, network +from . import agents as agentmod, config, serialization, utils, time, network, events class BaseEnvironment(Model): @@ -294,10 +293,6 @@ class NetworkEnvironment(BaseEnvironment): def add_agent(self, *args, **kwargs): a = super().add_agent(*args, **kwargs) if "node_id" in a: - if a.node_id == 24: - import pdb - - pdb.set_trace() assert self.G.nodes[a.node_id]["agent"] == a return a @@ -316,3 +311,14 @@ class NetworkEnvironment(BaseEnvironment): Environment = NetworkEnvironment + + +class EventedEnvironment(Environment): + def broadcast(self, msg, sender, expiration=None, ttl=None, **kwargs): + for agent in self.agents(**kwargs): + self.logger.info(f'Telling {repr(agent)}: {msg} ttl={ttl}') + try: + agent._inbox.append(events.Tell(payload=msg, sender=sender, expiration=expiration if ttl is None else self.now+ttl)) + except AttributeError: + self.info(f'Agent {agent.unique_id} cannot receive events') + diff --git a/soil/events.py b/soil/events.py new file mode 100644 index 0000000..3bc50eb --- /dev/null +++ b/soil/events.py @@ -0,0 +1,43 @@ +from .time import Cond +from dataclasses import dataclass, field +from typing import Any +from uuid import uuid4 + +class Event: + pass + +@dataclass +class Message: + payload: Any + sender: Any = None + expiration: float = None + id: int = field(default_factory=uuid4) + + def expired(self, when): + return self.expiration is not None and self.expiration < when + +class Reply(Message): + source: Message + + +class Ask(Message): + reply: Message = None + + def replied(self, expiration=None): + def ready(agent): + return self.reply is not None or agent.now > expiration + + def value(agent): + if agent.now > expiration: + raise TimedOut(f'No answer received for {self}') + return self.reply + + return Cond(func=ready, return_func=value) + + +class Tell(Message): + pass + + +class TimedOut(Exception): + pass diff --git a/soil/simulation.py b/soil/simulation.py index f5738d4..5746e67 100644 --- a/soil/simulation.py +++ b/soil/simulation.py @@ -47,7 +47,7 @@ class Simulation: max_time: float = float("inf") max_steps: int = -1 interval: int = 1 - num_trials: int = 3 + num_trials: int = 1 parallel: Optional[bool] = None exporters: Optional[List[str]] = field(default_factory=list) outdir: Optional[str] = None diff --git a/soil/time.py b/soil/time.py index e7acbac..ec1dc02 100644 --- a/soil/time.py +++ b/soil/time.py @@ -45,12 +45,16 @@ class When: def ready(self, agent): return self._time <= agent.model.schedule.time + def return_value(self, agent): + return None + class Cond(When): - def __init__(self, func, delta=1): + def __init__(self, func, delta=1, return_func=lambda agent: None): self._func = func self._delta = delta self._checked = False + self._return_func = return_func def next(self, time): if self._checked: @@ -64,6 +68,9 @@ class Cond(When): self._checked = True return self._func(agent) + def return_value(self, agent): + return self._return_func(agent) + def __eq__(self, other): return False @@ -144,14 +151,21 @@ class TimedActivation(BaseScheduler): ix = 0 + self.logger.debug(f"Queue length: {len(self._queue)}") + while self._queue: (when, agent) = self._queue[0] if when > self.time: break heappop(self._queue) if when.ready(agent): - to_process.append(agent) + try: + agent._last_return = when.return_value(agent) + except Exception as ex: + agent._last_except = ex + self._next.pop(agent.unique_id, None) + to_process.append(agent) continue next_time = min(next_time, when.next(self.time)) @@ -175,10 +189,10 @@ class TimedActivation(BaseScheduler): continue if not getattr(agent, "alive", True): - self.remove(agent) continue value = returned.next(self.time) + agent._last_return = value if value < self.time: raise Exception( diff --git a/tests/test_agents.py b/tests/test_agents.py index 35526e3..c6a603e 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -33,18 +33,20 @@ class TestMain(TestCase): The step function of an agent could be a generator. In that case, the state of the agent will be resumed after every call to step. ''' + a = 0 class Gen(agents.BaseAgent): def step(self): - a = 0 + nonlocal a for i in range(5): - yield a + yield a += 1 e = environment.Environment() g = Gen(model=e, unique_id=e.next_id()) + e.schedule.add(g) for i in range(5): - t = g.step() - assert t == i + e.step() + assert a == i def test_state_decorator(self): class MyAgent(agents.FSM): @@ -53,6 +55,12 @@ class TestMain(TestCase): @agents.state('original') def root(self): self.run += 1 + return self.other + + @agents.state + def other(self): + self.run += 1 + e = environment.Environment() a = MyAgent(model=e, unique_id=e.next_id()) a.step()