I am using scikit-learn for classification of text documents(22000) to 100 classes. I use scikit-learn's confusion matrix method for computing the confusion matrix.
model1 = LogisticRegression()
model1 = model1.fit(matrix, labels)
pred = model1.predict(test_matrix)
cm=metrics.confusion_matrix(test_labels,pred)
print(cm)
plt.imshow(cm, cmap='binary')
This is how my confusion matrix looks like:
[[3962 325 0 ..., 0 0 0]
[ 250 2765 0 ..., 0 0 0]
[ 2 8 17 ..., 0 0 0]
...,
[ 1 6 0 ..., 5 0 0]
[ 1 1 0 ..., 0 0 0]
[ 9 0 0 ..., 0 0 9]]
However, I do not receive a clear or legible plot. Is there a better way to do this?
you can use plt.matshow()
instead of plt.imshow()
or you can use seaborn module's heatmap
(see documentation) to plot the confusion matrix
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
array = [[33,2,0,0,0,0,0,0,0,1,3],
[3,31,0,0,0,0,0,0,0,0,0],
[0,4,41,0,0,0,0,0,0,0,1],
[0,1,0,30,0,6,0,0,0,0,1],
[0,0,0,0,38,10,0,0,0,0,0],
[0,0,0,3,1,39,0,0,0,0,4],
[0,2,2,0,4,1,31,0,0,0,2],
[0,1,0,0,0,0,0,36,0,2,0],
[0,0,0,0,0,0,1,5,37,5,1],
[3,0,0,0,0,0,0,0,0,39,0],
[0,0,0,0,0,0,0,0,0,0,38]]
df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"],
columns = [i for i in "ABCDEFGHIJK"])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True)