mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Disable warning of subprocess launched by dataloader (#870)
* Disable warning of subprocess launched by dataloader * Add type hint
This commit is contained in:
parent
0b59a90a21
commit
ad590e45a2
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import random
|
import random
|
||||||
|
import warnings
|
||||||
from typing import Any, Mapping, Sequence
|
from typing import Any, Mapping, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -13,8 +14,11 @@ from mmengine.structures import BaseDataElement
|
|||||||
COLLATE_FUNCTIONS = Registry('Collate Functions')
|
COLLATE_FUNCTIONS = Registry('Collate Functions')
|
||||||
|
|
||||||
|
|
||||||
def worker_init_fn(worker_id: int, num_workers: int, rank: int,
|
def worker_init_fn(worker_id: int,
|
||||||
seed: int) -> None:
|
num_workers: int,
|
||||||
|
rank: int,
|
||||||
|
seed: int,
|
||||||
|
disable_subprocess_warning: bool = False) -> None:
|
||||||
"""This function will be called on each worker subprocess after seeding and
|
"""This function will be called on each worker subprocess after seeding and
|
||||||
before data loading.
|
before data loading.
|
||||||
|
|
||||||
@ -31,6 +35,8 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int,
|
|||||||
np.random.seed(worker_seed)
|
np.random.seed(worker_seed)
|
||||||
random.seed(worker_seed)
|
random.seed(worker_seed)
|
||||||
torch.manual_seed(worker_seed)
|
torch.manual_seed(worker_seed)
|
||||||
|
if disable_subprocess_warning and worker_id != 0:
|
||||||
|
warnings.simplefilter('ignore')
|
||||||
|
|
||||||
|
|
||||||
@COLLATE_FUNCTIONS.register_module()
|
@COLLATE_FUNCTIONS.register_module()
|
||||||
|
@ -1367,12 +1367,20 @@ class Runner:
|
|||||||
|
|
||||||
# build dataloader
|
# build dataloader
|
||||||
init_fn: Optional[partial]
|
init_fn: Optional[partial]
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
|
disable_subprocess_warning = dataloader_cfg.pop(
|
||||||
|
'disable_subprocess_warning', False)
|
||||||
|
assert isinstance(
|
||||||
|
disable_subprocess_warning,
|
||||||
|
bool), ('disable_subprocess_warning should be a bool, but got '
|
||||||
|
f'{type(disable_subprocess_warning)}')
|
||||||
init_fn = partial(
|
init_fn = partial(
|
||||||
worker_init_fn,
|
worker_init_fn,
|
||||||
num_workers=dataloader_cfg.get('num_workers'),
|
num_workers=dataloader_cfg.get('num_workers'),
|
||||||
rank=get_rank(),
|
rank=get_rank(),
|
||||||
seed=seed)
|
seed=seed,
|
||||||
|
disable_subprocess_warning=disable_subprocess_warning)
|
||||||
else:
|
else:
|
||||||
init_fn = None
|
init_fn = None
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user