mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
Support DALI (#442)
This commit is contained in:
parent
c0b73558b1
commit
2b77c71459
11
tools/run_dali.sh
Normal file
11
tools/run_dali.sh
Normal file
@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
export FLAGS_fraction_of_gpu_memory_to_use=0.80
|
||||
|
||||
python -m paddle.distributed.launch \
|
||||
--selected_gpus="0,1,2,3,4,5,6,7" \
|
||||
tools/train.py \
|
||||
-c ./configs/ResNet/ResNet50.yaml \
|
||||
-o print_interval=10 \
|
||||
-o use_dali=true
|
340
tools/static/dali.py
Normal file
340
tools/static/dali.py
Normal file
@ -0,0 +1,340 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from nvidia.dali.pipeline import Pipeline
|
||||
import nvidia.dali.ops as ops
|
||||
import nvidia.dali.types as types
|
||||
from nvidia.dali.plugin.paddle import DALIGenericIterator
|
||||
|
||||
import paddle
|
||||
from paddle import fluid
|
||||
|
||||
|
||||
class HybridTrainPipe(Pipeline):
|
||||
def __init__(self,
|
||||
file_root,
|
||||
file_list,
|
||||
batch_size,
|
||||
resize_shorter,
|
||||
crop,
|
||||
min_area,
|
||||
lower,
|
||||
upper,
|
||||
interp,
|
||||
mean,
|
||||
std,
|
||||
device_id,
|
||||
shard_id=0,
|
||||
num_shards=1,
|
||||
random_shuffle=True,
|
||||
num_threads=4,
|
||||
seed=42):
|
||||
super(HybridTrainPipe, self).__init__(
|
||||
batch_size, num_threads, device_id, seed=seed)
|
||||
self.input = ops.FileReader(
|
||||
file_root=file_root,
|
||||
file_list=file_list,
|
||||
shard_id=shard_id,
|
||||
num_shards=num_shards,
|
||||
random_shuffle=random_shuffle)
|
||||
# set internal nvJPEG buffers size to handle full-sized ImageNet images
|
||||
# without additional reallocations
|
||||
device_memory_padding = 211025920
|
||||
host_memory_padding = 140544512
|
||||
self.decode = ops.ImageDecoderRandomCrop(
|
||||
device='mixed',
|
||||
output_type=types.RGB,
|
||||
device_memory_padding=device_memory_padding,
|
||||
host_memory_padding=host_memory_padding,
|
||||
random_aspect_ratio=[lower, upper],
|
||||
random_area=[min_area, 1.0],
|
||||
num_attempts=100)
|
||||
self.res = ops.Resize(
|
||||
device='gpu', resize_x=crop, resize_y=crop, interp_type=interp)
|
||||
self.cmnp = ops.CropMirrorNormalize(
|
||||
device="gpu",
|
||||
output_dtype=types.FLOAT,
|
||||
output_layout=types.NCHW,
|
||||
crop=(crop, crop),
|
||||
image_type=types.RGB,
|
||||
mean=mean,
|
||||
std=std)
|
||||
self.coin = ops.CoinFlip(probability=0.5)
|
||||
self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")
|
||||
|
||||
def define_graph(self):
|
||||
rng = self.coin()
|
||||
jpegs, labels = self.input(name="Reader")
|
||||
images = self.decode(jpegs)
|
||||
images = self.res(images)
|
||||
output = self.cmnp(images.gpu(), mirror=rng)
|
||||
return [output, self.to_int64(labels.gpu())]
|
||||
|
||||
def __len__(self):
|
||||
return self.epoch_size("Reader")
|
||||
|
||||
|
||||
class HybridValPipe(Pipeline):
|
||||
def __init__(self,
|
||||
file_root,
|
||||
file_list,
|
||||
batch_size,
|
||||
resize_shorter,
|
||||
crop,
|
||||
interp,
|
||||
mean,
|
||||
std,
|
||||
device_id,
|
||||
shard_id=0,
|
||||
num_shards=1,
|
||||
random_shuffle=False,
|
||||
num_threads=4,
|
||||
seed=42):
|
||||
super(HybridValPipe, self).__init__(
|
||||
batch_size, num_threads, device_id, seed=seed)
|
||||
self.input = ops.FileReader(
|
||||
file_root=file_root,
|
||||
file_list=file_list,
|
||||
shard_id=shard_id,
|
||||
num_shards=num_shards,
|
||||
random_shuffle=random_shuffle)
|
||||
self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
|
||||
self.res = ops.Resize(
|
||||
device="gpu", resize_shorter=resize_shorter, interp_type=interp)
|
||||
self.cmnp = ops.CropMirrorNormalize(
|
||||
device="gpu",
|
||||
output_dtype=types.FLOAT,
|
||||
output_layout=types.NCHW,
|
||||
crop=(crop, crop),
|
||||
image_type=types.RGB,
|
||||
mean=mean,
|
||||
std=std)
|
||||
self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")
|
||||
|
||||
def define_graph(self):
|
||||
jpegs, labels = self.input(name="Reader")
|
||||
images = self.decode(jpegs)
|
||||
images = self.res(images)
|
||||
output = self.cmnp(images)
|
||||
return [output, self.to_int64(labels.gpu())]
|
||||
|
||||
def __len__(self):
|
||||
return self.epoch_size("Reader")
|
||||
|
||||
|
||||
def build(config, mode='train'):
|
||||
env = os.environ
|
||||
assert config.get('use_gpu',
|
||||
True) == True, "gpu training is required for DALI"
|
||||
assert not config.get(
|
||||
'use_aa'), "auto augment is not supported by DALI reader"
|
||||
assert float(env.get('FLAGS_fraction_of_gpu_memory_to_use', 0.92)) < 0.9, \
|
||||
"Please leave enough GPU memory for DALI workspace, e.g., by setting" \
|
||||
" `export FLAGS_fraction_of_gpu_memory_to_use=0.8`"
|
||||
|
||||
dataset_config = config[mode.upper()]
|
||||
|
||||
gpu_num = paddle.fluid.core.get_cuda_device_count() if (
|
||||
'PADDLE_TRAINERS_NUM') and (
|
||||
'PADDLE_TRAINER_ID'
|
||||
) not in env else int(env.get('PADDLE_TRAINERS_NUM', 0))
|
||||
|
||||
batch_size = dataset_config.batch_size
|
||||
assert batch_size % gpu_num == 0, \
|
||||
"batch size must be multiple of number of devices"
|
||||
batch_size = batch_size // gpu_num
|
||||
|
||||
file_root = dataset_config.data_dir
|
||||
file_list = dataset_config.file_list
|
||||
|
||||
interp = 1 # settings.interpolation or 1 # default to linear
|
||||
interp_map = {
|
||||
0: types.INTERP_NN, # cv2.INTER_NEAREST
|
||||
1: types.INTERP_LINEAR, # cv2.INTER_LINEAR
|
||||
2: types.INTERP_CUBIC, # cv2.INTER_CUBIC
|
||||
4: types.INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4
|
||||
}
|
||||
assert interp in interp_map, "interpolation method not supported by DALI"
|
||||
interp = interp_map[interp]
|
||||
|
||||
transforms = {
|
||||
k: v
|
||||
for d in dataset_config["transforms"] for k, v in d.items()
|
||||
}
|
||||
|
||||
scale = transforms["NormalizeImage"].get("scale", 1.0 / 255)
|
||||
if isinstance(scale, str):
|
||||
scale = eval(scale)
|
||||
mean = transforms["NormalizeImage"].get("mean", [0.485, 0.456, 0.406])
|
||||
std = transforms["NormalizeImage"].get("std", [0.229, 0.224, 0.225])
|
||||
mean = [v / scale for v in mean]
|
||||
std = [v / scale for v in std]
|
||||
|
||||
if mode == "train":
|
||||
resize_shorter = 256
|
||||
crop = transforms["RandCropImage"]["size"]
|
||||
scale = transforms["RandCropImage"].get("scale", [0.08, 1.])
|
||||
ratio = transforms["RandCropImage"].get("ratio", [3.0 / 4, 4.0 / 3])
|
||||
min_area = scale[0]
|
||||
lower = ratio[0]
|
||||
upper = ratio[1]
|
||||
|
||||
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env:
|
||||
shard_id = int(env['PADDLE_TRAINER_ID'])
|
||||
num_shards = int(env['PADDLE_TRAINERS_NUM'])
|
||||
device_id = int(env['FLAGS_selected_gpus'])
|
||||
pipe = HybridTrainPipe(
|
||||
file_root,
|
||||
file_list,
|
||||
batch_size,
|
||||
resize_shorter,
|
||||
crop,
|
||||
min_area,
|
||||
lower,
|
||||
upper,
|
||||
interp,
|
||||
mean,
|
||||
std,
|
||||
device_id,
|
||||
shard_id,
|
||||
num_shards,
|
||||
seed=42 + shard_id)
|
||||
pipe.build()
|
||||
pipelines = [pipe]
|
||||
sample_per_shard = len(pipe) // num_shards
|
||||
else:
|
||||
pipelines = []
|
||||
places = fluid.framework.cuda_places()
|
||||
num_shards = len(places)
|
||||
for idx, p in enumerate(places):
|
||||
place = fluid.core.Place()
|
||||
place.set_place(p)
|
||||
device_id = place.gpu_device_id()
|
||||
pipe = HybridTrainPipe(
|
||||
file_root,
|
||||
file_list,
|
||||
batch_size,
|
||||
resize_shorter,
|
||||
crop,
|
||||
min_area,
|
||||
lower,
|
||||
upper,
|
||||
interp,
|
||||
mean,
|
||||
std,
|
||||
device_id,
|
||||
idx,
|
||||
num_shards,
|
||||
seed=42 + idx)
|
||||
pipe.build()
|
||||
pipelines.append(pipe)
|
||||
sample_per_shard = len(pipelines[0])
|
||||
return DALIGenericIterator(
|
||||
pipelines, ['feed_image', 'feed_label'], size=sample_per_shard)
|
||||
else:
|
||||
resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
|
||||
crop = transforms["CropImage"]["size"]
|
||||
|
||||
p = fluid.framework.cuda_places()[0]
|
||||
place = fluid.core.Place()
|
||||
place.set_place(p)
|
||||
device_id = place.gpu_device_id()
|
||||
pipe = HybridValPipe(
|
||||
file_root,
|
||||
file_list,
|
||||
batch_size,
|
||||
resize_shorter,
|
||||
crop,
|
||||
interp,
|
||||
mean,
|
||||
std,
|
||||
device_id=device_id)
|
||||
pipe.build()
|
||||
return DALIGenericIterator(
|
||||
pipe, ['feed_image', 'feed_label'],
|
||||
size=len(pipe),
|
||||
dynamic_shape=True,
|
||||
fill_last_batch=True,
|
||||
last_batch_padded=True)
|
||||
|
||||
|
||||
def train(config):
|
||||
return build(config, 'train')
|
||||
|
||||
|
||||
def val(config):
|
||||
return build(config, 'valid')
|
||||
|
||||
|
||||
def _to_Tensor(lod_tensor, dtype):
|
||||
data_tensor = fluid.layers.create_tensor(dtype=dtype)
|
||||
data = np.array(lod_tensor).astype(dtype)
|
||||
fluid.layers.assign(data, data_tensor)
|
||||
return data_tensor
|
||||
|
||||
|
||||
def normalize(feeds, config):
|
||||
image, label = feeds['image'], feeds['label']
|
||||
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
|
||||
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
|
||||
image = fluid.layers.cast(image, 'float32')
|
||||
costant = fluid.layers.fill_constant(
|
||||
shape=[1], value=255.0, dtype='float32')
|
||||
image = fluid.layers.elementwise_div(image, costant)
|
||||
|
||||
mean = fluid.layers.create_tensor(dtype="float32")
|
||||
fluid.layers.assign(input=img_mean.astype("float32"), output=mean)
|
||||
std = fluid.layers.create_tensor(dtype="float32")
|
||||
fluid.layers.assign(input=img_std.astype("float32"), output=std)
|
||||
|
||||
image = fluid.layers.elementwise_sub(image, mean)
|
||||
image = fluid.layers.elementwise_div(image, std)
|
||||
|
||||
image.stop_gradient = True
|
||||
feeds['image'] = image
|
||||
|
||||
return feeds
|
||||
|
||||
|
||||
def mix(feeds, config, is_train=True):
|
||||
env = os.environ
|
||||
gpu_num = paddle.fluid.core.get_cuda_device_count() if (
|
||||
'PADDLE_TRAINERS_NUM') and (
|
||||
'PADDLE_TRAINER_ID'
|
||||
) not in env else int(env.get('PADDLE_TRAINERS_NUM', 0))
|
||||
|
||||
batch_size = config.TRAIN.batch_size // gpu_num
|
||||
|
||||
images = feeds['image']
|
||||
label = feeds['label']
|
||||
# TODO: hard code here, should be fixed!
|
||||
alpha = 0.2
|
||||
idx = _to_Tensor(np.random.permutation(batch_size), 'int32')
|
||||
lam = np.random.beta(alpha, alpha)
|
||||
|
||||
images = lam * images + (1 - lam) * paddle.fluid.layers.gather(images, idx)
|
||||
|
||||
feed = {
|
||||
'image': images,
|
||||
'feed_y_a': label,
|
||||
'feed_y_b': paddle.fluid.layers.gather(label, idx),
|
||||
'feed_lam': _to_Tensor([lam] * batch_size, 'float32')
|
||||
}
|
||||
|
||||
return feed if is_train else feeds
|
@ -66,7 +66,7 @@ def save_model(program, model_path, epoch_id, prefix='ppcls'):
|
||||
logger.info("Already save model in {}".format(model_path))
|
||||
|
||||
|
||||
def create_feeds(image_shape, use_mix=None):
|
||||
def create_feeds(image_shape, use_mix=None, use_dali=None):
|
||||
"""
|
||||
Create feeds as model input
|
||||
|
||||
@ -80,7 +80,7 @@ def create_feeds(image_shape, use_mix=None):
|
||||
feeds = OrderedDict()
|
||||
feeds['image'] = paddle.static.data(
|
||||
name="feed_image", shape=[None] + image_shape, dtype="float32")
|
||||
if use_mix:
|
||||
if use_mix and not use_dali:
|
||||
feeds['feed_y_a'] = paddle.static.data(
|
||||
name="feed_y_a", shape=[None, 1], dtype="int64")
|
||||
feeds['feed_y_b'] = paddle.static.data(
|
||||
@ -345,8 +345,13 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
|
||||
with paddle.static.program_guard(main_prog, startup_prog):
|
||||
with paddle.utils.unique_name.guard():
|
||||
use_mix = config.get('use_mix') and is_train
|
||||
use_dali = config.get('use_dali', False)
|
||||
use_distillation = config.get('use_distillation')
|
||||
feeds = create_feeds(config.image_shape, use_mix=use_mix)
|
||||
feeds = create_feeds(
|
||||
config.image_shape, use_mix=use_mix, use_dali=use_dali)
|
||||
if use_dali and use_mix:
|
||||
import dali
|
||||
feeds = dali.mix(feeds, config, is_train)
|
||||
out = create_model(config.ARCHITECTURE, feeds['image'],
|
||||
config.classes_num, is_train)
|
||||
fetchs = create_fetchs(
|
||||
@ -431,8 +436,10 @@ def run(dataloader,
|
||||
for m in metric_list:
|
||||
m.reset()
|
||||
batch_time = AverageMeter('elapse', '.3f')
|
||||
use_dali = config.get('use_dali', False)
|
||||
dataloader = dataloader if use_dali else dataloader()
|
||||
tic = time.time()
|
||||
for idx, batch in enumerate(dataloader()):
|
||||
for idx, batch in enumerate(dataloader):
|
||||
# ignore the warmup iters
|
||||
if idx == 5:
|
||||
batch_time.reset()
|
||||
@ -497,6 +504,8 @@ def run(dataloader,
|
||||
end_epoch_str = "END epoch:{:<3d}".format(epoch)
|
||||
logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
|
||||
ips_info))
|
||||
if use_dali:
|
||||
dataloader.reset()
|
||||
|
||||
# return top1_acc in order to save the best model
|
||||
if mode == 'valid':
|
||||
|
@ -63,10 +63,13 @@ def main(args):
|
||||
|
||||
config = get_config(args.config, overrides=args.override, show=True)
|
||||
# assign the place
|
||||
use_gpu = config.get("use_gpu", False)
|
||||
use_gpu = config.get("use_gpu", True)
|
||||
use_xpu = config.get("use_xpu", False)
|
||||
assert (use_gpu or use_xpu) is True, "gpu or xpu must be true in static mode!"
|
||||
assert (use_gpu and use_xpu) is not True, "gpu and xpu can not be true in the same time in static mode!"
|
||||
assert (use_gpu or use_xpu
|
||||
) is True, "gpu or xpu must be true in static mode!"
|
||||
assert (
|
||||
use_gpu and use_xpu
|
||||
) is not True, "gpu and xpu can not be true in the same time in static mode!"
|
||||
|
||||
place = paddle.set_device('gpu' if use_gpu else 'xpu')
|
||||
|
||||
@ -78,12 +81,20 @@ def main(args):
|
||||
best_top1_acc = 0.0 # best top1 acc record
|
||||
|
||||
train_fetchs, lr_scheduler, train_feeds = program.build(
|
||||
config, train_prog, startup_prog, is_train=True, is_distributed=config.get("is_distributed", True))
|
||||
config,
|
||||
train_prog,
|
||||
startup_prog,
|
||||
is_train=True,
|
||||
is_distributed=config.get("is_distributed", True))
|
||||
|
||||
if config.validate:
|
||||
valid_prog = paddle.static.Program()
|
||||
valid_fetchs, _, valid_feeds = program.build(
|
||||
config, valid_prog, startup_prog, is_train=False, is_distributed=config.get("is_distributed", True))
|
||||
config,
|
||||
valid_prog,
|
||||
startup_prog,
|
||||
is_train=False,
|
||||
is_distributed=config.get("is_distributed", True))
|
||||
# clone to prune some content which is irrelevant in valid_prog
|
||||
valid_prog = valid_prog.clone(for_test=True)
|
||||
|
||||
@ -92,14 +103,20 @@ def main(args):
|
||||
# Parameter initialization
|
||||
exe.run(startup_prog)
|
||||
|
||||
# load model from 1. checkpoint to resume training, 2. pretrained model to finetune
|
||||
train_dataloader = Reader(config, 'train', places=place)()
|
||||
|
||||
if config.validate and paddle.distributed.get_rank() == 0:
|
||||
valid_dataloader = Reader(config, 'valid', places=place)()
|
||||
if use_xpu:
|
||||
compiled_valid_prog = valid_prog
|
||||
else:
|
||||
if not config.get('use_dali', False):
|
||||
train_dataloader = Reader(config, 'train', places=place)()
|
||||
if config.validate and paddle.distributed.get_rank() == 0:
|
||||
valid_dataloader = Reader(config, 'valid', places=place)()
|
||||
if use_xpu:
|
||||
compiled_valid_prog = valid_prog
|
||||
else:
|
||||
compiled_valid_prog = program.compile(config, valid_prog)
|
||||
else:
|
||||
assert use_gpu is True, "DALI only support gpu, please set use_gpu to True!"
|
||||
import dali
|
||||
train_dataloader = dali.train(config)
|
||||
if config.validate and paddle.distributed.get_rank() == 0:
|
||||
valid_dataloader = dali.val(config)
|
||||
compiled_valid_prog = program.compile(config, valid_prog)
|
||||
|
||||
vdl_writer = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user