How To Use Mask to Select Specific Tensor in PyTorch?

for i, (images, target, target2) in enumerate(train_loader):        
images = images.cuda()
target = target.cuda()
target2 = target2.cuda()
output, output2 = model(images) # select labels that equal to 0.
mask = target2.eq(0)
# get the index of the mask.
selected_idx = torch.tensor([i for i,x in enumerate(mask) if x]).cuda()
# get output2 and target2 according to the selected_idx
output2 = torch.index_select(output2, 0, selected_idx)
target2 = torch.index_select(target2, 0, selected_idx)
loss_A = criterion_A(output, target)
loss_B = criterion_B(output2, target2)
loss = loss_A + loss_B




Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store