change the loop

pull/1716/head
Marouane Amzil (Student at CentraleSupelec) 2023-07-25 22:08:13 +02:00
parent bc59dd06e4
commit 132aa33f16
1 changed files with 4 additions and 8 deletions

View File

@ -121,14 +121,10 @@ class MultiTaskHead(BaseModule):
for task_name, head in self.task_heads.items():
task_samples = None
if data_samples is not None:
task_samples = []
for data_sample in data_samples:
if data_sample is None:
task_samples.append(None)
elif task_name in data_sample.tasks:
task_samples.append(data_sample.get(task_name))
else:
task_samples.append(None)
task_samples = [
data_sample.get(task_name, None) if data_sample else None
for data_sample in data_samples
]
task_samples = head.predict(feats, task_samples)
batch_size = len(task_samples)