Merge pull request #4348 from Intsigstephon/feature_amp_train
support amp train; add example yamlpull/4442/head
commit
cc01a59b82
|
@ -1,9 +1,9 @@
|
|||
===========================train_params===========================
|
||||
model_name:ocr_det
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:null
|
||||
gpu_list:0|0,1|10.21.226.181,10.21.226.133;0,1
|
||||
Global.use_gpu:True|True|True
|
||||
Global.auto_cast:fp32|amp
|
||||
Global.epoch_num:lite_train_infer=1|whole_train_infer=300
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_infer=2|whole_train_infer=4
|
||||
|
|
|
@ -245,6 +245,7 @@ else
|
|||
for gpu in ${gpu_list[*]}; do
|
||||
use_gpu=${USE_GPU_KEY[Count]}
|
||||
Count=$(($Count + 1))
|
||||
ips=""
|
||||
if [ ${gpu} = "-1" ];then
|
||||
env=""
|
||||
elif [ ${#gpu} -le 1 ];then
|
||||
|
@ -264,6 +265,11 @@ else
|
|||
env=" "
|
||||
fi
|
||||
for autocast in ${autocast_list[*]}; do
|
||||
if [ ${autocast} = "amp" ]; then
|
||||
set_amp_config="Global.use_amp=True Global.scale_loss=1024.0 Global.use_dynamic_loss_scaling=True"
|
||||
else
|
||||
set_amp_config=" "
|
||||
fi
|
||||
for trainer in ${trainer_list[*]}; do
|
||||
flag_quant=False
|
||||
if [ ${trainer} = ${pact_key} ]; then
|
||||
|
@ -290,7 +296,6 @@ else
|
|||
if [ ${run_train} = "null" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
|
||||
set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
|
||||
set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
|
||||
|
@ -306,11 +311,11 @@ else
|
|||
|
||||
set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
|
||||
if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
|
||||
cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} "
|
||||
elif [ ${#gpu} -le 15 ];then # train with multi-gpu
|
||||
cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1}"
|
||||
cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config} "
|
||||
elif [ ${#ips} -le 26 ];then # train with multi-gpu
|
||||
cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
|
||||
else # train with multi-machine
|
||||
cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1}"
|
||||
cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${set_use_gpu} ${run_train} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
|
||||
fi
|
||||
# run train
|
||||
eval "unset CUDA_VISIBLE_DEVICES"
|
||||
|
|
|
@ -159,7 +159,8 @@ def train(config,
|
|||
eval_class,
|
||||
pre_best_model_dict,
|
||||
logger,
|
||||
vdl_writer=None):
|
||||
vdl_writer=None,
|
||||
scaler=None):
|
||||
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
|
||||
False)
|
||||
log_smooth_window = config['Global']['log_smooth_window']
|
||||
|
@ -226,14 +227,29 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
|
||||
# use amp
|
||||
if scaler:
|
||||
with paddle.amp.auto_cast():
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
else:
|
||||
preds = model(images)
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
loss = loss_class(preds, batch)
|
||||
avg_loss = loss['loss']
|
||||
avg_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if scaler:
|
||||
scaled_avg_loss = scaler.scale(avg_loss)
|
||||
scaled_avg_loss.backward()
|
||||
scaler.minimize(optimizer, scaled_avg_loss)
|
||||
else:
|
||||
avg_loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.clear_grad()
|
||||
|
||||
train_batch_cost += time.time() - batch_start
|
||||
|
|
|
@ -102,10 +102,27 @@ def main(config, device, logger, vdl_writer):
|
|||
if valid_dataloader is not None:
|
||||
logger.info('valid dataloader has {} iters'.format(
|
||||
len(valid_dataloader)))
|
||||
|
||||
use_amp = config["Global"].get("use_amp", False)
|
||||
if use_amp:
|
||||
AMP_RELATED_FLAGS_SETTING = {
|
||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
||||
'FLAGS_max_inplace_grad_add': 8,
|
||||
}
|
||||
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||
scale_loss = config["Global"].get("scale_loss", 1.0)
|
||||
use_dynamic_loss_scaling = config["Global"].get(
|
||||
"use_dynamic_loss_scaling", False)
|
||||
scaler = paddle.amp.GradScaler(
|
||||
init_loss_scaling=scale_loss,
|
||||
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
|
||||
else:
|
||||
scaler = None
|
||||
|
||||
# start train
|
||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||
loss_class, optimizer, lr_scheduler, post_process_class,
|
||||
eval_class, pre_best_model_dict, logger, vdl_writer)
|
||||
eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
|
||||
|
||||
|
||||
def test_reader(config, device, logger):
|
||||
|
|
Loading…
Reference in New Issue