mirror of
https://github.com/gsi-upm/soil
synced 2024-11-21 18:52:28 +00:00
Partial MESA compatibility and several fixes
Documentation for the new APIs is still a work in progress :)
This commit is contained in:
parent
af9a392a93
commit
6c4f44b4cb
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,3 +8,4 @@ soil_output
|
|||||||
docs/_build*
|
docs/_build*
|
||||||
build/*
|
build/*
|
||||||
dist/*
|
dist/*
|
||||||
|
prof
|
15
CHANGELOG.md
15
CHANGELOG.md
@ -5,23 +5,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
|||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
### Added
|
### Added
|
||||||
* [WIP] Integration with MESA
|
* Integration with MESA
|
||||||
* `not_agent_ids` paramter to get sql in history
|
* `not_agent_ids` paramter to get sql in history
|
||||||
### Changed
|
### Changed
|
||||||
* `soil.Environment` now also inherits from `mesa.Model`
|
* `soil.Environment` now also inherits from `mesa.Model`
|
||||||
* `soil.Agent` now also inherits from `mesa.Agent`
|
* `soil.Agent` now also inherits from `mesa.Agent`
|
||||||
* `soil.time` to replace `simpy` events, delays, duration, etc.
|
* `soil.time` to replace `simpy` events, delays, duration, etc.
|
||||||
|
* `agent.id` is not `agent.unique_id` to be compatible with `mesa`. A property `BaseAgent.id` has been added for compatibility.
|
||||||
|
* `agent.environment` is now `agent.model`, for the same reason as above. The parameter name in `BaseAgent.__init__` has also been renamed.
|
||||||
### Removed
|
### Removed
|
||||||
* `simpy` dependency and compatibility. Each agent used to be a simpy generator, but that made debugging and error handling more complex. That has been replaced by a scheduler within the `soil.Environment` class, similar to how `mesa` does it.
|
* `simpy` dependency and compatibility. Each agent used to be a simpy generator, but that made debugging and error handling more complex. That has been replaced by a scheduler within the `soil.Environment` class, similar to how `mesa` does it.
|
||||||
|
* `soil.history` is now a separate package named `tsih`. The keys namedtuple uses `dict_id` instead of `agent_id`.
|
||||||
|
|
||||||
### TODO:
|
### Added
|
||||||
* agent_id -> unique_id?
|
* An option to choose whether a database should be used for history
|
||||||
* mesa has Agent.model and soil has Agent.env
|
|
||||||
* Environments.agents and mesa.Agent.agents are not the same. env is a property, and it only takes into account network and environment agents. Might rename environment_agents to other_agents or sth like that
|
|
||||||
* soil.History should mimic a mesa.datacollector :/
|
|
||||||
* soil.Simulation *could* mimic a mesa.batchrunner
|
|
||||||
* DONE include scheduler in environment
|
|
||||||
* DONE environment inherits from `mesa.Model`
|
|
||||||
|
|
||||||
|
|
||||||
## [0.15.2]
|
## [0.15.2]
|
||||||
|
24
README.md
24
README.md
@ -5,6 +5,9 @@ Learn how to run your own simulations with our [documentation](http://soilsim.re
|
|||||||
|
|
||||||
Follow our [tutorial](examples/tutorial/soil_tutorial.ipynb) to develop your own agent models.
|
Follow our [tutorial](examples/tutorial/soil_tutorial.ipynb) to develop your own agent models.
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
|
||||||
If you use Soil in your research, don't forget to cite this paper:
|
If you use Soil in your research, don't forget to cite this paper:
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@ -28,7 +31,24 @@ If you use Soil in your research, don't forget to cite this paper:
|
|||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
@Copyright GSI - Universidad Politécnica de Madrid 2017
|
## Mesa compatibility
|
||||||
|
|
||||||
|
Soil is in the process of becoming fully compatible with MESA.
|
||||||
|
As of this writing,
|
||||||
|
|
||||||
|
This is a non-exhaustive list of tasks to achieve compatibility:
|
||||||
|
|
||||||
|
* Environments.agents and mesa.Agent.agents are not the same. env is a property, and it only takes into account network and environment agents. Might rename environment_agents to other_agents or sth like that
|
||||||
|
- [ ] Integrate `soil.Simulation` with mesa's runners:
|
||||||
|
- [ ] `soil.Simulation` could mimic/become a `mesa.batchrunner`
|
||||||
|
- [ ] Integrate `soil.Environment` with `mesa.Model`:
|
||||||
|
- [x] `Soil.Environment` inherits from `mesa.Model`
|
||||||
|
- [x] `Soil.Environment` includes a Mesa-like Scheduler (see the `soil.time` module.
|
||||||
|
- [ ] Integrate `soil.Agent` with `mesa.Agent`:
|
||||||
|
- [x] Rename agent.id to unique_id?
|
||||||
|
- [x] mesa agents can be used in soil simulations (see `examples/mesa`)
|
||||||
|
- [ ] Document the new APIs and usage
|
||||||
|
|
||||||
|
@Copyright GSI - Universidad Politécnica de Madrid 2017-2021
|
||||||
|
|
||||||
[![SOIL](logo_gsi.png)](https://www.gsi.upm.es)
|
[![SOIL](logo_gsi.png)](https://www.gsi.upm.es)
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ network_agents:
|
|||||||
- agent_type: CounterModel
|
- agent_type: CounterModel
|
||||||
weight: 1
|
weight: 1
|
||||||
state:
|
state:
|
||||||
id: 0
|
state_id: 0
|
||||||
- agent_type: AggregatedCounter
|
- agent_type: AggregatedCounter
|
||||||
weight: 0.2
|
weight: 0.2
|
||||||
environment_agents: []
|
environment_agents: []
|
||||||
|
@ -13,4 +13,4 @@ network_agents:
|
|||||||
- agent_type: CounterModel
|
- agent_type: CounterModel
|
||||||
weight: 1
|
weight: 1
|
||||||
state:
|
state:
|
||||||
id: 0
|
state_id: 0
|
||||||
|
@ -23,7 +23,6 @@ def network_portrayal(env):
|
|||||||
}
|
}
|
||||||
for (agent_id) in env.G.nodes
|
for (agent_id) in env.G.nodes
|
||||||
]
|
]
|
||||||
# import pdb;pdb.set_trace()
|
|
||||||
|
|
||||||
portrayal["edges"] = [
|
portrayal["edges"] = [
|
||||||
{"id": edge_id, "source": source, "target": target, "color": "#000000"}
|
{"id": edge_id, "source": source, "target": target, "color": "#000000"}
|
||||||
@ -65,7 +64,7 @@ model_params = {
|
|||||||
"N": UserSettableParameter(
|
"N": UserSettableParameter(
|
||||||
"slider",
|
"slider",
|
||||||
"N",
|
"N",
|
||||||
1,
|
5,
|
||||||
1,
|
1,
|
||||||
10,
|
10,
|
||||||
1,
|
1,
|
||||||
|
@ -34,9 +34,7 @@ class MoneyAgent(MesaAgent):
|
|||||||
self.pos,
|
self.pos,
|
||||||
moore=True,
|
moore=True,
|
||||||
include_center=False)
|
include_center=False)
|
||||||
print(self.pos, possible_steps)
|
|
||||||
new_position = self.random.choice(possible_steps)
|
new_position = self.random.choice(possible_steps)
|
||||||
print(self.pos, new_position)
|
|
||||||
self.model.grid.move_agent(self, new_position)
|
self.model.grid.move_agent(self, new_position)
|
||||||
|
|
||||||
def give_money(self):
|
def give_money(self):
|
||||||
@ -74,21 +72,13 @@ class SocialMoneyAgent(NetworkAgent, MoneyAgent):
|
|||||||
class MoneyEnv(Environment):
|
class MoneyEnv(Environment):
|
||||||
"""A model with some number of agents."""
|
"""A model with some number of agents."""
|
||||||
def __init__(self, N, width, height, *args, network_params, **kwargs):
|
def __init__(self, N, width, height, *args, network_params, **kwargs):
|
||||||
self.initialized = True
|
|
||||||
# import pdb;pdb.set_trace()
|
|
||||||
|
|
||||||
network_params['n'] = N
|
network_params['n'] = N
|
||||||
super().__init__(*args, network_params=network_params, **kwargs)
|
super().__init__(*args, network_params=network_params, **kwargs)
|
||||||
self.grid = MultiGrid(width, height, False)
|
self.grid = MultiGrid(width, height, False)
|
||||||
# self.schedule = RandomActivation(self)
|
|
||||||
self.running = True
|
|
||||||
|
|
||||||
# Create agents
|
# Create agents
|
||||||
for agent in self.agents:
|
for agent in self.agents:
|
||||||
self.schedule.add(agent)
|
|
||||||
# a = MoneyAgent(i, self)
|
|
||||||
# self.schedule.add(a)
|
|
||||||
# Add the agent to a random grid cell
|
|
||||||
x = self.random.randrange(self.grid.width)
|
x = self.random.randrange(self.grid.width)
|
||||||
y = self.random.randrange(self.grid.height)
|
y = self.random.randrange(self.grid.height)
|
||||||
self.grid.place_agent(agent, (x, y))
|
self.grid.place_agent(agent, (x, y))
|
||||||
@ -97,10 +87,6 @@ class MoneyEnv(Environment):
|
|||||||
model_reporters={"Gini": compute_gini},
|
model_reporters={"Gini": compute_gini},
|
||||||
agent_reporters={"Wealth": "wealth"})
|
agent_reporters={"Wealth": "wealth"})
|
||||||
|
|
||||||
def step(self):
|
|
||||||
super().step()
|
|
||||||
self.datacollector.collect(self)
|
|
||||||
self.schedule.step()
|
|
||||||
|
|
||||||
def graph_generator(n=5):
|
def graph_generator(n=5):
|
||||||
G = nx.Graph()
|
G = nx.Graph()
|
||||||
|
@ -68,12 +68,12 @@ network_agents:
|
|||||||
- agent_type: HerdViewer
|
- agent_type: HerdViewer
|
||||||
state:
|
state:
|
||||||
has_tv: true
|
has_tv: true
|
||||||
id: neutral
|
state_id: neutral
|
||||||
weight: 1
|
weight: 1
|
||||||
- agent_type: HerdViewer
|
- agent_type: HerdViewer
|
||||||
state:
|
state:
|
||||||
has_tv: true
|
has_tv: true
|
||||||
id: neutral
|
state_id: neutral
|
||||||
weight: 1
|
weight: 1
|
||||||
network_params:
|
network_params:
|
||||||
generator: barabasi_albert_graph
|
generator: barabasi_albert_graph
|
||||||
@ -95,7 +95,7 @@ network_agents:
|
|||||||
- agent_type: HerdViewer
|
- agent_type: HerdViewer
|
||||||
state:
|
state:
|
||||||
has_tv: true
|
has_tv: true
|
||||||
id: neutral
|
state_id: neutral
|
||||||
weight: 1
|
weight: 1
|
||||||
- agent_type: WiseViewer
|
- agent_type: WiseViewer
|
||||||
state:
|
state:
|
||||||
@ -121,7 +121,7 @@ network_agents:
|
|||||||
- agent_type: WiseViewer
|
- agent_type: WiseViewer
|
||||||
state:
|
state:
|
||||||
has_tv: true
|
has_tv: true
|
||||||
id: neutral
|
state_id: neutral
|
||||||
weight: 1
|
weight: 1
|
||||||
- agent_type: WiseViewer
|
- agent_type: WiseViewer
|
||||||
state:
|
state:
|
||||||
|
@ -34,8 +34,6 @@ class HerdViewer(DumbViewer):
|
|||||||
A viewer whose probability of infection depends on the state of its neighbors.
|
A viewer whose probability of infection depends on the state of its neighbors.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
level = logging.DEBUG
|
|
||||||
|
|
||||||
def infect(self):
|
def infect(self):
|
||||||
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
||||||
total = self.count_neighboring_agents()
|
total = self.count_neighboring_agents()
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
---
|
---
|
||||||
load_module: rabbit_agents
|
load_module: rabbit_agents
|
||||||
name: rabbits_example
|
name: rabbits_example
|
||||||
max_time: 200
|
max_time: 150
|
||||||
interval: 1
|
interval: 1
|
||||||
seed: MySeed
|
seed: MySeed
|
||||||
agent_type: RabbitModel
|
agent_type: RabbitModel
|
||||||
|
@ -16,7 +16,7 @@ template:
|
|||||||
- agent_type: CounterModel
|
- agent_type: CounterModel
|
||||||
weight: "{{ x1 }}"
|
weight: "{{ x1 }}"
|
||||||
state:
|
state:
|
||||||
id: 0
|
state_id: 0
|
||||||
- agent_type: AggregatedCounter
|
- agent_type: AggregatedCounter
|
||||||
weight: "{{ 1 - x1 }}"
|
weight: "{{ 1 - x1 }}"
|
||||||
environment_params:
|
environment_params:
|
||||||
|
@ -6,3 +6,4 @@ pandas>=0.23
|
|||||||
SALib>=1.3
|
SALib>=1.3
|
||||||
Jinja2
|
Jinja2
|
||||||
Mesa>=0.8
|
Mesa>=0.8
|
||||||
|
tsih>=0.1.5
|
||||||
|
@ -1 +1 @@
|
|||||||
0.15.2
|
0.20.0
|
@ -15,7 +15,6 @@ from .agents import *
|
|||||||
from . import agents
|
from . import agents
|
||||||
from .simulation import *
|
from .simulation import *
|
||||||
from .environment import Environment
|
from .environment import Environment
|
||||||
from .history import History
|
|
||||||
from . import serialization
|
from . import serialization
|
||||||
from . import analysis
|
from . import analysis
|
||||||
from .utils import logger
|
from .utils import logger
|
||||||
|
@ -6,7 +6,9 @@ from itertools import islice
|
|||||||
import json
|
import json
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
from .. import serialization, history, utils, time
|
from .. import serialization, utils, time
|
||||||
|
|
||||||
|
from tsih import Key
|
||||||
|
|
||||||
from mesa import Agent
|
from mesa import Agent
|
||||||
|
|
||||||
@ -16,6 +18,7 @@ def as_node(agent):
|
|||||||
return agent.id
|
return agent.id
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
IGNORED_FIELDS = ('model', 'logger')
|
||||||
|
|
||||||
class BaseAgent(Agent):
|
class BaseAgent(Agent):
|
||||||
"""
|
"""
|
||||||
@ -27,23 +30,20 @@ class BaseAgent(Agent):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
unique_id,
|
unique_id,
|
||||||
model,
|
model,
|
||||||
state=None,
|
|
||||||
name=None,
|
name=None,
|
||||||
interval=None):
|
interval=None):
|
||||||
# Check for REQUIRED arguments
|
# Check for REQUIRED arguments
|
||||||
# Initialize agent parameters
|
# Initialize agent parameters
|
||||||
if isinstance(unique_id, Agent):
|
if isinstance(unique_id, Agent):
|
||||||
raise Exception()
|
raise Exception()
|
||||||
|
self._saved = set()
|
||||||
super().__init__(unique_id=unique_id, model=model)
|
super().__init__(unique_id=unique_id, model=model)
|
||||||
self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id)
|
self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id)
|
||||||
|
|
||||||
self._neighbors = None
|
self._neighbors = None
|
||||||
self.alive = True
|
self.alive = True
|
||||||
real_state = deepcopy(self.defaults)
|
|
||||||
real_state.update(state or {})
|
|
||||||
self.state = real_state
|
|
||||||
|
|
||||||
self.interval = interval or self.get('interval', getattr(self.model, 'interval', 1))
|
self.interval = interval or self.get('interval', 1)
|
||||||
self.logger = logging.getLogger(self.model.name).getChild(self.name)
|
self.logger = logging.getLogger(self.model.name).getChild(self.name)
|
||||||
|
|
||||||
if hasattr(self, 'level'):
|
if hasattr(self, 'level'):
|
||||||
@ -75,7 +75,6 @@ class BaseAgent(Agent):
|
|||||||
|
|
||||||
@state.setter
|
@state.setter
|
||||||
def state(self, value):
|
def state(self, value):
|
||||||
self._state = {}
|
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
self[k] = v
|
self[k] = v
|
||||||
|
|
||||||
@ -87,28 +86,36 @@ class BaseAgent(Agent):
|
|||||||
def environment_params(self, value):
|
def environment_params(self, value):
|
||||||
self.model.environment_params = value
|
self.model.environment_params = value
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
if not key.startswith('_') and key not in IGNORED_FIELDS:
|
||||||
|
try:
|
||||||
|
k = Key(t_step=self.now,
|
||||||
|
dict_id=self.unique_id,
|
||||||
|
key=key)
|
||||||
|
self._saved.add(key)
|
||||||
|
self.model[k] = value
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
super().__setattr__(key, value)
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if isinstance(key, tuple):
|
if isinstance(key, tuple):
|
||||||
key, t_step = key
|
key, t_step = key
|
||||||
k = history.Key(key=key, t_step=t_step, agent_id=self.unique_id)
|
k = Key(key=key, t_step=t_step, dict_id=self.unique_id)
|
||||||
return self.model[k]
|
return self.model[k]
|
||||||
return self._state.get(key, None)
|
return getattr(self, key)
|
||||||
|
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
self._state[key] = None
|
return delattr(self, key)
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
return key in self._state
|
return hasattr(self, key)
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
self._state[key] = value
|
setattr(self, key, value)
|
||||||
k = history.Key(t_step=self.now,
|
|
||||||
agent_id=self.id,
|
|
||||||
key=key)
|
|
||||||
self.model[k] = value
|
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
return self._state.items()
|
return ((k, getattr(self, k)) for k in self._saved)
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
return self[key] if key in self else default
|
return self[key] if key in self else default
|
||||||
@ -150,25 +157,6 @@ class BaseAgent(Agent):
|
|||||||
def info(self, *args, **kwargs):
|
def info(self, *args, **kwargs):
|
||||||
return self.log(*args, level=logging.INFO, **kwargs)
|
return self.log(*args, level=logging.INFO, **kwargs)
|
||||||
|
|
||||||
# def __getstate__(self):
|
|
||||||
# '''
|
|
||||||
# Serializing an agent will lose all its running information (you cannot
|
|
||||||
# serialize an iterator), but it keeps the state and link to the environment,
|
|
||||||
# so it can be used for inspection and dumping to a file
|
|
||||||
# '''
|
|
||||||
# state = {}
|
|
||||||
# state['id'] = self.id
|
|
||||||
# state['environment'] = self.model
|
|
||||||
# state['_state'] = self._state
|
|
||||||
# return state
|
|
||||||
|
|
||||||
# def __setstate__(self, state):
|
|
||||||
# '''
|
|
||||||
# Get back a serialized agent and try to re-compose it
|
|
||||||
# '''
|
|
||||||
# self.state_id = state['id']
|
|
||||||
# self._state = state['_state']
|
|
||||||
# self.model = state['environment']
|
|
||||||
|
|
||||||
class NetworkAgent(BaseAgent):
|
class NetworkAgent(BaseAgent):
|
||||||
|
|
||||||
@ -303,15 +291,15 @@ class MetaFSM(type):
|
|||||||
class FSM(NetworkAgent, metaclass=MetaFSM):
|
class FSM(NetworkAgent, metaclass=MetaFSM):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(FSM, self).__init__(*args, **kwargs)
|
super(FSM, self).__init__(*args, **kwargs)
|
||||||
if 'id' not in self.state:
|
if not hasattr(self, 'state_id'):
|
||||||
if not self.default_state:
|
if not self.default_state:
|
||||||
raise ValueError('No default state specified for {}'.format(self.unique_id))
|
raise ValueError('No default state specified for {}'.format(self.unique_id))
|
||||||
self['id'] = self.default_state.id
|
self.state_id = self.default_state.id
|
||||||
|
|
||||||
self.set_state(self.state['id'])
|
self.set_state(self.state_id)
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
self.debug(f'Agent {self.unique_id} @ state {self["id"]}')
|
self.debug(f'Agent {self.unique_id} @ state {self.state_id}')
|
||||||
interval = super().step()
|
interval = super().step()
|
||||||
if 'id' not in self.state:
|
if 'id' not in self.state:
|
||||||
# if 'id' in self.state:
|
# if 'id' in self.state:
|
||||||
@ -320,14 +308,14 @@ class FSM(NetworkAgent, metaclass=MetaFSM):
|
|||||||
self.set_state(self.default_state.id)
|
self.set_state(self.default_state.id)
|
||||||
else:
|
else:
|
||||||
raise Exception('{} has no valid state id or default state'.format(self))
|
raise Exception('{} has no valid state id or default state'.format(self))
|
||||||
return self.states[self.state['id']](self) or interval
|
return self.states[self.state_id](self) or interval
|
||||||
|
|
||||||
def set_state(self, state):
|
def set_state(self, state):
|
||||||
if hasattr(state, 'id'):
|
if hasattr(state, 'id'):
|
||||||
state = state.id
|
state = state.id
|
||||||
if state not in self.states:
|
if state not in self.states:
|
||||||
raise ValueError('{} is not a valid state'.format(state))
|
raise ValueError('{} is not a valid state'.format(state))
|
||||||
self.state['id'] = state
|
self.state_id = state
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
@ -541,7 +529,7 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False,
|
|||||||
f = filter(lambda x: x not in ignore, f)
|
f = filter(lambda x: x not in ignore, f)
|
||||||
|
|
||||||
if state_id is not None:
|
if state_id is not None:
|
||||||
f = filter(lambda agent: agent.state.get('id', None) in state_id, f)
|
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
|
||||||
|
|
||||||
if agent_type is not None:
|
if agent_type is not None:
|
||||||
f = filter(lambda agent: isinstance(agent, agent_type), f)
|
f = filter(lambda agent: isinstance(agent, agent_type), f)
|
||||||
|
@ -4,7 +4,8 @@ import glob
|
|||||||
import yaml
|
import yaml
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
|
||||||
from . import serialization, history
|
from . import serialization
|
||||||
|
from tsih import History
|
||||||
|
|
||||||
|
|
||||||
def read_data(*args, group=False, **kwargs):
|
def read_data(*args, group=False, **kwargs):
|
||||||
@ -34,7 +35,7 @@ def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def read_sql(db, *args, **kwargs):
|
def read_sql(db, *args, **kwargs):
|
||||||
h = history.History(db_path=db, backup=False, readonly=True)
|
h = History(db_path=db, backup=False, readonly=True)
|
||||||
df = h.read_sql(*args, **kwargs)
|
df = h.read_sql(*args, **kwargs)
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
@ -12,9 +12,11 @@ from networkx.readwrite import json_graph
|
|||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
|
from tsih import History, Record, Key, NoHistory
|
||||||
|
|
||||||
from mesa import Model
|
from mesa import Model
|
||||||
|
|
||||||
from . import serialization, agents, analysis, history, utils, time
|
from . import serialization, agents, analysis, utils, time
|
||||||
|
|
||||||
# These properties will be copied when pickling/unpickling the environment
|
# These properties will be copied when pickling/unpickling the environment
|
||||||
_CONFIG_PROPS = [ 'name',
|
_CONFIG_PROPS = [ 'name',
|
||||||
@ -46,6 +48,7 @@ class Environment(Model):
|
|||||||
schedule=None,
|
schedule=None,
|
||||||
initial_time=0,
|
initial_time=0,
|
||||||
environment_params=None,
|
environment_params=None,
|
||||||
|
history=True,
|
||||||
dir_path=None,
|
dir_path=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
@ -78,7 +81,11 @@ class Environment(Model):
|
|||||||
|
|
||||||
self._env_agents = {}
|
self._env_agents = {}
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self._history = history.History(name=self.name,
|
if history:
|
||||||
|
history = History
|
||||||
|
else:
|
||||||
|
history = NoHistory
|
||||||
|
self._history = history(name=self.name,
|
||||||
backup=True)
|
backup=True)
|
||||||
self['SEED'] = seed
|
self['SEED'] = seed
|
||||||
|
|
||||||
@ -162,8 +169,15 @@ class Environment(Model):
|
|||||||
if agent_type:
|
if agent_type:
|
||||||
state = defstate
|
state = defstate
|
||||||
a = agent_type(model=self,
|
a = agent_type(model=self,
|
||||||
unique_id=agent_id,
|
unique_id=agent_id)
|
||||||
state=state)
|
|
||||||
|
for (k, v) in getattr(a, 'defaults', {}).items():
|
||||||
|
if not hasattr(a, k) or getattr(a, k) is None:
|
||||||
|
setattr(a, k, v)
|
||||||
|
|
||||||
|
for (k, v) in state.items():
|
||||||
|
setattr(a, k, v)
|
||||||
|
|
||||||
node['agent'] = a
|
node['agent'] = a
|
||||||
self.schedule.add(a)
|
self.schedule.add(a)
|
||||||
return a
|
return a
|
||||||
@ -183,6 +197,11 @@ class Environment(Model):
|
|||||||
start = start or self.now
|
start = start or self.now
|
||||||
return self.G.add_edge(agent1, agent2, **attrs)
|
return self.G.add_edge(agent1, agent2, **attrs)
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
super().step()
|
||||||
|
self.datacollector.collect(self)
|
||||||
|
self.schedule.step()
|
||||||
|
|
||||||
def run(self, until, *args, **kwargs):
|
def run(self, until, *args, **kwargs):
|
||||||
self._save_state()
|
self._save_state()
|
||||||
|
|
||||||
@ -204,12 +223,12 @@ class Environment(Model):
|
|||||||
|
|
||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
if isinstance(key, tuple):
|
if isinstance(key, tuple):
|
||||||
k = history.Key(*key)
|
k = Key(*key)
|
||||||
self._history.save_record(*k,
|
self._history.save_record(*k,
|
||||||
value=value)
|
value=value)
|
||||||
return
|
return
|
||||||
self.environment_params[key] = value
|
self.environment_params[key] = value
|
||||||
self._history.save_record(agent_id='env',
|
self._history.save_record(dict_id='env',
|
||||||
t_step=self.now,
|
t_step=self.now,
|
||||||
key=key,
|
key=key,
|
||||||
value=value)
|
value=value)
|
||||||
@ -274,13 +293,13 @@ class Environment(Model):
|
|||||||
if now is None:
|
if now is None:
|
||||||
now = self.now
|
now = self.now
|
||||||
for k, v in self.environment_params.items():
|
for k, v in self.environment_params.items():
|
||||||
yield history.Record(agent_id='env',
|
yield Record(dict_id='env',
|
||||||
t_step=now,
|
t_step=now,
|
||||||
key=k,
|
key=k,
|
||||||
value=v)
|
value=v)
|
||||||
for agent in self.agents:
|
for agent in self.agents:
|
||||||
for k, v in agent.state.items():
|
for k, v in agent.state.items():
|
||||||
yield history.Record(agent_id=agent.id,
|
yield Record(dict_id=agent.id,
|
||||||
t_step=now,
|
t_step=now,
|
||||||
key=k,
|
key=k,
|
||||||
value=v)
|
value=v)
|
||||||
|
388
soil/history.py
388
soil/history.py
@ -1,388 +0,0 @@
|
|||||||
import time
|
|
||||||
import os
|
|
||||||
import pandas as pd
|
|
||||||
import sqlite3
|
|
||||||
import copy
|
|
||||||
import logging
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
from collections import UserDict, namedtuple
|
|
||||||
|
|
||||||
from . import serialization
|
|
||||||
from .utils import open_or_reuse, unflatten_dict
|
|
||||||
|
|
||||||
|
|
||||||
class History:
|
|
||||||
"""
|
|
||||||
Store and retrieve values from a sqlite database.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, name=None, db_path=None, backup=False, readonly=False):
|
|
||||||
if readonly and (not os.path.exists(db_path)):
|
|
||||||
raise Exception('The DB file does not exist. Cannot open in read-only mode')
|
|
||||||
|
|
||||||
self._db = None
|
|
||||||
self._temp = db_path is None
|
|
||||||
self._stats_columns = None
|
|
||||||
self.readonly = readonly
|
|
||||||
|
|
||||||
if self._temp:
|
|
||||||
if not name:
|
|
||||||
name = time.time()
|
|
||||||
# The file will be deleted as soon as it's closed
|
|
||||||
# Normally, that will be on destruction
|
|
||||||
db_path = tempfile.NamedTemporaryFile(suffix='{}.sqlite'.format(name)).name
|
|
||||||
|
|
||||||
|
|
||||||
if backup and os.path.exists(db_path):
|
|
||||||
newname = db_path + '.backup{}.sqlite'.format(time.time())
|
|
||||||
os.rename(db_path, newname)
|
|
||||||
|
|
||||||
self.db_path = db_path
|
|
||||||
|
|
||||||
self.db = db_path
|
|
||||||
self._dtypes = {}
|
|
||||||
self._tups = []
|
|
||||||
|
|
||||||
|
|
||||||
if self.readonly:
|
|
||||||
return
|
|
||||||
|
|
||||||
with self.db:
|
|
||||||
logger.debug('Creating database {}'.format(self.db_path))
|
|
||||||
self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step real, key text, value text)''')
|
|
||||||
self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''')
|
|
||||||
self.db.execute('''CREATE TABLE IF NOT EXISTS stats (trial_id text)''')
|
|
||||||
self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''')
|
|
||||||
|
|
||||||
@property
|
|
||||||
def db(self):
|
|
||||||
try:
|
|
||||||
self._db.cursor()
|
|
||||||
except (sqlite3.ProgrammingError, AttributeError):
|
|
||||||
self.db = None # Reset the database
|
|
||||||
return self._db
|
|
||||||
|
|
||||||
@db.setter
|
|
||||||
def db(self, db_path=None):
|
|
||||||
self._close()
|
|
||||||
db_path = db_path or self.db_path
|
|
||||||
if isinstance(db_path, str):
|
|
||||||
logger.debug('Connecting to database {}'.format(db_path))
|
|
||||||
self._db = sqlite3.connect(db_path)
|
|
||||||
self._db.row_factory = sqlite3.Row
|
|
||||||
else:
|
|
||||||
self._db = db_path
|
|
||||||
|
|
||||||
def _close(self):
|
|
||||||
if self._db is None:
|
|
||||||
return
|
|
||||||
self.flush_cache()
|
|
||||||
self._db.close()
|
|
||||||
self._db = None
|
|
||||||
|
|
||||||
def save_stats(self, stat):
|
|
||||||
if self.readonly:
|
|
||||||
print('DB in readonly mode')
|
|
||||||
return
|
|
||||||
if not stat:
|
|
||||||
return
|
|
||||||
with self.db:
|
|
||||||
if not self._stats_columns:
|
|
||||||
self._stats_columns = list(c['name'] for c in self.db.execute('PRAGMA table_info(stats)'))
|
|
||||||
|
|
||||||
for column, value in stat.items():
|
|
||||||
if column in self._stats_columns:
|
|
||||||
continue
|
|
||||||
dtype = 'text'
|
|
||||||
if not isinstance(value, str):
|
|
||||||
try:
|
|
||||||
float(value)
|
|
||||||
dtype = 'real'
|
|
||||||
int(value)
|
|
||||||
dtype = 'int'
|
|
||||||
except (ValueError, OverflowError):
|
|
||||||
pass
|
|
||||||
self.db.execute('ALTER TABLE stats ADD "{}" "{}"'.format(column, dtype))
|
|
||||||
self._stats_columns.append(column)
|
|
||||||
|
|
||||||
columns = ", ".join(map(lambda x: '"{}"'.format(x), stat.keys()))
|
|
||||||
values = ", ".join(['"{0}"'.format(col) for col in stat.values()])
|
|
||||||
query = "INSERT INTO stats ({columns}) VALUES ({values})".format(
|
|
||||||
columns=columns,
|
|
||||||
values=values
|
|
||||||
)
|
|
||||||
self.db.execute(query)
|
|
||||||
|
|
||||||
def get_stats(self, unflatten=True):
|
|
||||||
rows = self.db.execute("select * from stats").fetchall()
|
|
||||||
res = []
|
|
||||||
for row in rows:
|
|
||||||
d = {}
|
|
||||||
for k in row.keys():
|
|
||||||
if row[k] is None:
|
|
||||||
continue
|
|
||||||
d[k] = row[k]
|
|
||||||
if unflatten:
|
|
||||||
d = unflatten_dict(d)
|
|
||||||
res.append(d)
|
|
||||||
return res
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtypes(self):
|
|
||||||
self._read_types()
|
|
||||||
return {k:v[0] for k, v in self._dtypes.items()}
|
|
||||||
|
|
||||||
def save_tuples(self, tuples):
|
|
||||||
'''
|
|
||||||
Save a series of tuples, converting them to records if necessary
|
|
||||||
'''
|
|
||||||
self.save_records(Record(*tup) for tup in tuples)
|
|
||||||
|
|
||||||
def save_records(self, records):
|
|
||||||
'''
|
|
||||||
Save a collection of records
|
|
||||||
'''
|
|
||||||
for record in records:
|
|
||||||
if not isinstance(record, Record):
|
|
||||||
record = Record(*record)
|
|
||||||
self.save_record(*record)
|
|
||||||
|
|
||||||
def save_record(self, agent_id, t_step, key, value):
|
|
||||||
'''
|
|
||||||
Save a collection of records to the database.
|
|
||||||
Database writes are cached.
|
|
||||||
'''
|
|
||||||
if self.readonly:
|
|
||||||
raise Exception('DB in readonly mode')
|
|
||||||
if key not in self._dtypes:
|
|
||||||
self._read_types()
|
|
||||||
if key not in self._dtypes:
|
|
||||||
name = serialization.name(value)
|
|
||||||
serializer = serialization.serializer(name)
|
|
||||||
deserializer = serialization.deserializer(name)
|
|
||||||
self._dtypes[key] = (name, serializer, deserializer)
|
|
||||||
with self.db:
|
|
||||||
self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name))
|
|
||||||
value = self._dtypes[key][1](value)
|
|
||||||
|
|
||||||
self._tups.append(Record(agent_id=agent_id,
|
|
||||||
t_step=t_step,
|
|
||||||
key=key,
|
|
||||||
value=value))
|
|
||||||
if len(self._tups) > 100:
|
|
||||||
self.flush_cache()
|
|
||||||
|
|
||||||
def flush_cache(self):
|
|
||||||
'''
|
|
||||||
Use a cache to save state changes to avoid opening a session for every change.
|
|
||||||
The cache will be flushed at the end of the simulation, and when history is accessed.
|
|
||||||
'''
|
|
||||||
if self.readonly:
|
|
||||||
raise Exception('DB in readonly mode')
|
|
||||||
logger.debug('Flushing cache {}'.format(self.db_path))
|
|
||||||
with self.db:
|
|
||||||
self.db.executemany("replace into history(agent_id, t_step, key, value) values (?, ?, ?, ?)", self._tups)
|
|
||||||
# (rec.agent_id, rec.t_step, rec.key, rec.value))
|
|
||||||
self._tups.clear()
|
|
||||||
|
|
||||||
def to_tuples(self):
|
|
||||||
self.flush_cache()
|
|
||||||
with self.db:
|
|
||||||
res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall()
|
|
||||||
for r in res:
|
|
||||||
agent_id, t_step, key, value = r
|
|
||||||
if key not in self._dtypes:
|
|
||||||
self._read_types()
|
|
||||||
if key not in self._dtypes:
|
|
||||||
raise ValueError("Unknown datatype for {} and {}".format(key, value))
|
|
||||||
value = self._dtypes[key][2](value)
|
|
||||||
yield agent_id, t_step, key, value
|
|
||||||
|
|
||||||
def _read_types(self):
|
|
||||||
with self.db:
|
|
||||||
res = self.db.execute("select key, value_type from value_types ").fetchall()
|
|
||||||
for k, v in res:
|
|
||||||
serializer = serialization.serializer(v)
|
|
||||||
deserializer = serialization.deserializer(v)
|
|
||||||
self._dtypes[k] = (v, serializer, deserializer)
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
# raise NotImplementedError()
|
|
||||||
self.flush_cache()
|
|
||||||
key = Key(*key)
|
|
||||||
agent_ids = [key.agent_id] if key.agent_id is not None else []
|
|
||||||
t_steps = [key.t_step] if key.t_step is not None else []
|
|
||||||
keys = [key.key] if key.key is not None else []
|
|
||||||
|
|
||||||
df = self.read_sql(agent_ids=agent_ids,
|
|
||||||
t_steps=t_steps,
|
|
||||||
keys=keys)
|
|
||||||
r = Records(df, filter=key, dtypes=self._dtypes)
|
|
||||||
if r.resolved:
|
|
||||||
return r.value()
|
|
||||||
return r
|
|
||||||
|
|
||||||
def read_sql(self, keys=None, agent_ids=None, not_agent_ids=None, t_steps=None, convert_types=False, limit=-1):
|
|
||||||
|
|
||||||
self._read_types()
|
|
||||||
|
|
||||||
def escape_and_join(v):
|
|
||||||
if v is None:
|
|
||||||
return
|
|
||||||
return ",".join(map(lambda x: "\'{}\'".format(x), v))
|
|
||||||
|
|
||||||
filters = [("key in ({})".format(escape_and_join(keys)), keys),
|
|
||||||
("agent_id in ({})".format(escape_and_join(agent_ids)), agent_ids),
|
|
||||||
("agent_id not in ({})".format(escape_and_join(not_agent_ids)), not_agent_ids)
|
|
||||||
]
|
|
||||||
filters = list(k[0] for k in filters if k[1])
|
|
||||||
|
|
||||||
last_df = None
|
|
||||||
if t_steps:
|
|
||||||
# Convert negative indices into positive
|
|
||||||
if any(x<0 for x in t_steps):
|
|
||||||
max_t = int(self.db.execute("select max(t_step) from history").fetchone()[0])
|
|
||||||
t_steps = [t if t>0 else max_t+1+t for t in t_steps]
|
|
||||||
|
|
||||||
# We will be doing ffill interpolation, so we need to look for
|
|
||||||
# the last value before the minimum step in the query
|
|
||||||
min_step = min(t_steps)
|
|
||||||
last_filters = ['t_step < {}'.format(min_step),]
|
|
||||||
last_filters = last_filters + filters
|
|
||||||
condition = ' and '.join(last_filters)
|
|
||||||
|
|
||||||
last_query = '''
|
|
||||||
select h1.*
|
|
||||||
from history h1
|
|
||||||
inner join (
|
|
||||||
select agent_id, key, max(t_step) as t_step
|
|
||||||
from history
|
|
||||||
where {condition}
|
|
||||||
group by agent_id, key
|
|
||||||
) h2
|
|
||||||
on h1.agent_id = h2.agent_id and
|
|
||||||
h1.key = h2.key and
|
|
||||||
h1.t_step = h2.t_step
|
|
||||||
'''.format(condition=condition)
|
|
||||||
last_df = pd.read_sql_query(last_query, self.db)
|
|
||||||
|
|
||||||
filters.append("t_step >= '{}' and t_step <= '{}'".format(min_step, max(t_steps)))
|
|
||||||
|
|
||||||
condition = ''
|
|
||||||
if filters:
|
|
||||||
condition = 'where {} '.format(' and '.join(filters))
|
|
||||||
query = 'select * from history {} limit {}'.format(condition, limit)
|
|
||||||
df = pd.read_sql_query(query, self.db)
|
|
||||||
if last_df is not None:
|
|
||||||
df = pd.concat([df, last_df])
|
|
||||||
|
|
||||||
df_p = df.pivot_table(values='value', index=['t_step'],
|
|
||||||
columns=['key', 'agent_id'],
|
|
||||||
aggfunc='first')
|
|
||||||
|
|
||||||
for k, v in self._dtypes.items():
|
|
||||||
if k in df_p:
|
|
||||||
dtype, _, deserial = v
|
|
||||||
try:
|
|
||||||
df_p[k] = df_p[k].fillna(method='ffill').astype(dtype)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
# Avoid forward-filling unknown/incompatible types
|
|
||||||
continue
|
|
||||||
if t_steps:
|
|
||||||
df_p = df_p.reindex(t_steps, method='ffill')
|
|
||||||
return df_p.ffill()
|
|
||||||
|
|
||||||
def __getstate__(self):
|
|
||||||
state = dict(**self.__dict__)
|
|
||||||
del state['_db']
|
|
||||||
del state['_dtypes']
|
|
||||||
return state
|
|
||||||
|
|
||||||
def __setstate__(self, state):
|
|
||||||
self.__dict__ = state
|
|
||||||
self._dtypes = {}
|
|
||||||
self._db = None
|
|
||||||
|
|
||||||
def dump(self, f):
|
|
||||||
self._close()
|
|
||||||
for line in open_or_reuse(self.db_path, 'rb'):
|
|
||||||
f.write(line)
|
|
||||||
|
|
||||||
|
|
||||||
class Records():
|
|
||||||
|
|
||||||
def __init__(self, df, filter=None, dtypes=None):
|
|
||||||
if not filter:
|
|
||||||
filter = Key(agent_id=None,
|
|
||||||
t_step=None,
|
|
||||||
key=None)
|
|
||||||
self._df = df
|
|
||||||
self._filter = filter
|
|
||||||
self.dtypes = dtypes or {}
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def mask(self, tup):
|
|
||||||
res = ()
|
|
||||||
for i, k in zip(tup[:-1], self._filter):
|
|
||||||
if k is None:
|
|
||||||
res = res + (i,)
|
|
||||||
res = res + (tup[-1],)
|
|
||||||
return res
|
|
||||||
|
|
||||||
def filter(self, newKey):
|
|
||||||
f = list(self._filter)
|
|
||||||
for ix, i in enumerate(f):
|
|
||||||
if i is None:
|
|
||||||
f[ix] = newKey
|
|
||||||
self._filter = Key(*f)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def resolved(self):
|
|
||||||
return sum(1 for i in self._filter if i is not None) == 3
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
for column, series in self._df.iteritems():
|
|
||||||
key, agent_id = column
|
|
||||||
for t_step, value in series.iteritems():
|
|
||||||
r = Record(t_step=t_step,
|
|
||||||
agent_id=agent_id,
|
|
||||||
key=key,
|
|
||||||
value=value)
|
|
||||||
yield self.mask(r)
|
|
||||||
|
|
||||||
def value(self):
|
|
||||||
if self.resolved:
|
|
||||||
f = self._filter
|
|
||||||
try:
|
|
||||||
i = self._df[f.key][str(f.agent_id)]
|
|
||||||
ix = i.index.get_loc(f.t_step, method='ffill')
|
|
||||||
return i.iloc[ix]
|
|
||||||
except KeyError as ex:
|
|
||||||
return self.dtypes[f.key][2]()
|
|
||||||
return list(self)
|
|
||||||
|
|
||||||
def df(self):
|
|
||||||
return self._df
|
|
||||||
|
|
||||||
def __getitem__(self, k):
|
|
||||||
n = copy.copy(self)
|
|
||||||
n.filter(k)
|
|
||||||
if n.resolved:
|
|
||||||
return n.value()
|
|
||||||
return n
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._df)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
if self.resolved:
|
|
||||||
return str(self.value())
|
|
||||||
return '<Records for [{}]>'.format(self._filter)
|
|
||||||
|
|
||||||
Key = namedtuple('Key', ['agent_id', 't_step', 'key'])
|
|
||||||
Record = namedtuple('Record', 'agent_id t_step key value')
|
|
||||||
|
|
||||||
Stat = namedtuple('Stat', 'trial_id')
|
|
@ -9,6 +9,7 @@ import networkx as nx
|
|||||||
from networkx.readwrite import json_graph
|
from networkx.readwrite import json_graph
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from tsih import History
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
@ -17,7 +18,6 @@ from .environment import Environment
|
|||||||
from .utils import logger
|
from .utils import logger
|
||||||
from .exporters import default
|
from .exporters import default
|
||||||
from .stats import defaultStats
|
from .stats import defaultStats
|
||||||
from .history import History
|
|
||||||
|
|
||||||
|
|
||||||
#TODO: change documentation for simulation
|
#TODO: change documentation for simulation
|
||||||
@ -159,7 +159,7 @@ class Simulation:
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def run_gen(self, *args, parallel=False, dry_run=False,
|
def run_gen(self, *args, parallel=False, dry_run=False,
|
||||||
exporters=[default, ], stats=[defaultStats], outdir=None, exporter_params={},
|
exporters=[default, ], stats=[], outdir=None, exporter_params={},
|
||||||
stats_params={}, log_level=None,
|
stats_params={}, log_level=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
'''Run the simulation and yield the resulting environments.'''
|
'''Run the simulation and yield the resulting environments.'''
|
||||||
|
@ -97,7 +97,7 @@ class defaultStats(Stats):
|
|||||||
return {
|
return {
|
||||||
'network ': {
|
'network ': {
|
||||||
'n_nodes': env.G.number_of_nodes(),
|
'n_nodes': env.G.number_of_nodes(),
|
||||||
'n_edges': env.G.number_of_nodes(),
|
'n_edges': env.G.number_of_edges(),
|
||||||
},
|
},
|
||||||
'agents': {
|
'agents': {
|
||||||
'model_count': dict(c),
|
'model_count': dict(c),
|
||||||
|
@ -26,6 +26,8 @@ def timer(name='task', pre="", function=logger.info, to_object=None):
|
|||||||
to_object.end = end
|
to_object.end = end
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def safe_open(path, mode='r', backup=True, **kwargs):
|
def safe_open(path, mode='r', backup=True, **kwargs):
|
||||||
outdir = os.path.dirname(path)
|
outdir = os.path.dirname(path)
|
||||||
if outdir and not os.path.exists(outdir):
|
if outdir and not os.path.exists(outdir):
|
||||||
|
@ -67,13 +67,13 @@ class TestAnalysis(TestCase):
|
|||||||
def test_count(self):
|
def test_count(self):
|
||||||
env = self.env
|
env = self.env
|
||||||
df = analysis.read_sql(env._history.db_path)
|
df = analysis.read_sql(env._history.db_path)
|
||||||
res = analysis.get_count(df, 'SEED', 'id')
|
res = analysis.get_count(df, 'SEED', 'state_id')
|
||||||
assert res['SEED'][self.env['SEED']].iloc[0] == 1
|
assert res['SEED'][self.env['SEED']].iloc[0] == 1
|
||||||
assert res['SEED'][self.env['SEED']].iloc[-1] == 1
|
assert res['SEED'][self.env['SEED']].iloc[-1] == 1
|
||||||
assert res['id']['odd'].iloc[0] == 2
|
assert res['state_id']['odd'].iloc[0] == 2
|
||||||
assert res['id']['even'].iloc[0] == 0
|
assert res['state_id']['even'].iloc[0] == 0
|
||||||
assert res['id']['odd'].iloc[-1] == 1
|
assert res['state_id']['odd'].iloc[-1] == 1
|
||||||
assert res['id']['even'].iloc[-1] == 1
|
assert res['state_id']['even'].iloc[-1] == 1
|
||||||
|
|
||||||
def test_value(self):
|
def test_value(self):
|
||||||
env = self.env
|
env = self.env
|
||||||
|
@ -9,7 +9,7 @@ from functools import partial
|
|||||||
|
|
||||||
from os.path import join
|
from os.path import join
|
||||||
from soil import (simulation, Environment, agents, serialization,
|
from soil import (simulation, Environment, agents, serialization,
|
||||||
history, utils)
|
utils)
|
||||||
from soil.time import Delta
|
from soil.time import Delta
|
||||||
|
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ class CustomAgent(agents.FSM):
|
|||||||
@agents.default_state
|
@agents.default_state
|
||||||
@agents.state
|
@agents.state
|
||||||
def normal(self):
|
def normal(self):
|
||||||
self.state['neighbors'] = self.count_agents(state_id='normal',
|
self.neighbors = self.count_agents(state_id='normal',
|
||||||
limit_neighbors=True)
|
limit_neighbors=True)
|
||||||
@agents.state
|
@agents.state
|
||||||
def unreachable(self):
|
def unreachable(self):
|
||||||
@ -116,7 +116,7 @@ class TestMain(TestCase):
|
|||||||
'network_agents': [{
|
'network_agents': [{
|
||||||
'agent_type': 'AggregatedCounter',
|
'agent_type': 'AggregatedCounter',
|
||||||
'weight': 1,
|
'weight': 1,
|
||||||
'state': {'id': 0}
|
'state': {'state_id': 0}
|
||||||
|
|
||||||
}],
|
}],
|
||||||
'max_time': 10,
|
'max_time': 10,
|
||||||
@ -149,10 +149,9 @@ class TestMain(TestCase):
|
|||||||
}
|
}
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
env = s.run_simulation(dry_run=True)[0]
|
env = s.run_simulation(dry_run=True)[0]
|
||||||
assert env.get_agent(0).state['neighbors'] == 1
|
|
||||||
assert env.get_agent(0).state['neighbors'] == 1
|
|
||||||
assert env.get_agent(1).count_agents(state_id='normal') == 2
|
assert env.get_agent(1).count_agents(state_id='normal') == 2
|
||||||
assert env.get_agent(1).count_agents(state_id='normal', limit_neighbors=True) == 1
|
assert env.get_agent(1).count_agents(state_id='normal', limit_neighbors=True) == 1
|
||||||
|
assert env.get_agent(0).neighbors == 1
|
||||||
|
|
||||||
def test_torvalds_example(self):
|
def test_torvalds_example(self):
|
||||||
"""A complete example from a documentation should work."""
|
"""A complete example from a documentation should work."""
|
||||||
@ -317,12 +316,6 @@ class TestMain(TestCase):
|
|||||||
assert recovered['key', 0] == 'test'
|
assert recovered['key', 0] == 'test'
|
||||||
assert recovered['key'] == 'test'
|
assert recovered['key'] == 'test'
|
||||||
|
|
||||||
def test_history(self):
|
|
||||||
'''Test storing in and retrieving from history (sqlite)'''
|
|
||||||
h = history.History()
|
|
||||||
h.save_record(agent_id=0, t_step=0, key="test", value="hello")
|
|
||||||
assert h[0, 0, "test"] == "hello"
|
|
||||||
|
|
||||||
def test_subgraph(self):
|
def test_subgraph(self):
|
||||||
'''An agent should be able to subgraph the global topology'''
|
'''An agent should be able to subgraph the global topology'''
|
||||||
G = nx.Graph()
|
G = nx.Graph()
|
||||||
@ -350,12 +343,13 @@ class TestMain(TestCase):
|
|||||||
'network_params': {},
|
'network_params': {},
|
||||||
'agent_type': 'CounterModel',
|
'agent_type': 'CounterModel',
|
||||||
'max_time': 2,
|
'max_time': 2,
|
||||||
'num_trials': 100,
|
'num_trials': 50,
|
||||||
'environment_params': {}
|
'environment_params': {}
|
||||||
}
|
}
|
||||||
s = simulation.from_config(config)
|
s = simulation.from_config(config)
|
||||||
runs = list(s.run_simulation(dry_run=True))
|
runs = list(s.run_simulation(dry_run=True))
|
||||||
over = list(x.now for x in runs if x.now>2)
|
over = list(x.now for x in runs if x.now>2)
|
||||||
|
assert len(runs) == config['num_trials']
|
||||||
assert len(over) == 0
|
assert len(over) == 0
|
||||||
|
|
||||||
|
|
||||||
@ -372,11 +366,11 @@ class TestMain(TestCase):
|
|||||||
return self.ping
|
return self.ping
|
||||||
|
|
||||||
a = ToggleAgent(unique_id=1, model=Environment())
|
a = ToggleAgent(unique_id=1, model=Environment())
|
||||||
assert a.state["id"] == a.ping.id
|
assert a.state_id == a.ping.id
|
||||||
a.step()
|
a.step()
|
||||||
assert a.state["id"] == a.pong.id
|
assert a.state_id == a.pong.id
|
||||||
a.step()
|
a.step()
|
||||||
assert a.state["id"] == a.ping.id
|
assert a.state_id == a.ping.id
|
||||||
|
|
||||||
def test_fsm_when(self):
|
def test_fsm_when(self):
|
||||||
'''Basic state change'''
|
'''Basic state change'''
|
||||||
|
Loading…
Reference in New Issue
Block a user