pull/1554/merge
Colle 2023-07-24 16:08:49 +08:00 committed by GitHub
commit 5cedc5527a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 2 deletions

View File

@ -119,12 +119,28 @@ class MultiTaskHead(BaseModule):
predictions_dict = dict()
for task_name, head in self.task_heads.items():
task_samples = head.predict(feats)
task_samples = None
if data_samples is not None:
task_samples = []
for data_sample in data_samples:
if data_sample is None:
task_samples.append(None)
elif task_name in data_sample.tasks:
task_samples.append(data_sample.get(task_name))
else:
task_samples.append(None)
task_samples = head.predict(feats, task_samples)
batch_size = len(task_samples)
predictions_dict[task_name] = task_samples
if data_samples is None:
data_samples = [MultiTaskDataSample() for _ in range(batch_size)]
else:
data_samples = [
MultiTaskDataSample() if data_sample is None else data_sample
for data_sample in data_samples
]
for task_name, task_samples in predictions_dict.items():
for data_sample, task_sample in zip(data_samples, task_samples):

View File

@ -549,7 +549,7 @@ class TestMultiTaskHead(TestCase):
data_sample.set_field(task_sample, task_name)
data_samples.append(data_sample)
head = MODELS.build(self.DEFAULT_ARGS)
# with without data_samples
# without data_samples
predictions = head.predict(feats)
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
for pred in predictions:
@ -564,6 +564,32 @@ class TestMultiTaskHead(TestCase):
self.assertIs(sample, pred)
self.assertIn('task0', pred)
# with data samples and nested
head_nested = MODELS.build(self.DEFAULT_ARGS2)
# adding a None data sample at the beginning
data_samples_nested = [None]
for _ in range(3):
data_sample_nested = MultiTaskDataSample()
data_sample_nested0 = MultiTaskDataSample()
data_sample_nested0.set_field(DataSample().set_gt_label(1),
'task00')
data_sample_nested0.set_field(DataSample().set_gt_label(1),
'task01')
data_sample_nested.set_field(data_sample_nested0, 'task0')
data_sample_nested.set_field(DataSample().set_gt_label(1), 'task1')
data_samples_nested.append(data_sample_nested)
predictions = head_nested.predict(feats, data_samples_nested)
self.assertTrue(is_seq_of(predictions, MultiTaskDataSample))
for i in range(3):
sample = data_samples_nested[i + 1]
pred = predictions[i + 1]
self.assertIn('task0', pred)
self.assertIn('task1', pred)
self.assertIn('task01', pred.get('task0'))
self.assertEqual(
sample.get('task0').get('task01').gt_label.numpy()[0], 1)
def test_loss_empty_data_sample(self):
feats = (torch.rand(4, 10), )
data_samples = []