# 如何根據PyTorch的Model預測的output繪製出混淆矩陣(Confusion Matrix)並取得每個class的accuracy？

--

`y_pred = []y_true = []model.eval()    with torch.no_grad():    for i, (images, target) in enumerate(test_loader):        output = model(images)        _, preds = torch.max(output, 1)      #preds是預測結果        loss = criterion(output, target)                  y_pred.extend(preds.view(-1).detach().cpu().numpy())               y_true.extend(target.view(-1).detach().cpu().numpy())cf_matrix = confusion_matrix(y_true, y_pred)`

`from sklearn.metrics import confusion_matrixy_true = [2, 0, 2, 2, 0, 1]y_pred = [0, 0, 2, 2, 0, 2]matrix = confusion_matrix(y_true, y_pred)matrix.diagonal()/matrix.sum(axis=1)# stackoverflow上面有人在爭吵axis=0和axis=1哪一個才是正確的，我自己實驗後發現axis=1的計算結果比較正確。# 想了解差別axis=0和axis=1的差別可透過這篇的圖片理解`

`# 開始繪製混淆矩陣並存檔class_names = [0,1,2,3,4,5]df_cm = pd.DataFrame(cf_matrix, class_names, class_names) plt.figure(figsize = (9,6))sns.heatmap(df_cm, annot=True, fmt="d", cmap='BuGn')plt.xlabel("prediction")plt.ylabel("label (ground truth)")plt.savefig("confusion_matrix.png")# 注意，如果預測dataset的class數量與下方定義的class_names不同的話，請以final_class_name為主class_names = [0,1,2,3,4,5]clss_num = len(final_class_name)if clss_num==6 or clss_num==5 or clss_num==3 or clss_num==2:    df_cm = pd.DataFrame(cf_matrix, final_class_name, final_class_name)if clss_num==4:  #4的話則依舊以class_names為主    df_cm = pd.DataFrame(cf_matrix, class_names, class_names)if clss_num==1:    print("Only 1 class exists")    print("No confusion matrix")    return`

2021/09/04更新：混淆矩陣除了包含分類錯誤的數字外，也包含其判斷準確率

`def show_confusion_matrix(confusion_matrix, class_names):  cm = confusion_matrix.copy()  cell_counts = cm.flatten()  cm_row_norm = cm / cm.sum(axis=1)[:, np.newaxis]  row_percentages = ["{0:.2f}".format(value) for value in cm_row_norm.flatten()]  cell_labels = [f"{cnt}\n{per}" for cnt, per in zip(cell_counts, row_percentages)]  cell_labels = np.asarray(cell_labels).reshape(cm.shape[0], cm.shape[1])  df_cm = pd.DataFrame(cm_row_norm, index=class_names, columns=class_names)  hmap = sns.heatmap(df_cm, annot=cell_labels, fmt="", cmap="Blues")  hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')  hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')  plt.ylabel('True Sign')  plt.xlabel('Predicted Sign');`