parent
9cf37b315c
commit
7f4eccbecf
mmpretrain/models/heads
|
@ -22,7 +22,10 @@ def loss_convertor(loss_func, task_name):
|
|||
task_data_samples.append(data_sample.get(task_name))
|
||||
|
||||
if len(task_data_samples) == 0:
|
||||
return {'loss': torch.tensor(0.), 'mask_size': torch.tensor(0.)}
|
||||
# This makes it possible to perform loss.backward when a
|
||||
# task does not have gt_labels within a batch.
|
||||
loss = (inputs[0] * 0).sum()
|
||||
return {'loss': loss, 'mask_size': torch.tensor(0.)}
|
||||
|
||||
# Mask the inputs of the task
|
||||
def mask_inputs(inputs, mask):
|
||||
|
|
Loading…
Reference in New Issue