1
0
mirror of https://github.com/gsi-upm/sitc synced 2024-11-14 02:32:27 +00:00
sitc/ml1/2_5_1_kNN_Model.ipynb

550 lines
78 KiB
Plaintext
Raw Normal View History

2016-03-15 12:55:14 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![](files/images/EscUpmPolit_p.gif \"UPM\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Course Notes for Learning Intelligent Systems"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Department of Telematic Engineering Systems, Universidad Politécnica de Madrid, © 2016 Carlos A. Iglesias"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [Introduction to Machine Learning](2_0_0_Intro_ML.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table of Contents\n",
"* [kNN Model](#kNN-Model)\n",
"* [Load data and preprocessing](#Load-data-and-preprocessing)\n",
"* [Train classifier](#Train-classifier)\n",
"* [Evaluating the algorithm](#Evaluating-the-algorithm)\n",
" * [Precision, recall and f-score](#Precision,-recall-and-f-score)\n",
"\t* [Confusion matrix](#Confusion-matrix)\n",
"\t* [K-Fold validation](#K-Fold-validation)\n",
"* [Tuning the algorithm](#Tuning-the-algorithm)\n",
"* [References](#References)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# kNN Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-03-28 10:26:20 +00:00
"The goal of this notebook is to learn how to train a model, make predictions with that model and evaluate these predictions.\n",
2016-03-15 12:55:14 +00:00
"\n",
"The notebook uses the [kNN (k nearest neighbors) algorithm](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-03-28 10:26:20 +00:00
"## Loading data and preprocessing\n",
2016-03-15 12:55:14 +00:00
"\n",
"The first step is loading and preprocessing the data as explained in the previous notebooks."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
2016-03-15 12:55:14 +00:00
"outputs": [],
"source": [
"# library for displaying plots\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# display plots in the notebook \n",
2016-03-28 10:26:20 +00:00
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
2016-03-28 10:26:20 +00:00
"outputs": [],
"source": [
2016-03-15 12:55:14 +00:00
"## First, we repeat the load and preprocessing steps\n",
"\n",
"# Load data\n",
"from sklearn import datasets\n",
"iris = datasets.load_iris()\n",
"\n",
"# Training and test spliting\n",
"from sklearn.model_selection import train_test_split\n",
2016-03-15 12:55:14 +00:00
"\n",
"x_iris, y_iris = iris.data, iris.target\n",
"\n",
"# Test set will be the 25% taken randomly\n",
"x_train, x_test, y_train, y_test = train_test_split(x_iris, y_iris, test_size=0.25, random_state=33)\n",
"\n",
"# Preprocess: normalize\n",
"from sklearn import preprocessing\n",
"scaler = preprocessing.StandardScaler().fit(x_train)\n",
"x_train = scaler.transform(x_train)\n",
"x_test = scaler.transform(x_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"## Train classifier"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The usual steps for creating a classifier are:\n",
"1. Create classifier object\n",
"2. Call *fit* to train the classifier\n",
"3. Call *predict* to obtain predictions\n",
"\n",
"Once the model is created, the most relevant methods are:\n",
"* model.fit(x_train, y_train): train the model\n",
"* model.predict(x): predict\n",
"* model.score(x, y): evaluate the prediction"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
2016-03-15 12:55:14 +00:00
"outputs": [
{
"data": {
"text/plain": [
"KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
" metric_params=None, n_jobs=1, n_neighbors=15, p=2,\n",
" weights='uniform')"
]
},
"execution_count": 3,
2016-03-15 12:55:14 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"import numpy as np\n",
"\n",
"# Create kNN model\n",
"model = KNeighborsClassifier(n_neighbors=15)\n",
"\n",
"# Train the model using the training sets\n",
"model.fit(x_train, y_train) "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
2016-03-15 12:55:14 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prediction [1 0 1 1 1 0 0 1 0 2 0 0 1 2 0 1 2 2 1 1 0 0 1 0 0 2 1 1 2 2 2 2 0 0 1 1 0\n",
" 1 2 1 2 0 2 0 1 0 2 1 0 2 2 0 0 2 0 0 0 2 2 0 1 0 1 0 1 1 1 1 1 0 1 0 1 2\n",
" 0 0 0 0 2 2 0 1 1 2 1 0 0 2 1 1 0 1 1 0 2 1 2 1 2 0 2 0 0 0 2 1 2 1 2 1 2\n",
" 0]\n",
"Expected [1 0 1 1 1 0 0 1 0 2 0 0 1 2 0 1 2 2 1 1 0 0 2 0 0 2 1 1 2 2 2 2 0 0 1 1 0\n",
" 1 2 1 2 0 2 0 1 0 2 1 0 2 2 0 0 2 0 0 0 2 2 0 1 0 1 0 1 1 1 1 1 0 1 0 1 2\n",
" 0 0 0 0 2 2 0 1 1 2 1 0 0 1 1 1 0 1 1 0 2 2 2 1 2 0 1 0 0 0 2 1 2 1 2 1 2\n",
" 0]\n"
]
}
],
"source": [
"print(\"Prediction \", model.predict(x_train))\n",
"print(\"Expected \", y_train)"
]
},
{
"cell_type": "code",
2016-03-28 10:26:20 +00:00
"execution_count": 7,
"metadata": {},
2016-03-15 12:55:14 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy in training 0.964285714286\n"
]
}
],
"source": [
"# Evaluate Accuracy in training\n",
"\n",
"from sklearn import metrics\n",
"y_train_pred = model.predict(x_train)\n",
"print(\"Accuracy in training\", metrics.accuracy_score(y_train, y_train_pred))"
]
},
{
"cell_type": "code",
2016-03-28 10:26:20 +00:00
"execution_count": 8,
"metadata": {},
2016-03-15 12:55:14 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy in testing 0.921052631579\n"
]
}
],
"source": [
"# Now we evaluate error in testing\n",
"y_test_pred = model.predict(x_test)\n",
"print(\"Accuracy in testing \", metrics.accuracy_score(y_test, y_test_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we are going to visualize the Nearest Neighbors classification. It will plot the decision boundaries for each class.\n",
"\n",
"We are going to import a function defined in the file [util_knn.py](files/util_knn.py) using the *magic command* **%run**."
]
},
{
"cell_type": "code",
2016-03-28 10:26:20 +00:00
"execution_count": 12,
"metadata": {},
2016-03-15 12:55:14 +00:00
"outputs": [
{
"data": {
2016-03-28 10:26:20 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXEAAAEKCAYAAADkYmWmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl4U9XWwOHfatqkVCYFFBTBAqLQiwIKVJDJERHUTy2K\nEwiKoggqOKICXsV5FgcUmS6ooNcBRXBiEEFAmSuKegVRBmVqS9MkTbu/P3JaQklK0iZN0673eXhI\ndnb2XufkdGVnn0mMMSillIpPCbEOQCmlVNlpEldKqTimSVwppeKYJnGllIpjmsSVUiqOaRJXSqk4\nVq2TuIiMEZHpVT0OEdkgIt38nk8WkT0i8p2InCkiG6PQ5/Eiki0iEum2rfZnishF1uMBIvJNNPqJ\nFyLyqoiMDrHuZBF5ONoxxULJ9SAiQ0Vkh7UtHlnBsbwnIudHu5+4SuIiMl1EtolIloj8JCKDQ3jP\nVSKyUkRyROQvEflURDr7VaksB8pHLQ5jzL+MMYsBRORM4GzgWGNMujFmiTGmVXn7EJHfReQsvz63\nGmNqmyiciCAibYBTjDEf+xVH9XMUkR4i8rWI7BOR/wV4fbOIOK1kkS0i86IZT0nGmKHGmEcj0ZaI\nFIpIs0i0FWJ/TUXk90i05b8eRCQReAY4x9oW90aij9KISKHf0yeAiHwmpYmrJA6MB5oaY+oAFwGP\niEi7YJVF5E7gWeAR4GigCfAKcHEFxFpZnQBsNsa4Yh1IOdwEzKjgPnOBScCoIK8b4EIrWdQ2xvSq\nuNAiLhYDm2j02RBwAGX6pVneX5HGmJVALRFpX552DieukrgxZqMxJt96Kvg++OaB6opIbWAccIsx\n5iNjTJ4xpsAY86kx5p4g75klIttFZK+ILBSR1n6v9RaRTGuUtdX6gkBE6onIHOs9u0VkUbD4RSRN\nRD636m0XkXsrIo6iUbKIDALeAM6w3j9GRLqLyFa/uo1F5H0R+VtE/hGRF63yZiLylYjssl77j7WO\nEZFp+L4g51jtjrJGV4UikmDVaSQiH1mxbRKRG/z6HCMi74rIVOv96w+z4V8AlLaenxKRxSJSq5Q2\nwmKMWWmMmQGUNmIM+49eRAaKyMd+z38RkXf9nv8hIqdYj0/22342ikiGX72DpkhE5G7rV+ufIjI4\nwOj6KBH5xFrfy0Qk1XrfIms51lmvZYSzjUdCyVj9l61oexWRO0Vkp/h+XQ8sWVdETgR+sor3isiX\n1uudRWSFtSzLReQMv/cuEJFHRGSJiOQCqVbZv0XkW/H9mv9IRI6ytv8sq40mfuGX/DJaBFwY0RVU\nkjEmrv4BE/CNigqB74GUIPXOBzxAQiltjQGm+T0fCKQASfhG8Kv9XtsGdLYe1wHaWo/H4xvdJwA2\noEuQvmpabdwO2IEjgA4VEQe+xHOW9XgAsNjvte7AH9bjBGAN8DSQbMVZ1FdzfNMwiUA9YCHwbIk+\nevo9bwoUFK1/YDHwkrVMpwJ/Az38lt9pfWZiLcuyIOsxxfrs6/mVDbDaF3xfUp8BjiDv7w/sBfZY\n//s/3gM0Psz2dzbwvwDlvwPbgZ3APHzTPaFsz6nAHutxI2Cz3+fRDNjtt9x/ANdZy3kq8A9wsvX6\nZOBh63Evazs52focp1ufRTO/uv8Ap1mf+X+AmX4xFQKpfs9D2satumut9ei/Tov+fznEdVIca4Bl\n6w7kW9uMDd8Xei5QJ0Ddom1QrOdHWnFcZS3LldbzI63XF1jr/2Tr9USrbBO+X7C1gEx8Xw49rTpT\ngUmlLMsdwHvlzXul/YurkTiAMeZWfAnxTOC/gDtI1XrALmNMYZDXA7U9xRjjNL7R/sPAqX6jOQ+Q\nJiK1jDFZxpg1Vnk+vj++VOMb6X8bpPk+wHZjzPPGGI8xJtf4fm5VdByl6WS1cbcxxmXFudSK6Tdj\nzFfGGK8xZjfwHL4/KH8BR6IicjxwBnCPMSbfGLMWeBNfQiqyxBgz3/i2/OnAKUFirItvtJNTotwO\nvG293tcYE3C7MMa8bYw50hhzlPW//+OjjDF/Bun3cK7C94feFN8X3PyiXyqlMcb8DuSISFugGzAf\n2CYiLa3nRTts+wC/G2OmGZ+1wPtARoBmM4DJxpifjG/abGyAOh8YY36w/j5mAG1LvO7/WYa8bRlj\nTrXW41EB1u2ww6yOQH0H4gH+bcXyGbAfOCmE9i4ENhljZhpjCo0x7+BLyH396k6x1luhMcZrlU02\nxmw2xuTgGyD8ZoxZYK272UDQKV1822ndwyxPucRdEgewNuKlwPHAUAARmWv93MkWkf7AbqB+0c/5\nwxGRBBF5XER+FZF9+EZWBqhvVbkM30awxfqJlW6VPwn8BnxuvTfgVI0V62+VII7SNAa2BPriE5Gj\nReRt6+f5Pnyjt/qHtBBYI3yjTadf2RbgOL/nO/weO4HkIJ/dPuv/klMlLfDtJxnn98dXYYwxy4wx\nbuvL73F8cXYN8e2L8I3suuH7AlgI9MD3JVk0ddEUSBffUUV7RGQvvi+OYwK0dyyw1e/5Vg5NjCXX\nd81S4ovEthVJu0tso4eLv8ix+LY7fyW3w60caqff47wAz0vruxYHttmoiMsk7icRa07cGNPbGFPL\n+HYqvQ0swzdKvyTEtq7G9418ljGmLr5RlVj/sEYtlwANgI+AWVZ5rjFmlDGmOb4kcqeI9AzQ/laC\nzN9XcByl2Qo0CZI8x+P7mZ1mxXUNByeG0nZMbcM3B3uEX1kT4K8w48P6IvgNaFnipR+B64F51ig2\nIPEdrVT0Ze//r6iscbgxBQuV0OfIF+NL2mfiS9qL8SXwbhxI4luBhSVGubWDjG634/tCLtKEcuw4\nDGfbEt/hrMHW7SshdunEN31UpGFZYy9hG76/J38lt8NI72BthW+KKWriJomLSAMRuUJEjrBGq+fj\nm9P6MlB9Y0w2vnmzCSJysYjUEJFEEblARB4P8Jaa+JL+XivZPIb1gYpIkvXHX9sYU4DvJ1KB9dqF\nIlKUnHMAL75kV9InQEMRGS4idhGpKSIdoxxHQeC1GdQKfAngcRFJERGHHDgcsxa+n605InIccFeJ\n9+7AN4frr+iL509gKfCY1eYpwGB80ybBlJYA53LoVA7GmHeB+4EvJMghctZP6aIve/9/RWUBp1PE\nx4Fv2ibBWo4k67XjrR1mSVb5Xfim8761Xu8uBx96VlLRSLyGMWYbvimUXlYbq606nwAtReQaaztO\nEpHTRSTQNMIs4Hrx7QhNAR4ope9ADvosw9jGMb7DWYOt21tC7H81cJX1d96LAJ91GPy3o7nAiSJy\npYjYROQKfEl2TjnaP5zu+KZgoiZukji+RDYU34hkD76feCOMMZ8GfYMxzwJ34tuI/8a3Y+gW4MMA\n1adZr/8FbMCXdPxdC/xuTSUMwfdTFuBE4EsRycH3RzvBGHPI3ntjzH7gXHwjmR34dpb0iHIci4u6\nD9DPIayfqH2ttv7At677WS+Pw7cjbB++jf79Em9/HHjQ+ql/Z4B+++PbibfNeu+DxpgFpYVTymtv\n4PslEGgZpgH/Br6Sg48aKK9u+H46f4JvasyJb/4afF9wr+LbLv8EzgMuMAeOSz4eK6EHifkXfMlx\nsfU8B9+vjSXWPoKi7ec8fAOXbda/x/EdQleyvXnAixzYKbfMeinY/qOSxgLTrM/yckLcxiPodnx/\nJ3vxbTcfHKZ+adtK8WvGmD349i2MAnZZ/1/o9zkFaqfMI3MR6QDkGGO+L2sbIfVjbSNKxRUR+Q8w\nyxx8wk+lJCITgdnGmC9i1P/JwHp8R+yEvKNflY+IvAe8aX2pRq8fTeJKVT0icgm+6YMjgCmA1xhz\nWUyDUlERT9MpSqnQ3YRvCvEXfIcIhjofreKMjsSVUiqO6UhcKaXiWGJFdiYiOuxXSqkyMMYEPOy2\nQpM4gJk1q6K7DNvYWbMY26/f4StWEbq88Wd2wLPtA5s1ayz9+o2NXjCVTFVc3n79gp82odMpSikV\nxzSJK6VUHNMkHkCPtLR
2016-03-15 12:55:14 +00:00
"text/plain": [
2016-03-28 10:26:20 +00:00
"<matplotlib.figure.Figure at 0x7fc7cb622908>"
2016-03-15 12:55:14 +00:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2016-03-28 10:26:20 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXEAAAEKCAYAAADkYmWmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl4U9XWwOHfatqkVCYFFBDBAqJQQUCBCjI5i6BevaA4\noaAoiqCCIyrgrFcvTjiAiMBFEfQqogifKIMKAspo4TqDIIMytaVpk6bd3x/nFEJJ2qQkTdOu93n6\nNNnZ2WedISs7+0xijEEppVR8Soh1AEoppcpOk7hSSsUxTeJKKRXHNIkrpVQc0ySulFJxTJO4UkrF\nsSqZxEVktIhMq+xxiMgPItLN7/lkEdkjIt+KyFkisjEK0zxBRLJERCLdtt3+OyJyif14gIh8FY3p\nxAsReU1ERoVYd7KIPBrtmCLFP95oba/lQUSGisjT0Wo/LpK4iEwTkW0ikiki/xORQSG852oRWSki\n2SLyp4h8KiKd/apUlAPkoxaHMeZUY8wSsD4EwDlAQ2NMujHma2NMyyOdhoj8LiJn+01zizGmponC\nCQgi0hpoY4z52K84qutRRHqIyJcisk9Efgvw+iYRcdtfXFkiMi+a8RRnjBlijHkiEm2JSKGINI1E\nWyFOr4mI/B5K3VC3V7tjNPXIozsyIrLQrwM1EbhGROpGY1pxkcSBJ4EmxphawCXA4yLSLlhlEbkb\n+DfwOHAs0Bh4Fbi0HGKtqE4ENhlj8mIdyBG4BZheztPMASYBI4O8boCL7S+umsaYC8svtIiLRcem\nonSmosYY4wHmAtdHo/24SOLGmI3GmHz7qWCt+GaB6opITWAscJsxZrYxJtcYU2CM+dQYc1+Q98wU\nke0isldEFolIK7/XeolIht3L2mJ/QSAidURkjv2e3SKyOFj8IpImIv9n19suIveXRxxFvWQRGYjV\nGzjTfv9oEekuIlv86jYSkQ9E5C8R+VtEXrLLm4rIFyKyy37tP/Yyxu7xNAbm2O2OtHtXhSKSYNdp\nICKz7dh+EpGb/KY5WkTeE5Ep9vvXi0j7YMsRuAgoaTn/S0SWiEiNEtoIizFmpTFmOlBSjzHsoSMR\nuUFEPvZ7/rOIvOf3/A8RaWM/PsVv+9koIn396h0yRCIi99q/WreKyKAAvetjROQTe3kvE5FU+32L\n7flYZ7/WN5xtPBJEpJ2IfC/WL+4ZQLLfa8W31/vsecyyl0lPEbkAeBC4Uqxf4KvtujeIyAa77i8i\nMrh4uyJyt4jsFOtX+w1+ryeLyPNi/eLaa29fLvu1dBH5xi5fLSLd/Wan+JfTYuDiCC4uvykZExd/\nwHisXlEh8B2QEqTeBYAXSCihrdHAVL/nNwApQBJWD36132vbgM7241pAW/vxk1i9+wTAAXQJMq3q\ndht3Ak7gKKBDecSBlXjOth8PAJb4vdYd+MN+nACsAZ7D+uA4/abVDGsYJhGoAywC/l1sGj39njcB\nCoqWP7AEeNmep9OAv4AefvPvtteZ2POyLMhyTLHXfR2/sgF2+4L1JfUZ4Ary/v7AXmCP/d//8R6g\nUSnb3znAbwHKfwe2AzuBeVjDPaFsz6nAHvtxA2CT3/poCuz2m+8/sHpxYi/Dv4FT7NcnA4/ajy+0\nt5NT7PU4zV4XTf3q/g2cbq/z/wDv+MVUCKT6PQ9pG7frrrWXo/8yLfr/SgjLI8leBsPsaV2B9Tku\nmjf/7bWFvUyOs583LoqbYp8pu+wi4ET7cVesPNLWr918+30Ou24OUMsv73wJ1LeXf7oda0NgF3CB\n3/axC7/ts1gM7YBdZc1/Jf3FRU8cwBhzO1ZCPAv4L+AJUrUO1sIqDKPtt40xbmP19h8FTvPrzXmB\nNBGpYYzJNMasscvzsT58qcbq6X8TpPnewHZjzAvGGK8xJscYszIGcZSkk93GvcaYPDvOpXZMvxpj\nvjDG+Iwxu4FxWBu+v4A9URE5ATgTuM8Yk2+MWQu8yaE/K782xsw31pY+DWgTJMbaWL2b7GLlTuBd\n+/U+xvrpehhjzLvGmKONMcfY//0fH2OM2RpkuqW5GmuoqgnWF9z8ol8qJTHG/A5ki0hboBswH9gm\nIi3s50U7bHsDvxtjphrLWuADoG+AZvsCk40x/zPWsNmYAHU+NMZ8b38+pgNti73uvy5D3raMMafZ\ny/GYAMt2aCmLA6zkmGiMecme1gdAwM8J1heTEzhVRBKNMX/YyzNYbJ8ZYzbZj78C/g8rmRfxAo/Z\n0/0M2A+cLCIC3AgMM8bssJf/t/bn81rgU2PMfLvdL7A6l72ChJGN1fmKuLhJ4gD2QlwKnAAMARCR\nufZPpywR6Q/sBuoW/ZwvjYgkiMjT9s+sfVg9KwMU7YS4Autn0Gaxdlak2+XPAr8C/2e/N+BQjR3r\nrxUgjpI0AjYH+uITkWNF5F37p+s+rN5bqDtoGmD1Nt1+ZZuB4/2e7/B77AaSg6y7ffb/4kMlzbH2\nk4w1xvhCjCtijDHLjDEe+8vvaaw4u5b2PttioCdW0l5k//XA+pIsGrpoAqSLdVTRHhHZi/XFcVyA\n9hoCW/yeb+HwL9jiy7t6CfFFYtsKVUPgz2JlmwNVNMb8ivXLdgywU6wjluoHa1hELrKHjnbby+8i\nDt2Gdxfb9ouWS13ABRy2QxtrvfQrtl66YG3zgdQAMoPFeCTiKon7ScQeEzfG9DLG1DDWTqV3gWVY\nvfTLQmzrGqAP1rBDbaxeldh/2L2Wy4B6wGxgpl2eY4wZaYxphpVE7haRngHa30KQ8ftyjqMkW4DG\nQZLnk1g/s9PsuK7l0MRQ0o6pbVhjsEf5lTXm8A9rqewvgl+xfkr724DVW5pn92IDEutopaIve/+/\norJG4cYULFRCHyNfgpW0z8JK2kuwEng3DibxLcCiYr3cmkF6t9uxvpCLNOYIdhyGs22JdThrsGX7\nagiT286hX+5F8QeLbYYxpitWMgV4puilYnE5gfexvpDqGWOOxhp2C2Ud7QLyCPz53YI1bOO/XmoY\nY54N0lZLrCGniKvwSVxE6onIlSJylN1bvQC4ClgQqL4xJgtrfGu8iFwqItVEJNH+Ng50rGZ1rKS/\n1042T2FvCCKSZH/4axpjCrB+EhXYr10sIkUrNxvwYSW74j4B6ovIMBFxikh1EekY5TgKAi/NoFZg\nfYieFpEUEXHJwcMxa2D9vMwWkeOBe4q9dwfWGK6/oi+ercBS4Cm7zTbAIKxhk2BK+nDN5fChHIwx\n72Ht0PpcghwiZ4x5x+/L3v+vqCzgcIpYXFg/3xPs+UiyXztBRDrb68clIvdgDed9Y7/eXURKGtYr\n6olXM8ZswxpCudBuY7Vd5xOghYhca2/HSSJyhoicHKC9mcCNYu0ITQEeKmHagRyyLsPYxjHW4azB\nlu1tIUx7GeATkTvs+bwcCPQ5QURaiLUj04k1FJLrF9dO4ESRA+cpOO2/XcaYQhG5CDg/hHiwh/gm\nA/8Wawd9glg7M5OwfpH2EZHz7fJke303DNJcd6wvj4ir8EkcK5ENwfrm24P1jTrcGPNp0DcY82/g\nbqyN+C+snSC3AR8FqD7Vfv1P4AespOPvOuB3eyhhMNZPWYCTgAUiko31oR1vjDls770xZj9wHlZP\nZgfwE1bvK5pxLCmafIDpHMb+KdnHbusPrGXdz355LNaOsH3AHKzxWH9PAw/bPynvDjDd/lg78bbZ\n733YGLOwpHBKeG0i1i+BQPMwFXgM+EJEgvbgyqAbVpL4BGtozI01fg3WF9xrWNvlVqzkcJExZq/9\n+gnYCT1IzD9jJccl9vNsrF8bX9sJpGj7OR+r47LN/nsa62d+8fbmAS8BC7G2s2X2S8H2HxU3Bphq\nr8t/EuI2Hgn2OPPlWL+qdmON7xff1oq4sJbB31jLox7wgP3aLKyOwG4R+c5efsOBWSKyB2s5zi4t\nHL/HI4H1WOPzu+3pJthf+pdidR7+xhr6GUmAnCoiyVhj5VNKmW6ZiL2tKBUXROQ/wExz6Ak/FZKI\nTABmGWM+j9H0T8FKQK5
2016-03-15 12:55:14 +00:00
"text/plain": [
2016-03-28 10:26:20 +00:00
"<matplotlib.figure.Figure at 0x7fc7cb632518>"
2016-03-15 12:55:14 +00:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%run util_knn.py\n",
"plot_classification_iris()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluating the algorithm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Precision, recall and f-score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For evaluating classification algorithms, we usually calculate three metrics: precision, recall and F1-score\n",
"\n",
"* **Precision**: This computes the proportion of instances predicted as positives that were correctly evaluated (it measures how right our classifier is when it says that an instance is positive).\n",
"* **Recall**: This counts the proportion of positive instances that were correctly evaluated (measuring how right our classifier is when faced with a positive instance).\n",
"* **F1-score**: This is the harmonic mean of precision and recall, and tries to combine both in a single number."
]
},
{
"cell_type": "code",
2016-03-28 10:26:20 +00:00
"execution_count": 14,
"metadata": {},
2016-03-15 12:55:14 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" setosa 1.00 1.00 1.00 8\n",
" versicolor 0.79 1.00 0.88 11\n",
" virginica 1.00 0.84 0.91 19\n",
"\n",
"avg / total 0.94 0.92 0.92 38\n",
"\n"
]
}
],
"source": [
2016-03-28 10:26:20 +00:00
"print(metrics.classification_report(y_test, y_test_pred, target_names=iris.target_names))"
2016-03-15 12:55:14 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Confusion matrix"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another useful metric is the confusion matrix"
]
},
{
"cell_type": "code",
2016-03-28 10:26:20 +00:00
"execution_count": 15,
"metadata": {},
2016-03-15 12:55:14 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 8 0 0]\n",
" [ 0 11 0]\n",
" [ 0 3 16]]\n"
]
}
],
"source": [
"print(metrics.confusion_matrix(y_test, y_test_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-03-15 12:55:14 +00:00
"source": [
"We see we classify well all the 'setosa' and 'versicolor' samples. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### K-Fold validation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to avoid bias in the training and testing dataset partition, it is recommended to use **k-fold validation**."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
2016-03-15 12:55:14 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.93333333 0.8 1. 0.93333333 0.93333333 0.93333333\n",
" 1. 1. 0.86666667 1. ]\n"
]
}
],
"source": [
"from sklearn.model_selection import cross_val_score, KFold\n",
2016-03-15 12:55:14 +00:00
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# create a composite estimator made by a pipeline of preprocessing and the KNN model\n",
"model = Pipeline([\n",
" ('scaler', StandardScaler()),\n",
" ('kNN', KNeighborsClassifier())\n",
"])\n",
"\n",
"# create a k-fold cross validation iterator of k=10 folds\n",
"cv = KFold(10, shuffle=True, random_state=33)\n",
2016-03-15 12:55:14 +00:00
"\n",
"# by default the score used is the one returned by score method of the estimator (accuracy)\n",
"scores = cross_val_score(model, x_iris, y_iris, cv=cv)\n",
"print(scores)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"We get an array of k scores. We can calculate the mean and the standard error to obtain a final figure"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
2016-03-15 12:55:14 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean score: 0.940 (+/- 0.021)\n"
]
}
],
"source": [
"from scipy.stats import sem\n",
"def mean_score(scores):\n",
" return (\"Mean score: {0:.3f} (+/- {1:.3f})\").format(np.mean(scores), sem(scores))\n",
"print(mean_score(scores))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So, we get an average accuracy of 0.940."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tuning the algorithm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We are going to tune the algorithm, and calculate which is the best value for the k parameter."
]
},
{
"cell_type": "code",
2016-03-28 10:26:20 +00:00
"execution_count": 18,
"metadata": {},
2016-03-15 12:55:14 +00:00
"outputs": [
{
"data": {
"text/plain": [
2016-03-28 10:26:20 +00:00
"<matplotlib.text.Text at 0x7fc7cb526160>"
2016-03-15 12:55:14 +00:00
]
},
2016-03-28 10:26:20 +00:00
"execution_count": 18,
2016-03-15 12:55:14 +00:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2016-03-28 10:26:20 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEPCAYAAABRHfM8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xm4HNV55/Hv72pDC0JCQqDtSmKRWMxiMGIXLWwHOWFC\nRjxJYBJjMomTzBhnJjNmgAwJV3HsQB4ncWJiJ3G84OCJ7NiTmJCxjbF0EQZsZJBkI0sIY+lqRUhC\nQqAFLfedP06VbqvVfbu6u6rX9/M8/ai7urrr3KPq81adVWaGc845V05XoxPgnHOuNXjAcM45l4gH\nDOecc4l4wHDOOZeIBwznnHOJeMBwzjmXSOYBQ9ICSWslrZN0d5H3uyU9IWmVpCWSpuS9d1TSC5JW\nSPrXrNPqnHOuNGU5DkNSF7AOeDewFVgO3Gpma/P2+SrwqJk9IikH/Gczuz16b6+Zjc0sgc455xLL\n+g5jLvCymfWZ2WFgMXBzwT7nA0sBzKy34H1lnD7nnHMJZR0wpgKb8l5vjrblWwksBJC0EBgjaXz0\n3ghJz0l6RlJhoHHOOVdHzdDofReQk/Q8cB2wBTgavTfDzOYCvwZ8UtKsBqXROec63tCMv38L0J33\nelq07Rgz2wbcAiBpNHCLme3New8zWy+pF3gnsD7/85J8MiznnKuCmVVU7Z/1HcZy4GxJMyQNB24F\nHs3fQdIESXGi7wU+H20fF30GSROBq4GfFDuImfkjpcf999/f8DS008Pz0/OzWR/VyDRgmNlR4E7g\ncWA1sNjM1khaJOmmaLcc8JKktcAk4GPR9vOAH0paAXwX+FPL613lnHOuvrKuksLMvgXMKdh2f97z\nrwNfL/K5Z4GLsk6fc865ZJqh0ds1kVwu1+gktBXPz3R5fjZWpgP36kGS1fI3bNwIP/0p3HBDioly\nLW/XLnj9dTjnnEanpPFeew3eegvOPLPRKXFpkoQ1WaN303vpJfiTP2l0KlyzefhhuO++RqeiOfz9\n38NHP9roVLhm0PEBY+ZM2LCh0alwzWb9+vBwnhduQMcHjO5u2LIFjhxpdEpcM/FCcoDnhYt1fMAY\nMQImTYLNmxudEtdM1q+HnTtD3X2nW78+/D4OH250SlyjdXzAAJg1y6+g3ACzUE05daqfF0eOhDvw\n008PHURcZ/OAgQcMd7wdO+Ckk+Cii7x9a/PmECxmz/a8cB4wAA8Y7njr14dzws8Lzwt3PA8YeE8p\ndzwvJAd4Xrh8HjDwH4M7nheSAzwvXD4PGPiPwR3PC8kBnhcunwcMQm+YnTvh4MFGp8Q1g8JCssVn\nz6mJBwyXzwMGMGQITJ8OfX2NTolrBhs2hHatceNAgt27G52ixonzYvJk2LMHDhxodIpcI3nAiPgV\nlAM4ehQ2bQqFpNTZ58XBg+HOe+pU6OoKsyJ455DO5gEj4j2lHMDWrXDqqWEcBnR2wOjrC3feQ4aE\n152cFy7IPGBIWiBpraR1ku4u8n63pCckrZK0RNKUgvdPlrRJ0l9nmU7/MTgYqLOPdfJ54XnhCmUa\nMCR1AQ8BNwIXALdJOrdgt08AXzSzi4E/Bh4oeP+jwJNZphP8x+ACLyQHeF64QlnfYcwFXjazPjM7\nDCwGbi7Y53xgKYCZ9ea/L+kywjrfj2ecTv8xOMALyXyeF65Q1gFjKrAp7/XmaFu+lcBCAEkLgTGS\nxksS4e7jI0BFq0JVw38MDgZ6BcVmzuzc82L9+hPzwtv5OtvQRicAuAt4SNIdwDJgC3AU+K/Av5vZ\n1hA7SgeNnp6eY89zuVxV6/5OmgT798Obb8LJJ1f8cdcm1q+H228feD1zZmj8NQu9pjrJhg1+h9FO\nent76e3trek7Ml3TW9KVQI+ZLYhe3wOYmT1YYv/RwBoz65b0CHAt0A+cDAwDPm1mf1DwmZrW9M53\n/vnwla/AhRem8nWuBXV3w5NPHl9QTpoEq1aFsQidZOJEWL06zFYLIWiefHKY7vyUUxqbNle7ZlzT\nezlwtqQZkoYDtwKP5u8gaUJU/QRwL/B5ADP7dTObaWZnEqqlvlQYLNLmV1Cd7dAh2L49dCXN14nn\nxZtvhkF6kyYNbOv0cSku44BhZkeBOwmN1quBxWa2RtIiSTdFu+WAlyStJTRwfyzLNA3GfwydbeNG\nmDIFhhZU1HbieRG3XxRWw3ViXrgBmbdhmNm3gDkF2+7Pe/514OtlvuNh4OFMEpjHfwydrbCRN9aJ\n50VhD6lYJ3cCcD7S+zidWDC4AYWNvLFOLCQHC57eU6pzecDI4wGjs5W6qu7EQrJU8PTfSGfzgJEn\n7mfeydNZd7LBAkanFZKeF64YDxh5xo8Ps3K+/nqjU+IaoVQh2d0dupIeOVL/NDVKubstv6jqTB4w\nCvgVVOcqVUiOGBG6l27eXP80NYJZ6bwYOzbkx44d9U+XazwPGAU8YHSm/fvhjTfgjDOKv99J58Wu\nXaFr8bhxxd/3KUI6lweMAp1UMLgBGzbAjBmhSrKYTuopVaqHVMx/I53LA0YB/zF0plJVMLFO6ilV\nqodUzH8jncsDRoFOupJ0A5IEjE45LzwvXCkeMAp00pWkG+CF5ADPC1eKB4wC8XTW/f2NTomrJy8k\nB3heuFI8YBQYPTp0HXz11UanxNVT4cJJhaZOhZ074eDBeqWoccoFjBkzYNMmv6jqRB4wivArqM5T\nrpAcMiRMe97XV780NUJ/f/gbBwueI0eGQa5bt9YtWa5JeMAowgNGZ9mzJ4zinjBh8P06YfzBq6+G\nxZFGjRp8P/+NdCYPGEV4T6nOEt9dlFuCtRMKyXJ3WrFOyAt3Ig8YRXhPqc7iheQAzws3mMwDhqQF\nktZKWifp7iLvd0t6QtIqSUskTcnb/rykFyT9WNLvZJ3WmP8YOosXkgM8L9xgMg0YkrqAh4AbgQuA\n2ySdW7DbJ4AvmtnFwB8DD0TbtwFXmtmlwBXAPZJKzPSTLv8xdJZyPaRinXBeJA0YndCe406U9R3G\nXOBlM+szs8PAYuDmgn3OB5YCmFlv/L6ZHY4+AzASKFPDnJ7u7tADpJOms+5kflU9wPPCDSbrgDEV\n2JT3enO0Ld9KYCGApIXAGEnjo9fTJK0C+oAHzawuoyOGD4fTTw99zV37S1pITpoUZrV9883s09Qo\nSe+2pk8PPaoOHy6/r2sfQxudAOAu4CFJdwDLgC3AUQAz2wxcHFVFfUPS18zshJn4e3p6jj3P5XLk\ncrmaExX3lEpSkLjWZVZ+sr2YNFAVc+GFGSesAY4cCXfW3d3l9x02DCZPho0b4ayzsk+bq11vby+9\nvb01fUfWAWMLkH/6TYu2HWNm24BbACSNBm4xs70F+7wq6UXgOuD/Fh4kP2CkxXtKdYbXXgsD0U4+\nOdn+cVVMOwaMTZvCnfXw4cn2j/PCA0ZrKLyYXrRoUcXfkXWV1HLgbEkzJA0HbgUezd9B0gTpWA/4\ne4HPR9unSjopej4euBZ4KeP0HuN1tJ2h0rvIdj4vPC9cOZkGDDM7CtwJPA6sBhab2RpJiyTdFO2W\nA16StBaYBHws2n4e8ANJKwiN4n9mZquzTG8+/zF0hqR19rF2Pi8qDRjeU6rzZN6GYWbfAuYUbLs/\n7/nXga8X+dwTwMVZp6+Udi4Y3IBqrqqXLcsuPY1UTV5885vZpcc1Hx/pXYIHjM7gV9UD/G7LleMB\no4QpU2DXLjhwoNEpcVmqtt7eLLs0NYq3YbhyPGCUEE9nvXFjo1PislRpITl+PHR1weuvZ5emRqk0\nLyZPhjfeCGNTXGfwgDEIv4Jqb0ePhq6kM2ZU9rl2PC8OHAhBcMqU5J/p6gpjNtq1is6dyAPGINqx\nYHADtm4Na2CcdFJln2vH86KvL9xRDxlS2efauU3HncgDxiDasWBwA6odyd+O54XnhUvCA8Yg/MfQ\n3qotJNvxqrrSHlIx/410Fg8Yg/CV99qbX1UP8LxwSXjAGITPJ9XevJAc4HnhkvCAMYhJk0LvkXae\nzrqT1VIl1dcH/f2pJ6l
2016-03-15 12:55:14 +00:00
"text/plain": [
2016-03-28 10:26:20 +00:00
"<matplotlib.figure.Figure at 0x7fc7cb726ac8>"
2016-03-15 12:55:14 +00:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"k_range = range(1, 21)\n",
"accuracy = []\n",
"for k in k_range:\n",
" m = KNeighborsClassifier(k)\n",
" m.fit(x_train, y_train)\n",
" y_test_pred = m.predict(x_test)\n",
" accuracy.append(metrics.accuracy_score(y_test, y_test_pred))\n",
"plt.plot(k_range, accuracy)\n",
"plt.xlabel('k value')\n",
"plt.ylabel('Accuracy')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The result is very dependent of the input data. Execute again the train_test_split and test again how the result changes with k."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* [KNeighborsClassifier API scikit-learn](http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html)\n",
"* [Learning scikit-learn: Machine Learning in Python](http://proquest.safaribooksonline.com/book/programming/python/9781783281930/1dot-machine-learning-a-gentle-introduction/ch01s02_html), Raúl Garreta; Guillermo Moncecchi, Packt Publishing, 2013.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Licence\n",
"The notebook is freely licensed under under the [Creative Commons Attribution Share-Alike license](https://creativecommons.org/licenses/by/2.0/). \n",
"\n",
"© 2016 Carlos A. Iglesias, Universidad Politécnica de Madrid."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
2016-03-15 12:55:14 +00:00
}
},
"nbformat": 4,
"nbformat_minor": 1
2016-03-15 12:55:14 +00:00
}