[Feature] Add Mosaic transform (#1093)

* Fix typo in usage example

* original mosaic code in mmdet

* Adjust mosaic to the semantic segmentation

* Remove bbox test in test_mosaic

* Add unittests

* Fix resize mode for seg_fields

* Fix repr error

* modify Mosaic docs

* modify from Mosaic to RandomMosaic

* Add docstring

* modify Mosaic docstring

* [Docs] Add a blank line before Returns:

* add blank lines

Co-authored-by: MeowZheng <meowzheng@outlook.com>
This commit is contained in:
Kyungmin Lee 2022-01-11 17:17:36 +09:00 committed by GitHub
parent 5b360d9ab1
commit e9dd32b6e8
3 changed files with 322 additions and 3 deletions

View File

@ -6,13 +6,14 @@ from .loading import LoadAnnotations, LoadImageFromFile
from .test_time_aug import MultiScaleFlipAug from .test_time_aug import MultiScaleFlipAug
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
PhotoMetricDistortion, RandomCrop, RandomCutOut, PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray, RandomFlip, RandomMosaic, RandomRotate, Rerange,
SegRescale) Resize, RGB2Gray, SegRescale)
__all__ = [ __all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut' 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut',
'RandomMosaic'
] ]

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy
import mmcv import mmcv
import numpy as np import numpy as np
from mmcv.utils import deprecated_api_warning, is_tuple_of from mmcv.utils import deprecated_api_warning, is_tuple_of
@ -1040,3 +1042,270 @@ class RandomCutOut(object):
repr_str += f'fill_in={self.fill_in}, ' repr_str += f'fill_in={self.fill_in}, '
repr_str += f'seg_fill_in={self.seg_fill_in})' repr_str += f'seg_fill_in={self.seg_fill_in})'
return repr_str return repr_str
@PIPELINES.register_module()
class RandomMosaic(object):
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
one output image. The output image is composed of the parts from each sub-
image.
.. code:: text
mosaic transform
center_x
+------------------------------+
| pad | pad |
| +-----------+ |
| | | |
| | image1 |--------+ |
| | | | |
| | | image2 | |
center_y |----+-------------+-----------|
| | cropped | |
|pad | image3 | image4 |
| | | |
+----|-------------+-----------+
| |
+-------------+
The mosaic transform steps are as follows:
1. Choose the mosaic center as the intersections of 4 images
2. Get the left top image according to the index, and randomly
sample another 3 images from the custom dataset.
3. Sub image will be cropped if image is larger than mosaic patch
Args:
prob (float): mosaic probability.
img_scale (Sequence[int]): Image size after mosaic pipeline of
a single image. The size of the output image is four times
that of a single image. The output image comprises 4 single images.
Default: (640, 640).
center_ratio_range (Sequence[float]): Center ratio range of mosaic
output. Default: (0.5, 1.5).
pad_val (int): Pad value. Default: 0.
seg_pad_val (int): Pad value of segmentation map. Default: 255.
"""
def __init__(self,
prob,
img_scale=(640, 640),
center_ratio_range=(0.5, 1.5),
pad_val=0,
seg_pad_val=255):
assert 0 <= prob and prob <= 1
assert isinstance(img_scale, tuple)
self.prob = prob
self.img_scale = img_scale
self.center_ratio_range = center_ratio_range
self.pad_val = pad_val
self.seg_pad_val = seg_pad_val
def __call__(self, results):
"""Call function to make a mosaic of image.
Args:
results (dict): Result dict.
Returns:
dict: Result dict with mosaic transformed.
"""
mosaic = True if np.random.rand() < self.prob else False
if mosaic:
results = self._mosaic_transform_img(results)
results = self._mosaic_transform_seg(results)
return results
def get_indexes(self, dataset):
"""Call function to collect indexes.
Args:
dataset (:obj:`MultiImageMixDataset`): The dataset.
Returns:
list: indexes.
"""
indexes = [random.randint(0, len(dataset)) for _ in range(3)]
return indexes
def _mosaic_transform_img(self, results):
"""Mosaic transform function.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""
assert 'mix_results' in results
if len(results['img'].shape) == 3:
mosaic_img = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3),
self.pad_val,
dtype=results['img'].dtype)
else:
mosaic_img = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
self.pad_val,
dtype=results['img'].dtype)
# mosaic center x, y
self.center_x = int(
random.uniform(*self.center_ratio_range) * self.img_scale[1])
self.center_y = int(
random.uniform(*self.center_ratio_range) * self.img_scale[0])
center_position = (self.center_x, self.center_y)
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
for i, loc in enumerate(loc_strs):
if loc == 'top_left':
result_patch = copy.deepcopy(results)
else:
result_patch = copy.deepcopy(results['mix_results'][i - 1])
img_i = result_patch['img']
h_i, w_i = img_i.shape[:2]
# keep_ratio resize
scale_ratio_i = min(self.img_scale[0] / h_i,
self.img_scale[1] / w_i)
img_i = mmcv.imresize(
img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
# compute the combine parameters
paste_coord, crop_coord = self._mosaic_combine(
loc, center_position, img_i.shape[:2][::-1])
x1_p, y1_p, x2_p, y2_p = paste_coord
x1_c, y1_c, x2_c, y2_c = crop_coord
# crop and paste image
mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
results['img'] = mosaic_img
results['img_shape'] = mosaic_img.shape
results['ori_shape'] = mosaic_img.shape
return results
def _mosaic_transform_seg(self, results):
"""Mosaic transform function for label annotations.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""
assert 'mix_results' in results
for key in results.get('seg_fields', []):
mosaic_seg = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
self.seg_pad_val,
dtype=results[key].dtype)
# mosaic center x, y
center_position = (self.center_x, self.center_y)
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
for i, loc in enumerate(loc_strs):
if loc == 'top_left':
result_patch = copy.deepcopy(results)
else:
result_patch = copy.deepcopy(results['mix_results'][i - 1])
gt_seg_i = result_patch[key]
h_i, w_i = gt_seg_i.shape[:2]
# keep_ratio resize
scale_ratio_i = min(self.img_scale[0] / h_i,
self.img_scale[1] / w_i)
gt_seg_i = mmcv.imresize(
gt_seg_i,
(int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)),
interpolation='nearest')
# compute the combine parameters
paste_coord, crop_coord = self._mosaic_combine(
loc, center_position, gt_seg_i.shape[:2][::-1])
x1_p, y1_p, x2_p, y2_p = paste_coord
x1_c, y1_c, x2_c, y2_c = crop_coord
# crop and paste image
mosaic_seg[y1_p:y2_p, x1_p:x2_p] = gt_seg_i[y1_c:y2_c,
x1_c:x2_c]
results[key] = mosaic_seg
return results
def _mosaic_combine(self, loc, center_position_xy, img_shape_wh):
"""Calculate global coordinate of mosaic image and local coordinate of
cropped sub-image.
Args:
loc (str): Index for the sub-image, loc in ('top_left',
'top_right', 'bottom_left', 'bottom_right').
center_position_xy (Sequence[float]): Mixing center for 4 images,
(x, y).
img_shape_wh (Sequence[int]): Width and height of sub-image
Returns:
tuple[tuple[float]]: Corresponding coordinate of pasting and
cropping
- paste_coord (tuple): paste corner coordinate in mosaic image.
- crop_coord (tuple): crop corner coordinate in mosaic image.
"""
assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right')
if loc == 'top_left':
# index0 to top left part of image
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
max(center_position_xy[1] - img_shape_wh[1], 0), \
center_position_xy[0], \
center_position_xy[1]
crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - (
y2 - y1), img_shape_wh[0], img_shape_wh[1]
elif loc == 'top_right':
# index1 to top right part of image
x1, y1, x2, y2 = center_position_xy[0], \
max(center_position_xy[1] - img_shape_wh[1], 0), \
min(center_position_xy[0] + img_shape_wh[0],
self.img_scale[1] * 2), \
center_position_xy[1]
crop_coord = 0, img_shape_wh[1] - (y2 - y1), min(
img_shape_wh[0], x2 - x1), img_shape_wh[1]
elif loc == 'bottom_left':
# index2 to bottom left part of image
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
center_position_xy[1], \
center_position_xy[0], \
min(self.img_scale[0] * 2, center_position_xy[1] +
img_shape_wh[1])
crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min(
y2 - y1, img_shape_wh[1])
else:
# index3 to bottom right part of image
x1, y1, x2, y2 = center_position_xy[0], \
center_position_xy[1], \
min(center_position_xy[0] + img_shape_wh[0],
self.img_scale[1] * 2), \
min(self.img_scale[0] * 2, center_position_xy[1] +
img_shape_wh[1])
crop_coord = 0, 0, min(img_shape_wh[0],
x2 - x1), min(y2 - y1, img_shape_wh[1])
paste_coord = x1, y1, x2, y2
return paste_coord, crop_coord
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'img_scale={self.img_scale}, '
repr_str += f'center_ratio_range={self.center_ratio_range}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'seg_pad_val={self.pad_val})'
return repr_str

View File

@ -614,3 +614,52 @@ def test_cutout():
cutout_result = cutout_module(copy.deepcopy(results)) cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() > img.sum() assert cutout_result['img'].sum() > img.sum()
assert cutout_result['gt_semantic_seg'].sum() > seg.sum() assert cutout_result['gt_semantic_seg'].sum() > seg.sum()
def test_mosaic():
# test prob
with pytest.raises(AssertionError):
transform = dict(type='RandomMosaic', prob=1.5)
build_from_cfg(transform, PIPELINES)
# test assertion for invalid img_scale
with pytest.raises(AssertionError):
transform = dict(type='RandomMosaic', prob=1, img_scale=640)
build_from_cfg(transform, PIPELINES)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = build_from_cfg(transform, PIPELINES)
assert 'Mosaic' in repr(mosaic_module)
# test assertion for invalid mix_results
with pytest.raises(AssertionError):
mosaic_module(results)
results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)
results = dict()
results['img'] = img[:, :, 0]
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
transform = dict(type='RandomMosaic', prob=0, img_scale=(10, 12))
mosaic_module = build_from_cfg(transform, PIPELINES)
results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[:2] == img.shape[:2]
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = build_from_cfg(transform, PIPELINES)
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)