I don't know if it is the right question to ask here, but I will ask anyways. If it is not allowed please do let me know.
I have used GridSearchCV
to tune parameters to find best accuracy. This is what I have done:
from sklearn.grid_search import GridSearchCV
parameters = {'min_samples_split':np.arange(2, 80), 'max_depth': np.arange(2,10), 'criterion':['gini', 'entropy']}
clfr = DecisionTreeClassifier()
grid = GridSearchCV(clfr, parameters,scoring='accuracy', cv=8)
grid.fit(X_train,y_train)
print('The parameters combination that would give best accuracy is : ')
print(grid.best_params_)
print('The best accuracy achieved after parameter tuning via grid search is : ', grid.best_score_)
This gives me following result:
The parameters combination that would give best accuracy is :
{'max_depth': 5, 'criterion': 'entropy', 'min_samples_split': 2}
The best accuracy achieved after parameter tuning via grid search is : 0.8147086914995224
Now, I want to use these parameters while calling a function that visualizes a decision tree
The function looks something like this
def visualize_decision_tree(decision_tree, feature, target):
dot_data = export_graphviz(decision_tree, out_file=None,
feature_names=feature,
class_names=target,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
return Image(graph.create_png())
Right now I am trying to use the best parameters provided by GridSearchCV to call the function in the following way
dtBestScore = DecisionTreeClassifier(parameters = grid.best_params_)
dtBestScore = dtBestScore.fit(X=dfWithTrainFeatures, y= dfWithTestFeature)
visualize_decision_tree(dtBestScore, list(dfCopy.columns.delete(0).values), 'survived')
I am getting error in first line of code which says
TypeError: __init__() got an unexpected keyword argument 'parameters'
Is there some way I can somehow manage to use the best parameters provided by grid search and use it automatically? Rather than looking the result and manually setting value of each parameter?
Try python kwargs:
DecisionTreeClassifier(**grid.best_params)
See http://pythontips.com/2013/08/04/args-and-kwargs-in-python-explained for more on kwargs.