From 2eb0a10d5a32f56174369d1a4c70ee18f2136345 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Tue, 15 Mar 2022 11:14:06 +0800 Subject: [PATCH] [Feature] Add worker_init_fn (#1788) * add worker_init_fn * "Fix as comment" * Fix format Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmcv/utils/__init__.py | 3 ++- mmcv/utils/seed.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 mmcv/utils/seed.py diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 019e77f56..8159c6a1a 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -50,6 +50,7 @@ else: is_rocm_pytorch) # yapf: enable from .registry import Registry, build_from_cfg + from .seed import worker_init_fn from .trace import is_jit_tracing __all__ = [ 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', @@ -70,5 +71,5 @@ else: 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', 'assert_params_all_zeros', 'check_python_script', 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', - '_get_cuda_home', 'load_url', 'has_method' + '_get_cuda_home', 'load_url', 'has_method', 'worker_init_fn' ] diff --git a/mmcv/utils/seed.py b/mmcv/utils/seed.py new file mode 100644 index 000000000..003f92367 --- /dev/null +++ b/mmcv/utils/seed.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random + +import numpy as np +import torch + + +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): + """Function to initialize each worker. + + The seed of each worker equals to + ``num_worker * rank + worker_id + user_seed``. + + Args: + worker_id (int): Id for each worker. + num_workers (int): Number of workers. + rank (int): Rank in distributed training. + seed (int): Random seed. + """ + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + torch.manual_seed(worker_seed)