[Enchance] Set a random seed when the user does not set a seed. (#554)

This commit is contained in:
Ma Zerun 2021-12-02 18:09:55 +08:00 committed by GitHub
parent 78d6d8503f
commit d25a78d547
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 10 deletions

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_model, init_model, show_result_pyplot from .inference import inference_model, init_model, show_result_pyplot
from .test import multi_gpu_test, single_gpu_test from .test import multi_gpu_test, single_gpu_test
from .train import set_random_seed, train_model from .train import init_random_seed, set_random_seed, train_model
__all__ = [ __all__ = [
'set_random_seed', 'train_model', 'init_model', 'inference_model', 'set_random_seed', 'train_model', 'init_model', 'inference_model',
'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot' 'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot',
'init_random_seed'
] ]

View File

@ -4,8 +4,10 @@ import warnings
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import DistSamplerSeedHook, build_optimizer, build_runner from mmcv.runner import (DistSamplerSeedHook, build_optimizer, build_runner,
get_dist_info)
from mmcls.core import DistOptimizerHook from mmcls.core import DistOptimizerHook
from mmcls.datasets import build_dataloader, build_dataset from mmcls.datasets import build_dataloader, build_dataset
@ -29,6 +31,39 @@ except ImportError:
from mmcls.core import Fp16OptimizerHook from mmcls.core import Fp16OptimizerHook
def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
def set_random_seed(seed, deterministic=False): def set_random_seed(seed, deterministic=False):
"""Set random seed. """Set random seed.

View File

@ -12,7 +12,7 @@ from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist from mmcv.runner import get_dist_info, init_dist
from mmcls import __version__ from mmcls import __version__
from mmcls.apis import set_random_seed, train_model from mmcls.apis import init_random_seed, set_random_seed, train_model
from mmcls.datasets import build_dataset from mmcls.datasets import build_dataset
from mmcls.models import build_classifier from mmcls.models import build_classifier
from mmcls.utils import collect_env, get_root_logger from mmcls.utils import collect_env, get_root_logger
@ -143,12 +143,12 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}') logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds # set random seeds
if args.seed is not None: seed = init_random_seed(args.seed)
logger.info(f'Set random seed to {args.seed}, ' logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}') f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic) set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = args.seed cfg.seed = seed
meta['seed'] = args.seed meta['seed'] = seed
model = build_classifier(cfg.model) model = build_classifier(cfg.model)
model.init_weights() model.init_weights()