1
0
mirror of https://github.com/gsi-upm/soil synced 2024-11-21 18:52:28 +00:00

WIP: working config

This commit is contained in:
J. Fernando Sánchez 2022-09-15 19:27:17 +02:00
parent e41dc3dae2
commit 3dc56892c1
10 changed files with 284 additions and 100 deletions

View File

@ -22,7 +22,7 @@ environment:
params: params:
am_i_complete: true am_i_complete: true
agents: agents:
# Agents are split into groups, each with its own definition # Agents are split several groups, each with its own definition
default: # This is a special group. Its values will be used as default values for the rest of the groups default: # This is a special group. Its values will be used as default values for the rest of the groups
agent_class: CounterModel agent_class: CounterModel
topology: default topology: default
@ -31,7 +31,7 @@ agents:
environment: environment:
# In this group we are not specifying any topology # In this group we are not specifying any topology
fixed: fixed:
- agent_id: 'Environment Agent 1' - name: 'Environment Agent 1'
agent_class: CounterModel agent_class: CounterModel
state: state:
times: 10 times: 10
@ -41,10 +41,16 @@ agents:
- agent_class: CounterModel - agent_class: CounterModel
weight: 1 weight: 1
state: state:
id: 0
times: 3 times: 3
- agent_class: AggregatedCounter - agent_class: AggregatedCounter
weight: 0.2 weight: 0.2
override:
- filter:
agent_class: AggregatedCounter
n: 2
state:
times: 5
other_counters: other_counters:
topology: another_graph topology: another_graph
fixed: fixed:

View File

@ -1,6 +1,7 @@
import logging import logging
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from collections.abc import Mapping, Set from collections.abc import MutableMapping, Mapping, Set
from abc import ABCMeta
from copy import deepcopy from copy import deepcopy
from functools import partial, wraps from functools import partial, wraps
from itertools import islice, chain from itertools import islice, chain
@ -10,6 +11,8 @@ import networkx as nx
from mesa import Agent as MesaAgent from mesa import Agent as MesaAgent
from typing import Dict, List from typing import Dict, List
from random import shuffle
from .. import serialization, utils, time, config from .. import serialization, utils, time, config
@ -25,7 +28,7 @@ IGNORED_FIELDS = ('model', 'logger')
class DeadAgent(Exception): class DeadAgent(Exception):
pass pass
class BaseAgent(MesaAgent): class BaseAgent(MesaAgent, MutableMapping):
""" """
A special type of Mesa Agent that: A special type of Mesa Agent that:
@ -50,8 +53,10 @@ class BaseAgent(MesaAgent):
# Initialize agent parameters # Initialize agent parameters
if isinstance(unique_id, MesaAgent): if isinstance(unique_id, MesaAgent):
raise Exception() raise Exception()
assert isinstance(unique_id, int)
super().__init__(unique_id=unique_id, model=model) super().__init__(unique_id=unique_id, model=model)
self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id) self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id)
self._neighbors = None self._neighbors = None
self.alive = True self.alive = True
@ -120,6 +125,12 @@ class BaseAgent(MesaAgent):
def __setitem__(self, key, value): def __setitem__(self, key, value):
setattr(self, key, value) setattr(self, key, value)
def __len__(self):
return sum(1 for n in self.keys())
def __iter__(self):
return self.items()
def keys(self): def keys(self):
return (k for k in self.__dict__ if k[0] != '_') return (k for k in self.__dict__ if k[0] != '_')
@ -284,7 +295,7 @@ def default_state(func):
return func return func
class MetaFSM(type): class MetaFSM(ABCMeta):
def __init__(cls, name, bases, nmspc): def __init__(cls, name, bases, nmspc):
super(MetaFSM, cls).__init__(name, bases, nmspc) super(MetaFSM, cls).__init__(name, bases, nmspc)
states = {} states = {}
@ -486,14 +497,15 @@ def _definition_to_dict(definition, size=None, default_state=None):
distro = sorted([item for item in definition if 'weight' in item]) distro = sorted([item for item in definition if 'weight' in item])
ix = 0 id = 0
def init_agent(item, id=ix): def init_agent(item, id=ix):
while id in agents: while id in agents:
id += 1 id += 1
agent = remaining[id] agent = remaining[id]
agent['state'].update(copy(item.get('state', {}))) agent['state'].update(copy(item.get('state', {})))
agents[id] = agent agents[agent.unique_id] = agent
del remaining[id] del remaining[id]
return agent return agent
@ -554,7 +566,7 @@ class AgentView(Mapping, Set):
return sum(len(x) for x in self._agents.values()) return sum(len(x) for x in self._agents.values())
def __iter__(self): def __iter__(self):
return iter(chain.from_iterable(g.values() for g in self._agents.values())) yield from iter(chain.from_iterable(g.values() for g in self._agents.values()))
def __getitem__(self, agent_id): def __getitem__(self, agent_id):
if isinstance(agent_id, slice): if isinstance(agent_id, slice):
@ -564,45 +576,11 @@ class AgentView(Mapping, Set):
return group[agent_id] return group[agent_id]
raise ValueError(f"Agent {agent_id} not found") raise ValueError(f"Agent {agent_id} not found")
def filter(self, ids=None, groups=None, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs): def filter(self, *group_ids, **kwargs):
yield from filter_groups(self._agents, group_ids=group_ids, **kwargs)
if state_id is not None and not isinstance(state_id, (tuple, list)):
state_id = tuple([state_id])
agents = self._agents
if groups:
agents = {(k,v) for (k, v) in agents.items() if k in groups}
if agent_type is not None:
try:
agent_type = tuple(agent_type)
except TypeError:
agent_type = tuple([agent_type])
if ids:
agents = (v[aid] for v in agents.values() for aid in ids if aid in v)
else:
agents = (a for v in agents.values() for a in v.values())
f = agents
if ignore:
f = filter(lambda x: x not in ignore, f)
if state_id is not None:
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
if agent_type is not None:
f = filter(lambda agent: isinstance(agent, agent_type), f)
for k, v in kwargs.items():
f = filter(lambda agent: agent.state.get(k, None) == v, f)
if iterator:
return f
return list(f)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.filter(*args, **kwargs) return list(self.filter(*args, **kwargs))
def __contains__(self, agent_id): def __contains__(self, agent_id):
return any(agent_id in g for g in self._agents) return any(agent_id in g for g in self._agents)
@ -614,6 +592,57 @@ class AgentView(Mapping, Set):
return f"{self.__class__.__name__}({self})" return f"{self.__class__.__name__}({self})"
def filter_groups(groups, group_ids=None, **kwargs):
assert isinstance(groups, dict)
if group_ids:
groups = list(groups[g] for g in group_ids if g in groups)
else:
groups = list(groups.values())
agents = chain.from_iterable(filter_group(g, **kwargs) for g in groups)
yield from agents
def filter_group(group, ids=None, state_id=None, agent_type=None, ignore=None, state=None, **kwargs):
'''
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
'''
assert isinstance(group, dict)
if state_id is not None and not isinstance(state_id, (tuple, list)):
state_id = tuple([state_id])
if agent_type is not None:
try:
agent_type = tuple(agent_type)
except TypeError:
agent_type = tuple([agent_type])
if ids:
agents = (v[aid] for aid in ids if aid in group)
else:
agents = (a for a in group.values())
f = agents
if ignore:
f = filter(lambda x: x not in ignore, f)
if state_id is not None:
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
if agent_type is not None:
f = filter(lambda agent: isinstance(agent, agent_type), f)
state = state or dict()
state.update(kwargs)
for k, v in state.items():
f = filter(lambda agent: agent.state.get(k, None) == v, f)
yield from f
def from_config(cfg: Dict[str, config.AgentConfig], env): def from_config(cfg: Dict[str, config.AgentConfig], env):
''' '''
Agents are specified in groups. Agents are specified in groups.
@ -632,10 +661,22 @@ def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfi
agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env) agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env)
if cfg.distribution: if cfg.distribution:
n = cfg.n or len(env.topologies[cfg.topology]) n = cfg.n or len(env.topologies[cfg.topology])
agents.update(_from_distro(cfg.distribution, n - len(agents), target = n - len(agents)
agents.update(_from_distro(cfg.distribution, target,
topology=cfg.topology or default.topology, topology=cfg.topology or default.topology,
default=default, default=default,
env=env)) env=env))
assert len(agents) == n
if cfg.override:
for attrs in cfg.override:
if attrs.filter:
filtered = list(filter_group(agents, **attrs.filter))
else:
filtered = list(agents)
for agent in random.sample(filtered, attrs.n):
agent.state.update(attrs.state)
return agents return agents
@ -650,10 +691,11 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: conf
cls = serialization.deserialize(fixed.agent_class or default.agent_class) cls = serialization.deserialize(fixed.agent_class or default.agent_class)
state = fixed.state.copy() state = fixed.state.copy()
state.update(default.state) state.update(default.state)
agents[agent_id] = cls(unique_id=agent_id, agent = cls(unique_id=agent_id,
model=env, model=env,
graph_name=fixed.topology or topology or default.topology, graph_name=fixed.topology or topology or default.topology,
**state) **state)
agents[agent.unique_id] = agent
return agents return agents
@ -671,31 +713,40 @@ def _from_distro(distro: List[config.AgentDistro],
raise ValueError('You must provide a total number of agents, or the number of each type') raise ValueError('You must provide a total number of agents, or the number of each type')
n = sum(dist.n for dist in distro) n = sum(dist.n for dist in distro)
weights = list(dist.weight if dist.weight is not None else 1 for dist in distro)
minw = min(weights)
norm = list(weight / minw for weight in weights)
total = sum(norm)
chunk = n // total
total = sum((dist.weight if dist.weight is not None else 1) for dist in distro) # random.choices would be enough to get a weighted distribution. But it can vary a lot for smaller k
thres = {} # So instead we calculate our own distribution to make sure the actual ratios are close to what we would expect
last = 0
for i in sorted(distro, key=lambda x: x.weight):
cls = serialization.deserialize(i.agent_class or default.agent_class) # Calculate how many times each has to appear
thres[(last, last + i.weight/total)] = (cls, i) indices = list(chain.from_iterable([idx] * int(n*chunk) for (idx, n) in enumerate(norm)))
acc = 0 # Complete with random agents following the original weight distribution
if len(indices) < n:
indices += random.choices(list(range(len(distro))), weights=[d.weight for d in distro], k=n-len(indices))
# using np.choice would be more efficient, but this allows us to use soil without # Deserialize classes for efficiency
# numpy classes = list(serialization.deserialize(i.agent_class or default.agent_class) for i in distro)
for i in range(n):
r = random.random()
for (t, (cls, d)) in thres.items():
if r >= t[0] and r <= t[1]:
agent_id = d.agent_id
if agent_id is None:
agent_id = env.next_id()
state = d.state.copy() # Add them in random order
state.update(default.state) random.shuffle(indices)
agents[agent_id] = cls(unique_id=agent_id, model=env, graph_name=d.topology or topology or default.topology, **state)
break
for idx in indices:
d = distro[idx]
cls = classes[idx]
agent_id = env.next_id()
state = d.state.copy()
state.update(default.state)
agent = cls(unique_id=agent_id, model=env, graph_name=d.topology or topology or default.topology, **state)
assert agent.name is not None
assert agent.name != 'None'
assert agent.name
agents[agent.unique_id] = agent
return agents return agents

View File

@ -76,8 +76,9 @@ class EnvConfig(BaseModel):
class SingleAgentConfig(BaseModel): class SingleAgentConfig(BaseModel):
agent_class: Optional[Union[Type, str]] = None agent_class: Optional[Union[Type, str]] = None
agent_id: Optional[Union[str, int]] = None agent_id: Optional[int] = None
topology: Optional[str] = None topology: Optional[str] = 'default'
name: Optional[str] = None
state: Optional[Dict[str, Any]] = {} state: Optional[Dict[str, Any]] = {}
class FixedAgentConfig(SingleAgentConfig): class FixedAgentConfig(SingleAgentConfig):
@ -85,11 +86,16 @@ class FixedAgentConfig(SingleAgentConfig):
@root_validator @root_validator
def validate_all(cls, values): def validate_all(cls, values):
if 'agent_id' in values and values.get('n', 1) > 1: if values.get('agent_id', None) is not None and values.get('n', 1) > 1:
raise ValueError("An agent_id can only be provided when there is only one agent") print(values)
raise ValueError(f"An agent_id can only be provided when there is only one agent ({values.get('n')} given)")
return values return values
class OverrideAgentConfig(FixedAgentConfig):
filter: Optional[Dict[str, Any]] = None
class AgentDistro(SingleAgentConfig): class AgentDistro(SingleAgentConfig):
weight: Optional[float] = 1 weight: Optional[float] = 1
@ -99,6 +105,7 @@ class AgentConfig(SingleAgentConfig):
topology: Optional[str] = None topology: Optional[str] = None
distribution: Optional[List[AgentDistro]] = None distribution: Optional[List[AgentDistro]] = None
fixed: Optional[List[FixedAgentConfig]] = None fixed: Optional[List[FixedAgentConfig]] = None
override: Optional[List[OverrideAgentConfig]] = None
@staticmethod @staticmethod
def default(): def default():
@ -118,7 +125,6 @@ class Config(BaseModel, extra=Extra.forbid):
environment: EnvConfig = EnvConfig.default() environment: EnvConfig = EnvConfig.default()
agents: Optional[Dict[str, AgentConfig]] = {} agents: Optional[Dict[str, AgentConfig]] = {}
def convert_old(old, strict=True): def convert_old(old, strict=True):
''' '''
Try to convert old style configs into the new format. Try to convert old style configs into the new format.
@ -126,10 +132,6 @@ def convert_old(old, strict=True):
This is still a work in progress and might not work in many cases. This is still a work in progress and might not work in many cases.
''' '''
# TODO: translate states
if strict and old.get('states', {}):
raise ValueError('Custom (i.e., manual) agent states cannot be translated to v2 configuration files. Please, convert your configuration file to the new format.')
new = {} new = {}
@ -181,13 +183,16 @@ def convert_old(old, strict=True):
for agent in old.get('environment_agents', []): for agent in old.get('environment_agents', []):
agents['environment'] = {'distribution': [], 'fixed': []} agents['environment'] = {'distribution': [], 'fixed': []}
if 'agent_id' not in agent: if 'agent_id' in agent:
agents['environment']['distribution'].append(updated_agent(agent)) agent['name'] = agent['agent_id']
else: del agent['agent_id']
agents['environment']['fixed'].append(updated_agent(agent)) agents['environment']['fixed'].append(updated_agent(agent))
else:
agents['environment']['distribution'].append(updated_agent(agent))
by_weight = [] by_weight = []
fixed = [] fixed = []
override = []
if 'network_agents' in old: if 'network_agents' in old:
agents['network']['topology'] = 'default' agents['network']['topology'] = 'default'
@ -203,6 +208,20 @@ def convert_old(old, strict=True):
agents['network']['topology'] = 'default' agents['network']['topology'] = 'default'
by_weight = [{'agent_type': old['agent_type']}] by_weight = [{'agent_type': old['agent_type']}]
# TODO: translate states
if 'states' in old:
states = old['states']
if isinstance(states, dict):
states = states.items()
else:
states = enumerate(states)
for (k, v) in states:
override.append({'filter': {'id': k},
'state': v
})
agents['network']['override'] = override
agents['network']['fixed'] = fixed agents['network']['fixed'] = fixed
agents['network']['distribution'] = by_weight agents['network']['distribution'] = by_weight

View File

@ -79,7 +79,7 @@ class Environment(Model):
@property @property
def network_agents(self): def network_agents(self):
yield from self.agents(agent_type=agents.NetworkAgent, iterator=True) yield from self.agents(agent_type=agents.NetworkAgent, iterator=False)
@staticmethod @staticmethod

42
soil/network.py Normal file
View File

@ -0,0 +1,42 @@
from typing import Dict
import os
import sys
import networkx as nx
from . import config, serialization, basestring
def from_config(cfg: config.NetConfig, dir_path: str = None):
if not isinstance(cfg, config.NetConfig):
cfg = config.NetConfig(**cfg)
if cfg.path:
path = cfg.path
if dir_path and not os.path.isabs(path):
path = os.path.join(dir_path, path)
extension = os.path.splitext(path)[1][1:]
kwargs = {}
if extension == 'gexf':
kwargs['version'] = '1.2draft'
kwargs['node_type'] = int
try:
method = getattr(nx.readwrite, 'read_' + extension)
except AttributeError:
raise AttributeError('Unknown format')
return method(path, **kwargs)
if cfg.params:
net_args = cfg.params.dict()
net_gen = net_args.pop('generator')
if dir_path not in sys.path:
sys.path.append(dir_path)
method = serialization.deserializer(net_gen,
known_modules=['networkx.generators',])
return method(**net_args)
if isinstance(cfg.topology, basestring) or isinstance(cfg.topology, dict):
return nx.json_graph.node_link_graph(cfg.topology)
return nx.Graph()

View File

@ -42,7 +42,6 @@ class Simulation:
config = Config(**cfg) config = Config(**cfg)
if not config: if not config:
raise ValueError("You need to specify a simulation configuration") raise ValueError("You need to specify a simulation configuration")
self.config = config self.config = config
@ -151,6 +150,7 @@ class Simulation:
# 'environment_agents': self.environment_agents, # 'environment_agents': self.environment_agents,
# }) # })
# opts.update(kwargs) # opts.update(kwargs)
print(self.config)
env = Environment.from_config(self.config, trial_id=trial_id, **kwargs) env = Environment.from_config(self.config, trial_id=trial_id, **kwargs)
return env return env

View File

@ -0,0 +1,49 @@
---
version: '2'
general:
id: simple
group: tests
dir_path: "/tmp/"
num_trials: 3
max_time: 100
interval: 1
seed: "CompleteSeed!"
topologies:
default:
params:
generator: complete_graph
n: 10
agents:
default:
agent_class: CounterModel
state:
times: 1
network:
topology: 'default'
distribution:
- agent_class: CounterModel
weight: 0.4
state:
state_id: 0
- agent_class: AggregatedCounter
weight: 0.6
override:
- filter:
id: 0
state:
name: 'The first node'
- filter:
id: 1
state:
name: 'The second node'
environment:
fixed:
- name: 'Environment Agent 1'
agent_class: CounterModel
state:
times: 10
environment:
environment_class: Environment
params:
am_i_complete: true

View File

@ -11,11 +11,11 @@ network_params:
n: 10 n: 10
network_agents: network_agents:
- agent_type: CounterModel - agent_type: CounterModel
weight: 1 weight: 0.4
state: state:
state_id: 0 state_id: 0
- agent_type: AggregatedCounter - agent_type: AggregatedCounter
weight: 0.2 weight: 0.6
environment_agents: environment_agents:
- agent_id: 'Environment Agent 1' - agent_id: 'Environment Agent 1'
agent_type: CounterModel agent_type: CounterModel

View File

@ -2,7 +2,7 @@ from unittest import TestCase
import os import os
from os.path import join from os.path import join
from soil import simulation, serialization, config, network from soil import simulation, serialization, config, network, agents
ROOT = os.path.abspath(os.path.dirname(__file__)) ROOT = os.path.abspath(os.path.dirname(__file__))
EXAMPLES = join(ROOT, '..', 'examples') EXAMPLES = join(ROOT, '..', 'examples')
@ -13,18 +13,22 @@ FORCE_TESTS = os.environ.get('FORCE_TESTS', '')
class TestConfig(TestCase): class TestConfig(TestCase):
def test_conversion(self): def test_conversion(self):
new = serialization.load_file(join(ROOT, "complete_converted.yml"))[0] expected = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
old = serialization.load_file(join(ROOT, "old_complete.yml"))[0] old = serialization.load_file(join(ROOT, "old_complete.yml"))[0]
converted_defaults = config.convert_old(old, strict=False) converted_defaults = config.convert_old(old, strict=False)
converted = converted_defaults.dict(skip_defaults=True) converted = converted_defaults.dict(skip_defaults=True)
def isequal(old, new):
if isinstance(old, dict):
for (k, v) in old.items():
isequal(old[k], new[k])
return
assert old == new
isequal(new, converted) def isequal(a, b):
if isinstance(a, dict):
for (k, v) in a.items():
if v:
isequal(a[k], b[k])
else:
assert not b.get(k, None)
return
assert a == b
isequal(converted, expected)
def test_topology_config(self): def test_topology_config(self):
netconfig = config.NetConfig(**{ netconfig = config.NetConfig(**{
@ -60,6 +64,19 @@ class TestConfig(TestCase):
assert env.agents[0].topology == env.topologies['default'] assert env.agents[0].topology == env.topologies['default']
def test_agents_from_config(self):
'''We test that the known complete configuration produces
the right agents in the right groups'''
cfg = serialization.load_file(join(ROOT, "complete_converted.yml"))[0]
s = simulation.from_config(cfg)
env = s.get_env()
assert len(env.topologies['default'].nodes) == 10
assert len(env.agents('network')) == 10
assert len(env.agents('environment')) == 1
assert sum(1 for a in env.agents('network') if isinstance(a, agents.CounterModel)) == 4
assert sum(1 for a in env.agents('network') if isinstance(a, agents.AggregatedCounter)) == 6
def make_example_test(path, cfg): def make_example_test(path, cfg):
def wrapped(self): def wrapped(self):
root = os.getcwd() root = os.getcwd()

View File

@ -18,10 +18,10 @@ def make_example_test(path, config):
def wrapped(self): def wrapped(self):
root = os.getcwd() root = os.getcwd()
for s in simulation.all_from_config(path): for s in simulation.all_from_config(path):
iterations = s.config.max_time * s.config.num_trials iterations = s.config.general.max_time * s.config.general.num_trials
if iterations > 1000: if iterations > 1000:
s.config.max_time = 100 s.config.general.max_time = 100
s.config.num_trials = 1 s.config.general.num_trials = 1
if config.get('skip_test', False) and not FORCE_TESTS: if config.get('skip_test', False) and not FORCE_TESTS:
self.skipTest('Example ignored.') self.skipTest('Example ignored.')
envs = s.run_simulation(dry_run=True) envs = s.run_simulation(dry_run=True)