1
0
mirror of https://github.com/gsi-upm/sitc synced 2024-11-24 15:32:29 +00:00

Updated util_knn.py to new version of scikit

This commit is contained in:
cif2cif 2021-02-27 20:11:17 +01:00
parent 5144b7f228
commit 2f7cbe9e45

View File

@ -2,6 +2,7 @@ import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap from matplotlib.colors import ListedColormap
from sklearn import neighbors, datasets from sklearn import neighbors, datasets
import seaborn as sns
from sklearn.neighbors import KNeighborsClassifier from sklearn.neighbors import KNeighborsClassifier
# Taken from http://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html # Taken from http://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html
@ -20,8 +21,8 @@ def plot_classification_iris():
n_neighbors = 15 n_neighbors = 15
# Create color maps # Create color maps
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']) cmap_light = ListedColormap(['orange', 'cyan', 'cornflowerblue'])
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF']) cmap_bold = ['darkorange', 'c', 'darkblue']
for weights in ['uniform', 'distance']: for weights in ['uniform', 'distance']:
# we create an instance of Neighbours Classifier and fit the data. # we create an instance of Neighbours Classifier and fit the data.
@ -29,7 +30,7 @@ def plot_classification_iris():
clf.fit(X, y) clf.fit(X, y)
# Plot the decision boundary. For that, we will assign a color to each # Plot the decision boundary. For that, we will assign a color to each
# point in the mesh [x_min, m_max]x[y_min, y_max]. # point in the mesh [x_min, x_max]x[y_min, y_max].
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
@ -38,14 +39,17 @@ def plot_classification_iris():
# Put the result into a color plot # Put the result into a color plot
Z = Z.reshape(xx.shape) Z = Z.reshape(xx.shape)
plt.figure() plt.figure(figsize=(8, 6))
plt.pcolormesh(xx, yy, Z, cmap=cmap_light) plt.contourf(xx, yy, Z, cmap=cmap_light)
# Plot also the training points # Plot also the training points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold) sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=iris.target_names[y],
palette=cmap_bold, alpha=1.0, edgecolor="black")
plt.xlim(xx.min(), xx.max()) plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max()) plt.ylim(yy.min(), yy.max())
plt.title("3-Class classification (k = %i, weights = '%s')" plt.title("3-Class classification (k = %i, weights = '%s')"
% (n_neighbors, weights)) % (n_neighbors, weights))
plt.xlabel(iris.feature_names[0])
plt.ylabel(iris.feature_names[1])
plt.show() plt.show()