From 3d6d96dd8ae76fea8724f60630190ac56d50901a Mon Sep 17 00:00:00 2001 From: Oscar Araque Date: Thu, 11 Mar 2021 16:28:14 +0100 Subject: [PATCH] updated ml1/2_6: using scorer to avoid traning warnings --- ml1/2_6_Model_Tuning.ipynb | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ml1/2_6_Model_Tuning.ipynb b/ml1/2_6_Model_Tuning.ipynb index f115de0..228d7c8 100644 --- a/ml1/2_6_Model_Tuning.ipynb +++ b/ml1/2_6_Model_Tuning.ipynb @@ -416,7 +416,7 @@ "source": [ "# Set the parameters by cross-validation\n", "\n", - "from sklearn.metrics import classification_report\n", + "from sklearn.metrics import classification_report, recall_score, precision_score, make_scorer\n", "\n", "# set of parameters to test\n", "tuned_parameters = [{'max_depth': np.arange(3, 10),\n", @@ -434,8 +434,13 @@ " print(\"# Tuning hyper-parameters for %s\" % score)\n", " print()\n", "\n", + " if score == 'precision':\n", + " scorer = make_scorer(precision_score, average='weighted', zero_division=0)\n", + " elif score == 'recall':\n", + " scorer = make_scorer(recall_score, average='weighted', zero_division=0)\n", + " \n", " # cv = the fold of the cross-validation cv, defaulted to 5\n", - " gs = GridSearchCV(DecisionTreeClassifier(), tuned_parameters, cv=10, scoring='%s_weighted' % score)\n", + " gs = GridSearchCV(DecisionTreeClassifier(), tuned_parameters, cv=10, scoring=scorer)\n", " gs.fit(x_train, y_train)\n", "\n", " print(\"Best parameters set found on development set:\")\n", @@ -552,7 +557,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.7" + "version": "3.8.6" }, "latex_envs": { "LaTeX_envs_menu_present": true,