1
0
mirror of https://github.com/gsi-upm/sitc synced 2024-11-22 14:32:28 +00:00
sitc/ml1/2_5_1_kNN_Model.ipynb

568 lines
79 KiB
Plaintext
Raw Normal View History

2016-03-15 12:55:14 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![](files/images/EscUpmPolit_p.gif \"UPM\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Course Notes for Learning Intelligent Systems"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Department of Telematic Engineering Systems, Universidad Politécnica de Madrid, © 2016 Carlos A. Iglesias"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [Introduction to Machine Learning](2_0_0_Intro_ML.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table of Contents\n",
"* [kNN Model](#kNN-Model)\n",
"* [Load data and preprocessing](#Load-data-and-preprocessing)\n",
"* [Train classifier](#Train-classifier)\n",
"* [Evaluating the algorithm](#Evaluating-the-algorithm)\n",
" * [Precision, recall and f-score](#Precision,-recall-and-f-score)\n",
"\t* [Confusion matrix](#Confusion-matrix)\n",
"\t* [K-Fold validation](#K-Fold-validation)\n",
"* [Tuning the algorithm](#Tuning-the-algorithm)\n",
"* [References](#References)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# kNN Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The goal of this notebook is to learn how to train a model and make and evaluate predictions.\n",
"\n",
"The notebook uses the [kNN (k nearest neighbors) algorithm](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data and preprocessing\n",
"\n",
"The first step is loading and preprocessing the data as explained in the previous notebooks."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# library for displaying plots\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# display plots in the notebook \n",
"%matplotlib inline\n",
"\n",
"## First, we repeat the load and preprocessing steps\n",
"\n",
"# Load data\n",
"from sklearn import datasets\n",
"iris = datasets.load_iris()\n",
"\n",
"# Training and test spliting\n",
"from sklearn.cross_validation import train_test_split\n",
"\n",
"x_iris, y_iris = iris.data, iris.target\n",
"\n",
"# Test set will be the 25% taken randomly\n",
"x_train, x_test, y_train, y_test = train_test_split(x_iris, y_iris, test_size=0.25, random_state=33)\n",
"\n",
"# Preprocess: normalize\n",
"from sklearn import preprocessing\n",
"scaler = preprocessing.StandardScaler().fit(x_train)\n",
"x_train = scaler.transform(x_train)\n",
"x_test = scaler.transform(x_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"## Train classifier"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The usual steps for creating a classifier are:\n",
"1. Create classifier object\n",
"2. Call *fit* to train the classifier\n",
"3. Call *predict* to obtain predictions\n",
"\n",
"Once the model is created, the most relevant methods are:\n",
"* model.fit(x_train, y_train): train the model\n",
"* model.predict(x): predict\n",
"* model.score(x, y): evaluate the prediction"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',\n",
" metric_params=None, n_jobs=1, n_neighbors=15, p=2,\n",
" weights='uniform')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"import numpy as np\n",
"\n",
"# Create kNN model\n",
"model = KNeighborsClassifier(n_neighbors=15)\n",
"\n",
"# Train the model using the training sets\n",
"model.fit(x_train, y_train) "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prediction [1 0 1 1 1 0 0 1 0 2 0 0 1 2 0 1 2 2 1 1 0 0 1 0 0 2 1 1 2 2 2 2 0 0 1 1 0\n",
" 1 2 1 2 0 2 0 1 0 2 1 0 2 2 0 0 2 0 0 0 2 2 0 1 0 1 0 1 1 1 1 1 0 1 0 1 2\n",
" 0 0 0 0 2 2 0 1 1 2 1 0 0 2 1 1 0 1 1 0 2 1 2 1 2 0 2 0 0 0 2 1 2 1 2 1 2\n",
" 0]\n",
"Expected [1 0 1 1 1 0 0 1 0 2 0 0 1 2 0 1 2 2 1 1 0 0 2 0 0 2 1 1 2 2 2 2 0 0 1 1 0\n",
" 1 2 1 2 0 2 0 1 0 2 1 0 2 2 0 0 2 0 0 0 2 2 0 1 0 1 0 1 1 1 1 1 0 1 0 1 2\n",
" 0 0 0 0 2 2 0 1 1 2 1 0 0 1 1 1 0 1 1 0 2 2 2 1 2 0 1 0 0 0 2 1 2 1 2 1 2\n",
" 0]\n"
]
}
],
"source": [
"print(\"Prediction \", model.predict(x_train))\n",
"print(\"Expected \", y_train)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy in training 0.964285714286\n"
]
}
],
"source": [
"# Evaluate Accuracy in training\n",
"\n",
"from sklearn import metrics\n",
"y_train_pred = model.predict(x_train)\n",
"print(\"Accuracy in training\", metrics.accuracy_score(y_train, y_train_pred))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy in testing 0.921052631579\n"
]
}
],
"source": [
"# Now we evaluate error in testing\n",
"y_test_pred = model.predict(x_test)\n",
"print(\"Accuracy in testing \", metrics.accuracy_score(y_test, y_test_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we are going to visualize the Nearest Neighbors classification. It will plot the decision boundaries for each class.\n",
"\n",
"We are going to import a function defined in the file [util_knn.py](files/util_knn.py) using the *magic command* **%run**."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXEAAAEKCAYAAADkYmWmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xd4k1X7wPHv3bRJqSwFFBRBhihUFFCggCzHqyK4i+JC\nUVEURQUnKuDrfB24cKDIelEBfR38RHAxRBRUQKDgFkQZympL0yRNe35/5GkJJU9J2qRp2vtzXb2a\nnJyccz9Pkjsn51lijEEppVRiSop3AEoppcpPk7hSSiUwTeJKKZXANIkrpVQC0ySulFIJTJO4Ukol\nsBqdxEVkjIhMr+5xiMhaEekVdH+yiOwUka9F5GQRWR+DPo8UkRwRkWi3bbX/hoicY90eLCJfxKKf\nRCEi94jIxDDrVon3fSyUXg8icr6I/GG9F0+o5FieFJEbYt1PQiVxEZkuIltEZLeI/CAi14TxnEtF\n5BsRyRWRv0TkQxHpHlSlquwoH7M4jDHHGWMWA4jIycCpwOHGmAxjzBJjTNuK9iEiv4vIKUF9bjLG\n1DUxOBBBRNoDxxtjPggqjunrKCJ9RORz6733W4jHN4iI20oWOSIyL5bxlGaMedQYMzSSp9g9UPq1\nrAwiUhSNdkKshyeAG6334vfR6KMs1rprZt19ErhXRJJj2WdCJXHgUaCFMaY+cA7wkIh0tKssIrcD\nTwMPAYcCzYAJ1nNrqqOADcYYT7wDqYDrgRmV3GceMAkYZfO4Ac62kkVdY8yZlReaKkNzYF15nigi\nFcqPxpitwHpinG8SKokbY9YFJR8h8MFpFaquiNQFxhH4Fn7fGJNvjCk0xsw1xtxt85xZ1kh/l4gs\nFJF2QY/1E5Esa5S1yfqCQEQaiMgc6zk7RGSRXfwiki4iH1v1tohIpcRRPLISkSHAq0A36/ljRKS3\niGwKqttURN4Rkb9F5B8Rec4qbykin4nIduux/1rrGBGZRuALco7V7igRaS4iRcUfBBFpIiLvW7H9\nJCLXBvU5RkRmishU6/lrRKST3XoEzgLKWs9PiMhiEalTRhsRMcZ8Y4yZAfxeRrWIp46sEf7qoPuf\niMjyoPuLg6aNmojI29b6/1VEbg6qt88UiYhcaf06+EdE7gsxunaFWt82r6XLer23W++vZSLSKNJl\nPYCSXwalYw1etqD31ZUistFaF/eWqjtNRJwikksgx60WkZ+tx9uKyAJrOdaIyICg504WkRcl8Gs9\nF+hjlU0QkbkS+DX/hYgcJiLjJTAluU72naYp/QtnEXB2NFfUfowxCfVHYCSdBxQB3wJpNvXOAHxA\nUhltjQGmBd2/CkgDUgiM4FcGPbYZ6G7drgd0sG4/ArxI4M3iAHrY9FXbauNWwAkcBHSujDgIJJ5T\nrNuDgcVBj/UG/rBuJwGrCPwMTLXiLO6rFYFpmGSgAbAQeLpUH32D7jcHCovXP7AYeN5aphOAv4E+\nQcvvtl4zsZblK5v1mGa99g2CygZb7QuBL6mPAJfN8wcBu4Cd1v/g2zuBpgd4/50K/Bai/HdgC7AN\nmEdguiec93OqteyHWOt2K7DJen8UP1bfWrZvgdHW63sU8Atweun3ENAOyAW6WW0+AXiD3gNlru8Q\nr+VQ4H3AZdXvCNS2WZ45IdZp8f8PwlwnJe/XEMvW3Hr9X7Hen8cDHuAYm89SEYFf71jr4mfgLut2\nXyAHONp6fLIVa4Z132WV/Q10sPr7DPgNuMxaF/8GPi9jWc4Hvo1W/gv1l1AjcQBjzE0EEuLJwP8I\nvDlDaQBsN8aEPddmjJlijHEbYwqAB4ETgkZzPiBdROoYY7KNMaus8gKgCYE3SqEx5kub5vsDW4wx\nzxhjfMaYPGPMN3GIoyxdrTbuNMZ4rDiXWjH9aoz5zBjjN8bsAMYT+AIIFnIkKiJHEkgodxljCkxg\nbvI14MqgakuMMfNN4J0/ncCHM5T6BEY7uaXKncCb1uMDjDEh3xfGmDeNMQcbYw6x/gffPsQY86dN\nvwdyKYHE2pzAF9z84l8qZTGBX5bfAL2AE4HvgS+BHkAG8JMxZjfQBWhojHnYen03EFiHl4Ro9kIC\nCfMrY4wfeCBEnQOt7+DXsoDA56mNCVhpjNljszwDQqzT4v/RmlYwwFjr/bmawDora6Nl8bJkAAcZ\nYx633scLgP8j8MVe7H1jzNfWshS/h941xqwyxviAd4F8Y8wMa93NJJDg7eQSeE/GTMIlcQDrjbQU\nOBIYBhD0cydHRAYBO4CGEua8logkichjIvKLiOwmMBowQEOryoUEfhZttH6OZVjl/wF+BT62nnuX\nTRdHWvXiHUdZmgIbQ33xicihIvKmiPxpxfXfoJgOpAmw0xjjDirbCBwRdH9r0G03kGrz2u22/pee\nKmlNYO5xnJW4KpWVML3Wl99jBOLsGebTFxMYFfYi8AWwEOhD4EuyeNqoGXCE9RN+p4jsAu4hsK2n\ntMMJjOaLY8sn8HkIFu76BpgGzAfesl7/x0TEEeayxcq2oNtuAgO7A9lnvVhKvw9LP166r/wQ98vq\nuw5737MxkZBJPEgy1py4MaafMaaOCWxUehP4isAo/bww27oMGEDgZ1x9AqMqsf4wxnxnjDkPaETg\np+UsqzzPGDPKGNOKQBK5XUT6hmh/Ezbz95UcR1k2Ac1sPsyPEPhpmm7FdTn7jtbK2jtkM3CIiBwU\nVNYM+CvC+LC+CH4F2pR6aB1wNTBPREo/VkICeysVf9kH/xWXNY00JrtQCX+OfBGBpN3Tur2YQALv\nxd4kvonANM4hQaPbesaYASHa20LgCxkAEalFYCQdSex77wRG/v82xqQD3Qm8P68M9cRSg6nSfx+G\n2X8egWmzYo0jiL0smwkMpoKVfh9Gey+ntgR+KcRMwiRxEWkkIheLyEHWaPUMAj8lPw1V3xiTQ2B+\nbIKInCsitUQkWUTOEpHHQjylNoGkv8tKNo9ivaAikmJ9+OsaYwoJ/EQqtB47W0SKk3Mu4CeQ7Er7\nP6CxiNxibXSpLSJdYhxHYei1aWs5gQTwmIikSWCDVvHumHWAPUCuiBwB3FHquVuBlqXKir94/gSW\nAo9abR4PXEPgZ7ydshLgXPafysEYMxO4F/hERErHUlznjaAv++C/4rKQ0ykS4CIwbZNkLUeK9diR\nItLden1cInIHgaT5pfV48ca4ZqHaJrBujiEwZbLcGLOOwLRMVwIJHQKvTa6I3CkiqSLikMCG8pNC\ntPc2MEBEMqwYx9r0u88iBt3e57WUwMbX46wv9z0EpldCTlOWGkyV/gt3A98q4BLr83oScFEZsUZi\nGeC21mGyiPQhMM35ZjnbCyee3gS20cRMwiRxAolsGIERyU4C0wcjjDG23+7GmKeB24H7CGyc+AO4\nEXgvRPVp1uN/AWsJfLCCXQH8bk0lDCUwBwpwNPCpBLZmfwlMMMbst+eENYd4OoFR8lbgJwKjr1jG\nUZwAwhpdWNMoA6y2/iCwrgdaD48jMGe7m8DGq3dKPf0x4H7rp/7tIfodBLQgMBp6B7jfmpO0DaeM\nx14l8Esg1DJMI7Cx6bMykmZ59CLw0/n/CIzm3ASmGCDwBfcSgffln8C/gDONMbusx48ENmDzy8P6\ndfEdsDZoKugrAruCbrfqFBFIOB0ITLH9TWA97Dfvbn0J3ExgvnYzgY13f2O//Qj2Xd+lX8vGBL4Y\nsoEsYAFlfwFX1P0Epsd2EhiIld6dtPR7o6z3SsljJrCNaQDQD9gOvABcYYz5uYx2wvnshKwjIk0I\njMRD5ZuokcDcvFKJRUT+C8wy+x7wUyWJyGjgb2PMq3Hq/yACX76tjTEb4xFDTSQiTwK/GGNejmk/\nmsSVqn5EpD+B3eGSgKcI7M56YnyjUrGQSNMpSqnwnUtgKuVPAhvUQ+2KqKoBHYkrpVQC05G4Ukol\nsJieXas0EdFhv1JKlYMxJuSujJWaxAHMrFmV3WXExs6axdiBAw9csZrQ5U08s8kMu+6sWWMZOHBs\n7IKpYqrj8g4caL8ruk6
"text/plain": [
"<matplotlib.figure.Figure at 0x7f21a42a1ef0>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXEAAAEKCAYAAADkYmWmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xl4k1X2wPHvadqmVARUUFAEEUShgooCFWRzmVEEdcYB\nxQ1FRXFDxR0VcNfRcRtcQERlcAEdRX6ijCiLioKKLALjNoIoi7K1pUvStOf3x/u2hJK0SUmapj2f\n5+nT5ObmvuddcnJz301UFWOMMckpJdEBGGOMqT5L4sYYk8QsiRtjTBKzJG6MMUnMkrgxxiQxS+LG\nGJPE6mUSF5ExIjKlrschIt+KSO+g55NFZKuIfCEiJ4jI6jhM82ARyRURiXXbbvuvisgZ7uOhIvJJ\nPKaTLETkdhGZEGHdWrHdRyo43nhvV/EkIgNE5PV4tZ8USVxEpojIBhHZLiL/FZFLI3jPeSLypYjk\nichvIvKeiPQIqlJbDpCPWxyqeqSqLgAQkROAk4ADVTVbVT9V1Q57Og0R+VlETgya5jpVbaRxOAFB\nRDoBnVX13aDiuK5HEekrIh+7297/Qry+RkQK3ASTKyIfxDOeilT1QVUdHs1bwr1QcV3WBBEpraKK\nQuTbVW35Ync7TBcBqOr/AR1F5Mh4TCspkjjwINBGVZsAZwD3icgx4SqLyI3AP4D7gP2BVsB49731\n1SHAGlUtSnQge+AKYGoNTzMfmATcFOZ1BU53E0wjVT215kIzIQi1p4MW7HWc7TfmkiKJq+qqoORT\ntpLahqorIo2AccBVqjpDVQtVtURVZ6nqbWHeM83t6W8TkXki0jHotf4istLtZa1zvyAQkf1EZKb7\nni0iMj9c/CKSJSL/cettEJEaiaOsZyUiw4CJwPHu+8eISB8RWRdUt6WIvCUiv4vIHyLylFt+qIh8\nJCKb3df+5S5jROQVnC/ImW67N4lIaxEpFZEUt04LEZnhxva9iFwWNM0xIvKGiLzsvn+FiHQJtxyB\n04DKlvPfRWSBiOxdSRtRUdUvVXUq8HMl1aL+ie/28JcHPf9QRBYHPV8gO4eNWojIm+7y/0lErg2q\nt8sQiYhc5P46+ENE7gzRu/aGWt5h1qXXXd+b3e1rkYg0i3Zeq1CecEXkEHe7zxGR2UDToNcqblcX\nu8si1/0/RESOAJ7F2c7zRGSrW7e/iCxx210rImNCtHuR+9rvInJH0OspInKHiPzovv9LETnIfe2I\noM/1ahEZFGq+XPOA02O21IKpalL84fSk84FS4CsgM0y9PwN+IKWStsYArwQ9vxjIBNJwevDfBL22\nHujhPm4MHO0+fgB4BueL0AP0DDOthm4b1wPpwF5A15qIAyfxnOg+HgosCHqtD/CL+zgFWAo8CmS4\ncZZNqy3OMEwqsB/OxviPCtPoF/S8NVBStvyBBcDT7jwdBfwO9A2a/wJ3nYk7L5+HWY6Z7rrfL6hs\nqNu+4HxJvQ94w7x/CLAN2Or+D368FWhZxfZ3EvC/EOU/AxuATcAHOMM9kWzPGe687+su243AOnf7\nKHutiTtvXwGj3fV7CPAjcErFbQjoCOQBx7tt/h3wBW0DlS7vEOtyODAD8Lr1jwEahpmfmSGWadn/\ndyNcJgvdmNOAXkBu0LyVb1futpADtHNfOwDoEGo7d8t6A1nu4yPd9XVGULulwPM4231noAg43H39\nZmBZ0LQ6Afu4MfwCXOQum7Jt+4gw87aPG3/I5bcnf0nREwdQ1atxEuIJwL9xNs5Q9gM2q2pVY23B\nbb+kqgWqWgzcAxwV1JvzA1kisreq5qjqUre8GGiBM8xToqqfhWl+ALBBVZ9QVb+q5qvqlwmIozLd\n3TZuUdUiN86Fbkw/qepHqhpQ1S3A4zhfAMFC9kRF5GCchHKrqhar6jLgBZwNv8ynqjpbnS19Cs6H\nKJQmOL2bvArl6cBr7usDVTXkdqGqr6nqPqq6r/s/+PG+qvprmOlW5TycxNoa5wtudtkvlcqo88vy\nS5wEcyxOovgM6AlkA9+r6nagG9BUVe931+8anGV4bohmz8ZJmJ+ragC4O0SdqpZ38Losxvk8tVfH\nN6q6I8z8DAyxTMv+VzmM6W4rxwF3u9vKJzhfDOGUAJ1EJENVN6lq2J30qrpAVVe6j7/FGdoI3oYV\nGOtu98tx1sVR7muXAqNV9Uf3/StUdRvO5/pnVX3FXTbLcPJScG88WB7Osm1S2XKojqRJ4gDuwloI\nHAyMABCRWe5Pp1wRGQJsAZqW/eyqivtz6SH359J2nN6IsvOn3Nk4P4PWishcEcl2yx8BfgL+4773\n1jCTONitl+g4KtMSWBvqi09E9heR10TkVzeufwXFVJUWwFZVLQgqWwscFPR8Y9DjAiAjzLrb7v6v\nOFTSDmdfxzg3cdUoN2H63C+/h3Di7BXh2xcA/XAS+Tz3ry9OgikbNmoFHCTOUUVbRWQbcDvOvp6K\nDsTpzZfFVojzeQgW6fIGeAWYDbzurv+HRMQT4bxF60BgmxtzmbWhKrrb0zk4OWCDOMOJh4drWES6\nibNz+nd3G76C3bfhTUGPC3A6jOB8fnfboY3zpZ1dYb2cBzQPE8beOJ/n7WFer7akSuJBUnHHxFW1\nv6rurc5OpdeAz3F66WdF2Nb5wECcn5xNcHpV4v6hql+r6llAM5yfltPc8nxVvUlV2+IkkRtFpF+I\n9tcRZvy+huOozDqgVZgP8wM4Pzez3LguYNfeWmU7kdYD+4rIXkFlrYDfooyv7IP7E9C+wkurgEuA\nD0Sk4mvlxDlaqezLPvivrKxltDGFC5XIx8jn4yTtXu7jBTgJvDc7k/g6nGGcfYN6t41VdWCI9jbg\nfCEDICINcHrS0cS+84nT879XVbOAHjjb50Wh3lihM1Xx770Ipr0B2MeNuUyrsIGqfqiqf8JJmt8B\nZYdZhtoeXwXeAQ5yt+HniXwdhfv8rgPmVVgvjdQZMQilA86BBSF/yeyJWp/ERaSZiJwjInu5vdU/\n4/yUnBOqvqrm4oz9jReRM0WkgYikishpIvJQiLc0xEn629xk8yDuhiAiae6Hv5GqluD8JCpxXztd\nRMpWbh4QwEl2Ff0f0FxErhORdBFpKCLd4hxHSeilGdZinA/RQyKSKc4OrbLDMfcGdgB57g6dmyu8\ndyNwaIWysi+eX3HGOR902+yM8/O0smOVK/twzWL3oRxU9Q3gDuBDEakYS1mdV4O+7IP/yspCDqeI\nw4szbJPizkea+9rBItLDXT9eEbkZJ2l+5r5ettMsXDJaCByOM2SyWFVX4fTwuuMkdHDWTZ6I3CIi\nGSLiEWdH+XEh2nsTGCgi2W6MY8NMd5dZDHq8y7oUZ+frke6X+w6c4ZWQw5QVOlMV/6rcoaeqv+CM\n/Y9zl+cJOF8au8Xq/jo8Q0Qy3Zh2BMW1CWhZto5cDXF6+cXuZ++8SpZBRS8A94pIO3fanURkH5zP\ndXsRucDNL2kicpw4O1dD6YOzzybman0Sx0lkI3C++bbiDB+MVNWw3+6q+g/gRuBOnJ0NvwBX4Xwb\nV/SK+/pvwLc4H6xgFwI/uz/DhrNzAzgMmCMieTgf2vGqutuRE+437yk4veSNwPc4va94xlGWACI6\n1ModRhnotvULzrIe7L48DmfMdjvOGOVbFd7+EHCX+5PyxhDTHQK0wemVvwXcpapzKwunktcm4vwS\nCDUPrwD3Ah9VkjSrozdQiPOhPRjnp/Zs97W9cY6G2Ar8CvwJONUdM8Wtv4YwvzzcXxdfA98GDQV9\njtNj2+zWKcUZfz0aZ4jtd5zlsNu4u/slcC3wBs7yznXrh9t/BLsu74rrsjnOF0MOsBKYS+VfwHvq\nPJz9AVuAu4CXw8SagvP5/g3YjLOORrivfezGulFEfnfLrsZJxDk4OeGNMO2Gev4PnF+9/3Hf/wLQ\nwP1c/wmnQ7ne/XsI58s+lCE4vwBiTpz9G8YkBxH5FzBNdz3hp1YSkdHA76o6MUHT3wvny7edqoYc\nXzbxJyIDgAtUNdTO6D1
"text/plain": [
"<matplotlib.figure.Figure at 0x7f21a42a1c88>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%run util_knn.py\n",
"plot_classification_iris()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluating the algorithm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Precision, recall and f-score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For evaluating classification algorithms, we usually calculate three metrics: precision, recall and F1-score\n",
"\n",
"* **Precision**: This computes the proportion of instances predicted as positives that were correctly evaluated (it measures how right our classifier is when it says that an instance is positive).\n",
"* **Recall**: This counts the proportion of positive instances that were correctly evaluated (measuring how right our classifier is when faced with a positive instance).\n",
"* **F1-score**: This is the harmonic mean of precision and recall, and tries to combine both in a single number."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" setosa 1.00 1.00 1.00 8\n",
" versicolor 0.79 1.00 0.88 11\n",
" virginica 1.00 0.84 0.91 19\n",
"\n",
"avg / total 0.94 0.92 0.92 38\n",
"\n"
]
}
],
"source": [
"print(metrics.classification_report(y_test, y_test_pred,target_names=iris.target_names))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Confusion matrix"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another useful metric is the confusion matrix"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 8 0 0]\n",
" [ 0 11 0]\n",
" [ 0 3 16]]\n"
]
}
],
"source": [
"print(metrics.confusion_matrix(y_test, y_test_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"We see we classify well all the 'setosa' and 'versicolor' samples. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### K-Fold validation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to avoid bias in the training and testing dataset partition, it is recommended to use **k-fold validation**."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0.93333333 0.8 1. 0.93333333 0.93333333 0.93333333\n",
" 1. 1. 0.86666667 1. ]\n"
]
}
],
"source": [
"from sklearn.cross_validation import cross_val_score, KFold\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# create a composite estimator made by a pipeline of preprocessing and the KNN model\n",
"model = Pipeline([\n",
" ('scaler', StandardScaler()),\n",
" ('kNN', KNeighborsClassifier())\n",
"])\n",
"\n",
"# create a k-fold cross validation iterator of k=10 folds\n",
"\n",
"cv = KFold(x_iris.shape[0], 10, shuffle=True, random_state=33)\n",
"\n",
"# by default the score used is the one returned by score method of the estimator (accuracy)\n",
"scores = cross_val_score(model, x_iris, y_iris, cv=cv)\n",
"print(scores)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"We get an array of k scores. We can calculate the mean and the standard error to obtain a final figure"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean score: 0.940 (+/- 0.021)\n"
]
}
],
"source": [
"from scipy.stats import sem\n",
"def mean_score(scores):\n",
" return (\"Mean score: {0:.3f} (+/- {1:.3f})\").format(np.mean(scores), sem(scores))\n",
"print(mean_score(scores))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So, we get an average accuracy of 0.940."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tuning the algorithm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We are going to tune the algorithm, and calculate which is the best value for the k parameter."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.text.Text at 0x7f21711a9cc0>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEPCAYAAABRHfM8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xu8XGV97/HPN1cJuQIhkISdhFsgkTsEkAg78ULwUlqO\np4K2inqQHkVRzzmF0noMr9qK7dGKRXtAEW1fKq0coWiVBkk2yCUYIIQQkgBm534hBEhICCHZ+3f+\neNZkTyZzWTOz1qzZM7/36zWvzKxZl2c/WfP81vOs53mWzAznnHOukgFZJ8A551z/4AHDOedcLB4w\nnHPOxeIBwznnXCweMJxzzsXiAcM551wsqQcMSXMkrZD0vKTrinw/WtLPJS2RtFDStLzvVkfLF0v6\nXdppdc45V5rSHIchaQDwPPAuYCOwCLjczFbkrfN3wOtm9teSpgLfMbN3R9+tAs4ys1dTS6RzzrlY\n0q5hzABeMLM1ZrYXuBO4tGCdacB8ADNbCUyWNDb6Tg1Io3POuRjSLownAOvyPq+PluVbAlwGIGkG\n0AFMjL4z4H5JiyRdlXJanXPOlTEo6wQANwE3S3oKWAosBnqi7y4ws01RjeN+ScvN7OGsEuqcc+0s\n7YCxgVBjyJkYLdvPzF4HPpn7LKkbWBV9tyn6d6ukuwlNXAcFDEk+IZZzzlXJzFTN+mk3SS0Cjpc0\nSdIQ4HLg3vwVJI2SNDh6fxXwoJntlDRM0vBo+aHAe4FnSx3IzPyVwOsrX/lK5mlopZfnp+dns75q\nkWoNw8x6JF0DzCMEp9vNbLmkq8PXdhtwMvAjSb3AMuBT0ebjgLuj2sMg4MdmNi/N9DrnnCst9XsY\nZnYfMLVg2a157xcWfh8t7wZOTzt9zjnn4vEuq+4AnZ2dWSehpXh+JsvzM1upDtxrFElW69+xbRs8\n+ih88IMJJ8r1a7t3w/PPw2mnZZ2S7O3cCatXw9vfnnVKXJIkYU1207vpvfIKXHtt1qlwzebXv4ZP\nfzrrVDSHe+6Bz30u61S4ZtD2AaOjAzZsgH37sk6Jaybd3eHlPC9cn7YPGEOHwtixIWg4l9PdDVu3\nhuaYdtfdDevWwd69WafEZa3tAwbAlCl+BeUOlDsfVq/ONBlNobsbentD0HDtzQMGHjDcwbq7YcIE\nPy/A88L18YCBBwx3ILNQs5g1y2sY+/bBxo1w0UWeF84DBuABwx3opZdg2DA49VQ/L9atg6OOghNP\n9LxwHjAAmDzZr55cn+7ucBHhFxKeF+5AHjDwH4M7kBeSfTwvXD4PGMDEiaEL5Z49WafENYPCQrIF\nJkOomQcMl88DBjBwYAgaa9ZknRLXDHKF5Jgx4fOrbfxE+VxejB8f8mH37qxT5LLkASPiV1AuZ/Xq\ncF9L8vtbubwYMACOOcYvqtqdB4yIBwyXk7uqBj8vPC9cPg8YkXa/knRBT0/oSjppUvjczoXk7t1h\ncs7x48Pnds4LF6QeMCTNkbRC0vOSrivy/WhJP5e0RNJCSdMKvh8g6SlJ9xZumyT/MTgIc4odcQS8\n7W3hczufF2vWhGaogQPD53bOCxekGjAkDQBuAS4GpgNXSDqpYLUbgMVmdhrwceDbBd9fCzyXZjrB\nfwwuyG+CgfY+LzwvXKG0axgzgBfMbI2Z7QXuBC4tWGcaMB/AzFYCkyWNBZA0EXgf8P2U0+k/Bgd4\nIZnP88IVSjtgTADy57hcHy3LtwS4DEDSDKADmBh99w/A/wJS7wk/bhzs2uXTWbe7XK+gnNy9rXYc\ni9HdXTwvXPsalHUCgJuAmyU9BSwFFgM9kt4PbDGzpyV1AmUfJTh37tz97zs7O6t+9m+uC2V3N5xy\nSlWbuhbS3R0m2ssZPjy8tmwJcyq1k9Wr4ayz+j6PHRsGt+7YASNHZpYsV6Ouri66urrq2kfaAWMD\nocaQMzFatp+ZvQ58MvdZ0ipgFXA58AeS3gccAoyQ9M9m9rFiB8oPGLXKXUF5wGhf3d1w5ZUHLss1\nxbRbwChsksq/qPJnnfc/hRfSN954Y9X7SLtJahFwvKRJkoYQgsABvZ0kjZI0OHp/FfCQme00sxvM\nrMPMjo22m18qWCTF22hdYSEJ7XteeF64QqnWMMysR9I1wDxCcLrdzJZLujp8bbcBJwM/ktQLLAM+\nlWaayvEfQ3vbsydMbT5x4oHL2/G82LEj5MfYsQcub8e8cH1Sv4dhZvcBUwuW3Zr3fmHh90X28SDw\nYCoJzDNlCjz8cNpHcc1q7dowSG1Qwa9iyhRYtCibNGUld8NbBXcOc01Srj35SO88fvXU3lavPrgJ\nBtqzkCzsIZUzZYr3lGpnHjDy+HTW7a1Ymz20ZyFZKnj6RVV784CRZ/ToUAVv5+ms21mpgDFpEqxf\nH+aZahflgqdfVLUvDxh5JL+CamelCsmhQ8PN3/XrG5+mrJTKi1GjYMgQePnlxqfJZc8DRgEPGO2r\nVCEJ7XdeeF64YjxgFPAfQ/sqnBYkXzudF2alb3qDTxHSzjxgFGingsH12bUrjD0oNZq7nXpKvfwy\nDB4c7ukV47+R9uUBo4D/GNrT6tXh5vaAEr+IduopVaqHVI7/RtqXB4wC7XQl6fqUa7OH9iokPS9c\nKR4wCkyeHJ405t0G24sXkn08L1wpHjAKDB8OI0bA5s1Zp8Q1UqVCcuJE2Lo1zK/U6irlxeTJYRqV\n3t6GJck1CQ8YRfgVVPsp10MKwnOtJ04Mtc9WVylgHHJIuCG+aVPj0uSagweMIjxgtJ9KhSS0z3nh\neeFK8YBRhP8Y2k+cQrIdxh/09obmpkmTyq/nv5H25AGjCO8p1V5efTUUlIcdVn69digkN20KzU3D\nhpVfrx3ywh3MA0YR7dTn3vXVLgqf/VCoHQrJODUtaI+8cAdLPWBImiNphaTnJV1X5PvRkn4uaYmk\nhZKmRcuHSnpc0mJJSyV9Je205viPob14IdnH88KVk2rAkDQAuAW4GJgOXCHppILVbgAWm9lpwMeB\nbwOY2R5glpmdAZwOXCJpRprpzenogA0bYN++RhzNZa1SD6mcdigk4waMdrif4w6Wdg1jBvCCma0x\ns73AncClBetMA+YDmNlKYLKksdHnN6J1hhIeJ9uQ4XRDh8KRR7bXdNbtLG4hOW5cmHNq587005SV\nuHnR0RHud+zdm36aXPNIO2BMANblfV4fLcu3BLgMIKpBdAATo88DJC0GNgP3m1nDnqzcDleTLohb\nSEqh91ArX1nHrW0NHhwmaly3rvK6rnUMqrxK6m4Cbpb0FLAUWAz0AJhZL3CGpJHAPZKmmdlzxXYy\nd+7c/e87Ozvp7OysK1G5nlKzZtW1G9cPxA0Y0Hch8fa3p5umrNSSF8cem26aXDK6urro6uqqax9p\nB4wNhBpDzsRo2X5m9jrwydxnSd3AqoJ1dkhaAMwBKgaMJHhPqfZgVnl21nytXPPcuzc0M3V0VF4X\nWjsvWlHhhfSNN95Y9T7SbpJaBBwvaZKkIcDlwL35K0gaJWlw9P4q4EEz2ynpCEmjouWHAO8BVqSc\n3v38x9AetmyBQw8Nc4jF0crnxbp1oZlp8OB467dyXrjiUg0YZtYDXAPMA5YBd5rZcklXS/p0tNrJ\nwLOSlhN6U10bLT8aWCDpaeBx4D/N7Fdppjef/xjaQ9w2+5xWPi+qaY4C7ynVjlK/h2Fm9wFTC5bd\nmvd+YeH30fKlwJlpp6+UVi4YXJ9qC8lWPi88L1wlPtK7hAkTwqMq33wz65S4NNV6Vd2Kz0vx2par\nxANGCe00nXU7qzZgjBkT/n311XTSk6Vq82L8+JAPu3enlybXXDxglOE9pVpftYWk1LpX1tXmxYAB\noUeV/0bahweMMlq1YHB9qi0koXXPC88LV4kHjDL8x9DaenrC9C9xxx3ktOJ5sXt3aF4aP7667byn\nVHvxgFFGKxYMrs+GDXDEEfC2t1W3XSueF6tXh8A5oMoSoRXzwpXmAaMM/zG0tlqaYKA1r6qr7SGV\n47+R9uIBowx/8l5rqzV
"text/plain": [
"<matplotlib.figure.Figure at 0x7f21711cc6d8>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"k_range = range(1, 21)\n",
"accuracy = []\n",
"for k in k_range:\n",
" m = KNeighborsClassifier(k)\n",
" m.fit(x_train, y_train)\n",
" y_test_pred = m.predict(x_test)\n",
" accuracy.append(metrics.accuracy_score(y_test, y_test_pred))\n",
"plt.plot(k_range, accuracy)\n",
"plt.xlabel('k value')\n",
"plt.ylabel('Accuracy')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The result is very dependent of the input data. Execute again the train_test_split and test again how the result changes with k."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* [KNeighborsClassifier API scikit-learn](http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html)\n",
"* [Learning scikit-learn: Machine Learning in Python](http://proquest.safaribooksonline.com/book/programming/python/9781783281930/1dot-machine-learning-a-gentle-introduction/ch01s02_html), Raúl Garreta; Guillermo Moncecchi, Packt Publishing, 2013.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Licence\n",
"The notebook is freely licensed under under the [Creative Commons Attribution Share-Alike license](https://creativecommons.org/licenses/by/2.0/). \n",
"\n",
"© 2016 Carlos A. Iglesias, Universidad Politécnica de Madrid."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}