mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
## Modification I changed the hardcoded 3 channel length to dynamic channel length in `np.full` function arguments. This modification enables `RandomMosaic` transform to support multispectral image (e.g. RGB image with NIR band) or bi-temporal image pairs for change detection task. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 4. The documentation has been modified accordingly, like docstring or example tutorials.
2138 lines
72 KiB
Python
2138 lines
72 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import warnings
|
|
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|
|
|
import cv2
|
|
import mmcv
|
|
import numpy as np
|
|
from mmcv.transforms.base import BaseTransform
|
|
from mmcv.transforms.utils import cache_randomness
|
|
from mmengine.utils import is_tuple_of
|
|
from numpy import random
|
|
from scipy.ndimage import gaussian_filter
|
|
|
|
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
|
|
from mmseg.registry import TRANSFORMS
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class ResizeToMultiple(BaseTransform):
|
|
"""Resize images & seg to multiple of divisor.
|
|
|
|
Required Keys:
|
|
|
|
- img
|
|
- gt_seg_map
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
- img_shape
|
|
- pad_shape
|
|
|
|
Args:
|
|
size_divisor (int): images and gt seg maps need to resize to multiple
|
|
of size_divisor. Default: 32.
|
|
interpolation (str, optional): The interpolation mode of image resize.
|
|
Default: None
|
|
"""
|
|
|
|
def __init__(self, size_divisor=32, interpolation=None):
|
|
self.size_divisor = size_divisor
|
|
self.interpolation = interpolation
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Call function to resize images, semantic segmentation map to
|
|
multiple of size divisor.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Resized results, 'img_shape', 'pad_shape' keys are updated.
|
|
"""
|
|
# Align image to multiple of size divisor.
|
|
img = results['img']
|
|
img = mmcv.imresize_to_multiple(
|
|
img,
|
|
self.size_divisor,
|
|
scale_factor=1,
|
|
interpolation=self.interpolation
|
|
if self.interpolation else 'bilinear')
|
|
|
|
results['img'] = img
|
|
results['img_shape'] = img.shape[:2]
|
|
results['pad_shape'] = img.shape[:2]
|
|
|
|
# Align segmentation map to multiple of size divisor.
|
|
for key in results.get('seg_fields', []):
|
|
gt_seg = results[key]
|
|
gt_seg = mmcv.imresize_to_multiple(
|
|
gt_seg,
|
|
self.size_divisor,
|
|
scale_factor=1,
|
|
interpolation='nearest')
|
|
results[key] = gt_seg
|
|
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += (f'(size_divisor={self.size_divisor}, '
|
|
f'interpolation={self.interpolation})')
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class Rerange(BaseTransform):
|
|
"""Rerange the image pixel value.
|
|
|
|
Required Keys:
|
|
|
|
- img
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
|
|
Args:
|
|
min_value (float or int): Minimum value of the reranged image.
|
|
Default: 0.
|
|
max_value (float or int): Maximum value of the reranged image.
|
|
Default: 255.
|
|
"""
|
|
|
|
def __init__(self, min_value=0, max_value=255):
|
|
assert isinstance(min_value, float) or isinstance(min_value, int)
|
|
assert isinstance(max_value, float) or isinstance(max_value, int)
|
|
assert min_value < max_value
|
|
self.min_value = min_value
|
|
self.max_value = max_value
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Call function to rerange images.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
Returns:
|
|
dict: Reranged results.
|
|
"""
|
|
|
|
img = results['img']
|
|
img_min_value = np.min(img)
|
|
img_max_value = np.max(img)
|
|
|
|
assert img_min_value < img_max_value
|
|
# rerange to [0, 1]
|
|
img = (img - img_min_value) / (img_max_value - img_min_value)
|
|
# rerange to [min_value, max_value]
|
|
img = img * (self.max_value - self.min_value) + self.min_value
|
|
results['img'] = img
|
|
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(min_value={self.min_value}, max_value={self.max_value})'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class CLAHE(BaseTransform):
|
|
"""Use CLAHE method to process the image.
|
|
|
|
See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
|
|
Graphics Gems, 1994:474-485.` for more information.
|
|
|
|
Required Keys:
|
|
|
|
- img
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
|
|
Args:
|
|
clip_limit (float): Threshold for contrast limiting. Default: 40.0.
|
|
tile_grid_size (tuple[int]): Size of grid for histogram equalization.
|
|
Input image will be divided into equally sized rectangular tiles.
|
|
It defines the number of tiles in row and column. Default: (8, 8).
|
|
"""
|
|
|
|
def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
|
|
assert isinstance(clip_limit, (float, int))
|
|
self.clip_limit = clip_limit
|
|
assert is_tuple_of(tile_grid_size, int)
|
|
assert len(tile_grid_size) == 2
|
|
self.tile_grid_size = tile_grid_size
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Call function to Use CLAHE method process images.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Processed results.
|
|
"""
|
|
|
|
for i in range(results['img'].shape[2]):
|
|
results['img'][:, :, i] = mmcv.clahe(
|
|
np.array(results['img'][:, :, i], dtype=np.uint8),
|
|
self.clip_limit, self.tile_grid_size)
|
|
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(clip_limit={self.clip_limit}, '\
|
|
f'tile_grid_size={self.tile_grid_size})'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class RandomCrop(BaseTransform):
|
|
"""Random crop the image & seg.
|
|
|
|
Required Keys:
|
|
|
|
- img
|
|
- gt_seg_map
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
- img_shape
|
|
- gt_seg_map
|
|
|
|
|
|
Args:
|
|
crop_size (Union[int, Tuple[int, int]]): Expected size after cropping
|
|
with the format of (h, w). If set to an integer, then cropping
|
|
width and height are equal to this integer.
|
|
cat_max_ratio (float): The maximum ratio that single category could
|
|
occupy.
|
|
ignore_index (int): The label index to be ignored. Default: 255
|
|
"""
|
|
|
|
def __init__(self,
|
|
crop_size: Union[int, Tuple[int, int]],
|
|
cat_max_ratio: float = 1.,
|
|
ignore_index: int = 255):
|
|
super().__init__()
|
|
assert isinstance(crop_size, int) or (
|
|
isinstance(crop_size, tuple) and len(crop_size) == 2
|
|
), 'The expected crop_size is an integer, or a tuple containing two '
|
|
'intergers'
|
|
|
|
if isinstance(crop_size, int):
|
|
crop_size = (crop_size, crop_size)
|
|
assert crop_size[0] > 0 and crop_size[1] > 0
|
|
self.crop_size = crop_size
|
|
self.cat_max_ratio = cat_max_ratio
|
|
self.ignore_index = ignore_index
|
|
|
|
@cache_randomness
|
|
def crop_bbox(self, results: dict) -> tuple:
|
|
"""get a crop bounding box.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
tuple: Coordinates of the cropped image.
|
|
"""
|
|
|
|
def generate_crop_bbox(img: np.ndarray) -> tuple:
|
|
"""Randomly get a crop bounding box.
|
|
|
|
Args:
|
|
img (np.ndarray): Original input image.
|
|
|
|
Returns:
|
|
tuple: Coordinates of the cropped image.
|
|
"""
|
|
|
|
margin_h = max(img.shape[0] - self.crop_size[0], 0)
|
|
margin_w = max(img.shape[1] - self.crop_size[1], 0)
|
|
offset_h = np.random.randint(0, margin_h + 1)
|
|
offset_w = np.random.randint(0, margin_w + 1)
|
|
crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
|
|
crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
|
|
|
|
return crop_y1, crop_y2, crop_x1, crop_x2
|
|
|
|
img = results['img']
|
|
crop_bbox = generate_crop_bbox(img)
|
|
if self.cat_max_ratio < 1.:
|
|
# Repeat 10 times
|
|
for _ in range(10):
|
|
seg_temp = self.crop(results['gt_seg_map'], crop_bbox)
|
|
labels, cnt = np.unique(seg_temp, return_counts=True)
|
|
cnt = cnt[labels != self.ignore_index]
|
|
if len(cnt) > 1 and np.max(cnt) / np.sum(
|
|
cnt) < self.cat_max_ratio:
|
|
break
|
|
crop_bbox = generate_crop_bbox(img)
|
|
|
|
return crop_bbox
|
|
|
|
def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray:
|
|
"""Crop from ``img``
|
|
|
|
Args:
|
|
img (np.ndarray): Original input image.
|
|
crop_bbox (tuple): Coordinates of the cropped image.
|
|
|
|
Returns:
|
|
np.ndarray: The cropped image.
|
|
"""
|
|
|
|
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
|
|
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
|
|
return img
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Transform function to randomly crop images, semantic segmentation
|
|
maps.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Randomly cropped results, 'img_shape' key in result dict is
|
|
updated according to crop size.
|
|
"""
|
|
|
|
img = results['img']
|
|
crop_bbox = self.crop_bbox(results)
|
|
|
|
# crop the image
|
|
img = self.crop(img, crop_bbox)
|
|
|
|
# crop semantic seg
|
|
for key in results.get('seg_fields', []):
|
|
results[key] = self.crop(results[key], crop_bbox)
|
|
|
|
results['img'] = img
|
|
results['img_shape'] = img.shape[:2]
|
|
return results
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class RandomRotate(BaseTransform):
|
|
"""Rotate the image & seg.
|
|
|
|
Required Keys:
|
|
|
|
- img
|
|
- gt_seg_map
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
- gt_seg_map
|
|
|
|
Args:
|
|
prob (float): The rotation probability.
|
|
degree (float, tuple[float]): Range of degrees to select from. If
|
|
degree is a number instead of tuple like (min, max),
|
|
the range of degree will be (``-degree``, ``+degree``)
|
|
pad_val (float, optional): Padding value of image. Default: 0.
|
|
seg_pad_val (float, optional): Padding value of segmentation map.
|
|
Default: 255.
|
|
center (tuple[float], optional): Center point (w, h) of the rotation in
|
|
the source image. If not specified, the center of the image will be
|
|
used. Default: None.
|
|
auto_bound (bool): Whether to adjust the image size to cover the whole
|
|
rotated image. Default: False
|
|
"""
|
|
|
|
def __init__(self,
|
|
prob,
|
|
degree,
|
|
pad_val=0,
|
|
seg_pad_val=255,
|
|
center=None,
|
|
auto_bound=False):
|
|
self.prob = prob
|
|
assert prob >= 0 and prob <= 1
|
|
if isinstance(degree, (float, int)):
|
|
assert degree > 0, f'degree {degree} should be positive'
|
|
self.degree = (-degree, degree)
|
|
else:
|
|
self.degree = degree
|
|
assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
|
|
f'tuple of (min, max)'
|
|
self.pal_val = pad_val
|
|
self.seg_pad_val = seg_pad_val
|
|
self.center = center
|
|
self.auto_bound = auto_bound
|
|
|
|
@cache_randomness
|
|
def generate_degree(self):
|
|
return np.random.rand() < self.prob, np.random.uniform(
|
|
min(*self.degree), max(*self.degree))
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Call function to rotate image, semantic segmentation maps.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Rotated results.
|
|
"""
|
|
|
|
rotate, degree = self.generate_degree()
|
|
if rotate:
|
|
# rotate image
|
|
results['img'] = mmcv.imrotate(
|
|
results['img'],
|
|
angle=degree,
|
|
border_value=self.pal_val,
|
|
center=self.center,
|
|
auto_bound=self.auto_bound)
|
|
|
|
# rotate segs
|
|
for key in results.get('seg_fields', []):
|
|
results[key] = mmcv.imrotate(
|
|
results[key],
|
|
angle=degree,
|
|
border_value=self.seg_pad_val,
|
|
center=self.center,
|
|
auto_bound=self.auto_bound,
|
|
interpolation='nearest')
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(prob={self.prob}, ' \
|
|
f'degree={self.degree}, ' \
|
|
f'pad_val={self.pal_val}, ' \
|
|
f'seg_pad_val={self.seg_pad_val}, ' \
|
|
f'center={self.center}, ' \
|
|
f'auto_bound={self.auto_bound})'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class RGB2Gray(BaseTransform):
|
|
"""Convert RGB image to grayscale image.
|
|
|
|
Required Keys:
|
|
|
|
- img
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
- img_shape
|
|
|
|
This transform calculate the weighted mean of input image channels with
|
|
``weights`` and then expand the channels to ``out_channels``. When
|
|
``out_channels`` is None, the number of output channels is the same as
|
|
input channels.
|
|
|
|
Args:
|
|
out_channels (int): Expected number of output channels after
|
|
transforming. Default: None.
|
|
weights (tuple[float]): The weights to calculate the weighted mean.
|
|
Default: (0.299, 0.587, 0.114).
|
|
"""
|
|
|
|
def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
|
|
assert out_channels is None or out_channels > 0
|
|
self.out_channels = out_channels
|
|
assert isinstance(weights, tuple)
|
|
for item in weights:
|
|
assert isinstance(item, (float, int))
|
|
self.weights = weights
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Call function to convert RGB image to grayscale image.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Result dict with grayscale image.
|
|
"""
|
|
img = results['img']
|
|
assert len(img.shape) == 3
|
|
assert img.shape[2] == len(self.weights)
|
|
weights = np.array(self.weights).reshape((1, 1, -1))
|
|
img = (img * weights).sum(2, keepdims=True)
|
|
if self.out_channels is None:
|
|
img = img.repeat(weights.shape[2], axis=2)
|
|
else:
|
|
img = img.repeat(self.out_channels, axis=2)
|
|
|
|
results['img'] = img
|
|
results['img_shape'] = img.shape
|
|
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(out_channels={self.out_channels}, ' \
|
|
f'weights={self.weights})'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class AdjustGamma(BaseTransform):
|
|
"""Using gamma correction to process the image.
|
|
|
|
Required Keys:
|
|
|
|
- img
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
|
|
Args:
|
|
gamma (float or int): Gamma value used in gamma correction.
|
|
Default: 1.0.
|
|
"""
|
|
|
|
def __init__(self, gamma=1.0):
|
|
assert isinstance(gamma, float) or isinstance(gamma, int)
|
|
assert gamma > 0
|
|
self.gamma = gamma
|
|
inv_gamma = 1.0 / gamma
|
|
self.table = np.array([(i / 255.0)**inv_gamma * 255
|
|
for i in np.arange(256)]).astype('uint8')
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Call function to process the image with gamma correction.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Processed results.
|
|
"""
|
|
|
|
results['img'] = mmcv.lut_transform(
|
|
np.array(results['img'], dtype=np.uint8), self.table)
|
|
|
|
return results
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + f'(gamma={self.gamma})'
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class SegRescale(BaseTransform):
|
|
"""Rescale semantic segmentation maps.
|
|
|
|
Required Keys:
|
|
|
|
- gt_seg_map
|
|
|
|
Modified Keys:
|
|
|
|
- gt_seg_map
|
|
|
|
Args:
|
|
scale_factor (float): The scale factor of the final output.
|
|
"""
|
|
|
|
def __init__(self, scale_factor=1):
|
|
self.scale_factor = scale_factor
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Call function to scale the semantic segmentation map.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Result dict with semantic segmentation map scaled.
|
|
"""
|
|
for key in results.get('seg_fields', []):
|
|
if self.scale_factor != 1:
|
|
results[key] = mmcv.imrescale(
|
|
results[key], self.scale_factor, interpolation='nearest')
|
|
return results
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class PhotoMetricDistortion(BaseTransform):
|
|
"""Apply photometric distortion to image sequentially, every transformation
|
|
is applied with a probability of 0.5. The position of random contrast is in
|
|
second or second to last.
|
|
|
|
1. random brightness
|
|
2. random contrast (mode 0)
|
|
3. convert color from BGR to HSV
|
|
4. random saturation
|
|
5. random hue
|
|
6. convert color from HSV to BGR
|
|
7. random contrast (mode 1)
|
|
|
|
Required Keys:
|
|
|
|
- img
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
|
|
Args:
|
|
brightness_delta (int): delta of brightness.
|
|
contrast_range (tuple): range of contrast.
|
|
saturation_range (tuple): range of saturation.
|
|
hue_delta (int): delta of hue.
|
|
"""
|
|
|
|
def __init__(self,
|
|
brightness_delta: int = 32,
|
|
contrast_range: Sequence[float] = (0.5, 1.5),
|
|
saturation_range: Sequence[float] = (0.5, 1.5),
|
|
hue_delta: int = 18):
|
|
self.brightness_delta = brightness_delta
|
|
self.contrast_lower, self.contrast_upper = contrast_range
|
|
self.saturation_lower, self.saturation_upper = saturation_range
|
|
self.hue_delta = hue_delta
|
|
|
|
def convert(self,
|
|
img: np.ndarray,
|
|
alpha: int = 1,
|
|
beta: int = 0) -> np.ndarray:
|
|
"""Multiple with alpha and add beat with clip.
|
|
|
|
Args:
|
|
img (np.ndarray): The input image.
|
|
alpha (int): Image weights, change the contrast/saturation
|
|
of the image. Default: 1
|
|
beta (int): Image bias, change the brightness of the
|
|
image. Default: 0
|
|
|
|
Returns:
|
|
np.ndarray: The transformed image.
|
|
"""
|
|
|
|
img = img.astype(np.float32) * alpha + beta
|
|
img = np.clip(img, 0, 255)
|
|
return img.astype(np.uint8)
|
|
|
|
def brightness(self, img: np.ndarray) -> np.ndarray:
|
|
"""Brightness distortion.
|
|
|
|
Args:
|
|
img (np.ndarray): The input image.
|
|
Returns:
|
|
np.ndarray: Image after brightness change.
|
|
"""
|
|
|
|
if random.randint(2):
|
|
return self.convert(
|
|
img,
|
|
beta=random.uniform(-self.brightness_delta,
|
|
self.brightness_delta))
|
|
return img
|
|
|
|
def contrast(self, img: np.ndarray) -> np.ndarray:
|
|
"""Contrast distortion.
|
|
|
|
Args:
|
|
img (np.ndarray): The input image.
|
|
Returns:
|
|
np.ndarray: Image after contrast change.
|
|
"""
|
|
|
|
if random.randint(2):
|
|
return self.convert(
|
|
img,
|
|
alpha=random.uniform(self.contrast_lower, self.contrast_upper))
|
|
return img
|
|
|
|
def saturation(self, img: np.ndarray) -> np.ndarray:
|
|
"""Saturation distortion.
|
|
|
|
Args:
|
|
img (np.ndarray): The input image.
|
|
Returns:
|
|
np.ndarray: Image after saturation change.
|
|
"""
|
|
|
|
if random.randint(2):
|
|
img = mmcv.bgr2hsv(img)
|
|
img[:, :, 1] = self.convert(
|
|
img[:, :, 1],
|
|
alpha=random.uniform(self.saturation_lower,
|
|
self.saturation_upper))
|
|
img = mmcv.hsv2bgr(img)
|
|
return img
|
|
|
|
def hue(self, img: np.ndarray) -> np.ndarray:
|
|
"""Hue distortion.
|
|
|
|
Args:
|
|
img (np.ndarray): The input image.
|
|
Returns:
|
|
np.ndarray: Image after hue change.
|
|
"""
|
|
|
|
if random.randint(2):
|
|
img = mmcv.bgr2hsv(img)
|
|
img[:, :,
|
|
0] = (img[:, :, 0].astype(int) +
|
|
random.randint(-self.hue_delta, self.hue_delta)) % 180
|
|
img = mmcv.hsv2bgr(img)
|
|
return img
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Transform function to perform photometric distortion on images.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Result dict with images distorted.
|
|
"""
|
|
|
|
img = results['img']
|
|
# random brightness
|
|
img = self.brightness(img)
|
|
|
|
# mode == 0 --> do random contrast first
|
|
# mode == 1 --> do random contrast last
|
|
mode = random.randint(2)
|
|
if mode == 1:
|
|
img = self.contrast(img)
|
|
|
|
# random saturation
|
|
img = self.saturation(img)
|
|
|
|
# random hue
|
|
img = self.hue(img)
|
|
|
|
# random contrast
|
|
if mode == 0:
|
|
img = self.contrast(img)
|
|
|
|
results['img'] = img
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += (f'(brightness_delta={self.brightness_delta}, '
|
|
f'contrast_range=({self.contrast_lower}, '
|
|
f'{self.contrast_upper}), '
|
|
f'saturation_range=({self.saturation_lower}, '
|
|
f'{self.saturation_upper}), '
|
|
f'hue_delta={self.hue_delta})')
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class RandomCutOut(BaseTransform):
|
|
"""CutOut operation.
|
|
|
|
Randomly drop some regions of image used in
|
|
`Cutout <https://arxiv.org/abs/1708.04552>`_.
|
|
|
|
Required Keys:
|
|
|
|
- img
|
|
- gt_seg_map
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
- gt_seg_map
|
|
|
|
Args:
|
|
prob (float): cutout probability.
|
|
n_holes (int | tuple[int, int]): Number of regions to be dropped.
|
|
If it is given as a list, number of holes will be randomly
|
|
selected from the closed interval [`n_holes[0]`, `n_holes[1]`].
|
|
cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate
|
|
shape of dropped regions. It can be `tuple[int, int]` to use a
|
|
fixed cutout shape, or `list[tuple[int, int]]` to randomly choose
|
|
shape from the list.
|
|
cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The
|
|
candidate ratio of dropped regions. It can be `tuple[float, float]`
|
|
to use a fixed ratio or `list[tuple[float, float]]` to randomly
|
|
choose ratio from the list. Please note that `cutout_shape`
|
|
and `cutout_ratio` cannot be both given at the same time.
|
|
fill_in (tuple[float, float, float] | tuple[int, int, int]): The value
|
|
of pixel to fill in the dropped regions. Default: (0, 0, 0).
|
|
seg_fill_in (int): The labels of pixel to fill in the dropped regions.
|
|
If seg_fill_in is None, skip. Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
prob,
|
|
n_holes,
|
|
cutout_shape=None,
|
|
cutout_ratio=None,
|
|
fill_in=(0, 0, 0),
|
|
seg_fill_in=None):
|
|
|
|
assert 0 <= prob and prob <= 1
|
|
assert (cutout_shape is None) ^ (cutout_ratio is None), \
|
|
'Either cutout_shape or cutout_ratio should be specified.'
|
|
assert (isinstance(cutout_shape, (list, tuple))
|
|
or isinstance(cutout_ratio, (list, tuple)))
|
|
if isinstance(n_holes, tuple):
|
|
assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1]
|
|
else:
|
|
n_holes = (n_holes, n_holes)
|
|
if seg_fill_in is not None:
|
|
assert (isinstance(seg_fill_in, int) and 0 <= seg_fill_in
|
|
and seg_fill_in <= 255)
|
|
self.prob = prob
|
|
self.n_holes = n_holes
|
|
self.fill_in = fill_in
|
|
self.seg_fill_in = seg_fill_in
|
|
self.with_ratio = cutout_ratio is not None
|
|
self.candidates = cutout_ratio if self.with_ratio else cutout_shape
|
|
if not isinstance(self.candidates, list):
|
|
self.candidates = [self.candidates]
|
|
|
|
@cache_randomness
|
|
def do_cutout(self):
|
|
return np.random.rand() < self.prob
|
|
|
|
@cache_randomness
|
|
def generate_patches(self, results):
|
|
cutout = self.do_cutout()
|
|
|
|
h, w, _ = results['img'].shape
|
|
if cutout:
|
|
n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
|
|
else:
|
|
n_holes = 0
|
|
x1_lst = []
|
|
y1_lst = []
|
|
index_lst = []
|
|
for _ in range(n_holes):
|
|
x1_lst.append(np.random.randint(0, w))
|
|
y1_lst.append(np.random.randint(0, h))
|
|
index_lst.append(np.random.randint(0, len(self.candidates)))
|
|
return cutout, n_holes, x1_lst, y1_lst, index_lst
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Call function to drop some regions of image."""
|
|
cutout, n_holes, x1_lst, y1_lst, index_lst = self.generate_patches(
|
|
results)
|
|
if cutout:
|
|
h, w, c = results['img'].shape
|
|
for i in range(n_holes):
|
|
x1 = x1_lst[i]
|
|
y1 = y1_lst[i]
|
|
index = index_lst[i]
|
|
if not self.with_ratio:
|
|
cutout_w, cutout_h = self.candidates[index]
|
|
else:
|
|
cutout_w = int(self.candidates[index][0] * w)
|
|
cutout_h = int(self.candidates[index][1] * h)
|
|
|
|
x2 = np.clip(x1 + cutout_w, 0, w)
|
|
y2 = np.clip(y1 + cutout_h, 0, h)
|
|
results['img'][y1:y2, x1:x2, :] = self.fill_in
|
|
|
|
if self.seg_fill_in is not None:
|
|
for key in results.get('seg_fields', []):
|
|
results[key][y1:y2, x1:x2] = self.seg_fill_in
|
|
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(prob={self.prob}, '
|
|
repr_str += f'n_holes={self.n_holes}, '
|
|
repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio
|
|
else f'cutout_shape={self.candidates}, ')
|
|
repr_str += f'fill_in={self.fill_in}, '
|
|
repr_str += f'seg_fill_in={self.seg_fill_in})'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class RandomRotFlip(BaseTransform):
|
|
"""Rotate and flip the image & seg or just rotate the image & seg.
|
|
|
|
Required Keys:
|
|
|
|
- img
|
|
- gt_seg_map
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
- gt_seg_map
|
|
|
|
Args:
|
|
rotate_prob (float): The probability of rotate image.
|
|
flip_prob (float): The probability of rotate&flip image.
|
|
degree (float, tuple[float]): Range of degrees to select from. If
|
|
degree is a number instead of tuple like (min, max),
|
|
the range of degree will be (``-degree``, ``+degree``)
|
|
"""
|
|
|
|
def __init__(self, rotate_prob=0.5, flip_prob=0.5, degree=(-20, 20)):
|
|
self.rotate_prob = rotate_prob
|
|
self.flip_prob = flip_prob
|
|
assert 0 <= rotate_prob <= 1 and 0 <= flip_prob <= 1
|
|
if isinstance(degree, (float, int)):
|
|
assert degree > 0, f'degree {degree} should be positive'
|
|
self.degree = (-degree, degree)
|
|
else:
|
|
self.degree = degree
|
|
assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
|
|
f'tuple of (min, max)'
|
|
|
|
def random_rot_flip(self, results: dict) -> dict:
|
|
k = np.random.randint(0, 4)
|
|
results['img'] = np.rot90(results['img'], k)
|
|
for key in results.get('seg_fields', []):
|
|
results[key] = np.rot90(results[key], k)
|
|
axis = np.random.randint(0, 2)
|
|
results['img'] = np.flip(results['img'], axis=axis).copy()
|
|
for key in results.get('seg_fields', []):
|
|
results[key] = np.flip(results[key], axis=axis).copy()
|
|
return results
|
|
|
|
def random_rotate(self, results: dict) -> dict:
|
|
angle = np.random.uniform(min(*self.degree), max(*self.degree))
|
|
results['img'] = mmcv.imrotate(results['img'], angle=angle)
|
|
for key in results.get('seg_fields', []):
|
|
results[key] = mmcv.imrotate(results[key], angle=angle)
|
|
return results
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Call function to rotate or rotate & flip image, semantic
|
|
segmentation maps.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Rotated or rotated & flipped results.
|
|
"""
|
|
rotate_flag = 0
|
|
if random.random() < self.rotate_prob:
|
|
results = self.random_rotate(results)
|
|
rotate_flag = 1
|
|
if random.random() < self.flip_prob and rotate_flag == 0:
|
|
results = self.random_rot_flip(results)
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(rotate_prob={self.rotate_prob}, ' \
|
|
f'flip_prob={self.flip_prob}, ' \
|
|
f'degree={self.degree})'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class RandomMosaic(BaseTransform):
|
|
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
|
|
one output image. The output image is composed of the parts from each sub-
|
|
image.
|
|
|
|
.. code:: text
|
|
|
|
mosaic transform
|
|
center_x
|
|
+------------------------------+
|
|
| pad | pad |
|
|
| +-----------+ |
|
|
| | | |
|
|
| | image1 |--------+ |
|
|
| | | | |
|
|
| | | image2 | |
|
|
center_y |----+-------------+-----------|
|
|
| | cropped | |
|
|
|pad | image3 | image4 |
|
|
| | | |
|
|
+----|-------------+-----------+
|
|
| |
|
|
+-------------+
|
|
|
|
The mosaic transform steps are as follows:
|
|
1. Choose the mosaic center as the intersections of 4 images
|
|
2. Get the left top image according to the index, and randomly
|
|
sample another 3 images from the custom dataset.
|
|
3. Sub image will be cropped if image is larger than mosaic patch
|
|
|
|
Required Keys:
|
|
|
|
- img
|
|
- gt_seg_map
|
|
- mix_results
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
- img_shape
|
|
- ori_shape
|
|
- gt_seg_map
|
|
|
|
Args:
|
|
prob (float): mosaic probability.
|
|
img_scale (Sequence[int]): Image size after mosaic pipeline of
|
|
a single image. The size of the output image is four times
|
|
that of a single image. The output image comprises 4 single images.
|
|
Default: (640, 640).
|
|
center_ratio_range (Sequence[float]): Center ratio range of mosaic
|
|
output. Default: (0.5, 1.5).
|
|
pad_val (int): Pad value. Default: 0.
|
|
seg_pad_val (int): Pad value of segmentation map. Default: 255.
|
|
"""
|
|
|
|
def __init__(self,
|
|
prob,
|
|
img_scale=(640, 640),
|
|
center_ratio_range=(0.5, 1.5),
|
|
pad_val=0,
|
|
seg_pad_val=255):
|
|
assert 0 <= prob and prob <= 1
|
|
assert isinstance(img_scale, tuple)
|
|
self.prob = prob
|
|
self.img_scale = img_scale
|
|
self.center_ratio_range = center_ratio_range
|
|
self.pad_val = pad_val
|
|
self.seg_pad_val = seg_pad_val
|
|
|
|
@cache_randomness
|
|
def do_mosaic(self):
|
|
return np.random.rand() < self.prob
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Call function to make a mosaic of image.
|
|
|
|
Args:
|
|
results (dict): Result dict.
|
|
|
|
Returns:
|
|
dict: Result dict with mosaic transformed.
|
|
"""
|
|
mosaic = self.do_mosaic()
|
|
if mosaic:
|
|
results = self._mosaic_transform_img(results)
|
|
results = self._mosaic_transform_seg(results)
|
|
return results
|
|
|
|
def get_indices(self, dataset: MultiImageMixDataset) -> list:
|
|
"""Call function to collect indices.
|
|
|
|
Args:
|
|
dataset (:obj:`MultiImageMixDataset`): The dataset.
|
|
|
|
Returns:
|
|
list: indices.
|
|
"""
|
|
|
|
indices = [random.randint(0, len(dataset)) for _ in range(3)]
|
|
return indices
|
|
|
|
@cache_randomness
|
|
def generate_mosaic_center(self):
|
|
# mosaic center x, y
|
|
center_x = int(
|
|
random.uniform(*self.center_ratio_range) * self.img_scale[1])
|
|
center_y = int(
|
|
random.uniform(*self.center_ratio_range) * self.img_scale[0])
|
|
return center_x, center_y
|
|
|
|
def _mosaic_transform_img(self, results: dict) -> dict:
|
|
"""Mosaic transform function.
|
|
|
|
Args:
|
|
results (dict): Result dict.
|
|
|
|
Returns:
|
|
dict: Updated result dict.
|
|
"""
|
|
|
|
assert 'mix_results' in results
|
|
if len(results['img'].shape) == 3:
|
|
c = results['img'].shape[2]
|
|
mosaic_img = np.full(
|
|
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), c),
|
|
self.pad_val,
|
|
dtype=results['img'].dtype)
|
|
else:
|
|
mosaic_img = np.full(
|
|
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
|
|
self.pad_val,
|
|
dtype=results['img'].dtype)
|
|
|
|
# mosaic center x, y
|
|
self.center_x, self.center_y = self.generate_mosaic_center()
|
|
center_position = (self.center_x, self.center_y)
|
|
|
|
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
|
for i, loc in enumerate(loc_strs):
|
|
if loc == 'top_left':
|
|
result_patch = copy.deepcopy(results)
|
|
else:
|
|
result_patch = copy.deepcopy(results['mix_results'][i - 1])
|
|
|
|
img_i = result_patch['img']
|
|
h_i, w_i = img_i.shape[:2]
|
|
# keep_ratio resize
|
|
scale_ratio_i = min(self.img_scale[0] / h_i,
|
|
self.img_scale[1] / w_i)
|
|
img_i = mmcv.imresize(
|
|
img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
|
|
|
|
# compute the combine parameters
|
|
paste_coord, crop_coord = self._mosaic_combine(
|
|
loc, center_position, img_i.shape[:2][::-1])
|
|
x1_p, y1_p, x2_p, y2_p = paste_coord
|
|
x1_c, y1_c, x2_c, y2_c = crop_coord
|
|
|
|
# crop and paste image
|
|
mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
|
|
|
|
results['img'] = mosaic_img
|
|
results['img_shape'] = mosaic_img.shape
|
|
results['ori_shape'] = mosaic_img.shape
|
|
|
|
return results
|
|
|
|
def _mosaic_transform_seg(self, results: dict) -> dict:
|
|
"""Mosaic transform function for label annotations.
|
|
|
|
Args:
|
|
results (dict): Result dict.
|
|
|
|
Returns:
|
|
dict: Updated result dict.
|
|
"""
|
|
|
|
assert 'mix_results' in results
|
|
for key in results.get('seg_fields', []):
|
|
mosaic_seg = np.full(
|
|
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
|
|
self.seg_pad_val,
|
|
dtype=results[key].dtype)
|
|
|
|
# mosaic center x, y
|
|
center_position = (self.center_x, self.center_y)
|
|
|
|
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
|
for i, loc in enumerate(loc_strs):
|
|
if loc == 'top_left':
|
|
result_patch = copy.deepcopy(results)
|
|
else:
|
|
result_patch = copy.deepcopy(results['mix_results'][i - 1])
|
|
|
|
gt_seg_i = result_patch[key]
|
|
h_i, w_i = gt_seg_i.shape[:2]
|
|
# keep_ratio resize
|
|
scale_ratio_i = min(self.img_scale[0] / h_i,
|
|
self.img_scale[1] / w_i)
|
|
gt_seg_i = mmcv.imresize(
|
|
gt_seg_i,
|
|
(int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)),
|
|
interpolation='nearest')
|
|
|
|
# compute the combine parameters
|
|
paste_coord, crop_coord = self._mosaic_combine(
|
|
loc, center_position, gt_seg_i.shape[:2][::-1])
|
|
x1_p, y1_p, x2_p, y2_p = paste_coord
|
|
x1_c, y1_c, x2_c, y2_c = crop_coord
|
|
|
|
# crop and paste image
|
|
mosaic_seg[y1_p:y2_p, x1_p:x2_p] = gt_seg_i[y1_c:y2_c,
|
|
x1_c:x2_c]
|
|
|
|
results[key] = mosaic_seg
|
|
|
|
return results
|
|
|
|
def _mosaic_combine(self, loc: str, center_position_xy: Sequence[float],
|
|
img_shape_wh: Sequence[int]) -> tuple:
|
|
"""Calculate global coordinate of mosaic image and local coordinate of
|
|
cropped sub-image.
|
|
|
|
Args:
|
|
loc (str): Index for the sub-image, loc in ('top_left',
|
|
'top_right', 'bottom_left', 'bottom_right').
|
|
center_position_xy (Sequence[float]): Mixing center for 4 images,
|
|
(x, y).
|
|
img_shape_wh (Sequence[int]): Width and height of sub-image
|
|
|
|
Returns:
|
|
tuple[tuple[float]]: Corresponding coordinate of pasting and
|
|
cropping
|
|
- paste_coord (tuple): paste corner coordinate in mosaic image.
|
|
- crop_coord (tuple): crop corner coordinate in mosaic image.
|
|
"""
|
|
|
|
assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
|
if loc == 'top_left':
|
|
# index0 to top left part of image
|
|
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
|
|
max(center_position_xy[1] - img_shape_wh[1], 0), \
|
|
center_position_xy[0], \
|
|
center_position_xy[1]
|
|
crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - (
|
|
y2 - y1), img_shape_wh[0], img_shape_wh[1]
|
|
|
|
elif loc == 'top_right':
|
|
# index1 to top right part of image
|
|
x1, y1, x2, y2 = center_position_xy[0], \
|
|
max(center_position_xy[1] - img_shape_wh[1], 0), \
|
|
min(center_position_xy[0] + img_shape_wh[0],
|
|
self.img_scale[1] * 2), \
|
|
center_position_xy[1]
|
|
crop_coord = 0, img_shape_wh[1] - (y2 - y1), min(
|
|
img_shape_wh[0], x2 - x1), img_shape_wh[1]
|
|
|
|
elif loc == 'bottom_left':
|
|
# index2 to bottom left part of image
|
|
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
|
|
center_position_xy[1], \
|
|
center_position_xy[0], \
|
|
min(self.img_scale[0] * 2, center_position_xy[1] +
|
|
img_shape_wh[1])
|
|
crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min(
|
|
y2 - y1, img_shape_wh[1])
|
|
|
|
else:
|
|
# index3 to bottom right part of image
|
|
x1, y1, x2, y2 = center_position_xy[0], \
|
|
center_position_xy[1], \
|
|
min(center_position_xy[0] + img_shape_wh[0],
|
|
self.img_scale[1] * 2), \
|
|
min(self.img_scale[0] * 2, center_position_xy[1] +
|
|
img_shape_wh[1])
|
|
crop_coord = 0, 0, min(img_shape_wh[0],
|
|
x2 - x1), min(y2 - y1, img_shape_wh[1])
|
|
|
|
paste_coord = x1, y1, x2, y2
|
|
return paste_coord, crop_coord
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(prob={self.prob}, '
|
|
repr_str += f'img_scale={self.img_scale}, '
|
|
repr_str += f'center_ratio_range={self.center_ratio_range}, '
|
|
repr_str += f'pad_val={self.pad_val}, '
|
|
repr_str += f'seg_pad_val={self.pad_val})'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class GenerateEdge(BaseTransform):
|
|
"""Generate Edge for CE2P approach.
|
|
|
|
Edge will be used to calculate loss of
|
|
`CE2P <https://arxiv.org/abs/1809.05996>`_.
|
|
|
|
Modified from https://github.com/liutinglt/CE2P/blob/master/dataset/target_generation.py # noqa:E501
|
|
|
|
Required Keys:
|
|
|
|
- img_shape
|
|
- gt_seg_map
|
|
|
|
Added Keys:
|
|
- gt_edge_map (np.ndarray, uint8): The edge annotation generated from the
|
|
seg map by extracting border between different semantics.
|
|
|
|
Args:
|
|
edge_width (int): The width of edge. Default to 3.
|
|
ignore_index (int): Index that will be ignored. Default to 255.
|
|
"""
|
|
|
|
def __init__(self, edge_width: int = 3, ignore_index: int = 255) -> None:
|
|
super().__init__()
|
|
self.edge_width = edge_width
|
|
self.ignore_index = ignore_index
|
|
|
|
def transform(self, results: Dict) -> Dict:
|
|
"""Call function to generate edge from segmentation map.
|
|
|
|
Args:
|
|
results (dict): Result dict.
|
|
|
|
Returns:
|
|
dict: Result dict with edge mask.
|
|
"""
|
|
h, w = results['img_shape']
|
|
edge = np.zeros((h, w), dtype=np.uint8)
|
|
seg_map = results['gt_seg_map']
|
|
|
|
# down
|
|
edge_down = edge[1:h, :]
|
|
edge_down[(seg_map[1:h, :] != seg_map[:h - 1, :])
|
|
& (seg_map[1:h, :] != self.ignore_index) &
|
|
(seg_map[:h - 1, :] != self.ignore_index)] = 1
|
|
# left
|
|
edge_left = edge[:, :w - 1]
|
|
edge_left[(seg_map[:, :w - 1] != seg_map[:, 1:w])
|
|
& (seg_map[:, :w - 1] != self.ignore_index) &
|
|
(seg_map[:, 1:w] != self.ignore_index)] = 1
|
|
# up_left
|
|
edge_upleft = edge[:h - 1, :w - 1]
|
|
edge_upleft[(seg_map[:h - 1, :w - 1] != seg_map[1:h, 1:w])
|
|
& (seg_map[:h - 1, :w - 1] != self.ignore_index) &
|
|
(seg_map[1:h, 1:w] != self.ignore_index)] = 1
|
|
# up_right
|
|
edge_upright = edge[:h - 1, 1:w]
|
|
edge_upright[(seg_map[:h - 1, 1:w] != seg_map[1:h, :w - 1])
|
|
& (seg_map[:h - 1, 1:w] != self.ignore_index) &
|
|
(seg_map[1:h, :w - 1] != self.ignore_index)] = 1
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_RECT,
|
|
(self.edge_width, self.edge_width))
|
|
edge = cv2.dilate(edge, kernel)
|
|
|
|
results['gt_edge_map'] = edge
|
|
results['edge_width'] = self.edge_width
|
|
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'edge_width={self.edge_width}, '
|
|
repr_str += f'ignore_index={self.ignore_index})'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class ResizeShortestEdge(BaseTransform):
|
|
"""Resize the image and mask while keeping the aspect ratio unchanged.
|
|
|
|
Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501
|
|
Copyright (c) Facebook, Inc. and its affiliates.
|
|
Licensed under the Apache-2.0 License
|
|
|
|
This transform attempts to scale the shorter edge to the given
|
|
`scale`, as long as the longer edge does not exceed `max_size`.
|
|
If `max_size` is reached, then downscale so that the longer
|
|
edge does not exceed `max_size`.
|
|
|
|
Required Keys:
|
|
|
|
- img
|
|
- gt_seg_map (optional)
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
- img_shape
|
|
- gt_seg_map (optional))
|
|
|
|
Added Keys:
|
|
|
|
- scale
|
|
- scale_factor
|
|
- keep_ratio
|
|
|
|
|
|
Args:
|
|
scale (Union[int, Tuple[int, int]]): The target short edge length.
|
|
If it's tuple, will select the min value as the short edge length.
|
|
max_size (int): The maximum allowed longest edge length.
|
|
"""
|
|
|
|
def __init__(self, scale: Union[int, Tuple[int, int]],
|
|
max_size: int) -> None:
|
|
super().__init__()
|
|
self.scale = scale
|
|
self.max_size = max_size
|
|
|
|
# Create a empty Resize object
|
|
self.resize = TRANSFORMS.build({
|
|
'type': 'Resize',
|
|
'scale': 0,
|
|
'keep_ratio': True
|
|
})
|
|
|
|
def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]:
|
|
"""Compute the target image shape with the given `short_edge_length`.
|
|
|
|
Args:
|
|
img (np.ndarray): The input image.
|
|
short_edge_length (Union[int, Tuple[int, int]]): The target short
|
|
edge length. If it's tuple, will select the min value as the
|
|
short edge length.
|
|
"""
|
|
h, w = img.shape[:2]
|
|
if isinstance(short_edge_length, int):
|
|
size = short_edge_length * 1.0
|
|
elif isinstance(short_edge_length, tuple):
|
|
size = min(short_edge_length) * 1.0
|
|
scale = size / min(h, w)
|
|
if h < w:
|
|
new_h, new_w = size, scale * w
|
|
else:
|
|
new_h, new_w = scale * h, size
|
|
|
|
if max(new_h, new_w) > self.max_size:
|
|
scale = self.max_size * 1.0 / max(new_h, new_w)
|
|
new_h *= scale
|
|
new_w *= scale
|
|
|
|
new_h = int(new_h + 0.5)
|
|
new_w = int(new_w + 0.5)
|
|
return (new_w, new_h)
|
|
|
|
def transform(self, results: Dict) -> Dict:
|
|
self.resize.scale = self._get_output_shape(results['img'], self.scale)
|
|
return self.resize(results)
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class BioMedical3DRandomCrop(BaseTransform):
|
|
"""Crop the input patch for medical image & segmentation mask.
|
|
|
|
Required Keys:
|
|
|
|
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
|
|
N is the number of modalities, and data type is float32.
|
|
- gt_seg_map (np.ndarray, optional): Biomedical semantic segmentation mask
|
|
with shape (Z, Y, X).
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
- img_shape
|
|
- gt_seg_map (optional)
|
|
|
|
Args:
|
|
crop_shape (Union[int, Tuple[int, int, int]]): Expected size after
|
|
cropping with the format of (z, y, x). If set to an integer,
|
|
then cropping width and height are equal to this integer.
|
|
keep_foreground (bool): If keep_foreground is True, it will sample a
|
|
voxel of foreground classes randomly, and will take it as the
|
|
center of the crop bounding-box. Default to True.
|
|
"""
|
|
|
|
def __init__(self,
|
|
crop_shape: Union[int, Tuple[int, int, int]],
|
|
keep_foreground: bool = True):
|
|
super().__init__()
|
|
assert isinstance(crop_shape, int) or (
|
|
isinstance(crop_shape, tuple) and len(crop_shape) == 3
|
|
), 'The expected crop_shape is an integer, or a tuple containing '
|
|
'three integers'
|
|
|
|
if isinstance(crop_shape, int):
|
|
crop_shape = (crop_shape, crop_shape, crop_shape)
|
|
assert crop_shape[0] > 0 and crop_shape[1] > 0 and crop_shape[2] > 0
|
|
self.crop_shape = crop_shape
|
|
self.keep_foreground = keep_foreground
|
|
|
|
def random_sample_location(self, seg_map: np.ndarray) -> dict:
|
|
"""sample foreground voxel when keep_foreground is True.
|
|
|
|
Args:
|
|
seg_map (np.ndarray): gt seg map.
|
|
|
|
Returns:
|
|
dict: Coordinates of selected foreground voxel.
|
|
"""
|
|
num_samples = 10000
|
|
# at least 1% of the class voxels need to be selected,
|
|
# otherwise it may be too sparse
|
|
min_percent_coverage = 0.01
|
|
class_locs = {}
|
|
foreground_classes = []
|
|
all_classes = np.unique(seg_map)
|
|
for c in all_classes:
|
|
if c == 0:
|
|
# to avoid the segmentation mask full of background 0
|
|
# and the class_locs is just void dictionary {} when it return
|
|
# there add a void list for background 0.
|
|
class_locs[c] = []
|
|
else:
|
|
all_locs = np.argwhere(seg_map == c)
|
|
target_num_samples = min(num_samples, len(all_locs))
|
|
target_num_samples = max(
|
|
target_num_samples,
|
|
int(np.ceil(len(all_locs) * min_percent_coverage)))
|
|
|
|
selected = all_locs[np.random.choice(
|
|
len(all_locs), target_num_samples, replace=False)]
|
|
class_locs[c] = selected
|
|
foreground_classes.append(c)
|
|
|
|
selected_voxel = None
|
|
if len(foreground_classes) > 0:
|
|
selected_class = np.random.choice(foreground_classes)
|
|
voxels_of_that_class = class_locs[selected_class]
|
|
selected_voxel = voxels_of_that_class[np.random.choice(
|
|
len(voxels_of_that_class))]
|
|
|
|
return selected_voxel
|
|
|
|
def random_generate_crop_bbox(self, margin_z: int, margin_y: int,
|
|
margin_x: int) -> tuple:
|
|
"""Randomly get a crop bounding box.
|
|
|
|
Args:
|
|
seg_map (np.ndarray): Ground truth segmentation map.
|
|
|
|
Returns:
|
|
tuple: Coordinates of the cropped image.
|
|
"""
|
|
offset_z = np.random.randint(0, margin_z + 1)
|
|
offset_y = np.random.randint(0, margin_y + 1)
|
|
offset_x = np.random.randint(0, margin_x + 1)
|
|
crop_z1, crop_z2 = offset_z, offset_z + self.crop_shape[0]
|
|
crop_y1, crop_y2 = offset_y, offset_y + self.crop_shape[1]
|
|
crop_x1, crop_x2 = offset_x, offset_x + self.crop_shape[2]
|
|
|
|
return crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2
|
|
|
|
def generate_margin(self, results: dict) -> tuple:
|
|
"""Generate margin of crop bounding-box.
|
|
|
|
If keep_foreground is True, it will sample a voxel of foreground
|
|
classes randomly, and will take it as the center of the bounding-box,
|
|
and return the margin between of the bounding-box and image.
|
|
If keep_foreground is False, it will return the difference from crop
|
|
shape and image shape.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
tuple: The margin for 3 dimensions of crop bounding-box and image.
|
|
"""
|
|
|
|
seg_map = results['gt_seg_map']
|
|
if self.keep_foreground:
|
|
selected_voxel = self.random_sample_location(seg_map)
|
|
if selected_voxel is None:
|
|
# this only happens if some image does not contain
|
|
# foreground voxels at all
|
|
warnings.warn(f'case does not contain any foreground classes'
|
|
f': {results["img_path"]}')
|
|
margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0)
|
|
margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0)
|
|
margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0)
|
|
else:
|
|
margin_z = max(0, selected_voxel[0] - self.crop_shape[0] // 2)
|
|
margin_y = max(0, selected_voxel[1] - self.crop_shape[1] // 2)
|
|
margin_x = max(0, selected_voxel[2] - self.crop_shape[2] // 2)
|
|
margin_z = max(
|
|
0, min(seg_map.shape[0] - self.crop_shape[0], margin_z))
|
|
margin_y = max(
|
|
0, min(seg_map.shape[1] - self.crop_shape[1], margin_y))
|
|
margin_x = max(
|
|
0, min(seg_map.shape[2] - self.crop_shape[2], margin_x))
|
|
else:
|
|
margin_z = max(seg_map.shape[0] - self.crop_shape[0], 0)
|
|
margin_y = max(seg_map.shape[1] - self.crop_shape[1], 0)
|
|
margin_x = max(seg_map.shape[2] - self.crop_shape[2], 0)
|
|
|
|
return margin_z, margin_y, margin_x
|
|
|
|
def crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray:
|
|
"""Crop from ``img``
|
|
|
|
Args:
|
|
img (np.ndarray): Original input image.
|
|
crop_bbox (tuple): Coordinates of the cropped image.
|
|
|
|
Returns:
|
|
np.ndarray: The cropped image.
|
|
"""
|
|
crop_z1, crop_z2, crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
|
|
if len(img.shape) == 3:
|
|
# crop seg map
|
|
img = img[crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2]
|
|
else:
|
|
# crop image
|
|
assert len(img.shape) == 4
|
|
img = img[:, crop_z1:crop_z2, crop_y1:crop_y2, crop_x1:crop_x2]
|
|
return img
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Transform function to randomly crop images, semantic segmentation
|
|
maps.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Randomly cropped results, 'img_shape' key in result dict is
|
|
updated according to crop size.
|
|
"""
|
|
margin = self.generate_margin(results)
|
|
crop_bbox = self.random_generate_crop_bbox(*margin)
|
|
|
|
# crop the image
|
|
img = results['img']
|
|
results['img'] = self.crop(img, crop_bbox)
|
|
results['img_shape'] = results['img'].shape[1:]
|
|
|
|
# crop semantic seg
|
|
seg_map = results['gt_seg_map']
|
|
results['gt_seg_map'] = self.crop(seg_map, crop_bbox)
|
|
|
|
return results
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + f'(crop_shape={self.crop_shape})'
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class BioMedicalGaussianNoise(BaseTransform):
|
|
"""Add random Gaussian noise to image.
|
|
|
|
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L53 # noqa:E501
|
|
|
|
Copyright (c) German Cancer Research Center (DKFZ)
|
|
Licensed under the Apache License, Version 2.0
|
|
|
|
Required Keys:
|
|
|
|
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
|
|
N is the number of modalities, and data type is float32.
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
|
|
Args:
|
|
prob (float): Probability to add Gaussian noise for
|
|
each sample. Default to 0.1.
|
|
mean (float): Mean or “centre” of the distribution. Default to 0.0.
|
|
std (float): Standard deviation of distribution. Default to 0.1.
|
|
"""
|
|
|
|
def __init__(self,
|
|
prob: float = 0.1,
|
|
mean: float = 0.0,
|
|
std: float = 0.1) -> None:
|
|
super().__init__()
|
|
assert 0.0 <= prob <= 1.0 and std >= 0.0
|
|
self.prob = prob
|
|
self.mean = mean
|
|
self.std = std
|
|
|
|
def transform(self, results: Dict) -> Dict:
|
|
"""Call function to add random Gaussian noise to image.
|
|
|
|
Args:
|
|
results (dict): Result dict.
|
|
|
|
Returns:
|
|
dict: Result dict with random Gaussian noise.
|
|
"""
|
|
if np.random.rand() < self.prob:
|
|
rand_std = np.random.uniform(0, self.std)
|
|
noise = np.random.normal(
|
|
self.mean, rand_std, size=results['img'].shape)
|
|
# noise is float64 array, convert to the results['img'].dtype
|
|
noise = noise.astype(results['img'].dtype)
|
|
results['img'] = results['img'] + noise
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(prob={self.prob}, '
|
|
repr_str += f'mean={self.mean}, '
|
|
repr_str += f'std={self.std})'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class BioMedicalGaussianBlur(BaseTransform):
|
|
"""Add Gaussian blur with random sigma to image.
|
|
|
|
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/noise_transforms.py#L81 # noqa:E501
|
|
|
|
Copyright (c) German Cancer Research Center (DKFZ)
|
|
Licensed under the Apache License, Version 2.0
|
|
|
|
Required Keys:
|
|
|
|
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
|
|
N is the number of modalities, and data type is float32.
|
|
|
|
Modified Keys:
|
|
|
|
- img
|
|
|
|
Args:
|
|
sigma_range (Tuple[float, float]|float): range to randomly
|
|
select sigma value. Default to (0.5, 1.0).
|
|
prob (float): Probability to apply Gaussian blur
|
|
for each sample. Default to 0.2.
|
|
prob_per_channel (float): Probability to apply Gaussian blur
|
|
for each channel (axis N of the image). Default to 0.5.
|
|
different_sigma_per_channel (bool): whether to use different
|
|
sigma for each channel (axis N of the image). Default to True.
|
|
different_sigma_per_axis (bool): whether to use different
|
|
sigma for axis Z, X and Y of the image. Default to True.
|
|
"""
|
|
|
|
def __init__(self,
|
|
sigma_range: Tuple[float, float] = (0.5, 1.0),
|
|
prob: float = 0.2,
|
|
prob_per_channel: float = 0.5,
|
|
different_sigma_per_channel: bool = True,
|
|
different_sigma_per_axis: bool = True) -> None:
|
|
super().__init__()
|
|
assert 0.0 <= prob <= 1.0
|
|
assert 0.0 <= prob_per_channel <= 1.0
|
|
assert isinstance(sigma_range, Sequence) and len(sigma_range) == 2
|
|
self.sigma_range = sigma_range
|
|
self.prob = prob
|
|
self.prob_per_channel = prob_per_channel
|
|
self.different_sigma_per_channel = different_sigma_per_channel
|
|
self.different_sigma_per_axis = different_sigma_per_axis
|
|
|
|
def _get_valid_sigma(self, value_range) -> Tuple[float, ...]:
|
|
"""Ensure the `value_range` to be either a single value or a sequence
|
|
of two values. If the `value_range` is a sequence, generate a random
|
|
value with `[value_range[0], value_range[1]]` based on uniform
|
|
sampling.
|
|
|
|
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/augmentations/utils.py#L625 # noqa:E501
|
|
|
|
Args:
|
|
value_range (tuple|list|float|int): the input value range
|
|
"""
|
|
if (isinstance(value_range, (list, tuple))):
|
|
if (value_range[0] == value_range[1]):
|
|
value = value_range[0]
|
|
else:
|
|
orig_type = type(value_range[0])
|
|
value = np.random.uniform(value_range[0], value_range[1])
|
|
value = orig_type(value)
|
|
return value
|
|
|
|
def _gaussian_blur(self, data_sample: np.ndarray) -> np.ndarray:
|
|
"""Random generate sigma and apply Gaussian Blur to the data
|
|
Args:
|
|
data_sample (np.ndarray): data sample with multiple modalities,
|
|
the data shape is (N, Z, Y, X)
|
|
"""
|
|
sigma = None
|
|
for c in range(data_sample.shape[0]):
|
|
if np.random.rand() < self.prob_per_channel:
|
|
# if no `sigma` is generated, generate one
|
|
# if `self.different_sigma_per_channel` is True,
|
|
# re-generate random sigma for each channel
|
|
if (sigma is None or self.different_sigma_per_channel):
|
|
if (not self.different_sigma_per_axis):
|
|
sigma = self._get_valid_sigma(self.sigma_range)
|
|
else:
|
|
sigma = [
|
|
self._get_valid_sigma(self.sigma_range)
|
|
for _ in data_sample.shape[1:]
|
|
]
|
|
# apply gaussian filter with `sigma`
|
|
data_sample[c] = gaussian_filter(
|
|
data_sample[c], sigma, order=0)
|
|
return data_sample
|
|
|
|
def transform(self, results: Dict) -> Dict:
|
|
"""Call function to add random Gaussian blur to image.
|
|
|
|
Args:
|
|
results (dict): Result dict.
|
|
|
|
Returns:
|
|
dict: Result dict with random Gaussian noise.
|
|
"""
|
|
if np.random.rand() < self.prob:
|
|
results['img'] = self._gaussian_blur(results['img'])
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(prob={self.prob}, '
|
|
repr_str += f'prob_per_channel={self.prob_per_channel}, '
|
|
repr_str += f'sigma_range={self.sigma_range}, '
|
|
repr_str += 'different_sigma_per_channel='\
|
|
f'{self.different_sigma_per_channel}, '
|
|
repr_str += 'different_sigma_per_axis='\
|
|
f'{self.different_sigma_per_axis})'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class BioMedicalRandomGamma(BaseTransform):
|
|
"""Using random gamma correction to process the biomedical image.
|
|
|
|
Modified from
|
|
https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/color_transforms.py#L132 # noqa:E501
|
|
With licence: Apache 2.0
|
|
|
|
Required Keys:
|
|
|
|
- img (np.ndarray): Biomedical image with shape (N, Z, Y, X),
|
|
N is the number of modalities, and data type is float32.
|
|
|
|
Modified Keys:
|
|
- img
|
|
|
|
Args:
|
|
prob (float): The probability to perform this transform. Default: 0.5.
|
|
gamma_range (Tuple[float]): Range of gamma values. Default: (0.5, 2).
|
|
invert_image (bool): Whether invert the image before applying gamma
|
|
augmentation. Default: False.
|
|
per_channel (bool): Whether perform the transform each channel
|
|
individually. Default: False
|
|
retain_stats (bool): Gamma transformation will alter the mean and std
|
|
of the data in the patch. If retain_stats=True, the data will be
|
|
transformed to match the mean and standard deviation before gamma
|
|
augmentation. Default: False.
|
|
"""
|
|
|
|
def __init__(self,
|
|
prob: float = 0.5,
|
|
gamma_range: Tuple[float] = (0.5, 2),
|
|
invert_image: bool = False,
|
|
per_channel: bool = False,
|
|
retain_stats: bool = False):
|
|
assert 0 <= prob and prob <= 1
|
|
assert isinstance(gamma_range, tuple) and len(gamma_range) == 2
|
|
assert isinstance(invert_image, bool)
|
|
assert isinstance(per_channel, bool)
|
|
assert isinstance(retain_stats, bool)
|
|
self.prob = prob
|
|
self.gamma_range = gamma_range
|
|
self.invert_image = invert_image
|
|
self.per_channel = per_channel
|
|
self.retain_stats = retain_stats
|
|
|
|
@cache_randomness
|
|
def _do_gamma(self):
|
|
"""Whether do adjust gamma for image."""
|
|
return np.random.rand() < self.prob
|
|
|
|
def _adjust_gamma(self, img: np.array):
|
|
"""Gamma adjustment for image.
|
|
|
|
Args:
|
|
img (np.array): Input image before gamma adjust.
|
|
|
|
Returns:
|
|
np.arrays: Image after gamma adjust.
|
|
"""
|
|
|
|
if self.invert_image:
|
|
img = -img
|
|
|
|
def _do_adjust(img):
|
|
if retain_stats_here:
|
|
img_mean = img.mean()
|
|
img_std = img.std()
|
|
if np.random.random() < 0.5 and self.gamma_range[0] < 1:
|
|
gamma = np.random.uniform(self.gamma_range[0], 1)
|
|
else:
|
|
gamma = np.random.uniform(
|
|
max(self.gamma_range[0], 1), self.gamma_range[1])
|
|
img_min = img.min()
|
|
img_range = img.max() - img_min # range
|
|
img = np.power(((img - img_min) / float(img_range + 1e-7)),
|
|
gamma) * img_range + img_min
|
|
if retain_stats_here:
|
|
img = img - img.mean()
|
|
img = img / (img.std() + 1e-8) * img_std
|
|
img = img + img_mean
|
|
return img
|
|
|
|
if not self.per_channel:
|
|
retain_stats_here = self.retain_stats
|
|
img = _do_adjust(img)
|
|
else:
|
|
for c in range(img.shape[0]):
|
|
img[c] = _do_adjust(img[c])
|
|
if self.invert_image:
|
|
img = -img
|
|
return img
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Call function to perform random gamma correction
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Result dict with random gamma correction performed.
|
|
"""
|
|
do_gamma = self._do_gamma()
|
|
|
|
if do_gamma:
|
|
results['img'] = self._adjust_gamma(results['img'])
|
|
else:
|
|
pass
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(prob={self.prob}, '
|
|
repr_str += f'gamma_range={self.gamma_range},'
|
|
repr_str += f'invert_image={self.invert_image},'
|
|
repr_str += f'per_channel={self.per_channel},'
|
|
repr_str += f'retain_stats={self.retain_stats}'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class BioMedical3DPad(BaseTransform):
|
|
"""Pad the biomedical 3d image & biomedical 3d semantic segmentation maps.
|
|
|
|
Required Keys:
|
|
|
|
- img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default,
|
|
N is the number of modalities.
|
|
- gt_seg_map (np.ndarray, optional): Biomedical seg map with shape
|
|
(Z, Y, X) by default.
|
|
|
|
Modified Keys:
|
|
|
|
- img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default,
|
|
N is the number of modalities.
|
|
- gt_seg_map (np.ndarray, optional): Biomedical seg map with shape
|
|
(Z, Y, X) by default.
|
|
|
|
Added Keys:
|
|
|
|
- pad_shape (Tuple[int, int, int]): The padded shape.
|
|
|
|
Args:
|
|
pad_shape (Tuple[int, int, int]): Fixed padding size.
|
|
Expected padding shape (Z, Y, X).
|
|
pad_val (float): Padding value for biomedical image.
|
|
The padding mode is set to "constant". The value
|
|
to be filled in padding area. Default: 0.
|
|
seg_pad_val (int): Padding value for biomedical 3d semantic
|
|
segmentation maps. The padding mode is set to "constant".
|
|
The value to be filled in padding area. Default: 0.
|
|
"""
|
|
|
|
def __init__(self,
|
|
pad_shape: Tuple[int, int, int],
|
|
pad_val: float = 0.,
|
|
seg_pad_val: int = 0) -> None:
|
|
|
|
# check pad_shape
|
|
assert pad_shape is not None
|
|
if not isinstance(pad_shape, tuple):
|
|
assert len(pad_shape) == 3
|
|
|
|
self.pad_shape = pad_shape
|
|
self.pad_val = pad_val
|
|
self.seg_pad_val = seg_pad_val
|
|
|
|
def _pad_img(self, results: dict) -> None:
|
|
"""Pad images according to ``self.pad_shape``
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: The dict contains the padded image and shape
|
|
information.
|
|
"""
|
|
padded_img = self._to_pad(
|
|
results['img'], pad_shape=self.pad_shape, pad_val=self.pad_val)
|
|
|
|
results['img'] = padded_img
|
|
results['pad_shape'] = padded_img.shape[1:]
|
|
|
|
def _pad_seg(self, results: dict) -> None:
|
|
"""Pad semantic segmentation map according to ``self.pad_shape`` if
|
|
``gt_seg_map`` is not None in results dict.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Update the padded gt seg map in dict.
|
|
"""
|
|
if results.get('gt_seg_map', None) is not None:
|
|
pad_gt_seg = self._to_pad(
|
|
results['gt_seg_map'][None, ...],
|
|
pad_shape=results['pad_shape'],
|
|
pad_val=self.seg_pad_val)
|
|
results['gt_seg_map'] = pad_gt_seg[1:]
|
|
|
|
@staticmethod
|
|
def _to_pad(img: np.ndarray,
|
|
pad_shape: Tuple[int, int, int],
|
|
pad_val: Union[int, float] = 0) -> np.ndarray:
|
|
"""Pad the given 3d image to a certain shape with specified padding
|
|
value.
|
|
|
|
Args:
|
|
img (ndarray): Biomedical image with shape (N, Z, Y, X)
|
|
to be padded. N is the number of modalities.
|
|
pad_shape (Tuple[int,int,int]): Expected padding shape (Z, Y, X).
|
|
pad_val (float, int): Values to be filled in padding areas
|
|
and the padding_mode is set to 'constant'. Default: 0.
|
|
|
|
Returns:
|
|
ndarray: The padded image.
|
|
"""
|
|
# compute pad width
|
|
d = max(pad_shape[0] - img.shape[1], 0)
|
|
pad_d = (d // 2, d - d // 2)
|
|
h = max(pad_shape[1] - img.shape[2], 0)
|
|
pad_h = (h // 2, h - h // 2)
|
|
w = max(pad_shape[2] - img.shape[2], 0)
|
|
pad_w = (w // 2, w - w // 2)
|
|
|
|
pad_list = [(0, 0), pad_d, pad_h, pad_w]
|
|
|
|
img = np.pad(img, pad_list, mode='constant', constant_values=pad_val)
|
|
return img
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
"""Call function to pad images, semantic segmentation maps.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Updated result dict.
|
|
"""
|
|
self._pad_img(results)
|
|
self._pad_seg(results)
|
|
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'pad_shape={self.pad_shape}, '
|
|
repr_str += f'pad_val={self.pad_val}), '
|
|
repr_str += f'seg_pad_val={self.seg_pad_val})'
|
|
return repr_str
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class BioMedical3DRandomFlip(BaseTransform):
|
|
"""Flip biomedical 3D images and segmentations.
|
|
|
|
Modified from https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/transforms/spatial_transforms.py # noqa:E501
|
|
|
|
Copyright 2021 Division of
|
|
Medical Image Computing, German Cancer Research Center (DKFZ) and Applied
|
|
Computer Vision Lab, Helmholtz Imaging Platform.
|
|
Licensed under the Apache-2.0 License.
|
|
|
|
Required Keys:
|
|
|
|
- img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default,
|
|
N is the number of modalities.
|
|
- gt_seg_map (np.ndarray, optional): Biomedical seg map with shape
|
|
(Z, Y, X) by default.
|
|
|
|
Modified Keys:
|
|
|
|
- img (np.ndarry): Biomedical image with shape (N, Z, Y, X) by default,
|
|
N is the number of modalities.
|
|
- gt_seg_map (np.ndarray, optional): Biomedical seg map with shape
|
|
(Z, Y, X) by default.
|
|
|
|
Added Keys:
|
|
|
|
- do_flip
|
|
- flip_axes
|
|
|
|
Args:
|
|
prob (float): Flipping probability.
|
|
axes (Tuple[int, ...]): Flipping axes with order 'ZXY'.
|
|
swap_label_pairs (Optional[List[Tuple[int, int]]]):
|
|
The segmentation label pairs that are swapped when flipping.
|
|
"""
|
|
|
|
def __init__(self,
|
|
prob: float,
|
|
axes: Tuple[int, ...],
|
|
swap_label_pairs: Optional[List[Tuple[int, int]]] = None):
|
|
self.prob = prob
|
|
self.axes = axes
|
|
self.swap_label_pairs = swap_label_pairs
|
|
assert prob >= 0 and prob <= 1
|
|
if axes is not None:
|
|
assert max(axes) <= 2
|
|
|
|
@staticmethod
|
|
def _flip(img, direction: Tuple[bool, bool, bool]) -> np.ndarray:
|
|
if direction[0]:
|
|
img[:, :] = img[:, ::-1]
|
|
if direction[1]:
|
|
img[:, :, :] = img[:, :, ::-1]
|
|
if direction[2]:
|
|
img[:, :, :, :] = img[:, :, :, ::-1]
|
|
return img
|
|
|
|
def _do_flip(self, img: np.ndarray) -> Tuple[bool, bool, bool]:
|
|
"""Call function to determine which axis to flip.
|
|
|
|
Args:
|
|
img (np.ndarry): Image or segmentation map array.
|
|
Returns:
|
|
tuple: Flip action, whether to flip on the z, x, and y axes.
|
|
"""
|
|
flip_c, flip_x, flip_y = False, False, False
|
|
if self.axes is not None:
|
|
flip_c = 0 in self.axes and np.random.rand() < self.prob
|
|
flip_x = 1 in self.axes and np.random.rand() < self.prob
|
|
if len(img.shape) == 4:
|
|
flip_y = 2 in self.axes and np.random.rand() < self.prob
|
|
return flip_c, flip_x, flip_y
|
|
|
|
def _swap_label(self, seg: np.ndarray) -> np.ndarray:
|
|
out = seg.copy()
|
|
for first, second in self.swap_label_pairs:
|
|
first_area = (seg == first)
|
|
second_area = (seg == second)
|
|
out[first_area] = second
|
|
out[second_area] = first
|
|
return out
|
|
|
|
def transform(self, results: Dict) -> Dict:
|
|
"""Call function to flip and swap pair labels.
|
|
|
|
Args:
|
|
results (dict): Result dict.
|
|
Returns:
|
|
dict: Flipped results, 'do_flip', 'flip_axes' keys are added into
|
|
result dict.
|
|
"""
|
|
# get actual flipped axis
|
|
if 'do_flip' not in results:
|
|
results['do_flip'] = self._do_flip(results['img'])
|
|
if 'flip_axes' not in results:
|
|
results['flip_axes'] = self.axes
|
|
# flip image
|
|
results['img'] = self._flip(
|
|
results['img'], direction=results['do_flip'])
|
|
# flip seg
|
|
if results['gt_seg_map'] is not None:
|
|
if results['gt_seg_map'].shape != results['img'].shape:
|
|
results['gt_seg_map'] = results['gt_seg_map'][None, :]
|
|
results['gt_seg_map'] = self._flip(
|
|
results['gt_seg_map'], direction=results['do_flip'])
|
|
results['gt_seg_map'] = results['gt_seg_map'].squeeze()
|
|
# swap label pairs
|
|
if self.swap_label_pairs is not None:
|
|
results['gt_seg_map'] = self._swap_label(results['gt_seg_map'])
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(prob={self.prob}, axes={self.axes}, ' \
|
|
f'swap_label_pairs={self.swap_label_pairs})'
|
|
return repr_str
|