From 3776c4e5c5958a7389a3ffe65087c28d0d5d533b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=2E=20Fernando=20S=C3=A1nchez?= Date: Mon, 17 Oct 2022 21:36:21 +0200 Subject: [PATCH] Refactor * Removed references to `set_state` * Split some functionality from `agents` into separate files (`fsm` and `network_agents`) * Rename `neighboring_agents` to `neighbors` * Delete some spurious functions --- examples/mesa/social_wealth.py | 2 +- examples/newsspread/newsspread.py | 30 ++- examples/pubcrawl/pubcrawl.py | 2 +- examples/terrorism/TerroristNetworkModel.py | 12 +- soil/agents/BassModel.py | 2 +- soil/agents/BigMarketModel.py | 8 +- soil/agents/CounterModel.py | 4 +- soil/agents/IndependentCascadeModel.py | 2 +- soil/agents/ModelM2.py | 38 +-- soil/agents/SISaModel.py | 10 +- soil/agents/SentimentCorrelationModel.py | 8 +- soil/agents/__init__.py | 271 +------------------- soil/agents/fsm.py | 133 ++++++++++ soil/agents/network_agents.py | 82 ++++++ tests/test_agents.py | 14 + tests/test_main.py | 24 +- 16 files changed, 295 insertions(+), 347 deletions(-) create mode 100644 soil/agents/fsm.py create mode 100644 soil/agents/network_agents.py diff --git a/examples/mesa/social_wealth.py b/examples/mesa/social_wealth.py index b4ae99f..8085543 100644 --- a/examples/mesa/social_wealth.py +++ b/examples/mesa/social_wealth.py @@ -58,7 +58,7 @@ class SocialMoneyAgent(NetworkAgent, MoneyAgent): def give_money(self): cellmates = set(self.model.grid.get_cell_list_contents([self.pos])) - friends = set(self.get_neighboring_agents()) + friends = set(self.get_neighbors()) self.info("Trying to give money") self.info("Cellmates: ", cellmates) self.info("Friends: ", friends) diff --git a/examples/newsspread/newsspread.py b/examples/newsspread/newsspread.py index f747f8e..bfcdbc9 100644 --- a/examples/newsspread/newsspread.py +++ b/examples/newsspread/newsspread.py @@ -8,10 +8,9 @@ class DumbViewer(FSM, NetworkAgent): its neighbors once it's infected. """ - defaults = { - "prob_neighbor_spread": 0.5, - "prob_tv_spread": 0.1, - } + prob_neighbor_spread = 0.5 + prob_tv_spread = 0.1 + has_been_infected = False @default_state @state @@ -19,10 +18,12 @@ class DumbViewer(FSM, NetworkAgent): if self["has_tv"]: if self.prob(self.model["prob_tv_spread"]): return self.infected + if self.has_been_infected: + return self.infected @state def infected(self): - for neighbor in self.get_neighboring_agents(state_id=self.neutral.id): + for neighbor in self.get_neighbors(state_id=self.neutral.id): if self.prob(self.model["prob_neighbor_spread"]): neighbor.infect() @@ -33,7 +34,7 @@ class DumbViewer(FSM, NetworkAgent): HerdViewer might not become infected right away """ - self.set_state(self.infected) + self.has_been_infected = True class HerdViewer(DumbViewer): @@ -43,12 +44,12 @@ class HerdViewer(DumbViewer): def infect(self): """Notice again that this is NOT a state. See DumbViewer.infect for reference""" - infected = self.count_neighboring_agents(state_id=self.infected.id) - total = self.count_neighboring_agents() + infected = self.count_neighbors(state_id=self.infected.id) + total = self.count_neighbors() prob_infect = self.model["prob_neighbor_spread"] * infected / total self.debug("prob_infect", prob_infect) if self.prob(prob_infect): - self.set_state(self.infected) + self.has_been_infected = True class WiseViewer(HerdViewer): @@ -65,7 +66,7 @@ class WiseViewer(HerdViewer): @state def cured(self): prob_cure = self.model["prob_neighbor_cure"] - for neighbor in self.get_neighboring_agents(state_id=self.infected.id): + for neighbor in self.get_neighbors(state_id=self.infected.id): if self.prob(prob_cure): try: neighbor.cure() @@ -73,13 +74,14 @@ class WiseViewer(HerdViewer): self.debug("Viewer {} cannot be cured".format(neighbor.id)) def cure(self): - self.set_state(self.cured.id) + self.has_been_cured = True @state def infected(self): - cured = max(self.count_neighboring_agents(self.cured.id), 1.0) - infected = max(self.count_neighboring_agents(self.infected.id), 1.0) + if self.has_been_cured: + return self.cured + cured = max(self.count_neighbors(self.cured.id), 1.0) + infected = max(self.count_neighbors(self.infected.id), 1.0) prob_cure = self.model["prob_neighbor_cure"] * (cured / infected) if self.prob(prob_cure): return self.cured - return self.set_state(super().infected) diff --git a/examples/pubcrawl/pubcrawl.py b/examples/pubcrawl/pubcrawl.py index 110a44c..be8a2b4 100644 --- a/examples/pubcrawl/pubcrawl.py +++ b/examples/pubcrawl/pubcrawl.py @@ -89,7 +89,7 @@ class Patron(FSM, NetworkAgent): if self["pub"] != None: return self.sober_in_pub self.debug("I am looking for a pub") - group = list(self.get_neighboring_agents()) + group = list(self.get_neighbors()) for pub in self.model.available_pubs(): self.debug("We're trying to get into {}: total: {}".format(pub, len(group))) if self.model.enter(pub, self, *group): diff --git a/examples/terrorism/TerroristNetworkModel.py b/examples/terrorism/TerroristNetworkModel.py index 8fa6563..fe3034f 100644 --- a/examples/terrorism/TerroristNetworkModel.py +++ b/examples/terrorism/TerroristNetworkModel.py @@ -49,7 +49,7 @@ class TerroristSpreadModel(FSM, Geo): @state def civilian(self): - neighbours = list(self.get_neighboring_agents(agent_class=TerroristSpreadModel)) + neighbours = list(self.get_neighbors(agent_class=TerroristSpreadModel)) if len(neighbours) > 0: # Only interact with some of the neighbors interactions = list( @@ -73,7 +73,7 @@ class TerroristSpreadModel(FSM, Geo): @state def leader(self): self.mean_belief = self.mean_belief ** (1 - self.terrorist_additional_influence) - for neighbour in self.get_neighboring_agents( + for neighbour in self.get_neighbors( state_id=[self.terrorist.id, self.leader.id] ): if self.betweenness(neighbour) > self.betweenness(self): @@ -158,7 +158,7 @@ class TrainingAreaModel(FSM, Geo): @default_state @state def terrorist(self): - for neighbour in self.get_neighboring_agents(agent_class=TerroristSpreadModel): + for neighbour in self.get_neighbors(agent_class=TerroristSpreadModel): if neighbour.vulnerability > self.min_vulnerability: neighbour.vulnerability = neighbour.vulnerability ** ( 1 - self.training_influence @@ -187,7 +187,7 @@ class HavenModel(FSM, Geo): self.max_vulnerability = model.environment_params["max_vulnerability"] def get_occupants(self, **kwargs): - return self.get_neighboring_agents(agent_class=TerroristSpreadModel, **kwargs) + return self.get_neighbors(agent_class=TerroristSpreadModel, **kwargs) @state def civilian(self): @@ -243,7 +243,7 @@ class TerroristNetworkModel(TerroristSpreadModel): return super().leader() def update_relationships(self): - if self.count_neighboring_agents(state_id=self.civilian.id) == 0: + if self.count_neighbors(state_id=self.civilian.id) == 0: close_ups = set( self.geo_search( radius=self.vision_range, agent_class=TerroristNetworkModel @@ -258,7 +258,7 @@ class TerroristNetworkModel(TerroristSpreadModel): ) neighbours = set( agent.id - for agent in self.get_neighboring_agents( + for agent in self.get_neighbors( agent_class=TerroristNetworkModel ) ) diff --git a/soil/agents/BassModel.py b/soil/agents/BassModel.py index 416063d..4410d82 100644 --- a/soil/agents/BassModel.py +++ b/soil/agents/BassModel.py @@ -20,7 +20,7 @@ class BassModel(FSM): self.sentimentCorrelation = 1 return self.aware else: - aware_neighbors = self.get_neighboring_agents(state_id=self.aware.id) + aware_neighbors = self.get_neighbors(state_id=self.aware.id) num_neighbors_aware = len(aware_neighbors) if self.prob((self["imitation_prob"] * num_neighbors_aware)): self.sentimentCorrelation = 1 diff --git a/soil/agents/BigMarketModel.py b/soil/agents/BigMarketModel.py index 5a93b23..e606e0a 100644 --- a/soil/agents/BigMarketModel.py +++ b/soil/agents/BigMarketModel.py @@ -24,14 +24,14 @@ class BigMarketModel(FSM): self.type = "" if self.id < len(self.enterprises): # Enterprises - self.set_state(self.enterprise.id) + self._set_state(self.enterprise.id) self.type = "Enterprise" self.tweet_probability = environment.environment_params[ "tweet_probability_enterprises" ][self.id] else: # normal users self.type = "User" - self.set_state(self.user.id) + self._set_state(self.user.id) self.tweet_probability = environment.environment_params[ "tweet_probability_users" ] @@ -49,7 +49,7 @@ class BigMarketModel(FSM): def enterprise(self): if self.random.random() < self.tweet_probability: # Tweets - aware_neighbors = self.get_neighboring_agents( + aware_neighbors = self.get_neighbors( state_id=self.number_of_enterprises ) # Nodes neighbour users for x in aware_neighbors: @@ -96,7 +96,7 @@ class BigMarketModel(FSM): ] = self.sentiment_about[i] def userTweets(self, sentiment, enterprise): - aware_neighbors = self.get_neighboring_agents( + aware_neighbors = self.get_neighbors( state_id=self.number_of_enterprises ) # Nodes neighbours users for x in aware_neighbors: diff --git a/soil/agents/CounterModel.py b/soil/agents/CounterModel.py index 731c61d..6cd41fb 100644 --- a/soil/agents/CounterModel.py +++ b/soil/agents/CounterModel.py @@ -14,7 +14,7 @@ class CounterModel(NetworkAgent): def step(self): # Outside effects total = len(list(self.model.schedule._agents)) - neighbors = len(list(self.get_neighboring_agents())) + neighbors = len(list(self.get_neighbors())) self["times"] = self.get("times", 0) + 1 self["neighbors"] = neighbors self["total"] = total @@ -33,7 +33,7 @@ class AggregatedCounter(NetworkAgent): def step(self): # Outside effects self["times"] += 1 - neighbors = len(list(self.get_neighboring_agents())) + neighbors = len(list(self.get_neighbors())) self["neighbors"] += neighbors total = len(list(self.model.schedule.agents)) self["total"] += total diff --git a/soil/agents/IndependentCascadeModel.py b/soil/agents/IndependentCascadeModel.py index d3280e0..e332b07 100644 --- a/soil/agents/IndependentCascadeModel.py +++ b/soil/agents/IndependentCascadeModel.py @@ -36,7 +36,7 @@ class IndependentCascadeModel(BaseAgent): # Imitation effects if self.state["id"] == 0: - aware_neighbors = self.get_neighboring_agents(state_id=1) + aware_neighbors = self.get_neighbors(state_id=1) for x in aware_neighbors: if x.state["time_awareness"] == (self.env.now - 1): aware_neighbors_1_time_step.append(x) diff --git a/soil/agents/ModelM2.py b/soil/agents/ModelM2.py index b22cafa..4fac2b8 100644 --- a/soil/agents/ModelM2.py +++ b/soil/agents/ModelM2.py @@ -71,7 +71,7 @@ class SpreadModelM2(BaseAgent): def neutral_behaviour(self): # Infected - infected_neighbors = self.get_neighboring_agents(state_id=1) + infected_neighbors = self.get_neighbors(state_id=1) if len(infected_neighbors) > 0: if self.prob(self.prob_neutral_making_denier): self.state["id"] = 3 # Vaccinated making denier @@ -79,7 +79,7 @@ class SpreadModelM2(BaseAgent): def infected_behaviour(self): # Neutral - neutral_neighbors = self.get_neighboring_agents(state_id=0) + neutral_neighbors = self.get_neighbors(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_infect): neighbor.state["id"] = 1 # Infected @@ -87,13 +87,13 @@ class SpreadModelM2(BaseAgent): def cured_behaviour(self): # Vaccinate - neutral_neighbors = self.get_neighboring_agents(state_id=0) + neutral_neighbors = self.get_neighbors(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_cured_vaccinate_neutral): neighbor.state["id"] = 3 # Vaccinated # Cure - infected_neighbors = self.get_neighboring_agents(state_id=1) + infected_neighbors = self.get_neighbors(state_id=1) for neighbor in infected_neighbors: if self.prob(self.prob_cured_healing_infected): neighbor.state["id"] = 2 # Cured @@ -101,19 +101,19 @@ class SpreadModelM2(BaseAgent): def vaccinated_behaviour(self): # Cure - infected_neighbors = self.get_neighboring_agents(state_id=1) + infected_neighbors = self.get_neighbors(state_id=1) for neighbor in infected_neighbors: if self.prob(self.prob_cured_healing_infected): neighbor.state["id"] = 2 # Cured # Vaccinate - neutral_neighbors = self.get_neighboring_agents(state_id=0) + neutral_neighbors = self.get_neighbors(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_cured_vaccinate_neutral): neighbor.state["id"] = 3 # Vaccinated # Generate anti-rumor - infected_neighbors_2 = self.get_neighboring_agents(state_id=1) + infected_neighbors_2 = self.get_neighbors(state_id=1) for neighbor in infected_neighbors_2: if self.prob(self.prob_generate_anti_rumor): neighbor.state["id"] = 2 # Cured @@ -191,7 +191,7 @@ class ControlModelM2(BaseAgent): self.state["visible"] = False # Infected - infected_neighbors = self.get_neighboring_agents(state_id=1) + infected_neighbors = self.get_neighbors(state_id=1) if len(infected_neighbors) > 0: if self.random(self.prob_neutral_making_denier): self.state["id"] = 3 # Vaccinated making denier @@ -199,7 +199,7 @@ class ControlModelM2(BaseAgent): def infected_behaviour(self): # Neutral - neutral_neighbors = self.get_neighboring_agents(state_id=0) + neutral_neighbors = self.get_neighbors(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_infect): neighbor.state["id"] = 1 # Infected @@ -209,13 +209,13 @@ class ControlModelM2(BaseAgent): self.state["visible"] = True # Vaccinate - neutral_neighbors = self.get_neighboring_agents(state_id=0) + neutral_neighbors = self.get_neighbors(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_cured_vaccinate_neutral): neighbor.state["id"] = 3 # Vaccinated # Cure - infected_neighbors = self.get_neighboring_agents(state_id=1) + infected_neighbors = self.get_neighbors(state_id=1) for neighbor in infected_neighbors: if self.prob(self.prob_cured_healing_infected): neighbor.state["id"] = 2 # Cured @@ -224,47 +224,47 @@ class ControlModelM2(BaseAgent): self.state["visible"] = True # Cure - infected_neighbors = self.get_neighboring_agents(state_id=1) + infected_neighbors = self.get_neighbors(state_id=1) for neighbor in infected_neighbors: if self.prob(self.prob_cured_healing_infected): neighbor.state["id"] = 2 # Cured # Vaccinate - neutral_neighbors = self.get_neighboring_agents(state_id=0) + neutral_neighbors = self.get_neighbors(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_cured_vaccinate_neutral): neighbor.state["id"] = 3 # Vaccinated # Generate anti-rumor - infected_neighbors_2 = self.get_neighboring_agents(state_id=1) + infected_neighbors_2 = self.get_neighbors(state_id=1) for neighbor in infected_neighbors_2: if self.prob(self.prob_generate_anti_rumor): neighbor.state["id"] = 2 # Cured def beacon_off_behaviour(self): self.state["visible"] = False - infected_neighbors = self.get_neighboring_agents(state_id=1) + infected_neighbors = self.get_neighbors(state_id=1) if len(infected_neighbors) > 0: self.state["id"] == 5 # Beacon on def beacon_on_behaviour(self): self.state["visible"] = False # Cure (M2 feature added) - infected_neighbors = self.get_neighboring_agents(state_id=1) + infected_neighbors = self.get_neighbors(state_id=1) for neighbor in infected_neighbors: if self.prob(self.prob_generate_anti_rumor): neighbor.state["id"] = 2 # Cured - neutral_neighbors_infected = neighbor.get_neighboring_agents(state_id=0) + neutral_neighbors_infected = neighbor.get_neighbors(state_id=0) for neighbor in neutral_neighbors_infected: if self.prob(self.prob_generate_anti_rumor): neighbor.state["id"] = 3 # Vaccinated - infected_neighbors_infected = neighbor.get_neighboring_agents(state_id=1) + infected_neighbors_infected = neighbor.get_neighbors(state_id=1) for neighbor in infected_neighbors_infected: if self.prob(self.prob_generate_anti_rumor): neighbor.state["id"] = 2 # Cured # Vaccinate - neutral_neighbors = self.get_neighboring_agents(state_id=0) + neutral_neighbors = self.get_neighbors(state_id=0) for neighbor in neutral_neighbors: if self.prob(self.prob_cured_vaccinate_neutral): neighbor.state["id"] = 3 # Vaccinated diff --git a/soil/agents/SISaModel.py b/soil/agents/SISaModel.py index e298e8a..45d9328 100644 --- a/soil/agents/SISaModel.py +++ b/soil/agents/SISaModel.py @@ -69,10 +69,10 @@ class SISaModel(FSM): return self.content # Infected - discontent_neighbors = self.count_neighboring_agents(state_id=self.discontent) + discontent_neighbors = self.count_neighbors(state_id=self.discontent) if self.prob(scontent_neighbors * self.neutral_discontent_infected_prob): return self.discontent - content_neighbors = self.count_neighboring_agents(state_id=self.content.id) + content_neighbors = self.count_neighbors(state_id=self.content.id) if self.prob(s * self.neutral_content_infected_prob): return self.content return self.neutral @@ -84,7 +84,7 @@ class SISaModel(FSM): return self.neutral # Superinfected - content_neighbors = self.count_neighboring_agents(state_id=self.content.id) + content_neighbors = self.count_neighbors(state_id=self.content.id) if self.prob(s * self.discontent_content): return self.content return self.discontent @@ -96,9 +96,7 @@ class SISaModel(FSM): return self.neutral # Superinfected - discontent_neighbors = self.count_neighboring_agents( - state_id=self.discontent.id - ) + discontent_neighbors = self.count_neighbors(state_id=self.discontent.id) if self.prob(scontent_neighbors * self.content_discontent): self.discontent return self.content diff --git a/soil/agents/SentimentCorrelationModel.py b/soil/agents/SentimentCorrelationModel.py index 721d026..751a59a 100644 --- a/soil/agents/SentimentCorrelationModel.py +++ b/soil/agents/SentimentCorrelationModel.py @@ -41,25 +41,25 @@ class SentimentCorrelationModel(BaseAgent): sad_neighbors_1_time_step = [] disgusted_neighbors_1_time_step = [] - angry_neighbors = self.get_neighboring_agents(state_id=1) + angry_neighbors = self.get_neighbors(state_id=1) for x in angry_neighbors: if x.state["time_awareness"][0] > (self.env.now - 500): angry_neighbors_1_time_step.append(x) num_neighbors_angry = len(angry_neighbors_1_time_step) - joyful_neighbors = self.get_neighboring_agents(state_id=2) + joyful_neighbors = self.get_neighbors(state_id=2) for x in joyful_neighbors: if x.state["time_awareness"][1] > (self.env.now - 500): joyful_neighbors_1_time_step.append(x) num_neighbors_joyful = len(joyful_neighbors_1_time_step) - sad_neighbors = self.get_neighboring_agents(state_id=3) + sad_neighbors = self.get_neighbors(state_id=3) for x in sad_neighbors: if x.state["time_awareness"][2] > (self.env.now - 500): sad_neighbors_1_time_step.append(x) num_neighbors_sad = len(sad_neighbors_1_time_step) - disgusted_neighbors = self.get_neighboring_agents(state_id=4) + disgusted_neighbors = self.get_neighbors(state_id=4) for x in disgusted_neighbors: if x.state["time_awareness"][3] > (self.env.now - 500): disgusted_neighbors_1_time_step.append(x) diff --git a/soil/agents/__init__.py b/soil/agents/__init__.py index 714b15e..9b5736b 100644 --- a/soil/agents/__init__.py +++ b/soil/agents/__init__.py @@ -243,223 +243,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent): return f"{self.__class__.__name__}({self.unique_id})" -class NetworkAgent(BaseAgent): - def __init__(self, *args, topology, node_id, **kwargs): - super().__init__(*args, **kwargs) - - assert topology is not None - assert node_id is not None - self.G = topology - assert self.G - self.node_id = node_id - - def count_neighboring_agents(self, state_id=None, **kwargs): - return len(self.get_neighboring_agents(state_id=state_id, **kwargs)) - - def get_neighboring_agents(self, **kwargs): - return list(self.iter_agents(limit_neighbors=True, **kwargs)) - - def add_edge(self, other): - assert self.node_id - assert other.node_id - assert self.node_id in self.G.nodes - assert other.node_id in self.G.nodes - self.topology.add_edge(self.node_id, other.node_id) - - @property - def node(self): - return self.topology.nodes[self.node_id] - - def iter_agents(self, unique_id=None, *, limit_neighbors=False, **kwargs): - unique_ids = None - if isinstance(unique_id, list): - unique_ids = set(unique_id) - elif unique_id is not None: - unique_ids = set( - [ - unique_id, - ] - ) - - if limit_neighbors: - neighbor_ids = set() - for node_id in self.G.neighbors(self.node_id): - if self.G.nodes[node_id].get("agent") is not None: - neighbor_ids.add(node_id) - if unique_ids: - unique_ids = unique_ids & neighbor_ids - else: - unique_ids = neighbor_ids - if not unique_ids: - return - unique_ids = list(unique_ids) - yield from super().iter_agents(unique_id=unique_ids, **kwargs) - - def subgraph(self, center=True, **kwargs): - include = [self] if center else [] - G = self.G.subgraph( - n.node_id for n in list(self.get_agents(**kwargs) + include) - ) - return G - - def remove_node(self): - print(f"Removing node for {self.unique_id}: {self.node_id}") - self.G.remove_node(self.node_id) - self.node_id = None - - def add_edge(self, other, edge_attr_dict=None, *edge_attrs): - if self.node_id not in self.G.nodes(data=False): - raise ValueError( - "{} not in list of existing agents in the network".format( - self.unique_id - ) - ) - if other.node_id not in self.G.nodes(data=False): - raise ValueError( - "{} not in list of existing agents in the network".format(other) - ) - - self.G.add_edge( - self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs - ) - - def die(self, remove=True): - if not self.alive: - return - if remove: - self.remove_node() - return super().die() - - -def state(name=None): - def decorator(func, name=None): - """ - A state function should return either a state id, or a tuple (state_id, when) - The default value for state_id is the current state id. - The default value for when is the interval defined in the environment. - """ - if inspect.isgeneratorfunction(func): - orig_func = func - - @wraps(func) - def func(self): - while True: - if not self._coroutine: - self._coroutine = orig_func(self) - try: - n = next(self._coroutine) - if n: - return None, n - return - except StopIteration as ex: - self._coroutine = None - next_state = ex.value - if next_state is not None: - self._set_state(next_state) - return next_state - - func.id = name or func.__name__ - func.is_default = False - return func - - if callable(name): - return decorator(name) - else: - return partial(decorator, name=name) - - -def default_state(func): - func.is_default = True - return func - - -class MetaFSM(MetaAgent): - def __new__(mcls, name, bases, namespace): - states = {} - # Re-use states from inherited classes - default_state = None - for i in bases: - if isinstance(i, MetaFSM): - for state_id, state in i._states.items(): - if state.is_default: - default_state = state - states[state_id] = state - - # Add new states - for attr, func in namespace.items(): - if hasattr(func, "id"): - if func.is_default: - default_state = func - states[func.id] = func - - namespace.update( - { - "_default_state": default_state, - "_states": states, - } - ) - - return super(MetaFSM, mcls).__new__( - mcls=mcls, name=name, bases=bases, namespace=namespace - ) - - -class FSM(BaseAgent, metaclass=MetaFSM): - def __init__(self, **kwargs): - super(FSM, self).__init__(**kwargs) - if not hasattr(self, "state_id"): - if not self._default_state: - raise ValueError( - "No default state specified for {}".format(self.unique_id) - ) - self.state_id = self._default_state.id - - self._coroutine = None - self._set_state(self.state_id) - - def step(self): - self.debug(f"Agent {self.unique_id} @ state {self.state_id}") - default_interval = super().step() - - next_state = self._states[self.state_id](self) - - when = None - try: - next_state, *when = next_state - if not when: - when = None - elif len(when) == 1: - when = when[0] - else: - raise ValueError( - "Too many values returned. Only state (and time) allowed" - ) - except TypeError: - pass - - if next_state is not None: - self._set_state(next_state) - - return when or default_interval - - def _set_state(self, state, when=None): - if hasattr(state, "id"): - state = state.id - if state not in self._states: - raise ValueError("{} is not a valid state".format(state)) - self.state_id = state - if when is not None: - self.model.schedule.add(self, when=when) - return state - - def die(self): - return self.dead, super().die() - - @state - def dead(self): - return self.die() - - def prob(prob, random): """ A true/False uniform distribution with a given probability. @@ -525,7 +308,7 @@ def calculate_distribution(network_agents=None, agent_class=None): return network_agents -def serialize_type(agent_class, known_modules=[], **kwargs): +def _serialize_type(agent_class, known_modules=[], **kwargs): if isinstance(agent_class, str): return agent_class known_modules += ["soil.agents"] @@ -534,20 +317,7 @@ def serialize_type(agent_class, known_modules=[], **kwargs): ] # Get the name of the class -def serialize_definition(network_agents, known_modules=[]): - """ - When serializing an agent distribution, remove the thresholds, in order - to avoid cluttering the YAML definition file. - """ - d = deepcopy(list(network_agents)) - for v in d: - if "threshold" in v: - del v["threshold"] - v["agent_class"] = serialize_type(v["agent_class"], known_modules=known_modules) - return d - - -def deserialize_type(agent_class, known_modules=[]): +def _deserialize_type(agent_class, known_modules=[]): if not isinstance(agent_class, str): return agent_class known = known_modules + ["soil.agents", "soil.agents.custom"] @@ -555,31 +325,6 @@ def deserialize_type(agent_class, known_modules=[]): return agent_class -def deserialize_definition(ind, **kwargs): - d = deepcopy(ind) - for v in d: - v["agent_class"] = deserialize_type(v["agent_class"], **kwargs) - return d - - -def _validate_states(states, topology): - """Validate states to avoid ignoring states during initialization""" - states = states or [] - if isinstance(states, dict): - for x in states: - assert x in topology.nodes - else: - assert len(states) <= len(topology) - return states - - -def _convert_agent_classs(ind, to_string=False, **kwargs): - """Convenience method to allow specifying agents by class or class name.""" - if to_string: - return serialize_definition(ind, **kwargs) - return deserialize_definition(ind, **kwargs) - - class AgentView(Mapping, Set): """A lazy-loaded list of agents.""" @@ -663,7 +408,7 @@ def filter_agents( state_id = tuple([state_id]) if agent_class is not None: - agent_class = deserialize_type(agent_class) + agent_class = _deserialize_type(agent_class) try: agent_class = tuple(agent_class) except TypeError: @@ -703,14 +448,6 @@ def from_config( default = cfg or config.AgentConfig() if not isinstance(cfg, config.AgentConfig): cfg = config.AgentConfig(**cfg) - return _agents_from_config(cfg, topology=topology, random=random) - - -def _agents_from_config( - cfg: config.AgentConfig, topology: nx.Graph, random -) -> List[Dict[str, Any]]: - if cfg and not isinstance(cfg, config.AgentConfig): - cfg = config.AgentConfig(**cfg) agents = [] @@ -878,6 +615,8 @@ def _from_distro( return agents +from .network_agents import * +from .fsm import * from .BassModel import * from .BigMarketModel import * from .IndependentCascadeModel import * diff --git a/soil/agents/fsm.py b/soil/agents/fsm.py new file mode 100644 index 0000000..729313d --- /dev/null +++ b/soil/agents/fsm.py @@ -0,0 +1,133 @@ +from . import MetaAgent, BaseAgent + +from functools import partial +import inspect + + +def state(name=None): + def decorator(func, name=None): + """ + A state function should return either a state id, or a tuple (state_id, when) + The default value for state_id is the current state id. + The default value for when is the interval defined in the environment. + """ + if inspect.isgeneratorfunction(func): + orig_func = func + + @wraps(func) + def func(self): + while True: + if not self._coroutine: + self._coroutine = orig_func(self) + try: + n = next(self._coroutine) + if n: + return None, n + return + except StopIteration as ex: + self._coroutine = None + next_state = ex.value + if next_state is not None: + self._set_state(next_state) + return next_state + + func.id = name or func.__name__ + func.is_default = False + return func + + if callable(name): + return decorator(name) + else: + return partial(decorator, name=name) + + +def default_state(func): + func.is_default = True + return func + + +class MetaFSM(MetaAgent): + def __new__(mcls, name, bases, namespace): + states = {} + # Re-use states from inherited classes + default_state = None + for i in bases: + if isinstance(i, MetaFSM): + for state_id, state in i._states.items(): + if state.is_default: + default_state = state + states[state_id] = state + + # Add new states + for attr, func in namespace.items(): + if hasattr(func, "id"): + if func.is_default: + default_state = func + states[func.id] = func + + namespace.update( + { + "_default_state": default_state, + "_states": states, + } + ) + + return super(MetaFSM, mcls).__new__( + mcls=mcls, name=name, bases=bases, namespace=namespace + ) + + +class FSM(BaseAgent, metaclass=MetaFSM): + def __init__(self, **kwargs): + super(FSM, self).__init__(**kwargs) + if not hasattr(self, "state_id"): + if not self._default_state: + raise ValueError( + "No default state specified for {}".format(self.unique_id) + ) + self.state_id = self._default_state.id + + self._coroutine = None + self._set_state(self.state_id) + + def step(self): + self.debug(f"Agent {self.unique_id} @ state {self.state_id}") + default_interval = super().step() + + next_state = self._states[self.state_id](self) + + when = None + try: + next_state, *when = next_state + if not when: + when = None + elif len(when) == 1: + when = when[0] + else: + raise ValueError( + "Too many values returned. Only state (and time) allowed" + ) + except TypeError: + pass + + if next_state is not None: + self._set_state(next_state) + + return when or default_interval + + def _set_state(self, state, when=None): + if hasattr(state, "id"): + state = state.id + if state not in self._states: + raise ValueError("{} is not a valid state".format(state)) + self.state_id = state + if when is not None: + self.model.schedule.add(self, when=when) + return state + + def die(self): + return self.dead, super().die() + + @state + def dead(self): + return self.die() diff --git a/soil/agents/network_agents.py b/soil/agents/network_agents.py new file mode 100644 index 0000000..d9950cf --- /dev/null +++ b/soil/agents/network_agents.py @@ -0,0 +1,82 @@ +from . import BaseAgent + + +class NetworkAgent(BaseAgent): + def __init__(self, *args, topology, node_id, **kwargs): + super().__init__(*args, **kwargs) + + assert topology is not None + assert node_id is not None + self.G = topology + assert self.G + self.node_id = node_id + + def count_neighbors(self, state_id=None, **kwargs): + return len(self.get_neighbors(state_id=state_id, **kwargs)) + + def get_neighbors(self, **kwargs): + return list(self.iter_agents(limit_neighbors=True, **kwargs)) + + @property + def node(self): + return self.G.nodes[self.node_id] + + def iter_agents(self, unique_id=None, *, limit_neighbors=False, **kwargs): + unique_ids = None + if isinstance(unique_id, list): + unique_ids = set(unique_id) + elif unique_id is not None: + unique_ids = set( + [ + unique_id, + ] + ) + + if limit_neighbors: + neighbor_ids = set() + for node_id in self.G.neighbors(self.node_id): + if self.G.nodes[node_id].get("agent") is not None: + neighbor_ids.add(node_id) + if unique_ids: + unique_ids = unique_ids & neighbor_ids + else: + unique_ids = neighbor_ids + if not unique_ids: + return + unique_ids = list(unique_ids) + yield from super().iter_agents(unique_id=unique_ids, **kwargs) + + def subgraph(self, center=True, **kwargs): + include = [self] if center else [] + G = self.G.subgraph( + n.node_id for n in list(self.get_agents(**kwargs) + include) + ) + return G + + def remove_node(self): + print(f"Removing node for {self.unique_id}: {self.node_id}") + self.G.remove_node(self.node_id) + self.node_id = None + + def add_edge(self, other, edge_attr_dict=None, *edge_attrs): + if self.node_id not in self.G.nodes(data=False): + raise ValueError( + "{} not in list of existing agents in the network".format( + self.unique_id + ) + ) + if other.node_id not in self.G.nodes(data=False): + raise ValueError( + "{} not in list of existing agents in the network".format(other) + ) + + self.G.add_edge( + self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs + ) + + def die(self, remove=True): + if not self.alive: + return None + if remove: + self.remove_node() + return super().die() diff --git a/tests/test_agents.py b/tests/test_agents.py index 4006e9d..35526e3 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -45,3 +45,17 @@ class TestMain(TestCase): for i in range(5): t = g.step() assert t == i + + def test_state_decorator(self): + class MyAgent(agents.FSM): + run = 0 + @agents.default_state + @agents.state('original') + def root(self): + self.run += 1 + e = environment.Environment() + a = MyAgent(model=e, unique_id=e.next_id()) + a.step() + assert a.run == 1 + a.step() + assert a.run == 2 diff --git a/tests/test_main.py b/tests/test_main.py index f2004ad..8f4f97c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -160,32 +160,12 @@ class TestMain(TestCase): def test_serialize_agent_class(self): """A class from soil.agents should be serialized without the module part""" - ser = agents.serialize_type(CustomAgent) + ser = agents._serialize_type(CustomAgent) assert ser == "test_main.CustomAgent" - ser = agents.serialize_type(agents.BaseAgent) + ser = agents._serialize_type(agents.BaseAgent) assert ser == "BaseAgent" pickle.dumps(ser) - def test_deserialize_agent_distribution(self): - agent_distro = [ - {"agent_class": "CounterModel", "weight": 1}, - {"agent_class": "test_main.CustomAgent", "weight": 2}, - ] - converted = agents.deserialize_definition(agent_distro) - assert converted[0]["agent_class"] == agents.CounterModel - assert converted[1]["agent_class"] == CustomAgent - pickle.dumps(converted) - - def test_serialize_agent_distribution(self): - agent_distro = [ - {"agent_class": agents.CounterModel, "weight": 1}, - {"agent_class": CustomAgent, "weight": 2}, - ] - converted = agents.serialize_definition(agent_distro) - assert converted[0]["agent_class"] == "CounterModel" - assert converted[1]["agent_class"] == "test_main.CustomAgent" - pickle.dumps(converted) - def test_templates(self): """Loading a template should result in several configs""" configs = serialization.load_file(join(EXAMPLES, "template.yml"))