diff --git a/mmyolo/datasets/transforms/__init__.py b/mmyolo/datasets/transforms/__init__.py new file mode 100644 index 00000000..b6d6ebaf --- /dev/null +++ b/mmyolo/datasets/transforms/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .transforms import (LetterResize, LoadAnnotations, YOLOv5HSVRandomAug, + YOLOv5KeepRatioResize, YOLOv5RandomAffine) + +__all__ = [ + 'YOLOv5KeepRatioResize', 'LetterResize', 'YOLOv5HSVRandomAug', 'LoadAnnotations', + 'YOLOv5RandomAffine' +] diff --git a/mmyolo/datasets/transforms/transforms.py b/mmyolo/datasets/transforms/transforms.py new file mode 100644 index 00000000..4d9d01b9 --- /dev/null +++ b/mmyolo/datasets/transforms/transforms.py @@ -0,0 +1,597 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Tuple, Union + +import cv2 +import mmcv +import numpy as np +import torch +from mmcv.transforms import BaseTransform +from mmcv.transforms.utils import cache_randomness +from numpy import random + +from mmdet.datasets.transforms import LoadAnnotations as MMDET_LoadAnnotations +from mmdet.datasets.transforms import Resize as MMDET_Resize +from mmdet.structures.bbox import autocast_box_type, get_box_type +from mmyolo.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class YOLOv5KeepRatioResize(MMDET_Resize): + """Resize images & bbox(if existed). + + This transform resizes the input image according to ``scale``. + Bboxes (if existed) are then resized with the same scale factor. + + Required Keys: + + - img (np.uint8) + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + + Modified Keys: + + - img (np.uint8) + - img_shape (tuple) + - gt_bboxes (optional) + - scale (float) + + Added Keys: + + - scale_factor (np.float32) + + Args: + scale (Union[int, Tuple[int, int]]): Images scales for resizing. + """ + + def __init__(self, + scale: Union[int, Tuple[int, int]], + keep_ratio: bool = True, + **kwargs) -> None: + assert keep_ratio is True + super().__init__(scale=scale, keep_ratio=True, **kwargs) + + @staticmethod + def _get_rescale_ratio(old_size, scale) -> float: + """Calculate the rescale ratio. + + Args: + old_size (tuple[int]): The old size (w, h) of image. + scale (float | tuple[int]): The scaling factor or maximum size. + If it is a float number, then the image will be rescaled by + this factor, else if it is a tuple of 2 integers, then + the image will be rescaled as large as possible within + the scale. + + Returns: + float: The resize ratio. + """ + w, h = old_size + if isinstance(scale, (float, int)): + if scale <= 0: + raise ValueError(f'Invalid scale {scale}, must be positive.') + scale_factor = scale + elif isinstance(scale, tuple): + max_long_edge = max(scale) + max_short_edge = min(scale) + scale_factor = min(max_long_edge / max(h, w), + max_short_edge / min(h, w)) + else: + raise TypeError('Scale must be a number or tuple of int, ' + f'but got {type(scale)}') + + return scale_factor + + def _resize_img(self, results: dict) -> None: + """Resize images with ``results['scale']``.""" + assert self.keep_ratio is True + + if results.get('img', None) is not None: + image = results['img'] + original_h, original_w = image.shape[:2] + ratio = self._get_rescale_ratio((original_h, original_w), + self.scale) + + if ratio != 1: + # resize image according to the ratio + image = mmcv.imrescale( + img=image, + scale=ratio, + interpolation='area' if ratio < 1 else 'bilinear', + backend=self.backend) + + resized_h, resized_w = image.shape[:2] + scale_ratio = resized_h / original_h + + scale_factor = np.array([scale_ratio, scale_ratio], + dtype=np.float32) + + results['img'] = image + results['img_shape'] = image.shape[:2] + results['scale_factor'] = scale_factor + + +@TRANSFORMS.register_module() +class LetterResize(MMDET_Resize): + """Resize and pad image while meeting stride-multiple constraints. + + Required Keys: + + - img (np.uint8) + - batch_shape (np.int64) (optional) + + Modified Keys: + + - img (np.uint8) + - img_shape (tuple) + - gt_bboxes (optional) + + Added Keys: + - pad_param (np.float32) + + Args: + scale (Union[int, Tuple[int, int]]): Images scales for resizing. + pad_val (dict): Padding value. Defaults to dict(img=0, seg=255). + use_mini_pad (bool): Whether using minimum rectangle padding. + Defaults to True + stretch_only (bool): Whether stretch to the specified size directly. + Defaults to False + allow_scale_up (bool): Allow scale up when ratio > 1. Defaults to True + """ + + def __init__(self, + scale: Union[int, Tuple[int, int]], + pad_val: dict = dict(img=0, mask=0, seg=255), + use_mini_pad: bool = False, + stretch_only: bool = False, + allow_scale_up: bool = True, + **kwargs) -> None: + super().__init__(scale=scale, keep_ratio=True, **kwargs) + + self.pad_val = pad_val + if isinstance(pad_val, (int, float)): + pad_val = dict(img=pad_val, seg=255) + assert isinstance( + pad_val, dict), f'pad_val must be dict, but got {type(pad_val)}' + + self.use_mini_pad = use_mini_pad + self.stretch_only = stretch_only + self.allow_scale_up = allow_scale_up + self.padded_val = None + + def _resize_img(self, results: dict) -> None: + """Resize images with ``results['scale']``.""" + image = results.get('img', None) + if image is None: + return + + # Use batch_shape if a batch_shape policy is configured + if 'batch_shape' in results: + self.scale = tuple(results['batch_shape']) + + image_shape = image.shape[:2] # height, width + + # Scale ratio (new / old) + ratio = min(self.scale[0] / image_shape[0], + self.scale[1] / image_shape[1]) + + # only scale down, do not scale up (for better test mAP) + if not self.allow_scale_up: + ratio = min(ratio, 1.0) + + ratio = [ratio, ratio] # float -> (float, float) for (height, width) + + # compute the best size of the image + no_pad_shape = (int(round(image_shape[0] * ratio[0])), + int(round(image_shape[1] * ratio[1]))) + + # padding height & width + padding_h, padding_w = [ + self.scale[0] - no_pad_shape[0], self.scale[1] - no_pad_shape[1] + ] + if self.use_mini_pad: + # minimum rectangle padding + padding_w, padding_h = np.mod(padding_w, 32), np.mod(padding_h, 32) + + elif self.stretch_only: + # stretch to the specified size directly + padding_h, padding_w = 0.0, 0.0 + no_pad_shape = (self.scale[0], self.scale[1]) + ratio = [ + self.scale[0] / image_shape[0], self.scale[1] / image_shape[1] + ] # height, width ratios + + # divide padding into 2 sides + padding_h /= 2 + padding_w /= 2 + + if image_shape[::-1] != no_pad_shape: + # compare with no resize and padding size + image = mmcv.imrescale( + image, + no_pad_shape, + interpolation=self.interpolation, + backend=self.backend) + + scale_factor = np.array([ratio[0], ratio[1]], dtype=np.float32) + + if 'scale_factor' in results: + results['scale_factor'] = results['scale_factor'] * scale_factor + else: + results['scale_factor'] = scale_factor + + # padding + top_padding, bottom_padding = int(round(padding_h - 0.1)), int( + round(padding_h + 0.1)) + left_padding, right_padding = int(round(padding_w - 0.1)), int( + round(padding_w + 0.1)) + + padding_list = [ + top_padding, bottom_padding, left_padding, right_padding + ] + if top_padding != 0 or bottom_padding != 0 or \ + left_padding != 0 or right_padding != 0: + + pad_val = self.pad_val.get('img', 0) + if isinstance(pad_val, int) and image.ndim == 3: + self.padded_val = tuple(pad_val for _ in range(image.shape[2])) + else: + self.padded_val = pad_val + + image = mmcv.impad( + img=image, + padding=(padding_list[2], padding_list[0], padding_list[3], + padding_list[1]), + pad_val=self.padded_val, + padding_mode='constant') + + results['img'] = image + results['img_shape'] = image.shape + results['pad_param'] = np.array(padding_list, dtype=np.float32) + + def _resize_masks(self, results: dict) -> None: + """Resize masks with ``results['scale']``""" + if results.get('gt_masks', None) is None: + return + + # resize the gt_masks + gt_mask_height = results['gt_masks'].height * \ + results['scale_factor'][0] + gt_mask_width = results['gt_masks'].width * \ + results['scale_factor'][1] + gt_masks = results['gt_masks'].rescale((gt_mask_height, gt_mask_width)) + + # padding the gt_masks + if len(gt_masks) == 0: + padded_masks = np.empty((0, *results['img_shape'][:2]), + dtype=np.uint8) + else: + # TODO: The function is incorrect. Because the mask may not + # be able to pad. + padded_masks = np.stack([ + mmcv.impad( + mask, + padding=(int(results['pad_param'][2]), + int(results['pad_param'][0]), + int(results['pad_param'][3]), + int(results['pad_param'][1])), + pad_val=self.pad_val.get('masks', 0)) for mask in gt_masks + ]) + results['gt_masks'] = type(results['gt_masks'])( + padded_masks, *results['img_shape'][:2]) + + def _resize_bboxes(self, results: dict) -> None: + """Resize bounding boxes with ``results['scale_factor']``.""" + if results.get('gt_bboxes', None) is None: + return + results['gt_bboxes'].rescale_(results['scale_factor']) + + if len(results['pad_param']) != 4: + return + results['gt_bboxes'].translate_( + (results['pad_param'][2], results['pad_param'][1])) + + if self.clip_object_border: + results['gt_bboxes'].clip_(results['img_shape']) + + +# TODO: Check if it can be merged with mmdet.YOLOXHSVRandomAug +@TRANSFORMS.register_module() +class YOLOv5HSVRandomAug(BaseTransform): + """Apply HSV augmentation to image sequentially. + + Required Keys: + + - img + + Modified Keys: + + - img + + Args: + hue_delta ([int, float]): delta of hue. Defaults to 0.015. + saturation_delta ([int, float]): delta of saturation. Defaults to 0.7. + value_delta ([int, float]): delta of value. Defaults to 0.4. + """ + + def __init__(self, + hue_delta: Union[int, float] = 0.015, + saturation_delta: Union[int, float] = 0.7, + value_delta: Union[int, float] = 0.4): + self.hue_delta = hue_delta + self.saturation_delta = saturation_delta + self.value_delta = value_delta + + def transform(self, results: dict) -> dict: + hsv_gains = \ + random.uniform(-1, 1, 3) * \ + [self.hue_delta, self.saturation_delta, self.value_delta] + 1 + hue, sat, val = cv2.split( + cv2.cvtColor(results['img'], cv2.COLOR_BGR2HSV)) + + table_list = np.arange(0, 256, dtype=hsv_gains.dtype) + lut_hue = ((table_list * hsv_gains[0]) % 180).astype(np.uint8) + lut_sat = np.clip(table_list * hsv_gains[1], 0, 255).astype(np.uint8) + lut_val = np.clip(table_list * hsv_gains[2], 0, 255).astype(np.uint8) + + im_hsv = cv2.merge( + (cv2.LUT(hue, lut_hue), cv2.LUT(sat, + lut_sat), cv2.LUT(val, lut_val))) + results['img'] = cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR) + return results + + +# TODO: can be accelerated +@TRANSFORMS.register_module() +class LoadAnnotations(MMDET_LoadAnnotations): + """Because the yolo series does not need to consider ignore bboxes for the + time being, in order to speed up the pipeline, it can be excluded in + advance.""" + + def _load_bboxes(self, results: dict) -> None: + """Private function to load bounding box annotations. + + Note: BBoxes with ignore_flag of 1 is not considered. + + Args: + results (dict): Result dict from :obj:``mmengine.BaseDataset``. + + Returns: + dict: The dict contains loaded bounding box annotations. + """ + gt_bboxes = [] + gt_ignore_flags = [] + for instance in results['instances']: + if instance['ignore_flag'] == 0: + gt_bboxes.append(instance['bbox']) + gt_ignore_flags.append(instance['ignore_flag']) + results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=bool) + + if self.box_type is None: + results['gt_bboxes'] = np.array( + gt_bboxes, dtype=np.float32).reshape((-1, 4)) + else: + _, box_type_cls = get_box_type(self.box_type) + results['gt_bboxes'] = box_type_cls(gt_bboxes, dtype=torch.float32) + + def _load_labels(self, results: dict) -> None: + """Private function to load label annotations. + + Note: BBoxes with ignore_flag of 1 is not considered. + + Args: + results (dict): Result dict from :obj:``mmengine.BaseDataset``. + + Returns: + dict: The dict contains loaded label annotations. + """ + gt_bboxes_labels = [] + for instance in results['instances']: + if instance['ignore_flag'] == 0: + gt_bboxes_labels.append(instance['bbox_label']) + results['gt_bboxes_labels'] = np.array( + gt_bboxes_labels, dtype=np.int64) + + +@TRANSFORMS.register_module() +class YOLOv5RandomAffine(BaseTransform): + """Random affine transform data augmentation in YOLOv5. It is different + from the implementation in YOLOX. + + This operation randomly generates affine transform matrix which including + rotation, translation, shear and scaling transforms. + + Required Keys: + + - img + - gt_bboxes (BaseBoxes[torch.float32]) (optional) + - gt_bboxes_labels (np.int64) (optional) + - gt_ignore_flags (np.bool) (optional) + + Modified Keys: + + - img + - img_shape + - gt_bboxes (optional) + - gt_bboxes_labels (optional) + - gt_ignore_flags (optional) + + Args: + max_rotate_degree (float): Maximum degrees of rotation transform. + Defaults to 10. + max_translate_ratio (float): Maximum ratio of translation. + Defaults to 0.1. + scaling_ratio_range (tuple[float]): Min and max ratio of + scaling transform. Defaults to (0.5, 1.5). + max_shear_degree (float): Maximum degrees of shear + transform. Defaults to 2. + border (tuple[int]): Distance from height and width sides of input + image to adjust output shape. Only used in mosaic dataset. + Defaults to (0, 0). + border_val (tuple[int]): Border padding values of 3 channels. + Defaults to (114, 114, 114). + bbox_clip_border (bool, optional): Whether to clip the objects outside + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. + """ + + def __init__(self, + max_rotate_degree: float = 10.0, + max_translate_ratio: float = 0.1, + scaling_ratio_range: Tuple[float, float] = (0.5, 1.5), + max_shear_degree: float = 2.0, + border: Tuple[int, int] = (0, 0), + border_val: Tuple[int, int, int] = (114, 114, 114), + bbox_clip_border: bool = True, + min_bbox_size=2, + min_area_ratio=0.1, + max_aspect_ratio=20) -> None: + assert 0 <= max_translate_ratio <= 1 + assert scaling_ratio_range[0] <= scaling_ratio_range[1] + assert scaling_ratio_range[0] > 0 + self.max_rotate_degree = max_rotate_degree + self.max_translate_ratio = max_translate_ratio + self.scaling_ratio_range = scaling_ratio_range + self.max_shear_degree = max_shear_degree + self.border = border + self.border_val = border_val + self.bbox_clip_border = bbox_clip_border + self.min_bbox_size = min_bbox_size + self.min_bbox_size = min_bbox_size + self.min_area_ratio = min_area_ratio + self.max_aspect_ratio = max_aspect_ratio + + @cache_randomness + def _get_random_homography_matrix(self, height, width): + # Rotation + rotation_degree = random.uniform(-self.max_rotate_degree, + self.max_rotate_degree) + rotation_matrix = self._get_rotation_matrix(rotation_degree) + + # Scaling + scaling_ratio = random.uniform(self.scaling_ratio_range[0], + self.scaling_ratio_range[1]) + scaling_matrix = self._get_scaling_matrix(scaling_ratio) + + # Shear + x_degree = random.uniform(-self.max_shear_degree, + self.max_shear_degree) + y_degree = random.uniform(-self.max_shear_degree, + self.max_shear_degree) + shear_matrix = self._get_shear_matrix(x_degree, y_degree) + + # Translation + trans_x = random.uniform(0.5 - self.max_translate_ratio, + 0.5 + self.max_translate_ratio) * width + trans_y = random.uniform(0.5 - self.max_translate_ratio, + 0.5 + self.max_translate_ratio) * height + translate_matrix = self._get_translation_matrix(trans_x, trans_y) + warp_matrix = ( + translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix) + return warp_matrix, scaling_ratio + + @autocast_box_type() + def transform(self, results: dict) -> dict: + img = results['img'] + height = img.shape[0] + self.border[0] * 2 + width = img.shape[1] + self.border[1] * 2 + + # Note: Different from YOLOX + center_matrix = np.eye(3, dtype=np.float32) + center_matrix[0, 2] = -img.shape[1] / 2 + center_matrix[1, 2] = -img.shape[0] / 2 + + warp_matrix, scaling_ratio = self._get_random_homography_matrix( + height, width) + warp_matrix = warp_matrix @ center_matrix + + img = cv2.warpPerspective( + img, + warp_matrix, + dsize=(width, height), + borderValue=self.border_val) + results['img'] = img + results['img_shape'] = img.shape + + bboxes = results['gt_bboxes'] + num_bboxes = len(bboxes) + if num_bboxes: + orig_bboxes = bboxes.clone() + + bboxes.project_(warp_matrix) + if self.bbox_clip_border: + bboxes.clip_([height, width]) + + # filter bboxes + orig_bboxes.rescale_([scaling_ratio, scaling_ratio]) + + # Be careful: valid_index must convert to numpy, + # otherwise it will raise out of bounds when len(valid_index)=1 + valid_index = self.filter_gt_bboxes(orig_bboxes, bboxes).numpy() + results['gt_bboxes'] = bboxes[valid_index] + results['gt_bboxes_labels'] = results['gt_bboxes_labels'][ + valid_index] + results['gt_ignore_flags'] = results['gt_ignore_flags'][ + valid_index] + + if 'gt_masks' in results: + raise NotImplementedError('RandomAffine only supports bbox.') + return results + + def filter_gt_bboxes(self, origin_bboxes, wrapped_bboxes): + origin_w = origin_bboxes.widths + origin_h = origin_bboxes.heights + wrapped_w = wrapped_bboxes.widths + wrapped_h = wrapped_bboxes.heights + aspect_ratio = np.maximum(wrapped_w / (wrapped_h + 1e-16), + wrapped_h / (wrapped_w + 1e-16)) + + wh_valid_idx = (wrapped_w > self.min_bbox_size) & \ + (wrapped_h > self.min_bbox_size) + area_valid_idx = wrapped_w * wrapped_h / (origin_w * origin_h + + 1e-16) > self.min_area_ratio + aspect_ratio_valid_idx = aspect_ratio < self.max_aspect_ratio + return wh_valid_idx & area_valid_idx & aspect_ratio_valid_idx + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(max_rotate_degree={self.max_rotate_degree}, ' + repr_str += f'max_translate_ratio={self.max_translate_ratio}, ' + repr_str += f'scaling_ratio_range={self.scaling_ratio_range}, ' + repr_str += f'max_shear_degree={self.max_shear_degree}, ' + repr_str += f'border={self.border}, ' + repr_str += f'border_val={self.border_val}, ' + repr_str += f'bbox_clip_border={self.bbox_clip_border})' + return repr_str + + @staticmethod + def _get_rotation_matrix(rotate_degrees: float) -> np.ndarray: + radian = math.radians(rotate_degrees) + rotation_matrix = np.array( + [[np.cos(radian), -np.sin(radian), 0.], + [np.sin(radian), np.cos(radian), 0.], [0., 0., 1.]], + dtype=np.float32) + return rotation_matrix + + @staticmethod + def _get_scaling_matrix(scale_ratio: float) -> np.ndarray: + scaling_matrix = np.array( + [[scale_ratio, 0., 0.], [0., scale_ratio, 0.], [0., 0., 1.]], + dtype=np.float32) + return scaling_matrix + + @staticmethod + def _get_shear_matrix(x_shear_degrees: float, + y_shear_degrees: float) -> np.ndarray: + x_radian = math.radians(x_shear_degrees) + y_radian = math.radians(y_shear_degrees) + shear_matrix = np.array([[1, np.tan(x_radian), 0.], + [np.tan(y_radian), 1, 0.], [0., 0., 1.]], + dtype=np.float32) + return shear_matrix + + @staticmethod + def _get_translation_matrix(x: float, y: float) -> np.ndarray: + translation_matrix = np.array([[1, 0., x], [0., 1, y], [0., 0., 1.]], + dtype=np.float32) + return translation_matrix