mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Enchance] Set a random seed when the user does not set a seed. (#554)
This commit is contained in:
parent
78d6d8503f
commit
d25a78d547
@ -1,9 +1,10 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .inference import inference_model, init_model, show_result_pyplot
|
from .inference import inference_model, init_model, show_result_pyplot
|
||||||
from .test import multi_gpu_test, single_gpu_test
|
from .test import multi_gpu_test, single_gpu_test
|
||||||
from .train import set_random_seed, train_model
|
from .train import init_random_seed, set_random_seed, train_model
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'set_random_seed', 'train_model', 'init_model', 'inference_model',
|
'set_random_seed', 'train_model', 'init_model', 'inference_model',
|
||||||
'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot'
|
'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot',
|
||||||
|
'init_random_seed'
|
||||||
]
|
]
|
||||||
|
@ -4,8 +4,10 @@ import warnings
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||||
from mmcv.runner import DistSamplerSeedHook, build_optimizer, build_runner
|
from mmcv.runner import (DistSamplerSeedHook, build_optimizer, build_runner,
|
||||||
|
get_dist_info)
|
||||||
|
|
||||||
from mmcls.core import DistOptimizerHook
|
from mmcls.core import DistOptimizerHook
|
||||||
from mmcls.datasets import build_dataloader, build_dataset
|
from mmcls.datasets import build_dataloader, build_dataset
|
||||||
@ -29,6 +31,39 @@ except ImportError:
|
|||||||
from mmcls.core import Fp16OptimizerHook
|
from mmcls.core import Fp16OptimizerHook
|
||||||
|
|
||||||
|
|
||||||
|
def init_random_seed(seed=None, device='cuda'):
|
||||||
|
"""Initialize random seed.
|
||||||
|
|
||||||
|
If the seed is not set, the seed will be automatically randomized,
|
||||||
|
and then broadcast to all processes to prevent some potential bugs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int, Optional): The seed. Default to None.
|
||||||
|
device (str): The device where the seed will be put on.
|
||||||
|
Default to 'cuda'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Seed to be used.
|
||||||
|
"""
|
||||||
|
if seed is not None:
|
||||||
|
return seed
|
||||||
|
|
||||||
|
# Make sure all ranks share the same random seed to prevent
|
||||||
|
# some potential bugs. Please refer to
|
||||||
|
# https://github.com/open-mmlab/mmdetection/issues/6339
|
||||||
|
rank, world_size = get_dist_info()
|
||||||
|
seed = np.random.randint(2**31)
|
||||||
|
if world_size == 1:
|
||||||
|
return seed
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
||||||
|
else:
|
||||||
|
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
||||||
|
dist.broadcast(random_num, src=0)
|
||||||
|
return random_num.item()
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed, deterministic=False):
|
def set_random_seed(seed, deterministic=False):
|
||||||
"""Set random seed.
|
"""Set random seed.
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from mmcv import Config, DictAction
|
|||||||
from mmcv.runner import get_dist_info, init_dist
|
from mmcv.runner import get_dist_info, init_dist
|
||||||
|
|
||||||
from mmcls import __version__
|
from mmcls import __version__
|
||||||
from mmcls.apis import set_random_seed, train_model
|
from mmcls.apis import init_random_seed, set_random_seed, train_model
|
||||||
from mmcls.datasets import build_dataset
|
from mmcls.datasets import build_dataset
|
||||||
from mmcls.models import build_classifier
|
from mmcls.models import build_classifier
|
||||||
from mmcls.utils import collect_env, get_root_logger
|
from mmcls.utils import collect_env, get_root_logger
|
||||||
@ -143,12 +143,12 @@ def main():
|
|||||||
logger.info(f'Config:\n{cfg.pretty_text}')
|
logger.info(f'Config:\n{cfg.pretty_text}')
|
||||||
|
|
||||||
# set random seeds
|
# set random seeds
|
||||||
if args.seed is not None:
|
seed = init_random_seed(args.seed)
|
||||||
logger.info(f'Set random seed to {args.seed}, '
|
logger.info(f'Set random seed to {seed}, '
|
||||||
f'deterministic: {args.deterministic}')
|
f'deterministic: {args.deterministic}')
|
||||||
set_random_seed(args.seed, deterministic=args.deterministic)
|
set_random_seed(seed, deterministic=args.deterministic)
|
||||||
cfg.seed = args.seed
|
cfg.seed = seed
|
||||||
meta['seed'] = args.seed
|
meta['seed'] = seed
|
||||||
|
|
||||||
model = build_classifier(cfg.model)
|
model = build_classifier(cfg.model)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user