mirror of
				https://github.com/gsi-upm/sitc
				synced 2025-10-31 07:28:17 +00:00 
			
		
		
		
	Actualizada práctica a gymnasium y extendida
This commit is contained in:
		| @@ -48,7 +48,9 @@ | |||||||
|    "cell_type": "markdown", |    "cell_type": "markdown", | ||||||
|    "metadata": {}, |    "metadata": {}, | ||||||
|    "source": [ |    "source": [ | ||||||
|     "1. [Q-Learning](2_6_1_Q-Learning.ipynb)" |     "1. [Q-Learning](2_6_1_Q-Learning_Basic.ipynb)\n", | ||||||
|  |     "1. [Visualization](2_6_1_Q-Learning_Visualization.ipynb)\n", | ||||||
|  |     "1. [Exercises](2_6_1_Q-Learning_Exercises.ipynb)" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
| @@ -64,7 +66,7 @@ | |||||||
|  ], |  ], | ||||||
|  "metadata": { |  "metadata": { | ||||||
|   "kernelspec": { |   "kernelspec": { | ||||||
|    "display_name": "Python 3", |    "display_name": "Python 3 (ipykernel)", | ||||||
|    "language": "python", |    "language": "python", | ||||||
|    "name": "python3" |    "name": "python3" | ||||||
|   }, |   }, | ||||||
| @@ -78,7 +80,7 @@ | |||||||
|    "name": "python", |    "name": "python", | ||||||
|    "nbconvert_exporter": "python", |    "nbconvert_exporter": "python", | ||||||
|    "pygments_lexer": "ipython3", |    "pygments_lexer": "ipython3", | ||||||
|    "version": "3.7.1" |    "version": "3.10.10" | ||||||
|   }, |   }, | ||||||
|   "latex_envs": { |   "latex_envs": { | ||||||
|    "LaTeX_envs_menu_present": true, |    "LaTeX_envs_menu_present": true, | ||||||
|   | |||||||
							
								
								
									
										1384
									
								
								ml5/2_6_1_Q-Learning_Basic.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1384
									
								
								ml5/2_6_1_Q-Learning_Basic.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										138
									
								
								ml5/2_6_1_Q-Learning_Exercises.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										138
									
								
								ml5/2_6_1_Q-Learning_Exercises.ipynb
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,138 @@ | |||||||
|  | { | ||||||
|  |  "cells": [ | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "# Course Notes for Learning Intelligent Systems" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "Department of Telematic Engineering Systems, Universidad Politécnica de Madrid, © Carlos Á. Iglesias" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "## [Introduction to Machine Learning V](2_6_0_Intro_RL.ipynb)" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "# Exercises\n", | ||||||
|  |     "\n", | ||||||
|  |     "\n", | ||||||
|  |     "## Taxi\n", | ||||||
|  |     "Analyze the [Taxi problem](https://gymnasium.farama.org/environments/toy_text/taxi/) and solve it applying Q-Learning. You can find a solution as the one previously presented  [here](https://www.oreilly.com/learning/introduction-to-reinforcement-learning-and-openai-gym), and the notebook is [here](https://github.com/wagonhelm/Reinforcement-Learning-Introduction/blob/master/Reinforcement%20Learning%20Introduction.ipynb). Take into account that Gymnasium has changed, so you will have to adapt the code.\n", | ||||||
|  |     "\n", | ||||||
|  |     "Analyze the impact of not changing the learning rate or changing it in a different way. " | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "# Optional exercises\n", | ||||||
|  |     "Select one of the following exercises.\n", | ||||||
|  |     "\n", | ||||||
|  |     "## Blackjack\n", | ||||||
|  |     "Analyze how to appy Q-Learning for solving Blackjack.\n", | ||||||
|  |     "You can find information in this [article](https://gymnasium.farama.org/tutorials/training_agents/blackjack_tutorial/).\n", | ||||||
|  |     "\n", | ||||||
|  |     "## Doom\n", | ||||||
|  |     "Read this [article](https://medium.freecodecamp.org/an-introduction-to-deep-q-learning-lets-play-doom-54d02d8017d8) and execute the companion [notebook](https://github.com/simoninithomas/Deep_reinforcement_learning_Course/blob/master/Deep%20Q%20Learning/Doom/Deep%20Q%20learning%20with%20Doom.ipynb). Analyze the results and provide conclusions about DQN.\n", | ||||||
|  |     "\n" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "## References\n", | ||||||
|  |     "* [Gymnasium documentation](https://gymnasium.farama.org/).\n", | ||||||
|  |     "* [Diving deeper into Reinforcement Learning with Q-Learning, Thomas Simonini](https://medium.freecodecamp.org/diving-deeper-into-reinforcement-learning-with-q-learning-c18d0db58efe).\n", | ||||||
|  |     "* Illustrations by [Thomas Simonini](https://github.com/simoninithomas/Deep_reinforcement_learning_Course) and [Sung Kim](https://www.youtube.com/watch?v=xgoO54qN4lY).\n", | ||||||
|  |     "* [Frozen Lake solution with TensorFlow](https://analyticsindiamag.com/openai-gym-frozen-lake-beginners-guide-reinforcement-learning/)\n", | ||||||
|  |     "* [Deep Q-Learning for Doom](https://medium.freecodecamp.org/an-introduction-to-deep-q-learning-lets-play-doom-54d02d8017d8)\n", | ||||||
|  |     "* [Intro OpenAI Gym with Random Search and the Cart Pole scenario](http://www.pinchofintelligence.com/getting-started-openai-gym/)\n", | ||||||
|  |     "* [Q-Learning for the Taxi scenario](https://www.oreilly.com/learning/introduction-to-reinforcement-learning-and-openai-gym)" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "## Licence" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "The notebook is freely licensed under under the [Creative Commons Attribution Share-Alike license](https://creativecommons.org/licenses/by/2.0/).  \n", | ||||||
|  |     "\n", | ||||||
|  |     "© Carlos Á. Iglesias, Universidad Politécnica de Madrid." | ||||||
|  |    ] | ||||||
|  |   } | ||||||
|  |  ], | ||||||
|  |  "metadata": { | ||||||
|  |   "datacleaner": { | ||||||
|  |    "position": { | ||||||
|  |     "top": "50px" | ||||||
|  |    }, | ||||||
|  |    "python": { | ||||||
|  |     "varRefreshCmd": "try:\n    print(_datacleaner.dataframe_metadata())\nexcept:\n    print([])" | ||||||
|  |    }, | ||||||
|  |    "window_display": false | ||||||
|  |   }, | ||||||
|  |   "kernelspec": { | ||||||
|  |    "display_name": "Python 3 (ipykernel)", | ||||||
|  |    "language": "python", | ||||||
|  |    "name": "python3" | ||||||
|  |   }, | ||||||
|  |   "language_info": { | ||||||
|  |    "codemirror_mode": { | ||||||
|  |     "name": "ipython", | ||||||
|  |     "version": 3 | ||||||
|  |    }, | ||||||
|  |    "file_extension": ".py", | ||||||
|  |    "mimetype": "text/x-python", | ||||||
|  |    "name": "python", | ||||||
|  |    "nbconvert_exporter": "python", | ||||||
|  |    "pygments_lexer": "ipython3", | ||||||
|  |    "version": "3.10.10" | ||||||
|  |   }, | ||||||
|  |   "latex_envs": { | ||||||
|  |    "LaTeX_envs_menu_present": true, | ||||||
|  |    "autocomplete": true, | ||||||
|  |    "bibliofile": "biblio.bib", | ||||||
|  |    "cite_by": "apalike", | ||||||
|  |    "current_citInitial": 1, | ||||||
|  |    "eqLabelWithNumbers": true, | ||||||
|  |    "eqNumInitial": 1, | ||||||
|  |    "hotkeys": { | ||||||
|  |     "equation": "Ctrl-E", | ||||||
|  |     "itemize": "Ctrl-I" | ||||||
|  |    }, | ||||||
|  |    "labels_anchors": false, | ||||||
|  |    "latex_user_defs": false, | ||||||
|  |    "report_style_numbering": false, | ||||||
|  |    "user_envs_cfg": false | ||||||
|  |   } | ||||||
|  |  }, | ||||||
|  |  "nbformat": 4, | ||||||
|  |  "nbformat_minor": 1 | ||||||
|  | } | ||||||
							
								
								
									
										368
									
								
								ml5/2_6_1_Q-Learning_Visualization.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										368
									
								
								ml5/2_6_1_Q-Learning_Visualization.ipynb
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										274
									
								
								ml5/qlearning.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										274
									
								
								ml5/qlearning.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,274 @@ | |||||||
|  | # Class definition of QLearning | ||||||
|  |  | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import NamedTuple | ||||||
|  |  | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | import numpy as np | ||||||
|  | import pandas as pd | ||||||
|  | import seaborn as sns | ||||||
|  | from tqdm import tqdm | ||||||
|  |  | ||||||
|  | import gymnasium as gym | ||||||
|  | from gymnasium.envs.toy_text.frozen_lake import generate_random_map | ||||||
|  |  | ||||||
|  | # Params | ||||||
|  |  | ||||||
|  | class Params(NamedTuple): | ||||||
|  |     total_episodes: int  # Total episodes | ||||||
|  |     learning_rate: float  # Learning rate | ||||||
|  |     gamma: float  # Discounting rate | ||||||
|  |     epsilon: float  # Exploration probability | ||||||
|  |     map_size: int  # Number of tiles of one side of the squared environment | ||||||
|  |     seed: int  # Define a seed so that we get reproducible results | ||||||
|  |     is_slippery: bool  # If true the player will move in intended direction with probability of 1/3 else will move in either perpendicular direction with equal probability of 1/3 in both directions | ||||||
|  |     n_runs: int  # Number of runs | ||||||
|  |     action_size: int  # Number of possible actions | ||||||
|  |     state_size: int  # Number of possible states | ||||||
|  |     proba_frozen: float  # Probability that a tile is frozen | ||||||
|  |     savefig_folder: Path  # Root folder where plots are saved | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class Qlearning: | ||||||
|  |     def __init__(self, learning_rate, gamma, state_size, action_size): | ||||||
|  |         self.state_size = state_size | ||||||
|  |         self.action_size = action_size | ||||||
|  |         self.learning_rate = learning_rate | ||||||
|  |         self.gamma = gamma | ||||||
|  |         self.reset_qtable() | ||||||
|  |  | ||||||
|  |     def update(self, state, action, reward, new_state): | ||||||
|  |         """Update Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)]""" | ||||||
|  |         delta = ( | ||||||
|  |             reward | ||||||
|  |             + self.gamma * np.max(self.qtable[new_state][:]) | ||||||
|  |             - self.qtable[state][action] | ||||||
|  |         ) | ||||||
|  |         q_update = self.qtable[state][action] + self.learning_rate * delta | ||||||
|  |         return q_update | ||||||
|  |  | ||||||
|  |     def reset_qtable(self): | ||||||
|  |         """Reset the Q-table.""" | ||||||
|  |         self.qtable = np.zeros((self.state_size, self.action_size)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class EpsilonGreedy: | ||||||
|  |     def __init__(self, epsilon, rng): | ||||||
|  |         self.epsilon = epsilon | ||||||
|  |         self.rng = rng | ||||||
|  |  | ||||||
|  |     def choose_action(self, action_space, state, qtable): | ||||||
|  |         """Choose an action `a` in the current world state (s).""" | ||||||
|  |         # First we randomize a number | ||||||
|  |         explor_exploit_tradeoff = self.rng.uniform(0, 1) | ||||||
|  |  | ||||||
|  |         # Exploration | ||||||
|  |         if explor_exploit_tradeoff < self.epsilon: | ||||||
|  |             action = action_space.sample() | ||||||
|  |  | ||||||
|  |         # Exploitation (taking the biggest Q-value for this state) | ||||||
|  |         else: | ||||||
|  |             # Break ties randomly | ||||||
|  |             # If all actions are the same for this state we choose a random one | ||||||
|  |             # (otherwise `np.argmax()` would always take the first one) | ||||||
|  |             if np.all(qtable[state][:]) == qtable[state][0]: | ||||||
|  |                 action = action_space.sample() | ||||||
|  |             else: | ||||||
|  |                 action = np.argmax(qtable[state][:]) | ||||||
|  |         return action | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def run_frozen_maps(maps, params, rng): | ||||||
|  |     """Run FrozenLake in maps and plot results""" | ||||||
|  |     map_sizes = maps | ||||||
|  |     res_all = pd.DataFrame() | ||||||
|  |     st_all = pd.DataFrame()  | ||||||
|  |    | ||||||
|  |     for map_size in map_sizes: | ||||||
|  |             env = gym.make( | ||||||
|  |             "FrozenLake-v1", | ||||||
|  |             is_slippery=params.is_slippery, | ||||||
|  |             render_mode="rgb_array", | ||||||
|  |             desc=generate_random_map( | ||||||
|  |                 size=map_size, p=params.proba_frozen, seed=params.seed | ||||||
|  |             ), | ||||||
|  |     ) | ||||||
|  |      | ||||||
|  |     params = params._replace(action_size=env.action_space.n) | ||||||
|  |     params = params._replace(state_size=env.observation_space.n) | ||||||
|  |     env.action_space.seed( | ||||||
|  |             params.seed | ||||||
|  |         )  # Set the seed to get reproducible results when sampling the action space | ||||||
|  |     learner = Qlearning( | ||||||
|  |         learning_rate=params.learning_rate, | ||||||
|  |         gamma=params.gamma, | ||||||
|  |         state_size=params.state_size, | ||||||
|  |         action_size=params.action_size, | ||||||
|  |     ) | ||||||
|  |     explorer = EpsilonGreedy( | ||||||
|  |         epsilon=params.epsilon, | ||||||
|  |         rng=rng | ||||||
|  |     ) | ||||||
|  |     print(f"Map size: {map_size}x{map_size}") | ||||||
|  |     rewards, steps, episodes, qtables, all_states, all_actions = run_env(env, params, learner, explorer) | ||||||
|  |  | ||||||
|  |         # Save the results in dataframes | ||||||
|  |     res, st = postprocess(episodes, params, rewards, steps, map_size) | ||||||
|  |     res_all = pd.concat([res_all, res]) | ||||||
|  |     st_all = pd.concat([st_all, st]) | ||||||
|  |     qtable = qtables.mean(axis=0)  # Average the Q-table between runs | ||||||
|  |  | ||||||
|  |     plot_states_actions_distribution( | ||||||
|  |         states=all_states, actions=all_actions, map_size=map_size, params=params | ||||||
|  |     )  # Sanity check | ||||||
|  |     plot_q_values_map(qtable, env, map_size, params) | ||||||
|  |  | ||||||
|  |     env.close() | ||||||
|  |     return res_all, st_all | ||||||
|  |  | ||||||
|  | def run_env(env, params, learner, explorer): | ||||||
|  |     rewards = np.zeros((params.total_episodes, params.n_runs)) | ||||||
|  |     steps = np.zeros((params.total_episodes, params.n_runs)) | ||||||
|  |     episodes = np.arange(params.total_episodes) | ||||||
|  |     qtables = np.zeros((params.n_runs, params.state_size, params.action_size)) | ||||||
|  |     all_states = [] | ||||||
|  |     all_actions = [] | ||||||
|  |      | ||||||
|  |     for run in range(params.n_runs):  # Run several times to account for stochasticity | ||||||
|  |         learner.reset_qtable()  # Reset the Q-table between runs | ||||||
|  |  | ||||||
|  |         for episode in tqdm( | ||||||
|  |             episodes, desc=f"Run {run}/{params.n_runs} - Episodes", leave=False | ||||||
|  |         ): | ||||||
|  |             state = env.reset(seed=params.seed)[0]  # Reset the environment | ||||||
|  |             step = 0 | ||||||
|  |             done = False | ||||||
|  |             total_rewards = 0 | ||||||
|  |  | ||||||
|  |             while not done: | ||||||
|  |                 action = explorer.choose_action( | ||||||
|  |                     action_space=env.action_space, state=state, qtable=learner.qtable | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |                 # Log all states and actions | ||||||
|  |                 all_states.append(state) | ||||||
|  |                 all_actions.append(action) | ||||||
|  |  | ||||||
|  |                 # Take the action (a) and observe the outcome state(s') and reward (r) | ||||||
|  |                 new_state, reward, terminated, truncated, info = env.step(action) | ||||||
|  |  | ||||||
|  |                 done = terminated or truncated | ||||||
|  |  | ||||||
|  |                 learner.qtable[state, action] = learner.update( | ||||||
|  |                     state, action, reward, new_state | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|  |                 total_rewards += reward | ||||||
|  |                 step += 1 | ||||||
|  |  | ||||||
|  |                 # Our new state is state | ||||||
|  |                 state = new_state | ||||||
|  |  | ||||||
|  |             # Log all rewards and steps | ||||||
|  |             rewards[episode, run] = total_rewards | ||||||
|  |             steps[episode, run] = step | ||||||
|  |         qtables[run, :, :] = learner.qtable | ||||||
|  |  | ||||||
|  |     return rewards, steps, episodes, qtables, all_states, all_actions | ||||||
|  |      | ||||||
|  | def postprocess(episodes, params, rewards, steps, map_size): | ||||||
|  |     """Convert the results of the simulation in dataframes.""" | ||||||
|  |     res = pd.DataFrame( | ||||||
|  |         data={ | ||||||
|  |             "Episodes": np.tile(episodes, reps=params.n_runs), | ||||||
|  |             "Rewards": rewards.flatten(), | ||||||
|  |             "Steps": steps.flatten(), | ||||||
|  |         } | ||||||
|  |     ) | ||||||
|  |     res["cum_rewards"] = rewards.cumsum(axis=0).flatten(order="F") | ||||||
|  |     res["map_size"] = np.repeat(f"{map_size}x{map_size}", res.shape[0]) | ||||||
|  |  | ||||||
|  |     st = pd.DataFrame(data={"Episodes": episodes, "Steps": steps.mean(axis=1)}) | ||||||
|  |     st["map_size"] = np.repeat(f"{map_size}x{map_size}", st.shape[0]) | ||||||
|  |     return res, st | ||||||
|  |      | ||||||
|  | def qtable_directions_map(qtable, map_size): | ||||||
|  |     """Get the best learned action & map it to arrows.""" | ||||||
|  |     qtable_val_max = qtable.max(axis=1).reshape(map_size, map_size) | ||||||
|  |     qtable_best_action = np.argmax(qtable, axis=1).reshape(map_size, map_size) | ||||||
|  |     directions = {0: "←", 1: "↓", 2: "→", 3: "↑"} | ||||||
|  |     qtable_directions = np.empty(qtable_best_action.flatten().shape, dtype=str) | ||||||
|  |     eps = np.finfo(float).eps  # Minimum float number on the machine | ||||||
|  |     for idx, val in enumerate(qtable_best_action.flatten()): | ||||||
|  |         if qtable_val_max.flatten()[idx] > eps: | ||||||
|  |             # Assign an arrow only if a minimal Q-value has been learned as best action | ||||||
|  |             # otherwise since 0 is a direction, it also gets mapped on the tiles where | ||||||
|  |             # it didn't actually learn anything | ||||||
|  |             qtable_directions[idx] = directions[val] | ||||||
|  |     qtable_directions = qtable_directions.reshape(map_size, map_size) | ||||||
|  |     return qtable_val_max, qtable_directions | ||||||
|  |  | ||||||
|  | def plot_q_values_map(qtable, env, map_size, params): | ||||||
|  |     """Plot the last frame of the simulation and the policy learned.""" | ||||||
|  |     qtable_val_max, qtable_directions = qtable_directions_map(qtable, map_size) | ||||||
|  |  | ||||||
|  |     # Plot the last frame | ||||||
|  |     fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15, 5)) | ||||||
|  |     ax[0].imshow(env.render()) | ||||||
|  |     ax[0].axis("off") | ||||||
|  |     ax[0].set_title("Last frame") | ||||||
|  |  | ||||||
|  |     # Plot the policy | ||||||
|  |     sns.heatmap( | ||||||
|  |         qtable_val_max, | ||||||
|  |         annot=qtable_directions, | ||||||
|  |         fmt="", | ||||||
|  |         ax=ax[1], | ||||||
|  |         cmap=sns.color_palette("Blues", as_cmap=True), | ||||||
|  |         linewidths=0.7, | ||||||
|  |         linecolor="black", | ||||||
|  |         xticklabels=[], | ||||||
|  |         yticklabels=[], | ||||||
|  |         annot_kws={"fontsize": "xx-large"}, | ||||||
|  |     ).set(title="Learned Q-values\nArrows represent best action") | ||||||
|  |     for _, spine in ax[1].spines.items(): | ||||||
|  |         spine.set_visible(True) | ||||||
|  |         spine.set_linewidth(0.7) | ||||||
|  |         spine.set_color("black") | ||||||
|  |     img_title = f"frozenlake_q_values_{map_size}x{map_size}.png" | ||||||
|  |     fig.savefig(params.savefig_folder / img_title, bbox_inches="tight") | ||||||
|  |     plt.show() | ||||||
|  |      | ||||||
|  | def plot_states_actions_distribution(states, actions, map_size, params): | ||||||
|  |     """Plot the distributions of states and actions.""" | ||||||
|  |     labels = {"LEFT": 0, "DOWN": 1, "RIGHT": 2, "UP": 3} | ||||||
|  |  | ||||||
|  |     fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15, 5)) | ||||||
|  |     sns.histplot(data=states, ax=ax[0], kde=True) | ||||||
|  |     ax[0].set_title("States") | ||||||
|  |     sns.histplot(data=actions, ax=ax[1]) | ||||||
|  |     ax[1].set_xticks(list(labels.values()), labels=labels.keys()) | ||||||
|  |     ax[1].set_title("Actions") | ||||||
|  |     fig.tight_layout() | ||||||
|  |     img_title = f"frozenlake_states_actions_distrib_{map_size}x{map_size}.png" | ||||||
|  |     fig.savefig(params.savefig_folder / img_title, bbox_inches="tight") | ||||||
|  |     plt.show() | ||||||
|  |  | ||||||
|  | def plot_steps_and_rewards(rewards_df, steps_df,params): | ||||||
|  |     """Plot the steps and rewards from dataframes.""" | ||||||
|  |     fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15, 5)) | ||||||
|  |     sns.lineplot( | ||||||
|  |         data=rewards_df, x="Episodes", y="cum_rewards", hue="map_size", ax=ax[0] | ||||||
|  |     ) | ||||||
|  |     ax[0].set(ylabel="Cumulated rewards") | ||||||
|  |  | ||||||
|  |     sns.lineplot(data=steps_df, x="Episodes", y="Steps", hue="map_size", ax=ax[1]) | ||||||
|  |     ax[1].set(ylabel="Averaged steps number") | ||||||
|  |  | ||||||
|  |     for axi in ax: | ||||||
|  |         axi.legend(title="map size") | ||||||
|  |     fig.tight_layout() | ||||||
|  |     img_title = "frozenlake_steps_and_rewards.png" | ||||||
|  |     fig.savefig(params.savefig_folder / img_title, bbox_inches="tight") | ||||||
|  |     plt.show() | ||||||
|  |  | ||||||
		Reference in New Issue
	
	Block a user