[Fix] Always broadcast a random seed to all the processes (#600)

This commit is contained in:
Tong Gao 2021-11-18 22:26:21 +08:00 committed by GitHub
parent 284b5acc12
commit e267d06281
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 10 deletions

View File

@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import init_detector, model_inference
from .train import train_detector
from .train import init_random_seed, train_detector
__all__ = ['model_inference', 'train_detector', 'init_detector']
__all__ = [
'model_inference', 'train_detector', 'init_detector', 'init_random_seed'
]

View File

@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner)
build_runner, get_dist_info)
from mmcv.utils import build_from_cfg
from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import (build_dataloader, build_dataset,
@ -161,3 +163,33 @@ def train_detector(model,
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)
def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed. If the seed is None, it will be replaced by a
random number, and then broadcasted to all processes.
Args:
seed (int, Optional): The seed.
device (str): The device where the seed will be put on.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()

View File

@ -14,7 +14,7 @@ from mmcv.runner import get_dist_info, init_dist, set_random_seed
from mmcv.utils import get_git_hash
from mmocr import __version__
from mmocr.apis import train_detector
from mmocr.apis import init_random_seed, train_detector
from mmocr.datasets import build_dataset
from mmocr.models import build_detector
from mmocr.utils import collect_env, get_root_logger
@ -171,12 +171,12 @@ def main():
logger.info(f'Config:\n{cfg.pretty_text}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(args.seed, deterministic=args.deterministic)
cfg.seed = args.seed
meta['seed'] = args.seed
seed = init_random_seed(args.seed)
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
cfg.seed = seed
meta['seed'] = seed
meta['exp_name'] = osp.basename(args.config)
model = build_detector(