PyTorch如何列出分類錯誤之原始圖片路徑?

參考資料

前提

本文假設batch size使用1的情況下,將分類錯誤的原始圖片路徑、實際標籤、預測標籤寫入到missclassified_file_path.txt檔案當中。

因為每張圖片都必須要有路徑,PyTorch的DataLoader才能順利載入圖片,也因此,如果希望獲得圖片的路徑,修改DataLoader讓它能夠return圖片原始路徑即可。如程式碼的第10至14行

關鍵在程式碼的第51~58行

if args.batch_size==1 and pred_index.cpu() != target.cpu():
with open("missclassified_file_path.txt", "a") as f:
true_label = str(target.cpu().tolist()[0])
pred_label = str(pred_index.cpu().tolist()[0])
write_content = ' '.join(path)
write_content = write_content + ' ' + true_label + ' ' + pred_label
f.write(write_content+"\n")

程式碼

根據txt文件中路徑,顯示圖片的預測 label及實際 label

import pandas as pd
import matplotlib.pyplot as plt
from pylab import *
df = pd.read_csv('../missclassified_file_path.txt', sep=" ", header=None)df.columns = ["path", "true", "prediction", "component_name"]df.loc[df['true'] == 0 , 'true'] = 'Good'
df.loc[df['true'] == 1 , 'true'] = 'Bad'
df.loc[df['prediction'] == 0 , 'prediction'] = 'Good'
df.loc[df['prediction'] == 1 , 'prediction'] = 'Bad'
data = df.loc[df['component_name'] == 'XXXXXX'][0::]
data_list = [path for path in data['path']]

predict_class = data['prediction'].tolist()
true_class = data['true'].tolist()
c_file_list = data['path'].tolist()
iclass = iter(true_class)
prediction = iter(predict_class)
def showImagesHorizontally(list_of_files):
fig = plt.figure(figsize=(20, 10), dpi=80)
number_of_files = len(list_of_files)
for i in range(number_of_files):
a=fig.add_subplot(5,9,i+1)
plt.title(f"True: {next(iclass)}\n Prediction: {next(prediction)}")
plt.tight_layout()
image = imread(list_of_files[i])
imshow(image,cmap='Greys_r')
axis('off')

showImagesHorizontally(data_list)

--

--

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Love podcasts or audiobooks? Learn on the go with our new app.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store