displaying scikit decision tree figure in jupyter notebook

As of scikit-learn version 21.0 (roughly May 2019), Decision Trees can now be plotted with matplotlib using scikit-learn’s tree.plot_tree without relying on graphviz.

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

X, y = load_iris(return_X_y=True)

# Make an instance of the Model
clf = DecisionTreeClassifier(max_depth = 5)

# Train the model on the data
clf.fit(X, y)

fn=['sepal length (cm)','sepal width (cm)','petal length (cm)','petal width (cm)']
cn=['setosa', 'versicolor', 'virginica']

# Setting dpi = 300 to make image clearer than default
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=300)

tree.plot_tree(clf,
           feature_names = fn, 
           class_names=cn,
           filled = True);

# You can save your plot if you want
#fig.savefig('imagename.png')

Something similar to what is below will output in your jupyter notebook.
enter image description here

The code was adapted from this post.


There is a simple library called graphviz which you can use to view your decision tree. In this you don't have to export the graphic, it'll directly open the graphic of tree for you and you can later decide if you want to save it or not. You can use it as following -

import graphviz
from sklearn.tree import DecisionTreeClassifier()
from sklearn import tree

clf = DecisionTreeClassifier()
clf.fit(trainX,trainY)
columns=list(trainX.columns)
dot_data = tree.export_graphviz(clf,out_file=None,feature_names=columns,class_names=True)
graph = graphviz.Source(dot_data)
graph.render("image",view=True)
f = open("classifiers/classifier.txt","w+")
f.write(dot_data)
f.close()

because of view = True your graphs will open up as soon as they're rendered but if you don't want that and just want to save graphs, you can use view = False

Hope this helps


You can show the tree directly using IPython.display:

import graphviz
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier,export_graphviz
from sklearn.datasets import make_regression

# Generate a simple dataset
X, y = make_regression(n_features=2, n_informative=2, random_state=0)
clf = DecisionTreeRegressor(random_state=0, max_depth=2)
clf.fit(X, y)
# Visualize the tree
from IPython.display import display
display(graphviz.Source(export_graphviz(clf)))