[Feature] Add worker_init_fn (#1788)

* add worker_init_fn

* "Fix as comment"

* Fix format

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/1802/head
Mashiro 2022-03-15 11:14:06 +08:00 committed by GitHub
parent e8cf961324
commit 2eb0a10d5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 1 deletions

View File

@ -50,6 +50,7 @@ else:
is_rocm_pytorch)
# yapf: enable
from .registry import Registry, build_from_cfg
from .seed import worker_init_fn
from .trace import is_jit_tracing
__all__ = [
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
@ -70,5 +71,5 @@ else:
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
'assert_params_all_zeros', 'check_python_script',
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home', 'load_url', 'has_method'
'_get_cuda_home', 'load_url', 'has_method', 'worker_init_fn'
]

23
mmcv/utils/seed.py 100644
View File

@ -0,0 +1,23 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
import numpy as np
import torch
def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
"""Function to initialize each worker.
The seed of each worker equals to
``num_worker * rank + worker_id + user_seed``.
Args:
worker_id (int): Id for each worker.
num_workers (int): Number of workers.
rank (int): Rank in distributed training.
seed (int): Random seed.
"""
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)