diff --git a/mmcls/apis/__init__.py b/mmcls/apis/__init__.py index 7dc58a97..b632f2a9 100644 --- a/mmcls/apis/__init__.py +++ b/mmcls/apis/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .inference import inference_model, init_model, show_result_pyplot from .test import multi_gpu_test, single_gpu_test -from .train import set_random_seed, train_model +from .train import init_random_seed, set_random_seed, train_model __all__ = [ 'set_random_seed', 'train_model', 'init_model', 'inference_model', - 'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot' + 'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot', + 'init_random_seed' ] diff --git a/mmcls/apis/train.py b/mmcls/apis/train.py index b3546c04..603a39de 100644 --- a/mmcls/apis/train.py +++ b/mmcls/apis/train.py @@ -4,8 +4,10 @@ import warnings import numpy as np import torch +import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import DistSamplerSeedHook, build_optimizer, build_runner +from mmcv.runner import (DistSamplerSeedHook, build_optimizer, build_runner, + get_dist_info) from mmcls.core import DistOptimizerHook from mmcls.datasets import build_dataloader, build_dataset @@ -29,6 +31,39 @@ except ImportError: from mmcls.core import Fp16OptimizerHook +def init_random_seed(seed=None, device='cuda'): + """Initialize random seed. + + If the seed is not set, the seed will be automatically randomized, + and then broadcast to all processes to prevent some potential bugs. + + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + + Returns: + int: Seed to be used. + """ + if seed is not None: + return seed + + # Make sure all ranks share the same random seed to prevent + # some potential bugs. Please refer to + # https://github.com/open-mmlab/mmdetection/issues/6339 + rank, world_size = get_dist_info() + seed = np.random.randint(2**31) + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() + + def set_random_seed(seed, deterministic=False): """Set random seed. diff --git a/tools/train.py b/tools/train.py index adff63d9..c9049a49 100644 --- a/tools/train.py +++ b/tools/train.py @@ -12,7 +12,7 @@ from mmcv import Config, DictAction from mmcv.runner import get_dist_info, init_dist from mmcls import __version__ -from mmcls.apis import set_random_seed, train_model +from mmcls.apis import init_random_seed, set_random_seed, train_model from mmcls.datasets import build_dataset from mmcls.models import build_classifier from mmcls.utils import collect_env, get_root_logger @@ -143,12 +143,12 @@ def main(): logger.info(f'Config:\n{cfg.pretty_text}') # set random seeds - if args.seed is not None: - logger.info(f'Set random seed to {args.seed}, ' - f'deterministic: {args.deterministic}') - set_random_seed(args.seed, deterministic=args.deterministic) - cfg.seed = args.seed - meta['seed'] = args.seed + seed = init_random_seed(args.seed) + logger.info(f'Set random seed to {seed}, ' + f'deterministic: {args.deterministic}') + set_random_seed(seed, deterministic=args.deterministic) + cfg.seed = seed + meta['seed'] = seed model = build_classifier(cfg.model) model.init_weights()