change the loop
parent
bc59dd06e4
commit
132aa33f16
|
@ -121,14 +121,10 @@ class MultiTaskHead(BaseModule):
|
|||
for task_name, head in self.task_heads.items():
|
||||
task_samples = None
|
||||
if data_samples is not None:
|
||||
task_samples = []
|
||||
for data_sample in data_samples:
|
||||
if data_sample is None:
|
||||
task_samples.append(None)
|
||||
elif task_name in data_sample.tasks:
|
||||
task_samples.append(data_sample.get(task_name))
|
||||
else:
|
||||
task_samples.append(None)
|
||||
task_samples = [
|
||||
data_sample.get(task_name, None) if data_sample else None
|
||||
for data_sample in data_samples
|
||||
]
|
||||
|
||||
task_samples = head.predict(feats, task_samples)
|
||||
batch_size = len(task_samples)
|
||||
|
|
Loading…
Reference in New Issue