mirror of
				https://github.com/gsi-upm/soil
				synced 2025-11-04 09:28:16 +00:00 
			
		
		
		
	Add tests exporters
This commit is contained in:
		@@ -21,6 +21,7 @@ def for_sim(simulation, names, *args, **kwargs):
 | 
			
		||||
        exporters.append(mod(simulation, *args, **kwargs))
 | 
			
		||||
    return exporters
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DryRunner(BytesIO):
 | 
			
		||||
    def __init__(self, fname, *args, copy_to=None, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										110
									
								
								tests/test_exporters.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								tests/test_exporters.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,110 @@
 | 
			
		||||
import os
 | 
			
		||||
import io
 | 
			
		||||
import tempfile
 | 
			
		||||
import shutil
 | 
			
		||||
from time import time
 | 
			
		||||
 | 
			
		||||
from unittest import TestCase
 | 
			
		||||
from soil import exporters
 | 
			
		||||
from soil.utils import safe_open
 | 
			
		||||
from soil import simulation
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Dummy(exporters.Exporter):
 | 
			
		||||
    started = False
 | 
			
		||||
    trials = 0
 | 
			
		||||
    ended = False
 | 
			
		||||
    total_time = 0
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        self.__class__.started = True
 | 
			
		||||
 | 
			
		||||
    def trial_end(self, env):
 | 
			
		||||
        assert env
 | 
			
		||||
        self.__class__.trials += 1
 | 
			
		||||
        self.__class__.total_time += env.now
 | 
			
		||||
 | 
			
		||||
    def end(self):
 | 
			
		||||
        self.__class__.ended = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Exporters(TestCase):
 | 
			
		||||
    def test_basic(self):
 | 
			
		||||
        config = {
 | 
			
		||||
            'name': 'exporter_sim',
 | 
			
		||||
            'network_params': {},
 | 
			
		||||
            'agent_type': 'CounterModel',
 | 
			
		||||
            'max_time': 2,
 | 
			
		||||
            'num_trials': 5,
 | 
			
		||||
            'environment_params': {}
 | 
			
		||||
        }
 | 
			
		||||
        s = simulation.from_config(config)
 | 
			
		||||
        s.run_simulation(exporters=[Dummy], dry_run=True)
 | 
			
		||||
        assert Dummy.started
 | 
			
		||||
        assert Dummy.ended
 | 
			
		||||
        assert Dummy.trials == 5
 | 
			
		||||
        assert Dummy.total_time == 2*5
 | 
			
		||||
 | 
			
		||||
    def test_distribution(self):
 | 
			
		||||
        '''The distribution exporter should write the number of agents in each state'''
 | 
			
		||||
        config = {
 | 
			
		||||
            'name': 'exporter_sim',
 | 
			
		||||
            'network_params': {
 | 
			
		||||
                'generator': 'complete_graph',
 | 
			
		||||
                'n': 4
 | 
			
		||||
            },
 | 
			
		||||
            'agent_type': 'CounterModel',
 | 
			
		||||
            'max_time': 2,
 | 
			
		||||
            'num_trials': 5,
 | 
			
		||||
            'environment_params': {}
 | 
			
		||||
        }
 | 
			
		||||
        output = io.StringIO()
 | 
			
		||||
        s = simulation.from_config(config)
 | 
			
		||||
        s.run_simulation(exporters=[exporters.Distribution], dry_run=True, exporter_params={'copy_to': output})
 | 
			
		||||
        result = output.getvalue()
 | 
			
		||||
        assert 'count' in result
 | 
			
		||||
        assert 'SEED,Noneexporter_sim_trial_3,1,,1,1,1,1' in result
 | 
			
		||||
 | 
			
		||||
    def test_writing(self):
 | 
			
		||||
        '''Try to write CSV, GEXF, sqlite and YAML (without dry_run)'''
 | 
			
		||||
        n_trials = 5
 | 
			
		||||
        config = {
 | 
			
		||||
            'name': 'exporter_sim',
 | 
			
		||||
            'network_params': {
 | 
			
		||||
                'generator': 'complete_graph',
 | 
			
		||||
                'n': 4
 | 
			
		||||
            },
 | 
			
		||||
            'agent_type': 'CounterModel',
 | 
			
		||||
            'max_time': 2,
 | 
			
		||||
            'num_trials': n_trials,
 | 
			
		||||
            'environment_params': {}
 | 
			
		||||
        }
 | 
			
		||||
        output = io.StringIO()
 | 
			
		||||
        s = simulation.from_config(config)
 | 
			
		||||
        tmpdir = tempfile.mkdtemp()
 | 
			
		||||
        envs = s.run_simulation(exporters=[
 | 
			
		||||
            exporters.Default,
 | 
			
		||||
            exporters.CSV,
 | 
			
		||||
            exporters.Gexf,
 | 
			
		||||
            exporters.Distribution,
 | 
			
		||||
        ],
 | 
			
		||||
                                outdir=tmpdir,
 | 
			
		||||
                                exporter_params={'copy_to': output})
 | 
			
		||||
        result = output.getvalue()
 | 
			
		||||
 | 
			
		||||
        simdir = os.path.join(tmpdir, s.group or '', s.name)
 | 
			
		||||
        with open(os.path.join(simdir, '{}.dumped.yml'.format(s.name))) as f:
 | 
			
		||||
            result = f.read()
 | 
			
		||||
            assert result
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            for e in envs:
 | 
			
		||||
                with open(os.path.join(simdir, '{}.gexf'.format(e.name))) as f:
 | 
			
		||||
                    result = f.read()
 | 
			
		||||
                    assert result
 | 
			
		||||
 | 
			
		||||
                with open(os.path.join(simdir, '{}.csv'.format(e.name))) as f:
 | 
			
		||||
                    result = f.read()
 | 
			
		||||
                    assert result
 | 
			
		||||
        finally:
 | 
			
		||||
            shutil.rmtree(tmpdir)
 | 
			
		||||
		Reference in New Issue
	
	Block a user