mirror of https://github.com/open-mmlab/mmocr.git
[Feat] support fp16 auto resume and auto scale lr
parent
dc180443b8
commit
1b33ff5d76
|
@ -1,9 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmocr.utils import register_all_modules
|
||||
|
@ -13,6 +15,18 @@ def parse_args():
|
|||
parser = argparse.ArgumentParser(description='Train a model')
|
||||
parser.add_argument('config', help='Train config file path')
|
||||
parser.add_argument('--work-dir', help='The dir to save logs and models')
|
||||
parser.add_argument(
|
||||
'--resume', action='store_true', help='Whether to resume checkpoint.')
|
||||
parser.add_argument(
|
||||
'--amp',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Enable automatic-mixed-precision training')
|
||||
parser.add_argument(
|
||||
'--auto-scale-lr',
|
||||
action='store_true',
|
||||
help='Whether to scale the learning rate automatically. It requires '
|
||||
'`auto_scale_lr` in config, and `base_batch_size` in `auto_scale_lr`')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
|
@ -57,7 +71,31 @@ def main():
|
|||
# use config filename as default work_dir if cfg.work_dir is None
|
||||
cfg.work_dir = osp.join('./work_dirs',
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
|
||||
# enable automatic-mixed-precision training
|
||||
if args.amp:
|
||||
optim_wrapper = cfg.optim_wrapper.type
|
||||
if optim_wrapper == 'AmpOptimWrapper':
|
||||
print_log(
|
||||
'AMP training is already enabled in your config.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
else:
|
||||
assert optim_wrapper == 'OptimWrapper', (
|
||||
'`--amp` is only supported when the optimizer wrapper type is '
|
||||
f'`OptimWrapper` but got {optim_wrapper}.')
|
||||
cfg.optim_wrapper.type = 'AmpOptimWrapper'
|
||||
cfg.optim_wrapper.loss_scale = 'dynamic'
|
||||
if args.resume:
|
||||
cfg.resume = True
|
||||
if args.auto_scale_lr:
|
||||
if cfg.get('auto_scale_lr'):
|
||||
cfg.auto_scale_lr = True
|
||||
else:
|
||||
print_log(
|
||||
'auto_scale_lr does not exist in your config, '
|
||||
'please set `auto_scale_lr = dict(base_batch_size=xx)',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
|
|
Loading…
Reference in New Issue