[Enchance] Set a random seed when the user does not set a seed. (#1039)

pull/1801/head
Junjun2016 2021-11-26 10:55:15 +08:00 committed by GitHub
parent fdc054614c
commit 341ecc7633
3 changed files with 43 additions and 10 deletions

View File

@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_segmentor, init_segmentor, show_result_pyplot from .inference import inference_segmentor, init_segmentor, show_result_pyplot
from .test import multi_gpu_test, single_gpu_test from .test import multi_gpu_test, single_gpu_test
from .train import get_root_logger, set_random_seed, train_segmentor from .train import (get_root_logger, init_random_seed, set_random_seed,
train_segmentor)
__all__ = [ __all__ = [
'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
'show_result_pyplot' 'show_result_pyplot', 'init_random_seed'
] ]

View File

@ -4,8 +4,9 @@ import warnings
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import HOOKS, build_optimizer, build_runner from mmcv.runner import HOOKS, build_optimizer, build_runner, get_dist_info
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
from mmseg.core import DistEvalHook, EvalHook from mmseg.core import DistEvalHook, EvalHook
@ -13,6 +14,37 @@ from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
def set_random_seed(seed, deterministic=False): def set_random_seed(seed, deterministic=False):
"""Set random seed. """Set random seed.

View File

@ -13,7 +13,7 @@ from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import Config, DictAction, get_git_hash from mmcv.utils import Config, DictAction, get_git_hash
from mmseg import __version__ from mmseg import __version__
from mmseg.apis import set_random_seed, train_segmentor from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
from mmseg.datasets import build_dataset from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor from mmseg.models import build_segmentor
from mmseg.utils import collect_env, get_root_logger from mmseg.utils import collect_env, get_root_logger
@ -125,12 +125,12 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}') logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds # set random seeds
if args.seed is not None: seed = init_random_seed(args.seed)
logger.info(f'Set random seed to {args.seed}, deterministic: ' logger.info(f'Set random seed to {seed}, '
f'{args.deterministic}') f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic) set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = args.seed cfg.seed = seed
meta['seed'] = args.seed meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config) meta['exp_name'] = osp.basename(args.config)
model = build_segmentor( model = build_segmentor(