diff --git a/.gitignore b/.gitignore index 17a6cc4..40bafb6 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ soil_output docs/_build* build/* dist/* +prof \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index fa94150..26e0702 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,23 +5,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased] ### Added -* [WIP] Integration with MESA +* Integration with MESA * `not_agent_ids` paramter to get sql in history ### Changed * `soil.Environment` now also inherits from `mesa.Model` * `soil.Agent` now also inherits from `mesa.Agent` * `soil.time` to replace `simpy` events, delays, duration, etc. +* `agent.id` is not `agent.unique_id` to be compatible with `mesa`. A property `BaseAgent.id` has been added for compatibility. +* `agent.environment` is now `agent.model`, for the same reason as above. The parameter name in `BaseAgent.__init__` has also been renamed. ### Removed * `simpy` dependency and compatibility. Each agent used to be a simpy generator, but that made debugging and error handling more complex. That has been replaced by a scheduler within the `soil.Environment` class, similar to how `mesa` does it. +* `soil.history` is now a separate package named `tsih`. The keys namedtuple uses `dict_id` instead of `agent_id`. -### TODO: -* agent_id -> unique_id? -* mesa has Agent.model and soil has Agent.env -* Environments.agents and mesa.Agent.agents are not the same. env is a property, and it only takes into account network and environment agents. Might rename environment_agents to other_agents or sth like that -* soil.History should mimic a mesa.datacollector :/ -* soil.Simulation *could* mimic a mesa.batchrunner -* DONE include scheduler in environment -* DONE environment inherits from `mesa.Model` +### Added +* An option to choose whether a database should be used for history ## [0.15.2] diff --git a/README.md b/README.md index 919624c..714d4df 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,9 @@ Learn how to run your own simulations with our [documentation](http://soilsim.re Follow our [tutorial](examples/tutorial/soil_tutorial.ipynb) to develop your own agent models. +## Citation + + If you use Soil in your research, don't forget to cite this paper: ```bibtex @@ -28,7 +31,24 @@ If you use Soil in your research, don't forget to cite this paper: ``` -@Copyright GSI - Universidad Politécnica de Madrid 2017 +## Mesa compatibility -[![SOIL](logo_gsi.png)](https://www.gsi.upm.es) +Soil is in the process of becoming fully compatible with MESA. +As of this writing, + +This is a non-exhaustive list of tasks to achieve compatibility: +* Environments.agents and mesa.Agent.agents are not the same. env is a property, and it only takes into account network and environment agents. Might rename environment_agents to other_agents or sth like that +- [ ] Integrate `soil.Simulation` with mesa's runners: + - [ ] `soil.Simulation` could mimic/become a `mesa.batchrunner` +- [ ] Integrate `soil.Environment` with `mesa.Model`: + - [x] `Soil.Environment` inherits from `mesa.Model` + - [x] `Soil.Environment` includes a Mesa-like Scheduler (see the `soil.time` module. +- [ ] Integrate `soil.Agent` with `mesa.Agent`: + - [x] Rename agent.id to unique_id? + - [x] mesa agents can be used in soil simulations (see `examples/mesa`) +- [ ] Document the new APIs and usage + +@Copyright GSI - Universidad Politécnica de Madrid 2017-2021 + +[![SOIL](logo_gsi.png)](https://www.gsi.upm.es) diff --git a/examples/complete.yml b/examples/complete.yml index ad563a4..b3d388a 100644 --- a/examples/complete.yml +++ b/examples/complete.yml @@ -13,7 +13,7 @@ network_agents: - agent_type: CounterModel weight: 1 state: - id: 0 + state_id: 0 - agent_type: AggregatedCounter weight: 0.2 environment_agents: [] diff --git a/examples/custom_generator/custom_generator.yml b/examples/custom_generator/custom_generator.yml index 1f8fa36..8c128f3 100644 --- a/examples/custom_generator/custom_generator.yml +++ b/examples/custom_generator/custom_generator.yml @@ -13,4 +13,4 @@ network_agents: - agent_type: CounterModel weight: 1 state: - id: 0 + state_id: 0 diff --git a/examples/mesa/server.py b/examples/mesa/server.py index 9946ead..e6afecd 100644 --- a/examples/mesa/server.py +++ b/examples/mesa/server.py @@ -23,7 +23,6 @@ def network_portrayal(env): } for (agent_id) in env.G.nodes ] - # import pdb;pdb.set_trace() portrayal["edges"] = [ {"id": edge_id, "source": source, "target": target, "color": "#000000"} @@ -65,7 +64,7 @@ model_params = { "N": UserSettableParameter( "slider", "N", - 1, + 5, 1, 10, 1, diff --git a/examples/mesa/social_wealth.py b/examples/mesa/social_wealth.py index 5105897..3398884 100644 --- a/examples/mesa/social_wealth.py +++ b/examples/mesa/social_wealth.py @@ -34,9 +34,7 @@ class MoneyAgent(MesaAgent): self.pos, moore=True, include_center=False) - print(self.pos, possible_steps) new_position = self.random.choice(possible_steps) - print(self.pos, new_position) self.model.grid.move_agent(self, new_position) def give_money(self): @@ -74,21 +72,13 @@ class SocialMoneyAgent(NetworkAgent, MoneyAgent): class MoneyEnv(Environment): """A model with some number of agents.""" def __init__(self, N, width, height, *args, network_params, **kwargs): - self.initialized = True - # import pdb;pdb.set_trace() network_params['n'] = N super().__init__(*args, network_params=network_params, **kwargs) self.grid = MultiGrid(width, height, False) - # self.schedule = RandomActivation(self) - self.running = True # Create agents for agent in self.agents: - self.schedule.add(agent) - # a = MoneyAgent(i, self) - # self.schedule.add(a) - # Add the agent to a random grid cell x = self.random.randrange(self.grid.width) y = self.random.randrange(self.grid.height) self.grid.place_agent(agent, (x, y)) @@ -97,10 +87,6 @@ class MoneyEnv(Environment): model_reporters={"Gini": compute_gini}, agent_reporters={"Wealth": "wealth"}) - def step(self): - super().step() - self.datacollector.collect(self) - self.schedule.step() def graph_generator(n=5): G = nx.Graph() diff --git a/examples/newsspread/NewsSpread.yml b/examples/newsspread/NewsSpread.yml index b3bd7ba..ffb1778 100644 --- a/examples/newsspread/NewsSpread.yml +++ b/examples/newsspread/NewsSpread.yml @@ -68,12 +68,12 @@ network_agents: - agent_type: HerdViewer state: has_tv: true - id: neutral + state_id: neutral weight: 1 - agent_type: HerdViewer state: has_tv: true - id: neutral + state_id: neutral weight: 1 network_params: generator: barabasi_albert_graph @@ -95,7 +95,7 @@ network_agents: - agent_type: HerdViewer state: has_tv: true - id: neutral + state_id: neutral weight: 1 - agent_type: WiseViewer state: @@ -121,7 +121,7 @@ network_agents: - agent_type: WiseViewer state: has_tv: true - id: neutral + state_id: neutral weight: 1 - agent_type: WiseViewer state: diff --git a/examples/newsspread/newsspread.py b/examples/newsspread/newsspread.py index 8934245..f6188eb 100644 --- a/examples/newsspread/newsspread.py +++ b/examples/newsspread/newsspread.py @@ -34,8 +34,6 @@ class HerdViewer(DumbViewer): A viewer whose probability of infection depends on the state of its neighbors. ''' - level = logging.DEBUG - def infect(self): infected = self.count_neighboring_agents(state_id=self.infected.id) total = self.count_neighboring_agents() diff --git a/examples/rabbits/rabbits.yml b/examples/rabbits/rabbits.yml index 204aa7a..1d9421f 100644 --- a/examples/rabbits/rabbits.yml +++ b/examples/rabbits/rabbits.yml @@ -1,7 +1,7 @@ --- load_module: rabbit_agents name: rabbits_example -max_time: 200 +max_time: 150 interval: 1 seed: MySeed agent_type: RabbitModel diff --git a/examples/template.yml b/examples/template.yml index 9ab7548..f61757d 100644 --- a/examples/template.yml +++ b/examples/template.yml @@ -16,7 +16,7 @@ template: - agent_type: CounterModel weight: "{{ x1 }}" state: - id: 0 + state_id: 0 - agent_type: AggregatedCounter weight: "{{ 1 - x1 }}" environment_params: diff --git a/requirements.txt b/requirements.txt index 8b2f0f1..5a1d973 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ pandas>=0.23 SALib>=1.3 Jinja2 Mesa>=0.8 +tsih>=0.1.5 diff --git a/soil/VERSION b/soil/VERSION index a12760e..a881cf7 100644 --- a/soil/VERSION +++ b/soil/VERSION @@ -1 +1 @@ -0.15.2 \ No newline at end of file +0.20.0 \ No newline at end of file diff --git a/soil/__init__.py b/soil/__init__.py index c9b7f1e..c02d744 100644 --- a/soil/__init__.py +++ b/soil/__init__.py @@ -15,7 +15,6 @@ from .agents import * from . import agents from .simulation import * from .environment import Environment -from .history import History from . import serialization from . import analysis from .utils import logger diff --git a/soil/agents/__init__.py b/soil/agents/__init__.py index bc8b685..d260fec 100644 --- a/soil/agents/__init__.py +++ b/soil/agents/__init__.py @@ -6,7 +6,9 @@ from itertools import islice import json import networkx as nx -from .. import serialization, history, utils, time +from .. import serialization, utils, time + +from tsih import Key from mesa import Agent @@ -16,6 +18,7 @@ def as_node(agent): return agent.id return agent +IGNORED_FIELDS = ('model', 'logger') class BaseAgent(Agent): """ @@ -27,23 +30,20 @@ class BaseAgent(Agent): def __init__(self, unique_id, model, - state=None, name=None, interval=None): # Check for REQUIRED arguments # Initialize agent parameters if isinstance(unique_id, Agent): raise Exception() + self._saved = set() super().__init__(unique_id=unique_id, model=model) self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id) self._neighbors = None self.alive = True - real_state = deepcopy(self.defaults) - real_state.update(state or {}) - self.state = real_state - self.interval = interval or self.get('interval', getattr(self.model, 'interval', 1)) + self.interval = interval or self.get('interval', 1) self.logger = logging.getLogger(self.model.name).getChild(self.name) if hasattr(self, 'level'): @@ -75,7 +75,6 @@ class BaseAgent(Agent): @state.setter def state(self, value): - self._state = {} for k, v in value.items(): self[k] = v @@ -87,28 +86,36 @@ class BaseAgent(Agent): def environment_params(self, value): self.model.environment_params = value + def __setattr__(self, key, value): + if not key.startswith('_') and key not in IGNORED_FIELDS: + try: + k = Key(t_step=self.now, + dict_id=self.unique_id, + key=key) + self._saved.add(key) + self.model[k] = value + except AttributeError: + pass + super().__setattr__(key, value) + def __getitem__(self, key): if isinstance(key, tuple): key, t_step = key - k = history.Key(key=key, t_step=t_step, agent_id=self.unique_id) + k = Key(key=key, t_step=t_step, dict_id=self.unique_id) return self.model[k] - return self._state.get(key, None) + return getattr(self, key) def __delitem__(self, key): - self._state[key] = None + return delattr(self, key) def __contains__(self, key): - return key in self._state + return hasattr(self, key) def __setitem__(self, key, value): - self._state[key] = value - k = history.Key(t_step=self.now, - agent_id=self.id, - key=key) - self.model[k] = value + setattr(self, key, value) def items(self): - return self._state.items() + return ((k, getattr(self, k)) for k in self._saved) def get(self, key, default=None): return self[key] if key in self else default @@ -150,25 +157,6 @@ class BaseAgent(Agent): def info(self, *args, **kwargs): return self.log(*args, level=logging.INFO, **kwargs) - # def __getstate__(self): - # ''' - # Serializing an agent will lose all its running information (you cannot - # serialize an iterator), but it keeps the state and link to the environment, - # so it can be used for inspection and dumping to a file - # ''' - # state = {} - # state['id'] = self.id - # state['environment'] = self.model - # state['_state'] = self._state - # return state - - # def __setstate__(self, state): - # ''' - # Get back a serialized agent and try to re-compose it - # ''' - # self.state_id = state['id'] - # self._state = state['_state'] - # self.model = state['environment'] class NetworkAgent(BaseAgent): @@ -303,15 +291,15 @@ class MetaFSM(type): class FSM(NetworkAgent, metaclass=MetaFSM): def __init__(self, *args, **kwargs): super(FSM, self).__init__(*args, **kwargs) - if 'id' not in self.state: + if not hasattr(self, 'state_id'): if not self.default_state: raise ValueError('No default state specified for {}'.format(self.unique_id)) - self['id'] = self.default_state.id + self.state_id = self.default_state.id - self.set_state(self.state['id']) + self.set_state(self.state_id) def step(self): - self.debug(f'Agent {self.unique_id} @ state {self["id"]}') + self.debug(f'Agent {self.unique_id} @ state {self.state_id}') interval = super().step() if 'id' not in self.state: # if 'id' in self.state: @@ -320,14 +308,14 @@ class FSM(NetworkAgent, metaclass=MetaFSM): self.set_state(self.default_state.id) else: raise Exception('{} has no valid state id or default state'.format(self)) - return self.states[self.state['id']](self) or interval + return self.states[self.state_id](self) or interval def set_state(self, state): if hasattr(state, 'id'): state = state.id if state not in self.states: raise ValueError('{} is not a valid state'.format(state)) - self.state['id'] = state + self.state_id = state return state @@ -541,7 +529,7 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False, f = filter(lambda x: x not in ignore, f) if state_id is not None: - f = filter(lambda agent: agent.state.get('id', None) in state_id, f) + f = filter(lambda agent: agent.get('state_id', None) in state_id, f) if agent_type is not None: f = filter(lambda agent: isinstance(agent, agent_type), f) diff --git a/soil/analysis.py b/soil/analysis.py index 1d07eb5..65d8468 100644 --- a/soil/analysis.py +++ b/soil/analysis.py @@ -4,7 +4,8 @@ import glob import yaml from os.path import join -from . import serialization, history +from . import serialization +from tsih import History def read_data(*args, group=False, **kwargs): @@ -34,7 +35,7 @@ def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs): def read_sql(db, *args, **kwargs): - h = history.History(db_path=db, backup=False, readonly=True) + h = History(db_path=db, backup=False, readonly=True) df = h.read_sql(*args, **kwargs) return df diff --git a/soil/environment.py b/soil/environment.py index 50c0745..cead20b 100644 --- a/soil/environment.py +++ b/soil/environment.py @@ -12,9 +12,11 @@ from networkx.readwrite import json_graph import networkx as nx +from tsih import History, Record, Key, NoHistory + from mesa import Model -from . import serialization, agents, analysis, history, utils, time +from . import serialization, agents, analysis, utils, time # These properties will be copied when pickling/unpickling the environment _CONFIG_PROPS = [ 'name', @@ -46,6 +48,7 @@ class Environment(Model): schedule=None, initial_time=0, environment_params=None, + history=True, dir_path=None, **kwargs): @@ -78,8 +81,12 @@ class Environment(Model): self._env_agents = {} self.interval = interval - self._history = history.History(name=self.name, - backup=True) + if history: + history = History + else: + history = NoHistory + self._history = history(name=self.name, + backup=True) self['SEED'] = seed if network_agents: @@ -162,8 +169,15 @@ class Environment(Model): if agent_type: state = defstate a = agent_type(model=self, - unique_id=agent_id, - state=state) + unique_id=agent_id) + + for (k, v) in getattr(a, 'defaults', {}).items(): + if not hasattr(a, k) or getattr(a, k) is None: + setattr(a, k, v) + + for (k, v) in state.items(): + setattr(a, k, v) + node['agent'] = a self.schedule.add(a) return a @@ -183,6 +197,11 @@ class Environment(Model): start = start or self.now return self.G.add_edge(agent1, agent2, **attrs) + def step(self): + super().step() + self.datacollector.collect(self) + self.schedule.step() + def run(self, until, *args, **kwargs): self._save_state() @@ -204,12 +223,12 @@ class Environment(Model): def __setitem__(self, key, value): if isinstance(key, tuple): - k = history.Key(*key) + k = Key(*key) self._history.save_record(*k, value=value) return self.environment_params[key] = value - self._history.save_record(agent_id='env', + self._history.save_record(dict_id='env', t_step=self.now, key=key, value=value) @@ -274,16 +293,16 @@ class Environment(Model): if now is None: now = self.now for k, v in self.environment_params.items(): - yield history.Record(agent_id='env', - t_step=now, - key=k, - value=v) + yield Record(dict_id='env', + t_step=now, + key=k, + value=v) for agent in self.agents: for k, v in agent.state.items(): - yield history.Record(agent_id=agent.id, - t_step=now, - key=k, - value=v) + yield Record(dict_id=agent.id, + t_step=now, + key=k, + value=v) def history_to_tuples(self): return self._history.to_tuples() diff --git a/soil/history.py b/soil/history.py deleted file mode 100644 index 984bc04..0000000 --- a/soil/history.py +++ /dev/null @@ -1,388 +0,0 @@ -import time -import os -import pandas as pd -import sqlite3 -import copy -import logging -import tempfile - -logger = logging.getLogger(__name__) - -from collections import UserDict, namedtuple - -from . import serialization -from .utils import open_or_reuse, unflatten_dict - - -class History: - """ - Store and retrieve values from a sqlite database. - """ - - def __init__(self, name=None, db_path=None, backup=False, readonly=False): - if readonly and (not os.path.exists(db_path)): - raise Exception('The DB file does not exist. Cannot open in read-only mode') - - self._db = None - self._temp = db_path is None - self._stats_columns = None - self.readonly = readonly - - if self._temp: - if not name: - name = time.time() - # The file will be deleted as soon as it's closed - # Normally, that will be on destruction - db_path = tempfile.NamedTemporaryFile(suffix='{}.sqlite'.format(name)).name - - - if backup and os.path.exists(db_path): - newname = db_path + '.backup{}.sqlite'.format(time.time()) - os.rename(db_path, newname) - - self.db_path = db_path - - self.db = db_path - self._dtypes = {} - self._tups = [] - - - if self.readonly: - return - - with self.db: - logger.debug('Creating database {}'.format(self.db_path)) - self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step real, key text, value text)''') - self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''') - self.db.execute('''CREATE TABLE IF NOT EXISTS stats (trial_id text)''') - self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''') - - @property - def db(self): - try: - self._db.cursor() - except (sqlite3.ProgrammingError, AttributeError): - self.db = None # Reset the database - return self._db - - @db.setter - def db(self, db_path=None): - self._close() - db_path = db_path or self.db_path - if isinstance(db_path, str): - logger.debug('Connecting to database {}'.format(db_path)) - self._db = sqlite3.connect(db_path) - self._db.row_factory = sqlite3.Row - else: - self._db = db_path - - def _close(self): - if self._db is None: - return - self.flush_cache() - self._db.close() - self._db = None - - def save_stats(self, stat): - if self.readonly: - print('DB in readonly mode') - return - if not stat: - return - with self.db: - if not self._stats_columns: - self._stats_columns = list(c['name'] for c in self.db.execute('PRAGMA table_info(stats)')) - - for column, value in stat.items(): - if column in self._stats_columns: - continue - dtype = 'text' - if not isinstance(value, str): - try: - float(value) - dtype = 'real' - int(value) - dtype = 'int' - except (ValueError, OverflowError): - pass - self.db.execute('ALTER TABLE stats ADD "{}" "{}"'.format(column, dtype)) - self._stats_columns.append(column) - - columns = ", ".join(map(lambda x: '"{}"'.format(x), stat.keys())) - values = ", ".join(['"{0}"'.format(col) for col in stat.values()]) - query = "INSERT INTO stats ({columns}) VALUES ({values})".format( - columns=columns, - values=values - ) - self.db.execute(query) - - def get_stats(self, unflatten=True): - rows = self.db.execute("select * from stats").fetchall() - res = [] - for row in rows: - d = {} - for k in row.keys(): - if row[k] is None: - continue - d[k] = row[k] - if unflatten: - d = unflatten_dict(d) - res.append(d) - return res - - @property - def dtypes(self): - self._read_types() - return {k:v[0] for k, v in self._dtypes.items()} - - def save_tuples(self, tuples): - ''' - Save a series of tuples, converting them to records if necessary - ''' - self.save_records(Record(*tup) for tup in tuples) - - def save_records(self, records): - ''' - Save a collection of records - ''' - for record in records: - if not isinstance(record, Record): - record = Record(*record) - self.save_record(*record) - - def save_record(self, agent_id, t_step, key, value): - ''' - Save a collection of records to the database. - Database writes are cached. - ''' - if self.readonly: - raise Exception('DB in readonly mode') - if key not in self._dtypes: - self._read_types() - if key not in self._dtypes: - name = serialization.name(value) - serializer = serialization.serializer(name) - deserializer = serialization.deserializer(name) - self._dtypes[key] = (name, serializer, deserializer) - with self.db: - self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name)) - value = self._dtypes[key][1](value) - - self._tups.append(Record(agent_id=agent_id, - t_step=t_step, - key=key, - value=value)) - if len(self._tups) > 100: - self.flush_cache() - - def flush_cache(self): - ''' - Use a cache to save state changes to avoid opening a session for every change. - The cache will be flushed at the end of the simulation, and when history is accessed. - ''' - if self.readonly: - raise Exception('DB in readonly mode') - logger.debug('Flushing cache {}'.format(self.db_path)) - with self.db: - self.db.executemany("replace into history(agent_id, t_step, key, value) values (?, ?, ?, ?)", self._tups) - # (rec.agent_id, rec.t_step, rec.key, rec.value)) - self._tups.clear() - - def to_tuples(self): - self.flush_cache() - with self.db: - res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall() - for r in res: - agent_id, t_step, key, value = r - if key not in self._dtypes: - self._read_types() - if key not in self._dtypes: - raise ValueError("Unknown datatype for {} and {}".format(key, value)) - value = self._dtypes[key][2](value) - yield agent_id, t_step, key, value - - def _read_types(self): - with self.db: - res = self.db.execute("select key, value_type from value_types ").fetchall() - for k, v in res: - serializer = serialization.serializer(v) - deserializer = serialization.deserializer(v) - self._dtypes[k] = (v, serializer, deserializer) - - def __getitem__(self, key): - # raise NotImplementedError() - self.flush_cache() - key = Key(*key) - agent_ids = [key.agent_id] if key.agent_id is not None else [] - t_steps = [key.t_step] if key.t_step is not None else [] - keys = [key.key] if key.key is not None else [] - - df = self.read_sql(agent_ids=agent_ids, - t_steps=t_steps, - keys=keys) - r = Records(df, filter=key, dtypes=self._dtypes) - if r.resolved: - return r.value() - return r - - def read_sql(self, keys=None, agent_ids=None, not_agent_ids=None, t_steps=None, convert_types=False, limit=-1): - - self._read_types() - - def escape_and_join(v): - if v is None: - return - return ",".join(map(lambda x: "\'{}\'".format(x), v)) - - filters = [("key in ({})".format(escape_and_join(keys)), keys), - ("agent_id in ({})".format(escape_and_join(agent_ids)), agent_ids), - ("agent_id not in ({})".format(escape_and_join(not_agent_ids)), not_agent_ids) - ] - filters = list(k[0] for k in filters if k[1]) - - last_df = None - if t_steps: - # Convert negative indices into positive - if any(x<0 for x in t_steps): - max_t = int(self.db.execute("select max(t_step) from history").fetchone()[0]) - t_steps = [t if t>0 else max_t+1+t for t in t_steps] - - # We will be doing ffill interpolation, so we need to look for - # the last value before the minimum step in the query - min_step = min(t_steps) - last_filters = ['t_step < {}'.format(min_step),] - last_filters = last_filters + filters - condition = ' and '.join(last_filters) - - last_query = ''' - select h1.* - from history h1 - inner join ( - select agent_id, key, max(t_step) as t_step - from history - where {condition} - group by agent_id, key - ) h2 - on h1.agent_id = h2.agent_id and - h1.key = h2.key and - h1.t_step = h2.t_step - '''.format(condition=condition) - last_df = pd.read_sql_query(last_query, self.db) - - filters.append("t_step >= '{}' and t_step <= '{}'".format(min_step, max(t_steps))) - - condition = '' - if filters: - condition = 'where {} '.format(' and '.join(filters)) - query = 'select * from history {} limit {}'.format(condition, limit) - df = pd.read_sql_query(query, self.db) - if last_df is not None: - df = pd.concat([df, last_df]) - - df_p = df.pivot_table(values='value', index=['t_step'], - columns=['key', 'agent_id'], - aggfunc='first') - - for k, v in self._dtypes.items(): - if k in df_p: - dtype, _, deserial = v - try: - df_p[k] = df_p[k].fillna(method='ffill').astype(dtype) - except (TypeError, ValueError): - # Avoid forward-filling unknown/incompatible types - continue - if t_steps: - df_p = df_p.reindex(t_steps, method='ffill') - return df_p.ffill() - - def __getstate__(self): - state = dict(**self.__dict__) - del state['_db'] - del state['_dtypes'] - return state - - def __setstate__(self, state): - self.__dict__ = state - self._dtypes = {} - self._db = None - - def dump(self, f): - self._close() - for line in open_or_reuse(self.db_path, 'rb'): - f.write(line) - - -class Records(): - - def __init__(self, df, filter=None, dtypes=None): - if not filter: - filter = Key(agent_id=None, - t_step=None, - key=None) - self._df = df - self._filter = filter - self.dtypes = dtypes or {} - super().__init__() - - def mask(self, tup): - res = () - for i, k in zip(tup[:-1], self._filter): - if k is None: - res = res + (i,) - res = res + (tup[-1],) - return res - - def filter(self, newKey): - f = list(self._filter) - for ix, i in enumerate(f): - if i is None: - f[ix] = newKey - self._filter = Key(*f) - - @property - def resolved(self): - return sum(1 for i in self._filter if i is not None) == 3 - - def __iter__(self): - for column, series in self._df.iteritems(): - key, agent_id = column - for t_step, value in series.iteritems(): - r = Record(t_step=t_step, - agent_id=agent_id, - key=key, - value=value) - yield self.mask(r) - - def value(self): - if self.resolved: - f = self._filter - try: - i = self._df[f.key][str(f.agent_id)] - ix = i.index.get_loc(f.t_step, method='ffill') - return i.iloc[ix] - except KeyError as ex: - return self.dtypes[f.key][2]() - return list(self) - - def df(self): - return self._df - - def __getitem__(self, k): - n = copy.copy(self) - n.filter(k) - if n.resolved: - return n.value() - return n - - def __len__(self): - return len(self._df) - - def __str__(self): - if self.resolved: - return str(self.value()) - return ''.format(self._filter) - -Key = namedtuple('Key', ['agent_id', 't_step', 'key']) -Record = namedtuple('Record', 'agent_id t_step key value') - -Stat = namedtuple('Stat', 'trial_id') diff --git a/soil/simulation.py b/soil/simulation.py index cb19f1d..5aa6374 100644 --- a/soil/simulation.py +++ b/soil/simulation.py @@ -9,6 +9,7 @@ import networkx as nx from networkx.readwrite import json_graph from multiprocessing import Pool from functools import partial +from tsih import History import pickle @@ -17,7 +18,6 @@ from .environment import Environment from .utils import logger from .exporters import default from .stats import defaultStats -from .history import History #TODO: change documentation for simulation @@ -159,7 +159,7 @@ class Simulation: **kwargs) def run_gen(self, *args, parallel=False, dry_run=False, - exporters=[default, ], stats=[defaultStats], outdir=None, exporter_params={}, + exporters=[default, ], stats=[], outdir=None, exporter_params={}, stats_params={}, log_level=None, **kwargs): '''Run the simulation and yield the resulting environments.''' diff --git a/soil/stats.py b/soil/stats.py index 50a8d29..2a7636f 100644 --- a/soil/stats.py +++ b/soil/stats.py @@ -97,7 +97,7 @@ class defaultStats(Stats): return { 'network ': { 'n_nodes': env.G.number_of_nodes(), - 'n_edges': env.G.number_of_nodes(), + 'n_edges': env.G.number_of_edges(), }, 'agents': { 'model_count': dict(c), diff --git a/soil/utils.py b/soil/utils.py index 22ee024..e95758c 100644 --- a/soil/utils.py +++ b/soil/utils.py @@ -26,6 +26,8 @@ def timer(name='task', pre="", function=logger.info, to_object=None): to_object.end = end + + def safe_open(path, mode='r', backup=True, **kwargs): outdir = os.path.dirname(path) if outdir and not os.path.exists(outdir): diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 425f2cc..47c649b 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -67,13 +67,13 @@ class TestAnalysis(TestCase): def test_count(self): env = self.env df = analysis.read_sql(env._history.db_path) - res = analysis.get_count(df, 'SEED', 'id') + res = analysis.get_count(df, 'SEED', 'state_id') assert res['SEED'][self.env['SEED']].iloc[0] == 1 assert res['SEED'][self.env['SEED']].iloc[-1] == 1 - assert res['id']['odd'].iloc[0] == 2 - assert res['id']['even'].iloc[0] == 0 - assert res['id']['odd'].iloc[-1] == 1 - assert res['id']['even'].iloc[-1] == 1 + assert res['state_id']['odd'].iloc[0] == 2 + assert res['state_id']['even'].iloc[0] == 0 + assert res['state_id']['odd'].iloc[-1] == 1 + assert res['state_id']['even'].iloc[-1] == 1 def test_value(self): env = self.env diff --git a/tests/test_main.py b/tests/test_main.py index db28e19..349c3e3 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -9,7 +9,7 @@ from functools import partial from os.path import join from soil import (simulation, Environment, agents, serialization, - history, utils) + utils) from soil.time import Delta @@ -21,8 +21,8 @@ class CustomAgent(agents.FSM): @agents.default_state @agents.state def normal(self): - self.state['neighbors'] = self.count_agents(state_id='normal', - limit_neighbors=True) + self.neighbors = self.count_agents(state_id='normal', + limit_neighbors=True) @agents.state def unreachable(self): return @@ -116,7 +116,7 @@ class TestMain(TestCase): 'network_agents': [{ 'agent_type': 'AggregatedCounter', 'weight': 1, - 'state': {'id': 0} + 'state': {'state_id': 0} }], 'max_time': 10, @@ -149,10 +149,9 @@ class TestMain(TestCase): } s = simulation.from_config(config) env = s.run_simulation(dry_run=True)[0] - assert env.get_agent(0).state['neighbors'] == 1 - assert env.get_agent(0).state['neighbors'] == 1 assert env.get_agent(1).count_agents(state_id='normal') == 2 assert env.get_agent(1).count_agents(state_id='normal', limit_neighbors=True) == 1 + assert env.get_agent(0).neighbors == 1 def test_torvalds_example(self): """A complete example from a documentation should work.""" @@ -317,12 +316,6 @@ class TestMain(TestCase): assert recovered['key', 0] == 'test' assert recovered['key'] == 'test' - def test_history(self): - '''Test storing in and retrieving from history (sqlite)''' - h = history.History() - h.save_record(agent_id=0, t_step=0, key="test", value="hello") - assert h[0, 0, "test"] == "hello" - def test_subgraph(self): '''An agent should be able to subgraph the global topology''' G = nx.Graph() @@ -350,12 +343,13 @@ class TestMain(TestCase): 'network_params': {}, 'agent_type': 'CounterModel', 'max_time': 2, - 'num_trials': 100, + 'num_trials': 50, 'environment_params': {} } s = simulation.from_config(config) runs = list(s.run_simulation(dry_run=True)) over = list(x.now for x in runs if x.now>2) + assert len(runs) == config['num_trials'] assert len(over) == 0 @@ -372,11 +366,11 @@ class TestMain(TestCase): return self.ping a = ToggleAgent(unique_id=1, model=Environment()) - assert a.state["id"] == a.ping.id + assert a.state_id == a.ping.id a.step() - assert a.state["id"] == a.pong.id + assert a.state_id == a.pong.id a.step() - assert a.state["id"] == a.ping.id + assert a.state_id == a.ping.id def test_fsm_when(self): '''Basic state change'''