From 227fdf050e307de3032aa471d79a2459d273d392 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=2E=20Fernando=20S=C3=A1nchez?= Date: Mon, 17 Oct 2022 19:29:39 +0200 Subject: [PATCH] Fix conditionals --- examples/pubcrawl/pubcrawl.py | 7 +- examples/rabbits/README.md | 10 ++ examples/rabbits/basic/rabbit_agents.py | 68 ++++---- examples/rabbits/improved/rabbit_agents.py | 179 ++++++++++++--------- examples/rabbits/improved/rabbits.yml | 8 +- soil/__init__.py | 20 ++- soil/__main__.py | 4 +- soil/agents/__init__.py | 103 ++---------- soil/debugging.py | 4 +- soil/environment.py | 26 ++- soil/network.py | 6 +- soil/simulation.py | 2 +- soil/time.py | 40 +++-- tests/test_agents.py | 2 +- tests/test_time.py | 74 +++++++++ 15 files changed, 321 insertions(+), 232 deletions(-) create mode 100644 tests/test_time.py diff --git a/examples/pubcrawl/pubcrawl.py b/examples/pubcrawl/pubcrawl.py index b220856..9fd1b04 100644 --- a/examples/pubcrawl/pubcrawl.py +++ b/examples/pubcrawl/pubcrawl.py @@ -64,6 +64,7 @@ class Patron(FSM, NetworkAgent): drunk = False pints = 0 max_pints = 3 + kicked_out = False @default_state @state @@ -105,7 +106,9 @@ class Patron(FSM, NetworkAgent): '''I'm out. Take me home!''' self.info('I\'m so drunk. Take me home!') self['drunk'] = True - pass # out drunk + if self.kicked_out: + return self.at_home + pass # out drun @state def at_home(self): @@ -118,7 +121,7 @@ class Patron(FSM, NetworkAgent): self.debug('Cheers to that') def kick_out(self): - self.set_state(self.at_home) + self.kicked_out = True def befriend(self, other_agent, force=False): ''' diff --git a/examples/rabbits/README.md b/examples/rabbits/README.md index 42b6011..dfee8ef 100644 --- a/examples/rabbits/README.md +++ b/examples/rabbits/README.md @@ -2,3 +2,13 @@ There are two similar implementations of this simulation. - `basic`. Using simple primites - `improved`. Using more advanced features such as the `time` module to avoid unnecessary computations (i.e., skip steps), and generator functions. + +The examples can be run directly in the terminal, and they accept command like arguments. +For example, to enable the CSV exporter and the Summary exporter, while setting `max_time` to `100` and `seed` to `CustomSeed`: + +``` +python rabbit_agents.py --set max_time=100 --csv -e summary --set 'seed="CustomSeed"' +``` + +To learn more about how this functionality works, check out the `soil.easy` function. + diff --git a/examples/rabbits/basic/rabbit_agents.py b/examples/rabbits/basic/rabbit_agents.py index 284c08a..60a0d15 100644 --- a/examples/rabbits/basic/rabbit_agents.py +++ b/examples/rabbits/basic/rabbit_agents.py @@ -1,6 +1,4 @@ from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment -from soil.time import Delta -from enum import Enum from collections import Counter import logging import math @@ -21,7 +19,7 @@ class RabbitEnv(Environment): return self.count_agents(agent_class=Female) -class Rabbit(FSM, NetworkAgent): +class Rabbit(NetworkAgent, FSM): sexual_maturity = 30 life_expectancy = 300 @@ -72,7 +70,8 @@ class Male(Rabbit): class Female(Rabbit): - gestation = 30 + gestation = 10 + pregnancy = -1 @state def fertile(self): @@ -80,46 +79,49 @@ class Female(Rabbit): self.age += 1 if self.age > self.life_expectancy: return self.dead + if self.pregnancy >= 0: + return self.pregnant def impregnate(self, male): - self.info(f'{repr(male)} impregnating female {repr(self)}') + self.info(f'impregnated by {repr(male)}') self.mate = male - self.pregnancy = -1 - self.set_state(self.pregnant, when=self.now) + self.pregnancy = 0 self.number_of_babies = int(8+4*self.random.random()) @state def pregnant(self): - self.debug('I am pregnant') + self.info('I am pregnant') self.age += 1 - self.pregnancy += 1 - if self.prob(self.age / self.life_expectancy): + if self.age >= self.life_expectancy: return self.die() - if self.pregnancy >= self.gestation: - self.info('Having {} babies'.format(self.number_of_babies)) - for i in range(self.number_of_babies): - state = {} - agent_class = self.random.choice([Male, Female]) - child = self.model.add_node(agent_class=agent_class, - **state) - child.add_edge(self) - try: - child.add_edge(self.mate) - self.model.agents[self.mate].offspring += 1 - except ValueError: - self.debug('The father has passed away') - - self.offspring += 1 - self.mate = None - return self.fertile + if self.pregnancy < self.gestation: + self.pregnancy += 1 + return + + self.info('Having {} babies'.format(self.number_of_babies)) + for i in range(self.number_of_babies): + state = {} + agent_class = self.random.choice([Male, Female]) + child = self.model.add_node(agent_class=agent_class, + **state) + child.add_edge(self) + try: + child.add_edge(self.mate) + self.model.agents[self.mate].offspring += 1 + except ValueError: + self.debug('The father has passed away') + + self.offspring += 1 + self.mate = None + self.pregnancy = -1 + return self.fertile - @state - def dead(self): - super().dead() + def die(self): if 'pregnancy' in self and self['pregnancy'] > -1: self.info('A mother has died carrying a baby!!') + return super().die() class RandomAccident(BaseAgent): @@ -138,11 +140,11 @@ class RandomAccident(BaseAgent): if self.prob(prob_death): self.info('I killed a rabbit: {}'.format(i.id)) rabbits_alive -= 1 - i.set_state(i.dead) + i.die() self.debug('Rabbits alive: {}'.format(rabbits_alive)) if __name__ == '__main__': from soil import easy - sim = easy('rabbits.yml') - sim.run() + with easy('rabbits.yml') as sim: + sim.run() diff --git a/examples/rabbits/improved/rabbit_agents.py b/examples/rabbits/improved/rabbit_agents.py index d97b7e7..c7d995d 100644 --- a/examples/rabbits/improved/rabbit_agents.py +++ b/examples/rabbits/improved/rabbit_agents.py @@ -1,130 +1,157 @@ -from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent -from soil.time import Delta, When, NEVER +from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment +from soil.time import Delta from enum import Enum +from collections import Counter import logging import math -class RabbitModel(FSM, NetworkAgent): +class RabbitEnv(Environment): - mating_prob = 0.005 - offspring = 0 - birth = None + @property + def num_rabbits(self): + return self.count_agents(agent_class=Rabbit) - sexual_maturity = 3 - life_expectancy = 30 + @property + def num_males(self): + return self.count_agents(agent_class=Male) + + @property + def num_females(self): + return self.count_agents(agent_class=Female) - @default_state - @state - def newborn(self): - self.birth = self.now - self.info(f'I am a newborn.') - self.model['rabbits_alive'] = self.model.get('rabbits_alive', 0) + 1 - # Here we can skip the `youngling` state by using a coroutine/generator. - while self.age < self.sexual_maturity: - interval = self.sexual_maturity - self.age - yield Delta(interval) +class Rabbit(FSM, NetworkAgent): - self.info(f'I am fertile! My age is {self.age}') - return self.fertile + sexual_maturity = 30 + life_expectancy = 300 + birth = None @property def age(self): + if self.birth is None: + return None return self.now - self.birth + @default_state + @state + def newborn(self): + self.info('I am a newborn.') + self.birth = self.now + self.offspring = 0 + return self.youngling, Delta(self.sexual_maturity - self.age) + + @state + def youngling(self): + if self.age >= self.sexual_maturity: + self.info(f'I am fertile! My age is {self.age}') + return self.fertile + @state def fertile(self): raise Exception("Each subclass should define its fertile state") - def step(self): - super().step() - if self.prob(self.age / self.life_expectancy): - return self.die() - + @state + def dead(self): + self.die() -class Male(RabbitModel): +class Male(Rabbit): max_females = 5 + mating_prob = 0.001 @state def fertile(self): + if self.age > self.life_expectancy: + return self.dead + # Males try to mate for f in self.model.agents(agent_class=Female, state_id=Female.fertile.id, limit=self.max_females): - self.debug('Found a female:', repr(f)) + self.debug('FOUND A FEMALE: ', repr(f), self.mating_prob) if self.prob(self['mating_prob']): f.impregnate(self) - break # Take a break, don't try to impregnate the rest + break # Do not try to impregnate other females + - -class Female(RabbitModel): - due_date = None - age_of_pregnancy = None +class Female(Rabbit): gestation = 10 - mate = None + conception = None @state def fertile(self): - return self.fertile, NEVER - - @state - def pregnant(self): - self.info('I am pregnant') + # Just wait for a Male if self.age > self.life_expectancy: return self.dead + if self.conception is not None: + return self.pregnant - self.due_date = self.now + self.gestation + @property + def pregnancy(self): + if self.conception is None: + return None + return self.now - self.conception - number_of_babies = int(8+4*self.random.random()) + def impregnate(self, male): + self.info(f'impregnated by {repr(male)}') + self.mate = male + self.conception = self.now + self.number_of_babies = int(8+4*self.random.random()) - while self.now < self.due_date: - yield When(self.due_date) + @state + def pregnant(self): + self.debug('I am pregnant') - self.info('Having {} babies'.format(number_of_babies)) - for i in range(number_of_babies): - agent_class = self.random.choice([Male, Female]) - child = self.model.add_node(agent_class=agent_class, - topology=self.topology) - self.model.add_edge(self, child) - self.model.add_edge(self.mate, child) - self.offspring += 1 - self.model.agents[self.mate].offspring += 1 - self.mate = None - self.due_date = None - return self.fertile + if self.age > self.life_expectancy: + self.info("Dying before giving birth") + return self.die() - @state - def dead(self): - super().dead() - if self.due_date is not None: + if self.pregnancy >= self.gestation: + self.info('Having {} babies'.format(self.number_of_babies)) + for i in range(self.number_of_babies): + state = {} + agent_class = self.random.choice([Male, Female]) + child = self.model.add_node(agent_class=agent_class, + **state) + child.add_edge(self) + if self.mate: + child.add_edge(self.mate) + self.mate.offspring += 1 + else: + self.debug('The father has passed away') + + self.offspring += 1 + self.mate = None + return self.fertile + + def die(self): + if self.pregnancy is not None: self.info('A mother has died carrying a baby!!') - - def impregnate(self, male): - self.info(f'{repr(male)} impregnating female {repr(self)}') - self.mate = male - self.set_state(self.pregnant, when=self.now) + return super().die() class RandomAccident(BaseAgent): - level = logging.INFO - def step(self): - rabbits_total = self.model.topology.number_of_nodes() - if 'rabbits_alive' not in self.model: - self.model['rabbits_alive'] = 0 - rabbits_alive = self.model.get('rabbits_alive', rabbits_total) + rabbits_alive = self.model.G.number_of_nodes() + + if not rabbits_alive: + return self.die() + prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive))) self.debug('Killing some rabbits with prob={}!'.format(prob_death)) - for i in self.model.network_agents: - if i.state.id == i.dead.id: + for i in self.iter_agents(agent_class=Rabbit): + if i.state_id == i.dead.id: continue if self.prob(prob_death): self.info('I killed a rabbit: {}'.format(i.id)) - rabbits_alive = self.model['rabbits_alive'] = rabbits_alive -1 - i.set_state(i.dead) - self.debug('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total)) - if self.model.count_agents(state_id=RabbitModel.dead.id) == self.model.topology.number_of_nodes(): - self.die() + rabbits_alive -= 1 + i.die() + self.debug('Rabbits alive: {}'.format(rabbits_alive)) + + +if __name__ == '__main__': + from soil import easy + with easy('rabbits.yml') as sim: + sim.run() diff --git a/examples/rabbits/improved/rabbits.yml b/examples/rabbits/improved/rabbits.yml index dd13c4e..204270c 100644 --- a/examples/rabbits/improved/rabbits.yml +++ b/examples/rabbits/improved/rabbits.yml @@ -7,11 +7,10 @@ description: null group: null interval: 1.0 max_time: 100 -model_class: soil.environment.Environment +model_class: rabbit_agents.RabbitEnv model_params: agents: topology: true - agent_class: rabbit_agents.RabbitModel distribution: - agent_class: rabbit_agents.Male weight: 1 @@ -34,5 +33,10 @@ model_params: nodes: - id: 1 - id: 0 + model_reporters: + num_males: 'num_males' + num_females: 'num_females' + num_rabbits: | + py:lambda env: env.num_males + env.num_females extra: visualization_params: {} diff --git a/soil/__init__.py b/soil/__init__.py index 46d56bd..a49462b 100644 --- a/soil/__init__.py +++ b/soil/__init__.py @@ -5,6 +5,7 @@ import sys import os import logging import traceback +from contextlib import contextmanager from .version import __version__ @@ -30,6 +31,7 @@ def main( *, do_run=False, debug=False, + pdb=False, **kwargs, ): import argparse @@ -154,6 +156,7 @@ def main( if args.pdb or debug: args.synchronous = True + os.environ["SOIL_POSTMORTEM"] = "true" res = [] try: @@ -214,9 +217,20 @@ def main( return res -def easy(cfg, debug=False, **kwargs): - return main(cfg, **kwargs)[0] - +@contextmanager +def easy(cfg, pdb=False, debug=False, **kwargs): + ex = None + try: + yield main(cfg, **kwargs)[0] + except Exception as e: + if os.environ.get("SOIL_POSTMORTEM"): + from .debugging import post_mortem + print(traceback.format_exc()) + post_mortem() + ex = e + finally: + if ex: + raise ex if __name__ == "__main__": main(do_run=True) diff --git a/soil/__main__.py b/soil/__main__.py index 0c76791..9ad5c4f 100644 --- a/soil/__main__.py +++ b/soil/__main__.py @@ -1,9 +1,7 @@ from . import main as init_main - def main(): init_main(do_run=True) - -if __name__ == "__main__": +if __name__ == '__main__': init_main(do_run=True) diff --git a/soil/agents/__init__.py b/soil/agents/__init__.py index 0ed5bf3..b316caa 100644 --- a/soil/agents/__init__.py +++ b/soil/agents/__init__.py @@ -29,10 +29,6 @@ def as_node(agent): IGNORED_FIELDS = ("model", "logger") -class DeadAgent(Exception): - pass - - class MetaAgent(ABCMeta): def __new__(mcls, name, bases, namespace): defaults = {} @@ -198,7 +194,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): def step(self): if not self.alive: - raise DeadAgent(self.unique_id) + raise time.DeadAgent(self.unique_id) return super().step() or time.Delta(self.interval) def log(self, message, *args, level=logging.INFO, **kwargs): @@ -264,6 +260,10 @@ class NetworkAgent(BaseAgent): return list(self.iter_agents(limit_neighbors=True, **kwargs)) def add_edge(self, other): + assert self.node_id + assert other.node_id + assert self.node_id in self.G.nodes + assert other.node_id in self.G.nodes self.topology.add_edge(self.node_id, other.node_id) @property @@ -303,7 +303,9 @@ class NetworkAgent(BaseAgent): return G def remove_node(self): + print(f'Removing node for {self.unique_id}: {self.node_id}') self.G.remove_node(self.node_id) + self.node_id = None def add_edge(self, other, edge_attr_dict=None, *edge_attrs): if self.node_id not in self.G.nodes(data=False): @@ -322,6 +324,8 @@ class NetworkAgent(BaseAgent): ) def die(self, remove=True): + if not self.alive: + return if remove: self.remove_node() return super().die() @@ -351,7 +355,7 @@ def state(name=None): self._coroutine = None next_state = ex.value if next_state is not None: - self.set_state(next_state) + self._set_state(next_state) return next_state func.id = name or func.__name__ @@ -401,8 +405,8 @@ class MetaFSM(MetaAgent): class FSM(BaseAgent, metaclass=MetaFSM): - def __init__(self, *args, **kwargs): - super(FSM, self).__init__(*args, **kwargs) + def __init__(self, **kwargs): + super(FSM, self).__init__(**kwargs) if not hasattr(self, "state_id"): if not self._default_state: raise ValueError( @@ -411,7 +415,7 @@ class FSM(BaseAgent, metaclass=MetaFSM): self.state_id = self._default_state.id self._coroutine = None - self.set_state(self.state_id) + self._set_state(self.state_id) def step(self): self.debug(f"Agent {self.unique_id} @ state {self.state_id}") @@ -434,11 +438,11 @@ class FSM(BaseAgent, metaclass=MetaFSM): pass if next_state is not None: - self.set_state(next_state) + self._set_state(next_state) return when or default_interval - def set_state(self, state, when=None): + def _set_state(self, state, when=None): if hasattr(state, "id"): state = state.id if state not in self._states: @@ -576,83 +580,6 @@ def _convert_agent_classs(ind, to_string=False, **kwargs): return deserialize_definition(ind, **kwargs) -# def _agent_from_definition(definition, random, value=-1, unique_id=None): -# """Used in the initialization of agents given an agent distribution.""" -# if value < 0: -# value = random.random() -# for d in sorted(definition, key=lambda x: x.get('threshold')): -# threshold = d.get('threshold', (-1, -1)) -# # Check if the definition matches by id (first) or by threshold -# if (unique_id is not None and unique_id in d.get('ids', [])) or \ -# (value >= threshold[0] and value < threshold[1]): -# state = {} -# if 'state' in d: -# state = deepcopy(d['state']) -# return d['agent_class'], state - -# raise Exception('Definition for value {} not found in: {}'.format(value, definition)) - - -# def _definition_to_dict(definition, random, size=None, default_state=None): -# state = default_state or {} -# agents = {} -# remaining = {} -# if size: -# for ix in range(size): -# remaining[ix] = copy(state) -# else: -# remaining = defaultdict(lambda x: copy(state)) - -# distro = sorted([item for item in definition if 'weight' in item]) - -# id = 0 - -# def init_agent(item, id=ix): -# while id in agents: -# id += 1 - -# agent = remaining[id] -# agent['state'].update(copy(item.get('state', {}))) -# agents[agent.unique_id] = agent -# del remaining[id] -# return agent - -# for item in definition: -# if 'ids' in item: -# ids = item['ids'] -# del item['ids'] -# for id in ids: -# agent = init_agent(item, id) - -# for item in definition: -# if 'number' in item: -# times = item['number'] -# del item['number'] -# for times in range(times): -# if size: -# ix = random.choice(remaining.keys()) -# agent = init_agent(item, id) -# else: -# agent = init_agent(item) -# if not size: -# return agents - -# if len(remaining) < 0: -# raise Exception('Invalid definition. Too many agents to add') - - -# total_weight = float(sum(s['weight'] for s in distro)) -# unit = size / total_weight - -# for item in distro: -# times = unit * item['weight'] -# del item['weight'] -# for times in range(times): -# ix = random.choice(remaining.keys()) -# agent = init_agent(item, id) -# return agents - - class AgentView(Mapping, Set): """A lazy-loaded list of agents.""" diff --git a/soil/debugging.py b/soil/debugging.py index 607996b..f5a43e7 100644 --- a/soil/debugging.py +++ b/soil/debugging.py @@ -31,8 +31,8 @@ class Debug(pdb.Pdb): def __init__(self, *args, skip_soil=False, **kwargs): skip = kwargs.get("skip", []) skip.append("soil") + skip.append("contextlib") if skip_soil: - skip.append("soil") skip.append("soil.*") skip.append("mesa.*") super(Debug, self).__init__(*args, skip=skip, **kwargs) @@ -181,7 +181,7 @@ def set_trace(frame=None, **kwargs): debugger.set_trace(frame) -def post_mortem(traceback=None): +def post_mortem(traceback=None, **kwargs): global debugger if debugger is None: debugger = Debug(**kwargs) diff --git a/soil/environment.py b/soil/environment.py index 238c494..fb4823f 100644 --- a/soil/environment.py +++ b/soil/environment.py @@ -142,12 +142,12 @@ class BaseEnvironment(Model): "The environment has not been scheduled, so it has no sense of time" ) - def add_agent(self, agent_class, unique_id=None, **kwargs): - a = None + def add_agent(self, unique_id=None, **kwargs): if unique_id is None: unique_id = self.next_id() - a = agent_class(model=self, unique_id=unique_id, **args) + kwargs['unique_id'] = unique_id + a = self._agent_from_dict(kwargs) self.schedule.add(a) return a @@ -236,6 +236,7 @@ class NetworkEnvironment(BaseEnvironment): node_id = agent.get("node_id", None) if node_id is None: node_id = network.find_unassigned(self.G, random=self.random) + self.G.nodes[node_id]['agent'] = None agent["node_id"] = node_id agent["unique_id"] = unique_id agent["topology"] = self.G @@ -269,18 +270,29 @@ class NetworkEnvironment(BaseEnvironment): node_id = network.find_unassigned( G=self.G, shuffle=True, random=self.random ) + if node_id is None: + node_id = f'node_for_{unique_id}' - if node_id in G.nodes: - self.G.nodes[node_id]["agent"] = None # Reserve - else: + if node_id not in self.G.nodes: self.G.add_node(node_id) + assert "agent" not in self.G.nodes[node_id] + self.G.nodes[node_id]["agent"] = None # Reserve + a = self.add_agent( - unique_id=unique_id, agent_class=agent_class, node_id=node_id, **kwargs + unique_id=unique_id, agent_class=agent_class, topology=self.G, node_id=node_id, **kwargs ) a["visible"] = True return a + def add_agent(self, *args, **kwargs): + a = super().add_agent(*args, **kwargs) + if 'node_id' in a: + if a.node_id == 24: + import pdb;pdb.set_trace() + assert self.G.nodes[a.node_id]['agent'] == a + return a + def agent_for_node_id(self, node_id): return self.G.nodes[node_id].get("agent") diff --git a/soil/network.py b/soil/network.py index 5c0b005..be7d96f 100644 --- a/soil/network.py +++ b/soil/network.py @@ -65,10 +65,8 @@ def find_unassigned(G, shuffle=False, random=random): random.shuffle(candidates) for next_id, data in candidates: if "agent" not in data: - node_id = next_id - break - - return node_id + return next_id + return None def dump_gexf(G, f): diff --git a/soil/simulation.py b/soil/simulation.py index 946023f..fc50ab8 100644 --- a/soil/simulation.py +++ b/soil/simulation.py @@ -226,7 +226,7 @@ Model stats: ) model.step() - if model.schedule.time < until: # Simulation ended (no more steps) before until (i.e., no changes expected) + if model.schedule.time < until: # Simulation ended (no more steps) before the expected time model.schedule.time = until return model diff --git a/soil/time.py b/soil/time.py index 661e35e..26c4259 100644 --- a/soil/time.py +++ b/soil/time.py @@ -13,6 +13,10 @@ from mesa import Agent as MesaAgent INFINITY = float("inf") +class DeadAgent(Exception): + pass + + class When: def __init__(self, time): if isinstance(time, When): @@ -38,23 +42,27 @@ class When: return self._time > other return self._time > other.next(self._time) - def ready(self, time): - return self._time <= time + def ready(self, agent): + return self._time <= agent.model.schedule.time class Cond(When): def __init__(self, func, delta=1): self._func = func self._delta = delta + self._checked = False def next(self, time): - return time + self._delta + if self._checked: + return time + self._delta + return time def abs(self, time): return self - def ready(self, time): - return self._func(time) + def ready(self, agent): + self._checked = True + return self._func(agent) def __eq__(self, other): return False @@ -109,10 +117,12 @@ class TimedActivation(BaseScheduler): elif not isinstance(when, When): when = When(when) if agent.unique_id in self._agents: - self._queue.remove((self._next[agent.unique_id], agent)) del self._agents[agent.unique_id] - heapify(self._queue) + if agent.unique_id in self._next: + self._queue.remove((self._next[agent.unique_id], agent)) + heapify(self._queue) + self._next[agent.unique_id] = when heappush(self._queue, (when, agent)) super().add(agent) @@ -139,8 +149,9 @@ class TimedActivation(BaseScheduler): if when > self.time: break heappop(self._queue) - if when.ready(self.time): + if when.ready(agent): to_process.append(agent) + self._next.pop(agent.unique_id, None) continue next_time = min(next_time, when.next(self.time)) @@ -155,13 +166,20 @@ class TimedActivation(BaseScheduler): for agent in to_process: self.logger.debug(f"Stepping agent {agent}") - returned = ((agent.step() or Delta(1))).abs(self.time) + try: + returned = ((agent.step() or Delta(1))).abs(self.time) + except DeadAgent: + if agent.unique_id in self._next: + del self._next[agent.unique_id] + agent.alive = False + continue + if not getattr(agent, "alive", True): self.remove(agent) continue - value = when.next(self.time) + value = returned.next(self.time) if value < self.time: raise Exception( @@ -172,6 +190,8 @@ class TimedActivation(BaseScheduler): self._next[agent.unique_id] = returned heappush(self._queue, (returned, agent)) + else: + assert not self._next[agent.unique_id] self.steps += 1 self.logger.debug(f"Updating time step: {self.time} -> {next_time}") diff --git a/tests/test_agents.py b/tests/test_agents.py index d3db80e..4006e9d 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -24,7 +24,7 @@ class TestMain(TestCase): '''A dead agent should raise an exception if it is stepped after death''' d = Dead(unique_id=0, model=environment.Environment()) d.step() - with pytest.raises(agents.DeadAgent): + with pytest.raises(stime.DeadAgent): d.step() diff --git a/tests/test_time.py b/tests/test_time.py new file mode 100644 index 0000000..db16609 --- /dev/null +++ b/tests/test_time.py @@ -0,0 +1,74 @@ +from unittest import TestCase + +from soil import time, agents, environment + +class TestMain(TestCase): + def test_cond(self): + ''' + A condition should match a When if the concition is True + ''' + + t = time.Cond(lambda t: True) + f = time.Cond(lambda t: False) + for i in range(10): + w = time.When(i) + assert w == t + assert w is not f + + def test_cond(self): + ''' + Comparing a Cond to a Delta should always return False + ''' + + c = time.Cond(lambda t: False) + d = time.Delta(1) + assert c is not d + + def test_cond_env(self): + ''' + ''' + + times_started = [] + times_awakened = [] + times = [] + done = 0 + + class CondAgent(agents.BaseAgent): + + def step(self): + nonlocal done + times_started.append(self.now) + while True: + yield time.Cond(lambda agent: agent.model.schedule.time >= 10) + times_awakened.append(self.now) + if self.now >= 10: + break + done += 1 + + env = environment.Environment(agents=[{'agent_class': CondAgent}]) + + + while env.schedule.time < 11: + env.step() + times.append(env.now) + assert env.schedule.time == 11 + assert times_started == [0] + assert times_awakened == [10] + assert done == 1 + # The first time will produce the Cond. + # Since there are no other agents, time will not advance, but the number + # of steps will. + assert env.schedule.steps == 12 + assert len(times) == 12 + + while env.schedule.time < 12: + env.step() + times.append(env.now) + + assert env.schedule.time == 12 + assert times_started == [0, 11] + assert times_awakened == [10, 11] + assert done == 2 + # Once more to yield the cond, another one to continue + assert env.schedule.steps == 14 + assert len(times) == 14