diff --git a/tools/train.py b/tools/train.py index 43196ba56..878d78c31 100644 --- a/tools/train.py +++ b/tools/train.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import logging import os import os.path as osp from mmengine.config import Config, DictAction +from mmengine.logging import print_log from mmengine.runner import Runner from mmseg.utils import register_all_modules @@ -13,6 +15,11 @@ def parse_args(): parser = argparse.ArgumentParser(description='Train a segmentor') parser.add_argument('config', help='train config file path') parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--amp', + action='store_true', + default=False, + help='enable automatic-mixed-precision training') parser.add_argument( '--cfg-options', nargs='+', @@ -58,6 +65,21 @@ def main(): cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) + # enable automatic-mixed-precision training + if args.amp is True: + optim_wrapper = cfg.optim_wrapper.type + if optim_wrapper == 'AmpOptimWrapper': + print_log( + 'AMP training is already enabled in your config.', + logger='current', + level=logging.WARNING) + else: + assert optim_wrapper == 'OptimWrapper', ( + '`--amp` is only supported when the optimizer wrapper type is ' + f'`OptimWrapper` but got {optim_wrapper}.') + cfg.optim_wrapper.type = 'AmpOptimWrapper' + cfg.optim_wrapper.loss_scale = 'dynamic' + # build the runner from config runner = Runner.from_cfg(cfg)