mirror of https://github.com/gsi-upm/soil
WIP: mesa compatibility
parent
e860bdb922
commit
5d7e57675a
@ -0,0 +1,21 @@
|
|||||||
|
---
|
||||||
|
name: mesa_sim
|
||||||
|
group: tests
|
||||||
|
dir_path: "/tmp"
|
||||||
|
num_trials: 3
|
||||||
|
max_time: 100
|
||||||
|
interval: 1
|
||||||
|
seed: '1'
|
||||||
|
network_params:
|
||||||
|
generator: social_wealth.graph_generator
|
||||||
|
n: 5
|
||||||
|
network_agents:
|
||||||
|
- agent_type: social_wealth.SocialMoneyAgent
|
||||||
|
weight: 1
|
||||||
|
environment_class: social_wealth.MoneyEnv
|
||||||
|
environment_params:
|
||||||
|
num_mesa_agents: 5
|
||||||
|
mesa_agent_type: social_wealth.MoneyAgent
|
||||||
|
N: 10
|
||||||
|
width: 50
|
||||||
|
height: 50
|
@ -0,0 +1,106 @@
|
|||||||
|
from mesa.visualization.ModularVisualization import ModularServer
|
||||||
|
from soil.visualization import UserSettableParameter
|
||||||
|
from mesa.visualization.modules import ChartModule, NetworkModule, CanvasGrid
|
||||||
|
from social_wealth import MoneyEnv, graph_generator, SocialMoneyAgent
|
||||||
|
|
||||||
|
|
||||||
|
class MyNetwork(NetworkModule):
|
||||||
|
def render(self, model):
|
||||||
|
return self.portrayal_method(model)
|
||||||
|
|
||||||
|
|
||||||
|
def network_portrayal(env):
|
||||||
|
# The model ensures there is 0 or 1 agent per node
|
||||||
|
|
||||||
|
portrayal = dict()
|
||||||
|
portrayal["nodes"] = [
|
||||||
|
{
|
||||||
|
"id": agent_id,
|
||||||
|
"size": env.get_agent(agent_id).wealth,
|
||||||
|
# "color": "#CC0000" if not agents or agents[0].wealth == 0 else "#007959",
|
||||||
|
"color": "#CC0000",
|
||||||
|
"label": f"{agent_id}: {env.get_agent(agent_id).wealth}",
|
||||||
|
}
|
||||||
|
for (agent_id) in env.G.nodes
|
||||||
|
]
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
|
|
||||||
|
portrayal["edges"] = [
|
||||||
|
{"id": edge_id, "source": source, "target": target, "color": "#000000"}
|
||||||
|
for edge_id, (source, target) in enumerate(env.G.edges)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
return portrayal
|
||||||
|
|
||||||
|
|
||||||
|
def gridPortrayal(agent):
|
||||||
|
"""
|
||||||
|
This function is registered with the visualization server to be called
|
||||||
|
each tick to indicate how to draw the agent in its current state.
|
||||||
|
:param agent: the agent in the simulation
|
||||||
|
:return: the portrayal dictionary
|
||||||
|
"""
|
||||||
|
color = max(10, min(agent.wealth*10, 100))
|
||||||
|
return {
|
||||||
|
"Shape": "rect",
|
||||||
|
"w": 1,
|
||||||
|
"h": 1,
|
||||||
|
"Filled": "true",
|
||||||
|
"Layer": 0,
|
||||||
|
"Label": agent.unique_id,
|
||||||
|
"Text": agent.unique_id,
|
||||||
|
"x": agent.pos[0],
|
||||||
|
"y": agent.pos[1],
|
||||||
|
"Color": f"rgba(31, 10, 255, 0.{color})"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
grid = MyNetwork(network_portrayal, 500, 500, library="sigma")
|
||||||
|
chart = ChartModule(
|
||||||
|
[{"Label": "Gini", "Color": "Black"}], data_collector_name="datacollector"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_params = {
|
||||||
|
"N": UserSettableParameter(
|
||||||
|
"slider",
|
||||||
|
"N",
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
10,
|
||||||
|
1,
|
||||||
|
description="Choose how many agents to include in the model",
|
||||||
|
),
|
||||||
|
"network_agents": [{"agent_type": SocialMoneyAgent}],
|
||||||
|
"height": UserSettableParameter(
|
||||||
|
"slider",
|
||||||
|
"height",
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
10,
|
||||||
|
1,
|
||||||
|
description="Grid height",
|
||||||
|
),
|
||||||
|
"width": UserSettableParameter(
|
||||||
|
"slider",
|
||||||
|
"width",
|
||||||
|
5,
|
||||||
|
5,
|
||||||
|
10,
|
||||||
|
1,
|
||||||
|
description="Grid width",
|
||||||
|
),
|
||||||
|
"network_params": {
|
||||||
|
'generator': graph_generator
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
canvas_element = CanvasGrid(gridPortrayal, model_params["width"].value, model_params["height"].value, 500, 500)
|
||||||
|
|
||||||
|
|
||||||
|
server = ModularServer(
|
||||||
|
MoneyEnv, [grid, chart, canvas_element], "Money Model", model_params
|
||||||
|
)
|
||||||
|
server.port = 8521
|
||||||
|
|
||||||
|
server.launch(open_browser=False)
|
@ -0,0 +1,134 @@
|
|||||||
|
'''
|
||||||
|
This is an example that adds soil agents and environment in a normal
|
||||||
|
mesa workflow.
|
||||||
|
'''
|
||||||
|
from mesa import Agent as MesaAgent
|
||||||
|
from mesa.space import MultiGrid
|
||||||
|
# from mesa.time import RandomActivation
|
||||||
|
from mesa.datacollection import DataCollector
|
||||||
|
from mesa.batchrunner import BatchRunner
|
||||||
|
|
||||||
|
import networkx as nx
|
||||||
|
|
||||||
|
from soil import NetworkAgent, Environment
|
||||||
|
|
||||||
|
def compute_gini(model):
|
||||||
|
agent_wealths = [agent.wealth for agent in model.agents]
|
||||||
|
x = sorted(agent_wealths)
|
||||||
|
N = len(list(model.agents))
|
||||||
|
B = sum( xi * (N-i) for i,xi in enumerate(x) ) / (N*sum(x))
|
||||||
|
return (1 + (1/N) - 2*B)
|
||||||
|
|
||||||
|
class MoneyAgent(MesaAgent):
|
||||||
|
"""
|
||||||
|
A MESA agent with fixed initial wealth.
|
||||||
|
It will only share wealth with neighbors based on grid proximity
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, unique_id, model):
|
||||||
|
super().__init__(unique_id=unique_id, model=model)
|
||||||
|
self.wealth = 1
|
||||||
|
|
||||||
|
def move(self):
|
||||||
|
possible_steps = self.model.grid.get_neighborhood(
|
||||||
|
self.pos,
|
||||||
|
moore=True,
|
||||||
|
include_center=False)
|
||||||
|
print(self.pos, possible_steps)
|
||||||
|
new_position = self.random.choice(possible_steps)
|
||||||
|
print(self.pos, new_position)
|
||||||
|
self.model.grid.move_agent(self, new_position)
|
||||||
|
|
||||||
|
def give_money(self):
|
||||||
|
cellmates = self.model.grid.get_cell_list_contents([self.pos])
|
||||||
|
if len(cellmates) > 1:
|
||||||
|
other = self.random.choice(cellmates)
|
||||||
|
other.wealth += 1
|
||||||
|
self.wealth -= 1
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.info("Crying wolf", self.pos)
|
||||||
|
self.move()
|
||||||
|
if self.wealth > 0:
|
||||||
|
self.give_money()
|
||||||
|
|
||||||
|
|
||||||
|
class SocialMoneyAgent(NetworkAgent, MoneyAgent):
|
||||||
|
wealth = 1
|
||||||
|
|
||||||
|
def give_money(self):
|
||||||
|
cellmates = set(self.model.grid.get_cell_list_contents([self.pos]))
|
||||||
|
friends = set(self.get_neighboring_agents())
|
||||||
|
self.info("Trying to give money")
|
||||||
|
self.debug("Cellmates: ", cellmates)
|
||||||
|
self.debug("Friends: ", friends)
|
||||||
|
|
||||||
|
nearby_friends = list(cellmates & friends)
|
||||||
|
|
||||||
|
if len(nearby_friends):
|
||||||
|
other = self.random.choice(nearby_friends)
|
||||||
|
other.wealth += 1
|
||||||
|
self.wealth -= 1
|
||||||
|
|
||||||
|
|
||||||
|
class MoneyEnv(Environment):
|
||||||
|
"""A model with some number of agents."""
|
||||||
|
def __init__(self, N, width, height, *args, network_params, **kwargs):
|
||||||
|
self.initialized = True
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
|
|
||||||
|
network_params['n'] = N
|
||||||
|
super().__init__(*args, network_params=network_params, **kwargs)
|
||||||
|
self.grid = MultiGrid(width, height, False)
|
||||||
|
# self.schedule = RandomActivation(self)
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
# Create agents
|
||||||
|
for agent in self.agents:
|
||||||
|
self.schedule.add(agent)
|
||||||
|
# a = MoneyAgent(i, self)
|
||||||
|
# self.schedule.add(a)
|
||||||
|
# Add the agent to a random grid cell
|
||||||
|
x = self.random.randrange(self.grid.width)
|
||||||
|
y = self.random.randrange(self.grid.height)
|
||||||
|
self.grid.place_agent(agent, (x, y))
|
||||||
|
|
||||||
|
self.datacollector = DataCollector(
|
||||||
|
model_reporters={"Gini": compute_gini},
|
||||||
|
agent_reporters={"Wealth": "wealth"})
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
super().step()
|
||||||
|
self.datacollector.collect(self)
|
||||||
|
self.schedule.step()
|
||||||
|
|
||||||
|
def graph_generator(n=5):
|
||||||
|
G = nx.Graph()
|
||||||
|
for ix in range(n):
|
||||||
|
G.add_edge(0, ix)
|
||||||
|
return G
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
|
||||||
|
G = graph_generator()
|
||||||
|
fixed_params = {"topology": G,
|
||||||
|
"width": 10,
|
||||||
|
"network_agents": [{"agent_type": SocialMoneyAgent,
|
||||||
|
'weight': 1}],
|
||||||
|
"height": 10}
|
||||||
|
|
||||||
|
variable_params = {"N": range(10, 100, 10)}
|
||||||
|
|
||||||
|
batch_run = BatchRunner(MoneyEnv,
|
||||||
|
variable_parameters=variable_params,
|
||||||
|
fixed_parameters=fixed_params,
|
||||||
|
iterations=5,
|
||||||
|
max_steps=100,
|
||||||
|
model_reporters={"Gini": compute_gini})
|
||||||
|
batch_run.run_all()
|
||||||
|
|
||||||
|
run_data = batch_run.get_model_vars_dataframe()
|
||||||
|
run_data.head()
|
||||||
|
print(run_data.Gini)
|
||||||
|
|
@ -0,0 +1,83 @@
|
|||||||
|
from mesa import Agent, Model
|
||||||
|
from mesa.space import MultiGrid
|
||||||
|
from mesa.time import RandomActivation
|
||||||
|
from mesa.datacollection import DataCollector
|
||||||
|
from mesa.batchrunner import BatchRunner
|
||||||
|
|
||||||
|
def compute_gini(model):
|
||||||
|
agent_wealths = [agent.wealth for agent in model.schedule.agents]
|
||||||
|
x = sorted(agent_wealths)
|
||||||
|
N = model.num_agents
|
||||||
|
B = sum( xi * (N-i) for i,xi in enumerate(x) ) / (N*sum(x))
|
||||||
|
return (1 + (1/N) - 2*B)
|
||||||
|
|
||||||
|
class MoneyAgent(Agent):
|
||||||
|
""" An agent with fixed initial wealth."""
|
||||||
|
def __init__(self, unique_id, model):
|
||||||
|
super().__init__(unique_id, model)
|
||||||
|
self.wealth = 1
|
||||||
|
|
||||||
|
def move(self):
|
||||||
|
possible_steps = self.model.grid.get_neighborhood(
|
||||||
|
self.pos,
|
||||||
|
moore=True,
|
||||||
|
include_center=False)
|
||||||
|
new_position = self.random.choice(possible_steps)
|
||||||
|
self.model.grid.move_agent(self, new_position)
|
||||||
|
|
||||||
|
def give_money(self):
|
||||||
|
cellmates = self.model.grid.get_cell_list_contents([self.pos])
|
||||||
|
if len(cellmates) > 1:
|
||||||
|
other = self.random.choice(cellmates)
|
||||||
|
other.wealth += 1
|
||||||
|
self.wealth -= 1
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.move()
|
||||||
|
if self.wealth > 0:
|
||||||
|
self.give_money()
|
||||||
|
|
||||||
|
class MoneyModel(Model):
|
||||||
|
"""A model with some number of agents."""
|
||||||
|
def __init__(self, N, width, height):
|
||||||
|
self.num_agents = N
|
||||||
|
self.grid = MultiGrid(width, height, True)
|
||||||
|
self.schedule = RandomActivation(self)
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
# Create agents
|
||||||
|
for i in range(self.num_agents):
|
||||||
|
a = MoneyAgent(i, self)
|
||||||
|
self.schedule.add(a)
|
||||||
|
# Add the agent to a random grid cell
|
||||||
|
x = self.random.randrange(self.grid.width)
|
||||||
|
y = self.random.randrange(self.grid.height)
|
||||||
|
self.grid.place_agent(a, (x, y))
|
||||||
|
|
||||||
|
self.datacollector = DataCollector(
|
||||||
|
model_reporters={"Gini": compute_gini},
|
||||||
|
agent_reporters={"Wealth": "wealth"})
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.datacollector.collect(self)
|
||||||
|
self.schedule.step()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
fixed_params = {"width": 10,
|
||||||
|
"height": 10}
|
||||||
|
variable_params = {"N": range(10, 500, 10)}
|
||||||
|
|
||||||
|
batch_run = BatchRunner(MoneyModel,
|
||||||
|
variable_params,
|
||||||
|
fixed_params,
|
||||||
|
iterations=5,
|
||||||
|
max_steps=100,
|
||||||
|
model_reporters={"Gini": compute_gini})
|
||||||
|
batch_run.run_all()
|
||||||
|
|
||||||
|
run_data = batch_run.get_model_vars_dataframe()
|
||||||
|
run_data.head()
|
||||||
|
print(run_data.Gini)
|
||||||
|
|
@ -1,9 +1,8 @@
|
|||||||
simpy>=4.0
|
|
||||||
networkx>=2.5
|
networkx>=2.5
|
||||||
numpy
|
numpy
|
||||||
matplotlib
|
matplotlib
|
||||||
pyyaml>=5.1
|
pyyaml>=5.1
|
||||||
pandas>=0.23
|
pandas>=0.23
|
||||||
scipy>=1.3
|
|
||||||
SALib>=1.3
|
SALib>=1.3
|
||||||
Jinja2
|
Jinja2
|
||||||
|
Mesa>=0.8
|
||||||
|
@ -1,40 +1,31 @@
|
|||||||
import random
|
import random
|
||||||
from . import BaseAgent
|
from . import FSM, state, default_state
|
||||||
|
|
||||||
|
|
||||||
class BassModel(BaseAgent):
|
class BassModel(FSM):
|
||||||
"""
|
"""
|
||||||
Settings:
|
Settings:
|
||||||
innovation_prob
|
innovation_prob
|
||||||
imitation_prob
|
imitation_prob
|
||||||
"""
|
"""
|
||||||
|
sentimentCorrelation = 0
|
||||||
def __init__(self, environment, agent_id, state, **kwargs):
|
|
||||||
super().__init__(environment=environment, agent_id=agent_id, state=state)
|
|
||||||
env_params = environment.environment_params
|
|
||||||
self.state['sentimentCorrelation'] = 0
|
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
self.behaviour()
|
self.behaviour()
|
||||||
|
|
||||||
def behaviour(self):
|
@default_state
|
||||||
# Outside effects
|
@state
|
||||||
if random.random() < self['innovation_prob']:
|
def innovation(self):
|
||||||
if self.state['id'] == 0:
|
if random.random() < self.innovation_prob:
|
||||||
self.state['id'] = 1
|
self.sentimentCorrelation = 1
|
||||||
self.state['sentimentCorrelation'] = 1
|
return self.aware
|
||||||
else:
|
else:
|
||||||
pass
|
aware_neighbors = self.get_neighboring_agents(state_id=self.aware.id)
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
# Imitation effects
|
|
||||||
if self.state['id'] == 0:
|
|
||||||
aware_neighbors = self.get_neighboring_agents(state_id=1)
|
|
||||||
num_neighbors_aware = len(aware_neighbors)
|
num_neighbors_aware = len(aware_neighbors)
|
||||||
if random.random() < (self['imitation_prob']*num_neighbors_aware):
|
if random.random() < (self['imitation_prob']*num_neighbors_aware):
|
||||||
self.state['id'] = 1
|
self.sentimentCorrelation = 1
|
||||||
self.state['sentimentCorrelation'] = 1
|
return self.aware
|
||||||
|
|
||||||
else:
|
@state
|
||||||
pass
|
def aware(self):
|
||||||
|
self.die()
|
||||||
|
@ -0,0 +1,20 @@
|
|||||||
|
from scipy.spatial import cKDTree as KDTree
|
||||||
|
from . import NetworkAgent
|
||||||
|
|
||||||
|
class Geo(NetworkAgent):
|
||||||
|
'''In this type of network, nodes have a "pos" attribute.'''
|
||||||
|
|
||||||
|
def geo_search(self, radius, node=None, center=False, **kwargs):
|
||||||
|
'''Get a list of nodes whose coordinates are closer than *radius* to *node*.'''
|
||||||
|
node = as_node(node if node is not None else self)
|
||||||
|
|
||||||
|
G = self.subgraph(**kwargs)
|
||||||
|
|
||||||
|
pos = nx.get_node_attributes(G, 'pos')
|
||||||
|
if not pos:
|
||||||
|
return []
|
||||||
|
nodes, coords = list(zip(*pos.items()))
|
||||||
|
kdtree = KDTree(coords) # Cannot provide generator.
|
||||||
|
indices = kdtree.query_ball_point(pos[node], radius)
|
||||||
|
return [nodes[i] for i in indices if center or (nodes[i] != node)]
|
||||||
|
|
@ -0,0 +1,26 @@
|
|||||||
|
from mesa import DataCollector as MDC
|
||||||
|
|
||||||
|
class SoilDataCollector(MDC):
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, environment, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# Populate model and env reporters so they have a key per
|
||||||
|
# So they can be shown in the web interface
|
||||||
|
self.environment = environment
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_vars(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@model_vars.setter
|
||||||
|
def model_vars(self, value):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def agent_reporters(self):
|
||||||
|
self.model._history._
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
@ -0,0 +1,84 @@
|
|||||||
|
from mesa.time import BaseScheduler
|
||||||
|
from queue import Empty
|
||||||
|
from heapq import heappush, heappop
|
||||||
|
import math
|
||||||
|
from .utils import logger
|
||||||
|
from mesa import Agent
|
||||||
|
|
||||||
|
|
||||||
|
class When:
|
||||||
|
def __init__(self, time):
|
||||||
|
self._time = float(time)
|
||||||
|
|
||||||
|
def abs(self, time):
|
||||||
|
return self._time
|
||||||
|
|
||||||
|
|
||||||
|
class Delta:
|
||||||
|
def __init__(self, delta):
|
||||||
|
self._delta = delta
|
||||||
|
|
||||||
|
def abs(self, time):
|
||||||
|
return time + self._delta
|
||||||
|
|
||||||
|
|
||||||
|
class TimedActivation(BaseScheduler):
|
||||||
|
"""A scheduler which activates each agent when the agent requests.
|
||||||
|
In each activation, each agent will update its 'next_time'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(self)
|
||||||
|
self._queue = []
|
||||||
|
self.next_time = 0
|
||||||
|
|
||||||
|
def add(self, agent: Agent):
|
||||||
|
if agent.unique_id not in self._agents:
|
||||||
|
heappush(self._queue, (self.time, agent.unique_id))
|
||||||
|
super().add(agent)
|
||||||
|
|
||||||
|
def step(self, until: float =float('inf')) -> None:
|
||||||
|
"""
|
||||||
|
Executes agents in order, one at a time. After each step,
|
||||||
|
an agent will signal when it wants to be scheduled next.
|
||||||
|
"""
|
||||||
|
|
||||||
|
when = None
|
||||||
|
agent_id = None
|
||||||
|
unsched = []
|
||||||
|
until = until or float('inf')
|
||||||
|
|
||||||
|
if not self._queue:
|
||||||
|
self.time = until
|
||||||
|
self.next_time = float('inf')
|
||||||
|
return
|
||||||
|
|
||||||
|
(when, agent_id) = self._queue[0]
|
||||||
|
|
||||||
|
if until and when > until:
|
||||||
|
self.time = until
|
||||||
|
self.next_time = when
|
||||||
|
return
|
||||||
|
|
||||||
|
self.time = when
|
||||||
|
next_time = float("inf")
|
||||||
|
|
||||||
|
while when == self.time:
|
||||||
|
heappop(self._queue)
|
||||||
|
logger.debug(f'Stepping agent {agent_id}')
|
||||||
|
when = (self._agents[agent_id].step() or Delta(1)).abs(self.time)
|
||||||
|
heappush(self._queue, (when, agent_id))
|
||||||
|
if when < next_time:
|
||||||
|
next_time = when
|
||||||
|
|
||||||
|
if not self._queue or self._queue[0][0] > self.time:
|
||||||
|
agent_id = None
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
(when, agent_id) = self._queue[0]
|
||||||
|
|
||||||
|
if when and when < self.time:
|
||||||
|
raise Exception("Invalid scheduling time")
|
||||||
|
|
||||||
|
self.next_time = next_time
|
||||||
|
self.steps += 1
|
@ -0,0 +1,5 @@
|
|||||||
|
from mesa.visualization.UserParam import UserSettableParameter
|
||||||
|
|
||||||
|
class UserSettableParameter(UserSettableParameter):
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
@ -1 +1,4 @@
|
|||||||
pytest
|
pytest
|
||||||
|
mesa>=0.8.9
|
||||||
|
scipy>=1.3
|
||||||
|
tornado
|
||||||
|
@ -1,203 +0,0 @@
|
|||||||
from unittest import TestCase
|
|
||||||
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
from glob import glob
|
|
||||||
|
|
||||||
from soil import history
|
|
||||||
from soil import utils
|
|
||||||
|
|
||||||
|
|
||||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
|
||||||
DBROOT = os.path.join(ROOT, 'testdb')
|
|
||||||
|
|
||||||
|
|
||||||
class TestHistory(TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
if not os.path.exists(DBROOT):
|
|
||||||
os.makedirs(DBROOT)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
if os.path.exists(DBROOT):
|
|
||||||
shutil.rmtree(DBROOT)
|
|
||||||
|
|
||||||
def test_history(self):
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
tuples = (
|
|
||||||
('a_0', 0, 'id', 'h'),
|
|
||||||
('a_0', 1, 'id', 'e'),
|
|
||||||
('a_0', 2, 'id', 'l'),
|
|
||||||
('a_0', 3, 'id', 'l'),
|
|
||||||
('a_0', 4, 'id', 'o'),
|
|
||||||
('a_1', 0, 'id', 'v'),
|
|
||||||
('a_1', 1, 'id', 'a'),
|
|
||||||
('a_1', 2, 'id', 'l'),
|
|
||||||
('a_1', 3, 'id', 'u'),
|
|
||||||
('a_1', 4, 'id', 'e'),
|
|
||||||
('env', 1, 'prob', 1),
|
|
||||||
('env', 3, 'prob', 2),
|
|
||||||
('env', 5, 'prob', 3),
|
|
||||||
('a_2', 7, 'finished', True),
|
|
||||||
)
|
|
||||||
h = history.History()
|
|
||||||
h.save_tuples(tuples)
|
|
||||||
# assert h['env', 0, 'prob'] == 0
|
|
||||||
for i in range(1, 7):
|
|
||||||
assert h['env', i, 'prob'] == ((i-1)//2)+1
|
|
||||||
|
|
||||||
|
|
||||||
for i, k in zip(range(5), 'hello'):
|
|
||||||
assert h['a_0', i, 'id'] == k
|
|
||||||
for record, value in zip(h['a_0', None, 'id'], 'hello'):
|
|
||||||
t_step, val = record
|
|
||||||
assert val == value
|
|
||||||
|
|
||||||
for i, k in zip(range(5), 'value'):
|
|
||||||
assert h['a_1', i, 'id'] == k
|
|
||||||
for i in range(5, 8):
|
|
||||||
assert h['a_1', i, 'id'] == 'e'
|
|
||||||
for i in range(7):
|
|
||||||
assert h['a_2', i, 'finished'] == False
|
|
||||||
assert h['a_2', 7, 'finished']
|
|
||||||
|
|
||||||
def test_history_gen(self):
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
tuples = (
|
|
||||||
('a_1', 0, 'id', 'v'),
|
|
||||||
('a_1', 1, 'id', 'a'),
|
|
||||||
('a_1', 2, 'id', 'l'),
|
|
||||||
('a_1', 3, 'id', 'u'),
|
|
||||||
('a_1', 4, 'id', 'e'),
|
|
||||||
('env', 1, 'prob', 1),
|
|
||||||
('env', 2, 'prob', 2),
|
|
||||||
('env', 3, 'prob', 3),
|
|
||||||
('a_2', 7, 'finished', True),
|
|
||||||
)
|
|
||||||
h = history.History()
|
|
||||||
h.save_tuples(tuples)
|
|
||||||
for t_step, key, value in h['env', None, None]:
|
|
||||||
assert t_step == value
|
|
||||||
assert key == 'prob'
|
|
||||||
|
|
||||||
records = list(h[None, 7, None])
|
|
||||||
assert len(records) == 3
|
|
||||||
for i in records:
|
|
||||||
agent_id, key, value = i
|
|
||||||
if agent_id == 'a_1':
|
|
||||||
assert key == 'id'
|
|
||||||
assert value == 'e'
|
|
||||||
elif agent_id == 'a_2':
|
|
||||||
assert key == 'finished'
|
|
||||||
assert value
|
|
||||||
else:
|
|
||||||
assert key == 'prob'
|
|
||||||
assert value == 3
|
|
||||||
|
|
||||||
records = h['a_1', 7, None]
|
|
||||||
assert records['id'] == 'e'
|
|
||||||
|
|
||||||
def test_history_file(self):
|
|
||||||
"""
|
|
||||||
History should be saved to a file
|
|
||||||
"""
|
|
||||||
tuples = (
|
|
||||||
('a_1', 0, 'id', 'v'),
|
|
||||||
('a_1', 1, 'id', 'a'),
|
|
||||||
('a_1', 2, 'id', 'l'),
|
|
||||||
('a_1', 3, 'id', 'u'),
|
|
||||||
('a_1', 4, 'id', 'e'),
|
|
||||||
('env', 1, 'prob', 1),
|
|
||||||
('env', 2, 'prob', 2),
|
|
||||||
('env', 3, 'prob', 3),
|
|
||||||
('a_2', 7, 'finished', True),
|
|
||||||
)
|
|
||||||
db_path = os.path.join(DBROOT, 'test')
|
|
||||||
h = history.History(db_path=db_path)
|
|
||||||
h.save_tuples(tuples)
|
|
||||||
h.flush_cache()
|
|
||||||
assert os.path.exists(db_path)
|
|
||||||
|
|
||||||
# Recover the data
|
|
||||||
recovered = history.History(db_path=db_path)
|
|
||||||
assert recovered['a_1', 0, 'id'] == 'v'
|
|
||||||
assert recovered['a_1', 4, 'id'] == 'e'
|
|
||||||
|
|
||||||
# Using backup=True should create a backup copy, and initialize an empty history
|
|
||||||
newhistory = history.History(db_path=db_path, backup=True)
|
|
||||||
backuppaths = glob(db_path + '.backup*.sqlite')
|
|
||||||
assert len(backuppaths) == 1
|
|
||||||
backuppath = backuppaths[0]
|
|
||||||
assert newhistory.db_path == h.db_path
|
|
||||||
assert os.path.exists(backuppath)
|
|
||||||
assert len(newhistory[None, None, None]) == 0
|
|
||||||
|
|
||||||
def test_history_tuples(self):
|
|
||||||
"""
|
|
||||||
The data recovered should be equal to the one recorded.
|
|
||||||
"""
|
|
||||||
tuples = (
|
|
||||||
('a_1', 0, 'id', 'v'),
|
|
||||||
('a_1', 1, 'id', 'a'),
|
|
||||||
('a_1', 2, 'id', 'l'),
|
|
||||||
('a_1', 3, 'id', 'u'),
|
|
||||||
('a_1', 4, 'id', 'e'),
|
|
||||||
('env', 1, 'prob', 1),
|
|
||||||
('env', 2, 'prob', 2),
|
|
||||||
('env', 3, 'prob', 3),
|
|
||||||
('a_2', 7, 'finished', True),
|
|
||||||
)
|
|
||||||
h = history.History()
|
|
||||||
h.save_tuples(tuples)
|
|
||||||
recovered = list(h.to_tuples())
|
|
||||||
assert recovered
|
|
||||||
for i in recovered:
|
|
||||||
assert i in tuples
|
|
||||||
|
|
||||||
def test_stats(self):
|
|
||||||
"""
|
|
||||||
The data recovered should be equal to the one recorded.
|
|
||||||
"""
|
|
||||||
tuples = (
|
|
||||||
('a_1', 0, 'id', 'v'),
|
|
||||||
('a_1', 1, 'id', 'a'),
|
|
||||||
('a_1', 2, 'id', 'l'),
|
|
||||||
('a_1', 3, 'id', 'u'),
|
|
||||||
('a_1', 4, 'id', 'e'),
|
|
||||||
('env', 1, 'prob', 1),
|
|
||||||
('env', 2, 'prob', 2),
|
|
||||||
('env', 3, 'prob', 3),
|
|
||||||
('a_2', 7, 'finished', True),
|
|
||||||
)
|
|
||||||
stat_tuples = [
|
|
||||||
{'num_infected': 5, 'runtime': 0.2},
|
|
||||||
{'num_infected': 5, 'runtime': 0.2},
|
|
||||||
{'new': '40'},
|
|
||||||
]
|
|
||||||
h = history.History()
|
|
||||||
h.save_tuples(tuples)
|
|
||||||
for stat in stat_tuples:
|
|
||||||
h.save_stats(stat)
|
|
||||||
recovered = h.get_stats()
|
|
||||||
assert recovered
|
|
||||||
assert recovered[0]['num_infected'] == 5
|
|
||||||
assert recovered[1]['runtime'] == 0.2
|
|
||||||
assert recovered[2]['new'] == '40'
|
|
||||||
|
|
||||||
def test_unflatten(self):
|
|
||||||
ex = {'count.neighbors.3': 4,
|
|
||||||
'count.times.2': 4,
|
|
||||||
'count.total.4': 4,
|
|
||||||
'mean.neighbors': 3,
|
|
||||||
'mean.times': 2,
|
|
||||||
'mean.total': 4,
|
|
||||||
't_step': 2,
|
|
||||||
'trial_id': 'exporter_sim_trial_1605817956-4475424'}
|
|
||||||
res = utils.unflatten_dict(ex)
|
|
||||||
|
|
||||||
assert 'count' in res
|
|
||||||
assert 'mean' in res
|
|
||||||
assert 't_step' in res
|
|
||||||
assert 'trial_id' in res
|
|
@ -0,0 +1,69 @@
|
|||||||
|
'''
|
||||||
|
Mesa-SOIL integration tests
|
||||||
|
|
||||||
|
We have to test that:
|
||||||
|
- Mesa agents can be used in SOIL
|
||||||
|
- Simplified soil agents can be used in mesa simulations
|
||||||
|
- Mesa and soil agents can interact in a simulation
|
||||||
|
|
||||||
|
- Mesa visualizations work with SOIL simulations
|
||||||
|
|
||||||
|
'''
|
||||||
|
from mesa import Agent, Model
|
||||||
|
from mesa.time import RandomActivation
|
||||||
|
from mesa.space import MultiGrid
|
||||||
|
|
||||||
|
class MoneyAgent(Agent):
|
||||||
|
""" An agent with fixed initial wealth."""
|
||||||
|
def __init__(self, unique_id, model):
|
||||||
|
super().__init__(unique_id, model)
|
||||||
|
self.wealth = 1
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.move()
|
||||||
|
if self.wealth > 0:
|
||||||
|
self.give_money()
|
||||||
|
|
||||||
|
def give_money(self):
|
||||||
|
cellmates = self.model.grid.get_cell_list_contents([self.pos])
|
||||||
|
if len(cellmates) > 1:
|
||||||
|
other = self.random.choice(cellmates)
|
||||||
|
other.wealth += 1
|
||||||
|
self.wealth -= 1
|
||||||
|
|
||||||
|
def move(self):
|
||||||
|
possible_steps = self.model.grid.get_neighborhood(
|
||||||
|
self.pos,
|
||||||
|
moore=True,
|
||||||
|
include_center=False)
|
||||||
|
new_position = self.random.choice(possible_steps)
|
||||||
|
self.model.grid.move_agent(self, new_position)
|
||||||
|
|
||||||
|
|
||||||
|
class MoneyModel(Model):
|
||||||
|
"""A model with some number of agents."""
|
||||||
|
def __init__(self, N, width, height):
|
||||||
|
self.num_agents = N
|
||||||
|
self.grid = MultiGrid(width, height, True)
|
||||||
|
self.schedule = RandomActivation(self)
|
||||||
|
|
||||||
|
# Create agents
|
||||||
|
for i in range(self.num_agents):
|
||||||
|
a = MoneyAgent(i, self)
|
||||||
|
self.schedule.add(a)
|
||||||
|
|
||||||
|
# Add the agent to a random grid cell
|
||||||
|
x = self.random.randrange(self.grid.width)
|
||||||
|
y = self.random.randrange(self.grid.height)
|
||||||
|
self.grid.place_agent(a, (x, y))
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
'''Advance the model by one step.'''
|
||||||
|
self.schedule.step()
|
||||||
|
|
||||||
|
|
||||||
|
# model = MoneyModel(10)
|
||||||
|
# for i in range(10):
|
||||||
|
# model.step()
|
||||||
|
|
||||||
|
# agent_wealth = [a.wealth for a in model.schedule.agents]
|
Loading…
Reference in New Issue