Use context manager to add source files

master
J. Fernando Sánchez 1 year ago
parent 4e296e0cf1
commit 3802578ad5

@ -8,6 +8,8 @@ import importlib.machinery, importlib.util
from glob import glob
from itertools import product, chain
from contextlib import contextmanager
import yaml
import networkx as nx
@ -110,12 +112,12 @@ KNOWN_MODULES = {
MODULE_FILES = {}
def add_source_file(file):
def _add_source_file(file):
"""Add a file to the list of known modules"""
file = os.path.abspath(file)
if file in MODULE_FILES:
logger.warning(f"File {file} already added as module {MODULE_FILES[file]}. Reloading")
remove_source_file(file)
_remove_source_file(file)
modname = f"imported_module_{len(MODULE_FILES)}"
loader = importlib.machinery.SourceFileLoader(modname, file)
spec = importlib.util.spec_from_loader(loader.name, loader)
@ -124,7 +126,7 @@ def add_source_file(file):
MODULE_FILES[file] = modname
KNOWN_MODULES[modname] = my_module
def remove_source_file(file):
def _remove_source_file(file):
"""Remove a file from the list of known modules"""
file = os.path.abspath(file)
modname = None
@ -134,6 +136,18 @@ def remove_source_file(file):
except KeyError as ex:
raise ValueError(f"File {file} had not been added as a module: {ex}")
@contextmanager
def with_source(file=None):
"""Add a file to the list of known modules, and remove it afterwards"""
if file:
_add_source_file(file)
try:
yield
finally:
if file:
_remove_source_file(file)
def get_module(modname):
"""Get a module from the list of known modules"""
if modname not in KNOWN_MODULES or KNOWN_MODULES[modname] is None:

@ -119,28 +119,23 @@ class Simulation:
self.logger = logger.getChild(self.name)
self.logger.setLevel(self.level)
if self.source_file:
source_file = self.source_file
if not os.path.isabs(source_file):
source_file = os.path.abspath(os.path.join(self.dir_path, source_file))
serialization.add_source_file(source_file)
self.source_file = source_file
if isinstance(self.model, str):
self.model = serialization.deserialize(self.model)
def deserialize_reporters(reporters):
for (k, v) in reporters.items():
if isinstance(v, str) and v.startswith("py:"):
reporters[k] = serialization.deserialize(v.split(":", 1)[1])
return reporters
self.agent_reporters = deserialize_reporters(self.agent_reporters)
self.model_reporters = deserialize_reporters(self.model_reporters)
self.tables = deserialize_reporters(self.tables)
if self.source_file:
serialization.remove_source_file(self.source_file)
self.id = f"{self.name}_{current_time()}"
if self.source_file and (not os.path.isabs(self.source_file)):
self.source_file = os.path.abspath(os.path.join(self.dir_path, self.source_file))
with serialization.with_source(self.source_file):
if isinstance(self.model, str):
self.model = serialization.deserialize(self.model)
def deserialize_reporters(reporters):
for (k, v) in reporters.items():
if isinstance(v, str) and v.startswith("py:"):
reporters[k] = serialization.deserialize(v.split(":", 1)[1])
return reporters
self.agent_reporters = deserialize_reporters(self.agent_reporters)
self.model_reporters = deserialize_reporters(self.model_reporters)
self.tables = deserialize_reporters(self.tables)
self.id = f"{self.name}_{current_time()}"
def run(self, **kwargs):
"""Run the simulation and return the list of resulting environments"""
@ -217,10 +212,7 @@ class Simulation:
):
"""Run the simulation and yield the resulting environments."""
try:
if self.source_file:
serialization.add_source_file(self.source_file)
with serialization.with_source(self.source_file):
with utils.timer(f"running for config {params}"):
if self.dry_run:
def func(*args, **kwargs):
@ -237,9 +229,6 @@ class Simulation:
continue
yield env
finally:
if self.source_file:
serialization.remove_source_file(self.source_file)
def _get_env(self, iteration_id, params):
"""Create an environment for a iteration of the simulation"""

Loading…
Cancel
Save