from unittest import TestCase import os import shutil from glob import glob from tsih import * from tsih import utils ROOT = os.path.abspath(os.path.dirname(__file__)) DBROOT = os.path.join(ROOT, 'testdb') class TestHistory(TestCase): def setUp(self): if not os.path.exists(DBROOT): os.makedirs(DBROOT) def tearDown(self): if os.path.exists(DBROOT): shutil.rmtree(DBROOT) def test_history(self): """ """ tuples = ( ('a_0', 0, 'id', 'h'), ('a_0', 1, 'id', 'e'), ('a_0', 2, 'id', 'l'), ('a_0', 3, 'id', 'l'), ('a_0', 4, 'id', 'o'), ('a_1', 0, 'id', 'v'), ('a_1', 1, 'id', 'a'), ('a_1', 2, 'id', 'l'), ('a_1', 3, 'id', 'u'), ('a_1', 4, 'id', 'e'), ('env', 1, 'prob', 1), ('env', 3, 'prob', 2), ('env', 5, 'prob', 3), ('a_2', 7, 'finished', True), ) h = History() h.save_tuples(tuples) # assert h['env', 0, 'prob'] == 0 for i in range(1, 7): assert h['env', i, 'prob'] == ((i-1)//2)+1 for i, k in zip(range(5), 'hello'): assert h['a_0', i, 'id'] == k for record, value in zip(h['a_0', None, 'id'], 'hello'): t_step, val = record assert val == value for i, k in zip(range(5), 'value'): assert h['a_1', i, 'id'] == k for i in range(5, 8): assert h['a_1', i, 'id'] == 'e' for i in range(7): assert h['a_2', i, 'finished'] == False assert h['a_2', 7, 'finished'] def test_history_gen(self): """ """ tuples = ( ('a_1', 0, 'id', 'v'), ('a_1', 1, 'id', 'a'), ('a_1', 2, 'id', 'l'), ('a_1', 3, 'id', 'u'), ('a_1', 4, 'id', 'e'), ('env', 1, 'prob', 1), ('env', 2, 'prob', 2), ('env', 3, 'prob', 3), ('a_2', 7, 'finished', True), ) h = History() h.save_tuples(tuples) for t_step, key, value in h['env', None, None]: assert t_step == value assert key == 'prob' records = list(h[None, 7, None]) assert len(records) == 3 for i in records: agent_id, key, value = i if agent_id == 'a_1': assert key == 'id' assert value == 'e' elif agent_id == 'a_2': assert key == 'finished' assert value else: assert key == 'prob' assert value == 3 records = h['a_1', 7, None] assert records['id'] == 'e' def test_history_file(self): """ History should be saved to a file """ tuples = ( ('a_1', 0, 'id', 'v'), ('a_1', 1, 'id', 'a'), ('a_1', 2, 'id', 'l'), ('a_1', 3, 'id', 'u'), ('a_1', 4, 'id', 'e'), ('env', 1, 'prob', 1), ('env', 2, 'prob', 2), ('env', 3, 'prob', 3), ('a_2', 7, 'finished', True), ) db_path = os.path.join(DBROOT, 'test') h = History(db_path=db_path) h.save_tuples(tuples) h.flush_cache() assert os.path.exists(db_path) # Recover the data recovered = History(db_path=db_path) assert recovered['a_1', 0, 'id'] == 'v' assert recovered['a_1', 4, 'id'] == 'e' # Using backup=True should create a backup copy, and initialize an empty history newhistory = History(db_path=db_path, backup=True) backuppaths = glob(db_path + '.backup*.sqlite') assert len(backuppaths) == 1 backuppath = backuppaths[0] assert newhistory.db_path == h.db_path assert os.path.exists(backuppath) assert len(newhistory[None, None, None]) == 0 def test_interpolation(self): """ Values for a key are valid until a new value is introduced at a later version """ tuples = ( ('a_1', 0, 'id', 'a'), ('a_1', 4, 'id', 'b'), ) db_path = os.path.join(DBROOT, 'test') h = History(db_path=db_path) h.save_tuples(tuples) h.flush_cache() assert os.path.exists(db_path) assert h['a_1', 2, 'id'] == 'a' # Recover the data recovered = History(db_path=db_path) assert recovered['a_1', 0, 'id'] == 'a' assert recovered['a_1', 4, 'id'] == 'b' assert recovered['a_1', 2, 'id'] == 'a' def test_history_tuples(self): """ The data recovered should be equal to the one recorded. """ tuples = ( ('a_1', 0, 'id', 'v'), ('a_1', 1, 'id', 'a'), ('a_1', 2, 'id', 'l'), ('a_1', 3, 'id', 'u'), ('a_1', 4, 'id', 'e'), ('env', 1, 'prob', 1), ('env', 2, 'prob', 2), ('env', 3, 'prob', 3), ('a_2', 7, 'finished', True), ) h = History() h.save_tuples(tuples) recovered = list(h.to_tuples()) assert recovered for i in recovered: assert i in tuples def test_stats(self): """ The data recovered should be equal to the one recorded. """ tuples = ( ('a_1', 0, 'id', 'v'), ('a_1', 1, 'id', 'a'), ('a_1', 2, 'id', 'l'), ('a_1', 3, 'id', 'u'), ('a_1', 4, 'id', 'e'), ('env', 1, 'prob', 1), ('env', 2, 'prob', 2), ('env', 3, 'prob', 3), ('a_2', 7, 'finished', True), ) stat_tuples = [ {'num_infected': 5, 'runtime': 0.2}, {'num_infected': 5, 'runtime': 0.2}, {'new': '40'}, ] h = History() h.save_tuples(tuples) for stat in stat_tuples: h.save_stats(stat) recovered = h.get_stats() assert recovered assert recovered[0]['num_infected'] == 5 assert recovered[1]['runtime'] == 0.2 assert recovered[2]['new'] == '40' def test_unflatten(self): ex = {'count.neighbors.3': 4, 'count.times.2': 4, 'count.total.4': 4, 'mean.neighbors': 3, 'mean.times': 2, 'mean.total': 4, 't_step': 2, 'trial_id': 'exporter_sim_trial_1605817956-4475424'} res = utils.unflatten_dict(ex) assert 'count' in res assert all(x in res['count'] for x in ['times', 'total', 'neighbors']) assert res['count']['times']['2'] == 4 assert 'mean' in res assert all(x in res['mean'] for x in ['times', 'total', 'neighbors']) assert 't_step' in res assert 'trial_id' in res