I was wondering if there was a simple solution to get recall and precision value for the classes of my classifier?
To put some context, I implemented a 20 classes CNN classifier using Tensorflow with the help of Denny Britz code : https://github.com/dennybritz/cnn-text-classification-tf .
As you can see at the end of text_cnn.py he implements a simple function to compute the global accuracy :
# Accuracy
with tf.name_scope("accuracy"):
correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_y, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, "float"), name="accuracy")
Any ideas on how i could do something similar to get the recall and precision value for the differents categories?
Maybe my question will sound dumb but I'm a bit lost with this to be honest. Thanks for the help.
Using tf.metrics did the trick for me :
#define the method
x = tf.placeholder(tf.int32, )
y = tf.placeholder(tf.int32, )
acc, acc_op = tf.metrics.accuracy(labels=x, predictions=y)
rec, rec_op = tf.metrics.recall(labels=x, predictions=y)
pre, pre_op = tf.metrics.precision(labels=x, predictions=y)
#predict the class using your classifier
scorednn = list(DNNClassifier.predict_classes(input_fn=lambda: input_fn(testing_set)))
scoreArr = np.array(scorednn).astype(int)
#run the session to compare the label with the prediction
sess=tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
v = sess.run(acc_op, feed_dict={x: testing_set["target"],y: scoreArr}) #accuracy
r = sess.run(rec_op, feed_dict={x: testing_set["target"],y: scoreArr}) #recall
p = sess.run(pre_op, feed_dict={x: testing_set["target"],y: scoreArr}) #precision
print("accuracy %f", v)
print("recall %f", r)
print("precision %f", p)
Result :
accuracy %f 0.686667
recall %f 0.978723
precision %f 0.824373
Note : for Accuracy I would use :
accuracy_score = DNNClassifier.evaluate(input_fn=lambda:input_fn(testing_set),steps=1)["accuracy"]
As it is simpler and already compute in the evaluate.
Also call variables_initializer if you don't want cumulative result.