diff --git a/soil/__init__.py b/soil/__init__.py index be53c47..46d56bd 100644 --- a/soil/__init__.py +++ b/soil/__init__.py @@ -22,58 +22,107 @@ from .utils import logger from .time import * -def main(cfg='simulation.yml', exporters=None, parallel=None, output="soil_output", *, do_run=False, debug=False, **kwargs): +def main( + cfg="simulation.yml", + exporters=None, + parallel=None, + output="soil_output", + *, + do_run=False, + debug=False, + **kwargs, +): import argparse from . import simulation - logger.info('Running SOIL version: {}'.format(__version__)) + logger.info("Running SOIL version: {}".format(__version__)) - parser = argparse.ArgumentParser(description='Run a SOIL simulation') - parser.add_argument('file', type=str, - nargs="?", - default=cfg, - help='Configuration file for the simulation (e.g., YAML or JSON)') - parser.add_argument('--version', action='store_true', - help='Show version info and exit') - parser.add_argument('--module', '-m', type=str, - help='file containing the code of any custom agents.') - parser.add_argument('--dry-run', '--dry', action='store_true', - help='Do not store the results of the simulation to disk, show in terminal instead.') - parser.add_argument('--pdb', action='store_true', - help='Use a pdb console in case of exception.') - parser.add_argument('--debug', action='store_true', - help='Run a customized version of a pdb console to debug a simulation.') - parser.add_argument('--graph', '-g', action='store_true', - help='Dump each trial\'s network topology as a GEXF graph. Defaults to false.') - parser.add_argument('--csv', action='store_true', - help='Dump all data collected in CSV format. Defaults to false.') - parser.add_argument('--level', type=str, - help='Logging level') - parser.add_argument('--output', '-o', type=str, default=output or "soil_output", - help='folder to write results to. It defaults to the current directory.') + parser = argparse.ArgumentParser(description="Run a SOIL simulation") + parser.add_argument( + "file", + type=str, + nargs="?", + default=cfg, + help="Configuration file for the simulation (e.g., YAML or JSON)", + ) + parser.add_argument( + "--version", action="store_true", help="Show version info and exit" + ) + parser.add_argument( + "--module", + "-m", + type=str, + help="file containing the code of any custom agents.", + ) + parser.add_argument( + "--dry-run", + "--dry", + action="store_true", + help="Do not store the results of the simulation to disk, show in terminal instead.", + ) + parser.add_argument( + "--pdb", action="store_true", help="Use a pdb console in case of exception." + ) + parser.add_argument( + "--debug", + action="store_true", + help="Run a customized version of a pdb console to debug a simulation.", + ) + parser.add_argument( + "--graph", + "-g", + action="store_true", + help="Dump each trial's network topology as a GEXF graph. Defaults to false.", + ) + parser.add_argument( + "--csv", + action="store_true", + help="Dump all data collected in CSV format. Defaults to false.", + ) + parser.add_argument("--level", type=str, help="Logging level") + parser.add_argument( + "--output", + "-o", + type=str, + default=output or "soil_output", + help="folder to write results to. It defaults to the current directory.", + ) if parallel is None: - parser.add_argument('--synchronous', action='store_true', - help='Run trials serially and synchronously instead of in parallel. Defaults to false.') + parser.add_argument( + "--synchronous", + action="store_true", + help="Run trials serially and synchronously instead of in parallel. Defaults to false.", + ) - parser.add_argument('-e', '--exporter', action='append', - default=[], - help='Export environment and/or simulations using this exporter') + parser.add_argument( + "-e", + "--exporter", + action="append", + default=[], + help="Export environment and/or simulations using this exporter", + ) - parser.add_argument('--only-convert', '--convert', action='store_true', - help='Do not run the simulation, only convert the configuration file(s) and output them.') + parser.add_argument( + "--only-convert", + "--convert", + action="store_true", + help="Do not run the simulation, only convert the configuration file(s) and output them.", + ) - parser.add_argument("--set", - metavar="KEY=VALUE", - action='append', - help="Set a number of parameters that will be passed to the simulation." - "(do not put spaces before or after the = sign). " - "If a value contains spaces, you should define " - "it with double quotes: " - 'foo="this is a sentence". Note that ' - "values are always treated as strings.") + parser.add_argument( + "--set", + metavar="KEY=VALUE", + action="append", + help="Set a number of parameters that will be passed to the simulation." + "(do not put spaces before or after the = sign). " + "If a value contains spaces, you should define " + "it with double quotes: " + 'foo="this is a sentence". Note that ' + "values are always treated as strings.", + ) args = parser.parse_args() - logger.setLevel(getattr(logging, (args.level or 'INFO').upper())) + logger.setLevel(getattr(logging, (args.level or "INFO").upper())) if args.version: return @@ -81,14 +130,16 @@ def main(cfg='simulation.yml', exporters=None, parallel=None, output="soil_outpu if parallel is None: parallel = not args.synchronous - exporters = exporters or ['default', ] + exporters = exporters or [ + "default", + ] for exp in args.exporter: if exp not in exporters: exporters.append(exp) if args.csv: - exporters.append('csv') + exporters.append("csv") if args.graph: - exporters.append('gexf') + exporters.append("gexf") if os.getcwd() not in sys.path: sys.path.append(os.getcwd()) @@ -97,38 +148,38 @@ def main(cfg='simulation.yml', exporters=None, parallel=None, output="soil_outpu if output is None: output = args.output - - logger.info('Loading config file: {}'.format(args.file)) + logger.info("Loading config file: {}".format(args.file)) debug = debug or args.debug if args.pdb or debug: args.synchronous = True - res = [] try: exp_params = {} if not os.path.exists(args.file): - logger.error('Please, input a valid file') + logger.error("Please, input a valid file") return - for sim in simulation.iter_from_config(args.file, - dry_run=args.dry_run, - exporters=exporters, - parallel=parallel, - outdir=output, - exporter_params=exp_params, - **kwargs): + for sim in simulation.iter_from_config( + args.file, + dry_run=args.dry_run, + exporters=exporters, + parallel=parallel, + outdir=output, + exporter_params=exp_params, + **kwargs, + ): if args.set: for s in args.set: - k, v = s.split('=', 1)[:2] + k, v = s.split("=", 1)[:2] v = eval(v) - tail, *head = k.rsplit('.', 1)[::-1] + tail, *head = k.rsplit(".", 1)[::-1] target = sim if head: - for part in head[0].split('.'): + for part in head[0].split("."): try: target = getattr(target, part) except AttributeError: @@ -144,19 +195,21 @@ def main(cfg='simulation.yml', exporters=None, parallel=None, output="soil_outpu if do_run: res.append(sim.run()) else: - print('not running') + print("not running") res.append(sim) except Exception as ex: if args.pdb: from .debugging import post_mortem + print(traceback.format_exc()) post_mortem() else: raise if debug: from .debugging import set_trace - os.environ['SOIL_DEBUG'] = 'true' + + os.environ["SOIL_DEBUG"] = "true" set_trace() return res @@ -165,5 +218,5 @@ def easy(cfg, debug=False, **kwargs): return main(cfg, **kwargs)[0] -if __name__ == '__main__': +if __name__ == "__main__": main(do_run=True) diff --git a/soil/__main__.py b/soil/__main__.py index 9ad5c4f..0c76791 100644 --- a/soil/__main__.py +++ b/soil/__main__.py @@ -1,7 +1,9 @@ from . import main as init_main + def main(): init_main(do_run=True) -if __name__ == '__main__': + +if __name__ == "__main__": init_main(do_run=True) diff --git a/soil/agents/BassModel.py b/soil/agents/BassModel.py index e3f5015..416063d 100644 --- a/soil/agents/BassModel.py +++ b/soil/agents/BassModel.py @@ -7,6 +7,7 @@ class BassModel(FSM): innovation_prob imitation_prob """ + sentimentCorrelation = 0 def step(self): @@ -21,7 +22,7 @@ class BassModel(FSM): else: aware_neighbors = self.get_neighboring_agents(state_id=self.aware.id) num_neighbors_aware = len(aware_neighbors) - if self.prob((self['imitation_prob']*num_neighbors_aware)): + if self.prob((self["imitation_prob"] * num_neighbors_aware)): self.sentimentCorrelation = 1 return self.aware diff --git a/soil/agents/BigMarketModel.py b/soil/agents/BigMarketModel.py index 7db663d..5a93b23 100644 --- a/soil/agents/BigMarketModel.py +++ b/soil/agents/BigMarketModel.py @@ -6,42 +6,54 @@ class BigMarketModel(FSM): Settings: Names: enterprises [Array] - + tweet_probability_enterprises [Array] Users: tweet_probability_users - + tweet_relevant_probability - + tweet_probability_about [Array] - + sentiment_about [Array] """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.enterprises = self.env.environment_params['enterprises'] + self.enterprises = self.env.environment_params["enterprises"] self.type = "" if self.id < len(self.enterprises): # Enterprises self.set_state(self.enterprise.id) self.type = "Enterprise" - self.tweet_probability = environment.environment_params['tweet_probability_enterprises'][self.id] + self.tweet_probability = environment.environment_params[ + "tweet_probability_enterprises" + ][self.id] else: # normal users self.type = "User" self.set_state(self.user.id) - self.tweet_probability = environment.environment_params['tweet_probability_users'] - self.tweet_relevant_probability = environment.environment_params['tweet_relevant_probability'] - self.tweet_probability_about = environment.environment_params['tweet_probability_about'] # List - self.sentiment_about = environment.environment_params['sentiment_about'] # List + self.tweet_probability = environment.environment_params[ + "tweet_probability_users" + ] + self.tweet_relevant_probability = environment.environment_params[ + "tweet_relevant_probability" + ] + self.tweet_probability_about = environment.environment_params[ + "tweet_probability_about" + ] # List + self.sentiment_about = environment.environment_params[ + "sentiment_about" + ] # List @state def enterprise(self): if self.random.random() < self.tweet_probability: # Tweets - aware_neighbors = self.get_neighboring_agents(state_id=self.number_of_enterprises) # Nodes neighbour users + aware_neighbors = self.get_neighboring_agents( + state_id=self.number_of_enterprises + ) # Nodes neighbour users for x in aware_neighbors: - if self.random.uniform(0,10) < 5: + if self.random.uniform(0, 10) < 5: x.sentiment_about[self.id] += 0.1 # Increments for enterprise else: x.sentiment_about[self.id] -= 0.1 # Decrements for enterprise @@ -49,15 +61,19 @@ class BigMarketModel(FSM): # Establecemos limites if x.sentiment_about[self.id] > 1: x.sentiment_about[self.id] = 1 - if x.sentiment_about[self.id]< -1: + if x.sentiment_about[self.id] < -1: x.sentiment_about[self.id] = -1 - x.attrs['sentiment_enterprise_%s'% self.enterprises[self.id]] = x.sentiment_about[self.id] + x.attrs[ + "sentiment_enterprise_%s" % self.enterprises[self.id] + ] = x.sentiment_about[self.id] @state def user(self): if self.random.random() < self.tweet_probability: # Tweets - if self.random.random() < self.tweet_relevant_probability: # Tweets something relevant + if ( + self.random.random() < self.tweet_relevant_probability + ): # Tweets something relevant # Tweet probability per enterprise for i in range(len(self.enterprises)): random_num = self.random.random() @@ -65,23 +81,29 @@ class BigMarketModel(FSM): # The condition is fulfilled, sentiments are evaluated towards that enterprise if self.sentiment_about[i] < 0: # NEGATIVO - self.userTweets("negative",i) + self.userTweets("negative", i) elif self.sentiment_about[i] == 0: # NEUTRO pass else: # POSITIVO - self.userTweets("positive",i) - for i in range(len(self.enterprises)): # So that it never is set to 0 if there are not changes (logs) - self.attrs['sentiment_enterprise_%s'% self.enterprises[i]] = self.sentiment_about[i] + self.userTweets("positive", i) + for i in range( + len(self.enterprises) + ): # So that it never is set to 0 if there are not changes (logs) + self.attrs[ + "sentiment_enterprise_%s" % self.enterprises[i] + ] = self.sentiment_about[i] - def userTweets(self, sentiment,enterprise): - aware_neighbors = self.get_neighboring_agents(state_id=self.number_of_enterprises) # Nodes neighbours users + def userTweets(self, sentiment, enterprise): + aware_neighbors = self.get_neighboring_agents( + state_id=self.number_of_enterprises + ) # Nodes neighbours users for x in aware_neighbors: if sentiment == "positive": - x.sentiment_about[enterprise] +=0.003 + x.sentiment_about[enterprise] += 0.003 elif sentiment == "negative": - x.sentiment_about[enterprise] -=0.003 + x.sentiment_about[enterprise] -= 0.003 else: pass @@ -91,4 +113,6 @@ class BigMarketModel(FSM): if x.sentiment_about[enterprise] < -1: x.sentiment_about[enterprise] = -1 - x.attrs['sentiment_enterprise_%s'% self.enterprises[enterprise]] = x.sentiment_about[enterprise] + x.attrs[ + "sentiment_enterprise_%s" % self.enterprises[enterprise] + ] = x.sentiment_about[enterprise] diff --git a/soil/agents/CounterModel.py b/soil/agents/CounterModel.py index 97c7356..731c61d 100644 --- a/soil/agents/CounterModel.py +++ b/soil/agents/CounterModel.py @@ -15,9 +15,9 @@ class CounterModel(NetworkAgent): # Outside effects total = len(list(self.model.schedule._agents)) neighbors = len(list(self.get_neighboring_agents())) - self['times'] = self.get('times', 0) + 1 - self['neighbors'] = neighbors - self['total'] = total + self["times"] = self.get("times", 0) + 1 + self["neighbors"] = neighbors + self["total"] = total class AggregatedCounter(NetworkAgent): @@ -32,9 +32,9 @@ class AggregatedCounter(NetworkAgent): def step(self): # Outside effects - self['times'] += 1 + self["times"] += 1 neighbors = len(list(self.get_neighboring_agents())) - self['neighbors'] += neighbors + self["neighbors"] += neighbors total = len(list(self.model.schedule.agents)) - self['total'] += total - self.debug('Running for step: {}. Total: {}'.format(self.now, total)) + self["total"] += total + self.debug("Running for step: {}. Total: {}".format(self.now, total)) diff --git a/soil/agents/Geo.py b/soil/agents/Geo.py index bf505bf..d61d1ce 100644 --- a/soil/agents/Geo.py +++ b/soil/agents/Geo.py @@ -2,20 +2,20 @@ from scipy.spatial import cKDTree as KDTree import networkx as nx from . import NetworkAgent, as_node + class Geo(NetworkAgent): - '''In this type of network, nodes have a "pos" attribute.''' + """In this type of network, nodes have a "pos" attribute.""" def geo_search(self, radius, node=None, center=False, **kwargs): - '''Get a list of nodes whose coordinates are closer than *radius* to *node*.''' + """Get a list of nodes whose coordinates are closer than *radius* to *node*.""" node = as_node(node if node is not None else self) G = self.subgraph(**kwargs) - pos = nx.get_node_attributes(G, 'pos') + pos = nx.get_node_attributes(G, "pos") if not pos: return [] nodes, coords = list(zip(*pos.items())) kdtree = KDTree(coords) # Cannot provide generator. indices = kdtree.query_ball_point(pos[node], radius) return [nodes[i] for i in indices if center or (nodes[i] != node)] - diff --git a/soil/agents/IndependentCascadeModel.py b/soil/agents/IndependentCascadeModel.py index e927a6f..d3280e0 100644 --- a/soil/agents/IndependentCascadeModel.py +++ b/soil/agents/IndependentCascadeModel.py @@ -11,10 +11,10 @@ class IndependentCascadeModel(BaseAgent): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.innovation_prob = self.env.environment_params['innovation_prob'] - self.imitation_prob = self.env.environment_params['imitation_prob'] - self.state['time_awareness'] = 0 - self.state['sentimentCorrelation'] = 0 + self.innovation_prob = self.env.environment_params["innovation_prob"] + self.imitation_prob = self.env.environment_params["imitation_prob"] + self.state["time_awareness"] = 0 + self.state["sentimentCorrelation"] = 0 def step(self): self.behaviour() @@ -23,25 +23,27 @@ class IndependentCascadeModel(BaseAgent): aware_neighbors_1_time_step = [] # Outside effects if self.prob(self.innovation_prob): - if self.state['id'] == 0: - self.state['id'] = 1 - self.state['sentimentCorrelation'] = 1 - self.state['time_awareness'] = self.env.now # To know when they have been infected + if self.state["id"] == 0: + self.state["id"] = 1 + self.state["sentimentCorrelation"] = 1 + self.state[ + "time_awareness" + ] = self.env.now # To know when they have been infected else: pass return # Imitation effects - if self.state['id'] == 0: + if self.state["id"] == 0: aware_neighbors = self.get_neighboring_agents(state_id=1) for x in aware_neighbors: - if x.state['time_awareness'] == (self.env.now-1): + if x.state["time_awareness"] == (self.env.now - 1): aware_neighbors_1_time_step.append(x) num_neighbors_aware = len(aware_neighbors_1_time_step) - if self.prob(self.imitation_prob*num_neighbors_aware): - self.state['id'] = 1 - self.state['sentimentCorrelation'] = 1 + if self.prob(self.imitation_prob * num_neighbors_aware): + self.state["id"] = 1 + self.state["sentimentCorrelation"] = 1 else: pass diff --git a/soil/agents/ModelM2.py b/soil/agents/ModelM2.py index dec6b97..b22cafa 100644 --- a/soil/agents/ModelM2.py +++ b/soil/agents/ModelM2.py @@ -23,36 +23,49 @@ class SpreadModelM2(BaseAgent): def __init__(self, model=None, unique_id=0, state=()): super().__init__(model=environment, unique_id=unique_id, state=state) - # Use a single generator with the same seed as `self.random` random = np.random.default_rng(seed=self._seed) - self.prob_neutral_making_denier = random.normal(environment.environment_params['prob_neutral_making_denier'], - environment.environment_params['standard_variance']) + self.prob_neutral_making_denier = random.normal( + environment.environment_params["prob_neutral_making_denier"], + environment.environment_params["standard_variance"], + ) - self.prob_infect = random.normal(environment.environment_params['prob_infect'], - environment.environment_params['standard_variance']) + self.prob_infect = random.normal( + environment.environment_params["prob_infect"], + environment.environment_params["standard_variance"], + ) - self.prob_cured_healing_infected = random.normal(environment.environment_params['prob_cured_healing_infected'], - environment.environment_params['standard_variance']) - self.prob_cured_vaccinate_neutral = random.normal(environment.environment_params['prob_cured_vaccinate_neutral'], - environment.environment_params['standard_variance']) + self.prob_cured_healing_infected = random.normal( + environment.environment_params["prob_cured_healing_infected"], + environment.environment_params["standard_variance"], + ) + self.prob_cured_vaccinate_neutral = random.normal( + environment.environment_params["prob_cured_vaccinate_neutral"], + environment.environment_params["standard_variance"], + ) - self.prob_vaccinated_healing_infected = random.normal(environment.environment_params['prob_vaccinated_healing_infected'], - environment.environment_params['standard_variance']) - self.prob_vaccinated_vaccinate_neutral = random.normal(environment.environment_params['prob_vaccinated_vaccinate_neutral'], - environment.environment_params['standard_variance']) - self.prob_generate_anti_rumor = random.normal(environment.environment_params['prob_generate_anti_rumor'], - environment.environment_params['standard_variance']) + self.prob_vaccinated_healing_infected = random.normal( + environment.environment_params["prob_vaccinated_healing_infected"], + environment.environment_params["standard_variance"], + ) + self.prob_vaccinated_vaccinate_neutral = random.normal( + environment.environment_params["prob_vaccinated_vaccinate_neutral"], + environment.environment_params["standard_variance"], + ) + self.prob_generate_anti_rumor = random.normal( + environment.environment_params["prob_generate_anti_rumor"], + environment.environment_params["standard_variance"], + ) def step(self): - if self.state['id'] == 0: # Neutral + if self.state["id"] == 0: # Neutral self.neutral_behaviour() - elif self.state['id'] == 1: # Infected + elif self.state["id"] == 1: # Infected self.infected_behaviour() - elif self.state['id'] == 2: # Cured + elif self.state["id"] == 2: # Cured self.cured_behaviour() - elif self.state['id'] == 3: # Vaccinated + elif self.state["id"] == 3: # Vaccinated self.vaccinated_behaviour() def neutral_behaviour(self): @@ -61,7 +74,7 @@ class SpreadModelM2(BaseAgent): infected_neighbors = self.get_neighboring_agents(state_id=1) if len(infected_neighbors) > 0: if self.prob(self.prob_neutral_making_denier): - self.state['id'] = 3 # Vaccinated making denier + self.state["id"] = 3 # Vaccinated making denier def infected_behaviour(self): @@ -69,7 +82,7 @@ class SpreadModelM2(BaseAgent): neutral_neighbors = self.get_neighboring_agents(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_infect): - neighbor.state['id'] = 1 # Infected + neighbor.state["id"] = 1 # Infected def cured_behaviour(self): @@ -77,13 +90,13 @@ class SpreadModelM2(BaseAgent): neutral_neighbors = self.get_neighboring_agents(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_cured_vaccinate_neutral): - neighbor.state['id'] = 3 # Vaccinated + neighbor.state["id"] = 3 # Vaccinated # Cure infected_neighbors = self.get_neighboring_agents(state_id=1) for neighbor in infected_neighbors: if self.prob(self.prob_cured_healing_infected): - neighbor.state['id'] = 2 # Cured + neighbor.state["id"] = 2 # Cured def vaccinated_behaviour(self): @@ -91,19 +104,19 @@ class SpreadModelM2(BaseAgent): infected_neighbors = self.get_neighboring_agents(state_id=1) for neighbor in infected_neighbors: if self.prob(self.prob_cured_healing_infected): - neighbor.state['id'] = 2 # Cured + neighbor.state["id"] = 2 # Cured # Vaccinate neutral_neighbors = self.get_neighboring_agents(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_cured_vaccinate_neutral): - neighbor.state['id'] = 3 # Vaccinated + neighbor.state["id"] = 3 # Vaccinated # Generate anti-rumor infected_neighbors_2 = self.get_neighboring_agents(state_id=1) for neighbor in infected_neighbors_2: if self.prob(self.prob_generate_anti_rumor): - neighbor.state['id'] = 2 # Cured + neighbor.state["id"] = 2 # Cured class ControlModelM2(BaseAgent): @@ -112,63 +125,76 @@ class ControlModelM2(BaseAgent): prob_neutral_making_denier prob_infect - + prob_cured_healing_infected - + prob_cured_vaccinate_neutral - + prob_vaccinated_healing_infected - + prob_vaccinated_vaccinate_neutral - + prob_generate_anti_rumor """ - def __init__(self, model=None, unique_id=0, state=()): super().__init__(model=environment, unique_id=unique_id, state=state) - self.prob_neutral_making_denier = np.random.normal(environment.environment_params['prob_neutral_making_denier'], - environment.environment_params['standard_variance']) + self.prob_neutral_making_denier = np.random.normal( + environment.environment_params["prob_neutral_making_denier"], + environment.environment_params["standard_variance"], + ) - self.prob_infect = np.random.normal(environment.environment_params['prob_infect'], - environment.environment_params['standard_variance']) + self.prob_infect = np.random.normal( + environment.environment_params["prob_infect"], + environment.environment_params["standard_variance"], + ) - self.prob_cured_healing_infected = np.random.normal(environment.environment_params['prob_cured_healing_infected'], - environment.environment_params['standard_variance']) - self.prob_cured_vaccinate_neutral = np.random.normal(environment.environment_params['prob_cured_vaccinate_neutral'], - environment.environment_params['standard_variance']) + self.prob_cured_healing_infected = np.random.normal( + environment.environment_params["prob_cured_healing_infected"], + environment.environment_params["standard_variance"], + ) + self.prob_cured_vaccinate_neutral = np.random.normal( + environment.environment_params["prob_cured_vaccinate_neutral"], + environment.environment_params["standard_variance"], + ) - self.prob_vaccinated_healing_infected = np.random.normal(environment.environment_params['prob_vaccinated_healing_infected'], - environment.environment_params['standard_variance']) - self.prob_vaccinated_vaccinate_neutral = np.random.normal(environment.environment_params['prob_vaccinated_vaccinate_neutral'], - environment.environment_params['standard_variance']) - self.prob_generate_anti_rumor = np.random.normal(environment.environment_params['prob_generate_anti_rumor'], - environment.environment_params['standard_variance']) + self.prob_vaccinated_healing_infected = np.random.normal( + environment.environment_params["prob_vaccinated_healing_infected"], + environment.environment_params["standard_variance"], + ) + self.prob_vaccinated_vaccinate_neutral = np.random.normal( + environment.environment_params["prob_vaccinated_vaccinate_neutral"], + environment.environment_params["standard_variance"], + ) + self.prob_generate_anti_rumor = np.random.normal( + environment.environment_params["prob_generate_anti_rumor"], + environment.environment_params["standard_variance"], + ) def step(self): - if self.state['id'] == 0: # Neutral + if self.state["id"] == 0: # Neutral self.neutral_behaviour() - elif self.state['id'] == 1: # Infected + elif self.state["id"] == 1: # Infected self.infected_behaviour() - elif self.state['id'] == 2: # Cured + elif self.state["id"] == 2: # Cured self.cured_behaviour() - elif self.state['id'] == 3: # Vaccinated + elif self.state["id"] == 3: # Vaccinated self.vaccinated_behaviour() - elif self.state['id'] == 4: # Beacon-off + elif self.state["id"] == 4: # Beacon-off self.beacon_off_behaviour() - elif self.state['id'] == 5: # Beacon-on + elif self.state["id"] == 5: # Beacon-on self.beacon_on_behaviour() def neutral_behaviour(self): - self.state['visible'] = False + self.state["visible"] = False # Infected infected_neighbors = self.get_neighboring_agents(state_id=1) if len(infected_neighbors) > 0: if self.random(self.prob_neutral_making_denier): - self.state['id'] = 3 # Vaccinated making denier + self.state["id"] = 3 # Vaccinated making denier def infected_behaviour(self): @@ -176,69 +202,69 @@ class ControlModelM2(BaseAgent): neutral_neighbors = self.get_neighboring_agents(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_infect): - neighbor.state['id'] = 1 # Infected - self.state['visible'] = False + neighbor.state["id"] = 1 # Infected + self.state["visible"] = False def cured_behaviour(self): - self.state['visible'] = True + self.state["visible"] = True # Vaccinate neutral_neighbors = self.get_neighboring_agents(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_cured_vaccinate_neutral): - neighbor.state['id'] = 3 # Vaccinated + neighbor.state["id"] = 3 # Vaccinated # Cure infected_neighbors = self.get_neighboring_agents(state_id=1) for neighbor in infected_neighbors: if self.prob(self.prob_cured_healing_infected): - neighbor.state['id'] = 2 # Cured + neighbor.state["id"] = 2 # Cured def vaccinated_behaviour(self): - self.state['visible'] = True + self.state["visible"] = True # Cure infected_neighbors = self.get_neighboring_agents(state_id=1) for neighbor in infected_neighbors: if self.prob(self.prob_cured_healing_infected): - neighbor.state['id'] = 2 # Cured + neighbor.state["id"] = 2 # Cured # Vaccinate neutral_neighbors = self.get_neighboring_agents(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_cured_vaccinate_neutral): - neighbor.state['id'] = 3 # Vaccinated + neighbor.state["id"] = 3 # Vaccinated # Generate anti-rumor infected_neighbors_2 = self.get_neighboring_agents(state_id=1) for neighbor in infected_neighbors_2: if self.prob(self.prob_generate_anti_rumor): - neighbor.state['id'] = 2 # Cured + neighbor.state["id"] = 2 # Cured def beacon_off_behaviour(self): - self.state['visible'] = False + self.state["visible"] = False infected_neighbors = self.get_neighboring_agents(state_id=1) if len(infected_neighbors) > 0: - self.state['id'] == 5 # Beacon on + self.state["id"] == 5 # Beacon on def beacon_on_behaviour(self): - self.state['visible'] = False + self.state["visible"] = False # Cure (M2 feature added) infected_neighbors = self.get_neighboring_agents(state_id=1) for neighbor in infected_neighbors: if self.prob(self.prob_generate_anti_rumor): - neighbor.state['id'] = 2 # Cured + neighbor.state["id"] = 2 # Cured neutral_neighbors_infected = neighbor.get_neighboring_agents(state_id=0) for neighbor in neutral_neighbors_infected: if self.prob(self.prob_generate_anti_rumor): - neighbor.state['id'] = 3 # Vaccinated + neighbor.state["id"] = 3 # Vaccinated infected_neighbors_infected = neighbor.get_neighboring_agents(state_id=1) for neighbor in infected_neighbors_infected: if self.prob(self.prob_generate_anti_rumor): - neighbor.state['id'] = 2 # Cured + neighbor.state["id"] = 2 # Cured # Vaccinate neutral_neighbors = self.get_neighboring_agents(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_cured_vaccinate_neutral): - neighbor.state['id'] = 3 # Vaccinated + neighbor.state["id"] = 3 # Vaccinated diff --git a/soil/agents/SISaModel.py b/soil/agents/SISaModel.py index fa0d224..e298e8a 100644 --- a/soil/agents/SISaModel.py +++ b/soil/agents/SISaModel.py @@ -6,25 +6,25 @@ class SISaModel(FSM): """ Settings: neutral_discontent_spon_prob - + neutral_discontent_infected_prob - + neutral_content_spon_prob - + neutral_content_infected_prob - + discontent_neutral - + discontent_content - + variance_d_c - + content_discontent - + variance_c_d - + content_neutral - + standard_variance """ @@ -33,24 +33,32 @@ class SISaModel(FSM): random = np.random.default_rng(seed=self._seed) - self.neutral_discontent_spon_prob = random.normal(self.env['neutral_discontent_spon_prob'], - self.env['standard_variance']) - self.neutral_discontent_infected_prob = random.normal(self.env['neutral_discontent_infected_prob'], - self.env['standard_variance']) - self.neutral_content_spon_prob = random.normal(self.env['neutral_content_spon_prob'], - self.env['standard_variance']) - self.neutral_content_infected_prob = random.normal(self.env['neutral_content_infected_prob'], - self.env['standard_variance']) + self.neutral_discontent_spon_prob = random.normal( + self.env["neutral_discontent_spon_prob"], self.env["standard_variance"] + ) + self.neutral_discontent_infected_prob = random.normal( + self.env["neutral_discontent_infected_prob"], self.env["standard_variance"] + ) + self.neutral_content_spon_prob = random.normal( + self.env["neutral_content_spon_prob"], self.env["standard_variance"] + ) + self.neutral_content_infected_prob = random.normal( + self.env["neutral_content_infected_prob"], self.env["standard_variance"] + ) - self.discontent_neutral = random.normal(self.env['discontent_neutral'], - self.env['standard_variance']) - self.discontent_content = random.normal(self.env['discontent_content'], - self.env['variance_d_c']) + self.discontent_neutral = random.normal( + self.env["discontent_neutral"], self.env["standard_variance"] + ) + self.discontent_content = random.normal( + self.env["discontent_content"], self.env["variance_d_c"] + ) - self.content_discontent = random.normal(self.env['content_discontent'], - self.env['variance_c_d']) - self.content_neutral = random.normal(self.env['content_neutral'], - self.env['standard_variance']) + self.content_discontent = random.normal( + self.env["content_discontent"], self.env["variance_c_d"] + ) + self.content_neutral = random.normal( + self.env["content_neutral"], self.env["standard_variance"] + ) @state def neutral(self): @@ -88,7 +96,9 @@ class SISaModel(FSM): return self.neutral # Superinfected - discontent_neighbors = self.count_neighboring_agents(state_id=self.discontent.id) + discontent_neighbors = self.count_neighboring_agents( + state_id=self.discontent.id + ) if self.prob(scontent_neighbors * self.content_discontent): self.discontent return self.content diff --git a/soil/agents/SentimentCorrelationModel.py b/soil/agents/SentimentCorrelationModel.py index 96907aa..721d026 100644 --- a/soil/agents/SentimentCorrelationModel.py +++ b/soil/agents/SentimentCorrelationModel.py @@ -5,27 +5,31 @@ class SentimentCorrelationModel(BaseAgent): """ Settings: outside_effects_prob - + anger_prob - + joy_prob - + sadness_prob - + disgust_prob """ def __init__(self, environment, unique_id=0, state=()): super().__init__(model=environment, unique_id=unique_id, state=state) - self.outside_effects_prob = environment.environment_params['outside_effects_prob'] - self.anger_prob = environment.environment_params['anger_prob'] - self.joy_prob = environment.environment_params['joy_prob'] - self.sadness_prob = environment.environment_params['sadness_prob'] - self.disgust_prob = environment.environment_params['disgust_prob'] - self.state['time_awareness'] = [] + self.outside_effects_prob = environment.environment_params[ + "outside_effects_prob" + ] + self.anger_prob = environment.environment_params["anger_prob"] + self.joy_prob = environment.environment_params["joy_prob"] + self.sadness_prob = environment.environment_params["sadness_prob"] + self.disgust_prob = environment.environment_params["disgust_prob"] + self.state["time_awareness"] = [] for i in range(4): # In this model we have 4 sentiments - self.state['time_awareness'].append(0) # 0-> Anger, 1-> joy, 2->sadness, 3 -> disgust - self.state['sentimentCorrelation'] = 0 + self.state["time_awareness"].append( + 0 + ) # 0-> Anger, 1-> joy, 2->sadness, 3 -> disgust + self.state["sentimentCorrelation"] = 0 def step(self): self.behaviour() @@ -39,63 +43,73 @@ class SentimentCorrelationModel(BaseAgent): angry_neighbors = self.get_neighboring_agents(state_id=1) for x in angry_neighbors: - if x.state['time_awareness'][0] > (self.env.now-500): + if x.state["time_awareness"][0] > (self.env.now - 500): angry_neighbors_1_time_step.append(x) num_neighbors_angry = len(angry_neighbors_1_time_step) joyful_neighbors = self.get_neighboring_agents(state_id=2) for x in joyful_neighbors: - if x.state['time_awareness'][1] > (self.env.now-500): + if x.state["time_awareness"][1] > (self.env.now - 500): joyful_neighbors_1_time_step.append(x) num_neighbors_joyful = len(joyful_neighbors_1_time_step) sad_neighbors = self.get_neighboring_agents(state_id=3) for x in sad_neighbors: - if x.state['time_awareness'][2] > (self.env.now-500): + if x.state["time_awareness"][2] > (self.env.now - 500): sad_neighbors_1_time_step.append(x) num_neighbors_sad = len(sad_neighbors_1_time_step) disgusted_neighbors = self.get_neighboring_agents(state_id=4) for x in disgusted_neighbors: - if x.state['time_awareness'][3] > (self.env.now-500): + if x.state["time_awareness"][3] > (self.env.now - 500): disgusted_neighbors_1_time_step.append(x) num_neighbors_disgusted = len(disgusted_neighbors_1_time_step) - anger_prob = self.anger_prob+(len(angry_neighbors_1_time_step)*self.anger_prob) - joy_prob = self.joy_prob+(len(joyful_neighbors_1_time_step)*self.joy_prob) - sadness_prob = self.sadness_prob+(len(sad_neighbors_1_time_step)*self.sadness_prob) - disgust_prob = self.disgust_prob+(len(disgusted_neighbors_1_time_step)*self.disgust_prob) + anger_prob = self.anger_prob + ( + len(angry_neighbors_1_time_step) * self.anger_prob + ) + joy_prob = self.joy_prob + (len(joyful_neighbors_1_time_step) * self.joy_prob) + sadness_prob = self.sadness_prob + ( + len(sad_neighbors_1_time_step) * self.sadness_prob + ) + disgust_prob = self.disgust_prob + ( + len(disgusted_neighbors_1_time_step) * self.disgust_prob + ) outside_effects_prob = self.outside_effects_prob num = self.random.random() - if num anger_prob: - self.state['id'] = 1 - self.state['sentimentCorrelation'] = 1 - self.state['time_awareness'][self.state['id']-1] = self.env.now - elif (numanger_prob): + self.state["id"] = 2 + self.state["sentimentCorrelation"] = 2 + self.state["time_awareness"][self.state["id"] - 1] = self.env.now + elif num < sadness_prob + anger_prob + joy_prob and num > joy_prob + anger_prob: - self.state['id'] = 2 - self.state['sentimentCorrelation'] = 2 - self.state['time_awareness'][self.state['id']-1] = self.env.now - elif (numjoy_prob+anger_prob): + self.state["id"] = 3 + self.state["sentimentCorrelation"] = 3 + self.state["time_awareness"][self.state["id"] - 1] = self.env.now + elif ( + num < disgust_prob + sadness_prob + anger_prob + joy_prob + and num > sadness_prob + anger_prob + joy_prob + ): - self.state['id'] = 3 - self.state['sentimentCorrelation'] = 3 - self.state['time_awareness'][self.state['id']-1] = self.env.now - elif (numsadness_prob+anger_prob+joy_prob): + self.state["id"] = 4 + self.state["sentimentCorrelation"] = 4 + self.state["time_awareness"][self.state["id"] - 1] = self.env.now - self.state['id'] = 4 - self.state['sentimentCorrelation'] = 4 - self.state['time_awareness'][self.state['id']-1] = self.env.now - - self.state['sentiment'] = self.state['id'] + self.state["sentiment"] = self.state["id"] diff --git a/soil/agents/__init__.py b/soil/agents/__init__.py index ad3e4a7..c284604 100644 --- a/soil/agents/__init__.py +++ b/soil/agents/__init__.py @@ -20,13 +20,13 @@ from typing import Dict, List from .. import serialization, utils, time, config - def as_node(agent): if isinstance(agent, BaseAgent): return agent.id return agent -IGNORED_FIELDS = ('model', 'logger') + +IGNORED_FIELDS = ("model", "logger") class DeadAgent(Exception): @@ -43,13 +43,18 @@ class MetaAgent(ABCMeta): defaults.update(i._defaults) new_nmspc = { - '_defaults': defaults, + "_defaults": defaults, } for attr, func in namespace.items(): - if isinstance(func, types.FunctionType) or isinstance(func, property) or isinstance(func, classmethod) or attr[0] == '_': + if ( + isinstance(func, types.FunctionType) + or isinstance(func, property) + or isinstance(func, classmethod) + or attr[0] == "_" + ): new_nmspc[attr] = func - elif attr == 'defaults': + elif attr == "defaults": defaults.update(func) else: defaults[attr] = copy(func) @@ -69,12 +74,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): Any attribute that is not preceded by an underscore (`_`) will also be added to its state. """ - def __init__(self, - unique_id, - model, - name=None, - interval=None, - **kwargs): + def __init__(self, unique_id, model, name=None, interval=None, **kwargs): # Check for REQUIRED arguments # Initialize agent parameters if isinstance(unique_id, MesaAgent): @@ -82,16 +82,19 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): assert isinstance(unique_id, int) super().__init__(unique_id=unique_id, model=model) - self.name = str(name) if name else'{}[{}]'.format(type(self).__name__, self.unique_id) - + self.name = ( + str(name) if name else "{}[{}]".format(type(self).__name__, self.unique_id) + ) self.alive = True - self.interval = interval or self.get('interval', 1) - logger = utils.logger.getChild(getattr(self.model, 'id', self.model)).getChild(self.name) - self.logger = logging.LoggerAdapter(logger, {'agent_name': self.name}) + self.interval = interval or self.get("interval", 1) + logger = utils.logger.getChild(getattr(self.model, "id", self.model)).getChild( + self.name + ) + self.logger = logging.LoggerAdapter(logger, {"agent_name": self.name}) - if hasattr(self, 'level'): + if hasattr(self, "level"): self.logger.setLevel(self.level) for (k, v) in self._defaults.items(): @@ -117,20 +120,22 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): def from_dict(cls, model, attrs, warn_extra=True): ignored = {} args = {} - for k, v in attrs.items(): + for k, v in attrs.items(): if k in inspect.signature(cls).parameters: args[k] = v else: ignored[k] = v if ignored and warn_extra: - utils.logger.info(f'Ignoring the following arguments for agent class { agent_class.__name__ }: { ignored }') + utils.logger.info( + f"Ignoring the following arguments for agent class { agent_class.__name__ }: { ignored }" + ) return cls(model=model, **args) def __getitem__(self, key): try: return getattr(self, key) except AttributeError: - raise KeyError(f'key {key} not found in agent') + raise KeyError(f"key {key} not found in agent") def __delitem__(self, key): return delattr(self, key) @@ -148,7 +153,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): return self.items() def keys(self): - return (k for k in self.__dict__ if k[0] != '_' and k not in IGNORED_FIELDS) + return (k for k in self.__dict__ if k[0] != "_" and k not in IGNORED_FIELDS) def items(self, keys=None, skip=None): keys = keys if keys is not None else self.keys() @@ -156,7 +161,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): if skip: return filter(lambda x: x[0] not in skip, it) return it - + def get(self, key, default=None): return self[key] if key in self else default @@ -169,7 +174,7 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): return None def die(self): - self.info(f'agent dying') + self.info(f"agent dying") self.alive = False return time.NEVER @@ -186,9 +191,9 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): for k, v in kwargs: message += " {k}={v} ".format(k, v) extra = {} - extra['now'] = self.now - extra['unique_id'] = self.unique_id - extra['agent_name'] = self.name + extra["now"] = self.now + extra["unique_id"] = self.unique_id + extra["agent_name"] = self.name return self.logger.log(level, message, extra=extra) def debug(self, *args, **kwargs): @@ -214,10 +219,10 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): content = dict(self.items(keys=keys)) if pretty and content: d = content - content = '\n' + content = "\n" for k, v in d.items(): - content += f'- {k}: {v}\n' - content = textwrap.indent(content, ' ') + content += f"- {k}: {v}\n" + content = textwrap.indent(content, " ") return f"{repr(self)}{content}" def __repr__(self): @@ -225,7 +230,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): class NetworkAgent(BaseAgent): - def __init__(self, *args, topology, node_id, **kwargs): super().__init__(*args, **kwargs) @@ -248,18 +252,21 @@ class NetworkAgent(BaseAgent): def node(self): return self.topology.nodes[self.node_id] - def iter_agents(self, unique_id=None, *, limit_neighbors=False, **kwargs): unique_ids = None if isinstance(unique_id, list): unique_ids = set(unique_id) elif unique_id is not None: - unique_ids = set([unique_id,]) + unique_ids = set( + [ + unique_id, + ] + ) if limit_neighbors: neighbor_ids = set() for node_id in self.G.neighbors(self.node_id): - if self.G.nodes[node_id].get('agent') is not None: + if self.G.nodes[node_id].get("agent") is not None: neighbor_ids.add(node_id) if unique_ids: unique_ids = unique_ids & neighbor_ids @@ -272,7 +279,9 @@ class NetworkAgent(BaseAgent): def subgraph(self, center=True, **kwargs): include = [self] if center else [] - G = self.G.subgraph(n.node_id for n in list(self.get_agents(**kwargs)+include)) + G = self.G.subgraph( + n.node_id for n in list(self.get_agents(**kwargs) + include) + ) return G def remove_node(self): @@ -280,11 +289,19 @@ class NetworkAgent(BaseAgent): def add_edge(self, other, edge_attr_dict=None, *edge_attrs): if self.node_id not in self.G.nodes(data=False): - raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id)) + raise ValueError( + "{} not in list of existing agents in the network".format( + self.unique_id + ) + ) if other.node_id not in self.G.nodes(data=False): - raise ValueError('{} not in list of existing agents in the network'.format(other)) + raise ValueError( + "{} not in list of existing agents in the network".format(other) + ) - self.G.add_edge(self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs) + self.G.add_edge( + self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs + ) def die(self, remove=True): if remove: @@ -294,11 +311,11 @@ class NetworkAgent(BaseAgent): def state(name=None): def decorator(func, name=None): - ''' + """ A state function should return either a state id, or a tuple (state_id, when) The default value for state_id is the current state id. The default value for when is the interval defined in the environment. - ''' + """ if inspect.isgeneratorfunction(func): orig_func = func @@ -348,32 +365,38 @@ class MetaFSM(MetaAgent): # Add new states for attr, func in namespace.items(): - if hasattr(func, 'id'): + if hasattr(func, "id"): if func.is_default: default_state = func states[func.id] = func - namespace.update({ - '_default_state': default_state, - '_states': states, - }) + namespace.update( + { + "_default_state": default_state, + "_states": states, + } + ) - return super(MetaFSM, mcls).__new__(mcls=mcls, name=name, bases=bases, namespace=namespace) + return super(MetaFSM, mcls).__new__( + mcls=mcls, name=name, bases=bases, namespace=namespace + ) class FSM(BaseAgent, metaclass=MetaFSM): def __init__(self, *args, **kwargs): super(FSM, self).__init__(*args, **kwargs) - if not hasattr(self, 'state_id'): + if not hasattr(self, "state_id"): if not self._default_state: - raise ValueError('No default state specified for {}'.format(self.unique_id)) + raise ValueError( + "No default state specified for {}".format(self.unique_id) + ) self.state_id = self._default_state.id self._coroutine = None self.set_state(self.state_id) def step(self): - self.debug(f'Agent {self.unique_id} @ state {self.state_id}') + self.debug(f"Agent {self.unique_id} @ state {self.state_id}") default_interval = super().step() next_state = self._states[self.state_id](self) @@ -386,7 +409,9 @@ class FSM(BaseAgent, metaclass=MetaFSM): elif len(when) == 1: when = when[0] else: - raise ValueError('Too many values returned. Only state (and time) allowed') + raise ValueError( + "Too many values returned. Only state (and time) allowed" + ) except TypeError: pass @@ -396,10 +421,10 @@ class FSM(BaseAgent, metaclass=MetaFSM): return when or default_interval def set_state(self, state, when=None): - if hasattr(state, 'id'): + if hasattr(state, "id"): state = state.id if state not in self._states: - raise ValueError('{} is not a valid state'.format(state)) + raise ValueError("{} is not a valid state".format(state)) self.state_id = state if when is not None: self.model.schedule.add(self, when=when) @@ -414,23 +439,22 @@ class FSM(BaseAgent, metaclass=MetaFSM): def prob(prob, random): - ''' + """ A true/False uniform distribution with a given probability. To be used like this: .. code-block:: python - + if prob(0.3): do_something() - ''' + """ r = random.random() return r < prob -def calculate_distribution(network_agents=None, - agent_class=None): - ''' +def calculate_distribution(network_agents=None, agent_class=None): + """ Calculate the threshold values (thresholds for a uniform distribution) of an agent distribution given the weights of each agent type. @@ -453,26 +477,28 @@ def calculate_distribution(network_agents=None, In this example, 20% of the nodes will be marked as type 'agent_class_1'. - ''' + """ if network_agents: - network_agents = [deepcopy(agent) for agent in network_agents if not hasattr(agent, 'id')] + network_agents = [ + deepcopy(agent) for agent in network_agents if not hasattr(agent, "id") + ] elif agent_class: - network_agents = [{'agent_class': agent_class}] + network_agents = [{"agent_class": agent_class}] else: - raise ValueError('Specify a distribution or a default agent type') + raise ValueError("Specify a distribution or a default agent type") # Fix missing weights and incompatible types for x in network_agents: - x['weight'] = float(x.get('weight', 1)) + x["weight"] = float(x.get("weight", 1)) # Calculate the thresholds - total = sum(x['weight'] for x in network_agents) + total = sum(x["weight"] for x in network_agents) acc = 0 for v in network_agents: - if 'ids' in v: + if "ids" in v: continue - upper = acc + (v['weight']/total) - v['threshold'] = [acc, upper] + upper = acc + (v["weight"] / total) + v["threshold"] = [acc, upper] acc = upper return network_agents @@ -480,28 +506,29 @@ def calculate_distribution(network_agents=None, def serialize_type(agent_class, known_modules=[], **kwargs): if isinstance(agent_class, str): return agent_class - known_modules += ['soil.agents'] - return serialization.serialize(agent_class, known_modules=known_modules, **kwargs)[1] # Get the name of the class + known_modules += ["soil.agents"] + return serialization.serialize(agent_class, known_modules=known_modules, **kwargs)[ + 1 + ] # Get the name of the class def serialize_definition(network_agents, known_modules=[]): - ''' + """ When serializing an agent distribution, remove the thresholds, in order to avoid cluttering the YAML definition file. - ''' + """ d = deepcopy(list(network_agents)) for v in d: - if 'threshold' in v: - del v['threshold'] - v['agent_class'] = serialize_type(v['agent_class'], - known_modules=known_modules) + if "threshold" in v: + del v["threshold"] + v["agent_class"] = serialize_type(v["agent_class"], known_modules=known_modules) return d def deserialize_type(agent_class, known_modules=[]): if not isinstance(agent_class, str): return agent_class - known = known_modules + ['soil.agents', 'soil.agents.custom' ] + known = known_modules + ["soil.agents", "soil.agents.custom"] agent_class = serialization.deserializer(agent_class, known_modules=known) return agent_class @@ -509,12 +536,12 @@ def deserialize_type(agent_class, known_modules=[]): def deserialize_definition(ind, **kwargs): d = deepcopy(ind) for v in d: - v['agent_class'] = deserialize_type(v['agent_class'], **kwargs) + v["agent_class"] = deserialize_type(v["agent_class"], **kwargs) return d def _validate_states(states, topology): - '''Validate states to avoid ignoring states during initialization''' + """Validate states to avoid ignoring states during initialization""" states = states or [] if isinstance(states, dict): for x in states: @@ -525,7 +552,7 @@ def _validate_states(states, topology): def _convert_agent_classs(ind, to_string=False, **kwargs): - '''Convenience method to allow specifying agents by class or class name.''' + """Convenience method to allow specifying agents by class or class name.""" if to_string: return serialize_definition(ind, **kwargs) return deserialize_definition(ind, **kwargs) @@ -609,12 +636,10 @@ def _convert_agent_classs(ind, to_string=False, **kwargs): class AgentView(Mapping, Set): - """A lazy-loaded list of agents. - """ + """A lazy-loaded list of agents.""" __slots__ = ("_agents",) - def __init__(self, agents): self._agents = agents @@ -657,11 +682,20 @@ class AgentView(Mapping, Set): return f"{self.__class__.__name__}({self})" -def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=None, ignore=None, state=None, - limit=None, **kwargs): - ''' +def filter_agents( + agents, + *id_args, + unique_id=None, + state_id=None, + agent_class=None, + ignore=None, + state=None, + limit=None, + **kwargs, +): + """ Filter agents given as a dict, by the criteria given as arguments (e.g., certain type or state id). - ''' + """ assert isinstance(agents, dict) ids = [] @@ -694,7 +728,7 @@ def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=N f = filter(lambda x: x not in ignore, f) if state_id is not None: - f = filter(lambda agent: agent.get('state_id', None) in state_id, f) + f = filter(lambda agent: agent.get("state_id", None) in state_id, f) if agent_class is not None: f = filter(lambda agent: isinstance(agent, agent_class), f) @@ -711,23 +745,25 @@ def filter_agents(agents, *id_args, unique_id=None, state_id=None, agent_class=N yield from f -def from_config(cfg: config.AgentConfig, random, topology: nx.Graph = None) -> List[Dict[str, Any]]: - ''' +def from_config( + cfg: config.AgentConfig, random, topology: nx.Graph = None +) -> List[Dict[str, Any]]: + """ This function turns an agentconfig into a list of individual "agent specifications", which are just a dictionary with the parameters that the environment will use to construct each agent. This function does NOT return a list of agents, mostly because some attributes to the agent are not known at the time of calling this function, such as `unique_id`. - ''' + """ default = cfg or config.AgentConfig() if not isinstance(cfg, config.AgentConfig): cfg = config.AgentConfig(**cfg) return _agents_from_config(cfg, topology=topology, random=random) -def _agents_from_config(cfg: config.AgentConfig, - topology: nx.Graph, - random) -> List[Dict[str, Any]]: +def _agents_from_config( + cfg: config.AgentConfig, topology: nx.Graph, random +) -> List[Dict[str, Any]]: if cfg and not isinstance(cfg, config.AgentConfig): cfg = config.AgentConfig(**cfg) @@ -737,7 +773,9 @@ def _agents_from_config(cfg: config.AgentConfig, assigned_network = 0 if cfg.fixed is not None: - agents, assigned_total, assigned_network = _from_fixed(cfg.fixed, topology=cfg.topology, default=cfg) + agents, assigned_total, assigned_network = _from_fixed( + cfg.fixed, topology=cfg.topology, default=cfg + ) n = cfg.n @@ -749,46 +787,56 @@ def _agents_from_config(cfg: config.AgentConfig, for d in cfg.distribution: if d.strategy == config.Strategy.topology: - topo = d.topology if ('topology' in d.__fields_set__) else cfg.topology + topo = d.topology if ("topology" in d.__fields_set__) else cfg.topology if not topo: - raise ValueError('The "topology" strategy only works if the topology parameter is set to True') + raise ValueError( + 'The "topology" strategy only works if the topology parameter is set to True' + ) if not topo_size: - raise ValueError(f'Topology does not have enough free nodes to assign one to the agent') + raise ValueError( + f"Topology does not have enough free nodes to assign one to the agent" + ) networked.append(d) if d.strategy == config.Strategy.total: if not cfg.n: - raise ValueError('Cannot use the "total" strategy without providing the total number of agents') + raise ValueError( + 'Cannot use the "total" strategy without providing the total number of agents' + ) total.append(d) - if networked: - new_agents = _from_distro(networked, - n= topo_size - assigned_network, - topology=topo, - default=cfg, - random=random) + new_agents = _from_distro( + networked, + n=topo_size - assigned_network, + topology=topo, + default=cfg, + random=random, + ) assigned_total += len(new_agents) assigned_network += len(new_agents) agents += new_agents if total: - remaining = n - assigned_total - agents += _from_distro(total, n=remaining, - default=cfg, - random=random) - + remaining = n - assigned_total + agents += _from_distro(total, n=remaining, default=cfg, random=random) if assigned_network < topo_size: - utils.logger.warn(f'The total number of agents does not match the total number of nodes in ' - 'every topology. This may be due to a definition error: assigned: ' - f'{ assigned } total size: { topo_size }') + utils.logger.warn( + f"The total number of agents does not match the total number of nodes in " + "every topology. This may be due to a definition error: assigned: " + f"{ assigned } total size: { topo_size }" + ) return agents -def _from_fixed(lst: List[config.FixedAgentConfig], topology: bool, default: config.SingleAgentConfig) -> List[Dict[str, Any]]: +def _from_fixed( + lst: List[config.FixedAgentConfig], + topology: bool, + default: config.SingleAgentConfig, +) -> List[Dict[str, Any]]: agents = [] counts_total = 0 @@ -799,12 +847,18 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: bool, default: con if default: agent = default.state.copy() agent.update(fixed.state) - cls = serialization.deserialize(fixed.agent_class or (default and default.agent_class)) - agent['agent_class'] = cls - topo = fixed.topology if ('topology' in fixed.__fields_set__) else topology or default.topology + cls = serialization.deserialize( + fixed.agent_class or (default and default.agent_class) + ) + agent["agent_class"] = cls + topo = ( + fixed.topology + if ("topology" in fixed.__fields_set__) + else topology or default.topology + ) if topo: - agent['topology'] = True + agent["topology"] = True counts_network += 1 if not fixed.hidden: counts_total += 1 @@ -813,17 +867,21 @@ def _from_fixed(lst: List[config.FixedAgentConfig], topology: bool, default: con return agents, counts_total, counts_network -def _from_distro(distro: List[config.AgentDistro], - n: int, - topology: str, - default: config.SingleAgentConfig, - random) -> List[Dict[str, Any]]: +def _from_distro( + distro: List[config.AgentDistro], + n: int, + topology: str, + default: config.SingleAgentConfig, + random, +) -> List[Dict[str, Any]]: agents = [] if n is None: if any(lambda dist: dist.n is None, distro): - raise ValueError('You must provide a total number of agents, or the number of each type') + raise ValueError( + "You must provide a total number of agents, or the number of each type" + ) n = sum(dist.n for dist in distro) weights = list(dist.weight if dist.weight is not None else 1 for dist in distro) @@ -836,29 +894,40 @@ def _from_distro(distro: List[config.AgentDistro], # So instead we calculate our own distribution to make sure the actual ratios are close to what we would expect # Calculate how many times each has to appear - indices = list(chain.from_iterable([idx] * int(n*chunk) for (idx, n) in enumerate(norm))) + indices = list( + chain.from_iterable([idx] * int(n * chunk) for (idx, n) in enumerate(norm)) + ) # Complete with random agents following the original weight distribution if len(indices) < n: - indices += random.choices(list(range(len(distro))), weights=[d.weight for d in distro], k=n-len(indices)) + indices += random.choices( + list(range(len(distro))), + weights=[d.weight for d in distro], + k=n - len(indices), + ) # Deserialize classes for efficiency - classes = list(serialization.deserialize(i.agent_class or default.agent_class) for i in distro) + classes = list( + serialization.deserialize(i.agent_class or default.agent_class) for i in distro + ) # Add them in random order random.shuffle(indices) - for idx in indices: d = distro[idx] agent = d.state.copy() cls = classes[idx] - agent['agent_class'] = cls + agent["agent_class"] = cls if default: agent.update(default.state) - topology = d.topology if ('topology' in d.__fields_set__) else topology or default.topology + topology = ( + d.topology + if ("topology" in d.__fields_set__) + else topology or default.topology + ) if topology: - agent['topology'] = topology + agent["topology"] = topology agents.append(agent) return agents @@ -877,4 +946,5 @@ try: from .Geo import Geo except ImportError: import sys - print('Could not load the Geo Agent, scipy is not installed', file=sys.stderr) + + print("Could not load the Geo Agent, scipy is not installed", file=sys.stderr) diff --git a/soil/config.py b/soil/config.py index 7b39154..8dbbffa 100644 --- a/soil/config.py +++ b/soil/config.py @@ -19,6 +19,7 @@ import networkx as nx # Could use TypeAlias in python >= 3.10 nodeId = int + class Node(BaseModel): id: nodeId state: Optional[Dict[str, Any]] = {} @@ -38,7 +39,7 @@ class Topology(BaseModel): class NetParams(BaseModel, extra=Extra.allow): generator: Union[Callable, str] - n: int + n: int class NetConfig(BaseModel): @@ -54,14 +55,15 @@ class NetConfig(BaseModel): return NetConfig(topology=None, params=None) @root_validator - def validate_all(cls, values): - if 'params' not in values and 'topology' not in values: - raise ValueError('You must specify either a topology or the parameters to generate a graph') + def validate_all(cls, values): + if "params" not in values and "topology" not in values: + raise ValueError( + "You must specify either a topology or the parameters to generate a graph" + ) return values class EnvConfig(BaseModel): - @staticmethod def default(): return EnvConfig() @@ -80,9 +82,11 @@ class FixedAgentConfig(SingleAgentConfig): hidden: Optional[bool] = False # Do not count this agent towards total agent count @root_validator - def validate_all(cls, values): - if values.get('unique_id', None) is not None and values.get('n', 1) > 1: - raise ValueError(f"An unique_id can only be provided when there is only one agent ({values.get('n')} given)") + def validate_all(cls, values): + if values.get("unique_id", None) is not None and values.get("n", 1) > 1: + raise ValueError( + f"An unique_id can only be provided when there is only one agent ({values.get('n')} given)" + ) return values @@ -91,8 +95,8 @@ class OverrideAgentConfig(FixedAgentConfig): class Strategy(Enum): - topology = 'topology' - total = 'total' + topology = "topology" + total = "total" class AgentDistro(SingleAgentConfig): @@ -111,16 +115,20 @@ class AgentConfig(SingleAgentConfig): return AgentConfig() @root_validator - def validate_all(cls, values): - if 'distribution' in values and ('n' not in values and 'topology' not in values): - raise ValueError("You need to provide the number of agents or a topology to extract the value from.") + def validate_all(cls, values): + if "distribution" in values and ( + "n" not in values and "topology" not in values + ): + raise ValueError( + "You need to provide the number of agents or a topology to extract the value from." + ) return values class Config(BaseModel, extra=Extra.allow): - version: Optional[str] = '1' + version: Optional[str] = "1" - name: str = 'Unnamed Simulation' + name: str = "Unnamed Simulation" description: Optional[str] = None group: str = None dir_path: Optional[str] = None @@ -140,45 +148,48 @@ class Config(BaseModel, extra=Extra.allow): def from_raw(cls, cfg): if isinstance(cfg, Config): return cfg - if cfg.get('version', '1') == '1' and any(k in cfg for k in ['agents', 'agent_class', 'topology', 'environment_class']): + if cfg.get("version", "1") == "1" and any( + k in cfg for k in ["agents", "agent_class", "topology", "environment_class"] + ): return convert_old(cfg) return Config(**cfg) def convert_old(old, strict=True): - ''' + """ Try to convert old style configs into the new format. This is still a work in progress and might not work in many cases. - ''' + """ - utils.logger.warning('The old configuration format is deprecated. The converted file MAY NOT yield the right results') + utils.logger.warning( + "The old configuration format is deprecated. The converted file MAY NOT yield the right results" + ) new = old.copy() network = {} - if 'topology' in old: - del new['topology'] - network['topology'] = old['topology'] + if "topology" in old: + del new["topology"] + network["topology"] = old["topology"] - if 'network_params' in old and old['network_params']: - del new['network_params'] - for (k, v) in old['network_params'].items(): - if k == 'path': - network['path'] = v + if "network_params" in old and old["network_params"]: + del new["network_params"] + for (k, v) in old["network_params"].items(): + if k == "path": + network["path"] = v else: - network.setdefault('params', {})[k] = v + network.setdefault("params", {})[k] = v topology = None if network: topology = network - - agents = {'fixed': [], 'distribution': []} + agents = {"fixed": [], "distribution": []} def updated_agent(agent): - '''Convert an agent definition''' + """Convert an agent definition""" newagent = dict(agent) return newagent @@ -186,80 +197,74 @@ def convert_old(old, strict=True): fixed = [] override = [] - if 'environment_agents' in new: + if "environment_agents" in new: - for agent in new['environment_agents']: - agent.setdefault('state', {})['group'] = 'environment' - if 'agent_id' in agent: - agent['state']['name'] = agent['agent_id'] - del agent['agent_id'] - agent['hidden'] = True - agent['topology'] = False + for agent in new["environment_agents"]: + agent.setdefault("state", {})["group"] = "environment" + if "agent_id" in agent: + agent["state"]["name"] = agent["agent_id"] + del agent["agent_id"] + agent["hidden"] = True + agent["topology"] = False fixed.append(updated_agent(agent)) - del new['environment_agents'] + del new["environment_agents"] + if "agent_class" in old: + del new["agent_class"] + agents["agent_class"] = old["agent_class"] - if 'agent_class' in old: - del new['agent_class'] - agents['agent_class'] = old['agent_class'] + if "default_state" in old: + del new["default_state"] + agents["state"] = old["default_state"] - if 'default_state' in old: - del new['default_state'] - agents['state'] = old['default_state'] + if "network_agents" in old: + agents["topology"] = True - if 'network_agents' in old: - agents['topology'] = True + agents.setdefault("state", {})["group"] = "network" - agents.setdefault('state', {})['group'] = 'network' - - for agent in new['network_agents']: + for agent in new["network_agents"]: agent = updated_agent(agent) - if 'agent_id' in agent: - agent['state']['name'] = agent['agent_id'] - del agent['agent_id'] + if "agent_id" in agent: + agent["state"]["name"] = agent["agent_id"] + del agent["agent_id"] fixed.append(agent) else: by_weight.append(agent) - del new['network_agents'] + del new["network_agents"] - if 'agent_class' in old and (not fixed and not by_weight): - agents['topology'] = True - by_weight = [{'agent_class': old['agent_class'], 'weight': 1}] + if "agent_class" in old and (not fixed and not by_weight): + agents["topology"] = True + by_weight = [{"agent_class": old["agent_class"], "weight": 1}] - # TODO: translate states properly - if 'states' in old: - del new['states'] - states = old['states'] + if "states" in old: + del new["states"] + states = old["states"] if isinstance(states, dict): states = states.items() else: states = enumerate(states) for (k, v) in states: - override.append({'filter': {'node_id': k}, - 'state': v}) - - agents['override'] = override - agents['fixed'] = fixed - agents['distribution'] = by_weight + override.append({"filter": {"node_id": k}, "state": v}) + agents["override"] = override + agents["fixed"] = fixed + agents["distribution"] = by_weight model_params = {} - if 'environment_params' in new: - del new['environment_params'] - model_params = dict(old['environment_params']) + if "environment_params" in new: + del new["environment_params"] + model_params = dict(old["environment_params"]) - if 'environment_class' in old: - del new['environment_class'] - new['model_class'] = old['environment_class'] + if "environment_class" in old: + del new["environment_class"] + new["model_class"] = old["environment_class"] - if 'dump' in old: - del new['dump'] - new['dry_run'] = not old['dump'] + if "dump" in old: + del new["dump"] + new["dry_run"] = not old["dump"] - model_params['topology'] = topology - model_params['agents'] = agents + model_params["topology"] = topology + model_params["agents"] = agents - return Config(version='2', - model_params=model_params, - **new) + return Config(version="2", model_params=model_params, **new) diff --git a/soil/datacollection.py b/soil/datacollection.py index a889a76..dea9f1d 100644 --- a/soil/datacollection.py +++ b/soil/datacollection.py @@ -1,6 +1,6 @@ from mesa import DataCollector as MDC -class SoilDataCollector(MDC): +class SoilDataCollector(MDC): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/soil/debugging.py b/soil/debugging.py index 863c50a..607996b 100644 --- a/soil/debugging.py +++ b/soil/debugging.py @@ -18,9 +18,9 @@ def wrapcmd(func): known = globals() known.update(self.curframe.f_globals) known.update(self.curframe.f_locals) - known['agent'] = known.get('self', None) - known['model'] = known.get('self', {}).get('model') - known['attrs'] = arg.strip().split() + known["agent"] = known.get("self", None) + known["model"] = known.get("self", {}).get("model") + known["attrs"] = arg.strip().split() exec(func.__code__, known, known) @@ -29,12 +29,12 @@ def wrapcmd(func): class Debug(pdb.Pdb): def __init__(self, *args, skip_soil=False, **kwargs): - skip = kwargs.get('skip', []) - skip.append('soil') + skip = kwargs.get("skip", []) + skip.append("soil") if skip_soil: - skip.append('soil') - skip.append('soil.*') - skip.append('mesa.*') + skip.append("soil") + skip.append("soil.*") + skip.append("mesa.*") super(Debug, self).__init__(*args, skip=skip, **kwargs) self.prompt = "[soil-pdb] " @@ -42,7 +42,7 @@ class Debug(pdb.Pdb): def _soil_agents(model, attrs=None, pretty=True, **kwargs): for agent in model.agents(**kwargs): d = agent - print(' - ' + indent(agent.to_str(keys=attrs, pretty=pretty), ' ')) + print(" - " + indent(agent.to_str(keys=attrs, pretty=pretty), " ")) @wrapcmd def do_soil_agents(): @@ -52,20 +52,20 @@ class Debug(pdb.Pdb): @wrapcmd def do_soil_list(): - return Debug._soil_agents(model, attrs=['state_id'], pretty=False) + return Debug._soil_agents(model, attrs=["state_id"], pretty=False) do_sl = do_soil_list def do_continue_state(self, arg): self.do_break_state(arg, temporary=True) - return self.do_continue('') + return self.do_continue("") do_cs = do_continue_state @wrapcmd def do_soil_agent(): if not agent: - print('No agent available') + print("No agent available") return keys = None @@ -81,9 +81,9 @@ class Debug(pdb.Pdb): do_aa = do_soil_agent def do_break_state(self, arg: str, instances=None, temporary=False): - ''' + """ Break before a specified state is stepped into. - ''' + """ klass = None state = arg @@ -95,39 +95,39 @@ class Debug(pdb.Pdb): if tokens: instances = list(eval(token) for token in tokens) - colon = state.find(':') + colon = state.find(":") if colon > 0: klass = state[:colon].rstrip() - state = state[colon+1:].strip() - + state = state[colon + 1 :].strip() print(klass, state, tokens) - klass = eval(klass, - self.curframe.f_globals, - self.curframe_locals) + klass = eval(klass, self.curframe.f_globals, self.curframe_locals) if klass: klasses = [klass] else: - klasses = [k for k in self.curframe.f_globals.values() if isinstance(k, type) and issubclass(k, FSM)] + klasses = [ + k + for k in self.curframe.f_globals.values() + if isinstance(k, type) and issubclass(k, FSM) + ] if not klasses: - self.error('No agent classes found') - - + self.error("No agent classes found") + for klass in klasses: try: func = getattr(klass, state) except AttributeError: - self.error(f'State {state} not found in class {klass}') + self.error(f"State {state} not found in class {klass}") continue - if hasattr(func, '__func__'): + if hasattr(func, "__func__"): func = func.__func__ code = func.__code__ - #use co_name to identify the bkpt (function names - #could be aliased, but co_name is invariant) + # use co_name to identify the bkpt (function names + # could be aliased, but co_name is invariant) funcname = code.co_name lineno = code.co_firstlineno filename = code.co_filename @@ -135,38 +135,36 @@ class Debug(pdb.Pdb): # Check for reasonable breakpoint line = self.checkline(filename, lineno) if not line: - raise ValueError('no line found') + raise ValueError("no line found") # now set the break point cond = None if instances: - cond = f'self.unique_id in { repr(instances) }' + cond = f"self.unique_id in { repr(instances) }" existing = self.get_breaks(filename, line) if existing: - self.message("Breakpoint already exists at %s:%d" % - (filename, line)) + self.message("Breakpoint already exists at %s:%d" % (filename, line)) continue err = self.set_break(filename, line, temporary, cond, funcname) if err: self.error(err) else: bp = self.get_breaks(filename, line)[-1] - self.message("Breakpoint %d at %s:%d" % - (bp.number, bp.file, bp.line)) + self.message("Breakpoint %d at %s:%d" % (bp.number, bp.file, bp.line)) do_bs = do_break_state def do_break_state_self(self, arg: str, temporary=False): - ''' + """ Break before a specified state is stepped into, for the current agent - ''' - agent = self.curframe.f_locals.get('self') + """ + agent = self.curframe.f_locals.get("self") if not agent: - self.error('No current agent.') - self.error('Try this again when the debugger is stopped inside an agent') + self.error("No current agent.") + self.error("Try this again when the debugger is stopped inside an agent") return - arg = f'{agent.__class__.__name__}:{ arg } {agent.unique_id}' + arg = f"{agent.__class__.__name__}:{ arg } {agent.unique_id}" return self.do_break_state(arg) do_bss = do_break_state_self @@ -174,6 +172,7 @@ class Debug(pdb.Pdb): debugger = None + def set_trace(frame=None, **kwargs): global debugger if debugger is None: diff --git a/soil/environment.py b/soil/environment.py index 8588eaf..d89585e 100644 --- a/soil/environment.py +++ b/soil/environment.py @@ -35,18 +35,20 @@ class BaseEnvironment(Model): :meth:`soil.environment.Environment.get` method. """ - def __init__(self, - id='unnamed_env', - seed='default', - schedule=None, - dir_path=None, - interval=1, - agent_class=None, - agents: [tuple[type, Dict[str, Any]]] = {}, - agent_reporters: Optional[Any] = None, - model_reporters: Optional[Any] = None, - tables: Optional[Any] = None, - **env_params): + def __init__( + self, + id="unnamed_env", + seed="default", + schedule=None, + dir_path=None, + interval=1, + agent_class=None, + agents: [tuple[type, Dict[str, Any]]] = {}, + agent_reporters: Optional[Any] = None, + model_reporters: Optional[Any] = None, + tables: Optional[Any] = None, + **env_params, + ): super().__init__(seed=seed) self.env_params = env_params or {} @@ -75,27 +77,26 @@ class BaseEnvironment(Model): ) def _agent_from_dict(self, agent): - ''' + """ Translate an agent dictionary into an agent - ''' + """ agent = dict(**agent) - cls = agent.pop('agent_class', None) or self.agent_class - unique_id = agent.pop('unique_id', None) + cls = agent.pop("agent_class", None) or self.agent_class + unique_id = agent.pop("unique_id", None) if unique_id is None: unique_id = self.next_id() - return serialization.deserialize(cls)(unique_id=unique_id, - model=self, **agent) + return serialization.deserialize(cls)(unique_id=unique_id, model=self, **agent) def init_agents(self, agents: Union[config.AgentConfig, [Dict[str, Any]]] = {}): - ''' + """ Initialize the agents in the model from either a `soil.config.AgentConfig` or a list of dictionaries that each describes an agent. If given a list of dictionaries, an agent will be created for each dictionary. The agent class can be specified through the `agent_class` key. The rest of the items will be used as parameters to the agent. - ''' + """ if not agents: return @@ -108,11 +109,10 @@ class BaseEnvironment(Model): override = lst.override lst = self._agent_dict_from_config(lst) - #TODO: check override is working again. It cannot (easily) be part of agents.from_config anymore, + # TODO: check override is working again. It cannot (easily) be part of agents.from_config anymore, # because it needs attribute such as unique_id, which are only present after init new_agents = [self._agent_from_dict(agent) for agent in lst] - for a in new_agents: self.schedule.add(a) @@ -122,8 +122,7 @@ class BaseEnvironment(Model): setattr(agent, attr, value) def _agent_dict_from_config(self, cfg): - return agentmod.from_config(cfg, - random=self.random) + return agentmod.from_config(cfg, random=self.random) @property def agents(self): @@ -131,7 +130,7 @@ class BaseEnvironment(Model): def find_one(self, *args, **kwargs): return agentmod.AgentView(self.schedule._agents).one(*args, **kwargs) - + def count_agents(self, *args, **kwargs): return sum(1 for i in self.agents(*args, **kwargs)) @@ -139,18 +138,16 @@ class BaseEnvironment(Model): def now(self): if self.schedule: return self.schedule.time - raise Exception('The environment has not been scheduled, so it has no sense of time') - + raise Exception( + "The environment has not been scheduled, so it has no sense of time" + ) def add_agent(self, agent_class, unique_id=None, **kwargs): a = None if unique_id is None: unique_id = self.next_id() - - a = agent_class(model=self, - unique_id=unique_id, - **args) + a = agent_class(model=self, unique_id=unique_id, **args) self.schedule.add(a) return a @@ -163,16 +160,16 @@ class BaseEnvironment(Model): for k, v in kwargs: message += " {k}={v} ".format(k, v) extra = {} - extra['now'] = self.now - extra['id'] = self.id + extra["now"] = self.now + extra["id"] = self.id return self.logger.log(level, message, extra=extra) def step(self): - ''' + """ Advance one step in the simulation, and update the data collection and scheduler appropriately - ''' + """ super().step() - self.logger.info(f'--- Step {self.now:^5} ---') + self.logger.info(f"--- Step {self.now:^5} ---") self.schedule.step() self.datacollector.collect(self) @@ -180,10 +177,10 @@ class BaseEnvironment(Model): return key in self.env_params def get(self, key, default=None): - ''' + """ Get the value of an environment attribute. Return `default` if the value is not set. - ''' + """ return self.env_params.get(key, default) def __getitem__(self, key): @@ -197,13 +194,15 @@ class BaseEnvironment(Model): class NetworkEnvironment(BaseEnvironment): - ''' + """ The NetworkEnvironment is an environment that includes one or more networkx.Graph intances and methods to associate agents to nodes and vice versa. - ''' + """ - def __init__(self, *args, topology: Union[config.NetConfig, nx.Graph] = None, **kwargs): - agents = kwargs.pop('agents', None) + def __init__( + self, *args, topology: Union[config.NetConfig, nx.Graph] = None, **kwargs + ): + agents = kwargs.pop("agents", None) super().__init__(*args, agents=None, **kwargs) self._set_topology(topology) @@ -211,37 +210,35 @@ class NetworkEnvironment(BaseEnvironment): self.init_agents(agents) def init_agents(self, *args, **kwargs): - '''Initialize the agents from a ''' + """Initialize the agents from a""" super().init_agents(*args, **kwargs) for agent in self.schedule._agents.values(): - if hasattr(agent, 'node_id'): + if hasattr(agent, "node_id"): self._init_node(agent) def _init_node(self, agent): - ''' + """ Make sure the node for a given agent has the proper attributes. - ''' - self.G.nodes[agent.node_id]['agent'] = agent + """ + self.G.nodes[agent.node_id]["agent"] = agent def _agent_dict_from_config(self, cfg): - return agentmod.from_config(cfg, - topology=self.G, - random=self.random) + return agentmod.from_config(cfg, topology=self.G, random=self.random) def _agent_from_dict(self, agent, unique_id=None): agent = dict(agent) - if not agent.get('topology', False): + if not agent.get("topology", False): return super()._agent_from_dict(agent) if unique_id is None: unique_id = self.next_id() - node_id = agent.get('node_id', None) + node_id = agent.get("node_id", None) if node_id is None: node_id = network.find_unassigned(self.G, random=self.random) - agent['node_id'] = node_id - agent['unique_id'] = unique_id - agent['topology'] = self.G + agent["node_id"] = node_id + agent["unique_id"] = unique_id + agent["topology"] = self.G node_attrs = self.G.nodes[node_id] node_attrs.update(agent) agent = node_attrs @@ -269,32 +266,33 @@ class NetworkEnvironment(BaseEnvironment): if unique_id is None: unique_id = self.next_id() if node_id is None: - node_id = network.find_unassigned(G=self.G, - shuffle=True, - random=self.random) - + node_id = network.find_unassigned( + G=self.G, shuffle=True, random=self.random + ) + if node_id in G.nodes: - self.G.nodes[node_id]['agent'] = None # Reserve + self.G.nodes[node_id]["agent"] = None # Reserve else: self.G.add_node(node_id) - a = self.add_agent(unique_id=unique_id, agent_class=agent_class, node_id=node_id, **kwargs) - a['visible'] = True + a = self.add_agent( + unique_id=unique_id, agent_class=agent_class, node_id=node_id, **kwargs + ) + a["visible"] = True return a def agent_for_node_id(self, node_id): - return self.G.nodes[node_id].get('agent') + return self.G.nodes[node_id].get("agent") def populate_network(self, agent_class, weights=None, **agent_params): - if not hasattr(agent_class, 'len'): + if not hasattr(agent_class, "len"): agent_class = [agent_class] weights = None for (node_id, node) in self.G.nodes(data=True): - if 'agent' in node: + if "agent" in node: continue a_class = self.random.choices(agent_class, weights)[0] - self.add_agent(node_id=node_id, - agent_class=a_class, **agent_params) + self.add_agent(node_id=node_id, agent_class=a_class, **agent_params) Environment = NetworkEnvironment diff --git a/soil/exporters.py b/soil/exporters.py index 648ba77..a31921d 100644 --- a/soil/exporters.py +++ b/soil/exporters.py @@ -24,56 +24,58 @@ class DryRunner(BytesIO): def write(self, txt): if self.__copy_to: - self.__copy_to.write('{}:::{}'.format(self.__fname, txt)) + self.__copy_to.write("{}:::{}".format(self.__fname, txt)) try: super().write(txt) except TypeError: - super().write(bytes(txt, 'utf-8')) + super().write(bytes(txt, "utf-8")) def close(self): - content = '(binary data not shown)' + content = "(binary data not shown)" try: content = self.getvalue().decode() except UnicodeDecodeError: pass - logger.info('**Not** written to {} (dry run mode):\n\n{}\n\n'.format(self.__fname, content)) + logger.info( + "**Not** written to {} (dry run mode):\n\n{}\n\n".format( + self.__fname, content + ) + ) super().close() class Exporter: - ''' + """ Interface for all exporters. It is not necessary, but it is useful if you don't plan to implement all the methods. - ''' + """ def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None): self.simulation = simulation - outdir = outdir or os.path.join(os.getcwd(), 'soil_output') - self.outdir = os.path.join(outdir, - simulation.group or '', - simulation.name) + outdir = outdir or os.path.join(os.getcwd(), "soil_output") + self.outdir = os.path.join(outdir, simulation.group or "", simulation.name) self.dry_run = dry_run if copy_to is None and dry_run: copy_to = sys.stdout self.copy_to = copy_to def sim_start(self): - '''Method to call when the simulation starts''' + """Method to call when the simulation starts""" pass def sim_end(self): - '''Method to call when the simulation ends''' + """Method to call when the simulation ends""" pass def trial_start(self, env): - '''Method to call when a trial start''' + """Method to call when a trial start""" pass def trial_end(self, env): - '''Method to call when a trial ends''' + """Method to call when a trial ends""" pass - def output(self, f, mode='w', **kwargs): + def output(self, f, mode="w", **kwargs): if self.dry_run: f = DryRunner(f, copy_to=self.copy_to) else: @@ -86,102 +88,117 @@ class Exporter: class default(Exporter): - '''Default exporter. Writes sqlite results, as well as the simulation YAML''' + """Default exporter. Writes sqlite results, as well as the simulation YAML""" def sim_start(self): if not self.dry_run: - logger.info('Dumping results to %s', self.outdir) - with self.output(self.simulation.name + '.dumped.yml') as f: + logger.info("Dumping results to %s", self.outdir) + with self.output(self.simulation.name + ".dumped.yml") as f: f.write(self.simulation.to_yaml()) else: - logger.info('NOT dumping results') + logger.info("NOT dumping results") def trial_end(self, env): if self.dry_run: - logger.info('Running in DRY_RUN mode, the database will NOT be created') + logger.info("Running in DRY_RUN mode, the database will NOT be created") return - with timer('Dumping simulation {} trial {}'.format(self.simulation.name, - env.id)): + with timer( + "Dumping simulation {} trial {}".format(self.simulation.name, env.id) + ): - fpath = os.path.join(self.outdir, f'{env.id}.sqlite') - engine = create_engine(f'sqlite:///{fpath}', echo=False) + fpath = os.path.join(self.outdir, f"{env.id}.sqlite") + engine = create_engine(f"sqlite:///{fpath}", echo=False) dc = env.datacollector for (t, df) in get_dc_dfs(dc): - df.to_sql(t, con=engine, if_exists='append') + df.to_sql(t, con=engine, if_exists="append") def get_dc_dfs(dc): - dfs = {'env': dc.get_model_vars_dataframe(), - 'agents': dc.get_agent_vars_dataframe() } + dfs = { + "env": dc.get_model_vars_dataframe(), + "agents": dc.get_agent_vars_dataframe(), + } for table_name in dc.tables: dfs[table_name] = dc.get_table_dataframe(table_name) - yield from dfs.items() + yield from dfs.items() class csv(Exporter): - '''Export the state of each environment (and its agents) in a separate CSV file''' + """Export the state of each environment (and its agents) in a separate CSV file""" + def trial_end(self, env): - with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name, - env.id, - self.outdir)): + with timer( + "[CSV] Dumping simulation {} trial {} @ dir {}".format( + self.simulation.name, env.id, self.outdir + ) + ): for (df_name, df) in get_dc_dfs(env.datacollector): - with self.output('{}.{}.csv'.format(env.id, df_name)) as f: + with self.output("{}.{}.csv".format(env.id, df_name)) as f: df.to_csv(f) -#TODO: reimplement GEXF exporting without history +# TODO: reimplement GEXF exporting without history class gexf(Exporter): def trial_end(self, env): if self.dry_run: - logger.info('Not dumping GEXF in dry_run mode') + logger.info("Not dumping GEXF in dry_run mode") return - with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name, - env.id)): - with self.output('{}.gexf'.format(env.id), mode='wb') as f: + with timer( + "[GEXF] Dumping simulation {} trial {}".format(self.simulation.name, env.id) + ): + with self.output("{}.gexf".format(env.id), mode="wb") as f: network.dump_gexf(env.history_to_graph(), f) self.dump_gexf(env, f) class dummy(Exporter): - def sim_start(self): - with self.output('dummy', 'w') as f: - f.write('simulation started @ {}\n'.format(current_time())) + with self.output("dummy", "w") as f: + f.write("simulation started @ {}\n".format(current_time())) def trial_start(self, env): - with self.output('dummy', 'w') as f: - f.write('trial started@ {}\n'.format(current_time())) + with self.output("dummy", "w") as f: + f.write("trial started@ {}\n".format(current_time())) def trial_end(self, env): - with self.output('dummy', 'w') as f: - f.write('trial ended@ {}\n'.format(current_time())) + with self.output("dummy", "w") as f: + f.write("trial ended@ {}\n".format(current_time())) def sim_end(self): - with self.output('dummy', 'a') as f: - f.write('simulation ended @ {}\n'.format(current_time())) + with self.output("dummy", "a") as f: + f.write("simulation ended @ {}\n".format(current_time())) + class graphdrawing(Exporter): - def trial_end(self, env): # Outside effects f = plt.figure() - nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111)) - with open('graph-{}.png'.format(env.id)) as f: + nx.draw( + env.G, + node_size=10, + width=0.2, + pos=nx.spring_layout(env.G, scale=100), + ax=f.add_subplot(111), + ) + with open("graph-{}.png".format(env.id)) as f: f.savefig(f) -''' + +""" Convert an environment into a NetworkX graph -''' +""" + + def env_to_graph(env, history=None): G = nx.Graph(env.G) for agent in env.network_agents: - attributes = {'agent': str(agent.__class__)} + attributes = {"agent": str(agent.__class__)} lastattributes = {} spells = [] lastvisible = False @@ -189,7 +206,7 @@ def env_to_graph(env, history=None): if not history: history = sorted(list(env.state_to_tuples())) for _, t_step, attribute, value in history: - if attribute == 'visible': + if attribute == "visible": nowvisible = value if nowvisible and not lastvisible: laststep = t_step @@ -198,7 +215,7 @@ def env_to_graph(env, history=None): lastvisible = nowvisible continue - key = 'attr_' + attribute + key = "attr_" + attribute if key not in attributes: attributes[key] = list() if key not in lastattributes: diff --git a/soil/network.py b/soil/network.py index bc69716..5c0b005 100644 --- a/soil/network.py +++ b/soil/network.py @@ -9,6 +9,7 @@ import networkx as nx from . import config, serialization, basestring + def from_config(cfg: config.NetConfig, dir_path: str = None): if not isinstance(cfg, config.NetConfig): cfg = config.NetConfig(**cfg) @@ -19,24 +20,28 @@ def from_config(cfg: config.NetConfig, dir_path: str = None): path = os.path.join(dir_path, path) extension = os.path.splitext(path)[1][1:] kwargs = {} - if extension == 'gexf': - kwargs['version'] = '1.2draft' - kwargs['node_type'] = int + if extension == "gexf": + kwargs["version"] = "1.2draft" + kwargs["node_type"] = int try: - method = getattr(nx.readwrite, 'read_' + extension) + method = getattr(nx.readwrite, "read_" + extension) except AttributeError: - raise AttributeError('Unknown format') + raise AttributeError("Unknown format") return method(path, **kwargs) if cfg.params: net_args = cfg.params.dict() - net_gen = net_args.pop('generator') + net_gen = net_args.pop("generator") if dir_path not in sys.path: sys.path.append(dir_path) - method = serialization.deserializer(net_gen, - known_modules=['networkx.generators',]) + method = serialization.deserializer( + net_gen, + known_modules=[ + "networkx.generators", + ], + ) return method(**net_args) if isinstance(cfg.fixed, config.Topology): @@ -49,17 +54,17 @@ def from_config(cfg: config.NetConfig, dir_path: str = None): def find_unassigned(G, shuffle=False, random=random): - ''' + """ Link an agent to a node in a topology. If node_id is None, a node without an agent_id will be found. - ''' - #TODO: test + """ + # TODO: test candidates = list(G.nodes(data=True)) if shuffle: random.shuffle(candidates) for next_id, data in candidates: - if 'agent' not in data: + if "agent" not in data: node_id = next_id break @@ -68,8 +73,14 @@ def find_unassigned(G, shuffle=False, random=random): def dump_gexf(G, f): for node in G.nodes(): - if 'pos' in G.nodes[node]: - G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}} - del (G.nodes[node]['pos']) + if "pos" in G.nodes[node]: + G.nodes[node]["viz"] = { + "position": { + "x": G.nodes[node]["pos"][0], + "y": G.nodes[node]["pos"][1], + "z": 0.0, + } + } + del G.nodes[node]["pos"] nx.write_gexf(G, f, version="1.2draft") diff --git a/soil/serialization.py b/soil/serialization.py index 972ca69..b728983 100644 --- a/soil/serialization.py +++ b/soil/serialization.py @@ -15,13 +15,14 @@ import networkx as nx from jinja2 import Template -logger = logging.getLogger('soil') +logger = logging.getLogger("soil") + def load_file(infile): folder = os.path.dirname(infile) if folder not in sys.path: sys.path.append(folder) - with open(infile, 'r') as f: + with open(infile, "r") as f: return list(chain.from_iterable(map(expand_template, load_string(f)))) @@ -30,14 +31,15 @@ def load_string(string): def expand_template(config): - if 'template' not in config: + if "template" not in config: yield config return - if 'vars' not in config: - raise ValueError(('You must provide a definition of variables' - ' for the template.')) + if "vars" not in config: + raise ValueError( + ("You must provide a definition of variables" " for the template.") + ) - template = config['template'] + template = config["template"] if not isinstance(template, str): template = yaml.dump(template) @@ -49,9 +51,9 @@ def expand_template(config): blank_str = template.render({k: 0 for k in params[0].keys()}) blank = list(load_string(blank_str)) if len(blank) > 1: - raise ValueError('Templates must not return more than one configuration') - if 'name' in blank[0]: - raise ValueError('Templates cannot be named, use group instead') + raise ValueError("Templates must not return more than one configuration") + if "name" in blank[0]: + raise ValueError("Templates cannot be named, use group instead") for ps in params: string = template.render(ps) @@ -60,25 +62,25 @@ def expand_template(config): def params_for_template(config): - sampler_config = config.get('sampler', {'N': 100}) - sampler = sampler_config.pop('method', 'SALib.sample.morris.sample') + sampler_config = config.get("sampler", {"N": 100}) + sampler = sampler_config.pop("method", "SALib.sample.morris.sample") sampler = deserializer(sampler) - bounds = config['vars']['bounds'] + bounds = config["vars"]["bounds"] problem = { - 'num_vars': len(bounds), - 'names': list(bounds.keys()), - 'bounds': list(v for v in bounds.values()) + "num_vars": len(bounds), + "names": list(bounds.keys()), + "bounds": list(v for v in bounds.values()), } samples = sampler(problem, **sampler_config) - lists = config['vars'].get('lists', {}) + lists = config["vars"].get("lists", {}) names = list(lists.keys()) values = list(lists.values()) combs = list(product(*values)) - allnames = names + problem['names'] - allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)] + allnames = names + problem["names"] + allvalues = [(list(i[0]) + list(i[1])) for i in product(combs, samples)] params = list(map(lambda x: dict(zip(allnames, x)), allvalues)) return params @@ -100,22 +102,24 @@ def load_config(cfg): yield from load_files(cfg) -builtins = importlib.import_module('builtins') +builtins = importlib.import_module("builtins") -KNOWN_MODULES = ['soil', ] +KNOWN_MODULES = [ + "soil", +] def name(value, known_modules=KNOWN_MODULES): - '''Return a name that can be imported, to serialize/deserialize an object''' + """Return a name that can be imported, to serialize/deserialize an object""" if value is None: - return 'None' + return "None" if not isinstance(value, type): # Get the class name first value = type(value) tname = value.__name__ if hasattr(builtins, tname): return tname modname = value.__module__ - if modname == '__main__': + if modname == "__main__": return tname if known_modules and modname in known_modules: return tname @@ -125,17 +129,17 @@ def name(value, known_modules=KNOWN_MODULES): module = importlib.import_module(kmod) if hasattr(module, tname): return tname - return '{}.{}'.format(modname, tname) + return "{}.{}".format(modname, tname) def serializer(type_): - if type_ != 'str' and hasattr(builtins, type_): + if type_ != "str" and hasattr(builtins, type_): return repr return lambda x: x def serialize(v, known_modules=KNOWN_MODULES): - '''Get a text representation of an object.''' + """Get a text representation of an object.""" tname = name(v, known_modules=known_modules) func = serializer(tname) return func(v), tname @@ -160,9 +164,9 @@ IS_CLASS = re.compile(r"") def deserializer(type_, known_modules=KNOWN_MODULES): if type(type_) != str: # Already deserialized return type_ - if type_ == 'str': - return lambda x='': x - if type_ == 'None': + if type_ == "str": + return lambda x="": x + if type_ == "None": return lambda x=None: None if hasattr(builtins, type_): # Check if it's a builtin type cls = getattr(builtins, type_) @@ -172,8 +176,8 @@ def deserializer(type_, known_modules=KNOWN_MODULES): modname, tname = match.group(1).rsplit(".", 1) module = importlib.import_module(modname) cls = getattr(module, tname) - return getattr(cls, 'deserialize', cls) - + return getattr(cls, "deserialize", cls) + # Otherwise, see if we can find the module and the class options = [] @@ -181,7 +185,7 @@ def deserializer(type_, known_modules=KNOWN_MODULES): if mod: options.append((mod, type_)) - if '.' in type_: # Fully qualified module + if "." in type_: # Fully qualified module module, type_ = type_.rsplit(".", 1) options.append((module, type_)) @@ -190,32 +194,31 @@ def deserializer(type_, known_modules=KNOWN_MODULES): try: module = importlib.import_module(modname) cls = getattr(module, tname) - return getattr(cls, 'deserialize', cls) + return getattr(cls, "deserialize", cls) except (ImportError, AttributeError) as ex: errors.append((modname, tname, ex)) raise Exception('Could not find type "{}". Tried: {}'.format(type_, errors)) def deserialize(type_, value=None, globs=None, **kwargs): - '''Get an object from a text representation''' + """Get an object from a text representation""" if not isinstance(type_, str): return type_ if globs and type_ in globs: des = globs[type_] else: - des = deserializer(type_, **kwargs) + des = deserializer(type_, **kwargs) if value is None: return des return des(value) def deserialize_all(names, *args, known_modules=KNOWN_MODULES, **kwargs): - '''Return the list of deserialized objects''' - #TODO: remove - print('SERIALIZATION', kwargs) + """Return the list of deserialized objects""" + # TODO: remove + print("SERIALIZATION", kwargs) objects = [] for name in names: mod = deserialize(name, known_modules=known_modules) objects.append(mod(*args, **kwargs)) return objects - diff --git a/soil/simulation.py b/soil/simulation.py index baee50f..7c79d92 100644 --- a/soil/simulation.py +++ b/soil/simulation.py @@ -1,5 +1,5 @@ import os -from time import time as current_time, strftime +from time import time as current_time, strftime import importlib import sys import yaml @@ -25,7 +25,7 @@ from .time import INFINITY from .config import Config, convert_old -#TODO: change documentation for simulation +# TODO: change documentation for simulation @dataclass class Simulation: """ @@ -36,15 +36,16 @@ class Simulation: kwargs: parameters to use to initialize a new configuration, if one not been provided. """ - version: str = '2' - name: str = 'Unnamed simulation' - description: Optional[str] = '' + + version: str = "2" + name: str = "Unnamed simulation" + description: Optional[str] = "" group: str = None - model_class: Union[str, type] = 'soil.Environment' + model_class: Union[str, type] = "soil.Environment" model_params: dict = field(default_factory=dict) seed: str = field(default_factory=lambda: current_time()) dir_path: str = field(default_factory=lambda: os.getcwd()) - max_time: float = float('inf') + max_time: float = float("inf") max_steps: int = -1 interval: int = 1 num_trials: int = 3 @@ -56,14 +57,15 @@ class Simulation: extra: Dict[str, Any] = field(default_factory=dict) @classmethod - def from_dict(cls, env, **kwargs): + def from_dict(cls, env, **kwargs): - ignored = {k: v for k, v in env.items() - if k not in inspect.signature(cls).parameters} + ignored = { + k: v for k, v in env.items() if k not in inspect.signature(cls).parameters + } - d = {k:v for k, v in env.items() if k not in ignored} + d = {k: v for k, v in env.items() if k not in ignored} if ignored: - d.setdefault('extra', {}).update(ignored) + d.setdefault("extra", {}).update(ignored) if ignored: print(f'Warning: Ignoring these parameters (added to "extra"): { ignored }') d.update(kwargs) @@ -74,24 +76,34 @@ class Simulation: return self.run(*args, **kwargs) def run(self, *args, **kwargs): - '''Run the simulation and return the list of resulting environments''' - logger.info(dedent(''' + """Run the simulation and return the list of resulting environments""" + logger.info( + dedent( + """ Simulation: --- - ''') + - self.to_yaml()) + """ + ) + + self.to_yaml() + ) return list(self.run_gen(*args, **kwargs)) - def run_gen(self, parallel=False, dry_run=None, - exporters=None, outdir=None, exporter_params={}, - log_level=None, - **kwargs): - '''Run the simulation and yield the resulting environments.''' + def run_gen( + self, + parallel=False, + dry_run=None, + exporters=None, + outdir=None, + exporter_params={}, + log_level=None, + **kwargs, + ): + """Run the simulation and yield the resulting environments.""" if log_level: logger.setLevel(log_level) outdir = outdir or self.outdir - logger.info('Using exporters: %s', exporters or []) - logger.info('Output directory: %s', outdir) + logger.info("Using exporters: %s", exporters or []) + logger.info("Output directory: %s", outdir) if dry_run is None: dry_run = self.dry_run if exporters is None: @@ -99,22 +111,28 @@ class Simulation: if not exporter_params: exporter_params = self.exporter_params - exporters = serialization.deserialize_all(exporters, - simulation=self, - known_modules=['soil.exporters', ], - dry_run=dry_run, - outdir=outdir, - **exporter_params) + exporters = serialization.deserialize_all( + exporters, + simulation=self, + known_modules=[ + "soil.exporters", + ], + dry_run=dry_run, + outdir=outdir, + **exporter_params, + ) - with utils.timer('simulation {}'.format(self.name)): + with utils.timer("simulation {}".format(self.name)): for exporter in exporters: exporter.sim_start() - for env in utils.run_parallel(func=self.run_trial, - iterable=range(int(self.num_trials)), - parallel=parallel, - log_level=log_level, - **kwargs): + for env in utils.run_parallel( + func=self.run_trial, + iterable=range(int(self.num_trials)), + parallel=parallel, + log_level=log_level, + **kwargs, + ): for exporter in exporters: exporter.trial_start(env) @@ -128,11 +146,12 @@ class Simulation: exporter.sim_end() def get_env(self, trial_id=0, model_params=None, **kwargs): - '''Create an environment for a trial of the simulation''' + """Create an environment for a trial of the simulation""" + def deserialize_reporters(reporters): for (k, v) in reporters.items(): - if isinstance(v, str) and v.startswith('py:'): - reporters[k] = serialization.deserialize(value.lsplit(':', 1)[1]) + if isinstance(v, str) and v.startswith("py:"): + reporters[k] = serialization.deserialize(value.lsplit(":", 1)[1]) return reporters params = self.model_params.copy() @@ -140,18 +159,22 @@ class Simulation: params.update(model_params) params.update(kwargs) - agent_reporters = deserialize_reporters(params.pop('agent_reporters', {})) - model_reporters = deserialize_reporters(params.pop('model_reporters', {})) + agent_reporters = deserialize_reporters(params.pop("agent_reporters", {})) + model_reporters = deserialize_reporters(params.pop("model_reporters", {})) - env = serialization.deserialize(self.model_class) - return env(id=f'{self.name}_trial_{trial_id}', - seed=f'{self.seed}_trial_{trial_id}', - dir_path=self.dir_path, - agent_reporters=agent_reporters, - model_reporters=model_reporters, - **params) + env = serialization.deserialize(self.model_class) + return env( + id=f"{self.name}_trial_{trial_id}", + seed=f"{self.seed}_trial_{trial_id}", + dir_path=self.dir_path, + agent_reporters=agent_reporters, + model_reporters=model_reporters, + **params, + ) - def run_trial(self, trial_id=None, until=None, log_file=False, log_level=logging.INFO, **opts): + def run_trial( + self, trial_id=None, until=None, log_file=False, log_level=logging.INFO, **opts + ): """ Run a single trial of the simulation @@ -160,50 +183,58 @@ class Simulation: logger.setLevel(log_level) model = self.get_env(trial_id, **opts) trial_id = trial_id if trial_id is not None else current_time() - with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)): - return self.run_model(model=model, trial_id=trial_id, until=until, log_level=log_level) + with utils.timer("Simulation {} trial {}".format(self.name, trial_id)): + return self.run_model( + model=model, trial_id=trial_id, until=until, log_level=log_level + ) def run_model(self, model, until=None, **opts): # Set-up trial environment and graph - until = float(until or self.max_time or 'inf') + until = float(until or self.max_time or "inf") # Set up agents on nodes def is_done(): return False - if until and hasattr(model.schedule, 'time'): + if until and hasattr(model.schedule, "time"): prev = is_done def is_done(): return prev() or model.schedule.time >= until - if self.max_steps and self.max_steps > 0 and hasattr(model.schedule, 'steps'): + if self.max_steps and self.max_steps > 0 and hasattr(model.schedule, "steps"): prev_steps = is_done def is_done(): return prev_steps() or model.schedule.steps >= self.max_steps - newline = '\n' - logger.info(dedent(f''' + newline = "\n" + logger.info( + dedent( + f""" Model stats: Agents (total: { model.schedule.get_agent_count() }): - { (newline + ' - ').join(str(a) for a in model.schedule.agents) } Topology size: { len(model.G) if hasattr(model, "G") else 0 } - ''')) + """ + ) + ) while not is_done(): - utils.logger.debug(f'Simulation time {model.schedule.time}/{until}. Next: {getattr(model.schedule, "next_time", model.schedule.time + self.interval)}') + utils.logger.debug( + f'Simulation time {model.schedule.time}/{until}. Next: {getattr(model.schedule, "next_time", model.schedule.time + self.interval)}' + ) model.step() return model def to_dict(self): d = asdict(self) - if not isinstance(d['model_class'], str): - d['model_class'] = serialization.name(d['model_class']) - d['model_params'] = serialization.serialize_dict(d['model_params']) - d['dir_path'] = str(d['dir_path']) - d['version'] = '2' + if not isinstance(d["model_class"], str): + d["model_class"] = serialization.name(d["model_class"]) + d["model_params"] = serialization.serialize_dict(d["model_params"]) + d["dir_path"] = str(d["dir_path"]) + d["version"] = "2" return d def to_yaml(self): @@ -215,15 +246,15 @@ def iter_from_config(*cfgs, **kwargs): configs = list(serialization.load_config(config)) for config, path in configs: d = dict(config) - if 'dir_path' not in d: - d['dir_path'] = os.path.dirname(path) + if "dir_path" not in d: + d["dir_path"] = os.path.dirname(path) yield Simulation.from_dict(d, **kwargs) def from_config(conf_or_path): lst = list(iter_from_config(conf_or_path)) if len(lst) > 1: - raise AttributeError('Provide only one configuration') + raise AttributeError("Provide only one configuration") return lst[0] diff --git a/soil/time.py b/soil/time.py index 602aa8c..11e3178 100644 --- a/soil/time.py +++ b/soil/time.py @@ -6,7 +6,8 @@ from .utils import logger from mesa import Agent as MesaAgent -INFINITY = float('inf') +INFINITY = float("inf") + class When: def __init__(self, time): @@ -42,7 +43,7 @@ class TimedActivation(BaseScheduler): self._next = {} self._queue = [] self.next_time = 0 - self.logger = logger.getChild(f'time_{ self.model }') + self.logger = logger.getChild(f"time_{ self.model }") def add(self, agent: MesaAgent, when=None): if when is None: @@ -51,7 +52,7 @@ class TimedActivation(BaseScheduler): self._queue.remove((self._next[agent.unique_id], agent.unique_id)) del self._agents[agent.unique_id] heapify(self._queue) - + heappush(self._queue, (when, agent.unique_id)) self._next[agent.unique_id] = when super().add(agent) @@ -62,7 +63,7 @@ class TimedActivation(BaseScheduler): an agent will signal when it wants to be scheduled next. """ - self.logger.debug(f'Simulation step {self.next_time}') + self.logger.debug(f"Simulation step {self.next_time}") if not self.model.running: return @@ -71,18 +72,22 @@ class TimedActivation(BaseScheduler): while self._queue and self._queue[0][0] == self.time: (when, agent_id) = heappop(self._queue) - self.logger.debug(f'Stepping agent {agent_id}') + self.logger.debug(f"Stepping agent {agent_id}") agent = self._agents[agent_id] returned = agent.step() - if not getattr(agent, 'alive', True): + if not getattr(agent, "alive", True): self.remove(agent) continue when = (returned or Delta(1)).abs(self.time) if when < self.time: - raise Exception("Cannot schedule an agent for a time in the past ({} < {})".format(when, self.time)) + raise Exception( + "Cannot schedule an agent for a time in the past ({} < {})".format( + when, self.time + ) + ) self._next[agent_id] = when heappush(self._queue, (when, agent_id)) @@ -96,4 +101,4 @@ class TimedActivation(BaseScheduler): return self.time self.next_time = self._queue[0][0] - self.logger.debug(f'Next step: {self.next_time}') + self.logger.debug(f"Next step: {self.next_time}") diff --git a/soil/utils.py b/soil/utils.py index 6c25dbc..9c4bcc7 100644 --- a/soil/utils.py +++ b/soil/utils.py @@ -9,12 +9,12 @@ from multiprocessing import Pool from contextlib import contextmanager -logger = logging.getLogger('soil') +logger = logging.getLogger("soil") logger.setLevel(logging.INFO) timeformat = "%H:%M:%S" -if os.environ.get('SOIL_VERBOSE', ''): +if os.environ.get("SOIL_VERBOSE", ""): logformat = "[%(levelname)-5.5s][%(asctime)s][%(name)s]: %(message)s" else: logformat = "[%(levelname)-5.5s][%(asctime)s] %(message)s" @@ -23,38 +23,44 @@ logFormatter = logging.Formatter(logformat, timeformat) consoleHandler = logging.StreamHandler() consoleHandler.setFormatter(logFormatter) -logging.basicConfig(level=logging.INFO, - handlers=[consoleHandler,]) +logging.basicConfig( + level=logging.INFO, + handlers=[ + consoleHandler, + ], +) @contextmanager -def timer(name='task', pre="", function=logger.info, to_object=None): +def timer(name="task", pre="", function=logger.info, to_object=None): start = current_time() - function('{}Starting {} at {}.'.format(pre, name, - strftime("%X", gmtime(start)))) + function("{}Starting {} at {}.".format(pre, name, strftime("%X", gmtime(start)))) yield start end = current_time() - function('{}Finished {} at {} in {} seconds'.format(pre, name, - strftime("%X", gmtime(end)), - str(end-start))) + function( + "{}Finished {} at {} in {} seconds".format( + pre, name, strftime("%X", gmtime(end)), str(end - start) + ) + ) if to_object: to_object.start = start to_object.end = end -def safe_open(path, mode='r', backup=True, **kwargs): +def safe_open(path, mode="r", backup=True, **kwargs): outdir = os.path.dirname(path) if outdir and not os.path.exists(outdir): os.makedirs(outdir) - if backup and 'w' in mode and os.path.exists(path): + if backup and "w" in mode and os.path.exists(path): creation = os.path.getctime(path) - stamp = strftime('%Y-%m-%d_%H.%M.%S', localtime(creation)) + stamp = strftime("%Y-%m-%d_%H.%M.%S", localtime(creation)) - backup_dir = os.path.join(outdir, 'backup') + backup_dir = os.path.join(outdir, "backup") if not os.path.exists(backup_dir): os.makedirs(backup_dir) - newpath = os.path.join(backup_dir, '{}@{}'.format(os.path.basename(path), - stamp)) + newpath = os.path.join( + backup_dir, "{}@{}".format(os.path.basename(path), stamp) + ) copyfile(path, newpath) return open(path, mode=mode, **kwargs) @@ -67,21 +73,23 @@ def open_or_reuse(f, *args, **kwargs): except (AttributeError, TypeError): yield f + def flatten_dict(d): if not isinstance(d, dict): return d return dict(_flatten_dict(d)) -def _flatten_dict(d, prefix=''): + +def _flatten_dict(d, prefix=""): if not isinstance(d, dict): # print('END:', prefix, d) yield prefix, d return if prefix: - prefix = prefix + '.' + prefix = prefix + "." for k, v in d.items(): # print(k, v) - res = list(_flatten_dict(v, prefix='{}{}'.format(prefix, k))) + res = list(_flatten_dict(v, prefix="{}{}".format(prefix, k))) # print('RES:', res) yield from res @@ -93,7 +101,7 @@ def unflatten_dict(d): if not isinstance(k, str): target[k] = v continue - tokens = k.split('.') + tokens = k.split(".") if len(tokens) < 2: target[k] = v continue @@ -106,27 +114,28 @@ def unflatten_dict(d): def run_and_return_exceptions(func, *args, **kwargs): - ''' + """ A wrapper for run_trial that catches exceptions and returns them. It is meant for async simulations. - ''' + """ try: return func(*args, **kwargs) except Exception as ex: if ex.__cause__ is not None: ex = ex.__cause__ - ex.message = ''.join(traceback.format_exception(type(ex), ex, ex.__traceback__)[:]) + ex.message = "".join( + traceback.format_exception(type(ex), ex, ex.__traceback__)[:] + ) return ex def run_parallel(func, iterable, parallel=False, **kwargs): - if parallel and not os.environ.get('SOIL_DEBUG', None): + if parallel and not os.environ.get("SOIL_DEBUG", None): p = Pool() - wrapped_func = partial(run_and_return_exceptions, - func, **kwargs) + wrapped_func = partial(run_and_return_exceptions, func, **kwargs) for i in p.imap_unordered(wrapped_func, iterable): if isinstance(i, Exception): - logger.error('Trial failed:\n\t%s', i.message) + logger.error("Trial failed:\n\t%s", i.message) continue yield i else: diff --git a/soil/version.py b/soil/version.py index ea5b40a..ae66caa 100644 --- a/soil/version.py +++ b/soil/version.py @@ -4,7 +4,7 @@ import logging logger = logging.getLogger(__name__) ROOT = os.path.dirname(__file__) -DEFAULT_FILE = os.path.join(ROOT, 'VERSION') +DEFAULT_FILE = os.path.join(ROOT, "VERSION") def read_version(versionfile=DEFAULT_FILE): @@ -12,9 +12,10 @@ def read_version(versionfile=DEFAULT_FILE): with open(versionfile) as f: return f.read().strip() except IOError: # pragma: no cover - logger.error(('Running an unknown version of {}.' - 'Be careful!.').format(__name__)) - return '0.0' + logger.error( + ("Running an unknown version of {}." "Be careful!.").format(__name__) + ) + return "0.0" __version__ = read_version() diff --git a/soil/visualization.py b/soil/visualization.py index fe12aca..a1cb7b8 100644 --- a/soil/visualization.py +++ b/soil/visualization.py @@ -1,5 +1,6 @@ from mesa.visualization.UserParam import UserSettableParameter + class UserSettableParameter(UserSettableParameter): def __str__(self): return self.value diff --git a/soil/web/__init__.py b/soil/web/__init__.py index 2339288..5327703 100644 --- a/soil/web/__init__.py +++ b/soil/web/__init__.py @@ -20,6 +20,7 @@ from tornado.concurrent import run_on_executor from concurrent.futures import ThreadPoolExecutor from ..simulation import Simulation + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -31,140 +32,183 @@ LOGGING_INTERVAL = 0.5 # Workaround to let Soil load the required modules sys.path.append(ROOT) + class PageHandler(tornado.web.RequestHandler): - """ Handler for the HTML template which holds the visualization. """ + """Handler for the HTML template which holds the visualization.""" def get(self): - self.render('index.html', port=self.application.port, - name=self.application.name) + self.render( + "index.html", port=self.application.port, name=self.application.name + ) class SocketHandler(tornado.websocket.WebSocketHandler): - """ Handler for websocket. """ + """Handler for websocket.""" + executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) def open(self): if self.application.verbose: - logger.info('Socket opened!') + logger.info("Socket opened!") def check_origin(self, origin): return True def on_message(self, message): - """ Receiving a message from the websocket, parse, and act accordingly. """ + """Receiving a message from the websocket, parse, and act accordingly.""" msg = tornado.escape.json_decode(message) - if msg['type'] == 'config_file': + if msg["type"] == "config_file": if self.application.verbose: - print(msg['data']) + print(msg["data"]) - self.config = list(yaml.load_all(msg['data'])) + self.config = list(yaml.load_all(msg["data"])) if len(self.config) > 1: - error = 'Please, provide only one configuration.' + error = "Please, provide only one configuration." if self.application.verbose: logger.error(error) - self.write_message({'type': 'error', - 'error': error}) + self.write_message({"type": "error", "error": error}) return self.config = self.config[0] - self.send_log('INFO.' + self.simulation_name, - 'Using config: {name}'.format(name=self.config['name'])) + self.send_log( + "INFO." + self.simulation_name, + "Using config: {name}".format(name=self.config["name"]), + ) - if 'visualization_params' in self.config: - self.write_message({'type': 'visualization_params', - 'data': self.config['visualization_params']}) - self.name = self.config['name'] + if "visualization_params" in self.config: + self.write_message( + { + "type": "visualization_params", + "data": self.config["visualization_params"], + } + ) + self.name = self.config["name"] self.run_simulation() settings = [] - for key in self.config['environment_params']: - if type(self.config['environment_params'][key]) == float or type(self.config['environment_params'][key]) == int: - if self.config['environment_params'][key] <= 1: - setting_type = 'number' + for key in self.config["environment_params"]: + if ( + type(self.config["environment_params"][key]) == float + or type(self.config["environment_params"][key]) == int + ): + if self.config["environment_params"][key] <= 1: + setting_type = "number" else: - setting_type = 'great_number' - elif type(self.config['environment_params'][key]) == bool: - setting_type = 'boolean' + setting_type = "great_number" + elif type(self.config["environment_params"][key]) == bool: + setting_type = "boolean" else: - setting_type = 'undefined' + setting_type = "undefined" - settings.append({ - 'label': key, - 'type': setting_type, - 'value': self.config['environment_params'][key] - }) + settings.append( + { + "label": key, + "type": setting_type, + "value": self.config["environment_params"][key], + } + ) - self.write_message({'type': 'settings', - 'data': settings}) + self.write_message({"type": "settings", "data": settings}) - elif msg['type'] == 'get_trial': + elif msg["type"] == "get_trial": if self.application.verbose: - logger.info('Trial {} requested!'.format(msg['data'])) - self.send_log('INFO.' + __name__, 'Trial {} requested!'.format(msg['data'])) - self.write_message({'type': 'get_trial', - 'data': self.get_trial(int(msg['data']))}) + logger.info("Trial {} requested!".format(msg["data"])) + self.send_log("INFO." + __name__, "Trial {} requested!".format(msg["data"])) + self.write_message( + {"type": "get_trial", "data": self.get_trial(int(msg["data"]))} + ) - elif msg['type'] == 'run_simulation': + elif msg["type"] == "run_simulation": if self.application.verbose: - logger.info('Running new simulation for {name}'.format(name=self.config['name'])) - self.send_log('INFO.' + self.simulation_name, 'Running new simulation for {name}'.format(name=self.config['name'])) - self.config['environment_params'] = msg['data'] + logger.info( + "Running new simulation for {name}".format(name=self.config["name"]) + ) + self.send_log( + "INFO." + self.simulation_name, + "Running new simulation for {name}".format(name=self.config["name"]), + ) + self.config["environment_params"] = msg["data"] self.run_simulation() - elif msg['type'] == 'download_gexf': - G = self.trials[ int(msg['data']) ].history_to_graph() + elif msg["type"] == "download_gexf": + G = self.trials[int(msg["data"])].history_to_graph() for node in G.nodes(): - if 'pos' in G.nodes[node]: - G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}} - del (G.nodes[node]['pos']) - writer = nx.readwrite.gexf.GEXFWriter(version='1.2draft') + if "pos" in G.nodes[node]: + G.nodes[node]["viz"] = { + "position": { + "x": G.nodes[node]["pos"][0], + "y": G.nodes[node]["pos"][1], + "z": 0.0, + } + } + del G.nodes[node]["pos"] + writer = nx.readwrite.gexf.GEXFWriter(version="1.2draft") writer.add_graph(G) - self.write_message({'type': 'download_gexf', - 'filename': self.config['name'] + '_trial_' + str(msg['data']), - 'data': tostring(writer.xml).decode(writer.encoding) }) + self.write_message( + { + "type": "download_gexf", + "filename": self.config["name"] + "_trial_" + str(msg["data"]), + "data": tostring(writer.xml).decode(writer.encoding), + } + ) - elif msg['type'] == 'download_json': - G = self.trials[ int(msg['data']) ].history_to_graph() + elif msg["type"] == "download_json": + G = self.trials[int(msg["data"])].history_to_graph() for node in G.nodes(): - if 'pos' in G.nodes[node]: - G.nodes[node]['viz'] = {"position": {"x": G.nodes[node]['pos'][0], "y": G.nodes[node]['pos'][1], "z": 0.0}} - del (G.nodes[node]['pos']) - self.write_message({'type': 'download_json', - 'filename': self.config['name'] + '_trial_' + str(msg['data']), - 'data': nx.node_link_data(G) }) + if "pos" in G.nodes[node]: + G.nodes[node]["viz"] = { + "position": { + "x": G.nodes[node]["pos"][0], + "y": G.nodes[node]["pos"][1], + "z": 0.0, + } + } + del G.nodes[node]["pos"] + self.write_message( + { + "type": "download_json", + "filename": self.config["name"] + "_trial_" + str(msg["data"]), + "data": nx.node_link_data(G), + } + ) else: if self.application.verbose: - logger.info('Unexpected message!') + logger.info("Unexpected message!") def update_logging(self): try: - if (not self.log_capture_string.closed and self.log_capture_string.getvalue()): - for i in range(len(self.log_capture_string.getvalue().split('\n')) - 1): - self.send_log('INFO.' + self.simulation_name, self.log_capture_string.getvalue().split('\n')[i]) + if ( + not self.log_capture_string.closed + and self.log_capture_string.getvalue() + ): + for i in range(len(self.log_capture_string.getvalue().split("\n")) - 1): + self.send_log( + "INFO." + self.simulation_name, + self.log_capture_string.getvalue().split("\n")[i], + ) self.log_capture_string.truncate(0) self.log_capture_string.seek(0) finally: if self.capture_logging: - tornado.ioloop.IOLoop.current().call_later(LOGGING_INTERVAL, self.update_logging) - + tornado.ioloop.IOLoop.current().call_later( + LOGGING_INTERVAL, self.update_logging + ) def on_close(self): if self.application.verbose: - logger.info('Socket closed!') + logger.info("Socket closed!") def send_log(self, logger, logging): - self.write_message({'type': 'log', - 'logger': logger, - 'logging': logging}) + self.write_message({"type": "log", "logger": logger, "logging": logging}) @property def simulation_name(self): - return self.config.get('name', 'NoSimulationRunning') + return self.config.get("name", "NoSimulationRunning") @run_on_executor def nonblocking(self, config): @@ -174,28 +218,31 @@ class SocketHandler(tornado.websocket.WebSocketHandler): @tornado.gen.coroutine def run_simulation(self): # Run simulation and capture logs - logger.info('Running simulation!') - if 'visualization_params' in self.config: - del self.config['visualization_params'] + logger.info("Running simulation!") + if "visualization_params" in self.config: + del self.config["visualization_params"] with self.logging(self.simulation_name): try: config = dict(**self.config) - config['outdir'] = os.path.join(self.application.outdir, config['name']) - config['dump'] = self.application.dump + config["outdir"] = os.path.join(self.application.outdir, config["name"]) + config["dump"] = self.application.dump self.trials = yield self.nonblocking(config) - self.write_message({'type': 'trials', - 'data': list(trial.name for trial in self.trials) }) + self.write_message( + { + "type": "trials", + "data": list(trial.name for trial in self.trials), + } + ) except Exception as ex: - error = 'Something went wrong:\n\t{}'.format(ex) + error = "Something went wrong:\n\t{}".format(ex) logging.info(error) - self.write_message({'type': 'error', - 'error': error}) - self.send_log('ERROR.' + self.simulation_name, error) + self.write_message({"type": "error", "error": error}) + self.send_log("ERROR." + self.simulation_name, error) def get_trial(self, trial): - logger.info('Available trials: %s ' % len(self.trials)) - logger.info('Ask for : %s' % trial) + logger.info("Available trials: %s " % len(self.trials)) + logger.info("Ask for : %s" % trial) trial = self.trials[trial] G = trial.history_to_graph() return nx.node_link_data(G) @@ -215,25 +262,28 @@ class SocketHandler(tornado.websocket.WebSocketHandler): self.logger_application.removeHandler(ch) self.capture_logging = False return self.capture_logging - + class ModularServer(tornado.web.Application): - """ Main visualization application. """ + """Main visualization application.""" port = 8001 - page_handler = (r'/', PageHandler) - socket_handler = (r'/ws', SocketHandler) - static_handler = (r'/(.*)', tornado.web.StaticFileHandler, - {'path': os.path.join(ROOT, 'static')}) - local_handler = (r'/local/(.*)', tornado.web.StaticFileHandler, - {'path': ''}) + page_handler = (r"/", PageHandler) + socket_handler = (r"/ws", SocketHandler) + static_handler = ( + r"/(.*)", + tornado.web.StaticFileHandler, + {"path": os.path.join(ROOT, "static")}, + ) + local_handler = (r"/local/(.*)", tornado.web.StaticFileHandler, {"path": ""}) handlers = [page_handler, socket_handler, static_handler, local_handler] - settings = {'debug': True, - 'template_path': ROOT + '/templates'} + settings = {"debug": True, "template_path": ROOT + "/templates"} + + def __init__( + self, dump=False, outdir="output", name="SOIL", verbose=True, *args, **kwargs + ): - def __init__(self, dump=False, outdir='output', name='SOIL', verbose=True, *args, **kwargs): - self.verbose = verbose self.name = name self.dump = dump @@ -243,12 +293,12 @@ class ModularServer(tornado.web.Application): super().__init__(self.handlers, **self.settings) def launch(self, port=None): - """ Run the app. """ - + """Run the app.""" + if port is not None: self.port = port - url = 'http://127.0.0.1:{PORT}'.format(PORT=self.port) - print('Interface starting at {url}'.format(url=url)) + url = "http://127.0.0.1:{PORT}".format(PORT=self.port) + print("Interface starting at {url}".format(url=url)) self.listen(self.port) # webbrowser.open(url) tornado.ioloop.IOLoop.instance().start() @@ -263,12 +313,22 @@ def run(*args, **kwargs): def main(): import argparse - parser = argparse.ArgumentParser(description='Visualization of a Graph Model') + parser = argparse.ArgumentParser(description="Visualization of a Graph Model") - parser.add_argument('--name', '-n', nargs=1, default='SOIL', help='name of the simulation') - parser.add_argument('--dump', '-d', help='dumping results in folder output', action='store_true') - parser.add_argument('--port', '-p', nargs=1, default=8001, help='port for launching the server') - parser.add_argument('--verbose', '-v', help='verbose mode', action='store_true') + parser.add_argument( + "--name", "-n", nargs=1, default="SOIL", help="name of the simulation" + ) + parser.add_argument( + "--dump", "-d", help="dumping results in folder output", action="store_true" + ) + parser.add_argument( + "--port", "-p", nargs=1, default=8001, help="port for launching the server" + ) + parser.add_argument("--verbose", "-v", help="verbose mode", action="store_true") args = parser.parse_args() - run(name=args.name, port=(args.port[0] if isinstance(args.port, list) else args.port), verbose=args.verbose) + run( + name=args.name, + port=(args.port[0] if isinstance(args.port, list) else args.port), + verbose=args.verbose, + ) diff --git a/soil/web/__main__.py b/soil/web/__main__.py index 5c211a8..29c2e0a 100644 --- a/soil/web/__main__.py +++ b/soil/web/__main__.py @@ -2,4 +2,4 @@ from . import main if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/soil/web/run.py b/soil/web/run.py index a0b1416..b13ca56 100644 --- a/soil/web/run.py +++ b/soil/web/run.py @@ -4,20 +4,33 @@ from simulator import Simulator def run(simulator, name="SOIL", port=8001, verbose=False): - server = ModularServer(simulator, name=(name[0] if isinstance(name, list) else name), verbose=verbose) + server = ModularServer( + simulator, name=(name[0] if isinstance(name, list) else name), verbose=verbose + ) server.port = port server.launch() if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Visualization of a Graph Model') + parser = argparse.ArgumentParser(description="Visualization of a Graph Model") - parser.add_argument('--name', '-n', nargs=1, default='SOIL', help='name of the simulation') - parser.add_argument('--dump', '-d', help='dumping results in folder output', action='store_true') - parser.add_argument('--port', '-p', nargs=1, default=8001, help='port for launching the server') - parser.add_argument('--verbose', '-v', help='verbose mode', action='store_true') + parser.add_argument( + "--name", "-n", nargs=1, default="SOIL", help="name of the simulation" + ) + parser.add_argument( + "--dump", "-d", help="dumping results in folder output", action="store_true" + ) + parser.add_argument( + "--port", "-p", nargs=1, default=8001, help="port for launching the server" + ) + parser.add_argument("--verbose", "-v", help="verbose mode", action="store_true") args = parser.parse_args() soil = Simulator(dump=args.dump) - run(soil, name=args.name, port=(args.port[0] if isinstance(args.port, list) else args.port), verbose=args.verbose) + run( + soil, + name=args.name, + port=(args.port[0] if isinstance(args.port, list) else args.port), + verbose=args.verbose, + )