mirror of
https://github.com/gsi-upm/soil
synced 2024-11-22 11:12:29 +00:00
Fix multithreading
Multithreading needs pickling to work. Pickling/unpickling didn't work in some situations, like when the environment_agents parameter was left blank. This was due to two reasons: 1) agents and history didn't have a setstate method, and some of their attributes cannot be pickled (generators, sqlite connection) 2) the environment was adding generators (agents) to its state. This fixes the situation by restricting the keys that the environment exports when it pickles, and by adding the set/getstate methods in agents. The resulting pickles should contain enough information to inspect them (history, state values, etc), but very limited.
This commit is contained in:
parent
3526fa29d7
commit
9749f4ca14
@ -1 +1 @@
|
|||||||
0.13.1
|
0.13.3
|
||||||
|
@ -11,8 +11,6 @@ try:
|
|||||||
except NameError:
|
except NameError:
|
||||||
basestring = str
|
basestring = str
|
||||||
|
|
||||||
logging.basicConfig()
|
|
||||||
|
|
||||||
from . import agents
|
from . import agents
|
||||||
from .simulation import *
|
from .simulation import *
|
||||||
from .environment import Environment
|
from .environment import Environment
|
||||||
@ -23,6 +21,9 @@ def main():
|
|||||||
import argparse
|
import argparse
|
||||||
from . import simulation
|
from . import simulation
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logging.info('Running SOIL version: {}'.format(__version__))
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Run a SOIL simulation')
|
parser = argparse.ArgumentParser(description='Run a SOIL simulation')
|
||||||
parser.add_argument('file', type=str,
|
parser.add_argument('file', type=str,
|
||||||
nargs="?",
|
nargs="?",
|
||||||
@ -62,7 +63,7 @@ def main():
|
|||||||
simulation.run_from_config(args.file,
|
simulation.run_from_config(args.file,
|
||||||
dry_run=args.dry_run,
|
dry_run=args.dry_run,
|
||||||
dump=dump,
|
dump=dump,
|
||||||
parallel=(not args.synchronous and not args.pdb),
|
parallel=(not args.synchronous),
|
||||||
results_dir=args.output)
|
results_dir=args.output)
|
||||||
except Exception:
|
except Exception:
|
||||||
if args.pdb:
|
if args.pdb:
|
||||||
|
@ -24,7 +24,7 @@ class BaseAgent(nxsim.BaseAgent):
|
|||||||
|
|
||||||
defaults = {}
|
defaults = {}
|
||||||
|
|
||||||
def __init__(self, environment, agent_id=None, state=None,
|
def __init__(self, environment, agent_id, state=None,
|
||||||
name='network_process', interval=None, **state_params):
|
name='network_process', interval=None, **state_params):
|
||||||
# Check for REQUIRED arguments
|
# Check for REQUIRED arguments
|
||||||
assert environment is not None, TypeError('__init__ missing 1 required keyword argument: \'environment\'. '
|
assert environment is not None, TypeError('__init__ missing 1 required keyword argument: \'environment\'. '
|
||||||
@ -34,10 +34,6 @@ class BaseAgent(nxsim.BaseAgent):
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.state_params = state_params
|
self.state_params = state_params
|
||||||
|
|
||||||
# Global parameters
|
|
||||||
self.global_topology = environment.G
|
|
||||||
self.environment_params = environment.environment_params
|
|
||||||
|
|
||||||
# Register agent to environment
|
# Register agent to environment
|
||||||
self.env = environment
|
self.env = environment
|
||||||
|
|
||||||
@ -73,6 +69,18 @@ class BaseAgent(nxsim.BaseAgent):
|
|||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
self[k] = v
|
self[k] = v
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_topology(self):
|
||||||
|
return self.env.G
|
||||||
|
|
||||||
|
@property
|
||||||
|
def environment_params(self):
|
||||||
|
return self.env.environment_params
|
||||||
|
|
||||||
|
@environment_params.setter
|
||||||
|
def environment_params(self, value):
|
||||||
|
self.env.environment_params = value
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if isinstance(key, tuple):
|
if isinstance(key, tuple):
|
||||||
key, t_step = key
|
key, t_step = key
|
||||||
@ -126,9 +134,6 @@ class BaseAgent(nxsim.BaseAgent):
|
|||||||
def step(self):
|
def step(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def to_json(self):
|
|
||||||
return json.dumps(self.state)
|
|
||||||
|
|
||||||
def count_agents(self, state_id=None, limit_neighbors=False):
|
def count_agents(self, state_id=None, limit_neighbors=False):
|
||||||
if limit_neighbors:
|
if limit_neighbors:
|
||||||
agents = self.global_topology.neighbors(self.id)
|
agents = self.global_topology.neighbors(self.id)
|
||||||
@ -183,6 +188,26 @@ class BaseAgent(nxsim.BaseAgent):
|
|||||||
def info(self, *args, **kwargs):
|
def info(self, *args, **kwargs):
|
||||||
return self.log(*args, level=logging.INFO, **kwargs)
|
return self.log(*args, level=logging.INFO, **kwargs)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
'''
|
||||||
|
Serializing an agent will lose all its running information (you cannot
|
||||||
|
serialize an iterator), but it keeps the state and link to the environment,
|
||||||
|
so it can be used for inspection and dumping to a file
|
||||||
|
'''
|
||||||
|
state = {}
|
||||||
|
state['id'] = self.id
|
||||||
|
state['environment'] = self.env
|
||||||
|
state['_state'] = self._state
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
'''
|
||||||
|
Get back a serialized agent and try to re-compose it
|
||||||
|
'''
|
||||||
|
self.id = state['id']
|
||||||
|
self._state = state['_state']
|
||||||
|
self.env = state['environment']
|
||||||
|
|
||||||
|
|
||||||
def state(func):
|
def state(func):
|
||||||
'''
|
'''
|
||||||
@ -336,7 +361,7 @@ def serialize_distribution(network_agents, known_modules=[]):
|
|||||||
When serializing an agent distribution, remove the thresholds, in order
|
When serializing an agent distribution, remove the thresholds, in order
|
||||||
to avoid cluttering the YAML definition file.
|
to avoid cluttering the YAML definition file.
|
||||||
'''
|
'''
|
||||||
d = deepcopy(network_agents)
|
d = deepcopy(list(network_agents))
|
||||||
for v in d:
|
for v in d:
|
||||||
if 'threshold' in v:
|
if 'threshold' in v:
|
||||||
del v['threshold']
|
del v['threshold']
|
||||||
|
@ -14,6 +14,14 @@ import nxsim
|
|||||||
|
|
||||||
from . import utils, agents, analysis, history
|
from . import utils, agents, analysis, history
|
||||||
|
|
||||||
|
# These properties will be copied when pickling/unpickling the environment
|
||||||
|
_CONFIG_PROPS = [ 'name',
|
||||||
|
'states',
|
||||||
|
'default_state',
|
||||||
|
'interval',
|
||||||
|
'dry_run',
|
||||||
|
'dir_path',
|
||||||
|
]
|
||||||
|
|
||||||
class Environment(nxsim.NetworkEnvironment):
|
class Environment(nxsim.NetworkEnvironment):
|
||||||
"""
|
"""
|
||||||
@ -320,19 +328,20 @@ class Environment(nxsim.NetworkEnvironment):
|
|||||||
return G
|
return G
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
state = self.__dict__.copy()
|
state = {}
|
||||||
|
for prop in _CONFIG_PROPS:
|
||||||
|
state[prop] = self.__dict__[prop]
|
||||||
state['G'] = json_graph.node_link_data(self.G)
|
state['G'] = json_graph.node_link_data(self.G)
|
||||||
state['network_agents'] = agents.serialize_distribution(self.network_agents)
|
state['environment_agents'] = self._env_agents
|
||||||
state['environment_agents'] = agents._convert_agent_types(self.environment_agents,
|
state['history'] = self._history
|
||||||
to_string=True)
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
self.__dict__ = state
|
for prop in _CONFIG_PROPS:
|
||||||
|
self.__dict__[prop] = state[prop]
|
||||||
|
self._env_agents = state['environment_agents']
|
||||||
self.G = json_graph.node_link_graph(state['G'])
|
self.G = json_graph.node_link_graph(state['G'])
|
||||||
self.network_agents = self.calculate_distribution(self._convert_agent_types(self.network_agents))
|
self._history = state['history']
|
||||||
self.environment_agents = self._convert_agent_types(self.environment_agents)
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
SoilEnvironment = Environment
|
SoilEnvironment = Environment
|
||||||
|
@ -38,7 +38,7 @@ class History:
|
|||||||
def db(self):
|
def db(self):
|
||||||
try:
|
try:
|
||||||
self._db.cursor()
|
self._db.cursor()
|
||||||
except sqlite3.ProgrammingError:
|
except (sqlite3.ProgrammingError, AttributeError):
|
||||||
self.db = None # Reset the database
|
self.db = None # Reset the database
|
||||||
return self._db
|
return self._db
|
||||||
|
|
||||||
@ -208,6 +208,16 @@ class History:
|
|||||||
df_p = df_p.reindex(t_steps, method='ffill')
|
df_p = df_p.reindex(t_steps, method='ffill')
|
||||||
return df_p.ffill()
|
return df_p.ffill()
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = dict(**self.__dict__)
|
||||||
|
del state['_db']
|
||||||
|
del state['_dtypes']
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__ = state
|
||||||
|
self._dtypes = {}
|
||||||
|
|
||||||
|
|
||||||
class Records():
|
class Records():
|
||||||
|
|
||||||
|
@ -201,7 +201,7 @@ class Simulation(NetworkSimulation):
|
|||||||
return self.run_trial(*args, **kwargs)
|
return self.run_trial(*args, **kwargs)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
c = ex.__cause__
|
c = ex.__cause__
|
||||||
c.message = ''.join(traceback.format_tb(c.__traceback__)[3:])
|
c.message = ''.join(traceback.format_exception(type(c), c, c.__traceback__)[:])
|
||||||
return c
|
return c
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
|
@ -2,6 +2,7 @@ from unittest import TestCase
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
|
import pickle
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@ -248,12 +249,10 @@ class TestMain(TestCase):
|
|||||||
assert name == 'soil.agents.BaseAgent'
|
assert name == 'soil.agents.BaseAgent'
|
||||||
assert ser == agents.BaseAgent
|
assert ser == agents.BaseAgent
|
||||||
|
|
||||||
class CustomAgent(agents.BaseAgent):
|
|
||||||
pass
|
|
||||||
|
|
||||||
ser, name = utils.serialize(CustomAgent)
|
ser, name = utils.serialize(CustomAgent)
|
||||||
assert name == 'test_main.CustomAgent'
|
assert name == 'test_main.CustomAgent'
|
||||||
assert ser == CustomAgent
|
assert ser == CustomAgent
|
||||||
|
pickle.dumps(ser)
|
||||||
|
|
||||||
def test_serialize_builtin_types(self):
|
def test_serialize_builtin_types(self):
|
||||||
|
|
||||||
@ -269,6 +268,7 @@ class TestMain(TestCase):
|
|||||||
assert ser == 'test_main.CustomAgent'
|
assert ser == 'test_main.CustomAgent'
|
||||||
ser = agents.serialize_type(agents.BaseAgent)
|
ser = agents.serialize_type(agents.BaseAgent)
|
||||||
assert ser == 'BaseAgent'
|
assert ser == 'BaseAgent'
|
||||||
|
pickle.dumps(ser)
|
||||||
|
|
||||||
def test_deserialize_agent_distribution(self):
|
def test_deserialize_agent_distribution(self):
|
||||||
agent_distro = [
|
agent_distro = [
|
||||||
@ -284,6 +284,7 @@ class TestMain(TestCase):
|
|||||||
converted = agents.deserialize_distribution(agent_distro)
|
converted = agents.deserialize_distribution(agent_distro)
|
||||||
assert converted[0]['agent_type'] == agents.CounterModel
|
assert converted[0]['agent_type'] == agents.CounterModel
|
||||||
assert converted[1]['agent_type'] == CustomAgent
|
assert converted[1]['agent_type'] == CustomAgent
|
||||||
|
pickle.dumps(converted)
|
||||||
|
|
||||||
def test_serialize_agent_distribution(self):
|
def test_serialize_agent_distribution(self):
|
||||||
agent_distro = [
|
agent_distro = [
|
||||||
@ -299,6 +300,20 @@ class TestMain(TestCase):
|
|||||||
converted = agents.serialize_distribution(agent_distro)
|
converted = agents.serialize_distribution(agent_distro)
|
||||||
assert converted[0]['agent_type'] == 'CounterModel'
|
assert converted[0]['agent_type'] == 'CounterModel'
|
||||||
assert converted[1]['agent_type'] == 'test_main.CustomAgent'
|
assert converted[1]['agent_type'] == 'test_main.CustomAgent'
|
||||||
|
pickle.dumps(converted)
|
||||||
|
|
||||||
|
def test_pickle_agent_environment(self):
|
||||||
|
env = Environment(name='Test')
|
||||||
|
a = agents.BaseAgent(environment=env, agent_id=25)
|
||||||
|
|
||||||
|
a['key'] = 'test'
|
||||||
|
|
||||||
|
pickled = pickle.dumps(a)
|
||||||
|
recovered = pickle.loads(pickled)
|
||||||
|
|
||||||
|
assert recovered.env.name == 'Test'
|
||||||
|
assert recovered['key'] == 'test'
|
||||||
|
assert recovered['key', 0] == 'test'
|
||||||
|
|
||||||
def test_history(self):
|
def test_history(self):
|
||||||
'''Test storing in and retrieving from history (sqlite)'''
|
'''Test storing in and retrieving from history (sqlite)'''
|
||||||
|
Loading…
Reference in New Issue
Block a user