如何在訓練PyTorch時的每個Batch中使各個類別擁有相同數量的樣本?

Yanwei Liu
4 min readNov 7, 2021

--

最近這幾天在測試pytorch-metric-learning這個library,使用的過程中發現內建了一個名為MPerClassSampler。顧名思義,這個Sampler的用途就是讓每個class在訓練時都剛好擁有M張圖片,使得類別不平衡的問題得以被解決。

舉例來說,假設dataloader的batch size是100,且dataset包含了20種不同的class,用100/20會得到5。這個5就是每個class在單一batch訓練中,都剛好包含了5張圖片(即M張圖片),藉此達到平衡類別資料的功能。

本段落內容來自官方文檔案說明samplers.MPerClassSampler(labels, m, batch_size=None, length_before_new_iter=100000)labels: The list of labels for your dataset, i.e. the labels[x] should be the label of the xth element in your dataset.m: The number of samples per class to fetch at every iteration. If a class has less than m samples, then there will be duplicates in the returned batch.batch_size: Optional. If specified, then every batch is guaranteed to have m samples per class. There are a few restrictions on this value:(不使用)batch_size must be a multiple of mlength_before_new_iter >= batch_size must be truem * (number of unique labels) >= batch_size must be truelength_before_new_iter: How many iterations will pass before a new iterable is created.
  1. 使用上相當簡單,首先透過以下指令安裝pytorch-metric-learning
pip install pytorch-metric-learning 

2. 程式碼主要修改的部分如下:

from pytorch_metric_learning import samplers# 建立dataset
Train_set = CustomDataset(train_df, transform = transform_train)
# 建立sampler
per_cls_num
= opt.batch_size // len(Train_set.dataframe['class'].value_counts().index)
train_sampler = samplers.MPerClassSampler(Train_set.dataframe['class'], per_cls_num, batch_size=None, length_before_new_iter=len(Train_set))# 將dataset和sampler加入到DataLoader
# 特別注意的是shuffle和sampler兩者不可同時出現,只能擇一使用
training_loader = torch.utils.data.DataLoader(Train_set, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)

3. 其餘的程式碼,按照一般PyTorch訓練流程運作即可,不需要修改

--

--