Merge pull request #2374 from HydrogenSulfate/fix_dali

add batch Tensor collate to simplify dali code in train/eval/retrival…
pull/2381/head
HydrogenSulfate 2022-10-17 10:40:36 +08:00 committed by GitHub
commit 45742397a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 20 deletions

View File

@ -14,14 +14,12 @@
from __future__ import division
import copy
import os
import numpy as np
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import paddle
from nvidia.dali import fn
from typing import List
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.paddle import DALIGenericIterator
@ -143,6 +141,18 @@ class HybridValPipe(Pipeline):
return self.epoch_size("Reader")
class DALIImageNetIterator(DALIGenericIterator):
def __next__(self) -> List[paddle.Tensor]:
data_batch = super(DALIImageNetIterator,
self).__next__() # List[Dict[str, Tensor], ...]
# reformat to List[Tensor1, Tensor2, ...]
data_batch = [
paddle.to_tensor(data_batch[0][key]) for key in self.output_map
]
return data_batch
def dali_dataloader(config, mode, device, num_threads=4, seed=None):
assert "gpu" in device, "gpu training is required for DALI"
device_id = int(device.split(':')[1])
@ -278,7 +288,7 @@ def dali_dataloader(config, mode, device, num_threads=4, seed=None):
pipe.build()
pipelines = [pipe]
# sample_per_shard = len(pipelines[0])
return DALIGenericIterator(
return DALIImageNetIterator(
pipelines, ['data', 'label'], reader_name='Reader')
else:
resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
@ -318,5 +328,5 @@ def dali_dataloader(config, mode, device, num_threads=4, seed=None):
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build()
return DALIGenericIterator(
return DALIImageNetIterator(
[pipe], ['data', 'label'], reader_name="Reader")

View File

@ -47,11 +47,7 @@ def classification_eval(engine, epoch_id=0):
if iter_id == 5:
for key in time_info:
time_info[key].reset()
if engine.use_dali:
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0])

View File

@ -155,11 +155,7 @@ def cal_feature(engine, name='gallery'):
logger.info(
f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
)
if engine.use_dali:
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if len(batch) == 3:

View File

@ -29,11 +29,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
for key in engine.time_info:
engine.time_info[key].reset()
engine.time_info["reader_cost"].update(time.time() - tic)
if engine.use_dali:
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
batch_size = batch[0].shape[0]
if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([batch_size, -1])