mirror of
https://github.com/gsi-upm/soil
synced 2025-08-23 19:52:19 +00:00
All tests pass
This commit is contained in:
@@ -14,12 +14,11 @@ except NameError:
|
||||
logging.basicConfig()
|
||||
|
||||
from . import agents
|
||||
from . import simulation
|
||||
from . import environment
|
||||
from .simulation import *
|
||||
from .environment import Environment
|
||||
from . import utils
|
||||
from . import analysis
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
from . import simulation
|
||||
@@ -46,11 +45,12 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.module:
|
||||
if os.getcwd() not in sys.path:
|
||||
sys.path.append(os.getcwd())
|
||||
if args.module:
|
||||
importlib.import_module(args.module)
|
||||
|
||||
logging.info('Loading config file: {}'.format(args.file, args.output))
|
||||
logging.info('Loading config file: {}'.format(args.file))
|
||||
|
||||
try:
|
||||
dump = []
|
||||
@@ -64,7 +64,7 @@ def main():
|
||||
dump=dump,
|
||||
parallel=(not args.synchronous and not args.pdb),
|
||||
results_dir=args.output)
|
||||
except Exception as ex:
|
||||
except Exception:
|
||||
if args.pdb:
|
||||
pdb.post_mortem()
|
||||
else:
|
||||
|
@@ -10,7 +10,7 @@ class SISaModel(FSM):
|
||||
|
||||
neutral_discontent_infected_prob
|
||||
|
||||
neutral_content_spong_prob
|
||||
neutral_content_spon_prob
|
||||
|
||||
neutral_content_infected_prob
|
||||
|
||||
@@ -29,27 +29,27 @@ class SISaModel(FSM):
|
||||
standard_variance
|
||||
"""
|
||||
|
||||
def __init__(self, environment=None, agent_id=0, state=()):
|
||||
def __init__(self, environment, agent_id=0, state=()):
|
||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||
|
||||
self.neutral_discontent_spon_prob = np.random.normal(environment.environment_params['neutral_discontent_spon_prob'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.neutral_discontent_infected_prob = np.random.normal(environment.environment_params['neutral_discontent_infected_prob'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.neutral_content_spon_prob = np.random.normal(environment.environment_params['neutral_content_spon_prob'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.neutral_content_infected_prob = np.random.normal(environment.environment_params['neutral_content_infected_prob'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.neutral_discontent_spon_prob = np.random.normal(self.env['neutral_discontent_spon_prob'],
|
||||
self.env['standard_variance'])
|
||||
self.neutral_discontent_infected_prob = np.random.normal(self.env['neutral_discontent_infected_prob'],
|
||||
self.env['standard_variance'])
|
||||
self.neutral_content_spon_prob = np.random.normal(self.env['neutral_content_spon_prob'],
|
||||
self.env['standard_variance'])
|
||||
self.neutral_content_infected_prob = np.random.normal(self.env['neutral_content_infected_prob'],
|
||||
self.env['standard_variance'])
|
||||
|
||||
self.discontent_neutral = np.random.normal(environment.environment_params['discontent_neutral'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.discontent_content = np.random.normal(environment.environment_params['discontent_content'],
|
||||
environment.environment_params['variance_d_c'])
|
||||
self.discontent_neutral = np.random.normal(self.env['discontent_neutral'],
|
||||
self.env['standard_variance'])
|
||||
self.discontent_content = np.random.normal(self.env['discontent_content'],
|
||||
self.env['variance_d_c'])
|
||||
|
||||
self.content_discontent = np.random.normal(environment.environment_params['content_discontent'],
|
||||
environment.environment_params['variance_c_d'])
|
||||
self.content_neutral = np.random.normal(environment.environment_params['content_neutral'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.content_discontent = np.random.normal(self.env['content_discontent'],
|
||||
self.env['variance_c_d'])
|
||||
self.content_neutral = np.random.normal(self.env['content_neutral'],
|
||||
self.env['standard_variance'])
|
||||
|
||||
@state
|
||||
def neutral(self):
|
||||
|
@@ -16,23 +16,15 @@ from functools import wraps
|
||||
|
||||
from .. import utils, history
|
||||
|
||||
agent_types = {}
|
||||
|
||||
|
||||
class MetaAgent(type):
|
||||
def __init__(cls, name, bases, nmspc):
|
||||
super(MetaAgent, cls).__init__(name, bases, nmspc)
|
||||
agent_types[name] = cls
|
||||
|
||||
|
||||
class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
||||
class BaseAgent(nxsim.BaseAgent):
|
||||
"""
|
||||
A special simpy BaseAgent that keeps track of its state history.
|
||||
"""
|
||||
|
||||
defaults = {}
|
||||
|
||||
def __init__(self, environment=None, agent_id=None, state=None,
|
||||
def __init__(self, environment, agent_id=None, state=None,
|
||||
name='network_process', interval=None, **state_params):
|
||||
# Check for REQUIRED arguments
|
||||
assert environment is not None, TypeError('__init__ missing 1 required keyword argument: \'environment\'. '
|
||||
@@ -152,14 +144,18 @@ class BaseAgent(nxsim.BaseAgent, metaclass=MetaAgent):
|
||||
def count_neighboring_agents(self, state_id=None):
|
||||
return len(super().get_agents(state_id, limit_neighbors=True))
|
||||
|
||||
def get_agents(self, state_id=None, limit_neighbors=False, iterator=False, **kwargs):
|
||||
def get_agents(self, state_id=None, agent_type=None, limit_neighbors=False, iterator=False, **kwargs):
|
||||
agents = self.env.agents
|
||||
if limit_neighbors:
|
||||
agents = super().get_agents(state_id, limit_neighbors)
|
||||
else:
|
||||
agents = filter(lambda x: state_id is None or x.state.get('id', None) == state_id,
|
||||
self.env.agents)
|
||||
|
||||
def matches_all(agent):
|
||||
if state_id is not None:
|
||||
if agent.state.get('id', None) != state_id:
|
||||
return False
|
||||
if agent_type is not None:
|
||||
if type(agent) != agent_type:
|
||||
return False
|
||||
state = agent.state
|
||||
for k, v in kwargs.items():
|
||||
if state.get(k, None) != v:
|
||||
@@ -219,7 +215,7 @@ def default_state(func):
|
||||
return func
|
||||
|
||||
|
||||
class MetaFSM(MetaAgent):
|
||||
class MetaFSM(type):
|
||||
def __init__(cls, name, bases, nmspc):
|
||||
super(MetaFSM, cls).__init__(name, bases, nmspc)
|
||||
states = {}
|
||||
@@ -328,16 +324,42 @@ def calculate_distribution(network_agents=None,
|
||||
return network_agents
|
||||
|
||||
|
||||
def _serialize_distribution(network_agents):
|
||||
d = _convert_agent_types(network_agents,
|
||||
to_string=True)
|
||||
def serialize_agent_type(agent_type):
|
||||
if isinstance(agent_type, str):
|
||||
return agent_type
|
||||
type_name = agent_type.__name__
|
||||
if type_name not in globals():
|
||||
type_name = utils.name(agent_type)
|
||||
return type_name
|
||||
|
||||
def serialize_distribution(network_agents):
|
||||
'''
|
||||
When serializing an agent distribution, remove the thresholds, in order
|
||||
to avoid cluttering the YAML definition file.
|
||||
'''
|
||||
d = deepcopy(network_agents)
|
||||
for v in d:
|
||||
if 'threshold' in v:
|
||||
del v['threshold']
|
||||
v['agent_type'] = serialize_agent_type(v['agent_type'])
|
||||
return d
|
||||
|
||||
|
||||
def deserialize_type(agent_type, known_modules=[]):
|
||||
if not isinstance(agent_type, str):
|
||||
return agent_type
|
||||
if agent_type in globals():
|
||||
agent_type = globals()[agent_type]
|
||||
else:
|
||||
known = known_modules + ['soil.agents', 'soil.agents.custom' ]
|
||||
agent_type = utils.deserializer(agent_type, known_modules=known)
|
||||
return agent_type
|
||||
|
||||
|
||||
def deserialize_distribution(ind):
|
||||
d = deepcopy(ind)
|
||||
for v in d:
|
||||
v['agent_type'] = deserialize_type(v['agent_type'])
|
||||
return d
|
||||
|
||||
|
||||
@@ -354,14 +376,9 @@ def _validate_states(states, topology):
|
||||
|
||||
def _convert_agent_types(ind, to_string=False):
|
||||
'''Convenience method to allow specifying agents by class or class name.'''
|
||||
d = deepcopy(ind)
|
||||
for v in d:
|
||||
agent_type = v['agent_type']
|
||||
if to_string and not isinstance(agent_type, str):
|
||||
v['agent_type'] = str(agent_type.__name__)
|
||||
elif not to_string and isinstance(agent_type, str):
|
||||
v['agent_type'] = agent_types[agent_type]
|
||||
return d
|
||||
if to_string:
|
||||
return serialize_distribution(ind)
|
||||
return deserialize_distribution(ind)
|
||||
|
||||
|
||||
def _agent_from_distribution(distribution, value=-1):
|
||||
|
@@ -56,7 +56,7 @@ def read_csv(filename, keys=None, convert_types=False, **kwargs):
|
||||
|
||||
|
||||
def convert_row(row):
|
||||
row['value'] = utils.convert(row['value'], row['value_type'])
|
||||
row['value'] = utils.deserialize(row['value_type'], row['value'])
|
||||
return row
|
||||
|
||||
|
||||
|
@@ -15,7 +15,7 @@ import nxsim
|
||||
from . import utils, agents, analysis, history
|
||||
|
||||
|
||||
class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
class Environment(nxsim.NetworkEnvironment):
|
||||
"""
|
||||
The environment is key in a simulation. It contains the network topology,
|
||||
a reference to network and environment agents, as well as the environment
|
||||
@@ -23,7 +23,7 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
|
||||
The environment parameters and the state of every agent can be accessed
|
||||
both by using the environment as a dictionary or with the environment's
|
||||
:meth:`soil.environment.SoilEnvironment.get` method.
|
||||
:meth:`soil.environment.Environment.get` method.
|
||||
"""
|
||||
|
||||
def __init__(self, name=None,
|
||||
@@ -49,7 +49,8 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
self.dry_run = dry_run
|
||||
self.interval = interval
|
||||
self.dir_path = dir_path or tempfile.mkdtemp('soil-env')
|
||||
self.get_path()
|
||||
if not dry_run:
|
||||
self.get_path()
|
||||
self._history = history.History(name=self.name if not dry_run else None,
|
||||
dir_path=self.dir_path)
|
||||
# Add environment agents first, so their events get
|
||||
@@ -93,17 +94,35 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
if not network_agents:
|
||||
return
|
||||
for ix in self.G.nodes():
|
||||
agent, state = agents._agent_from_distribution(network_agents)
|
||||
self.set_agent(ix, agent_type=agent, state=state)
|
||||
self.init_agent(ix, agent_distribution=network_agents)
|
||||
|
||||
def init_agent(self, agent_id, agent_distribution):
|
||||
node = self.G.nodes[agent_id]
|
||||
init = False
|
||||
state = dict(node)
|
||||
|
||||
agent_type = None
|
||||
if 'agent_type' in self.states.get(agent_id, {}):
|
||||
agent_type = self.states[agent_id]
|
||||
elif 'agent_type' in node:
|
||||
agent_type = node['agent_type']
|
||||
elif 'agent_type' in self.default_state:
|
||||
agent_type = self.default_state['agent_type']
|
||||
|
||||
if agent_type:
|
||||
agent_type = agents.deserialize_agent_type(agent_type)
|
||||
else:
|
||||
agent_type, state = agents._agent_from_distribution(agent_distribution)
|
||||
return self.set_agent(agent_id, agent_type, state)
|
||||
|
||||
def set_agent(self, agent_id, agent_type, state=None):
|
||||
node = self.G.nodes[agent_id]
|
||||
defstate = deepcopy(self.default_state)
|
||||
defstate = deepcopy(self.default_state) or {}
|
||||
defstate.update(self.states.get(agent_id, {}))
|
||||
defstate.update(node.get('state', {}))
|
||||
if state:
|
||||
defstate.update(state)
|
||||
state = defstate
|
||||
state.update(node.get('state', {}))
|
||||
a = agent_type(environment=self,
|
||||
agent_id=agent_id,
|
||||
state=state)
|
||||
@@ -118,6 +137,10 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
return a
|
||||
|
||||
def add_edge(self, agent1, agent2, attrs=None):
|
||||
if hasattr(agent1, 'id'):
|
||||
agent1 = agent1.id
|
||||
if hasattr(agent2, 'id'):
|
||||
agent2 = agent2.id
|
||||
return self.G.add_edge(agent1, agent2)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
@@ -202,7 +225,7 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
|
||||
with open(csv_name, 'w') as f:
|
||||
cr = csv.writer(f)
|
||||
cr.writerow(('agent_id', 't_step', 'key', 'value', 'value_type'))
|
||||
cr.writerow(('agent_id', 't_step', 'key', 'value'))
|
||||
for i in self.history_to_tuples():
|
||||
cr.writerow(i)
|
||||
|
||||
@@ -302,7 +325,6 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
state['network_agents'] = agents._serialize_distribution(self.network_agents)
|
||||
state['environment_agents'] = agents._convert_agent_types(self.environment_agents,
|
||||
to_string=True)
|
||||
del state['_queue']
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
@@ -311,3 +333,6 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
|
||||
self.network_agents = self.calculate_distribution(self._convert_agent_types(self.network_agents))
|
||||
self.environment_agents = self._convert_agent_types(self.environment_agents)
|
||||
return state
|
||||
|
||||
|
||||
SoilEnvironment = Environment
|
||||
|
101
soil/history.py
101
soil/history.py
@@ -17,12 +17,12 @@ class History:
|
||||
if db_path is None and name:
|
||||
db_path = os.path.join(dir_path or os.getcwd(),
|
||||
'{}.db.sqlite'.format(name))
|
||||
if db_path is None:
|
||||
db_path = ":memory:"
|
||||
else:
|
||||
if db_path:
|
||||
if backup and os.path.exists(db_path):
|
||||
newname = db_path + '.backup{}.sqlite'.format(time.time())
|
||||
os.rename(db_path, newname)
|
||||
else:
|
||||
db_path = ":memory:"
|
||||
self.db_path = db_path
|
||||
|
||||
self.db = db_path
|
||||
@@ -34,12 +34,6 @@ class History:
|
||||
self._dtypes = {}
|
||||
self._tups = []
|
||||
|
||||
def conversors(self, key):
|
||||
"""Get the serializer and deserializer for a given key."""
|
||||
if key not in self._dtypes:
|
||||
self.read_types()
|
||||
return self._dtypes[key]
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
try:
|
||||
@@ -58,55 +52,88 @@ class History:
|
||||
|
||||
@property
|
||||
def dtypes(self):
|
||||
self.read_types()
|
||||
return {k:v[0] for k, v in self._dtypes.items()}
|
||||
|
||||
def save_tuples(self, tuples):
|
||||
'''
|
||||
Save a series of tuples, converting them to records if necessary
|
||||
'''
|
||||
self.save_records(Record(*tup) for tup in tuples)
|
||||
|
||||
def save_records(self, records):
|
||||
with self.db:
|
||||
for rec in records:
|
||||
if not isinstance(rec, Record):
|
||||
rec = Record(*rec)
|
||||
if rec.key not in self._dtypes:
|
||||
name = utils.name(rec.value)
|
||||
serializer = utils.serializer(name)
|
||||
deserializer = utils.deserializer(name)
|
||||
self._dtypes[rec.key] = (name, serializer, deserializer)
|
||||
self.db.execute("replace into value_types (key, value_type) values (?, ?)", (rec.key, name))
|
||||
self.db.execute("replace into history(agent_id, t_step, key, value) values (?, ?, ?, ?)", (rec.agent_id, rec.t_step, rec.key, rec.value))
|
||||
'''
|
||||
Save a collection of records
|
||||
'''
|
||||
for record in records:
|
||||
if not isinstance(record, Record):
|
||||
record = Record(*record)
|
||||
self.save_record(*record)
|
||||
|
||||
def save_record(self, *args, **kwargs):
|
||||
self._tups.append(Record(*args, **kwargs))
|
||||
def save_record(self, agent_id, t_step, key, value):
|
||||
'''
|
||||
Save a collection of records to the database.
|
||||
Database writes are cached.
|
||||
'''
|
||||
value = self.convert(key, value)
|
||||
self._tups.append(Record(agent_id=agent_id,
|
||||
t_step=t_step,
|
||||
key=key,
|
||||
value=value))
|
||||
if len(self._tups) > 100:
|
||||
self.flush_cache()
|
||||
|
||||
def convert(self, key, value):
|
||||
"""Get the serialized value for a given key."""
|
||||
if key not in self._dtypes:
|
||||
self.read_types()
|
||||
if key not in self._dtypes:
|
||||
name = utils.name(value)
|
||||
serializer = utils.serializer(name)
|
||||
deserializer = utils.deserializer(name)
|
||||
self._dtypes[key] = (name, serializer, deserializer)
|
||||
with self.db:
|
||||
self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name))
|
||||
return self._dtypes[key][1](value)
|
||||
|
||||
def recover(self, key, value):
|
||||
"""Get the deserialized value for a given key, and the serialized version."""
|
||||
if key not in self._dtypes:
|
||||
self.read_types()
|
||||
if key not in self._dtypes:
|
||||
raise ValueError("Unknown datatype for {} and {}".format(key, value))
|
||||
return self._dtypes[key][2](value)
|
||||
|
||||
|
||||
def flush_cache(self):
|
||||
'''
|
||||
Use a cache to save state changes to avoid opening a session for every change.
|
||||
The cache will be flushed at the end of the simulation, and when history is accessed.
|
||||
'''
|
||||
self.save_records(self._tups)
|
||||
with self.db:
|
||||
for rec in self._tups:
|
||||
self.db.execute("replace into history(agent_id, t_step, key, value) values (?, ?, ?, ?)", (rec.agent_id, rec.t_step, rec.key, rec.value))
|
||||
self._tups = list()
|
||||
|
||||
def to_tuples(self):
|
||||
self.flush_cache()
|
||||
with self.db:
|
||||
res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall()
|
||||
for r in res:
|
||||
agent_id, t_step, key, value = r
|
||||
_, _ , des = self.conversors(key)
|
||||
yield agent_id, t_step, key, des(value)
|
||||
self.flush_cache()
|
||||
with self.db:
|
||||
res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall()
|
||||
for r in res:
|
||||
agent_id, t_step, key, value = r
|
||||
value = self.recover(key, value)
|
||||
yield agent_id, t_step, key, value
|
||||
|
||||
def read_types(self):
|
||||
with self.db:
|
||||
res = self.db.execute("select key, value_type from value_types ").fetchall()
|
||||
for k, v in res:
|
||||
serializer = utils.serializer(v)
|
||||
deserializer = utils.deserializer(v)
|
||||
self._dtypes[k] = (v, serializer, deserializer)
|
||||
with self.db:
|
||||
res = self.db.execute("select key, value_type from value_types ").fetchall()
|
||||
for k, v in res:
|
||||
serializer = utils.serializer(v)
|
||||
deserializer = utils.deserializer(v)
|
||||
self._dtypes[k] = (v, serializer, deserializer)
|
||||
|
||||
def __getitem__(self, key):
|
||||
self.flush_cache()
|
||||
key = Key(*key)
|
||||
agent_ids = [key.agent_id] if key.agent_id is not None else []
|
||||
t_steps = [key.t_step] if key.t_step is not None else []
|
||||
@@ -176,7 +203,7 @@ class History:
|
||||
for k, v in self._dtypes.items():
|
||||
if k in df_p:
|
||||
dtype, _, deserial = v
|
||||
df_p[k] = df_p[k].fillna(method='ffill').fillna(deserial()).astype(dtype)
|
||||
df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
|
||||
if t_steps:
|
||||
df_p = df_p.reindex(t_steps, method='ffill')
|
||||
return df_p.ffill()
|
||||
|
@@ -12,11 +12,12 @@ import pickle
|
||||
|
||||
from nxsim import NetworkSimulation
|
||||
|
||||
from . import utils, environment, basestring, agents
|
||||
from . import utils, basestring, agents
|
||||
from .environment import Environment
|
||||
from .utils import logger
|
||||
|
||||
|
||||
class SoilSimulation(NetworkSimulation):
|
||||
class Simulation(NetworkSimulation):
|
||||
"""
|
||||
Subclass of nsim.NetworkSimulation with three main differences:
|
||||
1) agent type can be specified by name or by class.
|
||||
@@ -43,13 +44,47 @@ class SoilSimulation(NetworkSimulation):
|
||||
'agent_type_1'.
|
||||
3) if no initial state is given, each node's state will be set
|
||||
to `{'id': 0}`.
|
||||
|
||||
Parameters
|
||||
---------
|
||||
name : str, optional
|
||||
name of the Simulation
|
||||
topology : networkx.Graph instance, optional
|
||||
network_params : dict
|
||||
parameters used to create a topology with networkx, if no topology is given
|
||||
network_agents : dict
|
||||
definition of agents to populate the topology with
|
||||
agent_type : NetworkAgent subclass, optional
|
||||
Default type of NetworkAgent to use for nodes not specified in network_agents
|
||||
states : list, optional
|
||||
List of initial states corresponding to the nodes in the topology. Basic form is a list of integers
|
||||
whose value indicates the state
|
||||
dir_path : str, optional
|
||||
Directory path where to save pickled objects
|
||||
seed : str, optional
|
||||
Seed to use for the random generator
|
||||
num_trials : int, optional
|
||||
Number of independent simulation runs
|
||||
max_time : int, optional
|
||||
Time how long the simulation should run
|
||||
environment_params : dict, optional
|
||||
Dictionary of globally-shared environmental parameters
|
||||
environment_agents: dict, optional
|
||||
Similar to network_agents. Distribution of Agents that control the environment
|
||||
environment_class: soil.environment.Environment subclass, optional
|
||||
Class for the environment. It defailts to soil.environment.Environment
|
||||
load_module : str, module name, deprecated
|
||||
If specified, soil will load the content of this module under 'soil.agents.custom'
|
||||
|
||||
|
||||
"""
|
||||
def __init__(self, name=None, topology=None, network_params=None,
|
||||
network_agents=None, agent_type=None, states=None,
|
||||
default_state=None, interval=1, dump=None, dry_run=False,
|
||||
dir_path=None, num_trials=1, max_time=100,
|
||||
agent_module=None, load_module=None, seed=None,
|
||||
environment_agents=None, environment_params=None, **kwargs):
|
||||
load_module=None, seed=None,
|
||||
environment_agents=None, environment_params=None,
|
||||
environment_class=None, **kwargs):
|
||||
|
||||
if topology is None:
|
||||
topology = utils.load_network(network_params,
|
||||
@@ -70,11 +105,15 @@ class SoilSimulation(NetworkSimulation):
|
||||
self.dump = dump
|
||||
self.dry_run = dry_run
|
||||
self.environment_params = environment_params or {}
|
||||
self.environment_class = utils.deserialize(environment_class,
|
||||
known_modules=['soil.environment',]) or Environment
|
||||
|
||||
self._loaded_module = None
|
||||
|
||||
if load_module:
|
||||
path = sys.path + [self.dir_path, os.getcwd()]
|
||||
f, fp, desc = imp.find_module(load_module, path)
|
||||
imp.load_module('soil.agents.custom', f, fp, desc)
|
||||
self._loaded_module = imp.load_module('soil.agents.custom', f, fp, desc)
|
||||
|
||||
environment_agents = environment_agents or []
|
||||
self.environment_agents = agents._convert_agent_types(environment_agents)
|
||||
@@ -128,7 +167,7 @@ class SoilSimulation(NetworkSimulation):
|
||||
'dir_path': self.dir_path,
|
||||
})
|
||||
opts.update(kwargs)
|
||||
env = environment.SoilEnvironment(**opts)
|
||||
env = self.environment_class(**opts)
|
||||
return env
|
||||
|
||||
def run_trial(self, trial_id=0, until=None, return_env=True, **opts):
|
||||
@@ -177,11 +216,18 @@ class SoilSimulation(NetworkSimulation):
|
||||
pickle.dump(self, f)
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if k[0] != '_':
|
||||
state[k] = v
|
||||
state['topology'] = json_graph.node_link_data(self.topology)
|
||||
state['network_agents'] = agents._serialize_distribution(self.network_agents)
|
||||
state['network_agents'] = agents.serialize_distribution(self.network_agents)
|
||||
state['environment_agents'] = agents._convert_agent_types(self.environment_agents,
|
||||
to_string=True)
|
||||
state['environment_class'] = utils.serialize(self.environment_class,
|
||||
known_modules=['soil.environment', ])[1] # func, name
|
||||
if state['load_module'] is None:
|
||||
del state['load_module']
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
@@ -189,6 +235,8 @@ class SoilSimulation(NetworkSimulation):
|
||||
self.topology = json_graph.node_link_graph(state['topology'])
|
||||
self.network_agents = agents.calculate_distribution(agents._convert_agent_types(self.network_agents))
|
||||
self.environment_agents = agents._convert_agent_types(self.environment_agents)
|
||||
self.environment_class = utils.deserialize(self.environment_class,
|
||||
known_modules=['soil.environment', ]) # func, name
|
||||
return state
|
||||
|
||||
|
||||
@@ -197,11 +245,11 @@ def from_config(config):
|
||||
if len(config) > 1:
|
||||
raise AttributeError('Provide only one configuration')
|
||||
config = config[0][0]
|
||||
sim = SoilSimulation(**config)
|
||||
sim = Simulation(**config)
|
||||
return sim
|
||||
|
||||
|
||||
def run_from_config(*configs, results_dir='soil_output', dry_run=False, dump=None, timestamp=False, **kwargs):
|
||||
def run_from_config(*configs, results_dir='soil_output', dump=None, timestamp=False, **kwargs):
|
||||
for config_def in configs:
|
||||
# logger.info("Found {} config(s)".format(len(ls)))
|
||||
for config, _ in utils.load_config(config_def):
|
||||
@@ -214,6 +262,8 @@ def run_from_config(*configs, results_dir='soil_output', dry_run=False, dump=Non
|
||||
else:
|
||||
sim_folder = name
|
||||
dir_path = os.path.join(results_dir, sim_folder)
|
||||
sim = SoilSimulation(dir_path=dir_path, dump=dump, **config)
|
||||
if dump is not None:
|
||||
config['dump'] = dump
|
||||
sim = Simulation(dir_path=dir_path, **config)
|
||||
logger.info('Dumping results to {} : {}'.format(sim.dir_path, sim.dump))
|
||||
sim.run_simulation(**kwargs)
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import os
|
||||
import ast
|
||||
import yaml
|
||||
import logging
|
||||
import importlib
|
||||
from time import time
|
||||
import time
|
||||
from glob import glob
|
||||
from random import random
|
||||
from copy import deepcopy
|
||||
@@ -62,44 +63,89 @@ def load_config(config):
|
||||
|
||||
@contextmanager
|
||||
def timer(name='task', pre="", function=logger.info, to_object=None):
|
||||
start = time()
|
||||
function('{}Starting {} at {}.'.format(pre, name, start))
|
||||
start = time.time()
|
||||
function('{}Starting {} at {}.'.format(pre, name,
|
||||
time.strftime("%X", time.gmtime(start))))
|
||||
yield start
|
||||
end = time()
|
||||
function('{}Finished {} in {} seconds'.format(pre, name, str(end-start)))
|
||||
end = time.time()
|
||||
function('{}Finished {} at {} in {} seconds'.format(pre, name,
|
||||
time.strftime("%X", time.gmtime(end)),
|
||||
str(end-start)))
|
||||
if to_object:
|
||||
to_object.start = start
|
||||
to_object.end = end
|
||||
|
||||
|
||||
def repr(v):
|
||||
func = serializer(v)
|
||||
tname = name(v)
|
||||
return func(v), tname
|
||||
builtins = importlib.import_module('builtins')
|
||||
|
||||
|
||||
def name(v):
|
||||
return type(v).__name__
|
||||
def name(value, known_modules=[]):
|
||||
'''Return a name that can be imported, to serialize/deserialize an object'''
|
||||
if value is None:
|
||||
return 'None'
|
||||
if not isinstance(value, type): # Get the class name first
|
||||
value = type(value)
|
||||
tname = value.__name__
|
||||
if hasattr(builtins, tname):
|
||||
return tname
|
||||
modname = value.__module__
|
||||
if modname == '__main__':
|
||||
return tname
|
||||
if known_modules and modname in known_modules:
|
||||
return tname
|
||||
for mod_name in known_modules:
|
||||
module = importlib.import_module(mod_name)
|
||||
if hasattr(module, tname):
|
||||
return tname
|
||||
return '{}.{}'.format(modname, tname)
|
||||
|
||||
|
||||
def serializer(type_):
|
||||
if type_ == 'bool':
|
||||
return lambda x: "true" if x else ""
|
||||
if type_ != 'str' and hasattr(builtins, type_):
|
||||
return repr
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def deserializer(type_):
|
||||
try:
|
||||
# Check if it's a builtin type
|
||||
module = importlib.import_module('builtins')
|
||||
cls = getattr(module, type_)
|
||||
except AttributeError:
|
||||
# if not, separate module and class
|
||||
def serialize(v, known_modules=[]):
|
||||
'''Get a text representation of an object.'''
|
||||
tname = name(v, known_modules=known_modules)
|
||||
func = serializer(tname)
|
||||
return func(v), tname
|
||||
|
||||
def deserializer(type_, known_modules=[]):
|
||||
if type_ == 'str':
|
||||
return lambda x='': x
|
||||
if type_ == 'None':
|
||||
return lambda x=None: None
|
||||
if hasattr(builtins, type_): # Check if it's a builtin type
|
||||
cls = getattr(builtins, type_)
|
||||
return lambda x=None: ast.literal_eval(x) if x is not None else cls()
|
||||
# Otherwise, see if we can find the module and the class
|
||||
modules = known_modules or []
|
||||
options = []
|
||||
|
||||
for mod in modules:
|
||||
options.append((mod, type_))
|
||||
|
||||
if '.' in type_: # Fully qualified module
|
||||
module, type_ = type_.rsplit(".", 1)
|
||||
module = importlib.import_module(module)
|
||||
cls = getattr(module, type_)
|
||||
return cls
|
||||
options.append ((module, type_))
|
||||
|
||||
errors = []
|
||||
for module, name in options:
|
||||
try:
|
||||
module = importlib.import_module(module)
|
||||
cls = getattr(module, name)
|
||||
return getattr(cls, 'deserialize', cls)
|
||||
except (ImportError, AttributeError) as ex:
|
||||
errors.append((module, name, ex))
|
||||
raise Exception('Could not find module {}. Tried: {}'.format(type_, errors))
|
||||
|
||||
|
||||
def convert(value, type_):
|
||||
return deserializer(type_)(value)
|
||||
def deserialize(type_, value=None, **kwargs):
|
||||
'''Get an object from a text representation'''
|
||||
if not isinstance(type_, str):
|
||||
return type_
|
||||
des = deserializer(type_, **kwargs)
|
||||
if value is None:
|
||||
return des
|
||||
return des(value)
|
||||
|
@@ -271,4 +271,4 @@ def main():
|
||||
parser.add_argument('--verbose', '-v', help='verbose mode', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
run(name=args.name, port=(args.port[0] if isinstance(args.port, list) else args.port), verbose=args.verbose)
|
||||
run(name=args.name, port=(args.port[0] if isinstance(args.port, list) else args.port), verbose=args.verbose)
|
Reference in New Issue
Block a user