mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add worker_init_fn (#1788)
* add worker_init_fn * "Fix as comment" * Fix format Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/1802/head
parent
e8cf961324
commit
2eb0a10d5a
|
@ -50,6 +50,7 @@ else:
|
|||
is_rocm_pytorch)
|
||||
# yapf: enable
|
||||
from .registry import Registry, build_from_cfg
|
||||
from .seed import worker_init_fn
|
||||
from .trace import is_jit_tracing
|
||||
__all__ = [
|
||||
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
|
||||
|
@ -70,5 +71,5 @@ else:
|
|||
'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
|
||||
'assert_params_all_zeros', 'check_python_script',
|
||||
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
|
||||
'_get_cuda_home', 'load_url', 'has_method'
|
||||
'_get_cuda_home', 'load_url', 'has_method', 'worker_init_fn'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
|
||||
"""Function to initialize each worker.
|
||||
|
||||
The seed of each worker equals to
|
||||
``num_worker * rank + worker_id + user_seed``.
|
||||
|
||||
Args:
|
||||
worker_id (int): Id for each worker.
|
||||
num_workers (int): Number of workers.
|
||||
rank (int): Rank in distributed training.
|
||||
seed (int): Random seed.
|
||||
"""
|
||||
worker_seed = num_workers * rank + worker_id + seed
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
Loading…
Reference in New Issue