From cb2d8fe085427e904e360c3497dfd1c1927666e9 Mon Sep 17 00:00:00 2001 From: Junhwa Song Date: Wed, 15 Mar 2023 20:36:47 +0900 Subject: [PATCH] [Enhance] Support multi-band image for Mosaic (#2748) ## Modification I changed the hardcoded 3 channel length to dynamic channel length in `np.full` function arguments. This modification enables `RandomMosaic` transform to support multispectral image (e.g. RGB image with NIR band) or bi-temporal image pairs for change detection task. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 4. The documentation has been modified accordingly, like docstring or example tutorials. --- mmseg/datasets/transforms/transforms.py | 3 ++- tests/test_datasets/test_transform.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py index 4f5316026..fb7e2a0e6 100644 --- a/mmseg/datasets/transforms/transforms.py +++ b/mmseg/datasets/transforms/transforms.py @@ -1062,8 +1062,9 @@ class RandomMosaic(BaseTransform): assert 'mix_results' in results if len(results['img'].shape) == 3: + c = results['img'].shape[2] mosaic_img = np.full( - (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3), + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), c), self.pad_val, dtype=results['img'].dtype) else: diff --git a/tests/test_datasets/test_transform.py b/tests/test_datasets/test_transform.py index a9136bebc..92d6c6106 100644 --- a/tests/test_datasets/test_transform.py +++ b/tests/test_datasets/test_transform.py @@ -639,6 +639,17 @@ def test_mosaic(): results = mosaic_module(results) assert results['img'].shape[:2] == (20, 24) + results = dict() + results['img'] = np.concatenate((img, img), axis=2) + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + + transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12)) + mosaic_module = TRANSFORMS.build(transform) + results['mix_results'] = [copy.deepcopy(results)] * 3 + results = mosaic_module(results) + assert results['img'].shape[2] == 6 + def test_cutout(): # test prob