1
0
mirror of https://github.com/gsi-upm/sitc synced 2026-03-01 09:18:16 +00:00

Add files via upload

Updated visualize decision tree functions
This commit is contained in:
Carlos A. Iglesias
2026-02-25 17:19:08 +01:00
committed by GitHub
parent 8e177963af
commit 5d01f26e72

View File

@@ -53,10 +53,10 @@ import matplotlib.pyplot as plt
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from sklearn.inspection import DecisionBoundaryDisplay
def plot_tree_iris(): def plot_tree_iris():
""" """
Taken fromhttp://scikit-learn.org/stable/auto_examples/tree/plot_iris.html Taken fromhttp://scikit-learn.org/stable/auto_examples/tree/plot_iris.html
""" """
# Parameters # Parameters
@@ -67,11 +67,11 @@ def plot_tree_iris():
# Load data # Load data
iris = load_iris() iris = load_iris()
for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]):
[1, 2], [1, 3], [2, 3]]):
# We only take the two corresponding features # We only take the two corresponding features
X = iris.data[:, pair] X = iris.data[:, pair]
y = iris.target y = iris.target
'''
# Shuffle # Shuffle
idx = np.arange(X.shape[0]) idx = np.arange(X.shape[0])
@@ -84,34 +84,38 @@ def plot_tree_iris():
mean = X.mean(axis=0) mean = X.mean(axis=0)
std = X.std(axis=0) std = X.std(axis=0)
X = (X - mean) / std X = (X - mean) / std
'''
# Train # Train
model = DecisionTreeClassifier(max_depth=3, random_state=1).fit(X, y) clf = DecisionTreeClassifier(max_depth=3, random_state=1).fit(X, y)
# Plot the decision boundary # Plot the decision boundary
plt.subplot(2, 3, pairidx + 1) # Taken from https://scikit-learn.org/stable/auto_examples/tree/plot_iris_dtc.html
# Plot the decision boundary
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 ax = plt.subplot(2, 3, pairidx + 1)
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step), DecisionBoundaryDisplay.from_estimator(
np.arange(y_min, y_max, plot_step)) clf,
X,
Z = model.predict(np.c_[xx.ravel(), yy.ravel()]) cmap=plt.cm.RdYlBu,
Z = Z.reshape(xx.shape) response_method="predict",
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired) ax=ax,
xlabel=iris.feature_names[pair[0]],
plt.xlabel(iris.feature_names[pair[0]]) ylabel=iris.feature_names[pair[1]],
plt.ylabel(iris.feature_names[pair[1]]) )
plt.axis("tight")
# Plot the training points # Plot the training points
for i, color in zip(range(n_classes), plot_colors): for i, color in zip(range(n_classes), plot_colors):
idx = np.where(y == i) idx = np.asarray(y == i).nonzero()
plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i], plt.scatter(
cmap=plt.cm.Paired) X[idx, 0],
X[idx, 1],
plt.axis("tight") c=color,
label=iris.target_names[i],
edgecolor="black",
s=15
)
plt.axis("tight")
plt.suptitle("Decision surface of a decision tree using paired features") plt.suptitle("Decision surface of a decision tree using paired features")
plt.legend() #plt.legend()
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.show() plt.show()