Merge remote-tracking branch 'origin/dygraph' into dygraph
|
@ -0,0 +1,273 @@
|
|||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--filename", type=str, help="The name of log which need to analysis.")
|
||||
parser.add_argument(
|
||||
"--log_with_profiler", type=str, help="The path of train log with profiler")
|
||||
parser.add_argument(
|
||||
"--profiler_path", type=str, help="The path of profiler timeline log.")
|
||||
parser.add_argument(
|
||||
"--keyword", type=str, help="Keyword to specify analysis data")
|
||||
parser.add_argument(
|
||||
"--separator", type=str, default=None, help="Separator of different field in log")
|
||||
parser.add_argument(
|
||||
'--position', type=int, default=None, help='The position of data field')
|
||||
parser.add_argument(
|
||||
'--range', type=str, default="", help='The range of data field to intercept')
|
||||
parser.add_argument(
|
||||
'--base_batch_size', type=int, help='base_batch size on gpu')
|
||||
parser.add_argument(
|
||||
'--skip_steps', type=int, default=0, help='The number of steps to be skipped')
|
||||
parser.add_argument(
|
||||
'--model_mode', type=int, default=-1, help='Analysis mode, default value is -1')
|
||||
parser.add_argument(
|
||||
'--ips_unit', type=str, default=None, help='IPS unit')
|
||||
parser.add_argument(
|
||||
'--model_name', type=str, default=0, help='training model_name, transformer_base')
|
||||
parser.add_argument(
|
||||
'--mission_name', type=str, default=0, help='training mission name')
|
||||
parser.add_argument(
|
||||
'--direction_id', type=int, default=0, help='training direction_id')
|
||||
parser.add_argument(
|
||||
'--run_mode', type=str, default="sp", help='multi process or single process')
|
||||
parser.add_argument(
|
||||
'--index', type=int, default=1, help='{1: speed, 2:mem, 3:profiler, 6:max_batch_size}')
|
||||
parser.add_argument(
|
||||
'--gpu_num', type=int, default=1, help='nums of training gpus')
|
||||
args = parser.parse_args()
|
||||
args.separator = None if args.separator == "None" else args.separator
|
||||
return args
|
||||
|
||||
|
||||
def _is_number(num):
|
||||
pattern = re.compile(r'^[-+]?[-0-9]\d*\.\d*|[-+]?\.?[0-9]\d*$')
|
||||
result = pattern.match(num)
|
||||
if result:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class TimeAnalyzer(object):
|
||||
def __init__(self, filename, keyword=None, separator=None, position=None, range="-1"):
|
||||
if filename is None:
|
||||
raise Exception("Please specify the filename!")
|
||||
|
||||
if keyword is None:
|
||||
raise Exception("Please specify the keyword!")
|
||||
|
||||
self.filename = filename
|
||||
self.keyword = keyword
|
||||
self.separator = separator
|
||||
self.position = position
|
||||
self.range = range
|
||||
self.records = None
|
||||
self._distil()
|
||||
|
||||
def _distil(self):
|
||||
self.records = []
|
||||
with open(self.filename, "r") as f_object:
|
||||
lines = f_object.readlines()
|
||||
for line in lines:
|
||||
if self.keyword not in line:
|
||||
continue
|
||||
try:
|
||||
result = None
|
||||
|
||||
# Distil the string from a line.
|
||||
line = line.strip()
|
||||
line_words = line.split(self.separator) if self.separator else line.split()
|
||||
if args.position:
|
||||
result = line_words[self.position]
|
||||
else:
|
||||
# Distil the string following the keyword.
|
||||
for i in range(len(line_words) - 1):
|
||||
if line_words[i] == self.keyword:
|
||||
result = line_words[i + 1]
|
||||
break
|
||||
|
||||
# Distil the result from the picked string.
|
||||
if not self.range:
|
||||
result = result[0:]
|
||||
elif _is_number(self.range):
|
||||
result = result[0: int(self.range)]
|
||||
else:
|
||||
result = result[int(self.range.split(":")[0]): int(self.range.split(":")[1])]
|
||||
self.records.append(float(result))
|
||||
except Exception as exc:
|
||||
print("line is: {}; separator={}; position={}".format(line, self.separator, self.position))
|
||||
|
||||
print("Extract {} records: separator={}; position={}".format(len(self.records), self.separator, self.position))
|
||||
|
||||
def _get_fps(self, mode, batch_size, gpu_num, avg_of_records, run_mode, unit=None):
|
||||
if mode == -1 and run_mode == 'sp':
|
||||
assert unit, "Please set the unit when mode is -1."
|
||||
fps = gpu_num * avg_of_records
|
||||
elif mode == -1 and run_mode == 'mp':
|
||||
assert unit, "Please set the unit when mode is -1."
|
||||
fps = gpu_num * avg_of_records #temporarily, not used now
|
||||
print("------------this is mp")
|
||||
elif mode == 0:
|
||||
# s/step -> samples/s
|
||||
fps = (batch_size * gpu_num) / avg_of_records
|
||||
unit = "samples/s"
|
||||
elif mode == 1:
|
||||
# steps/s -> steps/s
|
||||
fps = avg_of_records
|
||||
unit = "steps/s"
|
||||
elif mode == 2:
|
||||
# s/step -> steps/s
|
||||
fps = 1 / avg_of_records
|
||||
unit = "steps/s"
|
||||
elif mode == 3:
|
||||
# steps/s -> samples/s
|
||||
fps = batch_size * gpu_num * avg_of_records
|
||||
unit = "samples/s"
|
||||
elif mode == 4:
|
||||
# s/epoch -> s/epoch
|
||||
fps = avg_of_records
|
||||
unit = "s/epoch"
|
||||
else:
|
||||
ValueError("Unsupported analysis mode.")
|
||||
|
||||
return fps, unit
|
||||
|
||||
def analysis(self, batch_size, gpu_num=1, skip_steps=0, mode=-1, run_mode='sp', unit=None):
|
||||
if batch_size <= 0:
|
||||
print("base_batch_size should larger than 0.")
|
||||
return 0, ''
|
||||
|
||||
if len(self.records) <= skip_steps: # to address the condition which item of log equals to skip_steps
|
||||
print("no records")
|
||||
return 0, ''
|
||||
|
||||
sum_of_records = 0
|
||||
sum_of_records_skipped = 0
|
||||
skip_min = self.records[skip_steps]
|
||||
skip_max = self.records[skip_steps]
|
||||
|
||||
count = len(self.records)
|
||||
for i in range(count):
|
||||
sum_of_records += self.records[i]
|
||||
if i >= skip_steps:
|
||||
sum_of_records_skipped += self.records[i]
|
||||
if self.records[i] < skip_min:
|
||||
skip_min = self.records[i]
|
||||
if self.records[i] > skip_max:
|
||||
skip_max = self.records[i]
|
||||
|
||||
avg_of_records = sum_of_records / float(count)
|
||||
avg_of_records_skipped = sum_of_records_skipped / float(count - skip_steps)
|
||||
|
||||
fps, fps_unit = self._get_fps(mode, batch_size, gpu_num, avg_of_records, run_mode, unit)
|
||||
fps_skipped, _ = self._get_fps(mode, batch_size, gpu_num, avg_of_records_skipped, run_mode, unit)
|
||||
if mode == -1:
|
||||
print("average ips of %d steps, skip 0 step:" % count)
|
||||
print("\tAvg: %.3f %s" % (avg_of_records, fps_unit))
|
||||
print("\tFPS: %.3f %s" % (fps, fps_unit))
|
||||
if skip_steps > 0:
|
||||
print("average ips of %d steps, skip %d steps:" % (count, skip_steps))
|
||||
print("\tAvg: %.3f %s" % (avg_of_records_skipped, fps_unit))
|
||||
print("\tMin: %.3f %s" % (skip_min, fps_unit))
|
||||
print("\tMax: %.3f %s" % (skip_max, fps_unit))
|
||||
print("\tFPS: %.3f %s" % (fps_skipped, fps_unit))
|
||||
elif mode == 1 or mode == 3:
|
||||
print("average latency of %d steps, skip 0 step:" % count)
|
||||
print("\tAvg: %.3f steps/s" % avg_of_records)
|
||||
print("\tFPS: %.3f %s" % (fps, fps_unit))
|
||||
if skip_steps > 0:
|
||||
print("average latency of %d steps, skip %d steps:" % (count, skip_steps))
|
||||
print("\tAvg: %.3f steps/s" % avg_of_records_skipped)
|
||||
print("\tMin: %.3f steps/s" % skip_min)
|
||||
print("\tMax: %.3f steps/s" % skip_max)
|
||||
print("\tFPS: %.3f %s" % (fps_skipped, fps_unit))
|
||||
elif mode == 0 or mode == 2:
|
||||
print("average latency of %d steps, skip 0 step:" % count)
|
||||
print("\tAvg: %.3f s/step" % avg_of_records)
|
||||
print("\tFPS: %.3f %s" % (fps, fps_unit))
|
||||
if skip_steps > 0:
|
||||
print("average latency of %d steps, skip %d steps:" % (count, skip_steps))
|
||||
print("\tAvg: %.3f s/step" % avg_of_records_skipped)
|
||||
print("\tMin: %.3f s/step" % skip_min)
|
||||
print("\tMax: %.3f s/step" % skip_max)
|
||||
print("\tFPS: %.3f %s" % (fps_skipped, fps_unit))
|
||||
|
||||
return round(fps_skipped, 3), fps_unit
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
run_info = dict()
|
||||
run_info["log_file"] = args.filename
|
||||
run_info["model_name"] = args.model_name
|
||||
run_info["mission_name"] = args.mission_name
|
||||
run_info["direction_id"] = args.direction_id
|
||||
run_info["run_mode"] = args.run_mode
|
||||
run_info["index"] = args.index
|
||||
run_info["gpu_num"] = args.gpu_num
|
||||
run_info["FINAL_RESULT"] = 0
|
||||
run_info["JOB_FAIL_FLAG"] = 0
|
||||
|
||||
try:
|
||||
if args.index == 1:
|
||||
if args.gpu_num == 1:
|
||||
run_info["log_with_profiler"] = args.log_with_profiler
|
||||
run_info["profiler_path"] = args.profiler_path
|
||||
analyzer = TimeAnalyzer(args.filename, args.keyword, args.separator, args.position, args.range)
|
||||
run_info["FINAL_RESULT"], run_info["UNIT"] = analyzer.analysis(
|
||||
batch_size=args.base_batch_size,
|
||||
gpu_num=args.gpu_num,
|
||||
skip_steps=args.skip_steps,
|
||||
mode=args.model_mode,
|
||||
run_mode=args.run_mode,
|
||||
unit=args.ips_unit)
|
||||
try:
|
||||
if int(os.getenv('job_fail_flag')) == 1 or int(run_info["FINAL_RESULT"]) == 0:
|
||||
run_info["JOB_FAIL_FLAG"] = 1
|
||||
except:
|
||||
pass
|
||||
elif args.index == 3:
|
||||
run_info["FINAL_RESULT"] = {}
|
||||
records_fo_total = TimeAnalyzer(args.filename, 'Framework overhead', None, 3, '').records
|
||||
records_fo_ratio = TimeAnalyzer(args.filename, 'Framework overhead', None, 5).records
|
||||
records_ct_total = TimeAnalyzer(args.filename, 'Computation time', None, 3, '').records
|
||||
records_gm_total = TimeAnalyzer(args.filename, 'GpuMemcpy Calls', None, 4, '').records
|
||||
records_gm_ratio = TimeAnalyzer(args.filename, 'GpuMemcpy Calls', None, 6).records
|
||||
records_gmas_total = TimeAnalyzer(args.filename, 'GpuMemcpyAsync Calls', None, 4, '').records
|
||||
records_gms_total = TimeAnalyzer(args.filename, 'GpuMemcpySync Calls', None, 4, '').records
|
||||
run_info["FINAL_RESULT"]["Framework_Total"] = records_fo_total[0] if records_fo_total else 0
|
||||
run_info["FINAL_RESULT"]["Framework_Ratio"] = records_fo_ratio[0] if records_fo_ratio else 0
|
||||
run_info["FINAL_RESULT"]["ComputationTime_Total"] = records_ct_total[0] if records_ct_total else 0
|
||||
run_info["FINAL_RESULT"]["GpuMemcpy_Total"] = records_gm_total[0] if records_gm_total else 0
|
||||
run_info["FINAL_RESULT"]["GpuMemcpy_Ratio"] = records_gm_ratio[0] if records_gm_ratio else 0
|
||||
run_info["FINAL_RESULT"]["GpuMemcpyAsync_Total"] = records_gmas_total[0] if records_gmas_total else 0
|
||||
run_info["FINAL_RESULT"]["GpuMemcpySync_Total"] = records_gms_total[0] if records_gms_total else 0
|
||||
else:
|
||||
print("Not support!")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
print("{}".format(json.dumps(run_info))) # it's required, for the log file path insert to the database
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
|
||||
# PaddleOCR DB/EAST 算法训练benchmark测试
|
||||
|
||||
PaddleOCR/benchmark目录下的文件用于获取并分析训练日志。
|
||||
训练采用icdar2015数据集,包括1000张训练图像和500张测试图像。模型配置采用resnet18_vd作为backbone,分别训练batch_size=8和batch_size=16的情况。
|
||||
|
||||
## 运行训练benchmark
|
||||
|
||||
benchmark/run_det.sh 中包含了三个过程:
|
||||
- 安装依赖
|
||||
- 下载数据
|
||||
- 执行训练
|
||||
- 日志分析获取IPS
|
||||
|
||||
在执行训练部分,会执行单机单卡(默认0号卡)单机多卡训练,并分别执行batch_size=8和batch_size=16的情况。所以执行完后,每种模型会得到4个日志文件。
|
||||
|
||||
run_det.sh 执行方式如下:
|
||||
|
||||
```
|
||||
# cd PaddleOCR/
|
||||
bash benchmark/run_det.sh
|
||||
```
|
||||
|
||||
以DB为例,将得到四个日志文件,如下:
|
||||
```
|
||||
det_res18_db_v2.0_sp_bs16_fp32_1
|
||||
det_res18_db_v2.0_sp_bs8_fp32_1
|
||||
det_res18_db_v2.0_mp_bs16_fp32_1
|
||||
det_res18_db_v2.0_mp_bs8_fp32_1
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env bash
|
||||
set -xe
|
||||
# 运行示例:CUDA_VISIBLE_DEVICES=0 bash run_benchmark.sh ${run_mode} ${bs_item} ${fp_item} 500 ${model_mode}
|
||||
# 参数说明
|
||||
function _set_params(){
|
||||
run_mode=${1:-"sp"} # 单卡sp|多卡mp
|
||||
batch_size=${2:-"64"}
|
||||
fp_item=${3:-"fp32"} # fp32|fp16
|
||||
max_iter=${4:-"500"} # 可选,如果需要修改代码提前中断
|
||||
model_name=${5:-"model_name"}
|
||||
run_log_path=${TRAIN_LOG_DIR:-$(pwd)} # TRAIN_LOG_DIR 后续QA设置该参数
|
||||
|
||||
# 以下不用修改
|
||||
device=${CUDA_VISIBLE_DEVICES//,/ }
|
||||
arr=(${device})
|
||||
num_gpu_devices=${#arr[*]}
|
||||
log_file=${run_log_path}/${model_name}_${run_mode}_bs${batch_size}_${fp_item}_${num_gpu_devices}
|
||||
}
|
||||
function _train(){
|
||||
echo "Train on ${num_gpu_devices} GPUs"
|
||||
echo "current CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES, gpus=$num_gpu_devices, batch_size=$batch_size"
|
||||
|
||||
train_cmd="-c configs/det/${model_name}.yml -o Train.loader.batch_size_per_card=${batch_size} Global.epoch_num=${max_iter} "
|
||||
case ${run_mode} in
|
||||
sp)
|
||||
train_cmd="python3.7 tools/train.py "${train_cmd}""
|
||||
;;
|
||||
mp)
|
||||
train_cmd="python3.7 -m paddle.distributed.launch --log_dir=./mylog --gpus=$CUDA_VISIBLE_DEVICES tools/train.py ${train_cmd}"
|
||||
;;
|
||||
*) echo "choose run_mode(sp or mp)"; exit 1;
|
||||
esac
|
||||
# 以下不用修改
|
||||
timeout 15m ${train_cmd} > ${log_file} 2>&1
|
||||
if [ $? -ne 0 ];then
|
||||
echo -e "${model_name}, FAIL"
|
||||
export job_fail_flag=1
|
||||
else
|
||||
echo -e "${model_name}, SUCCESS"
|
||||
export job_fail_flag=0
|
||||
fi
|
||||
kill -9 `ps -ef|grep 'python3.7'|awk '{print $2}'`
|
||||
|
||||
if [ $run_mode = "mp" -a -d mylog ]; then
|
||||
rm ${log_file}
|
||||
cp mylog/workerlog.0 ${log_file}
|
||||
fi
|
||||
|
||||
# run log analysis
|
||||
analysis_cmd="python3.7 benchmark/analysis.py --filename ${log_file} --mission_name ${model_name} --run_mode ${mode} --direction_id 0 --keyword 'ips:' --base_batch_size ${batch_szie} --skip_steps 1 --gpu_num ${num_gpu_devices} --index 1 --model_mode=-1 --ips_unit=samples/sec"
|
||||
eval $analysis_cmd
|
||||
}
|
||||
|
||||
_set_params $@
|
||||
_train
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# 提供可稳定复现性能的脚本,默认在标准docker环境内py37执行: paddlepaddle/paddle:latest-gpu-cuda10.1-cudnn7 paddle=2.1.2 py=37
|
||||
# 执行目录: ./PaddleOCR
|
||||
# 1 安装该模型需要的依赖 (如需开启优化策略请注明)
|
||||
python3.7 -m pip install -r requirements.txt
|
||||
# 2 拷贝该模型需要数据、预训练模型
|
||||
wget -c -p ./tain_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar && cd train_data && tar xf icdar2015.tar && cd ../
|
||||
wget -c -p ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams
|
||||
# 3 批量运行(如不方便批量,1,2需放到单个模型中)
|
||||
|
||||
model_mode_list=(det_res18_db_v2.0 det_r50_vd_east)
|
||||
fp_item_list=(fp32)
|
||||
bs_list=(8 16)
|
||||
for model_mode in ${model_mode_list[@]}; do
|
||||
for fp_item in ${fp_item_list[@]}; do
|
||||
for bs_item in ${bs_list[@]}; do
|
||||
echo "index is speed, 1gpus, begin, ${model_name}"
|
||||
run_mode=sp
|
||||
CUDA_VISIBLE_DEVICES=0 bash benchmark/run_benchmark_det.sh ${run_mode} ${bs_item} ${fp_item} 10 ${model_mode} # (5min)
|
||||
sleep 60
|
||||
echo "index is speed, 8gpus, run_mode is multi_process, begin, ${model_name}"
|
||||
run_mode=mp
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash benchmark/run_benchmark_det.sh ${run_mode} ${bs_item} ${fp_item} 10 ${model_mode}
|
||||
sleep 60
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
|
@ -141,6 +141,7 @@ Train:
|
|||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- CopyPaste:
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
|
|
|
@ -68,8 +68,7 @@ Loss:
|
|||
ohem_ratio: 3
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Teacher"]
|
||||
# key: maps
|
||||
model_name_list: ["Student"]
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
|
@ -116,6 +115,7 @@ Train:
|
|||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- CopyPaste:
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
|
|
|
@ -118,6 +118,7 @@ Train:
|
|||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- CopyPaste:
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
|
|
|
@ -8,7 +8,7 @@ Global:
|
|||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [4000, 5000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/ResNet50_vd_pretrained/
|
||||
pretrained_model: ./pretrain_models/ResNet50_vd_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_res18/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/ResNet18_vd_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 18
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DBPostProcess
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -94,7 +94,7 @@ Eval:
|
|||
label_file_list: [./train_data/total_text/test/test.txt]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- E2ELabelEncodeTest:
|
||||
- E2EResizeForTest:
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: 800
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec_mobile_pp-OCRv2_enhanced_ctc_loss
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: true
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
||||
character_type: ch
|
||||
max_text_length: 25
|
||||
infer_mode: false
|
||||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_mobile_pp-OCRv2_enhanced_ctc_loss.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs : [700, 800]
|
||||
values : [0.001, 0.0001]
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 2.0e-05
|
||||
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
return_feats: true
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- CTCLoss:
|
||||
use_focal_loss: false
|
||||
weight: 1.0
|
||||
- CenterLoss:
|
||||
weight: 0.05
|
||||
num_classes: 6625
|
||||
feat_dim: 96
|
||||
init_center: false
|
||||
center_file_path: "./train_center.pkl"
|
||||
# you can also try to add ace loss on your own dataset
|
||||
# - ACELoss:
|
||||
# weight: 0.1
|
||||
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/
|
||||
label_file_list:
|
||||
- ./train_data/train_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- RecAug:
|
||||
- CTCLabelEncode:
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label
|
||||
- length
|
||||
- label_ace
|
||||
loader:
|
||||
shuffle: true
|
||||
batch_size_per_card: 128
|
||||
drop_last: true
|
||||
num_workers: 8
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/val_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- CTCLabelEncode:
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label
|
||||
- length
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 128
|
||||
num_workers: 8
|
|
@ -0,0 +1,109 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 400
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/seed
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: EN_symbol
|
||||
max_text_length: 100
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_seed.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adadelta
|
||||
weight_deacy: 0.0
|
||||
momentum: 0.9
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs: [4,5,8]
|
||||
values: [1.0, 0.1, 0.01]
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 2.0e-05
|
||||
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: SEED
|
||||
Transform:
|
||||
name: STN_ON
|
||||
tps_inputsize: [32, 64]
|
||||
tps_outputsize: [32, 100]
|
||||
num_control_points: 20
|
||||
tps_margins: [0.05,0.05]
|
||||
stn_activation: none
|
||||
Backbone:
|
||||
name: ResNet_ASTER
|
||||
Head:
|
||||
name: AsterHead # AttentionHead
|
||||
sDim: 512
|
||||
attDim: 512
|
||||
max_len_labels: 100
|
||||
|
||||
Loss:
|
||||
name: AsterLoss
|
||||
|
||||
PostProcess:
|
||||
name: SEEDLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
is_filter: True
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- Fasttext:
|
||||
path: "./cc.en.300.bin"
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SEEDLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
character_type: en
|
||||
image_shape: [3, 64, 256]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 256
|
||||
drop_last: True
|
||||
num_workers: 6
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/evaluation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SEEDLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
character_type: en
|
||||
image_shape: [3, 64, 256]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: True
|
||||
batch_size_per_card: 256
|
||||
num_workers: 4
|
|
@ -112,12 +112,16 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
|
|||
1 << 20, 10, 3,
|
||||
precision,
|
||||
false, false);
|
||||
|
||||
std::map<std::string, std::vector<int>> min_input_shape = {
|
||||
{"x", {1, 3, 32, 10}}};
|
||||
{"x", {1, 3, 32, 10}},
|
||||
{"lstm_0.tmp_0", {10, 1, 96}}};
|
||||
std::map<std::string, std::vector<int>> max_input_shape = {
|
||||
{"x", {1, 3, 32, 2000}}};
|
||||
{"x", {1, 3, 32, 2000}},
|
||||
{"lstm_0.tmp_0", {1000, 1, 96}}};
|
||||
std::map<std::string, std::vector<int>> opt_input_shape = {
|
||||
{"x", {1, 3, 32, 320}}};
|
||||
{"x", {1, 3, 32, 320}},
|
||||
{"lstm_0.tmp_0", {25, 1, 96}}};
|
||||
|
||||
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
|
||||
opt_input_shape);
|
||||
|
@ -139,7 +143,7 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
|
|||
config.SwitchIrOptim(true);
|
||||
|
||||
config.EnableMemoryOptim();
|
||||
config.DisableGlogInfo();
|
||||
// config.DisableGlogInfo();
|
||||
|
||||
this->predictor_ = CreatePredictor(config);
|
||||
}
|
||||
|
|
|
@ -110,25 +110,42 @@ def main(config, device, logger, vdl_writer):
|
|||
logger.info("metric['hmean']: {}".format(metric['hmean']))
|
||||
return metric['hmean']
|
||||
|
||||
run_sensitive_analysis = False
|
||||
"""
|
||||
run_sensitive_analysis=True:
|
||||
Automatically compute the sensitivities of convolutions in a model.
|
||||
The sensitivity of a convolution is the losses of accuracy on test dataset in
|
||||
differenct pruned ratios. The sensitivities can be used to get a group of best
|
||||
ratios with some condition.
|
||||
|
||||
run_sensitive_analysis=False:
|
||||
Set prune trim ratio to a fixed value, such as 10%. The larger the value,
|
||||
the more convolution weights will be cropped.
|
||||
|
||||
"""
|
||||
|
||||
if run_sensitive_analysis:
|
||||
params_sensitive = pruner.sensitive(
|
||||
eval_func=eval_fn,
|
||||
sen_file="./sen.pickle",
|
||||
sen_file="./deploy/slim/prune/sen.pickle",
|
||||
skip_vars=[
|
||||
"conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
|
||||
"conv2d_57.w_0", "conv2d_transpose_2.w_0",
|
||||
"conv2d_transpose_3.w_0"
|
||||
])
|
||||
|
||||
logger.info(
|
||||
"The sensitivity analysis results of model parameters saved in sen.pickle"
|
||||
)
|
||||
# calculate pruned params's ratio
|
||||
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
|
||||
params_sensitive = pruner._get_ratios_by_loss(
|
||||
params_sensitive, loss=0.02)
|
||||
for key in params_sensitive.keys():
|
||||
logger.info("{}, {}".format(key, params_sensitive[key]))
|
||||
|
||||
#params_sensitive = {}
|
||||
#for param in model.parameters():
|
||||
# if 'transpose' not in param.name and 'linear' not in param.name:
|
||||
# params_sensitive[param.name] = 0.1
|
||||
else:
|
||||
params_sensitive = {}
|
||||
for param in model.parameters():
|
||||
if 'transpose' not in param.name and 'linear' not in param.name:
|
||||
# set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
|
||||
params_sensitive[param.name] = 0.1
|
||||
|
||||
plan = pruner.prune_vars(params_sensitive, [0])
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
|||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))
|
||||
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
|
||||
- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2))
|
||||
- [x] SEED([paper](https://arxiv.org/pdf/2005.10977.pdf))
|
||||
|
||||
参考[DTRB](https://arxiv.org/abs/1904.01906) 文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
|
||||
|
@ -66,5 +67,5 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
|||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|
||||
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|
||||
|SAR|Resnet31| 87.2% | rec_r31_sar | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|
||||
|
||||
|SEED| Aster_Resnet | 85.2% | rec_resnet_stn_bilstm_att | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_resnet_stn_bilstm_att.tar)|
|
||||
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
# Enhanced CTC Loss
|
||||
|
||||
在OCR识别中, CRNN是一种在工业界广泛使用的文字识别算法。 在训练阶段,其采用CTCLoss来计算网络损失; 在推理阶段,其采用CTCDecode来获得解码结果。虽然CRNN算法在实际业务中被证明能够获得很好的识别效果, 然而用户对识别准确率的要求却是无止境的,如何进一步提升文字识别的准确率呢? 本文以CTCLoss为切人点,分别从难例挖掘、 多任务学习、 Metric Learning 3个不同的角度探索了CTCLoss的改进融合方案,提出了EnhancedCTCLoss,其包括如下3个组成部分: Focal-CTC Loss,A-CTC Loss, C-CTC Loss。
|
||||
|
||||
## 1. Focal-CTC Loss
|
||||
Focal Loss 出自论文《Focal Loss for Dense Object Detection》, 该loss最先提出的时候主要是为了解决one-stage目标检测中正负样本比例严重失衡的问题。该损失函数降低了大量简单负样本在训练中所占的权重,也可理解为一种困难样本挖掘。
|
||||
其损失函数形式如下:
|
||||
<div align="center">
|
||||
<img src="./focal_loss_formula.png" width = "600" />
|
||||
</div>
|
||||
|
||||
其中, y' 是经过激活函数的输出,取值在0-1之间。其在原始的交叉熵损失的基础上加了一个调制系数(1 – y’)^ γ和平衡因子α。 当α = 1,y=1时,其损失函数与交叉熵损失的对比如下图所示:
|
||||
<div align="center">
|
||||
<img src="./focal_loss_image.png" width = "600" />
|
||||
</div>
|
||||
|
||||
从上图可以看到, 当γ> 0时,调整系数(1-y’)^γ 赋予易分类样本损失一个更小的权重,使得网络更关注于困难的、错分的样本。 调整因子γ用于调节简单样本权重降低的速率,当γ为0时即为交叉熵损失函数,当γ增加时,调整因子的影响也会随之增大。实验发现γ为2是最优。平衡因子α用来平衡正负样本本身的比例不均,文中α取0.25。
|
||||
|
||||
对于经典的CTC算法,假设某个特征序列(f<sub>1</sub>, f<sub>2</sub>, ......f<sub>t</sub>), 经过CTC解码之后结果等于label的概率为y’, 则CTC解码结果不为label的概率即为(1-y’);不难发现 CTCLoss值和y’有如下关系:
|
||||
<div align="center">
|
||||
<img src="./equation_ctcloss.png" width = "250" />
|
||||
</div>
|
||||
|
||||
结合Focal Loss的思想,赋予困难样本较大的权重,简单样本较小的权重,可以使网络更加聚焦于对困难样本的挖掘,进一步提升识别的准确率,由此我们提出了Focal-CTC Loss; 其定义如下所示:
|
||||
<div align="center">
|
||||
<img src="./equation_focal_ctc.png" width = "500" />
|
||||
</div>
|
||||
|
||||
实验中,γ取值为2, α= 1, 具体实现见: [rec_ctc_loss.py](../../ppocr/losses/rec_ctc_loss.py)
|
||||
|
||||
## 2. A-CTC Loss
|
||||
A-CTC Loss是CTC Loss + ACE Loss的简称。 其中ACE Loss出自论文< Aggregation Cross-Entropy for Sequence Recognition>. ACE Loss相比于CTCLoss,主要有如下两点优势:
|
||||
+ ACE Loss能够解决2-D文本的识别问题; CTCLoss只能够处理1-D文本
|
||||
+ ACE Loss 在时间复杂度和空间复杂度上优于CTC loss
|
||||
|
||||
前人总结的OCR识别算法的优劣如下图所示:
|
||||
<div align="center">
|
||||
<img src="./rec_algo_compare.png" width = "1000" />
|
||||
</div>
|
||||
|
||||
虽然ACELoss确实如上图所说,可以处理2D预测,在内存占用及推理速度方面具备优势,但在实践过程中,我们发现单独使用ACE Loss, 识别效果并不如CTCLoss. 因此,我们尝试将CTCLoss和ACELoss进行组合,同时以CTCLoss为主,将ACELoss 定位为一个辅助监督loss。 这一尝试收到了效果,在我们内部的实验数据集上,相比单独使用CTCLoss,识别准确率可以提升1%左右。
|
||||
A_CTC Loss定义如下:
|
||||
<div align="center">
|
||||
<img src="./equation_a_ctc.png" width = "300" />
|
||||
</div>
|
||||
|
||||
实验中,λ = 0.1. ACE loss实现代码见: [ace_loss.py](../../ppocr/losses/ace_loss.py)
|
||||
|
||||
## 3. C-CTC Loss
|
||||
C-CTC Loss是CTC Loss + Center Loss的简称。 其中Center Loss出自论文 < A Discriminative Feature Learning Approach for Deep Face Recognition>. 最早用于人脸识别任务,用于增大累间距离,减小类内距离, 是Metric Learning领域一种较早的、也比较常用的一种算法。
|
||||
在中文OCR识别任务中,通过对badcase分析, 我们发现中文识别的一大难点是相似字符多,容易误识。 由此我们想到是否可以借鉴Metric Learing的想法, 增大相似字符的类间距,从而提高识别准确率。然而,MetricLearning主要用于图像识别领域,训练数据的标签为一个固定的值;而对于OCR识别来说,其本质上是一个序列识别任务,特征和label之间并不具有显式的对齐关系,因此两者如何结合依然是一个值得探索的方向。
|
||||
通过尝试Arcmargin, Cosmargin等方法, 我们最终发现Centerloss 有助于进一步提升识别的准确率。C_CTC Loss定义如下:
|
||||
<div align="center">
|
||||
<img src="./equation_c_ctc.png" width = "300" />
|
||||
</div>
|
||||
|
||||
实验中,我们设置λ=0.25. center_loss实现代码见: [center_loss.py](../../ppocr/losses/center_loss.py)
|
||||
|
||||
值得一提的是, 在C-CTC Loss中,选择随机初始化Center并不能够带来明显的提升. 我们的Center初始化方法如下:
|
||||
+ 基于原始的CTCLoss, 训练得到一个网络N
|
||||
+ 挑选出训练集中,识别完全正确的部分, 组成集合G
|
||||
+ 将G中的每个样本送入网络,进行前向计算, 提取最后一个FC层的输入(即feature)及其经过argmax计算的结果(即index)之间的对应关系
|
||||
+ 将相同index的feature进行聚合,计算平均值,得到各自字符的初始center.
|
||||
|
||||
以配置文件`configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml`为例, center提取命令如下所示:
|
||||
```
|
||||
python tools/export_center.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml -o Global.pretrained_model: "./output/rec_mobile_pp-OCRv2/best_accuracy"
|
||||
```
|
||||
运行完后,会在PaddleOCR主目录下生成`train_center.pkl`.
|
||||
|
||||
## 4. 实验
|
||||
对于上述的三种方案,我们基于百度内部数据集进行了训练、评测,实验情况如下表所示:
|
||||
|algorithm| Focal_CTC | A_CTC | C-CTC |
|
||||
|:------| :------| ------: | :------: |
|
||||
|gain| +0.3% | +0.7% | +1.7% |
|
||||
|
||||
基于上述实验结论,我们在PP-OCRv2中,采用了C-CTC的策略。 值得一提的是,由于PP-OCRv2 处理的是6625个中文字符的识别任务,字符集比较大,形似字较多,所以在该任务上C-CTC 方案带来的提升较大。 但如果换做其他OCR识别任务,结论可能会有所不同。大家可以尝试Focal-CTC,A-CTC, C-CTC以及组合方案EnhancedCTC,相信会带来不同程度的提升效果。
|
||||
统一的融合方案见如下文件: [rec_enhanced_ctc_loss.py](../../ppocr/losses/rec_enhanced_ctc_loss.py)
|
After Width: | Height: | Size: 10 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 9.3 KiB |
After Width: | Height: | Size: 14 KiB |
After Width: | Height: | Size: 23 KiB |
After Width: | Height: | Size: 125 KiB |
After Width: | Height: | Size: 224 KiB |
|
@ -234,6 +234,9 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
|
|||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
|
||||
| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder |
|
||||
| rec_resnet_stn_bilstm_att.yml | SEED | Aster_Resnet | STN | BiLSTM | att |
|
||||
|
||||
*其中SEED模型需要额外加载FastText训练好的[语言模型](https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz)
|
||||
|
||||
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
||||
|
||||
|
@ -460,5 +463,3 @@ python3 tools/export_model.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_trai
|
|||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
|
||||
```
|
||||
|
||||
|
||||
|
|
|
@ -215,6 +215,11 @@ class CTCLabelEncode(BaseRecLabelEncode):
|
|||
data['length'] = np.array(len(text))
|
||||
text = text + [0] * (self.max_text_len - len(text))
|
||||
data['label'] = np.array(text)
|
||||
|
||||
label = [0] * len(self.character)
|
||||
for x in text:
|
||||
label[x] += 1
|
||||
data['label_ace'] = np.array(label)
|
||||
return data
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
|
@ -342,6 +347,38 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
|||
return idx
|
||||
|
||||
|
||||
class SEEDLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(SEEDLabelEncode,
|
||||
self).__init__(max_text_length, character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.end_str = "eos"
|
||||
dict_character = dict_character + [self.end_str]
|
||||
return dict_character
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) >= self.max_text_len:
|
||||
return None
|
||||
data['length'] = np.array(len(text)) + 1 # conclude eos
|
||||
text = text + [len(self.character) - 1] * (self.max_text_len - len(text)
|
||||
)
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
|
||||
class SRNLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
@ -421,7 +458,6 @@ class TableLabelEncode(object):
|
|||
substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
|
||||
character_num = int(substr[0])
|
||||
elem_num = int(substr[1])
|
||||
|
||||
for cno in range(1, 1 + character_num):
|
||||
character = lines[cno].decode('utf-8').strip("\r\n")
|
||||
list_character.append(character)
|
||||
|
|
|
@ -23,6 +23,7 @@ import sys
|
|||
import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
import fasttext
|
||||
|
||||
|
||||
class DecodeImage(object):
|
||||
|
@ -83,12 +84,13 @@ class NRTRDecodeImage(object):
|
|||
elif self.img_mode == 'RGB':
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
if self.channel_first:
|
||||
img = img.transpose((2, 0, 1))
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
""" normalize image such as substract mean, divide std
|
||||
"""
|
||||
|
@ -133,6 +135,17 @@ class ToCHWImage(object):
|
|||
return data
|
||||
|
||||
|
||||
class Fasttext(object):
|
||||
def __init__(self, path="None", **kwargs):
|
||||
self.fast_model = fasttext.load_model(path)
|
||||
|
||||
def __call__(self, data):
|
||||
label = data['label']
|
||||
fast_label = self.fast_model[label]
|
||||
data['fast_label'] = fast_label
|
||||
return data
|
||||
|
||||
|
||||
class KeepKeys(object):
|
||||
def __init__(self, keep_keys, **kwargs):
|
||||
self.keep_keys = keep_keys
|
||||
|
|
|
@ -88,17 +88,19 @@ class RecResizeImg(object):
|
|||
image_shape,
|
||||
infer_mode=False,
|
||||
character_type='ch',
|
||||
padding=True,
|
||||
**kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.infer_mode = infer_mode
|
||||
self.character_type = character_type
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if self.infer_mode and self.character_type == "ch":
|
||||
norm_img = resize_norm_img_chinese(img, self.image_shape)
|
||||
else:
|
||||
norm_img = resize_norm_img(img, self.image_shape)
|
||||
norm_img = resize_norm_img(img, self.image_shape, self.padding)
|
||||
data['image'] = norm_img
|
||||
return data
|
||||
|
||||
|
@ -174,10 +176,15 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
|||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
|
||||
def resize_norm_img(img, image_shape):
|
||||
def resize_norm_img(img, image_shape, padding=True):
|
||||
imgC, imgH, imgW = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
if not padding:
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_w = imgW
|
||||
else:
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
|
|
|
@ -28,6 +28,8 @@ from .rec_att_loss import AttentionLoss
|
|||
from .rec_srn_loss import SRNLoss
|
||||
from .rec_nrtr_loss import NRTRLoss
|
||||
from .rec_sar_loss import SARLoss
|
||||
from .rec_aster_loss import AsterLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
|
@ -48,9 +50,8 @@ def build_loss(config):
|
|||
support_dict = [
|
||||
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
|
||||
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
|
||||
'TableAttentionLoss', 'SARLoss'
|
||||
'TableAttentionLoss', 'SARLoss', 'AsterLoss'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('loss only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
|
||||
class ACELoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(
|
||||
weight=None,
|
||||
ignore_index=0,
|
||||
reduction='none',
|
||||
soft_label=True,
|
||||
axis=-1)
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
if isinstance(predicts, (list, tuple)):
|
||||
predicts = predicts[-1]
|
||||
B, N = predicts.shape[:2]
|
||||
div = paddle.to_tensor([N]).astype('float32')
|
||||
|
||||
predicts = nn.functional.softmax(predicts, axis=-1)
|
||||
aggregation_preds = paddle.sum(predicts, axis=1)
|
||||
aggregation_preds = paddle.divide(aggregation_preds, div)
|
||||
|
||||
length = batch[2].astype("float32")
|
||||
batch = batch[3].astype("float32")
|
||||
batch[:, 0] = paddle.subtract(div, length)
|
||||
|
||||
batch = paddle.divide(batch, div)
|
||||
|
||||
loss = self.loss_func(aggregation_preds, batch)
|
||||
|
||||
return {"loss_ace": loss}
|
|
@ -0,0 +1,89 @@
|
|||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class CenterLoss(nn.Layer):
|
||||
"""
|
||||
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes=6625,
|
||||
feat_dim=96,
|
||||
init_center=False,
|
||||
center_file_path=None):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.feat_dim = feat_dim
|
||||
self.centers = paddle.randn(
|
||||
shape=[self.num_classes, self.feat_dim]).astype(
|
||||
"float64") #random center
|
||||
|
||||
if init_center:
|
||||
assert os.path.exists(
|
||||
center_file_path
|
||||
), f"center path({center_file_path}) must exist when init_center is set as True."
|
||||
with open(center_file_path, 'rb') as f:
|
||||
char_dict = pickle.load(f)
|
||||
for key in char_dict.keys():
|
||||
self.centers[key] = paddle.to_tensor(char_dict[key])
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
assert isinstance(predicts, (list, tuple))
|
||||
features, predicts = predicts
|
||||
|
||||
feats_reshape = paddle.reshape(
|
||||
features, [-1, features.shape[-1]]).astype("float64")
|
||||
label = paddle.argmax(predicts, axis=2)
|
||||
label = paddle.reshape(label, [label.shape[0] * label.shape[1]])
|
||||
|
||||
batch_size = feats_reshape.shape[0]
|
||||
|
||||
#calc feat * feat
|
||||
dist1 = paddle.sum(paddle.square(feats_reshape), axis=1, keepdim=True)
|
||||
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
|
||||
|
||||
#dist2 of centers
|
||||
dist2 = paddle.sum(paddle.square(self.centers), axis=1,
|
||||
keepdim=True) #num_classes
|
||||
dist2 = paddle.expand(dist2,
|
||||
[self.num_classes, batch_size]).astype("float64")
|
||||
dist2 = paddle.transpose(dist2, [1, 0])
|
||||
|
||||
#first x * x + y * y
|
||||
distmat = paddle.add(dist1, dist2)
|
||||
tmp = paddle.matmul(feats_reshape,
|
||||
paddle.transpose(self.centers, [1, 0]))
|
||||
distmat = distmat - 2.0 * tmp
|
||||
|
||||
#generate the mask
|
||||
classes = paddle.arange(self.num_classes).astype("int64")
|
||||
label = paddle.expand(
|
||||
paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
|
||||
mask = paddle.equal(
|
||||
paddle.expand(classes, [batch_size, self.num_classes]),
|
||||
label).astype("float64") #get mask
|
||||
dist = paddle.multiply(distmat, mask)
|
||||
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
|
||||
return {'loss_center': loss}
|
|
@ -15,6 +15,10 @@
|
|||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .center_loss import CenterLoss
|
||||
from .ace_loss import ACELoss
|
||||
|
||||
from .distillation_loss import DistillationCTCLoss
|
||||
from .distillation_loss import DistillationDMLLoss
|
||||
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
|
||||
|
|
|
@ -112,7 +112,7 @@ class DistillationDMLLoss(DMLLoss):
|
|||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||
0], pair[1], map_name, idx)] = loss[key]
|
||||
0], pair[1], self.maps_name, idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
|
||||
_c], idx)] = loss
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class CosineEmbeddingLoss(nn.Layer):
|
||||
def __init__(self, margin=0.):
|
||||
super(CosineEmbeddingLoss, self).__init__()
|
||||
self.margin = margin
|
||||
self.epsilon = 1e-12
|
||||
|
||||
def forward(self, x1, x2, target):
|
||||
similarity = paddle.fluid.layers.reduce_sum(
|
||||
x1 * x2, dim=-1) / (paddle.norm(
|
||||
x1, axis=-1) * paddle.norm(
|
||||
x2, axis=-1) + self.epsilon)
|
||||
one_list = paddle.full_like(target, fill_value=1)
|
||||
out = paddle.fluid.layers.reduce_mean(
|
||||
paddle.where(
|
||||
paddle.equal(target, one_list), 1. - similarity,
|
||||
paddle.maximum(
|
||||
paddle.zeros_like(similarity), similarity - self.margin)))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class AsterLoss(nn.Layer):
|
||||
def __init__(self,
|
||||
weight=None,
|
||||
size_average=True,
|
||||
ignore_index=-100,
|
||||
sequence_normalize=False,
|
||||
sample_normalize=True,
|
||||
**kwargs):
|
||||
super(AsterLoss, self).__init__()
|
||||
self.weight = weight
|
||||
self.size_average = size_average
|
||||
self.ignore_index = ignore_index
|
||||
self.sequence_normalize = sequence_normalize
|
||||
self.sample_normalize = sample_normalize
|
||||
self.loss_sem = CosineEmbeddingLoss()
|
||||
self.is_cosin_loss = True
|
||||
self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
targets = batch[1].astype("int64")
|
||||
label_lengths = batch[2].astype('int64')
|
||||
sem_target = batch[3].astype('float32')
|
||||
embedding_vectors = predicts['embedding_vectors']
|
||||
rec_pred = predicts['rec_pred']
|
||||
|
||||
if not self.is_cosin_loss:
|
||||
sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target))
|
||||
else:
|
||||
label_target = paddle.ones([embedding_vectors.shape[0]])
|
||||
sem_loss = paddle.sum(
|
||||
self.loss_sem(embedding_vectors, sem_target, label_target))
|
||||
|
||||
# rec loss
|
||||
batch_size, def_max_length = targets.shape[0], targets.shape[1]
|
||||
|
||||
mask = paddle.zeros([batch_size, def_max_length])
|
||||
for i in range(batch_size):
|
||||
mask[i, :label_lengths[i]] = 1
|
||||
mask = paddle.cast(mask, "float32")
|
||||
max_length = max(label_lengths)
|
||||
assert max_length == rec_pred.shape[1]
|
||||
targets = targets[:, :max_length]
|
||||
mask = mask[:, :max_length]
|
||||
rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]])
|
||||
input = nn.functional.log_softmax(rec_pred, axis=1)
|
||||
targets = paddle.reshape(targets, [-1, 1])
|
||||
mask = paddle.reshape(mask, [-1, 1])
|
||||
output = -paddle.index_sample(input, index=targets) * mask
|
||||
output = paddle.sum(output)
|
||||
if self.sequence_normalize:
|
||||
output = output / paddle.sum(mask)
|
||||
if self.sample_normalize:
|
||||
output = output / batch_size
|
||||
|
||||
loss = output + sem_loss * 0.1
|
||||
return {'loss': loss}
|
|
@ -21,16 +21,24 @@ from paddle import nn
|
|||
|
||||
|
||||
class CTCLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, use_focal_loss=False, **kwargs):
|
||||
super(CTCLoss, self).__init__()
|
||||
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
|
||||
self.use_focal_loss = use_focal_loss
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
if isinstance(predicts, (list, tuple)):
|
||||
predicts = predicts[-1]
|
||||
predicts = predicts.transpose((1, 0, 2))
|
||||
N, B, _ = predicts.shape
|
||||
preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
|
||||
labels = batch[1].astype("int32")
|
||||
label_lengths = batch[2].astype('int64')
|
||||
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
|
||||
loss = loss.mean() # sum
|
||||
if self.use_focal_loss:
|
||||
weight = paddle.exp(-loss)
|
||||
weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
|
||||
weight = paddle.square(weight)
|
||||
loss = paddle.multiply(loss, weight)
|
||||
loss = loss.mean()
|
||||
return {'loss': loss}
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from .ace_loss import ACELoss
|
||||
from .center_loss import CenterLoss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
|
||||
|
||||
class EnhancedCTCLoss(nn.Layer):
|
||||
def __init__(self,
|
||||
use_focal_loss=False,
|
||||
use_ace_loss=False,
|
||||
ace_loss_weight=0.1,
|
||||
use_center_loss=False,
|
||||
center_loss_weight=0.05,
|
||||
num_classes=6625,
|
||||
feat_dim=96,
|
||||
init_center=False,
|
||||
center_file_path=None,
|
||||
**kwargs):
|
||||
super(EnhancedCTCLoss, self).__init__()
|
||||
self.ctc_loss_func = CTCLoss(use_focal_loss=use_focal_loss)
|
||||
|
||||
self.use_ace_loss = False
|
||||
if use_ace_loss:
|
||||
self.use_ace_loss = use_ace_loss
|
||||
self.ace_loss_func = ACELoss()
|
||||
self.ace_loss_weight = ace_loss_weight
|
||||
|
||||
self.use_center_loss = False
|
||||
if use_center_loss:
|
||||
self.use_center_loss = use_center_loss
|
||||
self.center_loss_func = CenterLoss(
|
||||
num_classes=num_classes,
|
||||
feat_dim=feat_dim,
|
||||
init_center=init_center,
|
||||
center_file_path=center_file_path)
|
||||
self.center_loss_weight = center_loss_weight
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
loss = self.ctc_loss_func(predicts, batch)["loss"]
|
||||
|
||||
if self.use_center_loss:
|
||||
center_loss = self.center_loss_func(
|
||||
predicts, batch)["loss_center"] * self.center_loss_weight
|
||||
loss = loss + center_loss
|
||||
|
||||
if self.use_ace_loss:
|
||||
ace_loss = self.ace_loss_func(
|
||||
predicts, batch)["loss_ace"] * self.ace_loss_weight
|
||||
loss = loss + ace_loss
|
||||
|
||||
return {'enhanced_ctc_loss': loss}
|
|
@ -9,11 +9,14 @@ from paddle import nn
|
|||
class SARLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(SARLoss, self).__init__()
|
||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96)
|
||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(
|
||||
reduction="mean", ignore_index=92)
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets
|
||||
label = batch[1].astype("int64")[:, 1:] # ignore first index of target in loss calculation
|
||||
predict = predicts[:, :
|
||||
-1, :] # ignore last index of outputs to be in same seq_len with targets
|
||||
label = batch[1].astype(
|
||||
"int64")[:, 1:] # ignore first index of target in loss calculation
|
||||
batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
|
||||
1], predict.shape[2]
|
||||
assert len(label.shape) == len(list(predict.shape)) - 1, \
|
||||
|
|
|
@ -13,13 +13,20 @@
|
|||
# limitations under the License.
|
||||
|
||||
import Levenshtein
|
||||
import string
|
||||
|
||||
|
||||
class RecMetric(object):
|
||||
def __init__(self, main_indicator='acc', **kwargs):
|
||||
def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.is_filter = is_filter
|
||||
self.reset()
|
||||
|
||||
def _normalize_text(self, text):
|
||||
text = ''.join(
|
||||
filter(lambda x: x in (string.digits + string.ascii_letters), text))
|
||||
return text.lower()
|
||||
|
||||
def __call__(self, pred_label, *args, **kwargs):
|
||||
preds, labels = pred_label
|
||||
correct_num = 0
|
||||
|
@ -28,6 +35,9 @@ class RecMetric(object):
|
|||
for (pred, pred_conf), (target, _) in zip(preds, labels):
|
||||
pred = pred.replace(" ", "")
|
||||
target = target.replace(" ", "")
|
||||
if self.is_filter:
|
||||
pred = self._normalize_text(pred)
|
||||
target = self._normalize_text(target)
|
||||
norm_edit_dis += Levenshtein.distance(pred, target) / max(
|
||||
len(pred), len(target), 1)
|
||||
if pred == target:
|
||||
|
@ -57,4 +67,3 @@ class RecMetric(object):
|
|||
self.correct_num = 0
|
||||
self.all_num = 0
|
||||
self.norm_edit_dis = 0
|
||||
|
||||
|
|
|
@ -28,8 +28,10 @@ def build_backbone(config, model_type):
|
|||
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||
from .rec_nrtr_mtb import MTB
|
||||
from .rec_resnet_31 import ResNet31
|
||||
from .rec_resnet_aster import ResNet_ASTER
|
||||
support_dict = [
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', "ResNet31"
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
|
||||
"ResNet31", "ResNet_ASTER"
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
import sys
|
||||
import math
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2D(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2D(
|
||||
in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False)
|
||||
|
||||
|
||||
def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000):
|
||||
# [n_position]
|
||||
positions = paddle.arange(0, n_position)
|
||||
# [feat_dim]
|
||||
dim_range = paddle.arange(0, feat_dim)
|
||||
dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim)
|
||||
# [n_position, feat_dim]
|
||||
angles = paddle.unsqueeze(
|
||||
positions, axis=1) / paddle.unsqueeze(
|
||||
dim_range, axis=0)
|
||||
angles = paddle.cast(angles, "float32")
|
||||
angles[:, 0::2] = paddle.sin(angles[:, 0::2])
|
||||
angles[:, 1::2] = paddle.cos(angles[:, 1::2])
|
||||
return angles
|
||||
|
||||
|
||||
class AsterBlock(nn.Layer):
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(AsterBlock, self).__init__()
|
||||
self.conv1 = conv1x1(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2D(planes)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2D(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet_ASTER(nn.Layer):
|
||||
"""For aster or crnn"""
|
||||
|
||||
def __init__(self, with_lstm=True, n_group=1, in_channels=3):
|
||||
super(ResNet_ASTER, self).__init__()
|
||||
self.with_lstm = with_lstm
|
||||
self.n_group = n_group
|
||||
|
||||
self.layer0 = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
in_channels,
|
||||
32,
|
||||
kernel_size=(3, 3),
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False),
|
||||
nn.BatchNorm2D(32),
|
||||
nn.ReLU())
|
||||
|
||||
self.inplanes = 32
|
||||
self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50]
|
||||
self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25]
|
||||
self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25]
|
||||
self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25]
|
||||
self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25]
|
||||
|
||||
if with_lstm:
|
||||
self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2)
|
||||
self.out_channels = 2 * 256
|
||||
else:
|
||||
self.out_channels = 512
|
||||
|
||||
def _make_layer(self, planes, blocks, stride):
|
||||
downsample = None
|
||||
if stride != [1, 1] or self.inplanes != planes:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
|
||||
|
||||
layers = []
|
||||
layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
layers.append(AsterBlock(self.inplanes, planes))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.layer0(x)
|
||||
x1 = self.layer1(x0)
|
||||
x2 = self.layer2(x1)
|
||||
x3 = self.layer3(x2)
|
||||
x4 = self.layer4(x3)
|
||||
x5 = self.layer5(x4)
|
||||
|
||||
cnn_feat = x5.squeeze(2) # [N, c, w]
|
||||
cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1])
|
||||
if self.with_lstm:
|
||||
rnn_feat, _ = self.rnn(cnn_feat)
|
||||
return rnn_feat
|
||||
else:
|
||||
return cnn_feat
|
|
@ -29,13 +29,14 @@ def build_head(config):
|
|||
from .rec_srn_head import SRNHead
|
||||
from .rec_nrtr_head import Transformer
|
||||
from .rec_sar_head import SARHead
|
||||
from .rec_aster_head import AsterHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead',
|
||||
'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'TableAttentionHead', 'SARHead'
|
||||
'TableAttentionHead', 'SARHead', 'AsterHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -0,0 +1,389 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
||||
|
||||
class AsterHead(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
sDim,
|
||||
attDim,
|
||||
max_len_labels,
|
||||
time_step=25,
|
||||
beam_width=5,
|
||||
**kwargs):
|
||||
super(AsterHead, self).__init__()
|
||||
self.num_classes = out_channels
|
||||
self.in_planes = in_channels
|
||||
self.sDim = sDim
|
||||
self.attDim = attDim
|
||||
self.max_len_labels = max_len_labels
|
||||
self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim,
|
||||
attDim, max_len_labels)
|
||||
self.time_step = time_step
|
||||
self.embeder = Embedding(self.time_step, in_channels)
|
||||
self.beam_width = beam_width
|
||||
self.eos = self.num_classes - 1
|
||||
|
||||
def forward(self, x, targets=None, embed=None):
|
||||
return_dict = {}
|
||||
embedding_vectors = self.embeder(x)
|
||||
|
||||
if self.training:
|
||||
rec_targets, rec_lengths, _ = targets
|
||||
rec_pred = self.decoder([x, rec_targets, rec_lengths],
|
||||
embedding_vectors)
|
||||
return_dict['rec_pred'] = rec_pred
|
||||
return_dict['embedding_vectors'] = embedding_vectors
|
||||
else:
|
||||
rec_pred, rec_pred_scores = self.decoder.beam_search(
|
||||
x, self.beam_width, self.eos, embedding_vectors)
|
||||
return_dict['rec_pred'] = rec_pred
|
||||
return_dict['rec_pred_scores'] = rec_pred_scores
|
||||
return_dict['embedding_vectors'] = embedding_vectors
|
||||
|
||||
return return_dict
|
||||
|
||||
|
||||
class Embedding(nn.Layer):
|
||||
def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
|
||||
super(Embedding, self).__init__()
|
||||
self.in_timestep = in_timestep
|
||||
self.in_planes = in_planes
|
||||
self.embed_dim = embed_dim
|
||||
self.mid_dim = mid_dim
|
||||
self.eEmbed = nn.Linear(
|
||||
in_timestep * in_planes,
|
||||
self.embed_dim) # Embed encoder output to a word-embedding like
|
||||
|
||||
def forward(self, x):
|
||||
x = paddle.reshape(x, [paddle.shape(x)[0], -1])
|
||||
x = self.eEmbed(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionRecognitionHead(nn.Layer):
|
||||
"""
|
||||
input: [b x 16 x 64 x in_planes]
|
||||
output: probability sequence: [b x T x num_classes]
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels):
|
||||
super(AttentionRecognitionHead, self).__init__()
|
||||
self.num_classes = out_channels # this is the output classes. So it includes the <EOS>.
|
||||
self.in_planes = in_channels
|
||||
self.sDim = sDim
|
||||
self.attDim = attDim
|
||||
self.max_len_labels = max_len_labels
|
||||
|
||||
self.decoder = DecoderUnit(
|
||||
sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim)
|
||||
|
||||
def forward(self, x, embed):
|
||||
x, targets, lengths = x
|
||||
batch_size = paddle.shape(x)[0]
|
||||
# Decoder
|
||||
state = self.decoder.get_initial_state(embed)
|
||||
outputs = []
|
||||
for i in range(max(lengths)):
|
||||
if i == 0:
|
||||
y_prev = paddle.full(
|
||||
shape=[batch_size], fill_value=self.num_classes)
|
||||
else:
|
||||
y_prev = targets[:, i - 1]
|
||||
output, state = self.decoder(x, state, y_prev)
|
||||
outputs.append(output)
|
||||
outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1)
|
||||
return outputs
|
||||
|
||||
# inference stage.
|
||||
def sample(self, x):
|
||||
x, _, _ = x
|
||||
batch_size = x.size(0)
|
||||
# Decoder
|
||||
state = paddle.zeros([1, batch_size, self.sDim])
|
||||
|
||||
predicted_ids, predicted_scores = [], []
|
||||
for i in range(self.max_len_labels):
|
||||
if i == 0:
|
||||
y_prev = paddle.full(
|
||||
shape=[batch_size], fill_value=self.num_classes)
|
||||
else:
|
||||
y_prev = predicted
|
||||
|
||||
output, state = self.decoder(x, state, y_prev)
|
||||
output = F.softmax(output, axis=1)
|
||||
score, predicted = output.max(1)
|
||||
predicted_ids.append(predicted.unsqueeze(1))
|
||||
predicted_scores.append(score.unsqueeze(1))
|
||||
predicted_ids = paddle.concat([predicted_ids, 1])
|
||||
predicted_scores = paddle.concat([predicted_scores, 1])
|
||||
# return predicted_ids.squeeze(), predicted_scores.squeeze()
|
||||
return predicted_ids, predicted_scores
|
||||
|
||||
def beam_search(self, x, beam_width, eos, embed):
|
||||
def _inflate(tensor, times, dim):
|
||||
repeat_dims = [1] * tensor.dim()
|
||||
repeat_dims[dim] = times
|
||||
output = paddle.tile(tensor, repeat_dims)
|
||||
return output
|
||||
|
||||
# https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
|
||||
batch_size, l, d = x.shape
|
||||
x = paddle.tile(
|
||||
paddle.transpose(
|
||||
x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1])
|
||||
inflated_encoder_feats = paddle.reshape(
|
||||
paddle.transpose(
|
||||
x, perm=[1, 0, 2, 3]), [-1, l, d])
|
||||
|
||||
# Initialize the decoder
|
||||
state = self.decoder.get_initial_state(embed, tile_times=beam_width)
|
||||
|
||||
pos_index = paddle.reshape(
|
||||
paddle.arange(batch_size) * beam_width, shape=[-1, 1])
|
||||
|
||||
# Initialize the scores
|
||||
sequence_scores = paddle.full(
|
||||
shape=[batch_size * beam_width, 1], fill_value=-float('Inf'))
|
||||
index = [i * beam_width for i in range(0, batch_size)]
|
||||
sequence_scores[index] = 0.0
|
||||
|
||||
# Initialize the input vector
|
||||
y_prev = paddle.full(
|
||||
shape=[batch_size * beam_width], fill_value=self.num_classes)
|
||||
|
||||
# Store decisions for backtracking
|
||||
stored_scores = list()
|
||||
stored_predecessors = list()
|
||||
stored_emitted_symbols = list()
|
||||
|
||||
for i in range(self.max_len_labels):
|
||||
output, state = self.decoder(inflated_encoder_feats, state, y_prev)
|
||||
state = paddle.unsqueeze(state, axis=0)
|
||||
log_softmax_output = paddle.nn.functional.log_softmax(
|
||||
output, axis=1)
|
||||
|
||||
sequence_scores = _inflate(sequence_scores, self.num_classes, 1)
|
||||
sequence_scores += log_softmax_output
|
||||
scores, candidates = paddle.topk(
|
||||
paddle.reshape(sequence_scores, [batch_size, -1]),
|
||||
beam_width,
|
||||
axis=1)
|
||||
|
||||
# Reshape input = (bk, 1) and sequence_scores = (bk, 1)
|
||||
y_prev = paddle.reshape(
|
||||
candidates % self.num_classes, shape=[batch_size * beam_width])
|
||||
sequence_scores = paddle.reshape(
|
||||
scores, shape=[batch_size * beam_width, 1])
|
||||
|
||||
# Update fields for next timestep
|
||||
pos_index = paddle.expand_as(pos_index, candidates)
|
||||
predecessors = paddle.cast(
|
||||
candidates / self.num_classes + pos_index, dtype='int64')
|
||||
predecessors = paddle.reshape(
|
||||
predecessors, shape=[batch_size * beam_width, 1])
|
||||
state = paddle.index_select(
|
||||
state, index=predecessors.squeeze(), axis=1)
|
||||
|
||||
# Update sequence socres and erase scores for <eos> symbol so that they aren't expanded
|
||||
stored_scores.append(sequence_scores.clone())
|
||||
y_prev = paddle.reshape(y_prev, shape=[-1, 1])
|
||||
eos_prev = paddle.full_like(y_prev, fill_value=eos)
|
||||
mask = eos_prev == y_prev
|
||||
mask = paddle.nonzero(mask)
|
||||
if mask.dim() > 0:
|
||||
sequence_scores = sequence_scores.numpy()
|
||||
mask = mask.numpy()
|
||||
sequence_scores[mask] = -float('inf')
|
||||
sequence_scores = paddle.to_tensor(sequence_scores)
|
||||
|
||||
# Cache results for backtracking
|
||||
stored_predecessors.append(predecessors)
|
||||
y_prev = paddle.squeeze(y_prev)
|
||||
stored_emitted_symbols.append(y_prev)
|
||||
|
||||
# Do backtracking to return the optimal values
|
||||
#====== backtrak ======#
|
||||
# Initialize return variables given different types
|
||||
p = list()
|
||||
l = [[self.max_len_labels] * beam_width for _ in range(batch_size)
|
||||
] # Placeholder for lengths of top-k sequences
|
||||
|
||||
# the last step output of the beams are not sorted
|
||||
# thus they are sorted here
|
||||
sorted_score, sorted_idx = paddle.topk(
|
||||
paddle.reshape(
|
||||
stored_scores[-1], shape=[batch_size, beam_width]),
|
||||
beam_width)
|
||||
|
||||
# initialize the sequence scores with the sorted last step beam scores
|
||||
s = sorted_score.clone()
|
||||
|
||||
batch_eos_found = [0] * batch_size # the number of EOS found
|
||||
# in the backward loop below for each batch
|
||||
t = self.max_len_labels - 1
|
||||
# initialize the back pointer with the sorted order of the last step beams.
|
||||
# add pos_index for indexing variable with b*k as the first dimension.
|
||||
t_predecessors = paddle.reshape(
|
||||
sorted_idx + pos_index.expand_as(sorted_idx),
|
||||
shape=[batch_size * beam_width])
|
||||
while t >= 0:
|
||||
# Re-order the variables with the back pointer
|
||||
current_symbol = paddle.index_select(
|
||||
stored_emitted_symbols[t], index=t_predecessors, axis=0)
|
||||
t_predecessors = paddle.index_select(
|
||||
stored_predecessors[t].squeeze(), index=t_predecessors, axis=0)
|
||||
eos_indices = stored_emitted_symbols[t] == eos
|
||||
eos_indices = paddle.nonzero(eos_indices)
|
||||
|
||||
if eos_indices.dim() > 0:
|
||||
for i in range(eos_indices.shape[0] - 1, -1, -1):
|
||||
# Indices of the EOS symbol for both variables
|
||||
# with b*k as the first dimension, and b, k for
|
||||
# the first two dimensions
|
||||
idx = eos_indices[i]
|
||||
b_idx = int(idx[0] / beam_width)
|
||||
# The indices of the replacing position
|
||||
# according to the replacement strategy noted above
|
||||
res_k_idx = beam_width - (batch_eos_found[b_idx] %
|
||||
beam_width) - 1
|
||||
batch_eos_found[b_idx] += 1
|
||||
res_idx = b_idx * beam_width + res_k_idx
|
||||
|
||||
# Replace the old information in return variables
|
||||
# with the new ended sequence information
|
||||
t_predecessors[res_idx] = stored_predecessors[t][idx[0]]
|
||||
current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]]
|
||||
s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0]
|
||||
l[b_idx][res_k_idx] = t + 1
|
||||
|
||||
# record the back tracked results
|
||||
p.append(current_symbol)
|
||||
t -= 1
|
||||
|
||||
# Sort and re-order again as the added ended sequences may change
|
||||
# the order (very unlikely)
|
||||
s, re_sorted_idx = s.topk(beam_width)
|
||||
for b_idx in range(batch_size):
|
||||
l[b_idx] = [
|
||||
l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]
|
||||
]
|
||||
|
||||
re_sorted_idx = paddle.reshape(
|
||||
re_sorted_idx + pos_index.expand_as(re_sorted_idx),
|
||||
[batch_size * beam_width])
|
||||
|
||||
# Reverse the sequences and re-order at the same time
|
||||
# It is reversed because the backtracking happens in reverse time order
|
||||
p = [
|
||||
paddle.reshape(
|
||||
paddle.index_select(step, re_sorted_idx, 0),
|
||||
shape=[batch_size, beam_width, -1]) for step in reversed(p)
|
||||
]
|
||||
p = paddle.concat(p, -1)[:, 0, :]
|
||||
return p, paddle.ones_like(p)
|
||||
|
||||
|
||||
class AttentionUnit(nn.Layer):
|
||||
def __init__(self, sDim, xDim, attDim):
|
||||
super(AttentionUnit, self).__init__()
|
||||
|
||||
self.sDim = sDim
|
||||
self.xDim = xDim
|
||||
self.attDim = attDim
|
||||
|
||||
self.sEmbed = nn.Linear(sDim, attDim)
|
||||
self.xEmbed = nn.Linear(xDim, attDim)
|
||||
self.wEmbed = nn.Linear(attDim, 1)
|
||||
|
||||
def forward(self, x, sPrev):
|
||||
batch_size, T, _ = x.shape # [b x T x xDim]
|
||||
x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim]
|
||||
xProj = self.xEmbed(x) # [(b x T) x attDim]
|
||||
xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim]
|
||||
|
||||
sPrev = sPrev.squeeze(0)
|
||||
sProj = self.sEmbed(sPrev) # [b x attDim]
|
||||
sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim]
|
||||
sProj = paddle.expand(sProj,
|
||||
[batch_size, T, self.attDim]) # [b x T x attDim]
|
||||
|
||||
sumTanh = paddle.tanh(sProj + xProj)
|
||||
sumTanh = paddle.reshape(sumTanh, [-1, self.attDim])
|
||||
|
||||
vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
|
||||
vProj = paddle.reshape(vProj, [batch_size, T])
|
||||
alpha = F.softmax(
|
||||
vProj, axis=1) # attention weights for each sample in the minibatch
|
||||
return alpha
|
||||
|
||||
|
||||
class DecoderUnit(nn.Layer):
|
||||
def __init__(self, sDim, xDim, yDim, attDim):
|
||||
super(DecoderUnit, self).__init__()
|
||||
self.sDim = sDim
|
||||
self.xDim = xDim
|
||||
self.yDim = yDim
|
||||
self.attDim = attDim
|
||||
self.emdDim = attDim
|
||||
|
||||
self.attention_unit = AttentionUnit(sDim, xDim, attDim)
|
||||
self.tgt_embedding = nn.Embedding(
|
||||
yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal(
|
||||
std=0.01)) # the last is used for <BOS>
|
||||
self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim)
|
||||
self.fc = nn.Linear(
|
||||
sDim,
|
||||
yDim,
|
||||
weight_attr=nn.initializer.Normal(std=0.01),
|
||||
bias_attr=nn.initializer.Constant(value=0))
|
||||
self.embed_fc = nn.Linear(300, self.sDim)
|
||||
|
||||
def get_initial_state(self, embed, tile_times=1):
|
||||
assert embed.shape[1] == 300
|
||||
state = self.embed_fc(embed) # N * sDim
|
||||
if tile_times != 1:
|
||||
state = state.unsqueeze(1)
|
||||
trans_state = paddle.transpose(state, perm=[1, 0, 2])
|
||||
state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1])
|
||||
trans_state = paddle.transpose(state, perm=[1, 0, 2])
|
||||
state = paddle.reshape(trans_state, shape=[-1, self.sDim])
|
||||
state = state.unsqueeze(0) # 1 * N * sDim
|
||||
return state
|
||||
|
||||
def forward(self, x, sPrev, yPrev):
|
||||
# x: feature sequence from the image decoder.
|
||||
batch_size, T, _ = x.shape
|
||||
alpha = self.attention_unit(x, sPrev)
|
||||
context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1)
|
||||
yPrev = paddle.cast(yPrev, dtype="int64")
|
||||
yProj = self.tgt_embedding(yPrev)
|
||||
|
||||
concat_context = paddle.concat([yProj, context], 1)
|
||||
concat_context = paddle.squeeze(concat_context, 1)
|
||||
sPrev = paddle.squeeze(sPrev, 0)
|
||||
output, state = self.gru(concat_context, sPrev)
|
||||
output = paddle.squeeze(output, axis=1)
|
||||
output = self.fc(output)
|
||||
return output, state
|
|
@ -38,6 +38,7 @@ class CTCHead(nn.Layer):
|
|||
out_channels,
|
||||
fc_decay=0.0004,
|
||||
mid_channels=None,
|
||||
return_feats=False,
|
||||
**kwargs):
|
||||
super(CTCHead, self).__init__()
|
||||
if mid_channels is None:
|
||||
|
@ -66,14 +67,22 @@ class CTCHead(nn.Layer):
|
|||
bias_attr=bias_attr2)
|
||||
self.out_channels = out_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.return_feats = return_feats
|
||||
|
||||
def forward(self, x, targets=None):
|
||||
if self.mid_channels is None:
|
||||
predicts = self.fc(x)
|
||||
else:
|
||||
predicts = self.fc1(x)
|
||||
predicts = self.fc2(predicts)
|
||||
x = self.fc1(x)
|
||||
predicts = self.fc2(x)
|
||||
|
||||
if self.return_feats:
|
||||
result = (x, predicts)
|
||||
else:
|
||||
result = predicts
|
||||
|
||||
if not self.training:
|
||||
predicts = F.softmax(predicts, axis=2)
|
||||
return predicts
|
||||
result = predicts
|
||||
|
||||
return result
|
||||
|
|
|
@ -51,7 +51,7 @@ class EncoderWithFC(nn.Layer):
|
|||
super(EncoderWithFC, self).__init__()
|
||||
self.out_channels = hidden_size
|
||||
weight_attr, bias_attr = get_para_bias_attr(
|
||||
l2_decay=0.00001, k=in_channels, name='reduce_encoder_fea')
|
||||
l2_decay=0.00001, k=in_channels)
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
hidden_size,
|
||||
|
|
|
@ -17,8 +17,9 @@ __all__ = ['build_transform']
|
|||
|
||||
def build_transform(config):
|
||||
from .tps import TPS
|
||||
from .stn import STN_ON
|
||||
|
||||
support_dict = ['TPS']
|
||||
support_dict = ['TPS', 'STN_ON']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception(
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
from .tps_spatial_transformer import TPSSpatialTransformer
|
||||
|
||||
|
||||
def conv3x3_block(in_channels, out_channels, stride=1):
|
||||
n = 3 * 3 * out_channels
|
||||
w = math.sqrt(2. / n)
|
||||
conv_layer = nn.Conv2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
weight_attr=nn.initializer.Normal(
|
||||
mean=0.0, std=w),
|
||||
bias_attr=nn.initializer.Constant(0))
|
||||
block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU())
|
||||
return block
|
||||
|
||||
|
||||
class STN(nn.Layer):
|
||||
def __init__(self, in_channels, num_ctrlpoints, activation='none'):
|
||||
super(STN, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.num_ctrlpoints = num_ctrlpoints
|
||||
self.activation = activation
|
||||
self.stn_convnet = nn.Sequential(
|
||||
conv3x3_block(in_channels, 32), #32x64
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(32, 64), #16x32
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(64, 128), # 8*16
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(128, 256), # 4*8
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(256, 256), # 2*4,
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(256, 256)) # 1*2
|
||||
self.stn_fc1 = nn.Sequential(
|
||||
nn.Linear(
|
||||
2 * 256,
|
||||
512,
|
||||
weight_attr=nn.initializer.Normal(0, 0.001),
|
||||
bias_attr=nn.initializer.Constant(0)),
|
||||
nn.BatchNorm1D(512),
|
||||
nn.ReLU())
|
||||
fc2_bias = self.init_stn()
|
||||
self.stn_fc2 = nn.Linear(
|
||||
512,
|
||||
num_ctrlpoints * 2,
|
||||
weight_attr=nn.initializer.Constant(0.0),
|
||||
bias_attr=nn.initializer.Assign(fc2_bias))
|
||||
|
||||
def init_stn(self):
|
||||
margin = 0.01
|
||||
sampling_num_per_side = int(self.num_ctrlpoints / 2)
|
||||
ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
|
||||
ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
|
||||
ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
|
||||
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
||||
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
||||
ctrl_points = np.concatenate(
|
||||
[ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
|
||||
if self.activation == 'none':
|
||||
pass
|
||||
elif self.activation == 'sigmoid':
|
||||
ctrl_points = -np.log(1. / ctrl_points - 1.)
|
||||
ctrl_points = paddle.to_tensor(ctrl_points)
|
||||
fc2_bias = paddle.reshape(
|
||||
ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]])
|
||||
return fc2_bias
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stn_convnet(x)
|
||||
batch_size, _, h, w = x.shape
|
||||
x = paddle.reshape(x, shape=(batch_size, -1))
|
||||
img_feat = self.stn_fc1(x)
|
||||
x = self.stn_fc2(0.1 * img_feat)
|
||||
if self.activation == 'sigmoid':
|
||||
x = F.sigmoid(x)
|
||||
x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
|
||||
return img_feat, x
|
||||
|
||||
|
||||
class STN_ON(nn.Layer):
|
||||
def __init__(self, in_channels, tps_inputsize, tps_outputsize,
|
||||
num_control_points, tps_margins, stn_activation):
|
||||
super(STN_ON, self).__init__()
|
||||
self.tps = TPSSpatialTransformer(
|
||||
output_image_size=tuple(tps_outputsize),
|
||||
num_control_points=num_control_points,
|
||||
margins=tuple(tps_margins))
|
||||
self.stn_head = STN(in_channels=in_channels,
|
||||
num_ctrlpoints=num_control_points,
|
||||
activation=stn_activation)
|
||||
self.tps_inputsize = tps_inputsize
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, image):
|
||||
stn_input = paddle.nn.functional.interpolate(
|
||||
image, self.tps_inputsize, mode="bilinear", align_corners=True)
|
||||
stn_img_feat, ctrl_points = self.stn_head(stn_input)
|
||||
x, _ = self.tps(image, ctrl_points)
|
||||
return x
|
|
@ -231,7 +231,8 @@ class GridGenerator(nn.Layer):
|
|||
""" Return inv_delta_C which is needed to calculate T """
|
||||
F = self.F
|
||||
hat_eye = paddle.eye(F, dtype='float64') # F x F
|
||||
hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
|
||||
hat_C = paddle.norm(
|
||||
C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
|
||||
hat_C = (hat_C**2) * paddle.log(hat_C)
|
||||
delta_C = paddle.concat( # F+3 x F+3
|
||||
[
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
|
||||
def grid_sample(input, grid, canvas=None):
|
||||
input.stop_gradient = False
|
||||
output = F.grid_sample(input, grid)
|
||||
if canvas is None:
|
||||
return output
|
||||
else:
|
||||
input_mask = paddle.ones(shape=input.shape)
|
||||
output_mask = F.grid_sample(input_mask, grid)
|
||||
padded_output = output * output_mask + canvas * (1 - output_mask)
|
||||
return padded_output
|
||||
|
||||
|
||||
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
|
||||
def compute_partial_repr(input_points, control_points):
|
||||
N = input_points.shape[0]
|
||||
M = control_points.shape[0]
|
||||
pairwise_diff = paddle.reshape(
|
||||
input_points, shape=[N, 1, 2]) - paddle.reshape(
|
||||
control_points, shape=[1, M, 2])
|
||||
# original implementation, very slow
|
||||
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
|
||||
pairwise_diff_square = pairwise_diff * pairwise_diff
|
||||
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
|
||||
1]
|
||||
repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist)
|
||||
# fix numerical error for 0 * log(0), substitute all nan with 0
|
||||
mask = repr_matrix != repr_matrix
|
||||
repr_matrix[mask] = 0
|
||||
return repr_matrix
|
||||
|
||||
|
||||
# output_ctrl_pts are specified, according to our task.
|
||||
def build_output_control_points(num_control_points, margins):
|
||||
margin_x, margin_y = margins
|
||||
num_ctrl_pts_per_side = num_control_points // 2
|
||||
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
|
||||
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
|
||||
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
|
||||
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
||||
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
||||
output_ctrl_pts_arr = np.concatenate(
|
||||
[ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
||||
output_ctrl_pts = paddle.to_tensor(output_ctrl_pts_arr)
|
||||
return output_ctrl_pts
|
||||
|
||||
|
||||
class TPSSpatialTransformer(nn.Layer):
|
||||
def __init__(self,
|
||||
output_image_size=None,
|
||||
num_control_points=None,
|
||||
margins=None):
|
||||
super(TPSSpatialTransformer, self).__init__()
|
||||
self.output_image_size = output_image_size
|
||||
self.num_control_points = num_control_points
|
||||
self.margins = margins
|
||||
|
||||
self.target_height, self.target_width = output_image_size
|
||||
target_control_points = build_output_control_points(num_control_points,
|
||||
margins)
|
||||
N = num_control_points
|
||||
|
||||
# create padded kernel matrix
|
||||
forward_kernel = paddle.zeros(shape=[N + 3, N + 3])
|
||||
target_control_partial_repr = compute_partial_repr(
|
||||
target_control_points, target_control_points)
|
||||
target_control_partial_repr = paddle.cast(target_control_partial_repr,
|
||||
forward_kernel.dtype)
|
||||
forward_kernel[:N, :N] = target_control_partial_repr
|
||||
forward_kernel[:N, -3] = 1
|
||||
forward_kernel[-3, :N] = 1
|
||||
target_control_points = paddle.cast(target_control_points,
|
||||
forward_kernel.dtype)
|
||||
forward_kernel[:N, -2:] = target_control_points
|
||||
forward_kernel[-2:, :N] = paddle.transpose(
|
||||
target_control_points, perm=[1, 0])
|
||||
# compute inverse matrix
|
||||
inverse_kernel = paddle.inverse(forward_kernel)
|
||||
|
||||
# create target cordinate matrix
|
||||
HW = self.target_height * self.target_width
|
||||
target_coordinate = list(
|
||||
itertools.product(
|
||||
range(self.target_height), range(self.target_width)))
|
||||
target_coordinate = paddle.to_tensor(target_coordinate) # HW x 2
|
||||
Y, X = paddle.split(
|
||||
target_coordinate, target_coordinate.shape[1], axis=1)
|
||||
Y = Y / (self.target_height - 1)
|
||||
X = X / (self.target_width - 1)
|
||||
target_coordinate = paddle.concat(
|
||||
[X, Y], axis=1) # convert from (y, x) to (x, y)
|
||||
target_coordinate_partial_repr = compute_partial_repr(
|
||||
target_coordinate, target_control_points)
|
||||
target_coordinate_repr = paddle.concat(
|
||||
[
|
||||
target_coordinate_partial_repr, paddle.ones(shape=[HW, 1]),
|
||||
target_coordinate
|
||||
],
|
||||
axis=1)
|
||||
|
||||
# register precomputed matrices
|
||||
self.inverse_kernel = inverse_kernel
|
||||
self.padding_matrix = paddle.zeros(shape=[3, 2])
|
||||
self.target_coordinate_repr = target_coordinate_repr
|
||||
self.target_control_points = target_control_points
|
||||
|
||||
def forward(self, input, source_control_points):
|
||||
assert source_control_points.ndimension() == 3
|
||||
assert source_control_points.shape[1] == self.num_control_points
|
||||
assert source_control_points.shape[2] == 2
|
||||
batch_size = paddle.shape(source_control_points)[0]
|
||||
|
||||
self.padding_matrix = paddle.expand(
|
||||
self.padding_matrix, shape=[batch_size, 3, 2])
|
||||
Y = paddle.concat([source_control_points, self.padding_matrix], 1)
|
||||
mapping_matrix = paddle.matmul(self.inverse_kernel, Y)
|
||||
source_coordinate = paddle.matmul(self.target_coordinate_repr,
|
||||
mapping_matrix)
|
||||
|
||||
grid = paddle.reshape(
|
||||
source_coordinate,
|
||||
shape=[-1, self.target_height, self.target_width, 2])
|
||||
grid = paddle.clip(grid, 0,
|
||||
1) # the source_control_points may be out of [0, 1].
|
||||
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
|
||||
grid = 2.0 * grid - 1.0
|
||||
output_maps = grid_sample(input, grid, canvas=None)
|
||||
return output_maps, source_coordinate
|
|
@ -127,3 +127,34 @@ class RMSProp(object):
|
|||
grad_clip=self.grad_clip,
|
||||
parameters=parameters)
|
||||
return opt
|
||||
|
||||
|
||||
class Adadelta(object):
|
||||
def __init__(self,
|
||||
learning_rate=0.001,
|
||||
epsilon=1e-08,
|
||||
rho=0.95,
|
||||
parameter_list=None,
|
||||
weight_decay=None,
|
||||
grad_clip=None,
|
||||
name=None,
|
||||
**kwargs):
|
||||
self.learning_rate = learning_rate
|
||||
self.epsilon = epsilon
|
||||
self.rho = rho
|
||||
self.parameter_list = parameter_list
|
||||
self.learning_rate = learning_rate
|
||||
self.weight_decay = weight_decay
|
||||
self.grad_clip = grad_clip
|
||||
self.name = name
|
||||
|
||||
def __call__(self, parameters):
|
||||
opt = optim.Adadelta(
|
||||
learning_rate=self.learning_rate,
|
||||
epsilon=self.epsilon,
|
||||
rho=self.rho,
|
||||
weight_decay=self.weight_decay,
|
||||
grad_clip=self.grad_clip,
|
||||
name=self.name,
|
||||
parameters=parameters)
|
||||
return opt
|
||||
|
|
|
@ -18,17 +18,21 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import copy
|
||||
import platform
|
||||
|
||||
__all__ = ['build_post_process']
|
||||
|
||||
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \
|
||||
TableLabelDecode, SARLabelDecode
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||
TableLabelDecode, NRTRLabelDecode, SARLabelDecode , SEEDLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .pse_postprocess import PSEPostProcess
|
||||
|
||||
if platform.system() != "Windows":
|
||||
# pse is not support in Windows
|
||||
from .pse_postprocess import PSEPostProcess
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
|
@ -36,7 +40,8 @@ def build_post_process(config, global_config=None):
|
|||
'DBPostProcess', 'PSEPostProcess', 'EASTPostProcess', 'SASTPostProcess',
|
||||
'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
|
||||
'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode'
|
||||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
|
||||
'SEEDLabelDecode'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -111,6 +111,8 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|||
character_type, use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[-1]
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
|
@ -308,6 +310,87 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
return idx
|
||||
|
||||
|
||||
class SEEDLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(SEEDLabelDecode, self).__init__(character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
dict_character = dict_character + [self.end_str]
|
||||
return dict_character
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
end_idx = self.get_beg_end_flag_idx("eos")
|
||||
return [end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if beg_or_end == "sos":
|
||||
idx = np.array(self.dict[self.beg_str])
|
||||
elif beg_or_end == "eos":
|
||||
idx = np.array(self.dict[self.end_str])
|
||||
else:
|
||||
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
|
||||
return idx
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
[end_idx] = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if int(text_index[batch_idx][idx]) == int(end_idx):
|
||||
break
|
||||
if is_remove_duplicate:
|
||||
# only for predict
|
||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text, np.mean(conf_list)))
|
||||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
"""
|
||||
text = self.decode(text)
|
||||
if label is None:
|
||||
return text
|
||||
else:
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
"""
|
||||
preds_idx = preds["rec_pred"]
|
||||
if isinstance(preds_idx, paddle.Tensor):
|
||||
preds_idx = preds_idx.numpy()
|
||||
if "rec_pred_scores" in preds:
|
||||
preds_idx = preds["rec_pred"]
|
||||
preds_prob = preds["rec_pred_scores"]
|
||||
else:
|
||||
preds_idx = preds["rec_pred"].argmax(axis=2)
|
||||
preds_prob = preds["rec_pred"].max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
|
||||
|
||||
class SRNLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
!
|
||||
"
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
'
|
||||
(
|
||||
)
|
||||
*
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
:
|
||||
;
|
||||
<
|
||||
=
|
||||
>
|
||||
?
|
||||
@
|
||||
[
|
||||
\
|
||||
]
|
||||
_
|
||||
`
|
||||
~
|
|
@ -0,0 +1,110 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import paddle
|
||||
|
||||
# A global variable to record the number of calling times for profiler
|
||||
# functions. It is used to specify the tracing range of training steps.
|
||||
_profiler_step_id = 0
|
||||
|
||||
# A global variable to avoid parsing from string every time.
|
||||
_profiler_options = None
|
||||
|
||||
|
||||
class ProfilerOptions(object):
|
||||
'''
|
||||
Use a string to initialize a ProfilerOptions.
|
||||
The string should be in the format: "key1=value1;key2=value;key3=value3".
|
||||
For example:
|
||||
"profile_path=model.profile"
|
||||
"batch_range=[50, 60]; profile_path=model.profile"
|
||||
"batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile"
|
||||
ProfilerOptions supports following key-value pair:
|
||||
batch_range - a integer list, e.g. [100, 110].
|
||||
state - a string, the optional values are 'CPU', 'GPU' or 'All'.
|
||||
sorted_key - a string, the optional values are 'calls', 'total',
|
||||
'max', 'min' or 'ave.
|
||||
tracer_option - a string, the optional values are 'Default', 'OpDetail',
|
||||
'AllOpDetail'.
|
||||
profile_path - a string, the path to save the serialized profile data,
|
||||
which can be used to generate a timeline.
|
||||
exit_on_finished - a boolean.
|
||||
'''
|
||||
|
||||
def __init__(self, options_str):
|
||||
assert isinstance(options_str, str)
|
||||
|
||||
self._options = {
|
||||
'batch_range': [10, 20],
|
||||
'state': 'All',
|
||||
'sorted_key': 'total',
|
||||
'tracer_option': 'Default',
|
||||
'profile_path': '/tmp/profile',
|
||||
'exit_on_finished': True
|
||||
}
|
||||
self._parse_from_string(options_str)
|
||||
|
||||
def _parse_from_string(self, options_str):
|
||||
for kv in options_str.replace(' ', '').split(';'):
|
||||
key, value = kv.split('=')
|
||||
if key == 'batch_range':
|
||||
value_list = value.replace('[', '').replace(']', '').split(',')
|
||||
value_list = list(map(int, value_list))
|
||||
if len(value_list) >= 2 and value_list[0] >= 0 and value_list[
|
||||
1] > value_list[0]:
|
||||
self._options[key] = value_list
|
||||
elif key == 'exit_on_finished':
|
||||
self._options[key] = value.lower() in ("yes", "true", "t", "1")
|
||||
elif key in [
|
||||
'state', 'sorted_key', 'tracer_option', 'profile_path'
|
||||
]:
|
||||
self._options[key] = value
|
||||
|
||||
def __getitem__(self, name):
|
||||
if self._options.get(name, None) is None:
|
||||
raise ValueError(
|
||||
"ProfilerOptions does not have an option named %s." % name)
|
||||
return self._options[name]
|
||||
|
||||
|
||||
def add_profiler_step(options_str=None):
|
||||
'''
|
||||
Enable the operator-level timing using PaddlePaddle's profiler.
|
||||
The profiler uses a independent variable to count the profiler steps.
|
||||
One call of this function is treated as a profiler step.
|
||||
|
||||
Args:
|
||||
profiler_options - a string to initialize the ProfilerOptions.
|
||||
Default is None, and the profiler is disabled.
|
||||
'''
|
||||
if options_str is None:
|
||||
return
|
||||
|
||||
global _profiler_step_id
|
||||
global _profiler_options
|
||||
|
||||
if _profiler_options is None:
|
||||
_profiler_options = ProfilerOptions(options_str)
|
||||
|
||||
if _profiler_step_id == _profiler_options['batch_range'][0]:
|
||||
paddle.utils.profiler.start_profiler(
|
||||
_profiler_options['state'], _profiler_options['tracer_option'])
|
||||
elif _profiler_step_id == _profiler_options['batch_range'][1]:
|
||||
paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'],
|
||||
_profiler_options['profile_path'])
|
||||
if _profiler_options['exit_on_finished']:
|
||||
sys.exit(0)
|
||||
|
||||
_profiler_step_id += 1
|
|
@ -4,9 +4,9 @@
|
|||
|
||||
[1.1 Requirements](#Requirements)
|
||||
|
||||
[1.2 Install PaddleDetection](#Install PaddleDetection)
|
||||
[1.2 Install PaddleDetection](#Install_PaddleDetection)
|
||||
|
||||
[2. Data preparation](#Data preparation)
|
||||
[2. Data preparation](#Data_reparation)
|
||||
|
||||
[3. Configuration](#Configuration)
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
|||
|
||||
[6. Deployment](#Deployment)
|
||||
|
||||
[6.1 Export model](#Export model)
|
||||
[6.1 Export model](#Export_model)
|
||||
|
||||
[6.2 Inference](#Inference)
|
||||
|
||||
|
@ -35,7 +35,7 @@
|
|||
- CUDA >= 10.1
|
||||
- cuDNN >= 7.6
|
||||
|
||||
<a name="Install PaddleDetection"></a>
|
||||
<a name="Install_PaddleDetection"></a>
|
||||
|
||||
### 1.2 Install PaddleDetection
|
||||
|
||||
|
@ -51,7 +51,7 @@ pip install -r requirements.txt
|
|||
|
||||
For more installation tutorials, please refer to: [Install doc](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/INSTALL_cn.md)
|
||||
|
||||
<a name="Data preparation"></a>
|
||||
<a name="Data_preparation"></a>
|
||||
|
||||
## 2. Data preparation
|
||||
|
||||
|
@ -165,7 +165,7 @@ python tools/infer.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --infer
|
|||
|
||||
Use your trained model in Layout Parser
|
||||
|
||||
<a name="Export model"></a>
|
||||
<a name="Export_model"></a>
|
||||
|
||||
### 6.1 Export model
|
||||
|
||||
|
|
|
@ -12,3 +12,4 @@ cython
|
|||
lxml
|
||||
premailer
|
||||
openpyxl
|
||||
fasttext==0.9.1
|
|
@ -0,0 +1,65 @@
|
|||
#!/bin/bash
|
||||
|
||||
function func_parser_key(){
|
||||
strs=$1
|
||||
IFS=":"
|
||||
array=(${strs})
|
||||
tmp=${array[0]}
|
||||
echo ${tmp}
|
||||
}
|
||||
|
||||
function func_parser_value(){
|
||||
strs=$1
|
||||
IFS=":"
|
||||
array=(${strs})
|
||||
tmp=${array[1]}
|
||||
echo ${tmp}
|
||||
}
|
||||
|
||||
function func_set_params(){
|
||||
key=$1
|
||||
value=$2
|
||||
if [ ${key}x = "null"x ];then
|
||||
echo " "
|
||||
elif [[ ${value} = "null" ]] || [[ ${value} = " " ]] || [ ${#value} -le 0 ];then
|
||||
echo " "
|
||||
else
|
||||
echo "${key}=${value}"
|
||||
fi
|
||||
}
|
||||
|
||||
function func_parser_params(){
|
||||
strs=$1
|
||||
IFS=":"
|
||||
array=(${strs})
|
||||
key=${array[0]}
|
||||
tmp=${array[1]}
|
||||
IFS="|"
|
||||
res=""
|
||||
for _params in ${tmp[*]}; do
|
||||
IFS="="
|
||||
array=(${_params})
|
||||
mode=${array[0]}
|
||||
value=${array[1]}
|
||||
if [[ ${mode} = ${MODE} ]]; then
|
||||
IFS="|"
|
||||
#echo $(func_set_params "${mode}" "${value}")
|
||||
echo $value
|
||||
break
|
||||
fi
|
||||
IFS="|"
|
||||
done
|
||||
echo ${res}
|
||||
}
|
||||
|
||||
function status_check(){
|
||||
last_status=$1 # the exit code
|
||||
run_command=$2
|
||||
run_log=$3
|
||||
if [ $last_status -eq 0 ]; then
|
||||
echo -e "\033[33m Run successfully with command - ${run_command}! \033[0m" | tee -a ${run_log}
|
||||
else
|
||||
echo -e "\033[33m Run failed with command - ${run_command}! \033[0m" | tee -a ${run_log}
|
||||
fi
|
||||
}
|
||||
|
|
@ -40,13 +40,13 @@ infer_quant:False
|
|||
inference:tools/infer/predict_det.py
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:True|False
|
||||
--cpu_threads:6
|
||||
--cpu_threads:1|6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False|True
|
||||
--precision:fp32|fp16|int8
|
||||
--det_model_dir:
|
||||
--image_dir:./inference/ch_det_data_50/all-sum-510/
|
||||
--save_log_path:null
|
||||
null:null
|
||||
--benchmark:True
|
||||
null:null
|
||||
===========================cpp_infer_params===========================
|
||||
|
@ -79,4 +79,20 @@ op.det.local_service_conf.thread_num:1|6
|
|||
op.det.local_service_conf.use_trt:False|True
|
||||
op.det.local_service_conf.precision:fp32|fp16|int8
|
||||
pipline:pipeline_http_client.py --image_dir=../../doc/imgs
|
||||
|
||||
===========================kl_quant_params===========================
|
||||
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
|
||||
infer_export:tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
|
||||
infer_quant:False
|
||||
inference:tools/infer/predict_det.py
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:True|False
|
||||
--cpu_threads:1|6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False|True
|
||||
--precision:fp32|fp16|int8
|
||||
--det_model_dir:
|
||||
--image_dir:./inference/ch_det_data_50/all-sum-510/
|
||||
null:null
|
||||
--benchmark:True
|
||||
null:null
|
||||
null:null
|
|
@ -12,10 +12,10 @@ train_model_name:latest
|
|||
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train|pact_train
|
||||
norm_train:tools/train.py -c tests/configs/det_r50_vd_db.yml -o Global.pretrained_model=""
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
trainer:norm_train|pact_train|fpgm_export
|
||||
norm_train:tools/train.py -c tests/configs/det_r50_vd_db.yml -o
|
||||
quant_export:deploy/slim/quantization/export_model.py -c tests/configs/det_r50_vd_db.yml -o
|
||||
fpgm_export:deploy/slim/prune/export_prune_model.py -c tests/configs/det_r50_vd_db.yml -o
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
|
@ -34,8 +34,8 @@ distill_export:null
|
|||
export1:null
|
||||
export2:null
|
||||
##
|
||||
infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
|
||||
infer_export:null
|
||||
train_model:./inference/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||
infer_export:tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml -o
|
||||
infer_quant:False
|
||||
inference:tools/infer/predict_det.py
|
||||
--use_gpu:True|False
|
|
@ -1,51 +0,0 @@
|
|||
===========================train_params===========================
|
||||
model_name:ocr_system
|
||||
python:python3.7
|
||||
gpu_list:null
|
||||
Global.use_gpu:null
|
||||
Global.auto_cast:null
|
||||
Global.epoch_num:null
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:null
|
||||
Global.pretrained_model:null
|
||||
train_model_name:null
|
||||
train_infer_img_dir:null
|
||||
null:null
|
||||
##
|
||||
trainer:
|
||||
norm_train:null
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:null
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Global.pretrained_model:
|
||||
norm_export:null
|
||||
quant_export:null
|
||||
fpgm_export:null
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
|
||||
kl_quant:deploy/slim/quantization/quant_kl.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
|
||||
infer_quant:True
|
||||
inference:tools/infer/predict_det.py
|
||||
--use_gpu:TrueFalse
|
||||
--enable_mkldnn:True|False
|
||||
--cpu_threads:1|6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False|True
|
||||
--precision:fp32|fp16|int8
|
||||
--det_model_dir:
|
||||
--image_dir:./inference/ch_det_data_50/all-sum-510/
|
||||
--save_log_path:null
|
||||
--benchmark:True
|
||||
null:null
|
|
@ -1,7 +1,9 @@
|
|||
#!/bin/bash
|
||||
FILENAME=$1
|
||||
|
||||
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', 'cpp_infer', 'serving_infer']
|
||||
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer',
|
||||
# 'cpp_infer', 'serving_infer', 'klquant_infer']
|
||||
|
||||
MODE=$2
|
||||
|
||||
dataline=$(cat ${FILENAME})
|
||||
|
@ -72,9 +74,9 @@ elif [ ${MODE} = "infer" ];then
|
|||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../
|
||||
elif [ ${model_name} = "ocr_server_det" ]; then
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
|
||||
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
|
||||
cd ./inference && tar xf ch_ppocr_server_v2.0_det_train.tar && tar xf ch_det_data_50.tar && cd ../
|
||||
elif [ ${model_name} = "ocr_system_mobile" ]; then
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
|
||||
|
@ -98,6 +100,12 @@ elif [ ${MODE} = "infer" ];then
|
|||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && tar xf rec_inference.tar && cd ../
|
||||
fi
|
||||
elif [ ${MODE} = "klquant_infer" ];then
|
||||
if [ ${model_name} = "ocr_det" ]; then
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
|
||||
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
|
||||
fi
|
||||
elif [ ${MODE} = "cpp_infer" ];then
|
||||
if [ ${model_name} = "ocr_det" ]; then
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
|
||||
|
@ -128,77 +136,3 @@ if [ ${MODE} = "serving_infer" ];then
|
|||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar
|
||||
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_det_infer.tar cd ../
|
||||
fi
|
||||
|
||||
if [ ${MODE} = "cpp_infer" ];then
|
||||
cd deploy/cpp_infer
|
||||
use_opencv=$(func_parser_value "${lines[52]}")
|
||||
if [ ${use_opencv} = "True" ]; then
|
||||
if [ -d "opencv-3.4.7/opencv3/" ] && [ $(md5sum opencv-3.4.7.tar.gz | awk -F ' ' '{print $1}') = "faa2b5950f8bee3f03118e600c74746a" ];then
|
||||
echo "################### build opencv skipped ###################"
|
||||
else
|
||||
echo "################### build opencv ###################"
|
||||
rm -rf opencv-3.4.7.tar.gz opencv-3.4.7/
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/opencv-3.4.7.tar.gz
|
||||
tar -xf opencv-3.4.7.tar.gz
|
||||
|
||||
cd opencv-3.4.7/
|
||||
install_path=$(pwd)/opencv3
|
||||
|
||||
rm -rf build
|
||||
mkdir build
|
||||
cd build
|
||||
|
||||
cmake .. \
|
||||
-DCMAKE_INSTALL_PREFIX=${install_path} \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DBUILD_SHARED_LIBS=OFF \
|
||||
-DWITH_IPP=OFF \
|
||||
-DBUILD_IPP_IW=OFF \
|
||||
-DWITH_LAPACK=OFF \
|
||||
-DWITH_EIGEN=OFF \
|
||||
-DCMAKE_INSTALL_LIBDIR=lib64 \
|
||||
-DWITH_ZLIB=ON \
|
||||
-DBUILD_ZLIB=ON \
|
||||
-DWITH_JPEG=ON \
|
||||
-DBUILD_JPEG=ON \
|
||||
-DWITH_PNG=ON \
|
||||
-DBUILD_PNG=ON \
|
||||
-DWITH_TIFF=ON \
|
||||
-DBUILD_TIFF=ON
|
||||
|
||||
make -j
|
||||
make install
|
||||
cd ../
|
||||
echo "################### build opencv finished ###################"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
echo "################### build PaddleOCR demo ####################"
|
||||
if [ ${use_opencv} = "True" ]; then
|
||||
OPENCV_DIR=$(pwd)/opencv-3.4.7/opencv3/
|
||||
else
|
||||
OPENCV_DIR=''
|
||||
fi
|
||||
LIB_DIR=$(pwd)/Paddle/build/paddle_inference_install_dir/
|
||||
CUDA_LIB_DIR=$(dirname `find /usr -name libcudart.so`)
|
||||
CUDNN_LIB_DIR=$(dirname `find /usr -name libcudnn.so`)
|
||||
|
||||
BUILD_DIR=build
|
||||
rm -rf ${BUILD_DIR}
|
||||
mkdir ${BUILD_DIR}
|
||||
cd ${BUILD_DIR}
|
||||
cmake .. \
|
||||
-DPADDLE_LIB=${LIB_DIR} \
|
||||
-DWITH_MKL=ON \
|
||||
-DWITH_GPU=OFF \
|
||||
-DWITH_STATIC_LIB=OFF \
|
||||
-DWITH_TENSORRT=OFF \
|
||||
-DOPENCV_DIR=${OPENCV_DIR} \
|
||||
-DCUDNN_LIB=${CUDNN_LIB_DIR} \
|
||||
-DCUDA_LIB=${CUDA_LIB_DIR} \
|
||||
-DTENSORRT_DIR=${TENSORRT_DIR} \
|
||||
|
||||
make -j
|
||||
echo "################### build PaddleOCR demo finished ###################"
|
||||
fi
|
||||
|
|
144
tests/test.sh
|
@ -1,9 +1,16 @@
|
|||
#!/bin/bash
|
||||
FILENAME=$1
|
||||
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', 'cpp_infer']
|
||||
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', 'cpp_infer', 'serving_infer', 'klquant_infer']
|
||||
MODE=$2
|
||||
|
||||
dataline=$(cat ${FILENAME})
|
||||
if [ ${MODE} = "cpp_infer" ]; then
|
||||
dataline=$(awk 'NR==67, NR==81{print}' $FILENAME)
|
||||
elif [ ${MODE} = "serving_infer" ]; then
|
||||
dataline=$(awk 'NR==52, NR==66{print}' $FILENAME)
|
||||
elif [ ${MODE} = "klquant_infer" ]; then
|
||||
dataline=$(awk 'NR==82, NR==98{print}' $FILENAME)
|
||||
else
|
||||
dataline=$(awk 'NR==1, NR==51{print}' $FILENAME)
|
||||
fi
|
||||
|
||||
# parser params
|
||||
IFS=$'\n'
|
||||
|
@ -144,61 +151,93 @@ benchmark_key=$(func_parser_key "${lines[49]}")
|
|||
benchmark_value=$(func_parser_value "${lines[49]}")
|
||||
infer_key1=$(func_parser_key "${lines[50]}")
|
||||
infer_value1=$(func_parser_value "${lines[50]}")
|
||||
# parser serving
|
||||
trans_model_py=$(func_parser_value "${lines[67]}")
|
||||
infer_model_dir_key=$(func_parser_key "${lines[68]}")
|
||||
infer_model_dir_value=$(func_parser_value "${lines[68]}")
|
||||
model_filename_key=$(func_parser_key "${lines[69]}")
|
||||
model_filename_value=$(func_parser_value "${lines[69]}")
|
||||
params_filename_key=$(func_parser_key "${lines[70]}")
|
||||
params_filename_value=$(func_parser_value "${lines[70]}")
|
||||
serving_server_key=$(func_parser_key "${lines[71]}")
|
||||
serving_server_value=$(func_parser_value "${lines[71]}")
|
||||
serving_client_key=$(func_parser_key "${lines[72]}")
|
||||
serving_client_value=$(func_parser_value "${lines[72]}")
|
||||
serving_dir_value=$(func_parser_value "${lines[73]}")
|
||||
web_service_py=$(func_parser_value "${lines[74]}")
|
||||
web_use_gpu_key=$(func_parser_key "${lines[75]}")
|
||||
web_use_gpu_list=$(func_parser_value "${lines[75]}")
|
||||
web_use_mkldnn_key=$(func_parser_key "${lines[76]}")
|
||||
web_use_mkldnn_list=$(func_parser_value "${lines[76]}")
|
||||
web_cpu_threads_key=$(func_parser_key "${lines[77]}")
|
||||
web_cpu_threads_list=$(func_parser_value "${lines[77]}")
|
||||
web_use_trt_key=$(func_parser_key "${lines[78]}")
|
||||
web_use_trt_list=$(func_parser_value "${lines[78]}")
|
||||
web_precision_key=$(func_parser_key "${lines[79]}")
|
||||
web_precision_list=$(func_parser_value "${lines[79]}")
|
||||
pipeline_py=$(func_parser_value "${lines[80]}")
|
||||
|
||||
# parser serving
|
||||
if [ ${MODE} = "klquant_infer" ]; then
|
||||
# parser inference model
|
||||
infer_model_dir_list=$(func_parser_value "${lines[1]}")
|
||||
infer_export_list=$(func_parser_value "${lines[2]}")
|
||||
infer_is_quant=$(func_parser_value "${lines[3]}")
|
||||
# parser inference
|
||||
inference_py=$(func_parser_value "${lines[4]}")
|
||||
use_gpu_key=$(func_parser_key "${lines[5]}")
|
||||
use_gpu_list=$(func_parser_value "${lines[5]}")
|
||||
use_mkldnn_key=$(func_parser_key "${lines[6]}")
|
||||
use_mkldnn_list=$(func_parser_value "${lines[6]}")
|
||||
cpu_threads_key=$(func_parser_key "${lines[7]}")
|
||||
cpu_threads_list=$(func_parser_value "${lines[7]}")
|
||||
batch_size_key=$(func_parser_key "${lines[8]}")
|
||||
batch_size_list=$(func_parser_value "${lines[8]}")
|
||||
use_trt_key=$(func_parser_key "${lines[9]}")
|
||||
use_trt_list=$(func_parser_value "${lines[9]}")
|
||||
precision_key=$(func_parser_key "${lines[10]}")
|
||||
precision_list=$(func_parser_value "${lines[10]}")
|
||||
infer_model_key=$(func_parser_key "${lines[11]}")
|
||||
image_dir_key=$(func_parser_key "${lines[12]}")
|
||||
infer_img_dir=$(func_parser_value "${lines[12]}")
|
||||
save_log_key=$(func_parser_key "${lines[13]}")
|
||||
benchmark_key=$(func_parser_key "${lines[14]}")
|
||||
benchmark_value=$(func_parser_value "${lines[14]}")
|
||||
infer_key1=$(func_parser_key "${lines[15]}")
|
||||
infer_value1=$(func_parser_value "${lines[15]}")
|
||||
fi
|
||||
# parser serving
|
||||
if [ ${MODE} = "server_infer" ]; then
|
||||
trans_model_py=$(func_parser_value "${lines[1]}")
|
||||
infer_model_dir_key=$(func_parser_key "${lines[2]}")
|
||||
infer_model_dir_value=$(func_parser_value "${lines[2]}")
|
||||
model_filename_key=$(func_parser_key "${lines[3]}")
|
||||
model_filename_value=$(func_parser_value "${lines[3]}")
|
||||
params_filename_key=$(func_parser_key "${lines[4]}")
|
||||
params_filename_value=$(func_parser_value "${lines[4]}")
|
||||
serving_server_key=$(func_parser_key "${lines[5]}")
|
||||
serving_server_value=$(func_parser_value "${lines[5]}")
|
||||
serving_client_key=$(func_parser_key "${lines[6]}")
|
||||
serving_client_value=$(func_parser_value "${lines[6]}")
|
||||
serving_dir_value=$(func_parser_value "${lines[7]}")
|
||||
web_service_py=$(func_parser_value "${lines[8]}")
|
||||
web_use_gpu_key=$(func_parser_key "${lines[9]}")
|
||||
web_use_gpu_list=$(func_parser_value "${lines[9]}")
|
||||
web_use_mkldnn_key=$(func_parser_key "${lines[10]}")
|
||||
web_use_mkldnn_list=$(func_parser_value "${lines[10]}")
|
||||
web_cpu_threads_key=$(func_parser_key "${lines[11]}")
|
||||
web_cpu_threads_list=$(func_parser_value "${lines[11]}")
|
||||
web_use_trt_key=$(func_parser_key "${lines[12]}")
|
||||
web_use_trt_list=$(func_parser_value "${lines[12]}")
|
||||
web_precision_key=$(func_parser_key "${lines[13]}")
|
||||
web_precision_list=$(func_parser_value "${lines[13]}")
|
||||
pipeline_py=$(func_parser_value "${lines[14]}")
|
||||
fi
|
||||
|
||||
if [ ${MODE} = "cpp_infer" ]; then
|
||||
# parser cpp inference model
|
||||
cpp_infer_model_dir_list=$(func_parser_value "${lines[53]}")
|
||||
cpp_infer_is_quant=$(func_parser_value "${lines[54]}")
|
||||
cpp_infer_model_dir_list=$(func_parser_value "${lines[1]}")
|
||||
cpp_infer_is_quant=$(func_parser_value "${lines[2]}")
|
||||
# parser cpp inference
|
||||
inference_cmd=$(func_parser_value "${lines[55]}")
|
||||
cpp_use_gpu_key=$(func_parser_key "${lines[56]}")
|
||||
cpp_use_gpu_list=$(func_parser_value "${lines[56]}")
|
||||
cpp_use_mkldnn_key=$(func_parser_key "${lines[57]}")
|
||||
cpp_use_mkldnn_list=$(func_parser_value "${lines[57]}")
|
||||
cpp_cpu_threads_key=$(func_parser_key "${lines[58]}")
|
||||
cpp_cpu_threads_list=$(func_parser_value "${lines[58]}")
|
||||
cpp_batch_size_key=$(func_parser_key "${lines[59]}")
|
||||
cpp_batch_size_list=$(func_parser_value "${lines[59]}")
|
||||
cpp_use_trt_key=$(func_parser_key "${lines[60]}")
|
||||
cpp_use_trt_list=$(func_parser_value "${lines[60]}")
|
||||
cpp_precision_key=$(func_parser_key "${lines[61]}")
|
||||
cpp_precision_list=$(func_parser_value "${lines[61]}")
|
||||
cpp_infer_model_key=$(func_parser_key "${lines[62]}")
|
||||
cpp_image_dir_key=$(func_parser_key "${lines[63]}")
|
||||
cpp_infer_img_dir=$(func_parser_value "${lines[63]}")
|
||||
cpp_infer_key1=$(func_parser_key "${lines[64]}")
|
||||
cpp_infer_value1=$(func_parser_value "${lines[64]}")
|
||||
cpp_benchmark_key=$(func_parser_key "${lines[65]}")
|
||||
cpp_benchmark_value=$(func_parser_value "${lines[65]}")
|
||||
inference_cmd=$(func_parser_value "${lines[3]}")
|
||||
cpp_use_gpu_key=$(func_parser_key "${lines[4]}")
|
||||
cpp_use_gpu_list=$(func_parser_value "${lines[4]}")
|
||||
cpp_use_mkldnn_key=$(func_parser_key "${lines[5]}")
|
||||
cpp_use_mkldnn_list=$(func_parser_value "${lines[5]}")
|
||||
cpp_cpu_threads_key=$(func_parser_key "${lines[6]}")
|
||||
cpp_cpu_threads_list=$(func_parser_value "${lines[6]}")
|
||||
cpp_batch_size_key=$(func_parser_key "${lines[7]}")
|
||||
cpp_batch_size_list=$(func_parser_value "${lines[7]}")
|
||||
cpp_use_trt_key=$(func_parser_key "${lines[8]}")
|
||||
cpp_use_trt_list=$(func_parser_value "${lines[8]}")
|
||||
cpp_precision_key=$(func_parser_key "${lines[9]}")
|
||||
cpp_precision_list=$(func_parser_value "${lines[9]}")
|
||||
cpp_infer_model_key=$(func_parser_key "${lines[10]}")
|
||||
cpp_image_dir_key=$(func_parser_key "${lines[11]}")
|
||||
cpp_infer_img_dir=$(func_parser_value "${lines[12]}")
|
||||
cpp_infer_key1=$(func_parser_key "${lines[13]}")
|
||||
cpp_infer_value1=$(func_parser_value "${lines[13]}")
|
||||
cpp_benchmark_key=$(func_parser_key "${lines[14]}")
|
||||
cpp_benchmark_value=$(func_parser_value "${lines[14]}")
|
||||
fi
|
||||
|
||||
|
||||
|
||||
LOG_PATH="./tests/output"
|
||||
mkdir -p ${LOG_PATH}
|
||||
status_log="${LOG_PATH}/results.log"
|
||||
|
@ -414,7 +453,7 @@ function func_cpp_inference(){
|
|||
done
|
||||
}
|
||||
|
||||
if [ ${MODE} = "infer" ]; then
|
||||
if [ ${MODE} = "infer" ] || [ ${MODE} = "klquant_infer" ]; then
|
||||
GPUID=$3
|
||||
if [ ${#GPUID} -le 0 ];then
|
||||
env=" "
|
||||
|
@ -447,7 +486,6 @@ if [ ${MODE} = "infer" ]; then
|
|||
func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant}
|
||||
Count=$(($Count + 1))
|
||||
done
|
||||
|
||||
elif [ ${MODE} = "cpp_infer" ]; then
|
||||
GPUID=$3
|
||||
if [ ${#GPUID} -le 0 ];then
|
||||
|
@ -481,6 +519,8 @@ elif [ ${MODE} = "serving_infer" ]; then
|
|||
#run serving
|
||||
func_serving "${web_service_cmd}"
|
||||
|
||||
|
||||
|
||||
else
|
||||
IFS="|"
|
||||
export Count=0
|
||||
|
|
|
@ -0,0 +1,204 @@
|
|||
#!/bin/bash
|
||||
source tests/common_func.sh
|
||||
|
||||
FILENAME=$1
|
||||
dataline=$(awk 'NR==52, NR==66{print}' $FILENAME)
|
||||
|
||||
# parser params
|
||||
IFS=$'\n'
|
||||
lines=(${dataline})
|
||||
|
||||
# parser cpp inference model
|
||||
use_opencv=$(func_parser_value "${lines[1]}")
|
||||
cpp_infer_model_dir_list=$(func_parser_value "${lines[2]}")
|
||||
cpp_infer_is_quant=$(func_parser_value "${lines[3]}")
|
||||
# parser cpp inference
|
||||
inference_cmd=$(func_parser_value "${lines[4]}")
|
||||
cpp_use_gpu_key=$(func_parser_key "${lines[5]}")
|
||||
cpp_use_gpu_list=$(func_parser_value "${lines[5]}")
|
||||
cpp_use_mkldnn_key=$(func_parser_key "${lines[6]}")
|
||||
cpp_use_mkldnn_list=$(func_parser_value "${lines[6]}")
|
||||
cpp_cpu_threads_key=$(func_parser_key "${lines[7]}")
|
||||
cpp_cpu_threads_list=$(func_parser_value "${lines[7]}")
|
||||
cpp_batch_size_key=$(func_parser_key "${lines[8]}")
|
||||
cpp_batch_size_list=$(func_parser_value "${lines[8]}")
|
||||
cpp_use_trt_key=$(func_parser_key "${lines[9]}")
|
||||
cpp_use_trt_list=$(func_parser_value "${lines[9]}")
|
||||
cpp_precision_key=$(func_parser_key "${lines[10]}")
|
||||
cpp_precision_list=$(func_parser_value "${lines[10]}")
|
||||
cpp_infer_model_key=$(func_parser_key "${lines[11]}")
|
||||
cpp_image_dir_key=$(func_parser_key "${lines[12]}")
|
||||
cpp_infer_img_dir=$(func_parser_value "${lines[12]}")
|
||||
cpp_infer_key1=$(func_parser_key "${lines[13]}")
|
||||
cpp_infer_value1=$(func_parser_value "${lines[13]}")
|
||||
cpp_benchmark_key=$(func_parser_key "${lines[14]}")
|
||||
cpp_benchmark_value=$(func_parser_value "${lines[14]}")
|
||||
|
||||
|
||||
LOG_PATH="./tests/output"
|
||||
mkdir -p ${LOG_PATH}
|
||||
status_log="${LOG_PATH}/results_cpp.log"
|
||||
|
||||
|
||||
function func_cpp_inference(){
|
||||
IFS='|'
|
||||
_script=$1
|
||||
_model_dir=$2
|
||||
_log_path=$3
|
||||
_img_dir=$4
|
||||
_flag_quant=$5
|
||||
# inference
|
||||
for use_gpu in ${cpp_use_gpu_list[*]}; do
|
||||
if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then
|
||||
for use_mkldnn in ${cpp_use_mkldnn_list[*]}; do
|
||||
if [ ${use_mkldnn} = "False" ] && [ ${_flag_quant} = "True" ]; then
|
||||
continue
|
||||
fi
|
||||
for threads in ${cpp_cpu_threads_list[*]}; do
|
||||
for batch_size in ${cpp_batch_size_list[*]}; do
|
||||
_save_log_path="${_log_path}/cpp_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}.log"
|
||||
set_infer_data=$(func_set_params "${cpp_image_dir_key}" "${_img_dir}")
|
||||
set_benchmark=$(func_set_params "${cpp_benchmark_key}" "${cpp_benchmark_value}")
|
||||
set_batchsize=$(func_set_params "${cpp_batch_size_key}" "${batch_size}")
|
||||
set_cpu_threads=$(func_set_params "${cpp_cpu_threads_key}" "${threads}")
|
||||
set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}")
|
||||
set_infer_params1=$(func_set_params "${cpp_infer_key1}" "${cpp_infer_value1}")
|
||||
command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${cpp_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
|
||||
eval $command
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${command}" "${status_log}"
|
||||
done
|
||||
done
|
||||
done
|
||||
elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then
|
||||
for use_trt in ${cpp_use_trt_list[*]}; do
|
||||
for precision in ${cpp_precision_list[*]}; do
|
||||
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then
|
||||
continue
|
||||
fi
|
||||
for batch_size in ${cpp_batch_size_list[*]}; do
|
||||
_save_log_path="${_log_path}/cpp_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
|
||||
set_infer_data=$(func_set_params "${cpp_image_dir_key}" "${_img_dir}")
|
||||
set_benchmark=$(func_set_params "${cpp_benchmark_key}" "${cpp_benchmark_value}")
|
||||
set_batchsize=$(func_set_params "${cpp_batch_size_key}" "${batch_size}")
|
||||
set_tensorrt=$(func_set_params "${cpp_use_trt_key}" "${use_trt}")
|
||||
set_precision=$(func_set_params "${cpp_precision_key}" "${precision}")
|
||||
set_model_dir=$(func_set_params "${cpp_infer_model_key}" "${_model_dir}")
|
||||
set_infer_params1=$(func_set_params "${cpp_infer_key1}" "${cpp_infer_value1}")
|
||||
command="${_script} ${cpp_use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
|
||||
eval $command
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${command}" "${status_log}"
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
else
|
||||
echo "Does not support hardware other than CPU and GPU Currently!"
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
|
||||
cd deploy/cpp_infer
|
||||
if [ ${use_opencv} = "True" ]; then
|
||||
if [ -d "opencv-3.4.7/opencv3/" ] && [ $(md5sum opencv-3.4.7.tar.gz | awk -F ' ' '{print $1}') = "faa2b5950f8bee3f03118e600c74746a" ];then
|
||||
echo "################### build opencv skipped ###################"
|
||||
else
|
||||
echo "################### build opencv ###################"
|
||||
rm -rf opencv-3.4.7.tar.gz opencv-3.4.7/
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/opencv-3.4.7.tar.gz
|
||||
tar -xf opencv-3.4.7.tar.gz
|
||||
|
||||
cd opencv-3.4.7/
|
||||
install_path=$(pwd)/opencv3
|
||||
|
||||
rm -rf build
|
||||
mkdir build
|
||||
cd build
|
||||
|
||||
cmake .. \
|
||||
-DCMAKE_INSTALL_PREFIX=${install_path} \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DBUILD_SHARED_LIBS=OFF \
|
||||
-DWITH_IPP=OFF \
|
||||
-DBUILD_IPP_IW=OFF \
|
||||
-DWITH_LAPACK=OFF \
|
||||
-DWITH_EIGEN=OFF \
|
||||
-DCMAKE_INSTALL_LIBDIR=lib64 \
|
||||
-DWITH_ZLIB=ON \
|
||||
-DBUILD_ZLIB=ON \
|
||||
-DWITH_JPEG=ON \
|
||||
-DBUILD_JPEG=ON \
|
||||
-DWITH_PNG=ON \
|
||||
-DBUILD_PNG=ON \
|
||||
-DWITH_TIFF=ON \
|
||||
-DBUILD_TIFF=ON
|
||||
|
||||
make -j
|
||||
make install
|
||||
cd ../
|
||||
echo "################### build opencv finished ###################"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
echo "################### build PaddleOCR demo ####################"
|
||||
if [ ${use_opencv} = "True" ]; then
|
||||
OPENCV_DIR=$(pwd)/opencv-3.4.7/opencv3/
|
||||
else
|
||||
OPENCV_DIR=''
|
||||
fi
|
||||
LIB_DIR=$(pwd)/Paddle/build/paddle_inference_install_dir/
|
||||
CUDA_LIB_DIR=$(dirname `find /usr -name libcudart.so`)
|
||||
CUDNN_LIB_DIR=$(dirname `find /usr -name libcudnn.so`)
|
||||
|
||||
BUILD_DIR=build
|
||||
rm -rf ${BUILD_DIR}
|
||||
mkdir ${BUILD_DIR}
|
||||
cd ${BUILD_DIR}
|
||||
cmake .. \
|
||||
-DPADDLE_LIB=${LIB_DIR} \
|
||||
-DWITH_MKL=ON \
|
||||
-DWITH_GPU=OFF \
|
||||
-DWITH_STATIC_LIB=OFF \
|
||||
-DWITH_TENSORRT=OFF \
|
||||
-DOPENCV_DIR=${OPENCV_DIR} \
|
||||
-DCUDNN_LIB=${CUDNN_LIB_DIR} \
|
||||
-DCUDA_LIB=${CUDA_LIB_DIR} \
|
||||
-DTENSORRT_DIR=${TENSORRT_DIR} \
|
||||
|
||||
make -j
|
||||
cd ../../../
|
||||
echo "################### build PaddleOCR demo finished ###################"
|
||||
|
||||
|
||||
# set cuda device
|
||||
GPUID=$2
|
||||
if [ ${#GPUID} -le 0 ];then
|
||||
env=" "
|
||||
else
|
||||
env="export CUDA_VISIBLE_DEVICES=${GPUID}"
|
||||
fi
|
||||
set CUDA_VISIBLE_DEVICES
|
||||
eval $env
|
||||
|
||||
|
||||
echo "################### run test ###################"
|
||||
export Count=0
|
||||
IFS="|"
|
||||
infer_quant_flag=(${cpp_infer_is_quant})
|
||||
for infer_model in ${cpp_infer_model_dir_list[*]}; do
|
||||
#run inference
|
||||
is_quant=${infer_quant_flag[Count]}
|
||||
func_cpp_inference "${inference_cmd}" "${infer_model}" "${LOG_PATH}" "${cpp_infer_img_dir}" ${is_quant}
|
||||
Count=$(($Count + 1))
|
||||
done
|
|
@ -0,0 +1,173 @@
|
|||
#!/bin/bash
|
||||
source tests/common_func.sh
|
||||
|
||||
FILENAME=$1
|
||||
dataline=$(awk 'NR==1, NR==51{print}' $FILENAME)
|
||||
|
||||
# parser params
|
||||
IFS=$'\n'
|
||||
lines=(${dataline})
|
||||
|
||||
# The training params
|
||||
model_name=$(func_parser_value "${lines[1]}")
|
||||
python=$(func_parser_value "${lines[2]}")
|
||||
gpu_list=$(func_parser_value "${lines[3]}")
|
||||
train_use_gpu_key=$(func_parser_key "${lines[4]}")
|
||||
train_use_gpu_value=$(func_parser_value "${lines[4]}")
|
||||
autocast_list=$(func_parser_value "${lines[5]}")
|
||||
autocast_key=$(func_parser_key "${lines[5]}")
|
||||
epoch_key=$(func_parser_key "${lines[6]}")
|
||||
epoch_num=$(func_parser_params "${lines[6]}")
|
||||
save_model_key=$(func_parser_key "${lines[7]}")
|
||||
train_batch_key=$(func_parser_key "${lines[8]}")
|
||||
train_batch_value=$(func_parser_params "${lines[8]}")
|
||||
pretrain_model_key=$(func_parser_key "${lines[9]}")
|
||||
pretrain_model_value=$(func_parser_value "${lines[9]}")
|
||||
train_model_name=$(func_parser_value "${lines[10]}")
|
||||
train_infer_img_dir=$(func_parser_value "${lines[11]}")
|
||||
train_param_key1=$(func_parser_key "${lines[12]}")
|
||||
train_param_value1=$(func_parser_value "${lines[12]}")
|
||||
|
||||
trainer_list=$(func_parser_value "${lines[14]}")
|
||||
trainer_norm=$(func_parser_key "${lines[15]}")
|
||||
norm_trainer=$(func_parser_value "${lines[15]}")
|
||||
pact_key=$(func_parser_key "${lines[16]}")
|
||||
pact_trainer=$(func_parser_value "${lines[16]}")
|
||||
fpgm_key=$(func_parser_key "${lines[17]}")
|
||||
fpgm_trainer=$(func_parser_value "${lines[17]}")
|
||||
distill_key=$(func_parser_key "${lines[18]}")
|
||||
distill_trainer=$(func_parser_value "${lines[18]}")
|
||||
trainer_key1=$(func_parser_key "${lines[19]}")
|
||||
trainer_value1=$(func_parser_value "${lines[19]}")
|
||||
trainer_key2=$(func_parser_key "${lines[20]}")
|
||||
trainer_value2=$(func_parser_value "${lines[20]}")
|
||||
|
||||
eval_py=$(func_parser_value "${lines[23]}")
|
||||
eval_key1=$(func_parser_key "${lines[24]}")
|
||||
eval_value1=$(func_parser_value "${lines[24]}")
|
||||
|
||||
save_infer_key=$(func_parser_key "${lines[27]}")
|
||||
export_weight=$(func_parser_key "${lines[28]}")
|
||||
norm_export=$(func_parser_value "${lines[29]}")
|
||||
pact_export=$(func_parser_value "${lines[30]}")
|
||||
fpgm_export=$(func_parser_value "${lines[31]}")
|
||||
distill_export=$(func_parser_value "${lines[32]}")
|
||||
export_key1=$(func_parser_key "${lines[33]}")
|
||||
export_value1=$(func_parser_value "${lines[33]}")
|
||||
export_key2=$(func_parser_key "${lines[34]}")
|
||||
export_value2=$(func_parser_value "${lines[34]}")
|
||||
|
||||
# parser inference model
|
||||
infer_model_dir_list=$(func_parser_value "${lines[36]}")
|
||||
infer_export_list=$(func_parser_value "${lines[37]}")
|
||||
infer_is_quant=$(func_parser_value "${lines[38]}")
|
||||
# parser inference
|
||||
inference_py=$(func_parser_value "${lines[39]}")
|
||||
use_gpu_key=$(func_parser_key "${lines[40]}")
|
||||
use_gpu_list=$(func_parser_value "${lines[40]}")
|
||||
use_mkldnn_key=$(func_parser_key "${lines[41]}")
|
||||
use_mkldnn_list=$(func_parser_value "${lines[41]}")
|
||||
cpu_threads_key=$(func_parser_key "${lines[42]}")
|
||||
cpu_threads_list=$(func_parser_value "${lines[42]}")
|
||||
batch_size_key=$(func_parser_key "${lines[43]}")
|
||||
batch_size_list=$(func_parser_value "${lines[43]}")
|
||||
use_trt_key=$(func_parser_key "${lines[44]}")
|
||||
use_trt_list=$(func_parser_value "${lines[44]}")
|
||||
precision_key=$(func_parser_key "${lines[45]}")
|
||||
precision_list=$(func_parser_value "${lines[45]}")
|
||||
infer_model_key=$(func_parser_key "${lines[46]}")
|
||||
image_dir_key=$(func_parser_key "${lines[47]}")
|
||||
infer_img_dir=$(func_parser_value "${lines[47]}")
|
||||
save_log_key=$(func_parser_key "${lines[48]}")
|
||||
benchmark_key=$(func_parser_key "${lines[49]}")
|
||||
benchmark_value=$(func_parser_value "${lines[49]}")
|
||||
infer_key1=$(func_parser_key "${lines[50]}")
|
||||
infer_value1=$(func_parser_value "${lines[50]}")
|
||||
|
||||
|
||||
LOG_PATH="./tests/output"
|
||||
mkdir -p ${LOG_PATH}
|
||||
status_log="${LOG_PATH}/results_python.log"
|
||||
|
||||
|
||||
function func_inference(){
|
||||
IFS='|'
|
||||
_python=$1
|
||||
_script=$2
|
||||
_model_dir=$3
|
||||
_log_path=$4
|
||||
_img_dir=$5
|
||||
_flag_quant=$6
|
||||
# inference
|
||||
for use_gpu in ${use_gpu_list[*]}; do
|
||||
if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then
|
||||
for use_mkldnn in ${use_mkldnn_list[*]}; do
|
||||
if [ ${use_mkldnn} = "False" ] && [ ${_flag_quant} = "True" ]; then
|
||||
continue
|
||||
fi
|
||||
for threads in ${cpu_threads_list[*]}; do
|
||||
for batch_size in ${batch_size_list[*]}; do
|
||||
_save_log_path="${_log_path}/infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}.log"
|
||||
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
|
||||
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
|
||||
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
|
||||
set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
|
||||
set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
|
||||
set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
|
||||
command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
|
||||
eval $command
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${command}" "${status_log}"
|
||||
done
|
||||
done
|
||||
done
|
||||
elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then
|
||||
for use_trt in ${use_trt_list[*]}; do
|
||||
for precision in ${precision_list[*]}; do
|
||||
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then
|
||||
continue
|
||||
fi
|
||||
for batch_size in ${batch_size_list[*]}; do
|
||||
_save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
|
||||
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
|
||||
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
|
||||
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
|
||||
set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}")
|
||||
set_precision=$(func_set_params "${precision_key}" "${precision}")
|
||||
set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
|
||||
set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
|
||||
command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
|
||||
eval $command
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${command}" "${status_log}"
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
else
|
||||
echo "Does not support hardware other than CPU and GPU Currently!"
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
|
||||
# set cuda device
|
||||
GPUID=$2
|
||||
if [ ${#GPUID} -le 0 ];then
|
||||
env=" "
|
||||
else
|
||||
env="export CUDA_VISIBLE_DEVICES=${GPUID}"
|
||||
fi
|
||||
set CUDA_VISIBLE_DEVICES
|
||||
eval $env
|
||||
|
||||
|
||||
echo "################### run test ###################"
|
|
@ -0,0 +1,131 @@
|
|||
#!/bin/bash
|
||||
source tests/common_func.sh
|
||||
|
||||
FILENAME=$1
|
||||
dataline=$(awk 'NR==67, NR==81{print}' $FILENAME)
|
||||
|
||||
# parser params
|
||||
IFS=$'\n'
|
||||
lines=(${dataline})
|
||||
|
||||
# parser serving
|
||||
trans_model_py=$(func_parser_value "${lines[1]}")
|
||||
infer_model_dir_key=$(func_parser_key "${lines[2]}")
|
||||
infer_model_dir_value=$(func_parser_value "${lines[2]}")
|
||||
model_filename_key=$(func_parser_key "${lines[3]}")
|
||||
model_filename_value=$(func_parser_value "${lines[3]}")
|
||||
params_filename_key=$(func_parser_key "${lines[4]}")
|
||||
params_filename_value=$(func_parser_value "${lines[4]}")
|
||||
serving_server_key=$(func_parser_key "${lines[5]}")
|
||||
serving_server_value=$(func_parser_value "${lines[5]}")
|
||||
serving_client_key=$(func_parser_key "${lines[6]}")
|
||||
serving_client_value=$(func_parser_value "${lines[6]}")
|
||||
serving_dir_value=$(func_parser_value "${lines[7]}")
|
||||
web_service_py=$(func_parser_value "${lines[8]}")
|
||||
web_use_gpu_key=$(func_parser_key "${lines[9]}")
|
||||
web_use_gpu_list=$(func_parser_value "${lines[9]}")
|
||||
web_use_mkldnn_key=$(func_parser_key "${lines[10]}")
|
||||
web_use_mkldnn_list=$(func_parser_value "${lines[10]}")
|
||||
web_cpu_threads_key=$(func_parser_key "${lines[11]}")
|
||||
web_cpu_threads_list=$(func_parser_value "${lines[11]}")
|
||||
web_use_trt_key=$(func_parser_key "${lines[12]}")
|
||||
web_use_trt_list=$(func_parser_value "${lines[12]}")
|
||||
web_precision_key=$(func_parser_key "${lines[13]}")
|
||||
web_precision_list=$(func_parser_value "${lines[13]}")
|
||||
pipeline_py=$(func_parser_value "${lines[14]}")
|
||||
|
||||
|
||||
LOG_PATH="./tests/output"
|
||||
mkdir -p ${LOG_PATH}
|
||||
status_log="${LOG_PATH}/results_serving.log"
|
||||
|
||||
|
||||
function func_serving(){
|
||||
IFS='|'
|
||||
_python=$1
|
||||
_script=$2
|
||||
_model_dir=$3
|
||||
# pdserving
|
||||
set_dirname=$(func_set_params "${infer_model_dir_key}" "${infer_model_dir_value}")
|
||||
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
|
||||
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
|
||||
set_serving_server=$(func_set_params "${serving_server_key}" "${serving_server_value}")
|
||||
set_serving_client=$(func_set_params "${serving_client_key}" "${serving_client_value}")
|
||||
trans_model_cmd="${python} ${trans_model_py} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_serving_server} ${set_serving_client}"
|
||||
eval $trans_model_cmd
|
||||
cd ${serving_dir_value}
|
||||
echo $PWD
|
||||
unset https_proxy
|
||||
unset http_proxy
|
||||
for use_gpu in ${web_use_gpu_list[*]}; do
|
||||
echo ${ues_gpu}
|
||||
if [ ${use_gpu} = "null" ]; then
|
||||
for use_mkldnn in ${web_use_mkldnn_list[*]}; do
|
||||
if [ ${use_mkldnn} = "False" ]; then
|
||||
continue
|
||||
fi
|
||||
for threads in ${web_cpu_threads_list[*]}; do
|
||||
_save_log_path="${_log_path}/server_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_1.log"
|
||||
set_cpu_threads=$(func_set_params "${web_cpu_threads_key}" "${threads}")
|
||||
web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}=${use_gpu} ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} &>${_save_log_path} &"
|
||||
eval $web_service_cmd
|
||||
sleep 2s
|
||||
pipeline_cmd="${python} ${pipeline_py}"
|
||||
eval $pipeline_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}"
|
||||
PID=$!
|
||||
kill $PID
|
||||
sleep 2s
|
||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||
done
|
||||
done
|
||||
elif [ ${use_gpu} = "0" ]; then
|
||||
for use_trt in ${web_use_trt_list[*]}; do
|
||||
for precision in ${web_precision_list[*]}; do
|
||||
if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
|
||||
continue
|
||||
fi
|
||||
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
|
||||
continue
|
||||
fi
|
||||
_save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_1.log"
|
||||
set_tensorrt=$(func_set_params "${web_use_trt_key}" "${use_trt}")
|
||||
set_precision=$(func_set_params "${web_precision_key}" "${precision}")
|
||||
web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} &>${_save_log_path} & "
|
||||
eval $web_service_cmd
|
||||
sleep 2s
|
||||
pipeline_cmd="${python} ${pipeline_py}"
|
||||
eval $pipeline_cmd
|
||||
last_status=${PIPESTATUS[0]}
|
||||
eval "cat ${_save_log_path}"
|
||||
status_check $last_status "${pipeline_cmd}" "${status_log}"
|
||||
PID=$!
|
||||
kill $PID
|
||||
sleep 2s
|
||||
ps ux | grep -E 'web_service|pipeline' | awk '{print $2}' | xargs kill -s 9
|
||||
done
|
||||
done
|
||||
else
|
||||
echo "Does not support hardware other than CPU and GPU Currently!"
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
|
||||
# set cuda device
|
||||
GPUID=$2
|
||||
if [ ${#GPUID} -le 0 ];then
|
||||
env=" "
|
||||
else
|
||||
env="export CUDA_VISIBLE_DEVICES=${GPUID}"
|
||||
fi
|
||||
set CUDA_VISIBLE_DEVICES
|
||||
eval $env
|
||||
|
||||
|
||||
echo "################### run test ###################"
|
|
@ -54,8 +54,7 @@ def main():
|
|||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
use_sar = config['Architecture']['algorithm'] == "SAR"
|
||||
extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"]
|
||||
if "model_type" in config['Architecture'].keys():
|
||||
model_type = config['Architecture']['model_type']
|
||||
else:
|
||||
|
@ -72,7 +71,7 @@ def main():
|
|||
|
||||
# start eval
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class, model_type, use_srn, use_sar)
|
||||
eval_class, model_type, extra_input)
|
||||
logger.info('metric eval ***************')
|
||||
for k, v in metric.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
||||
from ppocr.utils.utility import print_dict
|
||||
import tools.program as program
|
||||
|
||||
|
||||
def main():
|
||||
global_config = config['Global']
|
||||
# build dataloader
|
||||
config['Eval']['dataset']['name'] = config['Train']['dataset']['name']
|
||||
config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][
|
||||
'data_dir']
|
||||
config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
|
||||
'label_file_list']
|
||||
eval_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
global_config)
|
||||
|
||||
# build model
|
||||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
#set return_features = True
|
||||
config['Architecture']["Head"]["return_feats"] = True
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
best_model_dict = load_dygraph_params(config, model, logger, None)
|
||||
if len(best_model_dict):
|
||||
logger.info('metric in ckpt ***************')
|
||||
for k, v in best_model_dict.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
||||
# get features from train data
|
||||
char_center = program.get_center(model, eval_dataloader, post_process_class)
|
||||
|
||||
#serialize to disk
|
||||
with open("train_center.pkl", 'wb') as f:
|
||||
pickle.dump(char_center, f)
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess()
|
||||
main()
|
|
@ -49,6 +49,12 @@ def export_single_model(model, arch_config, save_path, logger):
|
|||
]
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "SAR":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 48, 160], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
else:
|
||||
infer_shape = [3, -1, -1]
|
||||
if arch_config["model_type"] == "rec":
|
||||
|
|
|
@ -141,7 +141,6 @@ if __name__ == "__main__":
|
|||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
img = img[:, :, ::-1]
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
|
|
|
@ -68,6 +68,13 @@ class TextRecognizer(object):
|
|||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
elif self.rec_algorithm == "SAR":
|
||||
postprocess_params = {
|
||||
'name': 'SARLabelDecode',
|
||||
"character_type": args.rec_char_type,
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'rec', logger)
|
||||
|
@ -194,6 +201,41 @@ class TextRecognizer(object):
|
|||
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2)
|
||||
|
||||
def resize_norm_img_sar(self, img, image_shape,
|
||||
width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
valid_ratio = 1.0
|
||||
# make sure new_width is an integral multiple of width_divisor.
|
||||
width_divisor = int(1 / width_downsample_ratio)
|
||||
# resize
|
||||
ratio = w / float(h)
|
||||
resize_w = math.ceil(imgH * ratio)
|
||||
if resize_w % width_divisor != 0:
|
||||
resize_w = round(resize_w / width_divisor) * width_divisor
|
||||
if imgW_min is not None:
|
||||
resize_w = max(imgW_min, resize_w)
|
||||
if imgW_max is not None:
|
||||
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
||||
resize_w = min(imgW_max, resize_w)
|
||||
resized_image = cv2.resize(img, (resize_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
# norm
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
resize_shape = resized_image.shape
|
||||
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
||||
padding_im[:, :, 0:resize_w] = resized_image
|
||||
pad_shape = padding_im.shape
|
||||
|
||||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
|
@ -216,11 +258,19 @@ class TextRecognizer(object):
|
|||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
if self.rec_algorithm != "SRN":
|
||||
if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR":
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
elif self.rec_algorithm == "SAR":
|
||||
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
|
||||
img_list[indices[ino]], self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
valid_ratio = np.expand_dims(valid_ratio, axis=0)
|
||||
valid_ratios = []
|
||||
valid_ratios.append(valid_ratio)
|
||||
norm_img_batch.append(norm_img)
|
||||
else:
|
||||
norm_img = self.process_image_srn(
|
||||
img_list[indices[ino]], self.rec_image_shape, 8, 25)
|
||||
|
@ -266,6 +316,25 @@ class TextRecognizer(object):
|
|||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
preds = {"predict": outputs[2]}
|
||||
elif self.rec_algorithm == "SAR":
|
||||
valid_ratios = np.concatenate(valid_ratios)
|
||||
inputs = [
|
||||
norm_img_batch,
|
||||
valid_ratios,
|
||||
]
|
||||
input_names = self.predictor.get_input_names()
|
||||
for i in range(len(input_names)):
|
||||
input_tensor = self.predictor.get_input_handle(input_names[
|
||||
i])
|
||||
input_tensor.copy_from_cpu(inputs[i])
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
preds = outputs[0]
|
||||
else:
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.run()
|
||||
|
|
117
tools/program.py
|
@ -31,6 +31,7 @@ from ppocr.utils.stats import TrainingStats
|
|||
from ppocr.utils.save_load import save_model
|
||||
from ppocr.utils.utility import print_dict
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils import profiler
|
||||
from ppocr.data import build_dataloader
|
||||
import numpy as np
|
||||
|
||||
|
@ -42,6 +43,13 @@ class ArgsParser(ArgumentParser):
|
|||
self.add_argument("-c", "--config", help="configuration file to use")
|
||||
self.add_argument(
|
||||
"-o", "--opt", nargs='+', help="set configuration options")
|
||||
self.add_argument(
|
||||
'-p',
|
||||
'--profiler_options',
|
||||
type=str,
|
||||
default=None,
|
||||
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
|
||||
)
|
||||
|
||||
def parse_args(self, argv=None):
|
||||
args = super(ArgsParser, self).parse_args(argv)
|
||||
|
@ -158,6 +166,7 @@ def train(config,
|
|||
epoch_num = config['Global']['epoch_num']
|
||||
print_batch_step = config['Global']['print_batch_step']
|
||||
eval_batch_step = config['Global']['eval_batch_step']
|
||||
profiler_options = config['profiler_options']
|
||||
|
||||
global_step = 0
|
||||
if 'global_step' in pre_best_model_dict:
|
||||
|
@ -186,12 +195,13 @@ def train(config,
|
|||
model.train()
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
|
||||
use_sar = config['Architecture']['algorithm'] == 'SAR'
|
||||
extra_input = config['Architecture'][
|
||||
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
|
||||
try:
|
||||
model_type = config['Architecture']['model_type']
|
||||
except:
|
||||
model_type = None
|
||||
algorithm = config['Architecture']['algorithm']
|
||||
|
||||
if 'start_epoch' in best_model_dict:
|
||||
start_epoch = best_model_dict['start_epoch']
|
||||
|
@ -208,6 +218,7 @@ def train(config,
|
|||
max_iter = len(train_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(train_dataloader)
|
||||
for idx, batch in enumerate(train_dataloader):
|
||||
profiler.add_profiler_step(profiler_options)
|
||||
train_reader_cost += time.time() - batch_start
|
||||
if idx >= max_iter:
|
||||
break
|
||||
|
@ -215,7 +226,7 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
if use_srn or model_type == 'table' or use_nrtr or use_sar:
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
|
@ -279,8 +290,7 @@ def train(config,
|
|||
post_process_class,
|
||||
eval_class,
|
||||
model_type,
|
||||
use_srn=use_srn,
|
||||
use_sar=use_sar)
|
||||
extra_input=extra_input)
|
||||
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||
logger.info(cur_metric_str)
|
||||
|
@ -351,9 +361,8 @@ def eval(model,
|
|||
valid_dataloader,
|
||||
post_process_class,
|
||||
eval_class,
|
||||
model_type,
|
||||
use_srn=False,
|
||||
use_sar=False):
|
||||
model_type=None,
|
||||
extra_input=False):
|
||||
model.eval()
|
||||
with paddle.no_grad():
|
||||
total_frame = 0.0
|
||||
|
@ -366,7 +375,7 @@ def eval(model,
|
|||
break
|
||||
images = batch[0]
|
||||
start = time.time()
|
||||
if use_srn or model_type == 'table' or use_sar:
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
|
@ -390,25 +399,65 @@ def eval(model,
|
|||
return metric
|
||||
|
||||
|
||||
def update_center(char_center, post_result, preds):
|
||||
result, label = post_result
|
||||
feats, logits = preds
|
||||
logits = paddle.argmax(logits, axis=-1)
|
||||
feats = feats.numpy()
|
||||
logits = logits.numpy()
|
||||
|
||||
for idx_sample in range(len(label)):
|
||||
if result[idx_sample][0] == label[idx_sample][0]:
|
||||
feat = feats[idx_sample]
|
||||
logit = logits[idx_sample]
|
||||
for idx_time in range(len(logit)):
|
||||
index = logit[idx_time]
|
||||
if index in char_center.keys():
|
||||
char_center[index][0] = (
|
||||
char_center[index][0] * char_center[index][1] +
|
||||
feat[idx_time]) / (char_center[index][1] + 1)
|
||||
char_center[index][1] += 1
|
||||
else:
|
||||
char_center[index] = [feat[idx_time], 1]
|
||||
return char_center
|
||||
|
||||
|
||||
def get_center(model, eval_dataloader, post_process_class):
|
||||
pbar = tqdm(total=len(eval_dataloader), desc='get center:')
|
||||
max_iter = len(eval_dataloader) - 1 if platform.system(
|
||||
) == "Windows" else len(eval_dataloader)
|
||||
char_center = dict()
|
||||
for idx, batch in enumerate(eval_dataloader):
|
||||
if idx >= max_iter:
|
||||
break
|
||||
images = batch[0]
|
||||
start = time.time()
|
||||
preds = model(images)
|
||||
|
||||
batch = [item.numpy() for item in batch]
|
||||
# Obtain usable results from post-processing methods
|
||||
total_time += time.time() - start
|
||||
# Evaluate the results of the current batch
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
|
||||
#update char_center
|
||||
char_center = update_center(char_center, post_result, preds)
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
for key in char_center.keys():
|
||||
char_center[key] = char_center[key][0]
|
||||
return char_center
|
||||
|
||||
|
||||
def preprocess(is_train=False):
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
profiler_options = FLAGS.profiler_options
|
||||
config = load_config(FLAGS.config)
|
||||
merge_config(FLAGS.opt)
|
||||
profile_dic = {"profiler_options": FLAGS.profiler_options}
|
||||
merge_config(profile_dic)
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['Global']['use_gpu']
|
||||
check_gpu(use_gpu)
|
||||
|
||||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
device = paddle.set_device(device)
|
||||
|
||||
config['Global']['distributed'] = dist.get_world_size() != 1
|
||||
if is_train:
|
||||
# save_config
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
|
@ -420,6 +469,28 @@ def preprocess(is_train=False):
|
|||
else:
|
||||
log_file = None
|
||||
logger = get_logger(name='root', log_file=log_file)
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['Global']['use_gpu']
|
||||
check_gpu(use_gpu)
|
||||
|
||||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
|
||||
'SEED'
|
||||
]
|
||||
windows_not_support_list = ['PSE']
|
||||
if platform.system() == "Windows" and alg in windows_not_support_list:
|
||||
logger.warning('{} is not support in Windows now'.format(
|
||||
windows_not_support_list))
|
||||
sys.exit()
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
device = paddle.set_device(device)
|
||||
|
||||
config['Global']['distributed'] = dist.get_world_size() != 1
|
||||
|
||||
if config['Global']['use_visualdl']:
|
||||
from visualdl import LogWriter
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
|
|