You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
soil/soil/serialization.py

234 lines
6.6 KiB
Python

5 years ago
import os
import logging
import ast
import sys
2 years ago
import re
5 years ago
import importlib
from glob import glob
from itertools import product, chain
5 years ago
import yaml
import networkx as nx
5 years ago
from jinja2 import Template
logger = logging.getLogger('soil')
2 years ago
# def load_network(network_params, dir_path=None):
# G = nx.Graph()
2 years ago
# if not network_params:
# return G
2 years ago
2 years ago
# if 'path' in network_params:
# path = network_params['path']
# if dir_path and not os.path.isabs(path):
# path = os.path.join(dir_path, path)
# extension = os.path.splitext(path)[1][1:]
# kwargs = {}
# if extension == 'gexf':
# kwargs['version'] = '1.2draft'
# kwargs['node_type'] = int
# try:
# method = getattr(nx.readwrite, 'read_' + extension)
# except AttributeError:
# raise AttributeError('Unknown format')
# G = method(path, **kwargs)
2 years ago
# elif 'generator' in network_params:
# net_args = network_params.copy()
# net_gen = net_args.pop('generator')
5 years ago
2 years ago
# if dir_path not in sys.path:
# sys.path.append(dir_path)
5 years ago
2 years ago
# method = deserializer(net_gen,
# known_modules=['networkx.generators',])
# G = method(**net_args)
5 years ago
2 years ago
# return G
5 years ago
def load_file(infile):
folder = os.path.dirname(infile)
if folder not in sys.path:
sys.path.append(folder)
5 years ago
with open(infile, 'r') as f:
return list(chain.from_iterable(map(expand_template, load_string(f))))
5 years ago
5 years ago
def load_string(string):
yield from yaml.load_all(string, Loader=yaml.FullLoader)
5 years ago
def expand_template(config):
if 'template' not in config:
yield config
return
if 'vars' not in config:
raise ValueError(('You must provide a definition of variables'
' for the template.'))
template = config['template']
if not isinstance(template, str):
template = yaml.dump(template)
template = Template(template)
5 years ago
params = params_for_template(config)
blank_str = template.render({k: 0 for k in params[0].keys()})
blank = list(load_string(blank_str))
if len(blank) > 1:
raise ValueError('Templates must not return more than one configuration')
if 'name' in blank[0]:
raise ValueError('Templates cannot be named, use group instead')
for ps in params:
string = template.render(ps)
for c in load_string(string):
yield c
def params_for_template(config):
sampler_config = config.get('sampler', {'N': 100})
sampler = sampler_config.pop('method', 'SALib.sample.morris.sample')
sampler = deserializer(sampler)
5 years ago
bounds = config['vars']['bounds']
problem = {
'num_vars': len(bounds),
'names': list(bounds.keys()),
'bounds': list(v for v in bounds.values())
}
samples = sampler(problem, **sampler_config)
5 years ago
lists = config['vars'].get('lists', {})
names = list(lists.keys())
values = list(lists.values())
combs = list(product(*values))
allnames = names + problem['names']
allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)]
params = list(map(lambda x: dict(zip(allnames, x)), allvalues))
return params
5 years ago
def load_files(*patterns, **kwargs):
for pattern in patterns:
for i in glob(pattern, **kwargs):
for config in load_file(i):
path = os.path.abspath(i)
yield config, path
def load_config(config):
if isinstance(config, dict):
yield config, os.getcwd()
5 years ago
else:
yield from load_files(config)
builtins = importlib.import_module('builtins')
2 years ago
KNOWN_MODULES = ['soil', ]
def name(value, known_modules=KNOWN_MODULES):
5 years ago
'''Return a name that can be imported, to serialize/deserialize an object'''
if value is None:
return 'None'
if not isinstance(value, type): # Get the class name first
value = type(value)
tname = value.__name__
if hasattr(builtins, tname):
return tname
modname = value.__module__
if modname == '__main__':
return tname
if known_modules and modname in known_modules:
return tname
for kmod in known_modules:
if not kmod:
continue
module = importlib.import_module(kmod)
if hasattr(module, tname):
return tname
return '{}.{}'.format(modname, tname)
def serializer(type_):
if type_ != 'str' and hasattr(builtins, type_):
return repr
return lambda x: x
2 years ago
def serialize(v, known_modules=KNOWN_MODULES):
5 years ago
'''Get a text representation of an object.'''
tname = name(v, known_modules=known_modules)
func = serializer(tname)
return func(v), tname
2 years ago
IS_CLASS = re.compile(r"<class '(.*)'>")
2 years ago
def deserializer(type_, known_modules=KNOWN_MODULES):
5 years ago
if type(type_) != str: # Already deserialized
return type_
if type_ == 'str':
return lambda x='': x
if type_ == 'None':
return lambda x=None: None
if hasattr(builtins, type_): # Check if it's a builtin type
cls = getattr(builtins, type_)
return lambda x=None: ast.literal_eval(x) if x is not None else cls()
2 years ago
match = IS_CLASS.match(type_)
if match:
modname, tname = match.group(1).rsplit(".", 1)
module = importlib.import_module(modname)
cls = getattr(module, tname)
return getattr(cls, 'deserialize', cls)
5 years ago
# Otherwise, see if we can find the module and the class
options = []
2 years ago
for mod in known_modules:
5 years ago
if mod:
options.append((mod, type_))
if '.' in type_: # Fully qualified module
module, type_ = type_.rsplit(".", 1)
2 years ago
options.append((module, type_))
5 years ago
errors = []
for modname, tname in options:
try:
module = importlib.import_module(modname)
cls = getattr(module, tname)
return getattr(cls, 'deserialize', cls)
except (ImportError, AttributeError) as ex:
5 years ago
errors.append((modname, tname, ex))
raise Exception('Could not find type {}. Tried: {}'.format(type_, errors))
def deserialize(type_, value=None, **kwargs):
'''Get an object from a text representation'''
if not isinstance(type_, str):
return type_
des = deserializer(type_, **kwargs)
if value is None:
return des
return des(value)
2 years ago
def deserialize_all(names, *args, known_modules=KNOWN_MODULES, **kwargs):
2 years ago
'''Return the list of deserialized objects'''
objects = []
for name in names:
mod = deserialize(name, known_modules=known_modules)
2 years ago
objects.append(mod(*args, **kwargs))
return objects