mirror of
https://github.com/gsi-upm/soil
synced 2024-12-22 00:08:12 +00:00
Use context manager to add source files
This commit is contained in:
parent
4e296e0cf1
commit
3802578ad5
@ -8,6 +8,8 @@ import importlib.machinery, importlib.util
|
||||
from glob import glob
|
||||
from itertools import product, chain
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
import yaml
|
||||
import networkx as nx
|
||||
|
||||
@ -110,12 +112,12 @@ KNOWN_MODULES = {
|
||||
|
||||
MODULE_FILES = {}
|
||||
|
||||
def add_source_file(file):
|
||||
def _add_source_file(file):
|
||||
"""Add a file to the list of known modules"""
|
||||
file = os.path.abspath(file)
|
||||
if file in MODULE_FILES:
|
||||
logger.warning(f"File {file} already added as module {MODULE_FILES[file]}. Reloading")
|
||||
remove_source_file(file)
|
||||
_remove_source_file(file)
|
||||
modname = f"imported_module_{len(MODULE_FILES)}"
|
||||
loader = importlib.machinery.SourceFileLoader(modname, file)
|
||||
spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||
@ -124,7 +126,7 @@ def add_source_file(file):
|
||||
MODULE_FILES[file] = modname
|
||||
KNOWN_MODULES[modname] = my_module
|
||||
|
||||
def remove_source_file(file):
|
||||
def _remove_source_file(file):
|
||||
"""Remove a file from the list of known modules"""
|
||||
file = os.path.abspath(file)
|
||||
modname = None
|
||||
@ -134,6 +136,18 @@ def remove_source_file(file):
|
||||
except KeyError as ex:
|
||||
raise ValueError(f"File {file} had not been added as a module: {ex}")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def with_source(file=None):
|
||||
"""Add a file to the list of known modules, and remove it afterwards"""
|
||||
if file:
|
||||
_add_source_file(file)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if file:
|
||||
_remove_source_file(file)
|
||||
|
||||
def get_module(modname):
|
||||
"""Get a module from the list of known modules"""
|
||||
if modname not in KNOWN_MODULES or KNOWN_MODULES[modname] is None:
|
||||
|
@ -119,28 +119,23 @@ class Simulation:
|
||||
self.logger = logger.getChild(self.name)
|
||||
self.logger.setLevel(self.level)
|
||||
|
||||
if self.source_file:
|
||||
source_file = self.source_file
|
||||
if not os.path.isabs(source_file):
|
||||
source_file = os.path.abspath(os.path.join(self.dir_path, source_file))
|
||||
serialization.add_source_file(source_file)
|
||||
self.source_file = source_file
|
||||
if self.source_file and (not os.path.isabs(self.source_file)):
|
||||
self.source_file = os.path.abspath(os.path.join(self.dir_path, self.source_file))
|
||||
with serialization.with_source(self.source_file):
|
||||
|
||||
if isinstance(self.model, str):
|
||||
self.model = serialization.deserialize(self.model)
|
||||
if isinstance(self.model, str):
|
||||
self.model = serialization.deserialize(self.model)
|
||||
|
||||
def deserialize_reporters(reporters):
|
||||
for (k, v) in reporters.items():
|
||||
if isinstance(v, str) and v.startswith("py:"):
|
||||
reporters[k] = serialization.deserialize(v.split(":", 1)[1])
|
||||
return reporters
|
||||
def deserialize_reporters(reporters):
|
||||
for (k, v) in reporters.items():
|
||||
if isinstance(v, str) and v.startswith("py:"):
|
||||
reporters[k] = serialization.deserialize(v.split(":", 1)[1])
|
||||
return reporters
|
||||
|
||||
self.agent_reporters = deserialize_reporters(self.agent_reporters)
|
||||
self.model_reporters = deserialize_reporters(self.model_reporters)
|
||||
self.tables = deserialize_reporters(self.tables)
|
||||
if self.source_file:
|
||||
serialization.remove_source_file(self.source_file)
|
||||
self.id = f"{self.name}_{current_time()}"
|
||||
self.agent_reporters = deserialize_reporters(self.agent_reporters)
|
||||
self.model_reporters = deserialize_reporters(self.model_reporters)
|
||||
self.tables = deserialize_reporters(self.tables)
|
||||
self.id = f"{self.name}_{current_time()}"
|
||||
|
||||
def run(self, **kwargs):
|
||||
"""Run the simulation and return the list of resulting environments"""
|
||||
@ -217,10 +212,7 @@ class Simulation:
|
||||
):
|
||||
"""Run the simulation and yield the resulting environments."""
|
||||
|
||||
try:
|
||||
if self.source_file:
|
||||
serialization.add_source_file(self.source_file)
|
||||
|
||||
with serialization.with_source(self.source_file):
|
||||
with utils.timer(f"running for config {params}"):
|
||||
if self.dry_run:
|
||||
def func(*args, **kwargs):
|
||||
@ -237,9 +229,6 @@ class Simulation:
|
||||
continue
|
||||
|
||||
yield env
|
||||
finally:
|
||||
if self.source_file:
|
||||
serialization.remove_source_file(self.source_file)
|
||||
|
||||
def _get_env(self, iteration_id, params):
|
||||
"""Create an environment for a iteration of the simulation"""
|
||||
|
Loading…
Reference in New Issue
Block a user