diff --git a/deploy/android_demo/app/src/main/cpp/native.cpp b/deploy/android_demo/app/src/main/cpp/native.cpp index ced932556..4961e5ecf 100644 --- a/deploy/android_demo/app/src/main/cpp/native.cpp +++ b/deploy/android_demo/app/src/main/cpp/native.cpp @@ -47,7 +47,7 @@ str_to_cpu_mode(const std::string &cpu_mode) { std::string upper_key; std::transform(cpu_mode.cbegin(), cpu_mode.cend(), upper_key.begin(), ::toupper); - auto index = cpu_mode_map.find(upper_key); + auto index = cpu_mode_map.find(upper_key.c_str()); if (index == cpu_mode_map.end()) { LOGE("cpu_mode not found %s", upper_key.c_str()); return paddle::lite_api::LITE_POWER_HIGH; @@ -116,4 +116,4 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_release( ppredictor::OCR_PPredictor *ppredictor = (ppredictor::OCR_PPredictor *)java_pointer; delete ppredictor; -} \ No newline at end of file +} diff --git a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java index 622da2a3f..41fa183de 100644 --- a/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java +++ b/deploy/android_demo/app/src/main/java/com/baidu/paddle/lite/demo/ocr/OCRPredictorNative.java @@ -54,7 +54,7 @@ public class OCRPredictorNative { } public void destory() { - if (nativePointer > 0) { + if (nativePointer != 0) { release(nativePointer); nativePointer = 0; } diff --git a/deploy/cpp_infer/docs/windows_vs2019_build.md b/deploy/cpp_infer/docs/windows_vs2019_build.md index 4f391d925..bcaefa46f 100644 --- a/deploy/cpp_infer/docs/windows_vs2019_build.md +++ b/deploy/cpp_infer/docs/windows_vs2019_build.md @@ -109,8 +109,10 @@ CUDA_LIB、CUDNN_LIB、TENSORRT_DIR、WITH_GPU、WITH_TENSORRT 运行之前,将下面文件拷贝到`build/Release/`文件夹下 1. `paddle_inference/paddle/lib/paddle_inference.dll` -2. `opencv/build/x64/vc15/bin/opencv_world455.dll` -3. 如果使用openblas版本的预测库还需要拷贝 `paddle_inference/third_party/install/openblas/lib/openblas.dll` +2. `paddle_inference/third_party/install/onnxruntime/lib/onnxruntime.dll` +3. `paddle_inference/third_party/install/paddle2onnx/lib/paddle2onnx.dll` +4. `opencv/build/x64/vc15/bin/opencv_world455.dll` +5. 如果使用openblas版本的预测库还需要拷贝 `paddle_inference/third_party/install/openblas/lib/openblas.dll` ### Step4: 预测 diff --git a/deploy/slim/quantization/README_en.md b/deploy/slim/quantization/README_en.md index 33b2c4784..c6796ae9d 100644 --- a/deploy/slim/quantization/README_en.md +++ b/deploy/slim/quantization/README_en.md @@ -73,4 +73,4 @@ python deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_ The numerical range of the quantized model parameters derived from the above steps is still FP32, but the numerical range of the parameters is int8. The derived model can be converted through the `opt tool` of PaddleLite. -For quantitative model deployment, please refer to [Mobile terminal model deployment](../../lite/readme_en.md) +For quantitative model deployment, please refer to [Mobile terminal model deployment](../../lite/readme.md) diff --git a/ppocr/data/imaug/copy_paste.py b/ppocr/data/imaug/copy_paste.py index 0b3386c89..79343da60 100644 --- a/ppocr/data/imaug/copy_paste.py +++ b/ppocr/data/imaug/copy_paste.py @@ -35,10 +35,12 @@ class CopyPaste(object): point_num = data['polys'].shape[1] src_img = data['image'] src_polys = data['polys'].tolist() + src_texts = data['texts'] src_ignores = data['ignore_tags'].tolist() ext_data = data['ext_data'][0] ext_image = ext_data['image'] ext_polys = ext_data['polys'] + ext_texts = ext_data['texts'] ext_ignores = ext_data['ignore_tags'] indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]] @@ -53,7 +55,7 @@ class CopyPaste(object): src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB) src_img = Image.fromarray(src_img).convert('RGBA') - for poly, tag in zip(select_polys, select_ignores): + for idx, poly, tag in zip(select_idxs, select_polys, select_ignores): box_img = get_rotate_crop_image(ext_image, poly) src_img, box = self.paste_img(src_img, box_img, src_polys) @@ -62,6 +64,7 @@ class CopyPaste(object): for _ in range(len(box), point_num): box.append(box[-1]) src_polys.append(box) + src_texts.append(ext_texts[idx]) src_ignores.append(tag) src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR) h, w = src_img.shape[:2] @@ -70,6 +73,7 @@ class CopyPaste(object): src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h) data['image'] = src_img data['polys'] = src_polys + data['texts'] = src_texts data['ignore_tags'] = np.array(src_ignores) return data diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index d858ae28e..986397811 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import Levenshtein +from rapidfuzz.distance import Levenshtein import string @@ -46,8 +46,7 @@ class RecMetric(object): if self.is_filter: pred = self._normalize_text(pred) target = self._normalize_text(target) - norm_edit_dis += Levenshtein.distance(pred, target) / max( - len(pred), len(target), 1) + norm_edit_dis += Levenshtein.normalized_distance(pred, target) if pred == target: correct_num += 1 all_num += 1 diff --git a/ppstructure/kie/tools/eval_with_label_end2end.py b/ppstructure/kie/tools/eval_with_label_end2end.py index b13ffb568..b0fd84363 100644 --- a/ppstructure/kie/tools/eval_with_label_end2end.py +++ b/ppstructure/kie/tools/eval_with_label_end2end.py @@ -20,7 +20,7 @@ from shapely.geometry import Polygon import numpy as np from collections import defaultdict import operator -import Levenshtein +from rapidfuzz.distance import Levenshtein import argparse import json import copy diff --git a/ppstructure/table/table_metric/table_metric.py b/ppstructure/table/table_metric/table_metric.py index 9aca98ad7..923a9c007 100755 --- a/ppstructure/table/table_metric/table_metric.py +++ b/ppstructure/table/table_metric/table_metric.py @@ -9,7 +9,7 @@ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # Apache 2.0 License for more details. -import distance +from rapidfuzz.distance import Levenshtein from apted import APTED, Config from apted.helpers import Tree from lxml import etree, html @@ -39,17 +39,6 @@ class TableTree(Tree): class CustomConfig(Config): - @staticmethod - def maximum(*sequences): - """Get maximum possible value - """ - return max(map(len, sequences)) - - def normalized_distance(self, *sequences): - """Get distance from 0 to 1 - """ - return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) - def rename(self, node1, node2): """Compares attributes of trees""" #print(node1.tag) @@ -58,23 +47,12 @@ class CustomConfig(Config): if node1.tag == 'td': if node1.content or node2.content: #print(node1.content, ) - return self.normalized_distance(node1.content, node2.content) + return Levenshtein.normalized_distance(node1.content, node2.content) return 0. class CustomConfig_del_short(Config): - @staticmethod - def maximum(*sequences): - """Get maximum possible value - """ - return max(map(len, sequences)) - - def normalized_distance(self, *sequences): - """Get distance from 0 to 1 - """ - return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) - def rename(self, node1, node2): """Compares attributes of trees""" if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): @@ -90,21 +68,10 @@ class CustomConfig_del_short(Config): node1_content = ['####'] if len(node2_content) < 3: node2_content = ['####'] - return self.normalized_distance(node1_content, node2_content) + return Levenshtein.normalized_distance(node1_content, node2_content) return 0. class CustomConfig_del_block(Config): - @staticmethod - def maximum(*sequences): - """Get maximum possible value - """ - return max(map(len, sequences)) - - def normalized_distance(self, *sequences): - """Get distance from 0 to 1 - """ - return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) - def rename(self, node1, node2): """Compares attributes of trees""" if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): @@ -120,7 +87,7 @@ class CustomConfig_del_block(Config): while ' ' in node2_content: print(node2_content.index(' ')) node2_content.pop(node2_content.index(' ')) - return self.normalized_distance(node1_content, node2_content) + return Levenshtein.normalized_distance(node1_content, node2_content) return 0. class TEDS(object): diff --git a/requirements.txt b/requirements.txt index b15176db3..976d29192 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ lmdb tqdm numpy visualdl -python-Levenshtein +rapidfuzz opencv-contrib-python==4.4.0.46 cython lxml diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 252ed1aaf..e0f2c41fa 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -120,11 +120,14 @@ def sorted_boxes(dt_boxes): _boxes = list(sorted_boxes) for i in range(num_boxes - 1): - if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ - (_boxes[i + 1][0][0] < _boxes[i][0][0]): - tmp = _boxes[i] - _boxes[i] = _boxes[i + 1] - _boxes[i + 1] = tmp + for j in range(i, 0, -1): + if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \ + (_boxes[j + 1][0][0] < _boxes[j][0][0]): + tmp = _boxes[j] + _boxes[j] = _boxes[j + 1] + _boxes[j + 1] = tmp + else: + break return _boxes diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 1355ca62e..a547bbdba 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -549,7 +549,7 @@ def text_visual(texts, def base64_to_cv2(b64str): import base64 data = base64.b64decode(b64str.encode('utf8')) - data = np.fromstring(data, np.uint8) + data = np.frombuffer(data, np.uint8) data = cv2.imdecode(data, cv2.IMREAD_COLOR) return data diff --git a/tools/infer_kie.py b/tools/infer_kie.py index 346e2e0ae..9375434cc 100755 --- a/tools/infer_kie.py +++ b/tools/infer_kie.py @@ -88,6 +88,29 @@ def draw_kie_result(batch, node, idx_to_cls, count): cv2.imwrite(save_path, vis_img) logger.info("The Kie Image saved in {}".format(save_path)) +def write_kie_result(fout, node, data): + """ + Write infer result to output file, sorted by the predict label of each line. + The format keeps the same as the input with additional score attribute. + """ + import json + label = data['label'] + annotations = json.loads(label) + max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1) + node_pred_label = max_idx.numpy().tolist() + node_pred_score = max_value.numpy().tolist() + res = [] + for i, label in enumerate(node_pred_label): + pred_score = '{:.2f}'.format(node_pred_score[i]) + pred_res = { + 'label': label, + 'transcription': annotations[i]['transcription'], + 'score': pred_score, + 'points': annotations[i]['points'], + } + res.append(pred_res) + res.sort(key=lambda x: x['label']) + fout.writelines([json.dumps(res, ensure_ascii=False) + '\n']) def main(): global_config = config['Global'] @@ -114,7 +137,7 @@ def main(): warmup_times = 0 count_t = [] - with open(save_res_path, "wb") as fout: + with open(save_res_path, "w") as fout: with open(config['Global']['infer_img'], "rb") as f: lines = f.readlines() for index, data_line in enumerate(lines): @@ -139,6 +162,8 @@ def main(): node = F.softmax(node, -1) count_t.append(time.time() - st) draw_kie_result(batch, node, idx_to_cls, index) + write_kie_result(fout, node, data) + fout.close() logger.info("success!") logger.info("It took {} s for predict {} images.".format( np.sum(count_t), len(count_t)))