From e7fc25cf64156170f1b5108fdc09a975e96da9b0 Mon Sep 17 00:00:00 2001 From: marouane amzil <53240092+marouaneamz@users.noreply.github.com> Date: Fri, 28 Jul 2023 07:44:12 +0200 Subject: [PATCH] [Fix] Fix nested predict for multi-task prediction. (#1716) * fix: multi task predict * change the loop --------- Co-authored-by: Pierre Colle --- mmpretrain/models/heads/multi_task_head.py | 14 ++++++++++- tests/test_models/test_heads.py | 28 +++++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/mmpretrain/models/heads/multi_task_head.py b/mmpretrain/models/heads/multi_task_head.py index 8b4645a7..3515a2b1 100644 --- a/mmpretrain/models/heads/multi_task_head.py +++ b/mmpretrain/models/heads/multi_task_head.py @@ -119,12 +119,24 @@ class MultiTaskHead(BaseModule): predictions_dict = dict() for task_name, head in self.task_heads.items(): - task_samples = head.predict(feats) + task_samples = None + if data_samples is not None: + task_samples = [ + data_sample.get(task_name, None) if data_sample else None + for data_sample in data_samples + ] + + task_samples = head.predict(feats, task_samples) batch_size = len(task_samples) predictions_dict[task_name] = task_samples if data_samples is None: data_samples = [MultiTaskDataSample() for _ in range(batch_size)] + else: + data_samples = [ + MultiTaskDataSample() if data_sample is None else data_sample + for data_sample in data_samples + ] for task_name, task_samples in predictions_dict.items(): for data_sample, task_sample in zip(data_samples, task_samples): diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index fcaa8f67..a4ddf495 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -549,7 +549,7 @@ class TestMultiTaskHead(TestCase): data_sample.set_field(task_sample, task_name) data_samples.append(data_sample) head = MODELS.build(self.DEFAULT_ARGS) - # with without data_samples + # without data_samples predictions = head.predict(feats) self.assertTrue(is_seq_of(predictions, MultiTaskDataSample)) for pred in predictions: @@ -564,6 +564,32 @@ class TestMultiTaskHead(TestCase): self.assertIs(sample, pred) self.assertIn('task0', pred) + # with data samples and nested + head_nested = MODELS.build(self.DEFAULT_ARGS2) + # adding a None data sample at the beginning + data_samples_nested = [None] + for _ in range(3): + data_sample_nested = MultiTaskDataSample() + data_sample_nested0 = MultiTaskDataSample() + data_sample_nested0.set_field(DataSample().set_gt_label(1), + 'task00') + data_sample_nested0.set_field(DataSample().set_gt_label(1), + 'task01') + data_sample_nested.set_field(data_sample_nested0, 'task0') + data_sample_nested.set_field(DataSample().set_gt_label(1), 'task1') + data_samples_nested.append(data_sample_nested) + + predictions = head_nested.predict(feats, data_samples_nested) + self.assertTrue(is_seq_of(predictions, MultiTaskDataSample)) + for i in range(3): + sample = data_samples_nested[i + 1] + pred = predictions[i + 1] + self.assertIn('task0', pred) + self.assertIn('task1', pred) + self.assertIn('task01', pred.get('task0')) + self.assertEqual( + sample.get('task0').get('task01').gt_label.numpy()[0], 1) + def test_loss_empty_data_sample(self): feats = (torch.rand(4, 10), ) data_samples = []