mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] collate_fn does not support passing a function object (#1093)
This commit is contained in:
parent
2aef53d3fa
commit
cdec4cbd4a
@ -1421,9 +1421,17 @@ class Runner:
|
||||
# samples into a dict without stacking the batch tensor.
|
||||
collate_fn_cfg = dataloader_cfg.pop('collate_fn',
|
||||
dict(type='pseudo_collate'))
|
||||
collate_fn_type = collate_fn_cfg.pop('type')
|
||||
collate_fn = FUNCTIONS.get(collate_fn_type)
|
||||
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
|
||||
if isinstance(collate_fn_cfg, dict):
|
||||
collate_fn_type = collate_fn_cfg.pop('type')
|
||||
collate_fn = FUNCTIONS.get(collate_fn_type)
|
||||
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
|
||||
elif callable(collate_fn_cfg):
|
||||
collate_fn = collate_fn_cfg
|
||||
else:
|
||||
raise TypeError(
|
||||
'collate_fn should be a dict or callable object, but got '
|
||||
f'{collate_fn_cfg}')
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataset=dataset,
|
||||
sampler=sampler if batch_sampler is None else None,
|
||||
|
@ -6,6 +6,7 @@ import os.path as osp
|
||||
import random
|
||||
import shutil
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from unittest import TestCase, skipIf
|
||||
|
||||
import numpy as np
|
||||
@ -1263,6 +1264,43 @@ class TestRunner(TestCase):
|
||||
dataloader = runner.build_dataloader(cfg)
|
||||
self.assertIs(dataloader.worker_init_fn.func, custom_worker_init)
|
||||
|
||||
# collate_fn is a dict
|
||||
cfg = dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
worker_init_fn=dict(type='custom_worker_init'),
|
||||
batch_size=1,
|
||||
num_workers=2,
|
||||
collate_fn=dict(type='pseudo_collate'))
|
||||
dataloader = runner.build_dataloader(cfg)
|
||||
self.assertIsInstance(dataloader.collate_fn, partial)
|
||||
|
||||
# collate_fn is a callable object
|
||||
def custom_collate(data_batch):
|
||||
return data_batch
|
||||
|
||||
cfg = dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
worker_init_fn=dict(type='custom_worker_init'),
|
||||
batch_size=1,
|
||||
num_workers=2,
|
||||
collate_fn=custom_collate)
|
||||
dataloader = runner.build_dataloader(cfg)
|
||||
self.assertIs(dataloader.collate_fn, custom_collate)
|
||||
# collate_fn is a invalid value
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, 'collate_fn should be a dict or callable object'):
|
||||
cfg = dict(
|
||||
dataset=dict(type='ToyDataset'),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
worker_init_fn=dict(type='custom_worker_init'),
|
||||
batch_size=1,
|
||||
num_workers=2,
|
||||
collate_fn='collate_fn')
|
||||
dataloader = runner.build_dataloader(cfg)
|
||||
self.assertIsInstance(dataloader.collate_fn, partial)
|
||||
|
||||
def test_build_train_loop(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_build_train_loop'
|
||||
|
Loading…
x
Reference in New Issue
Block a user