mirror of
https://github.com/gsi-upm/soil
synced 2024-11-22 03:02:28 +00:00
Formatted with black
This commit is contained in:
parent
d9947c2c52
commit
78833a9e08
177
soil/__init__.py
177
soil/__init__.py
@ -22,58 +22,107 @@ from .utils import logger
|
|||||||
from .time import *
|
from .time import *
|
||||||
|
|
||||||
|
|
||||||
def main(cfg='simulation.yml', exporters=None, parallel=None, output="soil_output", *, do_run=False, debug=False, **kwargs):
|
def main(
|
||||||
|
cfg="simulation.yml",
|
||||||
|
exporters=None,
|
||||||
|
parallel=None,
|
||||||
|
output="soil_output",
|
||||||
|
*,
|
||||||
|
do_run=False,
|
||||||
|
debug=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
import argparse
|
import argparse
|
||||||
from . import simulation
|
from . import simulation
|
||||||
|
|
||||||
logger.info('Running SOIL version: {}'.format(__version__))
|
logger.info("Running SOIL version: {}".format(__version__))
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Run a SOIL simulation')
|
parser = argparse.ArgumentParser(description="Run a SOIL simulation")
|
||||||
parser.add_argument('file', type=str,
|
parser.add_argument(
|
||||||
nargs="?",
|
"file",
|
||||||
default=cfg,
|
type=str,
|
||||||
help='Configuration file for the simulation (e.g., YAML or JSON)')
|
nargs="?",
|
||||||
parser.add_argument('--version', action='store_true',
|
default=cfg,
|
||||||
help='Show version info and exit')
|
help="Configuration file for the simulation (e.g., YAML or JSON)",
|
||||||
parser.add_argument('--module', '-m', type=str,
|
)
|
||||||
help='file containing the code of any custom agents.')
|
parser.add_argument(
|
||||||
parser.add_argument('--dry-run', '--dry', action='store_true',
|
"--version", action="store_true", help="Show version info and exit"
|
||||||
help='Do not store the results of the simulation to disk, show in terminal instead.')
|
)
|
||||||
parser.add_argument('--pdb', action='store_true',
|
parser.add_argument(
|
||||||
help='Use a pdb console in case of exception.')
|
"--module",
|
||||||
parser.add_argument('--debug', action='store_true',
|
"-m",
|
||||||
help='Run a customized version of a pdb console to debug a simulation.')
|
type=str,
|
||||||
parser.add_argument('--graph', '-g', action='store_true',
|
help="file containing the code of any custom agents.",
|
||||||
help='Dump each trial\'s network topology as a GEXF graph. Defaults to false.')
|
)
|
||||||
parser.add_argument('--csv', action='store_true',
|
parser.add_argument(
|
||||||
help='Dump all data collected in CSV format. Defaults to false.')
|
"--dry-run",
|
||||||
parser.add_argument('--level', type=str,
|
"--dry",
|
||||||
help='Logging level')
|
action="store_true",
|
||||||
parser.add_argument('--output', '-o', type=str, default=output or "soil_output",
|
help="Do not store the results of the simulation to disk, show in terminal instead.",
|
||||||
help='folder to write results to. It defaults to the current directory.')
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pdb", action="store_true", help="Use a pdb console in case of exception."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
action="store_true",
|
||||||
|
help="Run a customized version of a pdb console to debug a simulation.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--graph",
|
||||||
|
"-g",
|
||||||
|
action="store_true",
|
||||||
|
help="Dump each trial's network topology as a GEXF graph. Defaults to false.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--csv",
|
||||||
|
action="store_true",
|
||||||
|
help="Dump all data collected in CSV format. Defaults to false.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--level", type=str, help="Logging level")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
"-o",
|
||||||
|
type=str,
|
||||||
|
default=output or "soil_output",
|
||||||
|
help="folder to write results to. It defaults to the current directory.",
|
||||||
|
)
|
||||||
if parallel is None:
|
if parallel is None:
|
||||||
parser.add_argument('--synchronous', action='store_true',
|
parser.add_argument(
|
||||||
help='Run trials serially and synchronously instead of in parallel. Defaults to false.')
|
"--synchronous",
|
||||||
|
action="store_true",
|
||||||
|
help="Run trials serially and synchronously instead of in parallel. Defaults to false.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument('-e', '--exporter', action='append',
|
parser.add_argument(
|
||||||
default=[],
|
"-e",
|
||||||
help='Export environment and/or simulations using this exporter')
|
"--exporter",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="Export environment and/or simulations using this exporter",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument('--only-convert', '--convert', action='store_true',
|
parser.add_argument(
|
||||||
help='Do not run the simulation, only convert the configuration file(s) and output them.')
|
"--only-convert",
|
||||||
|
"--convert",
|
||||||
|
action="store_true",
|
||||||
|
help="Do not run the simulation, only convert the configuration file(s) and output them.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("--set",
|
parser.add_argument(
|
||||||
metavar="KEY=VALUE",
|
"--set",
|
||||||
action='append',
|
metavar="KEY=VALUE",
|
||||||
help="Set a number of parameters that will be passed to the simulation."
|
action="append",
|
||||||
"(do not put spaces before or after the = sign). "
|
help="Set a number of parameters that will be passed to the simulation."
|
||||||
"If a value contains spaces, you should define "
|
"(do not put spaces before or after the = sign). "
|
||||||
"it with double quotes: "
|
"If a value contains spaces, you should define "
|
||||||
'foo="this is a sentence". Note that '
|
"it with double quotes: "
|
||||||
"values are always treated as strings.")
|
'foo="this is a sentence". Note that '
|
||||||
|
"values are always treated as strings.",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.setLevel(getattr(logging, (args.level or 'INFO').upper()))
|
logger.setLevel(getattr(logging, (args.level or "INFO").upper()))
|
||||||
|
|
||||||
if args.version:
|
if args.version:
|
||||||
return
|
return
|
||||||
@ -81,14 +130,16 @@ def main(cfg='simulation.yml', exporters=None, parallel=None, output="soil_outpu
|
|||||||
if parallel is None:
|
if parallel is None:
|
||||||
parallel = not args.synchronous
|
parallel = not args.synchronous
|
||||||
|
|
||||||
exporters = exporters or ['default', ]
|
exporters = exporters or [
|
||||||
|
"default",
|
||||||
|
]
|
||||||
for exp in args.exporter:
|
for exp in args.exporter:
|
||||||
if exp not in exporters:
|
if exp not in exporters:
|
||||||
exporters.append(exp)
|
exporters.append(exp)
|
||||||
if args.csv:
|
if args.csv:
|
||||||
exporters.append('csv')
|
exporters.append("csv")
|
||||||
if args.graph:
|
if args.graph:
|
||||||
exporters.append('gexf')
|
exporters.append("gexf")
|
||||||
|
|
||||||
if os.getcwd() not in sys.path:
|
if os.getcwd() not in sys.path:
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
@ -97,38 +148,38 @@ def main(cfg='simulation.yml', exporters=None, parallel=None, output="soil_outpu
|
|||||||
if output is None:
|
if output is None:
|
||||||
output = args.output
|
output = args.output
|
||||||
|
|
||||||
|
logger.info("Loading config file: {}".format(args.file))
|
||||||
logger.info('Loading config file: {}'.format(args.file))
|
|
||||||
|
|
||||||
debug = debug or args.debug
|
debug = debug or args.debug
|
||||||
|
|
||||||
if args.pdb or debug:
|
if args.pdb or debug:
|
||||||
args.synchronous = True
|
args.synchronous = True
|
||||||
|
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
try:
|
try:
|
||||||
exp_params = {}
|
exp_params = {}
|
||||||
|
|
||||||
if not os.path.exists(args.file):
|
if not os.path.exists(args.file):
|
||||||
logger.error('Please, input a valid file')
|
logger.error("Please, input a valid file")
|
||||||
return
|
return
|
||||||
|
|
||||||
for sim in simulation.iter_from_config(args.file,
|
for sim in simulation.iter_from_config(
|
||||||
dry_run=args.dry_run,
|
args.file,
|
||||||
exporters=exporters,
|
dry_run=args.dry_run,
|
||||||
parallel=parallel,
|
exporters=exporters,
|
||||||
outdir=output,
|
parallel=parallel,
|
||||||
exporter_params=exp_params,
|
outdir=output,
|
||||||
**kwargs):
|
exporter_params=exp_params,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
if args.set:
|
if args.set:
|
||||||
for s in args.set:
|
for s in args.set:
|
||||||
k, v = s.split('=', 1)[:2]
|
k, v = s.split("=", 1)[:2]
|
||||||
v = eval(v)
|
v = eval(v)
|
||||||
tail, *head = k.rsplit('.', 1)[::-1]
|
tail, *head = k.rsplit(".", 1)[::-1]
|
||||||
target = sim
|
target = sim
|
||||||
if head:
|
if head:
|
||||||
for part in head[0].split('.'):
|
for part in head[0].split("."):
|
||||||
try:
|
try:
|
||||||
target = getattr(target, part)
|
target = getattr(target, part)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -144,19 +195,21 @@ def main(cfg='simulation.yml', exporters=None, parallel=None, output="soil_outpu
|
|||||||
if do_run:
|
if do_run:
|
||||||
res.append(sim.run())
|
res.append(sim.run())
|
||||||
else:
|
else:
|
||||||
print('not running')
|
print("not running")
|
||||||
res.append(sim)
|
res.append(sim)
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
if args.pdb:
|
if args.pdb:
|
||||||
from .debugging import post_mortem
|
from .debugging import post_mortem
|
||||||
|
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
post_mortem()
|
post_mortem()
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
if debug:
|
if debug:
|
||||||
from .debugging import set_trace
|
from .debugging import set_trace
|
||||||
os.environ['SOIL_DEBUG'] = 'true'
|
|
||||||
|
os.environ["SOIL_DEBUG"] = "true"
|
||||||
set_trace()
|
set_trace()
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -165,5 +218,5 @@ def easy(cfg, debug=False, **kwargs):
|
|||||||
return main(cfg, **kwargs)[0]
|
return main(cfg, **kwargs)[0]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
main(do_run=True)
|
main(do_run=True)
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
from . import main as init_main
|
from . import main as init_main
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
init_main(do_run=True)
|
init_main(do_run=True)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
init_main(do_run=True)
|
init_main(do_run=True)
|
||||||
|
@ -7,6 +7,7 @@ class BassModel(FSM):
|
|||||||
innovation_prob
|
innovation_prob
|
||||||
imitation_prob
|
imitation_prob
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sentimentCorrelation = 0
|
sentimentCorrelation = 0
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
@ -21,7 +22,7 @@ class BassModel(FSM):
|
|||||||
else:
|
else:
|
||||||
aware_neighbors = self.get_neighboring_agents(state_id=self.aware.id)
|
aware_neighbors = self.get_neighboring_agents(state_id=self.aware.id)
|
||||||
num_neighbors_aware = len(aware_neighbors)
|
num_neighbors_aware = len(aware_neighbors)
|
||||||
if self.prob((self['imitation_prob']*num_neighbors_aware)):
|
if self.prob((self["imitation_prob"] * num_neighbors_aware)):
|
||||||
self.sentimentCorrelation = 1
|
self.sentimentCorrelation = 1
|
||||||
return self.aware
|
return self.aware
|
||||||
|
|
||||||
|
@ -20,28 +20,40 @@ class BigMarketModel(FSM):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.enterprises = self.env.environment_params['enterprises']
|
self.enterprises = self.env.environment_params["enterprises"]
|
||||||
self.type = ""
|
self.type = ""
|
||||||
|
|
||||||
if self.id < len(self.enterprises): # Enterprises
|
if self.id < len(self.enterprises): # Enterprises
|
||||||
self.set_state(self.enterprise.id)
|
self.set_state(self.enterprise.id)
|
||||||
self.type = "Enterprise"
|
self.type = "Enterprise"
|
||||||
self.tweet_probability = environment.environment_params['tweet_probability_enterprises'][self.id]
|
self.tweet_probability = environment.environment_params[
|
||||||
|
"tweet_probability_enterprises"
|
||||||
|
][self.id]
|
||||||
else: # normal users
|
else: # normal users
|
||||||
self.type = "User"
|
self.type = "User"
|
||||||
self.set_state(self.user.id)
|
self.set_state(self.user.id)
|
||||||
self.tweet_probability = environment.environment_params['tweet_probability_users']
|
self.tweet_probability = environment.environment_params[
|
||||||
self.tweet_relevant_probability = environment.environment_params['tweet_relevant_probability']
|
"tweet_probability_users"
|
||||||
self.tweet_probability_about = environment.environment_params['tweet_probability_about'] # List
|
]
|
||||||
self.sentiment_about = environment.environment_params['sentiment_about'] # List
|
self.tweet_relevant_probability = environment.environment_params[
|
||||||
|
"tweet_relevant_probability"
|
||||||
|
]
|
||||||
|
self.tweet_probability_about = environment.environment_params[
|
||||||
|
"tweet_probability_about"
|
||||||
|
] # List
|
||||||
|
self.sentiment_about = environment.environment_params[
|
||||||
|
"sentiment_about"
|
||||||
|
] # List
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def enterprise(self):
|
def enterprise(self):
|
||||||
|
|
||||||
if self.random.random() < self.tweet_probability: # Tweets
|
if self.random.random() < self.tweet_probability: # Tweets
|
||||||
aware_neighbors = self.get_neighboring_agents(state_id=self.number_of_enterprises) # Nodes neighbour users
|
aware_neighbors = self.get_neighboring_agents(
|
||||||
|
state_id=self.number_of_enterprises
|
||||||
|
) # Nodes neighbour users
|
||||||
for x in aware_neighbors:
|
for x in aware_neighbors:
|
||||||
if self.random.uniform(0,10) < 5:
|
if self.random.uniform(0, 10) < 5:
|
||||||
x.sentiment_about[self.id] += 0.1 # Increments for enterprise
|
x.sentiment_about[self.id] += 0.1 # Increments for enterprise
|
||||||
else:
|
else:
|
||||||
x.sentiment_about[self.id] -= 0.1 # Decrements for enterprise
|
x.sentiment_about[self.id] -= 0.1 # Decrements for enterprise
|
||||||
@ -49,15 +61,19 @@ class BigMarketModel(FSM):
|
|||||||
# Establecemos limites
|
# Establecemos limites
|
||||||
if x.sentiment_about[self.id] > 1:
|
if x.sentiment_about[self.id] > 1:
|
||||||
x.sentiment_about[self.id] = 1
|
x.sentiment_about[self.id] = 1
|
||||||
if x.sentiment_about[self.id]< -1:
|
if x.sentiment_about[self.id] < -1:
|
||||||
x.sentiment_about[self.id] = -1
|
x.sentiment_about[self.id] = -1
|
||||||
|
|
||||||
x.attrs['sentiment_enterprise_%s'% self.enterprises[self.id]] = x.sentiment_about[self.id]
|
x.attrs[
|
||||||
|
"sentiment_enterprise_%s" % self.enterprises[self.id]
|
||||||
|
] = x.sentiment_about[self.id]
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def user(self):
|
def user(self):
|
||||||
if self.random.random() < self.tweet_probability: # Tweets
|
if self.random.random() < self.tweet_probability: # Tweets
|
||||||
if self.random.random() < self.tweet_relevant_probability: # Tweets something relevant
|
if (
|
||||||
|
self.random.random() < self.tweet_relevant_probability
|
||||||
|
): # Tweets something relevant
|
||||||
# Tweet probability per enterprise
|
# Tweet probability per enterprise
|
||||||
for i in range(len(self.enterprises)):
|
for i in range(len(self.enterprises)):
|
||||||
random_num = self.random.random()
|
random_num = self.random.random()
|
||||||
@ -65,23 +81,29 @@ class BigMarketModel(FSM):
|
|||||||
# The condition is fulfilled, sentiments are evaluated towards that enterprise
|
# The condition is fulfilled, sentiments are evaluated towards that enterprise
|
||||||
if self.sentiment_about[i] < 0:
|
if self.sentiment_about[i] < 0:
|
||||||
# NEGATIVO
|
# NEGATIVO
|
||||||
self.userTweets("negative",i)
|
self.userTweets("negative", i)
|
||||||
elif self.sentiment_about[i] == 0:
|
elif self.sentiment_about[i] == 0:
|
||||||
# NEUTRO
|
# NEUTRO
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# POSITIVO
|
# POSITIVO
|
||||||
self.userTweets("positive",i)
|
self.userTweets("positive", i)
|
||||||
for i in range(len(self.enterprises)): # So that it never is set to 0 if there are not changes (logs)
|
for i in range(
|
||||||
self.attrs['sentiment_enterprise_%s'% self.enterprises[i]] = self.sentiment_about[i]
|
len(self.enterprises)
|
||||||
|
): # So that it never is set to 0 if there are not changes (logs)
|
||||||
|
self.attrs[
|
||||||
|
"sentiment_enterprise_%s" % self.enterprises[i]
|
||||||
|
] = self.sentiment_about[i]
|
||||||
|
|
||||||
def userTweets(self, sentiment,enterprise):
|
def userTweets(self, sentiment, enterprise):
|
||||||
aware_neighbors = self.get_neighboring_agents(state_id=self.number_of_enterprises) # Nodes neighbours users
|
aware_neighbors = self.get_neighboring_agents(
|
||||||
|
state_id=self.number_of_enterprises
|
||||||
|
) # Nodes neighbours users
|
||||||
for x in aware_neighbors:
|
for x in aware_neighbors:
|
||||||
if sentiment == "positive":
|
if sentiment == "positive":
|
||||||
x.sentiment_about[enterprise] +=0.003
|
x.sentiment_about[enterprise] += 0.003
|
||||||
elif sentiment == "negative":
|
elif sentiment == "negative":
|
||||||
x.sentiment_about[enterprise] -=0.003
|
x.sentiment_about[enterprise] -= 0.003
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -91,4 +113,6 @@ class BigMarketModel(FSM):
|
|||||||
if x.sentiment_about[enterprise] < -1:
|
if x.sentiment_about[enterprise] < -1:
|
||||||
x.sentiment_about[enterprise] = -1
|
x.sentiment_about[enterprise] = -1
|
||||||
|
|
||||||
x.attrs['sentiment_enterprise_%s'% self.enterprises[enterprise]] = x.sentiment_about[enterprise]
|
x.attrs[
|
||||||
|
"sentiment_enterprise_%s" % self.enterprises[enterprise]
|
||||||
|
] = x.sentiment_about[enterprise]
|
||||||
|
@ -15,9 +15,9 @@ class CounterModel(NetworkAgent):
|
|||||||
# Outside effects
|
# Outside effects
|
||||||
total = len(list(self.model.schedule._agents))
|
total = len(list(self.model.schedule._agents))
|
||||||
neighbors = len(list(self.get_neighboring_agents()))
|
neighbors = len(list(self.get_neighboring_agents()))
|
||||||
self['times'] = self.get('times', 0) + 1
|
self["times"] = self.get("times", 0) + 1
|
||||||
self['neighbors'] = neighbors
|
self["neighbors"] = neighbors
|
||||||
self['total'] = total
|
self["total"] = total
|
||||||
|
|
||||||
|
|
||||||
class AggregatedCounter(NetworkAgent):
|
class AggregatedCounter(NetworkAgent):
|
||||||
@ -32,9 +32,9 @@ class AggregatedCounter(NetworkAgent):
|
|||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
# Outside effects
|
# Outside effects
|
||||||
self['times'] += 1
|
self["times"] += 1
|
||||||
neighbors = len(list(self.get_neighboring_agents()))
|
neighbors = len(list(self.get_neighboring_agents()))
|
||||||
self['neighbors'] += neighbors
|
self["neighbors"] += neighbors
|
||||||
total = len(list(self.model.schedule.agents))
|
total = len(list(self.model.schedule.agents))
|
||||||
self['total'] += total
|
self["total"] += total
|
||||||
self.debug('Running for step: {}. Total: {}'.format(self.now, total))
|
self.debug("Running for step: {}. Total: {}".format(self.now, total))
|
||||||
|
@ -2,20 +2,20 @@ from scipy.spatial import cKDTree as KDTree
|
|||||||
import networkx as nx
|
import networkx as nx
|
||||||
from . import NetworkAgent, as_node
|
from . import NetworkAgent, as_node
|
||||||
|
|
||||||
|
|
||||||
class Geo(NetworkAgent):
|
class Geo(NetworkAgent):
|
||||||
'''In this type of network, nodes have a "pos" attribute.'''
|
"""In this type of network, nodes have a "pos" attribute."""
|
||||||
|
|
||||||
def geo_search(self, radius, node=None, center=False, **kwargs):
|
def geo_search(self, radius, node=None, center=False, **kwargs):
|
||||||
'''Get a list of nodes whose coordinates are closer than *radius* to *node*.'''
|
"""Get a list of nodes whose coordinates are closer than *radius* to *node*."""
|
||||||
node = as_node(node if node is not None else self)
|
node = as_node(node if node is not None else self)
|
||||||
|
|
||||||
G = self.subgraph(**kwargs)
|
G = self.subgraph(**kwargs)
|
||||||
|
|
||||||
pos = nx.get_node_attributes(G, 'pos')
|
pos = nx.get_node_attributes(G, "pos")
|
||||||
if not pos:
|
if not pos:
|
||||||
return []
|
return []
|
||||||
nodes, coords = list(zip(*pos.items()))
|
nodes, coords = list(zip(*pos.items()))
|
||||||
kdtree = KDTree(coords) # Cannot provide generator.
|
kdtree = KDTree(coords) # Cannot provide generator.
|
||||||
indices = kdtree.query_ball_point(pos[node], radius)
|
indices = kdtree.query_ball_point(pos[node], radius)
|
||||||
return [nodes[i] for i in indices if center or (nodes[i] != node)]
|
return [nodes[i] for i in indices if center or (nodes[i] != node)]
|
||||||
|
|
||||||
|
@ -11,10 +11,10 @@ class IndependentCascadeModel(BaseAgent):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.innovation_prob = self.env.environment_params['innovation_prob']
|
self.innovation_prob = self.env.environment_params["innovation_prob"]
|
||||||
self.imitation_prob = self.env.environment_params['imitation_prob']
|
self.imitation_prob = self.env.environment_params["imitation_prob"]
|
||||||
self.state['time_awareness'] = 0
|
self.state["time_awareness"] = 0
|
||||||
self.state['sentimentCorrelation'] = 0
|
self.state["sentimentCorrelation"] = 0
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
self.behaviour()
|
self.behaviour()
|
||||||
@ -23,25 +23,27 @@ class IndependentCascadeModel(BaseAgent):
|
|||||||
aware_neighbors_1_time_step = []
|
aware_neighbors_1_time_step = []
|
||||||
# Outside effects
|
# Outside effects
|
||||||
if self.prob(self.innovation_prob):
|
if self.prob(self.innovation_prob):
|
||||||
if self.state['id'] == 0:
|
if self.state["id"] == 0:
|
||||||
self.state['id'] = 1
|
self.state["id"] = 1
|
||||||
self.state['sentimentCorrelation'] = 1
|
self.state["sentimentCorrelation"] = 1
|
||||||
self.state['time_awareness'] = self.env.now # To know when they have been infected
|
self.state[
|
||||||
|
"time_awareness"
|
||||||
|
] = self.env.now # To know when they have been infected
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Imitation effects
|
# Imitation effects
|
||||||
if self.state['id'] == 0:
|
if self.state["id"] == 0:
|
||||||
aware_neighbors = self.get_neighboring_agents(state_id=1)
|
aware_neighbors = self.get_neighboring_agents(state_id=1)
|
||||||
for x in aware_neighbors:
|
for x in aware_neighbors:
|
||||||
if x.state['time_awareness'] == (self.env.now-1):
|
if x.state["time_awareness"] == (self.env.now - 1):
|
||||||
aware_neighbors_1_time_step.append(x)
|
aware_neighbors_1_time_step.append(x)
|
||||||
num_neighbors_aware = len(aware_neighbors_1_time_step)
|
num_neighbors_aware = len(aware_neighbors_1_time_step)
|
||||||
if self.prob(self.imitation_prob*num_neighbors_aware):
|
if self.prob(self.imitation_prob * num_neighbors_aware):
|
||||||
self.state['id'] = 1
|
self.state["id"] = 1
|
||||||
self.state['sentimentCorrelation'] = 1
|
self.state["sentimentCorrelation"] = 1
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -23,36 +23,49 @@ class SpreadModelM2(BaseAgent):
|
|||||||
def __init__(self, model=None, unique_id=0, state=()):
|
def __init__(self, model=None, unique_id=0, state=()):
|
||||||
super().__init__(model=environment, unique_id=unique_id, state=state)
|
super().__init__(model=environment, unique_id=unique_id, state=state)
|
||||||
|
|
||||||
|
|
||||||
# Use a single generator with the same seed as `self.random`
|
# Use a single generator with the same seed as `self.random`
|
||||||
random = np.random.default_rng(seed=self._seed)
|
random = np.random.default_rng(seed=self._seed)
|
||||||
self.prob_neutral_making_denier = random.normal(environment.environment_params['prob_neutral_making_denier'],
|
self.prob_neutral_making_denier = random.normal(
|
||||||
environment.environment_params['standard_variance'])
|
environment.environment_params["prob_neutral_making_denier"],
|
||||||
|
environment.environment_params["standard_variance"],
|
||||||
|
)
|
||||||
|
|
||||||
self.prob_infect = random.normal(environment.environment_params['prob_infect'],
|
self.prob_infect = random.normal(
|
||||||
environment.environment_params['standard_variance'])
|
environment.environment_params["prob_infect"],
|
||||||
|
environment.environment_params["standard_variance"],
|
||||||
|
)
|
||||||
|
|
||||||
self.prob_cured_healing_infected = random.normal(environment.environment_params['prob_cured_healing_infected'],
|
self.prob_cured_healing_infected = random.normal(
|
||||||
environment.environment_params['standard_variance'])
|
environment.environment_params["prob_cured_healing_infected"],
|
||||||
self.prob_cured_vaccinate_neutral = random.normal(environment.environment_params['prob_cured_vaccinate_neutral'],
|
environment.environment_params["standard_variance"],
|
||||||
environment.environment_params['standard_variance'])
|
)
|
||||||
|
self.prob_cured_vaccinate_neutral = random.normal(
|
||||||
|
environment.environment_params["prob_cured_vaccinate_neutral"],
|
||||||
|
environment.environment_params["standard_variance"],
|
||||||
|
)
|
||||||
|
|
||||||
self.prob_vaccinated_healing_infected = random.normal(environment.environment_params['prob_vaccinated_healing_infected'],
|
self.prob_vaccinated_healing_infected = random.normal(
|
||||||
environment.environment_params['standard_variance'])
|
environment.environment_params["prob_vaccinated_healing_infected"],
|
||||||
self.prob_vaccinated_vaccinate_neutral = random.normal(environment.environment_params['prob_vaccinated_vaccinate_neutral'],
|
environment.environment_params["standard_variance"],
|
||||||
environment.environment_params['standard_variance'])
|
)
|
||||||
self.prob_generate_anti_rumor = random.normal(environment.environment_params['prob_generate_anti_rumor'],
|
self.prob_vaccinated_vaccinate_neutral = random.normal(
|
||||||
environment.environment_params['standard_variance'])
|
environment.environment_params["prob_vaccinated_vaccinate_neutral"],
|
||||||
|
environment.environment_params["standard_variance"],
|
||||||
|
)
|
||||||
|
self.prob_generate_anti_rumor = random.normal(
|
||||||
|
environment.environment_params["prob_generate_anti_rumor"],
|
||||||
|
environment.environment_params["standard_variance"],
|
||||||
|
)
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
|
|
||||||
if self.state['id'] == 0: # Neutral
|
if self.state["id"] == 0: # Neutral
|
||||||
self.neutral_behaviour()
|
self.neutral_behaviour()
|
||||||
elif self.state['id'] == 1: # Infected
|
elif self.state["id"] == 1: # Infected
|
||||||
self.infected_behaviour()
|
self.infected_behaviour()
|
||||||
elif self.state['id'] == 2: # Cured
|
elif self.state["id"] == 2: # Cured
|
||||||
self.cured_behaviour()
|
self.cured_behaviour()
|
||||||
elif self.state['id'] == 3: # Vaccinated
|
elif self.state["id"] == 3: # Vaccinated
|
||||||
self.vaccinated_behaviour()
|
self.vaccinated_behaviour()
|
||||||
|
|
||||||
def neutral_behaviour(self):
|
def neutral_behaviour(self):
|
||||||
@ -61,7 +74,7 @@ class SpreadModelM2(BaseAgent):
|
|||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||||
if len(infected_neighbors) > 0:
|
if len(infected_neighbors) > 0:
|
||||||
if self.prob(self.prob_neutral_making_denier):
|
if self.prob(self.prob_neutral_making_denier):
|
||||||
self.state['id'] = 3 # Vaccinated making denier
|
self.state["id"] = 3 # Vaccinated making denier
|
||||||
|
|
||||||
def infected_behaviour(self):
|
def infected_behaviour(self):
|
||||||
|
|
||||||
@ -69,7 +82,7 @@ class SpreadModelM2(BaseAgent):
|
|||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_infect):
|
if self.prob(self.prob_infect):
|
||||||
neighbor.state['id'] = 1 # Infected
|
neighbor.state["id"] = 1 # Infected
|
||||||
|
|
||||||
def cured_behaviour(self):
|
def cured_behaviour(self):
|
||||||
|
|
||||||
@ -77,13 +90,13 @@ class SpreadModelM2(BaseAgent):
|
|||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||||
neighbor.state['id'] = 3 # Vaccinated
|
neighbor.state["id"] = 3 # Vaccinated
|
||||||
|
|
||||||
# Cure
|
# Cure
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||||
for neighbor in infected_neighbors:
|
for neighbor in infected_neighbors:
|
||||||
if self.prob(self.prob_cured_healing_infected):
|
if self.prob(self.prob_cured_healing_infected):
|
||||||
neighbor.state['id'] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
|
|
||||||
def vaccinated_behaviour(self):
|
def vaccinated_behaviour(self):
|
||||||
|
|
||||||
@ -91,19 +104,19 @@ class SpreadModelM2(BaseAgent):
|
|||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||||
for neighbor in infected_neighbors:
|
for neighbor in infected_neighbors:
|
||||||
if self.prob(self.prob_cured_healing_infected):
|
if self.prob(self.prob_cured_healing_infected):
|
||||||
neighbor.state['id'] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
|
|
||||||
# Vaccinate
|
# Vaccinate
|
||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||||
neighbor.state['id'] = 3 # Vaccinated
|
neighbor.state["id"] = 3 # Vaccinated
|
||||||
|
|
||||||
# Generate anti-rumor
|
# Generate anti-rumor
|
||||||
infected_neighbors_2 = self.get_neighboring_agents(state_id=1)
|
infected_neighbors_2 = self.get_neighboring_agents(state_id=1)
|
||||||
for neighbor in infected_neighbors_2:
|
for neighbor in infected_neighbors_2:
|
||||||
if self.prob(self.prob_generate_anti_rumor):
|
if self.prob(self.prob_generate_anti_rumor):
|
||||||
neighbor.state['id'] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
|
|
||||||
|
|
||||||
class ControlModelM2(BaseAgent):
|
class ControlModelM2(BaseAgent):
|
||||||
@ -124,51 +137,64 @@ class ControlModelM2(BaseAgent):
|
|||||||
prob_generate_anti_rumor
|
prob_generate_anti_rumor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, model=None, unique_id=0, state=()):
|
def __init__(self, model=None, unique_id=0, state=()):
|
||||||
super().__init__(model=environment, unique_id=unique_id, state=state)
|
super().__init__(model=environment, unique_id=unique_id, state=state)
|
||||||
|
|
||||||
self.prob_neutral_making_denier = np.random.normal(environment.environment_params['prob_neutral_making_denier'],
|
self.prob_neutral_making_denier = np.random.normal(
|
||||||
environment.environment_params['standard_variance'])
|
environment.environment_params["prob_neutral_making_denier"],
|
||||||
|
environment.environment_params["standard_variance"],
|
||||||
|
)
|
||||||
|
|
||||||
self.prob_infect = np.random.normal(environment.environment_params['prob_infect'],
|
self.prob_infect = np.random.normal(
|
||||||
environment.environment_params['standard_variance'])
|
environment.environment_params["prob_infect"],
|
||||||
|
environment.environment_params["standard_variance"],
|
||||||
|
)
|
||||||
|
|
||||||
self.prob_cured_healing_infected = np.random.normal(environment.environment_params['prob_cured_healing_infected'],
|
self.prob_cured_healing_infected = np.random.normal(
|
||||||
environment.environment_params['standard_variance'])
|
environment.environment_params["prob_cured_healing_infected"],
|
||||||
self.prob_cured_vaccinate_neutral = np.random.normal(environment.environment_params['prob_cured_vaccinate_neutral'],
|
environment.environment_params["standard_variance"],
|
||||||
environment.environment_params['standard_variance'])
|
)
|
||||||
|
self.prob_cured_vaccinate_neutral = np.random.normal(
|
||||||
|
environment.environment_params["prob_cured_vaccinate_neutral"],
|
||||||
|
environment.environment_params["standard_variance"],
|
||||||
|
)
|
||||||
|
|
||||||
self.prob_vaccinated_healing_infected = np.random.normal(environment.environment_params['prob_vaccinated_healing_infected'],
|
self.prob_vaccinated_healing_infected = np.random.normal(
|
||||||
environment.environment_params['standard_variance'])
|
environment.environment_params["prob_vaccinated_healing_infected"],
|
||||||
self.prob_vaccinated_vaccinate_neutral = np.random.normal(environment.environment_params['prob_vaccinated_vaccinate_neutral'],
|
environment.environment_params["standard_variance"],
|
||||||
environment.environment_params['standard_variance'])
|
)
|
||||||
self.prob_generate_anti_rumor = np.random.normal(environment.environment_params['prob_generate_anti_rumor'],
|
self.prob_vaccinated_vaccinate_neutral = np.random.normal(
|
||||||
environment.environment_params['standard_variance'])
|
environment.environment_params["prob_vaccinated_vaccinate_neutral"],
|
||||||
|
environment.environment_params["standard_variance"],
|
||||||
|
)
|
||||||
|
self.prob_generate_anti_rumor = np.random.normal(
|
||||||
|
environment.environment_params["prob_generate_anti_rumor"],
|
||||||
|
environment.environment_params["standard_variance"],
|
||||||
|
)
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
|
|
||||||
if self.state['id'] == 0: # Neutral
|
if self.state["id"] == 0: # Neutral
|
||||||
self.neutral_behaviour()
|
self.neutral_behaviour()
|
||||||
elif self.state['id'] == 1: # Infected
|
elif self.state["id"] == 1: # Infected
|
||||||
self.infected_behaviour()
|
self.infected_behaviour()
|
||||||
elif self.state['id'] == 2: # Cured
|
elif self.state["id"] == 2: # Cured
|
||||||
self.cured_behaviour()
|
self.cured_behaviour()
|
||||||
elif self.state['id'] == 3: # Vaccinated
|
elif self.state["id"] == 3: # Vaccinated
|
||||||
self.vaccinated_behaviour()
|
self.vaccinated_behaviour()
|
||||||
elif self.state['id'] == 4: # Beacon-off
|
elif self.state["id"] == 4: # Beacon-off
|
||||||
self.beacon_off_behaviour()
|
self.beacon_off_behaviour()
|
||||||
elif self.state['id'] == 5: # Beacon-on
|
elif self.state["id"] == 5: # Beacon-on
|
||||||
self.beacon_on_behaviour()
|
self.beacon_on_behaviour()
|
||||||
|
|
||||||
def neutral_behaviour(self):
|
def neutral_behaviour(self):
|
||||||
self.state['visible'] = False
|
self.state["visible"] = False
|
||||||
|
|
||||||
# Infected
|
# Infected
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||||
if len(infected_neighbors) > 0:
|
if len(infected_neighbors) > 0:
|
||||||
if self.random(self.prob_neutral_making_denier):
|
if self.random(self.prob_neutral_making_denier):
|
||||||
self.state['id'] = 3 # Vaccinated making denier
|
self.state["id"] = 3 # Vaccinated making denier
|
||||||
|
|
||||||
def infected_behaviour(self):
|
def infected_behaviour(self):
|
||||||
|
|
||||||
@ -176,69 +202,69 @@ class ControlModelM2(BaseAgent):
|
|||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_infect):
|
if self.prob(self.prob_infect):
|
||||||
neighbor.state['id'] = 1 # Infected
|
neighbor.state["id"] = 1 # Infected
|
||||||
self.state['visible'] = False
|
self.state["visible"] = False
|
||||||
|
|
||||||
def cured_behaviour(self):
|
def cured_behaviour(self):
|
||||||
|
|
||||||
self.state['visible'] = True
|
self.state["visible"] = True
|
||||||
# Vaccinate
|
# Vaccinate
|
||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||||
neighbor.state['id'] = 3 # Vaccinated
|
neighbor.state["id"] = 3 # Vaccinated
|
||||||
|
|
||||||
# Cure
|
# Cure
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||||
for neighbor in infected_neighbors:
|
for neighbor in infected_neighbors:
|
||||||
if self.prob(self.prob_cured_healing_infected):
|
if self.prob(self.prob_cured_healing_infected):
|
||||||
neighbor.state['id'] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
|
|
||||||
def vaccinated_behaviour(self):
|
def vaccinated_behaviour(self):
|
||||||
self.state['visible'] = True
|
self.state["visible"] = True
|
||||||
|
|
||||||
# Cure
|
# Cure
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||||
for neighbor in infected_neighbors:
|
for neighbor in infected_neighbors:
|
||||||
if self.prob(self.prob_cured_healing_infected):
|
if self.prob(self.prob_cured_healing_infected):
|
||||||
neighbor.state['id'] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
|
|
||||||
# Vaccinate
|
# Vaccinate
|
||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||||
neighbor.state['id'] = 3 # Vaccinated
|
neighbor.state["id"] = 3 # Vaccinated
|
||||||
|
|
||||||
# Generate anti-rumor
|
# Generate anti-rumor
|
||||||
infected_neighbors_2 = self.get_neighboring_agents(state_id=1)
|
infected_neighbors_2 = self.get_neighboring_agents(state_id=1)
|
||||||
for neighbor in infected_neighbors_2:
|
for neighbor in infected_neighbors_2:
|
||||||
if self.prob(self.prob_generate_anti_rumor):
|
if self.prob(self.prob_generate_anti_rumor):
|
||||||
neighbor.state['id'] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
|
|
||||||
def beacon_off_behaviour(self):
|
def beacon_off_behaviour(self):
|
||||||
self.state['visible'] = False
|
self.state["visible"] = False
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||||
if len(infected_neighbors) > 0:
|
if len(infected_neighbors) > 0:
|
||||||
self.state['id'] == 5 # Beacon on
|
self.state["id"] == 5 # Beacon on
|
||||||
|
|
||||||
def beacon_on_behaviour(self):
|
def beacon_on_behaviour(self):
|
||||||
self.state['visible'] = False
|
self.state["visible"] = False
|
||||||
# Cure (M2 feature added)
|
# Cure (M2 feature added)
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||||
for neighbor in infected_neighbors:
|
for neighbor in infected_neighbors:
|
||||||
if self.prob(self.prob_generate_anti_rumor):
|
if self.prob(self.prob_generate_anti_rumor):
|
||||||
neighbor.state['id'] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
neutral_neighbors_infected = neighbor.get_neighboring_agents(state_id=0)
|
neutral_neighbors_infected = neighbor.get_neighboring_agents(state_id=0)
|
||||||
for neighbor in neutral_neighbors_infected:
|
for neighbor in neutral_neighbors_infected:
|
||||||
if self.prob(self.prob_generate_anti_rumor):
|
if self.prob(self.prob_generate_anti_rumor):
|
||||||
neighbor.state['id'] = 3 # Vaccinated
|
neighbor.state["id"] = 3 # Vaccinated
|
||||||
infected_neighbors_infected = neighbor.get_neighboring_agents(state_id=1)
|
infected_neighbors_infected = neighbor.get_neighboring_agents(state_id=1)
|
||||||
for neighbor in infected_neighbors_infected:
|
for neighbor in infected_neighbors_infected:
|
||||||
if self.prob(self.prob_generate_anti_rumor):
|
if self.prob(self.prob_generate_anti_rumor):
|
||||||
neighbor.state['id'] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
|
|
||||||
# Vaccinate
|
# Vaccinate
|
||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||||
neighbor.state['id'] = 3 # Vaccinated
|
neighbor.state["id"] = 3 # Vaccinated
|
||||||
|
@ -33,24 +33,32 @@ class SISaModel(FSM):
|
|||||||
|
|
||||||
random = np.random.default_rng(seed=self._seed)
|
random = np.random.default_rng(seed=self._seed)
|
||||||
|
|
||||||
self.neutral_discontent_spon_prob = random.normal(self.env['neutral_discontent_spon_prob'],
|
self.neutral_discontent_spon_prob = random.normal(
|
||||||
self.env['standard_variance'])
|
self.env["neutral_discontent_spon_prob"], self.env["standard_variance"]
|
||||||
self.neutral_discontent_infected_prob = random.normal(self.env['neutral_discontent_infected_prob'],
|
)
|
||||||
self.env['standard_variance'])
|
self.neutral_discontent_infected_prob = random.normal(
|
||||||
self.neutral_content_spon_prob = random.normal(self.env['neutral_content_spon_prob'],
|
self.env["neutral_discontent_infected_prob"], self.env["standard_variance"]
|
||||||
self.env['standard_variance'])
|
)
|
||||||
self.neutral_content_infected_prob = random.normal(self.env['neutral_content_infected_prob'],
|
self.neutral_content_spon_prob = random.normal(
|
||||||
self.env['standard_variance'])
|
self.env["neutral_content_spon_prob"], self.env["standard_variance"]
|
||||||
|
)
|
||||||
|
self.neutral_content_infected_prob = random.normal(
|
||||||
|
self.env["neutral_content_infected_prob"], self.env["standard_variance"]
|
||||||
|
)
|
||||||
|
|
||||||
self.discontent_neutral = random.normal(self.env['discontent_neutral'],
|
self.discontent_neutral = random.normal(
|
||||||
self.env['standard_variance'])
|
self.env["discontent_neutral"], self.env["standard_variance"]
|
||||||
self.discontent_content = random.normal(self.env['discontent_content'],
|
)
|
||||||
self.env['variance_d_c'])
|
self.discontent_content = random.normal(
|
||||||
|
self.env["discontent_content"], self.env["variance_d_c"]
|
||||||
|
)
|
||||||
|
|
||||||
self.content_discontent = random.normal(self.env['content_discontent'],
|
self.content_discontent = random.normal(
|
||||||
self.env['variance_c_d'])
|
self.env["content_discontent"], self.env["variance_c_d"]
|
||||||
self.content_neutral = random.normal(self.env['content_neutral'],
|
)
|
||||||
self.env['standard_variance'])
|
self.content_neutral = random.normal(
|
||||||
|
self.env["content_neutral"], self.env["standard_variance"]
|
||||||
|
)
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def neutral(self):
|
def neutral(self):
|
||||||
@ -88,7 +96,9 @@ class SISaModel(FSM):
|
|||||||
return self.neutral
|
return self.neutral
|
||||||
|
|
||||||
# Superinfected
|
# Superinfected
|
||||||
discontent_neighbors = self.count_neighboring_agents(state_id=self.discontent.id)
|
discontent_neighbors = self.count_neighboring_agents(
|
||||||
|
state_id=self.discontent.id
|
||||||
|
)
|
||||||
if self.prob(scontent_neighbors * self.content_discontent):
|
if self.prob(scontent_neighbors * self.content_discontent):
|
||||||
self.discontent
|
self.discontent
|
||||||
return self.content
|
return self.content
|
||||||
|
@ -17,15 +17,19 @@ class SentimentCorrelationModel(BaseAgent):
|
|||||||
|
|
||||||
def __init__(self, environment, unique_id=0, state=()):
|
def __init__(self, environment, unique_id=0, state=()):
|
||||||
super().__init__(model=environment, unique_id=unique_id, state=state)
|
super().__init__(model=environment, unique_id=unique_id, state=state)
|
||||||
self.outside_effects_prob = environment.environment_params['outside_effects_prob']
|
self.outside_effects_prob = environment.environment_params[
|
||||||
self.anger_prob = environment.environment_params['anger_prob']
|
"outside_effects_prob"
|
||||||
self.joy_prob = environment.environment_params['joy_prob']
|
]
|
||||||
self.sadness_prob = environment.environment_params['sadness_prob']
|
self.anger_prob = environment.environment_params["anger_prob"]
|
||||||
self.disgust_prob = environment.environment_params['disgust_prob']
|
self.joy_prob = environment.environment_params["joy_prob"]
|
||||||
self.state['time_awareness'] = []
|
self.sadness_prob = environment.environment_params["sadness_prob"]
|
||||||
|
self.disgust_prob = environment.environment_params["disgust_prob"]
|
||||||
|
self.state["time_awareness"] = []
|
||||||
for i in range(4): # In this model we have 4 sentiments
|
for i in range(4): # In this model we have 4 sentiments
|
||||||
self.state['time_awareness'].append(0) # 0-> Anger, 1-> joy, 2->sadness, 3 -> disgust
|
self.state["time_awareness"].append(
|
||||||
self.state['sentimentCorrelation'] = 0
|
0
|
||||||
|
) # 0-> Anger, 1-> joy, 2->sadness, 3 -> disgust
|
||||||
|
self.state["sentimentCorrelation"] = 0
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
self.behaviour()
|
self.behaviour()
|
||||||
@ -39,63 +43,73 @@ class SentimentCorrelationModel(BaseAgent):
|
|||||||
|
|
||||||
angry_neighbors = self.get_neighboring_agents(state_id=1)
|
angry_neighbors = self.get_neighboring_agents(state_id=1)
|
||||||
for x in angry_neighbors:
|
for x in angry_neighbors:
|
||||||
if x.state['time_awareness'][0] > (self.env.now-500):
|
if x.state["time_awareness"][0] > (self.env.now - 500):
|
||||||
angry_neighbors_1_time_step.append(x)
|
angry_neighbors_1_time_step.append(x)
|
||||||
num_neighbors_angry = len(angry_neighbors_1_time_step)
|
num_neighbors_angry = len(angry_neighbors_1_time_step)
|
||||||
|
|
||||||
joyful_neighbors = self.get_neighboring_agents(state_id=2)
|
joyful_neighbors = self.get_neighboring_agents(state_id=2)
|
||||||
for x in joyful_neighbors:
|
for x in joyful_neighbors:
|
||||||
if x.state['time_awareness'][1] > (self.env.now-500):
|
if x.state["time_awareness"][1] > (self.env.now - 500):
|
||||||
joyful_neighbors_1_time_step.append(x)
|
joyful_neighbors_1_time_step.append(x)
|
||||||
num_neighbors_joyful = len(joyful_neighbors_1_time_step)
|
num_neighbors_joyful = len(joyful_neighbors_1_time_step)
|
||||||
|
|
||||||
sad_neighbors = self.get_neighboring_agents(state_id=3)
|
sad_neighbors = self.get_neighboring_agents(state_id=3)
|
||||||
for x in sad_neighbors:
|
for x in sad_neighbors:
|
||||||
if x.state['time_awareness'][2] > (self.env.now-500):
|
if x.state["time_awareness"][2] > (self.env.now - 500):
|
||||||
sad_neighbors_1_time_step.append(x)
|
sad_neighbors_1_time_step.append(x)
|
||||||
num_neighbors_sad = len(sad_neighbors_1_time_step)
|
num_neighbors_sad = len(sad_neighbors_1_time_step)
|
||||||
|
|
||||||
disgusted_neighbors = self.get_neighboring_agents(state_id=4)
|
disgusted_neighbors = self.get_neighboring_agents(state_id=4)
|
||||||
for x in disgusted_neighbors:
|
for x in disgusted_neighbors:
|
||||||
if x.state['time_awareness'][3] > (self.env.now-500):
|
if x.state["time_awareness"][3] > (self.env.now - 500):
|
||||||
disgusted_neighbors_1_time_step.append(x)
|
disgusted_neighbors_1_time_step.append(x)
|
||||||
num_neighbors_disgusted = len(disgusted_neighbors_1_time_step)
|
num_neighbors_disgusted = len(disgusted_neighbors_1_time_step)
|
||||||
|
|
||||||
anger_prob = self.anger_prob+(len(angry_neighbors_1_time_step)*self.anger_prob)
|
anger_prob = self.anger_prob + (
|
||||||
joy_prob = self.joy_prob+(len(joyful_neighbors_1_time_step)*self.joy_prob)
|
len(angry_neighbors_1_time_step) * self.anger_prob
|
||||||
sadness_prob = self.sadness_prob+(len(sad_neighbors_1_time_step)*self.sadness_prob)
|
)
|
||||||
disgust_prob = self.disgust_prob+(len(disgusted_neighbors_1_time_step)*self.disgust_prob)
|
joy_prob = self.joy_prob + (len(joyful_neighbors_1_time_step) * self.joy_prob)
|
||||||
|
sadness_prob = self.sadness_prob + (
|
||||||
|
len(sad_neighbors_1_time_step) * self.sadness_prob
|
||||||
|
)
|
||||||
|
disgust_prob = self.disgust_prob + (
|
||||||
|
len(disgusted_neighbors_1_time_step) * self.disgust_prob
|
||||||
|
)
|
||||||
outside_effects_prob = self.outside_effects_prob
|
outside_effects_prob = self.outside_effects_prob
|
||||||
|
|
||||||
num = self.random.random()
|
num = self.random.random()
|
||||||
|
|
||||||
if num<outside_effects_prob:
|
if num < outside_effects_prob:
|
||||||
self.state['id'] = self.random.randint(1, 4)
|
self.state["id"] = self.random.randint(1, 4)
|
||||||
|
|
||||||
self.state['sentimentCorrelation'] = self.state['id'] # It is stored when it has been infected for the dynamic network
|
self.state["sentimentCorrelation"] = self.state[
|
||||||
self.state['time_awareness'][self.state['id']-1] = self.env.now
|
"id"
|
||||||
self.state['sentiment'] = self.state['id']
|
] # It is stored when it has been infected for the dynamic network
|
||||||
|
self.state["time_awareness"][self.state["id"] - 1] = self.env.now
|
||||||
|
self.state["sentiment"] = self.state["id"]
|
||||||
|
|
||||||
|
if num < anger_prob:
|
||||||
|
|
||||||
if(num<anger_prob):
|
self.state["id"] = 1
|
||||||
|
self.state["sentimentCorrelation"] = 1
|
||||||
|
self.state["time_awareness"][self.state["id"] - 1] = self.env.now
|
||||||
|
elif num < joy_prob + anger_prob and num > anger_prob:
|
||||||
|
|
||||||
self.state['id'] = 1
|
self.state["id"] = 2
|
||||||
self.state['sentimentCorrelation'] = 1
|
self.state["sentimentCorrelation"] = 2
|
||||||
self.state['time_awareness'][self.state['id']-1] = self.env.now
|
self.state["time_awareness"][self.state["id"] - 1] = self.env.now
|
||||||
elif (num<joy_prob+anger_prob and num>anger_prob):
|
elif num < sadness_prob + anger_prob + joy_prob and num > joy_prob + anger_prob:
|
||||||
|
|
||||||
self.state['id'] = 2
|
self.state["id"] = 3
|
||||||
self.state['sentimentCorrelation'] = 2
|
self.state["sentimentCorrelation"] = 3
|
||||||
self.state['time_awareness'][self.state['id']-1] = self.env.now
|
self.state["time_awareness"][self.state["id"] - 1] = self.env.now
|
||||||
elif (num<sadness_prob+anger_prob+joy_prob and num>joy_prob+anger_prob):
|
elif (
|
||||||
|
num < disgust_prob + sadness_prob + anger_prob + joy_prob
|
||||||
|
and num > sadness_prob + anger_prob + joy_prob
|
||||||
|
):
|
||||||
|
|
||||||
self.state['id'] = 3
|
self.state["id"] = 4
|
||||||
self.state['sentimentCorrelation'] = 3
|
self.state["sentimentCorrelation"] = 4
|
||||||
self.state['time_awareness'][self.state['id']-1] = self.env.now
|
self.state["time_awareness"][self.state["id"] - 1] = self.env.now
|
||||||
elif (num<disgust_prob+sadness_prob+anger_prob+joy_prob and num>sadness_prob+anger_prob+joy_prob):
|
|
||||||
|
|
||||||
self.state['id'] = 4
|
self.state["sentiment"] = self.state["id"]
|
||||||
self.state['sentimentCorrelation'] = 4
|
|
||||||
self.state['time_awareness'][self.state['id']-1] = self.env.now
|
|
||||||
|
|
||||||
self.state['sentiment'] = self.state['id']
|
|
||||||
|
@ -20,13 +20,13 @@ from typing import Dict, List
|
|||||||
from .. import serialization, utils, time, config
|
from .. import serialization, utils, time, config
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def as_node(agent):
|
def as_node(agent):
|
||||||
if isinstance(agent, BaseAgent):
|
if isinstance(agent, BaseAgent):
|
||||||
return agent.id
|
return agent.id
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
IGNORED_FIELDS = ('model', 'logger')
|
|
||||||
|
IGNORED_FIELDS = ("model", "logger")
|
||||||
|
|
||||||
|
|
||||||
class DeadAgent(Exception):
|
class DeadAgent(Exception):
|
||||||
@ -43,13 +43,18 @@ class MetaAgent(ABCMeta):
|
|||||||
defaults.update(i._defaults)
|
defaults.update(i._defaults)
|
||||||
|
|
||||||
new_nmspc = {
|
new_nmspc = {
|
||||||
'_defaults': defaults,
|
"_defaults": defaults,
|
||||||
}
|
}
|
||||||
|
|
||||||
for attr, func in namespace.items():
|
for attr, func in namespace.items():
|
||||||
if isinstance(func, types.FunctionType) or isinstance(func, property) or isinstance(func, classmethod) or attr[0] == '_':
|
if (
|
||||||
|
isinstance(func, types.FunctionType)
|
||||||
|
or isinstance(func, property)
|
||||||
|
or isinstance(func, classmethod)
|
||||||
|
or attr[0] == "_"
|
||||||
|
):
|
||||||
new_nmspc[attr] = func
|
new_nmspc[attr] = func
|
||||||
elif attr == 'defaults':
|
elif attr == "defaults":
|
||||||
defaults.update(func)
|
defaults.update(func)
|
||||||
else:
|
else:
|
||||||
defaults[attr] = copy(func)
|
defaults[attr] = copy(func)
|
||||||
@ -69,12 +74,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|||||||
Any attribute that is not preceded by an underscore (`_`) will also be added to its state.
|
Any attribute that is not preceded by an underscore (`_`) will also be added to its state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, unique_id, model, name=None, interval=None, **kwargs):
|
||||||
unique_id,
|
|
||||||
model,
|
|
||||||
name=None,
|
|
||||||
interval=None,
|
|
||||||
**kwargs):
|
|
||||||
# Check for REQUIRED arguments
|
# Check for REQUIRED arguments
|
||||||
# Initialize agent parameters
|
# Initialize agent parameters
|
||||||
if isinstance(unique_id, MesaAgent):
|
if isinstance(unique_id, MesaAgent):
|
||||||
@ -82,16 +82,19 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|||||||
assert isinstance(unique_id, int)
|
assert isinstance(unique_id, int)
|
||||||
super().__init__(unique_id=unique_id, model=model)
|
super().__init__(unique_id=unique_id, model=model)
|
||||||
|
|
||||||
self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id)
|
self.name = (
|
||||||
|
str(name) if name else "{}[{}]".format(type(self).__name__, self.unique_id)
|
||||||
|
)
|
||||||
|
|
||||||
self.alive = True
|
self.alive = True
|
||||||
|
|
||||||
self.interval = interval or self.get('interval', 1)
|
self.interval = interval or self.get("interval", 1)
|
||||||
logger = utils.logger.getChild(getattr(self.model, 'id', self.model)).getChild(self.name)
|
logger = utils.logger.getChild(getattr(self.model, "id", self.model)).getChild(
|
||||||
self.logger = logging.LoggerAdapter(logger, {'agent_name': self.name})
|
self.name
|
||||||
|
)
|
||||||
|
self.logger = logging.LoggerAdapter(logger, {"agent_name": self.name})
|
||||||
|
|
||||||
if hasattr(self, 'level'):
|
if hasattr(self, "level"):
|
||||||
self.logger.setLevel(self.level)
|
self.logger.setLevel(self.level)
|
||||||
|
|
||||||
for (k, v) in self._defaults.items():
|
for (k, v) in self._defaults.items():
|
||||||
@ -117,20 +120,22 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|||||||
def from_dict(cls, model, attrs, warn_extra=True):
|
def from_dict(cls, model, attrs, warn_extra=True):
|
||||||
ignored = {}
|
ignored = {}
|
||||||
args = {}
|
args = {}
|
||||||
for k, v in attrs.items():
|
for k, v in attrs.items():
|
||||||
if k in inspect.signature(cls).parameters:
|
if k in inspect.signature(cls).parameters:
|
||||||
args[k] = v
|
args[k] = v
|
||||||
else:
|
else:
|
||||||
ignored[k] = v
|
ignored[k] = v
|
||||||
if ignored and warn_extra:
|
if ignored and warn_extra:
|
||||||
utils.logger.info(f'Ignoring the following arguments for agent class { agent_class.__name__ }: { ignored }')
|
utils.logger.info(
|
||||||
|
f"Ignoring the following arguments for agent class { agent_class.__name__ }: { ignored }"
|
||||||
|
)
|
||||||
return cls(model=model, **args)
|
return cls(model=model, **args)
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
try:
|
try:
|
||||||
return getattr(self, key)
|
return getattr(self, key)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise KeyError(f'key {key} not found in agent')
|
raise KeyError(f"key {key} not found in agent")
|
||||||
|
|
||||||
def __delitem__(self, key):
|
def __delitem__(self, key):
|
||||||
return delattr(self, key)
|
return delattr(self, key)
|
||||||
@ -148,7 +153,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|||||||
return self.items()
|
return self.items()
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return (k for k in self.__dict__ if k[0] != '_' and k not in IGNORED_FIELDS)
|
return (k for k in self.__dict__ if k[0] != "_" and k not in IGNORED_FIELDS)
|
||||||
|
|
||||||
def items(self, keys=None, skip=None):
|
def items(self, keys=None, skip=None):
|
||||||
keys = keys if keys is not None else self.keys()
|
keys = keys if keys is not None else self.keys()
|
||||||
@ -169,7 +174,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def die(self):
|
def die(self):
|
||||||
self.info(f'agent dying')
|
self.info(f"agent dying")
|
||||||
self.alive = False
|
self.alive = False
|
||||||
return time.NEVER
|
return time.NEVER
|
||||||
|
|
||||||
@ -186,9 +191,9 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|||||||
for k, v in kwargs:
|
for k, v in kwargs:
|
||||||
message += " {k}={v} ".format(k, v)
|
message += " {k}={v} ".format(k, v)
|
||||||
extra = {}
|
extra = {}
|
||||||
extra['now'] = self.now
|
extra["now"] = self.now
|
||||||
extra['unique_id'] = self.unique_id
|
extra["unique_id"] = self.unique_id
|
||||||
extra['agent_name'] = self.name
|
extra["agent_name"] = self.name
|
||||||
return self.logger.log(level, message, extra=extra)
|
return self.logger.log(level, message, extra=extra)
|
||||||
|
|
||||||
def debug(self, *args, **kwargs):
|
def debug(self, *args, **kwargs):
|
||||||
@ -214,10 +219,10 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|||||||
content = dict(self.items(keys=keys))
|
content = dict(self.items(keys=keys))
|
||||||
if pretty and content:
|
if pretty and content:
|
||||||
d = content
|
d = content
|
||||||
content = '\n'
|
content = "\n"
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
content += f'- {k}: {v}\n'
|
content += f"- {k}: {v}\n"
|
||||||
content = textwrap.indent(content, ' ')
|
content = textwrap.indent(content, " ")
|
||||||
return f"{repr(self)}{content}"
|
return f"{repr(self)}{content}"
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@ -225,7 +230,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|||||||
|
|
||||||
|
|
||||||
class NetworkAgent(BaseAgent):
|
class NetworkAgent(BaseAgent):
|
||||||
|
|
||||||
def __init__(self, *args, topology, node_id, **kwargs):
|
def __init__(self, *args, topology, node_id, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@ -248,18 +252,21 @@ class NetworkAgent(BaseAgent):
|
|||||||
def node(self):
|
def node(self):
|
||||||
return self.topology.nodes[self.node_id]
|
return self.topology.nodes[self.node_id]
|
||||||
|
|
||||||
|
|
||||||
def iter_agents(self, unique_id=None, *, limit_neighbors=False, **kwargs):
|
def iter_agents(self, unique_id=None, *, limit_neighbors=False, **kwargs):
|
||||||
unique_ids = None
|
unique_ids = None
|
||||||
if isinstance(unique_id, list):
|
if isinstance(unique_id, list):
|
||||||
unique_ids = set(unique_id)
|
unique_ids = set(unique_id)
|
||||||
elif unique_id is not None:
|
elif unique_id is not None:
|
||||||
unique_ids = set([unique_id,])
|
unique_ids = set(
|
||||||
|
[
|
||||||
|
unique_id,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if limit_neighbors:
|
if limit_neighbors:
|
||||||
neighbor_ids = set()
|
neighbor_ids = set()
|
||||||
for node_id in self.G.neighbors(self.node_id):
|
for node_id in self.G.neighbors(self.node_id):
|
||||||
if self.G.nodes[node_id].get('agent') is not None:
|
if self.G.nodes[node_id].get("agent") is not None:
|
||||||
neighbor_ids.add(node_id)
|
neighbor_ids.add(node_id)
|
||||||
if unique_ids:
|
if unique_ids:
|
||||||
unique_ids = unique_ids & neighbor_ids
|
unique_ids = unique_ids & neighbor_ids
|
||||||
@ -272,7 +279,9 @@ class NetworkAgent(BaseAgent):
|
|||||||
|
|
||||||
def subgraph(self, center=True, **kwargs):
|
def subgraph(self, center=True, **kwargs):
|
||||||
include = [self] if center else []
|
include = [self] if center else []
|
||||||
G = self.G.subgraph(n.node_id for n in list(self.get_agents(**kwargs)+include))
|
G = self.G.subgraph(
|
||||||
|
n.node_id for n in list(self.get_agents(**kwargs) + include)
|
||||||
|
)
|
||||||
return G
|
return G
|
||||||
|
|
||||||
def remove_node(self):
|
def remove_node(self):
|
||||||
@ -280,11 +289,19 @@ class NetworkAgent(BaseAgent):
|
|||||||
|
|
||||||
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
||||||
if self.node_id not in self.G.nodes(data=False):
|
if self.node_id not in self.G.nodes(data=False):
|
||||||
raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id))
|
raise ValueError(
|
||||||
|
"{} not in list of existing agents in the network".format(
|
||||||
|
self.unique_id
|
||||||
|
)
|
||||||
|
)
|
||||||
if other.node_id not in self.G.nodes(data=False):
|
if other.node_id not in self.G.nodes(data=False):
|
||||||
raise ValueError('{} not in list of existing agents in the network'.format(other))
|
raise ValueError(
|
||||||
|
"{} not in list of existing agents in the network".format(other)
|
||||||
|
)
|
||||||
|
|
||||||
self.G.add_edge(self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs)
|
self.G.add_edge(
|
||||||
|
self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs
|
||||||
|
)
|
||||||
|
|
||||||
def die(self, remove=True):
|
def die(self, remove=True):
|
||||||
if remove:
|
if remove:
|
||||||
@ -294,11 +311,11 @@ class NetworkAgent(BaseAgent):
|
|||||||
|
|
||||||
def state(name=None):
|
def state(name=None):
|
||||||
def decorator(func, name=None):
|
def decorator(func, name=None):
|
||||||
'''
|
"""
|
||||||
A state function should return either a state id, or a tuple (state_id, when)
|
A state function should return either a state id, or a tuple (state_id, when)
|
||||||
The default value for state_id is the current state id.
|
The default value for state_id is the current state id.
|
||||||
The default value for when is the interval defined in the environment.
|
The default value for when is the interval defined in the environment.
|
||||||
'''
|
"""
|
||||||
if inspect.isgeneratorfunction(func):
|
if inspect.isgeneratorfunction(func):
|
||||||
orig_func = func
|
orig_func = func
|
||||||
|
|
||||||
@ -348,32 +365,38 @@ class MetaFSM(MetaAgent):
|
|||||||
|
|
||||||
# Add new states
|
# Add new states
|
||||||
for attr, func in namespace.items():
|
for attr, func in namespace.items():
|
||||||
if hasattr(func, 'id'):
|
if hasattr(func, "id"):
|
||||||
if func.is_default:
|
if func.is_default:
|
||||||
default_state = func
|
default_state = func
|
||||||
states[func.id] = func
|
states[func.id] = func
|
||||||
|
|
||||||
namespace.update({
|
namespace.update(
|
||||||
'_default_state': default_state,
|
{
|
||||||
'_states': states,
|
"_default_state": default_state,
|
||||||
})
|
"_states": states,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return super(MetaFSM, mcls).__new__(mcls=mcls, name=name, bases=bases, namespace=namespace)
|
return super(MetaFSM, mcls).__new__(
|
||||||
|
mcls=mcls, name=name, bases=bases, namespace=namespace
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(FSM, self).__init__(*args, **kwargs)
|
super(FSM, self).__init__(*args, **kwargs)
|
||||||
if not hasattr(self, 'state_id'):
|
if not hasattr(self, "state_id"):
|
||||||
if not self._default_state:
|
if not self._default_state:
|
||||||
raise ValueError('No default state specified for {}'.format(self.unique_id))
|
raise ValueError(
|
||||||
|
"No default state specified for {}".format(self.unique_id)
|
||||||
|
)
|
||||||
self.state_id = self._default_state.id
|
self.state_id = self._default_state.id
|
||||||
|
|
||||||
self._coroutine = None
|
self._coroutine = None
|
||||||
self.set_state(self.state_id)
|
self.set_state(self.state_id)
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
self.debug(f'Agent {self.unique_id} @ state {self.state_id}')
|
self.debug(f"Agent {self.unique_id} @ state {self.state_id}")
|
||||||
default_interval = super().step()
|
default_interval = super().step()
|
||||||
|
|
||||||
next_state = self._states[self.state_id](self)
|
next_state = self._states[self.state_id](self)
|
||||||
@ -386,7 +409,9 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
|||||||
elif len(when) == 1:
|
elif len(when) == 1:
|
||||||
when = when[0]
|
when = when[0]
|
||||||
else:
|
else:
|
||||||
raise ValueError('Too many values returned. Only state (and time) allowed')
|
raise ValueError(
|
||||||
|
"Too many values returned. Only state (and time) allowed"
|
||||||
|
)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -396,10 +421,10 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
|||||||
return when or default_interval
|
return when or default_interval
|
||||||
|
|
||||||
def set_state(self, state, when=None):
|
def set_state(self, state, when=None):
|
||||||
if hasattr(state, 'id'):
|
if hasattr(state, "id"):
|
||||||
state = state.id
|
state = state.id
|
||||||
if state not in self._states:
|
if state not in self._states:
|
||||||
raise ValueError('{} is not a valid state'.format(state))
|
raise ValueError("{} is not a valid state".format(state))
|
||||||
self.state_id = state
|
self.state_id = state
|
||||||
if when is not None:
|
if when is not None:
|
||||||
self.model.schedule.add(self, when=when)
|
self.model.schedule.add(self, when=when)
|
||||||
@ -414,7 +439,7 @@ class FSM(BaseAgent, metaclass=MetaFSM):
|
|||||||
|
|
||||||
|
|
||||||
def prob(prob, random):
|
def prob(prob, random):
|
||||||
'''
|
"""
|
||||||
A true/False uniform distribution with a given probability.
|
A true/False uniform distribution with a given probability.
|
||||||
To be used like this:
|
To be used like this:
|
||||||
|
|
||||||
@ -423,14 +448,13 @@ def prob(prob, random):
|
|||||||
if prob(0.3):
|
if prob(0.3):
|
||||||
do_something()
|
do_something()
|
||||||
|
|
||||||
'''
|
"""
|
||||||
r = random.random()
|
r = random.random()
|
||||||
return r < prob
|
return r < prob
|
||||||
|
|
||||||
|
|
||||||
def calculate_distribution(network_agents=None,
|
def calculate_distribution(network_agents=None, agent_class=None):
|
||||||
agent_class=None):
|
"""
|
||||||
'''
|
|
||||||
Calculate the threshold values (thresholds for a uniform distribution)
|
Calculate the threshold values (thresholds for a uniform distribution)
|
||||||
of an agent distribution given the weights of each agent type.
|
of an agent distribution given the weights of each agent type.
|
||||||
|
|
||||||
@ -453,26 +477,28 @@ def calculate_distribution(network_agents=None,
|
|||||||
|
|
||||||
In this example, 20% of the nodes will be marked as type
|
In this example, 20% of the nodes will be marked as type
|
||||||
'agent_class_1'.
|
'agent_class_1'.
|
||||||
'''
|
"""
|
||||||
if network_agents:
|
if network_agents:
|
||||||
network_agents = [deepcopy(agent) for agent in network_agents if not hasattr(agent, 'id')]
|
network_agents = [
|
||||||
|
deepcopy(agent) for agent in network_agents if not hasattr(agent, "id")
|
||||||
|
]
|
||||||
elif agent_class:
|
elif agent_class:
|
||||||
network_agents = [{'agent_class': agent_class}]
|
network_agents = [{"agent_class": agent_class}]
|
||||||
else:
|
else:
|
||||||
raise ValueError('Specify a distribution or a default agent type')
|
raise ValueError("Specify a distribution or a default agent type")
|
||||||
|
|
||||||
# Fix missing weights and incompatible types
|
# Fix missing weights and incompatible types
|
||||||
for x in network_agents:
|
for x in network_agents:
|
||||||
x['weight'] = float(x.get('weight', 1))
|
x["weight"] = float(x.get("weight", 1))
|
||||||
|
|
||||||
# Calculate the thresholds
|
# Calculate the thresholds
|
||||||
total = sum(x['weight'] for x in network_agents)
|
total = sum(x["weight"] for x in network_agents)
|
||||||
acc = 0
|
acc = 0
|
||||||
for v in network_agents:
|
for v in network_agents:
|
||||||
if 'ids' in v:
|
if "ids" in v:
|
||||||
continue
|
continue
|
||||||
upper = acc + (v['weight']/total)
|
upper = acc + (v["weight"] / total)
|
||||||
v['threshold'] = [acc, upper]
|
v["threshold"] = [acc, upper]
|
||||||
acc = upper
|
acc = upper
|
||||||
return network_agents
|
return network_agents
|
||||||
|
|
||||||
@ -480,28 +506,29 @@ def calculate_distribution(network_agents=None,
|
|||||||
def serialize_type(agent_class, known_modules=[], **kwargs):
|
def serialize_type(agent_class, known_modules=[], **kwargs):
|
||||||
if isinstance(agent_class, str):
|
if isinstance(agent_class, str):
|
||||||
return agent_class
|
return agent_class
|
||||||
known_modules += ['soil.agents']
|
known_modules += ["soil.agents"]
|
||||||
return serialization.serialize(agent_class, known_modules=known_modules, **kwargs)[1] # Get the name of the class
|
return serialization.serialize(agent_class, known_modules=known_modules, **kwargs)[
|
||||||
|
1
|
||||||
|
] # Get the name of the class
|
||||||
|
|
||||||
|
|
||||||
def serialize_definition(network_agents, known_modules=[]):
|
def serialize_definition(network_agents, known_modules=[]):
|
||||||
'''
|
"""
|
||||||
When serializing an agent distribution, remove the thresholds, in order
|
When serializing an agent distribution, remove the thresholds, in order
|
||||||
to avoid cluttering the YAML definition file.
|
to avoid cluttering the YAML definition file.
|
||||||
'''
|
"""
|
||||||
d = deepcopy(list(network_agents))
|
d = deepcopy(list(network_agents))
|
||||||
for v in d:
|
for v in d:
|
||||||
if 'threshold' in v:
|
if "threshold" in v:
|
||||||
del v['threshold']
|
del v["threshold"]
|
||||||
v['agent_class'] = serialize_type(v['agent_class'],
|
v["agent_class"] = serialize_type(v["agent_class"], known_modules=known_modules)
|
||||||
known_modules=known_modules)
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def deserialize_type(agent_class, known_modules=[]):
|
def deserialize_type(agent_class, known_modules=[]):
|
||||||
if not isinstance(agent_class, str):
|
if not isinstance(agent_class, str):
|
||||||
return agent_class
|
return agent_class
|
||||||
known = known_modules + ['soil.agents', 'soil.agents.custom' ]
|
known = known_modules + ["soil.agents", "soil.agents.custom"]
|
||||||
agent_class = serialization.deserializer(agent_class, known_modules=known)
|
agent_class = serialization.deserializer(agent_class, known_modules=known)
|
||||||
return agent_class
|
return agent_class
|
||||||
|
|
||||||
@ -509,12 +536,12 @@ def deserialize_type(agent_class, known_modules=[]):
|
|||||||
def deserialize_definition(ind, **kwargs):
|
def deserialize_definition(ind, **kwargs):
|
||||||
d = deepcopy(ind)
|
d = deepcopy(ind)
|
||||||
for v in d:
|
for v in d:
|
||||||
v['agent_class'] = deserialize_type(v['agent_class'], **kwargs)
|
v["agent_class"] = deserialize_type(v["agent_class"], **kwargs)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def _validate_states(states, topology):
|
def _validate_states(states, topology):
|
||||||
'''Validate states to avoid ignoring states during initialization'''
|
"""Validate states to avoid ignoring states during initialization"""
|
||||||
states = states or []
|
states = states or []
|
||||||
if isinstance(states, dict):
|
if isinstance(states, dict):
|
||||||
for x in states:
|
for x in states:
|
||||||
@ -525,7 +552,7 @@ def _validate_states(states, topology):
|
|||||||
|
|
||||||
|
|
||||||
def _convert_agent_classs(ind, to_string=False, **kwargs):
|
def _convert_agent_classs(ind, to_string=False, **kwargs):
|
||||||
'''Convenience method to allow specifying agents by class or class name.'''
|
"""Convenience method to allow specifying agents by class or class name."""
|
||||||
if to_string:
|
if to_string:
|
||||||
return serialize_definition(ind, **kwargs)
|
return serialize_definition(ind, **kwargs)
|
||||||
return deserialize_definition(ind, **kwargs)
|
return deserialize_definition(ind, **kwargs)
|
||||||
@ -609,12 +636,10 @@ def _convert_agent_classs(ind, to_string=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
class AgentView(Mapping, Set):
|
class AgentView(Mapping, Set):
|
||||||
"""A lazy-loaded list of agents.
|
"""A lazy-loaded list of agents."""
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("_agents",)
|
__slots__ = ("_agents",)
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, agents):
|
def __init__(self, agents):
|
||||||
self._agents = agents
|
self._agents = agents
|
||||||
|
|
||||||
@ -657,11 +682,20 @@ class AgentView(Mapping, Set):
|
|||||||
return f"{self.__class__.__name__}({self})"
|
return f"{self.__class__.__name__}({self})"
|
||||||
|
|
||||||
|
|
||||||
def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=None, ignore=None, state=None,
|
def filter_agents(
|
||||||
limit=None, **kwargs):
|
agents,
|
||||||
'''
|
*id_args,
|
||||||
|
unique_id=None,
|
||||||
|
state_id=None,
|
||||||
|
agent_class=None,
|
||||||
|
ignore=None,
|
||||||
|
state=None,
|
||||||
|
limit=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id).
|
||||||
'''
|
"""
|
||||||
assert isinstance(agents, dict)
|
assert isinstance(agents, dict)
|
||||||
|
|
||||||
ids = []
|
ids = []
|
||||||
@ -694,7 +728,7 @@ def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=N
|
|||||||
f = filter(lambda x: x not in ignore, f)
|
f = filter(lambda x: x not in ignore, f)
|
||||||
|
|
||||||
if state_id is not None:
|
if state_id is not None:
|
||||||
f = filter(lambda agent: agent.get('state_id', None) in state_id, f)
|
f = filter(lambda agent: agent.get("state_id", None) in state_id, f)
|
||||||
|
|
||||||
if agent_class is not None:
|
if agent_class is not None:
|
||||||
f = filter(lambda agent: isinstance(agent, agent_class), f)
|
f = filter(lambda agent: isinstance(agent, agent_class), f)
|
||||||
@ -711,23 +745,25 @@ def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=N
|
|||||||
yield from f
|
yield from f
|
||||||
|
|
||||||
|
|
||||||
def from_config(cfg: config.AgentConfig, random, topology: nx.Graph = None) -> List[Dict[str, Any]]:
|
def from_config(
|
||||||
'''
|
cfg: config.AgentConfig, random, topology: nx.Graph = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
This function turns an agentconfig into a list of individual "agent specifications", which are just a dictionary
|
This function turns an agentconfig into a list of individual "agent specifications", which are just a dictionary
|
||||||
with the parameters that the environment will use to construct each agent.
|
with the parameters that the environment will use to construct each agent.
|
||||||
|
|
||||||
This function does NOT return a list of agents, mostly because some attributes to the agent are not known at the
|
This function does NOT return a list of agents, mostly because some attributes to the agent are not known at the
|
||||||
time of calling this function, such as `unique_id`.
|
time of calling this function, such as `unique_id`.
|
||||||
'''
|
"""
|
||||||
default = cfg or config.AgentConfig()
|
default = cfg or config.AgentConfig()
|
||||||
if not isinstance(cfg, config.AgentConfig):
|
if not isinstance(cfg, config.AgentConfig):
|
||||||
cfg = config.AgentConfig(**cfg)
|
cfg = config.AgentConfig(**cfg)
|
||||||
return _agents_from_config(cfg, topology=topology, random=random)
|
return _agents_from_config(cfg, topology=topology, random=random)
|
||||||
|
|
||||||
|
|
||||||
def _agents_from_config(cfg: config.AgentConfig,
|
def _agents_from_config(
|
||||||
topology: nx.Graph,
|
cfg: config.AgentConfig, topology: nx.Graph, random
|
||||||
random) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
if cfg and not isinstance(cfg, config.AgentConfig):
|
if cfg and not isinstance(cfg, config.AgentConfig):
|
||||||
cfg = config.AgentConfig(**cfg)
|
cfg = config.AgentConfig(**cfg)
|
||||||
|
|
||||||
@ -737,7 +773,9 @@ def _agents_from_config(cfg: config.AgentConfig,
|
|||||||
assigned_network = 0
|
assigned_network = 0
|
||||||
|
|
||||||
if cfg.fixed is not None:
|
if cfg.fixed is not None:
|
||||||
agents, assigned_total, assigned_network = _from_fixed(cfg.fixed, topology=cfg.topology, default=cfg)
|
agents, assigned_total, assigned_network = _from_fixed(
|
||||||
|
cfg.fixed, topology=cfg.topology, default=cfg
|
||||||
|
)
|
||||||
|
|
||||||
n = cfg.n
|
n = cfg.n
|
||||||
|
|
||||||
@ -749,46 +787,56 @@ def _agents_from_config(cfg: config.AgentConfig,
|
|||||||
|
|
||||||
for d in cfg.distribution:
|
for d in cfg.distribution:
|
||||||
if d.strategy == config.Strategy.topology:
|
if d.strategy == config.Strategy.topology:
|
||||||
topo = d.topology if ('topology' in d.__fields_set__) else cfg.topology
|
topo = d.topology if ("topology" in d.__fields_set__) else cfg.topology
|
||||||
if not topo:
|
if not topo:
|
||||||
raise ValueError('The "topology" strategy only works if the topology parameter is set to True')
|
raise ValueError(
|
||||||
|
'The "topology" strategy only works if the topology parameter is set to True'
|
||||||
|
)
|
||||||
if not topo_size:
|
if not topo_size:
|
||||||
raise ValueError(f'Topology does not have enough free nodes to assign one to the agent')
|
raise ValueError(
|
||||||
|
f"Topology does not have enough free nodes to assign one to the agent"
|
||||||
|
)
|
||||||
|
|
||||||
networked.append(d)
|
networked.append(d)
|
||||||
|
|
||||||
if d.strategy == config.Strategy.total:
|
if d.strategy == config.Strategy.total:
|
||||||
if not cfg.n:
|
if not cfg.n:
|
||||||
raise ValueError('Cannot use the "total" strategy without providing the total number of agents')
|
raise ValueError(
|
||||||
|
'Cannot use the "total" strategy without providing the total number of agents'
|
||||||
|
)
|
||||||
total.append(d)
|
total.append(d)
|
||||||
|
|
||||||
|
|
||||||
if networked:
|
if networked:
|
||||||
new_agents = _from_distro(networked,
|
new_agents = _from_distro(
|
||||||
n= topo_size - assigned_network,
|
networked,
|
||||||
topology=topo,
|
n=topo_size - assigned_network,
|
||||||
default=cfg,
|
topology=topo,
|
||||||
random=random)
|
default=cfg,
|
||||||
|
random=random,
|
||||||
|
)
|
||||||
assigned_total += len(new_agents)
|
assigned_total += len(new_agents)
|
||||||
assigned_network += len(new_agents)
|
assigned_network += len(new_agents)
|
||||||
agents += new_agents
|
agents += new_agents
|
||||||
|
|
||||||
if total:
|
if total:
|
||||||
remaining = n - assigned_total
|
remaining = n - assigned_total
|
||||||
agents += _from_distro(total, n=remaining,
|
agents += _from_distro(total, n=remaining, default=cfg, random=random)
|
||||||
default=cfg,
|
|
||||||
random=random)
|
|
||||||
|
|
||||||
|
|
||||||
if assigned_network < topo_size:
|
if assigned_network < topo_size:
|
||||||
utils.logger.warn(f'The total number of agents does not match the total number of nodes in '
|
utils.logger.warn(
|
||||||
'every topology. This may be due to a definition error: assigned: '
|
f"The total number of agents does not match the total number of nodes in "
|
||||||
f'{ assigned } total size: { topo_size }')
|
"every topology. This may be due to a definition error: assigned: "
|
||||||
|
f"{ assigned } total size: { topo_size }"
|
||||||
|
)
|
||||||
|
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
|
|
||||||
def _from_fixed(lst: List[config.FixedAgentConfig], topology: bool, default: config.SingleAgentConfig) -> List[Dict[str, Any]]:
|
def _from_fixed(
|
||||||
|
lst: List[config.FixedAgentConfig],
|
||||||
|
topology: bool,
|
||||||
|
default: config.SingleAgentConfig,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
agents = []
|
agents = []
|
||||||
|
|
||||||
counts_total = 0
|
counts_total = 0
|
||||||
@ -799,12 +847,18 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: bool, default: con
|
|||||||
if default:
|
if default:
|
||||||
agent = default.state.copy()
|
agent = default.state.copy()
|
||||||
agent.update(fixed.state)
|
agent.update(fixed.state)
|
||||||
cls = serialization.deserialize(fixed.agent_class or (default and default.agent_class))
|
cls = serialization.deserialize(
|
||||||
agent['agent_class'] = cls
|
fixed.agent_class or (default and default.agent_class)
|
||||||
topo = fixed.topology if ('topology' in fixed.__fields_set__) else topology or default.topology
|
)
|
||||||
|
agent["agent_class"] = cls
|
||||||
|
topo = (
|
||||||
|
fixed.topology
|
||||||
|
if ("topology" in fixed.__fields_set__)
|
||||||
|
else topology or default.topology
|
||||||
|
)
|
||||||
|
|
||||||
if topo:
|
if topo:
|
||||||
agent['topology'] = True
|
agent["topology"] = True
|
||||||
counts_network += 1
|
counts_network += 1
|
||||||
if not fixed.hidden:
|
if not fixed.hidden:
|
||||||
counts_total += 1
|
counts_total += 1
|
||||||
@ -813,17 +867,21 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: bool, default: con
|
|||||||
return agents, counts_total, counts_network
|
return agents, counts_total, counts_network
|
||||||
|
|
||||||
|
|
||||||
def _from_distro(distro: List[config.AgentDistro],
|
def _from_distro(
|
||||||
n: int,
|
distro: List[config.AgentDistro],
|
||||||
topology: str,
|
n: int,
|
||||||
default: config.SingleAgentConfig,
|
topology: str,
|
||||||
random) -> List[Dict[str, Any]]:
|
default: config.SingleAgentConfig,
|
||||||
|
random,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
|
||||||
agents = []
|
agents = []
|
||||||
|
|
||||||
if n is None:
|
if n is None:
|
||||||
if any(lambda dist: dist.n is None, distro):
|
if any(lambda dist: dist.n is None, distro):
|
||||||
raise ValueError('You must provide a total number of agents, or the number of each type')
|
raise ValueError(
|
||||||
|
"You must provide a total number of agents, or the number of each type"
|
||||||
|
)
|
||||||
n = sum(dist.n for dist in distro)
|
n = sum(dist.n for dist in distro)
|
||||||
|
|
||||||
weights = list(dist.weight if dist.weight is not None else 1 for dist in distro)
|
weights = list(dist.weight if dist.weight is not None else 1 for dist in distro)
|
||||||
@ -836,29 +894,40 @@ def _from_distro(distro: List[config.AgentDistro],
|
|||||||
# So instead we calculate our own distribution to make sure the actual ratios are close to what we would expect
|
# So instead we calculate our own distribution to make sure the actual ratios are close to what we would expect
|
||||||
|
|
||||||
# Calculate how many times each has to appear
|
# Calculate how many times each has to appear
|
||||||
indices = list(chain.from_iterable([idx] * int(n*chunk) for (idx, n) in enumerate(norm)))
|
indices = list(
|
||||||
|
chain.from_iterable([idx] * int(n * chunk) for (idx, n) in enumerate(norm))
|
||||||
|
)
|
||||||
|
|
||||||
# Complete with random agents following the original weight distribution
|
# Complete with random agents following the original weight distribution
|
||||||
if len(indices) < n:
|
if len(indices) < n:
|
||||||
indices += random.choices(list(range(len(distro))), weights=[d.weight for d in distro], k=n-len(indices))
|
indices += random.choices(
|
||||||
|
list(range(len(distro))),
|
||||||
|
weights=[d.weight for d in distro],
|
||||||
|
k=n - len(indices),
|
||||||
|
)
|
||||||
|
|
||||||
# Deserialize classes for efficiency
|
# Deserialize classes for efficiency
|
||||||
classes = list(serialization.deserialize(i.agent_class or default.agent_class) for i in distro)
|
classes = list(
|
||||||
|
serialization.deserialize(i.agent_class or default.agent_class) for i in distro
|
||||||
|
)
|
||||||
|
|
||||||
# Add them in random order
|
# Add them in random order
|
||||||
random.shuffle(indices)
|
random.shuffle(indices)
|
||||||
|
|
||||||
|
|
||||||
for idx in indices:
|
for idx in indices:
|
||||||
d = distro[idx]
|
d = distro[idx]
|
||||||
agent = d.state.copy()
|
agent = d.state.copy()
|
||||||
cls = classes[idx]
|
cls = classes[idx]
|
||||||
agent['agent_class'] = cls
|
agent["agent_class"] = cls
|
||||||
if default:
|
if default:
|
||||||
agent.update(default.state)
|
agent.update(default.state)
|
||||||
topology = d.topology if ('topology' in d.__fields_set__) else topology or default.topology
|
topology = (
|
||||||
|
d.topology
|
||||||
|
if ("topology" in d.__fields_set__)
|
||||||
|
else topology or default.topology
|
||||||
|
)
|
||||||
if topology:
|
if topology:
|
||||||
agent['topology'] = topology
|
agent["topology"] = topology
|
||||||
agents.append(agent)
|
agents.append(agent)
|
||||||
|
|
||||||
return agents
|
return agents
|
||||||
@ -877,4 +946,5 @@ try:
|
|||||||
from .Geo import Geo
|
from .Geo import Geo
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import sys
|
import sys
|
||||||
print('Could not load the Geo Agent, scipy is not installed', file=sys.stderr)
|
|
||||||
|
print("Could not load the Geo Agent, scipy is not installed", file=sys.stderr)
|
||||||
|
167
soil/config.py
167
soil/config.py
@ -19,6 +19,7 @@ import networkx as nx
|
|||||||
# Could use TypeAlias in python >= 3.10
|
# Could use TypeAlias in python >= 3.10
|
||||||
nodeId = int
|
nodeId = int
|
||||||
|
|
||||||
|
|
||||||
class Node(BaseModel):
|
class Node(BaseModel):
|
||||||
id: nodeId
|
id: nodeId
|
||||||
state: Optional[Dict[str, Any]] = {}
|
state: Optional[Dict[str, Any]] = {}
|
||||||
@ -54,14 +55,15 @@ class NetConfig(BaseModel):
|
|||||||
return NetConfig(topology=None, params=None)
|
return NetConfig(topology=None, params=None)
|
||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def validate_all(cls, values):
|
def validate_all(cls, values):
|
||||||
if 'params' not in values and 'topology' not in values:
|
if "params" not in values and "topology" not in values:
|
||||||
raise ValueError('You must specify either a topology or the parameters to generate a graph')
|
raise ValueError(
|
||||||
|
"You must specify either a topology or the parameters to generate a graph"
|
||||||
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
class EnvConfig(BaseModel):
|
class EnvConfig(BaseModel):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default():
|
def default():
|
||||||
return EnvConfig()
|
return EnvConfig()
|
||||||
@ -80,9 +82,11 @@ class FixedAgentConfig(SingleAgentConfig):
|
|||||||
hidden: Optional[bool] = False # Do not count this agent towards total agent count
|
hidden: Optional[bool] = False # Do not count this agent towards total agent count
|
||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def validate_all(cls, values):
|
def validate_all(cls, values):
|
||||||
if values.get('unique_id', None) is not None and values.get('n', 1) > 1:
|
if values.get("unique_id", None) is not None and values.get("n", 1) > 1:
|
||||||
raise ValueError(f"An unique_id can only be provided when there is only one agent ({values.get('n')} given)")
|
raise ValueError(
|
||||||
|
f"An unique_id can only be provided when there is only one agent ({values.get('n')} given)"
|
||||||
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
@ -91,8 +95,8 @@ class OverrideAgentConfig(FixedAgentConfig):
|
|||||||
|
|
||||||
|
|
||||||
class Strategy(Enum):
|
class Strategy(Enum):
|
||||||
topology = 'topology'
|
topology = "topology"
|
||||||
total = 'total'
|
total = "total"
|
||||||
|
|
||||||
|
|
||||||
class AgentDistro(SingleAgentConfig):
|
class AgentDistro(SingleAgentConfig):
|
||||||
@ -111,16 +115,20 @@ class AgentConfig(SingleAgentConfig):
|
|||||||
return AgentConfig()
|
return AgentConfig()
|
||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def validate_all(cls, values):
|
def validate_all(cls, values):
|
||||||
if 'distribution' in values and ('n' not in values and 'topology' not in values):
|
if "distribution" in values and (
|
||||||
raise ValueError("You need to provide the number of agents or a topology to extract the value from.")
|
"n" not in values and "topology" not in values
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"You need to provide the number of agents or a topology to extract the value from."
|
||||||
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel, extra=Extra.allow):
|
class Config(BaseModel, extra=Extra.allow):
|
||||||
version: Optional[str] = '1'
|
version: Optional[str] = "1"
|
||||||
|
|
||||||
name: str = 'Unnamed Simulation'
|
name: str = "Unnamed Simulation"
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
group: str = None
|
group: str = None
|
||||||
dir_path: Optional[str] = None
|
dir_path: Optional[str] = None
|
||||||
@ -140,45 +148,48 @@ class Config(BaseModel, extra=Extra.allow):
|
|||||||
def from_raw(cls, cfg):
|
def from_raw(cls, cfg):
|
||||||
if isinstance(cfg, Config):
|
if isinstance(cfg, Config):
|
||||||
return cfg
|
return cfg
|
||||||
if cfg.get('version', '1') == '1' and any(k in cfg for k in ['agents', 'agent_class', 'topology', 'environment_class']):
|
if cfg.get("version", "1") == "1" and any(
|
||||||
|
k in cfg for k in ["agents", "agent_class", "topology", "environment_class"]
|
||||||
|
):
|
||||||
return convert_old(cfg)
|
return convert_old(cfg)
|
||||||
return Config(**cfg)
|
return Config(**cfg)
|
||||||
|
|
||||||
|
|
||||||
def convert_old(old, strict=True):
|
def convert_old(old, strict=True):
|
||||||
'''
|
"""
|
||||||
Try to convert old style configs into the new format.
|
Try to convert old style configs into the new format.
|
||||||
|
|
||||||
This is still a work in progress and might not work in many cases.
|
This is still a work in progress and might not work in many cases.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
utils.logger.warning('The old configuration format is deprecated. The converted file MAY NOT yield the right results')
|
utils.logger.warning(
|
||||||
|
"The old configuration format is deprecated. The converted file MAY NOT yield the right results"
|
||||||
|
)
|
||||||
|
|
||||||
new = old.copy()
|
new = old.copy()
|
||||||
|
|
||||||
network = {}
|
network = {}
|
||||||
|
|
||||||
if 'topology' in old:
|
if "topology" in old:
|
||||||
del new['topology']
|
del new["topology"]
|
||||||
network['topology'] = old['topology']
|
network["topology"] = old["topology"]
|
||||||
|
|
||||||
if 'network_params' in old and old['network_params']:
|
if "network_params" in old and old["network_params"]:
|
||||||
del new['network_params']
|
del new["network_params"]
|
||||||
for (k, v) in old['network_params'].items():
|
for (k, v) in old["network_params"].items():
|
||||||
if k == 'path':
|
if k == "path":
|
||||||
network['path'] = v
|
network["path"] = v
|
||||||
else:
|
else:
|
||||||
network.setdefault('params', {})[k] = v
|
network.setdefault("params", {})[k] = v
|
||||||
|
|
||||||
topology = None
|
topology = None
|
||||||
if network:
|
if network:
|
||||||
topology = network
|
topology = network
|
||||||
|
|
||||||
|
agents = {"fixed": [], "distribution": []}
|
||||||
agents = {'fixed': [], 'distribution': []}
|
|
||||||
|
|
||||||
def updated_agent(agent):
|
def updated_agent(agent):
|
||||||
'''Convert an agent definition'''
|
"""Convert an agent definition"""
|
||||||
newagent = dict(agent)
|
newagent = dict(agent)
|
||||||
return newagent
|
return newagent
|
||||||
|
|
||||||
@ -186,80 +197,74 @@ def convert_old(old, strict=True):
|
|||||||
fixed = []
|
fixed = []
|
||||||
override = []
|
override = []
|
||||||
|
|
||||||
if 'environment_agents' in new:
|
if "environment_agents" in new:
|
||||||
|
|
||||||
for agent in new['environment_agents']:
|
for agent in new["environment_agents"]:
|
||||||
agent.setdefault('state', {})['group'] = 'environment'
|
agent.setdefault("state", {})["group"] = "environment"
|
||||||
if 'agent_id' in agent:
|
if "agent_id" in agent:
|
||||||
agent['state']['name'] = agent['agent_id']
|
agent["state"]["name"] = agent["agent_id"]
|
||||||
del agent['agent_id']
|
del agent["agent_id"]
|
||||||
agent['hidden'] = True
|
agent["hidden"] = True
|
||||||
agent['topology'] = False
|
agent["topology"] = False
|
||||||
fixed.append(updated_agent(agent))
|
fixed.append(updated_agent(agent))
|
||||||
del new['environment_agents']
|
del new["environment_agents"]
|
||||||
|
|
||||||
|
if "agent_class" in old:
|
||||||
|
del new["agent_class"]
|
||||||
|
agents["agent_class"] = old["agent_class"]
|
||||||
|
|
||||||
if 'agent_class' in old:
|
if "default_state" in old:
|
||||||
del new['agent_class']
|
del new["default_state"]
|
||||||
agents['agent_class'] = old['agent_class']
|
agents["state"] = old["default_state"]
|
||||||
|
|
||||||
if 'default_state' in old:
|
if "network_agents" in old:
|
||||||
del new['default_state']
|
agents["topology"] = True
|
||||||
agents['state'] = old['default_state']
|
|
||||||
|
|
||||||
if 'network_agents' in old:
|
agents.setdefault("state", {})["group"] = "network"
|
||||||
agents['topology'] = True
|
|
||||||
|
|
||||||
agents.setdefault('state', {})['group'] = 'network'
|
for agent in new["network_agents"]:
|
||||||
|
|
||||||
for agent in new['network_agents']:
|
|
||||||
agent = updated_agent(agent)
|
agent = updated_agent(agent)
|
||||||
if 'agent_id' in agent:
|
if "agent_id" in agent:
|
||||||
agent['state']['name'] = agent['agent_id']
|
agent["state"]["name"] = agent["agent_id"]
|
||||||
del agent['agent_id']
|
del agent["agent_id"]
|
||||||
fixed.append(agent)
|
fixed.append(agent)
|
||||||
else:
|
else:
|
||||||
by_weight.append(agent)
|
by_weight.append(agent)
|
||||||
del new['network_agents']
|
del new["network_agents"]
|
||||||
|
|
||||||
if 'agent_class' in old and (not fixed and not by_weight):
|
|
||||||
agents['topology'] = True
|
|
||||||
by_weight = [{'agent_class': old['agent_class'], 'weight': 1}]
|
|
||||||
|
|
||||||
|
if "agent_class" in old and (not fixed and not by_weight):
|
||||||
|
agents["topology"] = True
|
||||||
|
by_weight = [{"agent_class": old["agent_class"], "weight": 1}]
|
||||||
|
|
||||||
# TODO: translate states properly
|
# TODO: translate states properly
|
||||||
if 'states' in old:
|
if "states" in old:
|
||||||
del new['states']
|
del new["states"]
|
||||||
states = old['states']
|
states = old["states"]
|
||||||
if isinstance(states, dict):
|
if isinstance(states, dict):
|
||||||
states = states.items()
|
states = states.items()
|
||||||
else:
|
else:
|
||||||
states = enumerate(states)
|
states = enumerate(states)
|
||||||
for (k, v) in states:
|
for (k, v) in states:
|
||||||
override.append({'filter': {'node_id': k},
|
override.append({"filter": {"node_id": k}, "state": v})
|
||||||
'state': v})
|
|
||||||
|
|
||||||
agents['override'] = override
|
|
||||||
agents['fixed'] = fixed
|
|
||||||
agents['distribution'] = by_weight
|
|
||||||
|
|
||||||
|
agents["override"] = override
|
||||||
|
agents["fixed"] = fixed
|
||||||
|
agents["distribution"] = by_weight
|
||||||
|
|
||||||
model_params = {}
|
model_params = {}
|
||||||
if 'environment_params' in new:
|
if "environment_params" in new:
|
||||||
del new['environment_params']
|
del new["environment_params"]
|
||||||
model_params = dict(old['environment_params'])
|
model_params = dict(old["environment_params"])
|
||||||
|
|
||||||
if 'environment_class' in old:
|
if "environment_class" in old:
|
||||||
del new['environment_class']
|
del new["environment_class"]
|
||||||
new['model_class'] = old['environment_class']
|
new["model_class"] = old["environment_class"]
|
||||||
|
|
||||||
if 'dump' in old:
|
if "dump" in old:
|
||||||
del new['dump']
|
del new["dump"]
|
||||||
new['dry_run'] = not old['dump']
|
new["dry_run"] = not old["dump"]
|
||||||
|
|
||||||
model_params['topology'] = topology
|
model_params["topology"] = topology
|
||||||
model_params['agents'] = agents
|
model_params["agents"] = agents
|
||||||
|
|
||||||
return Config(version='2',
|
return Config(version="2", model_params=model_params, **new)
|
||||||
model_params=model_params,
|
|
||||||
**new)
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from mesa import DataCollector as MDC
|
from mesa import DataCollector as MDC
|
||||||
|
|
||||||
class SoilDataCollector(MDC):
|
|
||||||
|
|
||||||
|
class SoilDataCollector(MDC):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
@ -18,9 +18,9 @@ def wrapcmd(func):
|
|||||||
known = globals()
|
known = globals()
|
||||||
known.update(self.curframe.f_globals)
|
known.update(self.curframe.f_globals)
|
||||||
known.update(self.curframe.f_locals)
|
known.update(self.curframe.f_locals)
|
||||||
known['agent'] = known.get('self', None)
|
known["agent"] = known.get("self", None)
|
||||||
known['model'] = known.get('self', {}).get('model')
|
known["model"] = known.get("self", {}).get("model")
|
||||||
known['attrs'] = arg.strip().split()
|
known["attrs"] = arg.strip().split()
|
||||||
|
|
||||||
exec(func.__code__, known, known)
|
exec(func.__code__, known, known)
|
||||||
|
|
||||||
@ -29,12 +29,12 @@ def wrapcmd(func):
|
|||||||
|
|
||||||
class Debug(pdb.Pdb):
|
class Debug(pdb.Pdb):
|
||||||
def __init__(self, *args, skip_soil=False, **kwargs):
|
def __init__(self, *args, skip_soil=False, **kwargs):
|
||||||
skip = kwargs.get('skip', [])
|
skip = kwargs.get("skip", [])
|
||||||
skip.append('soil')
|
skip.append("soil")
|
||||||
if skip_soil:
|
if skip_soil:
|
||||||
skip.append('soil')
|
skip.append("soil")
|
||||||
skip.append('soil.*')
|
skip.append("soil.*")
|
||||||
skip.append('mesa.*')
|
skip.append("mesa.*")
|
||||||
super(Debug, self).__init__(*args, skip=skip, **kwargs)
|
super(Debug, self).__init__(*args, skip=skip, **kwargs)
|
||||||
self.prompt = "[soil-pdb] "
|
self.prompt = "[soil-pdb] "
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ class Debug(pdb.Pdb):
|
|||||||
def _soil_agents(model, attrs=None, pretty=True, **kwargs):
|
def _soil_agents(model, attrs=None, pretty=True, **kwargs):
|
||||||
for agent in model.agents(**kwargs):
|
for agent in model.agents(**kwargs):
|
||||||
d = agent
|
d = agent
|
||||||
print(' - ' + indent(agent.to_str(keys=attrs, pretty=pretty), ' '))
|
print(" - " + indent(agent.to_str(keys=attrs, pretty=pretty), " "))
|
||||||
|
|
||||||
@wrapcmd
|
@wrapcmd
|
||||||
def do_soil_agents():
|
def do_soil_agents():
|
||||||
@ -52,20 +52,20 @@ class Debug(pdb.Pdb):
|
|||||||
|
|
||||||
@wrapcmd
|
@wrapcmd
|
||||||
def do_soil_list():
|
def do_soil_list():
|
||||||
return Debug._soil_agents(model, attrs=['state_id'], pretty=False)
|
return Debug._soil_agents(model, attrs=["state_id"], pretty=False)
|
||||||
|
|
||||||
do_sl = do_soil_list
|
do_sl = do_soil_list
|
||||||
|
|
||||||
def do_continue_state(self, arg):
|
def do_continue_state(self, arg):
|
||||||
self.do_break_state(arg, temporary=True)
|
self.do_break_state(arg, temporary=True)
|
||||||
return self.do_continue('')
|
return self.do_continue("")
|
||||||
|
|
||||||
do_cs = do_continue_state
|
do_cs = do_continue_state
|
||||||
|
|
||||||
@wrapcmd
|
@wrapcmd
|
||||||
def do_soil_agent():
|
def do_soil_agent():
|
||||||
if not agent:
|
if not agent:
|
||||||
print('No agent available')
|
print("No agent available")
|
||||||
return
|
return
|
||||||
|
|
||||||
keys = None
|
keys = None
|
||||||
@ -81,9 +81,9 @@ class Debug(pdb.Pdb):
|
|||||||
do_aa = do_soil_agent
|
do_aa = do_soil_agent
|
||||||
|
|
||||||
def do_break_state(self, arg: str, instances=None, temporary=False):
|
def do_break_state(self, arg: str, instances=None, temporary=False):
|
||||||
'''
|
"""
|
||||||
Break before a specified state is stepped into.
|
Break before a specified state is stepped into.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
klass = None
|
klass = None
|
||||||
state = arg
|
state = arg
|
||||||
@ -95,39 +95,39 @@ class Debug(pdb.Pdb):
|
|||||||
if tokens:
|
if tokens:
|
||||||
instances = list(eval(token) for token in tokens)
|
instances = list(eval(token) for token in tokens)
|
||||||
|
|
||||||
colon = state.find(':')
|
colon = state.find(":")
|
||||||
|
|
||||||
if colon > 0:
|
if colon > 0:
|
||||||
klass = state[:colon].rstrip()
|
klass = state[:colon].rstrip()
|
||||||
state = state[colon+1:].strip()
|
state = state[colon + 1 :].strip()
|
||||||
|
|
||||||
|
|
||||||
print(klass, state, tokens)
|
print(klass, state, tokens)
|
||||||
klass = eval(klass,
|
klass = eval(klass, self.curframe.f_globals, self.curframe_locals)
|
||||||
self.curframe.f_globals,
|
|
||||||
self.curframe_locals)
|
|
||||||
|
|
||||||
if klass:
|
if klass:
|
||||||
klasses = [klass]
|
klasses = [klass]
|
||||||
else:
|
else:
|
||||||
klasses = [k for k in self.curframe.f_globals.values() if isinstance(k, type) and issubclass(k, FSM)]
|
klasses = [
|
||||||
|
k
|
||||||
|
for k in self.curframe.f_globals.values()
|
||||||
|
if isinstance(k, type) and issubclass(k, FSM)
|
||||||
|
]
|
||||||
|
|
||||||
if not klasses:
|
if not klasses:
|
||||||
self.error('No agent classes found')
|
self.error("No agent classes found")
|
||||||
|
|
||||||
|
|
||||||
for klass in klasses:
|
for klass in klasses:
|
||||||
try:
|
try:
|
||||||
func = getattr(klass, state)
|
func = getattr(klass, state)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
self.error(f'State {state} not found in class {klass}')
|
self.error(f"State {state} not found in class {klass}")
|
||||||
continue
|
continue
|
||||||
if hasattr(func, '__func__'):
|
if hasattr(func, "__func__"):
|
||||||
func = func.__func__
|
func = func.__func__
|
||||||
|
|
||||||
code = func.__code__
|
code = func.__code__
|
||||||
#use co_name to identify the bkpt (function names
|
# use co_name to identify the bkpt (function names
|
||||||
#could be aliased, but co_name is invariant)
|
# could be aliased, but co_name is invariant)
|
||||||
funcname = code.co_name
|
funcname = code.co_name
|
||||||
lineno = code.co_firstlineno
|
lineno = code.co_firstlineno
|
||||||
filename = code.co_filename
|
filename = code.co_filename
|
||||||
@ -135,38 +135,36 @@ class Debug(pdb.Pdb):
|
|||||||
# Check for reasonable breakpoint
|
# Check for reasonable breakpoint
|
||||||
line = self.checkline(filename, lineno)
|
line = self.checkline(filename, lineno)
|
||||||
if not line:
|
if not line:
|
||||||
raise ValueError('no line found')
|
raise ValueError("no line found")
|
||||||
# now set the break point
|
# now set the break point
|
||||||
cond = None
|
cond = None
|
||||||
if instances:
|
if instances:
|
||||||
cond = f'self.unique_id in { repr(instances) }'
|
cond = f"self.unique_id in { repr(instances) }"
|
||||||
|
|
||||||
existing = self.get_breaks(filename, line)
|
existing = self.get_breaks(filename, line)
|
||||||
if existing:
|
if existing:
|
||||||
self.message("Breakpoint already exists at %s:%d" %
|
self.message("Breakpoint already exists at %s:%d" % (filename, line))
|
||||||
(filename, line))
|
|
||||||
continue
|
continue
|
||||||
err = self.set_break(filename, line, temporary, cond, funcname)
|
err = self.set_break(filename, line, temporary, cond, funcname)
|
||||||
if err:
|
if err:
|
||||||
self.error(err)
|
self.error(err)
|
||||||
else:
|
else:
|
||||||
bp = self.get_breaks(filename, line)[-1]
|
bp = self.get_breaks(filename, line)[-1]
|
||||||
self.message("Breakpoint %d at %s:%d" %
|
self.message("Breakpoint %d at %s:%d" % (bp.number, bp.file, bp.line))
|
||||||
(bp.number, bp.file, bp.line))
|
|
||||||
|
|
||||||
do_bs = do_break_state
|
do_bs = do_break_state
|
||||||
|
|
||||||
def do_break_state_self(self, arg: str, temporary=False):
|
def do_break_state_self(self, arg: str, temporary=False):
|
||||||
'''
|
"""
|
||||||
Break before a specified state is stepped into, for the current agent
|
Break before a specified state is stepped into, for the current agent
|
||||||
'''
|
"""
|
||||||
agent = self.curframe.f_locals.get('self')
|
agent = self.curframe.f_locals.get("self")
|
||||||
if not agent:
|
if not agent:
|
||||||
self.error('No current agent.')
|
self.error("No current agent.")
|
||||||
self.error('Try this again when the debugger is stopped inside an agent')
|
self.error("Try this again when the debugger is stopped inside an agent")
|
||||||
return
|
return
|
||||||
|
|
||||||
arg = f'{agent.__class__.__name__}:{ arg } {agent.unique_id}'
|
arg = f"{agent.__class__.__name__}:{ arg } {agent.unique_id}"
|
||||||
return self.do_break_state(arg)
|
return self.do_break_state(arg)
|
||||||
|
|
||||||
do_bss = do_break_state_self
|
do_bss = do_break_state_self
|
||||||
@ -174,6 +172,7 @@ class Debug(pdb.Pdb):
|
|||||||
|
|
||||||
debugger = None
|
debugger = None
|
||||||
|
|
||||||
|
|
||||||
def set_trace(frame=None, **kwargs):
|
def set_trace(frame=None, **kwargs):
|
||||||
global debugger
|
global debugger
|
||||||
if debugger is None:
|
if debugger is None:
|
||||||
|
@ -35,18 +35,20 @@ class BaseEnvironment(Model):
|
|||||||
:meth:`soil.environment.Environment.get` method.
|
:meth:`soil.environment.Environment.get` method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
id='unnamed_env',
|
self,
|
||||||
seed='default',
|
id="unnamed_env",
|
||||||
schedule=None,
|
seed="default",
|
||||||
dir_path=None,
|
schedule=None,
|
||||||
interval=1,
|
dir_path=None,
|
||||||
agent_class=None,
|
interval=1,
|
||||||
agents: [tuple[type, Dict[str, Any]]] = {},
|
agent_class=None,
|
||||||
agent_reporters: Optional[Any] = None,
|
agents: [tuple[type, Dict[str, Any]]] = {},
|
||||||
model_reporters: Optional[Any] = None,
|
agent_reporters: Optional[Any] = None,
|
||||||
tables: Optional[Any] = None,
|
model_reporters: Optional[Any] = None,
|
||||||
**env_params):
|
tables: Optional[Any] = None,
|
||||||
|
**env_params,
|
||||||
|
):
|
||||||
|
|
||||||
super().__init__(seed=seed)
|
super().__init__(seed=seed)
|
||||||
self.env_params = env_params or {}
|
self.env_params = env_params or {}
|
||||||
@ -75,27 +77,26 @@ class BaseEnvironment(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _agent_from_dict(self, agent):
|
def _agent_from_dict(self, agent):
|
||||||
'''
|
"""
|
||||||
Translate an agent dictionary into an agent
|
Translate an agent dictionary into an agent
|
||||||
'''
|
"""
|
||||||
agent = dict(**agent)
|
agent = dict(**agent)
|
||||||
cls = agent.pop('agent_class', None) or self.agent_class
|
cls = agent.pop("agent_class", None) or self.agent_class
|
||||||
unique_id = agent.pop('unique_id', None)
|
unique_id = agent.pop("unique_id", None)
|
||||||
if unique_id is None:
|
if unique_id is None:
|
||||||
unique_id = self.next_id()
|
unique_id = self.next_id()
|
||||||
|
|
||||||
return serialization.deserialize(cls)(unique_id=unique_id,
|
return serialization.deserialize(cls)(unique_id=unique_id, model=self, **agent)
|
||||||
model=self, **agent)
|
|
||||||
|
|
||||||
def init_agents(self, agents: Union[config.AgentConfig, [Dict[str, Any]]] = {}):
|
def init_agents(self, agents: Union[config.AgentConfig, [Dict[str, Any]]] = {}):
|
||||||
'''
|
"""
|
||||||
Initialize the agents in the model from either a `soil.config.AgentConfig` or a list of
|
Initialize the agents in the model from either a `soil.config.AgentConfig` or a list of
|
||||||
dictionaries that each describes an agent.
|
dictionaries that each describes an agent.
|
||||||
|
|
||||||
If given a list of dictionaries, an agent will be created for each dictionary. The agent
|
If given a list of dictionaries, an agent will be created for each dictionary. The agent
|
||||||
class can be specified through the `agent_class` key. The rest of the items will be used
|
class can be specified through the `agent_class` key. The rest of the items will be used
|
||||||
as parameters to the agent.
|
as parameters to the agent.
|
||||||
'''
|
"""
|
||||||
if not agents:
|
if not agents:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -108,11 +109,10 @@ class BaseEnvironment(Model):
|
|||||||
override = lst.override
|
override = lst.override
|
||||||
lst = self._agent_dict_from_config(lst)
|
lst = self._agent_dict_from_config(lst)
|
||||||
|
|
||||||
#TODO: check override is working again. It cannot (easily) be part of agents.from_config anymore,
|
# TODO: check override is working again. It cannot (easily) be part of agents.from_config anymore,
|
||||||
# because it needs attribute such as unique_id, which are only present after init
|
# because it needs attribute such as unique_id, which are only present after init
|
||||||
new_agents = [self._agent_from_dict(agent) for agent in lst]
|
new_agents = [self._agent_from_dict(agent) for agent in lst]
|
||||||
|
|
||||||
|
|
||||||
for a in new_agents:
|
for a in new_agents:
|
||||||
self.schedule.add(a)
|
self.schedule.add(a)
|
||||||
|
|
||||||
@ -122,8 +122,7 @@ class BaseEnvironment(Model):
|
|||||||
setattr(agent, attr, value)
|
setattr(agent, attr, value)
|
||||||
|
|
||||||
def _agent_dict_from_config(self, cfg):
|
def _agent_dict_from_config(self, cfg):
|
||||||
return agentmod.from_config(cfg,
|
return agentmod.from_config(cfg, random=self.random)
|
||||||
random=self.random)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def agents(self):
|
def agents(self):
|
||||||
@ -139,18 +138,16 @@ class BaseEnvironment(Model):
|
|||||||
def now(self):
|
def now(self):
|
||||||
if self.schedule:
|
if self.schedule:
|
||||||
return self.schedule.time
|
return self.schedule.time
|
||||||
raise Exception('The environment has not been scheduled, so it has no sense of time')
|
raise Exception(
|
||||||
|
"The environment has not been scheduled, so it has no sense of time"
|
||||||
|
)
|
||||||
|
|
||||||
def add_agent(self, agent_class, unique_id=None, **kwargs):
|
def add_agent(self, agent_class, unique_id=None, **kwargs):
|
||||||
a = None
|
a = None
|
||||||
if unique_id is None:
|
if unique_id is None:
|
||||||
unique_id = self.next_id()
|
unique_id = self.next_id()
|
||||||
|
|
||||||
|
a = agent_class(model=self, unique_id=unique_id, **args)
|
||||||
a = agent_class(model=self,
|
|
||||||
unique_id=unique_id,
|
|
||||||
**args)
|
|
||||||
|
|
||||||
self.schedule.add(a)
|
self.schedule.add(a)
|
||||||
return a
|
return a
|
||||||
@ -163,16 +160,16 @@ class BaseEnvironment(Model):
|
|||||||
for k, v in kwargs:
|
for k, v in kwargs:
|
||||||
message += " {k}={v} ".format(k, v)
|
message += " {k}={v} ".format(k, v)
|
||||||
extra = {}
|
extra = {}
|
||||||
extra['now'] = self.now
|
extra["now"] = self.now
|
||||||
extra['id'] = self.id
|
extra["id"] = self.id
|
||||||
return self.logger.log(level, message, extra=extra)
|
return self.logger.log(level, message, extra=extra)
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
'''
|
"""
|
||||||
Advance one step in the simulation, and update the data collection and scheduler appropriately
|
Advance one step in the simulation, and update the data collection and scheduler appropriately
|
||||||
'''
|
"""
|
||||||
super().step()
|
super().step()
|
||||||
self.logger.info(f'--- Step {self.now:^5} ---')
|
self.logger.info(f"--- Step {self.now:^5} ---")
|
||||||
self.schedule.step()
|
self.schedule.step()
|
||||||
self.datacollector.collect(self)
|
self.datacollector.collect(self)
|
||||||
|
|
||||||
@ -180,10 +177,10 @@ class BaseEnvironment(Model):
|
|||||||
return key in self.env_params
|
return key in self.env_params
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
'''
|
"""
|
||||||
Get the value of an environment attribute.
|
Get the value of an environment attribute.
|
||||||
Return `default` if the value is not set.
|
Return `default` if the value is not set.
|
||||||
'''
|
"""
|
||||||
return self.env_params.get(key, default)
|
return self.env_params.get(key, default)
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
@ -197,13 +194,15 @@ class BaseEnvironment(Model):
|
|||||||
|
|
||||||
|
|
||||||
class NetworkEnvironment(BaseEnvironment):
|
class NetworkEnvironment(BaseEnvironment):
|
||||||
'''
|
"""
|
||||||
The NetworkEnvironment is an environment that includes one or more networkx.Graph intances
|
The NetworkEnvironment is an environment that includes one or more networkx.Graph intances
|
||||||
and methods to associate agents to nodes and vice versa.
|
and methods to associate agents to nodes and vice versa.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, topology: Union[config.NetConfig, nx.Graph] = None, **kwargs):
|
def __init__(
|
||||||
agents = kwargs.pop('agents', None)
|
self, *args, topology: Union[config.NetConfig, nx.Graph] = None, **kwargs
|
||||||
|
):
|
||||||
|
agents = kwargs.pop("agents", None)
|
||||||
super().__init__(*args, agents=None, **kwargs)
|
super().__init__(*args, agents=None, **kwargs)
|
||||||
|
|
||||||
self._set_topology(topology)
|
self._set_topology(topology)
|
||||||
@ -211,37 +210,35 @@ class NetworkEnvironment(BaseEnvironment):
|
|||||||
self.init_agents(agents)
|
self.init_agents(agents)
|
||||||
|
|
||||||
def init_agents(self, *args, **kwargs):
|
def init_agents(self, *args, **kwargs):
|
||||||
'''Initialize the agents from a '''
|
"""Initialize the agents from a"""
|
||||||
super().init_agents(*args, **kwargs)
|
super().init_agents(*args, **kwargs)
|
||||||
for agent in self.schedule._agents.values():
|
for agent in self.schedule._agents.values():
|
||||||
if hasattr(agent, 'node_id'):
|
if hasattr(agent, "node_id"):
|
||||||
self._init_node(agent)
|
self._init_node(agent)
|
||||||
|
|
||||||
def _init_node(self, agent):
|
def _init_node(self, agent):
|
||||||
'''
|
"""
|
||||||
Make sure the node for a given agent has the proper attributes.
|
Make sure the node for a given agent has the proper attributes.
|
||||||
'''
|
"""
|
||||||
self.G.nodes[agent.node_id]['agent'] = agent
|
self.G.nodes[agent.node_id]["agent"] = agent
|
||||||
|
|
||||||
def _agent_dict_from_config(self, cfg):
|
def _agent_dict_from_config(self, cfg):
|
||||||
return agentmod.from_config(cfg,
|
return agentmod.from_config(cfg, topology=self.G, random=self.random)
|
||||||
topology=self.G,
|
|
||||||
random=self.random)
|
|
||||||
|
|
||||||
def _agent_from_dict(self, agent, unique_id=None):
|
def _agent_from_dict(self, agent, unique_id=None):
|
||||||
agent = dict(agent)
|
agent = dict(agent)
|
||||||
|
|
||||||
if not agent.get('topology', False):
|
if not agent.get("topology", False):
|
||||||
return super()._agent_from_dict(agent)
|
return super()._agent_from_dict(agent)
|
||||||
|
|
||||||
if unique_id is None:
|
if unique_id is None:
|
||||||
unique_id = self.next_id()
|
unique_id = self.next_id()
|
||||||
node_id = agent.get('node_id', None)
|
node_id = agent.get("node_id", None)
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
node_id = network.find_unassigned(self.G, random=self.random)
|
node_id = network.find_unassigned(self.G, random=self.random)
|
||||||
agent['node_id'] = node_id
|
agent["node_id"] = node_id
|
||||||
agent['unique_id'] = unique_id
|
agent["unique_id"] = unique_id
|
||||||
agent['topology'] = self.G
|
agent["topology"] = self.G
|
||||||
node_attrs = self.G.nodes[node_id]
|
node_attrs = self.G.nodes[node_id]
|
||||||
node_attrs.update(agent)
|
node_attrs.update(agent)
|
||||||
agent = node_attrs
|
agent = node_attrs
|
||||||
@ -269,32 +266,33 @@ class NetworkEnvironment(BaseEnvironment):
|
|||||||
if unique_id is None:
|
if unique_id is None:
|
||||||
unique_id = self.next_id()
|
unique_id = self.next_id()
|
||||||
if node_id is None:
|
if node_id is None:
|
||||||
node_id = network.find_unassigned(G=self.G,
|
node_id = network.find_unassigned(
|
||||||
shuffle=True,
|
G=self.G, shuffle=True, random=self.random
|
||||||
random=self.random)
|
)
|
||||||
|
|
||||||
if node_id in G.nodes:
|
if node_id in G.nodes:
|
||||||
self.G.nodes[node_id]['agent'] = None # Reserve
|
self.G.nodes[node_id]["agent"] = None # Reserve
|
||||||
else:
|
else:
|
||||||
self.G.add_node(node_id)
|
self.G.add_node(node_id)
|
||||||
|
|
||||||
a = self.add_agent(unique_id=unique_id, agent_class=agent_class, node_id=node_id, **kwargs)
|
a = self.add_agent(
|
||||||
a['visible'] = True
|
unique_id=unique_id, agent_class=agent_class, node_id=node_id, **kwargs
|
||||||
|
)
|
||||||
|
a["visible"] = True
|
||||||
return a
|
return a
|
||||||
|
|
||||||
def agent_for_node_id(self, node_id):
|
def agent_for_node_id(self, node_id):
|
||||||
return self.G.nodes[node_id].get('agent')
|
return self.G.nodes[node_id].get("agent")
|
||||||
|
|
||||||
def populate_network(self, agent_class, weights=None, **agent_params):
|
def populate_network(self, agent_class, weights=None, **agent_params):
|
||||||
if not hasattr(agent_class, 'len'):
|
if not hasattr(agent_class, "len"):
|
||||||
agent_class = [agent_class]
|
agent_class = [agent_class]
|
||||||
weights = None
|
weights = None
|
||||||
for (node_id, node) in self.G.nodes(data=True):
|
for (node_id, node) in self.G.nodes(data=True):
|
||||||
if 'agent' in node:
|
if "agent" in node:
|
||||||
continue
|
continue
|
||||||
a_class = self.random.choices(agent_class, weights)[0]
|
a_class = self.random.choices(agent_class, weights)[0]
|
||||||
self.add_agent(node_id=node_id,
|
self.add_agent(node_id=node_id, agent_class=a_class, **agent_params)
|
||||||
agent_class=a_class, **agent_params)
|
|
||||||
|
|
||||||
|
|
||||||
Environment = NetworkEnvironment
|
Environment = NetworkEnvironment
|
||||||
|
@ -24,56 +24,58 @@ class DryRunner(BytesIO):
|
|||||||
|
|
||||||
def write(self, txt):
|
def write(self, txt):
|
||||||
if self.__copy_to:
|
if self.__copy_to:
|
||||||
self.__copy_to.write('{}:::{}'.format(self.__fname, txt))
|
self.__copy_to.write("{}:::{}".format(self.__fname, txt))
|
||||||
try:
|
try:
|
||||||
super().write(txt)
|
super().write(txt)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
super().write(bytes(txt, 'utf-8'))
|
super().write(bytes(txt, "utf-8"))
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
content = '(binary data not shown)'
|
content = "(binary data not shown)"
|
||||||
try:
|
try:
|
||||||
content = self.getvalue().decode()
|
content = self.getvalue().decode()
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
pass
|
pass
|
||||||
logger.info('**Not** written to {} (dry run mode):\n\n{}\n\n'.format(self.__fname, content))
|
logger.info(
|
||||||
|
"**Not** written to {} (dry run mode):\n\n{}\n\n".format(
|
||||||
|
self.__fname, content
|
||||||
|
)
|
||||||
|
)
|
||||||
super().close()
|
super().close()
|
||||||
|
|
||||||
|
|
||||||
class Exporter:
|
class Exporter:
|
||||||
'''
|
"""
|
||||||
Interface for all exporters. It is not necessary, but it is useful
|
Interface for all exporters. It is not necessary, but it is useful
|
||||||
if you don't plan to implement all the methods.
|
if you don't plan to implement all the methods.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None):
|
def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None):
|
||||||
self.simulation = simulation
|
self.simulation = simulation
|
||||||
outdir = outdir or os.path.join(os.getcwd(), 'soil_output')
|
outdir = outdir or os.path.join(os.getcwd(), "soil_output")
|
||||||
self.outdir = os.path.join(outdir,
|
self.outdir = os.path.join(outdir, simulation.group or "", simulation.name)
|
||||||
simulation.group or '',
|
|
||||||
simulation.name)
|
|
||||||
self.dry_run = dry_run
|
self.dry_run = dry_run
|
||||||
if copy_to is None and dry_run:
|
if copy_to is None and dry_run:
|
||||||
copy_to = sys.stdout
|
copy_to = sys.stdout
|
||||||
self.copy_to = copy_to
|
self.copy_to = copy_to
|
||||||
|
|
||||||
def sim_start(self):
|
def sim_start(self):
|
||||||
'''Method to call when the simulation starts'''
|
"""Method to call when the simulation starts"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sim_end(self):
|
def sim_end(self):
|
||||||
'''Method to call when the simulation ends'''
|
"""Method to call when the simulation ends"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def trial_start(self, env):
|
def trial_start(self, env):
|
||||||
'''Method to call when a trial start'''
|
"""Method to call when a trial start"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def trial_end(self, env):
|
def trial_end(self, env):
|
||||||
'''Method to call when a trial ends'''
|
"""Method to call when a trial ends"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def output(self, f, mode='w', **kwargs):
|
def output(self, f, mode="w", **kwargs):
|
||||||
if self.dry_run:
|
if self.dry_run:
|
||||||
f = DryRunner(f, copy_to=self.copy_to)
|
f = DryRunner(f, copy_to=self.copy_to)
|
||||||
else:
|
else:
|
||||||
@ -86,35 +88,38 @@ class Exporter:
|
|||||||
|
|
||||||
|
|
||||||
class default(Exporter):
|
class default(Exporter):
|
||||||
'''Default exporter. Writes sqlite results, as well as the simulation YAML'''
|
"""Default exporter. Writes sqlite results, as well as the simulation YAML"""
|
||||||
|
|
||||||
def sim_start(self):
|
def sim_start(self):
|
||||||
if not self.dry_run:
|
if not self.dry_run:
|
||||||
logger.info('Dumping results to %s', self.outdir)
|
logger.info("Dumping results to %s", self.outdir)
|
||||||
with self.output(self.simulation.name + '.dumped.yml') as f:
|
with self.output(self.simulation.name + ".dumped.yml") as f:
|
||||||
f.write(self.simulation.to_yaml())
|
f.write(self.simulation.to_yaml())
|
||||||
else:
|
else:
|
||||||
logger.info('NOT dumping results')
|
logger.info("NOT dumping results")
|
||||||
|
|
||||||
def trial_end(self, env):
|
def trial_end(self, env):
|
||||||
if self.dry_run:
|
if self.dry_run:
|
||||||
logger.info('Running in DRY_RUN mode, the database will NOT be created')
|
logger.info("Running in DRY_RUN mode, the database will NOT be created")
|
||||||
return
|
return
|
||||||
|
|
||||||
with timer('Dumping simulation {} trial {}'.format(self.simulation.name,
|
with timer(
|
||||||
env.id)):
|
"Dumping simulation {} trial {}".format(self.simulation.name, env.id)
|
||||||
|
):
|
||||||
|
|
||||||
fpath = os.path.join(self.outdir, f'{env.id}.sqlite')
|
fpath = os.path.join(self.outdir, f"{env.id}.sqlite")
|
||||||
engine = create_engine(f'sqlite:///{fpath}', echo=False)
|
engine = create_engine(f"sqlite:///{fpath}", echo=False)
|
||||||
|
|
||||||
dc = env.datacollector
|
dc = env.datacollector
|
||||||
for (t, df) in get_dc_dfs(dc):
|
for (t, df) in get_dc_dfs(dc):
|
||||||
df.to_sql(t, con=engine, if_exists='append')
|
df.to_sql(t, con=engine, if_exists="append")
|
||||||
|
|
||||||
|
|
||||||
def get_dc_dfs(dc):
|
def get_dc_dfs(dc):
|
||||||
dfs = {'env': dc.get_model_vars_dataframe(),
|
dfs = {
|
||||||
'agents': dc.get_agent_vars_dataframe() }
|
"env": dc.get_model_vars_dataframe(),
|
||||||
|
"agents": dc.get_agent_vars_dataframe(),
|
||||||
|
}
|
||||||
for table_name in dc.tables:
|
for table_name in dc.tables:
|
||||||
dfs[table_name] = dc.get_table_dataframe(table_name)
|
dfs[table_name] = dc.get_table_dataframe(table_name)
|
||||||
yield from dfs.items()
|
yield from dfs.items()
|
||||||
@ -122,66 +127,78 @@ def get_dc_dfs(dc):
|
|||||||
|
|
||||||
class csv(Exporter):
|
class csv(Exporter):
|
||||||
|
|
||||||
'''Export the state of each environment (and its agents) in a separate CSV file'''
|
"""Export the state of each environment (and its agents) in a separate CSV file"""
|
||||||
|
|
||||||
def trial_end(self, env):
|
def trial_end(self, env):
|
||||||
with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name,
|
with timer(
|
||||||
env.id,
|
"[CSV] Dumping simulation {} trial {} @ dir {}".format(
|
||||||
self.outdir)):
|
self.simulation.name, env.id, self.outdir
|
||||||
|
)
|
||||||
|
):
|
||||||
for (df_name, df) in get_dc_dfs(env.datacollector):
|
for (df_name, df) in get_dc_dfs(env.datacollector):
|
||||||
with self.output('{}.{}.csv'.format(env.id, df_name)) as f:
|
with self.output("{}.{}.csv".format(env.id, df_name)) as f:
|
||||||
df.to_csv(f)
|
df.to_csv(f)
|
||||||
|
|
||||||
|
|
||||||
#TODO: reimplement GEXF exporting without history
|
# TODO: reimplement GEXF exporting without history
|
||||||
class gexf(Exporter):
|
class gexf(Exporter):
|
||||||
def trial_end(self, env):
|
def trial_end(self, env):
|
||||||
if self.dry_run:
|
if self.dry_run:
|
||||||
logger.info('Not dumping GEXF in dry_run mode')
|
logger.info("Not dumping GEXF in dry_run mode")
|
||||||
return
|
return
|
||||||
|
|
||||||
with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name,
|
with timer(
|
||||||
env.id)):
|
"[GEXF] Dumping simulation {} trial {}".format(self.simulation.name, env.id)
|
||||||
with self.output('{}.gexf'.format(env.id), mode='wb') as f:
|
):
|
||||||
|
with self.output("{}.gexf".format(env.id), mode="wb") as f:
|
||||||
network.dump_gexf(env.history_to_graph(), f)
|
network.dump_gexf(env.history_to_graph(), f)
|
||||||
self.dump_gexf(env, f)
|
self.dump_gexf(env, f)
|
||||||
|
|
||||||
|
|
||||||
class dummy(Exporter):
|
class dummy(Exporter):
|
||||||
|
|
||||||
def sim_start(self):
|
def sim_start(self):
|
||||||
with self.output('dummy', 'w') as f:
|
with self.output("dummy", "w") as f:
|
||||||
f.write('simulation started @ {}\n'.format(current_time()))
|
f.write("simulation started @ {}\n".format(current_time()))
|
||||||
|
|
||||||
def trial_start(self, env):
|
def trial_start(self, env):
|
||||||
with self.output('dummy', 'w') as f:
|
with self.output("dummy", "w") as f:
|
||||||
f.write('trial started@ {}\n'.format(current_time()))
|
f.write("trial started@ {}\n".format(current_time()))
|
||||||
|
|
||||||
def trial_end(self, env):
|
def trial_end(self, env):
|
||||||
with self.output('dummy', 'w') as f:
|
with self.output("dummy", "w") as f:
|
||||||
f.write('trial ended@ {}\n'.format(current_time()))
|
f.write("trial ended@ {}\n".format(current_time()))
|
||||||
|
|
||||||
def sim_end(self):
|
def sim_end(self):
|
||||||
with self.output('dummy', 'a') as f:
|
with self.output("dummy", "a") as f:
|
||||||
f.write('simulation ended @ {}\n'.format(current_time()))
|
f.write("simulation ended @ {}\n".format(current_time()))
|
||||||
|
|
||||||
|
|
||||||
class graphdrawing(Exporter):
|
class graphdrawing(Exporter):
|
||||||
|
|
||||||
def trial_end(self, env):
|
def trial_end(self, env):
|
||||||
# Outside effects
|
# Outside effects
|
||||||
f = plt.figure()
|
f = plt.figure()
|
||||||
nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111))
|
nx.draw(
|
||||||
with open('graph-{}.png'.format(env.id)) as f:
|
env.G,
|
||||||
|
node_size=10,
|
||||||
|
width=0.2,
|
||||||
|
pos=nx.spring_layout(env.G, scale=100),
|
||||||
|
ax=f.add_subplot(111),
|
||||||
|
)
|
||||||
|
with open("graph-{}.png".format(env.id)) as f:
|
||||||
f.savefig(f)
|
f.savefig(f)
|
||||||
|
|
||||||
'''
|
|
||||||
|
"""
|
||||||
Convert an environment into a NetworkX graph
|
Convert an environment into a NetworkX graph
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
def env_to_graph(env, history=None):
|
def env_to_graph(env, history=None):
|
||||||
G = nx.Graph(env.G)
|
G = nx.Graph(env.G)
|
||||||
|
|
||||||
for agent in env.network_agents:
|
for agent in env.network_agents:
|
||||||
|
|
||||||
attributes = {'agent': str(agent.__class__)}
|
attributes = {"agent": str(agent.__class__)}
|
||||||
lastattributes = {}
|
lastattributes = {}
|
||||||
spells = []
|
spells = []
|
||||||
lastvisible = False
|
lastvisible = False
|
||||||
@ -189,7 +206,7 @@ def env_to_graph(env, history=None):
|
|||||||
if not history:
|
if not history:
|
||||||
history = sorted(list(env.state_to_tuples()))
|
history = sorted(list(env.state_to_tuples()))
|
||||||
for _, t_step, attribute, value in history:
|
for _, t_step, attribute, value in history:
|
||||||
if attribute == 'visible':
|
if attribute == "visible":
|
||||||
nowvisible = value
|
nowvisible = value
|
||||||
if nowvisible and not lastvisible:
|
if nowvisible and not lastvisible:
|
||||||
laststep = t_step
|
laststep = t_step
|
||||||
@ -198,7 +215,7 @@ def env_to_graph(env, history=None):
|
|||||||
|
|
||||||
lastvisible = nowvisible
|
lastvisible = nowvisible
|
||||||
continue
|
continue
|
||||||
key = 'attr_' + attribute
|
key = "attr_" + attribute
|
||||||
if key not in attributes:
|
if key not in attributes:
|
||||||
attributes[key] = list()
|
attributes[key] = list()
|
||||||
if key not in lastattributes:
|
if key not in lastattributes:
|
||||||
|
@ -9,6 +9,7 @@ import networkx as nx
|
|||||||
|
|
||||||
from . import config, serialization, basestring
|
from . import config, serialization, basestring
|
||||||
|
|
||||||
|
|
||||||
def from_config(cfg: config.NetConfig, dir_path: str = None):
|
def from_config(cfg: config.NetConfig, dir_path: str = None):
|
||||||
if not isinstance(cfg, config.NetConfig):
|
if not isinstance(cfg, config.NetConfig):
|
||||||
cfg = config.NetConfig(**cfg)
|
cfg = config.NetConfig(**cfg)
|
||||||
@ -19,24 +20,28 @@ def from_config(cfg: config.NetConfig, dir_path: str = None):
|
|||||||
path = os.path.join(dir_path, path)
|
path = os.path.join(dir_path, path)
|
||||||
extension = os.path.splitext(path)[1][1:]
|
extension = os.path.splitext(path)[1][1:]
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if extension == 'gexf':
|
if extension == "gexf":
|
||||||
kwargs['version'] = '1.2draft'
|
kwargs["version"] = "1.2draft"
|
||||||
kwargs['node_type'] = int
|
kwargs["node_type"] = int
|
||||||
try:
|
try:
|
||||||
method = getattr(nx.readwrite, 'read_' + extension)
|
method = getattr(nx.readwrite, "read_" + extension)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise AttributeError('Unknown format')
|
raise AttributeError("Unknown format")
|
||||||
return method(path, **kwargs)
|
return method(path, **kwargs)
|
||||||
|
|
||||||
if cfg.params:
|
if cfg.params:
|
||||||
net_args = cfg.params.dict()
|
net_args = cfg.params.dict()
|
||||||
net_gen = net_args.pop('generator')
|
net_gen = net_args.pop("generator")
|
||||||
|
|
||||||
if dir_path not in sys.path:
|
if dir_path not in sys.path:
|
||||||
sys.path.append(dir_path)
|
sys.path.append(dir_path)
|
||||||
|
|
||||||
method = serialization.deserializer(net_gen,
|
method = serialization.deserializer(
|
||||||
known_modules=['networkx.generators',])
|
net_gen,
|
||||||
|
known_modules=[
|
||||||
|
"networkx.generators",
|
||||||
|
],
|
||||||
|
)
|
||||||
return method(**net_args)
|
return method(**net_args)
|
||||||
|
|
||||||
if isinstance(cfg.fixed, config.Topology):
|
if isinstance(cfg.fixed, config.Topology):
|
||||||
@ -49,17 +54,17 @@ def from_config(cfg: config.NetConfig, dir_path: str = None):
|
|||||||
|
|
||||||
|
|
||||||
def find_unassigned(G, shuffle=False, random=random):
|
def find_unassigned(G, shuffle=False, random=random):
|
||||||
'''
|
"""
|
||||||
Link an agent to a node in a topology.
|
Link an agent to a node in a topology.
|
||||||
|
|
||||||
If node_id is None, a node without an agent_id will be found.
|
If node_id is None, a node without an agent_id will be found.
|
||||||
'''
|
"""
|
||||||
#TODO: test
|
# TODO: test
|
||||||
candidates = list(G.nodes(data=True))
|
candidates = list(G.nodes(data=True))
|
||||||
if shuffle:
|
if shuffle:
|
||||||
random.shuffle(candidates)
|
random.shuffle(candidates)
|
||||||
for next_id, data in candidates:
|
for next_id, data in candidates:
|
||||||
if 'agent' not in data:
|
if "agent" not in data:
|
||||||
node_id = next_id
|
node_id = next_id
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -68,8 +73,14 @@ def find_unassigned(G, shuffle=False, random=random):
|
|||||||
|
|
||||||
def dump_gexf(G, f):
|
def dump_gexf(G, f):
|
||||||
for node in G.nodes():
|
for node in G.nodes():
|
||||||
if 'pos' in G.nodes[node]:
|
if "pos" in G.nodes[node]:
|
||||||
G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}}
|
G.nodes[node]["viz"] = {
|
||||||
del (G.nodes[node]['pos'])
|
"position": {
|
||||||
|
"x": G.nodes[node]["pos"][0],
|
||||||
|
"y": G.nodes[node]["pos"][1],
|
||||||
|
"z": 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
del G.nodes[node]["pos"]
|
||||||
|
|
||||||
nx.write_gexf(G, f, version="1.2draft")
|
nx.write_gexf(G, f, version="1.2draft")
|
||||||
|
@ -15,13 +15,14 @@ import networkx as nx
|
|||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger('soil')
|
logger = logging.getLogger("soil")
|
||||||
|
|
||||||
|
|
||||||
def load_file(infile):
|
def load_file(infile):
|
||||||
folder = os.path.dirname(infile)
|
folder = os.path.dirname(infile)
|
||||||
if folder not in sys.path:
|
if folder not in sys.path:
|
||||||
sys.path.append(folder)
|
sys.path.append(folder)
|
||||||
with open(infile, 'r') as f:
|
with open(infile, "r") as f:
|
||||||
return list(chain.from_iterable(map(expand_template, load_string(f))))
|
return list(chain.from_iterable(map(expand_template, load_string(f))))
|
||||||
|
|
||||||
|
|
||||||
@ -30,14 +31,15 @@ def load_string(string):
|
|||||||
|
|
||||||
|
|
||||||
def expand_template(config):
|
def expand_template(config):
|
||||||
if 'template' not in config:
|
if "template" not in config:
|
||||||
yield config
|
yield config
|
||||||
return
|
return
|
||||||
if 'vars' not in config:
|
if "vars" not in config:
|
||||||
raise ValueError(('You must provide a definition of variables'
|
raise ValueError(
|
||||||
' for the template.'))
|
("You must provide a definition of variables" " for the template.")
|
||||||
|
)
|
||||||
|
|
||||||
template = config['template']
|
template = config["template"]
|
||||||
|
|
||||||
if not isinstance(template, str):
|
if not isinstance(template, str):
|
||||||
template = yaml.dump(template)
|
template = yaml.dump(template)
|
||||||
@ -49,9 +51,9 @@ def expand_template(config):
|
|||||||
blank_str = template.render({k: 0 for k in params[0].keys()})
|
blank_str = template.render({k: 0 for k in params[0].keys()})
|
||||||
blank = list(load_string(blank_str))
|
blank = list(load_string(blank_str))
|
||||||
if len(blank) > 1:
|
if len(blank) > 1:
|
||||||
raise ValueError('Templates must not return more than one configuration')
|
raise ValueError("Templates must not return more than one configuration")
|
||||||
if 'name' in blank[0]:
|
if "name" in blank[0]:
|
||||||
raise ValueError('Templates cannot be named, use group instead')
|
raise ValueError("Templates cannot be named, use group instead")
|
||||||
|
|
||||||
for ps in params:
|
for ps in params:
|
||||||
string = template.render(ps)
|
string = template.render(ps)
|
||||||
@ -60,25 +62,25 @@ def expand_template(config):
|
|||||||
|
|
||||||
|
|
||||||
def params_for_template(config):
|
def params_for_template(config):
|
||||||
sampler_config = config.get('sampler', {'N': 100})
|
sampler_config = config.get("sampler", {"N": 100})
|
||||||
sampler = sampler_config.pop('method', 'SALib.sample.morris.sample')
|
sampler = sampler_config.pop("method", "SALib.sample.morris.sample")
|
||||||
sampler = deserializer(sampler)
|
sampler = deserializer(sampler)
|
||||||
bounds = config['vars']['bounds']
|
bounds = config["vars"]["bounds"]
|
||||||
|
|
||||||
problem = {
|
problem = {
|
||||||
'num_vars': len(bounds),
|
"num_vars": len(bounds),
|
||||||
'names': list(bounds.keys()),
|
"names": list(bounds.keys()),
|
||||||
'bounds': list(v for v in bounds.values())
|
"bounds": list(v for v in bounds.values()),
|
||||||
}
|
}
|
||||||
samples = sampler(problem, **sampler_config)
|
samples = sampler(problem, **sampler_config)
|
||||||
|
|
||||||
lists = config['vars'].get('lists', {})
|
lists = config["vars"].get("lists", {})
|
||||||
names = list(lists.keys())
|
names = list(lists.keys())
|
||||||
values = list(lists.values())
|
values = list(lists.values())
|
||||||
combs = list(product(*values))
|
combs = list(product(*values))
|
||||||
|
|
||||||
allnames = names + problem['names']
|
allnames = names + problem["names"]
|
||||||
allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)]
|
allvalues = [(list(i[0]) + list(i[1])) for i in product(combs, samples)]
|
||||||
params = list(map(lambda x: dict(zip(allnames, x)), allvalues))
|
params = list(map(lambda x: dict(zip(allnames, x)), allvalues))
|
||||||
return params
|
return params
|
||||||
|
|
||||||
@ -100,22 +102,24 @@ def load_config(cfg):
|
|||||||
yield from load_files(cfg)
|
yield from load_files(cfg)
|
||||||
|
|
||||||
|
|
||||||
builtins = importlib.import_module('builtins')
|
builtins = importlib.import_module("builtins")
|
||||||
|
|
||||||
KNOWN_MODULES = ['soil', ]
|
KNOWN_MODULES = [
|
||||||
|
"soil",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def name(value, known_modules=KNOWN_MODULES):
|
def name(value, known_modules=KNOWN_MODULES):
|
||||||
'''Return a name that can be imported, to serialize/deserialize an object'''
|
"""Return a name that can be imported, to serialize/deserialize an object"""
|
||||||
if value is None:
|
if value is None:
|
||||||
return 'None'
|
return "None"
|
||||||
if not isinstance(value, type): # Get the class name first
|
if not isinstance(value, type): # Get the class name first
|
||||||
value = type(value)
|
value = type(value)
|
||||||
tname = value.__name__
|
tname = value.__name__
|
||||||
if hasattr(builtins, tname):
|
if hasattr(builtins, tname):
|
||||||
return tname
|
return tname
|
||||||
modname = value.__module__
|
modname = value.__module__
|
||||||
if modname == '__main__':
|
if modname == "__main__":
|
||||||
return tname
|
return tname
|
||||||
if known_modules and modname in known_modules:
|
if known_modules and modname in known_modules:
|
||||||
return tname
|
return tname
|
||||||
@ -125,17 +129,17 @@ def name(value, known_modules=KNOWN_MODULES):
|
|||||||
module = importlib.import_module(kmod)
|
module = importlib.import_module(kmod)
|
||||||
if hasattr(module, tname):
|
if hasattr(module, tname):
|
||||||
return tname
|
return tname
|
||||||
return '{}.{}'.format(modname, tname)
|
return "{}.{}".format(modname, tname)
|
||||||
|
|
||||||
|
|
||||||
def serializer(type_):
|
def serializer(type_):
|
||||||
if type_ != 'str' and hasattr(builtins, type_):
|
if type_ != "str" and hasattr(builtins, type_):
|
||||||
return repr
|
return repr
|
||||||
return lambda x: x
|
return lambda x: x
|
||||||
|
|
||||||
|
|
||||||
def serialize(v, known_modules=KNOWN_MODULES):
|
def serialize(v, known_modules=KNOWN_MODULES):
|
||||||
'''Get a text representation of an object.'''
|
"""Get a text representation of an object."""
|
||||||
tname = name(v, known_modules=known_modules)
|
tname = name(v, known_modules=known_modules)
|
||||||
func = serializer(tname)
|
func = serializer(tname)
|
||||||
return func(v), tname
|
return func(v), tname
|
||||||
@ -160,9 +164,9 @@ IS_CLASS = re.compile(r"<class '(.*)'>")
|
|||||||
def deserializer(type_, known_modules=KNOWN_MODULES):
|
def deserializer(type_, known_modules=KNOWN_MODULES):
|
||||||
if type(type_) != str: # Already deserialized
|
if type(type_) != str: # Already deserialized
|
||||||
return type_
|
return type_
|
||||||
if type_ == 'str':
|
if type_ == "str":
|
||||||
return lambda x='': x
|
return lambda x="": x
|
||||||
if type_ == 'None':
|
if type_ == "None":
|
||||||
return lambda x=None: None
|
return lambda x=None: None
|
||||||
if hasattr(builtins, type_): # Check if it's a builtin type
|
if hasattr(builtins, type_): # Check if it's a builtin type
|
||||||
cls = getattr(builtins, type_)
|
cls = getattr(builtins, type_)
|
||||||
@ -172,7 +176,7 @@ def deserializer(type_, known_modules=KNOWN_MODULES):
|
|||||||
modname, tname = match.group(1).rsplit(".", 1)
|
modname, tname = match.group(1).rsplit(".", 1)
|
||||||
module = importlib.import_module(modname)
|
module = importlib.import_module(modname)
|
||||||
cls = getattr(module, tname)
|
cls = getattr(module, tname)
|
||||||
return getattr(cls, 'deserialize', cls)
|
return getattr(cls, "deserialize", cls)
|
||||||
|
|
||||||
# Otherwise, see if we can find the module and the class
|
# Otherwise, see if we can find the module and the class
|
||||||
options = []
|
options = []
|
||||||
@ -181,7 +185,7 @@ def deserializer(type_, known_modules=KNOWN_MODULES):
|
|||||||
if mod:
|
if mod:
|
||||||
options.append((mod, type_))
|
options.append((mod, type_))
|
||||||
|
|
||||||
if '.' in type_: # Fully qualified module
|
if "." in type_: # Fully qualified module
|
||||||
module, type_ = type_.rsplit(".", 1)
|
module, type_ = type_.rsplit(".", 1)
|
||||||
options.append((module, type_))
|
options.append((module, type_))
|
||||||
|
|
||||||
@ -190,32 +194,31 @@ def deserializer(type_, known_modules=KNOWN_MODULES):
|
|||||||
try:
|
try:
|
||||||
module = importlib.import_module(modname)
|
module = importlib.import_module(modname)
|
||||||
cls = getattr(module, tname)
|
cls = getattr(module, tname)
|
||||||
return getattr(cls, 'deserialize', cls)
|
return getattr(cls, "deserialize", cls)
|
||||||
except (ImportError, AttributeError) as ex:
|
except (ImportError, AttributeError) as ex:
|
||||||
errors.append((modname, tname, ex))
|
errors.append((modname, tname, ex))
|
||||||
raise Exception('Could not find type "{}". Tried: {}'.format(type_, errors))
|
raise Exception('Could not find type "{}". Tried: {}'.format(type_, errors))
|
||||||
|
|
||||||
|
|
||||||
def deserialize(type_, value=None, globs=None, **kwargs):
|
def deserialize(type_, value=None, globs=None, **kwargs):
|
||||||
'''Get an object from a text representation'''
|
"""Get an object from a text representation"""
|
||||||
if not isinstance(type_, str):
|
if not isinstance(type_, str):
|
||||||
return type_
|
return type_
|
||||||
if globs and type_ in globs:
|
if globs and type_ in globs:
|
||||||
des = globs[type_]
|
des = globs[type_]
|
||||||
else:
|
else:
|
||||||
des = deserializer(type_, **kwargs)
|
des = deserializer(type_, **kwargs)
|
||||||
if value is None:
|
if value is None:
|
||||||
return des
|
return des
|
||||||
return des(value)
|
return des(value)
|
||||||
|
|
||||||
|
|
||||||
def deserialize_all(names, *args, known_modules=KNOWN_MODULES, **kwargs):
|
def deserialize_all(names, *args, known_modules=KNOWN_MODULES, **kwargs):
|
||||||
'''Return the list of deserialized objects'''
|
"""Return the list of deserialized objects"""
|
||||||
#TODO: remove
|
# TODO: remove
|
||||||
print('SERIALIZATION', kwargs)
|
print("SERIALIZATION", kwargs)
|
||||||
objects = []
|
objects = []
|
||||||
for name in names:
|
for name in names:
|
||||||
mod = deserialize(name, known_modules=known_modules)
|
mod = deserialize(name, known_modules=known_modules)
|
||||||
objects.append(mod(*args, **kwargs))
|
objects.append(mod(*args, **kwargs))
|
||||||
return objects
|
return objects
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ from .time import INFINITY
|
|||||||
from .config import Config, convert_old
|
from .config import Config, convert_old
|
||||||
|
|
||||||
|
|
||||||
#TODO: change documentation for simulation
|
# TODO: change documentation for simulation
|
||||||
@dataclass
|
@dataclass
|
||||||
class Simulation:
|
class Simulation:
|
||||||
"""
|
"""
|
||||||
@ -36,15 +36,16 @@ class Simulation:
|
|||||||
|
|
||||||
kwargs: parameters to use to initialize a new configuration, if one not been provided.
|
kwargs: parameters to use to initialize a new configuration, if one not been provided.
|
||||||
"""
|
"""
|
||||||
version: str = '2'
|
|
||||||
name: str = 'Unnamed simulation'
|
version: str = "2"
|
||||||
description: Optional[str] = ''
|
name: str = "Unnamed simulation"
|
||||||
|
description: Optional[str] = ""
|
||||||
group: str = None
|
group: str = None
|
||||||
model_class: Union[str, type] = 'soil.Environment'
|
model_class: Union[str, type] = "soil.Environment"
|
||||||
model_params: dict = field(default_factory=dict)
|
model_params: dict = field(default_factory=dict)
|
||||||
seed: str = field(default_factory=lambda: current_time())
|
seed: str = field(default_factory=lambda: current_time())
|
||||||
dir_path: str = field(default_factory=lambda: os.getcwd())
|
dir_path: str = field(default_factory=lambda: os.getcwd())
|
||||||
max_time: float = float('inf')
|
max_time: float = float("inf")
|
||||||
max_steps: int = -1
|
max_steps: int = -1
|
||||||
interval: int = 1
|
interval: int = 1
|
||||||
num_trials: int = 3
|
num_trials: int = 3
|
||||||
@ -58,12 +59,13 @@ class Simulation:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, env, **kwargs):
|
def from_dict(cls, env, **kwargs):
|
||||||
|
|
||||||
ignored = {k: v for k, v in env.items()
|
ignored = {
|
||||||
if k not in inspect.signature(cls).parameters}
|
k: v for k, v in env.items() if k not in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
|
||||||
d = {k:v for k, v in env.items() if k not in ignored}
|
d = {k: v for k, v in env.items() if k not in ignored}
|
||||||
if ignored:
|
if ignored:
|
||||||
d.setdefault('extra', {}).update(ignored)
|
d.setdefault("extra", {}).update(ignored)
|
||||||
if ignored:
|
if ignored:
|
||||||
print(f'Warning: Ignoring these parameters (added to "extra"): { ignored }')
|
print(f'Warning: Ignoring these parameters (added to "extra"): { ignored }')
|
||||||
d.update(kwargs)
|
d.update(kwargs)
|
||||||
@ -74,24 +76,34 @@ class Simulation:
|
|||||||
return self.run(*args, **kwargs)
|
return self.run(*args, **kwargs)
|
||||||
|
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
'''Run the simulation and return the list of resulting environments'''
|
"""Run the simulation and return the list of resulting environments"""
|
||||||
logger.info(dedent('''
|
logger.info(
|
||||||
|
dedent(
|
||||||
|
"""
|
||||||
Simulation:
|
Simulation:
|
||||||
---
|
---
|
||||||
''') +
|
"""
|
||||||
self.to_yaml())
|
)
|
||||||
|
+ self.to_yaml()
|
||||||
|
)
|
||||||
return list(self.run_gen(*args, **kwargs))
|
return list(self.run_gen(*args, **kwargs))
|
||||||
|
|
||||||
def run_gen(self, parallel=False, dry_run=None,
|
def run_gen(
|
||||||
exporters=None, outdir=None, exporter_params={},
|
self,
|
||||||
log_level=None,
|
parallel=False,
|
||||||
**kwargs):
|
dry_run=None,
|
||||||
'''Run the simulation and yield the resulting environments.'''
|
exporters=None,
|
||||||
|
outdir=None,
|
||||||
|
exporter_params={},
|
||||||
|
log_level=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Run the simulation and yield the resulting environments."""
|
||||||
if log_level:
|
if log_level:
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
outdir = outdir or self.outdir
|
outdir = outdir or self.outdir
|
||||||
logger.info('Using exporters: %s', exporters or [])
|
logger.info("Using exporters: %s", exporters or [])
|
||||||
logger.info('Output directory: %s', outdir)
|
logger.info("Output directory: %s", outdir)
|
||||||
if dry_run is None:
|
if dry_run is None:
|
||||||
dry_run = self.dry_run
|
dry_run = self.dry_run
|
||||||
if exporters is None:
|
if exporters is None:
|
||||||
@ -99,22 +111,28 @@ class Simulation:
|
|||||||
if not exporter_params:
|
if not exporter_params:
|
||||||
exporter_params = self.exporter_params
|
exporter_params = self.exporter_params
|
||||||
|
|
||||||
exporters = serialization.deserialize_all(exporters,
|
exporters = serialization.deserialize_all(
|
||||||
simulation=self,
|
exporters,
|
||||||
known_modules=['soil.exporters', ],
|
simulation=self,
|
||||||
dry_run=dry_run,
|
known_modules=[
|
||||||
outdir=outdir,
|
"soil.exporters",
|
||||||
**exporter_params)
|
],
|
||||||
|
dry_run=dry_run,
|
||||||
|
outdir=outdir,
|
||||||
|
**exporter_params,
|
||||||
|
)
|
||||||
|
|
||||||
with utils.timer('simulation {}'.format(self.name)):
|
with utils.timer("simulation {}".format(self.name)):
|
||||||
for exporter in exporters:
|
for exporter in exporters:
|
||||||
exporter.sim_start()
|
exporter.sim_start()
|
||||||
|
|
||||||
for env in utils.run_parallel(func=self.run_trial,
|
for env in utils.run_parallel(
|
||||||
iterable=range(int(self.num_trials)),
|
func=self.run_trial,
|
||||||
parallel=parallel,
|
iterable=range(int(self.num_trials)),
|
||||||
log_level=log_level,
|
parallel=parallel,
|
||||||
**kwargs):
|
log_level=log_level,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
for exporter in exporters:
|
for exporter in exporters:
|
||||||
exporter.trial_start(env)
|
exporter.trial_start(env)
|
||||||
@ -128,11 +146,12 @@ class Simulation:
|
|||||||
exporter.sim_end()
|
exporter.sim_end()
|
||||||
|
|
||||||
def get_env(self, trial_id=0, model_params=None, **kwargs):
|
def get_env(self, trial_id=0, model_params=None, **kwargs):
|
||||||
'''Create an environment for a trial of the simulation'''
|
"""Create an environment for a trial of the simulation"""
|
||||||
|
|
||||||
def deserialize_reporters(reporters):
|
def deserialize_reporters(reporters):
|
||||||
for (k, v) in reporters.items():
|
for (k, v) in reporters.items():
|
||||||
if isinstance(v, str) and v.startswith('py:'):
|
if isinstance(v, str) and v.startswith("py:"):
|
||||||
reporters[k] = serialization.deserialize(value.lsplit(':', 1)[1])
|
reporters[k] = serialization.deserialize(value.lsplit(":", 1)[1])
|
||||||
return reporters
|
return reporters
|
||||||
|
|
||||||
params = self.model_params.copy()
|
params = self.model_params.copy()
|
||||||
@ -140,18 +159,22 @@ class Simulation:
|
|||||||
params.update(model_params)
|
params.update(model_params)
|
||||||
params.update(kwargs)
|
params.update(kwargs)
|
||||||
|
|
||||||
agent_reporters = deserialize_reporters(params.pop('agent_reporters', {}))
|
agent_reporters = deserialize_reporters(params.pop("agent_reporters", {}))
|
||||||
model_reporters = deserialize_reporters(params.pop('model_reporters', {}))
|
model_reporters = deserialize_reporters(params.pop("model_reporters", {}))
|
||||||
|
|
||||||
env = serialization.deserialize(self.model_class)
|
env = serialization.deserialize(self.model_class)
|
||||||
return env(id=f'{self.name}_trial_{trial_id}',
|
return env(
|
||||||
seed=f'{self.seed}_trial_{trial_id}',
|
id=f"{self.name}_trial_{trial_id}",
|
||||||
dir_path=self.dir_path,
|
seed=f"{self.seed}_trial_{trial_id}",
|
||||||
agent_reporters=agent_reporters,
|
dir_path=self.dir_path,
|
||||||
model_reporters=model_reporters,
|
agent_reporters=agent_reporters,
|
||||||
**params)
|
model_reporters=model_reporters,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
|
||||||
def run_trial(self, trial_id=None, until=None, log_file=False, log_level=logging.INFO, **opts):
|
def run_trial(
|
||||||
|
self, trial_id=None, until=None, log_file=False, log_level=logging.INFO, **opts
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Run a single trial of the simulation
|
Run a single trial of the simulation
|
||||||
|
|
||||||
@ -160,50 +183,58 @@ class Simulation:
|
|||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
model = self.get_env(trial_id, **opts)
|
model = self.get_env(trial_id, **opts)
|
||||||
trial_id = trial_id if trial_id is not None else current_time()
|
trial_id = trial_id if trial_id is not None else current_time()
|
||||||
with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)):
|
with utils.timer("Simulation {} trial {}".format(self.name, trial_id)):
|
||||||
return self.run_model(model=model, trial_id=trial_id, until=until, log_level=log_level)
|
return self.run_model(
|
||||||
|
model=model, trial_id=trial_id, until=until, log_level=log_level
|
||||||
|
)
|
||||||
|
|
||||||
def run_model(self, model, until=None, **opts):
|
def run_model(self, model, until=None, **opts):
|
||||||
# Set-up trial environment and graph
|
# Set-up trial environment and graph
|
||||||
until = float(until or self.max_time or 'inf')
|
until = float(until or self.max_time or "inf")
|
||||||
|
|
||||||
# Set up agents on nodes
|
# Set up agents on nodes
|
||||||
def is_done():
|
def is_done():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if until and hasattr(model.schedule, 'time'):
|
if until and hasattr(model.schedule, "time"):
|
||||||
prev = is_done
|
prev = is_done
|
||||||
|
|
||||||
def is_done():
|
def is_done():
|
||||||
return prev() or model.schedule.time >= until
|
return prev() or model.schedule.time >= until
|
||||||
|
|
||||||
if self.max_steps and self.max_steps > 0 and hasattr(model.schedule, 'steps'):
|
if self.max_steps and self.max_steps > 0 and hasattr(model.schedule, "steps"):
|
||||||
prev_steps = is_done
|
prev_steps = is_done
|
||||||
|
|
||||||
def is_done():
|
def is_done():
|
||||||
return prev_steps() or model.schedule.steps >= self.max_steps
|
return prev_steps() or model.schedule.steps >= self.max_steps
|
||||||
|
|
||||||
newline = '\n'
|
newline = "\n"
|
||||||
logger.info(dedent(f'''
|
logger.info(
|
||||||
|
dedent(
|
||||||
|
f"""
|
||||||
Model stats:
|
Model stats:
|
||||||
Agents (total: { model.schedule.get_agent_count() }):
|
Agents (total: { model.schedule.get_agent_count() }):
|
||||||
- { (newline + ' - ').join(str(a) for a in model.schedule.agents) }
|
- { (newline + ' - ').join(str(a) for a in model.schedule.agents) }
|
||||||
|
|
||||||
Topology size: { len(model.G) if hasattr(model, "G") else 0 }
|
Topology size: { len(model.G) if hasattr(model, "G") else 0 }
|
||||||
'''))
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
while not is_done():
|
while not is_done():
|
||||||
utils.logger.debug(f'Simulation time {model.schedule.time}/{until}. Next: {getattr(model.schedule, "next_time", model.schedule.time + self.interval)}')
|
utils.logger.debug(
|
||||||
|
f'Simulation time {model.schedule.time}/{until}. Next: {getattr(model.schedule, "next_time", model.schedule.time + self.interval)}'
|
||||||
|
)
|
||||||
model.step()
|
model.step()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
d = asdict(self)
|
d = asdict(self)
|
||||||
if not isinstance(d['model_class'], str):
|
if not isinstance(d["model_class"], str):
|
||||||
d['model_class'] = serialization.name(d['model_class'])
|
d["model_class"] = serialization.name(d["model_class"])
|
||||||
d['model_params'] = serialization.serialize_dict(d['model_params'])
|
d["model_params"] = serialization.serialize_dict(d["model_params"])
|
||||||
d['dir_path'] = str(d['dir_path'])
|
d["dir_path"] = str(d["dir_path"])
|
||||||
d['version'] = '2'
|
d["version"] = "2"
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def to_yaml(self):
|
def to_yaml(self):
|
||||||
@ -215,15 +246,15 @@ def iter_from_config(*cfgs, **kwargs):
|
|||||||
configs = list(serialization.load_config(config))
|
configs = list(serialization.load_config(config))
|
||||||
for config, path in configs:
|
for config, path in configs:
|
||||||
d = dict(config)
|
d = dict(config)
|
||||||
if 'dir_path' not in d:
|
if "dir_path" not in d:
|
||||||
d['dir_path'] = os.path.dirname(path)
|
d["dir_path"] = os.path.dirname(path)
|
||||||
yield Simulation.from_dict(d, **kwargs)
|
yield Simulation.from_dict(d, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def from_config(conf_or_path):
|
def from_config(conf_or_path):
|
||||||
lst = list(iter_from_config(conf_or_path))
|
lst = list(iter_from_config(conf_or_path))
|
||||||
if len(lst) > 1:
|
if len(lst) > 1:
|
||||||
raise AttributeError('Provide only one configuration')
|
raise AttributeError("Provide only one configuration")
|
||||||
return lst[0]
|
return lst[0]
|
||||||
|
|
||||||
|
|
||||||
|
19
soil/time.py
19
soil/time.py
@ -6,7 +6,8 @@ from .utils import logger
|
|||||||
from mesa import Agent as MesaAgent
|
from mesa import Agent as MesaAgent
|
||||||
|
|
||||||
|
|
||||||
INFINITY = float('inf')
|
INFINITY = float("inf")
|
||||||
|
|
||||||
|
|
||||||
class When:
|
class When:
|
||||||
def __init__(self, time):
|
def __init__(self, time):
|
||||||
@ -42,7 +43,7 @@ class TimedActivation(BaseScheduler):
|
|||||||
self._next = {}
|
self._next = {}
|
||||||
self._queue = []
|
self._queue = []
|
||||||
self.next_time = 0
|
self.next_time = 0
|
||||||
self.logger = logger.getChild(f'time_{ self.model }')
|
self.logger = logger.getChild(f"time_{ self.model }")
|
||||||
|
|
||||||
def add(self, agent: MesaAgent, when=None):
|
def add(self, agent: MesaAgent, when=None):
|
||||||
if when is None:
|
if when is None:
|
||||||
@ -62,7 +63,7 @@ class TimedActivation(BaseScheduler):
|
|||||||
an agent will signal when it wants to be scheduled next.
|
an agent will signal when it wants to be scheduled next.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.logger.debug(f'Simulation step {self.next_time}')
|
self.logger.debug(f"Simulation step {self.next_time}")
|
||||||
if not self.model.running:
|
if not self.model.running:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -71,18 +72,22 @@ class TimedActivation(BaseScheduler):
|
|||||||
|
|
||||||
while self._queue and self._queue[0][0] == self.time:
|
while self._queue and self._queue[0][0] == self.time:
|
||||||
(when, agent_id) = heappop(self._queue)
|
(when, agent_id) = heappop(self._queue)
|
||||||
self.logger.debug(f'Stepping agent {agent_id}')
|
self.logger.debug(f"Stepping agent {agent_id}")
|
||||||
|
|
||||||
agent = self._agents[agent_id]
|
agent = self._agents[agent_id]
|
||||||
returned = agent.step()
|
returned = agent.step()
|
||||||
|
|
||||||
if not getattr(agent, 'alive', True):
|
if not getattr(agent, "alive", True):
|
||||||
self.remove(agent)
|
self.remove(agent)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
when = (returned or Delta(1)).abs(self.time)
|
when = (returned or Delta(1)).abs(self.time)
|
||||||
if when < self.time:
|
if when < self.time:
|
||||||
raise Exception("Cannot schedule an agent for a time in the past ({} < {})".format(when, self.time))
|
raise Exception(
|
||||||
|
"Cannot schedule an agent for a time in the past ({} < {})".format(
|
||||||
|
when, self.time
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self._next[agent_id] = when
|
self._next[agent_id] = when
|
||||||
heappush(self._queue, (when, agent_id))
|
heappush(self._queue, (when, agent_id))
|
||||||
@ -96,4 +101,4 @@ class TimedActivation(BaseScheduler):
|
|||||||
return self.time
|
return self.time
|
||||||
|
|
||||||
self.next_time = self._queue[0][0]
|
self.next_time = self._queue[0][0]
|
||||||
self.logger.debug(f'Next step: {self.next_time}')
|
self.logger.debug(f"Next step: {self.next_time}")
|
||||||
|
@ -9,12 +9,12 @@ from multiprocessing import Pool
|
|||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
logger = logging.getLogger('soil')
|
logger = logging.getLogger("soil")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
timeformat = "%H:%M:%S"
|
timeformat = "%H:%M:%S"
|
||||||
|
|
||||||
if os.environ.get('SOIL_VERBOSE', ''):
|
if os.environ.get("SOIL_VERBOSE", ""):
|
||||||
logformat = "[%(levelname)-5.5s][%(asctime)s][%(name)s]: %(message)s"
|
logformat = "[%(levelname)-5.5s][%(asctime)s][%(name)s]: %(message)s"
|
||||||
else:
|
else:
|
||||||
logformat = "[%(levelname)-5.5s][%(asctime)s] %(message)s"
|
logformat = "[%(levelname)-5.5s][%(asctime)s] %(message)s"
|
||||||
@ -23,38 +23,44 @@ logFormatter = logging.Formatter(logformat, timeformat)
|
|||||||
consoleHandler = logging.StreamHandler()
|
consoleHandler = logging.StreamHandler()
|
||||||
consoleHandler.setFormatter(logFormatter)
|
consoleHandler.setFormatter(logFormatter)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO,
|
logging.basicConfig(
|
||||||
handlers=[consoleHandler,])
|
level=logging.INFO,
|
||||||
|
handlers=[
|
||||||
|
consoleHandler,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def timer(name='task', pre="", function=logger.info, to_object=None):
|
def timer(name="task", pre="", function=logger.info, to_object=None):
|
||||||
start = current_time()
|
start = current_time()
|
||||||
function('{}Starting {} at {}.'.format(pre, name,
|
function("{}Starting {} at {}.".format(pre, name, strftime("%X", gmtime(start))))
|
||||||
strftime("%X", gmtime(start))))
|
|
||||||
yield start
|
yield start
|
||||||
end = current_time()
|
end = current_time()
|
||||||
function('{}Finished {} at {} in {} seconds'.format(pre, name,
|
function(
|
||||||
strftime("%X", gmtime(end)),
|
"{}Finished {} at {} in {} seconds".format(
|
||||||
str(end-start)))
|
pre, name, strftime("%X", gmtime(end)), str(end - start)
|
||||||
|
)
|
||||||
|
)
|
||||||
if to_object:
|
if to_object:
|
||||||
to_object.start = start
|
to_object.start = start
|
||||||
to_object.end = end
|
to_object.end = end
|
||||||
|
|
||||||
|
|
||||||
def safe_open(path, mode='r', backup=True, **kwargs):
|
def safe_open(path, mode="r", backup=True, **kwargs):
|
||||||
outdir = os.path.dirname(path)
|
outdir = os.path.dirname(path)
|
||||||
if outdir and not os.path.exists(outdir):
|
if outdir and not os.path.exists(outdir):
|
||||||
os.makedirs(outdir)
|
os.makedirs(outdir)
|
||||||
if backup and 'w' in mode and os.path.exists(path):
|
if backup and "w" in mode and os.path.exists(path):
|
||||||
creation = os.path.getctime(path)
|
creation = os.path.getctime(path)
|
||||||
stamp = strftime('%Y-%m-%d_%H.%M.%S', localtime(creation))
|
stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation))
|
||||||
|
|
||||||
backup_dir = os.path.join(outdir, 'backup')
|
backup_dir = os.path.join(outdir, "backup")
|
||||||
if not os.path.exists(backup_dir):
|
if not os.path.exists(backup_dir):
|
||||||
os.makedirs(backup_dir)
|
os.makedirs(backup_dir)
|
||||||
newpath = os.path.join(backup_dir, '{}@{}'.format(os.path.basename(path),
|
newpath = os.path.join(
|
||||||
stamp))
|
backup_dir, "{}@{}".format(os.path.basename(path), stamp)
|
||||||
|
)
|
||||||
copyfile(path, newpath)
|
copyfile(path, newpath)
|
||||||
return open(path, mode=mode, **kwargs)
|
return open(path, mode=mode, **kwargs)
|
||||||
|
|
||||||
@ -67,21 +73,23 @@ def open_or_reuse(f, *args, **kwargs):
|
|||||||
except (AttributeError, TypeError):
|
except (AttributeError, TypeError):
|
||||||
yield f
|
yield f
|
||||||
|
|
||||||
|
|
||||||
def flatten_dict(d):
|
def flatten_dict(d):
|
||||||
if not isinstance(d, dict):
|
if not isinstance(d, dict):
|
||||||
return d
|
return d
|
||||||
return dict(_flatten_dict(d))
|
return dict(_flatten_dict(d))
|
||||||
|
|
||||||
def _flatten_dict(d, prefix=''):
|
|
||||||
|
def _flatten_dict(d, prefix=""):
|
||||||
if not isinstance(d, dict):
|
if not isinstance(d, dict):
|
||||||
# print('END:', prefix, d)
|
# print('END:', prefix, d)
|
||||||
yield prefix, d
|
yield prefix, d
|
||||||
return
|
return
|
||||||
if prefix:
|
if prefix:
|
||||||
prefix = prefix + '.'
|
prefix = prefix + "."
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
# print(k, v)
|
# print(k, v)
|
||||||
res = list(_flatten_dict(v, prefix='{}{}'.format(prefix, k)))
|
res = list(_flatten_dict(v, prefix="{}{}".format(prefix, k)))
|
||||||
# print('RES:', res)
|
# print('RES:', res)
|
||||||
yield from res
|
yield from res
|
||||||
|
|
||||||
@ -93,7 +101,7 @@ def unflatten_dict(d):
|
|||||||
if not isinstance(k, str):
|
if not isinstance(k, str):
|
||||||
target[k] = v
|
target[k] = v
|
||||||
continue
|
continue
|
||||||
tokens = k.split('.')
|
tokens = k.split(".")
|
||||||
if len(tokens) < 2:
|
if len(tokens) < 2:
|
||||||
target[k] = v
|
target[k] = v
|
||||||
continue
|
continue
|
||||||
@ -106,27 +114,28 @@ def unflatten_dict(d):
|
|||||||
|
|
||||||
|
|
||||||
def run_and_return_exceptions(func, *args, **kwargs):
|
def run_and_return_exceptions(func, *args, **kwargs):
|
||||||
'''
|
"""
|
||||||
A wrapper for run_trial that catches exceptions and returns them.
|
A wrapper for run_trial that catches exceptions and returns them.
|
||||||
It is meant for async simulations.
|
It is meant for async simulations.
|
||||||
'''
|
"""
|
||||||
try:
|
try:
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
if ex.__cause__ is not None:
|
if ex.__cause__ is not None:
|
||||||
ex = ex.__cause__
|
ex = ex.__cause__
|
||||||
ex.message = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)[:])
|
ex.message = "".join(
|
||||||
|
traceback.format_exception(type(ex), ex, ex.__traceback__)[:]
|
||||||
|
)
|
||||||
return ex
|
return ex
|
||||||
|
|
||||||
|
|
||||||
def run_parallel(func, iterable, parallel=False, **kwargs):
|
def run_parallel(func, iterable, parallel=False, **kwargs):
|
||||||
if parallel and not os.environ.get('SOIL_DEBUG', None):
|
if parallel and not os.environ.get("SOIL_DEBUG", None):
|
||||||
p = Pool()
|
p = Pool()
|
||||||
wrapped_func = partial(run_and_return_exceptions,
|
wrapped_func = partial(run_and_return_exceptions, func, **kwargs)
|
||||||
func, **kwargs)
|
|
||||||
for i in p.imap_unordered(wrapped_func, iterable):
|
for i in p.imap_unordered(wrapped_func, iterable):
|
||||||
if isinstance(i, Exception):
|
if isinstance(i, Exception):
|
||||||
logger.error('Trial failed:\n\t%s', i.message)
|
logger.error("Trial failed:\n\t%s", i.message)
|
||||||
continue
|
continue
|
||||||
yield i
|
yield i
|
||||||
else:
|
else:
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ROOT = os.path.dirname(__file__)
|
ROOT = os.path.dirname(__file__)
|
||||||
DEFAULT_FILE = os.path.join(ROOT, 'VERSION')
|
DEFAULT_FILE = os.path.join(ROOT, "VERSION")
|
||||||
|
|
||||||
|
|
||||||
def read_version(versionfile=DEFAULT_FILE):
|
def read_version(versionfile=DEFAULT_FILE):
|
||||||
@ -12,9 +12,10 @@ def read_version(versionfile=DEFAULT_FILE):
|
|||||||
with open(versionfile) as f:
|
with open(versionfile) as f:
|
||||||
return f.read().strip()
|
return f.read().strip()
|
||||||
except IOError: # pragma: no cover
|
except IOError: # pragma: no cover
|
||||||
logger.error(('Running an unknown version of {}.'
|
logger.error(
|
||||||
'Be careful!.').format(__name__))
|
("Running an unknown version of {}." "Be careful!.").format(__name__)
|
||||||
return '0.0'
|
)
|
||||||
|
return "0.0"
|
||||||
|
|
||||||
|
|
||||||
__version__ = read_version()
|
__version__ = read_version()
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from mesa.visualization.UserParam import UserSettableParameter
|
from mesa.visualization.UserParam import UserSettableParameter
|
||||||
|
|
||||||
|
|
||||||
class UserSettableParameter(UserSettableParameter):
|
class UserSettableParameter(UserSettableParameter):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.value
|
return self.value
|
||||||
|
@ -20,6 +20,7 @@ from tornado.concurrent import run_on_executor
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
from ..simulation import Simulation
|
from ..simulation import Simulation
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
@ -31,140 +32,183 @@ LOGGING_INTERVAL = 0.5
|
|||||||
# Workaround to let Soil load the required modules
|
# Workaround to let Soil load the required modules
|
||||||
sys.path.append(ROOT)
|
sys.path.append(ROOT)
|
||||||
|
|
||||||
|
|
||||||
class PageHandler(tornado.web.RequestHandler):
|
class PageHandler(tornado.web.RequestHandler):
|
||||||
""" Handler for the HTML template which holds the visualization. """
|
"""Handler for the HTML template which holds the visualization."""
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
self.render('index.html', port=self.application.port,
|
self.render(
|
||||||
name=self.application.name)
|
"index.html", port=self.application.port, name=self.application.name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SocketHandler(tornado.websocket.WebSocketHandler):
|
class SocketHandler(tornado.websocket.WebSocketHandler):
|
||||||
""" Handler for websocket. """
|
"""Handler for websocket."""
|
||||||
|
|
||||||
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
executor = ThreadPoolExecutor(max_workers=MAX_WORKERS)
|
||||||
|
|
||||||
def open(self):
|
def open(self):
|
||||||
if self.application.verbose:
|
if self.application.verbose:
|
||||||
logger.info('Socket opened!')
|
logger.info("Socket opened!")
|
||||||
|
|
||||||
def check_origin(self, origin):
|
def check_origin(self, origin):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def on_message(self, message):
|
def on_message(self, message):
|
||||||
""" Receiving a message from the websocket, parse, and act accordingly. """
|
"""Receiving a message from the websocket, parse, and act accordingly."""
|
||||||
|
|
||||||
msg = tornado.escape.json_decode(message)
|
msg = tornado.escape.json_decode(message)
|
||||||
|
|
||||||
if msg['type'] == 'config_file':
|
if msg["type"] == "config_file":
|
||||||
|
|
||||||
if self.application.verbose:
|
if self.application.verbose:
|
||||||
print(msg['data'])
|
print(msg["data"])
|
||||||
|
|
||||||
self.config = list(yaml.load_all(msg['data']))
|
self.config = list(yaml.load_all(msg["data"]))
|
||||||
|
|
||||||
if len(self.config) > 1:
|
if len(self.config) > 1:
|
||||||
error = 'Please, provide only one configuration.'
|
error = "Please, provide only one configuration."
|
||||||
if self.application.verbose:
|
if self.application.verbose:
|
||||||
logger.error(error)
|
logger.error(error)
|
||||||
self.write_message({'type': 'error',
|
self.write_message({"type": "error", "error": error})
|
||||||
'error': error})
|
|
||||||
return
|
return
|
||||||
|
|
||||||
self.config = self.config[0]
|
self.config = self.config[0]
|
||||||
self.send_log('INFO.' + self.simulation_name,
|
self.send_log(
|
||||||
'Using config: {name}'.format(name=self.config['name']))
|
"INFO." + self.simulation_name,
|
||||||
|
"Using config: {name}".format(name=self.config["name"]),
|
||||||
|
)
|
||||||
|
|
||||||
if 'visualization_params' in self.config:
|
if "visualization_params" in self.config:
|
||||||
self.write_message({'type': 'visualization_params',
|
self.write_message(
|
||||||
'data': self.config['visualization_params']})
|
{
|
||||||
self.name = self.config['name']
|
"type": "visualization_params",
|
||||||
|
"data": self.config["visualization_params"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.name = self.config["name"]
|
||||||
self.run_simulation()
|
self.run_simulation()
|
||||||
|
|
||||||
settings = []
|
settings = []
|
||||||
for key in self.config['environment_params']:
|
for key in self.config["environment_params"]:
|
||||||
if type(self.config['environment_params'][key]) == float or type(self.config['environment_params'][key]) == int:
|
if (
|
||||||
if self.config['environment_params'][key] <= 1:
|
type(self.config["environment_params"][key]) == float
|
||||||
setting_type = 'number'
|
or type(self.config["environment_params"][key]) == int
|
||||||
|
):
|
||||||
|
if self.config["environment_params"][key] <= 1:
|
||||||
|
setting_type = "number"
|
||||||
else:
|
else:
|
||||||
setting_type = 'great_number'
|
setting_type = "great_number"
|
||||||
elif type(self.config['environment_params'][key]) == bool:
|
elif type(self.config["environment_params"][key]) == bool:
|
||||||
setting_type = 'boolean'
|
setting_type = "boolean"
|
||||||
else:
|
else:
|
||||||
setting_type = 'undefined'
|
setting_type = "undefined"
|
||||||
|
|
||||||
settings.append({
|
settings.append(
|
||||||
'label': key,
|
{
|
||||||
'type': setting_type,
|
"label": key,
|
||||||
'value': self.config['environment_params'][key]
|
"type": setting_type,
|
||||||
})
|
"value": self.config["environment_params"][key],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
self.write_message({'type': 'settings',
|
self.write_message({"type": "settings", "data": settings})
|
||||||
'data': settings})
|
|
||||||
|
|
||||||
elif msg['type'] == 'get_trial':
|
elif msg["type"] == "get_trial":
|
||||||
if self.application.verbose:
|
if self.application.verbose:
|
||||||
logger.info('Trial {} requested!'.format(msg['data']))
|
logger.info("Trial {} requested!".format(msg["data"]))
|
||||||
self.send_log('INFO.' + __name__, 'Trial {} requested!'.format(msg['data']))
|
self.send_log("INFO." + __name__, "Trial {} requested!".format(msg["data"]))
|
||||||
self.write_message({'type': 'get_trial',
|
self.write_message(
|
||||||
'data': self.get_trial(int(msg['data']))})
|
{"type": "get_trial", "data": self.get_trial(int(msg["data"]))}
|
||||||
|
)
|
||||||
|
|
||||||
elif msg['type'] == 'run_simulation':
|
elif msg["type"] == "run_simulation":
|
||||||
if self.application.verbose:
|
if self.application.verbose:
|
||||||
logger.info('Running new simulation for {name}'.format(name=self.config['name']))
|
logger.info(
|
||||||
self.send_log('INFO.' + self.simulation_name, 'Running new simulation for {name}'.format(name=self.config['name']))
|
"Running new simulation for {name}".format(name=self.config["name"])
|
||||||
self.config['environment_params'] = msg['data']
|
)
|
||||||
|
self.send_log(
|
||||||
|
"INFO." + self.simulation_name,
|
||||||
|
"Running new simulation for {name}".format(name=self.config["name"]),
|
||||||
|
)
|
||||||
|
self.config["environment_params"] = msg["data"]
|
||||||
self.run_simulation()
|
self.run_simulation()
|
||||||
|
|
||||||
elif msg['type'] == 'download_gexf':
|
elif msg["type"] == "download_gexf":
|
||||||
G = self.trials[ int(msg['data']) ].history_to_graph()
|
G = self.trials[int(msg["data"])].history_to_graph()
|
||||||
for node in G.nodes():
|
for node in G.nodes():
|
||||||
if 'pos' in G.nodes[node]:
|
if "pos" in G.nodes[node]:
|
||||||
G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}}
|
G.nodes[node]["viz"] = {
|
||||||
del (G.nodes[node]['pos'])
|
"position": {
|
||||||
writer = nx.readwrite.gexf.GEXFWriter(version='1.2draft')
|
"x": G.nodes[node]["pos"][0],
|
||||||
|
"y": G.nodes[node]["pos"][1],
|
||||||
|
"z": 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
del G.nodes[node]["pos"]
|
||||||
|
writer = nx.readwrite.gexf.GEXFWriter(version="1.2draft")
|
||||||
writer.add_graph(G)
|
writer.add_graph(G)
|
||||||
self.write_message({'type': 'download_gexf',
|
self.write_message(
|
||||||
'filename': self.config['name'] + '_trial_' + str(msg['data']),
|
{
|
||||||
'data': tostring(writer.xml).decode(writer.encoding) })
|
"type": "download_gexf",
|
||||||
|
"filename": self.config["name"] + "_trial_" + str(msg["data"]),
|
||||||
|
"data": tostring(writer.xml).decode(writer.encoding),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
elif msg['type'] == 'download_json':
|
elif msg["type"] == "download_json":
|
||||||
G = self.trials[ int(msg['data']) ].history_to_graph()
|
G = self.trials[int(msg["data"])].history_to_graph()
|
||||||
for node in G.nodes():
|
for node in G.nodes():
|
||||||
if 'pos' in G.nodes[node]:
|
if "pos" in G.nodes[node]:
|
||||||
G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}}
|
G.nodes[node]["viz"] = {
|
||||||
del (G.nodes[node]['pos'])
|
"position": {
|
||||||
self.write_message({'type': 'download_json',
|
"x": G.nodes[node]["pos"][0],
|
||||||
'filename': self.config['name'] + '_trial_' + str(msg['data']),
|
"y": G.nodes[node]["pos"][1],
|
||||||
'data': nx.node_link_data(G) })
|
"z": 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
del G.nodes[node]["pos"]
|
||||||
|
self.write_message(
|
||||||
|
{
|
||||||
|
"type": "download_json",
|
||||||
|
"filename": self.config["name"] + "_trial_" + str(msg["data"]),
|
||||||
|
"data": nx.node_link_data(G),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.application.verbose:
|
if self.application.verbose:
|
||||||
logger.info('Unexpected message!')
|
logger.info("Unexpected message!")
|
||||||
|
|
||||||
def update_logging(self):
|
def update_logging(self):
|
||||||
try:
|
try:
|
||||||
if (not self.log_capture_string.closed and self.log_capture_string.getvalue()):
|
if (
|
||||||
for i in range(len(self.log_capture_string.getvalue().split('\n')) - 1):
|
not self.log_capture_string.closed
|
||||||
self.send_log('INFO.' + self.simulation_name, self.log_capture_string.getvalue().split('\n')[i])
|
and self.log_capture_string.getvalue()
|
||||||
|
):
|
||||||
|
for i in range(len(self.log_capture_string.getvalue().split("\n")) - 1):
|
||||||
|
self.send_log(
|
||||||
|
"INFO." + self.simulation_name,
|
||||||
|
self.log_capture_string.getvalue().split("\n")[i],
|
||||||
|
)
|
||||||
self.log_capture_string.truncate(0)
|
self.log_capture_string.truncate(0)
|
||||||
self.log_capture_string.seek(0)
|
self.log_capture_string.seek(0)
|
||||||
finally:
|
finally:
|
||||||
if self.capture_logging:
|
if self.capture_logging:
|
||||||
tornado.ioloop.IOLoop.current().call_later(LOGGING_INTERVAL, self.update_logging)
|
tornado.ioloop.IOLoop.current().call_later(
|
||||||
|
LOGGING_INTERVAL, self.update_logging
|
||||||
|
)
|
||||||
|
|
||||||
def on_close(self):
|
def on_close(self):
|
||||||
if self.application.verbose:
|
if self.application.verbose:
|
||||||
logger.info('Socket closed!')
|
logger.info("Socket closed!")
|
||||||
|
|
||||||
def send_log(self, logger, logging):
|
def send_log(self, logger, logging):
|
||||||
self.write_message({'type': 'log',
|
self.write_message({"type": "log", "logger": logger, "logging": logging})
|
||||||
'logger': logger,
|
|
||||||
'logging': logging})
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def simulation_name(self):
|
def simulation_name(self):
|
||||||
return self.config.get('name', 'NoSimulationRunning')
|
return self.config.get("name", "NoSimulationRunning")
|
||||||
|
|
||||||
@run_on_executor
|
@run_on_executor
|
||||||
def nonblocking(self, config):
|
def nonblocking(self, config):
|
||||||
@ -174,28 +218,31 @@ class SocketHandler(tornado.websocket.WebSocketHandler):
|
|||||||
@tornado.gen.coroutine
|
@tornado.gen.coroutine
|
||||||
def run_simulation(self):
|
def run_simulation(self):
|
||||||
# Run simulation and capture logs
|
# Run simulation and capture logs
|
||||||
logger.info('Running simulation!')
|
logger.info("Running simulation!")
|
||||||
if 'visualization_params' in self.config:
|
if "visualization_params" in self.config:
|
||||||
del self.config['visualization_params']
|
del self.config["visualization_params"]
|
||||||
with self.logging(self.simulation_name):
|
with self.logging(self.simulation_name):
|
||||||
try:
|
try:
|
||||||
config = dict(**self.config)
|
config = dict(**self.config)
|
||||||
config['outdir'] = os.path.join(self.application.outdir, config['name'])
|
config["outdir"] = os.path.join(self.application.outdir, config["name"])
|
||||||
config['dump'] = self.application.dump
|
config["dump"] = self.application.dump
|
||||||
self.trials = yield self.nonblocking(config)
|
self.trials = yield self.nonblocking(config)
|
||||||
|
|
||||||
self.write_message({'type': 'trials',
|
self.write_message(
|
||||||
'data': list(trial.name for trial in self.trials) })
|
{
|
||||||
|
"type": "trials",
|
||||||
|
"data": list(trial.name for trial in self.trials),
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
error = 'Something went wrong:\n\t{}'.format(ex)
|
error = "Something went wrong:\n\t{}".format(ex)
|
||||||
logging.info(error)
|
logging.info(error)
|
||||||
self.write_message({'type': 'error',
|
self.write_message({"type": "error", "error": error})
|
||||||
'error': error})
|
self.send_log("ERROR." + self.simulation_name, error)
|
||||||
self.send_log('ERROR.' + self.simulation_name, error)
|
|
||||||
|
|
||||||
def get_trial(self, trial):
|
def get_trial(self, trial):
|
||||||
logger.info('Available trials: %s ' % len(self.trials))
|
logger.info("Available trials: %s " % len(self.trials))
|
||||||
logger.info('Ask for : %s' % trial)
|
logger.info("Ask for : %s" % trial)
|
||||||
trial = self.trials[trial]
|
trial = self.trials[trial]
|
||||||
G = trial.history_to_graph()
|
G = trial.history_to_graph()
|
||||||
return nx.node_link_data(G)
|
return nx.node_link_data(G)
|
||||||
@ -218,21 +265,24 @@ class SocketHandler(tornado.websocket.WebSocketHandler):
|
|||||||
|
|
||||||
|
|
||||||
class ModularServer(tornado.web.Application):
|
class ModularServer(tornado.web.Application):
|
||||||
""" Main visualization application. """
|
"""Main visualization application."""
|
||||||
|
|
||||||
port = 8001
|
port = 8001
|
||||||
page_handler = (r'/', PageHandler)
|
page_handler = (r"/", PageHandler)
|
||||||
socket_handler = (r'/ws', SocketHandler)
|
socket_handler = (r"/ws", SocketHandler)
|
||||||
static_handler = (r'/(.*)', tornado.web.StaticFileHandler,
|
static_handler = (
|
||||||
{'path': os.path.join(ROOT, 'static')})
|
r"/(.*)",
|
||||||
local_handler = (r'/local/(.*)', tornado.web.StaticFileHandler,
|
tornado.web.StaticFileHandler,
|
||||||
{'path': ''})
|
{"path": os.path.join(ROOT, "static")},
|
||||||
|
)
|
||||||
|
local_handler = (r"/local/(.*)", tornado.web.StaticFileHandler, {"path": ""})
|
||||||
|
|
||||||
handlers = [page_handler, socket_handler, static_handler, local_handler]
|
handlers = [page_handler, socket_handler, static_handler, local_handler]
|
||||||
settings = {'debug': True,
|
settings = {"debug": True, "template_path": ROOT + "/templates"}
|
||||||
'template_path': ROOT + '/templates'}
|
|
||||||
|
|
||||||
def __init__(self, dump=False, outdir='output', name='SOIL', verbose=True, *args, **kwargs):
|
def __init__(
|
||||||
|
self, dump=False, outdir="output", name="SOIL", verbose=True, *args, **kwargs
|
||||||
|
):
|
||||||
|
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.name = name
|
self.name = name
|
||||||
@ -243,12 +293,12 @@ class ModularServer(tornado.web.Application):
|
|||||||
super().__init__(self.handlers, **self.settings)
|
super().__init__(self.handlers, **self.settings)
|
||||||
|
|
||||||
def launch(self, port=None):
|
def launch(self, port=None):
|
||||||
""" Run the app. """
|
"""Run the app."""
|
||||||
|
|
||||||
if port is not None:
|
if port is not None:
|
||||||
self.port = port
|
self.port = port
|
||||||
url = 'http://127.0.0.1:{PORT}'.format(PORT=self.port)
|
url = "http://127.0.0.1:{PORT}".format(PORT=self.port)
|
||||||
print('Interface starting at {url}'.format(url=url))
|
print("Interface starting at {url}".format(url=url))
|
||||||
self.listen(self.port)
|
self.listen(self.port)
|
||||||
# webbrowser.open(url)
|
# webbrowser.open(url)
|
||||||
tornado.ioloop.IOLoop.instance().start()
|
tornado.ioloop.IOLoop.instance().start()
|
||||||
@ -263,12 +313,22 @@ def run(*args, **kwargs):
|
|||||||
def main():
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Visualization of a Graph Model')
|
parser = argparse.ArgumentParser(description="Visualization of a Graph Model")
|
||||||
|
|
||||||
parser.add_argument('--name', '-n', nargs=1, default='SOIL', help='name of the simulation')
|
parser.add_argument(
|
||||||
parser.add_argument('--dump', '-d', help='dumping results in folder output', action='store_true')
|
"--name", "-n", nargs=1, default="SOIL", help="name of the simulation"
|
||||||
parser.add_argument('--port', '-p', nargs=1, default=8001, help='port for launching the server')
|
)
|
||||||
parser.add_argument('--verbose', '-v', help='verbose mode', action='store_true')
|
parser.add_argument(
|
||||||
|
"--dump", "-d", help="dumping results in folder output", action="store_true"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port", "-p", nargs=1, default=8001, help="port for launching the server"
|
||||||
|
)
|
||||||
|
parser.add_argument("--verbose", "-v", help="verbose mode", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
run(name=args.name, port=(args.port[0] if isinstance(args.port, list) else args.port), verbose=args.verbose)
|
run(
|
||||||
|
name=args.name,
|
||||||
|
port=(args.port[0] if isinstance(args.port, list) else args.port),
|
||||||
|
verbose=args.verbose,
|
||||||
|
)
|
||||||
|
@ -4,20 +4,33 @@ from simulator import Simulator
|
|||||||
|
|
||||||
|
|
||||||
def run(simulator, name="SOIL", port=8001, verbose=False):
|
def run(simulator, name="SOIL", port=8001, verbose=False):
|
||||||
server = ModularServer(simulator, name=(name[0] if isinstance(name, list) else name), verbose=verbose)
|
server = ModularServer(
|
||||||
|
simulator, name=(name[0] if isinstance(name, list) else name), verbose=verbose
|
||||||
|
)
|
||||||
server.port = port
|
server.port = port
|
||||||
server.launch()
|
server.launch()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Visualization of a Graph Model')
|
parser = argparse.ArgumentParser(description="Visualization of a Graph Model")
|
||||||
|
|
||||||
parser.add_argument('--name', '-n', nargs=1, default='SOIL', help='name of the simulation')
|
parser.add_argument(
|
||||||
parser.add_argument('--dump', '-d', help='dumping results in folder output', action='store_true')
|
"--name", "-n", nargs=1, default="SOIL", help="name of the simulation"
|
||||||
parser.add_argument('--port', '-p', nargs=1, default=8001, help='port for launching the server')
|
)
|
||||||
parser.add_argument('--verbose', '-v', help='verbose mode', action='store_true')
|
parser.add_argument(
|
||||||
|
"--dump", "-d", help="dumping results in folder output", action="store_true"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port", "-p", nargs=1, default=8001, help="port for launching the server"
|
||||||
|
)
|
||||||
|
parser.add_argument("--verbose", "-v", help="verbose mode", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
soil = Simulator(dump=args.dump)
|
soil = Simulator(dump=args.dump)
|
||||||
run(soil, name=args.name, port=(args.port[0] if isinstance(args.port, list) else args.port), verbose=args.verbose)
|
run(
|
||||||
|
soil,
|
||||||
|
name=args.name,
|
||||||
|
port=(args.port[0] if isinstance(args.port, list) else args.port),
|
||||||
|
verbose=args.verbose,
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user