mirror of https://github.com/open-mmlab/mmocr.git
Fix RandomCrop
parent
8ac235677e
commit
a4952a6dd6
|
@ -753,27 +753,31 @@ class RandomCrop(BaseTransform):
|
|||
self.min_side_ratio = min_side_ratio
|
||||
|
||||
def _sample_valid_start_end(self, valid_array: np.ndarray, min_len: int,
|
||||
max_start: int,
|
||||
min_end: int) -> Tuple[int, int]:
|
||||
"""Sample a start and end point on a given axis that contains at least
|
||||
one polygon.
|
||||
max_start_idx: int,
|
||||
min_end_idx: int) -> Tuple[int, int]:
|
||||
"""Sample a start and end idx on a given axis that contains at least
|
||||
one polygon. There should be at least one intact polygon bounded by
|
||||
max_start_idx and min_end_idx.
|
||||
|
||||
Args:
|
||||
valid_array (ndarray): Valid area, where 0 means no text area,
|
||||
where 1 means text area.
|
||||
min_len (int): Minimum distance between the two sampling points.
|
||||
max_start (int): Start sampling point maximum start position.
|
||||
min_end (int): End sampling point minimum end position.
|
||||
valid_array (ndarray): A 0-1 mask 1D array indicating valid regions
|
||||
on the axis. 0 indicates text regions which are not allowed to
|
||||
be sampled from.
|
||||
min_len (int): Minimum distance between two start and end points.
|
||||
max_start_idx (int): The maximum start index.
|
||||
min_end_idx (int): The minimum end index.
|
||||
|
||||
Returns:
|
||||
tuple(int, int): Start and end point on a given axis.
|
||||
tuple(int, int): Start and end index on a given axis, where
|
||||
0 <= start < max_start_idx and
|
||||
min_end_idx <= end < len(valid_array).
|
||||
"""
|
||||
assert isinstance(min_len, int)
|
||||
assert len(valid_array) > min_len
|
||||
|
||||
start_array = valid_array.copy()
|
||||
max_start = min(len(start_array) - min_len, max_start)
|
||||
start_array[max_start:] = 0
|
||||
max_start_idx = min(len(start_array) - min_len, max_start_idx)
|
||||
start_array[max_start_idx:] = 0
|
||||
start_array[0] = 1
|
||||
diff_array = np.hstack([0, start_array]) - np.hstack([start_array, 0])
|
||||
region_starts = np.where(diff_array < 0)[0]
|
||||
|
@ -783,23 +787,26 @@ class RandomCrop(BaseTransform):
|
|||
region_ends[region_ind])
|
||||
|
||||
end_array = valid_array.copy()
|
||||
min_end = max(start + min_len, min_end)
|
||||
end_array[:min_end] = 0
|
||||
min_end_idx = max(start + min_len, min_end_idx)
|
||||
end_array[:min_end_idx] = 0
|
||||
end_array[-1] = 1
|
||||
diff_array = np.hstack([0, end_array]) - np.hstack([end_array, 0])
|
||||
region_starts = np.where(diff_array < 0)[0]
|
||||
region_ends = np.where(diff_array > 0)[0]
|
||||
region_ind = np.random.randint(0, len(region_starts))
|
||||
# Note that end index will never be region_ends[region_ind]
|
||||
# and therefore end index is always in range [0, w+1]
|
||||
end = np.random.randint(region_starts[region_ind],
|
||||
region_ends[region_ind])
|
||||
return start, end
|
||||
|
||||
def _sample_crop_box(self, img_size: Tuple[int, int],
|
||||
results: Dict) -> np.ndarray:
|
||||
"""Generate crop box and make sure not to crop the polygon instances.
|
||||
"""Generate crop box which only contains intact polygon instances with
|
||||
the number >= 1.
|
||||
|
||||
Args:
|
||||
img_size (tuple(int)): The image size (h, w).
|
||||
img_size (tuple(int, int)): The image size (h, w).
|
||||
results (dict): The results dict.
|
||||
|
||||
Returns:
|
||||
|
@ -808,27 +815,34 @@ class RandomCrop(BaseTransform):
|
|||
assert isinstance(img_size, tuple)
|
||||
h, w = img_size[:2]
|
||||
|
||||
# Crop box can be represented by any integer numbers in
|
||||
# range [0, w] and [0, h]
|
||||
x_valid_array = np.ones(w + 1, dtype=np.int32)
|
||||
y_valid_array = np.ones(h + 1, dtype=np.int32)
|
||||
|
||||
polygons = results['gt_polygons']
|
||||
x_valid_array = np.ones(w, dtype=np.int32)
|
||||
y_valid_array = np.ones(h, dtype=np.int32)
|
||||
selected_poly = polygons[np.random.randint(0, len(polygons))]
|
||||
selected_poly = selected_poly.reshape((-1, 2)).astype(np.int32)
|
||||
max_x_start = max(np.min(selected_poly[:, 0]), 0)
|
||||
min_x_end = min(np.max(selected_poly[:, 0]), w - 1)
|
||||
max_y_start = max(np.min(selected_poly[:, 1]), 0)
|
||||
min_y_end = min(np.max(selected_poly[:, 1]), h - 1)
|
||||
|
||||
if len(results['gt_polygons']) > 0:
|
||||
polygons = results['gt_polygons']
|
||||
for polygon in polygons:
|
||||
polygon = polygon.reshape((-1, 2)).astype(np.int32)
|
||||
clip_x = np.clip(polygon[:, 0], 0, w - 1)
|
||||
clip_y = np.clip(polygon[:, 1], 0, h - 1)
|
||||
min_x, max_x = np.min(clip_x), np.max(clip_x)
|
||||
min_y, max_y = np.min(clip_y), np.max(clip_y)
|
||||
# Randomly select a polygon that must be inside
|
||||
# the cropped region
|
||||
kept_poly_idx = np.random.randint(0, len(polygons))
|
||||
for i, polygon in enumerate(polygons):
|
||||
polygon = polygon.reshape((-1, 2))
|
||||
|
||||
x_valid_array[min_x:max_x] = 0
|
||||
y_valid_array[min_y:max_y] = 0
|
||||
clip_x = np.clip(polygon[:, 0], 0, w)
|
||||
clip_y = np.clip(polygon[:, 1], 0, h)
|
||||
min_x = np.floor(np.min(clip_x)).astype(np.int32)
|
||||
min_y = np.floor(np.min(clip_y)).astype(np.int32)
|
||||
max_x = np.ceil(np.max(clip_x)).astype(np.int32)
|
||||
max_y = np.ceil(np.max(clip_y)).astype(np.int32)
|
||||
|
||||
x_valid_array[min_x:max_x] = 0
|
||||
y_valid_array[min_y:max_y] = 0
|
||||
|
||||
if i == kept_poly_idx:
|
||||
max_x_start = min_x
|
||||
min_x_end = max_x
|
||||
max_y_start = min_y
|
||||
min_y_end = max_y
|
||||
|
||||
min_w = int(w * self.min_side_ratio)
|
||||
min_h = int(h * self.min_side_ratio)
|
||||
|
@ -883,14 +897,13 @@ class RandomCrop(BaseTransform):
|
|||
valid_texts = []
|
||||
texts = results['gt_texts']
|
||||
|
||||
# for polygons
|
||||
polys = results['gt_polygons']
|
||||
valid_polys = []
|
||||
for idx, poly in enumerate(polys):
|
||||
poly = poly.reshape(-1, 2)
|
||||
poly = poly - (crop_x, crop_y)
|
||||
if is_poly_inside_rect(poly.flatten(), [0, 0, crop_w, crop_h]):
|
||||
valid_polys.append(poly.flatten())
|
||||
poly = (poly - (crop_x, crop_y)).flatten()
|
||||
if is_poly_inside_rect(poly, [0, 0, crop_w, crop_h]):
|
||||
valid_polys.append(poly)
|
||||
valid_labels.append(labels[idx])
|
||||
valid_ignored.append(ignored[idx])
|
||||
if 'gt_texts' in results:
|
||||
|
@ -900,16 +913,7 @@ class RandomCrop(BaseTransform):
|
|||
results['gt_ignored'] = np.array(valid_ignored, dtype=bool)
|
||||
if 'gt_texts' in results:
|
||||
results['gt_texts'] = valid_texts
|
||||
# for bboxes
|
||||
bboxes = results['gt_bboxes']
|
||||
valid_bboxes = []
|
||||
for bbox in bboxes:
|
||||
bbox = bbox.reshape(-1, 2)
|
||||
bbox = bbox - (crop_x, crop_y)
|
||||
if is_poly_inside_rect(
|
||||
bbox2poly(bbox.flatten()), [0, 0, crop_w, crop_h]):
|
||||
valid_bboxes.append(bbox.flatten())
|
||||
assert (len(valid_bboxes) == len(valid_polys))
|
||||
valid_bboxes = [poly2bbox(poly) for poly in results['gt_polygons']]
|
||||
results['gt_bboxes'] = np.array(valid_bboxes).astype(np.float32)
|
||||
|
||||
return results
|
||||
|
|
|
@ -273,8 +273,7 @@ def is_poly_inside_rect(poly: ArrayLike, rect: np.ndarray) -> bool:
|
|||
|
||||
poly = poly2shapely(poly)
|
||||
rect = poly2shapely(bbox2poly(rect))
|
||||
inter = poly.intersection(rect)
|
||||
return inter.area == poly.area
|
||||
return rect.contains(poly)
|
||||
|
||||
|
||||
def offset_polygon(poly: ArrayLike, distance: float) -> ArrayLike:
|
||||
|
|
|
@ -211,25 +211,60 @@ class TestRandomCrop(unittest.TestCase):
|
|||
|
||||
@mock.patch('mmocr.datasets.pipelines.processing.np.random.randint')
|
||||
def test_sample_crop_box(self, mock_randint):
|
||||
|
||||
def rand_min(low, high):
|
||||
return low
|
||||
|
||||
trans = RandomCrop(min_side_ratio=0.3)
|
||||
mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 0, 0, 15]
|
||||
mock_randint.side_effect = rand_min
|
||||
crop_box = trans._sample_crop_box((30, 30), self.data_info.copy())
|
||||
assert np.allclose(np.array(crop_box), np.array([0, 0, 30, 15]))
|
||||
assert np.allclose(np.array(crop_box), np.array([0, 0, 25, 10]))
|
||||
|
||||
def rand_max(low, high):
|
||||
return high - 1
|
||||
|
||||
mock_randint.side_effect = rand_max
|
||||
crop_box = trans._sample_crop_box((30, 30), self.data_info.copy())
|
||||
assert np.allclose(np.array(crop_box), np.array([4, 19, 30, 30]))
|
||||
|
||||
@mock.patch('mmocr.datasets.pipelines.processing.np.random.randint')
|
||||
def test_transform(self, mock_randint):
|
||||
mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 0, 0, 15]
|
||||
|
||||
def rand_min(low, high):
|
||||
return low
|
||||
|
||||
# mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 0, 0, 15]
|
||||
mock_randint.side_effect = rand_min
|
||||
trans = RandomCrop(min_side_ratio=0.3)
|
||||
polygon_target = np.array([5., 5., 25., 5., 25., 10., 5., 10.])
|
||||
bbox_target = np.array([[5., 5., 25., 10.]])
|
||||
results = trans(self.data_info)
|
||||
|
||||
self.assertEqual(results['img'].shape, (15, 30, 3))
|
||||
self.assertEqual(results['img_shape'], (15, 30))
|
||||
self.assertEqual(results['gt_bboxes'].all(), bbox_target.all())
|
||||
self.assertEqual(results['img'].shape, (10, 25, 3))
|
||||
self.assertEqual(results['img_shape'], (10, 25))
|
||||
self.assertTrue(np.allclose(results['gt_bboxes'], bbox_target))
|
||||
self.assertEqual(results['gt_bboxes'].shape, (1, 4))
|
||||
self.assertTrue(len(results['gt_polygons']) == 1)
|
||||
self.assertEqual(results['gt_polygons'][0].all(), polygon_target.all())
|
||||
self.assertEqual(len(results['gt_polygons']), 1)
|
||||
self.assertTrue(np.allclose(results['gt_polygons'][0], polygon_target))
|
||||
self.assertEqual(results['gt_bboxes_labels'][0], 0)
|
||||
self.assertEqual(results['gt_ignored'][0], True)
|
||||
self.assertEqual(results['gt_texts'][0], 'text1')
|
||||
|
||||
def rand_max(low, high):
|
||||
return high - 1
|
||||
|
||||
mock_randint.side_effect = rand_max
|
||||
trans = RandomCrop(min_side_ratio=0.3)
|
||||
polygon_target = np.array([1, 1, 21, 1, 21, 6, 1, 6])
|
||||
bbox_target = np.array([[1, 1, 21, 6]])
|
||||
results = trans(self.data_info)
|
||||
|
||||
self.assertEqual(results['img'].shape, (6, 21, 3))
|
||||
self.assertEqual(results['img_shape'], (6, 21))
|
||||
self.assertTrue(np.allclose(results['gt_bboxes'], bbox_target))
|
||||
self.assertEqual(results['gt_bboxes'].shape, (1, 4))
|
||||
self.assertEqual(len(results['gt_polygons']), 1)
|
||||
self.assertTrue(np.allclose(results['gt_polygons'][0], polygon_target))
|
||||
self.assertEqual(results['gt_bboxes_labels'][0], 0)
|
||||
self.assertEqual(results['gt_ignored'][0], True)
|
||||
self.assertEqual(results['gt_texts'][0], 'text1')
|
||||
|
|
Loading…
Reference in New Issue