Use context manager to add source files

master
J. Fernando Sánchez 1 year ago
parent 4e296e0cf1
commit 3802578ad5

@ -8,6 +8,8 @@ import importlib.machinery, importlib.util
from glob import glob from glob import glob
from itertools import product, chain from itertools import product, chain
from contextlib import contextmanager
import yaml import yaml
import networkx as nx import networkx as nx
@ -110,12 +112,12 @@ KNOWN_MODULES = {
MODULE_FILES = {} MODULE_FILES = {}
def add_source_file(file): def _add_source_file(file):
"""Add a file to the list of known modules""" """Add a file to the list of known modules"""
file = os.path.abspath(file) file = os.path.abspath(file)
if file in MODULE_FILES: if file in MODULE_FILES:
logger.warning(f"File {file} already added as module {MODULE_FILES[file]}. Reloading") logger.warning(f"File {file} already added as module {MODULE_FILES[file]}. Reloading")
remove_source_file(file) _remove_source_file(file)
modname = f"imported_module_{len(MODULE_FILES)}" modname = f"imported_module_{len(MODULE_FILES)}"
loader = importlib.machinery.SourceFileLoader(modname, file) loader = importlib.machinery.SourceFileLoader(modname, file)
spec = importlib.util.spec_from_loader(loader.name, loader) spec = importlib.util.spec_from_loader(loader.name, loader)
@ -124,7 +126,7 @@ def add_source_file(file):
MODULE_FILES[file] = modname MODULE_FILES[file] = modname
KNOWN_MODULES[modname] = my_module KNOWN_MODULES[modname] = my_module
def remove_source_file(file): def _remove_source_file(file):
"""Remove a file from the list of known modules""" """Remove a file from the list of known modules"""
file = os.path.abspath(file) file = os.path.abspath(file)
modname = None modname = None
@ -134,6 +136,18 @@ def remove_source_file(file):
except KeyError as ex: except KeyError as ex:
raise ValueError(f"File {file} had not been added as a module: {ex}") raise ValueError(f"File {file} had not been added as a module: {ex}")
@contextmanager
def with_source(file=None):
"""Add a file to the list of known modules, and remove it afterwards"""
if file:
_add_source_file(file)
try:
yield
finally:
if file:
_remove_source_file(file)
def get_module(modname): def get_module(modname):
"""Get a module from the list of known modules""" """Get a module from the list of known modules"""
if modname not in KNOWN_MODULES or KNOWN_MODULES[modname] is None: if modname not in KNOWN_MODULES or KNOWN_MODULES[modname] is None:

@ -119,28 +119,23 @@ class Simulation:
self.logger = logger.getChild(self.name) self.logger = logger.getChild(self.name)
self.logger.setLevel(self.level) self.logger.setLevel(self.level)
if self.source_file: if self.source_file and (not os.path.isabs(self.source_file)):
source_file = self.source_file self.source_file = os.path.abspath(os.path.join(self.dir_path, self.source_file))
if not os.path.isabs(source_file): with serialization.with_source(self.source_file):
source_file = os.path.abspath(os.path.join(self.dir_path, source_file))
serialization.add_source_file(source_file) if isinstance(self.model, str):
self.source_file = source_file self.model = serialization.deserialize(self.model)
if isinstance(self.model, str): def deserialize_reporters(reporters):
self.model = serialization.deserialize(self.model) for (k, v) in reporters.items():
if isinstance(v, str) and v.startswith("py:"):
def deserialize_reporters(reporters): reporters[k] = serialization.deserialize(v.split(":", 1)[1])
for (k, v) in reporters.items(): return reporters
if isinstance(v, str) and v.startswith("py:"):
reporters[k] = serialization.deserialize(v.split(":", 1)[1]) self.agent_reporters = deserialize_reporters(self.agent_reporters)
return reporters self.model_reporters = deserialize_reporters(self.model_reporters)
self.tables = deserialize_reporters(self.tables)
self.agent_reporters = deserialize_reporters(self.agent_reporters) self.id = f"{self.name}_{current_time()}"
self.model_reporters = deserialize_reporters(self.model_reporters)
self.tables = deserialize_reporters(self.tables)
if self.source_file:
serialization.remove_source_file(self.source_file)
self.id = f"{self.name}_{current_time()}"
def run(self, **kwargs): def run(self, **kwargs):
"""Run the simulation and return the list of resulting environments""" """Run the simulation and return the list of resulting environments"""
@ -217,10 +212,7 @@ class Simulation:
): ):
"""Run the simulation and yield the resulting environments.""" """Run the simulation and yield the resulting environments."""
try: with serialization.with_source(self.source_file):
if self.source_file:
serialization.add_source_file(self.source_file)
with utils.timer(f"running for config {params}"): with utils.timer(f"running for config {params}"):
if self.dry_run: if self.dry_run:
def func(*args, **kwargs): def func(*args, **kwargs):
@ -237,9 +229,6 @@ class Simulation:
continue continue
yield env yield env
finally:
if self.source_file:
serialization.remove_source_file(self.source_file)
def _get_env(self, iteration_id, params): def _get_env(self, iteration_id, params):
"""Create an environment for a iteration of the simulation""" """Create an environment for a iteration of the simulation"""

Loading…
Cancel
Save