From cdec4cbd4acbabca6cc3bc7bb0866b9a67d8f401 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Mon, 24 Apr 2023 20:42:54 +0800 Subject: [PATCH] [Fix] collate_fn does not support passing a function object (#1093) --- mmengine/runner/runner.py | 14 +++++++++--- tests/test_runner/test_runner.py | 38 ++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index e0e04fe7..c4e787f4 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1421,9 +1421,17 @@ class Runner: # samples into a dict without stacking the batch tensor. collate_fn_cfg = dataloader_cfg.pop('collate_fn', dict(type='pseudo_collate')) - collate_fn_type = collate_fn_cfg.pop('type') - collate_fn = FUNCTIONS.get(collate_fn_type) - collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore + if isinstance(collate_fn_cfg, dict): + collate_fn_type = collate_fn_cfg.pop('type') + collate_fn = FUNCTIONS.get(collate_fn_type) + collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore + elif callable(collate_fn_cfg): + collate_fn = collate_fn_cfg + else: + raise TypeError( + 'collate_fn should be a dict or callable object, but got ' + f'{collate_fn_cfg}') + data_loader = DataLoader( dataset=dataset, sampler=sampler if batch_sampler is None else None, diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index ae170d8d..e710e9f5 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -6,6 +6,7 @@ import os.path as osp import random import shutil import tempfile +from functools import partial from unittest import TestCase, skipIf import numpy as np @@ -1263,6 +1264,43 @@ class TestRunner(TestCase): dataloader = runner.build_dataloader(cfg) self.assertIs(dataloader.worker_init_fn.func, custom_worker_init) + # collate_fn is a dict + cfg = dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + worker_init_fn=dict(type='custom_worker_init'), + batch_size=1, + num_workers=2, + collate_fn=dict(type='pseudo_collate')) + dataloader = runner.build_dataloader(cfg) + self.assertIsInstance(dataloader.collate_fn, partial) + + # collate_fn is a callable object + def custom_collate(data_batch): + return data_batch + + cfg = dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + worker_init_fn=dict(type='custom_worker_init'), + batch_size=1, + num_workers=2, + collate_fn=custom_collate) + dataloader = runner.build_dataloader(cfg) + self.assertIs(dataloader.collate_fn, custom_collate) + # collate_fn is a invalid value + with self.assertRaisesRegex( + TypeError, 'collate_fn should be a dict or callable object'): + cfg = dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + worker_init_fn=dict(type='custom_worker_init'), + batch_size=1, + num_workers=2, + collate_fn='collate_fn') + dataloader = runner.build_dataloader(cfg) + self.assertIsInstance(dataloader.collate_fn, partial) + def test_build_train_loop(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_build_train_loop'