mirror of
https://github.com/gsi-upm/soil
synced 2024-11-22 03:02:28 +00:00
Improve exporters
This commit is contained in:
parent
78833a9e08
commit
0efcd24d90
@ -1,4 +1,4 @@
|
|||||||
from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent
|
from soil import FSM, state, default_state, BaseAgent, NetworkAgent, Environment
|
||||||
from soil.time import Delta
|
from soil.time import Delta
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
@ -6,7 +6,23 @@ import logging
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
||||||
class RabbitModel(FSM, NetworkAgent):
|
class RabbitEnv(Environment):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_rabbits(self):
|
||||||
|
return self.count_agents(agent_class=Rabbit)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_males(self):
|
||||||
|
return self.count_agents(agent_class=Male)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_females(self):
|
||||||
|
return self.count_agents(agent_class=Female)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Rabbit(FSM, NetworkAgent):
|
||||||
|
|
||||||
sexual_maturity = 30
|
sexual_maturity = 30
|
||||||
life_expectancy = 300
|
life_expectancy = 300
|
||||||
@ -35,7 +51,7 @@ class RabbitModel(FSM, NetworkAgent):
|
|||||||
self.die()
|
self.die()
|
||||||
|
|
||||||
|
|
||||||
class Male(RabbitModel):
|
class Male(Rabbit):
|
||||||
max_females = 5
|
max_females = 5
|
||||||
mating_prob = 0.001
|
mating_prob = 0.001
|
||||||
|
|
||||||
@ -56,7 +72,7 @@ class Male(RabbitModel):
|
|||||||
break # Take a break
|
break # Take a break
|
||||||
|
|
||||||
|
|
||||||
class Female(RabbitModel):
|
class Female(Rabbit):
|
||||||
gestation = 30
|
gestation = 30
|
||||||
|
|
||||||
@state
|
@state
|
||||||
@ -119,7 +135,7 @@ class RandomAccident(BaseAgent):
|
|||||||
|
|
||||||
prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
|
prob_death = self.model.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive)))
|
||||||
self.debug('Killing some rabbits with prob={}!'.format(prob_death))
|
self.debug('Killing some rabbits with prob={}!'.format(prob_death))
|
||||||
for i in self.iter_agents(agent_class=RabbitModel):
|
for i in self.iter_agents(agent_class=Rabbit):
|
||||||
if i.state_id == i.dead.id:
|
if i.state_id == i.dead.id:
|
||||||
continue
|
continue
|
||||||
if self.prob(prob_death):
|
if self.prob(prob_death):
|
||||||
|
@ -7,11 +7,10 @@ description: null
|
|||||||
group: null
|
group: null
|
||||||
interval: 1.0
|
interval: 1.0
|
||||||
max_time: 100
|
max_time: 100
|
||||||
model_class: soil.environment.Environment
|
model_class: rabbit_agents.RabbitEnv
|
||||||
model_params:
|
model_params:
|
||||||
agents:
|
agents:
|
||||||
topology: true
|
topology: true
|
||||||
agent_class: rabbit_agents.RabbitModel
|
|
||||||
distribution:
|
distribution:
|
||||||
- agent_class: rabbit_agents.Male
|
- agent_class: rabbit_agents.Male
|
||||||
weight: 1
|
weight: 1
|
||||||
@ -34,5 +33,10 @@ model_params:
|
|||||||
nodes:
|
nodes:
|
||||||
- id: 1
|
- id: 1
|
||||||
- id: 0
|
- id: 0
|
||||||
|
model_reporters:
|
||||||
|
num_males: 'num_males'
|
||||||
|
num_females: 'num_females'
|
||||||
|
num_rabbits: |
|
||||||
|
py:lambda env: env.num_males + env.num_females
|
||||||
extra:
|
extra:
|
||||||
visualization_params: {}
|
visualization_params: {}
|
||||||
|
@ -10,7 +10,7 @@ import networkx as nx
|
|||||||
|
|
||||||
|
|
||||||
from .serialization import deserialize
|
from .serialization import deserialize
|
||||||
from .utils import open_or_reuse, logger, timer
|
from .utils import try_backup, open_or_reuse, logger, timer
|
||||||
|
|
||||||
|
|
||||||
from . import utils, network
|
from . import utils, network
|
||||||
@ -91,12 +91,14 @@ class default(Exporter):
|
|||||||
"""Default exporter. Writes sqlite results, as well as the simulation YAML"""
|
"""Default exporter. Writes sqlite results, as well as the simulation YAML"""
|
||||||
|
|
||||||
def sim_start(self):
|
def sim_start(self):
|
||||||
if not self.dry_run:
|
if self.dry_run:
|
||||||
|
logger.info("NOT dumping results")
|
||||||
|
return
|
||||||
logger.info("Dumping results to %s", self.outdir)
|
logger.info("Dumping results to %s", self.outdir)
|
||||||
with self.output(self.simulation.name + ".dumped.yml") as f:
|
with self.output(self.simulation.name + ".dumped.yml") as f:
|
||||||
f.write(self.simulation.to_yaml())
|
f.write(self.simulation.to_yaml())
|
||||||
else:
|
self.dbpath = os.path.join(self.outdir, f"{self.simulation.name}.sqlite")
|
||||||
logger.info("NOT dumping results")
|
try_backup(self.dbpath, move=True)
|
||||||
|
|
||||||
def trial_end(self, env):
|
def trial_end(self, env):
|
||||||
if self.dry_run:
|
if self.dry_run:
|
||||||
@ -107,21 +109,23 @@ class default(Exporter):
|
|||||||
"Dumping simulation {} trial {}".format(self.simulation.name, env.id)
|
"Dumping simulation {} trial {}".format(self.simulation.name, env.id)
|
||||||
):
|
):
|
||||||
|
|
||||||
fpath = os.path.join(self.outdir, f"{env.id}.sqlite")
|
engine = create_engine(f"sqlite:///{self.dbpath}", echo=False)
|
||||||
engine = create_engine(f"sqlite:///{fpath}", echo=False)
|
|
||||||
|
|
||||||
dc = env.datacollector
|
dc = env.datacollector
|
||||||
for (t, df) in get_dc_dfs(dc):
|
for (t, df) in get_dc_dfs(dc, trial_id=env.id):
|
||||||
df.to_sql(t, con=engine, if_exists="append")
|
df.to_sql(t, con=engine, if_exists="append")
|
||||||
|
|
||||||
|
|
||||||
def get_dc_dfs(dc):
|
def get_dc_dfs(dc, trial_id=None):
|
||||||
dfs = {
|
dfs = {
|
||||||
"env": dc.get_model_vars_dataframe(),
|
"env": dc.get_model_vars_dataframe(),
|
||||||
"agents": dc.get_agent_vars_dataframe(),
|
"agents": dc.get_agent_vars_dataframe(),
|
||||||
}
|
}
|
||||||
for table_name in dc.tables:
|
for table_name in dc.tables:
|
||||||
dfs[table_name] = dc.get_table_dataframe(table_name)
|
dfs[table_name] = dc.get_table_dataframe(table_name)
|
||||||
|
if trial_id:
|
||||||
|
for (name, df) in dfs.items():
|
||||||
|
df['trial_id'] = trial_id
|
||||||
yield from dfs.items()
|
yield from dfs.items()
|
||||||
|
|
||||||
|
|
||||||
@ -135,7 +139,7 @@ class csv(Exporter):
|
|||||||
self.simulation.name, env.id, self.outdir
|
self.simulation.name, env.id, self.outdir
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
for (df_name, df) in get_dc_dfs(env.datacollector):
|
for (df_name, df) in get_dc_dfs(env.datacollector, trial_id=env.id):
|
||||||
with self.output("{}.{}.csv".format(env.id, df_name)) as f:
|
with self.output("{}.{}.csv".format(env.id, df_name)) as f:
|
||||||
df.to_csv(f)
|
df.to_csv(f)
|
||||||
|
|
||||||
|
@ -197,7 +197,7 @@ def deserializer(type_, known_modules=KNOWN_MODULES):
|
|||||||
return getattr(cls, "deserialize", cls)
|
return getattr(cls, "deserialize", cls)
|
||||||
except (ImportError, AttributeError) as ex:
|
except (ImportError, AttributeError) as ex:
|
||||||
errors.append((modname, tname, ex))
|
errors.append((modname, tname, ex))
|
||||||
raise Exception('Could not find type "{}". Tried: {}'.format(type_, errors))
|
raise ValueError('Could not find type "{}". Tried: {}'.format(type_, errors))
|
||||||
|
|
||||||
|
|
||||||
def deserialize(type_, value=None, globs=None, **kwargs):
|
def deserialize(type_, value=None, globs=None, **kwargs):
|
||||||
@ -207,7 +207,13 @@ def deserialize(type_, value=None, globs=None, **kwargs):
|
|||||||
if globs and type_ in globs:
|
if globs and type_ in globs:
|
||||||
des = globs[type_]
|
des = globs[type_]
|
||||||
else:
|
else:
|
||||||
|
try:
|
||||||
des = deserializer(type_, **kwargs)
|
des = deserializer(type_, **kwargs)
|
||||||
|
except ValueError as ex:
|
||||||
|
try:
|
||||||
|
des = eval(type_)
|
||||||
|
except Exception:
|
||||||
|
raise ex
|
||||||
if value is None:
|
if value is None:
|
||||||
return des
|
return des
|
||||||
return des(value)
|
return des(value)
|
||||||
|
@ -151,7 +151,7 @@ class Simulation:
|
|||||||
def deserialize_reporters(reporters):
|
def deserialize_reporters(reporters):
|
||||||
for (k, v) in reporters.items():
|
for (k, v) in reporters.items():
|
||||||
if isinstance(v, str) and v.startswith("py:"):
|
if isinstance(v, str) and v.startswith("py:"):
|
||||||
reporters[k] = serialization.deserialize(value.lsplit(":", 1)[1])
|
reporters[k] = serialization.deserialize(v.split(":", 1)[1])
|
||||||
return reporters
|
return reporters
|
||||||
|
|
||||||
params = self.model_params.copy()
|
params = self.model_params.copy()
|
||||||
|
@ -4,7 +4,7 @@ import os
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from shutil import copyfile
|
from shutil import copyfile, move
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@ -47,11 +47,12 @@ def timer(name="task", pre="", function=logger.info, to_object=None):
|
|||||||
to_object.end = end
|
to_object.end = end
|
||||||
|
|
||||||
|
|
||||||
def safe_open(path, mode="r", backup=True, **kwargs):
|
def try_backup(path, move=False):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return None
|
||||||
outdir = os.path.dirname(path)
|
outdir = os.path.dirname(path)
|
||||||
if outdir and not os.path.exists(outdir):
|
if outdir and not os.path.exists(outdir):
|
||||||
os.makedirs(outdir)
|
os.makedirs(outdir)
|
||||||
if backup and "w" in mode and os.path.exists(path):
|
|
||||||
creation = os.path.getctime(path)
|
creation = os.path.getctime(path)
|
||||||
stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation))
|
stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation))
|
||||||
|
|
||||||
@ -61,7 +62,19 @@ def safe_open(path, mode="r", backup=True, **kwargs):
|
|||||||
newpath = os.path.join(
|
newpath = os.path.join(
|
||||||
backup_dir, "{}@{}".format(os.path.basename(path), stamp)
|
backup_dir, "{}@{}".format(os.path.basename(path), stamp)
|
||||||
)
|
)
|
||||||
|
if move:
|
||||||
|
move(path, newpath)
|
||||||
|
else:
|
||||||
copyfile(path, newpath)
|
copyfile(path, newpath)
|
||||||
|
return newpath
|
||||||
|
|
||||||
|
|
||||||
|
def safe_open(path, mode="r", backup=True, **kwargs):
|
||||||
|
outdir = os.path.dirname(path)
|
||||||
|
if outdir and not os.path.exists(outdir):
|
||||||
|
os.makedirs(outdir)
|
||||||
|
if backup and "w" in mode:
|
||||||
|
try_backup(path)
|
||||||
return open(path, mode=mode, **kwargs)
|
return open(path, mode=mode, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -70,7 +83,7 @@ def open_or_reuse(f, *args, **kwargs):
|
|||||||
try:
|
try:
|
||||||
with safe_open(f, *args, **kwargs) as f:
|
with safe_open(f, *args, **kwargs) as f:
|
||||||
yield f
|
yield f
|
||||||
except (AttributeError, TypeError):
|
except (AttributeError, TypeError) as ex:
|
||||||
yield f
|
yield f
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ class Exporters(TestCase):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
for e in envs:
|
for e in envs:
|
||||||
db = sqlite3.connect(os.path.join(simdir, f"{e.id}.sqlite"))
|
db = sqlite3.connect(os.path.join(simdir, f"{s.name}.sqlite"))
|
||||||
cur = db.cursor()
|
cur = db.cursor()
|
||||||
agent_entries = cur.execute("SELECT * from agents").fetchall()
|
agent_entries = cur.execute("SELECT * from agents").fetchall()
|
||||||
env_entries = cur.execute("SELECT * from env").fetchall()
|
env_entries = cur.execute("SELECT * from env").fetchall()
|
||||||
|
Loading…
Reference in New Issue
Block a user