[Enchance] Set a random seed when the user does not set a seed. (#1039)

pull/1801/head
Junjun2016 2021-11-26 10:55:15 +08:00 committed by GitHub
parent fdc054614c
commit 341ecc7633
3 changed files with 43 additions and 10 deletions

View File

@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_segmentor, init_segmentor, show_result_pyplot
from .test import multi_gpu_test, single_gpu_test
from .train import get_root_logger, set_random_seed, train_segmentor
from .train import (get_root_logger, init_random_seed, set_random_seed,
train_segmentor)
__all__ = [
'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
'show_result_pyplot'
'show_result_pyplot', 'init_random_seed'
]

View File

@ -4,8 +4,9 @@ import warnings
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import HOOKS, build_optimizer, build_runner
from mmcv.runner import HOOKS, build_optimizer, build_runner, get_dist_info
from mmcv.utils import build_from_cfg
from mmseg.core import DistEvalHook, EvalHook
@ -13,6 +14,37 @@ from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import get_root_logger
def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
def set_random_seed(seed, deterministic=False):
"""Set random seed.

View File

@ -13,7 +13,7 @@ from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import Config, DictAction, get_git_hash
from mmseg import __version__
from mmseg.apis import set_random_seed, train_segmentor
from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.utils import collect_env, get_root_logger
@ -125,12 +125,12 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, deterministic: '
f'{args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
seed = init_random_seed(args.seed)
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)
model = build_segmentor(