1
0
mirror of https://github.com/gsi-upm/sitc synced 2024-11-24 15:32:29 +00:00

updated ml1/2_6: using scorer to avoid traning warnings

This commit is contained in:
Oscar Araque 2021-03-11 16:28:14 +01:00
parent 44aa3d24fb
commit 3d6d96dd8a

View File

@ -416,7 +416,7 @@
"source": [ "source": [
"# Set the parameters by cross-validation\n", "# Set the parameters by cross-validation\n",
"\n", "\n",
"from sklearn.metrics import classification_report\n", "from sklearn.metrics import classification_report, recall_score, precision_score, make_scorer\n",
"\n", "\n",
"# set of parameters to test\n", "# set of parameters to test\n",
"tuned_parameters = [{'max_depth': np.arange(3, 10),\n", "tuned_parameters = [{'max_depth': np.arange(3, 10),\n",
@ -434,8 +434,13 @@
" print(\"# Tuning hyper-parameters for %s\" % score)\n", " print(\"# Tuning hyper-parameters for %s\" % score)\n",
" print()\n", " print()\n",
"\n", "\n",
" if score == 'precision':\n",
" scorer = make_scorer(precision_score, average='weighted', zero_division=0)\n",
" elif score == 'recall':\n",
" scorer = make_scorer(recall_score, average='weighted', zero_division=0)\n",
" \n",
" # cv = the fold of the cross-validation cv, defaulted to 5\n", " # cv = the fold of the cross-validation cv, defaulted to 5\n",
" gs = GridSearchCV(DecisionTreeClassifier(), tuned_parameters, cv=10, scoring='%s_weighted' % score)\n", " gs = GridSearchCV(DecisionTreeClassifier(), tuned_parameters, cv=10, scoring=scorer)\n",
" gs.fit(x_train, y_train)\n", " gs.fit(x_train, y_train)\n",
"\n", "\n",
" print(\"Best parameters set found on development set:\")\n", " print(\"Best parameters set found on development set:\")\n",
@ -552,7 +557,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.6.7" "version": "3.8.6"
}, },
"latex_envs": { "latex_envs": {
"LaTeX_envs_menu_present": true, "LaTeX_envs_menu_present": true,