1
0
mirror of https://github.com/gsi-upm/soil synced 2024-11-21 18:52:28 +00:00

Use context manager to add source files

This commit is contained in:
J. Fernando Sánchez 2023-04-24 17:40:00 +02:00
parent 4e296e0cf1
commit 3802578ad5
2 changed files with 32 additions and 29 deletions

View File

@ -8,6 +8,8 @@ import importlib.machinery, importlib.util
from glob import glob from glob import glob
from itertools import product, chain from itertools import product, chain
from contextlib import contextmanager
import yaml import yaml
import networkx as nx import networkx as nx
@ -110,12 +112,12 @@ KNOWN_MODULES = {
MODULE_FILES = {} MODULE_FILES = {}
def add_source_file(file): def _add_source_file(file):
"""Add a file to the list of known modules""" """Add a file to the list of known modules"""
file = os.path.abspath(file) file = os.path.abspath(file)
if file in MODULE_FILES: if file in MODULE_FILES:
logger.warning(f"File {file} already added as module {MODULE_FILES[file]}. Reloading") logger.warning(f"File {file} already added as module {MODULE_FILES[file]}. Reloading")
remove_source_file(file) _remove_source_file(file)
modname = f"imported_module_{len(MODULE_FILES)}" modname = f"imported_module_{len(MODULE_FILES)}"
loader = importlib.machinery.SourceFileLoader(modname, file) loader = importlib.machinery.SourceFileLoader(modname, file)
spec = importlib.util.spec_from_loader(loader.name, loader) spec = importlib.util.spec_from_loader(loader.name, loader)
@ -124,7 +126,7 @@ def add_source_file(file):
MODULE_FILES[file] = modname MODULE_FILES[file] = modname
KNOWN_MODULES[modname] = my_module KNOWN_MODULES[modname] = my_module
def remove_source_file(file): def _remove_source_file(file):
"""Remove a file from the list of known modules""" """Remove a file from the list of known modules"""
file = os.path.abspath(file) file = os.path.abspath(file)
modname = None modname = None
@ -134,6 +136,18 @@ def remove_source_file(file):
except KeyError as ex: except KeyError as ex:
raise ValueError(f"File {file} had not been added as a module: {ex}") raise ValueError(f"File {file} had not been added as a module: {ex}")
@contextmanager
def with_source(file=None):
"""Add a file to the list of known modules, and remove it afterwards"""
if file:
_add_source_file(file)
try:
yield
finally:
if file:
_remove_source_file(file)
def get_module(modname): def get_module(modname):
"""Get a module from the list of known modules""" """Get a module from the list of known modules"""
if modname not in KNOWN_MODULES or KNOWN_MODULES[modname] is None: if modname not in KNOWN_MODULES or KNOWN_MODULES[modname] is None:

View File

@ -119,28 +119,23 @@ class Simulation:
self.logger = logger.getChild(self.name) self.logger = logger.getChild(self.name)
self.logger.setLevel(self.level) self.logger.setLevel(self.level)
if self.source_file: if self.source_file and (not os.path.isabs(self.source_file)):
source_file = self.source_file self.source_file = os.path.abspath(os.path.join(self.dir_path, self.source_file))
if not os.path.isabs(source_file): with serialization.with_source(self.source_file):
source_file = os.path.abspath(os.path.join(self.dir_path, source_file))
serialization.add_source_file(source_file)
self.source_file = source_file
if isinstance(self.model, str): if isinstance(self.model, str):
self.model = serialization.deserialize(self.model) self.model = serialization.deserialize(self.model)
def deserialize_reporters(reporters): def deserialize_reporters(reporters):
for (k, v) in reporters.items(): for (k, v) in reporters.items():
if isinstance(v, str) and v.startswith("py:"): if isinstance(v, str) and v.startswith("py:"):
reporters[k] = serialization.deserialize(v.split(":", 1)[1]) reporters[k] = serialization.deserialize(v.split(":", 1)[1])
return reporters return reporters
self.agent_reporters = deserialize_reporters(self.agent_reporters) self.agent_reporters = deserialize_reporters(self.agent_reporters)
self.model_reporters = deserialize_reporters(self.model_reporters) self.model_reporters = deserialize_reporters(self.model_reporters)
self.tables = deserialize_reporters(self.tables) self.tables = deserialize_reporters(self.tables)
if self.source_file: self.id = f"{self.name}_{current_time()}"
serialization.remove_source_file(self.source_file)
self.id = f"{self.name}_{current_time()}"
def run(self, **kwargs): def run(self, **kwargs):
"""Run the simulation and return the list of resulting environments""" """Run the simulation and return the list of resulting environments"""
@ -217,10 +212,7 @@ class Simulation:
): ):
"""Run the simulation and yield the resulting environments.""" """Run the simulation and yield the resulting environments."""
try: with serialization.with_source(self.source_file):
if self.source_file:
serialization.add_source_file(self.source_file)
with utils.timer(f"running for config {params}"): with utils.timer(f"running for config {params}"):
if self.dry_run: if self.dry_run:
def func(*args, **kwargs): def func(*args, **kwargs):
@ -237,9 +229,6 @@ class Simulation:
continue continue
yield env yield env
finally:
if self.source_file:
serialization.remove_source_file(self.source_file)
def _get_env(self, iteration_id, params): def _get_env(self, iteration_id, params):
"""Create an environment for a iteration of the simulation""" """Create an environment for a iteration of the simulation"""