mirror of
https://github.com/gsi-upm/soil
synced 2024-11-25 04:12:29 +00:00
Add events
This commit is contained in:
parent
3776c4e5c5
commit
159c9a9077
141
examples/events_and_messages/cars.py
Normal file
141
examples/events_and_messages/cars.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from soil import *
|
||||||
|
from soil import events
|
||||||
|
from mesa.space import MultiGrid
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Journey:
|
||||||
|
origin: (int, int)
|
||||||
|
destination: (int, int)
|
||||||
|
tip: float
|
||||||
|
|
||||||
|
passenger: Passenger = None
|
||||||
|
driver: Driver = None
|
||||||
|
|
||||||
|
|
||||||
|
class City(EventedEnvironment):
|
||||||
|
def __init__(self, *args, n_cars=1, height=100, width=100, n_passengers=10, agents=None, **kwargs):
|
||||||
|
self.grid = MultiGrid(width=width, height=height, torus=False)
|
||||||
|
if agents is None:
|
||||||
|
agents = []
|
||||||
|
for i in range(n_cars):
|
||||||
|
agents.append({'agent_class': Driver})
|
||||||
|
for i in range(n_passengers):
|
||||||
|
agents.append({'agent_class': Passenger})
|
||||||
|
super().__init__(*args, agents=agents, **kwargs)
|
||||||
|
for agent in self.agents:
|
||||||
|
self.grid.place_agent(agent, (0, 0))
|
||||||
|
self.grid.move_to_empty(agent)
|
||||||
|
|
||||||
|
class Driver(Evented, FSM):
|
||||||
|
pos = None
|
||||||
|
journey = None
|
||||||
|
earnings = 0
|
||||||
|
|
||||||
|
def on_receive(self, msg, sender):
|
||||||
|
if self.journey is None and isinstance(msg, Journey) and msg.driver is None:
|
||||||
|
msg.driver = self
|
||||||
|
self.journey = msg
|
||||||
|
|
||||||
|
@default_state
|
||||||
|
@state
|
||||||
|
def wandering(self):
|
||||||
|
target = None
|
||||||
|
self.check_passengers()
|
||||||
|
self.journey = None
|
||||||
|
while self.journey is None:
|
||||||
|
if target is None or not self.move_towards(target):
|
||||||
|
target = self.random.choice(self.model.grid.get_neighborhood(self.pos, moore=False))
|
||||||
|
self.check_passengers()
|
||||||
|
self.check_messages() # This will call on_receive behind the scenes
|
||||||
|
yield Delta(30)
|
||||||
|
try:
|
||||||
|
self.journey = yield self.journey.passenger.ask(self.journey, timeout=60)
|
||||||
|
except events.TimedOut:
|
||||||
|
self.journey = None
|
||||||
|
return
|
||||||
|
return self.driving
|
||||||
|
|
||||||
|
def check_passengers(self):
|
||||||
|
c = self.count_agents(agent_class=Passenger)
|
||||||
|
self.info(f"Passengers left {c}")
|
||||||
|
if not c:
|
||||||
|
self.die()
|
||||||
|
|
||||||
|
@state
|
||||||
|
def driving(self):
|
||||||
|
#Approaching
|
||||||
|
while self.move_towards(self.journey.origin):
|
||||||
|
yield
|
||||||
|
while self.move_towards(self.journey.destination, with_passenger=True):
|
||||||
|
yield
|
||||||
|
self.check_passengers()
|
||||||
|
return self.wandering
|
||||||
|
|
||||||
|
def move_towards(self, target, with_passenger=False):
|
||||||
|
'''Move one cell at a time towards a target'''
|
||||||
|
self.info(f"Moving { self.pos } -> { target }")
|
||||||
|
if target[0] == self.pos[0] and target[1] == self.pos[1]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
next_pos = [self.pos[0], self.pos[1]]
|
||||||
|
for idx in [0, 1]:
|
||||||
|
if self.pos[idx] < target[idx]:
|
||||||
|
next_pos[idx] += 1
|
||||||
|
break
|
||||||
|
if self.pos[idx] > target[idx]:
|
||||||
|
next_pos[idx] -= 1
|
||||||
|
break
|
||||||
|
self.model.grid.move_agent(self, tuple(next_pos))
|
||||||
|
if with_passenger:
|
||||||
|
self.journey.passenger.pos = self.pos # This could be communicated through messages
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class Passenger(Evented, FSM):
|
||||||
|
pos = None
|
||||||
|
|
||||||
|
@default_state
|
||||||
|
@state
|
||||||
|
def asking(self):
|
||||||
|
destination = (self.random.randint(0, self.model.grid.height), self.random.randint(0, self.model.grid.width))
|
||||||
|
self.journey = None
|
||||||
|
journey = Journey(origin=self.pos,
|
||||||
|
destination=destination,
|
||||||
|
tip=self.random.randint(10, 100),
|
||||||
|
passenger=self)
|
||||||
|
|
||||||
|
timeout = 60
|
||||||
|
expiration = self.now + timeout
|
||||||
|
self.model.broadcast(journey, ttl=timeout, sender=self, agent_class=Driver)
|
||||||
|
while not self.journey:
|
||||||
|
self.info(f"Passenger at: { self.pos }. Checking for responses.")
|
||||||
|
try:
|
||||||
|
yield self.received(expiration=expiration)
|
||||||
|
except events.TimedOut:
|
||||||
|
self.info(f"Passenger at: { self.pos }. Asking for journey.")
|
||||||
|
self.model.broadcast(journey, ttl=timeout, sender=self, agent_class=Driver)
|
||||||
|
expiration = self.now + timeout
|
||||||
|
self.check_messages()
|
||||||
|
return self.driving_home
|
||||||
|
|
||||||
|
def on_receive(self, msg, sender):
|
||||||
|
if isinstance(msg, Journey):
|
||||||
|
self.journey = msg
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@state
|
||||||
|
def driving_home(self):
|
||||||
|
while self.pos[0] != self.journey.destination[0] or self.pos[1] != self.journey.destination[1]:
|
||||||
|
yield self.received(timeout=60)
|
||||||
|
self.info("Got home safe!")
|
||||||
|
self.die()
|
||||||
|
|
||||||
|
|
||||||
|
simulation = Simulation(model_class=City, model_params={'n_passengers': 2})
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with easy(simulation) as s:
|
||||||
|
s.run()
|
@ -17,7 +17,7 @@ except NameError:
|
|||||||
from .agents import *
|
from .agents import *
|
||||||
from . import agents
|
from . import agents
|
||||||
from .simulation import *
|
from .simulation import *
|
||||||
from .environment import Environment
|
from .environment import Environment, EventedEnvironment
|
||||||
from . import serialization
|
from . import serialization
|
||||||
from .utils import logger
|
from .utils import logger
|
||||||
from .time import *
|
from .time import *
|
||||||
@ -34,6 +34,9 @@ def main(
|
|||||||
pdb=False,
|
pdb=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if isinstance(cfg, Simulation):
|
||||||
|
sim = cfg
|
||||||
import argparse
|
import argparse
|
||||||
from . import simulation
|
from . import simulation
|
||||||
|
|
||||||
@ -44,7 +47,7 @@ def main(
|
|||||||
"file",
|
"file",
|
||||||
type=str,
|
type=str,
|
||||||
nargs="?",
|
nargs="?",
|
||||||
default=cfg,
|
default=cfg if sim is None else '',
|
||||||
help="Configuration file for the simulation (e.g., YAML or JSON)",
|
help="Configuration file for the simulation (e.g., YAML or JSON)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -150,7 +153,7 @@ def main(
|
|||||||
if output is None:
|
if output is None:
|
||||||
output = args.output
|
output = args.output
|
||||||
|
|
||||||
logger.info("Loading config file: {}".format(args.file))
|
|
||||||
|
|
||||||
debug = debug or args.debug
|
debug = debug or args.debug
|
||||||
|
|
||||||
@ -162,11 +165,16 @@ def main(
|
|||||||
try:
|
try:
|
||||||
exp_params = {}
|
exp_params = {}
|
||||||
|
|
||||||
|
if sim:
|
||||||
|
logger.info("Loading simulation instance")
|
||||||
|
sims = [sim, ]
|
||||||
|
else:
|
||||||
|
logger.info("Loading config file: {}".format(args.file))
|
||||||
if not os.path.exists(args.file):
|
if not os.path.exists(args.file):
|
||||||
logger.error("Please, input a valid file")
|
logger.error("Please, input a valid file")
|
||||||
return
|
return
|
||||||
|
|
||||||
for sim in simulation.iter_from_config(
|
sims = list(simulation.iter_from_config(
|
||||||
args.file,
|
args.file,
|
||||||
dry_run=args.dry_run,
|
dry_run=args.dry_run,
|
||||||
exporters=exporters,
|
exporters=exporters,
|
||||||
@ -174,7 +182,10 @@ def main(
|
|||||||
outdir=output,
|
outdir=output,
|
||||||
exporter_params=exp_params,
|
exporter_params=exp_params,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
))
|
||||||
|
|
||||||
|
for sim in sims:
|
||||||
|
|
||||||
if args.set:
|
if args.set:
|
||||||
for s in args.set:
|
for s in args.set:
|
||||||
k, v = s.split("=", 1)[:2]
|
k, v = s.split("=", 1)[:2]
|
||||||
@ -219,7 +230,6 @@ def main(
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def easy(cfg, pdb=False, debug=False, **kwargs):
|
def easy(cfg, pdb=False, debug=False, **kwargs):
|
||||||
ex = None
|
|
||||||
try:
|
try:
|
||||||
yield main(cfg, **kwargs)[0]
|
yield main(cfg, **kwargs)[0]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -228,10 +238,7 @@ def easy(cfg, pdb=False, debug=False, **kwargs):
|
|||||||
|
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
post_mortem()
|
post_mortem()
|
||||||
ex = e
|
raise
|
||||||
finally:
|
|
||||||
if ex:
|
|
||||||
raise ex
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -40,23 +40,31 @@ class MetaAgent(ABCMeta):
|
|||||||
|
|
||||||
new_nmspc = {
|
new_nmspc = {
|
||||||
"_defaults": defaults,
|
"_defaults": defaults,
|
||||||
|
"_last_return": None,
|
||||||
|
"_last_except": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
for attr, func in namespace.items():
|
for attr, func in namespace.items():
|
||||||
if attr == "step" and inspect.isgeneratorfunction(func):
|
if attr == "step" and inspect.isgeneratorfunction(func):
|
||||||
orig_func = func
|
orig_func = func
|
||||||
new_nmspc["_MetaAgent__coroutine"] = None
|
new_nmspc["_coroutine"] = None
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def func(self):
|
def func(self):
|
||||||
while True:
|
while True:
|
||||||
if not self.__coroutine:
|
if not self._coroutine:
|
||||||
self.__coroutine = orig_func(self)
|
self._coroutine = orig_func(self)
|
||||||
try:
|
try:
|
||||||
return next(self.__coroutine)
|
if self._last_except:
|
||||||
|
return self._coroutine.throw(self._last_except)
|
||||||
|
else:
|
||||||
|
return self._coroutine.send(self._last_return)
|
||||||
except StopIteration as ex:
|
except StopIteration as ex:
|
||||||
self.__coroutine = None
|
self._coroutine = None
|
||||||
return ex.value
|
return ex.value
|
||||||
|
finally:
|
||||||
|
self._last_return = None
|
||||||
|
self._last_except = None
|
||||||
|
|
||||||
func.id = name or func.__name__
|
func.id = name or func.__name__
|
||||||
func.is_default = False
|
func.is_default = False
|
||||||
@ -190,6 +198,10 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|||||||
def die(self):
|
def die(self):
|
||||||
self.info(f"agent dying")
|
self.info(f"agent dying")
|
||||||
self.alive = False
|
self.alive = False
|
||||||
|
try:
|
||||||
|
self.model.schedule.remove(self)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
return time.NEVER
|
return time.NEVER
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
@ -617,6 +629,7 @@ def _from_distro(
|
|||||||
|
|
||||||
from .network_agents import *
|
from .network_agents import *
|
||||||
from .fsm import *
|
from .fsm import *
|
||||||
|
from .evented import *
|
||||||
from .BassModel import *
|
from .BassModel import *
|
||||||
from .BigMarketModel import *
|
from .BigMarketModel import *
|
||||||
from .IndependentCascadeModel import *
|
from .IndependentCascadeModel import *
|
||||||
|
57
soil/agents/evented.py
Normal file
57
soil/agents/evented.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from . import BaseAgent
|
||||||
|
from ..events import Message, Tell, Ask, Reply, TimedOut
|
||||||
|
from ..time import Cond
|
||||||
|
from functools import partial
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class Evented(BaseAgent):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._inbox = deque()
|
||||||
|
self._received = 0
|
||||||
|
self._processed = 0
|
||||||
|
|
||||||
|
|
||||||
|
def on_receive(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def received(self, expiration=None, timeout=None):
|
||||||
|
current = self._received
|
||||||
|
if expiration is None:
|
||||||
|
expiration = float('inf') if timeout is None else self.now + timeout
|
||||||
|
|
||||||
|
if expiration < self.now:
|
||||||
|
raise ValueError("Invalid expiration time")
|
||||||
|
|
||||||
|
def ready(agent):
|
||||||
|
return agent._received > current or agent.now >= expiration
|
||||||
|
|
||||||
|
def value(agent):
|
||||||
|
if agent.now > expiration:
|
||||||
|
raise TimedOut("No message received")
|
||||||
|
|
||||||
|
c = Cond(func=ready, return_func=value)
|
||||||
|
c._checked = True
|
||||||
|
return c
|
||||||
|
|
||||||
|
def tell(self, msg, sender):
|
||||||
|
self._received += 1
|
||||||
|
self._inbox.append(Tell(payload=msg, sender=sender))
|
||||||
|
|
||||||
|
def ask(self, msg, timeout=None):
|
||||||
|
self._received += 1
|
||||||
|
ask = Ask(payload=msg)
|
||||||
|
self._inbox.append(ask)
|
||||||
|
expiration = float('inf') if timeout is None else self.now + timeout
|
||||||
|
return ask.replied(expiration=expiration)
|
||||||
|
|
||||||
|
def check_messages(self):
|
||||||
|
while self._inbox:
|
||||||
|
msg = self._inbox.popleft()
|
||||||
|
self._processed += 1
|
||||||
|
if msg.expired(self.now):
|
||||||
|
continue
|
||||||
|
reply = self.on_receive(msg.payload, sender=msg.sender)
|
||||||
|
if isinstance(msg, Ask):
|
||||||
|
msg.reply = reply
|
@ -1,6 +1,6 @@
|
|||||||
from . import MetaAgent, BaseAgent
|
from . import MetaAgent, BaseAgent
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial, wraps
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
@ -19,17 +19,26 @@ def state(name=None):
|
|||||||
while True:
|
while True:
|
||||||
if not self._coroutine:
|
if not self._coroutine:
|
||||||
self._coroutine = orig_func(self)
|
self._coroutine = orig_func(self)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
n = next(self._coroutine)
|
if self._last_except:
|
||||||
|
n = self._coroutine.throw(self._last_except)
|
||||||
|
else:
|
||||||
|
n = self._coroutine.send(self._last_return)
|
||||||
if n:
|
if n:
|
||||||
return None, n
|
return None, n
|
||||||
return
|
return n
|
||||||
except StopIteration as ex:
|
except StopIteration as ex:
|
||||||
self._coroutine = None
|
self._coroutine = None
|
||||||
next_state = ex.value
|
next_state = ex.value
|
||||||
if next_state is not None:
|
if next_state is not None:
|
||||||
self._set_state(next_state)
|
self._set_state(next_state)
|
||||||
return next_state
|
return next_state
|
||||||
|
finally:
|
||||||
|
self._last_return = None
|
||||||
|
self._last_except = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
func.id = name or func.__name__
|
func.id = name or func.__name__
|
||||||
func.is_default = False
|
func.is_default = False
|
||||||
|
@ -30,9 +30,9 @@ def wrapcmd(func):
|
|||||||
class Debug(pdb.Pdb):
|
class Debug(pdb.Pdb):
|
||||||
def __init__(self, *args, skip_soil=False, **kwargs):
|
def __init__(self, *args, skip_soil=False, **kwargs):
|
||||||
skip = kwargs.get("skip", [])
|
skip = kwargs.get("skip", [])
|
||||||
|
if skip_soil:
|
||||||
skip.append("soil")
|
skip.append("soil")
|
||||||
skip.append("contextlib")
|
skip.append("contextlib")
|
||||||
if skip_soil:
|
|
||||||
skip.append("soil.*")
|
skip.append("soil.*")
|
||||||
skip.append("mesa.*")
|
skip.append("mesa.*")
|
||||||
super(Debug, self).__init__(*args, skip=skip, **kwargs)
|
super(Debug, self).__init__(*args, skip=skip, **kwargs)
|
||||||
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import math
|
import math
|
||||||
import random
|
|
||||||
import logging
|
import logging
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
@ -19,7 +18,7 @@ import networkx as nx
|
|||||||
from mesa import Model
|
from mesa import Model
|
||||||
from mesa.datacollection import DataCollector
|
from mesa.datacollection import DataCollector
|
||||||
|
|
||||||
from . import agents as agentmod, config, serialization, utils, time, network
|
from . import agents as agentmod, config, serialization, utils, time, network, events
|
||||||
|
|
||||||
|
|
||||||
class BaseEnvironment(Model):
|
class BaseEnvironment(Model):
|
||||||
@ -294,10 +293,6 @@ class NetworkEnvironment(BaseEnvironment):
|
|||||||
def add_agent(self, *args, **kwargs):
|
def add_agent(self, *args, **kwargs):
|
||||||
a = super().add_agent(*args, **kwargs)
|
a = super().add_agent(*args, **kwargs)
|
||||||
if "node_id" in a:
|
if "node_id" in a:
|
||||||
if a.node_id == 24:
|
|
||||||
import pdb
|
|
||||||
|
|
||||||
pdb.set_trace()
|
|
||||||
assert self.G.nodes[a.node_id]["agent"] == a
|
assert self.G.nodes[a.node_id]["agent"] == a
|
||||||
return a
|
return a
|
||||||
|
|
||||||
@ -316,3 +311,14 @@ class NetworkEnvironment(BaseEnvironment):
|
|||||||
|
|
||||||
|
|
||||||
Environment = NetworkEnvironment
|
Environment = NetworkEnvironment
|
||||||
|
|
||||||
|
|
||||||
|
class EventedEnvironment(Environment):
|
||||||
|
def broadcast(self, msg, sender, expiration=None, ttl=None, **kwargs):
|
||||||
|
for agent in self.agents(**kwargs):
|
||||||
|
self.logger.info(f'Telling {repr(agent)}: {msg} ttl={ttl}')
|
||||||
|
try:
|
||||||
|
agent._inbox.append(events.Tell(payload=msg, sender=sender, expiration=expiration if ttl is None else self.now+ttl))
|
||||||
|
except AttributeError:
|
||||||
|
self.info(f'Agent {agent.unique_id} cannot receive events')
|
||||||
|
|
||||||
|
43
soil/events.py
Normal file
43
soil/events.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from .time import Cond
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
class Event:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message:
|
||||||
|
payload: Any
|
||||||
|
sender: Any = None
|
||||||
|
expiration: float = None
|
||||||
|
id: int = field(default_factory=uuid4)
|
||||||
|
|
||||||
|
def expired(self, when):
|
||||||
|
return self.expiration is not None and self.expiration < when
|
||||||
|
|
||||||
|
class Reply(Message):
|
||||||
|
source: Message
|
||||||
|
|
||||||
|
|
||||||
|
class Ask(Message):
|
||||||
|
reply: Message = None
|
||||||
|
|
||||||
|
def replied(self, expiration=None):
|
||||||
|
def ready(agent):
|
||||||
|
return self.reply is not None or agent.now > expiration
|
||||||
|
|
||||||
|
def value(agent):
|
||||||
|
if agent.now > expiration:
|
||||||
|
raise TimedOut(f'No answer received for {self}')
|
||||||
|
return self.reply
|
||||||
|
|
||||||
|
return Cond(func=ready, return_func=value)
|
||||||
|
|
||||||
|
|
||||||
|
class Tell(Message):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TimedOut(Exception):
|
||||||
|
pass
|
@ -47,7 +47,7 @@ class Simulation:
|
|||||||
max_time: float = float("inf")
|
max_time: float = float("inf")
|
||||||
max_steps: int = -1
|
max_steps: int = -1
|
||||||
interval: int = 1
|
interval: int = 1
|
||||||
num_trials: int = 3
|
num_trials: int = 1
|
||||||
parallel: Optional[bool] = None
|
parallel: Optional[bool] = None
|
||||||
exporters: Optional[List[str]] = field(default_factory=list)
|
exporters: Optional[List[str]] = field(default_factory=list)
|
||||||
outdir: Optional[str] = None
|
outdir: Optional[str] = None
|
||||||
|
20
soil/time.py
20
soil/time.py
@ -45,12 +45,16 @@ class When:
|
|||||||
def ready(self, agent):
|
def ready(self, agent):
|
||||||
return self._time <= agent.model.schedule.time
|
return self._time <= agent.model.schedule.time
|
||||||
|
|
||||||
|
def return_value(self, agent):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Cond(When):
|
class Cond(When):
|
||||||
def __init__(self, func, delta=1):
|
def __init__(self, func, delta=1, return_func=lambda agent: None):
|
||||||
self._func = func
|
self._func = func
|
||||||
self._delta = delta
|
self._delta = delta
|
||||||
self._checked = False
|
self._checked = False
|
||||||
|
self._return_func = return_func
|
||||||
|
|
||||||
def next(self, time):
|
def next(self, time):
|
||||||
if self._checked:
|
if self._checked:
|
||||||
@ -64,6 +68,9 @@ class Cond(When):
|
|||||||
self._checked = True
|
self._checked = True
|
||||||
return self._func(agent)
|
return self._func(agent)
|
||||||
|
|
||||||
|
def return_value(self, agent):
|
||||||
|
return self._return_func(agent)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -144,14 +151,21 @@ class TimedActivation(BaseScheduler):
|
|||||||
|
|
||||||
ix = 0
|
ix = 0
|
||||||
|
|
||||||
|
self.logger.debug(f"Queue length: {len(self._queue)}")
|
||||||
|
|
||||||
while self._queue:
|
while self._queue:
|
||||||
(when, agent) = self._queue[0]
|
(when, agent) = self._queue[0]
|
||||||
if when > self.time:
|
if when > self.time:
|
||||||
break
|
break
|
||||||
heappop(self._queue)
|
heappop(self._queue)
|
||||||
if when.ready(agent):
|
if when.ready(agent):
|
||||||
to_process.append(agent)
|
try:
|
||||||
|
agent._last_return = when.return_value(agent)
|
||||||
|
except Exception as ex:
|
||||||
|
agent._last_except = ex
|
||||||
|
|
||||||
self._next.pop(agent.unique_id, None)
|
self._next.pop(agent.unique_id, None)
|
||||||
|
to_process.append(agent)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
next_time = min(next_time, when.next(self.time))
|
next_time = min(next_time, when.next(self.time))
|
||||||
@ -175,10 +189,10 @@ class TimedActivation(BaseScheduler):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if not getattr(agent, "alive", True):
|
if not getattr(agent, "alive", True):
|
||||||
self.remove(agent)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
value = returned.next(self.time)
|
value = returned.next(self.time)
|
||||||
|
agent._last_return = value
|
||||||
|
|
||||||
if value < self.time:
|
if value < self.time:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -33,18 +33,20 @@ class TestMain(TestCase):
|
|||||||
The step function of an agent could be a generator. In that case, the state of the
|
The step function of an agent could be a generator. In that case, the state of the
|
||||||
agent will be resumed after every call to step.
|
agent will be resumed after every call to step.
|
||||||
'''
|
'''
|
||||||
|
a = 0
|
||||||
class Gen(agents.BaseAgent):
|
class Gen(agents.BaseAgent):
|
||||||
def step(self):
|
def step(self):
|
||||||
a = 0
|
nonlocal a
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
yield a
|
yield
|
||||||
a += 1
|
a += 1
|
||||||
e = environment.Environment()
|
e = environment.Environment()
|
||||||
g = Gen(model=e, unique_id=e.next_id())
|
g = Gen(model=e, unique_id=e.next_id())
|
||||||
|
e.schedule.add(g)
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
t = g.step()
|
e.step()
|
||||||
assert t == i
|
assert a == i
|
||||||
|
|
||||||
def test_state_decorator(self):
|
def test_state_decorator(self):
|
||||||
class MyAgent(agents.FSM):
|
class MyAgent(agents.FSM):
|
||||||
@ -53,6 +55,12 @@ class TestMain(TestCase):
|
|||||||
@agents.state('original')
|
@agents.state('original')
|
||||||
def root(self):
|
def root(self):
|
||||||
self.run += 1
|
self.run += 1
|
||||||
|
return self.other
|
||||||
|
|
||||||
|
@agents.state
|
||||||
|
def other(self):
|
||||||
|
self.run += 1
|
||||||
|
|
||||||
e = environment.Environment()
|
e = environment.Environment()
|
||||||
a = MyAgent(model=e, unique_id=e.next_id())
|
a = MyAgent(model=e, unique_id=e.next_id())
|
||||||
a.step()
|
a.step()
|
||||||
|
Loading…
Reference in New Issue
Block a user