From 5d759d0072568815fd8a806066961c0fa5024e04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=2E=20Fernando=20S=C3=A1nchez?= Date: Mon, 17 Oct 2022 13:58:14 +0200 Subject: [PATCH] Add conditional time values --- soil/environment.py | 2 +- soil/exporters.py | 93 +++++++++------------------- soil/simulation.py | 6 +- soil/time.py | 132 ++++++++++++++++++++++++++++++++-------- soil/utils.py | 2 +- tests/test_agents.py | 2 +- tests/test_exporters.py | 1 - 7 files changed, 141 insertions(+), 97 deletions(-) diff --git a/soil/environment.py b/soil/environment.py index d89585e..238c494 100644 --- a/soil/environment.py +++ b/soil/environment.py @@ -169,7 +169,7 @@ class BaseEnvironment(Model): Advance one step in the simulation, and update the data collection and scheduler appropriately """ super().step() - self.logger.info(f"--- Step {self.now:^5} ---") + self.logger.info(f"--- Step: {self.schedule.steps:^5} - Time: {self.now:^5} ---") self.schedule.step() self.datacollector.collect(self) diff --git a/soil/exporters.py b/soil/exporters.py index 55a5597..405b2f8 100644 --- a/soil/exporters.py +++ b/soil/exporters.py @@ -3,6 +3,7 @@ import sys from time import time as current_time from io import BytesIO from sqlalchemy import create_engine +from textwrap import dedent, indent import matplotlib.pyplot as plt @@ -86,6 +87,22 @@ class Exporter: pass return open_or_reuse(f, mode=mode, **kwargs) + def get_dfs(self, env): + yield from get_dc_dfs(env.datacollector, trial_id=env.id) + + +def get_dc_dfs(dc, trial_id=None): + dfs = { + "env": dc.get_model_vars_dataframe(), + "agents": dc.get_agent_vars_dataframe(), + } + for table_name in dc.tables: + dfs[table_name] = dc.get_table_dataframe(table_name) + if trial_id: + for (name, df) in dfs.items(): + df["trial_id"] = trial_id + yield from dfs.items() + class default(Exporter): """Default exporter. Writes sqlite results, as well as the simulation YAML""" @@ -98,7 +115,7 @@ class default(Exporter): with self.output(self.simulation.name + ".dumped.yml") as f: f.write(self.simulation.to_yaml()) self.dbpath = os.path.join(self.outdir, f"{self.simulation.name}.sqlite") - try_backup(self.dbpath, move=True) + try_backup(self.dbpath, remove=True) def trial_end(self, env): if self.dry_run: @@ -111,24 +128,10 @@ class default(Exporter): engine = create_engine(f"sqlite:///{self.dbpath}", echo=False) - dc = env.datacollector - for (t, df) in get_dc_dfs(dc, trial_id=env.id): + for (t, df) in self.get_dfs(env): df.to_sql(t, con=engine, if_exists="append") -def get_dc_dfs(dc, trial_id=None): - dfs = { - "env": dc.get_model_vars_dataframe(), - "agents": dc.get_agent_vars_dataframe(), - } - for table_name in dc.tables: - dfs[table_name] = dc.get_table_dataframe(table_name) - if trial_id: - for (name, df) in dfs.items(): - df["trial_id"] = trial_id - yield from dfs.items() - - class csv(Exporter): """Export the state of each environment (and its agents) in a separate CSV file""" @@ -139,7 +142,7 @@ class csv(Exporter): self.simulation.name, env.id, self.outdir ) ): - for (df_name, df) in get_dc_dfs(env.datacollector, trial_id=env.id): + for (df_name, df) in self.get_dfs(env): with self.output("{}.{}.csv".format(env.id, df_name)) as f: df.to_csv(f) @@ -192,52 +195,14 @@ class graphdrawing(Exporter): f.savefig(f) -""" -Convert an environment into a NetworkX graph -""" - - -def env_to_graph(env, history=None): - G = nx.Graph(env.G) - - for agent in env.network_agents: +class summary(Exporter): + """Print a summary of each trial to sys.stdout""" - attributes = {"agent": str(agent.__class__)} - lastattributes = {} - spells = [] - lastvisible = False - laststep = None - if not history: - history = sorted(list(env.state_to_tuples())) - for _, t_step, attribute, value in history: - if attribute == "visible": - nowvisible = value - if nowvisible and not lastvisible: - laststep = t_step - if not nowvisible and lastvisible: - spells.append((laststep, t_step)) - - lastvisible = nowvisible + def trial_end(self, env): + for (t, df) in self.get_dfs(env): + if not len(df): continue - key = "attr_" + attribute - if key not in attributes: - attributes[key] = list() - if key not in lastattributes: - lastattributes[key] = (value, t_step) - elif lastattributes[key][0] != value: - last_value, laststep = lastattributes[key] - commit_value = (last_value, laststep, t_step) - if key not in attributes: - attributes[key] = list() - attributes[key].append(commit_value) - lastattributes[key] = (value, t_step) - for k, v in lastattributes.items(): - attributes[k].append((v[0], v[1], None)) - if lastvisible: - spells.append((laststep, None)) - if spells: - G.add_node(agent.id, spells=spells, **attributes) - else: - G.add_node(agent.id, **attributes) - - return G + msg = indent(str(df.describe()), ' ') + logger.info(dedent(f''' + Dataframe {t}: + ''') + msg) diff --git a/soil/simulation.py b/soil/simulation.py index e5f5526..946023f 100644 --- a/soil/simulation.py +++ b/soil/simulation.py @@ -21,7 +21,6 @@ import pickle from . import serialization, exporters, utils, basestring, agents from .environment import Environment from .utils import logger, run_and_return_exceptions -from .time import INFINITY from .config import Config, convert_old @@ -194,7 +193,7 @@ class Simulation: # Set up agents on nodes def is_done(): - return False + return not model.running if until and hasattr(model.schedule, "time"): prev = is_done @@ -226,6 +225,9 @@ Model stats: f'Simulation time {model.schedule.time}/{until}. Next: {getattr(model.schedule, "next_time", model.schedule.time + self.interval)}' ) model.step() + + if model.schedule.time < until: # Simulation ended (no more steps) before until (i.e., no changes expected) + model.schedule.time = until return model def to_dict(self): diff --git a/soil/time.py b/soil/time.py index 11e3178..661e35e 100644 --- a/soil/time.py +++ b/soil/time.py @@ -2,6 +2,10 @@ from mesa.time import BaseScheduler from queue import Empty from heapq import heappush, heappop, heapify import math + +from inspect import getsource +from numbers import Number + from .utils import logger from mesa import Agent as MesaAgent @@ -15,9 +19,55 @@ class When: return time self._time = time - def abs(self, time): + def next(self, time): return self._time + def abs(self, time): + return self + + def __repr__(self): + return str(f"When({self._time})") + + def __lt__(self, other): + if isinstance(other, Number): + return self._time < other + return self._time < other.next(self._time) + + def __gt__(self, other): + if isinstance(other, Number): + return self._time > other + return self._time > other.next(self._time) + + def ready(self, time): + return self._time <= time + + +class Cond(When): + def __init__(self, func, delta=1): + self._func = func + self._delta = delta + + def next(self, time): + return time + self._delta + + def abs(self, time): + return self + + def ready(self, time): + return self._func(time) + + def __eq__(self, other): + return False + + def __lt__(self, other): + return True + + def __gt__(self, other): + return False + + def __repr__(self): + return str(f'Cond("{getsource(self._func)}")') + NEVER = When(INFINITY) @@ -27,11 +77,19 @@ class Delta(When): self._delta = delta def __eq__(self, other): - return self._delta == other._delta + if isinstance(other, Delta): + return self._delta == other._delta + return False def abs(self, time): + return When(self._delta + time) + + def next(self, time): return time + self._delta + def __repr__(self): + return str(f"Delta({self._delta})") + class TimedActivation(BaseScheduler): """A scheduler which activates each agent when the agent requests. @@ -47,14 +105,15 @@ class TimedActivation(BaseScheduler): def add(self, agent: MesaAgent, when=None): if when is None: - when = self.time + when = When(self.time) + elif not isinstance(when, When): + when = When(when) if agent.unique_id in self._agents: - self._queue.remove((self._next[agent.unique_id], agent.unique_id)) + self._queue.remove((self._next[agent.unique_id], agent)) del self._agents[agent.unique_id] heapify(self._queue) - heappush(self._queue, (when, agent.unique_id)) - self._next[agent.unique_id] = when + heappush(self._queue, (when, agent)) super().add(agent) def step(self) -> None: @@ -63,42 +122,61 @@ class TimedActivation(BaseScheduler): an agent will signal when it wants to be scheduled next. """ - self.logger.debug(f"Simulation step {self.next_time}") + self.logger.debug(f"Simulation step {self.time}") if not self.model.running: return - self.time = self.next_time - when = self.time + when = NEVER + + to_process = [] + skipped = [] + next_time = INFINITY - while self._queue and self._queue[0][0] == self.time: - (when, agent_id) = heappop(self._queue) - self.logger.debug(f"Stepping agent {agent_id}") + ix = 0 - agent = self._agents[agent_id] - returned = agent.step() + while self._queue: + (when, agent) = self._queue[0] + if when > self.time: + break + heappop(self._queue) + if when.ready(self.time): + to_process.append(agent) + continue + + next_time = min(next_time, when.next(self.time)) + self._next[agent.unique_id] = next_time + skipped.append((when, agent)) + + if self._queue: + next_time = min(next_time, self._queue[0][0].next(self.time)) + + self._queue = [*skipped, *self._queue] + + for agent in to_process: + self.logger.debug(f"Stepping agent {agent}") + + returned = ((agent.step() or Delta(1))).abs(self.time) if not getattr(agent, "alive", True): self.remove(agent) continue - when = (returned or Delta(1)).abs(self.time) - if when < self.time: + value = when.next(self.time) + + if value < self.time: raise Exception( - "Cannot schedule an agent for a time in the past ({} < {})".format( - when, self.time - ) + f"Cannot schedule an agent for a time in the past ({when} < {self.time})" ) + if value < INFINITY: + next_time = min(value, next_time) - self._next[agent_id] = when - heappush(self._queue, (when, agent_id)) + self._next[agent.unique_id] = returned + heappush(self._queue, (returned, agent)) self.steps += 1 + self.logger.debug(f"Updating time step: {self.time} -> {next_time}") + self.time = next_time - if not self._queue: - self.time = INFINITY - self.next_time = INFINITY + if not self._queue or next_time == INFINITY: self.model.running = False return self.time - - self.next_time = self._queue[0][0] - self.logger.debug(f"Next step: {self.next_time}") diff --git a/soil/utils.py b/soil/utils.py index 92d9d74..e1b3580 100644 --- a/soil/utils.py +++ b/soil/utils.py @@ -47,7 +47,7 @@ def timer(name="task", pre="", function=logger.info, to_object=None): to_object.end = end -def try_backup(path, move=False): +def try_backup(path, remove=False): if not os.path.exists(path): return None outdir = os.path.dirname(path) diff --git a/tests/test_agents.py b/tests/test_agents.py index 8603b1e..d3db80e 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -18,7 +18,7 @@ class TestMain(TestCase): d = Dead(unique_id=0, model=environment.Environment()) ret = d.step().abs(0) print(ret, "next") - assert ret == stime.INFINITY + assert ret == stime.NEVER def test_die_raises_exception(self): '''A dead agent should raise an exception if it is stepped after death''' diff --git a/tests/test_exporters.py b/tests/test_exporters.py index baf9d83..1b1b072 100644 --- a/tests/test_exporters.py +++ b/tests/test_exporters.py @@ -50,7 +50,6 @@ class Exporters(TestCase): for env in s.run_simulation(exporters=[Dummy], dry_run=True): assert len(env.agents) == 1 - assert env.now == max_time assert Dummy.started assert Dummy.ended