[Fix] Fix nested predict for multi-task prediction. (#1716)
* fix: multi task predict * change the loop --------- Co-authored-by: Pierre Colle <piercus@gmail.com>pull/1670/head
parent
64c446d507
commit
e7fc25cf64
|
@ -119,12 +119,24 @@ class MultiTaskHead(BaseModule):
|
|||
predictions_dict = dict()
|
||||
|
||||
for task_name, head in self.task_heads.items():
|
||||
task_samples = head.predict(feats)
|
||||
task_samples = None
|
||||
if data_samples is not None:
|
||||
task_samples = [
|
||||
data_sample.get(task_name, None) if data_sample else None
|
||||
for data_sample in data_samples
|
||||
]
|
||||
|
||||
task_samples = head.predict(feats, task_samples)
|
||||
batch_size = len(task_samples)
|
||||
predictions_dict[task_name] = task_samples
|
||||
|
||||
if data_samples is None:
|
||||
data_samples = [MultiTaskDataSample() for _ in range(batch_size)]
|
||||
else:
|
||||
data_samples = [
|
||||
MultiTaskDataSample() if data_sample is None else data_sample
|
||||
for data_sample in data_samples
|
||||
]
|
||||
|
||||
for task_name, task_samples in predictions_dict.items():
|
||||
for data_sample, task_sample in zip(data_samples, task_samples):
|
||||
|
|
|
@ -549,7 +549,7 @@ class TestMultiTaskHead(TestCase):
|
|||
data_sample.set_field(task_sample, task_name)
|
||||
data_samples.append(data_sample)
|
||||
head = MODELS.build(self.DEFAULT_ARGS)
|
||||
# with without data_samples
|
||||
# without data_samples
|
||||
predictions = head.predict(feats)
|
||||
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
|
||||
for pred in predictions:
|
||||
|
@ -564,6 +564,32 @@ class TestMultiTaskHead(TestCase):
|
|||
self.assertIs(sample, pred)
|
||||
self.assertIn('task0', pred)
|
||||
|
||||
# with data samples and nested
|
||||
head_nested = MODELS.build(self.DEFAULT_ARGS2)
|
||||
# adding a None data sample at the beginning
|
||||
data_samples_nested = [None]
|
||||
for _ in range(3):
|
||||
data_sample_nested = MultiTaskDataSample()
|
||||
data_sample_nested0 = MultiTaskDataSample()
|
||||
data_sample_nested0.set_field(DataSample().set_gt_label(1),
|
||||
'task00')
|
||||
data_sample_nested0.set_field(DataSample().set_gt_label(1),
|
||||
'task01')
|
||||
data_sample_nested.set_field(data_sample_nested0, 'task0')
|
||||
data_sample_nested.set_field(DataSample().set_gt_label(1), 'task1')
|
||||
data_samples_nested.append(data_sample_nested)
|
||||
|
||||
predictions = head_nested.predict(feats, data_samples_nested)
|
||||
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
|
||||
for i in range(3):
|
||||
sample = data_samples_nested[i + 1]
|
||||
pred = predictions[i + 1]
|
||||
self.assertIn('task0', pred)
|
||||
self.assertIn('task1', pred)
|
||||
self.assertIn('task01', pred.get('task0'))
|
||||
self.assertEqual(
|
||||
sample.get('task0').get('task01').gt_label.numpy()[0], 1)
|
||||
|
||||
def test_loss_empty_data_sample(self):
|
||||
feats = (torch.rand(4, 10), )
|
||||
data_samples = []
|
||||
|
|
Loading…
Reference in New Issue