194 lines
6.5 KiB
Python
194 lines
6.5 KiB
Python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import re
|
|
import sys
|
|
import shapely
|
|
from shapely.geometry import Polygon
|
|
import numpy as np
|
|
from collections import defaultdict
|
|
import operator
|
|
import editdistance
|
|
|
|
|
|
def strQ2B(ustring):
|
|
rstring = ""
|
|
for uchar in ustring:
|
|
inside_code = ord(uchar)
|
|
if inside_code == 12288:
|
|
inside_code = 32
|
|
elif (inside_code >= 65281 and inside_code <= 65374):
|
|
inside_code -= 65248
|
|
rstring += chr(inside_code)
|
|
return rstring
|
|
|
|
|
|
def polygon_from_str(polygon_points):
|
|
"""
|
|
Create a shapely polygon object from gt or dt line.
|
|
"""
|
|
polygon_points = np.array(polygon_points).reshape(4, 2)
|
|
polygon = Polygon(polygon_points).convex_hull
|
|
return polygon
|
|
|
|
|
|
def polygon_iou(poly1, poly2):
|
|
"""
|
|
Intersection over union between two shapely polygons.
|
|
"""
|
|
if not poly1.intersects(
|
|
poly2): # this test is fast and can accelerate calculation
|
|
iou = 0
|
|
else:
|
|
try:
|
|
inter_area = poly1.intersection(poly2).area
|
|
union_area = poly1.area + poly2.area - inter_area
|
|
iou = float(inter_area) / union_area
|
|
except shapely.geos.TopologicalError:
|
|
# except Exception as e:
|
|
# print(e)
|
|
print('shapely.geos.TopologicalError occurred, iou set to 0')
|
|
iou = 0
|
|
return iou
|
|
|
|
|
|
def ed(str1, str2):
|
|
return editdistance.eval(str1, str2)
|
|
|
|
|
|
def e2e_eval(gt_dir, res_dir, ignore_blank=False):
|
|
print('start testing...')
|
|
iou_thresh = 0.5
|
|
val_names = os.listdir(gt_dir)
|
|
num_gt_chars = 0
|
|
gt_count = 0
|
|
dt_count = 0
|
|
hit = 0
|
|
ed_sum = 0
|
|
|
|
for i, val_name in enumerate(val_names):
|
|
with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f:
|
|
gt_lines = [o.strip() for o in f.readlines()]
|
|
gts = []
|
|
ignore_masks = []
|
|
for line in gt_lines:
|
|
parts = line.strip().split('\t')
|
|
# ignore illegal data
|
|
if len(parts) < 9:
|
|
continue
|
|
assert (len(parts) < 11)
|
|
if len(parts) == 9:
|
|
gts.append(parts[:8] + [''])
|
|
else:
|
|
gts.append(parts[:8] + [parts[-1]])
|
|
|
|
ignore_masks.append(parts[8])
|
|
|
|
val_path = os.path.join(res_dir, val_name)
|
|
if not os.path.exists(val_path):
|
|
dt_lines = []
|
|
else:
|
|
with open(val_path, encoding='utf-8') as f:
|
|
dt_lines = [o.strip() for o in f.readlines()]
|
|
dts = []
|
|
for line in dt_lines:
|
|
# print(line)
|
|
parts = line.strip().split("\t")
|
|
assert (len(parts) < 10), "line error: {}".format(line)
|
|
if len(parts) == 8:
|
|
dts.append(parts + [''])
|
|
else:
|
|
dts.append(parts)
|
|
|
|
dt_match = [False] * len(dts)
|
|
gt_match = [False] * len(gts)
|
|
all_ious = defaultdict(tuple)
|
|
for index_gt, gt in enumerate(gts):
|
|
gt_coors = [float(gt_coor) for gt_coor in gt[0:8]]
|
|
gt_poly = polygon_from_str(gt_coors)
|
|
for index_dt, dt in enumerate(dts):
|
|
dt_coors = [float(dt_coor) for dt_coor in dt[0:8]]
|
|
dt_poly = polygon_from_str(dt_coors)
|
|
iou = polygon_iou(dt_poly, gt_poly)
|
|
if iou >= iou_thresh:
|
|
all_ious[(index_gt, index_dt)] = iou
|
|
sorted_ious = sorted(
|
|
all_ious.items(), key=operator.itemgetter(1), reverse=True)
|
|
sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
|
|
|
|
# matched gt and dt
|
|
for gt_dt_pair in sorted_gt_dt_pairs:
|
|
index_gt, index_dt = gt_dt_pair
|
|
if gt_match[index_gt] == False and dt_match[index_dt] == False:
|
|
gt_match[index_gt] = True
|
|
dt_match[index_dt] = True
|
|
if ignore_blank:
|
|
gt_str = strQ2B(gts[index_gt][8]).replace(" ", "")
|
|
dt_str = strQ2B(dts[index_dt][8]).replace(" ", "")
|
|
else:
|
|
gt_str = strQ2B(gts[index_gt][8])
|
|
dt_str = strQ2B(dts[index_dt][8])
|
|
if ignore_masks[index_gt] == '0':
|
|
ed_sum += ed(gt_str, dt_str)
|
|
num_gt_chars += len(gt_str)
|
|
if gt_str == dt_str:
|
|
hit += 1
|
|
gt_count += 1
|
|
dt_count += 1
|
|
|
|
# unmatched dt
|
|
for tindex, dt_match_flag in enumerate(dt_match):
|
|
if dt_match_flag == False:
|
|
dt_str = dts[tindex][8]
|
|
gt_str = ''
|
|
ed_sum += ed(dt_str, gt_str)
|
|
dt_count += 1
|
|
|
|
# unmatched gt
|
|
for tindex, gt_match_flag in enumerate(gt_match):
|
|
if gt_match_flag == False and ignore_masks[tindex] == '0':
|
|
dt_str = ''
|
|
gt_str = gts[tindex][8]
|
|
ed_sum += ed(gt_str, dt_str)
|
|
num_gt_chars += len(gt_str)
|
|
gt_count += 1
|
|
|
|
eps = 1e-9
|
|
print('hit, dt_count, gt_count', hit, dt_count, gt_count)
|
|
precision = hit / (dt_count + eps)
|
|
recall = hit / (gt_count + eps)
|
|
fmeasure = 2.0 * precision * recall / (precision + recall + eps)
|
|
avg_edit_dist_img = ed_sum / len(val_names)
|
|
avg_edit_dist_field = ed_sum / (gt_count + eps)
|
|
character_acc = 1 - ed_sum / (num_gt_chars + eps)
|
|
|
|
print('character_acc: %.2f' % (character_acc * 100) + "%")
|
|
print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
|
|
print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
|
|
print('precision: %.2f' % (precision * 100) + "%")
|
|
print('recall: %.2f' % (recall * 100) + "%")
|
|
print('fmeasure: %.2f' % (fmeasure * 100) + "%")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# if len(sys.argv) != 3:
|
|
# print("python3 ocr_e2e_eval.py gt_dir res_dir")
|
|
# exit(-1)
|
|
# gt_folder = sys.argv[1]
|
|
# pred_folder = sys.argv[2]
|
|
gt_folder = sys.argv[1]
|
|
pred_folder = sys.argv[2]
|
|
e2e_eval(gt_folder, pred_folder)
|