1
0
mirror of https://github.com/gsi-upm/soil synced 2024-11-22 03:02:28 +00:00

WIP: working config

This commit is contained in:
J. Fernando Sánchez 2022-09-15 19:27:17 +02:00
parent e41dc3dae2
commit 3dc56892c1
10 changed files with 284 additions and 100 deletions

View File

@ -22,7 +22,7 @@ environment:
params:
am_i_complete: true
agents:
# Agents are split into groups, each with its own definition
# Agents are split several groups, each with its own definition
default: # This is a special group. Its values will be used as default values for the rest of the groups
agent_class: CounterModel
topology: default
@ -31,7 +31,7 @@ agents:
environment:
# In this group we are not specifying any topology
fixed:
- agent_id: 'Environment Agent 1'
- name: 'Environment Agent 1'
agent_class: CounterModel
state:
times: 10
@ -41,10 +41,16 @@ agents:
- agent_class: CounterModel
weight: 1
state:
id: 0
times: 3
- agent_class: AggregatedCounter
weight: 0.2
override:
- filter:
agent_class: AggregatedCounter
n: 2
state:
times: 5
other_counters:
topology: another_graph
fixed:

View File

@ -1,6 +1,7 @@
import logging
from collections import OrderedDict, defaultdict
from collections.abc import Mapping, Set
from collections.abc import MutableMapping, Mapping, Set
from abc import ABCMeta
from copy import deepcopy
from functools import partial, wraps
from itertools import islice, chain
@ -10,6 +11,8 @@ import networkx as nx
from mesa import Agent as MesaAgent
from typing import Dict, List
from random import shuffle
from .. import serialization, utils, time, config
@ -25,7 +28,7 @@ IGNORED_FIELDS = ('model', 'logger')
class DeadAgent(Exception):
pass
class BaseAgent(MesaAgent):
class BaseAgent(MesaAgent, MutableMapping):
"""
A special type of Mesa Agent that:
@ -50,8 +53,10 @@ class BaseAgent(MesaAgent):
# Initialize agent parameters
if isinstance(unique_id, MesaAgent):
raise Exception()
assert isinstance(unique_id, int)
super().__init__(unique_id=unique_id, model=model)
self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id)
self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id)
self._neighbors = None
self.alive = True
@ -120,6 +125,12 @@ class BaseAgent(MesaAgent):
def __setitem__(self, key, value):
setattr(self, key, value)
def __len__(self):
return sum(1 for n in self.keys())
def __iter__(self):
return self.items()
def keys(self):
return (k for k in self.__dict__ if k[0] != '_')
@ -284,7 +295,7 @@ def default_state(func):
return func
class MetaFSM(type):
class MetaFSM(ABCMeta):
def __init__(cls, name, bases, nmspc):
super(MetaFSM, cls).__init__(name, bases, nmspc)
states = {}
@ -486,14 +497,15 @@ def _definition_to_dict(definition, size=None, default_state=None):
distro = sorted([item for item in definition if 'weight' in item])
ix = 0
id = 0
def init_agent(item, id=ix):
while id in agents:
id += 1
agent = remaining[id]
agent['state'].update(copy(item.get('state', {})))
agents[id] = agent
agents[agent.unique_id] = agent
del remaining[id]
return agent
@ -554,7 +566,7 @@ class AgentView(Mapping, Set):
return sum(len(x) for x in self._agents.values())
def __iter__(self):
return iter(chain.from_iterable(g.values() for g in self._agents.values()))
yield from iter(chain.from_iterable(g.values() for g in self._agents.values()))
def __getitem__(self, agent_id):
if isinstance(agent_id, slice):
@ -564,16 +576,43 @@ class AgentView(Mapping, Set):
return group[agent_id]
raise ValueError(f"Agent {agent_id} not found")
def filter(self, ids=None, groups=None, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs):
def filter(self, *group_ids, **kwargs):
yield from filter_groups(self._agents, group_ids=group_ids, **kwargs)
def __call__(self, *args, **kwargs):
return list(self.filter(*args, **kwargs))
def __contains__(self, agent_id):
return any(agent_id in g for g in self._agents)
def __str__(self):
return str(list(a.id for a in self))
def __repr__(self):
return f"{self.__class__.__name__}({self})"
def filter_groups(groups, group_ids=None, **kwargs):
assert isinstance(groups, dict)
if group_ids:
groups = list(groups[g] for g in group_ids if g in groups)
else:
groups = list(groups.values())
agents = chain.from_iterable(filter_group(g, **kwargs) for g in groups)
yield from agents
def filter_group(group, ids=None, state_id=None, agent_type=None, ignore=None, state=None, **kwargs):
'''
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
'''
assert isinstance(group, dict)
if state_id is not None and not isinstance(state_id, (tuple, list)):
state_id = tuple([state_id])
agents = self._agents
if groups:
agents = {(k,v) for (k, v) in agents.items() if k in groups}
if agent_type is not None:
try:
agent_type = tuple(agent_type)
@ -581,9 +620,9 @@ class AgentView(Mapping, Set):
agent_type = tuple([agent_type])
if ids:
agents = (v[aid] for v in agents.values() for aid in ids if aid in v)
agents = (v[aid] for aid in ids if aid in group)
else:
agents = (a for v in agents.values() for a in v.values())
agents = (a for a in group.values())
f = agents
if ignore:
@ -594,24 +633,14 @@ class AgentView(Mapping, Set):
if agent_type is not None:
f = filter(lambda agent: isinstance(agent, agent_type), f)
for k, v in kwargs.items():
state = state or dict()
state.update(kwargs)
for k, v in state.items():
f = filter(lambda agent: agent.state.get(k, None) == v, f)
if iterator:
return f
return list(f)
def __call__(self, *args, **kwargs):
return self.filter(*args, **kwargs)
def __contains__(self, agent_id):
return any(agent_id in g for g in self._agents)
def __str__(self):
return str(list(a.id for a in self))
def __repr__(self):
return f"{self.__class__.__name__}({self})"
yield from f
def from_config(cfg: Dict[str, config.AgentConfig], env):
@ -632,10 +661,22 @@ def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfi
agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env)
if cfg.distribution:
n = cfg.n or len(env.topologies[cfg.topology])
agents.update(_from_distro(cfg.distribution, n - len(agents),
target = n - len(agents)
agents.update(_from_distro(cfg.distribution, target,
topology=cfg.topology or default.topology,
default=default,
env=env))
assert len(agents) == n
if cfg.override:
for attrs in cfg.override:
if attrs.filter:
filtered = list(filter_group(agents, **attrs.filter))
else:
filtered = list(agents)
for agent in random.sample(filtered, attrs.n):
agent.state.update(attrs.state)
return agents
@ -650,10 +691,11 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: conf
cls = serialization.deserialize(fixed.agent_class or default.agent_class)
state = fixed.state.copy()
state.update(default.state)
agents[agent_id] = cls(unique_id=agent_id,
agent = cls(unique_id=agent_id,
model=env,
graph_name=fixed.topology or topology or default.topology,
**state)
agents[agent.unique_id] = agent
return agents
@ -671,31 +713,40 @@ def _from_distro(distro: List[config.AgentDistro],
raise ValueError('You must provide a total number of agents, or the number of each type')
n = sum(dist.n for dist in distro)
weights = list(dist.weight if dist.weight is not None else 1 for dist in distro)
minw = min(weights)
norm = list(weight / minw for weight in weights)
total = sum(norm)
chunk = n // total
total = sum((dist.weight if dist.weight is not None else 1) for dist in distro)
thres = {}
last = 0
for i in sorted(distro, key=lambda x: x.weight):
# random.choices would be enough to get a weighted distribution. But it can vary a lot for smaller k
# So instead we calculate our own distribution to make sure the actual ratios are close to what we would expect
cls = serialization.deserialize(i.agent_class or default.agent_class)
thres[(last, last + i.weight/total)] = (cls, i)
# Calculate how many times each has to appear
indices = list(chain.from_iterable([idx] * int(n*chunk) for (idx, n) in enumerate(norm)))
acc = 0
# Complete with random agents following the original weight distribution
if len(indices) < n:
indices += random.choices(list(range(len(distro))), weights=[d.weight for d in distro], k=n-len(indices))
# using np.choice would be more efficient, but this allows us to use soil without
# numpy
for i in range(n):
r = random.random()
for (t, (cls, d)) in thres.items():
if r >= t[0] and r <= t[1]:
agent_id = d.agent_id
if agent_id is None:
# Deserialize classes for efficiency
classes = list(serialization.deserialize(i.agent_class or default.agent_class) for i in distro)
# Add them in random order
random.shuffle(indices)
for idx in indices:
d = distro[idx]
cls = classes[idx]
agent_id = env.next_id()
state = d.state.copy()
state.update(default.state)
agents[agent_id] = cls(unique_id=agent_id, model=env, graph_name=d.topology or topology or default.topology, **state)
break
agent = cls(unique_id=agent_id, model=env, graph_name=d.topology or topology or default.topology, **state)
assert agent.name is not None
assert agent.name != 'None'
assert agent.name
agents[agent.unique_id] = agent
return agents

View File

@ -76,8 +76,9 @@ class EnvConfig(BaseModel):
class SingleAgentConfig(BaseModel):
agent_class: Optional[Union[Type, str]] = None
agent_id: Optional[Union[str, int]] = None
topology: Optional[str] = None
agent_id: Optional[int] = None
topology: Optional[str] = 'default'
name: Optional[str] = None
state: Optional[Dict[str, Any]] = {}
class FixedAgentConfig(SingleAgentConfig):
@ -85,11 +86,16 @@ class FixedAgentConfig(SingleAgentConfig):
@root_validator
def validate_all(cls, values):
if 'agent_id' in values and values.get('n', 1) > 1:
raise ValueError("An agent_id can only be provided when there is only one agent")
if values.get('agent_id', None) is not None and values.get('n', 1) > 1:
print(values)
raise ValueError(f"An agent_id can only be provided when there is only one agent ({values.get('n')} given)")
return values
class OverrideAgentConfig(FixedAgentConfig):
filter: Optional[Dict[str, Any]] = None
class AgentDistro(SingleAgentConfig):
weight: Optional[float] = 1
@ -99,6 +105,7 @@ class AgentConfig(SingleAgentConfig):
topology: Optional[str] = None
distribution: Optional[List[AgentDistro]] = None
fixed: Optional[List[FixedAgentConfig]] = None
override: Optional[List[OverrideAgentConfig]] = None
@staticmethod
def default():
@ -118,7 +125,6 @@ class Config(BaseModel, extra=Extra.forbid):
environment: EnvConfig = EnvConfig.default()
agents: Optional[Dict[str, AgentConfig]] = {}
def convert_old(old, strict=True):
'''
Try to convert old style configs into the new format.
@ -126,10 +132,6 @@ def convert_old(old, strict=True):
This is still a work in progress and might not work in many cases.
'''
# TODO: translate states
if strict and old.get('states', {}):
raise ValueError('Custom (i.e., manual) agent states cannot be translated to v2 configuration files. Please, convert your configuration file to the new format.')
new = {}
@ -181,13 +183,16 @@ def convert_old(old, strict=True):
for agent in old.get('environment_agents', []):
agents['environment'] = {'distribution': [], 'fixed': []}
if 'agent_id' not in agent:
agents['environment']['distribution'].append(updated_agent(agent))
else:
if 'agent_id' in agent:
agent['name'] = agent['agent_id']
del agent['agent_id']
agents['environment']['fixed'].append(updated_agent(agent))
else:
agents['environment']['distribution'].append(updated_agent(agent))
by_weight = []
fixed = []
override = []
if 'network_agents' in old:
agents['network']['topology'] = 'default'
@ -203,6 +208,20 @@ def convert_old(old, strict=True):
agents['network']['topology'] = 'default'
by_weight = [{'agent_type': old['agent_type']}]
# TODO: translate states
if 'states' in old:
states = old['states']
if isinstance(states, dict):
states = states.items()
else:
states = enumerate(states)
for (k, v) in states:
override.append({'filter': {'id': k},
'state': v
})
agents['network']['override'] = override
agents['network']['fixed'] = fixed
agents['network']['distribution'] = by_weight

View File

@ -79,7 +79,7 @@ class Environment(Model):
@property
def network_agents(self):
yield from self.agents(agent_type=agents.NetworkAgent, iterator=True)
yield from self.agents(agent_type=agents.NetworkAgent, iterator=False)
@staticmethod

42
soil/network.py Normal file
View File

@ -0,0 +1,42 @@
from typing import Dict
import os
import sys
import networkx as nx
from . import config, serialization, basestring
def from_config(cfg: config.NetConfig, dir_path: str = None):
if not isinstance(cfg, config.NetConfig):
cfg = config.NetConfig(**cfg)
if cfg.path:
path = cfg.path
if dir_path and not os.path.isabs(path):
path = os.path.join(dir_path, path)
extension = os.path.splitext(path)[1][1:]
kwargs = {}
if extension == 'gexf':
kwargs['version'] = '1.2draft'
kwargs['node_type'] = int
try:
method = getattr(nx.readwrite, 'read_' + extension)
except AttributeError:
raise AttributeError('Unknown format')
return method(path, **kwargs)
if cfg.params:
net_args = cfg.params.dict()
net_gen = net_args.pop('generator')
if dir_path not in sys.path:
sys.path.append(dir_path)
method = serialization.deserializer(net_gen,
known_modules=['networkx.generators',])
return method(**net_args)
if isinstance(cfg.topology, basestring) or isinstance(cfg.topology, dict):
return nx.json_graph.node_link_graph(cfg.topology)
return nx.Graph()

View File

@ -42,7 +42,6 @@ class Simulation:
config = Config(**cfg)
if not config:
raise ValueError("You need to specify a simulation configuration")
self.config = config
@ -151,6 +150,7 @@ class Simulation:
# 'environment_agents': self.environment_agents,
# })
# opts.update(kwargs)
print(self.config)
env = Environment.from_config(self.config, trial_id=trial_id, **kwargs)
return env

View File

@ -0,0 +1,49 @@
---
version: '2'
general:
id: simple
group: tests
dir_path: "/tmp/"
num_trials: 3
max_time: 100
interval: 1
seed: "CompleteSeed!"
topologies:
default:
params:
generator: complete_graph
n: 10
agents:
default:
agent_class: CounterModel
state:
times: 1
network:
topology: 'default'
distribution:
- agent_class: CounterModel
weight: 0.4
state:
state_id: 0
- agent_class: AggregatedCounter
weight: 0.6
override:
- filter:
id: 0
state:
name: 'The first node'
- filter:
id: 1
state:
name: 'The second node'
environment:
fixed:
- name: 'Environment Agent 1'
agent_class: CounterModel
state:
times: 10
environment:
environment_class: Environment
params:
am_i_complete: true

View File

@ -11,11 +11,11 @@ network_params:
n: 10
network_agents:
- agent_type: CounterModel
weight: 1
weight: 0.4
state:
state_id: 0
- agent_type: AggregatedCounter
weight: 0.2
weight: 0.6
environment_agents:
- agent_id: 'Environment Agent 1'
agent_type: CounterModel

View File

@ -2,7 +2,7 @@ from unittest import TestCase
import os
from os.path import join
from soil import simulation, serialization, config, network
from soil import simulation, serialization, config, network, agents
ROOT = os.path.abspath(os.path.dirname(__file__))
EXAMPLES = join(ROOT, '..', 'examples')
@ -13,18 +13,22 @@ FORCE_TESTS = os.environ.get('FORCE_TESTS', '')
class TestConfig(TestCase):
def test_conversion(self):
new = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
expected = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
old = serialization.load_file(join(ROOT, "old_complete.yml"))[0]
converted_defaults = config.convert_old(old, strict=False)
converted = converted_defaults.dict(skip_defaults=True)
def isequal(old, new):
if isinstance(old, dict):
for (k, v) in old.items():
isequal(old[k], new[k])
return
assert old == new
isequal(new, converted)
def isequal(a, b):
if isinstance(a, dict):
for (k, v) in a.items():
if v:
isequal(a[k], b[k])
else:
assert not b.get(k, None)
return
assert a == b
isequal(converted, expected)
def test_topology_config(self):
netconfig = config.NetConfig(**{
@ -60,6 +64,19 @@ class TestConfig(TestCase):
assert env.agents[0].topology == env.topologies['default']
def test_agents_from_config(self):
'''We test that the known complete configuration produces
the right agents in the right groups'''
cfg = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
s = simulation.from_config(cfg)
env = s.get_env()
assert len(env.topologies['default'].nodes) == 10
assert len(env.agents('network')) == 10
assert len(env.agents('environment')) == 1
assert sum(1 for a in env.agents('network') if isinstance(a, agents.CounterModel)) == 4
assert sum(1 for a in env.agents('network') if isinstance(a, agents.AggregatedCounter)) == 6
def make_example_test(path, cfg):
def wrapped(self):
root = os.getcwd()

View File

@ -18,10 +18,10 @@ def make_example_test(path, config):
def wrapped(self):
root = os.getcwd()
for s in simulation.all_from_config(path):
iterations = s.config.max_time * s.config.num_trials
iterations = s.config.general.max_time * s.config.general.num_trials
if iterations > 1000:
s.config.max_time = 100
s.config.num_trials = 1
s.config.general.max_time = 100
s.config.general.num_trials = 1
if config.get('skip_test', False) and not FORCE_TESTS:
self.skipTest('Example ignored.')
envs = s.run_simulation(dry_run=True)