mirror of
https://github.com/gsi-upm/sitc
synced 2025-12-15 09:38:16 +00:00
Not done reviewing ml2 yet
This commit is contained in:
109
ml2/plot_learning_curve.py
Normal file
109
ml2/plot_learning_curve.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Taken from http://scikit-learn.org/stable/auto_examples/model_selection/plot_learning_curve.html
|
||||
|
||||
========================
|
||||
Plotting Learning Curves
|
||||
========================
|
||||
|
||||
On the left side the learning curve of a naive Bayes classifier is shown for
|
||||
the digits dataset. Note that the training score and the cross-validation score
|
||||
are both not very good at the end. However, the shape of the curve can be found
|
||||
in more complex datasets very often: the training score is very high at the
|
||||
beginning and decreases and the cross-validation score is very low at the
|
||||
beginning and increases. On the right side we see the learning curve of an SVM
|
||||
with RBF kernel. We can see clearly that the training score is still around
|
||||
the maximum and the validation score could be increased with more training
|
||||
samples.
|
||||
"""
|
||||
#print(__doc__)
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn import cross_validation
|
||||
from sklearn.naive_bayes import GaussianNB
|
||||
from sklearn.svm import SVC
|
||||
from sklearn.datasets import load_digits
|
||||
from sklearn.learning_curve import learning_curve
|
||||
|
||||
|
||||
def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,
|
||||
n_jobs=1, train_sizes=np.linspace(.1, 1.0, 5)):
|
||||
"""
|
||||
Generate a simple plot of the test and traning learning curve.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimator : object type that implements the "fit" and "predict" methods
|
||||
An object of that type which is cloned for each validation.
|
||||
|
||||
title : string
|
||||
Title for the chart.
|
||||
|
||||
X : array-like, shape (n_samples, n_features)
|
||||
Training vector, where n_samples is the number of samples and
|
||||
n_features is the number of features.
|
||||
|
||||
y : array-like, shape (n_samples) or (n_samples, n_features), optional
|
||||
Target relative to X for classification or regression;
|
||||
None for unsupervised learning.
|
||||
|
||||
ylim : tuple, shape (ymin, ymax), optional
|
||||
Defines minimum and maximum yvalues plotted.
|
||||
|
||||
cv : integer, cross-validation generator, optional
|
||||
If an integer is passed, it is the number of folds (defaults to 3).
|
||||
Specific cross-validation objects can be passed, see
|
||||
sklearn.cross_validation module for the list of possible objects
|
||||
|
||||
n_jobs : integer, optional
|
||||
Number of jobs to run in parallel (default 1).
|
||||
"""
|
||||
plt.figure()
|
||||
plt.title(title)
|
||||
if ylim is not None:
|
||||
plt.ylim(*ylim)
|
||||
plt.xlabel("Training examples")
|
||||
plt.ylabel("Score")
|
||||
train_sizes, train_scores, test_scores = learning_curve(
|
||||
estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes)
|
||||
train_scores_mean = np.mean(train_scores, axis=1)
|
||||
train_scores_std = np.std(train_scores, axis=1)
|
||||
test_scores_mean = np.mean(test_scores, axis=1)
|
||||
test_scores_std = np.std(test_scores, axis=1)
|
||||
plt.grid()
|
||||
|
||||
plt.fill_between(train_sizes, train_scores_mean - train_scores_std,
|
||||
train_scores_mean + train_scores_std, alpha=0.1,
|
||||
color="r")
|
||||
plt.fill_between(train_sizes, test_scores_mean - test_scores_std,
|
||||
test_scores_mean + test_scores_std, alpha=0.1, color="g")
|
||||
plt.plot(train_sizes, train_scores_mean, 'o-', color="r",
|
||||
label="Training score")
|
||||
plt.plot(train_sizes, test_scores_mean, 'o-', color="g",
|
||||
label="Cross-validation score")
|
||||
|
||||
plt.legend(loc="best")
|
||||
return plt
|
||||
|
||||
|
||||
#digits = load_digits()
|
||||
#X, y = digits.data, digits.target
|
||||
|
||||
|
||||
#title = "Learning Curves (Naive Bayes)"
|
||||
# Cross validation with 100 iterations to get smoother mean test and train
|
||||
# score curves, each time with 20% data randomly selected as a validation set.
|
||||
#cv = cross_validation.ShuffleSplit(digits.data.shape[0], n_iter=100,
|
||||
# test_size=0.2, random_state=0)
|
||||
|
||||
#estimator = GaussianNB()
|
||||
#plot_learning_curve(estimator, title, X, y, ylim=(0.7, 1.01), cv=cv, n_jobs=4)
|
||||
|
||||
#title = "Learning Curves (SVM, RBF kernel, $\gamma=0.001$)"
|
||||
# SVC is more expensive so we do a lower number of CV iterations:
|
||||
#cv = cross_validation.ShuffleSplit(digits.data.shape[0], n_iter=10,
|
||||
# test_size=0.2, random_state=0)
|
||||
#estimator = SVC(gamma=0.001)
|
||||
#plot_learning_curve(estimator, title, X, y, (0.7, 1.01), cv=cv, n_jobs=4)
|
||||
|
||||
#plt.show()
|
||||
Reference in New Issue
Block a user