support '--amp' option

This commit is contained in:
xiexinch 2022-07-14 14:21:32 +08:00
parent b2174812bb
commit 741190a864

View File

@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import argparse import argparse
import logging
import os import os
import os.path as osp import os.path as osp
from mmengine.config import Config, DictAction from mmengine.config import Config, DictAction
from mmengine.logging import print_log
from mmengine.runner import Runner from mmengine.runner import Runner
from mmseg.utils import register_all_modules from mmseg.utils import register_all_modules
@ -13,6 +15,11 @@ def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor') parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('config', help='train config file path') parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--amp',
action='store_true',
default=False,
help='enable automatic-mixed-precision training')
parser.add_argument( parser.add_argument(
'--cfg-options', '--cfg-options',
nargs='+', nargs='+',
@ -58,6 +65,21 @@ def main():
cfg.work_dir = osp.join('./work_dirs', cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0]) osp.splitext(osp.basename(args.config))[0])
# enable automatic-mixed-precision training
if args.amp is True:
optim_wrapper = cfg.optim_wrapper.type
if optim_wrapper == 'AmpOptimWrapper':
print_log(
'AMP training is already enabled in your config.',
logger='current',
level=logging.WARNING)
else:
assert optim_wrapper == 'OptimWrapper', (
'`--amp` is only supported when the optimizer wrapper type is '
f'`OptimWrapper` but got {optim_wrapper}.')
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic'
# build the runner from config # build the runner from config
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)