diff --git a/mmseg/datasets/pipelines/__init__.py b/mmseg/datasets/pipelines/__init__.py index 91d9e4749..8256a6fe2 100644 --- a/mmseg/datasets/pipelines/__init__.py +++ b/mmseg/datasets/pipelines/__init__.py @@ -6,13 +6,14 @@ from .loading import LoadAnnotations, LoadImageFromFile from .test_time_aug import MultiScaleFlipAug from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, PhotoMetricDistortion, RandomCrop, RandomCutOut, - RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray, - SegRescale) + RandomFlip, RandomMosaic, RandomRotate, Rerange, + Resize, RGB2Gray, SegRescale) __all__ = [ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', - 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut' + 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut', + 'RandomMosaic' ] diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 567c960a1..003a56450 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy + import mmcv import numpy as np from mmcv.utils import deprecated_api_warning, is_tuple_of @@ -1040,3 +1042,270 @@ class RandomCutOut(object): repr_str += f'fill_in={self.fill_in}, ' repr_str += f'seg_fill_in={self.seg_fill_in})' return repr_str + + +@PIPELINES.register_module() +class RandomMosaic(object): + """Mosaic augmentation. Given 4 images, mosaic transform combines them into + one output image. The output image is composed of the parts from each sub- + image. + + .. code:: text + + mosaic transform + center_x + +------------------------------+ + | pad | pad | + | +-----------+ | + | | | | + | | image1 |--------+ | + | | | | | + | | | image2 | | + center_y |----+-------------+-----------| + | | cropped | | + |pad | image3 | image4 | + | | | | + +----|-------------+-----------+ + | | + +-------------+ + + The mosaic transform steps are as follows: + 1. Choose the mosaic center as the intersections of 4 images + 2. Get the left top image according to the index, and randomly + sample another 3 images from the custom dataset. + 3. Sub image will be cropped if image is larger than mosaic patch + + Args: + prob (float): mosaic probability. + img_scale (Sequence[int]): Image size after mosaic pipeline of + a single image. The size of the output image is four times + that of a single image. The output image comprises 4 single images. + Default: (640, 640). + center_ratio_range (Sequence[float]): Center ratio range of mosaic + output. Default: (0.5, 1.5). + pad_val (int): Pad value. Default: 0. + seg_pad_val (int): Pad value of segmentation map. Default: 255. + """ + + def __init__(self, + prob, + img_scale=(640, 640), + center_ratio_range=(0.5, 1.5), + pad_val=0, + seg_pad_val=255): + assert 0 <= prob and prob <= 1 + assert isinstance(img_scale, tuple) + self.prob = prob + self.img_scale = img_scale + self.center_ratio_range = center_ratio_range + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + + def __call__(self, results): + """Call function to make a mosaic of image. + + Args: + results (dict): Result dict. + + Returns: + dict: Result dict with mosaic transformed. + """ + mosaic = True if np.random.rand() < self.prob else False + if mosaic: + results = self._mosaic_transform_img(results) + results = self._mosaic_transform_seg(results) + return results + + def get_indexes(self, dataset): + """Call function to collect indexes. + + Args: + dataset (:obj:`MultiImageMixDataset`): The dataset. + + Returns: + list: indexes. + """ + + indexes = [random.randint(0, len(dataset)) for _ in range(3)] + return indexes + + def _mosaic_transform_img(self, results): + """Mosaic transform function. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert 'mix_results' in results + if len(results['img'].shape) == 3: + mosaic_img = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3), + self.pad_val, + dtype=results['img'].dtype) + else: + mosaic_img = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), + self.pad_val, + dtype=results['img'].dtype) + + # mosaic center x, y + self.center_x = int( + random.uniform(*self.center_ratio_range) * self.img_scale[1]) + self.center_y = int( + random.uniform(*self.center_ratio_range) * self.img_scale[0]) + center_position = (self.center_x, self.center_y) + + loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + for i, loc in enumerate(loc_strs): + if loc == 'top_left': + result_patch = copy.deepcopy(results) + else: + result_patch = copy.deepcopy(results['mix_results'][i - 1]) + + img_i = result_patch['img'] + h_i, w_i = img_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[0] / h_i, + self.img_scale[1] / w_i) + img_i = mmcv.imresize( + img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i))) + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, img_i.shape[:2][::-1]) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c] + + results['img'] = mosaic_img + results['img_shape'] = mosaic_img.shape + results['ori_shape'] = mosaic_img.shape + + return results + + def _mosaic_transform_seg(self, results): + """Mosaic transform function for label annotations. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + + assert 'mix_results' in results + for key in results.get('seg_fields', []): + mosaic_seg = np.full( + (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), + self.seg_pad_val, + dtype=results[key].dtype) + + # mosaic center x, y + center_position = (self.center_x, self.center_y) + + loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') + for i, loc in enumerate(loc_strs): + if loc == 'top_left': + result_patch = copy.deepcopy(results) + else: + result_patch = copy.deepcopy(results['mix_results'][i - 1]) + + gt_seg_i = result_patch[key] + h_i, w_i = gt_seg_i.shape[:2] + # keep_ratio resize + scale_ratio_i = min(self.img_scale[0] / h_i, + self.img_scale[1] / w_i) + gt_seg_i = mmcv.imresize( + gt_seg_i, + (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)), + interpolation='nearest') + + # compute the combine parameters + paste_coord, crop_coord = self._mosaic_combine( + loc, center_position, gt_seg_i.shape[:2][::-1]) + x1_p, y1_p, x2_p, y2_p = paste_coord + x1_c, y1_c, x2_c, y2_c = crop_coord + + # crop and paste image + mosaic_seg[y1_p:y2_p, x1_p:x2_p] = gt_seg_i[y1_c:y2_c, + x1_c:x2_c] + + results[key] = mosaic_seg + + return results + + def _mosaic_combine(self, loc, center_position_xy, img_shape_wh): + """Calculate global coordinate of mosaic image and local coordinate of + cropped sub-image. + + Args: + loc (str): Index for the sub-image, loc in ('top_left', + 'top_right', 'bottom_left', 'bottom_right'). + center_position_xy (Sequence[float]): Mixing center for 4 images, + (x, y). + img_shape_wh (Sequence[int]): Width and height of sub-image + + Returns: + tuple[tuple[float]]: Corresponding coordinate of pasting and + cropping + - paste_coord (tuple): paste corner coordinate in mosaic image. + - crop_coord (tuple): crop corner coordinate in mosaic image. + """ + + assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right') + if loc == 'top_left': + # index0 to top left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + center_position_xy[0], \ + center_position_xy[1] + crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - ( + y2 - y1), img_shape_wh[0], img_shape_wh[1] + + elif loc == 'top_right': + # index1 to top right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + max(center_position_xy[1] - img_shape_wh[1], 0), \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[1] * 2), \ + center_position_xy[1] + crop_coord = 0, img_shape_wh[1] - (y2 - y1), min( + img_shape_wh[0], x2 - x1), img_shape_wh[1] + + elif loc == 'bottom_left': + # index2 to bottom left part of image + x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ + center_position_xy[1], \ + center_position_xy[0], \ + min(self.img_scale[0] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min( + y2 - y1, img_shape_wh[1]) + + else: + # index3 to bottom right part of image + x1, y1, x2, y2 = center_position_xy[0], \ + center_position_xy[1], \ + min(center_position_xy[0] + img_shape_wh[0], + self.img_scale[1] * 2), \ + min(self.img_scale[0] * 2, center_position_xy[1] + + img_shape_wh[1]) + crop_coord = 0, 0, min(img_shape_wh[0], + x2 - x1), min(y2 - y1, img_shape_wh[1]) + + paste_coord = x1, y1, x2, y2 + return paste_coord, crop_coord + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(prob={self.prob}, ' + repr_str += f'img_scale={self.img_scale}, ' + repr_str += f'center_ratio_range={self.center_ratio_range}, ' + repr_str += f'pad_val={self.pad_val}, ' + repr_str += f'seg_pad_val={self.pad_val})' + return repr_str diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index ab7ffe066..e9aa1d75a 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -614,3 +614,52 @@ def test_cutout(): cutout_result = cutout_module(copy.deepcopy(results)) assert cutout_result['img'].sum() > img.sum() assert cutout_result['gt_semantic_seg'].sum() > seg.sum() + + +def test_mosaic(): + # test prob + with pytest.raises(AssertionError): + transform = dict(type='RandomMosaic', prob=1.5) + build_from_cfg(transform, PIPELINES) + # test assertion for invalid img_scale + with pytest.raises(AssertionError): + transform = dict(type='RandomMosaic', prob=1, img_scale=640) + build_from_cfg(transform, PIPELINES) + + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color') + seg = np.array( + Image.open(osp.join(osp.dirname(__file__), '../data/seg.png'))) + + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + + transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12)) + mosaic_module = build_from_cfg(transform, PIPELINES) + assert 'Mosaic' in repr(mosaic_module) + + # test assertion for invalid mix_results + with pytest.raises(AssertionError): + mosaic_module(results) + + results['mix_results'] = [copy.deepcopy(results)] * 3 + results = mosaic_module(results) + assert results['img'].shape[:2] == (20, 24) + + results = dict() + results['img'] = img[:, :, 0] + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + + transform = dict(type='RandomMosaic', prob=0, img_scale=(10, 12)) + mosaic_module = build_from_cfg(transform, PIPELINES) + results['mix_results'] = [copy.deepcopy(results)] * 3 + results = mosaic_module(results) + assert results['img'].shape[:2] == img.shape[:2] + + transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12)) + mosaic_module = build_from_cfg(transform, PIPELINES) + results = mosaic_module(results) + assert results['img'].shape[:2] == (20, 24)