mirror of
https://github.com/gsi-upm/soil
synced 2024-11-21 18:52:28 +00:00
Improve exporters
This commit is contained in:
parent
78833a9e08
commit
0efcd24d90
@ -1,4 +1,4 @@
|
||||
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
|
||||
from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment
|
||||
from soil.time import Delta
|
||||
from enum import Enum
|
||||
from collections import Counter
|
||||
@ -6,7 +6,23 @@ import logging
|
||||
import math
|
||||
|
||||
|
||||
class RabbitModel(FSM, NetworkAgent):
|
||||
class RabbitEnv(Environment):
|
||||
|
||||
@property
|
||||
def num_rabbits(self):
|
||||
return self.count_agents(agent_class=Rabbit)
|
||||
|
||||
@property
|
||||
def num_males(self):
|
||||
return self.count_agents(agent_class=Male)
|
||||
|
||||
@property
|
||||
def num_females(self):
|
||||
return self.count_agents(agent_class=Female)
|
||||
|
||||
|
||||
|
||||
class Rabbit(FSM, NetworkAgent):
|
||||
|
||||
sexual_maturity = 30
|
||||
life_expectancy = 300
|
||||
@ -35,7 +51,7 @@ class RabbitModel(FSM, NetworkAgent):
|
||||
self.die()
|
||||
|
||||
|
||||
class Male(RabbitModel):
|
||||
class Male(Rabbit):
|
||||
max_females = 5
|
||||
mating_prob = 0.001
|
||||
|
||||
@ -56,7 +72,7 @@ class Male(RabbitModel):
|
||||
break # Take a break
|
||||
|
||||
|
||||
class Female(RabbitModel):
|
||||
class Female(Rabbit):
|
||||
gestation = 30
|
||||
|
||||
@state
|
||||
@ -119,7 +135,7 @@ class RandomAccident(BaseAgent):
|
||||
|
||||
prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
|
||||
self.debug('Killing some rabbits with prob={}!'.format(prob_death))
|
||||
for i in self.iter_agents(agent_class=RabbitModel):
|
||||
for i in self.iter_agents(agent_class=Rabbit):
|
||||
if i.state_id == i.dead.id:
|
||||
continue
|
||||
if self.prob(prob_death):
|
||||
|
@ -7,11 +7,10 @@ description: null
|
||||
group: null
|
||||
interval: 1.0
|
||||
max_time: 100
|
||||
model_class: soil.environment.Environment
|
||||
model_class: rabbit_agents.RabbitEnv
|
||||
model_params:
|
||||
agents:
|
||||
topology: true
|
||||
agent_class: rabbit_agents.RabbitModel
|
||||
distribution:
|
||||
- agent_class: rabbit_agents.Male
|
||||
weight: 1
|
||||
@ -34,5 +33,10 @@ model_params:
|
||||
nodes:
|
||||
- id: 1
|
||||
- id: 0
|
||||
model_reporters:
|
||||
num_males: 'num_males'
|
||||
num_females: 'num_females'
|
||||
num_rabbits: |
|
||||
py:lambda env: env.num_males + env.num_females
|
||||
extra:
|
||||
visualization_params: {}
|
||||
|
@ -10,7 +10,7 @@ import networkx as nx
|
||||
|
||||
|
||||
from .serialization import deserialize
|
||||
from .utils import open_or_reuse, logger, timer
|
||||
from .utils import try_backup, open_or_reuse, logger, timer
|
||||
|
||||
|
||||
from . import utils, network
|
||||
@ -91,12 +91,14 @@ class default(Exporter):
|
||||
"""Default exporter. Writes sqlite results, as well as the simulation YAML"""
|
||||
|
||||
def sim_start(self):
|
||||
if not self.dry_run:
|
||||
logger.info("Dumping results to %s", self.outdir)
|
||||
with self.output(self.simulation.name + ".dumped.yml") as f:
|
||||
f.write(self.simulation.to_yaml())
|
||||
else:
|
||||
if self.dry_run:
|
||||
logger.info("NOT dumping results")
|
||||
return
|
||||
logger.info("Dumping results to %s", self.outdir)
|
||||
with self.output(self.simulation.name + ".dumped.yml") as f:
|
||||
f.write(self.simulation.to_yaml())
|
||||
self.dbpath = os.path.join(self.outdir, f"{self.simulation.name}.sqlite")
|
||||
try_backup(self.dbpath, move=True)
|
||||
|
||||
def trial_end(self, env):
|
||||
if self.dry_run:
|
||||
@ -107,21 +109,23 @@ class default(Exporter):
|
||||
"Dumping simulation {} trial {}".format(self.simulation.name, env.id)
|
||||
):
|
||||
|
||||
fpath = os.path.join(self.outdir, f"{env.id}.sqlite")
|
||||
engine = create_engine(f"sqlite:///{fpath}", echo=False)
|
||||
engine = create_engine(f"sqlite:///{self.dbpath}", echo=False)
|
||||
|
||||
dc = env.datacollector
|
||||
for (t, df) in get_dc_dfs(dc):
|
||||
for (t, df) in get_dc_dfs(dc, trial_id=env.id):
|
||||
df.to_sql(t, con=engine, if_exists="append")
|
||||
|
||||
|
||||
def get_dc_dfs(dc):
|
||||
def get_dc_dfs(dc, trial_id=None):
|
||||
dfs = {
|
||||
"env": dc.get_model_vars_dataframe(),
|
||||
"agents": dc.get_agent_vars_dataframe(),
|
||||
}
|
||||
for table_name in dc.tables:
|
||||
dfs[table_name] = dc.get_table_dataframe(table_name)
|
||||
if trial_id:
|
||||
for (name, df) in dfs.items():
|
||||
df['trial_id'] = trial_id
|
||||
yield from dfs.items()
|
||||
|
||||
|
||||
@ -135,7 +139,7 @@ class csv(Exporter):
|
||||
self.simulation.name, env.id, self.outdir
|
||||
)
|
||||
):
|
||||
for (df_name, df) in get_dc_dfs(env.datacollector):
|
||||
for (df_name, df) in get_dc_dfs(env.datacollector, trial_id=env.id):
|
||||
with self.output("{}.{}.csv".format(env.id, df_name)) as f:
|
||||
df.to_csv(f)
|
||||
|
||||
|
@ -197,7 +197,7 @@ def deserializer(type_, known_modules=KNOWN_MODULES):
|
||||
return getattr(cls, "deserialize", cls)
|
||||
except (ImportError, AttributeError) as ex:
|
||||
errors.append((modname, tname, ex))
|
||||
raise Exception('Could not find type "{}". Tried: {}'.format(type_, errors))
|
||||
raise ValueError('Could not find type "{}". Tried: {}'.format(type_, errors))
|
||||
|
||||
|
||||
def deserialize(type_, value=None, globs=None, **kwargs):
|
||||
@ -207,7 +207,13 @@ def deserialize(type_, value=None, globs=None, **kwargs):
|
||||
if globs and type_ in globs:
|
||||
des = globs[type_]
|
||||
else:
|
||||
des = deserializer(type_, **kwargs)
|
||||
try:
|
||||
des = deserializer(type_, **kwargs)
|
||||
except ValueError as ex:
|
||||
try:
|
||||
des = eval(type_)
|
||||
except Exception:
|
||||
raise ex
|
||||
if value is None:
|
||||
return des
|
||||
return des(value)
|
||||
|
@ -151,7 +151,7 @@ class Simulation:
|
||||
def deserialize_reporters(reporters):
|
||||
for (k, v) in reporters.items():
|
||||
if isinstance(v, str) and v.startswith("py:"):
|
||||
reporters[k] = serialization.deserialize(value.lsplit(":", 1)[1])
|
||||
reporters[k] = serialization.deserialize(v.split(":", 1)[1])
|
||||
return reporters
|
||||
|
||||
params = self.model_params.copy()
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
import traceback
|
||||
|
||||
from functools import partial
|
||||
from shutil import copyfile
|
||||
from shutil import copyfile, move
|
||||
from multiprocessing import Pool
|
||||
|
||||
from contextlib import contextmanager
|
||||
@ -47,21 +47,34 @@ def timer(name="task", pre="", function=logger.info, to_object=None):
|
||||
to_object.end = end
|
||||
|
||||
|
||||
def try_backup(path, move=False):
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
outdir = os.path.dirname(path)
|
||||
if outdir and not os.path.exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
creation = os.path.getctime(path)
|
||||
stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation))
|
||||
|
||||
backup_dir = os.path.join(outdir, "backup")
|
||||
if not os.path.exists(backup_dir):
|
||||
os.makedirs(backup_dir)
|
||||
newpath = os.path.join(
|
||||
backup_dir, "{}@{}".format(os.path.basename(path), stamp)
|
||||
)
|
||||
if move:
|
||||
move(path, newpath)
|
||||
else:
|
||||
copyfile(path, newpath)
|
||||
return newpath
|
||||
|
||||
|
||||
def safe_open(path, mode="r", backup=True, **kwargs):
|
||||
outdir = os.path.dirname(path)
|
||||
if outdir and not os.path.exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
if backup and "w" in mode and os.path.exists(path):
|
||||
creation = os.path.getctime(path)
|
||||
stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation))
|
||||
|
||||
backup_dir = os.path.join(outdir, "backup")
|
||||
if not os.path.exists(backup_dir):
|
||||
os.makedirs(backup_dir)
|
||||
newpath = os.path.join(
|
||||
backup_dir, "{}@{}".format(os.path.basename(path), stamp)
|
||||
)
|
||||
copyfile(path, newpath)
|
||||
if backup and "w" in mode:
|
||||
try_backup(path)
|
||||
return open(path, mode=mode, **kwargs)
|
||||
|
||||
|
||||
@ -70,7 +83,7 @@ def open_or_reuse(f, *args, **kwargs):
|
||||
try:
|
||||
with safe_open(f, *args, **kwargs) as f:
|
||||
yield f
|
||||
except (AttributeError, TypeError):
|
||||
except (AttributeError, TypeError) as ex:
|
||||
yield f
|
||||
|
||||
|
||||
|
@ -99,7 +99,7 @@ class Exporters(TestCase):
|
||||
|
||||
try:
|
||||
for e in envs:
|
||||
db = sqlite3.connect(os.path.join(simdir, f"{e.id}.sqlite"))
|
||||
db = sqlite3.connect(os.path.join(simdir, f"{s.name}.sqlite"))
|
||||
cur = db.cursor()
|
||||
agent_entries = cur.execute("SELECT * from agents").fetchall()
|
||||
env_entries = cur.execute("SELECT * from env").fetchall()
|
||||
|
Loading…
Reference in New Issue
Block a user