Fix conditionals

mesa
J. Fernando Sánchez 2 years ago
parent 5d759d0072
commit 227fdf050e

@ -64,6 +64,7 @@ class Patron(FSM, NetworkAgent):
drunk = False drunk = False
pints = 0 pints = 0
max_pints = 3 max_pints = 3
kicked_out = False
@default_state @default_state
@state @state
@ -105,7 +106,9 @@ class Patron(FSM, NetworkAgent):
'''I'm out. Take me home!''' '''I'm out. Take me home!'''
self.info('I\'m so drunk. Take me home!') self.info('I\'m so drunk. Take me home!')
self['drunk'] = True self['drunk'] = True
pass # out drunk if self.kicked_out:
return self.at_home
pass # out drun
@state @state
def at_home(self): def at_home(self):
@ -118,7 +121,7 @@ class Patron(FSM, NetworkAgent):
self.debug('Cheers to that') self.debug('Cheers to that')
def kick_out(self): def kick_out(self):
self.set_state(self.at_home) self.kicked_out = True
def befriend(self, other_agent, force=False): def befriend(self, other_agent, force=False):
''' '''

@ -2,3 +2,13 @@ There are two similar implementations of this simulation.
- `basic`. Using simple primites - `basic`. Using simple primites
- `improved`. Using more advanced features such as the `time` module to avoid unnecessary computations (i.e., skip steps), and generator functions. - `improved`. Using more advanced features such as the `time` module to avoid unnecessary computations (i.e., skip steps), and generator functions.
The examples can be run directly in the terminal, and they accept command like arguments.
For example, to enable the CSV exporter and the Summary exporter, while setting `max_time` to `100` and `seed` to `CustomSeed`:
```
python rabbit_agents.py --set max_time=100 --csv -e summary --set 'seed="CustomSeed"'
```
To learn more about how this functionality works, check out the `soil.easy` function.

@ -1,6 +1,4 @@
from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment
from soil.time import Delta
from enum import Enum
from collections import Counter from collections import Counter
import logging import logging
import math import math
@ -21,7 +19,7 @@ class RabbitEnv(Environment):
return self.count_agents(agent_class=Female) return self.count_agents(agent_class=Female)
class Rabbit(FSM, NetworkAgent): class Rabbit(NetworkAgent, FSM):
sexual_maturity = 30 sexual_maturity = 30
life_expectancy = 300 life_expectancy = 300
@ -72,7 +70,8 @@ class Male(Rabbit):
class Female(Rabbit): class Female(Rabbit):
gestation = 30 gestation = 10
pregnancy = -1
@state @state
def fertile(self): def fertile(self):
@ -80,46 +79,49 @@ class Female(Rabbit):
self.age += 1 self.age += 1
if self.age > self.life_expectancy: if self.age > self.life_expectancy:
return self.dead return self.dead
if self.pregnancy >= 0:
return self.pregnant
def impregnate(self, male): def impregnate(self, male):
self.info(f'{repr(male)} impregnating female {repr(self)}') self.info(f'impregnated by {repr(male)}')
self.mate = male self.mate = male
self.pregnancy = -1 self.pregnancy = 0
self.set_state(self.pregnant, when=self.now)
self.number_of_babies = int(8+4*self.random.random()) self.number_of_babies = int(8+4*self.random.random())
@state @state
def pregnant(self): def pregnant(self):
self.debug('I am pregnant') self.info('I am pregnant')
self.age += 1 self.age += 1
self.pregnancy += 1
if self.prob(self.age / self.life_expectancy): if self.age >= self.life_expectancy:
return self.die() return self.die()
if self.pregnancy >= self.gestation: if self.pregnancy < self.gestation:
self.info('Having {} babies'.format(self.number_of_babies)) self.pregnancy += 1
for i in range(self.number_of_babies): return
state = {}
agent_class = self.random.choice([Male, Female]) self.info('Having {} babies'.format(self.number_of_babies))
child = self.model.add_node(agent_class=agent_class, for i in range(self.number_of_babies):
**state) state = {}
child.add_edge(self) agent_class = self.random.choice([Male, Female])
try: child = self.model.add_node(agent_class=agent_class,
child.add_edge(self.mate) **state)
self.model.agents[self.mate].offspring += 1 child.add_edge(self)
except ValueError: try:
self.debug('The father has passed away') child.add_edge(self.mate)
self.model.agents[self.mate].offspring += 1
self.offspring += 1 except ValueError:
self.mate = None self.debug('The father has passed away')
return self.fertile
self.offspring += 1
self.mate = None
self.pregnancy = -1
return self.fertile
@state def die(self):
def dead(self):
super().dead()
if 'pregnancy' in self and self['pregnancy'] > -1: if 'pregnancy' in self and self['pregnancy'] > -1:
self.info('A mother has died carrying a baby!!') self.info('A mother has died carrying a baby!!')
return super().die()
class RandomAccident(BaseAgent): class RandomAccident(BaseAgent):
@ -138,11 +140,11 @@ class RandomAccident(BaseAgent):
if self.prob(prob_death): if self.prob(prob_death):
self.info('I killed a rabbit: {}'.format(i.id)) self.info('I killed a rabbit: {}'.format(i.id))
rabbits_alive -= 1 rabbits_alive -= 1
i.set_state(i.dead) i.die()
self.debug('Rabbits alive: {}'.format(rabbits_alive)) self.debug('Rabbits alive: {}'.format(rabbits_alive))
if __name__ == '__main__': if __name__ == '__main__':
from soil import easy from soil import easy
sim = easy('rabbits.yml') with easy('rabbits.yml') as sim:
sim.run() sim.run()

@ -1,130 +1,157 @@
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment
from soil.time import Delta, When, NEVER from soil.time import Delta
from enum import Enum from enum import Enum
from collections import Counter
import logging import logging
import math import math
class RabbitModel(FSM, NetworkAgent): class RabbitEnv(Environment):
mating_prob = 0.005 @property
offspring = 0 def num_rabbits(self):
birth = None return self.count_agents(agent_class=Rabbit)
sexual_maturity = 3 @property
life_expectancy = 30 def num_males(self):
return self.count_agents(agent_class=Male)
@property
def num_females(self):
return self.count_agents(agent_class=Female)
@default_state
@state
def newborn(self):
self.birth = self.now
self.info(f'I am a newborn.')
self.model['rabbits_alive'] = self.model.get('rabbits_alive', 0) + 1
# Here we can skip the `youngling` state by using a coroutine/generator. class Rabbit(FSM, NetworkAgent):
while self.age < self.sexual_maturity:
interval = self.sexual_maturity - self.age
yield Delta(interval)
self.info(f'I am fertile! My age is {self.age}') sexual_maturity = 30
return self.fertile life_expectancy = 300
birth = None
@property @property
def age(self): def age(self):
if self.birth is None:
return None
return self.now - self.birth return self.now - self.birth
@default_state
@state
def newborn(self):
self.info('I am a newborn.')
self.birth = self.now
self.offspring = 0
return self.youngling, Delta(self.sexual_maturity - self.age)
@state
def youngling(self):
if self.age >= self.sexual_maturity:
self.info(f'I am fertile! My age is {self.age}')
return self.fertile
@state @state
def fertile(self): def fertile(self):
raise Exception("Each subclass should define its fertile state") raise Exception("Each subclass should define its fertile state")
def step(self): @state
super().step() def dead(self):
if self.prob(self.age / self.life_expectancy): self.die()
return self.die()
class Male(RabbitModel):
class Male(Rabbit):
max_females = 5 max_females = 5
mating_prob = 0.001
@state @state
def fertile(self): def fertile(self):
if self.age > self.life_expectancy:
return self.dead
# Males try to mate # Males try to mate
for f in self.model.agents(agent_class=Female, for f in self.model.agents(agent_class=Female,
state_id=Female.fertile.id, state_id=Female.fertile.id,
limit=self.max_females): limit=self.max_females):
self.debug('Found a female:', repr(f)) self.debug('FOUND A FEMALE: ', repr(f), self.mating_prob)
if self.prob(self['mating_prob']): if self.prob(self['mating_prob']):
f.impregnate(self) f.impregnate(self)
break # Take a break, don't try to impregnate the rest break # Do not try to impregnate other females
class Female(Rabbit):
class Female(RabbitModel):
due_date = None
age_of_pregnancy = None
gestation = 10 gestation = 10
mate = None conception = None
@state @state
def fertile(self): def fertile(self):
return self.fertile, NEVER # Just wait for a Male
@state
def pregnant(self):
self.info('I am pregnant')
if self.age > self.life_expectancy: if self.age > self.life_expectancy:
return self.dead return self.dead
if self.conception is not None:
return self.pregnant
self.due_date = self.now + self.gestation @property
def pregnancy(self):
if self.conception is None:
return None
return self.now - self.conception
number_of_babies = int(8+4*self.random.random()) def impregnate(self, male):
self.info(f'impregnated by {repr(male)}')
self.mate = male
self.conception = self.now
self.number_of_babies = int(8+4*self.random.random())
while self.now < self.due_date: @state
yield When(self.due_date) def pregnant(self):
self.debug('I am pregnant')
self.info('Having {} babies'.format(number_of_babies)) if self.age > self.life_expectancy:
for i in range(number_of_babies): self.info("Dying before giving birth")
agent_class = self.random.choice([Male, Female]) return self.die()
child = self.model.add_node(agent_class=agent_class,
topology=self.topology)
self.model.add_edge(self, child)
self.model.add_edge(self.mate, child)
self.offspring += 1
self.model.agents[self.mate].offspring += 1
self.mate = None
self.due_date = None
return self.fertile
@state if self.pregnancy >= self.gestation:
def dead(self): self.info('Having {} babies'.format(self.number_of_babies))
super().dead() for i in range(self.number_of_babies):
if self.due_date is not None: state = {}
agent_class = self.random.choice([Male, Female])
child = self.model.add_node(agent_class=agent_class,
**state)
child.add_edge(self)
if self.mate:
child.add_edge(self.mate)
self.mate.offspring += 1
else:
self.debug('The father has passed away')
self.offspring += 1
self.mate = None
return self.fertile
def die(self):
if self.pregnancy is not None:
self.info('A mother has died carrying a baby!!') self.info('A mother has died carrying a baby!!')
return super().die()
def impregnate(self, male):
self.info(f'{repr(male)} impregnating female {repr(self)}')
self.mate = male
self.set_state(self.pregnant, when=self.now)
class RandomAccident(BaseAgent): class RandomAccident(BaseAgent):
level = logging.INFO
def step(self): def step(self):
rabbits_total = self.model.topology.number_of_nodes() rabbits_alive = self.model.G.number_of_nodes()
if 'rabbits_alive' not in self.model:
self.model['rabbits_alive'] = 0 if not rabbits_alive:
rabbits_alive = self.model.get('rabbits_alive', rabbits_total) return self.die()
prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive))) prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
self.debug('Killing some rabbits with prob={}!'.format(prob_death)) self.debug('Killing some rabbits with prob={}!'.format(prob_death))
for i in self.model.network_agents: for i in self.iter_agents(agent_class=Rabbit):
if i.state.id == i.dead.id: if i.state_id == i.dead.id:
continue continue
if self.prob(prob_death): if self.prob(prob_death):
self.info('I killed a rabbit: {}'.format(i.id)) self.info('I killed a rabbit: {}'.format(i.id))
rabbits_alive = self.model['rabbits_alive'] = rabbits_alive -1 rabbits_alive -= 1
i.set_state(i.dead) i.die()
self.debug('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total)) self.debug('Rabbits alive: {}'.format(rabbits_alive))
if self.model.count_agents(state_id=RabbitModel.dead.id) == self.model.topology.number_of_nodes():
self.die()
if __name__ == '__main__':
from soil import easy
with easy('rabbits.yml') as sim:
sim.run()

@ -7,11 +7,10 @@ description: null
group: null group: null
interval: 1.0 interval: 1.0
max_time: 100 max_time: 100
model_class: soil.environment.Environment model_class: rabbit_agents.RabbitEnv
model_params: model_params:
agents: agents:
topology: true topology: true
agent_class: rabbit_agents.RabbitModel
distribution: distribution:
- agent_class: rabbit_agents.Male - agent_class: rabbit_agents.Male
weight: 1 weight: 1
@ -34,5 +33,10 @@ model_params:
nodes: nodes:
- id: 1 - id: 1
- id: 0 - id: 0
model_reporters:
num_males: 'num_males'
num_females: 'num_females'
num_rabbits: |
py:lambda env: env.num_males + env.num_females
extra: extra:
visualization_params: {} visualization_params: {}

@ -5,6 +5,7 @@ import sys
import os import os
import logging import logging
import traceback import traceback
from contextlib import contextmanager
from .version import __version__ from .version import __version__
@ -30,6 +31,7 @@ def main(
*, *,
do_run=False, do_run=False,
debug=False, debug=False,
pdb=False,
**kwargs, **kwargs,
): ):
import argparse import argparse
@ -154,6 +156,7 @@ def main(
if args.pdb or debug: if args.pdb or debug:
args.synchronous = True args.synchronous = True
os.environ["SOIL_POSTMORTEM"] = "true"
res = [] res = []
try: try:
@ -214,9 +217,20 @@ def main(
return res return res
def easy(cfg, debug=False, **kwargs): @contextmanager
return main(cfg, **kwargs)[0] def easy(cfg, pdb=False, debug=False, **kwargs):
ex = None
try:
yield main(cfg, **kwargs)[0]
except Exception as e:
if os.environ.get("SOIL_POSTMORTEM"):
from .debugging import post_mortem
print(traceback.format_exc())
post_mortem()
ex = e
finally:
if ex:
raise ex
if __name__ == "__main__": if __name__ == "__main__":
main(do_run=True) main(do_run=True)

@ -1,9 +1,7 @@
from . import main as init_main from . import main as init_main
def main(): def main():
init_main(do_run=True) init_main(do_run=True)
if __name__ == '__main__':
if __name__ == "__main__":
init_main(do_run=True) init_main(do_run=True)

@ -29,10 +29,6 @@ def as_node(agent):
IGNORED_FIELDS = ("model", "logger") IGNORED_FIELDS = ("model", "logger")
class DeadAgent(Exception):
pass
class MetaAgent(ABCMeta): class MetaAgent(ABCMeta):
def __new__(mcls, name, bases, namespace): def __new__(mcls, name, bases, namespace):
defaults = {} defaults = {}
@ -198,7 +194,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
def step(self): def step(self):
if not self.alive: if not self.alive:
raise DeadAgent(self.unique_id) raise time.DeadAgent(self.unique_id)
return super().step() or time.Delta(self.interval) return super().step() or time.Delta(self.interval)
def log(self, message, *args, level=logging.INFO, **kwargs): def log(self, message, *args, level=logging.INFO, **kwargs):
@ -264,6 +260,10 @@ class NetworkAgent(BaseAgent):
return list(self.iter_agents(limit_neighbors=True, **kwargs)) return list(self.iter_agents(limit_neighbors=True, **kwargs))
def add_edge(self, other): def add_edge(self, other):
assert self.node_id
assert other.node_id
assert self.node_id in self.G.nodes
assert other.node_id in self.G.nodes
self.topology.add_edge(self.node_id, other.node_id) self.topology.add_edge(self.node_id, other.node_id)
@property @property
@ -303,7 +303,9 @@ class NetworkAgent(BaseAgent):
return G return G
def remove_node(self): def remove_node(self):
print(f'Removing node for {self.unique_id}: {self.node_id}')
self.G.remove_node(self.node_id) self.G.remove_node(self.node_id)
self.node_id = None
def add_edge(self, other, edge_attr_dict=None, *edge_attrs): def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
if self.node_id not in self.G.nodes(data=False): if self.node_id not in self.G.nodes(data=False):
@ -322,6 +324,8 @@ class NetworkAgent(BaseAgent):
) )
def die(self, remove=True): def die(self, remove=True):
if not self.alive:
return
if remove: if remove:
self.remove_node() self.remove_node()
return super().die() return super().die()
@ -351,7 +355,7 @@ def state(name=None):
self._coroutine = None self._coroutine = None
next_state = ex.value next_state = ex.value
if next_state is not None: if next_state is not None:
self.set_state(next_state) self._set_state(next_state)
return next_state return next_state
func.id = name or func.__name__ func.id = name or func.__name__
@ -401,8 +405,8 @@ class MetaFSM(MetaAgent):
class FSM(BaseAgent, metaclass=MetaFSM): class FSM(BaseAgent, metaclass=MetaFSM):
def __init__(self, *args, **kwargs): def __init__(self, **kwargs):
super(FSM, self).__init__(*args, **kwargs) super(FSM, self).__init__(**kwargs)
if not hasattr(self, "state_id"): if not hasattr(self, "state_id"):
if not self._default_state: if not self._default_state:
raise ValueError( raise ValueError(
@ -411,7 +415,7 @@ class FSM(BaseAgent, metaclass=MetaFSM):
self.state_id = self._default_state.id self.state_id = self._default_state.id
self._coroutine = None self._coroutine = None
self.set_state(self.state_id) self._set_state(self.state_id)
def step(self): def step(self):
self.debug(f"Agent {self.unique_id} @ state {self.state_id}") self.debug(f"Agent {self.unique_id} @ state {self.state_id}")
@ -434,11 +438,11 @@ class FSM(BaseAgent, metaclass=MetaFSM):
pass pass
if next_state is not None: if next_state is not None:
self.set_state(next_state) self._set_state(next_state)
return when or default_interval return when or default_interval
def set_state(self, state, when=None): def _set_state(self, state, when=None):
if hasattr(state, "id"): if hasattr(state, "id"):
state = state.id state = state.id
if state not in self._states: if state not in self._states:
@ -576,83 +580,6 @@ def _convert_agent_classs(ind, to_string=False, **kwargs):
return deserialize_definition(ind, **kwargs) return deserialize_definition(ind, **kwargs)
# def _agent_from_definition(definition, random, value=-1, unique_id=None):
# """Used in the initialization of agents given an agent distribution."""
# if value < 0:
# value = random.random()
# for d in sorted(definition, key=lambda x: x.get('threshold')):
# threshold = d.get('threshold', (-1, -1))
# # Check if the definition matches by id (first) or by threshold
# if (unique_id is not None and unique_id in d.get('ids', [])) or \
# (value >= threshold[0] and value < threshold[1]):
# state = {}
# if 'state' in d:
# state = deepcopy(d['state'])
# return d['agent_class'], state
# raise Exception('Definition for value {} not found in: {}'.format(value, definition))
# def _definition_to_dict(definition, random, size=None, default_state=None):
# state = default_state or {}
# agents = {}
# remaining = {}
# if size:
# for ix in range(size):
# remaining[ix] = copy(state)
# else:
# remaining = defaultdict(lambda x: copy(state))
# distro = sorted([item for item in definition if 'weight' in item])
# id = 0
# def init_agent(item, id=ix):
# while id in agents:
# id += 1
# agent = remaining[id]
# agent['state'].update(copy(item.get('state', {})))
# agents[agent.unique_id] = agent
# del remaining[id]
# return agent
# for item in definition:
# if 'ids' in item:
# ids = item['ids']
# del item['ids']
# for id in ids:
# agent = init_agent(item, id)
# for item in definition:
# if 'number' in item:
# times = item['number']
# del item['number']
# for times in range(times):
# if size:
# ix = random.choice(remaining.keys())
# agent = init_agent(item, id)
# else:
# agent = init_agent(item)
# if not size:
# return agents
# if len(remaining) < 0:
# raise Exception('Invalid definition. Too many agents to add')
# total_weight = float(sum(s['weight'] for s in distro))
# unit = size / total_weight
# for item in distro:
# times = unit * item['weight']
# del item['weight']
# for times in range(times):
# ix = random.choice(remaining.keys())
# agent = init_agent(item, id)
# return agents
class AgentView(Mapping, Set): class AgentView(Mapping, Set):
"""A lazy-loaded list of agents.""" """A lazy-loaded list of agents."""

@ -31,8 +31,8 @@ class Debug(pdb.Pdb):
def __init__(self, *args, skip_soil=False, **kwargs): def __init__(self, *args, skip_soil=False, **kwargs):
skip = kwargs.get("skip", []) skip = kwargs.get("skip", [])
skip.append("soil") skip.append("soil")
skip.append("contextlib")
if skip_soil: if skip_soil:
skip.append("soil")
skip.append("soil.*") skip.append("soil.*")
skip.append("mesa.*") skip.append("mesa.*")
super(Debug, self).__init__(*args, skip=skip, **kwargs) super(Debug, self).__init__(*args, skip=skip, **kwargs)
@ -181,7 +181,7 @@ def set_trace(frame=None, **kwargs):
debugger.set_trace(frame) debugger.set_trace(frame)
def post_mortem(traceback=None): def post_mortem(traceback=None, **kwargs):
global debugger global debugger
if debugger is None: if debugger is None:
debugger = Debug(**kwargs) debugger = Debug(**kwargs)

@ -142,12 +142,12 @@ class BaseEnvironment(Model):
"The environment has not been scheduled, so it has no sense of time" "The environment has not been scheduled, so it has no sense of time"
) )
def add_agent(self, agent_class, unique_id=None, **kwargs): def add_agent(self, unique_id=None, **kwargs):
a = None
if unique_id is None: if unique_id is None:
unique_id = self.next_id() unique_id = self.next_id()
a = agent_class(model=self, unique_id=unique_id, **args) kwargs['unique_id'] = unique_id
a = self._agent_from_dict(kwargs)
self.schedule.add(a) self.schedule.add(a)
return a return a
@ -236,6 +236,7 @@ class NetworkEnvironment(BaseEnvironment):
node_id = agent.get("node_id", None) node_id = agent.get("node_id", None)
if node_id is None: if node_id is None:
node_id = network.find_unassigned(self.G, random=self.random) node_id = network.find_unassigned(self.G, random=self.random)
self.G.nodes[node_id]['agent'] = None
agent["node_id"] = node_id agent["node_id"] = node_id
agent["unique_id"] = unique_id agent["unique_id"] = unique_id
agent["topology"] = self.G agent["topology"] = self.G
@ -269,18 +270,29 @@ class NetworkEnvironment(BaseEnvironment):
node_id = network.find_unassigned( node_id = network.find_unassigned(
G=self.G, shuffle=True, random=self.random G=self.G, shuffle=True, random=self.random
) )
if node_id is None:
node_id = f'node_for_{unique_id}'
if node_id in G.nodes: if node_id not in self.G.nodes:
self.G.nodes[node_id]["agent"] = None # Reserve
else:
self.G.add_node(node_id) self.G.add_node(node_id)
assert "agent" not in self.G.nodes[node_id]
self.G.nodes[node_id]["agent"] = None # Reserve
a = self.add_agent( a = self.add_agent(
unique_id=unique_id, agent_class=agent_class, node_id=node_id, **kwargs unique_id=unique_id, agent_class=agent_class, topology=self.G, node_id=node_id, **kwargs
) )
a["visible"] = True a["visible"] = True
return a return a
def add_agent(self, *args, **kwargs):
a = super().add_agent(*args, **kwargs)
if 'node_id' in a:
if a.node_id == 24:
import pdb;pdb.set_trace()
assert self.G.nodes[a.node_id]['agent'] == a
return a
def agent_for_node_id(self, node_id): def agent_for_node_id(self, node_id):
return self.G.nodes[node_id].get("agent") return self.G.nodes[node_id].get("agent")

@ -65,10 +65,8 @@ def find_unassigned(G, shuffle=False, random=random):
random.shuffle(candidates) random.shuffle(candidates)
for next_id, data in candidates: for next_id, data in candidates:
if "agent" not in data: if "agent" not in data:
node_id = next_id return next_id
break return None
return node_id
def dump_gexf(G, f): def dump_gexf(G, f):

@ -226,7 +226,7 @@ Model stats:
) )
model.step() model.step()
if model.schedule.time < until: # Simulation ended (no more steps) before until (i.e., no changes expected) if model.schedule.time < until: # Simulation ended (no more steps) before the expected time
model.schedule.time = until model.schedule.time = until
return model return model

@ -13,6 +13,10 @@ from mesa import Agent as MesaAgent
INFINITY = float("inf") INFINITY = float("inf")
class DeadAgent(Exception):
pass
class When: class When:
def __init__(self, time): def __init__(self, time):
if isinstance(time, When): if isinstance(time, When):
@ -38,23 +42,27 @@ class When:
return self._time > other return self._time > other
return self._time > other.next(self._time) return self._time > other.next(self._time)
def ready(self, time): def ready(self, agent):
return self._time <= time return self._time <= agent.model.schedule.time
class Cond(When): class Cond(When):
def __init__(self, func, delta=1): def __init__(self, func, delta=1):
self._func = func self._func = func
self._delta = delta self._delta = delta
self._checked = False
def next(self, time): def next(self, time):
return time + self._delta if self._checked:
return time + self._delta
return time
def abs(self, time): def abs(self, time):
return self return self
def ready(self, time): def ready(self, agent):
return self._func(time) self._checked = True
return self._func(agent)
def __eq__(self, other): def __eq__(self, other):
return False return False
@ -109,10 +117,12 @@ class TimedActivation(BaseScheduler):
elif not isinstance(when, When): elif not isinstance(when, When):
when = When(when) when = When(when)
if agent.unique_id in self._agents: if agent.unique_id in self._agents:
self._queue.remove((self._next[agent.unique_id], agent))
del self._agents[agent.unique_id] del self._agents[agent.unique_id]
heapify(self._queue) if agent.unique_id in self._next:
self._queue.remove((self._next[agent.unique_id], agent))
heapify(self._queue)
self._next[agent.unique_id] = when
heappush(self._queue, (when, agent)) heappush(self._queue, (when, agent))
super().add(agent) super().add(agent)
@ -139,8 +149,9 @@ class TimedActivation(BaseScheduler):
if when > self.time: if when > self.time:
break break
heappop(self._queue) heappop(self._queue)
if when.ready(self.time): if when.ready(agent):
to_process.append(agent) to_process.append(agent)
self._next.pop(agent.unique_id, None)
continue continue
next_time = min(next_time, when.next(self.time)) next_time = min(next_time, when.next(self.time))
@ -155,13 +166,20 @@ class TimedActivation(BaseScheduler):
for agent in to_process: for agent in to_process:
self.logger.debug(f"Stepping agent {agent}") self.logger.debug(f"Stepping agent {agent}")
returned = ((agent.step() or Delta(1))).abs(self.time) try:
returned = ((agent.step() or Delta(1))).abs(self.time)
except DeadAgent:
if agent.unique_id in self._next:
del self._next[agent.unique_id]
agent.alive = False
continue
if not getattr(agent, "alive", True): if not getattr(agent, "alive", True):
self.remove(agent) self.remove(agent)
continue continue
value = when.next(self.time) value = returned.next(self.time)
if value < self.time: if value < self.time:
raise Exception( raise Exception(
@ -172,6 +190,8 @@ class TimedActivation(BaseScheduler):
self._next[agent.unique_id] = returned self._next[agent.unique_id] = returned
heappush(self._queue, (returned, agent)) heappush(self._queue, (returned, agent))
else:
assert not self._next[agent.unique_id]
self.steps += 1 self.steps += 1
self.logger.debug(f"Updating time step: {self.time} -> {next_time}") self.logger.debug(f"Updating time step: {self.time} -> {next_time}")

@ -24,7 +24,7 @@ class TestMain(TestCase):
'''A dead agent should raise an exception if it is stepped after death''' '''A dead agent should raise an exception if it is stepped after death'''
d = Dead(unique_id=0, model=environment.Environment()) d = Dead(unique_id=0, model=environment.Environment())
d.step() d.step()
with pytest.raises(agents.DeadAgent): with pytest.raises(stime.DeadAgent):
d.step() d.step()

@ -0,0 +1,74 @@
from unittest import TestCase
from soil import time, agents, environment
class TestMain(TestCase):
def test_cond(self):
'''
A condition should match a When if the concition is True
'''
t = time.Cond(lambda t: True)
f = time.Cond(lambda t: False)
for i in range(10):
w = time.When(i)
assert w == t
assert w is not f
def test_cond(self):
'''
Comparing a Cond to a Delta should always return False
'''
c = time.Cond(lambda t: False)
d = time.Delta(1)
assert c is not d
def test_cond_env(self):
'''
'''
times_started = []
times_awakened = []
times = []
done = 0
class CondAgent(agents.BaseAgent):
def step(self):
nonlocal done
times_started.append(self.now)
while True:
yield time.Cond(lambda agent: agent.model.schedule.time >= 10)
times_awakened.append(self.now)
if self.now >= 10:
break
done += 1
env = environment.Environment(agents=[{'agent_class': CondAgent}])
while env.schedule.time < 11:
env.step()
times.append(env.now)
assert env.schedule.time == 11
assert times_started == [0]
assert times_awakened == [10]
assert done == 1
# The first time will produce the Cond.
# Since there are no other agents, time will not advance, but the number
# of steps will.
assert env.schedule.steps == 12
assert len(times) == 12
while env.schedule.time < 12:
env.step()
times.append(env.now)
assert env.schedule.time == 12
assert times_started == [0, 11]
assert times_awakened == [10, 11]
assert done == 2
# Once more to yield the cond, another one to continue
assert env.schedule.steps == 14
assert len(times) == 14
Loading…
Cancel
Save