如何修改PyTorch的Prediction結果?

假設今天我們的模型在一個擁有6種不同Label的資料集上進行訓練,模型用來預測的時候,會輸出0, 1, 2, 3, 4, 5這6種不同的數字。

然而,如果我們改變Testing方式,打算將第1, 2, 3, 4 ,5這5個數字,都視為同一個Label,也就是只剩下兩種Label的情況(0和1)

我們可以怎麼做呢?
在PyTorch程式碼進行Forward的過程中,會有類似以下的寫法。我們的方法就是直接將輸出大於1的結果都轉換成1,這樣就能得到只有2種Label的輸出結果了。

output = model(images)             #模型輸出得到Output
_, preds = torch.max(output, 1) #轉換成實際的label(preds)
# 將preds搭配lambda function,只要preds中的預測結果有大於1的數值,則全轉換成1
preds = torch.tensor([(lambda i: 1 if i > 1 else i)(i) for i in preds]).cuda()
target = torch.tensor([(lambda i: 1 if i > 1 else i)(i) for i in target]).cuda()# 如此一來,我們的模型就能以2個label的方式進行輸出