1
0
mirror of https://github.com/gsi-upm/soil synced 2024-11-24 20:02:28 +00:00

Improve exporters

This commit is contained in:
J. Fernando Sánchez 2022-10-16 21:57:30 +02:00
parent 78833a9e08
commit 0efcd24d90
7 changed files with 78 additions and 35 deletions

View File

@ -1,4 +1,4 @@
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment
from soil.time import Delta
from enum import Enum
from collections import Counter
@ -6,7 +6,23 @@ import logging
import math
class RabbitModel(FSM, NetworkAgent):
class RabbitEnv(Environment):
@property
def num_rabbits(self):
return self.count_agents(agent_class=Rabbit)
@property
def num_males(self):
return self.count_agents(agent_class=Male)
@property
def num_females(self):
return self.count_agents(agent_class=Female)
class Rabbit(FSM, NetworkAgent):
sexual_maturity = 30
life_expectancy = 300
@ -35,7 +51,7 @@ class RabbitModel(FSM, NetworkAgent):
self.die()
class Male(RabbitModel):
class Male(Rabbit):
max_females = 5
mating_prob = 0.001
@ -56,7 +72,7 @@ class Male(RabbitModel):
break # Take a break
class Female(RabbitModel):
class Female(Rabbit):
gestation = 30
@state
@ -119,7 +135,7 @@ class RandomAccident(BaseAgent):
prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
self.debug('Killing some rabbits with prob={}!'.format(prob_death))
for i in self.iter_agents(agent_class=RabbitModel):
for i in self.iter_agents(agent_class=Rabbit):
if i.state_id == i.dead.id:
continue
if self.prob(prob_death):

View File

@ -7,11 +7,10 @@ description: null
group: null
interval: 1.0
max_time: 100
model_class: soil.environment.Environment
model_class: rabbit_agents.RabbitEnv
model_params:
agents:
topology: true
agent_class: rabbit_agents.RabbitModel
distribution:
- agent_class: rabbit_agents.Male
weight: 1
@ -34,5 +33,10 @@ model_params:
nodes:
- id: 1
- id: 0
model_reporters:
num_males: 'num_males'
num_females: 'num_females'
num_rabbits: |
py:lambda env: env.num_males + env.num_females
extra:
visualization_params: {}

View File

@ -10,7 +10,7 @@ import networkx as nx
from .serialization import deserialize
from .utils import open_or_reuse, logger, timer
from .utils import try_backup, open_or_reuse, logger, timer
from . import utils, network
@ -91,12 +91,14 @@ class default(Exporter):
"""Default exporter. Writes sqlite results, as well as the simulation YAML"""
def sim_start(self):
if not self.dry_run:
if self.dry_run:
logger.info("NOT dumping results")
return
logger.info("Dumping results to %s", self.outdir)
with self.output(self.simulation.name + ".dumped.yml") as f:
f.write(self.simulation.to_yaml())
else:
logger.info("NOT dumping results")
self.dbpath = os.path.join(self.outdir, f"{self.simulation.name}.sqlite")
try_backup(self.dbpath, move=True)
def trial_end(self, env):
if self.dry_run:
@ -107,21 +109,23 @@ class default(Exporter):
"Dumping simulation {} trial {}".format(self.simulation.name, env.id)
):
fpath = os.path.join(self.outdir, f"{env.id}.sqlite")
engine = create_engine(f"sqlite:///{fpath}", echo=False)
engine = create_engine(f"sqlite:///{self.dbpath}", echo=False)
dc = env.datacollector
for (t, df) in get_dc_dfs(dc):
for (t, df) in get_dc_dfs(dc, trial_id=env.id):
df.to_sql(t, con=engine, if_exists="append")
def get_dc_dfs(dc):
def get_dc_dfs(dc, trial_id=None):
dfs = {
"env": dc.get_model_vars_dataframe(),
"agents": dc.get_agent_vars_dataframe(),
}
for table_name in dc.tables:
dfs[table_name] = dc.get_table_dataframe(table_name)
if trial_id:
for (name, df) in dfs.items():
df['trial_id'] = trial_id
yield from dfs.items()
@ -135,7 +139,7 @@ class csv(Exporter):
self.simulation.name, env.id, self.outdir
)
):
for (df_name, df) in get_dc_dfs(env.datacollector):
for (df_name, df) in get_dc_dfs(env.datacollector, trial_id=env.id):
with self.output("{}.{}.csv".format(env.id, df_name)) as f:
df.to_csv(f)

View File

@ -197,7 +197,7 @@ def deserializer(type_, known_modules=KNOWN_MODULES):
return getattr(cls, "deserialize", cls)
except (ImportError, AttributeError) as ex:
errors.append((modname, tname, ex))
raise Exception('Could not find type "{}". Tried: {}'.format(type_, errors))
raise ValueError('Could not find type "{}". Tried: {}'.format(type_, errors))
def deserialize(type_, value=None, globs=None, **kwargs):
@ -207,7 +207,13 @@ def deserialize(type_, value=None, globs=None, **kwargs):
if globs and type_ in globs:
des = globs[type_]
else:
try:
des = deserializer(type_, **kwargs)
except ValueError as ex:
try:
des = eval(type_)
except Exception:
raise ex
if value is None:
return des
return des(value)

View File

@ -151,7 +151,7 @@ class Simulation:
def deserialize_reporters(reporters):
for (k, v) in reporters.items():
if isinstance(v, str) and v.startswith("py:"):
reporters[k] = serialization.deserialize(value.lsplit(":", 1)[1])
reporters[k] = serialization.deserialize(v.split(":", 1)[1])
return reporters
params = self.model_params.copy()

View File

@ -4,7 +4,7 @@ import os
import traceback
from functools import partial
from shutil import copyfile
from shutil import copyfile, move
from multiprocessing import Pool
from contextlib import contextmanager
@ -47,11 +47,12 @@ def timer(name="task", pre="", function=logger.info, to_object=None):
to_object.end = end
def safe_open(path, mode="r", backup=True, **kwargs):
def try_backup(path, move=False):
if not os.path.exists(path):
return None
outdir = os.path.dirname(path)
if outdir and not os.path.exists(outdir):
os.makedirs(outdir)
if backup and "w" in mode and os.path.exists(path):
creation = os.path.getctime(path)
stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation))
@ -61,7 +62,19 @@ def safe_open(path, mode="r", backup=True, **kwargs):
newpath = os.path.join(
backup_dir, "{}@{}".format(os.path.basename(path), stamp)
)
if move:
move(path, newpath)
else:
copyfile(path, newpath)
return newpath
def safe_open(path, mode="r", backup=True, **kwargs):
outdir = os.path.dirname(path)
if outdir and not os.path.exists(outdir):
os.makedirs(outdir)
if backup and "w" in mode:
try_backup(path)
return open(path, mode=mode, **kwargs)
@ -70,7 +83,7 @@ def open_or_reuse(f, *args, **kwargs):
try:
with safe_open(f, *args, **kwargs) as f:
yield f
except (AttributeError, TypeError):
except (AttributeError, TypeError) as ex:
yield f

View File

@ -99,7 +99,7 @@ class Exporters(TestCase):
try:
for e in envs:
db = sqlite3.connect(os.path.join(simdir, f"{e.id}.sqlite"))
db = sqlite3.connect(os.path.join(simdir, f"{s.name}.sqlite"))
cur = db.cursor()
agent_entries = cur.execute("SELECT * from agents").fetchall()
env_entries = cur.execute("SELECT * from env").fetchall()