mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature]: Add diff seeds to diff ranks and set torch seed in worker_init_fn (#1362)
This commit is contained in:
parent
98b8ed37e4
commit
da6bb2c8c5
@ -186,3 +186,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed):
|
|||||||
worker_seed = num_workers * rank + worker_id + seed
|
worker_seed = num_workers * rank + worker_id + seed
|
||||||
np.random.seed(worker_seed)
|
np.random.seed(worker_seed)
|
||||||
random.seed(worker_seed)
|
random.seed(worker_seed)
|
||||||
|
torch.manual_seed(worker_seed)
|
||||||
|
@ -8,6 +8,7 @@ import warnings
|
|||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from mmcv.cnn.utils import revert_sync_batchnorm
|
from mmcv.cnn.utils import revert_sync_batchnorm
|
||||||
from mmcv.runner import get_dist_info, init_dist
|
from mmcv.runner import get_dist_info, init_dist
|
||||||
from mmcv.utils import Config, DictAction, get_git_hash
|
from mmcv.utils import Config, DictAction, get_git_hash
|
||||||
@ -50,6 +51,10 @@ def parse_args():
|
|||||||
help='id of gpu to use '
|
help='id of gpu to use '
|
||||||
'(only applicable to non-distributed training)')
|
'(only applicable to non-distributed training)')
|
||||||
parser.add_argument('--seed', type=int, default=None, help='random seed')
|
parser.add_argument('--seed', type=int, default=None, help='random seed')
|
||||||
|
parser.add_argument(
|
||||||
|
'--diff_seed',
|
||||||
|
action='store_true',
|
||||||
|
help='Whether or not set different seeds for different ranks')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--deterministic',
|
'--deterministic',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@ -180,6 +185,7 @@ def main():
|
|||||||
|
|
||||||
# set random seeds
|
# set random seeds
|
||||||
seed = init_random_seed(args.seed)
|
seed = init_random_seed(args.seed)
|
||||||
|
seed = seed + dist.get_rank() if args.diff_seed else seed
|
||||||
logger.info(f'Set random seed to {seed}, '
|
logger.info(f'Set random seed to {seed}, '
|
||||||
f'deterministic: {args.deterministic}')
|
f'deterministic: {args.deterministic}')
|
||||||
set_random_seed(seed, deterministic=args.deterministic)
|
set_random_seed(seed, deterministic=args.deterministic)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user