parent
9cf37b315c
commit
7f4eccbecf
|
@ -22,7 +22,10 @@ def loss_convertor(loss_func, task_name):
|
||||||
task_data_samples.append(data_sample.get(task_name))
|
task_data_samples.append(data_sample.get(task_name))
|
||||||
|
|
||||||
if len(task_data_samples) == 0:
|
if len(task_data_samples) == 0:
|
||||||
return {'loss': torch.tensor(0.), 'mask_size': torch.tensor(0.)}
|
# This makes it possible to perform loss.backward when a
|
||||||
|
# task does not have gt_labels within a batch.
|
||||||
|
loss = (inputs[0] * 0).sum()
|
||||||
|
return {'loss': loss, 'mask_size': torch.tensor(0.)}
|
||||||
|
|
||||||
# Mask the inputs of the task
|
# Mask the inputs of the task
|
||||||
def mask_inputs(inputs, mask):
|
def mask_inputs(inputs, mask):
|
||||||
|
|
Loading…
Reference in New Issue