mirror of
https://github.com/gsi-upm/soil
synced 2024-11-24 11:52:29 +00:00
9749f4ca14
Multithreading needs pickling to work. Pickling/unpickling didn't work in some situations, like when the environment_agents parameter was left blank. This was due to two reasons: 1) agents and history didn't have a setstate method, and some of their attributes cannot be pickled (generators, sqlite connection) 2) the environment was adding generators (agents) to its state. This fixes the situation by restricting the keys that the environment exports when it pickles, and by adding the set/getstate methods in agents. The resulting pickles should contain enough information to inspect them (history, state values, etc), but very limited.
292 lines
9.4 KiB
Python
292 lines
9.4 KiB
Python
import time
|
|
import os
|
|
import pandas as pd
|
|
import sqlite3
|
|
import copy
|
|
from collections import UserDict, namedtuple
|
|
|
|
from . import utils
|
|
|
|
|
|
class History:
|
|
"""
|
|
Store and retrieve values from a sqlite database.
|
|
"""
|
|
|
|
def __init__(self, db_path=None, name=None, dir_path=None, backup=True):
|
|
if db_path is None and name:
|
|
db_path = os.path.join(dir_path or os.getcwd(),
|
|
'{}.db.sqlite'.format(name))
|
|
if db_path:
|
|
if backup and os.path.exists(db_path):
|
|
newname = db_path + '.backup{}.sqlite'.format(time.time())
|
|
os.rename(db_path, newname)
|
|
else:
|
|
db_path = ":memory:"
|
|
self.db_path = db_path
|
|
|
|
self.db = db_path
|
|
|
|
with self.db:
|
|
self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text text)''')
|
|
self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''')
|
|
self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''')
|
|
self._dtypes = {}
|
|
self._tups = []
|
|
|
|
@property
|
|
def db(self):
|
|
try:
|
|
self._db.cursor()
|
|
except (sqlite3.ProgrammingError, AttributeError):
|
|
self.db = None # Reset the database
|
|
return self._db
|
|
|
|
@db.setter
|
|
def db(self, db_path=None):
|
|
db_path = db_path or self.db_path
|
|
if isinstance(db_path, str):
|
|
self._db = sqlite3.connect(db_path)
|
|
else:
|
|
self._db = db_path
|
|
|
|
@property
|
|
def dtypes(self):
|
|
self.read_types()
|
|
return {k:v[0] for k, v in self._dtypes.items()}
|
|
|
|
def save_tuples(self, tuples):
|
|
'''
|
|
Save a series of tuples, converting them to records if necessary
|
|
'''
|
|
self.save_records(Record(*tup) for tup in tuples)
|
|
|
|
def save_records(self, records):
|
|
'''
|
|
Save a collection of records
|
|
'''
|
|
for record in records:
|
|
if not isinstance(record, Record):
|
|
record = Record(*record)
|
|
self.save_record(*record)
|
|
|
|
def save_record(self, agent_id, t_step, key, value):
|
|
'''
|
|
Save a collection of records to the database.
|
|
Database writes are cached.
|
|
'''
|
|
value = self.convert(key, value)
|
|
self._tups.append(Record(agent_id=agent_id,
|
|
t_step=t_step,
|
|
key=key,
|
|
value=value))
|
|
if len(self._tups) > 100:
|
|
self.flush_cache()
|
|
|
|
def convert(self, key, value):
|
|
"""Get the serialized value for a given key."""
|
|
if key not in self._dtypes:
|
|
self.read_types()
|
|
if key not in self._dtypes:
|
|
name = utils.name(value)
|
|
serializer = utils.serializer(name)
|
|
deserializer = utils.deserializer(name)
|
|
self._dtypes[key] = (name, serializer, deserializer)
|
|
with self.db:
|
|
self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name))
|
|
return self._dtypes[key][1](value)
|
|
|
|
def recover(self, key, value):
|
|
"""Get the deserialized value for a given key, and the serialized version."""
|
|
if key not in self._dtypes:
|
|
self.read_types()
|
|
if key not in self._dtypes:
|
|
raise ValueError("Unknown datatype for {} and {}".format(key, value))
|
|
return self._dtypes[key][2](value)
|
|
|
|
|
|
def flush_cache(self):
|
|
'''
|
|
Use a cache to save state changes to avoid opening a session for every change.
|
|
The cache will be flushed at the end of the simulation, and when history is accessed.
|
|
'''
|
|
with self.db:
|
|
for rec in self._tups:
|
|
self.db.execute("replace into history(agent_id, t_step, key, value) values (?, ?, ?, ?)", (rec.agent_id, rec.t_step, rec.key, rec.value))
|
|
self._tups = list()
|
|
|
|
def to_tuples(self):
|
|
self.flush_cache()
|
|
with self.db:
|
|
res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall()
|
|
for r in res:
|
|
agent_id, t_step, key, value = r
|
|
value = self.recover(key, value)
|
|
yield agent_id, t_step, key, value
|
|
|
|
def read_types(self):
|
|
with self.db:
|
|
res = self.db.execute("select key, value_type from value_types ").fetchall()
|
|
for k, v in res:
|
|
serializer = utils.serializer(v)
|
|
deserializer = utils.deserializer(v)
|
|
self._dtypes[k] = (v, serializer, deserializer)
|
|
|
|
def __getitem__(self, key):
|
|
self.flush_cache()
|
|
key = Key(*key)
|
|
agent_ids = [key.agent_id] if key.agent_id is not None else []
|
|
t_steps = [key.t_step] if key.t_step is not None else []
|
|
keys = [key.key] if key.key is not None else []
|
|
|
|
df = self.read_sql(agent_ids=agent_ids,
|
|
t_steps=t_steps,
|
|
keys=keys)
|
|
r = Records(df, filter=key, dtypes=self._dtypes)
|
|
if r.resolved:
|
|
return r.value()
|
|
return r
|
|
|
|
|
|
|
|
def read_sql(self, keys=None, agent_ids=None, t_steps=None, convert_types=False, limit=-1):
|
|
|
|
self.read_types()
|
|
|
|
def escape_and_join(v):
|
|
if v is None:
|
|
return
|
|
return ",".join(map(lambda x: "\'{}\'".format(x), v))
|
|
|
|
filters = [("key in ({})".format(escape_and_join(keys)), keys),
|
|
("agent_id in ({})".format(escape_and_join(agent_ids)), agent_ids)
|
|
]
|
|
filters = list(k[0] for k in filters if k[1])
|
|
|
|
last_df = None
|
|
if t_steps:
|
|
# Look for the last value before the minimum step in the query
|
|
min_step = min(t_steps)
|
|
last_filters = ['t_step < {}'.format(min_step),]
|
|
last_filters = last_filters + filters
|
|
condition = ' and '.join(last_filters)
|
|
|
|
last_query = '''
|
|
select h1.*
|
|
from history h1
|
|
inner join (
|
|
select agent_id, key, max(t_step) as t_step
|
|
from history
|
|
where {condition}
|
|
group by agent_id, key
|
|
) h2
|
|
on h1.agent_id = h2.agent_id and
|
|
h1.key = h2.key and
|
|
h1.t_step = h2.t_step
|
|
'''.format(condition=condition)
|
|
last_df = pd.read_sql_query(last_query, self.db)
|
|
|
|
filters.append("t_step >= '{}' and t_step <= '{}'".format(min_step, max(t_steps)))
|
|
|
|
condition = ''
|
|
if filters:
|
|
condition = 'where {} '.format(' and '.join(filters))
|
|
query = 'select * from history {} limit {}'.format(condition, limit)
|
|
df = pd.read_sql_query(query, self.db)
|
|
if last_df is not None:
|
|
df = pd.concat([df, last_df])
|
|
|
|
df_p = df.pivot_table(values='value', index=['t_step'],
|
|
columns=['key', 'agent_id'],
|
|
aggfunc='first')
|
|
|
|
for k, v in self._dtypes.items():
|
|
if k in df_p:
|
|
dtype, _, deserial = v
|
|
df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
|
|
if t_steps:
|
|
df_p = df_p.reindex(t_steps, method='ffill')
|
|
return df_p.ffill()
|
|
|
|
def __getstate__(self):
|
|
state = dict(**self.__dict__)
|
|
del state['_db']
|
|
del state['_dtypes']
|
|
return state
|
|
|
|
def __setstate__(self, state):
|
|
self.__dict__ = state
|
|
self._dtypes = {}
|
|
|
|
|
|
class Records():
|
|
|
|
def __init__(self, df, filter=None, dtypes=None):
|
|
if not filter:
|
|
filter = Key(agent_id=None,
|
|
t_step=None,
|
|
key=None)
|
|
self._df = df
|
|
self._filter = filter
|
|
self.dtypes = dtypes or {}
|
|
super().__init__()
|
|
|
|
def mask(self, tup):
|
|
res = ()
|
|
for i, k in zip(tup[:-1], self._filter):
|
|
if k is None:
|
|
res = res + (i,)
|
|
res = res + (tup[-1],)
|
|
return res
|
|
|
|
def filter(self, newKey):
|
|
f = list(self._filter)
|
|
for ix, i in enumerate(f):
|
|
if i is None:
|
|
f[ix] = newKey
|
|
self._filter = Key(*f)
|
|
|
|
@property
|
|
def resolved(self):
|
|
return sum(1 for i in self._filter if i is not None) == 3
|
|
|
|
def __iter__(self):
|
|
for column, series in self._df.iteritems():
|
|
key, agent_id = column
|
|
for t_step, value in series.iteritems():
|
|
r = Record(t_step=t_step,
|
|
agent_id=agent_id,
|
|
key=key,
|
|
value=value)
|
|
yield self.mask(r)
|
|
|
|
def value(self):
|
|
if self.resolved:
|
|
f = self._filter
|
|
try:
|
|
i = self._df[f.key][str(f.agent_id)]
|
|
ix = i.index.get_loc(f.t_step, method='ffill')
|
|
return i.iloc[ix]
|
|
except KeyError:
|
|
return self.dtypes[f.key][2]()
|
|
return list(self)
|
|
|
|
def __getitem__(self, k):
|
|
n = copy.copy(self)
|
|
n.filter(k)
|
|
if n.resolved:
|
|
return n.value()
|
|
return n
|
|
|
|
def __len__(self):
|
|
return len(self._df)
|
|
|
|
def __str__(self):
|
|
if self.resolved:
|
|
return str(self.value())
|
|
return '<Records for [{}]>'.format(self._filter)
|
|
|
|
|
|
Key = namedtuple('Key', ['agent_id', 't_step', 'key'])
|
|
Record = namedtuple('Record', 'agent_id t_step key value')
|