mirror of
https://github.com/gsi-upm/soil
synced 2025-09-18 22:22:20 +00:00
Pre-release version of v1.0
This commit is contained in:
@@ -34,11 +34,13 @@ class BaseEnvironment(Model):
|
||||
:meth:`soil.environment.Environment.get` method.
|
||||
"""
|
||||
|
||||
collector_class = datacollection.SoilCollector
|
||||
|
||||
def __new__(cls,
|
||||
*args: Any,
|
||||
seed="default",
|
||||
dir_path=None,
|
||||
collector_class: type = datacollection.SoilCollector,
|
||||
collector_class: type = None,
|
||||
agent_reporters: Optional[Any] = None,
|
||||
model_reporters: Optional[Any] = None,
|
||||
tables: Optional[Any] = None,
|
||||
@@ -46,6 +48,7 @@ class BaseEnvironment(Model):
|
||||
"""Create a new model with a default seed value"""
|
||||
self = super().__new__(cls, *args, seed=seed, **kwargs)
|
||||
self.dir_path = dir_path or os.getcwd()
|
||||
collector_class = collector_class or cls.collector_class
|
||||
collector_class = serialization.deserialize(collector_class)
|
||||
self.datacollector = collector_class(
|
||||
model_reporters=model_reporters,
|
||||
@@ -69,6 +72,7 @@ class BaseEnvironment(Model):
|
||||
dir_path=None,
|
||||
schedule_class=time.TimedActivation,
|
||||
interval=1,
|
||||
logger = None,
|
||||
agents: Optional[Dict] = None,
|
||||
collector_class: type = datacollection.SoilCollector,
|
||||
agent_reporters: Optional[Any] = None,
|
||||
@@ -80,10 +84,15 @@ class BaseEnvironment(Model):
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.current_id = -1
|
||||
|
||||
self.id = id
|
||||
|
||||
if logger:
|
||||
self.logger = logger
|
||||
else:
|
||||
self.logger = utils.logger.getChild(self.id)
|
||||
|
||||
if schedule_class is None:
|
||||
schedule_class = time.TimedActivation
|
||||
@@ -93,8 +102,6 @@ class BaseEnvironment(Model):
|
||||
self.interval = interval
|
||||
self.schedule = schedule_class(self)
|
||||
|
||||
self.logger = utils.logger.getChild(self.id)
|
||||
|
||||
for (k, v) in env_params.items():
|
||||
self[k] = v
|
||||
|
||||
@@ -102,6 +109,7 @@ class BaseEnvironment(Model):
|
||||
self.add_agents(**agents)
|
||||
if init:
|
||||
self.init()
|
||||
self.datacollector.collect(self)
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
@@ -115,6 +123,22 @@ class BaseEnvironment(Model):
|
||||
|
||||
def count_agents(self, *args, **kwargs):
|
||||
return sum(1 for i in self.agents(*args, **kwargs))
|
||||
|
||||
def agent_df(self, steps=False):
|
||||
df = self.datacollector.get_agent_vars_dataframe()
|
||||
if steps:
|
||||
df.index.rename(["step", "agent_id"], inplace=True)
|
||||
return df
|
||||
model_df = self.datacollector.get_model_vars_dataframe()
|
||||
df.index = df.index.set_levels(model_df.time, level=0).rename(["time", "agent_id"])
|
||||
return df
|
||||
|
||||
def model_df(self, steps=False):
|
||||
df = self.datacollector.get_model_vars_dataframe()
|
||||
if steps:
|
||||
return df
|
||||
df.index.rename("step", inplace=True)
|
||||
return df.reset_index().set_index("time")
|
||||
|
||||
@property
|
||||
def now(self):
|
||||
@@ -171,11 +195,12 @@ class BaseEnvironment(Model):
|
||||
self.schedule.step()
|
||||
self.datacollector.collect(self)
|
||||
|
||||
msg = "Model data:\n"
|
||||
max_width = max(len(k) for k in self.datacollector.model_vars.keys())
|
||||
for (k, v) in self.datacollector.model_vars.items():
|
||||
msg += f"\t{k:<{max_width}}: {v[-1]:>6}\n"
|
||||
self.logger.info(f"--- Steps: {self.schedule.steps:^5} - Time: {self.now:^5} --- " + msg)
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
msg = "Model data:\n"
|
||||
max_width = max(len(k) for k in self.datacollector.model_vars.keys())
|
||||
for (k, v) in self.datacollector.model_vars.items():
|
||||
msg += f"\t{k:<{max_width}}: {v[-1]:>6}\n"
|
||||
self.logger.debug(f"--- Steps: {self.schedule.steps:^5} - Time: {self.now:^5} --- " + msg)
|
||||
|
||||
def add_model_reporter(self, name, func=None):
|
||||
if not func:
|
||||
@@ -186,9 +211,18 @@ class BaseEnvironment(Model):
|
||||
if agent_type:
|
||||
reporter = lambda a: getattr(a, name) if isinstance(a, agent_type) else None
|
||||
else:
|
||||
reporter = name
|
||||
reporter = lambda a: getattr(a, name, None)
|
||||
self.datacollector._new_agent_reporter(name, reporter)
|
||||
|
||||
@classmethod
|
||||
def run(cls, *,
|
||||
iterations=1,
|
||||
num_processes=1, **kwargs):
|
||||
from .simulation import Simulation
|
||||
return Simulation(name=cls.__name__,
|
||||
model=cls, iterations=iterations,
|
||||
num_processes=num_processes, **kwargs).run()
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
return getattr(self, key)
|
||||
@@ -250,6 +284,7 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
self._check_agent_nodes()
|
||||
if init:
|
||||
self.init()
|
||||
self.datacollector.collect(self)
|
||||
|
||||
def add_agent(self, agent_class, *args, node_id=None, topology=None, **kwargs):
|
||||
if node_id is None and topology is None:
|
||||
@@ -373,7 +408,7 @@ class EventedEnvironment(BaseEnvironment):
|
||||
for agent in self.agents(**kwargs):
|
||||
if agent == sender:
|
||||
continue
|
||||
self.logger.info(f"Telling {repr(agent)}: {msg} ttl={ttl}")
|
||||
self.logger.debug(f"Telling {repr(agent)}: {msg} ttl={ttl}")
|
||||
try:
|
||||
inbox = agent._inbox
|
||||
except AttributeError:
|
||||
|
Reference in New Issue
Block a user