mirror of
				https://github.com/gsi-upm/soil
				synced 2025-10-31 07:38:17 +00:00 
			
		
		
		
	Add events
This commit is contained in:
		
							
								
								
									
										141
									
								
								examples/events_and_messages/cars.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								examples/events_and_messages/cars.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,141 @@ | ||||
| from __future__ import annotations | ||||
| from soil import * | ||||
| from soil import events | ||||
| from mesa.space import MultiGrid | ||||
| from enum import Enum | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class Journey: | ||||
|     origin: (int, int) | ||||
|     destination: (int, int) | ||||
|     tip: float | ||||
|  | ||||
|     passenger: Passenger = None | ||||
|     driver: Driver = None | ||||
|  | ||||
|  | ||||
| class City(EventedEnvironment): | ||||
|     def __init__(self, *args, n_cars=1, height=100, width=100,  n_passengers=10, agents=None, **kwargs): | ||||
|         self.grid = MultiGrid(width=width, height=height, torus=False) | ||||
|         if agents is None: | ||||
|             agents = [] | ||||
|             for i in range(n_cars): | ||||
|                 agents.append({'agent_class': Driver}) | ||||
|             for i in range(n_passengers): | ||||
|                 agents.append({'agent_class': Passenger}) | ||||
|         super().__init__(*args, agents=agents, **kwargs) | ||||
|         for agent in self.agents: | ||||
|             self.grid.place_agent(agent, (0, 0)) | ||||
|             self.grid.move_to_empty(agent) | ||||
|  | ||||
| class Driver(Evented, FSM): | ||||
|     pos = None | ||||
|     journey = None | ||||
|     earnings = 0 | ||||
|  | ||||
|     def on_receive(self, msg, sender): | ||||
|         if self.journey is None and isinstance(msg, Journey) and msg.driver is None: | ||||
|             msg.driver = self | ||||
|             self.journey = msg | ||||
|  | ||||
|     @default_state | ||||
|     @state | ||||
|     def wandering(self): | ||||
|         target = None | ||||
|         self.check_passengers() | ||||
|         self.journey = None | ||||
|         while self.journey is None: | ||||
|             if target is None or not self.move_towards(target): | ||||
|                 target = self.random.choice(self.model.grid.get_neighborhood(self.pos, moore=False)) | ||||
|             self.check_passengers() | ||||
|             self.check_messages() # This will call on_receive behind the scenes | ||||
|             yield Delta(30) | ||||
|         try: | ||||
|             self.journey = yield self.journey.passenger.ask(self.journey, timeout=60) | ||||
|         except events.TimedOut: | ||||
|             self.journey = None | ||||
|             return | ||||
|         return self.driving | ||||
|  | ||||
|     def check_passengers(self): | ||||
|         c = self.count_agents(agent_class=Passenger) | ||||
|         self.info(f"Passengers left {c}") | ||||
|         if not c: | ||||
|             self.die() | ||||
|  | ||||
|     @state | ||||
|     def driving(self): | ||||
|         #Approaching | ||||
|         while self.move_towards(self.journey.origin): | ||||
|             yield | ||||
|         while self.move_towards(self.journey.destination, with_passenger=True): | ||||
|             yield | ||||
|         self.check_passengers() | ||||
|         return self.wandering | ||||
|  | ||||
|     def move_towards(self, target, with_passenger=False): | ||||
|         '''Move one cell at a time towards a target''' | ||||
|         self.info(f"Moving { self.pos } -> { target }") | ||||
|         if target[0] == self.pos[0] and target[1] == self.pos[1]: | ||||
|             return False | ||||
|  | ||||
|         next_pos = [self.pos[0], self.pos[1]] | ||||
|         for idx in [0, 1]: | ||||
|             if self.pos[idx] < target[idx]: | ||||
|                 next_pos[idx] += 1 | ||||
|                 break | ||||
|             if self.pos[idx] > target[idx]: | ||||
|                 next_pos[idx] -= 1 | ||||
|                 break | ||||
|         self.model.grid.move_agent(self, tuple(next_pos)) | ||||
|         if with_passenger: | ||||
|             self.journey.passenger.pos = self.pos  # This could be communicated through messages | ||||
|         return True | ||||
|              | ||||
|  | ||||
| class Passenger(Evented, FSM): | ||||
|     pos = None | ||||
|  | ||||
|     @default_state | ||||
|     @state | ||||
|     def asking(self): | ||||
|         destination = (self.random.randint(0, self.model.grid.height), self.random.randint(0, self.model.grid.width)) | ||||
|         self.journey = None | ||||
|         journey = Journey(origin=self.pos, | ||||
|                           destination=destination, | ||||
|                           tip=self.random.randint(10, 100), | ||||
|                           passenger=self) | ||||
|  | ||||
|         timeout = 60 | ||||
|         expiration = self.now + timeout | ||||
|         self.model.broadcast(journey, ttl=timeout, sender=self, agent_class=Driver) | ||||
|         while not self.journey: | ||||
|             self.info(f"Passenger at: { self.pos }. Checking for responses.") | ||||
|             try: | ||||
|                 yield self.received(expiration=expiration) | ||||
|             except events.TimedOut: | ||||
|                 self.info(f"Passenger at: { self.pos }. Asking for journey.") | ||||
|                 self.model.broadcast(journey, ttl=timeout, sender=self, agent_class=Driver) | ||||
|                 expiration = self.now + timeout | ||||
|             self.check_messages() | ||||
|         return self.driving_home | ||||
|  | ||||
|     def on_receive(self, msg, sender): | ||||
|         if isinstance(msg, Journey): | ||||
|             self.journey = msg | ||||
|             return msg | ||||
|  | ||||
|     @state | ||||
|     def driving_home(self): | ||||
|         while self.pos[0] != self.journey.destination[0] or self.pos[1] != self.journey.destination[1]: | ||||
|             yield self.received(timeout=60) | ||||
|         self.info("Got home safe!") | ||||
|         self.die() | ||||
|  | ||||
|  | ||||
| simulation = Simulation(model_class=City, model_params={'n_passengers': 2}) | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     with easy(simulation) as s: | ||||
|         s.run() | ||||
| @@ -17,7 +17,7 @@ except NameError: | ||||
| from .agents import * | ||||
| from . import agents | ||||
| from .simulation import * | ||||
| from .environment import Environment | ||||
| from .environment import Environment, EventedEnvironment | ||||
| from . import serialization | ||||
| from .utils import logger | ||||
| from .time import * | ||||
| @@ -34,6 +34,9 @@ def main( | ||||
|     pdb=False, | ||||
|     **kwargs, | ||||
| ): | ||||
|  | ||||
|     if isinstance(cfg, Simulation): | ||||
|         sim = cfg | ||||
|     import argparse | ||||
|     from . import simulation | ||||
|  | ||||
| @@ -44,7 +47,7 @@ def main( | ||||
|         "file", | ||||
|         type=str, | ||||
|         nargs="?", | ||||
|         default=cfg, | ||||
|         default=cfg if sim is None else '', | ||||
|         help="Configuration file for the simulation (e.g., YAML or JSON)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
| @@ -150,7 +153,7 @@ def main( | ||||
|     if output is None: | ||||
|         output = args.output | ||||
|  | ||||
|     logger.info("Loading config file: {}".format(args.file)) | ||||
|  | ||||
|  | ||||
|     debug = debug or args.debug | ||||
|  | ||||
| @@ -162,19 +165,27 @@ def main( | ||||
|     try: | ||||
|         exp_params = {} | ||||
|  | ||||
|         if not os.path.exists(args.file): | ||||
|             logger.error("Please, input a valid file") | ||||
|             return | ||||
|         if sim: | ||||
|             logger.info("Loading simulation instance") | ||||
|             sims = [sim, ] | ||||
|         else: | ||||
|             logger.info("Loading config file: {}".format(args.file)) | ||||
|             if not os.path.exists(args.file): | ||||
|                 logger.error("Please, input a valid file") | ||||
|                 return | ||||
|  | ||||
|             sims = list(simulation.iter_from_config( | ||||
|                 args.file, | ||||
|                 dry_run=args.dry_run, | ||||
|                 exporters=exporters, | ||||
|                 parallel=parallel, | ||||
|                 outdir=output, | ||||
|                 exporter_params=exp_params, | ||||
|                 **kwargs, | ||||
|             )) | ||||
|  | ||||
|         for sim in sims: | ||||
|  | ||||
|         for sim in simulation.iter_from_config( | ||||
|             args.file, | ||||
|             dry_run=args.dry_run, | ||||
|             exporters=exporters, | ||||
|             parallel=parallel, | ||||
|             outdir=output, | ||||
|             exporter_params=exp_params, | ||||
|             **kwargs, | ||||
|         ): | ||||
|             if args.set: | ||||
|                 for s in args.set: | ||||
|                     k, v = s.split("=", 1)[:2] | ||||
| @@ -219,7 +230,6 @@ def main( | ||||
|  | ||||
| @contextmanager | ||||
| def easy(cfg, pdb=False, debug=False, **kwargs): | ||||
|     ex = None | ||||
|     try: | ||||
|         yield main(cfg, **kwargs)[0] | ||||
|     except Exception as e: | ||||
| @@ -228,10 +238,7 @@ def easy(cfg, pdb=False, debug=False, **kwargs): | ||||
|  | ||||
|             print(traceback.format_exc()) | ||||
|             post_mortem() | ||||
|         ex = e | ||||
|     finally: | ||||
|         if ex: | ||||
|             raise ex | ||||
|         raise | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|   | ||||
| @@ -40,23 +40,31 @@ class MetaAgent(ABCMeta): | ||||
|  | ||||
|         new_nmspc = { | ||||
|             "_defaults": defaults, | ||||
|             "_last_return": None, | ||||
|             "_last_except": None, | ||||
|         } | ||||
|  | ||||
|         for attr, func in namespace.items(): | ||||
|             if attr == "step" and inspect.isgeneratorfunction(func): | ||||
|                 orig_func = func | ||||
|                 new_nmspc["_MetaAgent__coroutine"] = None | ||||
|                 new_nmspc["_coroutine"] = None | ||||
|  | ||||
|                 @wraps(func) | ||||
|                 def func(self): | ||||
|                     while True: | ||||
|                         if not self.__coroutine: | ||||
|                             self.__coroutine = orig_func(self) | ||||
|                         if not self._coroutine: | ||||
|                             self._coroutine = orig_func(self) | ||||
|                         try: | ||||
|                             return next(self.__coroutine) | ||||
|                             if self._last_except: | ||||
|                                 return self._coroutine.throw(self._last_except) | ||||
|                             else: | ||||
|                                 return self._coroutine.send(self._last_return) | ||||
|                         except StopIteration as ex: | ||||
|                             self.__coroutine = None | ||||
|                             self._coroutine = None | ||||
|                             return ex.value | ||||
|                         finally: | ||||
|                             self._last_return = None | ||||
|                             self._last_except = None | ||||
|  | ||||
|                 func.id = name or func.__name__ | ||||
|                 func.is_default = False | ||||
| @@ -190,6 +198,10 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): | ||||
|     def die(self): | ||||
|         self.info(f"agent dying") | ||||
|         self.alive = False | ||||
|         try: | ||||
|             self.model.schedule.remove(self) | ||||
|         except KeyError: | ||||
|             pass | ||||
|         return time.NEVER | ||||
|  | ||||
|     def step(self): | ||||
| @@ -617,6 +629,7 @@ def _from_distro( | ||||
|  | ||||
| from .network_agents import * | ||||
| from .fsm import * | ||||
| from .evented import * | ||||
| from .BassModel import * | ||||
| from .BigMarketModel import * | ||||
| from .IndependentCascadeModel import * | ||||
|   | ||||
							
								
								
									
										57
									
								
								soil/agents/evented.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								soil/agents/evented.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | ||||
| from . import BaseAgent | ||||
| from ..events import Message, Tell, Ask, Reply, TimedOut | ||||
| from ..time import Cond | ||||
| from functools import partial | ||||
| from collections import deque | ||||
|  | ||||
|  | ||||
| class Evented(BaseAgent): | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self._inbox = deque() | ||||
|         self._received = 0 | ||||
|         self._processed = 0 | ||||
|  | ||||
|  | ||||
|     def on_receive(self, *args, **kwargs): | ||||
|         pass | ||||
|  | ||||
|     def received(self, expiration=None, timeout=None): | ||||
|         current = self._received | ||||
|         if expiration is None: | ||||
|             expiration = float('inf') if timeout is None else self.now + timeout | ||||
|  | ||||
|         if expiration < self.now: | ||||
|             raise ValueError("Invalid expiration time") | ||||
|  | ||||
|         def ready(agent): | ||||
|             return agent._received > current or agent.now >= expiration | ||||
|  | ||||
|         def value(agent): | ||||
|             if agent.now > expiration: | ||||
|                 raise TimedOut("No message received") | ||||
|  | ||||
|         c = Cond(func=ready, return_func=value) | ||||
|         c._checked = True | ||||
|         return c | ||||
|  | ||||
|     def tell(self, msg, sender): | ||||
|         self._received += 1 | ||||
|         self._inbox.append(Tell(payload=msg, sender=sender)) | ||||
|  | ||||
|     def ask(self, msg, timeout=None): | ||||
|         self._received += 1 | ||||
|         ask = Ask(payload=msg) | ||||
|         self._inbox.append(ask) | ||||
|         expiration = float('inf') if timeout is None else self.now + timeout | ||||
|         return ask.replied(expiration=expiration) | ||||
|  | ||||
|     def check_messages(self): | ||||
|         while self._inbox: | ||||
|             msg = self._inbox.popleft() | ||||
|             self._processed += 1 | ||||
|             if msg.expired(self.now): | ||||
|                 continue | ||||
|             reply = self.on_receive(msg.payload, sender=msg.sender) | ||||
|             if isinstance(msg, Ask): | ||||
|                 msg.reply = reply | ||||
| @@ -1,6 +1,6 @@ | ||||
| from . import MetaAgent, BaseAgent | ||||
|  | ||||
| from functools import partial | ||||
| from functools import partial, wraps | ||||
| import inspect | ||||
|  | ||||
|  | ||||
| @@ -19,17 +19,26 @@ def state(name=None): | ||||
|                 while True: | ||||
|                     if not self._coroutine: | ||||
|                         self._coroutine = orig_func(self) | ||||
|  | ||||
|                     try: | ||||
|                         n = next(self._coroutine) | ||||
|                         if self._last_except: | ||||
|                             n = self._coroutine.throw(self._last_except) | ||||
|                         else: | ||||
|                             n = self._coroutine.send(self._last_return) | ||||
|                         if n: | ||||
|                             return None, n | ||||
|                         return | ||||
|                         return n | ||||
|                     except StopIteration as ex: | ||||
|                         self._coroutine = None | ||||
|                         next_state = ex.value | ||||
|                         if next_state is not None: | ||||
|                             self._set_state(next_state) | ||||
|                         return next_state | ||||
|                     finally: | ||||
|                         self._last_return = None | ||||
|                         self._last_except = None | ||||
|  | ||||
|  | ||||
|  | ||||
|         func.id = name or func.__name__ | ||||
|         func.is_default = False | ||||
|   | ||||
| @@ -30,9 +30,9 @@ def wrapcmd(func): | ||||
| class Debug(pdb.Pdb): | ||||
|     def __init__(self, *args, skip_soil=False, **kwargs): | ||||
|         skip = kwargs.get("skip", []) | ||||
|         skip.append("soil") | ||||
|         skip.append("contextlib") | ||||
|         if skip_soil: | ||||
|             skip.append("soil") | ||||
|             skip.append("contextlib") | ||||
|             skip.append("soil.*") | ||||
|             skip.append("mesa.*") | ||||
|         super(Debug, self).__init__(*args, skip=skip, **kwargs) | ||||
|   | ||||
| @@ -3,7 +3,6 @@ from __future__ import annotations | ||||
| import os | ||||
| import sqlite3 | ||||
| import math | ||||
| import random | ||||
| import logging | ||||
| import inspect | ||||
|  | ||||
| @@ -19,7 +18,7 @@ import networkx as nx | ||||
| from mesa import Model | ||||
| from mesa.datacollection import DataCollector | ||||
|  | ||||
| from . import agents as agentmod, config, serialization, utils, time, network | ||||
| from . import agents as agentmod, config, serialization, utils, time, network, events | ||||
|  | ||||
|  | ||||
| class BaseEnvironment(Model): | ||||
| @@ -294,10 +293,6 @@ class NetworkEnvironment(BaseEnvironment): | ||||
|     def add_agent(self, *args, **kwargs): | ||||
|         a = super().add_agent(*args, **kwargs) | ||||
|         if "node_id" in a: | ||||
|             if a.node_id == 24: | ||||
|                 import pdb | ||||
|  | ||||
|                 pdb.set_trace() | ||||
|             assert self.G.nodes[a.node_id]["agent"] == a | ||||
|         return a | ||||
|  | ||||
| @@ -316,3 +311,14 @@ class NetworkEnvironment(BaseEnvironment): | ||||
|  | ||||
|  | ||||
| Environment = NetworkEnvironment | ||||
|  | ||||
|  | ||||
| class EventedEnvironment(Environment): | ||||
|     def broadcast(self, msg,  sender, expiration=None, ttl=None, **kwargs): | ||||
|         for agent in self.agents(**kwargs): | ||||
|             self.logger.info(f'Telling {repr(agent)}: {msg} ttl={ttl}') | ||||
|             try: | ||||
|                 agent._inbox.append(events.Tell(payload=msg, sender=sender, expiration=expiration if ttl is None else self.now+ttl)) | ||||
|             except AttributeError: | ||||
|                 self.info(f'Agent {agent.unique_id} cannot receive events') | ||||
|  | ||||
|   | ||||
							
								
								
									
										43
									
								
								soil/events.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								soil/events.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | ||||
| from .time import Cond | ||||
| from dataclasses import dataclass, field | ||||
| from typing import Any | ||||
| from uuid import uuid4 | ||||
|  | ||||
| class Event: | ||||
|     pass | ||||
|  | ||||
| @dataclass | ||||
| class Message: | ||||
|     payload: Any | ||||
|     sender: Any = None | ||||
|     expiration: float = None | ||||
|     id: int = field(default_factory=uuid4) | ||||
|  | ||||
|     def expired(self, when): | ||||
|         return self.expiration is not None and self.expiration < when | ||||
|  | ||||
| class Reply(Message): | ||||
|     source: Message | ||||
|  | ||||
|  | ||||
| class Ask(Message): | ||||
|     reply: Message = None | ||||
|  | ||||
|     def replied(self, expiration=None): | ||||
|         def ready(agent): | ||||
|             return self.reply is not None or agent.now > expiration | ||||
|  | ||||
|         def value(agent): | ||||
|             if agent.now > expiration: | ||||
|                 raise TimedOut(f'No answer received for {self}') | ||||
|             return self.reply | ||||
|  | ||||
|         return Cond(func=ready, return_func=value) | ||||
|  | ||||
|  | ||||
| class Tell(Message): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class TimedOut(Exception): | ||||
|     pass | ||||
| @@ -47,7 +47,7 @@ class Simulation: | ||||
|     max_time: float = float("inf") | ||||
|     max_steps: int = -1 | ||||
|     interval: int = 1 | ||||
|     num_trials: int = 3 | ||||
|     num_trials: int = 1 | ||||
|     parallel: Optional[bool] = None | ||||
|     exporters: Optional[List[str]] = field(default_factory=list) | ||||
|     outdir: Optional[str] = None | ||||
|   | ||||
							
								
								
									
										20
									
								
								soil/time.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								soil/time.py
									
									
									
									
									
								
							| @@ -45,12 +45,16 @@ class When: | ||||
|     def ready(self, agent): | ||||
|         return self._time <= agent.model.schedule.time | ||||
|  | ||||
|     def return_value(self, agent): | ||||
|         return None | ||||
|  | ||||
|  | ||||
| class Cond(When): | ||||
|     def __init__(self, func, delta=1): | ||||
|     def __init__(self, func, delta=1, return_func=lambda agent: None): | ||||
|         self._func = func | ||||
|         self._delta = delta | ||||
|         self._checked = False | ||||
|         self._return_func = return_func | ||||
|  | ||||
|     def next(self, time): | ||||
|         if self._checked: | ||||
| @@ -64,6 +68,9 @@ class Cond(When): | ||||
|         self._checked = True | ||||
|         return self._func(agent) | ||||
|  | ||||
|     def return_value(self, agent): | ||||
|         return self._return_func(agent) | ||||
|  | ||||
|     def __eq__(self, other): | ||||
|         return False | ||||
|  | ||||
| @@ -144,14 +151,21 @@ class TimedActivation(BaseScheduler): | ||||
|  | ||||
|         ix = 0 | ||||
|  | ||||
|         self.logger.debug(f"Queue length: {len(self._queue)}") | ||||
|  | ||||
|         while self._queue: | ||||
|             (when, agent) = self._queue[0] | ||||
|             if when > self.time: | ||||
|                 break | ||||
|             heappop(self._queue) | ||||
|             if when.ready(agent): | ||||
|                 to_process.append(agent) | ||||
|                 try: | ||||
|                     agent._last_return = when.return_value(agent) | ||||
|                 except Exception as ex: | ||||
|                     agent._last_except = ex | ||||
|  | ||||
|                 self._next.pop(agent.unique_id, None) | ||||
|                 to_process.append(agent) | ||||
|                 continue | ||||
|  | ||||
|             next_time = min(next_time, when.next(self.time)) | ||||
| @@ -175,10 +189,10 @@ class TimedActivation(BaseScheduler): | ||||
|                 continue | ||||
|  | ||||
|             if not getattr(agent, "alive", True): | ||||
|                 self.remove(agent) | ||||
|                 continue | ||||
|  | ||||
|             value = returned.next(self.time) | ||||
|             agent._last_return = value | ||||
|  | ||||
|             if value < self.time: | ||||
|                 raise Exception( | ||||
|   | ||||
| @@ -33,18 +33,20 @@ class TestMain(TestCase): | ||||
|         The step function of an agent could be a generator. In that case, the state of the | ||||
|         agent will be resumed after every call to step. | ||||
|         ''' | ||||
|         a = 0 | ||||
|         class Gen(agents.BaseAgent): | ||||
|             def step(self): | ||||
|                 a = 0 | ||||
|                 nonlocal a | ||||
|                 for i in range(5): | ||||
|                     yield a | ||||
|                     yield | ||||
|                     a += 1 | ||||
|         e = environment.Environment() | ||||
|         g = Gen(model=e, unique_id=e.next_id()) | ||||
|         e.schedule.add(g) | ||||
|  | ||||
|         for i in range(5): | ||||
|             t = g.step() | ||||
|             assert t == i | ||||
|             e.step() | ||||
|             assert a == i | ||||
|  | ||||
|     def test_state_decorator(self): | ||||
|         class MyAgent(agents.FSM): | ||||
| @@ -53,6 +55,12 @@ class TestMain(TestCase): | ||||
|             @agents.state('original') | ||||
|             def root(self): | ||||
|                 self.run += 1 | ||||
|                 return self.other | ||||
|  | ||||
|             @agents.state | ||||
|             def other(self): | ||||
|                 self.run += 1 | ||||
|  | ||||
|         e = environment.Environment() | ||||
|         a = MyAgent(model=e, unique_id=e.next_id()) | ||||
|         a.step() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user