[Enhancement] Enhance CenterCrop (#1765)

* enhance centercrop and adjust crop size to (w, h)

* fix comments

* update required keys and docstring
pull/2133/head
Yifei Yang 2022-03-07 11:01:16 +08:00 committed by zhouzaida
parent 2619aa9c8e
commit 2f85d78149
2 changed files with 129 additions and 14 deletions

View File

@ -415,14 +415,16 @@ class Pad(BaseTransform):
@TRANSFORMS.register_module()
class CenterCrop(BaseTransform):
"""Crop the center of the image and segmentation masks. If the crop area
exceeds the original image and ``pad_mode`` is not None, the original image
will be padded before cropping.
"""Crop the center of the image, segmentation masks, bounding boxes and key
points. If the crop area exceeds the original image and ``pad_mode`` is not
None, the original image will be padded before cropping.
Required Keys:
- img
- gt_semantic_seg (optional)
- gt_bboxes (optional)
- gt_keypoints (optional)
Modified Keys:
@ -430,6 +432,8 @@ class CenterCrop(BaseTransform):
- height
- width
- gt_semantic_seg (optional)
- gt_bboxes (optional)
- gt_keypoints (optional)
Added Key:
@ -438,8 +442,8 @@ class CenterCrop(BaseTransform):
Args:
crop_size (Union[int, Tuple[int, int]]): Expected size after cropping
with the format of (h, w). If set to an integer, then cropping
height and width are equal to this integer.
with the format of (w, h). If set to an integer, then cropping
width and height are equal to this integer.
pad_val (Union[Number, Dict[str, Number]]): A dict for
padding value. To specify how to set this argument, please see
the docstring of class ``Pad``. Defaults to
@ -449,6 +453,11 @@ class CenterCrop(BaseTransform):
docstring of class ``Pad``. Defaults to 'constant'.
pad_cfg (str): Base config for padding. Defaults to
``dict(type='Pad')``.
clip_object_border (bool): Whether to clip the objects
outside the border of the image. In some dataset like MOT17, the
gt bboxes are allowed to cross the border of images. Therefore,
we don't need to clip the gt bboxes in these cases.
Defaults to True.
"""
def __init__(
@ -456,7 +465,8 @@ class CenterCrop(BaseTransform):
crop_size: Union[int, Tuple[int, int]],
pad_val: Union[Number, Dict[str, Number]] = dict(img=0, seg=255),
pad_mode: Optional[str] = None,
pad_cfg: dict = dict(type='Pad')
pad_cfg: dict = dict(type='Pad'),
clip_object_border: bool = True,
) -> None: # flake8: noqa
super().__init__()
assert isinstance(crop_size, int) or (
@ -471,6 +481,7 @@ class CenterCrop(BaseTransform):
self.pad_val = pad_val
self.pad_mode = pad_mode
self.pad_cfg = pad_cfg
self.clip_object_border = clip_object_border
def _crop_img(self, results: dict, bboxes: np.ndarray) -> None:
"""Crop image.
@ -498,6 +509,47 @@ class CenterCrop(BaseTransform):
img = mmcv.imcrop(results['gt_semantic_seg'], bboxes=bboxes)
results['gt_semantic_seg'] = img
def _crop_bboxes(self, results: dict, bboxes: np.ndarray) -> None:
"""Update bounding boxes according to CenterCrop.
Args:
results (dict): Result dict contains the data to transform.
bboxes (np.ndarray): Shape (4, ), location of cropped bboxes.
"""
if 'gt_bboxes' in results:
offset_w = bboxes[0]
offset_h = bboxes[1]
bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h])
# gt_bboxes has shape (num_gts, 4) in (tl_x, tl_y, br_x, br_y)
# order.
gt_bboxes = results['gt_bboxes'] - bbox_offset
if self.clip_object_border:
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0,
results['img'].shape[1])
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0,
results['img'].shape[0])
results['gt_bboxes'] = gt_bboxes
def _crop_keypoints(self, results: dict, bboxes: np.ndarray) -> None:
"""Update key points according to CenterCrop.
Args:
results (dict): Result dict contains the data to transform.
bboxes (np.ndarray): Shape (4, ), location of cropped bboxes.
"""
if 'gt_keypoints' in results:
offset_w = bboxes[0]
offset_h = bboxes[1]
keypoints_offset = np.array([offset_w, offset_h, 0])
# gt_keypoints has shape (N, NK, 3) in (x, y, visibility) order,
# NK = number of points per object
gt_keypoints = results['gt_keypoints'] - keypoints_offset
gt_keypoints[:, :, 0] = np.clip(gt_keypoints[:, :, 0], 0,
results['img'].shape[1])
gt_keypoints[:, :, 1] = np.clip(gt_keypoints[:, :, 1], 0,
results['img'].shape[0])
results['gt_keypoints'] = gt_keypoints
def transform(self, results: dict) -> dict:
"""Apply center crop on results.
@ -508,7 +560,7 @@ class CenterCrop(BaseTransform):
dict: Results with CenterCropped image and semantic segmentation
map.
"""
crop_height, crop_width = self.crop_size[0], self.crop_size[1]
crop_width, crop_height = self.crop_size[0], self.crop_size[1]
assert 'img' in results, '`img` is not found in results'
img = results['img']
@ -543,6 +595,10 @@ class CenterCrop(BaseTransform):
self._crop_img(results, bboxes)
# crop the gt_semantic_seg
self._crop_seg_map(results, bboxes)
# crop the bounding box
self._crop_bboxes(results, bboxes)
# crop the keypoints
self._crop_keypoints(results, bboxes)
return results
def __repr__(self) -> str:
@ -550,6 +606,7 @@ class CenterCrop(BaseTransform):
repr_str += f', crop_size = {self.crop_size}'
repr_str += f', pad_val = {self.pad_val}'
repr_str += f', pad_mode = {self.pad_mode}'
repr_str += f',clip_object_border = {self.clip_object_border}'
return repr_str

View File

@ -248,6 +248,10 @@ class TestCenterCrop:
def reset_results(results, original_img, gt_semantic_map):
results['img'] = copy.deepcopy(original_img)
results['gt_semantic_seg'] = copy.deepcopy(gt_semantic_map)
results['gt_bboxes'] = np.array([[0, 0, 210, 160],
[200, 150, 400, 300]])
results['gt_keypoints'] = np.array([[[20, 50, 1]], [[200, 150, 1]],
[[300, 225, 1]]])
return results
@pytest.mark.skipif(
@ -293,6 +297,12 @@ class TestCenterCrop:
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
88:312]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 122, 122], [112, 112, 224,
224]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 12, 1]], [[112, 112, 1]], [[212, 187, 1]]])).all()
# test CenterCrop when size is tuple
transform = dict(type='CenterCrop', crop_size=(224, 224))
@ -306,9 +316,15 @@ class TestCenterCrop:
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[38:262,
88:312]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 122, 122], [112, 112, 224,
224]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 12, 1]], [[112, 112, 1]], [[212, 187, 1]]])).all()
# test CenterCrop when crop_height != crop_width
transform = dict(type='CenterCrop', crop_size=(256, 224))
transform = dict(type='CenterCrop', crop_size=(224, 256))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
@ -319,10 +335,16 @@ class TestCenterCrop:
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[22:278,
88:312]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 122, 138], [112, 128, 224,
256]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 28, 1]], [[112, 128, 1]], [[212, 203, 1]]])).all()
# test CenterCrop when crop_size is equal to img.shape
img_height, img_width, _ = self.original_img.shape
transform = dict(type='CenterCrop', crop_size=(img_height, img_width))
transform = dict(type='CenterCrop', crop_size=(img_width, img_height))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
@ -331,10 +353,16 @@ class TestCenterCrop:
assert results['width'] == 400
assert (results['img'] == self.original_img).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 210, 160], [200, 150, 400,
300]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[20, 50, 1]], [[200, 150, 1]], [[300, 225, 1]]])).all()
# test CenterCrop when crop_size is larger than img.shape
transform = dict(
type='CenterCrop', crop_size=(img_height * 2, img_width * 2))
type='CenterCrop', crop_size=(img_width * 2, img_height * 2))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
@ -343,11 +371,17 @@ class TestCenterCrop:
assert results['width'] == 400
assert (results['img'] == self.original_img).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 210, 160], [200, 150, 400,
300]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[20, 50, 1]], [[200, 150, 1]], [[300, 225, 1]]])).all()
# test with padding
transform = dict(
type='CenterCrop',
crop_size=(img_height * 2, img_width // 2),
crop_size=(img_width // 2, img_height * 2),
pad_mode='constant',
pad_val=12)
center_crop_module = TRANSFORMS.build(transform)
@ -359,10 +393,16 @@ class TestCenterCrop:
assert results['img'].shape[:2] == results['gt_semantic_seg'].shape
assert (results['img'][300:600, 100:300, ...] == 12).all()
assert (results['gt_semantic_seg'][300:600, 100:300] == 255).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 110, 160], [100, 150, 200,
300]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all()
transform = dict(
type='CenterCrop',
crop_size=(img_height * 2, img_width // 2),
crop_size=(img_width // 2, img_height * 2),
pad_mode='constant',
pad_val=dict(img=13, seg=33))
center_crop_module = TRANSFORMS.build(transform)
@ -373,10 +413,16 @@ class TestCenterCrop:
assert results['width'] == 200
assert (results['img'][300:600, 100:300, ...] == 13).all()
assert (results['gt_semantic_seg'][300:600, 100:300] == 33).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 110, 160], [100, 150, 200,
300]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all()
# test CenterCrop when crop_width is smaller than img_width
transform = dict(
type='CenterCrop', crop_size=(img_height, img_width // 2))
type='CenterCrop', crop_size=(img_width // 2, img_height))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
@ -387,10 +433,16 @@ class TestCenterCrop:
assert (
results['gt_semantic_seg'] == self.gt_semantic_map[:,
100:300]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 110, 160], [100, 150, 200,
300]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[0, 50, 1]], [[100, 150, 1]], [[200, 225, 1]]])).all()
# test CenterCrop when crop_height is smaller than img_height
transform = dict(
type='CenterCrop', crop_size=(img_height // 2, img_width))
type='CenterCrop', crop_size=(img_width, img_height // 2))
center_crop_module = TRANSFORMS.build(transform)
results = self.reset_results(results, self.original_img,
self.gt_semantic_map)
@ -400,6 +452,12 @@ class TestCenterCrop:
assert (results['img'] == self.original_img[75:225, ...]).all()
assert (results['gt_semantic_seg'] == self.gt_semantic_map[75:225,
...]).all()
assert np.equal(results['gt_bboxes'],
np.array([[0, 0, 210, 85], [200, 75, 400,
150]])).all()
assert np.equal(
results['gt_keypoints'],
np.array([[[20, 0, 1]], [[200, 75, 1]], [[300, 150, 1]]])).all()
@pytest.mark.skipif(
condition=torch is None, reason='No torch in current env')