[Enhance] Disable warning of subprocess launched by dataloader (#870)

* Disable warning of subprocess launched by dataloader

* Add type hint
This commit is contained in:
Mashiro 2023-01-16 14:09:47 +08:00 committed by GitHub
parent 0b59a90a21
commit ad590e45a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 3 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import random import random
import warnings
from typing import Any, Mapping, Sequence from typing import Any, Mapping, Sequence
import numpy as np import numpy as np
@ -13,8 +14,11 @@ from mmengine.structures import BaseDataElement
COLLATE_FUNCTIONS = Registry('Collate Functions') COLLATE_FUNCTIONS = Registry('Collate Functions')
def worker_init_fn(worker_id: int, num_workers: int, rank: int, def worker_init_fn(worker_id: int,
seed: int) -> None: num_workers: int,
rank: int,
seed: int,
disable_subprocess_warning: bool = False) -> None:
"""This function will be called on each worker subprocess after seeding and """This function will be called on each worker subprocess after seeding and
before data loading. before data loading.
@ -31,6 +35,8 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int,
np.random.seed(worker_seed) np.random.seed(worker_seed)
random.seed(worker_seed) random.seed(worker_seed)
torch.manual_seed(worker_seed) torch.manual_seed(worker_seed)
if disable_subprocess_warning and worker_id != 0:
warnings.simplefilter('ignore')
@COLLATE_FUNCTIONS.register_module() @COLLATE_FUNCTIONS.register_module()

View File

@ -1367,12 +1367,20 @@ class Runner:
# build dataloader # build dataloader
init_fn: Optional[partial] init_fn: Optional[partial]
if seed is not None: if seed is not None:
disable_subprocess_warning = dataloader_cfg.pop(
'disable_subprocess_warning', False)
assert isinstance(
disable_subprocess_warning,
bool), ('disable_subprocess_warning should be a bool, but got '
f'{type(disable_subprocess_warning)}')
init_fn = partial( init_fn = partial(
worker_init_fn, worker_init_fn,
num_workers=dataloader_cfg.get('num_workers'), num_workers=dataloader_cfg.get('num_workers'),
rank=get_rank(), rank=get_rank(),
seed=seed) seed=seed,
disable_subprocess_warning=disable_subprocess_warning)
else: else:
init_fn = None init_fn = None