cherry pick PRs from community (#7273)
* Merge pull request #6824 from ChenNima/release/2.5-kie-save-res [kie]add write_kie_result to kie infer tool * Merge pull request #6677 from TonyJiangWJ/release/2.5 修复内存泄露问题 * Update native.cpp (#6650) fix issue 6640 * Merge pull request #6625 from ynjang/ynjang update sorted_boxes * fix DeprecationWarning, (#6604) DeprecationWarning: The binary mode of fromstring is deprecated, as it behaves surprisingly on unicode inputs. Use frombuffer instead * Merge pull request #6585 from maxbachmann/release/2.5 replace GPL licensed components * Merge pull request #6575 from Eling486/release/2.5 update win doc * Merge pull request #6477 from MikoyChinese/fix-copy-paste Fix copy_paste no texts augment. * Merge pull request #6361 from mohamadmansourX/patch-9 Update README_en.md Co-authored-by: Double_V <liuvv0203@163.com> Co-authored-by: shawn <1021362695@qq.com> Co-authored-by: paopjian <672034519@qq.com>pull/7278/head
parent
dd063fc98b
commit
7e4e87dd6d
deploy
android_demo/app/src/main
cpp
java/com/baidu/paddle/lite/demo/ocr
cpp_infer/docs
slim/quantization
ppocr
data/imaug
metrics
ppstructure
kie/tools
table/table_metric
tools
|
@ -47,7 +47,7 @@ str_to_cpu_mode(const std::string &cpu_mode) {
|
|||
std::string upper_key;
|
||||
std::transform(cpu_mode.cbegin(), cpu_mode.cend(), upper_key.begin(),
|
||||
::toupper);
|
||||
auto index = cpu_mode_map.find(upper_key);
|
||||
auto index = cpu_mode_map.find(upper_key.c_str());
|
||||
if (index == cpu_mode_map.end()) {
|
||||
LOGE("cpu_mode not found %s", upper_key.c_str());
|
||||
return paddle::lite_api::LITE_POWER_HIGH;
|
||||
|
@ -116,4 +116,4 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_release(
|
|||
ppredictor::OCR_PPredictor *ppredictor =
|
||||
(ppredictor::OCR_PPredictor *)java_pointer;
|
||||
delete ppredictor;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,7 +54,7 @@ public class OCRPredictorNative {
|
|||
}
|
||||
|
||||
public void destory() {
|
||||
if (nativePointer > 0) {
|
||||
if (nativePointer != 0) {
|
||||
release(nativePointer);
|
||||
nativePointer = 0;
|
||||
}
|
||||
|
|
|
@ -109,8 +109,10 @@ CUDA_LIB、CUDNN_LIB、TENSORRT_DIR、WITH_GPU、WITH_TENSORRT
|
|||
|
||||
运行之前,将下面文件拷贝到`build/Release/`文件夹下
|
||||
1. `paddle_inference/paddle/lib/paddle_inference.dll`
|
||||
2. `opencv/build/x64/vc15/bin/opencv_world455.dll`
|
||||
3. 如果使用openblas版本的预测库还需要拷贝 `paddle_inference/third_party/install/openblas/lib/openblas.dll`
|
||||
2. `paddle_inference/third_party/install/onnxruntime/lib/onnxruntime.dll`
|
||||
3. `paddle_inference/third_party/install/paddle2onnx/lib/paddle2onnx.dll`
|
||||
4. `opencv/build/x64/vc15/bin/opencv_world455.dll`
|
||||
5. 如果使用openblas版本的预测库还需要拷贝 `paddle_inference/third_party/install/openblas/lib/openblas.dll`
|
||||
|
||||
### Step4: 预测
|
||||
|
||||
|
|
|
@ -73,4 +73,4 @@ python deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_
|
|||
The numerical range of the quantized model parameters derived from the above steps is still FP32, but the numerical range of the parameters is int8.
|
||||
The derived model can be converted through the `opt tool` of PaddleLite.
|
||||
|
||||
For quantitative model deployment, please refer to [Mobile terminal model deployment](../../lite/readme_en.md)
|
||||
For quantitative model deployment, please refer to [Mobile terminal model deployment](../../lite/readme.md)
|
||||
|
|
|
@ -35,10 +35,12 @@ class CopyPaste(object):
|
|||
point_num = data['polys'].shape[1]
|
||||
src_img = data['image']
|
||||
src_polys = data['polys'].tolist()
|
||||
src_texts = data['texts']
|
||||
src_ignores = data['ignore_tags'].tolist()
|
||||
ext_data = data['ext_data'][0]
|
||||
ext_image = ext_data['image']
|
||||
ext_polys = ext_data['polys']
|
||||
ext_texts = ext_data['texts']
|
||||
ext_ignores = ext_data['ignore_tags']
|
||||
|
||||
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
|
||||
|
@ -53,7 +55,7 @@ class CopyPaste(object):
|
|||
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
|
||||
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
|
||||
src_img = Image.fromarray(src_img).convert('RGBA')
|
||||
for poly, tag in zip(select_polys, select_ignores):
|
||||
for idx, poly, tag in zip(select_idxs, select_polys, select_ignores):
|
||||
box_img = get_rotate_crop_image(ext_image, poly)
|
||||
|
||||
src_img, box = self.paste_img(src_img, box_img, src_polys)
|
||||
|
@ -62,6 +64,7 @@ class CopyPaste(object):
|
|||
for _ in range(len(box), point_num):
|
||||
box.append(box[-1])
|
||||
src_polys.append(box)
|
||||
src_texts.append(ext_texts[idx])
|
||||
src_ignores.append(tag)
|
||||
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
|
||||
h, w = src_img.shape[:2]
|
||||
|
@ -70,6 +73,7 @@ class CopyPaste(object):
|
|||
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
|
||||
data['image'] = src_img
|
||||
data['polys'] = src_polys
|
||||
data['texts'] = src_texts
|
||||
data['ignore_tags'] = np.array(src_ignores)
|
||||
return data
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import Levenshtein
|
||||
from rapidfuzz.distance import Levenshtein
|
||||
import string
|
||||
|
||||
|
||||
|
@ -46,8 +46,7 @@ class RecMetric(object):
|
|||
if self.is_filter:
|
||||
pred = self._normalize_text(pred)
|
||||
target = self._normalize_text(target)
|
||||
norm_edit_dis += Levenshtein.distance(pred, target) / max(
|
||||
len(pred), len(target), 1)
|
||||
norm_edit_dis += Levenshtein.normalized_distance(pred, target)
|
||||
if pred == target:
|
||||
correct_num += 1
|
||||
all_num += 1
|
||||
|
|
|
@ -20,7 +20,7 @@ from shapely.geometry import Polygon
|
|||
import numpy as np
|
||||
from collections import defaultdict
|
||||
import operator
|
||||
import Levenshtein
|
||||
from rapidfuzz.distance import Levenshtein
|
||||
import argparse
|
||||
import json
|
||||
import copy
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# Apache 2.0 License for more details.
|
||||
|
||||
import distance
|
||||
from rapidfuzz.distance import Levenshtein
|
||||
from apted import APTED, Config
|
||||
from apted.helpers import Tree
|
||||
from lxml import etree, html
|
||||
|
@ -39,17 +39,6 @@ class TableTree(Tree):
|
|||
|
||||
|
||||
class CustomConfig(Config):
|
||||
@staticmethod
|
||||
def maximum(*sequences):
|
||||
"""Get maximum possible value
|
||||
"""
|
||||
return max(map(len, sequences))
|
||||
|
||||
def normalized_distance(self, *sequences):
|
||||
"""Get distance from 0 to 1
|
||||
"""
|
||||
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
||||
|
||||
def rename(self, node1, node2):
|
||||
"""Compares attributes of trees"""
|
||||
#print(node1.tag)
|
||||
|
@ -58,23 +47,12 @@ class CustomConfig(Config):
|
|||
if node1.tag == 'td':
|
||||
if node1.content or node2.content:
|
||||
#print(node1.content, )
|
||||
return self.normalized_distance(node1.content, node2.content)
|
||||
return Levenshtein.normalized_distance(node1.content, node2.content)
|
||||
return 0.
|
||||
|
||||
|
||||
|
||||
class CustomConfig_del_short(Config):
|
||||
@staticmethod
|
||||
def maximum(*sequences):
|
||||
"""Get maximum possible value
|
||||
"""
|
||||
return max(map(len, sequences))
|
||||
|
||||
def normalized_distance(self, *sequences):
|
||||
"""Get distance from 0 to 1
|
||||
"""
|
||||
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
||||
|
||||
def rename(self, node1, node2):
|
||||
"""Compares attributes of trees"""
|
||||
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
||||
|
@ -90,21 +68,10 @@ class CustomConfig_del_short(Config):
|
|||
node1_content = ['####']
|
||||
if len(node2_content) < 3:
|
||||
node2_content = ['####']
|
||||
return self.normalized_distance(node1_content, node2_content)
|
||||
return Levenshtein.normalized_distance(node1_content, node2_content)
|
||||
return 0.
|
||||
|
||||
class CustomConfig_del_block(Config):
|
||||
@staticmethod
|
||||
def maximum(*sequences):
|
||||
"""Get maximum possible value
|
||||
"""
|
||||
return max(map(len, sequences))
|
||||
|
||||
def normalized_distance(self, *sequences):
|
||||
"""Get distance from 0 to 1
|
||||
"""
|
||||
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
||||
|
||||
def rename(self, node1, node2):
|
||||
"""Compares attributes of trees"""
|
||||
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
||||
|
@ -120,7 +87,7 @@ class CustomConfig_del_block(Config):
|
|||
while ' ' in node2_content:
|
||||
print(node2_content.index(' '))
|
||||
node2_content.pop(node2_content.index(' '))
|
||||
return self.normalized_distance(node1_content, node2_content)
|
||||
return Levenshtein.normalized_distance(node1_content, node2_content)
|
||||
return 0.
|
||||
|
||||
class TEDS(object):
|
||||
|
|
|
@ -6,7 +6,7 @@ lmdb
|
|||
tqdm
|
||||
numpy
|
||||
visualdl
|
||||
python-Levenshtein
|
||||
rapidfuzz
|
||||
opencv-contrib-python==4.4.0.46
|
||||
cython
|
||||
lxml
|
||||
|
|
|
@ -120,11 +120,14 @@ def sorted_boxes(dt_boxes):
|
|||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||
tmp = _boxes[i]
|
||||
_boxes[i] = _boxes[i + 1]
|
||||
_boxes[i + 1] = tmp
|
||||
for j in range(i, 0, -1):
|
||||
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
||||
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
||||
tmp = _boxes[j]
|
||||
_boxes[j] = _boxes[j + 1]
|
||||
_boxes[j + 1] = tmp
|
||||
else:
|
||||
break
|
||||
return _boxes
|
||||
|
||||
|
||||
|
|
|
@ -549,7 +549,7 @@ def text_visual(texts,
|
|||
def base64_to_cv2(b64str):
|
||||
import base64
|
||||
data = base64.b64decode(b64str.encode('utf8'))
|
||||
data = np.fromstring(data, np.uint8)
|
||||
data = np.frombuffer(data, np.uint8)
|
||||
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
|
||||
return data
|
||||
|
||||
|
|
|
@ -88,6 +88,29 @@ def draw_kie_result(batch, node, idx_to_cls, count):
|
|||
cv2.imwrite(save_path, vis_img)
|
||||
logger.info("The Kie Image saved in {}".format(save_path))
|
||||
|
||||
def write_kie_result(fout, node, data):
|
||||
"""
|
||||
Write infer result to output file, sorted by the predict label of each line.
|
||||
The format keeps the same as the input with additional score attribute.
|
||||
"""
|
||||
import json
|
||||
label = data['label']
|
||||
annotations = json.loads(label)
|
||||
max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
|
||||
node_pred_label = max_idx.numpy().tolist()
|
||||
node_pred_score = max_value.numpy().tolist()
|
||||
res = []
|
||||
for i, label in enumerate(node_pred_label):
|
||||
pred_score = '{:.2f}'.format(node_pred_score[i])
|
||||
pred_res = {
|
||||
'label': label,
|
||||
'transcription': annotations[i]['transcription'],
|
||||
'score': pred_score,
|
||||
'points': annotations[i]['points'],
|
||||
}
|
||||
res.append(pred_res)
|
||||
res.sort(key=lambda x: x['label'])
|
||||
fout.writelines([json.dumps(res, ensure_ascii=False) + '\n'])
|
||||
|
||||
def main():
|
||||
global_config = config['Global']
|
||||
|
@ -114,7 +137,7 @@ def main():
|
|||
|
||||
warmup_times = 0
|
||||
count_t = []
|
||||
with open(save_res_path, "wb") as fout:
|
||||
with open(save_res_path, "w") as fout:
|
||||
with open(config['Global']['infer_img'], "rb") as f:
|
||||
lines = f.readlines()
|
||||
for index, data_line in enumerate(lines):
|
||||
|
@ -139,6 +162,8 @@ def main():
|
|||
node = F.softmax(node, -1)
|
||||
count_t.append(time.time() - st)
|
||||
draw_kie_result(batch, node, idx_to_cls, index)
|
||||
write_kie_result(fout, node, data)
|
||||
fout.close()
|
||||
logger.info("success!")
|
||||
logger.info("It took {} s for predict {} images.".format(
|
||||
np.sum(count_t), len(count_t)))
|
||||
|
|
Loading…
Reference in New Issue