mirror of
https://github.com/gsi-upm/sitc
synced 2024-11-17 20:12:28 +00:00
369 lines
500 KiB
Plaintext
369 lines
500 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"![](images/EscUpmPolit_p.gif \"UPM\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Course Notes for Learning Intelligent Systems"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Department of Telematic Engineering Systems, Universidad Politécnica de Madrid, © Carlos Á. Iglesias"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## [Introduction to Machine Learning V](2_6_0_Intro_RL.ipynb)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Visualization\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"In this section we are going to visualize Q-Learning based on this [link](https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/#sphx-glr-tutorials-training-agents-frozenlake-tuto-py). The code has been ported to the last version of Gymnasium.\n",
|
|||
|
"\n",
|
|||
|
"First, we are going to define a class *Params* for the Q-Learning parameters and the environment based on these values."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from qlearning import *\n",
|
|||
|
"sns.set_theme()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"params = Params(\n",
|
|||
|
" total_episodes=2000,\n",
|
|||
|
" learning_rate=0.8,\n",
|
|||
|
" gamma=0.95,\n",
|
|||
|
" epsilon=0.1,\n",
|
|||
|
" map_size=5,\n",
|
|||
|
" seed=123,\n",
|
|||
|
" is_slippery=False,\n",
|
|||
|
" n_runs=20,\n",
|
|||
|
" action_size=None,\n",
|
|||
|
" state_size=None,\n",
|
|||
|
" proba_frozen=0.9,\n",
|
|||
|
" savefig_folder=Path(\"./\"),\n",
|
|||
|
")\n",
|
|||
|
"params\n",
|
|||
|
"\n",
|
|||
|
"# Set the seed\n",
|
|||
|
"rng = np.random.default_rng(params.seed)\n",
|
|||
|
"\n",
|
|||
|
"# Create the figure folder if it doesn't exists\n",
|
|||
|
"params.savefig_folder.mkdir(parents=True, exist_ok=True)\n",
|
|||
|
"\n",
|
|||
|
"# Environment\n",
|
|||
|
"env = gym.make(\n",
|
|||
|
" \"FrozenLake-v1\",\n",
|
|||
|
" is_slippery=params.is_slippery,\n",
|
|||
|
" render_mode=\"rgb_array\",\n",
|
|||
|
" desc=generate_random_map(\n",
|
|||
|
" size=params.map_size, p=params.proba_frozen, seed=params.seed\n",
|
|||
|
" ),\n",
|
|||
|
")\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"The Q-Learning algorithm has been defined with two clases in the file *qlearning.py*:\n",
|
|||
|
"- *Qlearning* for learning the q-table\n",
|
|||
|
"- *EpsilonGreedy* for implementing the epsilon-greedy policy \n",
|
|||
|
"\n",
|
|||
|
"First, we check the environment."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Action size: 4\n",
|
|||
|
"State size: 25\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"params = params._replace(action_size=env.action_space.n)\n",
|
|||
|
"params = params._replace(state_size=env.observation_space.n)\n",
|
|||
|
"print(f\"Action size: {params.action_size}\")\n",
|
|||
|
"print(f\"State size: {params.state_size}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Running the environment"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"learner = Qlearning(\n",
|
|||
|
" learning_rate=params.learning_rate,\n",
|
|||
|
" gamma=params.gamma,\n",
|
|||
|
" state_size=params.state_size,\n",
|
|||
|
" action_size=params.action_size,\n",
|
|||
|
")\n",
|
|||
|
"explorer = EpsilonGreedy(\n",
|
|||
|
" epsilon=params.epsilon,\n",
|
|||
|
" rng = rng\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"This will be our main function to run our environment until the maximum number of episodes *params.total_episodes*. To account for stochasticity, we will also run our environment a few times."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"We want to plot the policy the agent has learned in the end. To do that the function *qtable_directions_map* perform these actions: 1. extract the best Q-values from the Q-table for each state, 2. get the corresponding best action for those Q-values, 3. map each action to an arrow so we can visualize it."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"The function *plot_q_values_map* plots on the left the last frame of the simulation. If the agent learned a good policy to solve the task, we expect to see it on the tile of the treasure in the last frame of the video. On the right we’ll plot the policy the agent has learned. Each arrow will represent the best action to choose for each tile/state."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"As a sanity check, the function *plot_states_actons_distribution* plots the distributions of states and actions."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Now we’ll be running our agent on a few increasing maps sizes: \n",
|
|||
|
"- 4x4\n",
|
|||
|
"- 7x7\n",
|
|||
|
"- 9x9\n",
|
|||
|
"- 11x11\n",
|
|||
|
"\n",
|
|||
|
"Putting it all together:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"Params(total_episodes=2000, learning_rate=0.8, gamma=0.95, epsilon=0.1, map_size=5, seed=123, is_slippery=False, n_runs=20, action_size=4, state_size=25, proba_frozen=0.9, savefig_folder=PosixPath('.'))"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"params"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Map size: 11x11\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" \r"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABcwAAAHkCAYAAAAD/WxfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACjK0lEQVR4nOzdeXyU9b33//c1M5nJOtnIxo4gBFDASjS2Uige2rva9pyU3rfntPFUUKRVwdIjWAs90lOo9j5UBC2ntxCKv2MtoKHW5XShtC5VikDrBkS2EJasZJusM5nl98dkRsIiWSYzE/J6Ph55QOa6ru/1vfKdgSvv+c7na/h8Pp8AAAAAAAAAABjkTJHuAAAAAAAAAAAA0YDAHAAAAAAAAAAAEZgDAAAAAAAAACCJwBwAAAAAAAAAAEkE5gAAAAAAAAAASCIwBwAAAAAAAABAEoE5AAAAAAAAAACSCMwBAAAAAAAAAJBEYA4ACAGfzxfW4wAAAAB0D/fcANAzBOYAgEs6fPiwlixZos985jO65pprdPPNN+s73/mODh48GNxn//79WrhwYY/b3rVrlx566KFQdhcAAAC4oixbtkwTJkzQ008/3eNjKysrtXDhQp05cyb42OzZs/W9730vlF0EgCsOgTkA4KKOHDmi22+/XXV1dVq+fLk2b96sZcuWqby8XLfffrveffddSdLzzz+vo0eP9rj9LVu2qKKiIsS9BgAAAK4Mzc3N+sMf/qDx48dr+/btPZ4p/vbbb+u1117r8thTTz2le++9N4S9BIArD4E5AOCifvGLXyglJUWbNm3SrbfeqhtuuEFf+cpXtGXLFqWlpWnDhg2R7iIAAABwxXr11Vfl8Xi0YsUKnTp1Sn/5y1/63OakSZM0cuTIEPQOAK5cBOYAgIs6e/aspAtrHsbHx+vhhx/WF7/4RX3ve9/Tr3/9a505c0YTJkzQjh07JEmnT5/WsmXLdPPNN2vy5Mm66aabtGzZMtXX10uS7rjjDr3zzjt65513NGHCBO3Zs0eS1NDQoH//93/Xpz/9aV177bX6P//n/2j37t1dzv/222/r9ttv13XXXae8vDzde++9On78eH//OAAAAICwKi4u1o033qgbb7xRY8aM0datWy/Y59VXX9VXv/pVTZ06VbNmzdJ//ud/yuVyaceOHXr44YclSbfcckuwDMv5JVmampr06KOP6h/+4R907bXX6ktf+pJeeOGFLueYPXu21q9fr5/85Cf69Kc/rSlTpuiuu+5SaWlpcJ+6ujo9+OCD+sxnPqNrr71W//iP/6gXX3yxH34qAND/CMwBABc1a9YslZeX65//+Z/1y1/+UseOHQuG5//rf/0vFRQU6N5779XMmTOVkZGhbdu2adasWWpra9O//uu/6tixY3rkkUdUVFSkwsJCvfLKK3r88cclSY888ogmTZqkSZMmadu2bZo8ebKcTqe++c1vateuXVqyZImeeuopZWdn6+677w6G5qdOndK3v/1tTZ48Wf/1X/+lVatW6fjx47rnnnvk9Xoj9rMCAAAAQunYsWN67733VFBQIEn66le/qj//+c+qqqoK7rN161Z997vf1cSJE/XUU09p4cKFeu6557Ry5UrNmjVL3/72tyVdugxLe3u7vv71r+ull17S/PnztWHDBl1//fVavny5fv7zn3fZ9//7//4/HT9+XI8++qhWrVqlDz/8sEvwvnTpUh09elQ//OEP9fTTT2vSpEl66KGHghNjAGAgsUS6AwCA6PT1r39dNTU1Kioq0n/8x39IklJTU3XzzTfrjjvu0NSpUzVy5EilpaXJarVq2rRpkqRDhw4pOztbjz32WPDjnvn5+frggw/0zjvvSJLGjRunxMRESQoet337dpWUlGj79u2aOnWqJOmzn/2s7rjjDq1Zs0bFxcV6//331d7eroULFyorK0uSlJOTo127dqm1tTXYJgAAADCQvfDCC7Lb7fqHf/gHSdI//dM/6YknntDzzz+v+++/X16vV08++aTmzJmj1atXB49zOp369a9/rcTExOC9+MSJEzV8+PALzrFjxw4dPnxYzz33nK6//npJ0owZM+R2u7Vhwwb98z//s1JSUiRJdrtdGzZskNlsliSdPHlSTz75pOrr65Wamqp33nlH9957b7C/N954o1JSUoL7A8BAwgxzAMAlPfDAA3rzzTf105/+VF/72teUmJiol19+WbfffrueeeaZix4zceJEPffccxo+fLhOnTqlN998U5s3b9bx48fV0dFxyXPt3r1bGRkZmjx5stxut9xutzwejz73uc/pww8/VGNjo6ZOnSqbzaavfe1revTRR/X2228rNzdXS5YsISwHAADAFcHtduull17SP/zDP8jpdMrhcCg2NlY33nijnn/+eXk8HpWWlurs2bPBgDrgzjvv1G9+8xtZrdbLnuedd97RsGHDgmF5wFe+8hU5nU699957wceuvfbaLuF3dna2JKmtrU2SPyB/8skn9cADD2jHjh2qq6vTQw89pOnTp/f65wAAkcIMcwDAJ0pOTtaXvvQlfelLX5IkHTx4UMuWLdOaNWv0la985aLH/OIXv9D/+3//T/X19RoyZIgmT56suLg4NTU1XfI8DQ0Nqqmp0eTJky+6vaamRuPGjdOzzz6rp59+Wtu3b9eWLVtkt9v19a9/XQ888IBMJt4HBgAAwMD22muv6ezZs9qxY0dwjaBz/fnPf1ZqaqokKT09vdfnaWxs1JAhQy54PPCYw+EIPhYXF9dln8B9d6As4tq1a/Xzn/9cv/3tb/W73/1OJpNJn/70p7Vy5UqNGDGi130EgEggMAcAXKCqqkpz587VAw88oP/9v/93l22TJk3Sd77zHd133306derUBce+/PLLeuyxx/Rv//Zv+trXvqa0tDRJ/tnqH3zwwSXPmZSUpNGjR2vNmjUX3R74GOmUKVP01FNPyeVyaf/+/dq2bZt+/vOfa8KECbr11lt7e8kAAABAVHjhhRc0bNgwPfrooxdsW7x4sbZu3aqHHnpIkn+xzXM1NDTowIEDwbKHnyQ5OVllZWUXPF5TUyNJwVC+O5KSkrR06VItXbpUx48f165du7Rhwwb98Ic/1KZNm7rdDgBEA6biAQAuMGTIEFksFj333HNyOp0XbD9+/LhsNptGjRp1wazu/fv3KykpSffcc08wLG9padH+/fu7LMx5/nE33HCDKioqlJ6ermuvvTb4tXv3bm3atElms1lbtmzR7Nmz5XK5ZLVaddNNN+lHP/qRJKmioiLUPwYAAAAgrM6ePas333xTt912m2688cYLvm699Va99dZbstlsSk1N1a5du7oc//LLL2vBggVyOp2X/fRlXl6ezpw5o/3793d5/KWXXlJMTIymTJnSrT6fOXNGM2fO1O9+9ztJ0lVXXaUFCxbo05/+tCorK3tw9QAQHQjMAQAXMJvNWrlypQ4fPqy5c+fqV7/6ld555x29/vrr+vGPf6x169bp/vvvV3Jysux2u86ePavXX39d1dXVmjJlipqamvTYY49pz549evnll/WNb3xDZ8+eDdY4lPwLB5WWlmr37t1qbGzUV7/6VQ0dOlTz5s3Tr3/9a/31r3/V448/rrVr1yozM1MxMTHKz89XdXW17rvvPr3++uv6y1/+oocfflhWq1Wf+9znIvgTAwAAAPru17/+tdxut2677baLbi8oKJDX69Xzzz+vRYsW6fe//71Wrlypt956S7/85S/1xBNP6F/+5V+UlpYmu90uSdq5c6eOHTt2QVtf/epXNW7cON1///361a9+pb/85S/6j//4DxUXF2vhwoXB4y9n2LBhys7O1qpVq/TCCy/onXfe0ebNm/X666/rC1/4Qu9/GAAQIYbP5/NFuhMAgOh04MABFRUVaf/+/aqrq5PVatWkSZN0xx136POf/7wk6fDhw3rggQd06tQpLV68WAsWLNCTTz6p4uJi1dfXKysrSzNnztT48eP1gx/8QK+++qrGjRunv/71r3r44YdVU1OjRx99VF/+8pdVW1urn/70p3rttdfU1NSkYcOG6Wtf+5rmz58fnCHzl7/8RT/
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1500x500 with 2 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABD0AAAHDCAYAAAAnVGG2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9ebwcVZn//z7nVPV+l2wkgbBDAoRdCIJhHTdgQER0HAFFUX6gjhsjIn5VHGGUEUaUCKOI6CigMOCAKyPDqoCAMgoiCrIlkD25a29V55zfH6equvt23+SCaCCe9+sVTe6trj711Keafp7zLMJaa/F4PB6Px+PxeDwej8fj2cyQm3oBHo/H4/F4PB6Px+PxeDx/CXzQw+PxeDwej8fj8Xg8Hs9miQ96eDwej8fj8Xg8Ho/H49ks8UEPj8fj8Xg8Ho/H4/F4PJslPujh8Xg8Ho/H4/F4PB6PZ7PEBz08Ho/H4/F4PB6Px+PxbJb4oIfH4/F4PB6Px+PxeDyezRIf9PB4PB6Px+PxeDwej8ezWeKDHh6Px+PxeDwej8fj8Xg2S3zQw+PxeDwej8fj8Xg8Hs9miQ96eDwej8fj8Xg8U+Dkk0/m5JNP3tTL+ItyySWXsGDBgikde+edd3L66aezePFi9tprL173utfxuc99juXLl/+FV9nJEUccwdlnn/1XfU+Px/PywQc9PB6Px+PxeDwez/PivPPO47TTTqNcLvPpT3+ar371q5x88sncfvvtvOENb+Dee+/d1Ev0eDweAIJNvQCPx+PxeDwej8fz8uHaa6/l29/+Nueddx5vfvObs5+/8pWv5LjjjuM973kPH/rQh/jhD3/IzJkzN+FKPR6Px2d6eDwej8fj8Xg8LyoPPPAAJ510EnvttReLFi3iYx/7GOvWres45v777+fUU09l//33Z/fdd+eII47gkksuwRgDwLJly1iwYAFXXnklRx55JIsWLeKGG27gkksu4TWveQ233347xxxzDLvvvjuve93r+P73v99x/qGhIT71qU9x0EEHsccee/CWt7yFe+65p+OYRqPB5z73OV71qlexzz778PGPf5xGo7HBa7PWcumll7J48eKOgEdKpVLhvPPOY/369Vx11VWTnufXv/41CxYs4JZbbun4+Z/+9CcWLFjAT37yk8wOZ511FosXL2bhwoUceOCBnHXWWaxfv77neVO73XDDDR0/P/vsszniiCM6fnbLLbdw/PHHs8cee/CqV72K8847j2q12mGfz3zmMxxyyCHsvvvuvP71r+cb3/jGBu3j8Xheevigh8fj8Xg8Ho/H8yJx//33c8opp1AoFLj44os555xzuO+++3j7299OvV4H4NFHH+WUU05hcHCQL37xi1x22WXsu+++LFmyhB/96Ecd5/viF7/IqaeeynnnnccrX/lKAFavXs2//Mu/8Pa3v52vfe1rzJs3j7PPPps//elPgHPW3/GOd/C///u/fPjDH2bJkiXMmTOHd7/73R2Bj49+9KN873vf4z3veQ8XX3wxw8PDfPOb39zg9T3yyCMsX76cV7/61ZMes+OOO7LLLrt0BTTa2Xfffdl222358Y9/3PHzH/zgB/T19XHEEUdQq9V4+9vfzp/+9Cc+/elPc8UVV3DSSSfxwx/+kH//93/f4Do3xg9+8APe9773scMOO/CVr3yF97///dx00028973vxVoLwPnnn88dd9zBxz72Ma644gr+7u/+jgsuuKAroOLxeF7a+PIWj8fj8Xg8Ho/nReKiiy5i++2356tf/SpKKQD22msvjj76aK6//npOPPFEHn30UQ466CC+8IUvIKXbg3zVq17F7bffzv33388xxxyTne+1r30tJ5xwQsd71Go1zj//fA488EAAtttuOw4//HDuuOMOdtxxR2688UYeffRRrr32Wvbaay8ADjnkEE4++WQuvPBCrr/+eh577DFuvvlmPvWpT3HiiScCcPDBB3PMMcfw+OOPT3p9zz77LABbbbXVBu2w7bbb8vOf/3yDxxx77LFcccUV1Go1isUiAD/60Y94/etfTz6f5/e//z1z5szh85//PNtssw3gSmgeeugh7rvvvg2ee0NYa7nwwgs5+OCDufDCC7Ofb7fddpxyyinccccdHHbYYdx3330cdNBBHH300QAccMABlEolpk2b9oLf2+Px/PXxmR4ej8fj8Xg8Hs+LQK1W4ze/+Q2HHnoo1lriOCaOY7beemt23HFHfvGLXwBw3HHHcfnllxNFEY899hi33HILl1xyCVproijqOOf8+fN7vtfee++d/X3OnDkAWWnGPffcw6xZs1i4cGG2Bq01hx9+OA8//DDDw8M88MADAPzd3/1ddh4pJa973es2eI1pFoQQYoPHCSGyY40x2TrStQC84Q1voFqtcttttwHw29/+lmeeeYY3vOENAOy6665cffXVzJs3j6VLl3LXXXfxjW98gyeeeKLLTs+HJ554ghUrVnDEEUd0rGv//fenUqlk9+mAAw7guuuu4z3veQ9XX301zz77LO973/s4/PDDX/B7ezyevz4+08Pj8Xg8Ho/H43kRGBkZwRjD5ZdfzuWXX971+3w+D0C9Xuezn/0sN954I3EcM2/ePPbZZx+CIMgCBSmTNQJNMyOALFskfe3Q0BCrV69m4cKFPV+7evVqhoeHAZg+fXrH72bNmrXBa0wzPJYuXbrB45YuXcrcuXMBOOecczp6jmy11VbceuutbL311uy777786Ec/4qijjuIHP/gBW221Ffvtt1927JVXXslXv/pV1q9fz8yZM1m4cCHFYpHR0dENvv+GGBoaAuAzn/kMn/nMZ7p+v2rVKgA+8YlPMGfOHG666absuH322YdPfepT7Lbbbi/4/T0ez18XH/TweDwej8fj8XheBMrlMkIITjnllKwkop00UHH++edz8803c/HFF3PQQQdRKpUAsnKVP5e+vj622267jtKNdubNm5eVaKxZs4Ytt9wy+10aEJiMhQsXMnfuXP7nf/6Ht73tbdnPV65ciZSSWbNmsXTpUh599FHe/va3A/D+978/K6EByOVy2d/f8IY3cP755zM6OspPfvIT3vSmN2VZJD/4wQ/4/Oc/z5lnnskJJ5yQBWg++MEP8tBDD/VcX/raNJskpb1BaX9/PwBnnXUWixYt6jrHwMBAts4zzjiDM844g+eee47bbruNSy+9lDPPPDNrtOrxeF76+PIWj8fj8Xg8Ho/nRaBSqbDbbrvxxBNPsMcee2R/dt55Z5YsWcIvf/lLAH71q19xwAEH8OpXvzoLeDz88MOsW7cum97y57Bo0SKWL1/OjBkzOtZxzz338PWvfx2lVNYU9ac//WnHa9NSk8kQQvC+972Pe+65h2uvvTb7+Y033shhhx3GBRdcwDnnnEOhUOCd73wn4IIs7etYsGBB9rojjzwSgC996UusXr2aY489Nvvdr371K/r6+jjttNOygMf4+Di/+tWvJrVTpVIBYMWKFdnPoijit7/9bfbvHXbYgRkzZrBs2bKOdc2ZM4eLLrqIRx55hHq9zute97psWsuWW27JiSeeyNFHH91xbo/H89LHZ3p4PB6Px+PxeDxTZMWKFT0nnOy0004sXryYj3zkI5x22mmceeaZHHvssWit+cY3vsFvfvMbzjjjDAD23HNPfvKTn3DNNdew44478uijj3LZZZchhKBWq/3Zazz++OP5zne+wzvf+U5OP/105s6dy913383ll1/OSSedRBiGbLvttvzDP/wDX/ziF4njmF133ZUbb7yRP/zhDxs9/5vf/GYee+wxPvWpT/HLX/6SI488kn322YcjjzwyCxL80z/9E7Nnz97ouQYGBjj88MO5+uqr2WOPPdhxxx2z3+25555cc801fP7zn+fwww9n1apVXHHFFaxZsybLxuh1vn322YfvfOc7bLvttkybNo1vf/vb1Ov1LMCklOLDH/4wn/rUp1BKcfjhhzMyMsKll17KypUrWbhwIYVCgYULF7JkyRLCMGTBggU8+eSTfP/7399o3xOPx/PSwgc9PB6Px+PxeDyeKfLMM8/wuc99ruvnb3zjG1m
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1500x500 with 3 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"map_sizes = [4, 7, 9, 11]\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"res_all, st_all = run_frozen_maps(map_sizes, params, rng)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"The DOWN and RIGHT actions get chosen more often, which makes sense as the agent starts at the top left of the map and needs to find its way down to the bottom right. Also the bigger the map, the less states/tiles further away from the starting state get visited.\n",
|
|||
|
"\n",
|
|||
|
"To check if our agent is learning, we want to plot the cumulated sum of rewards, as well as the number of steps needed until the end of the episode. If our agent is learning, we expect to see the cumulated sum of rewards to increase and the number of steps to solve the task to decrease. "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABcwAAAHjCAYAAAAe+FznAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3wb9f0/8NfdaXnbcfYOGQQKSYAEAiQkhNECpW1+KRRaKA0UKLPsMNovScvoAMIIO2GUGSBsSoGEEUbIDhCy97SdeNtad/f5/P6QLVu2ZJ9kyRp+PfugsU+nu7c+Gj69733vjyKllCAiIiIiIiIiIiIi6uLUZAdARERERERERERERJQKmDAnIiIiIiIiIiIiIgIT5kREREREREREREREAJgwJyIiIiIiIiIiIiICwIQ5EREREREREREREREAJsyJiIiIiIiIiIiIiAAwYU5EREREREREREREBIAJcyIiIiIiIiIiIiIiAEyYExEREREREREREREBSLGE+WOPPYYLL7wwZNn69etxwQUXYMyYMZg8eTLmzZsXcrsQAg8//DAmTpyI0aNH4+KLL8bOnTuj2gYRERERERERERERUcokzJ977jk8/PDDIcsqKysxffp0DB48GAsWLMA111yDhx56CAsWLAiu89hjj+HVV1/FXXfdhfnz50NRFFx66aXw+/2Wt0FEREREREREREREZEt2AKWlpbjjjjuwcuVKDBkyJOS21157DQ6HAzNnzoTNZsPQoUOxc+dOPP3005g2bRr8fj+eeeYZ3HzzzZg0aRIAYPbs2Zg4cSI++eQTnHXWWe1ug4iIiIiIiIiIiIgISIEK8x9//BEFBQV49913MXr06JDbVqxYgXHjxsFma8rrjx8/Htu3b0d5eTk2bNiA+vp6jB8/Pnh7fn4+Dj/8cCxfvtzSNoiIiIiIiIiIiIiIgBSoMJ8yZQqmTJkS9raSkhKMGDEiZFnPnj0BAPv27UNJSQkAoE+fPq3W2b9/v6VtFBcXxxS3lBJCyJju2xGqqiRlv+mIYxUdjpd1HKvocLys41hFh+NlHccqOp09XqqqQFGUTttfuuDxdnrgeFnHsYoOx8s6jlV0OF7Wcayiw/GyLhljFc0xd9IT5m3xer1wOBwhy5xOJwDA5/PB4/EAQNh1qqurLW2jIzQtOQX6msYvVFZxrKLD8bKOYxUdjpd1HKvocLys41hFh+OVfEJIVFTUd+o+bTYVRUU5qKlxwzBEp+47HXG8rONYRYfjZR3HKjocL+s4VtHheFmXrLHq1i3H8jF+SifMXS5XcPLORo1J7uzsbLhcLgCA3+8P/ty4TlZWlqVtxEoIiZoad8z3j4WmqcjPz0JNjQemyTdfWzhW0eF4Wcexig7HyzqOVXQ4XtZxrKKTjPHKz89KWiEGERERERGFSumEee/evVFWVhayrPH3Xr16wTCM4LKBAweGrDNy5EhL2+iIZJ0xMk3Bs1UWcayiw/GyjmMVHY6XdRyr6HC8rONYRYfjRURERETUNaV0Kcu4ceOwcuVKmKYZXLZkyRIMGTIExcXFGDlyJHJzc7F06dLg7TU1NVi3bh3Gjh1raRtERERERERERERERECKJ8ynTZuGuro63HHHHdiyZQvefPNNPP/887j88ssBBHqXX3DBBbjvvvuwaNEibNiwAddffz169+6N0047zdI2iIiIiIiIiIiIiIiAFG/JUlxcjLlz5+Luu+/G1KlT0aNHD9xyyy2YOnVqcJ1rr70WhmHgL3/5C7xeL8aNG4d58+YFJ/q0sg0iIiIiIiIiIiIiopRKmP/jH/9otWzUqFGYP39+xPtomoabb74ZN998c8R12tsGERERUVcjpYQQAkKY7a/chQihwOvV4Pf7YJoyLtvUNBtUNaUv7CQiIiKiBBBCwDSNZIeRUtLheDulEuZERERElFhSSng8dairq2ayPIKDB1UIEd8JP7OycpGf3w2KosR1u0RERESUeqSUqKmpgMdTl+xQUlKqH28zYU5ERETUhTQeuLtcOXC5sqGqGpO4LWiaErdqFykl/H4f6uoqAQAFBZx0noiIiCjTNR5z5+YWweFw8ni7hVQ/3mbCnIiIiKiLEMKEx1OP3NxC5OYWJDuclGWzqTCM+FW8OBxOAEBdXSXy8orYnoWIiIgogwWOuQPJ8tzc/GSHk5JS/XibR+tEREREXYRpmgAknE5XskPpchoP4tnDkoiIiCizBY65m47/qHPE83ibCXMiIiKiLoeXhHY2XoZLRERE1LXw+K9zxXO8mTAnIiIiIiIiIiIiIgIT5kREREREREREREREAJgwJyIiIqIMNm/ek5gwYWyywyAiIiIiyliZdsxtS3YARERERESJcvbZv8Jxx52Q7DCIiIiIiDJWph1zM2FORERERBmrZ89e6NmzV7LDICIiIiLKWJl2zM2EORERERFZ8utfn40zzzwb9fV1+N//PoDfr2PChJNw88234803X8OCBa/B7a7H2LHH4pZb7kBBQSEAwOfz4tln5+LzzxehtLQEdrsDhx/+E1x11Z8xfPihAIC7756J/fv34Wc/OwvPPvs0amqqcdhhP8HVV1+HESNGRoxp7949eOSRB/D999/B5/Ni2LAR+MMf/ojjjz8RQODy0GeffRpffbUCq1atwLXX/insdnr37oM33ngPAFBSsh+PPPIQli37Fn6/D0ccMQpXXfXnNuOgzLN2Wzle/HgTpp85EocOLEp2OERERNRFdJ1j7hI8/vjDKXnMzYQ5EREREVk2f/5LGDv2WMyceQ82bFiHJ598FBs3rkePHj1xyy23Y9eunXjssYfRrVt33HjjDADA3/9+J9asWYU//elq9OvXH7t378LcuU/gzjtvx0svvQFFUQAAW7ZswlNP7cTll1+FvLx8zJv3JK655nK8+OLr6NGjZ6tYhBCYMeN6FBd3x1//Ogs2mw2vv/4qbr31Brz00hvo339AyPqHHjoSTzzxbMiy5cu/xbx5T+IXv5gKwxCoqKrE5ZdOh9PpwvXX34KsLBdee+0VXHXVZXj66ecxePCQBI0spZoHXvsOAPDPl1fjmVunJDkaIiIi6koy+ZgbAKqqKnHFFRen7DE3E+ZEREREZFl2dg5mzboXNpsN48Ydhw8/fB8HDx7EU089j9zcXBx//ASsXLkCP/wQSDbqug63243rr78Zp5xyOgDgqKOOgdtdjzlzHkR5eTm6d+8OAKirq8OcOQ9gzJijAQCHH34Ezj33lw0Hz39uFUtlZQV27NiO3//+Yhx//AQAwGGHHYFnn30Kfr+v1fo5Obk44ogjg7/v2rUT8+e/jClTTsPvf38xfLqJN15/BTU11XjppXno3bsPAGD8+BPxu9/9GnPnPoG77vpnHEeTiIiIiKi1TD7mBoBXXnkJ1dXVePnl1DzmZsKciIiIiCw77LCfwGZrOoTs1q0YOTk5yM3NDS4rKCjAtm1bAAB2ux0PPPAIAODgwYPYs2cXdu7cgW+++QoAYBh68H69evUOHrgDQPfu3XHkkaPw3Xerw8bSrVsxBg8+BP/8511Yvnwpxo8/AcceezyuueaGdh9HbW0tbr31BvTt2xe3335ncPmqVcsxfPgIdO/eA4ZhAAAURcH48Sfg448/bHe7REREREQdlenH3CtWLEvpY24mzImIiIjIspycnFbLnE5Xm/dZunQJHn74fuzcuQPZ2TkYOnQYsrMD25FSBtfr3r1Hq/sWFhahrKw07HYVRcGDDz6K556bh8WLP8OHH74Pm82Gk046GTfddCvy8wvC3s80Tdx55+2ora3FAw88CperKf6ammrs27sHkyePD3tfr9cbsj4RERERUbxl+jF3dXU19uzZnbLH3EyYExEREVHC7N27B7fddhMmTjwJ//rXg+jXrz8A4M03X8fSpd+ErFtdXd3q/pWVFSgq6hZx+92798BNN92KG2+cgS1bNuGzzxbhpZeeR35+Pm666baQdaWUEFLi0UcfxKpVy/HQQ0+gd+/eEFICDd8hcnPycNRRx4S9HBUIVO8QEREREaWSVDrmbtTymLu5vLw8jBlzNK6++rqw9032Mbea1L0
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1500x500 with 2 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"plot_steps_and_rewards(res_all, st_all,params)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## References\n",
|
|||
|
"* [Gymnasium documentation](https://gymnasium.farama.org/).\n",
|
|||
|
"* [Diving deeper into Reinforcement Learning with Q-Learning, Thomas Simonini](https://medium.freecodecamp.org/diving-deeper-into-reinforcement-learning-with-q-learning-c18d0db58efe).\n",
|
|||
|
"* Illustrations by [Thomas Simonini](https://github.com/simoninithomas/Deep_reinforcement_learning_Course) and [Sung Kim](https://www.youtube.com/watch?v=xgoO54qN4lY).\n",
|
|||
|
"* [Frozen Lake solution with TensorFlow](https://analyticsindiamag.com/openai-gym-frozen-lake-beginners-guide-reinforcement-learning/)\n",
|
|||
|
"* [Deep Q-Learning for Doom](https://medium.freecodecamp.org/an-introduction-to-deep-q-learning-lets-play-doom-54d02d8017d8)\n",
|
|||
|
"* [Intro OpenAI Gym with Random Search and the Cart Pole scenario](http://www.pinchofintelligence.com/getting-started-openai-gym/)\n",
|
|||
|
"* [Q-Learning for the Taxi scenario](https://www.oreilly.com/learning/introduction-to-reinforcement-learning-and-openai-gym)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Licence"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"The notebook is freely licensed under under the [Creative Commons Attribution Share-Alike license](https://creativecommons.org/licenses/by/2.0/). \n",
|
|||
|
"\n",
|
|||
|
"© Carlos Á. Iglesias, Universidad Politécnica de Madrid."
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"datacleaner": {
|
|||
|
"position": {
|
|||
|
"top": "50px"
|
|||
|
},
|
|||
|
"python": {
|
|||
|
"varRefreshCmd": "try:\n print(_datacleaner.dataframe_metadata())\nexcept:\n print([])"
|
|||
|
},
|
|||
|
"window_display": false
|
|||
|
},
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.10.10"
|
|||
|
},
|
|||
|
"latex_envs": {
|
|||
|
"LaTeX_envs_menu_present": true,
|
|||
|
"autocomplete": true,
|
|||
|
"bibliofile": "biblio.bib",
|
|||
|
"cite_by": "apalike",
|
|||
|
"current_citInitial": 1,
|
|||
|
"eqLabelWithNumbers": true,
|
|||
|
"eqNumInitial": 1,
|
|||
|
"hotkeys": {
|
|||
|
"equation": "Ctrl-E",
|
|||
|
"itemize": "Ctrl-I"
|
|||
|
},
|
|||
|
"labels_anchors": false,
|
|||
|
"latex_user_defs": false,
|
|||
|
"report_style_numbering": false,
|
|||
|
"user_envs_cfg": false
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 1
|
|||
|
}
|