Merge pull request #2339 from xiexinch/resize-shortest-edge

[Feature] Add ResizeShortestEdge transform
This commit is contained in:
Miao Zheng 2022-12-01 16:06:23 +08:00 committed by GitHub
commit 0cdab7297e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 124 additions and 6 deletions

View File

@ -7,7 +7,7 @@ from packaging.version import parse
from .version import __version__, version_info from .version import __version__, version_info
MMCV_MIN = '2.0.0rc1' MMCV_MIN = '2.0.0rc3'
MMCV_MAX = '2.1.0' MMCV_MAX = '2.1.0'
MMENGINE_MIN = '0.1.0' MMENGINE_MIN = '0.1.0'
MMENGINE_MAX = '1.0.0' MMENGINE_MAX = '1.0.0'

View File

@ -22,7 +22,8 @@ from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations,
LoadBiomedicalImageFromFile, LoadImageFromNDArray, LoadBiomedicalImageFromFile, LoadImageFromNDArray,
PackSegInputs, PhotoMetricDistortion, RandomCrop, PackSegInputs, PhotoMetricDistortion, RandomCrop,
RandomCutOut, RandomMosaic, RandomRotate, Rerange, RandomCutOut, RandomMosaic, RandomRotate, Rerange,
ResizeToMultiple, RGB2Gray, SegRescale) ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)
from .voc import PascalVOCDataset from .voc import PascalVOCDataset
__all__ = [ __all__ = [
@ -36,5 +37,5 @@ __all__ = [
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'DecathlonDataset', 'LIPDataset' 'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge'
] ]

View File

@ -5,13 +5,15 @@ from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
LoadImageFromNDArray) LoadImageFromNDArray)
from .transforms import (CLAHE, AdjustGamma, GenerateEdge, from .transforms import (CLAHE, AdjustGamma, GenerateEdge,
PhotoMetricDistortion, RandomCrop, RandomCutOut, PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomMosaic, RandomRotate, Rerange, ResizeToMultiple, RandomMosaic, RandomRotate, Rerange,
RGB2Gray, SegRescale) ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
SegRescale)
__all__ = [ __all__ = [
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion', 'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple', 'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile', 'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge' 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'ResizeShortestEdge'
] ]

View File

@ -1226,3 +1226,87 @@ class GenerateEdge(BaseTransform):
repr_str += f'edge_width={self.edge_width}, ' repr_str += f'edge_width={self.edge_width}, '
repr_str += f'ignore_index={self.ignore_index})' repr_str += f'ignore_index={self.ignore_index})'
return repr_str return repr_str
@TRANSFORMS.register_module()
class ResizeShortestEdge(BaseTransform):
"""Resize the image and mask while keeping the aspect ratio unchanged.
Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501
Copyright (c) Facebook, Inc. and its affiliates.
Licensed under the Apache-2.0 License
This transform attempts to scale the shorter edge to the given
`scale`, as long as the longer edge does not exceed `max_size`.
If `max_size` is reached, then downscale so that the longer
edge does not exceed `max_size`.
Required Keys:
- img
- gt_seg_map (optional)
Modified Keys:
- img
- img_shape
- gt_seg_map (optional))
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
scale (Union[int, Tuple[int, int]]): The target short edge length.
If it's tuple, will select the min value as the short edge length.
max_size (int): The maximum allowed longest edge length.
"""
def __init__(self, scale: Union[int, Tuple[int, int]],
max_size: int) -> None:
super().__init__()
self.scale = scale
self.max_size = max_size
# Create a empty Resize object
self.resize = TRANSFORMS.build({
'type': 'Resize',
'scale': 0,
'keep_ratio': True
})
def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]:
"""Compute the target image shape with the given `short_edge_length`.
Args:
img (np.ndarray): The input image.
short_edge_length (Union[int, Tuple[int, int]]): The target short
edge length. If it's tuple, will select the min value as the
short edge length.
"""
h, w = img.shape[:2]
if isinstance(short_edge_length, int):
size = short_edge_length * 1.0
elif isinstance(short_edge_length, tuple):
size = min(short_edge_length) * 1.0
scale = size / min(h, w)
if h < w:
new_h, new_w = size, scale * w
else:
new_h, new_w = scale * h, size
if max(new_h, new_w) > self.max_size:
scale = self.max_size * 1.0 / max(new_h, new_w)
new_h *= scale
new_w *= scale
new_h = int(new_h + 0.5)
new_w = int(new_w + 0.5)
return (new_w, new_h)
def transform(self, results: Dict) -> Dict:
self.resize.scale = self._get_output_shape(results['img'], self.scale)
return self.resize(results)

View File

@ -10,6 +10,9 @@ from PIL import Image
from mmseg.datasets.transforms import * # noqa from mmseg.datasets.transforms import * # noqa
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
from mmseg.registry import TRANSFORMS from mmseg.registry import TRANSFORMS
from mmseg.utils import register_all_modules
register_all_modules()
def test_resize(): def test_resize():
@ -71,6 +74,34 @@ def test_resize():
resized_results = resize_module(results.copy()) resized_results = resize_module(results.copy())
assert max(resized_results['img_shape'][:2]) <= 1333 * 1.1 assert max(resized_results['img_shape'][:2]) <= 1333 * 1.1
# test RandomChoiceResize, which `resize_type` is `ResizeShortestEdge`
transform = dict(
type='RandomChoiceResize',
scales=[128, 256, 512],
resize_type='ResizeShortestEdge',
max_size=1333)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'][0] in [128, 256, 512]
transform = dict(
type='RandomChoiceResize',
scales=[512],
resize_type='ResizeShortestEdge',
max_size=512)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'][1] == 512
transform = dict(
type='RandomChoiceResize',
scales=[(128, 256), (256, 512), (512, 1024)],
resize_type='ResizeShortestEdge',
max_size=1333)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'][0] in [128, 256, 512]
# test scale=None and scale_factor is tuple. # test scale=None and scale_factor is tuple.
# img shape: (288, 512, 3) # img shape: (288, 512, 3)
transform = dict( transform = dict(