improve amp training (#10119)
parent
062e2c5099
commit
6949448558
|
@ -16,7 +16,7 @@ Global:
|
|||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
use_amp: False
|
||||
amp_level: O2
|
||||
amp_custom_black_list: ['exp']
|
||||
amp_dtype: bfloat16
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
|
|
|
@ -189,7 +189,8 @@ def train(config,
|
|||
scaler=None,
|
||||
amp_level='O2',
|
||||
amp_custom_black_list=[],
|
||||
amp_custom_white_list=[]):
|
||||
amp_custom_white_list=[],
|
||||
amp_dtype='float16'):
|
||||
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
|
||||
False)
|
||||
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
|
||||
|
@ -279,7 +280,8 @@ def train(config,
|
|||
with paddle.amp.auto_cast(
|
||||
level=amp_level,
|
||||
custom_black_list=amp_custom_black_list,
|
||||
custom_white_list=amp_custom_white_list):
|
||||
custom_white_list=amp_custom_white_list,
|
||||
dtype=amp_dtype):
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
elif model_type in ["kie"]:
|
||||
|
@ -393,7 +395,9 @@ def train(config,
|
|||
extra_input=extra_input,
|
||||
scaler=scaler,
|
||||
amp_level=amp_level,
|
||||
amp_custom_black_list=amp_custom_black_list)
|
||||
amp_custom_black_list=amp_custom_black_list,
|
||||
amp_custom_white_list=amp_custom_white_list,
|
||||
amp_dtype=amp_dtype)
|
||||
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||
logger.info(cur_metric_str)
|
||||
|
@ -486,7 +490,9 @@ def eval(model,
|
|||
extra_input=False,
|
||||
scaler=None,
|
||||
amp_level='O2',
|
||||
amp_custom_black_list=[]):
|
||||
amp_custom_black_list=[],
|
||||
amp_custom_white_list=[],
|
||||
amp_dtype='float16'):
|
||||
model.eval()
|
||||
with paddle.no_grad():
|
||||
total_frame = 0.0
|
||||
|
@ -509,7 +515,8 @@ def eval(model,
|
|||
if scaler:
|
||||
with paddle.amp.auto_cast(
|
||||
level=amp_level,
|
||||
custom_black_list=amp_custom_black_list):
|
||||
custom_black_list=amp_custom_black_list,
|
||||
dtype=amp_dtype):
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
elif model_type in ["kie"]:
|
||||
|
|
|
@ -160,6 +160,7 @@ def main(config, device, logger, vdl_writer):
|
|||
|
||||
use_amp = config["Global"].get("use_amp", False)
|
||||
amp_level = config["Global"].get("amp_level", 'O2')
|
||||
amp_dtype = config["Global"].get("amp_dtype", 'float16')
|
||||
amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
|
||||
amp_custom_white_list = config['Global'].get('amp_custom_white_list', [])
|
||||
if use_amp:
|
||||
|
@ -181,7 +182,8 @@ def main(config, device, logger, vdl_writer):
|
|||
models=model,
|
||||
optimizers=optimizer,
|
||||
level=amp_level,
|
||||
master_weight=True)
|
||||
master_weight=True,
|
||||
dtype=amp_dtype)
|
||||
else:
|
||||
scaler = None
|
||||
|
||||
|
@ -195,7 +197,8 @@ def main(config, device, logger, vdl_writer):
|
|||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||
loss_class, optimizer, lr_scheduler, post_process_class,
|
||||
eval_class, pre_best_model_dict, logger, vdl_writer, scaler,
|
||||
amp_level, amp_custom_black_list, amp_custom_white_list)
|
||||
amp_level, amp_custom_black_list, amp_custom_white_list,
|
||||
amp_dtype)
|
||||
|
||||
|
||||
def test_reader(config, device, logger):
|
||||
|
|
Loading…
Reference in New Issue