[Refactor] Refactor train test interface.
parent
d9096743ef
commit
3c54de06bb
290
tools/test.py
290
tools/test.py
|
@ -2,77 +2,23 @@
|
|||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
from mmcv.cnn.utils import revert_sync_batchnorm
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
|
||||
wrap_fp16_model)
|
||||
from mmcv.utils import DictAction
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmseg import digit_version
|
||||
from mmseg.apis import multi_gpu_test, single_gpu_test
|
||||
from mmseg.datasets import build_dataloader, build_dataset
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.utils import setup_multi_processes
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
# TODO: support fuse_conv_bn, visualization, and format_only
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='mmseg test (and eval) a model')
|
||||
parser.add_argument('config', help='test config file path')
|
||||
description='MMSeg test (and eval) a model')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument('checkpoint', help='checkpoint file')
|
||||
parser.add_argument(
|
||||
'--work-dir',
|
||||
help=('if specified, the evaluation metric results will be dumped'
|
||||
'into the directory as json'))
|
||||
parser.add_argument(
|
||||
'--aug-test', action='store_true', help='Use Flip and Multi scale aug')
|
||||
parser.add_argument('--out', help='output result file in pickle format')
|
||||
parser.add_argument(
|
||||
'--format-only',
|
||||
action='store_true',
|
||||
help='Format the output results without perform evaluation. It is'
|
||||
'useful when you want to format the result to a specific format and '
|
||||
'submit it to the test server')
|
||||
parser.add_argument(
|
||||
'--eval',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
|
||||
' for generic datasets, and "cityscapes" for Cityscapes')
|
||||
parser.add_argument('--show', action='store_true', help='show results')
|
||||
parser.add_argument(
|
||||
'--show-dir', help='directory where painted images will be saved')
|
||||
parser.add_argument(
|
||||
'--gpu-collect',
|
||||
action='store_true',
|
||||
help='whether to use gpu to collect results.')
|
||||
parser.add_argument(
|
||||
'--gpu-id',
|
||||
type=int,
|
||||
default=0,
|
||||
help='id of gpu to use '
|
||||
'(only applicable to non-distributed testing)')
|
||||
parser.add_argument(
|
||||
'--tmpdir',
|
||||
help='tmp directory used for collecting results from multiple '
|
||||
'workers, available when gpu_collect is not specified')
|
||||
parser.add_argument(
|
||||
'--options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help="--options is deprecated in favor of --cfg_options' and it will "
|
||||
'not be supported in version v0.22.0. Override some settings in the '
|
||||
'used config, the key-value pair in xxx=yyy format will be merged '
|
||||
'into config file. If the value to be overwritten is a list, it '
|
||||
'should be like key="[a,b]" or key=a,b It also allows nested '
|
||||
'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
|
||||
'marks are necessary and that no white space is allowed.')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
|
@ -83,236 +29,48 @@ def parse_args():
|
|||
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
||||
'Note that the quotation marks are necessary and that no white space '
|
||||
'is allowed.')
|
||||
parser.add_argument(
|
||||
'--eval-options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help='custom options for evaluation')
|
||||
parser.add_argument(
|
||||
'--launcher',
|
||||
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
||||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument(
|
||||
'--opacity',
|
||||
type=float,
|
||||
default=0.5,
|
||||
help='Opacity of painted segmentation map. In (0, 1] range.')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
if args.options and args.cfg_options:
|
||||
raise ValueError(
|
||||
'--options and --cfg-options cannot be both '
|
||||
'specified, --options is deprecated in favor of --cfg-options. '
|
||||
'--options will not be supported in version v0.22.0.')
|
||||
if args.options:
|
||||
warnings.warn('--options is deprecated in favor of --cfg-options. '
|
||||
'--options will not be supported in version v0.22.0.')
|
||||
args.cfg_options = args.options
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
assert args.out or args.eval or args.format_only or args.show \
|
||||
or args.show_dir, \
|
||||
('Please specify at least one operation (save/eval/format/show the '
|
||||
'results / save the results) with the argument "--out", "--eval"'
|
||||
', "--format-only", "--show" or "--show-dir"')
|
||||
|
||||
if args.eval and args.format_only:
|
||||
raise ValueError('--eval and --format_only cannot be both specified')
|
||||
# register all modules in mmseg into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
|
||||
raise ValueError('The output file must be a pkl file.')
|
||||
|
||||
cfg = mmcv.Config.fromfile(args.config)
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.launcher = args.launcher
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# set multi-process settings
|
||||
setup_multi_processes(cfg)
|
||||
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
if args.aug_test:
|
||||
# hard code index
|
||||
cfg.data.test.pipeline[1].img_ratios = [
|
||||
0.5, 0.75, 1.0, 1.25, 1.5, 1.75
|
||||
]
|
||||
cfg.data.test.pipeline[1].flip = True
|
||||
cfg.model.pretrained = None
|
||||
cfg.data.test.test_mode = True
|
||||
|
||||
if args.gpu_id is not None:
|
||||
cfg.gpu_ids = [args.gpu_id]
|
||||
|
||||
# init distributed env first, since logger depends on the dist info.
|
||||
if args.launcher == 'none':
|
||||
cfg.gpu_ids = [args.gpu_id]
|
||||
distributed = False
|
||||
if len(cfg.gpu_ids) > 1:
|
||||
warnings.warn(f'The gpu-ids is reset from {cfg.gpu_ids} to '
|
||||
f'{cfg.gpu_ids[0:1]} to avoid potential error in '
|
||||
'non-distribute testing time.')
|
||||
cfg.gpu_ids = cfg.gpu_ids[0:1]
|
||||
else:
|
||||
distributed = True
|
||||
init_dist(args.launcher, **cfg.dist_params)
|
||||
|
||||
rank, _ = get_dist_info()
|
||||
# allows not to create
|
||||
if args.work_dir is not None and rank == 0:
|
||||
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
if args.aug_test:
|
||||
json_file = osp.join(args.work_dir,
|
||||
f'eval_multi_scale_{timestamp}.json')
|
||||
else:
|
||||
json_file = osp.join(args.work_dir,
|
||||
f'eval_single_scale_{timestamp}.json')
|
||||
elif rank == 0:
|
||||
work_dir = osp.join('./work_dirs',
|
||||
# work_dir is determined in this priority: CLI > segment in file > filename
|
||||
if args.work_dir is not None:
|
||||
# update configs according to CLI args if args.work_dir is not None
|
||||
cfg.work_dir = args.work_dir
|
||||
elif cfg.get('work_dir', None) is None:
|
||||
# use config filename as default work_dir if cfg.work_dir is None
|
||||
cfg.work_dir = osp.join('./work_dirs',
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
mmcv.mkdir_or_exist(osp.abspath(work_dir))
|
||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
if args.aug_test:
|
||||
json_file = osp.join(work_dir,
|
||||
f'eval_multi_scale_{timestamp}.json')
|
||||
else:
|
||||
json_file = osp.join(work_dir,
|
||||
f'eval_single_scale_{timestamp}.json')
|
||||
|
||||
# build the dataloader
|
||||
# TODO: support multiple images per gpu (only minor changes are needed)
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
# The default loader config
|
||||
loader_cfg = dict(
|
||||
# cfg.gpus will be ignored if distributed
|
||||
num_gpus=len(cfg.gpu_ids),
|
||||
dist=distributed,
|
||||
shuffle=False)
|
||||
# The overall dataloader settings
|
||||
loader_cfg.update({
|
||||
k: v
|
||||
for k, v in cfg.data.items() if k not in [
|
||||
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
||||
'test_dataloader'
|
||||
]
|
||||
})
|
||||
test_loader_cfg = {
|
||||
**loader_cfg,
|
||||
'samples_per_gpu': 1,
|
||||
'shuffle': False, # Not shuffle by default
|
||||
**cfg.data.get('test_dataloader', {})
|
||||
}
|
||||
# build the dataloader
|
||||
data_loader = build_dataloader(dataset, **test_loader_cfg)
|
||||
cfg.load_from = args.checkpoint
|
||||
|
||||
# build the model and load checkpoint
|
||||
cfg.model.train_cfg = None
|
||||
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
|
||||
fp16_cfg = cfg.get('fp16', None)
|
||||
if fp16_cfg is not None:
|
||||
wrap_fp16_model(model)
|
||||
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||
if 'CLASSES' in checkpoint.get('meta', {}):
|
||||
model.CLASSES = checkpoint['meta']['CLASSES']
|
||||
else:
|
||||
print('"CLASSES" not found in meta, use dataset.CLASSES instead')
|
||||
model.CLASSES = dataset.CLASSES
|
||||
if 'PALETTE' in checkpoint.get('meta', {}):
|
||||
model.PALETTE = checkpoint['meta']['PALETTE']
|
||||
else:
|
||||
print('"PALETTE" not found in meta, use dataset.PALETTE instead')
|
||||
model.PALETTE = dataset.PALETTE
|
||||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# clean gpu memory when starting a new evaluation.
|
||||
torch.cuda.empty_cache()
|
||||
eval_kwargs = {} if args.eval_options is None else args.eval_options
|
||||
|
||||
# Deprecated
|
||||
efficient_test = eval_kwargs.get('efficient_test', False)
|
||||
if efficient_test:
|
||||
warnings.warn(
|
||||
'``efficient_test=True`` does not have effect in tools/test.py, '
|
||||
'the evaluation and format results are CPU memory efficient by '
|
||||
'default')
|
||||
|
||||
eval_on_format_results = (
|
||||
args.eval is not None and 'cityscapes' in args.eval)
|
||||
if eval_on_format_results:
|
||||
assert len(args.eval) == 1, 'eval on format results is not ' \
|
||||
'applicable for metrics other than ' \
|
||||
'cityscapes'
|
||||
if args.format_only or eval_on_format_results:
|
||||
if 'imgfile_prefix' in eval_kwargs:
|
||||
tmpdir = eval_kwargs['imgfile_prefix']
|
||||
else:
|
||||
tmpdir = '.format_cityscapes'
|
||||
eval_kwargs.setdefault('imgfile_prefix', tmpdir)
|
||||
mmcv.mkdir_or_exist(tmpdir)
|
||||
else:
|
||||
tmpdir = None
|
||||
|
||||
if not distributed:
|
||||
warnings.warn(
|
||||
'SyncBN is only supported with DDP. To be compatible with DP, '
|
||||
'we convert SyncBN to BN. Please use dist_train.sh which can '
|
||||
'avoid this error.')
|
||||
if not torch.cuda.is_available():
|
||||
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
|
||||
'Please use MMCV >= 1.4.4 for CPU training!'
|
||||
model = revert_sync_batchnorm(model)
|
||||
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
|
||||
results = single_gpu_test(
|
||||
model,
|
||||
data_loader,
|
||||
args.show,
|
||||
args.show_dir,
|
||||
False,
|
||||
args.opacity,
|
||||
pre_eval=args.eval is not None and not eval_on_format_results,
|
||||
format_only=args.format_only or eval_on_format_results,
|
||||
format_args=eval_kwargs)
|
||||
else:
|
||||
model = MMDistributedDataParallel(
|
||||
model.cuda(),
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
broadcast_buffers=False)
|
||||
results = multi_gpu_test(
|
||||
model,
|
||||
data_loader,
|
||||
args.tmpdir,
|
||||
args.gpu_collect,
|
||||
False,
|
||||
pre_eval=args.eval is not None and not eval_on_format_results,
|
||||
format_only=args.format_only or eval_on_format_results,
|
||||
format_args=eval_kwargs)
|
||||
|
||||
rank, _ = get_dist_info()
|
||||
if rank == 0:
|
||||
if args.out:
|
||||
warnings.warn(
|
||||
'The behavior of ``args.out`` has been changed since MMSeg '
|
||||
'v0.16, the pickled outputs could be seg map as type of '
|
||||
'np.array, pre-eval results or file paths for '
|
||||
'``dataset.format_results()``.')
|
||||
print(f'\nwriting results to {args.out}')
|
||||
mmcv.dump(results, args.out)
|
||||
if args.eval:
|
||||
eval_kwargs.update(metric=args.eval)
|
||||
metric = dataset.evaluate(results, **eval_kwargs)
|
||||
metric_dict = dict(config=args.config, metric=metric)
|
||||
mmcv.dump(metric_dict, json_file, indent=4)
|
||||
if tmpdir is not None and eval_on_format_results:
|
||||
# remove tmp dir when cityscapes evaluation
|
||||
shutil.rmtree(tmpdir)
|
||||
# start testing
|
||||
runner.test()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
197
tools/train.py
197
tools/train.py
|
@ -1,75 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.cnn.utils import revert_sync_batchnorm
|
||||
from mmcv.runner import get_dist_info, init_dist
|
||||
from mmcv.utils import Config, DictAction, get_git_hash
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmseg import __version__
|
||||
from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
|
||||
from mmseg.datasets import build_dataset
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.utils import collect_env, get_root_logger, setup_multi_processes
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train a segmentor')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
||||
parser.add_argument(
|
||||
'--load-from', help='the checkpoint file to load weights from')
|
||||
parser.add_argument(
|
||||
'--resume-from', help='the checkpoint file to resume from')
|
||||
parser.add_argument(
|
||||
'--no-validate',
|
||||
action='store_true',
|
||||
help='whether not to evaluate the checkpoint during training')
|
||||
group_gpus = parser.add_mutually_exclusive_group()
|
||||
group_gpus.add_argument(
|
||||
'--gpus',
|
||||
type=int,
|
||||
help='(Deprecated, please use --gpu-id) number of gpus to use '
|
||||
'(only applicable to non-distributed training)')
|
||||
group_gpus.add_argument(
|
||||
'--gpu-ids',
|
||||
type=int,
|
||||
nargs='+',
|
||||
help='(Deprecated, please use --gpu-id) ids of gpus to use '
|
||||
'(only applicable to non-distributed training)')
|
||||
group_gpus.add_argument(
|
||||
'--gpu-id',
|
||||
type=int,
|
||||
default=0,
|
||||
help='id of gpu to use '
|
||||
'(only applicable to non-distributed training)')
|
||||
parser.add_argument('--seed', type=int, default=None, help='random seed')
|
||||
parser.add_argument(
|
||||
'--diff_seed',
|
||||
action='store_true',
|
||||
help='Whether or not set different seeds for different ranks')
|
||||
parser.add_argument(
|
||||
'--deterministic',
|
||||
action='store_true',
|
||||
help='whether to set deterministic options for CUDNN backend.')
|
||||
parser.add_argument(
|
||||
'--options',
|
||||
nargs='+',
|
||||
action=DictAction,
|
||||
help="--options is deprecated in favor of --cfg_options' and it will "
|
||||
'not be supported in version v0.22.0. Override some settings in the '
|
||||
'used config, the key-value pair in xxx=yyy format will be merged '
|
||||
'into config file. If the value to be overwritten is a list, it '
|
||||
'should be like key="[a,b]" or key=a,b It also allows nested '
|
||||
'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
|
||||
'marks are necessary and that no white space is allowed.')
|
||||
parser.add_argument(
|
||||
'--cfg-options',
|
||||
nargs='+',
|
||||
|
@ -86,38 +29,26 @@ def parse_args():
|
|||
default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
parser.add_argument(
|
||||
'--auto-resume',
|
||||
action='store_true',
|
||||
help='resume from the latest checkpoint automatically.')
|
||||
args = parser.parse_args()
|
||||
if 'LOCAL_RANK' not in os.environ:
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
if args.options and args.cfg_options:
|
||||
raise ValueError(
|
||||
'--options and --cfg-options cannot be both '
|
||||
'specified, --options is deprecated in favor of --cfg-options. '
|
||||
'--options will not be supported in version v0.22.0.')
|
||||
if args.options:
|
||||
warnings.warn('--options is deprecated in favor of --cfg-options. '
|
||||
'--options will not be supported in version v0.22.0.')
|
||||
args.cfg_options = args.options
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmseg into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.launcher = args.launcher
|
||||
if args.cfg_options is not None:
|
||||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# set cudnn_benchmark
|
||||
if cfg.get('cudnn_benchmark', False):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# work_dir is determined in this priority: CLI > segment in file > filename
|
||||
if args.work_dir is not None:
|
||||
# update configs according to CLI args if args.work_dir is not None
|
||||
|
@ -126,114 +57,12 @@ def main():
|
|||
# use config filename as default work_dir if cfg.work_dir is None
|
||||
cfg.work_dir = osp.join('./work_dirs',
|
||||
osp.splitext(osp.basename(args.config))[0])
|
||||
if args.load_from is not None:
|
||||
cfg.load_from = args.load_from
|
||||
if args.resume_from is not None:
|
||||
cfg.resume_from = args.resume_from
|
||||
if args.gpus is not None:
|
||||
cfg.gpu_ids = range(1)
|
||||
warnings.warn('`--gpus` is deprecated because we only support '
|
||||
'single GPU mode in non-distributed training. '
|
||||
'Use `gpus=1` now.')
|
||||
if args.gpu_ids is not None:
|
||||
cfg.gpu_ids = args.gpu_ids[0:1]
|
||||
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
|
||||
'Because we only support single GPU mode in '
|
||||
'non-distributed training. Use the first GPU '
|
||||
'in `gpu_ids` now.')
|
||||
if args.gpus is None and args.gpu_ids is None:
|
||||
cfg.gpu_ids = [args.gpu_id]
|
||||
|
||||
cfg.auto_resume = args.auto_resume
|
||||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# init distributed env first, since logger depends on the dist info.
|
||||
if args.launcher == 'none':
|
||||
distributed = False
|
||||
else:
|
||||
distributed = True
|
||||
init_dist(args.launcher, **cfg.dist_params)
|
||||
# gpu_ids is used to calculate iter when resuming checkpoint
|
||||
_, world_size = get_dist_info()
|
||||
cfg.gpu_ids = range(world_size)
|
||||
|
||||
# create work_dir
|
||||
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
|
||||
# dump config
|
||||
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
|
||||
# init the logger before other steps
|
||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
|
||||
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
|
||||
|
||||
# set multi-process settings
|
||||
setup_multi_processes(cfg)
|
||||
|
||||
# init the meta dict to record some important information such as
|
||||
# environment info and seed, which will be logged
|
||||
meta = dict()
|
||||
# log env info
|
||||
env_info_dict = collect_env()
|
||||
env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
|
||||
dash_line = '-' * 60 + '\n'
|
||||
logger.info('Environment info:\n' + dash_line + env_info + '\n' +
|
||||
dash_line)
|
||||
meta['env_info'] = env_info
|
||||
|
||||
# log some basic info
|
||||
logger.info(f'Distributed training: {distributed}')
|
||||
logger.info(f'Config:\n{cfg.pretty_text}')
|
||||
|
||||
# set random seeds
|
||||
seed = init_random_seed(args.seed)
|
||||
seed = seed + dist.get_rank() if args.diff_seed else seed
|
||||
logger.info(f'Set random seed to {seed}, '
|
||||
f'deterministic: {args.deterministic}')
|
||||
set_random_seed(seed, deterministic=args.deterministic)
|
||||
cfg.seed = seed
|
||||
meta['seed'] = seed
|
||||
meta['exp_name'] = osp.basename(args.config)
|
||||
|
||||
model = build_segmentor(
|
||||
cfg.model,
|
||||
train_cfg=cfg.get('train_cfg'),
|
||||
test_cfg=cfg.get('test_cfg'))
|
||||
model.init_weights()
|
||||
|
||||
# SyncBN is not support for DP
|
||||
if not distributed:
|
||||
warnings.warn(
|
||||
'SyncBN is only supported with DDP. To be compatible with DP, '
|
||||
'we convert SyncBN to BN. Please use dist_train.sh which can '
|
||||
'avoid this error.')
|
||||
model = revert_sync_batchnorm(model)
|
||||
|
||||
logger.info(model)
|
||||
|
||||
datasets = [build_dataset(cfg.data.train)]
|
||||
if len(cfg.workflow) == 2:
|
||||
val_dataset = copy.deepcopy(cfg.data.val)
|
||||
val_dataset.pipeline = cfg.data.train.pipeline
|
||||
datasets.append(build_dataset(val_dataset))
|
||||
if cfg.checkpoint_config is not None:
|
||||
# save mmseg version, config file content and class names in
|
||||
# checkpoints as meta data
|
||||
cfg.checkpoint_config.meta = dict(
|
||||
mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
|
||||
config=cfg.pretty_text,
|
||||
CLASSES=datasets[0].CLASSES,
|
||||
PALETTE=datasets[0].PALETTE)
|
||||
# add an attribute for visualization convenience
|
||||
model.CLASSES = datasets[0].CLASSES
|
||||
# passing checkpoint meta for saving best checkpoint
|
||||
meta.update(cfg.checkpoint_config.meta)
|
||||
train_segmentor(
|
||||
model,
|
||||
datasets,
|
||||
cfg,
|
||||
distributed=distributed,
|
||||
validate=(not args.no_validate),
|
||||
timestamp=timestamp,
|
||||
meta=meta)
|
||||
# start training
|
||||
runner.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue