Merge pull request #2401 from HydrogenSulfate/fix_dali_static

Fix tensor conversion in static mode with dali loader
pull/2405/head
HydrogenSulfate 2022-10-21 12:08:26 +08:00 committed by GitHub
commit 184b684fd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 11 deletions

View File

@ -142,13 +142,18 @@ class HybridValPipe(Pipeline):
class DALIImageNetIterator(DALIGenericIterator):
def __init__(self, *kargs, **kwargs):
super(DALIImageNetIterator, self).__init__(*kargs, **kwargs)
self.in_dynamic_mode = paddle.in_dynamic_mode()
def __next__(self) -> List[paddle.Tensor]:
data_batch = super(DALIImageNetIterator,
self).__next__() # List[Dict[str, Tensor], ...]
# reformat to List[Tensor1, Tensor2, ...]
data_batch = [
paddle.to_tensor(data_batch[0][key]) for key in self.output_map
paddle.to_tensor(data_batch[0][key])
if self.in_dynamic_mode else data_batch[0][key]
for key in self.output_map
]
return data_batch

View File

@ -386,15 +386,11 @@ def run(dataloader,
profiler.add_profiler_step(profiler_options)
if use_dali:
batch_size = batch[0]["data"].shape()[0]
feed_dict = batch[0]
else:
batch_size = batch[0].shape()[0]
feed_dict = {
key.name: batch[idx]
for idx, key in enumerate(feeds.values())
}
batch_size = batch[0].shape()[0]
feed_dict = {
key.name: batch[idx]
for idx, key in enumerate(feeds.values())
}
metrics = exe.run(program=program,
feed=feed_dict,