mirror of https://github.com/gsi-upm/soil
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
265 lines
7.9 KiB
Python
265 lines
7.9 KiB
Python
import os
|
|
from time import time as current_time, strftime
|
|
import importlib
|
|
import sys
|
|
import yaml
|
|
import traceback
|
|
import inspect
|
|
import logging
|
|
import networkx as nx
|
|
|
|
from textwrap import dedent
|
|
|
|
from dataclasses import dataclass, field, asdict
|
|
from typing import Any, Dict, Union, Optional, List
|
|
|
|
|
|
from networkx.readwrite import json_graph
|
|
from functools import partial
|
|
import pickle
|
|
|
|
from . import serialization, exporters, utils, basestring, agents
|
|
from .environment import Environment
|
|
from .utils import logger, run_and_return_exceptions
|
|
from .time import INFINITY
|
|
from .config import Config, convert_old
|
|
|
|
|
|
# TODO: change documentation for simulation
|
|
@dataclass
|
|
class Simulation:
|
|
"""
|
|
Parameters
|
|
---------
|
|
config (optional): :class:`config.Config`
|
|
name of the Simulation
|
|
|
|
kwargs: parameters to use to initialize a new configuration, if one not been provided.
|
|
"""
|
|
|
|
version: str = "2"
|
|
name: str = "Unnamed simulation"
|
|
description: Optional[str] = ""
|
|
group: str = None
|
|
model_class: Union[str, type] = "soil.Environment"
|
|
model_params: dict = field(default_factory=dict)
|
|
seed: str = field(default_factory=lambda: current_time())
|
|
dir_path: str = field(default_factory=lambda: os.getcwd())
|
|
max_time: float = float("inf")
|
|
max_steps: int = -1
|
|
interval: int = 1
|
|
num_trials: int = 3
|
|
parallel: Optional[bool] = None
|
|
exporters: Optional[List[str]] = field(default_factory=list)
|
|
outdir: Optional[str] = None
|
|
exporter_params: Optional[Dict[str, Any]] = field(default_factory=dict)
|
|
dry_run: bool = False
|
|
extra: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
@classmethod
|
|
def from_dict(cls, env, **kwargs):
|
|
|
|
ignored = {
|
|
k: v for k, v in env.items() if k not in inspect.signature(cls).parameters
|
|
}
|
|
|
|
d = {k: v for k, v in env.items() if k not in ignored}
|
|
if ignored:
|
|
d.setdefault("extra", {}).update(ignored)
|
|
if ignored:
|
|
print(f'Warning: Ignoring these parameters (added to "extra"): { ignored }')
|
|
d.update(kwargs)
|
|
|
|
return cls(**d)
|
|
|
|
def run_simulation(self, *args, **kwargs):
|
|
return self.run(*args, **kwargs)
|
|
|
|
def run(self, *args, **kwargs):
|
|
"""Run the simulation and return the list of resulting environments"""
|
|
logger.info(
|
|
dedent(
|
|
"""
|
|
Simulation:
|
|
---
|
|
"""
|
|
)
|
|
+ self.to_yaml()
|
|
)
|
|
return list(self.run_gen(*args, **kwargs))
|
|
|
|
def run_gen(
|
|
self,
|
|
parallel=False,
|
|
dry_run=None,
|
|
exporters=None,
|
|
outdir=None,
|
|
exporter_params={},
|
|
log_level=None,
|
|
**kwargs,
|
|
):
|
|
"""Run the simulation and yield the resulting environments."""
|
|
if log_level:
|
|
logger.setLevel(log_level)
|
|
outdir = outdir or self.outdir
|
|
logger.info("Using exporters: %s", exporters or [])
|
|
logger.info("Output directory: %s", outdir)
|
|
if dry_run is None:
|
|
dry_run = self.dry_run
|
|
if exporters is None:
|
|
exporters = self.exporters
|
|
if not exporter_params:
|
|
exporter_params = self.exporter_params
|
|
|
|
exporters = serialization.deserialize_all(
|
|
exporters,
|
|
simulation=self,
|
|
known_modules=[
|
|
"soil.exporters",
|
|
],
|
|
dry_run=dry_run,
|
|
outdir=outdir,
|
|
**exporter_params,
|
|
)
|
|
|
|
with utils.timer("simulation {}".format(self.name)):
|
|
for exporter in exporters:
|
|
exporter.sim_start()
|
|
|
|
for env in utils.run_parallel(
|
|
func=self.run_trial,
|
|
iterable=range(int(self.num_trials)),
|
|
parallel=parallel,
|
|
log_level=log_level,
|
|
**kwargs,
|
|
):
|
|
|
|
for exporter in exporters:
|
|
exporter.trial_start(env)
|
|
|
|
for exporter in exporters:
|
|
exporter.trial_end(env)
|
|
|
|
yield env
|
|
|
|
for exporter in exporters:
|
|
exporter.sim_end()
|
|
|
|
def get_env(self, trial_id=0, model_params=None, **kwargs):
|
|
"""Create an environment for a trial of the simulation"""
|
|
|
|
def deserialize_reporters(reporters):
|
|
for (k, v) in reporters.items():
|
|
if isinstance(v, str) and v.startswith("py:"):
|
|
reporters[k] = serialization.deserialize(v.split(":", 1)[1])
|
|
return reporters
|
|
|
|
params = self.model_params.copy()
|
|
if model_params:
|
|
params.update(model_params)
|
|
params.update(kwargs)
|
|
|
|
agent_reporters = deserialize_reporters(params.pop("agent_reporters", {}))
|
|
model_reporters = deserialize_reporters(params.pop("model_reporters", {}))
|
|
|
|
env = serialization.deserialize(self.model_class)
|
|
return env(
|
|
id=f"{self.name}_trial_{trial_id}",
|
|
seed=f"{self.seed}_trial_{trial_id}",
|
|
dir_path=self.dir_path,
|
|
agent_reporters=agent_reporters,
|
|
model_reporters=model_reporters,
|
|
**params,
|
|
)
|
|
|
|
def run_trial(
|
|
self, trial_id=None, until=None, log_file=False, log_level=logging.INFO, **opts
|
|
):
|
|
"""
|
|
Run a single trial of the simulation
|
|
|
|
"""
|
|
if log_level:
|
|
logger.setLevel(log_level)
|
|
model = self.get_env(trial_id, **opts)
|
|
trial_id = trial_id if trial_id is not None else current_time()
|
|
with utils.timer("Simulation {} trial {}".format(self.name, trial_id)):
|
|
return self.run_model(
|
|
model=model, trial_id=trial_id, until=until, log_level=log_level
|
|
)
|
|
|
|
def run_model(self, model, until=None, **opts):
|
|
# Set-up trial environment and graph
|
|
until = float(until or self.max_time or "inf")
|
|
|
|
# Set up agents on nodes
|
|
def is_done():
|
|
return False
|
|
|
|
if until and hasattr(model.schedule, "time"):
|
|
prev = is_done
|
|
|
|
def is_done():
|
|
return prev() or model.schedule.time >= until
|
|
|
|
if self.max_steps and self.max_steps > 0 and hasattr(model.schedule, "steps"):
|
|
prev_steps = is_done
|
|
|
|
def is_done():
|
|
return prev_steps() or model.schedule.steps >= self.max_steps
|
|
|
|
newline = "\n"
|
|
logger.info(
|
|
dedent(
|
|
f"""
|
|
Model stats:
|
|
Agents (total: { model.schedule.get_agent_count() }):
|
|
- { (newline + ' - ').join(str(a) for a in model.schedule.agents) }
|
|
|
|
Topology size: { len(model.G) if hasattr(model, "G") else 0 }
|
|
"""
|
|
)
|
|
)
|
|
|
|
while not is_done():
|
|
utils.logger.debug(
|
|
f'Simulation time {model.schedule.time}/{until}. Next: {getattr(model.schedule, "next_time", model.schedule.time + self.interval)}'
|
|
)
|
|
model.step()
|
|
return model
|
|
|
|
def to_dict(self):
|
|
d = asdict(self)
|
|
if not isinstance(d["model_class"], str):
|
|
d["model_class"] = serialization.name(d["model_class"])
|
|
d["model_params"] = serialization.serialize_dict(d["model_params"])
|
|
d["dir_path"] = str(d["dir_path"])
|
|
d["version"] = "2"
|
|
return d
|
|
|
|
def to_yaml(self):
|
|
return yaml.dump(self.to_dict())
|
|
|
|
|
|
def iter_from_config(*cfgs, **kwargs):
|
|
for config in cfgs:
|
|
configs = list(serialization.load_config(config))
|
|
for config, path in configs:
|
|
d = dict(config)
|
|
if "dir_path" not in d:
|
|
d["dir_path"] = os.path.dirname(path)
|
|
yield Simulation.from_dict(d, **kwargs)
|
|
|
|
|
|
def from_config(conf_or_path):
|
|
lst = list(iter_from_config(conf_or_path))
|
|
if len(lst) > 1:
|
|
raise AttributeError("Provide only one configuration")
|
|
return lst[0]
|
|
|
|
|
|
def run_from_config(*configs, **kwargs):
|
|
for sim in iter_from_config(*configs):
|
|
logger.info(f"Using config(s): {sim.name}")
|
|
sim.run_simulation(**kwargs)
|