mirror of
				https://github.com/gsi-upm/soil
				synced 2025-10-23 03:38:24 +00:00 
			
		
		
		
	Compare commits
	
		
			12 Commits
		
	
	
		
			6adc8d36ba
			...
			0.20.2
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 3fc5ca8c08 | ||
|  | c02e6ea2e8 | ||
|  | 38f8a8d110 | ||
|  | cb72aac980 | ||
|  | 6c4f44b4cb | ||
|  | af9a392a93 | ||
|  | 5d7e57675a | ||
|  | e860bdb922 | ||
|  | d6b684c1c1 | ||
|  | 05f7f49233 | ||
|  | 3b2c6a3db5 | ||
|  | 6118f917ee | 
| @@ -1,4 +1,5 @@ | |||||||
| **/soil_output | **/soil_output | ||||||
| .* | .* | ||||||
|  | **/__pycache__ | ||||||
| __pycache__ | __pycache__ | ||||||
| *.pyc | *.pyc | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -8,3 +8,4 @@ soil_output | |||||||
| docs/_build* | docs/_build* | ||||||
| build/* | build/* | ||||||
| dist/* | dist/* | ||||||
|  | prof | ||||||
| @@ -1,9 +1,10 @@ | |||||||
| stages: | stages: | ||||||
|   - test |   - test | ||||||
|   - build |   - publish | ||||||
|  |   - check_published | ||||||
|  |  | ||||||
| build: | docker: | ||||||
|   stage: build |   stage: publish | ||||||
|   image: |   image: | ||||||
|     name: gcr.io/kaniko-project/executor:debug |     name: gcr.io/kaniko-project/executor:debug | ||||||
|     entrypoint: [""] |     entrypoint: [""] | ||||||
| @@ -16,13 +17,34 @@ build: | |||||||
|   only: |   only: | ||||||
|     - tags |     - tags | ||||||
|  |  | ||||||
|  |  | ||||||
| test: | test: | ||||||
|   except: |  | ||||||
|     - tags  # Avoid running tests for tags, because they are already run for the branch |  | ||||||
|   tags: |   tags: | ||||||
|     - docker |     - docker | ||||||
|   image: python:3.7 |   image: python:3.7 | ||||||
|   stage: test |   stage: test | ||||||
|   script: |   script: | ||||||
|     - python setup.py test |     - pip install -r requirements.txt -r test-requirements.txt | ||||||
|  |     - python setup.py test | ||||||
|  |  | ||||||
|  | pypi: | ||||||
|  |   only: | ||||||
|  |     - tags | ||||||
|  |   tags: | ||||||
|  |     - docker | ||||||
|  |   image: python:3.7 | ||||||
|  |   stage: publish | ||||||
|  |   script: | ||||||
|  |     - echo $CI_COMMIT_TAG > soil/VERSION | ||||||
|  |     - pip install twine | ||||||
|  |     - python setup.py sdist bdist_wheel | ||||||
|  |     - TWINE_PASSWORD=${PYPI_PASSWORD} TWINE_USERNAME={PYPI_USERNAME} python -m twine upload --ignore-existing dist/* | ||||||
|  |  | ||||||
|  | pypi: | ||||||
|  |   only: | ||||||
|  |     - tags | ||||||
|  |   tags: | ||||||
|  |     - docker | ||||||
|  |   image: python:3.7 | ||||||
|  |   stage: check_published | ||||||
|  |   script: | ||||||
|  |     - pip install soil==$CI_COMMIT_TAG | ||||||
|   | |||||||
							
								
								
									
										57
									
								
								CHANGELOG.md
									
									
									
									
									
								
							
							
						
						
									
										57
									
								
								CHANGELOG.md
									
									
									
									
									
								
							| @@ -3,6 +3,63 @@ All notable changes to this project will be documented in this file. | |||||||
|  |  | ||||||
| The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | ||||||
|  |  | ||||||
|  | ## [0.20.2] | ||||||
|  | ### Fixed | ||||||
|  | * CI/CD testing issues | ||||||
|  | ## [0.20.1] | ||||||
|  | ### Fixed | ||||||
|  | * Agents would run another step after dying. | ||||||
|  | ## [0.20.0] | ||||||
|  | ### Added | ||||||
|  | * Integration with MESA | ||||||
|  | * `not_agent_ids` parameter to get sql in history | ||||||
|  | ### Changed | ||||||
|  | * `soil.Environment` now also inherits from `mesa.Model` | ||||||
|  | * `soil.Agent` now also inherits from `mesa.Agent` | ||||||
|  | * `soil.time` to replace `simpy` events, delays, duration, etc. | ||||||
|  | * `agent.id` is not `agent.unique_id` to be compatible with `mesa`. A property `BaseAgent.id` has been added for compatibility. | ||||||
|  | * `agent.environment` is now `agent.model`, for the same reason as above. The parameter name in `BaseAgent.__init__` has also been renamed. | ||||||
|  | ### Removed | ||||||
|  | * `simpy` dependency and compatibility. Each agent used to be a simpy generator, but that made debugging and error handling more complex. That has been replaced by a scheduler within the `soil.Environment` class, similar to how `mesa` does it. | ||||||
|  | * `soil.history` is now a separate package named `tsih`. The keys namedtuple uses `dict_id` instead of `agent_id`. | ||||||
|  | ### Added | ||||||
|  | * An option to choose whether a database should be used for history  | ||||||
|  | ## [0.15.2] | ||||||
|  | ### Fixed | ||||||
|  | * Pass the right known_modules and parameters to stats discovery in simulation | ||||||
|  | * The configuration file must exist when launching through the CLI. If it doesn't, an error will be logged | ||||||
|  | * Minor changes in the documentation of the CLI arguments | ||||||
|  | ### Changed | ||||||
|  | * Stats are now exported by default | ||||||
|  | ## [0.15.1] | ||||||
|  | ### Added | ||||||
|  | * read-only `History` | ||||||
|  | ### Fixed | ||||||
|  | * Serialization problem with the `Environment` on parallel mode. | ||||||
|  | * Analysis functions now work as they should in the tutorial | ||||||
|  | ## [0.15.0] | ||||||
|  | ### Added | ||||||
|  | * Control logging level in CLI and simulation | ||||||
|  | * `Stats` to calculate trial and simulation-wide statistics | ||||||
|  | * Simulation statistics are stored in a separate table in history (see `History.get_stats` and `History.save_stats`, as well as `soil.stats`) | ||||||
|  | * Aliased `NetworkAgent.G` to `NetworkAgent.topology`. | ||||||
|  | ### Changed | ||||||
|  | * Templates in config files can be given as dictionaries in addition to strings | ||||||
|  | * Samplers are used more explicitly | ||||||
|  | * Removed nxsim dependency. We had already made a lot of changes, and nxsim has not been updated in 5 years. | ||||||
|  | * Exporter methods renamed to `trial` and `end`. Added `start`. | ||||||
|  | * `Distribution` exporter now a stats class | ||||||
|  | * `global_topology` renamed to `topology` | ||||||
|  | * Moved topology-related methods to `NetworkAgent` | ||||||
|  | ### Fixed | ||||||
|  | * Temporary files used for history in dry_run mode are not longer left open  | ||||||
|  |  | ||||||
|  | ## [0.14.9] | ||||||
|  | ### Changed | ||||||
|  | * Seed random before environment initialization | ||||||
|  | ## [0.14.8] | ||||||
|  | ### Fixed | ||||||
|  | * Invalid directory names in Windows gsi-upm/soil#5 | ||||||
| ## [0.14.7] | ## [0.14.7] | ||||||
| ### Changed | ### Changed | ||||||
| * Minor change to traceback handling in async simulations | * Minor change to traceback handling in async simulations | ||||||
|   | |||||||
							
								
								
									
										24
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								README.md
									
									
									
									
									
								
							| @@ -5,6 +5,9 @@ Learn how to run your own simulations with our [documentation](http://soilsim.re | |||||||
|  |  | ||||||
| Follow our [tutorial](examples/tutorial/soil_tutorial.ipynb) to develop your own agent models. | Follow our [tutorial](examples/tutorial/soil_tutorial.ipynb) to develop your own agent models. | ||||||
|  |  | ||||||
|  | ## Citation  | ||||||
|  |  | ||||||
|  |  | ||||||
| If you use Soil in your research, don't forget to cite this paper: | If you use Soil in your research, don't forget to cite this paper: | ||||||
|  |  | ||||||
| ```bibtex | ```bibtex | ||||||
| @@ -28,7 +31,24 @@ If you use Soil in your research, don't forget to cite this paper: | |||||||
|  |  | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| @Copyright GSI - Universidad Politécnica de Madrid 2017 | ## Mesa compatibility | ||||||
|  |  | ||||||
| [](https://www.gsi.dit.upm.es) | Soil is in the process of becoming fully compatible with MESA. | ||||||
|  | As of this writing,  | ||||||
|  |  | ||||||
|  | This is a non-exhaustive list of tasks to achieve compatibility: | ||||||
|  |  | ||||||
|  | * Environments.agents and mesa.Agent.agents are not the same. env is a property, and it only takes into account network and environment agents. Might rename environment_agents to other_agents or sth like that | ||||||
|  | - [ ] Integrate `soil.Simulation` with mesa's runners: | ||||||
|  |   - [ ] `soil.Simulation` could mimic/become a `mesa.batchrunner` | ||||||
|  | - [ ] Integrate `soil.Environment` with `mesa.Model`: | ||||||
|  |   - [x] `Soil.Environment` inherits from `mesa.Model` | ||||||
|  |   - [x] `Soil.Environment` includes a Mesa-like Scheduler (see the `soil.time` module. | ||||||
|  | - [ ] Integrate `soil.Agent` with `mesa.Agent`: | ||||||
|  |   - [x] Rename agent.id to unique_id? | ||||||
|  |   - [x] mesa agents can be used in soil simulations (see `examples/mesa`) | ||||||
|  | - [ ] Document the new APIs and usage | ||||||
|  |  | ||||||
|  | @Copyright GSI - Universidad Politécnica de Madrid 2017-2021 | ||||||
|  |  | ||||||
|  | [](https://www.gsi.upm.es) | ||||||
|   | |||||||
| @@ -31,7 +31,7 @@ | |||||||
| # Add any Sphinx extension module names here, as strings. They can be | # Add any Sphinx extension module names here, as strings. They can be | ||||||
| # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom | ||||||
| # ones. | # ones. | ||||||
| extensions = [] | extensions = ['IPython.sphinxext.ipython_console_highlighting'] | ||||||
|  |  | ||||||
| # Add any paths that contain templates here, relative to this directory. | # Add any paths that contain templates here, relative to this directory. | ||||||
| templates_path = ['_templates'] | templates_path = ['_templates'] | ||||||
| @@ -69,7 +69,7 @@ language = None | |||||||
| # List of patterns, relative to source directory, that match files and | # List of patterns, relative to source directory, that match files and | ||||||
| # directories to ignore when looking for source files. | # directories to ignore when looking for source files. | ||||||
| # This patterns also effect to html_static_path and html_extra_path | # This patterns also effect to html_static_path and html_extra_path | ||||||
| exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '**.ipynb_checkpoints'] | ||||||
|  |  | ||||||
| # The name of the Pygments (syntax highlighting) style to use. | # The name of the Pygments (syntax highlighting) style to use. | ||||||
| pygments_style = 'sphinx' | pygments_style = 'sphinx' | ||||||
|   | |||||||
| @@ -218,3 +218,24 @@ These agents are programmed in much the same way as network agents, the only dif | |||||||
|  |  | ||||||
| You may use environment agents to model events that a normal agent cannot control, such as natural disasters or chance. | You may use environment agents to model events that a normal agent cannot control, such as natural disasters or chance. | ||||||
| They are also useful to add behavior that has little to do with the network and the interactions within that network. | They are also useful to add behavior that has little to do with the network and the interactions within that network. | ||||||
|  |  | ||||||
|  | Templating | ||||||
|  | ========== | ||||||
|  |  | ||||||
|  | Sometimes, it is useful to parameterize a simulation and run it over a range of values in order to compare each run and measure the effect of those parameters in the simulation. | ||||||
|  | For instance, you may want to run a simulation with different agent distributions. | ||||||
|  |  | ||||||
|  | This can be done in Soil using **templates**. | ||||||
|  | A template is a configuration where some of the values are specified with a variable. | ||||||
|  | e.g.,  ``weight: "{{ var1 }}"`` instead of ``weight: 1``. | ||||||
|  | There are two types of variables, depending on how their values are decided: | ||||||
|  |  | ||||||
|  | * Fixed. A list of values is provided, and a new simulation is run for each possible value. If more than a variable is given, a new simulation will be run per combination of values. | ||||||
|  | * Bounded/Sampled. The bounds of the variable are provided, along with a sampler method, which will be used to compute all the configuration combinations. | ||||||
|  |  | ||||||
|  | When fixed and bounded variables are mixed, Soil generates a new configuration per combination of fixed values and bounded values. | ||||||
|  |  | ||||||
|  | Here is an example with a single fixed variable and two bounded variable: | ||||||
|  |  | ||||||
|  | .. literalinclude:: ../examples/template.yml | ||||||
|  |    :language: yaml | ||||||
|   | |||||||
| @@ -14,11 +14,11 @@ Now test that it worked by running the command line tool | |||||||
|  |  | ||||||
|    soil --help |    soil --help | ||||||
|  |  | ||||||
| Or using soil programmatically: | Or, if you're using using soil programmatically: | ||||||
|  |  | ||||||
| .. code:: python | .. code:: python | ||||||
|  |  | ||||||
|    import soil |    import soil | ||||||
|    print(soil.__version__) |    print(soil.__version__) | ||||||
|  |  | ||||||
| The latest version can be installed through `GitLab <https://lab.cluster.gsi.dit.upm.es/soil/soil.git>`_. | The latest version can be installed through `GitLab <https://lab.gsi.upm.es/soil/soil.git>`_ or `GitHub <https://github.com/gsi-upm/soil>`_. | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								docs/requirements.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								docs/requirements.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | ipython==7.23 | ||||||
| @@ -47,12 +47,6 @@ There are three main elements in a soil simulation: | |||||||
| -  The environment. It assigns agents to nodes in the network, and | -  The environment. It assigns agents to nodes in the network, and | ||||||
|    stores the environment parameters (shared state for all agents). |    stores the environment parameters (shared state for all agents). | ||||||
|  |  | ||||||
| Soil is based on ``simpy``, which is an event-based network simulation |  | ||||||
| library. Soil provides several abstractions over events to make |  | ||||||
| developing agents easier. This means you can use events (timeouts, |  | ||||||
| delays) in soil, but for the most part we will assume your models will |  | ||||||
| be step-based. |  | ||||||
|  |  | ||||||
| Modeling behaviour | Modeling behaviour | ||||||
| ------------------ | ------------------ | ||||||
|  |  | ||||||
|   | |||||||
| @@ -500,7 +500,7 @@ | |||||||
|    "name": "python", |    "name": "python", | ||||||
|    "nbconvert_exporter": "python", |    "nbconvert_exporter": "python", | ||||||
|    "pygments_lexer": "ipython3", |    "pygments_lexer": "ipython3", | ||||||
|    "version": "3.6.5" |    "version": "3.8.5" | ||||||
|   }, |   }, | ||||||
|   "toc": { |   "toc": { | ||||||
|    "colors": { |    "colors": { | ||||||
|   | |||||||
| @@ -80800,7 +80800,7 @@ | |||||||
|    "name": "python", |    "name": "python", | ||||||
|    "nbconvert_exporter": "python", |    "nbconvert_exporter": "python", | ||||||
|    "pygments_lexer": "ipython3", |    "pygments_lexer": "ipython3", | ||||||
|    "version": "3.6.5" |    "version": "3.8.6" | ||||||
|   } |   } | ||||||
|  }, |  }, | ||||||
|  "nbformat": 4, |  "nbformat": 4, | ||||||
|   | |||||||
| @@ -13,7 +13,7 @@ network_agents: | |||||||
|   - agent_type: CounterModel |   - agent_type: CounterModel | ||||||
|     weight: 1 |     weight: 1 | ||||||
|     state: |     state: | ||||||
|       id: 0 |       state_id: 0 | ||||||
|   - agent_type: AggregatedCounter |   - agent_type: AggregatedCounter | ||||||
|     weight: 0.2 |     weight: 0.2 | ||||||
| environment_agents: [] | environment_agents: [] | ||||||
|   | |||||||
| @@ -13,4 +13,4 @@ network_agents: | |||||||
|   - agent_type: CounterModel |   - agent_type: CounterModel | ||||||
|     weight: 1 |     weight: 1 | ||||||
|     state: |     state: | ||||||
|       id: 0 |       state_id: 0 | ||||||
|   | |||||||
							
								
								
									
										21
									
								
								examples/mesa/mesa.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								examples/mesa/mesa.yml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | --- | ||||||
|  | name: mesa_sim | ||||||
|  | group: tests | ||||||
|  | dir_path: "/tmp" | ||||||
|  | num_trials: 3 | ||||||
|  | max_time: 100 | ||||||
|  | interval: 1 | ||||||
|  | seed: '1' | ||||||
|  | network_params: | ||||||
|  |   generator: social_wealth.graph_generator | ||||||
|  |   n: 5 | ||||||
|  | network_agents: | ||||||
|  |   - agent_type: social_wealth.SocialMoneyAgent | ||||||
|  |     weight: 1 | ||||||
|  | environment_class: social_wealth.MoneyEnv | ||||||
|  | environment_params: | ||||||
|  |   num_mesa_agents: 5 | ||||||
|  |   mesa_agent_type: social_wealth.MoneyAgent | ||||||
|  |   N: 10 | ||||||
|  |   width: 50 | ||||||
|  |   height: 50 | ||||||
							
								
								
									
										105
									
								
								examples/mesa/server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								examples/mesa/server.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,105 @@ | |||||||
|  | from mesa.visualization.ModularVisualization import ModularServer | ||||||
|  | from soil.visualization import UserSettableParameter | ||||||
|  | from mesa.visualization.modules import ChartModule, NetworkModule, CanvasGrid | ||||||
|  | from social_wealth import MoneyEnv, graph_generator, SocialMoneyAgent | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MyNetwork(NetworkModule): | ||||||
|  |     def render(self, model): | ||||||
|  |         return self.portrayal_method(model) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def network_portrayal(env): | ||||||
|  |     # The model ensures there is 0 or 1 agent per node | ||||||
|  |  | ||||||
|  |     portrayal = dict() | ||||||
|  |     portrayal["nodes"] = [ | ||||||
|  |         { | ||||||
|  |             "id": agent_id, | ||||||
|  |             "size": env.get_agent(agent_id).wealth, | ||||||
|  |             # "color": "#CC0000" if not agents or agents[0].wealth == 0 else "#007959", | ||||||
|  |             "color": "#CC0000", | ||||||
|  |             "label": f"{agent_id}: {env.get_agent(agent_id).wealth}", | ||||||
|  |         } | ||||||
|  |         for (agent_id) in env.G.nodes | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     portrayal["edges"] = [ | ||||||
|  |         {"id": edge_id, "source": source, "target": target, "color": "#000000"} | ||||||
|  |         for edge_id, (source, target) in enumerate(env.G.edges) | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     return portrayal | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def gridPortrayal(agent): | ||||||
|  |     """ | ||||||
|  |     This function is registered with the visualization server to be called | ||||||
|  |     each tick to indicate how to draw the agent in its current state. | ||||||
|  |     :param agent:  the agent in the simulation | ||||||
|  |     :return: the portrayal dictionary | ||||||
|  |     """ | ||||||
|  |     color = max(10, min(agent.wealth*10, 100)) | ||||||
|  |     return { | ||||||
|  |         "Shape": "rect", | ||||||
|  |         "w": 1, | ||||||
|  |         "h": 1, | ||||||
|  |         "Filled": "true", | ||||||
|  |         "Layer": 0, | ||||||
|  |         "Label": agent.unique_id, | ||||||
|  |         "Text": agent.unique_id, | ||||||
|  |         "x": agent.pos[0], | ||||||
|  |         "y": agent.pos[1], | ||||||
|  |         "Color": f"rgba(31, 10, 255, 0.{color})" | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | grid = MyNetwork(network_portrayal, 500, 500, library="sigma") | ||||||
|  | chart = ChartModule( | ||||||
|  |     [{"Label": "Gini", "Color": "Black"}], data_collector_name="datacollector" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | model_params = { | ||||||
|  |     "N": UserSettableParameter( | ||||||
|  |         "slider", | ||||||
|  |         "N", | ||||||
|  |         5, | ||||||
|  |         1, | ||||||
|  |         10, | ||||||
|  |         1, | ||||||
|  |         description="Choose how many agents to include in the model", | ||||||
|  |     ), | ||||||
|  |     "network_agents": [{"agent_type": SocialMoneyAgent}], | ||||||
|  |     "height": UserSettableParameter( | ||||||
|  |         "slider", | ||||||
|  |         "height", | ||||||
|  |         5, | ||||||
|  |         5, | ||||||
|  |         10, | ||||||
|  |         1, | ||||||
|  |         description="Grid height", | ||||||
|  |         ), | ||||||
|  |     "width": UserSettableParameter( | ||||||
|  |         "slider", | ||||||
|  |         "width", | ||||||
|  |         5, | ||||||
|  |         5, | ||||||
|  |         10, | ||||||
|  |         1, | ||||||
|  |         description="Grid width", | ||||||
|  |         ), | ||||||
|  |     "network_params": { | ||||||
|  |         'generator': graph_generator | ||||||
|  |     }, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | canvas_element = CanvasGrid(gridPortrayal, model_params["width"].value, model_params["height"].value, 500, 500) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | server = ModularServer( | ||||||
|  |     MoneyEnv, [grid, chart, canvas_element], "Money Model", model_params | ||||||
|  | ) | ||||||
|  | server.port = 8521 | ||||||
|  |  | ||||||
|  | server.launch(open_browser=False) | ||||||
							
								
								
									
										120
									
								
								examples/mesa/social_wealth.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								examples/mesa/social_wealth.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,120 @@ | |||||||
|  | ''' | ||||||
|  | This is an example that adds soil agents and environment in a normal | ||||||
|  | mesa workflow. | ||||||
|  | ''' | ||||||
|  | from mesa import Agent as MesaAgent | ||||||
|  | from mesa.space import MultiGrid | ||||||
|  | # from mesa.time import RandomActivation | ||||||
|  | from mesa.datacollection import DataCollector | ||||||
|  | from mesa.batchrunner import BatchRunner | ||||||
|  |  | ||||||
|  | import networkx as nx | ||||||
|  |  | ||||||
|  | from soil import NetworkAgent, Environment | ||||||
|  |  | ||||||
|  | def compute_gini(model): | ||||||
|  |     agent_wealths = [agent.wealth for agent in model.agents] | ||||||
|  |     x = sorted(agent_wealths) | ||||||
|  |     N = len(list(model.agents)) | ||||||
|  |     B = sum( xi * (N-i) for i,xi in enumerate(x) ) / (N*sum(x)) | ||||||
|  |     return (1 + (1/N) - 2*B) | ||||||
|  |  | ||||||
|  | class MoneyAgent(MesaAgent): | ||||||
|  |     """ | ||||||
|  |     A MESA agent with fixed initial wealth. | ||||||
|  |     It will only share wealth with neighbors based on grid proximity | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, unique_id, model): | ||||||
|  |         super().__init__(unique_id=unique_id, model=model) | ||||||
|  |         self.wealth = 1 | ||||||
|  |  | ||||||
|  |     def move(self): | ||||||
|  |         possible_steps = self.model.grid.get_neighborhood( | ||||||
|  |             self.pos, | ||||||
|  |             moore=True, | ||||||
|  |             include_center=False) | ||||||
|  |         new_position = self.random.choice(possible_steps) | ||||||
|  |         self.model.grid.move_agent(self, new_position) | ||||||
|  |  | ||||||
|  |     def give_money(self): | ||||||
|  |         cellmates = self.model.grid.get_cell_list_contents([self.pos]) | ||||||
|  |         if len(cellmates) > 1: | ||||||
|  |             other = self.random.choice(cellmates) | ||||||
|  |             other.wealth += 1 | ||||||
|  |             self.wealth -= 1 | ||||||
|  |  | ||||||
|  |     def step(self): | ||||||
|  |         self.info("Crying wolf", self.pos) | ||||||
|  |         self.move() | ||||||
|  |         if self.wealth > 0: | ||||||
|  |             self.give_money() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SocialMoneyAgent(NetworkAgent, MoneyAgent): | ||||||
|  |     wealth = 1 | ||||||
|  |  | ||||||
|  |     def give_money(self): | ||||||
|  |         cellmates = set(self.model.grid.get_cell_list_contents([self.pos])) | ||||||
|  |         friends = set(self.get_neighboring_agents()) | ||||||
|  |         self.info("Trying to give money") | ||||||
|  |         self.debug("Cellmates: ", cellmates) | ||||||
|  |         self.debug("Friends: ", friends) | ||||||
|  |  | ||||||
|  |         nearby_friends = list(cellmates & friends) | ||||||
|  |  | ||||||
|  |         if len(nearby_friends): | ||||||
|  |             other = self.random.choice(nearby_friends) | ||||||
|  |             other.wealth += 1 | ||||||
|  |             self.wealth -= 1 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MoneyEnv(Environment): | ||||||
|  |     """A model with some number of agents.""" | ||||||
|  |     def __init__(self, N, width, height, *args, network_params, **kwargs): | ||||||
|  |  | ||||||
|  |         network_params['n'] = N | ||||||
|  |         super().__init__(*args, network_params=network_params, **kwargs) | ||||||
|  |         self.grid = MultiGrid(width, height, False) | ||||||
|  |  | ||||||
|  |         # Create agents | ||||||
|  |         for agent in self.agents: | ||||||
|  |             x = self.random.randrange(self.grid.width) | ||||||
|  |             y = self.random.randrange(self.grid.height) | ||||||
|  |             self.grid.place_agent(agent, (x, y)) | ||||||
|  |  | ||||||
|  |         self.datacollector = DataCollector( | ||||||
|  |             model_reporters={"Gini": compute_gini}, | ||||||
|  |             agent_reporters={"Wealth": "wealth"}) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def graph_generator(n=5): | ||||||
|  |     G = nx.Graph() | ||||||
|  |     for ix in range(n): | ||||||
|  |         G.add_edge(0, ix) | ||||||
|  |     return G | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     G = graph_generator() | ||||||
|  |     fixed_params = {"topology": G, | ||||||
|  |                     "width": 10, | ||||||
|  |                     "network_agents": [{"agent_type": SocialMoneyAgent, | ||||||
|  |                                        'weight': 1}], | ||||||
|  |                     "height": 10} | ||||||
|  |  | ||||||
|  |     variable_params = {"N": range(10, 100, 10)} | ||||||
|  |  | ||||||
|  |     batch_run = BatchRunner(MoneyEnv, | ||||||
|  |                             variable_parameters=variable_params, | ||||||
|  |                             fixed_parameters=fixed_params, | ||||||
|  |                             iterations=5, | ||||||
|  |                             max_steps=100, | ||||||
|  |                             model_reporters={"Gini": compute_gini}) | ||||||
|  |     batch_run.run_all() | ||||||
|  |  | ||||||
|  |     run_data = batch_run.get_model_vars_dataframe() | ||||||
|  |     run_data.head() | ||||||
|  |     print(run_data.Gini) | ||||||
|  |  | ||||||
							
								
								
									
										83
									
								
								examples/mesa/wealth.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								examples/mesa/wealth.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,83 @@ | |||||||
|  | from mesa import Agent, Model | ||||||
|  | from mesa.space import MultiGrid | ||||||
|  | from mesa.time import RandomActivation | ||||||
|  | from mesa.datacollection import DataCollector | ||||||
|  | from mesa.batchrunner import BatchRunner | ||||||
|  |  | ||||||
|  | def compute_gini(model): | ||||||
|  |     agent_wealths = [agent.wealth for agent in model.schedule.agents] | ||||||
|  |     x = sorted(agent_wealths) | ||||||
|  |     N = model.num_agents | ||||||
|  |     B = sum( xi * (N-i) for i,xi in enumerate(x) ) / (N*sum(x)) | ||||||
|  |     return (1 + (1/N) - 2*B) | ||||||
|  |  | ||||||
|  | class MoneyAgent(Agent): | ||||||
|  |     """ An agent with fixed initial wealth.""" | ||||||
|  |     def __init__(self, unique_id, model): | ||||||
|  |         super().__init__(unique_id, model) | ||||||
|  |         self.wealth = 1 | ||||||
|  |  | ||||||
|  |     def move(self): | ||||||
|  |         possible_steps = self.model.grid.get_neighborhood( | ||||||
|  |             self.pos, | ||||||
|  |             moore=True, | ||||||
|  |             include_center=False) | ||||||
|  |         new_position = self.random.choice(possible_steps) | ||||||
|  |         self.model.grid.move_agent(self, new_position) | ||||||
|  |  | ||||||
|  |     def give_money(self): | ||||||
|  |         cellmates = self.model.grid.get_cell_list_contents([self.pos]) | ||||||
|  |         if len(cellmates) > 1: | ||||||
|  |             other = self.random.choice(cellmates) | ||||||
|  |             other.wealth += 1 | ||||||
|  |             self.wealth -= 1 | ||||||
|  |  | ||||||
|  |     def step(self): | ||||||
|  |         self.move() | ||||||
|  |         if self.wealth > 0: | ||||||
|  |             self.give_money() | ||||||
|  |  | ||||||
|  | class MoneyModel(Model): | ||||||
|  |     """A model with some number of agents.""" | ||||||
|  |     def __init__(self, N, width, height): | ||||||
|  |         self.num_agents = N | ||||||
|  |         self.grid = MultiGrid(width, height, True) | ||||||
|  |         self.schedule = RandomActivation(self) | ||||||
|  |         self.running = True | ||||||
|  |  | ||||||
|  |         # Create agents | ||||||
|  |         for i in range(self.num_agents): | ||||||
|  |             a = MoneyAgent(i, self) | ||||||
|  |             self.schedule.add(a) | ||||||
|  |             # Add the agent to a random grid cell | ||||||
|  |             x = self.random.randrange(self.grid.width) | ||||||
|  |             y = self.random.randrange(self.grid.height) | ||||||
|  |             self.grid.place_agent(a, (x, y)) | ||||||
|  |  | ||||||
|  |         self.datacollector = DataCollector( | ||||||
|  |             model_reporters={"Gini": compute_gini}, | ||||||
|  |             agent_reporters={"Wealth": "wealth"}) | ||||||
|  |  | ||||||
|  |     def step(self): | ||||||
|  |         self.datacollector.collect(self) | ||||||
|  |         self.schedule.step() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |  | ||||||
|  |     fixed_params = {"width": 10, | ||||||
|  |                     "height": 10} | ||||||
|  |     variable_params = {"N": range(10, 500, 10)} | ||||||
|  |  | ||||||
|  |     batch_run = BatchRunner(MoneyModel, | ||||||
|  |                             variable_params, | ||||||
|  |                             fixed_params, | ||||||
|  |                             iterations=5, | ||||||
|  |                             max_steps=100, | ||||||
|  |                             model_reporters={"Gini": compute_gini}) | ||||||
|  |     batch_run.run_all() | ||||||
|  |  | ||||||
|  |     run_data = batch_run.get_model_vars_dataframe() | ||||||
|  |     run_data.head() | ||||||
|  |     print(run_data.Gini) | ||||||
|  |  | ||||||
| @@ -68,12 +68,12 @@ network_agents: | |||||||
| - agent_type: HerdViewer | - agent_type: HerdViewer | ||||||
|   state: |   state: | ||||||
|     has_tv: true |     has_tv: true | ||||||
|     id: neutral |     state_id: neutral | ||||||
|   weight: 1 |   weight: 1 | ||||||
| - agent_type: HerdViewer | - agent_type: HerdViewer | ||||||
|   state: |   state: | ||||||
|     has_tv: true |     has_tv: true | ||||||
|     id: neutral |     state_id: neutral | ||||||
|   weight: 1 |   weight: 1 | ||||||
| network_params: | network_params: | ||||||
|   generator: barabasi_albert_graph |   generator: barabasi_albert_graph | ||||||
| @@ -95,7 +95,7 @@ network_agents: | |||||||
| - agent_type: HerdViewer | - agent_type: HerdViewer | ||||||
|   state: |   state: | ||||||
|     has_tv: true |     has_tv: true | ||||||
|     id: neutral |     state_id: neutral | ||||||
|   weight: 1 |   weight: 1 | ||||||
| - agent_type: WiseViewer | - agent_type: WiseViewer | ||||||
|   state: |   state: | ||||||
| @@ -121,7 +121,7 @@ network_agents: | |||||||
| - agent_type: WiseViewer | - agent_type: WiseViewer | ||||||
|   state: |   state: | ||||||
|     has_tv: true |     has_tv: true | ||||||
|     id: neutral |     state_id: neutral | ||||||
|   weight: 1 |   weight: 1 | ||||||
| - agent_type: WiseViewer | - agent_type: WiseViewer | ||||||
|   state: |   state: | ||||||
|   | |||||||
| @@ -34,8 +34,6 @@ class HerdViewer(DumbViewer): | |||||||
|     A viewer whose probability of infection depends on the state of its neighbors. |     A viewer whose probability of infection depends on the state of its neighbors. | ||||||
|     ''' |     ''' | ||||||
|  |  | ||||||
|     level = logging.DEBUG |  | ||||||
|  |  | ||||||
|     def infect(self): |     def infect(self): | ||||||
|         infected = self.count_neighboring_agents(state_id=self.infected.id) |         infected = self.count_neighboring_agents(state_id=self.infected.id) | ||||||
|         total = self.count_neighboring_agents() |         total = self.count_neighboring_agents() | ||||||
|   | |||||||
| @@ -1,7 +1,6 @@ | |||||||
| from soil.agents import FSM, state, default_state, BaseAgent | from soil.agents import FSM, state, default_state, BaseAgent, NetworkAgent | ||||||
| from enum import Enum | from enum import Enum | ||||||
| from random import random, choice | from random import random, choice | ||||||
| from itertools import islice |  | ||||||
| import logging | import logging | ||||||
| import math | import math | ||||||
|  |  | ||||||
| @@ -22,7 +21,7 @@ class RabbitModel(FSM): | |||||||
|         'offspring': 0, |         'offspring': 0, | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     sexual_maturity = 4*30 |     sexual_maturity = 3 #4*30 | ||||||
|     life_expectancy = 365 * 3 |     life_expectancy = 365 * 3 | ||||||
|     gestation = 33 |     gestation = 33 | ||||||
|     pregnancy = -1 |     pregnancy = -1 | ||||||
| @@ -31,9 +30,11 @@ class RabbitModel(FSM): | |||||||
|     @default_state |     @default_state | ||||||
|     @state |     @state | ||||||
|     def newborn(self): |     def newborn(self): | ||||||
|  |         self.debug(f'I am a newborn at age {self["age"]}') | ||||||
|         self['age'] += 1 |         self['age'] += 1 | ||||||
|  |  | ||||||
|         if self['age'] >= self.sexual_maturity: |         if self['age'] >= self.sexual_maturity: | ||||||
|  |             self.debug('I am fertile!') | ||||||
|             return self.fertile |             return self.fertile | ||||||
|  |  | ||||||
|     @state |     @state | ||||||
| @@ -46,8 +47,7 @@ class RabbitModel(FSM): | |||||||
|             return |             return | ||||||
|  |  | ||||||
|         # Males try to mate |         # Males try to mate | ||||||
|         females = self.get_agents(state_id=self.fertile.id, gender=Genders.female.value, limit_neighbors=False) |         for f in self.get_agents(state_id=self.fertile.id, gender=Genders.female.value, limit_neighbors=False, limit=self.max_females): | ||||||
|         for f in islice(females, self.max_females): |  | ||||||
|             r = random() |             r = random() | ||||||
|             if r < self['mating_prob']: |             if r < self['mating_prob']: | ||||||
|                 self.impregnate(f) |                 self.impregnate(f) | ||||||
| @@ -80,7 +80,7 @@ class RabbitModel(FSM): | |||||||
|                 self.env.add_edge(self['mate'], child.id) |                 self.env.add_edge(self['mate'], child.id) | ||||||
|                 # self.add_edge() |                 # self.add_edge() | ||||||
|                 self.debug('A BABY IS COMING TO LIFE') |                 self.debug('A BABY IS COMING TO LIFE') | ||||||
|                 self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.global_topology.number_of_nodes())+1 |                 self.env['rabbits_alive'] = self.env.get('rabbits_alive', self.topology.number_of_nodes())+1 | ||||||
|                 self.debug('Rabbits alive: {}'.format(self.env['rabbits_alive'])) |                 self.debug('Rabbits alive: {}'.format(self.env['rabbits_alive'])) | ||||||
|                 self['offspring'] += 1 |                 self['offspring'] += 1 | ||||||
|                 self.env.get_agent(self['mate'])['offspring'] += 1 |                 self.env.get_agent(self['mate'])['offspring'] += 1 | ||||||
| @@ -97,12 +97,14 @@ class RabbitModel(FSM): | |||||||
|         return |         return | ||||||
|  |  | ||||||
|  |  | ||||||
| class RandomAccident(BaseAgent): | class RandomAccident(NetworkAgent): | ||||||
|  |  | ||||||
|     level = logging.DEBUG |     level = logging.DEBUG | ||||||
|  |  | ||||||
|     def step(self): |     def step(self): | ||||||
|         rabbits_total = self.global_topology.number_of_nodes() |         rabbits_total = self.topology.number_of_nodes() | ||||||
|  |         if 'rabbits_alive' not in self.env: | ||||||
|  |             self.env['rabbits_alive'] = 0 | ||||||
|         rabbits_alive = self.env.get('rabbits_alive', rabbits_total) |         rabbits_alive = self.env.get('rabbits_alive', rabbits_total) | ||||||
|         prob_death = self.env.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive))) |         prob_death = self.env.get('prob_death', 1e-100)*math.floor(math.log10(max(1, rabbits_alive))) | ||||||
|         self.debug('Killing some rabbits with prob={}!'.format(prob_death)) |         self.debug('Killing some rabbits with prob={}!'.format(prob_death)) | ||||||
| @@ -116,5 +118,5 @@ class RandomAccident(BaseAgent): | |||||||
|                 self.log('Rabbits alive: {}'.format(self.env['rabbits_alive'])) |                 self.log('Rabbits alive: {}'.format(self.env['rabbits_alive'])) | ||||||
|                 i.set_state(i.dead) |                 i.set_state(i.dead) | ||||||
|         self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total)) |         self.log('Rabbits alive: {}/{}'.format(rabbits_alive, rabbits_total)) | ||||||
|         if self.count_agents(state_id=RabbitModel.dead.id) == self.global_topology.number_of_nodes(): |         if self.count_agents(state_id=RabbitModel.dead.id) == self.topology.number_of_nodes(): | ||||||
|             self.die() |             self.die() | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| --- | --- | ||||||
| load_module: rabbit_agents | load_module: rabbit_agents | ||||||
| name: rabbits_example | name: rabbits_example | ||||||
| max_time: 500 | max_time: 150 | ||||||
| interval: 1 | interval: 1 | ||||||
| seed: MySeed | seed: MySeed | ||||||
| agent_type: RabbitModel | agent_type: RabbitModel | ||||||
|   | |||||||
							
								
								
									
										45
									
								
								examples/random_delays/random_delays.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								examples/random_delays/random_delays.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | |||||||
|  | ''' | ||||||
|  | Example of setting a  | ||||||
|  | Example of a fully programmatic simulation, without definition files. | ||||||
|  | ''' | ||||||
|  | from soil import Simulation, agents | ||||||
|  | from soil.time import Delta | ||||||
|  | from random import expovariate | ||||||
|  | import logging | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MyAgent(agents.FSM): | ||||||
|  |     ''' | ||||||
|  |     An agent that first does a ping | ||||||
|  |     ''' | ||||||
|  |  | ||||||
|  |     defaults = {'pong_counts': 2} | ||||||
|  |  | ||||||
|  |     @agents.default_state | ||||||
|  |     @agents.state | ||||||
|  |     def ping(self): | ||||||
|  |         self.info('Ping') | ||||||
|  |         return self.pong, Delta(expovariate(1/16)) | ||||||
|  |  | ||||||
|  |     @agents.state | ||||||
|  |     def pong(self): | ||||||
|  |         self.info('Pong') | ||||||
|  |         self.pong_counts -= 1 | ||||||
|  |         self.info(str(self.pong_counts)) | ||||||
|  |         if self.pong_counts < 1: | ||||||
|  |             return self.die() | ||||||
|  |         return None, Delta(expovariate(1/16)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | s = Simulation(name='Programmatic', | ||||||
|  |                network_agents=[{'agent_type': MyAgent, 'id': 0}], | ||||||
|  |                topology={'nodes': [{'id': 0}], 'links': []}, | ||||||
|  |                num_trials=1, | ||||||
|  |                max_time=100, | ||||||
|  |                agent_type=MyAgent, | ||||||
|  |                dry_run=True) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | logging.basicConfig(level=logging.INFO) | ||||||
|  | envs = s.run() | ||||||
| @@ -1,13 +1,8 @@ | |||||||
| --- | --- | ||||||
| vars: | sampler: | ||||||
|   bounds: |   method: "SALib.sample.morris.sample" | ||||||
|     x1: [0, 1] |   N: 10 | ||||||
|     x2: [1, 2] | template: | ||||||
|   fixed: |  | ||||||
|     x3: ["a", "b", "c"] |  | ||||||
| sampler: "SALib.sample.morris.sample" |  | ||||||
| samples: 10 |  | ||||||
| template: | |  | ||||||
|   group: simple |   group: simple | ||||||
|   num_trials: 1 |   num_trials: 1 | ||||||
|   interval: 1 |   interval: 1 | ||||||
| @@ -19,11 +14,17 @@ template: | | |||||||
|     n: 10 |     n: 10 | ||||||
|   network_agents: |   network_agents: | ||||||
|     - agent_type: CounterModel |     - agent_type: CounterModel | ||||||
|       weight: {{ x1 }} |       weight: "{{ x1 }}" | ||||||
|       state: |       state: | ||||||
|         id: 0 |         state_id: 0 | ||||||
|     - agent_type: AggregatedCounter |     - agent_type: AggregatedCounter | ||||||
|       weight: {{ 1 - x1 }} |       weight: "{{ 1 - x1 }}" | ||||||
|   environment_params: |   environment_params: | ||||||
|     name: {{ x3 }} |     name: "{{ x3 }}" | ||||||
|   skip_test: true |   skip_test: true | ||||||
|  | vars: | ||||||
|  |   bounds: | ||||||
|  |     x1: [0, 1] | ||||||
|  |     x2: [1, 2] | ||||||
|  |   fixed: | ||||||
|  |     x3: ["a", "b", "c"] | ||||||
|   | |||||||
| @@ -18,12 +18,12 @@ class TerroristSpreadModel(FSM, Geo): | |||||||
|         prob_interaction |         prob_interaction | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, environment=None, agent_id=0, state=()): |     def __init__(self, model=None, unique_id=0, state=()): | ||||||
|         super().__init__(environment=environment, agent_id=agent_id, state=state) |         super().__init__(model=model, unique_id=unique_id, state=state) | ||||||
|  |  | ||||||
|         self.information_spread_intensity = environment.environment_params['information_spread_intensity'] |         self.information_spread_intensity = model.environment_params['information_spread_intensity'] | ||||||
|         self.terrorist_additional_influence = environment.environment_params['terrorist_additional_influence'] |         self.terrorist_additional_influence = model.environment_params['terrorist_additional_influence'] | ||||||
|         self.prob_interaction = environment.environment_params['prob_interaction'] |         self.prob_interaction = model.environment_params['prob_interaction'] | ||||||
|  |  | ||||||
|         if self['id'] == self.civilian.id:       # Civilian |         if self['id'] == self.civilian.id:       # Civilian | ||||||
|             self.mean_belief = random.uniform(0.00, 0.5) |             self.mean_belief = random.uniform(0.00, 0.5) | ||||||
| @@ -34,10 +34,10 @@ class TerroristSpreadModel(FSM, Geo): | |||||||
|         else: |         else: | ||||||
|             raise Exception('Invalid state id: {}'.format(self['id'])) |             raise Exception('Invalid state id: {}'.format(self['id'])) | ||||||
|  |  | ||||||
|         if 'min_vulnerability' in environment.environment_params: |         if 'min_vulnerability' in model.environment_params: | ||||||
|             self.vulnerability = random.uniform( environment.environment_params['min_vulnerability'], environment.environment_params['max_vulnerability'] ) |             self.vulnerability = random.uniform( model.environment_params['min_vulnerability'], model.environment_params['max_vulnerability'] ) | ||||||
|         else : |         else : | ||||||
|             self.vulnerability = random.uniform( 0, environment.environment_params['max_vulnerability'] ) |             self.vulnerability = random.uniform( 0, model.environment_params['max_vulnerability'] ) | ||||||
|  |  | ||||||
|  |  | ||||||
|     @state |     @state | ||||||
| @@ -93,11 +93,11 @@ class TrainingAreaModel(FSM, Geo): | |||||||
|     Requires TerroristSpreadModel. |     Requires TerroristSpreadModel. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, environment=None, agent_id=0, state=()): |     def __init__(self, model=None, unique_id=0, state=()): | ||||||
|         super().__init__(environment=environment, agent_id=agent_id, state=state) |         super().__init__(model=model, unique_id=unique_id, state=state) | ||||||
|         self.training_influence = environment.environment_params['training_influence'] |         self.training_influence = model.environment_params['training_influence'] | ||||||
|         if 'min_vulnerability' in environment.environment_params: |         if 'min_vulnerability' in model.environment_params: | ||||||
|             self.min_vulnerability = environment.environment_params['min_vulnerability'] |             self.min_vulnerability = model.environment_params['min_vulnerability'] | ||||||
|         else: self.min_vulnerability = 0 |         else: self.min_vulnerability = 0 | ||||||
|  |  | ||||||
|     @default_state |     @default_state | ||||||
| @@ -120,13 +120,13 @@ class HavenModel(FSM, Geo): | |||||||
|     Requires TerroristSpreadModel. |     Requires TerroristSpreadModel. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, environment=None, agent_id=0, state=()): |     def __init__(self, model=None, unique_id=0, state=()): | ||||||
|         super().__init__(environment=environment, agent_id=agent_id, state=state) |         super().__init__(model=model, unique_id=unique_id, state=state) | ||||||
|         self.haven_influence = environment.environment_params['haven_influence'] |         self.haven_influence = model.environment_params['haven_influence'] | ||||||
|         if 'min_vulnerability' in environment.environment_params: |         if 'min_vulnerability' in model.environment_params: | ||||||
|             self.min_vulnerability = environment.environment_params['min_vulnerability'] |             self.min_vulnerability = model.environment_params['min_vulnerability'] | ||||||
|         else: self.min_vulnerability = 0 |         else: self.min_vulnerability = 0 | ||||||
|         self.max_vulnerability = environment.environment_params['max_vulnerability'] |         self.max_vulnerability = model.environment_params['max_vulnerability'] | ||||||
|  |  | ||||||
|     def get_occupants(self, **kwargs): |     def get_occupants(self, **kwargs): | ||||||
|         return self.get_neighboring_agents(agent_type=TerroristSpreadModel, **kwargs) |         return self.get_neighboring_agents(agent_type=TerroristSpreadModel, **kwargs) | ||||||
| @@ -162,13 +162,13 @@ class TerroristNetworkModel(TerroristSpreadModel): | |||||||
|         weight_link_distance |         weight_link_distance | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, environment=None, agent_id=0, state=()): |     def __init__(self, model=None, unique_id=0, state=()): | ||||||
|         super().__init__(environment=environment, agent_id=agent_id, state=state) |         super().__init__(model=model, unique_id=unique_id, state=state) | ||||||
|  |  | ||||||
|         self.vision_range = environment.environment_params['vision_range'] |         self.vision_range = model.environment_params['vision_range'] | ||||||
|         self.sphere_influence = environment.environment_params['sphere_influence'] |         self.sphere_influence = model.environment_params['sphere_influence'] | ||||||
|         self.weight_social_distance = environment.environment_params['weight_social_distance'] |         self.weight_social_distance = model.environment_params['weight_social_distance'] | ||||||
|         self.weight_link_distance = environment.environment_params['weight_link_distance'] |         self.weight_link_distance = model.environment_params['weight_link_distance'] | ||||||
|  |  | ||||||
|     @state |     @state | ||||||
|     def terrorist(self): |     def terrorist(self): | ||||||
| @@ -195,14 +195,14 @@ class TerroristNetworkModel(TerroristSpreadModel): | |||||||
|                     break |                     break | ||||||
|  |  | ||||||
|     def get_distance(self, target): |     def get_distance(self, target): | ||||||
|         source_x, source_y = nx.get_node_attributes(self.global_topology, 'pos')[self.id] |         source_x, source_y = nx.get_node_attributes(self.topology, 'pos')[self.id] | ||||||
|         target_x, target_y = nx.get_node_attributes(self.global_topology, 'pos')[target] |         target_x, target_y = nx.get_node_attributes(self.topology, 'pos')[target] | ||||||
|         dx = abs( source_x - target_x ) |         dx = abs( source_x - target_x ) | ||||||
|         dy = abs( source_y - target_y ) |         dy = abs( source_y - target_y ) | ||||||
|         return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 ) |         return ( dx ** 2 + dy ** 2 ) ** ( 1 / 2 ) | ||||||
|  |  | ||||||
|     def shortest_path_length(self, target): |     def shortest_path_length(self, target): | ||||||
|         try: |         try: | ||||||
|             return nx.shortest_path_length(self.global_topology, self.id, target) |             return nx.shortest_path_length(self.topology, self.id, target) | ||||||
|         except nx.NetworkXNoPath: |         except nx.NetworkXNoPath: | ||||||
|             return float('inf') |             return float('inf') | ||||||
|   | |||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -1,10 +1,9 @@ | |||||||
| nxsim>=0.1.2 | networkx>=2.5 | ||||||
| simpy |  | ||||||
| networkx>=2.0,<2.4 |  | ||||||
| numpy | numpy | ||||||
| matplotlib | matplotlib | ||||||
| pyyaml>=5.1 | pyyaml>=5.1 | ||||||
| pandas>=0.23 | pandas>=0.23 | ||||||
| scipy==1.2.1 # scipy 1.3.0rc1 is not compatible with salib |  | ||||||
| SALib>=1.3 | SALib>=1.3 | ||||||
| Jinja2 | Jinja2 | ||||||
|  | Mesa>=0.8 | ||||||
|  | tsih>=0.1.5 | ||||||
|   | |||||||
							
								
								
									
										11
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								setup.py
									
									
									
									
									
								
							| @@ -16,6 +16,12 @@ def parse_requirements(filename): | |||||||
|  |  | ||||||
| install_reqs = parse_requirements("requirements.txt") | install_reqs = parse_requirements("requirements.txt") | ||||||
| test_reqs = parse_requirements("test-requirements.txt") | test_reqs = parse_requirements("test-requirements.txt") | ||||||
|  | extras_require={ | ||||||
|  |     'mesa': ['mesa>=0.8.9'], | ||||||
|  |     'geo': ['scipy>=1.3'], | ||||||
|  |     'web': ['tornado'] | ||||||
|  | } | ||||||
|  | extras_require['all'] = [dep for package in extras_require.values() for dep in package] | ||||||
|  |  | ||||||
|  |  | ||||||
| setup( | setup( | ||||||
| @@ -40,10 +46,7 @@ setup( | |||||||
|         'Operating System :: POSIX', |         'Operating System :: POSIX', | ||||||
|         'Programming Language :: Python :: 3'], |         'Programming Language :: Python :: 3'], | ||||||
|     install_requires=install_reqs, |     install_requires=install_reqs, | ||||||
|     extras_require={ |     extras_require=extras_require, | ||||||
|         'web': ['tornado'] |  | ||||||
|  |  | ||||||
|     }, |  | ||||||
|     tests_require=test_reqs, |     tests_require=test_reqs, | ||||||
|     setup_requires=['pytest-runner', ], |     setup_requires=['pytest-runner', ], | ||||||
|     include_package_data=True, |     include_package_data=True, | ||||||
|   | |||||||
| @@ -1 +1 @@ | |||||||
| 0.14.7 | 0.20.1 | ||||||
| @@ -11,25 +11,28 @@ try: | |||||||
| except NameError: | except NameError: | ||||||
|     basestring = str |     basestring = str | ||||||
|  |  | ||||||
|  | from .agents import * | ||||||
| from . import agents | from . import agents | ||||||
| from .simulation import * | from .simulation import * | ||||||
| from .environment import Environment | from .environment import Environment | ||||||
| from .history import History |  | ||||||
| from . import serialization | from . import serialization | ||||||
| from . import analysis | from . import analysis | ||||||
|  | from .utils import logger | ||||||
|  | from .time import * | ||||||
|  |  | ||||||
| def main(): | def main(): | ||||||
|     import argparse |     import argparse | ||||||
|     from . import simulation |     from . import simulation | ||||||
|  |  | ||||||
|     logging.basicConfig(level=logging.INFO) |     logger.info('Running SOIL version: {}'.format(__version__)) | ||||||
|     logging.info('Running SOIL version: {}'.format(__version__)) |  | ||||||
|  |  | ||||||
|     parser = argparse.ArgumentParser(description='Run a SOIL simulation') |     parser = argparse.ArgumentParser(description='Run a SOIL simulation') | ||||||
|     parser.add_argument('file', type=str, |     parser.add_argument('file', type=str, | ||||||
|                         nargs="?", |                         nargs="?", | ||||||
|                         default='simulation.yml', |                         default='simulation.yml', | ||||||
|                         help='python module containing the simulation configuration.') |                         help='Configuration file for the simulation (e.g., YAML or JSON)') | ||||||
|  |     parser.add_argument('--version', action='store_true', | ||||||
|  |                         help='Show version info and exit') | ||||||
|     parser.add_argument('--module', '-m', type=str, |     parser.add_argument('--module', '-m', type=str, | ||||||
|                         help='file containing the code of any custom agents.') |                         help='file containing the code of any custom agents.') | ||||||
|     parser.add_argument('--dry-run', '--dry', action='store_true', |     parser.add_argument('--dry-run', '--dry', action='store_true', | ||||||
| @@ -40,6 +43,8 @@ def main(): | |||||||
|                         help='Dump GEXF graph. Defaults to false.') |                         help='Dump GEXF graph. Defaults to false.') | ||||||
|     parser.add_argument('--csv', action='store_true', |     parser.add_argument('--csv', action='store_true', | ||||||
|                         help='Dump history in CSV format. Defaults to false.') |                         help='Dump history in CSV format. Defaults to false.') | ||||||
|  |     parser.add_argument('--level', type=str, | ||||||
|  |                         help='Logging level') | ||||||
|     parser.add_argument('--output', '-o', type=str, default="soil_output", |     parser.add_argument('--output', '-o', type=str, default="soil_output", | ||||||
|                         help='folder to write results to. It defaults to the current directory.') |                         help='folder to write results to. It defaults to the current directory.') | ||||||
|     parser.add_argument('--synchronous', action='store_true', |     parser.add_argument('--synchronous', action='store_true', | ||||||
| @@ -48,13 +53,17 @@ def main(): | |||||||
|                         help='Export environment and/or simulations using this exporter') |                         help='Export environment and/or simulations using this exporter') | ||||||
|  |  | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |     logging.basicConfig(level=getattr(logging, (args.level or 'INFO').upper())) | ||||||
|  |  | ||||||
|  |     if args.version: | ||||||
|  |         return | ||||||
|  |  | ||||||
|     if os.getcwd() not in sys.path: |     if os.getcwd() not in sys.path: | ||||||
|         sys.path.append(os.getcwd()) |         sys.path.append(os.getcwd()) | ||||||
|     if args.module: |     if args.module: | ||||||
|         importlib.import_module(args.module) |         importlib.import_module(args.module) | ||||||
|  |  | ||||||
|     logging.info('Loading config file: {}'.format(args.file)) |     logger.info('Loading config file: {}'.format(args.file)) | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         exporters = list(args.exporter or ['default', ]) |         exporters = list(args.exporter or ['default', ]) | ||||||
| @@ -65,6 +74,10 @@ def main(): | |||||||
|         exp_params = {} |         exp_params = {} | ||||||
|         if args.dry_run: |         if args.dry_run: | ||||||
|             exp_params['copy_to'] = sys.stdout |             exp_params['copy_to'] = sys.stdout | ||||||
|  |  | ||||||
|  |         if not os.path.exists(args.file): | ||||||
|  |             logger.error('Please, input a valid file') | ||||||
|  |             return | ||||||
|         simulation.run_from_config(args.file, |         simulation.run_from_config(args.file, | ||||||
|                                    dry_run=args.dry_run, |                                    dry_run=args.dry_run, | ||||||
|                                    exporters=exporters, |                                    exporters=exporters, | ||||||
|   | |||||||
| @@ -1,40 +1,31 @@ | |||||||
| import random | import random | ||||||
| from . import BaseAgent | from . import FSM, state, default_state | ||||||
|  |  | ||||||
|  |  | ||||||
| class BassModel(BaseAgent): | class BassModel(FSM): | ||||||
|     """ |     """ | ||||||
|     Settings: |     Settings: | ||||||
|         innovation_prob |         innovation_prob | ||||||
|         imitation_prob |         imitation_prob | ||||||
|     """ |     """ | ||||||
|  |     sentimentCorrelation = 0 | ||||||
|     def __init__(self, environment, agent_id, state): |  | ||||||
|         super().__init__(environment=environment, agent_id=agent_id, state=state) |  | ||||||
|         env_params = environment.environment_params |  | ||||||
|         self.state['sentimentCorrelation'] = 0 |  | ||||||
|  |  | ||||||
|     def step(self): |     def step(self): | ||||||
|         self.behaviour() |         self.behaviour() | ||||||
|  |  | ||||||
|     def behaviour(self): |     @default_state | ||||||
|         # Outside effects |     @state | ||||||
|         if random.random() < self.state_params['innovation_prob']: |     def innovation(self): | ||||||
|             if self.state['id'] == 0: |         if random.random() < self.innovation_prob: | ||||||
|                 self.state['id'] = 1 |             self.sentimentCorrelation = 1 | ||||||
|                 self.state['sentimentCorrelation'] = 1 |             return self.aware | ||||||
|             else: |         else: | ||||||
|                 pass |             aware_neighbors = self.get_neighboring_agents(state_id=self.aware.id) | ||||||
|  |  | ||||||
|             return |  | ||||||
|  |  | ||||||
|         # Imitation effects |  | ||||||
|         if self.state['id'] == 0: |  | ||||||
|             aware_neighbors = self.get_neighboring_agents(state_id=1) |  | ||||||
|             num_neighbors_aware = len(aware_neighbors) |             num_neighbors_aware = len(aware_neighbors) | ||||||
|             if random.random() < (self.state_params['imitation_prob']*num_neighbors_aware): |             if random.random() < (self['imitation_prob']*num_neighbors_aware): | ||||||
|                 self.state['id'] = 1 |                 self.sentimentCorrelation = 1 | ||||||
|                 self.state['sentimentCorrelation'] = 1 |                 return self.aware | ||||||
|  |  | ||||||
|             else: |     @state | ||||||
|                 pass |     def aware(self): | ||||||
|  |         self.die() | ||||||
|   | |||||||
| @@ -1,8 +1,8 @@ | |||||||
| import random | import random | ||||||
| from . import BaseAgent | from . import FSM, state, default_state | ||||||
|  |  | ||||||
|  |  | ||||||
| class BigMarketModel(BaseAgent): | class BigMarketModel(FSM): | ||||||
|     """ |     """ | ||||||
|     Settings: |     Settings: | ||||||
|         Names: |         Names: | ||||||
| @@ -19,34 +19,25 @@ class BigMarketModel(BaseAgent): | |||||||
|             sentiment_about [Array] |             sentiment_about [Array] | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, environment=None, agent_id=0, state=()): |     def __init__(self, *args, **kwargs): | ||||||
|         super().__init__(environment=environment, agent_id=agent_id, state=state) |         super().__init__(*args, **kwargs) | ||||||
|         self.enterprises = environment.environment_params['enterprises'] |         self.enterprises = self.env.environment_params['enterprises'] | ||||||
|         self.type = "" |         self.type = "" | ||||||
|         self.number_of_enterprises = len(environment.environment_params['enterprises']) |  | ||||||
|  |  | ||||||
|         if self.id < self.number_of_enterprises:  # Enterprises |         if self.id < len(self.enterprises):  # Enterprises | ||||||
|             self.state['id'] = self.id |             self.set_state(self.enterprise.id) | ||||||
|             self.type = "Enterprise" |             self.type = "Enterprise" | ||||||
|             self.tweet_probability = environment.environment_params['tweet_probability_enterprises'][self.id] |             self.tweet_probability = environment.environment_params['tweet_probability_enterprises'][self.id] | ||||||
|         else:  # normal users |         else:  # normal users | ||||||
|             self.state['id'] = self.number_of_enterprises |  | ||||||
|             self.type = "User" |             self.type = "User" | ||||||
|  |             self.set_state(self.user.id) | ||||||
|             self.tweet_probability = environment.environment_params['tweet_probability_users'] |             self.tweet_probability = environment.environment_params['tweet_probability_users'] | ||||||
|             self.tweet_relevant_probability = environment.environment_params['tweet_relevant_probability'] |             self.tweet_relevant_probability = environment.environment_params['tweet_relevant_probability'] | ||||||
|             self.tweet_probability_about = environment.environment_params['tweet_probability_about']  # List |             self.tweet_probability_about = environment.environment_params['tweet_probability_about']  # List | ||||||
|             self.sentiment_about = environment.environment_params['sentiment_about']  # List |             self.sentiment_about = environment.environment_params['sentiment_about']  # List | ||||||
|  |  | ||||||
|     def step(self): |     @state | ||||||
|  |     def enterprise(self): | ||||||
|         if self.id < self.number_of_enterprises:  # Enterprise |  | ||||||
|             self.enterpriseBehaviour() |  | ||||||
|         else:  # Usuario |  | ||||||
|             self.userBehaviour() |  | ||||||
|             for i in range(self.number_of_enterprises):  # So that it never is set to 0 if there are not changes (logs) |  | ||||||
|                 self.attrs['sentiment_enterprise_%s'% self.enterprises[i]] = self.sentiment_about[i] |  | ||||||
|  |  | ||||||
|     def enterpriseBehaviour(self): |  | ||||||
|  |  | ||||||
|         if random.random() < self.tweet_probability:  # Tweets |         if random.random() < self.tweet_probability:  # Tweets | ||||||
|             aware_neighbors = self.get_neighboring_agents(state_id=self.number_of_enterprises)  # Nodes neighbour users |             aware_neighbors = self.get_neighboring_agents(state_id=self.number_of_enterprises)  # Nodes neighbour users | ||||||
| @@ -64,12 +55,12 @@ class BigMarketModel(BaseAgent): | |||||||
|  |  | ||||||
|                 x.attrs['sentiment_enterprise_%s'% self.enterprises[self.id]] = x.sentiment_about[self.id] |                 x.attrs['sentiment_enterprise_%s'% self.enterprises[self.id]] = x.sentiment_about[self.id] | ||||||
|  |  | ||||||
|     def userBehaviour(self): |     @state | ||||||
|  |     def user(self): | ||||||
|         if random.random() < self.tweet_probability:  # Tweets |         if random.random() < self.tweet_probability:  # Tweets | ||||||
|             if random.random() < self.tweet_relevant_probability:  # Tweets something relevant |             if random.random() < self.tweet_relevant_probability:  # Tweets something relevant | ||||||
|                 # Tweet probability per enterprise |                 # Tweet probability per enterprise | ||||||
|                 for i in range(self.number_of_enterprises): |                 for i in range(len(self.enterprises)): | ||||||
|                     random_num = random.random() |                     random_num = random.random() | ||||||
|                     if random_num < self.tweet_probability_about[i]: |                     if random_num < self.tweet_probability_about[i]: | ||||||
|                         # The condition is fulfilled, sentiments are evaluated towards that enterprise |                         # The condition is fulfilled, sentiments are evaluated towards that enterprise | ||||||
| @@ -82,8 +73,10 @@ class BigMarketModel(BaseAgent): | |||||||
|                         else: |                         else: | ||||||
|                             # POSITIVO |                             # POSITIVO | ||||||
|                             self.userTweets("positive",i) |                             self.userTweets("positive",i) | ||||||
|  |         for i in range(len(self.enterprises)):  # So that it never is set to 0 if there are not changes (logs) | ||||||
|  |             self.attrs['sentiment_enterprise_%s'% self.enterprises[i]] = self.sentiment_about[i] | ||||||
|  |  | ||||||
|     def userTweets(self,sentiment,enterprise): |     def userTweets(self, sentiment,enterprise): | ||||||
|         aware_neighbors = self.get_neighboring_agents(state_id=self.number_of_enterprises)  # Nodes neighbours users |         aware_neighbors = self.get_neighboring_agents(state_id=self.number_of_enterprises)  # Nodes neighbours users | ||||||
|         for x in aware_neighbors: |         for x in aware_neighbors: | ||||||
|             if sentiment == "positive": |             if sentiment == "positive": | ||||||
|   | |||||||
| @@ -1,7 +1,7 @@ | |||||||
| from . import BaseAgent | from . import NetworkAgent | ||||||
|  |  | ||||||
|  |  | ||||||
| class CounterModel(BaseAgent): | class CounterModel(NetworkAgent): | ||||||
|     """ |     """ | ||||||
|     Dummy behaviour. It counts the number of nodes in the network and neighbors |     Dummy behaviour. It counts the number of nodes in the network and neighbors | ||||||
|     in each step and adds it to its state. |     in each step and adds it to its state. | ||||||
| @@ -9,14 +9,14 @@ class CounterModel(BaseAgent): | |||||||
|  |  | ||||||
|     def step(self): |     def step(self): | ||||||
|         # Outside effects |         # Outside effects | ||||||
|         total = len(list(self.get_all_agents())) |         total = len(list(self.get_agents())) | ||||||
|         neighbors = len(list(self.get_neighboring_agents())) |         neighbors = len(list(self.get_neighboring_agents())) | ||||||
|         self['times'] = self.get('times', 0) + 1 |         self['times'] = self.get('times', 0) + 1 | ||||||
|         self['neighbors'] = neighbors |         self['neighbors'] = neighbors | ||||||
|         self['total'] = total |         self['total'] = total | ||||||
|  |  | ||||||
|  |  | ||||||
| class AggregatedCounter(BaseAgent): | class AggregatedCounter(NetworkAgent): | ||||||
|     """ |     """ | ||||||
|     Dummy behaviour. It counts the number of nodes in the network and neighbors |     Dummy behaviour. It counts the number of nodes in the network and neighbors | ||||||
|     in each step and adds it to its state. |     in each step and adds it to its state. | ||||||
| @@ -33,6 +33,6 @@ class AggregatedCounter(BaseAgent): | |||||||
|         self['times'] += 1 |         self['times'] += 1 | ||||||
|         neighbors = len(list(self.get_neighboring_agents())) |         neighbors = len(list(self.get_neighboring_agents())) | ||||||
|         self['neighbors'] += neighbors |         self['neighbors'] += neighbors | ||||||
|         total = len(list(self.get_all_agents())) |         total = len(list(self.get_agents())) | ||||||
|         self['total'] += total |         self['total'] += total | ||||||
|         self.debug('Running for step: {}. Total: {}'.format(self.now, total)) |         self.debug('Running for step: {}. Total: {}'.format(self.now, total)) | ||||||
|   | |||||||
							
								
								
									
										21
									
								
								soil/agents/Geo.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								soil/agents/Geo.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | from scipy.spatial import cKDTree as KDTree | ||||||
|  | import networkx as nx | ||||||
|  | from . import NetworkAgent, as_node | ||||||
|  |  | ||||||
|  | class Geo(NetworkAgent): | ||||||
|  |     '''In this type of network, nodes have a "pos" attribute.''' | ||||||
|  |  | ||||||
|  |     def geo_search(self, radius, node=None, center=False, **kwargs): | ||||||
|  |         '''Get a list of nodes whose coordinates are closer than *radius* to *node*.''' | ||||||
|  |         node = as_node(node if node is not None else self) | ||||||
|  |  | ||||||
|  |         G = self.subgraph(**kwargs) | ||||||
|  |  | ||||||
|  |         pos = nx.get_node_attributes(G, 'pos') | ||||||
|  |         if not pos: | ||||||
|  |             return [] | ||||||
|  |         nodes, coords = list(zip(*pos.items())) | ||||||
|  |         kdtree = KDTree(coords)  # Cannot provide generator. | ||||||
|  |         indices = kdtree.query_ball_point(pos[node], radius) | ||||||
|  |         return [nodes[i] for i in indices if center or (nodes[i] != node)] | ||||||
|  |  | ||||||
| @@ -10,10 +10,10 @@ class IndependentCascadeModel(BaseAgent): | |||||||
|         imitation_prob |         imitation_prob | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, environment=None, agent_id=0, state=()): |     def __init__(self, *args, **kwargs): | ||||||
|         super().__init__(environment=environment, agent_id=agent_id, state=state) |         super().__init__(*args, **kwargs) | ||||||
|         self.innovation_prob = environment.environment_params['innovation_prob'] |         self.innovation_prob = self.env.environment_params['innovation_prob'] | ||||||
|         self.imitation_prob = environment.environment_params['imitation_prob'] |         self.imitation_prob = self.env.environment_params['imitation_prob'] | ||||||
|         self.state['time_awareness'] = 0 |         self.state['time_awareness'] = 0 | ||||||
|         self.state['sentimentCorrelation'] = 0 |         self.state['sentimentCorrelation'] = 0 | ||||||
|  |  | ||||||
|   | |||||||
| @@ -21,8 +21,8 @@ class SpreadModelM2(BaseAgent): | |||||||
|         prob_generate_anti_rumor |         prob_generate_anti_rumor | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, environment=None, agent_id=0, state=()): |     def __init__(self, model=None, unique_id=0, state=()): | ||||||
|         super().__init__(environment=environment, agent_id=agent_id, state=state) |         super().__init__(model=environment, unique_id=unique_id, state=state) | ||||||
|  |  | ||||||
|         self.prob_neutral_making_denier = np.random.normal(environment.environment_params['prob_neutral_making_denier'], |         self.prob_neutral_making_denier = np.random.normal(environment.environment_params['prob_neutral_making_denier'], | ||||||
|                                                            environment.environment_params['standard_variance']) |                                                            environment.environment_params['standard_variance']) | ||||||
| @@ -123,8 +123,8 @@ class ControlModelM2(BaseAgent): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|  |  | ||||||
|     def __init__(self, environment=None, agent_id=0, state=()): |     def __init__(self, model=None, unique_id=0, state=()): | ||||||
|         super().__init__(environment=environment, agent_id=agent_id, state=state) |         super().__init__(model=environment, unique_id=unique_id, state=state) | ||||||
|  |  | ||||||
|         self.prob_neutral_making_denier = np.random.normal(environment.environment_params['prob_neutral_making_denier'], |         self.prob_neutral_making_denier = np.random.normal(environment.environment_params['prob_neutral_making_denier'], | ||||||
|                                                            environment.environment_params['standard_variance']) |                                                            environment.environment_params['standard_variance']) | ||||||
|   | |||||||
| @@ -29,8 +29,8 @@ class SISaModel(FSM): | |||||||
|         standard_variance |         standard_variance | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, environment, agent_id=0, state=()): |     def __init__(self, environment, unique_id=0, state=()): | ||||||
|         super().__init__(environment=environment, agent_id=agent_id, state=state) |         super().__init__(model=environment, unique_id=unique_id, state=state) | ||||||
|  |  | ||||||
|         self.neutral_discontent_spon_prob = np.random.normal(self.env['neutral_discontent_spon_prob'], |         self.neutral_discontent_spon_prob = np.random.normal(self.env['neutral_discontent_spon_prob'], | ||||||
|                                                              self.env['standard_variance']) |                                                              self.env['standard_variance']) | ||||||
|   | |||||||
| @@ -16,8 +16,8 @@ class SentimentCorrelationModel(BaseAgent): | |||||||
|         disgust_prob |         disgust_prob | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def __init__(self, environment, agent_id=0, state=()): |     def __init__(self, environment, unique_id=0, state=()): | ||||||
|         super().__init__(environment=environment, agent_id=agent_id, state=state) |         super().__init__(model=environment, unique_id=unique_id, state=state) | ||||||
|         self.outside_effects_prob = environment.environment_params['outside_effects_prob'] |         self.outside_effects_prob = environment.environment_params['outside_effects_prob'] | ||||||
|         self.anger_prob = environment.environment_params['anger_prob'] |         self.anger_prob = environment.environment_params['anger_prob'] | ||||||
|         self.joy_prob = environment.environment_params['joy_prob'] |         self.joy_prob = environment.environment_params['joy_prob'] | ||||||
|   | |||||||
| @@ -1,21 +1,16 @@ | |||||||
| # networkStatus = {}  # Dict that will contain the status of every agent in the network |  | ||||||
| # sentimentCorrelationNodeArray = [] |  | ||||||
| # for x in range(0, settings.network_params["number_of_nodes"]): |  | ||||||
| #     sentimentCorrelationNodeArray.append({'id': x}) |  | ||||||
| # Initialize agent states. Let's assume everyone is normal. |  | ||||||
|      |  | ||||||
|  |  | ||||||
| import nxsim |  | ||||||
| import logging | import logging | ||||||
| from collections import OrderedDict | from collections import OrderedDict, defaultdict | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from functools import partial | from functools import partial, wraps | ||||||
| from scipy.spatial import cKDTree as KDTree | from itertools import islice | ||||||
| import json | import json | ||||||
|  | import networkx as nx | ||||||
|  |  | ||||||
| from functools import wraps | from .. import serialization, utils, time | ||||||
|  |  | ||||||
| from .. import serialization, history | from tsih import Key | ||||||
|  |  | ||||||
|  | from mesa import Agent | ||||||
|  |  | ||||||
|  |  | ||||||
| def as_node(agent): | def as_node(agent): | ||||||
| @@ -23,41 +18,54 @@ def as_node(agent): | |||||||
|         return agent.id |         return agent.id | ||||||
|     return agent |     return agent | ||||||
|  |  | ||||||
|  | IGNORED_FIELDS = ('model', 'logger') | ||||||
|  |  | ||||||
| class BaseAgent(nxsim.BaseAgent): |  | ||||||
|  | class DeadAgent(Exception): | ||||||
|  |     pass | ||||||
|  |  | ||||||
|  | class BaseAgent(Agent): | ||||||
|     """ |     """ | ||||||
|     A special simpy BaseAgent that keeps track of its state history. |     A special Agent that keeps track of its state history. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     defaults = {} |     defaults = {} | ||||||
|  |  | ||||||
|     def __init__(self, environment, agent_id, state=None, |     def __init__(self, | ||||||
|                  name=None, interval=None, **state_params): |                  unique_id, | ||||||
|  |                  model, | ||||||
|  |                  name=None, | ||||||
|  |                  interval=None): | ||||||
|         # Check for REQUIRED arguments |         # Check for REQUIRED arguments | ||||||
|         assert environment is not None, TypeError('__init__ missing 1 required keyword argument: \'environment\'. ' |  | ||||||
|                                                   'Cannot be NoneType.') |  | ||||||
|         # Initialize agent parameters |         # Initialize agent parameters | ||||||
|         self.id = agent_id |         if isinstance(unique_id, Agent): | ||||||
|         self.name = name or '{}[{}]'.format(type(self).__name__, self.id) |             raise Exception() | ||||||
|         self.state_params = state_params |         self._saved = set() | ||||||
|  |         super().__init__(unique_id=unique_id, model=model) | ||||||
|         # Register agent to environment |         self.name = name or '{}[{}]'.format(type(self).__name__, self.unique_id) | ||||||
|         self.env = environment |  | ||||||
|  |  | ||||||
|         self._neighbors = None |         self._neighbors = None | ||||||
|         self.alive = True |         self.alive = True | ||||||
|         real_state = deepcopy(self.defaults) |  | ||||||
|         real_state.update(state or {}) |  | ||||||
|         self.state = real_state |  | ||||||
|         self.interval = interval |  | ||||||
|  |  | ||||||
|         if not hasattr(self, 'level'): |         self.interval = interval or self.get('interval', 1) | ||||||
|             self.level = logging.DEBUG |         self.logger = logging.getLogger(self.model.name).getChild(self.name) | ||||||
|         self.logger = logging.getLogger(self.env.name) |  | ||||||
|         self.logger.setLevel(self.level) |  | ||||||
|  |  | ||||||
|         # initialize every time an instance of the agent is created |         if hasattr(self, 'level'): | ||||||
|         self.action = self.env.process(self.run()) |             self.logger.setLevel(self.level) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     # TODO: refactor to clean up mesa compatibility | ||||||
|  |     @property | ||||||
|  |     def id(self): | ||||||
|  |         return self.unique_id | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def env(self): | ||||||
|  |         return self.model | ||||||
|  |  | ||||||
|  |     @env.setter | ||||||
|  |     def env(self, model): | ||||||
|  |         self.model = model | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def state(self): |     def state(self): | ||||||
| @@ -71,44 +79,47 @@ class BaseAgent(nxsim.BaseAgent): | |||||||
|  |  | ||||||
|     @state.setter |     @state.setter | ||||||
|     def state(self, value): |     def state(self, value): | ||||||
|         self._state = {} |  | ||||||
|         for k, v in value.items(): |         for k, v in value.items(): | ||||||
|             self[k] = v |             self[k] = v | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def global_topology(self): |  | ||||||
|         return self.env.G |  | ||||||
|      |  | ||||||
|     @property |     @property | ||||||
|     def environment_params(self): |     def environment_params(self): | ||||||
|         return self.env.environment_params |         return self.model.environment_params | ||||||
|      |  | ||||||
|     @environment_params.setter |     @environment_params.setter | ||||||
|     def environment_params(self, value): |     def environment_params(self, value): | ||||||
|         self.env.environment_params = value |         self.model.environment_params = value | ||||||
|  |  | ||||||
|  |     def __setattr__(self, key, value): | ||||||
|  |         if not key.startswith('_') and key not in IGNORED_FIELDS: | ||||||
|  |             try: | ||||||
|  |                 k = Key(t_step=self.now, | ||||||
|  |                         dict_id=self.unique_id, | ||||||
|  |                         key=key) | ||||||
|  |                 self._saved.add(key) | ||||||
|  |                 self.model[k] = value | ||||||
|  |             except AttributeError: | ||||||
|  |                 pass | ||||||
|  |         super().__setattr__(key, value) | ||||||
|  |  | ||||||
|     def __getitem__(self, key): |     def __getitem__(self, key): | ||||||
|         if isinstance(key, tuple): |         if isinstance(key, tuple): | ||||||
|             key, t_step = key |             key, t_step = key | ||||||
|             k = history.Key(key=key, t_step=t_step, agent_id=self.id) |             k = Key(key=key, t_step=t_step, dict_id=self.unique_id) | ||||||
|             return self.env[k] |             return self.model[k] | ||||||
|         return self._state.get(key, None) |         return getattr(self, key) | ||||||
|  |  | ||||||
|     def __delitem__(self, key): |     def __delitem__(self, key): | ||||||
|         self._state[key] = None |         return delattr(self, key) | ||||||
|  |  | ||||||
|     def __contains__(self, key): |     def __contains__(self, key): | ||||||
|         return key in self._state |         return hasattr(self, key) | ||||||
|  |  | ||||||
|     def __setitem__(self, key, value): |     def __setitem__(self, key, value): | ||||||
|         self._state[key] = value |         setattr(self, key, value) | ||||||
|         k = history.Key(t_step=self.now, |  | ||||||
|                         agent_id=self.id, |  | ||||||
|                         key=key) |  | ||||||
|         self.env[k] = value |  | ||||||
|  |  | ||||||
|     def items(self): |     def items(self): | ||||||
|         return self._state.items() |         return ((k, getattr(self, k)) for k in self._saved) | ||||||
|  |  | ||||||
|     def get(self, key, default=None): |     def get(self, key, default=None): | ||||||
|         return self[key] if key in self else default |         return self[key] if key in self else default | ||||||
| @@ -116,54 +127,33 @@ class BaseAgent(nxsim.BaseAgent): | |||||||
|     @property |     @property | ||||||
|     def now(self): |     def now(self): | ||||||
|         try: |         try: | ||||||
|             return self.env.now |             return self.model.now | ||||||
|         except AttributeError: |         except AttributeError: | ||||||
|             # No environment |             # No environment | ||||||
|             return None |             return None | ||||||
|  |  | ||||||
|     def run(self): |  | ||||||
|         if self.interval is not None: |  | ||||||
|             interval = self.interval |  | ||||||
|         elif 'interval' in self: |  | ||||||
|             interval = self['interval'] |  | ||||||
|         else: |  | ||||||
|             interval = self.env.interval |  | ||||||
|         while self.alive: |  | ||||||
|             res = self.step() |  | ||||||
|             yield res or self.env.timeout(interval) |  | ||||||
|  |  | ||||||
|     def die(self, remove=False): |     def die(self, remove=False): | ||||||
|  |         self.info(f'agent {self.unique_id}  is dying') | ||||||
|         self.alive = False |         self.alive = False | ||||||
|         if remove: |         if remove: | ||||||
|             super().die() |             self.remove_node(self.id) | ||||||
|  |  | ||||||
|     def step(self): |     def step(self): | ||||||
|         pass |         if not self.alive: | ||||||
|  |             raise DeadAgent(self.unique_id) | ||||||
|     def count_agents(self, **kwargs): |         return super().step() or time.Delta(self.interval) | ||||||
|         return len(list(self.get_agents(**kwargs))) |  | ||||||
|  |  | ||||||
|     def count_neighboring_agents(self, state_id=None, **kwargs): |  | ||||||
|         return len(super().get_neighboring_agents(state_id=state_id, **kwargs)) |  | ||||||
|  |  | ||||||
|     def get_neighboring_agents(self, state_id=None, **kwargs): |  | ||||||
|         return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs) |  | ||||||
|  |  | ||||||
|     def get_agents(self, agents=None, limit_neighbors=False, **kwargs): |  | ||||||
|         if limit_neighbors: |  | ||||||
|             agents = super().get_agents(limit_neighbors=limit_neighbors) |  | ||||||
|         else: |  | ||||||
|             agents = self.env.get_agents(agents) |  | ||||||
|         return select(agents, **kwargs) |  | ||||||
|  |  | ||||||
|     def log(self, message, *args, level=logging.INFO, **kwargs): |     def log(self, message, *args, level=logging.INFO, **kwargs): | ||||||
|  |         if not self.logger.isEnabledFor(level): | ||||||
|  |             return | ||||||
|         message = message + " ".join(str(i) for i in args) |         message = message + " ".join(str(i) for i in args) | ||||||
|         message = "\t{:10}@{:>5}:\t{}".format(self.name, self.now, message) |         message = " @{:>3}: {}".format(self.now, message) | ||||||
|         for k, v in kwargs: |         for k, v in kwargs: | ||||||
|             message += " {k}={v} ".format(k, v) |             message += " {k}={v} ".format(k, v) | ||||||
|         extra = {} |         extra = {} | ||||||
|         extra['now'] = self.now |         extra['now'] = self.now | ||||||
|         extra['id'] = self.id |         extra['unique_id'] = self.unique_id | ||||||
|  |         extra['agent_name'] = self.name | ||||||
|         return self.logger.log(level, message, extra=extra) |         return self.logger.log(level, message, extra=extra) | ||||||
|  |  | ||||||
|     def debug(self, *args, **kwargs): |     def debug(self, *args, **kwargs): | ||||||
| @@ -172,44 +162,55 @@ class BaseAgent(nxsim.BaseAgent): | |||||||
|     def info(self, *args, **kwargs): |     def info(self, *args, **kwargs): | ||||||
|         return self.log(*args, level=logging.INFO, **kwargs) |         return self.log(*args, level=logging.INFO, **kwargs) | ||||||
|  |  | ||||||
|     def __getstate__(self): |  | ||||||
|         ''' |  | ||||||
|         Serializing an agent will lose all its running information (you cannot |  | ||||||
|         serialize an iterator), but it keeps the state and link to the environment, |  | ||||||
|         so it can be used for inspection and dumping to a file |  | ||||||
|         ''' |  | ||||||
|         state = {} |  | ||||||
|         state['id'] = self.id |  | ||||||
|         state['environment'] = self.env |  | ||||||
|         state['_state'] = self._state |  | ||||||
|         return state |  | ||||||
|  |  | ||||||
|     def __setstate__(self, state): |  | ||||||
|         ''' |  | ||||||
|         Get back a serialized agent and try to re-compose it |  | ||||||
|         ''' |  | ||||||
|         self.id = state['id'] |  | ||||||
|         self._state = state['_state'] |  | ||||||
|         self.env = state['environment'] |  | ||||||
|  |  | ||||||
|     def add_edge(self, node1, node2, **attrs): |  | ||||||
|         node1 = as_node(node1) |  | ||||||
|         node2 = as_node(node2) |  | ||||||
|  |  | ||||||
|         for n in [node1, node2]: |  | ||||||
|             if n not in self.global_topology.nodes(data=False): |  | ||||||
|                 raise ValueError('"{}" not in the graph'.format(n)) |  | ||||||
|         return self.global_topology.add_edge(node1, node2, **attrs) |  | ||||||
|  |  | ||||||
|     def subgraph(self, center=True, **kwargs): |  | ||||||
|         include = [self] if center else [] |  | ||||||
|         return self.global_topology.subgraph(n.id for n in self.get_agents(**kwargs)+include) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class NetworkAgent(BaseAgent): | class NetworkAgent(BaseAgent): | ||||||
|  |  | ||||||
|     def add_edge(self, other, **kwargs): |     @property | ||||||
|         return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs) |     def topology(self): | ||||||
|  |         return self.model.G | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def G(self): | ||||||
|  |         return self.model.G | ||||||
|  |  | ||||||
|  |     def count_agents(self, **kwargs): | ||||||
|  |         return len(list(self.get_agents(**kwargs))) | ||||||
|  |  | ||||||
|  |     def count_neighboring_agents(self, state_id=None, **kwargs): | ||||||
|  |         return len(self.get_neighboring_agents(state_id=state_id, **kwargs)) | ||||||
|  |  | ||||||
|  |     def get_neighboring_agents(self, state_id=None, **kwargs): | ||||||
|  |         return self.get_agents(limit_neighbors=True, state_id=state_id, **kwargs) | ||||||
|  |  | ||||||
|  |     def get_agents(self, *args, limit=None, **kwargs): | ||||||
|  |         it = self.iter_agents(*args, **kwargs) | ||||||
|  |         if limit is not None: | ||||||
|  |             it = islice(it, limit) | ||||||
|  |         return list(it) | ||||||
|  |  | ||||||
|  |     def iter_agents(self, agents=None, limit_neighbors=False, **kwargs): | ||||||
|  |         if limit_neighbors: | ||||||
|  |             agents = self.topology.neighbors(self.unique_id) | ||||||
|  |  | ||||||
|  |         agents = self.model.get_agents(agents) | ||||||
|  |         return select(agents, **kwargs) | ||||||
|  |  | ||||||
|  |     def subgraph(self, center=True, **kwargs): | ||||||
|  |         include = [self] if center else [] | ||||||
|  |         return self.topology.subgraph(n.unique_id for n in list(self.get_agents(**kwargs))+include) | ||||||
|  |  | ||||||
|  |     def remove_node(self, unique_id): | ||||||
|  |         self.topology.remove_node(unique_id) | ||||||
|  |  | ||||||
|  |     def add_edge(self, other, edge_attr_dict=None, *edge_attrs): | ||||||
|  |         # return super(NetworkAgent, self).add_edge(node1=self.id, node2=other, **kwargs) | ||||||
|  |         if self.unique_id not in self.topology.nodes(data=False): | ||||||
|  |             raise ValueError('{} not in list of existing agents in the network'.format(self.unique_id)) | ||||||
|  |         if other.unique_id not in self.topology.nodes(data=False): | ||||||
|  |             raise ValueError('{} not in list of existing agents in the network'.format(other)) | ||||||
|  |  | ||||||
|  |         self.topology.add_edge(self.unique_id, other.unique_id, edge_attr_dict=edge_attr_dict, *edge_attrs) | ||||||
|  |  | ||||||
|  |  | ||||||
|     def ego_search(self, steps=1, center=False, node=None, **kwargs): |     def ego_search(self, steps=1, center=False, node=None, **kwargs): | ||||||
|         '''Get a list of nodes in the ego network of *node* of radius *steps*''' |         '''Get a list of nodes in the ego network of *node* of radius *steps*''' | ||||||
| @@ -219,17 +220,17 @@ class NetworkAgent(BaseAgent): | |||||||
|  |  | ||||||
|     def degree(self, node, force=False): |     def degree(self, node, force=False): | ||||||
|         node = as_node(node) |         node = as_node(node) | ||||||
|         if force or (not hasattr(self.env, '_degree')) or getattr(self.env, '_last_step', 0) < self.now: |         if force or (not hasattr(self.model, '_degree')) or getattr(self.model, '_last_step', 0) < self.now: | ||||||
|             self.env._degree = nx.degree_centrality(self.global_topology) |             self.model._degree = nx.degree_centrality(self.topology) | ||||||
|             self.env._last_step = self.now |             self.model._last_step = self.now | ||||||
|         return self.env._degree[node] |         return self.model._degree[node] | ||||||
|  |  | ||||||
|     def betweenness(self, node, force=False): |     def betweenness(self, node, force=False): | ||||||
|         node = as_node(node) |         node = as_node(node) | ||||||
|         if force or (not hasattr(self.env, '_betweenness')) or getattr(self.env, '_last_step', 0) < self.now: |         if force or (not hasattr(self.model, '_betweenness')) or getattr(self.model, '_last_step', 0) < self.now: | ||||||
|             self.env._betweenness = nx.betweenness_centrality(self.global_topology) |             self.model._betweenness = nx.betweenness_centrality(self.topology) | ||||||
|             self.env._last_step = self.now |             self.model._last_step = self.now | ||||||
|         return self.env._betweenness[node] |         return self.model._betweenness[node] | ||||||
|  |  | ||||||
|  |  | ||||||
| def state(name=None): | def state(name=None): | ||||||
| @@ -292,31 +293,37 @@ class MetaFSM(type): | |||||||
|         cls.states = states |         cls.states = states | ||||||
|  |  | ||||||
|  |  | ||||||
| class FSM(BaseAgent, metaclass=MetaFSM): | class FSM(NetworkAgent, metaclass=MetaFSM): | ||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super(FSM, self).__init__(*args, **kwargs) |         super(FSM, self).__init__(*args, **kwargs) | ||||||
|         if 'id' not in self.state: |         if not hasattr(self, 'state_id'): | ||||||
|             if not self.default_state: |             if not self.default_state: | ||||||
|                 raise ValueError('No default state specified for {}'.format(self.id)) |                 raise ValueError('No default state specified for {}'.format(self.unique_id)) | ||||||
|             self['id'] = self.default_state.id |             self.state_id = self.default_state.id | ||||||
|  |  | ||||||
|  |         self.set_state(self.state_id) | ||||||
|  |  | ||||||
|     def step(self): |     def step(self): | ||||||
|         if 'id' in self.state: |         self.debug(f'Agent {self.unique_id} @ state {self.state_id}') | ||||||
|             next_state = self['id'] |         try: | ||||||
|         elif self.default_state: |             interval = super().step() | ||||||
|             next_state = self.default_state.id |         except DeadAgent: | ||||||
|         else: |             return time.When('inf') | ||||||
|             raise Exception('{} has no valid state id or default state'.format(self)) |         if 'id' not in self.state: | ||||||
|         if next_state not in self.states: |             # if 'id' in self.state: | ||||||
|             raise Exception('{} is not a valid id for {}'.format(next_state, self)) |             #     self.set_state(self.state['id']) | ||||||
|         return self.states[next_state](self) |             if self.default_state: | ||||||
|  |                 self.set_state(self.default_state.id) | ||||||
|  |             else: | ||||||
|  |                 raise Exception('{} has no valid state id or default state'.format(self)) | ||||||
|  |         return self.states[self.state_id](self) or interval | ||||||
|  |  | ||||||
|     def set_state(self, state): |     def set_state(self, state): | ||||||
|         if hasattr(state, 'id'): |         if hasattr(state, 'id'): | ||||||
|             state = state.id |             state = state.id | ||||||
|         if state not in self.states: |         if state not in self.states: | ||||||
|             raise ValueError('{} is not a valid state'.format(state)) |             raise ValueError('{} is not a valid state'.format(state)) | ||||||
|         self['id'] = state |         self.state_id = state | ||||||
|         return state |         return state | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -335,9 +342,6 @@ def prob(prob=1): | |||||||
|     return r < prob |     return r < prob | ||||||
|  |  | ||||||
|  |  | ||||||
| STATIC_THRESHOLD = (-1, -1) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def calculate_distribution(network_agents=None, | def calculate_distribution(network_agents=None, | ||||||
|                            agent_type=None): |                            agent_type=None): | ||||||
|     ''' |     ''' | ||||||
| @@ -365,20 +369,23 @@ def calculate_distribution(network_agents=None, | |||||||
|     'agent_type_1'. |     'agent_type_1'. | ||||||
|     ''' |     ''' | ||||||
|     if network_agents: |     if network_agents: | ||||||
|         network_agents = deepcopy(network_agents) |         network_agents = [deepcopy(agent) for agent in network_agents if not hasattr(agent, 'id')] | ||||||
|     elif agent_type: |     elif agent_type: | ||||||
|         network_agents = [{'agent_type': agent_type}] |         network_agents = [{'agent_type': agent_type}] | ||||||
|     else: |     else: | ||||||
|         raise ValueError('Specify a distribution or a default agent type') |         raise ValueError('Specify a distribution or a default agent type') | ||||||
|  |  | ||||||
|  |     # Fix missing weights and incompatible types | ||||||
|  |     for x in network_agents: | ||||||
|  |         x['weight'] = float(x.get('weight', 1)) | ||||||
|  |  | ||||||
|     # Calculate the thresholds |     # Calculate the thresholds | ||||||
|     total = sum(x.get('weight', 1) for x in network_agents) |     total = sum(x['weight'] for x in network_agents) | ||||||
|     acc = 0 |     acc = 0 | ||||||
|     for v in network_agents: |     for v in network_agents: | ||||||
|         if 'ids' in v: |         if 'ids' in v: | ||||||
|             v['threshold'] = STATIC_THRESHOLD |  | ||||||
|             continue |             continue | ||||||
|         upper = acc + (v.get('weight', 1)/total) |         upper = acc + (v['weight']/total) | ||||||
|         v['threshold'] = [acc, upper] |         v['threshold'] = [acc, upper] | ||||||
|         acc = upper |         acc = upper | ||||||
|     return network_agents |     return network_agents | ||||||
| @@ -391,7 +398,7 @@ def serialize_type(agent_type, known_modules=[], **kwargs): | |||||||
|     return serialization.serialize(agent_type, known_modules=known_modules, **kwargs)[1] # Get the name of the class |     return serialization.serialize(agent_type, known_modules=known_modules, **kwargs)[1] # Get the name of the class | ||||||
|  |  | ||||||
|  |  | ||||||
| def serialize_distribution(network_agents, known_modules=[]): | def serialize_definition(network_agents, known_modules=[]): | ||||||
|     ''' |     ''' | ||||||
|     When serializing an agent distribution, remove the thresholds, in order |     When serializing an agent distribution, remove the thresholds, in order | ||||||
|     to avoid cluttering the YAML definition file. |     to avoid cluttering the YAML definition file. | ||||||
| @@ -413,7 +420,7 @@ def deserialize_type(agent_type, known_modules=[]): | |||||||
|     return agent_type |     return agent_type | ||||||
|  |  | ||||||
|  |  | ||||||
| def deserialize_distribution(ind, **kwargs): | def deserialize_definition(ind, **kwargs): | ||||||
|     d = deepcopy(ind) |     d = deepcopy(ind) | ||||||
|     for v in d: |     for v in d: | ||||||
|         v['agent_type'] = deserialize_type(v['agent_type'], **kwargs) |         v['agent_type'] = deserialize_type(v['agent_type'], **kwargs) | ||||||
| @@ -425,7 +432,7 @@ def _validate_states(states, topology): | |||||||
|     states = states or [] |     states = states or [] | ||||||
|     if isinstance(states, dict): |     if isinstance(states, dict): | ||||||
|         for x in states: |         for x in states: | ||||||
|             assert x in topology.node |             assert x in topology.nodes | ||||||
|     else: |     else: | ||||||
|         assert len(states) <= len(topology) |         assert len(states) <= len(topology) | ||||||
|     return states |     return states | ||||||
| @@ -434,44 +441,84 @@ def _validate_states(states, topology): | |||||||
| def _convert_agent_types(ind, to_string=False, **kwargs): | def _convert_agent_types(ind, to_string=False, **kwargs): | ||||||
|     '''Convenience method to allow specifying agents by class or class name.''' |     '''Convenience method to allow specifying agents by class or class name.''' | ||||||
|     if to_string: |     if to_string: | ||||||
|         return serialize_distribution(ind, **kwargs) |         return serialize_definition(ind, **kwargs) | ||||||
|     return deserialize_distribution(ind, **kwargs) |     return deserialize_definition(ind, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| def _agent_from_distribution(distribution, value=-1, agent_id=None): | def _agent_from_definition(definition, value=-1, unique_id=None): | ||||||
|     """Used in the initialization of agents given an agent distribution.""" |     """Used in the initialization of agents given an agent distribution.""" | ||||||
|     if value < 0: |     if value < 0: | ||||||
|         value = random.random() |         value = random.random() | ||||||
|     for d in sorted(distribution, key=lambda x: x['threshold']): |     for d in sorted(definition, key=lambda x: x.get('threshold')): | ||||||
|         threshold = d['threshold'] |         threshold = d.get('threshold', (-1, -1)) | ||||||
|         # Check if the definition matches by id (first) or by threshold |         # Check if the definition matches by id (first) or by threshold | ||||||
|         if not ((agent_id is not None and threshold == STATIC_THRESHOLD and agent_id in d['ids']) or \ |         if (unique_id is not None and unique_id in d.get('ids', [])) or \ | ||||||
|                 (value >= threshold[0] and value < threshold[1])): |            (value >= threshold[0] and value < threshold[1]): | ||||||
|             continue |             state = {} | ||||||
|         state = {} |             if 'state' in d: | ||||||
|         if 'state' in d: |                 state = deepcopy(d['state']) | ||||||
|             state = deepcopy(d['state']) |             return d['agent_type'], state | ||||||
|         return d['agent_type'], state |  | ||||||
|  |  | ||||||
|     raise Exception('Distribution for value {} not found in: {}'.format(value, distribution)) |     raise Exception('Definition for value {} not found in: {}'.format(value, definition)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Geo(NetworkAgent): | def _definition_to_dict(definition, size=None, default_state=None): | ||||||
|     '''In this type of network, nodes have a "pos" attribute.''' |     state = default_state or {} | ||||||
|  |     agents = {} | ||||||
|  |     remaining = {} | ||||||
|  |     if size: | ||||||
|  |         for ix in range(size): | ||||||
|  |             remaining[ix] = copy(state) | ||||||
|  |     else: | ||||||
|  |         remaining = defaultdict(lambda x: copy(state)) | ||||||
|  |  | ||||||
|     def geo_search(self, radius, node=None, center=False, **kwargs): |     distro = sorted([item for item in definition if 'weight' in item]) | ||||||
|         '''Get a list of nodes whose coordinates are closer than *radius* to *node*.''' |  | ||||||
|         node = as_node(node if node is not None else self) |  | ||||||
|  |  | ||||||
|         G = self.subgraph(**kwargs) |     ix = 0 | ||||||
|  |     def init_agent(item, id=ix): | ||||||
|  |         while id in agents: | ||||||
|  |             id += 1 | ||||||
|  |  | ||||||
|         pos = nx.get_node_attributes(G, 'pos') |         agent = remaining[id] | ||||||
|         if not pos: |         agent['state'].update(copy(item.get('state', {}))) | ||||||
|             return [] |         agents[id] = agent | ||||||
|         nodes, coords = list(zip(*pos.items())) |         del remaining[id] | ||||||
|         kdtree = KDTree(coords)  # Cannot provide generator. |         return agent | ||||||
|         indices = kdtree.query_ball_point(pos[node], radius) |  | ||||||
|         return [nodes[i] for i in indices if center or (nodes[i] != node)] |     for item in definition: | ||||||
|  |         if 'ids' in item: | ||||||
|  |             ids = item['ids'] | ||||||
|  |             del item['ids'] | ||||||
|  |             for id in ids: | ||||||
|  |                 agent = init_agent(item, id) | ||||||
|  |  | ||||||
|  |     for item in definition: | ||||||
|  |         if 'number' in item: | ||||||
|  |             times = item['number'] | ||||||
|  |             del item['number'] | ||||||
|  |             for times in range(times): | ||||||
|  |                 if size: | ||||||
|  |                     ix = random.choice(remaining.keys()) | ||||||
|  |                     agent = init_agent(item, id) | ||||||
|  |                 else: | ||||||
|  |                     agent = init_agent(item) | ||||||
|  |     if not size: | ||||||
|  |         return agents | ||||||
|  |  | ||||||
|  |     if len(remaining) < 0: | ||||||
|  |         raise Exception('Invalid definition. Too many agents to add') | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     total_weight = float(sum(s['weight'] for s in distro)) | ||||||
|  |     unit = size / total_weight | ||||||
|  |  | ||||||
|  |     for item in distro: | ||||||
|  |         times = unit * item['weight'] | ||||||
|  |         del item['weight'] | ||||||
|  |         for times in range(times): | ||||||
|  |             ix = random.choice(remaining.keys()) | ||||||
|  |             agent = init_agent(item, id) | ||||||
|  |     return agents | ||||||
|  |  | ||||||
|  |  | ||||||
| def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs): | def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False, **kwargs): | ||||||
| @@ -484,25 +531,22 @@ def select(agents, state_id=None, agent_type=None, ignore=None, iterator=False, | |||||||
|         except TypeError: |         except TypeError: | ||||||
|             agent_type = tuple([agent_type]) |             agent_type = tuple([agent_type]) | ||||||
|  |  | ||||||
|     def matches_all(agent): |     f = agents | ||||||
|         if state_id is not None: |  | ||||||
|             if agent.state.get('id', None) not in state_id: |  | ||||||
|                 return False |  | ||||||
|         if agent_type is not None: |  | ||||||
|             if not isinstance(agent, agent_type): |  | ||||||
|                 return False |  | ||||||
|         state = agent.state |  | ||||||
|         for k, v in kwargs.items(): |  | ||||||
|             if state.get(k, None) != v: |  | ||||||
|                 return False |  | ||||||
|         return True |  | ||||||
|  |  | ||||||
|     f = filter(matches_all, agents) |  | ||||||
|     if ignore: |     if ignore: | ||||||
|         f = filter(lambda x: x not in ignore, f) |         f = filter(lambda x: x not in ignore, f) | ||||||
|  |  | ||||||
|  |     if state_id is not None: | ||||||
|  |         f = filter(lambda agent: agent.get('state_id', None) in state_id, f) | ||||||
|  |  | ||||||
|  |     if agent_type is not None: | ||||||
|  |         f = filter(lambda agent: isinstance(agent, agent_type), f) | ||||||
|  |     for k, v in kwargs.items(): | ||||||
|  |         f = filter(lambda agent: agent.state.get(k, None) == v, f) | ||||||
|  |  | ||||||
|     if iterator: |     if iterator: | ||||||
|         return f |         return f | ||||||
|     return list(f) |     return f | ||||||
|  |  | ||||||
|  |  | ||||||
| from .BassModel import * | from .BassModel import * | ||||||
| @@ -512,3 +556,10 @@ from .ModelM2 import * | |||||||
| from .SentimentCorrelationModel import * | from .SentimentCorrelationModel import * | ||||||
| from .SISaModel import * | from .SISaModel import * | ||||||
| from .CounterModel import * | from .CounterModel import * | ||||||
|  |  | ||||||
|  | try: | ||||||
|  |     import scipy | ||||||
|  |     from .Geo import Geo | ||||||
|  | except ImportError: | ||||||
|  |     import sys | ||||||
|  |     print('Could not load the Geo Agent, scipy is not installed', file=sys.stderr) | ||||||
|   | |||||||
| @@ -4,7 +4,8 @@ import glob | |||||||
| import yaml | import yaml | ||||||
| from os.path import join | from os.path import join | ||||||
|  |  | ||||||
| from . import serialization, history | from . import serialization | ||||||
|  | from tsih import History | ||||||
|  |  | ||||||
|  |  | ||||||
| def read_data(*args, group=False, **kwargs): | def read_data(*args, group=False, **kwargs): | ||||||
| @@ -28,13 +29,13 @@ def _read_data(pattern, *args, from_csv=False, process_args=None, **kwargs): | |||||||
|                 df = read_csv(trial_data, **kwargs) |                 df = read_csv(trial_data, **kwargs) | ||||||
|                 yield config_file, df, config |                 yield config_file, df, config | ||||||
|         else: |         else: | ||||||
|             for trial_data in sorted(glob.glob(join(folder, '*.db.sqlite'))): |             for trial_data in sorted(glob.glob(join(folder, '*.sqlite'))): | ||||||
|                 df = read_sql(trial_data, **kwargs) |                 df = read_sql(trial_data, **kwargs) | ||||||
|                 yield config_file, df, config |                 yield config_file, df, config | ||||||
|  |  | ||||||
|  |  | ||||||
| def read_sql(db, *args, **kwargs): | def read_sql(db, *args, **kwargs): | ||||||
|     h = history.History(db_path=db, backup=False) |     h = History(db_path=db, backup=False, readonly=True) | ||||||
|     df = h.read_sql(*args, **kwargs) |     df = h.read_sql(*args, **kwargs) | ||||||
|     return df |     return df | ||||||
|  |  | ||||||
| @@ -61,7 +62,12 @@ def convert_row(row): | |||||||
|  |  | ||||||
|  |  | ||||||
| def convert_types_slow(df): | def convert_types_slow(df): | ||||||
|     '''This is a slow operation.''' |     ''' | ||||||
|  |     Go over every column in a dataframe and convert it to the type determined by the `get_types` | ||||||
|  |     function. | ||||||
|  |  | ||||||
|  |     This is a slow operation. | ||||||
|  |     ''' | ||||||
|     dtypes = get_types(df) |     dtypes = get_types(df) | ||||||
|     for k, v in dtypes.items(): |     for k, v in dtypes.items(): | ||||||
|         t = df[df['key']==k] |         t = df[df['key']==k] | ||||||
| @@ -69,6 +75,13 @@ def convert_types_slow(df): | |||||||
|     df = df.apply(convert_row, axis=1) |     df = df.apply(convert_row, axis=1) | ||||||
|     return df |     return df | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def split_processed(df): | ||||||
|  |     env = df.loc[:, df.columns.get_level_values(1).isin(['env', 'stats'])] | ||||||
|  |     agents = df.loc[:, ~df.columns.get_level_values(1).isin(['env', 'stats'])] | ||||||
|  |     return env, agents | ||||||
|  |  | ||||||
|  |  | ||||||
| def split_df(df): | def split_df(df): | ||||||
|     ''' |     ''' | ||||||
|     Split a dataframe in two dataframes: one with the history of agents, |     Split a dataframe in two dataframes: one with the history of agents, | ||||||
| @@ -95,6 +108,9 @@ def process(df, **kwargs): | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_types(df): | def get_types(df): | ||||||
|  |     ''' | ||||||
|  |     Get the value type for every key stored in a raw history dataframe. | ||||||
|  |     ''' | ||||||
|     dtypes = df.groupby(by=['key'])['value_type'].unique() |     dtypes = df.groupby(by=['key'])['value_type'].unique() | ||||||
|     return {k:v[0] for k,v in dtypes.iteritems()} |     return {k:v[0] for k,v in dtypes.iteritems()} | ||||||
|  |  | ||||||
| @@ -119,8 +135,14 @@ def process_one(df, *keys, columns=['key', 'agent_id'], values='value', | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_count(df, *keys): | def get_count(df, *keys): | ||||||
|  |     ''' | ||||||
|  |     For every t_step and key, get the value count. | ||||||
|  |  | ||||||
|  |     The result is a dataframe with `t_step` as index, an a multiindex column based on `key` and the values found for each `key`. | ||||||
|  |     ''' | ||||||
|     if keys: |     if keys: | ||||||
|         df = df[list(keys)] |         df = df[list(keys)] | ||||||
|  |         df.columns = df.columns.remove_unused_levels() | ||||||
|     counts = pd.DataFrame() |     counts = pd.DataFrame() | ||||||
|     for key in df.columns.levels[0]: |     for key in df.columns.levels[0]: | ||||||
|         g = df[[key]].apply(pd.Series.value_counts, axis=1).fillna(0) |         g = df[[key]].apply(pd.Series.value_counts, axis=1).fillna(0) | ||||||
| @@ -130,13 +152,28 @@ def get_count(df, *keys): | |||||||
|     return counts |     return counts | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_majority(df, *keys): | ||||||
|  |     ''' | ||||||
|  |     For every t_step and key, get the value of the majority of agents | ||||||
|  |  | ||||||
|  |     The result is a dataframe with `t_step` as index, and columns based on `key`. | ||||||
|  |     ''' | ||||||
|  |     df = get_count(df, *keys) | ||||||
|  |     return df.stack(level=0).idxmax(axis=1).unstack() | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_value(df, *keys, aggfunc='sum'): | def get_value(df, *keys, aggfunc='sum'): | ||||||
|  |     ''' | ||||||
|  |     For every t_step and key, get the value of *numeric columns*, aggregated using a specific function. | ||||||
|  |     ''' | ||||||
|     if keys: |     if keys: | ||||||
|         df = df[list(keys)] |         df = df[list(keys)] | ||||||
|     return df.groupby(axis=1, level=0).agg(aggfunc) |         df.columns = df.columns.remove_unused_levels() | ||||||
|  |     df = df.select_dtypes('number') | ||||||
|  |     return df.groupby(level='key', axis=1).agg(aggfunc) | ||||||
|  |  | ||||||
|  |  | ||||||
| def plot_all(*args, **kwargs): | def plot_all(*args, plot_args={}, **kwargs): | ||||||
|     ''' |     ''' | ||||||
|     Read all the trial data and plot the result of applying a function on them. |     Read all the trial data and plot the result of applying a function on them. | ||||||
|     ''' |     ''' | ||||||
| @@ -144,14 +181,17 @@ def plot_all(*args, **kwargs): | |||||||
|     ps = [] |     ps = [] | ||||||
|     for line in dfs: |     for line in dfs: | ||||||
|         f, df, config = line |         f, df, config = line | ||||||
|         df.plot(title=config['name']) |         if len(df) < 1: | ||||||
|  |             continue | ||||||
|  |         df.plot(title=config['name'], **plot_args) | ||||||
|         ps.append(df) |         ps.append(df) | ||||||
|     return ps |     return ps | ||||||
|  |  | ||||||
| def do_all(pattern, func, *keys, include_env=False, **kwargs): | def do_all(pattern, func, *keys, include_env=False, **kwargs): | ||||||
|     for config_file, df, config in read_data(pattern, keys=keys): |     for config_file, df, config in read_data(pattern, keys=keys): | ||||||
|  |         if len(df) < 1: | ||||||
|  |             continue | ||||||
|         p = func(df, *keys, **kwargs) |         p = func(df, *keys, **kwargs) | ||||||
|         p.plot(title=config['name']) |  | ||||||
|         yield config_file, p, config |         yield config_file, p, config | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										26
									
								
								soil/datacollection.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								soil/datacollection.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | |||||||
|  | from mesa import DataCollector as MDC | ||||||
|  |  | ||||||
|  | class SoilDataCollector(MDC): | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     def __init__(self, environment, *args, **kwargs): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         # Populate model and env reporters so they have a key per  | ||||||
|  |         # So they can be shown in the web interface | ||||||
|  |         self.environment = environment | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def model_vars(self): | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     @model_vars.setter | ||||||
|  |     def model_vars(self, value): | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def agent_reporters(self): | ||||||
|  |         self.model._history._ | ||||||
|  |  | ||||||
|  |         pass | ||||||
|  |  | ||||||
| @@ -1,29 +1,31 @@ | |||||||
| import os | import os | ||||||
| import sqlite3 | import sqlite3 | ||||||
| import time |  | ||||||
| import csv | import csv | ||||||
|  | import math | ||||||
| import random | import random | ||||||
| import simpy |  | ||||||
| import yaml | import yaml | ||||||
| import tempfile | import tempfile | ||||||
| import pandas as pd | import pandas as pd | ||||||
|  | from time import time as current_time | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| from collections import Counter |  | ||||||
| from networkx.readwrite import json_graph | from networkx.readwrite import json_graph | ||||||
|  |  | ||||||
| import networkx as nx | import networkx as nx | ||||||
| import nxsim |  | ||||||
|  |  | ||||||
| from . import serialization, agents, analysis, history, utils | from tsih import History, Record, Key, NoHistory | ||||||
|  |  | ||||||
|  | from mesa import Model | ||||||
|  |  | ||||||
|  | from . import serialization, agents, analysis, utils, time | ||||||
|  |  | ||||||
| # These properties will be copied when pickling/unpickling the environment | # These properties will be copied when pickling/unpickling the environment | ||||||
| _CONFIG_PROPS = [ 'name', | _CONFIG_PROPS = [ 'name', | ||||||
|                  'states', |                   'states', | ||||||
|                  'default_state', |                   'default_state', | ||||||
|                  'interval', |                   'interval', | ||||||
|                  ] |                  ] | ||||||
|  |  | ||||||
| class Environment(nxsim.NetworkEnvironment): | class Environment(Model): | ||||||
|     """ |     """ | ||||||
|     The environment is key in a simulation. It contains the network topology, |     The environment is key in a simulation. It contains the network topology, | ||||||
|     a reference to network and environment agents, as well as the environment |     a reference to network and environment agents, as well as the environment | ||||||
| @@ -40,27 +42,70 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|                  states=None, |                  states=None, | ||||||
|                  default_state=None, |                  default_state=None, | ||||||
|                  interval=1, |                  interval=1, | ||||||
|  |                  network_params=None, | ||||||
|                  seed=None, |                  seed=None, | ||||||
|                  topology=None, |                  topology=None, | ||||||
|                  *args, **kwargs): |                  schedule=None, | ||||||
|  |                  initial_time=0, | ||||||
|  |                  environment_params=None, | ||||||
|  |                  history=True, | ||||||
|  |                  dir_path=None, | ||||||
|  |                  **kwargs): | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         super().__init__() | ||||||
|  |  | ||||||
|  |         self.schedule = schedule | ||||||
|  |         if schedule is None: | ||||||
|  |             self.schedule = time.TimedActivation() | ||||||
|  |  | ||||||
|         self.name = name or 'UnnamedEnvironment' |         self.name = name or 'UnnamedEnvironment' | ||||||
|  |         seed = seed or current_time() | ||||||
|  |         random.seed(seed) | ||||||
|         if isinstance(states, list): |         if isinstance(states, list): | ||||||
|             states = dict(enumerate(states)) |             states = dict(enumerate(states)) | ||||||
|         self.states = deepcopy(states) if states else {} |         self.states = deepcopy(states) if states else {} | ||||||
|         self.default_state = deepcopy(default_state) or {} |         self.default_state = deepcopy(default_state) or {} | ||||||
|  |  | ||||||
|  |         if topology is None: | ||||||
|  |             network_params = network_params or {} | ||||||
|  |             topology = serialization.load_network(network_params, | ||||||
|  |                                                   dir_path=dir_path) | ||||||
|         if not topology: |         if not topology: | ||||||
|             topology = nx.Graph() |             topology = nx.Graph() | ||||||
|         super().__init__(*args, topology=topology, **kwargs) |         self.G = nx.Graph(topology)  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         self.environment_params = environment_params or {} | ||||||
|  |         self.environment_params.update(kwargs) | ||||||
|  |  | ||||||
|         self._env_agents = {} |         self._env_agents = {} | ||||||
|         self.interval = interval |         self.interval = interval | ||||||
|         self._history = history.History(name=self.name, |         if history: | ||||||
|                                         backup=True) |             history = History | ||||||
|         # Add environment agents first, so their events get |         else: | ||||||
|         # executed before network agents |             history = NoHistory | ||||||
|         self.environment_agents = environment_agents or [] |         self._history = history(name=self.name, | ||||||
|         self.network_agents = network_agents or [] |                                 backup=True) | ||||||
|         self['SEED'] = seed or time.time() |         self['SEED'] = seed | ||||||
|         random.seed(self['SEED']) |  | ||||||
|  |         if network_agents: | ||||||
|  |             distro = agents.calculate_distribution(network_agents) | ||||||
|  |             self.network_agents = agents._convert_agent_types(distro) | ||||||
|  |         else: | ||||||
|  |             self.network_agents = [] | ||||||
|  |  | ||||||
|  |         environment_agents = environment_agents or [] | ||||||
|  |         if environment_agents: | ||||||
|  |             distro = agents.calculate_distribution(environment_agents) | ||||||
|  |             environment_agents = agents._convert_agent_types(distro) | ||||||
|  |         self.environment_agents = environment_agents | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def now(self): | ||||||
|  |         if self.schedule: | ||||||
|  |             return self.schedule.time | ||||||
|  |         raise Exception('The environment has not been scheduled, so it has no sense of time') | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def agents(self): |     def agents(self): | ||||||
| @@ -74,15 +119,9 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|  |  | ||||||
|     @environment_agents.setter |     @environment_agents.setter | ||||||
|     def environment_agents(self, environment_agents): |     def environment_agents(self, environment_agents): | ||||||
|         # Set up environmental agent |         self._environment_agents = environment_agents | ||||||
|         self._env_agents = {} |  | ||||||
|         for item in environment_agents: |         self._env_agents = agents._definition_to_dict(definition=environment_agents) | ||||||
|             kwargs = deepcopy(item) |  | ||||||
|             atype = kwargs.pop('agent_type') |  | ||||||
|             kwargs['agent_id'] = kwargs.get('agent_id', atype.__name__) |  | ||||||
|             kwargs['state'] = kwargs.get('state', {}) |  | ||||||
|             a = atype(environment=self, **kwargs) |  | ||||||
|             self._env_agents[a.id] = a |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def network_agents(self): |     def network_agents(self): | ||||||
| @@ -95,9 +134,9 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|     def network_agents(self, network_agents): |     def network_agents(self, network_agents): | ||||||
|         self._network_agents = network_agents |         self._network_agents = network_agents | ||||||
|         for ix in self.G.nodes(): |         for ix in self.G.nodes(): | ||||||
|             self.init_agent(ix, agent_distribution=network_agents) |             self.init_agent(ix, agent_definitions=network_agents) | ||||||
|  |  | ||||||
|     def init_agent(self, agent_id, agent_distribution): |     def init_agent(self, agent_id, agent_definitions): | ||||||
|         node = self.G.nodes[agent_id] |         node = self.G.nodes[agent_id] | ||||||
|         init = False |         init = False | ||||||
|         state = dict(node) |         state = dict(node) | ||||||
| @@ -112,8 +151,8 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|  |  | ||||||
|         if agent_type: |         if agent_type: | ||||||
|             agent_type = agents.deserialize_type(agent_type) |             agent_type = agents.deserialize_type(agent_type) | ||||||
|         elif agent_distribution: |         elif agent_definitions: | ||||||
|             agent_type, state = agents._agent_from_distribution(agent_distribution, agent_id=agent_id) |             agent_type, state = agents._agent_from_definition(agent_definitions, unique_id=agent_id) | ||||||
|         else: |         else: | ||||||
|             serialization.logger.debug('Skipping node {}'.format(agent_id)) |             serialization.logger.debug('Skipping node {}'.format(agent_id)) | ||||||
|             return |             return | ||||||
| @@ -129,10 +168,18 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|         a = None |         a = None | ||||||
|         if agent_type: |         if agent_type: | ||||||
|             state = defstate |             state = defstate | ||||||
|             a = agent_type(environment=self, |             a = agent_type(model=self, | ||||||
|                            agent_id=agent_id, |                            unique_id=agent_id) | ||||||
|                            state=state) |  | ||||||
|  |         for (k, v) in getattr(a, 'defaults', {}).items(): | ||||||
|  |             if not hasattr(a, k) or getattr(a, k) is None: | ||||||
|  |                 setattr(a, k, v) | ||||||
|  |  | ||||||
|  |         for (k, v) in state.items(): | ||||||
|  |             setattr(a, k, v) | ||||||
|  |  | ||||||
|         node['agent'] = a |         node['agent'] = a | ||||||
|  |         self.schedule.add(a) | ||||||
|         return a |         return a | ||||||
|  |  | ||||||
|     def add_node(self, agent_type, state=None): |     def add_node(self, agent_type, state=None): | ||||||
| @@ -150,34 +197,23 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|         start = start or self.now |         start = start or self.now | ||||||
|         return self.G.add_edge(agent1, agent2, **attrs) |         return self.G.add_edge(agent1, agent2, **attrs) | ||||||
|  |  | ||||||
|     def run(self, *args, **kwargs): |     def step(self): | ||||||
|  |         super().step() | ||||||
|  |         self.datacollector.collect(self) | ||||||
|  |         self.schedule.step() | ||||||
|  |  | ||||||
|  |     def run(self, until, *args, **kwargs): | ||||||
|         self._save_state() |         self._save_state() | ||||||
|         self.log_stats() |  | ||||||
|         super().run(*args, **kwargs) |         while self.schedule.next_time <= until and not math.isinf(self.schedule.next_time): | ||||||
|  |             self.schedule.step(until=until) | ||||||
|  |             utils.logger.debug(f'Simulation step {self.schedule.time}/{until}. Next: {self.schedule.next_time}') | ||||||
|         self._history.flush_cache() |         self._history.flush_cache() | ||||||
|         self.log_stats() |  | ||||||
|  |  | ||||||
|     def _save_state(self, now=None): |     def _save_state(self, now=None): | ||||||
|         serialization.logger.debug('Saving state @{}'.format(self.now)) |         serialization.logger.debug('Saving state @{}'.format(self.now)) | ||||||
|         self._history.save_records(self.state_to_tuples(now=now)) |         self._history.save_records(self.state_to_tuples(now=now)) | ||||||
|  |  | ||||||
|     def save_state(self): |  | ||||||
|         ''' |  | ||||||
|         :DEPRECATED: |  | ||||||
|         Periodically save the state of the environment and the agents. |  | ||||||
|         ''' |  | ||||||
|         self._save_state() |  | ||||||
|         while self.peek() != simpy.core.Infinity: |  | ||||||
|             delay = max(self.peek() - self.now, self.interval) |  | ||||||
|             serialization.logger.debug('Step: {}'.format(self.now)) |  | ||||||
|             ev = self.event() |  | ||||||
|             ev._ok = True |  | ||||||
|             # Schedule the event with minimum priority so |  | ||||||
|             # that it executes before all agents |  | ||||||
|             self.schedule(ev, -999, delay) |  | ||||||
|             yield ev |  | ||||||
|             self._save_state() |  | ||||||
|  |  | ||||||
|     def __getitem__(self, key): |     def __getitem__(self, key): | ||||||
|         if isinstance(key, tuple): |         if isinstance(key, tuple): | ||||||
|             self._history.flush_cache() |             self._history.flush_cache() | ||||||
| @@ -187,12 +223,12 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|  |  | ||||||
|     def __setitem__(self, key, value): |     def __setitem__(self, key, value): | ||||||
|         if isinstance(key, tuple): |         if isinstance(key, tuple): | ||||||
|             k = history.Key(*key) |             k = Key(*key) | ||||||
|             self._history.save_record(*k, |             self._history.save_record(*k, | ||||||
|                                       value=value) |                                       value=value) | ||||||
|             return |             return | ||||||
|         self.environment_params[key] = value |         self.environment_params[key] = value | ||||||
|         self._history.save_record(agent_id='env', |         self._history.save_record(dict_id='env', | ||||||
|                                   t_step=self.now, |                                   t_step=self.now, | ||||||
|                                   key=key, |                                   key=key, | ||||||
|                                   value=value) |                                   value=value) | ||||||
| @@ -216,8 +252,8 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|  |  | ||||||
|     def get_agents(self, nodes=None): |     def get_agents(self, nodes=None): | ||||||
|         if nodes is None: |         if nodes is None: | ||||||
|             return list(self.agents) |             return self.agents | ||||||
|         return [self.G.nodes[i]['agent'] for i in nodes] |         return (self.G.nodes[i]['agent'] for i in nodes) | ||||||
|  |  | ||||||
|     def dump_csv(self, f): |     def dump_csv(self, f): | ||||||
|         with utils.open_or_reuse(f, 'w') as f: |         with utils.open_or_reuse(f, 'w') as f: | ||||||
| @@ -257,16 +293,16 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|         if now is None: |         if now is None: | ||||||
|             now = self.now |             now = self.now | ||||||
|         for k, v in self.environment_params.items(): |         for k, v in self.environment_params.items(): | ||||||
|             yield history.Record(agent_id='env', |             yield Record(dict_id='env', | ||||||
|                                  t_step=now, |                          t_step=now, | ||||||
|                                  key=k, |                          key=k, | ||||||
|                                  value=v) |                          value=v) | ||||||
|         for agent in self.agents: |         for agent in self.agents: | ||||||
|             for k, v in agent.state.items(): |             for k, v in agent.state.items(): | ||||||
|                 yield history.Record(agent_id=agent.id, |                 yield Record(dict_id=agent.id, | ||||||
|                                      t_step=now, |                              t_step=now, | ||||||
|                                      key=k, |                              key=k, | ||||||
|                                      value=v) |                              value=v) | ||||||
|  |  | ||||||
|     def history_to_tuples(self): |     def history_to_tuples(self): | ||||||
|         return self._history.to_tuples() |         return self._history.to_tuples() | ||||||
| @@ -317,25 +353,6 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|  |  | ||||||
|         return G |         return G | ||||||
|  |  | ||||||
|     def stats(self): |  | ||||||
|         stats = {} |  | ||||||
|         stats['network'] = {} |  | ||||||
|         stats['network']['n_nodes'] = self.G.number_of_nodes() |  | ||||||
|         stats['network']['n_edges'] = self.G.number_of_edges() |  | ||||||
|         c = Counter() |  | ||||||
|         c.update(a.__class__.__name__ for a in self.network_agents) |  | ||||||
|         stats['agents'] = {} |  | ||||||
|         stats['agents']['model_count'] = dict(c) |  | ||||||
|         c2 = Counter() |  | ||||||
|         c2.update(a['id'] for a in self.network_agents) |  | ||||||
|         stats['agents']['state_count'] = dict(c2) |  | ||||||
|         stats['params'] = self.environment_params |  | ||||||
|         return stats |  | ||||||
|  |  | ||||||
|     def log_stats(self): |  | ||||||
|         stats = self.stats() |  | ||||||
|         serialization.logger.info('Environment stats: \n{}'.format(yaml.dump(stats, default_flow_style=False))) |  | ||||||
|      |  | ||||||
|     def __getstate__(self): |     def __getstate__(self): | ||||||
|         state = {} |         state = {} | ||||||
|         for prop in _CONFIG_PROPS: |         for prop in _CONFIG_PROPS: | ||||||
| @@ -343,6 +360,7 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|         state['G'] = json_graph.node_link_data(self.G) |         state['G'] = json_graph.node_link_data(self.G) | ||||||
|         state['environment_agents'] = self._env_agents |         state['environment_agents'] = self._env_agents | ||||||
|         state['history'] = self._history |         state['history'] = self._history | ||||||
|  |         state['schedule'] = self.schedule | ||||||
|         return state |         return state | ||||||
|  |  | ||||||
|     def __setstate__(self, state): |     def __setstate__(self, state): | ||||||
| @@ -351,6 +369,9 @@ class Environment(nxsim.NetworkEnvironment): | |||||||
|         self._env_agents = state['environment_agents'] |         self._env_agents = state['environment_agents'] | ||||||
|         self.G = json_graph.node_link_graph(state['G']) |         self.G = json_graph.node_link_graph(state['G']) | ||||||
|         self._history = state['history'] |         self._history = state['history'] | ||||||
|  |         # self._env = None | ||||||
|  |         self.schedule = state['schedule'] | ||||||
|  |         self._queue = [] | ||||||
|  |  | ||||||
|  |  | ||||||
| SoilEnvironment = Environment | SoilEnvironment = Environment | ||||||
|   | |||||||
| @@ -1,10 +1,11 @@ | |||||||
| import os | import os | ||||||
|  | import csv as csvlib | ||||||
| import time | import time | ||||||
| from io import BytesIO | from io import BytesIO | ||||||
|  |  | ||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
| import networkx as nx | import networkx as nx | ||||||
| import pandas as pd |  | ||||||
|  |  | ||||||
| from .serialization import deserialize | from .serialization import deserialize | ||||||
| from .utils import open_or_reuse, logger, timer | from .utils import open_or_reuse, logger, timer | ||||||
| @@ -13,15 +14,6 @@ from .utils import open_or_reuse, logger, timer | |||||||
| from . import utils | from . import utils | ||||||
|  |  | ||||||
|  |  | ||||||
| def for_sim(simulation, names, *args, **kwargs): |  | ||||||
|     '''Return the set of exporters for a simulation, given the exporter names''' |  | ||||||
|     exporters = [] |  | ||||||
|     for name in names: |  | ||||||
|         mod = deserialize(name, known_modules=['soil.exporters']) |  | ||||||
|         exporters.append(mod(simulation, *args, **kwargs)) |  | ||||||
|     return exporters |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class DryRunner(BytesIO): | class DryRunner(BytesIO): | ||||||
|     def __init__(self, fname, *args, copy_to=None, **kwargs): |     def __init__(self, fname, *args, copy_to=None, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
| @@ -37,8 +29,12 @@ class DryRunner(BytesIO): | |||||||
|             super().write(bytes(txt, 'utf-8')) |             super().write(bytes(txt, 'utf-8')) | ||||||
|  |  | ||||||
|     def close(self): |     def close(self): | ||||||
|         logger.info('**Not** written to {} (dry run mode):\n\n{}\n\n'.format(self.__fname, |         content = '(binary data not shown)' | ||||||
|                                                                        self.getvalue().decode())) |         try: | ||||||
|  |             content = self.getvalue().decode() | ||||||
|  |         except UnicodeDecodeError: | ||||||
|  |             pass | ||||||
|  |         logger.info('**Not** written to {} (dry run mode):\n\n{}\n\n'.format(self.__fname, content)) | ||||||
|         super().close() |         super().close() | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -49,7 +45,7 @@ class Exporter: | |||||||
|     ''' |     ''' | ||||||
|  |  | ||||||
|     def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None): |     def __init__(self, simulation, outdir=None, dry_run=None, copy_to=None): | ||||||
|         self.sim = simulation |         self.simulation = simulation | ||||||
|         outdir = outdir or os.path.join(os.getcwd(), 'soil_output') |         outdir = outdir or os.path.join(os.getcwd(), 'soil_output') | ||||||
|         self.outdir = os.path.join(outdir, |         self.outdir = os.path.join(outdir, | ||||||
|                                    simulation.group or '', |                                    simulation.group or '', | ||||||
| @@ -59,12 +55,15 @@ class Exporter: | |||||||
|  |  | ||||||
|     def start(self): |     def start(self): | ||||||
|         '''Method to call when the simulation starts''' |         '''Method to call when the simulation starts''' | ||||||
|  |         pass | ||||||
|  |  | ||||||
|     def end(self): |     def end(self, stats): | ||||||
|         '''Method to call when the simulation ends''' |         '''Method to call when the simulation ends''' | ||||||
|  |         pass | ||||||
|  |  | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         '''Method to call when a trial ends''' |         '''Method to call when a trial ends''' | ||||||
|  |         pass | ||||||
|  |  | ||||||
|     def output(self, f, mode='w', **kwargs): |     def output(self, f, mode='w', **kwargs): | ||||||
|         if self.dry_run: |         if self.dry_run: | ||||||
| @@ -84,35 +83,47 @@ class default(Exporter): | |||||||
|     def start(self): |     def start(self): | ||||||
|         if not self.dry_run: |         if not self.dry_run: | ||||||
|             logger.info('Dumping results to %s', self.outdir) |             logger.info('Dumping results to %s', self.outdir) | ||||||
|             self.sim.dump_yaml(outdir=self.outdir) |             self.simulation.dump_yaml(outdir=self.outdir) | ||||||
|         else: |         else: | ||||||
|             logger.info('NOT dumping results') |             logger.info('NOT dumping results') | ||||||
|  |  | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         if not self.dry_run: |         if not self.dry_run: | ||||||
|             with timer('Dumping simulation {} trial {}'.format(self.sim.name, |             with timer('Dumping simulation {} trial {}'.format(self.simulation.name, | ||||||
|                                                                env.name)): |                                                                env.name)): | ||||||
|                 with self.output('{}.sqlite'.format(env.name), mode='wb') as f: |                 with self.output('{}.sqlite'.format(env.name), mode='wb') as f: | ||||||
|                     env.dump_sqlite(f) |                     env.dump_sqlite(f) | ||||||
|  |  | ||||||
|  |     def end(self, stats): | ||||||
|  |           with timer('Dumping simulation {}\'s stats'.format(self.simulation.name)): | ||||||
|  |               with self.output('{}.sqlite'.format(self.simulation.name), mode='wb') as f: | ||||||
|  |                   self.simulation.dump_sqlite(f) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class csv(Exporter): | class csv(Exporter): | ||||||
|     '''Export the state of each environment (and its agents) in a separate CSV file''' |     '''Export the state of each environment (and its agents) in a separate CSV file''' | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.sim.name, |         with timer('[CSV] Dumping simulation {} trial {} @ dir {}'.format(self.simulation.name, | ||||||
|                                                                           env.name, |                                                                           env.name, | ||||||
|                                                                           self.outdir)): |                                                                           self.outdir)): | ||||||
|             with self.output('{}.csv'.format(env.name)) as f: |             with self.output('{}.csv'.format(env.name)) as f: | ||||||
|                 env.dump_csv(f) |                 env.dump_csv(f) | ||||||
|  |  | ||||||
|  |             with self.output('{}.stats.csv'.format(env.name)) as f: | ||||||
|  |                 statwriter = csvlib.writer(f, delimiter='\t', quotechar='"', quoting=csvlib.QUOTE_ALL) | ||||||
|  |  | ||||||
|  |                 for stat in stats: | ||||||
|  |                     statwriter.writerow(stat) | ||||||
|  |  | ||||||
|  |  | ||||||
| class gexf(Exporter): | class gexf(Exporter): | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         if self.dry_run: |         if self.dry_run: | ||||||
|             logger.info('Not dumping GEXF in dry_run mode') |             logger.info('Not dumping GEXF in dry_run mode') | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         with timer('[GEXF] Dumping simulation {} trial {}'.format(self.sim.name, |         with timer('[GEXF] Dumping simulation {} trial {}'.format(self.simulation.name, | ||||||
|                                                                   env.name)): |                                                                   env.name)): | ||||||
|             with self.output('{}.gexf'.format(env.name), mode='wb') as f: |             with self.output('{}.gexf'.format(env.name), mode='wb') as f: | ||||||
|                 env.dump_gexf(f) |                 env.dump_gexf(f) | ||||||
| @@ -124,56 +135,24 @@ class dummy(Exporter): | |||||||
|         with self.output('dummy', 'w') as f: |         with self.output('dummy', 'w') as f: | ||||||
|             f.write('simulation started @ {}\n'.format(time.time())) |             f.write('simulation started @ {}\n'.format(time.time())) | ||||||
|  |  | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         with self.output('dummy', 'w') as f: |         with self.output('dummy', 'w') as f: | ||||||
|             for i in env.history_to_tuples(): |             for i in env.history_to_tuples(): | ||||||
|                 f.write(','.join(map(str, i))) |                 f.write(','.join(map(str, i))) | ||||||
|                 f.write('\n') |                 f.write('\n') | ||||||
|  |  | ||||||
|     def end(self): |     def sim(self, stats): | ||||||
|         with self.output('dummy', 'a') as f: |         with self.output('dummy', 'a') as f: | ||||||
|             f.write('simulation ended @ {}\n'.format(time.time())) |             f.write('simulation ended @ {}\n'.format(time.time())) | ||||||
|  |  | ||||||
|  |  | ||||||
| class distribution(Exporter): |  | ||||||
|     ''' |  | ||||||
|     Write the distribution of agent states at the end of each trial, |  | ||||||
|     the mean value, and its deviation. |  | ||||||
|     ''' |  | ||||||
|  |  | ||||||
|     def start(self): |  | ||||||
|         self.means = [] |  | ||||||
|         self.counts = [] |  | ||||||
|  |  | ||||||
|     def trial_end(self, env): |  | ||||||
|         df = env[None, None, None].df() |  | ||||||
|         ix = df.index[-1] |  | ||||||
|         attrs = df.columns.levels[0] |  | ||||||
|         vc = {} |  | ||||||
|         stats = {} |  | ||||||
|         for a in attrs: |  | ||||||
|             t = df.loc[(ix, a)] |  | ||||||
|             try: |  | ||||||
|                 self.means.append(('mean', a, t.mean())) |  | ||||||
|             except TypeError: |  | ||||||
|                 for name, count in t.value_counts().iteritems(): |  | ||||||
|                     self.counts.append(('count', a, name, count)) |  | ||||||
|  |  | ||||||
|     def end(self): |  | ||||||
|         dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value']) |  | ||||||
|         dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count']) |  | ||||||
|         dfm = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min']) |  | ||||||
|         dfc = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min']) |  | ||||||
|         with self.output('counts.csv') as f: |  | ||||||
|             dfc.to_csv(f) |  | ||||||
|         with self.output('metrics.csv') as f: |  | ||||||
|             dfm.to_csv(f) |  | ||||||
|  |  | ||||||
| class graphdrawing(Exporter): | class graphdrawing(Exporter): | ||||||
|  |  | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         # Outside effects |         # Outside effects | ||||||
|         f = plt.figure() |         f = plt.figure() | ||||||
|         nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111)) |         nx.draw(env.G, node_size=10, width=0.2, pos=nx.spring_layout(env.G, scale=100), ax=f.add_subplot(111)) | ||||||
|         with open('graph-{}.png'.format(env.name)) as f: |         with open('graph-{}.png'.format(env.name)) as f: | ||||||
|             f.savefig(f) |             f.savefig(f) | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										315
									
								
								soil/history.py
									
									
									
									
									
								
							
							
						
						
									
										315
									
								
								soil/history.py
									
									
									
									
									
								
							| @@ -1,315 +0,0 @@ | |||||||
| import time |  | ||||||
| import os |  | ||||||
| import pandas as pd |  | ||||||
| import sqlite3 |  | ||||||
| import copy |  | ||||||
| import logging |  | ||||||
| import tempfile |  | ||||||
|  |  | ||||||
| logger = logging.getLogger(__name__) |  | ||||||
|  |  | ||||||
| from collections import UserDict, namedtuple |  | ||||||
|  |  | ||||||
| from . import serialization |  | ||||||
| from .utils import open_or_reuse |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class History: |  | ||||||
|     """ |  | ||||||
|     Store and retrieve values from a sqlite database. |  | ||||||
|     """ |  | ||||||
|  |  | ||||||
|     def __init__(self, name=None, db_path=None, backup=False): |  | ||||||
|         self._db = None |  | ||||||
|  |  | ||||||
|         if db_path is None: |  | ||||||
|             if not name: |  | ||||||
|                 name = time.time() |  | ||||||
|             _, db_path = tempfile.mkstemp(suffix='{}.sqlite'.format(name)) |  | ||||||
|  |  | ||||||
|         if backup and os.path.exists(db_path): |  | ||||||
|             newname = db_path + '.backup{}.sqlite'.format(time.time()) |  | ||||||
|             os.rename(db_path, newname) |  | ||||||
|  |  | ||||||
|         self.db_path = db_path |  | ||||||
|  |  | ||||||
|         self.db = db_path |  | ||||||
|  |  | ||||||
|         with self.db: |  | ||||||
|             logger.debug('Creating database {}'.format(self.db_path)) |  | ||||||
|             self.db.execute('''CREATE TABLE IF NOT EXISTS history (agent_id text, t_step int, key text, value text text)''') |  | ||||||
|             self.db.execute('''CREATE TABLE IF NOT EXISTS value_types (key text, value_type text)''') |  | ||||||
|             self.db.execute('''CREATE UNIQUE INDEX IF NOT EXISTS idx_history ON history (agent_id, t_step, key);''') |  | ||||||
|         self._dtypes = {} |  | ||||||
|         self._tups = [] |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def db(self): |  | ||||||
|         try: |  | ||||||
|             self._db.cursor() |  | ||||||
|         except (sqlite3.ProgrammingError, AttributeError): |  | ||||||
|             self.db = None  # Reset the database |  | ||||||
|         return self._db |  | ||||||
|  |  | ||||||
|     @db.setter |  | ||||||
|     def db(self, db_path=None): |  | ||||||
|         self._close() |  | ||||||
|         db_path = db_path or self.db_path |  | ||||||
|         if isinstance(db_path, str): |  | ||||||
|             logger.debug('Connecting to database {}'.format(db_path)) |  | ||||||
|             self._db = sqlite3.connect(db_path) |  | ||||||
|         else: |  | ||||||
|             self._db = db_path |  | ||||||
|  |  | ||||||
|     def _close(self): |  | ||||||
|         if self._db is None: |  | ||||||
|             return |  | ||||||
|         self.flush_cache() |  | ||||||
|         self._db.close() |  | ||||||
|         self._db = None |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def dtypes(self): |  | ||||||
|         self.read_types() |  | ||||||
|         return {k:v[0] for k, v in self._dtypes.items()} |  | ||||||
|  |  | ||||||
|     def save_tuples(self, tuples): |  | ||||||
|         ''' |  | ||||||
|         Save a series of tuples, converting them to records if necessary |  | ||||||
|         ''' |  | ||||||
|         self.save_records(Record(*tup) for tup in tuples) |  | ||||||
|  |  | ||||||
|     def save_records(self, records): |  | ||||||
|         ''' |  | ||||||
|         Save a collection of records |  | ||||||
|         ''' |  | ||||||
|         for record in records: |  | ||||||
|             if not isinstance(record, Record): |  | ||||||
|                 record = Record(*record) |  | ||||||
|             self.save_record(*record) |  | ||||||
|  |  | ||||||
|     def save_record(self, agent_id, t_step, key, value): |  | ||||||
|         ''' |  | ||||||
|         Save a collection of records to the database. |  | ||||||
|         Database writes are cached. |  | ||||||
|         ''' |  | ||||||
|         value = self.convert(key, value) |  | ||||||
|         self._tups.append(Record(agent_id=agent_id, |  | ||||||
|                                  t_step=t_step, |  | ||||||
|                                  key=key, |  | ||||||
|                                  value=value)) |  | ||||||
|         if len(self._tups) > 100: |  | ||||||
|             self.flush_cache() |  | ||||||
|  |  | ||||||
|     def convert(self, key, value): |  | ||||||
|         """Get the serialized value for a given key.""" |  | ||||||
|         if key not in self._dtypes: |  | ||||||
|             self.read_types() |  | ||||||
|             if key not in self._dtypes: |  | ||||||
|                 name = serialization.name(value) |  | ||||||
|                 serializer = serialization.serializer(name) |  | ||||||
|                 deserializer = serialization.deserializer(name) |  | ||||||
|                 self._dtypes[key] = (name, serializer, deserializer) |  | ||||||
|                 with self.db: |  | ||||||
|                     self.db.execute("replace into value_types (key, value_type) values (?, ?)", (key, name)) |  | ||||||
|         return self._dtypes[key][1](value) |  | ||||||
|  |  | ||||||
|     def recover(self, key, value): |  | ||||||
|         """Get the deserialized value for a given key, and the serialized version.""" |  | ||||||
|         if key not in self._dtypes: |  | ||||||
|             self.read_types() |  | ||||||
|         if key not in self._dtypes: |  | ||||||
|             raise ValueError("Unknown datatype for {} and {}".format(key, value)) |  | ||||||
|         return self._dtypes[key][2](value) |  | ||||||
|  |  | ||||||
|     def flush_cache(self): |  | ||||||
|         ''' |  | ||||||
|         Use a cache to save state changes to avoid opening a session for every change. |  | ||||||
|         The cache will be flushed at the end of the simulation, and when history is accessed. |  | ||||||
|         ''' |  | ||||||
|         logger.debug('Flushing cache {}'.format(self.db_path)) |  | ||||||
|         with self.db: |  | ||||||
|             for rec in self._tups: |  | ||||||
|                 self.db.execute("replace into history(agent_id, t_step, key, value) values (?, ?, ?, ?)", (rec.agent_id, rec.t_step, rec.key, rec.value)) |  | ||||||
|         self._tups = list() |  | ||||||
|  |  | ||||||
|     def to_tuples(self): |  | ||||||
|         self.flush_cache() |  | ||||||
|         with self.db: |  | ||||||
|             res = self.db.execute("select agent_id, t_step, key, value from history ").fetchall() |  | ||||||
|         for r in res: |  | ||||||
|             agent_id, t_step, key, value = r |  | ||||||
|             value = self.recover(key, value) |  | ||||||
|             yield agent_id, t_step, key, value |  | ||||||
|  |  | ||||||
|     def read_types(self): |  | ||||||
|         with self.db: |  | ||||||
|             res = self.db.execute("select key, value_type from value_types ").fetchall() |  | ||||||
|         for k, v in res: |  | ||||||
|             serializer = serialization.serializer(v) |  | ||||||
|             deserializer = serialization.deserializer(v) |  | ||||||
|             self._dtypes[k] = (v, serializer, deserializer) |  | ||||||
|  |  | ||||||
|     def __getitem__(self, key): |  | ||||||
|         self.flush_cache() |  | ||||||
|         key = Key(*key) |  | ||||||
|         agent_ids = [key.agent_id] if key.agent_id is not None else [] |  | ||||||
|         t_steps = [key.t_step] if key.t_step is not None else [] |  | ||||||
|         keys = [key.key] if key.key is not None else [] |  | ||||||
|  |  | ||||||
|         df = self.read_sql(agent_ids=agent_ids, |  | ||||||
|                            t_steps=t_steps, |  | ||||||
|                            keys=keys) |  | ||||||
|         r = Records(df, filter=key, dtypes=self._dtypes) |  | ||||||
|         if r.resolved: |  | ||||||
|             return r.value() |  | ||||||
|         return r |  | ||||||
|  |  | ||||||
|     def read_sql(self, keys=None, agent_ids=None, t_steps=None, convert_types=False, limit=-1): |  | ||||||
|  |  | ||||||
|         self.read_types() |  | ||||||
|  |  | ||||||
|         def escape_and_join(v): |  | ||||||
|             if v is None: |  | ||||||
|                 return |  | ||||||
|             return ",".join(map(lambda x: "\'{}\'".format(x), v)) |  | ||||||
|  |  | ||||||
|         filters = [("key in ({})".format(escape_and_join(keys)), keys), |  | ||||||
|                    ("agent_id in ({})".format(escape_and_join(agent_ids)), agent_ids) |  | ||||||
|         ] |  | ||||||
|         filters = list(k[0] for k in filters if k[1]) |  | ||||||
|  |  | ||||||
|         last_df = None |  | ||||||
|         if t_steps: |  | ||||||
|             # Look for the last value before the minimum step in the query |  | ||||||
|             min_step = min(t_steps) |  | ||||||
|             last_filters = ['t_step < {}'.format(min_step),] |  | ||||||
|             last_filters = last_filters + filters |  | ||||||
|             condition = ' and '.join(last_filters) |  | ||||||
|  |  | ||||||
|             last_query = ''' |  | ||||||
|             select h1.* |  | ||||||
|             from history h1 |  | ||||||
|             inner join ( |  | ||||||
|             select agent_id, key, max(t_step) as t_step |  | ||||||
|             from history |  | ||||||
|             where {condition} |  | ||||||
|             group by agent_id, key |  | ||||||
|             ) h2 |  | ||||||
|             on h1.agent_id = h2.agent_id  and |  | ||||||
|                h1.key      = h2.key       and |  | ||||||
|                h1.t_step   = h2.t_step |  | ||||||
|             '''.format(condition=condition) |  | ||||||
|             last_df = pd.read_sql_query(last_query, self.db) |  | ||||||
|  |  | ||||||
|             filters.append("t_step >= '{}' and t_step <= '{}'".format(min_step, max(t_steps))) |  | ||||||
|  |  | ||||||
|         condition = '' |  | ||||||
|         if filters: |  | ||||||
|             condition = 'where {} '.format(' and '.join(filters)) |  | ||||||
|         query = 'select * from history {} limit {}'.format(condition, limit) |  | ||||||
|         df = pd.read_sql_query(query, self.db) |  | ||||||
|         if last_df is not None: |  | ||||||
|             df = pd.concat([df, last_df]) |  | ||||||
|  |  | ||||||
|         df_p = df.pivot_table(values='value', index=['t_step'], |  | ||||||
|                               columns=['key', 'agent_id'], |  | ||||||
|                               aggfunc='first') |  | ||||||
|  |  | ||||||
|         for k, v in self._dtypes.items(): |  | ||||||
|             if k in df_p: |  | ||||||
|                 dtype, _, deserial = v |  | ||||||
|                 df_p[k] = df_p[k].fillna(method='ffill').astype(dtype) |  | ||||||
|         if t_steps: |  | ||||||
|             df_p = df_p.reindex(t_steps, method='ffill') |  | ||||||
|         return df_p.ffill() |  | ||||||
|  |  | ||||||
|     def __getstate__(self): |  | ||||||
|         state = dict(**self.__dict__) |  | ||||||
|         del state['_db'] |  | ||||||
|         del state['_dtypes'] |  | ||||||
|         return state |  | ||||||
|  |  | ||||||
|     def __setstate__(self, state): |  | ||||||
|         self.__dict__ = state |  | ||||||
|         self._dtypes = {} |  | ||||||
|         self._db = None |  | ||||||
|  |  | ||||||
|     def dump(self, f): |  | ||||||
|         self._close() |  | ||||||
|         for line in open_or_reuse(self.db_path, 'rb'): |  | ||||||
|             f.write(line) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class Records(): |  | ||||||
|  |  | ||||||
|     def __init__(self, df, filter=None, dtypes=None): |  | ||||||
|         if not filter: |  | ||||||
|             filter = Key(agent_id=None, |  | ||||||
|                          t_step=None, |  | ||||||
|                          key=None) |  | ||||||
|         self._df = df |  | ||||||
|         self._filter = filter |  | ||||||
|         self.dtypes = dtypes or {} |  | ||||||
|         super().__init__() |  | ||||||
|  |  | ||||||
|     def mask(self, tup): |  | ||||||
|         res = () |  | ||||||
|         for i, k in zip(tup[:-1], self._filter): |  | ||||||
|             if k is None: |  | ||||||
|                 res = res + (i,) |  | ||||||
|         res = res + (tup[-1],) |  | ||||||
|         return res |  | ||||||
|  |  | ||||||
|     def filter(self, newKey): |  | ||||||
|         f = list(self._filter) |  | ||||||
|         for ix, i in enumerate(f): |  | ||||||
|             if i is None: |  | ||||||
|                 f[ix] = newKey |  | ||||||
|         self._filter = Key(*f) |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def resolved(self): |  | ||||||
|         return sum(1 for i in self._filter if i is not None) == 3 |  | ||||||
|  |  | ||||||
|     def __iter__(self): |  | ||||||
|         for column, series in self._df.iteritems(): |  | ||||||
|             key, agent_id = column |  | ||||||
|             for t_step, value in series.iteritems(): |  | ||||||
|                 r = Record(t_step=t_step, |  | ||||||
|                            agent_id=agent_id, |  | ||||||
|                            key=key, |  | ||||||
|                            value=value) |  | ||||||
|                 yield self.mask(r) |  | ||||||
|  |  | ||||||
|     def value(self): |  | ||||||
|         if self.resolved: |  | ||||||
|             f = self._filter |  | ||||||
|             try: |  | ||||||
|                 i = self._df[f.key][str(f.agent_id)] |  | ||||||
|                 ix = i.index.get_loc(f.t_step, method='ffill') |  | ||||||
|                 return i.iloc[ix] |  | ||||||
|             except KeyError as ex: |  | ||||||
|                 return self.dtypes[f.key][2]() |  | ||||||
|         return list(self) |  | ||||||
|  |  | ||||||
|     def df(self): |  | ||||||
|         return self._df |  | ||||||
|  |  | ||||||
|     def __getitem__(self, k): |  | ||||||
|         n = copy.copy(self) |  | ||||||
|         n.filter(k) |  | ||||||
|         if n.resolved: |  | ||||||
|             return n.value() |  | ||||||
|         return n |  | ||||||
|  |  | ||||||
|     def __len__(self): |  | ||||||
|         return len(self._df) |  | ||||||
|  |  | ||||||
|     def __str__(self): |  | ||||||
|         if self.resolved: |  | ||||||
|             return str(self.value()) |  | ||||||
|         return '<Records for [{}]>'.format(self._filter) |  | ||||||
|  |  | ||||||
| Key = namedtuple('Key', ['agent_id', 't_step', 'key']) |  | ||||||
| Record = namedtuple('Record', 'agent_id t_step key value') |  | ||||||
| @@ -13,14 +13,13 @@ from jinja2 import Template | |||||||
|  |  | ||||||
|  |  | ||||||
| logger = logging.getLogger('soil') | logger = logging.getLogger('soil') | ||||||
| logger.setLevel(logging.INFO) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_network(network_params, dir_path=None): | def load_network(network_params, dir_path=None): | ||||||
|     if network_params is None: |     G = nx.Graph() | ||||||
|         return nx.Graph() |  | ||||||
|     path = network_params.get('path', None) |     if 'path' in network_params: | ||||||
|     if path: |         path = network_params['path'] | ||||||
|         if dir_path and not os.path.isabs(path): |         if dir_path and not os.path.isabs(path): | ||||||
|             path = os.path.join(dir_path, path) |             path = os.path.join(dir_path, path) | ||||||
|         extension = os.path.splitext(path)[1][1:] |         extension = os.path.splitext(path)[1][1:] | ||||||
| @@ -32,24 +31,28 @@ def load_network(network_params, dir_path=None): | |||||||
|             method = getattr(nx.readwrite, 'read_' + extension) |             method = getattr(nx.readwrite, 'read_' + extension) | ||||||
|         except AttributeError: |         except AttributeError: | ||||||
|             raise AttributeError('Unknown format') |             raise AttributeError('Unknown format') | ||||||
|         return method(path, **kwargs) |         G = method(path, **kwargs) | ||||||
|  |  | ||||||
|     net_args = network_params.copy() |     elif 'generator' in network_params: | ||||||
|     if 'generator' not in net_args: |         net_args = network_params.copy() | ||||||
|         return nx.Graph() |         net_gen = net_args.pop('generator') | ||||||
|  |  | ||||||
|     net_gen = net_args.pop('generator') |         if dir_path not in sys.path: | ||||||
|  |             sys.path.append(dir_path) | ||||||
|  |  | ||||||
|     if dir_path not in sys.path: |         method = deserializer(net_gen, | ||||||
|         sys.path.append(dir_path) |                               known_modules=['networkx.generators',]) | ||||||
|  |         G = method(**net_args) | ||||||
|  |  | ||||||
|  |     return G | ||||||
|  |  | ||||||
|     method = deserializer(net_gen, |  | ||||||
|                           known_modules=['networkx.generators',]) |  | ||||||
|  |  | ||||||
|     return method(**net_args) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_file(infile): | def load_file(infile): | ||||||
|  |     folder = os.path.dirname(infile) | ||||||
|  |     if folder not in sys.path: | ||||||
|  |         sys.path.append(folder) | ||||||
|     with open(infile, 'r') as f: |     with open(infile, 'r') as f: | ||||||
|         return list(chain.from_iterable(map(expand_template, load_string(f)))) |         return list(chain.from_iterable(map(expand_template, load_string(f)))) | ||||||
|  |  | ||||||
| @@ -66,11 +69,32 @@ def expand_template(config): | |||||||
|         raise ValueError(('You must provide a definition of variables' |         raise ValueError(('You must provide a definition of variables' | ||||||
|                           ' for the template.')) |                           ' for the template.')) | ||||||
|  |  | ||||||
|     template = Template(config['template']) |     template = config['template'] | ||||||
|  |  | ||||||
|     sampler_name = config.get('sampler', 'SALib.sample.morris.sample') |     if not isinstance(template, str): | ||||||
|     n_samples = int(config.get('samples', 100)) |         template = yaml.dump(template) | ||||||
|     sampler = deserializer(sampler_name) |  | ||||||
|  |     template = Template(template) | ||||||
|  |  | ||||||
|  |     params = params_for_template(config) | ||||||
|  |  | ||||||
|  |     blank_str = template.render({k: 0 for k in params[0].keys()}) | ||||||
|  |     blank = list(load_string(blank_str)) | ||||||
|  |     if len(blank) > 1: | ||||||
|  |         raise ValueError('Templates must not return more than one configuration') | ||||||
|  |     if 'name' in blank[0]: | ||||||
|  |         raise ValueError('Templates cannot be named, use group instead') | ||||||
|  |  | ||||||
|  |     for ps in params: | ||||||
|  |         string = template.render(ps) | ||||||
|  |         for c in load_string(string): | ||||||
|  |             yield c | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def params_for_template(config): | ||||||
|  |     sampler_config = config.get('sampler', {'N': 100}) | ||||||
|  |     sampler = sampler_config.pop('method', 'SALib.sample.morris.sample') | ||||||
|  |     sampler = deserializer(sampler) | ||||||
|     bounds = config['vars']['bounds'] |     bounds = config['vars']['bounds'] | ||||||
|  |  | ||||||
|     problem = { |     problem = { | ||||||
| @@ -78,7 +102,7 @@ def expand_template(config): | |||||||
|         'names': list(bounds.keys()), |         'names': list(bounds.keys()), | ||||||
|         'bounds': list(v for v in bounds.values()) |         'bounds': list(v for v in bounds.values()) | ||||||
|     } |     } | ||||||
|     samples = sampler(problem, n_samples) |     samples = sampler(problem, **sampler_config) | ||||||
|  |  | ||||||
|     lists = config['vars'].get('lists', {}) |     lists = config['vars'].get('lists', {}) | ||||||
|     names = list(lists.keys()) |     names = list(lists.keys()) | ||||||
| @@ -88,20 +112,7 @@ def expand_template(config): | |||||||
|     allnames = names + problem['names'] |     allnames = names + problem['names'] | ||||||
|     allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)] |     allvalues = [(list(i[0])+list(i[1])) for i in product(combs, samples)] | ||||||
|     params = list(map(lambda x: dict(zip(allnames, x)), allvalues)) |     params = list(map(lambda x: dict(zip(allnames, x)), allvalues)) | ||||||
|  |     return params | ||||||
|  |  | ||||||
|     blank_str = template.render({k: 0 for k in allnames}) |  | ||||||
|     blank = list(load_string(blank_str)) |  | ||||||
|     if len(blank) > 1: |  | ||||||
|         raise ValueError('Templates must not return more than one configuration') |  | ||||||
|     if 'name' in blank[0]: |  | ||||||
|         raise ValueError('Templates cannot be named, use group instead') |  | ||||||
|  |  | ||||||
|     confs = [] |  | ||||||
|     for ps in params: |  | ||||||
|         string = template.render(ps) |  | ||||||
|         for c in load_string(string): |  | ||||||
|             yield c |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_files(*patterns, **kwargs): | def load_files(*patterns, **kwargs): | ||||||
| @@ -116,7 +127,7 @@ def load_files(*patterns, **kwargs): | |||||||
|  |  | ||||||
| def load_config(config): | def load_config(config): | ||||||
|     if isinstance(config, dict): |     if isinstance(config, dict): | ||||||
|         yield config, None |         yield config, os.getcwd() | ||||||
|     else: |     else: | ||||||
|         yield from load_files(config) |         yield from load_files(config) | ||||||
|  |  | ||||||
| @@ -199,3 +210,13 @@ def deserialize(type_, value=None, **kwargs): | |||||||
|     if value is None: |     if value is None: | ||||||
|         return des |         return des | ||||||
|     return des(value) |     return des(value) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def deserialize_all(names, *args, known_modules=['soil'], **kwargs): | ||||||
|  |     '''Return the set of exporters for a simulation, given the exporter names''' | ||||||
|  |     exporters = [] | ||||||
|  |     for name in names: | ||||||
|  |         mod = deserialize(name, known_modules=known_modules) | ||||||
|  |         exporters.append(mod(*args, **kwargs)) | ||||||
|  |     return exporters | ||||||
|  |  | ||||||
|   | |||||||
| @@ -4,24 +4,27 @@ import importlib | |||||||
| import sys | import sys | ||||||
| import yaml | import yaml | ||||||
| import traceback | import traceback | ||||||
|  | import logging | ||||||
| import networkx as nx | import networkx as nx | ||||||
| from networkx.readwrite import json_graph | from networkx.readwrite import json_graph | ||||||
| from multiprocessing import Pool | from multiprocessing import Pool | ||||||
| from functools import partial | from functools import partial | ||||||
|  | from tsih import History | ||||||
|  |  | ||||||
| import pickle | import pickle | ||||||
|  |  | ||||||
| from nxsim import NetworkSimulation |  | ||||||
|  |  | ||||||
| from . import serialization, utils, basestring, agents | from . import serialization, utils, basestring, agents | ||||||
| from .environment import Environment | from .environment import Environment | ||||||
| from .utils import logger | from .utils import logger | ||||||
| from .exporters import for_sim as exporters_for_sim | from .exporters import default | ||||||
|  | from .stats import defaultStats | ||||||
|  |  | ||||||
|  |  | ||||||
| class Simulation(NetworkSimulation): | #TODO: change documentation for simulation | ||||||
|  |  | ||||||
|  | class Simulation: | ||||||
|     """ |     """ | ||||||
|     Subclass of nsim.NetworkSimulation with three main differences: |     Similar to nsim.NetworkSimulation with three main differences: | ||||||
|         1) agent type can be specified by name or by class. |         1) agent type can be specified by name or by class. | ||||||
|         2) instead of just one type, a network agents distribution can be used. |         2) instead of just one type, a network agents distribution can be used. | ||||||
|            The distribution specifies the weight (or probability) of each |            The distribution specifies the weight (or probability) of each | ||||||
| @@ -91,11 +94,12 @@ class Simulation(NetworkSimulation): | |||||||
|                  environment_params=None, environment_class=None, |                  environment_params=None, environment_class=None, | ||||||
|                  **kwargs): |                  **kwargs): | ||||||
|  |  | ||||||
|         self.seed = str(seed) or str(time.time()) |  | ||||||
|         self.load_module = load_module |         self.load_module = load_module | ||||||
|         self.network_params = network_params |         self.network_params = network_params | ||||||
|         self.name = name or 'Unnamed_' + time.strftime("%Y-%m-%d_%H:%M:%S") |         self.name = name or 'Unnamed' | ||||||
|         self.group = group or None |         self.seed = str(seed or name) | ||||||
|  |         self._id = '{}_{}'.format(self.name, time.strftime("%Y-%m-%d_%H.%M.%S")) | ||||||
|  |         self.group = group or '' | ||||||
|         self.num_trials = num_trials |         self.num_trials = num_trials | ||||||
|         self.max_time = max_time |         self.max_time = max_time | ||||||
|         self.default_state = default_state or {} |         self.default_state = default_state or {} | ||||||
| @@ -128,15 +132,18 @@ class Simulation(NetworkSimulation): | |||||||
|         self.states = agents._validate_states(states, |         self.states = agents._validate_states(states, | ||||||
|                                               self.topology) |                                               self.topology) | ||||||
|  |  | ||||||
|  |         self._history = History(name=self.name, | ||||||
|  |                                 backup=False) | ||||||
|  |  | ||||||
|     def run_simulation(self, *args, **kwargs): |     def run_simulation(self, *args, **kwargs): | ||||||
|         return self.run(*args, **kwargs) |         return self.run(*args, **kwargs) | ||||||
|  |  | ||||||
|     def run(self, *args, **kwargs): |     def run(self, *args, **kwargs): | ||||||
|         '''Run the simulation and return the list of resulting environments''' |         '''Run the simulation and return the list of resulting environments''' | ||||||
|         return list(self._run_simulation_gen(*args, **kwargs)) |         return list(self.run_gen(*args, **kwargs)) | ||||||
|  |  | ||||||
|     def _run_sync_or_async(self, parallel=False, *args, **kwargs): |     def _run_sync_or_async(self, parallel=False, *args, **kwargs): | ||||||
|         if parallel: |         if parallel and not os.environ.get('SENPY_DEBUG', None): | ||||||
|             p = Pool() |             p = Pool() | ||||||
|             func = partial(self.run_trial_exceptions, |             func = partial(self.run_trial_exceptions, | ||||||
|                            *args, |                            *args, | ||||||
| @@ -148,45 +155,85 @@ class Simulation(NetworkSimulation): | |||||||
|                 yield i |                 yield i | ||||||
|         else: |         else: | ||||||
|             for i in range(self.num_trials): |             for i in range(self.num_trials): | ||||||
|                 yield self.run_trial(i, |                 yield self.run_trial(*args, | ||||||
|                                      *args, |  | ||||||
|                                      **kwargs) |                                      **kwargs) | ||||||
|  |  | ||||||
|     def _run_simulation_gen(self, *args, parallel=False, dry_run=False, |     def run_gen(self, *args, parallel=False, dry_run=False, | ||||||
|                             exporters=['default', ], outdir=None, exporter_params={}, **kwargs): |                 exporters=[default, ], stats=[], outdir=None, exporter_params={}, | ||||||
|  |                 stats_params={}, log_level=None, | ||||||
|  |                 **kwargs): | ||||||
|  |         '''Run the simulation and yield the resulting environments.''' | ||||||
|  |         if log_level: | ||||||
|  |             logger.setLevel(log_level) | ||||||
|         logger.info('Using exporters: %s', exporters or []) |         logger.info('Using exporters: %s', exporters or []) | ||||||
|         logger.info('Output directory: %s', outdir) |         logger.info('Output directory: %s', outdir) | ||||||
|         exporters = exporters_for_sim(self, |         exporters = serialization.deserialize_all(exporters, | ||||||
|                                       exporters, |                                                   simulation=self, | ||||||
|                                       dry_run=dry_run, |                                                   known_modules=['soil.exporters',], | ||||||
|                                       outdir=outdir, |                                                   dry_run=dry_run, | ||||||
|                                       **exporter_params) |                                                   outdir=outdir, | ||||||
|  |                                                   **exporter_params) | ||||||
|  |         stats = serialization.deserialize_all(simulation=self, | ||||||
|  |                                               names=stats, | ||||||
|  |                                               known_modules=['soil.stats',], | ||||||
|  |                                               **stats_params) | ||||||
|  |  | ||||||
|         with utils.timer('simulation {}'.format(self.name)): |         with utils.timer('simulation {}'.format(self.name)): | ||||||
|  |             for stat in stats: | ||||||
|  |                 stat.start() | ||||||
|  |  | ||||||
|             for exporter in exporters: |             for exporter in exporters: | ||||||
|                 exporter.start() |                 exporter.start() | ||||||
|  |             for env in self._run_sync_or_async(*args, | ||||||
|             for env in self._run_sync_or_async(*args, parallel=parallel, |                                                parallel=parallel, | ||||||
|  |                                                log_level=log_level, | ||||||
|                                                **kwargs): |                                                **kwargs): | ||||||
|  |  | ||||||
|  |                 collected = list(stat.trial(env) for stat in stats) | ||||||
|  |  | ||||||
|  |                 saved = self.save_stats(collected, t_step=env.now, trial_id=env.name) | ||||||
|  |  | ||||||
|                 for exporter in exporters: |                 for exporter in exporters: | ||||||
|                     exporter.trial_end(env) |                     exporter.trial(env, saved) | ||||||
|  |  | ||||||
|                 yield env |                 yield env | ||||||
|  |  | ||||||
|             for exporter in exporters: |  | ||||||
|                 exporter.end() |  | ||||||
|  |  | ||||||
|     def get_env(self, trial_id = 0, **kwargs): |             collected = list(stat.end() for stat in stats) | ||||||
|  |             saved = self.save_stats(collected) | ||||||
|  |  | ||||||
|  |             for exporter in exporters: | ||||||
|  |                 exporter.end(saved) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     def save_stats(self, collection, **kwargs): | ||||||
|  |         stats = dict(kwargs) | ||||||
|  |         for stat in collection: | ||||||
|  |             stats.update(stat) | ||||||
|  |         self._history.save_stats(utils.flatten_dict(stats)) | ||||||
|  |         return stats | ||||||
|  |  | ||||||
|  |     def get_stats(self, **kwargs): | ||||||
|  |         return self._history.get_stats(**kwargs) | ||||||
|  |  | ||||||
|  |     def log_stats(self, stats): | ||||||
|  |         logger.info('Stats: \n{}'.format(yaml.dump(stats, default_flow_style=False))) | ||||||
|  |      | ||||||
|  |  | ||||||
|  |     def get_env(self, trial_id=0, **kwargs): | ||||||
|         '''Create an environment for a trial of the simulation''' |         '''Create an environment for a trial of the simulation''' | ||||||
|         opts = self.environment_params.copy() |         opts = self.environment_params.copy() | ||||||
|         env_name = '{}_trial_{}'.format(self.name, trial_id) |  | ||||||
|         opts.update({ |         opts.update({ | ||||||
|             'name': env_name, |             'name': trial_id, | ||||||
|             'topology': self.topology.copy(), |             'topology': self.topology.copy(), | ||||||
|             'seed': self.seed+env_name, |             'network_params': self.network_params, | ||||||
|  |             'seed': '{}_trial_{}'.format(self.seed, trial_id), | ||||||
|             'initial_time': 0, |             'initial_time': 0, | ||||||
|             'interval': self.interval, |             'interval': self.interval, | ||||||
|             'network_agents': self.network_agents, |             'network_agents': self.network_agents, | ||||||
|  |             'initial_time': 0, | ||||||
|             'states': self.states, |             'states': self.states, | ||||||
|  |             'dir_path': self.dir_path, | ||||||
|             'default_state': self.default_state, |             'default_state': self.default_state, | ||||||
|             'environment_agents': self.environment_agents, |             'environment_agents': self.environment_agents, | ||||||
|         }) |         }) | ||||||
| @@ -194,20 +241,22 @@ class Simulation(NetworkSimulation): | |||||||
|         env = self.environment_class(**opts) |         env = self.environment_class(**opts) | ||||||
|         return env |         return env | ||||||
|  |  | ||||||
|     def run_trial(self, trial_id=0, until=None, **opts): |     def run_trial(self, until=None, log_level=logging.INFO, **opts): | ||||||
|         """Run a single trial of the simulation |  | ||||||
|  |  | ||||||
|         Parameters |  | ||||||
|         ---------- |  | ||||||
|         trial_id : int |  | ||||||
|         """ |         """ | ||||||
|  |         Run a single trial of the simulation | ||||||
|  |  | ||||||
|  |         """ | ||||||
|  |         trial_id = '{}_trial_{}'.format(self.name, time.time()).replace('.', '-') | ||||||
|  |         if log_level: | ||||||
|  |             logger.setLevel(log_level) | ||||||
|         # Set-up trial environment and graph |         # Set-up trial environment and graph | ||||||
|         until = until or self.max_time |         until = until or self.max_time | ||||||
|         env = self.get_env(trial_id = trial_id, **opts) |         env = self.get_env(trial_id=trial_id, **opts) | ||||||
|         # Set up agents on nodes |         # Set up agents on nodes | ||||||
|         with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)): |         with utils.timer('Simulation {} trial {}'.format(self.name, trial_id)): | ||||||
|             env.run(until) |             env.run(until) | ||||||
|         return env |         return env | ||||||
|  |  | ||||||
|     def run_trial_exceptions(self, *args, **kwargs): |     def run_trial_exceptions(self, *args, **kwargs): | ||||||
|         ''' |         ''' | ||||||
|         A wrapper for run_trial that catches exceptions and returns them. |         A wrapper for run_trial that catches exceptions and returns them. | ||||||
| @@ -248,16 +297,19 @@ class Simulation(NetworkSimulation): | |||||||
|         with utils.open_or_reuse(f, 'wb') as f: |         with utils.open_or_reuse(f, 'wb') as f: | ||||||
|             pickle.dump(self, f) |             pickle.dump(self, f) | ||||||
|  |  | ||||||
|  |     def dump_sqlite(self, f): | ||||||
|  |         return self._history.dump(f) | ||||||
|  |  | ||||||
|     def __getstate__(self): |     def __getstate__(self): | ||||||
|         state={} |         state={} | ||||||
|         for k, v in self.__dict__.items(): |         for k, v in self.__dict__.items(): | ||||||
|             if k[0] != '_': |             if k[0] != '_': | ||||||
|                 state[k] = v |                 state[k] = v | ||||||
|                 state['topology'] = json_graph.node_link_data(self.topology) |                 state['topology'] = json_graph.node_link_data(self.topology) | ||||||
|                 state['network_agents'] = agents.serialize_distribution(self.network_agents, |                 state['network_agents'] = agents.serialize_definition(self.network_agents, | ||||||
|                                                                         known_modules = []) |                                                                       known_modules = []) | ||||||
|                 state['environment_agents'] = agents.serialize_distribution(self.environment_agents, |                 state['environment_agents'] = agents.serialize_definition(self.environment_agents, | ||||||
|                                                                             known_modules = []) |                                                                           known_modules = []) | ||||||
|                 state['environment_class'] = serialization.serialize(self.environment_class, |                 state['environment_class'] = serialization.serialize(self.environment_class, | ||||||
|                                                                      known_modules=['soil.environment'])[1]  # func, name |                                                                      known_modules=['soil.environment'])[1]  # func, name | ||||||
|         if state['load_module'] is None: |         if state['load_module'] is None: | ||||||
| @@ -275,7 +327,6 @@ class Simulation(NetworkSimulation): | |||||||
|                                                               known_modules=[self.load_module]) |                                                               known_modules=[self.load_module]) | ||||||
|         self.environment_class = serialization.deserialize(self.environment_class, |         self.environment_class = serialization.deserialize(self.environment_class, | ||||||
|                                                    known_modules=[self.load_module, 'soil.environment', ])  # func, name |                                                    known_modules=[self.load_module, 'soil.environment', ])  # func, name | ||||||
|         return state |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def all_from_config(config): | def all_from_config(config): | ||||||
|   | |||||||
							
								
								
									
										106
									
								
								soil/stats.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								soil/stats.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,106 @@ | |||||||
|  | import pandas as pd | ||||||
|  |  | ||||||
|  | from collections import Counter | ||||||
|  |  | ||||||
|  | class Stats: | ||||||
|  |     ''' | ||||||
|  |     Interface for all stats. It is not necessary, but it is useful | ||||||
|  |     if you don't plan to implement all the methods. | ||||||
|  |     ''' | ||||||
|  |  | ||||||
|  |     def __init__(self, simulation): | ||||||
|  |         self.simulation = simulation | ||||||
|  |  | ||||||
|  |     def start(self): | ||||||
|  |         '''Method to call when the simulation starts''' | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     def end(self): | ||||||
|  |         '''Method to call when the simulation ends''' | ||||||
|  |         return {} | ||||||
|  |  | ||||||
|  |     def trial(self, env): | ||||||
|  |         '''Method to call when a trial ends''' | ||||||
|  |         return {} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class distribution(Stats): | ||||||
|  |     ''' | ||||||
|  |     Calculate the distribution of agent states at the end of each trial, | ||||||
|  |     the mean value, and its deviation. | ||||||
|  |     ''' | ||||||
|  |  | ||||||
|  |     def start(self): | ||||||
|  |         self.means = [] | ||||||
|  |         self.counts = [] | ||||||
|  |  | ||||||
|  |     def trial(self, env): | ||||||
|  |         df = env[None, None, None].df() | ||||||
|  |         df = df.drop('SEED', axis=1) | ||||||
|  |         ix = df.index[-1] | ||||||
|  |         attrs = df.columns.get_level_values(0) | ||||||
|  |         vc = {} | ||||||
|  |         stats = { | ||||||
|  |             'mean': {}, | ||||||
|  |             'count': {}, | ||||||
|  |         } | ||||||
|  |         for a in attrs: | ||||||
|  |             t = df.loc[(ix, a)] | ||||||
|  |             try: | ||||||
|  |                 stats['mean'][a] = t.mean() | ||||||
|  |                 self.means.append(('mean', a, t.mean())) | ||||||
|  |             except TypeError: | ||||||
|  |                 pass | ||||||
|  |  | ||||||
|  |             for name, count in t.value_counts().iteritems(): | ||||||
|  |                 if a not in stats['count']: | ||||||
|  |                     stats['count'][a] = {} | ||||||
|  |                 stats['count'][a][name] = count | ||||||
|  |                 self.counts.append(('count', a, name, count)) | ||||||
|  |  | ||||||
|  |         return stats | ||||||
|  |  | ||||||
|  |     def end(self): | ||||||
|  |         dfm = pd.DataFrame(self.means, columns=['metric', 'key', 'value']) | ||||||
|  |         dfc = pd.DataFrame(self.counts, columns=['metric', 'key', 'value', 'count']) | ||||||
|  |  | ||||||
|  |         count = {} | ||||||
|  |         mean = {} | ||||||
|  |  | ||||||
|  |         if self.means: | ||||||
|  |             res = dfm.groupby(by=['key']).agg(['mean', 'std', 'count', 'median', 'max', 'min']) | ||||||
|  |             mean = res['value'].to_dict() | ||||||
|  |         if self.counts: | ||||||
|  |             res = dfc.groupby(by=['key', 'value']).agg(['mean', 'std', 'count', 'median', 'max', 'min']) | ||||||
|  |             for k,v in res['count'].to_dict().items(): | ||||||
|  |                 if k not in count: | ||||||
|  |                     count[k] = {} | ||||||
|  |                 for tup, times in v.items(): | ||||||
|  |                     subkey, subcount = tup | ||||||
|  |                     if subkey not in count[k]: | ||||||
|  |                         count[k][subkey] = {} | ||||||
|  |                     count[k][subkey][subcount] = times | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         return {'count': count, 'mean': mean} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class defaultStats(Stats): | ||||||
|  |  | ||||||
|  |     def trial(self, env): | ||||||
|  |         c = Counter() | ||||||
|  |         c.update(a.__class__.__name__ for a in env.network_agents) | ||||||
|  |  | ||||||
|  |         c2 = Counter() | ||||||
|  |         c2.update(a['id'] for a in env.network_agents) | ||||||
|  |  | ||||||
|  |         return { | ||||||
|  |             'network ': { | ||||||
|  |                 'n_nodes': env.G.number_of_nodes(), | ||||||
|  |                 'n_edges': env.G.number_of_edges(), | ||||||
|  |             }, | ||||||
|  |             'agents': { | ||||||
|  |                 'model_count': dict(c), | ||||||
|  |                 'state_count': dict(c2), | ||||||
|  |             } | ||||||
|  |         } | ||||||
							
								
								
									
										87
									
								
								soil/time.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								soil/time.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,87 @@ | |||||||
|  | from mesa.time import BaseScheduler | ||||||
|  | from queue import Empty | ||||||
|  | from heapq import heappush, heappop | ||||||
|  | import math | ||||||
|  | from .utils import logger | ||||||
|  | from mesa import Agent | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class When: | ||||||
|  |     def __init__(self, time): | ||||||
|  |         self._time = float(time) | ||||||
|  |  | ||||||
|  |     def abs(self, time): | ||||||
|  |         return self._time | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Delta: | ||||||
|  |     def __init__(self, delta): | ||||||
|  |         self._delta = delta | ||||||
|  |  | ||||||
|  |     def __eq__(self, other): | ||||||
|  |         return self._delta == other._delta | ||||||
|  |  | ||||||
|  |     def abs(self, time): | ||||||
|  |         return time + self._delta | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TimedActivation(BaseScheduler): | ||||||
|  |     """A scheduler which activates each agent when the agent requests. | ||||||
|  |     In each activation, each agent will update its 'next_time'. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, *args, **kwargs): | ||||||
|  |         super().__init__(self) | ||||||
|  |         self._queue = [] | ||||||
|  |         self.next_time = 0 | ||||||
|  |  | ||||||
|  |     def add(self, agent: Agent): | ||||||
|  |         if agent.unique_id not in self._agents: | ||||||
|  |             heappush(self._queue, (self.time, agent.unique_id)) | ||||||
|  |             super().add(agent) | ||||||
|  |  | ||||||
|  |     def step(self, until: float =float('inf')) -> None: | ||||||
|  |         """ | ||||||
|  |         Executes agents in order, one at a time. After each step, | ||||||
|  |         an agent will signal when it wants to be scheduled next. | ||||||
|  |         """ | ||||||
|  |  | ||||||
|  |         when = None | ||||||
|  |         agent_id = None | ||||||
|  |         unsched = [] | ||||||
|  |         until = until or float('inf') | ||||||
|  |  | ||||||
|  |         if not self._queue: | ||||||
|  |             self.time = until | ||||||
|  |             self.next_time = float('inf') | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         (when, agent_id) = self._queue[0] | ||||||
|  |  | ||||||
|  |         if until and when > until: | ||||||
|  |             self.time = until | ||||||
|  |             self.next_time = when | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         self.time = when | ||||||
|  |         next_time = float("inf") | ||||||
|  |  | ||||||
|  |         while when == self.time: | ||||||
|  |             heappop(self._queue) | ||||||
|  |             logger.debug(f'Stepping agent {agent_id}') | ||||||
|  |             when = (self._agents[agent_id].step() or Delta(1)).abs(self.time) | ||||||
|  |             heappush(self._queue, (when, agent_id)) | ||||||
|  |             if when < next_time: | ||||||
|  |                 next_time = when | ||||||
|  |  | ||||||
|  |             if not self._queue or self._queue[0][0] > self.time: | ||||||
|  |                 agent_id = None | ||||||
|  |                 break | ||||||
|  |             else: | ||||||
|  |                 (when, agent_id) = self._queue[0] | ||||||
|  |  | ||||||
|  |         if when and when < self.time: | ||||||
|  |             raise Exception("Invalid scheduling time") | ||||||
|  |  | ||||||
|  |         self.next_time = next_time | ||||||
|  |         self.steps += 1 | ||||||
| @@ -7,7 +7,8 @@ from shutil import copyfile | |||||||
| from contextlib import contextmanager | from contextlib import contextmanager | ||||||
|  |  | ||||||
| logger = logging.getLogger('soil') | logger = logging.getLogger('soil') | ||||||
| logger.setLevel(logging.INFO) | # logging.basicConfig() | ||||||
|  | # logger.setLevel(logging.INFO) | ||||||
|  |  | ||||||
|  |  | ||||||
| @contextmanager | @contextmanager | ||||||
| @@ -25,20 +26,21 @@ def timer(name='task', pre="", function=logger.info, to_object=None): | |||||||
|         to_object.end = end |         to_object.end = end | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def safe_open(path, mode='r', backup=True, **kwargs): | def safe_open(path, mode='r', backup=True, **kwargs): | ||||||
|     outdir = os.path.dirname(path) |     outdir = os.path.dirname(path) | ||||||
|     if outdir and not os.path.exists(outdir): |     if outdir and not os.path.exists(outdir): | ||||||
|         os.makedirs(outdir) |         os.makedirs(outdir) | ||||||
|     if backup and 'w' in mode and os.path.exists(path): |     if backup and 'w' in mode and os.path.exists(path): | ||||||
|         creation = os.path.getctime(path) |         creation = os.path.getctime(path) | ||||||
|         stamp = time.strftime('%Y-%m-%d_%H:%M', time.localtime(creation)) |         stamp = time.strftime('%Y-%m-%d_%H.%M.%S', time.localtime(creation)) | ||||||
|  |  | ||||||
|         backup_dir = os.path.join(outdir, stamp) |         backup_dir = os.path.join(outdir, 'backup') | ||||||
|         if not os.path.exists(backup_dir): |         if not os.path.exists(backup_dir): | ||||||
|             os.makedirs(backup_dir) |             os.makedirs(backup_dir) | ||||||
|         newpath = os.path.join(backup_dir, os.path.basename(path)) |         newpath = os.path.join(backup_dir, '{}@{}'.format(os.path.basename(path), | ||||||
|         if os.path.exists(newpath): |                                                                stamp)) | ||||||
|             newpath = '{}@{}'.format(newpath, time.time()) |  | ||||||
|         copyfile(path, newpath) |         copyfile(path, newpath) | ||||||
|     return open(path, mode=mode, **kwargs) |     return open(path, mode=mode, **kwargs) | ||||||
|  |  | ||||||
| @@ -48,3 +50,40 @@ def open_or_reuse(f, *args, **kwargs): | |||||||
|         return safe_open(f, *args, **kwargs) |         return safe_open(f, *args, **kwargs) | ||||||
|     except (AttributeError, TypeError): |     except (AttributeError, TypeError): | ||||||
|         return f |         return f | ||||||
|  |  | ||||||
|  | def flatten_dict(d): | ||||||
|  |     if not isinstance(d, dict): | ||||||
|  |         return d | ||||||
|  |     return dict(_flatten_dict(d)) | ||||||
|  |  | ||||||
|  | def _flatten_dict(d, prefix=''): | ||||||
|  |     if not isinstance(d, dict): | ||||||
|  |         # print('END:', prefix, d) | ||||||
|  |         yield prefix, d | ||||||
|  |         return | ||||||
|  |     if prefix: | ||||||
|  |         prefix = prefix + '.' | ||||||
|  |     for k, v in d.items(): | ||||||
|  |         # print(k, v) | ||||||
|  |         res = list(_flatten_dict(v, prefix='{}{}'.format(prefix, k))) | ||||||
|  |         # print('RES:', res) | ||||||
|  |         yield from res | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def unflatten_dict(d): | ||||||
|  |     out = {} | ||||||
|  |     for k, v in d.items(): | ||||||
|  |         target = out | ||||||
|  |         if not isinstance(k, str): | ||||||
|  |             target[k] = v | ||||||
|  |             continue | ||||||
|  |         tokens = k.split('.') | ||||||
|  |         if len(tokens) < 2: | ||||||
|  |             target[k] = v | ||||||
|  |             continue | ||||||
|  |         for token in tokens[:-1]: | ||||||
|  |             if token not in target: | ||||||
|  |                 target[token] = {} | ||||||
|  |             target = target[token] | ||||||
|  |         target[tokens[-1]] = v | ||||||
|  |     return out | ||||||
|   | |||||||
							
								
								
									
										5
									
								
								soil/visualization.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								soil/visualization.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | from mesa.visualization.UserParam import UserSettableParameter | ||||||
|  |  | ||||||
|  | class UserSettableParameter(UserSettableParameter): | ||||||
|  |     def __str__(self): | ||||||
|  |         return self.value | ||||||
| @@ -1 +1,4 @@ | |||||||
| pytest | pytest | ||||||
|  | mesa>=0.8.9 | ||||||
|  | scipy>=1.3 | ||||||
|  | tornado | ||||||
|   | |||||||
| @@ -21,11 +21,13 @@ class Ping(agents.FSM): | |||||||
|     @agents.default_state |     @agents.default_state | ||||||
|     @agents.state |     @agents.state | ||||||
|     def even(self): |     def even(self): | ||||||
|  |         self.debug(f'Even {self["count"]}') | ||||||
|         self['count'] += 1 |         self['count'] += 1 | ||||||
|         return self.odd |         return self.odd | ||||||
|  |  | ||||||
|     @agents.state |     @agents.state | ||||||
|     def odd(self): |     def odd(self): | ||||||
|  |         self.debug(f'Odd {self["count"]}') | ||||||
|         self['count'] += 1 |         self['count'] += 1 | ||||||
|         return self.even |         return self.even | ||||||
|  |  | ||||||
| @@ -65,25 +67,24 @@ class TestAnalysis(TestCase): | |||||||
|     def test_count(self): |     def test_count(self): | ||||||
|         env = self.env |         env = self.env | ||||||
|         df = analysis.read_sql(env._history.db_path) |         df = analysis.read_sql(env._history.db_path) | ||||||
|         res = analysis.get_count(df, 'SEED', 'id') |         res = analysis.get_count(df, 'SEED', 'state_id') | ||||||
|         assert res['SEED']['seedanalysis_trial_0'].iloc[0] == 1 |         assert res['SEED'][self.env['SEED']].iloc[0] == 1 | ||||||
|         assert res['SEED']['seedanalysis_trial_0'].iloc[-1] == 1 |         assert res['SEED'][self.env['SEED']].iloc[-1] == 1 | ||||||
|         assert res['id']['odd'].iloc[0] == 2 |         assert res['state_id']['odd'].iloc[0] == 2 | ||||||
|         assert res['id']['even'].iloc[0] == 0 |         assert res['state_id']['even'].iloc[0] == 0 | ||||||
|         assert res['id']['odd'].iloc[-1] == 1 |         assert res['state_id']['odd'].iloc[-1] == 1 | ||||||
|         assert res['id']['even'].iloc[-1] == 1 |         assert res['state_id']['even'].iloc[-1] == 1 | ||||||
|  |  | ||||||
|     def test_value(self): |     def test_value(self): | ||||||
|         env = self.env |         env = self.env | ||||||
|         df = analysis.read_sql(env._history._db) |         df = analysis.read_sql(env._history.db_path) | ||||||
|         res_sum = analysis.get_value(df, 'count') |         res_sum = analysis.get_value(df, 'count') | ||||||
|  |  | ||||||
|         assert res_sum['count'].iloc[0] == 2 |         assert res_sum['count'].iloc[0] == 2 | ||||||
|  |  | ||||||
|         import numpy as np |         import numpy as np | ||||||
|         res_mean = analysis.get_value(df, 'count', aggfunc=np.mean) |         res_mean = analysis.get_value(df, 'count', aggfunc=np.mean) | ||||||
|         assert res_mean['count'].iloc[0] == 1 |         assert res_mean['count'].iloc[15] == (16+8)/2 | ||||||
|  |  | ||||||
|         res_total = analysis.get_value(df) |         res_total = analysis.get_majority(df) | ||||||
|  |         res_total['SEED'].iloc[0] == self.env['SEED'] | ||||||
|         res_total['SEED'].iloc[0] == 'seedanalysis_trial_0' |  | ||||||
|   | |||||||
| @@ -31,7 +31,7 @@ def make_example_test(path, config): | |||||||
|                 try: |                 try: | ||||||
|                     n = config['network_params']['n'] |                     n = config['network_params']['n'] | ||||||
|                     assert len(list(env.network_agents)) == n |                     assert len(list(env.network_agents)) == n | ||||||
|                     assert env.now > 2  # It has run |                     assert env.now > 0  # It has run | ||||||
|                     assert env.now <= config['max_time']  # But not further than allowed |                     assert env.now <= config['max_time']  # But not further than allowed | ||||||
|                 except KeyError: |                 except KeyError: | ||||||
|                     pass |                     pass | ||||||
|   | |||||||
| @@ -6,26 +6,32 @@ from time import time | |||||||
|  |  | ||||||
| from unittest import TestCase | from unittest import TestCase | ||||||
| from soil import exporters | from soil import exporters | ||||||
| from soil.utils import safe_open |  | ||||||
| from soil import simulation | from soil import simulation | ||||||
|  |  | ||||||
|  | from soil.stats import distribution | ||||||
|  |  | ||||||
| class Dummy(exporters.Exporter): | class Dummy(exporters.Exporter): | ||||||
|     started = False |     started = False | ||||||
|     trials = 0 |     trials = 0 | ||||||
|     ended = False |     ended = False | ||||||
|     total_time = 0 |     total_time = 0 | ||||||
|  |     called_start = 0 | ||||||
|  |     called_trial = 0 | ||||||
|  |     called_end = 0 | ||||||
|  |  | ||||||
|     def start(self): |     def start(self): | ||||||
|  |         self.__class__.called_start += 1 | ||||||
|         self.__class__.started = True |         self.__class__.started = True | ||||||
|  |  | ||||||
|     def trial_end(self, env): |     def trial(self, env, stats): | ||||||
|         assert env |         assert env | ||||||
|         self.__class__.trials += 1 |         self.__class__.trials += 1 | ||||||
|         self.__class__.total_time += env.now |         self.__class__.total_time += env.now | ||||||
|  |         self.__class__.called_trial += 1 | ||||||
|  |  | ||||||
|     def end(self): |     def end(self, stats): | ||||||
|         self.__class__.ended = True |         self.__class__.ended = True | ||||||
|  |         self.__class__.called_end += 1 | ||||||
|  |  | ||||||
|  |  | ||||||
| class Exporters(TestCase): | class Exporters(TestCase): | ||||||
| @@ -39,32 +45,17 @@ class Exporters(TestCase): | |||||||
|             'environment_params': {} |             'environment_params': {} | ||||||
|         } |         } | ||||||
|         s = simulation.from_config(config) |         s = simulation.from_config(config) | ||||||
|         s.run_simulation(exporters=[Dummy], dry_run=True) |         for env in s.run_simulation(exporters=[Dummy], dry_run=True): | ||||||
|  |             assert env.now <= 2 | ||||||
|  |  | ||||||
|         assert Dummy.started |         assert Dummy.started | ||||||
|         assert Dummy.ended |         assert Dummy.ended | ||||||
|  |         assert Dummy.called_start == 1 | ||||||
|  |         assert Dummy.called_end == 1 | ||||||
|  |         assert Dummy.called_trial == 5 | ||||||
|         assert Dummy.trials == 5 |         assert Dummy.trials == 5 | ||||||
|         assert Dummy.total_time == 2*5 |         assert Dummy.total_time == 2*5 | ||||||
|  |  | ||||||
|     def test_distribution(self): |  | ||||||
|         '''The distribution exporter should write the number of agents in each state''' |  | ||||||
|         config = { |  | ||||||
|             'name': 'exporter_sim', |  | ||||||
|             'network_params': { |  | ||||||
|                 'generator': 'complete_graph', |  | ||||||
|                 'n': 4 |  | ||||||
|             }, |  | ||||||
|             'agent_type': 'CounterModel', |  | ||||||
|             'max_time': 2, |  | ||||||
|             'num_trials': 5, |  | ||||||
|             'environment_params': {} |  | ||||||
|         } |  | ||||||
|         output = io.StringIO() |  | ||||||
|         s = simulation.from_config(config) |  | ||||||
|         s.run_simulation(exporters=[exporters.distribution], dry_run=True, exporter_params={'copy_to': output}) |  | ||||||
|         result = output.getvalue() |  | ||||||
|         assert 'count' in result |  | ||||||
|         assert 'SEED,Noneexporter_sim_trial_3,1,,1,1,1,1' in result |  | ||||||
|  |  | ||||||
|     def test_writing(self): |     def test_writing(self): | ||||||
|         '''Try to write CSV, GEXF, sqlite and YAML (without dry_run)''' |         '''Try to write CSV, GEXF, sqlite and YAML (without dry_run)''' | ||||||
|         n_trials = 5 |         n_trials = 5 | ||||||
| @@ -83,11 +74,11 @@ class Exporters(TestCase): | |||||||
|         s = simulation.from_config(config) |         s = simulation.from_config(config) | ||||||
|         tmpdir = tempfile.mkdtemp() |         tmpdir = tempfile.mkdtemp() | ||||||
|         envs = s.run_simulation(exporters=[ |         envs = s.run_simulation(exporters=[ | ||||||
|             exporters.default, |                                     exporters.default, | ||||||
|             exporters.csv, |                                     exporters.csv, | ||||||
|             exporters.gexf, |                                     exporters.gexf, | ||||||
|             exporters.distribution, |                                 ], | ||||||
|         ], |                                 stats=[distribution,], | ||||||
|                                 outdir=tmpdir, |                                 outdir=tmpdir, | ||||||
|                                 exporter_params={'copy_to': output}) |                                 exporter_params={'copy_to': output}) | ||||||
|         result = output.getvalue() |         result = output.getvalue() | ||||||
|   | |||||||
| @@ -1,156 +0,0 @@ | |||||||
| from unittest import TestCase |  | ||||||
|  |  | ||||||
| import os |  | ||||||
| import shutil |  | ||||||
| from glob import glob |  | ||||||
|  |  | ||||||
| from soil import history |  | ||||||
|  |  | ||||||
|  |  | ||||||
| ROOT = os.path.abspath(os.path.dirname(__file__)) |  | ||||||
| DBROOT = os.path.join(ROOT, 'testdb') |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestHistory(TestCase): |  | ||||||
|  |  | ||||||
|     def setUp(self): |  | ||||||
|         if not os.path.exists(DBROOT): |  | ||||||
|             os.makedirs(DBROOT) |  | ||||||
|  |  | ||||||
|     def tearDown(self): |  | ||||||
|         if os.path.exists(DBROOT): |  | ||||||
|             shutil.rmtree(DBROOT) |  | ||||||
|  |  | ||||||
|     def test_history(self): |  | ||||||
|         """ |  | ||||||
|         """ |  | ||||||
|         tuples = ( |  | ||||||
|             ('a_0', 0, 'id', 'h'), |  | ||||||
|             ('a_0', 1, 'id', 'e'), |  | ||||||
|             ('a_0', 2, 'id', 'l'), |  | ||||||
|             ('a_0', 3, 'id', 'l'), |  | ||||||
|             ('a_0', 4, 'id', 'o'), |  | ||||||
|             ('a_1', 0, 'id', 'v'), |  | ||||||
|             ('a_1', 1, 'id', 'a'), |  | ||||||
|             ('a_1', 2, 'id', 'l'), |  | ||||||
|             ('a_1', 3, 'id', 'u'), |  | ||||||
|             ('a_1', 4, 'id', 'e'), |  | ||||||
|             ('env', 1, 'prob', 1), |  | ||||||
|             ('env', 3, 'prob', 2), |  | ||||||
|             ('env', 5, 'prob', 3), |  | ||||||
|             ('a_2', 7, 'finished', True), |  | ||||||
|         ) |  | ||||||
|         h = history.History() |  | ||||||
|         h.save_tuples(tuples) |  | ||||||
|         # assert h['env', 0, 'prob'] == 0 |  | ||||||
|         for i in range(1, 7): |  | ||||||
|             assert h['env', i, 'prob'] == ((i-1)//2)+1 |  | ||||||
|  |  | ||||||
|  |  | ||||||
|         for i, k in zip(range(5), 'hello'): |  | ||||||
|             assert h['a_0', i, 'id'] == k |  | ||||||
|         for record, value in zip(h['a_0', None, 'id'], 'hello'): |  | ||||||
|             t_step, val = record |  | ||||||
|             assert val == value |  | ||||||
|  |  | ||||||
|         for i, k in zip(range(5), 'value'): |  | ||||||
|             assert h['a_1', i, 'id'] == k |  | ||||||
|         for i in range(5, 8): |  | ||||||
|             assert h['a_1', i, 'id'] == 'e' |  | ||||||
|         for i in range(7): |  | ||||||
|             assert h['a_2', i, 'finished'] == False |  | ||||||
|         assert h['a_2', 7, 'finished'] |  | ||||||
|  |  | ||||||
|     def test_history_gen(self): |  | ||||||
|         """ |  | ||||||
|         """ |  | ||||||
|         tuples = ( |  | ||||||
|             ('a_1', 0, 'id', 'v'), |  | ||||||
|             ('a_1', 1, 'id', 'a'), |  | ||||||
|             ('a_1', 2, 'id', 'l'), |  | ||||||
|             ('a_1', 3, 'id', 'u'), |  | ||||||
|             ('a_1', 4, 'id', 'e'), |  | ||||||
|             ('env', 1, 'prob', 1), |  | ||||||
|             ('env', 2, 'prob', 2), |  | ||||||
|             ('env', 3, 'prob', 3), |  | ||||||
|             ('a_2', 7, 'finished', True), |  | ||||||
|         ) |  | ||||||
|         h = history.History() |  | ||||||
|         h.save_tuples(tuples) |  | ||||||
|         for t_step, key, value in h['env', None, None]: |  | ||||||
|             assert t_step == value |  | ||||||
|             assert key == 'prob' |  | ||||||
|  |  | ||||||
|         records = list(h[None, 7, None]) |  | ||||||
|         assert len(records) == 3 |  | ||||||
|         for i in records: |  | ||||||
|             agent_id, key, value = i |  | ||||||
|             if agent_id == 'a_1': |  | ||||||
|                 assert key == 'id' |  | ||||||
|                 assert value == 'e' |  | ||||||
|             elif agent_id == 'a_2': |  | ||||||
|                 assert key == 'finished' |  | ||||||
|                 assert value |  | ||||||
|             else: |  | ||||||
|                 assert key == 'prob' |  | ||||||
|                 assert value == 3 |  | ||||||
|  |  | ||||||
|         records = h['a_1', 7, None] |  | ||||||
|         assert records['id'] == 'e' |  | ||||||
|  |  | ||||||
|     def test_history_file(self): |  | ||||||
|         """ |  | ||||||
|         History should be saved to a file |  | ||||||
|         """ |  | ||||||
|         tuples = ( |  | ||||||
|             ('a_1', 0, 'id', 'v'), |  | ||||||
|             ('a_1', 1, 'id', 'a'), |  | ||||||
|             ('a_1', 2, 'id', 'l'), |  | ||||||
|             ('a_1', 3, 'id', 'u'), |  | ||||||
|             ('a_1', 4, 'id', 'e'), |  | ||||||
|             ('env', 1, 'prob', 1), |  | ||||||
|             ('env', 2, 'prob', 2), |  | ||||||
|             ('env', 3, 'prob', 3), |  | ||||||
|             ('a_2', 7, 'finished', True), |  | ||||||
|         ) |  | ||||||
|         db_path = os.path.join(DBROOT, 'test') |  | ||||||
|         h = history.History(db_path=db_path) |  | ||||||
|         h.save_tuples(tuples) |  | ||||||
|         h.flush_cache() |  | ||||||
|         assert os.path.exists(db_path) |  | ||||||
|  |  | ||||||
|         # Recover the data |  | ||||||
|         recovered = history.History(db_path=db_path) |  | ||||||
|         assert recovered['a_1', 0, 'id'] == 'v' |  | ||||||
|         assert recovered['a_1', 4, 'id'] == 'e' |  | ||||||
|  |  | ||||||
|         # Using backup=True should create a backup copy, and initialize an empty history |  | ||||||
|         newhistory = history.History(db_path=db_path, backup=True) |  | ||||||
|         backuppaths = glob(db_path + '.backup*.sqlite') |  | ||||||
|         assert len(backuppaths) == 1 |  | ||||||
|         backuppath = backuppaths[0] |  | ||||||
|         assert newhistory.db_path == h.db_path |  | ||||||
|         assert os.path.exists(backuppath) |  | ||||||
|         assert len(newhistory[None, None, None]) == 0 |  | ||||||
|  |  | ||||||
|     def test_history_tuples(self): |  | ||||||
|         """ |  | ||||||
|         The data recovered should be equal to the one recorded. |  | ||||||
|         """ |  | ||||||
|         tuples = ( |  | ||||||
|             ('a_1', 0, 'id', 'v'), |  | ||||||
|             ('a_1', 1, 'id', 'a'), |  | ||||||
|             ('a_1', 2, 'id', 'l'), |  | ||||||
|             ('a_1', 3, 'id', 'u'), |  | ||||||
|             ('a_1', 4, 'id', 'e'), |  | ||||||
|             ('env', 1, 'prob', 1), |  | ||||||
|             ('env', 2, 'prob', 2), |  | ||||||
|             ('env', 3, 'prob', 3), |  | ||||||
|             ('a_2', 7, 'finished', True), |  | ||||||
|         ) |  | ||||||
|         h = history.History() |  | ||||||
|         h.save_tuples(tuples) |  | ||||||
|         recovered = list(h.to_tuples()) |  | ||||||
|         assert recovered |  | ||||||
|         for i in recovered: |  | ||||||
|             assert i in tuples |  | ||||||
| @@ -9,7 +9,8 @@ from functools import partial | |||||||
|  |  | ||||||
| from os.path import join | from os.path import join | ||||||
| from soil import (simulation, Environment, agents, serialization, | from soil import (simulation, Environment, agents, serialization, | ||||||
|                   history, utils) |                   utils) | ||||||
|  | from soil.time import Delta | ||||||
|  |  | ||||||
|  |  | ||||||
| ROOT = os.path.abspath(os.path.dirname(__file__)) | ROOT = os.path.abspath(os.path.dirname(__file__)) | ||||||
| @@ -20,8 +21,8 @@ class CustomAgent(agents.FSM): | |||||||
|     @agents.default_state |     @agents.default_state | ||||||
|     @agents.state |     @agents.state | ||||||
|     def normal(self): |     def normal(self): | ||||||
|         self.state['neighbors'] = self.count_agents(state_id='normal', |         self.neighbors = self.count_agents(state_id='normal', | ||||||
|                                                     limit_neighbors=True) |                                            limit_neighbors=True) | ||||||
|     @agents.state |     @agents.state | ||||||
|     def unreachable(self): |     def unreachable(self): | ||||||
|         return |         return | ||||||
| @@ -115,7 +116,7 @@ class TestMain(TestCase): | |||||||
|             'network_agents': [{ |             'network_agents': [{ | ||||||
|                 'agent_type': 'AggregatedCounter', |                 'agent_type': 'AggregatedCounter', | ||||||
|                 'weight': 1, |                 'weight': 1, | ||||||
|                 'state': {'id': 0} |                 'state': {'state_id': 0} | ||||||
|  |  | ||||||
|             }], |             }], | ||||||
|             'max_time': 10, |             'max_time': 10, | ||||||
| @@ -126,7 +127,7 @@ class TestMain(TestCase): | |||||||
|         env = s.run_simulation(dry_run=True)[0] |         env = s.run_simulation(dry_run=True)[0] | ||||||
|         for agent in env.network_agents: |         for agent in env.network_agents: | ||||||
|             last = 0 |             last = 0 | ||||||
|             assert len(agent[None, None]) == 10 |             assert len(agent[None, None]) == 11 | ||||||
|             for step, total in sorted(agent['total', None]): |             for step, total in sorted(agent['total', None]): | ||||||
|                 assert total == last + 2 |                 assert total == last + 2 | ||||||
|                 last = total |                 last = total | ||||||
| @@ -148,10 +149,9 @@ class TestMain(TestCase): | |||||||
|         } |         } | ||||||
|         s = simulation.from_config(config) |         s = simulation.from_config(config) | ||||||
|         env = s.run_simulation(dry_run=True)[0] |         env = s.run_simulation(dry_run=True)[0] | ||||||
|         assert env.get_agent(0).state['neighbors'] == 1 |  | ||||||
|         assert env.get_agent(0).state['neighbors'] == 1 |  | ||||||
|         assert env.get_agent(1).count_agents(state_id='normal') == 2 |         assert env.get_agent(1).count_agents(state_id='normal') == 2 | ||||||
|         assert env.get_agent(1).count_agents(state_id='normal', limit_neighbors=True) == 1 |         assert env.get_agent(1).count_agents(state_id='normal', limit_neighbors=True) == 1 | ||||||
|  |         assert env.get_agent(0).neighbors == 1 | ||||||
|  |  | ||||||
|     def test_torvalds_example(self): |     def test_torvalds_example(self): | ||||||
|         """A complete example from a documentation should work.""" |         """A complete example from a documentation should work.""" | ||||||
| @@ -198,11 +198,11 @@ class TestMain(TestCase): | |||||||
|         """ |         """ | ||||||
|         config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0] |         config = serialization.load_file(join(EXAMPLES, 'complete.yml'))[0] | ||||||
|         s = simulation.from_config(config) |         s = simulation.from_config(config) | ||||||
|         for i in range(5): |  | ||||||
|             s.run_simulation(dry_run=True) |         s.run_simulation(dry_run=True) | ||||||
|             nconfig = s.to_dict() |         nconfig = s.to_dict() | ||||||
|             del nconfig['topology'] |         del nconfig['topology'] | ||||||
|             assert config == nconfig |         assert config == nconfig | ||||||
|  |  | ||||||
|     def test_row_conversion(self): |     def test_row_conversion(self): | ||||||
|         env = Environment() |         env = Environment() | ||||||
| @@ -211,7 +211,7 @@ class TestMain(TestCase): | |||||||
|         res = list(env.history_to_tuples()) |         res = list(env.history_to_tuples()) | ||||||
|         assert len(res) == len(env.environment_params) |         assert len(res) == len(env.environment_params) | ||||||
|  |  | ||||||
|         env._now = 1 |         env.schedule.time = 1 | ||||||
|         env['test'] = 'second_value' |         env['test'] = 'second_value' | ||||||
|         res = list(env.history_to_tuples()) |         res = list(env.history_to_tuples()) | ||||||
|  |  | ||||||
| @@ -281,7 +281,7 @@ class TestMain(TestCase): | |||||||
|                 'weight': 2 |                 'weight': 2 | ||||||
|             }, |             }, | ||||||
|         ] |         ] | ||||||
|         converted = agents.deserialize_distribution(agent_distro) |         converted = agents.deserialize_definition(agent_distro) | ||||||
|         assert converted[0]['agent_type'] == agents.CounterModel |         assert converted[0]['agent_type'] == agents.CounterModel | ||||||
|         assert converted[1]['agent_type'] == CustomAgent |         assert converted[1]['agent_type'] == CustomAgent | ||||||
|         pickle.dumps(converted) |         pickle.dumps(converted) | ||||||
| @@ -297,14 +297,14 @@ class TestMain(TestCase): | |||||||
|                 'weight': 2 |                 'weight': 2 | ||||||
|             }, |             }, | ||||||
|         ] |         ] | ||||||
|         converted = agents.serialize_distribution(agent_distro) |         converted = agents.serialize_definition(agent_distro) | ||||||
|         assert converted[0]['agent_type'] == 'CounterModel' |         assert converted[0]['agent_type'] == 'CounterModel' | ||||||
|         assert converted[1]['agent_type'] == 'test_main.CustomAgent' |         assert converted[1]['agent_type'] == 'test_main.CustomAgent' | ||||||
|         pickle.dumps(converted) |         pickle.dumps(converted) | ||||||
|  |  | ||||||
|     def test_pickle_agent_environment(self): |     def test_pickle_agent_environment(self): | ||||||
|         env = Environment(name='Test') |         env = Environment(name='Test') | ||||||
|         a = agents.BaseAgent(environment=env, agent_id=25) |         a = agents.BaseAgent(model=env, unique_id=25) | ||||||
|  |  | ||||||
|         a['key'] = 'test' |         a['key'] = 'test' | ||||||
|  |  | ||||||
| @@ -316,12 +316,6 @@ class TestMain(TestCase): | |||||||
|         assert recovered['key', 0] == 'test' |         assert recovered['key', 0] == 'test' | ||||||
|         assert recovered['key'] == 'test' |         assert recovered['key'] == 'test' | ||||||
|  |  | ||||||
|     def test_history(self): |  | ||||||
|         '''Test storing in and retrieving from history (sqlite)''' |  | ||||||
|         h = history.History() |  | ||||||
|         h.save_record(agent_id=0, t_step=0, key="test", value="hello") |  | ||||||
|         assert h[0, 0, "test"] == "hello" |  | ||||||
|  |  | ||||||
|     def test_subgraph(self): |     def test_subgraph(self): | ||||||
|         '''An agent should be able to subgraph the global topology''' |         '''An agent should be able to subgraph the global topology''' | ||||||
|         G = nx.Graph() |         G = nx.Graph() | ||||||
| @@ -343,4 +337,55 @@ class TestMain(TestCase): | |||||||
|         configs = serialization.load_file(join(EXAMPLES, 'template.yml')) |         configs = serialization.load_file(join(EXAMPLES, 'template.yml')) | ||||||
|         assert len(configs) > 0 |         assert len(configs) > 0 | ||||||
|  |  | ||||||
|  |     def test_until(self): | ||||||
|  |         config = { | ||||||
|  |             'name': 'until_sim', | ||||||
|  |             'network_params': {}, | ||||||
|  |             'agent_type': 'CounterModel', | ||||||
|  |             'max_time': 2, | ||||||
|  |             'num_trials': 50, | ||||||
|  |             'environment_params': {} | ||||||
|  |         } | ||||||
|  |         s = simulation.from_config(config) | ||||||
|  |         runs = list(s.run_simulation(dry_run=True)) | ||||||
|  |         over = list(x.now for x in runs if x.now>2) | ||||||
|  |         assert len(runs) == config['num_trials'] | ||||||
|  |         assert len(over) == 0 | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     def test_fsm(self): | ||||||
|  |         '''Basic state change''' | ||||||
|  |         class ToggleAgent(agents.FSM): | ||||||
|  |             @agents.default_state | ||||||
|  |             @agents.state | ||||||
|  |             def ping(self): | ||||||
|  |                 return self.pong | ||||||
|  |  | ||||||
|  |             @agents.state | ||||||
|  |             def pong(self): | ||||||
|  |                 return self.ping | ||||||
|  |  | ||||||
|  |         a = ToggleAgent(unique_id=1, model=Environment()) | ||||||
|  |         assert a.state_id == a.ping.id | ||||||
|  |         a.step() | ||||||
|  |         assert a.state_id == a.pong.id | ||||||
|  |         a.step() | ||||||
|  |         assert a.state_id == a.ping.id | ||||||
|  |  | ||||||
|  |     def test_fsm_when(self): | ||||||
|  |         '''Basic state change''' | ||||||
|  |         class ToggleAgent(agents.FSM): | ||||||
|  |             @agents.default_state | ||||||
|  |             @agents.state | ||||||
|  |             def ping(self): | ||||||
|  |                 return self.pong, 2 | ||||||
|  |  | ||||||
|  |             @agents.state | ||||||
|  |             def pong(self): | ||||||
|  |                 return self.ping | ||||||
|  |  | ||||||
|  |         a = ToggleAgent(unique_id=1, model=Environment()) | ||||||
|  |         when = a.step() | ||||||
|  |         assert when == 2 | ||||||
|  |         when = a.step() | ||||||
|  |         assert when == Delta(a.interval) | ||||||
|   | |||||||
							
								
								
									
										69
									
								
								tests/test_mesa.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								tests/test_mesa.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | |||||||
|  | ''' | ||||||
|  | Mesa-SOIL integration tests | ||||||
|  |  | ||||||
|  | We have to test that: | ||||||
|  | - Mesa agents can be used in SOIL | ||||||
|  | - Simplified soil agents can be used in mesa simulations | ||||||
|  | - Mesa and soil agents can interact in a simulation | ||||||
|  |  | ||||||
|  | - Mesa visualizations work with SOIL simulations | ||||||
|  |  | ||||||
|  | ''' | ||||||
|  | from mesa import Agent, Model | ||||||
|  | from mesa.time import RandomActivation | ||||||
|  | from mesa.space import MultiGrid | ||||||
|  |  | ||||||
|  | class MoneyAgent(Agent): | ||||||
|  |     """ An agent with fixed initial wealth.""" | ||||||
|  |     def __init__(self, unique_id, model): | ||||||
|  |         super().__init__(unique_id, model) | ||||||
|  |         self.wealth = 1 | ||||||
|  |  | ||||||
|  |     def step(self): | ||||||
|  |         self.move() | ||||||
|  |         if self.wealth > 0: | ||||||
|  |             self.give_money() | ||||||
|  |  | ||||||
|  |     def give_money(self): | ||||||
|  |         cellmates = self.model.grid.get_cell_list_contents([self.pos]) | ||||||
|  |         if len(cellmates) > 1: | ||||||
|  |             other = self.random.choice(cellmates) | ||||||
|  |             other.wealth += 1 | ||||||
|  |             self.wealth -= 1 | ||||||
|  |  | ||||||
|  |     def move(self): | ||||||
|  |         possible_steps = self.model.grid.get_neighborhood( | ||||||
|  |             self.pos, | ||||||
|  |             moore=True, | ||||||
|  |             include_center=False) | ||||||
|  |         new_position = self.random.choice(possible_steps) | ||||||
|  |         self.model.grid.move_agent(self, new_position) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MoneyModel(Model): | ||||||
|  |     """A model with some number of agents.""" | ||||||
|  |     def __init__(self, N, width, height): | ||||||
|  |         self.num_agents = N | ||||||
|  |         self.grid = MultiGrid(width, height, True) | ||||||
|  |         self.schedule = RandomActivation(self) | ||||||
|  |  | ||||||
|  |         # Create agents | ||||||
|  |         for i in range(self.num_agents): | ||||||
|  |             a = MoneyAgent(i, self) | ||||||
|  |             self.schedule.add(a) | ||||||
|  |  | ||||||
|  |             # Add the agent to a random grid cell | ||||||
|  |             x = self.random.randrange(self.grid.width) | ||||||
|  |             y = self.random.randrange(self.grid.height) | ||||||
|  |             self.grid.place_agent(a, (x, y)) | ||||||
|  |  | ||||||
|  |     def step(self): | ||||||
|  |         '''Advance the model by one step.''' | ||||||
|  |         self.schedule.step() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # model = MoneyModel(10) | ||||||
|  | # for i in range(10): | ||||||
|  | #     model.step() | ||||||
|  |  | ||||||
|  | # agent_wealth = [a.wealth for a in model.schedule.agents] | ||||||
							
								
								
									
										34
									
								
								tests/test_stats.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								tests/test_stats.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,34 @@ | |||||||
|  | from unittest import TestCase | ||||||
|  |  | ||||||
|  | from soil import simulation, stats | ||||||
|  | from soil.utils import unflatten_dict | ||||||
|  |  | ||||||
|  | class Stats(TestCase): | ||||||
|  |  | ||||||
|  |     def test_distribution(self): | ||||||
|  |         '''The distribution exporter should write the number of agents in each state''' | ||||||
|  |         config = { | ||||||
|  |             'name': 'exporter_sim', | ||||||
|  |             'network_params': { | ||||||
|  |                 'generator': 'complete_graph', | ||||||
|  |                 'n': 4 | ||||||
|  |             }, | ||||||
|  |             'agent_type': 'CounterModel', | ||||||
|  |             'max_time': 2, | ||||||
|  |             'num_trials': 5, | ||||||
|  |             'environment_params': {} | ||||||
|  |         } | ||||||
|  |         s = simulation.from_config(config) | ||||||
|  |         for env in s.run_simulation(stats=[stats.distribution]): | ||||||
|  |             pass | ||||||
|  |             # stats_res = unflatten_dict(dict(env._history['stats', -1, None])) | ||||||
|  |         allstats = s.get_stats() | ||||||
|  |         for stat in allstats: | ||||||
|  |             assert 'count' in stat | ||||||
|  |             assert 'mean' in stat | ||||||
|  |             if 'trial_id' in stat: | ||||||
|  |                 assert stat['mean']['neighbors'] == 3 | ||||||
|  |                 assert stat['count']['total']['4'] == 4 | ||||||
|  |             else: | ||||||
|  |                 assert stat['count']['count']['neighbors']['3'] == 20 | ||||||
|  |                 assert stat['mean']['min']['neighbors'] == stat['mean']['max']['neighbors'] | ||||||
		Reference in New Issue
	
	Block a user