mirror of
https://github.com/gsi-upm/sitc
synced 2024-11-24 23:42:29 +00:00
117 lines
3.8 KiB
Python
117 lines
3.8 KiB
Python
|
import numpy as np
|
||
|
|
||
|
# Taken from http://chrisstrelioff.ws/sandbox/2015/06/25/decision_trees_in_python_again_cross_validation.html
|
||
|
|
||
|
def get_code(tree, feature_names, target_names,
|
||
|
spacer_base=" "):
|
||
|
"""Produce psuedo-code for decision tree.
|
||
|
|
||
|
Args
|
||
|
----
|
||
|
tree -- scikit-leant DescisionTree.
|
||
|
feature_names -- list of feature names.
|
||
|
target_names -- list of target (class) names.
|
||
|
spacer_base -- used for spacing code (default: " ").
|
||
|
|
||
|
Notes
|
||
|
-----
|
||
|
based on http://stackoverflow.com/a/30104792.
|
||
|
"""
|
||
|
left = tree.tree_.children_left
|
||
|
right = tree.tree_.children_right
|
||
|
threshold = tree.tree_.threshold
|
||
|
features = [feature_names[i] for i in tree.tree_.feature]
|
||
|
value = tree.tree_.value
|
||
|
|
||
|
def recurse(left, right, threshold, features, node, depth):
|
||
|
spacer = spacer_base * depth
|
||
|
if (threshold[node] != -2):
|
||
|
print(spacer + "if ( " + features[node] + " <= " + \
|
||
|
str(threshold[node]) + " ) {")
|
||
|
if left[node] != -1:
|
||
|
recurse(left, right, threshold, features,
|
||
|
left[node], depth+1)
|
||
|
print(spacer + "}\n" + spacer +"else {")
|
||
|
if right[node] != -1:
|
||
|
recurse(left, right, threshold, features,
|
||
|
right[node], depth+1)
|
||
|
print(spacer + "}")
|
||
|
else:
|
||
|
target = value[node]
|
||
|
for i, v in zip(np.nonzero(target)[1],
|
||
|
target[np.nonzero(target)]):
|
||
|
target_name = target_names[i]
|
||
|
target_count = int(v)
|
||
|
print(spacer + "return " + str(target_name) + \
|
||
|
" ( " + str(target_count) + " examples )")
|
||
|
|
||
|
recurse(left, right, threshold, features, 0, 0)
|
||
|
|
||
|
# Taken from http://scikit-learn.org/stable/auto_examples/tree/plot_iris.html#example-tree-plot-iris-py
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
from sklearn.datasets import load_iris
|
||
|
from sklearn.tree import DecisionTreeClassifier
|
||
|
|
||
|
def plot_tree_iris():
|
||
|
"""
|
||
|
|
||
|
Taken fromhttp://scikit-learn.org/stable/auto_examples/tree/plot_iris.html
|
||
|
"""
|
||
|
# Parameters
|
||
|
n_classes = 3
|
||
|
plot_colors = "bry"
|
||
|
plot_step = 0.02
|
||
|
|
||
|
# Load data
|
||
|
iris = load_iris()
|
||
|
|
||
|
for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3],
|
||
|
[1, 2], [1, 3], [2, 3]]):
|
||
|
# We only take the two corresponding features
|
||
|
X = iris.data[:, pair]
|
||
|
y = iris.target
|
||
|
|
||
|
# Shuffle
|
||
|
idx = np.arange(X.shape[0])
|
||
|
np.random.seed(13)
|
||
|
np.random.shuffle(idx)
|
||
|
X = X[idx]
|
||
|
y = y[idx]
|
||
|
|
||
|
# Standardize
|
||
|
mean = X.mean(axis=0)
|
||
|
std = X.std(axis=0)
|
||
|
X = (X - mean) / std
|
||
|
|
||
|
# Train
|
||
|
model = DecisionTreeClassifier(max_depth=3, random_state=1).fit(X, y)
|
||
|
|
||
|
# Plot the decision boundary
|
||
|
plt.subplot(2, 3, pairidx + 1)
|
||
|
|
||
|
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
|
||
|
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
|
||
|
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
|
||
|
np.arange(y_min, y_max, plot_step))
|
||
|
|
||
|
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
|
||
|
Z = Z.reshape(xx.shape)
|
||
|
cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired)
|
||
|
|
||
|
plt.xlabel(iris.feature_names[pair[0]])
|
||
|
plt.ylabel(iris.feature_names[pair[1]])
|
||
|
plt.axis("tight")
|
||
|
|
||
|
# Plot the training points
|
||
|
for i, color in zip(range(n_classes), plot_colors):
|
||
|
idx = np.where(y == i)
|
||
|
plt.scatter(X[idx, 0], X[idx, 1], c=color, label=iris.target_names[i],
|
||
|
cmap=plt.cm.Paired)
|
||
|
|
||
|
plt.axis("tight")
|
||
|
|
||
|
plt.suptitle("Decision surface of a decision tree using paired features")
|
||
|
plt.legend()
|
||
|
plt.show()
|