[Enhance] Allow users to customize worker_init_fn of Dataloader (#1038)

* customize worker init fn function

* add assert

* narrow worker_init_fn type
This commit is contained in:
shufan wu 2023-04-10 17:32:36 +08:00 committed by GitHub
parent eea2c278f4
commit 5e1ed7aaf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 15 deletions

View File

@ -19,7 +19,7 @@ from torch.utils.data import DataLoader
import mmengine
from mmengine.config import Config, ConfigDict
from mmengine.dataset import worker_init_fn
from mmengine.dataset import worker_init_fn as default_worker_init_fn
from mmengine.device import get_device
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
is_distributed, master_only)
@ -1381,21 +1381,28 @@ class Runner:
# build dataloader
init_fn: Optional[partial]
if seed is not None:
disable_subprocess_warning = dataloader_cfg.pop(
'disable_subprocess_warning', False)
assert isinstance(
disable_subprocess_warning,
bool), ('disable_subprocess_warning should be a bool, but got '
f'{type(disable_subprocess_warning)}')
init_fn = partial(
worker_init_fn,
num_workers=dataloader_cfg.get('num_workers'),
rank=get_rank(),
seed=seed,
disable_subprocess_warning=disable_subprocess_warning)
if 'worker_init_fn' in dataloader_cfg:
worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn')
worker_init_fn_type = worker_init_fn_cfg.pop('type')
worker_init_fn = FUNCTIONS.get(worker_init_fn_type)
assert callable(worker_init_fn)
init_fn = partial(worker_init_fn,
**worker_init_fn_cfg) # type: ignore
else:
init_fn = None
if seed is not None:
disable_subprocess_warning = dataloader_cfg.pop(
'disable_subprocess_warning', False)
assert isinstance(disable_subprocess_warning, bool), (
'disable_subprocess_warning should be a bool, but got '
f'{type(disable_subprocess_warning)}')
init_fn = partial(
default_worker_init_fn,
num_workers=dataloader_cfg.get('num_workers'),
rank=get_rank(),
seed=seed,
disable_subprocess_warning=disable_subprocess_warning)
else:
init_fn = None
# `persistent_workers` requires pytorch version >= 1.7
if ('persistent_workers' in dataloader_cfg

View File

@ -3,6 +3,7 @@ import copy
import logging
import os
import os.path as osp
import random
import shutil
import tempfile
from unittest import TestCase, skipIf
@ -352,6 +353,11 @@ def custom_collate(data_batch, pad_value):
return pseudo_collate(data_batch)
def custom_worker_init(worker_id):
np.random.seed(worker_id)
random.seed(worker_id)
class TestRunner(TestCase):
def setUp(self):
@ -376,6 +382,7 @@ class TestRunner(TestCase):
RUNNERS.register_module(module=CustomRunner, force=True)
EVALUATOR.register_module(module=ToyEvaluator, force=True)
FUNCTIONS.register_module(module=custom_collate, force=True)
FUNCTIONS.register_module(module=custom_worker_init, force=True)
self.temp_dir = tempfile.mkdtemp()
epoch_based_cfg = dict(
@ -459,6 +466,7 @@ class TestRunner(TestCase):
RUNNERS.module_dict.pop('CustomRunner')
EVALUATOR.module_dict.pop('ToyEvaluator')
FUNCTIONS.module_dict.pop('custom_collate')
FUNCTIONS.module_dict.pop('custom_worker_init')
logging.shutdown()
MMLogger._instance_dict.clear()
@ -1245,6 +1253,16 @@ class TestRunner(TestCase):
cfg, seed=seed, diff_rank_seed=True)
self.assertNotEqual(dataloader.sampler.seed, seed)
# custom worker_init_fn
cfg = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
worker_init_fn=dict(type='custom_worker_init'),
batch_size=1,
num_workers=2)
dataloader = runner.build_dataloader(cfg)
self.assertIs(dataloader.worker_init_fn.func, custom_worker_init)
def test_build_train_loop(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_train_loop'
@ -1689,6 +1707,14 @@ class TestRunner(TestCase):
runner = Runner.from_cfg(cfg)
runner.train()
# 10.3 Test build dataloader with custom worker_init function
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_train10.3'
cfg.train_dataloader.update(
worker_init_fn=dict(type='custom_worker_init'))
runner = Runner.from_cfg(cfg)
runner.train()
# 11 test build dataloader without default arguments of collate
# function.
with self.assertRaises(TypeError):