How To Use Mask to Select Specific Tensor in PyTorch?

for i, (images, target, target2) in enumerate(train_loader):        
images = images.cuda()
target = target.cuda()
target2 = target2.cuda()
output, output2 = model(images) # select labels that equal to 0.
mask = target2.eq(0)
# get the index of the mask.
selected_idx = torch.tensor([i for i,x in enumerate(mask) if x]).cuda()
# get output2 and target2 according to the selected_idx
output2 = torch.index_select(output2, 0, selected_idx)
target2 = torch.index_select(target2, 0, selected_idx)
loss_A = criterion_A(output, target)
loss_B = criterion_B(output2, target2)
loss = loss_A + loss_B

--

--

--

Machine Learning | Deep Learning | https://linktr.ee/yanwei

Love podcasts or audiobooks? Learn on the go with our new app.

Recommended from Medium

How Machine Learning Actually Works…

Recommendation system in python using ALS algorithm and Apache Spark

Robust Financial Forecasting using ML

Deep Q Learning — Becoming a prophet.

GAN Inversion: A brief walkthrough — Part IV

What’s the difference between gradient descent and stochastic gradient descent?

How To Improve Your Model’s Performance Using Cross-Validation Techniques

NLP Lecture 14 @ CMU — A Watch👓 & Read Treat🍨

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Yanwei Liu

Yanwei Liu

Machine Learning | Deep Learning | https://linktr.ee/yanwei

More from Medium

Model Interpretation using GradCAM

Transfer learning and active learning to find images of a small class in an unlabeled dataset.

Bear Classification: From Data Collection to GUI for Model Inference

png