mirror of
https://github.com/gsi-upm/soil
synced 2025-09-18 22:22:20 +00:00
Clean-up
* Removed old/unnecessary models * Added a `simulation.{iter_}from_py` method to load simulations from python files * Changed tests of examples to run programmatic simulations * Fixed programmatic examples
This commit is contained in:
@@ -6,7 +6,7 @@ import math
|
||||
import logging
|
||||
import inspect
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union, List
|
||||
from collections import namedtuple
|
||||
from time import time as current_time
|
||||
from copy import deepcopy
|
||||
@@ -16,9 +16,8 @@ from networkx.readwrite import json_graph
|
||||
import networkx as nx
|
||||
|
||||
from mesa import Model
|
||||
from mesa.datacollection import DataCollector
|
||||
|
||||
from . import agents as agentmod, config, serialization, utils, time, network, events
|
||||
from . import agents as agentmod, config, datacollection, serialization, utils, time, network, events
|
||||
|
||||
|
||||
class BaseEnvironment(Model):
|
||||
@@ -42,7 +41,8 @@ class BaseEnvironment(Model):
|
||||
dir_path=None,
|
||||
interval=1,
|
||||
agent_class=None,
|
||||
agents: [tuple[type, Dict[str, Any]]] = {},
|
||||
agents: List[tuple[type, Dict[str, Any]]] = {},
|
||||
collector_class: type = datacollection.SoilCollector,
|
||||
agent_reporters: Optional[Any] = None,
|
||||
model_reporters: Optional[Any] = None,
|
||||
tables: Optional[Any] = None,
|
||||
@@ -50,7 +50,6 @@ class BaseEnvironment(Model):
|
||||
):
|
||||
|
||||
super().__init__(seed=seed)
|
||||
self.env_params = env_params or {}
|
||||
|
||||
self.current_id = -1
|
||||
|
||||
@@ -71,11 +70,14 @@ class BaseEnvironment(Model):
|
||||
|
||||
self.logger = utils.logger.getChild(self.id)
|
||||
|
||||
self.datacollector = DataCollector(
|
||||
collector_class = serialization.deserialize(collector_class)
|
||||
self.datacollector = collector_class(
|
||||
model_reporters=model_reporters,
|
||||
agent_reporters=agent_reporters,
|
||||
tables=tables,
|
||||
)
|
||||
for (k, v) in env_params.items():
|
||||
self[k] = v
|
||||
|
||||
def _agent_from_dict(self, agent):
|
||||
"""
|
||||
@@ -89,7 +91,7 @@ class BaseEnvironment(Model):
|
||||
|
||||
return serialization.deserialize(cls)(unique_id=unique_id, model=self, **agent)
|
||||
|
||||
def init_agents(self, agents: Union[config.AgentConfig, [Dict[str, Any]]] = {}):
|
||||
def init_agents(self, agents: Union[config.AgentConfig, List[Dict[str, Any]]] = {}):
|
||||
"""
|
||||
Initialize the agents in the model from either a `soil.config.AgentConfig` or a list of
|
||||
dictionaries that each describes an agent.
|
||||
@@ -170,31 +172,41 @@ class BaseEnvironment(Model):
|
||||
Advance one step in the simulation, and update the data collection and scheduler appropriately
|
||||
"""
|
||||
super().step()
|
||||
self.logger.info(
|
||||
f"--- Step: {self.schedule.steps:^5} - Time: {self.now:^5} ---"
|
||||
)
|
||||
# self.logger.info(
|
||||
# "--- Step: {:^5} - Time: {now:^5} ---", steps=self.schedule.steps, now=self.now
|
||||
# )
|
||||
self.schedule.step()
|
||||
self.datacollector.collect(self)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.env_params
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""
|
||||
Get the value of an environment attribute.
|
||||
Return `default` if the value is not set.
|
||||
"""
|
||||
return self.env_params.get(key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.env_params.get(key)
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError:
|
||||
raise KeyError(f"key {key} not found in environment")
|
||||
|
||||
def __delitem__(self, key):
|
||||
return delattr(self, key)
|
||||
|
||||
def __contains__(self, key):
|
||||
return hasattr(self, key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
return self.env_params.__setitem__(key, value)
|
||||
setattr(self, key, value)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.env_params)
|
||||
return str(dict(self))
|
||||
|
||||
def __len__(self):
|
||||
return sum(1 for n in self.keys())
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.agents())
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self[key] if key in self else default
|
||||
|
||||
def keys(self):
|
||||
return (k for k in self.__dict__ if k[0] != "_")
|
||||
|
||||
class NetworkEnvironment(BaseEnvironment):
|
||||
"""
|
||||
@@ -208,7 +220,12 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
agents = kwargs.pop("agents", None)
|
||||
super().__init__(*args, agents=None, **kwargs)
|
||||
|
||||
self._set_topology(topology)
|
||||
if topology is None:
|
||||
topology = nx.Graph()
|
||||
elif not isinstance(topology, nx.Graph):
|
||||
topology = network.from_config(topology, dir_path=self.dir_path)
|
||||
|
||||
self.G = topology
|
||||
|
||||
self.init_agents(agents)
|
||||
|
||||
@@ -216,14 +233,14 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
"""Initialize the agents from a"""
|
||||
super().init_agents(*args, **kwargs)
|
||||
for agent in self.schedule._agents.values():
|
||||
if hasattr(agent, "node_id"):
|
||||
self._init_node(agent)
|
||||
self._init_node(agent)
|
||||
|
||||
def _init_node(self, agent):
|
||||
"""
|
||||
Make sure the node for a given agent has the proper attributes.
|
||||
"""
|
||||
self.G.nodes[agent.node_id]["agent"] = agent
|
||||
if hasattr(agent, "node_id"):
|
||||
self.G.nodes[agent.node_id]["agent"] = agent
|
||||
|
||||
def _agent_dict_from_config(self, cfg):
|
||||
return agentmod.from_config(cfg, topology=self.G, random=self.random)
|
||||
@@ -244,6 +261,7 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
agent["unique_id"] = unique_id
|
||||
agent["topology"] = self.G
|
||||
node_attrs = self.G.nodes[node_id]
|
||||
node_attrs.pop('agent', None)
|
||||
node_attrs.update(agent)
|
||||
agent = node_attrs
|
||||
|
||||
@@ -252,17 +270,9 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
|
||||
return a
|
||||
|
||||
def _set_topology(self, cfg=None, dir_path=None):
|
||||
if cfg is None:
|
||||
cfg = nx.Graph()
|
||||
elif not isinstance(cfg, nx.Graph):
|
||||
cfg = network.from_config(cfg, dir_path=dir_path or self.dir_path)
|
||||
|
||||
self.G = cfg
|
||||
|
||||
@property
|
||||
def network_agents(self):
|
||||
for a in self.schedule._agents:
|
||||
for a in self.schedule._agents.values():
|
||||
if isinstance(a, agentmod.NetworkAgent):
|
||||
yield a
|
||||
|
||||
@@ -294,7 +304,7 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
|
||||
def add_agent(self, *args, **kwargs):
|
||||
a = super().add_agent(*args, **kwargs)
|
||||
if "node_id" in a:
|
||||
if hasattr(a, "node_id"):
|
||||
assert self.G.nodes[a.node_id]["agent"] == a
|
||||
return a
|
||||
|
||||
@@ -309,7 +319,7 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
if "agent" in node:
|
||||
continue
|
||||
a_class = self.random.choices(agent_class, weights)[0]
|
||||
self.add_agent(node_id=node_id, agent_class=a_class, **agent_params)
|
||||
self.add_agent(node_id=node_id, topology=self.G, agent_class=a_class, **agent_params)
|
||||
|
||||
|
||||
class EventedEnvironment(BaseEnvironment):
|
||||
|
Reference in New Issue
Block a user