mirror of
				https://github.com/gsi-upm/soil
				synced 2025-11-04 01:18:17 +00:00 
			
		
		
		
	Python3.7, testing and bug fixes
* Upgrade to python3.7 and pandas 0.3.4 because pandas has dropped support for python 3.4 -> There are some API changes in pandas, and I've update the code accordingly. * Set pytest as the default test runner
This commit is contained in:
		@@ -1,4 +1,11 @@
 | 
			
		||||
FROM python:3.4-onbuild
 | 
			
		||||
FROM python:3.7
 | 
			
		||||
 | 
			
		||||
WORKDIR /usr/src/app
 | 
			
		||||
 | 
			
		||||
COPY test-requirements.txt requirements.txt /usr/src/app/
 | 
			
		||||
RUN pip install --no-cache-dir -r test-requirements.txt -r requirements.txt
 | 
			
		||||
 | 
			
		||||
COPY ./ /usr/src/app
 | 
			
		||||
 | 
			
		||||
RUN pip install '.[web]'
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,8 @@ version: '3'
 | 
			
		||||
services:
 | 
			
		||||
  dev:
 | 
			
		||||
    build: .
 | 
			
		||||
    environment:
 | 
			
		||||
      PYTHONDONTWRITEBYTECODE: 1
 | 
			
		||||
    volumes:
 | 
			
		||||
      - .:/usr/src/app
 | 
			
		||||
    tty: true
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,8 @@ from random import random, shuffle
 | 
			
		||||
from itertools import islice
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
import other_module
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CityPubs(Environment):
 | 
			
		||||
    '''Environment with Pubs'''
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										4
									
								
								setup.cfg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								setup.cfg
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,4 @@
 | 
			
		||||
[aliases]
 | 
			
		||||
test=pytest
 | 
			
		||||
[tool:pytest]
 | 
			
		||||
addopts = --verbose
 | 
			
		||||
@@ -1 +1 @@
 | 
			
		||||
0.12.0
 | 
			
		||||
0.13.0
 | 
			
		||||
@@ -16,7 +16,7 @@ class SentimentCorrelationModel(BaseAgent):
 | 
			
		||||
        disgust_prob
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, environment=None, agent_id=0, state=()):
 | 
			
		||||
    def __init__(self, environment, agent_id=0, state=()):
 | 
			
		||||
        super().__init__(environment=environment, agent_id=agent_id, state=state)
 | 
			
		||||
        self.outside_effects_prob = environment.environment_params['outside_effects_prob']
 | 
			
		||||
        self.anger_prob = environment.environment_params['anger_prob']
 | 
			
		||||
 
 | 
			
		||||
@@ -324,15 +324,14 @@ def calculate_distribution(network_agents=None,
 | 
			
		||||
    return network_agents
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def serialize_agent_type(agent_type):
 | 
			
		||||
def serialize_type(agent_type, known_modules=[], **kwargs):
 | 
			
		||||
    if isinstance(agent_type, str):
 | 
			
		||||
        return agent_type
 | 
			
		||||
    type_name = agent_type.__name__
 | 
			
		||||
    if type_name not in globals():
 | 
			
		||||
        type_name = utils.name(agent_type)
 | 
			
		||||
    return type_name
 | 
			
		||||
    known_modules += ['soil.agents']
 | 
			
		||||
    return utils.serialize(agent_type, known_modules=known_modules, **kwargs)[1] # Get the name of the class
 | 
			
		||||
 | 
			
		||||
def serialize_distribution(network_agents):
 | 
			
		||||
 | 
			
		||||
def serialize_distribution(network_agents, known_modules=[]):
 | 
			
		||||
    '''
 | 
			
		||||
    When serializing an agent distribution, remove the thresholds, in order
 | 
			
		||||
    to avoid cluttering the YAML definition file.
 | 
			
		||||
@@ -341,25 +340,23 @@ def serialize_distribution(network_agents):
 | 
			
		||||
    for v in d:
 | 
			
		||||
        if 'threshold' in v:
 | 
			
		||||
            del v['threshold']
 | 
			
		||||
        v['agent_type'] = serialize_agent_type(v['agent_type'])
 | 
			
		||||
        v['agent_type'] = serialize_type(v['agent_type'],
 | 
			
		||||
                                         known_modules=known_modules)
 | 
			
		||||
    return d
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def deserialize_type(agent_type, known_modules=[]):
 | 
			
		||||
    if not isinstance(agent_type, str):
 | 
			
		||||
        return agent_type
 | 
			
		||||
    if agent_type in globals():
 | 
			
		||||
        agent_type = globals()[agent_type]
 | 
			
		||||
    else:
 | 
			
		||||
    known = known_modules + ['soil.agents', 'soil.agents.custom' ]
 | 
			
		||||
    agent_type = utils.deserializer(agent_type, known_modules=known)
 | 
			
		||||
    return agent_type
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def deserialize_distribution(ind):
 | 
			
		||||
def deserialize_distribution(ind, **kwargs):
 | 
			
		||||
    d = deepcopy(ind)
 | 
			
		||||
    for v in d:
 | 
			
		||||
        v['agent_type'] = deserialize_type(v['agent_type'])
 | 
			
		||||
        v['agent_type'] = deserialize_type(v['agent_type'], **kwargs)
 | 
			
		||||
    return d
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -374,11 +371,11 @@ def _validate_states(states, topology):
 | 
			
		||||
    return states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _convert_agent_types(ind, to_string=False):
 | 
			
		||||
def _convert_agent_types(ind, to_string=False, **kwargs):
 | 
			
		||||
    '''Convenience method to allow specifying agents by class or class name.'''
 | 
			
		||||
    if to_string:
 | 
			
		||||
        return serialize_distribution(ind)
 | 
			
		||||
    return deserialize_distribution(ind)
 | 
			
		||||
        return serialize_distribution(ind, **kwargs)
 | 
			
		||||
    return deserialize_distribution(ind, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _agent_from_distribution(distribution, value=-1):
 | 
			
		||||
 
 | 
			
		||||
@@ -123,7 +123,7 @@ def get_count(df, *keys):
 | 
			
		||||
        df = df[list(keys)]
 | 
			
		||||
    counts = pd.DataFrame()
 | 
			
		||||
    for key in df.columns.levels[0]:
 | 
			
		||||
        g = df[key].apply(pd.Series.value_counts, axis=1).fillna(0)
 | 
			
		||||
        g = df[[key]].apply(pd.Series.value_counts, axis=1).fillna(0)
 | 
			
		||||
        for value, series in g.iteritems():
 | 
			
		||||
            counts[key, value] = series
 | 
			
		||||
    counts.columns = pd.MultiIndex.from_tuples(counts.columns)
 | 
			
		||||
 
 | 
			
		||||
@@ -110,7 +110,7 @@ class Environment(nxsim.NetworkEnvironment):
 | 
			
		||||
            agent_type = self.default_state['agent_type']
 | 
			
		||||
 | 
			
		||||
        if agent_type:
 | 
			
		||||
            agent_type = agents.deserialize_agent_type(agent_type)
 | 
			
		||||
            agent_type = agents.deserialize_type(agent_type)
 | 
			
		||||
        else:
 | 
			
		||||
            agent_type, state = agents._agent_from_distribution(agent_distribution)
 | 
			
		||||
        return self.set_agent(agent_id, agent_type, state)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,7 @@ import os
 | 
			
		||||
import pandas as pd
 | 
			
		||||
import sqlite3
 | 
			
		||||
import copy
 | 
			
		||||
from collections import UserDict, Iterable, namedtuple
 | 
			
		||||
from collections import UserDict, namedtuple
 | 
			
		||||
 | 
			
		||||
from . import utils
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,9 @@
 | 
			
		||||
import os
 | 
			
		||||
import time
 | 
			
		||||
import imp
 | 
			
		||||
import importlib
 | 
			
		||||
import sys
 | 
			
		||||
import yaml
 | 
			
		||||
import traceback
 | 
			
		||||
import networkx as nx
 | 
			
		||||
from networkx.readwrite import json_graph
 | 
			
		||||
from multiprocessing import Pool
 | 
			
		||||
@@ -78,6 +79,7 @@ class Simulation(NetworkSimulation):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name=None, topology=None, network_params=None,
 | 
			
		||||
                 network_agents=None, agent_type=None, states=None,
 | 
			
		||||
                 default_state=None, interval=1, dump=None, dry_run=False,
 | 
			
		||||
@@ -104,23 +106,21 @@ class Simulation(NetworkSimulation):
 | 
			
		||||
        self.seed = str(seed) or str(time.time())
 | 
			
		||||
        self.dump = dump
 | 
			
		||||
        self.dry_run = dry_run
 | 
			
		||||
 | 
			
		||||
        sys.path += [self.dir_path, os.getcwd()]
 | 
			
		||||
 | 
			
		||||
        self.environment_params = environment_params or {}
 | 
			
		||||
        self.environment_class = utils.deserialize(environment_class,
 | 
			
		||||
                                                   known_modules=['soil.environment',]) or Environment
 | 
			
		||||
 | 
			
		||||
        self._loaded_module = None
 | 
			
		||||
 | 
			
		||||
        if load_module:
 | 
			
		||||
            path = sys.path + [self.dir_path, os.getcwd()]
 | 
			
		||||
            f, fp, desc = imp.find_module(load_module, path)
 | 
			
		||||
            self._loaded_module = imp.load_module('soil.agents.custom', f, fp, desc)
 | 
			
		||||
                                                   known_modules=['soil.environment', ]) or Environment
 | 
			
		||||
 | 
			
		||||
        environment_agents = environment_agents or []
 | 
			
		||||
        self.environment_agents = agents._convert_agent_types(environment_agents)
 | 
			
		||||
        self.environment_agents = agents._convert_agent_types(environment_agents,
 | 
			
		||||
                                                              known_modules=[self.load_module])
 | 
			
		||||
 | 
			
		||||
        distro = agents.calculate_distribution(network_agents,
 | 
			
		||||
                                               agent_type)
 | 
			
		||||
        self.network_agents = agents._convert_agent_types(distro)
 | 
			
		||||
        self.network_agents = agents._convert_agent_types(distro,
 | 
			
		||||
                                                          known_modules=[self.load_module])
 | 
			
		||||
 | 
			
		||||
        self.states = agents._validate_states(states,
 | 
			
		||||
                                              self.topology)
 | 
			
		||||
@@ -136,13 +136,17 @@ class Simulation(NetworkSimulation):
 | 
			
		||||
        p = Pool()
 | 
			
		||||
        with utils.timer('simulation {}'.format(self.name)):
 | 
			
		||||
            if parallel:
 | 
			
		||||
                func = partial(self.run_trial, dry_run=dry_run or self.dry_run,
 | 
			
		||||
                               return_env=not parallel, **kwargs)
 | 
			
		||||
                func = partial(self.run_trial_exceptions, dry_run=dry_run or self.dry_run,
 | 
			
		||||
                                                         return_env=True,
 | 
			
		||||
                                                         **kwargs)
 | 
			
		||||
                for i in p.imap_unordered(func, range(self.num_trials)):
 | 
			
		||||
                    if isinstance(i, Exception):
 | 
			
		||||
                        logger.error('Trial failed:\n\t{}'.format(i.message))
 | 
			
		||||
                        continue
 | 
			
		||||
                    yield i
 | 
			
		||||
            else:
 | 
			
		||||
                for i in range(self.num_trials):
 | 
			
		||||
                    yield self.run_trial(i, dry_run=dry_run or self.dry_run, **kwargs)
 | 
			
		||||
                    yield self.run_trial(i, dry_run = dry_run or self.dry_run, **kwargs)
 | 
			
		||||
            if not (dry_run or self.dry_run):
 | 
			
		||||
                logger.info('Dumping results to {}'.format(self.dir_path))
 | 
			
		||||
                self.dump_pickle(self.dir_path)
 | 
			
		||||
@@ -150,9 +154,9 @@ class Simulation(NetworkSimulation):
 | 
			
		||||
            else:
 | 
			
		||||
                logger.info('NOT dumping results')
 | 
			
		||||
 | 
			
		||||
    def get_env(self, trial_id=0, **kwargs):
 | 
			
		||||
        opts = self.environment_params.copy()
 | 
			
		||||
        env_name = '{}_trial_{}'.format(self.name, trial_id)
 | 
			
		||||
    def get_env(self, trial_id = 0, **kwargs):
 | 
			
		||||
        opts=self.environment_params.copy()
 | 
			
		||||
        env_name='{}_trial_{}'.format(self.name, trial_id)
 | 
			
		||||
        opts.update({
 | 
			
		||||
            'name': env_name,
 | 
			
		||||
            'topology': self.topology.copy(),
 | 
			
		||||
@@ -167,10 +171,10 @@ class Simulation(NetworkSimulation):
 | 
			
		||||
            'dir_path': self.dir_path,
 | 
			
		||||
        })
 | 
			
		||||
        opts.update(kwargs)
 | 
			
		||||
        env = self.environment_class(**opts)
 | 
			
		||||
        env=self.environment_class(**opts)
 | 
			
		||||
        return env
 | 
			
		||||
 | 
			
		||||
    def run_trial(self, trial_id=0, until=None, return_env=True, **opts):
 | 
			
		||||
    def run_trial(self, trial_id = 0, until = None, return_env = True, **opts):
 | 
			
		||||
        """Run a single trial of the simulation
 | 
			
		||||
 | 
			
		||||
        Parameters
 | 
			
		||||
@@ -178,16 +182,27 @@ class Simulation(NetworkSimulation):
 | 
			
		||||
        trial_id : int
 | 
			
		||||
        """
 | 
			
		||||
        # Set-up trial environment and graph
 | 
			
		||||
        until = until or self.max_time
 | 
			
		||||
        env = self.get_env(trial_id=trial_id, **opts)
 | 
			
		||||
        until=until or self.max_time
 | 
			
		||||
        env=self.get_env(trial_id = trial_id, **opts)
 | 
			
		||||
        # Set up agents on nodes
 | 
			
		||||
        with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
 | 
			
		||||
            env.run(until)
 | 
			
		||||
        if self.dump and not self.dry_run:
 | 
			
		||||
            with utils.timer('Dumping simulation {} trial {}'.format(self.name, trial_id)):
 | 
			
		||||
                env.dump(formats=self.dump)
 | 
			
		||||
                env.dump(formats = self.dump)
 | 
			
		||||
        if return_env:
 | 
			
		||||
            return env
 | 
			
		||||
    def run_trial_exceptions(self, *args, **kwargs):
 | 
			
		||||
        '''
 | 
			
		||||
        A wrapper for run_trial that catches exceptions and returns them.
 | 
			
		||||
        It is meant for async simulations
 | 
			
		||||
        '''
 | 
			
		||||
        try:
 | 
			
		||||
            return self.run_trial(*args, **kwargs)
 | 
			
		||||
        except Exception as ex:
 | 
			
		||||
            c = ex.__cause__
 | 
			
		||||
            c.message = ''.join(traceback.format_tb(c.__traceback__)[3:])
 | 
			
		||||
            return c
 | 
			
		||||
 | 
			
		||||
    def to_dict(self):
 | 
			
		||||
        return self.__getstate__()
 | 
			
		||||
@@ -195,48 +210,53 @@ class Simulation(NetworkSimulation):
 | 
			
		||||
    def to_yaml(self):
 | 
			
		||||
        return yaml.dump(self.to_dict())
 | 
			
		||||
 | 
			
		||||
    def dump_yaml(self, dir_path=None, file_name=None):
 | 
			
		||||
        dir_path = dir_path or self.dir_path
 | 
			
		||||
    def dump_yaml(self, dir_path = None, file_name = None):
 | 
			
		||||
        dir_path=dir_path or self.dir_path
 | 
			
		||||
        if not os.path.exists(dir_path):
 | 
			
		||||
            os.makedirs(dir_path)
 | 
			
		||||
        if not file_name:
 | 
			
		||||
            file_name = os.path.join(dir_path,
 | 
			
		||||
            file_name=os.path.join(dir_path,
 | 
			
		||||
                                     '{}.dumped.yml'.format(self.name))
 | 
			
		||||
        with open(file_name, 'w') as f:
 | 
			
		||||
            f.write(self.to_yaml())
 | 
			
		||||
 | 
			
		||||
    def dump_pickle(self, dir_path=None, pickle_name=None):
 | 
			
		||||
        dir_path = dir_path or self.dir_path
 | 
			
		||||
    def dump_pickle(self, dir_path = None, pickle_name = None):
 | 
			
		||||
        dir_path=dir_path or self.dir_path
 | 
			
		||||
        if not os.path.exists(dir_path):
 | 
			
		||||
            os.makedirs(dir_path)
 | 
			
		||||
        if not pickle_name:
 | 
			
		||||
            pickle_name = os.path.join(dir_path,
 | 
			
		||||
            pickle_name=os.path.join(dir_path,
 | 
			
		||||
                                       '{}.simulation.pickle'.format(self.name))
 | 
			
		||||
        with open(pickle_name, 'wb') as f:
 | 
			
		||||
            pickle.dump(self, f)
 | 
			
		||||
 | 
			
		||||
    def __getstate__(self):
 | 
			
		||||
        state = {}
 | 
			
		||||
        state={}
 | 
			
		||||
        for k, v in self.__dict__.items():
 | 
			
		||||
            if k[0] != '_':
 | 
			
		||||
                state[k] = v
 | 
			
		||||
        state['topology'] = json_graph.node_link_data(self.topology)
 | 
			
		||||
        state['network_agents'] = agents.serialize_distribution(self.network_agents)
 | 
			
		||||
        state['environment_agents'] = agents._convert_agent_types(self.environment_agents,
 | 
			
		||||
                                                                 to_string=True)
 | 
			
		||||
        state['environment_class'] = utils.serialize(self.environment_class,
 | 
			
		||||
                                                     known_modules=['soil.environment', ])[1]  # func, name
 | 
			
		||||
                state[k]=v
 | 
			
		||||
        state['topology']=json_graph.node_link_data(self.topology)
 | 
			
		||||
        state['network_agents']=agents.serialize_distribution(self.network_agents,
 | 
			
		||||
                                                                known_modules = [])
 | 
			
		||||
        state['environment_agents']=agents.serialize_distribution(self.environment_agents,
 | 
			
		||||
                                                                    known_modules = [])
 | 
			
		||||
        state['environment_class']=utils.serialize(self.environment_class,
 | 
			
		||||
                                                     known_modules=['soil.environment'])[1]  # func, name
 | 
			
		||||
        if state['load_module'] is None:
 | 
			
		||||
            del state['load_module']
 | 
			
		||||
        return state
 | 
			
		||||
 | 
			
		||||
    def __setstate__(self, state):
 | 
			
		||||
        self.__dict__ = state
 | 
			
		||||
        self.load_module = getattr(self, 'load_module', None)
 | 
			
		||||
        if self.dir_path not in sys.path:
 | 
			
		||||
            sys.path += [self.dir_path, os.getcwd()]
 | 
			
		||||
        self.topology = json_graph.node_link_graph(state['topology'])
 | 
			
		||||
        self.network_agents = agents.calculate_distribution(agents._convert_agent_types(self.network_agents))
 | 
			
		||||
        self.environment_agents = agents._convert_agent_types(self.environment_agents)
 | 
			
		||||
        self.environment_agents = agents._convert_agent_types(self.environment_agents,
 | 
			
		||||
                                                              known_modules=[self.load_module])
 | 
			
		||||
        self.environment_class = utils.deserialize(self.environment_class,
 | 
			
		||||
                                                     known_modules=['soil.environment', ])  # func, name
 | 
			
		||||
                                                   known_modules=[self.load_module, 'soil.environment', ])  # func, name
 | 
			
		||||
        return state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -92,8 +92,10 @@ def name(value, known_modules=[]):
 | 
			
		||||
        return tname
 | 
			
		||||
    if known_modules and modname in known_modules:
 | 
			
		||||
        return tname
 | 
			
		||||
    for mod_name in known_modules:
 | 
			
		||||
        module = importlib.import_module(mod_name)
 | 
			
		||||
    for kmod in known_modules:
 | 
			
		||||
        if not kmod:
 | 
			
		||||
            continue
 | 
			
		||||
        module = importlib.import_module(kmod)
 | 
			
		||||
        if hasattr(module, tname):
 | 
			
		||||
            return tname
 | 
			
		||||
    return '{}.{}'.format(modname, tname)
 | 
			
		||||
@@ -124,6 +126,7 @@ def deserializer(type_, known_modules=[]):
 | 
			
		||||
    options = []
 | 
			
		||||
 | 
			
		||||
    for mod in modules:
 | 
			
		||||
        if mod:
 | 
			
		||||
            options.append((mod, type_))
 | 
			
		||||
 | 
			
		||||
    if '.' in type_:  # Fully qualified module
 | 
			
		||||
@@ -131,14 +134,14 @@ def deserializer(type_, known_modules=[]):
 | 
			
		||||
        options.append ((module, type_))
 | 
			
		||||
 | 
			
		||||
    errors = []
 | 
			
		||||
    for module, name in options:
 | 
			
		||||
    for modname, tname in options:
 | 
			
		||||
        try:
 | 
			
		||||
            module = importlib.import_module(module)
 | 
			
		||||
            cls = getattr(module, name)
 | 
			
		||||
            module = importlib.import_module(modname)
 | 
			
		||||
            cls = getattr(module, tname)
 | 
			
		||||
            return getattr(cls, 'deserialize', cls)
 | 
			
		||||
        except (ImportError, AttributeError) as ex:
 | 
			
		||||
            errors.append((module, name, ex))
 | 
			
		||||
    raise Exception('Could not find module {}. Tried: {}'.format(type_, errors))
 | 
			
		||||
            errors.append((modname, tname, ex))
 | 
			
		||||
    raise Exception('Could not find type {}. Tried: {}'.format(type_, errors))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def deserialize(type_, value=None, **kwargs):
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1 @@
 | 
			
		||||
pytest
 | 
			
		||||
@@ -129,7 +129,7 @@ class TestHistory(TestCase):
 | 
			
		||||
        backuppaths = glob(db_path + '.backup*.sqlite')
 | 
			
		||||
        assert len(backuppaths) == 1
 | 
			
		||||
        backuppath = backuppaths[0]
 | 
			
		||||
        assert newhistory._db_path == h._db_path
 | 
			
		||||
        assert newhistory.db_path == h.db_path
 | 
			
		||||
        assert os.path.exists(backuppath)
 | 
			
		||||
        assert not len(newhistory[None, None, None])
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -12,6 +12,12 @@ from soil import simulation, Environment, agents, utils, history
 | 
			
		||||
ROOT = os.path.abspath(os.path.dirname(__file__))
 | 
			
		||||
EXAMPLES = join(ROOT, '..', 'examples')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CustomAgent(agents.BaseAgent):
 | 
			
		||||
    def step(self):
 | 
			
		||||
        self.state['neighbors'] = self.count_agents(state_id=0,
 | 
			
		||||
                                                    limit_neighbors=True)
 | 
			
		||||
 | 
			
		||||
class TestMain(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_load_graph(self):
 | 
			
		||||
@@ -125,10 +131,6 @@ class TestMain(TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_custom_agent(self):
 | 
			
		||||
        """Allow for search of neighbors with a certain state_id"""
 | 
			
		||||
        class CustomAgent(agents.BaseAgent):
 | 
			
		||||
            def step(self):
 | 
			
		||||
                self.state['neighbors'] = self.count_agents(state_id=0,
 | 
			
		||||
                                                            limit_neighbors=True)
 | 
			
		||||
        config = {
 | 
			
		||||
            'dry_run': True,
 | 
			
		||||
            'network_params': {
 | 
			
		||||
@@ -261,6 +263,13 @@ class TestMain(TestCase):
 | 
			
		||||
            des = utils.deserialize(name, ser)
 | 
			
		||||
            assert i == des
 | 
			
		||||
 | 
			
		||||
    def test_serialize_agent_type(self):
 | 
			
		||||
        '''A class from soil.agents should be serialized without the module part'''
 | 
			
		||||
        ser = agents.serialize_type(CustomAgent)
 | 
			
		||||
        assert ser == 'test_main.CustomAgent'
 | 
			
		||||
        ser = agents.serialize_type(agents.BaseAgent)
 | 
			
		||||
        assert ser == 'BaseAgent'
 | 
			
		||||
 | 
			
		||||
    def test_deserialize_agent_distribution(self):
 | 
			
		||||
        agent_distro = [
 | 
			
		||||
            {
 | 
			
		||||
@@ -268,13 +277,13 @@ class TestMain(TestCase):
 | 
			
		||||
                'weight': 1
 | 
			
		||||
            },
 | 
			
		||||
            {
 | 
			
		||||
                'agent_type': 'BaseAgent',
 | 
			
		||||
                'agent_type': 'test_main.CustomAgent',
 | 
			
		||||
                'weight': 2
 | 
			
		||||
            },
 | 
			
		||||
        ]
 | 
			
		||||
        converted = agents.deserialize_distribution(agent_distro)
 | 
			
		||||
        assert converted[0]['agent_type'] == agents.CounterModel
 | 
			
		||||
        assert converted[1]['agent_type'] == agents.BaseAgent
 | 
			
		||||
        assert converted[1]['agent_type'] == CustomAgent
 | 
			
		||||
 | 
			
		||||
    def test_serialize_agent_distribution(self):
 | 
			
		||||
        agent_distro = [
 | 
			
		||||
@@ -283,13 +292,13 @@ class TestMain(TestCase):
 | 
			
		||||
                'weight': 1
 | 
			
		||||
            },
 | 
			
		||||
            {
 | 
			
		||||
                'agent_type': agents.BaseAgent,
 | 
			
		||||
                'agent_type': CustomAgent,
 | 
			
		||||
                'weight': 2
 | 
			
		||||
            },
 | 
			
		||||
        ]
 | 
			
		||||
        converted = agents.serialize_distribution(agent_distro)
 | 
			
		||||
        assert converted[0]['agent_type'] == 'CounterModel'
 | 
			
		||||
        assert converted[1]['agent_type'] == 'BaseAgent'
 | 
			
		||||
        assert converted[1]['agent_type'] == 'test_main.CustomAgent'
 | 
			
		||||
 | 
			
		||||
    def test_history(self):
 | 
			
		||||
        '''Test storing in and retrieving from history (sqlite)'''
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user