add go deploy for paddleocr
parent
d7d4892e00
commit
835b69c04b
|
@ -0,0 +1,328 @@
|
|||
# PaddleOCR-GO
|
||||
|
||||
本服务是[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)的golang部署版本。
|
||||
|
||||
## 1. 环境准备
|
||||
|
||||
### 运行环境
|
||||
|
||||
- go: 1.14
|
||||
- OpenCV: 4.3.0
|
||||
- PaddlePaddle: 1.8.4
|
||||
- 编译环境:cmake 3.15.4 | gcc 4.8.5
|
||||
- 基于Centos 7.4运行环境编译,Windows请自行解决`OpenCV`和`PaddlePaddle`的编译问题
|
||||
|
||||
*另外,以下编译以`.bashrc`个人环境变量配置文件,如果使用`zsh`,请自行更换为`.zshrc`*
|
||||
|
||||
### 1.1 安装golang
|
||||
|
||||
从官网下载[golang](https://golang.org/dl/),建议选择1.13版本以上进行安装。下载完成后,直接解压你需要的安装目录,并配置相关环境变量,此处以1.14版本为例。
|
||||
|
||||
```shell
|
||||
# 下载golang
|
||||
wget https://golang.org/dl/go1.14.10.linux-amd64.tar.gz
|
||||
|
||||
# 解压到 /usr/local 目录下
|
||||
tar -xzvf go1.14.10.linux-amd64.tar.gz -C /usr/local
|
||||
|
||||
# 配置GOROOT,即go的安装目录
|
||||
echo "export GOROOT=/usr/local/go" >> ~/.bashrc
|
||||
# 配置GOPATH,即go相关package的安装目录,可自定义一个目录
|
||||
echo "export GOPATH=$HOME/golang" >> ~/.bashrc
|
||||
# 配置GOPROXY,即go mod包管理器的下载代理,同时打开mod模式
|
||||
echo "export GO111MODULE=on" >> ~/.bashrc
|
||||
echo "export GOPROXY=https://mirrors.aliyun.com/goproxy/" >> ~/.bashrc
|
||||
source ~/.bashrc
|
||||
```
|
||||
|
||||
### 1.2 编译OpenCV库
|
||||
|
||||
go语言中,OpenCV的使用主要以[gocv](https://github.com/hybridgroup/gocv)包为主,gocv使用cgo调用OpenCV提供接口,因此还是需要编译OpenCV库。
|
||||
|
||||
**踩坑指南之一:[gocv官方实现](https://github.com/hybridgroup/gocv)中,部分接口并没有与原版C++的OpenCV的API保持一致,导致图片处理结果会出现一定的数值偏差。为处理这种偏差,[该仓库](https://github.com/LKKlein/gocv)fork了一份gocv官方源码,并对部分这些不一致的API进行了修正,保证结果与其他语言的一致性。**
|
||||
|
||||
对于OpenCV的编译,gocv官方提供了[Makefile](https://github.com/LKKlein/gocv/blob/lk/Makefile),可以一键进行安装,具体安装步骤详见[官方指南](https://github.com/LKKlein/gocv/blob/lk/README_ORIGIN.md#ubuntulinux)。
|
||||
|
||||
这里提供逐步安装的方式,方便排查错误。
|
||||
|
||||
- 下载并解压OpenCV-4.3.0和OpenCV-Contrib-4.3.0
|
||||
|
||||
```shell
|
||||
# 创建opencv安装目录
|
||||
mkdir -p ~/opencv
|
||||
|
||||
# 下载OpenCV
|
||||
cd ~/opencv
|
||||
curl -sL https://github.com/opencv/opencv/archive/4.3.0.zip > opencv.zip
|
||||
unzip -q opencv.zip
|
||||
rm -rf opencv.zip
|
||||
|
||||
# 下载OpenCV-Contrib
|
||||
curl -sL https://github.com/opencv/opencv_contrib/archive/4.3.0.zip > opencv-contrib.zip
|
||||
unzip -q opencv-contrib.zip
|
||||
rm -rf opencv-contrib.zip
|
||||
```
|
||||
|
||||
- 安装相关依赖
|
||||
|
||||
```shell
|
||||
sudo yum -y install pkgconfig cmake git gtk2-devel libpng-devel libjpeg-devel libtiff-devel tbb tbb-devel libdc1394-devel
|
||||
```
|
||||
|
||||
- 编译安装
|
||||
|
||||
```shell
|
||||
mkdir -p ~/.local/opencv-4.3.0
|
||||
cd ~/opencv/opencv-4.3.0
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -D WITH_IPP=OFF \
|
||||
-D WITH_OPENGL=OFF \
|
||||
-D WITH_QT=OFF \
|
||||
-D BUILD_EXAMPLES=OFF \
|
||||
-D BUILD_TESTS=OFF \
|
||||
-D BUILD_PERF_TESTS=OFF \
|
||||
-D BUILD_opencv_java=OFF \
|
||||
-D BUILD_opencv_python=OFF \
|
||||
-D BUILD_opencv_python2=OFF \
|
||||
-D BUILD_opencv_python3=OFF \
|
||||
-D OPENCV_GENERATE_PKGCONFIG=ON \
|
||||
-D CMAKE_INSTALL_PREFIX=$HOME/.local/opencv-4.3.0 \
|
||||
-D OPENCV_ENABLE_NONFREE=ON \
|
||||
-D OPENCV_EXTRA_MODULES_PATH=$HOME/opencv/opencv_contrib-4.3.0/modules ..
|
||||
make -j8
|
||||
make install
|
||||
sudo ldconfig
|
||||
```
|
||||
|
||||
make进行编译时,可能出现因`xfeatures2d`的两个模块下载失败导致的编译失败,这里只需要手动下载这部分文件到`$HOME/opencv/opencv_contrib-4.3.0/modules/xfeatures2d/src`目录下,然后重新执行`make -j8`即可。这部分文件地址可参考[这里](https://github.com/opencv/opencv_contrib/issues/1301#issuecomment-447181426)给出的链接。
|
||||
|
||||
- 配置环境变量
|
||||
|
||||
```shell
|
||||
echo "export PKG_CONFIG_PATH=$PKG_CONFIG_PATH:$HOME/.local/opencv-4.3.0/lib64/pkgconfig" >> ~/.bashrc
|
||||
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/opencv-4.3.0/lib64" >> ~/.bashrc
|
||||
source ~/.bashrc
|
||||
```
|
||||
|
||||
- 验证安装
|
||||
|
||||
```shell
|
||||
# 安装gocv包,先mod init
|
||||
go mod init opencv
|
||||
go get -u github.com/LKKlein/gocv
|
||||
|
||||
# 验证安装结果
|
||||
cd $GOPATH/pkg/mod/github.com/!l!k!klein/gocv@v0.28.0
|
||||
go run ./cmd/version/main.go
|
||||
|
||||
# 输出
|
||||
# gocv version: 0.28.0
|
||||
# opencv lib version: 4.3.0
|
||||
```
|
||||
|
||||
### 1.3 编译PaddlePaddle的C语言API
|
||||
|
||||
go语言只能通过cgo调用C语言API,而不能直接与C++进行交互,因此需要编译PaddlePaddle的C语言API。当然,也可以自己写C语言调用C++的代码和头文件,这样就可以直接使用PaddlePaddle提供的已编译的C++推理库,无需自己手动编译,详见[该仓库](https://github.com/LKKlein/paddleocr-go/tree/dev_cxx)。
|
||||
|
||||
- 获取PaddlePaddle源代码
|
||||
|
||||
```shell
|
||||
cd ~
|
||||
git clone --recurse-submodules https://github.com/paddlepaddle/paddle
|
||||
|
||||
# 切换到v1.8.4版本
|
||||
cd paddle
|
||||
git checkout v1.8.4
|
||||
|
||||
# 目前版本无论单卡还是多卡都需要先安装nccl
|
||||
git clone https://github.com/NVIDIA/nccl.git
|
||||
make -j8
|
||||
make install
|
||||
```
|
||||
|
||||
- 编译Paddle源代码
|
||||
|
||||
**踩坑指南之二:PaddlePaddle的C语言API实现有一个bug,即获取输入输出变量名时只能获取到第一个模型的变量名,后续模型都无法获取输入输出变量名,进而无法获取到模型输出,详情见[issue](https://github.com/PaddlePaddle/Paddle/issues/28309)。因此,编译前需要手动将`paddle/fluid/inference/capi/pd_predictor.cc`文件中`210行`与`215行`的`static`删除。**
|
||||
|
||||
在处理完该bug之后,才能进行后续编译。相关编译参数见[官方文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html#id12),注意部分参数需要相关依赖,请确保依赖完整再启用。
|
||||
|
||||
```shell
|
||||
# 创建c++推理库文件夹
|
||||
mkdir -p ~/paddle_inference
|
||||
export PADDLE_ROOT=`$HOME/paddle_inference`
|
||||
|
||||
# 执行编译
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DFLUID_INFERENCE_INSTALL_DIR=$PADDLE_ROOT \
|
||||
-DWITH_CONTRIB=OFF \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DWITH_PYTHON=OFF \
|
||||
-DWITH_MKL=ON \
|
||||
-DWITH_GPU=ON \
|
||||
-DON_INFER=ON \
|
||||
--WITH_MKLDNN=ON \
|
||||
--WITH_XBYAK=ON \
|
||||
--WITH_DSO=OFF ..
|
||||
make
|
||||
make inference_lib_dist
|
||||
```
|
||||
|
||||
编译完成后,可以在`build/fluid_inference_c_install_dir`目录下,看到以下生成的文件
|
||||
|
||||
```
|
||||
build/fluid_inference_c_install_dir
|
||||
├── paddle
|
||||
├── third_party
|
||||
└── version.txt
|
||||
```
|
||||
|
||||
其中`paddle`就是Paddle库的C语言预测API,`version.txt`中包含当前预测库的版本信息。
|
||||
|
||||
|
||||
## 2. paddleocr-go预测库
|
||||
|
||||
### 2.1 安装paddleocr-go
|
||||
|
||||
直接执行安装命令
|
||||
|
||||
```shell
|
||||
go get github.com/PaddlePaddle/PaddleOCR/deploy/paddleocr-go
|
||||
```
|
||||
|
||||
### 2.2 相关使用API
|
||||
|
||||
在go中使用import引入包
|
||||
|
||||
```go
|
||||
import github.com/PaddlePaddle/PaddleOCR/deploy/paddleocr-go/ocr
|
||||
```
|
||||
|
||||
- 预测结果结构体
|
||||
|
||||
```go
|
||||
type OCRText struct {
|
||||
BBox [][]int `json:"bbox"`
|
||||
Text string `json:"text"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
```
|
||||
|
||||
一张图的OCR结果包含多个`OCRText`结果,每个结果包含预测框、预测文本、预测文本得分。
|
||||
|
||||
- OCR预测类
|
||||
|
||||
|
||||
```go
|
||||
func NewOCRSystem(confFile string, a map[string]interface{}) *OCRSystem
|
||||
```
|
||||
|
||||
`OCRSystem`是主要对外提供API的结构;
|
||||
|
||||
`confFile`是yaml配置文件的路径,可在配置文件中修改相关预测参数,也可以传空字符串,这时会全部使用默认配置;
|
||||
|
||||
`a`是可以在代码中直接定义的配置参数,优先级高于配置文件,会覆盖配置文件和默认配置的参数。
|
||||
|
||||
- 单张图预测API
|
||||
|
||||
```go
|
||||
func (ocr *OCRSystem) PredictOneImage(img gocv.Mat) []OCRText
|
||||
```
|
||||
|
||||
|
||||
- 图片文件夹预测API
|
||||
|
||||
```go
|
||||
func (ocr *OCRSystem) PredictDirImages(dirname string) map[string][]OCRText
|
||||
```
|
||||
|
||||
`dirname`图片文件夹的目录,默认会预测改目录下所有`jpg`和`png`图片,并返回每张图的预测结果。
|
||||
|
||||
- OCR Server
|
||||
|
||||
```go
|
||||
func (ocr *OCRSystem) StartServer(port string)
|
||||
```
|
||||
|
||||
开启OCR预测Server,开启后,使用`post`请求上传需要识别的图片至`http://$ip:$port/ocr`即可直接获取该图片上所有文本的识别结果。其中,`$ip`是开启服务的主机`ip`或`127.0.0.1`的本地ip, `$port`是传入的端口参数。
|
||||
|
||||
|
||||
## 3. 预测demo
|
||||
|
||||
### 3.1 修改预测配置
|
||||
|
||||
当前给定的配置文件`config/conf.yaml`中,包含了默认的OCR预测配置参数,可根据个人需要更改相关参数。
|
||||
|
||||
比如,将`use_gpu`改为`false`,使用CPU执行预测;将`det_model_dir`, `rec_model_dir`, `cls_model_dir`都更改为自己的本地模型路径,也或者是更改字典`rec_char_dict_path`的路径。配置参数包含了预测引擎、检测模型、检测阈值、方向分类模型、识别模型及阈值的相关参数,具体参数的意义可参见[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/whl.md#%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E)。
|
||||
|
||||
### 3.2 编译预测demo
|
||||
|
||||
- 下载`paddleocr-go`代码
|
||||
|
||||
```shell
|
||||
git clone https://github.com/PaddlePaddle/PaddleOCR
|
||||
cd PaddleOCR/deploy/paddleocr-go
|
||||
```
|
||||
|
||||
- 准备paddle_c环境
|
||||
|
||||
```shell
|
||||
cp -r ~/paddle/build/fluid_inference_c_install_dir/* paddle_c/
|
||||
```
|
||||
|
||||
- 编译demo
|
||||
|
||||
```shell
|
||||
go build demo.go
|
||||
```
|
||||
|
||||
### 3.3 执行预测demo
|
||||
|
||||
预测demo提供了三种预测方式,分别是单张图预测、文件夹批量预测、OCR Server预测。三者命令行优先级依次降低。
|
||||
|
||||
#### 3.3.1 单张图预测
|
||||
|
||||
```shell
|
||||
./demo --config config/conf.yaml --image images/test.jpg
|
||||
```
|
||||
|
||||
执行完成,会输出以下内容:
|
||||
|
||||
<img src="./images/result/single_img_result.jpg" style="zoom:80%;" />
|
||||
|
||||
#### 3.3.2 文件夹批量预测
|
||||
|
||||
```shell
|
||||
./demo --config config/conf.yaml --image_dir ./images
|
||||
```
|
||||
|
||||
执行完成,会输出以下内容:
|
||||
|
||||
<img src="./images/result/img_dir_result.jpg" style="zoom:80%;" />
|
||||
|
||||
#### 3.3.3 开启OCR Server
|
||||
|
||||
```shell
|
||||
./demo --use_servering --port=18600
|
||||
```
|
||||
|
||||
开启服务后,可以在其他客户端中通过`post`请求进行ocr预测。此处以`Python`客户端为例,如下所示
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
files = {'image': open('images/test.jpg','rb')}
|
||||
url = "http://127.0.0.1:18600/ocr"
|
||||
|
||||
r = requests.post(url, files=files)
|
||||
print(r.text)
|
||||
```
|
||||
|
||||
执行完成可以得到以下结果
|
||||
|
||||

|
||||
|
||||
最后,在Python中将上述结果可视化可以得到以下结果
|
||||
|
||||

|
|
@ -0,0 +1,47 @@
|
|||
# params for prediction engine
|
||||
use_gpu: true
|
||||
ir_optim: true
|
||||
enable_mkldnn: false
|
||||
# use_zero_copy_run: true
|
||||
use_tensorrt: false
|
||||
num_cpu_threads: 6
|
||||
gpu_id: 0
|
||||
gpu_mem: 2000
|
||||
|
||||
# params for text detector
|
||||
det_algorithm: "DB"
|
||||
det_model_dir: "https://paddleocr.bj.bcebos.com/20-09-22/mobile/det/ch_ppocr_mobile_v1.1_det_infer.tar"
|
||||
det_max_side_len: 960
|
||||
|
||||
# DB parmas
|
||||
det_db_thresh: 0.3
|
||||
det_db_box_thresh: 0.5
|
||||
det_db_unclip_ratio: 2.0
|
||||
|
||||
# EAST parmas
|
||||
det_east_score_thresh: 0.8
|
||||
det_east_cover_thresh: 0.1
|
||||
det_east_nms_thresh: 0.2
|
||||
|
||||
# params for text recognizer
|
||||
rec_algorithm: "CRNN"
|
||||
rec_model_dir: "https://paddleocr.bj.bcebos.com/20-09-22/mobile/rec/ch_ppocr_mobile_v1.1_rec_infer.tar"
|
||||
rec_image_shape: [3, 32, 320]
|
||||
rec_char_type: "ch"
|
||||
rec_batch_num: 30
|
||||
max_text_length: 25
|
||||
rec_char_dict_path: "config/ppocr_keys_v1.txt"
|
||||
use_space_char: true
|
||||
|
||||
# params for text classifier
|
||||
use_angle_cls: false
|
||||
cls_model_dir: "https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile_v1.1_cls_infer.tar"
|
||||
cls_image_shape: [3, 48, 192]
|
||||
label_list: ["0", "180"]
|
||||
cls_batch_num: 30
|
||||
cls_thresh: 0.9
|
||||
|
||||
lang: ch
|
||||
det: true
|
||||
rec: true
|
||||
cls: false
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,51 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"paddleocr-go/ocr"
|
||||
)
|
||||
|
||||
var (
|
||||
confFile string
|
||||
image string
|
||||
imageDir string
|
||||
useServering bool
|
||||
port string
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&confFile, "config", "config/conf.yaml", "config from ocr system. If not given, will use default config.")
|
||||
flag.StringVar(&image, "image", "", "image to predict. if not given, will use image_dir")
|
||||
flag.StringVar(&imageDir, "image_dir", "", "imgs in dir to be predicted. if not given, will check servering")
|
||||
flag.BoolVar(&useServering, "use_servering", false, "whether to use ocr server. [default: false]")
|
||||
flag.StringVar(&port, "port", "18600", "which port to serve ocr server. [default: 18600].")
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
sys := ocr.NewOCRSystem(confFile, nil)
|
||||
|
||||
if image != "" {
|
||||
img := ocr.ReadImage(image)
|
||||
results := sys.PredictOneImage(img)
|
||||
for _, res := range results {
|
||||
log.Println(res)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if imageDir != "" {
|
||||
results := sys.PredictDirImages(imageDir)
|
||||
for k, vs := range results {
|
||||
log.Printf("======== image: %v =======\n", k)
|
||||
for _, res := range vs {
|
||||
log.Println(res)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if useServering {
|
||||
sys.StartServer(port)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
module paddleocr-go
|
||||
|
||||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/LKKlein/gocv v0.28.0
|
||||
github.com/ctessum/go.clipper v0.0.0-20200522184404-9c744fa3e86c
|
||||
gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776
|
||||
)
|
|
@ -0,0 +1,9 @@
|
|||
github.com/LKKlein/gocv v0.27.0 h1:JGNBMa2HY7HC0VlVHB4gdFjoc9NlyyrQvlUdBMWWSYw=
|
||||
github.com/LKKlein/gocv v0.27.0/go.mod h1:MP408EL7eakRU3vzjsozzfELSX7HDDGdMpWANV1IOHY=
|
||||
github.com/LKKlein/gocv v0.28.0 h1:1MMvs9uYf+QGPi86it2pUmN8RRoyMnPLUefKB/Jf1Q0=
|
||||
github.com/LKKlein/gocv v0.28.0/go.mod h1:MP408EL7eakRU3vzjsozzfELSX7HDDGdMpWANV1IOHY=
|
||||
github.com/ctessum/go.clipper v0.0.0-20200522184404-9c744fa3e86c h1:VXCsVlam0R2Yl7VET2GxZBPdOa7gFRexyhfWb9v9QtM=
|
||||
github.com/ctessum/go.clipper v0.0.0-20200522184404-9c744fa3e86c/go.mod h1:KRMo3PCsooJP3LmCwKI76dkd7f3ki3zwYLHR7Iwbi5k=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 h1:tQIYjPdBoyREyB9XMu+nnTclpTYkz2zFM+lzLJFO4gQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
Binary file not shown.
After Width: | Height: | Size: 36 KiB |
Binary file not shown.
After Width: | Height: | Size: 279 KiB |
Binary file not shown.
After Width: | Height: | Size: 98 KiB |
Binary file not shown.
After Width: | Height: | Size: 162 KiB |
Binary file not shown.
After Width: | Height: | Size: 141 KiB |
Binary file not shown.
After Width: | Height: | Size: 48 KiB |
|
@ -0,0 +1,259 @@
|
|||
package ocr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"image"
|
||||
"image/color"
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
"net/http"
|
||||
"paddleocr-go/paddle"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/LKKlein/gocv"
|
||||
)
|
||||
|
||||
type PaddleModel struct {
|
||||
predictor *paddle.Predictor
|
||||
input *paddle.ZeroCopyTensor
|
||||
outputs []*paddle.ZeroCopyTensor
|
||||
|
||||
useGPU bool
|
||||
deviceID int
|
||||
initGPUMem int
|
||||
numThreads int
|
||||
useMKLDNN bool
|
||||
useTensorRT bool
|
||||
useIROptim bool
|
||||
}
|
||||
|
||||
func NewPaddleModel(args map[string]interface{}) *PaddleModel {
|
||||
return &PaddleModel{
|
||||
useGPU: getBool(args, "use_gpu", false),
|
||||
deviceID: getInt(args, "gpu_id", 0),
|
||||
initGPUMem: getInt(args, "gpu_mem", 1000),
|
||||
numThreads: getInt(args, "num_cpu_threads", 6),
|
||||
useMKLDNN: getBool(args, "enable_mkldnn", false),
|
||||
useTensorRT: getBool(args, "use_tensorrt", false),
|
||||
useIROptim: getBool(args, "ir_optim", true),
|
||||
}
|
||||
}
|
||||
|
||||
func (model *PaddleModel) LoadModel(modelDir string) {
|
||||
config := paddle.NewAnalysisConfig()
|
||||
config.DisableGlogInfo()
|
||||
|
||||
config.SetModel(modelDir+"/model", modelDir+"/params")
|
||||
if model.useGPU {
|
||||
config.EnableUseGpu(model.initGPUMem, model.deviceID)
|
||||
} else {
|
||||
config.DisableGpu()
|
||||
config.SetCpuMathLibraryNumThreads(model.numThreads)
|
||||
if model.useMKLDNN {
|
||||
config.EnableMkldnn()
|
||||
}
|
||||
}
|
||||
|
||||
// config.EnableMemoryOptim()
|
||||
if model.useIROptim {
|
||||
config.SwitchIrOptim(true)
|
||||
}
|
||||
|
||||
// false for zero copy tensor
|
||||
config.SwitchUseFeedFetchOps(false)
|
||||
config.SwitchSpecifyInputNames(true)
|
||||
|
||||
model.predictor = paddle.NewPredictor(config)
|
||||
model.input = model.predictor.GetInputTensors()[0]
|
||||
model.outputs = model.predictor.GetOutputTensors()
|
||||
}
|
||||
|
||||
type OCRText struct {
|
||||
BBox [][]int `json:"bbox"`
|
||||
Text string `json:"text"`
|
||||
Score float64 `json:"score"`
|
||||
}
|
||||
|
||||
type TextPredictSystem struct {
|
||||
detector *DBDetector
|
||||
cls *TextClassifier
|
||||
rec *TextRecognizer
|
||||
}
|
||||
|
||||
func NewTextPredictSystem(args map[string]interface{}) *TextPredictSystem {
|
||||
sys := &TextPredictSystem{
|
||||
detector: NewDBDetector(getString(args, "det_model_dir", ""), args),
|
||||
rec: NewTextRecognizer(getString(args, "rec_model_dir", ""), args),
|
||||
}
|
||||
if getBool(args, "use_angle_cls", false) {
|
||||
sys.cls = NewTextClassifier(getString(args, "cls_model_dir", ""), args)
|
||||
}
|
||||
return sys
|
||||
}
|
||||
|
||||
func (sys *TextPredictSystem) sortBoxes(boxes [][][]int) [][][]int {
|
||||
sort.Slice(boxes, func(i, j int) bool {
|
||||
if boxes[i][0][1] < boxes[j][0][1] {
|
||||
return true
|
||||
}
|
||||
if boxes[i][0][1] > boxes[j][0][1] {
|
||||
return false
|
||||
}
|
||||
return boxes[i][0][0] < boxes[j][0][0]
|
||||
})
|
||||
|
||||
for i := 0; i < len(boxes)-1; i++ {
|
||||
if math.Abs(float64(boxes[i+1][0][1]-boxes[i][0][1])) < 10 && boxes[i+1][0][0] < boxes[i][0][0] {
|
||||
boxes[i], boxes[i+1] = boxes[i+1], boxes[i]
|
||||
}
|
||||
}
|
||||
return boxes
|
||||
}
|
||||
|
||||
func (sys *TextPredictSystem) getRotateCropImage(img gocv.Mat, box [][]int) gocv.Mat {
|
||||
cropW := int(math.Sqrt(math.Pow(float64(box[0][0]-box[1][0]), 2) + math.Pow(float64(box[0][1]-box[1][1]), 2)))
|
||||
cropH := int(math.Sqrt(math.Pow(float64(box[0][0]-box[3][0]), 2) + math.Pow(float64(box[0][1]-box[3][1]), 2)))
|
||||
ptsstd := make([]image.Point, 4)
|
||||
ptsstd[0] = image.Pt(0, 0)
|
||||
ptsstd[1] = image.Pt(cropW, 0)
|
||||
ptsstd[2] = image.Pt(cropW, cropH)
|
||||
ptsstd[3] = image.Pt(0, cropH)
|
||||
|
||||
points := make([]image.Point, 4)
|
||||
points[0] = image.Pt(box[0][0], box[0][1])
|
||||
points[1] = image.Pt(box[1][0], box[1][1])
|
||||
points[2] = image.Pt(box[2][0], box[2][1])
|
||||
points[3] = image.Pt(box[3][0], box[3][1])
|
||||
|
||||
M := gocv.GetPerspectiveTransform(points, ptsstd)
|
||||
defer M.Close()
|
||||
dstimg := gocv.NewMat()
|
||||
gocv.WarpPerspectiveWithParams(img, &dstimg, M, image.Pt(cropW, cropH),
|
||||
gocv.InterpolationCubic, gocv.BorderReplicate, color.RGBA{0, 0, 0, 0})
|
||||
|
||||
if float64(dstimg.Rows()) >= float64(dstimg.Cols())*1.5 {
|
||||
srcCopy := gocv.NewMat()
|
||||
gocv.Transpose(dstimg, &srcCopy)
|
||||
defer dstimg.Close()
|
||||
gocv.Flip(srcCopy, &srcCopy, 0)
|
||||
return srcCopy
|
||||
}
|
||||
return dstimg
|
||||
}
|
||||
|
||||
func (sys *TextPredictSystem) Run(img gocv.Mat) []OCRText {
|
||||
srcimg := gocv.NewMat()
|
||||
img.CopyTo(&srcimg)
|
||||
boxes := sys.detector.Run(img)
|
||||
if len(boxes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
boxes = sys.sortBoxes(boxes)
|
||||
cropimages := make([]gocv.Mat, len(boxes))
|
||||
for i := 0; i < len(boxes); i++ {
|
||||
tmpbox := make([][]int, len(boxes[i]))
|
||||
for j := 0; j < len(tmpbox); j++ {
|
||||
tmpbox[j] = make([]int, len(boxes[i][j]))
|
||||
copy(tmpbox[j], boxes[i][j])
|
||||
}
|
||||
cropimg := sys.getRotateCropImage(srcimg, tmpbox)
|
||||
cropimages[i] = cropimg
|
||||
}
|
||||
if sys.cls != nil {
|
||||
cropimages = sys.cls.Run(cropimages)
|
||||
}
|
||||
recResult := sys.rec.Run(cropimages, boxes)
|
||||
return recResult
|
||||
}
|
||||
|
||||
type OCRSystem struct {
|
||||
args map[string]interface{}
|
||||
tps *TextPredictSystem
|
||||
}
|
||||
|
||||
func NewOCRSystem(confFile string, a map[string]interface{}) *OCRSystem {
|
||||
args, err := ReadYaml(confFile)
|
||||
if err != nil {
|
||||
log.Printf("Read config file %v failed! Please check. err: %v\n", confFile, err)
|
||||
log.Println("Program will use default config.")
|
||||
args = defaultArgs
|
||||
}
|
||||
for k, v := range a {
|
||||
args[k] = v
|
||||
}
|
||||
return &OCRSystem{
|
||||
args: args,
|
||||
tps: NewTextPredictSystem(args),
|
||||
}
|
||||
}
|
||||
|
||||
func (ocr *OCRSystem) StartServer(port string) {
|
||||
http.HandleFunc("/ocr", ocr.predictHandler)
|
||||
log.Println("OCR Server has been started on port :", port)
|
||||
err := http.ListenAndServe(":"+port, nil)
|
||||
if err != nil {
|
||||
log.Panicf("http error! error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ocr *OCRSystem) predictHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
w.Write([]byte(errors.New("post method only").Error()))
|
||||
return
|
||||
}
|
||||
r.ParseMultipartForm(32 << 20)
|
||||
var buf bytes.Buffer
|
||||
file, header, err := r.FormFile("image")
|
||||
if err != nil {
|
||||
w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
ext := strings.ToLower(path.Ext(header.Filename))
|
||||
if ext != ".jpg" && ext != ".png" {
|
||||
w.Write([]byte(errors.New("only support image endswith jpg/png").Error()))
|
||||
return
|
||||
}
|
||||
|
||||
io.Copy(&buf, file)
|
||||
img, err2 := gocv.IMDecode(buf.Bytes(), gocv.IMReadColor)
|
||||
if err2 != nil {
|
||||
w.Write([]byte(err2.Error()))
|
||||
return
|
||||
}
|
||||
result := ocr.PredictOneImage(img)
|
||||
if output, err3 := json.Marshal(result); err3 != nil {
|
||||
w.Write([]byte(err3.Error()))
|
||||
} else {
|
||||
w.Write(output)
|
||||
}
|
||||
}
|
||||
|
||||
func (ocr *OCRSystem) PredictOneImage(img gocv.Mat) []OCRText {
|
||||
return ocr.tps.Run(img)
|
||||
}
|
||||
|
||||
func (ocr *OCRSystem) PredictDirImages(dirname string) map[string][]OCRText {
|
||||
if dirname == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
imgs, _ := filepath.Glob(dirname + "/*.jpg")
|
||||
tmpimgs, _ := filepath.Glob(dirname + "/*.png")
|
||||
imgs = append(imgs, tmpimgs...)
|
||||
results := make(map[string][]OCRText, len(imgs))
|
||||
for i := 0; i < len(imgs); i++ {
|
||||
imgname := imgs[i]
|
||||
img := ReadImage(imgname)
|
||||
res := ocr.PredictOneImage(img)
|
||||
results[imgname] = res
|
||||
}
|
||||
return results
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
package ocr
|
||||
|
||||
var (
|
||||
defaultArgs = map[string]interface{}{
|
||||
"use_gpu": true,
|
||||
"ir_optim": true,
|
||||
"enable_mkldnn": false,
|
||||
"use_tensorrt": false,
|
||||
"num_cpu_threads": 6,
|
||||
"gpu_id": 0,
|
||||
"gpu_mem": 2000,
|
||||
|
||||
"det_algorithm": "DB",
|
||||
"det_model_dir": "https://paddleocr.bj.bcebos.com/20-09-22/mobile/det/ch_ppocr_mobile_v1.1_det_infer.tar",
|
||||
"det_max_side_len": 960,
|
||||
|
||||
"det_db_thresh": 0.3,
|
||||
"det_db_box_thresh": 0.5,
|
||||
"det_db_unclip_ratio": 2.0,
|
||||
|
||||
"det_east_score_thresh": 0.8,
|
||||
"det_east_cover_thresh": 0.1,
|
||||
"det_east_nms_thresh": 0.2,
|
||||
|
||||
"rec_algorithm": "CRNN",
|
||||
"rec_model_dir": "https://paddleocr.bj.bcebos.com/20-09-22/mobile/rec/ch_ppocr_mobile_v1.1_rec_infer.tar",
|
||||
"rec_image_shape": []interface{}{3, 32, 320},
|
||||
"rec_char_type": "ch",
|
||||
"rec_batch_num": 30,
|
||||
"max_text_length": 25,
|
||||
"rec_char_dict_path": "config/ppocr_keys_v1.txt",
|
||||
"use_space_char": true,
|
||||
|
||||
"use_angle_cls": false,
|
||||
"cls_model_dir": "https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile_v1.1_cls_infer.tar",
|
||||
"cls_image_shape": []interface{}{3, 48, 192},
|
||||
"label_list": []interface{}{"0", "180"},
|
||||
"cls_batch_num": 30,
|
||||
"cls_thresh": 0.9,
|
||||
|
||||
"lang": "ch",
|
||||
"det": true,
|
||||
"rec": true,
|
||||
"cls": false,
|
||||
}
|
||||
)
|
|
@ -0,0 +1,105 @@
|
|||
package ocr
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/LKKlein/gocv"
|
||||
)
|
||||
|
||||
type TextClassifier struct {
|
||||
*PaddleModel
|
||||
batchNum int
|
||||
thresh float64
|
||||
shape []int
|
||||
labels []string
|
||||
}
|
||||
|
||||
type ClsResult struct {
|
||||
Score float32
|
||||
Label int64
|
||||
}
|
||||
|
||||
func NewTextClassifier(modelDir string, args map[string]interface{}) *TextClassifier {
|
||||
shapes := []int{3, 48, 192}
|
||||
if v, ok := args["cls_image_shape"]; ok {
|
||||
for i, s := range v.([]interface{}) {
|
||||
shapes[i] = s.(int)
|
||||
}
|
||||
}
|
||||
cls := &TextClassifier{
|
||||
PaddleModel: NewPaddleModel(args),
|
||||
batchNum: getInt(args, "cls_batch_num", 30),
|
||||
thresh: getFloat64(args, "cls_thresh", 0.9),
|
||||
shape: shapes,
|
||||
}
|
||||
if checkModelExists(modelDir) {
|
||||
modelDir, _ = downloadModel("./inference/cls", modelDir)
|
||||
} else {
|
||||
log.Panicf("cls model path: %v not exist! Please check!", modelDir)
|
||||
}
|
||||
cls.LoadModel(modelDir)
|
||||
return cls
|
||||
}
|
||||
|
||||
func (cls *TextClassifier) Run(imgs []gocv.Mat) []gocv.Mat {
|
||||
batch := cls.batchNum
|
||||
var clsTime int64 = 0
|
||||
clsout := make([]ClsResult, len(imgs))
|
||||
srcimgs := make([]gocv.Mat, len(imgs))
|
||||
c, h, w := cls.shape[0], cls.shape[1], cls.shape[2]
|
||||
for i := 0; i < len(imgs); i += batch {
|
||||
j := i + batch
|
||||
if len(imgs) < j {
|
||||
j = len(imgs)
|
||||
}
|
||||
|
||||
normImgs := make([]float32, (j-i)*c*h*w)
|
||||
for k := i; k < j; k++ {
|
||||
tmp := gocv.NewMat()
|
||||
imgs[k].CopyTo(&tmp)
|
||||
srcimgs[k] = tmp
|
||||
img := clsResize(imgs[k], cls.shape)
|
||||
data := normPermute(img, []float32{0.5, 0.5, 0.5}, []float32{0.5, 0.5, 0.5}, 255.0)
|
||||
copy(normImgs[(k-i)*c*h*w:], data)
|
||||
}
|
||||
|
||||
st := time.Now()
|
||||
cls.input.SetValue(normImgs)
|
||||
cls.input.Reshape([]int32{int32(j - i), int32(c), int32(h), int32(w)})
|
||||
|
||||
cls.predictor.SetZeroCopyInput(cls.input)
|
||||
cls.predictor.ZeroCopyRun()
|
||||
cls.predictor.GetZeroCopyOutput(cls.outputs[0])
|
||||
cls.predictor.GetZeroCopyOutput(cls.outputs[1])
|
||||
|
||||
var probout [][]float32
|
||||
var labelout []int64
|
||||
if len(cls.outputs[0].Shape()) == 2 {
|
||||
probout = cls.outputs[0].Value().([][]float32)
|
||||
} else {
|
||||
labelout = cls.outputs[0].Value().([]int64)
|
||||
}
|
||||
|
||||
if len(cls.outputs[1].Shape()) == 2 {
|
||||
probout = cls.outputs[1].Value().([][]float32)
|
||||
} else {
|
||||
labelout = cls.outputs[1].Value().([]int64)
|
||||
}
|
||||
clsTime += int64(time.Since(st).Milliseconds())
|
||||
|
||||
for no, label := range labelout {
|
||||
score := probout[no][label]
|
||||
clsout[i+no] = ClsResult{
|
||||
Score: score,
|
||||
Label: label,
|
||||
}
|
||||
|
||||
if label%2 == 1 && float64(score) > cls.thresh {
|
||||
gocv.Rotate(srcimgs[i+no], &srcimgs[i+no], gocv.Rotate180Clockwise)
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Println("cls num: ", len(clsout), ", cls time elapse: ", clsTime, "ms")
|
||||
return srcimgs
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
package ocr
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/LKKlein/gocv"
|
||||
)
|
||||
|
||||
type DBDetector struct {
|
||||
*PaddleModel
|
||||
preProcess DetPreProcess
|
||||
postProcess DetPostProcess
|
||||
}
|
||||
|
||||
func NewDBDetector(modelDir string, args map[string]interface{}) *DBDetector {
|
||||
maxSideLen := getInt(args, "det_max_side_len", 960)
|
||||
thresh := getFloat64(args, "det_db_thresh", 0.3)
|
||||
boxThresh := getFloat64(args, "det_db_box_thresh", 0.5)
|
||||
unClipRatio := getFloat64(args, "det_db_unclip_ratio", 2.0)
|
||||
|
||||
detector := &DBDetector{
|
||||
PaddleModel: NewPaddleModel(args),
|
||||
preProcess: NewDBProcess(make([]int, 0), maxSideLen),
|
||||
postProcess: NewDBPostProcess(thresh, boxThresh, unClipRatio),
|
||||
}
|
||||
if checkModelExists(modelDir) {
|
||||
modelDir, _ = downloadModel("./inference/det", modelDir)
|
||||
} else {
|
||||
log.Panicf("det model path: %v not exist! Please check!", modelDir)
|
||||
}
|
||||
detector.LoadModel(modelDir)
|
||||
return detector
|
||||
}
|
||||
|
||||
func (det *DBDetector) Run(img gocv.Mat) [][][]int {
|
||||
oriH := img.Rows()
|
||||
oriW := img.Cols()
|
||||
data, resizeH, resizeW := det.preProcess.Run(img)
|
||||
st := time.Now()
|
||||
det.input.SetValue(data)
|
||||
det.input.Reshape([]int32{1, 3, int32(resizeH), int32(resizeW)})
|
||||
|
||||
det.predictor.SetZeroCopyInput(det.input)
|
||||
det.predictor.ZeroCopyRun()
|
||||
det.predictor.GetZeroCopyOutput(det.outputs[0])
|
||||
|
||||
ratioH, ratioW := float64(resizeH)/float64(oriH), float64(resizeW)/float64(oriW)
|
||||
boxes := det.postProcess.Run(det.outputs[0], oriH, oriW, ratioH, ratioW)
|
||||
log.Println("det_box num: ", len(boxes), ", time elapse: ", time.Since(st))
|
||||
return boxes
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
package ocr
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/LKKlein/gocv"
|
||||
)
|
||||
|
||||
type TextRecognizer struct {
|
||||
*PaddleModel
|
||||
batchNum int
|
||||
textLen int
|
||||
shape []int
|
||||
charType string
|
||||
labels []string
|
||||
}
|
||||
|
||||
func NewTextRecognizer(modelDir string, args map[string]interface{}) *TextRecognizer {
|
||||
shapes := []int{3, 32, 320}
|
||||
if v, ok := args["rec_image_shape"]; ok {
|
||||
for i, s := range v.([]interface{}) {
|
||||
shapes[i] = s.(int)
|
||||
}
|
||||
}
|
||||
labelpath := getString(args, "rec_char_dict_path", "./config/ppocr_keys_v1.txt")
|
||||
labels := readLines2StringSlice(labelpath)
|
||||
if getBool(args, "use_space_char", true) {
|
||||
labels = append(labels, " ")
|
||||
}
|
||||
rec := &TextRecognizer{
|
||||
PaddleModel: NewPaddleModel(args),
|
||||
batchNum: getInt(args, "rec_batch_num", 30),
|
||||
textLen: getInt(args, "max_text_length", 25),
|
||||
charType: getString(args, "rec_char_type", "ch"),
|
||||
shape: shapes,
|
||||
labels: labels,
|
||||
}
|
||||
if checkModelExists(modelDir) {
|
||||
modelDir, _ = downloadModel("./inference/rec/ch", modelDir)
|
||||
} else {
|
||||
log.Panicf("rec model path: %v not exist! Please check!", modelDir)
|
||||
}
|
||||
rec.LoadModel(modelDir)
|
||||
return rec
|
||||
}
|
||||
|
||||
func (rec *TextRecognizer) Run(imgs []gocv.Mat, bboxes [][][]int) []OCRText {
|
||||
recResult := make([]OCRText, 0, len(imgs))
|
||||
batch := rec.batchNum
|
||||
var recTime int64 = 0
|
||||
c, h, w := rec.shape[0], rec.shape[1], rec.shape[2]
|
||||
for i := 0; i < len(imgs); i += batch {
|
||||
j := i + batch
|
||||
if len(imgs) < j {
|
||||
j = len(imgs)
|
||||
}
|
||||
|
||||
maxwhratio := 0.0
|
||||
for k := i; k < j; k++ {
|
||||
h, w := imgs[k].Rows(), imgs[k].Cols()
|
||||
ratio := float64(w) / float64(h)
|
||||
if ratio > maxwhratio {
|
||||
maxwhratio = ratio
|
||||
}
|
||||
}
|
||||
|
||||
if rec.charType == "ch" {
|
||||
w = int(32 * maxwhratio)
|
||||
}
|
||||
normimgs := make([]float32, (j-i)*c*h*w)
|
||||
|
||||
for k := i; k < j; k++ {
|
||||
data := crnnPreprocess(imgs[k], rec.shape, []float32{0.5, 0.5, 0.5},
|
||||
[]float32{0.5, 0.5, 0.5}, 255.0, maxwhratio, rec.charType)
|
||||
copy(normimgs[(k-i)*c*h*w:], data)
|
||||
}
|
||||
|
||||
st := time.Now()
|
||||
rec.input.SetValue(normimgs)
|
||||
rec.input.Reshape([]int32{int32(j - i), int32(c), int32(h), int32(w)})
|
||||
|
||||
rec.predictor.SetZeroCopyInput(rec.input)
|
||||
rec.predictor.ZeroCopyRun()
|
||||
rec.predictor.GetZeroCopyOutput(rec.outputs[0])
|
||||
rec.predictor.GetZeroCopyOutput(rec.outputs[1])
|
||||
|
||||
recIdxBatch := rec.outputs[0].Value().([][]int64)
|
||||
recIdxLod := rec.outputs[0].Lod()
|
||||
|
||||
predictBatch := rec.outputs[1].Value().([][]float32)
|
||||
predictLod := rec.outputs[1].Lod()
|
||||
recTime += int64(time.Since(st).Milliseconds())
|
||||
|
||||
for rno := 0; rno < len(recIdxLod)-1; rno++ {
|
||||
predIdx := make([]int, 0, 2)
|
||||
for beg := recIdxLod[rno]; beg < recIdxLod[rno+1]; beg++ {
|
||||
predIdx = append(predIdx, int(recIdxBatch[beg][0]))
|
||||
}
|
||||
if len(predIdx) == 0 {
|
||||
continue
|
||||
}
|
||||
words := ""
|
||||
for n := 0; n < len(predIdx); n++ {
|
||||
words += rec.labels[predIdx[n]]
|
||||
}
|
||||
|
||||
score := 0.0
|
||||
count := 0
|
||||
blankPosition := int(rec.outputs[1].Shape()[1])
|
||||
for beg := predictLod[rno]; beg < predictLod[rno+1]; beg++ {
|
||||
argMaxID, maxVal := argmax(predictBatch[beg])
|
||||
if blankPosition-1-argMaxID > 0 {
|
||||
score += float64(maxVal)
|
||||
count++
|
||||
}
|
||||
}
|
||||
score = score / float64(count)
|
||||
recResult = append(recResult, OCRText{
|
||||
BBox: bboxes[i+rno],
|
||||
Text: words,
|
||||
Score: score,
|
||||
})
|
||||
}
|
||||
}
|
||||
log.Println("rec num: ", len(recResult), ", rec time elapse: ", recTime, "ms")
|
||||
return recResult
|
||||
}
|
|
@ -0,0 +1,264 @@
|
|||
package ocr
|
||||
|
||||
import (
|
||||
"image"
|
||||
"image/color"
|
||||
"math"
|
||||
"paddleocr-go/paddle"
|
||||
"sort"
|
||||
|
||||
"github.com/LKKlein/gocv"
|
||||
clipper "github.com/ctessum/go.clipper"
|
||||
)
|
||||
|
||||
type xFloatSortBy [][]float32
|
||||
|
||||
func (a xFloatSortBy) Len() int { return len(a) }
|
||||
func (a xFloatSortBy) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a xFloatSortBy) Less(i, j int) bool { return a[i][0] < a[j][0] }
|
||||
|
||||
type xIntSortBy [][]int
|
||||
|
||||
func (a xIntSortBy) Len() int { return len(a) }
|
||||
func (a xIntSortBy) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a xIntSortBy) Less(i, j int) bool { return a[i][0] < a[j][0] }
|
||||
|
||||
type DetPostProcess interface {
|
||||
Run(output *paddle.ZeroCopyTensor, oriH, oriW int, ratioH, ratioW float64) [][][]int
|
||||
}
|
||||
|
||||
type DBPostProcess struct {
|
||||
thresh float64
|
||||
boxThresh float64
|
||||
maxCandidates int
|
||||
unClipRatio float64
|
||||
minSize int
|
||||
}
|
||||
|
||||
func NewDBPostProcess(thresh, boxThresh, unClipRatio float64) *DBPostProcess {
|
||||
return &DBPostProcess{
|
||||
thresh: thresh,
|
||||
boxThresh: boxThresh,
|
||||
unClipRatio: unClipRatio,
|
||||
maxCandidates: 1000,
|
||||
minSize: 3,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DBPostProcess) getMinBoxes(rect gocv.RotatedRect) [][]float32 {
|
||||
points := gocv.NewMat()
|
||||
gocv.BoxPoints(rect, &points)
|
||||
defer points.Close()
|
||||
array := d.mat2slice(points)
|
||||
sort.Sort(xFloatSortBy(array))
|
||||
|
||||
point1, point2, point3, point4 := array[0], array[1], array[2], array[3]
|
||||
if array[3][1] <= array[2][1] {
|
||||
point2, point3 = array[3], array[2]
|
||||
} else {
|
||||
point2, point3 = array[2], array[3]
|
||||
}
|
||||
|
||||
if array[1][1] <= array[0][1] {
|
||||
point1, point4 = array[1], array[0]
|
||||
} else {
|
||||
point1, point4 = array[0], array[1]
|
||||
}
|
||||
|
||||
array = [][]float32{point1, point2, point3, point4}
|
||||
return array
|
||||
}
|
||||
|
||||
func (d *DBPostProcess) mat2slice(mat gocv.Mat) [][]float32 {
|
||||
array := make([][]float32, mat.Rows())
|
||||
for i := 0; i < mat.Rows(); i++ {
|
||||
tmp := make([]float32, mat.Cols())
|
||||
for j := 0; j < mat.Cols(); j++ {
|
||||
tmp[j] = mat.GetFloatAt(i, j)
|
||||
}
|
||||
array[i] = tmp
|
||||
}
|
||||
return array
|
||||
}
|
||||
|
||||
func (d *DBPostProcess) boxScoreFast(array [][]float32, pred gocv.Mat) float64 {
|
||||
height, width := pred.Rows(), pred.Cols()
|
||||
boxX := []float32{array[0][0], array[1][0], array[2][0], array[3][0]}
|
||||
boxY := []float32{array[0][1], array[1][1], array[2][1], array[3][1]}
|
||||
|
||||
xmin := clip(int(math.Floor(float64(minf(boxX)))), 0, width-1)
|
||||
xmax := clip(int(math.Ceil(float64(maxf(boxX)))), 0, width-1)
|
||||
ymin := clip(int(math.Floor(float64(minf(boxY)))), 0, height-1)
|
||||
ymax := clip(int(math.Ceil(float64(maxf(boxY)))), 0, height-1)
|
||||
|
||||
mask := gocv.NewMatWithSize(ymax-ymin+1, xmax-xmin+1, gocv.MatTypeCV8UC1)
|
||||
ppt := make([][]image.Point, 1)
|
||||
ppt[0] = make([]image.Point, 4)
|
||||
ppt[0][0] = image.Point{int(array[0][0]) - xmin, int(array[0][1]) - ymin}
|
||||
ppt[0][1] = image.Point{int(array[1][0]) - xmin, int(array[1][1]) - ymin}
|
||||
ppt[0][2] = image.Point{int(array[2][0]) - xmin, int(array[2][1]) - ymin}
|
||||
ppt[0][3] = image.Point{int(array[3][0]) - xmin, int(array[3][1]) - ymin}
|
||||
gocv.FillPoly(&mask, ppt, color.RGBA{0, 0, 1, 0})
|
||||
croppedImg := pred.Region(image.Rect(xmin, ymin, xmax+1, ymax+1))
|
||||
s := croppedImg.MeanWithMask(mask)
|
||||
return s.Val1
|
||||
}
|
||||
|
||||
func (d *DBPostProcess) unClip(box [][]float32) gocv.RotatedRect {
|
||||
var area, dist float64
|
||||
for i := 0; i < 4; i++ {
|
||||
area += float64(box[i][0]*box[(i+1)%4][1] - box[i][1]*box[(i+1)%4][0])
|
||||
dist += math.Sqrt(float64(
|
||||
(box[i][0]-box[(i+1)%4][0])*(box[i][0]-box[(i+1)%4][0]) +
|
||||
(box[i][1]-box[(i+1)%4][1])*(box[i][1]-box[(i+1)%4][1]),
|
||||
))
|
||||
}
|
||||
area = math.Abs(area / 2.0)
|
||||
distance := area * d.unClipRatio / dist
|
||||
offset := clipper.NewClipperOffset()
|
||||
path := make([]*clipper.IntPoint, 4)
|
||||
path[0] = &clipper.IntPoint{X: clipper.CInt(box[0][0]), Y: clipper.CInt(box[0][1])}
|
||||
path[1] = &clipper.IntPoint{X: clipper.CInt(box[1][0]), Y: clipper.CInt(box[1][1])}
|
||||
path[2] = &clipper.IntPoint{X: clipper.CInt(box[2][0]), Y: clipper.CInt(box[2][1])}
|
||||
path[3] = &clipper.IntPoint{X: clipper.CInt(box[3][0]), Y: clipper.CInt(box[3][1])}
|
||||
offset.AddPath(clipper.Path(path), clipper.JtRound, clipper.EtClosedPolygon)
|
||||
soln := offset.Execute(distance)
|
||||
|
||||
points := make([]image.Point, 0, 4)
|
||||
for i := 0; i < len(soln); i++ {
|
||||
for j := 0; j < len(soln[i]); j++ {
|
||||
points = append(points, image.Point{int(soln[i][j].X), int(soln[i][j].Y)})
|
||||
}
|
||||
}
|
||||
|
||||
var res gocv.RotatedRect
|
||||
if len(points) <= 0 {
|
||||
points = make([]image.Point, 4)
|
||||
points[0] = image.Pt(0, 0)
|
||||
points[1] = image.Pt(1, 0)
|
||||
points[2] = image.Pt(1, 1)
|
||||
points[3] = image.Pt(0, 1)
|
||||
res = gocv.RotatedRect{
|
||||
Contour: points,
|
||||
BoundingRect: image.Rect(0, 0, 1, 1),
|
||||
Center: gocv.Point2f{X: 0.5, Y: 0.5},
|
||||
Width: 1,
|
||||
Height: 1,
|
||||
Angle: 0,
|
||||
}
|
||||
} else {
|
||||
res = gocv.MinAreaRect(points)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (d *DBPostProcess) boxesFromBitmap(pred gocv.Mat, mask gocv.Mat, ratioH float64, ratioW float64) [][][]int {
|
||||
height, width := mask.Rows(), mask.Cols()
|
||||
mask.MultiplyUChar(255)
|
||||
contours := gocv.FindContours(mask, gocv.RetrievalList, gocv.ChainApproxSimple)
|
||||
numContours := len(contours)
|
||||
if numContours > d.maxCandidates {
|
||||
numContours = d.maxCandidates
|
||||
}
|
||||
|
||||
boxes := make([][][]int, 0, numContours)
|
||||
for i := 0; i < numContours; i++ {
|
||||
contour := contours[i]
|
||||
boundingbox := gocv.MinAreaRect(contour)
|
||||
if boundingbox.Width < float32(d.minSize) || boundingbox.Height < float32(d.minSize) {
|
||||
continue
|
||||
}
|
||||
points := d.getMinBoxes(boundingbox)
|
||||
score := d.boxScoreFast(points, pred)
|
||||
if score < d.boxThresh {
|
||||
continue
|
||||
}
|
||||
|
||||
box := d.unClip(points)
|
||||
if box.Width < float32(d.minSize+2) || box.Height < float32(d.minSize+2) {
|
||||
continue
|
||||
}
|
||||
|
||||
cliparray := d.getMinBoxes(box)
|
||||
dstHeight, dstWidth := pred.Rows(), pred.Cols()
|
||||
intcliparray := make([][]int, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
p := []int{
|
||||
int(float64(clip(int(math.Round(
|
||||
float64(cliparray[i][0]/float32(width)*float32(dstWidth)))), 0, dstWidth)) / ratioW),
|
||||
int(float64(clip(int(math.Round(
|
||||
float64(cliparray[i][1]/float32(height)*float32(dstHeight)))), 0, dstHeight)) / ratioH),
|
||||
}
|
||||
intcliparray[i] = p
|
||||
}
|
||||
boxes = append(boxes, intcliparray)
|
||||
}
|
||||
return boxes
|
||||
}
|
||||
|
||||
func (d *DBPostProcess) orderPointsClockwise(box [][]int) [][]int {
|
||||
sort.Sort(xIntSortBy(box))
|
||||
leftmost := [][]int{box[0], box[1]}
|
||||
rightmost := [][]int{box[2], box[3]}
|
||||
|
||||
if leftmost[0][1] > leftmost[1][1] {
|
||||
leftmost[0], leftmost[1] = leftmost[1], leftmost[0]
|
||||
}
|
||||
|
||||
if rightmost[0][1] > rightmost[1][1] {
|
||||
rightmost[0], rightmost[1] = rightmost[1], rightmost[0]
|
||||
}
|
||||
|
||||
return [][]int{leftmost[0], rightmost[0], rightmost[1], leftmost[1]}
|
||||
}
|
||||
|
||||
func (d *DBPostProcess) filterTagDetRes(boxes [][][]int, oriH, oriW int) [][][]int {
|
||||
points := make([][][]int, 0, len(boxes))
|
||||
for i := 0; i < len(boxes); i++ {
|
||||
boxes[i] = d.orderPointsClockwise(boxes[i])
|
||||
for j := 0; j < len(boxes[i]); j++ {
|
||||
boxes[i][j][0] = clip(boxes[i][j][0], 0, oriW-1)
|
||||
boxes[i][j][1] = clip(boxes[i][j][1], 0, oriH-1)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(boxes); i++ {
|
||||
rectW := int(math.Sqrt(math.Pow(float64(boxes[i][0][0]-boxes[i][1][0]), 2.0) +
|
||||
math.Pow(float64(boxes[i][0][1]-boxes[i][1][1]), 2.0)))
|
||||
rectH := int(math.Sqrt(math.Pow(float64(boxes[i][0][0]-boxes[i][3][0]), 2.0) +
|
||||
math.Pow(float64(boxes[i][0][1]-boxes[i][3][1]), 2.0)))
|
||||
if rectW <= 4 || rectH <= 4 {
|
||||
continue
|
||||
}
|
||||
points = append(points, boxes[i])
|
||||
}
|
||||
return points
|
||||
}
|
||||
|
||||
func (d *DBPostProcess) Run(output *paddle.ZeroCopyTensor, oriH, oriW int, ratioH, ratioW float64) [][][]int {
|
||||
v := output.Value().([][][][]float32)
|
||||
|
||||
shape := output.Shape()
|
||||
height, width := int(shape[2]), int(shape[3])
|
||||
|
||||
pred := gocv.NewMatWithSize(height, width, gocv.MatTypeCV32F)
|
||||
bitmap := gocv.NewMatWithSize(height, width, gocv.MatTypeCV8UC1)
|
||||
thresh := float32(d.thresh)
|
||||
for i := 0; i < height; i++ {
|
||||
for j := 0; j < width; j++ {
|
||||
pred.SetFloatAt(i, j, v[0][0][i][j])
|
||||
if v[0][0][i][j] > thresh {
|
||||
bitmap.SetUCharAt(i, j, 1)
|
||||
} else {
|
||||
bitmap.SetUCharAt(i, j, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mask := gocv.NewMat()
|
||||
kernel := gocv.GetStructuringElement(gocv.MorphRect, image.Point{2, 2})
|
||||
gocv.Dilate(bitmap, &mask, kernel)
|
||||
boxes := d.boxesFromBitmap(pred, mask, ratioH, ratioW)
|
||||
dtboxes := d.filterTagDetRes(boxes, oriH, oriW)
|
||||
return dtboxes
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
package ocr
|
||||
|
||||
import (
|
||||
"image"
|
||||
"image/color"
|
||||
"math"
|
||||
|
||||
"github.com/LKKlein/gocv"
|
||||
)
|
||||
|
||||
func resizeByShape(img gocv.Mat, resizeShape []int) (gocv.Mat, int, int) {
|
||||
resizeH := resizeShape[0]
|
||||
resizeW := resizeShape[1]
|
||||
gocv.Resize(img, &img, image.Pt(resizeW, resizeH), 0, 0, gocv.InterpolationLinear)
|
||||
return img, resizeH, resizeW
|
||||
}
|
||||
|
||||
func resizeByMaxLen(img gocv.Mat, maxLen int) (gocv.Mat, int, int) {
|
||||
oriH := img.Rows()
|
||||
oriW := img.Cols()
|
||||
var resizeH, resizeW int = oriH, oriW
|
||||
|
||||
var ratio float64 = 1.0
|
||||
if resizeH > maxLen || resizeW > maxLen {
|
||||
if resizeH > resizeW {
|
||||
ratio = float64(maxLen) / float64(resizeH)
|
||||
} else {
|
||||
ratio = float64(maxLen) / float64(resizeW)
|
||||
}
|
||||
}
|
||||
|
||||
resizeH = int(float64(resizeH) * ratio)
|
||||
resizeW = int(float64(resizeW) * ratio)
|
||||
|
||||
if resizeH%32 == 0 {
|
||||
resizeH = resizeH
|
||||
} else if resizeH/32 <= 1 {
|
||||
resizeH = 32
|
||||
} else {
|
||||
resizeH = (resizeH/32 - 1) * 32
|
||||
}
|
||||
|
||||
if resizeW%32 == 0 {
|
||||
resizeW = resizeW
|
||||
} else if resizeW/32 <= 1 {
|
||||
resizeW = 32
|
||||
} else {
|
||||
resizeW = (resizeW/32 - 1) * 32
|
||||
}
|
||||
|
||||
if resizeW <= 0 || resizeH <= 0 {
|
||||
return gocv.NewMat(), 0, 0
|
||||
}
|
||||
|
||||
gocv.Resize(img, &img, image.Pt(resizeW, resizeH), 0, 0, gocv.InterpolationLinear)
|
||||
return img, resizeH, resizeW
|
||||
}
|
||||
|
||||
func normPermute(img gocv.Mat, mean []float32, std []float32, scaleFactor float32) []float32 {
|
||||
img.ConvertTo(&img, gocv.MatTypeCV32F)
|
||||
img.DivideFloat(scaleFactor)
|
||||
defer img.Close()
|
||||
|
||||
c := gocv.Split(img)
|
||||
data := make([]float32, img.Rows()*img.Cols()*img.Channels())
|
||||
for i := 0; i < 3; i++ {
|
||||
c[i].SubtractFloat(mean[i])
|
||||
c[i].DivideFloat(std[i])
|
||||
defer c[i].Close()
|
||||
x, _ := c[i].DataPtrFloat32()
|
||||
copy(data[i*img.Rows()*img.Cols():], x)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
type DetPreProcess interface {
|
||||
Run(gocv.Mat) ([]float32, int, int)
|
||||
}
|
||||
|
||||
type DBPreProcess struct {
|
||||
resizeType int
|
||||
imageShape []int
|
||||
maxSideLen int
|
||||
mean []float32
|
||||
std []float32
|
||||
scaleFactor float32
|
||||
}
|
||||
|
||||
func NewDBProcess(shape []int, sideLen int) *DBPreProcess {
|
||||
db := &DBPreProcess{
|
||||
resizeType: 0,
|
||||
imageShape: shape,
|
||||
maxSideLen: sideLen,
|
||||
mean: []float32{0.485, 0.456, 0.406},
|
||||
std: []float32{0.229, 0.224, 0.225},
|
||||
scaleFactor: 255.0,
|
||||
}
|
||||
if len(shape) > 0 {
|
||||
db.resizeType = 1
|
||||
}
|
||||
if sideLen == 0 {
|
||||
db.maxSideLen = 2400
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (d *DBPreProcess) Run(img gocv.Mat) ([]float32, int, int) {
|
||||
var resizeH, resizeW int
|
||||
if d.resizeType == 0 {
|
||||
img, resizeH, resizeW = resizeByMaxLen(img, d.maxSideLen)
|
||||
} else {
|
||||
img, resizeH, resizeW = resizeByShape(img, d.imageShape)
|
||||
}
|
||||
|
||||
im := normPermute(img, d.mean, d.std, d.scaleFactor)
|
||||
return im, resizeH, resizeW
|
||||
}
|
||||
|
||||
func clsResize(img gocv.Mat, resizeShape []int) gocv.Mat {
|
||||
imgH, imgW := resizeShape[1], resizeShape[2]
|
||||
h, w := img.Rows(), img.Cols()
|
||||
ratio := float64(w) / float64(h)
|
||||
var resizeW int
|
||||
if math.Ceil(float64(imgH)*ratio) > float64(imgW) {
|
||||
resizeW = imgW
|
||||
} else {
|
||||
resizeW = int(math.Ceil(float64(imgH) * ratio))
|
||||
}
|
||||
gocv.Resize(img, &img, image.Pt(resizeW, imgH), 0, 0, gocv.InterpolationLinear)
|
||||
if resizeW < imgW {
|
||||
gocv.CopyMakeBorder(img, &img, 0, 0, 0, imgW-resizeW, gocv.BorderConstant, color.RGBA{0, 0, 0, 0})
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
func crnnPreprocess(img gocv.Mat, resizeShape []int, mean []float32, std []float32,
|
||||
scaleFactor float32, whRatio float64, charType string) []float32 {
|
||||
imgH := resizeShape[1]
|
||||
imgW := resizeShape[2]
|
||||
if charType == "ch" {
|
||||
imgW = int(32 * whRatio)
|
||||
}
|
||||
h, w := img.Rows(), img.Cols()
|
||||
ratio := float64(w) / float64(h)
|
||||
var resizeW int
|
||||
if math.Ceil(float64(imgH)*ratio) > float64(imgW) {
|
||||
resizeW = imgW
|
||||
} else {
|
||||
resizeW = int(math.Ceil(float64(imgH) * ratio))
|
||||
}
|
||||
gocv.Resize(img, &img, image.Pt(resizeW, imgH), 0, 0, gocv.InterpolationLinear)
|
||||
|
||||
img.ConvertTo(&img, gocv.MatTypeCV32F)
|
||||
img.DivideFloat(scaleFactor)
|
||||
img.SubtractScalar(gocv.NewScalar(float64(mean[0]), float64(mean[1]), float64(mean[2]), 0))
|
||||
img.DivideScalar(gocv.NewScalar(float64(std[0]), float64(std[1]), float64(std[2]), 0))
|
||||
defer img.Close()
|
||||
|
||||
if resizeW < imgW {
|
||||
gocv.CopyMakeBorder(img, &img, 0, 0, 0, imgW-resizeW, gocv.BorderConstant, color.RGBA{0, 0, 0, 0})
|
||||
}
|
||||
|
||||
c := gocv.Split(img)
|
||||
data := make([]float32, img.Rows()*img.Cols()*img.Channels())
|
||||
for i := 0; i < 3; i++ {
|
||||
defer c[i].Close()
|
||||
x, _ := c[i].DataPtrFloat32()
|
||||
copy(data[i*img.Rows()*img.Cols():], x)
|
||||
}
|
||||
return data
|
||||
}
|
|
@ -0,0 +1,268 @@
|
|||
package ocr
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/LKKlein/gocv"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func getString(args map[string]interface{}, key string, dv string) string {
|
||||
if f, ok := args[key]; ok {
|
||||
return f.(string)
|
||||
}
|
||||
return dv
|
||||
}
|
||||
|
||||
func getFloat64(args map[string]interface{}, key string, dv float64) float64 {
|
||||
if f, ok := args[key]; ok {
|
||||
return f.(float64)
|
||||
}
|
||||
return dv
|
||||
}
|
||||
|
||||
func getInt(args map[string]interface{}, key string, dv int) int {
|
||||
if i, ok := args[key]; ok {
|
||||
return i.(int)
|
||||
}
|
||||
return dv
|
||||
}
|
||||
|
||||
func getBool(args map[string]interface{}, key string, dv bool) bool {
|
||||
if b, ok := args[key]; ok {
|
||||
return b.(bool)
|
||||
}
|
||||
return dv
|
||||
}
|
||||
|
||||
func ReadImage(image_path string) gocv.Mat {
|
||||
img := gocv.IMRead(image_path, gocv.IMReadColor)
|
||||
if img.Empty() {
|
||||
log.Printf("Could not read image %s\n", image_path)
|
||||
os.Exit(1)
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
func clip(value, min, max int) int {
|
||||
if value <= min {
|
||||
return min
|
||||
} else if value >= max {
|
||||
return max
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func minf(data []float32) float32 {
|
||||
v := data[0]
|
||||
for _, val := range data {
|
||||
if val < v {
|
||||
v = val
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func maxf(data []float32) float32 {
|
||||
v := data[0]
|
||||
for _, val := range data {
|
||||
if val > v {
|
||||
v = val
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func mini(data []int) int {
|
||||
v := data[0]
|
||||
for _, val := range data {
|
||||
if val < v {
|
||||
v = val
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func maxi(data []int) int {
|
||||
v := data[0]
|
||||
for _, val := range data {
|
||||
if val > v {
|
||||
v = val
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func argmax(arr []float32) (int, float32) {
|
||||
max_value, index := arr[0], 0
|
||||
for i, item := range arr {
|
||||
if item > max_value {
|
||||
max_value = item
|
||||
index = i
|
||||
}
|
||||
}
|
||||
return index, max_value
|
||||
}
|
||||
|
||||
func checkModelExists(modelPath string) bool {
|
||||
if isPathExist(modelPath+"/model") && isPathExist(modelPath+"/params") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(modelPath, "http://") ||
|
||||
strings.HasPrefix(modelPath, "ftp://") || strings.HasPrefix(modelPath, "https://") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func downloadFile(filepath, url string) error {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
out, err := os.Create(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
log.Println("[download_file] from:", url, " to:", filepath)
|
||||
return err
|
||||
}
|
||||
|
||||
func isPathExist(path string) bool {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return true
|
||||
} else if os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func downloadModel(modelDir, modelPath string) (string, error) {
|
||||
if modelPath != "" && (strings.HasPrefix(modelPath, "http://") ||
|
||||
strings.HasPrefix(modelPath, "ftp://") || strings.HasPrefix(modelPath, "https://")) {
|
||||
reg := regexp.MustCompile("^(http|https|ftp)://[^/]+/(.+)")
|
||||
suffix := reg.FindStringSubmatch(modelPath)[2]
|
||||
outPath := filepath.Join(modelDir, suffix)
|
||||
outDir := filepath.Dir(outPath)
|
||||
if !isPathExist(outDir) {
|
||||
os.MkdirAll(outDir, os.ModePerm)
|
||||
}
|
||||
|
||||
if !isPathExist(outPath) {
|
||||
err := downloadFile(outPath, modelPath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(outPath, ".tar") {
|
||||
_, f := path.Split(suffix)
|
||||
nextDir := strings.TrimSuffix(f, ".tar")
|
||||
finalPath := modelDir + "/" + nextDir
|
||||
if !checkModelExists(finalPath) {
|
||||
unTar(modelDir, outPath)
|
||||
}
|
||||
return finalPath, nil
|
||||
}
|
||||
return outPath, nil
|
||||
}
|
||||
return modelPath, nil
|
||||
}
|
||||
|
||||
func unTar(dst, src string) (err error) {
|
||||
fr, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fr.Close()
|
||||
|
||||
tr := tar.NewReader(fr)
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
return nil
|
||||
case err != nil:
|
||||
return err
|
||||
case hdr == nil:
|
||||
continue
|
||||
}
|
||||
|
||||
dstFileDir := filepath.Join(dst, hdr.Name)
|
||||
|
||||
switch hdr.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if b := isPathExist(dstFileDir); !b {
|
||||
if err := os.MkdirAll(dstFileDir, 0775); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case tar.TypeReg:
|
||||
file, err := os.OpenFile(dstFileDir, os.O_CREATE|os.O_RDWR, os.FileMode(hdr.Mode))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err2 := io.Copy(file, tr)
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
file.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func readLines2StringSlice(path string) []string {
|
||||
content, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
log.Println("read file error!")
|
||||
return nil
|
||||
}
|
||||
lines := strings.Split(string(content), "\n")
|
||||
return lines
|
||||
}
|
||||
|
||||
func ReadYaml(yamlPath string) (map[string]interface{}, error) {
|
||||
data, err := ioutil.ReadFile(yamlPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var body interface{}
|
||||
if err := yaml.Unmarshal(data, &body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body = convertYaml2Map(body)
|
||||
return body.(map[string]interface{}), nil
|
||||
}
|
||||
|
||||
func convertYaml2Map(i interface{}) interface{} {
|
||||
switch x := i.(type) {
|
||||
case map[interface{}]interface{}:
|
||||
m2 := map[string]interface{}{}
|
||||
for k, v := range x {
|
||||
m2[k.(string)] = convertYaml2Map(v)
|
||||
}
|
||||
return m2
|
||||
case []interface{}:
|
||||
for i, v := range x {
|
||||
x[i] = convertYaml2Map(v)
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package paddle
|
||||
|
||||
// #cgo CFLAGS: -I../paddle_c/paddle/include
|
||||
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -Wl,-rpath=\$ORIGIN/paddle_c/paddle/lib -lpaddle_fluid_c
|
||||
// #include <stdbool.h>
|
||||
// #include "paddle_c_api.h"
|
||||
import "C"
|
||||
import "fmt"
|
||||
|
||||
func ConvertCBooleanToGo(b C.bool) bool {
|
||||
var c_false C.bool
|
||||
if b != c_false {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func numel(shape []int32) int32 {
|
||||
n := int32(1)
|
||||
for _, d := range shape {
|
||||
n *= d
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func bug(format string, args ...interface{}) error {
|
||||
return fmt.Errorf("Bug %v", fmt.Sprintf(format, args...))
|
||||
}
|
|
@ -0,0 +1,183 @@
|
|||
package paddle
|
||||
|
||||
// #cgo CFLAGS: -I../paddle_c/paddle/include
|
||||
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -Wl,-rpath,$ORIGIN/paddle_c/paddle/lib -lpaddle_fluid_c
|
||||
// #include <stdbool.h>
|
||||
// #include <stdlib.h>
|
||||
// #include <paddle_c_api.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type Precision C.Precision
|
||||
|
||||
const (
|
||||
Precision_FLOAT32 Precision = C.kFloat32
|
||||
Precision_INT8 Precision = C.kInt8
|
||||
Precision_HALF Precision = C.kHalf
|
||||
)
|
||||
|
||||
type AnalysisConfig struct {
|
||||
c *C.PD_AnalysisConfig
|
||||
}
|
||||
|
||||
func NewAnalysisConfig() *AnalysisConfig {
|
||||
c_config := C.PD_NewAnalysisConfig()
|
||||
config := &AnalysisConfig{c: c_config}
|
||||
runtime.SetFinalizer(config, (*AnalysisConfig).finalize)
|
||||
return config
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) finalize() {
|
||||
C.PD_DeleteAnalysisConfig(config.c)
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) SetModel(model, params string) {
|
||||
c_model := C.CString(model)
|
||||
defer C.free(unsafe.Pointer(c_model))
|
||||
var c_params *C.char
|
||||
if params == "" {
|
||||
c_params = nil
|
||||
} else {
|
||||
c_params = C.CString(params)
|
||||
defer C.free(unsafe.Pointer(c_params))
|
||||
}
|
||||
|
||||
C.PD_SetModel(config.c, c_model, c_params)
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) ModelDir() string {
|
||||
return C.GoString(C.PD_ModelDir(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) ProgFile() string {
|
||||
return C.GoString(C.PD_ProgFile(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) ParamsFile() string {
|
||||
return C.GoString(C.PD_ParamsFile(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) EnableUseGpu(memory_pool_init_size_mb int, device_id int) {
|
||||
C.PD_EnableUseGpu(config.c, C.int(memory_pool_init_size_mb), C.int(device_id))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) DisableGpu() {
|
||||
C.PD_DisableGpu(config.c)
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) UseGpu() bool {
|
||||
return ConvertCBooleanToGo(C.PD_UseGpu(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) GpuDeviceId() int {
|
||||
return int(C.PD_GpuDeviceId(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) MemoryPoolInitSizeMb() int {
|
||||
return int(C.PD_MemoryPoolInitSizeMb(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) EnableCudnn() {
|
||||
C.PD_EnableCUDNN(config.c)
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) CudnnEnabled() bool {
|
||||
return ConvertCBooleanToGo(C.PD_CudnnEnabled(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) SwitchIrOptim(x bool) {
|
||||
C.PD_SwitchIrOptim(config.c, C.bool(x))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) IrOptim() bool {
|
||||
return ConvertCBooleanToGo(C.PD_IrOptim(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) SwitchUseFeedFetchOps(x bool) {
|
||||
C.PD_SwitchUseFeedFetchOps(config.c, C.bool(x))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) UseFeedFetchOpsEnabled() bool {
|
||||
return ConvertCBooleanToGo(C.PD_UseFeedFetchOpsEnabled(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) SwitchSpecifyInputNames(x bool) {
|
||||
C.PD_SwitchSpecifyInputNames(config.c, C.bool(x))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) SpecifyInputName() bool {
|
||||
return ConvertCBooleanToGo(C.PD_SpecifyInputName(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) EnableTensorRtEngine(workspace_size int, max_batch_size int, min_subgraph_size int, precision Precision, use_static bool, use_calib_mode bool) {
|
||||
C.PD_EnableTensorRtEngine(config.c, C.int(workspace_size), C.int(max_batch_size), C.int(min_subgraph_size), C.Precision(precision), C.bool(use_static), C.bool(use_calib_mode))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) TensorrtEngineEnabled() bool {
|
||||
return ConvertCBooleanToGo(C.PD_TensorrtEngineEnabled(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) SwitchIrDebug(x bool) {
|
||||
C.PD_SwitchIrDebug(config.c, C.bool(x))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) EnableMkldnn() {
|
||||
C.PD_EnableMKLDNN(config.c)
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) SetCpuMathLibraryNumThreads(n int) {
|
||||
C.PD_SetCpuMathLibraryNumThreads(config.c, C.int(n))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) CpuMathLibraryNumThreads() int {
|
||||
return int(C.PD_CpuMathLibraryNumThreads(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) EnableMkldnnQuantizer() {
|
||||
C.PD_EnableMkldnnQuantizer(config.c)
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) MkldnnQuantizerEnabled() bool {
|
||||
return ConvertCBooleanToGo(C.PD_MkldnnQuantizerEnabled(config.c))
|
||||
}
|
||||
|
||||
// SetModelBuffer
|
||||
// ModelFromMemory
|
||||
|
||||
func (config *AnalysisConfig) EnableMemoryOptim() {
|
||||
C.PD_EnableMemoryOptim(config.c)
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) MemoryOptimEnabled() bool {
|
||||
return ConvertCBooleanToGo(C.PD_MemoryOptimEnabled(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) EnableProfile() {
|
||||
C.PD_EnableProfile(config.c)
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) ProfileEnabled() bool {
|
||||
return ConvertCBooleanToGo(C.PD_ProfileEnabled(config.c))
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) DisableGlogInfo() {
|
||||
C.PD_DisableGlogInfo(config.c)
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) DeletePass(pass string) {
|
||||
c_pass := C.CString(pass)
|
||||
defer C.free(unsafe.Pointer(c_pass))
|
||||
C.PD_DeletePass(config.c, c_pass)
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) SetInValid() {
|
||||
C.PD_SetInValid(config.c)
|
||||
}
|
||||
|
||||
func (config *AnalysisConfig) IsValid() bool {
|
||||
return ConvertCBooleanToGo(C.PD_IsValid(config.c))
|
||||
}
|
|
@ -0,0 +1,103 @@
|
|||
package paddle
|
||||
|
||||
// #cgo CFLAGS: -I../paddle_c/paddle/include
|
||||
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -Wl,-rpath,$ORIGIN/paddle_c/paddle/lib -lpaddle_fluid_c
|
||||
// #include <stdbool.h>
|
||||
// #include "paddle_c_api.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type Predictor struct {
|
||||
c *C.PD_Predictor
|
||||
}
|
||||
|
||||
func NewPredictor(config *AnalysisConfig) *Predictor {
|
||||
c_predictor := C.PD_NewPredictor((*config).c)
|
||||
predictor := &Predictor{c: c_predictor}
|
||||
runtime.SetFinalizer(predictor, (*Predictor).finalize)
|
||||
return predictor
|
||||
}
|
||||
|
||||
func (predictor *Predictor) finalize() {
|
||||
C.PD_DeletePredictor(predictor.c)
|
||||
}
|
||||
|
||||
func DeletePredictor(predictor *Predictor) {
|
||||
C.PD_DeletePredictor(predictor.c)
|
||||
}
|
||||
|
||||
func (predictor *Predictor) GetInputNum() int {
|
||||
return int(C.PD_GetInputNum(predictor.c))
|
||||
}
|
||||
|
||||
func (predictor *Predictor) GetOutputNum() int {
|
||||
return int(C.PD_GetOutputNum(predictor.c))
|
||||
}
|
||||
|
||||
func (predictor *Predictor) GetInputName(n int) string {
|
||||
return C.GoString(C.PD_GetInputName(predictor.c, C.int(n)))
|
||||
}
|
||||
|
||||
func (predictor *Predictor) GetOutputName(n int) string {
|
||||
return C.GoString(C.PD_GetOutputName(predictor.c, C.int(n)))
|
||||
}
|
||||
|
||||
func (predictor *Predictor) GetInputTensors() [](*ZeroCopyTensor) {
|
||||
var result [](*ZeroCopyTensor)
|
||||
for i := 0; i < predictor.GetInputNum(); i++ {
|
||||
tensor := NewZeroCopyTensor()
|
||||
tensor.c.name = C.PD_GetInputName(predictor.c, C.int(i))
|
||||
result = append(result, tensor)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (predictor *Predictor) GetOutputTensors() [](*ZeroCopyTensor) {
|
||||
var result [](*ZeroCopyTensor)
|
||||
for i := 0; i < predictor.GetOutputNum(); i++ {
|
||||
tensor := NewZeroCopyTensor()
|
||||
tensor.c.name = C.PD_GetOutputName(predictor.c, C.int(i))
|
||||
result = append(result, tensor)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (predictor *Predictor) GetInputNames() []string {
|
||||
names := make([]string, predictor.GetInputNum())
|
||||
for i := 0; i < len(names); i++ {
|
||||
names[i] = predictor.GetInputName(i)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func (predictor *Predictor) GetOutputNames() []string {
|
||||
names := make([]string, predictor.GetOutputNum())
|
||||
for i := 0; i < len(names); i++ {
|
||||
names[i] = predictor.GetOutputName(i)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func (predictor *Predictor) SetZeroCopyInput(tensor *ZeroCopyTensor) {
|
||||
C.PD_SetZeroCopyInput(predictor.c, tensor.c)
|
||||
}
|
||||
|
||||
func (predictor *Predictor) GetZeroCopyOutput(tensor *ZeroCopyTensor) {
|
||||
C.PD_GetZeroCopyOutput(predictor.c, tensor.c)
|
||||
tensor.name = C.GoString(tensor.c.name)
|
||||
var shape []int32
|
||||
shape_hdr := (*reflect.SliceHeader)(unsafe.Pointer(&shape))
|
||||
shape_hdr.Data = uintptr(unsafe.Pointer(tensor.c.shape.data))
|
||||
shape_hdr.Len = int(tensor.c.shape.length / C.sizeof_int)
|
||||
shape_hdr.Cap = int(tensor.c.shape.length / C.sizeof_int)
|
||||
tensor.Reshape(shape)
|
||||
}
|
||||
|
||||
func (predictor *Predictor) ZeroCopyRun() {
|
||||
C.PD_ZeroCopyRun(predictor.c)
|
||||
}
|
|
@ -0,0 +1,249 @@
|
|||
package paddle
|
||||
|
||||
// #cgo CFLAGS: -I../paddle_c/paddle/include
|
||||
// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -Wl,-rpath,$ORIGIN/paddle_c/paddle/lib -lpaddle_fluid_c
|
||||
// #include <stdbool.h>
|
||||
// #include <stdlib.h>
|
||||
// #include <string.h>
|
||||
// #include <paddle_c_api.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type PaddleDType C.PD_DataType
|
||||
|
||||
const (
|
||||
FLOAT32 PaddleDType = C.PD_FLOAT32
|
||||
INT32 PaddleDType = C.PD_INT32
|
||||
INT64 PaddleDType = C.PD_INT64
|
||||
UINT8 PaddleDType = C.PD_UINT8
|
||||
UNKDTYPE PaddleDType = C.PD_UNKDTYPE
|
||||
)
|
||||
|
||||
var types = []struct {
|
||||
gotype reflect.Type
|
||||
dtype PaddleDType
|
||||
}{
|
||||
{reflect.TypeOf(float32(0)), FLOAT32},
|
||||
{reflect.TypeOf(int32(0)), INT32},
|
||||
{reflect.TypeOf(int64(0)), INT64},
|
||||
{reflect.TypeOf(uint8(0)), UINT8},
|
||||
}
|
||||
|
||||
func TypeOfShape(dtype PaddleDType, shape []int32) reflect.Type {
|
||||
var ret reflect.Type
|
||||
for _, t := range types {
|
||||
if dtype == PaddleDType(t.dtype) {
|
||||
ret = t.gotype
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ret == nil {
|
||||
panic(bug("Data %v type is not support", dtype))
|
||||
}
|
||||
|
||||
for range shape {
|
||||
ret = reflect.SliceOf(ret)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
type ZeroCopyTensor struct {
|
||||
c *C.PD_ZeroCopyTensor
|
||||
name string
|
||||
shape []int32
|
||||
}
|
||||
|
||||
func NewZeroCopyTensor() *ZeroCopyTensor {
|
||||
c_tensor := C.PD_NewZeroCopyTensor()
|
||||
|
||||
tensor := &ZeroCopyTensor{c: c_tensor}
|
||||
runtime.SetFinalizer(tensor, (*ZeroCopyTensor).finalize)
|
||||
return tensor
|
||||
}
|
||||
|
||||
func (tensor *ZeroCopyTensor) finalize() {
|
||||
C.PD_DeleteZeroCopyTensor(tensor.c)
|
||||
}
|
||||
|
||||
func (tensor *ZeroCopyTensor) Shape() []int32 {
|
||||
return tensor.shape
|
||||
}
|
||||
|
||||
func (tensor *ZeroCopyTensor) Name() string {
|
||||
return C.GoString(tensor.c.name)
|
||||
}
|
||||
|
||||
func (tensor *ZeroCopyTensor) Rename(name string) {
|
||||
tensor.name = name
|
||||
tensor.c.name = (*C.char)(unsafe.Pointer(tensor.c.name))
|
||||
//tensor.c.name = C.CString(tensor.name)
|
||||
//defer C.free(unsafe.Pointer(tensor.c.name))
|
||||
}
|
||||
|
||||
func (tensor *ZeroCopyTensor) Reshape(shape []int32) {
|
||||
tensor.shape = make([]int32, len(shape))
|
||||
copy(tensor.shape, shape)
|
||||
length := C.sizeof_int * C.size_t(len(shape))
|
||||
if tensor.c.shape.capacity < C.size_t(length) {
|
||||
if tensor.c.shape.capacity != C.size_t(0) {
|
||||
C.free(tensor.c.shape.data)
|
||||
}
|
||||
tensor.c.shape.data = C.malloc(length)
|
||||
tensor.c.shape.capacity = length
|
||||
}
|
||||
tensor.c.shape.length = length
|
||||
C.memcpy(tensor.c.shape.data, unsafe.Pointer(&shape[0]), length)
|
||||
}
|
||||
|
||||
func (tensor *ZeroCopyTensor) DataType() PaddleDType {
|
||||
return PaddleDType(tensor.c.dtype)
|
||||
}
|
||||
|
||||
func (tensor *ZeroCopyTensor) SetValue(value interface{}) {
|
||||
val := reflect.ValueOf(value)
|
||||
shape, dtype := ShapeAndTypeOf(val)
|
||||
num := numel(shape)
|
||||
length := C.size_t(SizeofDataType(dtype) * num)
|
||||
if tensor.c.data.capacity < length {
|
||||
if tensor.c.data.capacity != C.size_t(0) {
|
||||
C.free(tensor.c.data.data)
|
||||
}
|
||||
tensor.c.data.data = C.malloc(length)
|
||||
tensor.c.data.capacity = length
|
||||
}
|
||||
tensor.c.data.length = length
|
||||
|
||||
switch dtype {
|
||||
case PaddleDType(UINT8):
|
||||
data := val.Interface().([]uint8)
|
||||
C.memcpy(tensor.c.data.data, unsafe.Pointer(&data[0]), length)
|
||||
case PaddleDType(INT32):
|
||||
data := val.Interface().([]int32)
|
||||
C.memcpy(tensor.c.data.data, unsafe.Pointer(&data[0]), length)
|
||||
case PaddleDType(INT64):
|
||||
data := val.Interface().([]int64)
|
||||
C.memcpy(tensor.c.data.data, unsafe.Pointer(&data[0]), length)
|
||||
case PaddleDType(FLOAT32):
|
||||
data := val.Interface().([]float32)
|
||||
C.memcpy(tensor.c.data.data, unsafe.Pointer(&data[0]), length)
|
||||
}
|
||||
tensor.c.dtype = C.PD_DataType(dtype)
|
||||
}
|
||||
|
||||
func TypeOf(dtype PaddleDType, shape []int32) reflect.Type {
|
||||
var ret reflect.Type
|
||||
for _, t := range types {
|
||||
if t.dtype == dtype {
|
||||
ret = t.gotype
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for range shape {
|
||||
ret = reflect.SliceOf(ret)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (tensor *ZeroCopyTensor) Value() interface{} {
|
||||
t := TypeOf(PaddleDType(tensor.c.dtype), tensor.shape)
|
||||
value := reflect.New(t)
|
||||
c_bytes := tensor.c.data.data
|
||||
length := tensor.c.data.length
|
||||
var slice []byte
|
||||
if unsafe.Sizeof(unsafe.Pointer(nil)) == 8 {
|
||||
slice = (*[1<<50 - 1]byte)(unsafe.Pointer(c_bytes))[:length:length]
|
||||
} else {
|
||||
slice = (*[1 << 30]byte)(unsafe.Pointer(c_bytes))[:length:length]
|
||||
}
|
||||
r := bytes.NewReader(slice)
|
||||
DecodeTensor(r, tensor.Shape(), t, value)
|
||||
return reflect.Indirect(value).Interface()
|
||||
}
|
||||
|
||||
func (tensor *ZeroCopyTensor) Lod() []uint {
|
||||
var val []uint
|
||||
valHdr := (*reflect.SliceHeader)(unsafe.Pointer(&val))
|
||||
valHdr.Data = uintptr(unsafe.Pointer(tensor.c.lod.data))
|
||||
valHdr.Len = int(tensor.c.lod.length / C.sizeof_size_t)
|
||||
valHdr.Cap = int(tensor.c.lod.length / C.sizeof_size_t)
|
||||
return val
|
||||
}
|
||||
|
||||
func Endian() binary.ByteOrder {
|
||||
buf := [2]byte{}
|
||||
*(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD)
|
||||
|
||||
var endian binary.ByteOrder
|
||||
|
||||
switch buf {
|
||||
case [2]byte{0xCD, 0xAB}:
|
||||
endian = binary.LittleEndian
|
||||
case [2]byte{0xAB, 0xCD}:
|
||||
endian = binary.BigEndian
|
||||
default:
|
||||
panic("Could not determine native endianness.")
|
||||
}
|
||||
return endian
|
||||
}
|
||||
|
||||
func DecodeTensor(r *bytes.Reader, shape []int32, t reflect.Type, ptr reflect.Value) {
|
||||
switch t.Kind() {
|
||||
case reflect.Uint8, reflect.Int32, reflect.Int64, reflect.Float32:
|
||||
binary.Read(r, Endian(), ptr.Interface())
|
||||
case reflect.Slice:
|
||||
value := reflect.Indirect(ptr)
|
||||
value.Set(reflect.MakeSlice(t, int(shape[0]), int(shape[0])))
|
||||
if len(shape) == 1 && value.Len() > 0 {
|
||||
switch value.Index(0).Kind() {
|
||||
case reflect.Uint8, reflect.Int32, reflect.Int64, reflect.Float32:
|
||||
binary.Read(r, Endian(), value.Interface())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
DecodeTensor(r, shape[1:], t.Elem(), value.Index(i).Addr())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func SizeofDataType(dtype PaddleDType) int32 {
|
||||
switch dtype {
|
||||
case UINT8:
|
||||
return int32(C.sizeof_uchar)
|
||||
case INT32:
|
||||
return int32(C.sizeof_int)
|
||||
case INT64:
|
||||
return int32(C.sizeof_longlong)
|
||||
case FLOAT32:
|
||||
return int32(C.sizeof_float)
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func ShapeAndTypeOf(val reflect.Value) (shape []int32, dt PaddleDType) {
|
||||
gotype := val.Type()
|
||||
for gotype.Kind() == reflect.Array || gotype.Kind() == reflect.Slice {
|
||||
shape = append(shape, int32(val.Len()))
|
||||
if val.Len() > 0 {
|
||||
val = val.Index(0)
|
||||
}
|
||||
gotype = gotype.Elem()
|
||||
}
|
||||
|
||||
for _, t := range types {
|
||||
if gotype.Kind() == t.gotype.Kind() {
|
||||
return shape, PaddleDType(t.dtype)
|
||||
}
|
||||
}
|
||||
return shape, dt
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
# Paddle C预测库目录
|
||||
|
||||
## 编译安装
|
||||
使用cmake编译paddle,并打开-DON_INFER=ON,在编译目录下得到paddle_inference_c_install_dir,将该目录下的所有文件复制到本目录下。
|
||||
|
||||
详细编译步骤请参见[README.md](../README.md) 或者官方文档指导 https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/build_and_install_lib_cn.html#id12
|
Loading…
Reference in New Issue