Merge 3f25e6b59e
into 64c446d507
commit
5cedc5527a
|
@ -119,12 +119,28 @@ class MultiTaskHead(BaseModule):
|
|||
predictions_dict = dict()
|
||||
|
||||
for task_name, head in self.task_heads.items():
|
||||
task_samples = head.predict(feats)
|
||||
task_samples = None
|
||||
if data_samples is not None:
|
||||
task_samples = []
|
||||
for data_sample in data_samples:
|
||||
if data_sample is None:
|
||||
task_samples.append(None)
|
||||
elif task_name in data_sample.tasks:
|
||||
task_samples.append(data_sample.get(task_name))
|
||||
else:
|
||||
task_samples.append(None)
|
||||
|
||||
task_samples = head.predict(feats, task_samples)
|
||||
batch_size = len(task_samples)
|
||||
predictions_dict[task_name] = task_samples
|
||||
|
||||
if data_samples is None:
|
||||
data_samples = [MultiTaskDataSample() for _ in range(batch_size)]
|
||||
else:
|
||||
data_samples = [
|
||||
MultiTaskDataSample() if data_sample is None else data_sample
|
||||
for data_sample in data_samples
|
||||
]
|
||||
|
||||
for task_name, task_samples in predictions_dict.items():
|
||||
for data_sample, task_sample in zip(data_samples, task_samples):
|
||||
|
|
|
@ -549,7 +549,7 @@ class TestMultiTaskHead(TestCase):
|
|||
data_sample.set_field(task_sample, task_name)
|
||||
data_samples.append(data_sample)
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
# with without data_samples
|
||||
# without data_samples
|
||||
predictions = head.predict(feats)
|
||||
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
|
||||
for pred in predictions:
|
||||
|
@ -564,6 +564,32 @@ class TestMultiTaskHead(TestCase):
|
|||
self.assertIs(sample, pred)
|
||||
self.assertIn('task0', pred)
|
||||
|
||||
# with data samples and nested
|
||||
head_nested = MODELS.build(self.DEFAULT_ARGS2)
|
||||
# adding a None data sample at the beginning
|
||||
data_samples_nested = [None]
|
||||
for _ in range(3):
|
||||
data_sample_nested = MultiTaskDataSample()
|
||||
data_sample_nested0 = MultiTaskDataSample()
|
||||
data_sample_nested0.set_field(DataSample().set_gt_label(1),
|
||||
'task00')
|
||||
data_sample_nested0.set_field(DataSample().set_gt_label(1),
|
||||
'task01')
|
||||
data_sample_nested.set_field(data_sample_nested0, 'task0')
|
||||
data_sample_nested.set_field(DataSample().set_gt_label(1), 'task1')
|
||||
data_samples_nested.append(data_sample_nested)
|
||||
|
||||
predictions = head_nested.predict(feats, data_samples_nested)
|
||||
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
|
||||
for i in range(3):
|
||||
sample = data_samples_nested[i + 1]
|
||||
pred = predictions[i + 1]
|
||||
self.assertIn('task0', pred)
|
||||
self.assertIn('task1', pred)
|
||||
self.assertIn('task01', pred.get('task0'))
|
||||
self.assertEqual(
|
||||
sample.get('task0').get('task01').gt_label.numpy()[0], 1)
|
||||
|
||||
def test_loss_empty_data_sample(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
data_samples = []
|
||||
|
|
Loading…
Reference in New Issue