support '--amp' option
parent
b2174812bb
commit
741190a864
|
@ -1,9 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmseg.utils import register_all_modules
|
||||
|
@ -13,6 +15,11 @@ def parse_args():
|
|||
parser = argparse.ArgumentParser(description='Train a segmentor')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
||||
parser.add_argument(
|
||||
'--amp',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='enable automatic-mixed-precision training')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
|
@ -58,6 +65,21 @@ def main():
|
|||
cfg.work_dir = osp.join('./work_dirs',
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
|
||||
# enable automatic-mixed-precision training
|
||||
if args.amp is True:
|
||||
optim_wrapper = cfg.optim_wrapper.type
|
||||
if optim_wrapper == 'AmpOptimWrapper':
|
||||
print_log(
|
||||
'AMP training is already enabled in your config.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
else:
|
||||
assert optim_wrapper == 'OptimWrapper', (
|
||||
'`--amp` is only supported when the optimizer wrapper type is '
|
||||
f'`OptimWrapper` but got {optim_wrapper}.')
|
||||
cfg.optim_wrapper.type = 'AmpOptimWrapper'
|
||||
cfg.optim_wrapper.loss_scale = 'dynamic'
|
||||
|
||||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
|
|
Loading…
Reference in New Issue