mirror of
https://github.com/gsi-upm/sitc
synced 2026-03-01 09:18:16 +00:00
Add files via upload
Updated visualize decision tree functions
This commit is contained in:
committed by
GitHub
parent
8e177963af
commit
5d01f26e72
@@ -53,10 +53,10 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.tree import DecisionTreeClassifier
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
|
from sklearn.inspection import DecisionBoundaryDisplay
|
||||||
|
|
||||||
def plot_tree_iris():
|
def plot_tree_iris():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Taken fromhttp://scikit-learn.org/stable/auto_examples/tree/plot_iris.html
|
Taken fromhttp://scikit-learn.org/stable/auto_examples/tree/plot_iris.html
|
||||||
"""
|
"""
|
||||||
# Parameters
|
# Parameters
|
||||||
@@ -67,11 +67,11 @@ def plot_tree_iris():
|
|||||||
# Load data
|
# Load data
|
||||||
iris = load_iris()
|
iris = load_iris()
|
||||||
|
|
||||||
for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3],
|
for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]):
|
||||||
[1, 2], [1, 3], [2, 3]]):
|
|
||||||
# We only take the two corresponding features
|
# We only take the two corresponding features
|
||||||
X = iris.data[:, pair]
|
X = iris.data[:, pair]
|
||||||
y = iris.target
|
y = iris.target
|
||||||
|
'''
|
||||||
|
|
||||||
# Shuffle
|
# Shuffle
|
||||||
idx = np.arange(X.shape[0])
|
idx = np.arange(X.shape[0])
|
||||||
@@ -84,34 +84,38 @@ def plot_tree_iris():
|
|||||||
mean = X.mean(axis=0)
|
mean = X.mean(axis=0)
|
||||||
std = X.std(axis=0)
|
std = X.std(axis=0)
|
||||||
X = (X - mean) / std
|
X = (X - mean) / std
|
||||||
|
'''
|
||||||
# Train
|
# Train
|
||||||
model = DecisionTreeClassifier(max_depth=3, random_state=1).fit(X, y)
|
clf = DecisionTreeClassifier(max_depth=3, random_state=1).fit(X, y)
|
||||||
|
|
||||||
# Plot the decision boundary
|
# Plot the decision boundary
|
||||||
plt.subplot(2, 3, pairidx + 1)
|
# Taken from https://scikit-learn.org/stable/auto_examples/tree/plot_iris_dtc.html
|
||||||
|
# Plot the decision boundary
|
||||||
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
|
ax = plt.subplot(2, 3, pairidx + 1)
|
||||||
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
|
plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
|
||||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
|
DecisionBoundaryDisplay.from_estimator(
|
||||||
np.arange(y_min, y_max, plot_step))
|
clf,
|
||||||
|
X,
|
||||||
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
|
cmap=plt.cm.RdYlBu,
|
||||||
Z = Z.reshape(xx.shape)
|
response_method="predict",
|
||||||
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
|
ax=ax,
|
||||||
|
xlabel=iris.feature_names[pair[0]],
|
||||||
plt.xlabel(iris.feature_names[pair[0]])
|
ylabel=iris.feature_names[pair[1]],
|
||||||
plt.ylabel(iris.feature_names[pair[1]])
|
)
|
||||||
plt.axis("tight")
|
|
||||||
|
|
||||||
# Plot the training points
|
# Plot the training points
|
||||||
for i, color in zip(range(n_classes), plot_colors):
|
for i, color in zip(range(n_classes), plot_colors):
|
||||||
idx = np.where(y == i)
|
idx = np.asarray(y == i).nonzero()
|
||||||
plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i],
|
plt.scatter(
|
||||||
cmap=plt.cm.Paired)
|
X[idx, 0],
|
||||||
|
X[idx, 1],
|
||||||
plt.axis("tight")
|
c=color,
|
||||||
|
label=iris.target_names[i],
|
||||||
|
edgecolor="black",
|
||||||
|
s=15
|
||||||
|
)
|
||||||
|
plt.axis("tight")
|
||||||
plt.suptitle("Decision surface of a decision tree using paired features")
|
plt.suptitle("Decision surface of a decision tree using paired features")
|
||||||
plt.legend()
|
#plt.legend()
|
||||||
|
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|||||||
Reference in New Issue
Block a user