I'm following the tutorial for decision tree on scikit documentation.
I have pydotplus 2.0.2
but it is telling me that it does not have write
method - error below. I've been struggling for a while with it now, any ideas, please? Many thanks!
from sklearn import tree
from sklearn.datasets import load_iris
iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
from IPython.display import Image
dot_data = tree.export_graphviz(clf, out_file=None)
import pydotplus
graph = pydotplus.graphviz.graph_from_dot_data(dot_data)
Image(graph.create_png())
and my error is
/Users/air/anaconda/bin/python /Users/air/PycharmProjects/kiwi/hemr.py
Traceback (most recent call last):
File "/Users/air/PycharmProjects/kiwi/hemr.py", line 10, in <module>
dot_data = tree.export_graphviz(clf, out_file=None)
File "/Users/air/anaconda/lib/python2.7/site-packages/sklearn/tree/export.py", line 375, in export_graphviz
out_file.write('digraph Tree {\n')
AttributeError: 'NoneType' object has no attribute 'write'
Process finished with exit code 1
----- UPDATE -----
Using the fix with out_file
, it throws another error:
Traceback (most recent call last):
File "/Users/air/PycharmProjects/kiwi/hemr.py", line 13, in <module>
graph = pydotplus.graphviz.graph_from_dot_data(dot_data)
File "/Users/air/anaconda/lib/python2.7/site-packages/pydotplus/graphviz.py", line 302, in graph_from_dot_data
return parser.parse_dot_data(data)
File "/Users/air/anaconda/lib/python2.7/site-packages/pydotplus/parser.py", line 548, in parse_dot_data
if data.startswith(codecs.BOM_UTF8):
AttributeError: 'NoneType' object has no attribute 'startswith'
---- UPDATE 2 -----
Also, se my own answer below which solves another problem
The problem is that you are setting the parameter out_file
to None
.
If you look at the documentation, if you set it at None
it returns the string
file directly and does not create a file. And of course a string
does not have a write
method.
Therefore, do as follows :
dot_data = tree.export_graphviz(clf)
graph = pydotplus.graphviz.graph_from_dot_data(dot_data)