mirror of
https://github.com/gsi-upm/soil
synced 2024-11-22 03:02:28 +00:00
WIP
This commit is contained in:
parent
bbaed636a8
commit
e41dc3dae2
@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
|
|||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
## [UNRELEASED]
|
## [UNRELEASED]
|
||||||
|
### Changed
|
||||||
|
* Configuration schema is very different now. Check `soil.config` for more information. We are using Pydantic for (de)serialization.
|
||||||
|
* There may be more than one topology/network in the simulation
|
||||||
|
* Agents are split into groups now. Each group may be assigned a given set of agents or an agent distribution, and a network topology to be assigned to.
|
||||||
|
### Removed
|
||||||
|
* Any `tsih` and `History` integration in the main classes. To record the state of environments/agents, just use a datacollector. In some cases this may be slower or consume more memory than the previous system. However, few cases actually used the full potential of the history, and it came at the cost of unnecessary complexity and worse performance for the majority of cases.
|
||||||
## [0.20.7]
|
## [0.20.7]
|
||||||
### Changed
|
### Changed
|
||||||
* Creating a `time.When` from another `time.When` does not nest them anymore (it returns the argument)
|
* Creating a `time.When` from another `time.When` does not nest them anymore (it returns the argument)
|
||||||
|
@ -1,38 +1,59 @@
|
|||||||
---
|
---
|
||||||
|
version: '2'
|
||||||
general:
|
general:
|
||||||
name: simple
|
id: simple
|
||||||
group: tests
|
group: tests
|
||||||
dir_path: "/tmp/"
|
dir_path: "/tmp/"
|
||||||
num_trials: 3
|
num_trials: 3
|
||||||
max_time: 100
|
max_time: 100
|
||||||
interval: 1
|
interval: 1
|
||||||
seed: "CompleteSeed!"
|
seed: "CompleteSeed!"
|
||||||
network:
|
topologies:
|
||||||
group:
|
default:
|
||||||
network
|
params:
|
||||||
params:
|
generator: complete_graph
|
||||||
generator: complete_graph
|
n: 10
|
||||||
n: 10
|
another_graph:
|
||||||
|
params:
|
||||||
|
generator: complete_graph
|
||||||
|
n: 2
|
||||||
environment:
|
environment:
|
||||||
environment_class: Environment
|
environment_class: Environment
|
||||||
params:
|
params:
|
||||||
am_i_complete: true
|
am_i_complete: true
|
||||||
agents:
|
agents:
|
||||||
default:
|
# Agents are split into groups, each with its own definition
|
||||||
|
default: # This is a special group. Its values will be used as default values for the rest of the groups
|
||||||
agent_class: CounterModel
|
agent_class: CounterModel
|
||||||
|
topology: default
|
||||||
state:
|
state:
|
||||||
times: 1
|
times: 1
|
||||||
environment:
|
environment:
|
||||||
|
# In this group we are not specifying any topology
|
||||||
fixed:
|
fixed:
|
||||||
- agent_id: 'Environment Agent 1'
|
- agent_id: 'Environment Agent 1'
|
||||||
agent_class: CounterModel
|
agent_class: CounterModel
|
||||||
state:
|
state:
|
||||||
times: 10
|
times: 10
|
||||||
network:
|
general_counters:
|
||||||
|
topology: default
|
||||||
distribution:
|
distribution:
|
||||||
- agent_class: CounterModel
|
- agent_class: CounterModel
|
||||||
weight: 1
|
weight: 1
|
||||||
state:
|
state:
|
||||||
state_id: 0
|
id: 0
|
||||||
|
times: 3
|
||||||
- agent_class: AggregatedCounter
|
- agent_class: AggregatedCounter
|
||||||
weight: 0.2
|
weight: 0.2
|
||||||
|
other_counters:
|
||||||
|
topology: another_graph
|
||||||
|
fixed:
|
||||||
|
- agent_class: CounterModel
|
||||||
|
id: 0
|
||||||
|
state:
|
||||||
|
times: 1
|
||||||
|
total: 0
|
||||||
|
- agent_class: CounterModel
|
||||||
|
id: 1
|
||||||
|
# If not specified, it will use the state set in the default
|
||||||
|
# state:
|
||||||
|
@ -6,5 +6,4 @@ pandas>=0.23
|
|||||||
SALib>=1.3
|
SALib>=1.3
|
||||||
Jinja2
|
Jinja2
|
||||||
Mesa>=0.8.9
|
Mesa>=0.8.9
|
||||||
tsih>=0.1.6
|
|
||||||
pydantic>=1.9
|
pydantic>=1.9
|
||||||
|
@ -7,9 +7,15 @@ class CounterModel(NetworkAgent):
|
|||||||
in each step and adds it to its state.
|
in each step and adds it to its state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
defaults = {
|
||||||
|
'times': 0,
|
||||||
|
'neighbors': 0,
|
||||||
|
'total': 0
|
||||||
|
}
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
# Outside effects
|
# Outside effects
|
||||||
total = len(list(self.get_agents()))
|
total = len(list(self.env.agents))
|
||||||
neighbors = len(list(self.get_neighboring_agents()))
|
neighbors = len(list(self.get_neighboring_agents()))
|
||||||
self['times'] = self.get('times', 0) + 1
|
self['times'] = self.get('times', 0) + 1
|
||||||
self['neighbors'] = neighbors
|
self['neighbors'] = neighbors
|
||||||
@ -33,6 +39,6 @@ class AggregatedCounter(NetworkAgent):
|
|||||||
self['times'] += 1
|
self['times'] += 1
|
||||||
neighbors = len(list(self.get_neighboring_agents()))
|
neighbors = len(list(self.get_neighboring_agents()))
|
||||||
self['neighbors'] += neighbors
|
self['neighbors'] += neighbors
|
||||||
total = len(list(self.get_agents()))
|
total = len(list(self.env.agents))
|
||||||
self['total'] += total
|
self['total'] += total
|
||||||
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
|
from collections.abc import Mapping, Set
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from itertools import islice
|
from itertools import islice, chain
|
||||||
import json
|
import json
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
from .. import serialization, utils, time
|
from mesa import Agent as MesaAgent
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
from tsih import Key
|
from .. import serialization, utils, time, config
|
||||||
|
|
||||||
from mesa import Agent
|
|
||||||
|
|
||||||
|
|
||||||
def as_node(agent):
|
def as_node(agent):
|
||||||
@ -24,7 +25,7 @@ IGNORED_FIELDS = ('model', 'logger')
|
|||||||
class DeadAgent(Exception):
|
class DeadAgent(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class BaseAgent(Agent):
|
class BaseAgent(MesaAgent):
|
||||||
"""
|
"""
|
||||||
A special type of Mesa Agent that:
|
A special type of Mesa Agent that:
|
||||||
|
|
||||||
@ -47,9 +48,8 @@ class BaseAgent(Agent):
|
|||||||
):
|
):
|
||||||
# Check for REQUIRED arguments
|
# Check for REQUIRED arguments
|
||||||
# Initialize agent parameters
|
# Initialize agent parameters
|
||||||
if isinstance(unique_id, Agent):
|
if isinstance(unique_id, MesaAgent):
|
||||||
raise Exception()
|
raise Exception()
|
||||||
self._saved = set()
|
|
||||||
super().__init__(unique_id=unique_id, model=model)
|
super().__init__(unique_id=unique_id, model=model)
|
||||||
self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id)
|
self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id)
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ class BaseAgent(Agent):
|
|||||||
self.alive = True
|
self.alive = True
|
||||||
|
|
||||||
self.interval = interval or self.get('interval', 1)
|
self.interval = interval or self.get('interval', 1)
|
||||||
self.logger = logging.getLogger(self.model.name).getChild(self.name)
|
self.logger = logging.getLogger(self.model.id).getChild(self.name)
|
||||||
|
|
||||||
if hasattr(self, 'level'):
|
if hasattr(self, 'level'):
|
||||||
self.logger.setLevel(self.level)
|
self.logger.setLevel(self.level)
|
||||||
@ -66,6 +66,7 @@ class BaseAgent(Agent):
|
|||||||
setattr(self, k, deepcopy(v))
|
setattr(self, k, deepcopy(v))
|
||||||
|
|
||||||
for (k, v) in kwargs.items():
|
for (k, v) in kwargs.items():
|
||||||
|
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
|
|
||||||
for (k, v) in getattr(self, 'defaults', {}).items():
|
for (k, v) in getattr(self, 'defaults', {}).items():
|
||||||
@ -107,23 +108,7 @@ class BaseAgent(Agent):
|
|||||||
def environment_params(self, value):
|
def environment_params(self, value):
|
||||||
self.model.environment_params = value
|
self.model.environment_params = value
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
|
||||||
if not key.startswith('_') and key not in IGNORED_FIELDS:
|
|
||||||
try:
|
|
||||||
k = Key(t_step=self.now,
|
|
||||||
dict_id=self.unique_id,
|
|
||||||
key=key)
|
|
||||||
self._saved.add(key)
|
|
||||||
self.model[k] = value
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
super().__setattr__(key, value)
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if isinstance(key, tuple):
|
|
||||||
key, t_step = key
|
|
||||||
k = Key(key=key, t_step=t_step, dict_id=self.unique_id)
|
|
||||||
return self.model[k]
|
|
||||||
return getattr(self, key)
|
return getattr(self, key)
|
||||||
|
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
@ -135,8 +120,11 @@ class BaseAgent(Agent):
|
|||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return (k for k in self.__dict__ if k[0] != '_')
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
return ((k, getattr(self, k)) for k in self._saved)
|
return ((k, v) for (k, v) in self.__dict__.items() if k[0] != '_')
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
return self[key] if key in self else default
|
return self[key] if key in self else default
|
||||||
@ -174,22 +162,32 @@ class BaseAgent(Agent):
|
|||||||
extra['agent_name'] = self.name
|
extra['agent_name'] = self.name
|
||||||
return self.logger.log(level, message, extra=extra)
|
return self.logger.log(level, message, extra=extra)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def debug(self, *args, **kwargs):
|
def debug(self, *args, **kwargs):
|
||||||
return self.log(*args, level=logging.DEBUG, **kwargs)
|
return self.log(*args, level=logging.DEBUG, **kwargs)
|
||||||
|
|
||||||
def info(self, *args, **kwargs):
|
def info(self, *args, **kwargs):
|
||||||
return self.log(*args, level=logging.INFO, **kwargs)
|
return self.log(*args, level=logging.INFO, **kwargs)
|
||||||
|
|
||||||
|
# Alias
|
||||||
|
# Agent = BaseAgent
|
||||||
|
|
||||||
class NetworkAgent(BaseAgent):
|
class NetworkAgent(BaseAgent):
|
||||||
|
def __init__(self,
|
||||||
@property
|
*args,
|
||||||
def topology(self):
|
graph_name: str,
|
||||||
return self.model.G
|
node_id: int = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.graph_name = graph_name
|
||||||
|
self.topology = self.env.topologies[self.graph_name]
|
||||||
|
self.node_id = node_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def G(self):
|
def G(self):
|
||||||
return self.model.G
|
return self.model.topologies[self._topology]
|
||||||
|
|
||||||
def count_agents(self, **kwargs):
|
def count_agents(self, **kwargs):
|
||||||
return len(list(self.get_agents(**kwargs)))
|
return len(list(self.get_agents(**kwargs)))
|
||||||
@ -210,8 +208,7 @@ class NetworkAgent(BaseAgent):
|
|||||||
if limit_neighbors:
|
if limit_neighbors:
|
||||||
agents = self.topology.neighbors(self.unique_id)
|
agents = self.topology.neighbors(self.unique_id)
|
||||||
|
|
||||||
agents = self.model.get_agents(agents)
|
return self.model.agents(ids=agents, **kwargs)
|
||||||
return select(agents, **kwargs)
|
|
||||||
|
|
||||||
def subgraph(self, center=True, **kwargs):
|
def subgraph(self, center=True, **kwargs):
|
||||||
include = [self] if center else []
|
include = [self] if center else []
|
||||||
@ -229,7 +226,6 @@ class NetworkAgent(BaseAgent):
|
|||||||
|
|
||||||
self.topology.add_edge(self.unique_id, other.unique_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
self.topology.add_edge(self.unique_id, other.unique_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
||||||
|
|
||||||
|
|
||||||
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
||||||
'''Get a list of nodes in the ego network of *node* of radius *steps*'''
|
'''Get a list of nodes in the ego network of *node* of radius *steps*'''
|
||||||
node = as_node(node if node is not None else self)
|
node = as_node(node if node is not None else self)
|
||||||
@ -311,7 +307,7 @@ class MetaFSM(type):
|
|||||||
cls.states = states
|
cls.states = states
|
||||||
|
|
||||||
|
|
||||||
class FSM(NetworkAgent, metaclass=MetaFSM):
|
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(FSM, self).__init__(*args, **kwargs)
|
super(FSM, self).__init__(*args, **kwargs)
|
||||||
if not hasattr(self, 'state_id'):
|
if not hasattr(self, 'state_id'):
|
||||||
@ -537,32 +533,171 @@ def _definition_to_dict(definition, size=None, default_state=None):
|
|||||||
return agents
|
return agents
|
||||||
|
|
||||||
|
|
||||||
def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs):
|
class AgentView(Mapping, Set):
|
||||||
|
"""A lazy-loaded list of agents.
|
||||||
|
"""
|
||||||
|
|
||||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
__slots__ = ("_agents",)
|
||||||
state_id = tuple([state_id])
|
|
||||||
if agent_type is not None:
|
|
||||||
try:
|
|
||||||
agent_type = tuple(agent_type)
|
|
||||||
except TypeError:
|
|
||||||
agent_type = tuple([agent_type])
|
|
||||||
|
|
||||||
f = agents
|
|
||||||
|
|
||||||
if ignore:
|
def __init__(self, agents):
|
||||||
f = filter(lambda x: x not in ignore, f)
|
self._agents = agents
|
||||||
|
|
||||||
if state_id is not None:
|
def __getstate__(self):
|
||||||
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
|
return {"_agents": self._agents}
|
||||||
|
|
||||||
if agent_type is not None:
|
def __setstate__(self, state):
|
||||||
f = filter(lambda agent: isinstance(agent, agent_type), f)
|
self._agents = state["_agents"]
|
||||||
for k, v in kwargs.items():
|
|
||||||
f = filter(lambda agent: agent.state.get(k, None) == v, f)
|
|
||||||
|
|
||||||
if iterator:
|
# Mapping methods
|
||||||
return f
|
def __len__(self):
|
||||||
return f
|
return sum(len(x) for x in self._agents.values())
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(chain.from_iterable(g.values() for g in self._agents.values()))
|
||||||
|
|
||||||
|
def __getitem__(self, agent_id):
|
||||||
|
if isinstance(agent_id, slice):
|
||||||
|
raise ValueError(f"Slicing is not supported")
|
||||||
|
for group in self._agents.values():
|
||||||
|
if agent_id in group:
|
||||||
|
return group[agent_id]
|
||||||
|
raise ValueError(f"Agent {agent_id} not found")
|
||||||
|
|
||||||
|
def filter(self, ids=None, groups=None, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs):
|
||||||
|
|
||||||
|
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||||
|
state_id = tuple([state_id])
|
||||||
|
|
||||||
|
agents = self._agents
|
||||||
|
|
||||||
|
if groups:
|
||||||
|
agents = {(k,v) for (k, v) in agents.items() if k in groups}
|
||||||
|
|
||||||
|
if agent_type is not None:
|
||||||
|
try:
|
||||||
|
agent_type = tuple(agent_type)
|
||||||
|
except TypeError:
|
||||||
|
agent_type = tuple([agent_type])
|
||||||
|
|
||||||
|
if ids:
|
||||||
|
agents = (v[aid] for v in agents.values() for aid in ids if aid in v)
|
||||||
|
else:
|
||||||
|
agents = (a for v in agents.values() for a in v.values())
|
||||||
|
|
||||||
|
f = agents
|
||||||
|
if ignore:
|
||||||
|
f = filter(lambda x: x not in ignore, f)
|
||||||
|
|
||||||
|
if state_id is not None:
|
||||||
|
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
|
||||||
|
|
||||||
|
if agent_type is not None:
|
||||||
|
f = filter(lambda agent: isinstance(agent, agent_type), f)
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
f = filter(lambda agent: agent.state.get(k, None) == v, f)
|
||||||
|
|
||||||
|
if iterator:
|
||||||
|
return f
|
||||||
|
return list(f)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.filter(*args, **kwargs)
|
||||||
|
|
||||||
|
def __contains__(self, agent_id):
|
||||||
|
return any(agent_id in g for g in self._agents)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(list(a.id for a in self))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}({self})"
|
||||||
|
|
||||||
|
|
||||||
|
def from_config(cfg: Dict[str, config.AgentConfig], env):
|
||||||
|
'''
|
||||||
|
Agents are specified in groups.
|
||||||
|
Each group can be specified in two ways, either through a fixed list in which each item has
|
||||||
|
has the agent type, number of agents to create, and the other parameters, or through what we call
|
||||||
|
an `agent distribution`, which is similar but instead of number of agents, it specifies the weight
|
||||||
|
of each agent type.
|
||||||
|
'''
|
||||||
|
default = cfg.get('default', None)
|
||||||
|
return {k: _group_from_config(c, default=default, env=env) for (k, c) in cfg.items() if k is not 'default'}
|
||||||
|
|
||||||
|
|
||||||
|
def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfig, env):
|
||||||
|
agents = {}
|
||||||
|
if cfg.fixed is not None:
|
||||||
|
agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env)
|
||||||
|
if cfg.distribution:
|
||||||
|
n = cfg.n or len(env.topologies[cfg.topology])
|
||||||
|
agents.update(_from_distro(cfg.distribution, n - len(agents),
|
||||||
|
topology=cfg.topology or default.topology,
|
||||||
|
default=default,
|
||||||
|
env=env))
|
||||||
|
return agents
|
||||||
|
|
||||||
|
|
||||||
|
def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: config.SingleAgentConfig, env):
|
||||||
|
agents = {}
|
||||||
|
|
||||||
|
for fixed in lst:
|
||||||
|
agent_id = fixed.agent_id
|
||||||
|
if agent_id is None:
|
||||||
|
agent_id = env.next_id()
|
||||||
|
|
||||||
|
cls = serialization.deserialize(fixed.agent_class or default.agent_class)
|
||||||
|
state = fixed.state.copy()
|
||||||
|
state.update(default.state)
|
||||||
|
agents[agent_id] = cls(unique_id=agent_id,
|
||||||
|
model=env,
|
||||||
|
graph_name=fixed.topology or topology or default.topology,
|
||||||
|
**state)
|
||||||
|
|
||||||
|
return agents
|
||||||
|
|
||||||
|
|
||||||
|
def _from_distro(distro: List[config.AgentDistro],
|
||||||
|
n: int,
|
||||||
|
topology: str,
|
||||||
|
default: config.SingleAgentConfig,
|
||||||
|
env):
|
||||||
|
|
||||||
|
agents = {}
|
||||||
|
|
||||||
|
if n is None:
|
||||||
|
if any(lambda dist: dist.n is None, distro):
|
||||||
|
raise ValueError('You must provide a total number of agents, or the number of each type')
|
||||||
|
n = sum(dist.n for dist in distro)
|
||||||
|
|
||||||
|
|
||||||
|
total = sum((dist.weight if dist.weight is not None else 1) for dist in distro)
|
||||||
|
thres = {}
|
||||||
|
last = 0
|
||||||
|
for i in sorted(distro, key=lambda x: x.weight):
|
||||||
|
|
||||||
|
cls = serialization.deserialize(i.agent_class or default.agent_class)
|
||||||
|
thres[(last, last + i.weight/total)] = (cls, i)
|
||||||
|
|
||||||
|
acc = 0
|
||||||
|
|
||||||
|
# using np.choice would be more efficient, but this allows us to use soil without
|
||||||
|
# numpy
|
||||||
|
for i in range(n):
|
||||||
|
r = random.random()
|
||||||
|
for (t, (cls, d)) in thres.items():
|
||||||
|
if r >= t[0] and r <= t[1]:
|
||||||
|
agent_id = d.agent_id
|
||||||
|
if agent_id is None:
|
||||||
|
agent_id = env.next_id()
|
||||||
|
|
||||||
|
state = d.state.copy()
|
||||||
|
state.update(default.state)
|
||||||
|
agents[agent_id] = cls(unique_id=agent_id, model=env, graph_name=d.topology or topology or default.topology, **state)
|
||||||
|
break
|
||||||
|
|
||||||
|
return agents
|
||||||
|
|
||||||
|
|
||||||
from .BassModel import *
|
from .BassModel import *
|
||||||
|
@ -11,7 +11,7 @@ from pydantic import BaseModel, Extra
|
|||||||
class General(BaseModel):
|
class General(BaseModel):
|
||||||
id: str = 'Unnamed Simulation'
|
id: str = 'Unnamed Simulation'
|
||||||
group: str = None
|
group: str = None
|
||||||
dir_path: str = None
|
dir_path: Optional[str] = None
|
||||||
num_trials: int = 1
|
num_trials: int = 1
|
||||||
max_time: float = 100
|
max_time: float = 100
|
||||||
interval: float = 1
|
interval: float = 1
|
||||||
@ -27,13 +27,13 @@ nodeId = int
|
|||||||
|
|
||||||
class Node(BaseModel):
|
class Node(BaseModel):
|
||||||
id: nodeId
|
id: nodeId
|
||||||
state: Dict[str, Any]
|
state: Optional[Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
class Edge(BaseModel):
|
class Edge(BaseModel):
|
||||||
source: nodeId
|
source: nodeId
|
||||||
target: nodeId
|
target: nodeId
|
||||||
value: float = 1
|
value: Optional[float] = 1
|
||||||
|
|
||||||
|
|
||||||
class Topology(BaseModel):
|
class Topology(BaseModel):
|
||||||
@ -75,46 +75,62 @@ class EnvConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class SingleAgentConfig(BaseModel):
|
class SingleAgentConfig(BaseModel):
|
||||||
agent_class: Union[Type, str] = 'soil.Agent'
|
agent_class: Optional[Union[Type, str]] = None
|
||||||
agent_id: Optional[Union[str, int]] = None
|
agent_id: Optional[Union[str, int]] = None
|
||||||
params: Dict[str, Any] = {}
|
topology: Optional[str] = None
|
||||||
state: Dict[str, Any] = {}
|
state: Optional[Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
class FixedAgentConfig(SingleAgentConfig):
|
||||||
class AgentDistro(SingleAgentConfig):
|
n: Optional[int] = 1
|
||||||
weight: Optional[float] = None
|
|
||||||
n: Optional[int] = None
|
|
||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def validate_all(cls, values):
|
def validate_all(cls, values):
|
||||||
if 'weight' in values and 'count' in values:
|
if 'agent_id' in values and values.get('n', 1) > 1:
|
||||||
raise ValueError("You may either specify a weight in the distribution or an agent count")
|
raise ValueError("An agent_id can only be provided when there is only one agent")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class AgentDistro(SingleAgentConfig):
|
||||||
|
weight: Optional[float] = 1
|
||||||
|
|
||||||
|
|
||||||
class AgentConfig(SingleAgentConfig):
|
class AgentConfig(SingleAgentConfig):
|
||||||
n: Optional[int] = None
|
n: Optional[int] = None
|
||||||
|
topology: Optional[str] = None
|
||||||
distribution: Optional[List[AgentDistro]] = None
|
distribution: Optional[List[AgentDistro]] = None
|
||||||
fixed: Optional[List[SingleAgentConfig]] = None
|
fixed: Optional[List[FixedAgentConfig]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default():
|
def default():
|
||||||
return AgentConfig()
|
return AgentConfig()
|
||||||
|
|
||||||
|
@root_validator
|
||||||
|
def validate_all(cls, values):
|
||||||
|
if 'distribution' in values and ('n' not in values and 'topology' not in values):
|
||||||
|
raise ValueError("You need to provide the number of agents or a topology to extract the value from.")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel, extra=Extra.forbid):
|
class Config(BaseModel, extra=Extra.forbid):
|
||||||
|
version: Optional[str] = '1'
|
||||||
general: General = General.default()
|
general: General = General.default()
|
||||||
network: Optional[NetConfig] = None
|
topologies: Optional[Dict[str, NetConfig]] = {}
|
||||||
environment: EnvConfig = EnvConfig.default()
|
environment: EnvConfig = EnvConfig.default()
|
||||||
agents: Dict[str, AgentConfig] = {}
|
agents: Optional[Dict[str, AgentConfig]] = {}
|
||||||
|
|
||||||
|
|
||||||
def convert_old(old):
|
def convert_old(old, strict=True):
|
||||||
'''
|
'''
|
||||||
Try to convert old style configs into the new format.
|
Try to convert old style configs into the new format.
|
||||||
|
|
||||||
This is still a work in progress and might not work in many cases.
|
This is still a work in progress and might not work in many cases.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
# TODO: translate states
|
||||||
|
|
||||||
|
if strict and old.get('states', {}):
|
||||||
|
raise ValueError('Custom (i.e., manual) agent states cannot be translated to v2 configuration files. Please, convert your configuration file to the new format.')
|
||||||
|
|
||||||
new = {}
|
new = {}
|
||||||
|
|
||||||
|
|
||||||
@ -129,7 +145,10 @@ def convert_old(old):
|
|||||||
if k in old:
|
if k in old:
|
||||||
general[k] = old[k]
|
general[k] = old[k]
|
||||||
|
|
||||||
network = {'group': 'network'}
|
if 'name' in old:
|
||||||
|
general['id'] = old['name']
|
||||||
|
|
||||||
|
network = {}
|
||||||
|
|
||||||
|
|
||||||
if 'network_params' in old and old['network_params']:
|
if 'network_params' in old and old['network_params']:
|
||||||
@ -143,9 +162,6 @@ def convert_old(old):
|
|||||||
network['topology'] = old['topology']
|
network['topology'] = old['topology']
|
||||||
|
|
||||||
agents = {
|
agents = {
|
||||||
'environment': {
|
|
||||||
'fixed': []
|
|
||||||
},
|
|
||||||
'network': {},
|
'network': {},
|
||||||
'default': {},
|
'default': {},
|
||||||
}
|
}
|
||||||
@ -164,10 +180,31 @@ def convert_old(old):
|
|||||||
return newagent
|
return newagent
|
||||||
|
|
||||||
for agent in old.get('environment_agents', []):
|
for agent in old.get('environment_agents', []):
|
||||||
agents['environment']['fixed'].append(updated_agent(agent))
|
agents['environment'] = {'distribution': [], 'fixed': []}
|
||||||
|
if 'agent_id' not in agent:
|
||||||
|
agents['environment']['distribution'].append(updated_agent(agent))
|
||||||
|
else:
|
||||||
|
agents['environment']['fixed'].append(updated_agent(agent))
|
||||||
|
|
||||||
for agent in old.get('network_agents', []):
|
by_weight = []
|
||||||
agents['network'].setdefault('distribution', []).append(updated_agent(agent))
|
fixed = []
|
||||||
|
|
||||||
|
if 'network_agents' in old:
|
||||||
|
agents['network']['topology'] = 'default'
|
||||||
|
|
||||||
|
for agent in old['network_agents']:
|
||||||
|
agent = updated_agent(agent)
|
||||||
|
if 'agent_id' in agent:
|
||||||
|
fixed.append(agent)
|
||||||
|
else:
|
||||||
|
by_weight.append(agent)
|
||||||
|
|
||||||
|
if 'agent_type' in old and (not fixed and not by_weight):
|
||||||
|
agents['network']['topology'] = 'default'
|
||||||
|
by_weight = [{'agent_type': old['agent_type']}]
|
||||||
|
|
||||||
|
agents['network']['fixed'] = fixed
|
||||||
|
agents['network']['distribution'] = by_weight
|
||||||
|
|
||||||
environment = {'params': {}}
|
environment = {'params': {}}
|
||||||
if 'environment_class' in old:
|
if 'environment_class' in old:
|
||||||
@ -176,8 +213,8 @@ def convert_old(old):
|
|||||||
for (k, v) in old.get('environment_params', {}).items():
|
for (k, v) in old.get('environment_params', {}).items():
|
||||||
environment['params'][k] = v
|
environment['params'][k] = v
|
||||||
|
|
||||||
|
return Config(version='2',
|
||||||
return Config(general=general,
|
general=general,
|
||||||
network=network,
|
topologies={'default': network},
|
||||||
environment=environment,
|
environment=environment,
|
||||||
agents=agents)
|
agents=agents)
|
||||||
|
@ -3,6 +3,10 @@ import os
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
from collections import namedtuple
|
||||||
from time import time as current_time
|
from time import time as current_time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from networkx.readwrite import json_graph
|
from networkx.readwrite import json_graph
|
||||||
@ -12,9 +16,11 @@ import networkx as nx
|
|||||||
|
|
||||||
from mesa import Model
|
from mesa import Model
|
||||||
|
|
||||||
from tsih import Record
|
from . import serialization, agents, analysis, utils, time, config, network
|
||||||
|
|
||||||
|
|
||||||
|
Record = namedtuple('Record', 'dict_id t_step key value')
|
||||||
|
|
||||||
from . import serialization, agents, analysis, utils, time, config
|
|
||||||
|
|
||||||
class Environment(Model):
|
class Environment(Model):
|
||||||
"""
|
"""
|
||||||
@ -28,15 +34,17 @@ class Environment(Model):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
env_id,
|
env_id='unnamed_env',
|
||||||
seed='default',
|
seed='default',
|
||||||
schedule=None,
|
schedule=None,
|
||||||
env_params=None,
|
|
||||||
dir_path=None,
|
dir_path=None,
|
||||||
**kwargs):
|
interval=1,
|
||||||
|
agents: Dict[str, config.AgentConfig] = {},
|
||||||
|
topologies: Dict[str, config.NetConfig] = {},
|
||||||
|
**env_params):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.current_id = -1
|
||||||
|
|
||||||
self.seed = '{}_{}'.format(seed, env_id)
|
self.seed = '{}_{}'.format(seed, env_id)
|
||||||
self.id = env_id
|
self.id = env_id
|
||||||
@ -51,25 +59,28 @@ class Environment(Model):
|
|||||||
|
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
if isinstance(states, list):
|
|
||||||
states = dict(enumerate(states))
|
|
||||||
self.states = deepcopy(states) if states else {}
|
|
||||||
self.default_state = deepcopy(default_state) or {}
|
|
||||||
|
|
||||||
|
|
||||||
self.set_topology(topology=topology,
|
|
||||||
network_params=network_params)
|
|
||||||
|
|
||||||
|
self.topologies = {}
|
||||||
|
for (name, cfg) in topologies.items():
|
||||||
|
self.set_topology(cfg=cfg,
|
||||||
|
graph=name)
|
||||||
self.agents = agents or {}
|
self.agents = agents or {}
|
||||||
|
|
||||||
self.env_params = env_params or {}
|
self.env_params = env_params or {}
|
||||||
self.env_params.update(kwargs)
|
|
||||||
|
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self['SEED'] = seed
|
self['SEED'] = seed
|
||||||
|
|
||||||
|
self.logger = utils.logger.getChild(self.id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def topology(self):
|
||||||
|
return self.topologies['default']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def network_agents(self):
|
||||||
|
yield from self.agents(agent_type=agents.NetworkAgent, iterator=True)
|
||||||
|
|
||||||
self.logger = utils.logger.getChild(self.name)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(conf: config.Config, trial_id, **kwargs) -> Environment:
|
def from_config(conf: config.Config, trial_id, **kwargs) -> Environment:
|
||||||
@ -92,39 +103,30 @@ class Environment(Model):
|
|||||||
raise Exception('The environment has not been scheduled, so it has no sense of time')
|
raise Exception('The environment has not been scheduled, so it has no sense of time')
|
||||||
|
|
||||||
|
|
||||||
def set_topology(self, topology, network_params=None, dir_path=None):
|
def set_topology(self, cfg=None, dir_path=None, graph='default'):
|
||||||
if topology is None:
|
self.topologies[graph] = network.from_config(cfg, dir_path=dir_path)
|
||||||
network_params = network_params or {}
|
|
||||||
topology = serialization.load_network(network_params,
|
|
||||||
dir_path=dir_path or self.dir_path)
|
|
||||||
if not topology:
|
|
||||||
topology = nx.Graph()
|
|
||||||
self.G = nx.Graph(topology)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def agents(self):
|
def agents(self):
|
||||||
for agents in self.agents.values():
|
return agents.AgentView(self._agents)
|
||||||
yield from agents
|
|
||||||
|
|
||||||
@agents.setter
|
@agents.setter
|
||||||
def agents(self, agents):
|
def agents(self, agents_def: Dict[str, config.AgentConfig]):
|
||||||
self.agents = {}
|
self._agents = agents.from_config(agents_def, env=self)
|
||||||
|
for d in self._agents.values():
|
||||||
|
for a in d.values():
|
||||||
|
self.schedule.add(a)
|
||||||
|
|
||||||
for (k, v) in agents.items():
|
|
||||||
self.agents[k] = agents.from_config(v)
|
|
||||||
for agent in self.agents.get('network', []):
|
|
||||||
node = self.G.nodes[agent.unique_id]
|
|
||||||
node['agent'] = agent
|
|
||||||
|
|
||||||
@property
|
# @property
|
||||||
def network_agents(self):
|
# def network_agents(self):
|
||||||
for i in self.G.nodes():
|
# for i in self.G.nodes():
|
||||||
node = self.G.nodes[i]
|
# node = self.G.nodes[i]
|
||||||
if 'agent' in node:
|
# if 'agent' in node:
|
||||||
yield node['agent']
|
# yield node['agent']
|
||||||
|
|
||||||
def init_agent(self, agent_id, agent_definitions):
|
def init_agent(self, agent_id, agent_definitions, graph='default'):
|
||||||
node = self.G.nodes[agent_id]
|
node = self.topologies[graph].nodes[agent_id]
|
||||||
init = False
|
init = False
|
||||||
state = dict(node)
|
state = dict(node)
|
||||||
|
|
||||||
@ -145,8 +147,8 @@ class Environment(Model):
|
|||||||
return
|
return
|
||||||
return self.set_agent(agent_id, agent_type, state)
|
return self.set_agent(agent_id, agent_type, state)
|
||||||
|
|
||||||
def set_agent(self, agent_id, agent_type, state=None):
|
def set_agent(self, agent_id, agent_type, state=None, graph='default'):
|
||||||
node = self.G.nodes[agent_id]
|
node = self.topologies[graph].nodes[agent_id]
|
||||||
defstate = deepcopy(self.default_state) or {}
|
defstate = deepcopy(self.default_state) or {}
|
||||||
defstate.update(self.states.get(agent_id, {}))
|
defstate.update(self.states.get(agent_id, {}))
|
||||||
defstate.update(node.get('state', {}))
|
defstate.update(node.get('state', {}))
|
||||||
@ -166,20 +168,20 @@ class Environment(Model):
|
|||||||
self.schedule.add(a)
|
self.schedule.add(a)
|
||||||
return a
|
return a
|
||||||
|
|
||||||
def add_node(self, agent_type, state=None):
|
def add_node(self, agent_type, state=None, graph='default'):
|
||||||
agent_id = int(len(self.G.nodes()))
|
agent_id = int(len(self.topologies[graph].nodes()))
|
||||||
self.G.add_node(agent_id)
|
self.topologies[graph].add_node(agent_id)
|
||||||
a = self.set_agent(agent_id, agent_type, state)
|
a = self.set_agent(agent_id, agent_type, state, graph=graph)
|
||||||
a['visible'] = True
|
a['visible'] = True
|
||||||
return a
|
return a
|
||||||
|
|
||||||
def add_edge(self, agent1, agent2, start=None, **attrs):
|
def add_edge(self, agent1, agent2, start=None, graph='default', **attrs):
|
||||||
if hasattr(agent1, 'id'):
|
if hasattr(agent1, 'id'):
|
||||||
agent1 = agent1.id
|
agent1 = agent1.id
|
||||||
if hasattr(agent2, 'id'):
|
if hasattr(agent2, 'id'):
|
||||||
agent2 = agent2.id
|
agent2 = agent2.id
|
||||||
start = start or self.now
|
start = start or self.now
|
||||||
return self.G.add_edge(agent1, agent2, **attrs)
|
return self.topologies[graph].add_edge(agent1, agent2, **attrs)
|
||||||
|
|
||||||
def log(self, message, *args, level=logging.INFO, **kwargs):
|
def log(self, message, *args, level=logging.INFO, **kwargs):
|
||||||
if not self.logger.isEnabledFor(level):
|
if not self.logger.isEnabledFor(level):
|
||||||
@ -190,7 +192,7 @@ class Environment(Model):
|
|||||||
message += " {k}={v} ".format(k, v)
|
message += " {k}={v} ".format(k, v)
|
||||||
extra = {}
|
extra = {}
|
||||||
extra['now'] = self.now
|
extra['now'] = self.now
|
||||||
extra['unique_id'] = self.name
|
extra['unique_id'] = self.id
|
||||||
return self.logger.log(level, message, extra=extra)
|
return self.logger.log(level, message, extra=extra)
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
@ -207,30 +209,6 @@ class Environment(Model):
|
|||||||
self.step()
|
self.step()
|
||||||
utils.logger.debug(f'Simulation step {self.schedule.time}/{until}. Next: {self.schedule.next_time}')
|
utils.logger.debug(f'Simulation step {self.schedule.time}/{until}. Next: {self.schedule.next_time}')
|
||||||
self.schedule.time = until
|
self.schedule.time = until
|
||||||
self._history.flush_cache()
|
|
||||||
|
|
||||||
def _save_state(self, now=None):
|
|
||||||
serialization.logger.debug('Saving state @{}'.format(self.now))
|
|
||||||
self._history.save_records(self.state_to_tuples(now=now))
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
if isinstance(key, tuple):
|
|
||||||
self._history.flush_cache()
|
|
||||||
return self._history[key]
|
|
||||||
|
|
||||||
return self.environment_params[key]
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
if isinstance(key, tuple):
|
|
||||||
k = Key(*key)
|
|
||||||
self._history.save_record(*k,
|
|
||||||
value=value)
|
|
||||||
return
|
|
||||||
self.environment_params[key] = value
|
|
||||||
self._history.save_record(dict_id='env',
|
|
||||||
t_step=self.now,
|
|
||||||
key=key,
|
|
||||||
value=value)
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
return key in self.env_params
|
return key in self.env_params
|
||||||
@ -248,14 +226,6 @@ class Environment(Model):
|
|||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
return self.env_params.__setitem__(key, value)
|
return self.env_params.__setitem__(key, value)
|
||||||
|
|
||||||
def get_agent(self, agent_id):
|
|
||||||
return self.G.nodes[agent_id]['agent']
|
|
||||||
|
|
||||||
def get_agents(self, nodes=None):
|
|
||||||
if nodes is None:
|
|
||||||
return self.agents
|
|
||||||
return (self.G.nodes[i]['agent'] for i in nodes)
|
|
||||||
|
|
||||||
def _agent_to_tuples(self, agent, now=None):
|
def _agent_to_tuples(self, agent, now=None):
|
||||||
if now is None:
|
if now is None:
|
||||||
now = self.now
|
now = self.now
|
||||||
@ -270,7 +240,7 @@ class Environment(Model):
|
|||||||
now = self.now
|
now = self.now
|
||||||
|
|
||||||
if agent_id:
|
if agent_id:
|
||||||
agent = self.get_agent(agent_id)
|
agent = self.agents[agent_id]
|
||||||
yield from self._agent_to_tuples(agent, now)
|
yield from self._agent_to_tuples(agent, now)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -283,4 +253,5 @@ class Environment(Model):
|
|||||||
yield from self._agent_to_tuples(agent, now)
|
yield from self._agent_to_tuples(agent, now)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
SoilEnvironment = Environment
|
SoilEnvironment = Environment
|
||||||
|
@ -16,39 +16,39 @@ from jinja2 import Template
|
|||||||
logger = logging.getLogger('soil')
|
logger = logging.getLogger('soil')
|
||||||
|
|
||||||
|
|
||||||
def load_network(network_params, dir_path=None):
|
# def load_network(network_params, dir_path=None):
|
||||||
G = nx.Graph()
|
# G = nx.Graph()
|
||||||
|
|
||||||
if not network_params:
|
# if not network_params:
|
||||||
return G
|
# return G
|
||||||
|
|
||||||
if 'path' in network_params:
|
# if 'path' in network_params:
|
||||||
path = network_params['path']
|
# path = network_params['path']
|
||||||
if dir_path and not os.path.isabs(path):
|
# if dir_path and not os.path.isabs(path):
|
||||||
path = os.path.join(dir_path, path)
|
# path = os.path.join(dir_path, path)
|
||||||
extension = os.path.splitext(path)[1][1:]
|
# extension = os.path.splitext(path)[1][1:]
|
||||||
kwargs = {}
|
# kwargs = {}
|
||||||
if extension == 'gexf':
|
# if extension == 'gexf':
|
||||||
kwargs['version'] = '1.2draft'
|
# kwargs['version'] = '1.2draft'
|
||||||
kwargs['node_type'] = int
|
# kwargs['node_type'] = int
|
||||||
try:
|
# try:
|
||||||
method = getattr(nx.readwrite, 'read_' + extension)
|
# method = getattr(nx.readwrite, 'read_' + extension)
|
||||||
except AttributeError:
|
# except AttributeError:
|
||||||
raise AttributeError('Unknown format')
|
# raise AttributeError('Unknown format')
|
||||||
G = method(path, **kwargs)
|
# G = method(path, **kwargs)
|
||||||
|
|
||||||
elif 'generator' in network_params:
|
# elif 'generator' in network_params:
|
||||||
net_args = network_params.copy()
|
# net_args = network_params.copy()
|
||||||
net_gen = net_args.pop('generator')
|
# net_gen = net_args.pop('generator')
|
||||||
|
|
||||||
if dir_path not in sys.path:
|
# if dir_path not in sys.path:
|
||||||
sys.path.append(dir_path)
|
# sys.path.append(dir_path)
|
||||||
|
|
||||||
method = deserializer(net_gen,
|
# method = deserializer(net_gen,
|
||||||
known_modules=['networkx.generators',])
|
# known_modules=['networkx.generators',])
|
||||||
G = method(**net_args)
|
# G = method(**net_args)
|
||||||
|
|
||||||
return G
|
# return G
|
||||||
|
|
||||||
|
|
||||||
def load_file(infile):
|
def load_file(infile):
|
||||||
@ -122,8 +122,8 @@ def load_files(*patterns, **kwargs):
|
|||||||
for i in glob(pattern, **kwargs):
|
for i in glob(pattern, **kwargs):
|
||||||
for config in load_file(i):
|
for config in load_file(i):
|
||||||
path = os.path.abspath(i)
|
path = os.path.abspath(i)
|
||||||
if 'dir_path' not in config:
|
if 'general' in config and 'dir_path' not in config['general']:
|
||||||
config['dir_path'] = os.path.dirname(path)
|
config['general']['dir_path'] = os.path.dirname(path)
|
||||||
yield config, path
|
yield config, path
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,7 +96,7 @@ class Simulation:
|
|||||||
stat.sim_start()
|
stat.sim_start()
|
||||||
|
|
||||||
for exporter in exporters:
|
for exporter in exporters:
|
||||||
exporter.start()
|
exporter.sim_start()
|
||||||
|
|
||||||
for env in self._run_sync_or_async(parallel=parallel,
|
for env in self._run_sync_or_async(parallel=parallel,
|
||||||
log_level=log_level,
|
log_level=log_level,
|
||||||
@ -107,7 +107,7 @@ class Simulation:
|
|||||||
|
|
||||||
collected = list(stat.trial_end(env) for stat in stats)
|
collected = list(stat.trial_end(env) for stat in stats)
|
||||||
|
|
||||||
saved = self._update_stats(collected, t_step=env.now, trial_id=env.name)
|
saved = self._update_stats(collected, t_step=env.now, trial_id=env.id)
|
||||||
|
|
||||||
for exporter in exporters:
|
for exporter in exporters:
|
||||||
exporter.trial_end(env, saved)
|
exporter.trial_end(env, saved)
|
||||||
@ -117,6 +117,9 @@ class Simulation:
|
|||||||
collected = list(stat.end() for stat in stats)
|
collected = list(stat.end() for stat in stats)
|
||||||
saved = self._update_stats(collected)
|
saved = self._update_stats(collected)
|
||||||
|
|
||||||
|
for stat in stats:
|
||||||
|
stat.sim_end()
|
||||||
|
|
||||||
for exporter in exporters:
|
for exporter in exporters:
|
||||||
exporter.sim_end(saved)
|
exporter.sim_end(saved)
|
||||||
|
|
||||||
@ -131,24 +134,24 @@ class Simulation:
|
|||||||
|
|
||||||
def get_env(self, trial_id=0, **kwargs):
|
def get_env(self, trial_id=0, **kwargs):
|
||||||
'''Create an environment for a trial of the simulation'''
|
'''Create an environment for a trial of the simulation'''
|
||||||
opts = self.environment_params.copy()
|
# opts = self.environment_params.copy()
|
||||||
opts.update({
|
# opts.update({
|
||||||
'name': '{}_trial_{}'.format(self.name, trial_id),
|
# 'name': '{}_trial_{}'.format(self.name, trial_id),
|
||||||
'topology': self.topology.copy(),
|
# 'topology': self.topology.copy(),
|
||||||
'network_params': self.network_params,
|
# 'network_params': self.network_params,
|
||||||
'seed': '{}_trial_{}'.format(self.seed, trial_id),
|
# 'seed': '{}_trial_{}'.format(self.seed, trial_id),
|
||||||
'initial_time': 0,
|
# 'initial_time': 0,
|
||||||
'interval': self.interval,
|
# 'interval': self.interval,
|
||||||
'network_agents': self.network_agents,
|
# 'network_agents': self.network_agents,
|
||||||
'initial_time': 0,
|
# 'initial_time': 0,
|
||||||
'states': self.states,
|
# 'states': self.states,
|
||||||
'dir_path': self.dir_path,
|
# 'dir_path': self.dir_path,
|
||||||
'default_state': self.default_state,
|
# 'default_state': self.default_state,
|
||||||
'history': bool(self._history),
|
# 'history': bool(self._history),
|
||||||
'environment_agents': self.environment_agents,
|
# 'environment_agents': self.environment_agents,
|
||||||
})
|
# })
|
||||||
opts.update(kwargs)
|
# opts.update(kwargs)
|
||||||
env = self.environment_class(**opts)
|
env = Environment.from_config(self.config, trial_id=trial_id, **kwargs)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def run_trial(self, trial_id=None, until=None, log_level=logging.INFO, **opts):
|
def run_trial(self, trial_id=None, until=None, log_level=logging.INFO, **opts):
|
||||||
@ -162,7 +165,7 @@ class Simulation:
|
|||||||
# Set-up trial environment and graph
|
# Set-up trial environment and graph
|
||||||
until = until or self.config.general.max_time
|
until = until or self.config.general.max_time
|
||||||
|
|
||||||
env = Environment.from_config(self.config, trial_id=trial_id)
|
env = self.get_env(trial_id, **opts)
|
||||||
# Set up agents on nodes
|
# Set up agents on nodes
|
||||||
with utils.timer('Simulation {} trial {}'.format(self.config.general.id, trial_id)):
|
with utils.timer('Simulation {} trial {}'.format(self.config.general.id, trial_id)):
|
||||||
env.run(until)
|
env.run(until)
|
||||||
@ -181,21 +184,31 @@ class Simulation:
|
|||||||
ex.message = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)[:])
|
ex.message = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)[:])
|
||||||
return ex
|
return ex
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return self.config.dict()
|
||||||
|
|
||||||
|
def to_yaml(self):
|
||||||
|
return yaml.dump(self.config.dict())
|
||||||
|
|
||||||
|
|
||||||
def all_from_config(config):
|
def all_from_config(config):
|
||||||
configs = list(serialization.load_config(config))
|
configs = list(serialization.load_config(config))
|
||||||
for config, _ in configs:
|
for config, path in configs:
|
||||||
sim = Simulation(**config)
|
if config.get('version', '1') == '1':
|
||||||
|
config = convert_old(config)
|
||||||
|
if not isinstance(config, Config):
|
||||||
|
config = Config(**config)
|
||||||
|
if not config.general.dir_path:
|
||||||
|
config.general.dir_path = os.path.dirname(path)
|
||||||
|
sim = Simulation(config=config)
|
||||||
yield sim
|
yield sim
|
||||||
|
|
||||||
|
|
||||||
def from_config(conf_or_path):
|
def from_config(conf_or_path):
|
||||||
config = list(serialization.load_config(conf_or_path))
|
lst = list(all_from_config(conf_or_path))
|
||||||
if len(config) > 1:
|
if len(lst) > 1:
|
||||||
raise AttributeError('Provide only one configuration')
|
raise AttributeError('Provide only one configuration')
|
||||||
config = config[0][0]
|
return lst[0]
|
||||||
sim = Simulation(**config)
|
|
||||||
return sim
|
|
||||||
|
|
||||||
def from_old_config(conf_or_path):
|
def from_old_config(conf_or_path):
|
||||||
config = list(serialization.load_config(conf_or_path))
|
config = list(serialization.load_config(conf_or_path))
|
||||||
@ -206,13 +219,7 @@ def from_old_config(conf_or_path):
|
|||||||
|
|
||||||
|
|
||||||
def run_from_config(*configs, **kwargs):
|
def run_from_config(*configs, **kwargs):
|
||||||
for config_def in configs:
|
for sim in all_from_config(configs):
|
||||||
# logger.info("Found {} config(s)".format(len(ls)))
|
name = config.general.id
|
||||||
for config, path in serialization.load_config(config_def):
|
logger.info("Using config(s): {name}".format(name=name))
|
||||||
name = config.general.id
|
sim.run_simulation(**kwargs)
|
||||||
logger.info("Using config(s): {name}".format(name=name))
|
|
||||||
|
|
||||||
dir_path = config.general.dir_path or os.path.dirname(path)
|
|
||||||
sim = Simulation(dir_path=dir_path,
|
|
||||||
**config)
|
|
||||||
sim.run_simulation(**kwargs)
|
|
||||||
|
@ -3,7 +3,7 @@ from queue import Empty
|
|||||||
from heapq import heappush, heappop
|
from heapq import heappush, heappop
|
||||||
import math
|
import math
|
||||||
from .utils import logger
|
from .utils import logger
|
||||||
from mesa import Agent
|
from mesa import Agent as MesaAgent
|
||||||
|
|
||||||
|
|
||||||
INFINITY = float('inf')
|
INFINITY = float('inf')
|
||||||
@ -41,7 +41,7 @@ class TimedActivation(BaseScheduler):
|
|||||||
self._queue = []
|
self._queue = []
|
||||||
self.next_time = 0
|
self.next_time = 0
|
||||||
|
|
||||||
def add(self, agent: Agent):
|
def add(self, agent: MesaAgent):
|
||||||
if agent.unique_id not in self._agents:
|
if agent.unique_id not in self._agents:
|
||||||
heappush(self._queue, (self.time, agent.unique_id))
|
heappush(self._queue, (self.time, agent.unique_id))
|
||||||
super().add(agent)
|
super().add(agent)
|
||||||
|
@ -2,7 +2,7 @@ from unittest import TestCase
|
|||||||
import os
|
import os
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
|
||||||
from soil import serialization, config
|
from soil import simulation, serialization, config, network
|
||||||
|
|
||||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||||
EXAMPLES = join(ROOT, '..', 'examples')
|
EXAMPLES = join(ROOT, '..', 'examples')
|
||||||
@ -13,18 +13,58 @@ FORCE_TESTS = os.environ.get('FORCE_TESTS', '')
|
|||||||
class TestConfig(TestCase):
|
class TestConfig(TestCase):
|
||||||
|
|
||||||
def test_conversion(self):
|
def test_conversion(self):
|
||||||
new = serialization.load_file(join(EXAMPLES, "complete.yml"))[0]
|
new = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
|
||||||
old = serialization.load_file(join(ROOT, "old_complete.yml"))[0]
|
old = serialization.load_file(join(ROOT, "old_complete.yml"))[0]
|
||||||
converted = config.convert_old(old).dict(skip_defaults=True)
|
converted_defaults = config.convert_old(old, strict=False)
|
||||||
for (k, v) in new.items():
|
converted = converted_defaults.dict(skip_defaults=True)
|
||||||
assert v == converted[k]
|
def isequal(old, new):
|
||||||
|
if isinstance(old, dict):
|
||||||
|
for (k, v) in old.items():
|
||||||
|
isequal(old[k], new[k])
|
||||||
|
return
|
||||||
|
assert old == new
|
||||||
|
|
||||||
|
isequal(new, converted)
|
||||||
|
|
||||||
|
def test_topology_config(self):
|
||||||
|
netconfig = config.NetConfig(**{
|
||||||
|
'path': join(ROOT, 'test.gexf')
|
||||||
|
})
|
||||||
|
net = network.from_config(netconfig, dir_path=ROOT)
|
||||||
|
assert len(net.nodes) == 2
|
||||||
|
assert len(net.edges) == 1
|
||||||
|
|
||||||
|
def test_env_from_config(self):
|
||||||
|
"""
|
||||||
|
Simple configuration that tests that the graph is loaded, and that
|
||||||
|
network agents are initialized properly.
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
'name': 'CounterAgent',
|
||||||
|
'network_params': {
|
||||||
|
'path': join(ROOT, 'test.gexf')
|
||||||
|
},
|
||||||
|
'agent_type': 'CounterModel',
|
||||||
|
# 'states': [{'times': 10}, {'times': 20}],
|
||||||
|
'max_time': 2,
|
||||||
|
'dry_run': True,
|
||||||
|
'num_trials': 1,
|
||||||
|
'environment_params': {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s = simulation.from_old_config(config)
|
||||||
|
env = s.get_env()
|
||||||
|
assert len(env.topologies['default'].nodes) == 2
|
||||||
|
assert len(env.topologies['default'].edges) == 1
|
||||||
|
assert len(env.agents) == 2
|
||||||
|
assert env.agents[0].topology == env.topologies['default']
|
||||||
|
|
||||||
|
|
||||||
def make_example_test(path, cfg):
|
def make_example_test(path, cfg):
|
||||||
def wrapped(self):
|
def wrapped(self):
|
||||||
root = os.getcwd()
|
root = os.getcwd()
|
||||||
s = config.Config(**cfg)
|
print(path)
|
||||||
import pdb;pdb.set_trace()
|
s = simulation.from_config(cfg)
|
||||||
# for s in simulation.all_from_config(path):
|
# for s in simulation.all_from_config(path):
|
||||||
# iterations = s.config.max_time * s.config.num_trials
|
# iterations = s.config.max_time * s.config.num_trials
|
||||||
# if iterations > 1000:
|
# if iterations > 1000:
|
||||||
|
@ -18,17 +18,17 @@ class Dummy(exporters.Exporter):
|
|||||||
called_trial = 0
|
called_trial = 0
|
||||||
called_end = 0
|
called_end = 0
|
||||||
|
|
||||||
def start(self):
|
def sim_start(self):
|
||||||
self.__class__.called_start += 1
|
self.__class__.called_start += 1
|
||||||
self.__class__.started = True
|
self.__class__.started = True
|
||||||
|
|
||||||
def trial(self, env, stats):
|
def trial_end(self, env, stats):
|
||||||
assert env
|
assert env
|
||||||
self.__class__.trials += 1
|
self.__class__.trials += 1
|
||||||
self.__class__.total_time += env.now
|
self.__class__.total_time += env.now
|
||||||
self.__class__.called_trial += 1
|
self.__class__.called_trial += 1
|
||||||
|
|
||||||
def end(self, stats):
|
def sim_end(self, stats):
|
||||||
self.__class__.ended = True
|
self.__class__.ended = True
|
||||||
self.__class__.called_end += 1
|
self.__class__.called_end += 1
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ import networkx as nx
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from os.path import join
|
from os.path import join
|
||||||
from soil import (simulation, Environment, agents, serialization,
|
from soil import (simulation, Environment, agents, network, serialization,
|
||||||
utils)
|
utils)
|
||||||
from soil.time import Delta
|
from soil.time import Delta
|
||||||
|
|
||||||
@ -17,7 +17,7 @@ ROOT = os.path.abspath(os.path.dirname(__file__))
|
|||||||
EXAMPLES = join(ROOT, '..', 'examples')
|
EXAMPLES = join(ROOT, '..', 'examples')
|
||||||
|
|
||||||
|
|
||||||
class CustomAgent(agents.FSM):
|
class CustomAgent(agents.FSM, agents.NetworkAgent):
|
||||||
@agents.default_state
|
@agents.default_state
|
||||||
@agents.state
|
@agents.state
|
||||||
def normal(self):
|
def normal(self):
|
||||||
@ -39,7 +39,7 @@ class TestMain(TestCase):
|
|||||||
'path': join(ROOT, 'test.gexf')
|
'path': join(ROOT, 'test.gexf')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
G = serialization.load_network(config['network_params'])
|
G = network.from_config(config['network_params'])
|
||||||
assert G
|
assert G
|
||||||
assert len(G) == 2
|
assert len(G) == 2
|
||||||
with self.assertRaises(AttributeError):
|
with self.assertRaises(AttributeError):
|
||||||
@ -48,7 +48,7 @@ class TestMain(TestCase):
|
|||||||
'path': join(ROOT, 'unknown.extension')
|
'path': join(ROOT, 'unknown.extension')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
G = serialization.load_network(config['network_params'])
|
G = network.from_config(config['network_params'])
|
||||||
print(G)
|
print(G)
|
||||||
|
|
||||||
def test_generate_barabasi(self):
|
def test_generate_barabasi(self):
|
||||||
@ -56,16 +56,16 @@ class TestMain(TestCase):
|
|||||||
If no path is given, a generator and network parameters
|
If no path is given, a generator and network parameters
|
||||||
should be used to generate a network
|
should be used to generate a network
|
||||||
"""
|
"""
|
||||||
config = {
|
cfg = {
|
||||||
'network_params': {
|
'params': {
|
||||||
'generator': 'barabasi_albert_graph'
|
'generator': 'barabasi_albert_graph'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(Exception):
|
||||||
G = serialization.load_network(config['network_params'])
|
G = network.from_config(cfg)
|
||||||
config['network_params']['n'] = 100
|
cfg['params']['n'] = 100
|
||||||
config['network_params']['m'] = 10
|
cfg['params']['m'] = 10
|
||||||
G = serialization.load_network(config['network_params'])
|
G = network.from_config(cfg)
|
||||||
assert len(G) == 100
|
assert len(G) == 100
|
||||||
|
|
||||||
def test_empty_simulation(self):
|
def test_empty_simulation(self):
|
||||||
@ -103,28 +103,43 @@ class TestMain(TestCase):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
s = simulation.from_old_config(config)
|
s = simulation.from_old_config(config)
|
||||||
|
|
||||||
def test_counter_agent(self):
|
def test_counter_agent(self):
|
||||||
"""
|
"""
|
||||||
The initial states should be applied to the agent and the
|
The initial states should be applied to the agent and the
|
||||||
agent should be able to update its state."""
|
agent should be able to update its state."""
|
||||||
config = {
|
config = {
|
||||||
'name': 'CounterAgent',
|
'version': '2',
|
||||||
'network_params': {
|
'general': {
|
||||||
'path': join(ROOT, 'test.gexf')
|
'name': 'CounterAgent',
|
||||||
|
'max_time': 2,
|
||||||
|
'dry_run': True,
|
||||||
|
'num_trials': 1,
|
||||||
},
|
},
|
||||||
'agent_type': 'CounterModel',
|
'topologies': {
|
||||||
'states': [{'times': 10}, {'times': 20}],
|
'default': {
|
||||||
'max_time': 2,
|
'path': join(ROOT, 'test.gexf')
|
||||||
'num_trials': 1,
|
}
|
||||||
'environment_params': {
|
},
|
||||||
|
'agents': {
|
||||||
|
'default': {
|
||||||
|
'agent_class': 'CounterModel',
|
||||||
|
},
|
||||||
|
'counters': {
|
||||||
|
'topology': 'default',
|
||||||
|
'fixed': [{'state': {'times': 10}}, {'state': {'times': 20}}],
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s = simulation.from_old_config(config)
|
s = simulation.from_config(config)
|
||||||
env = s.run_simulation(dry_run=True)[0]
|
env = s.get_env()
|
||||||
assert env.get_agent(0)['times', 0] == 11
|
assert isinstance(env.agents[0], agents.CounterModel)
|
||||||
assert env.get_agent(0)['times', 1] == 12
|
assert env.agents[0].topology == env.topologies['default']
|
||||||
assert env.get_agent(1)['times', 0] == 21
|
assert env.agents[0]['times'] == 10
|
||||||
assert env.get_agent(1)['times', 1] == 22
|
assert env.agents[0]['times'] == 10
|
||||||
|
env.step()
|
||||||
|
assert env.agents[0]['times'] == 11
|
||||||
|
assert env.agents[1]['times'] == 21
|
||||||
|
|
||||||
def test_custom_agent(self):
|
def test_custom_agent(self):
|
||||||
"""Allow for search of neighbors with a certain state_id"""
|
"""Allow for search of neighbors with a certain state_id"""
|
||||||
@ -143,9 +158,9 @@ class TestMain(TestCase):
|
|||||||
}
|
}
|
||||||
s = simulation.from_old_config(config)
|
s = simulation.from_old_config(config)
|
||||||
env = s.run_simulation(dry_run=True)[0]
|
env = s.run_simulation(dry_run=True)[0]
|
||||||
assert env.get_agent(1).count_agents(state_id='normal') == 2
|
assert env.agents[1].count_agents(state_id='normal') == 2
|
||||||
assert env.get_agent(1).count_agents(state_id='normal', limit_neighbors=True) == 1
|
assert env.agents[1].count_agents(state_id='normal', limit_neighbors=True) == 1
|
||||||
assert env.get_agent(0).neighbors == 1
|
assert env.agents[0].neighbors == 1
|
||||||
|
|
||||||
def test_torvalds_example(self):
|
def test_torvalds_example(self):
|
||||||
"""A complete example from a documentation should work."""
|
"""A complete example from a documentation should work."""
|
||||||
@ -180,11 +195,9 @@ class TestMain(TestCase):
|
|||||||
config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
||||||
s = simulation.from_old_config(config)
|
s = simulation.from_old_config(config)
|
||||||
with utils.timer('serializing'):
|
with utils.timer('serializing'):
|
||||||
serial = s.config.to_yaml()
|
serial = s.to_yaml()
|
||||||
with utils.timer('recovering'):
|
with utils.timer('recovering'):
|
||||||
recovered = yaml.load(serial, Loader=yaml.SafeLoader)
|
recovered = yaml.load(serial, Loader=yaml.SafeLoader)
|
||||||
with utils.timer('deleting'):
|
|
||||||
del recovered['topology']
|
|
||||||
for (k, v) in config.items():
|
for (k, v) in config.items():
|
||||||
assert recovered[k] == v
|
assert recovered[k] == v
|
||||||
# assert config == recovered
|
# assert config == recovered
|
||||||
|
Loading…
Reference in New Issue
Block a user