import os import io import tempfile import shutil from unittest import TestCase from soil import exporters from soil import simulation from soil import agents class Dummy(exporters.Exporter): started = False trials = 0 ended = False total_time = 0 called_start = 0 called_trial = 0 called_end = 0 def sim_start(self): self.__class__.called_start += 1 self.__class__.started = True def trial_end(self, env): assert env self.__class__.trials += 1 self.__class__.total_time += env.now self.__class__.called_trial += 1 def sim_end(self): self.__class__.ended = True self.__class__.called_end += 1 class Exporters(TestCase): def test_basic(self): # We need to add at least one agent to make sure the scheduler # ticks every step num_trials = 5 max_time = 2 config = { 'name': 'exporter_sim', 'model_params': { 'agents': [{ 'agent_class': agents.BaseAgent }] }, 'max_time': max_time, 'num_trials': num_trials, } s = simulation.from_config(config) for env in s.run_simulation(exporters=[Dummy], dry_run=True): assert len(env.agents) == 1 assert env.now == max_time assert Dummy.started assert Dummy.ended assert Dummy.called_start == 1 assert Dummy.called_end == 1 assert Dummy.called_trial == num_trials assert Dummy.trials == num_trials assert Dummy.total_time == max_time * num_trials def test_writing(self): '''Try to write CSV, sqlite and YAML (without dry_run)''' n_trials = 5 config = { 'name': 'exporter_sim', 'network_params': { 'generator': 'complete_graph', 'n': 4 }, 'agent_class': 'CounterModel', 'max_time': 2, 'num_trials': n_trials, 'dry_run': False, 'environment_params': {} } output = io.StringIO() s = simulation.from_config(config) tmpdir = tempfile.mkdtemp() envs = s.run_simulation(exporters=[ exporters.default, exporters.csv, ], dry_run=False, outdir=tmpdir, exporter_params={'copy_to': output}) result = output.getvalue() simdir = os.path.join(tmpdir, s.group or '', s.name) with open(os.path.join(simdir, '{}.dumped.yml'.format(s.name))) as f: result = f.read() assert result try: for e in envs: with open(os.path.join(simdir, '{}.env.csv'.format(e.id))) as f: result = f.read() assert result finally: shutil.rmtree(tmpdir)