mirror of
https://github.com/gsi-upm/soil
synced 2024-11-22 03:02:28 +00:00
Refactor
* Removed references to `set_state` * Split some functionality from `agents` into separate files (`fsm` and `network_agents`) * Rename `neighboring_agents` to `neighbors` * Delete some spurious functions
This commit is contained in:
parent
880a9f2a1c
commit
3776c4e5c5
@ -58,7 +58,7 @@ class SocialMoneyAgent(NetworkAgent, MoneyAgent):
|
|||||||
|
|
||||||
def give_money(self):
|
def give_money(self):
|
||||||
cellmates = set(self.model.grid.get_cell_list_contents([self.pos]))
|
cellmates = set(self.model.grid.get_cell_list_contents([self.pos]))
|
||||||
friends = set(self.get_neighboring_agents())
|
friends = set(self.get_neighbors())
|
||||||
self.info("Trying to give money")
|
self.info("Trying to give money")
|
||||||
self.info("Cellmates: ", cellmates)
|
self.info("Cellmates: ", cellmates)
|
||||||
self.info("Friends: ", friends)
|
self.info("Friends: ", friends)
|
||||||
|
@ -8,10 +8,9 @@ class DumbViewer(FSM, NetworkAgent):
|
|||||||
its neighbors once it's infected.
|
its neighbors once it's infected.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
defaults = {
|
prob_neighbor_spread = 0.5
|
||||||
"prob_neighbor_spread": 0.5,
|
prob_tv_spread = 0.1
|
||||||
"prob_tv_spread": 0.1,
|
has_been_infected = False
|
||||||
}
|
|
||||||
|
|
||||||
@default_state
|
@default_state
|
||||||
@state
|
@state
|
||||||
@ -19,10 +18,12 @@ class DumbViewer(FSM, NetworkAgent):
|
|||||||
if self["has_tv"]:
|
if self["has_tv"]:
|
||||||
if self.prob(self.model["prob_tv_spread"]):
|
if self.prob(self.model["prob_tv_spread"]):
|
||||||
return self.infected
|
return self.infected
|
||||||
|
if self.has_been_infected:
|
||||||
|
return self.infected
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def infected(self):
|
def infected(self):
|
||||||
for neighbor in self.get_neighboring_agents(state_id=self.neutral.id):
|
for neighbor in self.get_neighbors(state_id=self.neutral.id):
|
||||||
if self.prob(self.model["prob_neighbor_spread"]):
|
if self.prob(self.model["prob_neighbor_spread"]):
|
||||||
neighbor.infect()
|
neighbor.infect()
|
||||||
|
|
||||||
@ -33,7 +34,7 @@ class DumbViewer(FSM, NetworkAgent):
|
|||||||
HerdViewer might not become infected right away
|
HerdViewer might not become infected right away
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.set_state(self.infected)
|
self.has_been_infected = True
|
||||||
|
|
||||||
|
|
||||||
class HerdViewer(DumbViewer):
|
class HerdViewer(DumbViewer):
|
||||||
@ -43,12 +44,12 @@ class HerdViewer(DumbViewer):
|
|||||||
|
|
||||||
def infect(self):
|
def infect(self):
|
||||||
"""Notice again that this is NOT a state. See DumbViewer.infect for reference"""
|
"""Notice again that this is NOT a state. See DumbViewer.infect for reference"""
|
||||||
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
infected = self.count_neighbors(state_id=self.infected.id)
|
||||||
total = self.count_neighboring_agents()
|
total = self.count_neighbors()
|
||||||
prob_infect = self.model["prob_neighbor_spread"] * infected / total
|
prob_infect = self.model["prob_neighbor_spread"] * infected / total
|
||||||
self.debug("prob_infect", prob_infect)
|
self.debug("prob_infect", prob_infect)
|
||||||
if self.prob(prob_infect):
|
if self.prob(prob_infect):
|
||||||
self.set_state(self.infected)
|
self.has_been_infected = True
|
||||||
|
|
||||||
|
|
||||||
class WiseViewer(HerdViewer):
|
class WiseViewer(HerdViewer):
|
||||||
@ -65,7 +66,7 @@ class WiseViewer(HerdViewer):
|
|||||||
@state
|
@state
|
||||||
def cured(self):
|
def cured(self):
|
||||||
prob_cure = self.model["prob_neighbor_cure"]
|
prob_cure = self.model["prob_neighbor_cure"]
|
||||||
for neighbor in self.get_neighboring_agents(state_id=self.infected.id):
|
for neighbor in self.get_neighbors(state_id=self.infected.id):
|
||||||
if self.prob(prob_cure):
|
if self.prob(prob_cure):
|
||||||
try:
|
try:
|
||||||
neighbor.cure()
|
neighbor.cure()
|
||||||
@ -73,13 +74,14 @@ class WiseViewer(HerdViewer):
|
|||||||
self.debug("Viewer {} cannot be cured".format(neighbor.id))
|
self.debug("Viewer {} cannot be cured".format(neighbor.id))
|
||||||
|
|
||||||
def cure(self):
|
def cure(self):
|
||||||
self.set_state(self.cured.id)
|
self.has_been_cured = True
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def infected(self):
|
def infected(self):
|
||||||
cured = max(self.count_neighboring_agents(self.cured.id), 1.0)
|
if self.has_been_cured:
|
||||||
infected = max(self.count_neighboring_agents(self.infected.id), 1.0)
|
return self.cured
|
||||||
|
cured = max(self.count_neighbors(self.cured.id), 1.0)
|
||||||
|
infected = max(self.count_neighbors(self.infected.id), 1.0)
|
||||||
prob_cure = self.model["prob_neighbor_cure"] * (cured / infected)
|
prob_cure = self.model["prob_neighbor_cure"] * (cured / infected)
|
||||||
if self.prob(prob_cure):
|
if self.prob(prob_cure):
|
||||||
return self.cured
|
return self.cured
|
||||||
return self.set_state(super().infected)
|
|
||||||
|
@ -89,7 +89,7 @@ class Patron(FSM, NetworkAgent):
|
|||||||
if self["pub"] != None:
|
if self["pub"] != None:
|
||||||
return self.sober_in_pub
|
return self.sober_in_pub
|
||||||
self.debug("I am looking for a pub")
|
self.debug("I am looking for a pub")
|
||||||
group = list(self.get_neighboring_agents())
|
group = list(self.get_neighbors())
|
||||||
for pub in self.model.available_pubs():
|
for pub in self.model.available_pubs():
|
||||||
self.debug("We're trying to get into {}: total: {}".format(pub, len(group)))
|
self.debug("We're trying to get into {}: total: {}".format(pub, len(group)))
|
||||||
if self.model.enter(pub, self, *group):
|
if self.model.enter(pub, self, *group):
|
||||||
|
@ -49,7 +49,7 @@ class TerroristSpreadModel(FSM, Geo):
|
|||||||
|
|
||||||
@state
|
@state
|
||||||
def civilian(self):
|
def civilian(self):
|
||||||
neighbours = list(self.get_neighboring_agents(agent_class=TerroristSpreadModel))
|
neighbours = list(self.get_neighbors(agent_class=TerroristSpreadModel))
|
||||||
if len(neighbours) > 0:
|
if len(neighbours) > 0:
|
||||||
# Only interact with some of the neighbors
|
# Only interact with some of the neighbors
|
||||||
interactions = list(
|
interactions = list(
|
||||||
@ -73,7 +73,7 @@ class TerroristSpreadModel(FSM, Geo):
|
|||||||
@state
|
@state
|
||||||
def leader(self):
|
def leader(self):
|
||||||
self.mean_belief = self.mean_belief ** (1 - self.terrorist_additional_influence)
|
self.mean_belief = self.mean_belief ** (1 - self.terrorist_additional_influence)
|
||||||
for neighbour in self.get_neighboring_agents(
|
for neighbour in self.get_neighbors(
|
||||||
state_id=[self.terrorist.id, self.leader.id]
|
state_id=[self.terrorist.id, self.leader.id]
|
||||||
):
|
):
|
||||||
if self.betweenness(neighbour) > self.betweenness(self):
|
if self.betweenness(neighbour) > self.betweenness(self):
|
||||||
@ -158,7 +158,7 @@ class TrainingAreaModel(FSM, Geo):
|
|||||||
@default_state
|
@default_state
|
||||||
@state
|
@state
|
||||||
def terrorist(self):
|
def terrorist(self):
|
||||||
for neighbour in self.get_neighboring_agents(agent_class=TerroristSpreadModel):
|
for neighbour in self.get_neighbors(agent_class=TerroristSpreadModel):
|
||||||
if neighbour.vulnerability > self.min_vulnerability:
|
if neighbour.vulnerability > self.min_vulnerability:
|
||||||
neighbour.vulnerability = neighbour.vulnerability ** (
|
neighbour.vulnerability = neighbour.vulnerability ** (
|
||||||
1 - self.training_influence
|
1 - self.training_influence
|
||||||
@ -187,7 +187,7 @@ class HavenModel(FSM, Geo):
|
|||||||
self.max_vulnerability = model.environment_params["max_vulnerability"]
|
self.max_vulnerability = model.environment_params["max_vulnerability"]
|
||||||
|
|
||||||
def get_occupants(self, **kwargs):
|
def get_occupants(self, **kwargs):
|
||||||
return self.get_neighboring_agents(agent_class=TerroristSpreadModel, **kwargs)
|
return self.get_neighbors(agent_class=TerroristSpreadModel, **kwargs)
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def civilian(self):
|
def civilian(self):
|
||||||
@ -243,7 +243,7 @@ class TerroristNetworkModel(TerroristSpreadModel):
|
|||||||
return super().leader()
|
return super().leader()
|
||||||
|
|
||||||
def update_relationships(self):
|
def update_relationships(self):
|
||||||
if self.count_neighboring_agents(state_id=self.civilian.id) == 0:
|
if self.count_neighbors(state_id=self.civilian.id) == 0:
|
||||||
close_ups = set(
|
close_ups = set(
|
||||||
self.geo_search(
|
self.geo_search(
|
||||||
radius=self.vision_range, agent_class=TerroristNetworkModel
|
radius=self.vision_range, agent_class=TerroristNetworkModel
|
||||||
@ -258,7 +258,7 @@ class TerroristNetworkModel(TerroristSpreadModel):
|
|||||||
)
|
)
|
||||||
neighbours = set(
|
neighbours = set(
|
||||||
agent.id
|
agent.id
|
||||||
for agent in self.get_neighboring_agents(
|
for agent in self.get_neighbors(
|
||||||
agent_class=TerroristNetworkModel
|
agent_class=TerroristNetworkModel
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,7 @@ class BassModel(FSM):
|
|||||||
self.sentimentCorrelation = 1
|
self.sentimentCorrelation = 1
|
||||||
return self.aware
|
return self.aware
|
||||||
else:
|
else:
|
||||||
aware_neighbors = self.get_neighboring_agents(state_id=self.aware.id)
|
aware_neighbors = self.get_neighbors(state_id=self.aware.id)
|
||||||
num_neighbors_aware = len(aware_neighbors)
|
num_neighbors_aware = len(aware_neighbors)
|
||||||
if self.prob((self["imitation_prob"] * num_neighbors_aware)):
|
if self.prob((self["imitation_prob"] * num_neighbors_aware)):
|
||||||
self.sentimentCorrelation = 1
|
self.sentimentCorrelation = 1
|
||||||
|
@ -24,14 +24,14 @@ class BigMarketModel(FSM):
|
|||||||
self.type = ""
|
self.type = ""
|
||||||
|
|
||||||
if self.id < len(self.enterprises): # Enterprises
|
if self.id < len(self.enterprises): # Enterprises
|
||||||
self.set_state(self.enterprise.id)
|
self._set_state(self.enterprise.id)
|
||||||
self.type = "Enterprise"
|
self.type = "Enterprise"
|
||||||
self.tweet_probability = environment.environment_params[
|
self.tweet_probability = environment.environment_params[
|
||||||
"tweet_probability_enterprises"
|
"tweet_probability_enterprises"
|
||||||
][self.id]
|
][self.id]
|
||||||
else: # normal users
|
else: # normal users
|
||||||
self.type = "User"
|
self.type = "User"
|
||||||
self.set_state(self.user.id)
|
self._set_state(self.user.id)
|
||||||
self.tweet_probability = environment.environment_params[
|
self.tweet_probability = environment.environment_params[
|
||||||
"tweet_probability_users"
|
"tweet_probability_users"
|
||||||
]
|
]
|
||||||
@ -49,7 +49,7 @@ class BigMarketModel(FSM):
|
|||||||
def enterprise(self):
|
def enterprise(self):
|
||||||
|
|
||||||
if self.random.random() < self.tweet_probability: # Tweets
|
if self.random.random() < self.tweet_probability: # Tweets
|
||||||
aware_neighbors = self.get_neighboring_agents(
|
aware_neighbors = self.get_neighbors(
|
||||||
state_id=self.number_of_enterprises
|
state_id=self.number_of_enterprises
|
||||||
) # Nodes neighbour users
|
) # Nodes neighbour users
|
||||||
for x in aware_neighbors:
|
for x in aware_neighbors:
|
||||||
@ -96,7 +96,7 @@ class BigMarketModel(FSM):
|
|||||||
] = self.sentiment_about[i]
|
] = self.sentiment_about[i]
|
||||||
|
|
||||||
def userTweets(self, sentiment, enterprise):
|
def userTweets(self, sentiment, enterprise):
|
||||||
aware_neighbors = self.get_neighboring_agents(
|
aware_neighbors = self.get_neighbors(
|
||||||
state_id=self.number_of_enterprises
|
state_id=self.number_of_enterprises
|
||||||
) # Nodes neighbours users
|
) # Nodes neighbours users
|
||||||
for x in aware_neighbors:
|
for x in aware_neighbors:
|
||||||
|
@ -14,7 +14,7 @@ class CounterModel(NetworkAgent):
|
|||||||
def step(self):
|
def step(self):
|
||||||
# Outside effects
|
# Outside effects
|
||||||
total = len(list(self.model.schedule._agents))
|
total = len(list(self.model.schedule._agents))
|
||||||
neighbors = len(list(self.get_neighboring_agents()))
|
neighbors = len(list(self.get_neighbors()))
|
||||||
self["times"] = self.get("times", 0) + 1
|
self["times"] = self.get("times", 0) + 1
|
||||||
self["neighbors"] = neighbors
|
self["neighbors"] = neighbors
|
||||||
self["total"] = total
|
self["total"] = total
|
||||||
@ -33,7 +33,7 @@ class AggregatedCounter(NetworkAgent):
|
|||||||
def step(self):
|
def step(self):
|
||||||
# Outside effects
|
# Outside effects
|
||||||
self["times"] += 1
|
self["times"] += 1
|
||||||
neighbors = len(list(self.get_neighboring_agents()))
|
neighbors = len(list(self.get_neighbors()))
|
||||||
self["neighbors"] += neighbors
|
self["neighbors"] += neighbors
|
||||||
total = len(list(self.model.schedule.agents))
|
total = len(list(self.model.schedule.agents))
|
||||||
self["total"] += total
|
self["total"] += total
|
||||||
|
@ -36,7 +36,7 @@ class IndependentCascadeModel(BaseAgent):
|
|||||||
|
|
||||||
# Imitation effects
|
# Imitation effects
|
||||||
if self.state["id"] == 0:
|
if self.state["id"] == 0:
|
||||||
aware_neighbors = self.get_neighboring_agents(state_id=1)
|
aware_neighbors = self.get_neighbors(state_id=1)
|
||||||
for x in aware_neighbors:
|
for x in aware_neighbors:
|
||||||
if x.state["time_awareness"] == (self.env.now - 1):
|
if x.state["time_awareness"] == (self.env.now - 1):
|
||||||
aware_neighbors_1_time_step.append(x)
|
aware_neighbors_1_time_step.append(x)
|
||||||
|
@ -71,7 +71,7 @@ class SpreadModelM2(BaseAgent):
|
|||||||
def neutral_behaviour(self):
|
def neutral_behaviour(self):
|
||||||
|
|
||||||
# Infected
|
# Infected
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighbors(state_id=1)
|
||||||
if len(infected_neighbors) > 0:
|
if len(infected_neighbors) > 0:
|
||||||
if self.prob(self.prob_neutral_making_denier):
|
if self.prob(self.prob_neutral_making_denier):
|
||||||
self.state["id"] = 3 # Vaccinated making denier
|
self.state["id"] = 3 # Vaccinated making denier
|
||||||
@ -79,7 +79,7 @@ class SpreadModelM2(BaseAgent):
|
|||||||
def infected_behaviour(self):
|
def infected_behaviour(self):
|
||||||
|
|
||||||
# Neutral
|
# Neutral
|
||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_infect):
|
if self.prob(self.prob_infect):
|
||||||
neighbor.state["id"] = 1 # Infected
|
neighbor.state["id"] = 1 # Infected
|
||||||
@ -87,13 +87,13 @@ class SpreadModelM2(BaseAgent):
|
|||||||
def cured_behaviour(self):
|
def cured_behaviour(self):
|
||||||
|
|
||||||
# Vaccinate
|
# Vaccinate
|
||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||||
neighbor.state["id"] = 3 # Vaccinated
|
neighbor.state["id"] = 3 # Vaccinated
|
||||||
|
|
||||||
# Cure
|
# Cure
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighbors(state_id=1)
|
||||||
for neighbor in infected_neighbors:
|
for neighbor in infected_neighbors:
|
||||||
if self.prob(self.prob_cured_healing_infected):
|
if self.prob(self.prob_cured_healing_infected):
|
||||||
neighbor.state["id"] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
@ -101,19 +101,19 @@ class SpreadModelM2(BaseAgent):
|
|||||||
def vaccinated_behaviour(self):
|
def vaccinated_behaviour(self):
|
||||||
|
|
||||||
# Cure
|
# Cure
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighbors(state_id=1)
|
||||||
for neighbor in infected_neighbors:
|
for neighbor in infected_neighbors:
|
||||||
if self.prob(self.prob_cured_healing_infected):
|
if self.prob(self.prob_cured_healing_infected):
|
||||||
neighbor.state["id"] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
|
|
||||||
# Vaccinate
|
# Vaccinate
|
||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||||
neighbor.state["id"] = 3 # Vaccinated
|
neighbor.state["id"] = 3 # Vaccinated
|
||||||
|
|
||||||
# Generate anti-rumor
|
# Generate anti-rumor
|
||||||
infected_neighbors_2 = self.get_neighboring_agents(state_id=1)
|
infected_neighbors_2 = self.get_neighbors(state_id=1)
|
||||||
for neighbor in infected_neighbors_2:
|
for neighbor in infected_neighbors_2:
|
||||||
if self.prob(self.prob_generate_anti_rumor):
|
if self.prob(self.prob_generate_anti_rumor):
|
||||||
neighbor.state["id"] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
@ -191,7 +191,7 @@ class ControlModelM2(BaseAgent):
|
|||||||
self.state["visible"] = False
|
self.state["visible"] = False
|
||||||
|
|
||||||
# Infected
|
# Infected
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighbors(state_id=1)
|
||||||
if len(infected_neighbors) > 0:
|
if len(infected_neighbors) > 0:
|
||||||
if self.random(self.prob_neutral_making_denier):
|
if self.random(self.prob_neutral_making_denier):
|
||||||
self.state["id"] = 3 # Vaccinated making denier
|
self.state["id"] = 3 # Vaccinated making denier
|
||||||
@ -199,7 +199,7 @@ class ControlModelM2(BaseAgent):
|
|||||||
def infected_behaviour(self):
|
def infected_behaviour(self):
|
||||||
|
|
||||||
# Neutral
|
# Neutral
|
||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_infect):
|
if self.prob(self.prob_infect):
|
||||||
neighbor.state["id"] = 1 # Infected
|
neighbor.state["id"] = 1 # Infected
|
||||||
@ -209,13 +209,13 @@ class ControlModelM2(BaseAgent):
|
|||||||
|
|
||||||
self.state["visible"] = True
|
self.state["visible"] = True
|
||||||
# Vaccinate
|
# Vaccinate
|
||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||||
neighbor.state["id"] = 3 # Vaccinated
|
neighbor.state["id"] = 3 # Vaccinated
|
||||||
|
|
||||||
# Cure
|
# Cure
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighbors(state_id=1)
|
||||||
for neighbor in infected_neighbors:
|
for neighbor in infected_neighbors:
|
||||||
if self.prob(self.prob_cured_healing_infected):
|
if self.prob(self.prob_cured_healing_infected):
|
||||||
neighbor.state["id"] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
@ -224,47 +224,47 @@ class ControlModelM2(BaseAgent):
|
|||||||
self.state["visible"] = True
|
self.state["visible"] = True
|
||||||
|
|
||||||
# Cure
|
# Cure
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighbors(state_id=1)
|
||||||
for neighbor in infected_neighbors:
|
for neighbor in infected_neighbors:
|
||||||
if self.prob(self.prob_cured_healing_infected):
|
if self.prob(self.prob_cured_healing_infected):
|
||||||
neighbor.state["id"] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
|
|
||||||
# Vaccinate
|
# Vaccinate
|
||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||||
neighbor.state["id"] = 3 # Vaccinated
|
neighbor.state["id"] = 3 # Vaccinated
|
||||||
|
|
||||||
# Generate anti-rumor
|
# Generate anti-rumor
|
||||||
infected_neighbors_2 = self.get_neighboring_agents(state_id=1)
|
infected_neighbors_2 = self.get_neighbors(state_id=1)
|
||||||
for neighbor in infected_neighbors_2:
|
for neighbor in infected_neighbors_2:
|
||||||
if self.prob(self.prob_generate_anti_rumor):
|
if self.prob(self.prob_generate_anti_rumor):
|
||||||
neighbor.state["id"] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
|
|
||||||
def beacon_off_behaviour(self):
|
def beacon_off_behaviour(self):
|
||||||
self.state["visible"] = False
|
self.state["visible"] = False
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighbors(state_id=1)
|
||||||
if len(infected_neighbors) > 0:
|
if len(infected_neighbors) > 0:
|
||||||
self.state["id"] == 5 # Beacon on
|
self.state["id"] == 5 # Beacon on
|
||||||
|
|
||||||
def beacon_on_behaviour(self):
|
def beacon_on_behaviour(self):
|
||||||
self.state["visible"] = False
|
self.state["visible"] = False
|
||||||
# Cure (M2 feature added)
|
# Cure (M2 feature added)
|
||||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
infected_neighbors = self.get_neighbors(state_id=1)
|
||||||
for neighbor in infected_neighbors:
|
for neighbor in infected_neighbors:
|
||||||
if self.prob(self.prob_generate_anti_rumor):
|
if self.prob(self.prob_generate_anti_rumor):
|
||||||
neighbor.state["id"] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
neutral_neighbors_infected = neighbor.get_neighboring_agents(state_id=0)
|
neutral_neighbors_infected = neighbor.get_neighbors(state_id=0)
|
||||||
for neighbor in neutral_neighbors_infected:
|
for neighbor in neutral_neighbors_infected:
|
||||||
if self.prob(self.prob_generate_anti_rumor):
|
if self.prob(self.prob_generate_anti_rumor):
|
||||||
neighbor.state["id"] = 3 # Vaccinated
|
neighbor.state["id"] = 3 # Vaccinated
|
||||||
infected_neighbors_infected = neighbor.get_neighboring_agents(state_id=1)
|
infected_neighbors_infected = neighbor.get_neighbors(state_id=1)
|
||||||
for neighbor in infected_neighbors_infected:
|
for neighbor in infected_neighbors_infected:
|
||||||
if self.prob(self.prob_generate_anti_rumor):
|
if self.prob(self.prob_generate_anti_rumor):
|
||||||
neighbor.state["id"] = 2 # Cured
|
neighbor.state["id"] = 2 # Cured
|
||||||
|
|
||||||
# Vaccinate
|
# Vaccinate
|
||||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||||
for neighbor in neutral_neighbors:
|
for neighbor in neutral_neighbors:
|
||||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||||
neighbor.state["id"] = 3 # Vaccinated
|
neighbor.state["id"] = 3 # Vaccinated
|
||||||
|
@ -69,10 +69,10 @@ class SISaModel(FSM):
|
|||||||
return self.content
|
return self.content
|
||||||
|
|
||||||
# Infected
|
# Infected
|
||||||
discontent_neighbors = self.count_neighboring_agents(state_id=self.discontent)
|
discontent_neighbors = self.count_neighbors(state_id=self.discontent)
|
||||||
if self.prob(scontent_neighbors * self.neutral_discontent_infected_prob):
|
if self.prob(scontent_neighbors * self.neutral_discontent_infected_prob):
|
||||||
return self.discontent
|
return self.discontent
|
||||||
content_neighbors = self.count_neighboring_agents(state_id=self.content.id)
|
content_neighbors = self.count_neighbors(state_id=self.content.id)
|
||||||
if self.prob(s * self.neutral_content_infected_prob):
|
if self.prob(s * self.neutral_content_infected_prob):
|
||||||
return self.content
|
return self.content
|
||||||
return self.neutral
|
return self.neutral
|
||||||
@ -84,7 +84,7 @@ class SISaModel(FSM):
|
|||||||
return self.neutral
|
return self.neutral
|
||||||
|
|
||||||
# Superinfected
|
# Superinfected
|
||||||
content_neighbors = self.count_neighboring_agents(state_id=self.content.id)
|
content_neighbors = self.count_neighbors(state_id=self.content.id)
|
||||||
if self.prob(s * self.discontent_content):
|
if self.prob(s * self.discontent_content):
|
||||||
return self.content
|
return self.content
|
||||||
return self.discontent
|
return self.discontent
|
||||||
@ -96,9 +96,7 @@ class SISaModel(FSM):
|
|||||||
return self.neutral
|
return self.neutral
|
||||||
|
|
||||||
# Superinfected
|
# Superinfected
|
||||||
discontent_neighbors = self.count_neighboring_agents(
|
discontent_neighbors = self.count_neighbors(state_id=self.discontent.id)
|
||||||
state_id=self.discontent.id
|
|
||||||
)
|
|
||||||
if self.prob(scontent_neighbors * self.content_discontent):
|
if self.prob(scontent_neighbors * self.content_discontent):
|
||||||
self.discontent
|
self.discontent
|
||||||
return self.content
|
return self.content
|
||||||
|
@ -41,25 +41,25 @@ class SentimentCorrelationModel(BaseAgent):
|
|||||||
sad_neighbors_1_time_step = []
|
sad_neighbors_1_time_step = []
|
||||||
disgusted_neighbors_1_time_step = []
|
disgusted_neighbors_1_time_step = []
|
||||||
|
|
||||||
angry_neighbors = self.get_neighboring_agents(state_id=1)
|
angry_neighbors = self.get_neighbors(state_id=1)
|
||||||
for x in angry_neighbors:
|
for x in angry_neighbors:
|
||||||
if x.state["time_awareness"][0] > (self.env.now - 500):
|
if x.state["time_awareness"][0] > (self.env.now - 500):
|
||||||
angry_neighbors_1_time_step.append(x)
|
angry_neighbors_1_time_step.append(x)
|
||||||
num_neighbors_angry = len(angry_neighbors_1_time_step)
|
num_neighbors_angry = len(angry_neighbors_1_time_step)
|
||||||
|
|
||||||
joyful_neighbors = self.get_neighboring_agents(state_id=2)
|
joyful_neighbors = self.get_neighbors(state_id=2)
|
||||||
for x in joyful_neighbors:
|
for x in joyful_neighbors:
|
||||||
if x.state["time_awareness"][1] > (self.env.now - 500):
|
if x.state["time_awareness"][1] > (self.env.now - 500):
|
||||||
joyful_neighbors_1_time_step.append(x)
|
joyful_neighbors_1_time_step.append(x)
|
||||||
num_neighbors_joyful = len(joyful_neighbors_1_time_step)
|
num_neighbors_joyful = len(joyful_neighbors_1_time_step)
|
||||||
|
|
||||||
sad_neighbors = self.get_neighboring_agents(state_id=3)
|
sad_neighbors = self.get_neighbors(state_id=3)
|
||||||
for x in sad_neighbors:
|
for x in sad_neighbors:
|
||||||
if x.state["time_awareness"][2] > (self.env.now - 500):
|
if x.state["time_awareness"][2] > (self.env.now - 500):
|
||||||
sad_neighbors_1_time_step.append(x)
|
sad_neighbors_1_time_step.append(x)
|
||||||
num_neighbors_sad = len(sad_neighbors_1_time_step)
|
num_neighbors_sad = len(sad_neighbors_1_time_step)
|
||||||
|
|
||||||
disgusted_neighbors = self.get_neighboring_agents(state_id=4)
|
disgusted_neighbors = self.get_neighbors(state_id=4)
|
||||||
for x in disgusted_neighbors:
|
for x in disgusted_neighbors:
|
||||||
if x.state["time_awareness"][3] > (self.env.now - 500):
|
if x.state["time_awareness"][3] > (self.env.now - 500):
|
||||||
disgusted_neighbors_1_time_step.append(x)
|
disgusted_neighbors_1_time_step.append(x)
|
||||||
|
@ -243,223 +243,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
|||||||
return f"{self.__class__.__name__}({self.unique_id})"
|
return f"{self.__class__.__name__}({self.unique_id})"
|
||||||
|
|
||||||
|
|
||||||
class NetworkAgent(BaseAgent):
|
|
||||||
def __init__(self, *args, topology, node_id, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
assert topology is not None
|
|
||||||
assert node_id is not None
|
|
||||||
self.G = topology
|
|
||||||
assert self.G
|
|
||||||
self.node_id = node_id
|
|
||||||
|
|
||||||
def count_neighboring_agents(self, state_id=None, **kwargs):
|
|
||||||
return len(self.get_neighboring_agents(state_id=state_id, **kwargs))
|
|
||||||
|
|
||||||
def get_neighboring_agents(self, **kwargs):
|
|
||||||
return list(self.iter_agents(limit_neighbors=True, **kwargs))
|
|
||||||
|
|
||||||
def add_edge(self, other):
|
|
||||||
assert self.node_id
|
|
||||||
assert other.node_id
|
|
||||||
assert self.node_id in self.G.nodes
|
|
||||||
assert other.node_id in self.G.nodes
|
|
||||||
self.topology.add_edge(self.node_id, other.node_id)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def node(self):
|
|
||||||
return self.topology.nodes[self.node_id]
|
|
||||||
|
|
||||||
def iter_agents(self, unique_id=None, *, limit_neighbors=False, **kwargs):
|
|
||||||
unique_ids = None
|
|
||||||
if isinstance(unique_id, list):
|
|
||||||
unique_ids = set(unique_id)
|
|
||||||
elif unique_id is not None:
|
|
||||||
unique_ids = set(
|
|
||||||
[
|
|
||||||
unique_id,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
if limit_neighbors:
|
|
||||||
neighbor_ids = set()
|
|
||||||
for node_id in self.G.neighbors(self.node_id):
|
|
||||||
if self.G.nodes[node_id].get("agent") is not None:
|
|
||||||
neighbor_ids.add(node_id)
|
|
||||||
if unique_ids:
|
|
||||||
unique_ids = unique_ids & neighbor_ids
|
|
||||||
else:
|
|
||||||
unique_ids = neighbor_ids
|
|
||||||
if not unique_ids:
|
|
||||||
return
|
|
||||||
unique_ids = list(unique_ids)
|
|
||||||
yield from super().iter_agents(unique_id=unique_ids, **kwargs)
|
|
||||||
|
|
||||||
def subgraph(self, center=True, **kwargs):
|
|
||||||
include = [self] if center else []
|
|
||||||
G = self.G.subgraph(
|
|
||||||
n.node_id for n in list(self.get_agents(**kwargs) + include)
|
|
||||||
)
|
|
||||||
return G
|
|
||||||
|
|
||||||
def remove_node(self):
|
|
||||||
print(f"Removing node for {self.unique_id}: {self.node_id}")
|
|
||||||
self.G.remove_node(self.node_id)
|
|
||||||
self.node_id = None
|
|
||||||
|
|
||||||
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
|
||||||
if self.node_id not in self.G.nodes(data=False):
|
|
||||||
raise ValueError(
|
|
||||||
"{} not in list of existing agents in the network".format(
|
|
||||||
self.unique_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if other.node_id not in self.G.nodes(data=False):
|
|
||||||
raise ValueError(
|
|
||||||
"{} not in list of existing agents in the network".format(other)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.G.add_edge(
|
|
||||||
self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs
|
|
||||||
)
|
|
||||||
|
|
||||||
def die(self, remove=True):
|
|
||||||
if not self.alive:
|
|
||||||
return
|
|
||||||
if remove:
|
|
||||||
self.remove_node()
|
|
||||||
return super().die()
|
|
||||||
|
|
||||||
|
|
||||||
def state(name=None):
|
|
||||||
def decorator(func, name=None):
|
|
||||||
"""
|
|
||||||
A state function should return either a state id, or a tuple (state_id, when)
|
|
||||||
The default value for state_id is the current state id.
|
|
||||||
The default value for when is the interval defined in the environment.
|
|
||||||
"""
|
|
||||||
if inspect.isgeneratorfunction(func):
|
|
||||||
orig_func = func
|
|
||||||
|
|
||||||
@wraps(func)
|
|
||||||
def func(self):
|
|
||||||
while True:
|
|
||||||
if not self._coroutine:
|
|
||||||
self._coroutine = orig_func(self)
|
|
||||||
try:
|
|
||||||
n = next(self._coroutine)
|
|
||||||
if n:
|
|
||||||
return None, n
|
|
||||||
return
|
|
||||||
except StopIteration as ex:
|
|
||||||
self._coroutine = None
|
|
||||||
next_state = ex.value
|
|
||||||
if next_state is not None:
|
|
||||||
self._set_state(next_state)
|
|
||||||
return next_state
|
|
||||||
|
|
||||||
func.id = name or func.__name__
|
|
||||||
func.is_default = False
|
|
||||||
return func
|
|
||||||
|
|
||||||
if callable(name):
|
|
||||||
return decorator(name)
|
|
||||||
else:
|
|
||||||
return partial(decorator, name=name)
|
|
||||||
|
|
||||||
|
|
||||||
def default_state(func):
|
|
||||||
func.is_default = True
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
class MetaFSM(MetaAgent):
|
|
||||||
def __new__(mcls, name, bases, namespace):
|
|
||||||
states = {}
|
|
||||||
# Re-use states from inherited classes
|
|
||||||
default_state = None
|
|
||||||
for i in bases:
|
|
||||||
if isinstance(i, MetaFSM):
|
|
||||||
for state_id, state in i._states.items():
|
|
||||||
if state.is_default:
|
|
||||||
default_state = state
|
|
||||||
states[state_id] = state
|
|
||||||
|
|
||||||
# Add new states
|
|
||||||
for attr, func in namespace.items():
|
|
||||||
if hasattr(func, "id"):
|
|
||||||
if func.is_default:
|
|
||||||
default_state = func
|
|
||||||
states[func.id] = func
|
|
||||||
|
|
||||||
namespace.update(
|
|
||||||
{
|
|
||||||
"_default_state": default_state,
|
|
||||||
"_states": states,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return super(MetaFSM, mcls).__new__(
|
|
||||||
mcls=mcls, name=name, bases=bases, namespace=namespace
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super(FSM, self).__init__(**kwargs)
|
|
||||||
if not hasattr(self, "state_id"):
|
|
||||||
if not self._default_state:
|
|
||||||
raise ValueError(
|
|
||||||
"No default state specified for {}".format(self.unique_id)
|
|
||||||
)
|
|
||||||
self.state_id = self._default_state.id
|
|
||||||
|
|
||||||
self._coroutine = None
|
|
||||||
self._set_state(self.state_id)
|
|
||||||
|
|
||||||
def step(self):
|
|
||||||
self.debug(f"Agent {self.unique_id} @ state {self.state_id}")
|
|
||||||
default_interval = super().step()
|
|
||||||
|
|
||||||
next_state = self._states[self.state_id](self)
|
|
||||||
|
|
||||||
when = None
|
|
||||||
try:
|
|
||||||
next_state, *when = next_state
|
|
||||||
if not when:
|
|
||||||
when = None
|
|
||||||
elif len(when) == 1:
|
|
||||||
when = when[0]
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Too many values returned. Only state (and time) allowed"
|
|
||||||
)
|
|
||||||
except TypeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if next_state is not None:
|
|
||||||
self._set_state(next_state)
|
|
||||||
|
|
||||||
return when or default_interval
|
|
||||||
|
|
||||||
def _set_state(self, state, when=None):
|
|
||||||
if hasattr(state, "id"):
|
|
||||||
state = state.id
|
|
||||||
if state not in self._states:
|
|
||||||
raise ValueError("{} is not a valid state".format(state))
|
|
||||||
self.state_id = state
|
|
||||||
if when is not None:
|
|
||||||
self.model.schedule.add(self, when=when)
|
|
||||||
return state
|
|
||||||
|
|
||||||
def die(self):
|
|
||||||
return self.dead, super().die()
|
|
||||||
|
|
||||||
@state
|
|
||||||
def dead(self):
|
|
||||||
return self.die()
|
|
||||||
|
|
||||||
|
|
||||||
def prob(prob, random):
|
def prob(prob, random):
|
||||||
"""
|
"""
|
||||||
A true/False uniform distribution with a given probability.
|
A true/False uniform distribution with a given probability.
|
||||||
@ -525,7 +308,7 @@ def calculate_distribution(network_agents=None, agent_class=None):
|
|||||||
return network_agents
|
return network_agents
|
||||||
|
|
||||||
|
|
||||||
def serialize_type(agent_class, known_modules=[], **kwargs):
|
def _serialize_type(agent_class, known_modules=[], **kwargs):
|
||||||
if isinstance(agent_class, str):
|
if isinstance(agent_class, str):
|
||||||
return agent_class
|
return agent_class
|
||||||
known_modules += ["soil.agents"]
|
known_modules += ["soil.agents"]
|
||||||
@ -534,20 +317,7 @@ def serialize_type(agent_class, known_modules=[], **kwargs):
|
|||||||
] # Get the name of the class
|
] # Get the name of the class
|
||||||
|
|
||||||
|
|
||||||
def serialize_definition(network_agents, known_modules=[]):
|
def _deserialize_type(agent_class, known_modules=[]):
|
||||||
"""
|
|
||||||
When serializing an agent distribution, remove the thresholds, in order
|
|
||||||
to avoid cluttering the YAML definition file.
|
|
||||||
"""
|
|
||||||
d = deepcopy(list(network_agents))
|
|
||||||
for v in d:
|
|
||||||
if "threshold" in v:
|
|
||||||
del v["threshold"]
|
|
||||||
v["agent_class"] = serialize_type(v["agent_class"], known_modules=known_modules)
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_type(agent_class, known_modules=[]):
|
|
||||||
if not isinstance(agent_class, str):
|
if not isinstance(agent_class, str):
|
||||||
return agent_class
|
return agent_class
|
||||||
known = known_modules + ["soil.agents", "soil.agents.custom"]
|
known = known_modules + ["soil.agents", "soil.agents.custom"]
|
||||||
@ -555,31 +325,6 @@ def deserialize_type(agent_class, known_modules=[]):
|
|||||||
return agent_class
|
return agent_class
|
||||||
|
|
||||||
|
|
||||||
def deserialize_definition(ind, **kwargs):
|
|
||||||
d = deepcopy(ind)
|
|
||||||
for v in d:
|
|
||||||
v["agent_class"] = deserialize_type(v["agent_class"], **kwargs)
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_states(states, topology):
|
|
||||||
"""Validate states to avoid ignoring states during initialization"""
|
|
||||||
states = states or []
|
|
||||||
if isinstance(states, dict):
|
|
||||||
for x in states:
|
|
||||||
assert x in topology.nodes
|
|
||||||
else:
|
|
||||||
assert len(states) <= len(topology)
|
|
||||||
return states
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_agent_classs(ind, to_string=False, **kwargs):
|
|
||||||
"""Convenience method to allow specifying agents by class or class name."""
|
|
||||||
if to_string:
|
|
||||||
return serialize_definition(ind, **kwargs)
|
|
||||||
return deserialize_definition(ind, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentView(Mapping, Set):
|
class AgentView(Mapping, Set):
|
||||||
"""A lazy-loaded list of agents."""
|
"""A lazy-loaded list of agents."""
|
||||||
|
|
||||||
@ -663,7 +408,7 @@ def filter_agents(
|
|||||||
state_id = tuple([state_id])
|
state_id = tuple([state_id])
|
||||||
|
|
||||||
if agent_class is not None:
|
if agent_class is not None:
|
||||||
agent_class = deserialize_type(agent_class)
|
agent_class = _deserialize_type(agent_class)
|
||||||
try:
|
try:
|
||||||
agent_class = tuple(agent_class)
|
agent_class = tuple(agent_class)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
@ -703,14 +448,6 @@ def from_config(
|
|||||||
default = cfg or config.AgentConfig()
|
default = cfg or config.AgentConfig()
|
||||||
if not isinstance(cfg, config.AgentConfig):
|
if not isinstance(cfg, config.AgentConfig):
|
||||||
cfg = config.AgentConfig(**cfg)
|
cfg = config.AgentConfig(**cfg)
|
||||||
return _agents_from_config(cfg, topology=topology, random=random)
|
|
||||||
|
|
||||||
|
|
||||||
def _agents_from_config(
|
|
||||||
cfg: config.AgentConfig, topology: nx.Graph, random
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
if cfg and not isinstance(cfg, config.AgentConfig):
|
|
||||||
cfg = config.AgentConfig(**cfg)
|
|
||||||
|
|
||||||
agents = []
|
agents = []
|
||||||
|
|
||||||
@ -878,6 +615,8 @@ def _from_distro(
|
|||||||
return agents
|
return agents
|
||||||
|
|
||||||
|
|
||||||
|
from .network_agents import *
|
||||||
|
from .fsm import *
|
||||||
from .BassModel import *
|
from .BassModel import *
|
||||||
from .BigMarketModel import *
|
from .BigMarketModel import *
|
||||||
from .IndependentCascadeModel import *
|
from .IndependentCascadeModel import *
|
||||||
|
133
soil/agents/fsm.py
Normal file
133
soil/agents/fsm.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
from . import MetaAgent, BaseAgent
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
|
def state(name=None):
|
||||||
|
def decorator(func, name=None):
|
||||||
|
"""
|
||||||
|
A state function should return either a state id, or a tuple (state_id, when)
|
||||||
|
The default value for state_id is the current state id.
|
||||||
|
The default value for when is the interval defined in the environment.
|
||||||
|
"""
|
||||||
|
if inspect.isgeneratorfunction(func):
|
||||||
|
orig_func = func
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def func(self):
|
||||||
|
while True:
|
||||||
|
if not self._coroutine:
|
||||||
|
self._coroutine = orig_func(self)
|
||||||
|
try:
|
||||||
|
n = next(self._coroutine)
|
||||||
|
if n:
|
||||||
|
return None, n
|
||||||
|
return
|
||||||
|
except StopIteration as ex:
|
||||||
|
self._coroutine = None
|
||||||
|
next_state = ex.value
|
||||||
|
if next_state is not None:
|
||||||
|
self._set_state(next_state)
|
||||||
|
return next_state
|
||||||
|
|
||||||
|
func.id = name or func.__name__
|
||||||
|
func.is_default = False
|
||||||
|
return func
|
||||||
|
|
||||||
|
if callable(name):
|
||||||
|
return decorator(name)
|
||||||
|
else:
|
||||||
|
return partial(decorator, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def default_state(func):
|
||||||
|
func.is_default = True
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
class MetaFSM(MetaAgent):
|
||||||
|
def __new__(mcls, name, bases, namespace):
|
||||||
|
states = {}
|
||||||
|
# Re-use states from inherited classes
|
||||||
|
default_state = None
|
||||||
|
for i in bases:
|
||||||
|
if isinstance(i, MetaFSM):
|
||||||
|
for state_id, state in i._states.items():
|
||||||
|
if state.is_default:
|
||||||
|
default_state = state
|
||||||
|
states[state_id] = state
|
||||||
|
|
||||||
|
# Add new states
|
||||||
|
for attr, func in namespace.items():
|
||||||
|
if hasattr(func, "id"):
|
||||||
|
if func.is_default:
|
||||||
|
default_state = func
|
||||||
|
states[func.id] = func
|
||||||
|
|
||||||
|
namespace.update(
|
||||||
|
{
|
||||||
|
"_default_state": default_state,
|
||||||
|
"_states": states,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return super(MetaFSM, mcls).__new__(
|
||||||
|
mcls=mcls, name=name, bases=bases, namespace=namespace
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(FSM, self).__init__(**kwargs)
|
||||||
|
if not hasattr(self, "state_id"):
|
||||||
|
if not self._default_state:
|
||||||
|
raise ValueError(
|
||||||
|
"No default state specified for {}".format(self.unique_id)
|
||||||
|
)
|
||||||
|
self.state_id = self._default_state.id
|
||||||
|
|
||||||
|
self._coroutine = None
|
||||||
|
self._set_state(self.state_id)
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self.debug(f"Agent {self.unique_id} @ state {self.state_id}")
|
||||||
|
default_interval = super().step()
|
||||||
|
|
||||||
|
next_state = self._states[self.state_id](self)
|
||||||
|
|
||||||
|
when = None
|
||||||
|
try:
|
||||||
|
next_state, *when = next_state
|
||||||
|
if not when:
|
||||||
|
when = None
|
||||||
|
elif len(when) == 1:
|
||||||
|
when = when[0]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Too many values returned. Only state (and time) allowed"
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if next_state is not None:
|
||||||
|
self._set_state(next_state)
|
||||||
|
|
||||||
|
return when or default_interval
|
||||||
|
|
||||||
|
def _set_state(self, state, when=None):
|
||||||
|
if hasattr(state, "id"):
|
||||||
|
state = state.id
|
||||||
|
if state not in self._states:
|
||||||
|
raise ValueError("{} is not a valid state".format(state))
|
||||||
|
self.state_id = state
|
||||||
|
if when is not None:
|
||||||
|
self.model.schedule.add(self, when=when)
|
||||||
|
return state
|
||||||
|
|
||||||
|
def die(self):
|
||||||
|
return self.dead, super().die()
|
||||||
|
|
||||||
|
@state
|
||||||
|
def dead(self):
|
||||||
|
return self.die()
|
82
soil/agents/network_agents.py
Normal file
82
soil/agents/network_agents.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
from . import BaseAgent
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkAgent(BaseAgent):
|
||||||
|
def __init__(self, *args, topology, node_id, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
assert topology is not None
|
||||||
|
assert node_id is not None
|
||||||
|
self.G = topology
|
||||||
|
assert self.G
|
||||||
|
self.node_id = node_id
|
||||||
|
|
||||||
|
def count_neighbors(self, state_id=None, **kwargs):
|
||||||
|
return len(self.get_neighbors(state_id=state_id, **kwargs))
|
||||||
|
|
||||||
|
def get_neighbors(self, **kwargs):
|
||||||
|
return list(self.iter_agents(limit_neighbors=True, **kwargs))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node(self):
|
||||||
|
return self.G.nodes[self.node_id]
|
||||||
|
|
||||||
|
def iter_agents(self, unique_id=None, *, limit_neighbors=False, **kwargs):
|
||||||
|
unique_ids = None
|
||||||
|
if isinstance(unique_id, list):
|
||||||
|
unique_ids = set(unique_id)
|
||||||
|
elif unique_id is not None:
|
||||||
|
unique_ids = set(
|
||||||
|
[
|
||||||
|
unique_id,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if limit_neighbors:
|
||||||
|
neighbor_ids = set()
|
||||||
|
for node_id in self.G.neighbors(self.node_id):
|
||||||
|
if self.G.nodes[node_id].get("agent") is not None:
|
||||||
|
neighbor_ids.add(node_id)
|
||||||
|
if unique_ids:
|
||||||
|
unique_ids = unique_ids & neighbor_ids
|
||||||
|
else:
|
||||||
|
unique_ids = neighbor_ids
|
||||||
|
if not unique_ids:
|
||||||
|
return
|
||||||
|
unique_ids = list(unique_ids)
|
||||||
|
yield from super().iter_agents(unique_id=unique_ids, **kwargs)
|
||||||
|
|
||||||
|
def subgraph(self, center=True, **kwargs):
|
||||||
|
include = [self] if center else []
|
||||||
|
G = self.G.subgraph(
|
||||||
|
n.node_id for n in list(self.get_agents(**kwargs) + include)
|
||||||
|
)
|
||||||
|
return G
|
||||||
|
|
||||||
|
def remove_node(self):
|
||||||
|
print(f"Removing node for {self.unique_id}: {self.node_id}")
|
||||||
|
self.G.remove_node(self.node_id)
|
||||||
|
self.node_id = None
|
||||||
|
|
||||||
|
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
||||||
|
if self.node_id not in self.G.nodes(data=False):
|
||||||
|
raise ValueError(
|
||||||
|
"{} not in list of existing agents in the network".format(
|
||||||
|
self.unique_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if other.node_id not in self.G.nodes(data=False):
|
||||||
|
raise ValueError(
|
||||||
|
"{} not in list of existing agents in the network".format(other)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.G.add_edge(
|
||||||
|
self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs
|
||||||
|
)
|
||||||
|
|
||||||
|
def die(self, remove=True):
|
||||||
|
if not self.alive:
|
||||||
|
return None
|
||||||
|
if remove:
|
||||||
|
self.remove_node()
|
||||||
|
return super().die()
|
@ -45,3 +45,17 @@ class TestMain(TestCase):
|
|||||||
for i in range(5):
|
for i in range(5):
|
||||||
t = g.step()
|
t = g.step()
|
||||||
assert t == i
|
assert t == i
|
||||||
|
|
||||||
|
def test_state_decorator(self):
|
||||||
|
class MyAgent(agents.FSM):
|
||||||
|
run = 0
|
||||||
|
@agents.default_state
|
||||||
|
@agents.state('original')
|
||||||
|
def root(self):
|
||||||
|
self.run += 1
|
||||||
|
e = environment.Environment()
|
||||||
|
a = MyAgent(model=e, unique_id=e.next_id())
|
||||||
|
a.step()
|
||||||
|
assert a.run == 1
|
||||||
|
a.step()
|
||||||
|
assert a.run == 2
|
||||||
|
@ -160,32 +160,12 @@ class TestMain(TestCase):
|
|||||||
|
|
||||||
def test_serialize_agent_class(self):
|
def test_serialize_agent_class(self):
|
||||||
"""A class from soil.agents should be serialized without the module part"""
|
"""A class from soil.agents should be serialized without the module part"""
|
||||||
ser = agents.serialize_type(CustomAgent)
|
ser = agents._serialize_type(CustomAgent)
|
||||||
assert ser == "test_main.CustomAgent"
|
assert ser == "test_main.CustomAgent"
|
||||||
ser = agents.serialize_type(agents.BaseAgent)
|
ser = agents._serialize_type(agents.BaseAgent)
|
||||||
assert ser == "BaseAgent"
|
assert ser == "BaseAgent"
|
||||||
pickle.dumps(ser)
|
pickle.dumps(ser)
|
||||||
|
|
||||||
def test_deserialize_agent_distribution(self):
|
|
||||||
agent_distro = [
|
|
||||||
{"agent_class": "CounterModel", "weight": 1},
|
|
||||||
{"agent_class": "test_main.CustomAgent", "weight": 2},
|
|
||||||
]
|
|
||||||
converted = agents.deserialize_definition(agent_distro)
|
|
||||||
assert converted[0]["agent_class"] == agents.CounterModel
|
|
||||||
assert converted[1]["agent_class"] == CustomAgent
|
|
||||||
pickle.dumps(converted)
|
|
||||||
|
|
||||||
def test_serialize_agent_distribution(self):
|
|
||||||
agent_distro = [
|
|
||||||
{"agent_class": agents.CounterModel, "weight": 1},
|
|
||||||
{"agent_class": CustomAgent, "weight": 2},
|
|
||||||
]
|
|
||||||
converted = agents.serialize_definition(agent_distro)
|
|
||||||
assert converted[0]["agent_class"] == "CounterModel"
|
|
||||||
assert converted[1]["agent_class"] == "test_main.CustomAgent"
|
|
||||||
pickle.dumps(converted)
|
|
||||||
|
|
||||||
def test_templates(self):
|
def test_templates(self):
|
||||||
"""Loading a template should result in several configs"""
|
"""Loading a template should result in several configs"""
|
||||||
configs = serialization.load_file(join(EXAMPLES, "template.yml"))
|
configs = serialization.load_file(join(EXAMPLES, "template.yml"))
|
||||||
|
Loading…
Reference in New Issue
Block a user