NameError: global name 'export_graphviz' is not defined
On OSX high sierra I'm trying to implement my first decision tree on Spotify data following a YT tutorial. I'm trying to build the png of the tree using export_graphviz method, but the terminal in Jupyter is returning an error:
NameError Traceback (most recent call last)
<ipython-input-16-a4130fba6e1d> in <module>()
----> 1 show_tree(dt, features, 'dec_tree_01.png')
<ipython-input-15-9051b2c19b1b> in show_tree(tree, features, path)
1 def show_tree(tree, features, path):
2 f = io.StringIO()
----> 3 export_graphviz(tree, out_file = f, feature_names=features)
4 pydotplus.graph_from_dot_data(f.getvalue()).write_png(path)
5 img = misc.imread(path)
NameError: global name 'export_graphviz' is not defined
Here's the source code:
import pandas as pd
import numpy as np
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import seaborn as sns
import graphviz
import pydotplus
import io
from scipy import misc
data = pd.read_csv('Spotify.csv')
train, test = train_test_split(data, test_size = 0.15)
pos_tempo = data[data['target']==1]['tempo']
neg_tempo = data[data['target']==0]['tempo']
fig = plt.figure(figsize=(12, 8))
plt.title('Song Tempo Like / Dislike Distribution')
pos_tempo.hist(alpha = 0.7, bins = 30, label = 'positive')
neg_tempo.hist(alpha = 0.7, bins = 30, label = 'negative')
plt.legend(loc = 'upper right')
c = DecisionTreeClassifier(min_samples_split=100)
features = ['acousticness', 'danceability', 'duration_ms', 'energy', 'instrumentalness', 'key', 'liveness', 'loudness', 'mode', 'speechiness', 'tempo', 'target', 'time_signature', 'valence']
X_train = train[features]
Y_train = train['target']
X_test = test[features]
Y_test = test['target']
dt = c.fit(X_train, Y_train)
def show_tree(tree, features, path):
f = io.StringIO()
export_graphviz(tree, out_file = f, feature_names=features)
pydotplus.graph_from_dot_data(f.getvalue()).write_png(path)
img = misc.imread(path)
plt.rcParams['figure.figsize'] = (20, 20)
plt.imshow(img)
show_tree(dt, features, 'dec_tree_01.png')
Could someone please help me?
Edited by Davide Montanari