From 3802578ad57f74669b012c492692311707ed2731 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=2E=20Fernando=20S=C3=A1nchez?= Date: Mon, 24 Apr 2023 17:40:00 +0200 Subject: [PATCH] Use context manager to add source files --- soil/serialization.py | 20 +++++++++++++++--- soil/simulation.py | 47 +++++++++++++++++-------------------------- 2 files changed, 35 insertions(+), 32 deletions(-) diff --git a/soil/serialization.py b/soil/serialization.py index 34e7768..45d0083 100644 --- a/soil/serialization.py +++ b/soil/serialization.py @@ -8,6 +8,8 @@ import importlib.machinery, importlib.util from glob import glob from itertools import product, chain +from contextlib import contextmanager + import yaml import networkx as nx @@ -110,12 +112,12 @@ KNOWN_MODULES = { MODULE_FILES = {} -def add_source_file(file): +def _add_source_file(file): """Add a file to the list of known modules""" file = os.path.abspath(file) if file in MODULE_FILES: logger.warning(f"File {file} already added as module {MODULE_FILES[file]}. Reloading") - remove_source_file(file) + _remove_source_file(file) modname = f"imported_module_{len(MODULE_FILES)}" loader = importlib.machinery.SourceFileLoader(modname, file) spec = importlib.util.spec_from_loader(loader.name, loader) @@ -124,7 +126,7 @@ def add_source_file(file): MODULE_FILES[file] = modname KNOWN_MODULES[modname] = my_module -def remove_source_file(file): +def _remove_source_file(file): """Remove a file from the list of known modules""" file = os.path.abspath(file) modname = None @@ -134,6 +136,18 @@ def remove_source_file(file): except KeyError as ex: raise ValueError(f"File {file} had not been added as a module: {ex}") + +@contextmanager +def with_source(file=None): + """Add a file to the list of known modules, and remove it afterwards""" + if file: + _add_source_file(file) + try: + yield + finally: + if file: + _remove_source_file(file) + def get_module(modname): """Get a module from the list of known modules""" if modname not in KNOWN_MODULES or KNOWN_MODULES[modname] is None: diff --git a/soil/simulation.py b/soil/simulation.py index 636769e..c243208 100644 --- a/soil/simulation.py +++ b/soil/simulation.py @@ -119,28 +119,23 @@ class Simulation: self.logger = logger.getChild(self.name) self.logger.setLevel(self.level) - if self.source_file: - source_file = self.source_file - if not os.path.isabs(source_file): - source_file = os.path.abspath(os.path.join(self.dir_path, source_file)) - serialization.add_source_file(source_file) - self.source_file = source_file - - if isinstance(self.model, str): - self.model = serialization.deserialize(self.model) - - def deserialize_reporters(reporters): - for (k, v) in reporters.items(): - if isinstance(v, str) and v.startswith("py:"): - reporters[k] = serialization.deserialize(v.split(":", 1)[1]) - return reporters - - self.agent_reporters = deserialize_reporters(self.agent_reporters) - self.model_reporters = deserialize_reporters(self.model_reporters) - self.tables = deserialize_reporters(self.tables) - if self.source_file: - serialization.remove_source_file(self.source_file) - self.id = f"{self.name}_{current_time()}" + if self.source_file and (not os.path.isabs(self.source_file)): + self.source_file = os.path.abspath(os.path.join(self.dir_path, self.source_file)) + with serialization.with_source(self.source_file): + + if isinstance(self.model, str): + self.model = serialization.deserialize(self.model) + + def deserialize_reporters(reporters): + for (k, v) in reporters.items(): + if isinstance(v, str) and v.startswith("py:"): + reporters[k] = serialization.deserialize(v.split(":", 1)[1]) + return reporters + + self.agent_reporters = deserialize_reporters(self.agent_reporters) + self.model_reporters = deserialize_reporters(self.model_reporters) + self.tables = deserialize_reporters(self.tables) + self.id = f"{self.name}_{current_time()}" def run(self, **kwargs): """Run the simulation and return the list of resulting environments""" @@ -217,10 +212,7 @@ class Simulation: ): """Run the simulation and yield the resulting environments.""" - try: - if self.source_file: - serialization.add_source_file(self.source_file) - + with serialization.with_source(self.source_file): with utils.timer(f"running for config {params}"): if self.dry_run: def func(*args, **kwargs): @@ -237,9 +229,6 @@ class Simulation: continue yield env - finally: - if self.source_file: - serialization.remove_source_file(self.source_file) def _get_env(self, iteration_id, params): """Create an environment for a iteration of the simulation"""