mirror of
https://github.com/gsi-upm/soil
synced 2024-11-24 20:02:28 +00:00
WIP
This commit is contained in:
parent
bbaed636a8
commit
e41dc3dae2
@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [UNRELEASED]
|
||||
### Changed
|
||||
* Configuration schema is very different now. Check `soil.config` for more information. We are using Pydantic for (de)serialization.
|
||||
* There may be more than one topology/network in the simulation
|
||||
* Agents are split into groups now. Each group may be assigned a given set of agents or an agent distribution, and a network topology to be assigned to.
|
||||
### Removed
|
||||
* Any `tsih` and `History` integration in the main classes. To record the state of environments/agents, just use a datacollector. In some cases this may be slower or consume more memory than the previous system. However, few cases actually used the full potential of the history, and it came at the cost of unnecessary complexity and worse performance for the majority of cases.
|
||||
## [0.20.7]
|
||||
### Changed
|
||||
* Creating a `time.When` from another `time.When` does not nest them anymore (it returns the argument)
|
||||
|
@ -1,38 +1,59 @@
|
||||
---
|
||||
version: '2'
|
||||
general:
|
||||
name: simple
|
||||
id: simple
|
||||
group: tests
|
||||
dir_path: "/tmp/"
|
||||
num_trials: 3
|
||||
max_time: 100
|
||||
interval: 1
|
||||
seed: "CompleteSeed!"
|
||||
network:
|
||||
group:
|
||||
network
|
||||
topologies:
|
||||
default:
|
||||
params:
|
||||
generator: complete_graph
|
||||
n: 10
|
||||
another_graph:
|
||||
params:
|
||||
generator: complete_graph
|
||||
n: 2
|
||||
environment:
|
||||
environment_class: Environment
|
||||
params:
|
||||
am_i_complete: true
|
||||
agents:
|
||||
default:
|
||||
# Agents are split into groups, each with its own definition
|
||||
default: # This is a special group. Its values will be used as default values for the rest of the groups
|
||||
agent_class: CounterModel
|
||||
topology: default
|
||||
state:
|
||||
times: 1
|
||||
environment:
|
||||
# In this group we are not specifying any topology
|
||||
fixed:
|
||||
- agent_id: 'Environment Agent 1'
|
||||
agent_class: CounterModel
|
||||
state:
|
||||
times: 10
|
||||
network:
|
||||
general_counters:
|
||||
topology: default
|
||||
distribution:
|
||||
- agent_class: CounterModel
|
||||
weight: 1
|
||||
state:
|
||||
state_id: 0
|
||||
id: 0
|
||||
times: 3
|
||||
- agent_class: AggregatedCounter
|
||||
weight: 0.2
|
||||
other_counters:
|
||||
topology: another_graph
|
||||
fixed:
|
||||
- agent_class: CounterModel
|
||||
id: 0
|
||||
state:
|
||||
times: 1
|
||||
total: 0
|
||||
- agent_class: CounterModel
|
||||
id: 1
|
||||
# If not specified, it will use the state set in the default
|
||||
# state:
|
||||
|
@ -6,5 +6,4 @@ pandas>=0.23
|
||||
SALib>=1.3
|
||||
Jinja2
|
||||
Mesa>=0.8.9
|
||||
tsih>=0.1.6
|
||||
pydantic>=1.9
|
||||
|
@ -7,9 +7,15 @@ class CounterModel(NetworkAgent):
|
||||
in each step and adds it to its state.
|
||||
"""
|
||||
|
||||
defaults = {
|
||||
'times': 0,
|
||||
'neighbors': 0,
|
||||
'total': 0
|
||||
}
|
||||
|
||||
def step(self):
|
||||
# Outside effects
|
||||
total = len(list(self.get_agents()))
|
||||
total = len(list(self.env.agents))
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['times'] = self.get('times', 0) + 1
|
||||
self['neighbors'] = neighbors
|
||||
@ -33,6 +39,6 @@ class AggregatedCounter(NetworkAgent):
|
||||
self['times'] += 1
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['neighbors'] += neighbors
|
||||
total = len(list(self.get_agents()))
|
||||
total = len(list(self.env.agents))
|
||||
self['total'] += total
|
||||
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
||||
|
@ -1,16 +1,17 @@
|
||||
import logging
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections.abc import Mapping, Set
|
||||
from copy import deepcopy
|
||||
from functools import partial, wraps
|
||||
from itertools import islice
|
||||
from itertools import islice, chain
|
||||
import json
|
||||
import networkx as nx
|
||||
|
||||
from .. import serialization, utils, time
|
||||
from mesa import Agent as MesaAgent
|
||||
from typing import Dict, List
|
||||
|
||||
from tsih import Key
|
||||
from .. import serialization, utils, time, config
|
||||
|
||||
from mesa import Agent
|
||||
|
||||
|
||||
def as_node(agent):
|
||||
@ -24,7 +25,7 @@ IGNORED_FIELDS = ('model', 'logger')
|
||||
class DeadAgent(Exception):
|
||||
pass
|
||||
|
||||
class BaseAgent(Agent):
|
||||
class BaseAgent(MesaAgent):
|
||||
"""
|
||||
A special type of Mesa Agent that:
|
||||
|
||||
@ -47,9 +48,8 @@ class BaseAgent(Agent):
|
||||
):
|
||||
# Check for REQUIRED arguments
|
||||
# Initialize agent parameters
|
||||
if isinstance(unique_id, Agent):
|
||||
if isinstance(unique_id, MesaAgent):
|
||||
raise Exception()
|
||||
self._saved = set()
|
||||
super().__init__(unique_id=unique_id, model=model)
|
||||
self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id)
|
||||
|
||||
@ -57,7 +57,7 @@ class BaseAgent(Agent):
|
||||
self.alive = True
|
||||
|
||||
self.interval = interval or self.get('interval', 1)
|
||||
self.logger = logging.getLogger(self.model.name).getChild(self.name)
|
||||
self.logger = logging.getLogger(self.model.id).getChild(self.name)
|
||||
|
||||
if hasattr(self, 'level'):
|
||||
self.logger.setLevel(self.level)
|
||||
@ -66,6 +66,7 @@ class BaseAgent(Agent):
|
||||
setattr(self, k, deepcopy(v))
|
||||
|
||||
for (k, v) in kwargs.items():
|
||||
|
||||
setattr(self, k, v)
|
||||
|
||||
for (k, v) in getattr(self, 'defaults', {}).items():
|
||||
@ -107,23 +108,7 @@ class BaseAgent(Agent):
|
||||
def environment_params(self, value):
|
||||
self.model.environment_params = value
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if not key.startswith('_') and key not in IGNORED_FIELDS:
|
||||
try:
|
||||
k = Key(t_step=self.now,
|
||||
dict_id=self.unique_id,
|
||||
key=key)
|
||||
self._saved.add(key)
|
||||
self.model[k] = value
|
||||
except AttributeError:
|
||||
pass
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, tuple):
|
||||
key, t_step = key
|
||||
k = Key(key=key, t_step=t_step, dict_id=self.unique_id)
|
||||
return self.model[k]
|
||||
return getattr(self, key)
|
||||
|
||||
def __delitem__(self, key):
|
||||
@ -135,8 +120,11 @@ class BaseAgent(Agent):
|
||||
def __setitem__(self, key, value):
|
||||
setattr(self, key, value)
|
||||
|
||||
def keys(self):
|
||||
return (k for k in self.__dict__ if k[0] != '_')
|
||||
|
||||
def items(self):
|
||||
return ((k, getattr(self, k)) for k in self._saved)
|
||||
return ((k, v) for (k, v) in self.__dict__.items() if k[0] != '_')
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self[key] if key in self else default
|
||||
@ -174,22 +162,32 @@ class BaseAgent(Agent):
|
||||
extra['agent_name'] = self.name
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
|
||||
|
||||
|
||||
def debug(self, *args, **kwargs):
|
||||
return self.log(*args, level=logging.DEBUG, **kwargs)
|
||||
|
||||
def info(self, *args, **kwargs):
|
||||
return self.log(*args, level=logging.INFO, **kwargs)
|
||||
|
||||
# Alias
|
||||
# Agent = BaseAgent
|
||||
|
||||
class NetworkAgent(BaseAgent):
|
||||
|
||||
@property
|
||||
def topology(self):
|
||||
return self.model.G
|
||||
def __init__(self,
|
||||
*args,
|
||||
graph_name: str,
|
||||
node_id: int = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.graph_name = graph_name
|
||||
self.topology = self.env.topologies[self.graph_name]
|
||||
self.node_id = node_id
|
||||
|
||||
@property
|
||||
def G(self):
|
||||
return self.model.G
|
||||
return self.model.topologies[self._topology]
|
||||
|
||||
def count_agents(self, **kwargs):
|
||||
return len(list(self.get_agents(**kwargs)))
|
||||
@ -210,8 +208,7 @@ class NetworkAgent(BaseAgent):
|
||||
if limit_neighbors:
|
||||
agents = self.topology.neighbors(self.unique_id)
|
||||
|
||||
agents = self.model.get_agents(agents)
|
||||
return select(agents, **kwargs)
|
||||
return self.model.agents(ids=agents, **kwargs)
|
||||
|
||||
def subgraph(self, center=True, **kwargs):
|
||||
include = [self] if center else []
|
||||
@ -229,7 +226,6 @@ class NetworkAgent(BaseAgent):
|
||||
|
||||
self.topology.add_edge(self.unique_id, other.unique_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
||||
|
||||
|
||||
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
||||
'''Get a list of nodes in the ego network of *node* of radius *steps*'''
|
||||
node = as_node(node if node is not None else self)
|
||||
@ -311,7 +307,7 @@ class MetaFSM(type):
|
||||
cls.states = states
|
||||
|
||||
|
||||
class FSM(NetworkAgent, metaclass=MetaFSM):
|
||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FSM, self).__init__(*args, **kwargs)
|
||||
if not hasattr(self, 'state_id'):
|
||||
@ -537,18 +533,59 @@ def _definition_to_dict(definition, size=None, default_state=None):
|
||||
return agents
|
||||
|
||||
|
||||
def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs):
|
||||
class AgentView(Mapping, Set):
|
||||
"""A lazy-loaded list of agents.
|
||||
"""
|
||||
|
||||
__slots__ = ("_agents",)
|
||||
|
||||
|
||||
def __init__(self, agents):
|
||||
self._agents = agents
|
||||
|
||||
def __getstate__(self):
|
||||
return {"_agents": self._agents}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._agents = state["_agents"]
|
||||
|
||||
# Mapping methods
|
||||
def __len__(self):
|
||||
return sum(len(x) for x in self._agents.values())
|
||||
|
||||
def __iter__(self):
|
||||
return iter(chain.from_iterable(g.values() for g in self._agents.values()))
|
||||
|
||||
def __getitem__(self, agent_id):
|
||||
if isinstance(agent_id, slice):
|
||||
raise ValueError(f"Slicing is not supported")
|
||||
for group in self._agents.values():
|
||||
if agent_id in group:
|
||||
return group[agent_id]
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
def filter(self, ids=None, groups=None, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs):
|
||||
|
||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||
state_id = tuple([state_id])
|
||||
|
||||
agents = self._agents
|
||||
|
||||
if groups:
|
||||
agents = {(k,v) for (k, v) in agents.items() if k in groups}
|
||||
|
||||
if agent_type is not None:
|
||||
try:
|
||||
agent_type = tuple(agent_type)
|
||||
except TypeError:
|
||||
agent_type = tuple([agent_type])
|
||||
|
||||
f = agents
|
||||
if ids:
|
||||
agents = (v[aid] for v in agents.values() for aid in ids if aid in v)
|
||||
else:
|
||||
agents = (a for v in agents.values() for a in v.values())
|
||||
|
||||
f = agents
|
||||
if ignore:
|
||||
f = filter(lambda x: x not in ignore, f)
|
||||
|
||||
@ -562,7 +599,105 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False,
|
||||
|
||||
if iterator:
|
||||
return f
|
||||
return f
|
||||
return list(f)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.filter(*args, **kwargs)
|
||||
|
||||
def __contains__(self, agent_id):
|
||||
return any(agent_id in g for g in self._agents)
|
||||
|
||||
def __str__(self):
|
||||
return str(list(a.id for a in self))
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self})"
|
||||
|
||||
|
||||
def from_config(cfg: Dict[str, config.AgentConfig], env):
|
||||
'''
|
||||
Agents are specified in groups.
|
||||
Each group can be specified in two ways, either through a fixed list in which each item has
|
||||
has the agent type, number of agents to create, and the other parameters, or through what we call
|
||||
an `agent distribution`, which is similar but instead of number of agents, it specifies the weight
|
||||
of each agent type.
|
||||
'''
|
||||
default = cfg.get('default', None)
|
||||
return {k: _group_from_config(c, default=default, env=env) for (k, c) in cfg.items() if k is not 'default'}
|
||||
|
||||
|
||||
def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfig, env):
|
||||
agents = {}
|
||||
if cfg.fixed is not None:
|
||||
agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env)
|
||||
if cfg.distribution:
|
||||
n = cfg.n or len(env.topologies[cfg.topology])
|
||||
agents.update(_from_distro(cfg.distribution, n - len(agents),
|
||||
topology=cfg.topology or default.topology,
|
||||
default=default,
|
||||
env=env))
|
||||
return agents
|
||||
|
||||
|
||||
def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: config.SingleAgentConfig, env):
|
||||
agents = {}
|
||||
|
||||
for fixed in lst:
|
||||
agent_id = fixed.agent_id
|
||||
if agent_id is None:
|
||||
agent_id = env.next_id()
|
||||
|
||||
cls = serialization.deserialize(fixed.agent_class or default.agent_class)
|
||||
state = fixed.state.copy()
|
||||
state.update(default.state)
|
||||
agents[agent_id] = cls(unique_id=agent_id,
|
||||
model=env,
|
||||
graph_name=fixed.topology or topology or default.topology,
|
||||
**state)
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
def _from_distro(distro: List[config.AgentDistro],
|
||||
n: int,
|
||||
topology: str,
|
||||
default: config.SingleAgentConfig,
|
||||
env):
|
||||
|
||||
agents = {}
|
||||
|
||||
if n is None:
|
||||
if any(lambda dist: dist.n is None, distro):
|
||||
raise ValueError('You must provide a total number of agents, or the number of each type')
|
||||
n = sum(dist.n for dist in distro)
|
||||
|
||||
|
||||
total = sum((dist.weight if dist.weight is not None else 1) for dist in distro)
|
||||
thres = {}
|
||||
last = 0
|
||||
for i in sorted(distro, key=lambda x: x.weight):
|
||||
|
||||
cls = serialization.deserialize(i.agent_class or default.agent_class)
|
||||
thres[(last, last + i.weight/total)] = (cls, i)
|
||||
|
||||
acc = 0
|
||||
|
||||
# using np.choice would be more efficient, but this allows us to use soil without
|
||||
# numpy
|
||||
for i in range(n):
|
||||
r = random.random()
|
||||
for (t, (cls, d)) in thres.items():
|
||||
if r >= t[0] and r <= t[1]:
|
||||
agent_id = d.agent_id
|
||||
if agent_id is None:
|
||||
agent_id = env.next_id()
|
||||
|
||||
state = d.state.copy()
|
||||
state.update(default.state)
|
||||
agents[agent_id] = cls(unique_id=agent_id, model=env, graph_name=d.topology or topology or default.topology, **state)
|
||||
break
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
from .BassModel import *
|
||||
|
@ -11,7 +11,7 @@ from pydantic import BaseModel, Extra
|
||||
class General(BaseModel):
|
||||
id: str = 'Unnamed Simulation'
|
||||
group: str = None
|
||||
dir_path: str = None
|
||||
dir_path: Optional[str] = None
|
||||
num_trials: int = 1
|
||||
max_time: float = 100
|
||||
interval: float = 1
|
||||
@ -27,13 +27,13 @@ nodeId = int
|
||||
|
||||
class Node(BaseModel):
|
||||
id: nodeId
|
||||
state: Dict[str, Any]
|
||||
state: Optional[Dict[str, Any]] = {}
|
||||
|
||||
|
||||
class Edge(BaseModel):
|
||||
source: nodeId
|
||||
target: nodeId
|
||||
value: float = 1
|
||||
value: Optional[float] = 1
|
||||
|
||||
|
||||
class Topology(BaseModel):
|
||||
@ -75,46 +75,62 @@ class EnvConfig(BaseModel):
|
||||
|
||||
|
||||
class SingleAgentConfig(BaseModel):
|
||||
agent_class: Union[Type, str] = 'soil.Agent'
|
||||
agent_class: Optional[Union[Type, str]] = None
|
||||
agent_id: Optional[Union[str, int]] = None
|
||||
params: Dict[str, Any] = {}
|
||||
state: Dict[str, Any] = {}
|
||||
topology: Optional[str] = None
|
||||
state: Optional[Dict[str, Any]] = {}
|
||||
|
||||
|
||||
class AgentDistro(SingleAgentConfig):
|
||||
weight: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
class FixedAgentConfig(SingleAgentConfig):
|
||||
n: Optional[int] = 1
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, values):
|
||||
if 'weight' in values and 'count' in values:
|
||||
raise ValueError("You may either specify a weight in the distribution or an agent count")
|
||||
if 'agent_id' in values and values.get('n', 1) > 1:
|
||||
raise ValueError("An agent_id can only be provided when there is only one agent")
|
||||
return values
|
||||
|
||||
|
||||
class AgentDistro(SingleAgentConfig):
|
||||
weight: Optional[float] = 1
|
||||
|
||||
|
||||
class AgentConfig(SingleAgentConfig):
|
||||
n: Optional[int] = None
|
||||
topology: Optional[str] = None
|
||||
distribution: Optional[List[AgentDistro]] = None
|
||||
fixed: Optional[List[SingleAgentConfig]] = None
|
||||
fixed: Optional[List[FixedAgentConfig]] = None
|
||||
|
||||
@staticmethod
|
||||
def default():
|
||||
return AgentConfig()
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, values):
|
||||
if 'distribution' in values and ('n' not in values and 'topology' not in values):
|
||||
raise ValueError("You need to provide the number of agents or a topology to extract the value from.")
|
||||
return values
|
||||
|
||||
|
||||
class Config(BaseModel, extra=Extra.forbid):
|
||||
version: Optional[str] = '1'
|
||||
general: General = General.default()
|
||||
network: Optional[NetConfig] = None
|
||||
topologies: Optional[Dict[str, NetConfig]] = {}
|
||||
environment: EnvConfig = EnvConfig.default()
|
||||
agents: Dict[str, AgentConfig] = {}
|
||||
agents: Optional[Dict[str, AgentConfig]] = {}
|
||||
|
||||
|
||||
def convert_old(old):
|
||||
def convert_old(old, strict=True):
|
||||
'''
|
||||
Try to convert old style configs into the new format.
|
||||
|
||||
This is still a work in progress and might not work in many cases.
|
||||
'''
|
||||
|
||||
# TODO: translate states
|
||||
|
||||
if strict and old.get('states', {}):
|
||||
raise ValueError('Custom (i.e., manual) agent states cannot be translated to v2 configuration files. Please, convert your configuration file to the new format.')
|
||||
|
||||
new = {}
|
||||
|
||||
|
||||
@ -129,7 +145,10 @@ def convert_old(old):
|
||||
if k in old:
|
||||
general[k] = old[k]
|
||||
|
||||
network = {'group': 'network'}
|
||||
if 'name' in old:
|
||||
general['id'] = old['name']
|
||||
|
||||
network = {}
|
||||
|
||||
|
||||
if 'network_params' in old and old['network_params']:
|
||||
@ -143,9 +162,6 @@ def convert_old(old):
|
||||
network['topology'] = old['topology']
|
||||
|
||||
agents = {
|
||||
'environment': {
|
||||
'fixed': []
|
||||
},
|
||||
'network': {},
|
||||
'default': {},
|
||||
}
|
||||
@ -164,10 +180,31 @@ def convert_old(old):
|
||||
return newagent
|
||||
|
||||
for agent in old.get('environment_agents', []):
|
||||
agents['environment'] = {'distribution': [], 'fixed': []}
|
||||
if 'agent_id' not in agent:
|
||||
agents['environment']['distribution'].append(updated_agent(agent))
|
||||
else:
|
||||
agents['environment']['fixed'].append(updated_agent(agent))
|
||||
|
||||
for agent in old.get('network_agents', []):
|
||||
agents['network'].setdefault('distribution', []).append(updated_agent(agent))
|
||||
by_weight = []
|
||||
fixed = []
|
||||
|
||||
if 'network_agents' in old:
|
||||
agents['network']['topology'] = 'default'
|
||||
|
||||
for agent in old['network_agents']:
|
||||
agent = updated_agent(agent)
|
||||
if 'agent_id' in agent:
|
||||
fixed.append(agent)
|
||||
else:
|
||||
by_weight.append(agent)
|
||||
|
||||
if 'agent_type' in old and (not fixed and not by_weight):
|
||||
agents['network']['topology'] = 'default'
|
||||
by_weight = [{'agent_type': old['agent_type']}]
|
||||
|
||||
agents['network']['fixed'] = fixed
|
||||
agents['network']['distribution'] = by_weight
|
||||
|
||||
environment = {'params': {}}
|
||||
if 'environment_class' in old:
|
||||
@ -176,8 +213,8 @@ def convert_old(old):
|
||||
for (k, v) in old.get('environment_params', {}).items():
|
||||
environment['params'][k] = v
|
||||
|
||||
|
||||
return Config(general=general,
|
||||
network=network,
|
||||
return Config(version='2',
|
||||
general=general,
|
||||
topologies={'default': network},
|
||||
environment=environment,
|
||||
agents=agents)
|
||||
|
@ -3,6 +3,10 @@ import os
|
||||
import sqlite3
|
||||
import math
|
||||
import random
|
||||
import logging
|
||||
|
||||
from typing import Dict
|
||||
from collections import namedtuple
|
||||
from time import time as current_time
|
||||
from copy import deepcopy
|
||||
from networkx.readwrite import json_graph
|
||||
@ -12,9 +16,11 @@ import networkx as nx
|
||||
|
||||
from mesa import Model
|
||||
|
||||
from tsih import Record
|
||||
from . import serialization, agents, analysis, utils, time, config, network
|
||||
|
||||
|
||||
Record = namedtuple('Record', 'dict_id t_step key value')
|
||||
|
||||
from . import serialization, agents, analysis, utils, time, config
|
||||
|
||||
class Environment(Model):
|
||||
"""
|
||||
@ -28,15 +34,17 @@ class Environment(Model):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
env_id,
|
||||
env_id='unnamed_env',
|
||||
seed='default',
|
||||
schedule=None,
|
||||
env_params=None,
|
||||
dir_path=None,
|
||||
**kwargs):
|
||||
interval=1,
|
||||
agents: Dict[str, config.AgentConfig] = {},
|
||||
topologies: Dict[str, config.NetConfig] = {},
|
||||
**env_params):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.current_id = -1
|
||||
|
||||
self.seed = '{}_{}'.format(seed, env_id)
|
||||
self.id = env_id
|
||||
@ -51,25 +59,28 @@ class Environment(Model):
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
if isinstance(states, list):
|
||||
states = dict(enumerate(states))
|
||||
self.states = deepcopy(states) if states else {}
|
||||
self.default_state = deepcopy(default_state) or {}
|
||||
|
||||
|
||||
self.set_topology(topology=topology,
|
||||
network_params=network_params)
|
||||
|
||||
self.topologies = {}
|
||||
for (name, cfg) in topologies.items():
|
||||
self.set_topology(cfg=cfg,
|
||||
graph=name)
|
||||
self.agents = agents or {}
|
||||
|
||||
self.env_params = env_params or {}
|
||||
self.env_params.update(kwargs)
|
||||
|
||||
self.interval = interval
|
||||
self['SEED'] = seed
|
||||
|
||||
self.logger = utils.logger.getChild(self.id)
|
||||
|
||||
@property
|
||||
def topology(self):
|
||||
return self.topologies['default']
|
||||
|
||||
@property
|
||||
def network_agents(self):
|
||||
yield from self.agents(agent_type=agents.NetworkAgent, iterator=True)
|
||||
|
||||
self.logger = utils.logger.getChild(self.name)
|
||||
|
||||
@staticmethod
|
||||
def from_config(conf: config.Config, trial_id, **kwargs) -> Environment:
|
||||
@ -92,39 +103,30 @@ class Environment(Model):
|
||||
raise Exception('The environment has not been scheduled, so it has no sense of time')
|
||||
|
||||
|
||||
def set_topology(self, topology, network_params=None, dir_path=None):
|
||||
if topology is None:
|
||||
network_params = network_params or {}
|
||||
topology = serialization.load_network(network_params,
|
||||
dir_path=dir_path or self.dir_path)
|
||||
if not topology:
|
||||
topology = nx.Graph()
|
||||
self.G = nx.Graph(topology)
|
||||
def set_topology(self, cfg=None, dir_path=None, graph='default'):
|
||||
self.topologies[graph] = network.from_config(cfg, dir_path=dir_path)
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
for agents in self.agents.values():
|
||||
yield from agents
|
||||
return agents.AgentView(self._agents)
|
||||
|
||||
@agents.setter
|
||||
def agents(self, agents):
|
||||
self.agents = {}
|
||||
def agents(self, agents_def: Dict[str, config.AgentConfig]):
|
||||
self._agents = agents.from_config(agents_def, env=self)
|
||||
for d in self._agents.values():
|
||||
for a in d.values():
|
||||
self.schedule.add(a)
|
||||
|
||||
for (k, v) in agents.items():
|
||||
self.agents[k] = agents.from_config(v)
|
||||
for agent in self.agents.get('network', []):
|
||||
node = self.G.nodes[agent.unique_id]
|
||||
node['agent'] = agent
|
||||
|
||||
@property
|
||||
def network_agents(self):
|
||||
for i in self.G.nodes():
|
||||
node = self.G.nodes[i]
|
||||
if 'agent' in node:
|
||||
yield node['agent']
|
||||
# @property
|
||||
# def network_agents(self):
|
||||
# for i in self.G.nodes():
|
||||
# node = self.G.nodes[i]
|
||||
# if 'agent' in node:
|
||||
# yield node['agent']
|
||||
|
||||
def init_agent(self, agent_id, agent_definitions):
|
||||
node = self.G.nodes[agent_id]
|
||||
def init_agent(self, agent_id, agent_definitions, graph='default'):
|
||||
node = self.topologies[graph].nodes[agent_id]
|
||||
init = False
|
||||
state = dict(node)
|
||||
|
||||
@ -145,8 +147,8 @@ class Environment(Model):
|
||||
return
|
||||
return self.set_agent(agent_id, agent_type, state)
|
||||
|
||||
def set_agent(self, agent_id, agent_type, state=None):
|
||||
node = self.G.nodes[agent_id]
|
||||
def set_agent(self, agent_id, agent_type, state=None, graph='default'):
|
||||
node = self.topologies[graph].nodes[agent_id]
|
||||
defstate = deepcopy(self.default_state) or {}
|
||||
defstate.update(self.states.get(agent_id, {}))
|
||||
defstate.update(node.get('state', {}))
|
||||
@ -166,20 +168,20 @@ class Environment(Model):
|
||||
self.schedule.add(a)
|
||||
return a
|
||||
|
||||
def add_node(self, agent_type, state=None):
|
||||
agent_id = int(len(self.G.nodes()))
|
||||
self.G.add_node(agent_id)
|
||||
a = self.set_agent(agent_id, agent_type, state)
|
||||
def add_node(self, agent_type, state=None, graph='default'):
|
||||
agent_id = int(len(self.topologies[graph].nodes()))
|
||||
self.topologies[graph].add_node(agent_id)
|
||||
a = self.set_agent(agent_id, agent_type, state, graph=graph)
|
||||
a['visible'] = True
|
||||
return a
|
||||
|
||||
def add_edge(self, agent1, agent2, start=None, **attrs):
|
||||
def add_edge(self, agent1, agent2, start=None, graph='default', **attrs):
|
||||
if hasattr(agent1, 'id'):
|
||||
agent1 = agent1.id
|
||||
if hasattr(agent2, 'id'):
|
||||
agent2 = agent2.id
|
||||
start = start or self.now
|
||||
return self.G.add_edge(agent1, agent2, **attrs)
|
||||
return self.topologies[graph].add_edge(agent1, agent2, **attrs)
|
||||
|
||||
def log(self, message, *args, level=logging.INFO, **kwargs):
|
||||
if not self.logger.isEnabledFor(level):
|
||||
@ -190,7 +192,7 @@ class Environment(Model):
|
||||
message += " {k}={v} ".format(k, v)
|
||||
extra = {}
|
||||
extra['now'] = self.now
|
||||
extra['unique_id'] = self.name
|
||||
extra['unique_id'] = self.id
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
|
||||
def step(self):
|
||||
@ -207,30 +209,6 @@ class Environment(Model):
|
||||
self.step()
|
||||
utils.logger.debug(f'Simulation step {self.schedule.time}/{until}. Next: {self.schedule.next_time}')
|
||||
self.schedule.time = until
|
||||
self._history.flush_cache()
|
||||
|
||||
def _save_state(self, now=None):
|
||||
serialization.logger.debug('Saving state @{}'.format(self.now))
|
||||
self._history.save_records(self.state_to_tuples(now=now))
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, tuple):
|
||||
self._history.flush_cache()
|
||||
return self._history[key]
|
||||
|
||||
return self.environment_params[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if isinstance(key, tuple):
|
||||
k = Key(*key)
|
||||
self._history.save_record(*k,
|
||||
value=value)
|
||||
return
|
||||
self.environment_params[key] = value
|
||||
self._history.save_record(dict_id='env',
|
||||
t_step=self.now,
|
||||
key=key,
|
||||
value=value)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.env_params
|
||||
@ -248,14 +226,6 @@ class Environment(Model):
|
||||
def __setitem__(self, key, value):
|
||||
return self.env_params.__setitem__(key, value)
|
||||
|
||||
def get_agent(self, agent_id):
|
||||
return self.G.nodes[agent_id]['agent']
|
||||
|
||||
def get_agents(self, nodes=None):
|
||||
if nodes is None:
|
||||
return self.agents
|
||||
return (self.G.nodes[i]['agent'] for i in nodes)
|
||||
|
||||
def _agent_to_tuples(self, agent, now=None):
|
||||
if now is None:
|
||||
now = self.now
|
||||
@ -270,7 +240,7 @@ class Environment(Model):
|
||||
now = self.now
|
||||
|
||||
if agent_id:
|
||||
agent = self.get_agent(agent_id)
|
||||
agent = self.agents[agent_id]
|
||||
yield from self._agent_to_tuples(agent, now)
|
||||
return
|
||||
|
||||
@ -283,4 +253,5 @@ class Environment(Model):
|
||||
yield from self._agent_to_tuples(agent, now)
|
||||
|
||||
|
||||
|
||||
SoilEnvironment = Environment
|
||||
|
@ -16,39 +16,39 @@ from jinja2 import Template
|
||||
logger = logging.getLogger('soil')
|
||||
|
||||
|
||||
def load_network(network_params, dir_path=None):
|
||||
G = nx.Graph()
|
||||
# def load_network(network_params, dir_path=None):
|
||||
# G = nx.Graph()
|
||||
|
||||
if not network_params:
|
||||
return G
|
||||
# if not network_params:
|
||||
# return G
|
||||
|
||||
if 'path' in network_params:
|
||||
path = network_params['path']
|
||||
if dir_path and not os.path.isabs(path):
|
||||
path = os.path.join(dir_path, path)
|
||||
extension = os.path.splitext(path)[1][1:]
|
||||
kwargs = {}
|
||||
if extension == 'gexf':
|
||||
kwargs['version'] = '1.2draft'
|
||||
kwargs['node_type'] = int
|
||||
try:
|
||||
method = getattr(nx.readwrite, 'read_' + extension)
|
||||
except AttributeError:
|
||||
raise AttributeError('Unknown format')
|
||||
G = method(path, **kwargs)
|
||||
# if 'path' in network_params:
|
||||
# path = network_params['path']
|
||||
# if dir_path and not os.path.isabs(path):
|
||||
# path = os.path.join(dir_path, path)
|
||||
# extension = os.path.splitext(path)[1][1:]
|
||||
# kwargs = {}
|
||||
# if extension == 'gexf':
|
||||
# kwargs['version'] = '1.2draft'
|
||||
# kwargs['node_type'] = int
|
||||
# try:
|
||||
# method = getattr(nx.readwrite, 'read_' + extension)
|
||||
# except AttributeError:
|
||||
# raise AttributeError('Unknown format')
|
||||
# G = method(path, **kwargs)
|
||||
|
||||
elif 'generator' in network_params:
|
||||
net_args = network_params.copy()
|
||||
net_gen = net_args.pop('generator')
|
||||
# elif 'generator' in network_params:
|
||||
# net_args = network_params.copy()
|
||||
# net_gen = net_args.pop('generator')
|
||||
|
||||
if dir_path not in sys.path:
|
||||
sys.path.append(dir_path)
|
||||
# if dir_path not in sys.path:
|
||||
# sys.path.append(dir_path)
|
||||
|
||||
method = deserializer(net_gen,
|
||||
known_modules=['networkx.generators',])
|
||||
G = method(**net_args)
|
||||
# method = deserializer(net_gen,
|
||||
# known_modules=['networkx.generators',])
|
||||
# G = method(**net_args)
|
||||
|
||||
return G
|
||||
# return G
|
||||
|
||||
|
||||
def load_file(infile):
|
||||
@ -122,8 +122,8 @@ def load_files(*patterns, **kwargs):
|
||||
for i in glob(pattern, **kwargs):
|
||||
for config in load_file(i):
|
||||
path = os.path.abspath(i)
|
||||
if 'dir_path' not in config:
|
||||
config['dir_path'] = os.path.dirname(path)
|
||||
if 'general' in config and 'dir_path' not in config['general']:
|
||||
config['general']['dir_path'] = os.path.dirname(path)
|
||||
yield config, path
|
||||
|
||||
|
||||
|
@ -96,7 +96,7 @@ class Simulation:
|
||||
stat.sim_start()
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.start()
|
||||
exporter.sim_start()
|
||||
|
||||
for env in self._run_sync_or_async(parallel=parallel,
|
||||
log_level=log_level,
|
||||
@ -107,7 +107,7 @@ class Simulation:
|
||||
|
||||
collected = list(stat.trial_end(env) for stat in stats)
|
||||
|
||||
saved = self._update_stats(collected, t_step=env.now, trial_id=env.name)
|
||||
saved = self._update_stats(collected, t_step=env.now, trial_id=env.id)
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.trial_end(env, saved)
|
||||
@ -117,6 +117,9 @@ class Simulation:
|
||||
collected = list(stat.end() for stat in stats)
|
||||
saved = self._update_stats(collected)
|
||||
|
||||
for stat in stats:
|
||||
stat.sim_end()
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.sim_end(saved)
|
||||
|
||||
@ -131,24 +134,24 @@ class Simulation:
|
||||
|
||||
def get_env(self, trial_id=0, **kwargs):
|
||||
'''Create an environment for a trial of the simulation'''
|
||||
opts = self.environment_params.copy()
|
||||
opts.update({
|
||||
'name': '{}_trial_{}'.format(self.name, trial_id),
|
||||
'topology': self.topology.copy(),
|
||||
'network_params': self.network_params,
|
||||
'seed': '{}_trial_{}'.format(self.seed, trial_id),
|
||||
'initial_time': 0,
|
||||
'interval': self.interval,
|
||||
'network_agents': self.network_agents,
|
||||
'initial_time': 0,
|
||||
'states': self.states,
|
||||
'dir_path': self.dir_path,
|
||||
'default_state': self.default_state,
|
||||
'history': bool(self._history),
|
||||
'environment_agents': self.environment_agents,
|
||||
})
|
||||
opts.update(kwargs)
|
||||
env = self.environment_class(**opts)
|
||||
# opts = self.environment_params.copy()
|
||||
# opts.update({
|
||||
# 'name': '{}_trial_{}'.format(self.name, trial_id),
|
||||
# 'topology': self.topology.copy(),
|
||||
# 'network_params': self.network_params,
|
||||
# 'seed': '{}_trial_{}'.format(self.seed, trial_id),
|
||||
# 'initial_time': 0,
|
||||
# 'interval': self.interval,
|
||||
# 'network_agents': self.network_agents,
|
||||
# 'initial_time': 0,
|
||||
# 'states': self.states,
|
||||
# 'dir_path': self.dir_path,
|
||||
# 'default_state': self.default_state,
|
||||
# 'history': bool(self._history),
|
||||
# 'environment_agents': self.environment_agents,
|
||||
# })
|
||||
# opts.update(kwargs)
|
||||
env = Environment.from_config(self.config, trial_id=trial_id, **kwargs)
|
||||
return env
|
||||
|
||||
def run_trial(self, trial_id=None, until=None, log_level=logging.INFO, **opts):
|
||||
@ -162,7 +165,7 @@ class Simulation:
|
||||
# Set-up trial environment and graph
|
||||
until = until or self.config.general.max_time
|
||||
|
||||
env = Environment.from_config(self.config, trial_id=trial_id)
|
||||
env = self.get_env(trial_id, **opts)
|
||||
# Set up agents on nodes
|
||||
with utils.timer('Simulation {} trial {}'.format(self.config.general.id, trial_id)):
|
||||
env.run(until)
|
||||
@ -181,21 +184,31 @@ class Simulation:
|
||||
ex.message = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)[:])
|
||||
return ex
|
||||
|
||||
def to_dict(self):
|
||||
return self.config.dict()
|
||||
|
||||
def to_yaml(self):
|
||||
return yaml.dump(self.config.dict())
|
||||
|
||||
|
||||
def all_from_config(config):
|
||||
configs = list(serialization.load_config(config))
|
||||
for config, _ in configs:
|
||||
sim = Simulation(**config)
|
||||
for config, path in configs:
|
||||
if config.get('version', '1') == '1':
|
||||
config = convert_old(config)
|
||||
if not isinstance(config, Config):
|
||||
config = Config(**config)
|
||||
if not config.general.dir_path:
|
||||
config.general.dir_path = os.path.dirname(path)
|
||||
sim = Simulation(config=config)
|
||||
yield sim
|
||||
|
||||
|
||||
def from_config(conf_or_path):
|
||||
config = list(serialization.load_config(conf_or_path))
|
||||
if len(config) > 1:
|
||||
lst = list(all_from_config(conf_or_path))
|
||||
if len(lst) > 1:
|
||||
raise AttributeError('Provide only one configuration')
|
||||
config = config[0][0]
|
||||
sim = Simulation(**config)
|
||||
return sim
|
||||
return lst[0]
|
||||
|
||||
def from_old_config(conf_or_path):
|
||||
config = list(serialization.load_config(conf_or_path))
|
||||
@ -206,13 +219,7 @@ def from_old_config(conf_or_path):
|
||||
|
||||
|
||||
def run_from_config(*configs, **kwargs):
|
||||
for config_def in configs:
|
||||
# logger.info("Found {} config(s)".format(len(ls)))
|
||||
for config, path in serialization.load_config(config_def):
|
||||
for sim in all_from_config(configs):
|
||||
name = config.general.id
|
||||
logger.info("Using config(s): {name}".format(name=name))
|
||||
|
||||
dir_path = config.general.dir_path or os.path.dirname(path)
|
||||
sim = Simulation(dir_path=dir_path,
|
||||
**config)
|
||||
sim.run_simulation(**kwargs)
|
||||
|
@ -3,7 +3,7 @@ from queue import Empty
|
||||
from heapq import heappush, heappop
|
||||
import math
|
||||
from .utils import logger
|
||||
from mesa import Agent
|
||||
from mesa import Agent as MesaAgent
|
||||
|
||||
|
||||
INFINITY = float('inf')
|
||||
@ -41,7 +41,7 @@ class TimedActivation(BaseScheduler):
|
||||
self._queue = []
|
||||
self.next_time = 0
|
||||
|
||||
def add(self, agent: Agent):
|
||||
def add(self, agent: MesaAgent):
|
||||
if agent.unique_id not in self._agents:
|
||||
heappush(self._queue, (self.time, agent.unique_id))
|
||||
super().add(agent)
|
||||
|
@ -2,7 +2,7 @@ from unittest import TestCase
|
||||
import os
|
||||
from os.path import join
|
||||
|
||||
from soil import serialization, config
|
||||
from soil import simulation, serialization, config, network
|
||||
|
||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
EXAMPLES = join(ROOT, '..', 'examples')
|
||||
@ -13,18 +13,58 @@ FORCE_TESTS = os.environ.get('FORCE_TESTS', '')
|
||||
class TestConfig(TestCase):
|
||||
|
||||
def test_conversion(self):
|
||||
new = serialization.load_file(join(EXAMPLES, "complete.yml"))[0]
|
||||
new = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
|
||||
old = serialization.load_file(join(ROOT, "old_complete.yml"))[0]
|
||||
converted = config.convert_old(old).dict(skip_defaults=True)
|
||||
for (k, v) in new.items():
|
||||
assert v == converted[k]
|
||||
converted_defaults = config.convert_old(old, strict=False)
|
||||
converted = converted_defaults.dict(skip_defaults=True)
|
||||
def isequal(old, new):
|
||||
if isinstance(old, dict):
|
||||
for (k, v) in old.items():
|
||||
isequal(old[k], new[k])
|
||||
return
|
||||
assert old == new
|
||||
|
||||
isequal(new, converted)
|
||||
|
||||
def test_topology_config(self):
|
||||
netconfig = config.NetConfig(**{
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
})
|
||||
net = network.from_config(netconfig, dir_path=ROOT)
|
||||
assert len(net.nodes) == 2
|
||||
assert len(net.edges) == 1
|
||||
|
||||
def test_env_from_config(self):
|
||||
"""
|
||||
Simple configuration that tests that the graph is loaded, and that
|
||||
network agents are initialized properly.
|
||||
"""
|
||||
config = {
|
||||
'name': 'CounterAgent',
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
},
|
||||
'agent_type': 'CounterModel',
|
||||
# 'states': [{'times': 10}, {'times': 20}],
|
||||
'max_time': 2,
|
||||
'dry_run': True,
|
||||
'num_trials': 1,
|
||||
'environment_params': {
|
||||
}
|
||||
}
|
||||
s = simulation.from_old_config(config)
|
||||
env = s.get_env()
|
||||
assert len(env.topologies['default'].nodes) == 2
|
||||
assert len(env.topologies['default'].edges) == 1
|
||||
assert len(env.agents) == 2
|
||||
assert env.agents[0].topology == env.topologies['default']
|
||||
|
||||
|
||||
def make_example_test(path, cfg):
|
||||
def wrapped(self):
|
||||
root = os.getcwd()
|
||||
s = config.Config(**cfg)
|
||||
import pdb;pdb.set_trace()
|
||||
print(path)
|
||||
s = simulation.from_config(cfg)
|
||||
# for s in simulation.all_from_config(path):
|
||||
# iterations = s.config.max_time * s.config.num_trials
|
||||
# if iterations > 1000:
|
||||
|
@ -18,17 +18,17 @@ class Dummy(exporters.Exporter):
|
||||
called_trial = 0
|
||||
called_end = 0
|
||||
|
||||
def start(self):
|
||||
def sim_start(self):
|
||||
self.__class__.called_start += 1
|
||||
self.__class__.started = True
|
||||
|
||||
def trial(self, env, stats):
|
||||
def trial_end(self, env, stats):
|
||||
assert env
|
||||
self.__class__.trials += 1
|
||||
self.__class__.total_time += env.now
|
||||
self.__class__.called_trial += 1
|
||||
|
||||
def end(self, stats):
|
||||
def sim_end(self, stats):
|
||||
self.__class__.ended = True
|
||||
self.__class__.called_end += 1
|
||||
|
||||
|
@ -9,7 +9,7 @@ import networkx as nx
|
||||
from functools import partial
|
||||
|
||||
from os.path import join
|
||||
from soil import (simulation, Environment, agents, serialization,
|
||||
from soil import (simulation, Environment, agents, network, serialization,
|
||||
utils)
|
||||
from soil.time import Delta
|
||||
|
||||
@ -17,7 +17,7 @@ ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||
EXAMPLES = join(ROOT, '..', 'examples')
|
||||
|
||||
|
||||
class CustomAgent(agents.FSM):
|
||||
class CustomAgent(agents.FSM, agents.NetworkAgent):
|
||||
@agents.default_state
|
||||
@agents.state
|
||||
def normal(self):
|
||||
@ -39,7 +39,7 @@ class TestMain(TestCase):
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
}
|
||||
}
|
||||
G = serialization.load_network(config['network_params'])
|
||||
G = network.from_config(config['network_params'])
|
||||
assert G
|
||||
assert len(G) == 2
|
||||
with self.assertRaises(AttributeError):
|
||||
@ -48,7 +48,7 @@ class TestMain(TestCase):
|
||||
'path': join(ROOT, 'unknown.extension')
|
||||
}
|
||||
}
|
||||
G = serialization.load_network(config['network_params'])
|
||||
G = network.from_config(config['network_params'])
|
||||
print(G)
|
||||
|
||||
def test_generate_barabasi(self):
|
||||
@ -56,16 +56,16 @@ class TestMain(TestCase):
|
||||
If no path is given, a generator and network parameters
|
||||
should be used to generate a network
|
||||
"""
|
||||
config = {
|
||||
'network_params': {
|
||||
cfg = {
|
||||
'params': {
|
||||
'generator': 'barabasi_albert_graph'
|
||||
}
|
||||
}
|
||||
with self.assertRaises(TypeError):
|
||||
G = serialization.load_network(config['network_params'])
|
||||
config['network_params']['n'] = 100
|
||||
config['network_params']['m'] = 10
|
||||
G = serialization.load_network(config['network_params'])
|
||||
with self.assertRaises(Exception):
|
||||
G = network.from_config(cfg)
|
||||
cfg['params']['n'] = 100
|
||||
cfg['params']['m'] = 10
|
||||
G = network.from_config(cfg)
|
||||
assert len(G) == 100
|
||||
|
||||
def test_empty_simulation(self):
|
||||
@ -103,28 +103,43 @@ class TestMain(TestCase):
|
||||
}
|
||||
}
|
||||
s = simulation.from_old_config(config)
|
||||
|
||||
def test_counter_agent(self):
|
||||
"""
|
||||
The initial states should be applied to the agent and the
|
||||
agent should be able to update its state."""
|
||||
config = {
|
||||
'version': '2',
|
||||
'general': {
|
||||
'name': 'CounterAgent',
|
||||
'network_params': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
},
|
||||
'agent_type': 'CounterModel',
|
||||
'states': [{'times': 10}, {'times': 20}],
|
||||
'max_time': 2,
|
||||
'dry_run': True,
|
||||
'num_trials': 1,
|
||||
'environment_params': {
|
||||
},
|
||||
'topologies': {
|
||||
'default': {
|
||||
'path': join(ROOT, 'test.gexf')
|
||||
}
|
||||
},
|
||||
'agents': {
|
||||
'default': {
|
||||
'agent_class': 'CounterModel',
|
||||
},
|
||||
'counters': {
|
||||
'topology': 'default',
|
||||
'fixed': [{'state': {'times': 10}}, {'state': {'times': 20}}],
|
||||
}
|
||||
}
|
||||
s = simulation.from_old_config(config)
|
||||
env = s.run_simulation(dry_run=True)[0]
|
||||
assert env.get_agent(0)['times', 0] == 11
|
||||
assert env.get_agent(0)['times', 1] == 12
|
||||
assert env.get_agent(1)['times', 0] == 21
|
||||
assert env.get_agent(1)['times', 1] == 22
|
||||
}
|
||||
s = simulation.from_config(config)
|
||||
env = s.get_env()
|
||||
assert isinstance(env.agents[0], agents.CounterModel)
|
||||
assert env.agents[0].topology == env.topologies['default']
|
||||
assert env.agents[0]['times'] == 10
|
||||
assert env.agents[0]['times'] == 10
|
||||
env.step()
|
||||
assert env.agents[0]['times'] == 11
|
||||
assert env.agents[1]['times'] == 21
|
||||
|
||||
def test_custom_agent(self):
|
||||
"""Allow for search of neighbors with a certain state_id"""
|
||||
@ -143,9 +158,9 @@ class TestMain(TestCase):
|
||||
}
|
||||
s = simulation.from_old_config(config)
|
||||
env = s.run_simulation(dry_run=True)[0]
|
||||
assert env.get_agent(1).count_agents(state_id='normal') == 2
|
||||
assert env.get_agent(1).count_agents(state_id='normal', limit_neighbors=True) == 1
|
||||
assert env.get_agent(0).neighbors == 1
|
||||
assert env.agents[1].count_agents(state_id='normal') == 2
|
||||
assert env.agents[1].count_agents(state_id='normal', limit_neighbors=True) == 1
|
||||
assert env.agents[0].neighbors == 1
|
||||
|
||||
def test_torvalds_example(self):
|
||||
"""A complete example from a documentation should work."""
|
||||
@ -180,11 +195,9 @@ class TestMain(TestCase):
|
||||
config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0]
|
||||
s = simulation.from_old_config(config)
|
||||
with utils.timer('serializing'):
|
||||
serial = s.config.to_yaml()
|
||||
serial = s.to_yaml()
|
||||
with utils.timer('recovering'):
|
||||
recovered = yaml.load(serial, Loader=yaml.SafeLoader)
|
||||
with utils.timer('deleting'):
|
||||
del recovered['topology']
|
||||
for (k, v) in config.items():
|
||||
assert recovered[k] == v
|
||||
# assert config == recovered
|
||||
|
Loading…
Reference in New Issue
Block a user