PyTorch如何列出分類錯誤之原始圖片路徑?

參考資料

前提

本文假設batch size使用1的情況下,將分類錯誤的原始圖片路徑、實際標籤、預測標籤寫入到missclassified_file_path.txt檔案當中。

因為每張圖片都必須要有路徑,PyTorch的DataLoader才能順利載入圖片,也因此,如果希望獲得圖片的路徑,修改DataLoader讓它能夠return圖片原始路徑即可。如程式碼的第10至14行

關鍵在程式碼的第51~58行

if args.batch_size==1 and pred_index.cpu() != target.cpu():
with open("missclassified_file_path.txt", "a") as f:
true_label = str(target.cpu().tolist()[0])
pred_label = str(pred_index.cpu().tolist()[0])
write_content = ' '.join(path)
write_content = write_content + ' ' + true_label + ' ' + pred_label
f.write(write_content+"\n")

程式碼

Machine Learning | Deep Learning | https://linktr.ee/yanwei