Merge pull request #2374 from HydrogenSulfate/fix_dali
add batch Tensor collate to simplify dali code in train/eval/retrival…pull/2381/head
commit
45742397a3
|
@ -14,14 +14,12 @@
|
|||
|
||||
from __future__ import division
|
||||
|
||||
import copy
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import nvidia.dali.ops as ops
|
||||
import nvidia.dali.types as types
|
||||
import paddle
|
||||
from nvidia.dali import fn
|
||||
from typing import List
|
||||
from nvidia.dali.pipeline import Pipeline
|
||||
from nvidia.dali.plugin.paddle import DALIGenericIterator
|
||||
|
||||
|
@ -143,6 +141,18 @@ class HybridValPipe(Pipeline):
|
|||
return self.epoch_size("Reader")
|
||||
|
||||
|
||||
class DALIImageNetIterator(DALIGenericIterator):
|
||||
def __next__(self) -> List[paddle.Tensor]:
|
||||
data_batch = super(DALIImageNetIterator,
|
||||
self).__next__() # List[Dict[str, Tensor], ...]
|
||||
|
||||
# reformat to List[Tensor1, Tensor2, ...]
|
||||
data_batch = [
|
||||
paddle.to_tensor(data_batch[0][key]) for key in self.output_map
|
||||
]
|
||||
return data_batch
|
||||
|
||||
|
||||
def dali_dataloader(config, mode, device, num_threads=4, seed=None):
|
||||
assert "gpu" in device, "gpu training is required for DALI"
|
||||
device_id = int(device.split(':')[1])
|
||||
|
@ -278,7 +288,7 @@ def dali_dataloader(config, mode, device, num_threads=4, seed=None):
|
|||
pipe.build()
|
||||
pipelines = [pipe]
|
||||
# sample_per_shard = len(pipelines[0])
|
||||
return DALIGenericIterator(
|
||||
return DALIImageNetIterator(
|
||||
pipelines, ['data', 'label'], reader_name='Reader')
|
||||
else:
|
||||
resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
|
||||
|
@ -318,5 +328,5 @@ def dali_dataloader(config, mode, device, num_threads=4, seed=None):
|
|||
pad_output=pad_output,
|
||||
output_dtype=output_dtype)
|
||||
pipe.build()
|
||||
return DALIGenericIterator(
|
||||
return DALIImageNetIterator(
|
||||
[pipe], ['data', 'label'], reader_name="Reader")
|
||||
|
|
|
@ -47,11 +47,7 @@ def classification_eval(engine, epoch_id=0):
|
|||
if iter_id == 5:
|
||||
for key in time_info:
|
||||
time_info[key].reset()
|
||||
if engine.use_dali:
|
||||
batch = [
|
||||
paddle.to_tensor(batch[0]['data']),
|
||||
paddle.to_tensor(batch[0]['label'])
|
||||
]
|
||||
|
||||
time_info["reader_cost"].update(time.time() - tic)
|
||||
batch_size = batch[0].shape[0]
|
||||
batch[0] = paddle.to_tensor(batch[0])
|
||||
|
|
|
@ -155,11 +155,7 @@ def cal_feature(engine, name='gallery'):
|
|||
logger.info(
|
||||
f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
|
||||
)
|
||||
if engine.use_dali:
|
||||
batch = [
|
||||
paddle.to_tensor(batch[0]['data']),
|
||||
paddle.to_tensor(batch[0]['label'])
|
||||
]
|
||||
|
||||
batch = [paddle.to_tensor(x) for x in batch]
|
||||
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
|
||||
if len(batch) == 3:
|
||||
|
|
|
@ -29,11 +29,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
|
|||
for key in engine.time_info:
|
||||
engine.time_info[key].reset()
|
||||
engine.time_info["reader_cost"].update(time.time() - tic)
|
||||
if engine.use_dali:
|
||||
batch = [
|
||||
paddle.to_tensor(batch[0]['data']),
|
||||
paddle.to_tensor(batch[0]['label'])
|
||||
]
|
||||
|
||||
batch_size = batch[0].shape[0]
|
||||
if not engine.config["Global"].get("use_multilabel", False):
|
||||
batch[1] = batch[1].reshape([batch_size, -1])
|
||||
|
|
Loading…
Reference in New Issue