PyTorch訓練時同時使用兩個不同的Dataloader

Yanwei Liu
Dec 17, 2021

假設模型在訓練過程中,需要同時使用不同的dataset進行訓練;則可以透過以下方式進行:

from itertools import cyclefor i, data in enumerate(zip(cycle(dataloaders1), dataloaders2)):
img_0, label_0 = data[0]
img_1, label_1 = data[1]
.
.
.
.
.
.
# 假設dataloaders1只有100張圖片;而dataloaders2可能有1000張圖片。
# 透過cycle(dataloaders1)的方式可以讓dataloaders1中的這100張圖片不斷地被循環抽出,直到dataloaders2的圖片被完整使用過一次,才會完成一整個epoch的訓練。
# 特別注意的是,因為batch size設定的原因,在訓練的最後一個mini-batch中,有可能會出現dataloaders1的數量(100張)比dataloaders2的圖片數量多的情形(假設dataloaders2只剩下90張)。在計算loss的時候,可能會出現shape不同的錯誤,因此要給dataloaders2設定drop_last=True的參數,才會避免這種情況發生。# 20220617更新:若圖片資料集較多的時候(十萬張圖的數量),使用cycle可能會造成dataLoader worker (pid 61577) is killed by signal: Killed.的錯誤產生,可將cycle去除,避免錯誤。

若是用以下的方式訓練:

for i, data in enumerate(zip(dataloaders1, dataloaders2)):
img_0, label_0 = data[0]
img_1, label_1 = data[1]
.
.
.
.
.
.
# 訓練過程中,將會受限於dataloaders1的100張圖片,dataloaders2的1000張圖片將不會被完整使用,每次利用完dataloaders1的圖片後,就會切換到下一個epoch

參考資料

--

--