mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Allow users to customize worker_init_fn of Dataloader (#1038)
* customize worker init fn function * add assert * narrow worker_init_fn type
This commit is contained in:
parent
eea2c278f4
commit
5e1ed7aaf0
@ -19,7 +19,7 @@ from torch.utils.data import DataLoader
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.dataset import worker_init_fn
|
||||
from mmengine.dataset import worker_init_fn as default_worker_init_fn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
|
||||
is_distributed, master_only)
|
||||
@ -1381,21 +1381,28 @@ class Runner:
|
||||
# build dataloader
|
||||
init_fn: Optional[partial]
|
||||
|
||||
if seed is not None:
|
||||
disable_subprocess_warning = dataloader_cfg.pop(
|
||||
'disable_subprocess_warning', False)
|
||||
assert isinstance(
|
||||
disable_subprocess_warning,
|
||||
bool), ('disable_subprocess_warning should be a bool, but got '
|
||||
f'{type(disable_subprocess_warning)}')
|
||||
init_fn = partial(
|
||||
worker_init_fn,
|
||||
num_workers=dataloader_cfg.get('num_workers'),
|
||||
rank=get_rank(),
|
||||
seed=seed,
|
||||
disable_subprocess_warning=disable_subprocess_warning)
|
||||
if 'worker_init_fn' in dataloader_cfg:
|
||||
worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn')
|
||||
worker_init_fn_type = worker_init_fn_cfg.pop('type')
|
||||
worker_init_fn = FUNCTIONS.get(worker_init_fn_type)
|
||||
assert callable(worker_init_fn)
|
||||
init_fn = partial(worker_init_fn,
|
||||
**worker_init_fn_cfg) # type: ignore
|
||||
else:
|
||||
init_fn = None
|
||||
if seed is not None:
|
||||
disable_subprocess_warning = dataloader_cfg.pop(
|
||||
'disable_subprocess_warning', False)
|
||||
assert isinstance(disable_subprocess_warning, bool), (
|
||||
'disable_subprocess_warning should be a bool, but got '
|
||||
f'{type(disable_subprocess_warning)}')
|
||||
init_fn = partial(
|
||||
default_worker_init_fn,
|
||||
num_workers=dataloader_cfg.get('num_workers'),
|
||||
rank=get_rank(),
|
||||
seed=seed,
|
||||
disable_subprocess_warning=disable_subprocess_warning)
|
||||
else:
|
||||
init_fn = None
|
||||
|
||||
# `persistent_workers` requires pytorch version >= 1.7
|
||||
if ('persistent_workers' in dataloader_cfg
|
||||
|
@ -3,6 +3,7 @@ import copy
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest import TestCase, skipIf
|
||||
@ -352,6 +353,11 @@ def custom_collate(data_batch, pad_value):
|
||||
return pseudo_collate(data_batch)
|
||||
|
||||
|
||||
def custom_worker_init(worker_id):
|
||||
np.random.seed(worker_id)
|
||||
random.seed(worker_id)
|
||||
|
||||
|
||||
class TestRunner(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -376,6 +382,7 @@ class TestRunner(TestCase):
|
||||
RUNNERS.register_module(module=CustomRunner, force=True)
|
||||
EVALUATOR.register_module(module=ToyEvaluator, force=True)
|
||||
FUNCTIONS.register_module(module=custom_collate, force=True)
|
||||
FUNCTIONS.register_module(module=custom_worker_init, force=True)
|
||||
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
epoch_based_cfg = dict(
|
||||
@ -459,6 +466,7 @@ class TestRunner(TestCase):
|
||||
RUNNERS.module_dict.pop('CustomRunner')
|
||||
EVALUATOR.module_dict.pop('ToyEvaluator')
|
||||
FUNCTIONS.module_dict.pop('custom_collate')
|
||||
FUNCTIONS.module_dict.pop('custom_worker_init')
|
||||
|
||||
logging.shutdown()
|
||||
MMLogger._instance_dict.clear()
|
||||
@ -1245,6 +1253,16 @@ class TestRunner(TestCase):
|
||||
cfg, seed=seed, diff_rank_seed=True)
|
||||
self.assertNotEqual(dataloader.sampler.seed, seed)
|
||||
|
||||
# custom worker_init_fn
|
||||
cfg = dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
worker_init_fn=dict(type='custom_worker_init'),
|
||||
batch_size=1,
|
||||
num_workers=2)
|
||||
dataloader = runner.build_dataloader(cfg)
|
||||
self.assertIs(dataloader.worker_init_fn.func, custom_worker_init)
|
||||
|
||||
def test_build_train_loop(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_build_train_loop'
|
||||
@ -1689,6 +1707,14 @@ class TestRunner(TestCase):
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
# 10.3 Test build dataloader with custom worker_init function
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_train10.3'
|
||||
cfg.train_dataloader.update(
|
||||
worker_init_fn=dict(type='custom_worker_init'))
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
# 11 test build dataloader without default arguments of collate
|
||||
# function.
|
||||
with self.assertRaises(TypeError):
|
||||
|
Loading…
x
Reference in New Issue
Block a user