From 7f4eccbecfaf8cddf06bcec3967a740907129c81 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com> Date: Sat, 6 May 2023 18:04:34 +0800 Subject: [PATCH] [Fix] Fix multi-task-head loss potential bug (#1530) * fix bug * add comments --- mmpretrain/models/heads/multi_task_head.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mmpretrain/models/heads/multi_task_head.py b/mmpretrain/models/heads/multi_task_head.py index ac949413..8b4645a7 100644 --- a/mmpretrain/models/heads/multi_task_head.py +++ b/mmpretrain/models/heads/multi_task_head.py @@ -22,7 +22,10 @@ def loss_convertor(loss_func, task_name): task_data_samples.append(data_sample.get(task_name)) if len(task_data_samples) == 0: - return {'loss': torch.tensor(0.), 'mask_size': torch.tensor(0.)} + # This makes it possible to perform loss.backward when a + # task does not have gt_labels within a batch. + loss = (inputs[0] * 0).sum() + return {'loss': loss, 'mask_size': torch.tensor(0.)} # Mask the inputs of the task def mask_inputs(inputs, mask):