diff --git a/ml1/util_knn.py b/ml1/util_knn.py index 1f71f94..ba1c382 100644 --- a/ml1/util_knn.py +++ b/ml1/util_knn.py @@ -2,6 +2,7 @@ import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from sklearn import neighbors, datasets +import seaborn as sns from sklearn.neighbors import KNeighborsClassifier # Taken from http://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html @@ -19,9 +20,9 @@ def plot_classification_iris(): h = .02 # step size in the mesh n_neighbors = 15 - # Create color maps - cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']) - cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF']) + # Create color maps + cmap_light = ListedColormap(['orange', 'cyan', 'cornflowerblue']) + cmap_bold = ['darkorange', 'c', 'darkblue'] for weights in ['uniform', 'distance']: # we create an instance of Neighbours Classifier and fit the data. @@ -29,7 +30,7 @@ def plot_classification_iris(): clf.fit(X, y) # Plot the decision boundary. For that, we will assign a color to each - # point in the mesh [x_min, m_max]x[y_min, y_max]. + # point in the mesh [x_min, x_max]x[y_min, y_max]. x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), @@ -38,14 +39,17 @@ def plot_classification_iris(): # Put the result into a color plot Z = Z.reshape(xx.shape) - plt.figure() - plt.pcolormesh(xx, yy, Z, cmap=cmap_light) + plt.figure(figsize=(8, 6)) + plt.contourf(xx, yy, Z, cmap=cmap_light) # Plot also the training points - plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold) + sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=iris.target_names[y], + palette=cmap_bold, alpha=1.0, edgecolor="black") plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max()) plt.title("3-Class classification (k = %i, weights = '%s')" - % (n_neighbors, weights)) + % (n_neighbors, weights)) + plt.xlabel(iris.feature_names[0]) + plt.ylabel(iris.feature_names[1]) - plt.show() \ No newline at end of file +plt.show() \ No newline at end of file