You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
soil/tests/test_exporters.py

105 lines
3.0 KiB
Python

import os
import io
import tempfile
import shutil
from unittest import TestCase
from soil import exporters
from soil import simulation
from soil import agents
class Dummy(exporters.Exporter):
started = False
trials = 0
ended = False
total_time = 0
called_start = 0
called_trial = 0
called_end = 0
def sim_start(self):
self.__class__.called_start += 1
self.__class__.started = True
def trial_end(self, env):
assert env
self.__class__.trials += 1
self.__class__.total_time += env.now
self.__class__.called_trial += 1
def sim_end(self):
self.__class__.ended = True
self.__class__.called_end += 1
class Exporters(TestCase):
def test_basic(self):
# We need to add at least one agent to make sure the scheduler
# ticks every step
num_trials = 5
max_time = 2
config = {
'name': 'exporter_sim',
'model_params': {
'agents': [{
'agent_class': agents.BaseAgent
}]
},
'max_time': max_time,
'num_trials': num_trials,
}
s = simulation.from_config(config)
for env in s.run_simulation(exporters=[Dummy], dry_run=True):
assert len(env.agents) == 1
assert env.now == max_time
assert Dummy.started
assert Dummy.ended
assert Dummy.called_start == 1
assert Dummy.called_end == 1
assert Dummy.called_trial == num_trials
assert Dummy.trials == num_trials
assert Dummy.total_time == max_time * num_trials
def test_writing(self):
'''Try to write CSV, sqlite and YAML (without dry_run)'''
n_trials = 5
config = {
'name': 'exporter_sim',
'network_params': {
'generator': 'complete_graph',
'n': 4
},
'agent_class': 'CounterModel',
'max_time': 2,
'num_trials': n_trials,
'dry_run': False,
'environment_params': {}
}
output = io.StringIO()
s = simulation.from_config(config)
tmpdir = tempfile.mkdtemp()
envs = s.run_simulation(exporters=[
exporters.default,
exporters.csv,
],
dry_run=False,
outdir=tmpdir,
exporter_params={'copy_to': output})
result = output.getvalue()
simdir = os.path.join(tmpdir, s.group or '', s.name)
with open(os.path.join(simdir, '{}.dumped.yml'.format(s.name))) as f:
result = f.read()
assert result
try:
for e in envs:
with open(os.path.join(simdir, '{}.env.csv'.format(e.id))) as f:
result = f.read()
assert result
finally:
shutil.rmtree(tmpdir)