如何根據PyTorch的Model預測的output繪製出混淆矩陣(Confusion Matrix)並取得每個class的accuracy?
5 min readJun 14, 2021
--
參考資料
完整程式:計算並保存混淆矩陣圖、列印出每個class的Accuracy
取得預測結果並計算混淆矩陣
y_pred = []
y_true = []model.eval()
with torch.no_grad():
for i, (images, target) in enumerate(test_loader):
output = model(images)
_, preds = torch.max(output, 1) #preds是預測結果
loss = criterion(output, target)
y_pred.extend(preds.view(-1).detach().cpu().numpy())
y_true.extend(target.view(-1).detach().cpu().numpy())cf_matrix = confusion_matrix(y_true, y_pred)
每個class的accuracy
from sklearn.metrics import confusion_matrix
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
matrix = confusion_matrix(y_true, y_pred)
matrix.diagonal()/matrix.sum(axis=1)# stackoverflow上面有人在爭吵axis=0和axis=1哪一個才是正確的,我自己實驗後發現axis=1的計算結果比較正確。
# 想了解差別axis=0和axis=1的差別可透過這篇的圖片理解
繪製並保存混淆矩陣
# 開始繪製混淆矩陣並存檔
class_names = [0,1,2,3,4,5]df_cm = pd.DataFrame(cf_matrix, class_names, class_names)
plt.figure(figsize = (9,6))
sns.heatmap(df_cm, annot=True, fmt="d", cmap='BuGn')
plt.xlabel("prediction")
plt.ylabel("label (ground truth)")
plt.savefig("confusion_matrix.png")# 注意,如果預測dataset的class數量與下方定義的class_names不同的話,請以final_class_name為主class_names = [0,1,2,3,4,5]clss_num = len(final_class_name)if clss_num==6 or clss_num==5 or clss_num==3 or clss_num==2:
df_cm = pd.DataFrame(cf_matrix, final_class_name, final_class_name)if clss_num==4: #4的話則依舊以class_names為主
df_cm = pd.DataFrame(cf_matrix, class_names, class_names)if clss_num==1:
print("Only 1 class exists")
print("No confusion matrix")
return
2021/09/04更新:混淆矩陣除了包含分類錯誤的數字外,也包含其判斷準確率
def show_confusion_matrix(confusion_matrix, class_names):
cm = confusion_matrix.copy()
cell_counts = cm.flatten()
cm_row_norm = cm / cm.sum(axis=1)[:, np.newaxis]
row_percentages = ["{0:.2f}".format(value) for value in cm_row_norm.flatten()]
cell_labels = [f"{cnt}\n{per}" for cnt, per in zip(cell_counts, row_percentages)]
cell_labels = np.asarray(cell_labels).reshape(cm.shape[0], cm.shape[1])
df_cm = pd.DataFrame(cm_row_norm, index=class_names, columns=class_names)
hmap = sns.heatmap(df_cm, annot=cell_labels, fmt="", cmap="Blues")
hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')
hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')
plt.ylabel('True Sign')
plt.xlabel('Predicted Sign');