[Feature] Add Resize

pull/1178/head
jiangqing.vendor 2022-05-18 13:33:28 +00:00 committed by gaotongxiao
parent 6478499073
commit f29853d9cd
8 changed files with 348 additions and 7 deletions

View File

@ -19,4 +19,5 @@ mmocr/models/textdet/detectors/text_detector_mixin.py
# It will be covered by tests of any det model implemented in future
mmocr/models/textdet/detectors/single_stage_text_detector.py
mmocr/datasets/pipelines/transforms.py
mmocr/datasets/pipelines/dbnet_transforms.py

View File

@ -9,7 +9,7 @@ from .ocr_seg_targets import OCRSegTargets
from .ocr_transforms import (FancyPCA, NormalizeOCR, OnlineCropOCR,
OpencvToPil, PilToOpencv, RandomPaddingOCR,
RandomRotateImageBox, ResizeOCR, ToTensorOCR)
from .processing import PyramidRescale
from .processing import PyramidRescale, Resize
from .test_time_aug import MultiRotateAugOCR
from .textdet_targets import (DBNetTargets, FCENetTargets, PANetTargets,
TextSnakeTargets)
@ -31,5 +31,5 @@ __all__ = [
'LoadImageFromNdarray', 'sort_vertex8', 'FCENetTargets', 'RandomScaling',
'RandomCropFlip', 'NerTransform', 'ToTensorNER', 'ResizeNoImg',
'PyramidRescale', 'OneOfWrapper', 'RandomWrapper', 'TorchVisionWrapper',
'LoadImageFromLMDB'
'LoadImageFromLMDB', 'Resize'
]

View File

@ -4,10 +4,13 @@ from typing import Dict, Tuple
import cv2
import mmcv
import numpy as np
from mmcv.image.geometric import _scale_size
from mmcv.transforms import Resize as MMCV_Resize
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.utils import cache_randomness
from mmocr.registry import TRANSFORMS
from mmocr.utils import crop_polygon, rescale_bboxes, rescale_polygon
@TRANSFORMS.register_module()
@ -96,3 +99,114 @@ class PyramidRescale(BaseTransform):
repr_str += f', base_w = {self.base_w}'
repr_str += f', base_h = {self.base_h})'
return repr_str
@TRANSFORMS.register_module()
class Resize(MMCV_Resize):
"""Resize image & bboxes & polygons.
This transform resizes the input image according to ``scale`` or
``scale_factor``. Bboxes and polygons are then resized
with the same scale factor.
if ``scale`` and ``scale_factor`` are both set, it will use ``scale`` to
resize.
Required Keys:
- img
- img_shape
- gt_bboxes
- gt_polygons
Modified Keys:
- img
- img_shape
- gt_bboxes
- gt_polygons
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
scale (int or tuple): Image scales for resizing. Defaults to None
scale_factor (float or tuple[float]): Scale factors for resizing. It's
either a factor applicable to both dimensions or in the form of
(scale_w, scale_h). Defaults to None.
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image. Defaults to False.
clip_object_border (bool): Whether to clip the objects
outside the border of the image. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear" for 'pillow' backend. Defaults
to 'bilinear'.
"""
def _resize_bboxes(self, results: dict) -> None:
"""Resize bounding boxes with ``results['scale_factor']``."""
if results.get('gt_bboxes', None) is not None:
bboxes = rescale_bboxes(results['gt_bboxes'],
results['scale_factor'])
if self.clip_object_border:
bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0,
results['img_shape'][1])
bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0,
results['img_shape'][0])
results['gt_bboxes'] = bboxes.astype(np.float32)
def _resize_polygons(self, results: dict) -> None:
"""Resize polygons with ``results['scale_factor']``."""
if results.get('gt_polygons', None) is not None:
polygons = results['gt_polygons']
polygons_resize = []
for idx, polygon in enumerate(polygons):
polygon = rescale_polygon(polygon, results['scale_factor'])
if self.clip_object_border:
crop_bbox = np.array([
0, 0, results['img_shape'][1], results['img_shape'][0]
])
polygon = crop_polygon(polygon, crop_bbox)
if polygon is not None:
polygons_resize.append(polygon.astype(np.float32))
else:
polygons_resize.append(np.zeros_like(polygons[idx]))
results['gt_polygons'] = polygons_resize
def transform(self, results: dict) -> dict:
"""Transform function to resize images, bounding boxes and polygons.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img', 'gt_bboxes', 'gt_polygons',
'scale', 'scale_factor', 'height', 'width', and 'keep_ratio' keys
are updated in result dict.
"""
if self.scale:
results['scale'] = self.scale
else:
img_shape = results['img'].shape[:2]
results['scale'] = _scale_size(img_shape[::-1], self.scale_factor)
self._resize_img(results)
self._resize_bboxes(results)
self._resize_polygons(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(scale={self.scale}, '
repr_str += f'scale_factor={self.scale_factor}, '
repr_str += f'keep_ratio={self.keep_ratio}, '
repr_str += f'clip_object_border={self.clip_object_border}), '
repr_str += f'backend={self.backend}), '
repr_str += f'interpolation={self.interpolation})'
return repr_str

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, build_from_cfg
from .bbox_utils import bbox2poly, rescale_bboxes
from .box_util import (bezier_to_polygon, is_on_same_line, sort_points,
stitch_boxes_into_lines)
from .check_argument import (equal_len, is_2dlist, is_3dlist, is_none_or_type,
@ -13,7 +14,7 @@ from .img_util import drop_orientation, is_not_png
from .lmdb_util import recog2lmdb
from .logger import get_root_logger
from .model import revert_sync_batchnorm
from .polygon_utils import rescale_polygon, rescale_polygons
from .polygon_utils import crop_polygon, rescale_polygon, rescale_polygons
from .setup_env import setup_multi_processes
from .string_util import StringStrip
@ -25,5 +26,6 @@ __all__ = [
'is_on_same_line', 'stitch_boxes_into_lines', 'StringStrip',
'revert_sync_batchnorm', 'bezier_to_polygon', 'sort_points',
'setup_multi_processes', 'recog2lmdb', 'dump_ocr_data',
'recog_anno_to_imginfo', 'rescale_polygons', 'rescale_polygon'
'recog_anno_to_imginfo', 'rescale_polygons', 'rescale_polygon',
'rescale_bboxes', 'bbox2poly', 'crop_polygon'
]

View File

@ -0,0 +1,74 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import numpy as np
def rescale_bbox(bbox: np.ndarray,
scale_factor: Tuple[int, int],
mode: str = 'mul') -> np.ndarray:
"""Rescale a bounding box according to scale_factor.
The behavior is different depending on the mode. When mode is 'mul', the
coordinates will be multiplied by scale_factor, which is usually used in
preprocessing transforms such as :func:`Resize`.
The coordinates will be divided by scale_factor if mode is 'div'. It can be
used in postprocessors to recover the bbox in the original image size.
Args:
bbox (ndarray): A bounding box [x1, y1, x2, y2].
scale_factor (tuple(int, int)): (w_scale, h_scale).
model (str): Rescale mode. Can be 'mul' or 'div'. Defaults to 'mul'.
Returns:
np.ndarray: Rescaled bbox.
"""
assert mode in ['mul', 'div']
bbox = np.array(bbox, dtype=np.float32)
bbox_shape = bbox.shape
reshape_bbox = bbox.reshape(-1, 2)
scale_factor = np.array(scale_factor, dtype=float)
if mode == 'div':
scale_factor = 1 / scale_factor
bbox = (reshape_bbox * scale_factor[None]).reshape(bbox_shape)
return bbox
def rescale_bboxes(bboxes: np.ndarray,
scale_factor: Tuple[int, int],
mode: str = 'mul') -> np.ndarray:
"""Rescale bboxes according to scale_factor.
The behavior is different depending on the mode. When mode is 'mul', the
coordinates will be multiplied by scale_factor, which is usually used in
preprocessing transforms such as :func:`Resize`.
The coordinates will be divided by scale_factor if mode is 'div'. It can be
used in postprocessors to recover the bboxes in the original
image size.
Args:
bboxes (np.ndarray]): Bounding bboxes in shape (N, 4)
scale_factor (tuple(int, int)): (w_scale, h_scale).
model (str): Rescale mode. Can be 'mul' or 'div'. Defaults to 'mul'.
Returns:
list[np.ndarray]: Rescaled bboxes.
"""
bboxes = rescale_bbox(bboxes, scale_factor, mode)
return bboxes
def bbox2poly(bbox: np.ndarray) -> np.ndarray:
"""Converting a bounding box to a polygon.
Args:
bbox (np.array): The bounding box with two points [x1, y1, x2, y2].
Returns:
np.array: The converted polygon [x1, y1, x2, y1, x2, y2, x1, y2].
"""
assert len(bbox) == 4
return np.array([
bbox[0], bbox[1], bbox[2], bbox[1], bbox[2], bbox[3], bbox[0], bbox[3]
],
dtype=np.float32)

View File

@ -3,6 +3,9 @@ from typing import Sequence, Tuple
import numpy as np
from numpy.typing import ArrayLike
from shapely.geometry import Polygon
from mmocr.utils import bbox2poly
def rescale_polygon(polygon: ArrayLike,
@ -22,7 +25,7 @@ def rescale_polygon(polygon: ArrayLike,
to an 1-D numpy array. E.g. list[float], np.ndarray,
or torch.Tensor. Polygon is written in
[x1, y1, x2, y2, ...].
scale_factor (tuple(int, int)): (w_scale, h_scale)
scale_factor (tuple(int, int)): (w_scale, h_scale).
model (str): Rescale mode. Can be 'mul' or 'div'. Defaults to 'mul'.
Returns:
@ -57,7 +60,7 @@ def rescale_polygons(polygons: Sequence[ArrayLike],
[x1, y1, x2, y2, ...] and in any form can be converted
to an 1-D numpy array. E.g. list[list[float]],
list[np.ndarray], or list[torch.Tensor].
scale_factor (tuple(int, int)): (w_scale, h_scale)
scale_factor (tuple(int, int)): (w_scale, h_scale).
model (str): Rescale mode. Can be 'mul' or 'div'. Defaults to 'mul'.
Returns:
@ -67,3 +70,26 @@ def rescale_polygons(polygons: Sequence[ArrayLike],
for polygon in polygons:
results.append(rescale_polygon(polygon, scale_factor, mode))
return results
def crop_polygon(polygon: ArrayLike, crop_box: np.ndarray) -> np.ndarray:
"""Crop polygon to be within a box region.
Args:
polygon (ndarray): polygon in shape (N, ).
crop_box (ndarray): target box region in shape (4, ).
Returns:
np.array or None: Cropped polygon.
"""
polygon = np.asarray(polygon, dtype=np.float32)
crop_box = np.asarray(crop_box, dtype=np.float32)
poly = Polygon(polygon.reshape(-1, 2))
crop_poly = Polygon(bbox2poly(crop_box).reshape(-1, 2))
poly_cropped = poly.intersection(crop_poly)
if poly_cropped.area == 0.:
# If polygon is outside crop_box region, return None.
return None
else:
poly_cropped = np.array(poly_cropped.boundary.xy)[:, :-1]
return poly_cropped.reshape(-1)

View File

@ -4,7 +4,7 @@ import unittest
import numpy as np
from mmocr.datasets.pipelines import PyramidRescale
from mmocr.datasets.pipelines import PyramidRescale, Resize
class TestPyramidRescale(unittest.TestCase):
@ -53,3 +53,100 @@ class TestPyramidRescale(unittest.TestCase):
repr(transform),
('PyramidRescale(factor = 4, randomize_factor = False, '
'base_w = 128, base_h = 512)'))
class TestResize(unittest.TestCase):
def setUp(self):
self.data_info1 = dict(
img=np.random.random((600, 800, 3)),
gt_bboxes=np.array([[0, 0, 60, 100]]),
gt_polygons=[np.array([0, 0, 200, 0, 200, 100, 0, 100])])
self.data_info2 = dict(
img=np.random.random((200, 300, 3)),
gt_bboxes=np.array([[0, 0, 400, 600]]),
gt_polygons=[np.array([0, 0, 400, 0, 400, 400, 0, 400])])
self.data_info3 = dict(
img=np.random.random((200, 300, 3)),
gt_bboxes=np.array([[400, 400, 600, 600]]),
gt_polygons=[np.array([400, 400, 500, 400, 500, 600, 400, 600])])
def test_resize(self):
# test keep_ratio is True
transform = Resize(scale=(400, 400), keep_ratio=True)
results = transform(copy.deepcopy(self.data_info1.copy()))
self.assertEqual(results['img'].shape[:2], (300, 400))
self.assertEqual(results['img_shape'], (300, 400))
self.assertEqual(results['scale'], (400, 300))
self.assertEqual(results['scale_factor'], (400 / 800, 300 / 600))
self.assertEqual(results['keep_ratio'], True)
# test keep_ratio is False
transform = Resize(scale=(400, 400))
results = transform(copy.deepcopy(self.data_info1.copy()))
self.assertEqual(results['img'].shape[:2], (400, 400))
self.assertEqual(results['img_shape'], (400, 400))
self.assertEqual(results['scale'], (400, 400))
self.assertEqual(results['scale_factor'], (400 / 800, 400 / 600))
self.assertEqual(results['keep_ratio'], False)
# test resize_bboxes/polygons
transform = Resize(scale_factor=(1.5, 2))
results = transform(copy.deepcopy(self.data_info1.copy()))
self.assertEqual(results['img'].shape[:2], (1200, 1200))
self.assertEqual(results['img_shape'], (1200, 1200))
self.assertEqual(results['scale'], (1200, 1200))
self.assertEqual(results['scale_factor'], (1.5, 2))
self.assertEqual(results['keep_ratio'], False)
self.assertTrue(
results['gt_bboxes'].all() == np.array([[0, 0, 90, 200]]).all())
self.assertTrue(results['gt_polygons'][0].all() == np.array(
[0, 0, 300, 0, 300, 200, 0, 200]).all())
# test clip_object_border = False
transform = Resize(scale=(150, 100), clip_object_border=False)
results = transform(self.data_info2.copy())
self.assertEqual(results['img'].shape[:2], (100, 150))
self.assertEqual(results['img_shape'], (100, 150))
self.assertEqual(results['scale'], (150, 100))
self.assertEqual(results['scale_factor'], (150. / 300., 100. / 200.))
self.assertEqual(results['keep_ratio'], False)
self.assertTrue(
results['gt_bboxes'].all() == np.array([0, 0, 200, 300]).all())
self.assertTrue(results['gt_polygons'][0].all() == np.array(
[0, 0, 200, 0, 200, 200, 0, 200]).all())
# test clip_object_border = True
transform = Resize(scale=(150, 100), clip_object_border=True)
results = transform(self.data_info2.copy())
self.assertEqual(results['img'].shape[:2], (100, 150))
self.assertEqual(results['img_shape'], (100, 150))
self.assertEqual(results['scale'], (150, 100))
self.assertEqual(results['scale_factor'], (150. / 300., 100. / 200.))
self.assertEqual(results['keep_ratio'], False)
self.assertTrue(
results['gt_bboxes'].all() == np.array([0, 0, 150, 100]).all())
self.assertTrue(results['gt_polygons'][0].shape == (8, ))
self.assertTrue(results['gt_polygons'][0].all() == np.array(
[0, 0, 150, 0, 150, 100, 0, 100]).all())
# test clip_object_border = True and polygon outside image
transform = Resize(scale=(150, 100), clip_object_border=True)
results = transform(self.data_info3)
self.assertEqual(results['img'].shape[:2], (100, 150))
self.assertEqual(results['img_shape'], (100, 150))
self.assertEqual(results['scale'], (150, 100))
self.assertEqual(results['scale_factor'], (150. / 300., 100. / 200.))
self.assertEqual(results['keep_ratio'], False)
self.assertEqual(results['gt_polygons'][0].all(),
np.array([0., 0., 0., 0., 0., 0., 0., 0.]).all())
self.assertEqual(results['gt_bboxes'].all(),
np.array([[150., 100., 150., 100.]]).all())
def test_repr(self):
transform = Resize(scale=(2000, 2000), keep_ratio=True)
self.assertEqual(
repr(transform), ('Resize(scale=(2000, 2000), '
'scale_factor=None, keep_ratio=True, '
'clip_object_border=True), backend=cv2), '
'interpolation=bilinear)'))

View File

@ -5,6 +5,33 @@ import numpy as np
import torch
from mmocr.utils import rescale_polygon, rescale_polygons
from mmocr.utils.polygon_utils import crop_polygon
class TestCropPolygon(unittest.TestCase):
def test_crop_polygon(self):
# polygon cross box
polygon = np.array([20., -10., 40., 10., 10., 40., -10., 20.])
crop_box = np.array([0., 0., 60., 60.])
target_poly_cropped = np.array([[10., 40., 30., 10., 0., 0., 10.],
[40., 10., 0., 0., 10., 30., 40.]])
poly_cropped = crop_polygon(polygon, crop_box)
self.assertTrue(target_poly_cropped.all() == poly_cropped.all())
# polygon inside box
polygon = np.array([0., 0., 30., 0., 30., 30., 0., 30.]).reshape(-1, 2)
crop_box = np.array([0., 0., 60., 60.])
target_poly_cropped = polygon
poly_cropped = crop_polygon(polygon, crop_box)
self.assertTrue(target_poly_cropped.all() == poly_cropped.all())
# polygon outside box
polygon = np.array([0., 0., 30., 0., 30., 30., 0., 30.]).reshape(-1, 2)
crop_box = np.array([80., 80., 90., 90.])
target_poly_cropped = polygon
poly_cropped = crop_polygon(polygon, crop_box)
self.assertEqual(poly_cropped, None)
class TestPolygonUtils(unittest.TestCase):