mirror of
https://github.com/gsi-upm/soil
synced 2024-11-21 18:52:28 +00:00
WIP: removed stats
This commit is contained in:
parent
3dc56892c1
commit
0a9c6d8b19
@ -14,7 +14,6 @@ network_agents:
|
|||||||
weight: 1
|
weight: 1
|
||||||
environment_class: social_wealth.MoneyEnv
|
environment_class: social_wealth.MoneyEnv
|
||||||
environment_params:
|
environment_params:
|
||||||
num_mesa_agents: 5
|
|
||||||
mesa_agent_type: social_wealth.MoneyAgent
|
mesa_agent_type: social_wealth.MoneyAgent
|
||||||
N: 10
|
N: 10
|
||||||
width: 50
|
width: 50
|
||||||
|
@ -71,10 +71,9 @@ class SocialMoneyAgent(NetworkAgent, MoneyAgent):
|
|||||||
|
|
||||||
class MoneyEnv(Environment):
|
class MoneyEnv(Environment):
|
||||||
"""A model with some number of agents."""
|
"""A model with some number of agents."""
|
||||||
def __init__(self, N, width, height, *args, network_params, **kwargs):
|
def __init__(self, width, height, *args, topologies, **kwargs):
|
||||||
|
|
||||||
network_params['n'] = N
|
super().__init__(*args, topologies=topologies, **kwargs)
|
||||||
super().__init__(*args, network_params=network_params, **kwargs)
|
|
||||||
self.grid = MultiGrid(width, height, False)
|
self.grid = MultiGrid(width, height, False)
|
||||||
|
|
||||||
# Create agents
|
# Create agents
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from soil.agents import FSM, state, default_state, prob
|
from soil.agents import FSM, NetworkAgent, state, default_state, prob
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class DumbViewer(FSM):
|
class DumbViewer(FSM, NetworkAgent):
|
||||||
'''
|
'''
|
||||||
A viewer that gets infected via TV (if it has one) and tries to infect
|
A viewer that gets infected via TV (if it has one) and tries to infect
|
||||||
its neighbors once it's infected.
|
its neighbors once it's infected.
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from soil.agents import FSM, state, default_state
|
from soil.agents import FSM, NetworkAgent, state, default_state
|
||||||
from soil import Environment
|
from soil import Environment
|
||||||
from random import random, shuffle
|
from random import random, shuffle
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
@ -53,7 +53,7 @@ class CityPubs(Environment):
|
|||||||
pub['occupancy'] -= 1
|
pub['occupancy'] -= 1
|
||||||
|
|
||||||
|
|
||||||
class Patron(FSM):
|
class Patron(FSM, NetworkAgent):
|
||||||
'''Agent that looks for friends to drink with. It will do three things:
|
'''Agent that looks for friends to drink with. It will do three things:
|
||||||
1) Look for other patrons to drink with
|
1) Look for other patrons to drink with
|
||||||
2) Look for a bar where the agent and other agents in the same group can get in.
|
2) Look for a bar where the agent and other agents in the same group can get in.
|
||||||
@ -151,7 +151,7 @@ class Patron(FSM):
|
|||||||
return befriended
|
return befriended
|
||||||
|
|
||||||
|
|
||||||
class Police(FSM):
|
class Police(FSM, NetworkAgent):
|
||||||
'''Simple agent to take drunk people out of pubs.'''
|
'''Simple agent to take drunk people out of pubs.'''
|
||||||
level = logging.INFO
|
level = logging.INFO
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ class Genders(Enum):
|
|||||||
female = 'female'
|
female = 'female'
|
||||||
|
|
||||||
|
|
||||||
class RabbitModel(FSM):
|
class RabbitModel(FSM, NetworkAgent):
|
||||||
|
|
||||||
defaults = {
|
defaults = {
|
||||||
'age': 0,
|
'age': 0,
|
||||||
@ -110,12 +110,12 @@ class Female(RabbitModel):
|
|||||||
self.info('A mother has died carrying a baby!!')
|
self.info('A mother has died carrying a baby!!')
|
||||||
|
|
||||||
|
|
||||||
class RandomAccident(NetworkAgent):
|
class RandomAccident(BaseAgent):
|
||||||
|
|
||||||
level = logging.DEBUG
|
level = logging.DEBUG
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
rabbits_total = self.topology.number_of_nodes()
|
rabbits_total = self.env.topology.number_of_nodes()
|
||||||
if 'rabbits_alive' not in self.env:
|
if 'rabbits_alive' not in self.env:
|
||||||
self.env['rabbits_alive'] = 0
|
self.env['rabbits_alive'] = 0
|
||||||
rabbits_alive = self.env.get('rabbits_alive', rabbits_total)
|
rabbits_alive = self.env.get('rabbits_alive', rabbits_total)
|
||||||
@ -131,5 +131,5 @@ class RandomAccident(NetworkAgent):
|
|||||||
self.log('Rabbits alive: {}'.format(self.env['rabbits_alive']))
|
self.log('Rabbits alive: {}'.format(self.env['rabbits_alive']))
|
||||||
i.set_state(i.dead)
|
i.set_state(i.dead)
|
||||||
self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total))
|
self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total))
|
||||||
if self.count_agents(state_id=RabbitModel.dead.id) == self.topology.number_of_nodes():
|
if self.env.count_agents(state_id=RabbitModel.dead.id) == self.env.topology.number_of_nodes():
|
||||||
self.die()
|
self.die()
|
||||||
|
@ -55,6 +55,7 @@ class BaseAgent(MesaAgent, MutableMapping):
|
|||||||
raise Exception()
|
raise Exception()
|
||||||
assert isinstance(unique_id, int)
|
assert isinstance(unique_id, int)
|
||||||
super().__init__(unique_id=unique_id, model=model)
|
super().__init__(unique_id=unique_id, model=model)
|
||||||
|
|
||||||
self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id)
|
self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id)
|
||||||
|
|
||||||
|
|
||||||
@ -78,6 +79,9 @@ class BaseAgent(MesaAgent, MutableMapping):
|
|||||||
if not hasattr(self, k) or getattr(self, k) is None:
|
if not hasattr(self, k) or getattr(self, k) is None:
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.unique_id)
|
||||||
|
|
||||||
# TODO: refactor to clean up mesa compatibility
|
# TODO: refactor to clean up mesa compatibility
|
||||||
@property
|
@property
|
||||||
def id(self):
|
def id(self):
|
||||||
@ -185,16 +189,14 @@ class BaseAgent(MesaAgent, MutableMapping):
|
|||||||
# Agent = BaseAgent
|
# Agent = BaseAgent
|
||||||
|
|
||||||
class NetworkAgent(BaseAgent):
|
class NetworkAgent(BaseAgent):
|
||||||
def __init__(self,
|
|
||||||
*args,
|
@property
|
||||||
graph_name: str,
|
def topology(self):
|
||||||
node_id: int = None,
|
return self.env.topology_for(self.unique_id)
|
||||||
**kwargs,
|
|
||||||
):
|
@property
|
||||||
super().__init__(*args, **kwargs)
|
def node_id(self):
|
||||||
self.graph_name = graph_name
|
return self.env.node_id_for(self.unique_id)
|
||||||
self.topology = self.env.topologies[self.graph_name]
|
|
||||||
self.node_id = node_id
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def G(self):
|
def G(self):
|
||||||
@ -215,15 +217,19 @@ class NetworkAgent(BaseAgent):
|
|||||||
it = islice(it, limit)
|
it = islice(it, limit)
|
||||||
return list(it)
|
return list(it)
|
||||||
|
|
||||||
def iter_agents(self, agents=None, limit_neighbors=False, **kwargs):
|
def iter_agents(self, unique_id=None, limit_neighbors=False, **kwargs):
|
||||||
if limit_neighbors:
|
if limit_neighbors:
|
||||||
agents = self.topology.neighbors(self.unique_id)
|
unique_id = [self.topology.nodes[node]['agent_id'] for node in self.topology.neighbors(self.node_id)]
|
||||||
|
if not unique_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
yield from self.model.agents(unique_id=unique_id, **kwargs)
|
||||||
|
|
||||||
return self.model.agents(ids=agents, **kwargs)
|
|
||||||
|
|
||||||
def subgraph(self, center=True, **kwargs):
|
def subgraph(self, center=True, **kwargs):
|
||||||
include = [self] if center else []
|
include = [self] if center else []
|
||||||
return self.topology.subgraph(n.unique_id for n in list(self.get_agents(**kwargs))+include)
|
G = self.topology.subgraph(n.node_id for n in list(self.get_agents(**kwargs)+include))
|
||||||
|
return G
|
||||||
|
|
||||||
def remove_node(self, unique_id):
|
def remove_node(self, unique_id):
|
||||||
self.topology.remove_node(unique_id)
|
self.topology.remove_node(unique_id)
|
||||||
@ -366,7 +372,7 @@ def prob(prob=1):
|
|||||||
|
|
||||||
|
|
||||||
def calculate_distribution(network_agents=None,
|
def calculate_distribution(network_agents=None,
|
||||||
agent_type=None):
|
agent_class=None):
|
||||||
'''
|
'''
|
||||||
Calculate the threshold values (thresholds for a uniform distribution)
|
Calculate the threshold values (thresholds for a uniform distribution)
|
||||||
of an agent distribution given the weights of each agent type.
|
of an agent distribution given the weights of each agent type.
|
||||||
@ -374,13 +380,13 @@ def calculate_distribution(network_agents=None,
|
|||||||
The input has this form: ::
|
The input has this form: ::
|
||||||
|
|
||||||
[
|
[
|
||||||
{'agent_type': 'agent_type_1',
|
{'agent_class': 'agent_class_1',
|
||||||
'weight': 0.2,
|
'weight': 0.2,
|
||||||
'state': {
|
'state': {
|
||||||
'id': 0
|
'id': 0
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{'agent_type': 'agent_type_2',
|
{'agent_class': 'agent_class_2',
|
||||||
'weight': 0.8,
|
'weight': 0.8,
|
||||||
'state': {
|
'state': {
|
||||||
'id': 1
|
'id': 1
|
||||||
@ -389,12 +395,12 @@ def calculate_distribution(network_agents=None,
|
|||||||
]
|
]
|
||||||
|
|
||||||
In this example, 20% of the nodes will be marked as type
|
In this example, 20% of the nodes will be marked as type
|
||||||
'agent_type_1'.
|
'agent_class_1'.
|
||||||
'''
|
'''
|
||||||
if network_agents:
|
if network_agents:
|
||||||
network_agents = [deepcopy(agent) for agent in network_agents if not hasattr(agent, 'id')]
|
network_agents = [deepcopy(agent) for agent in network_agents if not hasattr(agent, 'id')]
|
||||||
elif agent_type:
|
elif agent_class:
|
||||||
network_agents = [{'agent_type': agent_type}]
|
network_agents = [{'agent_class': agent_class}]
|
||||||
else:
|
else:
|
||||||
raise ValueError('Specify a distribution or a default agent type')
|
raise ValueError('Specify a distribution or a default agent type')
|
||||||
|
|
||||||
@ -414,11 +420,11 @@ def calculate_distribution(network_agents=None,
|
|||||||
return network_agents
|
return network_agents
|
||||||
|
|
||||||
|
|
||||||
def serialize_type(agent_type, known_modules=[], **kwargs):
|
def serialize_type(agent_class, known_modules=[], **kwargs):
|
||||||
if isinstance(agent_type, str):
|
if isinstance(agent_class, str):
|
||||||
return agent_type
|
return agent_class
|
||||||
known_modules += ['soil.agents']
|
known_modules += ['soil.agents']
|
||||||
return serialization.serialize(agent_type, known_modules=known_modules, **kwargs)[1] # Get the name of the class
|
return serialization.serialize(agent_class, known_modules=known_modules, **kwargs)[1] # Get the name of the class
|
||||||
|
|
||||||
|
|
||||||
def serialize_definition(network_agents, known_modules=[]):
|
def serialize_definition(network_agents, known_modules=[]):
|
||||||
@ -430,23 +436,23 @@ def serialize_definition(network_agents, known_modules=[]):
|
|||||||
for v in d:
|
for v in d:
|
||||||
if 'threshold' in v:
|
if 'threshold' in v:
|
||||||
del v['threshold']
|
del v['threshold']
|
||||||
v['agent_type'] = serialize_type(v['agent_type'],
|
v['agent_class'] = serialize_type(v['agent_class'],
|
||||||
known_modules=known_modules)
|
known_modules=known_modules)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def deserialize_type(agent_type, known_modules=[]):
|
def deserialize_type(agent_class, known_modules=[]):
|
||||||
if not isinstance(agent_type, str):
|
if not isinstance(agent_class, str):
|
||||||
return agent_type
|
return agent_class
|
||||||
known = known_modules + ['soil.agents', 'soil.agents.custom' ]
|
known = known_modules + ['soil.agents', 'soil.agents.custom' ]
|
||||||
agent_type = serialization.deserializer(agent_type, known_modules=known)
|
agent_class = serialization.deserializer(agent_class, known_modules=known)
|
||||||
return agent_type
|
return agent_class
|
||||||
|
|
||||||
|
|
||||||
def deserialize_definition(ind, **kwargs):
|
def deserialize_definition(ind, **kwargs):
|
||||||
d = deepcopy(ind)
|
d = deepcopy(ind)
|
||||||
for v in d:
|
for v in d:
|
||||||
v['agent_type'] = deserialize_type(v['agent_type'], **kwargs)
|
v['agent_class'] = deserialize_type(v['agent_class'], **kwargs)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
@ -461,7 +467,7 @@ def _validate_states(states, topology):
|
|||||||
return states
|
return states
|
||||||
|
|
||||||
|
|
||||||
def _convert_agent_types(ind, to_string=False, **kwargs):
|
def _convert_agent_classs(ind, to_string=False, **kwargs):
|
||||||
'''Convenience method to allow specifying agents by class or class name.'''
|
'''Convenience method to allow specifying agents by class or class name.'''
|
||||||
if to_string:
|
if to_string:
|
||||||
return serialize_definition(ind, **kwargs)
|
return serialize_definition(ind, **kwargs)
|
||||||
@ -480,7 +486,7 @@ def _agent_from_definition(definition, value=-1, unique_id=None):
|
|||||||
state = {}
|
state = {}
|
||||||
if 'state' in d:
|
if 'state' in d:
|
||||||
state = deepcopy(d['state'])
|
state = deepcopy(d['state'])
|
||||||
return d['agent_type'], state
|
return d['agent_class'], state
|
||||||
|
|
||||||
raise Exception('Definition for value {} not found in: {}'.format(value, definition))
|
raise Exception('Definition for value {} not found in: {}'.format(value, definition))
|
||||||
|
|
||||||
@ -576,8 +582,11 @@ class AgentView(Mapping, Set):
|
|||||||
return group[agent_id]
|
return group[agent_id]
|
||||||
raise ValueError(f"Agent {agent_id} not found")
|
raise ValueError(f"Agent {agent_id} not found")
|
||||||
|
|
||||||
def filter(self, *group_ids, **kwargs):
|
def filter(self, *args, **kwargs):
|
||||||
yield from filter_groups(self._agents, group_ids=group_ids, **kwargs)
|
yield from filter_groups(self._agents, *args, **kwargs)
|
||||||
|
|
||||||
|
def one(self, *args, **kwargs):
|
||||||
|
return next(filter_groups(self._agents, *args, **kwargs))
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return list(self.filter(*args, **kwargs))
|
return list(self.filter(*args, **kwargs))
|
||||||
@ -586,41 +595,57 @@ class AgentView(Mapping, Set):
|
|||||||
return any(agent_id in g for g in self._agents)
|
return any(agent_id in g for g in self._agents)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(list(a.id for a in self))
|
return str(list(a.unique_id for a in self))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.__class__.__name__}({self})"
|
return f"{self.__class__.__name__}({self})"
|
||||||
|
|
||||||
|
|
||||||
def filter_groups(groups, group_ids=None, **kwargs):
|
def filter_groups(groups, *, group=None, **kwargs):
|
||||||
assert isinstance(groups, dict)
|
assert isinstance(groups, dict)
|
||||||
if group_ids:
|
|
||||||
groups = list(groups[g] for g in group_ids if g in groups)
|
if group is not None and not isinstance(group, list):
|
||||||
|
group = [group]
|
||||||
|
|
||||||
|
if group:
|
||||||
|
groups = list(groups[g] for g in group if g in groups)
|
||||||
else:
|
else:
|
||||||
groups = list(groups.values())
|
groups = list(groups.values())
|
||||||
|
|
||||||
agents = chain.from_iterable(filter_group(g, **kwargs) for g in groups)
|
agents = chain.from_iterable(filter_group(g, **kwargs) for g in groups)
|
||||||
|
|
||||||
yield from agents
|
yield from agents
|
||||||
|
|
||||||
|
|
||||||
def filter_group(group, ids=None, state_id=None, agent_type=None, ignore=None, state=None, **kwargs):
|
def filter_group(group, *id_args, unique_id=None, state_id=None, agent_class=None, ignore=None, state=None, **kwargs):
|
||||||
'''
|
'''
|
||||||
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
||||||
'''
|
'''
|
||||||
assert isinstance(group, dict)
|
assert isinstance(group, dict)
|
||||||
|
|
||||||
|
ids = []
|
||||||
|
|
||||||
|
if unique_id is not None:
|
||||||
|
if isinstance(unique_id, list):
|
||||||
|
ids += unique_id
|
||||||
|
else:
|
||||||
|
ids.append(unique_id)
|
||||||
|
|
||||||
|
if id_args:
|
||||||
|
ids += id_args
|
||||||
|
|
||||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||||
state_id = tuple([state_id])
|
state_id = tuple([state_id])
|
||||||
|
|
||||||
if agent_type is not None:
|
if agent_class is not None:
|
||||||
|
agent_class = deserialize_type(agent_class)
|
||||||
try:
|
try:
|
||||||
agent_type = tuple(agent_type)
|
agent_class = tuple(agent_class)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
agent_type = tuple([agent_type])
|
agent_class = tuple([agent_class])
|
||||||
|
|
||||||
if ids:
|
if ids:
|
||||||
agents = (v[aid] for aid in ids if aid in group)
|
agents = (group[aid] for aid in ids if aid in group)
|
||||||
else:
|
else:
|
||||||
agents = (a for a in group.values())
|
agents = (a for a in group.values())
|
||||||
|
|
||||||
@ -631,8 +656,8 @@ def filter_group(group, ids=None, state_id=None, agent_type=None, ignore=None, s
|
|||||||
if state_id is not None:
|
if state_id is not None:
|
||||||
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
|
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
|
||||||
|
|
||||||
if agent_type is not None:
|
if agent_class is not None:
|
||||||
f = filter(lambda agent: isinstance(agent, agent_type), f)
|
f = filter(lambda agent: isinstance(agent, agent_class), f)
|
||||||
|
|
||||||
state = state or dict()
|
state = state or dict()
|
||||||
state.update(kwargs)
|
state.update(kwargs)
|
||||||
@ -660,7 +685,7 @@ def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfi
|
|||||||
if cfg.fixed is not None:
|
if cfg.fixed is not None:
|
||||||
agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env)
|
agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env)
|
||||||
if cfg.distribution:
|
if cfg.distribution:
|
||||||
n = cfg.n or len(env.topologies[cfg.topology])
|
n = cfg.n or len(env.topologies[cfg.topology or default.topology])
|
||||||
target = n - len(agents)
|
target = n - len(agents)
|
||||||
agents.update(_from_distro(cfg.distribution, target,
|
agents.update(_from_distro(cfg.distribution, target,
|
||||||
topology=cfg.topology or default.topology,
|
topology=cfg.topology or default.topology,
|
||||||
@ -674,6 +699,8 @@ def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfi
|
|||||||
else:
|
else:
|
||||||
filtered = list(agents)
|
filtered = list(agents)
|
||||||
|
|
||||||
|
if attrs.n > len(filtered):
|
||||||
|
raise ValueError(f'Not enough agents to sample. Got {len(filtered)}, expected >= {attrs.n}')
|
||||||
for agent in random.sample(filtered, attrs.n):
|
for agent in random.sample(filtered, attrs.n):
|
||||||
agent.state.update(attrs.state)
|
agent.state.update(attrs.state)
|
||||||
|
|
||||||
@ -684,18 +711,20 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: conf
|
|||||||
agents = {}
|
agents = {}
|
||||||
|
|
||||||
for fixed in lst:
|
for fixed in lst:
|
||||||
agent_id = fixed.agent_id
|
agent_id = fixed.agent_id
|
||||||
if agent_id is None:
|
if agent_id is None:
|
||||||
agent_id = env.next_id()
|
agent_id = env.next_id()
|
||||||
|
|
||||||
cls = serialization.deserialize(fixed.agent_class or default.agent_class)
|
cls = serialization.deserialize(fixed.agent_class or default.agent_class)
|
||||||
state = fixed.state.copy()
|
state = fixed.state.copy()
|
||||||
state.update(default.state)
|
state.update(default.state)
|
||||||
agent = cls(unique_id=agent_id,
|
agent = cls(unique_id=agent_id,
|
||||||
model=env,
|
model=env,
|
||||||
graph_name=fixed.topology or topology or default.topology,
|
**state)
|
||||||
**state)
|
topology = fixed.topology if (fixed.topology is not None) else (topology or default.topology)
|
||||||
agents[agent.unique_id] = agent
|
if topology:
|
||||||
|
env.agent_to_node(agent_id, topology, fixed.node_id)
|
||||||
|
agents[agent.unique_id] = agent
|
||||||
|
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
@ -741,8 +770,12 @@ def _from_distro(distro: List[config.AgentDistro],
|
|||||||
cls = classes[idx]
|
cls = classes[idx]
|
||||||
agent_id = env.next_id()
|
agent_id = env.next_id()
|
||||||
state = d.state.copy()
|
state = d.state.copy()
|
||||||
state.update(default.state)
|
if default:
|
||||||
agent = cls(unique_id=agent_id, model=env, graph_name=d.topology or topology or default.topology, **state)
|
state.update(default.state)
|
||||||
|
agent = cls(unique_id=agent_id, model=env, **state)
|
||||||
|
topology = d.topology if (d.topology is not None) else topology or default.topology
|
||||||
|
if topology:
|
||||||
|
env.agent_to_node(agent.unique_id, topology)
|
||||||
assert agent.name is not None
|
assert agent.name is not None
|
||||||
assert agent.name != 'None'
|
assert agent.name != 'None'
|
||||||
assert agent.name
|
assert agent.name
|
||||||
|
@ -7,6 +7,7 @@ import sys
|
|||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union, Type
|
from typing import Any, Callable, Dict, List, Optional, Union, Type
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra
|
||||||
|
import networkx as nx
|
||||||
|
|
||||||
class General(BaseModel):
|
class General(BaseModel):
|
||||||
id: str = 'Unnamed Simulation'
|
id: str = 'Unnamed Simulation'
|
||||||
@ -50,9 +51,12 @@ class NetParams(BaseModel, extra=Extra.allow):
|
|||||||
class NetConfig(BaseModel):
|
class NetConfig(BaseModel):
|
||||||
group: str = 'network'
|
group: str = 'network'
|
||||||
params: Optional[NetParams]
|
params: Optional[NetParams]
|
||||||
topology: Optional[Topology]
|
topology: Optional[Union[Topology, nx.Graph]]
|
||||||
path: Optional[str]
|
path: Optional[str]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default():
|
def default():
|
||||||
return NetConfig(topology=None, params=None)
|
return NetConfig(topology=None, params=None)
|
||||||
@ -77,7 +81,8 @@ class EnvConfig(BaseModel):
|
|||||||
class SingleAgentConfig(BaseModel):
|
class SingleAgentConfig(BaseModel):
|
||||||
agent_class: Optional[Union[Type, str]] = None
|
agent_class: Optional[Union[Type, str]] = None
|
||||||
agent_id: Optional[int] = None
|
agent_id: Optional[int] = None
|
||||||
topology: Optional[str] = 'default'
|
topology: Optional[str] = None
|
||||||
|
node_id: Optional[Union[int, str]] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
state: Optional[Dict[str, Any]] = {}
|
state: Optional[Dict[str, Any]] = {}
|
||||||
|
|
||||||
@ -186,9 +191,7 @@ def convert_old(old, strict=True):
|
|||||||
if 'agent_id' in agent:
|
if 'agent_id' in agent:
|
||||||
agent['name'] = agent['agent_id']
|
agent['name'] = agent['agent_id']
|
||||||
del agent['agent_id']
|
del agent['agent_id']
|
||||||
agents['environment']['fixed'].append(updated_agent(agent))
|
agents['environment']['fixed'].append(updated_agent(agent))
|
||||||
else:
|
|
||||||
agents['environment']['distribution'].append(updated_agent(agent))
|
|
||||||
|
|
||||||
by_weight = []
|
by_weight = []
|
||||||
fixed = []
|
fixed = []
|
||||||
@ -206,10 +209,10 @@ def convert_old(old, strict=True):
|
|||||||
|
|
||||||
if 'agent_type' in old and (not fixed and not by_weight):
|
if 'agent_type' in old and (not fixed and not by_weight):
|
||||||
agents['network']['topology'] = 'default'
|
agents['network']['topology'] = 'default'
|
||||||
by_weight = [{'agent_type': old['agent_type']}]
|
by_weight = [{'agent_class': old['agent_type']}]
|
||||||
|
|
||||||
|
|
||||||
# TODO: translate states
|
# TODO: translate states properly
|
||||||
if 'states' in old:
|
if 'states' in old:
|
||||||
states = old['states']
|
states = old['states']
|
||||||
if isinstance(states, dict):
|
if isinstance(states, dict):
|
||||||
@ -217,7 +220,7 @@ def convert_old(old, strict=True):
|
|||||||
else:
|
else:
|
||||||
states = enumerate(states)
|
states = enumerate(states)
|
||||||
for (k, v) in states:
|
for (k, v) in states:
|
||||||
override.append({'filter': {'id': k},
|
override.append({'filter': {'node_id': k},
|
||||||
'state': v
|
'state': v
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -1,264 +0,0 @@
|
|||||||
from pydantic import BaseModel, ValidationError, validator
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import networkx as nx
|
|
||||||
import collections.abc
|
|
||||||
|
|
||||||
from . import serialization, utils, basestring, agents
|
|
||||||
|
|
||||||
class Config(collections.abc.Mapping):
|
|
||||||
"""
|
|
||||||
|
|
||||||
1) agent type can be specified by name or by class.
|
|
||||||
2) instead of just one type, a network agents distribution can be used.
|
|
||||||
The distribution specifies the weight (or probability) of each
|
|
||||||
agent type in the topology. This is an example distribution: ::
|
|
||||||
|
|
||||||
[
|
|
||||||
{'agent_type': 'agent_type_1',
|
|
||||||
'weight': 0.2,
|
|
||||||
'state': {
|
|
||||||
'id': 0
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{'agent_type': 'agent_type_2',
|
|
||||||
'weight': 0.8,
|
|
||||||
'state': {
|
|
||||||
'id': 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
In this example, 20% of the nodes will be marked as type
|
|
||||||
'agent_type_1'.
|
|
||||||
3) if no initial state is given, each node's state will be set
|
|
||||||
to `{'id': 0}`.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
---------
|
|
||||||
name : str, optional
|
|
||||||
name of the Simulation
|
|
||||||
group : str, optional
|
|
||||||
a group name can be used to link simulations
|
|
||||||
topology (optional): networkx.Graph instance or Node-Link topology as a dict or string (will be loaded with `json_graph.node_link_graph(topology`).
|
|
||||||
network_params : dict
|
|
||||||
parameters used to create a topology with networkx, if no topology is given
|
|
||||||
network_agents : dict
|
|
||||||
definition of agents to populate the topology with
|
|
||||||
agent_type : NetworkAgent subclass, optional
|
|
||||||
Default type of NetworkAgent to use for nodes not specified in network_agents
|
|
||||||
states : list, optional
|
|
||||||
List of initial states corresponding to the nodes in the topology. Basic form is a list of integers
|
|
||||||
whose value indicates the state
|
|
||||||
dir_path: str, optional
|
|
||||||
Directory path to load simulation assets (files, modules...)
|
|
||||||
seed : str, optional
|
|
||||||
Seed to use for the random generator
|
|
||||||
num_trials : int, optional
|
|
||||||
Number of independent simulation runs
|
|
||||||
max_time : int, optional
|
|
||||||
Maximum step/time for each simulation
|
|
||||||
environment_params : dict, optional
|
|
||||||
Dictionary of globally-shared environmental parameters
|
|
||||||
environment_agents: dict, optional
|
|
||||||
Similar to network_agents. Distribution of Agents that control the environment
|
|
||||||
environment_class: soil.environment.Environment subclass, optional
|
|
||||||
Class for the environment. It defailts to soil.environment.Environment
|
|
||||||
"""
|
|
||||||
__slots__ = 'name', 'agent_type', 'group', 'description', 'network_agents', 'environment_agents', 'states', 'default_state', 'interval', 'network_params', 'seed', 'num_trials', 'max_time', 'topology', 'schedule', 'initial_time', 'environment_params', 'environment_class', 'dir_path', '_added_to_path', 'visualization_params'
|
|
||||||
|
|
||||||
def __init__(self, name=None,
|
|
||||||
group=None,
|
|
||||||
agent_type='BaseAgent',
|
|
||||||
network_agents=None,
|
|
||||||
environment_agents=None,
|
|
||||||
states=None,
|
|
||||||
description=None,
|
|
||||||
default_state=None,
|
|
||||||
interval=1,
|
|
||||||
network_params=None,
|
|
||||||
seed=None,
|
|
||||||
num_trials=1,
|
|
||||||
max_time=None,
|
|
||||||
topology=None,
|
|
||||||
schedule=None,
|
|
||||||
initial_time=0,
|
|
||||||
environment_params={},
|
|
||||||
environment_class='soil.Environment',
|
|
||||||
dir_path=None,
|
|
||||||
visualization_params=None,
|
|
||||||
):
|
|
||||||
|
|
||||||
self.network_params = network_params
|
|
||||||
self.name = name or 'Unnamed'
|
|
||||||
self.description = description or 'No simulation description available'
|
|
||||||
self.seed = str(seed or name)
|
|
||||||
self.group = group or ''
|
|
||||||
self.num_trials = num_trials
|
|
||||||
self.max_time = max_time
|
|
||||||
self.default_state = default_state or {}
|
|
||||||
self.dir_path = dir_path or os.getcwd()
|
|
||||||
self.interval = interval
|
|
||||||
self.visualization_params = visualization_params or {}
|
|
||||||
|
|
||||||
self._added_to_path = list(x for x in [os.getcwd(), self.dir_path] if x not in sys.path)
|
|
||||||
sys.path += self._added_to_path
|
|
||||||
|
|
||||||
self.topology = topology
|
|
||||||
|
|
||||||
self.schedule = schedule
|
|
||||||
self.initial_time = initial_time
|
|
||||||
|
|
||||||
|
|
||||||
self.environment_class = environment_class
|
|
||||||
self.environment_params = dict(environment_params)
|
|
||||||
|
|
||||||
#TODO: Check agent distro vs fixed agents
|
|
||||||
self.environment_agents = environment_agents or []
|
|
||||||
|
|
||||||
self.agent_type = agent_type
|
|
||||||
|
|
||||||
self.network_agents = network_agents or {}
|
|
||||||
|
|
||||||
self.states = states or {}
|
|
||||||
|
|
||||||
|
|
||||||
def validate(self):
|
|
||||||
agents._validate_states(self.states,
|
|
||||||
self._topology)
|
|
||||||
|
|
||||||
def calculate(self):
|
|
||||||
return CalculatedConfig(self)
|
|
||||||
|
|
||||||
def restore_path(self):
|
|
||||||
for added in self._added_to_path:
|
|
||||||
sys.path.remove(added)
|
|
||||||
|
|
||||||
def to_yaml(self):
|
|
||||||
return yaml.dump(self.to_dict())
|
|
||||||
|
|
||||||
def dump_yaml(self, f=None, outdir=None):
|
|
||||||
if not f and not outdir:
|
|
||||||
raise ValueError('specify a file or an output directory')
|
|
||||||
|
|
||||||
if not f:
|
|
||||||
f = os.path.join(outdir, '{}.dumped.yml'.format(self.name))
|
|
||||||
|
|
||||||
with utils.open_or_reuse(f, 'w') as f:
|
|
||||||
f.write(self.to_yaml())
|
|
||||||
|
|
||||||
def to_yaml(self):
|
|
||||||
return yaml.dump(self.to_dict())
|
|
||||||
|
|
||||||
# TODO: See note on getstate
|
|
||||||
def to_dict(self):
|
|
||||||
return dict(self)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return self.to_yaml()
|
|
||||||
|
|
||||||
def dump_yaml(self, f=None, outdir=None):
|
|
||||||
if not f and not outdir:
|
|
||||||
raise ValueError('specify a file or an output directory')
|
|
||||||
|
|
||||||
if not f:
|
|
||||||
f = os.path.join(outdir, '{}.dumped.yml'.format(self.name))
|
|
||||||
|
|
||||||
with utils.open_or_reuse(f, 'w') as f:
|
|
||||||
f.write(self.to_yaml())
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return getattr(self, key)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return (k for k in self.__slots__ if k[0] != '_')
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.__slots__)
|
|
||||||
|
|
||||||
def dump_pickle(self, f=None, outdir=None):
|
|
||||||
if not outdir and not f:
|
|
||||||
raise ValueError('specify a file or an output directory')
|
|
||||||
|
|
||||||
if not f:
|
|
||||||
f = os.path.join(outdir,
|
|
||||||
'{}.simulation.pickle'.format(self.name))
|
|
||||||
with utils.open_or_reuse(f, 'wb') as f:
|
|
||||||
pickle.dump(self, f)
|
|
||||||
|
|
||||||
# TODO: remove this. A config should be sendable regardless. Non-pickable objects could be computed via properties and the like
|
|
||||||
# def __getstate__(self):
|
|
||||||
# state={}
|
|
||||||
# for k, v in self.__dict__.items():
|
|
||||||
# if k[0] != '_':
|
|
||||||
# state[k] = v
|
|
||||||
# state['topology'] = json_graph.node_link_data(self.topology)
|
|
||||||
# state['network_agents'] = agents.serialize_definition(self.network_agents,
|
|
||||||
# known_modules = [])
|
|
||||||
# state['environment_agents'] = agents.serialize_definition(self.environment_agents,
|
|
||||||
# known_modules = [])
|
|
||||||
# state['environment_class'] = serialization.serialize(self.environment_class,
|
|
||||||
# known_modules=['soil.environment'])[1] # func, name
|
|
||||||
# if state['load_module'] is None:
|
|
||||||
# del state['load_module']
|
|
||||||
# return state
|
|
||||||
|
|
||||||
# # TODO: remove, same as __getstate__
|
|
||||||
# def __setstate__(self, state):
|
|
||||||
# self.__dict__ = state
|
|
||||||
# self.load_module = getattr(self, 'load_module', None)
|
|
||||||
# if self.dir_path not in sys.path:
|
|
||||||
# sys.path += [self.dir_path, os.getcwd()]
|
|
||||||
# self.topology = json_graph.node_link_graph(state['topology'])
|
|
||||||
# self.network_agents = agents.calculate_distribution(agents._convert_agent_types(self.network_agents))
|
|
||||||
# self.environment_agents = agents._convert_agent_types(self.environment_agents,
|
|
||||||
# known_modules=[self.load_module])
|
|
||||||
# self.environment_class = serialization.deserialize(self.environment_class,
|
|
||||||
# known_modules=[self.load_module,
|
|
||||||
# 'soil.environment', ]) # func, name
|
|
||||||
|
|
||||||
class CalculatedConfig(Config):
|
|
||||||
def __init__(self, config):
|
|
||||||
"""
|
|
||||||
Returns a configuration object that replaces some "plain" attributes (e.g., `environment_class` string) into
|
|
||||||
a Python object (`soil.environment.Environment` class).
|
|
||||||
"""
|
|
||||||
self._config = config
|
|
||||||
values = dict(config)
|
|
||||||
values['environment_class'] = self._environment_class()
|
|
||||||
values['environment_agents'] = self._environment_agents()
|
|
||||||
values['topology'] = self._topology()
|
|
||||||
values['network_agents'] = self._network_agents()
|
|
||||||
values['agent_type'] = serialization.deserialize(self.agent_type, known_modules=['soil.agents'])
|
|
||||||
|
|
||||||
return values
|
|
||||||
|
|
||||||
def _topology(self):
|
|
||||||
topology = self._config.topology
|
|
||||||
if topology is None:
|
|
||||||
topology = serialization.load_network(self._config.network_params,
|
|
||||||
dir_path=self._config.dir_path)
|
|
||||||
|
|
||||||
elif isinstance(topology, basestring) or isinstance(topology, dict):
|
|
||||||
topology = json_graph.node_link_graph(topology)
|
|
||||||
|
|
||||||
return nx.Graph(topology)
|
|
||||||
|
|
||||||
def _environment_class(self):
|
|
||||||
return serialization.deserialize(self._config.environment_class,
|
|
||||||
known_modules=['soil.environment', ]) or Environment
|
|
||||||
|
|
||||||
def _environment_agents(self):
|
|
||||||
return agents._convert_agent_types(self._config.environment_agents)
|
|
||||||
|
|
||||||
def _network_agents(self):
|
|
||||||
distro = agents.calculate_distribution(self._config.network_agents,
|
|
||||||
self._config.agent_type)
|
|
||||||
return agents._convert_agent_types(distro)
|
|
||||||
|
|
||||||
def _environment_class(self):
|
|
||||||
return serialization.deserialize(self._config.environment_class,
|
|
||||||
known_modules=['soil.environment', ]) # func, name
|
|
||||||
|
|
@ -15,6 +15,7 @@ from networkx.readwrite import json_graph
|
|||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
from mesa import Model
|
from mesa import Model
|
||||||
|
from mesa.datacollection import DataCollector
|
||||||
|
|
||||||
from . import serialization, agents, analysis, utils, time, config, network
|
from . import serialization, agents, analysis, utils, time, config, network
|
||||||
|
|
||||||
@ -41,6 +42,9 @@ class Environment(Model):
|
|||||||
interval=1,
|
interval=1,
|
||||||
agents: Dict[str, config.AgentConfig] = {},
|
agents: Dict[str, config.AgentConfig] = {},
|
||||||
topologies: Dict[str, config.NetConfig] = {},
|
topologies: Dict[str, config.NetConfig] = {},
|
||||||
|
agent_reporters: Optional[Any] = None,
|
||||||
|
model_reporters: Optional[Any] = None,
|
||||||
|
tables: Optional[Any] = None,
|
||||||
**env_params):
|
**env_params):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -61,6 +65,7 @@ class Environment(Model):
|
|||||||
|
|
||||||
|
|
||||||
self.topologies = {}
|
self.topologies = {}
|
||||||
|
self._node_ids = {}
|
||||||
for (name, cfg) in topologies.items():
|
for (name, cfg) in topologies.items():
|
||||||
self.set_topology(cfg=cfg,
|
self.set_topology(cfg=cfg,
|
||||||
graph=name)
|
graph=name)
|
||||||
@ -72,6 +77,7 @@ class Environment(Model):
|
|||||||
self['SEED'] = seed
|
self['SEED'] = seed
|
||||||
|
|
||||||
self.logger = utils.logger.getChild(self.id)
|
self.logger = utils.logger.getChild(self.id)
|
||||||
|
self.datacollector = DataCollector(model_reporters, agent_reporters, tables)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def topology(self):
|
def topology(self):
|
||||||
@ -79,8 +85,7 @@ class Environment(Model):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def network_agents(self):
|
def network_agents(self):
|
||||||
yield from self.agents(agent_type=agents.NetworkAgent, iterator=False)
|
yield from self.agents(agent_class=agents.NetworkAgent)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(conf: config.Config, trial_id, **kwargs) -> Environment:
|
def from_config(conf: config.Config, trial_id, **kwargs) -> Environment:
|
||||||
@ -91,9 +96,10 @@ class Environment(Model):
|
|||||||
seed = '{}_{}'.format(conf.general.seed, trial_id)
|
seed = '{}_{}'.format(conf.general.seed, trial_id)
|
||||||
id = '{}_trial_{}'.format(conf.general.id, trial_id).replace('.', '-')
|
id = '{}_trial_{}'.format(conf.general.id, trial_id).replace('.', '-')
|
||||||
opts = conf.environment.params.copy()
|
opts = conf.environment.params.copy()
|
||||||
|
dir_path = conf.general.dir_path
|
||||||
opts.update(conf)
|
opts.update(conf)
|
||||||
opts.update(kwargs)
|
opts.update(kwargs)
|
||||||
env = serialization.deserialize(conf.environment.environment_class)(env_id=id, seed=seed, **opts)
|
env = serialization.deserialize(conf.environment.environment_class)(env_id=id, seed=seed, dir_path=dir_path, **opts)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -103,12 +109,31 @@ class Environment(Model):
|
|||||||
raise Exception('The environment has not been scheduled, so it has no sense of time')
|
raise Exception('The environment has not been scheduled, so it has no sense of time')
|
||||||
|
|
||||||
|
|
||||||
|
def topology_for(self, agent_id):
|
||||||
|
return self.topologies[self._node_ids[agent_id][0]]
|
||||||
|
|
||||||
|
def node_id_for(self, agent_id):
|
||||||
|
return self._node_ids[agent_id][1]
|
||||||
|
|
||||||
def set_topology(self, cfg=None, dir_path=None, graph='default'):
|
def set_topology(self, cfg=None, dir_path=None, graph='default'):
|
||||||
self.topologies[graph] = network.from_config(cfg, dir_path=dir_path)
|
topology = cfg
|
||||||
|
if not isinstance(cfg, nx.Graph):
|
||||||
|
topology = network.from_config(cfg, dir_path=dir_path or self.dir_path)
|
||||||
|
|
||||||
|
self.topologies[graph] = topology
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def agents(self):
|
def agents(self):
|
||||||
return agents.AgentView(self._agents)
|
return agents.AgentView(self._agents)
|
||||||
|
|
||||||
|
def count_agents(self, *args, **kwargs):
|
||||||
|
return sum(1 for i in self.find_all(*args, **kwargs))
|
||||||
|
|
||||||
|
def find_all(self, *args, **kwargs):
|
||||||
|
return agents.AgentView(self._agents).filter(*args, **kwargs)
|
||||||
|
|
||||||
|
def find_one(self, *args, **kwargs):
|
||||||
|
return agents.AgentView(self._agents).one(*args, **kwargs)
|
||||||
|
|
||||||
@agents.setter
|
@agents.setter
|
||||||
def agents(self, agents_def: Dict[str, config.AgentConfig]):
|
def agents(self, agents_def: Dict[str, config.AgentConfig]):
|
||||||
@ -117,37 +142,47 @@ class Environment(Model):
|
|||||||
for a in d.values():
|
for a in d.values():
|
||||||
self.schedule.add(a)
|
self.schedule.add(a)
|
||||||
|
|
||||||
|
|
||||||
# @property
|
|
||||||
# def network_agents(self):
|
|
||||||
# for i in self.G.nodes():
|
|
||||||
# node = self.G.nodes[i]
|
|
||||||
# if 'agent' in node:
|
|
||||||
# yield node['agent']
|
|
||||||
|
|
||||||
def init_agent(self, agent_id, agent_definitions, graph='default'):
|
def init_agent(self, agent_id, agent_definitions, graph='default'):
|
||||||
node = self.topologies[graph].nodes[agent_id]
|
node = self.topologies[graph].nodes[agent_id]
|
||||||
init = False
|
init = False
|
||||||
state = dict(node)
|
state = dict(node)
|
||||||
|
|
||||||
agent_type = None
|
agent_class = None
|
||||||
if 'agent_type' in self.states.get(agent_id, {}):
|
if 'agent_class' in self.states.get(agent_id, {}):
|
||||||
agent_type = self.states[agent_id]['agent_type']
|
agent_class = self.states[agent_id]['agent_class']
|
||||||
elif 'agent_type' in node:
|
elif 'agent_class' in node:
|
||||||
agent_type = node['agent_type']
|
agent_class = node['agent_class']
|
||||||
elif 'agent_type' in self.default_state:
|
elif 'agent_class' in self.default_state:
|
||||||
agent_type = self.default_state['agent_type']
|
agent_class = self.default_state['agent_class']
|
||||||
|
|
||||||
if agent_type:
|
if agent_class:
|
||||||
agent_type = agents.deserialize_type(agent_type)
|
agent_class = agents.deserialize_type(agent_class)
|
||||||
elif agent_definitions:
|
elif agent_definitions:
|
||||||
agent_type, state = agents._agent_from_definition(agent_definitions, unique_id=agent_id)
|
agent_class, state = agents._agent_from_definition(agent_definitions, unique_id=agent_id)
|
||||||
else:
|
else:
|
||||||
serialization.logger.debug('Skipping node {}'.format(agent_id))
|
serialization.logger.debug('Skipping node {}'.format(agent_id))
|
||||||
return
|
return
|
||||||
return self.set_agent(agent_id, agent_type, state)
|
return self.set_agent(agent_id, agent_class, state)
|
||||||
|
|
||||||
def set_agent(self, agent_id, agent_type, state=None, graph='default'):
|
def agent_to_node(self, agent_id, graph_name='default', node_id=None, shuffle=False):
|
||||||
|
#TODO: test
|
||||||
|
if node_id is None:
|
||||||
|
G = self.topologies[graph_name]
|
||||||
|
candidates = list(G.nodes(data=True))
|
||||||
|
if shuffle:
|
||||||
|
random.shuffle(candidates)
|
||||||
|
for next_id, data in candidates:
|
||||||
|
if data.get('agent_id', None) is None:
|
||||||
|
node_id = next_id
|
||||||
|
data['agent_id'] = agent_id
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
self._node_ids[agent_id] = (graph_name, node_id)
|
||||||
|
print(self._node_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def set_agent(self, agent_id, agent_class, state=None, graph='default'):
|
||||||
node = self.topologies[graph].nodes[agent_id]
|
node = self.topologies[graph].nodes[agent_id]
|
||||||
defstate = deepcopy(self.default_state) or {}
|
defstate = deepcopy(self.default_state) or {}
|
||||||
defstate.update(self.states.get(agent_id, {}))
|
defstate.update(self.states.get(agent_id, {}))
|
||||||
@ -155,9 +190,9 @@ class Environment(Model):
|
|||||||
if state:
|
if state:
|
||||||
defstate.update(state)
|
defstate.update(state)
|
||||||
a = None
|
a = None
|
||||||
if agent_type:
|
if agent_class:
|
||||||
state = defstate
|
state = defstate
|
||||||
a = agent_type(model=self,
|
a = agent_class(model=self,
|
||||||
unique_id=agent_id
|
unique_id=agent_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -168,10 +203,10 @@ class Environment(Model):
|
|||||||
self.schedule.add(a)
|
self.schedule.add(a)
|
||||||
return a
|
return a
|
||||||
|
|
||||||
def add_node(self, agent_type, state=None, graph='default'):
|
def add_node(self, agent_class, state=None, graph='default'):
|
||||||
agent_id = int(len(self.topologies[graph].nodes()))
|
agent_id = int(len(self.topologies[graph].nodes()))
|
||||||
self.topologies[graph].add_node(agent_id)
|
self.topologies[graph].add_node(agent_id)
|
||||||
a = self.set_agent(agent_id, agent_type, state, graph=graph)
|
a = self.set_agent(agent_id, agent_class, state, graph=graph)
|
||||||
a['visible'] = True
|
a['visible'] = True
|
||||||
return a
|
return a
|
||||||
|
|
||||||
@ -201,6 +236,7 @@ class Environment(Model):
|
|||||||
'''
|
'''
|
||||||
super().step()
|
super().step()
|
||||||
self.schedule.step()
|
self.schedule.step()
|
||||||
|
self.datacollector.collect(self)
|
||||||
|
|
||||||
def run(self, until, *args, **kwargs):
|
def run(self, until, *args, **kwargs):
|
||||||
until = until or float('inf')
|
until = until or float('inf')
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import csv as csvlib
|
|
||||||
from time import time as current_time
|
from time import time as current_time
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
@ -59,7 +58,7 @@ class Exporter:
|
|||||||
'''Method to call when the simulation starts'''
|
'''Method to call when the simulation starts'''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sim_end(self, stats):
|
def sim_end(self):
|
||||||
'''Method to call when the simulation ends'''
|
'''Method to call when the simulation ends'''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -67,7 +66,7 @@ class Exporter:
|
|||||||
'''Method to call when a trial start'''
|
'''Method to call when a trial start'''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def trial_end(self, env, stats):
|
def trial_end(self, env):
|
||||||
'''Method to call when a trial ends'''
|
'''Method to call when a trial ends'''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -115,31 +114,35 @@ class default(Exporter):
|
|||||||
# self.simulation.dump_sqlite(f)
|
# self.simulation.dump_sqlite(f)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dc_dfs(dc):
|
||||||
|
dfs = {'env': dc.get_model_vars_dataframe(),
|
||||||
|
'agents': dc.get_agent_vars_dataframe }
|
||||||
|
for table_name in dc.tables:
|
||||||
|
dfs[table_name] = dc.get_table_dataframe(table_name)
|
||||||
|
yield from dfs.items()
|
||||||
|
|
||||||
|
|
||||||
class csv(Exporter):
|
class csv(Exporter):
|
||||||
|
|
||||||
'''Export the state of each environment (and its agents) in a separate CSV file'''
|
'''Export the state of each environment (and its agents) in a separate CSV file'''
|
||||||
def trial_end(self, env, stats):
|
def trial_end(self, env):
|
||||||
with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name,
|
with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name,
|
||||||
env.name,
|
env.id,
|
||||||
self.outdir)):
|
self.outdir)):
|
||||||
|
for (df_name, df) in get_dc_dfs(env.datacollector):
|
||||||
with self.output('{}.stats.{}.csv'.format(env.name, stats.name)) as f:
|
with self.output('{}.stats.{}.csv'.format(env.id, df_name)) as f:
|
||||||
statwriter = csvlib.writer(f, delimiter='\t', quotechar='"', quoting=csvlib.QUOTE_ALL)
|
df.to_csv(f)
|
||||||
|
|
||||||
for stat in stats:
|
|
||||||
statwriter.writerow(stat)
|
|
||||||
|
|
||||||
|
|
||||||
class gexf(Exporter):
|
class gexf(Exporter):
|
||||||
def trial_end(self, env, stats):
|
def trial_end(self, env):
|
||||||
if self.dry_run:
|
if self.dry_run:
|
||||||
logger.info('Not dumping GEXF in dry_run mode')
|
logger.info('Not dumping GEXF in dry_run mode')
|
||||||
return
|
return
|
||||||
|
|
||||||
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name,
|
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||||
env.name)):
|
env.id)):
|
||||||
with self.output('{}.gexf'.format(env.name), mode='wb') as f:
|
with self.output('{}.gexf'.format(env.id), mode='wb') as f:
|
||||||
self.dump_gexf(env, f)
|
self.dump_gexf(env, f)
|
||||||
|
|
||||||
def dump_gexf(self, env, f):
|
def dump_gexf(self, env, f):
|
||||||
@ -159,25 +162,25 @@ class dummy(Exporter):
|
|||||||
with self.output('dummy', 'w') as f:
|
with self.output('dummy', 'w') as f:
|
||||||
f.write('simulation started @ {}\n'.format(current_time()))
|
f.write('simulation started @ {}\n'.format(current_time()))
|
||||||
|
|
||||||
def trial_end(self, env, stats):
|
def trial_start(self, env):
|
||||||
with self.output('dummy', 'w') as f:
|
with self.output('dummy', 'w') as f:
|
||||||
for i in stats:
|
f.write('trial started@ {}\n'.format(current_time()))
|
||||||
f.write(','.join(map(str, i)))
|
|
||||||
f.write('\n')
|
|
||||||
|
|
||||||
def sim_end(self, stats):
|
def trial_end(self, env):
|
||||||
|
with self.output('dummy', 'w') as f:
|
||||||
|
f.write('trial ended@ {}\n'.format(current_time()))
|
||||||
|
|
||||||
|
def sim_end(self):
|
||||||
with self.output('dummy', 'a') as f:
|
with self.output('dummy', 'a') as f:
|
||||||
f.write('simulation ended @ {}\n'.format(current_time()))
|
f.write('simulation ended @ {}\n'.format(current_time()))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class graphdrawing(Exporter):
|
class graphdrawing(Exporter):
|
||||||
|
|
||||||
def trial_end(self, env, stats):
|
def trial_end(self, env):
|
||||||
# Outside effects
|
# Outside effects
|
||||||
f = plt.figure()
|
f = plt.figure()
|
||||||
nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111))
|
nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111))
|
||||||
with open('graph-{}.png'.format(env.name)) as f:
|
with open('graph-{}.png'.format(env.id)) as f:
|
||||||
f.savefig(f)
|
f.savefig(f)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
@ -16,7 +16,6 @@ from . import serialization, utils, basestring, agents
|
|||||||
from .environment import Environment
|
from .environment import Environment
|
||||||
from .utils import logger
|
from .utils import logger
|
||||||
from .exporters import default
|
from .exporters import default
|
||||||
from .stats import defaultStats
|
|
||||||
|
|
||||||
from .config import Config, convert_old
|
from .config import Config, convert_old
|
||||||
|
|
||||||
@ -71,8 +70,8 @@ class Simulation:
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def run_gen(self, parallel=False, dry_run=False,
|
def run_gen(self, parallel=False, dry_run=False,
|
||||||
exporters=[default, ], stats=[], outdir=None, exporter_params={},
|
exporters=[default, ], outdir=None, exporter_params={},
|
||||||
stats_params={}, log_level=None,
|
log_level=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
'''Run the simulation and yield the resulting environments.'''
|
'''Run the simulation and yield the resulting environments.'''
|
||||||
if log_level:
|
if log_level:
|
||||||
@ -85,15 +84,8 @@ class Simulation:
|
|||||||
dry_run=dry_run,
|
dry_run=dry_run,
|
||||||
outdir=outdir,
|
outdir=outdir,
|
||||||
**exporter_params)
|
**exporter_params)
|
||||||
stats = serialization.deserialize_all(simulation=self,
|
|
||||||
names=stats,
|
|
||||||
known_modules=['soil.stats',],
|
|
||||||
**stats_params)
|
|
||||||
|
|
||||||
with utils.timer('simulation {}'.format(self.config.general.id)):
|
with utils.timer('simulation {}'.format(self.config.general.id)):
|
||||||
for stat in stats:
|
|
||||||
stat.sim_start()
|
|
||||||
|
|
||||||
for exporter in exporters:
|
for exporter in exporters:
|
||||||
exporter.sim_start()
|
exporter.sim_start()
|
||||||
|
|
||||||
@ -104,32 +96,13 @@ class Simulation:
|
|||||||
for exporter in exporters:
|
for exporter in exporters:
|
||||||
exporter.trial_start(env)
|
exporter.trial_start(env)
|
||||||
|
|
||||||
collected = list(stat.trial_end(env) for stat in stats)
|
|
||||||
|
|
||||||
saved = self._update_stats(collected, t_step=env.now, trial_id=env.id)
|
|
||||||
|
|
||||||
for exporter in exporters:
|
for exporter in exporters:
|
||||||
exporter.trial_end(env, saved)
|
exporter.trial_end(env)
|
||||||
|
|
||||||
yield env
|
yield env
|
||||||
|
|
||||||
collected = list(stat.end() for stat in stats)
|
|
||||||
saved = self._update_stats(collected)
|
|
||||||
|
|
||||||
for stat in stats:
|
|
||||||
stat.sim_end()
|
|
||||||
|
|
||||||
for exporter in exporters:
|
for exporter in exporters:
|
||||||
exporter.sim_end(saved)
|
exporter.sim_end()
|
||||||
|
|
||||||
def _update_stats(self, collection, **kwargs):
|
|
||||||
stats = dict(kwargs)
|
|
||||||
for stat in collection:
|
|
||||||
stats.update(stat)
|
|
||||||
return stats
|
|
||||||
|
|
||||||
def log_stats(self, stats):
|
|
||||||
logger.info('Stats: \n{}'.format(yaml.dump(stats, default_flow_style=False)))
|
|
||||||
|
|
||||||
def get_env(self, trial_id=0, **kwargs):
|
def get_env(self, trial_id=0, **kwargs):
|
||||||
'''Create an environment for a trial of the simulation'''
|
'''Create an environment for a trial of the simulation'''
|
||||||
|
111
soil/stats.py
111
soil/stats.py
@ -1,111 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
|
|
||||||
from collections import Counter
|
|
||||||
|
|
||||||
class Stats:
|
|
||||||
'''
|
|
||||||
Interface for all stats. It is not necessary, but it is useful
|
|
||||||
if you don't plan to implement all the methods.
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(self, simulation, name=None):
|
|
||||||
self.name = name or type(self).__name__
|
|
||||||
self.simulation = simulation
|
|
||||||
|
|
||||||
def sim_start(self):
|
|
||||||
'''Method to call when the simulation starts'''
|
|
||||||
pass
|
|
||||||
|
|
||||||
def sim_end(self):
|
|
||||||
'''Method to call when the simulation ends'''
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def trial_start(self, env):
|
|
||||||
'''Method to call when a trial starts'''
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def trial_end(self, env):
|
|
||||||
'''Method to call when a trial ends'''
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class distribution(Stats):
|
|
||||||
'''
|
|
||||||
Calculate the distribution of agent states at the end of each trial,
|
|
||||||
the mean value, and its deviation.
|
|
||||||
'''
|
|
||||||
|
|
||||||
def sim_start(self):
|
|
||||||
self.means = []
|
|
||||||
self.counts = []
|
|
||||||
|
|
||||||
def trial_end(self, env):
|
|
||||||
df = pd.DataFrame(env.state_to_tuples())
|
|
||||||
df = df.drop('SEED', axis=1)
|
|
||||||
ix = df.index[-1]
|
|
||||||
attrs = df.columns.get_level_values(0)
|
|
||||||
vc = {}
|
|
||||||
stats = {
|
|
||||||
'mean': {},
|
|
||||||
'count': {},
|
|
||||||
}
|
|
||||||
for a in attrs:
|
|
||||||
t = df.loc[(ix, a)]
|
|
||||||
try:
|
|
||||||
stats['mean'][a] = t.mean()
|
|
||||||
self.means.append(('mean', a, t.mean()))
|
|
||||||
except TypeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
for name, count in t.value_counts().iteritems():
|
|
||||||
if a not in stats['count']:
|
|
||||||
stats['count'][a] = {}
|
|
||||||
stats['count'][a][name] = count
|
|
||||||
self.counts.append(('count', a, name, count))
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
def sim_end(self):
|
|
||||||
dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value'])
|
|
||||||
dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count'])
|
|
||||||
|
|
||||||
count = {}
|
|
||||||
mean = {}
|
|
||||||
|
|
||||||
if self.means:
|
|
||||||
res = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
|
||||||
mean = res['value'].to_dict()
|
|
||||||
if self.counts:
|
|
||||||
res = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min'])
|
|
||||||
for k,v in res['count'].to_dict().items():
|
|
||||||
if k not in count:
|
|
||||||
count[k] = {}
|
|
||||||
for tup, times in v.items():
|
|
||||||
subkey, subcount = tup
|
|
||||||
if subkey not in count[k]:
|
|
||||||
count[k][subkey] = {}
|
|
||||||
count[k][subkey][subcount] = times
|
|
||||||
|
|
||||||
|
|
||||||
return {'count': count, 'mean': mean}
|
|
||||||
|
|
||||||
|
|
||||||
class defaultStats(Stats):
|
|
||||||
|
|
||||||
def trial_end(self, env):
|
|
||||||
c = Counter()
|
|
||||||
c.update(a.__class__.__name__ for a in env.network_agents)
|
|
||||||
|
|
||||||
c2 = Counter()
|
|
||||||
c2.update(a['id'] for a in env.network_agents)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'network ': {
|
|
||||||
'n_nodes': env.G.number_of_nodes(),
|
|
||||||
'n_edges': env.G.number_of_edges(),
|
|
||||||
},
|
|
||||||
'agents': {
|
|
||||||
'model_count': dict(c),
|
|
||||||
'state_count': dict(c2),
|
|
||||||
}
|
|
||||||
}
|
|
@ -29,11 +29,11 @@ agents:
|
|||||||
weight: 0.6
|
weight: 0.6
|
||||||
override:
|
override:
|
||||||
- filter:
|
- filter:
|
||||||
id: 0
|
node_id: 0
|
||||||
state:
|
state:
|
||||||
name: 'The first node'
|
name: 'The first node'
|
||||||
- filter:
|
- filter:
|
||||||
id: 1
|
node_id: 1
|
||||||
state:
|
state:
|
||||||
name: 'The second node'
|
name: 'The second node'
|
||||||
|
|
||||||
|
@ -71,11 +71,11 @@ class TestConfig(TestCase):
|
|||||||
s = simulation.from_config(cfg)
|
s = simulation.from_config(cfg)
|
||||||
env = s.get_env()
|
env = s.get_env()
|
||||||
assert len(env.topologies['default'].nodes) == 10
|
assert len(env.topologies['default'].nodes) == 10
|
||||||
assert len(env.agents('network')) == 10
|
assert len(env.agents(group='network')) == 10
|
||||||
assert len(env.agents('environment')) == 1
|
assert len(env.agents(group='environment')) == 1
|
||||||
|
|
||||||
assert sum(1 for a in env.agents('network') if isinstance(a, agents.CounterModel)) == 4
|
assert sum(1 for a in env.agents(group='network', agent_type=agents.CounterModel)) == 4
|
||||||
assert sum(1 for a in env.agents('network') if isinstance(a, agents.AggregatedCounter)) == 6
|
assert sum(1 for a in env.agents(group='network', agent_type=agents.AggregatedCounter)) == 6
|
||||||
|
|
||||||
def make_example_test(path, cfg):
|
def make_example_test(path, cfg):
|
||||||
def wrapped(self):
|
def wrapped(self):
|
||||||
|
@ -7,8 +7,6 @@ from unittest import TestCase
|
|||||||
from soil import exporters
|
from soil import exporters
|
||||||
from soil import simulation
|
from soil import simulation
|
||||||
|
|
||||||
from soil.stats import distribution
|
|
||||||
|
|
||||||
class Dummy(exporters.Exporter):
|
class Dummy(exporters.Exporter):
|
||||||
started = False
|
started = False
|
||||||
trials = 0
|
trials = 0
|
||||||
@ -22,13 +20,13 @@ class Dummy(exporters.Exporter):
|
|||||||
self.__class__.called_start += 1
|
self.__class__.called_start += 1
|
||||||
self.__class__.started = True
|
self.__class__.started = True
|
||||||
|
|
||||||
def trial_end(self, env, stats):
|
def trial_end(self, env):
|
||||||
assert env
|
assert env
|
||||||
self.__class__.trials += 1
|
self.__class__.trials += 1
|
||||||
self.__class__.total_time += env.now
|
self.__class__.total_time += env.now
|
||||||
self.__class__.called_trial += 1
|
self.__class__.called_trial += 1
|
||||||
|
|
||||||
def sim_end(self, stats):
|
def sim_end(self):
|
||||||
self.__class__.ended = True
|
self.__class__.ended = True
|
||||||
self.__class__.called_end += 1
|
self.__class__.called_end += 1
|
||||||
|
|
||||||
@ -78,7 +76,6 @@ class Exporters(TestCase):
|
|||||||
exporters.csv,
|
exporters.csv,
|
||||||
exporters.gexf,
|
exporters.gexf,
|
||||||
],
|
],
|
||||||
stats=[distribution,],
|
|
||||||
dry_run=False,
|
dry_run=False,
|
||||||
outdir=tmpdir,
|
outdir=tmpdir,
|
||||||
exporter_params={'copy_to': output})
|
exporter_params={'copy_to': output})
|
||||||
|
@ -10,7 +10,7 @@ from functools import partial
|
|||||||
|
|
||||||
from os.path import join
|
from os.path import join
|
||||||
from soil import (simulation, Environment, agents, network, serialization,
|
from soil import (simulation, Environment, agents, network, serialization,
|
||||||
utils)
|
utils, config)
|
||||||
from soil.time import Delta
|
from soil.time import Delta
|
||||||
|
|
||||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||||
@ -200,7 +200,6 @@ class TestMain(TestCase):
|
|||||||
recovered = yaml.load(serial, Loader=yaml.SafeLoader)
|
recovered = yaml.load(serial, Loader=yaml.SafeLoader)
|
||||||
for (k, v) in config.items():
|
for (k, v) in config.items():
|
||||||
assert recovered[k] == v
|
assert recovered[k] == v
|
||||||
# assert config == recovered
|
|
||||||
|
|
||||||
def test_configuration_changes(self):
|
def test_configuration_changes(self):
|
||||||
"""
|
"""
|
||||||
@ -294,11 +293,13 @@ class TestMain(TestCase):
|
|||||||
G.add_node(3)
|
G.add_node(3)
|
||||||
G.add_edge(1, 2)
|
G.add_edge(1, 2)
|
||||||
distro = agents.calculate_distribution(agent_type=agents.NetworkAgent)
|
distro = agents.calculate_distribution(agent_type=agents.NetworkAgent)
|
||||||
env = Environment(name='Test', topology=G, network_agents=distro)
|
distro[0]['topology'] = 'default'
|
||||||
|
aconfig = config.AgentConfig(distribution=distro, topology='default')
|
||||||
|
env = Environment(name='Test', topologies={'default': G}, agents={'network': aconfig})
|
||||||
lst = list(env.network_agents)
|
lst = list(env.network_agents)
|
||||||
|
|
||||||
a2 = env.get_agent(2)
|
a2 = env.find_one(node_id=2)
|
||||||
a3 = env.get_agent(3)
|
a3 = env.find_one(node_id=3)
|
||||||
assert len(a2.subgraph(limit_neighbors=True)) == 2
|
assert len(a2.subgraph(limit_neighbors=True)) == 2
|
||||||
assert len(a3.subgraph(limit_neighbors=True)) == 1
|
assert len(a3.subgraph(limit_neighbors=True)) == 1
|
||||||
assert len(a3.subgraph(limit_neighbors=True, center=False)) == 0
|
assert len(a3.subgraph(limit_neighbors=True, center=False)) == 0
|
||||||
|
@ -1,34 +0,0 @@
|
|||||||
from unittest import TestCase
|
|
||||||
|
|
||||||
from soil import simulation, stats
|
|
||||||
from soil.utils import unflatten_dict
|
|
||||||
|
|
||||||
class Stats(TestCase):
|
|
||||||
|
|
||||||
def test_distribution(self):
|
|
||||||
'''The distribution exporter should write the number of agents in each state'''
|
|
||||||
config = {
|
|
||||||
'name': 'exporter_sim',
|
|
||||||
'network_params': {
|
|
||||||
'generator': 'complete_graph',
|
|
||||||
'n': 4
|
|
||||||
},
|
|
||||||
'agent_type': 'CounterModel',
|
|
||||||
'max_time': 2,
|
|
||||||
'num_trials': 5,
|
|
||||||
'environment_params': {}
|
|
||||||
}
|
|
||||||
s = simulation.from_config(config)
|
|
||||||
for env in s.run_simulation(stats=[stats.distribution]):
|
|
||||||
pass
|
|
||||||
# stats_res = unflatten_dict(dict(env._history['stats', -1, None]))
|
|
||||||
allstats = s.get_stats()
|
|
||||||
for stat in allstats:
|
|
||||||
assert 'count' in stat
|
|
||||||
assert 'mean' in stat
|
|
||||||
if 'trial_id' in stat:
|
|
||||||
assert stat['mean']['neighbors'] == 3
|
|
||||||
assert stat['count']['total']['4'] == 4
|
|
||||||
else:
|
|
||||||
assert stat['count']['count']['neighbors']['3'] == 20
|
|
||||||
assert stat['mean']['min']['neighbors'] == stat['mean']['max']['neighbors']
|
|
Loading…
Reference in New Issue
Block a user