1
0
mirror of https://github.com/gsi-upm/soil synced 2024-10-06 14:21:43 +00:00
soil/tests/test_exporters.py
J. Fernando Sánchez d9947c2c52 WIP: all tests pass
Documentation needs some improvement

The API has been simplified to only allow for ONE topology per
NetworkEnvironment.
This covers the main use case, and simplifies the code.
2022-10-16 17:56:23 +02:00

114 lines
3.3 KiB
Python

import os
import io
import tempfile
import shutil
import sqlite3
from unittest import TestCase
from soil import exporters
from soil import simulation
from soil import agents
class Dummy(exporters.Exporter):
started = False
trials = 0
ended = False
total_time = 0
called_start = 0
called_trial = 0
called_end = 0
def sim_start(self):
self.__class__.called_start += 1
self.__class__.started = True
def trial_end(self, env):
assert env
self.__class__.trials += 1
self.__class__.total_time += env.now
self.__class__.called_trial += 1
def sim_end(self):
self.__class__.ended = True
self.__class__.called_end += 1
class Exporters(TestCase):
def test_basic(self):
# We need to add at least one agent to make sure the scheduler
# ticks every step
num_trials = 5
max_time = 2
config = {
"name": "exporter_sim",
"model_params": {"agents": [{"agent_class": agents.BaseAgent}]},
"max_time": max_time,
"num_trials": num_trials,
}
s = simulation.from_config(config)
for env in s.run_simulation(exporters=[Dummy], dry_run=True):
assert len(env.agents) == 1
assert env.now == max_time
assert Dummy.started
assert Dummy.ended
assert Dummy.called_start == 1
assert Dummy.called_end == 1
assert Dummy.called_trial == num_trials
assert Dummy.trials == num_trials
assert Dummy.total_time == max_time * num_trials
def test_writing(self):
"""Try to write CSV, sqlite and YAML (without dry_run)"""
n_trials = 5
config = {
"name": "exporter_sim",
"network_params": {"generator": "complete_graph", "n": 4},
"agent_class": "CounterModel",
"max_time": 2,
"num_trials": n_trials,
"dry_run": False,
"environment_params": {},
}
output = io.StringIO()
s = simulation.from_config(config)
tmpdir = tempfile.mkdtemp()
envs = s.run_simulation(
exporters=[
exporters.default,
exporters.csv,
],
model_params={
"agent_reporters": {"times": "times"},
"model_reporters": {
"constant": lambda x: 1,
},
},
dry_run=False,
outdir=tmpdir,
exporter_params={"copy_to": output},
)
result = output.getvalue()
simdir = os.path.join(tmpdir, s.group or "", s.name)
with open(os.path.join(simdir, "{}.dumped.yml".format(s.name))) as f:
result = f.read()
assert result
try:
for e in envs:
db = sqlite3.connect(os.path.join(simdir, f"{e.id}.sqlite"))
cur = db.cursor()
agent_entries = cur.execute("SELECT * from agents").fetchall()
env_entries = cur.execute("SELECT * from env").fetchall()
assert len(agent_entries) > 0
assert len(env_entries) > 0
with open(os.path.join(simdir, "{}.env.csv".format(e.id))) as f:
result = f.read()
assert result
finally:
shutil.rmtree(tmpdir)