add eval end2end
parent
00a8735073
commit
84b72584f2
|
@ -0,0 +1,94 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def poly_to_string(poly):
|
||||
if len(poly.shape) > 1:
|
||||
poly = np.array(poly).flatten()
|
||||
|
||||
string = "\t".join(str(i) for i in poly)
|
||||
return string
|
||||
|
||||
|
||||
def convert_label(label_dir, mode="gt", save_dir="./save_results/"):
|
||||
if not os.path.exists(label_dir):
|
||||
raise ValueError(f"The file {label_dir} does not exist!")
|
||||
|
||||
assert label_dir != save_dir, "hahahhaha"
|
||||
|
||||
label_file = open(label_dir, 'r')
|
||||
data = label_file.readlines()
|
||||
|
||||
gt_dict = {}
|
||||
|
||||
for line in data:
|
||||
try:
|
||||
tmp = line.split('\t')
|
||||
assert len(tmp) == 2, ""
|
||||
except:
|
||||
tmp = line.strip().split(' ')
|
||||
|
||||
gt_lists = []
|
||||
|
||||
if tmp[0].split('/')[0] is not None:
|
||||
img_path = tmp[0]
|
||||
anno = json.loads(tmp[1])
|
||||
gt_collect = []
|
||||
for dic in anno:
|
||||
#txt = dic['transcription'].replace(' ', '') # ignore blank
|
||||
txt = dic['transcription']
|
||||
if 'score' in dic and float(dic['score']) < 0.5:
|
||||
continue
|
||||
if u'\u3000' in txt: txt = txt.replace(u'\u3000', u' ')
|
||||
#while ' ' in txt:
|
||||
# txt = txt.replace(' ', '')
|
||||
poly = np.array(dic['points']).flatten()
|
||||
if txt == "###":
|
||||
txt_tag = 1 ## ignore 1
|
||||
else:
|
||||
txt_tag = 0
|
||||
if mode == "gt":
|
||||
gt_label = poly_to_string(poly) + "\t" + str(
|
||||
txt_tag) + "\t" + txt + "\n"
|
||||
else:
|
||||
gt_label = poly_to_string(poly) + "\t" + txt + "\n"
|
||||
|
||||
gt_lists.append(gt_label)
|
||||
|
||||
gt_dict[img_path] = gt_lists
|
||||
else:
|
||||
continue
|
||||
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
for img_name in gt_dict.keys():
|
||||
save_name = img_name.split("/")[-1]
|
||||
save_file = os.path.join(save_dir, save_name + ".txt")
|
||||
with open(save_file, "w") as f:
|
||||
f.writelines(gt_dict[img_name])
|
||||
|
||||
print("The convert label saved in {}".format(save_dir))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
ppocr_label_gt = "/paddle/Datasets/chinese/test_set/Label_refine_310_V2.txt"
|
||||
convert_label(ppocr_label_gt, "gt", "./save_gt_310_V2/")
|
||||
|
||||
ppocr_label_gt = "./infer_results/ch_PPOCRV2_infer.txt"
|
||||
convert_label(ppocr_label_gt_en, "pred", "./save_PPOCRV2_infer/")
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def draw_debug_img(html_path):
|
||||
|
||||
err_cnt = 0
|
||||
with open(html_path, 'w') as html:
|
||||
html.write('<html>\n<body>\n')
|
||||
html.write('<table border="1">\n')
|
||||
html.write(
|
||||
"<meta http-equiv=\"Content-Type\" content=\"text/html; charset=utf-8\" />"
|
||||
)
|
||||
image_list = []
|
||||
path = "./det_results/310_gt/"
|
||||
for i, filename in enumerate(sorted(os.listdir(path))):
|
||||
if filename.endswith("txt"): continue
|
||||
print(filename)
|
||||
# The image path
|
||||
base = "{}/{}".format(path, filename)
|
||||
base_2 = "../PaddleOCR/det_results/ch_PPOCRV2_infer/{}".format(
|
||||
filename)
|
||||
base_3 = "../PaddleOCR/det_results/ch_ppocr_mobile_infer/{}".format(
|
||||
filename)
|
||||
|
||||
html.write("<tr>\n")
|
||||
html.write(f'<td> {filename}\n GT')
|
||||
html.write('<td>GT\n<img src="%s" width=640></td>' % (base))
|
||||
html.write('<td>PPOCRV2\n<img src="%s" width=640></td>' % (base_2))
|
||||
html.write('<td>ppocr_mobile\n<img src="%s" width=640></td>' %
|
||||
(base_3))
|
||||
|
||||
html.write("</tr>\n")
|
||||
html.write('<style>\n')
|
||||
html.write('span {\n')
|
||||
html.write(' color: red;\n')
|
||||
html.write('}\n')
|
||||
html.write('</style>\n')
|
||||
html.write('</table>\n')
|
||||
html.write('</html>\n</body>\n')
|
||||
print("ok")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
html_path = "sys_visual_iou_310.html"
|
||||
|
||||
draw_debug_img()
|
|
@ -0,0 +1,193 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import shapely
|
||||
from shapely.geometry import Polygon
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
import operator
|
||||
import editdistance
|
||||
|
||||
|
||||
def strQ2B(ustring):
|
||||
rstring = ""
|
||||
for uchar in ustring:
|
||||
inside_code = ord(uchar)
|
||||
if inside_code == 12288:
|
||||
inside_code = 32
|
||||
elif (inside_code >= 65281 and inside_code <= 65374):
|
||||
inside_code -= 65248
|
||||
rstring += chr(inside_code)
|
||||
return rstring
|
||||
|
||||
|
||||
def polygon_from_str(polygon_points):
|
||||
"""
|
||||
Create a shapely polygon object from gt or dt line.
|
||||
"""
|
||||
polygon_points = np.array(polygon_points).reshape(4, 2)
|
||||
polygon = Polygon(polygon_points).convex_hull
|
||||
return polygon
|
||||
|
||||
|
||||
def polygon_iou(poly1, poly2):
|
||||
"""
|
||||
Intersection over union between two shapely polygons.
|
||||
"""
|
||||
if not poly1.intersects(
|
||||
poly2): # this test is fast and can accelerate calculation
|
||||
iou = 0
|
||||
else:
|
||||
try:
|
||||
inter_area = poly1.intersection(poly2).area
|
||||
union_area = poly1.area + poly2.area - inter_area
|
||||
iou = float(inter_area) / union_area
|
||||
except shapely.geos.TopologicalError:
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
print('shapely.geos.TopologicalError occured, iou set to 0')
|
||||
iou = 0
|
||||
return iou
|
||||
|
||||
|
||||
def ed(str1, str2):
|
||||
return editdistance.eval(str1, str2)
|
||||
|
||||
|
||||
def e2e_eval(gt_dir, res_dir, ignore_blank=False):
|
||||
print('start testing...')
|
||||
iou_thresh = 0.5
|
||||
val_names = os.listdir(gt_dir)
|
||||
num_gt_chars = 0
|
||||
gt_count = 0
|
||||
dt_count = 0
|
||||
hit = 0
|
||||
ed_sum = 0
|
||||
|
||||
for i, val_name in enumerate(val_names):
|
||||
with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f:
|
||||
gt_lines = [o.strip() for o in f.readlines()]
|
||||
gts = []
|
||||
ignore_masks = []
|
||||
for line in gt_lines:
|
||||
parts = line.strip().split('\t')
|
||||
# ignore illegal data
|
||||
if len(parts) < 9:
|
||||
continue
|
||||
assert (len(parts) < 11)
|
||||
if len(parts) == 9:
|
||||
gts.append(parts[:8] + [''])
|
||||
else:
|
||||
gts.append(parts[:8] + [parts[-1]])
|
||||
|
||||
ignore_masks.append(parts[8])
|
||||
|
||||
val_path = os.path.join(res_dir, val_name)
|
||||
if not os.path.exists(val_path):
|
||||
dt_lines = []
|
||||
else:
|
||||
with open(val_path, encoding='utf-8') as f:
|
||||
dt_lines = [o.strip() for o in f.readlines()]
|
||||
dts = []
|
||||
for line in dt_lines:
|
||||
# print(line)
|
||||
parts = line.strip().split("\t")
|
||||
assert (len(parts) < 10), "line error: {}".format(line)
|
||||
if len(parts) == 8:
|
||||
dts.append(parts + [''])
|
||||
else:
|
||||
dts.append(parts)
|
||||
|
||||
dt_match = [False] * len(dts)
|
||||
gt_match = [False] * len(gts)
|
||||
all_ious = defaultdict(tuple)
|
||||
for index_gt, gt in enumerate(gts):
|
||||
gt_coors = [float(gt_coor) for gt_coor in gt[0:8]]
|
||||
gt_poly = polygon_from_str(gt_coors)
|
||||
for index_dt, dt in enumerate(dts):
|
||||
dt_coors = [float(dt_coor) for dt_coor in dt[0:8]]
|
||||
dt_poly = polygon_from_str(dt_coors)
|
||||
iou = polygon_iou(dt_poly, gt_poly)
|
||||
if iou >= iou_thresh:
|
||||
all_ious[(index_gt, index_dt)] = iou
|
||||
sorted_ious = sorted(
|
||||
all_ious.items(), key=operator.itemgetter(1), reverse=True)
|
||||
sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
|
||||
|
||||
# matched gt and dt
|
||||
for gt_dt_pair in sorted_gt_dt_pairs:
|
||||
index_gt, index_dt = gt_dt_pair
|
||||
if gt_match[index_gt] == False and dt_match[index_dt] == False:
|
||||
gt_match[index_gt] = True
|
||||
dt_match[index_dt] = True
|
||||
if ignore_blank:
|
||||
gt_str = strQ2B(gts[index_gt][8]).replace(" ", "")
|
||||
dt_str = strQ2B(dts[index_dt][8]).replace(" ", "")
|
||||
else:
|
||||
gt_str = strQ2B(gts[index_gt][8])
|
||||
dt_str = strQ2B(dts[index_dt][8])
|
||||
if ignore_masks[index_gt] == '0':
|
||||
ed_sum += ed(gt_str, dt_str)
|
||||
num_gt_chars += len(gt_str)
|
||||
if gt_str == dt_str:
|
||||
hit += 1
|
||||
gt_count += 1
|
||||
dt_count += 1
|
||||
|
||||
# unmatched dt
|
||||
for tindex, dt_match_flag in enumerate(dt_match):
|
||||
if dt_match_flag == False:
|
||||
dt_str = dts[tindex][8]
|
||||
gt_str = ''
|
||||
ed_sum += ed(dt_str, gt_str)
|
||||
dt_count += 1
|
||||
|
||||
# unmatched gt
|
||||
for tindex, gt_match_flag in enumerate(gt_match):
|
||||
if gt_match_flag == False and ignore_masks[tindex] == '0':
|
||||
dt_str = ''
|
||||
gt_str = gts[tindex][8]
|
||||
ed_sum += ed(gt_str, dt_str)
|
||||
num_gt_chars += len(gt_str)
|
||||
gt_count += 1
|
||||
|
||||
eps = 1e-9
|
||||
print('hit, dt_count, gt_count', hit, dt_count, gt_count)
|
||||
precision = hit / (dt_count + eps)
|
||||
recall = hit / (gt_count + eps)
|
||||
fmeasure = 2.0 * precision * recall / (precision + recall + eps)
|
||||
avg_edit_dist_img = ed_sum / len(val_names)
|
||||
avg_edit_dist_field = ed_sum / (gt_count + eps)
|
||||
character_acc = 1 - ed_sum / (num_gt_chars + eps)
|
||||
|
||||
print('character_acc: %.2f' % (character_acc * 100) + "%")
|
||||
print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
|
||||
print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
|
||||
print('precision: %.2f' % (precision * 100) + "%")
|
||||
print('recall: %.2f' % (recall * 100) + "%")
|
||||
print('fmeasure: %.2f' % (fmeasure * 100) + "%")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# if len(sys.argv) != 3:
|
||||
# print("python3 ocr_e2e_eval.py gt_dir res_dir")
|
||||
# exit(-1)
|
||||
# gt_folder = sys.argv[1]
|
||||
# pred_folder = sys.argv[2]
|
||||
gt_folder = sys.argv[1]
|
||||
pred_folder = sys.argv[2]
|
||||
e2e_eval(gt_folder, pred_folder)
|
|
@ -0,0 +1,69 @@
|
|||
|
||||
# 简介
|
||||
|
||||
`tools/end2end`目录下存放了文本检测+文本识别pipeline串联预测的指标评测代码以及可视化工具。本节介绍文本检测+文本识别的端对端指标评估方式。
|
||||
|
||||
|
||||
## 端对端评测步骤
|
||||
|
||||
**步骤一:**
|
||||
|
||||
运行`tools/infer/predict_system.py`,得到保存的结果:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_system.py --det_model_dir=./ch_PP-OCRv2_det_infer/ --rec_model_dir=./ch_PP-OCRv2_rec_infer/ --image_dir=./datasets/img_dir/ --draw_img_save_dir=./ch_PP-OCRv2_results/ --is_visualize=True
|
||||
```
|
||||
|
||||
文本检测识别可视化图默认保存在`./ch_PP-OCRv2_results/`目录下,预测结果默认保存在`./ch_PP-OCRv2_results/system_results.txt`中,格式如下:
|
||||
```
|
||||
all-sum-510/00224225.jpg [{"transcription": "超赞", "points": [[8.0, 48.0], [157.0, 44.0], [159.0, 115.0], [10.0, 119.0]], "score": "0.99396634"}, {"transcription": "中", "points": [[202.0, 152.0], [230.0, 152.0], [230.0, 163.0], [202.0, 163.0]], "score": "0.09310734"}, {"transcription": "58.0m", "points": [[196.0, 192.0], [444.0, 192.0], [444.0, 240.0], [196.0, 240.0]], "score": "0.44041982"}, {"transcription": "汽配", "points": [[55.0, 263.0], [95.0, 263.0], [95.0, 281.0], [55.0, 281.0]], "score": "0.9986651"}, {"transcription": "成总店", "points": [[120.0, 262.0], [176.0, 262.0], [176.0, 283.0], [120.0, 283.0]], "score": "0.9929402"}, {"transcription": "K", "points": [[237.0, 286.0], [311.0, 286.0], [311.0, 345.0], [237.0, 345.0]], "score": "0.6074794"}, {"transcription": "88:-8", "points": [[203.0, 405.0], [477.0, 414.0], [475.0, 459.0], [201.0, 450.0]], "score": "0.7106863"}]
|
||||
```
|
||||
|
||||
|
||||
**步骤二:**
|
||||
|
||||
将步骤一保存的数据转换为端对端评测需要的数据格式:
|
||||
修改 `tools/convert_ppocr_label.py`中的代码,convert_label函数中设置输入标签路径,Mode,保存标签路径等,对预测数据的GTlabel和预测结果的label格式进行转换。
|
||||
|
||||
```
|
||||
ppocr_label_gt = "gt_label.txt"
|
||||
convert_label(ppocr_label_gt, "gt", "./save_gt_label/")
|
||||
|
||||
ppocr_label_gt = "./ch_PP-OCRv2_results/system_results.txt"
|
||||
convert_label(ppocr_label_gt_en, "pred", "./save_PPOCRV2_infer/")
|
||||
```
|
||||
|
||||
运行`convert_ppocr_label.py`:
|
||||
```
|
||||
python3 tools/convert_ppocr_label.py
|
||||
```
|
||||
|
||||
得到如下结果:
|
||||
```
|
||||
├── ./save_gt_label/
|
||||
├── ./save_PPOCRV2_infer/
|
||||
```
|
||||
|
||||
**步骤三:**
|
||||
|
||||
执行端对端评测,运行`tools/eval_end2end.py`计算端对端指标,运行方式如下:
|
||||
|
||||
```
|
||||
python3 tools/eval_end2end.py "gt_label_dir" "predict_label_dir"
|
||||
```
|
||||
|
||||
比如:
|
||||
|
||||
```
|
||||
python3 tools/eval_end2end.py ./save_gt_label/ ./save_PPOCRV2_infer/
|
||||
```
|
||||
将得到如下结果,fmeasure为主要关注的指标:
|
||||
```
|
||||
hit, dt_count, gt_count 1557 2693 3283
|
||||
character_acc: 61.77%
|
||||
avg_edit_dist_field: 3.08
|
||||
avg_edit_dist_img: 51.82
|
||||
precision: 57.82%
|
||||
recall: 47.43%
|
||||
fmeasure: 52.11%
|
||||
```
|
Loading…
Reference in New Issue