mirror of https://github.com/gsi-upm/soil
WIP: mesa compatibility
parent
e860bdb922
commit
5d7e57675a
@ -0,0 +1,21 @@
|
||||
---
|
||||
name: mesa_sim
|
||||
group: tests
|
||||
dir_path: "/tmp"
|
||||
num_trials: 3
|
||||
max_time: 100
|
||||
interval: 1
|
||||
seed: '1'
|
||||
network_params:
|
||||
generator: social_wealth.graph_generator
|
||||
n: 5
|
||||
network_agents:
|
||||
- agent_type: social_wealth.SocialMoneyAgent
|
||||
weight: 1
|
||||
environment_class: social_wealth.MoneyEnv
|
||||
environment_params:
|
||||
num_mesa_agents: 5
|
||||
mesa_agent_type: social_wealth.MoneyAgent
|
||||
N: 10
|
||||
width: 50
|
||||
height: 50
|
@ -0,0 +1,106 @@
|
||||
from mesa.visualization.ModularVisualization import ModularServer
|
||||
from soil.visualization import UserSettableParameter
|
||||
from mesa.visualization.modules import ChartModule, NetworkModule, CanvasGrid
|
||||
from social_wealth import MoneyEnv, graph_generator, SocialMoneyAgent
|
||||
|
||||
|
||||
class MyNetwork(NetworkModule):
|
||||
def render(self, model):
|
||||
return self.portrayal_method(model)
|
||||
|
||||
|
||||
def network_portrayal(env):
|
||||
# The model ensures there is 0 or 1 agent per node
|
||||
|
||||
portrayal = dict()
|
||||
portrayal["nodes"] = [
|
||||
{
|
||||
"id": agent_id,
|
||||
"size": env.get_agent(agent_id).wealth,
|
||||
# "color": "#CC0000" if not agents or agents[0].wealth == 0 else "#007959",
|
||||
"color": "#CC0000",
|
||||
"label": f"{agent_id}: {env.get_agent(agent_id).wealth}",
|
||||
}
|
||||
for (agent_id) in env.G.nodes
|
||||
]
|
||||
# import pdb;pdb.set_trace()
|
||||
|
||||
portrayal["edges"] = [
|
||||
{"id": edge_id, "source": source, "target": target, "color": "#000000"}
|
||||
for edge_id, (source, target) in enumerate(env.G.edges)
|
||||
]
|
||||
|
||||
|
||||
return portrayal
|
||||
|
||||
|
||||
def gridPortrayal(agent):
|
||||
"""
|
||||
This function is registered with the visualization server to be called
|
||||
each tick to indicate how to draw the agent in its current state.
|
||||
:param agent: the agent in the simulation
|
||||
:return: the portrayal dictionary
|
||||
"""
|
||||
color = max(10, min(agent.wealth*10, 100))
|
||||
return {
|
||||
"Shape": "rect",
|
||||
"w": 1,
|
||||
"h": 1,
|
||||
"Filled": "true",
|
||||
"Layer": 0,
|
||||
"Label": agent.unique_id,
|
||||
"Text": agent.unique_id,
|
||||
"x": agent.pos[0],
|
||||
"y": agent.pos[1],
|
||||
"Color": f"rgba(31, 10, 255, 0.{color})"
|
||||
}
|
||||
|
||||
|
||||
grid = MyNetwork(network_portrayal, 500, 500, library="sigma")
|
||||
chart = ChartModule(
|
||||
[{"Label": "Gini", "Color": "Black"}], data_collector_name="datacollector"
|
||||
)
|
||||
|
||||
model_params = {
|
||||
"N": UserSettableParameter(
|
||||
"slider",
|
||||
"N",
|
||||
1,
|
||||
1,
|
||||
10,
|
||||
1,
|
||||
description="Choose how many agents to include in the model",
|
||||
),
|
||||
"network_agents": [{"agent_type": SocialMoneyAgent}],
|
||||
"height": UserSettableParameter(
|
||||
"slider",
|
||||
"height",
|
||||
5,
|
||||
5,
|
||||
10,
|
||||
1,
|
||||
description="Grid height",
|
||||
),
|
||||
"width": UserSettableParameter(
|
||||
"slider",
|
||||
"width",
|
||||
5,
|
||||
5,
|
||||
10,
|
||||
1,
|
||||
description="Grid width",
|
||||
),
|
||||
"network_params": {
|
||||
'generator': graph_generator
|
||||
},
|
||||
}
|
||||
|
||||
canvas_element = CanvasGrid(gridPortrayal, model_params["width"].value, model_params["height"].value, 500, 500)
|
||||
|
||||
|
||||
server = ModularServer(
|
||||
MoneyEnv, [grid, chart, canvas_element], "Money Model", model_params
|
||||
)
|
||||
server.port = 8521
|
||||
|
||||
server.launch(open_browser=False)
|
@ -0,0 +1,134 @@
|
||||
'''
|
||||
This is an example that adds soil agents and environment in a normal
|
||||
mesa workflow.
|
||||
'''
|
||||
from mesa import Agent as MesaAgent
|
||||
from mesa.space import MultiGrid
|
||||
# from mesa.time import RandomActivation
|
||||
from mesa.datacollection import DataCollector
|
||||
from mesa.batchrunner import BatchRunner
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from soil import NetworkAgent, Environment
|
||||
|
||||
def compute_gini(model):
|
||||
agent_wealths = [agent.wealth for agent in model.agents]
|
||||
x = sorted(agent_wealths)
|
||||
N = len(list(model.agents))
|
||||
B = sum( xi * (N-i) for i,xi in enumerate(x) ) / (N*sum(x))
|
||||
return (1 + (1/N) - 2*B)
|
||||
|
||||
class MoneyAgent(MesaAgent):
|
||||
"""
|
||||
A MESA agent with fixed initial wealth.
|
||||
It will only share wealth with neighbors based on grid proximity
|
||||
"""
|
||||
|
||||
def __init__(self, unique_id, model):
|
||||
super().__init__(unique_id=unique_id, model=model)
|
||||
self.wealth = 1
|
||||
|
||||
def move(self):
|
||||
possible_steps = self.model.grid.get_neighborhood(
|
||||
self.pos,
|
||||
moore=True,
|
||||
include_center=False)
|
||||
print(self.pos, possible_steps)
|
||||
new_position = self.random.choice(possible_steps)
|
||||
print(self.pos, new_position)
|
||||
self.model.grid.move_agent(self, new_position)
|
||||
|
||||
def give_money(self):
|
||||
cellmates = self.model.grid.get_cell_list_contents([self.pos])
|
||||
if len(cellmates) > 1:
|
||||
other = self.random.choice(cellmates)
|
||||
other.wealth += 1
|
||||
self.wealth -= 1
|
||||
|
||||
def step(self):
|
||||
self.info("Crying wolf", self.pos)
|
||||
self.move()
|
||||
if self.wealth > 0:
|
||||
self.give_money()
|
||||
|
||||
|
||||
class SocialMoneyAgent(NetworkAgent, MoneyAgent):
|
||||
wealth = 1
|
||||
|
||||
def give_money(self):
|
||||
cellmates = set(self.model.grid.get_cell_list_contents([self.pos]))
|
||||
friends = set(self.get_neighboring_agents())
|
||||
self.info("Trying to give money")
|
||||
self.debug("Cellmates: ", cellmates)
|
||||
self.debug("Friends: ", friends)
|
||||
|
||||
nearby_friends = list(cellmates & friends)
|
||||
|
||||
if len(nearby_friends):
|
||||
other = self.random.choice(nearby_friends)
|
||||
other.wealth += 1
|
||||
self.wealth -= 1
|
||||
|
||||
|
||||
class MoneyEnv(Environment):
|
||||
"""A model with some number of agents."""
|
||||
def __init__(self, N, width, height, *args, network_params, **kwargs):
|
||||
self.initialized = True
|
||||
# import pdb;pdb.set_trace()
|
||||
|
||||
network_params['n'] = N
|
||||
super().__init__(*args, network_params=network_params, **kwargs)
|
||||
self.grid = MultiGrid(width, height, False)
|
||||
# self.schedule = RandomActivation(self)
|
||||
self.running = True
|
||||
|
||||
# Create agents
|
||||
for agent in self.agents:
|
||||
self.schedule.add(agent)
|
||||
# a = MoneyAgent(i, self)
|
||||
# self.schedule.add(a)
|
||||
# Add the agent to a random grid cell
|
||||
x = self.random.randrange(self.grid.width)
|
||||
y = self.random.randrange(self.grid.height)
|
||||
self.grid.place_agent(agent, (x, y))
|
||||
|
||||
self.datacollector = DataCollector(
|
||||
model_reporters={"Gini": compute_gini},
|
||||
agent_reporters={"Wealth": "wealth"})
|
||||
|
||||
def step(self):
|
||||
super().step()
|
||||
self.datacollector.collect(self)
|
||||
self.schedule.step()
|
||||
|
||||
def graph_generator(n=5):
|
||||
G = nx.Graph()
|
||||
for ix in range(n):
|
||||
G.add_edge(0, ix)
|
||||
return G
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
G = graph_generator()
|
||||
fixed_params = {"topology": G,
|
||||
"width": 10,
|
||||
"network_agents": [{"agent_type": SocialMoneyAgent,
|
||||
'weight': 1}],
|
||||
"height": 10}
|
||||
|
||||
variable_params = {"N": range(10, 100, 10)}
|
||||
|
||||
batch_run = BatchRunner(MoneyEnv,
|
||||
variable_parameters=variable_params,
|
||||
fixed_parameters=fixed_params,
|
||||
iterations=5,
|
||||
max_steps=100,
|
||||
model_reporters={"Gini": compute_gini})
|
||||
batch_run.run_all()
|
||||
|
||||
run_data = batch_run.get_model_vars_dataframe()
|
||||
run_data.head()
|
||||
print(run_data.Gini)
|
||||
|
@ -0,0 +1,83 @@
|
||||
from mesa import Agent, Model
|
||||
from mesa.space import MultiGrid
|
||||
from mesa.time import RandomActivation
|
||||
from mesa.datacollection import DataCollector
|
||||
from mesa.batchrunner import BatchRunner
|
||||
|
||||
def compute_gini(model):
|
||||
agent_wealths = [agent.wealth for agent in model.schedule.agents]
|
||||
x = sorted(agent_wealths)
|
||||
N = model.num_agents
|
||||
B = sum( xi * (N-i) for i,xi in enumerate(x) ) / (N*sum(x))
|
||||
return (1 + (1/N) - 2*B)
|
||||
|
||||
class MoneyAgent(Agent):
|
||||
""" An agent with fixed initial wealth."""
|
||||
def __init__(self, unique_id, model):
|
||||
super().__init__(unique_id, model)
|
||||
self.wealth = 1
|
||||
|
||||
def move(self):
|
||||
possible_steps = self.model.grid.get_neighborhood(
|
||||
self.pos,
|
||||
moore=True,
|
||||
include_center=False)
|
||||
new_position = self.random.choice(possible_steps)
|
||||
self.model.grid.move_agent(self, new_position)
|
||||
|
||||
def give_money(self):
|
||||
cellmates = self.model.grid.get_cell_list_contents([self.pos])
|
||||
if len(cellmates) > 1:
|
||||
other = self.random.choice(cellmates)
|
||||
other.wealth += 1
|
||||
self.wealth -= 1
|
||||
|
||||
def step(self):
|
||||
self.move()
|
||||
if self.wealth > 0:
|
||||
self.give_money()
|
||||
|
||||
class MoneyModel(Model):
|
||||
"""A model with some number of agents."""
|
||||
def __init__(self, N, width, height):
|
||||
self.num_agents = N
|
||||
self.grid = MultiGrid(width, height, True)
|
||||
self.schedule = RandomActivation(self)
|
||||
self.running = True
|
||||
|
||||
# Create agents
|
||||
for i in range(self.num_agents):
|
||||
a = MoneyAgent(i, self)
|
||||
self.schedule.add(a)
|
||||
# Add the agent to a random grid cell
|
||||
x = self.random.randrange(self.grid.width)
|
||||
y = self.random.randrange(self.grid.height)
|
||||
self.grid.place_agent(a, (x, y))
|
||||
|
||||
self.datacollector = DataCollector(
|
||||
model_reporters={"Gini": compute_gini},
|
||||
agent_reporters={"Wealth": "wealth"})
|
||||
|
||||
def step(self):
|
||||
self.datacollector.collect(self)
|
||||
self.schedule.step()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
fixed_params = {"width": 10,
|
||||
"height": 10}
|
||||
variable_params = {"N": range(10, 500, 10)}
|
||||
|
||||
batch_run = BatchRunner(MoneyModel,
|
||||
variable_params,
|
||||
fixed_params,
|
||||
iterations=5,
|
||||
max_steps=100,
|
||||
model_reporters={"Gini": compute_gini})
|
||||
batch_run.run_all()
|
||||
|
||||
run_data = batch_run.get_model_vars_dataframe()
|
||||
run_data.head()
|
||||
print(run_data.Gini)
|
||||
|
@ -1,9 +1,8 @@
|
||||
simpy>=4.0
|
||||
networkx>=2.5
|
||||
numpy
|
||||
matplotlib
|
||||
pyyaml>=5.1
|
||||
pandas>=0.23
|
||||
scipy>=1.3
|
||||
SALib>=1.3
|
||||
Jinja2
|
||||
Mesa>=0.8
|
||||
|
@ -1,40 +1,31 @@
|
||||
import random
|
||||
from . import BaseAgent
|
||||
from . import FSM, state, default_state
|
||||
|
||||
|
||||
class BassModel(BaseAgent):
|
||||
class BassModel(FSM):
|
||||
"""
|
||||
Settings:
|
||||
innovation_prob
|
||||
imitation_prob
|
||||
"""
|
||||
|
||||
def __init__(self, environment, agent_id, state, **kwargs):
|
||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
||||
env_params = environment.environment_params
|
||||
self.state['sentimentCorrelation'] = 0
|
||||
sentimentCorrelation = 0
|
||||
|
||||
def step(self):
|
||||
self.behaviour()
|
||||
|
||||
def behaviour(self):
|
||||
# Outside effects
|
||||
if random.random() < self['innovation_prob']:
|
||||
if self.state['id'] == 0:
|
||||
self.state['id'] = 1
|
||||
self.state['sentimentCorrelation'] = 1
|
||||
else:
|
||||
pass
|
||||
|
||||
return
|
||||
|
||||
# Imitation effects
|
||||
if self.state['id'] == 0:
|
||||
aware_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
@default_state
|
||||
@state
|
||||
def innovation(self):
|
||||
if random.random() < self.innovation_prob:
|
||||
self.sentimentCorrelation = 1
|
||||
return self.aware
|
||||
else:
|
||||
aware_neighbors = self.get_neighboring_agents(state_id=self.aware.id)
|
||||
num_neighbors_aware = len(aware_neighbors)
|
||||
if random.random() < (self['imitation_prob']*num_neighbors_aware):
|
||||
self.state['id'] = 1
|
||||
self.state['sentimentCorrelation'] = 1
|
||||
self.sentimentCorrelation = 1
|
||||
return self.aware
|
||||
|
||||
else:
|
||||
pass
|
||||
@state
|
||||
def aware(self):
|
||||
self.die()
|
||||
|
@ -0,0 +1,20 @@
|
||||
from scipy.spatial import cKDTree as KDTree
|
||||
from . import NetworkAgent
|
||||
|
||||
class Geo(NetworkAgent):
|
||||
'''In this type of network, nodes have a "pos" attribute.'''
|
||||
|
||||
def geo_search(self, radius, node=None, center=False, **kwargs):
|
||||
'''Get a list of nodes whose coordinates are closer than *radius* to *node*.'''
|
||||
node = as_node(node if node is not None else self)
|
||||
|
||||
G = self.subgraph(**kwargs)
|
||||
|
||||
pos = nx.get_node_attributes(G, 'pos')
|
||||
if not pos:
|
||||
return []
|
||||
nodes, coords = list(zip(*pos.items()))
|
||||
kdtree = KDTree(coords) # Cannot provide generator.
|
||||
indices = kdtree.query_ball_point(pos[node], radius)
|
||||
return [nodes[i] for i in indices if center or (nodes[i] != node)]
|
||||
|
@ -0,0 +1,26 @@
|
||||
from mesa import DataCollector as MDC
|
||||
|
||||
class SoilDataCollector(MDC):
|
||||
|
||||
|
||||
def __init__(self, environment, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Populate model and env reporters so they have a key per
|
||||
# So they can be shown in the web interface
|
||||
self.environment = environment
|
||||
|
||||
|
||||
@property
|
||||
def model_vars(self):
|
||||
pass
|
||||
|
||||
@model_vars.setter
|
||||
def model_vars(self, value):
|
||||
pass
|
||||
|
||||
@property
|
||||
def agent_reporters(self):
|
||||
self.model._history._
|
||||
|
||||
pass
|
||||
|
@ -0,0 +1,84 @@
|
||||
from mesa.time import BaseScheduler
|
||||
from queue import Empty
|
||||
from heapq import heappush, heappop
|
||||
import math
|
||||
from .utils import logger
|
||||
from mesa import Agent
|
||||
|
||||
|
||||
class When:
|
||||
def __init__(self, time):
|
||||
self._time = float(time)
|
||||
|
||||
def abs(self, time):
|
||||
return self._time
|
||||
|
||||
|
||||
class Delta:
|
||||
def __init__(self, delta):
|
||||
self._delta = delta
|
||||
|
||||
def abs(self, time):
|
||||
return time + self._delta
|
||||
|
||||
|
||||
class TimedActivation(BaseScheduler):
|
||||
"""A scheduler which activates each agent when the agent requests.
|
||||
In each activation, each agent will update its 'next_time'.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(self)
|
||||
self._queue = []
|
||||
self.next_time = 0
|
||||
|
||||
def add(self, agent: Agent):
|
||||
if agent.unique_id not in self._agents:
|
||||
heappush(self._queue, (self.time, agent.unique_id))
|
||||
super().add(agent)
|
||||
|
||||
def step(self, until: float =float('inf')) -> None:
|
||||
"""
|
||||
Executes agents in order, one at a time. After each step,
|
||||
an agent will signal when it wants to be scheduled next.
|
||||
"""
|
||||
|
||||
when = None
|
||||
agent_id = None
|
||||
unsched = []
|
||||
until = until or float('inf')
|
||||
|
||||
if not self._queue:
|
||||
self.time = until
|
||||
self.next_time = float('inf')
|
||||
return
|
||||
|
||||
(when, agent_id) = self._queue[0]
|
||||
|
||||
if until and when > until:
|
||||
self.time = until
|
||||
self.next_time = when
|
||||
return
|
||||
|
||||
self.time = when
|
||||
next_time = float("inf")
|
||||
|
||||
while when == self.time:
|
||||
heappop(self._queue)
|
||||
logger.debug(f'Stepping agent {agent_id}')
|
||||
when = (self._agents[agent_id].step() or Delta(1)).abs(self.time)
|
||||
heappush(self._queue, (when, agent_id))
|
||||
if when < next_time:
|
||||
next_time = when
|
||||
|
||||
if not self._queue or self._queue[0][0] > self.time:
|
||||
agent_id = None
|
||||
break
|
||||
else:
|
||||
(when, agent_id) = self._queue[0]
|
||||
|
||||
if when and when < self.time:
|
||||
raise Exception("Invalid scheduling time")
|
||||
|
||||
self.next_time = next_time
|
||||
self.steps += 1
|
@ -0,0 +1,5 @@
|
||||
from mesa.visualization.UserParam import UserSettableParameter
|
||||
|
||||
class UserSettableParameter(UserSettableParameter):
|
||||
def __str__(self):
|
||||
return self.value
|
@ -1 +1,4 @@
|
||||
pytest
|
||||
pytest
|
||||
mesa>=0.8.9
|
||||
scipy>=1.3
|
||||
tornado
|
||||
|
@ -1,203 +0,0 @@
|
||||
from unittest import TestCase
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from glob import glob
|
||||
|
||||
from soil import history
|
||||
from soil import utils
|
||||
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
DBROOT = os.path.join(ROOT, 'testdb')
|
||||
|
||||
|
||||
class TestHistory(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not os.path.exists(DBROOT):
|
||||
os.makedirs(DBROOT)
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(DBROOT):
|
||||
shutil.rmtree(DBROOT)
|
||||
|
||||
def test_history(self):
|
||||
"""
|
||||
"""
|
||||
tuples = (
|
||||
('a_0', 0, 'id', 'h'),
|
||||
('a_0', 1, 'id', 'e'),
|
||||
('a_0', 2, 'id', 'l'),
|
||||
('a_0', 3, 'id', 'l'),
|
||||
('a_0', 4, 'id', 'o'),
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 3, 'prob', 2),
|
||||
('env', 5, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
h = history.History()
|
||||
h.save_tuples(tuples)
|
||||
# assert h['env', 0, 'prob'] == 0
|
||||
for i in range(1, 7):
|
||||
assert h['env', i, 'prob'] == ((i-1)//2)+1
|
||||
|
||||
|
||||
for i, k in zip(range(5), 'hello'):
|
||||
assert h['a_0', i, 'id'] == k
|
||||
for record, value in zip(h['a_0', None, 'id'], 'hello'):
|
||||
t_step, val = record
|
||||
assert val == value
|
||||
|
||||
for i, k in zip(range(5), 'value'):
|
||||
assert h['a_1', i, 'id'] == k
|
||||
for i in range(5, 8):
|
||||
assert h['a_1', i, 'id'] == 'e'
|
||||
for i in range(7):
|
||||
assert h['a_2', i, 'finished'] == False
|
||||
assert h['a_2', 7, 'finished']
|
||||
|
||||
def test_history_gen(self):
|
||||
"""
|
||||
"""
|
||||
tuples = (
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 2, 'prob', 2),
|
||||
('env', 3, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
h = history.History()
|
||||
h.save_tuples(tuples)
|
||||
for t_step, key, value in h['env', None, None]:
|
||||
assert t_step == value
|
||||
assert key == 'prob'
|
||||
|
||||
records = list(h[None, 7, None])
|
||||
assert len(records) == 3
|
||||
for i in records:
|
||||
agent_id, key, value = i
|
||||
if agent_id == 'a_1':
|
||||
assert key == 'id'
|
||||
assert value == 'e'
|
||||
elif agent_id == 'a_2':
|
||||
assert key == 'finished'
|
||||
assert value
|
||||
else:
|
||||
assert key == 'prob'
|
||||
assert value == 3
|
||||
|
||||
records = h['a_1', 7, None]
|
||||
assert records['id'] == 'e'
|
||||
|
||||
def test_history_file(self):
|
||||
"""
|
||||
History should be saved to a file
|
||||
"""
|
||||
tuples = (
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 2, 'prob', 2),
|
||||
('env', 3, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
db_path = os.path.join(DBROOT, 'test')
|
||||
h = history.History(db_path=db_path)
|
||||
h.save_tuples(tuples)
|
||||
h.flush_cache()
|
||||
assert os.path.exists(db_path)
|
||||
|
||||
# Recover the data
|
||||
recovered = history.History(db_path=db_path)
|
||||
assert recovered['a_1', 0, 'id'] == 'v'
|
||||
assert recovered['a_1', 4, 'id'] == 'e'
|
||||
|
||||
# Using backup=True should create a backup copy, and initialize an empty history
|
||||
newhistory = history.History(db_path=db_path, backup=True)
|
||||
backuppaths = glob(db_path + '.backup*.sqlite')
|
||||
assert len(backuppaths) == 1
|
||||
backuppath = backuppaths[0]
|
||||
assert newhistory.db_path == h.db_path
|
||||
assert os.path.exists(backuppath)
|
||||
assert len(newhistory[None, None, None]) == 0
|
||||
|
||||
def test_history_tuples(self):
|
||||
"""
|
||||
The data recovered should be equal to the one recorded.
|
||||
"""
|
||||
tuples = (
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 2, 'prob', 2),
|
||||
('env', 3, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
h = history.History()
|
||||
h.save_tuples(tuples)
|
||||
recovered = list(h.to_tuples())
|
||||
assert recovered
|
||||
for i in recovered:
|
||||
assert i in tuples
|
||||
|
||||
def test_stats(self):
|
||||
"""
|
||||
The data recovered should be equal to the one recorded.
|
||||
"""
|
||||
tuples = (
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 2, 'prob', 2),
|
||||
('env', 3, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
stat_tuples = [
|
||||
{'num_infected': 5, 'runtime': 0.2},
|
||||
{'num_infected': 5, 'runtime': 0.2},
|
||||
{'new': '40'},
|
||||
]
|
||||
h = history.History()
|
||||
h.save_tuples(tuples)
|
||||
for stat in stat_tuples:
|
||||
h.save_stats(stat)
|
||||
recovered = h.get_stats()
|
||||
assert recovered
|
||||
assert recovered[0]['num_infected'] == 5
|
||||
assert recovered[1]['runtime'] == 0.2
|
||||
assert recovered[2]['new'] == '40'
|
||||
|
||||
def test_unflatten(self):
|
||||
ex = {'count.neighbors.3': 4,
|
||||
'count.times.2': 4,
|
||||
'count.total.4': 4,
|
||||
'mean.neighbors': 3,
|
||||
'mean.times': 2,
|
||||
'mean.total': 4,
|
||||
't_step': 2,
|
||||
'trial_id': 'exporter_sim_trial_1605817956-4475424'}
|
||||
res = utils.unflatten_dict(ex)
|
||||
|
||||
assert 'count' in res
|
||||
assert 'mean' in res
|
||||
assert 't_step' in res
|
||||
assert 'trial_id' in res
|
@ -0,0 +1,69 @@
|
||||
'''
|
||||
Mesa-SOIL integration tests
|
||||
|
||||
We have to test that:
|
||||
- Mesa agents can be used in SOIL
|
||||
- Simplified soil agents can be used in mesa simulations
|
||||
- Mesa and soil agents can interact in a simulation
|
||||
|
||||
- Mesa visualizations work with SOIL simulations
|
||||
|
||||
'''
|
||||
from mesa import Agent, Model
|
||||
from mesa.time import RandomActivation
|
||||
from mesa.space import MultiGrid
|
||||
|
||||
class MoneyAgent(Agent):
|
||||
""" An agent with fixed initial wealth."""
|
||||
def __init__(self, unique_id, model):
|
||||
super().__init__(unique_id, model)
|
||||
self.wealth = 1
|
||||
|
||||
def step(self):
|
||||
self.move()
|
||||
if self.wealth > 0:
|
||||
self.give_money()
|
||||
|
||||
def give_money(self):
|
||||
cellmates = self.model.grid.get_cell_list_contents([self.pos])
|
||||
if len(cellmates) > 1:
|
||||
other = self.random.choice(cellmates)
|
||||
other.wealth += 1
|
||||
self.wealth -= 1
|
||||
|
||||
def move(self):
|
||||
possible_steps = self.model.grid.get_neighborhood(
|
||||
self.pos,
|
||||
moore=True,
|
||||
include_center=False)
|
||||
new_position = self.random.choice(possible_steps)
|
||||
self.model.grid.move_agent(self, new_position)
|
||||
|
||||
|
||||
class MoneyModel(Model):
|
||||
"""A model with some number of agents."""
|
||||
def __init__(self, N, width, height):
|
||||
self.num_agents = N
|
||||
self.grid = MultiGrid(width, height, True)
|
||||
self.schedule = RandomActivation(self)
|
||||
|
||||
# Create agents
|
||||
for i in range(self.num_agents):
|
||||
a = MoneyAgent(i, self)
|
||||
self.schedule.add(a)
|
||||
|
||||
# Add the agent to a random grid cell
|
||||
x = self.random.randrange(self.grid.width)
|
||||
y = self.random.randrange(self.grid.height)
|
||||
self.grid.place_agent(a, (x, y))
|
||||
|
||||
def step(self):
|
||||
'''Advance the model by one step.'''
|
||||
self.schedule.step()
|
||||
|
||||
|
||||
# model = MoneyModel(10)
|
||||
# for i in range(10):
|
||||
# model.step()
|
||||
|
||||
# agent_wealth = [a.wealth for a in model.schedule.agents]
|
Loading…
Reference in New Issue