PyTorch訓練時同時使用兩個不同的Dataloader

from itertools import cyclefor i, data in enumerate(zip(cycle(dataloaders1), dataloaders2)):
img_0, label_0 = data[0]
img_1, label_1 = data[1]
.
.
.
.
.
.
# 假設dataloaders1只有100張圖片;而dataloaders2可能有1000張圖片。
# 透過cycle(dataloaders1)的方式可以讓dataloaders1中的這100張圖片不斷地被循環抽出,直到dataloaders2的圖片被完整使用過一次,才會完成一整個epoch的訓練。
# 特別注意的是,因為batch size設定的原因,在訓練的最後一個mini-batch中,有可能會出現dataloaders1的數量(100張)比dataloaders2的圖片數量多的情形(假設dataloaders2只剩下90張)。在計算loss的時候,可能會出現shape不同的錯誤,因此要給dataloaders2設定drop_last=True的參數,才會避免這種情況發生。
for i, data in enumerate(zip(dataloaders1, dataloaders2)):
img_0, label_0 = data[0]
img_1, label_1 = data[1]
.
.
.
.
.
.
# 訓練過程中,將會受限於dataloaders1的100張圖片,dataloaders2的1000張圖片將不會被完整使用,每次利用完dataloaders1的圖片後,就會切換到下一個epoch

--

--

--

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Love podcasts or audiobooks? Learn on the go with our new app.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Yanwei Liu

Yanwei Liu

Machine Learning | Deep Learning | https://linktr.ee/yanwei

More from Medium

Invasive Species Monitoring: Using a Convolutional Neural Network to identify hydrangeas

Introduction to Neural Networks

SIIM FISABIO RSNA COVID-19 DETECTION

Parameters in Convolutional Network