[Enchance] Set a random seed when the user does not set a seed. (#554)

pull/571/head^2
Ma Zerun 2021-12-02 18:09:55 +08:00 committed by GitHub
parent 78d6d8503f
commit d25a78d547
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 10 deletions

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_model, init_model, show_result_pyplot
from .test import multi_gpu_test, single_gpu_test
from .train import set_random_seed, train_model
from .train import init_random_seed, set_random_seed, train_model
__all__ = [
'set_random_seed', 'train_model', 'init_model', 'inference_model',
'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot'
'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot',
'init_random_seed'
]

View File

@ -4,8 +4,10 @@ import warnings
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import DistSamplerSeedHook, build_optimizer, build_runner
from mmcv.runner import (DistSamplerSeedHook, build_optimizer, build_runner,
get_dist_info)
from mmcls.core import DistOptimizerHook
from mmcls.datasets import build_dataloader, build_dataset
@ -29,6 +31,39 @@ except ImportError:
from mmcls.core import Fp16OptimizerHook
def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
def set_random_seed(seed, deterministic=False):
"""Set random seed.

View File

@ -12,7 +12,7 @@ from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcls import __version__
from mmcls.apis import set_random_seed, train_model
from mmcls.apis import init_random_seed, set_random_seed, train_model
from mmcls.datasets import build_dataset
from mmcls.models import build_classifier
from mmcls.utils import collect_env, get_root_logger
@ -143,12 +143,12 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
seed = init_random_seed(args.seed)
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
model = build_classifier(cfg.model)
model.init_weights()