mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Merge pull request #2339 from xiexinch/resize-shortest-edge
[Feature] Add ResizeShortestEdge transform
This commit is contained in:
commit
0cdab7297e
@ -7,7 +7,7 @@ from packaging.version import parse
|
||||
|
||||
from .version import __version__, version_info
|
||||
|
||||
MMCV_MIN = '2.0.0rc1'
|
||||
MMCV_MIN = '2.0.0rc3'
|
||||
MMCV_MAX = '2.1.0'
|
||||
MMENGINE_MIN = '0.1.0'
|
||||
MMENGINE_MAX = '1.0.0'
|
||||
|
@ -22,7 +22,8 @@ from .transforms import (CLAHE, AdjustGamma, GenerateEdge, LoadAnnotations,
|
||||
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
||||
PackSegInputs, PhotoMetricDistortion, RandomCrop,
|
||||
RandomCutOut, RandomMosaic, RandomRotate, Rerange,
|
||||
ResizeToMultiple, RGB2Gray, SegRescale)
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
from .voc import PascalVOCDataset
|
||||
|
||||
__all__ = [
|
||||
@ -36,5 +37,5 @@ __all__ = [
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'DecathlonDataset', 'LIPDataset'
|
||||
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge'
|
||||
]
|
||||
|
@ -5,13 +5,15 @@ from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadImageFromNDArray)
|
||||
from .transforms import (CLAHE, AdjustGamma, GenerateEdge,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomMosaic, RandomRotate, Rerange, ResizeToMultiple,
|
||||
RGB2Gray, SegRescale)
|
||||
RandomMosaic, RandomRotate, Rerange,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
|
||||
__all__ = [
|
||||
'LoadAnnotations', 'RandomCrop', 'SegRescale', 'PhotoMetricDistortion',
|
||||
'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray',
|
||||
'RandomCutOut', 'RandomMosaic', 'PackSegInputs', 'ResizeToMultiple',
|
||||
'LoadImageFromNDArray', 'LoadBiomedicalImageFromFile',
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge'
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'ResizeShortestEdge'
|
||||
]
|
||||
|
@ -1226,3 +1226,87 @@ class GenerateEdge(BaseTransform):
|
||||
repr_str += f'edge_width={self.edge_width}, '
|
||||
repr_str += f'ignore_index={self.ignore_index})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ResizeShortestEdge(BaseTransform):
|
||||
"""Resize the image and mask while keeping the aspect ratio unchanged.
|
||||
|
||||
Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501
|
||||
Copyright (c) Facebook, Inc. and its affiliates.
|
||||
Licensed under the Apache-2.0 License
|
||||
|
||||
This transform attempts to scale the shorter edge to the given
|
||||
`scale`, as long as the longer edge does not exceed `max_size`.
|
||||
If `max_size` is reached, then downscale so that the longer
|
||||
edge does not exceed `max_size`.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- gt_seg_map (optional)
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- gt_seg_map (optional))
|
||||
|
||||
Added Keys:
|
||||
|
||||
- scale
|
||||
- scale_factor
|
||||
- keep_ratio
|
||||
|
||||
|
||||
Args:
|
||||
scale (Union[int, Tuple[int, int]]): The target short edge length.
|
||||
If it's tuple, will select the min value as the short edge length.
|
||||
max_size (int): The maximum allowed longest edge length.
|
||||
"""
|
||||
|
||||
def __init__(self, scale: Union[int, Tuple[int, int]],
|
||||
max_size: int) -> None:
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.max_size = max_size
|
||||
|
||||
# Create a empty Resize object
|
||||
self.resize = TRANSFORMS.build({
|
||||
'type': 'Resize',
|
||||
'scale': 0,
|
||||
'keep_ratio': True
|
||||
})
|
||||
|
||||
def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]:
|
||||
"""Compute the target image shape with the given `short_edge_length`.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): The input image.
|
||||
short_edge_length (Union[int, Tuple[int, int]]): The target short
|
||||
edge length. If it's tuple, will select the min value as the
|
||||
short edge length.
|
||||
"""
|
||||
h, w = img.shape[:2]
|
||||
if isinstance(short_edge_length, int):
|
||||
size = short_edge_length * 1.0
|
||||
elif isinstance(short_edge_length, tuple):
|
||||
size = min(short_edge_length) * 1.0
|
||||
scale = size / min(h, w)
|
||||
if h < w:
|
||||
new_h, new_w = size, scale * w
|
||||
else:
|
||||
new_h, new_w = scale * h, size
|
||||
|
||||
if max(new_h, new_w) > self.max_size:
|
||||
scale = self.max_size * 1.0 / max(new_h, new_w)
|
||||
new_h *= scale
|
||||
new_w *= scale
|
||||
|
||||
new_h = int(new_h + 0.5)
|
||||
new_w = int(new_w + 0.5)
|
||||
return (new_w, new_h)
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
self.resize.scale = self._get_output_shape(results['img'], self.scale)
|
||||
return self.resize(results)
|
||||
|
@ -10,6 +10,9 @@ from PIL import Image
|
||||
from mmseg.datasets.transforms import * # noqa
|
||||
from mmseg.datasets.transforms import PhotoMetricDistortion, RandomCrop
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
def test_resize():
|
||||
@ -71,6 +74,34 @@ def test_resize():
|
||||
resized_results = resize_module(results.copy())
|
||||
assert max(resized_results['img_shape'][:2]) <= 1333 * 1.1
|
||||
|
||||
# test RandomChoiceResize, which `resize_type` is `ResizeShortestEdge`
|
||||
transform = dict(
|
||||
type='RandomChoiceResize',
|
||||
scales=[128, 256, 512],
|
||||
resize_type='ResizeShortestEdge',
|
||||
max_size=1333)
|
||||
resize_module = TRANSFORMS.build(transform)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'][0] in [128, 256, 512]
|
||||
|
||||
transform = dict(
|
||||
type='RandomChoiceResize',
|
||||
scales=[512],
|
||||
resize_type='ResizeShortestEdge',
|
||||
max_size=512)
|
||||
resize_module = TRANSFORMS.build(transform)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'][1] == 512
|
||||
|
||||
transform = dict(
|
||||
type='RandomChoiceResize',
|
||||
scales=[(128, 256), (256, 512), (512, 1024)],
|
||||
resize_type='ResizeShortestEdge',
|
||||
max_size=1333)
|
||||
resize_module = TRANSFORMS.build(transform)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'][0] in [128, 256, 512]
|
||||
|
||||
# test scale=None and scale_factor is tuple.
|
||||
# img shape: (288, 512, 3)
|
||||
transform = dict(
|
||||
|
Loading…
x
Reference in New Issue
Block a user