1
0
mirror of https://github.com/gsi-upm/soil synced 2024-11-14 15:32:29 +00:00
soil/tests/test_exporters.py

124 lines
3.7 KiB
Python
Raw Normal View History

2019-04-30 07:28:25 +00:00
import os
import io
import tempfile
import shutil
import sqlite3
2019-04-30 07:28:25 +00:00
from unittest import TestCase
from soil import exporters
from soil import environment
2019-04-30 07:28:25 +00:00
from soil import simulation
2022-10-13 20:43:16 +00:00
from soil import agents
2023-04-20 15:56:44 +00:00
from soil import decorators
2022-10-13 20:43:16 +00:00
from mesa import Agent as MesaAgent
2019-04-30 07:28:25 +00:00
class Dummy(exporters.Exporter):
started = False
2023-04-20 15:56:44 +00:00
iterations = 0
2019-04-30 07:28:25 +00:00
ended = False
total_time = 0
called_start = 0
2023-04-20 15:56:44 +00:00
called_iteration = 0
called_end = 0
2019-04-30 07:28:25 +00:00
2022-09-13 16:16:31 +00:00
def sim_start(self):
self.__class__.called_start += 1
2019-04-30 07:28:25 +00:00
self.__class__.started = True
2023-04-20 15:56:44 +00:00
def iteration_end(self, env, *args, **kwargs):
2019-04-30 07:28:25 +00:00
assert env
2023-04-20 15:56:44 +00:00
self.__class__.iterations += 1
2019-04-30 07:28:25 +00:00
self.__class__.total_time += env.now
2023-04-20 15:56:44 +00:00
self.__class__.called_iteration += 1
2019-04-30 07:28:25 +00:00
2022-09-16 16:13:39 +00:00
def sim_end(self):
2019-04-30 07:28:25 +00:00
self.__class__.ended = True
self.__class__.called_end += 1
2019-04-30 07:28:25 +00:00
class Exporters(TestCase):
def test_basic(self):
2022-10-13 20:43:16 +00:00
# We need to add at least one agent to make sure the scheduler
# ticks every step
class SimpleEnv(environment.Environment):
def init(self):
self.add_agent(agent_class=MesaAgent)
2023-04-20 15:56:44 +00:00
iterations = 5
2022-10-13 20:43:16 +00:00
max_time = 2
2023-04-20 15:56:44 +00:00
s = simulation.Simulation(iterations=iterations,
max_time=max_time, name="exporter_sim",
exporters=[Dummy], dump=False, model=SimpleEnv)
2022-10-13 20:43:16 +00:00
2023-04-20 15:56:44 +00:00
for env in s.run():
2022-10-13 20:43:16 +00:00
assert len(env.agents) == 1
2019-04-30 07:28:25 +00:00
assert Dummy.started
assert Dummy.ended
assert Dummy.called_start == 1
assert Dummy.called_end == 1
2023-04-20 15:56:44 +00:00
assert Dummy.called_iteration == iterations
assert Dummy.iterations == iterations
assert Dummy.total_time == max_time * iterations
2019-04-30 07:28:25 +00:00
def test_writing(self):
"""Try to write CSV, sqlite and YAML (without no_dump)"""
2023-04-20 15:56:44 +00:00
n_iterations = 5
n_nodes = 4
max_time = 2
2019-04-30 07:28:25 +00:00
output = io.StringIO()
tmpdir = tempfile.mkdtemp()
2023-04-20 15:56:44 +00:00
class ConstantEnv(environment.Environment):
@decorators.report
@property
def constant(self):
return 1
s = simulation.Simulation(
model=ConstantEnv,
name="exporter_sim",
exporters=[
2023-05-12 12:09:00 +00:00
exporters.YAML,
exporters.SQLite,
exporters.csv,
],
exporter_params={"copy_to": output},
2023-04-20 15:56:44 +00:00
parameters=dict(
network_generator="complete_graph",
network_params={"n": n_nodes},
2023-05-03 10:14:49 +00:00
agent_class=agents.CounterModel,
2023-04-20 15:56:44 +00:00
agent_reporters={"times": "times"},
),
max_time=max_time,
outdir=tmpdir,
iterations=n_iterations,
dump=True,
)
2023-04-20 15:56:44 +00:00
envs = s.run()
2019-04-30 07:28:25 +00:00
result = output.getvalue()
simdir = os.path.join(tmpdir, s.group or "", s.name)
2023-04-20 15:56:44 +00:00
with open(os.path.join(simdir, "{}.dumped.yml".format(s.id))) as f:
2019-04-30 07:28:25 +00:00
result = f.read()
assert result
try:
2023-04-20 15:56:44 +00:00
dbpath = os.path.join(simdir, f"{s.name}.sqlite")
db = sqlite3.connect(dbpath)
cur = db.cursor()
agent_entries = cur.execute("SELECT times FROM agents WHERE times > 0").fetchall()
env_entries = cur.execute("SELECT constant from env WHERE constant == 1").fetchall()
assert len(agent_entries) == n_nodes * n_iterations * max_time
assert len(env_entries) == n_iterations * (max_time + 1) # +1 for the initial state
2023-04-20 15:56:44 +00:00
for e in envs:
with open(os.path.join(simdir, "{}.env.csv".format(e.id))) as f:
2019-04-30 07:28:25 +00:00
result = f.read()
assert result
2023-04-20 15:56:44 +00:00
2019-04-30 07:28:25 +00:00
finally:
shutil.rmtree(tmpdir)