[Fix] collate_fn does not support passing a function object (#1093)

This commit is contained in:
Zaida Zhou 2023-04-24 20:42:54 +08:00 committed by GitHub
parent 2aef53d3fa
commit cdec4cbd4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 3 deletions

View File

@ -1421,9 +1421,17 @@ class Runner:
# samples into a dict without stacking the batch tensor.
collate_fn_cfg = dataloader_cfg.pop('collate_fn',
dict(type='pseudo_collate'))
collate_fn_type = collate_fn_cfg.pop('type')
collate_fn = FUNCTIONS.get(collate_fn_type)
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
if isinstance(collate_fn_cfg, dict):
collate_fn_type = collate_fn_cfg.pop('type')
collate_fn = FUNCTIONS.get(collate_fn_type)
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
elif callable(collate_fn_cfg):
collate_fn = collate_fn_cfg
else:
raise TypeError(
'collate_fn should be a dict or callable object, but got '
f'{collate_fn_cfg}')
data_loader = DataLoader(
dataset=dataset,
sampler=sampler if batch_sampler is None else None,

View File

@ -6,6 +6,7 @@ import os.path as osp
import random
import shutil
import tempfile
from functools import partial
from unittest import TestCase, skipIf
import numpy as np
@ -1263,6 +1264,43 @@ class TestRunner(TestCase):
dataloader = runner.build_dataloader(cfg)
self.assertIs(dataloader.worker_init_fn.func, custom_worker_init)
# collate_fn is a dict
cfg = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
worker_init_fn=dict(type='custom_worker_init'),
batch_size=1,
num_workers=2,
collate_fn=dict(type='pseudo_collate'))
dataloader = runner.build_dataloader(cfg)
self.assertIsInstance(dataloader.collate_fn, partial)
# collate_fn is a callable object
def custom_collate(data_batch):
return data_batch
cfg = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
worker_init_fn=dict(type='custom_worker_init'),
batch_size=1,
num_workers=2,
collate_fn=custom_collate)
dataloader = runner.build_dataloader(cfg)
self.assertIs(dataloader.collate_fn, custom_collate)
# collate_fn is a invalid value
with self.assertRaisesRegex(
TypeError, 'collate_fn should be a dict or callable object'):
cfg = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
worker_init_fn=dict(type='custom_worker_init'),
batch_size=1,
num_workers=2,
collate_fn='collate_fn')
dataloader = runner.build_dataloader(cfg)
self.assertIsInstance(dataloader.collate_fn, partial)
def test_build_train_loop(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_train_loop'