[Enhance] Allow users to customize worker_init_fn of Dataloader (#1038)

* customize worker init fn function

* add assert

* narrow worker_init_fn type
This commit is contained in:
shufan wu 2023-04-10 17:32:36 +08:00 committed by GitHub
parent eea2c278f4
commit 5e1ed7aaf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 15 deletions

View File

@ -19,7 +19,7 @@ from torch.utils.data import DataLoader
import mmengine import mmengine
from mmengine.config import Config, ConfigDict from mmengine.config import Config, ConfigDict
from mmengine.dataset import worker_init_fn from mmengine.dataset import worker_init_fn as default_worker_init_fn
from mmengine.device import get_device from mmengine.device import get_device
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
is_distributed, master_only) is_distributed, master_only)
@ -1381,21 +1381,28 @@ class Runner:
# build dataloader # build dataloader
init_fn: Optional[partial] init_fn: Optional[partial]
if seed is not None: if 'worker_init_fn' in dataloader_cfg:
disable_subprocess_warning = dataloader_cfg.pop( worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn')
'disable_subprocess_warning', False) worker_init_fn_type = worker_init_fn_cfg.pop('type')
assert isinstance( worker_init_fn = FUNCTIONS.get(worker_init_fn_type)
disable_subprocess_warning, assert callable(worker_init_fn)
bool), ('disable_subprocess_warning should be a bool, but got ' init_fn = partial(worker_init_fn,
f'{type(disable_subprocess_warning)}') **worker_init_fn_cfg) # type: ignore
init_fn = partial(
worker_init_fn,
num_workers=dataloader_cfg.get('num_workers'),
rank=get_rank(),
seed=seed,
disable_subprocess_warning=disable_subprocess_warning)
else: else:
init_fn = None if seed is not None:
disable_subprocess_warning = dataloader_cfg.pop(
'disable_subprocess_warning', False)
assert isinstance(disable_subprocess_warning, bool), (
'disable_subprocess_warning should be a bool, but got '
f'{type(disable_subprocess_warning)}')
init_fn = partial(
default_worker_init_fn,
num_workers=dataloader_cfg.get('num_workers'),
rank=get_rank(),
seed=seed,
disable_subprocess_warning=disable_subprocess_warning)
else:
init_fn = None
# `persistent_workers` requires pytorch version >= 1.7 # `persistent_workers` requires pytorch version >= 1.7
if ('persistent_workers' in dataloader_cfg if ('persistent_workers' in dataloader_cfg

View File

@ -3,6 +3,7 @@ import copy
import logging import logging
import os import os
import os.path as osp import os.path as osp
import random
import shutil import shutil
import tempfile import tempfile
from unittest import TestCase, skipIf from unittest import TestCase, skipIf
@ -352,6 +353,11 @@ def custom_collate(data_batch, pad_value):
return pseudo_collate(data_batch) return pseudo_collate(data_batch)
def custom_worker_init(worker_id):
np.random.seed(worker_id)
random.seed(worker_id)
class TestRunner(TestCase): class TestRunner(TestCase):
def setUp(self): def setUp(self):
@ -376,6 +382,7 @@ class TestRunner(TestCase):
RUNNERS.register_module(module=CustomRunner, force=True) RUNNERS.register_module(module=CustomRunner, force=True)
EVALUATOR.register_module(module=ToyEvaluator, force=True) EVALUATOR.register_module(module=ToyEvaluator, force=True)
FUNCTIONS.register_module(module=custom_collate, force=True) FUNCTIONS.register_module(module=custom_collate, force=True)
FUNCTIONS.register_module(module=custom_worker_init, force=True)
self.temp_dir = tempfile.mkdtemp() self.temp_dir = tempfile.mkdtemp()
epoch_based_cfg = dict( epoch_based_cfg = dict(
@ -459,6 +466,7 @@ class TestRunner(TestCase):
RUNNERS.module_dict.pop('CustomRunner') RUNNERS.module_dict.pop('CustomRunner')
EVALUATOR.module_dict.pop('ToyEvaluator') EVALUATOR.module_dict.pop('ToyEvaluator')
FUNCTIONS.module_dict.pop('custom_collate') FUNCTIONS.module_dict.pop('custom_collate')
FUNCTIONS.module_dict.pop('custom_worker_init')
logging.shutdown() logging.shutdown()
MMLogger._instance_dict.clear() MMLogger._instance_dict.clear()
@ -1245,6 +1253,16 @@ class TestRunner(TestCase):
cfg, seed=seed, diff_rank_seed=True) cfg, seed=seed, diff_rank_seed=True)
self.assertNotEqual(dataloader.sampler.seed, seed) self.assertNotEqual(dataloader.sampler.seed, seed)
# custom worker_init_fn
cfg = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
worker_init_fn=dict(type='custom_worker_init'),
batch_size=1,
num_workers=2)
dataloader = runner.build_dataloader(cfg)
self.assertIs(dataloader.worker_init_fn.func, custom_worker_init)
def test_build_train_loop(self): def test_build_train_loop(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_train_loop' cfg.experiment_name = 'test_build_train_loop'
@ -1689,6 +1707,14 @@ class TestRunner(TestCase):
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)
runner.train() runner.train()
# 10.3 Test build dataloader with custom worker_init function
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_train10.3'
cfg.train_dataloader.update(
worker_init_fn=dict(type='custom_worker_init'))
runner = Runner.from_cfg(cfg)
runner.train()
# 11 test build dataloader without default arguments of collate # 11 test build dataloader without default arguments of collate
# function. # function.
with self.assertRaises(TypeError): with self.assertRaises(TypeError):