如何在訓練PyTorch時的每個Batch中使各個類別擁有相同數量的樣本?

最近這幾天在測試pytorch-metric-learning這個library,使用的過程中發現內建了一個名為MPerClassSampler。顧名思義,這個Sampler的用途就是讓每個class在訓練時都剛好擁有M張圖片,使得類別不平衡的問題得以被解決。

舉例來說,假設dataloader的batch size是100,且dataset包含了20種不同的class,用100/20會得到5。這個5就是每個class在單一batch訓練中,都剛好包含了5張圖片(即M張圖片),藉此達到平衡類別資料的功能。

本段落內容來自官方文檔案說明
  1. 使用上相當簡單,首先透過以下指令安裝pytorch-metric-learning
pip install pytorch-metric-learning 

2. 程式碼主要修改的部分如下:

# 建立dataset
Train_set = CustomDataset(train_df, transform = transform_train)

3. 其餘的程式碼,按照一般PyTorch訓練流程運作即可,不需要修改

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Machine Learning | Deep Learning | https://linktr.ee/yanwei