mmsegmentation/mmseg/datasets/transforms/transforms.py

1150 lines
37 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Sequence, Tuple, Union
import mmcv
import numpy as np
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmengine.utils import is_tuple_of
from numpy import random
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
from mmseg.registry import TRANSFORMS
@TRANSFORMS.register_module()
class ResizeToMultiple(BaseTransform):
"""Resize images & seg to multiple of divisor.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- img_shape
- pad_shape
Args:
size_divisor (int): images and gt seg maps need to resize to multiple
of size_divisor. Default: 32.
interpolation (str, optional): The interpolation mode of image resize.
Default: None
"""
def __init__(self, size_divisor=32, interpolation=None):
self.size_divisor = size_divisor
self.interpolation = interpolation
def transform(self, results: dict) -> dict:
"""Call function to resize images, semantic segmentation map to
multiple of size divisor.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img_shape', 'pad_shape' keys are updated.
"""
# Align image to multiple of size divisor.
img = results['img']
img = mmcv.imresize_to_multiple(
img,
self.size_divisor,
scale_factor=1,
interpolation=self.interpolation
if self.interpolation else 'bilinear')
results['img'] = img
results['img_shape'] = img.shape[:2]
results['pad_shape'] = img.shape[:2]
# Align segmentation map to multiple of size divisor.
for key in results.get('seg_fields', []):
gt_seg = results[key]
gt_seg = mmcv.imresize_to_multiple(
gt_seg,
self.size_divisor,
scale_factor=1,
interpolation='nearest')
results[key] = gt_seg
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(size_divisor={self.size_divisor}, '
f'interpolation={self.interpolation})')
return repr_str
@TRANSFORMS.register_module()
class Rerange(BaseTransform):
"""Rerange the image pixel value.
Required Keys:
- img
Modified Keys:
- img
Args:
min_value (float or int): Minimum value of the reranged image.
Default: 0.
max_value (float or int): Maximum value of the reranged image.
Default: 255.
"""
def __init__(self, min_value=0, max_value=255):
assert isinstance(min_value, float) or isinstance(min_value, int)
assert isinstance(max_value, float) or isinstance(max_value, int)
assert min_value < max_value
self.min_value = min_value
self.max_value = max_value
def transform(self, results: dict) -> dict:
"""Call function to rerange images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Reranged results.
"""
img = results['img']
img_min_value = np.min(img)
img_max_value = np.max(img)
assert img_min_value < img_max_value
# rerange to [0, 1]
img = (img - img_min_value) / (img_max_value - img_min_value)
# rerange to [min_value, max_value]
img = img * (self.max_value - self.min_value) + self.min_value
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(min_value={self.min_value}, max_value={self.max_value})'
return repr_str
@TRANSFORMS.register_module()
class CLAHE(BaseTransform):
"""Use CLAHE method to process the image.
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
Graphics Gems, 1994:474-485.` for more information.
Required Keys:
- img
Modified Keys:
- img
Args:
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
Input image will be divided into equally sized rectangular tiles.
It defines the number of tiles in row and column. Default: (8, 8).
"""
def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
assert isinstance(clip_limit, (float, int))
self.clip_limit = clip_limit
assert is_tuple_of(tile_grid_size, int)
assert len(tile_grid_size) == 2
self.tile_grid_size = tile_grid_size
def transform(self, results: dict) -> dict:
"""Call function to Use CLAHE method process images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Processed results.
"""
for i in range(results['img'].shape[2]):
results['img'][:, :, i] = mmcv.clahe(
np.array(results['img'][:, :, i], dtype=np.uint8),
self.clip_limit, self.tile_grid_size)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(clip_limit={self.clip_limit}, '\
f'tile_grid_size={self.tile_grid_size})'
return repr_str
@TRANSFORMS.register_module()
class RandomCrop(BaseTransform):
"""Random crop the image & seg.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- img_shape
- gt_seg_map
Args:
crop_size (Union[int, Tuple[int, int]]): Expected size after cropping
with the format of (h, w). If set to an integer, then cropping
width and height are equal to this integer.
cat_max_ratio (float): The maximum ratio that single category could
occupy.
ignore_index (int): The label index to be ignored. Default: 255
"""
def __init__(self,
crop_size: Union[int, Tuple[int, int]],
cat_max_ratio: float = 1.,
ignore_index: int = 255):
super().__init__()
assert isinstance(crop_size, int) or (
isinstance(crop_size, tuple) and len(crop_size) == 2
), 'The expected crop_size is an integer, or a tuple containing two '
'intergers'
if isinstance(crop_size, int):
crop_size = (crop_size, crop_size)
assert crop_size[0] > 0 and crop_size[1] > 0
self.crop_size = crop_size
self.cat_max_ratio = cat_max_ratio
self.ignore_index = ignore_index
@cache_randomness
def crop_bbox(self, results: dict) -> tuple:
"""get a crop bounding box.
Args:
results (dict): Result dict from loading pipeline.
Returns:
tuple: Coordinates of the cropped image.
"""
def generate_crop_bbox(img: np.ndarray) -> tuple:
"""Randomly get a crop bounding box.
Args:
img (np.ndarray): Original input image.
Returns:
tuple: Coordinates of the cropped image.
"""
margin_h = max(img.shape[0] - self.crop_size[0], 0)
margin_w = max(img.shape[1] - self.crop_size[1], 0)
offset_h = np.random.randint(0, margin_h + 1)
offset_w = np.random.randint(0, margin_w + 1)
crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
return crop_y1, crop_y2, crop_x1, crop_x2
img = results['img']
crop_bbox = generate_crop_bbox(img)
if self.cat_max_ratio < 1.:
# Repeat 10 times
for _ in range(10):
seg_temp = self.crop(results['gt_seg_map'], crop_bbox)
labels, cnt = np.unique(seg_temp, return_counts=True)
cnt = cnt[labels != self.ignore_index]
if len(cnt) > 1 and np.max(cnt) / np.sum(
cnt) < self.cat_max_ratio:
break
crop_bbox = generate_crop_bbox(img)
return crop_bbox
def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray:
"""Crop from ``img``
Args:
img (np.ndarray): Original input image.
crop_bbox (tuple): Coordinates of the cropped image.
Returns:
np.ndarray: The cropped image.
"""
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
return img
def transform(self, results: dict) -> dict:
"""Transform function to randomly crop images, semantic segmentation
maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
"""
img = results['img']
crop_bbox = self.crop_bbox(results)
# crop the image
img = self.crop(img, crop_bbox)
# crop semantic seg
for key in results.get('seg_fields', []):
results[key] = self.crop(results[key], crop_bbox)
img_shape = img.shape
results['img'] = img
results['img_shape'] = img_shape
return results
def __repr__(self):
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
@TRANSFORMS.register_module()
class RandomRotate(BaseTransform):
"""Rotate the image & seg.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- gt_seg_map
Args:
prob (float): The rotation probability.
degree (float, tuple[float]): Range of degrees to select from. If
degree is a number instead of tuple like (min, max),
the range of degree will be (``-degree``, ``+degree``)
pad_val (float, optional): Padding value of image. Default: 0.
seg_pad_val (float, optional): Padding value of segmentation map.
Default: 255.
center (tuple[float], optional): Center point (w, h) of the rotation in
the source image. If not specified, the center of the image will be
used. Default: None.
auto_bound (bool): Whether to adjust the image size to cover the whole
rotated image. Default: False
"""
def __init__(self,
prob,
degree,
pad_val=0,
seg_pad_val=255,
center=None,
auto_bound=False):
self.prob = prob
assert prob >= 0 and prob <= 1
if isinstance(degree, (float, int)):
assert degree > 0, f'degree {degree} should be positive'
self.degree = (-degree, degree)
else:
self.degree = degree
assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
f'tuple of (min, max)'
self.pal_val = pad_val
self.seg_pad_val = seg_pad_val
self.center = center
self.auto_bound = auto_bound
@cache_randomness
def generate_degree(self):
return np.random.rand() < self.prob, np.random.uniform(
min(*self.degree), max(*self.degree))
def transform(self, results: dict) -> dict:
"""Call function to rotate image, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Rotated results.
"""
rotate, degree = self.generate_degree()
if rotate:
# rotate image
results['img'] = mmcv.imrotate(
results['img'],
angle=degree,
border_value=self.pal_val,
center=self.center,
auto_bound=self.auto_bound)
# rotate segs
for key in results.get('seg_fields', []):
results[key] = mmcv.imrotate(
results[key],
angle=degree,
border_value=self.seg_pad_val,
center=self.center,
auto_bound=self.auto_bound,
interpolation='nearest')
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, ' \
f'degree={self.degree}, ' \
f'pad_val={self.pal_val}, ' \
f'seg_pad_val={self.seg_pad_val}, ' \
f'center={self.center}, ' \
f'auto_bound={self.auto_bound})'
return repr_str
@TRANSFORMS.register_module()
class RGB2Gray(BaseTransform):
"""Convert RGB image to grayscale image.
Required Keys:
- img
Modified Keys:
- img
- img_shape
This transform calculate the weighted mean of input image channels with
``weights`` and then expand the channels to ``out_channels``. When
``out_channels`` is None, the number of output channels is the same as
input channels.
Args:
out_channels (int): Expected number of output channels after
transforming. Default: None.
weights (tuple[float]): The weights to calculate the weighted mean.
Default: (0.299, 0.587, 0.114).
"""
def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
assert out_channels is None or out_channels > 0
self.out_channels = out_channels
assert isinstance(weights, tuple)
for item in weights:
assert isinstance(item, (float, int))
self.weights = weights
def transform(self, results: dict) -> dict:
"""Call function to convert RGB image to grayscale image.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with grayscale image.
"""
img = results['img']
assert len(img.shape) == 3
assert img.shape[2] == len(self.weights)
weights = np.array(self.weights).reshape((1, 1, -1))
img = (img * weights).sum(2, keepdims=True)
if self.out_channels is None:
img = img.repeat(weights.shape[2], axis=2)
else:
img = img.repeat(self.out_channels, axis=2)
results['img'] = img
results['img_shape'] = img.shape
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(out_channels={self.out_channels}, ' \
f'weights={self.weights})'
return repr_str
@TRANSFORMS.register_module()
class AdjustGamma(BaseTransform):
"""Using gamma correction to process the image.
Required Keys:
- img
Modified Keys:
- img
Args:
gamma (float or int): Gamma value used in gamma correction.
Default: 1.0.
"""
def __init__(self, gamma=1.0):
assert isinstance(gamma, float) or isinstance(gamma, int)
assert gamma > 0
self.gamma = gamma
inv_gamma = 1.0 / gamma
self.table = np.array([(i / 255.0)**inv_gamma * 255
for i in np.arange(256)]).astype('uint8')
def transform(self, results: dict) -> dict:
"""Call function to process the image with gamma correction.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Processed results.
"""
results['img'] = mmcv.lut_transform(
np.array(results['img'], dtype=np.uint8), self.table)
return results
def __repr__(self):
return self.__class__.__name__ + f'(gamma={self.gamma})'
@TRANSFORMS.register_module()
class SegRescale(BaseTransform):
"""Rescale semantic segmentation maps.
Required Keys:
- gt_seg_map
Modified Keys:
- gt_seg_map
Args:
scale_factor (float): The scale factor of the final output.
"""
def __init__(self, scale_factor=1):
self.scale_factor = scale_factor
def transform(self, results: dict) -> dict:
"""Call function to scale the semantic segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with semantic segmentation map scaled.
"""
for key in results.get('seg_fields', []):
if self.scale_factor != 1:
results[key] = mmcv.imrescale(
results[key], self.scale_factor, interpolation='nearest')
return results
def __repr__(self):
return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
@TRANSFORMS.register_module()
class PhotoMetricDistortion(BaseTransform):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
Required Keys:
- img
Modified Keys:
- img
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def __init__(self,
brightness_delta: int = 32,
contrast_range: Sequence[float] = (0.5, 1.5),
saturation_range: Sequence[float] = (0.5, 1.5),
hue_delta: int = 18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def convert(self,
img: np.ndarray,
alpha: int = 1,
beta: int = 0) -> np.ndarray:
"""Multiple with alpha and add beat with clip.
Args:
img (np.ndarray): The input image.
alpha (int): Image weights, change the contrast/saturation
of the image. Default: 1
beta (int): Image bias, change the brightness of the
image. Default: 0
Returns:
np.ndarray: The transformed image.
"""
img = img.astype(np.float32) * alpha + beta
img = np.clip(img, 0, 255)
return img.astype(np.uint8)
def brightness(self, img: np.ndarray) -> np.ndarray:
"""Brightness distortion.
Args:
img (np.ndarray): The input image.
Returns:
np.ndarray: Image after brightness change.
"""
if random.randint(2):
return self.convert(
img,
beta=random.uniform(-self.brightness_delta,
self.brightness_delta))
return img
def contrast(self, img: np.ndarray) -> np.ndarray:
"""Contrast distortion.
Args:
img (np.ndarray): The input image.
Returns:
np.ndarray: Image after contrast change.
"""
if random.randint(2):
return self.convert(
img,
alpha=random.uniform(self.contrast_lower, self.contrast_upper))
return img
def saturation(self, img: np.ndarray) -> np.ndarray:
"""Saturation distortion.
Args:
img (np.ndarray): The input image.
Returns:
np.ndarray: Image after saturation change.
"""
if random.randint(2):
img = mmcv.bgr2hsv(img)
img[:, :, 1] = self.convert(
img[:, :, 1],
alpha=random.uniform(self.saturation_lower,
self.saturation_upper))
img = mmcv.hsv2bgr(img)
return img
def hue(self, img: np.ndarray) -> np.ndarray:
"""Hue distortion.
Args:
img (np.ndarray): The input image.
Returns:
np.ndarray: Image after hue change.
"""
if random.randint(2):
img = mmcv.bgr2hsv(img)
img[:, :,
0] = (img[:, :, 0].astype(int) +
random.randint(-self.hue_delta, self.hue_delta)) % 180
img = mmcv.hsv2bgr(img)
return img
def transform(self, results: dict) -> dict:
"""Transform function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
img = results['img']
# random brightness
img = self.brightness(img)
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(2)
if mode == 1:
img = self.contrast(img)
# random saturation
img = self.saturation(img)
# random hue
img = self.hue(img)
# random contrast
if mode == 0:
img = self.contrast(img)
results['img'] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(brightness_delta={self.brightness_delta}, '
f'contrast_range=({self.contrast_lower}, '
f'{self.contrast_upper}), '
f'saturation_range=({self.saturation_lower}, '
f'{self.saturation_upper}), '
f'hue_delta={self.hue_delta})')
return repr_str
@TRANSFORMS.register_module()
class RandomCutOut(BaseTransform):
"""CutOut operation.
Randomly drop some regions of image used in
`Cutout <https://arxiv.org/abs/1708.04552>`_.
Required Keys:
- img
- gt_seg_map
Modified Keys:
- img
- gt_seg_map
Args:
prob (float): cutout probability.
n_holes (int | tuple[int, int]): Number of regions to be dropped.
If it is given as a list, number of holes will be randomly
selected from the closed interval [`n_holes[0]`, `n_holes[1]`].
cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate
shape of dropped regions. It can be `tuple[int, int]` to use a
fixed cutout shape, or `list[tuple[int, int]]` to randomly choose
shape from the list.
cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The
candidate ratio of dropped regions. It can be `tuple[float, float]`
to use a fixed ratio or `list[tuple[float, float]]` to randomly
choose ratio from the list. Please note that `cutout_shape`
and `cutout_ratio` cannot be both given at the same time.
fill_in (tuple[float, float, float] | tuple[int, int, int]): The value
of pixel to fill in the dropped regions. Default: (0, 0, 0).
seg_fill_in (int): The labels of pixel to fill in the dropped regions.
If seg_fill_in is None, skip. Default: None.
"""
def __init__(self,
prob,
n_holes,
cutout_shape=None,
cutout_ratio=None,
fill_in=(0, 0, 0),
seg_fill_in=None):
assert 0 <= prob and prob <= 1
assert (cutout_shape is None) ^ (cutout_ratio is None), \
'Either cutout_shape or cutout_ratio should be specified.'
assert (isinstance(cutout_shape, (list, tuple))
or isinstance(cutout_ratio, (list, tuple)))
if isinstance(n_holes, tuple):
assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1]
else:
n_holes = (n_holes, n_holes)
if seg_fill_in is not None:
assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in
and seg_fill_in <= 255)
self.prob = prob
self.n_holes = n_holes
self.fill_in = fill_in
self.seg_fill_in = seg_fill_in
self.with_ratio = cutout_ratio is not None
self.candidates = cutout_ratio if self.with_ratio else cutout_shape
if not isinstance(self.candidates, list):
self.candidates = [self.candidates]
@cache_randomness
def do_cutout(self):
return np.random.rand() < self.prob
@cache_randomness
def generate_patches(self, results):
cutout = self.do_cutout()
h, w, _ = results['img'].shape
if cutout:
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
else:
n_holes = 0
x1_lst = []
y1_lst = []
index_lst = []
for _ in range(n_holes):
x1_lst.append(np.random.randint(0, w))
y1_lst.append(np.random.randint(0, h))
index_lst.append(np.random.randint(0, len(self.candidates)))
return cutout, n_holes, x1_lst, y1_lst, index_lst
def transform(self, results: dict) -> dict:
"""Call function to drop some regions of image."""
cutout, n_holes, x1_lst, y1_lst, index_lst = self.generate_patches(
results)
if cutout:
h, w, c = results['img'].shape
for i in range(n_holes):
x1 = x1_lst[i]
y1 = y1_lst[i]
index = index_lst[i]
if not self.with_ratio:
cutout_w, cutout_h = self.candidates[index]
else:
cutout_w = int(self.candidates[index][0] * w)
cutout_h = int(self.candidates[index][1] * h)
x2 = np.clip(x1 + cutout_w, 0, w)
y2 = np.clip(y1 + cutout_h, 0, h)
results['img'][y1:y2, x1:x2, :] = self.fill_in
if self.seg_fill_in is not None:
for key in results.get('seg_fields', []):
results[key][y1:y2, x1:x2] = self.seg_fill_in
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'n_holes={self.n_holes}, '
repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio
else f'cutout_shape={self.candidates}, ')
repr_str += f'fill_in={self.fill_in}, '
repr_str += f'seg_fill_in={self.seg_fill_in})'
return repr_str
@TRANSFORMS.register_module()
class RandomMosaic(BaseTransform):
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
one output image. The output image is composed of the parts from each sub-
image.
.. code:: text
mosaic transform
center_x
+------------------------------+
| pad | pad |
| +-----------+ |
| | | |
| | image1 |--------+ |
| | | | |
| | | image2 | |
center_y |----+-------------+-----------|
| | cropped | |
|pad | image3 | image4 |
| | | |
+----|-------------+-----------+
| |
+-------------+
The mosaic transform steps are as follows:
1. Choose the mosaic center as the intersections of 4 images
2. Get the left top image according to the index, and randomly
sample another 3 images from the custom dataset.
3. Sub image will be cropped if image is larger than mosaic patch
Required Keys:
- img
- gt_seg_map
- mix_results
Modified Keys:
- img
- img_shape
- ori_shape
- gt_seg_map
Args:
prob (float): mosaic probability.
img_scale (Sequence[int]): Image size after mosaic pipeline of
a single image. The size of the output image is four times
that of a single image. The output image comprises 4 single images.
Default: (640, 640).
center_ratio_range (Sequence[float]): Center ratio range of mosaic
output. Default: (0.5, 1.5).
pad_val (int): Pad value. Default: 0.
seg_pad_val (int): Pad value of segmentation map. Default: 255.
"""
def __init__(self,
prob,
img_scale=(640, 640),
center_ratio_range=(0.5, 1.5),
pad_val=0,
seg_pad_val=255):
assert 0 <= prob and prob <= 1
assert isinstance(img_scale, tuple)
self.prob = prob
self.img_scale = img_scale
self.center_ratio_range = center_ratio_range
self.pad_val = pad_val
self.seg_pad_val = seg_pad_val
@cache_randomness
def do_mosaic(self):
return np.random.rand() < self.prob
def transform(self, results: dict) -> dict:
"""Call function to make a mosaic of image.
Args:
results (dict): Result dict.
Returns:
dict: Result dict with mosaic transformed.
"""
mosaic = self.do_mosaic()
if mosaic:
results = self._mosaic_transform_img(results)
results = self._mosaic_transform_seg(results)
return results
def get_indices(self, dataset: MultiImageMixDataset) -> list:
"""Call function to collect indexes.
Args:
dataset (:obj:`MultiImageMixDataset`): The dataset.
Returns:
list: indexes.
"""
indexes = [random.randint(0, len(dataset)) for _ in range(3)]
return indexes
@cache_randomness
def generate_mosaic_center(self):
# mosaic center x, y
center_x = int(
random.uniform(*self.center_ratio_range) * self.img_scale[1])
center_y = int(
random.uniform(*self.center_ratio_range) * self.img_scale[0])
return center_x, center_y
def _mosaic_transform_img(self, results: dict) -> dict:
"""Mosaic transform function.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""
assert 'mix_results' in results
if len(results['img'].shape) == 3:
mosaic_img = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3),
self.pad_val,
dtype=results['img'].dtype)
else:
mosaic_img = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
self.pad_val,
dtype=results['img'].dtype)
# mosaic center x, y
self.center_x, self.center_y = self.generate_mosaic_center()
center_position = (self.center_x, self.center_y)
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
for i, loc in enumerate(loc_strs):
if loc == 'top_left':
result_patch = copy.deepcopy(results)
else:
result_patch = copy.deepcopy(results['mix_results'][i - 1])
img_i = result_patch['img']
h_i, w_i = img_i.shape[:2]
# keep_ratio resize
scale_ratio_i = min(self.img_scale[0] / h_i,
self.img_scale[1] / w_i)
img_i = mmcv.imresize(
img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
# compute the combine parameters
paste_coord, crop_coord = self._mosaic_combine(
loc, center_position, img_i.shape[:2][::-1])
x1_p, y1_p, x2_p, y2_p = paste_coord
x1_c, y1_c, x2_c, y2_c = crop_coord
# crop and paste image
mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
results['img'] = mosaic_img
results['img_shape'] = mosaic_img.shape
results['ori_shape'] = mosaic_img.shape
return results
def _mosaic_transform_seg(self, results: dict) -> dict:
"""Mosaic transform function for label annotations.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""
assert 'mix_results' in results
for key in results.get('seg_fields', []):
mosaic_seg = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
self.seg_pad_val,
dtype=results[key].dtype)
# mosaic center x, y
center_position = (self.center_x, self.center_y)
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
for i, loc in enumerate(loc_strs):
if loc == 'top_left':
result_patch = copy.deepcopy(results)
else:
result_patch = copy.deepcopy(results['mix_results'][i - 1])
gt_seg_i = result_patch[key]
h_i, w_i = gt_seg_i.shape[:2]
# keep_ratio resize
scale_ratio_i = min(self.img_scale[0] / h_i,
self.img_scale[1] / w_i)
gt_seg_i = mmcv.imresize(
gt_seg_i,
(int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)),
interpolation='nearest')
# compute the combine parameters
paste_coord, crop_coord = self._mosaic_combine(
loc, center_position, gt_seg_i.shape[:2][::-1])
x1_p, y1_p, x2_p, y2_p = paste_coord
x1_c, y1_c, x2_c, y2_c = crop_coord
# crop and paste image
mosaic_seg[y1_p:y2_p, x1_p:x2_p] = gt_seg_i[y1_c:y2_c,
x1_c:x2_c]
results[key] = mosaic_seg
return results
def _mosaic_combine(self, loc: str, center_position_xy: Sequence[float],
img_shape_wh: Sequence[int]) -> tuple:
"""Calculate global coordinate of mosaic image and local coordinate of
cropped sub-image.
Args:
loc (str): Index for the sub-image, loc in ('top_left',
'top_right', 'bottom_left', 'bottom_right').
center_position_xy (Sequence[float]): Mixing center for 4 images,
(x, y).
img_shape_wh (Sequence[int]): Width and height of sub-image
Returns:
tuple[tuple[float]]: Corresponding coordinate of pasting and
cropping
- paste_coord (tuple): paste corner coordinate in mosaic image.
- crop_coord (tuple): crop corner coordinate in mosaic image.
"""
assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right')
if loc == 'top_left':
# index0 to top left part of image
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
max(center_position_xy[1] - img_shape_wh[1], 0), \
center_position_xy[0], \
center_position_xy[1]
crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - (
y2 - y1), img_shape_wh[0], img_shape_wh[1]
elif loc == 'top_right':
# index1 to top right part of image
x1, y1, x2, y2 = center_position_xy[0], \
max(center_position_xy[1] - img_shape_wh[1], 0), \
min(center_position_xy[0] + img_shape_wh[0],
self.img_scale[1] * 2), \
center_position_xy[1]
crop_coord = 0, img_shape_wh[1] - (y2 - y1), min(
img_shape_wh[0], x2 - x1), img_shape_wh[1]
elif loc == 'bottom_left':
# index2 to bottom left part of image
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
center_position_xy[1], \
center_position_xy[0], \
min(self.img_scale[0] * 2, center_position_xy[1] +
img_shape_wh[1])
crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min(
y2 - y1, img_shape_wh[1])
else:
# index3 to bottom right part of image
x1, y1, x2, y2 = center_position_xy[0], \
center_position_xy[1], \
min(center_position_xy[0] + img_shape_wh[0],
self.img_scale[1] * 2), \
min(self.img_scale[0] * 2, center_position_xy[1] +
img_shape_wh[1])
crop_coord = 0, 0, min(img_shape_wh[0],
x2 - x1), min(y2 - y1, img_shape_wh[1])
paste_coord = x1, y1, x2, y2
return paste_coord, crop_coord
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'img_scale={self.img_scale}, '
repr_str += f'center_ratio_range={self.center_ratio_range}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'seg_pad_val={self.pad_val})'
return repr_str