mirror of
https://github.com/gsi-upm/soil
synced 2025-08-23 19:52:19 +00:00
WIP
This commit is contained in:
@@ -7,9 +7,15 @@ class CounterModel(NetworkAgent):
|
||||
in each step and adds it to its state.
|
||||
"""
|
||||
|
||||
defaults = {
|
||||
'times': 0,
|
||||
'neighbors': 0,
|
||||
'total': 0
|
||||
}
|
||||
|
||||
def step(self):
|
||||
# Outside effects
|
||||
total = len(list(self.get_agents()))
|
||||
total = len(list(self.env.agents))
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['times'] = self.get('times', 0) + 1
|
||||
self['neighbors'] = neighbors
|
||||
@@ -33,6 +39,6 @@ class AggregatedCounter(NetworkAgent):
|
||||
self['times'] += 1
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['neighbors'] += neighbors
|
||||
total = len(list(self.get_agents()))
|
||||
total = len(list(self.env.agents))
|
||||
self['total'] += total
|
||||
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
||||
|
@@ -1,16 +1,17 @@
|
||||
import logging
|
||||
from collections import OrderedDict, defaultdict
|
||||
from collections.abc import Mapping, Set
|
||||
from copy import deepcopy
|
||||
from functools import partial, wraps
|
||||
from itertools import islice
|
||||
from itertools import islice, chain
|
||||
import json
|
||||
import networkx as nx
|
||||
|
||||
from .. import serialization, utils, time
|
||||
from mesa import Agent as MesaAgent
|
||||
from typing import Dict, List
|
||||
|
||||
from tsih import Key
|
||||
from .. import serialization, utils, time, config
|
||||
|
||||
from mesa import Agent
|
||||
|
||||
|
||||
def as_node(agent):
|
||||
@@ -24,7 +25,7 @@ IGNORED_FIELDS = ('model', 'logger')
|
||||
class DeadAgent(Exception):
|
||||
pass
|
||||
|
||||
class BaseAgent(Agent):
|
||||
class BaseAgent(MesaAgent):
|
||||
"""
|
||||
A special type of Mesa Agent that:
|
||||
|
||||
@@ -47,9 +48,8 @@ class BaseAgent(Agent):
|
||||
):
|
||||
# Check for REQUIRED arguments
|
||||
# Initialize agent parameters
|
||||
if isinstance(unique_id, Agent):
|
||||
if isinstance(unique_id, MesaAgent):
|
||||
raise Exception()
|
||||
self._saved = set()
|
||||
super().__init__(unique_id=unique_id, model=model)
|
||||
self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id)
|
||||
|
||||
@@ -57,7 +57,7 @@ class BaseAgent(Agent):
|
||||
self.alive = True
|
||||
|
||||
self.interval = interval or self.get('interval', 1)
|
||||
self.logger = logging.getLogger(self.model.name).getChild(self.name)
|
||||
self.logger = logging.getLogger(self.model.id).getChild(self.name)
|
||||
|
||||
if hasattr(self, 'level'):
|
||||
self.logger.setLevel(self.level)
|
||||
@@ -66,6 +66,7 @@ class BaseAgent(Agent):
|
||||
setattr(self, k, deepcopy(v))
|
||||
|
||||
for (k, v) in kwargs.items():
|
||||
|
||||
setattr(self, k, v)
|
||||
|
||||
for (k, v) in getattr(self, 'defaults', {}).items():
|
||||
@@ -107,23 +108,7 @@ class BaseAgent(Agent):
|
||||
def environment_params(self, value):
|
||||
self.model.environment_params = value
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if not key.startswith('_') and key not in IGNORED_FIELDS:
|
||||
try:
|
||||
k = Key(t_step=self.now,
|
||||
dict_id=self.unique_id,
|
||||
key=key)
|
||||
self._saved.add(key)
|
||||
self.model[k] = value
|
||||
except AttributeError:
|
||||
pass
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, tuple):
|
||||
key, t_step = key
|
||||
k = Key(key=key, t_step=t_step, dict_id=self.unique_id)
|
||||
return self.model[k]
|
||||
return getattr(self, key)
|
||||
|
||||
def __delitem__(self, key):
|
||||
@@ -135,8 +120,11 @@ class BaseAgent(Agent):
|
||||
def __setitem__(self, key, value):
|
||||
setattr(self, key, value)
|
||||
|
||||
def keys(self):
|
||||
return (k for k in self.__dict__ if k[0] != '_')
|
||||
|
||||
def items(self):
|
||||
return ((k, getattr(self, k)) for k in self._saved)
|
||||
return ((k, v) for (k, v) in self.__dict__.items() if k[0] != '_')
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self[key] if key in self else default
|
||||
@@ -174,22 +162,32 @@ class BaseAgent(Agent):
|
||||
extra['agent_name'] = self.name
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
|
||||
|
||||
|
||||
def debug(self, *args, **kwargs):
|
||||
return self.log(*args, level=logging.DEBUG, **kwargs)
|
||||
|
||||
def info(self, *args, **kwargs):
|
||||
return self.log(*args, level=logging.INFO, **kwargs)
|
||||
|
||||
# Alias
|
||||
# Agent = BaseAgent
|
||||
|
||||
class NetworkAgent(BaseAgent):
|
||||
|
||||
@property
|
||||
def topology(self):
|
||||
return self.model.G
|
||||
def __init__(self,
|
||||
*args,
|
||||
graph_name: str,
|
||||
node_id: int = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.graph_name = graph_name
|
||||
self.topology = self.env.topologies[self.graph_name]
|
||||
self.node_id = node_id
|
||||
|
||||
@property
|
||||
def G(self):
|
||||
return self.model.G
|
||||
return self.model.topologies[self._topology]
|
||||
|
||||
def count_agents(self, **kwargs):
|
||||
return len(list(self.get_agents(**kwargs)))
|
||||
@@ -210,8 +208,7 @@ class NetworkAgent(BaseAgent):
|
||||
if limit_neighbors:
|
||||
agents = self.topology.neighbors(self.unique_id)
|
||||
|
||||
agents = self.model.get_agents(agents)
|
||||
return select(agents, **kwargs)
|
||||
return self.model.agents(ids=agents, **kwargs)
|
||||
|
||||
def subgraph(self, center=True, **kwargs):
|
||||
include = [self] if center else []
|
||||
@@ -229,7 +226,6 @@ class NetworkAgent(BaseAgent):
|
||||
|
||||
self.topology.add_edge(self.unique_id, other.unique_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
||||
|
||||
|
||||
def ego_search(self, steps=1, center=False, node=None, **kwargs):
|
||||
'''Get a list of nodes in the ego network of *node* of radius *steps*'''
|
||||
node = as_node(node if node is not None else self)
|
||||
@@ -311,7 +307,7 @@ class MetaFSM(type):
|
||||
cls.states = states
|
||||
|
||||
|
||||
class FSM(NetworkAgent, metaclass=MetaFSM):
|
||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FSM, self).__init__(*args, **kwargs)
|
||||
if not hasattr(self, 'state_id'):
|
||||
@@ -537,32 +533,171 @@ def _definition_to_dict(definition, size=None, default_state=None):
|
||||
return agents
|
||||
|
||||
|
||||
def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs):
|
||||
class AgentView(Mapping, Set):
|
||||
"""A lazy-loaded list of agents.
|
||||
"""
|
||||
|
||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||
state_id = tuple([state_id])
|
||||
if agent_type is not None:
|
||||
try:
|
||||
agent_type = tuple(agent_type)
|
||||
except TypeError:
|
||||
agent_type = tuple([agent_type])
|
||||
__slots__ = ("_agents",)
|
||||
|
||||
f = agents
|
||||
|
||||
if ignore:
|
||||
f = filter(lambda x: x not in ignore, f)
|
||||
def __init__(self, agents):
|
||||
self._agents = agents
|
||||
|
||||
if state_id is not None:
|
||||
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
|
||||
def __getstate__(self):
|
||||
return {"_agents": self._agents}
|
||||
|
||||
if agent_type is not None:
|
||||
f = filter(lambda agent: isinstance(agent, agent_type), f)
|
||||
for k, v in kwargs.items():
|
||||
f = filter(lambda agent: agent.state.get(k, None) == v, f)
|
||||
def __setstate__(self, state):
|
||||
self._agents = state["_agents"]
|
||||
|
||||
if iterator:
|
||||
return f
|
||||
return f
|
||||
# Mapping methods
|
||||
def __len__(self):
|
||||
return sum(len(x) for x in self._agents.values())
|
||||
|
||||
def __iter__(self):
|
||||
return iter(chain.from_iterable(g.values() for g in self._agents.values()))
|
||||
|
||||
def __getitem__(self, agent_id):
|
||||
if isinstance(agent_id, slice):
|
||||
raise ValueError(f"Slicing is not supported")
|
||||
for group in self._agents.values():
|
||||
if agent_id in group:
|
||||
return group[agent_id]
|
||||
raise ValueError(f"Agent {agent_id} not found")
|
||||
|
||||
def filter(self, ids=None, groups=None, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs):
|
||||
|
||||
if state_id is not None and not isinstance(state_id, (tuple, list)):
|
||||
state_id = tuple([state_id])
|
||||
|
||||
agents = self._agents
|
||||
|
||||
if groups:
|
||||
agents = {(k,v) for (k, v) in agents.items() if k in groups}
|
||||
|
||||
if agent_type is not None:
|
||||
try:
|
||||
agent_type = tuple(agent_type)
|
||||
except TypeError:
|
||||
agent_type = tuple([agent_type])
|
||||
|
||||
if ids:
|
||||
agents = (v[aid] for v in agents.values() for aid in ids if aid in v)
|
||||
else:
|
||||
agents = (a for v in agents.values() for a in v.values())
|
||||
|
||||
f = agents
|
||||
if ignore:
|
||||
f = filter(lambda x: x not in ignore, f)
|
||||
|
||||
if state_id is not None:
|
||||
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
|
||||
|
||||
if agent_type is not None:
|
||||
f = filter(lambda agent: isinstance(agent, agent_type), f)
|
||||
for k, v in kwargs.items():
|
||||
f = filter(lambda agent: agent.state.get(k, None) == v, f)
|
||||
|
||||
if iterator:
|
||||
return f
|
||||
return list(f)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.filter(*args, **kwargs)
|
||||
|
||||
def __contains__(self, agent_id):
|
||||
return any(agent_id in g for g in self._agents)
|
||||
|
||||
def __str__(self):
|
||||
return str(list(a.id for a in self))
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}({self})"
|
||||
|
||||
|
||||
def from_config(cfg: Dict[str, config.AgentConfig], env):
|
||||
'''
|
||||
Agents are specified in groups.
|
||||
Each group can be specified in two ways, either through a fixed list in which each item has
|
||||
has the agent type, number of agents to create, and the other parameters, or through what we call
|
||||
an `agent distribution`, which is similar but instead of number of agents, it specifies the weight
|
||||
of each agent type.
|
||||
'''
|
||||
default = cfg.get('default', None)
|
||||
return {k: _group_from_config(c, default=default, env=env) for (k, c) in cfg.items() if k is not 'default'}
|
||||
|
||||
|
||||
def _group_from_config(cfg: config.AgentConfig, default: config.SingleAgentConfig, env):
|
||||
agents = {}
|
||||
if cfg.fixed is not None:
|
||||
agents = _from_fixed(cfg.fixed, topology=cfg.topology, default=default, env=env)
|
||||
if cfg.distribution:
|
||||
n = cfg.n or len(env.topologies[cfg.topology])
|
||||
agents.update(_from_distro(cfg.distribution, n - len(agents),
|
||||
topology=cfg.topology or default.topology,
|
||||
default=default,
|
||||
env=env))
|
||||
return agents
|
||||
|
||||
|
||||
def _from_fixed(lst: List[config.FixedAgentConfig], topology: str, default: config.SingleAgentConfig, env):
|
||||
agents = {}
|
||||
|
||||
for fixed in lst:
|
||||
agent_id = fixed.agent_id
|
||||
if agent_id is None:
|
||||
agent_id = env.next_id()
|
||||
|
||||
cls = serialization.deserialize(fixed.agent_class or default.agent_class)
|
||||
state = fixed.state.copy()
|
||||
state.update(default.state)
|
||||
agents[agent_id] = cls(unique_id=agent_id,
|
||||
model=env,
|
||||
graph_name=fixed.topology or topology or default.topology,
|
||||
**state)
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
def _from_distro(distro: List[config.AgentDistro],
|
||||
n: int,
|
||||
topology: str,
|
||||
default: config.SingleAgentConfig,
|
||||
env):
|
||||
|
||||
agents = {}
|
||||
|
||||
if n is None:
|
||||
if any(lambda dist: dist.n is None, distro):
|
||||
raise ValueError('You must provide a total number of agents, or the number of each type')
|
||||
n = sum(dist.n for dist in distro)
|
||||
|
||||
|
||||
total = sum((dist.weight if dist.weight is not None else 1) for dist in distro)
|
||||
thres = {}
|
||||
last = 0
|
||||
for i in sorted(distro, key=lambda x: x.weight):
|
||||
|
||||
cls = serialization.deserialize(i.agent_class or default.agent_class)
|
||||
thres[(last, last + i.weight/total)] = (cls, i)
|
||||
|
||||
acc = 0
|
||||
|
||||
# using np.choice would be more efficient, but this allows us to use soil without
|
||||
# numpy
|
||||
for i in range(n):
|
||||
r = random.random()
|
||||
for (t, (cls, d)) in thres.items():
|
||||
if r >= t[0] and r <= t[1]:
|
||||
agent_id = d.agent_id
|
||||
if agent_id is None:
|
||||
agent_id = env.next_id()
|
||||
|
||||
state = d.state.copy()
|
||||
state.update(default.state)
|
||||
agents[agent_id] = cls(unique_id=agent_id, model=env, graph_name=d.topology or topology or default.topology, **state)
|
||||
break
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
from .BassModel import *
|
||||
|
@@ -11,7 +11,7 @@ from pydantic import BaseModel, Extra
|
||||
class General(BaseModel):
|
||||
id: str = 'Unnamed Simulation'
|
||||
group: str = None
|
||||
dir_path: str = None
|
||||
dir_path: Optional[str] = None
|
||||
num_trials: int = 1
|
||||
max_time: float = 100
|
||||
interval: float = 1
|
||||
@@ -27,13 +27,13 @@ nodeId = int
|
||||
|
||||
class Node(BaseModel):
|
||||
id: nodeId
|
||||
state: Dict[str, Any]
|
||||
state: Optional[Dict[str, Any]] = {}
|
||||
|
||||
|
||||
class Edge(BaseModel):
|
||||
source: nodeId
|
||||
target: nodeId
|
||||
value: float = 1
|
||||
value: Optional[float] = 1
|
||||
|
||||
|
||||
class Topology(BaseModel):
|
||||
@@ -75,46 +75,62 @@ class EnvConfig(BaseModel):
|
||||
|
||||
|
||||
class SingleAgentConfig(BaseModel):
|
||||
agent_class: Union[Type, str] = 'soil.Agent'
|
||||
agent_class: Optional[Union[Type, str]] = None
|
||||
agent_id: Optional[Union[str, int]] = None
|
||||
params: Dict[str, Any] = {}
|
||||
state: Dict[str, Any] = {}
|
||||
topology: Optional[str] = None
|
||||
state: Optional[Dict[str, Any]] = {}
|
||||
|
||||
|
||||
class AgentDistro(SingleAgentConfig):
|
||||
weight: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
class FixedAgentConfig(SingleAgentConfig):
|
||||
n: Optional[int] = 1
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, values):
|
||||
if 'weight' in values and 'count' in values:
|
||||
raise ValueError("You may either specify a weight in the distribution or an agent count")
|
||||
if 'agent_id' in values and values.get('n', 1) > 1:
|
||||
raise ValueError("An agent_id can only be provided when there is only one agent")
|
||||
return values
|
||||
|
||||
|
||||
class AgentDistro(SingleAgentConfig):
|
||||
weight: Optional[float] = 1
|
||||
|
||||
|
||||
class AgentConfig(SingleAgentConfig):
|
||||
n: Optional[int] = None
|
||||
topology: Optional[str] = None
|
||||
distribution: Optional[List[AgentDistro]] = None
|
||||
fixed: Optional[List[SingleAgentConfig]] = None
|
||||
fixed: Optional[List[FixedAgentConfig]] = None
|
||||
|
||||
@staticmethod
|
||||
def default():
|
||||
return AgentConfig()
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, values):
|
||||
if 'distribution' in values and ('n' not in values and 'topology' not in values):
|
||||
raise ValueError("You need to provide the number of agents or a topology to extract the value from.")
|
||||
return values
|
||||
|
||||
|
||||
class Config(BaseModel, extra=Extra.forbid):
|
||||
version: Optional[str] = '1'
|
||||
general: General = General.default()
|
||||
network: Optional[NetConfig] = None
|
||||
topologies: Optional[Dict[str, NetConfig]] = {}
|
||||
environment: EnvConfig = EnvConfig.default()
|
||||
agents: Dict[str, AgentConfig] = {}
|
||||
agents: Optional[Dict[str, AgentConfig]] = {}
|
||||
|
||||
|
||||
def convert_old(old):
|
||||
def convert_old(old, strict=True):
|
||||
'''
|
||||
Try to convert old style configs into the new format.
|
||||
|
||||
This is still a work in progress and might not work in many cases.
|
||||
'''
|
||||
|
||||
# TODO: translate states
|
||||
|
||||
if strict and old.get('states', {}):
|
||||
raise ValueError('Custom (i.e., manual) agent states cannot be translated to v2 configuration files. Please, convert your configuration file to the new format.')
|
||||
|
||||
new = {}
|
||||
|
||||
|
||||
@@ -129,7 +145,10 @@ def convert_old(old):
|
||||
if k in old:
|
||||
general[k] = old[k]
|
||||
|
||||
network = {'group': 'network'}
|
||||
if 'name' in old:
|
||||
general['id'] = old['name']
|
||||
|
||||
network = {}
|
||||
|
||||
|
||||
if 'network_params' in old and old['network_params']:
|
||||
@@ -143,9 +162,6 @@ def convert_old(old):
|
||||
network['topology'] = old['topology']
|
||||
|
||||
agents = {
|
||||
'environment': {
|
||||
'fixed': []
|
||||
},
|
||||
'network': {},
|
||||
'default': {},
|
||||
}
|
||||
@@ -164,10 +180,31 @@ def convert_old(old):
|
||||
return newagent
|
||||
|
||||
for agent in old.get('environment_agents', []):
|
||||
agents['environment']['fixed'].append(updated_agent(agent))
|
||||
agents['environment'] = {'distribution': [], 'fixed': []}
|
||||
if 'agent_id' not in agent:
|
||||
agents['environment']['distribution'].append(updated_agent(agent))
|
||||
else:
|
||||
agents['environment']['fixed'].append(updated_agent(agent))
|
||||
|
||||
for agent in old.get('network_agents', []):
|
||||
agents['network'].setdefault('distribution', []).append(updated_agent(agent))
|
||||
by_weight = []
|
||||
fixed = []
|
||||
|
||||
if 'network_agents' in old:
|
||||
agents['network']['topology'] = 'default'
|
||||
|
||||
for agent in old['network_agents']:
|
||||
agent = updated_agent(agent)
|
||||
if 'agent_id' in agent:
|
||||
fixed.append(agent)
|
||||
else:
|
||||
by_weight.append(agent)
|
||||
|
||||
if 'agent_type' in old and (not fixed and not by_weight):
|
||||
agents['network']['topology'] = 'default'
|
||||
by_weight = [{'agent_type': old['agent_type']}]
|
||||
|
||||
agents['network']['fixed'] = fixed
|
||||
agents['network']['distribution'] = by_weight
|
||||
|
||||
environment = {'params': {}}
|
||||
if 'environment_class' in old:
|
||||
@@ -176,8 +213,8 @@ def convert_old(old):
|
||||
for (k, v) in old.get('environment_params', {}).items():
|
||||
environment['params'][k] = v
|
||||
|
||||
|
||||
return Config(general=general,
|
||||
network=network,
|
||||
return Config(version='2',
|
||||
general=general,
|
||||
topologies={'default': network},
|
||||
environment=environment,
|
||||
agents=agents)
|
||||
|
@@ -3,6 +3,10 @@ import os
|
||||
import sqlite3
|
||||
import math
|
||||
import random
|
||||
import logging
|
||||
|
||||
from typing import Dict
|
||||
from collections import namedtuple
|
||||
from time import time as current_time
|
||||
from copy import deepcopy
|
||||
from networkx.readwrite import json_graph
|
||||
@@ -12,9 +16,11 @@ import networkx as nx
|
||||
|
||||
from mesa import Model
|
||||
|
||||
from tsih import Record
|
||||
from . import serialization, agents, analysis, utils, time, config, network
|
||||
|
||||
|
||||
Record = namedtuple('Record', 'dict_id t_step key value')
|
||||
|
||||
from . import serialization, agents, analysis, utils, time, config
|
||||
|
||||
class Environment(Model):
|
||||
"""
|
||||
@@ -28,15 +34,17 @@ class Environment(Model):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
env_id,
|
||||
env_id='unnamed_env',
|
||||
seed='default',
|
||||
schedule=None,
|
||||
env_params=None,
|
||||
dir_path=None,
|
||||
**kwargs):
|
||||
interval=1,
|
||||
agents: Dict[str, config.AgentConfig] = {},
|
||||
topologies: Dict[str, config.NetConfig] = {},
|
||||
**env_params):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.current_id = -1
|
||||
|
||||
self.seed = '{}_{}'.format(seed, env_id)
|
||||
self.id = env_id
|
||||
@@ -51,25 +59,28 @@ class Environment(Model):
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
if isinstance(states, list):
|
||||
states = dict(enumerate(states))
|
||||
self.states = deepcopy(states) if states else {}
|
||||
self.default_state = deepcopy(default_state) or {}
|
||||
|
||||
|
||||
self.set_topology(topology=topology,
|
||||
network_params=network_params)
|
||||
|
||||
self.topologies = {}
|
||||
for (name, cfg) in topologies.items():
|
||||
self.set_topology(cfg=cfg,
|
||||
graph=name)
|
||||
self.agents = agents or {}
|
||||
|
||||
self.env_params = env_params or {}
|
||||
self.env_params.update(kwargs)
|
||||
|
||||
self.interval = interval
|
||||
self['SEED'] = seed
|
||||
|
||||
self.logger = utils.logger.getChild(self.id)
|
||||
|
||||
@property
|
||||
def topology(self):
|
||||
return self.topologies['default']
|
||||
|
||||
@property
|
||||
def network_agents(self):
|
||||
yield from self.agents(agent_type=agents.NetworkAgent, iterator=True)
|
||||
|
||||
self.logger = utils.logger.getChild(self.name)
|
||||
|
||||
@staticmethod
|
||||
def from_config(conf: config.Config, trial_id, **kwargs) -> Environment:
|
||||
@@ -92,39 +103,30 @@ class Environment(Model):
|
||||
raise Exception('The environment has not been scheduled, so it has no sense of time')
|
||||
|
||||
|
||||
def set_topology(self, topology, network_params=None, dir_path=None):
|
||||
if topology is None:
|
||||
network_params = network_params or {}
|
||||
topology = serialization.load_network(network_params,
|
||||
dir_path=dir_path or self.dir_path)
|
||||
if not topology:
|
||||
topology = nx.Graph()
|
||||
self.G = nx.Graph(topology)
|
||||
def set_topology(self, cfg=None, dir_path=None, graph='default'):
|
||||
self.topologies[graph] = network.from_config(cfg, dir_path=dir_path)
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
for agents in self.agents.values():
|
||||
yield from agents
|
||||
return agents.AgentView(self._agents)
|
||||
|
||||
@agents.setter
|
||||
def agents(self, agents):
|
||||
self.agents = {}
|
||||
def agents(self, agents_def: Dict[str, config.AgentConfig]):
|
||||
self._agents = agents.from_config(agents_def, env=self)
|
||||
for d in self._agents.values():
|
||||
for a in d.values():
|
||||
self.schedule.add(a)
|
||||
|
||||
for (k, v) in agents.items():
|
||||
self.agents[k] = agents.from_config(v)
|
||||
for agent in self.agents.get('network', []):
|
||||
node = self.G.nodes[agent.unique_id]
|
||||
node['agent'] = agent
|
||||
|
||||
@property
|
||||
def network_agents(self):
|
||||
for i in self.G.nodes():
|
||||
node = self.G.nodes[i]
|
||||
if 'agent' in node:
|
||||
yield node['agent']
|
||||
# @property
|
||||
# def network_agents(self):
|
||||
# for i in self.G.nodes():
|
||||
# node = self.G.nodes[i]
|
||||
# if 'agent' in node:
|
||||
# yield node['agent']
|
||||
|
||||
def init_agent(self, agent_id, agent_definitions):
|
||||
node = self.G.nodes[agent_id]
|
||||
def init_agent(self, agent_id, agent_definitions, graph='default'):
|
||||
node = self.topologies[graph].nodes[agent_id]
|
||||
init = False
|
||||
state = dict(node)
|
||||
|
||||
@@ -145,8 +147,8 @@ class Environment(Model):
|
||||
return
|
||||
return self.set_agent(agent_id, agent_type, state)
|
||||
|
||||
def set_agent(self, agent_id, agent_type, state=None):
|
||||
node = self.G.nodes[agent_id]
|
||||
def set_agent(self, agent_id, agent_type, state=None, graph='default'):
|
||||
node = self.topologies[graph].nodes[agent_id]
|
||||
defstate = deepcopy(self.default_state) or {}
|
||||
defstate.update(self.states.get(agent_id, {}))
|
||||
defstate.update(node.get('state', {}))
|
||||
@@ -166,20 +168,20 @@ class Environment(Model):
|
||||
self.schedule.add(a)
|
||||
return a
|
||||
|
||||
def add_node(self, agent_type, state=None):
|
||||
agent_id = int(len(self.G.nodes()))
|
||||
self.G.add_node(agent_id)
|
||||
a = self.set_agent(agent_id, agent_type, state)
|
||||
def add_node(self, agent_type, state=None, graph='default'):
|
||||
agent_id = int(len(self.topologies[graph].nodes()))
|
||||
self.topologies[graph].add_node(agent_id)
|
||||
a = self.set_agent(agent_id, agent_type, state, graph=graph)
|
||||
a['visible'] = True
|
||||
return a
|
||||
|
||||
def add_edge(self, agent1, agent2, start=None, **attrs):
|
||||
def add_edge(self, agent1, agent2, start=None, graph='default', **attrs):
|
||||
if hasattr(agent1, 'id'):
|
||||
agent1 = agent1.id
|
||||
if hasattr(agent2, 'id'):
|
||||
agent2 = agent2.id
|
||||
start = start or self.now
|
||||
return self.G.add_edge(agent1, agent2, **attrs)
|
||||
return self.topologies[graph].add_edge(agent1, agent2, **attrs)
|
||||
|
||||
def log(self, message, *args, level=logging.INFO, **kwargs):
|
||||
if not self.logger.isEnabledFor(level):
|
||||
@@ -190,7 +192,7 @@ class Environment(Model):
|
||||
message += " {k}={v} ".format(k, v)
|
||||
extra = {}
|
||||
extra['now'] = self.now
|
||||
extra['unique_id'] = self.name
|
||||
extra['unique_id'] = self.id
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
|
||||
def step(self):
|
||||
@@ -207,30 +209,6 @@ class Environment(Model):
|
||||
self.step()
|
||||
utils.logger.debug(f'Simulation step {self.schedule.time}/{until}. Next: {self.schedule.next_time}')
|
||||
self.schedule.time = until
|
||||
self._history.flush_cache()
|
||||
|
||||
def _save_state(self, now=None):
|
||||
serialization.logger.debug('Saving state @{}'.format(self.now))
|
||||
self._history.save_records(self.state_to_tuples(now=now))
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, tuple):
|
||||
self._history.flush_cache()
|
||||
return self._history[key]
|
||||
|
||||
return self.environment_params[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if isinstance(key, tuple):
|
||||
k = Key(*key)
|
||||
self._history.save_record(*k,
|
||||
value=value)
|
||||
return
|
||||
self.environment_params[key] = value
|
||||
self._history.save_record(dict_id='env',
|
||||
t_step=self.now,
|
||||
key=key,
|
||||
value=value)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.env_params
|
||||
@@ -248,14 +226,6 @@ class Environment(Model):
|
||||
def __setitem__(self, key, value):
|
||||
return self.env_params.__setitem__(key, value)
|
||||
|
||||
def get_agent(self, agent_id):
|
||||
return self.G.nodes[agent_id]['agent']
|
||||
|
||||
def get_agents(self, nodes=None):
|
||||
if nodes is None:
|
||||
return self.agents
|
||||
return (self.G.nodes[i]['agent'] for i in nodes)
|
||||
|
||||
def _agent_to_tuples(self, agent, now=None):
|
||||
if now is None:
|
||||
now = self.now
|
||||
@@ -270,7 +240,7 @@ class Environment(Model):
|
||||
now = self.now
|
||||
|
||||
if agent_id:
|
||||
agent = self.get_agent(agent_id)
|
||||
agent = self.agents[agent_id]
|
||||
yield from self._agent_to_tuples(agent, now)
|
||||
return
|
||||
|
||||
@@ -283,4 +253,5 @@ class Environment(Model):
|
||||
yield from self._agent_to_tuples(agent, now)
|
||||
|
||||
|
||||
|
||||
SoilEnvironment = Environment
|
||||
|
@@ -16,39 +16,39 @@ from jinja2 import Template
|
||||
logger = logging.getLogger('soil')
|
||||
|
||||
|
||||
def load_network(network_params, dir_path=None):
|
||||
G = nx.Graph()
|
||||
# def load_network(network_params, dir_path=None):
|
||||
# G = nx.Graph()
|
||||
|
||||
if not network_params:
|
||||
return G
|
||||
# if not network_params:
|
||||
# return G
|
||||
|
||||
if 'path' in network_params:
|
||||
path = network_params['path']
|
||||
if dir_path and not os.path.isabs(path):
|
||||
path = os.path.join(dir_path, path)
|
||||
extension = os.path.splitext(path)[1][1:]
|
||||
kwargs = {}
|
||||
if extension == 'gexf':
|
||||
kwargs['version'] = '1.2draft'
|
||||
kwargs['node_type'] = int
|
||||
try:
|
||||
method = getattr(nx.readwrite, 'read_' + extension)
|
||||
except AttributeError:
|
||||
raise AttributeError('Unknown format')
|
||||
G = method(path, **kwargs)
|
||||
# if 'path' in network_params:
|
||||
# path = network_params['path']
|
||||
# if dir_path and not os.path.isabs(path):
|
||||
# path = os.path.join(dir_path, path)
|
||||
# extension = os.path.splitext(path)[1][1:]
|
||||
# kwargs = {}
|
||||
# if extension == 'gexf':
|
||||
# kwargs['version'] = '1.2draft'
|
||||
# kwargs['node_type'] = int
|
||||
# try:
|
||||
# method = getattr(nx.readwrite, 'read_' + extension)
|
||||
# except AttributeError:
|
||||
# raise AttributeError('Unknown format')
|
||||
# G = method(path, **kwargs)
|
||||
|
||||
elif 'generator' in network_params:
|
||||
net_args = network_params.copy()
|
||||
net_gen = net_args.pop('generator')
|
||||
# elif 'generator' in network_params:
|
||||
# net_args = network_params.copy()
|
||||
# net_gen = net_args.pop('generator')
|
||||
|
||||
if dir_path not in sys.path:
|
||||
sys.path.append(dir_path)
|
||||
# if dir_path not in sys.path:
|
||||
# sys.path.append(dir_path)
|
||||
|
||||
method = deserializer(net_gen,
|
||||
known_modules=['networkx.generators',])
|
||||
G = method(**net_args)
|
||||
# method = deserializer(net_gen,
|
||||
# known_modules=['networkx.generators',])
|
||||
# G = method(**net_args)
|
||||
|
||||
return G
|
||||
# return G
|
||||
|
||||
|
||||
def load_file(infile):
|
||||
@@ -122,8 +122,8 @@ def load_files(*patterns, **kwargs):
|
||||
for i in glob(pattern, **kwargs):
|
||||
for config in load_file(i):
|
||||
path = os.path.abspath(i)
|
||||
if 'dir_path' not in config:
|
||||
config['dir_path'] = os.path.dirname(path)
|
||||
if 'general' in config and 'dir_path' not in config['general']:
|
||||
config['general']['dir_path'] = os.path.dirname(path)
|
||||
yield config, path
|
||||
|
||||
|
||||
|
@@ -96,7 +96,7 @@ class Simulation:
|
||||
stat.sim_start()
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.start()
|
||||
exporter.sim_start()
|
||||
|
||||
for env in self._run_sync_or_async(parallel=parallel,
|
||||
log_level=log_level,
|
||||
@@ -107,7 +107,7 @@ class Simulation:
|
||||
|
||||
collected = list(stat.trial_end(env) for stat in stats)
|
||||
|
||||
saved = self._update_stats(collected, t_step=env.now, trial_id=env.name)
|
||||
saved = self._update_stats(collected, t_step=env.now, trial_id=env.id)
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.trial_end(env, saved)
|
||||
@@ -117,6 +117,9 @@ class Simulation:
|
||||
collected = list(stat.end() for stat in stats)
|
||||
saved = self._update_stats(collected)
|
||||
|
||||
for stat in stats:
|
||||
stat.sim_end()
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.sim_end(saved)
|
||||
|
||||
@@ -131,24 +134,24 @@ class Simulation:
|
||||
|
||||
def get_env(self, trial_id=0, **kwargs):
|
||||
'''Create an environment for a trial of the simulation'''
|
||||
opts = self.environment_params.copy()
|
||||
opts.update({
|
||||
'name': '{}_trial_{}'.format(self.name, trial_id),
|
||||
'topology': self.topology.copy(),
|
||||
'network_params': self.network_params,
|
||||
'seed': '{}_trial_{}'.format(self.seed, trial_id),
|
||||
'initial_time': 0,
|
||||
'interval': self.interval,
|
||||
'network_agents': self.network_agents,
|
||||
'initial_time': 0,
|
||||
'states': self.states,
|
||||
'dir_path': self.dir_path,
|
||||
'default_state': self.default_state,
|
||||
'history': bool(self._history),
|
||||
'environment_agents': self.environment_agents,
|
||||
})
|
||||
opts.update(kwargs)
|
||||
env = self.environment_class(**opts)
|
||||
# opts = self.environment_params.copy()
|
||||
# opts.update({
|
||||
# 'name': '{}_trial_{}'.format(self.name, trial_id),
|
||||
# 'topology': self.topology.copy(),
|
||||
# 'network_params': self.network_params,
|
||||
# 'seed': '{}_trial_{}'.format(self.seed, trial_id),
|
||||
# 'initial_time': 0,
|
||||
# 'interval': self.interval,
|
||||
# 'network_agents': self.network_agents,
|
||||
# 'initial_time': 0,
|
||||
# 'states': self.states,
|
||||
# 'dir_path': self.dir_path,
|
||||
# 'default_state': self.default_state,
|
||||
# 'history': bool(self._history),
|
||||
# 'environment_agents': self.environment_agents,
|
||||
# })
|
||||
# opts.update(kwargs)
|
||||
env = Environment.from_config(self.config, trial_id=trial_id, **kwargs)
|
||||
return env
|
||||
|
||||
def run_trial(self, trial_id=None, until=None, log_level=logging.INFO, **opts):
|
||||
@@ -162,7 +165,7 @@ class Simulation:
|
||||
# Set-up trial environment and graph
|
||||
until = until or self.config.general.max_time
|
||||
|
||||
env = Environment.from_config(self.config, trial_id=trial_id)
|
||||
env = self.get_env(trial_id, **opts)
|
||||
# Set up agents on nodes
|
||||
with utils.timer('Simulation {} trial {}'.format(self.config.general.id, trial_id)):
|
||||
env.run(until)
|
||||
@@ -181,21 +184,31 @@ class Simulation:
|
||||
ex.message = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)[:])
|
||||
return ex
|
||||
|
||||
def to_dict(self):
|
||||
return self.config.dict()
|
||||
|
||||
def to_yaml(self):
|
||||
return yaml.dump(self.config.dict())
|
||||
|
||||
|
||||
def all_from_config(config):
|
||||
configs = list(serialization.load_config(config))
|
||||
for config, _ in configs:
|
||||
sim = Simulation(**config)
|
||||
for config, path in configs:
|
||||
if config.get('version', '1') == '1':
|
||||
config = convert_old(config)
|
||||
if not isinstance(config, Config):
|
||||
config = Config(**config)
|
||||
if not config.general.dir_path:
|
||||
config.general.dir_path = os.path.dirname(path)
|
||||
sim = Simulation(config=config)
|
||||
yield sim
|
||||
|
||||
|
||||
def from_config(conf_or_path):
|
||||
config = list(serialization.load_config(conf_or_path))
|
||||
if len(config) > 1:
|
||||
lst = list(all_from_config(conf_or_path))
|
||||
if len(lst) > 1:
|
||||
raise AttributeError('Provide only one configuration')
|
||||
config = config[0][0]
|
||||
sim = Simulation(**config)
|
||||
return sim
|
||||
return lst[0]
|
||||
|
||||
def from_old_config(conf_or_path):
|
||||
config = list(serialization.load_config(conf_or_path))
|
||||
@@ -206,13 +219,7 @@ def from_old_config(conf_or_path):
|
||||
|
||||
|
||||
def run_from_config(*configs, **kwargs):
|
||||
for config_def in configs:
|
||||
# logger.info("Found {} config(s)".format(len(ls)))
|
||||
for config, path in serialization.load_config(config_def):
|
||||
name = config.general.id
|
||||
logger.info("Using config(s): {name}".format(name=name))
|
||||
|
||||
dir_path = config.general.dir_path or os.path.dirname(path)
|
||||
sim = Simulation(dir_path=dir_path,
|
||||
**config)
|
||||
sim.run_simulation(**kwargs)
|
||||
for sim in all_from_config(configs):
|
||||
name = config.general.id
|
||||
logger.info("Using config(s): {name}".format(name=name))
|
||||
sim.run_simulation(**kwargs)
|
||||
|
@@ -3,7 +3,7 @@ from queue import Empty
|
||||
from heapq import heappush, heappop
|
||||
import math
|
||||
from .utils import logger
|
||||
from mesa import Agent
|
||||
from mesa import Agent as MesaAgent
|
||||
|
||||
|
||||
INFINITY = float('inf')
|
||||
@@ -41,7 +41,7 @@ class TimedActivation(BaseScheduler):
|
||||
self._queue = []
|
||||
self.next_time = 0
|
||||
|
||||
def add(self, agent: Agent):
|
||||
def add(self, agent: MesaAgent):
|
||||
if agent.unique_id not in self._agents:
|
||||
heappush(self._queue, (self.time, agent.unique_id))
|
||||
super().add(agent)
|
||||
|
Reference in New Issue
Block a user