mirror of
https://github.com/gsi-upm/soil
synced 2024-12-22 08:18:13 +00:00
Python3.7, testing and bug fixes
* Upgrade to python3.7 and pandas 0.3.4 because pandas has dropped support for python 3.4 -> There are some API changes in pandas, and I've update the code accordingly. * Set pytest as the default test runner
This commit is contained in:
parent
bd4700567e
commit
2e28b36f6e
@ -1,4 +1,11 @@
|
|||||||
FROM python:3.4-onbuild
|
FROM python:3.7
|
||||||
|
|
||||||
|
WORKDIR /usr/src/app
|
||||||
|
|
||||||
|
COPY test-requirements.txt requirements.txt /usr/src/app/
|
||||||
|
RUN pip install --no-cache-dir -r test-requirements.txt -r requirements.txt
|
||||||
|
|
||||||
|
COPY ./ /usr/src/app
|
||||||
|
|
||||||
RUN pip install '.[web]'
|
RUN pip install '.[web]'
|
||||||
|
|
||||||
|
@ -2,6 +2,8 @@ version: '3'
|
|||||||
services:
|
services:
|
||||||
dev:
|
dev:
|
||||||
build: .
|
build: .
|
||||||
|
environment:
|
||||||
|
PYTHONDONTWRITEBYTECODE: 1
|
||||||
volumes:
|
volumes:
|
||||||
- .:/usr/src/app
|
- .:/usr/src/app
|
||||||
tty: true
|
tty: true
|
||||||
|
@ -4,6 +4,8 @@ from random import random, shuffle
|
|||||||
from itertools import islice
|
from itertools import islice
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import other_module
|
||||||
|
|
||||||
|
|
||||||
class CityPubs(Environment):
|
class CityPubs(Environment):
|
||||||
'''Environment with Pubs'''
|
'''Environment with Pubs'''
|
||||||
|
4
setup.cfg
Normal file
4
setup.cfg
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
[aliases]
|
||||||
|
test=pytest
|
||||||
|
[tool:pytest]
|
||||||
|
addopts = --verbose
|
@ -1 +1 @@
|
|||||||
0.12.0
|
0.13.0
|
@ -16,7 +16,7 @@ class SentimentCorrelationModel(BaseAgent):
|
|||||||
disgust_prob
|
disgust_prob
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, environment=None, agent_id=0, state=()):
|
def __init__(self, environment, agent_id=0, state=()):
|
||||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||||
self.outside_effects_prob = environment.environment_params['outside_effects_prob']
|
self.outside_effects_prob = environment.environment_params['outside_effects_prob']
|
||||||
self.anger_prob = environment.environment_params['anger_prob']
|
self.anger_prob = environment.environment_params['anger_prob']
|
||||||
|
@ -324,15 +324,14 @@ def calculate_distribution(network_agents=None,
|
|||||||
return network_agents
|
return network_agents
|
||||||
|
|
||||||
|
|
||||||
def serialize_agent_type(agent_type):
|
def serialize_type(agent_type, known_modules=[], **kwargs):
|
||||||
if isinstance(agent_type, str):
|
if isinstance(agent_type, str):
|
||||||
return agent_type
|
return agent_type
|
||||||
type_name = agent_type.__name__
|
known_modules += ['soil.agents']
|
||||||
if type_name not in globals():
|
return utils.serialize(agent_type, known_modules=known_modules, **kwargs)[1] # Get the name of the class
|
||||||
type_name = utils.name(agent_type)
|
|
||||||
return type_name
|
|
||||||
|
|
||||||
def serialize_distribution(network_agents):
|
|
||||||
|
def serialize_distribution(network_agents, known_modules=[]):
|
||||||
'''
|
'''
|
||||||
When serializing an agent distribution, remove the thresholds, in order
|
When serializing an agent distribution, remove the thresholds, in order
|
||||||
to avoid cluttering the YAML definition file.
|
to avoid cluttering the YAML definition file.
|
||||||
@ -341,25 +340,23 @@ def serialize_distribution(network_agents):
|
|||||||
for v in d:
|
for v in d:
|
||||||
if 'threshold' in v:
|
if 'threshold' in v:
|
||||||
del v['threshold']
|
del v['threshold']
|
||||||
v['agent_type'] = serialize_agent_type(v['agent_type'])
|
v['agent_type'] = serialize_type(v['agent_type'],
|
||||||
|
known_modules=known_modules)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def deserialize_type(agent_type, known_modules=[]):
|
def deserialize_type(agent_type, known_modules=[]):
|
||||||
if not isinstance(agent_type, str):
|
if not isinstance(agent_type, str):
|
||||||
return agent_type
|
return agent_type
|
||||||
if agent_type in globals():
|
|
||||||
agent_type = globals()[agent_type]
|
|
||||||
else:
|
|
||||||
known = known_modules + ['soil.agents', 'soil.agents.custom' ]
|
known = known_modules + ['soil.agents', 'soil.agents.custom' ]
|
||||||
agent_type = utils.deserializer(agent_type, known_modules=known)
|
agent_type = utils.deserializer(agent_type, known_modules=known)
|
||||||
return agent_type
|
return agent_type
|
||||||
|
|
||||||
|
|
||||||
def deserialize_distribution(ind):
|
def deserialize_distribution(ind, **kwargs):
|
||||||
d = deepcopy(ind)
|
d = deepcopy(ind)
|
||||||
for v in d:
|
for v in d:
|
||||||
v['agent_type'] = deserialize_type(v['agent_type'])
|
v['agent_type'] = deserialize_type(v['agent_type'], **kwargs)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
@ -374,11 +371,11 @@ def _validate_states(states, topology):
|
|||||||
return states
|
return states
|
||||||
|
|
||||||
|
|
||||||
def _convert_agent_types(ind, to_string=False):
|
def _convert_agent_types(ind, to_string=False, **kwargs):
|
||||||
'''Convenience method to allow specifying agents by class or class name.'''
|
'''Convenience method to allow specifying agents by class or class name.'''
|
||||||
if to_string:
|
if to_string:
|
||||||
return serialize_distribution(ind)
|
return serialize_distribution(ind, **kwargs)
|
||||||
return deserialize_distribution(ind)
|
return deserialize_distribution(ind, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _agent_from_distribution(distribution, value=-1):
|
def _agent_from_distribution(distribution, value=-1):
|
||||||
|
@ -123,7 +123,7 @@ def get_count(df, *keys):
|
|||||||
df = df[list(keys)]
|
df = df[list(keys)]
|
||||||
counts = pd.DataFrame()
|
counts = pd.DataFrame()
|
||||||
for key in df.columns.levels[0]:
|
for key in df.columns.levels[0]:
|
||||||
g = df[key].apply(pd.Series.value_counts, axis=1).fillna(0)
|
g = df[[key]].apply(pd.Series.value_counts, axis=1).fillna(0)
|
||||||
for value, series in g.iteritems():
|
for value, series in g.iteritems():
|
||||||
counts[key, value] = series
|
counts[key, value] = series
|
||||||
counts.columns = pd.MultiIndex.from_tuples(counts.columns)
|
counts.columns = pd.MultiIndex.from_tuples(counts.columns)
|
||||||
|
@ -110,7 +110,7 @@ class Environment(nxsim.NetworkEnvironment):
|
|||||||
agent_type = self.default_state['agent_type']
|
agent_type = self.default_state['agent_type']
|
||||||
|
|
||||||
if agent_type:
|
if agent_type:
|
||||||
agent_type = agents.deserialize_agent_type(agent_type)
|
agent_type = agents.deserialize_type(agent_type)
|
||||||
else:
|
else:
|
||||||
agent_type, state = agents._agent_from_distribution(agent_distribution)
|
agent_type, state = agents._agent_from_distribution(agent_distribution)
|
||||||
return self.set_agent(agent_id, agent_type, state)
|
return self.set_agent(agent_id, agent_type, state)
|
||||||
|
@ -3,7 +3,7 @@ import os
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import copy
|
import copy
|
||||||
from collections import UserDict, Iterable, namedtuple
|
from collections import UserDict, namedtuple
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import imp
|
import importlib
|
||||||
import sys
|
import sys
|
||||||
import yaml
|
import yaml
|
||||||
|
import traceback
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from networkx.readwrite import json_graph
|
from networkx.readwrite import json_graph
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
@ -78,6 +79,7 @@ class Simulation(NetworkSimulation):
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name=None, topology=None, network_params=None,
|
def __init__(self, name=None, topology=None, network_params=None,
|
||||||
network_agents=None, agent_type=None, states=None,
|
network_agents=None, agent_type=None, states=None,
|
||||||
default_state=None, interval=1, dump=None, dry_run=False,
|
default_state=None, interval=1, dump=None, dry_run=False,
|
||||||
@ -104,23 +106,21 @@ class Simulation(NetworkSimulation):
|
|||||||
self.seed = str(seed) or str(time.time())
|
self.seed = str(seed) or str(time.time())
|
||||||
self.dump = dump
|
self.dump = dump
|
||||||
self.dry_run = dry_run
|
self.dry_run = dry_run
|
||||||
|
|
||||||
|
sys.path += [self.dir_path, os.getcwd()]
|
||||||
|
|
||||||
self.environment_params = environment_params or {}
|
self.environment_params = environment_params or {}
|
||||||
self.environment_class = utils.deserialize(environment_class,
|
self.environment_class = utils.deserialize(environment_class,
|
||||||
known_modules=['soil.environment',]) or Environment
|
known_modules=['soil.environment', ]) or Environment
|
||||||
|
|
||||||
self._loaded_module = None
|
|
||||||
|
|
||||||
if load_module:
|
|
||||||
path = sys.path + [self.dir_path, os.getcwd()]
|
|
||||||
f, fp, desc = imp.find_module(load_module, path)
|
|
||||||
self._loaded_module = imp.load_module('soil.agents.custom', f, fp, desc)
|
|
||||||
|
|
||||||
environment_agents = environment_agents or []
|
environment_agents = environment_agents or []
|
||||||
self.environment_agents = agents._convert_agent_types(environment_agents)
|
self.environment_agents = agents._convert_agent_types(environment_agents,
|
||||||
|
known_modules=[self.load_module])
|
||||||
|
|
||||||
distro = agents.calculate_distribution(network_agents,
|
distro = agents.calculate_distribution(network_agents,
|
||||||
agent_type)
|
agent_type)
|
||||||
self.network_agents = agents._convert_agent_types(distro)
|
self.network_agents = agents._convert_agent_types(distro,
|
||||||
|
known_modules=[self.load_module])
|
||||||
|
|
||||||
self.states = agents._validate_states(states,
|
self.states = agents._validate_states(states,
|
||||||
self.topology)
|
self.topology)
|
||||||
@ -136,13 +136,17 @@ class Simulation(NetworkSimulation):
|
|||||||
p = Pool()
|
p = Pool()
|
||||||
with utils.timer('simulation {}'.format(self.name)):
|
with utils.timer('simulation {}'.format(self.name)):
|
||||||
if parallel:
|
if parallel:
|
||||||
func = partial(self.run_trial, dry_run=dry_run or self.dry_run,
|
func = partial(self.run_trial_exceptions, dry_run=dry_run or self.dry_run,
|
||||||
return_env=not parallel, **kwargs)
|
return_env=True,
|
||||||
|
**kwargs)
|
||||||
for i in p.imap_unordered(func, range(self.num_trials)):
|
for i in p.imap_unordered(func, range(self.num_trials)):
|
||||||
|
if isinstance(i, Exception):
|
||||||
|
logger.error('Trial failed:\n\t{}'.format(i.message))
|
||||||
|
continue
|
||||||
yield i
|
yield i
|
||||||
else:
|
else:
|
||||||
for i in range(self.num_trials):
|
for i in range(self.num_trials):
|
||||||
yield self.run_trial(i, dry_run=dry_run or self.dry_run, **kwargs)
|
yield self.run_trial(i, dry_run = dry_run or self.dry_run, **kwargs)
|
||||||
if not (dry_run or self.dry_run):
|
if not (dry_run or self.dry_run):
|
||||||
logger.info('Dumping results to {}'.format(self.dir_path))
|
logger.info('Dumping results to {}'.format(self.dir_path))
|
||||||
self.dump_pickle(self.dir_path)
|
self.dump_pickle(self.dir_path)
|
||||||
@ -150,9 +154,9 @@ class Simulation(NetworkSimulation):
|
|||||||
else:
|
else:
|
||||||
logger.info('NOT dumping results')
|
logger.info('NOT dumping results')
|
||||||
|
|
||||||
def get_env(self, trial_id=0, **kwargs):
|
def get_env(self, trial_id = 0, **kwargs):
|
||||||
opts = self.environment_params.copy()
|
opts=self.environment_params.copy()
|
||||||
env_name = '{}_trial_{}'.format(self.name, trial_id)
|
env_name='{}_trial_{}'.format(self.name, trial_id)
|
||||||
opts.update({
|
opts.update({
|
||||||
'name': env_name,
|
'name': env_name,
|
||||||
'topology': self.topology.copy(),
|
'topology': self.topology.copy(),
|
||||||
@ -167,10 +171,10 @@ class Simulation(NetworkSimulation):
|
|||||||
'dir_path': self.dir_path,
|
'dir_path': self.dir_path,
|
||||||
})
|
})
|
||||||
opts.update(kwargs)
|
opts.update(kwargs)
|
||||||
env = self.environment_class(**opts)
|
env=self.environment_class(**opts)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def run_trial(self, trial_id=0, until=None, return_env=True, **opts):
|
def run_trial(self, trial_id = 0, until = None, return_env = True, **opts):
|
||||||
"""Run a single trial of the simulation
|
"""Run a single trial of the simulation
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -178,16 +182,27 @@ class Simulation(NetworkSimulation):
|
|||||||
trial_id : int
|
trial_id : int
|
||||||
"""
|
"""
|
||||||
# Set-up trial environment and graph
|
# Set-up trial environment and graph
|
||||||
until = until or self.max_time
|
until=until or self.max_time
|
||||||
env = self.get_env(trial_id=trial_id, **opts)
|
env=self.get_env(trial_id = trial_id, **opts)
|
||||||
# Set up agents on nodes
|
# Set up agents on nodes
|
||||||
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
||||||
env.run(until)
|
env.run(until)
|
||||||
if self.dump and not self.dry_run:
|
if self.dump and not self.dry_run:
|
||||||
with utils.timer('Dumping simulation {} trial {}'.format(self.name, trial_id)):
|
with utils.timer('Dumping simulation {} trial {}'.format(self.name, trial_id)):
|
||||||
env.dump(formats=self.dump)
|
env.dump(formats = self.dump)
|
||||||
if return_env:
|
if return_env:
|
||||||
return env
|
return env
|
||||||
|
def run_trial_exceptions(self, *args, **kwargs):
|
||||||
|
'''
|
||||||
|
A wrapper for run_trial that catches exceptions and returns them.
|
||||||
|
It is meant for async simulations
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
return self.run_trial(*args, **kwargs)
|
||||||
|
except Exception as ex:
|
||||||
|
c = ex.__cause__
|
||||||
|
c.message = ''.join(traceback.format_tb(c.__traceback__)[3:])
|
||||||
|
return c
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return self.__getstate__()
|
return self.__getstate__()
|
||||||
@ -195,48 +210,53 @@ class Simulation(NetworkSimulation):
|
|||||||
def to_yaml(self):
|
def to_yaml(self):
|
||||||
return yaml.dump(self.to_dict())
|
return yaml.dump(self.to_dict())
|
||||||
|
|
||||||
def dump_yaml(self, dir_path=None, file_name=None):
|
def dump_yaml(self, dir_path = None, file_name = None):
|
||||||
dir_path = dir_path or self.dir_path
|
dir_path=dir_path or self.dir_path
|
||||||
if not os.path.exists(dir_path):
|
if not os.path.exists(dir_path):
|
||||||
os.makedirs(dir_path)
|
os.makedirs(dir_path)
|
||||||
if not file_name:
|
if not file_name:
|
||||||
file_name = os.path.join(dir_path,
|
file_name=os.path.join(dir_path,
|
||||||
'{}.dumped.yml'.format(self.name))
|
'{}.dumped.yml'.format(self.name))
|
||||||
with open(file_name, 'w') as f:
|
with open(file_name, 'w') as f:
|
||||||
f.write(self.to_yaml())
|
f.write(self.to_yaml())
|
||||||
|
|
||||||
def dump_pickle(self, dir_path=None, pickle_name=None):
|
def dump_pickle(self, dir_path = None, pickle_name = None):
|
||||||
dir_path = dir_path or self.dir_path
|
dir_path=dir_path or self.dir_path
|
||||||
if not os.path.exists(dir_path):
|
if not os.path.exists(dir_path):
|
||||||
os.makedirs(dir_path)
|
os.makedirs(dir_path)
|
||||||
if not pickle_name:
|
if not pickle_name:
|
||||||
pickle_name = os.path.join(dir_path,
|
pickle_name=os.path.join(dir_path,
|
||||||
'{}.simulation.pickle'.format(self.name))
|
'{}.simulation.pickle'.format(self.name))
|
||||||
with open(pickle_name, 'wb') as f:
|
with open(pickle_name, 'wb') as f:
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
state = {}
|
state={}
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if k[0] != '_':
|
if k[0] != '_':
|
||||||
state[k] = v
|
state[k]=v
|
||||||
state['topology'] = json_graph.node_link_data(self.topology)
|
state['topology']=json_graph.node_link_data(self.topology)
|
||||||
state['network_agents'] = agents.serialize_distribution(self.network_agents)
|
state['network_agents']=agents.serialize_distribution(self.network_agents,
|
||||||
state['environment_agents'] = agents._convert_agent_types(self.environment_agents,
|
known_modules = [])
|
||||||
to_string=True)
|
state['environment_agents']=agents.serialize_distribution(self.environment_agents,
|
||||||
state['environment_class'] = utils.serialize(self.environment_class,
|
known_modules = [])
|
||||||
known_modules=['soil.environment', ])[1] # func, name
|
state['environment_class']=utils.serialize(self.environment_class,
|
||||||
|
known_modules=['soil.environment'])[1] # func, name
|
||||||
if state['load_module'] is None:
|
if state['load_module'] is None:
|
||||||
del state['load_module']
|
del state['load_module']
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
self.__dict__ = state
|
self.__dict__ = state
|
||||||
|
self.load_module = getattr(self, 'load_module', None)
|
||||||
|
if self.dir_path not in sys.path:
|
||||||
|
sys.path += [self.dir_path, os.getcwd()]
|
||||||
self.topology = json_graph.node_link_graph(state['topology'])
|
self.topology = json_graph.node_link_graph(state['topology'])
|
||||||
self.network_agents = agents.calculate_distribution(agents._convert_agent_types(self.network_agents))
|
self.network_agents = agents.calculate_distribution(agents._convert_agent_types(self.network_agents))
|
||||||
self.environment_agents = agents._convert_agent_types(self.environment_agents)
|
self.environment_agents = agents._convert_agent_types(self.environment_agents,
|
||||||
|
known_modules=[self.load_module])
|
||||||
self.environment_class = utils.deserialize(self.environment_class,
|
self.environment_class = utils.deserialize(self.environment_class,
|
||||||
known_modules=['soil.environment', ]) # func, name
|
known_modules=[self.load_module, 'soil.environment', ]) # func, name
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,8 +92,10 @@ def name(value, known_modules=[]):
|
|||||||
return tname
|
return tname
|
||||||
if known_modules and modname in known_modules:
|
if known_modules and modname in known_modules:
|
||||||
return tname
|
return tname
|
||||||
for mod_name in known_modules:
|
for kmod in known_modules:
|
||||||
module = importlib.import_module(mod_name)
|
if not kmod:
|
||||||
|
continue
|
||||||
|
module = importlib.import_module(kmod)
|
||||||
if hasattr(module, tname):
|
if hasattr(module, tname):
|
||||||
return tname
|
return tname
|
||||||
return '{}.{}'.format(modname, tname)
|
return '{}.{}'.format(modname, tname)
|
||||||
@ -124,6 +126,7 @@ def deserializer(type_, known_modules=[]):
|
|||||||
options = []
|
options = []
|
||||||
|
|
||||||
for mod in modules:
|
for mod in modules:
|
||||||
|
if mod:
|
||||||
options.append((mod, type_))
|
options.append((mod, type_))
|
||||||
|
|
||||||
if '.' in type_: # Fully qualified module
|
if '.' in type_: # Fully qualified module
|
||||||
@ -131,14 +134,14 @@ def deserializer(type_, known_modules=[]):
|
|||||||
options.append ((module, type_))
|
options.append ((module, type_))
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
for module, name in options:
|
for modname, tname in options:
|
||||||
try:
|
try:
|
||||||
module = importlib.import_module(module)
|
module = importlib.import_module(modname)
|
||||||
cls = getattr(module, name)
|
cls = getattr(module, tname)
|
||||||
return getattr(cls, 'deserialize', cls)
|
return getattr(cls, 'deserialize', cls)
|
||||||
except (ImportError, AttributeError) as ex:
|
except (ImportError, AttributeError) as ex:
|
||||||
errors.append((module, name, ex))
|
errors.append((modname, tname, ex))
|
||||||
raise Exception('Could not find module {}. Tried: {}'.format(type_, errors))
|
raise Exception('Could not find type {}. Tried: {}'.format(type_, errors))
|
||||||
|
|
||||||
|
|
||||||
def deserialize(type_, value=None, **kwargs):
|
def deserialize(type_, value=None, **kwargs):
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
pytest
|
@ -129,7 +129,7 @@ class TestHistory(TestCase):
|
|||||||
backuppaths = glob(db_path + '.backup*.sqlite')
|
backuppaths = glob(db_path + '.backup*.sqlite')
|
||||||
assert len(backuppaths) == 1
|
assert len(backuppaths) == 1
|
||||||
backuppath = backuppaths[0]
|
backuppath = backuppaths[0]
|
||||||
assert newhistory._db_path == h._db_path
|
assert newhistory.db_path == h.db_path
|
||||||
assert os.path.exists(backuppath)
|
assert os.path.exists(backuppath)
|
||||||
assert not len(newhistory[None, None, None])
|
assert not len(newhistory[None, None, None])
|
||||||
|
|
||||||
|
@ -12,6 +12,12 @@ from soil import simulation, Environment, agents, utils, history
|
|||||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||||
EXAMPLES = join(ROOT, '..', 'examples')
|
EXAMPLES = join(ROOT, '..', 'examples')
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAgent(agents.BaseAgent):
|
||||||
|
def step(self):
|
||||||
|
self.state['neighbors'] = self.count_agents(state_id=0,
|
||||||
|
limit_neighbors=True)
|
||||||
|
|
||||||
class TestMain(TestCase):
|
class TestMain(TestCase):
|
||||||
|
|
||||||
def test_load_graph(self):
|
def test_load_graph(self):
|
||||||
@ -125,10 +131,6 @@ class TestMain(TestCase):
|
|||||||
|
|
||||||
def test_custom_agent(self):
|
def test_custom_agent(self):
|
||||||
"""Allow for search of neighbors with a certain state_id"""
|
"""Allow for search of neighbors with a certain state_id"""
|
||||||
class CustomAgent(agents.BaseAgent):
|
|
||||||
def step(self):
|
|
||||||
self.state['neighbors'] = self.count_agents(state_id=0,
|
|
||||||
limit_neighbors=True)
|
|
||||||
config = {
|
config = {
|
||||||
'dry_run': True,
|
'dry_run': True,
|
||||||
'network_params': {
|
'network_params': {
|
||||||
@ -261,6 +263,13 @@ class TestMain(TestCase):
|
|||||||
des = utils.deserialize(name, ser)
|
des = utils.deserialize(name, ser)
|
||||||
assert i == des
|
assert i == des
|
||||||
|
|
||||||
|
def test_serialize_agent_type(self):
|
||||||
|
'''A class from soil.agents should be serialized without the module part'''
|
||||||
|
ser = agents.serialize_type(CustomAgent)
|
||||||
|
assert ser == 'test_main.CustomAgent'
|
||||||
|
ser = agents.serialize_type(agents.BaseAgent)
|
||||||
|
assert ser == 'BaseAgent'
|
||||||
|
|
||||||
def test_deserialize_agent_distribution(self):
|
def test_deserialize_agent_distribution(self):
|
||||||
agent_distro = [
|
agent_distro = [
|
||||||
{
|
{
|
||||||
@ -268,13 +277,13 @@ class TestMain(TestCase):
|
|||||||
'weight': 1
|
'weight': 1
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'agent_type': 'BaseAgent',
|
'agent_type': 'test_main.CustomAgent',
|
||||||
'weight': 2
|
'weight': 2
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
converted = agents.deserialize_distribution(agent_distro)
|
converted = agents.deserialize_distribution(agent_distro)
|
||||||
assert converted[0]['agent_type'] == agents.CounterModel
|
assert converted[0]['agent_type'] == agents.CounterModel
|
||||||
assert converted[1]['agent_type'] == agents.BaseAgent
|
assert converted[1]['agent_type'] == CustomAgent
|
||||||
|
|
||||||
def test_serialize_agent_distribution(self):
|
def test_serialize_agent_distribution(self):
|
||||||
agent_distro = [
|
agent_distro = [
|
||||||
@ -283,13 +292,13 @@ class TestMain(TestCase):
|
|||||||
'weight': 1
|
'weight': 1
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'agent_type': agents.BaseAgent,
|
'agent_type': CustomAgent,
|
||||||
'weight': 2
|
'weight': 2
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
converted = agents.serialize_distribution(agent_distro)
|
converted = agents.serialize_distribution(agent_distro)
|
||||||
assert converted[0]['agent_type'] == 'CounterModel'
|
assert converted[0]['agent_type'] == 'CounterModel'
|
||||||
assert converted[1]['agent_type'] == 'BaseAgent'
|
assert converted[1]['agent_type'] == 'test_main.CustomAgent'
|
||||||
|
|
||||||
def test_history(self):
|
def test_history(self):
|
||||||
'''Test storing in and retrieving from history (sqlite)'''
|
'''Test storing in and retrieving from history (sqlite)'''
|
||||||
|
Loading…
Reference in New Issue
Block a user