195 lines
6.9 KiB
Python
195 lines
6.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
import random
|
|
import warnings
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import torch
|
|
import torch.distributed as dist
|
|
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
|
|
build_runner, get_dist_info)
|
|
from mmcv.utils import build_from_cfg
|
|
|
|
from mmseg import digit_version
|
|
from mmseg.core import DistEvalHook, EvalHook, build_optimizer
|
|
from mmseg.datasets import build_dataloader, build_dataset
|
|
from mmseg.utils import (build_ddp, build_dp, find_latest_checkpoint,
|
|
get_root_logger)
|
|
|
|
|
|
def init_random_seed(seed=None, device='cuda'):
|
|
"""Initialize random seed.
|
|
|
|
If the seed is not set, the seed will be automatically randomized,
|
|
and then broadcast to all processes to prevent some potential bugs.
|
|
Args:
|
|
seed (int, Optional): The seed. Default to None.
|
|
device (str): The device where the seed will be put on.
|
|
Default to 'cuda'.
|
|
Returns:
|
|
int: Seed to be used.
|
|
"""
|
|
if seed is not None:
|
|
return seed
|
|
|
|
# Make sure all ranks share the same random seed to prevent
|
|
# some potential bugs. Please refer to
|
|
# https://github.com/open-mmlab/mmdetection/issues/6339
|
|
rank, world_size = get_dist_info()
|
|
seed = np.random.randint(2**31)
|
|
if world_size == 1:
|
|
return seed
|
|
|
|
if rank == 0:
|
|
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
|
else:
|
|
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
|
dist.broadcast(random_num, src=0)
|
|
return random_num.item()
|
|
|
|
|
|
def set_random_seed(seed, deterministic=False):
|
|
"""Set random seed.
|
|
|
|
Args:
|
|
seed (int): Seed to be used.
|
|
deterministic (bool): Whether to set the deterministic option for
|
|
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
|
|
to True and `torch.backends.cudnn.benchmark` to False.
|
|
Default: False.
|
|
"""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
if deterministic:
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
def train_segmentor(model,
|
|
dataset,
|
|
cfg,
|
|
distributed=False,
|
|
validate=False,
|
|
timestamp=None,
|
|
meta=None):
|
|
"""Launch segmentor training."""
|
|
logger = get_root_logger(cfg.log_level)
|
|
|
|
# prepare data loaders
|
|
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
|
|
# The default loader config
|
|
loader_cfg = dict(
|
|
# cfg.gpus will be ignored if distributed
|
|
num_gpus=len(cfg.gpu_ids),
|
|
dist=distributed,
|
|
seed=cfg.seed,
|
|
drop_last=True)
|
|
# The overall dataloader settings
|
|
loader_cfg.update({
|
|
k: v
|
|
for k, v in cfg.data.items() if k not in [
|
|
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
|
|
'test_dataloader'
|
|
]
|
|
})
|
|
|
|
# The specific dataloader settings
|
|
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
|
|
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
|
|
|
|
# put model on devices
|
|
if distributed:
|
|
find_unused_parameters = cfg.get('find_unused_parameters', False)
|
|
# Sets the `find_unused_parameters` parameter in
|
|
# DDP wrapper
|
|
model = build_ddp(
|
|
model,
|
|
cfg.device,
|
|
device_ids=[int(os.environ['LOCAL_RANK'])],
|
|
broadcast_buffers=False,
|
|
find_unused_parameters=find_unused_parameters)
|
|
else:
|
|
if not torch.cuda.is_available():
|
|
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
|
|
'Please use MMCV >= 1.4.4 for CPU training!'
|
|
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
|
|
|
|
# build runner
|
|
optimizer = build_optimizer(model, cfg.optimizer)
|
|
|
|
if cfg.get('runner') is None:
|
|
cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
|
|
warnings.warn(
|
|
'config is now expected to have a `runner` section, '
|
|
'please set `runner` in your config.', UserWarning)
|
|
|
|
runner = build_runner(
|
|
cfg.runner,
|
|
default_args=dict(
|
|
model=model,
|
|
batch_processor=None,
|
|
optimizer=optimizer,
|
|
work_dir=cfg.work_dir,
|
|
logger=logger,
|
|
meta=meta))
|
|
|
|
# register hooks
|
|
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
|
|
cfg.checkpoint_config, cfg.log_config,
|
|
cfg.get('momentum_config', None))
|
|
if distributed:
|
|
# when distributed training by epoch, using`DistSamplerSeedHook` to set
|
|
# the different seed to distributed sampler for each epoch, it will
|
|
# shuffle dataset at each epoch and avoid overfitting.
|
|
if isinstance(runner, EpochBasedRunner):
|
|
runner.register_hook(DistSamplerSeedHook())
|
|
|
|
# an ugly walkaround to make the .log and .log.json filenames the same
|
|
runner.timestamp = timestamp
|
|
|
|
# register eval hooks
|
|
if validate:
|
|
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
|
|
# The specific dataloader settings
|
|
val_loader_cfg = {
|
|
**loader_cfg,
|
|
'samples_per_gpu': 1,
|
|
'shuffle': False, # Not shuffle by default
|
|
**cfg.data.get('val_dataloader', {}),
|
|
}
|
|
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
|
|
eval_cfg = cfg.get('evaluation', {})
|
|
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
|
eval_hook = DistEvalHook if distributed else EvalHook
|
|
# In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
|
|
# priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
|
|
runner.register_hook(
|
|
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
|
|
|
|
# user-defined hooks
|
|
if cfg.get('custom_hooks', None):
|
|
custom_hooks = cfg.custom_hooks
|
|
assert isinstance(custom_hooks, list), \
|
|
f'custom_hooks expect list type, but got {type(custom_hooks)}'
|
|
for hook_cfg in cfg.custom_hooks:
|
|
assert isinstance(hook_cfg, dict), \
|
|
'Each item in custom_hooks expects dict type, but got ' \
|
|
f'{type(hook_cfg)}'
|
|
hook_cfg = hook_cfg.copy()
|
|
priority = hook_cfg.pop('priority', 'NORMAL')
|
|
hook = build_from_cfg(hook_cfg, HOOKS)
|
|
runner.register_hook(hook, priority=priority)
|
|
|
|
if cfg.resume_from is None and cfg.get('auto_resume'):
|
|
resume_from = find_latest_checkpoint(cfg.work_dir)
|
|
if resume_from is not None:
|
|
cfg.resume_from = resume_from
|
|
if cfg.resume_from:
|
|
runner.resume(cfg.resume_from)
|
|
elif cfg.load_from:
|
|
runner.load_checkpoint(cfg.load_from)
|
|
runner.run(data_loaders, cfg.workflow)
|