diff --git a/examples/rabbits/basic/rabbit_agents.py b/examples/rabbits/basic/rabbit_agents.py index bd05057..284c08a 100644 --- a/examples/rabbits/basic/rabbit_agents.py +++ b/examples/rabbits/basic/rabbit_agents.py @@ -21,7 +21,6 @@ class RabbitEnv(Environment): return self.count_agents(agent_class=Female) - class Rabbit(FSM, NetworkAgent): sexual_maturity = 30 @@ -125,8 +124,6 @@ class Female(Rabbit): class RandomAccident(BaseAgent): - level = logging.INFO - def step(self): rabbits_alive = self.model.G.number_of_nodes() @@ -144,6 +141,7 @@ class RandomAccident(BaseAgent): i.set_state(i.dead) self.debug('Rabbits alive: {}'.format(rabbits_alive)) + if __name__ == '__main__': from soil import easy sim = easy('rabbits.yml') diff --git a/soil/agents/__init__.py b/soil/agents/__init__.py index c284604..0ed5bf3 100644 --- a/soil/agents/__init__.py +++ b/soil/agents/__init__.py @@ -47,11 +47,29 @@ class MetaAgent(ABCMeta): } for attr, func in namespace.items(): - if ( - isinstance(func, types.FunctionType) - or isinstance(func, property) - or isinstance(func, classmethod) - or attr[0] == "_" + if attr == 'step' and inspect.isgeneratorfunction(func): + orig_func = func + new_nmspc['_MetaAgent__coroutine'] = None + + @wraps(func) + def func(self): + while True: + if not self.__coroutine: + self.__coroutine = orig_func(self) + try: + return next(self.__coroutine) + except StopIteration as ex: + self.__coroutine = None + return ex.value + + func.id = name or func.__name__ + func.is_default = False + new_nmspc[attr] = func + elif ( + isinstance(func, types.FunctionType) + or isinstance(func, property) + or isinstance(func, classmethod) + or attr[0] == "_" ): new_nmspc[attr] = func elif attr == "defaults": diff --git a/soil/exporters.py b/soil/exporters.py index b1850f4..55a5597 100644 --- a/soil/exporters.py +++ b/soil/exporters.py @@ -125,7 +125,7 @@ def get_dc_dfs(dc, trial_id=None): dfs[table_name] = dc.get_table_dataframe(table_name) if trial_id: for (name, df) in dfs.items(): - df['trial_id'] = trial_id + df["trial_id"] = trial_id yield from dfs.items() diff --git a/soil/utils.py b/soil/utils.py index 0422f48..92d9d74 100644 --- a/soil/utils.py +++ b/soil/utils.py @@ -59,9 +59,7 @@ def try_backup(path, move=False): backup_dir = os.path.join(outdir, "backup") if not os.path.exists(backup_dir): os.makedirs(backup_dir) - newpath = os.path.join( - backup_dir, "{}@{}".format(os.path.basename(path), stamp) - ) + newpath = os.path.join(backup_dir, "{}@{}".format(os.path.basename(path), stamp)) if move: move(path, newpath) else: diff --git a/tests/test_agents.py b/tests/test_agents.py index bee9a9a..8603b1e 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -13,14 +13,35 @@ class Dead(agents.FSM): class TestMain(TestCase): + def test_die_returns_infinity(self): + '''The last step of a dead agent should return time.INFINITY''' + d = Dead(unique_id=0, model=environment.Environment()) + ret = d.step().abs(0) + print(ret, "next") + assert ret == stime.INFINITY + def test_die_raises_exception(self): + '''A dead agent should raise an exception if it is stepped after death''' d = Dead(unique_id=0, model=environment.Environment()) d.step() with pytest.raises(agents.DeadAgent): d.step() - def test_die_returns_infinity(self): - d = Dead(unique_id=0, model=environment.Environment()) - ret = d.step().abs(0) - print(ret, "next") - assert ret == stime.INFINITY + + def test_agent_generator(self): + ''' + The step function of an agent could be a generator. In that case, the state of the + agent will be resumed after every call to step. + ''' + class Gen(agents.BaseAgent): + def step(self): + a = 0 + for i in range(5): + yield a + a += 1 + e = environment.Environment() + g = Gen(model=e, unique_id=e.next_id()) + + for i in range(5): + t = g.step() + assert t == i