mirror of https://github.com/alibaba/EasyCV.git
1873 lines
70 KiB
Python
1873 lines
70 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import logging
|
|
import math
|
|
import os.path as osp
|
|
|
|
import cv2
|
|
import mmcv
|
|
import numpy as np
|
|
from numpy import random
|
|
from torchvision.transforms import functional as F
|
|
|
|
from easycv.datasets.registry import PIPELINES
|
|
from easycv.datasets.shared.pipelines.transforms import Compose
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class MMToTensor:
|
|
"""Transform image to Tensor.
|
|
|
|
Required key: 'img'. Modifies key: 'img'.
|
|
|
|
Args:
|
|
results (dict): contain all information about training.
|
|
"""
|
|
|
|
def __call__(self, results):
|
|
results['img'] = F.to_tensor(results['img'])
|
|
return results
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class NormalizeTensor:
|
|
"""Normalize the Tensor image (CxHxW), with mean and std.
|
|
|
|
Required key: 'img'. Modifies key: 'img'.
|
|
|
|
Args:
|
|
mean (list[float]): Mean values of 3 channels.
|
|
std (list[float]): Std values of 3 channels.
|
|
"""
|
|
|
|
def __init__(self, mean, std):
|
|
self.mean = mean
|
|
self.std = std
|
|
|
|
def __call__(self, results):
|
|
results['img'] = F.normalize(
|
|
results['img'], mean=self.mean, std=self.std)
|
|
return results
|
|
|
|
|
|
@PIPELINES.register_module
|
|
class MMMosaic(object):
|
|
"""Mosaic augmentation.
|
|
|
|
Given 4 images, mosaic transform combines them into
|
|
one output image. The output image is composed of the parts from each sub-
|
|
image.
|
|
|
|
.. code:: text
|
|
|
|
mosaic transform
|
|
center_x
|
|
+------------------------------+
|
|
| pad | pad |
|
|
| +-----------+ |
|
|
| | | |
|
|
| | image1 |--------+ |
|
|
| | | | |
|
|
| | | image2 | |
|
|
center_y |----+-------------+-----------|
|
|
| | cropped | |
|
|
|pad | image3 | image4 |
|
|
| | | |
|
|
+----|-------------+-----------+
|
|
| |
|
|
+-------------+
|
|
|
|
The mosaic transform steps are as follows:
|
|
|
|
1. Choose the mosaic center as the intersections of 4 images
|
|
2. Get the left top image according to the index, and randomly
|
|
sample another 3 images from the custom dataset.
|
|
3. Sub image will be cropped if image is larger than mosaic patch
|
|
|
|
Args:
|
|
img_scale (Sequence[int]): Image size after mosaic pipeline of single
|
|
image. Default to (640, 640).
|
|
center_ratio_range (Sequence[float]): Center ratio range of mosaic
|
|
output. Default to (0.5, 1.5).
|
|
pad_val (int): Pad value. Default to 114.
|
|
"""
|
|
|
|
def __init__(self,
|
|
img_scale=(640, 640),
|
|
center_ratio_range=(0.5, 1.5),
|
|
pad_val=114):
|
|
assert isinstance(img_scale, tuple)
|
|
self.img_scale = img_scale
|
|
self.center_ratio_range = center_ratio_range
|
|
self.pad_val = pad_val
|
|
|
|
def __call__(self, results):
|
|
"""Call function to make a mosaic of image.
|
|
|
|
Args:
|
|
results (dict): Result dict.
|
|
|
|
Returns:
|
|
dict: Result dict with mosaic transformed.
|
|
"""
|
|
|
|
results = self._mosaic_transform(results)
|
|
return results
|
|
|
|
def get_indexes(self, dataset):
|
|
"""Call function to collect indexes.
|
|
|
|
Args:
|
|
dataset (:obj:`DetImagesMixDataset`): The dataset.
|
|
|
|
Returns:
|
|
list: indexes.
|
|
"""
|
|
indexs = [random.randint(0, len(dataset)) for _ in range(3)]
|
|
return indexs
|
|
|
|
def _mosaic_transform(self, results):
|
|
"""Mosaic transform function.
|
|
|
|
Args:
|
|
results (dict): Result dict.
|
|
|
|
Returns:
|
|
dict: Updated result dict.
|
|
"""
|
|
|
|
assert 'mix_results' in results
|
|
mosaic_labels = []
|
|
mosaic_bboxes = []
|
|
if len(results['img'].shape) == 3:
|
|
mosaic_img = np.full(
|
|
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3),
|
|
self.pad_val,
|
|
dtype=results['img'].dtype)
|
|
else:
|
|
mosaic_img = np.full(
|
|
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
|
|
self.pad_val,
|
|
dtype=results['img'].dtype)
|
|
|
|
# mosaic center x, y
|
|
center_x = int(
|
|
random.uniform(*self.center_ratio_range) * self.img_scale[1])
|
|
center_y = int(
|
|
random.uniform(*self.center_ratio_range) * self.img_scale[0])
|
|
center_position = (center_x, center_y)
|
|
|
|
loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
|
for i, loc in enumerate(loc_strs):
|
|
if loc == 'top_left':
|
|
results_patch = copy.deepcopy(results)
|
|
else:
|
|
results_patch = copy.deepcopy(results['mix_results'][i - 1])
|
|
|
|
img_i = results_patch['img']
|
|
h_i, w_i = img_i.shape[:2]
|
|
# keep_ratio resize
|
|
scale_ratio_i = min(self.img_scale[0] / h_i,
|
|
self.img_scale[1] / w_i)
|
|
img_i = mmcv.imresize(
|
|
img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))
|
|
|
|
# compute the combine parameters
|
|
paste_coord, crop_coord = self._mosaic_combine(
|
|
loc, center_position, img_i.shape[:2][::-1])
|
|
x1_p, y1_p, x2_p, y2_p = paste_coord
|
|
x1_c, y1_c, x2_c, y2_c = crop_coord
|
|
|
|
# crop and paste image
|
|
mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]
|
|
|
|
# adjust coordinate
|
|
gt_bboxes_i = results_patch['gt_bboxes']
|
|
gt_labels_i = results_patch['gt_labels']
|
|
|
|
if gt_bboxes_i.shape[0] > 0:
|
|
padw = x1_p - x1_c
|
|
padh = y1_p - y1_c
|
|
gt_bboxes_i[:, 0::2] = \
|
|
scale_ratio_i * gt_bboxes_i[:, 0::2] + padw
|
|
gt_bboxes_i[:, 1::2] = \
|
|
scale_ratio_i * gt_bboxes_i[:, 1::2] + padh
|
|
|
|
mosaic_bboxes.append(gt_bboxes_i)
|
|
mosaic_labels.append(gt_labels_i)
|
|
|
|
if len(mosaic_labels) > 0:
|
|
mosaic_bboxes = np.concatenate(mosaic_bboxes, 0)
|
|
mosaic_bboxes[:, 0::2] = np.clip(mosaic_bboxes[:, 0::2], 0,
|
|
2 * self.img_scale[1])
|
|
mosaic_bboxes[:, 1::2] = np.clip(mosaic_bboxes[:, 1::2], 0,
|
|
2 * self.img_scale[0])
|
|
mosaic_labels = np.concatenate(mosaic_labels, 0)
|
|
|
|
results['img'] = mosaic_img
|
|
results['img_shape'] = mosaic_img.shape
|
|
results['ori_shape'] = mosaic_img.shape
|
|
results['gt_bboxes'] = mosaic_bboxes
|
|
results['gt_labels'] = mosaic_labels
|
|
|
|
return results
|
|
|
|
def _mosaic_combine(self, loc, center_position_xy, img_shape_wh):
|
|
"""Calculate global coordinate of mosaic image and local coordinate of
|
|
cropped sub-image.
|
|
|
|
Args:
|
|
loc (str): Index for the sub-image, loc in ('top_left',
|
|
'top_right', 'bottom_left', 'bottom_right').
|
|
center_position_xy (Sequence[float]): Mixing center for 4 images,
|
|
(x, y).
|
|
img_shape_wh (Sequence[int]): Width and height of sub-image
|
|
|
|
Returns:
|
|
tuple[tuple[float]]: Corresponding coordinate of pasting and
|
|
cropping
|
|
- paste_coord (tuple): paste corner coordinate in mosaic image.
|
|
- crop_coord (tuple): crop corner coordinate in mosaic image.
|
|
"""
|
|
|
|
assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right')
|
|
if loc == 'top_left':
|
|
# index0 to top left part of image
|
|
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
|
|
max(center_position_xy[1] - img_shape_wh[1], 0), \
|
|
center_position_xy[0], \
|
|
center_position_xy[1]
|
|
crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - (
|
|
y2 - y1), img_shape_wh[0], img_shape_wh[1]
|
|
|
|
elif loc == 'top_right':
|
|
# index1 to top right part of image
|
|
x1, y1, x2, y2 = center_position_xy[0], \
|
|
max(center_position_xy[1] - img_shape_wh[1], 0), \
|
|
min(center_position_xy[0] + img_shape_wh[0],
|
|
self.img_scale[1] * 2), \
|
|
center_position_xy[1]
|
|
crop_coord = 0, img_shape_wh[1] - (y2 - y1), min(
|
|
img_shape_wh[0], x2 - x1), img_shape_wh[1]
|
|
|
|
elif loc == 'bottom_left':
|
|
# index2 to bottom left part of image
|
|
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
|
|
center_position_xy[1], \
|
|
center_position_xy[0], \
|
|
min(self.img_scale[1] * 2, center_position_xy[1] +
|
|
img_shape_wh[1])
|
|
crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min(
|
|
y2 - y1, img_shape_wh[1])
|
|
|
|
else:
|
|
# index3 to bottom right part of image
|
|
x1, y1, x2, y2 = center_position_xy[0], \
|
|
center_position_xy[1], \
|
|
min(center_position_xy[0] + img_shape_wh[0],
|
|
self.img_scale[1] * 2), \
|
|
min(self.img_scale[0] * 2, center_position_xy[1] +
|
|
img_shape_wh[1])
|
|
crop_coord = 0, 0, min(img_shape_wh[0],
|
|
x2 - x1), min(y2 - y1, img_shape_wh[1])
|
|
|
|
paste_coord = x1, y1, x2, y2
|
|
return paste_coord, crop_coord
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'img_scale={self.img_scale}, '
|
|
repr_str += f'center_ratio_range={self.center_ratio_range})'
|
|
repr_str += f'pad_val={self.pad_val})'
|
|
return repr_str
|
|
|
|
|
|
@PIPELINES.register_module
|
|
class MMMixUp:
|
|
"""MixUp data augmentation.
|
|
|
|
.. code:: text
|
|
|
|
mixup transform
|
|
+------------------------------+
|
|
| mixup image | |
|
|
| +--------|--------+ |
|
|
| | | | |
|
|
|---------------+ | |
|
|
| | | |
|
|
| | image | |
|
|
| | | |
|
|
| | | |
|
|
| |-----------------+ |
|
|
| pad |
|
|
+------------------------------+
|
|
|
|
The mixup transform steps are as follows::
|
|
|
|
1. Another random image is picked by dataset and embedded in
|
|
the top left patch(after padding and resizing)
|
|
2. The target of mixup transform is the weighted average of mixup
|
|
image and origin image.
|
|
|
|
Args:
|
|
img_scale (Sequence[int]): Image output size after mixup pipeline.
|
|
Default: (640, 640).
|
|
ratio_range (Sequence[float]): Scale ratio of mixup image.
|
|
Default: (0.5, 1.5).
|
|
flip_ratio (float): Horizontal flip ratio of mixup image.
|
|
Default: 0.5.
|
|
pad_val (int): Pad value. Default: 114.
|
|
max_iters (int): The maximum number of iterations. If the number of
|
|
iterations is greater than `max_iters`, but gt_bbox is still
|
|
empty, then the iteration is terminated. Default: 15.
|
|
min_bbox_size (float): Width and height threshold to filter bboxes.
|
|
If the height or width of a box is smaller than this value, it
|
|
will be removed. Default: 5.
|
|
min_area_ratio (float): Threshold of area ratio between
|
|
original bboxes and wrapped bboxes. If smaller than this value,
|
|
the box will be removed. Default: 0.2.
|
|
max_aspect_ratio (float): Aspect ratio of width and height
|
|
threshold to filter bboxes. If max(h/w, w/h) larger than this
|
|
value, the box will be removed. Default: 20.
|
|
"""
|
|
|
|
def __init__(self,
|
|
img_scale=(640, 640),
|
|
ratio_range=(0.5, 1.5),
|
|
flip_ratio=0.5,
|
|
pad_val=114,
|
|
max_iters=15,
|
|
min_bbox_size=5,
|
|
min_area_ratio=0.2,
|
|
max_aspect_ratio=20):
|
|
assert isinstance(img_scale, tuple)
|
|
self.dynamic_scale = img_scale
|
|
self.ratio_range = ratio_range
|
|
self.flip_ratio = flip_ratio
|
|
self.pad_val = pad_val
|
|
self.max_iters = max_iters
|
|
self.min_bbox_size = min_bbox_size
|
|
self.min_area_ratio = min_area_ratio
|
|
self.max_aspect_ratio = max_aspect_ratio
|
|
|
|
def __call__(self, results):
|
|
"""Call function to make a mixup of image.
|
|
|
|
Args:
|
|
results (dict): Result dict.
|
|
|
|
Returns:
|
|
dict: Result dict with mixup transformed.
|
|
"""
|
|
|
|
results = self._mixup_transform(results)
|
|
return results
|
|
|
|
def get_indexes(self, dataset):
|
|
"""Call function to collect indexes.
|
|
|
|
Args:
|
|
dataset (:obj:`DetImagesMixDataset`): The dataset.
|
|
|
|
Returns:
|
|
list: indexes.
|
|
"""
|
|
|
|
for i in range(self.max_iters):
|
|
index = random.randint(0, len(dataset))
|
|
gt_bboxes_i = dataset.get_ann_info(index)['bboxes']
|
|
if len(gt_bboxes_i) != 0:
|
|
break
|
|
|
|
return index
|
|
|
|
def _mixup_transform(self, results):
|
|
"""MixUp transform function.
|
|
|
|
Args:
|
|
results (dict): Result dict.
|
|
|
|
Returns:
|
|
dict: Updated result dict.
|
|
"""
|
|
|
|
assert 'mix_results' in results
|
|
assert len(
|
|
results['mix_results']) == 1, 'MixUp only support 2 images now !'
|
|
|
|
if results['mix_results'][0]['gt_bboxes'].shape[0] == 0:
|
|
# empty bbox
|
|
return results
|
|
|
|
if 'scale' in results:
|
|
self.dynamic_scale = results['scale']
|
|
|
|
retrieve_results = results['mix_results'][0]
|
|
retrieve_img = retrieve_results['img']
|
|
|
|
jit_factor = random.uniform(*self.ratio_range)
|
|
is_filp = random.uniform(0, 1) > self.flip_ratio
|
|
|
|
if len(retrieve_img.shape) == 3:
|
|
out_img = np.ones(
|
|
(self.dynamic_scale[0], self.dynamic_scale[1], 3),
|
|
dtype=retrieve_img.dtype) * self.pad_val
|
|
else:
|
|
out_img = np.ones(
|
|
self.dynamic_scale, dtype=retrieve_img.dtype) * self.pad_val
|
|
|
|
# 1. keep_ratio resize
|
|
scale_ratio = min(self.dynamic_scale[0] / retrieve_img.shape[0],
|
|
self.dynamic_scale[1] / retrieve_img.shape[1])
|
|
retrieve_img = mmcv.imresize(
|
|
retrieve_img, (int(retrieve_img.shape[1] * scale_ratio),
|
|
int(retrieve_img.shape[0] * scale_ratio)))
|
|
|
|
# 2. paste
|
|
out_img[:retrieve_img.shape[0], :retrieve_img.shape[1]] = retrieve_img
|
|
|
|
# 3. scale jit
|
|
scale_ratio *= jit_factor
|
|
out_img = mmcv.imresize(out_img, (int(out_img.shape[1] * jit_factor),
|
|
int(out_img.shape[0] * jit_factor)))
|
|
|
|
# 4. flip
|
|
if is_filp:
|
|
out_img = out_img[:, ::-1, :]
|
|
|
|
# 5. random crop
|
|
ori_img = results['img']
|
|
origin_h, origin_w = out_img.shape[:2]
|
|
target_h, target_w = ori_img.shape[:2]
|
|
padded_img = np.zeros(
|
|
(max(origin_h, target_h), max(origin_w,
|
|
target_w), 3)).astype(np.uint8)
|
|
padded_img[:origin_h, :origin_w] = out_img
|
|
|
|
x_offset, y_offset = 0, 0
|
|
if padded_img.shape[0] > target_h:
|
|
y_offset = random.randint(0, padded_img.shape[0] - target_h)
|
|
if padded_img.shape[1] > target_w:
|
|
x_offset = random.randint(0, padded_img.shape[1] - target_w)
|
|
padded_cropped_img = padded_img[y_offset:y_offset + target_h,
|
|
x_offset:x_offset + target_w]
|
|
|
|
# 6. adjust bbox
|
|
retrieve_gt_bboxes = retrieve_results['gt_bboxes']
|
|
retrieve_gt_bboxes[:, 0::2] = np.clip(
|
|
retrieve_gt_bboxes[:, 0::2] * scale_ratio, 0, origin_w)
|
|
retrieve_gt_bboxes[:, 1::2] = np.clip(
|
|
retrieve_gt_bboxes[:, 1::2] * scale_ratio, 0, origin_h)
|
|
|
|
if is_filp:
|
|
retrieve_gt_bboxes[:, 0::2] = (
|
|
origin_w - retrieve_gt_bboxes[:, 0::2][:, ::-1])
|
|
|
|
# 7. filter
|
|
cp_retrieve_gt_bboxes = retrieve_gt_bboxes.copy()
|
|
cp_retrieve_gt_bboxes[:, 0::2] = np.clip(
|
|
cp_retrieve_gt_bboxes[:, 0::2] - x_offset, 0, target_w)
|
|
cp_retrieve_gt_bboxes[:, 1::2] = np.clip(
|
|
cp_retrieve_gt_bboxes[:, 1::2] - y_offset, 0, target_h)
|
|
keep_list = self._filter_box_candidates(retrieve_gt_bboxes.T,
|
|
cp_retrieve_gt_bboxes.T)
|
|
|
|
# 8. mix up
|
|
if keep_list.sum() >= 1.0:
|
|
ori_img = ori_img.astype(np.float32)
|
|
mixup_img = 0.5 * ori_img + 0.5 * padded_cropped_img.astype(
|
|
np.float32)
|
|
|
|
retrieve_gt_labels = retrieve_results['gt_labels'][keep_list]
|
|
retrieve_gt_bboxes = cp_retrieve_gt_bboxes[keep_list]
|
|
mixup_gt_bboxes = np.concatenate(
|
|
(results['gt_bboxes'], retrieve_gt_bboxes), axis=0)
|
|
mixup_gt_labels = np.concatenate(
|
|
(results['gt_labels'], retrieve_gt_labels), axis=0)
|
|
|
|
results['img'] = mixup_img
|
|
results['img_shape'] = mixup_img.shape
|
|
results['gt_bboxes'] = mixup_gt_bboxes
|
|
results['gt_labels'] = mixup_gt_labels
|
|
|
|
return results
|
|
|
|
def _filter_box_candidates(self, bbox1, bbox2):
|
|
"""Compute candidate boxes which include following 5 things:
|
|
|
|
bbox1 before augment, bbox2 after augment, min_bbox_size (pixels),
|
|
min_area_ratio, max_aspect_ratio.
|
|
"""
|
|
|
|
w1, h1 = bbox1[2] - bbox1[0], bbox1[3] - bbox1[1]
|
|
w2, h2 = bbox2[2] - bbox2[0], bbox2[3] - bbox2[1]
|
|
ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16))
|
|
return ((w2 > self.min_bbox_size)
|
|
& (h2 > self.min_bbox_size)
|
|
& (w2 * h2 / (w1 * h1 + 1e-16) > self.min_area_ratio)
|
|
& (ar < self.max_aspect_ratio))
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'dynamic_scale={self.dynamic_scale}, '
|
|
repr_str += f'ratio_range={self.ratio_range})'
|
|
repr_str += f'flip_ratio={self.flip_ratio})'
|
|
repr_str += f'pad_val={self.pad_val})'
|
|
repr_str += f'max_iters={self.max_iters})'
|
|
repr_str += f'min_bbox_size={self.min_bbox_size})'
|
|
repr_str += f'min_area_ratio={self.min_area_ratio})'
|
|
repr_str += f'max_aspect_ratio={self.max_aspect_ratio})'
|
|
return repr_str
|
|
|
|
|
|
@PIPELINES.register_module
|
|
class MMRandomAffine:
|
|
"""Random affine transform data augmentation. for yolox
|
|
|
|
This operation randomly generates affine transform matrix which including
|
|
rotation, translation, shear and scaling transforms.
|
|
|
|
Args:
|
|
max_rotate_degree (float): Maximum degrees of rotation transform.
|
|
Default: 10.
|
|
max_translate_ratio (float): Maximum ratio of translation.
|
|
Default: 0.1.
|
|
scaling_ratio_range (tuple[float]): Min and max ratio of
|
|
scaling transform. Default: (0.5, 1.5).
|
|
max_shear_degree (float): Maximum degrees of shear
|
|
transform. Default: 2.
|
|
border (tuple[int]): Distance from height and width sides of input
|
|
image to adjust output shape. Only used in mosaic dataset.
|
|
Default: (0, 0).
|
|
border_val (tuple[int]): Border padding values of 3 channels.
|
|
Default: (114, 114, 114).
|
|
min_bbox_size (float): Width and height threshold to filter bboxes.
|
|
If the height or width of a box is smaller than this value, it
|
|
will be removed. Default: 2.
|
|
min_area_ratio (float): Threshold of area ratio between
|
|
original bboxes and wrapped bboxes. If smaller than this value,
|
|
the box will be removed. Default: 0.2.
|
|
max_aspect_ratio (float): Aspect ratio of width and height
|
|
threshold to filter bboxes. If max(h/w, w/h) larger than this
|
|
value, the box will be removed.
|
|
"""
|
|
|
|
def __init__(self,
|
|
max_rotate_degree=10.0,
|
|
max_translate_ratio=0.1,
|
|
scaling_ratio_range=(0.5, 1.5),
|
|
max_shear_degree=2.0,
|
|
border=(0, 0),
|
|
border_val=(114, 114, 114),
|
|
min_bbox_size=2,
|
|
min_area_ratio=0.2,
|
|
max_aspect_ratio=20):
|
|
assert 0 <= max_translate_ratio <= 1
|
|
assert scaling_ratio_range[0] <= scaling_ratio_range[1]
|
|
assert scaling_ratio_range[0] > 0
|
|
self.max_rotate_degree = max_rotate_degree
|
|
self.max_translate_ratio = max_translate_ratio
|
|
self.scaling_ratio_range = scaling_ratio_range
|
|
self.max_shear_degree = max_shear_degree
|
|
self.border = border
|
|
self.border_val = border_val
|
|
self.min_bbox_size = min_bbox_size
|
|
self.min_area_ratio = min_area_ratio
|
|
self.max_aspect_ratio = max_aspect_ratio
|
|
|
|
def __call__(self, results):
|
|
img = results['img']
|
|
height = img.shape[0] + self.border[0] * 2
|
|
width = img.shape[1] + self.border[1] * 2
|
|
|
|
# Center
|
|
center_matrix = np.eye(3, dtype=np.float32)
|
|
center_matrix[0, 2] = -img.shape[1] / 2 # x translation (pixels)
|
|
center_matrix[1, 2] = -img.shape[0] / 2 # y translation (pixels)
|
|
|
|
# Rotation
|
|
rotation_degree = random.uniform(-self.max_rotate_degree,
|
|
self.max_rotate_degree)
|
|
rotation_matrix = self._get_rotation_matrix(rotation_degree)
|
|
|
|
# Scaling
|
|
scaling_ratio = random.uniform(self.scaling_ratio_range[0],
|
|
self.scaling_ratio_range[1])
|
|
scaling_matrix = self._get_scaling_matrix(scaling_ratio)
|
|
|
|
# Shear
|
|
x_degree = random.uniform(-self.max_shear_degree,
|
|
self.max_shear_degree)
|
|
y_degree = random.uniform(-self.max_shear_degree,
|
|
self.max_shear_degree)
|
|
shear_matrix = self._get_shear_matrix(x_degree, y_degree)
|
|
|
|
# Translation
|
|
trans_x = random.uniform(0.5 - self.max_translate_ratio,
|
|
0.5 + self.max_translate_ratio) * width
|
|
trans_y = random.uniform(0.5 - self.max_translate_ratio,
|
|
0.5 + self.max_translate_ratio) * height
|
|
translate_matrix = self._get_translation_matrix(trans_x, trans_y)
|
|
|
|
warp_matrix = (
|
|
translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix
|
|
@ center_matrix)
|
|
|
|
img = cv2.warpPerspective(
|
|
img,
|
|
warp_matrix,
|
|
dsize=(width, height),
|
|
borderValue=self.border_val)
|
|
results['img'] = img
|
|
results['img_shape'] = img.shape
|
|
|
|
for key in results.get('bbox_fields', []):
|
|
bboxes = results[key]
|
|
num_bboxes = len(bboxes)
|
|
if num_bboxes:
|
|
# homogeneous coordinates
|
|
xs = bboxes[:, [0, 2, 2, 0]].reshape(num_bboxes * 4)
|
|
ys = bboxes[:, [1, 3, 3, 1]].reshape(num_bboxes * 4)
|
|
ones = np.ones_like(xs)
|
|
points = np.vstack([xs, ys, ones])
|
|
|
|
warp_points = warp_matrix @ points
|
|
warp_points = warp_points[:2] / warp_points[2]
|
|
xs = warp_points[0].reshape(num_bboxes, 4)
|
|
ys = warp_points[1].reshape(num_bboxes, 4)
|
|
|
|
warp_bboxes = np.vstack(
|
|
(xs.min(1), ys.min(1), xs.max(1), ys.max(1))).T
|
|
|
|
warp_bboxes[:, [0, 2]] = warp_bboxes[:, [0, 2]].clip(0, width)
|
|
warp_bboxes[:, [1, 3]] = warp_bboxes[:, [1, 3]].clip(0, height)
|
|
|
|
# filter bboxes
|
|
valid_index = self.filter_gt_bboxes(bboxes * scaling_ratio,
|
|
warp_bboxes)
|
|
results[key] = warp_bboxes[valid_index]
|
|
if key in ['gt_bboxes']:
|
|
if 'gt_labels' in results:
|
|
results['gt_labels'] = results['gt_labels'][
|
|
valid_index]
|
|
if 'gt_masks' in results:
|
|
raise NotImplementedError(
|
|
'RandomAffine only supports bbox.')
|
|
return results
|
|
|
|
def filter_gt_bboxes(self, origin_bboxes, wrapped_bboxes):
|
|
origin_w = origin_bboxes[:, 2] - origin_bboxes[:, 0]
|
|
origin_h = origin_bboxes[:, 3] - origin_bboxes[:, 1]
|
|
wrapped_w = wrapped_bboxes[:, 2] - wrapped_bboxes[:, 0]
|
|
wrapped_h = wrapped_bboxes[:, 3] - wrapped_bboxes[:, 1]
|
|
aspect_ratio = np.maximum(wrapped_w / (wrapped_h + 1e-16),
|
|
wrapped_h / (wrapped_w + 1e-16))
|
|
|
|
wh_valid_idx = (wrapped_w > self.min_bbox_size) & \
|
|
(wrapped_h > self.min_bbox_size)
|
|
area_valid_idx = wrapped_w * wrapped_h / (origin_w * origin_h +
|
|
1e-16) > self.min_area_ratio
|
|
aspect_ratio_valid_idx = aspect_ratio < self.max_aspect_ratio
|
|
return wh_valid_idx & area_valid_idx & aspect_ratio_valid_idx
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(max_rotate_degree={self.max_rotate_degree}, '
|
|
repr_str += f'max_translate_ratio={self.max_translate_ratio}, '
|
|
repr_str += f'scaling_ratio={self.scaling_ratio_range}, '
|
|
repr_str += f'max_shear_degree={self.max_shear_degree}, '
|
|
repr_str += f'border={self.border}, '
|
|
repr_str += f'border_val={self.border_val}, '
|
|
repr_str += f'min_bbox_size={self.min_bbox_size}, '
|
|
repr_str += f'min_area_ratio={self.min_area_ratio}, '
|
|
repr_str += f'max_aspect_ratio={self.max_aspect_ratio})'
|
|
return repr_str
|
|
|
|
@staticmethod
|
|
def _get_rotation_matrix(rotate_degrees):
|
|
radian = math.radians(rotate_degrees)
|
|
rotation_matrix = np.array(
|
|
[[np.cos(radian), -np.sin(radian), 0.],
|
|
[np.sin(radian), np.cos(radian), 0.], [0., 0., 1.]],
|
|
dtype=np.float32)
|
|
return rotation_matrix
|
|
|
|
@staticmethod
|
|
def _get_scaling_matrix(scale_ratio):
|
|
scaling_matrix = np.array(
|
|
[[scale_ratio, 0., 0.], [0., scale_ratio, 0.], [0., 0., 1.]],
|
|
dtype=np.float32)
|
|
return scaling_matrix
|
|
|
|
@staticmethod
|
|
def _get_share_matrix(scale_ratio):
|
|
scaling_matrix = np.array(
|
|
[[scale_ratio, 0., 0.], [0., scale_ratio, 0.], [0., 0., 1.]],
|
|
dtype=np.float32)
|
|
return scaling_matrix
|
|
|
|
@staticmethod
|
|
def _get_shear_matrix(x_shear_degrees, y_shear_degrees):
|
|
x_radian = math.radians(x_shear_degrees)
|
|
y_radian = math.radians(y_shear_degrees)
|
|
shear_matrix = np.array([[1, np.tan(x_radian), 0.],
|
|
[np.tan(y_radian), 1, 0.], [0., 0., 1.]],
|
|
dtype=np.float32)
|
|
return shear_matrix
|
|
|
|
@staticmethod
|
|
def _get_translation_matrix(x, y):
|
|
translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]],
|
|
dtype=np.float32)
|
|
return translation_matrix
|
|
|
|
|
|
@PIPELINES.register_module
|
|
class MMPhotoMetricDistortion:
|
|
"""Apply photometric distortion to image sequentially, every transformation
|
|
is applied with a probability of 0.5. The position of random contrast is in
|
|
second or second to last.
|
|
|
|
1. random brightness
|
|
2. random contrast (mode 0)
|
|
3. convert color from BGR to HSV
|
|
4. random saturation
|
|
5. random hue
|
|
6. convert color from HSV to BGR
|
|
7. random contrast (mode 1)
|
|
8. randomly swap channels
|
|
|
|
Args:
|
|
brightness_delta (int): delta of brightness.
|
|
contrast_range (tuple): range of contrast.
|
|
saturation_range (tuple): range of saturation.
|
|
hue_delta (int): delta of hue.
|
|
"""
|
|
|
|
def __init__(self,
|
|
brightness_delta=32,
|
|
contrast_range=(0.5, 1.5),
|
|
saturation_range=(0.5, 1.5),
|
|
hue_delta=18):
|
|
self.brightness_delta = brightness_delta
|
|
self.contrast_lower, self.contrast_upper = contrast_range
|
|
self.saturation_lower, self.saturation_upper = saturation_range
|
|
self.hue_delta = hue_delta
|
|
|
|
def __call__(self, results):
|
|
"""Call function to perform photometric distortion on images.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Result dict with images distorted.
|
|
"""
|
|
|
|
if 'img_fields' in results:
|
|
assert results['img_fields'] == ['img'], \
|
|
'Only single img_fields is allowed'
|
|
img = results['img']
|
|
assert img.dtype == np.float32, \
|
|
'PhotoMetricDistortion needs the input image of dtype ' \
|
|
'np.float32, please set "to_float32=True" in ' \
|
|
'"LoadImageFromFile" pipeline'
|
|
# random brightness
|
|
if random.randint(2):
|
|
delta = random.uniform(-self.brightness_delta,
|
|
self.brightness_delta)
|
|
img += delta
|
|
|
|
# mode == 0 --> do random contrast first
|
|
# mode == 1 --> do random contrast last
|
|
mode = random.randint(2)
|
|
if mode == 1:
|
|
if random.randint(2):
|
|
alpha = random.uniform(self.contrast_lower,
|
|
self.contrast_upper)
|
|
img *= alpha
|
|
|
|
# convert color from BGR to HSV
|
|
img = mmcv.bgr2hsv(img)
|
|
|
|
# random saturation
|
|
if random.randint(2):
|
|
img[..., 1] *= random.uniform(self.saturation_lower,
|
|
self.saturation_upper)
|
|
|
|
# random hue
|
|
if random.randint(2):
|
|
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
|
|
img[..., 0][img[..., 0] > 360] -= 360
|
|
img[..., 0][img[..., 0] < 0] += 360
|
|
|
|
# convert color from HSV to BGR
|
|
img = mmcv.hsv2bgr(img)
|
|
|
|
# random contrast
|
|
if mode == 0:
|
|
if random.randint(2):
|
|
alpha = random.uniform(self.contrast_lower,
|
|
self.contrast_upper)
|
|
img *= alpha
|
|
|
|
# randomly swap channels
|
|
if random.randint(2):
|
|
img = img[..., random.permutation(3)]
|
|
|
|
results['img'] = img
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(\nbrightness_delta={self.brightness_delta},\n'
|
|
repr_str += 'contrast_range='
|
|
repr_str += f'{(self.contrast_lower, self.contrast_upper)},\n'
|
|
repr_str += 'saturation_range='
|
|
repr_str += f'{(self.saturation_lower, self.saturation_upper)},\n'
|
|
repr_str += f'hue_delta={self.hue_delta})'
|
|
return repr_str
|
|
|
|
|
|
@PIPELINES.register_module
|
|
class MMResize:
|
|
"""Resize images & bbox & mask.
|
|
|
|
This transform resizes the input image to some scale. Bboxes and masks are
|
|
then resized with the same scale factor. If the input dict contains the key
|
|
"scale", then the scale in the input dict is used, otherwise the specified
|
|
scale in the init method is used. If the input dict contains the key
|
|
"scale_factor" (if MultiScaleFlipAug does not give img_scale but
|
|
scale_factor), the actual scale will be computed by image shape and
|
|
scale_factor.
|
|
|
|
`img_scale` can either be a tuple (single-scale) or a list of tuple
|
|
(multi-scale). There are 3 multiscale modes:
|
|
|
|
- ``ratio_range is not None``: randomly sample a ratio from the ratio \
|
|
range and multiply it with the image scale.
|
|
- ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \
|
|
sample a scale from the multiscale range.
|
|
- ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \
|
|
sample a scale from multiple scales.
|
|
|
|
Args:
|
|
img_scale (tuple or list[tuple]): Images scales for resizing.
|
|
multiscale_mode (str): Either "range" or "value".
|
|
ratio_range (tuple[float]): (min_ratio, max_ratio)
|
|
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
|
|
image.
|
|
bbox_clip_border (bool, optional): Whether clip the objects outside
|
|
the border of the image. Defaults to True.
|
|
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
|
|
These two backends generates slightly different results. Defaults
|
|
to 'cv2'.
|
|
override (bool, optional): Whether to override `scale` and
|
|
`scale_factor` so as to call resize twice. Default False. If True,
|
|
after the first resizing, the existed `scale` and `scale_factor`
|
|
will be ignored so the second resizing can be allowed.
|
|
This option is a work-around for multiple times of resize in DETR.
|
|
Defaults to False.
|
|
"""
|
|
|
|
def __init__(self,
|
|
img_scale=None,
|
|
multiscale_mode='range',
|
|
ratio_range=None,
|
|
keep_ratio=True,
|
|
bbox_clip_border=True,
|
|
backend='cv2',
|
|
override=False):
|
|
|
|
if img_scale is None:
|
|
self.img_scale = None
|
|
else:
|
|
if isinstance(img_scale, list) and isinstance(img_scale[0], tuple):
|
|
self.img_scale = img_scale
|
|
elif isinstance(img_scale, list):
|
|
self.img_scale = [tuple(img_scale)]
|
|
else:
|
|
self.img_scale = [img_scale]
|
|
assert mmcv.is_list_of(self.img_scale, tuple)
|
|
|
|
if ratio_range is not None:
|
|
# mode 1: given a scale and a range of image ratio
|
|
assert len(self.img_scale) == 1
|
|
else:
|
|
# mode 2: given multiple scales or a range of scales
|
|
assert multiscale_mode in ['value', 'range']
|
|
|
|
self.backend = backend
|
|
self.multiscale_mode = multiscale_mode
|
|
self.ratio_range = ratio_range
|
|
self.keep_ratio = keep_ratio
|
|
# TODO: refactor the override option in Resize
|
|
self.override = override
|
|
self.bbox_clip_border = bbox_clip_border
|
|
|
|
@staticmethod
|
|
def random_select(img_scales):
|
|
"""Randomly select an img_scale from given candidates.
|
|
|
|
Args:
|
|
img_scales (list[tuple]): Images scales for selection.
|
|
|
|
Returns:
|
|
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \
|
|
where ``img_scale`` is the selected image scale and \
|
|
``scale_idx`` is the selected index in the given candidates.
|
|
"""
|
|
|
|
assert mmcv.is_list_of(img_scales, tuple)
|
|
scale_idx = np.random.randint(len(img_scales))
|
|
img_scale = img_scales[scale_idx]
|
|
return img_scale, scale_idx
|
|
|
|
@staticmethod
|
|
def random_sample(img_scales):
|
|
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
|
|
|
|
Args:
|
|
img_scales (list[tuple]): Images scale range for sampling.
|
|
There must be two tuples in img_scales, which specify the lower
|
|
and upper bound of image scales.
|
|
|
|
Returns:
|
|
(tuple, None): Returns a tuple ``(img_scale, None)``, where \
|
|
``img_scale`` is sampled scale and None is just a placeholder \
|
|
to be consistent with :func:`random_select`.
|
|
"""
|
|
|
|
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
|
|
img_scale_long = [max(s) for s in img_scales]
|
|
img_scale_short = [min(s) for s in img_scales]
|
|
long_edge = np.random.randint(
|
|
min(img_scale_long),
|
|
max(img_scale_long) + 1)
|
|
short_edge = np.random.randint(
|
|
min(img_scale_short),
|
|
max(img_scale_short) + 1)
|
|
img_scale = (long_edge, short_edge)
|
|
return img_scale, None
|
|
|
|
@staticmethod
|
|
def random_sample_ratio(img_scale, ratio_range):
|
|
"""Randomly sample an img_scale when ``ratio_range`` is specified.
|
|
|
|
A ratio will be randomly sampled from the range specified by
|
|
``ratio_range``. Then it would be multiplied with ``img_scale`` to
|
|
generate sampled scale.
|
|
|
|
Args:
|
|
img_scale (tuple): Images scale base to multiply with ratio.
|
|
ratio_range (tuple[float]): The minimum and maximum ratio to scale
|
|
the ``img_scale``.
|
|
|
|
Returns:
|
|
(tuple, None): Returns a tuple ``(scale, None)``, where \
|
|
``scale`` is sampled ratio multiplied with ``img_scale`` and \
|
|
None is just a placeholder to be consistent with \
|
|
:func:`random_select`.
|
|
"""
|
|
|
|
assert isinstance(img_scale, tuple) and len(img_scale) == 2
|
|
min_ratio, max_ratio = ratio_range
|
|
assert min_ratio <= max_ratio
|
|
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
|
|
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
|
|
return scale, None
|
|
|
|
def _random_scale(self, results):
|
|
"""Randomly sample an img_scale according to ``ratio_range`` and
|
|
``multiscale_mode``.
|
|
|
|
If ``ratio_range`` is specified, a ratio will be sampled and be
|
|
multiplied with ``img_scale``.
|
|
If multiple scales are specified by ``img_scale``, a scale will be
|
|
sampled according to ``multiscale_mode``.
|
|
Otherwise, single scale will be used.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:`dataset`.
|
|
|
|
Returns:
|
|
dict: Two new keys 'scale` and 'scale_idx` are added into \
|
|
``results``, which would be used by subsequent pipelines.
|
|
"""
|
|
|
|
if self.ratio_range is not None:
|
|
scale, scale_idx = self.random_sample_ratio(
|
|
self.img_scale[0], self.ratio_range)
|
|
elif len(self.img_scale) == 1:
|
|
scale, scale_idx = self.img_scale[0], 0
|
|
elif self.multiscale_mode == 'range':
|
|
scale, scale_idx = self.random_sample(self.img_scale)
|
|
elif self.multiscale_mode == 'value':
|
|
scale, scale_idx = self.random_select(self.img_scale)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
results['scale'] = scale
|
|
results['scale_idx'] = scale_idx
|
|
|
|
def _resize_img(self, results):
|
|
"""Resize images with ``results['scale']``."""
|
|
for key in results.get('img_fields', ['img']):
|
|
if self.keep_ratio:
|
|
img, scale_factor = mmcv.imrescale(
|
|
results[key],
|
|
results['scale'],
|
|
return_scale=True,
|
|
backend=self.backend)
|
|
# the w_scale and h_scale has minor difference
|
|
# a real fix should be done in the mmcv.imrescale in the future
|
|
new_h, new_w = img.shape[:2]
|
|
h, w = results[key].shape[:2]
|
|
w_scale = new_w / w
|
|
h_scale = new_h / h
|
|
else:
|
|
img, w_scale, h_scale = mmcv.imresize(
|
|
results[key],
|
|
results['scale'],
|
|
return_scale=True,
|
|
backend=self.backend)
|
|
results[key] = img
|
|
|
|
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
|
|
dtype=np.float32)
|
|
results['img_shape'] = img.shape
|
|
# in case that there is no padding
|
|
results['pad_shape'] = img.shape
|
|
results['scale_factor'] = scale_factor
|
|
results['keep_ratio'] = self.keep_ratio
|
|
|
|
def _resize_bboxes(self, results):
|
|
"""Resize bounding boxes with ``results['scale_factor']``."""
|
|
for key in results.get('bbox_fields', []):
|
|
bboxes = results[key] * results['scale_factor']
|
|
if self.bbox_clip_border:
|
|
img_shape = results['img_shape']
|
|
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
|
|
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
|
|
results[key] = bboxes
|
|
|
|
def _resize_masks(self, results):
|
|
"""Resize masks with ``results['scale']``"""
|
|
for key in results.get('mask_fields', []):
|
|
if results[key] is None:
|
|
continue
|
|
if self.keep_ratio:
|
|
results[key] = results[key].rescale(results['scale'])
|
|
else:
|
|
results[key] = results[key].resize(results['img_shape'][:2])
|
|
|
|
def _resize_seg(self, results):
|
|
"""Resize semantic segmentation map with ``results['scale']``."""
|
|
for key in results.get('seg_fields', []):
|
|
if self.keep_ratio:
|
|
gt_seg = mmcv.imrescale(
|
|
results[key],
|
|
results['scale'],
|
|
interpolation='nearest',
|
|
backend=self.backend)
|
|
else:
|
|
gt_seg = mmcv.imresize(
|
|
results[key],
|
|
results['scale'],
|
|
interpolation='nearest',
|
|
backend=self.backend)
|
|
results['gt_semantic_seg'] = gt_seg
|
|
|
|
def __call__(self, results):
|
|
"""Call function to resize images, bounding boxes, masks, semantic
|
|
segmentation map.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \
|
|
'keep_ratio' keys are added into result dict.
|
|
"""
|
|
|
|
if 'scale' not in results:
|
|
if 'scale_factor' in results:
|
|
img_shape = results['img'].shape[:2]
|
|
scale_factor = results['scale_factor']
|
|
assert isinstance(scale_factor, float)
|
|
results['scale'] = tuple(
|
|
[int(x * scale_factor) for x in img_shape][::-1])
|
|
else:
|
|
self._random_scale(results)
|
|
else:
|
|
if not self.override:
|
|
assert 'scale_factor' not in results, (
|
|
'scale and scale_factor cannot be both set.')
|
|
else:
|
|
results.pop('scale')
|
|
if 'scale_factor' in results:
|
|
results.pop('scale_factor')
|
|
self._random_scale(results)
|
|
|
|
self._resize_img(results)
|
|
self._resize_bboxes(results)
|
|
self._resize_masks(results)
|
|
self._resize_seg(results)
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(img_scale={self.img_scale}, '
|
|
repr_str += f'multiscale_mode={self.multiscale_mode}, '
|
|
repr_str += f'ratio_range={self.ratio_range}, '
|
|
repr_str += f'keep_ratio={self.keep_ratio}, '
|
|
repr_str += f'bbox_clip_border={self.bbox_clip_border})'
|
|
return repr_str
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class MMRandomFlip:
|
|
"""Flip the image & bbox & mask.
|
|
|
|
If the input dict contains the key "flip", then the flag will be used,
|
|
otherwise it will be randomly decided by a ratio specified in the init
|
|
method.
|
|
|
|
When random flip is enabled, ``flip_ratio``/``direction`` can either be a
|
|
float/string or tuple of float/string. There are 3 flip modes:
|
|
|
|
- ``flip_ratio`` is float, ``direction`` is string: the image will be
|
|
``direction``ly flipped with probability of ``flip_ratio`` .
|
|
E.g., ``flip_ratio=0.5``, ``direction='horizontal'``,
|
|
then image will be horizontally flipped with probability of 0.5.
|
|
- ``flip_ratio`` is float, ``direction`` is list of string: the image wil
|
|
be ``direction[i]``ly flipped with probability of
|
|
``flip_ratio/len(direction)``.
|
|
E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``,
|
|
then image will be horizontally flipped with probability of 0.25,
|
|
vertically with probability of 0.25.
|
|
- ``flip_ratio`` is list of float, ``direction`` is list of string:
|
|
given ``len(flip_ratio) == len(direction)``, the image wil
|
|
be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``.
|
|
E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal',
|
|
'vertical']``, then image will be horizontally flipped with probability
|
|
of 0.3, vertically with probability of 0.5.
|
|
|
|
Args:
|
|
flip_ratio (float | list[float], optional): The flipping probability.
|
|
Default: None.
|
|
direction(str | list[str], optional): The flipping direction. Options
|
|
are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'.
|
|
If input is a list, the length must equal ``flip_ratio``. Each
|
|
element in ``flip_ratio`` indicates the flip probability of
|
|
corresponding direction.
|
|
"""
|
|
|
|
def __init__(self, flip_ratio=None, direction='horizontal'):
|
|
if isinstance(flip_ratio, list):
|
|
assert mmcv.is_list_of(flip_ratio, float)
|
|
assert 0 <= sum(flip_ratio) <= 1
|
|
elif isinstance(flip_ratio, float):
|
|
assert 0 <= flip_ratio <= 1
|
|
elif flip_ratio is None:
|
|
pass
|
|
else:
|
|
raise ValueError('flip_ratios must be None, float, '
|
|
'or list of float')
|
|
self.flip_ratio = flip_ratio
|
|
|
|
valid_directions = ['horizontal', 'vertical', 'diagonal']
|
|
if isinstance(direction, str):
|
|
assert direction in valid_directions
|
|
elif isinstance(direction, list):
|
|
assert mmcv.is_list_of(direction, str)
|
|
assert set(direction).issubset(set(valid_directions))
|
|
else:
|
|
raise ValueError('direction must be either str or list of str')
|
|
self.direction = direction
|
|
|
|
if isinstance(flip_ratio, list):
|
|
assert len(self.flip_ratio) == len(self.direction)
|
|
|
|
def bbox_flip(self, bboxes, img_shape, direction):
|
|
"""Flip bboxes horizontally.
|
|
|
|
Args:
|
|
bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k)
|
|
img_shape (tuple[int]): Image shape (height, width)
|
|
direction (str): Flip direction. Options are 'horizontal',
|
|
'vertical'.
|
|
|
|
Returns:
|
|
numpy.ndarray: Flipped bounding boxes.
|
|
"""
|
|
|
|
assert bboxes.shape[-1] % 4 == 0
|
|
flipped = bboxes.copy()
|
|
if direction == 'horizontal':
|
|
w = img_shape[1]
|
|
flipped[..., 0::4] = w - bboxes[..., 2::4]
|
|
flipped[..., 2::4] = w - bboxes[..., 0::4]
|
|
elif direction == 'vertical':
|
|
h = img_shape[0]
|
|
flipped[..., 1::4] = h - bboxes[..., 3::4]
|
|
flipped[..., 3::4] = h - bboxes[..., 1::4]
|
|
elif direction == 'diagonal':
|
|
w = img_shape[1]
|
|
h = img_shape[0]
|
|
flipped[..., 0::4] = w - bboxes[..., 2::4]
|
|
flipped[..., 1::4] = h - bboxes[..., 3::4]
|
|
flipped[..., 2::4] = w - bboxes[..., 0::4]
|
|
flipped[..., 3::4] = h - bboxes[..., 1::4]
|
|
else:
|
|
raise ValueError(f"Invalid flipping direction '{direction}'")
|
|
return flipped
|
|
|
|
def __call__(self, results):
|
|
"""Call function to flip bounding boxes, masks, semantic segmentation
|
|
maps.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Flipped results, 'flip', 'flip_direction' keys are added \
|
|
into result dict.
|
|
"""
|
|
|
|
if 'flip' not in results:
|
|
if isinstance(self.direction, list):
|
|
# None means non-flip
|
|
direction_list = self.direction + [None]
|
|
else:
|
|
# None means non-flip
|
|
direction_list = [self.direction, None]
|
|
|
|
if isinstance(self.flip_ratio, list):
|
|
non_flip_ratio = 1 - sum(self.flip_ratio)
|
|
flip_ratio_list = self.flip_ratio + [non_flip_ratio]
|
|
else:
|
|
non_flip_ratio = 1 - self.flip_ratio
|
|
# exclude non-flip
|
|
single_ratio = self.flip_ratio / (len(direction_list) - 1)
|
|
flip_ratio_list = [single_ratio] * (len(direction_list) -
|
|
1) + [non_flip_ratio]
|
|
|
|
cur_dir = np.random.choice(direction_list, p=flip_ratio_list)
|
|
|
|
results['flip'] = cur_dir is not None
|
|
if 'flip_direction' not in results:
|
|
results['flip_direction'] = cur_dir
|
|
if results['flip']:
|
|
# flip image
|
|
for key in results.get('img_fields', ['img']):
|
|
results[key] = mmcv.imflip(
|
|
results[key], direction=results['flip_direction'])
|
|
# flip bboxes
|
|
for key in results.get('bbox_fields', []):
|
|
results[key] = self.bbox_flip(results[key],
|
|
results['img_shape'],
|
|
results['flip_direction'])
|
|
# flip masks
|
|
for key in results.get('mask_fields', []):
|
|
results[key] = results[key].flip(results['flip_direction'])
|
|
|
|
# flip segs
|
|
for key in results.get('seg_fields', []):
|
|
results[key] = mmcv.imflip(
|
|
results[key], direction=results['flip_direction'])
|
|
return results
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})'
|
|
|
|
|
|
@PIPELINES.register_module
|
|
class MMPad:
|
|
"""Pad the image & mask.
|
|
|
|
There are two padding modes: (1) pad to a fixed size and (2) pad to the
|
|
minimum size that is divisible by some number.
|
|
Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
|
|
|
|
Args:
|
|
size (tuple, optional): Fixed padding size.
|
|
size_divisor (int, optional): The divisor of padded size.
|
|
pad_to_square (bool): Whether to pad the image into a square.
|
|
Currently only used for YOLOX. Default: False.
|
|
pad_val (float, optional): Padding value, 0 by default.
|
|
"""
|
|
|
|
def __init__(self,
|
|
size=None,
|
|
size_divisor=None,
|
|
pad_to_square=False,
|
|
pad_val=0):
|
|
self.size = size
|
|
self.size_divisor = size_divisor
|
|
self.pad_val = tuple(pad_val) if isinstance(pad_val, list) else pad_val
|
|
self.pad_to_square = pad_to_square
|
|
|
|
if pad_to_square:
|
|
assert size is None and size_divisor is None, \
|
|
'The size and size_divisor must be None ' \
|
|
'when pad2square is True'
|
|
else:
|
|
assert size is not None or size_divisor is not None, \
|
|
'only one of size and size_divisor should be valid'
|
|
assert size is None or size_divisor is None
|
|
|
|
def _pad_img(self, results):
|
|
"""Pad images according to ``self.size``."""
|
|
for key in results.get('img_fields', ['img']):
|
|
if self.pad_to_square:
|
|
max_size = max(results[key].shape[:2])
|
|
self.size = (max_size, max_size)
|
|
if self.size is not None:
|
|
padded_img = mmcv.impad(
|
|
results[key], shape=self.size, pad_val=self.pad_val)
|
|
elif self.size_divisor is not None:
|
|
padded_img = mmcv.impad_to_multiple(
|
|
results[key], self.size_divisor, pad_val=self.pad_val)
|
|
results[key] = padded_img
|
|
results['pad_shape'] = padded_img.shape
|
|
results['pad_fixed_size'] = self.size
|
|
results['pad_size_divisor'] = self.size_divisor
|
|
|
|
def _pad_masks(self, results):
|
|
"""Pad masks according to ``results['pad_shape']``."""
|
|
pad_shape = results['pad_shape'][:2]
|
|
for key in results.get('mask_fields', []):
|
|
results[key] = results[key].pad(pad_shape, pad_val=self.pad_val)
|
|
|
|
def _pad_seg(self, results):
|
|
"""Pad semantic segmentation map according to
|
|
``results['pad_shape']``."""
|
|
for key in results.get('seg_fields', []):
|
|
results[key] = mmcv.impad(
|
|
results[key], shape=results['pad_shape'][:2])
|
|
|
|
def __call__(self, results):
|
|
"""Call function to pad images, masks, semantic segmentation maps.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Updated result dict.
|
|
"""
|
|
self._pad_img(results)
|
|
self._pad_masks(results)
|
|
self._pad_seg(results)
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(size={self.size}, '
|
|
repr_str += f'size_divisor={self.size_divisor}, '
|
|
repr_str += f'pad_to_square={self.pad_to_square}, '
|
|
repr_str += f'pad_val={self.pad_val})'
|
|
return repr_str
|
|
|
|
|
|
@PIPELINES.register_module
|
|
class MMNormalize:
|
|
"""Normalize the image.
|
|
|
|
Added key is "img_norm_cfg".
|
|
|
|
Args:
|
|
mean (sequence): Mean values of 3 channels.
|
|
std (sequence): Std values of 3 channels.
|
|
to_rgb (bool): Whether to convert the image from BGR to RGB,
|
|
default is true.
|
|
"""
|
|
|
|
def __init__(self, mean, std, to_rgb=True):
|
|
self.mean = np.array(mean, dtype=np.float32)
|
|
self.std = np.array(std, dtype=np.float32)
|
|
self.to_rgb = to_rgb
|
|
|
|
def __call__(self, results):
|
|
"""Call function to normalize images.
|
|
|
|
Args:
|
|
results (dict): Result dict from loading pipeline.
|
|
|
|
Returns:
|
|
dict: Normalized results, 'img_norm_cfg' key is added into
|
|
result dict.
|
|
"""
|
|
for key in results.get('img_fields', ['img']):
|
|
results[key] = mmcv.imnormalize(results[key], self.mean, self.std,
|
|
self.to_rgb)
|
|
results['img_norm_cfg'] = dict(
|
|
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'
|
|
return repr_str
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class LoadImageFromFile:
|
|
"""Load an image from file.
|
|
|
|
Required keys are "img_prefix" and "img_info" (a dict that must contain the
|
|
key "filename"). Added or updated keys are "filename", "img", "img_shape",
|
|
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
|
|
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
|
|
|
|
Args:
|
|
to_float32 (bool): Whether to convert the loaded image to a float32
|
|
numpy array. If set to False, the loaded image is an uint8 array.
|
|
Defaults to False.
|
|
color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
|
|
Defaults to 'color'.
|
|
file_client_args (dict): Arguments to instantiate a FileClient.
|
|
See :class:`mmcv.fileio.FileClient` for details.
|
|
Defaults to ``dict(backend='disk')``.
|
|
"""
|
|
|
|
def __init__(self,
|
|
to_float32=False,
|
|
color_type='color',
|
|
file_client_args=dict(backend='disk')):
|
|
self.to_float32 = to_float32
|
|
self.color_type = color_type
|
|
self.file_client_args = file_client_args.copy()
|
|
self.file_client = None
|
|
|
|
def __call__(self, results):
|
|
"""Call functions to load image and get image meta information.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded image and meta information.
|
|
"""
|
|
|
|
if self.file_client is None:
|
|
self.file_client = mmcv.FileClient(**self.file_client_args)
|
|
|
|
if results['img_prefix'] is not None:
|
|
filename = osp.join(results['img_prefix'],
|
|
results['img_info']['filename'])
|
|
else:
|
|
filename = results['img_info']['filename']
|
|
|
|
img_bytes = self.file_client.get(filename)
|
|
img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
|
|
if self.to_float32:
|
|
img = img.astype(np.float32)
|
|
|
|
results['filename'] = filename
|
|
results['ori_filename'] = results['img_info']['filename']
|
|
results['img'] = img
|
|
results['img_shape'] = img.shape
|
|
results['ori_shape'] = img.shape
|
|
results['ori_img_shape'] = img.shape
|
|
results['img_fields'] = ['img']
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = (f'{self.__class__.__name__}('
|
|
f'to_float32={self.to_float32}, '
|
|
f"color_type='{self.color_type}', "
|
|
f'file_client_args={self.file_client_args})')
|
|
return repr_str
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class LoadMultiChannelImageFromFiles:
|
|
"""Load multi-channel images from a list of separate channel files.
|
|
|
|
Required keys are "img_prefix" and "img_info" (a dict that must contain the
|
|
key "filename", which is expected to be a list of filenames).
|
|
Added or updated keys are "filename", "img", "img_shape",
|
|
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
|
|
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
|
|
|
|
Args:
|
|
to_float32 (bool): Whether to convert the loaded image to a float32
|
|
numpy array. If set to False, the loaded image is an uint8 array.
|
|
Defaults to False.
|
|
color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
|
|
Defaults to 'color'.
|
|
file_client_args (dict): Arguments to instantiate a FileClient.
|
|
See :class:`mmcv.fileio.FileClient` for details.
|
|
Defaults to ``dict(backend='disk')``.
|
|
"""
|
|
|
|
def __init__(self,
|
|
to_float32=False,
|
|
color_type='unchanged',
|
|
file_client_args=dict(backend='disk')):
|
|
self.to_float32 = to_float32
|
|
self.color_type = color_type
|
|
self.file_client_args = file_client_args.copy()
|
|
self.file_client = None
|
|
|
|
def __call__(self, results):
|
|
"""Call functions to load multiple images and get images meta
|
|
information.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded images and meta information.
|
|
"""
|
|
|
|
if self.file_client is None:
|
|
self.file_client = mmcv.FileClient(**self.file_client_args)
|
|
|
|
if results['img_prefix'] is not None:
|
|
filename = [
|
|
osp.join(results['img_prefix'], fname)
|
|
for fname in results['img_info']['filename']
|
|
]
|
|
else:
|
|
filename = results['img_info']['filename']
|
|
|
|
img = []
|
|
for name in filename:
|
|
img_bytes = self.file_client.get(name)
|
|
img.append(mmcv.imfrombytes(img_bytes, flag=self.color_type))
|
|
img = np.stack(img, axis=-1)
|
|
if self.to_float32:
|
|
img = img.astype(np.float32)
|
|
|
|
results['filename'] = filename
|
|
results['ori_filename'] = results['img_info']['filename']
|
|
results['img'] = img
|
|
results['img_shape'] = img.shape
|
|
results['ori_shape'] = img.shape
|
|
# Set initial values for default meta_keys
|
|
results['pad_shape'] = img.shape
|
|
results['scale_factor'] = 1.0
|
|
num_channels = 1 if len(img.shape) < 3 else img.shape[2]
|
|
results['img_norm_cfg'] = dict(
|
|
mean=np.zeros(num_channels, dtype=np.float32),
|
|
std=np.ones(num_channels, dtype=np.float32),
|
|
to_rgb=False)
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = (f'{self.__class__.__name__}('
|
|
f'to_float32={self.to_float32}, '
|
|
f"color_type='{self.color_type}', "
|
|
f'file_client_args={self.file_client_args})')
|
|
return repr_str
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class LoadAnnotations:
|
|
"""Load multiple types of annotations.
|
|
|
|
Args:
|
|
with_bbox (bool): Whether to parse and load the bbox annotation.
|
|
Default: True.
|
|
with_label (bool): Whether to parse and load the label annotation.
|
|
Default: True.
|
|
with_mask (bool): Whether to parse and load the mask annotation.
|
|
Default: False.
|
|
with_seg (bool): Whether to parse and load the semantic segmentation
|
|
annotation. Default: False.
|
|
poly2mask (bool): Whether to convert the instance masks from polygons
|
|
to bitmaps. Default: True.
|
|
file_client_args (dict): Arguments to instantiate a FileClient.
|
|
See :class:`mmcv.fileio.FileClient` for details.
|
|
Defaults to ``dict(backend='disk')``.
|
|
"""
|
|
|
|
def __init__(self,
|
|
with_bbox=True,
|
|
with_label=True,
|
|
with_mask=False,
|
|
with_seg=False,
|
|
poly2mask=True,
|
|
file_client_args=dict(backend='disk')):
|
|
self.with_bbox = with_bbox
|
|
self.with_label = with_label
|
|
self.with_mask = with_mask
|
|
self.with_seg = with_seg
|
|
self.poly2mask = poly2mask
|
|
self.file_client_args = file_client_args.copy()
|
|
self.file_client = None
|
|
|
|
def _load_bboxes(self, results):
|
|
"""Private function to load bounding box annotations.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded bounding box annotations.
|
|
"""
|
|
|
|
ann_info = results['ann_info']
|
|
results['gt_bboxes'] = ann_info['bboxes'].copy()
|
|
|
|
gt_bboxes_ignore = ann_info.get('bboxes_ignore', None)
|
|
if gt_bboxes_ignore is not None:
|
|
results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy()
|
|
results['bbox_fields'].append('gt_bboxes_ignore')
|
|
results['bbox_fields'].append('gt_bboxes')
|
|
return results
|
|
|
|
def _load_labels(self, results):
|
|
"""Private function to load label annotations.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded label annotations.
|
|
"""
|
|
|
|
results['gt_labels'] = results['ann_info']['labels'].copy()
|
|
return results
|
|
|
|
def _poly2mask(self, mask_ann, img_h, img_w):
|
|
"""Private function to convert masks represented with polygon to
|
|
bitmaps.
|
|
|
|
Args:
|
|
mask_ann (list | dict): Polygon mask annotation input.
|
|
img_h (int): The height of output mask.
|
|
img_w (int): The width of output mask.
|
|
|
|
Returns:
|
|
numpy.ndarray: The decode bitmap mask of shape (img_h, img_w).
|
|
"""
|
|
import xtcocotools.mask as maskUtils
|
|
|
|
if isinstance(mask_ann, list):
|
|
# polygon -- a single object might consist of multiple parts
|
|
# we merge all parts into one mask rle code
|
|
rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
|
|
rle = maskUtils.merge(rles)
|
|
elif isinstance(mask_ann['counts'], list):
|
|
# uncompressed RLE
|
|
rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
|
|
else:
|
|
# rle
|
|
rle = mask_ann
|
|
mask = maskUtils.decode(rle)
|
|
return mask
|
|
|
|
def process_polygons(self, polygons):
|
|
"""Convert polygons to list of ndarray and filter invalid polygons.
|
|
|
|
Args:
|
|
polygons (list[list]): Polygons of one instance.
|
|
|
|
Returns:
|
|
list[numpy.ndarray]: Processed polygons.
|
|
"""
|
|
|
|
polygons = [np.array(p) for p in polygons]
|
|
valid_polygons = []
|
|
for polygon in polygons:
|
|
if len(polygon) % 2 == 0 and len(polygon) >= 6:
|
|
valid_polygons.append(polygon)
|
|
return valid_polygons
|
|
|
|
def _load_masks(self, results):
|
|
"""Private function to load mask annotations.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded mask annotations.
|
|
If ``self.poly2mask`` is set ``True``, `gt_mask` will contain
|
|
:obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.
|
|
"""
|
|
from mmdet.core import BitmapMasks, PolygonMasks
|
|
|
|
h, w = results['img_info']['height'], results['img_info']['width']
|
|
gt_masks = results['ann_info']['masks']
|
|
if self.poly2mask:
|
|
gt_masks = BitmapMasks(
|
|
[self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
|
|
else:
|
|
gt_masks = PolygonMasks(
|
|
[self.process_polygons(polygons) for polygons in gt_masks], h,
|
|
w)
|
|
results['gt_masks'] = gt_masks
|
|
results['mask_fields'].append('gt_masks')
|
|
return results
|
|
|
|
def _load_semantic_seg(self, results):
|
|
"""Private function to load semantic segmentation annotations.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:`dataset`.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded semantic segmentation annotations.
|
|
"""
|
|
|
|
if self.file_client is None:
|
|
self.file_client = mmcv.FileClient(**self.file_client_args)
|
|
|
|
filename = osp.join(results['seg_prefix'],
|
|
results['ann_info']['seg_map'])
|
|
img_bytes = self.file_client.get(filename)
|
|
results['gt_semantic_seg'] = mmcv.imfrombytes(
|
|
img_bytes, flag='unchanged').squeeze()
|
|
results['seg_fields'].append('gt_semantic_seg')
|
|
return results
|
|
|
|
def __call__(self, results):
|
|
"""Call function to load multiple types annotations.
|
|
|
|
Args:
|
|
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
|
|
|
|
Returns:
|
|
dict: The dict contains loaded bounding box, label, mask and
|
|
semantic segmentation annotations.
|
|
"""
|
|
|
|
if self.with_bbox:
|
|
results = self._load_bboxes(results)
|
|
if results is None:
|
|
return None
|
|
if self.with_label:
|
|
results = self._load_labels(results)
|
|
if self.with_mask:
|
|
results = self._load_masks(results)
|
|
if self.with_seg:
|
|
results = self._load_semantic_seg(results)
|
|
return results
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(with_bbox={self.with_bbox}, '
|
|
repr_str += f'with_label={self.with_label}, '
|
|
repr_str += f'with_mask={self.with_mask}, '
|
|
repr_str += f'with_seg={self.with_seg}, '
|
|
repr_str += f'poly2mask={self.poly2mask}, '
|
|
repr_str += f'poly2mask={self.file_client_args})'
|
|
return repr_str
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
class MMMultiScaleFlipAug:
|
|
"""Test-time augmentation with multiple scales and flipping.
|
|
|
|
An example configuration is as followed:
|
|
|
|
.. code-block::
|
|
|
|
img_scale=[(1333, 400), (1333, 800)],
|
|
flip=True,
|
|
transforms=[
|
|
dict(type='Resize', keep_ratio=True),
|
|
dict(type='RandomFlip'),
|
|
dict(type='Normalize', **img_norm_cfg),
|
|
dict(type='Pad', size_divisor=32),
|
|
dict(type='ImageToTensor', keys=['img']),
|
|
dict(type='Collect', keys=['img']),
|
|
]
|
|
|
|
After MultiScaleFLipAug with above configuration, the results are wrapped
|
|
into lists of the same length as followed:
|
|
|
|
.. code-block::
|
|
|
|
dict(
|
|
img=[...],
|
|
img_shape=[...],
|
|
scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
|
|
flip=[False, True, False, True]
|
|
...
|
|
)
|
|
|
|
Args:
|
|
transforms (list[dict]): Transforms to apply in each augmentation.
|
|
img_scale (tuple | list[tuple] | None): Images scales for resizing.
|
|
scale_factor (float | list[float] | None): Scale factors for resizing.
|
|
flip (bool): Whether apply flip augmentation. Default: False.
|
|
flip_direction (str | list[str]): Flip augmentation directions,
|
|
options are "horizontal", "vertical" and "diagonal". If
|
|
flip_direction is a list, multiple flip augmentations will be
|
|
applied. It has no effect when flip == False. Default:
|
|
"horizontal".
|
|
"""
|
|
|
|
def __init__(self,
|
|
transforms,
|
|
img_scale=None,
|
|
scale_factor=None,
|
|
flip=False,
|
|
flip_direction='horizontal'):
|
|
self.transforms = Compose(transforms)
|
|
assert (img_scale is None) ^ (scale_factor is None), (
|
|
'Must have but only one variable can be setted')
|
|
if img_scale is not None:
|
|
self.img_scale = img_scale if isinstance(img_scale,
|
|
list) else [img_scale]
|
|
self.scale_key = 'scale'
|
|
assert mmcv.is_list_of(self.img_scale, tuple)
|
|
else:
|
|
self.img_scale = scale_factor if isinstance(
|
|
scale_factor, list) else [scale_factor]
|
|
self.scale_key = 'scale_factor'
|
|
|
|
self.flip = flip
|
|
self.flip_direction = flip_direction if isinstance(
|
|
flip_direction, list) else [flip_direction]
|
|
assert mmcv.is_list_of(self.flip_direction, str)
|
|
if not self.flip and self.flip_direction != ['horizontal']:
|
|
logging.warning(
|
|
'flip_direction has no effect when flip is set to False')
|
|
if (self.flip
|
|
and not any([t['type'] == 'RandomFlip' for t in transforms])):
|
|
logging.warning(
|
|
'flip has no effect when RandomFlip is not in transforms')
|
|
|
|
def __call__(self, results):
|
|
"""Call function to apply test time augment transforms on results.
|
|
|
|
Args:
|
|
results (dict): Result dict contains the data to transform.
|
|
|
|
Returns:
|
|
dict[str: list]: The augmented data, where each value is wrapped
|
|
into a list.
|
|
"""
|
|
|
|
aug_data = []
|
|
flip_args = [(False, None)]
|
|
if self.flip:
|
|
flip_args += [(True, direction)
|
|
for direction in self.flip_direction]
|
|
for scale in self.img_scale:
|
|
for flip, direction in flip_args:
|
|
_results = results.copy()
|
|
_results[self.scale_key] = scale
|
|
_results['flip'] = flip
|
|
_results['flip_direction'] = direction
|
|
data = self.transforms(_results)
|
|
aug_data.append(data)
|
|
# list of dict to dict of list
|
|
aug_data_dict = {key: [] for key in aug_data[0]}
|
|
for data in aug_data:
|
|
for key, val in data.items():
|
|
aug_data_dict[key].append(val)
|
|
return aug_data_dict
|
|
|
|
def __repr__(self):
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(transforms={self.transforms}, '
|
|
repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
|
|
repr_str += f'flip_direction={self.flip_direction})'
|
|
return repr_str
|