mirror of https://github.com/open-mmlab/mmocr.git
Support ocr box stitching (#290)
* ocr box stitching * rename a varible * add a demo * move functions to box_util add testpull/304/head
parent
a86ab115e4
commit
9582f93c70
|
@ -22,3 +22,4 @@ python demo/ocr_image_demo.py demo/demo_text_det.jpg demo/output.jpg
|
|||
|
||||
1. If `--imshow` is specified, the demo will also show the image with OpenCV.
|
||||
2. The `ocr_image_demo.py` script only supports GPU and so the `--device` parameter cannot take cpu as an argument.
|
||||
3. (Experimental) By specifying `--ocr-in-lines`, the ocr results will be grouped and presented in lines.
|
||||
|
|
|
@ -6,6 +6,7 @@ from mmdet.apis import init_detector
|
|||
from mmocr.apis.inference import model_inference
|
||||
from mmocr.core.visualize import det_recog_show_result
|
||||
from mmocr.datasets.pipelines.crop import crop_img
|
||||
from mmocr.utils.img_util import stitch_boxes_into_lines
|
||||
|
||||
|
||||
def det_and_recog_inference(args, det_model, recog_model):
|
||||
|
@ -111,6 +112,10 @@ def main():
|
|||
'--imshow',
|
||||
action='store_true',
|
||||
help='Whether show image with OpenCV.')
|
||||
parser.add_argument(
|
||||
'--ocr-in-lines',
|
||||
action='store_true',
|
||||
help='Whether group ocr results in lines.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.device == 'cpu':
|
||||
|
@ -141,6 +146,16 @@ def main():
|
|||
ensure_ascii=False,
|
||||
indent=4)
|
||||
|
||||
if args.ocr_in_lines:
|
||||
res = det_recog_result['result']
|
||||
res = stitch_boxes_into_lines(res, 10, 0.5)
|
||||
det_recog_result['result'] = res
|
||||
mmcv.dump(
|
||||
det_recog_result,
|
||||
args.out_file + '.line.json',
|
||||
ensure_ascii=False,
|
||||
indent=4)
|
||||
|
||||
img = det_recog_show_result(args.img, det_recog_result)
|
||||
mmcv.imwrite(img, args.out_file)
|
||||
if args.imshow:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
from .box_util import is_on_same_line, stitch_boxes_into_lines
|
||||
from .check_argument import (equal_len, is_2dlist, is_3dlist, is_ndarray_list,
|
||||
is_none_or_type, is_type_list, valid_boundary)
|
||||
from .collect_env import collect_env
|
||||
|
@ -14,5 +15,5 @@ __all__ = [
|
|||
'is_3dlist', 'is_ndarray_list', 'is_type_list', 'is_none_or_type',
|
||||
'equal_len', 'is_2dlist', 'valid_boundary', 'lmdb_converter',
|
||||
'drop_orientation', 'convert_annotations', 'is_not_png', 'list_to_file',
|
||||
'list_from_file'
|
||||
'list_from_file', 'is_on_same_line', 'stitch_boxes_into_lines'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def is_on_same_line(box_a, box_b, min_y_overlap_ratio=0.8):
|
||||
"""Check if two boxes are on the same line by their y-axis coordinates.
|
||||
|
||||
Two boxes are on the same line if they overlap vertically, and the length
|
||||
of the overlapping line segment is greater than min_y_overlap_ratio * the
|
||||
height of either of the boxes.
|
||||
|
||||
Args:
|
||||
box_a (list), box_b (list): Two bounding boxes to be checked
|
||||
min_y_overlap_ratio (float): The minimum vertical overlapping ratio
|
||||
allowed for boxes in the same line
|
||||
|
||||
Returns:
|
||||
The bool flag indicating if they are on the same line
|
||||
"""
|
||||
a_y_min = np.min(box_a[1::2])
|
||||
b_y_min = np.min(box_b[1::2])
|
||||
a_y_max = np.max(box_a[1::2])
|
||||
b_y_max = np.max(box_b[1::2])
|
||||
|
||||
# Make sure that box a is always the box above another
|
||||
if a_y_min > b_y_min:
|
||||
a_y_min, b_y_min = b_y_min, a_y_min
|
||||
a_y_max, b_y_max = b_y_max, a_y_max
|
||||
|
||||
if b_y_min <= a_y_max:
|
||||
if min_y_overlap_ratio is not None:
|
||||
sorted_y = sorted([b_y_min, b_y_max, a_y_max])
|
||||
overlap = sorted_y[1] - sorted_y[0]
|
||||
min_a_overlap = (a_y_max - a_y_min) * min_y_overlap_ratio
|
||||
min_b_overlap = (b_y_max - b_y_min) * min_y_overlap_ratio
|
||||
return overlap >= min_a_overlap or \
|
||||
overlap >= min_b_overlap
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def stitch_boxes_into_lines(boxes, max_x_dist=10, min_y_overlap_ratio=0.8):
|
||||
"""Stitch fragmented boxes of words into lines.
|
||||
|
||||
Note: part of its logic is inspired by @Johndirr
|
||||
(https://github.com/faustomorales/keras-ocr/issues/22)
|
||||
|
||||
Args:
|
||||
boxes (list): List of ocr results to be stitched
|
||||
max_x_dist (int): The maximum horizontal distance between the closest
|
||||
edges of neighboring boxes in the same line
|
||||
min_y_overlap_ratio (float): The minimum vertical overlapping ratio
|
||||
allowed for any pairs of neighboring boxes in the same line
|
||||
|
||||
Returns:
|
||||
merged_boxes(list[dict]): List of merged boxes and texts
|
||||
"""
|
||||
|
||||
if len(boxes) <= 1:
|
||||
return boxes
|
||||
|
||||
merged_boxes = []
|
||||
|
||||
# sort groups based on the x_min coordinate of boxes
|
||||
x_sorted_boxes = sorted(boxes, key=lambda x: np.min(x['box'][::2]))
|
||||
# store indexes of boxes which are already parts of other lines
|
||||
skip_idxs = set()
|
||||
|
||||
i = 0
|
||||
# locate lines of boxes starting from the leftmost one
|
||||
for i in range(len(x_sorted_boxes)):
|
||||
if i in skip_idxs:
|
||||
continue
|
||||
# the rightmost box in the current line
|
||||
rightmost_box_idx = i
|
||||
line = [rightmost_box_idx]
|
||||
for j in range(i + 1, len(x_sorted_boxes)):
|
||||
if j in skip_idxs:
|
||||
continue
|
||||
if is_on_same_line(x_sorted_boxes[rightmost_box_idx]['box'],
|
||||
x_sorted_boxes[j]['box'], min_y_overlap_ratio):
|
||||
line.append(j)
|
||||
skip_idxs.add(j)
|
||||
rightmost_box_idx = j
|
||||
|
||||
# split line into lines if the distance between two neighboring
|
||||
# sub-lines' is greater than max_x_dist
|
||||
lines = []
|
||||
line_idx = 0
|
||||
lines.append([line[0]])
|
||||
for k in range(1, len(line)):
|
||||
curr_box = x_sorted_boxes[line[k]]
|
||||
prev_box = x_sorted_boxes[line[k - 1]]
|
||||
dist = np.min(curr_box['box'][::2]) - np.max(prev_box['box'][::2])
|
||||
if dist > max_x_dist:
|
||||
line_idx += 1
|
||||
lines.append([])
|
||||
lines[line_idx].append(line[k])
|
||||
|
||||
# Get merged boxes
|
||||
for box_group in lines:
|
||||
merged_box = {}
|
||||
merged_box['text'] = ' '.join(
|
||||
[x_sorted_boxes[idx]['text'] for idx in box_group])
|
||||
x_min, y_min = float('inf'), float('inf')
|
||||
x_max, y_max = float('-inf'), float('-inf')
|
||||
for idx in box_group:
|
||||
x_max = max(np.max(x_sorted_boxes[idx]['box'][::2]), x_max)
|
||||
x_min = min(np.min(x_sorted_boxes[idx]['box'][::2]), x_min)
|
||||
y_max = max(np.max(x_sorted_boxes[idx]['box'][1::2]), y_max)
|
||||
y_min = min(np.min(x_sorted_boxes[idx]['box'][1::2]), y_min)
|
||||
merged_box['box'] = [
|
||||
x_min, y_min, x_max, y_min, x_max, y_max, x_min, y_max
|
||||
]
|
||||
merged_boxes.append(merged_box)
|
||||
|
||||
return merged_boxes
|
|
@ -0,0 +1,47 @@
|
|||
from mmocr.utils import is_on_same_line, stitch_boxes_into_lines
|
||||
|
||||
|
||||
def test_box_on_line():
|
||||
# regular boxes
|
||||
box1 = [0, 0, 1, 0, 1, 1, 0, 1]
|
||||
box2 = [2, 0.5, 3, 0.5, 3, 1.5, 2, 1.5]
|
||||
box3 = [4, 0.8, 5, 0.8, 5, 1.8, 4, 1.8]
|
||||
assert is_on_same_line(box1, box2, 0.5)
|
||||
assert not is_on_same_line(box1, box3, 0.5)
|
||||
|
||||
# irregular box4
|
||||
box4 = [0, 0, 1, 1, 1, 2, 0, 1]
|
||||
box5 = [2, 1.5, 3, 1.5, 3, 2.5, 2, 2.5]
|
||||
box6 = [2, 1.6, 3, 1.6, 3, 2.6, 2, 2.6]
|
||||
assert is_on_same_line(box4, box5, 0.5)
|
||||
assert not is_on_same_line(box4, box6, 0.5)
|
||||
|
||||
|
||||
def test_stitch_boxes_into_lines():
|
||||
boxes = [ # regular boxes
|
||||
[0, 0, 1, 0, 1, 1, 0, 1],
|
||||
[2, 0.5, 3, 0.5, 3, 1.5, 2, 1.5],
|
||||
[3, 1.2, 4, 1.2, 4, 2.2, 3, 2.2],
|
||||
[5, 0.5, 6, 0.5, 6, 1.5, 5, 1.5],
|
||||
# irregular box
|
||||
[6, 1.5, 7, 1.25, 7, 1.75, 6, 1.75]
|
||||
]
|
||||
raw_input = [{'box': boxes[i], 'text': str(i)} for i in range(len(boxes))]
|
||||
result = stitch_boxes_into_lines(raw_input, 1, 0.5)
|
||||
# Final lines: [0, 1], [2], [3, 4]
|
||||
# box 0, 1, 3, 4 are on the same line but box 3 is 2 pixels away from box 1
|
||||
# box 3 and 4 are on the same line since the length of overlapping part >=
|
||||
# 0.5 * the y-axis length of box 5
|
||||
expected_result = [{
|
||||
'box': [0, 0, 3, 0, 3, 1.5, 0, 1.5],
|
||||
'text': '0 1'
|
||||
}, {
|
||||
'box': [3, 1.2, 4, 1.2, 4, 2.2, 3, 2.2],
|
||||
'text': '2'
|
||||
}, {
|
||||
'box': [5, 0.5, 7, 0.5, 7, 1.75, 5, 1.75],
|
||||
'text': '3 4'
|
||||
}]
|
||||
result.sort(key=lambda x: x['box'][0])
|
||||
expected_result.sort(key=lambda x: x['box'][0])
|
||||
assert result == expected_result
|
Loading…
Reference in New Issue