1
0
mirror of https://github.com/gsi-upm/soil synced 2024-11-21 18:52:28 +00:00

WIP: removed stats

This commit is contained in:
J. Fernando Sánchez 2022-09-16 18:13:39 +02:00
parent 3dc56892c1
commit 0a9c6d8b19
17 changed files with 224 additions and 589 deletions

View File

@ -14,7 +14,6 @@ network_agents:
weight: 1 weight: 1
environment_class: social_wealth.MoneyEnv environment_class: social_wealth.MoneyEnv
environment_params: environment_params:
num_mesa_agents: 5
mesa_agent_type: social_wealth.MoneyAgent mesa_agent_type: social_wealth.MoneyAgent
N: 10 N: 10
width: 50 width: 50

View File

@ -71,10 +71,9 @@ class SocialMoneyAgent(NetworkAgent, MoneyAgent):
class MoneyEnv(Environment): class MoneyEnv(Environment):
"""A model with some number of agents.""" """A model with some number of agents."""
def __init__(self, N, width, height, *args, network_params, **kwargs): def __init__(self, width, height, *args, topologies, **kwargs):
network_params['n'] = N super().__init__(*args, topologies=topologies, **kwargs)
super().__init__(*args, network_params=network_params, **kwargs)
self.grid = MultiGrid(width, height, False) self.grid = MultiGrid(width, height, False)
# Create agents # Create agents

View File

@ -1,8 +1,8 @@
from soil.agents import FSM, state, default_state, prob from soil.agents import FSM, NetworkAgent, state, default_state, prob
import logging import logging
class DumbViewer(FSM): class DumbViewer(FSM, NetworkAgent):
''' '''
A viewer that gets infected via TV (if it has one) and tries to infect A viewer that gets infected via TV (if it has one) and tries to infect
its neighbors once it's infected. its neighbors once it's infected.

View File

@ -1,4 +1,4 @@
from soil.agents import FSM, state, default_state from soil.agents import FSM, NetworkAgent, state, default_state
from soil import Environment from soil import Environment
from random import random, shuffle from random import random, shuffle
from itertools import islice from itertools import islice
@ -53,7 +53,7 @@ class CityPubs(Environment):
pub['occupancy'] -= 1 pub['occupancy'] -= 1
class Patron(FSM): class Patron(FSM, NetworkAgent):
'''Agent that looks for friends to drink with. It will do three things: '''Agent that looks for friends to drink with. It will do three things:
1) Look for other patrons to drink with 1) Look for other patrons to drink with
2) Look for a bar where the agent and other agents in the same group can get in. 2) Look for a bar where the agent and other agents in the same group can get in.
@ -151,7 +151,7 @@ class Patron(FSM):
return befriended return befriended
class Police(FSM): class Police(FSM, NetworkAgent):
'''Simple agent to take drunk people out of pubs.''' '''Simple agent to take drunk people out of pubs.'''
level = logging.INFO level = logging.INFO

View File

@ -10,7 +10,7 @@ class Genders(Enum):
female = 'female' female = 'female'
class RabbitModel(FSM): class RabbitModel(FSM, NetworkAgent):
defaults = { defaults = {
'age': 0, 'age': 0,
@ -110,12 +110,12 @@ class Female(RabbitModel):
self.info('A mother has died carrying a baby!!') self.info('A mother has died carrying a baby!!')
class RandomAccident(NetworkAgent): class RandomAccident(BaseAgent):
level = logging.DEBUG level = logging.DEBUG
def step(self): def step(self):
rabbits_total = self.topology.number_of_nodes() rabbits_total = self.env.topology.number_of_nodes()
if 'rabbits_alive' not in self.env: if 'rabbits_alive' not in self.env:
self.env['rabbits_alive'] = 0 self.env['rabbits_alive'] = 0
rabbits_alive = self.env.get('rabbits_alive', rabbits_total) rabbits_alive = self.env.get('rabbits_alive', rabbits_total)
@ -131,5 +131,5 @@ class RandomAccident(NetworkAgent):
self.log('Rabbits alive: {}'.format(self.env['rabbits_alive'])) self.log('Rabbits alive: {}'.format(self.env['rabbits_alive']))
i.set_state(i.dead) i.set_state(i.dead)
self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total)) self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total))
if self.count_agents(state_id=RabbitModel.dead.id) == self.topology.number_of_nodes(): if self.env.count_agents(state_id=RabbitModel.dead.id) == self.env.topology.number_of_nodes():
self.die() self.die()

View File

@ -55,6 +55,7 @@ class BaseAgent(MesaAgent, MutableMapping):
raise Exception() raise Exception()
assert isinstance(unique_id, int) assert isinstance(unique_id, int)
super().__init__(unique_id=unique_id, model=model) super().__init__(unique_id=unique_id, model=model)
self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id) self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id)
@ -78,6 +79,9 @@ class BaseAgent(MesaAgent, MutableMapping):
if not hasattr(self, k) or getattr(self, k) is None: if not hasattr(self, k) or getattr(self, k) is None:
setattr(self, k, v) setattr(self, k, v)
def __hash__(self):
return hash(self.unique_id)
# TODO: refactor to clean up mesa compatibility # TODO: refactor to clean up mesa compatibility
@property @property
def id(self): def id(self):
@ -185,16 +189,14 @@ class BaseAgent(MesaAgent, MutableMapping):
# Agent = BaseAgent # Agent = BaseAgent
class NetworkAgent(BaseAgent): class NetworkAgent(BaseAgent):
def __init__(self,
*args, @property
graph_name: str, def topology(self):
node_id: int = None, return self.env.topology_for(self.unique_id)
**kwargs,
): @property
super().__init__(*args, **kwargs) def node_id(self):
self.graph_name = graph_name return self.env.node_id_for(self.unique_id)
self.topology = self.env.topologies[self.graph_name]
self.node_id = node_id
@property @property
def G(self): def G(self):
@ -215,15 +217,19 @@ class NetworkAgent(BaseAgent):
it = islice(it, limit) it = islice(it, limit)
return list(it) return list(it)
def iter_agents(self, agents=None, limit_neighbors=False, **kwargs): def iter_agents(self, unique_id=None, limit_neighbors=False, **kwargs):
if limit_neighbors: if limit_neighbors:
agents = self.topology.neighbors(self.unique_id) unique_id = [self.topology.nodes[node]['agent_id'] for node in self.topology.neighbors(self.node_id)]
if not unique_id:
return
yield from self.model.agents(unique_id=unique_id, **kwargs)
return self.model.agents(ids=agents, **kwargs)
def subgraph(self, center=True, **kwargs): def subgraph(self, center=True, **kwargs):
include = [self] if center else [] include = [self] if center else []
return self.topology.subgraph(n.unique_id for n in list(self.get_agents(**kwargs))+include) G = self.topology.subgraph(n.node_id for n in list(self.get_agents(**kwargs)+include))
return G
def remove_node(self, unique_id): def remove_node(self, unique_id):
self.topology.remove_node(unique_id) self.topology.remove_node(unique_id)
@ -366,7 +372,7 @@ def prob(prob=1):
def calculate_distribution(network_agents=None, def calculate_distribution(network_agents=None,
agent_type=None): agent_class=None):
''' '''
Calculate the threshold values (thresholds for a uniform distribution) Calculate the threshold values (thresholds for a uniform distribution)
of an agent distribution given the weights of each agent type. of an agent distribution given the weights of each agent type.
@ -374,13 +380,13 @@ def calculate_distribution(network_agents=None,
The input has this form: :: The input has this form: ::
[ [
{'agent_type': 'agent_type_1', {'agent_class': 'agent_class_1',
'weight': 0.2, 'weight': 0.2,
'state': { 'state': {
'id': 0 'id': 0
} }
}, },
{'agent_type': 'agent_type_2', {'agent_class': 'agent_class_2',
'weight': 0.8, 'weight': 0.8,
'state': { 'state': {
'id': 1 'id': 1
@ -389,12 +395,12 @@ def calculate_distribution(network_agents=None,
] ]
In this example, 20% of the nodes will be marked as type In this example, 20% of the nodes will be marked as type
'agent_type_1'. 'agent_class_1'.
''' '''
if network_agents: if network_agents:
network_agents = [deepcopy(agent) for agent in network_agents if not hasattr(agent, 'id')] network_agents = [deepcopy(agent) for agent in network_agents if not hasattr(agent, 'id')]
elif agent_type: elif agent_class:
network_agents = [{'agent_type': agent_type}] network_agents = [{'agent_class': agent_class}]
else: else:
raise ValueError('Specify a distribution or a default agent type') raise ValueError('Specify a distribution or a default agent type')
@ -414,11 +420,11 @@ def calculate_distribution(network_agents=None,
return network_agents return network_agents
def serialize_type(agent_type, known_modules=[], **kwargs): def serialize_type(agent_class, known_modules=[], **kwargs):
if isinstance(agent_type, str): if isinstance(agent_class, str):
return agent_type return agent_class
known_modules += ['soil.agents'] known_modules += ['soil.agents']
return serialization.serialize(agent_type, known_modules=known_modules, **kwargs)[1] # Get the name of the class return serialization.serialize(agent_class, known_modules=known_modules, **kwargs)[1] # Get the name of the class
def serialize_definition(network_agents, known_modules=[]): def serialize_definition(network_agents, known_modules=[]):
@ -430,23 +436,23 @@ def serialize_definition(network_agents, known_modules=[]):
for v in d: for v in d:
if 'threshold' in v: if 'threshold' in v:
del v['threshold'] del v['threshold']
v['agent_type'] = serialize_type(v['agent_type'], v['agent_class'] = serialize_type(v['agent_class'],
known_modules=known_modules) known_modules=known_modules)
return d return d
def deserialize_type(agent_type, known_modules=[]): def deserialize_type(agent_class, known_modules=[]):
if not isinstance(agent_type, str): if not isinstance(agent_class, str):
return agent_type return agent_class
known = known_modules + ['soil.agents', 'soil.agents.custom' ] known = known_modules + ['soil.agents', 'soil.agents.custom' ]
agent_type = serialization.deserializer(agent_type, known_modules=known) agent_class = serialization.deserializer(agent_class, known_modules=known)
return agent_type return agent_class
def deserialize_definition(ind, **kwargs): def deserialize_definition(ind, **kwargs):
d = deepcopy(ind) d = deepcopy(ind)
for v in d: for v in d:
v['agent_type'] = deserialize_type(v['agent_type'], **kwargs) v['agent_class'] = deserialize_type(v['agent_class'], **kwargs)
return d return d
@ -461,7 +467,7 @@ def _validate_states(states, topology):
return states return states
def _convert_agent_types(ind, to_string=False, **kwargs): def _convert_agent_classs(ind, to_string=False, **kwargs):
'''Convenience method to allow specifying agents by class or class name.''' '''Convenience method to allow specifying agents by class or class name.'''
if to_string: if to_string:
return serialize_definition(ind, **kwargs) return serialize_definition(ind, **kwargs)
@ -480,7 +486,7 @@ def _agent_from_definition(definition, value=-1, unique_id=None):
state = {} state = {}
if 'state' in d: if 'state' in d:
state = deepcopy(d['state']) state = deepcopy(d['state'])
return d['agent_type'], state return d['agent_class'], state
raise Exception('Definition for value {} not found in: {}'.format(value, definition)) raise Exception('Definition for value {} not found in: {}'.format(value, definition))
@ -576,8 +582,11 @@ class AgentView(Mapping, Set):
return group[agent_id] return group[agent_id]
raise ValueError(f"Agent {agent_id} not found") raise ValueError(f"Agent {agent_id} not found")
def filter(self, *group_ids, **kwargs): def filter(self, *args, **kwargs):
yield from filter_groups(self._agents, group_ids=group_ids, **kwargs) yield from filter_groups(self._agents, *args, **kwargs)
def one(self, *args, **kwargs):
return next(filter_groups(self._agents, *args, **kwargs))
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return list(self.filter(*args, **kwargs)) return list(self.filter(*args, **kwargs))
@ -586,41 +595,57 @@ class AgentView(Mapping, Set):
return any(agent_id in g for g in self._agents) return any(agent_id in g for g in self._agents)
def __str__(self): def __str__(self):
return str(list(a.id for a in self)) return str(list(a.unique_id for a in self))
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}({self})" return f"{self.__class__.__name__}({self})"
def filter_groups(groups, group_ids=None, **kwargs): def filter_groups(groups, *, group=None, **kwargs):
assert isinstance(groups, dict) assert isinstance(groups, dict)
if group_ids:
groups = list(groups[g] for g in group_ids if g in groups) if group is not None and not isinstance(group, list):
group = [group]
if group:
groups = list(groups[g] for g in group if g in groups)
else: else:
groups = list(groups.values()) groups = list(groups.values())
agents = chain.from_iterable(filter_group(g, **kwargs) for g in groups) agents = chain.from_iterable(filter_group(g, **kwargs) for g in groups)
yield from agents yield from agents
def filter_group(group, ids=None, state_id=None, agent_type=None, ignore=None, state=None, **kwargs): def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=None, ignore=None, state=None, **kwargs):
''' '''
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id). Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
''' '''
assert isinstance(group, dict) assert isinstance(group, dict)
ids = []
if unique_id is not None:
if isinstance(unique_id, list):
ids += unique_id
else:
ids.append(unique_id)
if id_args:
ids += id_args
if state_id is not None and not isinstance(state_id, (tuple, list)): if state_id is not None and not isinstance(state_id, (tuple, list)):
state_id = tuple([state_id]) state_id = tuple([state_id])
if agent_type is not None: if agent_class is not None:
agent_class = deserialize_type(agent_class)
try: try:
agent_type = tuple(agent_type) agent_class = tuple(agent_class)
except TypeError: except TypeError:
agent_type = tuple([agent_type]) agent_class = tuple([agent_class])
if ids: if ids:
agents = (v[aid] for aid in ids if aid in group) agents = (group[aid] for aid in ids if aid in group)
else: else:
agents = (a for a in group.values()) agents = (a for a in group.values())
@ -631,8 +656,8 @@ def filter_group(group, ids=None, state_id=None, agent_type=None, ignore=None, s
if state_id is not None: if state_id is not None:
f = filter(lambda agent: agent.get('state_id', None) in state_id, f) f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
if agent_type is not None: if agent_class is not None:
f = filter(lambda agent: isinstance(agent, agent_type), f) f = filter(lambda agent: isinstance(agent, agent_class), f)
state = state or dict() state = state or dict()
state.update(kwargs) state.update(kwargs)
@ -660,7 +685,7 @@ def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfi
if cfg.fixed is not None: if cfg.fixed is not None:
agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env) agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env)
if cfg.distribution: if cfg.distribution:
n = cfg.n or len(env.topologies[cfg.topology]) n = cfg.n or len(env.topologies[cfg.topology or default.topology])
target = n - len(agents) target = n - len(agents)
agents.update(_from_distro(cfg.distribution, target, agents.update(_from_distro(cfg.distribution, target,
topology=cfg.topology or default.topology, topology=cfg.topology or default.topology,
@ -674,6 +699,8 @@ def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfi
else: else:
filtered = list(agents) filtered = list(agents)
if attrs.n > len(filtered):
raise ValueError(f'Not enough agents to sample. Got {len(filtered)}, expected >= {attrs.n}')
for agent in random.sample(filtered, attrs.n): for agent in random.sample(filtered, attrs.n):
agent.state.update(attrs.state) agent.state.update(attrs.state)
@ -684,18 +711,20 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: conf
agents = {} agents = {}
for fixed in lst: for fixed in lst:
agent_id = fixed.agent_id agent_id = fixed.agent_id
if agent_id is None: if agent_id is None:
agent_id = env.next_id() agent_id = env.next_id()
cls = serialization.deserialize(fixed.agent_class or default.agent_class) cls = serialization.deserialize(fixed.agent_class or default.agent_class)
state = fixed.state.copy() state = fixed.state.copy()
state.update(default.state) state.update(default.state)
agent = cls(unique_id=agent_id, agent = cls(unique_id=agent_id,
model=env, model=env,
graph_name=fixed.topology or topology or default.topology, **state)
**state) topology = fixed.topology if (fixed.topology is not None) else (topology or default.topology)
agents[agent.unique_id] = agent if topology:
env.agent_to_node(agent_id, topology, fixed.node_id)
agents[agent.unique_id] = agent
return agents return agents
@ -741,8 +770,12 @@ def _from_distro(distro: List[config.AgentDistro],
cls = classes[idx] cls = classes[idx]
agent_id = env.next_id() agent_id = env.next_id()
state = d.state.copy() state = d.state.copy()
state.update(default.state) if default:
agent = cls(unique_id=agent_id, model=env, graph_name=d.topology or topology or default.topology, **state) state.update(default.state)
agent = cls(unique_id=agent_id, model=env, **state)
topology = d.topology if (d.topology is not None) else topology or default.topology
if topology:
env.agent_to_node(agent.unique_id, topology)
assert agent.name is not None assert agent.name is not None
assert agent.name != 'None' assert agent.name != 'None'
assert agent.name assert agent.name

View File

@ -7,6 +7,7 @@ import sys
from typing import Any, Callable, Dict, List, Optional, Union, Type from typing import Any, Callable, Dict, List, Optional, Union, Type
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
import networkx as nx
class General(BaseModel): class General(BaseModel):
id: str = 'Unnamed Simulation' id: str = 'Unnamed Simulation'
@ -50,9 +51,12 @@ class NetParams(BaseModel, extra=Extra.allow):
class NetConfig(BaseModel): class NetConfig(BaseModel):
group: str = 'network' group: str = 'network'
params: Optional[NetParams] params: Optional[NetParams]
topology: Optional[Topology] topology: Optional[Union[Topology, nx.Graph]]
path: Optional[str] path: Optional[str]
class Config:
arbitrary_types_allowed = True
@staticmethod @staticmethod
def default(): def default():
return NetConfig(topology=None, params=None) return NetConfig(topology=None, params=None)
@ -77,7 +81,8 @@ class EnvConfig(BaseModel):
class SingleAgentConfig(BaseModel): class SingleAgentConfig(BaseModel):
agent_class: Optional[Union[Type, str]] = None agent_class: Optional[Union[Type, str]] = None
agent_id: Optional[int] = None agent_id: Optional[int] = None
topology: Optional[str] = 'default' topology: Optional[str] = None
node_id: Optional[Union[int, str]] = None
name: Optional[str] = None name: Optional[str] = None
state: Optional[Dict[str, Any]] = {} state: Optional[Dict[str, Any]] = {}
@ -186,9 +191,7 @@ def convert_old(old, strict=True):
if 'agent_id' in agent: if 'agent_id' in agent:
agent['name'] = agent['agent_id'] agent['name'] = agent['agent_id']
del agent['agent_id'] del agent['agent_id']
agents['environment']['fixed'].append(updated_agent(agent)) agents['environment']['fixed'].append(updated_agent(agent))
else:
agents['environment']['distribution'].append(updated_agent(agent))
by_weight = [] by_weight = []
fixed = [] fixed = []
@ -206,10 +209,10 @@ def convert_old(old, strict=True):
if 'agent_type' in old and (not fixed and not by_weight): if 'agent_type' in old and (not fixed and not by_weight):
agents['network']['topology'] = 'default' agents['network']['topology'] = 'default'
by_weight = [{'agent_type': old['agent_type']}] by_weight = [{'agent_class': old['agent_type']}]
# TODO: translate states # TODO: translate states properly
if 'states' in old: if 'states' in old:
states = old['states'] states = old['states']
if isinstance(states, dict): if isinstance(states, dict):
@ -217,7 +220,7 @@ def convert_old(old, strict=True):
else: else:
states = enumerate(states) states = enumerate(states)
for (k, v) in states: for (k, v) in states:
override.append({'filter': {'id': k}, override.append({'filter': {'node_id': k},
'state': v 'state': v
}) })

View File

@ -1,264 +0,0 @@
from pydantic import BaseModel, ValidationError, validator
import yaml
import os
import sys
import networkx as nx
import collections.abc
from . import serialization, utils, basestring, agents
class Config(collections.abc.Mapping):
"""
1) agent type can be specified by name or by class.
2) instead of just one type, a network agents distribution can be used.
The distribution specifies the weight (or probability) of each
agent type in the topology. This is an example distribution: ::
[
{'agent_type': 'agent_type_1',
'weight': 0.2,
'state': {
'id': 0
}
},
{'agent_type': 'agent_type_2',
'weight': 0.8,
'state': {
'id': 1
}
}
]
In this example, 20% of the nodes will be marked as type
'agent_type_1'.
3) if no initial state is given, each node's state will be set
to `{'id': 0}`.
Parameters
---------
name : str, optional
name of the Simulation
group : str, optional
a group name can be used to link simulations
topology (optional): networkx.Graph instance or Node-Link topology as a dict or string (will be loaded with `json_graph.node_link_graph(topology`).
network_params : dict
parameters used to create a topology with networkx, if no topology is given
network_agents : dict
definition of agents to populate the topology with
agent_type : NetworkAgent subclass, optional
Default type of NetworkAgent to use for nodes not specified in network_agents
states : list, optional
List of initial states corresponding to the nodes in the topology. Basic form is a list of integers
whose value indicates the state
dir_path: str, optional
Directory path to load simulation assets (files, modules...)
seed : str, optional
Seed to use for the random generator
num_trials : int, optional
Number of independent simulation runs
max_time : int, optional
Maximum step/time for each simulation
environment_params : dict, optional
Dictionary of globally-shared environmental parameters
environment_agents: dict, optional
Similar to network_agents. Distribution of Agents that control the environment
environment_class: soil.environment.Environment subclass, optional
Class for the environment. It defailts to soil.environment.Environment
"""
__slots__ = 'name', 'agent_type', 'group', 'description', 'network_agents', 'environment_agents', 'states', 'default_state', 'interval', 'network_params', 'seed', 'num_trials', 'max_time', 'topology', 'schedule', 'initial_time', 'environment_params', 'environment_class', 'dir_path', '_added_to_path', 'visualization_params'
def __init__(self, name=None,
group=None,
agent_type='BaseAgent',
network_agents=None,
environment_agents=None,
states=None,
description=None,
default_state=None,
interval=1,
network_params=None,
seed=None,
num_trials=1,
max_time=None,
topology=None,
schedule=None,
initial_time=0,
environment_params={},
environment_class='soil.Environment',
dir_path=None,
visualization_params=None,
):
self.network_params = network_params
self.name = name or 'Unnamed'
self.description = description or 'No simulation description available'
self.seed = str(seed or name)
self.group = group or ''
self.num_trials = num_trials
self.max_time = max_time
self.default_state = default_state or {}
self.dir_path = dir_path or os.getcwd()
self.interval = interval
self.visualization_params = visualization_params or {}
self._added_to_path = list(x for x in [os.getcwd(), self.dir_path] if x not in sys.path)
sys.path += self._added_to_path
self.topology = topology
self.schedule = schedule
self.initial_time = initial_time
self.environment_class = environment_class
self.environment_params = dict(environment_params)
#TODO: Check agent distro vs fixed agents
self.environment_agents = environment_agents or []
self.agent_type = agent_type
self.network_agents = network_agents or {}
self.states = states or {}
def validate(self):
agents._validate_states(self.states,
self._topology)
def calculate(self):
return CalculatedConfig(self)
def restore_path(self):
for added in self._added_to_path:
sys.path.remove(added)
def to_yaml(self):
return yaml.dump(self.to_dict())
def dump_yaml(self, f=None, outdir=None):
if not f and not outdir:
raise ValueError('specify a file or an output directory')
if not f:
f = os.path.join(outdir, '{}.dumped.yml'.format(self.name))
with utils.open_or_reuse(f, 'w') as f:
f.write(self.to_yaml())
def to_yaml(self):
return yaml.dump(self.to_dict())
# TODO: See note on getstate
def to_dict(self):
return dict(self)
def __repr__(self):
return self.to_yaml()
def dump_yaml(self, f=None, outdir=None):
if not f and not outdir:
raise ValueError('specify a file or an output directory')
if not f:
f = os.path.join(outdir, '{}.dumped.yml'.format(self.name))
with utils.open_or_reuse(f, 'w') as f:
f.write(self.to_yaml())
def __getitem__(self, key):
return getattr(self, key)
def __iter__(self):
return (k for k in self.__slots__ if k[0] != '_')
def __len__(self):
return len(self.__slots__)
def dump_pickle(self, f=None, outdir=None):
if not outdir and not f:
raise ValueError('specify a file or an output directory')
if not f:
f = os.path.join(outdir,
'{}.simulation.pickle'.format(self.name))
with utils.open_or_reuse(f, 'wb') as f:
pickle.dump(self, f)
# TODO: remove this. A config should be sendable regardless. Non-pickable objects could be computed via properties and the like
# def __getstate__(self):
# state={}
# for k, v in self.__dict__.items():
# if k[0] != '_':
# state[k] = v
# state['topology'] = json_graph.node_link_data(self.topology)
# state['network_agents'] = agents.serialize_definition(self.network_agents,
# known_modules = [])
# state['environment_agents'] = agents.serialize_definition(self.environment_agents,
# known_modules = [])
# state['environment_class'] = serialization.serialize(self.environment_class,
# known_modules=['soil.environment'])[1] # func, name
# if state['load_module'] is None:
# del state['load_module']
# return state
# # TODO: remove, same as __getstate__
# def __setstate__(self, state):
# self.__dict__ = state
# self.load_module = getattr(self, 'load_module', None)
# if self.dir_path not in sys.path:
# sys.path += [self.dir_path, os.getcwd()]
# self.topology = json_graph.node_link_graph(state['topology'])
# self.network_agents = agents.calculate_distribution(agents._convert_agent_types(self.network_agents))
# self.environment_agents = agents._convert_agent_types(self.environment_agents,
# known_modules=[self.load_module])
# self.environment_class = serialization.deserialize(self.environment_class,
# known_modules=[self.load_module,
# 'soil.environment', ]) # func, name
class CalculatedConfig(Config):
def __init__(self, config):
"""
Returns a configuration object that replaces some "plain" attributes (e.g., `environment_class` string) into
a Python object (`soil.environment.Environment` class).
"""
self._config = config
values = dict(config)
values['environment_class'] = self._environment_class()
values['environment_agents'] = self._environment_agents()
values['topology'] = self._topology()
values['network_agents'] = self._network_agents()
values['agent_type'] = serialization.deserialize(self.agent_type, known_modules=['soil.agents'])
return values
def _topology(self):
topology = self._config.topology
if topology is None:
topology = serialization.load_network(self._config.network_params,
dir_path=self._config.dir_path)
elif isinstance(topology, basestring) or isinstance(topology, dict):
topology = json_graph.node_link_graph(topology)
return nx.Graph(topology)
def _environment_class(self):
return serialization.deserialize(self._config.environment_class,
known_modules=['soil.environment', ]) or Environment
def _environment_agents(self):
return agents._convert_agent_types(self._config.environment_agents)
def _network_agents(self):
distro = agents.calculate_distribution(self._config.network_agents,
self._config.agent_type)
return agents._convert_agent_types(distro)
def _environment_class(self):
return serialization.deserialize(self._config.environment_class,
known_modules=['soil.environment', ]) # func, name

View File

@ -15,6 +15,7 @@ from networkx.readwrite import json_graph
import networkx as nx import networkx as nx
from mesa import Model from mesa import Model
from mesa.datacollection import DataCollector
from . import serialization, agents, analysis, utils, time, config, network from . import serialization, agents, analysis, utils, time, config, network
@ -41,6 +42,9 @@ class Environment(Model):
interval=1, interval=1,
agents: Dict[str, config.AgentConfig] = {}, agents: Dict[str, config.AgentConfig] = {},
topologies: Dict[str, config.NetConfig] = {}, topologies: Dict[str, config.NetConfig] = {},
agent_reporters: Optional[Any] = None,
model_reporters: Optional[Any] = None,
tables: Optional[Any] = None,
**env_params): **env_params):
super().__init__() super().__init__()
@ -61,6 +65,7 @@ class Environment(Model):
self.topologies = {} self.topologies = {}
self._node_ids = {}
for (name, cfg) in topologies.items(): for (name, cfg) in topologies.items():
self.set_topology(cfg=cfg, self.set_topology(cfg=cfg,
graph=name) graph=name)
@ -72,6 +77,7 @@ class Environment(Model):
self['SEED'] = seed self['SEED'] = seed
self.logger = utils.logger.getChild(self.id) self.logger = utils.logger.getChild(self.id)
self.datacollector = DataCollector(model_reporters, agent_reporters, tables)
@property @property
def topology(self): def topology(self):
@ -79,8 +85,7 @@ class Environment(Model):
@property @property
def network_agents(self): def network_agents(self):
yield from self.agents(agent_type=agents.NetworkAgent, iterator=False) yield from self.agents(agent_class=agents.NetworkAgent)
@staticmethod @staticmethod
def from_config(conf: config.Config, trial_id, **kwargs) -> Environment: def from_config(conf: config.Config, trial_id, **kwargs) -> Environment:
@ -91,9 +96,10 @@ class Environment(Model):
seed = '{}_{}'.format(conf.general.seed, trial_id) seed = '{}_{}'.format(conf.general.seed, trial_id)
id = '{}_trial_{}'.format(conf.general.id, trial_id).replace('.', '-') id = '{}_trial_{}'.format(conf.general.id, trial_id).replace('.', '-')
opts = conf.environment.params.copy() opts = conf.environment.params.copy()
dir_path = conf.general.dir_path
opts.update(conf) opts.update(conf)
opts.update(kwargs) opts.update(kwargs)
env = serialization.deserialize(conf.environment.environment_class)(env_id=id, seed=seed, **opts) env = serialization.deserialize(conf.environment.environment_class)(env_id=id, seed=seed, dir_path=dir_path, **opts)
return env return env
@property @property
@ -103,12 +109,31 @@ class Environment(Model):
raise Exception('The environment has not been scheduled, so it has no sense of time') raise Exception('The environment has not been scheduled, so it has no sense of time')
def topology_for(self, agent_id):
return self.topologies[self._node_ids[agent_id][0]]
def node_id_for(self, agent_id):
return self._node_ids[agent_id][1]
def set_topology(self, cfg=None, dir_path=None, graph='default'): def set_topology(self, cfg=None, dir_path=None, graph='default'):
self.topologies[graph] = network.from_config(cfg, dir_path=dir_path) topology = cfg
if not isinstance(cfg, nx.Graph):
topology = network.from_config(cfg, dir_path=dir_path or self.dir_path)
self.topologies[graph] = topology
@property @property
def agents(self): def agents(self):
return agents.AgentView(self._agents) return agents.AgentView(self._agents)
def count_agents(self, *args, **kwargs):
return sum(1 for i in self.find_all(*args, **kwargs))
def find_all(self, *args, **kwargs):
return agents.AgentView(self._agents).filter(*args, **kwargs)
def find_one(self, *args, **kwargs):
return agents.AgentView(self._agents).one(*args, **kwargs)
@agents.setter @agents.setter
def agents(self, agents_def: Dict[str, config.AgentConfig]): def agents(self, agents_def: Dict[str, config.AgentConfig]):
@ -117,37 +142,47 @@ class Environment(Model):
for a in d.values(): for a in d.values():
self.schedule.add(a) self.schedule.add(a)
# @property
# def network_agents(self):
# for i in self.G.nodes():
# node = self.G.nodes[i]
# if 'agent' in node:
# yield node['agent']
def init_agent(self, agent_id, agent_definitions, graph='default'): def init_agent(self, agent_id, agent_definitions, graph='default'):
node = self.topologies[graph].nodes[agent_id] node = self.topologies[graph].nodes[agent_id]
init = False init = False
state = dict(node) state = dict(node)
agent_type = None agent_class = None
if 'agent_type' in self.states.get(agent_id, {}): if 'agent_class' in self.states.get(agent_id, {}):
agent_type = self.states[agent_id]['agent_type'] agent_class = self.states[agent_id]['agent_class']
elif 'agent_type' in node: elif 'agent_class' in node:
agent_type = node['agent_type'] agent_class = node['agent_class']
elif 'agent_type' in self.default_state: elif 'agent_class' in self.default_state:
agent_type = self.default_state['agent_type'] agent_class = self.default_state['agent_class']
if agent_type: if agent_class:
agent_type = agents.deserialize_type(agent_type) agent_class = agents.deserialize_type(agent_class)
elif agent_definitions: elif agent_definitions:
agent_type, state = agents._agent_from_definition(agent_definitions, unique_id=agent_id) agent_class, state = agents._agent_from_definition(agent_definitions, unique_id=agent_id)
else: else:
serialization.logger.debug('Skipping node {}'.format(agent_id)) serialization.logger.debug('Skipping node {}'.format(agent_id))
return return
return self.set_agent(agent_id, agent_type, state) return self.set_agent(agent_id, agent_class, state)
def set_agent(self, agent_id, agent_type, state=None, graph='default'): def agent_to_node(self, agent_id, graph_name='default', node_id=None, shuffle=False):
#TODO: test
if node_id is None:
G = self.topologies[graph_name]
candidates = list(G.nodes(data=True))
if shuffle:
random.shuffle(candidates)
for next_id, data in candidates:
if data.get('agent_id', None) is None:
node_id = next_id
data['agent_id'] = agent_id
break
self._node_ids[agent_id] = (graph_name, node_id)
print(self._node_ids)
def set_agent(self, agent_id, agent_class, state=None, graph='default'):
node = self.topologies[graph].nodes[agent_id] node = self.topologies[graph].nodes[agent_id]
defstate = deepcopy(self.default_state) or {} defstate = deepcopy(self.default_state) or {}
defstate.update(self.states.get(agent_id, {})) defstate.update(self.states.get(agent_id, {}))
@ -155,9 +190,9 @@ class Environment(Model):
if state: if state:
defstate.update(state) defstate.update(state)
a = None a = None
if agent_type: if agent_class:
state = defstate state = defstate
a = agent_type(model=self, a = agent_class(model=self,
unique_id=agent_id unique_id=agent_id
) )
@ -168,10 +203,10 @@ class Environment(Model):
self.schedule.add(a) self.schedule.add(a)
return a return a
def add_node(self, agent_type, state=None, graph='default'): def add_node(self, agent_class, state=None, graph='default'):
agent_id = int(len(self.topologies[graph].nodes())) agent_id = int(len(self.topologies[graph].nodes()))
self.topologies[graph].add_node(agent_id) self.topologies[graph].add_node(agent_id)
a = self.set_agent(agent_id, agent_type, state, graph=graph) a = self.set_agent(agent_id, agent_class, state, graph=graph)
a['visible'] = True a['visible'] = True
return a return a
@ -201,6 +236,7 @@ class Environment(Model):
''' '''
super().step() super().step()
self.schedule.step() self.schedule.step()
self.datacollector.collect(self)
def run(self, until, *args, **kwargs): def run(self, until, *args, **kwargs):
until = until or float('inf') until = until or float('inf')

View File

@ -1,5 +1,4 @@
import os import os
import csv as csvlib
from time import time as current_time from time import time as current_time
from io import BytesIO from io import BytesIO
from sqlalchemy import create_engine from sqlalchemy import create_engine
@ -59,7 +58,7 @@ class Exporter:
'''Method to call when the simulation starts''' '''Method to call when the simulation starts'''
pass pass
def sim_end(self, stats): def sim_end(self):
'''Method to call when the simulation ends''' '''Method to call when the simulation ends'''
pass pass
@ -67,7 +66,7 @@ class Exporter:
'''Method to call when a trial start''' '''Method to call when a trial start'''
pass pass
def trial_end(self, env, stats): def trial_end(self, env):
'''Method to call when a trial ends''' '''Method to call when a trial ends'''
pass pass
@ -115,31 +114,35 @@ class default(Exporter):
# self.simulation.dump_sqlite(f) # self.simulation.dump_sqlite(f)
def get_dc_dfs(dc):
dfs = {'env': dc.get_model_vars_dataframe(),
'agents': dc.get_agent_vars_dataframe }
for table_name in dc.tables:
dfs[table_name] = dc.get_table_dataframe(table_name)
yield from dfs.items()
class csv(Exporter): class csv(Exporter):
'''Export the state of each environment (and its agents) in a separate CSV file''' '''Export the state of each environment (and its agents) in a separate CSV file'''
def trial_end(self, env, stats): def trial_end(self, env):
with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name, with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name,
env.name, env.id,
self.outdir)): self.outdir)):
for (df_name, df) in get_dc_dfs(env.datacollector):
with self.output('{}.stats.{}.csv'.format(env.name, stats.name)) as f: with self.output('{}.stats.{}.csv'.format(env.id, df_name)) as f:
statwriter = csvlib.writer(f, delimiter='\t', quotechar='"', quoting=csvlib.QUOTE_ALL) df.to_csv(f)
for stat in stats:
statwriter.writerow(stat)
class gexf(Exporter): class gexf(Exporter):
def trial_end(self, env, stats): def trial_end(self, env):
if self.dry_run: if self.dry_run:
logger.info('Not dumping GEXF in dry_run mode') logger.info('Not dumping GEXF in dry_run mode')
return return
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name, with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name,
env.name)): env.id)):
with self.output('{}.gexf'.format(env.name), mode='wb') as f: with self.output('{}.gexf'.format(env.id), mode='wb') as f:
self.dump_gexf(env, f) self.dump_gexf(env, f)
def dump_gexf(self, env, f): def dump_gexf(self, env, f):
@ -159,25 +162,25 @@ class dummy(Exporter):
with self.output('dummy', 'w') as f: with self.output('dummy', 'w') as f:
f.write('simulation started @ {}\n'.format(current_time())) f.write('simulation started @ {}\n'.format(current_time()))
def trial_end(self, env, stats): def trial_start(self, env):
with self.output('dummy', 'w') as f: with self.output('dummy', 'w') as f:
for i in stats: f.write('trial started@ {}\n'.format(current_time()))
f.write(','.join(map(str, i)))
f.write('\n')
def sim_end(self, stats): def trial_end(self, env):
with self.output('dummy', 'w') as f:
f.write('trial ended@ {}\n'.format(current_time()))
def sim_end(self):
with self.output('dummy', 'a') as f: with self.output('dummy', 'a') as f:
f.write('simulation ended @ {}\n'.format(current_time())) f.write('simulation ended @ {}\n'.format(current_time()))
class graphdrawing(Exporter): class graphdrawing(Exporter):
def trial_end(self, env, stats): def trial_end(self, env):
# Outside effects # Outside effects
f = plt.figure() f = plt.figure()
nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111)) nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111))
with open('graph-{}.png'.format(env.name)) as f: with open('graph-{}.png'.format(env.id)) as f:
f.savefig(f) f.savefig(f)
''' '''

View File

@ -16,7 +16,6 @@ from . import serialization, utils, basestring, agents
from .environment import Environment from .environment import Environment
from .utils import logger from .utils import logger
from .exporters import default from .exporters import default
from .stats import defaultStats
from .config import Config, convert_old from .config import Config, convert_old
@ -71,8 +70,8 @@ class Simulation:
**kwargs) **kwargs)
def run_gen(self, parallel=False, dry_run=False, def run_gen(self, parallel=False, dry_run=False,
exporters=[default, ], stats=[], outdir=None, exporter_params={}, exporters=[default, ], outdir=None, exporter_params={},
stats_params={}, log_level=None, log_level=None,
**kwargs): **kwargs):
'''Run the simulation and yield the resulting environments.''' '''Run the simulation and yield the resulting environments.'''
if log_level: if log_level:
@ -85,15 +84,8 @@ class Simulation:
dry_run=dry_run, dry_run=dry_run,
outdir=outdir, outdir=outdir,
**exporter_params) **exporter_params)
stats = serialization.deserialize_all(simulation=self,
names=stats,
known_modules=['soil.stats',],
**stats_params)
with utils.timer('simulation {}'.format(self.config.general.id)): with utils.timer('simulation {}'.format(self.config.general.id)):
for stat in stats:
stat.sim_start()
for exporter in exporters: for exporter in exporters:
exporter.sim_start() exporter.sim_start()
@ -104,32 +96,13 @@ class Simulation:
for exporter in exporters: for exporter in exporters:
exporter.trial_start(env) exporter.trial_start(env)
collected = list(stat.trial_end(env) for stat in stats)
saved = self._update_stats(collected, t_step=env.now, trial_id=env.id)
for exporter in exporters: for exporter in exporters:
exporter.trial_end(env, saved) exporter.trial_end(env)
yield env yield env
collected = list(stat.end() for stat in stats)
saved = self._update_stats(collected)
for stat in stats:
stat.sim_end()
for exporter in exporters: for exporter in exporters:
exporter.sim_end(saved) exporter.sim_end()
def _update_stats(self, collection, **kwargs):
stats = dict(kwargs)
for stat in collection:
stats.update(stat)
return stats
def log_stats(self, stats):
logger.info('Stats: \n{}'.format(yaml.dump(stats, default_flow_style=False)))
def get_env(self, trial_id=0, **kwargs): def get_env(self, trial_id=0, **kwargs):
'''Create an environment for a trial of the simulation''' '''Create an environment for a trial of the simulation'''

View File

@ -1,111 +0,0 @@
import pandas as pd
from collections import Counter
class Stats:
'''
Interface for all stats. It is not necessary, but it is useful
if you don't plan to implement all the methods.
'''
def __init__(self, simulation, name=None):
self.name = name or type(self).__name__
self.simulation = simulation
def sim_start(self):
'''Method to call when the simulation starts'''
pass
def sim_end(self):
'''Method to call when the simulation ends'''
return {}
def trial_start(self, env):
'''Method to call when a trial starts'''
return {}
def trial_end(self, env):
'''Method to call when a trial ends'''
return {}
class distribution(Stats):
'''
Calculate the distribution of agent states at the end of each trial,
the mean value, and its deviation.
'''
def sim_start(self):
self.means = []
self.counts = []
def trial_end(self, env):
df = pd.DataFrame(env.state_to_tuples())
df = df.drop('SEED', axis=1)
ix = df.index[-1]
attrs = df.columns.get_level_values(0)
vc = {}
stats = {
'mean': {},
'count': {},
}
for a in attrs:
t = df.loc[(ix, a)]
try:
stats['mean'][a] = t.mean()
self.means.append(('mean', a, t.mean()))
except TypeError:
pass
for name, count in t.value_counts().iteritems():
if a not in stats['count']:
stats['count'][a] = {}
stats['count'][a][name] = count
self.counts.append(('count', a, name, count))
return stats
def sim_end(self):
dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value'])
dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count'])
count = {}
mean = {}
if self.means:
res = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
mean = res['value'].to_dict()
if self.counts:
res = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
for k,v in res['count'].to_dict().items():
if k not in count:
count[k] = {}
for tup, times in v.items():
subkey, subcount = tup
if subkey not in count[k]:
count[k][subkey] = {}
count[k][subkey][subcount] = times
return {'count': count, 'mean': mean}
class defaultStats(Stats):
def trial_end(self, env):
c = Counter()
c.update(a.__class__.__name__ for a in env.network_agents)
c2 = Counter()
c2.update(a['id'] for a in env.network_agents)
return {
'network ': {
'n_nodes': env.G.number_of_nodes(),
'n_edges': env.G.number_of_edges(),
},
'agents': {
'model_count': dict(c),
'state_count': dict(c2),
}
}

View File

@ -29,11 +29,11 @@ agents:
weight: 0.6 weight: 0.6
override: override:
- filter: - filter:
id: 0 node_id: 0
state: state:
name: 'The first node' name: 'The first node'
- filter: - filter:
id: 1 node_id: 1
state: state:
name: 'The second node' name: 'The second node'

View File

@ -71,11 +71,11 @@ class TestConfig(TestCase):
s = simulation.from_config(cfg) s = simulation.from_config(cfg)
env = s.get_env() env = s.get_env()
assert len(env.topologies['default'].nodes) == 10 assert len(env.topologies['default'].nodes) == 10
assert len(env.agents('network')) == 10 assert len(env.agents(group='network')) == 10
assert len(env.agents('environment')) == 1 assert len(env.agents(group='environment')) == 1
assert sum(1 for a in env.agents('network') if isinstance(a, agents.CounterModel)) == 4 assert sum(1 for a in env.agents(group='network', agent_type=agents.CounterModel)) == 4
assert sum(1 for a in env.agents('network') if isinstance(a, agents.AggregatedCounter)) == 6 assert sum(1 for a in env.agents(group='network', agent_type=agents.AggregatedCounter)) == 6
def make_example_test(path, cfg): def make_example_test(path, cfg):
def wrapped(self): def wrapped(self):

View File

@ -7,8 +7,6 @@ from unittest import TestCase
from soil import exporters from soil import exporters
from soil import simulation from soil import simulation
from soil.stats import distribution
class Dummy(exporters.Exporter): class Dummy(exporters.Exporter):
started = False started = False
trials = 0 trials = 0
@ -22,13 +20,13 @@ class Dummy(exporters.Exporter):
self.__class__.called_start += 1 self.__class__.called_start += 1
self.__class__.started = True self.__class__.started = True
def trial_end(self, env, stats): def trial_end(self, env):
assert env assert env
self.__class__.trials += 1 self.__class__.trials += 1
self.__class__.total_time += env.now self.__class__.total_time += env.now
self.__class__.called_trial += 1 self.__class__.called_trial += 1
def sim_end(self, stats): def sim_end(self):
self.__class__.ended = True self.__class__.ended = True
self.__class__.called_end += 1 self.__class__.called_end += 1
@ -78,7 +76,6 @@ class Exporters(TestCase):
exporters.csv, exporters.csv,
exporters.gexf, exporters.gexf,
], ],
stats=[distribution,],
dry_run=False, dry_run=False,
outdir=tmpdir, outdir=tmpdir,
exporter_params={'copy_to': output}) exporter_params={'copy_to': output})

View File

@ -10,7 +10,7 @@ from functools import partial
from os.path import join from os.path import join
from soil import (simulation, Environment, agents, network, serialization, from soil import (simulation, Environment, agents, network, serialization,
utils) utils, config)
from soil.time import Delta from soil.time import Delta
ROOT = os.path.abspath(os.path.dirname(__file__)) ROOT = os.path.abspath(os.path.dirname(__file__))
@ -200,7 +200,6 @@ class TestMain(TestCase):
recovered = yaml.load(serial, Loader=yaml.SafeLoader) recovered = yaml.load(serial, Loader=yaml.SafeLoader)
for (k, v) in config.items(): for (k, v) in config.items():
assert recovered[k] == v assert recovered[k] == v
# assert config == recovered
def test_configuration_changes(self): def test_configuration_changes(self):
""" """
@ -294,11 +293,13 @@ class TestMain(TestCase):
G.add_node(3) G.add_node(3)
G.add_edge(1, 2) G.add_edge(1, 2)
distro = agents.calculate_distribution(agent_type=agents.NetworkAgent) distro = agents.calculate_distribution(agent_type=agents.NetworkAgent)
env = Environment(name='Test', topology=G, network_agents=distro) distro[0]['topology'] = 'default'
aconfig = config.AgentConfig(distribution=distro, topology='default')
env = Environment(name='Test', topologies={'default': G}, agents={'network': aconfig})
lst = list(env.network_agents) lst = list(env.network_agents)
a2 = env.get_agent(2) a2 = env.find_one(node_id=2)
a3 = env.get_agent(3) a3 = env.find_one(node_id=3)
assert len(a2.subgraph(limit_neighbors=True)) == 2 assert len(a2.subgraph(limit_neighbors=True)) == 2
assert len(a3.subgraph(limit_neighbors=True)) == 1 assert len(a3.subgraph(limit_neighbors=True)) == 1
assert len(a3.subgraph(limit_neighbors=True, center=False)) == 0 assert len(a3.subgraph(limit_neighbors=True, center=False)) == 0

View File

@ -1,34 +0,0 @@
from unittest import TestCase
from soil import simulation, stats
from soil.utils import unflatten_dict
class Stats(TestCase):
def test_distribution(self):
'''The distribution exporter should write the number of agents in each state'''
config = {
'name': 'exporter_sim',
'network_params': {
'generator': 'complete_graph',
'n': 4
},
'agent_type': 'CounterModel',
'max_time': 2,
'num_trials': 5,
'environment_params': {}
}
s = simulation.from_config(config)
for env in s.run_simulation(stats=[stats.distribution]):
pass
# stats_res = unflatten_dict(dict(env._history['stats', -1, None]))
allstats = s.get_stats()
for stat in allstats:
assert 'count' in stat
assert 'mean' in stat
if 'trial_id' in stat:
assert stat['mean']['neighbors'] == 3
assert stat['count']['total']['4'] == 4
else:
assert stat['count']['count']['neighbors']['3'] == 20
assert stat['mean']['min']['neighbors'] == stat['mean']['max']['neighbors']