mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] collate_fn does not support passing a function object (#1093)
This commit is contained in:
parent
2aef53d3fa
commit
cdec4cbd4a
@ -1421,9 +1421,17 @@ class Runner:
|
|||||||
# samples into a dict without stacking the batch tensor.
|
# samples into a dict without stacking the batch tensor.
|
||||||
collate_fn_cfg = dataloader_cfg.pop('collate_fn',
|
collate_fn_cfg = dataloader_cfg.pop('collate_fn',
|
||||||
dict(type='pseudo_collate'))
|
dict(type='pseudo_collate'))
|
||||||
collate_fn_type = collate_fn_cfg.pop('type')
|
if isinstance(collate_fn_cfg, dict):
|
||||||
collate_fn = FUNCTIONS.get(collate_fn_type)
|
collate_fn_type = collate_fn_cfg.pop('type')
|
||||||
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
|
collate_fn = FUNCTIONS.get(collate_fn_type)
|
||||||
|
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
|
||||||
|
elif callable(collate_fn_cfg):
|
||||||
|
collate_fn = collate_fn_cfg
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'collate_fn should be a dict or callable object, but got '
|
||||||
|
f'{collate_fn_cfg}')
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
sampler=sampler if batch_sampler is None else None,
|
sampler=sampler if batch_sampler is None else None,
|
||||||
|
@ -6,6 +6,7 @@ import os.path as osp
|
|||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from functools import partial
|
||||||
from unittest import TestCase, skipIf
|
from unittest import TestCase, skipIf
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -1263,6 +1264,43 @@ class TestRunner(TestCase):
|
|||||||
dataloader = runner.build_dataloader(cfg)
|
dataloader = runner.build_dataloader(cfg)
|
||||||
self.assertIs(dataloader.worker_init_fn.func, custom_worker_init)
|
self.assertIs(dataloader.worker_init_fn.func, custom_worker_init)
|
||||||
|
|
||||||
|
# collate_fn is a dict
|
||||||
|
cfg = dict(
|
||||||
|
dataset=dict(type='ToyDataset'),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
worker_init_fn=dict(type='custom_worker_init'),
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=2,
|
||||||
|
collate_fn=dict(type='pseudo_collate'))
|
||||||
|
dataloader = runner.build_dataloader(cfg)
|
||||||
|
self.assertIsInstance(dataloader.collate_fn, partial)
|
||||||
|
|
||||||
|
# collate_fn is a callable object
|
||||||
|
def custom_collate(data_batch):
|
||||||
|
return data_batch
|
||||||
|
|
||||||
|
cfg = dict(
|
||||||
|
dataset=dict(type='ToyDataset'),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
worker_init_fn=dict(type='custom_worker_init'),
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=2,
|
||||||
|
collate_fn=custom_collate)
|
||||||
|
dataloader = runner.build_dataloader(cfg)
|
||||||
|
self.assertIs(dataloader.collate_fn, custom_collate)
|
||||||
|
# collate_fn is a invalid value
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
TypeError, 'collate_fn should be a dict or callable object'):
|
||||||
|
cfg = dict(
|
||||||
|
dataset=dict(type='ToyDataset'),
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
|
worker_init_fn=dict(type='custom_worker_init'),
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=2,
|
||||||
|
collate_fn='collate_fn')
|
||||||
|
dataloader = runner.build_dataloader(cfg)
|
||||||
|
self.assertIsInstance(dataloader.collate_fn, partial)
|
||||||
|
|
||||||
def test_build_train_loop(self):
|
def test_build_train_loop(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_build_train_loop'
|
cfg.experiment_name = 'test_build_train_loop'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user