Member-only story
如何根據PyTorch的Model預測的output繪製出混淆矩陣(Confusion Matrix)並取得每個class的accuracy?
5 min readJun 14, 2021
參考資料
完整程式:計算並保存混淆矩陣圖、列印出每個class的Accuracy
取得預測結果並計算混淆矩陣
y_pred = []
y_true = []model.eval()
with torch.no_grad():
for i, (images, target) in enumerate(test_loader):
output = model(images)
_, preds = torch.max(output, 1) #preds是預測結果
loss = criterion(output, target)…