From 5d89827ccf8fe413e62395ecd13e41fd987f6e00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=2E=20Fernando=20S=C3=A1nchez?= Date: Fri, 4 May 2018 11:21:23 +0200 Subject: [PATCH] Fix history bug --- soil/VERSION | 2 +- soil/environment.py | 14 -------- soil/history.py | 6 ++-- tests/test_history.py | 81 +++++++++++++++++++++++++++++++++---------- 4 files changed, 66 insertions(+), 37 deletions(-) diff --git a/soil/VERSION b/soil/VERSION index 0eb4182..027934e 100644 --- a/soil/VERSION +++ b/soil/VERSION @@ -1 +1 @@ -0.11 \ No newline at end of file +0.11.1 \ No newline at end of file diff --git a/soil/environment.py b/soil/environment.py index de2f8b1..552bf36 100644 --- a/soil/environment.py +++ b/soil/environment.py @@ -56,23 +56,9 @@ class SoilEnvironment(nxsim.NetworkEnvironment): # executed before network agents self.environment_agents = environment_agents or [] self.network_agents = network_agents or [] - if self.dry_run: - self._db_path = ":memory:" - else: - self._db_path = os.path.join(self.get_path(), '{}.db.sqlite'.format(self.name)) - self.create_db(self._db_path) self['SEED'] = seed or time.time() random.seed(self['SEED']) - def create_db(self, db_path=None): - db_path = db_path or self._db_path - if os.path.exists(db_path): - newname = db_path.replace('db.sqlite', 'backup{}.sqlite'.format(time.time())) - os.rename(db_path, newname) - self._db = sqlite3.connect(db_path) - with self._db: - self._db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text, value_type text)''') - @property def agents(self): yield from self.environment_agents diff --git a/soil/history.py b/soil/history.py index 73eed13..cf8b725 100644 --- a/soil/history.py +++ b/soil/history.py @@ -15,13 +15,13 @@ class History: def __init__(self, db_path=None, name=None, dir_path=None, backup=True): if db_path is None and name: - db_path = os.path.join(dir_path or os.getcwd(), '{}.db.sqlite'.format(name)) - + db_path = os.path.join(dir_path or os.getcwd(), + '{}.db.sqlite'.format(name)) if db_path is None: db_path = ":memory:" else: if backup and os.path.exists(db_path): - newname = db_path.replace('db.sqlite', 'backup{}.sqlite'.format(time.time())) + newname = db_path + '.backup{}.sqlite'.format(time.time()) os.rename(db_path, newname) self._db_path = db_path if isinstance(db_path, str): diff --git a/tests/test_history.py b/tests/test_history.py index 51d4e3e..19d0893 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -1,30 +1,40 @@ from unittest import TestCase import os -import pandas as pd +import shutil +from glob import glob -from soil import history, analysis +from soil import history ROOT = os.path.abspath(os.path.dirname(__file__)) +DBROOT = os.path.join(ROOT, 'testdb') class TestHistory(TestCase): + def setUp(self): + if not os.path.exists(DBROOT): + os.makedirs(DBROOT) + + def tearDown(self): + if os.path.exists(DBROOT): + shutil.rmtree(DBROOT) + def test_history(self): """ """ tuples = ( - ('a_0', 0, 'id', 'h', ), - ('a_0', 1, 'id', 'e', ), - ('a_0', 2, 'id', 'l', ), - ('a_0', 3, 'id', 'l', ), - ('a_0', 4, 'id', 'o', ), - ('a_1', 0, 'id', 'v', ), - ('a_1', 1, 'id', 'a', ), - ('a_1', 2, 'id', 'l', ), - ('a_1', 3, 'id', 'u', ), - ('a_1', 4, 'id', 'e', ), + ('a_0', 0, 'id', 'h'), + ('a_0', 1, 'id', 'e'), + ('a_0', 2, 'id', 'l'), + ('a_0', 3, 'id', 'l'), + ('a_0', 4, 'id', 'o'), + ('a_1', 0, 'id', 'v'), + ('a_1', 1, 'id', 'a'), + ('a_1', 2, 'id', 'l'), + ('a_1', 3, 'id', 'u'), + ('a_1', 4, 'id', 'e'), ('env', 1, 'prob', 1), ('env', 3, 'prob', 2), ('env', 5, 'prob', 3), @@ -55,11 +65,11 @@ class TestHistory(TestCase): """ """ tuples = ( - ('a_1', 0, 'id', 'v', ), - ('a_1', 1, 'id', 'a', ), - ('a_1', 2, 'id', 'l', ), - ('a_1', 3, 'id', 'u', ), - ('a_1', 4, 'id', 'e', ), + ('a_1', 0, 'id', 'v'), + ('a_1', 1, 'id', 'a'), + ('a_1', 2, 'id', 'l'), + ('a_1', 3, 'id', 'u'), + ('a_1', 4, 'id', 'e'), ('env', 1, 'prob', 1), ('env', 2, 'prob', 2), ('env', 3, 'prob', 3), @@ -80,11 +90,44 @@ class TestHistory(TestCase): assert value == 'e' elif agent_id == 'a_2': assert key == 'finished' - assert value == True + assert value else: assert key == 'prob' assert value == 3 - records = h['a_1', 7, None] assert records['id'] == 'e' + + def test_history_file(self): + """ + History should be saved to a file + """ + tuples = ( + ('a_1', 0, 'id', 'v'), + ('a_1', 1, 'id', 'a'), + ('a_1', 2, 'id', 'l'), + ('a_1', 3, 'id', 'u'), + ('a_1', 4, 'id', 'e'), + ('env', 1, 'prob', 1), + ('env', 2, 'prob', 2), + ('env', 3, 'prob', 3), + ('a_2', 7, 'finished', True), + ) + db_path = os.path.join(DBROOT, 'test') + h = history.History(db_path=db_path) + h.save_tuples(tuples) + assert os.path.exists(db_path) + + # Recover the data + recovered = history.History(db_path=db_path, backup=False) + assert recovered['a_1', 0, 'id'] == 'v' + assert recovered['a_1', 4, 'id'] == 'e' + + # Using the same name should create a backup copy + newhistory = history.History(db_path=db_path, backup=True) + backuppaths = glob(db_path + '.backup*.sqlite') + assert len(backuppaths) == 1 + backuppath = backuppaths[0] + assert newhistory._db_path == h._db_path + assert os.path.exists(backuppath) + assert not len(newhistory[None, None, None])