[Feature] Add diff seeds to diff ranks and set torch seed in worker_init_fn (#113)
* add init_random_seed * Set diff seed to diff workerspull/115/head
parent
20d1e0b0d8
commit
09800c7d85
|
@ -2,4 +2,6 @@
|
|||
from .mmcls import * # noqa: F401,F403
|
||||
from .mmdet import * # noqa: F401,F403
|
||||
from .mmseg import * # noqa: F401,F403
|
||||
from .utils import set_random_seed # noqa: F401
|
||||
from .utils import init_random_seed, set_random_seed # noqa: F401
|
||||
|
||||
__all__ = ['init_random_seed', 'set_random_seed']
|
||||
|
|
|
@ -3,6 +3,39 @@ import random
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import get_dist_info
|
||||
|
||||
|
||||
def init_random_seed(seed=None, device='cuda'):
|
||||
"""Initialize random seed.
|
||||
|
||||
If the seed is not set, the seed will be automatically randomized,
|
||||
and then broadcast to all processes to prevent some potential bugs.
|
||||
Args:
|
||||
seed (int, Optional): The seed. Default to None.
|
||||
device (str): The device where the seed will be put on.
|
||||
Default to 'cuda'.
|
||||
Returns:
|
||||
int: Seed to be used.
|
||||
"""
|
||||
if seed is not None:
|
||||
return seed
|
||||
|
||||
# Make sure all ranks share the same random seed to prevent
|
||||
# some potential bugs. Please refer to
|
||||
# https://github.com/open-mmlab/mmdetection/issues/6339
|
||||
rank, world_size = get_dist_info()
|
||||
seed = np.random.randint(2**31)
|
||||
if world_size == 1:
|
||||
return seed
|
||||
|
||||
if rank == 0:
|
||||
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
|
||||
else:
|
||||
random_num = torch.tensor(0, dtype=torch.int32, device=device)
|
||||
dist.broadcast(random_num, src=0)
|
||||
return random_num.item()
|
||||
|
||||
|
||||
def set_random_seed(seed: int, deterministic: bool = False) -> None:
|
||||
|
|
|
@ -8,6 +8,7 @@ import warnings
|
|||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcls import __version__
|
||||
from mmcls.datasets import build_dataset
|
||||
from mmcls.utils import collect_env, get_root_logger
|
||||
|
@ -15,7 +16,7 @@ from mmcv import Config, DictAction
|
|||
from mmcv.runner import get_dist_info, init_dist
|
||||
|
||||
# Differences from mmclassification
|
||||
from mmrazor.apis import set_random_seed, train_mmcls_model
|
||||
from mmrazor.apis import init_random_seed, set_random_seed, train_mmcls_model
|
||||
from mmrazor.models import build_algorithm
|
||||
from mmrazor.utils import setup_multi_processes
|
||||
|
||||
|
@ -54,6 +55,10 @@ def parse_args():
|
|||
help='id of gpu to use '
|
||||
'(only applicable to non-distributed training)')
|
||||
parser.add_argument('--seed', type=int, default=None, help='random seed')
|
||||
parser.add_argument(
|
||||
'--diff_seed',
|
||||
action='store_true',
|
||||
help='Whether or not set different seeds for different ranks')
|
||||
parser.add_argument(
|
||||
'--deterministic',
|
||||
action='store_true',
|
||||
|
@ -154,12 +159,14 @@ def main():
|
|||
logger.info(f'Config:\n{cfg.pretty_text}')
|
||||
|
||||
# set random seeds
|
||||
if args.seed is not None:
|
||||
logger.info(f'Set random seed to {args.seed}, '
|
||||
f'deterministic: {args.deterministic}')
|
||||
set_random_seed(args.seed, deterministic=args.deterministic)
|
||||
cfg.seed = args.seed
|
||||
meta['seed'] = args.seed
|
||||
seed = init_random_seed(args.seed)
|
||||
seed = seed + dist.get_rank() if args.diff_seed else seed
|
||||
logger.info(f'Set random seed to {seed}, '
|
||||
f'deterministic: {args.deterministic}')
|
||||
set_random_seed(seed, deterministic=args.deterministic)
|
||||
cfg.seed = seed
|
||||
meta['seed'] = seed
|
||||
meta['exp_name'] = osp.basename(args.config)
|
||||
|
||||
# Difference from mmclassification
|
||||
# replace `model` to `algorithm`
|
||||
|
|
|
@ -16,6 +16,7 @@ import warnings
|
|||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv import Config, DictAction
|
||||
from mmcv.runner import get_dist_info, init_dist
|
||||
from mmcv.utils import get_git_hash
|
||||
|
@ -23,7 +24,7 @@ from mmdet import __version__
|
|||
from mmdet.datasets import build_dataset
|
||||
from mmdet.utils import collect_env, get_root_logger
|
||||
|
||||
from mmrazor.apis import set_random_seed, train_mmdet_model
|
||||
from mmrazor.apis import init_random_seed, set_random_seed, train_mmdet_model
|
||||
from mmrazor.models import build_algorithm
|
||||
from mmrazor.utils import setup_multi_processes
|
||||
|
||||
|
@ -61,6 +62,10 @@ def parse_args():
|
|||
help='id of gpu to use '
|
||||
'(only applicable to non-distributed training)')
|
||||
parser.add_argument('--seed', type=int, default=None, help='random seed')
|
||||
parser.add_argument(
|
||||
'--diff_seed',
|
||||
action='store_true',
|
||||
help='Whether or not set different seeds for different ranks')
|
||||
parser.add_argument(
|
||||
'--deterministic',
|
||||
action='store_true',
|
||||
|
@ -166,12 +171,13 @@ def main():
|
|||
logger.info(f'Config:\n{cfg.pretty_text}')
|
||||
|
||||
# set random seeds
|
||||
if args.seed is not None:
|
||||
logger.info(f'Set random seed to {args.seed}, '
|
||||
f'deterministic: {args.deterministic}')
|
||||
set_random_seed(args.seed, deterministic=args.deterministic)
|
||||
cfg.seed = args.seed
|
||||
meta['seed'] = args.seed
|
||||
seed = init_random_seed(args.seed)
|
||||
seed = seed + dist.get_rank() if args.diff_seed else seed
|
||||
logger.info(f'Set random seed to {seed}, '
|
||||
f'deterministic: {args.deterministic}')
|
||||
set_random_seed(seed, deterministic=args.deterministic)
|
||||
cfg.seed = seed
|
||||
meta['seed'] = seed
|
||||
meta['exp_name'] = osp.basename(args.config)
|
||||
|
||||
algorithm = build_algorithm(cfg.algorithm)
|
||||
|
|
|
@ -16,6 +16,7 @@ import warnings
|
|||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.cnn.utils import revert_sync_batchnorm
|
||||
from mmcv.runner import get_dist_info, init_dist
|
||||
from mmcv.utils import Config, DictAction, get_git_hash
|
||||
|
@ -24,7 +25,7 @@ from mmseg.datasets import build_dataset
|
|||
from mmseg.utils import collect_env, get_root_logger
|
||||
|
||||
# Differences from mmdetection
|
||||
from mmrazor.apis import set_random_seed, train_mmseg_model
|
||||
from mmrazor.apis import init_random_seed, set_random_seed, train_mmseg_model
|
||||
from mmrazor.models.builder import build_algorithm
|
||||
from mmrazor.utils import setup_multi_processes
|
||||
|
||||
|
@ -64,6 +65,10 @@ def parse_args():
|
|||
help='id of gpu to use '
|
||||
'(only applicable to non-distributed training)')
|
||||
parser.add_argument('--seed', type=int, default=None, help='random seed')
|
||||
parser.add_argument(
|
||||
'--diff_seed',
|
||||
action='store_true',
|
||||
help='Whether or not set different seeds for different ranks')
|
||||
parser.add_argument(
|
||||
'--deterministic',
|
||||
action='store_true',
|
||||
|
@ -168,11 +173,12 @@ def main():
|
|||
logger.info(f'Config:\n{cfg.pretty_text}')
|
||||
|
||||
# set random seeds
|
||||
if args.seed is not None:
|
||||
logger.info(f'Set random seed to {args.seed}, deterministic: '
|
||||
f'{args.deterministic}')
|
||||
set_random_seed(args.seed, deterministic=args.deterministic)
|
||||
cfg.seed = args.seed
|
||||
seed = init_random_seed(args.seed)
|
||||
seed = seed + dist.get_rank() if args.diff_seed else seed
|
||||
logger.info(f'Set random seed to {args.seed}, deterministic: '
|
||||
f'{args.deterministic}')
|
||||
set_random_seed(args.seed, deterministic=args.deterministic)
|
||||
cfg.seed = seed
|
||||
meta['seed'] = args.seed
|
||||
meta['exp_name'] = osp.basename(args.config)
|
||||
|
||||
|
|
Loading…
Reference in New Issue