Fix issue 122 (#130)

* fix #122: textsnake targets adaptation

* fix #122: textsnake targets adaptation

* add unittest

* fix format

* fix textsnake unittest on cpu

* fix unit test coverage

* add unit test
pull/141/head
Jianyong Chen 2021-04-28 22:02:57 -05:00 committed by GitHub
parent e7847bd8b9
commit e29b5b8abe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 184 additions and 42 deletions

View File

@ -80,19 +80,43 @@ class TextSnakeTargets(BaseTextDetTargets):
edge_vec = pad_points[1:] - pad_points[:-1]
theta_sum = []
adjacent_vec_theta = []
for i, edge_vec1 in enumerate(edge_vec):
adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
adjacent_edge_vec = edge_vec[adjacent_ind]
temp_theta_sum = np.sum(
self.vector_angle(edge_vec1, adjacent_edge_vec))
temp_adjacent_theta = self.vector_angle(
adjacent_edge_vec[0], adjacent_edge_vec[1])
theta_sum.append(temp_theta_sum)
theta_sum = np.array(theta_sum)
head_start, tail_start = np.argsort(theta_sum)[::-1][0:2]
adjacent_vec_theta.append(temp_adjacent_theta)
theta_sum_score = np.array(theta_sum) / np.pi
adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
poly_center = np.mean(points, axis=0)
edge_dist = np.maximum(
norm(pad_points[1:] - poly_center, axis=-1),
norm(pad_points[:-1] - poly_center, axis=-1))
dist_score = edge_dist / np.max(edge_dist)
position_score = np.zeros(len(edge_vec))
score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
score += 0.35 * dist_score
if len(points) % 2 == 0:
position_score[(len(score) // 2 - 1)] += 1
position_score[-1] += 1
score += 0.1 * position_score
pad_score = np.concatenate([score, score])
score_matrix = np.zeros((len(score), len(score) - 3))
x = np.arange(len(score) - 3) / float(len(score) - 4)
gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
(x - 0.5) / 0.5, 2.) / 2)
gaussian = gaussian / np.max(gaussian)
for i in range(len(score)):
score_matrix[i, :] = score[i] + pad_score[
(i + 2):(i + len(score) - 1)] * gaussian * 0.3
if (abs(head_start - tail_start) < 2
or abs(head_start - tail_start) > 12):
tail_start = (head_start + len(points) // 2) % len(points)
head_start, tail_increment = np.unravel_index(
score_matrix.argmax(), score_matrix.shape)
tail_start = (head_start + tail_increment + 2) % len(points)
head_end = (head_start + 1) % len(points)
tail_end = (tail_start + 1) % len(points)
@ -297,16 +321,15 @@ class TextSnakeTargets(BaseTextDetTargets):
sin_theta = self.vector_sin(text_direction)
cos_theta = self.vector_cos(text_direction)
pnt_tl = center_line[i] + (top_line[i] -
center_line[i]) * region_shrink_ratio
pnt_tr = center_line[i + 1] + (
tl = center_line[i] + (top_line[i] -
center_line[i]) * region_shrink_ratio
tr = center_line[i + 1] + (
top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
pnt_br = center_line[i + 1] + (
br = center_line[i + 1] + (
bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
pnt_bl = center_line[i] + (bot_line[i] -
center_line[i]) * region_shrink_ratio
current_center_box = np.vstack([pnt_tl, pnt_tr, pnt_br,
pnt_bl]).astype(np.int32)
bl = center_line[i] + (bot_line[i] -
center_line[i]) * region_shrink_ratio
current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
cv2.fillPoly(center_region_mask, [current_center_box], color=1)
cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
@ -344,8 +367,15 @@ class TextSnakeTargets(BaseTextDetTargets):
assert len(poly) == 1
text_instance = [[poly[0][i], poly[0][i + 1]]
for i in range(0, len(poly[0]), 2)]
polygon_points = np.array(
text_instance, dtype=np.int32).reshape(-1, 2)
polygon_points = np.array(text_instance).reshape(-1, 2)
n = len(polygon_points)
keep_inds = []
for i in range(n):
if norm(polygon_points[i] -
polygon_points[(i + 1) % n]) > 1e-5:
keep_inds.append(i)
polygon_points = polygon_points[keep_inds]
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
resampled_top_line, resampled_bot_line = self.resample_sidelines(

View File

@ -408,37 +408,42 @@ class RandomCropPolyInstances:
region_ends[region_ind])
return start, end
def sample_crop_box(self, img_size, masks):
def sample_crop_box(self, img_size, results):
"""Generate crop box and make sure not to crop the polygon instances.
Args:
img_size (tuple(int)): The image size.
masks (list[list[ndarray]]): The polygon masks.
img_size (tuple(int)): The image size (h, w).
results (dict): The results dict.
"""
assert isinstance(img_size, tuple)
h, w = img_size[:2]
key_masks = results[self.instance_key].masks
x_valid_array = np.ones(w, dtype=np.int32)
y_valid_array = np.ones(h, dtype=np.int32)
selected_mask = masks[np.random.randint(0, len(masks))]
selected_mask = key_masks[np.random.randint(0, len(key_masks))]
selected_mask = selected_mask[0].reshape((-1, 2)).astype(np.int32)
max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0)
min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1)
max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
for mask in masks:
assert len(mask) == 1
mask = mask[0].reshape((-1, 2)).astype(np.int32)
clip_x = np.clip(mask[:, 0], 0, w - 1)
clip_y = np.clip(mask[:, 1], 0, h - 1)
min_x, max_x = np.min(clip_x), np.max(clip_x)
min_y, max_y = np.min(clip_y), np.max(clip_y)
for key in results.get('mask_fields', []):
if len(results[key].masks) == 0:
continue
masks = results[key].masks
for mask in masks:
assert len(mask) == 1
mask = mask[0].reshape((-1, 2)).astype(np.int32)
clip_x = np.clip(mask[:, 0], 0, w - 1)
clip_y = np.clip(mask[:, 1], 0, h - 1)
min_x, max_x = np.min(clip_x), np.max(clip_x)
min_y, max_y = np.min(clip_y), np.max(clip_y)
x_valid_array[min_x - 2:max_x + 3] = 0
y_valid_array[min_y - 2:max_y + 3] = 0
x_valid_array[min_x - 2:max_x + 3] = 0
y_valid_array[min_y - 2:max_y + 3] = 0
min_w = int(w * self.min_side_ratio)
min_h = int(h * self.min_side_ratio)
@ -458,9 +463,10 @@ class RandomCropPolyInstances:
return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
def __call__(self, results):
if len(results[self.instance_key].masks) < 1:
return results
if np.random.random_sample() < self.crop_ratio:
crop_box = self.sample_crop_box(results['img'].shape,
results[self.instance_key].masks)
crop_box = self.sample_crop_box(results['img'].shape, results)
results['crop_region'] = crop_box
img = self.crop_img(results['img'], crop_box)
results['img'] = img

View File

@ -145,12 +145,20 @@ def test_gen_textsnake_targets(mock_show_feature):
assert np.allclose(target_generator.resample_step, 4.0)
assert np.allclose(target_generator.center_region_shrink_ratio, 0.3)
# test find_head_tail
# test find_head_tail for quadrangle
polygon = np.array([[1.0, 1.0], [5.0, 1.0], [5.0, 3.0], [1.0, 3.0]])
head_inds, tail_inds = target_generator.find_head_tail(polygon, 2.0)
assert np.allclose(head_inds, [3, 0])
assert np.allclose(tail_inds, [1, 2])
# test find_head_tail for polygon
polygon = np.array([[0., 10.], [3., 3.], [10., 0.], [17., 3.], [20., 10.],
[15., 10.], [13.5, 6.5], [10., 5.], [6.5, 6.5],
[5., 10.]])
head_inds, tail_inds = target_generator.find_head_tail(polygon, 2.0)
assert np.allclose(head_inds, [9, 0])
assert np.allclose(tail_inds, [4, 5])
# test generate_text_region_mask
img_size = (3, 10)
text_polys = [[np.array([0, 0, 1, 0, 1, 1, 0, 1])],

View File

@ -5,7 +5,7 @@ import torchvision.transforms as TF
from PIL import Image
import mmocr.datasets.pipelines.transforms as transforms
from mmdet.core import BitmapMasks
from mmdet.core import BitmapMasks, PolygonMasks
@mock.patch('%s.transforms.np.random.random_sample' % __name__)
@ -164,3 +164,97 @@ def test_affine_jitter():
output2 = affine_jitter(results)
assert np.allclose(np.array(output1), output2['img'])
@mock.patch('%s.transforms.np.random.random_sample' % __name__)
@mock.patch('%s.transforms.np.random.randint' % __name__)
def test_random_crop_poly_instances(mock_randint, mock_sample):
results = {}
img = np.zeros((30, 30, 3))
poly_masks = PolygonMasks([[
np.array([5., 5., 25., 5., 25., 10., 5., 10.])
], [np.array([5., 20., 25., 20., 25., 25., 5., 25.])]], 30, 30)
results['img'] = img
results['gt_masks'] = poly_masks
results['gt_masks_ignore'] = PolygonMasks([], 30, 30)
results['mask_fields'] = ['gt_masks', 'gt_masks_ignore']
results['gt_labels'] = [1, 1]
rcpi = transforms.RandomCropPolyInstances(
instance_key='gt_masks', crop_ratio=1.0, min_side_ratio=0.3)
# test sample_crop_box(img_size, results)
mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 0, 0, 15]
crop_box = rcpi.sample_crop_box((30, 30), results)
assert np.allclose(np.array(crop_box), np.array([0, 0, 30, 15]))
# test __call__
mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 15, 0, 30]
mock_sample.side_effect = [0.1]
output = rcpi(results)
target = np.array([5., 5., 25., 5., 25., 10., 5., 10.])
assert len(output['gt_masks']) == 1
assert len(output['gt_masks_ignore']) == 0
assert np.allclose(output['gt_masks'].masks[0][0], target)
assert output['img'].shape == (15, 30, 3)
# test __call__ with blank instace_key masks
mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 15, 0, 30]
mock_sample.side_effect = [0.1]
rcpi = transforms.RandomCropPolyInstances(
instance_key='gt_masks_ignore', crop_ratio=1.0, min_side_ratio=0.3)
results['img'] = img
results['gt_masks'] = poly_masks
output = rcpi(results)
assert len(output['gt_masks']) == 2
assert np.allclose(output['gt_masks'].masks[0][0], poly_masks.masks[0][0])
assert np.allclose(output['gt_masks'].masks[1][0], poly_masks.masks[1][0])
assert output['img'].shape == (30, 30, 3)
@mock.patch('%s.transforms.np.random.random_sample' % __name__)
def test_random_rotate_poly_instances(mock_sample):
results = {}
img = np.zeros((30, 30, 3))
poly_masks = PolygonMasks(
[[np.array([10., 10., 20., 10., 20., 20., 10., 20.])]], 30, 30)
results['img'] = img
results['gt_masks'] = poly_masks
results['mask_fields'] = ['gt_masks']
rrpi = transforms.RandomRotatePolyInstances(rotate_ratio=1.0, max_angle=90)
mock_sample.side_effect = [0., 1.]
output = rrpi(results)
assert np.allclose(output['gt_masks'].masks[0][0],
np.array([10., 20., 10., 10., 20., 10., 20., 20.]))
assert output['img'].shape == (30, 30, 3)
@mock.patch('%s.transforms.np.random.random_sample' % __name__)
def test_square_resize_pad(mock_sample):
results = {}
img = np.zeros((15, 30, 3))
polygon = np.array([10., 5., 20., 5., 20., 10., 10., 10.])
poly_masks = PolygonMasks([[polygon]], 15, 30)
results['img'] = img
results['gt_masks'] = poly_masks
results['mask_fields'] = ['gt_masks']
srp = transforms.SquareResizePad(target_size=40, pad_ratio=0.5)
# test resize with padding
mock_sample.side_effect = [0.]
output = srp(results)
target = 4. / 3 * polygon
target[1::2] += 10.
assert np.allclose(output['gt_masks'].masks[0][0], target)
assert output['img'].shape == (40, 40, 3)
# test resize to square without padding
results['img'] = img
results['gt_masks'] = poly_masks
mock_sample.side_effect = [1.]
output = srp(results)
target = polygon.copy()
target[::2] *= 4. / 3
target[1::2] *= 8. / 3
assert np.allclose(output['gt_masks'].masks[0][0], target)
assert output['img'].shape == (40, 40, 3)

View File

@ -329,7 +329,7 @@ def test_textsnake(cfg_file):
from mmocr.models import build_detector
detector = build_detector(model)
detector = detector.cuda()
input_shape = (1, 3, 64, 64)
input_shape = (1, 3, 224, 224)
num_kernels = 1
mm_inputs = _demo_mm_inputs(num_kernels, input_shape)
@ -355,14 +355,18 @@ def test_textsnake(cfg_file):
gt_cos_map=gt_cos_map)
assert isinstance(losses, dict)
# Test forward test
# with torch.no_grad():
# img_list = [g[None, :] for g in imgs]
# batch_results = []
# for one_img, one_meta in zip(img_list, img_metas):
# result = detector.forward([one_img], [[one_meta]],
# return_loss=False)
# batch_results.append(result)
# Test forward test get_boundary
maps = torch.zeros((1, 5, 224, 224), dtype=torch.float)
maps[:, 0:2, :, :] = -10.
maps[:, 0, 60:100, 12:212] = 10.
maps[:, 1, 70:90, 22:202] = 10.
maps[:, 2, 70:90, 22:202] = 0.
maps[:, 3, 70:90, 22:202] = 1.
maps[:, 4, 70:90, 22:202] = 10.
one_meta = img_metas[0]
result = detector.bbox_head.get_boundary(maps, [one_meta], False)
assert len(result) == 1
# Test show result
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}