[Fix] Fix nested predict for multi-task prediction. (#1716)

* fix: multi task predict

* change the loop

---------

Co-authored-by: Pierre Colle <piercus@gmail.com>
pull/1670/head
marouane amzil 2023-07-28 07:44:12 +02:00 committed by GitHub
parent 64c446d507
commit e7fc25cf64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 2 deletions

View File

@ -119,12 +119,24 @@ class MultiTaskHead(BaseModule):
predictions_dict = dict()
for task_name, head in self.task_heads.items():
task_samples = head.predict(feats)
task_samples = None
if data_samples is not None:
task_samples = [
data_sample.get(task_name, None) if data_sample else None
for data_sample in data_samples
]
task_samples = head.predict(feats, task_samples)
batch_size = len(task_samples)
predictions_dict[task_name] = task_samples
if data_samples is None:
data_samples = [MultiTaskDataSample() for _ in range(batch_size)]
else:
data_samples = [
MultiTaskDataSample() if data_sample is None else data_sample
for data_sample in data_samples
]
for task_name, task_samples in predictions_dict.items():
for data_sample, task_sample in zip(data_samples, task_samples):

View File

@ -549,7 +549,7 @@ class TestMultiTaskHead(TestCase):
data_sample.set_field(task_sample, task_name)
data_samples.append(data_sample)
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
# without data_samples
predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
for pred in predictions:
@ -564,6 +564,32 @@ class TestMultiTaskHead(TestCase):
self.assertIs(sample, pred)
self.assertIn('task0', pred)
# with data samples and nested
head_nested = MODELS.build(self.DEFAULT_ARGS2)
# adding a None data sample at the beginning
data_samples_nested = [None]
for _ in range(3):
data_sample_nested = MultiTaskDataSample()
data_sample_nested0 = MultiTaskDataSample()
data_sample_nested0.set_field(DataSample().set_gt_label(1),
'task00')
data_sample_nested0.set_field(DataSample().set_gt_label(1),
'task01')
data_sample_nested.set_field(data_sample_nested0, 'task0')
data_sample_nested.set_field(DataSample().set_gt_label(1), 'task1')
data_samples_nested.append(data_sample_nested)
predictions = head_nested.predict(feats, data_samples_nested)
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
for i in range(3):
sample = data_samples_nested[i + 1]
pred = predictions[i + 1]
self.assertIn('task0', pred)
self.assertIn('task1', pred)
self.assertIn('task01', pred.get('task0'))
self.assertEqual(
sample.get('task0').get('task01').gt_label.numpy()[0], 1)
def test_loss_empty_data_sample(self):
feats = (torch.rand(4, 10), )
data_samples = []