how to explain the decision tree from scikit-learn
The value
line in each box is telling you how many samples at that node fall into each category, in order. That's why, in each box, the numbers in value
add up to the number shown in sample
. For instance, in your red box, 91+212+113=416. So this means if you reach this node, there were 91 data points in category 1, 212 in category 2, and 113 in category 3.
If you were going to predict the outcome for a new data point that reached that leaf in the decision tree, you would predict category 2, because that is the most common category for samples at that node.
According to the book "Learning scikit-learn: Machine Learning in Python", The decision tree represents a series of decisions based on the training data.
!(http://i.imgur.com/vM9fJLy.png)
To classify an instance, we should answer the question at each node. For example, Is sex<=0.5? (are we talking about a woman?). If the answer is yes, you go to the left child node in the tree; otherwise you go to the right child node. You keep answering questions (was she in the third class?, was she in the first class?, and was she below 13 years old?), until you reach a leaf. When you are there, the prediction corresponds to the target class that has most instances.
First question: Yes, your logic is correct. The left node is True and the right node is False. This can be counter-intuitive; true can equate to a smaller sample.
Second question: This problem is best resolved by visualizing the tree as a graph with pydotplus. The 'class_names' attribute of tree.export_graphviz() will add a class declaration to the majority class of each node. Code is executed in an iPython notebook.
from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
clf2 = tree.DecisionTreeClassifier()
clf2 = clf2.fit(iris.data, iris.target)
with open("iris.dot", 'w') as f:
f = tree.export_graphviz(clf, out_file=f)
import os
os.unlink('iris.dot')
import pydotplus
dot_data = tree.export_graphviz(clf2, out_file=None)
graph2 = pydotplus.graph_from_dot_data(dot_data)
graph2.write_pdf("iris.pdf")
from IPython.display import Image
dot_data = tree.export_graphviz(clf2, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True, # leaves_parallel=True,
special_characters=True)
graph2 = pydotplus.graph_from_dot_data(dot_data)
## Color of nodes
nodes = graph2.get_node_list()
for node in nodes:
if node.get_label():
values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')];
color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],}
values = color[values.index(max(values))]; # print(values)
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color)
node.set_fillcolor(color )
#
Image(graph2.create_png() )
As for determining the class at the leaf, your example doesn't have leaves with a single class, as the iris data set does. This is common and may require over-fitting the model to attain such an outcome. A discrete distribution of classes is best result for many cross-validated models.