1
0
mirror of https://github.com/gsi-upm/soil synced 2024-11-22 03:02:28 +00:00

Improve exporters

This commit is contained in:
J. Fernando Sánchez 2022-10-16 21:57:30 +02:00
parent 78833a9e08
commit 0efcd24d90
7 changed files with 78 additions and 35 deletions

View File

@ -1,4 +1,4 @@
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment
from soil.time import Delta from soil.time import Delta
from enum import Enum from enum import Enum
from collections import Counter from collections import Counter
@ -6,7 +6,23 @@ import logging
import math import math
class RabbitModel(FSM, NetworkAgent): class RabbitEnv(Environment):
@property
def num_rabbits(self):
return self.count_agents(agent_class=Rabbit)
@property
def num_males(self):
return self.count_agents(agent_class=Male)
@property
def num_females(self):
return self.count_agents(agent_class=Female)
class Rabbit(FSM, NetworkAgent):
sexual_maturity = 30 sexual_maturity = 30
life_expectancy = 300 life_expectancy = 300
@ -35,7 +51,7 @@ class RabbitModel(FSM, NetworkAgent):
self.die() self.die()
class Male(RabbitModel): class Male(Rabbit):
max_females = 5 max_females = 5
mating_prob = 0.001 mating_prob = 0.001
@ -56,7 +72,7 @@ class Male(RabbitModel):
break # Take a break break # Take a break
class Female(RabbitModel): class Female(Rabbit):
gestation = 30 gestation = 30
@state @state
@ -119,7 +135,7 @@ class RandomAccident(BaseAgent):
prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive))) prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
self.debug('Killing some rabbits with prob={}!'.format(prob_death)) self.debug('Killing some rabbits with prob={}!'.format(prob_death))
for i in self.iter_agents(agent_class=RabbitModel): for i in self.iter_agents(agent_class=Rabbit):
if i.state_id == i.dead.id: if i.state_id == i.dead.id:
continue continue
if self.prob(prob_death): if self.prob(prob_death):

View File

@ -7,11 +7,10 @@ description: null
group: null group: null
interval: 1.0 interval: 1.0
max_time: 100 max_time: 100
model_class: soil.environment.Environment model_class: rabbit_agents.RabbitEnv
model_params: model_params:
agents: agents:
topology: true topology: true
agent_class: rabbit_agents.RabbitModel
distribution: distribution:
- agent_class: rabbit_agents.Male - agent_class: rabbit_agents.Male
weight: 1 weight: 1
@ -34,5 +33,10 @@ model_params:
nodes: nodes:
- id: 1 - id: 1
- id: 0 - id: 0
model_reporters:
num_males: 'num_males'
num_females: 'num_females'
num_rabbits: |
py:lambda env: env.num_males + env.num_females
extra: extra:
visualization_params: {} visualization_params: {}

View File

@ -10,7 +10,7 @@ import networkx as nx
from .serialization import deserialize from .serialization import deserialize
from .utils import open_or_reuse, logger, timer from .utils import try_backup, open_or_reuse, logger, timer
from . import utils, network from . import utils, network
@ -91,12 +91,14 @@ class default(Exporter):
"""Default exporter. Writes sqlite results, as well as the simulation YAML""" """Default exporter. Writes sqlite results, as well as the simulation YAML"""
def sim_start(self): def sim_start(self):
if not self.dry_run: if self.dry_run:
logger.info("NOT dumping results")
return
logger.info("Dumping results to %s", self.outdir) logger.info("Dumping results to %s", self.outdir)
with self.output(self.simulation.name + ".dumped.yml") as f: with self.output(self.simulation.name + ".dumped.yml") as f:
f.write(self.simulation.to_yaml()) f.write(self.simulation.to_yaml())
else: self.dbpath = os.path.join(self.outdir, f"{self.simulation.name}.sqlite")
logger.info("NOT dumping results") try_backup(self.dbpath, move=True)
def trial_end(self, env): def trial_end(self, env):
if self.dry_run: if self.dry_run:
@ -107,21 +109,23 @@ class default(Exporter):
"Dumping simulation {} trial {}".format(self.simulation.name, env.id) "Dumping simulation {} trial {}".format(self.simulation.name, env.id)
): ):
fpath = os.path.join(self.outdir, f"{env.id}.sqlite") engine = create_engine(f"sqlite:///{self.dbpath}", echo=False)
engine = create_engine(f"sqlite:///{fpath}", echo=False)
dc = env.datacollector dc = env.datacollector
for (t, df) in get_dc_dfs(dc): for (t, df) in get_dc_dfs(dc, trial_id=env.id):
df.to_sql(t, con=engine, if_exists="append") df.to_sql(t, con=engine, if_exists="append")
def get_dc_dfs(dc): def get_dc_dfs(dc, trial_id=None):
dfs = { dfs = {
"env": dc.get_model_vars_dataframe(), "env": dc.get_model_vars_dataframe(),
"agents": dc.get_agent_vars_dataframe(), "agents": dc.get_agent_vars_dataframe(),
} }
for table_name in dc.tables: for table_name in dc.tables:
dfs[table_name] = dc.get_table_dataframe(table_name) dfs[table_name] = dc.get_table_dataframe(table_name)
if trial_id:
for (name, df) in dfs.items():
df['trial_id'] = trial_id
yield from dfs.items() yield from dfs.items()
@ -135,7 +139,7 @@ class csv(Exporter):
self.simulation.name, env.id, self.outdir self.simulation.name, env.id, self.outdir
) )
): ):
for (df_name, df) in get_dc_dfs(env.datacollector): for (df_name, df) in get_dc_dfs(env.datacollector, trial_id=env.id):
with self.output("{}.{}.csv".format(env.id, df_name)) as f: with self.output("{}.{}.csv".format(env.id, df_name)) as f:
df.to_csv(f) df.to_csv(f)

View File

@ -197,7 +197,7 @@ def deserializer(type_, known_modules=KNOWN_MODULES):
return getattr(cls, "deserialize", cls) return getattr(cls, "deserialize", cls)
except (ImportError, AttributeError) as ex: except (ImportError, AttributeError) as ex:
errors.append((modname, tname, ex)) errors.append((modname, tname, ex))
raise Exception('Could not find type "{}". Tried: {}'.format(type_, errors)) raise ValueError('Could not find type "{}". Tried: {}'.format(type_, errors))
def deserialize(type_, value=None, globs=None, **kwargs): def deserialize(type_, value=None, globs=None, **kwargs):
@ -207,7 +207,13 @@ def deserialize(type_, value=None, globs=None, **kwargs):
if globs and type_ in globs: if globs and type_ in globs:
des = globs[type_] des = globs[type_]
else: else:
try:
des = deserializer(type_, **kwargs) des = deserializer(type_, **kwargs)
except ValueError as ex:
try:
des = eval(type_)
except Exception:
raise ex
if value is None: if value is None:
return des return des
return des(value) return des(value)

View File

@ -151,7 +151,7 @@ class Simulation:
def deserialize_reporters(reporters): def deserialize_reporters(reporters):
for (k, v) in reporters.items(): for (k, v) in reporters.items():
if isinstance(v, str) and v.startswith("py:"): if isinstance(v, str) and v.startswith("py:"):
reporters[k] = serialization.deserialize(value.lsplit(":", 1)[1]) reporters[k] = serialization.deserialize(v.split(":", 1)[1])
return reporters return reporters
params = self.model_params.copy() params = self.model_params.copy()

View File

@ -4,7 +4,7 @@ import os
import traceback import traceback
from functools import partial from functools import partial
from shutil import copyfile from shutil import copyfile, move
from multiprocessing import Pool from multiprocessing import Pool
from contextlib import contextmanager from contextlib import contextmanager
@ -47,11 +47,12 @@ def timer(name="task", pre="", function=logger.info, to_object=None):
to_object.end = end to_object.end = end
def safe_open(path, mode="r", backup=True, **kwargs): def try_backup(path, move=False):
if not os.path.exists(path):
return None
outdir = os.path.dirname(path) outdir = os.path.dirname(path)
if outdir and not os.path.exists(outdir): if outdir and not os.path.exists(outdir):
os.makedirs(outdir) os.makedirs(outdir)
if backup and "w" in mode and os.path.exists(path):
creation = os.path.getctime(path) creation = os.path.getctime(path)
stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation)) stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation))
@ -61,7 +62,19 @@ def safe_open(path, mode="r", backup=True, **kwargs):
newpath = os.path.join( newpath = os.path.join(
backup_dir, "{}@{}".format(os.path.basename(path), stamp) backup_dir, "{}@{}".format(os.path.basename(path), stamp)
) )
if move:
move(path, newpath)
else:
copyfile(path, newpath) copyfile(path, newpath)
return newpath
def safe_open(path, mode="r", backup=True, **kwargs):
outdir = os.path.dirname(path)
if outdir and not os.path.exists(outdir):
os.makedirs(outdir)
if backup and "w" in mode:
try_backup(path)
return open(path, mode=mode, **kwargs) return open(path, mode=mode, **kwargs)
@ -70,7 +83,7 @@ def open_or_reuse(f, *args, **kwargs):
try: try:
with safe_open(f, *args, **kwargs) as f: with safe_open(f, *args, **kwargs) as f:
yield f yield f
except (AttributeError, TypeError): except (AttributeError, TypeError) as ex:
yield f yield f

View File

@ -99,7 +99,7 @@ class Exporters(TestCase):
try: try:
for e in envs: for e in envs:
db = sqlite3.connect(os.path.join(simdir, f"{e.id}.sqlite")) db = sqlite3.connect(os.path.join(simdir, f"{s.name}.sqlite"))
cur = db.cursor() cur = db.cursor()
agent_entries = cur.execute("SELECT * from agents").fetchall() agent_entries = cur.execute("SELECT * from agents").fetchall()
env_entries = cur.execute("SELECT * from env").fetchall() env_entries = cur.execute("SELECT * from env").fetchall()