Fix history bug

history
J. Fernando Sánchez 6 years ago
parent fc48ed7e09
commit 5d89827ccf

@ -1 +1 @@
0.11 0.11.1

@ -56,23 +56,9 @@ class SoilEnvironment(nxsim.NetworkEnvironment):
# executed before network agents # executed before network agents
self.environment_agents = environment_agents or [] self.environment_agents = environment_agents or []
self.network_agents = network_agents or [] self.network_agents = network_agents or []
if self.dry_run:
self._db_path = ":memory:"
else:
self._db_path = os.path.join(self.get_path(), '{}.db.sqlite'.format(self.name))
self.create_db(self._db_path)
self['SEED'] = seed or time.time() self['SEED'] = seed or time.time()
random.seed(self['SEED']) random.seed(self['SEED'])
def create_db(self, db_path=None):
db_path = db_path or self._db_path
if os.path.exists(db_path):
newname = db_path.replace('db.sqlite', 'backup{}.sqlite'.format(time.time()))
os.rename(db_path, newname)
self._db = sqlite3.connect(db_path)
with self._db:
self._db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text, value_type text)''')
@property @property
def agents(self): def agents(self):
yield from self.environment_agents yield from self.environment_agents

@ -15,13 +15,13 @@ class History:
def __init__(self, db_path=None, name=None, dir_path=None, backup=True): def __init__(self, db_path=None, name=None, dir_path=None, backup=True):
if db_path is None and name: if db_path is None and name:
db_path = os.path.join(dir_path or os.getcwd(), '{}.db.sqlite'.format(name)) db_path = os.path.join(dir_path or os.getcwd(),
'{}.db.sqlite'.format(name))
if db_path is None: if db_path is None:
db_path = ":memory:" db_path = ":memory:"
else: else:
if backup and os.path.exists(db_path): if backup and os.path.exists(db_path):
newname = db_path.replace('db.sqlite', 'backup{}.sqlite'.format(time.time())) newname = db_path + '.backup{}.sqlite'.format(time.time())
os.rename(db_path, newname) os.rename(db_path, newname)
self._db_path = db_path self._db_path = db_path
if isinstance(db_path, str): if isinstance(db_path, str):

@ -1,30 +1,40 @@
from unittest import TestCase from unittest import TestCase
import os import os
import pandas as pd import shutil
from glob import glob
from soil import history, analysis from soil import history
ROOT = os.path.abspath(os.path.dirname(__file__)) ROOT = os.path.abspath(os.path.dirname(__file__))
DBROOT = os.path.join(ROOT, 'testdb')
class TestHistory(TestCase): class TestHistory(TestCase):
def setUp(self):
if not os.path.exists(DBROOT):
os.makedirs(DBROOT)
def tearDown(self):
if os.path.exists(DBROOT):
shutil.rmtree(DBROOT)
def test_history(self): def test_history(self):
""" """
""" """
tuples = ( tuples = (
('a_0', 0, 'id', 'h', ), ('a_0', 0, 'id', 'h'),
('a_0', 1, 'id', 'e', ), ('a_0', 1, 'id', 'e'),
('a_0', 2, 'id', 'l', ), ('a_0', 2, 'id', 'l'),
('a_0', 3, 'id', 'l', ), ('a_0', 3, 'id', 'l'),
('a_0', 4, 'id', 'o', ), ('a_0', 4, 'id', 'o'),
('a_1', 0, 'id', 'v', ), ('a_1', 0, 'id', 'v'),
('a_1', 1, 'id', 'a', ), ('a_1', 1, 'id', 'a'),
('a_1', 2, 'id', 'l', ), ('a_1', 2, 'id', 'l'),
('a_1', 3, 'id', 'u', ), ('a_1', 3, 'id', 'u'),
('a_1', 4, 'id', 'e', ), ('a_1', 4, 'id', 'e'),
('env', 1, 'prob', 1), ('env', 1, 'prob', 1),
('env', 3, 'prob', 2), ('env', 3, 'prob', 2),
('env', 5, 'prob', 3), ('env', 5, 'prob', 3),
@ -55,11 +65,11 @@ class TestHistory(TestCase):
""" """
""" """
tuples = ( tuples = (
('a_1', 0, 'id', 'v', ), ('a_1', 0, 'id', 'v'),
('a_1', 1, 'id', 'a', ), ('a_1', 1, 'id', 'a'),
('a_1', 2, 'id', 'l', ), ('a_1', 2, 'id', 'l'),
('a_1', 3, 'id', 'u', ), ('a_1', 3, 'id', 'u'),
('a_1', 4, 'id', 'e', ), ('a_1', 4, 'id', 'e'),
('env', 1, 'prob', 1), ('env', 1, 'prob', 1),
('env', 2, 'prob', 2), ('env', 2, 'prob', 2),
('env', 3, 'prob', 3), ('env', 3, 'prob', 3),
@ -80,11 +90,44 @@ class TestHistory(TestCase):
assert value == 'e' assert value == 'e'
elif agent_id == 'a_2': elif agent_id == 'a_2':
assert key == 'finished' assert key == 'finished'
assert value == True assert value
else: else:
assert key == 'prob' assert key == 'prob'
assert value == 3 assert value == 3
records = h['a_1', 7, None] records = h['a_1', 7, None]
assert records['id'] == 'e' assert records['id'] == 'e'
def test_history_file(self):
"""
History should be saved to a file
"""
tuples = (
('a_1', 0, 'id', 'v'),
('a_1', 1, 'id', 'a'),
('a_1', 2, 'id', 'l'),
('a_1', 3, 'id', 'u'),
('a_1', 4, 'id', 'e'),
('env', 1, 'prob', 1),
('env', 2, 'prob', 2),
('env', 3, 'prob', 3),
('a_2', 7, 'finished', True),
)
db_path = os.path.join(DBROOT, 'test')
h = history.History(db_path=db_path)
h.save_tuples(tuples)
assert os.path.exists(db_path)
# Recover the data
recovered = history.History(db_path=db_path, backup=False)
assert recovered['a_1', 0, 'id'] == 'v'
assert recovered['a_1', 4, 'id'] == 'e'
# Using the same name should create a backup copy
newhistory = history.History(db_path=db_path, backup=True)
backuppaths = glob(db_path + '.backup*.sqlite')
assert len(backuppaths) == 1
backuppath = backuppaths[0]
assert newhistory._db_path == h._db_path
assert os.path.exists(backuppath)
assert not len(newhistory[None, None, None])

Loading…
Cancel
Save