mirror of
https://github.com/gsi-upm/soil
synced 2024-11-22 03:02:28 +00:00
Use context manager to add source files
This commit is contained in:
parent
4e296e0cf1
commit
3802578ad5
@ -8,6 +8,8 @@ import importlib.machinery, importlib.util
|
|||||||
from glob import glob
|
from glob import glob
|
||||||
from itertools import product, chain
|
from itertools import product, chain
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
@ -110,12 +112,12 @@ KNOWN_MODULES = {
|
|||||||
|
|
||||||
MODULE_FILES = {}
|
MODULE_FILES = {}
|
||||||
|
|
||||||
def add_source_file(file):
|
def _add_source_file(file):
|
||||||
"""Add a file to the list of known modules"""
|
"""Add a file to the list of known modules"""
|
||||||
file = os.path.abspath(file)
|
file = os.path.abspath(file)
|
||||||
if file in MODULE_FILES:
|
if file in MODULE_FILES:
|
||||||
logger.warning(f"File {file} already added as module {MODULE_FILES[file]}. Reloading")
|
logger.warning(f"File {file} already added as module {MODULE_FILES[file]}. Reloading")
|
||||||
remove_source_file(file)
|
_remove_source_file(file)
|
||||||
modname = f"imported_module_{len(MODULE_FILES)}"
|
modname = f"imported_module_{len(MODULE_FILES)}"
|
||||||
loader = importlib.machinery.SourceFileLoader(modname, file)
|
loader = importlib.machinery.SourceFileLoader(modname, file)
|
||||||
spec = importlib.util.spec_from_loader(loader.name, loader)
|
spec = importlib.util.spec_from_loader(loader.name, loader)
|
||||||
@ -124,7 +126,7 @@ def add_source_file(file):
|
|||||||
MODULE_FILES[file] = modname
|
MODULE_FILES[file] = modname
|
||||||
KNOWN_MODULES[modname] = my_module
|
KNOWN_MODULES[modname] = my_module
|
||||||
|
|
||||||
def remove_source_file(file):
|
def _remove_source_file(file):
|
||||||
"""Remove a file from the list of known modules"""
|
"""Remove a file from the list of known modules"""
|
||||||
file = os.path.abspath(file)
|
file = os.path.abspath(file)
|
||||||
modname = None
|
modname = None
|
||||||
@ -134,6 +136,18 @@ def remove_source_file(file):
|
|||||||
except KeyError as ex:
|
except KeyError as ex:
|
||||||
raise ValueError(f"File {file} had not been added as a module: {ex}")
|
raise ValueError(f"File {file} had not been added as a module: {ex}")
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def with_source(file=None):
|
||||||
|
"""Add a file to the list of known modules, and remove it afterwards"""
|
||||||
|
if file:
|
||||||
|
_add_source_file(file)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if file:
|
||||||
|
_remove_source_file(file)
|
||||||
|
|
||||||
def get_module(modname):
|
def get_module(modname):
|
||||||
"""Get a module from the list of known modules"""
|
"""Get a module from the list of known modules"""
|
||||||
if modname not in KNOWN_MODULES or KNOWN_MODULES[modname] is None:
|
if modname not in KNOWN_MODULES or KNOWN_MODULES[modname] is None:
|
||||||
|
@ -119,12 +119,9 @@ class Simulation:
|
|||||||
self.logger = logger.getChild(self.name)
|
self.logger = logger.getChild(self.name)
|
||||||
self.logger.setLevel(self.level)
|
self.logger.setLevel(self.level)
|
||||||
|
|
||||||
if self.source_file:
|
if self.source_file and (not os.path.isabs(self.source_file)):
|
||||||
source_file = self.source_file
|
self.source_file = os.path.abspath(os.path.join(self.dir_path, self.source_file))
|
||||||
if not os.path.isabs(source_file):
|
with serialization.with_source(self.source_file):
|
||||||
source_file = os.path.abspath(os.path.join(self.dir_path, source_file))
|
|
||||||
serialization.add_source_file(source_file)
|
|
||||||
self.source_file = source_file
|
|
||||||
|
|
||||||
if isinstance(self.model, str):
|
if isinstance(self.model, str):
|
||||||
self.model = serialization.deserialize(self.model)
|
self.model = serialization.deserialize(self.model)
|
||||||
@ -138,8 +135,6 @@ class Simulation:
|
|||||||
self.agent_reporters = deserialize_reporters(self.agent_reporters)
|
self.agent_reporters = deserialize_reporters(self.agent_reporters)
|
||||||
self.model_reporters = deserialize_reporters(self.model_reporters)
|
self.model_reporters = deserialize_reporters(self.model_reporters)
|
||||||
self.tables = deserialize_reporters(self.tables)
|
self.tables = deserialize_reporters(self.tables)
|
||||||
if self.source_file:
|
|
||||||
serialization.remove_source_file(self.source_file)
|
|
||||||
self.id = f"{self.name}_{current_time()}"
|
self.id = f"{self.name}_{current_time()}"
|
||||||
|
|
||||||
def run(self, **kwargs):
|
def run(self, **kwargs):
|
||||||
@ -217,10 +212,7 @@ class Simulation:
|
|||||||
):
|
):
|
||||||
"""Run the simulation and yield the resulting environments."""
|
"""Run the simulation and yield the resulting environments."""
|
||||||
|
|
||||||
try:
|
with serialization.with_source(self.source_file):
|
||||||
if self.source_file:
|
|
||||||
serialization.add_source_file(self.source_file)
|
|
||||||
|
|
||||||
with utils.timer(f"running for config {params}"):
|
with utils.timer(f"running for config {params}"):
|
||||||
if self.dry_run:
|
if self.dry_run:
|
||||||
def func(*args, **kwargs):
|
def func(*args, **kwargs):
|
||||||
@ -237,9 +229,6 @@ class Simulation:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
yield env
|
yield env
|
||||||
finally:
|
|
||||||
if self.source_file:
|
|
||||||
serialization.remove_source_file(self.source_file)
|
|
||||||
|
|
||||||
def _get_env(self, iteration_id, params):
|
def _get_env(self, iteration_id, params):
|
||||||
"""Create an environment for a iteration of the simulation"""
|
"""Create an environment for a iteration of the simulation"""
|
||||||
|
Loading…
Reference in New Issue
Block a user