How To Use Mask to Select Specific Tensor in PyTorch?

Yanwei Liu
1 min readMay 4, 2022

Suppose we have a image dataset with two kind of labels (target and target2 in the example code below), we want to train a model that uses information from both labels, but only uses part of the target2 label for training.

for i, (images, target, target2) in enumerate(train_loader):        
images = images.cuda()
target = target.cuda()
target2 = target2.cuda()
output, output2 = model(images) # select labels that equal to 0.
mask = target2.eq(0)
# get the index of the mask.
selected_idx = torch.tensor([i for i,x in enumerate(mask) if x]).cuda()
# get output2 and target2 according to the selected_idx
output2 = torch.index_select(output2, 0, selected_idx)
target2 = torch.index_select(target2, 0, selected_idx)
loss_A = criterion_A(output, target)
loss_B = criterion_B(output2, target2)
loss = loss_A + loss_B

Reference

--

--