mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] Enhance CenterCrop (#1765)
* enhance centercrop and adjust crop size to (w, h) * fix comments * update required keys and docstringpull/2133/head
parent
2619aa9c8e
commit
2f85d78149
|
@ -415,14 +415,16 @@ class Pad(BaseTransform):
|
|||
|
||||
@TRANSFORMS.register_module()
|
||||
class CenterCrop(BaseTransform):
|
||||
"""Crop the center of the image and segmentation masks. If the crop area
|
||||
exceeds the original image and ``pad_mode`` is not None, the original image
|
||||
will be padded before cropping.
|
||||
"""Crop the center of the image, segmentation masks, bounding boxes and key
|
||||
points. If the crop area exceeds the original image and ``pad_mode`` is not
|
||||
None, the original image will be padded before cropping.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- gt_semantic_seg (optional)
|
||||
- gt_bboxes (optional)
|
||||
- gt_keypoints (optional)
|
||||
|
||||
Modified Keys:
|
||||
|
||||
|
@ -430,6 +432,8 @@ class CenterCrop(BaseTransform):
|
|||
- height
|
||||
- width
|
||||
- gt_semantic_seg (optional)
|
||||
- gt_bboxes (optional)
|
||||
- gt_keypoints (optional)
|
||||
|
||||
Added Key:
|
||||
|
||||
|
@ -438,8 +442,8 @@ class CenterCrop(BaseTransform):
|
|||
|
||||
Args:
|
||||
crop_size (Union[int, Tuple[int, int]]): Expected size after cropping
|
||||
with the format of (h, w). If set to an integer, then cropping
|
||||
height and width are equal to this integer.
|
||||
with the format of (w, h). If set to an integer, then cropping
|
||||
width and height are equal to this integer.
|
||||
pad_val (Union[Number, Dict[str, Number]]): A dict for
|
||||
padding value. To specify how to set this argument, please see
|
||||
the docstring of class ``Pad``. Defaults to
|
||||
|
@ -449,6 +453,11 @@ class CenterCrop(BaseTransform):
|
|||
docstring of class ``Pad``. Defaults to 'constant'.
|
||||
pad_cfg (str): Base config for padding. Defaults to
|
||||
``dict(type='Pad')``.
|
||||
clip_object_border (bool): Whether to clip the objects
|
||||
outside the border of the image. In some dataset like MOT17, the
|
||||
gt bboxes are allowed to cross the border of images. Therefore,
|
||||
we don't need to clip the gt bboxes in these cases.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -456,7 +465,8 @@ class CenterCrop(BaseTransform):
|
|||
crop_size: Union[int, Tuple[int, int]],
|
||||
pad_val: Union[Number, Dict[str, Number]] = dict(img=0, seg=255),
|
||||
pad_mode: Optional[str] = None,
|
||||
pad_cfg: dict = dict(type='Pad')
|
||||
pad_cfg: dict = dict(type='Pad'),
|
||||
clip_object_border: bool = True,
|
||||
) -> None: # flake8: noqa
|
||||
super().__init__()
|
||||
assert isinstance(crop_size, int) or (
|
||||
|
@ -471,6 +481,7 @@ class CenterCrop(BaseTransform):
|
|||
self.pad_val = pad_val
|
||||
self.pad_mode = pad_mode
|
||||
self.pad_cfg = pad_cfg
|
||||
self.clip_object_border = clip_object_border
|
||||
|
||||
def _crop_img(self, results: dict, bboxes: np.ndarray) -> None:
|
||||
"""Crop image.
|
||||
|
@ -498,6 +509,47 @@ class CenterCrop(BaseTransform):
|
|||
img = mmcv.imcrop(results['gt_semantic_seg'], bboxes=bboxes)
|
||||
results['gt_semantic_seg'] = img
|
||||
|
||||
def _crop_bboxes(self, results: dict, bboxes: np.ndarray) -> None:
|
||||
"""Update bounding boxes according to CenterCrop.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to transform.
|
||||
bboxes (np.ndarray): Shape (4, ), location of cropped bboxes.
|
||||
"""
|
||||
if 'gt_bboxes' in results:
|
||||
offset_w = bboxes[0]
|
||||
offset_h = bboxes[1]
|
||||
bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h])
|
||||
# gt_bboxes has shape (num_gts, 4) in (tl_x, tl_y, br_x, br_y)
|
||||
# order.
|
||||
gt_bboxes = results['gt_bboxes'] - bbox_offset
|
||||
if self.clip_object_border:
|
||||
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0,
|
||||
results['img'].shape[1])
|
||||
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0,
|
||||
results['img'].shape[0])
|
||||
results['gt_bboxes'] = gt_bboxes
|
||||
|
||||
def _crop_keypoints(self, results: dict, bboxes: np.ndarray) -> None:
|
||||
"""Update key points according to CenterCrop.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to transform.
|
||||
bboxes (np.ndarray): Shape (4, ), location of cropped bboxes.
|
||||
"""
|
||||
if 'gt_keypoints' in results:
|
||||
offset_w = bboxes[0]
|
||||
offset_h = bboxes[1]
|
||||
keypoints_offset = np.array([offset_w, offset_h, 0])
|
||||
# gt_keypoints has shape (N, NK, 3) in (x, y, visibility) order,
|
||||
# NK = number of points per object
|
||||
gt_keypoints = results['gt_keypoints'] - keypoints_offset
|
||||
gt_keypoints[:, :, 0] = np.clip(gt_keypoints[:, :, 0], 0,
|
||||
results['img'].shape[1])
|
||||
gt_keypoints[:, :, 1] = np.clip(gt_keypoints[:, :, 1], 0,
|
||||
results['img'].shape[0])
|
||||
results['gt_keypoints'] = gt_keypoints
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
"""Apply center crop on results.
|
||||
|
||||
|
@ -508,7 +560,7 @@ class CenterCrop(BaseTransform):
|
|||
dict: Results with CenterCropped image and semantic segmentation
|
||||
map.
|
||||
"""
|
||||
crop_height, crop_width = self.crop_size[0], self.crop_size[1]
|
||||
crop_width, crop_height = self.crop_size[0], self.crop_size[1]
|
||||
|
||||
assert 'img' in results, '`img` is not found in results'
|
||||
img = results['img']
|
||||
|
@ -543,6 +595,10 @@ class CenterCrop(BaseTransform):
|
|||
self._crop_img(results, bboxes)
|
||||
# crop the gt_semantic_seg
|
||||
self._crop_seg_map(results, bboxes)
|
||||
# crop the bounding box
|
||||
self._crop_bboxes(results, bboxes)
|
||||
# crop the keypoints
|
||||
self._crop_keypoints(results, bboxes)
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
@ -550,6 +606,7 @@ class CenterCrop(BaseTransform):
|
|||
repr_str += f', crop_size = {self.crop_size}'
|
||||
repr_str += f', pad_val = {self.pad_val}'
|
||||
repr_str += f', pad_mode = {self.pad_mode}'
|
||||
repr_str += f',clip_object_border = {self.clip_object_border}'
|
||||
return repr_str
|
||||
|
||||
|
||||
|
|
|
@ -248,6 +248,10 @@ class TestCenterCrop:
|
|||
def reset_results(results, original_img, gt_semantic_map):
|
||||
results['img'] = copy.deepcopy(original_img)
|
||||
results['gt_semantic_seg'] = copy.deepcopy(gt_semantic_map)
|
||||
results['gt_bboxes'] = np.array([[0, 0, 210, 160],
|
||||
[200, 150, 400, 300]])
|
||||
results['gt_keypoints'] = np.array([[[20, 50, 1]], [[200, 150, 1]],
|
||||
[[300, 225, 1]]])
|
||||
return results
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
@ -293,6 +297,12 @@ class TestCenterCrop:
|
|||
assert (
|
||||
results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
|
||||
88:312]).all()
|
||||
assert np.equal(results['gt_bboxes'],
|
||||
np.array([[0, 0, 122, 122], [112, 112, 224,
|
||||
224]])).all()
|
||||
assert np.equal(
|
||||
results['gt_keypoints'],
|
||||
np.array([[[0, 12, 1]], [[112, 112, 1]], [[212, 187, 1]]])).all()
|
||||
|
||||
# test CenterCrop when size is tuple
|
||||
transform = dict(type='CenterCrop', crop_size=(224, 224))
|
||||
|
@ -306,9 +316,15 @@ class TestCenterCrop:
|
|||
assert (
|
||||
results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
|
||||
88:312]).all()
|
||||
assert np.equal(results['gt_bboxes'],
|
||||
np.array([[0, 0, 122, 122], [112, 112, 224,
|
||||
224]])).all()
|
||||
assert np.equal(
|
||||
results['gt_keypoints'],
|
||||
np.array([[[0, 12, 1]], [[112, 112, 1]], [[212, 187, 1]]])).all()
|
||||
|
||||
# test CenterCrop when crop_height != crop_width
|
||||
transform = dict(type='CenterCrop', crop_size=(256, 224))
|
||||
transform = dict(type='CenterCrop', crop_size=(224, 256))
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
|
@ -319,10 +335,16 @@ class TestCenterCrop:
|
|||
assert (
|
||||
results['gt_semantic_seg'] == self.gt_semantic_map[22:278,
|
||||
88:312]).all()
|
||||
assert np.equal(results['gt_bboxes'],
|
||||
np.array([[0, 0, 122, 138], [112, 128, 224,
|
||||
256]])).all()
|
||||
assert np.equal(
|
||||
results['gt_keypoints'],
|
||||
np.array([[[0, 28, 1]], [[112, 128, 1]], [[212, 203, 1]]])).all()
|
||||
|
||||
# test CenterCrop when crop_size is equal to img.shape
|
||||
img_height, img_width, _ = self.original_img.shape
|
||||
transform = dict(type='CenterCrop', crop_size=(img_height, img_width))
|
||||
transform = dict(type='CenterCrop', crop_size=(img_width, img_height))
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
|
@ -331,10 +353,16 @@ class TestCenterCrop:
|
|||
assert results['width'] == 400
|
||||
assert (results['img'] == self.original_img).all()
|
||||
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
|
||||
assert np.equal(results['gt_bboxes'],
|
||||
np.array([[0, 0, 210, 160], [200, 150, 400,
|
||||
300]])).all()
|
||||
assert np.equal(
|
||||
results['gt_keypoints'],
|
||||
np.array([[[20, 50, 1]], [[200, 150, 1]], [[300, 225, 1]]])).all()
|
||||
|
||||
# test CenterCrop when crop_size is larger than img.shape
|
||||
transform = dict(
|
||||
type='CenterCrop', crop_size=(img_height * 2, img_width * 2))
|
||||
type='CenterCrop', crop_size=(img_width * 2, img_height * 2))
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
|
@ -343,11 +371,17 @@ class TestCenterCrop:
|
|||
assert results['width'] == 400
|
||||
assert (results['img'] == self.original_img).all()
|
||||
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
|
||||
assert np.equal(results['gt_bboxes'],
|
||||
np.array([[0, 0, 210, 160], [200, 150, 400,
|
||||
300]])).all()
|
||||
assert np.equal(
|
||||
results['gt_keypoints'],
|
||||
np.array([[[20, 50, 1]], [[200, 150, 1]], [[300, 225, 1]]])).all()
|
||||
|
||||
# test with padding
|
||||
transform = dict(
|
||||
type='CenterCrop',
|
||||
crop_size=(img_height * 2, img_width // 2),
|
||||
crop_size=(img_width // 2, img_height * 2),
|
||||
pad_mode='constant',
|
||||
pad_val=12)
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
|
@ -359,10 +393,16 @@ class TestCenterCrop:
|
|||
assert results['img'].shape[:2] == results['gt_semantic_seg'].shape
|
||||
assert (results['img'][300:600, 100:300, ...] == 12).all()
|
||||
assert (results['gt_semantic_seg'][300:600, 100:300] == 255).all()
|
||||
assert np.equal(results['gt_bboxes'],
|
||||
np.array([[0, 0, 110, 160], [100, 150, 200,
|
||||
300]])).all()
|
||||
assert np.equal(
|
||||
results['gt_keypoints'],
|
||||
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all()
|
||||
|
||||
transform = dict(
|
||||
type='CenterCrop',
|
||||
crop_size=(img_height * 2, img_width // 2),
|
||||
crop_size=(img_width // 2, img_height * 2),
|
||||
pad_mode='constant',
|
||||
pad_val=dict(img=13, seg=33))
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
|
@ -373,10 +413,16 @@ class TestCenterCrop:
|
|||
assert results['width'] == 200
|
||||
assert (results['img'][300:600, 100:300, ...] == 13).all()
|
||||
assert (results['gt_semantic_seg'][300:600, 100:300] == 33).all()
|
||||
assert np.equal(results['gt_bboxes'],
|
||||
np.array([[0, 0, 110, 160], [100, 150, 200,
|
||||
300]])).all()
|
||||
assert np.equal(
|
||||
results['gt_keypoints'],
|
||||
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all()
|
||||
|
||||
# test CenterCrop when crop_width is smaller than img_width
|
||||
transform = dict(
|
||||
type='CenterCrop', crop_size=(img_height, img_width // 2))
|
||||
type='CenterCrop', crop_size=(img_width // 2, img_height))
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
|
@ -387,10 +433,16 @@ class TestCenterCrop:
|
|||
assert (
|
||||
results['gt_semantic_seg'] == self.gt_semantic_map[:,
|
||||
100:300]).all()
|
||||
assert np.equal(results['gt_bboxes'],
|
||||
np.array([[0, 0, 110, 160], [100, 150, 200,
|
||||
300]])).all()
|
||||
assert np.equal(
|
||||
results['gt_keypoints'],
|
||||
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all()
|
||||
|
||||
# test CenterCrop when crop_height is smaller than img_height
|
||||
transform = dict(
|
||||
type='CenterCrop', crop_size=(img_height // 2, img_width))
|
||||
type='CenterCrop', crop_size=(img_width, img_height // 2))
|
||||
center_crop_module = TRANSFORMS.build(transform)
|
||||
results = self.reset_results(results, self.original_img,
|
||||
self.gt_semantic_map)
|
||||
|
@ -400,6 +452,12 @@ class TestCenterCrop:
|
|||
assert (results['img'] == self.original_img[75:225, ...]).all()
|
||||
assert (results['gt_semantic_seg'] == self.gt_semantic_map[75:225,
|
||||
...]).all()
|
||||
assert np.equal(results['gt_bboxes'],
|
||||
np.array([[0, 0, 210, 85], [200, 75, 400,
|
||||
150]])).all()
|
||||
assert np.equal(
|
||||
results['gt_keypoints'],
|
||||
np.array([[[20, 0, 1]], [[200, 75, 1]], [[300, 150, 1]]])).all()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
condition=torch is None, reason='No torch in current env')
|
||||
|
|
Loading…
Reference in New Issue