Improve exporters

mesa
J. Fernando Sánchez 2 years ago
parent 78833a9e08
commit 0efcd24d90

@ -1,4 +1,4 @@
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment
from soil.time import Delta from soil.time import Delta
from enum import Enum from enum import Enum
from collections import Counter from collections import Counter
@ -6,7 +6,23 @@ import logging
import math import math
class RabbitModel(FSM, NetworkAgent): class RabbitEnv(Environment):
@property
def num_rabbits(self):
return self.count_agents(agent_class=Rabbit)
@property
def num_males(self):
return self.count_agents(agent_class=Male)
@property
def num_females(self):
return self.count_agents(agent_class=Female)
class Rabbit(FSM, NetworkAgent):
sexual_maturity = 30 sexual_maturity = 30
life_expectancy = 300 life_expectancy = 300
@ -35,7 +51,7 @@ class RabbitModel(FSM, NetworkAgent):
self.die() self.die()
class Male(RabbitModel): class Male(Rabbit):
max_females = 5 max_females = 5
mating_prob = 0.001 mating_prob = 0.001
@ -56,7 +72,7 @@ class Male(RabbitModel):
break # Take a break break # Take a break
class Female(RabbitModel): class Female(Rabbit):
gestation = 30 gestation = 30
@state @state
@ -119,7 +135,7 @@ class RandomAccident(BaseAgent):
prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive))) prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
self.debug('Killing some rabbits with prob={}!'.format(prob_death)) self.debug('Killing some rabbits with prob={}!'.format(prob_death))
for i in self.iter_agents(agent_class=RabbitModel): for i in self.iter_agents(agent_class=Rabbit):
if i.state_id == i.dead.id: if i.state_id == i.dead.id:
continue continue
if self.prob(prob_death): if self.prob(prob_death):

@ -7,11 +7,10 @@ description: null
group: null group: null
interval: 1.0 interval: 1.0
max_time: 100 max_time: 100
model_class: soil.environment.Environment model_class: rabbit_agents.RabbitEnv
model_params: model_params:
agents: agents:
topology: true topology: true
agent_class: rabbit_agents.RabbitModel
distribution: distribution:
- agent_class: rabbit_agents.Male - agent_class: rabbit_agents.Male
weight: 1 weight: 1
@ -34,5 +33,10 @@ model_params:
nodes: nodes:
- id: 1 - id: 1
- id: 0 - id: 0
model_reporters:
num_males: 'num_males'
num_females: 'num_females'
num_rabbits: |
py:lambda env: env.num_males + env.num_females
extra: extra:
visualization_params: {} visualization_params: {}

@ -10,7 +10,7 @@ import networkx as nx
from .serialization import deserialize from .serialization import deserialize
from .utils import open_or_reuse, logger, timer from .utils import try_backup, open_or_reuse, logger, timer
from . import utils, network from . import utils, network
@ -91,12 +91,14 @@ class default(Exporter):
"""Default exporter. Writes sqlite results, as well as the simulation YAML""" """Default exporter. Writes sqlite results, as well as the simulation YAML"""
def sim_start(self): def sim_start(self):
if not self.dry_run: if self.dry_run:
logger.info("Dumping results to %s", self.outdir)
with self.output(self.simulation.name + ".dumped.yml") as f:
f.write(self.simulation.to_yaml())
else:
logger.info("NOT dumping results") logger.info("NOT dumping results")
return
logger.info("Dumping results to %s", self.outdir)
with self.output(self.simulation.name + ".dumped.yml") as f:
f.write(self.simulation.to_yaml())
self.dbpath = os.path.join(self.outdir, f"{self.simulation.name}.sqlite")
try_backup(self.dbpath, move=True)
def trial_end(self, env): def trial_end(self, env):
if self.dry_run: if self.dry_run:
@ -107,21 +109,23 @@ class default(Exporter):
"Dumping simulation {} trial {}".format(self.simulation.name, env.id) "Dumping simulation {} trial {}".format(self.simulation.name, env.id)
): ):
fpath = os.path.join(self.outdir, f"{env.id}.sqlite") engine = create_engine(f"sqlite:///{self.dbpath}", echo=False)
engine = create_engine(f"sqlite:///{fpath}", echo=False)
dc = env.datacollector dc = env.datacollector
for (t, df) in get_dc_dfs(dc): for (t, df) in get_dc_dfs(dc, trial_id=env.id):
df.to_sql(t, con=engine, if_exists="append") df.to_sql(t, con=engine, if_exists="append")
def get_dc_dfs(dc): def get_dc_dfs(dc, trial_id=None):
dfs = { dfs = {
"env": dc.get_model_vars_dataframe(), "env": dc.get_model_vars_dataframe(),
"agents": dc.get_agent_vars_dataframe(), "agents": dc.get_agent_vars_dataframe(),
} }
for table_name in dc.tables: for table_name in dc.tables:
dfs[table_name] = dc.get_table_dataframe(table_name) dfs[table_name] = dc.get_table_dataframe(table_name)
if trial_id:
for (name, df) in dfs.items():
df['trial_id'] = trial_id
yield from dfs.items() yield from dfs.items()
@ -135,7 +139,7 @@ class csv(Exporter):
self.simulation.name, env.id, self.outdir self.simulation.name, env.id, self.outdir
) )
): ):
for (df_name, df) in get_dc_dfs(env.datacollector): for (df_name, df) in get_dc_dfs(env.datacollector, trial_id=env.id):
with self.output("{}.{}.csv".format(env.id, df_name)) as f: with self.output("{}.{}.csv".format(env.id, df_name)) as f:
df.to_csv(f) df.to_csv(f)

@ -197,7 +197,7 @@ def deserializer(type_, known_modules=KNOWN_MODULES):
return getattr(cls, "deserialize", cls) return getattr(cls, "deserialize", cls)
except (ImportError, AttributeError) as ex: except (ImportError, AttributeError) as ex:
errors.append((modname, tname, ex)) errors.append((modname, tname, ex))
raise Exception('Could not find type "{}". Tried: {}'.format(type_, errors)) raise ValueError('Could not find type "{}". Tried: {}'.format(type_, errors))
def deserialize(type_, value=None, globs=None, **kwargs): def deserialize(type_, value=None, globs=None, **kwargs):
@ -207,7 +207,13 @@ def deserialize(type_, value=None, globs=None, **kwargs):
if globs and type_ in globs: if globs and type_ in globs:
des = globs[type_] des = globs[type_]
else: else:
des = deserializer(type_, **kwargs) try:
des = deserializer(type_, **kwargs)
except ValueError as ex:
try:
des = eval(type_)
except Exception:
raise ex
if value is None: if value is None:
return des return des
return des(value) return des(value)

@ -151,7 +151,7 @@ class Simulation:
def deserialize_reporters(reporters): def deserialize_reporters(reporters):
for (k, v) in reporters.items(): for (k, v) in reporters.items():
if isinstance(v, str) and v.startswith("py:"): if isinstance(v, str) and v.startswith("py:"):
reporters[k] = serialization.deserialize(value.lsplit(":", 1)[1]) reporters[k] = serialization.deserialize(v.split(":", 1)[1])
return reporters return reporters
params = self.model_params.copy() params = self.model_params.copy()

@ -4,7 +4,7 @@ import os
import traceback import traceback
from functools import partial from functools import partial
from shutil import copyfile from shutil import copyfile, move
from multiprocessing import Pool from multiprocessing import Pool
from contextlib import contextmanager from contextlib import contextmanager
@ -47,21 +47,34 @@ def timer(name="task", pre="", function=logger.info, to_object=None):
to_object.end = end to_object.end = end
def safe_open(path, mode="r", backup=True, **kwargs): def try_backup(path, move=False):
if not os.path.exists(path):
return None
outdir = os.path.dirname(path) outdir = os.path.dirname(path)
if outdir and not os.path.exists(outdir): if outdir and not os.path.exists(outdir):
os.makedirs(outdir) os.makedirs(outdir)
if backup and "w" in mode and os.path.exists(path): creation = os.path.getctime(path)
creation = os.path.getctime(path) stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation))
stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation))
backup_dir = os.path.join(outdir, "backup")
backup_dir = os.path.join(outdir, "backup") if not os.path.exists(backup_dir):
if not os.path.exists(backup_dir): os.makedirs(backup_dir)
os.makedirs(backup_dir) newpath = os.path.join(
newpath = os.path.join( backup_dir, "{}@{}".format(os.path.basename(path), stamp)
backup_dir, "{}@{}".format(os.path.basename(path), stamp) )
) if move:
move(path, newpath)
else:
copyfile(path, newpath) copyfile(path, newpath)
return newpath
def safe_open(path, mode="r", backup=True, **kwargs):
outdir = os.path.dirname(path)
if outdir and not os.path.exists(outdir):
os.makedirs(outdir)
if backup and "w" in mode:
try_backup(path)
return open(path, mode=mode, **kwargs) return open(path, mode=mode, **kwargs)
@ -70,7 +83,7 @@ def open_or_reuse(f, *args, **kwargs):
try: try:
with safe_open(f, *args, **kwargs) as f: with safe_open(f, *args, **kwargs) as f:
yield f yield f
except (AttributeError, TypeError): except (AttributeError, TypeError) as ex:
yield f yield f

@ -99,7 +99,7 @@ class Exporters(TestCase):
try: try:
for e in envs: for e in envs:
db = sqlite3.connect(os.path.join(simdir, f"{e.id}.sqlite")) db = sqlite3.connect(os.path.join(simdir, f"{s.name}.sqlite"))
cur = db.cursor() cur = db.cursor()
agent_entries = cur.execute("SELECT * from agents").fetchall() agent_entries = cur.execute("SELECT * from agents").fetchall()
env_entries = cur.execute("SELECT * from env").fetchall() env_entries = cur.execute("SELECT * from env").fetchall()

Loading…
Cancel
Save