mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Allow users to customize worker_init_fn of Dataloader (#1038)
* customize worker init fn function * add assert * narrow worker_init_fn type
This commit is contained in:
parent
eea2c278f4
commit
5e1ed7aaf0
@ -19,7 +19,7 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
import mmengine
|
import mmengine
|
||||||
from mmengine.config import Config, ConfigDict
|
from mmengine.config import Config, ConfigDict
|
||||||
from mmengine.dataset import worker_init_fn
|
from mmengine.dataset import worker_init_fn as default_worker_init_fn
|
||||||
from mmengine.device import get_device
|
from mmengine.device import get_device
|
||||||
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
|
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
|
||||||
is_distributed, master_only)
|
is_distributed, master_only)
|
||||||
@ -1381,21 +1381,28 @@ class Runner:
|
|||||||
# build dataloader
|
# build dataloader
|
||||||
init_fn: Optional[partial]
|
init_fn: Optional[partial]
|
||||||
|
|
||||||
if seed is not None:
|
if 'worker_init_fn' in dataloader_cfg:
|
||||||
disable_subprocess_warning = dataloader_cfg.pop(
|
worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn')
|
||||||
'disable_subprocess_warning', False)
|
worker_init_fn_type = worker_init_fn_cfg.pop('type')
|
||||||
assert isinstance(
|
worker_init_fn = FUNCTIONS.get(worker_init_fn_type)
|
||||||
disable_subprocess_warning,
|
assert callable(worker_init_fn)
|
||||||
bool), ('disable_subprocess_warning should be a bool, but got '
|
init_fn = partial(worker_init_fn,
|
||||||
f'{type(disable_subprocess_warning)}')
|
**worker_init_fn_cfg) # type: ignore
|
||||||
init_fn = partial(
|
|
||||||
worker_init_fn,
|
|
||||||
num_workers=dataloader_cfg.get('num_workers'),
|
|
||||||
rank=get_rank(),
|
|
||||||
seed=seed,
|
|
||||||
disable_subprocess_warning=disable_subprocess_warning)
|
|
||||||
else:
|
else:
|
||||||
init_fn = None
|
if seed is not None:
|
||||||
|
disable_subprocess_warning = dataloader_cfg.pop(
|
||||||
|
'disable_subprocess_warning', False)
|
||||||
|
assert isinstance(disable_subprocess_warning, bool), (
|
||||||
|
'disable_subprocess_warning should be a bool, but got '
|
||||||
|
f'{type(disable_subprocess_warning)}')
|
||||||
|
init_fn = partial(
|
||||||
|
default_worker_init_fn,
|
||||||
|
num_workers=dataloader_cfg.get('num_workers'),
|
||||||
|
rank=get_rank(),
|
||||||
|
seed=seed,
|
||||||
|
disable_subprocess_warning=disable_subprocess_warning)
|
||||||
|
else:
|
||||||
|
init_fn = None
|
||||||
|
|
||||||
# `persistent_workers` requires pytorch version >= 1.7
|
# `persistent_workers` requires pytorch version >= 1.7
|
||||||
if ('persistent_workers' in dataloader_cfg
|
if ('persistent_workers' in dataloader_cfg
|
||||||
|
@ -3,6 +3,7 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest import TestCase, skipIf
|
from unittest import TestCase, skipIf
|
||||||
@ -352,6 +353,11 @@ def custom_collate(data_batch, pad_value):
|
|||||||
return pseudo_collate(data_batch)
|
return pseudo_collate(data_batch)
|
||||||
|
|
||||||
|
|
||||||
|
def custom_worker_init(worker_id):
|
||||||
|
np.random.seed(worker_id)
|
||||||
|
random.seed(worker_id)
|
||||||
|
|
||||||
|
|
||||||
class TestRunner(TestCase):
|
class TestRunner(TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -376,6 +382,7 @@ class TestRunner(TestCase):
|
|||||||
RUNNERS.register_module(module=CustomRunner, force=True)
|
RUNNERS.register_module(module=CustomRunner, force=True)
|
||||||
EVALUATOR.register_module(module=ToyEvaluator, force=True)
|
EVALUATOR.register_module(module=ToyEvaluator, force=True)
|
||||||
FUNCTIONS.register_module(module=custom_collate, force=True)
|
FUNCTIONS.register_module(module=custom_collate, force=True)
|
||||||
|
FUNCTIONS.register_module(module=custom_worker_init, force=True)
|
||||||
|
|
||||||
self.temp_dir = tempfile.mkdtemp()
|
self.temp_dir = tempfile.mkdtemp()
|
||||||
epoch_based_cfg = dict(
|
epoch_based_cfg = dict(
|
||||||
@ -459,6 +466,7 @@ class TestRunner(TestCase):
|
|||||||
RUNNERS.module_dict.pop('CustomRunner')
|
RUNNERS.module_dict.pop('CustomRunner')
|
||||||
EVALUATOR.module_dict.pop('ToyEvaluator')
|
EVALUATOR.module_dict.pop('ToyEvaluator')
|
||||||
FUNCTIONS.module_dict.pop('custom_collate')
|
FUNCTIONS.module_dict.pop('custom_collate')
|
||||||
|
FUNCTIONS.module_dict.pop('custom_worker_init')
|
||||||
|
|
||||||
logging.shutdown()
|
logging.shutdown()
|
||||||
MMLogger._instance_dict.clear()
|
MMLogger._instance_dict.clear()
|
||||||
@ -1245,6 +1253,16 @@ class TestRunner(TestCase):
|
|||||||
cfg, seed=seed, diff_rank_seed=True)
|
cfg, seed=seed, diff_rank_seed=True)
|
||||||
self.assertNotEqual(dataloader.sampler.seed, seed)
|
self.assertNotEqual(dataloader.sampler.seed, seed)
|
||||||
|
|
||||||
|
# custom worker_init_fn
|
||||||
|
cfg = dict(
|
||||||
|
dataset=dict(type='ToyDataset'),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
worker_init_fn=dict(type='custom_worker_init'),
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=2)
|
||||||
|
dataloader = runner.build_dataloader(cfg)
|
||||||
|
self.assertIs(dataloader.worker_init_fn.func, custom_worker_init)
|
||||||
|
|
||||||
def test_build_train_loop(self):
|
def test_build_train_loop(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_build_train_loop'
|
cfg.experiment_name = 'test_build_train_loop'
|
||||||
@ -1689,6 +1707,14 @@ class TestRunner(TestCase):
|
|||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
|
# 10.3 Test build dataloader with custom worker_init function
|
||||||
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_train10.3'
|
||||||
|
cfg.train_dataloader.update(
|
||||||
|
worker_init_fn=dict(type='custom_worker_init'))
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
runner.train()
|
||||||
|
|
||||||
# 11 test build dataloader without default arguments of collate
|
# 11 test build dataloader without default arguments of collate
|
||||||
# function.
|
# function.
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user