mirror of
https://github.com/gsi-upm/soil
synced 2025-09-18 22:22:20 +00:00
Big refactor v0.30
All test pass, except for the TestConfig suite, which is not too critical as the plan for this version onwards is to avoid configuration as much as possible.
This commit is contained in:
@@ -6,20 +6,22 @@ import math
|
||||
import logging
|
||||
import inspect
|
||||
|
||||
from typing import Any, Dict, Optional, Union, List
|
||||
from typing import Any, Callable, Dict, Optional, Union, List, Type
|
||||
from collections import namedtuple
|
||||
from time import time as current_time
|
||||
from copy import deepcopy
|
||||
from networkx.readwrite import json_graph
|
||||
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from mesa import Model
|
||||
from mesa import Model, Agent
|
||||
|
||||
from . import agents as agentmod, config, datacollection, serialization, utils, time, network, events
|
||||
from . import agents as agentmod, datacollection, serialization, utils, time, network, events
|
||||
|
||||
|
||||
# TODO: add metaclass to read attributes of a model
|
||||
# TODO: read "report" attributes from the model
|
||||
|
||||
class BaseEnvironment(Model):
|
||||
"""
|
||||
The environment is key in a simulation. It controls how agents interact,
|
||||
@@ -33,29 +35,35 @@ class BaseEnvironment(Model):
|
||||
:meth:`soil.environment.Environment.get` method.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args: Any, seed="default", dir_path=None, **kwargs: Any) -> Any:
|
||||
"""Create a new model with a default seed value"""
|
||||
self = super().__new__(cls, *args, seed=seed, **kwargs)
|
||||
self.dir_path = dir_path or os.getcwd()
|
||||
return self
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
id="unnamed_env",
|
||||
seed="default",
|
||||
schedule_class=time.TimedActivation,
|
||||
dir_path=None,
|
||||
schedule_class=time.TimedActivation,
|
||||
interval=1,
|
||||
agent_class=None,
|
||||
agents: List[tuple[type, Dict[str, Any]]] = {},
|
||||
agents: Optional[Dict] = None,
|
||||
collector_class: type = datacollection.SoilCollector,
|
||||
agent_reporters: Optional[Any] = None,
|
||||
model_reporters: Optional[Any] = None,
|
||||
tables: Optional[Any] = None,
|
||||
init: bool = True,
|
||||
**env_params,
|
||||
):
|
||||
|
||||
super().__init__(seed=seed)
|
||||
super().__init__()
|
||||
|
||||
self.current_id = -1
|
||||
|
||||
self.id = id
|
||||
|
||||
self.dir_path = dir_path or os.getcwd()
|
||||
|
||||
if schedule_class is None:
|
||||
schedule_class = time.TimedActivation
|
||||
@@ -63,10 +71,7 @@ class BaseEnvironment(Model):
|
||||
schedule_class = serialization.deserialize(schedule_class)
|
||||
self.schedule = schedule_class(self)
|
||||
|
||||
self.agent_class = agent_class or agentmod.BaseAgent
|
||||
|
||||
self.interval = interval
|
||||
self.init_agents(agents)
|
||||
|
||||
self.logger = utils.logger.getChild(self.id)
|
||||
|
||||
@@ -79,53 +84,13 @@ class BaseEnvironment(Model):
|
||||
for (k, v) in env_params.items():
|
||||
self[k] = v
|
||||
|
||||
def _agent_from_dict(self, agent):
|
||||
"""
|
||||
Translate an agent dictionary into an agent
|
||||
"""
|
||||
agent = dict(**agent)
|
||||
cls = agent.pop("agent_class", None) or self.agent_class
|
||||
unique_id = agent.pop("unique_id", None)
|
||||
if unique_id is None:
|
||||
unique_id = self.next_id()
|
||||
if agents:
|
||||
self.add_agents(**agents)
|
||||
if init:
|
||||
self.init()
|
||||
|
||||
return serialization.deserialize(cls)(unique_id=unique_id, model=self, **agent)
|
||||
|
||||
def init_agents(self, agents: Union[config.AgentConfig, List[Dict[str, Any]]] = {}):
|
||||
"""
|
||||
Initialize the agents in the model from either a `soil.config.AgentConfig` or a list of
|
||||
dictionaries that each describes an agent.
|
||||
|
||||
If given a list of dictionaries, an agent will be created for each dictionary. The agent
|
||||
class can be specified through the `agent_class` key. The rest of the items will be used
|
||||
as parameters to the agent.
|
||||
"""
|
||||
if not agents:
|
||||
return
|
||||
|
||||
lst = agents
|
||||
override = []
|
||||
if not isinstance(lst, list):
|
||||
if not isinstance(agents, config.AgentConfig):
|
||||
lst = config.AgentConfig(**agents)
|
||||
if lst.override:
|
||||
override = lst.override
|
||||
lst = self._agent_dict_from_config(lst)
|
||||
|
||||
# TODO: check override is working again. It cannot (easily) be part of agents.from_config anymore,
|
||||
# because it needs attribute such as unique_id, which are only present after init
|
||||
new_agents = [self._agent_from_dict(agent) for agent in lst]
|
||||
|
||||
for a in new_agents:
|
||||
self.schedule.add(a)
|
||||
|
||||
for rule in override:
|
||||
for agent in agentmod.filter_agents(self.schedule._agents, **rule.filter):
|
||||
for attr, value in rule.state.items():
|
||||
setattr(agent, attr, value)
|
||||
|
||||
def _agent_dict_from_config(self, cfg):
|
||||
return agentmod.from_config(cfg, random=self.random)
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
@@ -145,16 +110,29 @@ class BaseEnvironment(Model):
|
||||
"The environment has not been scheduled, so it has no sense of time"
|
||||
)
|
||||
|
||||
def add_agent(self, unique_id=None, **kwargs):
|
||||
def add_agent(self, agent_class, unique_id=None, **agent):
|
||||
if unique_id is None:
|
||||
unique_id = self.next_id()
|
||||
|
||||
kwargs["unique_id"] = unique_id
|
||||
a = self._agent_from_dict(kwargs)
|
||||
agent["unique_id"] = unique_id
|
||||
|
||||
agent = dict(**agent)
|
||||
unique_id = agent.pop("unique_id", None)
|
||||
if unique_id is None:
|
||||
unique_id = self.next_id()
|
||||
|
||||
a = serialization.deserialize(agent_class)(unique_id=unique_id, model=self, **agent)
|
||||
|
||||
self.schedule.add(a)
|
||||
return a
|
||||
|
||||
def add_agents(self, agent_classes: List[type], k, weights: Optional[List[float]] = None, **kwargs):
|
||||
if weights is None:
|
||||
weights = [1] * len(agent_classes)
|
||||
|
||||
for cls in self.random.choices(agent_classes, weights=weights, k=k):
|
||||
self.add_agent(agent_class=cls, **kwargs)
|
||||
|
||||
def log(self, message, *args, level=logging.INFO, **kwargs):
|
||||
if not self.logger.isEnabledFor(level):
|
||||
return
|
||||
@@ -215,61 +193,58 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *args, topology: Union[config.NetConfig, nx.Graph] = None, **kwargs
|
||||
self, *args,
|
||||
topology: Optional[Union[nx.Graph, str]] = None,
|
||||
agent_class: Optional[Type[agentmod.Agent]] = None,
|
||||
network_generator: Optional[Callable] = None,
|
||||
network_params: Optional[Dict] = None, **kwargs
|
||||
):
|
||||
agents = kwargs.pop("agents", None)
|
||||
super().__init__(*args, agents=None, **kwargs)
|
||||
self.topology = topology
|
||||
self.network_generator = network_generator
|
||||
self.network_params = network_params
|
||||
if topology or network_params or network_generator:
|
||||
self.create_network(topology, network_params=network_params, network_generator=network_generator)
|
||||
else:
|
||||
self.G = nx.Graph()
|
||||
super().__init__(*args, **kwargs, init=False)
|
||||
|
||||
if topology is None:
|
||||
topology = nx.Graph()
|
||||
elif not isinstance(topology, nx.Graph):
|
||||
topology = network.from_config(topology, dir_path=self.dir_path)
|
||||
self.agent_class = agent_class
|
||||
if agent_class:
|
||||
self.agent_class = serialization.deserialize(agent_class)
|
||||
self.init()
|
||||
if self.agent_class:
|
||||
self.populate_network(self.agent_class)
|
||||
|
||||
|
||||
def add_agents(self, *args, k=None, **kwargs):
|
||||
if not k and not self.G:
|
||||
raise ValueError("Cannot add agents to an empty network")
|
||||
super().add_agents(*args, k=k or len(self.G), **kwargs)
|
||||
|
||||
def create_network(self, topology=None, network_generator=None, path=None, network_params=None):
|
||||
if topology is not None:
|
||||
topology = network.from_topology(topology, dir_path=self.dir_path)
|
||||
elif path is not None:
|
||||
topology = network.from_topology(path, dir_path=self.dir_path)
|
||||
elif network_generator is not None:
|
||||
topology = network.from_params(network_generator, dir_path=self.dir_path, **network_params)
|
||||
else:
|
||||
raise ValueError("topology must be a networkx.Graph or a string, or network_generator must be provided")
|
||||
self.G = topology
|
||||
|
||||
self.init_agents(agents)
|
||||
|
||||
def init_agents(self, *args, **kwargs):
|
||||
"""Initialize the agents from a"""
|
||||
super().init_agents(*args, **kwargs)
|
||||
for agent in self.schedule._agents.values():
|
||||
self._init_node(agent)
|
||||
self._assign_node(agent)
|
||||
|
||||
def _init_node(self, agent):
|
||||
def _assign_node(self, agent):
|
||||
"""
|
||||
Make sure the node for a given agent has the proper attributes.
|
||||
"""
|
||||
if hasattr(agent, "node_id"):
|
||||
self.G.nodes[agent.node_id]["agent"] = agent
|
||||
|
||||
def _agent_dict_from_config(self, cfg):
|
||||
return agentmod.from_config(cfg, topology=self.G, random=self.random)
|
||||
|
||||
def _agent_from_dict(self, agent, unique_id=None):
|
||||
agent = dict(agent)
|
||||
|
||||
if not agent.get("topology", False):
|
||||
return super()._agent_from_dict(agent)
|
||||
|
||||
if unique_id is None:
|
||||
unique_id = self.next_id()
|
||||
node_id = agent.get("node_id", None)
|
||||
if node_id is None:
|
||||
node_id = network.find_unassigned(self.G, random=self.random)
|
||||
self.G.nodes[node_id]["agent"] = None
|
||||
agent["node_id"] = node_id
|
||||
agent["unique_id"] = unique_id
|
||||
agent["topology"] = self.G
|
||||
node_attrs = self.G.nodes[node_id]
|
||||
node_attrs.pop('agent', None)
|
||||
node_attrs.update(agent)
|
||||
agent = node_attrs
|
||||
|
||||
a = super()._agent_from_dict(agent)
|
||||
self._init_node(a)
|
||||
|
||||
return a
|
||||
|
||||
@property
|
||||
def network_agents(self):
|
||||
for a in self.schedule._agents.values():
|
||||
@@ -302,24 +277,37 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
a["visible"] = True
|
||||
return a
|
||||
|
||||
def add_agent(self, *args, **kwargs):
|
||||
a = super().add_agent(*args, **kwargs)
|
||||
def add_agent(self, agent_class, *args, **kwargs):
|
||||
if issubclass(agent_class, agentmod.NetworkAgent) and "node_id" not in kwargs:
|
||||
return self.add_node(agent_class, *args, **kwargs)
|
||||
a = super().add_agent(agent_class, *args, **kwargs)
|
||||
if hasattr(a, "node_id"):
|
||||
assert self.G.nodes[a.node_id]["agent"] == a
|
||||
assigned = self.G.nodes[a.node_id].get("agent")
|
||||
if not assigned:
|
||||
self.G.nodes[a.node_id]["agent"] = a
|
||||
elif assigned != a:
|
||||
raise ValueError(f"Node {a.node_id} already has an agent assigned: {assigned}")
|
||||
return a
|
||||
|
||||
def agent_for_node_id(self, node_id):
|
||||
return self.G.nodes[node_id].get("agent")
|
||||
|
||||
def populate_network(self, agent_class, weights=None, **agent_params):
|
||||
if not hasattr(agent_class, "len"):
|
||||
def populate_network(self, agent_class: List[Model], weights: List[float] = None, **agent_params):
|
||||
if isinstance(agent_class, type):
|
||||
agent_class = [agent_class]
|
||||
weights = None
|
||||
for (node_id, node) in self.G.nodes(data=True):
|
||||
else:
|
||||
agent_class = list(agent_class)
|
||||
if not weights:
|
||||
weights = [1] * len(agent_class)
|
||||
assert len(self.G)
|
||||
classes = self.random.choices(agent_class, weights, k=len(self.G))
|
||||
for (cls, (node_id, node)) in zip(classes, self.G.nodes(data=True)):
|
||||
if "agent" in node:
|
||||
continue
|
||||
a_class = self.random.choices(agent_class, weights)[0]
|
||||
self.add_agent(node_id=node_id, topology=self.G, agent_class=a_class, **agent_params)
|
||||
a = self.add_agent(node_id=node_id, topology=self.G, agent_class=cls, **agent_params)
|
||||
node["agent"] = a
|
||||
assert all("agent" in node for (_, node) in self.G.nodes(data=True))
|
||||
assert len(list(self.network_agents))
|
||||
|
||||
|
||||
class EventedEnvironment(BaseEnvironment):
|
||||
|
Reference in New Issue
Block a user