diff --git a/examples/complete.yml b/examples/complete.yml index f017075..adf1e7e 100644 --- a/examples/complete.yml +++ b/examples/complete.yml @@ -22,7 +22,7 @@ environment: params: am_i_complete: true agents: -# Agents are split into groups, each with its own definition +# Agents are split several groups, each with its own definition default: # This is a special group. Its values will be used as default values for the rest of the groups agent_class: CounterModel topology: default @@ -31,7 +31,7 @@ agents: environment: # In this group we are not specifying any topology fixed: - - agent_id: 'Environment Agent 1' + - name: 'Environment Agent 1' agent_class: CounterModel state: times: 10 @@ -41,10 +41,16 @@ agents: - agent_class: CounterModel weight: 1 state: - id: 0 times: 3 - agent_class: AggregatedCounter weight: 0.2 + override: + - filter: + agent_class: AggregatedCounter + n: 2 + state: + times: 5 + other_counters: topology: another_graph fixed: diff --git a/soil/agents/__init__.py b/soil/agents/__init__.py index 1b9b714..3b585f1 100644 --- a/soil/agents/__init__.py +++ b/soil/agents/__init__.py @@ -1,6 +1,7 @@ import logging from collections import OrderedDict, defaultdict -from collections.abc import Mapping, Set +from collections.abc import MutableMapping, Mapping, Set +from abc import ABCMeta from copy import deepcopy from functools import partial, wraps from itertools import islice, chain @@ -10,6 +11,8 @@ import networkx as nx from mesa import Agent as MesaAgent from typing import Dict, List +from random import shuffle + from .. import serialization, utils, time, config @@ -25,7 +28,7 @@ IGNORED_FIELDS = ('model', 'logger') class DeadAgent(Exception): pass -class BaseAgent(MesaAgent): +class BaseAgent(MesaAgent, MutableMapping): """ A special type of Mesa Agent that: @@ -50,8 +53,10 @@ class BaseAgent(MesaAgent): # Initialize agent parameters if isinstance(unique_id, MesaAgent): raise Exception() + assert isinstance(unique_id, int) super().__init__(unique_id=unique_id, model=model) - self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id) + self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id) + self._neighbors = None self.alive = True @@ -120,6 +125,12 @@ class BaseAgent(MesaAgent): def __setitem__(self, key, value): setattr(self, key, value) + def __len__(self): + return sum(1 for n in self.keys()) + + def __iter__(self): + return self.items() + def keys(self): return (k for k in self.__dict__ if k[0] != '_') @@ -284,7 +295,7 @@ def default_state(func): return func -class MetaFSM(type): +class MetaFSM(ABCMeta): def __init__(cls, name, bases, nmspc): super(MetaFSM, cls).__init__(name, bases, nmspc) states = {} @@ -486,14 +497,15 @@ def _definition_to_dict(definition, size=None, default_state=None): distro = sorted([item for item in definition if 'weight' in item]) - ix = 0 + id = 0 + def init_agent(item, id=ix): while id in agents: id += 1 agent = remaining[id] agent['state'].update(copy(item.get('state', {}))) - agents[id] = agent + agents[agent.unique_id] = agent del remaining[id] return agent @@ -554,7 +566,7 @@ class AgentView(Mapping, Set): return sum(len(x) for x in self._agents.values()) def __iter__(self): - return iter(chain.from_iterable(g.values() for g in self._agents.values())) + yield from iter(chain.from_iterable(g.values() for g in self._agents.values())) def __getitem__(self, agent_id): if isinstance(agent_id, slice): @@ -564,45 +576,11 @@ class AgentView(Mapping, Set): return group[agent_id] raise ValueError(f"Agent {agent_id} not found") - def filter(self, ids=None, groups=None, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs): - - if state_id is not None and not isinstance(state_id, (tuple, list)): - state_id = tuple([state_id]) - - agents = self._agents - - if groups: - agents = {(k,v) for (k, v) in agents.items() if k in groups} - - if agent_type is not None: - try: - agent_type = tuple(agent_type) - except TypeError: - agent_type = tuple([agent_type]) - - if ids: - agents = (v[aid] for v in agents.values() for aid in ids if aid in v) - else: - agents = (a for v in agents.values() for a in v.values()) - - f = agents - if ignore: - f = filter(lambda x: x not in ignore, f) - - if state_id is not None: - f = filter(lambda agent: agent.get('state_id', None) in state_id, f) - - if agent_type is not None: - f = filter(lambda agent: isinstance(agent, agent_type), f) - for k, v in kwargs.items(): - f = filter(lambda agent: agent.state.get(k, None) == v, f) - - if iterator: - return f - return list(f) + def filter(self, *group_ids, **kwargs): + yield from filter_groups(self._agents, group_ids=group_ids, **kwargs) def __call__(self, *args, **kwargs): - return self.filter(*args, **kwargs) + return list(self.filter(*args, **kwargs)) def __contains__(self, agent_id): return any(agent_id in g for g in self._agents) @@ -614,6 +592,57 @@ class AgentView(Mapping, Set): return f"{self.__class__.__name__}({self})" +def filter_groups(groups, group_ids=None, **kwargs): + assert isinstance(groups, dict) + if group_ids: + groups = list(groups[g] for g in group_ids if g in groups) + else: + groups = list(groups.values()) + + agents = chain.from_iterable(filter_group(g, **kwargs) for g in groups) + + yield from agents + + +def filter_group(group, ids=None, state_id=None, agent_type=None, ignore=None, state=None, **kwargs): + ''' + Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id). + ''' + assert isinstance(group, dict) + + if state_id is not None and not isinstance(state_id, (tuple, list)): + state_id = tuple([state_id]) + + if agent_type is not None: + try: + agent_type = tuple(agent_type) + except TypeError: + agent_type = tuple([agent_type]) + + if ids: + agents = (v[aid] for aid in ids if aid in group) + else: + agents = (a for a in group.values()) + + f = agents + if ignore: + f = filter(lambda x: x not in ignore, f) + + if state_id is not None: + f = filter(lambda agent: agent.get('state_id', None) in state_id, f) + + if agent_type is not None: + f = filter(lambda agent: isinstance(agent, agent_type), f) + + state = state or dict() + state.update(kwargs) + + for k, v in state.items(): + f = filter(lambda agent: agent.state.get(k, None) == v, f) + + yield from f + + def from_config(cfg: Dict[str, config.AgentConfig], env): ''' Agents are specified in groups. @@ -632,10 +661,22 @@ def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfi agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env) if cfg.distribution: n = cfg.n or len(env.topologies[cfg.topology]) - agents.update(_from_distro(cfg.distribution, n - len(agents), + target = n - len(agents) + agents.update(_from_distro(cfg.distribution, target, topology=cfg.topology or default.topology, default=default, env=env)) + assert len(agents) == n + if cfg.override: + for attrs in cfg.override: + if attrs.filter: + filtered = list(filter_group(agents, **attrs.filter)) + else: + filtered = list(agents) + + for agent in random.sample(filtered, attrs.n): + agent.state.update(attrs.state) + return agents @@ -650,10 +691,11 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: conf cls = serialization.deserialize(fixed.agent_class or default.agent_class) state = fixed.state.copy() state.update(default.state) - agents[agent_id] = cls(unique_id=agent_id, - model=env, - graph_name=fixed.topology or topology or default.topology, - **state) + agent = cls(unique_id=agent_id, + model=env, + graph_name=fixed.topology or topology or default.topology, + **state) + agents[agent.unique_id] = agent return agents @@ -671,31 +713,40 @@ def _from_distro(distro: List[config.AgentDistro], raise ValueError('You must provide a total number of agents, or the number of each type') n = sum(dist.n for dist in distro) + weights = list(dist.weight if dist.weight is not None else 1 for dist in distro) + minw = min(weights) + norm = list(weight / minw for weight in weights) + total = sum(norm) + chunk = n // total - total = sum((dist.weight if dist.weight is not None else 1) for dist in distro) - thres = {} - last = 0 - for i in sorted(distro, key=lambda x: x.weight): + # random.choices would be enough to get a weighted distribution. But it can vary a lot for smaller k + # So instead we calculate our own distribution to make sure the actual ratios are close to what we would expect - cls = serialization.deserialize(i.agent_class or default.agent_class) - thres[(last, last + i.weight/total)] = (cls, i) + # Calculate how many times each has to appear + indices = list(chain.from_iterable([idx] * int(n*chunk) for (idx, n) in enumerate(norm))) - acc = 0 + # Complete with random agents following the original weight distribution + if len(indices) < n: + indices += random.choices(list(range(len(distro))), weights=[d.weight for d in distro], k=n-len(indices)) - # using np.choice would be more efficient, but this allows us to use soil without - # numpy - for i in range(n): - r = random.random() - for (t, (cls, d)) in thres.items(): - if r >= t[0] and r <= t[1]: - agent_id = d.agent_id - if agent_id is None: - agent_id = env.next_id() + # Deserialize classes for efficiency + classes = list(serialization.deserialize(i.agent_class or default.agent_class) for i in distro) - state = d.state.copy() - state.update(default.state) - agents[agent_id] = cls(unique_id=agent_id, model=env, graph_name=d.topology or topology or default.topology, **state) - break + # Add them in random order + random.shuffle(indices) + + + for idx in indices: + d = distro[idx] + cls = classes[idx] + agent_id = env.next_id() + state = d.state.copy() + state.update(default.state) + agent = cls(unique_id=agent_id, model=env, graph_name=d.topology or topology or default.topology, **state) + assert agent.name is not None + assert agent.name != 'None' + assert agent.name + agents[agent.unique_id] = agent return agents diff --git a/soil/config.py b/soil/config.py index a47e5a2..eabb43f 100644 --- a/soil/config.py +++ b/soil/config.py @@ -76,8 +76,9 @@ class EnvConfig(BaseModel): class SingleAgentConfig(BaseModel): agent_class: Optional[Union[Type, str]] = None - agent_id: Optional[Union[str, int]] = None - topology: Optional[str] = None + agent_id: Optional[int] = None + topology: Optional[str] = 'default' + name: Optional[str] = None state: Optional[Dict[str, Any]] = {} class FixedAgentConfig(SingleAgentConfig): @@ -85,11 +86,16 @@ class FixedAgentConfig(SingleAgentConfig): @root_validator def validate_all(cls, values): - if 'agent_id' in values and values.get('n', 1) > 1: - raise ValueError("An agent_id can only be provided when there is only one agent") + if values.get('agent_id', None) is not None and values.get('n', 1) > 1: + print(values) + raise ValueError(f"An agent_id can only be provided when there is only one agent ({values.get('n')} given)") return values +class OverrideAgentConfig(FixedAgentConfig): + filter: Optional[Dict[str, Any]] = None + + class AgentDistro(SingleAgentConfig): weight: Optional[float] = 1 @@ -99,6 +105,7 @@ class AgentConfig(SingleAgentConfig): topology: Optional[str] = None distribution: Optional[List[AgentDistro]] = None fixed: Optional[List[FixedAgentConfig]] = None + override: Optional[List[OverrideAgentConfig]] = None @staticmethod def default(): @@ -118,7 +125,6 @@ class Config(BaseModel, extra=Extra.forbid): environment: EnvConfig = EnvConfig.default() agents: Optional[Dict[str, AgentConfig]] = {} - def convert_old(old, strict=True): ''' Try to convert old style configs into the new format. @@ -126,10 +132,6 @@ def convert_old(old, strict=True): This is still a work in progress and might not work in many cases. ''' - # TODO: translate states - - if strict and old.get('states', {}): - raise ValueError('Custom (i.e., manual) agent states cannot be translated to v2 configuration files. Please, convert your configuration file to the new format.') new = {} @@ -181,13 +183,16 @@ def convert_old(old, strict=True): for agent in old.get('environment_agents', []): agents['environment'] = {'distribution': [], 'fixed': []} - if 'agent_id' not in agent: - agents['environment']['distribution'].append(updated_agent(agent)) - else: + if 'agent_id' in agent: + agent['name'] = agent['agent_id'] + del agent['agent_id'] agents['environment']['fixed'].append(updated_agent(agent)) + else: + agents['environment']['distribution'].append(updated_agent(agent)) by_weight = [] fixed = [] + override = [] if 'network_agents' in old: agents['network']['topology'] = 'default' @@ -203,6 +208,20 @@ def convert_old(old, strict=True): agents['network']['topology'] = 'default' by_weight = [{'agent_type': old['agent_type']}] + + # TODO: translate states + if 'states' in old: + states = old['states'] + if isinstance(states, dict): + states = states.items() + else: + states = enumerate(states) + for (k, v) in states: + override.append({'filter': {'id': k}, + 'state': v + }) + + agents['network']['override'] = override agents['network']['fixed'] = fixed agents['network']['distribution'] = by_weight diff --git a/soil/environment.py b/soil/environment.py index f95b901..ddb46cb 100644 --- a/soil/environment.py +++ b/soil/environment.py @@ -79,7 +79,7 @@ class Environment(Model): @property def network_agents(self): - yield from self.agents(agent_type=agents.NetworkAgent, iterator=True) + yield from self.agents(agent_type=agents.NetworkAgent, iterator=False) @staticmethod diff --git a/soil/network.py b/soil/network.py new file mode 100644 index 0000000..0eb3688 --- /dev/null +++ b/soil/network.py @@ -0,0 +1,42 @@ +from typing import Dict +import os +import sys + +import networkx as nx + +from . import config, serialization, basestring + +def from_config(cfg: config.NetConfig, dir_path: str = None): + if not isinstance(cfg, config.NetConfig): + cfg = config.NetConfig(**cfg) + + if cfg.path: + path = cfg.path + if dir_path and not os.path.isabs(path): + path = os.path.join(dir_path, path) + extension = os.path.splitext(path)[1][1:] + kwargs = {} + if extension == 'gexf': + kwargs['version'] = '1.2draft' + kwargs['node_type'] = int + try: + method = getattr(nx.readwrite, 'read_' + extension) + except AttributeError: + raise AttributeError('Unknown format') + return method(path, **kwargs) + + if cfg.params: + net_args = cfg.params.dict() + net_gen = net_args.pop('generator') + + if dir_path not in sys.path: + sys.path.append(dir_path) + + method = serialization.deserializer(net_gen, + known_modules=['networkx.generators',]) + return method(**net_args) + + if isinstance(cfg.topology, basestring) or isinstance(cfg.topology, dict): + return nx.json_graph.node_link_graph(cfg.topology) + + return nx.Graph() diff --git a/soil/simulation.py b/soil/simulation.py index 3826459..0892731 100644 --- a/soil/simulation.py +++ b/soil/simulation.py @@ -42,7 +42,6 @@ class Simulation: config = Config(**cfg) if not config: raise ValueError("You need to specify a simulation configuration") - self.config = config @@ -151,6 +150,7 @@ class Simulation: # 'environment_agents': self.environment_agents, # }) # opts.update(kwargs) + print(self.config) env = Environment.from_config(self.config, trial_id=trial_id, **kwargs) return env diff --git a/tests/complete_converted.yml b/tests/complete_converted.yml new file mode 100644 index 0000000..ffb5a16 --- /dev/null +++ b/tests/complete_converted.yml @@ -0,0 +1,49 @@ +--- +version: '2' +general: + id: simple + group: tests + dir_path: "/tmp/" + num_trials: 3 + max_time: 100 + interval: 1 + seed: "CompleteSeed!" +topologies: + default: + params: + generator: complete_graph + n: 10 +agents: + default: + agent_class: CounterModel + state: + times: 1 + network: + topology: 'default' + distribution: + - agent_class: CounterModel + weight: 0.4 + state: + state_id: 0 + - agent_class: AggregatedCounter + weight: 0.6 + override: + - filter: + id: 0 + state: + name: 'The first node' + - filter: + id: 1 + state: + name: 'The second node' + + environment: + fixed: + - name: 'Environment Agent 1' + agent_class: CounterModel + state: + times: 10 +environment: + environment_class: Environment + params: + am_i_complete: true diff --git a/tests/old_complete.yml b/tests/old_complete.yml index 4382935..8609eb9 100644 --- a/tests/old_complete.yml +++ b/tests/old_complete.yml @@ -11,11 +11,11 @@ network_params: n: 10 network_agents: - agent_type: CounterModel - weight: 1 + weight: 0.4 state: state_id: 0 - agent_type: AggregatedCounter - weight: 0.2 + weight: 0.6 environment_agents: - agent_id: 'Environment Agent 1' agent_type: CounterModel diff --git a/tests/test_config.py b/tests/test_config.py index f4ad32e..7cba6af 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,7 +2,7 @@ from unittest import TestCase import os from os.path import join -from soil import simulation, serialization, config, network +from soil import simulation, serialization, config, network, agents ROOT = os.path.abspath(os.path.dirname(__file__)) EXAMPLES = join(ROOT, '..', 'examples') @@ -13,18 +13,22 @@ FORCE_TESTS = os.environ.get('FORCE_TESTS', '') class TestConfig(TestCase): def test_conversion(self): - new = serialization.load_file(join(ROOT, "complete_converted.yml"))[0] + expected = serialization.load_file(join(ROOT, "complete_converted.yml"))[0] old = serialization.load_file(join(ROOT, "old_complete.yml"))[0] converted_defaults = config.convert_old(old, strict=False) converted = converted_defaults.dict(skip_defaults=True) - def isequal(old, new): - if isinstance(old, dict): - for (k, v) in old.items(): - isequal(old[k], new[k]) - return - assert old == new - isequal(new, converted) + def isequal(a, b): + if isinstance(a, dict): + for (k, v) in a.items(): + if v: + isequal(a[k], b[k]) + else: + assert not b.get(k, None) + return + assert a == b + + isequal(converted, expected) def test_topology_config(self): netconfig = config.NetConfig(**{ @@ -60,6 +64,19 @@ class TestConfig(TestCase): assert env.agents[0].topology == env.topologies['default'] + def test_agents_from_config(self): + '''We test that the known complete configuration produces + the right agents in the right groups''' + cfg = serialization.load_file(join(ROOT, "complete_converted.yml"))[0] + s = simulation.from_config(cfg) + env = s.get_env() + assert len(env.topologies['default'].nodes) == 10 + assert len(env.agents('network')) == 10 + assert len(env.agents('environment')) == 1 + + assert sum(1 for a in env.agents('network') if isinstance(a, agents.CounterModel)) == 4 + assert sum(1 for a in env.agents('network') if isinstance(a, agents.AggregatedCounter)) == 6 + def make_example_test(path, cfg): def wrapped(self): root = os.getcwd() diff --git a/tests/test_examples.py b/tests/test_examples.py index a516c27..1cc4cca 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -18,10 +18,10 @@ def make_example_test(path, config): def wrapped(self): root = os.getcwd() for s in simulation.all_from_config(path): - iterations = s.config.max_time * s.config.num_trials + iterations = s.config.general.max_time * s.config.general.num_trials if iterations > 1000: - s.config.max_time = 100 - s.config.num_trials = 1 + s.config.general.max_time = 100 + s.config.general.num_trials = 1 if config.get('skip_test', False) and not FORCE_TESTS: self.skipTest('Example ignored.') envs = s.run_simulation(dry_run=True)