Is it possible to print the decision tree in scikit-learn?

There is a method to export to graph_viz format: http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html

So from the online docs:

>>> from sklearn.datasets import load_iris
>>> from sklearn import tree
>>>
>>> clf = tree.DecisionTreeClassifier()
>>> iris = load_iris()
>>>
>>> clf = clf.fit(iris.data, iris.target)
>>> tree.export_graphviz(clf,
...     out_file='tree.dot')    

Then you can load this using graph viz, or if you have pydot installed then you can do this more directly: http://scikit-learn.org/stable/modules/tree.html

>>> from sklearn.externals.six import StringIO  
>>> import pydot 
>>> dot_data = StringIO() 
>>> tree.export_graphviz(clf, out_file=dot_data) 
>>> graph = pydot.graph_from_dot_data(dot_data.getvalue()) 
>>> graph.write_pdf("iris.pdf") 

Will produce an svg, can't display it here so you'll have to follow the link: http://scikit-learn.org/stable/_images/iris.svg

Update

It seems that there has been a change in the behaviour since I first answered this question and it now returns a list and hence you get this error:

AttributeError: 'list' object has no attribute 'write_pdf'

Firstly when you see this it's worth just printing the object and inspecting the object, and most likely what you want is the first object:

graph[0].write_pdf("iris.pdf")

Thanks to @NickBraunagel for the comment


There are 4 methods which I'm aware of for plotting the scikit-learn decision tree:

  • print the text representation of the tree with sklearn.tree.export_text method
  • plot with sklearn.tree.plot_tree method (matplotlib needed)
  • plot with sklearn.tree.export_graphviz method (graphviz needed)
  • plot with dtreeviz package (dtreeviz and graphviz needed)

The simplest is to export to the text representation. The example decision tree will look like:

|--- feature_2 <= 2.45
|   |--- class: 0
|--- feature_2 >  2.45
|   |--- feature_3 <= 1.75
|   |   |--- feature_2 <= 4.95
|   |   |   |--- feature_3 <= 1.65
|   |   |   |   |--- class: 1
|   |   |   |--- feature_3 >  1.65
|   |   |   |   |--- class: 2
|   |   |--- feature_2 >  4.95
|   |   |   |--- feature_3 <= 1.55
|   |   |   |   |--- class: 2
|   |   |   |--- feature_3 >  1.55
|   |   |   |   |--- feature_0 <= 6.95
|   |   |   |   |   |--- class: 1
|   |   |   |   |--- feature_0 >  6.95
|   |   |   |   |   |--- class: 2
|   |--- feature_3 >  1.75
|   |   |--- feature_2 <= 4.85
|   |   |   |--- feature_1 <= 3.10
|   |   |   |   |--- class: 2
|   |   |   |--- feature_1 >  3.10
|   |   |   |   |--- class: 1
|   |   |--- feature_2 >  4.85
|   |   |   |--- class: 2

Then if you have matplotlib installed, you can plot with sklearn.tree.plot_tree:

tree.plot_tree(clf) # the clf is your decision tree model

The example output is similar to what you will get with export_graphviz: sklearn decision tree visualization

You can also try dtreeviz package. It will give you much more information. The example:

dtreeviz example decision tree

You can find a comparison of different visualization of sklearn decision tree with code snippets in this blog post: link.


Although I'm late to the game, the below comprehensive instructions could be useful for others who want to display decision tree output:

Install necessary modules:

  1. install graphviz. I used conda's install package here (recommended over pip install graphviz as pip install doesn't include the actual GraphViz executables)
  2. install pydot via pip (pip install pydot)
  3. Add the graphviz folder directory containing the .exe files (e.g. dot.exe) to your environment variable PATH
  4. run EdChum's above (NOTE: graph is a list containing the pydot.Dot object):

from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.externals.six import StringIO  
import pydot 

clf = tree.DecisionTreeClassifier()
iris = load_iris()
clf = clf.fit(iris.data, iris.target)

dot_data = StringIO() 
tree.export_graphviz(clf, out_file=dot_data) 
graph = pydot.graph_from_dot_data(dot_data.getvalue()) 

graph[0].write_pdf("iris.pdf")  # must access graph's first element

Now you'll find the "iris.pdf" within your environment's default directory