mirror of
https://github.com/gsi-upm/soil
synced 2024-11-14 15:32:29 +00:00
Formatted with black
This commit is contained in:
parent
d9947c2c52
commit
78833a9e08
177
soil/__init__.py
177
soil/__init__.py
@ -22,58 +22,107 @@ from .utils import logger
|
||||
from .time import *
|
||||
|
||||
|
||||
def main(cfg='simulation.yml', exporters=None, parallel=None, output="soil_output", *, do_run=False, debug=False, **kwargs):
|
||||
def main(
|
||||
cfg="simulation.yml",
|
||||
exporters=None,
|
||||
parallel=None,
|
||||
output="soil_output",
|
||||
*,
|
||||
do_run=False,
|
||||
debug=False,
|
||||
**kwargs,
|
||||
):
|
||||
import argparse
|
||||
from . import simulation
|
||||
|
||||
logger.info('Running SOIL version: {}'.format(__version__))
|
||||
logger.info("Running SOIL version: {}".format(__version__))
|
||||
|
||||
parser = argparse.ArgumentParser(description='Run a SOIL simulation')
|
||||
parser.add_argument('file', type=str,
|
||||
nargs="?",
|
||||
default=cfg,
|
||||
help='Configuration file for the simulation (e.g., YAML or JSON)')
|
||||
parser.add_argument('--version', action='store_true',
|
||||
help='Show version info and exit')
|
||||
parser.add_argument('--module', '-m', type=str,
|
||||
help='file containing the code of any custom agents.')
|
||||
parser.add_argument('--dry-run', '--dry', action='store_true',
|
||||
help='Do not store the results of the simulation to disk, show in terminal instead.')
|
||||
parser.add_argument('--pdb', action='store_true',
|
||||
help='Use a pdb console in case of exception.')
|
||||
parser.add_argument('--debug', action='store_true',
|
||||
help='Run a customized version of a pdb console to debug a simulation.')
|
||||
parser.add_argument('--graph', '-g', action='store_true',
|
||||
help='Dump each trial\'s network topology as a GEXF graph. Defaults to false.')
|
||||
parser.add_argument('--csv', action='store_true',
|
||||
help='Dump all data collected in CSV format. Defaults to false.')
|
||||
parser.add_argument('--level', type=str,
|
||||
help='Logging level')
|
||||
parser.add_argument('--output', '-o', type=str, default=output or "soil_output",
|
||||
help='folder to write results to. It defaults to the current directory.')
|
||||
parser = argparse.ArgumentParser(description="Run a SOIL simulation")
|
||||
parser.add_argument(
|
||||
"file",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default=cfg,
|
||||
help="Configuration file for the simulation (e.g., YAML or JSON)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version", action="store_true", help="Show version info and exit"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--module",
|
||||
"-m",
|
||||
type=str,
|
||||
help="file containing the code of any custom agents.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
"--dry",
|
||||
action="store_true",
|
||||
help="Do not store the results of the simulation to disk, show in terminal instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pdb", action="store_true", help="Use a pdb console in case of exception."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Run a customized version of a pdb console to debug a simulation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--graph",
|
||||
"-g",
|
||||
action="store_true",
|
||||
help="Dump each trial's network topology as a GEXF graph. Defaults to false.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
action="store_true",
|
||||
help="Dump all data collected in CSV format. Defaults to false.",
|
||||
)
|
||||
parser.add_argument("--level", type=str, help="Logging level")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
type=str,
|
||||
default=output or "soil_output",
|
||||
help="folder to write results to. It defaults to the current directory.",
|
||||
)
|
||||
if parallel is None:
|
||||
parser.add_argument('--synchronous', action='store_true',
|
||||
help='Run trials serially and synchronously instead of in parallel. Defaults to false.')
|
||||
parser.add_argument(
|
||||
"--synchronous",
|
||||
action="store_true",
|
||||
help="Run trials serially and synchronously instead of in parallel. Defaults to false.",
|
||||
)
|
||||
|
||||
parser.add_argument('-e', '--exporter', action='append',
|
||||
default=[],
|
||||
help='Export environment and/or simulations using this exporter')
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--exporter",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Export environment and/or simulations using this exporter",
|
||||
)
|
||||
|
||||
parser.add_argument('--only-convert', '--convert', action='store_true',
|
||||
help='Do not run the simulation, only convert the configuration file(s) and output them.')
|
||||
parser.add_argument(
|
||||
"--only-convert",
|
||||
"--convert",
|
||||
action="store_true",
|
||||
help="Do not run the simulation, only convert the configuration file(s) and output them.",
|
||||
)
|
||||
|
||||
parser.add_argument("--set",
|
||||
metavar="KEY=VALUE",
|
||||
action='append',
|
||||
help="Set a number of parameters that will be passed to the simulation."
|
||||
"(do not put spaces before or after the = sign). "
|
||||
"If a value contains spaces, you should define "
|
||||
"it with double quotes: "
|
||||
'foo="this is a sentence". Note that '
|
||||
"values are always treated as strings.")
|
||||
parser.add_argument(
|
||||
"--set",
|
||||
metavar="KEY=VALUE",
|
||||
action="append",
|
||||
help="Set a number of parameters that will be passed to the simulation."
|
||||
"(do not put spaces before or after the = sign). "
|
||||
"If a value contains spaces, you should define "
|
||||
"it with double quotes: "
|
||||
'foo="this is a sentence". Note that '
|
||||
"values are always treated as strings.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.setLevel(getattr(logging, (args.level or 'INFO').upper()))
|
||||
logger.setLevel(getattr(logging, (args.level or "INFO").upper()))
|
||||
|
||||
if args.version:
|
||||
return
|
||||
@ -81,14 +130,16 @@ def main(cfg='simulation.yml', exporters=None, parallel=None, output="soil_outpu
|
||||
if parallel is None:
|
||||
parallel = not args.synchronous
|
||||
|
||||
exporters = exporters or ['default', ]
|
||||
exporters = exporters or [
|
||||
"default",
|
||||
]
|
||||
for exp in args.exporter:
|
||||
if exp not in exporters:
|
||||
exporters.append(exp)
|
||||
if args.csv:
|
||||
exporters.append('csv')
|
||||
exporters.append("csv")
|
||||
if args.graph:
|
||||
exporters.append('gexf')
|
||||
exporters.append("gexf")
|
||||
|
||||
if os.getcwd() not in sys.path:
|
||||
sys.path.append(os.getcwd())
|
||||
@ -97,38 +148,38 @@ def main(cfg='simulation.yml', exporters=None, parallel=None, output="soil_outpu
|
||||
if output is None:
|
||||
output = args.output
|
||||
|
||||
|
||||
logger.info('Loading config file: {}'.format(args.file))
|
||||
logger.info("Loading config file: {}".format(args.file))
|
||||
|
||||
debug = debug or args.debug
|
||||
|
||||
if args.pdb or debug:
|
||||
args.synchronous = True
|
||||
|
||||
|
||||
res = []
|
||||
try:
|
||||
exp_params = {}
|
||||
|
||||
if not os.path.exists(args.file):
|
||||
logger.error('Please, input a valid file')
|
||||
logger.error("Please, input a valid file")
|
||||
return
|
||||
|
||||
for sim in simulation.iter_from_config(args.file,
|
||||
dry_run=args.dry_run,
|
||||
exporters=exporters,
|
||||
parallel=parallel,
|
||||
outdir=output,
|
||||
exporter_params=exp_params,
|
||||
**kwargs):
|
||||
for sim in simulation.iter_from_config(
|
||||
args.file,
|
||||
dry_run=args.dry_run,
|
||||
exporters=exporters,
|
||||
parallel=parallel,
|
||||
outdir=output,
|
||||
exporter_params=exp_params,
|
||||
**kwargs,
|
||||
):
|
||||
if args.set:
|
||||
for s in args.set:
|
||||
k, v = s.split('=', 1)[:2]
|
||||
k, v = s.split("=", 1)[:2]
|
||||
v = eval(v)
|
||||
tail, *head = k.rsplit('.', 1)[::-1]
|
||||
tail, *head = k.rsplit(".", 1)[::-1]
|
||||
target = sim
|
||||
if head:
|
||||
for part in head[0].split('.'):
|
||||
for part in head[0].split("."):
|
||||
try:
|
||||
target = getattr(target, part)
|
||||
except AttributeError:
|
||||
@ -144,19 +195,21 @@ def main(cfg='simulation.yml', exporters=None, parallel=None, output="soil_outpu
|
||||
if do_run:
|
||||
res.append(sim.run())
|
||||
else:
|
||||
print('not running')
|
||||
print("not running")
|
||||
res.append(sim)
|
||||
|
||||
except Exception as ex:
|
||||
if args.pdb:
|
||||
from .debugging import post_mortem
|
||||
|
||||
print(traceback.format_exc())
|
||||
post_mortem()
|
||||
else:
|
||||
raise
|
||||
if debug:
|
||||
from .debugging import set_trace
|
||||
os.environ['SOIL_DEBUG'] = 'true'
|
||||
|
||||
os.environ["SOIL_DEBUG"] = "true"
|
||||
set_trace()
|
||||
return res
|
||||
|
||||
@ -165,5 +218,5 @@ def easy(cfg, debug=False, **kwargs):
|
||||
return main(cfg, **kwargs)[0]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main(do_run=True)
|
||||
|
@ -1,7 +1,9 @@
|
||||
from . import main as init_main
|
||||
|
||||
|
||||
def main():
|
||||
init_main(do_run=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_main(do_run=True)
|
||||
|
@ -7,6 +7,7 @@ class BassModel(FSM):
|
||||
innovation_prob
|
||||
imitation_prob
|
||||
"""
|
||||
|
||||
sentimentCorrelation = 0
|
||||
|
||||
def step(self):
|
||||
@ -21,7 +22,7 @@ class BassModel(FSM):
|
||||
else:
|
||||
aware_neighbors = self.get_neighboring_agents(state_id=self.aware.id)
|
||||
num_neighbors_aware = len(aware_neighbors)
|
||||
if self.prob((self['imitation_prob']*num_neighbors_aware)):
|
||||
if self.prob((self["imitation_prob"] * num_neighbors_aware)):
|
||||
self.sentimentCorrelation = 1
|
||||
return self.aware
|
||||
|
||||
|
@ -20,28 +20,40 @@ class BigMarketModel(FSM):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.enterprises = self.env.environment_params['enterprises']
|
||||
self.enterprises = self.env.environment_params["enterprises"]
|
||||
self.type = ""
|
||||
|
||||
if self.id < len(self.enterprises): # Enterprises
|
||||
self.set_state(self.enterprise.id)
|
||||
self.type = "Enterprise"
|
||||
self.tweet_probability = environment.environment_params['tweet_probability_enterprises'][self.id]
|
||||
self.tweet_probability = environment.environment_params[
|
||||
"tweet_probability_enterprises"
|
||||
][self.id]
|
||||
else: # normal users
|
||||
self.type = "User"
|
||||
self.set_state(self.user.id)
|
||||
self.tweet_probability = environment.environment_params['tweet_probability_users']
|
||||
self.tweet_relevant_probability = environment.environment_params['tweet_relevant_probability']
|
||||
self.tweet_probability_about = environment.environment_params['tweet_probability_about'] # List
|
||||
self.sentiment_about = environment.environment_params['sentiment_about'] # List
|
||||
self.tweet_probability = environment.environment_params[
|
||||
"tweet_probability_users"
|
||||
]
|
||||
self.tweet_relevant_probability = environment.environment_params[
|
||||
"tweet_relevant_probability"
|
||||
]
|
||||
self.tweet_probability_about = environment.environment_params[
|
||||
"tweet_probability_about"
|
||||
] # List
|
||||
self.sentiment_about = environment.environment_params[
|
||||
"sentiment_about"
|
||||
] # List
|
||||
|
||||
@state
|
||||
def enterprise(self):
|
||||
|
||||
if self.random.random() < self.tweet_probability: # Tweets
|
||||
aware_neighbors = self.get_neighboring_agents(state_id=self.number_of_enterprises) # Nodes neighbour users
|
||||
aware_neighbors = self.get_neighboring_agents(
|
||||
state_id=self.number_of_enterprises
|
||||
) # Nodes neighbour users
|
||||
for x in aware_neighbors:
|
||||
if self.random.uniform(0,10) < 5:
|
||||
if self.random.uniform(0, 10) < 5:
|
||||
x.sentiment_about[self.id] += 0.1 # Increments for enterprise
|
||||
else:
|
||||
x.sentiment_about[self.id] -= 0.1 # Decrements for enterprise
|
||||
@ -49,15 +61,19 @@ class BigMarketModel(FSM):
|
||||
# Establecemos limites
|
||||
if x.sentiment_about[self.id] > 1:
|
||||
x.sentiment_about[self.id] = 1
|
||||
if x.sentiment_about[self.id]< -1:
|
||||
if x.sentiment_about[self.id] < -1:
|
||||
x.sentiment_about[self.id] = -1
|
||||
|
||||
x.attrs['sentiment_enterprise_%s'% self.enterprises[self.id]] = x.sentiment_about[self.id]
|
||||
x.attrs[
|
||||
"sentiment_enterprise_%s" % self.enterprises[self.id]
|
||||
] = x.sentiment_about[self.id]
|
||||
|
||||
@state
|
||||
def user(self):
|
||||
if self.random.random() < self.tweet_probability: # Tweets
|
||||
if self.random.random() < self.tweet_relevant_probability: # Tweets something relevant
|
||||
if (
|
||||
self.random.random() < self.tweet_relevant_probability
|
||||
): # Tweets something relevant
|
||||
# Tweet probability per enterprise
|
||||
for i in range(len(self.enterprises)):
|
||||
random_num = self.random.random()
|
||||
@ -65,23 +81,29 @@ class BigMarketModel(FSM):
|
||||
# The condition is fulfilled, sentiments are evaluated towards that enterprise
|
||||
if self.sentiment_about[i] < 0:
|
||||
# NEGATIVO
|
||||
self.userTweets("negative",i)
|
||||
self.userTweets("negative", i)
|
||||
elif self.sentiment_about[i] == 0:
|
||||
# NEUTRO
|
||||
pass
|
||||
else:
|
||||
# POSITIVO
|
||||
self.userTweets("positive",i)
|
||||
for i in range(len(self.enterprises)): # So that it never is set to 0 if there are not changes (logs)
|
||||
self.attrs['sentiment_enterprise_%s'% self.enterprises[i]] = self.sentiment_about[i]
|
||||
self.userTweets("positive", i)
|
||||
for i in range(
|
||||
len(self.enterprises)
|
||||
): # So that it never is set to 0 if there are not changes (logs)
|
||||
self.attrs[
|
||||
"sentiment_enterprise_%s" % self.enterprises[i]
|
||||
] = self.sentiment_about[i]
|
||||
|
||||
def userTweets(self, sentiment,enterprise):
|
||||
aware_neighbors = self.get_neighboring_agents(state_id=self.number_of_enterprises) # Nodes neighbours users
|
||||
def userTweets(self, sentiment, enterprise):
|
||||
aware_neighbors = self.get_neighboring_agents(
|
||||
state_id=self.number_of_enterprises
|
||||
) # Nodes neighbours users
|
||||
for x in aware_neighbors:
|
||||
if sentiment == "positive":
|
||||
x.sentiment_about[enterprise] +=0.003
|
||||
x.sentiment_about[enterprise] += 0.003
|
||||
elif sentiment == "negative":
|
||||
x.sentiment_about[enterprise] -=0.003
|
||||
x.sentiment_about[enterprise] -= 0.003
|
||||
else:
|
||||
pass
|
||||
|
||||
@ -91,4 +113,6 @@ class BigMarketModel(FSM):
|
||||
if x.sentiment_about[enterprise] < -1:
|
||||
x.sentiment_about[enterprise] = -1
|
||||
|
||||
x.attrs['sentiment_enterprise_%s'% self.enterprises[enterprise]] = x.sentiment_about[enterprise]
|
||||
x.attrs[
|
||||
"sentiment_enterprise_%s" % self.enterprises[enterprise]
|
||||
] = x.sentiment_about[enterprise]
|
||||
|
@ -15,9 +15,9 @@ class CounterModel(NetworkAgent):
|
||||
# Outside effects
|
||||
total = len(list(self.model.schedule._agents))
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['times'] = self.get('times', 0) + 1
|
||||
self['neighbors'] = neighbors
|
||||
self['total'] = total
|
||||
self["times"] = self.get("times", 0) + 1
|
||||
self["neighbors"] = neighbors
|
||||
self["total"] = total
|
||||
|
||||
|
||||
class AggregatedCounter(NetworkAgent):
|
||||
@ -32,9 +32,9 @@ class AggregatedCounter(NetworkAgent):
|
||||
|
||||
def step(self):
|
||||
# Outside effects
|
||||
self['times'] += 1
|
||||
self["times"] += 1
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
self['neighbors'] += neighbors
|
||||
self["neighbors"] += neighbors
|
||||
total = len(list(self.model.schedule.agents))
|
||||
self['total'] += total
|
||||
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
||||
self["total"] += total
|
||||
self.debug("Running for step: {}. Total: {}".format(self.now, total))
|
||||
|
@ -2,20 +2,20 @@ from scipy.spatial import cKDTree as KDTree
|
||||
import networkx as nx
|
||||
from . import NetworkAgent, as_node
|
||||
|
||||
|
||||
class Geo(NetworkAgent):
|
||||
'''In this type of network, nodes have a "pos" attribute.'''
|
||||
"""In this type of network, nodes have a "pos" attribute."""
|
||||
|
||||
def geo_search(self, radius, node=None, center=False, **kwargs):
|
||||
'''Get a list of nodes whose coordinates are closer than *radius* to *node*.'''
|
||||
"""Get a list of nodes whose coordinates are closer than *radius* to *node*."""
|
||||
node = as_node(node if node is not None else self)
|
||||
|
||||
G = self.subgraph(**kwargs)
|
||||
|
||||
pos = nx.get_node_attributes(G, 'pos')
|
||||
pos = nx.get_node_attributes(G, "pos")
|
||||
if not pos:
|
||||
return []
|
||||
nodes, coords = list(zip(*pos.items()))
|
||||
kdtree = KDTree(coords) # Cannot provide generator.
|
||||
indices = kdtree.query_ball_point(pos[node], radius)
|
||||
return [nodes[i] for i in indices if center or (nodes[i] != node)]
|
||||
|
||||
|
@ -11,10 +11,10 @@ class IndependentCascadeModel(BaseAgent):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.innovation_prob = self.env.environment_params['innovation_prob']
|
||||
self.imitation_prob = self.env.environment_params['imitation_prob']
|
||||
self.state['time_awareness'] = 0
|
||||
self.state['sentimentCorrelation'] = 0
|
||||
self.innovation_prob = self.env.environment_params["innovation_prob"]
|
||||
self.imitation_prob = self.env.environment_params["imitation_prob"]
|
||||
self.state["time_awareness"] = 0
|
||||
self.state["sentimentCorrelation"] = 0
|
||||
|
||||
def step(self):
|
||||
self.behaviour()
|
||||
@ -23,25 +23,27 @@ class IndependentCascadeModel(BaseAgent):
|
||||
aware_neighbors_1_time_step = []
|
||||
# Outside effects
|
||||
if self.prob(self.innovation_prob):
|
||||
if self.state['id'] == 0:
|
||||
self.state['id'] = 1
|
||||
self.state['sentimentCorrelation'] = 1
|
||||
self.state['time_awareness'] = self.env.now # To know when they have been infected
|
||||
if self.state["id"] == 0:
|
||||
self.state["id"] = 1
|
||||
self.state["sentimentCorrelation"] = 1
|
||||
self.state[
|
||||
"time_awareness"
|
||||
] = self.env.now # To know when they have been infected
|
||||
else:
|
||||
pass
|
||||
|
||||
return
|
||||
|
||||
# Imitation effects
|
||||
if self.state['id'] == 0:
|
||||
if self.state["id"] == 0:
|
||||
aware_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
for x in aware_neighbors:
|
||||
if x.state['time_awareness'] == (self.env.now-1):
|
||||
if x.state["time_awareness"] == (self.env.now - 1):
|
||||
aware_neighbors_1_time_step.append(x)
|
||||
num_neighbors_aware = len(aware_neighbors_1_time_step)
|
||||
if self.prob(self.imitation_prob*num_neighbors_aware):
|
||||
self.state['id'] = 1
|
||||
self.state['sentimentCorrelation'] = 1
|
||||
if self.prob(self.imitation_prob * num_neighbors_aware):
|
||||
self.state["id"] = 1
|
||||
self.state["sentimentCorrelation"] = 1
|
||||
else:
|
||||
pass
|
||||
|
||||
|
@ -23,36 +23,49 @@ class SpreadModelM2(BaseAgent):
|
||||
def __init__(self, model=None, unique_id=0, state=()):
|
||||
super().__init__(model=environment, unique_id=unique_id, state=state)
|
||||
|
||||
|
||||
# Use a single generator with the same seed as `self.random`
|
||||
random = np.random.default_rng(seed=self._seed)
|
||||
self.prob_neutral_making_denier = random.normal(environment.environment_params['prob_neutral_making_denier'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_neutral_making_denier = random.normal(
|
||||
environment.environment_params["prob_neutral_making_denier"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
|
||||
self.prob_infect = random.normal(environment.environment_params['prob_infect'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_infect = random.normal(
|
||||
environment.environment_params["prob_infect"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
|
||||
self.prob_cured_healing_infected = random.normal(environment.environment_params['prob_cured_healing_infected'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_cured_vaccinate_neutral = random.normal(environment.environment_params['prob_cured_vaccinate_neutral'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_cured_healing_infected = random.normal(
|
||||
environment.environment_params["prob_cured_healing_infected"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
self.prob_cured_vaccinate_neutral = random.normal(
|
||||
environment.environment_params["prob_cured_vaccinate_neutral"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
|
||||
self.prob_vaccinated_healing_infected = random.normal(environment.environment_params['prob_vaccinated_healing_infected'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_vaccinated_vaccinate_neutral = random.normal(environment.environment_params['prob_vaccinated_vaccinate_neutral'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_generate_anti_rumor = random.normal(environment.environment_params['prob_generate_anti_rumor'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_vaccinated_healing_infected = random.normal(
|
||||
environment.environment_params["prob_vaccinated_healing_infected"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
self.prob_vaccinated_vaccinate_neutral = random.normal(
|
||||
environment.environment_params["prob_vaccinated_vaccinate_neutral"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
self.prob_generate_anti_rumor = random.normal(
|
||||
environment.environment_params["prob_generate_anti_rumor"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
|
||||
def step(self):
|
||||
|
||||
if self.state['id'] == 0: # Neutral
|
||||
if self.state["id"] == 0: # Neutral
|
||||
self.neutral_behaviour()
|
||||
elif self.state['id'] == 1: # Infected
|
||||
elif self.state["id"] == 1: # Infected
|
||||
self.infected_behaviour()
|
||||
elif self.state['id'] == 2: # Cured
|
||||
elif self.state["id"] == 2: # Cured
|
||||
self.cured_behaviour()
|
||||
elif self.state['id'] == 3: # Vaccinated
|
||||
elif self.state["id"] == 3: # Vaccinated
|
||||
self.vaccinated_behaviour()
|
||||
|
||||
def neutral_behaviour(self):
|
||||
@ -61,7 +74,7 @@ class SpreadModelM2(BaseAgent):
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
if len(infected_neighbors) > 0:
|
||||
if self.prob(self.prob_neutral_making_denier):
|
||||
self.state['id'] = 3 # Vaccinated making denier
|
||||
self.state["id"] = 3 # Vaccinated making denier
|
||||
|
||||
def infected_behaviour(self):
|
||||
|
||||
@ -69,7 +82,7 @@ class SpreadModelM2(BaseAgent):
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_infect):
|
||||
neighbor.state['id'] = 1 # Infected
|
||||
neighbor.state["id"] = 1 # Infected
|
||||
|
||||
def cured_behaviour(self):
|
||||
|
||||
@ -77,13 +90,13 @@ class SpreadModelM2(BaseAgent):
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||
neighbor.state['id'] = 3 # Vaccinated
|
||||
neighbor.state["id"] = 3 # Vaccinated
|
||||
|
||||
# Cure
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
for neighbor in infected_neighbors:
|
||||
if self.prob(self.prob_cured_healing_infected):
|
||||
neighbor.state['id'] = 2 # Cured
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
|
||||
def vaccinated_behaviour(self):
|
||||
|
||||
@ -91,19 +104,19 @@ class SpreadModelM2(BaseAgent):
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
for neighbor in infected_neighbors:
|
||||
if self.prob(self.prob_cured_healing_infected):
|
||||
neighbor.state['id'] = 2 # Cured
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
|
||||
# Vaccinate
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||
neighbor.state['id'] = 3 # Vaccinated
|
||||
neighbor.state["id"] = 3 # Vaccinated
|
||||
|
||||
# Generate anti-rumor
|
||||
infected_neighbors_2 = self.get_neighboring_agents(state_id=1)
|
||||
for neighbor in infected_neighbors_2:
|
||||
if self.prob(self.prob_generate_anti_rumor):
|
||||
neighbor.state['id'] = 2 # Cured
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
|
||||
|
||||
class ControlModelM2(BaseAgent):
|
||||
@ -124,51 +137,64 @@ class ControlModelM2(BaseAgent):
|
||||
prob_generate_anti_rumor
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, model=None, unique_id=0, state=()):
|
||||
super().__init__(model=environment, unique_id=unique_id, state=state)
|
||||
|
||||
self.prob_neutral_making_denier = np.random.normal(environment.environment_params['prob_neutral_making_denier'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_neutral_making_denier = np.random.normal(
|
||||
environment.environment_params["prob_neutral_making_denier"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
|
||||
self.prob_infect = np.random.normal(environment.environment_params['prob_infect'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_infect = np.random.normal(
|
||||
environment.environment_params["prob_infect"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
|
||||
self.prob_cured_healing_infected = np.random.normal(environment.environment_params['prob_cured_healing_infected'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_cured_vaccinate_neutral = np.random.normal(environment.environment_params['prob_cured_vaccinate_neutral'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_cured_healing_infected = np.random.normal(
|
||||
environment.environment_params["prob_cured_healing_infected"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
self.prob_cured_vaccinate_neutral = np.random.normal(
|
||||
environment.environment_params["prob_cured_vaccinate_neutral"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
|
||||
self.prob_vaccinated_healing_infected = np.random.normal(environment.environment_params['prob_vaccinated_healing_infected'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_vaccinated_vaccinate_neutral = np.random.normal(environment.environment_params['prob_vaccinated_vaccinate_neutral'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_generate_anti_rumor = np.random.normal(environment.environment_params['prob_generate_anti_rumor'],
|
||||
environment.environment_params['standard_variance'])
|
||||
self.prob_vaccinated_healing_infected = np.random.normal(
|
||||
environment.environment_params["prob_vaccinated_healing_infected"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
self.prob_vaccinated_vaccinate_neutral = np.random.normal(
|
||||
environment.environment_params["prob_vaccinated_vaccinate_neutral"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
self.prob_generate_anti_rumor = np.random.normal(
|
||||
environment.environment_params["prob_generate_anti_rumor"],
|
||||
environment.environment_params["standard_variance"],
|
||||
)
|
||||
|
||||
def step(self):
|
||||
|
||||
if self.state['id'] == 0: # Neutral
|
||||
if self.state["id"] == 0: # Neutral
|
||||
self.neutral_behaviour()
|
||||
elif self.state['id'] == 1: # Infected
|
||||
elif self.state["id"] == 1: # Infected
|
||||
self.infected_behaviour()
|
||||
elif self.state['id'] == 2: # Cured
|
||||
elif self.state["id"] == 2: # Cured
|
||||
self.cured_behaviour()
|
||||
elif self.state['id'] == 3: # Vaccinated
|
||||
elif self.state["id"] == 3: # Vaccinated
|
||||
self.vaccinated_behaviour()
|
||||
elif self.state['id'] == 4: # Beacon-off
|
||||
elif self.state["id"] == 4: # Beacon-off
|
||||
self.beacon_off_behaviour()
|
||||
elif self.state['id'] == 5: # Beacon-on
|
||||
elif self.state["id"] == 5: # Beacon-on
|
||||
self.beacon_on_behaviour()
|
||||
|
||||
def neutral_behaviour(self):
|
||||
self.state['visible'] = False
|
||||
self.state["visible"] = False
|
||||
|
||||
# Infected
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
if len(infected_neighbors) > 0:
|
||||
if self.random(self.prob_neutral_making_denier):
|
||||
self.state['id'] = 3 # Vaccinated making denier
|
||||
self.state["id"] = 3 # Vaccinated making denier
|
||||
|
||||
def infected_behaviour(self):
|
||||
|
||||
@ -176,69 +202,69 @@ class ControlModelM2(BaseAgent):
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_infect):
|
||||
neighbor.state['id'] = 1 # Infected
|
||||
self.state['visible'] = False
|
||||
neighbor.state["id"] = 1 # Infected
|
||||
self.state["visible"] = False
|
||||
|
||||
def cured_behaviour(self):
|
||||
|
||||
self.state['visible'] = True
|
||||
self.state["visible"] = True
|
||||
# Vaccinate
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||
neighbor.state['id'] = 3 # Vaccinated
|
||||
neighbor.state["id"] = 3 # Vaccinated
|
||||
|
||||
# Cure
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
for neighbor in infected_neighbors:
|
||||
if self.prob(self.prob_cured_healing_infected):
|
||||
neighbor.state['id'] = 2 # Cured
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
|
||||
def vaccinated_behaviour(self):
|
||||
self.state['visible'] = True
|
||||
self.state["visible"] = True
|
||||
|
||||
# Cure
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
for neighbor in infected_neighbors:
|
||||
if self.prob(self.prob_cured_healing_infected):
|
||||
neighbor.state['id'] = 2 # Cured
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
|
||||
# Vaccinate
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||
neighbor.state['id'] = 3 # Vaccinated
|
||||
neighbor.state["id"] = 3 # Vaccinated
|
||||
|
||||
# Generate anti-rumor
|
||||
infected_neighbors_2 = self.get_neighboring_agents(state_id=1)
|
||||
for neighbor in infected_neighbors_2:
|
||||
if self.prob(self.prob_generate_anti_rumor):
|
||||
neighbor.state['id'] = 2 # Cured
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
|
||||
def beacon_off_behaviour(self):
|
||||
self.state['visible'] = False
|
||||
self.state["visible"] = False
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
if len(infected_neighbors) > 0:
|
||||
self.state['id'] == 5 # Beacon on
|
||||
self.state["id"] == 5 # Beacon on
|
||||
|
||||
def beacon_on_behaviour(self):
|
||||
self.state['visible'] = False
|
||||
self.state["visible"] = False
|
||||
# Cure (M2 feature added)
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
for neighbor in infected_neighbors:
|
||||
if self.prob(self.prob_generate_anti_rumor):
|
||||
neighbor.state['id'] = 2 # Cured
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
neutral_neighbors_infected = neighbor.get_neighboring_agents(state_id=0)
|
||||
for neighbor in neutral_neighbors_infected:
|
||||
if self.prob(self.prob_generate_anti_rumor):
|
||||
neighbor.state['id'] = 3 # Vaccinated
|
||||
neighbor.state["id"] = 3 # Vaccinated
|
||||
infected_neighbors_infected = neighbor.get_neighboring_agents(state_id=1)
|
||||
for neighbor in infected_neighbors_infected:
|
||||
if self.prob(self.prob_generate_anti_rumor):
|
||||
neighbor.state['id'] = 2 # Cured
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
|
||||
# Vaccinate
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||
neighbor.state['id'] = 3 # Vaccinated
|
||||
neighbor.state["id"] = 3 # Vaccinated
|
||||
|
@ -33,24 +33,32 @@ class SISaModel(FSM):
|
||||
|
||||
random = np.random.default_rng(seed=self._seed)
|
||||
|
||||
self.neutral_discontent_spon_prob = random.normal(self.env['neutral_discontent_spon_prob'],
|
||||
self.env['standard_variance'])
|
||||
self.neutral_discontent_infected_prob = random.normal(self.env['neutral_discontent_infected_prob'],
|
||||
self.env['standard_variance'])
|
||||
self.neutral_content_spon_prob = random.normal(self.env['neutral_content_spon_prob'],
|
||||
self.env['standard_variance'])
|
||||
self.neutral_content_infected_prob = random.normal(self.env['neutral_content_infected_prob'],
|
||||
self.env['standard_variance'])
|
||||
self.neutral_discontent_spon_prob = random.normal(
|
||||
self.env["neutral_discontent_spon_prob"], self.env["standard_variance"]
|
||||
)
|
||||
self.neutral_discontent_infected_prob = random.normal(
|
||||
self.env["neutral_discontent_infected_prob"], self.env["standard_variance"]
|
||||
)
|
||||
self.neutral_content_spon_prob = random.normal(
|
||||
self.env["neutral_content_spon_prob"], self.env["standard_variance"]
|
||||
)
|
||||
self.neutral_content_infected_prob = random.normal(
|
||||
self.env["neutral_content_infected_prob"], self.env["standard_variance"]
|
||||
)
|
||||
|
||||
self.discontent_neutral = random.normal(self.env['discontent_neutral'],
|
||||
self.env['standard_variance'])
|
||||
self.discontent_content = random.normal(self.env['discontent_content'],
|
||||
self.env['variance_d_c'])
|
||||
self.discontent_neutral = random.normal(
|
||||
self.env["discontent_neutral"], self.env["standard_variance"]
|
||||
)
|
||||
self.discontent_content = random.normal(
|
||||
self.env["discontent_content"], self.env["variance_d_c"]
|
||||
)
|
||||
|
||||
self.content_discontent = random.normal(self.env['content_discontent'],
|
||||
self.env['variance_c_d'])
|
||||
self.content_neutral = random.normal(self.env['content_neutral'],
|
||||
self.env['standard_variance'])
|
||||
self.content_discontent = random.normal(
|
||||
self.env["content_discontent"], self.env["variance_c_d"]
|
||||
)
|
||||
self.content_neutral = random.normal(
|
||||
self.env["content_neutral"], self.env["standard_variance"]
|
||||
)
|
||||
|
||||
@state
|
||||
def neutral(self):
|
||||
@ -88,7 +96,9 @@ class SISaModel(FSM):
|
||||
return self.neutral
|
||||
|
||||
# Superinfected
|
||||
discontent_neighbors = self.count_neighboring_agents(state_id=self.discontent.id)
|
||||
discontent_neighbors = self.count_neighboring_agents(
|
||||
state_id=self.discontent.id
|
||||
)
|
||||
if self.prob(scontent_neighbors * self.content_discontent):
|
||||
self.discontent
|
||||
return self.content
|
||||
|
@ -17,15 +17,19 @@ class SentimentCorrelationModel(BaseAgent):
|
||||
|
||||
def __init__(self, environment, unique_id=0, state=()):
|
||||
super().__init__(model=environment, unique_id=unique_id, state=state)
|
||||
self.outside_effects_prob = environment.environment_params['outside_effects_prob']
|
||||
self.anger_prob = environment.environment_params['anger_prob']
|
||||
self.joy_prob = environment.environment_params['joy_prob']
|
||||
self.sadness_prob = environment.environment_params['sadness_prob']
|
||||
self.disgust_prob = environment.environment_params['disgust_prob']
|
||||
self.state['time_awareness'] = []
|
||||
self.outside_effects_prob = environment.environment_params[
|
||||
"outside_effects_prob"
|
||||
]
|
||||
self.anger_prob = environment.environment_params["anger_prob"]
|
||||
self.joy_prob = environment.environment_params["joy_prob"]
|
||||
self.sadness_prob = environment.environment_params["sadness_prob"]
|
||||
self.disgust_prob = environment.environment_params["disgust_prob"]
|
||||
self.state["time_awareness"] = []
|
||||
for i in range(4): # In this model we have 4 sentiments
|
||||
self.state['time_awareness'].append(0) # 0-> Anger, 1-> joy, 2->sadness, 3 -> disgust
|
||||
self.state['sentimentCorrelation'] = 0
|
||||
self.state["time_awareness"].append(
|
||||
0
|
||||
) # 0-> Anger, 1-> joy, 2->sadness, 3 -> disgust
|
||||
self.state["sentimentCorrelation"] = 0
|
||||
|
||||
def step(self):
|
||||
self.behaviour()
|
||||
@ -39,63 +43,73 @@ class SentimentCorrelationModel(BaseAgent):
|
||||
|
||||
angry_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
for x in angry_neighbors:
|
||||
if x.state['time_awareness'][0] > (self.env.now-500):
|
||||
if x.state["time_awareness"][0] > (self.env.now - 500):
|
||||
angry_neighbors_1_time_step.append(x)
|
||||
num_neighbors_angry = len(angry_neighbors_1_time_step)
|
||||
|
||||
joyful_neighbors = self.get_neighboring_agents(state_id=2)
|
||||
for x in joyful_neighbors:
|
||||
if x.state['time_awareness'][1] > (self.env.now-500):
|
||||
if x.state["time_awareness"][1] > (self.env.now - 500):
|
||||
joyful_neighbors_1_time_step.append(x)
|
||||
num_neighbors_joyful = len(joyful_neighbors_1_time_step)
|
||||
|
||||
sad_neighbors = self.get_neighboring_agents(state_id=3)
|
||||
for x in sad_neighbors:
|
||||
if x.state['time_awareness'][2] > (self.env.now-500):
|
||||
if x.state["time_awareness"][2] > (self.env.now - 500):
|
||||
sad_neighbors_1_time_step.append(x)
|
||||
num_neighbors_sad = len(sad_neighbors_1_time_step)
|
||||
|
||||
disgusted_neighbors = self.get_neighboring_agents(state_id=4)
|
||||
for x in disgusted_neighbors:
|
||||
if x.state['time_awareness'][3] > (self.env.now-500):
|
||||
if x.state["time_awareness"][3] > (self.env.now - 500):
|
||||
disgusted_neighbors_1_time_step.append(x)
|
||||
num_neighbors_disgusted = len(disgusted_neighbors_1_time_step)
|
||||
|
||||
anger_prob = self.anger_prob+(len(angry_neighbors_1_time_step)*self.anger_prob)
|
||||
joy_prob = self.joy_prob+(len(joyful_neighbors_1_time_step)*self.joy_prob)
|
||||
sadness_prob = self.sadness_prob+(len(sad_neighbors_1_time_step)*self.sadness_prob)
|
||||
disgust_prob = self.disgust_prob+(len(disgusted_neighbors_1_time_step)*self.disgust_prob)
|
||||
anger_prob = self.anger_prob + (
|
||||
len(angry_neighbors_1_time_step) * self.anger_prob
|
||||
)
|
||||
joy_prob = self.joy_prob + (len(joyful_neighbors_1_time_step) * self.joy_prob)
|
||||
sadness_prob = self.sadness_prob + (
|
||||
len(sad_neighbors_1_time_step) * self.sadness_prob
|
||||
)
|
||||
disgust_prob = self.disgust_prob + (
|
||||
len(disgusted_neighbors_1_time_step) * self.disgust_prob
|
||||
)
|
||||
outside_effects_prob = self.outside_effects_prob
|
||||
|
||||
num = self.random.random()
|
||||
|
||||
if num<outside_effects_prob:
|
||||
self.state['id'] = self.random.randint(1, 4)
|
||||
if num < outside_effects_prob:
|
||||
self.state["id"] = self.random.randint(1, 4)
|
||||
|
||||
self.state['sentimentCorrelation'] = self.state['id'] # It is stored when it has been infected for the dynamic network
|
||||
self.state['time_awareness'][self.state['id']-1] = self.env.now
|
||||
self.state['sentiment'] = self.state['id']
|
||||
self.state["sentimentCorrelation"] = self.state[
|
||||
"id"
|
||||
] # It is stored when it has been infected for the dynamic network
|
||||
self.state["time_awareness"][self.state["id"] - 1] = self.env.now
|
||||
self.state["sentiment"] = self.state["id"]
|
||||
|
||||
if num < anger_prob:
|
||||
|
||||
if(num<anger_prob):
|
||||
self.state["id"] = 1
|
||||
self.state["sentimentCorrelation"] = 1
|
||||
self.state["time_awareness"][self.state["id"] - 1] = self.env.now
|
||||
elif num < joy_prob + anger_prob and num > anger_prob:
|
||||
|
||||
self.state['id'] = 1
|
||||
self.state['sentimentCorrelation'] = 1
|
||||
self.state['time_awareness'][self.state['id']-1] = self.env.now
|
||||
elif (num<joy_prob+anger_prob and num>anger_prob):
|
||||
self.state["id"] = 2
|
||||
self.state["sentimentCorrelation"] = 2
|
||||
self.state["time_awareness"][self.state["id"] - 1] = self.env.now
|
||||
elif num < sadness_prob + anger_prob + joy_prob and num > joy_prob + anger_prob:
|
||||
|
||||
self.state['id'] = 2
|
||||
self.state['sentimentCorrelation'] = 2
|
||||
self.state['time_awareness'][self.state['id']-1] = self.env.now
|
||||
elif (num<sadness_prob+anger_prob+joy_prob and num>joy_prob+anger_prob):
|
||||
self.state["id"] = 3
|
||||
self.state["sentimentCorrelation"] = 3
|
||||
self.state["time_awareness"][self.state["id"] - 1] = self.env.now
|
||||
elif (
|
||||
num < disgust_prob + sadness_prob + anger_prob + joy_prob
|
||||
and num > sadness_prob + anger_prob + joy_prob
|
||||
):
|
||||
|
||||
self.state['id'] = 3
|
||||
self.state['sentimentCorrelation'] = 3
|
||||
self.state['time_awareness'][self.state['id']-1] = self.env.now
|
||||
elif (num<disgust_prob+sadness_prob+anger_prob+joy_prob and num>sadness_prob+anger_prob+joy_prob):
|
||||
self.state["id"] = 4
|
||||
self.state["sentimentCorrelation"] = 4
|
||||
self.state["time_awareness"][self.state["id"] - 1] = self.env.now
|
||||
|
||||
self.state['id'] = 4
|
||||
self.state['sentimentCorrelation'] = 4
|
||||
self.state['time_awareness'][self.state['id']-1] = self.env.now
|
||||
|
||||
self.state['sentiment'] = self.state['id']
|
||||
self.state["sentiment"] = self.state["id"]
|
||||
|
@ -20,13 +20,13 @@ from typing import Dict, List
|
||||
from .. import serialization, utils, time, config
|
||||
|
||||
|
||||
|
||||
def as_node(agent):
|
||||
if isinstance(agent, BaseAgent):
|
||||
return agent.id
|
||||
return agent
|
||||
|
||||
IGNORED_FIELDS = ('model', 'logger')
|
||||
|
||||
IGNORED_FIELDS = ("model", "logger")
|
||||
|
||||
|
||||
class DeadAgent(Exception):
|
||||
@ -43,13 +43,18 @@ class MetaAgent(ABCMeta):
|
||||
defaults.update(i._defaults)
|
||||
|
||||
new_nmspc = {
|
||||
'_defaults': defaults,
|
||||
"_defaults": defaults,
|
||||
}
|
||||
|
||||
for attr, func in namespace.items():
|
||||
if isinstance(func, types.FunctionType) or isinstance(func, property) or isinstance(func, classmethod) or attr[0] == '_':
|
||||
if (
|
||||
isinstance(func, types.FunctionType)
|
||||
or isinstance(func, property)
|
||||
or isinstance(func, classmethod)
|
||||
or attr[0] == "_"
|
||||
):
|
||||
new_nmspc[attr] = func
|
||||
elif attr == 'defaults':
|
||||
elif attr == "defaults":
|
||||
defaults.update(func)
|
||||
else:
|
||||
defaults[attr] = copy(func)
|
||||
@ -69,12 +74,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
Any attribute that is not preceded by an underscore (`_`) will also be added to its state.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
unique_id,
|
||||
model,
|
||||
name=None,
|
||||
interval=None,
|
||||
**kwargs):
|
||||
def __init__(self, unique_id, model, name=None, interval=None, **kwargs):
|
||||
# Check for REQUIRED arguments
|
||||
# Initialize agent parameters
|
||||
if isinstance(unique_id, MesaAgent):
|
||||
@ -82,16 +82,19 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
assert isinstance(unique_id, int)
|
||||
super().__init__(unique_id=unique_id, model=model)
|
||||
|
||||
self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id)
|
||||
|
||||
self.name = (
|
||||
str(name) if name else "{}[{}]".format(type(self).__name__, self.unique_id)
|
||||
)
|
||||
|
||||
self.alive = True
|
||||
|
||||
self.interval = interval or self.get('interval', 1)
|
||||
logger = utils.logger.getChild(getattr(self.model, 'id', self.model)).getChild(self.name)
|
||||
self.logger = logging.LoggerAdapter(logger, {'agent_name': self.name})
|
||||
self.interval = interval or self.get("interval", 1)
|
||||
logger = utils.logger.getChild(getattr(self.model, "id", self.model)).getChild(
|
||||
self.name
|
||||
)
|
||||
self.logger = logging.LoggerAdapter(logger, {"agent_name": self.name})
|
||||
|
||||
if hasattr(self, 'level'):
|
||||
if hasattr(self, "level"):
|
||||
self.logger.setLevel(self.level)
|
||||
|
||||
for (k, v) in self._defaults.items():
|
||||
@ -117,20 +120,22 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
def from_dict(cls, model, attrs, warn_extra=True):
|
||||
ignored = {}
|
||||
args = {}
|
||||
for k, v in attrs.items():
|
||||
for k, v in attrs.items():
|
||||
if k in inspect.signature(cls).parameters:
|
||||
args[k] = v
|
||||
else:
|
||||
ignored[k] = v
|
||||
if ignored and warn_extra:
|
||||
utils.logger.info(f'Ignoring the following arguments for agent class { agent_class.__name__ }: { ignored }')
|
||||
utils.logger.info(
|
||||
f"Ignoring the following arguments for agent class { agent_class.__name__ }: { ignored }"
|
||||
)
|
||||
return cls(model=model, **args)
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError:
|
||||
raise KeyError(f'key {key} not found in agent')
|
||||
raise KeyError(f"key {key} not found in agent")
|
||||
|
||||
def __delitem__(self, key):
|
||||
return delattr(self, key)
|
||||
@ -148,7 +153,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
return self.items()
|
||||
|
||||
def keys(self):
|
||||
return (k for k in self.__dict__ if k[0] != '_' and k not in IGNORED_FIELDS)
|
||||
return (k for k in self.__dict__ if k[0] != "_" and k not in IGNORED_FIELDS)
|
||||
|
||||
def items(self, keys=None, skip=None):
|
||||
keys = keys if keys is not None else self.keys()
|
||||
@ -169,7 +174,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
return None
|
||||
|
||||
def die(self):
|
||||
self.info(f'agent dying')
|
||||
self.info(f"agent dying")
|
||||
self.alive = False
|
||||
return time.NEVER
|
||||
|
||||
@ -186,9 +191,9 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
for k, v in kwargs:
|
||||
message += " {k}={v} ".format(k, v)
|
||||
extra = {}
|
||||
extra['now'] = self.now
|
||||
extra['unique_id'] = self.unique_id
|
||||
extra['agent_name'] = self.name
|
||||
extra["now"] = self.now
|
||||
extra["unique_id"] = self.unique_id
|
||||
extra["agent_name"] = self.name
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
|
||||
def debug(self, *args, **kwargs):
|
||||
@ -214,10 +219,10 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
content = dict(self.items(keys=keys))
|
||||
if pretty and content:
|
||||
d = content
|
||||
content = '\n'
|
||||
content = "\n"
|
||||
for k, v in d.items():
|
||||
content += f'- {k}: {v}\n'
|
||||
content = textwrap.indent(content, ' ')
|
||||
content += f"- {k}: {v}\n"
|
||||
content = textwrap.indent(content, " ")
|
||||
return f"{repr(self)}{content}"
|
||||
|
||||
def __repr__(self):
|
||||
@ -225,7 +230,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
|
||||
|
||||
class NetworkAgent(BaseAgent):
|
||||
|
||||
def __init__(self, *args, topology, node_id, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@ -248,18 +252,21 @@ class NetworkAgent(BaseAgent):
|
||||
def node(self):
|
||||
return self.topology.nodes[self.node_id]
|
||||
|
||||
|
||||
def iter_agents(self, unique_id=None, *, limit_neighbors=False, **kwargs):
|
||||
unique_ids = None
|
||||
if isinstance(unique_id, list):
|
||||
unique_ids = set(unique_id)
|
||||
elif unique_id is not None:
|
||||
unique_ids = set([unique_id,])
|
||||
unique_ids = set(
|
||||
[
|
||||
unique_id,
|
||||
]
|
||||
)
|
||||
|
||||
if limit_neighbors:
|
||||
neighbor_ids = set()
|
||||
for node_id in self.G.neighbors(self.node_id):
|
||||
if self.G.nodes[node_id].get('agent') is not None:
|
||||
if self.G.nodes[node_id].get("agent") is not None:
|
||||
neighbor_ids.add(node_id)
|
||||
if unique_ids:
|
||||
unique_ids = unique_ids & neighbor_ids
|
||||
@ -272,7 +279,9 @@ class NetworkAgent(BaseAgent):
|
||||
|
||||
def subgraph(self, center=True, **kwargs):
|
||||
include = [self] if center else []
|
||||
G = self.G.subgraph(n.node_id for n in list(self.get_agents(**kwargs)+include))
|
||||
G = self.G.subgraph(
|
||||
n.node_id for n in list(self.get_agents(**kwargs) + include)
|
||||
)
|
||||
return G
|
||||
|
||||
def remove_node(self):
|
||||
@ -280,11 +289,19 @@ class NetworkAgent(BaseAgent):
|
||||
|
||||
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
||||
if self.node_id not in self.G.nodes(data=False):
|
||||
raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id))
|
||||
raise ValueError(
|
||||
"{} not in list of existing agents in the network".format(
|
||||
self.unique_id
|
||||
)
|
||||
)
|
||||
if other.node_id not in self.G.nodes(data=False):
|
||||
raise ValueError('{} not in list of existing agents in the network'.format(other))
|
||||
raise ValueError(
|
||||
"{} not in list of existing agents in the network".format(other)
|
||||
)
|
||||
|
||||
self.G.add_edge(self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
||||
self.G.add_edge(
|
||||
self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs
|
||||
)
|
||||
|
||||
def die(self, remove=True):
|
||||
if remove:
|
||||
@ -294,11 +311,11 @@ class NetworkAgent(BaseAgent):
|
||||
|
||||
def state(name=None):
|
||||
def decorator(func, name=None):
|
||||
'''
|
||||
"""
|
||||
A state function should return either a state id, or a tuple (state_id, when)
|
||||
The default value for state_id is the current state id.
|
||||
The default value for when is the interval defined in the environment.
|
||||
'''
|
||||
"""
|
||||
if inspect.isgeneratorfunction(func):
|
||||
orig_func = func
|
||||
|
||||
@ -348,32 +365,38 @@ class MetaFSM(MetaAgent):
|
||||
|
||||
# Add new states
|
||||
for attr, func in namespace.items():
|
||||
if hasattr(func, 'id'):
|
||||
if hasattr(func, "id"):
|
||||
if func.is_default:
|
||||
default_state = func
|
||||
states[func.id] = func
|
||||
|
||||
namespace.update({
|
||||
'_default_state': default_state,
|
||||
'_states': states,
|
||||
})
|
||||
namespace.update(
|
||||
{
|
||||
"_default_state": default_state,
|
||||
"_states": states,
|
||||
}
|
||||
)
|
||||
|
||||
return super(MetaFSM, mcls).__new__(mcls=mcls, name=name, bases=bases, namespace=namespace)
|
||||
return super(MetaFSM, mcls).__new__(
|
||||
mcls=mcls, name=name, bases=bases, namespace=namespace
|
||||
)
|
||||
|
||||
|
||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(FSM, self).__init__(*args, **kwargs)
|
||||
if not hasattr(self, 'state_id'):
|
||||
if not hasattr(self, "state_id"):
|
||||
if not self._default_state:
|
||||
raise ValueError('No default state specified for {}'.format(self.unique_id))
|
||||
raise ValueError(
|
||||
"No default state specified for {}".format(self.unique_id)
|
||||
)
|
||||
self.state_id = self._default_state.id
|
||||
|
||||
self._coroutine = None
|
||||
self.set_state(self.state_id)
|
||||
|
||||
def step(self):
|
||||
self.debug(f'Agent {self.unique_id} @ state {self.state_id}')
|
||||
self.debug(f"Agent {self.unique_id} @ state {self.state_id}")
|
||||
default_interval = super().step()
|
||||
|
||||
next_state = self._states[self.state_id](self)
|
||||
@ -386,7 +409,9 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
elif len(when) == 1:
|
||||
when = when[0]
|
||||
else:
|
||||
raise ValueError('Too many values returned. Only state (and time) allowed')
|
||||
raise ValueError(
|
||||
"Too many values returned. Only state (and time) allowed"
|
||||
)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
@ -396,10 +421,10 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
return when or default_interval
|
||||
|
||||
def set_state(self, state, when=None):
|
||||
if hasattr(state, 'id'):
|
||||
if hasattr(state, "id"):
|
||||
state = state.id
|
||||
if state not in self._states:
|
||||
raise ValueError('{} is not a valid state'.format(state))
|
||||
raise ValueError("{} is not a valid state".format(state))
|
||||
self.state_id = state
|
||||
if when is not None:
|
||||
self.model.schedule.add(self, when=when)
|
||||
@ -414,7 +439,7 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
|
||||
|
||||
def prob(prob, random):
|
||||
'''
|
||||
"""
|
||||
A true/False uniform distribution with a given probability.
|
||||
To be used like this:
|
||||
|
||||
@ -423,14 +448,13 @@ def prob(prob, random):
|
||||
if prob(0.3):
|
||||
do_something()
|
||||
|
||||
'''
|
||||
"""
|
||||
r = random.random()
|
||||
return r < prob
|
||||
|
||||
|
||||
def calculate_distribution(network_agents=None,
|
||||
agent_class=None):
|
||||
'''
|
||||
def calculate_distribution(network_agents=None, agent_class=None):
|
||||
"""
|
||||
Calculate the threshold values (thresholds for a uniform distribution)
|
||||
of an agent distribution given the weights of each agent type.
|
||||
|
||||
@ -453,26 +477,28 @@ def calculate_distribution(network_agents=None,
|
||||
|
||||
In this example, 20% of the nodes will be marked as type
|
||||
'agent_class_1'.
|
||||
'''
|
||||
"""
|
||||
if network_agents:
|
||||
network_agents = [deepcopy(agent) for agent in network_agents if not hasattr(agent, 'id')]
|
||||
network_agents = [
|
||||
deepcopy(agent) for agent in network_agents if not hasattr(agent, "id")
|
||||
]
|
||||
elif agent_class:
|
||||
network_agents = [{'agent_class': agent_class}]
|
||||
network_agents = [{"agent_class": agent_class}]
|
||||
else:
|
||||
raise ValueError('Specify a distribution or a default agent type')
|
||||
raise ValueError("Specify a distribution or a default agent type")
|
||||
|
||||
# Fix missing weights and incompatible types
|
||||
for x in network_agents:
|
||||
x['weight'] = float(x.get('weight', 1))
|
||||
x["weight"] = float(x.get("weight", 1))
|
||||
|
||||
# Calculate the thresholds
|
||||
total = sum(x['weight'] for x in network_agents)
|
||||
total = sum(x["weight"] for x in network_agents)
|
||||
acc = 0
|
||||
for v in network_agents:
|
||||
if 'ids' in v:
|
||||
if "ids" in v:
|
||||
continue
|
||||
upper = acc + (v['weight']/total)
|
||||
v['threshold'] = [acc, upper]
|
||||
upper = acc + (v["weight"] / total)
|
||||
v["threshold"] = [acc, upper]
|
||||
acc = upper
|
||||
return network_agents
|
||||
|
||||
@ -480,28 +506,29 @@ def calculate_distribution(network_agents=None,
|
||||
def serialize_type(agent_class, known_modules=[], **kwargs):
|
||||
if isinstance(agent_class, str):
|
||||
return agent_class
|
||||
known_modules += ['soil.agents']
|
||||
return serialization.serialize(agent_class, known_modules=known_modules, **kwargs)[1] # Get the name of the class
|
||||
known_modules += ["soil.agents"]
|
||||
return serialization.serialize(agent_class, known_modules=known_modules, **kwargs)[
|
||||
1
|
||||
] # Get the name of the class
|
||||
|
||||
|
||||
def serialize_definition(network_agents, known_modules=[]):
|
||||
'''
|
||||
"""
|
||||
When serializing an agent distribution, remove the thresholds, in order
|
||||
to avoid cluttering the YAML definition file.
|
||||
'''
|
||||
"""
|
||||
d = deepcopy(list(network_agents))
|
||||
for v in d:
|
||||
if 'threshold' in v:
|
||||
del v['threshold']
|
||||
v['agent_class'] = serialize_type(v['agent_class'],
|
||||
known_modules=known_modules)
|
||||
if "threshold" in v:
|
||||
del v["threshold"]
|
||||
v["agent_class"] = serialize_type(v["agent_class"], known_modules=known_modules)
|
||||
return d
|
||||
|
||||
|
||||
def deserialize_type(agent_class, known_modules=[]):
|
||||
if not isinstance(agent_class, str):
|
||||
return agent_class
|
||||
known = known_modules + ['soil.agents', 'soil.agents.custom' ]
|
||||
known = known_modules + ["soil.agents", "soil.agents.custom"]
|
||||
agent_class = serialization.deserializer(agent_class, known_modules=known)
|
||||
return agent_class
|
||||
|
||||
@ -509,12 +536,12 @@ def deserialize_type(agent_class, known_modules=[]):
|
||||
def deserialize_definition(ind, **kwargs):
|
||||
d = deepcopy(ind)
|
||||
for v in d:
|
||||
v['agent_class'] = deserialize_type(v['agent_class'], **kwargs)
|
||||
v["agent_class"] = deserialize_type(v["agent_class"], **kwargs)
|
||||
return d
|
||||
|
||||
|
||||
def _validate_states(states, topology):
|
||||
'''Validate states to avoid ignoring states during initialization'''
|
||||
"""Validate states to avoid ignoring states during initialization"""
|
||||
states = states or []
|
||||
if isinstance(states, dict):
|
||||
for x in states:
|
||||
@ -525,7 +552,7 @@ def _validate_states(states, topology):
|
||||
|
||||
|
||||
def _convert_agent_classs(ind, to_string=False, **kwargs):
|
||||
'''Convenience method to allow specifying agents by class or class name.'''
|
||||
"""Convenience method to allow specifying agents by class or class name."""
|
||||
if to_string:
|
||||
return serialize_definition(ind, **kwargs)
|
||||
return deserialize_definition(ind, **kwargs)
|
||||
@ -609,12 +636,10 @@ def _convert_agent_classs(ind, to_string=False, **kwargs):
|
||||
|
||||
|
||||
class AgentView(Mapping, Set):
|
||||
"""A lazy-loaded list of agents.
|
||||
"""
|
||||
"""A lazy-loaded list of agents."""
|
||||
|
||||
__slots__ = ("_agents",)
|
||||
|
||||
|
||||
def __init__(self, agents):
|
||||
self._agents = agents
|
||||
|
||||
@ -657,11 +682,20 @@ class AgentView(Mapping, Set):
|
||||
return f"{self.__class__.__name__}({self})"
|
||||
|
||||
|
||||
def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=None, ignore=None, state=None,
|
||||
limit=None, **kwargs):
|
||||
'''
|
||||
def filter_agents(
|
||||
agents,
|
||||
*id_args,
|
||||
unique_id=None,
|
||||
state_id=None,
|
||||
agent_class=None,
|
||||
ignore=None,
|
||||
state=None,
|
||||
limit=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
||||
'''
|
||||
"""
|
||||
assert isinstance(agents, dict)
|
||||
|
||||
ids = []
|
||||
@ -694,7 +728,7 @@ def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=N
|
||||
f = filter(lambda x: x not in ignore, f)
|
||||
|
||||
if state_id is not None:
|
||||
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
|
||||
f = filter(lambda agent: agent.get("state_id", None) in state_id, f)
|
||||
|
||||
if agent_class is not None:
|
||||
f = filter(lambda agent: isinstance(agent, agent_class), f)
|
||||
@ -711,23 +745,25 @@ def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=N
|
||||
yield from f
|
||||
|
||||
|
||||
def from_config(cfg: config.AgentConfig, random, topology: nx.Graph = None) -> List[Dict[str, Any]]:
|
||||
'''
|
||||
def from_config(
|
||||
cfg: config.AgentConfig, random, topology: nx.Graph = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
This function turns an agentconfig into a list of individual "agent specifications", which are just a dictionary
|
||||
with the parameters that the environment will use to construct each agent.
|
||||
|
||||
This function does NOT return a list of agents, mostly because some attributes to the agent are not known at the
|
||||
time of calling this function, such as `unique_id`.
|
||||
'''
|
||||
"""
|
||||
default = cfg or config.AgentConfig()
|
||||
if not isinstance(cfg, config.AgentConfig):
|
||||
cfg = config.AgentConfig(**cfg)
|
||||
return _agents_from_config(cfg, topology=topology, random=random)
|
||||
|
||||
|
||||
def _agents_from_config(cfg: config.AgentConfig,
|
||||
topology: nx.Graph,
|
||||
random) -> List[Dict[str, Any]]:
|
||||
def _agents_from_config(
|
||||
cfg: config.AgentConfig, topology: nx.Graph, random
|
||||
) -> List[Dict[str, Any]]:
|
||||
if cfg and not isinstance(cfg, config.AgentConfig):
|
||||
cfg = config.AgentConfig(**cfg)
|
||||
|
||||
@ -737,7 +773,9 @@ def _agents_from_config(cfg: config.AgentConfig,
|
||||
assigned_network = 0
|
||||
|
||||
if cfg.fixed is not None:
|
||||
agents, assigned_total, assigned_network = _from_fixed(cfg.fixed, topology=cfg.topology, default=cfg)
|
||||
agents, assigned_total, assigned_network = _from_fixed(
|
||||
cfg.fixed, topology=cfg.topology, default=cfg
|
||||
)
|
||||
|
||||
n = cfg.n
|
||||
|
||||
@ -749,46 +787,56 @@ def _agents_from_config(cfg: config.AgentConfig,
|
||||
|
||||
for d in cfg.distribution:
|
||||
if d.strategy == config.Strategy.topology:
|
||||
topo = d.topology if ('topology' in d.__fields_set__) else cfg.topology
|
||||
topo = d.topology if ("topology" in d.__fields_set__) else cfg.topology
|
||||
if not topo:
|
||||
raise ValueError('The "topology" strategy only works if the topology parameter is set to True')
|
||||
raise ValueError(
|
||||
'The "topology" strategy only works if the topology parameter is set to True'
|
||||
)
|
||||
if not topo_size:
|
||||
raise ValueError(f'Topology does not have enough free nodes to assign one to the agent')
|
||||
raise ValueError(
|
||||
f"Topology does not have enough free nodes to assign one to the agent"
|
||||
)
|
||||
|
||||
networked.append(d)
|
||||
|
||||
if d.strategy == config.Strategy.total:
|
||||
if not cfg.n:
|
||||
raise ValueError('Cannot use the "total" strategy without providing the total number of agents')
|
||||
raise ValueError(
|
||||
'Cannot use the "total" strategy without providing the total number of agents'
|
||||
)
|
||||
total.append(d)
|
||||
|
||||
|
||||
if networked:
|
||||
new_agents = _from_distro(networked,
|
||||
n= topo_size - assigned_network,
|
||||
topology=topo,
|
||||
default=cfg,
|
||||
random=random)
|
||||
new_agents = _from_distro(
|
||||
networked,
|
||||
n=topo_size - assigned_network,
|
||||
topology=topo,
|
||||
default=cfg,
|
||||
random=random,
|
||||
)
|
||||
assigned_total += len(new_agents)
|
||||
assigned_network += len(new_agents)
|
||||
agents += new_agents
|
||||
|
||||
if total:
|
||||
remaining = n - assigned_total
|
||||
agents += _from_distro(total, n=remaining,
|
||||
default=cfg,
|
||||
random=random)
|
||||
|
||||
remaining = n - assigned_total
|
||||
agents += _from_distro(total, n=remaining, default=cfg, random=random)
|
||||
|
||||
if assigned_network < topo_size:
|
||||
utils.logger.warn(f'The total number of agents does not match the total number of nodes in '
|
||||
'every topology. This may be due to a definition error: assigned: '
|
||||
f'{ assigned } total size: { topo_size }')
|
||||
utils.logger.warn(
|
||||
f"The total number of agents does not match the total number of nodes in "
|
||||
"every topology. This may be due to a definition error: assigned: "
|
||||
f"{ assigned } total size: { topo_size }"
|
||||
)
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
def _from_fixed(lst: List[config.FixedAgentConfig], topology: bool, default: config.SingleAgentConfig) -> List[Dict[str, Any]]:
|
||||
def _from_fixed(
|
||||
lst: List[config.FixedAgentConfig],
|
||||
topology: bool,
|
||||
default: config.SingleAgentConfig,
|
||||
) -> List[Dict[str, Any]]:
|
||||
agents = []
|
||||
|
||||
counts_total = 0
|
||||
@ -799,12 +847,18 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: bool, default: con
|
||||
if default:
|
||||
agent = default.state.copy()
|
||||
agent.update(fixed.state)
|
||||
cls = serialization.deserialize(fixed.agent_class or (default and default.agent_class))
|
||||
agent['agent_class'] = cls
|
||||
topo = fixed.topology if ('topology' in fixed.__fields_set__) else topology or default.topology
|
||||
cls = serialization.deserialize(
|
||||
fixed.agent_class or (default and default.agent_class)
|
||||
)
|
||||
agent["agent_class"] = cls
|
||||
topo = (
|
||||
fixed.topology
|
||||
if ("topology" in fixed.__fields_set__)
|
||||
else topology or default.topology
|
||||
)
|
||||
|
||||
if topo:
|
||||
agent['topology'] = True
|
||||
agent["topology"] = True
|
||||
counts_network += 1
|
||||
if not fixed.hidden:
|
||||
counts_total += 1
|
||||
@ -813,17 +867,21 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: bool, default: con
|
||||
return agents, counts_total, counts_network
|
||||
|
||||
|
||||
def _from_distro(distro: List[config.AgentDistro],
|
||||
n: int,
|
||||
topology: str,
|
||||
default: config.SingleAgentConfig,
|
||||
random) -> List[Dict[str, Any]]:
|
||||
def _from_distro(
|
||||
distro: List[config.AgentDistro],
|
||||
n: int,
|
||||
topology: str,
|
||||
default: config.SingleAgentConfig,
|
||||
random,
|
||||
) -> List[Dict[str, Any]]:
|
||||
|
||||
agents = []
|
||||
|
||||
if n is None:
|
||||
if any(lambda dist: dist.n is None, distro):
|
||||
raise ValueError('You must provide a total number of agents, or the number of each type')
|
||||
raise ValueError(
|
||||
"You must provide a total number of agents, or the number of each type"
|
||||
)
|
||||
n = sum(dist.n for dist in distro)
|
||||
|
||||
weights = list(dist.weight if dist.weight is not None else 1 for dist in distro)
|
||||
@ -836,29 +894,40 @@ def _from_distro(distro: List[config.AgentDistro],
|
||||
# So instead we calculate our own distribution to make sure the actual ratios are close to what we would expect
|
||||
|
||||
# Calculate how many times each has to appear
|
||||
indices = list(chain.from_iterable([idx] * int(n*chunk) for (idx, n) in enumerate(norm)))
|
||||
indices = list(
|
||||
chain.from_iterable([idx] * int(n * chunk) for (idx, n) in enumerate(norm))
|
||||
)
|
||||
|
||||
# Complete with random agents following the original weight distribution
|
||||
if len(indices) < n:
|
||||
indices += random.choices(list(range(len(distro))), weights=[d.weight for d in distro], k=n-len(indices))
|
||||
indices += random.choices(
|
||||
list(range(len(distro))),
|
||||
weights=[d.weight for d in distro],
|
||||
k=n - len(indices),
|
||||
)
|
||||
|
||||
# Deserialize classes for efficiency
|
||||
classes = list(serialization.deserialize(i.agent_class or default.agent_class) for i in distro)
|
||||
classes = list(
|
||||
serialization.deserialize(i.agent_class or default.agent_class) for i in distro
|
||||
)
|
||||
|
||||
# Add them in random order
|
||||
random.shuffle(indices)
|
||||
|
||||
|
||||
for idx in indices:
|
||||
d = distro[idx]
|
||||
agent = d.state.copy()
|
||||
cls = classes[idx]
|
||||
agent['agent_class'] = cls
|
||||
agent["agent_class"] = cls
|
||||
if default:
|
||||
agent.update(default.state)
|
||||
topology = d.topology if ('topology' in d.__fields_set__) else topology or default.topology
|
||||
topology = (
|
||||
d.topology
|
||||
if ("topology" in d.__fields_set__)
|
||||
else topology or default.topology
|
||||
)
|
||||
if topology:
|
||||
agent['topology'] = topology
|
||||
agent["topology"] = topology
|
||||
agents.append(agent)
|
||||
|
||||
return agents
|
||||
@ -877,4 +946,5 @@ try:
|
||||
from .Geo import Geo
|
||||
except ImportError:
|
||||
import sys
|
||||
print('Could not load the Geo Agent, scipy is not installed', file=sys.stderr)
|
||||
|
||||
print("Could not load the Geo Agent, scipy is not installed", file=sys.stderr)
|
||||
|
167
soil/config.py
167
soil/config.py
@ -19,6 +19,7 @@ import networkx as nx
|
||||
# Could use TypeAlias in python >= 3.10
|
||||
nodeId = int
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
id: nodeId
|
||||
state: Optional[Dict[str, Any]] = {}
|
||||
@ -54,14 +55,15 @@ class NetConfig(BaseModel):
|
||||
return NetConfig(topology=None, params=None)
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, values):
|
||||
if 'params' not in values and 'topology' not in values:
|
||||
raise ValueError('You must specify either a topology or the parameters to generate a graph')
|
||||
def validate_all(cls, values):
|
||||
if "params" not in values and "topology" not in values:
|
||||
raise ValueError(
|
||||
"You must specify either a topology or the parameters to generate a graph"
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
class EnvConfig(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def default():
|
||||
return EnvConfig()
|
||||
@ -80,9 +82,11 @@ class FixedAgentConfig(SingleAgentConfig):
|
||||
hidden: Optional[bool] = False # Do not count this agent towards total agent count
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, values):
|
||||
if values.get('unique_id', None) is not None and values.get('n', 1) > 1:
|
||||
raise ValueError(f"An unique_id can only be provided when there is only one agent ({values.get('n')} given)")
|
||||
def validate_all(cls, values):
|
||||
if values.get("unique_id", None) is not None and values.get("n", 1) > 1:
|
||||
raise ValueError(
|
||||
f"An unique_id can only be provided when there is only one agent ({values.get('n')} given)"
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
@ -91,8 +95,8 @@ class OverrideAgentConfig(FixedAgentConfig):
|
||||
|
||||
|
||||
class Strategy(Enum):
|
||||
topology = 'topology'
|
||||
total = 'total'
|
||||
topology = "topology"
|
||||
total = "total"
|
||||
|
||||
|
||||
class AgentDistro(SingleAgentConfig):
|
||||
@ -111,16 +115,20 @@ class AgentConfig(SingleAgentConfig):
|
||||
return AgentConfig()
|
||||
|
||||
@root_validator
|
||||
def validate_all(cls, values):
|
||||
if 'distribution' in values and ('n' not in values and 'topology' not in values):
|
||||
raise ValueError("You need to provide the number of agents or a topology to extract the value from.")
|
||||
def validate_all(cls, values):
|
||||
if "distribution" in values and (
|
||||
"n" not in values and "topology" not in values
|
||||
):
|
||||
raise ValueError(
|
||||
"You need to provide the number of agents or a topology to extract the value from."
|
||||
)
|
||||
return values
|
||||
|
||||
|
||||
class Config(BaseModel, extra=Extra.allow):
|
||||
version: Optional[str] = '1'
|
||||
version: Optional[str] = "1"
|
||||
|
||||
name: str = 'Unnamed Simulation'
|
||||
name: str = "Unnamed Simulation"
|
||||
description: Optional[str] = None
|
||||
group: str = None
|
||||
dir_path: Optional[str] = None
|
||||
@ -140,45 +148,48 @@ class Config(BaseModel, extra=Extra.allow):
|
||||
def from_raw(cls, cfg):
|
||||
if isinstance(cfg, Config):
|
||||
return cfg
|
||||
if cfg.get('version', '1') == '1' and any(k in cfg for k in ['agents', 'agent_class', 'topology', 'environment_class']):
|
||||
if cfg.get("version", "1") == "1" and any(
|
||||
k in cfg for k in ["agents", "agent_class", "topology", "environment_class"]
|
||||
):
|
||||
return convert_old(cfg)
|
||||
return Config(**cfg)
|
||||
|
||||
|
||||
def convert_old(old, strict=True):
|
||||
'''
|
||||
"""
|
||||
Try to convert old style configs into the new format.
|
||||
|
||||
This is still a work in progress and might not work in many cases.
|
||||
'''
|
||||
"""
|
||||
|
||||
utils.logger.warning('The old configuration format is deprecated. The converted file MAY NOT yield the right results')
|
||||
utils.logger.warning(
|
||||
"The old configuration format is deprecated. The converted file MAY NOT yield the right results"
|
||||
)
|
||||
|
||||
new = old.copy()
|
||||
|
||||
network = {}
|
||||
|
||||
if 'topology' in old:
|
||||
del new['topology']
|
||||
network['topology'] = old['topology']
|
||||
if "topology" in old:
|
||||
del new["topology"]
|
||||
network["topology"] = old["topology"]
|
||||
|
||||
if 'network_params' in old and old['network_params']:
|
||||
del new['network_params']
|
||||
for (k, v) in old['network_params'].items():
|
||||
if k == 'path':
|
||||
network['path'] = v
|
||||
if "network_params" in old and old["network_params"]:
|
||||
del new["network_params"]
|
||||
for (k, v) in old["network_params"].items():
|
||||
if k == "path":
|
||||
network["path"] = v
|
||||
else:
|
||||
network.setdefault('params', {})[k] = v
|
||||
network.setdefault("params", {})[k] = v
|
||||
|
||||
topology = None
|
||||
if network:
|
||||
topology = network
|
||||
|
||||
|
||||
agents = {'fixed': [], 'distribution': []}
|
||||
agents = {"fixed": [], "distribution": []}
|
||||
|
||||
def updated_agent(agent):
|
||||
'''Convert an agent definition'''
|
||||
"""Convert an agent definition"""
|
||||
newagent = dict(agent)
|
||||
return newagent
|
||||
|
||||
@ -186,80 +197,74 @@ def convert_old(old, strict=True):
|
||||
fixed = []
|
||||
override = []
|
||||
|
||||
if 'environment_agents' in new:
|
||||
if "environment_agents" in new:
|
||||
|
||||
for agent in new['environment_agents']:
|
||||
agent.setdefault('state', {})['group'] = 'environment'
|
||||
if 'agent_id' in agent:
|
||||
agent['state']['name'] = agent['agent_id']
|
||||
del agent['agent_id']
|
||||
agent['hidden'] = True
|
||||
agent['topology'] = False
|
||||
for agent in new["environment_agents"]:
|
||||
agent.setdefault("state", {})["group"] = "environment"
|
||||
if "agent_id" in agent:
|
||||
agent["state"]["name"] = agent["agent_id"]
|
||||
del agent["agent_id"]
|
||||
agent["hidden"] = True
|
||||
agent["topology"] = False
|
||||
fixed.append(updated_agent(agent))
|
||||
del new['environment_agents']
|
||||
del new["environment_agents"]
|
||||
|
||||
if "agent_class" in old:
|
||||
del new["agent_class"]
|
||||
agents["agent_class"] = old["agent_class"]
|
||||
|
||||
if 'agent_class' in old:
|
||||
del new['agent_class']
|
||||
agents['agent_class'] = old['agent_class']
|
||||
if "default_state" in old:
|
||||
del new["default_state"]
|
||||
agents["state"] = old["default_state"]
|
||||
|
||||
if 'default_state' in old:
|
||||
del new['default_state']
|
||||
agents['state'] = old['default_state']
|
||||
if "network_agents" in old:
|
||||
agents["topology"] = True
|
||||
|
||||
if 'network_agents' in old:
|
||||
agents['topology'] = True
|
||||
agents.setdefault("state", {})["group"] = "network"
|
||||
|
||||
agents.setdefault('state', {})['group'] = 'network'
|
||||
|
||||
for agent in new['network_agents']:
|
||||
for agent in new["network_agents"]:
|
||||
agent = updated_agent(agent)
|
||||
if 'agent_id' in agent:
|
||||
agent['state']['name'] = agent['agent_id']
|
||||
del agent['agent_id']
|
||||
if "agent_id" in agent:
|
||||
agent["state"]["name"] = agent["agent_id"]
|
||||
del agent["agent_id"]
|
||||
fixed.append(agent)
|
||||
else:
|
||||
by_weight.append(agent)
|
||||
del new['network_agents']
|
||||
|
||||
if 'agent_class' in old and (not fixed and not by_weight):
|
||||
agents['topology'] = True
|
||||
by_weight = [{'agent_class': old['agent_class'], 'weight': 1}]
|
||||
del new["network_agents"]
|
||||
|
||||
if "agent_class" in old and (not fixed and not by_weight):
|
||||
agents["topology"] = True
|
||||
by_weight = [{"agent_class": old["agent_class"], "weight": 1}]
|
||||
|
||||
# TODO: translate states properly
|
||||
if 'states' in old:
|
||||
del new['states']
|
||||
states = old['states']
|
||||
if "states" in old:
|
||||
del new["states"]
|
||||
states = old["states"]
|
||||
if isinstance(states, dict):
|
||||
states = states.items()
|
||||
else:
|
||||
states = enumerate(states)
|
||||
for (k, v) in states:
|
||||
override.append({'filter': {'node_id': k},
|
||||
'state': v})
|
||||
|
||||
agents['override'] = override
|
||||
agents['fixed'] = fixed
|
||||
agents['distribution'] = by_weight
|
||||
override.append({"filter": {"node_id": k}, "state": v})
|
||||
|
||||
agents["override"] = override
|
||||
agents["fixed"] = fixed
|
||||
agents["distribution"] = by_weight
|
||||
|
||||
model_params = {}
|
||||
if 'environment_params' in new:
|
||||
del new['environment_params']
|
||||
model_params = dict(old['environment_params'])
|
||||
if "environment_params" in new:
|
||||
del new["environment_params"]
|
||||
model_params = dict(old["environment_params"])
|
||||
|
||||
if 'environment_class' in old:
|
||||
del new['environment_class']
|
||||
new['model_class'] = old['environment_class']
|
||||
if "environment_class" in old:
|
||||
del new["environment_class"]
|
||||
new["model_class"] = old["environment_class"]
|
||||
|
||||
if 'dump' in old:
|
||||
del new['dump']
|
||||
new['dry_run'] = not old['dump']
|
||||
if "dump" in old:
|
||||
del new["dump"]
|
||||
new["dry_run"] = not old["dump"]
|
||||
|
||||
model_params['topology'] = topology
|
||||
model_params['agents'] = agents
|
||||
model_params["topology"] = topology
|
||||
model_params["agents"] = agents
|
||||
|
||||
return Config(version='2',
|
||||
model_params=model_params,
|
||||
**new)
|
||||
return Config(version="2", model_params=model_params, **new)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from mesa import DataCollector as MDC
|
||||
|
||||
class SoilDataCollector(MDC):
|
||||
|
||||
class SoilDataCollector(MDC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -18,9 +18,9 @@ def wrapcmd(func):
|
||||
known = globals()
|
||||
known.update(self.curframe.f_globals)
|
||||
known.update(self.curframe.f_locals)
|
||||
known['agent'] = known.get('self', None)
|
||||
known['model'] = known.get('self', {}).get('model')
|
||||
known['attrs'] = arg.strip().split()
|
||||
known["agent"] = known.get("self", None)
|
||||
known["model"] = known.get("self", {}).get("model")
|
||||
known["attrs"] = arg.strip().split()
|
||||
|
||||
exec(func.__code__, known, known)
|
||||
|
||||
@ -29,12 +29,12 @@ def wrapcmd(func):
|
||||
|
||||
class Debug(pdb.Pdb):
|
||||
def __init__(self, *args, skip_soil=False, **kwargs):
|
||||
skip = kwargs.get('skip', [])
|
||||
skip.append('soil')
|
||||
skip = kwargs.get("skip", [])
|
||||
skip.append("soil")
|
||||
if skip_soil:
|
||||
skip.append('soil')
|
||||
skip.append('soil.*')
|
||||
skip.append('mesa.*')
|
||||
skip.append("soil")
|
||||
skip.append("soil.*")
|
||||
skip.append("mesa.*")
|
||||
super(Debug, self).__init__(*args, skip=skip, **kwargs)
|
||||
self.prompt = "[soil-pdb] "
|
||||
|
||||
@ -42,7 +42,7 @@ class Debug(pdb.Pdb):
|
||||
def _soil_agents(model, attrs=None, pretty=True, **kwargs):
|
||||
for agent in model.agents(**kwargs):
|
||||
d = agent
|
||||
print(' - ' + indent(agent.to_str(keys=attrs, pretty=pretty), ' '))
|
||||
print(" - " + indent(agent.to_str(keys=attrs, pretty=pretty), " "))
|
||||
|
||||
@wrapcmd
|
||||
def do_soil_agents():
|
||||
@ -52,20 +52,20 @@ class Debug(pdb.Pdb):
|
||||
|
||||
@wrapcmd
|
||||
def do_soil_list():
|
||||
return Debug._soil_agents(model, attrs=['state_id'], pretty=False)
|
||||
return Debug._soil_agents(model, attrs=["state_id"], pretty=False)
|
||||
|
||||
do_sl = do_soil_list
|
||||
|
||||
def do_continue_state(self, arg):
|
||||
self.do_break_state(arg, temporary=True)
|
||||
return self.do_continue('')
|
||||
return self.do_continue("")
|
||||
|
||||
do_cs = do_continue_state
|
||||
|
||||
@wrapcmd
|
||||
def do_soil_agent():
|
||||
if not agent:
|
||||
print('No agent available')
|
||||
print("No agent available")
|
||||
return
|
||||
|
||||
keys = None
|
||||
@ -81,9 +81,9 @@ class Debug(pdb.Pdb):
|
||||
do_aa = do_soil_agent
|
||||
|
||||
def do_break_state(self, arg: str, instances=None, temporary=False):
|
||||
'''
|
||||
"""
|
||||
Break before a specified state is stepped into.
|
||||
'''
|
||||
"""
|
||||
|
||||
klass = None
|
||||
state = arg
|
||||
@ -95,39 +95,39 @@ class Debug(pdb.Pdb):
|
||||
if tokens:
|
||||
instances = list(eval(token) for token in tokens)
|
||||
|
||||
colon = state.find(':')
|
||||
colon = state.find(":")
|
||||
|
||||
if colon > 0:
|
||||
klass = state[:colon].rstrip()
|
||||
state = state[colon+1:].strip()
|
||||
|
||||
state = state[colon + 1 :].strip()
|
||||
|
||||
print(klass, state, tokens)
|
||||
klass = eval(klass,
|
||||
self.curframe.f_globals,
|
||||
self.curframe_locals)
|
||||
klass = eval(klass, self.curframe.f_globals, self.curframe_locals)
|
||||
|
||||
if klass:
|
||||
klasses = [klass]
|
||||
else:
|
||||
klasses = [k for k in self.curframe.f_globals.values() if isinstance(k, type) and issubclass(k, FSM)]
|
||||
klasses = [
|
||||
k
|
||||
for k in self.curframe.f_globals.values()
|
||||
if isinstance(k, type) and issubclass(k, FSM)
|
||||
]
|
||||
|
||||
if not klasses:
|
||||
self.error('No agent classes found')
|
||||
|
||||
self.error("No agent classes found")
|
||||
|
||||
for klass in klasses:
|
||||
try:
|
||||
func = getattr(klass, state)
|
||||
except AttributeError:
|
||||
self.error(f'State {state} not found in class {klass}')
|
||||
self.error(f"State {state} not found in class {klass}")
|
||||
continue
|
||||
if hasattr(func, '__func__'):
|
||||
if hasattr(func, "__func__"):
|
||||
func = func.__func__
|
||||
|
||||
code = func.__code__
|
||||
#use co_name to identify the bkpt (function names
|
||||
#could be aliased, but co_name is invariant)
|
||||
# use co_name to identify the bkpt (function names
|
||||
# could be aliased, but co_name is invariant)
|
||||
funcname = code.co_name
|
||||
lineno = code.co_firstlineno
|
||||
filename = code.co_filename
|
||||
@ -135,38 +135,36 @@ class Debug(pdb.Pdb):
|
||||
# Check for reasonable breakpoint
|
||||
line = self.checkline(filename, lineno)
|
||||
if not line:
|
||||
raise ValueError('no line found')
|
||||
raise ValueError("no line found")
|
||||
# now set the break point
|
||||
cond = None
|
||||
if instances:
|
||||
cond = f'self.unique_id in { repr(instances) }'
|
||||
cond = f"self.unique_id in { repr(instances) }"
|
||||
|
||||
existing = self.get_breaks(filename, line)
|
||||
if existing:
|
||||
self.message("Breakpoint already exists at %s:%d" %
|
||||
(filename, line))
|
||||
self.message("Breakpoint already exists at %s:%d" % (filename, line))
|
||||
continue
|
||||
err = self.set_break(filename, line, temporary, cond, funcname)
|
||||
if err:
|
||||
self.error(err)
|
||||
else:
|
||||
bp = self.get_breaks(filename, line)[-1]
|
||||
self.message("Breakpoint %d at %s:%d" %
|
||||
(bp.number, bp.file, bp.line))
|
||||
self.message("Breakpoint %d at %s:%d" % (bp.number, bp.file, bp.line))
|
||||
|
||||
do_bs = do_break_state
|
||||
|
||||
def do_break_state_self(self, arg: str, temporary=False):
|
||||
'''
|
||||
"""
|
||||
Break before a specified state is stepped into, for the current agent
|
||||
'''
|
||||
agent = self.curframe.f_locals.get('self')
|
||||
"""
|
||||
agent = self.curframe.f_locals.get("self")
|
||||
if not agent:
|
||||
self.error('No current agent.')
|
||||
self.error('Try this again when the debugger is stopped inside an agent')
|
||||
self.error("No current agent.")
|
||||
self.error("Try this again when the debugger is stopped inside an agent")
|
||||
return
|
||||
|
||||
arg = f'{agent.__class__.__name__}:{ arg } {agent.unique_id}'
|
||||
arg = f"{agent.__class__.__name__}:{ arg } {agent.unique_id}"
|
||||
return self.do_break_state(arg)
|
||||
|
||||
do_bss = do_break_state_self
|
||||
@ -174,6 +172,7 @@ class Debug(pdb.Pdb):
|
||||
|
||||
debugger = None
|
||||
|
||||
|
||||
def set_trace(frame=None, **kwargs):
|
||||
global debugger
|
||||
if debugger is None:
|
||||
|
@ -35,18 +35,20 @@ class BaseEnvironment(Model):
|
||||
:meth:`soil.environment.Environment.get` method.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
id='unnamed_env',
|
||||
seed='default',
|
||||
schedule=None,
|
||||
dir_path=None,
|
||||
interval=1,
|
||||
agent_class=None,
|
||||
agents: [tuple[type, Dict[str, Any]]] = {},
|
||||
agent_reporters: Optional[Any] = None,
|
||||
model_reporters: Optional[Any] = None,
|
||||
tables: Optional[Any] = None,
|
||||
**env_params):
|
||||
def __init__(
|
||||
self,
|
||||
id="unnamed_env",
|
||||
seed="default",
|
||||
schedule=None,
|
||||
dir_path=None,
|
||||
interval=1,
|
||||
agent_class=None,
|
||||
agents: [tuple[type, Dict[str, Any]]] = {},
|
||||
agent_reporters: Optional[Any] = None,
|
||||
model_reporters: Optional[Any] = None,
|
||||
tables: Optional[Any] = None,
|
||||
**env_params,
|
||||
):
|
||||
|
||||
super().__init__(seed=seed)
|
||||
self.env_params = env_params or {}
|
||||
@ -75,27 +77,26 @@ class BaseEnvironment(Model):
|
||||
)
|
||||
|
||||
def _agent_from_dict(self, agent):
|
||||
'''
|
||||
"""
|
||||
Translate an agent dictionary into an agent
|
||||
'''
|
||||
"""
|
||||
agent = dict(**agent)
|
||||
cls = agent.pop('agent_class', None) or self.agent_class
|
||||
unique_id = agent.pop('unique_id', None)
|
||||
cls = agent.pop("agent_class", None) or self.agent_class
|
||||
unique_id = agent.pop("unique_id", None)
|
||||
if unique_id is None:
|
||||
unique_id = self.next_id()
|
||||
|
||||
return serialization.deserialize(cls)(unique_id=unique_id,
|
||||
model=self, **agent)
|
||||
return serialization.deserialize(cls)(unique_id=unique_id, model=self, **agent)
|
||||
|
||||
def init_agents(self, agents: Union[config.AgentConfig, [Dict[str, Any]]] = {}):
|
||||
'''
|
||||
"""
|
||||
Initialize the agents in the model from either a `soil.config.AgentConfig` or a list of
|
||||
dictionaries that each describes an agent.
|
||||
|
||||
If given a list of dictionaries, an agent will be created for each dictionary. The agent
|
||||
class can be specified through the `agent_class` key. The rest of the items will be used
|
||||
as parameters to the agent.
|
||||
'''
|
||||
"""
|
||||
if not agents:
|
||||
return
|
||||
|
||||
@ -108,11 +109,10 @@ class BaseEnvironment(Model):
|
||||
override = lst.override
|
||||
lst = self._agent_dict_from_config(lst)
|
||||
|
||||
#TODO: check override is working again. It cannot (easily) be part of agents.from_config anymore,
|
||||
# TODO: check override is working again. It cannot (easily) be part of agents.from_config anymore,
|
||||
# because it needs attribute such as unique_id, which are only present after init
|
||||
new_agents = [self._agent_from_dict(agent) for agent in lst]
|
||||
|
||||
|
||||
for a in new_agents:
|
||||
self.schedule.add(a)
|
||||
|
||||
@ -122,8 +122,7 @@ class BaseEnvironment(Model):
|
||||
setattr(agent, attr, value)
|
||||
|
||||
def _agent_dict_from_config(self, cfg):
|
||||
return agentmod.from_config(cfg,
|
||||
random=self.random)
|
||||
return agentmod.from_config(cfg, random=self.random)
|
||||
|
||||
@property
|
||||
def agents(self):
|
||||
@ -139,18 +138,16 @@ class BaseEnvironment(Model):
|
||||
def now(self):
|
||||
if self.schedule:
|
||||
return self.schedule.time
|
||||
raise Exception('The environment has not been scheduled, so it has no sense of time')
|
||||
|
||||
raise Exception(
|
||||
"The environment has not been scheduled, so it has no sense of time"
|
||||
)
|
||||
|
||||
def add_agent(self, agent_class, unique_id=None, **kwargs):
|
||||
a = None
|
||||
if unique_id is None:
|
||||
unique_id = self.next_id()
|
||||
|
||||
|
||||
a = agent_class(model=self,
|
||||
unique_id=unique_id,
|
||||
**args)
|
||||
a = agent_class(model=self, unique_id=unique_id, **args)
|
||||
|
||||
self.schedule.add(a)
|
||||
return a
|
||||
@ -163,16 +160,16 @@ class BaseEnvironment(Model):
|
||||
for k, v in kwargs:
|
||||
message += " {k}={v} ".format(k, v)
|
||||
extra = {}
|
||||
extra['now'] = self.now
|
||||
extra['id'] = self.id
|
||||
extra["now"] = self.now
|
||||
extra["id"] = self.id
|
||||
return self.logger.log(level, message, extra=extra)
|
||||
|
||||
def step(self):
|
||||
'''
|
||||
"""
|
||||
Advance one step in the simulation, and update the data collection and scheduler appropriately
|
||||
'''
|
||||
"""
|
||||
super().step()
|
||||
self.logger.info(f'--- Step {self.now:^5} ---')
|
||||
self.logger.info(f"--- Step {self.now:^5} ---")
|
||||
self.schedule.step()
|
||||
self.datacollector.collect(self)
|
||||
|
||||
@ -180,10 +177,10 @@ class BaseEnvironment(Model):
|
||||
return key in self.env_params
|
||||
|
||||
def get(self, key, default=None):
|
||||
'''
|
||||
"""
|
||||
Get the value of an environment attribute.
|
||||
Return `default` if the value is not set.
|
||||
'''
|
||||
"""
|
||||
return self.env_params.get(key, default)
|
||||
|
||||
def __getitem__(self, key):
|
||||
@ -197,13 +194,15 @@ class BaseEnvironment(Model):
|
||||
|
||||
|
||||
class NetworkEnvironment(BaseEnvironment):
|
||||
'''
|
||||
"""
|
||||
The NetworkEnvironment is an environment that includes one or more networkx.Graph intances
|
||||
and methods to associate agents to nodes and vice versa.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, *args, topology: Union[config.NetConfig, nx.Graph] = None, **kwargs):
|
||||
agents = kwargs.pop('agents', None)
|
||||
def __init__(
|
||||
self, *args, topology: Union[config.NetConfig, nx.Graph] = None, **kwargs
|
||||
):
|
||||
agents = kwargs.pop("agents", None)
|
||||
super().__init__(*args, agents=None, **kwargs)
|
||||
|
||||
self._set_topology(topology)
|
||||
@ -211,37 +210,35 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
self.init_agents(agents)
|
||||
|
||||
def init_agents(self, *args, **kwargs):
|
||||
'''Initialize the agents from a '''
|
||||
"""Initialize the agents from a"""
|
||||
super().init_agents(*args, **kwargs)
|
||||
for agent in self.schedule._agents.values():
|
||||
if hasattr(agent, 'node_id'):
|
||||
if hasattr(agent, "node_id"):
|
||||
self._init_node(agent)
|
||||
|
||||
def _init_node(self, agent):
|
||||
'''
|
||||
"""
|
||||
Make sure the node for a given agent has the proper attributes.
|
||||
'''
|
||||
self.G.nodes[agent.node_id]['agent'] = agent
|
||||
"""
|
||||
self.G.nodes[agent.node_id]["agent"] = agent
|
||||
|
||||
def _agent_dict_from_config(self, cfg):
|
||||
return agentmod.from_config(cfg,
|
||||
topology=self.G,
|
||||
random=self.random)
|
||||
return agentmod.from_config(cfg, topology=self.G, random=self.random)
|
||||
|
||||
def _agent_from_dict(self, agent, unique_id=None):
|
||||
agent = dict(agent)
|
||||
|
||||
if not agent.get('topology', False):
|
||||
if not agent.get("topology", False):
|
||||
return super()._agent_from_dict(agent)
|
||||
|
||||
if unique_id is None:
|
||||
unique_id = self.next_id()
|
||||
node_id = agent.get('node_id', None)
|
||||
node_id = agent.get("node_id", None)
|
||||
if node_id is None:
|
||||
node_id = network.find_unassigned(self.G, random=self.random)
|
||||
agent['node_id'] = node_id
|
||||
agent['unique_id'] = unique_id
|
||||
agent['topology'] = self.G
|
||||
agent["node_id"] = node_id
|
||||
agent["unique_id"] = unique_id
|
||||
agent["topology"] = self.G
|
||||
node_attrs = self.G.nodes[node_id]
|
||||
node_attrs.update(agent)
|
||||
agent = node_attrs
|
||||
@ -269,32 +266,33 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
if unique_id is None:
|
||||
unique_id = self.next_id()
|
||||
if node_id is None:
|
||||
node_id = network.find_unassigned(G=self.G,
|
||||
shuffle=True,
|
||||
random=self.random)
|
||||
node_id = network.find_unassigned(
|
||||
G=self.G, shuffle=True, random=self.random
|
||||
)
|
||||
|
||||
if node_id in G.nodes:
|
||||
self.G.nodes[node_id]['agent'] = None # Reserve
|
||||
self.G.nodes[node_id]["agent"] = None # Reserve
|
||||
else:
|
||||
self.G.add_node(node_id)
|
||||
|
||||
a = self.add_agent(unique_id=unique_id, agent_class=agent_class, node_id=node_id, **kwargs)
|
||||
a['visible'] = True
|
||||
a = self.add_agent(
|
||||
unique_id=unique_id, agent_class=agent_class, node_id=node_id, **kwargs
|
||||
)
|
||||
a["visible"] = True
|
||||
return a
|
||||
|
||||
def agent_for_node_id(self, node_id):
|
||||
return self.G.nodes[node_id].get('agent')
|
||||
return self.G.nodes[node_id].get("agent")
|
||||
|
||||
def populate_network(self, agent_class, weights=None, **agent_params):
|
||||
if not hasattr(agent_class, 'len'):
|
||||
if not hasattr(agent_class, "len"):
|
||||
agent_class = [agent_class]
|
||||
weights = None
|
||||
for (node_id, node) in self.G.nodes(data=True):
|
||||
if 'agent' in node:
|
||||
if "agent" in node:
|
||||
continue
|
||||
a_class = self.random.choices(agent_class, weights)[0]
|
||||
self.add_agent(node_id=node_id,
|
||||
agent_class=a_class, **agent_params)
|
||||
self.add_agent(node_id=node_id, agent_class=a_class, **agent_params)
|
||||
|
||||
|
||||
Environment = NetworkEnvironment
|
||||
|
@ -24,56 +24,58 @@ class DryRunner(BytesIO):
|
||||
|
||||
def write(self, txt):
|
||||
if self.__copy_to:
|
||||
self.__copy_to.write('{}:::{}'.format(self.__fname, txt))
|
||||
self.__copy_to.write("{}:::{}".format(self.__fname, txt))
|
||||
try:
|
||||
super().write(txt)
|
||||
except TypeError:
|
||||
super().write(bytes(txt, 'utf-8'))
|
||||
super().write(bytes(txt, "utf-8"))
|
||||
|
||||
def close(self):
|
||||
content = '(binary data not shown)'
|
||||
content = "(binary data not shown)"
|
||||
try:
|
||||
content = self.getvalue().decode()
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
logger.info('**Not** written to {} (dry run mode):\n\n{}\n\n'.format(self.__fname, content))
|
||||
logger.info(
|
||||
"**Not** written to {} (dry run mode):\n\n{}\n\n".format(
|
||||
self.__fname, content
|
||||
)
|
||||
)
|
||||
super().close()
|
||||
|
||||
|
||||
class Exporter:
|
||||
'''
|
||||
"""
|
||||
Interface for all exporters. It is not necessary, but it is useful
|
||||
if you don't plan to implement all the methods.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None):
|
||||
self.simulation = simulation
|
||||
outdir = outdir or os.path.join(os.getcwd(), 'soil_output')
|
||||
self.outdir = os.path.join(outdir,
|
||||
simulation.group or '',
|
||||
simulation.name)
|
||||
outdir = outdir or os.path.join(os.getcwd(), "soil_output")
|
||||
self.outdir = os.path.join(outdir, simulation.group or "", simulation.name)
|
||||
self.dry_run = dry_run
|
||||
if copy_to is None and dry_run:
|
||||
copy_to = sys.stdout
|
||||
self.copy_to = copy_to
|
||||
|
||||
def sim_start(self):
|
||||
'''Method to call when the simulation starts'''
|
||||
"""Method to call when the simulation starts"""
|
||||
pass
|
||||
|
||||
def sim_end(self):
|
||||
'''Method to call when the simulation ends'''
|
||||
"""Method to call when the simulation ends"""
|
||||
pass
|
||||
|
||||
def trial_start(self, env):
|
||||
'''Method to call when a trial start'''
|
||||
"""Method to call when a trial start"""
|
||||
pass
|
||||
|
||||
def trial_end(self, env):
|
||||
'''Method to call when a trial ends'''
|
||||
"""Method to call when a trial ends"""
|
||||
pass
|
||||
|
||||
def output(self, f, mode='w', **kwargs):
|
||||
def output(self, f, mode="w", **kwargs):
|
||||
if self.dry_run:
|
||||
f = DryRunner(f, copy_to=self.copy_to)
|
||||
else:
|
||||
@ -86,35 +88,38 @@ class Exporter:
|
||||
|
||||
|
||||
class default(Exporter):
|
||||
'''Default exporter. Writes sqlite results, as well as the simulation YAML'''
|
||||
"""Default exporter. Writes sqlite results, as well as the simulation YAML"""
|
||||
|
||||
def sim_start(self):
|
||||
if not self.dry_run:
|
||||
logger.info('Dumping results to %s', self.outdir)
|
||||
with self.output(self.simulation.name + '.dumped.yml') as f:
|
||||
logger.info("Dumping results to %s", self.outdir)
|
||||
with self.output(self.simulation.name + ".dumped.yml") as f:
|
||||
f.write(self.simulation.to_yaml())
|
||||
else:
|
||||
logger.info('NOT dumping results')
|
||||
logger.info("NOT dumping results")
|
||||
|
||||
def trial_end(self, env):
|
||||
if self.dry_run:
|
||||
logger.info('Running in DRY_RUN mode, the database will NOT be created')
|
||||
logger.info("Running in DRY_RUN mode, the database will NOT be created")
|
||||
return
|
||||
|
||||
with timer('Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||
env.id)):
|
||||
with timer(
|
||||
"Dumping simulation {} trial {}".format(self.simulation.name, env.id)
|
||||
):
|
||||
|
||||
fpath = os.path.join(self.outdir, f'{env.id}.sqlite')
|
||||
engine = create_engine(f'sqlite:///{fpath}', echo=False)
|
||||
fpath = os.path.join(self.outdir, f"{env.id}.sqlite")
|
||||
engine = create_engine(f"sqlite:///{fpath}", echo=False)
|
||||
|
||||
dc = env.datacollector
|
||||
for (t, df) in get_dc_dfs(dc):
|
||||
df.to_sql(t, con=engine, if_exists='append')
|
||||
df.to_sql(t, con=engine, if_exists="append")
|
||||
|
||||
|
||||
def get_dc_dfs(dc):
|
||||
dfs = {'env': dc.get_model_vars_dataframe(),
|
||||
'agents': dc.get_agent_vars_dataframe() }
|
||||
dfs = {
|
||||
"env": dc.get_model_vars_dataframe(),
|
||||
"agents": dc.get_agent_vars_dataframe(),
|
||||
}
|
||||
for table_name in dc.tables:
|
||||
dfs[table_name] = dc.get_table_dataframe(table_name)
|
||||
yield from dfs.items()
|
||||
@ -122,66 +127,78 @@ def get_dc_dfs(dc):
|
||||
|
||||
class csv(Exporter):
|
||||
|
||||
'''Export the state of each environment (and its agents) in a separate CSV file'''
|
||||
"""Export the state of each environment (and its agents) in a separate CSV file"""
|
||||
|
||||
def trial_end(self, env):
|
||||
with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name,
|
||||
env.id,
|
||||
self.outdir)):
|
||||
with timer(
|
||||
"[CSV] Dumping simulation {} trial {} @ dir {}".format(
|
||||
self.simulation.name, env.id, self.outdir
|
||||
)
|
||||
):
|
||||
for (df_name, df) in get_dc_dfs(env.datacollector):
|
||||
with self.output('{}.{}.csv'.format(env.id, df_name)) as f:
|
||||
with self.output("{}.{}.csv".format(env.id, df_name)) as f:
|
||||
df.to_csv(f)
|
||||
|
||||
|
||||
#TODO: reimplement GEXF exporting without history
|
||||
# TODO: reimplement GEXF exporting without history
|
||||
class gexf(Exporter):
|
||||
def trial_end(self, env):
|
||||
if self.dry_run:
|
||||
logger.info('Not dumping GEXF in dry_run mode')
|
||||
logger.info("Not dumping GEXF in dry_run mode")
|
||||
return
|
||||
|
||||
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name,
|
||||
env.id)):
|
||||
with self.output('{}.gexf'.format(env.id), mode='wb') as f:
|
||||
with timer(
|
||||
"[GEXF] Dumping simulation {} trial {}".format(self.simulation.name, env.id)
|
||||
):
|
||||
with self.output("{}.gexf".format(env.id), mode="wb") as f:
|
||||
network.dump_gexf(env.history_to_graph(), f)
|
||||
self.dump_gexf(env, f)
|
||||
|
||||
|
||||
class dummy(Exporter):
|
||||
|
||||
def sim_start(self):
|
||||
with self.output('dummy', 'w') as f:
|
||||
f.write('simulation started @ {}\n'.format(current_time()))
|
||||
with self.output("dummy", "w") as f:
|
||||
f.write("simulation started @ {}\n".format(current_time()))
|
||||
|
||||
def trial_start(self, env):
|
||||
with self.output('dummy', 'w') as f:
|
||||
f.write('trial started@ {}\n'.format(current_time()))
|
||||
with self.output("dummy", "w") as f:
|
||||
f.write("trial started@ {}\n".format(current_time()))
|
||||
|
||||
def trial_end(self, env):
|
||||
with self.output('dummy', 'w') as f:
|
||||
f.write('trial ended@ {}\n'.format(current_time()))
|
||||
with self.output("dummy", "w") as f:
|
||||
f.write("trial ended@ {}\n".format(current_time()))
|
||||
|
||||
def sim_end(self):
|
||||
with self.output('dummy', 'a') as f:
|
||||
f.write('simulation ended @ {}\n'.format(current_time()))
|
||||
with self.output("dummy", "a") as f:
|
||||
f.write("simulation ended @ {}\n".format(current_time()))
|
||||
|
||||
|
||||
class graphdrawing(Exporter):
|
||||
|
||||
def trial_end(self, env):
|
||||
# Outside effects
|
||||
f = plt.figure()
|
||||
nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111))
|
||||
with open('graph-{}.png'.format(env.id)) as f:
|
||||
nx.draw(
|
||||
env.G,
|
||||
node_size=10,
|
||||
width=0.2,
|
||||
pos=nx.spring_layout(env.G, scale=100),
|
||||
ax=f.add_subplot(111),
|
||||
)
|
||||
with open("graph-{}.png".format(env.id)) as f:
|
||||
f.savefig(f)
|
||||
|
||||
'''
|
||||
|
||||
"""
|
||||
Convert an environment into a NetworkX graph
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def env_to_graph(env, history=None):
|
||||
G = nx.Graph(env.G)
|
||||
|
||||
for agent in env.network_agents:
|
||||
|
||||
attributes = {'agent': str(agent.__class__)}
|
||||
attributes = {"agent": str(agent.__class__)}
|
||||
lastattributes = {}
|
||||
spells = []
|
||||
lastvisible = False
|
||||
@ -189,7 +206,7 @@ def env_to_graph(env, history=None):
|
||||
if not history:
|
||||
history = sorted(list(env.state_to_tuples()))
|
||||
for _, t_step, attribute, value in history:
|
||||
if attribute == 'visible':
|
||||
if attribute == "visible":
|
||||
nowvisible = value
|
||||
if nowvisible and not lastvisible:
|
||||
laststep = t_step
|
||||
@ -198,7 +215,7 @@ def env_to_graph(env, history=None):
|
||||
|
||||
lastvisible = nowvisible
|
||||
continue
|
||||
key = 'attr_' + attribute
|
||||
key = "attr_" + attribute
|
||||
if key not in attributes:
|
||||
attributes[key] = list()
|
||||
if key not in lastattributes:
|
||||
|
@ -9,6 +9,7 @@ import networkx as nx
|
||||
|
||||
from . import config, serialization, basestring
|
||||
|
||||
|
||||
def from_config(cfg: config.NetConfig, dir_path: str = None):
|
||||
if not isinstance(cfg, config.NetConfig):
|
||||
cfg = config.NetConfig(**cfg)
|
||||
@ -19,24 +20,28 @@ def from_config(cfg: config.NetConfig, dir_path: str = None):
|
||||
path = os.path.join(dir_path, path)
|
||||
extension = os.path.splitext(path)[1][1:]
|
||||
kwargs = {}
|
||||
if extension == 'gexf':
|
||||
kwargs['version'] = '1.2draft'
|
||||
kwargs['node_type'] = int
|
||||
if extension == "gexf":
|
||||
kwargs["version"] = "1.2draft"
|
||||
kwargs["node_type"] = int
|
||||
try:
|
||||
method = getattr(nx.readwrite, 'read_' + extension)
|
||||
method = getattr(nx.readwrite, "read_" + extension)
|
||||
except AttributeError:
|
||||
raise AttributeError('Unknown format')
|
||||
raise AttributeError("Unknown format")
|
||||
return method(path, **kwargs)
|
||||
|
||||
if cfg.params:
|
||||
net_args = cfg.params.dict()
|
||||
net_gen = net_args.pop('generator')
|
||||
net_gen = net_args.pop("generator")
|
||||
|
||||
if dir_path not in sys.path:
|
||||
sys.path.append(dir_path)
|
||||
|
||||
method = serialization.deserializer(net_gen,
|
||||
known_modules=['networkx.generators',])
|
||||
method = serialization.deserializer(
|
||||
net_gen,
|
||||
known_modules=[
|
||||
"networkx.generators",
|
||||
],
|
||||
)
|
||||
return method(**net_args)
|
||||
|
||||
if isinstance(cfg.fixed, config.Topology):
|
||||
@ -49,17 +54,17 @@ def from_config(cfg: config.NetConfig, dir_path: str = None):
|
||||
|
||||
|
||||
def find_unassigned(G, shuffle=False, random=random):
|
||||
'''
|
||||
"""
|
||||
Link an agent to a node in a topology.
|
||||
|
||||
If node_id is None, a node without an agent_id will be found.
|
||||
'''
|
||||
#TODO: test
|
||||
"""
|
||||
# TODO: test
|
||||
candidates = list(G.nodes(data=True))
|
||||
if shuffle:
|
||||
random.shuffle(candidates)
|
||||
for next_id, data in candidates:
|
||||
if 'agent' not in data:
|
||||
if "agent" not in data:
|
||||
node_id = next_id
|
||||
break
|
||||
|
||||
@ -68,8 +73,14 @@ def find_unassigned(G, shuffle=False, random=random):
|
||||
|
||||
def dump_gexf(G, f):
|
||||
for node in G.nodes():
|
||||
if 'pos' in G.nodes[node]:
|
||||
G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}}
|
||||
del (G.nodes[node]['pos'])
|
||||
if "pos" in G.nodes[node]:
|
||||
G.nodes[node]["viz"] = {
|
||||
"position": {
|
||||
"x": G.nodes[node]["pos"][0],
|
||||
"y": G.nodes[node]["pos"][1],
|
||||
"z": 0.0,
|
||||
}
|
||||
}
|
||||
del G.nodes[node]["pos"]
|
||||
|
||||
nx.write_gexf(G, f, version="1.2draft")
|
||||
|
@ -15,13 +15,14 @@ import networkx as nx
|
||||
from jinja2 import Template
|
||||
|
||||
|
||||
logger = logging.getLogger('soil')
|
||||
logger = logging.getLogger("soil")
|
||||
|
||||
|
||||
def load_file(infile):
|
||||
folder = os.path.dirname(infile)
|
||||
if folder not in sys.path:
|
||||
sys.path.append(folder)
|
||||
with open(infile, 'r') as f:
|
||||
with open(infile, "r") as f:
|
||||
return list(chain.from_iterable(map(expand_template, load_string(f))))
|
||||
|
||||
|
||||
@ -30,14 +31,15 @@ def load_string(string):
|
||||
|
||||
|
||||
def expand_template(config):
|
||||
if 'template' not in config:
|
||||
if "template" not in config:
|
||||
yield config
|
||||
return
|
||||
if 'vars' not in config:
|
||||
raise ValueError(('You must provide a definition of variables'
|
||||
' for the template.'))
|
||||
if "vars" not in config:
|
||||
raise ValueError(
|
||||
("You must provide a definition of variables" " for the template.")
|
||||
)
|
||||
|
||||
template = config['template']
|
||||
template = config["template"]
|
||||
|
||||
if not isinstance(template, str):
|
||||
template = yaml.dump(template)
|
||||
@ -49,9 +51,9 @@ def expand_template(config):
|
||||
blank_str = template.render({k: 0 for k in params[0].keys()})
|
||||
blank = list(load_string(blank_str))
|
||||
if len(blank) > 1:
|
||||
raise ValueError('Templates must not return more than one configuration')
|
||||
if 'name' in blank[0]:
|
||||
raise ValueError('Templates cannot be named, use group instead')
|
||||
raise ValueError("Templates must not return more than one configuration")
|
||||
if "name" in blank[0]:
|
||||
raise ValueError("Templates cannot be named, use group instead")
|
||||
|
||||
for ps in params:
|
||||
string = template.render(ps)
|
||||
@ -60,25 +62,25 @@ def expand_template(config):
|
||||
|
||||
|
||||
def params_for_template(config):
|
||||
sampler_config = config.get('sampler', {'N': 100})
|
||||
sampler = sampler_config.pop('method', 'SALib.sample.morris.sample')
|
||||
sampler_config = config.get("sampler", {"N": 100})
|
||||
sampler = sampler_config.pop("method", "SALib.sample.morris.sample")
|
||||
sampler = deserializer(sampler)
|
||||
bounds = config['vars']['bounds']
|
||||
bounds = config["vars"]["bounds"]
|
||||
|
||||
problem = {
|
||||
'num_vars': len(bounds),
|
||||
'names': list(bounds.keys()),
|
||||
'bounds': list(v for v in bounds.values())
|
||||
"num_vars": len(bounds),
|
||||
"names": list(bounds.keys()),
|
||||
"bounds": list(v for v in bounds.values()),
|
||||
}
|
||||
samples = sampler(problem, **sampler_config)
|
||||
|
||||
lists = config['vars'].get('lists', {})
|
||||
lists = config["vars"].get("lists", {})
|
||||
names = list(lists.keys())
|
||||
values = list(lists.values())
|
||||
combs = list(product(*values))
|
||||
|
||||
allnames = names + problem['names']
|
||||
allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)]
|
||||
allnames = names + problem["names"]
|
||||
allvalues = [(list(i[0]) + list(i[1])) for i in product(combs, samples)]
|
||||
params = list(map(lambda x: dict(zip(allnames, x)), allvalues))
|
||||
return params
|
||||
|
||||
@ -100,22 +102,24 @@ def load_config(cfg):
|
||||
yield from load_files(cfg)
|
||||
|
||||
|
||||
builtins = importlib.import_module('builtins')
|
||||
builtins = importlib.import_module("builtins")
|
||||
|
||||
KNOWN_MODULES = ['soil', ]
|
||||
KNOWN_MODULES = [
|
||||
"soil",
|
||||
]
|
||||
|
||||
|
||||
def name(value, known_modules=KNOWN_MODULES):
|
||||
'''Return a name that can be imported, to serialize/deserialize an object'''
|
||||
"""Return a name that can be imported, to serialize/deserialize an object"""
|
||||
if value is None:
|
||||
return 'None'
|
||||
return "None"
|
||||
if not isinstance(value, type): # Get the class name first
|
||||
value = type(value)
|
||||
tname = value.__name__
|
||||
if hasattr(builtins, tname):
|
||||
return tname
|
||||
modname = value.__module__
|
||||
if modname == '__main__':
|
||||
if modname == "__main__":
|
||||
return tname
|
||||
if known_modules and modname in known_modules:
|
||||
return tname
|
||||
@ -125,17 +129,17 @@ def name(value, known_modules=KNOWN_MODULES):
|
||||
module = importlib.import_module(kmod)
|
||||
if hasattr(module, tname):
|
||||
return tname
|
||||
return '{}.{}'.format(modname, tname)
|
||||
return "{}.{}".format(modname, tname)
|
||||
|
||||
|
||||
def serializer(type_):
|
||||
if type_ != 'str' and hasattr(builtins, type_):
|
||||
if type_ != "str" and hasattr(builtins, type_):
|
||||
return repr
|
||||
return lambda x: x
|
||||
|
||||
|
||||
def serialize(v, known_modules=KNOWN_MODULES):
|
||||
'''Get a text representation of an object.'''
|
||||
"""Get a text representation of an object."""
|
||||
tname = name(v, known_modules=known_modules)
|
||||
func = serializer(tname)
|
||||
return func(v), tname
|
||||
@ -160,9 +164,9 @@ IS_CLASS = re.compile(r"<class '(.*)'>")
|
||||
def deserializer(type_, known_modules=KNOWN_MODULES):
|
||||
if type(type_) != str: # Already deserialized
|
||||
return type_
|
||||
if type_ == 'str':
|
||||
return lambda x='': x
|
||||
if type_ == 'None':
|
||||
if type_ == "str":
|
||||
return lambda x="": x
|
||||
if type_ == "None":
|
||||
return lambda x=None: None
|
||||
if hasattr(builtins, type_): # Check if it's a builtin type
|
||||
cls = getattr(builtins, type_)
|
||||
@ -172,7 +176,7 @@ def deserializer(type_, known_modules=KNOWN_MODULES):
|
||||
modname, tname = match.group(1).rsplit(".", 1)
|
||||
module = importlib.import_module(modname)
|
||||
cls = getattr(module, tname)
|
||||
return getattr(cls, 'deserialize', cls)
|
||||
return getattr(cls, "deserialize", cls)
|
||||
|
||||
# Otherwise, see if we can find the module and the class
|
||||
options = []
|
||||
@ -181,7 +185,7 @@ def deserializer(type_, known_modules=KNOWN_MODULES):
|
||||
if mod:
|
||||
options.append((mod, type_))
|
||||
|
||||
if '.' in type_: # Fully qualified module
|
||||
if "." in type_: # Fully qualified module
|
||||
module, type_ = type_.rsplit(".", 1)
|
||||
options.append((module, type_))
|
||||
|
||||
@ -190,32 +194,31 @@ def deserializer(type_, known_modules=KNOWN_MODULES):
|
||||
try:
|
||||
module = importlib.import_module(modname)
|
||||
cls = getattr(module, tname)
|
||||
return getattr(cls, 'deserialize', cls)
|
||||
return getattr(cls, "deserialize", cls)
|
||||
except (ImportError, AttributeError) as ex:
|
||||
errors.append((modname, tname, ex))
|
||||
raise Exception('Could not find type "{}". Tried: {}'.format(type_, errors))
|
||||
|
||||
|
||||
def deserialize(type_, value=None, globs=None, **kwargs):
|
||||
'''Get an object from a text representation'''
|
||||
"""Get an object from a text representation"""
|
||||
if not isinstance(type_, str):
|
||||
return type_
|
||||
if globs and type_ in globs:
|
||||
des = globs[type_]
|
||||
else:
|
||||
des = deserializer(type_, **kwargs)
|
||||
des = deserializer(type_, **kwargs)
|
||||
if value is None:
|
||||
return des
|
||||
return des(value)
|
||||
|
||||
|
||||
def deserialize_all(names, *args, known_modules=KNOWN_MODULES, **kwargs):
|
||||
'''Return the list of deserialized objects'''
|
||||
#TODO: remove
|
||||
print('SERIALIZATION', kwargs)
|
||||
"""Return the list of deserialized objects"""
|
||||
# TODO: remove
|
||||
print("SERIALIZATION", kwargs)
|
||||
objects = []
|
||||
for name in names:
|
||||
mod = deserialize(name, known_modules=known_modules)
|
||||
objects.append(mod(*args, **kwargs))
|
||||
return objects
|
||||
|
||||
|
@ -25,7 +25,7 @@ from .time import INFINITY
|
||||
from .config import Config, convert_old
|
||||
|
||||
|
||||
#TODO: change documentation for simulation
|
||||
# TODO: change documentation for simulation
|
||||
@dataclass
|
||||
class Simulation:
|
||||
"""
|
||||
@ -36,15 +36,16 @@ class Simulation:
|
||||
|
||||
kwargs: parameters to use to initialize a new configuration, if one not been provided.
|
||||
"""
|
||||
version: str = '2'
|
||||
name: str = 'Unnamed simulation'
|
||||
description: Optional[str] = ''
|
||||
|
||||
version: str = "2"
|
||||
name: str = "Unnamed simulation"
|
||||
description: Optional[str] = ""
|
||||
group: str = None
|
||||
model_class: Union[str, type] = 'soil.Environment'
|
||||
model_class: Union[str, type] = "soil.Environment"
|
||||
model_params: dict = field(default_factory=dict)
|
||||
seed: str = field(default_factory=lambda: current_time())
|
||||
dir_path: str = field(default_factory=lambda: os.getcwd())
|
||||
max_time: float = float('inf')
|
||||
max_time: float = float("inf")
|
||||
max_steps: int = -1
|
||||
interval: int = 1
|
||||
num_trials: int = 3
|
||||
@ -58,12 +59,13 @@ class Simulation:
|
||||
@classmethod
|
||||
def from_dict(cls, env, **kwargs):
|
||||
|
||||
ignored = {k: v for k, v in env.items()
|
||||
if k not in inspect.signature(cls).parameters}
|
||||
ignored = {
|
||||
k: v for k, v in env.items() if k not in inspect.signature(cls).parameters
|
||||
}
|
||||
|
||||
d = {k:v for k, v in env.items() if k not in ignored}
|
||||
d = {k: v for k, v in env.items() if k not in ignored}
|
||||
if ignored:
|
||||
d.setdefault('extra', {}).update(ignored)
|
||||
d.setdefault("extra", {}).update(ignored)
|
||||
if ignored:
|
||||
print(f'Warning: Ignoring these parameters (added to "extra"): { ignored }')
|
||||
d.update(kwargs)
|
||||
@ -74,24 +76,34 @@ class Simulation:
|
||||
return self.run(*args, **kwargs)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
'''Run the simulation and return the list of resulting environments'''
|
||||
logger.info(dedent('''
|
||||
"""Run the simulation and return the list of resulting environments"""
|
||||
logger.info(
|
||||
dedent(
|
||||
"""
|
||||
Simulation:
|
||||
---
|
||||
''') +
|
||||
self.to_yaml())
|
||||
"""
|
||||
)
|
||||
+ self.to_yaml()
|
||||
)
|
||||
return list(self.run_gen(*args, **kwargs))
|
||||
|
||||
def run_gen(self, parallel=False, dry_run=None,
|
||||
exporters=None, outdir=None, exporter_params={},
|
||||
log_level=None,
|
||||
**kwargs):
|
||||
'''Run the simulation and yield the resulting environments.'''
|
||||
def run_gen(
|
||||
self,
|
||||
parallel=False,
|
||||
dry_run=None,
|
||||
exporters=None,
|
||||
outdir=None,
|
||||
exporter_params={},
|
||||
log_level=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Run the simulation and yield the resulting environments."""
|
||||
if log_level:
|
||||
logger.setLevel(log_level)
|
||||
outdir = outdir or self.outdir
|
||||
logger.info('Using exporters: %s', exporters or [])
|
||||
logger.info('Output directory: %s', outdir)
|
||||
logger.info("Using exporters: %s", exporters or [])
|
||||
logger.info("Output directory: %s", outdir)
|
||||
if dry_run is None:
|
||||
dry_run = self.dry_run
|
||||
if exporters is None:
|
||||
@ -99,22 +111,28 @@ class Simulation:
|
||||
if not exporter_params:
|
||||
exporter_params = self.exporter_params
|
||||
|
||||
exporters = serialization.deserialize_all(exporters,
|
||||
simulation=self,
|
||||
known_modules=['soil.exporters', ],
|
||||
dry_run=dry_run,
|
||||
outdir=outdir,
|
||||
**exporter_params)
|
||||
exporters = serialization.deserialize_all(
|
||||
exporters,
|
||||
simulation=self,
|
||||
known_modules=[
|
||||
"soil.exporters",
|
||||
],
|
||||
dry_run=dry_run,
|
||||
outdir=outdir,
|
||||
**exporter_params,
|
||||
)
|
||||
|
||||
with utils.timer('simulation {}'.format(self.name)):
|
||||
with utils.timer("simulation {}".format(self.name)):
|
||||
for exporter in exporters:
|
||||
exporter.sim_start()
|
||||
|
||||
for env in utils.run_parallel(func=self.run_trial,
|
||||
iterable=range(int(self.num_trials)),
|
||||
parallel=parallel,
|
||||
log_level=log_level,
|
||||
**kwargs):
|
||||
for env in utils.run_parallel(
|
||||
func=self.run_trial,
|
||||
iterable=range(int(self.num_trials)),
|
||||
parallel=parallel,
|
||||
log_level=log_level,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
for exporter in exporters:
|
||||
exporter.trial_start(env)
|
||||
@ -128,11 +146,12 @@ class Simulation:
|
||||
exporter.sim_end()
|
||||
|
||||
def get_env(self, trial_id=0, model_params=None, **kwargs):
|
||||
'''Create an environment for a trial of the simulation'''
|
||||
"""Create an environment for a trial of the simulation"""
|
||||
|
||||
def deserialize_reporters(reporters):
|
||||
for (k, v) in reporters.items():
|
||||
if isinstance(v, str) and v.startswith('py:'):
|
||||
reporters[k] = serialization.deserialize(value.lsplit(':', 1)[1])
|
||||
if isinstance(v, str) and v.startswith("py:"):
|
||||
reporters[k] = serialization.deserialize(value.lsplit(":", 1)[1])
|
||||
return reporters
|
||||
|
||||
params = self.model_params.copy()
|
||||
@ -140,18 +159,22 @@ class Simulation:
|
||||
params.update(model_params)
|
||||
params.update(kwargs)
|
||||
|
||||
agent_reporters = deserialize_reporters(params.pop('agent_reporters', {}))
|
||||
model_reporters = deserialize_reporters(params.pop('model_reporters', {}))
|
||||
agent_reporters = deserialize_reporters(params.pop("agent_reporters", {}))
|
||||
model_reporters = deserialize_reporters(params.pop("model_reporters", {}))
|
||||
|
||||
env = serialization.deserialize(self.model_class)
|
||||
return env(id=f'{self.name}_trial_{trial_id}',
|
||||
seed=f'{self.seed}_trial_{trial_id}',
|
||||
dir_path=self.dir_path,
|
||||
agent_reporters=agent_reporters,
|
||||
model_reporters=model_reporters,
|
||||
**params)
|
||||
env = serialization.deserialize(self.model_class)
|
||||
return env(
|
||||
id=f"{self.name}_trial_{trial_id}",
|
||||
seed=f"{self.seed}_trial_{trial_id}",
|
||||
dir_path=self.dir_path,
|
||||
agent_reporters=agent_reporters,
|
||||
model_reporters=model_reporters,
|
||||
**params,
|
||||
)
|
||||
|
||||
def run_trial(self, trial_id=None, until=None, log_file=False, log_level=logging.INFO, **opts):
|
||||
def run_trial(
|
||||
self, trial_id=None, until=None, log_file=False, log_level=logging.INFO, **opts
|
||||
):
|
||||
"""
|
||||
Run a single trial of the simulation
|
||||
|
||||
@ -160,50 +183,58 @@ class Simulation:
|
||||
logger.setLevel(log_level)
|
||||
model = self.get_env(trial_id, **opts)
|
||||
trial_id = trial_id if trial_id is not None else current_time()
|
||||
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
||||
return self.run_model(model=model, trial_id=trial_id, until=until, log_level=log_level)
|
||||
with utils.timer("Simulation {} trial {}".format(self.name, trial_id)):
|
||||
return self.run_model(
|
||||
model=model, trial_id=trial_id, until=until, log_level=log_level
|
||||
)
|
||||
|
||||
def run_model(self, model, until=None, **opts):
|
||||
# Set-up trial environment and graph
|
||||
until = float(until or self.max_time or 'inf')
|
||||
until = float(until or self.max_time or "inf")
|
||||
|
||||
# Set up agents on nodes
|
||||
def is_done():
|
||||
return False
|
||||
|
||||
if until and hasattr(model.schedule, 'time'):
|
||||
if until and hasattr(model.schedule, "time"):
|
||||
prev = is_done
|
||||
|
||||
def is_done():
|
||||
return prev() or model.schedule.time >= until
|
||||
|
||||
if self.max_steps and self.max_steps > 0 and hasattr(model.schedule, 'steps'):
|
||||
if self.max_steps and self.max_steps > 0 and hasattr(model.schedule, "steps"):
|
||||
prev_steps = is_done
|
||||
|
||||
def is_done():
|
||||
return prev_steps() or model.schedule.steps >= self.max_steps
|
||||
|
||||
newline = '\n'
|
||||
logger.info(dedent(f'''
|
||||
newline = "\n"
|
||||
logger.info(
|
||||
dedent(
|
||||
f"""
|
||||
Model stats:
|
||||
Agents (total: { model.schedule.get_agent_count() }):
|
||||
- { (newline + ' - ').join(str(a) for a in model.schedule.agents) }
|
||||
|
||||
Topology size: { len(model.G) if hasattr(model, "G") else 0 }
|
||||
'''))
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
while not is_done():
|
||||
utils.logger.debug(f'Simulation time {model.schedule.time}/{until}. Next: {getattr(model.schedule, "next_time", model.schedule.time + self.interval)}')
|
||||
utils.logger.debug(
|
||||
f'Simulation time {model.schedule.time}/{until}. Next: {getattr(model.schedule, "next_time", model.schedule.time + self.interval)}'
|
||||
)
|
||||
model.step()
|
||||
return model
|
||||
|
||||
def to_dict(self):
|
||||
d = asdict(self)
|
||||
if not isinstance(d['model_class'], str):
|
||||
d['model_class'] = serialization.name(d['model_class'])
|
||||
d['model_params'] = serialization.serialize_dict(d['model_params'])
|
||||
d['dir_path'] = str(d['dir_path'])
|
||||
d['version'] = '2'
|
||||
if not isinstance(d["model_class"], str):
|
||||
d["model_class"] = serialization.name(d["model_class"])
|
||||
d["model_params"] = serialization.serialize_dict(d["model_params"])
|
||||
d["dir_path"] = str(d["dir_path"])
|
||||
d["version"] = "2"
|
||||
return d
|
||||
|
||||
def to_yaml(self):
|
||||
@ -215,15 +246,15 @@ def iter_from_config(*cfgs, **kwargs):
|
||||
configs = list(serialization.load_config(config))
|
||||
for config, path in configs:
|
||||
d = dict(config)
|
||||
if 'dir_path' not in d:
|
||||
d['dir_path'] = os.path.dirname(path)
|
||||
if "dir_path" not in d:
|
||||
d["dir_path"] = os.path.dirname(path)
|
||||
yield Simulation.from_dict(d, **kwargs)
|
||||
|
||||
|
||||
def from_config(conf_or_path):
|
||||
lst = list(iter_from_config(conf_or_path))
|
||||
if len(lst) > 1:
|
||||
raise AttributeError('Provide only one configuration')
|
||||
raise AttributeError("Provide only one configuration")
|
||||
return lst[0]
|
||||
|
||||
|
||||
|
19
soil/time.py
19
soil/time.py
@ -6,7 +6,8 @@ from .utils import logger
|
||||
from mesa import Agent as MesaAgent
|
||||
|
||||
|
||||
INFINITY = float('inf')
|
||||
INFINITY = float("inf")
|
||||
|
||||
|
||||
class When:
|
||||
def __init__(self, time):
|
||||
@ -42,7 +43,7 @@ class TimedActivation(BaseScheduler):
|
||||
self._next = {}
|
||||
self._queue = []
|
||||
self.next_time = 0
|
||||
self.logger = logger.getChild(f'time_{ self.model }')
|
||||
self.logger = logger.getChild(f"time_{ self.model }")
|
||||
|
||||
def add(self, agent: MesaAgent, when=None):
|
||||
if when is None:
|
||||
@ -62,7 +63,7 @@ class TimedActivation(BaseScheduler):
|
||||
an agent will signal when it wants to be scheduled next.
|
||||
"""
|
||||
|
||||
self.logger.debug(f'Simulation step {self.next_time}')
|
||||
self.logger.debug(f"Simulation step {self.next_time}")
|
||||
if not self.model.running:
|
||||
return
|
||||
|
||||
@ -71,18 +72,22 @@ class TimedActivation(BaseScheduler):
|
||||
|
||||
while self._queue and self._queue[0][0] == self.time:
|
||||
(when, agent_id) = heappop(self._queue)
|
||||
self.logger.debug(f'Stepping agent {agent_id}')
|
||||
self.logger.debug(f"Stepping agent {agent_id}")
|
||||
|
||||
agent = self._agents[agent_id]
|
||||
returned = agent.step()
|
||||
|
||||
if not getattr(agent, 'alive', True):
|
||||
if not getattr(agent, "alive", True):
|
||||
self.remove(agent)
|
||||
continue
|
||||
|
||||
when = (returned or Delta(1)).abs(self.time)
|
||||
if when < self.time:
|
||||
raise Exception("Cannot schedule an agent for a time in the past ({} < {})".format(when, self.time))
|
||||
raise Exception(
|
||||
"Cannot schedule an agent for a time in the past ({} < {})".format(
|
||||
when, self.time
|
||||
)
|
||||
)
|
||||
|
||||
self._next[agent_id] = when
|
||||
heappush(self._queue, (when, agent_id))
|
||||
@ -96,4 +101,4 @@ class TimedActivation(BaseScheduler):
|
||||
return self.time
|
||||
|
||||
self.next_time = self._queue[0][0]
|
||||
self.logger.debug(f'Next step: {self.next_time}')
|
||||
self.logger.debug(f"Next step: {self.next_time}")
|
||||
|
@ -9,12 +9,12 @@ from multiprocessing import Pool
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = logging.getLogger('soil')
|
||||
logger = logging.getLogger("soil")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
timeformat = "%H:%M:%S"
|
||||
|
||||
if os.environ.get('SOIL_VERBOSE', ''):
|
||||
if os.environ.get("SOIL_VERBOSE", ""):
|
||||
logformat = "[%(levelname)-5.5s][%(asctime)s][%(name)s]: %(message)s"
|
||||
else:
|
||||
logformat = "[%(levelname)-5.5s][%(asctime)s] %(message)s"
|
||||
@ -23,38 +23,44 @@ logFormatter = logging.Formatter(logformat, timeformat)
|
||||
consoleHandler = logging.StreamHandler()
|
||||
consoleHandler.setFormatter(logFormatter)
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
handlers=[consoleHandler,])
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[
|
||||
consoleHandler,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timer(name='task', pre="", function=logger.info, to_object=None):
|
||||
def timer(name="task", pre="", function=logger.info, to_object=None):
|
||||
start = current_time()
|
||||
function('{}Starting {} at {}.'.format(pre, name,
|
||||
strftime("%X", gmtime(start))))
|
||||
function("{}Starting {} at {}.".format(pre, name, strftime("%X", gmtime(start))))
|
||||
yield start
|
||||
end = current_time()
|
||||
function('{}Finished {} at {} in {} seconds'.format(pre, name,
|
||||
strftime("%X", gmtime(end)),
|
||||
str(end-start)))
|
||||
function(
|
||||
"{}Finished {} at {} in {} seconds".format(
|
||||
pre, name, strftime("%X", gmtime(end)), str(end - start)
|
||||
)
|
||||
)
|
||||
if to_object:
|
||||
to_object.start = start
|
||||
to_object.end = end
|
||||
|
||||
|
||||
def safe_open(path, mode='r', backup=True, **kwargs):
|
||||
def safe_open(path, mode="r", backup=True, **kwargs):
|
||||
outdir = os.path.dirname(path)
|
||||
if outdir and not os.path.exists(outdir):
|
||||
os.makedirs(outdir)
|
||||
if backup and 'w' in mode and os.path.exists(path):
|
||||
if backup and "w" in mode and os.path.exists(path):
|
||||
creation = os.path.getctime(path)
|
||||
stamp = strftime('%Y-%m-%d_%H.%M.%S', localtime(creation))
|
||||
stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation))
|
||||
|
||||
backup_dir = os.path.join(outdir, 'backup')
|
||||
backup_dir = os.path.join(outdir, "backup")
|
||||
if not os.path.exists(backup_dir):
|
||||
os.makedirs(backup_dir)
|
||||
newpath = os.path.join(backup_dir, '{}@{}'.format(os.path.basename(path),
|
||||
stamp))
|
||||
newpath = os.path.join(
|
||||
backup_dir, "{}@{}".format(os.path.basename(path), stamp)
|
||||
)
|
||||
copyfile(path, newpath)
|
||||
return open(path, mode=mode, **kwargs)
|
||||
|
||||
@ -67,21 +73,23 @@ def open_or_reuse(f, *args, **kwargs):
|
||||
except (AttributeError, TypeError):
|
||||
yield f
|
||||
|
||||
|
||||
def flatten_dict(d):
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
return dict(_flatten_dict(d))
|
||||
|
||||
def _flatten_dict(d, prefix=''):
|
||||
|
||||
def _flatten_dict(d, prefix=""):
|
||||
if not isinstance(d, dict):
|
||||
# print('END:', prefix, d)
|
||||
yield prefix, d
|
||||
return
|
||||
if prefix:
|
||||
prefix = prefix + '.'
|
||||
prefix = prefix + "."
|
||||
for k, v in d.items():
|
||||
# print(k, v)
|
||||
res = list(_flatten_dict(v, prefix='{}{}'.format(prefix, k)))
|
||||
res = list(_flatten_dict(v, prefix="{}{}".format(prefix, k)))
|
||||
# print('RES:', res)
|
||||
yield from res
|
||||
|
||||
@ -93,7 +101,7 @@ def unflatten_dict(d):
|
||||
if not isinstance(k, str):
|
||||
target[k] = v
|
||||
continue
|
||||
tokens = k.split('.')
|
||||
tokens = k.split(".")
|
||||
if len(tokens) < 2:
|
||||
target[k] = v
|
||||
continue
|
||||
@ -106,27 +114,28 @@ def unflatten_dict(d):
|
||||
|
||||
|
||||
def run_and_return_exceptions(func, *args, **kwargs):
|
||||
'''
|
||||
"""
|
||||
A wrapper for run_trial that catches exceptions and returns them.
|
||||
It is meant for async simulations.
|
||||
'''
|
||||
"""
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as ex:
|
||||
if ex.__cause__ is not None:
|
||||
ex = ex.__cause__
|
||||
ex.message = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)[:])
|
||||
ex.message = "".join(
|
||||
traceback.format_exception(type(ex), ex, ex.__traceback__)[:]
|
||||
)
|
||||
return ex
|
||||
|
||||
|
||||
def run_parallel(func, iterable, parallel=False, **kwargs):
|
||||
if parallel and not os.environ.get('SOIL_DEBUG', None):
|
||||
if parallel and not os.environ.get("SOIL_DEBUG", None):
|
||||
p = Pool()
|
||||
wrapped_func = partial(run_and_return_exceptions,
|
||||
func, **kwargs)
|
||||
wrapped_func = partial(run_and_return_exceptions, func, **kwargs)
|
||||
for i in p.imap_unordered(wrapped_func, iterable):
|
||||
if isinstance(i, Exception):
|
||||
logger.error('Trial failed:\n\t%s', i.message)
|
||||
logger.error("Trial failed:\n\t%s", i.message)
|
||||
continue
|
||||
yield i
|
||||
else:
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ROOT = os.path.dirname(__file__)
|
||||
DEFAULT_FILE = os.path.join(ROOT, 'VERSION')
|
||||
DEFAULT_FILE = os.path.join(ROOT, "VERSION")
|
||||
|
||||
|
||||
def read_version(versionfile=DEFAULT_FILE):
|
||||
@ -12,9 +12,10 @@ def read_version(versionfile=DEFAULT_FILE):
|
||||
with open(versionfile) as f:
|
||||
return f.read().strip()
|
||||
except IOError: # pragma: no cover
|
||||
logger.error(('Running an unknown version of {}.'
|
||||
'Be careful!.').format(__name__))
|
||||
return '0.0'
|
||||
logger.error(
|
||||
("Running an unknown version of {}." "Be careful!.").format(__name__)
|
||||
)
|
||||
return "0.0"
|
||||
|
||||
|
||||
__version__ = read_version()
|
||||
|
@ -1,5 +1,6 @@
|
||||
from mesa.visualization.UserParam import UserSettableParameter
|
||||
|
||||
|
||||
class UserSettableParameter(UserSettableParameter):
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
@ -20,6 +20,7 @@ from tornado.concurrent import run_on_executor
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from ..simulation import Simulation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
@ -31,140 +32,183 @@ LOGGING_INTERVAL = 0.5
|
||||
# Workaround to let Soil load the required modules
|
||||
sys.path.append(ROOT)
|
||||
|
||||
|
||||
class PageHandler(tornado.web.RequestHandler):
|
||||
""" Handler for the HTML template which holds the visualization. """
|
||||
"""Handler for the HTML template which holds the visualization."""
|
||||
|
||||
def get(self):
|
||||
self.render('index.html', port=self.application.port,
|
||||
name=self.application.name)
|
||||
self.render(
|
||||
"index.html", port=self.application.port, name=self.application.name
|
||||
)
|
||||
|
||||
|
||||
class SocketHandler(tornado.websocket.WebSocketHandler):
|
||||
""" Handler for websocket. """
|
||||
"""Handler for websocket."""
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
||||
|
||||
def open(self):
|
||||
if self.application.verbose:
|
||||
logger.info('Socket opened!')
|
||||
logger.info("Socket opened!")
|
||||
|
||||
def check_origin(self, origin):
|
||||
return True
|
||||
|
||||
def on_message(self, message):
|
||||
""" Receiving a message from the websocket, parse, and act accordingly. """
|
||||
"""Receiving a message from the websocket, parse, and act accordingly."""
|
||||
|
||||
msg = tornado.escape.json_decode(message)
|
||||
|
||||
if msg['type'] == 'config_file':
|
||||
if msg["type"] == "config_file":
|
||||
|
||||
if self.application.verbose:
|
||||
print(msg['data'])
|
||||
print(msg["data"])
|
||||
|
||||
self.config = list(yaml.load_all(msg['data']))
|
||||
self.config = list(yaml.load_all(msg["data"]))
|
||||
|
||||
if len(self.config) > 1:
|
||||
error = 'Please, provide only one configuration.'
|
||||
error = "Please, provide only one configuration."
|
||||
if self.application.verbose:
|
||||
logger.error(error)
|
||||
self.write_message({'type': 'error',
|
||||
'error': error})
|
||||
self.write_message({"type": "error", "error": error})
|
||||
return
|
||||
|
||||
self.config = self.config[0]
|
||||
self.send_log('INFO.' + self.simulation_name,
|
||||
'Using config: {name}'.format(name=self.config['name']))
|
||||
self.send_log(
|
||||
"INFO." + self.simulation_name,
|
||||
"Using config: {name}".format(name=self.config["name"]),
|
||||
)
|
||||
|
||||
if 'visualization_params' in self.config:
|
||||
self.write_message({'type': 'visualization_params',
|
||||
'data': self.config['visualization_params']})
|
||||
self.name = self.config['name']
|
||||
if "visualization_params" in self.config:
|
||||
self.write_message(
|
||||
{
|
||||
"type": "visualization_params",
|
||||
"data": self.config["visualization_params"],
|
||||
}
|
||||
)
|
||||
self.name = self.config["name"]
|
||||
self.run_simulation()
|
||||
|
||||
settings = []
|
||||
for key in self.config['environment_params']:
|
||||
if type(self.config['environment_params'][key]) == float or type(self.config['environment_params'][key]) == int:
|
||||
if self.config['environment_params'][key] <= 1:
|
||||
setting_type = 'number'
|
||||
for key in self.config["environment_params"]:
|
||||
if (
|
||||
type(self.config["environment_params"][key]) == float
|
||||
or type(self.config["environment_params"][key]) == int
|
||||
):
|
||||
if self.config["environment_params"][key] <= 1:
|
||||
setting_type = "number"
|
||||
else:
|
||||
setting_type = 'great_number'
|
||||
elif type(self.config['environment_params'][key]) == bool:
|
||||
setting_type = 'boolean'
|
||||
setting_type = "great_number"
|
||||
elif type(self.config["environment_params"][key]) == bool:
|
||||
setting_type = "boolean"
|
||||
else:
|
||||
setting_type = 'undefined'
|
||||
setting_type = "undefined"
|
||||
|
||||
settings.append({
|
||||
'label': key,
|
||||
'type': setting_type,
|
||||
'value': self.config['environment_params'][key]
|
||||
})
|
||||
settings.append(
|
||||
{
|
||||
"label": key,
|
||||
"type": setting_type,
|
||||
"value": self.config["environment_params"][key],
|
||||
}
|
||||
)
|
||||
|
||||
self.write_message({'type': 'settings',
|
||||
'data': settings})
|
||||
self.write_message({"type": "settings", "data": settings})
|
||||
|
||||
elif msg['type'] == 'get_trial':
|
||||
elif msg["type"] == "get_trial":
|
||||
if self.application.verbose:
|
||||
logger.info('Trial {} requested!'.format(msg['data']))
|
||||
self.send_log('INFO.' + __name__, 'Trial {} requested!'.format(msg['data']))
|
||||
self.write_message({'type': 'get_trial',
|
||||
'data': self.get_trial(int(msg['data']))})
|
||||
logger.info("Trial {} requested!".format(msg["data"]))
|
||||
self.send_log("INFO." + __name__, "Trial {} requested!".format(msg["data"]))
|
||||
self.write_message(
|
||||
{"type": "get_trial", "data": self.get_trial(int(msg["data"]))}
|
||||
)
|
||||
|
||||
elif msg['type'] == 'run_simulation':
|
||||
elif msg["type"] == "run_simulation":
|
||||
if self.application.verbose:
|
||||
logger.info('Running new simulation for {name}'.format(name=self.config['name']))
|
||||
self.send_log('INFO.' + self.simulation_name, 'Running new simulation for {name}'.format(name=self.config['name']))
|
||||
self.config['environment_params'] = msg['data']
|
||||
logger.info(
|
||||
"Running new simulation for {name}".format(name=self.config["name"])
|
||||
)
|
||||
self.send_log(
|
||||
"INFO." + self.simulation_name,
|
||||
"Running new simulation for {name}".format(name=self.config["name"]),
|
||||
)
|
||||
self.config["environment_params"] = msg["data"]
|
||||
self.run_simulation()
|
||||
|
||||
elif msg['type'] == 'download_gexf':
|
||||
G = self.trials[ int(msg['data']) ].history_to_graph()
|
||||
elif msg["type"] == "download_gexf":
|
||||
G = self.trials[int(msg["data"])].history_to_graph()
|
||||
for node in G.nodes():
|
||||
if 'pos' in G.nodes[node]:
|
||||
G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}}
|
||||
del (G.nodes[node]['pos'])
|
||||
writer = nx.readwrite.gexf.GEXFWriter(version='1.2draft')
|
||||
if "pos" in G.nodes[node]:
|
||||
G.nodes[node]["viz"] = {
|
||||
"position": {
|
||||
"x": G.nodes[node]["pos"][0],
|
||||
"y": G.nodes[node]["pos"][1],
|
||||
"z": 0.0,
|
||||
}
|
||||
}
|
||||
del G.nodes[node]["pos"]
|
||||
writer = nx.readwrite.gexf.GEXFWriter(version="1.2draft")
|
||||
writer.add_graph(G)
|
||||
self.write_message({'type': 'download_gexf',
|
||||
'filename': self.config['name'] + '_trial_' + str(msg['data']),
|
||||
'data': tostring(writer.xml).decode(writer.encoding) })
|
||||
self.write_message(
|
||||
{
|
||||
"type": "download_gexf",
|
||||
"filename": self.config["name"] + "_trial_" + str(msg["data"]),
|
||||
"data": tostring(writer.xml).decode(writer.encoding),
|
||||
}
|
||||
)
|
||||
|
||||
elif msg['type'] == 'download_json':
|
||||
G = self.trials[ int(msg['data']) ].history_to_graph()
|
||||
elif msg["type"] == "download_json":
|
||||
G = self.trials[int(msg["data"])].history_to_graph()
|
||||
for node in G.nodes():
|
||||
if 'pos' in G.nodes[node]:
|
||||
G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}}
|
||||
del (G.nodes[node]['pos'])
|
||||
self.write_message({'type': 'download_json',
|
||||
'filename': self.config['name'] + '_trial_' + str(msg['data']),
|
||||
'data': nx.node_link_data(G) })
|
||||
if "pos" in G.nodes[node]:
|
||||
G.nodes[node]["viz"] = {
|
||||
"position": {
|
||||
"x": G.nodes[node]["pos"][0],
|
||||
"y": G.nodes[node]["pos"][1],
|
||||
"z": 0.0,
|
||||
}
|
||||
}
|
||||
del G.nodes[node]["pos"]
|
||||
self.write_message(
|
||||
{
|
||||
"type": "download_json",
|
||||
"filename": self.config["name"] + "_trial_" + str(msg["data"]),
|
||||
"data": nx.node_link_data(G),
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
if self.application.verbose:
|
||||
logger.info('Unexpected message!')
|
||||
logger.info("Unexpected message!")
|
||||
|
||||
def update_logging(self):
|
||||
try:
|
||||
if (not self.log_capture_string.closed and self.log_capture_string.getvalue()):
|
||||
for i in range(len(self.log_capture_string.getvalue().split('\n')) - 1):
|
||||
self.send_log('INFO.' + self.simulation_name, self.log_capture_string.getvalue().split('\n')[i])
|
||||
if (
|
||||
not self.log_capture_string.closed
|
||||
and self.log_capture_string.getvalue()
|
||||
):
|
||||
for i in range(len(self.log_capture_string.getvalue().split("\n")) - 1):
|
||||
self.send_log(
|
||||
"INFO." + self.simulation_name,
|
||||
self.log_capture_string.getvalue().split("\n")[i],
|
||||
)
|
||||
self.log_capture_string.truncate(0)
|
||||
self.log_capture_string.seek(0)
|
||||
finally:
|
||||
if self.capture_logging:
|
||||
tornado.ioloop.IOLoop.current().call_later(LOGGING_INTERVAL, self.update_logging)
|
||||
|
||||
tornado.ioloop.IOLoop.current().call_later(
|
||||
LOGGING_INTERVAL, self.update_logging
|
||||
)
|
||||
|
||||
def on_close(self):
|
||||
if self.application.verbose:
|
||||
logger.info('Socket closed!')
|
||||
logger.info("Socket closed!")
|
||||
|
||||
def send_log(self, logger, logging):
|
||||
self.write_message({'type': 'log',
|
||||
'logger': logger,
|
||||
'logging': logging})
|
||||
self.write_message({"type": "log", "logger": logger, "logging": logging})
|
||||
|
||||
@property
|
||||
def simulation_name(self):
|
||||
return self.config.get('name', 'NoSimulationRunning')
|
||||
return self.config.get("name", "NoSimulationRunning")
|
||||
|
||||
@run_on_executor
|
||||
def nonblocking(self, config):
|
||||
@ -174,28 +218,31 @@ class SocketHandler(tornado.websocket.WebSocketHandler):
|
||||
@tornado.gen.coroutine
|
||||
def run_simulation(self):
|
||||
# Run simulation and capture logs
|
||||
logger.info('Running simulation!')
|
||||
if 'visualization_params' in self.config:
|
||||
del self.config['visualization_params']
|
||||
logger.info("Running simulation!")
|
||||
if "visualization_params" in self.config:
|
||||
del self.config["visualization_params"]
|
||||
with self.logging(self.simulation_name):
|
||||
try:
|
||||
config = dict(**self.config)
|
||||
config['outdir'] = os.path.join(self.application.outdir, config['name'])
|
||||
config['dump'] = self.application.dump
|
||||
config["outdir"] = os.path.join(self.application.outdir, config["name"])
|
||||
config["dump"] = self.application.dump
|
||||
self.trials = yield self.nonblocking(config)
|
||||
|
||||
self.write_message({'type': 'trials',
|
||||
'data': list(trial.name for trial in self.trials) })
|
||||
self.write_message(
|
||||
{
|
||||
"type": "trials",
|
||||
"data": list(trial.name for trial in self.trials),
|
||||
}
|
||||
)
|
||||
except Exception as ex:
|
||||
error = 'Something went wrong:\n\t{}'.format(ex)
|
||||
error = "Something went wrong:\n\t{}".format(ex)
|
||||
logging.info(error)
|
||||
self.write_message({'type': 'error',
|
||||
'error': error})
|
||||
self.send_log('ERROR.' + self.simulation_name, error)
|
||||
self.write_message({"type": "error", "error": error})
|
||||
self.send_log("ERROR." + self.simulation_name, error)
|
||||
|
||||
def get_trial(self, trial):
|
||||
logger.info('Available trials: %s ' % len(self.trials))
|
||||
logger.info('Ask for : %s' % trial)
|
||||
logger.info("Available trials: %s " % len(self.trials))
|
||||
logger.info("Ask for : %s" % trial)
|
||||
trial = self.trials[trial]
|
||||
G = trial.history_to_graph()
|
||||
return nx.node_link_data(G)
|
||||
@ -218,21 +265,24 @@ class SocketHandler(tornado.websocket.WebSocketHandler):
|
||||
|
||||
|
||||
class ModularServer(tornado.web.Application):
|
||||
""" Main visualization application. """
|
||||
"""Main visualization application."""
|
||||
|
||||
port = 8001
|
||||
page_handler = (r'/', PageHandler)
|
||||
socket_handler = (r'/ws', SocketHandler)
|
||||
static_handler = (r'/(.*)', tornado.web.StaticFileHandler,
|
||||
{'path': os.path.join(ROOT, 'static')})
|
||||
local_handler = (r'/local/(.*)', tornado.web.StaticFileHandler,
|
||||
{'path': ''})
|
||||
page_handler = (r"/", PageHandler)
|
||||
socket_handler = (r"/ws", SocketHandler)
|
||||
static_handler = (
|
||||
r"/(.*)",
|
||||
tornado.web.StaticFileHandler,
|
||||
{"path": os.path.join(ROOT, "static")},
|
||||
)
|
||||
local_handler = (r"/local/(.*)", tornado.web.StaticFileHandler, {"path": ""})
|
||||
|
||||
handlers = [page_handler, socket_handler, static_handler, local_handler]
|
||||
settings = {'debug': True,
|
||||
'template_path': ROOT + '/templates'}
|
||||
settings = {"debug": True, "template_path": ROOT + "/templates"}
|
||||
|
||||
def __init__(self, dump=False, outdir='output', name='SOIL', verbose=True, *args, **kwargs):
|
||||
def __init__(
|
||||
self, dump=False, outdir="output", name="SOIL", verbose=True, *args, **kwargs
|
||||
):
|
||||
|
||||
self.verbose = verbose
|
||||
self.name = name
|
||||
@ -243,12 +293,12 @@ class ModularServer(tornado.web.Application):
|
||||
super().__init__(self.handlers, **self.settings)
|
||||
|
||||
def launch(self, port=None):
|
||||
""" Run the app. """
|
||||
"""Run the app."""
|
||||
|
||||
if port is not None:
|
||||
self.port = port
|
||||
url = 'http://127.0.0.1:{PORT}'.format(PORT=self.port)
|
||||
print('Interface starting at {url}'.format(url=url))
|
||||
url = "http://127.0.0.1:{PORT}".format(PORT=self.port)
|
||||
print("Interface starting at {url}".format(url=url))
|
||||
self.listen(self.port)
|
||||
# webbrowser.open(url)
|
||||
tornado.ioloop.IOLoop.instance().start()
|
||||
@ -263,12 +313,22 @@ def run(*args, **kwargs):
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Visualization of a Graph Model')
|
||||
parser = argparse.ArgumentParser(description="Visualization of a Graph Model")
|
||||
|
||||
parser.add_argument('--name', '-n', nargs=1, default='SOIL', help='name of the simulation')
|
||||
parser.add_argument('--dump', '-d', help='dumping results in folder output', action='store_true')
|
||||
parser.add_argument('--port', '-p', nargs=1, default=8001, help='port for launching the server')
|
||||
parser.add_argument('--verbose', '-v', help='verbose mode', action='store_true')
|
||||
parser.add_argument(
|
||||
"--name", "-n", nargs=1, default="SOIL", help="name of the simulation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dump", "-d", help="dumping results in folder output", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", "-p", nargs=1, default=8001, help="port for launching the server"
|
||||
)
|
||||
parser.add_argument("--verbose", "-v", help="verbose mode", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
run(name=args.name, port=(args.port[0] if isinstance(args.port, list) else args.port), verbose=args.verbose)
|
||||
run(
|
||||
name=args.name,
|
||||
port=(args.port[0] if isinstance(args.port, list) else args.port),
|
||||
verbose=args.verbose,
|
||||
)
|
||||
|
@ -4,20 +4,33 @@ from simulator import Simulator
|
||||
|
||||
|
||||
def run(simulator, name="SOIL", port=8001, verbose=False):
|
||||
server = ModularServer(simulator, name=(name[0] if isinstance(name, list) else name), verbose=verbose)
|
||||
server = ModularServer(
|
||||
simulator, name=(name[0] if isinstance(name, list) else name), verbose=verbose
|
||||
)
|
||||
server.port = port
|
||||
server.launch()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description='Visualization of a Graph Model')
|
||||
parser = argparse.ArgumentParser(description="Visualization of a Graph Model")
|
||||
|
||||
parser.add_argument('--name', '-n', nargs=1, default='SOIL', help='name of the simulation')
|
||||
parser.add_argument('--dump', '-d', help='dumping results in folder output', action='store_true')
|
||||
parser.add_argument('--port', '-p', nargs=1, default=8001, help='port for launching the server')
|
||||
parser.add_argument('--verbose', '-v', help='verbose mode', action='store_true')
|
||||
parser.add_argument(
|
||||
"--name", "-n", nargs=1, default="SOIL", help="name of the simulation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dump", "-d", help="dumping results in folder output", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", "-p", nargs=1, default=8001, help="port for launching the server"
|
||||
)
|
||||
parser.add_argument("--verbose", "-v", help="verbose mode", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
soil = Simulator(dump=args.dump)
|
||||
run(soil, name=args.name, port=(args.port[0] if isinstance(args.port, list) else args.port), verbose=args.verbose)
|
||||
run(
|
||||
soil,
|
||||
name=args.name,
|
||||
port=(args.port[0] if isinstance(args.port, list) else args.port),
|
||||
verbose=args.verbose,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user