1
0
mirror of https://github.com/gsi-upm/sitc synced 2024-11-17 12:02:28 +00:00
sitc/ml5/2_6_1_Q-Learning_Basic.ipynb

1385 lines
54 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![](images/EscUpmPolit_p.gif \"UPM\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Course Notes for Learning Intelligent Systems"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Department of Telematic Engineering Systems, Universidad Politécnica de Madrid, © Carlos Á. Iglesias"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [Introduction to Machine Learning V](2_6_0_Intro_RL.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Introduction\n",
"The purpose of this practice is to understand better Reinforcement Learning (RL) and, in particular, Q-Learning.\n",
"\n",
"We are going to use [Gymnasium](https://gymnasium.farama.org/). Gymnasium is toolkit for developing and comparing RL algorithms. It is a fork of [Open AI Gym](https://www.gymlibrary.dev/). Take a loot at ther [website](https://github.com/Farama-Foundation/Gymnasium).\n",
"\n",
"It implements [algorithm imitation](http://gym.openai.com/envs/#algorithmic), [classic control problems](https://gymnasium.farama.org/environments/classic_control/), [Atari games](https://gymnasium.farama.org/environments/atari/), [Box2D continuous control](https://gymnasium.farama.org/environments/box2d/), [robotics with MuJoCo, Multi-Joint dynamics with Contact](https://gymnasium.farama.org/environments/mujoco/), [simple text based environments](https://gymnasium.farama.org/environments/toy_text/), and [other problems](https://gymnasium.farama.org/environments/third_party_environments/).\n",
"\n",
"This notebook is based on [Diving deeper into Reinforcement Learning with Q-Learning](https://medium.freecodecamp.org/diving-deeper-into-reinforcement-learning-with-q-learning-c18d0db58efe) and [Introduction to Q-Learning](https://huggingface.co/deep-rl-course/unit2/hands-on?fw=pt).\n",
"\n",
"First of all, install the Gymnasium library, which is a fork of the OpenAI Gym library:\n",
"\n",
"```console\n",
"foo@bar:~$ conda install gymnasium\n",
"```\n",
"\n",
"If you get an error 'No module named 'Box2D', install 'pybox2d'.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Getting started with Gymnasium\n",
"\n",
"First of all, read the [introduction](https://gymnasium.farama.org/content/basic_usage/) of Gymnasium."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Environments\n",
"OpenGym provides a number of problems called *environments*. \n",
"\n",
"Try 'LunarLander-v2' (or 'CartPole-v1', 'MountainCar-v0', etc.)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from warnings import filterwarnings\n",
"filterwarnings(action='ignore', category=DeprecationWarning)\n",
"\n",
"\n",
"import gymnasium as gym\n",
"#env = gym.make(\"LunarLander-v2\", render_mode=\"human\")\n",
"#env = gym.make(\"CartPole-v1\", render_mode=\"human\")\n",
"env = gym.make(\"MountainCar-v0\", render_mode=\"human\")\n",
"observation, info = env.reset()\n",
"\n",
"for _ in range(100):\n",
" action = env.action_space.sample() # agent policy that uses the observation and info\n",
" observation, reward, terminated, truncated, info = env.step(action)\n",
"\n",
" if terminated or truncated:\n",
" observation, info = env.reset()\n",
"\n",
"env.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This will launch an external window with the game. If you cannot close that window, just execute in a code cell:\n",
"\n",
"```python\n",
"env.close()\n",
"```\n",
"\n",
"The full list of available environments can be found printing the environment registry as follows."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['CartPole-v0', 'CartPole-v1', 'MountainCar-v0', 'MountainCarContinuous-v0', 'Pendulum-v1', 'Acrobot-v1', 'CartPoleJax-v0', 'CartPoleJax-v1', 'PendulumJax-v0', 'LunarLander-v2', 'LunarLanderContinuous-v2', 'BipedalWalker-v3', 'BipedalWalkerHardcore-v3', 'CarRacing-v2', 'Blackjack-v1', 'FrozenLake-v1', 'FrozenLake8x8-v1', 'CliffWalking-v0', 'Taxi-v3', 'Reacher-v2', 'Reacher-v4', 'Pusher-v2', 'Pusher-v4', 'InvertedPendulum-v2', 'InvertedPendulum-v4', 'InvertedDoublePendulum-v2', 'InvertedDoublePendulum-v4', 'HalfCheetah-v2', 'HalfCheetah-v3', 'HalfCheetah-v4', 'Hopper-v2', 'Hopper-v3', 'Hopper-v4', 'Swimmer-v2', 'Swimmer-v3', 'Swimmer-v4', 'Walker2d-v2', 'Walker2d-v3', 'Walker2d-v4', 'Ant-v2', 'Ant-v3', 'Ant-v4', 'Humanoid-v2', 'Humanoid-v3', 'Humanoid-v4', 'HumanoidStandup-v2', 'HumanoidStandup-v4', 'GymV22Environment-v0', 'GymV26Environment-v0'])"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gym.envs.registry.keys()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can check the environment specification with spec."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"EnvSpec(id='MountainCar-v0', entry_point='gymnasium.envs.classic_control.mountain_car:MountainCarEnv', reward_threshold=-110.0, nondeterministic=False, max_episode_steps=200, order_enforce=True, autoreset=False, disable_env_checker=False, apply_api_compatibility=False, kwargs={'render_mode': 'human'}, namespace=None, name='MountainCar', version=0)\n"
]
}
],
"source": [
"print(env.spec)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The environments **step** function returns five values. These are:\n",
"\n",
"* **observation (object):** an environment-specific object representing your observation of the environment. For example, pixel data from a camera, joint angles and joint velocities of a robot, or the board state in a board game.\n",
"* **reward (float):** amount of reward achieved by the previous action. The scale varies between environments, but the goal is always to increase your total reward.\n",
"* **terminated (boolean):** whether the agent reaches the terminal state, which can be positive (e.g., reaching the goal state) or negative (e.g., you lost your last life). If true, the user needs to call *reset()*.\n",
"* **truncated (boolean):** when the truncated condition is satisfied (e.g., timelimit or mechanical problem in a robot). It can be used to end the episode prematurely before a terminal state is reached . If true, the user needs to call *reset()*.\n",
"* **info (dict):** diagnostic information useful for debugging. It can sometimes be useful for learning (for example, it might contain the raw probabilities behind the environments last state change). However, official evaluations of your agent are not allowed to use this for learning.\n",
"\n",
"The typical agent loop consists in first calling the method *reset* which provides an initial observation. Then the agent executes an action, and receives the reward, the new observation, and if the episode has finished (terminated or truncated are true). \n",
"\n",
"For example, analyze this sample of agent loop for 100 ms. The details of the previous variables for this game as described [here](https://gymnasium.farama.org/environments/classic_control/cart_pole/) are:\n",
"* **observation**: Cart Position, Cart Velocity, Pole Angle, Pole Velocity.\n",
"* **action**: 0\t(Push cart to the left), 1\t(Push cart to the right).\n",
"* **reward**: 1 for every step taken, including the termination step."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(array([0.02466699, 0.02517318, 0.04970684, 0.01225195], dtype=float32), {})\n",
"Action 0\n",
"Observation [ 0.02517045 -0.17062509 0.04995188 0.3201944 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.02517045 -0.17062509 0.04995188 0.3201944 ]\n",
"Action 0\n",
"Observation [ 0.02175795 -0.36642152 0.05635576 0.62820244] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.02175795 -0.36642152 0.05635576 0.62820244]\n",
"Action 0\n",
"Observation [ 0.01442952 -0.56228286 0.06891982 0.9380878 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.01442952 -0.56228286 0.06891982 0.9380878 ]\n",
"Action 1\n",
"Observation [ 0.00318386 -0.3681544 0.08768157 0.66783285] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.00318386 -0.3681544 0.08768157 0.66783285]\n",
"Action 1\n",
"Observation [-0.00417923 -0.17435414 0.10103822 0.4039946 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.00417923 -0.17435414 0.10103822 0.4039946 ]\n",
"Action 0\n",
"Observation [-0.00766631 -0.37075302 0.10911812 0.7267452 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.00766631 -0.37075302 0.10911812 0.7267452 ]\n",
"Action 1\n",
"Observation [-0.01508137 -0.17729534 0.12365302 0.47030163] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.01508137 -0.17729534 0.12365302 0.47030163]\n",
"Action 0\n",
"Observation [-0.01862728 -0.37392715 0.13305905 0.79925877] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.01862728 -0.37392715 0.13305905 0.79925877]\n",
"Action 1\n",
"Observation [-0.02610582 -0.18085699 0.14904423 0.55121744] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.02610582 -0.18085699 0.14904423 0.55121744]\n",
"Action 0\n",
"Observation [-0.02972296 -0.37772328 0.16006857 0.8869 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"(array([-0.0286674 , 0.02992311, -0.01944724, -0.04998275], dtype=float32), {})\n",
"Action 1\n",
"Observation [-0.02806894 0.22531843 -0.0204469 -0.34873745] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.02806894 0.22531843 -0.0204469 -0.34873745]\n",
"Action 1\n",
"Observation [-0.02356257 0.42072514 -0.02742165 -0.6477972 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.02356257 0.42072514 -0.02742165 -0.6477972 ]\n",
"Action 0\n",
"Observation [-0.01514807 0.22599575 -0.04037759 -0.3638739 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.01514807 0.22599575 -0.04037759 -0.3638739 ]\n",
"Action 0\n",
"Observation [-0.01062815 0.03147022 -0.04765507 -0.0841912 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.01062815 0.03147022 -0.04765507 -0.0841912 ]\n",
"Action 1\n",
"Observation [-0.00999875 0.22724174 -0.0493389 -0.39152038] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.00999875 0.22724174 -0.0493389 -0.39152038]\n",
"Action 1\n",
"Observation [-0.00545391 0.4230279 -0.0571693 -0.699342 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.00545391 0.4230279 -0.0571693 -0.699342 ]\n",
"Action 1\n",
"Observation [ 0.00300664 0.6188939 -0.07115614 -1.0094596 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.00300664 0.6188939 -0.07115614 -1.0094596 ]\n",
"Action 1\n",
"Observation [ 0.01538452 0.8148897 -0.09134534 -1.3236117 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.01538452 0.8148897 -0.09134534 -1.3236117 ]\n",
"Action 1\n",
"Observation [ 0.03168232 1.0110391 -0.11781757 -1.6434273 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.03168232 1.0110391 -0.11781757 -1.6434273 ]\n",
"Action 1\n",
"Observation [ 0.0519031 1.207327 -0.15068612 -1.9703763 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"(array([ 0.04422636, 0.03184601, -0.00627003, 0.03030435], dtype=float32), {})\n",
"Action 1\n",
"Observation [ 0.04486328 0.22705732 -0.00566394 -0.26435024] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.04486328 0.22705732 -0.00566394 -0.26435024]\n",
"Action 0\n",
"Observation [ 0.04940443 0.03201666 -0.01095094 0.02654087] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.04940443 0.03201666 -0.01095094 0.02654087]\n",
"Action 0\n",
"Observation [ 0.05004476 -0.16294654 -0.01042013 0.31574863] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.05004476 -0.16294654 -0.01042013 0.31574863]\n",
"Action 1\n",
"Observation [ 0.04678584 0.03232227 -0.00410515 0.01979785] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.04678584 0.03232227 -0.00410515 0.01979785]\n",
"Action 1\n",
"Observation [ 0.04743228 0.22750285 -0.0037092 -0.27417746] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.04743228 0.22750285 -0.0037092 -0.27417746]\n",
"Action 0\n",
"Observation [ 0.05198234 0.03243402 -0.00919275 0.01733326] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.05198234 0.03243402 -0.00919275 0.01733326]\n",
"Action 1\n",
"Observation [ 0.05263102 0.2276866 -0.00884608 -0.27823585] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.05263102 0.2276866 -0.00884608 -0.27823585]\n",
"Action 1\n",
"Observation [ 0.05718475 0.4229336 -0.0144108 -0.57369566] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.05718475 0.4229336 -0.0144108 -0.57369566]\n",
"Action 1\n",
"Observation [ 0.06564342 0.6182546 -0.02588471 -0.87088335] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.06564342 0.6182546 -0.02588471 -0.87088335]\n",
"Action 0\n",
"Observation [ 0.07800851 0.42349413 -0.04330238 -0.58644974] , reward 1.0 , terminated False , truncated False , info {}\n",
"(array([ 0.00689714, -0.004671 , -0.03079236, 0.01346231], dtype=float32), {})\n",
"Action 0\n",
"Observation [ 0.00680372 -0.19933812 -0.03052311 0.29627305] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.00680372 -0.19933812 -0.03052311 0.29627305]\n",
"Action 1\n",
"Observation [ 0.00281696 -0.00379464 -0.02459765 -0.00587795] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.00281696 -0.00379464 -0.02459765 -0.00587795]\n",
"Action 0\n",
"Observation [ 0.00274106 -0.19855535 -0.02471521 0.27894375] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.00274106 -0.19855535 -0.02471521 0.27894375]\n",
"Action 0\n",
"Observation [-0.00123004 -0.39331618 -0.01913633 0.56373024] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.00123004 -0.39331618 -0.01913633 0.56373024]\n",
"Action 1\n",
"Observation [-0.00909637 -0.197931 -0.00786173 0.26508042] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.00909637 -0.197931 -0.00786173 0.26508042]\n",
"Action 0\n",
"Observation [-0.01305499 -0.39293987 -0.00256012 0.55527335] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.01305499 -0.39293987 -0.00256012 0.55527335]\n",
"Action 0\n",
"Observation [-0.02091378 -0.5880258 0.00854535 0.8471486 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.02091378 -0.5880258 0.00854535 0.8471486 ]\n",
"Action 1\n",
"Observation [-0.0326743 -0.39302143 0.02548832 0.557165 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.0326743 -0.39302143 0.02548832 0.557165 ]\n",
"Action 0\n",
"Observation [-0.04053473 -0.58849174 0.03663162 0.85776806] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.04053473 -0.58849174 0.03663162 0.85776806]\n",
"Action 0\n",
"Observation [-0.05230456 -0.7840931 0.05378698 1.1617405 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"(array([ 0.01617916, -0.01409287, -0.04266112, 0.00518782], dtype=float32), {})\n",
"Action 1\n",
"Observation [ 0.01589731 0.18161412 -0.04255737 -0.30064413] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.01589731 0.18161412 -0.04255737 -0.30064413]\n",
"Action 0\n",
"Observation [ 0.01952959 -0.01287623 -0.04857025 -0.02168084] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.01952959 -0.01287623 -0.04857025 -0.02168084]\n",
"Action 1\n",
"Observation [ 0.01927206 0.1829074 -0.04900387 -0.329284 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.01927206 0.1829074 -0.04900387 -0.329284 ]\n",
"Action 0\n",
"Observation [ 0.02293021 -0.01148394 -0.05558955 -0.05244839] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.02293021 -0.01148394 -0.05558955 -0.05244839]\n",
"Action 0\n",
"Observation [ 0.02270053 -0.20576656 -0.05663851 0.22219047] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.02270053 -0.20576656 -0.05663851 0.22219047]\n",
"Action 1\n",
"Observation [ 0.0185852 -0.00988272 -0.0521947 -0.08780695] , reward 1.0 , terminated False , truncated False , info {}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.0185852 -0.00988272 -0.0521947 -0.08780695]\n",
"Action 0\n",
"Observation [ 0.01838755 -0.20421918 -0.05395084 0.18796247] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.01838755 -0.20421918 -0.05395084 0.18796247]\n",
"Action 0\n",
"Observation [ 0.01430316 -0.3985294 -0.05019159 0.46314988] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.01430316 -0.3985294 -0.05019159 0.46314988]\n",
"Action 1\n",
"Observation [ 0.00633258 -0.2027354 -0.04092859 0.1550786 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.00633258 -0.2027354 -0.04092859 0.1550786 ]\n",
"Action 0\n",
"Observation [ 0.00227787 -0.39724815 -0.03782702 0.43457374] , reward 1.0 , terminated False , truncated False , info {}\n",
"(array([-0.01675624, 0.00657571, -0.01191216, -0.03199182], dtype=float32), {})\n",
"Action 1\n",
"Observation [-0.01662473 0.20186645 -0.012552 -0.32840922] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.01662473 0.20186645 -0.012552 -0.32840922]\n",
"Action 1\n",
"Observation [-0.0125874 0.39716482 -0.01912018 -0.6250239 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.0125874 0.39716482 -0.01912018 -0.6250239 ]\n",
"Action 1\n",
"Observation [-0.0046441 0.59254843 -0.03162066 -0.9236667 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.0046441 0.59254843 -0.03162066 -0.9236667 ]\n",
"Action 1\n",
"Observation [ 0.00720687 0.7880829 -0.050094 -1.2261168 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.00720687 0.7880829 -0.050094 -1.2261168 ]\n",
"Action 1\n",
"Observation [ 0.02296852 0.98381275 -0.07461634 -1.5340647 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.02296852 0.98381275 -0.07461634 -1.5340647 ]\n",
"Action 1\n",
"Observation [ 0.04264478 1.17975 -0.10529763 -1.8490696 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.04264478 1.17975 -0.10529763 -1.8490696 ]\n",
"Action 1\n",
"Observation [ 0.06623978 1.3758619 -0.14227901 -2.1725085 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.06623978 1.3758619 -0.14227901 -2.1725085 ]\n",
"Action 1\n",
"Observation [ 0.09375702 1.5720552 -0.18572919 -2.505514 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.09375702 1.5720552 -0.18572919 -2.505514 ]\n",
"Action 0\n",
"Observation [ 0.12519813 1.3788872 -0.23583947 -2.2750359 ] , reward 1.0 , terminated True , truncated False , info {}\n",
"Episode finished after 9 timesteps\n",
"(array([ 0.01553306, -0.03829413, 0.01700553, 0.01151424], dtype=float32), {})\n",
"Action 0\n",
"Observation [ 0.01476718 -0.23365578 0.01723581 0.30951375] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.01476718 -0.23365578 0.01723581 0.30951375]\n",
"Action 1\n",
"Observation [ 0.01009406 -0.03878359 0.02342609 0.02231595] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.01009406 -0.03878359 0.02342609 0.02231595]\n",
"Action 0\n",
"Observation [ 0.00931839 -0.23423353 0.02387241 0.32229704] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.00931839 -0.23423353 0.02387241 0.32229704]\n",
"Action 1\n",
"Observation [ 0.00463372 -0.03945951 0.03031835 0.03723709] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.00463372 -0.03945951 0.03031835 0.03723709]\n",
"Action 1\n",
"Observation [ 0.00384453 0.15521485 0.03106309 -0.24572802] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.00384453 0.15521485 0.03106309 -0.24572802]\n",
"Action 1\n",
"Observation [ 0.00694883 0.34987968 0.02614853 -0.52845335] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.00694883 0.34987968 0.02614853 -0.52845335]\n",
"Action 0\n",
"Observation [ 0.01394642 0.1543998 0.01557946 -0.22764695] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.01394642 0.1543998 0.01557946 -0.22764695]\n",
"Action 1\n",
"Observation [ 0.01703442 0.34929568 0.01102652 -0.51537514] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.01703442 0.34929568 0.01102652 -0.51537514]\n",
"Action 1\n",
"Observation [ 2.4020329e-02 5.4426062e-01 7.1901991e-04 -8.0456305e-01] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 2.4020329e-02 5.4426062e-01 7.1901991e-04 -8.0456305e-01]\n",
"Action 1\n",
"Observation [ 0.03490554 0.73937273 -0.01537224 -1.0970197 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"(array([ 0.04527323, 0.02190002, -0.03927047, 0.01146642], dtype=float32), {})\n",
"Action 1\n",
"Observation [ 0.04571123 0.21756251 -0.03904114 -0.2933436 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.04571123 0.21756251 -0.03904114 -0.2933436 ]\n",
"Action 1\n",
"Observation [ 0.05006249 0.4132187 -0.04490801 -0.59807944] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.05006249 0.4132187 -0.04490801 -0.59807944]\n",
"Action 0\n",
"Observation [ 0.05832686 0.21875295 -0.0568696 -0.31987342] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.05832686 0.21875295 -0.0568696 -0.31987342]\n",
"Action 1\n",
"Observation [ 0.06270192 0.41463676 -0.06326707 -0.6299348 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.06270192 0.41463676 -0.06326707 -0.6299348 ]\n",
"Action 0\n",
"Observation [ 0.07099465 0.22045206 -0.07586577 -0.3578286 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.07099465 0.22045206 -0.07586577 -0.3578286 ]\n",
"Action 1\n",
"Observation [ 0.0754037 0.41656595 -0.08302233 -0.6734364 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.0754037 0.41656595 -0.08302233 -0.6734364 ]\n",
"Action 1\n",
"Observation [ 0.08373501 0.6127377 -0.09649107 -0.99106103] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.08373501 0.6127377 -0.09649107 -0.99106103]\n",
"Action 0\n",
"Observation [ 0.09598977 0.41903025 -0.11631229 -0.7301758 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.09598977 0.41903025 -0.11631229 -0.7301758 ]\n",
"Action 1\n",
"Observation [ 0.10437037 0.61555123 -0.1309158 -1.0570843 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.10437037 0.61555123 -0.1309158 -1.0570843 ]\n",
"Action 1\n",
"Observation [ 0.1166814 0.8121419 -0.15205748 -1.3878263 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"(array([ 0.0463822 , 0.03024504, 0.03519755, -0.02215011], dtype=float32), {})\n",
"Action 0\n",
"Observation [ 0.0469871 -0.16536354 0.03475455 0.28142697] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.0469871 -0.16536354 0.03475455 0.28142697]\n",
"Action 0\n",
"Observation [ 0.04367983 -0.36096355 0.04038309 0.58486557] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.04367983 -0.36096355 0.04038309 0.58486557]\n",
"Action 0\n",
"Observation [ 0.03646055 -0.5566272 0.0520804 0.8899912 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.03646055 -0.5566272 0.0520804 0.8899912 ]\n",
"Action 0\n",
"Observation [ 0.02532801 -0.75241566 0.06988022 1.1985804 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.02532801 -0.75241566 0.06988022 1.1985804 ]\n",
"Action 1\n",
"Observation [ 0.0102797 -0.5582641 0.09385183 0.92859185] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.0102797 -0.5582641 0.09385183 0.92859185]\n",
"Action 1\n",
"Observation [-0.00088559 -0.36452585 0.11242367 0.6668154 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.00088559 -0.36452585 0.11242367 0.6668154 ]\n",
"Action 0\n",
"Observation [-0.0081761 -0.561017 0.12575997 0.99267435] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.0081761 -0.561017 0.12575997 0.99267435]\n",
"Action 1\n",
"Observation [-0.01939644 -0.3677815 0.14561346 0.7419863 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.01939644 -0.3677815 0.14561346 0.7419863 ]\n",
"Action 1\n",
"Observation [-0.02675207 -0.17493759 0.16045319 0.49844098] , reward 1.0 , terminated False , truncated False , info {}\n",
"[-0.02675207 -0.17493759 0.16045319 0.49844098]\n",
"Action 1\n",
"Observation [-0.03025082 0.01760164 0.170422 0.26031297] , reward 1.0 , terminated False , truncated False , info {}\n",
"(array([ 0.02045374, 0.02493249, -0.02570103, 0.03014062], dtype=float32), {})\n",
"Action 1\n",
"Observation [ 0.02095238 0.2204134 -0.02509822 -0.27053916] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.02095238 0.2204134 -0.02509822 -0.27053916]\n",
"Action 1\n",
"Observation [ 0.02536065 0.41588435 -0.030509 -0.57103133] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.02536065 0.41588435 -0.030509 -0.57103133]\n",
"Action 0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Observation [ 0.03367834 0.22120321 -0.04192962 -0.2881138 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.03367834 0.22120321 -0.04192962 -0.2881138 ]\n",
"Action 0\n",
"Observation [ 0.0381024 0.0267035 -0.0476919 -0.00894437] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.0381024 0.0267035 -0.0476919 -0.00894437]\n",
"Action 0\n",
"Observation [ 0.03863648 -0.16770318 -0.04787079 0.268318 ] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.03863648 -0.16770318 -0.04787079 0.268318 ]\n",
"Action 1\n",
"Observation [ 0.03528241 0.02806809 -0.04250443 -0.03907115] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.03528241 0.02806809 -0.04250443 -0.03907115]\n",
"Action 1\n",
"Observation [ 0.03584377 0.22377296 -0.04328585 -0.34485587] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.03584377 0.22377296 -0.04328585 -0.34485587]\n",
"Action 0\n",
"Observation [ 0.04031923 0.02929264 -0.05018297 -0.06613071] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.04031923 0.02929264 -0.05018297 -0.06613071]\n",
"Action 0\n",
"Observation [ 0.04090508 -0.16507524 -0.05150558 0.21030648] , reward 1.0 , terminated False , truncated False , info {}\n",
"[ 0.04090508 -0.16507524 -0.05150558 0.21030648]\n",
"Action 0\n",
"Observation [ 0.03760358 -0.35942435 -0.04729946 0.48630762] , reward 1.0 , terminated False , truncated False , info {}\n"
]
}
],
"source": [
"import gymnasium as gym\n",
"\n",
"env = gym.make('CartPole-v1', render_mode='human')\n",
"for i_episode in range(10):\n",
" \n",
" \n",
" observation = env.reset()\n",
" for t in range(10):\n",
" env.render()\n",
" print(observation)\n",
" action = env.action_space.sample()\n",
" print(\"Action \", action)\n",
" observation, reward, terminated, truncated, info = env.step(action)\n",
" print(\"Observation \", observation, \", reward \", reward, \", terminated \", terminated,\n",
" \", truncated\", truncated, \", info \" , info)\n",
" if terminated or truncated:\n",
" print(\"Episode finished after {} timesteps\".format(t+1))\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The Frozen Lake scenario\n",
"We are going to play to the [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) game.\n",
"\n",
"The problem is a grid where you should go from the 'start' (S) position to the 'goal position (G) (the pizza!). You can only walk through the 'frozen tiles' (F). Unfortunately, you can fall in a 'hole' (H).\n",
"![](images/frozenlake-problem.png \"Frozen lake problem\")\n",
"\n",
"The episode ends when you reach the goal or fall in a hole. You receive a reward of 1 if you reach the goal, and zero otherwise. The possible actions are going left, right, up or down. However, the ice is slippery, so you won't always move in the direction you intend.\n",
"\n",
"![](images/frozenlake-world.png \"Frozen lake world\")\n",
"\n",
"\n",
"Here you can see several episodes. A full recording is available at [Frozen World](http://gym.openai.com/envs/FrozenLake-v0/).\n",
"\n",
"![](images/recording.gif \"Example running\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Q-Learning with the Frozen Lake scenario\n",
"We are now going to apply Q-Learning for the Frozen Lake scenario. This part of the notebook is taken from [here](https://github.com/simoninithomas/Deep_reinforcement_learning_Course/blob/master/Q%20learning/Q%20Learning%20with%20FrozenLake.ipynb). You can get more details about this scenario [here](https://gymnasium.farama.org/environments/toy_text/frozen_lake/).\n",
"\n",
"First we create the environment and a Q-table inizializated with zeros to store the value of each action in a given state. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"There are 16 possible states\n",
"There are 4 possible actions\n",
"QTable\n",
"[[0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]\n",
" [0. 0. 0. 0.]]\n"
]
}
],
"source": [
"import numpy as np\n",
"import gymnasium as gym\n",
"import random\n",
"\n",
"env = gym.make(\"FrozenLake-v1\", desc=None, map_name=\"4x4\", is_slippery=False) #no render so training is faster\n",
"#env = gym.make(\"FrozenLake-v1\", desc=None, map_name=\"4x4\", is_slippery=False, render_mode='human')\n",
"#env = gym.make(\"FrozenLake-v1\", desc=None, map_name=\"4x4\", is_slippery=False, render_mode='rgb_array')\n",
"\n",
"\n",
"action_space = env.action_space.n\n",
"state_space = env.observation_space.n\n",
"\n",
"print(\"There are \", state_space, \" possible states\")\n",
"print(\"There are \", action_space, \" possible actions\")\n",
"\n",
"# Step 1. Initialize QTable\n",
"qtable = np.zeros((state_space, action_space))\n",
"print(\"QTable\")\n",
"print(qtable)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define the hyperparameters."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Q-Learning hyperparameters\n",
"total_episodes = 10000 # Total episodes\n",
"learning_rate = 0.8 # Learning rate\n",
"max_steps = 99 # Max steps per episode\n",
"gamma = 0.95 # Discounting rate\n",
"\n",
"# Exploration hyperparameters\n",
"epsilon = 1.0 # Exploration rate\n",
"max_epsilon = 1.0 # Exploration probability at start\n",
"min_epsilon = 0.01 # Minimum exploration probability \n",
"decay_rate = 0.01 # Exponential decay rate for exploration prob"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now we implement the Q-Learning algorithm.\n",
"\n",
"![](images/qlearning-algo.png \"Q-Learning algorithm\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"state, info = env.reset() #reset returns observation, info\n",
"step = 0\n",
"done = False\n",
"total_rewards = 0"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/10000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Score over time: 0.9765\n",
"[[0.73509189 0.77378094 0.6983373 0.73509189]\n",
" [0.73509189 0. 0. 0. ]\n",
" [0.5586684 0. 0. 0. ]\n",
" [0. 0. 0. 0. ]\n",
" [0.77378094 0.81450625 0. 0.73509189]\n",
" [0. 0. 0. 0. ]\n",
" [0. 0. 0. 0. ]\n",
" [0. 0. 0. 0. ]\n",
" [0.81450625 0. 0.857375 0.77378094]\n",
" [0.81450625 0.9025 0.81450625 0. ]\n",
" [0.857375 0. 0. 0. ]\n",
" [0. 0. 0. 0. ]\n",
" [0. 0. 0. 0. ]\n",
" [0. 0.9025 0.95 0.857375 ]\n",
" [0.9025 0.95 1. 0.81450625]\n",
" [0. 0. 0. 0. ]]\n"
]
}
],
"source": [
"# Traing to learn Q-Table\n",
"from tqdm.notebook import tqdm\n",
"# List of rewards\n",
"rewards = []\n",
"\n",
"# Step 2 For life or until learning is stopped\n",
"for episode in tqdm(range(total_episodes)):\n",
" # Reset the environment\n",
" state, info = env.reset() #reset returns observation, info\n",
" step = 0\n",
" done = False\n",
" total_rewards = 0\n",
" \n",
" for step in range(max_steps):\n",
" # 3. Choose an action a in the current world state (s)\n",
" ## First we randomize a number\n",
" exp_exp_tradeoff = random.uniform(0, 1)\n",
" \n",
" ## If this number > greater than epsilon --> exploitation (taking the biggest Q value for this state)\n",
" if exp_exp_tradeoff > epsilon:\n",
" action = np.argmax(qtable[state][:])\n",
"\n",
" # Else doing a random choice --> exploration\n",
" else:\n",
" action = env.action_space.sample()\n",
"\n",
" # Take the action (a) and observe the outcome state(s') and reward (r)\n",
" new_state, reward, terminated, truncated, info = env.step(action)\n",
"\n",
" # Update Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)]\n",
" # qtable[new_state][:] : all the actions we can take from new state\n",
" qtable[state][action] = qtable[state][action] + learning_rate * (reward + gamma * np.max(qtable[new_state]) - qtable[state][action])\n",
" \n",
" total_rewards += reward\n",
" \n",
" # Our new state is state\n",
" state = new_state\n",
" \n",
" done = terminated or truncated\n",
" # If done (if we're dead) : finish episode\n",
" if done == True: \n",
" break\n",
" \n",
" episode += 1\n",
" # Reduce epsilon (because we need less and less exploration)\n",
" epsilon = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*episode) \n",
" rewards.append(total_rewards)\n",
"\n",
"print (\"Score over time: \" + str(sum(rewards)/total_episodes))\n",
"print(qtable)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we use the learnt Q-table for playing the Frozen World game."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"****************************************************\n",
"EPISODE 0\n",
"****************************************************\n",
"Action 1\n",
"Action 1\n",
"Action 2\n",
"Action 1\n",
"Action 2\n",
"Action 2\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"****************************************************\n",
"EPISODE 1\n",
"****************************************************\n",
"Action 1\n",
"Action 1\n",
"Action 2\n",
"Action 1\n",
"Action 2\n",
"Action 2\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"****************************************************\n",
"EPISODE 2\n",
"****************************************************\n",
"Action 1\n",
"Action 1\n",
"Action 2\n",
"Action 1\n",
"Action 2\n",
"Action 2\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"****************************************************\n",
"EPISODE 3\n",
"****************************************************\n",
"Action 1\n",
"Action 1\n",
"Action 2\n",
"Action 1\n",
"Action 2\n",
"Action 2\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"****************************************************\n",
"EPISODE 4\n",
"****************************************************\n",
"Action 1\n",
"Action 1\n",
"Action 2\n",
"Action 1\n",
"Action 2\n",
"Action 2\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n",
"Action 0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cif/anaconda3/lib/python3.10/site-packages/gymnasium/envs/toy_text/frozen_lake.py:328: UserWarning: \u001b[33mWARN: You are calling render method without specifying any render mode. You can specify the render_mode at initialization, e.g. gym.make(\"FrozenLake-v1\", render_mode=\"rgb_array\")\u001b[0m\n",
" gym.logger.warn(\n"
]
}
],
"source": [
"env.reset()\n",
"\n",
"for episode in range(5):\n",
" state, info = env.reset()\n",
" step = 0\n",
" done = False\n",
" print(\"****************************************************\")\n",
" print(\"EPISODE \", episode)\n",
" print(\"****************************************************\")\n",
"\n",
" for step in range(max_steps):\n",
" env.render() # render according to rend_mode specified in gym.make\n",
" # Take the action (index) that have the maximum expected future reward given that state\n",
" action = np.argmax(qtable[state][:])\n",
" print(\"Action \", action)\n",
" \n",
" new_state, reward, terminated, truncated, info = env.step(action)\n",
" \n",
" if done:\n",
" break\n",
" state = new_state\n",
"env.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References\n",
"* [Gymnasium documentation](https://gymnasium.farama.org/).\n",
"* [Diving deeper into Reinforcement Learning with Q-Learning, Thomas Simonini](https://medium.freecodecamp.org/diving-deeper-into-reinforcement-learning-with-q-learning-c18d0db58efe).\n",
"* Illustrations by [Thomas Simonini](https://github.com/simoninithomas/Deep_reinforcement_learning_Course) and [Sung Kim](https://www.youtube.com/watch?v=xgoO54qN4lY).\n",
"* [Frozen Lake solution with TensorFlow](https://analyticsindiamag.com/openai-gym-frozen-lake-beginners-guide-reinforcement-learning/)\n",
"* [Deep Q-Learning for Doom](https://medium.freecodecamp.org/an-introduction-to-deep-q-learning-lets-play-doom-54d02d8017d8)\n",
"* [Intro OpenAI Gym with Random Search and the Cart Pole scenario](http://www.pinchofintelligence.com/getting-started-openai-gym/)\n",
"* [Q-Learning for the Taxi scenario](https://www.oreilly.com/learning/introduction-to-reinforcement-learning-and-openai-gym)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Licence"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The notebook is freely licensed under under the [Creative Commons Attribution Share-Alike license](https://creativecommons.org/licenses/by/2.0/). \n",
"\n",
"© Carlos Á. Iglesias, Universidad Politécnica de Madrid."
]
}
],
"metadata": {
"datacleaner": {
"position": {
"top": "50px"
},
"python": {
"varRefreshCmd": "try:\n print(_datacleaner.dataframe_metadata())\nexcept:\n print([])"
},
"window_display": false
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
"autocomplete": true,
"bibliofile": "biblio.bib",
"cite_by": "apalike",
"current_citInitial": 1,
"eqLabelWithNumbers": true,
"eqNumInitial": 1,
"hotkeys": {
"equation": "Ctrl-E",
"itemize": "Ctrl-I"
},
"labels_anchors": false,
"latex_user_defs": false,
"report_style_numbering": false,
"user_envs_cfg": false
}
},
"nbformat": 4,
"nbformat_minor": 1
}