Merge pull request #2401 from HydrogenSulfate/fix_dali_static
Fix tensor conversion in static mode with dali loaderpull/2405/head
commit
184b684fd8
|
@ -142,13 +142,18 @@ class HybridValPipe(Pipeline):
|
|||
|
||||
|
||||
class DALIImageNetIterator(DALIGenericIterator):
|
||||
def __init__(self, *kargs, **kwargs):
|
||||
super(DALIImageNetIterator, self).__init__(*kargs, **kwargs)
|
||||
self.in_dynamic_mode = paddle.in_dynamic_mode()
|
||||
|
||||
def __next__(self) -> List[paddle.Tensor]:
|
||||
data_batch = super(DALIImageNetIterator,
|
||||
self).__next__() # List[Dict[str, Tensor], ...]
|
||||
|
||||
# reformat to List[Tensor1, Tensor2, ...]
|
||||
data_batch = [
|
||||
paddle.to_tensor(data_batch[0][key]) for key in self.output_map
|
||||
paddle.to_tensor(data_batch[0][key])
|
||||
if self.in_dynamic_mode else data_batch[0][key]
|
||||
for key in self.output_map
|
||||
]
|
||||
return data_batch
|
||||
|
||||
|
|
|
@ -386,15 +386,11 @@ def run(dataloader,
|
|||
|
||||
profiler.add_profiler_step(profiler_options)
|
||||
|
||||
if use_dali:
|
||||
batch_size = batch[0]["data"].shape()[0]
|
||||
feed_dict = batch[0]
|
||||
else:
|
||||
batch_size = batch[0].shape()[0]
|
||||
feed_dict = {
|
||||
key.name: batch[idx]
|
||||
for idx, key in enumerate(feeds.values())
|
||||
}
|
||||
batch_size = batch[0].shape()[0]
|
||||
feed_dict = {
|
||||
key.name: batch[idx]
|
||||
for idx, key in enumerate(feeds.values())
|
||||
}
|
||||
|
||||
metrics = exe.run(program=program,
|
||||
feed=feed_dict,
|
||||
|
|
Loading…
Reference in New Issue