mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
[Fix] Always broadcast a random seed to all the processes (#600)
This commit is contained in:
parent
284b5acc12
commit
e267d06281
@ -1,5 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .inference import init_detector, model_inference
|
||||
from .train import train_detector
|
||||
from .train import init_random_seed, train_detector
|
||||
|
||||
__all__ = ['model_inference', 'train_detector', 'init_detector']
|
||||
__all__ = [
|
||||
'model_inference', 'train_detector', 'init_detector', 'init_random_seed'
|
||||
]
|
||||
|
@ -1,11 +1,13 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
|
||||
Fp16OptimizerHook, OptimizerHook, build_optimizer,
|
||||
build_runner)
|
||||
build_runner, get_dist_info)
|
||||
from mmcv.utils import build_from_cfg
|
||||
from mmdet.core import DistEvalHook, EvalHook
|
||||
from mmdet.datasets import (build_dataloader, build_dataset,
|
||||
@ -161,3 +163,33 @@ def train_detector(model,
|
||||
elif cfg.load_from:
|
||||
runner.load_checkpoint(cfg.load_from)
|
||||
runner.run(data_loaders, cfg.workflow)
|
||||
|
||||
|
||||
def init_random_seed(seed=None, device='cuda'):
|
||||
"""Initialize random seed. If the seed is None, it will be replaced by a
|
||||
random number, and then broadcasted to all processes.
|
||||
|
||||
Args:
|
||||
seed (int, Optional): The seed.
|
||||
device (str): The device where the seed will be put on.
|
||||
|
||||
Returns:
|
||||
int: Seed to be used.
|
||||
"""
|
||||
if seed is not None:
|
||||
return seed
|
||||
|
||||
# Make sure all ranks share the same random seed to prevent
|
||||
# some potential bugs. Please refer to
|
||||
# https://github.com/open-mmlab/mmdetection/issues/6339
|
||||
rank, world_size = get_dist_info()
|
||||
seed = np.random.randint(2**31)
|
||||
if world_size == 1:
|
||||
return seed
|
||||
|
||||
if rank == 0:
|
||||
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
||||
else:
|
||||
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
||||
dist.broadcast(random_num, src=0)
|
||||
return random_num.item()
|
||||
|
@ -14,7 +14,7 @@ from mmcv.runner import get_dist_info, init_dist, set_random_seed
|
||||
from mmcv.utils import get_git_hash
|
||||
|
||||
from mmocr import __version__
|
||||
from mmocr.apis import train_detector
|
||||
from mmocr.apis import init_random_seed, train_detector
|
||||
from mmocr.datasets import build_dataset
|
||||
from mmocr.models import build_detector
|
||||
from mmocr.utils import collect_env, get_root_logger
|
||||
@ -171,12 +171,12 @@ def main():
|
||||
logger.info(f'Config:\n{cfg.pretty_text}')
|
||||
|
||||
# set random seeds
|
||||
if args.seed is not None:
|
||||
logger.info(f'Set random seed to {args.seed}, '
|
||||
f'deterministic: {args.deterministic}')
|
||||
set_random_seed(args.seed, deterministic=args.deterministic)
|
||||
cfg.seed = args.seed
|
||||
meta['seed'] = args.seed
|
||||
seed = init_random_seed(args.seed)
|
||||
logger.info(f'Set random seed to {seed}, '
|
||||
f'deterministic: {args.deterministic}')
|
||||
set_random_seed(seed, deterministic=args.deterministic)
|
||||
cfg.seed = seed
|
||||
meta['seed'] = seed
|
||||
meta['exp_name'] = osp.basename(args.config)
|
||||
|
||||
model = build_detector(
|
||||
|
Loading…
x
Reference in New Issue
Block a user