From a14df4ac527d7d25fcf938cff0ce0f5d071bd5fe Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 20 Oct 2022 12:03:03 +0000 Subject: [PATCH] fix tensor conversion in static mode with dali loader --- ppcls/data/dataloader/dali.py | 9 +++++++-- ppcls/static/program.py | 14 +++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/ppcls/data/dataloader/dali.py b/ppcls/data/dataloader/dali.py index 5d8255504..cb8f2e1a1 100644 --- a/ppcls/data/dataloader/dali.py +++ b/ppcls/data/dataloader/dali.py @@ -142,13 +142,18 @@ class HybridValPipe(Pipeline): class DALIImageNetIterator(DALIGenericIterator): + def __init__(self, *kargs, **kwargs): + super(DALIImageNetIterator, self).__init__(*kargs, **kwargs) + self.in_dynamic_mode = paddle.in_dynamic_mode() + def __next__(self) -> List[paddle.Tensor]: data_batch = super(DALIImageNetIterator, self).__next__() # List[Dict[str, Tensor], ...] - # reformat to List[Tensor1, Tensor2, ...] data_batch = [ - paddle.to_tensor(data_batch[0][key]) for key in self.output_map + paddle.to_tensor(data_batch[0][key]) + if self.in_dynamic_mode else data_batch[0][key] + for key in self.output_map ] return data_batch diff --git a/ppcls/static/program.py b/ppcls/static/program.py index a6a80f13e..5c28af0a7 100644 --- a/ppcls/static/program.py +++ b/ppcls/static/program.py @@ -386,15 +386,11 @@ def run(dataloader, profiler.add_profiler_step(profiler_options) - if use_dali: - batch_size = batch[0]["data"].shape()[0] - feed_dict = batch[0] - else: - batch_size = batch[0].shape()[0] - feed_dict = { - key.name: batch[idx] - for idx, key in enumerate(feeds.values()) - } + batch_size = batch[0].shape()[0] + feed_dict = { + key.name: batch[idx] + for idx, key in enumerate(feeds.values()) + } metrics = exe.run(program=program, feed=feed_dict,