mirror of
https://github.com/gsi-upm/sitc
synced 2024-11-05 07:31:41 +00:00
678 lines
101 KiB
Plaintext
678 lines
101 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"![](files/images/EscUpmPolit_p.gif \"UPM\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Course Notes for Learning Intelligent Systems"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Department of Telematic Engineering Systems, Universidad Politécnica de Madrid, © 2016 Carlos A. Iglesias"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## [Introduction to Machine Learning](2_0_0_Intro_ML.ipynb)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Table of Contents\n",
|
||
|
"* [Decision Tree Learning](#Decision-Tree-Learning)\n",
|
||
|
"* [Load data and preprocessing](#Load-data-and-preprocessing)\n",
|
||
|
"* [Train classifier](#Train-classifier)\n",
|
||
|
"* [Evaluating the algorithm](#Evaluating-the-algorithm)\n",
|
||
|
"\t* [Precision, recall and f-score](#Precision,-recall-and-f-score)\n",
|
||
|
"\t* [Confusion matrix](#Confusion-matrix)\n",
|
||
|
"\t* [K-Fold cross validation](#K-Fold-cross-validation)\n",
|
||
|
"* [References](#References)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Decision Tree Learning"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"The goal of this notebook is to learn how to learn how create a classification object using a [decision tree learning algorithm](https://en.wikipedia.org/wiki/Decision_tree_learning). \n",
|
||
|
"\n",
|
||
|
"There are a number of well known machine learning algorithms for decision tree learning, such as ID3, C4.5, C5.0 and CART. The scikit-learn uses an optimised version of the [CART (Classification and Regression Trees) algorithm](https://en.wikipedia.org/wiki/Predictive_analytics#Classification_and_regression_trees).\n",
|
||
|
"\n",
|
||
|
"This notebook will follow the same steps that the previous notebook for learning using the [kNN Model](2_5_1_kNN_Model.ipynb), and details some pecualiarities of the decision tree algorithms."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Load data and preprocessing\n",
|
||
|
"\n",
|
||
|
"Here we repeat the same operations for loading data and preprocessing than in the previous notebooks."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 29,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# library for displaying plots\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"# display plots in the notebook \n",
|
||
|
"%matplotlib inline\n",
|
||
|
"\n",
|
||
|
"## First, we repeat the load and preprocessing steps\n",
|
||
|
"\n",
|
||
|
"# Load data\n",
|
||
|
"from sklearn import datasets\n",
|
||
|
"iris = datasets.load_iris()\n",
|
||
|
"\n",
|
||
|
"# Training and test spliting\n",
|
||
|
"from sklearn.cross_validation import train_test_split\n",
|
||
|
"\n",
|
||
|
"x_iris, y_iris = iris.data, iris.target\n",
|
||
|
"# Test set will be the 25% taken randomly\n",
|
||
|
"x_train, x_test, y_train, y_test = train_test_split(x_iris, y_iris, test_size=0.25, random_state=33)\n",
|
||
|
"\n",
|
||
|
"# Preprocess: normalize\n",
|
||
|
"from sklearn import preprocessing\n",
|
||
|
"scaler = preprocessing.StandardScaler().fit(x_train)\n",
|
||
|
"x_train = scaler.transform(x_train)\n",
|
||
|
"x_test = scaler.transform(x_test)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"collapsed": true
|
||
|
},
|
||
|
"source": [
|
||
|
"## Train classifier"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"The usual steps for creating a classifier are:\n",
|
||
|
"1. Create classifier object\n",
|
||
|
"2. Call *fit* to train the classifier\n",
|
||
|
"3. Call *predict* to obtain predictions\n",
|
||
|
"\n",
|
||
|
"*DecisionTreeClassifier* is capable of both binary (where the labels are [-1, 1]) classification and multiclass (where the labels are [0, ..., K-1]) classification."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 30,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,\n",
|
||
|
" max_features=None, max_leaf_nodes=None, min_samples_leaf=1,\n",
|
||
|
" min_samples_split=2, min_weight_fraction_leaf=0.0,\n",
|
||
|
" presort=False, random_state=1, splitter='best')"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 30,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.tree import DecisionTreeClassifier\n",
|
||
|
"import numpy as np\n",
|
||
|
"\n",
|
||
|
"from sklearn import tree\n",
|
||
|
"\n",
|
||
|
"max_depth=3\n",
|
||
|
"random_state=1\n",
|
||
|
"\n",
|
||
|
"# Create decision tree model\n",
|
||
|
"model = tree.DecisionTreeClassifier(max_depth=max_depth, random_state=random_state)\n",
|
||
|
"\n",
|
||
|
"# Train the model using the training sets\n",
|
||
|
"model.fit(x_train, y_train) "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 31,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Prediction [1 0 1 1 1 0 0 1 0 2 0 0 1 2 0 1 2 2 1 1 0 0 2 0 0 2 1 1 2 2 2 2 0 0 1 1 0\n",
|
||
|
" 1 2 1 2 0 2 0 1 0 2 1 0 2 2 0 0 2 0 0 0 2 2 0 1 0 1 0 1 1 1 1 1 0 1 0 1 2\n",
|
||
|
" 0 0 0 0 2 2 0 1 1 2 1 0 0 2 1 1 0 1 1 0 2 1 2 1 2 0 1 0 0 0 2 1 2 1 2 1 2\n",
|
||
|
" 0]\n",
|
||
|
"Expected [1 0 1 1 1 0 0 1 0 2 0 0 1 2 0 1 2 2 1 1 0 0 2 0 0 2 1 1 2 2 2 2 0 0 1 1 0\n",
|
||
|
" 1 2 1 2 0 2 0 1 0 2 1 0 2 2 0 0 2 0 0 0 2 2 0 1 0 1 0 1 1 1 1 1 0 1 0 1 2\n",
|
||
|
" 0 0 0 0 2 2 0 1 1 2 1 0 0 1 1 1 0 1 1 0 2 2 2 1 2 0 1 0 0 0 2 1 2 1 2 1 2\n",
|
||
|
" 0]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(\"Prediction \", model.predict(x_train))\n",
|
||
|
"print(\"Expected \", y_train)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Alternatively, the probability of each class can be predicted, which is the fraction of training samples of the same class in a leaf:"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 32,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Predicted probabilities [[ 0. 0.97368421 0.02631579]\n",
|
||
|
" [ 1. 0. 0. ]\n",
|
||
|
" [ 0. 0.97368421 0.02631579]\n",
|
||
|
" [ 0. 0.97368421 0.02631579]\n",
|
||
|
" [ 0. 0.97368421 0.02631579]\n",
|
||
|
" [ 1. 0. 0. ]\n",
|
||
|
" [ 1. 0. 0. ]\n",
|
||
|
" [ 0. 0.97368421 0.02631579]\n",
|
||
|
" [ 1. 0. 0. ]\n",
|
||
|
" [ 0. 0. 1. ]]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Print the \n",
|
||
|
"print(\"Predicted probabilities\", model.predict_proba(x_train[:10]))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 33,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Accuracy in training 0.982142857143\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Evaluate Accuracy in training\n",
|
||
|
"\n",
|
||
|
"from sklearn import metrics\n",
|
||
|
"y_train_pred = model.predict(x_train)\n",
|
||
|
"print(\"Accuracy in training\", metrics.accuracy_score(y_train, y_train_pred))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 34,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Accuracy in testing 0.921052631579\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Now we evaluate error in testing\n",
|
||
|
"y_test_pred = model.predict(x_test)\n",
|
||
|
"print(\"Accuracy in testing \", metrics.accuracy_score(y_test, y_test_pred))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Now we are going to visualize the DecisionTree classification. It will plot the decision boundaries for each class.\n",
|
||
|
"\n",
|
||
|
"The current version of pydot does not work well in Python 3.\n",
|
||
|
"For obtaining an image, you need to install `pip install pydotplus` and then `conda install graphviz`.\n",
|
||
|
"\n",
|
||
|
"You can skip this example. Since it can require installing additional packages, we include here the result.\n",
|
||
|
"![Decision Tree](files/images/cart.png)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 35,
|
||
|
"metadata": {
|
||
|
"collapsed": false,
|
||
|
"scrolled": true
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Warning: Not built with libexpat. Table formatting is not available.\n",
|
||
|
"in label of node 0\n",
|
||
|
"in label of node 1\n",
|
||
|
"in label of node 2\n",
|
||
|
"in label of node 3\n",
|
||
|
"in label of node 4\n",
|
||
|
"in label of node 5\n",
|
||
|
"in label of node 6\n",
|
||
|
"in label of node 7\n",
|
||
|
"in label of node 8\n",
|
||
|
"\n",
|
||
|
"Warning: Not built with libexpat. Table formatting is not available.\n",
|
||
|
"in label of node 0\n",
|
||
|
"in label of node 1\n",
|
||
|
"in label of node 2\n",
|
||
|
"in label of node 3\n",
|
||
|
"in label of node 4\n",
|
||
|
"in label of node 5\n",
|
||
|
"in label of node 6\n",
|
||
|
"in label of node 7\n",
|
||
|
"in label of node 8\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXMAAAFbCAYAAAAnRW7JAAAABmJLR0QA/wD/AP+gvaeTAAAgAElE\nQVR4nOzdeVhU5fs/8PeZhX0NcwEVgUBKMTfK3HBHEMstDdRKP+a+27dcPqOW5G4uuFBWbpkJRpgK\nuKBoiYi5hAso7gi4sG8zwsw8vz/6yC8Cd2aeMzP367q4rjgO87y5G24ezjznOQJjjIEQQoghi5Tw\nTkAIIeTlUTMnhBAjQM2cEEKMgIx3AFI7VCoVUlNTkZmZiZycHJSXl/OO9FjW1taoV68e3N3d4ebm\nBkEQeEcixOBRMzdg+fn52LZtG3755RccP34cGo2Gd6Tn5uDggN69e2Po0KEICAiAVCrlHYkQgyTQ\nahbDo1QqsWTJEixbtgwSiQRBgQHo1asnWrdqiUYNG8Ha2gpmZma8Yz5WaWkpysqUSE1LQ9LJZOyL\nicEfxxPh4eGBlStXIigoiHdEQgxNJDVzA7Nnzx5MmTIFeXl5UMyZjU/+MwK2tra8Y720a9ev44sv\nQ7F9x88IDAxEWFgY3NzceMcixFDQ0kRDodVqMWfOHLz33nvw69QRaRdSMH3qZKNo5ADg4e6OrZt/\nwNHDh3AnIwO+vr44fPgw71iEGAxq5gagtLQUAwYMwLJly/Ddt+H44btvUa9eXd6xdKJjh/ZI/D0B\n3br4wd/fHxs2bOAdiRCDQG+AipxWq8XQoUNx9OhRxO79Dd26duEdSecsLS3x808/Yo5iLiZMmABb\nW1sMGzaMdyxCRI2aucjNmDEDcXFxOLQ/Fh3av8M7jt4IgoCFoQtQUVGBkSNHwsXFBV27duUdixDR\nojdARSwqKgoDBw7Et+HrMWrkCN5xuFCr1Xi3/0CcPfcXLl68CCcnJ96RCBEjWs0iVmVlZXj99dfh\n16kjtmz6nnccrgoKCuDdvAUGDXof69at4x2HEDGi1SxiFRoaiqKiIixbsph3FO4cHBywdPEihIeH\nIzk5mXccQkSJZuYilJubi0aNGkExZzZmfvYp7ziioNVq8Xb7TnB2ccFvv/3GOw4hYkMzczHasmUL\nZDIZxo8dzTuKaEgkEsyYPhUxMTHIyMjgHYcQ0aFmLkLR0dHo925f2NnZ8Y4iKgP6vQcrKyvs3r2b\ndxRCRIeaucioVCqcOHEC/v69eEcRHTMzM3T180NCQgLvKISIDjVzkUlNTYVarUbrVi15RxGlVq1a\n4sKFC7xjECI61MxFJjs7GwDQqGFDzknEqVGjhsjKyuIdgxDRoWYuMqWlpQAAKysrrjk0Gg1CFy5G\na9+3YffKq2jfyQ/f/bAJvBc/WVtbo6SkhGsGQsSImrnIPGqWvO++MyR4KObO/wIODg6YOH4slEoV\nRo8dD8W8+VxzCYLA/RcKIWJEzZxUczL5FKKid+O9vn1xaH8sFoYuwPFjR9DCxwdfr1qD+/cf8I5I\nCPkXauakmvUbwgEAU6dMgkTy90vEysoKY8d8ApVKhe83beaYjhBSE2rmpJrLV65AKpVW26XRr3Mn\nAEB6ejqPWISQJ6BmTqq5cycTr7ziCJms6g7Jr9Z5FQCQSatJCBEdauakmgc5ObC1qX47Onv7v69I\nvXfvvr4jEUKegpo5qcbJ6RWUlFZf/ldUVAwAcHR00HckQshTUDMn1Tg3aIC8vHxoNJoqx3NycwAA\nLs7OPGIRQp6Amjmpxqd5c6jVapxMPlXleOKJJADAG2+8wSMWIeQJqJmTaj4Z9R8AQPg331ZeoFNR\nUYEfNm2GXC7HyI8/4hmPEFIDuqEzqeaddm9j8KCB+PGnHVCr1WjX7m3s2bsPxxNPYJ5iDurXr8c7\nIiHkX6iZk2oEQcD2bVvw+uuvY8/evdgXG4cWPs1N+sbShIgdNXNSI6lUinmKOZinmMM7CiHkGdA5\nc0IIMQLUzAkhxAhQMxeZR5fQ/3uNN/mbRqOpts0AIYSauejY29sDAAoLizgnEaeCggI4ONAVqIT8\nGzVzkfHw8AAAXKGdCWuUnp4Od3d33jEIER1q5iLj6uoKBwcHJJ1M5h1FlE4mn8Kbb77JOwYhokPN\nXGQEQUCvXr2wd98+3lFE5+7de0g+9Sf8/f15RyFEdKiZi1BwcDCOJBylUy3/8v2mzbCzs0NAQADv\nKISIDjVzEQoKCoK7uzvmzf+SdxTRyM3Nw6o1YRg1ahSsrKx4xyFEdKiZi5BMJsPq1auxM3IX9h84\nyDuOKMz+rwLm5uZQKBS8oxAiSgJ7tC0eER1/f39kZ2Uh6fgxWFpa8o7DTfKpP9GhcxeEh4dj1KhR\nvOMQIkaR1MxFLD09HW+99RZ6du+Gn3/6EYIg8I6kdxl37qBdh05o1qw59u/fD4mE/pgkpAaR9JMh\nYp6enoiIiMCvu3/Df+fO4x1H74qLi/Fe/0GwtbVDREQENXJCnoCuixa5nj17YvXq1ZgwYQJKS8uw\nYtkSSKVS3rF07k5mJt7tNxAZd+4gMTERjo6OvCMRImrUzA3A+PHjYWtri9GjR+P6jRv4ccsm2NnZ\n8Y6lM6fPnEW/gYNgY2OLxMREeHl58Y5EiOjR360GYvjw4Thy5AhOnzmL133exLbtP8HY3u4oKCjA\ntBn/h3c6doaPTwskJSVRIyfkGVEzNyDt2rXDpUuXMGjQ+/jPJ2PQ0a8rInb9ApVKxTvaS8nOvosl\ny1bAu3kL7Izche+++w4xMTG0oRYhz4FWsxiolJQUKBQK7Nu3DxYWFvDr3AmtW7VEw4YNeUd7Jmq1\nGmlpl5F0Mhlnzp6Fg4MDRo0ahdmzZ1fuHEkIeWa0NNHQ3blzB9HR0Th8+DDOnz+PzMxMKJVK3rGe\nytHREe7u7mjVqhUCAgIQGBgICwsL3rEIMVTUzE3NoEGDcPXqVZw5c0YnS/0uXbqEFi1a4KeffsLg\nwYNr/fkJITWiZm5KTp48iXbt2mHfvn0IDAzU2TgjRozAH3/8gdTUVLorECH6Qc3clPTu3RtlZWU4\nduyYTse5desWvLy8sGHDBowcOVKnYxFCAFAzNx1Hjx5Fly5dcOjQIXTv3l3n440dOxZxcXG4fPky\nzM3NdT4eISaOmrmp6NSpE+RyOQ4fPqyX8bKysvDaa69hxYoVGDdunF7GJMSE0d4spiA+Ph5//PEH\nFixYoLcxnZ2d8cknn2DBggUGsbqGEENHzfwpQkNDIQjCM31ER0fzjlsjhUIBf39/dOjQQa/jzpo1\nC4WFhfjmm2+e6fHPUuPn4e3tbZI7TRLTREsNnqJ9+/b4/PPPqxxbsmQJHB0dMXr06CrHPT099Rnt\nmcTExCApKQmnTp3S+9j169fHhAkTsHDhQvznP/+Bra3tU7+mproSQp6Ozpm/AEEQ0LRpU6SlpfGO\n8kSMMbRt2xaNGjXi9ldDbm4u3N3dMWvWLMycOfOJj63tunp7e+Py5ctGt4cNITWgc+bGLDo6GufO\nnUNoaCi3DE5OTpg8eTKWL1+OoqIibjkIMXbUzGvZo/O05eXlGDNmDGxtbXH16tUnnr8VBAHe3t6V\nnzPGsHHjRvj5+cHOzg6urq6YNm0a8vLynjmHVquFQqHAoEGD0Lx585f+vl7Gp59+CsYYVq5cWSvP\np9VqsWPHDvj5+cHFxQUWFhZwc3PDpEmTkJub+8Sv+/7779GuXTs4OTnBzs4OrVq1Qnh4eJXZe23U\nnxC9Y+S5AWBNmzat8d+aNm3KALDx48ezevXqsb59+7Ls7OzK48/yfJ988gkDwN566y02d+5cNnDg\nQAaAtWjRgpWUlDxTxh07djCZTMYuX778/N+gDixYsIDZ29uz3Nzcxz7mSXX9p6lTpzIAzN7enn34\n4Yds1qxZrE2bNgwA69OnT+Xj/l1zhULBADBvb282depUNmXKFObu7s4AsLCwsMrH1Ub9CdGzCGrm\nL+BZmrmvry8rKyurdvxpz3f06FEGgAUEBLCKiorKx6xatYoBYF999dVT81VUVDBPT082fPjw5/m2\ndKq4uJjVrVuXzZ49+7G
|
||
|
"text/plain": [
|
||
|
"<IPython.core.display.Image object>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 35,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from IPython.display import Image \n",
|
||
|
"from sklearn.externals.six import StringIO\n",
|
||
|
"import pydotplus as pydot\n",
|
||
|
"\n",
|
||
|
"dot_data = StringIO() \n",
|
||
|
"tree.export_graphviz(model, out_file=dot_data, \n",
|
||
|
" feature_names=iris.feature_names, \n",
|
||
|
" class_names=iris.target_names, \n",
|
||
|
" filled=True, rounded=True, \n",
|
||
|
" special_characters=True) \n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"graph = pydot.graph_from_dot_data(dot_data.getvalue()) \n",
|
||
|
"graph.write_png('iris-tree.png')\n",
|
||
|
"Image(graph.create_png()) "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Here we show a graph of the decision tree boundaries. For each pair of iris features, the decision tree learns decision boundaries made of combinations of simple thresholding rules inferred from the training samples.\n",
|
||
|
"\n",
|
||
|
"We are going to import a function defined in the file [util_ds.py](files/util_ds.py) using the *magic command* **%run**."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 36,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYEAAAEhCAYAAACZRRzKAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXd4VFX6xz9n0ntvJCRAxIAUEcEFEQgirr2vBbtuQVAp\nioDoD3FZBRHErrt2BcsqK3ZABVRQQaUqvaSRHlIISaad3x/nTnJnMpMMyaTBfJ8nT2ZuOffcOfee\nt33P+wopJV544YUXXpycMHR0B7zwwgsvvOg4eIWAF1544cVJDK8Q8MILL7w4ieEVAl544YUXJzG8\nQsALL7zw4iSGVwh44YUXXpzEOGmFgBDCIoT4TQixQwixWQgxTQghWtjWXCHEuU3s/4cQ4qaW99Yz\nEEL4CyFWa/f9l3a+9jVCiD+EEN94oK1bhRDPtvDcz4QQ4U3s/7cQok/Le1ffzmghxPDWtuMJCCEu\nFUI80EHXPlMIsaSVbbwuhLjKyfYM7d39VQjRswXtThZCBLambycCxMm6TkAIUSmlDNc+xwLvAuul\nlI90aMfaCEIIH2Ao8KiU8vwOuP6XwD+llBs80NatwJlSyntb37O2gRBiDnBUSrnIyT4fKaWlA7rV\naXA8v4EQ4nXgUynlcoftMwAfKeVjLezDQdRzVHYc55x4YyelPCn/gEqH7z2BEu2zAXgC+BnYAvxN\nd9wMYBuwGXhM2/Y6cJX2eT6wQzvvCW3bHGCa9nkQ8KO2/yMgQtu+Rjv3Z2AXMMJJnxOBdcBvWh9G\naNurdMdcDbyu69eL2vVeAfYC5dr5PYGHgY1aWy/p2kgHVmt9/AXoqW2/Xzt+CzDHxe96g9beNuBx\nbdvDQBWwE1jgcHwI8LV2na3AZS7avR3YDfwE/Bt4RtseC3yo/W4/A2fr2n1N68cW4Ept+0EgGggG\nPtPGcRvwF904DHZyL/N1fakC5mntbgDiHPqaBuQDOdpvPUI3Fj8BT2rXf1X7/qvtvmni2XNof7vu\n+33A/2mf7wV+185dpm27FXhW90w8DawH9tHw3ArgBeAPYCXwuW2fw7XXAEt0v9sQbftQ7bf4FfgB\n6K1tH42awEG9B29p+5c2da/Ac9rzsspZX4ALdb/xN9q2G7W2ftN+a5uS+wLqud2O9twC9wB1qGfO\ndn5z75GrsbtUO+403fW3AOkdPc+5NRd2dAc67MYdhIC2rQyIA/4GPKht8wc2aS/eBdoDHKDti9Q9\nJFehJpdduvbCdQ+/TQhsBc7RPs8FFmuf1wALdQ/4aif9mwbM0j4LIMTxXrSH9zVdvz7R7Rvt8D1S\n9/kt4GLt8080TEr+QCAwDnhZd+1PbfehayMJyNJ+BwPwja6dNcAZTu7JAIRqn2OAvU6OSdS166uN\ngU0ILKVh4u8O/KF9nm/7bbXvNmF7QGvnKtv9aNvDdP0c3My9WIGLtM8LbM+KQ5/rx9zFWPwLGG/r\nG0rABeHi2XNoOw3YpvuuFwJ5gJ/D83er7vd6HXhf+9zX9nsD1wCfaZ8TUO+CKyFgew5GogkjIBQw\naJ/HAh86PnPab7IJ8Ne+u3rPrgRW6p6pIy76on+v+gCfoCwDgOeBmxzeU4PW//66ZyHK2ZxA0++R\nq7F7BrhB2+6LNk909j9fvHCG84EBOr95ONAbOA+lHdQBSCnLHc6rAGqEEK+gtJfP9Ds1X3SElPIH\nbdObwAe6Q2zm7q+ol8ERm4BXhRB+wAop5VY37uW/TewbK4SYjtJsooAdQoh1QDcp5ScAUkqj1vfz\ngXFCiN/QBBDqN/lB195QYI3UzGshxFJgFOrlRDvPEQbgcSHEKNTk2k0IES+lLNId8yeHdt/Xrg1q\nTPrq4jmhQogQbft1tgaklBUOfdgOPCmEeBz4XDcm7tyLUUr5hXbcr9q13IF+LM4HLtV+f1CTYCqu\nn70sN6+xFVgmhPgY+NjFMR8DSCl3CiHitW0jbP2TUhYKIdY0cY13teO+F0KEac91OPCWEKI3IMHl\n3PKJ7ZnC9b2O0l0jXwjxbZN3rDAWJbw3ac9CIFCo7bteCPE3rU+JKI19B+pZcDcO6M7Y/QjMFkKk\nAP+TUu5zs+0OhVcIaBBC9AIsUspi7SG6R0q52uGYC5pqQ0ppEUKchXog/wLcrX22a6aJJuq0/xac\njI320o0CLgbeEEIsklK+43CYY6Cr2tmFhBABKG1psJTysObDtp3rrI8C5d75TxP9d3VuU7gR5dI5\nQ0pp1fy0zoJ1rtoVwJ+klCa7jULIpi4qpdwrhBgMXATME0J8LaWc5+Y19ddyOlYu4DgWV0sp99pd\n0MWz5wAz4KP7rv+9LkZNopehJqT+Ts6v031uCRnC8beVwD+Bb6WUVwkh0lAatzPofwNX79nFLeiT\nAN6UUs52aKsHylI6U0pZqcUX3AkGN/ceNRo7YLcQ4ifgEuALIcTfpZRr3ex/h+GkZQehe/iFEHEo\nn5+NcbISmCiE8NX29xZCBKP85LcLIYK07VF2DapjIqWUX6FcNwP1+6WUlUCZEGKEtulmlI+/yf7p\n2k8FiqSUr6J8/IO1XQUaU8KAMqXdQSDq5S0VQoSi3AFIKY8COUKIy7Vr+mv3uxK4Q9OyEUJ00343\nPTYCo4QQ0Vog+gZgbTP9iNDuySqEGINzC+hnrd0ozQrSM5tWAZNtX4QQp2sfVwOTdNsj9Q0KIZKA\nGinlMmAhDb9la+5FjyqUZusKK1H+e1t/Bum2Oz57QQ7nFgJx2u8RgJp0bEiVUq4DZmrXD22mn7bn\nbD1wtVBIADKbOOc6rW/nABVSyirUOOZp+29v5po2uHrPvgOuE0IYtHEa40Zb3wDX2J5J7bdJRf0G\nR4Eq7b4u1J1Tif0YufseOR07IURPKeVBKeWzwAoc3v/OipPZEgjUXBv+KM3uLSnlU9q+V4AewG+a\nZlYEXCGlXKlNMr8IIeqAL4CHaNCMwoEVOtrZVCfXvQ14SXuxD9DwwjjTrhyRCUwXQphQk8wt2vZZ\nKPdTESrAanvxXWrDUsoKIcR/UEHEfNSkZ8MtwMtCiEcBIypouloo6uSPmuelCrgJKNa1WSCEmEnD\nZPmZlNLmEnPVl6XAp0KIrVrfdzrpa4EQ4hFUrOIIKuhmw2Tgee18H9QEMhHlt31eCLEdpTnPRblB\nbP0YACwUQli1e5yg76eTe/ncjXvR41PgQyHEZaggpOM584AlQohtqIn4IEp7d/rsOfweZm1sNgG5\naL+ZNpm+o7lnBPC0pv3ane7QD9v3j4BzUc9DDsrNVYFz1Grvji8Nz+8TwJtCiIdQz6I7cPWe/U8o\nyvXvQDYq4NwkNNfWQ8AqbRI3ApOklBuFEFtQv1EO9u7L/wBfCSHypJRjcf890o+dAfUeXwZcK4S4\nGTWf5KOewU6Pk5Yi6oUXXthDCBEipawWQkSjrK8RDrEZtFjBfVLK3zqkk154HCezJeCFF17Y4zPN\nbeaHWk9S5OQYr9Z4gsFrCXjhhRdenMQ4mQPDXnjhhRcnPbxCwAsvvPDiJIZXCHjhhRdenMTwCgEv\nvPDCi5MYXiHghRdeeHESwysEvPDCCy9OYniFgBdeeOHFSQyvEPDCCy+8OInhXTHshcehJTX7DpWX\nyReVW35ux/bKCy+8cAbvimEv2gRCiGAp5TEtA+d64F4p5cbmzvPCCy/aF153kBdtAinlMe1jAMoa\n8GobXnjRCeEVAl60CbRc8JuBAlSpzE0d3ScvvPCiMbpETKC5KlFetB+klG5VopJSWoEztNz2Hwsh\nTpNS/qE/xjuunQvujm1z8I5r50Jz49opLAFNa/xNCPGJq2NsRZHnzJnTomLK3vNaf15LIFU1tTWA\n09KcXeG+T4bzPI2uct8n+nnuoFMIAVR1qD+aPcqLLgEhRKwQIkL7HASMA3Z1bK+88MILZ+hwISCE\nSEEV+36lo/vihceQBKzRyvr9DKyUUn7RwX3ywgsvnKDDKaJCiP+ianFGoMrWXebkGGnr59q1a3li\nx/GHMkr3biam9xkdet6
|
||
|
"text/plain": [
|
||
|
"<matplotlib.figure.Figure at 0x7f4823af2048>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%run util_ds\n",
|
||
|
"\n",
|
||
|
"# display plots in the notebook \n",
|
||
|
"%matplotlib inline\n",
|
||
|
"plot_tree_iris()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Next we are going to export the pseudocode of the the learnt decision tree."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 37,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"if ( petal width (cm) <= -0.415243178606 ) {\n",
|
||
|
" return setosa ( 42 examples )\n",
|
||
|
"}\n",
|
||
|
"else {\n",
|
||
|
" if ( petal width (cm) <= 0.861689627171 ) {\n",
|
||
|
" if ( petal length (cm) <= 0.818572163582 ) {\n",
|
||
|
" return versicolor ( 37 examples )\n",
|
||
|
" return virginica ( 1 examples )\n",
|
||
|
" }\n",
|
||
|
" else {\n",
|
||
|
" return versicolor ( 1 examples )\n",
|
||
|
" return virginica ( 2 examples )\n",
|
||
|
" }\n",
|
||
|
" }\n",
|
||
|
" else {\n",
|
||
|
" if ( petal length (cm) <= 0.707377433777 ) {\n",
|
||
|
" return versicolor ( 1 examples )\n",
|
||
|
" }\n",
|
||
|
" else {\n",
|
||
|
" return virginica ( 28 examples )\n",
|
||
|
" }\n",
|
||
|
" }\n",
|
||
|
"}\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%run util_ds\n",
|
||
|
"get_code(model, iris.feature_names, iris.target_names)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We can also obtain the feature importance of the fitted model as follows."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 39,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']\n",
|
||
|
"[ 0. 0. 0.05947455 0.94052545]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(iris.feature_names)\n",
|
||
|
"print(model.feature_importances_)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We see that the most important feature for this classifier is `petal width`."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Evaluating the algorithm"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Precision, recall and f-score"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"For evaluating classification algorithms, we usually calculate three metrics: precision, recall and F1-score\n",
|
||
|
"\n",
|
||
|
"* **Precision**: This computes the proportion of instances predicted as positives that were correctly evaluated (it measures how right our classifier is when it says that an instance is positive).\n",
|
||
|
"* **Recall**: This counts the proportion of positive instances that were correctly evaluated (measuring how right our classifier is when faced with a positive instance).\n",
|
||
|
"* **F1-score**: This is the harmonic mean of precision and recall, and tries to combine both in a single number."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"print(metrics.classification_report(y_test, y_test_pred,target_names=iris.target_names))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Confusion matrix"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Another useful metric is the confusion matrix"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"print(metrics.confusion_matrix(y_test, y_test_pred))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We see we classify well all the 'setosa' and 'versicolor' samples. "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### K-Fold cross validation"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"In order to avoid bias in the training and testing dataset partition, it is recommended to use **k-fold validation**.\n",
|
||
|
"\n",
|
||
|
"Sklearn comes with other strategies for [cross validation](http://scikit-learn.org/stable/modules/cross_validation.html#cross-validation), such as stratified K-fold, label k-fold, Leave-One-Out, Leave-P-Out, Leave-One-Label-Out, Leave-P-Label-Out or Shuffle & Split."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 40,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"[ 1. 0.8 1. 0.93333333 0.93333333 1. 1.\n",
|
||
|
" 1. 0.86666667 0.93333333]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.cross_validation import cross_val_score, KFold\n",
|
||
|
"from sklearn.pipeline import Pipeline\n",
|
||
|
"from sklearn.preprocessing import StandardScaler\n",
|
||
|
"\n",
|
||
|
"# create a composite estimator made by a pipeline of preprocessing and the KNN model\n",
|
||
|
"model = Pipeline([\n",
|
||
|
" ('scaler', StandardScaler()),\n",
|
||
|
" ('DecisionTree', DecisionTreeClassifier())\n",
|
||
|
"])\n",
|
||
|
"\n",
|
||
|
"# create a k-fold cross validation iterator of k=10 folds\n",
|
||
|
"\n",
|
||
|
"cv = KFold(x_iris.shape[0], 10, shuffle=True, random_state=33)\n",
|
||
|
"\n",
|
||
|
"# by default the score used is the one returned by score method of the estimator (accuracy)\n",
|
||
|
"scores = cross_val_score(model, x_iris, y_iris, cv=cv)\n",
|
||
|
"print(scores)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"collapsed": true
|
||
|
},
|
||
|
"source": [
|
||
|
"We get an array of k scores. We can calculate the mean and the standard error to obtain a final figure"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 41,
|
||
|
"metadata": {
|
||
|
"collapsed": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Mean score: 0.947 (+/- 0.022)\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from scipy.stats import sem\n",
|
||
|
"def mean_score(scores):\n",
|
||
|
" return (\"Mean score: {0:.3f} (+/- {1:.3f})\").format(np.mean(scores), sem(scores))\n",
|
||
|
"print(mean_score(scores))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"So, we get an average accuracy of 0.947."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## References"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"* [Plot the decision surface of a decision tree on the iris dataset](http://scikit-learn.org/stable/auto_examples/tree/plot_iris.html)\n",
|
||
|
"* [Learning scikit-learn: Machine Learning in Python](http://proquest.safaribooksonline.com/book/programming/python/9781783281930/1dot-machine-learning-a-gentle-introduction/ch01s02_html), Raúl Garreta; Guillermo Moncecchi, Packt Publishing, 2013.\n",
|
||
|
"* [Python Machine Learning](http://proquest.safaribooksonline.com/book/programming/python/9781783555130), Sebastian Raschka, Packt Publishing, 2015.\n",
|
||
|
"* [Parameter estimation using grid search with cross-validation](http://scikit-learn.org/stable/auto_examples/model_selection/grid_search_digits.html)\n",
|
||
|
"* [Decision trees in python with scikit-learn and pandas](http://chrisstrelioff.ws/sandbox/2015/06/08/decision_trees_in_python_with_scikit_learn_and_pandas.html)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Licence\n",
|
||
|
"The notebook is freely licensed under under the [Creative Commons Attribution Share-Alike license](https://creativecommons.org/licenses/by/2.0/). \n",
|
||
|
"\n",
|
||
|
"© 2016 Carlos A. Iglesias, Universidad Politécnica de Madrid."
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.5.1"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 0
|
||
|
}
|