1
0
mirror of https://github.com/gsi-upm/sitc synced 2026-03-01 09:18:16 +00:00
Files
sitc/ml1/util_ds.py
Carlos A. Iglesias 5d01f26e72 Add files via upload
Updated visualize decision tree functions
2026-02-25 17:19:08 +01:00

122 lines
3.9 KiB
Python

import numpy as np
# Taken from http://chrisstrelioff.ws/sandbox/2015/06/25/decision_trees_in_python_again_cross_validation.html
def get_code(tree, feature_names, target_names,
spacer_base=" "):
"""Produce psuedo-code for decision tree.
Args
----
tree -- scikit-leant DescisionTree.
feature_names -- list of feature names.
target_names -- list of target (class) names.
spacer_base -- used for spacing code (default: " ").
Notes
-----
based on http://stackoverflow.com/a/30104792.
"""
left = tree.tree_.children_left
right = tree.tree_.children_right
threshold = tree.tree_.threshold
features = [feature_names[i] for i in tree.tree_.feature]
value = tree.tree_.value
def recurse(left, right, threshold, features, node, depth):
spacer = spacer_base * depth
if (threshold[node] != -2):
print(spacer + "if ( " + features[node] + " <= " + \
str(threshold[node]) + " ) {")
if left[node] != -1:
recurse(left, right, threshold, features,
left[node], depth+1)
print(spacer + "}\n" + spacer +"else {")
if right[node] != -1:
recurse(left, right, threshold, features,
right[node], depth+1)
print(spacer + "}")
else:
target = value[node]
for i, v in zip(np.nonzero(target)[1],
target[np.nonzero(target)]):
target_name = target_names[i]
target_count = int(v)
print(spacer + "return " + str(target_name) + \
" ( " + str(target_count) + " examples )")
recurse(left, right, threshold, features, 0, 0)
# Taken from https://scikit-learn.org/stable/auto_examples/tree/plot_iris_dtc.html
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.inspection import DecisionBoundaryDisplay
def plot_tree_iris():
"""
Taken fromhttp://scikit-learn.org/stable/auto_examples/tree/plot_iris.html
"""
# Parameters
n_classes = 3
plot_colors = "bry"
plot_step = 0.02
# Load data
iris = load_iris()
for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]):
# We only take the two corresponding features
X = iris.data[:, pair]
y = iris.target
'''
# Shuffle
idx = np.arange(X.shape[0])
np.random.seed(13)
np.random.shuffle(idx)
X = X[idx]
y = y[idx]
# Standardize
mean = X.mean(axis=0)
std = X.std(axis=0)
X = (X - mean) / std
'''
# Train
clf = DecisionTreeClassifier(max_depth=3, random_state=1).fit(X, y)
# Plot the decision boundary
# Taken from https://scikit-learn.org/stable/auto_examples/tree/plot_iris_dtc.html
# Plot the decision boundary
ax = plt.subplot(2, 3, pairidx + 1)
plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
DecisionBoundaryDisplay.from_estimator(
clf,
X,
cmap=plt.cm.RdYlBu,
response_method="predict",
ax=ax,
xlabel=iris.feature_names[pair[0]],
ylabel=iris.feature_names[pair[1]],
)
# Plot the training points
for i, color in zip(range(n_classes), plot_colors):
idx = np.asarray(y == i).nonzero()
plt.scatter(
X[idx, 0],
X[idx, 1],
c=color,
label=iris.target_names[i],
edgecolor="black",
s=15
)
plt.axis("tight")
plt.suptitle("Decision surface of a decision tree using paired features")
#plt.legend()
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.show()