1
0
mirror of https://github.com/gsi-upm/soil synced 2024-11-21 18:52:28 +00:00
This commit is contained in:
J. Fernando Sánchez 2022-03-21 12:53:40 +01:00
parent a40aa55b6a
commit 1a8313e4f6
12 changed files with 81 additions and 33 deletions

View File

@ -1 +1 @@
ipython==7.31.1 ipython>=7.31.1

View File

@ -5,5 +5,5 @@ pyyaml>=5.1
pandas>=0.23 pandas>=0.23
SALib>=1.3 SALib>=1.3
Jinja2 Jinja2
Mesa>=0.8 Mesa>=0.8.9
tsih>=0.1.5 tsih>=0.1.6

View File

@ -49,6 +49,7 @@ setup(
extras_require=extras_require, extras_require=extras_require,
tests_require=test_reqs, tests_require=test_reqs,
setup_requires=['pytest-runner', ], setup_requires=['pytest-runner', ],
pytest_plugins = ['pytest_profiling'],
include_package_data=True, include_package_data=True,
entry_points={ entry_points={
'console_scripts': 'console_scripts':

View File

@ -13,7 +13,7 @@ from networkx.readwrite import json_graph
import networkx as nx import networkx as nx
from tsih import History, Record, Key, NoHistory from tsih import History, NoHistory, Record, Key
from mesa import Model from mesa import Model
@ -49,7 +49,7 @@ class Environment(Model):
schedule=None, schedule=None,
initial_time=0, initial_time=0,
environment_params=None, environment_params=None,
history=True, history=False,
dir_path=None, dir_path=None,
**kwargs): **kwargs):
@ -82,10 +82,12 @@ class Environment(Model):
self._env_agents = {} self._env_agents = {}
self.interval = interval self.interval = interval
if history: if history:
history = History history = History
else: else:
history = NoHistory history = NoHistory
self._history = history(name=self.name, self._history = history(name=self.name,
backup=True) backup=True)
self['SEED'] = seed self['SEED'] = seed
@ -298,6 +300,9 @@ class Environment(Model):
else: else:
raise ValueError('Unknown format: {}'.format(f)) raise ValueError('Unknown format: {}'.format(f))
def df(self):
return self._history[None, None, None].df()
def dump_sqlite(self, f): def dump_sqlite(self, f):
return self._history.dump(f) return self._history.dump(f)
@ -316,8 +321,14 @@ class Environment(Model):
key=k, key=k,
value=v) value=v)
def history_to_tuples(self): def history_to_tuples(self, agent_id=None):
return self._history.to_tuples() if isinstance(self._history, NoHistory):
tuples = self.state_to_tuples()
else:
tuples = self._history.to_tuples()
if agent_id is None:
return tuples
return filter(lambda x: str(x[0]) == str(agent_id), tuples)
def history_to_graph(self): def history_to_graph(self):
G = nx.Graph(self.G) G = nx.Graph(self.G)
@ -329,10 +340,10 @@ class Environment(Model):
spells = [] spells = []
lastvisible = False lastvisible = False
laststep = None laststep = None
history = self[agent.id, None, None] history = sorted(list(self.history_to_tuples(agent_id=agent.id)))
if not history: if not history:
continue continue
for t_step, attribute, value in sorted(list(history)): for _, t_step, attribute, value in history:
if attribute == 'visible': if attribute == 'visible':
nowvisible = value nowvisible = value
if nowvisible and not lastvisible: if nowvisible and not lastvisible:

View File

@ -1,6 +1,6 @@
import os import os
import csv as csvlib import csv as csvlib
import time from time import time as current_time
from io import BytesIO from io import BytesIO
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -133,7 +133,7 @@ class dummy(Exporter):
def start(self): def start(self):
with self.output('dummy', 'w') as f: with self.output('dummy', 'w') as f:
f.write('simulation started @ {}\n'.format(time.time())) f.write('simulation started @ {}\n'.format(current_time()))
def trial(self, env, stats): def trial(self, env, stats):
with self.output('dummy', 'w') as f: with self.output('dummy', 'w') as f:
@ -143,7 +143,7 @@ class dummy(Exporter):
def sim(self, stats): def sim(self, stats):
with self.output('dummy', 'a') as f: with self.output('dummy', 'a') as f:
f.write('simulation ended @ {}\n'.format(time.time())) f.write('simulation ended @ {}\n'.format(current_time()))

View File

@ -1,4 +1,5 @@
import os import os
from time import time as current_time, strftime
import importlib import importlib
import sys import sys
import yaml import yaml
@ -6,7 +7,6 @@ import traceback
import logging import logging
import networkx as nx import networkx as nx
from time import strftime
from networkx.readwrite import json_graph from networkx.readwrite import json_graph
from multiprocessing import Pool from multiprocessing import Pool
from functools import partial from functools import partial
@ -83,8 +83,9 @@ class Simulation:
Class for the environment. It defailts to soil.environment.Environment Class for the environment. It defailts to soil.environment.Environment
load_module : str, module name, deprecated load_module : str, module name, deprecated
If specified, soil will load the content of this module under 'soil.agents.custom' If specified, soil will load the content of this module under 'soil.agents.custom'
history: tsih.History subclass, optional
Class to use to store the history of the simulation (and environments). It defailts to tsih.History
If set to True, tsih.History will be used. If set to False or None, tsih.NoHistory will be used.
""" """
def __init__(self, name=None, group=None, topology=None, network_params=None, def __init__(self, name=None, group=None, topology=None, network_params=None,
@ -93,7 +94,7 @@ class Simulation:
max_time=100, load_module=None, seed=None, max_time=100, load_module=None, seed=None,
dir_path=None, environment_agents=None, dir_path=None, environment_agents=None,
environment_params=None, environment_class=None, environment_params=None, environment_class=None,
**kwargs): history=History, **kwargs):
self.load_module = load_module self.load_module = load_module
self.network_params = network_params self.network_params = network_params
@ -133,7 +134,12 @@ class Simulation:
self.states = agents._validate_states(states, self.states = agents._validate_states(states,
self.topology) self.topology)
self._history = History(name=self.name, if history == True:
history = History
elif not history:
history = NoHistory
self._history = history(name=self.name,
backup=False) backup=False)
def run_simulation(self, *args, **kwargs): def run_simulation(self, *args, **kwargs):
@ -233,6 +239,7 @@ class Simulation:
'states': self.states, 'states': self.states,
'dir_path': self.dir_path, 'dir_path': self.dir_path,
'default_state': self.default_state, 'default_state': self.default_state,
'history': bool(self._history),
'environment_agents': self.environment_agents, 'environment_agents': self.environment_agents,
}) })
opts.update(kwargs) opts.update(kwargs)

View File

@ -35,7 +35,7 @@ class distribution(Stats):
self.counts = [] self.counts = []
def trial(self, env): def trial(self, env):
df = env[None, None, None].df() df = env.df()
df = df.drop('SEED', axis=1) df = df.drop('SEED', axis=1)
ix = df.index[-1] ix = df.index[-1]
attrs = df.columns.get_level_values(0) attrs = df.columns.get_level_values(0)

View File

@ -1,5 +1,5 @@
import logging import logging
import time from time import time as current_time, strftime, gmtime, localtime
import os import os
from shutil import copyfile from shutil import copyfile
@ -13,13 +13,13 @@ logger = logging.getLogger('soil')
@contextmanager @contextmanager
def timer(name='task', pre="", function=logger.info, to_object=None): def timer(name='task', pre="", function=logger.info, to_object=None):
start = time.time() start = current_time()
function('{}Starting {} at {}.'.format(pre, name, function('{}Starting {} at {}.'.format(pre, name,
time.strftime("%X", time.gmtime(start)))) strftime("%X", gmtime(start))))
yield start yield start
end = time.time() end = current_time()
function('{}Finished {} at {} in {} seconds'.format(pre, name, function('{}Finished {} at {} in {} seconds'.format(pre, name,
time.strftime("%X", time.gmtime(end)), strftime("%X", gmtime(end)),
str(end-start))) str(end-start)))
if to_object: if to_object:
to_object.start = start to_object.start = start
@ -34,7 +34,7 @@ def safe_open(path, mode='r', backup=True, **kwargs):
os.makedirs(outdir) os.makedirs(outdir)
if backup and 'w' in mode and os.path.exists(path): if backup and 'w' in mode and os.path.exists(path):
creation = os.path.getctime(path) creation = os.path.getctime(path)
stamp = time.strftime('%Y-%m-%d_%H.%M.%S', time.localtime(creation)) stamp = strftime('%Y-%m-%d_%H.%M.%S', localtime(creation))
backup_dir = os.path.join(outdir, 'backup') backup_dir = os.path.join(outdir, 'backup')
if not os.path.exists(backup_dir): if not os.path.exists(backup_dir):
@ -45,11 +45,13 @@ def safe_open(path, mode='r', backup=True, **kwargs):
return open(path, mode=mode, **kwargs) return open(path, mode=mode, **kwargs)
@contextmanager
def open_or_reuse(f, *args, **kwargs): def open_or_reuse(f, *args, **kwargs):
try: try:
return safe_open(f, *args, **kwargs) with safe_open(f, *args, **kwargs) as f:
yield f
except (AttributeError, TypeError): except (AttributeError, TypeError):
return f yield f
def flatten_dict(d): def flatten_dict(d):
if not isinstance(d, dict): if not isinstance(d, dict):

View File

@ -1,4 +1,4 @@
pytest pytest
mesa>=0.8.9 pytest-profiling
scipy>=1.3 scipy>=1.3
tornado tornado

View File

@ -50,6 +50,7 @@ class TestAnalysis(TestCase):
'states': [{'interval': 1}, {'interval': 2}], 'states': [{'interval': 1}, {'interval': 2}],
'max_time': 30, 'max_time': 30,
'num_trials': 1, 'num_trials': 1,
'history': True,
'environment_params': { 'environment_params': {
} }
} }

View File

@ -2,7 +2,6 @@ import os
import io import io
import tempfile import tempfile
import shutil import shutil
from time import time
from unittest import TestCase from unittest import TestCase
from soil import exporters from soil import exporters
@ -68,6 +67,7 @@ class Exporters(TestCase):
'agent_type': 'CounterModel', 'agent_type': 'CounterModel',
'max_time': 2, 'max_time': 2,
'num_trials': n_trials, 'num_trials': n_trials,
'dry_run': False,
'environment_params': {} 'environment_params': {}
} }
output = io.StringIO() output = io.StringIO()
@ -79,6 +79,7 @@ class Exporters(TestCase):
exporters.gexf, exporters.gexf,
], ],
stats=[distribution,], stats=[distribution,],
dry_run=False,
outdir=tmpdir, outdir=tmpdir,
exporter_params={'copy_to': output}) exporter_params={'copy_to': output})
result = output.getvalue() result = output.getvalue()

View File

@ -11,6 +11,7 @@ from os.path import join
from soil import (simulation, Environment, agents, serialization, from soil import (simulation, Environment, agents, serialization,
utils) utils)
from soil.time import Delta from soil.time import Delta
from tsih import NoHistory, History
ROOT = os.path.abspath(os.path.dirname(__file__)) ROOT = os.path.abspath(os.path.dirname(__file__))
@ -205,7 +206,7 @@ class TestMain(TestCase):
assert config == nconfig assert config == nconfig
def test_row_conversion(self): def test_row_conversion(self):
env = Environment() env = Environment(history=True)
env['test'] = 'test_value' env['test'] = 'test_value'
res = list(env.history_to_tuples()) res = list(env.history_to_tuples())
@ -228,7 +229,14 @@ class TestMain(TestCase):
f = io.BytesIO() f = io.BytesIO()
env.dump_gexf(f) env.dump_gexf(f)
def test_save_graph(self): def test_nohistory(self):
'''
Make sure that no history(/sqlite) is used by default
'''
env = Environment(topology=nx.Graph(), network_agents=[])
assert isinstance(env._history, NoHistory)
def test_save_graph_history(self):
''' '''
The history_to_graph method should return a valid networkx graph. The history_to_graph method should return a valid networkx graph.
@ -236,7 +244,7 @@ class TestMain(TestCase):
''' '''
G = nx.cycle_graph(5) G = nx.cycle_graph(5)
distribution = agents.calculate_distribution(None, agents.BaseAgent) distribution = agents.calculate_distribution(None, agents.BaseAgent)
env = Environment(topology=G, network_agents=distribution) env = Environment(topology=G, network_agents=distribution, history=True)
env[0, 0, 'testvalue'] = 'start' env[0, 0, 'testvalue'] = 'start'
env[0, 10, 'testvalue'] = 'finish' env[0, 10, 'testvalue'] = 'finish'
nG = env.history_to_graph() nG = env.history_to_graph()
@ -244,6 +252,23 @@ class TestMain(TestCase):
assert ('start', 0, 10) in values assert ('start', 0, 10) in values
assert ('finish', 10, None) in values assert ('finish', 10, None) in values
def test_save_graph_nohistory(self):
'''
The history_to_graph method should return a valid networkx graph.
When NoHistory is used, only the last known value is known
'''
G = nx.cycle_graph(5)
distribution = agents.calculate_distribution(None, agents.BaseAgent)
env = Environment(topology=G, network_agents=distribution, history=False)
env.get_agent(0)['testvalue'] = 'start'
env.schedule.time = 10
env.get_agent(0)['testvalue'] = 'finish'
nG = env.history_to_graph()
values = nG.nodes[0]['attr_testvalue']
assert ('start', 0, None) not in values
assert ('finish', 10, None) in values
def test_serialize_class(self): def test_serialize_class(self):
ser, name = serialization.serialize(agents.BaseAgent) ser, name = serialization.serialize(agents.BaseAgent)
assert name == 'soil.agents.BaseAgent' assert name == 'soil.agents.BaseAgent'
@ -303,7 +328,7 @@ class TestMain(TestCase):
pickle.dumps(converted) pickle.dumps(converted)
def test_pickle_agent_environment(self): def test_pickle_agent_environment(self):
env = Environment(name='Test') env = Environment(name='Test', history=True)
a = agents.BaseAgent(model=env, unique_id=25) a = agents.BaseAgent(model=env, unique_id=25)
a['key'] = 'test' a['key'] = 'test'