[Fix] Fix multi-task-head loss potential bug ()

* fix bug

* add comments
pull/1503/head
Ezra-Yu 2023-05-06 18:04:34 +08:00 committed by GitHub
parent 9cf37b315c
commit 7f4eccbecf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 1 deletions
mmpretrain/models/heads

View File

@ -22,7 +22,10 @@ def loss_convertor(loss_func, task_name):
task_data_samples.append(data_sample.get(task_name))
if len(task_data_samples) == 0:
return {'loss': torch.tensor(0.), 'mask_size': torch.tensor(0.)}
# This makes it possible to perform loss.backward when a
# task does not have gt_labels within a batch.
loss = (inputs[0] * 0).sum()
return {'loss': loss, 'mask_size': torch.tensor(0.)}
# Mask the inputs of the task
def mask_inputs(inputs, mask):