mmpretrain/mmcls/datasets/pipelines/transforms.py

460 lines
17 KiB
Python

import math
import random
import mmcv
import numpy as np
from ..builder import PIPELINES
@PIPELINES.register_module()
class RandomCrop(object):
"""Crop the given Image at a random location.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
padding (int or sequence, optional): Optional padding on each border
of the image. If a sequence of length 4 is provided, it is used to
pad left, top, right, bottom borders respectively. If a sequence
of length 2 is provided, it is used to pad left/right, top/bottom
borders, respectively. Default: None, which means no padding.
pad_if_needed (boolean): It will pad the image if smaller than the
desired size to avoid raising an exception. Since cropping is done
after padding, the padding seems to be done at a random offset.
Default: False.
pad_val (Number | Sequence[Number]): Pixel pad_val value for constant
fill. If a tuple of length 3, it is used to pad_val R, G, B
channels respectively. Default: 0.
padding_mode (str): Type of padding. Should be: constant, edge,
reflect or symmetric. Default: constant.
-constant: Pads with a constant value, this value is specified
with pad_val.
-edge: pads with the last value at the edge of the image.
-reflect: Pads with reflection of image without repeating the
last value on the edge. For example, padding [1, 2, 3, 4]
with 2 elements on both sides in reflect mode will result
in [3, 2, 1, 2, 3, 4, 3, 2].
-symmetric: Pads with reflection of image repeating the last
value on the edge. For example, padding [1, 2, 3, 4] with
2 elements on both sides in symmetric mode will result in
[2, 1, 1, 2, 3, 4, 4, 3].
"""
def __init__(self,
size,
padding=None,
pad_if_needed=False,
pad_val=0,
padding_mode='constant'):
if isinstance(size, (tuple, list)):
self.size = size
else:
self.size = (size, size)
# check padding mode
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
self.padding = padding
self.pad_if_needed = pad_if_needed
self.pad_val = pad_val
self.padding_mode = padding_mode
@staticmethod
def get_params(img, output_size):
"""Get parameters for ``crop`` for a random crop.
Args:
img (ndarray): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: Params (xmin, ymin, target_height, target_width) to be
passed to ``crop`` for random crop.
"""
height = img.shape[0]
width = img.shape[1]
target_height, target_width = output_size
if width == target_width and height == target_height:
return 0, 0, height, width
xmin = random.randint(0, height - target_height)
ymin = random.randint(0, width - target_width)
return xmin, ymin, target_height, target_width
def __call__(self, results):
"""
Args:
img (ndarray): Image to be cropped.
"""
for key in results.get('img_fields', ['img']):
img = results[key]
if self.padding is not None:
img = mmcv.impad(
img, padding=self.padding, pad_val=self.pad_val)
# pad the height if needed
if self.pad_if_needed and img.shape[0] < self.size[0]:
img = mmcv.impad(
img,
padding=(0, self.size[0] - img.shape[0], 0,
self.size[0] - img.shape[0]),
pad_val=self.pad_val,
padding_mode=self.padding_mode)
# pad the width if needed
if self.pad_if_needed and img.shape[1] < self.size[1]:
img = mmcv.impad(
img,
padding=(self.size[1] - img.shape[1], 0,
self.size[1] - img.shape[1], 0),
pad_val=self.pad_val,
padding_mode=self.padding_mode)
xmin, ymin, height, width = self.get_params(img, self.size)
results[key] = mmcv.imcrop(
img,
np.array([ymin, xmin, ymin + width - 1, xmin + height - 1]))
return results
def __repr__(self):
return (self.__class__.__name__ +
f'(size={self.size}, padding={self.padding})')
@PIPELINES.register_module()
class RandomResizedCrop(object):
"""Crop the given image to random size and aspect ratio.
A crop of random size (default: of 0.08 to 1.0) of the original size and a
random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio
is made. This crop is finally resized to given size.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
scale (tuple): Range of the random size of the cropped image compared
to the original image. Default: (0.08, 1.0).
ratio (tuple): Range of the random aspect ratio of the cropped image
compared to the original image. Default: (3. / 4., 4. / 3.).
interpolation (str): Interpolation method, accepted values are
'nearest', 'bilinear', 'bicubic', 'area', 'lanczos'. Default:
'bilinear'.
backend (str): The image resize backend type, accpeted values are
`cv2` and `pillow`. Default: `cv2`.
"""
def __init__(self,
size,
scale=(0.08, 1.0),
ratio=(3. / 4., 4. / 3.),
interpolation='bilinear',
backend='cv2'):
if isinstance(size, (tuple, list)):
self.size = size
else:
self.size = (size, size)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
raise ValueError("range should be of kind (min, max). "
f"But received {scale}")
if backend not in ['cv2', 'pillow']:
raise ValueError(f'backend: {backend} is not supported for resize.'
'Supported backends are "cv2", "pillow"')
self.interpolation = interpolation
self.scale = scale
self.ratio = ratio
self.backend = backend
@staticmethod
def get_params(img, scale, ratio):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (ndarray): Image to be cropped.
scale (tuple): Range of the random size of the cropped image
compared to the original image size.
ratio (tuple): Range of the random aspect ratio of the cropped
image compared to the original image area.
Returns:
tuple: Params (xmin, ymin, target_height, target_width) to be
passed to ``crop`` for a random sized crop.
"""
height = img.shape[0]
width = img.shape[1]
area = height * width
for _ in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
target_width = int(round(math.sqrt(target_area * aspect_ratio)))
target_height = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < target_width <= width and 0 < target_height <= height:
xmin = random.randint(0, height - target_height)
ymin = random.randint(0, width - target_width)
return xmin, ymin, target_height, target_width
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(ratio):
target_width = width
target_height = int(round(target_width / min(ratio)))
elif in_ratio > max(ratio):
target_height = height
target_width = int(round(target_height * max(ratio)))
else: # whole image
target_width = width
target_height = height
xmin = (height - target_height) // 2
ymin = (width - target_width) // 2
return xmin, ymin, target_height, target_width
def __call__(self, results):
"""
Args:
img (ndarray): Image to be cropped and resized.
Returns:
ndarray: Randomly cropped and resized image.
"""
for key in results.get('img_fields', ['img']):
img = results[key]
xmin, ymin, target_height, target_width = self.get_params(
img, self.scale, self.ratio)
img = mmcv.imcrop(
img,
np.array([
ymin, xmin, ymin + target_width - 1,
xmin + target_height - 1
]))
results[key] = mmcv.imresize(
img,
tuple(self.size[::-1]),
interpolation=self.interpolation,
backend=self.backend)
return results
def __repr__(self):
format_string = self.__class__.__name__ + f'(size={self.size}'
format_string += f', scale={tuple(round(s, 4) for s in self.scale)}'
format_string += f', ratio={tuple(round(r, 4) for r in self.ratio)}'
format_string += f', interpolation={self.interpolation})'
return format_string
@PIPELINES.register_module()
class RandomGrayscale(object):
"""Randomly convert image to grayscale with a probability of gray_prob.
Args:
gray_prob (float): Probability that image should be converted to
grayscale. Default: 0.1.
Returns:
ndarray: Grayscale version of the input image with probability
gray_prob and unchanged with probability (1-gray_prob).
- If input image is 1 channel: grayscale version is 1 channel.
- If input image is 3 channel: grayscale version is 3 channel
with r == g == b.
"""
def __init__(self, gray_prob=0.1):
self.gray_prob = gray_prob
def __call__(self, results):
"""
Args:
img (ndarray): Image to be converted to grayscale.
Returns:
ndarray: Randomly grayscaled image.
"""
for key in results.get('img_fields', ['img']):
img = results[key]
num_output_channels = img.shape[2]
if random.random() < self.gray_prob:
if num_output_channels > 1:
img = mmcv.rgb2gray(img)[:, :, None]
results[key] = np.dstack(
[img for _ in range(num_output_channels)])
return results
results[key] = img
return results
def __repr__(self):
return self.__class__.__name__ + f'(gray_prob={self.gray_prob})'
@PIPELINES.register_module()
class RandomFlip(object):
"""Flip the image randomly.
Flip the image randomly based on flip probaility and flip direction.
Args:
flip_prob (float): probability of the image being flipped. Default: 0.5
direction (str, optional): The flipping direction. Options are
'horizontal' and 'vertical'. Default: 'horizontal'.
"""
def __init__(self, flip_prob=0.5, direction='horizontal'):
assert 0 <= flip_prob <= 1
assert direction in ['horizontal', 'vertical']
self.flip_prob = flip_prob
self.direction = direction
def __call__(self, results):
"""Call function to flip image.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Flipped results, 'flip', 'flip_direction' keys are added into
result dict.
"""
flip = True if np.random.rand() < self.flip_prob else False
results['flip'] = flip
results['flip_direction'] = self.direction
if results['flip']:
# flip image
for key in results.get('img_fields', ['img']):
results[key] = mmcv.imflip(
results[key], direction=results['flip_direction'])
return results
def __repr__(self):
return self.__class__.__name__ + f'(flip_prob={self.flip_prob})'
@PIPELINES.register_module()
class Resize(object):
"""Resize images.
Args:
size (int | tuple): Images scales for resizing (h, w).
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos".
More details can be found in `mmcv.image.geometric`.
backend (str): The image resize backend type, accpeted values are
`cv2` and `pillow`. Default: `cv2`.
"""
def __init__(self, size, interpolation='bilinear', backend='cv2'):
assert isinstance(size, int) or (isinstance(size, tuple)
and len(size) == 2)
if isinstance(size, int):
size = (size, size)
assert size[0] > 0 and size[1] > 0
assert interpolation in ("nearest", "bilinear", "bicubic", "area",
"lanczos")
if backend not in ['cv2', 'pillow']:
raise ValueError(f'backend: {backend} is not supported for resize.'
'Supported backends are "cv2", "pillow"')
self.height = size[0]
self.width = size[1]
self.size = size
self.interpolation = interpolation
self.backend = backend
def _resize_img(self, results):
for key in results.get('img_fields', ['img']):
img = mmcv.imresize(
results[key],
size=(self.width, self.height),
interpolation=self.interpolation,
return_scale=False,
backend=self.backend)
results[key] = img
results['img_shape'] = img.shape
def __call__(self, results):
self._resize_img(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(size={self.size}, '
repr_str += f'interpolation={self.interpolation})'
return repr_str
@PIPELINES.register_module()
class CenterCrop(object):
"""Center crop the image.
Args:
crop_size (int | tuple): Expected size after cropping, (h, w).
Notes:
If the image is smaller than the crop size, return the original image
"""
def __init__(self, crop_size):
assert isinstance(crop_size, int) or (isinstance(crop_size, tuple)
and len(crop_size) == 2)
if isinstance(crop_size, int):
crop_size = (crop_size, crop_size)
assert crop_size[0] > 0 and crop_size[1] > 0
self.crop_size = crop_size
def __call__(self, results):
crop_height, crop_width = self.crop_size[0], self.crop_size[1]
for key in results.get('img_fields', ['img']):
img = results[key]
img_height, img_width, _ = img.shape
y1 = max(0, int(round((img_height - crop_height) / 2.)))
x1 = max(0, int(round((img_width - crop_width) / 2.)))
y2 = min(img_height, y1 + crop_height) - 1
x2 = min(img_width, x1 + crop_width) - 1
# crop the image
img = mmcv.imcrop(img, bboxes=np.array([x1, y1, x2, y2]))
img_shape = img.shape
results[key] = img
results['img_shape'] = img_shape
return results
def __repr__(self):
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
@PIPELINES.register_module()
class Normalize(object):
"""Normalize the image.
Args:
mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB,
default is true.
"""
def __init__(self, mean, std, to_rgb=True):
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
def __call__(self, results):
for key in results.get('img_fields', ['img']):
results[key] = mmcv.imnormalize(results[key], self.mean, self.std,
self.to_rgb)
results['img_norm_cfg'] = dict(
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(mean={list(self.mean)}, '
repr_str += f'std={list(self.std)}, '
repr_str += f'to_rgb={self.to_rgb})'
return repr_str