mirror of
https://github.com/gsi-upm/soil
synced 2024-11-21 18:52:28 +00:00
WIP: working config
This commit is contained in:
parent
e41dc3dae2
commit
3dc56892c1
@ -22,7 +22,7 @@ environment:
|
|||||||
params:
|
params:
|
||||||
am_i_complete: true
|
am_i_complete: true
|
||||||
agents:
|
agents:
|
||||||
# Agents are split into groups, each with its own definition
|
# Agents are split several groups, each with its own definition
|
||||||
default: # This is a special group. Its values will be used as default values for the rest of the groups
|
default: # This is a special group. Its values will be used as default values for the rest of the groups
|
||||||
agent_class: CounterModel
|
agent_class: CounterModel
|
||||||
topology: default
|
topology: default
|
||||||
@ -31,7 +31,7 @@ agents:
|
|||||||
environment:
|
environment:
|
||||||
# In this group we are not specifying any topology
|
# In this group we are not specifying any topology
|
||||||
fixed:
|
fixed:
|
||||||
- agent_id: 'Environment Agent 1'
|
- name: 'Environment Agent 1'
|
||||||
agent_class: CounterModel
|
agent_class: CounterModel
|
||||||
state:
|
state:
|
||||||
times: 10
|
times: 10
|
||||||
@ -41,10 +41,16 @@ agents:
|
|||||||
- agent_class: CounterModel
|
- agent_class: CounterModel
|
||||||
weight: 1
|
weight: 1
|
||||||
state:
|
state:
|
||||||
id: 0
|
|
||||||
times: 3
|
times: 3
|
||||||
- agent_class: AggregatedCounter
|
- agent_class: AggregatedCounter
|
||||||
weight: 0.2
|
weight: 0.2
|
||||||
|
override:
|
||||||
|
- filter:
|
||||||
|
agent_class: AggregatedCounter
|
||||||
|
n: 2
|
||||||
|
state:
|
||||||
|
times: 5
|
||||||
|
|
||||||
other_counters:
|
other_counters:
|
||||||
topology: another_graph
|
topology: another_graph
|
||||||
fixed:
|
fixed:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from collections.abc import Mapping, Set
|
from collections.abc import MutableMapping, Mapping, Set
|
||||||
|
from abc import ABCMeta
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from itertools import islice, chain
|
from itertools import islice, chain
|
||||||
@ -10,6 +11,8 @@ import networkx as nx
|
|||||||
from mesa import Agent as MesaAgent
|
from mesa import Agent as MesaAgent
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from random import shuffle
|
||||||
|
|
||||||
from .. import serialization, utils, time, config
|
from .. import serialization, utils, time, config
|
||||||
|
|
||||||
|
|
||||||
@ -25,7 +28,7 @@ IGNORED_FIELDS = ('model', 'logger')
|
|||||||
class DeadAgent(Exception):
|
class DeadAgent(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class BaseAgent(MesaAgent):
|
class BaseAgent(MesaAgent, MutableMapping):
|
||||||
"""
|
"""
|
||||||
A special type of Mesa Agent that:
|
A special type of Mesa Agent that:
|
||||||
|
|
||||||
@ -50,8 +53,10 @@ class BaseAgent(MesaAgent):
|
|||||||
# Initialize agent parameters
|
# Initialize agent parameters
|
||||||
if isinstance(unique_id, MesaAgent):
|
if isinstance(unique_id, MesaAgent):
|
||||||
raise Exception()
|
raise Exception()
|
||||||
|
assert isinstance(unique_id, int)
|
||||||
super().__init__(unique_id=unique_id, model=model)
|
super().__init__(unique_id=unique_id, model=model)
|
||||||
self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id)
|
self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id)
|
||||||
|
|
||||||
|
|
||||||
self._neighbors = None
|
self._neighbors = None
|
||||||
self.alive = True
|
self.alive = True
|
||||||
@ -120,6 +125,12 @@ class BaseAgent(MesaAgent):
|
|||||||
def __setitem__(self, key, value):
|
def __setitem__(self, key, value):
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return sum(1 for n in self.keys())
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self.items()
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return (k for k in self.__dict__ if k[0] != '_')
|
return (k for k in self.__dict__ if k[0] != '_')
|
||||||
|
|
||||||
@ -284,7 +295,7 @@ def default_state(func):
|
|||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
class MetaFSM(type):
|
class MetaFSM(ABCMeta):
|
||||||
def __init__(cls, name, bases, nmspc):
|
def __init__(cls, name, bases, nmspc):
|
||||||
super(MetaFSM, cls).__init__(name, bases, nmspc)
|
super(MetaFSM, cls).__init__(name, bases, nmspc)
|
||||||
states = {}
|
states = {}
|
||||||
@ -486,14 +497,15 @@ def _definition_to_dict(definition, size=None, default_state=None):
|
|||||||
|
|
||||||
distro = sorted([item for item in definition if 'weight' in item])
|
distro = sorted([item for item in definition if 'weight' in item])
|
||||||
|
|
||||||
ix = 0
|
id = 0
|
||||||
|
|
||||||
def init_agent(item, id=ix):
|
def init_agent(item, id=ix):
|
||||||
while id in agents:
|
while id in agents:
|
||||||
id += 1
|
id += 1
|
||||||
|
|
||||||
agent = remaining[id]
|
agent = remaining[id]
|
||||||
agent['state'].update(copy(item.get('state', {})))
|
agent['state'].update(copy(item.get('state', {})))
|
||||||
agents[id] = agent
|
agents[agent.unique_id] = agent
|
||||||
del remaining[id]
|
del remaining[id]
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
@ -554,7 +566,7 @@ class AgentView(Mapping, Set):
|
|||||||
return sum(len(x) for x in self._agents.values())
|
return sum(len(x) for x in self._agents.values())
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(chain.from_iterable(g.values() for g in self._agents.values()))
|
yield from iter(chain.from_iterable(g.values() for g in self._agents.values()))
|
||||||
|
|
||||||
def __getitem__(self, agent_id):
|
def __getitem__(self, agent_id):
|
||||||
if isinstance(agent_id, slice):
|
if isinstance(agent_id, slice):
|
||||||
@ -564,45 +576,11 @@ class AgentView(Mapping, Set):
|
|||||||
return group[agent_id]
|
return group[agent_id]
|
||||||
raise ValueError(f"Agent {agent_id} not found")
|
raise ValueError(f"Agent {agent_id} not found")
|
||||||
|
|
||||||
def filter(self, ids=None, groups=None, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs):
|
def filter(self, *group_ids, **kwargs):
|
||||||
|
yield from filter_groups(self._agents, group_ids=group_ids, **kwargs)
|
||||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
|
||||||
state_id = tuple([state_id])
|
|
||||||
|
|
||||||
agents = self._agents
|
|
||||||
|
|
||||||
if groups:
|
|
||||||
agents = {(k,v) for (k, v) in agents.items() if k in groups}
|
|
||||||
|
|
||||||
if agent_type is not None:
|
|
||||||
try:
|
|
||||||
agent_type = tuple(agent_type)
|
|
||||||
except TypeError:
|
|
||||||
agent_type = tuple([agent_type])
|
|
||||||
|
|
||||||
if ids:
|
|
||||||
agents = (v[aid] for v in agents.values() for aid in ids if aid in v)
|
|
||||||
else:
|
|
||||||
agents = (a for v in agents.values() for a in v.values())
|
|
||||||
|
|
||||||
f = agents
|
|
||||||
if ignore:
|
|
||||||
f = filter(lambda x: x not in ignore, f)
|
|
||||||
|
|
||||||
if state_id is not None:
|
|
||||||
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
|
|
||||||
|
|
||||||
if agent_type is not None:
|
|
||||||
f = filter(lambda agent: isinstance(agent, agent_type), f)
|
|
||||||
for k, v in kwargs.items():
|
|
||||||
f = filter(lambda agent: agent.state.get(k, None) == v, f)
|
|
||||||
|
|
||||||
if iterator:
|
|
||||||
return f
|
|
||||||
return list(f)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.filter(*args, **kwargs)
|
return list(self.filter(*args, **kwargs))
|
||||||
|
|
||||||
def __contains__(self, agent_id):
|
def __contains__(self, agent_id):
|
||||||
return any(agent_id in g for g in self._agents)
|
return any(agent_id in g for g in self._agents)
|
||||||
@ -614,6 +592,57 @@ class AgentView(Mapping, Set):
|
|||||||
return f"{self.__class__.__name__}({self})"
|
return f"{self.__class__.__name__}({self})"
|
||||||
|
|
||||||
|
|
||||||
|
def filter_groups(groups, group_ids=None, **kwargs):
|
||||||
|
assert isinstance(groups, dict)
|
||||||
|
if group_ids:
|
||||||
|
groups = list(groups[g] for g in group_ids if g in groups)
|
||||||
|
else:
|
||||||
|
groups = list(groups.values())
|
||||||
|
|
||||||
|
agents = chain.from_iterable(filter_group(g, **kwargs) for g in groups)
|
||||||
|
|
||||||
|
yield from agents
|
||||||
|
|
||||||
|
|
||||||
|
def filter_group(group, ids=None, state_id=None, agent_type=None, ignore=None, state=None, **kwargs):
|
||||||
|
'''
|
||||||
|
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
||||||
|
'''
|
||||||
|
assert isinstance(group, dict)
|
||||||
|
|
||||||
|
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||||
|
state_id = tuple([state_id])
|
||||||
|
|
||||||
|
if agent_type is not None:
|
||||||
|
try:
|
||||||
|
agent_type = tuple(agent_type)
|
||||||
|
except TypeError:
|
||||||
|
agent_type = tuple([agent_type])
|
||||||
|
|
||||||
|
if ids:
|
||||||
|
agents = (v[aid] for aid in ids if aid in group)
|
||||||
|
else:
|
||||||
|
agents = (a for a in group.values())
|
||||||
|
|
||||||
|
f = agents
|
||||||
|
if ignore:
|
||||||
|
f = filter(lambda x: x not in ignore, f)
|
||||||
|
|
||||||
|
if state_id is not None:
|
||||||
|
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
|
||||||
|
|
||||||
|
if agent_type is not None:
|
||||||
|
f = filter(lambda agent: isinstance(agent, agent_type), f)
|
||||||
|
|
||||||
|
state = state or dict()
|
||||||
|
state.update(kwargs)
|
||||||
|
|
||||||
|
for k, v in state.items():
|
||||||
|
f = filter(lambda agent: agent.state.get(k, None) == v, f)
|
||||||
|
|
||||||
|
yield from f
|
||||||
|
|
||||||
|
|
||||||
def from_config(cfg: Dict[str, config.AgentConfig], env):
|
def from_config(cfg: Dict[str, config.AgentConfig], env):
|
||||||
'''
|
'''
|
||||||
Agents are specified in groups.
|
Agents are specified in groups.
|
||||||
@ -632,10 +661,22 @@ def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfi
|
|||||||
agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env)
|
agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env)
|
||||||
if cfg.distribution:
|
if cfg.distribution:
|
||||||
n = cfg.n or len(env.topologies[cfg.topology])
|
n = cfg.n or len(env.topologies[cfg.topology])
|
||||||
agents.update(_from_distro(cfg.distribution, n - len(agents),
|
target = n - len(agents)
|
||||||
|
agents.update(_from_distro(cfg.distribution, target,
|
||||||
topology=cfg.topology or default.topology,
|
topology=cfg.topology or default.topology,
|
||||||
default=default,
|
default=default,
|
||||||
env=env))
|
env=env))
|
||||||
|
assert len(agents) == n
|
||||||
|
if cfg.override:
|
||||||
|
for attrs in cfg.override:
|
||||||
|
if attrs.filter:
|
||||||
|
filtered = list(filter_group(agents, **attrs.filter))
|
||||||
|
else:
|
||||||
|
filtered = list(agents)
|
||||||
|
|
||||||
|
for agent in random.sample(filtered, attrs.n):
|
||||||
|
agent.state.update(attrs.state)
|
||||||
|
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
|
|
||||||
@ -650,10 +691,11 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: conf
|
|||||||
cls = serialization.deserialize(fixed.agent_class or default.agent_class)
|
cls = serialization.deserialize(fixed.agent_class or default.agent_class)
|
||||||
state = fixed.state.copy()
|
state = fixed.state.copy()
|
||||||
state.update(default.state)
|
state.update(default.state)
|
||||||
agents[agent_id] = cls(unique_id=agent_id,
|
agent = cls(unique_id=agent_id,
|
||||||
model=env,
|
model=env,
|
||||||
graph_name=fixed.topology or topology or default.topology,
|
graph_name=fixed.topology or topology or default.topology,
|
||||||
**state)
|
**state)
|
||||||
|
agents[agent.unique_id] = agent
|
||||||
|
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
@ -671,31 +713,40 @@ def _from_distro(distro: List[config.AgentDistro],
|
|||||||
raise ValueError('You must provide a total number of agents, or the number of each type')
|
raise ValueError('You must provide a total number of agents, or the number of each type')
|
||||||
n = sum(dist.n for dist in distro)
|
n = sum(dist.n for dist in distro)
|
||||||
|
|
||||||
|
weights = list(dist.weight if dist.weight is not None else 1 for dist in distro)
|
||||||
|
minw = min(weights)
|
||||||
|
norm = list(weight / minw for weight in weights)
|
||||||
|
total = sum(norm)
|
||||||
|
chunk = n // total
|
||||||
|
|
||||||
total = sum((dist.weight if dist.weight is not None else 1) for dist in distro)
|
# random.choices would be enough to get a weighted distribution. But it can vary a lot for smaller k
|
||||||
thres = {}
|
# So instead we calculate our own distribution to make sure the actual ratios are close to what we would expect
|
||||||
last = 0
|
|
||||||
for i in sorted(distro, key=lambda x: x.weight):
|
|
||||||
|
|
||||||
cls = serialization.deserialize(i.agent_class or default.agent_class)
|
# Calculate how many times each has to appear
|
||||||
thres[(last, last + i.weight/total)] = (cls, i)
|
indices = list(chain.from_iterable([idx] * int(n*chunk) for (idx, n) in enumerate(norm)))
|
||||||
|
|
||||||
acc = 0
|
# Complete with random agents following the original weight distribution
|
||||||
|
if len(indices) < n:
|
||||||
|
indices += random.choices(list(range(len(distro))), weights=[d.weight for d in distro], k=n-len(indices))
|
||||||
|
|
||||||
# using np.choice would be more efficient, but this allows us to use soil without
|
# Deserialize classes for efficiency
|
||||||
# numpy
|
classes = list(serialization.deserialize(i.agent_class or default.agent_class) for i in distro)
|
||||||
for i in range(n):
|
|
||||||
r = random.random()
|
|
||||||
for (t, (cls, d)) in thres.items():
|
|
||||||
if r >= t[0] and r <= t[1]:
|
|
||||||
agent_id = d.agent_id
|
|
||||||
if agent_id is None:
|
|
||||||
agent_id = env.next_id()
|
|
||||||
|
|
||||||
state = d.state.copy()
|
# Add them in random order
|
||||||
state.update(default.state)
|
random.shuffle(indices)
|
||||||
agents[agent_id] = cls(unique_id=agent_id, model=env, graph_name=d.topology or topology or default.topology, **state)
|
|
||||||
break
|
|
||||||
|
for idx in indices:
|
||||||
|
d = distro[idx]
|
||||||
|
cls = classes[idx]
|
||||||
|
agent_id = env.next_id()
|
||||||
|
state = d.state.copy()
|
||||||
|
state.update(default.state)
|
||||||
|
agent = cls(unique_id=agent_id, model=env, graph_name=d.topology or topology or default.topology, **state)
|
||||||
|
assert agent.name is not None
|
||||||
|
assert agent.name != 'None'
|
||||||
|
assert agent.name
|
||||||
|
agents[agent.unique_id] = agent
|
||||||
|
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
|
@ -76,8 +76,9 @@ class EnvConfig(BaseModel):
|
|||||||
|
|
||||||
class SingleAgentConfig(BaseModel):
|
class SingleAgentConfig(BaseModel):
|
||||||
agent_class: Optional[Union[Type, str]] = None
|
agent_class: Optional[Union[Type, str]] = None
|
||||||
agent_id: Optional[Union[str, int]] = None
|
agent_id: Optional[int] = None
|
||||||
topology: Optional[str] = None
|
topology: Optional[str] = 'default'
|
||||||
|
name: Optional[str] = None
|
||||||
state: Optional[Dict[str, Any]] = {}
|
state: Optional[Dict[str, Any]] = {}
|
||||||
|
|
||||||
class FixedAgentConfig(SingleAgentConfig):
|
class FixedAgentConfig(SingleAgentConfig):
|
||||||
@ -85,11 +86,16 @@ class FixedAgentConfig(SingleAgentConfig):
|
|||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def validate_all(cls, values):
|
def validate_all(cls, values):
|
||||||
if 'agent_id' in values and values.get('n', 1) > 1:
|
if values.get('agent_id', None) is not None and values.get('n', 1) > 1:
|
||||||
raise ValueError("An agent_id can only be provided when there is only one agent")
|
print(values)
|
||||||
|
raise ValueError(f"An agent_id can only be provided when there is only one agent ({values.get('n')} given)")
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class OverrideAgentConfig(FixedAgentConfig):
|
||||||
|
filter: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
class AgentDistro(SingleAgentConfig):
|
class AgentDistro(SingleAgentConfig):
|
||||||
weight: Optional[float] = 1
|
weight: Optional[float] = 1
|
||||||
|
|
||||||
@ -99,6 +105,7 @@ class AgentConfig(SingleAgentConfig):
|
|||||||
topology: Optional[str] = None
|
topology: Optional[str] = None
|
||||||
distribution: Optional[List[AgentDistro]] = None
|
distribution: Optional[List[AgentDistro]] = None
|
||||||
fixed: Optional[List[FixedAgentConfig]] = None
|
fixed: Optional[List[FixedAgentConfig]] = None
|
||||||
|
override: Optional[List[OverrideAgentConfig]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default():
|
def default():
|
||||||
@ -118,7 +125,6 @@ class Config(BaseModel, extra=Extra.forbid):
|
|||||||
environment: EnvConfig = EnvConfig.default()
|
environment: EnvConfig = EnvConfig.default()
|
||||||
agents: Optional[Dict[str, AgentConfig]] = {}
|
agents: Optional[Dict[str, AgentConfig]] = {}
|
||||||
|
|
||||||
|
|
||||||
def convert_old(old, strict=True):
|
def convert_old(old, strict=True):
|
||||||
'''
|
'''
|
||||||
Try to convert old style configs into the new format.
|
Try to convert old style configs into the new format.
|
||||||
@ -126,10 +132,6 @@ def convert_old(old, strict=True):
|
|||||||
This is still a work in progress and might not work in many cases.
|
This is still a work in progress and might not work in many cases.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# TODO: translate states
|
|
||||||
|
|
||||||
if strict and old.get('states', {}):
|
|
||||||
raise ValueError('Custom (i.e., manual) agent states cannot be translated to v2 configuration files. Please, convert your configuration file to the new format.')
|
|
||||||
|
|
||||||
new = {}
|
new = {}
|
||||||
|
|
||||||
@ -181,13 +183,16 @@ def convert_old(old, strict=True):
|
|||||||
|
|
||||||
for agent in old.get('environment_agents', []):
|
for agent in old.get('environment_agents', []):
|
||||||
agents['environment'] = {'distribution': [], 'fixed': []}
|
agents['environment'] = {'distribution': [], 'fixed': []}
|
||||||
if 'agent_id' not in agent:
|
if 'agent_id' in agent:
|
||||||
agents['environment']['distribution'].append(updated_agent(agent))
|
agent['name'] = agent['agent_id']
|
||||||
else:
|
del agent['agent_id']
|
||||||
agents['environment']['fixed'].append(updated_agent(agent))
|
agents['environment']['fixed'].append(updated_agent(agent))
|
||||||
|
else:
|
||||||
|
agents['environment']['distribution'].append(updated_agent(agent))
|
||||||
|
|
||||||
by_weight = []
|
by_weight = []
|
||||||
fixed = []
|
fixed = []
|
||||||
|
override = []
|
||||||
|
|
||||||
if 'network_agents' in old:
|
if 'network_agents' in old:
|
||||||
agents['network']['topology'] = 'default'
|
agents['network']['topology'] = 'default'
|
||||||
@ -203,6 +208,20 @@ def convert_old(old, strict=True):
|
|||||||
agents['network']['topology'] = 'default'
|
agents['network']['topology'] = 'default'
|
||||||
by_weight = [{'agent_type': old['agent_type']}]
|
by_weight = [{'agent_type': old['agent_type']}]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: translate states
|
||||||
|
if 'states' in old:
|
||||||
|
states = old['states']
|
||||||
|
if isinstance(states, dict):
|
||||||
|
states = states.items()
|
||||||
|
else:
|
||||||
|
states = enumerate(states)
|
||||||
|
for (k, v) in states:
|
||||||
|
override.append({'filter': {'id': k},
|
||||||
|
'state': v
|
||||||
|
})
|
||||||
|
|
||||||
|
agents['network']['override'] = override
|
||||||
agents['network']['fixed'] = fixed
|
agents['network']['fixed'] = fixed
|
||||||
agents['network']['distribution'] = by_weight
|
agents['network']['distribution'] = by_weight
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ class Environment(Model):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def network_agents(self):
|
def network_agents(self):
|
||||||
yield from self.agents(agent_type=agents.NetworkAgent, iterator=True)
|
yield from self.agents(agent_type=agents.NetworkAgent, iterator=False)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
42
soil/network.py
Normal file
42
soil/network.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from typing import Dict
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import networkx as nx
|
||||||
|
|
||||||
|
from . import config, serialization, basestring
|
||||||
|
|
||||||
|
def from_config(cfg: config.NetConfig, dir_path: str = None):
|
||||||
|
if not isinstance(cfg, config.NetConfig):
|
||||||
|
cfg = config.NetConfig(**cfg)
|
||||||
|
|
||||||
|
if cfg.path:
|
||||||
|
path = cfg.path
|
||||||
|
if dir_path and not os.path.isabs(path):
|
||||||
|
path = os.path.join(dir_path, path)
|
||||||
|
extension = os.path.splitext(path)[1][1:]
|
||||||
|
kwargs = {}
|
||||||
|
if extension == 'gexf':
|
||||||
|
kwargs['version'] = '1.2draft'
|
||||||
|
kwargs['node_type'] = int
|
||||||
|
try:
|
||||||
|
method = getattr(nx.readwrite, 'read_' + extension)
|
||||||
|
except AttributeError:
|
||||||
|
raise AttributeError('Unknown format')
|
||||||
|
return method(path, **kwargs)
|
||||||
|
|
||||||
|
if cfg.params:
|
||||||
|
net_args = cfg.params.dict()
|
||||||
|
net_gen = net_args.pop('generator')
|
||||||
|
|
||||||
|
if dir_path not in sys.path:
|
||||||
|
sys.path.append(dir_path)
|
||||||
|
|
||||||
|
method = serialization.deserializer(net_gen,
|
||||||
|
known_modules=['networkx.generators',])
|
||||||
|
return method(**net_args)
|
||||||
|
|
||||||
|
if isinstance(cfg.topology, basestring) or isinstance(cfg.topology, dict):
|
||||||
|
return nx.json_graph.node_link_graph(cfg.topology)
|
||||||
|
|
||||||
|
return nx.Graph()
|
@ -42,7 +42,6 @@ class Simulation:
|
|||||||
config = Config(**cfg)
|
config = Config(**cfg)
|
||||||
if not config:
|
if not config:
|
||||||
raise ValueError("You need to specify a simulation configuration")
|
raise ValueError("You need to specify a simulation configuration")
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
|
||||||
@ -151,6 +150,7 @@ class Simulation:
|
|||||||
# 'environment_agents': self.environment_agents,
|
# 'environment_agents': self.environment_agents,
|
||||||
# })
|
# })
|
||||||
# opts.update(kwargs)
|
# opts.update(kwargs)
|
||||||
|
print(self.config)
|
||||||
env = Environment.from_config(self.config, trial_id=trial_id, **kwargs)
|
env = Environment.from_config(self.config, trial_id=trial_id, **kwargs)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
49
tests/complete_converted.yml
Normal file
49
tests/complete_converted.yml
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
---
|
||||||
|
version: '2'
|
||||||
|
general:
|
||||||
|
id: simple
|
||||||
|
group: tests
|
||||||
|
dir_path: "/tmp/"
|
||||||
|
num_trials: 3
|
||||||
|
max_time: 100
|
||||||
|
interval: 1
|
||||||
|
seed: "CompleteSeed!"
|
||||||
|
topologies:
|
||||||
|
default:
|
||||||
|
params:
|
||||||
|
generator: complete_graph
|
||||||
|
n: 10
|
||||||
|
agents:
|
||||||
|
default:
|
||||||
|
agent_class: CounterModel
|
||||||
|
state:
|
||||||
|
times: 1
|
||||||
|
network:
|
||||||
|
topology: 'default'
|
||||||
|
distribution:
|
||||||
|
- agent_class: CounterModel
|
||||||
|
weight: 0.4
|
||||||
|
state:
|
||||||
|
state_id: 0
|
||||||
|
- agent_class: AggregatedCounter
|
||||||
|
weight: 0.6
|
||||||
|
override:
|
||||||
|
- filter:
|
||||||
|
id: 0
|
||||||
|
state:
|
||||||
|
name: 'The first node'
|
||||||
|
- filter:
|
||||||
|
id: 1
|
||||||
|
state:
|
||||||
|
name: 'The second node'
|
||||||
|
|
||||||
|
environment:
|
||||||
|
fixed:
|
||||||
|
- name: 'Environment Agent 1'
|
||||||
|
agent_class: CounterModel
|
||||||
|
state:
|
||||||
|
times: 10
|
||||||
|
environment:
|
||||||
|
environment_class: Environment
|
||||||
|
params:
|
||||||
|
am_i_complete: true
|
@ -11,11 +11,11 @@ network_params:
|
|||||||
n: 10
|
n: 10
|
||||||
network_agents:
|
network_agents:
|
||||||
- agent_type: CounterModel
|
- agent_type: CounterModel
|
||||||
weight: 1
|
weight: 0.4
|
||||||
state:
|
state:
|
||||||
state_id: 0
|
state_id: 0
|
||||||
- agent_type: AggregatedCounter
|
- agent_type: AggregatedCounter
|
||||||
weight: 0.2
|
weight: 0.6
|
||||||
environment_agents:
|
environment_agents:
|
||||||
- agent_id: 'Environment Agent 1'
|
- agent_id: 'Environment Agent 1'
|
||||||
agent_type: CounterModel
|
agent_type: CounterModel
|
||||||
|
@ -2,7 +2,7 @@ from unittest import TestCase
|
|||||||
import os
|
import os
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
|
||||||
from soil import simulation, serialization, config, network
|
from soil import simulation, serialization, config, network, agents
|
||||||
|
|
||||||
ROOT = os.path.abspath(os.path.dirname(__file__))
|
ROOT = os.path.abspath(os.path.dirname(__file__))
|
||||||
EXAMPLES = join(ROOT, '..', 'examples')
|
EXAMPLES = join(ROOT, '..', 'examples')
|
||||||
@ -13,18 +13,22 @@ FORCE_TESTS = os.environ.get('FORCE_TESTS', '')
|
|||||||
class TestConfig(TestCase):
|
class TestConfig(TestCase):
|
||||||
|
|
||||||
def test_conversion(self):
|
def test_conversion(self):
|
||||||
new = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
|
expected = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
|
||||||
old = serialization.load_file(join(ROOT, "old_complete.yml"))[0]
|
old = serialization.load_file(join(ROOT, "old_complete.yml"))[0]
|
||||||
converted_defaults = config.convert_old(old, strict=False)
|
converted_defaults = config.convert_old(old, strict=False)
|
||||||
converted = converted_defaults.dict(skip_defaults=True)
|
converted = converted_defaults.dict(skip_defaults=True)
|
||||||
def isequal(old, new):
|
|
||||||
if isinstance(old, dict):
|
|
||||||
for (k, v) in old.items():
|
|
||||||
isequal(old[k], new[k])
|
|
||||||
return
|
|
||||||
assert old == new
|
|
||||||
|
|
||||||
isequal(new, converted)
|
def isequal(a, b):
|
||||||
|
if isinstance(a, dict):
|
||||||
|
for (k, v) in a.items():
|
||||||
|
if v:
|
||||||
|
isequal(a[k], b[k])
|
||||||
|
else:
|
||||||
|
assert not b.get(k, None)
|
||||||
|
return
|
||||||
|
assert a == b
|
||||||
|
|
||||||
|
isequal(converted, expected)
|
||||||
|
|
||||||
def test_topology_config(self):
|
def test_topology_config(self):
|
||||||
netconfig = config.NetConfig(**{
|
netconfig = config.NetConfig(**{
|
||||||
@ -60,6 +64,19 @@ class TestConfig(TestCase):
|
|||||||
assert env.agents[0].topology == env.topologies['default']
|
assert env.agents[0].topology == env.topologies['default']
|
||||||
|
|
||||||
|
|
||||||
|
def test_agents_from_config(self):
|
||||||
|
'''We test that the known complete configuration produces
|
||||||
|
the right agents in the right groups'''
|
||||||
|
cfg = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
|
||||||
|
s = simulation.from_config(cfg)
|
||||||
|
env = s.get_env()
|
||||||
|
assert len(env.topologies['default'].nodes) == 10
|
||||||
|
assert len(env.agents('network')) == 10
|
||||||
|
assert len(env.agents('environment')) == 1
|
||||||
|
|
||||||
|
assert sum(1 for a in env.agents('network') if isinstance(a, agents.CounterModel)) == 4
|
||||||
|
assert sum(1 for a in env.agents('network') if isinstance(a, agents.AggregatedCounter)) == 6
|
||||||
|
|
||||||
def make_example_test(path, cfg):
|
def make_example_test(path, cfg):
|
||||||
def wrapped(self):
|
def wrapped(self):
|
||||||
root = os.getcwd()
|
root = os.getcwd()
|
||||||
|
@ -18,10 +18,10 @@ def make_example_test(path, config):
|
|||||||
def wrapped(self):
|
def wrapped(self):
|
||||||
root = os.getcwd()
|
root = os.getcwd()
|
||||||
for s in simulation.all_from_config(path):
|
for s in simulation.all_from_config(path):
|
||||||
iterations = s.config.max_time * s.config.num_trials
|
iterations = s.config.general.max_time * s.config.general.num_trials
|
||||||
if iterations > 1000:
|
if iterations > 1000:
|
||||||
s.config.max_time = 100
|
s.config.general.max_time = 100
|
||||||
s.config.num_trials = 1
|
s.config.general.num_trials = 1
|
||||||
if config.get('skip_test', False) and not FORCE_TESTS:
|
if config.get('skip_test', False) and not FORCE_TESTS:
|
||||||
self.skipTest('Example ignored.')
|
self.skipTest('Example ignored.')
|
||||||
envs = s.run_simulation(dry_run=True)
|
envs = s.run_simulation(dry_run=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user