PyTorch如何在訓練過程中對多種不同損失函數的損失值進行反向傳播?
假設output為模型的輸出;target為ground truth label;而criterion_a、criterion_b、criterion_c分別為三種不同的損失函數,其形式可能是其他寫法,但本文為求簡潔,故簡略寫成以下之形式:
loss_a = criterion_a(output, target)
loss_b = criterion_b(output, target)
loss_c = criterion_c(output, target)
此時可有兩種作法:
1. 相加後直接進行反向傳播
通常這一種作法不太會有錯誤產生,接下來介紹的第二種可能就會有錯誤出現。
loss = loss_a + loss_b + loss_c
loss.backward()
2.各別進行反向傳播
若按照下面的寫法可能會出現RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.的錯誤
loss_a.backward()
loss_b.backward()
loss_c.backward()
可以將第一個進行backward的loss添加retain_graph=True,如下所示:
loss_a.backward(retain_graph=True)
loss_b.backward(retain_graph=True) # 2022/04/24修正
loss_c.backward()
參考資料
pytorch程序中错误的集锦 — 知乎 (zhihu.com)
RuntimeError: Trying to backward through the graph a second time…_Huiyu Blog-CSDN博客