如何在訓練PyTorch時的每個Batch中使各個類別擁有相同數量的樣本?

本段落內容來自官方文檔案說明samplers.MPerClassSampler(labels, m, batch_size=None, length_before_new_iter=100000)labels: The list of labels for your dataset, i.e. the labels[x] should be the label of the xth element in your dataset.m: The number of samples per class to fetch at every iteration. If a class has less than m samples, then there will be duplicates in the returned batch.batch_size: Optional. If specified, then every batch is guaranteed to have m samples per class. There are a few restrictions on this value:(不使用)batch_size must be a multiple of mlength_before_new_iter >= batch_size must be truem * (number of unique labels) >= batch_size must be truelength_before_new_iter: How many iterations will pass before a new iterable is created.
  1. 使用上相當簡單,首先透過以下指令安裝pytorch-metric-learning
pip install pytorch-metric-learning 
from pytorch_metric_learning import samplers# 建立dataset
Train_set = CustomDataset(train_df, transform = transform_train)
# 建立sampler
per_cls_num
= opt.batch_size // len(Train_set.dataframe['class'].value_counts().index)
train_sampler = samplers.MPerClassSampler(Train_set.dataframe['class'], per_cls_num, batch_size=None, length_before_new_iter=len(Train_set))# 將dataset和sampler加入到DataLoader
# 特別注意的是shuffle和sampler兩者不可同時出現,只能擇一使用
training_loader = torch.utils.data.DataLoader(Train_set, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)

--

--

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Love podcasts or audiobooks? Learn on the go with our new app.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store