diff --git a/CHANGELOG.md b/CHANGELOG.md index ad22c8c..b2599b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [UNRELEASED] +## [0.20.6] +### Fixed +* Agents now return `time.INFINITY` when dead, instead of 'inf' +* `soil.__init__` does not re-export built-in time (change in `soil.simulation`. It used to create subtle import conflicts when importing soil.time. +* Parallel simulations were broken because lambdas cannot be pickled properly, which is needed for multiprocessing. +### Changed +* Some internal simulation methods do not accept `*args` anymore, to avoid ambiguity and bugs. ## [0.20.5] ### Changed * Defaults are now set in the agent __init__, not in the environment. This decouples both classes a bit more, and it is more intuitive diff --git a/soil/agents/__init__.py b/soil/agents/__init__.py index c9bcd23..c4e1023 100644 --- a/soil/agents/__init__.py +++ b/soil/agents/__init__.py @@ -145,6 +145,7 @@ class BaseAgent(Agent): self.alive = False if remove: self.remove_node(self.id) + return time.INFINITY def step(self): if not self.alive: @@ -313,18 +314,16 @@ class FSM(NetworkAgent, metaclass=MetaFSM): def step(self): self.debug(f'Agent {self.unique_id} @ state {self.state_id}') - try: - interval = super().step() - except DeadAgent: - return time.When('inf') + interval = super().step() if 'id' not in self.state: - # if 'id' in self.state: - # self.set_state(self.state['id']) if self.default_state: self.set_state(self.default_state.id) else: raise Exception('{} has no valid state id or default state'.format(self)) - return self.states[self.state_id](self) or interval + interval = self.states[self.state_id](self) or interval + if not self.alive: + return time.NEVER + return interval def set_state(self, state): if hasattr(state, 'id'): diff --git a/soil/simulation.py b/soil/simulation.py index 427adb8..39d909d 100644 --- a/soil/simulation.py +++ b/soil/simulation.py @@ -1,11 +1,12 @@ import os -import time import importlib import sys import yaml import traceback import logging import networkx as nx + +from time import strftime from networkx.readwrite import json_graph from multiprocessing import Pool from functools import partial @@ -98,7 +99,7 @@ class Simulation: self.network_params = network_params self.name = name or 'Unnamed' self.seed = str(seed or name) - self._id = '{}_{}'.format(self.name, time.strftime("%Y-%m-%d_%H.%M.%S")) + self._id = '{}_{}'.format(self.name, strftime("%Y-%m-%d_%H.%M.%S")) self.group = group or '' self.num_trials = num_trials self.max_time = max_time @@ -142,10 +143,10 @@ class Simulation: '''Run the simulation and return the list of resulting environments''' return list(self.run_gen(*args, **kwargs)) - def _run_sync_or_async(self, parallel=False, *args, **kwargs): + def _run_sync_or_async(self, parallel=False, **kwargs): if parallel and not os.environ.get('SENPY_DEBUG', None): p = Pool() - func = lambda x: self.run_trial_exceptions(trial_id=x, *args, **kwargs) + func = partial(self.run_trial_exceptions, **kwargs) for i in p.imap_unordered(func, range(self.num_trials)): if isinstance(i, Exception): logger.error('Trial failed:\n\t%s', i.message) @@ -154,10 +155,9 @@ class Simulation: else: for i in range(self.num_trials): yield self.run_trial(trial_id=i, - *args, **kwargs) - def run_gen(self, *args, parallel=False, dry_run=False, + def run_gen(self, parallel=False, dry_run=False, exporters=[default, ], stats=[], outdir=None, exporter_params={}, stats_params={}, log_level=None, **kwargs): @@ -183,8 +183,7 @@ class Simulation: for exporter in exporters: exporter.start() - for env in self._run_sync_or_async(*args, - parallel=parallel, + for env in self._run_sync_or_async(parallel=parallel, log_level=log_level, **kwargs): diff --git a/soil/time.py b/soil/time.py index 889c7c8..29ae31f 100644 --- a/soil/time.py +++ b/soil/time.py @@ -15,6 +15,8 @@ class When: def abs(self, time): return self._time +NEVER = When(INFINITY) + class Delta: def __init__(self, delta): diff --git a/tests/test_agents.py b/tests/test_agents.py new file mode 100644 index 0000000..76e0998 --- /dev/null +++ b/tests/test_agents.py @@ -0,0 +1,24 @@ +from unittest import TestCase +import pytest + +from soil import agents, environment +from soil import time as stime + +class Dead(agents.FSM): + @agents.default_state + @agents.state + def only(self): + self.die() + +class TestMain(TestCase): + def test_die_raises_exception(self): + d = Dead(unique_id=0, model=environment.Environment()) + d.step() + with pytest.raises(agents.DeadAgent): + d.step() + def test_die_returns_infinity(self): + d = Dead(unique_id=0, model=environment.Environment()) + assert d.step().abs(0) == stime.INFINITY + + +