mirror of https://github.com/open-mmlab/mmocr.git
Fix issue 122 (#130)
* fix #122: textsnake targets adaptation * fix #122: textsnake targets adaptation * add unittest * fix format * fix textsnake unittest on cpu * fix unit test coverage * add unit testpull/141/head
parent
e7847bd8b9
commit
e29b5b8abe
|
@ -80,19 +80,43 @@ class TextSnakeTargets(BaseTextDetTargets):
|
|||
edge_vec = pad_points[1:] - pad_points[:-1]
|
||||
|
||||
theta_sum = []
|
||||
|
||||
adjacent_vec_theta = []
|
||||
for i, edge_vec1 in enumerate(edge_vec):
|
||||
adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
|
||||
adjacent_edge_vec = edge_vec[adjacent_ind]
|
||||
temp_theta_sum = np.sum(
|
||||
self.vector_angle(edge_vec1, adjacent_edge_vec))
|
||||
temp_adjacent_theta = self.vector_angle(
|
||||
adjacent_edge_vec[0], adjacent_edge_vec[1])
|
||||
theta_sum.append(temp_theta_sum)
|
||||
theta_sum = np.array(theta_sum)
|
||||
head_start, tail_start = np.argsort(theta_sum)[::-1][0:2]
|
||||
adjacent_vec_theta.append(temp_adjacent_theta)
|
||||
theta_sum_score = np.array(theta_sum) / np.pi
|
||||
adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
|
||||
poly_center = np.mean(points, axis=0)
|
||||
edge_dist = np.maximum(
|
||||
norm(pad_points[1:] - poly_center, axis=-1),
|
||||
norm(pad_points[:-1] - poly_center, axis=-1))
|
||||
dist_score = edge_dist / np.max(edge_dist)
|
||||
position_score = np.zeros(len(edge_vec))
|
||||
score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
|
||||
score += 0.35 * dist_score
|
||||
if len(points) % 2 == 0:
|
||||
position_score[(len(score) // 2 - 1)] += 1
|
||||
position_score[-1] += 1
|
||||
score += 0.1 * position_score
|
||||
pad_score = np.concatenate([score, score])
|
||||
score_matrix = np.zeros((len(score), len(score) - 3))
|
||||
x = np.arange(len(score) - 3) / float(len(score) - 4)
|
||||
gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
|
||||
(x - 0.5) / 0.5, 2.) / 2)
|
||||
gaussian = gaussian / np.max(gaussian)
|
||||
for i in range(len(score)):
|
||||
score_matrix[i, :] = score[i] + pad_score[
|
||||
(i + 2):(i + len(score) - 1)] * gaussian * 0.3
|
||||
|
||||
if (abs(head_start - tail_start) < 2
|
||||
or abs(head_start - tail_start) > 12):
|
||||
tail_start = (head_start + len(points) // 2) % len(points)
|
||||
head_start, tail_increment = np.unravel_index(
|
||||
score_matrix.argmax(), score_matrix.shape)
|
||||
tail_start = (head_start + tail_increment + 2) % len(points)
|
||||
head_end = (head_start + 1) % len(points)
|
||||
tail_end = (tail_start + 1) % len(points)
|
||||
|
||||
|
@ -297,16 +321,15 @@ class TextSnakeTargets(BaseTextDetTargets):
|
|||
sin_theta = self.vector_sin(text_direction)
|
||||
cos_theta = self.vector_cos(text_direction)
|
||||
|
||||
pnt_tl = center_line[i] + (top_line[i] -
|
||||
center_line[i]) * region_shrink_ratio
|
||||
pnt_tr = center_line[i + 1] + (
|
||||
tl = center_line[i] + (top_line[i] -
|
||||
center_line[i]) * region_shrink_ratio
|
||||
tr = center_line[i + 1] + (
|
||||
top_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
|
||||
pnt_br = center_line[i + 1] + (
|
||||
br = center_line[i + 1] + (
|
||||
bot_line[i + 1] - center_line[i + 1]) * region_shrink_ratio
|
||||
pnt_bl = center_line[i] + (bot_line[i] -
|
||||
center_line[i]) * region_shrink_ratio
|
||||
current_center_box = np.vstack([pnt_tl, pnt_tr, pnt_br,
|
||||
pnt_bl]).astype(np.int32)
|
||||
bl = center_line[i] + (bot_line[i] -
|
||||
center_line[i]) * region_shrink_ratio
|
||||
current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
|
||||
|
||||
cv2.fillPoly(center_region_mask, [current_center_box], color=1)
|
||||
cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
|
||||
|
@ -344,8 +367,15 @@ class TextSnakeTargets(BaseTextDetTargets):
|
|||
assert len(poly) == 1
|
||||
text_instance = [[poly[0][i], poly[0][i + 1]]
|
||||
for i in range(0, len(poly[0]), 2)]
|
||||
polygon_points = np.array(
|
||||
text_instance, dtype=np.int32).reshape(-1, 2)
|
||||
polygon_points = np.array(text_instance).reshape(-1, 2)
|
||||
|
||||
n = len(polygon_points)
|
||||
keep_inds = []
|
||||
for i in range(n):
|
||||
if norm(polygon_points[i] -
|
||||
polygon_points[(i + 1) % n]) > 1e-5:
|
||||
keep_inds.append(i)
|
||||
polygon_points = polygon_points[keep_inds]
|
||||
|
||||
_, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
|
||||
resampled_top_line, resampled_bot_line = self.resample_sidelines(
|
||||
|
|
|
@ -408,37 +408,42 @@ class RandomCropPolyInstances:
|
|||
region_ends[region_ind])
|
||||
return start, end
|
||||
|
||||
def sample_crop_box(self, img_size, masks):
|
||||
def sample_crop_box(self, img_size, results):
|
||||
"""Generate crop box and make sure not to crop the polygon instances.
|
||||
|
||||
Args:
|
||||
img_size (tuple(int)): The image size.
|
||||
masks (list[list[ndarray]]): The polygon masks.
|
||||
img_size (tuple(int)): The image size (h, w).
|
||||
results (dict): The results dict.
|
||||
"""
|
||||
|
||||
assert isinstance(img_size, tuple)
|
||||
h, w = img_size[:2]
|
||||
|
||||
key_masks = results[self.instance_key].masks
|
||||
x_valid_array = np.ones(w, dtype=np.int32)
|
||||
y_valid_array = np.ones(h, dtype=np.int32)
|
||||
|
||||
selected_mask = masks[np.random.randint(0, len(masks))]
|
||||
selected_mask = key_masks[np.random.randint(0, len(key_masks))]
|
||||
selected_mask = selected_mask[0].reshape((-1, 2)).astype(np.int32)
|
||||
max_x_start = max(np.min(selected_mask[:, 0]) - 2, 0)
|
||||
min_x_end = min(np.max(selected_mask[:, 0]) + 3, w - 1)
|
||||
max_y_start = max(np.min(selected_mask[:, 1]) - 2, 0)
|
||||
min_y_end = min(np.max(selected_mask[:, 1]) + 3, h - 1)
|
||||
|
||||
for mask in masks:
|
||||
assert len(mask) == 1
|
||||
mask = mask[0].reshape((-1, 2)).astype(np.int32)
|
||||
clip_x = np.clip(mask[:, 0], 0, w - 1)
|
||||
clip_y = np.clip(mask[:, 1], 0, h - 1)
|
||||
min_x, max_x = np.min(clip_x), np.max(clip_x)
|
||||
min_y, max_y = np.min(clip_y), np.max(clip_y)
|
||||
for key in results.get('mask_fields', []):
|
||||
if len(results[key].masks) == 0:
|
||||
continue
|
||||
masks = results[key].masks
|
||||
for mask in masks:
|
||||
assert len(mask) == 1
|
||||
mask = mask[0].reshape((-1, 2)).astype(np.int32)
|
||||
clip_x = np.clip(mask[:, 0], 0, w - 1)
|
||||
clip_y = np.clip(mask[:, 1], 0, h - 1)
|
||||
min_x, max_x = np.min(clip_x), np.max(clip_x)
|
||||
min_y, max_y = np.min(clip_y), np.max(clip_y)
|
||||
|
||||
x_valid_array[min_x - 2:max_x + 3] = 0
|
||||
y_valid_array[min_y - 2:max_y + 3] = 0
|
||||
x_valid_array[min_x - 2:max_x + 3] = 0
|
||||
y_valid_array[min_y - 2:max_y + 3] = 0
|
||||
|
||||
min_w = int(w * self.min_side_ratio)
|
||||
min_h = int(h * self.min_side_ratio)
|
||||
|
@ -458,9 +463,10 @@ class RandomCropPolyInstances:
|
|||
return img[bbox[1]:bbox[3], bbox[0]:bbox[2]]
|
||||
|
||||
def __call__(self, results):
|
||||
if len(results[self.instance_key].masks) < 1:
|
||||
return results
|
||||
if np.random.random_sample() < self.crop_ratio:
|
||||
crop_box = self.sample_crop_box(results['img'].shape,
|
||||
results[self.instance_key].masks)
|
||||
crop_box = self.sample_crop_box(results['img'].shape, results)
|
||||
results['crop_region'] = crop_box
|
||||
img = self.crop_img(results['img'], crop_box)
|
||||
results['img'] = img
|
||||
|
|
|
@ -145,12 +145,20 @@ def test_gen_textsnake_targets(mock_show_feature):
|
|||
assert np.allclose(target_generator.resample_step, 4.0)
|
||||
assert np.allclose(target_generator.center_region_shrink_ratio, 0.3)
|
||||
|
||||
# test find_head_tail
|
||||
# test find_head_tail for quadrangle
|
||||
polygon = np.array([[1.0, 1.0], [5.0, 1.0], [5.0, 3.0], [1.0, 3.0]])
|
||||
head_inds, tail_inds = target_generator.find_head_tail(polygon, 2.0)
|
||||
assert np.allclose(head_inds, [3, 0])
|
||||
assert np.allclose(tail_inds, [1, 2])
|
||||
|
||||
# test find_head_tail for polygon
|
||||
polygon = np.array([[0., 10.], [3., 3.], [10., 0.], [17., 3.], [20., 10.],
|
||||
[15., 10.], [13.5, 6.5], [10., 5.], [6.5, 6.5],
|
||||
[5., 10.]])
|
||||
head_inds, tail_inds = target_generator.find_head_tail(polygon, 2.0)
|
||||
assert np.allclose(head_inds, [9, 0])
|
||||
assert np.allclose(tail_inds, [4, 5])
|
||||
|
||||
# test generate_text_region_mask
|
||||
img_size = (3, 10)
|
||||
text_polys = [[np.array([0, 0, 1, 0, 1, 1, 0, 1])],
|
||||
|
|
|
@ -5,7 +5,7 @@ import torchvision.transforms as TF
|
|||
from PIL import Image
|
||||
|
||||
import mmocr.datasets.pipelines.transforms as transforms
|
||||
from mmdet.core import BitmapMasks
|
||||
from mmdet.core import BitmapMasks, PolygonMasks
|
||||
|
||||
|
||||
@mock.patch('%s.transforms.np.random.random_sample' % __name__)
|
||||
|
@ -164,3 +164,97 @@ def test_affine_jitter():
|
|||
output2 = affine_jitter(results)
|
||||
|
||||
assert np.allclose(np.array(output1), output2['img'])
|
||||
|
||||
|
||||
@mock.patch('%s.transforms.np.random.random_sample' % __name__)
|
||||
@mock.patch('%s.transforms.np.random.randint' % __name__)
|
||||
def test_random_crop_poly_instances(mock_randint, mock_sample):
|
||||
results = {}
|
||||
img = np.zeros((30, 30, 3))
|
||||
poly_masks = PolygonMasks([[
|
||||
np.array([5., 5., 25., 5., 25., 10., 5., 10.])
|
||||
], [np.array([5., 20., 25., 20., 25., 25., 5., 25.])]], 30, 30)
|
||||
results['img'] = img
|
||||
results['gt_masks'] = poly_masks
|
||||
results['gt_masks_ignore'] = PolygonMasks([], 30, 30)
|
||||
results['mask_fields'] = ['gt_masks', 'gt_masks_ignore']
|
||||
results['gt_labels'] = [1, 1]
|
||||
rcpi = transforms.RandomCropPolyInstances(
|
||||
instance_key='gt_masks', crop_ratio=1.0, min_side_ratio=0.3)
|
||||
|
||||
# test sample_crop_box(img_size, results)
|
||||
mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 0, 0, 15]
|
||||
crop_box = rcpi.sample_crop_box((30, 30), results)
|
||||
assert np.allclose(np.array(crop_box), np.array([0, 0, 30, 15]))
|
||||
|
||||
# test __call__
|
||||
mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 15, 0, 30]
|
||||
mock_sample.side_effect = [0.1]
|
||||
output = rcpi(results)
|
||||
target = np.array([5., 5., 25., 5., 25., 10., 5., 10.])
|
||||
assert len(output['gt_masks']) == 1
|
||||
assert len(output['gt_masks_ignore']) == 0
|
||||
assert np.allclose(output['gt_masks'].masks[0][0], target)
|
||||
assert output['img'].shape == (15, 30, 3)
|
||||
|
||||
# test __call__ with blank instace_key masks
|
||||
mock_randint.side_effect = [0, 0, 0, 0, 30, 0, 15, 0, 30]
|
||||
mock_sample.side_effect = [0.1]
|
||||
rcpi = transforms.RandomCropPolyInstances(
|
||||
instance_key='gt_masks_ignore', crop_ratio=1.0, min_side_ratio=0.3)
|
||||
results['img'] = img
|
||||
results['gt_masks'] = poly_masks
|
||||
output = rcpi(results)
|
||||
assert len(output['gt_masks']) == 2
|
||||
assert np.allclose(output['gt_masks'].masks[0][0], poly_masks.masks[0][0])
|
||||
assert np.allclose(output['gt_masks'].masks[1][0], poly_masks.masks[1][0])
|
||||
assert output['img'].shape == (30, 30, 3)
|
||||
|
||||
|
||||
@mock.patch('%s.transforms.np.random.random_sample' % __name__)
|
||||
def test_random_rotate_poly_instances(mock_sample):
|
||||
results = {}
|
||||
img = np.zeros((30, 30, 3))
|
||||
poly_masks = PolygonMasks(
|
||||
[[np.array([10., 10., 20., 10., 20., 20., 10., 20.])]], 30, 30)
|
||||
results['img'] = img
|
||||
results['gt_masks'] = poly_masks
|
||||
results['mask_fields'] = ['gt_masks']
|
||||
rrpi = transforms.RandomRotatePolyInstances(rotate_ratio=1.0, max_angle=90)
|
||||
|
||||
mock_sample.side_effect = [0., 1.]
|
||||
output = rrpi(results)
|
||||
assert np.allclose(output['gt_masks'].masks[0][0],
|
||||
np.array([10., 20., 10., 10., 20., 10., 20., 20.]))
|
||||
assert output['img'].shape == (30, 30, 3)
|
||||
|
||||
|
||||
@mock.patch('%s.transforms.np.random.random_sample' % __name__)
|
||||
def test_square_resize_pad(mock_sample):
|
||||
results = {}
|
||||
img = np.zeros((15, 30, 3))
|
||||
polygon = np.array([10., 5., 20., 5., 20., 10., 10., 10.])
|
||||
poly_masks = PolygonMasks([[polygon]], 15, 30)
|
||||
results['img'] = img
|
||||
results['gt_masks'] = poly_masks
|
||||
results['mask_fields'] = ['gt_masks']
|
||||
srp = transforms.SquareResizePad(target_size=40, pad_ratio=0.5)
|
||||
|
||||
# test resize with padding
|
||||
mock_sample.side_effect = [0.]
|
||||
output = srp(results)
|
||||
target = 4. / 3 * polygon
|
||||
target[1::2] += 10.
|
||||
assert np.allclose(output['gt_masks'].masks[0][0], target)
|
||||
assert output['img'].shape == (40, 40, 3)
|
||||
|
||||
# test resize to square without padding
|
||||
results['img'] = img
|
||||
results['gt_masks'] = poly_masks
|
||||
mock_sample.side_effect = [1.]
|
||||
output = srp(results)
|
||||
target = polygon.copy()
|
||||
target[::2] *= 4. / 3
|
||||
target[1::2] *= 8. / 3
|
||||
assert np.allclose(output['gt_masks'].masks[0][0], target)
|
||||
assert output['img'].shape == (40, 40, 3)
|
||||
|
|
|
@ -329,7 +329,7 @@ def test_textsnake(cfg_file):
|
|||
from mmocr.models import build_detector
|
||||
detector = build_detector(model)
|
||||
detector = detector.cuda()
|
||||
input_shape = (1, 3, 64, 64)
|
||||
input_shape = (1, 3, 224, 224)
|
||||
num_kernels = 1
|
||||
mm_inputs = _demo_mm_inputs(num_kernels, input_shape)
|
||||
|
||||
|
@ -355,14 +355,18 @@ def test_textsnake(cfg_file):
|
|||
gt_cos_map=gt_cos_map)
|
||||
assert isinstance(losses, dict)
|
||||
|
||||
# Test forward test
|
||||
# with torch.no_grad():
|
||||
# img_list = [g[None, :] for g in imgs]
|
||||
# batch_results = []
|
||||
# for one_img, one_meta in zip(img_list, img_metas):
|
||||
# result = detector.forward([one_img], [[one_meta]],
|
||||
# return_loss=False)
|
||||
# batch_results.append(result)
|
||||
# Test forward test get_boundary
|
||||
maps = torch.zeros((1, 5, 224, 224), dtype=torch.float)
|
||||
maps[:, 0:2, :, :] = -10.
|
||||
maps[:, 0, 60:100, 12:212] = 10.
|
||||
maps[:, 1, 70:90, 22:202] = 10.
|
||||
maps[:, 2, 70:90, 22:202] = 0.
|
||||
maps[:, 3, 70:90, 22:202] = 1.
|
||||
maps[:, 4, 70:90, 22:202] = 10.
|
||||
|
||||
one_meta = img_metas[0]
|
||||
result = detector.bbox_head.get_boundary(maps, [one_meta], False)
|
||||
assert len(result) == 1
|
||||
|
||||
# Test show result
|
||||
results = {'boundary_result': [[0, 0, 1, 0, 1, 1, 0, 1, 0.9]]}
|
||||
|
|
Loading…
Reference in New Issue