From e29b5b8abe2a71466ee196d16cc106aec77d588b Mon Sep 17 00:00:00 2001 From: Jianyong Chen <46100303+HolyCrap96@users.noreply.github.com> Date: Wed, 28 Apr 2021 22:02:57 -0500 Subject: [PATCH] Fix issue 122 (#130) * fix #122: textsnake targets adaptation * fix #122: textsnake targets adaptation * add unittest * fix format * fix textsnake unittest on cpu * fix unit test coverage * add unit test --- .../textdet_targets/textsnake_targets.py | 62 ++++++++---- mmocr/datasets/pipelines/transforms.py | 36 ++++--- tests/test_dataset/test_textdet_targets.py | 10 +- tests/test_dataset/test_transforms.py | 96 ++++++++++++++++++- tests/test_models/test_detector.py | 22 +++-- 5 files changed, 184 insertions(+), 42 deletions(-) diff --git a/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py b/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py index 73f4597c..8d10216e 100644 --- a/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py +++ b/mmocr/datasets/pipelines/textdet_targets/textsnake_targets.py @@ -80,19 +80,43 @@ class TextSnakeTargets(BaseTextDetTargets): edge_vec = pad_points[1:] - pad_points[:-1] theta_sum = [] - + adjacent_vec_theta = [] for i, edge_vec1 in enumerate(edge_vec): adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]] adjacent_edge_vec = edge_vec[adjacent_ind] temp_theta_sum = np.sum( self.vector_angle(edge_vec1, adjacent_edge_vec)) + temp_adjacent_theta = self.vector_angle( + adjacent_edge_vec[0], adjacent_edge_vec[1]) theta_sum.append(temp_theta_sum) - theta_sum = np.array(theta_sum) - head_start, tail_start = np.argsort(theta_sum)[::-1][0:2] + adjacent_vec_theta.append(temp_adjacent_theta) + theta_sum_score = np.array(theta_sum) / np.pi + adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi + poly_center = np.mean(points, axis=0) + edge_dist = np.maximum( + norm(pad_points[1:] - poly_center, axis=-1), + norm(pad_points[:-1] - poly_center, axis=-1)) + dist_score = edge_dist / np.max(edge_dist) + position_score = np.zeros(len(edge_vec)) + score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score + score += 0.35 * dist_score + if len(points) % 2 == 0: + position_score[(len(score) // 2 - 1)] += 1 + position_score[-1] += 1 + score += 0.1 * position_score + pad_score = np.concatenate([score, score]) + score_matrix = np.zeros((len(score), len(score) - 3)) + x = np.arange(len(score) - 3) / float(len(score) - 4) + gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power( + (x - 0.5) / 0.5, 2.) / 2) + gaussian = gaussian / np.max(gaussian) + for i in range(len(score)): + score_matrix[i, :] = score[i] + pad_score[ + (i + 2):(i + len(score) - 1)] * gaussian * 0.3 - if (abs(head_start - tail_start) < 2 - or abs(head_start - tail_start) > 12): - tail_start = (head_start + len(points) // 2) % len(points) + head_start, tail_increment = np.unravel_index( + score_matrix.argmax(), score_matrix.shape) + tail_start = (head_start + tail_increment + 2) % len(points) head_end = (head_start + 1) % len(points) tail_end = (tail_start + 1) % len(points) @@ -297,16 +321,15 @@ class TextSnakeTargets(BaseTextDetTargets): sin_theta = self.vector_sin(text_direction) cos_theta = self.vector_cos(text_direction) - pnt_tl = center_line[i] + (top_line[i] - - center_line[i]) * region_shrink_ratio - pnt_tr = center_line[i + 1] + ( + tl = center_line[i] + (top_line[i] - + center_line[i]) * region_shrink_ratio + tr = center_line[i + 1] + ( top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio - pnt_br = center_line[i + 1] + ( + br = center_line[i + 1] + ( bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio - pnt_bl = center_line[i] + (bot_line[i] - - center_line[i]) * region_shrink_ratio - current_center_box = np.vstack([pnt_tl, pnt_tr, pnt_br, - pnt_bl]).astype(np.int32) + bl = center_line[i] + (bot_line[i] - + center_line[i]) * region_shrink_ratio + current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32) cv2.fillPoly(center_region_mask, [current_center_box], color=1) cv2.fillPoly(sin_map, [current_center_box], color=sin_theta) @@ -344,8 +367,15 @@ class TextSnakeTargets(BaseTextDetTargets): assert len(poly) == 1 text_instance = [[poly[0][i], poly[0][i + 1]] for i in range(0, len(poly[0]), 2)] - polygon_points = np.array( - text_instance, dtype=np.int32).reshape(-1, 2) + polygon_points = np.array(text_instance).reshape(-1, 2) + + n = len(polygon_points) + keep_inds = [] + for i in range(n): + if norm(polygon_points[i] - + polygon_points[(i + 1) % n]) > 1e-5: + keep_inds.append(i) + polygon_points = polygon_points[keep_inds] _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points) resampled_top_line, resampled_bot_line = self.resample_sidelines( diff --git a/mmocr/datasets/pipelines/transforms.py b/mmocr/datasets/pipelines/transforms.py index 537c107e..be01d659 100644 --- a/mmocr/datasets/pipelines/transforms.py +++ b/mmocr/datasets/pipelines/transforms.py @@ -408,37 +408,42 @@ class RandomCropPolyInstances: region_ends[region_ind]) return start, end - def sample_crop_box(self, img_size, masks): + def sample_crop_box(self, img_size, results): """Generate crop box and make sure not to crop the polygon instances. Args: - img_size (tuple(int)): The image size. - masks (list[list[ndarray]]): The polygon masks. + img_size (tuple(int)): The image size (h, w). + results (dict): The results dict. """ assert isinstance(img_size, tuple) h, w = img_size[:2] + key_masks = results[self.instance_key].masks x_valid_array = np.ones(w, dtype=np.int32) y_valid_array = np.ones(h, dtype=np.int32) - selected_mask = masks[np.random.randint(0, len(masks))] + selected_mask = key_masks[np.random.randint(0, len(key_masks))] selected_mask = selected_mask[0].reshape((-1, 2)).astype(np.int32) max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0) min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1) max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0) min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1) - for mask in masks: - assert len(mask) == 1 - mask = mask[0].reshape((-1, 2)).astype(np.int32) - clip_x = np.clip(mask[:, 0], 0, w - 1) - clip_y = np.clip(mask[:, 1], 0, h - 1) - min_x, max_x = np.min(clip_x), np.max(clip_x) - min_y, max_y = np.min(clip_y), np.max(clip_y) + for key in results.get('mask_fields', []): + if len(results[key].masks) == 0: + continue + masks = results[key].masks + for mask in masks: + assert len(mask) == 1 + mask = mask[0].reshape((-1, 2)).astype(np.int32) + clip_x = np.clip(mask[:, 0], 0, w - 1) + clip_y = np.clip(mask[:, 1], 0, h - 1) + min_x, max_x = np.min(clip_x), np.max(clip_x) + min_y, max_y = np.min(clip_y), np.max(clip_y) - x_valid_array[min_x - 2:max_x + 3] = 0 - y_valid_array[min_y - 2:max_y + 3] = 0 + x_valid_array[min_x - 2:max_x + 3] = 0 + y_valid_array[min_y - 2:max_y + 3] = 0 min_w = int(w * self.min_side_ratio) min_h = int(h * self.min_side_ratio) @@ -458,9 +463,10 @@ class RandomCropPolyInstances: return img[bbox[1]:bbox[3], bbox[0]:bbox[2]] def __call__(self, results): + if len(results[self.instance_key].masks) < 1: + return results if np.random.random_sample() < self.crop_ratio: - crop_box = self.sample_crop_box(results['img'].shape, - results[self.instance_key].masks) + crop_box = self.sample_crop_box(results['img'].shape, results) results['crop_region'] = crop_box img = self.crop_img(results['img'], crop_box) results['img'] = img diff --git a/tests/test_dataset/test_textdet_targets.py b/tests/test_dataset/test_textdet_targets.py index e06353a6..6fa53984 100644 --- a/tests/test_dataset/test_textdet_targets.py +++ b/tests/test_dataset/test_textdet_targets.py @@ -145,12 +145,20 @@ def test_gen_textsnake_targets(mock_show_feature): assert np.allclose(target_generator.resample_step, 4.0) assert np.allclose(target_generator.center_region_shrink_ratio, 0.3) - # test find_head_tail + # test find_head_tail for quadrangle polygon = np.array([[1.0, 1.0], [5.0, 1.0], [5.0, 3.0], [1.0, 3.0]]) head_inds, tail_inds = target_generator.find_head_tail(polygon, 2.0) assert np.allclose(head_inds, [3, 0]) assert np.allclose(tail_inds, [1, 2]) + # test find_head_tail for polygon + polygon = np.array([[0., 10.], [3., 3.], [10., 0.], [17., 3.], [20., 10.], + [15., 10.], [13.5, 6.5], [10., 5.], [6.5, 6.5], + [5., 10.]]) + head_inds, tail_inds = target_generator.find_head_tail(polygon, 2.0) + assert np.allclose(head_inds, [9, 0]) + assert np.allclose(tail_inds, [4, 5]) + # test generate_text_region_mask img_size = (3, 10) text_polys = [[np.array([0, 0, 1, 0, 1, 1, 0, 1])], diff --git a/tests/test_dataset/test_transforms.py b/tests/test_dataset/test_transforms.py index a50b4d0a..a25f308a 100644 --- a/tests/test_dataset/test_transforms.py +++ b/tests/test_dataset/test_transforms.py @@ -5,7 +5,7 @@ import torchvision.transforms as TF from PIL import Image import mmocr.datasets.pipelines.transforms as transforms -from mmdet.core import BitmapMasks +from mmdet.core import BitmapMasks, PolygonMasks @mock.patch('%s.transforms.np.random.random_sample' % __name__) @@ -164,3 +164,97 @@ def test_affine_jitter(): output2 = affine_jitter(results) assert np.allclose(np.array(output1), output2['img']) + + +@mock.patch('%s.transforms.np.random.random_sample' % __name__) +@mock.patch('%s.transforms.np.random.randint' % __name__) +def test_random_crop_poly_instances(mock_randint, mock_sample): + results = {} + img = np.zeros((30, 30, 3)) + poly_masks = PolygonMasks([[ + np.array([5., 5., 25., 5., 25., 10., 5., 10.]) + ], [np.array([5., 20., 25., 20., 25., 25., 5., 25.])]], 30, 30) + results['img'] = img + results['gt_masks'] = poly_masks + results['gt_masks_ignore'] = PolygonMasks([], 30, 30) + results['mask_fields'] = ['gt_masks', 'gt_masks_ignore'] + results['gt_labels'] = [1, 1] + rcpi = transforms.RandomCropPolyInstances( + instance_key='gt_masks', crop_ratio=1.0, min_side_ratio=0.3) + + # test sample_crop_box(img_size, results) + mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 0, 0, 15] + crop_box = rcpi.sample_crop_box((30, 30), results) + assert np.allclose(np.array(crop_box), np.array([0, 0, 30, 15])) + + # test __call__ + mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 15, 0, 30] + mock_sample.side_effect = [0.1] + output = rcpi(results) + target = np.array([5., 5., 25., 5., 25., 10., 5., 10.]) + assert len(output['gt_masks']) == 1 + assert len(output['gt_masks_ignore']) == 0 + assert np.allclose(output['gt_masks'].masks[0][0], target) + assert output['img'].shape == (15, 30, 3) + + # test __call__ with blank instace_key masks + mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 15, 0, 30] + mock_sample.side_effect = [0.1] + rcpi = transforms.RandomCropPolyInstances( + instance_key='gt_masks_ignore', crop_ratio=1.0, min_side_ratio=0.3) + results['img'] = img + results['gt_masks'] = poly_masks + output = rcpi(results) + assert len(output['gt_masks']) == 2 + assert np.allclose(output['gt_masks'].masks[0][0], poly_masks.masks[0][0]) + assert np.allclose(output['gt_masks'].masks[1][0], poly_masks.masks[1][0]) + assert output['img'].shape == (30, 30, 3) + + +@mock.patch('%s.transforms.np.random.random_sample' % __name__) +def test_random_rotate_poly_instances(mock_sample): + results = {} + img = np.zeros((30, 30, 3)) + poly_masks = PolygonMasks( + [[np.array([10., 10., 20., 10., 20., 20., 10., 20.])]], 30, 30) + results['img'] = img + results['gt_masks'] = poly_masks + results['mask_fields'] = ['gt_masks'] + rrpi = transforms.RandomRotatePolyInstances(rotate_ratio=1.0, max_angle=90) + + mock_sample.side_effect = [0., 1.] + output = rrpi(results) + assert np.allclose(output['gt_masks'].masks[0][0], + np.array([10., 20., 10., 10., 20., 10., 20., 20.])) + assert output['img'].shape == (30, 30, 3) + + +@mock.patch('%s.transforms.np.random.random_sample' % __name__) +def test_square_resize_pad(mock_sample): + results = {} + img = np.zeros((15, 30, 3)) + polygon = np.array([10., 5., 20., 5., 20., 10., 10., 10.]) + poly_masks = PolygonMasks([[polygon]], 15, 30) + results['img'] = img + results['gt_masks'] = poly_masks + results['mask_fields'] = ['gt_masks'] + srp = transforms.SquareResizePad(target_size=40, pad_ratio=0.5) + + # test resize with padding + mock_sample.side_effect = [0.] + output = srp(results) + target = 4. / 3 * polygon + target[1::2] += 10. + assert np.allclose(output['gt_masks'].masks[0][0], target) + assert output['img'].shape == (40, 40, 3) + + # test resize to square without padding + results['img'] = img + results['gt_masks'] = poly_masks + mock_sample.side_effect = [1.] + output = srp(results) + target = polygon.copy() + target[::2] *= 4. / 3 + target[1::2] *= 8. / 3 + assert np.allclose(output['gt_masks'].masks[0][0], target) + assert output['img'].shape == (40, 40, 3) diff --git a/tests/test_models/test_detector.py b/tests/test_models/test_detector.py index a5a255c1..87d9d0b2 100644 --- a/tests/test_models/test_detector.py +++ b/tests/test_models/test_detector.py @@ -329,7 +329,7 @@ def test_textsnake(cfg_file): from mmocr.models import build_detector detector = build_detector(model) detector = detector.cuda() - input_shape = (1, 3, 64, 64) + input_shape = (1, 3, 224, 224) num_kernels = 1 mm_inputs = _demo_mm_inputs(num_kernels, input_shape) @@ -355,14 +355,18 @@ def test_textsnake(cfg_file): gt_cos_map=gt_cos_map) assert isinstance(losses, dict) - # Test forward test - # with torch.no_grad(): - # img_list = [g[None, :] for g in imgs] - # batch_results = [] - # for one_img, one_meta in zip(img_list, img_metas): - # result = detector.forward([one_img], [[one_meta]], - # return_loss=False) - # batch_results.append(result) + # Test forward test get_boundary + maps = torch.zeros((1, 5, 224, 224), dtype=torch.float) + maps[:, 0:2, :, :] = -10. + maps[:, 0, 60:100, 12:212] = 10. + maps[:, 1, 70:90, 22:202] = 10. + maps[:, 2, 70:90, 22:202] = 0. + maps[:, 3, 70:90, 22:202] = 1. + maps[:, 4, 70:90, 22:202] = 10. + + one_meta = img_metas[0] + result = detector.bbox_head.get_boundary(maps, [one_meta], False) + assert len(result) == 1 # Test show result results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}