mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Add Mosaic transform (#1093)
* Fix typo in usage example * original mosaic code in mmdet * Adjust mosaic to the semantic segmentation * Remove bbox test in test_mosaic * Add unittests * Fix resize mode for seg_fields * Fix repr error * modify Mosaic docs * modify from Mosaic to RandomMosaic * Add docstring * modify Mosaic docstring * [Docs] Add a blank line before Returns: * add blank lines Co-authored-by: MeowZheng <meowzheng@outlook.com>
This commit is contained in:
parent
5b360d9ab1
commit
e9dd32b6e8
@ -6,13 +6,14 @@ from .loading import LoadAnnotations, LoadImageFromFile
|
||||
from .test_time_aug import MultiScaleFlipAug
|
||||
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray,
|
||||
SegRescale)
|
||||
RandomFlip, RandomMosaic, RandomRotate, Rerange,
|
||||
Resize, RGB2Gray, SegRescale)
|
||||
|
||||
__all__ = [
|
||||
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
|
||||
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
|
||||
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
|
||||
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
|
||||
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut'
|
||||
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut',
|
||||
'RandomMosaic'
|
||||
]
|
||||
|
@ -1,4 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.utils import deprecated_api_warning, is_tuple_of
|
||||
@ -1040,3 +1042,270 @@ class RandomCutOut(object):
|
||||
repr_str += f'fill_in={self.fill_in}, '
|
||||
repr_str += f'seg_fill_in={self.seg_fill_in})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class RandomMosaic(object):
|
||||
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
|
||||
one output image. The output image is composed of the parts from each sub-
|
||||
image.
|
||||
|
||||
.. code:: text
|
||||
|
||||
mosaic transform
|
||||
center_x
|
||||
+------------------------------+
|
||||
| pad | pad |
|
||||
| +-----------+ |
|
||||
| | | |
|
||||
| | image1 |--------+ |
|
||||
| | | | |
|
||||
| | | image2 | |
|
||||
center_y |----+-------------+-----------|
|
||||
| | cropped | |
|
||||
|pad | image3 | image4 |
|
||||
| | | |
|
||||
+----|-------------+-----------+
|
||||
| |
|
||||
+-------------+
|
||||
|
||||
The mosaic transform steps are as follows:
|
||||
1. Choose the mosaic center as the intersections of 4 images
|
||||
2. Get the left top image according to the index, and randomly
|
||||
sample another 3 images from the custom dataset.
|
||||
3. Sub image will be cropped if image is larger than mosaic patch
|
||||
|
||||
Args:
|
||||
prob (float): mosaic probability.
|
||||
img_scale (Sequence[int]): Image size after mosaic pipeline of
|
||||
a single image. The size of the output image is four times
|
||||
that of a single image. The output image comprises 4 single images.
|
||||
Default: (640, 640).
|
||||
center_ratio_range (Sequence[float]): Center ratio range of mosaic
|
||||
output. Default: (0.5, 1.5).
|
||||
pad_val (int): Pad value. Default: 0.
|
||||
seg_pad_val (int): Pad value of segmentation map. Default: 255.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
prob,
|
||||
img_scale=(640, 640),
|
||||
center_ratio_range=(0.5, 1.5),
|
||||
pad_val=0,
|
||||
seg_pad_val=255):
|
||||
assert 0 <= prob and prob <= 1
|
||||
assert isinstance(img_scale, tuple)
|
||||
self.prob = prob
|
||||
self.img_scale = img_scale
|
||||
self.center_ratio_range = center_ratio_range
|
||||
self.pad_val = pad_val
|
||||
self.seg_pad_val = seg_pad_val
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to make a mosaic of image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict.
|
||||
|
||||
Returns:
|
||||
dict: Result dict with mosaic transformed.
|
||||
"""
|
||||
mosaic = True if np.random.rand() < self.prob else False
|
||||
if mosaic:
|
||||
results = self._mosaic_transform_img(results)
|
||||
results = self._mosaic_transform_seg(results)
|
||||
return results
|
||||
|
||||
def get_indexes(self, dataset):
|
||||
"""Call function to collect indexes.
|
||||
|
||||
Args:
|
||||
dataset (:obj:`MultiImageMixDataset`): The dataset.
|
||||
|
||||
Returns:
|
||||
list: indexes.
|
||||
"""
|
||||
|
||||
indexes = [random.randint(0, len(dataset)) for _ in range(3)]
|
||||
return indexes
|
||||
|
||||
def _mosaic_transform_img(self, results):
|
||||
"""Mosaic transform function.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict.
|
||||
|
||||
Returns:
|
||||
dict: Updated result dict.
|
||||
"""
|
||||
|
||||
assert 'mix_results' in results
|
||||
if len(results['img'].shape) == 3:
|
||||
mosaic_img = np.full(
|
||||
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3),
|
||||
self.pad_val,
|
||||
dtype=results['img'].dtype)
|
||||
else:
|
||||
mosaic_img = np.full(
|
||||
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
|
||||
self.pad_val,
|
||||
dtype=results['img'].dtype)
|
||||
|
||||
# mosaic center x, y
|
||||
self.center_x = int(
|
||||
random.uniform(*self.center_ratio_range) * self.img_scale[1])
|
||||
self.center_y = int(
|
||||
random.uniform(*self.center_ratio_range) * self.img_scale[0])
|
||||
center_position = (self.center_x, self.center_y)
|
||||
|
||||
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
||||
for i, loc in enumerate(loc_strs):
|
||||
if loc == 'top_left':
|
||||
result_patch = copy.deepcopy(results)
|
||||
else:
|
||||
result_patch = copy.deepcopy(results['mix_results'][i - 1])
|
||||
|
||||
img_i = result_patch['img']
|
||||
h_i, w_i = img_i.shape[:2]
|
||||
# keep_ratio resize
|
||||
scale_ratio_i = min(self.img_scale[0] / h_i,
|
||||
self.img_scale[1] / w_i)
|
||||
img_i = mmcv.imresize(
|
||||
img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
|
||||
|
||||
# compute the combine parameters
|
||||
paste_coord, crop_coord = self._mosaic_combine(
|
||||
loc, center_position, img_i.shape[:2][::-1])
|
||||
x1_p, y1_p, x2_p, y2_p = paste_coord
|
||||
x1_c, y1_c, x2_c, y2_c = crop_coord
|
||||
|
||||
# crop and paste image
|
||||
mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
|
||||
|
||||
results['img'] = mosaic_img
|
||||
results['img_shape'] = mosaic_img.shape
|
||||
results['ori_shape'] = mosaic_img.shape
|
||||
|
||||
return results
|
||||
|
||||
def _mosaic_transform_seg(self, results):
|
||||
"""Mosaic transform function for label annotations.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict.
|
||||
|
||||
Returns:
|
||||
dict: Updated result dict.
|
||||
"""
|
||||
|
||||
assert 'mix_results' in results
|
||||
for key in results.get('seg_fields', []):
|
||||
mosaic_seg = np.full(
|
||||
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
|
||||
self.seg_pad_val,
|
||||
dtype=results[key].dtype)
|
||||
|
||||
# mosaic center x, y
|
||||
center_position = (self.center_x, self.center_y)
|
||||
|
||||
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
||||
for i, loc in enumerate(loc_strs):
|
||||
if loc == 'top_left':
|
||||
result_patch = copy.deepcopy(results)
|
||||
else:
|
||||
result_patch = copy.deepcopy(results['mix_results'][i - 1])
|
||||
|
||||
gt_seg_i = result_patch[key]
|
||||
h_i, w_i = gt_seg_i.shape[:2]
|
||||
# keep_ratio resize
|
||||
scale_ratio_i = min(self.img_scale[0] / h_i,
|
||||
self.img_scale[1] / w_i)
|
||||
gt_seg_i = mmcv.imresize(
|
||||
gt_seg_i,
|
||||
(int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)),
|
||||
interpolation='nearest')
|
||||
|
||||
# compute the combine parameters
|
||||
paste_coord, crop_coord = self._mosaic_combine(
|
||||
loc, center_position, gt_seg_i.shape[:2][::-1])
|
||||
x1_p, y1_p, x2_p, y2_p = paste_coord
|
||||
x1_c, y1_c, x2_c, y2_c = crop_coord
|
||||
|
||||
# crop and paste image
|
||||
mosaic_seg[y1_p:y2_p, x1_p:x2_p] = gt_seg_i[y1_c:y2_c,
|
||||
x1_c:x2_c]
|
||||
|
||||
results[key] = mosaic_seg
|
||||
|
||||
return results
|
||||
|
||||
def _mosaic_combine(self, loc, center_position_xy, img_shape_wh):
|
||||
"""Calculate global coordinate of mosaic image and local coordinate of
|
||||
cropped sub-image.
|
||||
|
||||
Args:
|
||||
loc (str): Index for the sub-image, loc in ('top_left',
|
||||
'top_right', 'bottom_left', 'bottom_right').
|
||||
center_position_xy (Sequence[float]): Mixing center for 4 images,
|
||||
(x, y).
|
||||
img_shape_wh (Sequence[int]): Width and height of sub-image
|
||||
|
||||
Returns:
|
||||
tuple[tuple[float]]: Corresponding coordinate of pasting and
|
||||
cropping
|
||||
- paste_coord (tuple): paste corner coordinate in mosaic image.
|
||||
- crop_coord (tuple): crop corner coordinate in mosaic image.
|
||||
"""
|
||||
|
||||
assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
||||
if loc == 'top_left':
|
||||
# index0 to top left part of image
|
||||
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
|
||||
max(center_position_xy[1] - img_shape_wh[1], 0), \
|
||||
center_position_xy[0], \
|
||||
center_position_xy[1]
|
||||
crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - (
|
||||
y2 - y1), img_shape_wh[0], img_shape_wh[1]
|
||||
|
||||
elif loc == 'top_right':
|
||||
# index1 to top right part of image
|
||||
x1, y1, x2, y2 = center_position_xy[0], \
|
||||
max(center_position_xy[1] - img_shape_wh[1], 0), \
|
||||
min(center_position_xy[0] + img_shape_wh[0],
|
||||
self.img_scale[1] * 2), \
|
||||
center_position_xy[1]
|
||||
crop_coord = 0, img_shape_wh[1] - (y2 - y1), min(
|
||||
img_shape_wh[0], x2 - x1), img_shape_wh[1]
|
||||
|
||||
elif loc == 'bottom_left':
|
||||
# index2 to bottom left part of image
|
||||
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
|
||||
center_position_xy[1], \
|
||||
center_position_xy[0], \
|
||||
min(self.img_scale[0] * 2, center_position_xy[1] +
|
||||
img_shape_wh[1])
|
||||
crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min(
|
||||
y2 - y1, img_shape_wh[1])
|
||||
|
||||
else:
|
||||
# index3 to bottom right part of image
|
||||
x1, y1, x2, y2 = center_position_xy[0], \
|
||||
center_position_xy[1], \
|
||||
min(center_position_xy[0] + img_shape_wh[0],
|
||||
self.img_scale[1] * 2), \
|
||||
min(self.img_scale[0] * 2, center_position_xy[1] +
|
||||
img_shape_wh[1])
|
||||
crop_coord = 0, 0, min(img_shape_wh[0],
|
||||
x2 - x1), min(y2 - y1, img_shape_wh[1])
|
||||
|
||||
paste_coord = x1, y1, x2, y2
|
||||
return paste_coord, crop_coord
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(prob={self.prob}, '
|
||||
repr_str += f'img_scale={self.img_scale}, '
|
||||
repr_str += f'center_ratio_range={self.center_ratio_range}, '
|
||||
repr_str += f'pad_val={self.pad_val}, '
|
||||
repr_str += f'seg_pad_val={self.pad_val})'
|
||||
return repr_str
|
||||
|
@ -614,3 +614,52 @@ def test_cutout():
|
||||
cutout_result = cutout_module(copy.deepcopy(results))
|
||||
assert cutout_result['img'].sum() > img.sum()
|
||||
assert cutout_result['gt_semantic_seg'].sum() > seg.sum()
|
||||
|
||||
|
||||
def test_mosaic():
|
||||
# test prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomMosaic', prob=1.5)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
# test assertion for invalid img_scale
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomMosaic', prob=1, img_scale=640)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
|
||||
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
|
||||
mosaic_module = build_from_cfg(transform, PIPELINES)
|
||||
assert 'Mosaic' in repr(mosaic_module)
|
||||
|
||||
# test assertion for invalid mix_results
|
||||
with pytest.raises(AssertionError):
|
||||
mosaic_module(results)
|
||||
|
||||
results['mix_results'] = [copy.deepcopy(results)] * 3
|
||||
results = mosaic_module(results)
|
||||
assert results['img'].shape[:2] == (20, 24)
|
||||
|
||||
results = dict()
|
||||
results['img'] = img[:, :, 0]
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
|
||||
transform = dict(type='RandomMosaic', prob=0, img_scale=(10, 12))
|
||||
mosaic_module = build_from_cfg(transform, PIPELINES)
|
||||
results['mix_results'] = [copy.deepcopy(results)] * 3
|
||||
results = mosaic_module(results)
|
||||
assert results['img'].shape[:2] == img.shape[:2]
|
||||
|
||||
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
|
||||
mosaic_module = build_from_cfg(transform, PIPELINES)
|
||||
results = mosaic_module(results)
|
||||
assert results['img'].shape[:2] == (20, 24)
|
||||
|
Loading…
x
Reference in New Issue
Block a user