Merge pull request #2339 from xiexinch/resize-shortest-edge

[Feature] Add ResizeShortestEdge transform
This commit is contained in:
Miao Zheng 2022-12-01 16:06:23 +08:00 committed by GitHub
commit 0cdab7297e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 124 additions and 6 deletions

View File

@ -7,7 +7,7 @@ from packaging.version import parse
from .version import __version__, version_info
MMCV_MIN = '2.0.0rc1'
MMCV_MIN = '2.0.0rc3'
MMCV_MAX = '2.1.0'
MMENGINE_MIN = '0.1.0'
MMENGINE_MAX = '1.0.0'

View File

@ -22,7 +22,8 @@ from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations,
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
PackSegInputs, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
ResizeToMultiple, RGB2Gray, SegRescale)
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)
from .voc import PascalVOCDataset
__all__ = [
@ -36,5 +37,5 @@ __all__ = [
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset'
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge'
]

View File

@ -5,13 +5,15 @@ from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
LoadImageFromNDArray)
from .transforms import (CLAHE, AdjustGamma, GenerateEdge,
PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomMosaic, RandomRotate, Rerange, ResizeToMultiple,
RGB2Gray, SegRescale)
RandomMosaic, RandomRotate, Rerange,
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)
__all__ = [
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge'
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'ResizeShortestEdge'
]

View File

@ -1226,3 +1226,87 @@ class GenerateEdge(BaseTransform):
repr_str += f'edge_width={self.edge_width}, '
repr_str += f'ignore_index={self.ignore_index})'
return repr_str
@TRANSFORMS.register_module()
class ResizeShortestEdge(BaseTransform):
"""Resize the image and mask while keeping the aspect ratio unchanged.
Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501
Copyright (c) Facebook, Inc. and its affiliates.
Licensed under the Apache-2.0 License
This transform attempts to scale the shorter edge to the given
`scale`, as long as the longer edge does not exceed `max_size`.
If `max_size` is reached, then downscale so that the longer
edge does not exceed `max_size`.
Required Keys:
- img
- gt_seg_map (optional)
Modified Keys:
- img
- img_shape
- gt_seg_map (optional))
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
scale (Union[int, Tuple[int, int]]): The target short edge length.
If it's tuple, will select the min value as the short edge length.
max_size (int): The maximum allowed longest edge length.
"""
def __init__(self, scale: Union[int, Tuple[int, int]],
max_size: int) -> None:
super().__init__()
self.scale = scale
self.max_size = max_size
# Create a empty Resize object
self.resize = TRANSFORMS.build({
'type': 'Resize',
'scale': 0,
'keep_ratio': True
})
def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]:
"""Compute the target image shape with the given `short_edge_length`.
Args:
img (np.ndarray): The input image.
short_edge_length (Union[int, Tuple[int, int]]): The target short
edge length. If it's tuple, will select the min value as the
short edge length.
"""
h, w = img.shape[:2]
if isinstance(short_edge_length, int):
size = short_edge_length * 1.0
elif isinstance(short_edge_length, tuple):
size = min(short_edge_length) * 1.0
scale = size / min(h, w)
if h < w:
new_h, new_w = size, scale * w
else:
new_h, new_w = scale * h, size
if max(new_h, new_w) > self.max_size:
scale = self.max_size * 1.0 / max(new_h, new_w)
new_h *= scale
new_w *= scale
new_h = int(new_h + 0.5)
new_w = int(new_w + 0.5)
return (new_w, new_h)
def transform(self, results: Dict) -> Dict:
self.resize.scale = self._get_output_shape(results['img'], self.scale)
return self.resize(results)

View File

@ -10,6 +10,9 @@ from PIL import Image
from mmseg.datasets.transforms import * # noqa
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
from mmseg.registry import TRANSFORMS
from mmseg.utils import register_all_modules
register_all_modules()
def test_resize():
@ -71,6 +74,34 @@ def test_resize():
resized_results = resize_module(results.copy())
assert max(resized_results['img_shape'][:2]) <= 1333 * 1.1
# test RandomChoiceResize, which `resize_type` is `ResizeShortestEdge`
transform = dict(
type='RandomChoiceResize',
scales=[128, 256, 512],
resize_type='ResizeShortestEdge',
max_size=1333)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'][0] in [128, 256, 512]
transform = dict(
type='RandomChoiceResize',
scales=[512],
resize_type='ResizeShortestEdge',
max_size=512)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'][1] == 512
transform = dict(
type='RandomChoiceResize',
scales=[(128, 256), (256, 512), (512, 1024)],
resize_type='ResizeShortestEdge',
max_size=1333)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'][0] in [128, 256, 512]
# test scale=None and scale_factor is tuple.
# img shape: (288, 512, 3)
transform = dict(