如何根據PyTorch的Model預測的output繪製出混淆矩陣(Confusion Matrix)並取得每個class的accuracy?

參考資料

完整程式:計算並保存混淆矩陣圖、列印出每個class的Accuracy

取得預測結果並計算混淆矩陣

y_pred = []
y_true = []

每個class的accuracy

from sklearn.metrics import confusion_matrix
y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
matrix = confusion_matrix(y_true, y_pred)
matrix.diagonal()/matrix.sum(axis=1)

繪製並保存混淆矩陣

# 開始繪製混淆矩陣並存檔
class_names = [0,1,2,3,4,5]

2021/09/04更新:混淆矩陣除了包含分類錯誤的數字外,也包含其判斷準確率

def show_confusion_matrix(confusion_matrix, class_names):
cm = confusion_matrix.copy()
cell_counts = cm.flatten()
cm_row_norm = cm / cm.sum(axis=1)[:, np.newaxis]
row_percentages = ["{0:.2f}".format(value) for value in cm_row_norm.flatten()]
cell_labels = [f"{cnt}\n{per}" for cnt, per in zip(cell_counts, row_percentages)]
cell_labels = np.asarray(cell_labels).reshape(cm.shape[0], cm.shape[1])
df_cm = pd.DataFrame(cm_row_norm, index=class_names, columns=class_names)
hmap = sns.heatmap(df_cm, annot=cell_labels, fmt="", cmap="Blues")
hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')
hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')
plt.ylabel('True Sign')
plt.xlabel('Predicted Sign');

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Love podcasts or audiobooks? Learn on the go with our new app.