diff --git a/tools/train.py b/tools/train.py index 4e5e7f85..76488794 100755 --- a/tools/train.py +++ b/tools/train.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import logging import os import os.path as osp from mmengine.config import Config, DictAction +from mmengine.logging import print_log from mmengine.runner import Runner from mmocr.utils import register_all_modules @@ -13,6 +15,18 @@ def parse_args(): parser = argparse.ArgumentParser(description='Train a model') parser.add_argument('config', help='Train config file path') parser.add_argument('--work-dir', help='The dir to save logs and models') + parser.add_argument( + '--resume', action='store_true', help='Whether to resume checkpoint.') + parser.add_argument( + '--amp', + action='store_true', + default=False, + help='Enable automatic-mixed-precision training') + parser.add_argument( + '--auto-scale-lr', + action='store_true', + help='Whether to scale the learning rate automatically. It requires ' + '`auto_scale_lr` in config, and `base_batch_size` in `auto_scale_lr`') parser.add_argument( '--cfg-options', nargs='+', @@ -57,7 +71,31 @@ def main(): # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) - + # enable automatic-mixed-precision training + if args.amp: + optim_wrapper = cfg.optim_wrapper.type + if optim_wrapper == 'AmpOptimWrapper': + print_log( + 'AMP training is already enabled in your config.', + logger='current', + level=logging.WARNING) + else: + assert optim_wrapper == 'OptimWrapper', ( + '`--amp` is only supported when the optimizer wrapper type is ' + f'`OptimWrapper` but got {optim_wrapper}.') + cfg.optim_wrapper.type = 'AmpOptimWrapper' + cfg.optim_wrapper.loss_scale = 'dynamic' + if args.resume: + cfg.resume = True + if args.auto_scale_lr: + if cfg.get('auto_scale_lr'): + cfg.auto_scale_lr = True + else: + print_log( + 'auto_scale_lr does not exist in your config, ' + 'please set `auto_scale_lr = dict(base_batch_size=xx)', + logger='current', + level=logging.WARNING) # build the runner from config runner = Runner.from_cfg(cfg)