mirror of
https://github.com/gsi-upm/soil
synced 2025-09-14 12:12:21 +00:00
Compare commits
8 Commits
0.20.2.pos
...
0.20.7
Author | SHA1 | Date | |
---|---|---|---|
|
a40aa55b6a | ||
|
50cba751a6 | ||
|
dfb6d13649 | ||
|
5559d37e57 | ||
|
2116fe6f38 | ||
|
affeeb9643 | ||
|
42ddc02318 | ||
|
cab9a3440b |
@@ -37,7 +37,7 @@ push_pypi:
|
|||||||
- echo $CI_COMMIT_TAG > soil/VERSION
|
- echo $CI_COMMIT_TAG > soil/VERSION
|
||||||
- pip install twine
|
- pip install twine
|
||||||
- python setup.py sdist bdist_wheel
|
- python setup.py sdist bdist_wheel
|
||||||
- TWINE_PASSWORD=${PYPI_PASSWORD} TWINE_USERNAME={PYPI_USERNAME} python -m twine upload dist/*
|
- TWINE_PASSWORD=$PYPI_PASSWORD TWINE_USERNAME=$PYPI_USERNAME python -m twine upload dist/*
|
||||||
|
|
||||||
check_pypi:
|
check_pypi:
|
||||||
only:
|
only:
|
||||||
@@ -48,3 +48,6 @@ check_pypi:
|
|||||||
stage: check_published
|
stage: check_published
|
||||||
script:
|
script:
|
||||||
- pip install soil==$CI_COMMIT_TAG
|
- pip install soil==$CI_COMMIT_TAG
|
||||||
|
# Allow PYPI to update its index before we try to install
|
||||||
|
when: delayed
|
||||||
|
start_in: 2 minutes
|
||||||
|
33
CHANGELOG.md
33
CHANGELOG.md
@@ -3,6 +3,39 @@ All notable changes to this project will be documented in this file.
|
|||||||
|
|
||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [UNRELEASED]
|
||||||
|
## [0.20.7]
|
||||||
|
### Changed
|
||||||
|
* Creating a `time.When` from another `time.When` does not nest them anymore (it returns the argument)
|
||||||
|
### Fixed
|
||||||
|
* Bug with time.NEVER/time.INFINITY
|
||||||
|
## [0.20.6]
|
||||||
|
### Fixed
|
||||||
|
* Agents now return `time.INFINITY` when dead, instead of 'inf'
|
||||||
|
* `soil.__init__` does not re-export built-in time (change in `soil.simulation`. It used to create subtle import conflicts when importing soil.time.
|
||||||
|
* Parallel simulations were broken because lambdas cannot be pickled properly, which is needed for multiprocessing.
|
||||||
|
### Changed
|
||||||
|
* Some internal simulation methods do not accept `*args` anymore, to avoid ambiguity and bugs.
|
||||||
|
## [0.20.5]
|
||||||
|
### Changed
|
||||||
|
* Defaults are now set in the agent __init__, not in the environment. This decouples both classes a bit more, and it is more intuitive
|
||||||
|
## [0.20.4]
|
||||||
|
### Added
|
||||||
|
* Agents can now be given any kwargs, which will be used to set their state
|
||||||
|
* Environments have a default logger `self.logger` and a log method, just like agents
|
||||||
|
## [0.20.3]
|
||||||
|
### Fixed
|
||||||
|
* Default state values are now deepcopied again.
|
||||||
|
* Seeds for environments only concatenate the trial id (i.e., a number), to provide repeatable results.
|
||||||
|
* `Environment.run` now calls `Environment.step`, to allow for easy overloading of the environment step
|
||||||
|
### Removed
|
||||||
|
* Datacollectors are not being used for now.
|
||||||
|
* `time.TimedActivation.step` does not use an `until` parameter anymore.
|
||||||
|
### Changed
|
||||||
|
* Simulations now run right up to `until` (open interval)
|
||||||
|
* Time instants (`time.When`) don't need to be floats anymore. Now we can avoid precision issues with big numbers by using ints.
|
||||||
|
* Rabbits simulation is more idiomatic (using subclasses)
|
||||||
|
|
||||||
## [0.20.2]
|
## [0.20.2]
|
||||||
### Fixed
|
### Fixed
|
||||||
* CI/CD testing issues
|
* CI/CD testing issues
|
||||||
|
@@ -1 +1 @@
|
|||||||
ipython==7.23
|
ipython==7.31.1
|
||||||
|
@@ -17,7 +17,7 @@ class DumbViewer(FSM):
|
|||||||
def neutral(self):
|
def neutral(self):
|
||||||
if self['has_tv']:
|
if self['has_tv']:
|
||||||
if prob(self.env['prob_tv_spread']):
|
if prob(self.env['prob_tv_spread']):
|
||||||
self.set_state(self.infected)
|
return self.infected
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def infected(self):
|
def infected(self):
|
||||||
@@ -26,6 +26,12 @@ class DumbViewer(FSM):
|
|||||||
neighbor.infect()
|
neighbor.infect()
|
||||||
|
|
||||||
def infect(self):
|
def infect(self):
|
||||||
|
'''
|
||||||
|
This is not a state. It is a function that other agents can use to try to
|
||||||
|
infect this agent. DumbViewer always gets infected, but other agents like
|
||||||
|
HerdViewer might not become infected right away
|
||||||
|
'''
|
||||||
|
|
||||||
self.set_state(self.infected)
|
self.set_state(self.infected)
|
||||||
|
|
||||||
|
|
||||||
@@ -35,12 +41,13 @@ class HerdViewer(DumbViewer):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
def infect(self):
|
def infect(self):
|
||||||
|
'''Notice again that this is NOT a state. See DumbViewer.infect for reference'''
|
||||||
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
infected = self.count_neighboring_agents(state_id=self.infected.id)
|
||||||
total = self.count_neighboring_agents()
|
total = self.count_neighboring_agents()
|
||||||
prob_infect = self.env['prob_neighbor_spread'] * infected/total
|
prob_infect = self.env['prob_neighbor_spread'] * infected/total
|
||||||
self.debug('prob_infect', prob_infect)
|
self.debug('prob_infect', prob_infect)
|
||||||
if prob(prob_infect):
|
if prob(prob_infect):
|
||||||
self.set_state(self.infected.id)
|
self.set_state(self.infected)
|
||||||
|
|
||||||
|
|
||||||
class WiseViewer(HerdViewer):
|
class WiseViewer(HerdViewer):
|
||||||
@@ -75,5 +82,5 @@ class WiseViewer(HerdViewer):
|
|||||||
1.0)
|
1.0)
|
||||||
prob_cure = self.env['prob_neighbor_cure'] * (cured/infected)
|
prob_cure = self.env['prob_neighbor_cure'] * (cured/infected)
|
||||||
if prob(prob_cure):
|
if prob(prob_cure):
|
||||||
return self.cure()
|
return self.cured
|
||||||
return self.set_state(super().infected)
|
return self.set_state(super().infected)
|
||||||
|
@@ -18,7 +18,9 @@ class MyAgent(agents.FSM):
|
|||||||
@agents.default_state
|
@agents.default_state
|
||||||
@agents.state
|
@agents.state
|
||||||
def neutral(self):
|
def neutral(self):
|
||||||
self.info('I am running')
|
self.debug('I am running')
|
||||||
|
if agents.prob(0.2):
|
||||||
|
self.info('This runs 2/10 times on average')
|
||||||
|
|
||||||
|
|
||||||
s = Simulation(name='Programmatic',
|
s = Simulation(name='Programmatic',
|
||||||
@@ -29,10 +31,10 @@ s = Simulation(name='Programmatic',
|
|||||||
dry_run=True)
|
dry_run=True)
|
||||||
|
|
||||||
|
|
||||||
|
# By default, logging will only print WARNING logs (and above).
|
||||||
|
# You need to choose a lower logging level to get INFO/DEBUG traces
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
envs = s.run()
|
envs = s.run()
|
||||||
|
|
||||||
s.dump_yaml()
|
# Uncomment this to output the simulation to a YAML file
|
||||||
|
# s.dump_yaml('simulation.yaml')
|
||||||
for env in envs:
|
|
||||||
env.dump_csv()
|
|
||||||
|
@@ -12,8 +12,6 @@ class Genders(Enum):
|
|||||||
|
|
||||||
class RabbitModel(FSM):
|
class RabbitModel(FSM):
|
||||||
|
|
||||||
level = logging.INFO
|
|
||||||
|
|
||||||
defaults = {
|
defaults = {
|
||||||
'age': 0,
|
'age': 0,
|
||||||
'gender': Genders.male.value,
|
'gender': Genders.male.value,
|
||||||
@@ -36,6 +34,17 @@ class RabbitModel(FSM):
|
|||||||
if self['age'] >= self.sexual_maturity:
|
if self['age'] >= self.sexual_maturity:
|
||||||
self.debug('I am fertile!')
|
self.debug('I am fertile!')
|
||||||
return self.fertile
|
return self.fertile
|
||||||
|
@state
|
||||||
|
def fertile(self):
|
||||||
|
raise Exception("Each subclass should define its fertile state")
|
||||||
|
|
||||||
|
@state
|
||||||
|
def dead(self):
|
||||||
|
self.info('Agent {} is dying'.format(self.id))
|
||||||
|
self.die()
|
||||||
|
|
||||||
|
|
||||||
|
class Male(RabbitModel):
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def fertile(self):
|
def fertile(self):
|
||||||
@@ -47,20 +56,26 @@ class RabbitModel(FSM):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Males try to mate
|
# Males try to mate
|
||||||
for f in self.get_agents(state_id=self.fertile.id, gender=Genders.female.value, limit_neighbors=False, limit=self.max_females):
|
for f in self.get_agents(state_id=Female.fertile.id,
|
||||||
|
agent_type=Female,
|
||||||
|
limit_neighbors=False,
|
||||||
|
limit=self.max_females):
|
||||||
r = random()
|
r = random()
|
||||||
if r < self['mating_prob']:
|
if r < self['mating_prob']:
|
||||||
self.impregnate(f)
|
self.impregnate(f)
|
||||||
break # Take a break
|
break # Take a break
|
||||||
|
|
||||||
def impregnate(self, whom):
|
def impregnate(self, whom):
|
||||||
if self['gender'] == Genders.female.value:
|
|
||||||
raise NotImplementedError('Females cannot impregnate')
|
|
||||||
whom['pregnancy'] = 0
|
whom['pregnancy'] = 0
|
||||||
whom['mate'] = self.id
|
whom['mate'] = self.id
|
||||||
whom.set_state(whom.pregnant)
|
whom.set_state(whom.pregnant)
|
||||||
self.debug('{} impregnating: {}. {}'.format(self.id, whom.id, whom.state))
|
self.debug('{} impregnating: {}. {}'.format(self.id, whom.id, whom.state))
|
||||||
|
|
||||||
|
class Female(RabbitModel):
|
||||||
|
@state
|
||||||
|
def fertile(self):
|
||||||
|
# Just wait for a Male
|
||||||
|
pass
|
||||||
|
|
||||||
@state
|
@state
|
||||||
def pregnant(self):
|
def pregnant(self):
|
||||||
self['age'] += 1
|
self['age'] += 1
|
||||||
@@ -90,11 +105,9 @@ class RabbitModel(FSM):
|
|||||||
|
|
||||||
@state
|
@state
|
||||||
def dead(self):
|
def dead(self):
|
||||||
self.info('Agent {} is dying'.format(self.id))
|
super().dead()
|
||||||
if 'pregnancy' in self and self['pregnancy'] > -1:
|
if 'pregnancy' in self and self['pregnancy'] > -1:
|
||||||
self.info('A mother has died carrying a baby!!')
|
self.info('A mother has died carrying a baby!!')
|
||||||
self.die()
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
class RandomAccident(NetworkAgent):
|
class RandomAccident(NetworkAgent):
|
||||||
|
@@ -1,23 +1,21 @@
|
|||||||
---
|
---
|
||||||
load_module: rabbit_agents
|
load_module: rabbit_agents
|
||||||
name: rabbits_example
|
name: rabbits_example
|
||||||
max_time: 150
|
max_time: 100
|
||||||
interval: 1
|
interval: 1
|
||||||
seed: MySeed
|
seed: MySeed
|
||||||
agent_type: RabbitModel
|
agent_type: rabbit_agents.RabbitModel
|
||||||
environment_agents:
|
environment_agents:
|
||||||
- agent_type: RandomAccident
|
- agent_type: rabbit_agents.RandomAccident
|
||||||
environment_params:
|
environment_params:
|
||||||
prob_death: 0.001
|
prob_death: 0.001
|
||||||
default_state:
|
default_state:
|
||||||
mating_prob: 0.01
|
mating_prob: 0.1
|
||||||
topology:
|
topology:
|
||||||
nodes:
|
nodes:
|
||||||
- id: 1
|
- id: 1
|
||||||
state:
|
agent_type: rabbit_agents.Male
|
||||||
gender: female
|
|
||||||
- id: 0
|
- id: 0
|
||||||
state:
|
agent_type: rabbit_agents.Female
|
||||||
gender: male
|
|
||||||
directed: true
|
directed: true
|
||||||
links: []
|
links: []
|
||||||
|
@@ -1 +1 @@
|
|||||||
0.20.1
|
0.20.7
|
@@ -65,6 +65,10 @@ def main():
|
|||||||
|
|
||||||
logger.info('Loading config file: {}'.format(args.file))
|
logger.info('Loading config file: {}'.format(args.file))
|
||||||
|
|
||||||
|
if args.pdb:
|
||||||
|
args.synchronous = True
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
exporters = list(args.exporter or ['default', ])
|
exporters = list(args.exporter or ['default', ])
|
||||||
if args.csv:
|
if args.csv:
|
||||||
|
@@ -35,7 +35,9 @@ class BaseAgent(Agent):
|
|||||||
unique_id,
|
unique_id,
|
||||||
model,
|
model,
|
||||||
name=None,
|
name=None,
|
||||||
interval=None):
|
interval=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
# Check for REQUIRED arguments
|
# Check for REQUIRED arguments
|
||||||
# Initialize agent parameters
|
# Initialize agent parameters
|
||||||
if isinstance(unique_id, Agent):
|
if isinstance(unique_id, Agent):
|
||||||
@@ -52,6 +54,12 @@ class BaseAgent(Agent):
|
|||||||
|
|
||||||
if hasattr(self, 'level'):
|
if hasattr(self, 'level'):
|
||||||
self.logger.setLevel(self.level)
|
self.logger.setLevel(self.level)
|
||||||
|
for (k, v) in self.defaults.items():
|
||||||
|
if not hasattr(self, k) or getattr(self, k) is None:
|
||||||
|
setattr(self, k, deepcopy(v))
|
||||||
|
|
||||||
|
for (k, v) in kwargs.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
|
||||||
# TODO: refactor to clean up mesa compatibility
|
# TODO: refactor to clean up mesa compatibility
|
||||||
@@ -137,6 +145,7 @@ class BaseAgent(Agent):
|
|||||||
self.alive = False
|
self.alive = False
|
||||||
if remove:
|
if remove:
|
||||||
self.remove_node(self.id)
|
self.remove_node(self.id)
|
||||||
|
return time.NEVER
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
if not self.alive:
|
if not self.alive:
|
||||||
@@ -305,18 +314,16 @@ class FSM(NetworkAgent, metaclass=MetaFSM):
|
|||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
self.debug(f'Agent {self.unique_id} @ state {self.state_id}')
|
self.debug(f'Agent {self.unique_id} @ state {self.state_id}')
|
||||||
try:
|
interval = super().step()
|
||||||
interval = super().step()
|
|
||||||
except DeadAgent:
|
|
||||||
return time.When('inf')
|
|
||||||
if 'id' not in self.state:
|
if 'id' not in self.state:
|
||||||
# if 'id' in self.state:
|
|
||||||
# self.set_state(self.state['id'])
|
|
||||||
if self.default_state:
|
if self.default_state:
|
||||||
self.set_state(self.default_state.id)
|
self.set_state(self.default_state.id)
|
||||||
else:
|
else:
|
||||||
raise Exception('{} has no valid state id or default state'.format(self))
|
raise Exception('{} has no valid state id or default state'.format(self))
|
||||||
return self.states[self.state_id](self) or interval
|
interval = self.states[self.state_id](self) or interval
|
||||||
|
if not self.alive:
|
||||||
|
return time.NEVER
|
||||||
|
return interval
|
||||||
|
|
||||||
def set_state(self, state):
|
def set_state(self, state):
|
||||||
if hasattr(state, 'id'):
|
if hasattr(state, 'id'):
|
||||||
|
@@ -5,6 +5,7 @@ import math
|
|||||||
import random
|
import random
|
||||||
import yaml
|
import yaml
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import logging
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from time import time as current_time
|
from time import time as current_time
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@@ -101,6 +102,8 @@ class Environment(Model):
|
|||||||
environment_agents = agents._convert_agent_types(distro)
|
environment_agents = agents._convert_agent_types(distro)
|
||||||
self.environment_agents = environment_agents
|
self.environment_agents = environment_agents
|
||||||
|
|
||||||
|
self.logger = utils.logger.getChild(self.name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def now(self):
|
def now(self):
|
||||||
if self.schedule:
|
if self.schedule:
|
||||||
@@ -169,11 +172,8 @@ class Environment(Model):
|
|||||||
if agent_type:
|
if agent_type:
|
||||||
state = defstate
|
state = defstate
|
||||||
a = agent_type(model=self,
|
a = agent_type(model=self,
|
||||||
unique_id=agent_id)
|
unique_id=agent_id
|
||||||
|
)
|
||||||
for (k, v) in getattr(a, 'defaults', {}).items():
|
|
||||||
if not hasattr(a, k) or getattr(a, k) is None:
|
|
||||||
setattr(a, k, v)
|
|
||||||
|
|
||||||
for (k, v) in state.items():
|
for (k, v) in state.items():
|
||||||
setattr(a, k, v)
|
setattr(a, k, v)
|
||||||
@@ -197,17 +197,29 @@ class Environment(Model):
|
|||||||
start = start or self.now
|
start = start or self.now
|
||||||
return self.G.add_edge(agent1, agent2, **attrs)
|
return self.G.add_edge(agent1, agent2, **attrs)
|
||||||
|
|
||||||
|
def log(self, message, *args, level=logging.INFO, **kwargs):
|
||||||
|
if not self.logger.isEnabledFor(level):
|
||||||
|
return
|
||||||
|
message = message + " ".join(str(i) for i in args)
|
||||||
|
message = " @{:>3}: {}".format(self.now, message)
|
||||||
|
for k, v in kwargs:
|
||||||
|
message += " {k}={v} ".format(k, v)
|
||||||
|
extra = {}
|
||||||
|
extra['now'] = self.now
|
||||||
|
extra['unique_id'] = self.name
|
||||||
|
return self.logger.log(level, message, extra=extra)
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
super().step()
|
super().step()
|
||||||
self.datacollector.collect(self)
|
|
||||||
self.schedule.step()
|
self.schedule.step()
|
||||||
|
|
||||||
def run(self, until, *args, **kwargs):
|
def run(self, until, *args, **kwargs):
|
||||||
self._save_state()
|
self._save_state()
|
||||||
|
|
||||||
while self.schedule.next_time <= until and not math.isinf(self.schedule.next_time):
|
while self.schedule.next_time < until:
|
||||||
self.schedule.step(until=until)
|
self.step()
|
||||||
utils.logger.debug(f'Simulation step {self.schedule.time}/{until}. Next: {self.schedule.next_time}')
|
utils.logger.debug(f'Simulation step {self.schedule.time}/{until}. Next: {self.schedule.next_time}')
|
||||||
|
self.schedule.time = until
|
||||||
self._history.flush_cache()
|
self._history.flush_cache()
|
||||||
|
|
||||||
def _save_state(self, now=None):
|
def _save_state(self, now=None):
|
||||||
|
@@ -1,11 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import importlib
|
import importlib
|
||||||
import sys
|
import sys
|
||||||
import yaml
|
import yaml
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
|
from time import strftime
|
||||||
from networkx.readwrite import json_graph
|
from networkx.readwrite import json_graph
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -98,7 +99,7 @@ class Simulation:
|
|||||||
self.network_params = network_params
|
self.network_params = network_params
|
||||||
self.name = name or 'Unnamed'
|
self.name = name or 'Unnamed'
|
||||||
self.seed = str(seed or name)
|
self.seed = str(seed or name)
|
||||||
self._id = '{}_{}'.format(self.name, time.strftime("%Y-%m-%d_%H.%M.%S"))
|
self._id = '{}_{}'.format(self.name, strftime("%Y-%m-%d_%H.%M.%S"))
|
||||||
self.group = group or ''
|
self.group = group or ''
|
||||||
self.num_trials = num_trials
|
self.num_trials = num_trials
|
||||||
self.max_time = max_time
|
self.max_time = max_time
|
||||||
@@ -142,12 +143,10 @@ class Simulation:
|
|||||||
'''Run the simulation and return the list of resulting environments'''
|
'''Run the simulation and return the list of resulting environments'''
|
||||||
return list(self.run_gen(*args, **kwargs))
|
return list(self.run_gen(*args, **kwargs))
|
||||||
|
|
||||||
def _run_sync_or_async(self, parallel=False, *args, **kwargs):
|
def _run_sync_or_async(self, parallel=False, **kwargs):
|
||||||
if parallel and not os.environ.get('SENPY_DEBUG', None):
|
if parallel and not os.environ.get('SENPY_DEBUG', None):
|
||||||
p = Pool()
|
p = Pool()
|
||||||
func = partial(self.run_trial_exceptions,
|
func = partial(self.run_trial_exceptions, **kwargs)
|
||||||
*args,
|
|
||||||
**kwargs)
|
|
||||||
for i in p.imap_unordered(func, range(self.num_trials)):
|
for i in p.imap_unordered(func, range(self.num_trials)):
|
||||||
if isinstance(i, Exception):
|
if isinstance(i, Exception):
|
||||||
logger.error('Trial failed:\n\t%s', i.message)
|
logger.error('Trial failed:\n\t%s', i.message)
|
||||||
@@ -155,10 +154,10 @@ class Simulation:
|
|||||||
yield i
|
yield i
|
||||||
else:
|
else:
|
||||||
for i in range(self.num_trials):
|
for i in range(self.num_trials):
|
||||||
yield self.run_trial(*args,
|
yield self.run_trial(trial_id=i,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def run_gen(self, *args, parallel=False, dry_run=False,
|
def run_gen(self, parallel=False, dry_run=False,
|
||||||
exporters=[default, ], stats=[], outdir=None, exporter_params={},
|
exporters=[default, ], stats=[], outdir=None, exporter_params={},
|
||||||
stats_params={}, log_level=None,
|
stats_params={}, log_level=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@@ -184,8 +183,7 @@ class Simulation:
|
|||||||
|
|
||||||
for exporter in exporters:
|
for exporter in exporters:
|
||||||
exporter.start()
|
exporter.start()
|
||||||
for env in self._run_sync_or_async(*args,
|
for env in self._run_sync_or_async(parallel=parallel,
|
||||||
parallel=parallel,
|
|
||||||
log_level=log_level,
|
log_level=log_level,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
@@ -224,7 +222,7 @@ class Simulation:
|
|||||||
'''Create an environment for a trial of the simulation'''
|
'''Create an environment for a trial of the simulation'''
|
||||||
opts = self.environment_params.copy()
|
opts = self.environment_params.copy()
|
||||||
opts.update({
|
opts.update({
|
||||||
'name': trial_id,
|
'name': '{}_trial_{}'.format(self.name, trial_id),
|
||||||
'topology': self.topology.copy(),
|
'topology': self.topology.copy(),
|
||||||
'network_params': self.network_params,
|
'network_params': self.network_params,
|
||||||
'seed': '{}_trial_{}'.format(self.seed, trial_id),
|
'seed': '{}_trial_{}'.format(self.seed, trial_id),
|
||||||
@@ -241,12 +239,11 @@ class Simulation:
|
|||||||
env = self.environment_class(**opts)
|
env = self.environment_class(**opts)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def run_trial(self, until=None, log_level=logging.INFO, **opts):
|
def run_trial(self, trial_id=0, until=None, log_level=logging.INFO, **opts):
|
||||||
"""
|
"""
|
||||||
Run a single trial of the simulation
|
Run a single trial of the simulation
|
||||||
|
|
||||||
"""
|
"""
|
||||||
trial_id = '{}_trial_{}'.format(self.name, time.time()).replace('.', '-')
|
|
||||||
if log_level:
|
if log_level:
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
# Set-up trial environment and graph
|
# Set-up trial environment and graph
|
||||||
|
67
soil/time.py
67
soil/time.py
@@ -6,15 +6,21 @@ from .utils import logger
|
|||||||
from mesa import Agent
|
from mesa import Agent
|
||||||
|
|
||||||
|
|
||||||
|
INFINITY = float('inf')
|
||||||
|
|
||||||
class When:
|
class When:
|
||||||
def __init__(self, time):
|
def __init__(self, time):
|
||||||
self._time = float(time)
|
if isinstance(time, When):
|
||||||
|
return time
|
||||||
|
self._time = time
|
||||||
|
|
||||||
def abs(self, time):
|
def abs(self, time):
|
||||||
return self._time
|
return self._time
|
||||||
|
|
||||||
|
NEVER = When(INFINITY)
|
||||||
|
|
||||||
class Delta:
|
|
||||||
|
class Delta(When):
|
||||||
def __init__(self, delta):
|
def __init__(self, delta):
|
||||||
self._delta = delta
|
self._delta = delta
|
||||||
|
|
||||||
@@ -40,48 +46,35 @@ class TimedActivation(BaseScheduler):
|
|||||||
heappush(self._queue, (self.time, agent.unique_id))
|
heappush(self._queue, (self.time, agent.unique_id))
|
||||||
super().add(agent)
|
super().add(agent)
|
||||||
|
|
||||||
def step(self, until: float =float('inf')) -> None:
|
def step(self) -> None:
|
||||||
"""
|
"""
|
||||||
Executes agents in order, one at a time. After each step,
|
Executes agents in order, one at a time. After each step,
|
||||||
an agent will signal when it wants to be scheduled next.
|
an agent will signal when it wants to be scheduled next.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
when = None
|
if self.next_time == INFINITY:
|
||||||
agent_id = None
|
return
|
||||||
unsched = []
|
|
||||||
until = until or float('inf')
|
self.time = self.next_time
|
||||||
|
when = self.time
|
||||||
|
|
||||||
|
while self._queue and self._queue[0][0] == self.time:
|
||||||
|
(when, agent_id) = heappop(self._queue)
|
||||||
|
logger.debug(f'Stepping agent {agent_id}')
|
||||||
|
|
||||||
|
returned = self._agents[agent_id].step()
|
||||||
|
when = (returned or Delta(1)).abs(self.time)
|
||||||
|
if when < self.time:
|
||||||
|
raise Exception("Cannot schedule an agent for a time in the past ({} < {})".format(when, self.time))
|
||||||
|
|
||||||
|
heappush(self._queue, (when, agent_id))
|
||||||
|
|
||||||
|
self.steps += 1
|
||||||
|
|
||||||
if not self._queue:
|
if not self._queue:
|
||||||
self.time = until
|
self.time = INFINITY
|
||||||
self.next_time = float('inf')
|
self.next_time = INFINITY
|
||||||
return
|
return
|
||||||
|
|
||||||
(when, agent_id) = self._queue[0]
|
self.next_time = self._queue[0][0]
|
||||||
|
|
||||||
if until and when > until:
|
|
||||||
self.time = until
|
|
||||||
self.next_time = when
|
|
||||||
return
|
|
||||||
|
|
||||||
self.time = when
|
|
||||||
next_time = float("inf")
|
|
||||||
|
|
||||||
while when == self.time:
|
|
||||||
heappop(self._queue)
|
|
||||||
logger.debug(f'Stepping agent {agent_id}')
|
|
||||||
when = (self._agents[agent_id].step() or Delta(1)).abs(self.time)
|
|
||||||
heappush(self._queue, (when, agent_id))
|
|
||||||
if when < next_time:
|
|
||||||
next_time = when
|
|
||||||
|
|
||||||
if not self._queue or self._queue[0][0] > self.time:
|
|
||||||
agent_id = None
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
(when, agent_id) = self._queue[0]
|
|
||||||
|
|
||||||
if when and when < self.time:
|
|
||||||
raise Exception("Invalid scheduling time")
|
|
||||||
|
|
||||||
self.next_time = next_time
|
|
||||||
self.steps += 1
|
|
||||||
|
22
tests/test_agents.py
Normal file
22
tests/test_agents.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from unittest import TestCase
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from soil import agents, environment
|
||||||
|
from soil import time as stime
|
||||||
|
|
||||||
|
class Dead(agents.FSM):
|
||||||
|
@agents.default_state
|
||||||
|
@agents.state
|
||||||
|
def only(self):
|
||||||
|
self.die()
|
||||||
|
|
||||||
|
class TestMain(TestCase):
|
||||||
|
def test_die_raises_exception(self):
|
||||||
|
d = Dead(unique_id=0, model=environment.Environment())
|
||||||
|
d.step()
|
||||||
|
with pytest.raises(agents.DeadAgent):
|
||||||
|
d.step()
|
||||||
|
|
||||||
|
def test_die_returns_infinity(self):
|
||||||
|
d = Dead(unique_id=0, model=environment.Environment())
|
||||||
|
assert d.step().abs(0) == stime.INFINITY
|
@@ -127,7 +127,7 @@ class TestMain(TestCase):
|
|||||||
env = s.run_simulation(dry_run=True)[0]
|
env = s.run_simulation(dry_run=True)[0]
|
||||||
for agent in env.network_agents:
|
for agent in env.network_agents:
|
||||||
last = 0
|
last = 0
|
||||||
assert len(agent[None, None]) == 11
|
assert len(agent[None, None]) == 10
|
||||||
for step, total in sorted(agent['total', None]):
|
for step, total in sorted(agent['total', None]):
|
||||||
assert total == last + 2
|
assert total == last + 2
|
||||||
last = total
|
last = total
|
||||||
|
Reference in New Issue
Block a user