[Feat] support fp16 auto resume and auto scale lr

pull/1178/head
liukuikun 2022-07-15 11:56:27 +00:00 committed by gaotongxiao
parent dc180443b8
commit 1b33ff5d76
1 changed files with 39 additions and 1 deletions

View File

@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
import os
import os.path as osp
from mmengine.config import Config, DictAction
from mmengine.logging import print_log
from mmengine.runner import Runner
from mmocr.utils import register_all_modules
@ -13,6 +15,18 @@ def parse_args():
parser = argparse.ArgumentParser(description='Train a model')
parser.add_argument('config', help='Train config file path')
parser.add_argument('--work-dir', help='The dir to save logs and models')
parser.add_argument(
'--resume', action='store_true', help='Whether to resume checkpoint.')
parser.add_argument(
'--amp',
action='store_true',
default=False,
help='Enable automatic-mixed-precision training')
parser.add_argument(
'--auto-scale-lr',
action='store_true',
help='Whether to scale the learning rate automatically. It requires '
'`auto_scale_lr` in config, and `base_batch_size` in `auto_scale_lr`')
parser.add_argument(
'--cfg-options',
nargs='+',
@ -57,7 +71,31 @@ def main():
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
# enable automatic-mixed-precision training
if args.amp:
optim_wrapper = cfg.optim_wrapper.type
if optim_wrapper == 'AmpOptimWrapper':
print_log(
'AMP training is already enabled in your config.',
logger='current',
level=logging.WARNING)
else:
assert optim_wrapper == 'OptimWrapper', (
'`--amp` is only supported when the optimizer wrapper type is '
f'`OptimWrapper` but got {optim_wrapper}.')
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic'
if args.resume:
cfg.resume = True
if args.auto_scale_lr:
if cfg.get('auto_scale_lr'):
cfg.auto_scale_lr = True
else:
print_log(
'auto_scale_lr does not exist in your config, '
'please set `auto_scale_lr = dict(base_batch_size=xx)',
logger='current',
level=logging.WARNING)
# build the runner from config
runner = Runner.from_cfg(cfg)