mirror of
https://github.com/gsi-upm/soil
synced 2025-01-06 23:01:27 +00:00
Compare commits
5 Commits
880a9f2a1c
...
2f5e5d0a74
Author | SHA1 | Date | |
---|---|---|---|
|
2f5e5d0a74 | ||
|
a2fb25c160 | ||
|
5fcf610108 | ||
|
159c9a9077 | ||
|
3776c4e5c5 |
@ -3,7 +3,7 @@ All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.3 UNRELEASED]
|
||||
## [0.30 UNRELEASED]
|
||||
### Added
|
||||
* Simple debugging capabilities in `soil.debugging`, with a custom `pdb.Debugger` subclass that exposes commands to list agents and their status and set breakpoints on states (for FSM agents). Try it with `soil --debug <simulation file>`
|
||||
* Ability to run
|
||||
|
7
examples/events_and_messages/README.md
Normal file
7
examples/events_and_messages/README.md
Normal file
@ -0,0 +1,7 @@
|
||||
This example can be run like with command-line options, like this:
|
||||
|
||||
```bash
|
||||
python cars.py --level DEBUG -e summary --csv
|
||||
```
|
||||
|
||||
This will set the `CSV` (save the agent and model data to a CSV) and `summary` (print the a summary of the data to stdout) exporters, and set the log level to DEBUG.
|
234
examples/events_and_messages/cars.py
Normal file
234
examples/events_and_messages/cars.py
Normal file
@ -0,0 +1,234 @@
|
||||
"""
|
||||
This is an example of a simplified city, where there are Passengers and Drivers that can take those passengers
|
||||
from their location to their desired location.
|
||||
|
||||
An example scenario could play like the following:
|
||||
|
||||
- Drivers start in the `wandering` state, where they wander around the city until they have been assigned a journey
|
||||
- Passenger(1) tells every driver that it wants to request a Journey.
|
||||
- Each driver receives the request.
|
||||
If Driver(2) is interested in providing the Journey, it asks Passenger(1) to confirm that it accepts Driver(2)'s request
|
||||
- When Passenger(1) accepts the request, two things happen:
|
||||
- Passenger(1) changes its state to `driving_home`
|
||||
- Driver(2) starts moving towards the origin of the Journey
|
||||
- Once Driver(2) reaches the origin, it starts moving itself and Passenger(1) to the destination of the Journey
|
||||
- When Driver(2) reaches the destination (carrying Passenger(1) along):
|
||||
- Driver(2) starts wondering again
|
||||
- Passenger(1) dies, and is removed from the simulation
|
||||
- If there are no more passengers available in the simulation, Drivers die
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from soil import *
|
||||
from soil import events
|
||||
from mesa.space import MultiGrid
|
||||
|
||||
|
||||
# More complex scenarios may use more than one type of message between objects.
|
||||
# A common pattern is to use `enum.Enum` to represent state changes in a request.
|
||||
@dataclass
|
||||
class Journey:
|
||||
"""
|
||||
This represents a request for a journey. Passengers and drivers exchange this object.
|
||||
|
||||
A journey may have a driver assigned or not. If the driver has not been assigned, this
|
||||
object is considered a "request for a journey".
|
||||
"""
|
||||
|
||||
origin: (int, int)
|
||||
destination: (int, int)
|
||||
tip: float
|
||||
|
||||
passenger: Passenger
|
||||
driver: Driver = None
|
||||
|
||||
|
||||
class City(EventedEnvironment):
|
||||
"""
|
||||
An environment with a grid where drivers and passengers will be placed.
|
||||
|
||||
The number of drivers and riders is configurable through its parameters:
|
||||
|
||||
:param str n_cars: The total number of drivers to add
|
||||
:param str n_passengers: The number of passengers in the simulation
|
||||
:param list agents: Specific agents to use in the simulation. It overrides the `n_passengers`
|
||||
and `n_cars` params.
|
||||
:param int height: Height of the internal grid
|
||||
:param int width: Width of the internal grid
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
n_cars=1,
|
||||
n_passengers=10,
|
||||
height=100,
|
||||
width=100,
|
||||
agents=None,
|
||||
model_reporters=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.grid = MultiGrid(width=width, height=height, torus=False)
|
||||
if agents is None:
|
||||
agents = []
|
||||
for i in range(n_cars):
|
||||
agents.append({"agent_class": Driver})
|
||||
for i in range(n_passengers):
|
||||
agents.append({"agent_class": Passenger})
|
||||
model_reporters = model_reporters or {
|
||||
"earnings": "total_earnings",
|
||||
"n_passengers": "number_passengers",
|
||||
}
|
||||
print("REPORTERS", model_reporters)
|
||||
super().__init__(
|
||||
*args, agents=agents, model_reporters=model_reporters, **kwargs
|
||||
)
|
||||
for agent in self.agents:
|
||||
self.grid.place_agent(agent, (0, 0))
|
||||
self.grid.move_to_empty(agent)
|
||||
|
||||
@property
|
||||
def total_earnings(self):
|
||||
return sum(d.earnings for d in self.agents(agent_class=Driver))
|
||||
|
||||
@property
|
||||
def number_passengers(self):
|
||||
return self.count_agents(agent_class=Passenger)
|
||||
|
||||
|
||||
class Driver(Evented, FSM):
|
||||
pos = None
|
||||
journey = None
|
||||
earnings = 0
|
||||
|
||||
def on_receive(self, msg, sender):
|
||||
"""This is not a state. It will run (and block) every time check_messages is invoked"""
|
||||
if self.journey is None and isinstance(msg, Journey) and msg.driver is None:
|
||||
msg.driver = self
|
||||
self.journey = msg
|
||||
|
||||
def check_passengers(self):
|
||||
"""If there are no more passengers, stop forever"""
|
||||
c = self.count_agents(agent_class=Passenger)
|
||||
self.info(f"Passengers left {c}")
|
||||
if not c:
|
||||
self.die()
|
||||
|
||||
@default_state
|
||||
@state
|
||||
def wandering(self):
|
||||
"""Move around the city until a journey is accepted"""
|
||||
target = None
|
||||
self.check_passengers()
|
||||
self.journey = None
|
||||
while self.journey is None: # No potential journeys detected (see on_receive)
|
||||
if target is None or not self.move_towards(target):
|
||||
target = self.random.choice(
|
||||
self.model.grid.get_neighborhood(self.pos, moore=False)
|
||||
)
|
||||
|
||||
self.check_passengers()
|
||||
self.check_messages() # This will call on_receive behind the scenes, and the agent's status will be updated
|
||||
yield Delta(30) # Wait at least 30 seconds before checking again
|
||||
|
||||
try:
|
||||
# Re-send the journey to the passenger, to confirm that we have been selected
|
||||
self.journey = yield self.journey.passenger.ask(self.journey, timeout=60)
|
||||
except events.TimedOut:
|
||||
# No journey has been accepted. Try again
|
||||
self.journey = None
|
||||
return
|
||||
|
||||
return self.driving
|
||||
|
||||
@state
|
||||
def driving(self):
|
||||
"""The journey has been accepted. Pick them up and take them to their destination"""
|
||||
while self.move_towards(self.journey.origin):
|
||||
yield
|
||||
while self.move_towards(self.journey.destination, with_passenger=True):
|
||||
yield
|
||||
self.earnings += self.journey.tip
|
||||
self.check_passengers()
|
||||
return self.wandering
|
||||
|
||||
def move_towards(self, target, with_passenger=False):
|
||||
"""Move one cell at a time towards a target"""
|
||||
self.info(f"Moving { self.pos } -> { target }")
|
||||
if target[0] == self.pos[0] and target[1] == self.pos[1]:
|
||||
return False
|
||||
|
||||
next_pos = [self.pos[0], self.pos[1]]
|
||||
for idx in [0, 1]:
|
||||
if self.pos[idx] < target[idx]:
|
||||
next_pos[idx] += 1
|
||||
break
|
||||
if self.pos[idx] > target[idx]:
|
||||
next_pos[idx] -= 1
|
||||
break
|
||||
self.model.grid.move_agent(self, tuple(next_pos))
|
||||
if with_passenger:
|
||||
self.journey.passenger.pos = (
|
||||
self.pos
|
||||
) # This could be communicated through messages
|
||||
return True
|
||||
|
||||
|
||||
class Passenger(Evented, FSM):
|
||||
pos = None
|
||||
|
||||
def on_receive(self, msg, sender):
|
||||
"""This is not a state. It will be run synchronously every time `check_messages` is run"""
|
||||
|
||||
if isinstance(msg, Journey):
|
||||
self.journey = msg
|
||||
return msg
|
||||
|
||||
@default_state
|
||||
@state
|
||||
def asking(self):
|
||||
destination = (
|
||||
self.random.randint(0, self.model.grid.height),
|
||||
self.random.randint(0, self.model.grid.width),
|
||||
)
|
||||
self.journey = None
|
||||
journey = Journey(
|
||||
origin=self.pos,
|
||||
destination=destination,
|
||||
tip=self.random.randint(10, 100),
|
||||
passenger=self,
|
||||
)
|
||||
|
||||
timeout = 60
|
||||
expiration = self.now + timeout
|
||||
self.model.broadcast(journey, ttl=timeout, sender=self, agent_class=Driver)
|
||||
while not self.journey:
|
||||
self.info(f"Passenger at: { self.pos }. Checking for responses.")
|
||||
try:
|
||||
yield self.received(expiration=expiration)
|
||||
except events.TimedOut:
|
||||
self.info(f"Passenger at: { self.pos }. Asking for journey.")
|
||||
self.model.broadcast(
|
||||
journey, ttl=timeout, sender=self, agent_class=Driver
|
||||
)
|
||||
expiration = self.now + timeout
|
||||
self.check_messages()
|
||||
return self.driving_home
|
||||
|
||||
@state
|
||||
def driving_home(self):
|
||||
while (
|
||||
self.pos[0] != self.journey.destination[0]
|
||||
or self.pos[1] != self.journey.destination[1]
|
||||
):
|
||||
yield self.received(timeout=60)
|
||||
self.info("Got home safe!")
|
||||
self.die()
|
||||
|
||||
|
||||
simulation = Simulation(
|
||||
name="RideHailing", model_class=City, model_params={"n_passengers": 2}
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
with easy(simulation) as s:
|
||||
s.run()
|
@ -58,7 +58,7 @@ class SocialMoneyAgent(NetworkAgent, MoneyAgent):
|
||||
|
||||
def give_money(self):
|
||||
cellmates = set(self.model.grid.get_cell_list_contents([self.pos]))
|
||||
friends = set(self.get_neighboring_agents())
|
||||
friends = set(self.get_neighbors())
|
||||
self.info("Trying to give money")
|
||||
self.info("Cellmates: ", cellmates)
|
||||
self.info("Friends: ", friends)
|
||||
|
@ -8,10 +8,9 @@ class DumbViewer(FSM, NetworkAgent):
|
||||
its neighbors once it's infected.
|
||||
"""
|
||||
|
||||
defaults = {
|
||||
"prob_neighbor_spread": 0.5,
|
||||
"prob_tv_spread": 0.1,
|
||||
}
|
||||
prob_neighbor_spread = 0.5
|
||||
prob_tv_spread = 0.1
|
||||
has_been_infected = False
|
||||
|
||||
@default_state
|
||||
@state
|
||||
@ -19,10 +18,12 @@ class DumbViewer(FSM, NetworkAgent):
|
||||
if self["has_tv"]:
|
||||
if self.prob(self.model["prob_tv_spread"]):
|
||||
return self.infected
|
||||
if self.has_been_infected:
|
||||
return self.infected
|
||||
|
||||
@state
|
||||
def infected(self):
|
||||
for neighbor in self.get_neighboring_agents(state_id=self.neutral.id):
|
||||
for neighbor in self.get_neighbors(state_id=self.neutral.id):
|
||||
if self.prob(self.model["prob_neighbor_spread"]):
|
||||
neighbor.infect()
|
||||
|
||||
@ -33,7 +34,7 @@ class DumbViewer(FSM, NetworkAgent):
|
||||
HerdViewer might not become infected right away
|
||||
"""
|
||||
|
||||
self.set_state(self.infected)
|
||||
self.has_been_infected = True
|
||||
|
||||
|
||||
class HerdViewer(DumbViewer):
|
||||
@ -43,12 +44,12 @@ class HerdViewer(DumbViewer):
|
||||
|
||||
def infect(self):
|
||||
"""Notice again that this is NOT a state. See DumbViewer.infect for reference"""
|
||||
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
||||
total = self.count_neighboring_agents()
|
||||
infected = self.count_neighbors(state_id=self.infected.id)
|
||||
total = self.count_neighbors()
|
||||
prob_infect = self.model["prob_neighbor_spread"] * infected / total
|
||||
self.debug("prob_infect", prob_infect)
|
||||
if self.prob(prob_infect):
|
||||
self.set_state(self.infected)
|
||||
self.has_been_infected = True
|
||||
|
||||
|
||||
class WiseViewer(HerdViewer):
|
||||
@ -65,7 +66,7 @@ class WiseViewer(HerdViewer):
|
||||
@state
|
||||
def cured(self):
|
||||
prob_cure = self.model["prob_neighbor_cure"]
|
||||
for neighbor in self.get_neighboring_agents(state_id=self.infected.id):
|
||||
for neighbor in self.get_neighbors(state_id=self.infected.id):
|
||||
if self.prob(prob_cure):
|
||||
try:
|
||||
neighbor.cure()
|
||||
@ -73,13 +74,14 @@ class WiseViewer(HerdViewer):
|
||||
self.debug("Viewer {} cannot be cured".format(neighbor.id))
|
||||
|
||||
def cure(self):
|
||||
self.set_state(self.cured.id)
|
||||
self.has_been_cured = True
|
||||
|
||||
@state
|
||||
def infected(self):
|
||||
cured = max(self.count_neighboring_agents(self.cured.id), 1.0)
|
||||
infected = max(self.count_neighboring_agents(self.infected.id), 1.0)
|
||||
if self.has_been_cured:
|
||||
return self.cured
|
||||
cured = max(self.count_neighbors(self.cured.id), 1.0)
|
||||
infected = max(self.count_neighbors(self.infected.id), 1.0)
|
||||
prob_cure = self.model["prob_neighbor_cure"] * (cured / infected)
|
||||
if self.prob(prob_cure):
|
||||
return self.cured
|
||||
return self.set_state(super().infected)
|
||||
|
@ -89,7 +89,7 @@ class Patron(FSM, NetworkAgent):
|
||||
if self["pub"] != None:
|
||||
return self.sober_in_pub
|
||||
self.debug("I am looking for a pub")
|
||||
group = list(self.get_neighboring_agents())
|
||||
group = list(self.get_neighbors())
|
||||
for pub in self.model.available_pubs():
|
||||
self.debug("We're trying to get into {}: total: {}".format(pub, len(group)))
|
||||
if self.model.enter(pub, self, *group):
|
||||
|
@ -49,7 +49,7 @@ class TerroristSpreadModel(FSM, Geo):
|
||||
|
||||
@state
|
||||
def civilian(self):
|
||||
neighbours = list(self.get_neighboring_agents(agent_class=TerroristSpreadModel))
|
||||
neighbours = list(self.get_neighbors(agent_class=TerroristSpreadModel))
|
||||
if len(neighbours) > 0:
|
||||
# Only interact with some of the neighbors
|
||||
interactions = list(
|
||||
@ -73,7 +73,7 @@ class TerroristSpreadModel(FSM, Geo):
|
||||
@state
|
||||
def leader(self):
|
||||
self.mean_belief = self.mean_belief ** (1 - self.terrorist_additional_influence)
|
||||
for neighbour in self.get_neighboring_agents(
|
||||
for neighbour in self.get_neighbors(
|
||||
state_id=[self.terrorist.id, self.leader.id]
|
||||
):
|
||||
if self.betweenness(neighbour) > self.betweenness(self):
|
||||
@ -158,7 +158,7 @@ class TrainingAreaModel(FSM, Geo):
|
||||
@default_state
|
||||
@state
|
||||
def terrorist(self):
|
||||
for neighbour in self.get_neighboring_agents(agent_class=TerroristSpreadModel):
|
||||
for neighbour in self.get_neighbors(agent_class=TerroristSpreadModel):
|
||||
if neighbour.vulnerability > self.min_vulnerability:
|
||||
neighbour.vulnerability = neighbour.vulnerability ** (
|
||||
1 - self.training_influence
|
||||
@ -187,7 +187,7 @@ class HavenModel(FSM, Geo):
|
||||
self.max_vulnerability = model.environment_params["max_vulnerability"]
|
||||
|
||||
def get_occupants(self, **kwargs):
|
||||
return self.get_neighboring_agents(agent_class=TerroristSpreadModel, **kwargs)
|
||||
return self.get_neighbors(agent_class=TerroristSpreadModel, **kwargs)
|
||||
|
||||
@state
|
||||
def civilian(self):
|
||||
@ -243,7 +243,7 @@ class TerroristNetworkModel(TerroristSpreadModel):
|
||||
return super().leader()
|
||||
|
||||
def update_relationships(self):
|
||||
if self.count_neighboring_agents(state_id=self.civilian.id) == 0:
|
||||
if self.count_neighbors(state_id=self.civilian.id) == 0:
|
||||
close_ups = set(
|
||||
self.geo_search(
|
||||
radius=self.vision_range, agent_class=TerroristNetworkModel
|
||||
@ -258,7 +258,7 @@ class TerroristNetworkModel(TerroristSpreadModel):
|
||||
)
|
||||
neighbours = set(
|
||||
agent.id
|
||||
for agent in self.get_neighboring_agents(
|
||||
for agent in self.get_neighbors(
|
||||
agent_class=TerroristNetworkModel
|
||||
)
|
||||
)
|
||||
|
@ -1 +1 @@
|
||||
0.20.7
|
||||
0.30.0rc2
|
@ -17,7 +17,7 @@ except NameError:
|
||||
from .agents import *
|
||||
from . import agents
|
||||
from .simulation import *
|
||||
from .environment import Environment
|
||||
from .environment import Environment, EventedEnvironment
|
||||
from . import serialization
|
||||
from .utils import logger
|
||||
from .time import *
|
||||
@ -34,6 +34,9 @@ def main(
|
||||
pdb=False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if isinstance(cfg, Simulation):
|
||||
sim = cfg
|
||||
import argparse
|
||||
from . import simulation
|
||||
|
||||
@ -44,7 +47,7 @@ def main(
|
||||
"file",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default=cfg,
|
||||
default=cfg if sim is None else '',
|
||||
help="Configuration file for the simulation (e.g., YAML or JSON)",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -150,8 +153,6 @@ def main(
|
||||
if output is None:
|
||||
output = args.output
|
||||
|
||||
logger.info("Loading config file: {}".format(args.file))
|
||||
|
||||
debug = debug or args.debug
|
||||
|
||||
if args.pdb or debug:
|
||||
@ -162,11 +163,20 @@ def main(
|
||||
try:
|
||||
exp_params = {}
|
||||
|
||||
if sim:
|
||||
logger.info("Loading simulation instance")
|
||||
sim.dry_run = args.dry_run
|
||||
sim.exporters = exporters
|
||||
sim.parallel = parallel
|
||||
sim.outdir = output
|
||||
sims = [sim, ]
|
||||
else:
|
||||
logger.info("Loading config file: {}".format(args.file))
|
||||
if not os.path.exists(args.file):
|
||||
logger.error("Please, input a valid file")
|
||||
return
|
||||
|
||||
for sim in simulation.iter_from_config(
|
||||
sims = list(simulation.iter_from_config(
|
||||
args.file,
|
||||
dry_run=args.dry_run,
|
||||
exporters=exporters,
|
||||
@ -174,7 +184,10 @@ def main(
|
||||
outdir=output,
|
||||
exporter_params=exp_params,
|
||||
**kwargs,
|
||||
):
|
||||
))
|
||||
|
||||
for sim in sims:
|
||||
|
||||
if args.set:
|
||||
for s in args.set:
|
||||
k, v = s.split("=", 1)[:2]
|
||||
@ -219,19 +232,15 @@ def main(
|
||||
|
||||
@contextmanager
|
||||
def easy(cfg, pdb=False, debug=False, **kwargs):
|
||||
ex = None
|
||||
try:
|
||||
yield main(cfg, **kwargs)[0]
|
||||
yield main(cfg, debug=debug, pdb=pdb, **kwargs)[0]
|
||||
except Exception as e:
|
||||
if os.environ.get("SOIL_POSTMORTEM"):
|
||||
from .debugging import post_mortem
|
||||
|
||||
print(traceback.format_exc())
|
||||
post_mortem()
|
||||
ex = e
|
||||
finally:
|
||||
if ex:
|
||||
raise ex
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -20,7 +20,7 @@ class BassModel(FSM):
|
||||
self.sentimentCorrelation = 1
|
||||
return self.aware
|
||||
else:
|
||||
aware_neighbors = self.get_neighboring_agents(state_id=self.aware.id)
|
||||
aware_neighbors = self.get_neighbors(state_id=self.aware.id)
|
||||
num_neighbors_aware = len(aware_neighbors)
|
||||
if self.prob((self["imitation_prob"] * num_neighbors_aware)):
|
||||
self.sentimentCorrelation = 1
|
||||
|
@ -24,14 +24,14 @@ class BigMarketModel(FSM):
|
||||
self.type = ""
|
||||
|
||||
if self.id < len(self.enterprises): # Enterprises
|
||||
self.set_state(self.enterprise.id)
|
||||
self._set_state(self.enterprise.id)
|
||||
self.type = "Enterprise"
|
||||
self.tweet_probability = environment.environment_params[
|
||||
"tweet_probability_enterprises"
|
||||
][self.id]
|
||||
else: # normal users
|
||||
self.type = "User"
|
||||
self.set_state(self.user.id)
|
||||
self._set_state(self.user.id)
|
||||
self.tweet_probability = environment.environment_params[
|
||||
"tweet_probability_users"
|
||||
]
|
||||
@ -49,7 +49,7 @@ class BigMarketModel(FSM):
|
||||
def enterprise(self):
|
||||
|
||||
if self.random.random() < self.tweet_probability: # Tweets
|
||||
aware_neighbors = self.get_neighboring_agents(
|
||||
aware_neighbors = self.get_neighbors(
|
||||
state_id=self.number_of_enterprises
|
||||
) # Nodes neighbour users
|
||||
for x in aware_neighbors:
|
||||
@ -96,7 +96,7 @@ class BigMarketModel(FSM):
|
||||
] = self.sentiment_about[i]
|
||||
|
||||
def userTweets(self, sentiment, enterprise):
|
||||
aware_neighbors = self.get_neighboring_agents(
|
||||
aware_neighbors = self.get_neighbors(
|
||||
state_id=self.number_of_enterprises
|
||||
) # Nodes neighbours users
|
||||
for x in aware_neighbors:
|
||||
|
@ -14,7 +14,7 @@ class CounterModel(NetworkAgent):
|
||||
def step(self):
|
||||
# Outside effects
|
||||
total = len(list(self.model.schedule._agents))
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
neighbors = len(list(self.get_neighbors()))
|
||||
self["times"] = self.get("times", 0) + 1
|
||||
self["neighbors"] = neighbors
|
||||
self["total"] = total
|
||||
@ -33,7 +33,7 @@ class AggregatedCounter(NetworkAgent):
|
||||
def step(self):
|
||||
# Outside effects
|
||||
self["times"] += 1
|
||||
neighbors = len(list(self.get_neighboring_agents()))
|
||||
neighbors = len(list(self.get_neighbors()))
|
||||
self["neighbors"] += neighbors
|
||||
total = len(list(self.model.schedule.agents))
|
||||
self["total"] += total
|
||||
|
@ -36,7 +36,7 @@ class IndependentCascadeModel(BaseAgent):
|
||||
|
||||
# Imitation effects
|
||||
if self.state["id"] == 0:
|
||||
aware_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
aware_neighbors = self.get_neighbors(state_id=1)
|
||||
for x in aware_neighbors:
|
||||
if x.state["time_awareness"] == (self.env.now - 1):
|
||||
aware_neighbors_1_time_step.append(x)
|
||||
|
@ -71,7 +71,7 @@ class SpreadModelM2(BaseAgent):
|
||||
def neutral_behaviour(self):
|
||||
|
||||
# Infected
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
infected_neighbors = self.get_neighbors(state_id=1)
|
||||
if len(infected_neighbors) > 0:
|
||||
if self.prob(self.prob_neutral_making_denier):
|
||||
self.state["id"] = 3 # Vaccinated making denier
|
||||
@ -79,7 +79,7 @@ class SpreadModelM2(BaseAgent):
|
||||
def infected_behaviour(self):
|
||||
|
||||
# Neutral
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_infect):
|
||||
neighbor.state["id"] = 1 # Infected
|
||||
@ -87,13 +87,13 @@ class SpreadModelM2(BaseAgent):
|
||||
def cured_behaviour(self):
|
||||
|
||||
# Vaccinate
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||
neighbor.state["id"] = 3 # Vaccinated
|
||||
|
||||
# Cure
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
infected_neighbors = self.get_neighbors(state_id=1)
|
||||
for neighbor in infected_neighbors:
|
||||
if self.prob(self.prob_cured_healing_infected):
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
@ -101,19 +101,19 @@ class SpreadModelM2(BaseAgent):
|
||||
def vaccinated_behaviour(self):
|
||||
|
||||
# Cure
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
infected_neighbors = self.get_neighbors(state_id=1)
|
||||
for neighbor in infected_neighbors:
|
||||
if self.prob(self.prob_cured_healing_infected):
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
|
||||
# Vaccinate
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||
neighbor.state["id"] = 3 # Vaccinated
|
||||
|
||||
# Generate anti-rumor
|
||||
infected_neighbors_2 = self.get_neighboring_agents(state_id=1)
|
||||
infected_neighbors_2 = self.get_neighbors(state_id=1)
|
||||
for neighbor in infected_neighbors_2:
|
||||
if self.prob(self.prob_generate_anti_rumor):
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
@ -191,7 +191,7 @@ class ControlModelM2(BaseAgent):
|
||||
self.state["visible"] = False
|
||||
|
||||
# Infected
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
infected_neighbors = self.get_neighbors(state_id=1)
|
||||
if len(infected_neighbors) > 0:
|
||||
if self.random(self.prob_neutral_making_denier):
|
||||
self.state["id"] = 3 # Vaccinated making denier
|
||||
@ -199,7 +199,7 @@ class ControlModelM2(BaseAgent):
|
||||
def infected_behaviour(self):
|
||||
|
||||
# Neutral
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_infect):
|
||||
neighbor.state["id"] = 1 # Infected
|
||||
@ -209,13 +209,13 @@ class ControlModelM2(BaseAgent):
|
||||
|
||||
self.state["visible"] = True
|
||||
# Vaccinate
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||
neighbor.state["id"] = 3 # Vaccinated
|
||||
|
||||
# Cure
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
infected_neighbors = self.get_neighbors(state_id=1)
|
||||
for neighbor in infected_neighbors:
|
||||
if self.prob(self.prob_cured_healing_infected):
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
@ -224,47 +224,47 @@ class ControlModelM2(BaseAgent):
|
||||
self.state["visible"] = True
|
||||
|
||||
# Cure
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
infected_neighbors = self.get_neighbors(state_id=1)
|
||||
for neighbor in infected_neighbors:
|
||||
if self.prob(self.prob_cured_healing_infected):
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
|
||||
# Vaccinate
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||
neighbor.state["id"] = 3 # Vaccinated
|
||||
|
||||
# Generate anti-rumor
|
||||
infected_neighbors_2 = self.get_neighboring_agents(state_id=1)
|
||||
infected_neighbors_2 = self.get_neighbors(state_id=1)
|
||||
for neighbor in infected_neighbors_2:
|
||||
if self.prob(self.prob_generate_anti_rumor):
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
|
||||
def beacon_off_behaviour(self):
|
||||
self.state["visible"] = False
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
infected_neighbors = self.get_neighbors(state_id=1)
|
||||
if len(infected_neighbors) > 0:
|
||||
self.state["id"] == 5 # Beacon on
|
||||
|
||||
def beacon_on_behaviour(self):
|
||||
self.state["visible"] = False
|
||||
# Cure (M2 feature added)
|
||||
infected_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
infected_neighbors = self.get_neighbors(state_id=1)
|
||||
for neighbor in infected_neighbors:
|
||||
if self.prob(self.prob_generate_anti_rumor):
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
neutral_neighbors_infected = neighbor.get_neighboring_agents(state_id=0)
|
||||
neutral_neighbors_infected = neighbor.get_neighbors(state_id=0)
|
||||
for neighbor in neutral_neighbors_infected:
|
||||
if self.prob(self.prob_generate_anti_rumor):
|
||||
neighbor.state["id"] = 3 # Vaccinated
|
||||
infected_neighbors_infected = neighbor.get_neighboring_agents(state_id=1)
|
||||
infected_neighbors_infected = neighbor.get_neighbors(state_id=1)
|
||||
for neighbor in infected_neighbors_infected:
|
||||
if self.prob(self.prob_generate_anti_rumor):
|
||||
neighbor.state["id"] = 2 # Cured
|
||||
|
||||
# Vaccinate
|
||||
neutral_neighbors = self.get_neighboring_agents(state_id=0)
|
||||
neutral_neighbors = self.get_neighbors(state_id=0)
|
||||
for neighbor in neutral_neighbors:
|
||||
if self.prob(self.prob_cured_vaccinate_neutral):
|
||||
neighbor.state["id"] = 3 # Vaccinated
|
||||
|
@ -69,10 +69,10 @@ class SISaModel(FSM):
|
||||
return self.content
|
||||
|
||||
# Infected
|
||||
discontent_neighbors = self.count_neighboring_agents(state_id=self.discontent)
|
||||
discontent_neighbors = self.count_neighbors(state_id=self.discontent)
|
||||
if self.prob(scontent_neighbors * self.neutral_discontent_infected_prob):
|
||||
return self.discontent
|
||||
content_neighbors = self.count_neighboring_agents(state_id=self.content.id)
|
||||
content_neighbors = self.count_neighbors(state_id=self.content.id)
|
||||
if self.prob(s * self.neutral_content_infected_prob):
|
||||
return self.content
|
||||
return self.neutral
|
||||
@ -84,7 +84,7 @@ class SISaModel(FSM):
|
||||
return self.neutral
|
||||
|
||||
# Superinfected
|
||||
content_neighbors = self.count_neighboring_agents(state_id=self.content.id)
|
||||
content_neighbors = self.count_neighbors(state_id=self.content.id)
|
||||
if self.prob(s * self.discontent_content):
|
||||
return self.content
|
||||
return self.discontent
|
||||
@ -96,9 +96,7 @@ class SISaModel(FSM):
|
||||
return self.neutral
|
||||
|
||||
# Superinfected
|
||||
discontent_neighbors = self.count_neighboring_agents(
|
||||
state_id=self.discontent.id
|
||||
)
|
||||
discontent_neighbors = self.count_neighbors(state_id=self.discontent.id)
|
||||
if self.prob(scontent_neighbors * self.content_discontent):
|
||||
self.discontent
|
||||
return self.content
|
||||
|
@ -41,25 +41,25 @@ class SentimentCorrelationModel(BaseAgent):
|
||||
sad_neighbors_1_time_step = []
|
||||
disgusted_neighbors_1_time_step = []
|
||||
|
||||
angry_neighbors = self.get_neighboring_agents(state_id=1)
|
||||
angry_neighbors = self.get_neighbors(state_id=1)
|
||||
for x in angry_neighbors:
|
||||
if x.state["time_awareness"][0] > (self.env.now - 500):
|
||||
angry_neighbors_1_time_step.append(x)
|
||||
num_neighbors_angry = len(angry_neighbors_1_time_step)
|
||||
|
||||
joyful_neighbors = self.get_neighboring_agents(state_id=2)
|
||||
joyful_neighbors = self.get_neighbors(state_id=2)
|
||||
for x in joyful_neighbors:
|
||||
if x.state["time_awareness"][1] > (self.env.now - 500):
|
||||
joyful_neighbors_1_time_step.append(x)
|
||||
num_neighbors_joyful = len(joyful_neighbors_1_time_step)
|
||||
|
||||
sad_neighbors = self.get_neighboring_agents(state_id=3)
|
||||
sad_neighbors = self.get_neighbors(state_id=3)
|
||||
for x in sad_neighbors:
|
||||
if x.state["time_awareness"][2] > (self.env.now - 500):
|
||||
sad_neighbors_1_time_step.append(x)
|
||||
num_neighbors_sad = len(sad_neighbors_1_time_step)
|
||||
|
||||
disgusted_neighbors = self.get_neighboring_agents(state_id=4)
|
||||
disgusted_neighbors = self.get_neighbors(state_id=4)
|
||||
for x in disgusted_neighbors:
|
||||
if x.state["time_awareness"][3] > (self.env.now - 500):
|
||||
disgusted_neighbors_1_time_step.append(x)
|
||||
|
@ -40,23 +40,31 @@ class MetaAgent(ABCMeta):
|
||||
|
||||
new_nmspc = {
|
||||
"_defaults": defaults,
|
||||
"_last_return": None,
|
||||
"_last_except": None,
|
||||
}
|
||||
|
||||
for attr, func in namespace.items():
|
||||
if attr == "step" and inspect.isgeneratorfunction(func):
|
||||
orig_func = func
|
||||
new_nmspc["_MetaAgent__coroutine"] = None
|
||||
new_nmspc["_coroutine"] = None
|
||||
|
||||
@wraps(func)
|
||||
def func(self):
|
||||
while True:
|
||||
if not self.__coroutine:
|
||||
self.__coroutine = orig_func(self)
|
||||
if not self._coroutine:
|
||||
self._coroutine = orig_func(self)
|
||||
try:
|
||||
return next(self.__coroutine)
|
||||
if self._last_except:
|
||||
return self._coroutine.throw(self._last_except)
|
||||
else:
|
||||
return self._coroutine.send(self._last_return)
|
||||
except StopIteration as ex:
|
||||
self.__coroutine = None
|
||||
self._coroutine = None
|
||||
return ex.value
|
||||
finally:
|
||||
self._last_return = None
|
||||
self._last_except = None
|
||||
|
||||
func.id = name or func.__name__
|
||||
func.is_default = False
|
||||
@ -190,6 +198,10 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
def die(self):
|
||||
self.info(f"agent dying")
|
||||
self.alive = False
|
||||
try:
|
||||
self.model.schedule.remove(self)
|
||||
except KeyError:
|
||||
pass
|
||||
return time.NEVER
|
||||
|
||||
def step(self):
|
||||
@ -243,223 +255,6 @@ class BaseAgent(MesaAgent, MutableMapping, metaclass=MetaAgent):
|
||||
return f"{self.__class__.__name__}({self.unique_id})"
|
||||
|
||||
|
||||
class NetworkAgent(BaseAgent):
|
||||
def __init__(self, *args, topology, node_id, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
assert topology is not None
|
||||
assert node_id is not None
|
||||
self.G = topology
|
||||
assert self.G
|
||||
self.node_id = node_id
|
||||
|
||||
def count_neighboring_agents(self, state_id=None, **kwargs):
|
||||
return len(self.get_neighboring_agents(state_id=state_id, **kwargs))
|
||||
|
||||
def get_neighboring_agents(self, **kwargs):
|
||||
return list(self.iter_agents(limit_neighbors=True, **kwargs))
|
||||
|
||||
def add_edge(self, other):
|
||||
assert self.node_id
|
||||
assert other.node_id
|
||||
assert self.node_id in self.G.nodes
|
||||
assert other.node_id in self.G.nodes
|
||||
self.topology.add_edge(self.node_id, other.node_id)
|
||||
|
||||
@property
|
||||
def node(self):
|
||||
return self.topology.nodes[self.node_id]
|
||||
|
||||
def iter_agents(self, unique_id=None, *, limit_neighbors=False, **kwargs):
|
||||
unique_ids = None
|
||||
if isinstance(unique_id, list):
|
||||
unique_ids = set(unique_id)
|
||||
elif unique_id is not None:
|
||||
unique_ids = set(
|
||||
[
|
||||
unique_id,
|
||||
]
|
||||
)
|
||||
|
||||
if limit_neighbors:
|
||||
neighbor_ids = set()
|
||||
for node_id in self.G.neighbors(self.node_id):
|
||||
if self.G.nodes[node_id].get("agent") is not None:
|
||||
neighbor_ids.add(node_id)
|
||||
if unique_ids:
|
||||
unique_ids = unique_ids & neighbor_ids
|
||||
else:
|
||||
unique_ids = neighbor_ids
|
||||
if not unique_ids:
|
||||
return
|
||||
unique_ids = list(unique_ids)
|
||||
yield from super().iter_agents(unique_id=unique_ids, **kwargs)
|
||||
|
||||
def subgraph(self, center=True, **kwargs):
|
||||
include = [self] if center else []
|
||||
G = self.G.subgraph(
|
||||
n.node_id for n in list(self.get_agents(**kwargs) + include)
|
||||
)
|
||||
return G
|
||||
|
||||
def remove_node(self):
|
||||
print(f"Removing node for {self.unique_id}: {self.node_id}")
|
||||
self.G.remove_node(self.node_id)
|
||||
self.node_id = None
|
||||
|
||||
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
||||
if self.node_id not in self.G.nodes(data=False):
|
||||
raise ValueError(
|
||||
"{} not in list of existing agents in the network".format(
|
||||
self.unique_id
|
||||
)
|
||||
)
|
||||
if other.node_id not in self.G.nodes(data=False):
|
||||
raise ValueError(
|
||||
"{} not in list of existing agents in the network".format(other)
|
||||
)
|
||||
|
||||
self.G.add_edge(
|
||||
self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs
|
||||
)
|
||||
|
||||
def die(self, remove=True):
|
||||
if not self.alive:
|
||||
return
|
||||
if remove:
|
||||
self.remove_node()
|
||||
return super().die()
|
||||
|
||||
|
||||
def state(name=None):
|
||||
def decorator(func, name=None):
|
||||
"""
|
||||
A state function should return either a state id, or a tuple (state_id, when)
|
||||
The default value for state_id is the current state id.
|
||||
The default value for when is the interval defined in the environment.
|
||||
"""
|
||||
if inspect.isgeneratorfunction(func):
|
||||
orig_func = func
|
||||
|
||||
@wraps(func)
|
||||
def func(self):
|
||||
while True:
|
||||
if not self._coroutine:
|
||||
self._coroutine = orig_func(self)
|
||||
try:
|
||||
n = next(self._coroutine)
|
||||
if n:
|
||||
return None, n
|
||||
return
|
||||
except StopIteration as ex:
|
||||
self._coroutine = None
|
||||
next_state = ex.value
|
||||
if next_state is not None:
|
||||
self._set_state(next_state)
|
||||
return next_state
|
||||
|
||||
func.id = name or func.__name__
|
||||
func.is_default = False
|
||||
return func
|
||||
|
||||
if callable(name):
|
||||
return decorator(name)
|
||||
else:
|
||||
return partial(decorator, name=name)
|
||||
|
||||
|
||||
def default_state(func):
|
||||
func.is_default = True
|
||||
return func
|
||||
|
||||
|
||||
class MetaFSM(MetaAgent):
|
||||
def __new__(mcls, name, bases, namespace):
|
||||
states = {}
|
||||
# Re-use states from inherited classes
|
||||
default_state = None
|
||||
for i in bases:
|
||||
if isinstance(i, MetaFSM):
|
||||
for state_id, state in i._states.items():
|
||||
if state.is_default:
|
||||
default_state = state
|
||||
states[state_id] = state
|
||||
|
||||
# Add new states
|
||||
for attr, func in namespace.items():
|
||||
if hasattr(func, "id"):
|
||||
if func.is_default:
|
||||
default_state = func
|
||||
states[func.id] = func
|
||||
|
||||
namespace.update(
|
||||
{
|
||||
"_default_state": default_state,
|
||||
"_states": states,
|
||||
}
|
||||
)
|
||||
|
||||
return super(MetaFSM, mcls).__new__(
|
||||
mcls=mcls, name=name, bases=bases, namespace=namespace
|
||||
)
|
||||
|
||||
|
||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
def __init__(self, **kwargs):
|
||||
super(FSM, self).__init__(**kwargs)
|
||||
if not hasattr(self, "state_id"):
|
||||
if not self._default_state:
|
||||
raise ValueError(
|
||||
"No default state specified for {}".format(self.unique_id)
|
||||
)
|
||||
self.state_id = self._default_state.id
|
||||
|
||||
self._coroutine = None
|
||||
self._set_state(self.state_id)
|
||||
|
||||
def step(self):
|
||||
self.debug(f"Agent {self.unique_id} @ state {self.state_id}")
|
||||
default_interval = super().step()
|
||||
|
||||
next_state = self._states[self.state_id](self)
|
||||
|
||||
when = None
|
||||
try:
|
||||
next_state, *when = next_state
|
||||
if not when:
|
||||
when = None
|
||||
elif len(when) == 1:
|
||||
when = when[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Too many values returned. Only state (and time) allowed"
|
||||
)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
if next_state is not None:
|
||||
self._set_state(next_state)
|
||||
|
||||
return when or default_interval
|
||||
|
||||
def _set_state(self, state, when=None):
|
||||
if hasattr(state, "id"):
|
||||
state = state.id
|
||||
if state not in self._states:
|
||||
raise ValueError("{} is not a valid state".format(state))
|
||||
self.state_id = state
|
||||
if when is not None:
|
||||
self.model.schedule.add(self, when=when)
|
||||
return state
|
||||
|
||||
def die(self):
|
||||
return self.dead, super().die()
|
||||
|
||||
@state
|
||||
def dead(self):
|
||||
return self.die()
|
||||
|
||||
|
||||
def prob(prob, random):
|
||||
"""
|
||||
A true/False uniform distribution with a given probability.
|
||||
@ -525,7 +320,7 @@ def calculate_distribution(network_agents=None, agent_class=None):
|
||||
return network_agents
|
||||
|
||||
|
||||
def serialize_type(agent_class, known_modules=[], **kwargs):
|
||||
def _serialize_type(agent_class, known_modules=[], **kwargs):
|
||||
if isinstance(agent_class, str):
|
||||
return agent_class
|
||||
known_modules += ["soil.agents"]
|
||||
@ -534,20 +329,7 @@ def serialize_type(agent_class, known_modules=[], **kwargs):
|
||||
] # Get the name of the class
|
||||
|
||||
|
||||
def serialize_definition(network_agents, known_modules=[]):
|
||||
"""
|
||||
When serializing an agent distribution, remove the thresholds, in order
|
||||
to avoid cluttering the YAML definition file.
|
||||
"""
|
||||
d = deepcopy(list(network_agents))
|
||||
for v in d:
|
||||
if "threshold" in v:
|
||||
del v["threshold"]
|
||||
v["agent_class"] = serialize_type(v["agent_class"], known_modules=known_modules)
|
||||
return d
|
||||
|
||||
|
||||
def deserialize_type(agent_class, known_modules=[]):
|
||||
def _deserialize_type(agent_class, known_modules=[]):
|
||||
if not isinstance(agent_class, str):
|
||||
return agent_class
|
||||
known = known_modules + ["soil.agents", "soil.agents.custom"]
|
||||
@ -555,31 +337,6 @@ def deserialize_type(agent_class, known_modules=[]):
|
||||
return agent_class
|
||||
|
||||
|
||||
def deserialize_definition(ind, **kwargs):
|
||||
d = deepcopy(ind)
|
||||
for v in d:
|
||||
v["agent_class"] = deserialize_type(v["agent_class"], **kwargs)
|
||||
return d
|
||||
|
||||
|
||||
def _validate_states(states, topology):
|
||||
"""Validate states to avoid ignoring states during initialization"""
|
||||
states = states or []
|
||||
if isinstance(states, dict):
|
||||
for x in states:
|
||||
assert x in topology.nodes
|
||||
else:
|
||||
assert len(states) <= len(topology)
|
||||
return states
|
||||
|
||||
|
||||
def _convert_agent_classs(ind, to_string=False, **kwargs):
|
||||
"""Convenience method to allow specifying agents by class or class name."""
|
||||
if to_string:
|
||||
return serialize_definition(ind, **kwargs)
|
||||
return deserialize_definition(ind, **kwargs)
|
||||
|
||||
|
||||
class AgentView(Mapping, Set):
|
||||
"""A lazy-loaded list of agents."""
|
||||
|
||||
@ -663,7 +420,7 @@ def filter_agents(
|
||||
state_id = tuple([state_id])
|
||||
|
||||
if agent_class is not None:
|
||||
agent_class = deserialize_type(agent_class)
|
||||
agent_class = _deserialize_type(agent_class)
|
||||
try:
|
||||
agent_class = tuple(agent_class)
|
||||
except TypeError:
|
||||
@ -703,14 +460,6 @@ def from_config(
|
||||
default = cfg or config.AgentConfig()
|
||||
if not isinstance(cfg, config.AgentConfig):
|
||||
cfg = config.AgentConfig(**cfg)
|
||||
return _agents_from_config(cfg, topology=topology, random=random)
|
||||
|
||||
|
||||
def _agents_from_config(
|
||||
cfg: config.AgentConfig, topology: nx.Graph, random
|
||||
) -> List[Dict[str, Any]]:
|
||||
if cfg and not isinstance(cfg, config.AgentConfig):
|
||||
cfg = config.AgentConfig(**cfg)
|
||||
|
||||
agents = []
|
||||
|
||||
@ -878,6 +627,9 @@ def _from_distro(
|
||||
return agents
|
||||
|
||||
|
||||
from .network_agents import *
|
||||
from .fsm import *
|
||||
from .evented import *
|
||||
from .BassModel import *
|
||||
from .BigMarketModel import *
|
||||
from .IndependentCascadeModel import *
|
||||
|
57
soil/agents/evented.py
Normal file
57
soil/agents/evented.py
Normal file
@ -0,0 +1,57 @@
|
||||
from . import BaseAgent
|
||||
from ..events import Message, Tell, Ask, Reply, TimedOut
|
||||
from ..time import Cond
|
||||
from functools import partial
|
||||
from collections import deque
|
||||
|
||||
|
||||
class Evented(BaseAgent):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._inbox = deque()
|
||||
self._received = 0
|
||||
self._processed = 0
|
||||
|
||||
|
||||
def on_receive(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def received(self, expiration=None, timeout=None):
|
||||
current = self._received
|
||||
if expiration is None:
|
||||
expiration = float('inf') if timeout is None else self.now + timeout
|
||||
|
||||
if expiration < self.now:
|
||||
raise ValueError("Invalid expiration time")
|
||||
|
||||
def ready(agent):
|
||||
return agent._received > current or agent.now >= expiration
|
||||
|
||||
def value(agent):
|
||||
if agent.now > expiration:
|
||||
raise TimedOut("No message received")
|
||||
|
||||
c = Cond(func=ready, return_func=value)
|
||||
c._checked = True
|
||||
return c
|
||||
|
||||
def tell(self, msg, sender):
|
||||
self._received += 1
|
||||
self._inbox.append(Tell(payload=msg, sender=sender))
|
||||
|
||||
def ask(self, msg, timeout=None):
|
||||
self._received += 1
|
||||
ask = Ask(payload=msg)
|
||||
self._inbox.append(ask)
|
||||
expiration = float('inf') if timeout is None else self.now + timeout
|
||||
return ask.replied(expiration=expiration)
|
||||
|
||||
def check_messages(self):
|
||||
while self._inbox:
|
||||
msg = self._inbox.popleft()
|
||||
self._processed += 1
|
||||
if msg.expired(self.now):
|
||||
continue
|
||||
reply = self.on_receive(msg.payload, sender=msg.sender)
|
||||
if isinstance(msg, Ask):
|
||||
msg.reply = reply
|
142
soil/agents/fsm.py
Normal file
142
soil/agents/fsm.py
Normal file
@ -0,0 +1,142 @@
|
||||
from . import MetaAgent, BaseAgent
|
||||
|
||||
from functools import partial, wraps
|
||||
import inspect
|
||||
|
||||
|
||||
def state(name=None):
|
||||
def decorator(func, name=None):
|
||||
"""
|
||||
A state function should return either a state id, or a tuple (state_id, when)
|
||||
The default value for state_id is the current state id.
|
||||
The default value for when is the interval defined in the environment.
|
||||
"""
|
||||
if inspect.isgeneratorfunction(func):
|
||||
orig_func = func
|
||||
|
||||
@wraps(func)
|
||||
def func(self):
|
||||
while True:
|
||||
if not self._coroutine:
|
||||
self._coroutine = orig_func(self)
|
||||
|
||||
try:
|
||||
if self._last_except:
|
||||
n = self._coroutine.throw(self._last_except)
|
||||
else:
|
||||
n = self._coroutine.send(self._last_return)
|
||||
if n:
|
||||
return None, n
|
||||
return n
|
||||
except StopIteration as ex:
|
||||
self._coroutine = None
|
||||
next_state = ex.value
|
||||
if next_state is not None:
|
||||
self._set_state(next_state)
|
||||
return next_state
|
||||
finally:
|
||||
self._last_return = None
|
||||
self._last_except = None
|
||||
|
||||
|
||||
|
||||
func.id = name or func.__name__
|
||||
func.is_default = False
|
||||
return func
|
||||
|
||||
if callable(name):
|
||||
return decorator(name)
|
||||
else:
|
||||
return partial(decorator, name=name)
|
||||
|
||||
|
||||
def default_state(func):
|
||||
func.is_default = True
|
||||
return func
|
||||
|
||||
|
||||
class MetaFSM(MetaAgent):
|
||||
def __new__(mcls, name, bases, namespace):
|
||||
states = {}
|
||||
# Re-use states from inherited classes
|
||||
default_state = None
|
||||
for i in bases:
|
||||
if isinstance(i, MetaFSM):
|
||||
for state_id, state in i._states.items():
|
||||
if state.is_default:
|
||||
default_state = state
|
||||
states[state_id] = state
|
||||
|
||||
# Add new states
|
||||
for attr, func in namespace.items():
|
||||
if hasattr(func, "id"):
|
||||
if func.is_default:
|
||||
default_state = func
|
||||
states[func.id] = func
|
||||
|
||||
namespace.update(
|
||||
{
|
||||
"_default_state": default_state,
|
||||
"_states": states,
|
||||
}
|
||||
)
|
||||
|
||||
return super(MetaFSM, mcls).__new__(
|
||||
mcls=mcls, name=name, bases=bases, namespace=namespace
|
||||
)
|
||||
|
||||
|
||||
class FSM(BaseAgent, metaclass=MetaFSM):
|
||||
def __init__(self, **kwargs):
|
||||
super(FSM, self).__init__(**kwargs)
|
||||
if not hasattr(self, "state_id"):
|
||||
if not self._default_state:
|
||||
raise ValueError(
|
||||
"No default state specified for {}".format(self.unique_id)
|
||||
)
|
||||
self.state_id = self._default_state.id
|
||||
|
||||
self._coroutine = None
|
||||
self._set_state(self.state_id)
|
||||
|
||||
def step(self):
|
||||
self.debug(f"Agent {self.unique_id} @ state {self.state_id}")
|
||||
default_interval = super().step()
|
||||
|
||||
next_state = self._states[self.state_id](self)
|
||||
|
||||
when = None
|
||||
try:
|
||||
next_state, *when = next_state
|
||||
if not when:
|
||||
when = None
|
||||
elif len(when) == 1:
|
||||
when = when[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Too many values returned. Only state (and time) allowed"
|
||||
)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
if next_state is not None:
|
||||
self._set_state(next_state)
|
||||
|
||||
return when or default_interval
|
||||
|
||||
def _set_state(self, state, when=None):
|
||||
if hasattr(state, "id"):
|
||||
state = state.id
|
||||
if state not in self._states:
|
||||
raise ValueError("{} is not a valid state".format(state))
|
||||
self.state_id = state
|
||||
if when is not None:
|
||||
self.model.schedule.add(self, when=when)
|
||||
return state
|
||||
|
||||
def die(self):
|
||||
return self.dead, super().die()
|
||||
|
||||
@state
|
||||
def dead(self):
|
||||
return self.die()
|
82
soil/agents/network_agents.py
Normal file
82
soil/agents/network_agents.py
Normal file
@ -0,0 +1,82 @@
|
||||
from . import BaseAgent
|
||||
|
||||
|
||||
class NetworkAgent(BaseAgent):
|
||||
def __init__(self, *args, topology, node_id, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
assert topology is not None
|
||||
assert node_id is not None
|
||||
self.G = topology
|
||||
assert self.G
|
||||
self.node_id = node_id
|
||||
|
||||
def count_neighbors(self, state_id=None, **kwargs):
|
||||
return len(self.get_neighbors(state_id=state_id, **kwargs))
|
||||
|
||||
def get_neighbors(self, **kwargs):
|
||||
return list(self.iter_agents(limit_neighbors=True, **kwargs))
|
||||
|
||||
@property
|
||||
def node(self):
|
||||
return self.G.nodes[self.node_id]
|
||||
|
||||
def iter_agents(self, unique_id=None, *, limit_neighbors=False, **kwargs):
|
||||
unique_ids = None
|
||||
if isinstance(unique_id, list):
|
||||
unique_ids = set(unique_id)
|
||||
elif unique_id is not None:
|
||||
unique_ids = set(
|
||||
[
|
||||
unique_id,
|
||||
]
|
||||
)
|
||||
|
||||
if limit_neighbors:
|
||||
neighbor_ids = set()
|
||||
for node_id in self.G.neighbors(self.node_id):
|
||||
if self.G.nodes[node_id].get("agent") is not None:
|
||||
neighbor_ids.add(node_id)
|
||||
if unique_ids:
|
||||
unique_ids = unique_ids & neighbor_ids
|
||||
else:
|
||||
unique_ids = neighbor_ids
|
||||
if not unique_ids:
|
||||
return
|
||||
unique_ids = list(unique_ids)
|
||||
yield from super().iter_agents(unique_id=unique_ids, **kwargs)
|
||||
|
||||
def subgraph(self, center=True, **kwargs):
|
||||
include = [self] if center else []
|
||||
G = self.G.subgraph(
|
||||
n.node_id for n in list(self.get_agents(**kwargs) + include)
|
||||
)
|
||||
return G
|
||||
|
||||
def remove_node(self):
|
||||
print(f"Removing node for {self.unique_id}: {self.node_id}")
|
||||
self.G.remove_node(self.node_id)
|
||||
self.node_id = None
|
||||
|
||||
def add_edge(self, other, edge_attr_dict=None, *edge_attrs):
|
||||
if self.node_id not in self.G.nodes(data=False):
|
||||
raise ValueError(
|
||||
"{} not in list of existing agents in the network".format(
|
||||
self.unique_id
|
||||
)
|
||||
)
|
||||
if other.node_id not in self.G.nodes(data=False):
|
||||
raise ValueError(
|
||||
"{} not in list of existing agents in the network".format(other)
|
||||
)
|
||||
|
||||
self.G.add_edge(
|
||||
self.node_id, other.node_id, edge_attr_dict=edge_attr_dict, *edge_attrs
|
||||
)
|
||||
|
||||
def die(self, remove=True):
|
||||
if not self.alive:
|
||||
return None
|
||||
if remove:
|
||||
self.remove_node()
|
||||
return super().die()
|
@ -30,9 +30,9 @@ def wrapcmd(func):
|
||||
class Debug(pdb.Pdb):
|
||||
def __init__(self, *args, skip_soil=False, **kwargs):
|
||||
skip = kwargs.get("skip", [])
|
||||
if skip_soil:
|
||||
skip.append("soil")
|
||||
skip.append("contextlib")
|
||||
if skip_soil:
|
||||
skip.append("soil.*")
|
||||
skip.append("mesa.*")
|
||||
super(Debug, self).__init__(*args, skip=skip, **kwargs)
|
||||
|
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import os
|
||||
import sqlite3
|
||||
import math
|
||||
import random
|
||||
import logging
|
||||
import inspect
|
||||
|
||||
@ -19,7 +18,7 @@ import networkx as nx
|
||||
from mesa import Model
|
||||
from mesa.datacollection import DataCollector
|
||||
|
||||
from . import agents as agentmod, config, serialization, utils, time, network
|
||||
from . import agents as agentmod, config, serialization, utils, time, network, events
|
||||
|
||||
|
||||
class BaseEnvironment(Model):
|
||||
@ -294,10 +293,6 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
def add_agent(self, *args, **kwargs):
|
||||
a = super().add_agent(*args, **kwargs)
|
||||
if "node_id" in a:
|
||||
if a.node_id == 24:
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
assert self.G.nodes[a.node_id]["agent"] == a
|
||||
return a
|
||||
|
||||
@ -316,3 +311,14 @@ class NetworkEnvironment(BaseEnvironment):
|
||||
|
||||
|
||||
Environment = NetworkEnvironment
|
||||
|
||||
|
||||
class EventedEnvironment(Environment):
|
||||
def broadcast(self, msg, sender, expiration=None, ttl=None, **kwargs):
|
||||
for agent in self.agents(**kwargs):
|
||||
self.logger.info(f'Telling {repr(agent)}: {msg} ttl={ttl}')
|
||||
try:
|
||||
agent._inbox.append(events.Tell(payload=msg, sender=sender, expiration=expiration if ttl is None else self.now+ttl))
|
||||
except AttributeError:
|
||||
self.info(f'Agent {agent.unique_id} cannot receive events')
|
||||
|
||||
|
43
soil/events.py
Normal file
43
soil/events.py
Normal file
@ -0,0 +1,43 @@
|
||||
from .time import Cond
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
class Event:
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
payload: Any
|
||||
sender: Any = None
|
||||
expiration: float = None
|
||||
id: int = field(default_factory=uuid4)
|
||||
|
||||
def expired(self, when):
|
||||
return self.expiration is not None and self.expiration < when
|
||||
|
||||
class Reply(Message):
|
||||
source: Message
|
||||
|
||||
|
||||
class Ask(Message):
|
||||
reply: Message = None
|
||||
|
||||
def replied(self, expiration=None):
|
||||
def ready(agent):
|
||||
return self.reply is not None or agent.now > expiration
|
||||
|
||||
def value(agent):
|
||||
if agent.now > expiration:
|
||||
raise TimedOut(f'No answer received for {self}')
|
||||
return self.reply
|
||||
|
||||
return Cond(func=ready, return_func=value)
|
||||
|
||||
|
||||
class Tell(Message):
|
||||
pass
|
||||
|
||||
|
||||
class TimedOut(Exception):
|
||||
pass
|
@ -47,7 +47,7 @@ class Simulation:
|
||||
max_time: float = float("inf")
|
||||
max_steps: int = -1
|
||||
interval: int = 1
|
||||
num_trials: int = 3
|
||||
num_trials: int = 1
|
||||
parallel: Optional[bool] = None
|
||||
exporters: Optional[List[str]] = field(default_factory=list)
|
||||
outdir: Optional[str] = None
|
||||
|
20
soil/time.py
20
soil/time.py
@ -45,12 +45,16 @@ class When:
|
||||
def ready(self, agent):
|
||||
return self._time <= agent.model.schedule.time
|
||||
|
||||
def return_value(self, agent):
|
||||
return None
|
||||
|
||||
|
||||
class Cond(When):
|
||||
def __init__(self, func, delta=1):
|
||||
def __init__(self, func, delta=1, return_func=lambda agent: None):
|
||||
self._func = func
|
||||
self._delta = delta
|
||||
self._checked = False
|
||||
self._return_func = return_func
|
||||
|
||||
def next(self, time):
|
||||
if self._checked:
|
||||
@ -64,6 +68,9 @@ class Cond(When):
|
||||
self._checked = True
|
||||
return self._func(agent)
|
||||
|
||||
def return_value(self, agent):
|
||||
return self._return_func(agent)
|
||||
|
||||
def __eq__(self, other):
|
||||
return False
|
||||
|
||||
@ -144,14 +151,21 @@ class TimedActivation(BaseScheduler):
|
||||
|
||||
ix = 0
|
||||
|
||||
self.logger.debug(f"Queue length: {len(self._queue)}")
|
||||
|
||||
while self._queue:
|
||||
(when, agent) = self._queue[0]
|
||||
if when > self.time:
|
||||
break
|
||||
heappop(self._queue)
|
||||
if when.ready(agent):
|
||||
to_process.append(agent)
|
||||
try:
|
||||
agent._last_return = when.return_value(agent)
|
||||
except Exception as ex:
|
||||
agent._last_except = ex
|
||||
|
||||
self._next.pop(agent.unique_id, None)
|
||||
to_process.append(agent)
|
||||
continue
|
||||
|
||||
next_time = min(next_time, when.next(self.time))
|
||||
@ -175,10 +189,10 @@ class TimedActivation(BaseScheduler):
|
||||
continue
|
||||
|
||||
if not getattr(agent, "alive", True):
|
||||
self.remove(agent)
|
||||
continue
|
||||
|
||||
value = returned.next(self.time)
|
||||
agent._last_return = value
|
||||
|
||||
if value < self.time:
|
||||
raise Exception(
|
||||
|
@ -33,15 +33,37 @@ class TestMain(TestCase):
|
||||
The step function of an agent could be a generator. In that case, the state of the
|
||||
agent will be resumed after every call to step.
|
||||
'''
|
||||
a = 0
|
||||
class Gen(agents.BaseAgent):
|
||||
def step(self):
|
||||
a = 0
|
||||
nonlocal a
|
||||
for i in range(5):
|
||||
yield a
|
||||
yield
|
||||
a += 1
|
||||
e = environment.Environment()
|
||||
g = Gen(model=e, unique_id=e.next_id())
|
||||
e.schedule.add(g)
|
||||
|
||||
for i in range(5):
|
||||
t = g.step()
|
||||
assert t == i
|
||||
e.step()
|
||||
assert a == i
|
||||
|
||||
def test_state_decorator(self):
|
||||
class MyAgent(agents.FSM):
|
||||
run = 0
|
||||
@agents.default_state
|
||||
@agents.state('original')
|
||||
def root(self):
|
||||
self.run += 1
|
||||
return self.other
|
||||
|
||||
@agents.state
|
||||
def other(self):
|
||||
self.run += 1
|
||||
|
||||
e = environment.Environment()
|
||||
a = MyAgent(model=e, unique_id=e.next_id())
|
||||
a.step()
|
||||
assert a.run == 1
|
||||
a.step()
|
||||
assert a.run == 2
|
||||
|
@ -160,32 +160,12 @@ class TestMain(TestCase):
|
||||
|
||||
def test_serialize_agent_class(self):
|
||||
"""A class from soil.agents should be serialized without the module part"""
|
||||
ser = agents.serialize_type(CustomAgent)
|
||||
ser = agents._serialize_type(CustomAgent)
|
||||
assert ser == "test_main.CustomAgent"
|
||||
ser = agents.serialize_type(agents.BaseAgent)
|
||||
ser = agents._serialize_type(agents.BaseAgent)
|
||||
assert ser == "BaseAgent"
|
||||
pickle.dumps(ser)
|
||||
|
||||
def test_deserialize_agent_distribution(self):
|
||||
agent_distro = [
|
||||
{"agent_class": "CounterModel", "weight": 1},
|
||||
{"agent_class": "test_main.CustomAgent", "weight": 2},
|
||||
]
|
||||
converted = agents.deserialize_definition(agent_distro)
|
||||
assert converted[0]["agent_class"] == agents.CounterModel
|
||||
assert converted[1]["agent_class"] == CustomAgent
|
||||
pickle.dumps(converted)
|
||||
|
||||
def test_serialize_agent_distribution(self):
|
||||
agent_distro = [
|
||||
{"agent_class": agents.CounterModel, "weight": 1},
|
||||
{"agent_class": CustomAgent, "weight": 2},
|
||||
]
|
||||
converted = agents.serialize_definition(agent_distro)
|
||||
assert converted[0]["agent_class"] == "CounterModel"
|
||||
assert converted[1]["agent_class"] == "test_main.CustomAgent"
|
||||
pickle.dumps(converted)
|
||||
|
||||
def test_templates(self):
|
||||
"""Loading a template should result in several configs"""
|
||||
configs = serialization.load_file(join(EXAMPLES, "template.yml"))
|
||||
|
Loading…
Reference in New Issue
Block a user