1
0
mirror of https://github.com/gsi-upm/sitc synced 2024-11-22 22:42:28 +00:00
sitc/ml1/2_5_2_Decision_Tree_Model.ipynb

702 lines
102 KiB
Plaintext
Raw Normal View History

2016-03-15 12:55:14 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![](files/images/EscUpmPolit_p.gif \"UPM\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Course Notes for Learning Intelligent Systems"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Department of Telematic Engineering Systems, Universidad Politécnica de Madrid, © 2016 Carlos A. Iglesias"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [Introduction to Machine Learning](2_0_0_Intro_ML.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Table of Contents\n",
"* [Decision Tree Learning](#Decision-Tree-Learning)\n",
"* [Load data and preprocessing](#Load-data-and-preprocessing)\n",
"* [Train classifier](#Train-classifier)\n",
"* [Evaluating the algorithm](#Evaluating-the-algorithm)\n",
"\t* [Precision, recall and f-score](#Precision,-recall-and-f-score)\n",
"\t* [Confusion matrix](#Confusion-matrix)\n",
"\t* [K-Fold cross validation](#K-Fold-cross-validation)\n",
"* [References](#References)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Decision Tree Learning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2016-03-28 10:26:20 +00:00
"The goal of this notebook is to learn how to learn how create a classification object using a [decision tree learning algorithm](https://en.wikipedia.org/wiki/Decision_tree_learning). \n",
2016-03-15 12:55:14 +00:00
"\n",
"There are a number of well known machine learning algorithms for decision tree learning, such as ID3, C4.5, C5.0 and CART. The scikit-learn uses an optimised version of the [CART (Classification and Regression Trees) algorithm](https://en.wikipedia.org/wiki/Predictive_analytics#Classification_and_regression_trees).\n",
"\n",
"This notebook will follow the same steps that the previous notebook for learning using the [kNN Model](2_5_1_kNN_Model.ipynb), and details some pecualiarities of the decision tree algorithms."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load data and preprocessing\n",
"\n",
"Here we repeat the same operations for loading data and preprocessing than in the previous notebooks."
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 2,
2016-03-15 12:55:14 +00:00
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# library for displaying plots\n",
"import matplotlib.pyplot as plt\n",
"# display plots in the notebook \n",
"%matplotlib inline\n",
"\n",
"## First, we repeat the load and preprocessing steps\n",
"\n",
"# Load data\n",
"from sklearn import datasets\n",
"iris = datasets.load_iris()\n",
"\n",
"# Training and test spliting\n",
"from sklearn.cross_validation import train_test_split\n",
"\n",
"x_iris, y_iris = iris.data, iris.target\n",
"# Test set will be the 25% taken randomly\n",
"x_train, x_test, y_train, y_test = train_test_split(x_iris, y_iris, test_size=0.25, random_state=33)\n",
"\n",
"# Preprocess: normalize\n",
"from sklearn import preprocessing\n",
"scaler = preprocessing.StandardScaler().fit(x_train)\n",
"x_train = scaler.transform(x_train)\n",
"x_test = scaler.transform(x_test)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"## Train classifier"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The usual steps for creating a classifier are:\n",
"1. Create classifier object\n",
"2. Call *fit* to train the classifier\n",
"3. Call *predict* to obtain predictions\n",
"\n",
"*DecisionTreeClassifier* is capable of both binary (where the labels are [-1, 1]) classification and multiclass (where the labels are [0, ..., K-1]) classification."
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 3,
2016-03-15 12:55:14 +00:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,\n",
" max_features=None, max_leaf_nodes=None, min_samples_leaf=1,\n",
" min_samples_split=2, min_weight_fraction_leaf=0.0,\n",
" presort=False, random_state=1, splitter='best')"
]
},
2016-03-29 10:48:56 +00:00
"execution_count": 3,
2016-03-15 12:55:14 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.tree import DecisionTreeClassifier\n",
"import numpy as np\n",
"\n",
"from sklearn import tree\n",
"\n",
"max_depth=3\n",
"random_state=1\n",
"\n",
"# Create decision tree model\n",
"model = tree.DecisionTreeClassifier(max_depth=max_depth, random_state=random_state)\n",
"\n",
"# Train the model using the training sets\n",
"model.fit(x_train, y_train) "
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 4,
2016-03-15 12:55:14 +00:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prediction [1 0 1 1 1 0 0 1 0 2 0 0 1 2 0 1 2 2 1 1 0 0 2 0 0 2 1 1 2 2 2 2 0 0 1 1 0\n",
" 1 2 1 2 0 2 0 1 0 2 1 0 2 2 0 0 2 0 0 0 2 2 0 1 0 1 0 1 1 1 1 1 0 1 0 1 2\n",
" 0 0 0 0 2 2 0 1 1 2 1 0 0 2 1 1 0 1 1 0 2 1 2 1 2 0 1 0 0 0 2 1 2 1 2 1 2\n",
" 0]\n",
"Expected [1 0 1 1 1 0 0 1 0 2 0 0 1 2 0 1 2 2 1 1 0 0 2 0 0 2 1 1 2 2 2 2 0 0 1 1 0\n",
" 1 2 1 2 0 2 0 1 0 2 1 0 2 2 0 0 2 0 0 0 2 2 0 1 0 1 0 1 1 1 1 1 0 1 0 1 2\n",
" 0 0 0 0 2 2 0 1 1 2 1 0 0 1 1 1 0 1 1 0 2 2 2 1 2 0 1 0 0 0 2 1 2 1 2 1 2\n",
" 0]\n"
]
}
],
"source": [
"print(\"Prediction \", model.predict(x_train))\n",
"print(\"Expected \", y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, the probability of each class can be predicted, which is the fraction of training samples of the same class in a leaf:"
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 5,
2016-03-15 12:55:14 +00:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted probabilities [[ 0. 0.97368421 0.02631579]\n",
" [ 1. 0. 0. ]\n",
" [ 0. 0.97368421 0.02631579]\n",
" [ 0. 0.97368421 0.02631579]\n",
" [ 0. 0.97368421 0.02631579]\n",
" [ 1. 0. 0. ]\n",
" [ 1. 0. 0. ]\n",
" [ 0. 0.97368421 0.02631579]\n",
" [ 1. 0. 0. ]\n",
" [ 0. 0. 1. ]]\n"
]
}
],
"source": [
"# Print the \n",
"print(\"Predicted probabilities\", model.predict_proba(x_train[:10]))"
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 6,
2016-03-15 12:55:14 +00:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy in training 0.982142857143\n"
]
}
],
"source": [
"# Evaluate Accuracy in training\n",
"\n",
"from sklearn import metrics\n",
"y_train_pred = model.predict(x_train)\n",
"print(\"Accuracy in training\", metrics.accuracy_score(y_train, y_train_pred))"
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 7,
2016-03-15 12:55:14 +00:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy in testing 0.921052631579\n"
]
}
],
"source": [
"# Now we evaluate error in testing\n",
"y_test_pred = model.predict(x_test)\n",
"print(\"Accuracy in testing \", metrics.accuracy_score(y_test, y_test_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we are going to visualize the DecisionTree classification. It will plot the decision boundaries for each class.\n",
"\n",
"The current version of pydot does not work well in Python 3.\n",
"For obtaining an image, you need to install `pip install pydotplus` and then `conda install graphviz`.\n",
"\n",
"You can skip this example. Since it can require installing additional packages, we include here the result.\n",
"![Decision Tree](files/images/cart.png)"
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 8,
2016-03-15 12:55:14 +00:00
"metadata": {
2016-03-28 10:26:20 +00:00
"collapsed": false
2016-03-15 12:55:14 +00:00
},
"outputs": [
2016-03-29 10:48:56 +00:00
{
"name": "stdout",
"output_type": "stream",
"text": [
"Warning: Not built with libexpat. Table formatting is not available.\n",
"in label of node 0\n",
"in label of node 1\n",
"in label of node 2\n",
"in label of node 3\n",
"in label of node 4\n",
"in label of node 5\n",
"in label of node 6\n",
"in label of node 7\n",
"in label of node 8\n",
"\n",
"Warning: Not built with libexpat. Table formatting is not available.\n",
"in label of node 0\n",
"in label of node 1\n",
"in label of node 2\n",
"in label of node 3\n",
"in label of node 4\n",
"in label of node 5\n",
"in label of node 6\n",
"in label of node 7\n",
"in label of node 8\n",
"\n"
]
},
2016-03-15 12:55:14 +00:00
{
"data": {
2016-03-29 10:48:56 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXMAAAFbCAYAAAAnRW7JAAAABmJLR0QA/wD/AP+gvaeTAAAgAElE\nQVR4nOzdeVhU5fs/8PeZhX0NcwEVgUBKMTfK3HBHEMstDdRKP+a+27dcPqOW5G4uuFBWbpkJRpgK\nuKBoiYi5hAso7gi4sG8zwsw8vz/6yC8Cd2aeMzP367q4rjgO87y5G24ezjznOQJjjIEQQoghi5Tw\nTkAIIeTlUTMnhBAjQM2cEEKMgIx3AFI7VCoVUlNTkZmZiZycHJSXl/OO9FjW1taoV68e3N3d4ebm\nBkEQeEcixOBRMzdg+fn52LZtG3755RccP34cGo2Gd6Tn5uDggN69e2Po0KEICAiAVCrlHYkQgyTQ\nahbDo1QqsWTJEixbtgwSiQRBgQHo1asnWrdqiUYNG8Ha2gpmZma8Yz5WaWkpysqUSE1LQ9LJZOyL\nicEfxxPh4eGBlStXIigoiHdEQgxNJDVzA7Nnzx5MmTIFeXl5UMyZjU/+MwK2tra8Y720a9ev44sv\nQ7F9x88IDAxEWFgY3NzceMcixFDQ0kRDodVqMWfOHLz33nvw69QRaRdSMH3qZKNo5ADg4e6OrZt/\nwNHDh3AnIwO+vr44fPgw71iEGAxq5gagtLQUAwYMwLJly/Ddt+H44btvUa9eXd6xdKJjh/ZI/D0B\n3br4wd/fHxs2bOAdiRCDQG+AipxWq8XQoUNx9OhRxO79Dd26duEdSecsLS3x808/Yo5iLiZMmABb\nW1sMGzaMdyxCRI2aucjNmDEDcXFxOLQ/Fh3av8M7jt4IgoCFoQtQUVGBkSNHwsXFBV27duUdixDR\nojdARSwqKgoDBw7Et+HrMWrkCN5xuFCr1Xi3/0CcPfcXLl68CCcnJ96RCBEjWs0iVmVlZXj99dfh\n16kjtmz6nnccrgoKCuDdvAUGDXof69at4x2HEDGi1SxiFRoaiqKiIixbsph3FO4cHBywdPEihIeH\nIzk5mXccQkSJZuYilJubi0aNGkExZzZmfvYp7ziioNVq8Xb7TnB2ccFvv/3GOw4hYkMzczHasmUL\nZDIZxo8dzTuKaEgkEsyYPhUxMTHIyMjgHYcQ0aFmLkLR0dHo925f2NnZ8Y4iKgP6vQcrKyvs3r2b\ndxRCRIeaucioVCqcOHEC/v69eEcRHTMzM3T180NCQgLvKISIDjVzkUlNTYVarUbrVi15RxGlVq1a\n4sKFC7xjECI61MxFJjs7GwDQqGFDzknEqVGjhsjKyuIdgxDRoWYuMqWlpQAAKysrrjk0Gg1CFy5G\na9+3YffKq2jfyQ/f/bAJvBc/WVtbo6SkhGsGQsSImrnIPGqWvO++MyR4KObO/wIODg6YOH4slEoV\nRo8dD8W8+VxzCYLA/RcKIWJEzZxUczL5FKKid+O9vn1xaH8sFoYuwPFjR9DCxwdfr1qD+/cf8I5I\nCPkXauakmvUbwgEAU6dMgkTy90vEysoKY8d8ApVKhe83beaYjhBSE2rmpJrLV65AKpVW26XRr3Mn\nAEB6ejqPWISQJ6BmTqq5cycTr7ziCJms6g7Jr9Z5FQCQSatJCBEdauakmgc5ObC1qX47Onv7v69I\nvXfvvr4jEUKegpo5qcbJ6RWUlFZf/ldUVAwAcHR00HckQshTUDMn1Tg3aIC8vHxoNJoqx3NycwAA\nLs7OPGIRQp6Amjmpxqd5c6jVapxMPlXleOKJJADAG2+8wSMWIeQJqJmTaj4Z9R8AQPg331ZeoFNR\nUYEfNm2GXC7HyI8/4hmPEFIDuqEzqeaddm9j8KCB+PGnHVCr1WjX7m3s2bsPxxNPYJ5iDurXr8c7\nIiHkX6iZk2oEQcD2bVvw+uuvY8/evdgXG4cWPs1N+sbShIgdNXNSI6lUinmKOZinmMM7CiHkGdA5\nc0IIMQLUzAkhxAhQMxeZR5fQ/3uNN/mbRqOpts0AIYSauejY29sDAAoLizgnEaeCggI4ONAVqIT8\nGzVzkfHw8AAAXKGdCWuUnp4Od3d33jEIER1q5iLj6uoKBwcHJJ1M5h1FlE4mn8Kbb77JOwYhokPN\nXGQEQUCvXr2wd98+3lFE5+7de0g+9Sf8/f15RyFEdKiZi1BwcDCOJBylUy3/8v2mzbCzs0NAQADv\nKISIDjVzEQoKCoK7uzvmzf+SdxTRyM3Nw6o1YRg1ahSsrKx4xyFEdKiZi5BMJsPq1auxM3IX9h84\nyDuOKMz+rwLm5uZQKBS8oxAiSgJ7tC0eER1/f39kZ2Uh6fgxWFpa8o7DTfKpP9GhcxeEh4dj1KhR\nvOMQIkaR1MxFLD09HW+99RZ6du+Gn3/6EYIg8I6kdxl37qBdh05o1qw59u/fD4mE/pgkpAaR9JMh\nYp6enoiIiMCvu3/Df+fO4x1H74qLi/Fe/0GwtbVDREQENXJCnoCuixa5nj17YvXq1ZgwYQJKS8uw\nYtkSSKVS3rF07k5mJt7tNxAZd+4gMTERjo6OvCMRImrUzA3A+PHjYWtri9GjR+P6jRv4ccsm2NnZ\n8Y6lM6fPnEW/gYNgY2OLxMREeHl58Y5EiOjR360GYvjw4Thy5AhOnzmL133exLbtP8HY3u4oKCjA\ntBn/h3c6doaPTwskJSVRIyfkGVEzNyDt2rXDpUuXMGjQ+/jPJ2PQ0a8rInb9ApVKxTvaS8nOvosl\ny1bAu3kL7Izche+++w4xMTG0oRYhz4FWsxiolJQUKBQK7Nu3DxYWFvDr3AmtW7VEw4YNeUd7Jmq1\nGmlpl5F0Mhlnzp6Fg4MDRo0ahdmzZ1fuHEkIeWa0NNHQ3blzB9HR0Th8+DDOnz+PzMxMKJVK3rGe\nytHREe7u7mjVqhUCAgIQGBgICwsL3rEIMVTUzE3NoEGDcPXqVZw5c0YnS/0uXbqEFi1a4KeffsLg\nwYNr/fkJITWiZm5KTp48iXbt2mHfvn0IDAzU2TgjRozAH3/8gdTUVLorECH6Qc3clPTu3RtlZWU4\nduyYTse5desWvLy8sGHDBowcOVKnYxFCAFAzNx1Hjx5Fly5dcOjQIXTv3l3n440dOxZxcXG4fPky\nzM3NdT4eISaOmrmp6NSpE+RyOQ4fPqyX8bKysvDaa69hxYoVGDdunF7GJMSE0d4spiA+Ph5//PEH\nFixYoLcxnZ2d8cknn2DBggUGsbqGEENHzfwpQkNDIQjCM31ER0fzjlsjhUIBf39/dOjQQa/jzpo1\nC4WFhfjmm2+e6fHPUuPn4e3tbZI7TRLTREsNnqJ9+/b4/PPPqxxbsmQJHB0dMXr06CrHPT099Rnt\nmcTExCApKQmnTp3S+9j169fHhAkTsHDhQvznP/+Bra3tU7+mproSQp6Ozpm/AEEQ0LRpU6SlpfGO\n8kSMMbRt2xaNGjXi9ldDbm4u3N3dMWvWLMycOfOJj63tunp7e+Py5ctGt4cNITWgc+bGLDo6GufO\nnUNoaCi3DE5OTpg8eTKWL1+OoqIibjkIMXbUzGvZo/O05eXlGDNmDGxtbXH16tUnnr8VBAHe3t6V\nnzPGsHHjRvj5+cHOzg6urq6YNm0a8vLynjmHVquFQqHAoEGD0Lx585f+vl7Gp59+CsYYVq5cWSvP\np9VqsWPHDvj5+cHFxQUWFhZwc3PDpEmTkJub+8Sv+/7779GuXTs4OTnBzs4OrVq1Qnh4eJXZe23U\nnxC9Y+S5AWBNmzat8d+aNm3KALDx48ezevXqsb59+7Ls7OzK48/yfJ988gkDwN566y02d+5cNnDg\nQAaAtWjRgpWUlDxTxh07djCZTMYuX778/N+gDixYsIDZ29uz3Nzcxz7mSXX9p6lTpzIAzN7enn34\n4Yds1qxZrE2bNgwA69OnT+Xj/l1zhULBADBvb282depUNmXKFObu7s4AsLCwsMrH1Ub9CdGzCGrm\nL+BZmrmvry8rKyurdvxpz3f06FEGgAUEBLCKiorKx6xatYoBYF999dVT81VUVDBPT082fPjw5/m2\ndKq4uJjVrVuXzZ49+7G
2016-03-15 12:55:14 +00:00
"text/plain": [
"<IPython.core.display.Image object>"
]
},
2016-03-29 10:48:56 +00:00
"execution_count": 8,
2016-03-15 12:55:14 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from IPython.display import Image \n",
"from sklearn.externals.six import StringIO\n",
"import pydotplus as pydot\n",
"\n",
"dot_data = StringIO() \n",
"tree.export_graphviz(model, out_file=dot_data, \n",
" feature_names=iris.feature_names, \n",
" class_names=iris.target_names, \n",
" filled=True, rounded=True, \n",
" special_characters=True) \n",
"\n",
"\n",
"graph = pydot.graph_from_dot_data(dot_data.getvalue()) \n",
"graph.write_png('iris-tree.png')\n",
"Image(graph.create_png()) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we show a graph of the decision tree boundaries. For each pair of iris features, the decision tree learns decision boundaries made of combinations of simple thresholding rules inferred from the training samples.\n",
"\n",
"We are going to import a function defined in the file [util_ds.py](files/util_ds.py) using the *magic command* **%run**."
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 9,
2016-03-15 12:55:14 +00:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
2016-03-29 10:48:56 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAEhCAYAAACZRRzKAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXd4VFX6xz9n0ntvJCRAxIAUEcEFEQgirr2vBbtuQVAp\nioDoD3FZBRHErrt2BcsqK3ZABVRQQaUqvaSRHlIISaad3x/nTnJnMpMMyaTBfJ8nT2ZuOffcOfee\nt33P+wopJV544YUXXpycMHR0B7zwwgsvvOg4eIWAF1544cVJDK8Q8MILL7w4ieEVAl544YUXJzG8\nQsALL7zw4iSGVwh44YUXXpzEOGmFgBDCIoT4TQixQwixWQgxTQghWtjWXCHEuU3s/4cQ4qaW99Yz\nEEL4CyFWa/f9l3a+9jVCiD+EEN94oK1bhRDPtvDcz4QQ4U3s/7cQok/Le1ffzmghxPDWtuMJCCEu\nFUI80EHXPlMIsaSVbbwuhLjKyfYM7d39VQjRswXtThZCBLambycCxMm6TkAIUSmlDNc+xwLvAuul\nlI90aMfaCEIIH2Ao8KiU8vwOuP6XwD+llBs80NatwJlSyntb37O2gRBiDnBUSrnIyT4fKaWlA7rV\naXA8v4EQ4nXgUynlcoftMwAfKeVjLezDQdRzVHYc55x4YyelPCn/gEqH7z2BEu2zAXgC+BnYAvxN\nd9wMYBuwGXhM2/Y6cJX2eT6wQzvvCW3bHGCa9nkQ8KO2/yMgQtu+Rjv3Z2AXMMJJnxOBdcBvWh9G\naNurdMdcDbyu69eL2vVeAfYC5dr5PYGHgY1aWy/p2kgHVmt9/AXoqW2/Xzt+CzDHxe96g9beNuBx\nbdvDQBWwE1jgcHwI8LV2na3AZS7avR3YDfwE/Bt4RtseC3yo/W4/A2fr2n1N68cW4Ept+0EgGggG\nPtPGcRvwF904DHZyL/N1fakC5mntbgDiHPqaBuQDOdpvPUI3Fj8BT2rXf1X7/qvtvmni2XNof7vu\n+33A/2mf7wV+185dpm27FXhW90w8DawH9tHw3ArgBeAPYCXwuW2fw7XXAEt0v9sQbftQ7bf4FfgB\n6K1tH42awEG9B29p+5c2da/Ac9rzsspZX4ALdb/xN9q2G7W2ftN+a5uS+wLqud2O9twC9wB1qGfO\ndn5z75GrsbtUO+403fW3AOkdPc+5NRd2dAc67MYdhIC2rQyIA/4GPKht8wc2aS/eBdoDHKDti9Q9\nJFehJpdduvbCdQ+/TQhsBc7RPs8FFmuf1wALdQ/4aif9mwbM0j4LIMTxXrSH9zVdvz7R7Rvt8D1S\n9/kt4GLt8080TEr+QCAwDnhZd+1PbfehayMJyNJ+BwPwja6dNcAZTu7JAIRqn2OAvU6OSdS166uN\ngU0ILKVh4u8O/KF9nm/7bbXvNmF7QGvnKtv9aNvDdP0c3My9WIGLtM8LbM+KQ5/rx9zFWPwLGG/r\nG0rABeHi2XNoOw3YpvuuFwJ5gJ/D83er7vd6HXhf+9zX9nsD1wCfaZ8TUO+CKyFgew5GogkjIBQw\naJ/HAh86PnPab7IJ8Ne+u3rPrgRW6p6pIy76on+v+gCfoCwDgOeBmxzeU4PW//66ZyHK2ZxA0++R\nq7F7BrhB2+6LNk909j9fvHCG84EBOr95ONAbOA+lHdQBSCnLHc6rAGqEEK+gtJfP9Ds1X3SElPIH\nbdObwAe6Q2zm7q+ol8ERm4BXhRB+wAop5VY37uW/TewbK4SYjtJsooAdQoh1QDcp5ScAUkqj1vfz\ngXFCiN/QBBDqN/lB195QYI3UzGshxFJgFOrlRDvPEQbgcSHEKNTk2k0IES+lLNId8yeHdt/Xrg1q\nTPrq4jmhQogQbft1tgaklBUOfdgOPCmEeBz4XDcm7tyLUUr5hXbcr9q13IF+LM4HLtV+f1CTYCqu\nn70sN6+xFVgmhPgY+NjFMR8DSCl3CiHitW0jbP2TUhYKIdY0cY13teO+F0KEac91OPCWEKI3IMHl\n3PKJ7ZnC9b2O0l0jXwjxbZN3rDAWJbw3ac9CIFCo7bteCPE3rU+JKI19B+pZcDcO6M7Y/QjMFkKk\nAP+TUu5zs+0OhVcIaBBC9AIsUspi7SG6R0q52uGYC5pqQ0ppEUKchXog/wLcrX22a6aJJuq0/xac\njI320o0CLgbeEEIsklK+43CYY6Cr2tmFhBABKG1psJTysObDtp3rrI8C5d75TxP9d3VuU7gR5dI5\nQ0pp1fy0zoJ1rtoVwJ+klCa7jULIpi4qpdwrhBgMXATME0J8LaWc5+Y19ddyOlYu4DgWV0sp99pd\n0MWz5wAz4KP7rv+9LkZNopehJqT+Ts6v031uCRnC8beVwD+Bb6WUVwkh0lAatzPofwNX79nFLeiT\nAN6UUs52aKsHylI6U0pZqcUX3AkGN/ceNRo7YLcQ4ifgEuALIcTfpZRr3ex/h+GkZQehe/iFEHEo\nn5+NcbISmCiE8NX29xZCBKP85LcLIYK07VF2DapjIqWUX6FcNwP1+6WUlUCZEGKEtulmlI+/yf7p\n2k8FiqSUr6J8/IO1XQUaU8KAMqXdQSDq5S0VQoSi3AFIKY8COUKIy7Vr+mv3uxK4Q9OyEUJ00343\nPTYCo4QQ0Vog+gZgbTP9iNDuySqEGINzC+hnrd0ozQrSM5tWAZNtX4QQp2sfVwOTdNsj9Q0KIZKA\nGinlMmAhDb9la+5FjyqUZusKK1H+e1t/Bum2Oz57QQ7nFgJx2u8RgJp0bEiVUq4DZmrXD22mn7bn\nbD1wtVBIADKbOOc6rW/nABVSyirUOOZp+29v5po2uHrPvgOuE0IYtHEa40Zb3wDX2J5J7bdJRf0G\nR4Eq7b4u1J1Tif0YufseOR07IURPKeVBKeWzwAoc3v/OipPZEgjUXBv+KM3uLSnlU9q+V4AewG+a\nZlYEXCGlXKlNMr8IIeqAL4CHaNCMwoEVOtrZVCfXvQ14SXuxD9DwwjjTrhyRCUwXQphQk8wt2vZZ\nKPdTESrAanvxXWrDUsoKIcR/UEHEfNSkZ8MtwMtCiEcBIypouloo6uSPmuelCrgJKNa1WSCEmEnD\nZPmZlNLmEnPVl6XAp0KIrVrfdzrpa4EQ4hFUrOIIKuhmw2Tgee18H9QEMhHlt31eCLEdpTnPRblB\nbP0YACwUQli1e5yg76eTe/ncjXvR41PgQyHEZaggpOM584AlQohtqIn4IEp7d/rsOfweZm1sNgG5\naL+ZNpm+o7lnBPC0pv3ane7QD9v3j4BzUc9DDsrNVYFz1Grvji8Nz+8TwJtCiIdQz6I7cPWe/U8o\nyvXvQDYq4NwkNNfWQ8AqbRI3ApOklBuFEFtQv1EO9u7L/wBfCSHypJRjcf890o+dAfUeXwZcK4S4\nGTWf5KOewU6Pk5Yi6oUXXthDCBEipawWQkSjrK8RDrEZtFjBfVLK3zqkk154HCezJeCFF17Y4zPN\nbeaHWk9S5OQYr9Z4gsFrCXjhhRdenMQ4mQPDXnjhhRcnPbxCwAsvvPDiJIZXCHjhhRdenMTwCgEv\nvPDCi5MYXiHghRdeeHESwysEvPDCCy9OYniFgBdeeOHFSQyvEPDCCy+8OInhXTHshcehJTX7DpWX\nyReVW35ux/bKCy+8cAbvimEv2gRCiGAp5TEtA+d64F4p5cbmzvPCCy/aF153kBdtAinlMe1jAMoa\n8GobXnjRCeEVAl60CbRc8JuBAlSpzE0d3ScvvPCiMbpETKC5KlFetB+klG5VopJSWoEztNz2Hwsh\nTpNS/qE/xjuunQvujm1z8I5r50Jz49opLAFNa/xNCPGJq2NsRZHnzJnTomLK3vNaf15LIFU1tTWA\n09KcXeG+T4bzPI2uct8n+nnuoFMIAVR1qD+aPcqLLgEhRKwQIkL7HASMA3Z1bK+88MILZ+hwISCE\nSEEV+36lo/vihceQBKzRyvr9DKyUUn7RwX3ywgsvnKDDKaJCiP+ianFGoMrWXebkGGnr59q1a3li\nx/GHMkr3biam9xkdet6
2016-03-15 12:55:14 +00:00
"text/plain": [
2016-03-29 10:48:56 +00:00
"<matplotlib.figure.Figure at 0x7f0a2d296a20>"
2016-03-15 12:55:14 +00:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%run util_ds\n",
"\n",
"# display plots in the notebook \n",
"%matplotlib inline\n",
"plot_tree_iris()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we are going to export the pseudocode of the the learnt decision tree."
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 10,
2016-03-15 12:55:14 +00:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"if ( petal width (cm) <= -0.415243178606 ) {\n",
" return setosa ( 42 examples )\n",
"}\n",
"else {\n",
" if ( petal width (cm) <= 0.861689627171 ) {\n",
" if ( petal length (cm) <= 0.818572163582 ) {\n",
" return versicolor ( 37 examples )\n",
" return virginica ( 1 examples )\n",
" }\n",
" else {\n",
" return versicolor ( 1 examples )\n",
" return virginica ( 2 examples )\n",
" }\n",
" }\n",
" else {\n",
" if ( petal length (cm) <= 0.707377433777 ) {\n",
" return versicolor ( 1 examples )\n",
" }\n",
" else {\n",
" return virginica ( 28 examples )\n",
" }\n",
" }\n",
"}\n"
]
}
],
"source": [
"%run util_ds\n",
"get_code(model, iris.feature_names, iris.target_names)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also obtain the feature importance of the fitted model as follows."
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 11,
2016-03-15 12:55:14 +00:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']\n",
"[ 0. 0. 0.05947455 0.94052545]\n"
]
}
],
"source": [
"print(iris.feature_names)\n",
"print(model.feature_importances_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see that the most important feature for this classifier is `petal width`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluating the algorithm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Precision, recall and f-score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For evaluating classification algorithms, we usually calculate three metrics: precision, recall and F1-score\n",
"\n",
"* **Precision**: This computes the proportion of instances predicted as positives that were correctly evaluated (it measures how right our classifier is when it says that an instance is positive).\n",
"* **Recall**: This counts the proportion of positive instances that were correctly evaluated (measuring how right our classifier is when faced with a positive instance).\n",
"* **F1-score**: This is the harmonic mean of precision and recall, and tries to combine both in a single number."
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 12,
2016-03-15 12:55:14 +00:00
"metadata": {
"collapsed": false
},
2016-03-28 10:26:20 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" setosa 1.00 1.00 1.00 8\n",
" versicolor 0.79 1.00 0.88 11\n",
" virginica 1.00 0.84 0.91 19\n",
"\n",
"avg / total 0.94 0.92 0.92 38\n",
"\n"
]
}
],
2016-03-15 12:55:14 +00:00
"source": [
"print(metrics.classification_report(y_test, y_test_pred,target_names=iris.target_names))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Confusion matrix"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another useful metric is the confusion matrix"
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 13,
2016-03-15 12:55:14 +00:00
"metadata": {
"collapsed": false
},
2016-03-28 10:26:20 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 8 0 0]\n",
" [ 0 11 0]\n",
" [ 0 3 16]]\n"
]
}
],
2016-03-15 12:55:14 +00:00
"source": [
"print(metrics.confusion_matrix(y_test, y_test_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see we classify well all the 'setosa' and 'versicolor' samples. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### K-Fold cross validation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to avoid bias in the training and testing dataset partition, it is recommended to use **k-fold validation**.\n",
"\n",
"Sklearn comes with other strategies for [cross validation](http://scikit-learn.org/stable/modules/cross_validation.html#cross-validation), such as stratified K-fold, label k-fold, Leave-One-Out, Leave-P-Out, Leave-One-Label-Out, Leave-P-Label-Out or Shuffle & Split."
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 14,
2016-03-15 12:55:14 +00:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 1. 0.8 1. 0.93333333 0.93333333 1. 1.\n",
" 1. 0.86666667 0.93333333]\n"
]
}
],
"source": [
"from sklearn.cross_validation import cross_val_score, KFold\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# create a composite estimator made by a pipeline of preprocessing and the KNN model\n",
"model = Pipeline([\n",
" ('scaler', StandardScaler()),\n",
" ('DecisionTree', DecisionTreeClassifier())\n",
"])\n",
"\n",
"# create a k-fold cross validation iterator of k=10 folds\n",
"\n",
"cv = KFold(x_iris.shape[0], 10, shuffle=True, random_state=33)\n",
"\n",
"# by default the score used is the one returned by score method of the estimator (accuracy)\n",
"scores = cross_val_score(model, x_iris, y_iris, cv=cv)\n",
"print(scores)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"We get an array of k scores. We can calculate the mean and the standard error to obtain a final figure"
]
},
{
"cell_type": "code",
2016-03-29 10:48:56 +00:00
"execution_count": 15,
2016-03-15 12:55:14 +00:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean score: 0.947 (+/- 0.022)\n"
]
}
],
"source": [
"from scipy.stats import sem\n",
"def mean_score(scores):\n",
" return (\"Mean score: {0:.3f} (+/- {1:.3f})\").format(np.mean(scores), sem(scores))\n",
"print(mean_score(scores))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So, we get an average accuracy of 0.947."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* [Plot the decision surface of a decision tree on the iris dataset](http://scikit-learn.org/stable/auto_examples/tree/plot_iris.html)\n",
"* [Learning scikit-learn: Machine Learning in Python](http://proquest.safaribooksonline.com/book/programming/python/9781783281930/1dot-machine-learning-a-gentle-introduction/ch01s02_html), Raúl Garreta; Guillermo Moncecchi, Packt Publishing, 2013.\n",
"* [Python Machine Learning](http://proquest.safaribooksonline.com/book/programming/python/9781783555130), Sebastian Raschka, Packt Publishing, 2015.\n",
"* [Parameter estimation using grid search with cross-validation](http://scikit-learn.org/stable/auto_examples/model_selection/grid_search_digits.html)\n",
"* [Decision trees in python with scikit-learn and pandas](http://chrisstrelioff.ws/sandbox/2015/06/08/decision_trees_in_python_with_scikit_learn_and_pandas.html)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Licence\n",
"The notebook is freely licensed under under the [Creative Commons Attribution Share-Alike license](https://creativecommons.org/licenses/by/2.0/). \n",
"\n",
"© 2016 Carlos A. Iglesias, Universidad Politécnica de Madrid."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2016-03-29 10:48:56 +00:00
"version": "3.5.1"
2016-03-15 12:55:14 +00:00
}
},
"nbformat": 4,
"nbformat_minor": 0
}