diff --git a/mmpretrain/models/heads/multi_task_head.py b/mmpretrain/models/heads/multi_task_head.py index 9cc71d35..3515a2b1 100644 --- a/mmpretrain/models/heads/multi_task_head.py +++ b/mmpretrain/models/heads/multi_task_head.py @@ -121,14 +121,10 @@ class MultiTaskHead(BaseModule): for task_name, head in self.task_heads.items(): task_samples = None if data_samples is not None: - task_samples = [] - for data_sample in data_samples: - if data_sample is None: - task_samples.append(None) - elif task_name in data_sample.tasks: - task_samples.append(data_sample.get(task_name)) - else: - task_samples.append(None) + task_samples = [ + data_sample.get(task_name, None) if data_sample else None + for data_sample in data_samples + ] task_samples = head.predict(feats, task_samples) batch_size = len(task_samples)