[Fix] collate_fn does not support passing a function object (#1093)

This commit is contained in:
Zaida Zhou 2023-04-24 20:42:54 +08:00 committed by GitHub
parent 2aef53d3fa
commit cdec4cbd4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 3 deletions

View File

@ -1421,9 +1421,17 @@ class Runner:
# samples into a dict without stacking the batch tensor. # samples into a dict without stacking the batch tensor.
collate_fn_cfg = dataloader_cfg.pop('collate_fn', collate_fn_cfg = dataloader_cfg.pop('collate_fn',
dict(type='pseudo_collate')) dict(type='pseudo_collate'))
collate_fn_type = collate_fn_cfg.pop('type') if isinstance(collate_fn_cfg, dict):
collate_fn = FUNCTIONS.get(collate_fn_type) collate_fn_type = collate_fn_cfg.pop('type')
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore collate_fn = FUNCTIONS.get(collate_fn_type)
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
elif callable(collate_fn_cfg):
collate_fn = collate_fn_cfg
else:
raise TypeError(
'collate_fn should be a dict or callable object, but got '
f'{collate_fn_cfg}')
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
sampler=sampler if batch_sampler is None else None, sampler=sampler if batch_sampler is None else None,

View File

@ -6,6 +6,7 @@ import os.path as osp
import random import random
import shutil import shutil
import tempfile import tempfile
from functools import partial
from unittest import TestCase, skipIf from unittest import TestCase, skipIf
import numpy as np import numpy as np
@ -1263,6 +1264,43 @@ class TestRunner(TestCase):
dataloader = runner.build_dataloader(cfg) dataloader = runner.build_dataloader(cfg)
self.assertIs(dataloader.worker_init_fn.func, custom_worker_init) self.assertIs(dataloader.worker_init_fn.func, custom_worker_init)
# collate_fn is a dict
cfg = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
worker_init_fn=dict(type='custom_worker_init'),
batch_size=1,
num_workers=2,
collate_fn=dict(type='pseudo_collate'))
dataloader = runner.build_dataloader(cfg)
self.assertIsInstance(dataloader.collate_fn, partial)
# collate_fn is a callable object
def custom_collate(data_batch):
return data_batch
cfg = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
worker_init_fn=dict(type='custom_worker_init'),
batch_size=1,
num_workers=2,
collate_fn=custom_collate)
dataloader = runner.build_dataloader(cfg)
self.assertIs(dataloader.collate_fn, custom_collate)
# collate_fn is a invalid value
with self.assertRaisesRegex(
TypeError, 'collate_fn should be a dict or callable object'):
cfg = dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
worker_init_fn=dict(type='custom_worker_init'),
batch_size=1,
num_workers=2,
collate_fn='collate_fn')
dataloader = runner.build_dataloader(cfg)
self.assertIsInstance(dataloader.collate_fn, partial)
def test_build_train_loop(self): def test_build_train_loop(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_train_loop' cfg.experiment_name = 'test_build_train_loop'