1
0
mirror of https://github.com/gsi-upm/sitc synced 2024-11-05 15:31:42 +00:00
sitc/ml3/2_4_1_Exercise.ipynb

666 lines
608 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![](images/EscUpmPolit_p.gif \"UPM\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Course Notes for Learning Intelligent Systems"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Department of Telematic Engineering Systems, Universidad Politécnica de Madrid, © 2018 Óscar Araque"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [Introduction to Machine Learning III](2_4_0_Intro_NN.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MultiLayer Perceptron (MLP) Introduction"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Multilayer perceptrons, also called feedforward neural networks or deep feedforward networks, are the most basic deep learning models."
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2017-03-15T12:33:49.116461",
"start_time": "2017-03-15T12:33:49.111870"
}
},
"source": [
"<img src=\"images/multilayerperceptron_network.png\" alt=\"Drawing\" style=\"width: 400px;\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook we are going to try the spiral dataset with different algorthms. In particular, we are going to focus our attention on the MLP classifier.\n",
"\n",
"\n",
"Answer directly in your copy of the exercise and submit it as a moodle task."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load dataset"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2017-03-16T18:10:16.146770",
"start_time": "2017-03-16T18:10:15.825583"
},
"collapsed": true
},
"outputs": [],
"source": [
"# Show plots in the notebooks\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2017-03-16T18:10:16.200490",
"start_time": "2017-03-16T18:10:16.149330"
}
},
"outputs": [],
"source": [
"# Load the utilities\n",
"from spiral import *"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2017-03-16T18:10:16.662881",
"start_time": "2017-03-16T18:10:16.203138"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"sns.set(color_codes=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2017-03-16T18:10:17.018804",
"start_time": "2017-03-16T18:10:16.665477"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of classes: 5\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAloAAAGrCAYAAAAYfTnLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XdgFNXax/HvbMlm00ivdEINXXqV3hEBEUFQqSLSFCwI\niqiAFKUIKAJSRUAEG0URlV6kCIQSQiAhpPe2abvz/hFv7s0bvF4gYVOez1/snN3Nbw6TzbNnzpxR\nVFVVEUIIIYQQRU5j7QBCCCGEEGWVFFpCCCGEEMVECi0hhBBCiGIihZYQQgghRDGRQksIIYQQophI\noSWEEEIIUUx0D/PihQsXcu7cOcxmM+PGjaNBgwbMmDEDVVXx8PBg4cKF6PX6osoqhBBCCFGqKA+6\njtapU6dYv349n332GUlJSTz55JO0atWKxx9/nB49evDxxx/j4+PD0KFDizqzEEIIIUSp8MCFlqqq\nZGVlYWtri6qqtGrVCkdHR/bt24der+fChQusX7+e5cuXF3VmIYQQQohS4YHnaCmKgq2tLQA7d+7k\n8ccfx2Qy5Z8qdHNzIzY2tmhSCiGEEEKUQg89Gf7gwYPs2rWL2bNnF9iuqiqKojzs2wshhBBClFoP\nVWgdOXKENWvWsHbtWhwcHLCzsyM7OxuA6OhoPDw8/uvr5TaLQgghhCjLHviqw7S0NBYtWsSGDRtw\ndHQEoHXr1hw4cIB+/fpx4MAB2rdv/1/fQ1EUYmNTHzRCmeXh4Sj9cg/SL/cm/VKY9Mm9Sb/cm/TL\nvUm/FObh4Xjfr3ngQmvv3r0kJSUxderU/NOEH374IW+99Rbbt2/H19eXJ5988kHfXgghhBCi1Hvg\nQmvIkCEMGTKk0Pb169c/VCAhhBBCiLJCVoYXQgghhCgmUmgJIYQQQhQTKbSEEEIIIYqJFFpCCCGE\nEMVECi0hhBBCiGIihZYQQgghRDGRQksIIYQQophIoSWEEEIIUUyk0BJCCCGEKCZSaAkhhBBCFBMp\ntIQQQgghiokUWkIIIYQQxUQKLSGEEEKIYiKFlhBCCCFEMZFCSwghhBCimEihJYQQQghRTKTQEkII\nIYQoJlJoCSGEEEIUEym0hBBCCCGKiRRaQgghhBDFRAotIYQQQohiIoWWEEIIIUQxkUJLCCGEEKKY\n6KwdQAhxbyeP/MrRXZsxq9BjxHgaN2tZoD0s9DZbP34fe62Fai070W/IiALt2dnZfL74PSypSdRu\n3ZHu/QcX+hlnTx3n2p9n6dT3SXx9Kxbr/gghRHkkI1pClFAnvt/BlLruTKvrxq+7thRq/3bTGqbW\ncmZCLXeCjx1EVdUC7bu2rKO3PpEXqxoJ+fV7kpOTCrT/cfIY179eQ/eMYDbMm0lGRnqB9pSUZBa+\n/jLLXx3Djg2f3TPjnTthBAZeKvSzhRBC5JFCS4gSyqzVk5mTS3pOLoqNoVC7fQVnYtKzMFtUMlUN\niqIUaLeYLWj/2qZRKFQMBZ4/Q49KzjgbDQQ46oiJiS7QvnPDGkb56phQy5Xki8dJSkos0H5o33f8\nsnQ2oVs+4rNF7xXKl52dzd49X3Pi6G8PsvtCCFEmSKElRAk1ctpbbEyy5cs0B56b+lah9uHjp3DQ\npiIbk2x4atIbhdoHjRjNngwHPr2djl+bHjg7uxRo79JvIGuCEjgQGsdVix0VK1Yu0O7o7EysKQeA\ntFwVGxubAu1XTh9leC0vulX1IDc6tNDPX/Hem1QO/JnU/ZvZtWV9ofY7d0L54ZuviIgI/+fOEEKI\nUkrmaAlRQnl6evHyrHl/224wGBg95XU8PByJjU0t1G5ra8vkt+f/7esrVqzMpAWriI6OomeVquh0\nBT8OBg57gU2rP+bnu1G0fXoMdnb2Bdp9atTmt2vH8bXTk2WsUDhfRhK1K7pRG1hz81qBtrt3w9mx\n8C2eqOjElkM/8sI7S/Dw8CzwnLCwULRaLX5+MndMCFF6SaElRDnm4OCAg4P/Pdt0Oh2jJs3429cO\nenY0J45UJzQ+jkm9nyjUrvWqzMHQ20RnWajTZWCBtisXz9PT254qzg50zswm6NrVAoXW9i8+xXLl\nJDmqinPzrvR/esT/f3shhCgVpNASQjyw1u07/W3bizPeJjDwMvUrVKBSpYKnJZu2aM1nP+6kY1Yu\nvyWamdSocYH22OsXGe+fV3h9dukP+I9Cy2QysXr+LPSZafjUb8bgkWOLcI+EEKJoSaElhCgWiqJQ\nv36De7a5ubkz8YPlBAVdZ3Kdejg6OhZ8rZMb1+NjybGAwbVgkbb7yy94yikLb78KbL14jJiY/nh6\neuW3h4eHse+bw9So15RadeoV/Y4JIcR9kMnwQgirqFDBmebNWxYqsgDGzXibW7U7EtWgG6OmvF6g\nzdZoT2qOBYBMs4pe/+9J+gkJCWxd8BYNbp3i4Or53AwOKt6dEEKIfyAjWkKIEkev19Nv0DP3bOs/\nZDgbVkTwy50YAroPwsXl31dT3r4dQhsXPV4ORnr4mrny51lq+NfKb9+7axs3zx5H6+TK2Fdnodfr\ni31fhBDlmxRaQohSRafTMWbam/dsq1u3Hks2KWSHx3MqMZcxL3TOb4uLiyXi2H5erO3FzcRYvtux\nhUHDXyjw+szMTAwGQ6E1yYQQ4kFJoSWEKDOMRjtenf8JiYmRtHD0KHBa0mKxoPmrgNJpFCwWS4G2\nZXPfwD4lhhhVz+R3l+Dk5PTI8wshyh6ZoyWEKFOMRiONGjUqNPfL09MLt2ad+PR2OvtznBkw9N9X\nMgYFXad2dhwv1PZkuK8NP3+/q8BrTSYToaG3yM3NfST7IIQoO2RESwhRbjzxzHPwzHOFtru7u/Nz\nWi5dVZXARBMV29TIb4uNjWHN3BnUd9CxNdfAjHlLC0zAF0KI/0ZGtIQQ5Z67uwftRr7E5wkGbFr3\npWXbDvlthw/8yMgqjvSu6kZzm0xu3Ch4JWNCQgKxsTGPOrIQopSQES0hhACaNGtFk2atCm2v3aAx\nh87/Rn8bPRdScmnu65ffdvDHPdw8uAcbDTg1ac+gZ0c/yshCiFJARrSEEOK/qN+oCQ2fmcB3uso8\n9co7BW7OffXEr4yq48mztTyJuPRHodf+54R7IUT59FAjWkFBQUycOJHnn3+e4cOH8+abb3L58uX8\ndW1Gjx5Nx44diySoEEJYy9+NdhmcPQhOjMBOq8Vs/Pfke7PZzLK5b2BIiSXN4MjUdxdjMBgeZWQh\nRAnxwIWWyWTi/fffp3Xr1gW2T58+XYorIUS5MGrqG3y7YwtZJhMvTvr3mlznzv1BU0s8bet4ciUm\nid9/OUD33v2tmFQIYS0PXGgZDAbWrl3LmjVrijKPEEKUGjqdjkHDni+03dPTk0PpubQFgtNyqOlb\nMb/tWuAl9n25Dq3BluemvEGFCs6PLrAQ4pF74DlaGo0GG5vClzhv2bKF5557jldffZWkpKSHCieE\nEKVRlSrVqNFnGGtidTi260fDxk3z275du5SJlfSMcM5iy8rFVkwphHgUivSqwyeeeAJnZ2fq1KnD\nmjVrWLFiBbNnzy7KHyGEEKVChy496dClZ6HtOkBRFAw6LbnZ2fnbVVXlyG8Hyc3J4fGuPdFo5Fol\nIcqCIi20WrX692TRLl26MGfOnH98jYeH4z8+pzySfrk36Zd7k34prKT2SY8RL/Dpjq3kanVMmD03\nP+cnCz7A785lHDSw6doFZrw3v1h+fkntF2uTfrk36ZeHV6SF1uTJk5kxYwaVKlXi1KlT1KpV6x9f\nExubWpQRygQPD0fpl3uQfrk36ZfCSnKf1G/ajvpN2+U//lfO2Fs3GVgpb77W6VvhxZK/JPeLNUm/\n3Jv0S2EPUng+cKEVGBjIggULiIiIQKfTceDAAUaMGMG0adMwGo3Y29szb968B317IYQoV3xqN+Db\nS0ex0Sg4VWmYv/361ct8t3YFiqLSfdhYGjZtbsWUQoj79cCFVkBAAJs3by60vVu3bg8VSAghyqNB\nI8Zw40Z7zGYzderUy9/
"text/plain": [
"<matplotlib.figure.Figure at 0x7f5f80361748>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# load and plot the spiral dataset\n",
"n_classes = 5\n",
"X, y = load_spiral_dataset(n_classes=n_classes)\n",
"\n",
"plt.figure(figsize=(10,7))\n",
"plot_dataset(X, y)\n",
"\n",
"print('Number of classes: {}'.format(n_classes))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2017-03-16T18:10:17.083617",
"start_time": "2017-03-16T18:10:17.021488"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1050, 2)\n",
"(450, 2)\n",
"(1050,)\n",
"(450,)\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# split the dataset in train and test\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)\n",
"\n",
"# check the dimensions\n",
"print(X_train.shape)\n",
"print(X_test.shape)\n",
"print(y_train.shape)\n",
"print(y_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The features are simply the position of each point in the 2 dimension plane.\n",
"\n",
"In other words, a point $\\mathbf{x}$ is represented by its values $x_1$ and $x_2$:\n",
"\n",
"$\\mathbf{x} = [x_1, x_2] $"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Perform the classification task on several classifiers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Following, the classification on the spiral is done with several classifiers. We can see the performance on each class (each spiral), and their decision surfaces."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Logistic Regression"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2017-03-16T18:10:18.028423",
"start_time": "2017-03-16T18:10:17.089748"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR\n",
" precision recall f1-score support\n",
"\n",
" 0 0.29 0.35 0.32 83\n",
" 1 0.31 0.28 0.29 90\n",
" 2 0.20 0.23 0.21 79\n",
" 3 0.34 0.29 0.31 109\n",
" 4 0.25 0.25 0.25 89\n",
"\n",
"avg / total 0.28 0.28 0.28 450\n",
"\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlQAAAGmCAYAAAC6KhqmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XeYnGd18P/vU6a3ndmd7U1a9W7Lli13uTuAwQaCwfi1\nKe8vEANJKAkl7xsnNCcvSUgIqUCCDQnEgGmxLcC4yZa7rbLqdXvf2Z0+85TfHyOttFrJkrVldnbP\n57p8XatnZp4583h29sy57/vcim3bNkIIIYQQ4rypxQ5ACCGEEKLUSUIlhBBCCDFJklAJIYQQQkyS\nJFRCCCGEEJMkCZUQQgghxCRJQiWEEEIIMUl6MZ+8vz8+9nM47GV4OFXEaGY/uUbnRq7TuZHrdHZy\njc6NXKdzI9fp7Gb7NYpGA2e8bdZUqHRdK3YIs55co3Mj1+ncyHU6O7lG50au07mR63R2pXyNZk1C\nJYQQQghRqiShEkIIIYSYJEmohBBCCCEmSRIqIYQQQohJkoRKCCGEEGKSJKESQgghhJgkSaiEEEII\nISZJEiohhBBCiEmShEoIIYQQYpIkoRJCCCGEmCRJqIQQQgghJkkSKiGEEEKISZKESgghhBBikiSh\nEkIIIYSYJEmohBBCCCEmSS92ANPt8dv+V7FDmDK6rmIYVrHDmPXkOp0buU5nJ9fo3LyZ67SwKkeg\nZcE0RzQ7jegqeXk/vaHJXKOKz3x1iqN5c6RCJYQQQggxSZJQCSGEmBHzuTol5j5JqIQQQgghJmnO\nz6ESYr55KdsPfhhJ5rlcr8Sjjv81z1omW4xeQk4nRtbiUlflhHP05dPsUmLYNqxXygnqzgn36c2n\n6bRSLND8hHXXtL0eMTdIdUrMdVKhEmIOiZt5wtUuPnXdav7g+hW8mhuccJ9tuSE+umkZn7x2FfX1\nPgbz2Qn3aVVjfOF31vEnt6zhFXviOXrzaXrKUrx7UzPbHUMkzfyE++zLjfCs0stv813ET3M7QMo0\nGDVyb/6FCiHELCMJlRBziK4oxNKFBKUvkcGDNuE+bkWjazQNwGAqi1Od+DGgqQqaquBQVRRl4vN0\nmUluW9NMXdDHpiU19OUz4263bZt+T4bP3bSWP755Da8aE5OyfbkR9gVG6I9meDk3cNrXM5TPcigT\nx7Tts752MXtJdUrMBzLkJ8Qc4lF1ykadfOWRbWh5hUtc0Qn3WeUM85uXO9m8rYNo3k3A5Zhwn+a8\nny89+jqmZbPSLptw+wI9yHe27mPT0hoe29nJNc7qCfcxLBvbtsmZ1mm/ufXpaT5/1ToAvvb4Djil\niHU0l2AkkmNtfYRfbevkemctyinZXcLMM2LkqHZ60U6X+QkhxAyRhEqIOWaRM8gignCGaU2qonCp\nq7LQO0g9fb+XRqefRvygwWmKXJTpTjaYUQ5vi3ONsxqXOv5OiqLQlPPx5Ue3YRk2lzgmJnZKXmFP\nX4xyn5t4Mg+nTNM6Ysf508vXoSgKe3tGyA5ZuJUTz9OVS9HmS7C2PsJvd3dxvWtiwgWFatnpjouZ\nsbBKhnTF/CAJlRDivHg1nWYtcMbbmx0BmglMSJSO2+is5NHnOshhcYWzasLtNYqX/952mPUNFRzq\ni7PYERx3+z5rhC9sWoemKvTHMyR6DQLaiWrbQD7D6+oQQY8DPaFyobP8/F6omDQZ7hPzgcyhEkIU\nhaoorHdXsNFdOaHCBbDEGSLfZvHoMx1co9dMqDLVql4e3nGEnniKPT0jeE9ZzbjTGub/3ryOT1+7\nGjNgY50yDyth5nkq283T2R7SpjH1L1BIdUrMK5JQCSFmrUann7XuCO4zJFzJowbff/wQlymVE+ZQ\nudDoGEliWBaj6RynDvq9aPXzyVtW84mbV/C82T/h/LZtM2rkMG3ZKmQypDol5gsZ8hNClKwFzgAL\nOP2w48XOCn749CFyqsVqJYziOCXhcmj4nIWPQE0ff5tl2/wm28XyhhC7u0e4TKnEr02cvC/OTKpT\nYr6RhEoIMSfpispG98S5WcdV5tx85dfbsG2bBsM3bq5Xdy7FlcuruGlpPT3xFP/5+CEu0irGbrdt\nm9ZcjKxtssYZwXGa1hNCqlNifpFPASHEvLTYGeJyo5IrzCoWOMdXuUK6k9c7hjAtmy2HeilXxy+Z\nfCHXz4b1UW67uomnjZ6ZDLskSHVKzEdSoRJCzFtnaqfg1xw0J/189ZfbqMbDEmdo3O2Gw+ai+kLF\nKuBzTOihtSM7REzNUW17WHzKY+cLqU6J+UYSKiGEOI0ap5cG3Y9hTJyUXm64+Jfn9hD2urAS9rie\nX4ezozQu9nPvikb+deseBgeylDvmz16HUp0S85UM+QkhxJu03FlG7ZAXrZ0Jm0uP2gZLKwtVqUXR\nIIlT9jHclxthi9nDK9kB7Dm6pY5Up8R8JBUqIYQ4DyH99B1LlztDPPDsAaJlboZjWa5x1YzdNmLk\nMCpsPn/ZOrYc7qF1e4zlrolb+5QqqU6J+UwSKiGEmEJOVeNGVx1WykZ1jZ+jlbZMqgMeABrCfl6y\nx28K3Z5PcsgeJWK7WO2KzFjMU0mqU2K+koRKCCGmgXqaCe9VDjdbDvVxqD9OLJHlqpM2lU6aebr9\nKf70mnX8Zn8X+/aMTJgMP5tJdUrMd5JQCSHEDFEUhStdVZBjwh6HCdNgUTSIoiisrY3w+q7BcbcP\n5rN0mEkadR9hfXZOcpfqlJjPJpVQ7du3j3vvvZd77rmHO++8k8997nPs3LmTcDgMwIc+9CGuvvrq\nKQlUCCHmskqHmycP9NAVS9ExlORKx4mmpMNGlv3eEX73woX818sHWZ0NEzzDHK5iaCrPMDen1wtx\n7s47oUqn03zpS19i48aN445/+tOfliRKCCHeJEVR2OSuIT9qsdJRNq5HVpeR4u1rmmgO+3nL6gae\n39o3LqGybRsLJuxnOJOkOiXmu/Num+ByufjWt75FZWXl2e8shBDinDhUdULD0UaHn/96+SAvtPXx\no1eP0Oj0j93WnU/xBN285hnk1ezAqacTQsyQ806oVFXF6ZxYcv7e977H3Xffzac+9SlisdikghNC\nCAEBzcGlViW7XolxuV2JVzsxuLCPEf7vzRfwyWtWkffP/MCbTEYXomBKG3u+/e1v51Of+hTf/e53\nWbp0Kd/4xjem8vRCCDFveTWdFk8QjzZ+pobT0mgbTpAzTeLp8clNdy7FU9lutmUHp7WJaHhpy7Sd\nW4hSMaWr/C699NKxn6+77jruu+++N7x/OOxF17Wxf0ejgTe49/nR9bnVDH6uvZ7pItfp3Mh1OrvZ\nfo0u0yp5aMsRsphscFSOxZu1TA564nzh+nU8e7iX3btGWOUOT+lzN5VnxpIpxyy/TrOFXKezO99r\nNB05xJsxpQnVJz7xCT7zmc/Q0NDACy+8wJIlS97w/sPDqbGfo9EA/f3xqQwH4LT7cJUqXVfn1OuZ\nLnKdzo1cp7MrlWt0iTM69vPxeEeNHAtq/aiKwgV15TyzvQfDONHX6njF6kwbRJ8LG8gbFg5dJV8C\n16nY5Dqd3WSu0XTkEKd6o6TtvBOq1tZW7r//frq6utB1nc2bN3PXXXfxR3/0R3g8Hnw+H1/5ylfO\n9/RCCCEmoUxzsr19iH9O76ZjOMWl+omk62B+lC5XCgWoznhZ5Ay+6fPL3CkhxjvvhGrlypU8+OCD\nE47fcMMNkwpICCHE5CmKwtWuGrIjJsu1U9ow6Cm+cOM6AL782Oss4s0nVCCtEoQ4mXRKF0KIOcyl\nahOOGXmbRDYPCuTzFjhO3NaRTZK08yxyhc7Y10qqU0JMJAmVEELMM5doUb7+WCs2NpecNBTYmhum\nvMnN2vJyHnm1g2tdNWc8h1SnhBhPEiohhJhnvJrOVVr1hOPDapaPrlsOwLMHeiE78bFSnRLi9CSh\nEkIIAUDYcvHQtsM0RfyMjObh2B7Mpm3zanaADCYV+SB1yxYXN1AhZiFpiCGEEAKAlc4w+SMW214Z\n5GrniQrWC7k+br2ikY/
"text/plain": [
"<matplotlib.figure.Figure at 0x7f5f87e3a978>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"\n",
"lr = LogisticRegression(n_jobs=-1)\n",
"lr.fit(X,y)\n",
"\n",
"lr_preds = lr.predict(X_test)\n",
"\n",
"print('LR')\n",
"print(classification_report(y_test, lr_preds))\n",
"\n",
"plt.figure(figsize=(10,7))\n",
"plot_decision_surface(X, y, lr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### k-NN"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2017-03-16T18:10:23.248268",
"start_time": "2017-03-16T18:10:18.031503"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"K-NN\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 1.00 1.00 83\n",
" 1 1.00 1.00 1.00 90\n",
" 2 1.00 1.00 1.00 79\n",
" 3 1.00 1.00 1.00 109\n",
" 4 1.00 1.00 1.00 89\n",
"\n",
"avg / total 1.00 1.00 1.00 450\n",
"\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlQAAAGmCAYAAAC6KhqmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xd0pGd9//33XaY3jaRR79qi1TbXXfeCCxhXisFgbIMh\nhQCBEPiFcJ6cX87v4SEOIcSE5JeEEJLYAUJsbMDghhu4t7W3F616rzOj6TN3ef7QSiuttKtVnRnp\nep2z50gzo5lr7h3NfHRd3/t7SaZpmgiCIAiCIAiLJmd7AIIgCIIgCPlOBCpBEARBEIQlEoFKEARB\nEARhiUSgEgRBEARBWCIRqARBEARBEJZIBCpBEARBEIQlUrP54MPDkamv/X4nwWA8i6PJfeIYnR1x\nnM6OOE7zE8fo7IjjdHbEcZpfrh+jQMBz2utyZoZKVZVsDyHniWN0dsRxOjviOM1PHKOzI47T2RHH\naX75fIxyJlAJgiAIgiDkKxGoBEEQBEEQlkgEKkEQBEEQhCUSgUoQBEEQBGGJRKASBEEQBEFYIhGo\nBEEQBEEQlkgEKkEQBEEQhCUSgUoQBEEQBGGJRKASBEEQBEFYIhGoBEEQBEEQlkgEKkEQBEEQhCUS\ngUoQBEEQBGGJRKASBEEQBEFYIhGoBEEQBEEQlkgEKkEQBEEQhCVSsz2AlfbsB+7O9hCWjarKaJqR\n7WFkXUNpGk9j/WmvD6syGXGc5iWO0/zEMTo74jidnbV4nBruvQvTV7Js96eNBPFo+qJ+NqIWL9s4\nFmPNByphbYq0tk99faZwJQiCIKyMhnvvyvYQcooIVELeaRu0Tn3dUJqeClciWAmCIKyu5Zydynei\nhkrIa22D1qmAFWltnzFzJQiCIKwMMTs1m5ihEtaEyVDVUJomeLQVk/U7Y/XLllbC6SSJtM7Ht27B\nbbXMuD6R0Xjw4CGcVgWHYuVDmzfOuo+u8Di/6ehEAm7dtIEih2PWbTrDYY6MBjm3NECJy7VST0cQ\nhBwlZqdmEjNUwprSNmilc9RO26B1asZqPc1aBRNJTFnj8xc3ce+FjTwxx3P/TUcnd51bz+cuasJl\nl+iPxmbd5sn2Dr546RY+e/FmfnHs+KzrO8NhXunr4dKGQh5tOU44lZp1m9d7+/mP/Qf413f3EUzO\nvh5gPJVmNJFcxDMVBCFbxOzU3ESgEtas9bgcaFEUIskMAMOx1KzZKQCPxcpAJAFAKJ7Grs6eqJYl\nUGQJiyxjzvE4R0aD3NRUTaXXxaW1AXrGozOuN02TA6PDfPHSLfzRxZv51fHWWffxRl8/v2pr4eW+\nLh5raZvz+QxG4+wfGkE31taZUYKQ78Ts1GwiUAlr3noKVm6rhS1FAb73yhFebh/h+vq6Wbe5sraK\nd3rD/MOrR2jw+fHbbbNuc3FFBd958RDfeekQ19XVzrr+vLISfvRuGy91DPJ82yANft+s22iGiWma\npHUDWZJmXX94dIzP7m7irnMbCaUTs67fPzTC8z0doKb5wd79mObsaBdOpjg+FhKBSxBWSdnW8mwP\nIWeJGiph3ZheZ7WWzww8v6yU88tKT3u9LEl8cPNGLGfoibM1UMzWwOl7ugScTj7e3EzPeITP7NyO\n45RZLkmSuKSigu++fBgJiQ83bZ51Hw5V5ehwmEKnjWRmdt+ZdwaH+Pwlm5AkidaxCAlNw2k5OePW\nGgzxu54utpb6eXZvJ79/zg6kOYKbaZpzXi4IwsI5d1+LZLHOOXO93olAJaw7M9ousHaD1Urz2qw0\nB4pOe/32kgDbSwKnvf5DmzfxdHsH8YzGJ7Y2z7p+c6Gfnx3s5JzyQnrC8Vmh7ZWePr5wcTOKLDEW\nTxFKpvA77FPX90Wi/LzlOG6rSsDh4gbxfywIy8JwFmR7CDlJBCphXZuatRLBatUpsnTGkLO7spwj\nI2Ps741w745ts2aZNhcV8tiRLi6qDnB8NMJ1NdYZ1/+mo5OvXr4VVZH53quH0Q0TRT55H6FkiseO\ntyIjceumRtzWmT8vCMJMohj9zEQNlSDAuqmxyjdNxYW8p65mxlLfpF0VZRRbPTzXMswntjajyDPf\nztwWKz3jcTTDIJrSkE9Z9Xv4yDE+dUEjnzi/joeOHJt1/6ZpMppIoon6LEGYqp0SxeinJ2aoBOGE\n9VJjtZacaVnxlo2N/LLlONF0hpsaG2bNcKmKhMs6+RY4syJEN0x+sHcftX4Xx0cjfGJrMwVzFO8L\nwnohaqfmJwKVIJxCBKu1waLIfKhp02mvby4q5jsvHcQw4cKyshnXtYfCnFtZyHUbKhiIxHmupZcb\nNzRMXW+aJi929xJNZ7i2vgaroqzY8xCEbJtc6hO1U2cmApUgnIYIVmvbrooyLiyfOBvy1NmrgNPB\n71q6eE9DOa91D1Pj9cy4/ufHjrO9wkeR08MDBw7xmZ3bV23cgpANYqlvfiJQCcI8RLBau07XTsFn\nt3FVVQ3ff72FDX7/rGXFSCbN+ZUTZzja1NmlqM93dtM9Pk5zcREXlJfNul4Q8oUoRD97IlAJwlkS\nwWp9afAXsDlQOGevrkZfAf/+dgtemwWnOvPswH2Dw1itBl+4tIl/f7uFyqiHcrfY61DIP6IQfWFE\noBKEBRLBSri0upKReIKkplFVM3M5cCSRYGvFxGV1fjfBZGpGoHqjr5+jY2MU2h28v7FeNB0VcpYo\nRF8Y0TZBEBZpPW1pI8xW7HRQdUptFcAlVRU8cqCbf3r9CO/2hmgq8k9dNxJP0BkN8flLmqgrdvBa\nb/9qDlkQFkwUop89MUMlCEs014wViFmr9cquqnz2vJ0YpjlrD8NYJkOpa6Kbe6XXSefI2Izrj4yM\n8Ub/ANUeD1fXVa/amAVBWDoxQyUIy2RyxurUWSsxc7U+zbUhdI3XQ28oxT++eoSf7u3gPdNCUziV\n4vWBXj5/ySY8Tnijd2A1hysIMzTcexeSRewesBBihkoQVsCM/QJFrZVwgiRJfGxr05zXhZNpGgo9\nSJLE9jI/Tx4ZnHF9fzTKoZExtgWKKHWJIndh5YnlvoVZ0gzVsWPHuO666/jRj34EwJ//+Z9z8803\nc/fdd3P33Xfz29/+dlkGKQj5TNRaCWej2uumdSTGD948xvffaOGq2qqp6wZjMZ5sb+O8ah+PHT/O\naCKZxZEKa93k2X3Cwix6hiqRSPCNb3yDiy++eMblX/nKV7jyyiuXPDBBWGtOV2sFYuZKmJi9+uSO\nraR1HYsszzj7r2UsxA2bqqjzu7luYwWtYyGKKk/2tzJNE8M0Z+1nKAiL4dx9rWiVsAiL/u2z2Wz8\n4Ac/oKREHHRBWIjptVZi5ko4lVVRZrVS2Boo5pGDXbzePcyvjvTQVFw4dV1bMMQ/v7uXnxw+xBPi\nNSQskWjkuXiLDlSyLGO1zi5Y+6//+i/uuece/vRP/5RQKLSkwQnCeiCClTAfv93GJ7Y2k0oq3LNt\nK17byffeF3t6+V9XbONzF29hKBHN4iiFtULMTi3Oshal33rrrRQUFNDU1MT3v/99vve97/EXf/EX\ny/kQgrBmnWlJEMSy4HrntVnZWRqYdbnLYqErFKPC6yCW1mdc1xYM8VJPHxVuN9fUVYsmosIZidmp\npVnWQHXRRRdNfX3NNdfwl3/5l2e8vd/vRFVP7tIeCMxukrdU6hz7bOWztfZ8Vko+H6euUfusy2qL\nkkRPCVn+zY1LfixLHh+n1ZLrx+j25k08euQ40UyajzZvnhpvPKPxXFcXf3JZM692DfFqXz9XTit0\nX265fpxyRa4ep+KmUpBAKcp+QbpqUea/0RwCxcufIRZiWQPVH//xH/PVr36V6upqXn/9dTZt2nTG\n2weD8amvAwEPw8OR5RwOANoc+3DlK1WV19TzWSlr8Ti1Ds5eXm+gddZlC5nFsqjynPvUCSflyzG6\nbdOGqa8nxzsWT1Jd4ESWJHaWFfKTdzpnPBfTnNhQZDlmrfLlOGVbLh8nx65rkVQrWkaf/8YrSLUo\nix5DcAUyxKnONPGz6EB
"text/plain": [
"<matplotlib.figure.Figure at 0x7f5f4c074780>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"\n",
"knn = KNeighborsClassifier()\n",
"knn.fit(X_train, y_train)\n",
"\n",
"knn_preds = knn.predict(X_test)\n",
"\n",
"print('K-NN')\n",
"print(classification_report(y_test, knn_preds))\n",
"\n",
"plt.figure(figsize=(10,7))\n",
"plot_decision_surface(X, y, knn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Gaussian Naive Bayes"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2017-03-16T18:10:24.946571",
"start_time": "2017-03-16T18:10:23.251722"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GaussianNB\n",
" precision recall f1-score support\n",
"\n",
" 0 0.24 0.36 0.29 83\n",
" 1 0.22 0.16 0.18 90\n",
" 2 0.22 0.29 0.25 79\n",
" 3 0.33 0.27 0.30 109\n",
" 4 0.29 0.22 0.25 89\n",
"\n",
"avg / total 0.26 0.26 0.26 450\n",
"\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlQAAAGmCAYAAAC6KhqmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xd8HHeZ+PHPlO1NK2nVmyVbkiV3O3Hs9GIIJY1A4Egc\nIORoAe4g5Ee5393x+pHjQjnKcXfcHaEl5IALISGEkEJIQuI4sePETbas3ntZrVZbZ2d+f6wtW5G7\nyq6k7/svaWY18+x4vfvs833m+5UMwzAQBEEQBEEQzpuc6gAEQRAEQRAWOpFQCYIgCIIgzJBIqARB\nEARBEGZIJFSCIAiCIAgzJBIqQRAEQRCEGRIJlSAIgiAIwgypqTz54OD45M9er53R0VAKo0l/4hqd\nHXGdzo64TmcmrtHZEdfp7IjrdGbpfo18Ptcp96VNhUpVlVSHkPbENTo74jqdHXGdzkxco7MjrtPZ\nEdfpzBbyNUqbhEoQBEEQBGGhEgmVIAiCIAjCDImEShAEQRAEYYZEQiUIgiAIgjBDIqESBEEQBEGY\nIZFQCYIgCIIgzJBIqARBEARBEGZIJFSCIAiCIAgzJBIqQRAEQRCEGRIJlSAIgiAIwgyJhEoQBEEQ\nBGGGREIlCIIgCIIwQyKhEgRBEARBmCGRUAmCIAiCIMyQSKgEQRAEQRBmSE11AHPtln96NNUhzBpV\nVdC0RKrDOG+Bplbu8e+d8/NU5MZwViyb8/MsdGOqTFzTUx1GWhPX6OyI63R2xHU6s5lco+x7/nmW\nozk3okIlzJvRrLx5O9d4c+u8nUsQBEEQREIlzIv20fD8nWvYCiSTKpFYCYIgCPNBJFTCvNl9cfe8\nnaul30xLvxkQ1SpBEARh7i36HiphaWvpN1OeG5tMqlxLoLfq8cZmxmIRwrEEH6xdidNsmrI/HNd4\nsO4QdrOCTTFzc9WKacfoGAvwbFs7EnBD5XKybLZpj2kfG6N+eJT1uT5yHI65ejqCIAgLgqhQCfPm\n2QcOpOS8S6laNRqOYMgan95SzR0XVPDHkzzfZ9va2b5+GXddVI3DKtEbnJj2mKda2/ibi1fyyS1V\n/K6hadr+9rExXunp4uLyTB5tbGIsGp32mNe6e/nZgYP8aO9+RiPT9wMEojGGw5HzeKaCIAjpRSRU\nwpJxYlK1WBMrk6IwHokDMDgRnVadAnCZzPSNJ3va/KEYVnV6oVqWQJElTLKMcZLz1A+P8u7qYgrd\nDi4u9dEVCE7ZbxgGB4cH+ZuLV/KpLVU80dQ87Ri7enp5oqWRHT0d/L6x5aTPpz8Y4sDAEAld3Bkl\nCEJ6EwmVsKQs9mqV02xiZZaPH7xSz47WId62rGzaYy4vLeLN7jH+bWc95R4vXqtl2mO2FBTwnZcO\n8Z2XD7GtrHTa/g15OTy0t4WX2/p5vqWfcq9n2mM03cAwDGIJHVmSpu0/PDzCJzdXs319Bf7Y9JsW\nDgwM8XxXG6gx7t93AMOYntqNRaI0jfhFwiUIQsqJHiphSVrMvVUb83LZmJd7yv2yJPGeqhWYTjPf\nS60vm1pf9imP4bPb+WBNDV2Bce5cuxrbW6pckiSxtaCA7+84jITEe6urph3DpqocGRwj024hEp8+\nv9qb/QN8emslkiTRPDJOWNOwm45X3JpH/fylq4PaXC/P7WvnY+vWIJ0kcTMM46TbBUEQZpNIqIQl\n61il6lhitZiSqvngtpip8WWdcv/qHB+rc3yn3H9zVSXPtLYRimvcVlszbX9VppdH6tpZl59J11ho\nWtL2SlcPn9lSgyJLjISi+CNRvDbr5P6e8SCPNTbhNKv4bA7eIf59BUGYQyKhEpa8xVytSmeKLJ02\nydlcmE/90AgHuse5Y82qaVWmqqxMfl/fwUXFPpqGx9lWYp6y/9m2du65tBZVkfnBzsMkdANFPn4M\nfyTK75uakZG4obICp3nq3wuCIJwL0UMlCCz+3qqFqjo7k6vKSqYM9R1zYUEe2WYXf24c5LbaGhR5\n6tuZ02SmKxBC03WCUQ35LaN+v6lv4CObKrhtYxkP1zdMO75hGAyHI2iiP0sQhLMgKlSCcAJRrVpY\nTjeseP2KCh5vbCIYi/PuivJpFS5VkXCYj70FTm14T+gG9+/bT6nXQdPwOLfV1pBxkuZ9QRCEY0RC\nJQhvIXqrFgeTInNzdeUp99dkZfOdl+vQDbggb+o6k63+MdYXZrJteQF94yH+3NjNu5aXT+43DIOX\nOrsJxuJcs6wEs6LM2fMQBGFhEEN+gnAKS2HeqqXswoI87li9mjvXrGZdbs6UfT67jYP9oyR0g1c7\nBylxu6bsf6yhiQKvmS3LMnng4KH5DFsQhDQlEipBOA3RW7W4SZJ00ikVPFYLVxSV8N+vNeKQ7dOG\nFcfjMTYWZlHmdWJRp7+NPt/eyQMH6ni9t2/OYhcEIb2IIT9BOAuit2rpKfdmUOXLPOlcXRWeDH66\npxG3xYRdnXp34P7+Qcxmnc9cXM1P9zRSGHSR7xRrHQrCYicSKkE4S6K3Sjjm4uJChkJhIppGUcnU\n4cChcJjaguS2Mq+T0Uh0SkK1q6eXIyMjZFptvLNimZh0VBAWCTHkJwjnSAwBCgDZdhtFb+mtAtha\nVMBvD3byw9fq2dvtpzrLO7lvKBSmPejn01urKcu28Wp373yGLAjCHBIVKkE4D5PVKsQQoDCVVVX5\n5Ia16IYxbQ3DiXicXEdyNvdCt532oZEp++uHRtjV20exy8WVZcXzFrMgCDMnKlSCMAOiWiWcyskW\nhC5xu+j2R/n3nfX8el8bV52QNI1Fo7zW182nt1bissOubtHQLggLiahQCcIMiWqVcLYkSeKvaqtP\num8sEqM804UkSazO8/JUff+U/b3BIIeGRljlyyLXIZrcBSHdzKhC1dDQwLZt23jooYcA+PKXv8x1\n113H7bffzu23386LL744K0EKi8O221enOoQ5JapVwkwUu500D01w/+4G/ntXI1eUFk3u65+Y4KnW\nFjYUe/h9UxPD4UgKIxUE4WTOu0IVDoe599572bJly5TtX/jCF7j88stnHJggLERiegXhfEmSxIfX\n1BJLJDDJ8pS7/xpH/Lyjsogyr5NtKwpoHvGTVXh8dnfDMNANY9p6hoIgzJ/z/t9nsVi4//77ycnJ\nOfODBQG4YEdhqkOYF2IyUGEmzIoybSqFWl82v63r4LXOQZ6o76I6O3NyX8uon//cu49fHj7EH8Xr\nTRBS5rwTKlmWMZvN07b/4he/4EMf+hB33303fr9/RsEJi0ep15bqEObdscRKLF0jzJTXauG22hqi\nEYUPrarFbTn+3vtSVzf/57JV3LVlJQPhYAqjFISlbVbrwzfccAN33303P//5z6mqquIHP/jBbB5e\nEBYkUa0SZoPbYmZtrg/nW77IOkwmOvwTxBIJJmKJKftaRv08cOAQf2rtwDCM+QxXEJacWb3L76KL\nLpr8+eqrr+arX/3qaR/v9dpR1eOrtPt80yfJm6kTj78YLPTno55k3bOFfJ6z1TGcnHuo9OidgN6q\nilSGM8mUZtcpHaX7NXpfTSWP1jcRjMd4f03VZLyhuMafOzr43CU17OwYYGdPL5ef0Og+29L9OqUL\ncZ3O7Hyv0VzkEOdiVhOqz372s9xzzz0UFxfz2muvUVlZedrHj46GJn/2+VwMDo7PZjgAaFrizA9a\nIFRVWfDPRzvJumizTVXleTnP+Wg+2rQ+cqQ55Q3rJlU+6Tp1wnEL5RrdWLl88udj8Y6EIhRn2JEl\nibV5mfzyzfYpz+VYxWo2lr5ZKNcp1cR1OrOZXKO5yCHe6nRJ23knVHV1ddx333309PSgqipPP/00\n27dv53Of+xw2mw2Hw8HXv/718z28ICxa4k5AYT7k2G0MBGL8+PVG+sbD/FXNysl9b/YN8Hp/HxKw\nPjeXjXm5qQtUEBaJ806oamtrefDBB6dt37Zt24wCEha339x+D+994FupDiPlxELLwlyTJIntq2sI\naxrWt9w5uHegny9cWgv
"text/plain": [
"<matplotlib.figure.Figure at 0x7f5f44666c88>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.naive_bayes import GaussianNB\n",
"\n",
"gnb = GaussianNB()\n",
"gnb.fit(X_train, y_train)\n",
"\n",
"gnb_preds = gnb.predict(X_test)\n",
"\n",
"print('GaussianNB')\n",
"print(classification_report(y_test, gnb_preds))\n",
"\n",
"plt.figure(figsize=(10,7))\n",
"plot_decision_surface(X, y, gnb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SVM"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2017-03-16T18:11:56.862050",
"start_time": "2017-03-16T18:10:24.949917"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"SVM\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 1.00 1.00 83\n",
" 1 1.00 1.00 1.00 90\n",
" 2 1.00 1.00 1.00 79\n",
" 3 1.00 1.00 1.00 109\n",
" 4 1.00 1.00 1.00 89\n",
"\n",
"avg / total 1.00 1.00 1.00 450\n",
"\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlQAAAGmCAYAAAC6KhqmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xd83Fed//vXt0xvGkmj3iVbsuSSYsdxeqel0AIhIQmw\nbGGBbcBd4HHv/e3j9+OyYeFHWXZ/29hCQlk2kLC0xCEhgcTpceJeZFu9l5nR9JnvfL/3j7FkyZJt\nWW2KzvPxyOOhfGc8Ojqa8tY5n3OOZBiGgSAIgiAIgrBkcrYbIAiCIAiCkO9EoBIEQRAEQVgmEagE\nQRAEQRCWSQQqQRAEQRCEZRKBShAEQRAEYZlEoBIEQRAEQVgmNZvffGwsNPO112vH749msTW5T/TR\n4oh+WhzRTxcm+mhxRD8tjuinC8v1PvL5XOe8LWdGqFRVyXYTcp7oo8UR/bQ4op8uTPTR4oh+WhzR\nTxeWz32UM4FKEARBEAQhX4lAJQiCIAiCsEwiUAmCIAiCICyTCFSCIAiCIAjLJAKVIAiCIAjCMolA\nJQiCIAiCsEwiUAmCIAiCICyTCFSCIAiCIAjLJAKVIAiCIAjCMolAJQiCIAiCsEwiUAmCIAiCICyT\nCFSCIAiCIAjLJAKVIAiCIAjCMolAJQiCIAiCsEwiUAmCIAiCICyTmu0GrLbxr34h201YMUFVJqXp\n2W5GzhP9tDiiny5M9NHiiH5aHNFPF7acPir93F+vcGsujhihEgRBEARBWCYRqARBEARBEJZJBCpB\nEARBEIRlKvgaKkFYb37WeZJgMk4smebejk04zaY5t8dSGo8cOozdrGBTzLyvdcO8x+gNTvHr7h4k\n4K6NLZTYbPPu0xMMcnTCz6XlPsocjtX6cQRBEPKCGKEShALij8UxZI1P7WrjYzuaeeJk17z7/Lq7\nh/svbeSTV7bhsEoMhSPz7vNkVzd/evUmPrGrlf8+fmLe7T3BIC8O9nN1UzGPd54gmEjMu88rA0P8\nx4GD/Mtb+/HH598OMJVIMhGLL+EnFQRByC0iUAlCATEpCqF4CoCxSGLe6BSAy2RmOBQDIBBNYlXn\nD1TLEiiyhEmWMRb4Pkcn/NzeVku128HV9T76p8JzbjcMg4MTY/zp1Zv4412t/OLEyXmP8ergEL84\n1cmewV5+3nlqwZ9nJBzlwOg4aV2sjBIEIbeJQCUIBcRpNrGpxMe3XzzKnq5xbmtsmHef6+treHMg\nyN+9dJQmjxev1TLvPruqqvj684f5+guHubWhft7tl1WU8f23TvFC9wjPnhqhyeuZdx9NNzAMg2Ra\nR5akebcfmZjkEzvbuP/SZgLJ2LzbD4yO82x/N6hJvrPvAIYxP9oF4wlOTAZE4BIEIetEDZUgFJjL\nK8q5vKL8nLfLksR7WzdgOs9+Lx2+Ujp8ped8DJ/dzr3t7fRPhfj4ti3YzhrlkiSJq6qq+NaeI0hI\nvL+tdd5j2FSVY2NBiu0W4qn0vNvfHBnlU1dtRJIkTk6GiGkadtOZEbeT/gC/6++lo9zLM/t6+INL\ntiItENwMw1jwuiAIwkoSgUoQhCVxW8y0+0rOefuWMh9bynznvP19rRt5qqubaErjwx3t825vLfby\nk0M9XFJZTH8wOi+0vdg/yKd3taPIEpPRBIF4Aq/NOnP7YCjMTztP4DSr+GwO3tHcuISfUhAEYXFE\noBIEISsUWTpvyNlZXcnR8UkODIT42NbN80aZWkuK+fnRXq6s9XFiIsStdeY5t/+6u4fPXduBqsh8\n+6UjpHUDRT7zGIF4gp+fOImMxF0bm3Ga5/57QRCEiyFqqARByFltpcXc1FA3Z6pv2hVVFZSaXfym\nc4wPd7SjyHPfzpwmM/1TUTRdJ5zQkM+a9fvx0eN8dHszH768gUePHp/3+IZhMBGLo4n6LEEQFkGM\nUAmCkLfON61454ZmftZ5gnAyxe3NTfNGuFRFwmGefgucW/Ce1g2+s28/9V4HJyZCfLijnaIFivcF\nQRCmiUAlCEJBMiky72vbeM7b20tK+foLh9AN2FFRMee2rkCQS6uLubWliuFQlN90DvCulqaZ2w3D\n4Pm+AcLJFLc01mFWlFX7OQRByA8iUAmCsC5dUVXBjsrMasizR698dhu/6+zlpqZKXu4bo87tmnP7\nT4+fYEuVhxK7i4cPHubj27asWbsFQchNIlAJgrBunWs7BY/Vwg01dfzzK520eL3zphVDqSSXV2dW\nOFrU+aWoz/b00Tc1RXtpCdsrK+bdLghC4RGBShAEYQFN3iJafcUL7tXV7Cni39/oxG0xYVfnrg7c\nPzKG2azz6avb+Pc3OqkOu6h0irMOBaHQiUAlCIJwka6urWY8GiOuadTUzZ0OHI/F6KjKXGvwOvHH\nE3MC1auDQxybnKTYauOdzY1i01FBKBBi2wRBEIQlKLXbqDmrtgrgqpoqHjvYxz+8cpS3BgK0lXhn\nbhuPxugJB/jUVW00lNp4eWBoLZssCMIqEiNUgiAIK8iqqnzism3ohjHvDMNIKkW5I7Obe7XbTs/4\n5Jzbj45P8urQMLUuFzc21K5ZmwVBWD4xQiUIgrAKFjoQus7tYiCQ4O9fOsqP9nVz06zQFEwkeGV4\ngE9dtRGXHV4dGF7L5gqCsExihEoQBGGNSJLEhzraFrwtGE/SVOxCkiS2VHh58ujInNuHwmEOj0+y\n2VdCuUMUuQtCrlnWCNXx48e59dZb+f73vw/AF77wBe644w4eeOABHnjgAX7729+uSCMFQRAKXa3b\nycnxCN957Tj//GonN9TXzNw2EonwZNcpLqv18PMTJ5iIxbPYUkEQFrLkEapYLMaXvvQldu3aNef6\nZz/7Wa6//vplN0wQBGE9kSSJj2ztIJlOY5LlOav/OicDvGNjDQ1eJ7duqOLkZICS6jP7WxmGgW4Y\n884zFARh7Sz51WexWPjOd75DWVnZSrZHEARhXTMryrytFDp8pTx2qJdX+sb4xdF+2kqLZ2475Q/w\nj2/t44dHDvPEya61bq4gCKctOVDJsozZbJ53/Xvf+x4PPvggn/nMZwgEAstqnCAIggBeq4UPd7ST\niCs8uLkDt+XMe+/z/QP8X9dt5pO7NjEaC2exlYKwvq3o+PBdd93FZz7zGb773e/S2trKt7/97ZV8\neEEQhHXLbTGzrdyH86w/ZB0mE72BCMl0mkgyPee2U/4ADx84zNNdvRiGsZbNFYR1Z0VX+V155ZUz\nX99888381V/91Xnv7/XaUdUzp7T7fPM3yVuu4ALnbOUzU4H9PKtF9NPiiH66sFzvo7vbN/L40ROE\nU0k+2N46095oSuM3vb38+TXtvNQ7ykuDQ1w/q9B9peV6P+UK0U8XttQ+Wo0McTFWNFD9yZ/8CZ/7\n3Oeora3llVdeYePGjee9v98fnfna53MxNhZayeYALHgOV74yqXJB/TyrRfTT4oh+urB86aN3b2yZ\n+Xq6vZPROLVFdmRJYltFMT98s2fOzzI9YrUSR9/kSz9lm+inC1tOH61Ghjjb+ULbkgPVoUOHeOih\nhxgcHERVVXbv3s3999/Pn//5n2Oz2XA4HHz5y19e6sMLgiAIy1BmtzE6leRfX+9kOBTjQ+2bZm57\nc3iU10eGkYBLy8u5vKI8ew0VhAKx5EDV0dHBI488Mu/6rbfeuqwGCYKw9kIrtDrM1dy4Io8jLJ8k\nSdy/pZ2YpmE9a+XgW6MjfPbaDgC++cJhEagEYQWIndIFocAsNhxJwHSZ8qmR+St2L1ZTeXJR31uE\nrrVlU+e/zesGhBMpkCB9Vq1656SfQDzBZRVlYl8rQbgIIlAJQh65UGDJBKPFhSNVldFWsJ5jsaGs\nCRG6su3utla++8ZJDAzubjtT6/rb3n5SJKgrcvLdA4f42LYtWWylIOQXEagEIYcsLjDlt8X8DOcL\nXSJsLZ/bYubDmzfNu94zFeTTV2XOGnytf3ytmyUIeU0EKkFYY+shNC3XufrgQtOKImwtT73bw2OH\neqj1OEikzoxepnWdJ05
"text/plain": [
"<matplotlib.figure.Figure at 0x7f5f446048d0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.svm import SVC\n",
"\n",
"svc = SVC()\n",
"svc.fit(X_train, y_train)\n",
"\n",
"svc_preds = svc.predict(X_test)\n",
"\n",
"print('SVM')\n",
"print(classification_report(y_test, svc_preds))\n",
"\n",
"plt.figure(figsize=(10,7))\n",
"plot_decision_surface(X, y, svc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### MLP"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2017-03-16T18:12:01.663538",
"start_time": "2017-03-16T18:11:56.864322"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.5/dist-packages/sklearn/neural_network/multilayer_perceptron.py:563: ConvergenceWarning: Stochastic Optimizer: Maximum iterations reached and the optimization hasn't converged yet.\n",
" % (), ConvergenceWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"MLP\n",
" precision recall f1-score support\n",
"\n",
" 0 0.28 0.42 0.34 83\n",
" 1 0.46 0.26 0.33 90\n",
" 2 0.32 0.49 0.39 79\n",
" 3 0.67 0.36 0.47 109\n",
" 4 0.34 0.37 0.36 89\n",
"\n",
"avg / total 0.43 0.38 0.38 450\n",
"\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlQAAAGmCAYAAAC6KhqmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3XdgXFeZ8P/vLdOLNJJGXbKKbcmSW+LEiZ3qdAipkBBI\nob67S18ILAR2WX5LfgvsspQXdtkC2SWhbDaQsKRhQgrEcUucxEXusmyrt9FIM5o+975/jC1blrvK\nqDyfv2buHd955ng088w5zzlHMU3TRAghhBBCnDc12wEIIYQQQsx0klAJIYQQQoyTJFRCCCGEEOMk\nCZUQQgghxDhJQiWEEEIIMU6SUAkhhBBCjJOezSfv7Q2N3Pb5nAwMRLIYzfQnbXR2pJ3OjrTTmUkb\nnR1pp7Mj7XRm072N/H7PKc9Nmx4qXdeyHcK0J210dqSdzo6005lJG50daaezI+10ZjO5jaZNQiWE\nEEIIMVNJQiWEEEIIMU6SUAkhhBBCjJMkVEIIIYQQ4yQJlRBCCCHEOElCJYQQQggxTpJQCSGEEEKM\nkyRUQgghhBDjJAmVEEIIIcQ4SUIlhBBCCDFOklAJIYQQQoyTJFRCCCGEEOMkCZUQQgghxDhJQiWE\nEEIIMU6SUAkhhBBCjJOe7QAm24t3PJDtECaMrqukUka2w5j2pJ3OzkS1U01RAk9t9QRENP0M6ipJ\neS+dkbTT2ZF2OrPxtFHBF74xwdGcG+mhEkIIIYQYJ0mohBBCCCHGSRIqIYQQQohxmvU1VELMNa/H\ne8ENg8NJLtMLcaij/8zjRpp1qW5yrFZScYNLbYVjrtGTjLJTCWKasELJx6tbxzymOxml3Yjgjdug\nuWXW1lEJIcTZkB4qIWaRUDqJr9jGg9cu4TPXNfBmon/MY7YmAnxsTT2fu2Yx5eUu+pPxMY9pUoN8\n5Z3L+eI7lrLFHHuN7mSUrtwId62p4unBEEPJ1JjHbGrv5L+27+A/3t7GQGzscwAMxRP0R2Pn/kKF\nEGKakYRKiFlEVxSC0QQAPeEYDrQxj7ErGh1DUQD6I3Gs6tiPAU1V0FQFi6qiKGOfpyM9zB1Lqyjz\nulizsISOExIm0zTZ0d/LZy5bxMdX1fHM/uYx19jc0ckzB/bxWsdhnt534KSvpzscYXtPH2lDZkYJ\nIaY3GfITYhZxqDq5Q1b+/rmtaEmFS2z+MY9ZbPXxhzfaWbu1DX/SjsdmGfOYqqSbh59/m7Rh0mjm\njjlfrXt5ZMNe1tSV8Lsd7XyoJG/MY1KGiWmaJNIG6kmysl39AT59WT0AP9ywe8z57T19NAV6aCzM\n5cdbt/Nny5einHCdwVic3kiU6lwv2kkSQyGEmCqSUAkxy8y3epmPF2wnP68qCpfaCjPrUKkn7/mp\ntLqpxA0anKSTi1zdysq0n5atIa62FmPX0qPOK4rC6tJSvv/aLhQU3lNfN+YaDl1nT+8geU4bsWR6\nzPm3unv45OqFKIpCcyBENJXCaTmW/DUPBPlT22Eai3y8uPXQSRMuyPSWney4EEJMJEmohBDnxanp\nVGmeI/fGJkRLCv0sKRzbQ3bUu+sW8vuWg0SSKe5rbBhzvi7Px6+bDrG8JI+2wQgOffTH1fq2Dj61\nqgFNVQhE4gRjcXwO+8j5jlCY3+zbj9uq43e4eIcUzQshJpEkVEKICRE6x5l+mqqcNsm5pKyE3X0B\ntreH+PDSxWN6mery83h692EurfCzvz/E9ZWjZyK+cPAQX7iiEV1T+cGGXaQNE009do1gLM7T+5tR\nUbhtYS1u69iZjEIIcbak6EAIMW4HuicnGakvyOOaqspRQ31HrSwtpsDq4aV9vdzX2DCmhsptsdI2\nFCFlGITjKdQTRv1+tXsvH7qolvtWVPHE7r1jrm+aJv3RGCkpiBdCnAXpoRJCzFinG1a8dUEtv923\nn3Aiybtqa8b0cOmagst69CPQHHUubZj8eOs25vlc7O8PcV9jA7n2UxSlCSEEklAJIWYpi6by7vqF\npzzfkF/Ad9Y1YZhwcXHxqHMtwUEuKMvj+vmldIUivLSvnZvn14ycN02TV1vbCSeSXFddiVU7SeW+\nEGJOkYRKCDEnrSwt5uKSIoAxvVd+p4M/7TvMNTUlbGztpdLrGXX+N3v3s6Q0h3ynh0d37OSjy5ZM\nWdxCiOlJaqiEEBPiQLeVUHNLtsM4J4qinHRJhRy7javLK/n3Tftwqc4xw4qhZIIVZflU+dzY9LEf\noy8fauXR7U280dk1abELIaYX6aESQoiTqPHlUufPI5kaW5Rem5PLf27Zh9dmwXnCPofbunuxWg0+\ndVk9/7llH2VhDyVu11SFLYTIEkmohBDiHF1WUUZfJEoslaK8cvRwYF80SmNp5liVz81ALD4qodrc\n0cmeQIA8u4N31lbLoqNCzBIy5CeEEOehwOmg/ITaKoDV5aU8uaOVH23azdvtQerzfSPn+iJRDoWD\nfHJ1PVUFDja2d05lyEKISSQ9VEIIMYHsus7HLlyGYZpj9jAcTiYpcmVWcy/zOjnUFxh1fndfgM2d\nXVR4PKypqpiymIUQ4yc9VEKICTXTCtMny8k2hK70emgPxvnnDbt5fOtBrjkuaRqMx9nU1c4nVy/E\n44TN7VLQLsRMIj1UQogJc6DbSk1RItthTFuKovC+xvqTnhuMJajJ86AoCkuKffxud/eo853hMDv7\nAiz251PkkiJ3IaabcfVQ7d27l+uvv56f//znADz00EPccsstPPDAAzzwwAP88Y9/nJAghRBitqvw\numnuG+bHr+/l3zfv4+p55SPnuoeH+V3LAS6syOHp/fvpj8ayGKkQ4mTOu4cqGo3y8MMPs2rVqlHH\nP//5z3PVVVeNOzAhhJhLFEXhg0sbSaTTWFR11Oy/fYEg71hYTpXPzfULSmkOBMkvO7a6u2maGKY5\nZj9DIcTUOe+/PpvNxo9//GMKCwsnMh4hxCwgdVTnz6ppY5ZSaPQX8GTTYTa19vLM7jbqC/JGzh0Y\nCPKvb2/ll7t28ry0uxBZc94JlaqqWK1jd5j/2c9+xgc+8AEefPBBgsHguIITQsw8B7rHfi6I8fHZ\nbdzX2EA8pvGBxY14bcfa+NW2dv7qysV8YtUieqLhLEYpxNw2of3Dt912Gw8++CA//elPqaur4wc/\n+MFEXl4IIeYsr83KsiI/7hN+yLosFg4Hh0mk0wwn0qPOHRgI8uj2nfyh5TCmaU5luELMORM6y+/S\nSy8duX3ttdfyta997bSP9/mc6PqxXdr9/rGL5I2XfpJ9tmay2fZ6Jou009mZrHZSAMss+T+Y7q/j\nroaFPLV7P+Fkgvc21I3EG0mmeOnwYT57eQMbDvewoaOTq44rdJ9o072dpgtppzM73zaajBziXExo\nQvXpT3+aL3zhC1RUVLBp0yYWLlx42scPDERGbvv9Hnp7QxMZDgCpk+zDNVPpujqrXs9kkXY6O5PZ\nTiacdA+8mcaiqzPiddy+cP7I7aPxBiIxKnKdqIrCsuI8fvnWoVGv5WiP1URsfTNT2inbpJ3ObDxt\nNBk5xIlOl7Sdd0LV1NTEN7/5TTo6OtB1nbVr13L//ffz2c9+FofDgcvl4u///u/P9/JCCCHGodDp\noGcowU/e2EdXKMr7GhaNnHurq4c3urtQgAuKilhRXJS9QIWYJc47oWpsbOSxxx4bc/z6668fV0BC\niNkh1NyCp7Y622HMWYqicP+SBqKpFPYTZg6+3dPN569oBOB763ZKQiXEBJDBXCHEhJOZftOHQ9fH\nDOsZJoTjScKJJOkTatX3BQZ4vaOLtCFDU0KcC9l6Rggh5pi76uv46ZZmTEzuqj9W6/rHw20kiVOZ\n6+an25v48LIlWYxSiJlFEiohhJhjvDYr9y1eNOb4oaFBPrU6s9fg6219Ux2WEDOaDPkJISaNrJg+\ns8zz5vBk0yFeb+sjnjw25Jc2DJ7Zf4Bf7tzNYDyexQiFmL4koRJCTAqpo5p5rqosp9LlIxg2+cCS\nxpHjT+7ZzyXzfLz/gnn
"text/plain": [
"<matplotlib.figure.Figure at 0x7f5f44580c88>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.neural_network import MLPClassifier\n",
"\n",
"mlp = MLPClassifier()\n",
"mlp.fit(X_train, y_train)\n",
"\n",
"mlp_preds = mlp.predict(X_test)\n",
"\n",
"print('MLP')\n",
"print(classification_report(y_test, mlp_preds))\n",
"\n",
"plt.figure(figsize=(10,7))\n",
"plot_decision_surface(X, y, mlp)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"We see that some classifiers (kNN, SVM) successfully learn the spiral problem. They can classify correctly in any part of the plane.\n",
"\n",
"Nevertheless, some classifiers (Logistic Regression, Gaussian Naive Bayes) are not able to learn the spiral pattern with their default configurations.\n",
"\n",
"In particular, the MLP performs very bad: it is not able to learn the spiral function. Nevertheless, it should be able to."
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"**QUESTION: Why do you think that MLP does not learn the spiral pattern?**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Answer here:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Try to make the MLP learn the spiral!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Your task is to learn the spiral with the MLP classifier.\n",
"\n",
"Write the necessary code in the following cells.\n",
"\n",
"You should try to change some parameters of the MLPClassifier. Some parameters that you can change are:\n",
"- complexity of the network\n",
"- regularization of the network\n",
"- new features that are passed to the network\n",
"\n",
"You can search inspiration on [this playground](http://playground.tensorflow.org)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2017-03-27T11:56:47.325519",
"start_time": "2017-03-27T11:56:47.316384"
},
"collapsed": true
},
"outputs": [],
"source": [
"# write your code in here"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# References"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* [MLP documentation](http://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Licence"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The notebook is freely licensed under under the [Creative Commons Attribution Share-Alike license](https://creativecommons.org/licenses/by/2.0/). \n",
"\n",
"© 2018 Óscar Araque, Universidad Politécnica de Madrid."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.5"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
"autocomplete": true,
"bibliofile": "biblio.bib",
"cite_by": "apalike",
"current_citInitial": 1,
"eqLabelWithNumbers": true,
"eqNumInitial": 1,
"hotkeys": {
"equation": "Ctrl-E",
"itemize": "Ctrl-I"
},
"labels_anchors": false,
"latex_user_defs": false,
"report_style_numbering": false,
"user_envs_cfg": false
}
},
"nbformat": 4,
"nbformat_minor": 1
}