You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
soil/soil/serialization.py

283 lines
7.9 KiB
Python

5 years ago
import os
import logging
import ast
import sys
2 years ago
import re
5 years ago
import importlib
import importlib.machinery, importlib.util
5 years ago
from glob import glob
from itertools import product, chain
from contextlib import contextmanager
5 years ago
import yaml
import networkx as nx
from . import config
5 years ago
from jinja2 import Template
logger = logging.getLogger("soil")
5 years ago
def load_file(infile):
folder = os.path.dirname(infile)
if folder not in sys.path:
sys.path.append(folder)
with open(infile, "r") as f:
5 years ago
return list(chain.from_iterable(map(expand_template, load_string(f))))
5 years ago
5 years ago
def load_string(string):
yield from yaml.load_all(string, Loader=yaml.FullLoader)
5 years ago
def expand_template(config):
if "template" not in config:
5 years ago
yield config
return
if "vars" not in config:
raise ValueError(
("You must provide a definition of variables" " for the template.")
)
5 years ago
template = config["template"]
if not isinstance(template, str):
template = yaml.dump(template)
template = Template(template)
5 years ago
params = params_for_template(config)
blank_str = template.render({k: 0 for k in params[0].keys()})
blank = list(load_string(blank_str))
if len(blank) > 1:
raise ValueError("Templates must not return more than one configuration")
if "name" in blank[0]:
raise ValueError("Templates cannot be named, use group instead")
for ps in params:
string = template.render(ps)
for c in load_string(string):
yield c
def params_for_template(config):
sampler_config = config.get("sampler", {"N": 100})
sampler = sampler_config.pop("method", "SALib.sample.morris.sample")
sampler = deserializer(sampler)
bounds = config["vars"]["bounds"]
5 years ago
problem = {
"num_vars": len(bounds),
"names": list(bounds.keys()),
"bounds": list(v for v in bounds.values()),
5 years ago
}
samples = sampler(problem, **sampler_config)
5 years ago
lists = config["vars"].get("lists", {})
5 years ago
names = list(lists.keys())
values = list(lists.values())
combs = list(product(*values))
allnames = names + problem["names"]
allvalues = [(list(i[0]) + list(i[1])) for i in product(combs, samples)]
5 years ago
params = list(map(lambda x: dict(zip(allnames, x)), allvalues))
return params
5 years ago
def load_files(*patterns, **kwargs):
for pattern in patterns:
for i in glob(pattern, **kwargs, recursive=True):
for cfg in load_file(i):
5 years ago
path = os.path.abspath(i)
yield cfg, path
5 years ago
def load_config(cfg):
if isinstance(cfg, dict):
yield config.load_config(cfg), os.getcwd()
5 years ago
else:
yield from load_files(cfg)
5 years ago
1 year ago
_BUILTINS = None
def builtins():
global _BUILTINS
if not _BUILTINS:
_BUILTINS = importlib.import_module("builtins")
return _BUILTINS
5 years ago
KNOWN_MODULES = {
'soil': None,
}
MODULE_FILES = {}
def _add_source_file(file):
"""Add a file to the list of known modules"""
file = os.path.abspath(file)
if file in MODULE_FILES:
logger.warning(f"File {file} already added as module {MODULE_FILES[file]}. Reloading")
_remove_source_file(file)
modname = f"imported_module_{len(MODULE_FILES)}"
loader = importlib.machinery.SourceFileLoader(modname, file)
spec = importlib.util.spec_from_loader(loader.name, loader)
my_module = importlib.util.module_from_spec(spec)
loader.exec_module(my_module)
MODULE_FILES[file] = modname
KNOWN_MODULES[modname] = my_module
def _remove_source_file(file):
"""Remove a file from the list of known modules"""
file = os.path.abspath(file)
modname = None
try:
modname = MODULE_FILES.pop(file)
KNOWN_MODULES.pop(modname)
except KeyError as ex:
raise ValueError(f"File {file} had not been added as a module: {ex}")
@contextmanager
def with_source(file=None):
"""Add a file to the list of known modules, and remove it afterwards"""
if file:
_add_source_file(file)
try:
yield
finally:
if file:
_remove_source_file(file)
def get_module(modname):
"""Get a module from the list of known modules"""
if modname not in KNOWN_MODULES or KNOWN_MODULES[modname] is None:
module = importlib.import_module(modname)
KNOWN_MODULES[modname] = module
return KNOWN_MODULES[modname]
2 years ago
def name(value, known_modules=KNOWN_MODULES):
"""Return a name that can be imported, to serialize/deserialize an object"""
5 years ago
if value is None:
return "None"
5 years ago
if not isinstance(value, type): # Get the class name first
value = type(value)
tname = value.__name__
1 year ago
if hasattr(builtins(), tname):
5 years ago
return tname
modname = value.__module__
if modname == "__main__":
5 years ago
return tname
if known_modules and modname in known_modules:
return tname
for kmod in known_modules:
module = get_module(kmod)
5 years ago
if hasattr(module, tname):
return tname
return "{}.{}".format(modname, tname)
5 years ago
def serializer(type_):
1 year ago
if type_ != "str":
5 years ago
return repr
return lambda x: x
2 years ago
def serialize(v, known_modules=KNOWN_MODULES):
"""Get a text representation of an object."""
5 years ago
tname = name(v, known_modules=known_modules)
func = serializer(tname)
return func(v), tname
2 years ago
def serialize_dict(d, known_modules=KNOWN_MODULES):
try:
d = dict(d)
except (ValueError, TypeError) as ex:
return serialize(d)[0]
for (k, v) in reversed(list(d.items())):
if isinstance(v, dict):
d[k] = serialize_dict(v, known_modules=known_modules)
elif isinstance(v, list):
for ix in range(len(v)):
v[ix] = serialize_dict(v[ix], known_modules=known_modules)
elif isinstance(v, type):
d[k] = serialize(v, known_modules=known_modules)[1]
return d
2 years ago
IS_CLASS = re.compile(r"<class '(.*)'>")
2 years ago
def deserializer(type_, known_modules=KNOWN_MODULES):
5 years ago
if type(type_) != str: # Already deserialized
return type_
if type_ == "str":
return lambda x="": x
if type_ == "None":
5 years ago
return lambda x=None: None
1 year ago
if hasattr(builtins(), type_): # Check if it's a builtin type
cls = getattr(builtins(), type_)
5 years ago
return lambda x=None: ast.literal_eval(x) if x is not None else cls()
2 years ago
match = IS_CLASS.match(type_)
if match:
modname, tname = match.group(1).rsplit(".", 1)
module = get_module(modname)
2 years ago
cls = getattr(module, tname)
return getattr(cls, "deserialize", cls)
5 years ago
# Otherwise, see if we can find the module and the class
options = []
2 years ago
for mod in known_modules:
5 years ago
if mod:
options.append((mod, type_))
if "." in type_: # Fully qualified module
5 years ago
module, type_ = type_.rsplit(".", 1)
2 years ago
options.append((module, type_))
5 years ago
errors = []
for modname, tname in options:
try:
module = get_module(modname)
5 years ago
cls = getattr(module, tname)
return getattr(cls, "deserialize", cls)
except (ImportError, AttributeError) as ex:
5 years ago
errors.append((modname, tname, ex))
raise ValueError('Could not find type "{}". Tried: {}'.format(type_, errors))
5 years ago
def deserialize(type_, value=None, globs=None, **kwargs):
"""Get an object from a text representation"""
5 years ago
if not isinstance(type_, str):
return type_
if globs and type_ in globs:
des = globs[type_]
else:
try:
des = deserializer(type_, **kwargs)
except ValueError as ex:
try:
des = eval(type_)
except Exception:
raise ex
5 years ago
if value is None:
return des
return des(value)
2 years ago
def deserialize_all(names, *args, known_modules=KNOWN_MODULES, **kwargs):
"""Return the list of deserialized objects"""
2 years ago
objects = []
for name in names:
mod = deserialize(name, known_modules=known_modules)
2 years ago
objects.append(mod(*args, **kwargs))
return objects