mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
add amp eval
This commit is contained in:
parent
0a247f02d4
commit
c3924a959b
@ -23,6 +23,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
|
|||||||
sys.path.insert(0, __dir__)
|
sys.path.insert(0, __dir__)
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
|
|
||||||
|
import paddle
|
||||||
from ppocr.data import build_dataloader
|
from ppocr.data import build_dataloader
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
@ -86,6 +87,30 @@ def main():
|
|||||||
else:
|
else:
|
||||||
model_type = None
|
model_type = None
|
||||||
|
|
||||||
|
# build metric
|
||||||
|
eval_class = build_metric(config['Metric'])
|
||||||
|
# amp
|
||||||
|
use_amp = config["Global"].get("use_amp", False)
|
||||||
|
amp_level = config["Global"].get("amp_level", 'O2')
|
||||||
|
amp_custom_black_list = config['Global'].get('amp_custom_black_list',[])
|
||||||
|
if use_amp:
|
||||||
|
AMP_RELATED_FLAGS_SETTING = {
|
||||||
|
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
||||||
|
'FLAGS_max_inplace_grad_add': 8,
|
||||||
|
}
|
||||||
|
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
|
||||||
|
scale_loss = config["Global"].get("scale_loss", 1.0)
|
||||||
|
use_dynamic_loss_scaling = config["Global"].get(
|
||||||
|
"use_dynamic_loss_scaling", False)
|
||||||
|
scaler = paddle.amp.GradScaler(
|
||||||
|
init_loss_scaling=scale_loss,
|
||||||
|
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
|
||||||
|
if amp_level == "O2":
|
||||||
|
model = paddle.amp.decorate(
|
||||||
|
models=model, level=amp_level, master_weight=True)
|
||||||
|
else:
|
||||||
|
scaler = None
|
||||||
|
|
||||||
best_model_dict = load_model(
|
best_model_dict = load_model(
|
||||||
config, model, model_type=config['Architecture']["model_type"])
|
config, model, model_type=config['Architecture']["model_type"])
|
||||||
if len(best_model_dict):
|
if len(best_model_dict):
|
||||||
@ -93,11 +118,9 @@ def main():
|
|||||||
for k, v in best_model_dict.items():
|
for k, v in best_model_dict.items():
|
||||||
logger.info('{}:{}'.format(k, v))
|
logger.info('{}:{}'.format(k, v))
|
||||||
|
|
||||||
# build metric
|
|
||||||
eval_class = build_metric(config['Metric'])
|
|
||||||
# start eval
|
# start eval
|
||||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||||
eval_class, model_type, extra_input)
|
eval_class, model_type, extra_input, scaler, amp_level, amp_custom_black_list)
|
||||||
logger.info('metric eval ***************')
|
logger.info('metric eval ***************')
|
||||||
for k, v in metric.items():
|
for k, v in metric.items():
|
||||||
logger.info('{}:{}'.format(k, v))
|
logger.info('{}:{}'.format(k, v))
|
||||||
|
@ -191,7 +191,8 @@ def train(config,
|
|||||||
logger,
|
logger,
|
||||||
log_writer=None,
|
log_writer=None,
|
||||||
scaler=None,
|
scaler=None,
|
||||||
amp_level='O2'):
|
amp_level='O2',
|
||||||
|
amp_custom_black_list=[]):
|
||||||
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
|
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
|
||||||
False)
|
False)
|
||||||
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
|
calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
|
||||||
@ -277,8 +278,7 @@ def train(config,
|
|||||||
model_average = True
|
model_average = True
|
||||||
# use amp
|
# use amp
|
||||||
if scaler:
|
if scaler:
|
||||||
custom_black_list = config['Global'].get('amp_custom_black_list',[])
|
with paddle.amp.auto_cast(level=amp_level, custom_black_list=amp_custom_black_list):
|
||||||
with paddle.amp.auto_cast(level=amp_level, custom_black_list=custom_black_list):
|
|
||||||
if model_type == 'table' or extra_input:
|
if model_type == 'table' or extra_input:
|
||||||
preds = model(images, data=batch[1:])
|
preds = model(images, data=batch[1:])
|
||||||
elif model_type in ["kie", 'vqa']:
|
elif model_type in ["kie", 'vqa']:
|
||||||
@ -383,7 +383,9 @@ def train(config,
|
|||||||
eval_class,
|
eval_class,
|
||||||
model_type,
|
model_type,
|
||||||
extra_input=extra_input,
|
extra_input=extra_input,
|
||||||
scaler=scaler)
|
scaler=scaler,
|
||||||
|
amp_level=amp_level,
|
||||||
|
amp_custom_black_list=amp_custom_black_list)
|
||||||
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||||
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||||
logger.info(cur_metric_str)
|
logger.info(cur_metric_str)
|
||||||
@ -474,7 +476,9 @@ def eval(model,
|
|||||||
eval_class,
|
eval_class,
|
||||||
model_type=None,
|
model_type=None,
|
||||||
extra_input=False,
|
extra_input=False,
|
||||||
scaler=None):
|
scaler=None,
|
||||||
|
amp_level='O2',
|
||||||
|
amp_custom_black_list = []):
|
||||||
model.eval()
|
model.eval()
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
total_frame = 0.0
|
total_frame = 0.0
|
||||||
@ -495,7 +499,7 @@ def eval(model,
|
|||||||
|
|
||||||
# use amp
|
# use amp
|
||||||
if scaler:
|
if scaler:
|
||||||
with paddle.amp.auto_cast(level='O2'):
|
with paddle.amp.auto_cast(level=amp_level, custom_black_list=amp_custom_black_list):
|
||||||
if model_type == 'table' or extra_input:
|
if model_type == 'table' or extra_input:
|
||||||
preds = model(images, data=batch[1:])
|
preds = model(images, data=batch[1:])
|
||||||
elif model_type in ["kie", 'vqa']:
|
elif model_type in ["kie", 'vqa']:
|
||||||
|
@ -138,9 +138,7 @@ def main(config, device, logger, vdl_writer):
|
|||||||
|
|
||||||
# build metric
|
# build metric
|
||||||
eval_class = build_metric(config['Metric'])
|
eval_class = build_metric(config['Metric'])
|
||||||
# load pretrain model
|
|
||||||
pre_best_model_dict = load_model(config, model, optimizer,
|
|
||||||
config['Architecture']["model_type"])
|
|
||||||
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
||||||
if valid_dataloader is not None:
|
if valid_dataloader is not None:
|
||||||
logger.info('valid dataloader has {} iters'.format(
|
logger.info('valid dataloader has {} iters'.format(
|
||||||
@ -148,6 +146,7 @@ def main(config, device, logger, vdl_writer):
|
|||||||
|
|
||||||
use_amp = config["Global"].get("use_amp", False)
|
use_amp = config["Global"].get("use_amp", False)
|
||||||
amp_level = config["Global"].get("amp_level", 'O2')
|
amp_level = config["Global"].get("amp_level", 'O2')
|
||||||
|
amp_custom_black_list = config['Global'].get('amp_custom_black_list',[])
|
||||||
if use_amp:
|
if use_amp:
|
||||||
AMP_RELATED_FLAGS_SETTING = {
|
AMP_RELATED_FLAGS_SETTING = {
|
||||||
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
|
||||||
@ -166,12 +165,16 @@ def main(config, device, logger, vdl_writer):
|
|||||||
else:
|
else:
|
||||||
scaler = None
|
scaler = None
|
||||||
|
|
||||||
|
# load pretrain model
|
||||||
|
pre_best_model_dict = load_model(config, model, optimizer,
|
||||||
|
config['Architecture']["model_type"])
|
||||||
|
|
||||||
if config['Global']['distributed']:
|
if config['Global']['distributed']:
|
||||||
model = paddle.DataParallel(model)
|
model = paddle.DataParallel(model)
|
||||||
# start train
|
# start train
|
||||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||||
loss_class, optimizer, lr_scheduler, post_process_class,
|
loss_class, optimizer, lr_scheduler, post_process_class,
|
||||||
eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level)
|
eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level, amp_custom_black_list)
|
||||||
|
|
||||||
|
|
||||||
def test_reader(config, device, logger):
|
def test_reader(config, device, logger):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user