[Enhance] Support multi-band image for Mosaic (#2748)

## Modification

I changed the hardcoded 3 channel length to dynamic channel length in
`np.full` function arguments.
This modification enables `RandomMosaic` transform to support
multispectral image (e.g. RGB image with NIR band) or bi-temporal image
pairs for change detection task.

## Checklist

1. Pre-commit or other linting tools are used to fix the potential lint
issues.
2. The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
3. If the modification has potential influence on downstream projects,
this PR should be tested with downstream projects, like MMDet or
MMDet3D.
4. The documentation has been modified accordingly, like docstring or
example tutorials.
This commit is contained in:
Junhwa Song 2023-03-15 20:36:47 +09:00 committed by GitHub
parent 1f1f2666b5
commit cb2d8fe085
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 1 deletions

View File

@ -1062,8 +1062,9 @@ class RandomMosaic(BaseTransform):
assert 'mix_results' in results
if len(results['img'].shape) == 3:
c = results['img'].shape[2]
mosaic_img = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3),
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), c),
self.pad_val,
dtype=results['img'].dtype)
else:

View File

@ -639,6 +639,17 @@ def test_mosaic():
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)
results = dict()
results['img'] = np.concatenate((img, img), axis=2)
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = TRANSFORMS.build(transform)
results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[2] == 6
def test_cutout():
# test prob