mirror of
				https://github.com/gsi-upm/soil
				synced 2025-10-31 15:48:16 +00:00 
			
		
		
		
	The examples weren't being properly tested in the last commit. When we fixed that a lot of bugs in the new implementation of environment and agent were found, which accounts for most of these changes. The main difference is the mechanism to load simulations from a configuration file. For that to work, we had to rework our module loading code in `serialization` and add a `source_file` attribute to configurations (and simulations, for that matter).
		
			
				
	
	
		
			121 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			121 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import io
 | |
| import tempfile
 | |
| import shutil
 | |
| import sqlite3
 | |
| 
 | |
| from unittest import TestCase
 | |
| from soil import exporters
 | |
| from soil import environment
 | |
| from soil import simulation
 | |
| from soil import agents
 | |
| 
 | |
| from mesa import Agent as MesaAgent
 | |
| 
 | |
| 
 | |
| class Dummy(exporters.Exporter):
 | |
|     started = False
 | |
|     trials = 0
 | |
|     ended = False
 | |
|     total_time = 0
 | |
|     called_start = 0
 | |
|     called_trial = 0
 | |
|     called_end = 0
 | |
| 
 | |
|     def sim_start(self):
 | |
|         self.__class__.called_start += 1
 | |
|         self.__class__.started = True
 | |
| 
 | |
|     def trial_end(self, env):
 | |
|         assert env
 | |
|         self.__class__.trials += 1
 | |
|         self.__class__.total_time += env.now
 | |
|         self.__class__.called_trial += 1
 | |
| 
 | |
|     def sim_end(self):
 | |
|         self.__class__.ended = True
 | |
|         self.__class__.called_end += 1
 | |
| 
 | |
| 
 | |
| class Exporters(TestCase):
 | |
|     def test_basic(self):
 | |
|         # We need to add at least one agent to make sure the scheduler
 | |
|         # ticks every step
 | |
|         class SimpleEnv(environment.Environment):
 | |
|             def init(self):
 | |
|                 self.add_agent(agent_class=MesaAgent)
 | |
|         
 | |
| 
 | |
|         num_trials = 5
 | |
|         max_time = 2
 | |
|         s = simulation.Simulation(num_trials=num_trials, max_time=max_time, name="exporter_sim",
 | |
|                                   dump=False, model=SimpleEnv)
 | |
| 
 | |
|         for env in s.run_simulation(exporters=[Dummy], dump=False):
 | |
|             assert len(env.agents) == 1
 | |
| 
 | |
|         assert Dummy.started
 | |
|         assert Dummy.ended
 | |
|         assert Dummy.called_start == 1
 | |
|         assert Dummy.called_end == 1
 | |
|         assert Dummy.called_trial == num_trials
 | |
|         assert Dummy.trials == num_trials
 | |
|         assert Dummy.total_time == max_time * num_trials
 | |
| 
 | |
|     def test_writing(self):
 | |
|         """Try to write CSV, sqlite and YAML (without no_dump)"""
 | |
|         n_trials = 5
 | |
|         n_nodes = 4
 | |
|         max_time = 2
 | |
|         config = {
 | |
|             "name": "exporter_sim",
 | |
|             "model_params": {
 | |
|                 "network_generator": "complete_graph",
 | |
|                 "network_params": {"n": n_nodes},
 | |
|                 "agent_class": "CounterModel",
 | |
|             },
 | |
|             "max_time": max_time,
 | |
|             "num_trials": n_trials,
 | |
|             "dump": True,
 | |
|         }
 | |
|         output = io.StringIO()
 | |
|         s = simulation.from_config(config)
 | |
|         tmpdir = tempfile.mkdtemp()
 | |
|         envs = s.run_simulation(
 | |
|             exporters=[
 | |
|                 exporters.default,
 | |
|                 exporters.csv,
 | |
|             ],
 | |
|             model_params={
 | |
|                 "agent_reporters": {"times": "times"},
 | |
|                 "model_reporters": {
 | |
|                     "constant": lambda x: 1,
 | |
|                 },
 | |
|             },
 | |
|             dump=True,
 | |
|             outdir=tmpdir,
 | |
|             exporter_params={"copy_to": output},
 | |
|         )
 | |
|         result = output.getvalue()
 | |
| 
 | |
|         simdir = os.path.join(tmpdir, s.group or "", s.name)
 | |
|         with open(os.path.join(simdir, "{}.dumped.yml".format(s.name))) as f:
 | |
|             result = f.read()
 | |
|             assert result
 | |
| 
 | |
|         try:
 | |
|             for e in envs:
 | |
|                 dbpath = os.path.join(simdir, f"{s.name}.sqlite")
 | |
|                 db = sqlite3.connect(dbpath)
 | |
|                 cur = db.cursor()
 | |
|                 agent_entries = cur.execute("SELECT times FROM agents WHERE times > 0").fetchall()
 | |
|                 env_entries = cur.execute("SELECT constant from env WHERE constant == 1").fetchall()
 | |
|                 assert len(agent_entries) == n_nodes * n_trials * max_time
 | |
|                 assert len(env_entries) == n_trials * max_time
 | |
| 
 | |
|                 with open(os.path.join(simdir, "{}.env.csv".format(e.id))) as f:
 | |
|                     result = f.read()
 | |
|                     assert result
 | |
|         finally:
 | |
|             shutil.rmtree(tmpdir)
 |