cherry pick PRs from community (#7273)

* Merge pull request #6824 from ChenNima/release/2.5-kie-save-res

[kie]add write_kie_result to kie infer tool

* Merge pull request #6677 from TonyJiangWJ/release/2.5

修复内存泄露问题

* Update native.cpp (#6650)

fix issue 6640

* Merge pull request #6625 from ynjang/ynjang

update sorted_boxes

* fix DeprecationWarning, (#6604)

DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead

* Merge pull request #6585 from maxbachmann/release/2.5

replace GPL licensed components

* Merge pull request #6575 from Eling486/release/2.5

update win doc

* Merge pull request #6477 from MikoyChinese/fix-copy-paste

Fix copy_paste no texts augment.

* Merge pull request #6361 from mohamadmansourX/patch-9

Update README_en.md

Co-authored-by: Double_V <liuvv0203@163.com>
Co-authored-by: shawn <1021362695@qq.com>
Co-authored-by: paopjian <672034519@qq.com>
pull/7278/head
Evezerest 2022-08-21 18:03:57 +08:00 committed by GitHub
parent dd063fc98b
commit 7e4e87dd6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 56 additions and 56 deletions

View File

@ -47,7 +47,7 @@ str_to_cpu_mode(const std::string &cpu_mode) {
std::string upper_key; std::string upper_key;
std::transform(cpu_mode.cbegin(), cpu_mode.cend(), upper_key.begin(), std::transform(cpu_mode.cbegin(), cpu_mode.cend(), upper_key.begin(),
::toupper); ::toupper);
auto index = cpu_mode_map.find(upper_key); auto index = cpu_mode_map.find(upper_key.c_str());
if (index == cpu_mode_map.end()) { if (index == cpu_mode_map.end()) {
LOGE("cpu_mode not found %s", upper_key.c_str()); LOGE("cpu_mode not found %s", upper_key.c_str());
return paddle::lite_api::LITE_POWER_HIGH; return paddle::lite_api::LITE_POWER_HIGH;

View File

@ -54,7 +54,7 @@ public class OCRPredictorNative {
} }
public void destory() { public void destory() {
if (nativePointer > 0) { if (nativePointer != 0) {
release(nativePointer); release(nativePointer);
nativePointer = 0; nativePointer = 0;
} }

View File

@ -109,8 +109,10 @@ CUDA_LIB、CUDNN_LIB、TENSORRT_DIR、WITH_GPU、WITH_TENSORRT
运行之前,将下面文件拷贝到`build/Release/`文件夹下 运行之前,将下面文件拷贝到`build/Release/`文件夹下
1. `paddle_inference/paddle/lib/paddle_inference.dll` 1. `paddle_inference/paddle/lib/paddle_inference.dll`
2. `opencv/build/x64/vc15/bin/opencv_world455.dll` 2. `paddle_inference/third_party/install/onnxruntime/lib/onnxruntime.dll`
3. 如果使用openblas版本的预测库还需要拷贝 `paddle_inference/third_party/install/openblas/lib/openblas.dll` 3. `paddle_inference/third_party/install/paddle2onnx/lib/paddle2onnx.dll`
4. `opencv/build/x64/vc15/bin/opencv_world455.dll`
5. 如果使用openblas版本的预测库还需要拷贝 `paddle_inference/third_party/install/openblas/lib/openblas.dll`
### Step4: 预测 ### Step4: 预测

View File

@ -73,4 +73,4 @@ python deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_
The numerical range of the quantized model parameters derived from the above steps is still FP32, but the numerical range of the parameters is int8. The numerical range of the quantized model parameters derived from the above steps is still FP32, but the numerical range of the parameters is int8.
The derived model can be converted through the `opt tool` of PaddleLite. The derived model can be converted through the `opt tool` of PaddleLite.
For quantitative model deployment, please refer to [Mobile terminal model deployment](../../lite/readme_en.md) For quantitative model deployment, please refer to [Mobile terminal model deployment](../../lite/readme.md)

View File

@ -35,10 +35,12 @@ class CopyPaste(object):
point_num = data['polys'].shape[1] point_num = data['polys'].shape[1]
src_img = data['image'] src_img = data['image']
src_polys = data['polys'].tolist() src_polys = data['polys'].tolist()
src_texts = data['texts']
src_ignores = data['ignore_tags'].tolist() src_ignores = data['ignore_tags'].tolist()
ext_data = data['ext_data'][0] ext_data = data['ext_data'][0]
ext_image = ext_data['image'] ext_image = ext_data['image']
ext_polys = ext_data['polys'] ext_polys = ext_data['polys']
ext_texts = ext_data['texts']
ext_ignores = ext_data['ignore_tags'] ext_ignores = ext_data['ignore_tags']
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]] indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
@ -53,7 +55,7 @@ class CopyPaste(object):
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB) ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
src_img = Image.fromarray(src_img).convert('RGBA') src_img = Image.fromarray(src_img).convert('RGBA')
for poly, tag in zip(select_polys, select_ignores): for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
box_img = get_rotate_crop_image(ext_image, poly) box_img = get_rotate_crop_image(ext_image, poly)
src_img, box = self.paste_img(src_img, box_img, src_polys) src_img, box = self.paste_img(src_img, box_img, src_polys)
@ -62,6 +64,7 @@ class CopyPaste(object):
for _ in range(len(box), point_num): for _ in range(len(box), point_num):
box.append(box[-1]) box.append(box[-1])
src_polys.append(box) src_polys.append(box)
src_texts.append(ext_texts[idx])
src_ignores.append(tag) src_ignores.append(tag)
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR) src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
h, w = src_img.shape[:2] h, w = src_img.shape[:2]
@ -70,6 +73,7 @@ class CopyPaste(object):
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h) src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
data['image'] = src_img data['image'] = src_img
data['polys'] = src_polys data['polys'] = src_polys
data['texts'] = src_texts
data['ignore_tags'] = np.array(src_ignores) data['ignore_tags'] = np.array(src_ignores)
return data return data

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import Levenshtein from rapidfuzz.distance import Levenshtein
import string import string
@ -46,8 +46,7 @@ class RecMetric(object):
if self.is_filter: if self.is_filter:
pred = self._normalize_text(pred) pred = self._normalize_text(pred)
target = self._normalize_text(target) target = self._normalize_text(target)
norm_edit_dis += Levenshtein.distance(pred, target) / max( norm_edit_dis += Levenshtein.normalized_distance(pred, target)
len(pred), len(target), 1)
if pred == target: if pred == target:
correct_num += 1 correct_num += 1
all_num += 1 all_num += 1

View File

@ -20,7 +20,7 @@ from shapely.geometry import Polygon
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
import operator import operator
import Levenshtein from rapidfuzz.distance import Levenshtein
import argparse import argparse
import json import json
import copy import copy

View File

@ -9,7 +9,7 @@
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Apache 2.0 License for more details. # Apache 2.0 License for more details.
import distance from rapidfuzz.distance import Levenshtein
from apted import APTED, Config from apted import APTED, Config
from apted.helpers import Tree from apted.helpers import Tree
from lxml import etree, html from lxml import etree, html
@ -39,17 +39,6 @@ class TableTree(Tree):
class CustomConfig(Config): class CustomConfig(Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value
"""
return max(map(len, sequences))
def normalized_distance(self, *sequences):
"""Get distance from 0 to 1
"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
def rename(self, node1, node2): def rename(self, node1, node2):
"""Compares attributes of trees""" """Compares attributes of trees"""
#print(node1.tag) #print(node1.tag)
@ -58,23 +47,12 @@ class CustomConfig(Config):
if node1.tag == 'td': if node1.tag == 'td':
if node1.content or node2.content: if node1.content or node2.content:
#print(node1.content, ) #print(node1.content, )
return self.normalized_distance(node1.content, node2.content) return Levenshtein.normalized_distance(node1.content, node2.content)
return 0. return 0.
class CustomConfig_del_short(Config): class CustomConfig_del_short(Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value
"""
return max(map(len, sequences))
def normalized_distance(self, *sequences):
"""Get distance from 0 to 1
"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
def rename(self, node1, node2): def rename(self, node1, node2):
"""Compares attributes of trees""" """Compares attributes of trees"""
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
@ -90,21 +68,10 @@ class CustomConfig_del_short(Config):
node1_content = ['####'] node1_content = ['####']
if len(node2_content) < 3: if len(node2_content) < 3:
node2_content = ['####'] node2_content = ['####']
return self.normalized_distance(node1_content, node2_content) return Levenshtein.normalized_distance(node1_content, node2_content)
return 0. return 0.
class CustomConfig_del_block(Config): class CustomConfig_del_block(Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value
"""
return max(map(len, sequences))
def normalized_distance(self, *sequences):
"""Get distance from 0 to 1
"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
def rename(self, node1, node2): def rename(self, node1, node2):
"""Compares attributes of trees""" """Compares attributes of trees"""
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
@ -120,7 +87,7 @@ class CustomConfig_del_block(Config):
while ' ' in node2_content: while ' ' in node2_content:
print(node2_content.index(' ')) print(node2_content.index(' '))
node2_content.pop(node2_content.index(' ')) node2_content.pop(node2_content.index(' '))
return self.normalized_distance(node1_content, node2_content) return Levenshtein.normalized_distance(node1_content, node2_content)
return 0. return 0.
class TEDS(object): class TEDS(object):

View File

@ -6,7 +6,7 @@ lmdb
tqdm tqdm
numpy numpy
visualdl visualdl
python-Levenshtein rapidfuzz
opencv-contrib-python==4.4.0.46 opencv-contrib-python==4.4.0.46
cython cython
lxml lxml

View File

@ -120,11 +120,14 @@ def sorted_boxes(dt_boxes):
_boxes = list(sorted_boxes) _boxes = list(sorted_boxes)
for i in range(num_boxes - 1): for i in range(num_boxes - 1):
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ for j in range(i, 0, -1):
(_boxes[i + 1][0][0] < _boxes[i][0][0]): if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
tmp = _boxes[i] (_boxes[j + 1][0][0] < _boxes[j][0][0]):
_boxes[i] = _boxes[i + 1] tmp = _boxes[j]
_boxes[i + 1] = tmp _boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes return _boxes

View File

@ -549,7 +549,7 @@ def text_visual(texts,
def base64_to_cv2(b64str): def base64_to_cv2(b64str):
import base64 import base64
data = base64.b64decode(b64str.encode('utf8')) data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.frombuffer(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR) data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data return data

View File

@ -88,6 +88,29 @@ def draw_kie_result(batch, node, idx_to_cls, count):
cv2.imwrite(save_path, vis_img) cv2.imwrite(save_path, vis_img)
logger.info("The Kie Image saved in {}".format(save_path)) logger.info("The Kie Image saved in {}".format(save_path))
def write_kie_result(fout, node, data):
"""
Write infer result to output file, sorted by the predict label of each line.
The format keeps the same as the input with additional score attribute.
"""
import json
label = data['label']
annotations = json.loads(label)
max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
node_pred_label = max_idx.numpy().tolist()
node_pred_score = max_value.numpy().tolist()
res = []
for i, label in enumerate(node_pred_label):
pred_score = '{:.2f}'.format(node_pred_score[i])
pred_res = {
'label': label,
'transcription': annotations[i]['transcription'],
'score': pred_score,
'points': annotations[i]['points'],
}
res.append(pred_res)
res.sort(key=lambda x: x['label'])
fout.writelines([json.dumps(res, ensure_ascii=False) + '\n'])
def main(): def main():
global_config = config['Global'] global_config = config['Global']
@ -114,7 +137,7 @@ def main():
warmup_times = 0 warmup_times = 0
count_t = [] count_t = []
with open(save_res_path, "wb") as fout: with open(save_res_path, "w") as fout:
with open(config['Global']['infer_img'], "rb") as f: with open(config['Global']['infer_img'], "rb") as f:
lines = f.readlines() lines = f.readlines()
for index, data_line in enumerate(lines): for index, data_line in enumerate(lines):
@ -139,6 +162,8 @@ def main():
node = F.softmax(node, -1) node = F.softmax(node, -1)
count_t.append(time.time() - st) count_t.append(time.time() - st)
draw_kie_result(batch, node, idx_to_cls, index) draw_kie_result(batch, node, idx_to_cls, index)
write_kie_result(fout, node, data)
fout.close()
logger.info("success!") logger.info("success!")
logger.info("It took {} s for predict {} images.".format( logger.info("It took {} s for predict {} images.".format(
np.sum(count_t), len(count_t))) np.sum(count_t), len(count_t)))