diff --git a/mmseg/apis/__init__.py b/mmseg/apis/__init__.py index ba5ab7736..c68818053 100644 --- a/mmseg/apis/__init__.py +++ b/mmseg/apis/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .inference import inference_segmentor, init_segmentor, show_result_pyplot from .test import multi_gpu_test, single_gpu_test -from .train import get_root_logger, set_random_seed, train_segmentor +from .train import (get_root_logger, init_random_seed, set_random_seed, + train_segmentor) __all__ = [ 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', - 'show_result_pyplot' + 'show_result_pyplot', 'init_random_seed' ] diff --git a/mmseg/apis/train.py b/mmseg/apis/train.py index 811c98310..7e1096bce 100644 --- a/mmseg/apis/train.py +++ b/mmseg/apis/train.py @@ -4,8 +4,9 @@ import warnings import numpy as np import torch +import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import HOOKS, build_optimizer, build_runner +from mmcv.runner import HOOKS, build_optimizer, build_runner, get_dist_info from mmcv.utils import build_from_cfg from mmseg.core import DistEvalHook, EvalHook @@ -13,6 +14,37 @@ from mmseg.datasets import build_dataloader, build_dataset from mmseg.utils import get_root_logger +def init_random_seed(seed=None, device='cuda'): + """Initialize random seed. + + If the seed is not set, the seed will be automatically randomized, + and then broadcast to all processes to prevent some potential bugs. + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + Returns: + int: Seed to be used. + """ + if seed is not None: + return seed + + # Make sure all ranks share the same random seed to prevent + # some potential bugs. Please refer to + # https://github.com/open-mmlab/mmdetection/issues/6339 + rank, world_size = get_dist_info() + seed = np.random.randint(2**31) + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() + + def set_random_seed(seed, deterministic=False): """Set random seed. diff --git a/tools/train.py b/tools/train.py index 208ca5ee1..29ea15d5d 100644 --- a/tools/train.py +++ b/tools/train.py @@ -13,7 +13,7 @@ from mmcv.runner import get_dist_info, init_dist from mmcv.utils import Config, DictAction, get_git_hash from mmseg import __version__ -from mmseg.apis import set_random_seed, train_segmentor +from mmseg.apis import init_random_seed, set_random_seed, train_segmentor from mmseg.datasets import build_dataset from mmseg.models import build_segmentor from mmseg.utils import collect_env, get_root_logger @@ -125,12 +125,12 @@ def main(): logger.info(f'Config:\n{cfg.pretty_text}') # set random seeds - if args.seed is not None: - logger.info(f'Set random seed to {args.seed}, deterministic: ' - f'{args.deterministic}') - set_random_seed(args.seed, deterministic=args.deterministic) - cfg.seed = args.seed - meta['seed'] = args.seed + seed = init_random_seed(args.seed) + logger.info(f'Set random seed to {seed}, ' + f'deterministic: {args.deterministic}') + set_random_seed(seed, deterministic=args.deterministic) + cfg.seed = seed + meta['seed'] = seed meta['exp_name'] = osp.basename(args.config) model = build_segmentor(