1
0
mirror of https://github.com/gsi-upm/sitc synced 2024-11-17 20:12:28 +00:00
sitc/ml1/2_5_1_kNN_Model.ipynb

572 lines
82 KiB
Plaintext
Raw Permalink Normal View History

2016-03-15 12:55:14 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![](files/images/EscUpmPolit_p.gif \"UPM\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Course Notes for Learning Intelligent Systems"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-02-28 10:32:00 +00:00
"Department of Telematic Engineering Systems, Universidad Politécnica de Madrid, © Carlos A. Iglesias"
2016-03-15 12:55:14 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [Introduction to Machine Learning](2_0_0_Intro_ML.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table of Contents\n",
"* [kNN Model](#kNN-Model)\n",
"* [Load data and preprocessing](#Load-data-and-preprocessing)\n",
"* [Train classifier](#Train-classifier)\n",
"* [Evaluating the algorithm](#Evaluating-the-algorithm)\n",
" * [Precision, recall and f-score](#Precision,-recall-and-f-score)\n",
"\t* [Confusion matrix](#Confusion-matrix)\n",
"\t* [K-Fold validation](#K-Fold-validation)\n",
"* [Tuning the algorithm](#Tuning-the-algorithm)\n",
"* [References](#References)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# kNN Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-03-28 10:26:20 +00:00
"The goal of this notebook is to learn how to train a model, make predictions with that model and evaluate these predictions.\n",
2016-03-15 12:55:14 +00:00
"\n",
"The notebook uses the [kNN (k nearest neighbors) algorithm](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-03-28 10:26:20 +00:00
"## Loading data and preprocessing\n",
2016-03-15 12:55:14 +00:00
"\n",
"The first step is loading and preprocessing the data as explained in the previous notebooks."
]
},
{
"cell_type": "code",
2019-03-06 16:44:30 +00:00
"execution_count": 1,
"metadata": {},
2016-03-15 12:55:14 +00:00
"outputs": [],
"source": [
"# library for displaying plots\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# display plots in the notebook \n",
2016-03-28 10:26:20 +00:00
"%matplotlib inline"
]
},
{
"cell_type": "code",
2019-03-06 16:44:30 +00:00
"execution_count": 2,
"metadata": {},
2016-03-28 10:26:20 +00:00
"outputs": [],
"source": [
2016-03-15 12:55:14 +00:00
"## First, we repeat the load and preprocessing steps\n",
"\n",
"# Load data\n",
"from sklearn import datasets\n",
"iris = datasets.load_iris()\n",
"\n",
"# Training and test spliting\n",
"from sklearn.model_selection import train_test_split\n",
2016-03-15 12:55:14 +00:00
"\n",
"x_iris, y_iris = iris.data, iris.target\n",
"\n",
"# Test set will be the 25% taken randomly\n",
"x_train, x_test, y_train, y_test = train_test_split(x_iris, y_iris, test_size=0.25, random_state=33)\n",
"\n",
"# Preprocess: normalize\n",
"from sklearn import preprocessing\n",
"scaler = preprocessing.StandardScaler().fit(x_train)\n",
"x_train = scaler.transform(x_train)\n",
"x_test = scaler.transform(x_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-03-15 12:55:14 +00:00
"source": [
"## Train classifier"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The usual steps for creating a classifier are:\n",
"1. Create classifier object\n",
"2. Call *fit* to train the classifier\n",
"3. Call *predict* to obtain predictions\n",
"\n",
"Once the model is created, the most relevant methods are:\n",
"* model.fit(x_train, y_train): train the model\n",
"* model.predict(x): predict\n",
"* model.score(x, y): evaluate the prediction"
]
},
{
"cell_type": "code",
2019-03-06 16:44:30 +00:00
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
" metric_params=None, n_jobs=None, n_neighbors=15, p=2,\n",
" weights='uniform')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
2016-03-15 12:55:14 +00:00
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"import numpy as np\n",
"\n",
"# Create kNN model\n",
"model = KNeighborsClassifier(n_neighbors=15)\n",
"\n",
"# Train the model using the training sets\n",
"model.fit(x_train, y_train) "
]
},
{
"cell_type": "code",
2019-03-06 16:44:30 +00:00
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prediction [1 0 1 1 1 0 0 1 0 2 0 0 1 2 0 1 2 2 1 1 0 0 1 0 0 2 1 1 2 2 2 2 0 0 1 1 0\n",
" 1 2 1 2 0 2 0 1 0 2 1 0 2 2 0 0 2 0 0 0 2 2 0 1 0 1 0 1 1 1 1 1 0 1 0 1 2\n",
" 0 0 0 0 2 2 0 1 1 2 1 0 0 2 1 1 0 1 1 0 2 1 2 1 2 0 2 0 0 0 2 1 2 1 2 1 2\n",
" 0]\n",
"Expected [1 0 1 1 1 0 0 1 0 2 0 0 1 2 0 1 2 2 1 1 0 0 2 0 0 2 1 1 2 2 2 2 0 0 1 1 0\n",
" 1 2 1 2 0 2 0 1 0 2 1 0 2 2 0 0 2 0 0 0 2 2 0 1 0 1 0 1 1 1 1 1 0 1 0 1 2\n",
" 0 0 0 0 2 2 0 1 1 2 1 0 0 1 1 1 0 1 1 0 2 2 2 1 2 0 1 0 0 0 2 1 2 1 2 1 2\n",
" 0]\n"
]
}
],
2016-03-15 12:55:14 +00:00
"source": [
"print(\"Prediction \", model.predict(x_train))\n",
"print(\"Expected \", y_train)"
]
},
{
"cell_type": "code",
2019-03-06 16:44:30 +00:00
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy in training 0.9642857142857143\n"
]
}
],
2016-03-15 12:55:14 +00:00
"source": [
"# Evaluate Accuracy in training\n",
"\n",
"from sklearn import metrics\n",
"y_train_pred = model.predict(x_train)\n",
"print(\"Accuracy in training\", metrics.accuracy_score(y_train, y_train_pred))"
]
},
{
"cell_type": "code",
2019-03-06 16:44:30 +00:00
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy in testing 0.9210526315789473\n"
]
}
],
2016-03-15 12:55:14 +00:00
"source": [
"# Now we evaluate error in testing\n",
"y_test_pred = model.predict(x_test)\n",
"print(\"Accuracy in testing \", metrics.accuracy_score(y_test, y_test_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we are going to visualize the Nearest Neighbors classification. It will plot the decision boundaries for each class.\n",
"\n",
"We are going to import a function defined in the file [util_knn.py](files/util_knn.py) using the *magic command* **%run**."
]
},
{
"cell_type": "code",
2019-03-06 16:44:30 +00:00
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJztnXecVNX5h593ylaq9A6CWFBURLHGXiMaTcAeW0SMUROjJppojMaSmFiiRoMaY0t+kUQTCxhRI/aCiohYQHrvLLDLtnl/f5y77JQ7uzO7U7a8D5/9MHPumfe8994z3zn3PU1UFcMwDKNtEci3A4ZhGEbmMXE3DMNog5i4G4ZhtEFM3A3DMNogJu6GYRhtEBN3wzCMNoiJu4eI3CgiT7Z1P0TkcxE5zHstIvKoiGwQkQ9E5BAR+SoLZQ4UkS0iEsy0bc/+30XkO97r80TkrWyU01oQkQdF5PoU8/5VRH6TbZ/yQfx1EJFLRGSVVxe75diXZ0TkuFyW2WrFXUSeFJEVIlImIl+LyA9S+MyZIjLDu7krRGSqiBycC39bCqo6QlVf994eDBwN9FfV/VT1TVXdublliMhCETkqqszFqtpBVWuba9unrJHAnsB/Mm27gTIPF5H/icgmEVnoc3yhiFR49WyLiLycK98AVHWiqt6cCVsioiIyLBO2UixvsN81bQrR10FEwsCdwDFeXVyXiTIaQkSiJxHdDtyS7TKjabXiDtwGDFbVTsBJwG9EZJ9kmUXkSuBu4FagFzAQ+BNwcg58bakMAhaq6tZ8O9IMLgae0tzOxtsK/AW4uoE8Yz0R6aCqx+TILyM5vYAi4PN0P+g94TZLK1X1A6CTiIxujp10aLXirqqfq2pl3Vvvb6hfXhHpDNwEXKqqz6jqVlWtVtXnVdX3Cyoik0Vkpdc6e0NERkQdO0FE5ojIZhFZJiJXeendReQFEdkoIutF5M1klUJERojINC/fKhG5Lhd+1LWqReRC4GHgAK91+WsROUxElkbZH+A9Tq4RkXUicp+XPlREXvPS1orIUyLSxTv2BO6H83nP7jVea0xFJOTl6Ssiz3m+zRORi6LKvFFEnhaRx73z+ryRL8TxwPRkB0XkDhF5y6sDGUFVP1DVJ4D5mbIJICLni8jzUe/nicjTUe+XiMhe3utdourPVyIyPipfTKjFuwcrRGS5iPzApzXeVURe9K73+yIy1PvcG97xT717eVo6dTxD1yTG1+hzq6uvIvJTEVntneP58XlFZDhQF27cKCKveccPFJEPve/WhyJyYNRnXxeRW0TkbaAc2NFL+42IvONdj+dFpJtX/8s8G4MbOJ3XgW9n6NI0jqq22j9cy7scJ+wfAx2S5DsOqAFCDdi6EXgy6v0FQEegENfinxl1bAVwiPe6KzDKe30b8CAQ9v4OAcSnrI6ejZ/iWhMdgTG58ANYCBzlvT4PeCvK3mHAUu91EPgUuAso9fw82Ds2DBfOKQR6AG8Ad0fZ2V6G936wd49C3vvp3r0rAvYC1gBHRp3/NuAEz4fbgPeS3LNSz26PqLTzgLdwDZeHgP8CJUk+fyawsYG/gY3Uv6NwTz7x6QuBVd55vQzsmWJ93tErNwD0ARYBy6KObfCOlQJLgPOBEDAKWAuM8PL+FfhNVN1fCYwASoAnvGs2LCrvemA/z9ZTwP9F+bQ9bzp13Ms7q4Fr+6cUr0l8+dHndhjue32T58sJOD3o6pN3MLF1cAfvep7jnfcZ3vtu3vHXgcXedQt59l8H5uEakZ2BOcDXXj0IAY8DjzZwLlcCz2RaB5P9tdqWO4Cq/hAnfIcAzwCVSbJ2A9aqak0atv+iqpvVPR3cCOwZ1fqrBnYTkU6qukFVP45K7wMMUvdk8KZ6dzWOE4GVqvoHVd3mlfN+HvxoiP2AvsDV6p50tqnqW55P81R1mqpWquoaXCzz0FSMisgAXKz/Z57NmbgniHOisr2lqlPUxeifwMXU/eji/b85Lj0M/B33BR6rquV+H1bVv6lqlwb+FqdyTj6chROTQcD/gP/WPdk0hKrO985lL9z1/C+wTER28d6/qaoRXP1ZqKqPqmqNd9//BXzPx+x4nOB87l2HX/vkeUbd00gNTtz3asDNlOuWqo5s4Nr+sLHrkSLVwE2eL1OALUAq/UbfBuaq6hPeNfw78CUwNirPX73rVqOq1V7ao6r6japuAqYC36jqK961mwzs3UCZm6mvs1mnVYs7gKrWeqLTH7gEQFxHaV1n1lnAOqB7XVigMUQkKCK3i8g3IlKGa4kBdPf+/y6ulbBIRKaLyAFe+h24X/aXRWS+iPw8SREDgG9agB8NMQBY5PeDKCI9ReT/xIWCyoAno3xqjL7AelWNFuRFQL+o9yujXpcDRUnu3Ubv/45x6cNwfSm/VtWqFP3KGKr6tqpWqGq5qt6G8/OQFD8+Hdci/Zb3+nWcsB9KffhpEDDGC41sFJGNuB+U3j72+uJa+XUs8ckTf707NOBfJupWJlkXV0cb87+Ovrh6F018PfS7VquiXlf4vG+o7I7U19ms0+rFPYoQXsxdVY/X+s6sp4B3cY/630nR1pk4cTgK9/g12EsXz/6Hqnoy0BP4N/C0l75ZVX+qqjviWgBXisiRPvaXkKR/IMd+NMQSYGASUb0N94g7Ul2H9tl1Pnk09JSwHNhBRKIFeSCwLE3/UNcR/A0wPO7QF7iQxVQRSdqKE5GzohoBfn8D0/UpmavEXp+GqBP3Q7zX00kU9yXA9LiWcAdVvcTH3gpcw6eOAU3wfzvp1C2vvyTZtX0wxSLLceGkOvx+wJrCctyPZDTx9TDTnfS74kKdOaFVirvXcjxdRDp4rdtjcTGz1/zye49QNwD3i8h3RKRERMIicryI/M7nIx1xIZ51uIp1a1TZBZ4odPYe1cqAWu/YiSIyTEQkKt1v+N8LQG8R+bGIFIpIRxEZkwc/GuIDnDDcLiKlIlIkIgdF+bUF1znVj8RRI6twMeIEVHUJ8A5wm2dzJHAhLhzQFKbgExLyHrOvA14Rr4PQJ89TUY0Avz/fsIyIBESkCBf+Ee88CrxjA0XkIO/+FInI1binmre944dJ7BC5eKYDhwPFqroUeBMXN+8GfOLleQEYLiLnePU4LCL7isiuPvaeBs4XkV1FpAT3PUiHmHuZTt1SN+w22bWdmGL5M4Ezve/5caQY/kuBKbhreKaIhETkNGA33LXNFofiQjk5oVWKO+4X9RJgKa4T5PfAj1U16VhnVb0T16HxS1xH1xLgR7gWbzyP43Vm4TpN3os7fg6w0AtJTMS1XAF2Al7BCd+7uE6j13182YzrkByLeySei/tC59SPhvDi3WNxIY7FuGt9mnf417hOvE3Ai7j+jmhuA37phQyu8jF/Bu4pZDnwLPArVZ2Wjn9RTALO8sQm/hwew3W2vSYNj2JIl2/hHsGn4Fp7FbiOU3A/fA/g6uUynDAfr/Xjqgfg7okvqvo17r696b0vw43Kedu7J3X15xjgdNw1XAn8FtfBHW9vKvBHXOx/XlTZyfqn4rkReMy7l+PJQN1Kkytw9bAu9OT3fU0b736ciBvUsA64BjhRVddmwn48IrIvsFXdkMicUDeCwjBaLSLyN+BpVc3IFz+biMjDwGRV/W+eyt8VmA0U+vWnGNlBRP4FPOJ1+uamTBN3w2jbiMgpuCesUuAxIKKqqfY/Ga2U1hqWMQwjdS7GhSK/wcXH/TpejTaGtdwNwzDaINZyNwzDaIOkNKknG3Tv1EkH9+iRr+INo82wga75dsHIIfPnf7RWVRsVz7yJ++AePZhx++35Kt4w2gyTGZdvF4wcMn68xM+s9cXCMoZhGG0QE3fDMIw2iIm7YRhGG8TE3TBaMRZvN5Jh4m4YhtEGMXE3jFaKtdqNhjBxNwzDaIOYuBuGYbRB8jaJyTCMpmHhGCMVrOVuGIbRBjFxNwzDaIOYuBtGK8JCMkaqmLgbhmG0QUzcDcMw2iApibuILBSRz0RkpojM8DkuIvJHEZknIrNEZFTmXTUMwzBSJZ2hkIer6tokx44HdvL+xgAPeP8bhpEhLN5upEO
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJztnXecVNX5h593Z7axVCnSiyCoKCoQwRYVwRbRaAJq1NgiokaNNWISNXZjj5oY1Bhb/EWMGgsaUSOCHRVQQQXpvbPALlvf3x/nLszO3Jmd2Z2yO/s++5nPzpx75pz3nnvud859TxNVxTAMw8gucjJtgGEYhpF8TNwNwzCyEBN3wzCMLMTE3TAMIwsxcTcMw8hCTNwNwzCykGYv7iJyo4g8k+12iMg3InK4915E5AkR2Sgin4rIoSLyXQry7CkiW0UkkOy0vfSfE5Gfeu/PFpHpqcinqSAij4jIH+KM+w8RuSXVNiWLUHtTVV/TgYhcKiJ3pCOvJifuIvKMiKwUkWIR+V5EfhXHd34hIjM8oVkpIm+IyCHpsLexoKoDVfU97+MhwCigu6oeoKrTVHVAQ/MQkUUiMjIkzyWq2lJVqxqatk9eg4B9gf8kO+0YeR4hIv8Tkc0issjn+CIRKfXq2VYReStdtgGo6nhVvTkZaYmIiki/ZKQVZ369/crUj3jrayNquL1X07ACJgJniEinVOfb5MQduB3oraqtgROAW0RkSLTIInIFcD9wG7Ar0BP4C3BiGmxtrPQCFqnqtkwb0gAuAJ7V9M7C2wb8Hbg6RpzR3g9aS1U9Kk12GU0EVd0OvAH8MtV5NTlxV9VvVLWs5qP36usXV0TaADcBF6vqi6q6TVUrVPVVVfW9QUVkkois8lpn74vIwJBjx4nIHBHZIiLLReQqL7yDiLwmIptEZIOITBMR37IVkYEiMsWLt1pErkuHHTWtahE5D3gMONBrXf5RRA4XkWUh6fcQkRdFZK2IrBeRh7zwviLyrhe2TkSeFZG23rGncT+cr3rpXuO1xlREgl6criLyimfbfBE5PyTPG0XkeRF5yjuvb0RkqF/ZeBwLTI12UETuEpHpXh1ICqr6qao+DSxIVpoAInKOiLwa8nm+iDwf8nmpiOznvd8jpP58JyJjQ+LVcrV412CliKwQkV/5tMbbicjrXnl/IiJ9ve+97x2f5V3LUxKp40kqk/1F5AvPtn8BBSHHwuvrb737YItXJkeKyDHAdcAp3jnM8uKeIyJzvbgLROSC8HRF5EoRWeOV3TkhxwtF5B4RWezdl9NFpNA7NlxEPvTKZ5bsbKn78R7wkyQVVXRUtcm9cC3vEpywfwG0jBLvGKASCMZI60bgmZDP5wKtgHxci39myLGVwKHe+3bAYO/97cAjQK73OhQQn7xaeWlciausrYBh6bADWASM9N6fDUwPSe9wYJn3PgDMAu4Dijw7D/GO9cO5c/KBjsD7wP0h6ezIw/vc27tGQe/zVO/aFQD7AWuBI0POfztwnGfD7cDHUa5ZkZdux5Cws4HpuAbLo8B/gRZRvv8LYFOMV8866t9I3JNPePgiYLV3Xm8B+8ZZn3fz8s0BugCLgeUhxzZ6x4qApcA5QBAYDKwDBnpx/wHcElL3VwEDgRbA016Z9QuJuwE4wEvrWeD/QmzaETeROu7FnR2jbP8SR3nkeWVwuZfXz4GKkHM7nJ31dYBXJl1D6lxfv3vKC/sJrjEowGE4HRkckm4lrkGYi6uLJUA77/jDOGHuhqujB+HuhW7Aei9+Du4eWU9I/QyzYTCwIRXaGPpqci13AFW9CCd8hwIvAmVRorYH1qlqZQJp/11Vt6h7OrgR2Dek9VcB7CUirVV1o6p+ERLeBeil7slgmnpXMYzjgVWqeo+qbvfy+SQDdsTiAKArcLW6J53tqjrds2m+qk5R1TJVXQvci7tB6kREeuB8/b/10pyJe4I4MyTadFWdrM5H/zTOp+5HW+//lrDwXOA5YBece6TE78uq+k9VbRvjtSSec/LhdJy49AL+B/y35skmFqq6wDuX/XDl+V9guYjs4X2epqrVuPqzSFWfUNVK77r/Gyd+4YwFnlD3pFsC/NEnzovqnkYqceK+Xwwz465bqjooRtleVFd5AMNx1/J+L68XgM+ixK3CCexeIpKrqotU9YdoCavq66r6gzqm4n6EDw07z5u8fCcDW4EB3lPKucBlqrpcVatU9UPv/jwDmOzV3WpVnQLMwIm9H1uApD1RRqNJijuAV7jTge7AhQDiOkprOrNOx/16dqhxC9SFiARE5A4R+UFEinEtMYAO3v+f4S7YYhGZKiIHeuF3AfOBt7xHvWujZNEDiFrx0mhHLHoAi/1+EEWkk4j8n/cIXAw8E2JTXXTFtVZCBXkxrtVTw6qQ9yVAQZRrt8n73yosvB+uL+WPqloep11JQ1U/UNVSVS1R1dtxdh5a1/c8puJajj/23r+HE/bD2Ol+6gUM8x79N4nIJtwPSmef9LriWrQ1LPWJE17eLWPYl4y6FS9dcU8uoT8ei/0iqup84De4BtAar352jZawiBwrIh97rqVNuPsotA6vD6v7NeXSAffE6Xf/9gLGhF2XQ3A/hn60AjZHszFZNFlxDyGI53NX1WN1Z2fWs8BHuEf9n8aZ1i9w4jAS98va2wsXL/3PVPVEoBPwMvC8F75FVa9U1d2A0cAVInKkT/pLidI/kGY7YrEU6BlFVG/HPa4PUtehfUaNTR6xnhJWALuISKgg9wSWJ2gf6jqCfwD6hx2ai3NZvCEiUUdTiMjpIY0Av1fPRG2KZiq1yycWNeJ+qPd+KpHivhSYGtYSbqmqF/qktxLX8KmhRz3s30EidUtcf0m0sn0kjuxWAt1EJLTsol4T70nsEJzIKnBnzaEwu/JxTzp3A7uqaltgMvFdo3U4LfG7f5cCT4ddlyJVjTbkcU+c6zOlNClx91qOp4pIS691ezRwGvCuX3xV3QxcDzwsIj8VkRYikuv9ev/J5yutcC6e9Tg/5W0heed5otBGVSuAYtwjISJyvIj08ypjTbjf8L/XgM4i8hsRyReRViIyLAN2xOJT3M11h4gUiUiBiBwcYtdWYJOIdCNy1MhqnI84AlVdCnwI3O6lOQg4D+cOqA+T8XEJqepzuI60t8XrIPSJ82xII8Dv5euWEZEcESnAuQzEO48871hPETnYuz4FInI1rrX3gXf8cBGJ9eM3FTgCKFTVZcA0nN+8PfClF+c1oL+InOnV41wR+ZGI7OmT3vPAOSKyp4i0wN0HiVDrWiZSt9QNu41WtuPjyPsjnO/7UhEJisjJOHdhBCIyQERGeMK9HSgNsWs10Ft2dvzm4Vw4a4FKETkWiGtEk+cW+ztwr7iBAQEROdDL9xlgtIgc7YUXeNe7e5TkDsONmEkpTUrccb/EFwLLcJ1MdwO/UdWoY51V9V7gCuD3uIu6FPg1rsUbzlN4nVnAHODjsONnAos8l8R4XMsVYHfgbZzwfYTrNHrPx5YtuM6W0bhH4nm4GzqtdsRCnb97NM7FsQRX1qd4h/+I6wzaDLyO6+8I5Xbg996j6VU+yZ+GewpZAbwE3OD5J+vDROD0sNZdzTk8iesUe1dEetczfT9+jBOPybiWZCnOZwvuh++vuHq5HCfMx6rqeu94D9w18UVVv8ddt2ne52LcqJwPvGtSU3+OAk7FleEqXCs13ye9N4A/43z/80PyjtY/Fc6NwJPetRxLEupWvHgutZNxneQbcfUvvK7VkA/cgWtZr8I9zdaMQJvk/V8vIl945Xcp7odvI+4J+ZUETLsK+Arn/9+AK/scr+FyopdvjcZcjY++eo2D44AnE8i3XtSMpDCMJoeI/BN4XlX9fqgbFSLyGDBJVf+bofz3BL4G8v36U4z0ICKXAD1U9ZqU52XibhjZiYichHvCKsK1FKtVNd7+J6OJ09TcMoZhxM8FODfBDzg/tF/Hq5GlWMvdMAwjC7GWu2EYRhYS1+SeVNChdWvt3bFjprI3jKxhI+0ybYKRRhYs+HydqtYpnhkT994dOzLjjrQsa2wYWc0kxmTaBCO
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-03-15 12:55:14 +00:00
"source": [
"%run util_knn.py\n",
"\n",
2016-03-15 12:55:14 +00:00
"plot_classification_iris()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluating the algorithm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Precision, recall and f-score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For evaluating classification algorithms, we usually calculate three metrics: precision, recall and F1-score\n",
"\n",
"* **Precision**: This computes the proportion of instances predicted as positives that were correctly evaluated (it measures how right our classifier is when it says that an instance is positive).\n",
"* **Recall**: This counts the proportion of positive instances that were correctly evaluated (measuring how right our classifier is when faced with a positive instance).\n",
"* **F1-score**: This is the harmonic mean of precision and recall, and tries to combine both in a single number."
]
},
{
"cell_type": "code",
2019-03-06 16:44:30 +00:00
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" setosa 1.00 1.00 1.00 8\n",
" versicolor 0.79 1.00 0.88 11\n",
" virginica 1.00 0.84 0.91 19\n",
"\n",
" micro avg 0.92 0.92 0.92 38\n",
" macro avg 0.93 0.95 0.93 38\n",
"weighted avg 0.94 0.92 0.92 38\n",
"\n"
]
}
],
2016-03-15 12:55:14 +00:00
"source": [
2016-03-28 10:26:20 +00:00
"print(metrics.classification_report(y_test, y_test_pred, target_names=iris.target_names))"
2016-03-15 12:55:14 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Confusion matrix"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another useful metric is the confusion matrix"
]
},
{
"cell_type": "code",
2019-03-06 16:44:30 +00:00
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 8 0 0]\n",
" [ 0 11 0]\n",
" [ 0 3 16]]\n"
]
}
],
2016-03-15 12:55:14 +00:00
"source": [
"print(metrics.confusion_matrix(y_test, y_test_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-03-15 12:55:14 +00:00
"source": [
"We see we classify well all the 'setosa' and 'versicolor' samples. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### K-Fold validation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to avoid bias in the training and testing dataset partition, it is recommended to use **k-fold validation**."
]
},
{
"cell_type": "code",
2019-03-06 16:44:30 +00:00
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.93333333 0.8 1. 0.93333333 0.93333333 0.93333333\n",
" 1. 1. 0.86666667 0.93333333]\n"
]
}
],
2016-03-15 12:55:14 +00:00
"source": [
"from sklearn.model_selection import cross_val_score, KFold\n",
2016-03-15 12:55:14 +00:00
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# create a composite estimator made by a pipeline of preprocessing and the KNN model\n",
"model = Pipeline([\n",
" ('scaler', StandardScaler()),\n",
" ('kNN', KNeighborsClassifier())\n",
"])\n",
"\n",
"# create a k-fold cross validation iterator of k=10 folds\n",
"cv = KFold(10, shuffle=True, random_state=33)\n",
2016-03-15 12:55:14 +00:00
"\n",
"# by default the score used is the one returned by score method of the estimator (accuracy)\n",
"scores = cross_val_score(model, x_iris, y_iris, cv=cv)\n",
"print(scores)"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-03-15 12:55:14 +00:00
"source": [
"We get an array of k scores. We can calculate the mean and the standard error to obtain a final figure"
]
},
{
"cell_type": "code",
2019-03-06 16:44:30 +00:00
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean score: 0.933 (+/- 0.020)\n"
]
}
],
2016-03-15 12:55:14 +00:00
"source": [
"from scipy.stats import sem\n",
"def mean_score(scores):\n",
" return (\"Mean score: {0:.3f} (+/- {1:.3f})\").format(np.mean(scores), sem(scores))\n",
"print(mean_score(scores))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So, we get an average accuracy of 0.940."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tuning the algorithm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We are going to tune the algorithm, and calculate which is the best value for the k parameter."
]
},
{
"cell_type": "code",
2019-03-06 16:44:30 +00:00
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'Accuracy')"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJztvXuYY/lZ3/l5dasqSdXdkrqqZ6a73T3TVY49LMY2jSHmYi8EMzasHdvZ2E4WzCXrJYl3E3icB3shDs8Ex5g4WSD2snESLzYQfJldEsMOGDNrAw+LYXqwxzAeZkrdnkv3jLuqS9XdJdVFt9/+cc5PpVbrcnR0jqTSeT/P00+rpCPpd3Sk3/d3ed/vK8YYFEVRFKUfsUk3QFEURZl+VCwURVGUgahYKIqiKANRsVAURVEGomKhKIqiDETFQlEURRlIqGIhIveJyBMiUhSRd3d5/IyIPCQiXxGRL4jIqbbHGiLyZfffZ8Jsp6IoitIfCSvPQkTiwJPA9wKXgYeBtxljvtp2zKeB3zHGfExEvhv4EWPMD7qPlY0x2VAapyiKogxFmDOLVwBFY8wlY0wV+ATwho5j7gUecm9/vsvjiqIoyhSQCPG1TwLPtv19GfjWjmMeBd4M/BLwRmBRRArGmE1gXkQuAHXg540x/6XzDUTkHcA7ADKZzDe/6EUvCv4sFEVRZphHHnnkmjFmadBxYYqFdLmvc83rXcCHROSHgT8CruCIA8ALjDHPicg9wP8rIn9pjLl4y4sZ8xHgIwDnz583Fy5cCLL9iqIoM4+IPO3luDDF4jJwuu3vU8Bz7QcYY54D3gQgIlngzcaYG22PYYy5JCJfAF4G3CIWiqIoyngIc8/iYWBVRO4WkRTwVuCWqCYROS4itg3vAT7q3p8TkTl7DPDtwFdRFEVRJkJoYmGMqQPvBD4LPA58yhjzmIjcLyKvdw97NfCEiDwJnADe597/YuCCiDyKs/H98+1RVIqiKMp4CS10dtzonoWiKMrwiMgjxpjzg47TDG5FURRlICoWiqIoykBULBRFUZSBRF4sbu7V+N8+9yRffvb6pJuiTIgvPLHOM5s7k27GxPjTi5usXd2edDOUKSfyYmGa8EsPrXHhqdKkm6JMiHf+5y/xkT+ObgrPP3vgUX7xobVJN0OZciIvFovzCeIxoVSpTropygTYqzUo79e5th3d63+tvM+17f1JN0OZciIvFrGYkEun2NqJbmcRZTbdQUJUBws71Tp7tWZkz1/xTuTFAqCQSbFZ1h9LFCm5132zEs2Rtf3eq1gog1CxAPKZlP5YIkrJnVFu7dQm3JLJsNU6/yrN5mwk6CrhoGKBikWUKbkziq2dKo0IdpZ2Ga5p4MZuNAVT8YaKBa5Y6J5FJLHLMMbA9Qh+B0pty6+bOmBS+qBigSMW13dq1BvNSTdFGTPtM8oozi6jfv6Kd1QsgEI2BUR33TrKtEfBRbGzLEX8/BXvqFgAubQjFvpjiR6b5SrxmFPUMYrXvxTx81e8o2KBEzoL+mOJIqVKlTOFNBDNNfvNSpUzeef8SxENH1a8oWIB5LMqFlGlVKmyspRt3Y4apco+dxydJzuXiKRYKt5RscDZ4AYdWUWR0k6VO47OsziXiKRYbO3UyGdS5DMptiJ4/op3VCw42LPQkVW0qDeaXLedZTaauTab5X0Krljo91/ph4oFkIzHODKf0JFVxLDRb/lMilw6emJRazS5uVcn54pF1M5fGQ4VC5dCdk5HVhHDdo75TMrxB4vY9beDo4KKheIBFQsX/bFEj3axiOKavc2xyGfmKLjff2OiZ3mieEPFwiWKyxBRp9QaWc+19iyi1Flaqw8rlvv1JjvVxoRbpUwrKhYuBZ1ZRA4b/ZbLJMmnU1QbTcr79Qm3anxsts2scpprpAxAxcIln3UKIEVpZBl1bGeZS6fawqej01l27tmARgQqvVGxcClkUtQahpt70RlZRp2tSpWjC0mS8VjLHyyKYpFLJ1tiGbV9G8U7KhYu6g8VPTYr1VYnmc/MAdG6/qVKlWPpJIl4rPU56MxC6YWKhYtafkSPUrtYRDAxs1Spts5bXQyUQahYuKiZYPS4RSwiOFjYrOy3zj87lyAVj0VKLJXhULFw0ZFV9ChVqq1BQiYVJ5WIRWrNfqtSa33vRSSSuSaKd0IVCxG5T0SeEJGiiLy7y+NnROQhEfmKiHxBRE51PH5ERK6IyIfCbCega7YRwxjD1k61FTIqIpHL4t6sVFsb+wA5DR9X+hCaWIhIHPgw8FrgXuBtInJvx2EfBD5ujHkJcD/w/o7H/yXwh2G1sZ10KsF8Mlojyyhzc69OrWFaMwuIVmJms+mKZfrg/KMmlspwhDmzeAVQNMZcMsZUgU8Ab+g45l7gIff259sfF5FvBk4Avx9iG2+hkFF/qKjQnmNgKWSj01ne3KvRaJpbzl8tb5R+hCkWJ4Fn2/6+7N7XzqPAm93bbwQWRaQgIjHg3wD/rN8biMg7ROSCiFzY2NgYucH6Y4kO3cQiSmv2LauTrIqF4o0wxUK63NeZHv0u4FUi8iXgVcAVoA78I+BBY8yz9MEY8xFjzHljzPmlpaWRG6xrttGhl1hE5fofnP9c6758JsX2Xp1qvTmpZilTTCLE174MnG77+xTwXPsBxpjngDcBiEgWeLMx5oaI/E3gO0XkHwFZICUiZWPMbZvkQVLIpLi0UQ7zLZQpwUa93SIW6RTl/Tr79QZzifikmjYWWr5Q6VvFEmBrp8qJI/MTaZcyvYQ5s3gYWBWRu0UkBbwV+Ez7ASJy3F1yAngP8FEAY8zfN8a8wBhzFmf28fGwhQKiNbKMOrazLLSPrCOUa9GaWWRv3eAG2CzP/vkrwxOaWBhj6sA7gc8CjwOfMsY8JiL3i8jr3cNeDTwhIk/ibGa/L6z2eCGfSbFTbbBXU5vmWWerUmUhGWchdTCDiFJi5oE9e/eZhaJ0EuYyFMaYB4EHO+57b9vtB4AHBrzGrwK/GkLzbqM91+LksYVxvKUyIdp9oSxR8ocqVaqkU3HmkwdiqblGSj80g7sNdd6MDqWuYpFsPTbrlCq35lhAm4tBWV0MlNtRsWhDPf2jQ3excGYWUViz78zeBjiWTiESDbFUhkfFog31h4oO7b5QlmMLSWISjTX7rS5iGY+Jk8UegfNXhkfFoo28RoNEhlLlwBfKEnM7yyjMLLvNrMAphKQzC6UbKhZtHJlPEo9JJEaWUWav1mCn2ujeWWZSlCIwWNis7N+SY2EpZOZ0sKR0RcWiDTuy1JHVbLPZJWzUEoVcm51qnb1a85YcC0sUzl/xh4pFB4VMSkdWM85Wy+qi28h69tfsu+VYWPLZlM6sla6oWHSQy+ia7ayz2UcsojCy7uYLZcmnU2zt1Gg2O23clKijYtFBITM38yPLqNPNF8qSzzgj68YMd5YHYpm87bF8JkWjabixWxt3s5QpR8WigyiMLKNOqeJ0hIVuI+tMCmOY6c5yq8/MwuZe6IBJ6UTFooN8JsX1nRr1hto0zyqlyj7xmHBk4Xa3myjk2nSzZ7ccnL+KhXIrKhYdHJipze7IMupYqwuR20uuRCHXZrNSJRETjszfLpbWAmSWz1/xh4pFB+q8Oftslm/P3rZEYWRdKjsJid3EshAhm3ZlOFQsOlBP/9lna6d79jIc7GPM8pp9aWewWOpgSelExaKDKBXAiSqblWrXhDRwQqeBmc7i7mX1ATCXiJOdS+hgSbkNFYsOrAXCLG9wRp1SpdrV6gLaOssZHix088Vqx8k10u+/cisqFh3kWmvWusE9i9QbTa7v1HqOrGH2w6c3y/s9l6HACamdZbFU/KFi0UEyHuPIfEJHVjPKdTd/orOWQzs2MW8WqTWa3Nyr9xXLwgyfv+IfFYsuFLI6sppV+uUYWGbZH8yKQP+ZRTScd5XhULHoQi6d1JHVjGJFoNeeBbg25TM6WLDn1W/PIp9xanoYM7uWJ8rwqFh0Ia+e/jNLa2bRZxmq4IrFLHaWdsYwaM9mv95kp9oYV7OUQ4CKRRcKMzyyjDo2f2JQZ1ltNKnMYGdZai1
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-03-15 12:55:14 +00:00
"source": [
"k_range = range(1, 21)\n",
"accuracy = []\n",
"for k in k_range:\n",
" m = KNeighborsClassifier(k)\n",
" m.fit(x_train, y_train)\n",
" y_test_pred = m.predict(x_test)\n",
" accuracy.append(metrics.accuracy_score(y_test, y_test_pred))\n",
"plt.plot(k_range, accuracy)\n",
"plt.xlabel('k value')\n",
"plt.ylabel('Accuracy')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The result is very dependent of the input data. Execute again the train_test_split and test again how the result changes with k."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* [KNeighborsClassifier API scikit-learn](http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html)\n",
"* [Learning scikit-learn: Machine Learning in Python](http://proquest.safaribooksonline.com/book/programming/python/9781783281930/1dot-machine-learning-a-gentle-introduction/ch01s02_html), Raúl Garreta; Guillermo Moncecchi, Packt Publishing, 2013.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Licence\n",
"The notebook is freely licensed under under the [Creative Commons Attribution Share-Alike license](https://creativecommons.org/licenses/by/2.0/). \n",
"\n",
2019-02-28 10:32:00 +00:00
"© Carlos A. Iglesias, Universidad Politécnica de Madrid."
2016-03-15 12:55:14 +00:00
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2019-03-06 16:44:30 +00:00
"version": "3.7.1"
2019-02-28 10:32:00 +00:00
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
"autocomplete": true,
"bibliofile": "biblio.bib",
"cite_by": "apalike",
"current_citInitial": 1,
"eqLabelWithNumbers": true,
"eqNumInitial": 1,
"hotkeys": {
"equation": "Ctrl-E",
"itemize": "Ctrl-I"
},
"labels_anchors": false,
"latex_user_defs": false,
"report_style_numbering": false,
"user_envs_cfg": false
2016-03-15 12:55:14 +00:00
}
},
"nbformat": 4,
"nbformat_minor": 1
2016-03-15 12:55:14 +00:00
}