From 0efcd24d904b2bee289e4b63d45478e610984d63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=2E=20Fernando=20S=C3=A1nchez?= Date: Sun, 16 Oct 2022 21:57:30 +0200 Subject: [PATCH] Improve exporters --- examples/rabbits/basic/rabbit_agents.py | 26 +++++++++++++---- examples/rabbits/basic/rabbits.yml | 8 +++-- soil/exporters.py | 26 ++++++++++------- soil/serialization.py | 10 +++++-- soil/simulation.py | 2 +- soil/utils.py | 39 ++++++++++++++++--------- tests/test_exporters.py | 2 +- 7 files changed, 78 insertions(+), 35 deletions(-) diff --git a/examples/rabbits/basic/rabbit_agents.py b/examples/rabbits/basic/rabbit_agents.py index fc7b73b..bd05057 100644 --- a/examples/rabbits/basic/rabbit_agents.py +++ b/examples/rabbits/basic/rabbit_agents.py @@ -1,4 +1,4 @@ -from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent +from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment from soil.time import Delta from enum import Enum from collections import Counter @@ -6,7 +6,23 @@ import logging import math -class RabbitModel(FSM, NetworkAgent): +class RabbitEnv(Environment): + + @property + def num_rabbits(self): + return self.count_agents(agent_class=Rabbit) + + @property + def num_males(self): + return self.count_agents(agent_class=Male) + + @property + def num_females(self): + return self.count_agents(agent_class=Female) + + + +class Rabbit(FSM, NetworkAgent): sexual_maturity = 30 life_expectancy = 300 @@ -35,7 +51,7 @@ class RabbitModel(FSM, NetworkAgent): self.die() -class Male(RabbitModel): +class Male(Rabbit): max_females = 5 mating_prob = 0.001 @@ -56,7 +72,7 @@ class Male(RabbitModel): break # Take a break -class Female(RabbitModel): +class Female(Rabbit): gestation = 30 @state @@ -119,7 +135,7 @@ class RandomAccident(BaseAgent): prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive))) self.debug('Killing some rabbits with prob={}!'.format(prob_death)) - for i in self.iter_agents(agent_class=RabbitModel): + for i in self.iter_agents(agent_class=Rabbit): if i.state_id == i.dead.id: continue if self.prob(prob_death): diff --git a/examples/rabbits/basic/rabbits.yml b/examples/rabbits/basic/rabbits.yml index 6945f67..a137844 100644 --- a/examples/rabbits/basic/rabbits.yml +++ b/examples/rabbits/basic/rabbits.yml @@ -7,11 +7,10 @@ description: null group: null interval: 1.0 max_time: 100 -model_class: soil.environment.Environment +model_class: rabbit_agents.RabbitEnv model_params: agents: topology: true - agent_class: rabbit_agents.RabbitModel distribution: - agent_class: rabbit_agents.Male weight: 1 @@ -34,5 +33,10 @@ model_params: nodes: - id: 1 - id: 0 + model_reporters: + num_males: 'num_males' + num_females: 'num_females' + num_rabbits: | + py:lambda env: env.num_males + env.num_females extra: visualization_params: {} diff --git a/soil/exporters.py b/soil/exporters.py index a31921d..b1850f4 100644 --- a/soil/exporters.py +++ b/soil/exporters.py @@ -10,7 +10,7 @@ import networkx as nx from .serialization import deserialize -from .utils import open_or_reuse, logger, timer +from .utils import try_backup, open_or_reuse, logger, timer from . import utils, network @@ -91,12 +91,14 @@ class default(Exporter): """Default exporter. Writes sqlite results, as well as the simulation YAML""" def sim_start(self): - if not self.dry_run: - logger.info("Dumping results to %s", self.outdir) - with self.output(self.simulation.name + ".dumped.yml") as f: - f.write(self.simulation.to_yaml()) - else: + if self.dry_run: logger.info("NOT dumping results") + return + logger.info("Dumping results to %s", self.outdir) + with self.output(self.simulation.name + ".dumped.yml") as f: + f.write(self.simulation.to_yaml()) + self.dbpath = os.path.join(self.outdir, f"{self.simulation.name}.sqlite") + try_backup(self.dbpath, move=True) def trial_end(self, env): if self.dry_run: @@ -107,21 +109,23 @@ class default(Exporter): "Dumping simulation {} trial {}".format(self.simulation.name, env.id) ): - fpath = os.path.join(self.outdir, f"{env.id}.sqlite") - engine = create_engine(f"sqlite:///{fpath}", echo=False) + engine = create_engine(f"sqlite:///{self.dbpath}", echo=False) dc = env.datacollector - for (t, df) in get_dc_dfs(dc): + for (t, df) in get_dc_dfs(dc, trial_id=env.id): df.to_sql(t, con=engine, if_exists="append") -def get_dc_dfs(dc): +def get_dc_dfs(dc, trial_id=None): dfs = { "env": dc.get_model_vars_dataframe(), "agents": dc.get_agent_vars_dataframe(), } for table_name in dc.tables: dfs[table_name] = dc.get_table_dataframe(table_name) + if trial_id: + for (name, df) in dfs.items(): + df['trial_id'] = trial_id yield from dfs.items() @@ -135,7 +139,7 @@ class csv(Exporter): self.simulation.name, env.id, self.outdir ) ): - for (df_name, df) in get_dc_dfs(env.datacollector): + for (df_name, df) in get_dc_dfs(env.datacollector, trial_id=env.id): with self.output("{}.{}.csv".format(env.id, df_name)) as f: df.to_csv(f) diff --git a/soil/serialization.py b/soil/serialization.py index b728983..f0a98df 100644 --- a/soil/serialization.py +++ b/soil/serialization.py @@ -197,7 +197,7 @@ def deserializer(type_, known_modules=KNOWN_MODULES): return getattr(cls, "deserialize", cls) except (ImportError, AttributeError) as ex: errors.append((modname, tname, ex)) - raise Exception('Could not find type "{}". Tried: {}'.format(type_, errors)) + raise ValueError('Could not find type "{}". Tried: {}'.format(type_, errors)) def deserialize(type_, value=None, globs=None, **kwargs): @@ -207,7 +207,13 @@ def deserialize(type_, value=None, globs=None, **kwargs): if globs and type_ in globs: des = globs[type_] else: - des = deserializer(type_, **kwargs) + try: + des = deserializer(type_, **kwargs) + except ValueError as ex: + try: + des = eval(type_) + except Exception: + raise ex if value is None: return des return des(value) diff --git a/soil/simulation.py b/soil/simulation.py index 7c79d92..e5f5526 100644 --- a/soil/simulation.py +++ b/soil/simulation.py @@ -151,7 +151,7 @@ class Simulation: def deserialize_reporters(reporters): for (k, v) in reporters.items(): if isinstance(v, str) and v.startswith("py:"): - reporters[k] = serialization.deserialize(value.lsplit(":", 1)[1]) + reporters[k] = serialization.deserialize(v.split(":", 1)[1]) return reporters params = self.model_params.copy() diff --git a/soil/utils.py b/soil/utils.py index 9c4bcc7..0422f48 100644 --- a/soil/utils.py +++ b/soil/utils.py @@ -4,7 +4,7 @@ import os import traceback from functools import partial -from shutil import copyfile +from shutil import copyfile, move from multiprocessing import Pool from contextlib import contextmanager @@ -47,21 +47,34 @@ def timer(name="task", pre="", function=logger.info, to_object=None): to_object.end = end +def try_backup(path, move=False): + if not os.path.exists(path): + return None + outdir = os.path.dirname(path) + if outdir and not os.path.exists(outdir): + os.makedirs(outdir) + creation = os.path.getctime(path) + stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation)) + + backup_dir = os.path.join(outdir, "backup") + if not os.path.exists(backup_dir): + os.makedirs(backup_dir) + newpath = os.path.join( + backup_dir, "{}@{}".format(os.path.basename(path), stamp) + ) + if move: + move(path, newpath) + else: + copyfile(path, newpath) + return newpath + + def safe_open(path, mode="r", backup=True, **kwargs): outdir = os.path.dirname(path) if outdir and not os.path.exists(outdir): os.makedirs(outdir) - if backup and "w" in mode and os.path.exists(path): - creation = os.path.getctime(path) - stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation)) - - backup_dir = os.path.join(outdir, "backup") - if not os.path.exists(backup_dir): - os.makedirs(backup_dir) - newpath = os.path.join( - backup_dir, "{}@{}".format(os.path.basename(path), stamp) - ) - copyfile(path, newpath) + if backup and "w" in mode: + try_backup(path) return open(path, mode=mode, **kwargs) @@ -70,7 +83,7 @@ def open_or_reuse(f, *args, **kwargs): try: with safe_open(f, *args, **kwargs) as f: yield f - except (AttributeError, TypeError): + except (AttributeError, TypeError) as ex: yield f diff --git a/tests/test_exporters.py b/tests/test_exporters.py index 973bd06..baf9d83 100644 --- a/tests/test_exporters.py +++ b/tests/test_exporters.py @@ -99,7 +99,7 @@ class Exporters(TestCase): try: for e in envs: - db = sqlite3.connect(os.path.join(simdir, f"{e.id}.sqlite")) + db = sqlite3.connect(os.path.join(simdir, f"{s.name}.sqlite")) cur = db.cursor() agent_entries = cur.execute("SELECT * from agents").fetchall() env_entries = cur.execute("SELECT * from env").fetchall()