|
|
|
@ -20,13 +20,13 @@ from typing import Dict, List
|
|
|
|
|
from .. import serialization, utils, time, config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def as_node(agent):
|
|
|
|
|
if isinstance(agent, BaseAgent):
|
|
|
|
|
return agent.id
|
|
|
|
|
return agent
|
|
|
|
|
|
|
|
|
|
IGNORED_FIELDS = ('model', 'logger')
|
|
|
|
|
|
|
|
|
|
IGNORED_FIELDS = ("model", "logger")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeadAgent(Exception):
|
|
|
|
@ -43,13 +43,18 @@ class MetaAgent(ABCMeta):
|
|
|
|
|
defaults.update(i._defaults)
|
|
|
|
|
|
|
|
|
|
new_nmspc = {
|
|
|
|
|
'_defaults': defaults,
|
|
|
|
|
"_defaults": defaults,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for attr, func in namespace.items():
|
|
|
|
|
if isinstance(func, types.FunctionType) or isinstance(func, property) or isinstance(func, classmethod) or attr[0] == '_':
|
|
|
|
|
if (
|
|
|
|
|
isinstance(func, types.FunctionType)
|
|
|
|
|
or isinstance(func, property)
|
|
|
|
|
or isinstance(func, classmethod)
|
|
|
|
|
or attr[0] == "_"
|
|
|
|
|
):
|
|
|
|
|
new_nmspc[attr] = func
|
|
|
|
|
elif attr == 'defaults':
|
|
|
|
|
elif attr == "defaults":
|
|
|
|
|
defaults.update(func)
|
|
|
|
|
else:
|
|
|
|
|
defaults[attr] = copy(func)
|
|
|
|
@ -69,12 +74,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|
|
|
|
Any attribute that is not preceded by an underscore (`_`) will also be added to its state.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
unique_id,
|
|
|
|
|
model,
|
|
|
|
|
name=None,
|
|
|
|
|
interval=None,
|
|
|
|
|
**kwargs):
|
|
|
|
|
def __init__(self, unique_id, model, name=None, interval=None, **kwargs):
|
|
|
|
|
# Check for REQUIRED arguments
|
|
|
|
|
# Initialize agent parameters
|
|
|
|
|
if isinstance(unique_id, MesaAgent):
|
|
|
|
@ -82,16 +82,19 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|
|
|
|
assert isinstance(unique_id, int)
|
|
|
|
|
super().__init__(unique_id=unique_id, model=model)
|
|
|
|
|
|
|
|
|
|
self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id)
|
|
|
|
|
|
|
|
|
|
self.name = (
|
|
|
|
|
str(name) if name else "{}[{}]".format(type(self).__name__, self.unique_id)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.alive = True
|
|
|
|
|
|
|
|
|
|
self.interval = interval or self.get('interval', 1)
|
|
|
|
|
logger = utils.logger.getChild(getattr(self.model, 'id', self.model)).getChild(self.name)
|
|
|
|
|
self.logger = logging.LoggerAdapter(logger, {'agent_name': self.name})
|
|
|
|
|
self.interval = interval or self.get("interval", 1)
|
|
|
|
|
logger = utils.logger.getChild(getattr(self.model, "id", self.model)).getChild(
|
|
|
|
|
self.name
|
|
|
|
|
)
|
|
|
|
|
self.logger = logging.LoggerAdapter(logger, {"agent_name": self.name})
|
|
|
|
|
|
|
|
|
|
if hasattr(self, 'level'):
|
|
|
|
|
if hasattr(self, "level"):
|
|
|
|
|
self.logger.setLevel(self.level)
|
|
|
|
|
|
|
|
|
|
for (k, v) in self._defaults.items():
|
|
|
|
@ -117,20 +120,22 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|
|
|
|
def from_dict(cls, model, attrs, warn_extra=True):
|
|
|
|
|
ignored = {}
|
|
|
|
|
args = {}
|
|
|
|
|
for k, v in attrs.items():
|
|
|
|
|
for k, v in attrs.items():
|
|
|
|
|
if k in inspect.signature(cls).parameters:
|
|
|
|
|
args[k] = v
|
|
|
|
|
else:
|
|
|
|
|
ignored[k] = v
|
|
|
|
|
if ignored and warn_extra:
|
|
|
|
|
utils.logger.info(f'Ignoring the following arguments for agent class { agent_class.__name__ }: { ignored }')
|
|
|
|
|
utils.logger.info(
|
|
|
|
|
f"Ignoring the following arguments for agent class { agent_class.__name__ }: { ignored }"
|
|
|
|
|
)
|
|
|
|
|
return cls(model=model, **args)
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, key):
|
|
|
|
|
try:
|
|
|
|
|
return getattr(self, key)
|
|
|
|
|
except AttributeError:
|
|
|
|
|
raise KeyError(f'key {key} not found in agent')
|
|
|
|
|
raise KeyError(f"key {key} not found in agent")
|
|
|
|
|
|
|
|
|
|
def __delitem__(self, key):
|
|
|
|
|
return delattr(self, key)
|
|
|
|
@ -148,7 +153,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|
|
|
|
return self.items()
|
|
|
|
|
|
|
|
|
|
def keys(self):
|
|
|
|
|
return (k for k in self.__dict__ if k[0] != '_' and k not in IGNORED_FIELDS)
|
|
|
|
|
return (k for k in self.__dict__ if k[0] != "_" and k not in IGNORED_FIELDS)
|
|
|
|
|
|
|
|
|
|
def items(self, keys=None, skip=None):
|
|
|
|
|
keys = keys if keys is not None else self.keys()
|
|
|
|
@ -156,7 +161,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|
|
|
|
if skip:
|
|
|
|
|
return filter(lambda x: x[0] not in skip, it)
|
|
|
|
|
return it
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get(self, key, default=None):
|
|
|
|
|
return self[key] if key in self else default
|
|
|
|
|
|
|
|
|
@ -169,7 +174,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def die(self):
|
|
|
|
|
self.info(f'agent dying')
|
|
|
|
|
self.info(f"agent dying")
|
|
|
|
|
self.alive = False
|
|
|
|
|
return time.NEVER
|
|
|
|
|
|
|
|
|
@ -186,9 +191,9 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|
|
|
|
for k, v in kwargs:
|
|
|
|
|
message += " {k}={v} ".format(k, v)
|
|
|
|
|
extra = {}
|
|
|
|
|
extra['now'] = self.now
|
|
|
|
|
extra['unique_id'] = self.unique_id
|
|
|
|
|
extra['agent_name'] = self.name
|
|
|
|
|
extra["now"] = self.now
|
|
|
|
|
extra["unique_id"] = self.unique_id
|
|
|
|
|
extra["agent_name"] = self.name
|
|
|
|
|
return self.logger.log(level, message, extra=extra)
|
|
|
|
|
|
|
|
|
|
def debug(self, *args, **kwargs):
|
|
|
|
@ -214,10 +219,10 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|
|
|
|
content = dict(self.items(keys=keys))
|
|
|
|
|
if pretty and content:
|
|
|
|
|
d = content
|
|
|
|
|
content = '\n'
|
|
|
|
|
content = "\n"
|
|
|
|
|
for k, v in d.items():
|
|
|
|
|
content += f'- {k}: {v}\n'
|
|
|
|
|
content = textwrap.indent(content, ' ')
|
|
|
|
|
content += f"- {k}: {v}\n"
|
|
|
|
|
content = textwrap.indent(content, " ")
|
|
|
|
|
return f"{repr(self)}{content}"
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
@ -225,7 +230,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NetworkAgent(BaseAgent):
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, topology, node_id, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
@ -248,18 +252,21 @@ class NetworkAgent(BaseAgent):
|
|
|
|
|
def node(self):
|
|
|
|
|
return self.topology.nodes[self.node_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def iter_agents(self, unique_id=None, *, limit_neighbors=False, **kwargs):
|
|
|
|
|
unique_ids = None
|
|
|
|
|
if isinstance(unique_id, list):
|
|
|
|
|
unique_ids = set(unique_id)
|
|
|
|
|
elif unique_id is not None:
|
|
|
|
|
unique_ids = set([unique_id,])
|
|
|
|
|
unique_ids = set(
|
|
|
|
|
[
|
|
|
|
|
unique_id,
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if limit_neighbors:
|
|
|
|
|
neighbor_ids = set()
|
|
|
|
|
for node_id in self.G.neighbors(self.node_id):
|
|
|
|
|
if self.G.nodes[node_id].get('agent') is not None:
|
|
|
|
|
if self.G.nodes[node_id].get("agent") is not None:
|
|
|
|
|
neighbor_ids.add(node_id)
|
|
|
|
|
if unique_ids:
|
|
|
|
|
unique_ids = unique_ids & neighbor_ids
|
|
|
|
@ -272,7 +279,9 @@ class NetworkAgent(BaseAgent):
|
|
|
|
|
|
|
|
|
|
def subgraph(self, center=True, **kwargs):
|
|
|
|
|
include = [self] if center else []
|
|
|
|
|
G = self.G.subgraph(n.node_id for n in list(self.get_agents(**kwargs)+include))
|
|
|
|
|
G = self.G.subgraph(
|
|
|
|
|
n.node_id for n in list(self.get_agents(**kwargs) + include)
|
|
|
|
|
)
|
|
|
|
|
return G
|
|
|
|
|
|
|
|
|
|
def remove_node(self):
|
|
|
|
@ -280,11 +289,19 @@ class NetworkAgent(BaseAgent):
|
|
|
|
|
|
|
|
|
|
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
|
|
|
|
if self.node_id not in self.G.nodes(data=False):
|
|
|
|
|
raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id))
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"{} not in list of existing agents in the network".format(
|
|
|
|
|
self.unique_id
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
if other.node_id not in self.G.nodes(data=False):
|
|
|
|
|
raise ValueError('{} not in list of existing agents in the network'.format(other))
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"{} not in list of existing agents in the network".format(other)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.G.add_edge(self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
|
|
|
|
self.G.add_edge(
|
|
|
|
|
self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def die(self, remove=True):
|
|
|
|
|
if remove:
|
|
|
|
@ -294,11 +311,11 @@ class NetworkAgent(BaseAgent):
|
|
|
|
|
|
|
|
|
|
def state(name=None):
|
|
|
|
|
def decorator(func, name=None):
|
|
|
|
|
'''
|
|
|
|
|
"""
|
|
|
|
|
A state function should return either a state id, or a tuple (state_id, when)
|
|
|
|
|
The default value for state_id is the current state id.
|
|
|
|
|
The default value for when is the interval defined in the environment.
|
|
|
|
|
'''
|
|
|
|
|
"""
|
|
|
|
|
if inspect.isgeneratorfunction(func):
|
|
|
|
|
orig_func = func
|
|
|
|
|
|
|
|
|
@ -348,32 +365,38 @@ class MetaFSM(MetaAgent):
|
|
|
|
|
|
|
|
|
|
# Add new states
|
|
|
|
|
for attr, func in namespace.items():
|
|
|
|
|
if hasattr(func, 'id'):
|
|
|
|
|
if hasattr(func, "id"):
|
|
|
|
|
if func.is_default:
|
|
|
|
|
default_state = func
|
|
|
|
|
states[func.id] = func
|
|
|
|
|
|
|
|
|
|
namespace.update({
|
|
|
|
|
'_default_state': default_state,
|
|
|
|
|
'_states': states,
|
|
|
|
|
})
|
|
|
|
|
namespace.update(
|
|
|
|
|
{
|
|
|
|
|
"_default_state": default_state,
|
|
|
|
|
"_states": states,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return super(MetaFSM, mcls).__new__(mcls=mcls, name=name, bases=bases, namespace=namespace)
|
|
|
|
|
return super(MetaFSM, mcls).__new__(
|
|
|
|
|
mcls=mcls, name=name, bases=bases, namespace=namespace
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FSM(BaseAgent, metaclass=MetaFSM):
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super(FSM, self).__init__(*args, **kwargs)
|
|
|
|
|
if not hasattr(self, 'state_id'):
|
|
|
|
|
if not hasattr(self, "state_id"):
|
|
|
|
|
if not self._default_state:
|
|
|
|
|
raise ValueError('No default state specified for {}'.format(self.unique_id))
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"No default state specified for {}".format(self.unique_id)
|
|
|
|
|
)
|
|
|
|
|
self.state_id = self._default_state.id
|
|
|
|
|
|
|
|
|
|
self._coroutine = None
|
|
|
|
|
self.set_state(self.state_id)
|
|
|
|
|
|
|
|
|
|
def step(self):
|
|
|
|
|
self.debug(f'Agent {self.unique_id} @ state {self.state_id}')
|
|
|
|
|
self.debug(f"Agent {self.unique_id} @ state {self.state_id}")
|
|
|
|
|
default_interval = super().step()
|
|
|
|
|
|
|
|
|
|
next_state = self._states[self.state_id](self)
|
|
|
|
@ -386,7 +409,9 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
|
|
|
|
elif len(when) == 1:
|
|
|
|
|
when = when[0]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError('Too many values returned. Only state (and time) allowed')
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Too many values returned. Only state (and time) allowed"
|
|
|
|
|
)
|
|
|
|
|
except TypeError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
@ -396,10 +421,10 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
|
|
|
|
return when or default_interval
|
|
|
|
|
|
|
|
|
|
def set_state(self, state, when=None):
|
|
|
|
|
if hasattr(state, 'id'):
|
|
|
|
|
if hasattr(state, "id"):
|
|
|
|
|
state = state.id
|
|
|
|
|
if state not in self._states:
|
|
|
|
|
raise ValueError('{} is not a valid state'.format(state))
|
|
|
|
|
raise ValueError("{} is not a valid state".format(state))
|
|
|
|
|
self.state_id = state
|
|
|
|
|
if when is not None:
|
|
|
|
|
self.model.schedule.add(self, when=when)
|
|
|
|
@ -414,23 +439,22 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prob(prob, random):
|
|
|
|
|
'''
|
|
|
|
|
"""
|
|
|
|
|
A true/False uniform distribution with a given probability.
|
|
|
|
|
To be used like this:
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if prob(0.3):
|
|
|
|
|
do_something()
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
"""
|
|
|
|
|
r = random.random()
|
|
|
|
|
return r < prob
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_distribution(network_agents=None,
|
|
|
|
|
agent_class=None):
|
|
|
|
|
'''
|
|
|
|
|
def calculate_distribution(network_agents=None, agent_class=None):
|
|
|
|
|
"""
|
|
|
|
|
Calculate the threshold values (thresholds for a uniform distribution)
|
|
|
|
|
of an agent distribution given the weights of each agent type.
|
|
|
|
|
|
|
|
|
@ -453,26 +477,28 @@ def calculate_distribution(network_agents=None,
|
|
|
|
|
|
|
|
|
|
In this example, 20% of the nodes will be marked as type
|
|
|
|
|
'agent_class_1'.
|
|
|
|
|
'''
|
|
|
|
|
"""
|
|
|
|
|
if network_agents:
|
|
|
|
|
network_agents = [deepcopy(agent) for agent in network_agents if not hasattr(agent, 'id')]
|
|
|
|
|
network_agents = [
|
|
|
|
|
deepcopy(agent) for agent in network_agents if not hasattr(agent, "id")
|
|
|
|
|
]
|
|
|
|
|
elif agent_class:
|
|
|
|
|
network_agents = [{'agent_class': agent_class}]
|
|
|
|
|
network_agents = [{"agent_class": agent_class}]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError('Specify a distribution or a default agent type')
|
|
|
|
|
raise ValueError("Specify a distribution or a default agent type")
|
|
|
|
|
|
|
|
|
|
# Fix missing weights and incompatible types
|
|
|
|
|
for x in network_agents:
|
|
|
|
|
x['weight'] = float(x.get('weight', 1))
|
|
|
|
|
x["weight"] = float(x.get("weight", 1))
|
|
|
|
|
|
|
|
|
|
# Calculate the thresholds
|
|
|
|
|
total = sum(x['weight'] for x in network_agents)
|
|
|
|
|
total = sum(x["weight"] for x in network_agents)
|
|
|
|
|
acc = 0
|
|
|
|
|
for v in network_agents:
|
|
|
|
|
if 'ids' in v:
|
|
|
|
|
if "ids" in v:
|
|
|
|
|
continue
|
|
|
|
|
upper = acc + (v['weight']/total)
|
|
|
|
|
v['threshold'] = [acc, upper]
|
|
|
|
|
upper = acc + (v["weight"] / total)
|
|
|
|
|
v["threshold"] = [acc, upper]
|
|
|
|
|
acc = upper
|
|
|
|
|
return network_agents
|
|
|
|
|
|
|
|
|
@ -480,28 +506,29 @@ def calculate_distribution(network_agents=None,
|
|
|
|
|
def serialize_type(agent_class, known_modules=[], **kwargs):
|
|
|
|
|
if isinstance(agent_class, str):
|
|
|
|
|
return agent_class
|
|
|
|
|
known_modules += ['soil.agents']
|
|
|
|
|
return serialization.serialize(agent_class, known_modules=known_modules, **kwargs)[1] # Get the name of the class
|
|
|
|
|
known_modules += ["soil.agents"]
|
|
|
|
|
return serialization.serialize(agent_class, known_modules=known_modules, **kwargs)[
|
|
|
|
|
1
|
|
|
|
|
] # Get the name of the class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def serialize_definition(network_agents, known_modules=[]):
|
|
|
|
|
'''
|
|
|
|
|
"""
|
|
|
|
|
When serializing an agent distribution, remove the thresholds, in order
|
|
|
|
|
to avoid cluttering the YAML definition file.
|
|
|
|
|
'''
|
|
|
|
|
"""
|
|
|
|
|
d = deepcopy(list(network_agents))
|
|
|
|
|
for v in d:
|
|
|
|
|
if 'threshold' in v:
|
|
|
|
|
del v['threshold']
|
|
|
|
|
v['agent_class'] = serialize_type(v['agent_class'],
|
|
|
|
|
known_modules=known_modules)
|
|
|
|
|
if "threshold" in v:
|
|
|
|
|
del v["threshold"]
|
|
|
|
|
v["agent_class"] = serialize_type(v["agent_class"], known_modules=known_modules)
|
|
|
|
|
return d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def deserialize_type(agent_class, known_modules=[]):
|
|
|
|
|
if not isinstance(agent_class, str):
|
|
|
|
|
return agent_class
|
|
|
|
|
known = known_modules + ['soil.agents', 'soil.agents.custom' ]
|
|
|
|
|
known = known_modules + ["soil.agents", "soil.agents.custom"]
|
|
|
|
|
agent_class = serialization.deserializer(agent_class, known_modules=known)
|
|
|
|
|
return agent_class
|
|
|
|
|
|
|
|
|
@ -509,12 +536,12 @@ def deserialize_type(agent_class, known_modules=[]):
|
|
|
|
|
def deserialize_definition(ind, **kwargs):
|
|
|
|
|
d = deepcopy(ind)
|
|
|
|
|
for v in d:
|
|
|
|
|
v['agent_class'] = deserialize_type(v['agent_class'], **kwargs)
|
|
|
|
|
v["agent_class"] = deserialize_type(v["agent_class"], **kwargs)
|
|
|
|
|
return d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _validate_states(states, topology):
|
|
|
|
|
'''Validate states to avoid ignoring states during initialization'''
|
|
|
|
|
"""Validate states to avoid ignoring states during initialization"""
|
|
|
|
|
states = states or []
|
|
|
|
|
if isinstance(states, dict):
|
|
|
|
|
for x in states:
|
|
|
|
@ -525,7 +552,7 @@ def _validate_states(states, topology):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_agent_classs(ind, to_string=False, **kwargs):
|
|
|
|
|
'''Convenience method to allow specifying agents by class or class name.'''
|
|
|
|
|
"""Convenience method to allow specifying agents by class or class name."""
|
|
|
|
|
if to_string:
|
|
|
|
|
return serialize_definition(ind, **kwargs)
|
|
|
|
|
return deserialize_definition(ind, **kwargs)
|
|
|
|
@ -609,12 +636,10 @@ def _convert_agent_classs(ind, to_string=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AgentView(Mapping, Set):
|
|
|
|
|
"""A lazy-loaded list of agents.
|
|
|
|
|
"""
|
|
|
|
|
"""A lazy-loaded list of agents."""
|
|
|
|
|
|
|
|
|
|
__slots__ = ("_agents",)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, agents):
|
|
|
|
|
self._agents = agents
|
|
|
|
|
|
|
|
|
@ -657,11 +682,20 @@ class AgentView(Mapping, Set):
|
|
|
|
|
return f"{self.__class__.__name__}({self})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=None, ignore=None, state=None,
|
|
|
|
|
limit=None, **kwargs):
|
|
|
|
|
'''
|
|
|
|
|
def filter_agents(
|
|
|
|
|
agents,
|
|
|
|
|
*id_args,
|
|
|
|
|
unique_id=None,
|
|
|
|
|
state_id=None,
|
|
|
|
|
agent_class=None,
|
|
|
|
|
ignore=None,
|
|
|
|
|
state=None,
|
|
|
|
|
limit=None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
|
|
|
|
'''
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(agents, dict)
|
|
|
|
|
|
|
|
|
|
ids = []
|
|
|
|
@ -694,7 +728,7 @@ def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=N
|
|
|
|
|
f = filter(lambda x: x not in ignore, f)
|
|
|
|
|
|
|
|
|
|
if state_id is not None:
|
|
|
|
|
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
|
|
|
|
|
f = filter(lambda agent: agent.get("state_id", None) in state_id, f)
|
|
|
|
|
|
|
|
|
|
if agent_class is not None:
|
|
|
|
|
f = filter(lambda agent: isinstance(agent, agent_class), f)
|
|
|
|
@ -711,23 +745,25 @@ def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=N
|
|
|
|
|
yield from f
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def from_config(cfg: config.AgentConfig, random, topology: nx.Graph = None) -> List[Dict[str, Any]]:
|
|
|
|
|
'''
|
|
|
|
|
def from_config(
|
|
|
|
|
cfg: config.AgentConfig, random, topology: nx.Graph = None
|
|
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
|
|
"""
|
|
|
|
|
This function turns an agentconfig into a list of individual "agent specifications", which are just a dictionary
|
|
|
|
|
with the parameters that the environment will use to construct each agent.
|
|
|
|
|
|
|
|
|
|
This function does NOT return a list of agents, mostly because some attributes to the agent are not known at the
|
|
|
|
|
time of calling this function, such as `unique_id`.
|
|
|
|
|
'''
|
|
|
|
|
"""
|
|
|
|
|
default = cfg or config.AgentConfig()
|
|
|
|
|
if not isinstance(cfg, config.AgentConfig):
|
|
|
|
|
cfg = config.AgentConfig(**cfg)
|
|
|
|
|
return _agents_from_config(cfg, topology=topology, random=random)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _agents_from_config(cfg: config.AgentConfig,
|
|
|
|
|
topology: nx.Graph,
|
|
|
|
|
random) -> List[Dict[str, Any]]:
|
|
|
|
|
def _agents_from_config(
|
|
|
|
|
cfg: config.AgentConfig, topology: nx.Graph, random
|
|
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
|
|
if cfg and not isinstance(cfg, config.AgentConfig):
|
|
|
|
|
cfg = config.AgentConfig(**cfg)
|
|
|
|
|
|
|
|
|
@ -737,7 +773,9 @@ def _agents_from_config(cfg: config.AgentConfig,
|
|
|
|
|
assigned_network = 0
|
|
|
|
|
|
|
|
|
|
if cfg.fixed is not None:
|
|
|
|
|
agents, assigned_total, assigned_network = _from_fixed(cfg.fixed, topology=cfg.topology, default=cfg)
|
|
|
|
|
agents, assigned_total, assigned_network = _from_fixed(
|
|
|
|
|
cfg.fixed, topology=cfg.topology, default=cfg
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
n = cfg.n
|
|
|
|
|
|
|
|
|
@ -749,46 +787,56 @@ def _agents_from_config(cfg: config.AgentConfig,
|
|
|
|
|
|
|
|
|
|
for d in cfg.distribution:
|
|
|
|
|
if d.strategy == config.Strategy.topology:
|
|
|
|
|
topo = d.topology if ('topology' in d.__fields_set__) else cfg.topology
|
|
|
|
|
topo = d.topology if ("topology" in d.__fields_set__) else cfg.topology
|
|
|
|
|
if not topo:
|
|
|
|
|
raise ValueError('The "topology" strategy only works if the topology parameter is set to True')
|
|
|
|
|
raise ValueError(
|
|
|
|
|
'The "topology" strategy only works if the topology parameter is set to True'
|
|
|
|
|
)
|
|
|
|
|
if not topo_size:
|
|
|
|
|
raise ValueError(f'Topology does not have enough free nodes to assign one to the agent')
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Topology does not have enough free nodes to assign one to the agent"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
networked.append(d)
|
|
|
|
|
|
|
|
|
|
if d.strategy == config.Strategy.total:
|
|
|
|
|
if not cfg.n:
|
|
|
|
|
raise ValueError('Cannot use the "total" strategy without providing the total number of agents')
|
|
|
|
|
raise ValueError(
|
|
|
|
|
'Cannot use the "total" strategy without providing the total number of agents'
|
|
|
|
|
)
|
|
|
|
|
total.append(d)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if networked:
|
|
|
|
|
new_agents = _from_distro(networked,
|
|
|
|
|
n= topo_size - assigned_network,
|
|
|
|
|
topology=topo,
|
|
|
|
|
default=cfg,
|
|
|
|
|
random=random)
|
|
|
|
|
new_agents = _from_distro(
|
|
|
|
|
networked,
|
|
|
|
|
n=topo_size - assigned_network,
|
|
|
|
|
topology=topo,
|
|
|
|
|
default=cfg,
|
|
|
|
|
random=random,
|
|
|
|
|
)
|
|
|
|
|
assigned_total += len(new_agents)
|
|
|
|
|
assigned_network += len(new_agents)
|
|
|
|
|
agents += new_agents
|
|
|
|
|
|
|
|
|
|
if total:
|
|
|
|
|
remaining = n - assigned_total
|
|
|
|
|
agents += _from_distro(total, n=remaining,
|
|
|
|
|
default=cfg,
|
|
|
|
|
random=random)
|
|
|
|
|
|
|
|
|
|
remaining = n - assigned_total
|
|
|
|
|
agents += _from_distro(total, n=remaining, default=cfg, random=random)
|
|
|
|
|
|
|
|
|
|
if assigned_network < topo_size:
|
|
|
|
|
utils.logger.warn(f'The total number of agents does not match the total number of nodes in '
|
|
|
|
|
'every topology. This may be due to a definition error: assigned: '
|
|
|
|
|
f'{ assigned } total size: { topo_size }')
|
|
|
|
|
utils.logger.warn(
|
|
|
|
|
f"The total number of agents does not match the total number of nodes in "
|
|
|
|
|
"every topology. This may be due to a definition error: assigned: "
|
|
|
|
|
f"{ assigned } total size: { topo_size }"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return agents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _from_fixed(lst: List[config.FixedAgentConfig], topology: bool, default: config.SingleAgentConfig) -> List[Dict[str, Any]]:
|
|
|
|
|
def _from_fixed(
|
|
|
|
|
lst: List[config.FixedAgentConfig],
|
|
|
|
|
topology: bool,
|
|
|
|
|
default: config.SingleAgentConfig,
|
|
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
|
|
agents = []
|
|
|
|
|
|
|
|
|
|
counts_total = 0
|
|
|
|
@ -799,12 +847,18 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: bool, default: con
|
|
|
|
|
if default:
|
|
|
|
|
agent = default.state.copy()
|
|
|
|
|
agent.update(fixed.state)
|
|
|
|
|
cls = serialization.deserialize(fixed.agent_class or (default and default.agent_class))
|
|
|
|
|
agent['agent_class'] = cls
|
|
|
|
|
topo = fixed.topology if ('topology' in fixed.__fields_set__) else topology or default.topology
|
|
|
|
|
cls = serialization.deserialize(
|
|
|
|
|
fixed.agent_class or (default and default.agent_class)
|
|
|
|
|
)
|
|
|
|
|
agent["agent_class"] = cls
|
|
|
|
|
topo = (
|
|
|
|
|
fixed.topology
|
|
|
|
|
if ("topology" in fixed.__fields_set__)
|
|
|
|
|
else topology or default.topology
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if topo:
|
|
|
|
|
agent['topology'] = True
|
|
|
|
|
agent["topology"] = True
|
|
|
|
|
counts_network += 1
|
|
|
|
|
if not fixed.hidden:
|
|
|
|
|
counts_total += 1
|
|
|
|
@ -813,17 +867,21 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: bool, default: con
|
|
|
|
|
return agents, counts_total, counts_network
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _from_distro(distro: List[config.AgentDistro],
|
|
|
|
|
n: int,
|
|
|
|
|
topology: str,
|
|
|
|
|
default: config.SingleAgentConfig,
|
|
|
|
|
random) -> List[Dict[str, Any]]:
|
|
|
|
|
def _from_distro(
|
|
|
|
|
distro: List[config.AgentDistro],
|
|
|
|
|
n: int,
|
|
|
|
|
topology: str,
|
|
|
|
|
default: config.SingleAgentConfig,
|
|
|
|
|
random,
|
|
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
|
|
|
|
|
|
|
agents = []
|
|
|
|
|
|
|
|
|
|
if n is None:
|
|
|
|
|
if any(lambda dist: dist.n is None, distro):
|
|
|
|
|
raise ValueError('You must provide a total number of agents, or the number of each type')
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"You must provide a total number of agents, or the number of each type"
|
|
|
|
|
)
|
|
|
|
|
n = sum(dist.n for dist in distro)
|
|
|
|
|
|
|
|
|
|
weights = list(dist.weight if dist.weight is not None else 1 for dist in distro)
|
|
|
|
@ -836,29 +894,40 @@ def _from_distro(distro: List[config.AgentDistro],
|
|
|
|
|
# So instead we calculate our own distribution to make sure the actual ratios are close to what we would expect
|
|
|
|
|
|
|
|
|
|
# Calculate how many times each has to appear
|
|
|
|
|
indices = list(chain.from_iterable([idx] * int(n*chunk) for (idx, n) in enumerate(norm)))
|
|
|
|
|
indices = list(
|
|
|
|
|
chain.from_iterable([idx] * int(n * chunk) for (idx, n) in enumerate(norm))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Complete with random agents following the original weight distribution
|
|
|
|
|
if len(indices) < n:
|
|
|
|
|
indices += random.choices(list(range(len(distro))), weights=[d.weight for d in distro], k=n-len(indices))
|
|
|
|
|
indices += random.choices(
|
|
|
|
|
list(range(len(distro))),
|
|
|
|
|
weights=[d.weight for d in distro],
|
|
|
|
|
k=n - len(indices),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Deserialize classes for efficiency
|
|
|
|
|
classes = list(serialization.deserialize(i.agent_class or default.agent_class) for i in distro)
|
|
|
|
|
classes = list(
|
|
|
|
|
serialization.deserialize(i.agent_class or default.agent_class) for i in distro
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add them in random order
|
|
|
|
|
random.shuffle(indices)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for idx in indices:
|
|
|
|
|
d = distro[idx]
|
|
|
|
|
agent = d.state.copy()
|
|
|
|
|
cls = classes[idx]
|
|
|
|
|
agent['agent_class'] = cls
|
|
|
|
|
agent["agent_class"] = cls
|
|
|
|
|
if default:
|
|
|
|
|
agent.update(default.state)
|
|
|
|
|
topology = d.topology if ('topology' in d.__fields_set__) else topology or default.topology
|
|
|
|
|
topology = (
|
|
|
|
|
d.topology
|
|
|
|
|
if ("topology" in d.__fields_set__)
|
|
|
|
|
else topology or default.topology
|
|
|
|
|
)
|
|
|
|
|
if topology:
|
|
|
|
|
agent['topology'] = topology
|
|
|
|
|
agent["topology"] = topology
|
|
|
|
|
agents.append(agent)
|
|
|
|
|
|
|
|
|
|
return agents
|
|
|
|
@ -877,4 +946,5 @@ try:
|
|
|
|
|
from .Geo import Geo
|
|
|
|
|
except ImportError:
|
|
|
|
|
import sys
|
|
|
|
|
print('Could not load the Geo Agent, scipy is not installed', file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
print("Could not load the Geo Agent, scipy is not installed", file=sys.stderr)
|
|
|
|
|