How To Use Mask to Select Specific Tensor in PyTorch?

for i, (images, target, target2) in enumerate(train_loader):        
images = images.cuda()
target = target.cuda()
target2 = target2.cuda()
output, output2 = model(images) # select labels that equal to 0.
mask = target2.eq(0)
# get the index of the mask.
selected_idx = torch.tensor([i for i,x in enumerate(mask) if x]).cuda()
# get output2 and target2 according to the selected_idx
output2 = torch.index_select(output2, 0, selected_idx)
target2 = torch.index_select(target2, 0, selected_idx)
loss_A = criterion_A(output, target)
loss_B = criterion_B(output2, target2)
loss = loss_A + loss_B




Machine Learning | Deep Learning |

Love podcasts or audiobooks? Learn on the go with our new app.

Recommended from Medium

SeFa — Finding Semantic Vectors in Latent Space for GANs

A Classification Model for Source Code Languages

Signal Processing on Cell Complexes with Cell Complex Networks (CXNs)


Evolutionary Reinforcement Learning in an Ecosystem Based Environment

Spiking Neural Networks

Visualization of spiking neurons (from


Introduction to FaceNet : A Unified Embedding for Face Recognition and Clustering

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Yanwei Liu

Yanwei Liu

Machine Learning | Deep Learning |

More from Medium

Visualizing Neural Network Activation

SVD and image compression

Ray Tracing from Scratch — Advanced 3D Image Data Augmentation in Python

Understanding Self Supervised Learning, with Examples