mirror of
https://github.com/gsi-upm/soil
synced 2025-08-23 19:52:19 +00:00
WIP: mesa compatibility
This commit is contained in:
@@ -21,11 +21,13 @@ class Ping(agents.FSM):
|
||||
@agents.default_state
|
||||
@agents.state
|
||||
def even(self):
|
||||
self.debug(f'Even {self["count"]}')
|
||||
self['count'] += 1
|
||||
return self.odd
|
||||
|
||||
@agents.state
|
||||
def odd(self):
|
||||
self.debug(f'Odd {self["count"]}')
|
||||
self['count'] += 1
|
||||
return self.even
|
||||
|
||||
@@ -82,8 +84,7 @@ class TestAnalysis(TestCase):
|
||||
|
||||
import numpy as np
|
||||
res_mean = analysis.get_value(df, 'count', aggfunc=np.mean)
|
||||
assert res_mean['count'].iloc[0] == 1
|
||||
|
||||
res_total = analysis.get_value(df)
|
||||
assert res_mean['count'].iloc[15] == (16+8)/2
|
||||
|
||||
res_total = analysis.get_majority(df)
|
||||
res_total['SEED'].iloc[0] == self.env['SEED']
|
||||
|
@@ -1,203 +0,0 @@
|
||||
from unittest import TestCase
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from glob import glob
|
||||
|
||||
from soil import history
|
||||
from soil import utils
|
||||
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
DBROOT = os.path.join(ROOT, 'testdb')
|
||||
|
||||
|
||||
class TestHistory(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not os.path.exists(DBROOT):
|
||||
os.makedirs(DBROOT)
|
||||
|
||||
def tearDown(self):
|
||||
if os.path.exists(DBROOT):
|
||||
shutil.rmtree(DBROOT)
|
||||
|
||||
def test_history(self):
|
||||
"""
|
||||
"""
|
||||
tuples = (
|
||||
('a_0', 0, 'id', 'h'),
|
||||
('a_0', 1, 'id', 'e'),
|
||||
('a_0', 2, 'id', 'l'),
|
||||
('a_0', 3, 'id', 'l'),
|
||||
('a_0', 4, 'id', 'o'),
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 3, 'prob', 2),
|
||||
('env', 5, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
h = history.History()
|
||||
h.save_tuples(tuples)
|
||||
# assert h['env', 0, 'prob'] == 0
|
||||
for i in range(1, 7):
|
||||
assert h['env', i, 'prob'] == ((i-1)//2)+1
|
||||
|
||||
|
||||
for i, k in zip(range(5), 'hello'):
|
||||
assert h['a_0', i, 'id'] == k
|
||||
for record, value in zip(h['a_0', None, 'id'], 'hello'):
|
||||
t_step, val = record
|
||||
assert val == value
|
||||
|
||||
for i, k in zip(range(5), 'value'):
|
||||
assert h['a_1', i, 'id'] == k
|
||||
for i in range(5, 8):
|
||||
assert h['a_1', i, 'id'] == 'e'
|
||||
for i in range(7):
|
||||
assert h['a_2', i, 'finished'] == False
|
||||
assert h['a_2', 7, 'finished']
|
||||
|
||||
def test_history_gen(self):
|
||||
"""
|
||||
"""
|
||||
tuples = (
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 2, 'prob', 2),
|
||||
('env', 3, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
h = history.History()
|
||||
h.save_tuples(tuples)
|
||||
for t_step, key, value in h['env', None, None]:
|
||||
assert t_step == value
|
||||
assert key == 'prob'
|
||||
|
||||
records = list(h[None, 7, None])
|
||||
assert len(records) == 3
|
||||
for i in records:
|
||||
agent_id, key, value = i
|
||||
if agent_id == 'a_1':
|
||||
assert key == 'id'
|
||||
assert value == 'e'
|
||||
elif agent_id == 'a_2':
|
||||
assert key == 'finished'
|
||||
assert value
|
||||
else:
|
||||
assert key == 'prob'
|
||||
assert value == 3
|
||||
|
||||
records = h['a_1', 7, None]
|
||||
assert records['id'] == 'e'
|
||||
|
||||
def test_history_file(self):
|
||||
"""
|
||||
History should be saved to a file
|
||||
"""
|
||||
tuples = (
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 2, 'prob', 2),
|
||||
('env', 3, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
db_path = os.path.join(DBROOT, 'test')
|
||||
h = history.History(db_path=db_path)
|
||||
h.save_tuples(tuples)
|
||||
h.flush_cache()
|
||||
assert os.path.exists(db_path)
|
||||
|
||||
# Recover the data
|
||||
recovered = history.History(db_path=db_path)
|
||||
assert recovered['a_1', 0, 'id'] == 'v'
|
||||
assert recovered['a_1', 4, 'id'] == 'e'
|
||||
|
||||
# Using backup=True should create a backup copy, and initialize an empty history
|
||||
newhistory = history.History(db_path=db_path, backup=True)
|
||||
backuppaths = glob(db_path + '.backup*.sqlite')
|
||||
assert len(backuppaths) == 1
|
||||
backuppath = backuppaths[0]
|
||||
assert newhistory.db_path == h.db_path
|
||||
assert os.path.exists(backuppath)
|
||||
assert len(newhistory[None, None, None]) == 0
|
||||
|
||||
def test_history_tuples(self):
|
||||
"""
|
||||
The data recovered should be equal to the one recorded.
|
||||
"""
|
||||
tuples = (
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 2, 'prob', 2),
|
||||
('env', 3, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
h = history.History()
|
||||
h.save_tuples(tuples)
|
||||
recovered = list(h.to_tuples())
|
||||
assert recovered
|
||||
for i in recovered:
|
||||
assert i in tuples
|
||||
|
||||
def test_stats(self):
|
||||
"""
|
||||
The data recovered should be equal to the one recorded.
|
||||
"""
|
||||
tuples = (
|
||||
('a_1', 0, 'id', 'v'),
|
||||
('a_1', 1, 'id', 'a'),
|
||||
('a_1', 2, 'id', 'l'),
|
||||
('a_1', 3, 'id', 'u'),
|
||||
('a_1', 4, 'id', 'e'),
|
||||
('env', 1, 'prob', 1),
|
||||
('env', 2, 'prob', 2),
|
||||
('env', 3, 'prob', 3),
|
||||
('a_2', 7, 'finished', True),
|
||||
)
|
||||
stat_tuples = [
|
||||
{'num_infected': 5, 'runtime': 0.2},
|
||||
{'num_infected': 5, 'runtime': 0.2},
|
||||
{'new': '40'},
|
||||
]
|
||||
h = history.History()
|
||||
h.save_tuples(tuples)
|
||||
for stat in stat_tuples:
|
||||
h.save_stats(stat)
|
||||
recovered = h.get_stats()
|
||||
assert recovered
|
||||
assert recovered[0]['num_infected'] == 5
|
||||
assert recovered[1]['runtime'] == 0.2
|
||||
assert recovered[2]['new'] == '40'
|
||||
|
||||
def test_unflatten(self):
|
||||
ex = {'count.neighbors.3': 4,
|
||||
'count.times.2': 4,
|
||||
'count.total.4': 4,
|
||||
'mean.neighbors': 3,
|
||||
'mean.times': 2,
|
||||
'mean.total': 4,
|
||||
't_step': 2,
|
||||
'trial_id': 'exporter_sim_trial_1605817956-4475424'}
|
||||
res = utils.unflatten_dict(ex)
|
||||
|
||||
assert 'count' in res
|
||||
assert 'mean' in res
|
||||
assert 't_step' in res
|
||||
assert 'trial_id' in res
|
@@ -126,7 +126,7 @@ class TestMain(TestCase):
|
||||
env = s.run_simulation(dry_run=True)[0]
|
||||
for agent in env.network_agents:
|
||||
last = 0
|
||||
assert len(agent[None, None]) == 10
|
||||
assert len(agent[None, None]) == 11
|
||||
for step, total in sorted(agent['total', None]):
|
||||
assert total == last + 2
|
||||
last = total
|
||||
@@ -198,11 +198,11 @@ class TestMain(TestCase):
|
||||
"""
|
||||
config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
||||
s = simulation.from_config(config)
|
||||
for i in range(5):
|
||||
s.run_simulation(dry_run=True)
|
||||
nconfig = s.to_dict()
|
||||
del nconfig['topology']
|
||||
assert config == nconfig
|
||||
|
||||
s.run_simulation(dry_run=True)
|
||||
nconfig = s.to_dict()
|
||||
del nconfig['topology']
|
||||
assert config == nconfig
|
||||
|
||||
def test_row_conversion(self):
|
||||
env = Environment()
|
||||
@@ -211,7 +211,7 @@ class TestMain(TestCase):
|
||||
res = list(env.history_to_tuples())
|
||||
assert len(res) == len(env.environment_params)
|
||||
|
||||
env._now = 1
|
||||
env.schedule.time = 1
|
||||
env['test'] = 'second_value'
|
||||
res = list(env.history_to_tuples())
|
||||
|
||||
@@ -281,7 +281,7 @@ class TestMain(TestCase):
|
||||
'weight': 2
|
||||
},
|
||||
]
|
||||
converted = agents.deserialize_distribution(agent_distro)
|
||||
converted = agents.deserialize_definition(agent_distro)
|
||||
assert converted[0]['agent_type'] == agents.CounterModel
|
||||
assert converted[1]['agent_type'] == CustomAgent
|
||||
pickle.dumps(converted)
|
||||
@@ -297,14 +297,14 @@ class TestMain(TestCase):
|
||||
'weight': 2
|
||||
},
|
||||
]
|
||||
converted = agents.serialize_distribution(agent_distro)
|
||||
converted = agents.serialize_definition(agent_distro)
|
||||
assert converted[0]['agent_type'] == 'CounterModel'
|
||||
assert converted[1]['agent_type'] == 'test_main.CustomAgent'
|
||||
pickle.dumps(converted)
|
||||
|
||||
def test_pickle_agent_environment(self):
|
||||
env = Environment(name='Test')
|
||||
a = agents.BaseAgent(environment=env, agent_id=25)
|
||||
a = agents.BaseAgent(model=env, unique_id=25)
|
||||
|
||||
a['key'] = 'test'
|
||||
|
||||
@@ -345,7 +345,7 @@ class TestMain(TestCase):
|
||||
|
||||
def test_until(self):
|
||||
config = {
|
||||
'name': 'exporter_sim',
|
||||
'name': 'until_sim',
|
||||
'network_params': {},
|
||||
'agent_type': 'CounterModel',
|
||||
'max_time': 2,
|
||||
|
69
tests/test_mesa.py
Normal file
69
tests/test_mesa.py
Normal file
@@ -0,0 +1,69 @@
|
||||
'''
|
||||
Mesa-SOIL integration tests
|
||||
|
||||
We have to test that:
|
||||
- Mesa agents can be used in SOIL
|
||||
- Simplified soil agents can be used in mesa simulations
|
||||
- Mesa and soil agents can interact in a simulation
|
||||
|
||||
- Mesa visualizations work with SOIL simulations
|
||||
|
||||
'''
|
||||
from mesa import Agent, Model
|
||||
from mesa.time import RandomActivation
|
||||
from mesa.space import MultiGrid
|
||||
|
||||
class MoneyAgent(Agent):
|
||||
""" An agent with fixed initial wealth."""
|
||||
def __init__(self, unique_id, model):
|
||||
super().__init__(unique_id, model)
|
||||
self.wealth = 1
|
||||
|
||||
def step(self):
|
||||
self.move()
|
||||
if self.wealth > 0:
|
||||
self.give_money()
|
||||
|
||||
def give_money(self):
|
||||
cellmates = self.model.grid.get_cell_list_contents([self.pos])
|
||||
if len(cellmates) > 1:
|
||||
other = self.random.choice(cellmates)
|
||||
other.wealth += 1
|
||||
self.wealth -= 1
|
||||
|
||||
def move(self):
|
||||
possible_steps = self.model.grid.get_neighborhood(
|
||||
self.pos,
|
||||
moore=True,
|
||||
include_center=False)
|
||||
new_position = self.random.choice(possible_steps)
|
||||
self.model.grid.move_agent(self, new_position)
|
||||
|
||||
|
||||
class MoneyModel(Model):
|
||||
"""A model with some number of agents."""
|
||||
def __init__(self, N, width, height):
|
||||
self.num_agents = N
|
||||
self.grid = MultiGrid(width, height, True)
|
||||
self.schedule = RandomActivation(self)
|
||||
|
||||
# Create agents
|
||||
for i in range(self.num_agents):
|
||||
a = MoneyAgent(i, self)
|
||||
self.schedule.add(a)
|
||||
|
||||
# Add the agent to a random grid cell
|
||||
x = self.random.randrange(self.grid.width)
|
||||
y = self.random.randrange(self.grid.height)
|
||||
self.grid.place_agent(a, (x, y))
|
||||
|
||||
def step(self):
|
||||
'''Advance the model by one step.'''
|
||||
self.schedule.step()
|
||||
|
||||
|
||||
# model = MoneyModel(10)
|
||||
# for i in range(10):
|
||||
# model.step()
|
||||
|
||||
# agent_wealth = [a.wealth for a in model.schedule.agents]
|
Reference in New Issue
Block a user