PaddleOCR/notebook/notebook_ch/4.ppocr_system_strategy/PP-OCR系统及优化策略.ipynb

3492 lines
2.5 MiB
Plaintext
Raw Normal View History

2021-12-30 15:50:45 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# 1. PP-OCR系统简介与总览\n",
"\n",
"前两章主要介绍了DBNet文字检测算法以及CRNN文字识别算法。然而对于我们实际场景中的一张图像想要单独基于文字检测或者识别模型是无法同时获取文字位置与文字内容的因此我们将文字检测算法以及文字识别算法进行串联构建了PP-OCR文字检测与识别系统。在实际使用过程中检测出的文字方向可能不是我们期望的方向最终导致文字识别错误因此我们在PP-OCR系统中也引入了方向分类器。\n",
"\n",
"本章主要介绍PP-OCR文字检测与识别系统以及该系统中涉及到的优化策略。通过本节课的学习您可以获得\n",
"\n",
"* PaddleOCR策略调优技巧\n",
"* 文本检测、识别、方向分类器模型的优化技巧和优化方法\n",
"\n",
"PP-OCR系统共经历了2次优化下面对PP-OCR系统和这2次优化进行简单介绍。\n",
"\n",
"## 1.1 PP-OCR系统与优化策略简介\n",
"\n",
"PP-OCR中对于一张图像如果希望提取其中的文字信息需要完成以下几个步骤\n",
"\n",
"* 使用文本检测的方法获取文本区域多边形信息PP-OCR中文本检测使用的是DBNet因此获取的是四点信息。\n",
"* 对上述文本多边形区域进行裁剪与透视变换校正,将文本区域转化成矩形框,再使用方向分类器对方向进行校正。\n",
"* 基于包含文字区域的矩形框进行文本识别,得到最终识别结果。\n",
"\n",
"上面便完成了对于一张图像的文本检测与识别过程。\n",
"\n",
"PP-OCR的系统框图如下所示。\n",
"\n",
"<p align=\"center\">\n",
" <img src=\"https://ai-studio-static-online.cdn.bcebos.com/9665bc66a29346d9bf04ef2f664cf488da855a37da2249a3975b809e5e9038c4\" align=\"middle\" width = \"1200\"/>\n",
"<p align=\"center\">\n",
"<center>PP-OCR系统框图</center>\n",
"\n",
"文本检测基于后处理方案比较简单的DBNet文字区域校正主要使用几何变换以及方向分类器文本识别使用了基于融合了卷积特征与序列特征的CRNN模型使用CTC loss解决预测结果与标签不一致的问题。\n",
"\n",
"PP-OCR从骨干网络、学习率策略、数据增广、模型裁剪量化等方面共使用了19个策略对模型进行优化瘦身最终打造了面向服务器端的PP-OCR server系统以及面向移动端的PP-OCR mobile系统。\n",
"\n",
"## 1.2 PP-OCRv2系统与优化策略简介\n",
"\n",
"相比于PP-OCR PP-OCRv2 在骨干网络、数据增广、损失函数这三个方面进行进一步优化,解决端侧预测效率较差、背景复杂以及相似字符的误识等问题,同时引入了知识蒸馏训练策略,进一步提升模型精度。具体地:\n",
"\n",
"* 检测模型优化: (1) 采用 CML 协同互学习知识蒸馏策略;(2) CopyPaste 数据增广策略;\n",
"* 识别模型优化: (1) PP-LCNet 轻量级骨干网络;(2) U-DML 改进知识蒸馏策略; (3) Enhanced CTC loss 损失函数改进。\n",
"\n",
"从效果上看,主要有三个方面提升:\n",
"\n",
"* 在模型效果上,相对于 PP-OCR mobile 版本提升超7%\n",
"* 在速度上,相对于 PP-OCR server 版本提升超过220%\n",
"* 在模型大小上11.6M 的总大小,服务器端和移动端都可以轻松部署。\n",
"\n",
"PP-OCRv2 模型与之前 PP-OCR 系列模型的精度、预测耗时、模型大小对比图如下所示。\n",
"\n",
"<p align=\"center\">\n",
" <img src=\"https://ai-studio-static-online.cdn.bcebos.com/18952983582a4a86853aab4c74e3c429dc0c5ab99f4c48eeb1b28cb0a174aa23\" align=\"middle\" width = \"800\"/>\n",
"<p align=\"center\">\n",
"<center>PP-OCRv2与PP-OCR的速度、精度、模型大小对比</center>\n",
"\n",
"PP-OCRv2的系统框图如下所示。\n",
"\n",
"<p align=\"center\">\n",
" <img src=\"https://ai-studio-static-online.cdn.bcebos.com/ccd0d94112bc4adbb42645c71a9a91abeefe4015781646b88da9f223b2d3401f\" align=\"middle\" width = \"1200\"/>\n",
"<p align=\"center\">\n",
"<center>PP-OCRv2系统框图</center>\n",
" \n",
"\n",
"本章将对上述PP-OCR以及PP-OCRv2系统优化策略进行详细的解读。"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# 2. PP-OCR 优化策略\n",
"\n",
"PP-OCR系统包括文本检测器、方向分类器以及文本识别器。本节针对这三个方向的模型优化策略进行详细介绍。\n",
"\n",
"## 2.1 文本检测\n",
"\n",
"PP-OCR中的文本检测基于DBNet (Differentiable Binarization)模型它基于分割方案后处理简单。DBNet的具体模型结构如下图。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/33116a2a1f5a4c19b22f366d3f9a360083ff2d8fe7e343fea573eb22a0801143\" width = \"1200\" />\n",
"</div>\n",
"<center>DBNet框图</center>\n",
"\n",
"DBNet通过骨干网络(backbone)提取特征使用DBFPN的结构(neck)对各阶段的特征进行融合,得到融合后的特征。融合后的特征经过卷积等操作(head)进行解码,生成概率图和阈值图,二者融合后计算得到一个近似的二值图。计算损失函数时,对这三个特征图均计算损失函数,这里把二值化的监督也也加入训练过程,从而让模型学习到更准确的边界。\n",
"\n",
"DBNet中使用了6种优化策略用于提升模型精度与速度包括骨干网络、特征金字塔网络、头部结构、学习率策略、模型裁剪等策略。在验证集上不同模块的消融实验结论如下所示。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/a78fafbe853e410d9febb5ff412970081609f15a4059493f88c3aa6dc8278d25\" width = \"1000\" />\n",
"</div>\n",
"<center>DBNet消融实验</center>\n",
"\n",
"\n",
"下面进行详细说明。\n",
"\n",
"### 2.1.1 轻量级骨干网络\n",
"\n",
"骨干网络的大小对文本检测器的模型大小有重要影响。因此在构建超轻量检测模型时应选择轻量的骨干网络。随着图像分类技术的发展MobileNetV1、MobileNetV2、MobileNetV3和ShuffleNetV2系列常用作轻量骨干网络。每个系列都有不同的模型大小和性能表现。[PaddeClas](https://github.com/PaddlePaddle/PaddleClas)提供了20多种轻量级骨干网络。他们在ARM上的`精度-速度`曲线如下图所示。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/d3855eac989542d49e5dd69e2f09de284ec02fd2c3314f8b9db7491630e0cd14\" width = \"800\" />\n",
"</div>\n",
"<center>PaddleClas中骨干网络的\"速度-精度\"曲线</center>\n",
"\n",
"在预测时间相同的情况下MobileNetV3系列可以实现更高的精度。作者在设计的时候为了覆盖尽可能多的场景使用scale这个参数来调整特征图通道数标准为1x如果是0.5x则表示该网络中部分特征图通道数为1x对应网络的0.5倍。为了进一步平衡准确率和效率在V3的尺寸选择上我们采用了MobileNetV3_large 0.5x的结构。\n",
"\n",
"下面打印出DBNet中MobileNetV3各个阶段的特征图尺寸。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fatal: destination path 'PaddleOCR' already exists and is not an empty directory.\n",
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
"Requirement already satisfied: pip in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (21.3.1)\n",
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
"Requirement already satisfied: shapely in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 1)) (1.8.0)\n",
"Requirement already satisfied: scikit-image in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 2)) (0.19.1)\n",
"Requirement already satisfied: imgaug==0.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 3)) (0.4.0)\n",
"Requirement already satisfied: pyclipper in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 4)) (1.3.0.post2)\n",
"Requirement already satisfied: lmdb in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 5)) (1.2.1)\n",
"Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 6)) (4.27.0)\n",
"Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 7)) (1.20.3)\n",
"Requirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 8)) (2.2.0)\n",
"Requirement already satisfied: python-Levenshtein in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 9)) (0.12.2)\n",
"Requirement already satisfied: opencv-contrib-python==4.4.0.46 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 10)) (4.4.0.46)\n",
"Requirement already satisfied: cython in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 11)) (0.29)\n",
"Requirement already satisfied: lxml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 12)) (4.7.1)\n",
"Requirement already satisfied: premailer in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 13)) (3.10.0)\n",
"Requirement already satisfied: openpyxl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 14)) (3.0.5)\n",
"Requirement already satisfied: fasttext==0.9.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 15)) (0.9.1)\n",
"Requirement already satisfied: imageio in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (2.6.1)\n",
"Requirement already satisfied: scipy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (1.6.3)\n",
"Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (7.1.2)\n",
"Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (4.1.1.26)\n",
"Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (1.15.0)\n",
"Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (2.2.3)\n",
"Requirement already satisfied: pybind11>=2.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from fasttext==0.9.1->-r requirements.txt (line 15)) (2.8.1)\n",
"Requirement already satisfied: setuptools>=0.7.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from fasttext==0.9.1->-r requirements.txt (line 15)) (56.2.0)\n",
"Requirement already satisfied: packaging>=20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->-r requirements.txt (line 2)) (20.9)\n",
"Requirement already satisfied: networkx>=2.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->-r requirements.txt (line 2)) (2.4)\n",
"Requirement already satisfied: tifffile>=2019.7.26 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->-r requirements.txt (line 2)) (2021.11.2)\n",
"Requirement already satisfied: PyWavelets>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image->-r requirements.txt (line 2)) (1.2.0)\n",
"Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.14.0)\n",
"Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.8.2)\n",
"Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.0.0)\n",
"Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.5)\n",
"Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.21.0)\n",
"Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.7.1.1)\n",
"Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (2.22.0)\n",
"Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.1)\n",
"Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.8.53)\n",
"Requirement already satisfied: cssutils in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 13)) (2.3.0)\n",
"Requirement already satisfied: cachetools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 13)) (4.0.0)\n",
"Requirement already satisfied: cssselect in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 13)) (1.1.0)\n",
"Requirement already satisfied: et-xmlfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 14)) (1.0.1)\n",
"Requirement already satisfied: jdcal in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 14)) (1.4.1)\n",
"Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.6.0)\n",
"Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.23)\n",
"Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.6.1)\n",
"Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.2.0)\n",
"Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (0.16.0)\n",
"Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (7.0)\n",
"Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (2.11.0)\n",
"Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.0)\n",
"Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2.8.0)\n",
"Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2019.3)\n",
"Requirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.2->scikit-image->-r requirements.txt (line 2)) (4.4.2)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from packaging>=20.0->scikit-image->-r requirements.txt (line 2)) (2.4.2)\n",
"Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (0.18.0)\n",
"Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (3.9.9)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->-r requirements.txt (line 3)) (1.1.0)\n",
"Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->-r requirements.txt (line 3)) (0.10.0)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->imgaug==0.4.0->-r requirements.txt (line 3)) (2.8.0)\n",
"Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.0)\n",
"Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (2.0.1)\n",
"Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (5.1.2)\n",
"Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (0.10.0)\n",
"Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.4)\n",
"Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (16.7.9)\n",
"Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.4.10)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (1.25.6)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2.8)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2019.9.11)\n",
"Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.1)\n",
"Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (3.6.0)\n"
]
}
],
"source": [
"import os\n",
"import sys\n",
"\n",
"# 下载代码\n",
"os.chdir(\"/home/aistudio/\")\n",
"!git clone https://gitee.com/paddlepaddle/PaddleOCR.git\n",
"# 切换工作目录\n",
"os.chdir(\"/home/aistudio/PaddleOCR/\")\n",
"!pip install -U pip\n",
"!pip install -r requirements.txt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"the shape of 0 stage: [1, 16, 160, 160]\n",
"the shape of 1 stage: [1, 24, 80, 80]\n",
"the shape of 2 stage: [1, 56, 40, 40]\n",
"the shape of 3 stage: [1, 480, 20, 20]\n"
]
}
],
"source": [
"# 具体代码实现位于:\n",
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/modeling/backbones/det_mobilenet_v3.py\n",
"import numpy as np\n",
"import paddle\n",
"\n",
"# 设置随机输入\n",
"inputs = np.random.rand(1, 3, 640, 640).astype(np.float32)\n",
"x = paddle.to_tensor(inputs)\n",
"\n",
"# 导入MobileNetV3库\n",
"from ppocr.modeling.backbones.det_mobilenet_v3 import MobileNetV3\n",
"\n",
"# 模型定义\n",
"backbone_mv3 = MobileNetV3(scale=0.5, model_name='large')\n",
"\n",
"# 模型forward\n",
"bk_out = backbone_mv3(x)\n",
"\n",
"# 模型中间层打印\n",
"for i, stage_out in enumerate(bk_out):\n",
" print(\"the shape of \",i,'stage: ',stage_out.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"### 2.1.2 轻量级特征金字塔网络DBFPN结构\n",
"\n",
"文本检测器的特征融合(neck)部分DBFPN与目标检测任务中的FPN结构类似融合不同尺度的特征图以提升不同尺度的文本区域检测效果。\n",
"\n",
"为了方便合并不同通道的特征图,这里使用`1×1`的卷积将特征图减少到相同数量的通道。\n",
"\n",
"概率图和阈值图是由卷积融合的特征图生成的卷积也与inner_channels相关联。因此inner_channels对模型尺寸有很大的影响。当inner_channels由256减小到96时模型尺寸由7M减小到4.1M速度提升48%,但精度只是略有下降。\n",
"\n",
"下面打印DBFPN的结构以及对于骨干网络特征图的融合结果。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DBFPN(\n",
" (in2_conv): Conv2D(16, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (in3_conv): Conv2D(24, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (in4_conv): Conv2D(56, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (in5_conv): Conv2D(480, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (p5_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p4_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p3_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p2_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
")\n",
"the shape of output of DBFPN: [1, 96, 160, 160]\n"
]
}
],
"source": [
"# 具体代码实现位于:\n",
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/modeling/necks/db_fpn.py\n",
"from ppocr.modeling.necks.db_fpn import DBFPN\n",
"\n",
"neck_bdfpn = DBFPN(in_channels=[16, 24, 56, 480], out_channels=96)\n",
"# 打印 DBFPN结构\n",
"print(neck_bdfpn)\n",
"\n",
"# 先对原始的通道数降到96再降到24最后4个feature map进行concat\n",
"fpn_out = neck_bdfpn(bk_out)\n",
"\n",
"print('the shape of output of DBFPN: ', fpn_out.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"### 2.1.3 骨干网络中SE模块分析\n",
"\n",
"SE是`squeeze-and-excitation`的缩写(Hu, Shen, and Sun 2018)。如图所示\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/9685731c2d65435b9fa6ede7e2dec5da77262d1f15704280a3207207e3946c7f\" width = \"1200\" />\n",
"</div>\n",
"<center>SE模块示意图</center>\n",
"\n",
"SE块显式地建模通道之间的相互依赖关系并自适应地重新校准通道特征响应。在网络中使用SE块可以明显提高视觉任务的准确性因此MobileNetV3的搜索空间包含了SE模块最终MobileNetV3中也包含很多个SE模块。然而当输入分辨率较大时例如`640×640`使用SE模块较难估计通道的特征响应精度提高有限但SE模块的时间成本非常高。在DBNet中**我们将SE模块从骨干网络中移除**,模型大小从`4.1M`降到`2.6M`,但精度没有影响。\n",
"\n",
"PaddleOCR中可以通过设置`disable_se=True`来移除骨干网络中的SE模块使用方法如下所示。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"the shape of 0 stage: [1, 16, 160, 160]\n",
"the shape of 1 stage: [1, 24, 80, 80]\n",
"the shape of 2 stage: [1, 56, 40, 40]\n",
"the shape of 3 stage: [1, 480, 20, 20]\n"
]
}
],
"source": [
"# 具体代码实现位于:\n",
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/modeling/backbones/det_mobilenet_v3.py\n",
"\n",
"x = paddle.rand([1, 3, 640, 640])\n",
"\n",
"from ppocr.modeling.backbones.det_mobilenet_v3 import MobileNetV3\n",
"\n",
"# 定义模型\n",
"backbone_mv3 = MobileNetV3(scale=0.5, model_name='large', disable_se=True)\n",
"\n",
"# 模型forward\n",
"bk_out = backbone_mv3(x)\n",
"# 输出\n",
"for i, stage_out in enumerate(bk_out):\n",
" print(\"the shape of \",i,'stage: ',stage_out.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"### 2.1.4 学习率策略优化\n",
"\n",
"* Cosine 学习率下降策略\n",
"\n",
"梯度下降算法需要我们设置一个值用来控制权重更新幅度我们将其称之为学习率。它是控制模型学习速度的超参数。学习率越小loss的变化越慢。虽然使用较低的学习速率可以确保不会错过任何局部极小值但这也意味着模型收敛速度较慢。\n",
"\n",
"因此,在训练前期,权重处于随机初始化状态,我们可以设置一个相对较大的学习速率以加快收敛速度。在训练后期,权重接近最优值,使用相对较小的学习率可以防止模型在收敛的过程中发生震荡。\n",
"\n",
"Cosine学习率策略也就应运而生Cosine学习率策略指的是学习率在训练的过程中按照余弦的曲线变化。在整个训练过程中Cosine学习率衰减策略使得在网络在训练初期保持了较大的学习速率在后期学习率会逐渐衰减至0其收敛速度相对较慢但最终收敛精度较好。下图比较了两种不同的学习率衰减策略`piecewise decay`和`cosine decay`。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/e9f48fda04c14d5787c0cea52f0e75d85b8f50aee70a4ee3adfc17da5158a72a\" width = \"800\" />\n",
"</div>\n",
"<center>Cosine与Piecewise学习率下降策略</center>\n",
"\n",
"* 学习率预热策略\n",
"\n",
"模型刚开始训练时,模型权重是随机初始化的,此时若选择一个较大的学习率,可能造成模型训练不稳定的问题,因此**学习率预热**的概念被提出,用于解决模型训练初期不收敛的问题。\n",
"\n",
"学习率预热指的是将学习率从一个很小的值开始逐步增加到初始较大的学习率。它可以保证模型在训练初期的稳定性。使用学习率预热策略有助于提高图像分类任务的准确性。在DBNet中实验表明该策略也是有效的。学习率预热策略与Cosine学习率结合时学习率的变化趋势如下代码演示。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4IAAAGDCAYAAAB+yq7tAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzs3XeY1NX5v/H70AWJImADBERBQbAhVgSsEKNogt+o0WjsiSVGjRJNbLHFaDSxJDGW2BI1Go297lAERUClKQhiAYwGETtIO78/zvBjQwAX2NnPlPt1XXPN7sxndp9dUPa955znCTFGJEmSJEmVo17WBUiSJEmS6pZBUJIkSZIqjEFQkiRJkiqMQVCSJEmSKoxBUJIkSZIqjEFQkiRJkiqMQVCStEIhhEkhhL4F+Lh/DSFcWtsfd7nPcVEI4e5a+lh9Qwgza/va2hRCeDKEcPQqni/491ySVFoMgpJUBkIIR4QQxoQQvggh/DsfDPZYm48ZY+wWYxxSSyWqgGKMA2KMdwCEEI4JIbywph8rhPB0COHcau+3CSHElTy28dpVLknKikFQkkpcCOFM4DrgcmAjYDPgJmBglnWtiRBCg6xryEoRfe3DgD2rvb8nMHkFj02NMX6wOh84JP7sIUlFwP8ZS1IJCyGsB1wCnBJj/GeM8csY48IY46Mxxp/nr2kcQrguhPB+/nZdCKFx/rlWIYTHQgifhBA+DiEMX/qDegjhnRDCPvm3Lwoh3B9CuDOE8Hl+22jPanVsGkJ4MIQwO4Twdgjh9BrW3zeEMDOEcG4I4QPg9hVcs0UIYWgI4dMQwkchhPuqPdcthPBsvvYPQwjnVXtpozWpN4SwTn4r5dwQwuvATsvVE0MIW1R7f6XbLr/h81wUQngghHB3COEz4JjlXtsx/+ey9M/jLyGE/1R7/q4Qwhn5t4eEEI4PIWwN/AnYNb86/Em1D9kihPB4/vsxKoTQaUU1k4Lg7tUCW2/SLxp6LvfYsPznbpH/OzQ7/z17LITQtlqdQ0IIl4UQRgBfAZvnH7s0hDAyX+ejIYSWIYR7QgifhRBGhxA65F/fIf89b7Dcxzw+//YxIYQRIYQb8n9HJocQ9l7J1yZJyjMISlJp2xVoAjy0imvOB3YBtgO2BXoBv8w/dxYwE2hNWk08D4gr+TgHAfcC6wOPADcA5MPBo8A4oA2wN3BGCGH/Gn4NGwMbAO2BE1fw/K+BZ4AWQFvg+vznbQ48BzwFbApsATxfC/VeCHTK3/YHVnr2blVq+H0ZCDyQr/Ge6q+PMb4NfAZsn39oT+CLfNgD6AMMXe41bwAnAy/GGNeNMa5f7enDgItJ38dpwGUrKf1loDHp78rSz/ts/jXVHxuWf7seKcC3J61GzyP/va7mKNKfbXPg3Wr1HEX63nQCXsx/nA2AN0h/DjW1M/AW0Cr/un+GEDZYjddLUsUxCEpSaWsJfBRjXLSKa34AXBJj/E+McTYpDByVf24hsAnQPr+SODzGuLIg+EKM8YkY42LgLpaFgp2A1jHGS2KMC2KM04G/kH7Qr4klwIUxxq9jjPNW8PxCUsjYNMY4P8a49Pzbd4APYozX5B//PMY4qhbq/T/gshjjxzHGGcAfavh1LK8m35cXY4wPxxiXrORrHwr0CcvO4j2Qf78j8C1SyKyph2KML+f/rtxD+sXA/4gxfg2MAvbMh6n18rUPr/ZY13xtxBjnxBgfjDF+FWP8nBQw+yz3Yf8aY5wUY1wUY1yYf+z2GONbMcZPgSeBt2KMz+Xr+wfLAnBN/Ae4Lv93+D5gCnDAarxekipOsZxHkCStmTlAqxBCg1WEwU1ZtgpD/u1N82//FrgIeCaEAHBzjPHKlXyc6ufBvgKa5LfrtQc2XW4bYn1ScKiJ2THG+at4/hzSquDLIYS5wDUxxtuAdqRVoJVZ03o3BWZUe67692511OT7MoNVG0pa2ZxJWoEbQgrx84HhMcYlq1HP8t+PdVdx7dJzgu8AI/KPvQD8KP/YjBjjuwAhhKbAtUB/0mojQPMQQv18CIcVf50fVnt73greX1V9y5u13C8wqv8dlyStgCuCklTaXgS+Bg5exTXvk0LJUpvlHyO/inZWjHFzUuA4cw3OV80A3o4xrl/t1jzG+O0avn5lK5Dka/wgxnhCjHFT4CTgpvwZvRnA5qtZa03q/TcpZC612XKv/wpoWu39lXXOrMn3ZZVfOykI9gb65t9+AdidFWwLXY2PWRPD8p93T5YF1xH5z119Wyik7cVdgJ1jjN9iWVOZUEs1fZm/X9X3vE3I/yYj7///HZckrZhBUJJKWH5b3QXAjSGEg0MITUMIDUMIA0IIV+Uv+zvwyxBC6xBCq/z1dwOEEL6Tb8YSgE+BxaStmqvjZeDzfMOXdUII9UMI24QQdvrGV9ZACOHQas1H5pJCxRLgMWCTEMIZITXEaR5C2LkW6r0f+EW+CUpb4LTlXv8acET+df35322QNf083yjGOJW0OnYkMDTG+Blp5ex7rDwIfgi0DSE0qunnWYEXSecWjyQfBGOMc4HZ+ceqB8Hm+Ro/yW8bXZ2zfd8ov515FnBk/nt4LOlMYXUbAqfn/+4fCmwNPFGbdUhSuTEISlKJizFeA5xJagAzm7QSdSrwcP6SS4ExwHhgAvBK/jGALUkNV74g/fB/U4wxt5qffzHpvN52wNvAR8AtwHpr/EX9t52AUSGEL0hNX34aY5yeP4+2L3AgadvjVKBfLdR7MWlr4dukJjV3Lfchfpr/nJ+Qzl8+zArU4vdlKDAnf15x6fuB9Oe4IlXAJOCDEMJHq/m5AIgxfgmMBRoBE6s9NZwUuqoHweuAdUhf30uk5j217QTg56St0N2Akcs9P4r0d/kj0hnFQTHGOQWoQ5LKRlh5TwBJkqTiFkI4Bjg+xrhH1rVIUilxRVCSJEmSKoxBUJIkSZIqjFtDJUmSJKnCuCIoSZIkSRXGIChJkiRJFaZB1gXUllatWsUOHTpkXYYkSZIkZWLs2LEfxRhb1+TasgmCHTp0YMyYMVmXIUmSJEmZCCG8W9Nr3RoqSZIkSRXGIChJkiRJFcYgKEmSJEkVpmzOCEqSJEnS2li4cCEzZ85k/vz5WZeySk2aNKFt27Y0bNhwjT+GQVCSJEmSgJkzZ9K8eXM6dOhACCHrclYoxsicOXOYOXMmHTt2XOOP49ZQSZIkSQLmz59Py5YtizYEAoQQaNmy5VqvWhoEJUmSJCmvmEPgUrVRo0FQkiRJkorEuuuuWyefxyAoSZIkSUVs0aJFtf4xCxoEQwj9QwhTQgjTQgiDV/D8niGEV0IIi0IIg5Z77ugQwtT87ehC1ilJkiRJxWTIkCH07t2bgw46iK5du9b6xy9Y19AQQn3gRmBfYCYwOoTwSIzx9WqXvQccA5y93Gs3AC4EegIRGJt/7dxC1StJkiRJ/98ZZ8Brr9Xux9xuO7juuhpf/sorrzBx4sS16g66MoUcH9ELmBZjnA4QQrgXGAj8/yAYY3wn/9yS5V67P/BsjPHj/PPPAv2BvxewXmXt44/h5ZehXj1o0ADq11/xrUEDaNIE1lkn3Zo2hcaNoQQO9kqSJEk11atXr4KEQChsEGwDzKj2/kxg57V4bZvlLwohnAicCLDZZputWZUqHqecAvfeu2avDWFZOGzaNN03awbrrQfrr7/i+6Vvt2oFG26Y7hs1qt2vSZIkSaVpNVbuCqVZs2YF+9glPVA+xngzcDNAz549Y8blaG0sWQLPPgsHHQTnnAOLFsHixSu+LVoE8+fDvHnw1Vfpfumt+vtffAGffgrTp6f7Tz+Fzz6DuIq/Kuuvn0Jh69b/fb/RRtC2LbRpk+433DCtTkqSJEklqJBBcBbQrtr7bfOP1fS1fZd77ZBaqUrFaeJEmDMHvvtd2H33wn2eJUvg889TKPzkk3T76COYPRv+8590W/r21KkwYkR6fslyu5fr14dNN/3vcNi2LbRvD5tvnm7rr1+4r0OSJElaC4UMgqOBLUMIHUnB7jDgiBq+9mng8hBCi/z7+wG/qP0SVTRyuXTfr19hP0+9esu2hdZ
"text/plain": [
"<Figure size 1080x432 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 具体代码实现位于\r\n",
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/optimizer/__init__.py\r\n",
"# 导入学习率优化器构建的函数\r\n",
"from ppocr.optimizer import build_lr_scheduler\r\n",
"import numpy as np\r\n",
"import matplotlib.pyplot as plt\r\n",
"%matplotlib inline\r\n",
"# 咱们也可以看看warmup_epoch为2时的效果\r\n",
"lr_config = {'name': 'Cosine', 'learning_rate': 0.1, 'warmup_epoch': 2}\r\n",
"epochs = 20 # config['Global']['epoch_num']\r\n",
"iters_epoch = 100 # len(train_dataloader)\r\n",
"lr_scheduler=build_lr_scheduler(lr_config, epochs, iters_epoch)\r\n",
"\r\n",
"iters = 0\r\n",
"lr = []\r\n",
"for epoch in range(epochs):\r\n",
" for _ in range(iters_epoch):\r\n",
" lr_scheduler.step() # 对应 https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/tools/program.py#L262\r\n",
" iters += 1\r\n",
" lr.append(lr_scheduler.get_lr())\r\n",
"\r\n",
"x = np.arange(iters,dtype=np.int64)\r\n",
"y = np.array(lr,dtype=np.float64)\r\n",
"\r\n",
"plt.figure(figsize=(15, 6))\r\n",
"plt.plot(x,y,color='red',label='lr')\r\n",
"\r\n",
"plt.title(u'Cosine lr scheduler with Warmup')\r\n",
"plt.xlabel(u'iters')\r\n",
"plt.ylabel(u'lr')\r\n",
"\r\n",
"plt.legend()\r\n",
"\r\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"### 2.1.5 模型裁剪策略-FPGM\n",
"\n",
"深度学习模型中一般有比较多的参数冗余,我们可以使用一些方法,去除模型中比较冗余的地方,从而提升模型推理效率。\n",
"\n",
"模型裁剪指的是通过去除网络中冗余的通道channel、滤波器filter、神经元neuron来得到一个更轻量的网络同时尽可能保证模型精度。\n",
"\n",
"相比于裁剪通道或者特征图的方法,裁剪滤波器的方法可以得到更加规则的模型,因此减少内存消耗,加速模型推理过程。\n",
"\n",
"之前的裁剪滤波器的方法大多基于范数进行裁剪认为范数较小的滤波器重要程度较小但是这种方法要求存在的滤波器的最小范数应该趋近于0否则我们难以去除。\n",
"\n",
"针对上面的问题,基于**几何中心点的裁剪算法**(Filter Pruning via Geometric Median, FPGM)被提出。FPGM将卷积层中的每个滤波器都作为欧几里德空间中的一个点它引入了几何中位数这样一个概念即**与所有采样点距离之和最小的点**。如果一个滤波器的接近这个几何中位数,那我们可以认为这个滤波器的信息和其他滤波器重合,可以去掉。\n",
"\n",
"FPGM与基于范数的裁剪算法的对比如下图所示。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/13f84e7c4ce84d39b3e06adec20efed0e600dbfff1dc4db9b6498ba33d73925d\" width = \"800\" />\n",
"</div>\n",
"<center>FPGM裁剪示意图</center>\n",
"\n",
"\n",
"在PP-OCR中我们使用FPGM对检测模型进行剪枝最终DBNet的模型精度只有轻微下降但是模型大小减小**46%**,预测速度加速**19%**。\n",
"\n",
"关于FPGM模型裁剪实现的更多细节可以参考[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim/blob/release/2.0.0/docs/zh_cn/api_cn/dygraph/pruners/fpgm_filter_pruner.rst#fpgmfilterpruner)。\n",
"\n",
"\n",
"**注意:**\n",
"\n",
"1. 模型裁剪需要重新训练模型,可以参考[PaddleOCR剪枝教程](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/deploy/slim/prune/README.md)。\n",
"2. 裁剪代码是根据DBNet进行适配如果您需要对自己的模型进行剪枝需要重新分析模型结构、参数的敏感度我们通常情况下只建议裁剪相对敏感度低的参数而跳过敏感度高的参数。\n",
"3. 每个卷积层的剪枝率对于裁剪后模型的性能也很重要,用完全相同的裁剪率去进行模型裁剪通常会导致显着的性能下降。\n",
"4. 模型裁剪不是一蹴而就的,需要进行反复的实验,才能得到符合要求的模型。\n",
"\n",
"### 2.1.6 文本检测配置说明\n",
"\n",
"下面给出DBNet的训练配置简要说明完整的配置文件可以参考[ch_det_mv3_db_v2.0.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)。\n",
"\n",
"```yaml\n",
"Architecture: # 模型结构定义\n",
" model_type: det\n",
" algorithm: DB\n",
" Transform:\n",
" Backbone:\n",
" name: MobileNetV3 # 配置骨干网络\n",
" scale: 0.5\n",
" model_name: large\n",
" disable_se: True # 去除SE模块\n",
" Neck:\n",
" name: DBFPN # 配置DBFPN\n",
" out_channels: 96 # 配置 inner_channels\n",
" Head:\n",
" name: DBHead\n",
" k: 50\n",
"\n",
"Optimizer:\n",
" name: Adam\n",
" beta1: 0.9\n",
" beta2: 0.999\n",
" lr:\n",
" name: Cosine # 配置cosine学习率下降策略\n",
" learning_rate: 0.001 # 初始学习率\n",
" warmup_epoch: 2 # 配置学习率预热策略\n",
" regularizer:\n",
" name: 'L2' # 配置L2正则\n",
" factor: 0 # 正则项的权重\n",
"```\n",
"\n",
"### 2.1.7 PP-OCR 检测优化总结\n",
"\n",
"上面给大家介绍了PP-OCR中文字检测算法的优化策略这里再给大家回顾一下不同优化策略对应的消融实验与结论。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/a78fafbe853e410d9febb5ff412970081609f15a4059493f88c3aa6dc8278d25\" width = \"1000\" />\n",
"</div>\n",
"<center>DBNet消融实验</center>\n",
"\n",
"通过轻量级骨干网络、轻量级neck结构、SE模块的分析和去除、学习率调整及优化、模型裁剪等策略DBNet的模型大小从**7M**减少至**1.5M**。通过学习率策略优化等训练策略优化DBNet的模型精度提升超过**1%**。\n",
"\n",
"PP-OCR中超轻量DBNet检测效果如下所示\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/593321ac0ec34cd39ca4c3755de0bc6ee15b19d24ccc4fe7b37c194b7e825187\" width = \"800\" />\n",
"</div>\n",
"\n",
"下面展示快速使用文字检测模型的预测效果。具体的预测推理代码,我们在第五章会进行详细说明。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mkdir: cannot create directory inference: File exists\n",
"--2021-12-24 21:07:17-- https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar\n",
"Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.229, 182.61.200.195, 2409:8c04:1001:1002:0:ff:b001:368a\n",
"Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.229|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 3190272 (3.0M) [application/x-tar]\n",
"Saving to: ch_PP-OCRv2_det_infer.tar\n",
"\n",
"ch_PP-OCRv2_det_inf 100%[===================>] 3.04M 4.13MB/s in 0.7s \n",
"\n",
"2021-12-24 21:07:18 (4.13 MB/s) - ch_PP-OCRv2_det_infer.tar saved [3190272/3190272]\n",
"\n",
"[2021/12/24 21:07:22] root INFO: 00111002.jpg\t[[[78, 641], [408, 638], [408, 659], [78, 662]], [[76, 614], [214, 614], [214, 635], [76, 635]], [[103, 554], [150, 554], [150, 576], [103, 576]], [[74, 531], [349, 531], [349, 551], [74, 551]], [[75, 503], [310, 499], [311, 523], [75, 527]], [[162, 462], [320, 462], [320, 495], [162, 495]], [[326, 432], [415, 432], [415, 453], [326, 453]], [[306, 409], [429, 407], [430, 428], [306, 430]], [[74, 411], [212, 406], [213, 426], [75, 431]], [[74, 384], [219, 382], [219, 403], [74, 405]], [[309, 381], [429, 381], [429, 402], [309, 402]], [[74, 362], [201, 359], [201, 380], [75, 383]], [[304, 358], [426, 358], [426, 378], [304, 378]], [[70, 336], [242, 332], [242, 356], [71, 359]], [[72, 312], [206, 307], [206, 328], [73, 333]], [[304, 308], [419, 308], [419, 329], [304, 329]], [[114, 271], [249, 271], [249, 302], [114, 302]], [[363, 270], [383, 270], [383, 297], [363, 297]], [[68, 248], [246, 246], [246, 269], [69, 271]], [[65, 218], [188, 218], [188, 242], [65, 242]], [[337, 215], [384, 215], [384, 241], [337, 241]], [[67, 196], [248, 196], [248, 216], [67, 216]], [[296, 196], [424, 190], [425, 211], [296, 217]], [[65, 167], [245, 167], [245, 188], [65, 188]], [[67, 138], [290, 138], [290, 159], [67, 159]], [[68, 112], [411, 112], [411, 132], [68, 132]], [[278, 86], [417, 86], [417, 107], [278, 107]], [[167, 60], [412, 61], [412, 74], [167, 73]], [[165, 17], [412, 16], [412, 51], [165, 52]], [[7, 6], [61, 6], [61, 24], [7, 24]]]\n",
"\n",
"[2021/12/24 21:07:22] root INFO: The predict time of ./doc/imgs/00111002.jpg: 1.7913281917572021\n",
"[2021/12/24 21:07:22] root INFO: The visualized image saved in ./inference_results/det_res_00111002.jpg\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAckAAAJOCAYAAADcYvyMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvWmMZEt23/eLiLvlVpm1dFd3V7/uefPWmTePHA4lUqaHJi1KJkWYomELokxCtgTDtD7YMGwClgAKAm3IhmEIkgxDsCxv/CBBhA1IXGwZEihZBm1qMJx5HJIzb1+7q5faK7e7xuIPce/NzOqunjczb8yncf4b2ZmVeW9E3LgRcc75n3PiCucca6yxxhprrLHGo5C/3w1YY4011lhjjY8r1kJyjTXWWGONNS7BWkiuscYaa6yxxiVYC8k11lhjjTXWuARrIbnGGmusscYal2AtJNdYY4011ljjEqyF5BprfAdDCPE1IcQP/363Y401/nnFWkiusca3CCHE+0KITAgxFUKcCyF+Uwjx54QQX3d+CSF+WAix/xG14xeFEH95+Tvn3EvOuX/6UZS/xhr/f8RaSK6xxkeDn3DODYDbwH8J/Hngf/z9bdIaa6zxrWItJNdY4yOEc27snPtV4KeAf1sI8RkhRCyE+CtCiDtCiAMhxN8UQnSEED3g/wBuCCFm9euGEEIKIf6CEOIdIcSJEOJ/EUJsNXUIIT5fW6vnQoi7Qog/I4T4WeBngP+kLufX6mPfF0L8kfpzLIT460KI+/Xrrwsh4vq3HxZC7Ashfk4IcSiEeCCE+LP/X/ffGmt83LAWkmus8W2Ac+6LwD7wg3jL8nngs8CzwB7wl5xzc+CPAfedc/36dR/4D4B/Dfgh4AZwBvwNACHEbbxg/W+AK3WZX3HO/S3g7wD/VV3OTzymWT8P/KH6nO8Gvg/4i0u/XwOGdfv+HeBvCCE2P5oeWWONfz6xFpJrrPHtw31gC/hZ4D9yzp0656bAfwH8qSec9+eAn3fO7TvnCuAXgD8hhAiAnwZ+3Tn3d51zlXPuxDn3lQ/Znp8B/jPn3KFz7gj4T4E/vfR7Vf9eOef+ATADXvjwl7vGGt95CH6/G7DGGt/B2MPPsS7wZSFE870A1BPOuw38fSGEXfrOALvAU8A732R7bgAfLP39Qf1dgxPnnF76OwX632Rda6zxHYG1JbnGGt8GCCH+IF5I/jKQAS8550b1a+ica4TP4x7Dcxf4Y0vHj5xziXPuXv3bM5dU+/Ue6XMfL4Ab3Kq/W2ONNS7BWkiuscZHCCHEhhDiXwV+CfjbzrnfAf574K8JIa7Wx+wJIX60PuUA2BZCDJeK+ZvAf177HxFCXBFC/GT9298B/ogQ4k8KIQIhxLYQ4rNLZX3yCc37u8BfrMvbAf4S8Le/9ateY43vXKyF5BprfDT4NSHEFG/p/TzwV4EmOvTPA28DXxBCTIBfp/b1Oedexwuvd+to1RvAfw38KvCP6jK/AHx/ffwd4MeBnwNOga/gg3DAp5x8ui7nlx/Txr8MfAn4XeD3gFfq79ZYY41LINYPXV5jjTXWWGONx2NtSa6xxhprrLHGJVgLyTXWWGONNda4BN8WISmE+DEhxBtCiLeFEH/h21HHGmusscYaa3y78ZH7JIUQCngT+KP4HUd+C/g3nXOvfqQVrbHGGmussca3Gd+OzQS+D3jbOfcugBDil4CfBC4VkpvDgdu7toUVgAPhQOCwzuCURDiBAJxzCOFQSJyUWGN8YpgDpbxRLITEWosQ0Mj/9rMQCAQIcNbWSWUO58AneguU8jnexliElAghEc6Cc1hcW4b/J3FYkGCdwDmHNQYhBUKItkznHNZanHM+k23pN59f7tuEA1e3Sgjp21pfgLO2/ij8BeEekxUnFm+X6D6LfPZH0ShM4kkHXazriWgaIVbenlzqhyn3o8BlyuHXq3/5PFGPrdXvVm6D+0avx11+k56o0H6E/XbZ+BGAs6tfCbF06IX7ffHry+p63LHim1De6752SxVeHE/uYmPc0i8OcA7nXH3c4/vUObsyR1x9TvO5mevLn63VK/W1Y8b5eh2urdvW64ixDiFAIvx5zvrhJZZaZt1SKwV28XHl+8f3gWj/by9n6VpEU1HbvgvF1f19ca19pK6lLm/7zS1aIYXALo3ti8eItmC3qKf57zHNApBStvfxYJIeO+eu8E3i2yEk9/Bh8A32qcPXl1FvyPyzANevjPgf/tq/hw0TuqpLJBVCalygmZeO0oIMA0KpsEWFEII4TLDW0uv1yPMcIQTz+ZwgCIjjGOccxhiUUnS7XQByY9qBq3WJUoogCFBKkeUlQiisdVjjF7eqMnQHI5TViKpgVmoKZ3AClJP04x5GVQCELgQRciZSeqW/SUophBBobamqCqMdSZIgZYAQqhWkUii01OBk/Z1EyRiQ4PwxXkYuTU7h2kFa9+fK+2VQ8vLfvXIhVsoQQmDMkoC+UN+T4L6Bdi3w+CH54c//cBDCfFP1uAuTWUqJMeaRc5vPzoaPLUdKuVLe8ru8ZC+ei8eu1PNNeE4uu9a2zKUFs2lv5QoAFMIrs82xrbKwaEezYAv3+L5uyl2+nvbzJffnSe3GqZU2N8f6+WNXy2+uz6wKNK01VVWhtb60HmPLlXmitcbUa4sxhizLqKqKqqooioI8z8mmE19nXV9VVWBsW681xo8j6yjLkjRNmUzO0XmBLSoCpRh0O+RVQVnlJIFCGE1sBYGK0MKRa4OUq3OuGaMA1tbX3Orei+uL6/5uhPryeStK/jKsatc45xZzwGGWjl0ds2E9uK217T2JooiiKBbHhGG97ph2DV9cgz+vqdcYg7UWKSVSyvaaojhAa421lr/yD7+0vMvUN4zft23p6g2Z/xbAZz592w22h+gqJA4jlCmQzqJtgAkSAiHoRwldEiaqohIVylrKMiPLKubznF6vR5L06ff7bec4V+EcGOM7s8h1KwiCIMKWlqLQgCYrKqwFrS1xlNDtdpnMzphnh8SBIooiUH26CESgIIzJKsNQKghC5lKR64pPlpp5r99cIyBRISRdsbRw1n3QKkyOQERIqepFQ6ArwDlvUQqJHyMS0WjY36SQfNLPjRXdTK5mcjQD9Buta8W++rDnXGJ5ffRC8hsvb1l4NHDOoZRaEWDLL3mZxFuQBMt/4pmHxwsIKUW9gNEuZODfzQULb7l9y21f/nxZFyysY7c4Rvh5o630locTtVWzKERZh6y7xgqwEoyETnX5DnzOeFvh4sJuLpeRl947Z0U9Xxbj11l/HxTqkfsGS4LDOZwTCOcQtcrh7OOtWWmDth3OOYQNUK4W9jJEBgKNpqIiciGRC1FlLRC1aQWwExaHQzuNchLpJEpJkk5CP+6z2QVdVmTzOZPxjKOzQ8IkJukkBFGErUqyNMfpFITAKkGEbOedc96StrpWcKoKKSVBEPg+XBqnldPtXG/WgaadQoh2jC8LTFULpea7pk/s0viVtVJ+USlcvo+NkrG496sCe/m8Rngvt6U5Xim1IjyX6/1W8O0Qkvfw+0s2uFl/dymsc+hS0AlCohBc2CELerxrFPtXrtJFMXIhEsUMjVRgioKqqkiSBFhoI80AaLQ68J1njMHohfYSBMGCFpECIRQ4iZTeEg3DmCzNsVikAqEU1gUoZ1EiwMmYQgu2+wnnOkdHMdFkyg/sv89+Ei4mar1INjfZLFM1y1Zbs02nk4BEdIJaM/KCs6q8xdpoissD6Bux1p5gSLYDb3kgaq0p8nK17NZ6eTLs0hFi6f1J7by0zI+ahf1myhMSh1vQVQ31LUT73fLvDocTlwivJ90HHn/OxXpYGkeXnLFKPV4YJ5cKG7fU/saSFNIzwU1NrQUpvJsEkNa3XNYvJ/zLPuHZ0y1VtqQ0CAH2CZbxZZS8xfrxRcPELf41boqFpVqvEQ2TWb+sFBjh3+0lQnJlTjfXr5YWbqlwyhcoHAjrkMvrjZO+b5Tw5rYUWDxbZBtlR0DU7xJai4piCEKm2pA7R1EWKKMJpCAOFa4yCOsIhXhkAvn1rraiMU1v0Ey
"text/plain": [
"<Figure size 1008x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"!mkdir inference\r\n",
"!cd inference && wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar -O ch_PP-OCRv2_det_infer.tar && tar -xf ch_PP-OCRv2_det_infer.tar\r\n",
"!python tools/infer/predict_det.py --image_dir=\"./doc/imgs/00111002.jpg\" --det_model_dir=\"./inference/ch_PP-OCRv2_det_infer\" --use_gpu=False\r\n",
"from PIL import Image\r\n",
"img_det = Image.open('./inference_results/det_res_00111002.jpg')\r\n",
"\r\n",
"plt.figure(figsize=(14, 10)) # 图像窗口大小\r\n",
"plt.imshow(img_det)\r\n",
"plt.axis('on')\r\n",
"plt.title('Detection')\r\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## 2.2 方向分类器\n",
"\n",
"方向分类器的任务是用于分类出文本检测出的文本实例的方向将文本旋转到0度之后再送入后续的文本识别器中。PP-OCR中我们考虑了**0**度和**180**度2个方向。下面详细介绍针对方向分类器的速度、精度优化策略。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/65b6e5d75f22403aba50665a1ab6dbba51cdb47eff0845d1909e5e54fd48e336\" width = \"1200\" />\n",
"</div>\n",
"<center>方向分类器消融实验</center>\n",
"\n",
"### 2.2.1 轻量级骨干网络\n",
"\n",
"与文本检测器相同我们仍然采用MobileNetV3作为方向分类器的骨干网络。因为方向分类的任务相对简单我们使用MobileNetV3 small 0.35x来平衡模型精度与预测效率。实验表明,即使当使用更大的骨干时,精度不会有进一步的提升。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/9177cfe7d294423e9f7e330d3c11ed4cafb46948305346a69a0edaee2b128890\" width = \"1000\" />\n",
"</div>\n",
"<center>不同骨干网络下的方向分类器精度对比</center>\n",
"\n",
"### 2.2.2 数据增强\n",
"\n",
"数据增强指的是对图像变换送入网络进行训练它可以提升网络的泛化性能。常用的数据增强包括旋转、透视失真变换、运动模糊变换和高斯噪声变换等PP-OCR中我们统称这些数据增强方法为BDA(Base Data Augmentation)。结果表明BDA可以明显提升方向分类器的精度。\n",
"\n",
"下面展示一些BDA数据增广方法的效果\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/b78ce4772e684804bb22f20bf806dab2128793fd30e04e74b660f24dc636d68a\" width = \"1000\" />\n",
"</div>\n",
"<center>BDA数据增广效果</center>\n",
"\n",
"\n",
"除了BDA外我们还加入了一些更高阶的数据增强操作来提高分类的效果例如 AutoAugment (Cubuk et al. 2019), RandAugment (Cubuk et al. 2020), CutOut (DeVries and Taylor 2017), RandErasing (Zhong et al. 2020), HideAndSeek (Singh and Lee 2017), GridMask (Chen 2020), Mixup (Zhang et al. 2017) 和 Cutmix (Yun et al. 2019)。\n",
"\n",
"这些数据增广大体分为3个类别\n",
"\n",
"1图像变换类AutoAugment、RandAugment\n",
"\n",
"2图像裁剪类CutOut、RandErasing、HideAndSeek、GridMask\n",
"\n",
"3图像混叠类Mixup、Cutmix\n",
"\n",
"下面给出不同高阶数据增广的可视化对比结果。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/3d71ffaa2c3b4b7ebf7a39256d8607f205ce9a9e22eb4908bce75066a4a0f0f2\" width = \"1200\" />\n",
"</div>\n",
"<center>高阶数据增广可视化效果</center>\n",
"\n",
"\n",
"但是实验表明除了RandAugment 和 RandErasing 外,大多数方法都不适用于方向分类器。下图也给出了在不同数据增强策略下,模型精度的变化。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/0f85faaa6dd6429882e33827384a2d4fac9f4d8c3f804a249917779c97790e7c\" width = \"1000\" />\n",
"</div>\n",
"\n",
"最终我们在训练时结合BDA和RandAugment作为方向分类器的数据增强策略。\n",
"\n",
"* RandAugment代码演示"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAABzCAYAAACIEflfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvWmwZdd1HvbtM9/xzWO/ntADZpDgAJEQGQ6WLKoslfRLjhRrSClhElupuEouK1M5Kg1VchLZspOUBg+JnXIsMbQcTbbjiCJFiRRJECDRALob3ei539Bvfnc+486Pvffa61xcEE0RalLQWVUo3L7v3HP22eNa31rrW0JKiUoqqaSSSv78i/PNbkAllVRSSSVvjVQbeiWVVFLJ20SqDb2SSiqp5G0i1YZeSSWVVPI2kWpDr6SSSip5m0i1oVdSSSWVvE2k2tArectECPFfCCHuCSF6Qoi5b3Z73koRQtwUQnzHN7sdlVTytaTa0Csh0ZvWUG/IB0KI3xNCHL/P3/oA/h6AvyylbEop9/6M2/rTQohUt/VQCPF5IcT7/yyfOaENPyaEkEKIv/ogn/tWiRDiw0KIu9/sdlTy1km1oVcyLt8rpWwCWAFwD8D/cp+/WwIQAXjl632gUPKnmYu/ods6D+DTAP7vP8U9vhH5UQD7AH7kAT+3kkomSrWhVzJRpJQjAJ8E8Jj5TggRCiH+ZyHEbQ2t/IoQoiaEOA/gVX3ZoRDiD/T1zwohnhNCHOn/P8vu9RkhxM8LIT4HYADgISHElBDinwghNoUQ60KInxNCuPfR1gzAvwBwTAixoO8/I4T4XSHEjrY2flcIsTb2/J8VQnxOCNEVQvx7IcQ8+/sPCyFuCSH2hBD/3fgzhRAnAXwIwMcBfJcQYpn97ceEEH88dr0UQpzVn+eEEL8jhOjofvk5fr2+9q8LIa7qtv2sEOKMtkI6QohPCCECdv33CCG+yiyVp9jfbgoh/pYQ4oIeh98QQkRCiAaAfwtgVVs5PSHE6pv1dSXf2lJt6JVMFCFEHcBfBfAF9vUvADgP4J0AzgI4BuDvSCmvAHhcXzMtpfyoEGIWwO8B+IcA5qDgmN8bw9Z/GGpDbAG4BeD/AJDpez8N4C8D+E/uo60BlJa8B+BAf+0A+N8BnARwAsAQwP869tMfAvAfA1gEEAD4W/p+jwH4Zd2+Vd3+tbHf/giAL0sp/xWASwD+ozdrJ5P/DUAfwDKUlv+jE675LgDvBvA+AH8bwK8B+GsAjgN4AsAP6rY+DeCfAvjPdDt/FcBvCyFCdq8fAPAxAKcBPAXgx6SUfQDfDWBDQ2RNKeXG1/EOlXwripSy+q/6D1JKALgJoAfgEEAKYAPAk/pvAmoTOsOufz+AG/rzKQASgKf//cMAvjR2/z+B2kwA4DMAfob9bQlADKDGvvtBAJ9+g7b+NIBEtzWH2sw//DXe7Z0ADti/PwPgv2f//usA/p3+/HcA/Dr7W0M/6zvYd1cB/E39+b8B8CL7248B+OOx50uog8rVffsw+9vP8ev1td/O/v08gJ9i//5FAL+kP/8ygJ8de9arAD7ExvSvsb/9jwB+RX/+MIC73+x5V/331v1XaeiVjMv3SymnofDwnwDwhxpOWABQB/C8Nu0PAfw7/f0kWYXSurncgtLqjdxhn08C8AFssvv/KpT2/EbyCd3WJQAvQ2m0AJSFIYT4VQ2bdAB8FsD0GISzxT4PADRZ26ltUmmz5OQVQnw7lLb76/qr/wvAk0KId36NthpZAOCh/O53Jlx3j30eTvi3aetJAD9p+kz323H9Dkbe6D0reZtJtaFXMlGklLmU8jehtN8PANiF2kgel1JO6/+mpHJKTpINqM2GywkA6/wx7PMdKA19nt2/LaV8HG8iUspdKOjmp4UQK/rrnwTwMIBvk1K2AfwH+nvxZvcDsAm1KaofKPiJQ0U/qu/zVSHEFoAvsu8BZcnU2e+X2W93oGAlDuHcVyTRG8gdAD/P+mxaSlmXUv7L+/htRbX6NpNqQ69koujIk+8DMAPgkpSyAPCPAPx9IcSivuaYEOK73uAW/wbAeSHEDwkhPB3a9xiA3510sZRyE8C/B/CLQoi2EMLRjsAP3U97pZSvAvh/ofBmQOHyQygn7SyA/+F+7qPlkwC+RwjxAY3P/wz0WhFCRFCY9MehYBzz338J4IeEEB6AFwE8LoR4p77+p1k7cwC/CXX41IUQj+Abi5L5RwD+cyHEt+kxawgh/ooQonUfv70HYE4IMfUNPL+SbyGpNvRKxuV3hBA9AB0APw/gR6WUJhTxpwC8BuALGsb4fSgt+HUiVRz690BpyntQG+33aG36jeRHoJyTF6Gcm5+ECp+8X/mfAHxcHzi/BKAGZVl8AQoeui/R7/s3oKCUTd0WE6/9/VAHxT+XUm6Z/6Ackx6Aj0nlJP4ZqP65CuCPxx7xEwCmoKCQ/xPAv4SyTr5ukVJ+GcB/CuXwPYAanx+7z99e1s++ruGaKsrlz7kIKSurq5JKvpkihPi7AJallJOiXSqp5L6l0tArqeQBixDiESHEUxoieQbAjwP419/sdlXy51+8b3YDKqnkL6C0oKCOVSgc+xcB/NY3tUWVvC3kG4JchBAfA/APoGJr/7GU8hfeqoZVUkkllVTy9cmfekPX8bxXAHwnlMPoOQA/KKW8+NY1r5JKKqmkkvuVbwRyeQbAa1LK6wAghPh1AN8HFaEwUcIwlK6r8jo8z0MQKDoKIQTSNAUAJEmCMIwAAK7rmOw2DIcjeJ5qrvkdACRxjLwoAAC+78P3fQBAlmVIkljfx0UQhPpZQByr74tC0j0dx0GWZ+r7PIfvq2fU6zU4jmpznudIkoQ+m3fJ8hyp/t5xHUS6/VmWIcsy+txoNKg9cRxTO8IwhK/fKcszxLG6lywKhFGk31RCFqov4iRGGKr3CfwAhVTvH8cx8jwHAPUsfVZneYZhf6C+bzbhOA71e6/XU33nefQsKQtkWa77fYB2u039bcaj3+9TX0dRhEKPwWg4ovY0m037/mmGWPdRrVaD75upJzAcDFDo+9brdWpfkiTUR1FUg9AR5EII+l5KoFaL9HgWpf4288TzPOqXJEnpPp7n0+csTeid1TyyS8M8K0lTuqfv+8j19aN4BKHD22u1iObFaDSi9/J9j/orjmN4nvrcbrXh6e+FEMiyDH09JkmSUL/meUb94vsePS/LM7rGcVyaz8JxkOh2Z3mGeq1uuht5bucL9ZHrQuoJk6b2nr7vwdXzvyjs/Pc8j55l+luNh4Rr1pQQ9P5pktD3nufRuwwHA5hB8H2f7pmlKUbxCFFk9gLXzovY9ovnedTfRVFQ+yAEAt2vgECm95dCShoHx3Fo38nzjMbE9VxaO3ESU+KC6itB15t39nyf2uAIgRHNTUl7AQRKc9Osd9exe5zZQ9SzfNp3sizF4WFnV0r5Rkl8JN/Ihn4M5Qy3uwC+bfwiIcTHoWJ2UavV8OijjwIA5ufnceyYShqUUmJ7exsAsLGxgdOnzwAApqZseOy1a9cQBjUAQLvdpoHf2dlBt9sFAKyuHcPc3Bx9f/euijSbm5ujZwkh6PskSXDixAkAagPc2lIJdXfv3qX7vPe978XS0hIAYHZmAevrKi/muS9/EaPRCADQ7/dweHgIAFhaWsDx4ypPpNPp4NatW/qaIT7ykY8AABYWFvHiiy/iTz6v8lHe8Y53YGFNRYx1uof0m6Io8Mgjj+jf99HtdgAAm5ubOHv2LABgcXGR2nH9+nUMBgNqt5ncOzs7uPLKZQDABz/4wdIC+OIXVRtWV1dx+qFTAIA0TWmjv337Nt71rnfS9eb+zz//PFZXVZvPnj1Lk/XChQu0AX7kIx+htl27dgOXL1+m6xcW1Nx0HQ9XrlyhxfrEE0/Q2K6vb+LSpUsAgGfe915qd5IkuHHjBgC1CbzjHe8AoBbK9evXAQD7+/t46KGHAADT09P0Pnfu3KHDcGVlhZ67d2+TxrDVamF6ehqAWsRmXty7dw9rJ9V8adbqdM+DgwNq89raGn2+ePEi3X/t+DKmp2cBAK+99hq
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAABzCAYAAACIEflfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvWmwbcd1HvZ1957OfO707vwGPMwDQZAURYmEBA+SLMsUJdmUZFsSnUqiSiKXKynHUaJSRXEGRT9ciSelHMWabEulwbJEUqQshRpIcQIBkgLBh4cZb3733fHMe+zu/Oju1fuCAPFAgmAEnYV6hXPP2UPv3j2s9a1vrcW01pjLXOYyl7n8+Rf+9W7AXOYyl7nM5bWR+YI+l7nMZS5vEJkv6HOZy1zm8gaR+YI+l7nMZS5vEJkv6HOZy1zm8gaR+YI+l7nMZS5vEJkv6HP5ioUx1mCMfZAxNmSM/ebXuz2vhTDG/oQx9p99vdsxl7l8JTJf0OfyimIXuSPGWPyin/4WgFUAS1rr9zLG/h5j7OOv8b3/HmNMMsYmjLERY+wxxtjfeC3vcRNteIgxphljP/563ve1EsbYadv+4Ovdlrl8bWW+oM/lywpj7DSABwFoAN/9op9PAXhaa129Rvd6uQXnU1rrNoA+gP8LwK8xxvqvxT1vUt4H4BDAj7yO95zLXF61zBf0ubyS/AiATwP4JZiFDQDAGPvHAP5HAD9gtecfA/CvAHyT/Xtgj4sZY/+EMXaJMXaDMfavGGMN+9tDjLErjLEfZ4ztAPjFL9cQrbUC8G8BtADcVmvLbzLGdiz08zHG2D21336JMfazjLEPMcbGjLGHGWNna79/G2PsSXvuvwTA6vdkjLVgLJEfA3AbY+xttd8eYoxdedHxFxhjf9V+bjDGftlaN+cZY/9d/Xh77D9ijH2BMTZljP08Y2yVMfZ7tq0fYYwt1I5/B2Psk4yxgbVUHqr99ieMsf+FMfYJe+4fMMaW7c8fs/8f2HfzTV+un+fy51fmC/pcXkl+BMCv2H/fwRhbBQCt9U8B+GkAv661bmutfxbAfwGrTWutnQb9MwBuB/BmALcC2ITZCJysAViE0fZ/9Ms1hDEmAPwnAEoAF2s//R7MAn8CwOdsW+vygwD+MYAFAM8C+N/s9ZYB/AcAPwlgGcBzAN75onO/D8AEwG8C+H3UNrWbkJ8CcBrALQC+DcAPvcQxf9P+djuAd9tn+QkAKzDz8x/Ytm4C+BCA/xWmv/5bAL/FGFupXevvwPTPCQCRPQYAvsX+v2/fzadexTPM5c+RzBf0ubysMMbeBbPQ/obW+rMwC97feRXnM5hF+r/RWh9qrccwm8AP1g5TAH5Ka51rrdOXudQ7rMafAfgnAH5Ia73rftRa/4LWeqy1zgH8TwDuZ4z1auf/ttb6MxYa+hWYzQUA/jqAc1rrf6+1LgH8UwA7L7r3+2A2LQngVwH8IGMsvMku+H4AP621PtJaXwHwz1/imH+htb6htb4K4E8BPKy1/rzWOgPw2wAesMf9EIAPa60/rLVWWuv/F8Cj9hmc/KLW+mnbj79Re865/AWR+YI+ly8n7wPwB1rrffv3r+LVaagrAJoAPmthggGA/2i/d7JnF68vJ5+2Gv8CgA/AYPoAjNbOGPsZxthzjLERgAv2p+Xa+fVFegagbT9vALjsftAmUx39zRjbBvCX4DX+9wNIAHzXK7TXybHrv+izkxu1z+lL/O3aegrAe10/2r58F4D12vEv95xz+Qsic6/3XF5SLM79/QCExbcBIAbQZ4zdr7V+7CVOe3Hqzn2YRekeq4G+lNx0uk+t9YQx9l8CeJ4x9gta68/DWAzvAfBXYRbzHoAjvAgLfxm5DmDb/WEtiu3a7z8Mo/R80PwEwCzo7wPwOwCmMBuWO1/g+GZ1HcAWgCfs3/Vrv1q5DODfaq3/86/g3HlK1b8gMtfQ5/Jy8j0AJIC7YUz3NwO4CwYWeDm2xw0AW4yxCCAn5v8D4P9kjJ0ADBbMGPuOr7RRWutDAP8aHofvAMgBHMAsrj/9Ki73IQD3MMa+zzJs/gEMpu/kfTDY+5tr//4mgL/OGFsC8DSAhDH2XRaG+UmYTc/JbwD4HxhjCxYD//uv6mGPy78D8G7G2HdYqySxTtmtmzh3DwbauuWruP9c/hzIfEGfy8vJ+2Aw2Uta6x33D8C/BPB3X4Zi+EcAzgHYYYw5mObHYRyRn7aQyEcA3PFVtu2fwiyqbwLwb2AcpFdhNOFP3+xFLJT0XhjH7QGMY/UTgGGUwMAcP1t/fq31B+zz/G2t9RDAfwWzwVyF0djrrJf/2f79Asxz/3uYzedVi9b6Mowl8hMwC/RlAP8INzGHtdYzGEfwJyxc846vpA1z+f+/sHmBi7nM5fURCxf9oNb6W7/ebZnLG1PmGvpc5vI1EsbYOmPsnYwxzhi7A8A/hGGuzGUuXxOZO0XnMpevnUQA/m8AZwAMAPwaTKTrXObyNZGvCnJhjP01AP8MgADwr7XWP/NaNWwuc5nLXOby6uQrXtAtRetpmCi3KwAegXEUPfFlT5zLXOYyl7l8TeSrgVzeDuBZrfXzAMAY+zUYL/zLLuitVkvHkQmyC6MYSZIAMIThsiwBAGmaotlqAQCEEHAbzmQ8QRQZRlgcR3TNLMshpckNFScxwtBcv6oqpDMTeBiEIZIkpnulmSEaKCkR2WtxzlGWlT23pLZFUQTHQdZaQ1YSACCVghDcHi+RZyY2RgQCjUaD2lAUBQAgzwv0eyZ4kXGOsigxmUxNv7RbiCLzKiqpMEtTew+Jtu0LaEApBQCYzWZo2e+jMICyfZRmGT1Dr9uhPqqqCqPhCADQX1gAt8/DGDAYjGyfxmg0YnrOSprnHA1HWFpapGtpS2keDIb0nEkSw+kF0+mU2tntdCDt56Iokc5mAIBOt4Mw8ENvMp3ROZ1OG8xSyMuywtSe02636D0wALPUxSJptJqGCq60RlGU9n4FmrZ9QSCoHVmWwVHU4ziCo5cXeYGiMOMiiRNEkR9js9S0Ic9zJO6Z45jG7HQ6pWO7nS6EEOb72RTS9mOS+GumaUqfk0bjWF9UUqK0YybLCih7flHmCERA/U3tLkq6RxAECO11OWNIs5SerdPr2r5j9G5n0ym9wzAMaa7lRU7XjOMEgR3nUkqkti+iKD7WR67vtNK+DZzTe82zFIGdm2EYgXNzzclkDPcwSZIgsH1RFgUmkyk6HRMbJURA5+R5Ru2LwggiMP2tlKb2McZoDgMcRW7nvFKI7fecc5qfZZEjjs33IgjgqPtp6oOXG0mDohuqskLuxksS23MAzjhmMzMetAaaTRemwFBVJfVvkph+F5zTnHJrCGDGF7fjqKpKXL26s6+1rsc4vKR8NQv6Jo5Hvl0B8I0vPogx9qOwOTr6vR6+89u/HQCwffIkzpw1OZK01rh25RoA4Kknn8YDb3kLAKC/6BPqPf75c7TQLy8v0cJ9+fJlHOwfAABuv+s2rJwwAYI3dm7g3Llz5l7b2zhr78U4w7PPPAsASNMMd919JwCg0Uhw/ZoJ0jt//glsbZkYkPvuvweNhnkpvd4CDg/MvS5euILp1Ayeo8Ehdq6b2JuzZ8/gzNlTAIDxaIIvfvGLAIDDgyHe870m62uSJLhy6To+8MH3AwAe/NZvw8nThv58NJrhC+ceBwBUSuKbv+EbzLUmE+wfHpo+eupJfOM3mq7eOLGMzC5ij33hizgaDQEA3/mXH0JpN5+d67v4+Ec/CgB47/e/F4L7Der9H/gwAOCuO+/GXXcZmnIpJcajCQDgi088gYce/GYAgIJClppB/MHf+xDuvuNuAMC999wFKc3E/dOPfxzjiTn3e7/73ajs90+cfwaPPmxSiLzpgbfi5NZJAGax/fTDj9BEfte73kELyAsXdvCpT/wJAOBvvOd7aOLmWYHPPfYF897iBA++0+TLklLh8XPPAAAuX7mKb3ibiXxfXOhiPDET84knnkKzZSbTrbecATeXxLULl7Fzw8Q+LS0tYW3VvI84ifHC888DAJ587hncf7+55kKnh6N
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 参考代码:\r\n",
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/data/imaug/__init__.py\r\n",
"import random\r\n",
"from PIL import Image\r\n",
"from ppocr.data.imaug import DecodeImage, RandAugment, transform\r\n",
"\r\n",
"np.random.seed(1)\r\n",
"random.seed(1)\r\n",
"\r\n",
"img = Image.open('./doc/imgs_words/ch/word_4.jpg')\r\n",
"\r\n",
"# 绘制原图\r\n",
"plt.figure(\"Image1\") # 图像窗口名称\r\n",
"plt.imshow(img)\r\n",
"plt.axis('on') # 关掉坐标轴为 off\r\n",
"plt.title('Before RandAugment') # 图像题目\r\n",
"plt.show()\r\n",
"\r\n",
"\r\n",
"data = {'image':None}\r\n",
"with open('./doc/imgs_words/ch/word_4.jpg', 'rb') as f:\r\n",
" img = f.read()\r\n",
" data['image'] = img\r\n",
"\r\n",
"# 定义变换算子\r\n",
"ops_list = [DecodeImage(), RandAugment()]\r\n",
"\r\n",
"# 数据变换\r\n",
"data = transform(data,ops_list)\r\n",
"\r\n",
"img_auged = data['image']\r\n",
"\r\n",
"# 显示\r\n",
"img_auged = Image.fromarray(img_auged, 'RGB')\r\n",
"plt.figure(\"Image\") # 图像窗口名称\r\n",
"plt.imshow(img_auged)\r\n",
"plt.axis('on') # 关掉坐标轴为 off\r\n",
"plt.title('After RandAugment') # 图像标题\r\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"下面展示快速使用方向分类器模型的预测效果。具体的预测推理代码,我们在第五章会进行详细说明。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2021-12-24 21:19:04-- https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar\n",
"Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.195, 182.61.200.229, 2409:8c04:1001:1002:0:ff:b001:368a\n",
"Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.195|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1454080 (1.4M) [application/x-tar]\n",
"Saving to: ch_ppocr_mobile_v2.0_cls_infer.tar\n",
"\n",
"ch_ppocr_mobile_v2. 100%[===================>] 1.39M --.-KB/s in 0.1s \n",
"\n",
"2021-12-24 21:19:04 (14.3 MB/s) - ch_ppocr_mobile_v2.0_cls_infer.tar saved [1454080/1454080]\n",
"\n",
"[2021/12/24 21:19:06] root INFO: Predicts of ./doc/imgs_words/ch/word_1.jpg:['0', 0.9998784]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAABrCAYAAABnlHmpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvUezbVly3/fLtdY2x17z/Kt6ZdBggy2FEKGABH0CRWimqaQPwJGGGvCzcKCxPgEjNJYiNOBIokCRIECgu6v7lXnmunO2WSY1yLXPvdVVABsE2JAQNztu9Ttum2XS/POfuUVVeZRHeZRHeZT//4v7+76AR3mUR3mUR/m7kUeF/iiP8iiP8g9EHhX6ozzKozzKPxB5VOiP8iiP8ij/QORRoT/KozzKo/wDkUeF/iiP8iiP8g9E/lYKXUT+GxH5NyLyZyLyT/+uLupRHuVRHuVR/uYi/6E8dBHxwJ8C/zXwFfAvgP9eVf/V393lPcqjPMqjPMpvK38bD/2PgT9T1X+nqjPwvwD/7d/NZT3KozzKozzK31T+Ngr9E+CXD15/Vd97lEd5lEd5lL8HCf+xTyAi/wT4JwC49o9k/QwDeeT+Ow9enj6T5bMHkJAsXxIQQUtGnIAWVBURUFVQEOfr9+8PrCiCoPenvpfvIU9iP1W1n6uCKlLPr2qHVbVzI3Y9y1dPB1vgrIc3eDrfj0BdfxX69ZvXKw8G6wefPXz/Rz5UffCFHz/hw0+WI+iPHvOvP8733v/eAX/4G1l8izouKmrz/BvnkPrf01yiP7jL762vh+Msv3HND65JHv5D5YdDd5pLefC75R/3Z5T6XVuHhabxtG1DaDylZEpRSiloUXIupFwoBRCHiK1ZFXkw4MuaredQu+ffuKnv389vyo+td+p9/uB7y9p9+P6D7+nD+7arW/aJ3b/W/VhO49E0QvCO4B05JUoptp+cwzkHIuSipJRJOdcZfehryv00Cqd9eP/pg8+x9S1SV4cWQHGAOGibgPceLRkoeO9xThBx9ksFLUpBKfWaYsqUXOwAdgF1fO7XyW8OvekCefCqXndd/7qMzqIvlqX34DYerm09/Oqdqj7j3yN/G4X+K+DNg9ef1ve+J6r6z4B/BuDP3mjzX/1PRHWUul4FaFzBS11frkVcqHs5Q8k4CkVsWkQ8JYMTxZGYpzucjjQB9mcbRITjMDGMkZQKElZIu0WlsU2Gw3uPiFCwjZdzRkuBYmrFIbRNIKURSmIeB4LztI0nxQIIoelQ3zKmYjq9bQgoOUcotqByzkg1BM4528jLZnecDIXNniJl2bTOvrcsXufA3S8KVbXrdYILAXGuHkrx3pNztu+6YJun3jNAKQWh2EYqmVzS6XcAOSVwDsGuddnMTsQ2ST2/KSUzZo4CJVejKgR3vxlVlTTHOvaKc3YtqrZpVE4jbveaC0UzOMW7hlIU7xtKiRRNVR172/JajbqCR2yNCGQR8AFKHd+cEAXnBZyiTpCi93NBofHhgaJwdaz9jxp/L/a5CqhmcpxBM62D1ivzcAXzwOefv+A/+dmXNKHw7v1bbm8PCJ6u69huz8jZ8/a7K95+e83HmwkNa8JqB35FKgXvGlJKhBBsL6CneS7FrkWwtewUlFzn+17p2Zwt91l/o/fr6Ef262meH4qIh1L3jCSKt8+zKiUlvAidU7xmyAfyfKApE+e7ln/0k+esO8GVxM31R9o2sNudsd1uCU3DmDK3h5mr6wM//8WvGKbCODqcBIo4pnkm5oJvWrquO82JiMdJgOJAWzw2DqLR1ksZEB3pe3hyvuLsrOPp5Y6cRqbxwLOLPefne7quw7lAcZ45ZaZxJiY4zpF3H6759dvv+Hh1y5Q8WT0pFXJRGxPqXsXhg+CcjaFmgB5wFBziPb5p8E1LKQ1ZHYitb/M7C1BOc6C52F6tczH/H//05z9ciT+Uv41C/xfAPxKRLzFF/t8B/8Nf9wNFSBJOysn+3/wNrQpF8RSHKTddLJRDTtawPDACCZWZ3brh1csLgnMoGXmy4/b2jg9XN9xNA0pnE4BDpCAF1AlOhKIFh5IVHBkBvPNInvE5Ms0H+saz3/XkeeYwDzb5RUk5gQTAoTGRJJvtdYJTu6+yeJy5nKy5c87eP0Uh3CvwZbMJCAW8vXlyytWUsYopRXJBc7H7cQ53sv7OjuUcIoL3guaCdwJZTAmVgne2gTUnnCsUkh2jKuxF4ZlDYueAcpobQdA4m8Ksn0vhZChzzjQ+nJSnqCLqqjYs4B1KwPmC14A6M4LOOZxXijhyiqjLuGWNqKBl8SSrYqme7CLi3IPIzYPmqt8FMe/g3p6a+TIz4cRuT+4jOrt1QasSxdVNXPTkWUtRM/7TAU/k+YsdP/3yOa0b+PjuWwKZP/zZ7+Mk4H1D03QonvOzPX37Dfnnb7kbIxonRBqcs62pztdI8EG0uXjtuvhwQhH53rWxGF8piAoi5XsK/K9S5qfxW+bSQcmK5mpMqPdc16sXG7MgICVT8oij0Hllv3JcbjyNjHgc03wgNMrZ+Zbz8x1934M43Gxee0oNL56dMw6Zb765pZSZmAquRFoRhEyaBptPJwTfo0TyrHTtnsa3NARKUlKeIWfaBvarhsuzNZcXa+bphuFwTdcEznYr+jagpTDPR5K5OnjncK2gLrDb9Fycb3Gi3B4SKRUmZ967OHCorSmBEDzeizkeRYgpE3Mh1/Um2ZGZUQrOtbbiREAKahYA84UUnOJwP2pc/zr5D1boqppE5H8E/lfAA/+zqv7JX/sbBG3M6jrnEDGvTbNQnFvUhm0iBwFPIx4RZZgnmhDIOeMdNC5yd/WOT15e8Pmb52xXgfFwhzhhu91ys3a0TLy/Hvk43SJ0eN+AtGiZbYE6G1B/CjULATFPez4QiHz5+gnPnlzgpDAMB1TPORwnfvmrb0BbQujJ6ikqhL6xSEMLmm2xL56QKOAEkYA4peT8ADqpnrBmtIZ+ch/AghROukTMO9WieCxSATMYqoU0J3RZBA+86pITJWXEhhZPwYkphhwjOU2oRoIvJw9aqwYTp+SYENH7CIF7j7ytCpuikDmdswsCAUqZvrcKFoXjpTClBAo5NSge1I6fcYgL+GYNBdqmBe/IWckqpkQBvLPzUuzGhGqoq6oTQR0PlH0Nz8uy0mycMuDEPKllNZjNXcCYCoVwbwREC06FoB4tE6XM9C5zuV/z5ZsLzs88tx+/o2Xi8vKSJ9sdIo6cE1McaJuOp2cr9M1TshZ++faKq8NAnj2h21Yj5ev1CkHcyTCaoleLDuv/lApVyWKAH6At5YfQVN3HP6owfvO9gqJYtHf6nbN59ioIiZJnNI20rbBqhKfbnv1akXREvWMab2iaBmFGJFPKjAoEL6zXHqFn1Tzl47tbyjEyzzNTisxJUPGEtkGco4hQqpGb58RYEuSRkjO5OByZVhLSCl3r2HSeLhQaF5nygJaJftURgielhGZIJRNzplANmW9ovWPdedadYwyK27XVmTCIxjcBL646CRa1xjKj2uB9IObAOMNxmIlJyTqRkqdINt3gPV6XyD2f9jdS9WFeHLm/Ckv7ofytMHRV/efAP/+tfyCC+mAbzTnE2YUWsk0QdYE4wRdFSqGUhJRMIwVNM5ozSCHNIy+erPnJZy/J8y1zUS73G/b7LU3TsWkCabij9YF+gG+vjuQ0IjSAx/mA0BiE4gSHQMnm2ZVEIPLq2Y5XT7c0ITGMBy7Wge3ZnuMQEZ357v0dt8cbSlRcu4Z5Ims2DzwJThqCVI9ZDPtMOaKpQFUcNiwe1cRps5TFIXOYK3m/uUzJmhdZSkachf2nUBmH8w5PQCloKieIxLtlA9uoxziRU0Kcow2wWTv6NuD9vfJbNnbwFRKp+KRiC9Bjm817TwiBNoQT3BRjNAMs7hR9GDZZqoFT2+gpMg6RaYqkKZEKBimoR1JBszDrgLi26uGVQUTf8zKdQUB+gd5rRPBgjIuU078X4wSgThFnUY5pcfeDYy8Ipyl6PUU9pEJOI8SJIDP7fcenr854+fSMNN1wd/2Oi/Mtz59ccjzcVmU
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2021/12/24 21:19:09] root INFO: Predicts of ./test.png:['180', 0.9999759]\r\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAABrCAYAAABnlHmpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvcmSZEl2pvcdVb2DTT5FhHtMmVkDClOzqycRvgFFuOOW5AM0N3wArvsxesE1n6BFuMSWWFAE0qRIA6hCVdaQmZERPrvZHVT1cHH0XjOPTBQSKHQVCHEVCQl3c7NrenU4+p///OdcUVWe2lN7ak/tqf3/v7nfdwee2lN7ak/tqf3jtCeD/tSe2lN7av9M2pNBf2pP7ak9tX8m7cmgP7Wn9tSe2j+T9mTQn9pTe2pP7Z9JezLoT+2pPbWn9s+k/VYGXUT+exH5LyLy1yLyv/1jdeqpPbWn9tSe2t+/yT9Uhy4iHvhL4L8Dfgn8OfA/qer/+4/Xvaf21J7aU3tq37X9Ngj9vwX+WlV/qqoD8H8A/8M/Tree2lN7ak/tqf19229j0N8Avzj4/Zfltaf21J7aU3tqv4cW/mt/gYj8e+DfA+Cqf0f73F4vf1cUDmgfmT/3+DpGDSlaPiXz/1Jet+vsf8/ze9CM5mTfozpfxf5W/v+t7vHj3n/bZeXxW771K/WjNwgiUq4vdpflHkWmu5suVH7XqUPOXhN3cM3DQf2oA/LRHMjHY6Pse8F+zg7H81FX5G+5T/n4hW/+qcz14WvyrdTgR32Zvm/6/nks5Ntf+0YHv+U7HvXp4LMKiP/oTfLttzd/58cvTv8yaC7rM3/zbd/Wn4//JjxePt/ogx70YXrzb9M+7shH602mNSoonhk7Ho6DfHyd79I+nq/yTw/Hc/rzR/f48dg++u6P98jfMpffOv760XunX1wZ6Y+u/ei+P/6+w7V58Pfui/eq+uJbevSo/TYG/VfAJwe/vy2vPWqq+h+B/wggq9fq/uR/QQSEDGRySpAi4hSH4FRxnmLIxOyKKjGPpDSi5XPeZYSEU7uWaEI1ISSQiKQRJwmXI3nY0T9cksaBHAdbYm7qX4KcCcGhmsk5T/2e+wA8+nn63TmHiJRrOQRH0kzOkFIiRSWlhKqgbn8dVPYG9ODnR/EMTYhzhBBwzjqbc2bsOtR72roBHFEzOSkqkNUDAXyNCwvUtbiwsaPNeVQ8Hk/OGdEMYmOpoiAZlYyTjEMRp3gUNIImco44Mt4ppEhKiayRnDNpHKzLOYO6MlYeESHNe8jZZtcyDi4gIqgm+5yojSd2XdWE94J3DlACejAXiitzl3MmpYiqkrPaWKuSsoMMOAEXwAWEChUBqRD8fpxFECKqGSXZPpUMOZU+HhgKF0AqUI+XNVkcSEDFIy6As+uqK/NcDOnh+lEiTjMiIy73xLgl9w8QO8gjIDiEj+NbIvsxUNVHf5/XpuRH61RV0by/lnNumoL5c9Naf/xlNnH54D2TUZb5IANN5UAiIH4B0oCvwNU43yC+Btmg0pIlgKsQ58h+v78/7u/++ocGuOxL0vROlBHIiEZy3iJqdmCa05Qi5GhrXVMZ21xsiuCkAhxZDQAJYf+7OMTLR+OV532sqji1NSgiIJksMJ+qIni3hOTRsu41FHMrHhFPxhsokMq+zwXE1daHYsjNxnjSf/4PP+c7tN/GoP858CMR+T5myP9H4H/+uz6UBZwquLLYKDYNW7CJhOieCXqEWXXCh4rmOG9qISMozhV8rtOmTwgRxIyDy0KWgBPFe49zzNeYjPm08OVgE6qasfHeP3p9aq5sNFUhpkSMdrioL/grKZqKcXCCU0cmgwpeHCqCR0hiizVnQ2tmpPaLet7AqsQYbdFqLkZsepcZkZwz4somEHB4EDnYIs6AjbOxygVN5WK4nOaCFQRBEVHIZsxzGkk5ocWQ2qE4IcBUHCHdfw8gzsYsZ+u/yn6ccQ7NEc3JDmwpxgvzTryAw1kfmHCP3atzDi3rxc2HtJCGzIzcpn9SjFO5S2dHFmgqr9m8IUCejMOEAA8WsZQvctnuL9v4mFc4GQA3Ler5s0rGFeMiku0QISM5gdMCcr2NcRYzQgfr7RtG92Cd2i+Z34R4v+3z0zzlnPfrnYTm6QCbBxUmY1oObPMWtbw3oToaOMkKvhwGCq7OiCQb78M5FJnX+f7QKCDLdiazUZcJaE0r2ObLMRlrAx+qCcmZJBlRLWt0LB7QaHOtChmSTzgXEMzo2tdPXoZDs9qBraCSDYjqfjzsvnX+uxNBxdaEImgeIUXMQDsbP3HgEipVWTN2QCoC2fbFPJ3zfH3sWfzt7R9s0FU1isj/CvyfgAf+d1X9f77DJ8mScY/AR57sEKiSxJwUf+DJZymo5+De9GCzilO8UxxaDEMkpx5JI+hI8IA4crbhds7hJ1QogZwnZPhNZD4hceGxQZ8PAJ2QXNmsKN4V9KFCzNkMRlZEigFwGc2Gzl1ZNCK22KTQGKJlc+AMpQskV/aYZMQ5PIYuFEfSYtwebdyMiKH8PAFGnBkQPCpS7E0uCNr6YVdI+zEoaCSlgZxGu/c8bY5i+MTt99/EuGhBjOptntSXz5iB1ekD08Z2EJzHeftOJ2Jzq8WbUCXPGz/PSNXmyY6gpIL3Qpb8Da8bKOiX2YAbUtSyGXNBERnVaO+1HVtsZTEgCkjCiUPFl8/6ebx0pji0HGYgORrCFAVJBVkO5hV5h/M1aCKNtjZSTqDZvIpibB6h9unAmDzZnPaHyUETV/ZIQZAy7Z3DvtqKK98xHWQgBzQmOHtvMo9KgKzl4BQBl0ACSMIM/IAmQXREXGX3oglx5TvT3hPe39bU93K4YuZwOhVFJtBlIE6l2IAcARtTyp5EM0g0ZK7ZPisZ4ljWaiZnh0hlHhapeAdm0I2qlRlIWB+8GXVNj2hAW8diazQLOEEmtK9mYzQnO7BxIIpi3gqk4rzamkfV1j088qa+S/utOHRV/U/Af/r7fGaaPC2L0CyFTWyCYsQyzjmSUAxWLgYlg3OIQhB7XVC8KMGB00K3EEkaIQ62YYgEX05e7wx5ASIO54rBdhmn4RsGXVWp69oMSdrfx4xsNKIp2s+FarFrKD7YxBktUBZlLpxzzjjxholzJj9C49kokvLaobs89cs5t//deVvyWYxqkD2Nk6W42a4Y/pxsDLP9DfVlD3kbTwHnPMJYqJDiwmoi5wHNIzknQ9SHbr8IDsgFSUlBsWb0rSMirhhoIZUDjmnTlpXovdJUAee8bUhn9IrYkD36zml8DufKvlzx3uFcICbzYA7XH4du/uQ2k2cDgmScqq23EsPQQxuZzWhkl3DO48goAYqBEedJqmUOzCA4TaiksvZiMRYDohFxido5nBhllJ0h5hghxli6qfOhV27W/pvXamaKrxy6E4fI93DtHo7h5L3Y9+wP2OnvE6ixz9lBiPp5rFC1NeNlNkzq1LxQYkHOESQg0x6YDPEMKObeHXa0HKx68F0Tgs+IK+Anp0K3TIdtsjkSMFCiiGSCg6yJ5BLeQxrMEzDPqsJ7wAk55dmQOgSRXMy12Jos3tO8L5U9HSbmeWsWsgiSBQNOyYz6BA7mtSiICpqdHTjEQsEYfamHa/U7tP/qQdGP2xzW1INQgZNyeisqbgJv5sLkyfgqwTuUEVHjeb1TnEpxcBLkjpwGNA+QRwQz7q5ww+YF6OS92eJS4/FVFe9tcRnSZqY1RgXvPbnQJj6EsrhzQdB23XEcZqMsEzrTPCP2GBMhVMTRTmQptEZKxSiNo8HvabGros4Z56iZpEoIAV/ZtEkQcxnFDLpkwWP8eSIU4+5Rlb0jLg4ngZSjGSvniuEuvGFxo52buMmMphHygOZISqPhF2dG+2MeN6W9N5NTmumJnM2QqWaquibHaDSUs7GtKvN4nFOcGPrPcSCRqH3AFS6tzJAZKnE23oU3Nl58QtuCc56UFU0ZCl02zaFzFt+QYgiRTHAOLRv10JBNdFqe3WtFc8YL5BRxwTZqzhnxZY15Zwe9m2gRtXHUCUWOFjdymRQjCTV6Xsq8qsM
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 参考代码:\n",
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/tools/infer/predict_cls.py\n",
"!cd inference && wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar -O ch_ppocr_mobile_v2.0_cls_infer.tar && tar -xf ch_ppocr_mobile_v2.0_cls_infer.tar\n",
"\n",
"# 方向分类器分类\n",
"!python tools/infer/predict_cls.py --image_dir=\"./doc/imgs_words/ch/word_1.jpg\" --cls_model_dir=\"./inference/ch_ppocr_mobile_v2.0_cls_infer\" --use_gpu=False\n",
"\n",
"# 读入图像\n",
"import cv2\n",
"img = cv2.imread(\"./doc/imgs_words/ch/word_1.jpg\")\n",
"\n",
"plt.imshow(img[:,:,::-1])\n",
"plt.show()\n",
"\n",
"# 旋转180度\n",
"img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)\n",
"img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)\n",
"cv2.imwrite(\"./test.png\", img)\n",
"\n",
"# 对旋转后图像使用方向分类器进行分类\n",
"!python tools/infer/predict_cls.py --image_dir=\"./test.png\" --cls_model_dir=\"./inference/ch_ppocr_mobile_v2.0_cls_infer\" --use_gpu=False\n",
"\n",
"plt.imshow(img[:,:,::-1])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"\n",
"### 2.2.3 输入分辨率优化\n",
"\n",
"一般来说,当图像的输入分辨率提高时,精度也会提高。由于方向分类器的骨干网络参数量很小,即使提高了分辨率也不会导致推理时间的明显增加。我们将方向分类器的输入图像尺度从`3x32x100`增加到`3x48x192`,方向分类器的精度从`92.1%`提升至`94.0%`,但是预测耗时仅仅从`3.19ms`提升至`3.21ms`。\n",
"\n",
"下面给出两种尺度下的图像大小对比。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/7f7fbe95eb3a4d14ac6def76874ce27559436040b91840688337f9a9f72c1a8d\" width = \"1200\" />\n",
"</div>\n",
"<center>32x100和48x192尺度下的图像大小对比</center>\n",
"\n",
"\n",
"### 2.2.4 模型量化策略-PACT\n",
"\n",
"模型量化是一种将浮点计算转成低比特定点计算的技术,可以使神经网络模型具有更低的延迟、更小的体积以及更低的计算功耗。\n",
"\n",
"模型量化主要分为离线量化和在线量化。其中离线量化是指一种利用KL散度等方法来确定量化参数的定点量化方法量化后不需要再次训练在线量化是指在训练过程中确定量化参数相比离线量化模式它的精度损失更小。\n",
"\n",
"PACT(PArameterized Clipping acTivation)是一种新的在线量化方法,可以**提前从激活层中去除一些极端值**。在去除极端值后模型可以学习更合适的量化参数。普通PACT方法的激活值的预处理是基于RELU函数的公式如下\n",
"\n",
"$$\n",
"\n",
"y=P A C T(x)=0.5(|x|-|x-\\alpha|+\\alpha)=\\left\\{\\begin{array}{cc}\n",
"0 & x \\in(-\\infty, 0) \\\\\n",
"x & x \\in[0, \\alpha) \\\\\n",
"\\alpha & x \\in[\\alpha,+\\infty)\n",
"\\end{array}\\right.\n",
"\n",
"$$\n",
"\n",
"所有大于特定阈值的激活值都会被重置为一个常数。然而MobileNetV3中的激活函数不仅是ReLU还包括hardswish。因此使用普通的PACT量化会导致更高的精度损失。因此为了减少量化损失我们将激活函数的公式修改为\n",
"\n",
"$$\n",
"\n",
"y=P A C T(x)=\\left\\{\\begin{array}{rl}\n",
"-\\alpha & x \\in(-\\infty,-\\alpha) \\\\\n",
"x & x \\in[-\\alpha, \\alpha) \\\\\n",
"\\alpha & x \\in[\\alpha,+\\infty)\n",
"\\end{array}\\right.\n",
"\n",
"$$\n",
"\n",
"PaddleOCR中提供了适用于PP-OCR套件的量化脚本。具体链接可以参考[PaddleOCR模型量化教程](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/deploy/slim/quantization/README.md)。\n",
"\n",
"\n",
"### 2.2.5 方向分类器配置说明\n",
"\n",
"训练方向分类器时,配置文件中的部分关键字段和说明如下所示。完整配置文件可以参考[cls_mv3.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/cls/cls_mv3.yml)。\n",
"\n",
"```yaml\n",
"Architecture:\n",
" model_type: cls\n",
" algorithm: CLS\n",
" Transform:\n",
" Backbone:\n",
" name: MobileNetV3 # 配置分类模型为MobileNetV3\n",
" scale: 0.35\n",
" model_name: small\n",
" Neck:\n",
" Head:\n",
" name: ClsHead\n",
" class_dim: 2\n",
"\n",
"Train:\n",
" dataset:\n",
" name: SimpleDataSet\n",
" data_dir: ./train_data/cls\n",
" label_file_list:\n",
" - ./train_data/cls/train.txt\n",
" transforms:\n",
" - DecodeImage: # load image\n",
" img_mode: BGR\n",
" channel_first: False\n",
" - ClsLabelEncode: # Class handling label\n",
" - RecAug: \n",
" use_tia: False # 配置BDA数据增强不使用TIA数据增强\n",
" - RandAugment: # 配置随机增强数据增强方法\n",
" - ClsResizeImg:\n",
" image_shape: [3, 48, 192] # 这里将[3, 32, 100]修改为[3, 48, 192],进行输入分辨率优化\n",
" - KeepKeys:\n",
" keep_keys: ['image', 'label'] # dataloader will return list in this order\n",
" loader:\n",
" shuffle: True\n",
" batch_size_per_card: 512\n",
" drop_last: True\n",
" num_workers: 8\n",
"```\n",
"\n",
"### 2.2.5 方向分类器实验总结\n",
"\n",
"在方向分类器模型优化中,我们使用轻量化骨干网络以及模型量化,最终将模型从**0.85M**降低到了**0.46M**,使用组合数据增广、高分辨率等特征,最终将模型精度提升了超过**2%**。消融实验对比如下所示。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/65b6e5d75f22403aba50665a1ab6dbba51cdb47eff0845d1909e5e54fd48e336\" width = \"1200\" />\n",
"</div>\n",
"<center>方向分类器消融实验</center>"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"\n",
"## 2.3 文本识别\n",
"\n",
"PP-OCR中文本识别器使用的是CRNN模型。训练的时候使用CTC loss去解决不定长文本的预测问题。\n",
"\n",
"CRNN模型结构如下所示。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/6c4f0d8e87de400691764107c7b1bd374c695a94d7de479da1c1de36a1cca238\" width = \"800\" />\n",
"</div>\n",
"<center>CRNN结构图</center>\n",
"\n",
"PP-OCR针对文本识别器从骨干网络、头部结构优化、数据增强、正则化策略、特征图下采样策略、量化等多个角度进行模型优化具体消融实验如下所示。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/44bcc2e797114171aaf584e4d7c79e19cf4ac1fc3e254ad7bdb5cbc01463c790\" width = \"800\" />\n",
"</div>\n",
"<center>CRNN识别模型消融实验</center>\n",
"\n",
"下面详细介绍文本识别模型的具体优化策略。\n",
"\n",
"\n",
"### 2.3.1 轻量级骨干网络和头部结构\n",
"\n",
"* 轻量级骨干网络\n",
"\n",
"在文本识别中仍然采用了与文本检测相同的MobileNetV3作为backbone。选自MobileNetV3_small_x0.5进一步地平衡精度和效率。如果不要求模型大小的话可以选择MobileNetV3_small_x1模型大小仅增加5M精度明显提高。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/6272a6cfdc7a431db44407c0e60c17abc259f8c83dc7422eb60ecb87f322a2e1\" width = \"800\" />\n",
"</div>\n",
"<center>不同骨干网络下的识别模型精度对比</center>\n",
"\n",
"* 轻量级头部结构\n",
"\n",
"CRNN中用于解码的轻量级头(head)是一个全连接层用于将序列特征解码为普通的预测字符。序列特征的维数对文本识别器的模型大小影响非常大特别是对于6000多个字符的中文识别场景序列特征维度若设置为256则仅仅是head部分的模型大小就为**6.7M**。在PP-OCR中我们针对序列特征的维度展开实验最终将其设置为48平衡了精度与效率。部分消融实验结论如下。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/e9f020c2748d4870b273a2c3ada03212d2791520baf949ac84b78429251bff2a\" width = \"600\" />\n",
"</div>\n",
"<center>不同序列特征维度的精度对比</center>\n",
"\n",
"\n",
"### 2.3.2 数据增强\n",
"\n",
"除了前面提到的经常用于文本识别的BDA基本数据增强TIALuo等人2020也是一种有效的文本识别数据增强方法。TIA是一种针对场景文字的数据增强方法它在图像中设置了多个基准点然后随机移动点通过几何变换生成新图像这样大大提升了数据的多样性以及模型的泛化能力。TIA的基本流程图如图所示\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/f7883b2263e64eaf853b9bf02fc2be5fb737aefd97e74a4db158943eae792a23\" width = \"600\" />\n",
"</div>\n",
"\n",
"实验证明使用TIA数据增广可以帮助文本识别模型的精度在一个极高的baseline上面进一步提升**0.9%**。\n",
"\n",
"下面是TIA中三种涉及到的数据增广的可视化效果图。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABIEAAABXCAYAAACTHlXgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvcnvLUmW5/U5Zubud/wNb34v5srKmrq6s9QloJtGokBi2LEFJBASUi8QW6D+D3a9YM2GDYIWiBVqMakautUDnVRXZVdWZkbG8KbfcAd3NzuHxTG/v9978SIjMzIyMzLzfqWI97v3+mBubnbsnO8ZTMyMI4444ogjjjjiiCOOOOKII4444ogjfrURftENOOKII4444ogjjjjiiCOOOOKII4444mePIwl0xBFHHHHEEUccccQRRxxxxBFHHPFrgCMJdMQRRxxxxBFHHHHEEUccccQRRxzxa4AjCXTEEUccccQRRxxxxBFHHHHEEUcc8WuAIwl0xBFHHHHEEUccccQRRxxxxBFHHPFrgCMJdMQRRxxxxBFHHHHEEUccccQRRxzxa4CfigQSkX9XRP4/EfkzEfnjr6pRRxxxxBE/CY6y6Igjjvg64CiLjjjiiK8LjvLoiCOO+DyImX25E0Ui8KfAvwV8H/gT4D8ws//3q2veEUccccSPxlEWHXHEEV8HHGXREUcc8XXBUR4dccQRPwo/TSTQvwz8mZl9x8wG4L8F/r2vpllHHHHEET82jrLoiCOO+DrgKIuOOOKIrwuO8uiII474XPw0JNBbwPduff5+/e6II4444ueJoyw64ogjvg44yqIjjjji64KjPDriiCM+F+lnfQMR+dvA3wYgtn/I4uFrv09/VD7KFAFMAAQzQxDMCoKSkh+Xs2KAhIQhN9erfxsGZlA/BxHM1H9Rq58NkYDWY0TEf59S5G5Or3/f+v72HybTzV99sOk6Qer59bvpoe214169+K0G3P5O6levpfHVz1Kvbaq32iSvHku4dc16/c+0o349HXm4xNR2e+2y9f6v30te/fzK05gB5t2jhbYJzGcduYyUkhEJiAghJMZc2O9HHyexOfTjq82w2mO37nn7meT1vvzsx1caOl3L6pVvvTYMghiYghVSgNREmhQQM1QLMUZiTCjGMBSGYaCU6foBVQURZBr7AmJSb16fwhRDCeJjv2uTt6Vev2kSIBigahQ1cimMY2Ycy6GPzF55qNdGVajdE/x4wutv6tCXBtC/wMbN64Pqa49XZFFo/pD53c85MtZ/bTrR3xOA3OpLkTq8tfZNfW8Sqgyp8/eVrrwlUEyR+p6no6f73MigH5Gu+5pYeOWP19/O6xP5TTLk9jVvzdvPjgS7+eHQIbeF5WuNs88bKlZFz20h69e029c7tPO2QLq5vsuJ6bzPea7Dg3hbBfE1QDMSoGkCEgQtRilKLgoEJDT1znLzyLV1n7283Hrv9eg3PfrhPdyMD3nlqvLqGnNYPuofNq1YipkRMESUNkVSCmhRTPWwhu6H0VsXoq93ai4T6r1CeGXhOvSvUZikkZkipsQY6LqEmWFWav9PYx5yUYYhYyIIoY4jubm2vP6O5KbfBJjOMe8REV7t61vnHb7afP+pmd1/Q09/bfF5etFhuZZbY78qRHaQGb4WuLwHQjzoEjfT4/ZK6ONhGj92WG9urfmTXnK439RSe7N+M51/8+Hmn9f1I27d4/Z5h3vZq6fent+v6EqHzvlMM+wgY75gfX+THvOmZ6tt+yJ95jM4yBe/uVghxsBy0WEYWvLNvJNACIExK7kU+qEAwd/n4cHedIv6bl97ZKty7abJr+urb4bceia7/e5u6zu4dJVJZ6MQBZrG5agE1zljSsQYXbYCpajrJGOhFK2yx9Db+uqt96rFny2EACZVort8i1Fo24iIUUomhEDbNDc6r91+VEHN6PuBoopq/d3sxiywKnsPen99Srtpj3dy4GaETv/cXt9uhqdtfvDLLYtC84fM7/OKnnI4MNx0yfRLnXfeh/rKsmY26UMc7Lhbp7xBJ8L16TqOb0RJlX1vspW+YGx/4fx/03R+XVc6fPc5N/uJdKbP05duHfOj9KXD9W86sa6Ut3693S9v0JnEe/fH05nstijwo01JTSAEcZ1JlXEsmBpIAgkHXXg650frS1PbfpS+dHPW5+tL9spHqXcWM0TM10CMIEoQmHVuR+acD/2QS6GYVJ0yoGYcTGmRV2TlpKPANGa9DaaFEIT5rAVTdNKTECS4nuQ6ZiEXq3bXZHNN43KSQ4fVwq8+rZkSXh0nt/ut9umPK4t+GhLoB8A7tz6/Xb97BWb2d4C/AxBO3zX5m3+M4s8QMVKsakpoEDVERwKgVTENasRg5P0z1vPIO0/uYhSuN3uev7zkyk5RGsz8BQVJEMNhYAZTmhBpYyL31/TDhvvnJ5RhYLcb0NDQM8MsIDFi5otbMBjHfHjxNhnAAaiDyUqpvFAAFFI8HK+qiMRKNAkWhCa1lFIQiU4MATFGrCgUpeQeCVavURhLJgTBVNFS/B7iC+SkgKcQyENPqAunmFFKIRwmmBKCK+3FFI0tMbaH31PToRoYcvZ7h4CZEApY8ecUEXIQCIIQfTLl7M8QpZIeCigxBaIECFIXgXDTh+rPq3VWaRmQkkn5gsiOf/Nv/S6LReDF8x/y4MEj5rMVMTakZsEnz674X/7eP+R6L8jyIRoSJg2lFFKdKGaGqRzGgojVRb/2aTCCcWO8HGSiHI6bPqeUkABlMMyEIsXfa7Q6LjNJM6FckXTL+3cjJyctd847+v2G9XrN2dkZs/mC/ZD5wUcvePr0JT/8wQXDaPRDps8jJG+9xESKM8ogiDQsZ6c0JIbdS0rZcP8ePHxwwmqZ2W0uuHf3lHvnd5jNFuScGYhAYJeVlxcbPnn6gg9/+Ay1wGY3omrE4POkaRpiFEopFA30Y6CoIGmBxQhhgUmL4oaeiPefloKpMP6j//pLC42fEX5iWSSrxyZ/9T/9zIXMBGThAjYUwCA1kBpSghiFsQhqkdQkJBqBPUUzWRNGR9O2lHF0gy34XMhqQCSQEIXCHsaBLiQCQjYhmxKbDrOC5ez9P7GGajeGGt4sKYo4BYBVeSLBFXKNwefyYWE0sMCB4BLDNWO/flDw1bPK3bpYTQr5QUHGKGEEzOUNQskjmMuO2s8HZS4goAkMQnDi3nUgAfXzutBADBQtYEokMpqSQ/RrlFKfvYAYWvza4P3bNbNqNAQUUJ3ICfM1JQiKQgBVfy8zjcxSYtg9ZzY3Hry1omlbdtd7rq63PHuxB5nTzB9SJFEsgBhRwLRgQVGDQKAUI0r0vtCMaanPG/0diFEM74vp3RFu5NItWaXgBEqZ1Jza7mCE2DjfOO5opBB0RxlHujCQ4sB795fcPVsxXF+z226xlFAJfPsvPgSZ0SxOSM2cbV8YxoJqQEKibRMxOKkQafDGFsZy5Ws0ipYeyVecnS74zfcfUsY9w3CBmJGaGe1sTp+FFxcbvvOXn6ISCc2KEFqQBrMAkrAY/Rk1YyoEWmJoGPIIYoQ0A2nQ0gIDKQXsloKac/6MMtb/7//Fdz8zkX9x+FJ6kf3NP4YgRLLrBJJAjWgZE59OUQwZn3PvfMa9OwvUhG9/5yNKc4ZaQwgRYkK1oGMmEmhQNO/oWiFhjFm5Lh1qCYnBzVt14tMNd8GkGiklV10nMuk3iI9jNXHnjUxrvMu5IMHlihawjKlSbOcyRKsOY5Xcs+IEQvG1VaAq6tO7vdGh3BViFHNiNuCyQSQiIaHqzhFpnERUdX0jKKCGVN1Jg7iIiw0UrW2txmuoOoHVeQeEJMRqEFgl0EQm4vLGMLDJ02IKZSTqjlD2PDyHd57c4f7dls3mkpQCTx6/Q4wNMTSEGLncFD5+esGf/ON/wWZvlPYciz5v1Fy/OhC/BmqTc9Tb4j3j/eVGt7txmOTgzZi7RZjczKmU0kFnraLL+0gMEyNEAS2IZdhfkkLhwSJzbyGs5pm2C0gYGcae+/cfsD45I7Uu8/sBxkH56AfPub7csN8XDBg0k1WJTYMaZA2
"text/plain": [
"<Figure size 1440x576 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 参考代码:\n",
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/data/imaug/text_image_aug/augment.py\n",
"import cv2\n",
"from ppocr.data.imaug.rec_img_aug import tia_distort, tia_stretch, tia_perspective\n",
"img = cv2.imread(\"./doc/imgs_words/ch/word_1.jpg\")\n",
"\n",
"img_out1 = tia_distort(img, 2.5)\n",
"img_out2 = tia_stretch(img, 3)\n",
"img_out3 = tia_perspective(img)\n",
"plt.figure(figsize=(20, 8))\n",
"plt.subplot(1,4,1)\n",
"plt.imshow(img[:,:,::-1])\n",
"plt.subplot(1,4,2)\n",
"plt.imshow(img_out1[:,:,::-1])\n",
"plt.subplot(1,4,3)\n",
"plt.imshow(img_out2[:,:,::-1])\n",
"plt.subplot(1,4,4)\n",
"plt.imshow(img_out3[:,:,::-1])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"### 2.3.3 学习率策略和正则化\n",
"\n",
"在识别模型训练中学习率下降策略与文本检测相同也使用了Cosine+Warmup的学习率策略。\n",
"\n",
"正则化是一种广泛使用的避免过度拟合的方法一般包含L1正则化和L2正则化。在大多数使用场景中我们都使用L2正则化。它主要的原理就是计算网络中权重的L2范数添加到损失函数中。在L2正则化的帮助下网络的权重趋向于选择一个较小的值最终整个网络中的参数趋向于0从而缓解模型的过拟合问题提高了模型的泛化性能。\n",
"\n",
"我们实验发现对于文本识别L2正则化对识别准确率有很大的影响。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/44bcc2e797114171aaf584e4d7c79e19cf4ac1fc3e254ad7bdb5cbc01463c790\" width = \"800\" />\n",
"</div>\n",
"<center>CRNN识别模型消融实验</center>"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"### 2.3.4 特征图降采样策略\n",
"\n",
"我们在做检测、分割、OCR等下游视觉任务时骨干网络一般都是使用的图像分类任务中的骨干网络它的输入分辨率一般设置为224x224降采样时一般宽度和高度会同时降采样。\n",
"\n",
"但是对于文本识别任务来说由于输入图像一般是32x100长宽比非常不平衡此时对宽度和高度同时降采样会导致特征损失严重因此图像分类任务中的骨干网络应用到文本识别任务中需要进行特征图降采样方面的适配**如果大家自己换骨干网络的话,这里也需要注意一下**)。\n",
"\n",
"在PaddleOCR中CRNN中文文本识别模型设置的输入图像的高度和宽度设置为32和320。原始MobileNetV3来自分类模型如前文所述需要调整降采样的步长适配文本图像输入分辨率。具体地为了保留更多的水平信息我们将下采样特征图的步长从 **2,2** 修改为 **2,1** ,第一次下采样除外。最终如下图所示。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/330608b3e6114eb49172b83864bd1caf160a46be759349f3a89ee08e583e9f46\" width = \"800\" />\n",
"</div>\n",
"<center>降采样步长策略优化可视化</center>\n",
"\n",
"为了保留更多的垂直信息,我们进一步将第二次下采样特征图的步长从 **2,1** 修改为 **1,1**。因此第二个下采样特征图的步长s2会显著影响整个特征图的分辨率和文本识别器的准确性。在PP-OCR中s2被设置为1,1可以获得更好的性能。同时由于水平的分辨率增加CPU的推理时间从`11.84ms` 增加到 `12.96ms`。\n",
"\n",
"下面给出了stride优化前后的特征图尺度对比。虽然最终输出特征图尺度相同但是stride从(2,1)修改为(1,1)之后,特征信息在编码的过程中被保留得更为完整。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1, 288, 1, 80]\n",
"[1, 288, 1, 80]\n"
]
}
],
"source": [
"# 参考代码:\n",
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/modeling/backbones/rec_mobilenet_v3.py\n",
"from ppocr.modeling.backbones.rec_mobilenet_v3 import MobileNetV3\n",
"\n",
"mv3_ori = MobileNetV3(model_name=\"small\", scale=0.5, small_stride=[2,2,2,2])\n",
"mv3_new = MobileNetV3(model_name=\"small\", scale=0.5, small_stride=[1,2,2,2])\n",
"\n",
"x = paddle.rand([1, 3, 32, 320])\n",
"\n",
"y_ori = mv3_ori(x)\n",
"y_new = mv3_new(x)\n",
"\n",
"print(y_ori.shape)\n",
"print(y_new.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"\n",
"### 2.3.5 PACT 在线量化策略\n",
"\n",
"我们采用与方向分类器量化类似的方案来减小文本识别器的模型大小。由于LSTM量化的复杂性PP-OCR中没有对LSTM进行量化。使用该量化策略之后模型大小减小`67.4%`、预测速度加速`8%`、准确率提升`1.6%`,量化可以减少模型冗余,增强模型的表达能力。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/0840afd4a40c4cc39352287c18d6de57d47fe43587804c1baadb924894881397\" width = \"600\" />\n",
"</div>\n",
"<center>模型量化消融实验</center>\n",
"\n",
"### 2.3.6 文字识别预训练模型\n",
"\n",
"使用合适的预训练模型可以加快模型的收敛速度。在真实场景中用于文本识别的数据通常是有限的。PP-OCR中我们合成了千万级别的数据对模型进行训练之后再基于该模型在真实数据上微调最终识别准确率从从`65.81%`提升到`69%`。\n",
"\n",
"### 2.3.7 文本识别配置说明\n",
"\n",
"下面给出CRNN的训练配置简要说明完整的配置文件可以参考[rec_chinese_lite_train_v2.0.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)。\n",
"\n",
"\n",
"```yaml\n",
"Optimizer:\n",
" name: Adam\n",
" beta1: 0.9\n",
" beta2: 0.999\n",
" lr:\n",
" name: Cosine # 配置Cosine 学习率下降策略\n",
" learning_rate: 0.001 \n",
" warmup_epoch: 5 # 配置预热学习率\n",
" regularizer: \n",
" name: 'L2' # 配置L2正则\n",
" factor: 0.00001\n",
"\n",
"Architecture:\n",
" model_type: rec\n",
" algorithm: CRNN\n",
" Transform:\n",
" Backbone:\n",
" name: MobileNetV3 # 配置Backbone\n",
" scale: 0.5\n",
" model_name: small\n",
" small_stride: [1, 2, 2, 2] # 配置下采样的stride\n",
" Neck:\n",
" name: SequenceEncoder\n",
" encoder_type: rnn\n",
" hidden_size: 48 # 配置最后一层全连接层的维度\n",
" Head:\n",
" name: CTCHead\n",
" fc_decay: 0.00001\n",
" \n",
" Train:\n",
" dataset:\n",
" name: SimpleDataSet\n",
" data_dir: ./train_data/\n",
" label_file_list: [\"./train_data/train_list.txt\"]\n",
" transforms:\n",
" - DecodeImage: # load image\n",
" img_mode: BGR\n",
" channel_first: False\n",
" - RecAug: # 配置数据增强BDA和TIATIA默认使用\n",
" - CTCLabelEncode: # Class handling label\n",
" - RecResizeImg:\n",
" image_shape: [3, 32, 320]\n",
" - KeepKeys:\n",
" keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order\n",
" loader:\n",
" shuffle: True\n",
" batch_size_per_card: 256\n",
" drop_last: True\n",
" num_workers: 8\n",
" ```\n",
"\n",
"### 2.3.8 识别优化小结\n",
"\n",
"在模型体积方面PP-OCR使用轻量级骨干网络、序列维度裁剪、模型量化的策略将模型大小从4.5M减小至1.6M。在精度方面使用TIA数据增强、Cosine-warmup学习率策略、L2正则、特征图分辨率改进、预训练模型等优化策略最终在验证集上提升`15.4%`。\n",
"\n",
"PP-OCR中部分识别效果如下所示。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/8da2664561ac47d2b3db3a9957fb81a39d5b8b06f85346a48fe4ac9fd6ff70b3\" width = \"600\" />\n",
"</div>\n",
"\n",
"文本识别模型的代码演示如下。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAABnCAYAAAAQVrnSAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvVmsZFd2JbbOnWN878WbX76cmAOTU1WxBqqqVIRUkgxIdsvdH0bDAxptoIH6MmADNtxqf/uj/WO7v2wIbgMyYEAtlwfJQBtGq1Sl6q6SWCSTZJI5MOfhzWO8mO94/HHO2WffyMhiUsXOhonYAMHIeDfuPfeMe6+99t5CSompTGUqU5nK///F+TfdgKlMZSpTmcoXI9MNfSpTmcpUviQy3dCnMpWpTOVLItMNfSpTmcpUviQy3dCnMpWpTOVLItMNfSpTmcpUviTyK23oQojfFUJ8KoS4K4T4gy+qUVOZylSmMpXPL+JvykMXQrgAbgP4twBsAHgXwH8gpbzxxTVvKlOZylSm8rzyq2jobwG4K6W8L6VMAPwxgL/9xTRrKlOZylSm8nnF+xV+ewrAE/bvDQC/9st+EIahdF1XPdjzEAQBAEAIgTRNAQBJkiAMIwCA6zowFsRwOILnqeaa3wFAEsfIiwIA4Ps+fN8HAGRZhiSJ9X1cBEGonwXEsfq+KCTd03EcZHmmvs9z+L56RrVageOoNud5jiRJ6LN5lyzPkervHddBpNufZRmyLKPPtVqN2hPHMbUjCEN6pyzP6XtZFAijSL+phCxUX8RJjDBU7xP4AQqp3j+OY+R5DgDqWdr4yvIMw/5AfV+vw3Ec6vder6f6zvPoWVIWyLJc9/sAzWaT+tuMR7/fp76OogiFHoPRcETtqdfrdJ8sTRHrPqpUKvRbABgOBij0favVKrUvSRLqiyiqQAhQu6mPJFCpRHo8i1J/mz71PI/6JUlSuo/n+fQ5S1P6rZpHdmmYZyVpSvcMfJ/ebRSPICD0u0U0L0ajEb2X73v0znEcw/PU52ajCU9/L4RAlmXo6zFJkoT6Nc8z6hff9/XT1NgWel44jlOaz6bdWZ6hUqnqZwB5bucL9ZHrQuoJk6YZPdf3fbj6uUVh57/n+fA8l/rI9J2UEq5pgxD0/mmS0Pee59G7DAcDmEHwfZ/an6Yp4niEKDJ7gWvnRWz7xfM86u+iKKh9EAIBzTGBTO8vhZQ0Do7j0L6T5xmNieu5tHbiJKa+Vn0l6PqUzRfTBkcIjGhuStoLIASyLKW+8nW/u47d48weop7l076TZSna7c6BlHIRnyG/yob+XCKE+AGAHwBqIb/yyisAgIWFBZw6dQqAevG9vT0AwNbWDs6fPw8AmJmZofvcu/uANrFms0kDv7+/j263CwBYW1/F/Pw8fb+xsQEAmJ+fp2cJIej7JElw5swZAGoD3NnZAQBsbGzQfb71rW9heXkVANBqzWNzYxsA8O5772A0GgEA+v0e2u02AGB5eRGnT58GAHQ6HTx69EhfM8T3v/99AMDi4hI++ugj/NXP3wEAvPG1r2L11BoAoN3p4NHjBwDUBL1y5Yr+fR/dbgcAsL29jYsXLwIAlpaWqB3379/HYDCgdpvJfbi3j1s3bgEA3n777dICeOcd1Ya1tTWcf+kcALWYzEb/+PFjfP3rXwMA5LKgg+G9q+/j1Kpq88WLF2lBX7t2jTaS73//+xiNVBvu3buHW7du0fWLC0sA1EK9ffs2LdbXX3+dxnZzcxs3b94EALz17V+jdidJggcPVB/5vo+vfvWrANRCuX//PgDg6OgIL730EgBgdnaW3ufJkyc0j1ZXV6Efi8PdXbTbRwCARqOB2dlZAGoRm3mxvbdL86VeqdI9j4+Pqc3r6+v0+caNG/Re66fXMDurDsa7d++jNbcAAPjN3/4dLC6qdVqJqmi32/jzP/9zfd1dSK1k9HodtObmdLuXkWcJPbvT7QMA5ubmUK/XYWR3X62pwWCASy+/DEBtYv3+UI/JHZqrtVqN2rq1tYXDw0MAwOXLlxHQJhtja2uL+q41p9anlBK7u7sA1Jwy7xNFEc2FnZ0d1BsNAGo9VqtV6qNMKwArKyuYm1P9vre3hwf37uFl3e5ms0m/efDgAa35laVlrK2donc27UvTFM2GupeQDo6PjwEAw3hECkoURTSGnU4Hzabqu+bsDIQ+ZHZ2dujwWFpYpD5qt9vYPzwAACyuLKPVUs+SmcTjx4/pmldeeQ2AUm5MG3b29jA7p9qwuLiI0UitqcO9fQh9kqwsLaNWrwAAjo6P8X/9n//PIzyH/CqQyyaA0+zf6/q7kkgp/1BK+U0p5TfNQprKVKYylal88fKraOjvArgkhDgPtZH/+wD+w1/2AyklaVnc5C6Kgk5BIQSZILVajTSl+fl5Oh0dYU22+fl5dDpKc23OzpAJub9/CCnVKeu6PszZNRrFSBJjHgrox2pNUl0TRVX4vjp80jTHycmJ/pzh6FidykdHR6QZD4cDDIdDar8x7/v9PmkAlUoNc6RhraLT6eD27dv6/TO612g0IjOQm6Cu65bMYKP5DAaDEgxk3t9xHLo+yzLq9ygISPsYjUZ0ovuuizxRz5V5jpcvXgIAfO2Nr2BhoQUAaHc7aB8qLXZwZYCZutK4giCgZ83MzNA7x3EMIRx6ljnQl5eX6b12d3fR7XZJq8uyjPovDH24roEyQnrGyUkfg4Hp1wp9dl0XUmqIJ0vot0WRlb6vVpUZ7HmONZUZXGHaoX5rzXhHqv/Mc1st1S9Xrlwhjbbf79P85WMGqLmkOtia2fxZSZJgNBrR+8dxjDS284L/xjy70WiQhl6vN0uQo+P69Nt6RWmfUkpIqd55fn4RjYbSsoMgIA04zyWaTbXu1tfXaU71ej16TyEECmnACDvnDWxk2mD6rigKBm9Y6IaLlJLgICkBKRxkGk7KCkmf41GCfk9ptWkrx1xLWTunT5/GV776pu7rFFGk3idLUlrzw+GwNLamTb1+H0IPbhiGDLps0DUzMzOoVCo0Nqf1PHd9H80Z1b/pKEUQqPm1v7+PMFDXJ1mB4xNlVXQ6PdoLIB2aF71eD3OzM9R3rmP3yOeVv/GGLqXMhBD/CYD/F4AL4H+WUl7/jN+AY+hmY+HYp/nOXGM29Gq1Sr+FtIbF6uoqdb7juWSK3bt3jwaCS5qm9nqGXyVJUppkZuCazSYtnsFgSCbd8fExLdYkiUubgfm+3++joc3Mc+deovvU63W88sorBNPcf/iQDoSTkxOCTebm5uwhxjZoIUQJfuj31YIeDoeEOSZJQu/GMcs8z2nR882GY5Gu61JbV1ZWIITFt1cWFVSyvLxMYyWEQG+o2txqtehdarUaTdaXX34ZFy5cAKAWhu2Xc1hbPUXvUKlUyCReWVmhcZifny9tdGbzqdfr1PfD4ZAO96Io6J35RpimKb23lBKOI+i35pCsVCp0+CRJQu+jNkNJ13z729+mtv3sZz8DoCAEI0EQUJullHR/02fjY2PejePD/W6if2/HPIoivPG6gpmazSaGcUL3dITqi0JmSBON8Scx9WNRFJDsna0fS9JGNBj02Ebv0W/39nawvamM8NFoRPNICIFeV41/tWZ8PmaDzunfvN/NZ9d1kedPM+3MAWv6m49hFEVMKUnoGc1mE7Vag97N97TfTD8HUPPfXM8VnSzPSwqT2dwFXEDo772QPg8HMYYjNWfjNIfv670ptz6Xfr9Pn3vMB9jr9ZBlhe7rEUZDrTA4HkZ6LCXzK34eJuKvhKFLKf85gH/+q9xjKlOZylSm8sXIv3an6PMK19DN6es4DmkBHH/nGo3nW7M5L0DaLdc++b9HoxFp4pxtwU/uPM/peTMzM1hZWQEAdLs9+t51HfJaF0VB33OtREqJ1VXlUD137hw5rNI0Qa1Wxa/9miIF1RoN3Lj1KQDg6OAAI+14jJaW4WnIQmY5MnN6Zzmq2nvuOy6kZlsgL+Brz3g8GNIJPxgMSMtK05Q0Bf7OUkqgsFCB+Rz6ASS0OQ0LldWqVcvAKArMaOfdqVOnSCtxXReVSo2
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2021-12-24 21:50:31-- https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar\n",
"Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.195, 182.61.200.229, 2409:8c04:1001:1002:0:ff:b001:368a\n",
"Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.195|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 8875520 (8.5M) [application/x-tar]\n",
"Saving to: ch_PP-OCRv2_rec_infer.tar\n",
"\n",
"ch_PP-OCRv2_rec_inf 100%[===================>] 8.46M 24.2MB/s in 0.4s \n",
"\n",
"2021-12-24 21:50:31 (24.2 MB/s) - ch_PP-OCRv2_rec_infer.tar saved [8875520/8875520]\n",
"\n",
"[2021/12/24 21:50:33] root INFO: Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.9409585)\n"
]
}
],
"source": [
"# 可视化原图\r\n",
"img = cv2.imread(\"./doc/imgs_words/ch/word_4.jpg\")\r\n",
"plt.imshow(img[..., ::-1])\r\n",
"plt.show()\r\n",
"\r\n",
"!cd inference && wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar -O ch_PP-OCRv2_rec_infer.tar && tar -xf ch_PP-OCRv2_rec_infer.tar\r\n",
"!python tools/infer/predict_rec.py --image_dir=\"./doc/imgs_words/ch/word_4.jpg\" --rec_model_dir=\"./inference/ch_PP-OCRv2_rec_infer\" --use_gpu=False"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAABrCAYAAABnlHmpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvUezbVly3/fLtdY2x17z/Kt6ZdBggy2FEKGABH0CRWimqaQPwJGGGvCzcKCxPgEjNJYiNOBIokCRIECgu6v7lXnmunO2WSY1yLXPvdVVABsE2JAQNztu9Ttum2XS/POfuUVVeZRHeZRHeZT//4v7+76AR3mUR3mUR/m7kUeF/iiP8iiP8g9EHhX6ozzKozzKPxB5VOiP8iiP8ij/QORRoT/KozzKo/wDkUeF/iiP8iiP8g9E/lYKXUT+GxH5NyLyZyLyT/+uLupRHuVRHuVR/uYi/6E8dBHxwJ8C/zXwFfAvgP9eVf/V393lPcqjPMqjPMpvK38bD/2PgT9T1X+nqjPwvwD/7d/NZT3KozzKozzK31T+Ngr9E+CXD15/Vd97lEd5lEd5lL8HCf+xTyAi/wT4JwC49o9k/QwDeeT+Ow9enj6T5bMHkJAsXxIQQUtGnIAWVBURUFVQEOfr9+8PrCiCoPenvpfvIU9iP1W1n6uCKlLPr2qHVbVzI3Y9y1dPB1vgrIc3eDrfj0BdfxX69ZvXKw8G6wefPXz/Rz5UffCFHz/hw0+WI+iPHvOvP8733v/eAX/4G1l8izouKmrz/BvnkPrf01yiP7jL762vh+Msv3HND65JHv5D5YdDd5pLefC75R/3Z5T6XVuHhabxtG1DaDylZEpRSiloUXIupFwoBRCHiK1ZFXkw4MuaredQu+ffuKnv389vyo+td+p9/uB7y9p9+P6D7+nD+7arW/aJ3b/W/VhO49E0QvCO4B05JUoptp+cwzkHIuSipJRJOdcZfehryv00Cqd9eP/pg8+x9S1SV4cWQHGAOGibgPceLRkoeO9xThBx9ksFLUpBKfWaYsqUXOwAdgF1fO7XyW8OvekCefCqXndd/7qMzqIvlqX34DYerm09/Oqdqj7j3yN/G4X+K+DNg9ef1ve+J6r6z4B/BuDP3mjzX/1PRHWUul4FaFzBS11frkVcqHs5Q8k4CkVsWkQ8JYMTxZGYpzucjjQB9mcbRITjMDGMkZQKElZIu0WlsU2Gw3uPiFCwjZdzRkuBYmrFIbRNIKURSmIeB4LztI0nxQIIoelQ3zKmYjq9bQgoOUcotqByzkg1BM4528jLZnecDIXNniJl2bTOvrcsXufA3S8KVbXrdYILAXGuHkrx3pNztu+6YJun3jNAKQWh2EYqmVzS6XcAOSVwDsGuddnMTsQ2ST2/KSUzZo4CJVejKgR3vxlVlTTHOvaKc3YtqrZpVE4jbveaC0UzOMW7hlIU7xtKiRRNVR172/JajbqCR2yNCGQR8AFKHd+cEAXnBZyiTpCi93NBofHhgaJwdaz9jxp/L/a5CqhmcpxBM62D1ivzcAXzwOefv+A/+dmXNKHw7v1bbm8PCJ6u69huz8jZ8/a7K95+e83HmwkNa8JqB35FKgXvGlJKhBBsL6CneS7FrkWwtewUlFzn+17p2Zwt91l/o/fr6Ef262meH4qIh1L3jCSKt8+zKiUlvAidU7xmyAfyfKApE+e7ln/0k+esO8GVxM31R9o2sNudsd1uCU3DmDK3h5mr6wM//8WvGKbCODqcBIo4pnkm5oJvWrquO82JiMdJgOJAWzw2DqLR1ksZEB3pe3hyvuLsrOPp5Y6cRqbxwLOLPefne7quw7lAcZ45ZaZxJiY4zpF3H6759dvv+Hh1y5Q8WT0pFXJRGxPqXsXhg+CcjaFmgB5wFBziPb5p8E1LKQ1ZHYitb/M7C1BOc6C52F6tczH/H//05z9ciT+Uv41C/xfAPxKRLzFF/t8B/8Nf9wNFSBJOysn+3/wNrQpF8RSHKTddLJRDTtawPDACCZWZ3brh1csLgnMoGXmy4/b2jg9XN9xNA0pnE4BDpCAF1AlOhKIFh5IVHBkBvPNInvE5Ms0H+saz3/XkeeYwDzb5RUk5gQTAoTGRJJvtdYJTu6+yeJy5nKy5c87eP0Uh3CvwZbMJCAW8vXlyytWUsYopRXJBc7H7cQ53sv7OjuUcIoL3guaCdwJZTAmVgne2gTUnnCsUkh2jKuxF4ZlDYueAcpobQdA4m8Ksn0vhZChzzjQ+nJSnqCLqqjYs4B1KwPmC14A6M4LOOZxXijhyiqjLuGWNqKBl8SSrYqme7CLi3IPIzYPmqt8FMe/g3p6a+TIz4cRuT+4jOrt1QasSxdVNXPTkWUtRM/7TAU/k+YsdP/3yOa0b+PjuWwKZP/zZ7+Mk4H1D03QonvOzPX37Dfnnb7kbIxonRBqcs62pztdI8EG0uXjtuvhwQhH53rWxGF8piAoi5XsK/K9S5qfxW+bSQcmK5mpMqPdc16sXG7MgICVT8oij0Hllv3JcbjyNjHgc03wgNMrZ+Zbz8x1934M43Gxee0oNL56dMw6Zb765pZSZmAquRFoRhEyaBptPJwTfo0TyrHTtnsa3NARKUlKeIWfaBvarhsuzNZcXa+bphuFwTdcEznYr+jagpTDPR5K5OnjncK2gLrDb9Fycb3Gi3B4SKRUmZ967OHCorSmBEDzeizkeRYgpE3Mh1/Um2ZGZUQrOtbbiREAKahYA84UUnOJwP2pc/zr5D1boqppE5H8E/lfAA/+zqv7JX/sbBG3M6jrnEDGvTbNQnFvUhm0iBwFPIx4RZZgnmhDIOeMdNC5yd/WOT15e8Pmb52xXgfFwhzhhu91ys3a0TLy/Hvk43SJ0eN+AtGiZbYE6G1B/CjULATFPez4QiHz5+gnPnlzgpDAMB1TPORwnfvmrb0BbQujJ6ikqhL6xSEMLmm2xL56QKOAEkYA4peT8ADqpnrBmtIZ+ch/AghROukTMO9WieCxSATMYqoU0J3RZBA+86pITJWXEhhZPwYkphhwjOU2oRoIvJw9aqwYTp+SYENH7CIF7j7ytCpuikDmdswsCAUqZvrcKFoXjpTClBAo5NSge1I6fcYgL+GYNBdqmBe/IWckqpkQBvLPzUuzGhGqoq6oTQR0PlH0Nz8uy0mycMuDEPKllNZjNXcCYCoVwbwREC06FoB4tE6XM9C5zuV/z5ZsLzs88tx+/o2Xi8vKSJ9sdIo6cE1McaJuOp2cr9M1TshZ++faKq8NAnj2h21Yj5ev1CkHcyTCaoleLDuv/lApVyWKAH6At5YfQVN3HP6owfvO9gqJYtHf6nbN59ioIiZJnNI20rbBqhKfbnv1akXREvWMab2iaBmFGJFPKjAoEL6zXHqFn1Tzl47tbyjEyzzNTisxJUPGEtkGco4hQqpGb58RYEuSRkjO5OByZVhLSCl3r2HSeLhQaF5nygJaJftURgielhGZIJRNzplANmW9ovWPdedadYwyK27XVmTCIxjcBL646CRa1xjKj2uB9IObAOMNxmIlJyTqRkqdINt3gPV6XyD2f9jdS9WFeHLm/Ckv7ofytMHRV/efAP/+tfyCC+mAbzTnE2YUWsk0QdYE4wRdFSqGUhJRMIwVNM5ozSCHNIy+erPnJZy/J8y1zUS73G/b7LU3TsWkCabij9YF+gG+vjuQ0IjSAx/mA0BiE4gSHQMnm2ZVEIPLq2Y5XT7c0ITGMBy7Wge3ZnuMQEZ357v0dt8cbSlRcu4Z5Ims2DzwJThqCVI9ZDPtMOaKpQFUcNiwe1cRps5TFIXOYK3m/uUzJmhdZSkachf2nUBmH8w5PQCloKieIxLtlA9uoxziRU0Kcow2wWTv6NuD9vfJbNnbwFRKp+KRiC9Bjm817TwiBNoQT3BRjNAMs7hR9GDZZqoFT2+gpMg6RaYqkKZEKBimoR1JBszDrgLi26uGVQUTf8zKdQUB+gd5rRPBgjIuU078X4wSgThFnUY5pcfeDYy8Ipyl6PUU9pEJOI8SJIDP7fcenr854+fSMNN1wd/2Oi/Mtz59ccjzcVmU
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2021/12/24 21:52:00] root INFO: Predicts of ././doc/imgs_words/ch/word_1.jpg:('韩国小馆', 0.9967349)\r\n"
]
}
],
"source": [
"import cv2\n",
"# 对 ./doc/imgs_words/ch/word_1.jpg 旋转180度得到\n",
"img = cv2.imread(\"./doc/imgs_words/ch/word_1.jpg\")\n",
"\n",
"plt.imshow(img[:,:,::-1])\n",
"plt.show()\n",
"\n",
"!python tools/infer/predict_rec.py --image_dir=\"././doc/imgs_words/ch/word_1.jpg\" --rec_model_dir=\"./inference/ch_PP-OCRv2_rec_infer\" --use_gpu=False\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# 3. PP-OCRv2优化策略解读\n",
"\n",
"\n",
"第2节的内容主要是对PP-OCR以及它的19个优化策略进行了详细介绍。\n",
"\n",
"相比于PP-OCR PP-OCRv2 在骨干网络、数据增广、损失函数这三个方面进行进一步优化,解决端侧预测效率较差、背景复杂以及相似字符的误识等问题,同时引入了知识蒸馏训练策略,进一步提升模型精度。具体地:\n",
"\n",
"* 检测模型优化: (1) 采用 CML 协同互学习知识蒸馏策略;(2) CopyPaste 数据增广策略;\n",
"* 识别模型优化: (1) PP-LCNet 轻量级骨干网络;(2) U-DML 改进知识蒸馏策略; (3) Enhanced CTC loss 损失函数改进。\n",
"\n",
"本节主要基于文字检测和识别模型的优化过程去解读PP-OCRv2的优化策略。\n",
"\n",
"## 3.1 文字检测模型优化详解\n",
"\n",
"文字检测模型优化过程中,采用 CML 协同互学习知识蒸馏以及 CopyPaste 数据增广策略最终将文字检测模型在大小不变的情况下Hmean从 **0.759** 提升至 **0.795**,具体消融实验如下所示。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/71c31fc78946459d9b2b0a5aeae75e9bf784399a73554bc79f8c25716ed9dcbe\" width = \"800\" />\n",
"</div>\n",
"<center>PP-OCRv2检测模型消融实验</center>\n",
"\n",
"### 3.1.1 CML知识蒸馏策略\n",
"\n",
"知识蒸馏的方法在部署中非常常用,通过使用大模型指导小模型学习的方式,在通常情况下可以使得小模型在预测耗时不变的情况下,精度得到进一步的提升,从而进一步提升实际部署的体验。\n",
"\n",
"标准的蒸馏方法是通过一个大模型作为 Teacher 模型来指导 Student 模型提升效果,而后来又发展出 DML 互学习蒸馏方法即通过两个结构相同的模型互相学习相比于前者DML 脱离了对大的 Teacher 模型的依赖,蒸馏训练的流程更加简单,模型产出效率也要更高一些。\n",
"\n",
"PP-OCRv2 文字检测模型中使用的是三个模型之间的 CML (Collaborative Mutual Learning) 协同互蒸馏方法,既包含两个相同结构的 Student 模型之间互学习,同时还引入了较大模型结构的 Teacher 模型。CML与其他蒸馏算法的对比如下所示。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/0d2eb64faebe41648631c656eeab8ba3d27e2f5617e9486b8eb26dd12bafbf43\" width = \"800\" />\n",
"</div>\n",
"<center>CML与其他知识蒸馏算法的对比</center>\n",
"\n",
"具体地文本检测任务中CML的结构框图如下所示。这里的 response maps 指的就是DBNet最后一层的概率图输出 (Probability map) 。在整个训练过程中总共包含3个损失函数。\n",
"\n",
"* GT loss\n",
"* DML loss\n",
"* Distill loss\n",
"\n",
"这里的 Teacher 模型的骨干网络为 ResNet18_vd2 个 Student 模型的骨干网络为 MobileNetV3。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/7886f63f94b84215812e24377e12e6f1655dbc4971054659ade9850355a6125f\" width = \"800\" />\n",
"</div>\n",
"<center>CML结构框图</center>\n",
"\n",
"* GT loss\n",
"\n",
"两个 Student 模型中大部分的参数都是从头初始化的,因此它们在训练的过程中需要受到 groundtruth (GT) 信息 的监督。DBNet 训练任务的 pipeline 如下所示。其输出主要包含 3 种 feature map具体如下所示。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/1506d74093f247589dc7f1adc3666c9de9075b2d2e704a668afda42a3b3c7a71\" width = \"600\" />\n",
"</div>\n",
"<center>DBNet头部结构</center>\n",
"\n",
"对这 3 种 feature map 使用不同的 loss function 进行监督,具体如下表所示。\n",
"\n",
"<center>\n",
" \n",
"| Feature map | Loss function | weight |\n",
"| :------------: | :-------------: | :------: |\n",
"| Probability map | Binary cross-entropy loss | 1.0 |\n",
"| Binary map | Dice loss | $\\alpha$ |\n",
"| Threshold map | L1 loss | $\\beta$ |\n",
"\n",
"</center>\n",
"\n",
"\n",
"最终GT loss可以表示为如下所示。\n",
"\n",
"$$ Loss_{gt}(T_{out}, gt) = l_{p}(S_{out}, gt) + \\alpha l_{b}(S_{out}, gt) + \\beta l_{t}(S_{out}, gt) $$ \n",
"\n",
"\n",
"* DML loss\n",
"\n",
"对于 2 个完全相同的 Student 模型来说因为它们的结构完全相同因此对于相同的输入应该具有相同的输出DBNet 最终输出的是概率图 (response maps),因此基于 KL 散度,计算 2 个 Student 模型的 DML loss具体计算方式如下。\n",
"\n",
"$$ Loss_{dml} = \\frac{KL(S1_{pout} || S2_{pout}) + KL(S2_{pout} || S1_{pout})}{2} $$\n",
"\n",
"其中 `KL(·|·)`是 KL 散度的计算公式,最终这种形式的 DML loss 具有对称性。\n",
"\n",
"\n",
"* Distill loss\n",
"\n",
"CML 中,引入了 Teacher 模型,来同时监督 2 个 Student 模型。PP-OCRv2 中只对特征 `Probability map` 进行蒸馏的监督。具体地,对于其中一个 Student 模型,计算方法如下所示, lp(·) 和 lb(·) 分别表示 Binary cross-entropy loss 和 Dice loss。另一个 Student 模型的 loss 计算过程完全相同。\n",
"\n",
"$$ Loss_{distill} = \\gamma l_{p}(S_{out}, f_{dila}(T_{out})) + l_{b}(S_{out}, f_{dila}(T_{out})) $$\n",
"\n",
"最终,将上述三个 loss 相加,就得到了用于 CML 训练的损失函数。\n",
"\n",
"\n",
"检测配置文件为[ch_PP-OCRv2_det_cml.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml),蒸馏结构部分的配置和部分解释如下。\n",
"\n",
"\n",
"```yaml\n",
"Architecture:\n",
" name: DistillationModel # 模型名称,这是通用的蒸馏模型表示。\n",
" algorithm: Distillation # 算法名称,\n",
" Models: # 模型,包含子网络的配置信息\n",
" Teacher: # Teacher子网络包含`pretrained`与`freeze_params`信息以及其他用于构建子网络的参数\n",
" freeze_params: true # 是否固定Teacher网络的参数\n",
" pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy # 预训练模型\n",
" return_all_feats: false # 是否返回所有的特征为True时会将backbone、neck、head等模块的输出都返回\n",
" model_type: det # 模型类别\n",
" algorithm: DB # Teacher网络的算法名称\n",
" Transform:\n",
" Backbone:\n",
" name: ResNet\n",
" layers: 18\n",
" Neck:\n",
" name: DBFPN\n",
" out_channels: 256\n",
" Head:\n",
" name: DBHead\n",
" k: 50\n",
" Student: # Student子网络\n",
" freeze_params: false\n",
" pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained\n",
" return_all_feats: false\n",
" model_type: det\n",
" algorithm: DB\n",
" Backbone:\n",
" name: MobileNetV3\n",
" scale: 0.5\n",
" model_name: large\n",
" disable_se: True\n",
" Neck:\n",
" name: DBFPN\n",
" out_channels: 96\n",
" Head:\n",
" name: DBHead\n",
" k: 50\n",
" Student2: # Student2子网络\n",
" freeze_params: false\n",
" pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained\n",
" return_all_feats: false\n",
" model_type: det\n",
" algorithm: DB\n",
" Transform:\n",
" Backbone:\n",
" name: MobileNetV3\n",
" scale: 0.5\n",
" model_name: large\n",
" disable_se: True\n",
" Neck:\n",
" name: DBFPN\n",
" out_channels: 96\n",
" Head:\n",
" name: DBHead\n",
" k: 50\n",
"```\n",
"\n",
"`DistillationModel`类的实现在[distillation_model.py](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/modeling/architectures/distillation_model.py)文件中,`DistillationModel`类的实现与部分讲解如下。\n",
"\n",
"```python\n",
"class DistillationModel(nn.Layer):\n",
" def __init__(self, config):\n",
" \"\"\"\n",
" the module for OCR distillation.\n",
" args:\n",
" config (dict): the super parameters for module.\n",
" \"\"\"\n",
" super().__init__()\n",
" self.model_list = []\n",
" self.model_name_list = []\n",
" # 根据Models中的每个字段抽取出子网络的名称以及对应的配置\n",
" for key in config[\"Models\"]:\n",
" model_config = config[\"Models\"][key]\n",
" freeze_params = False\n",
" pretrained = None\n",
" if \"freeze_params\" in model_config:\n",
" freeze_params = model_config.pop(\"freeze_params\")\n",
" if \"pretrained\" in model_config:\n",
" pretrained = model_config.pop(\"pretrained\")\n",
" # 根据每个子网络的配置基于BaseModel生成子网络\n",
" model = BaseModel(model_config)\n",
" # 判断是否加载预训练模型\n",
" if pretrained is not None:\n",
" load_pretrained_params(model, pretrained)\n",
" # 判断是否需要固定该子网络的模型参数\n",
" if freeze_params:\n",
" for param in model.parameters():\n",
" param.trainable = False\n",
" self.model_list.append(self.add_sublayer(key, model))\n",
" self.model_name_list.append(key)\n",
"\n",
" def forward(self, x):\n",
" result_dict = dict()\n",
" for idx, model_name in enumerate(self.model_name_list):\n",
" result_dict[model_name] = self.model_list[idx](x)\n",
" return result_dict\n",
"```\n",
"\n",
"使用下面的命令,可以快速完成蒸馏模型的初始化过程。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DistillationModel(\n",
" (Teacher): BaseModel(\n",
" (backbone): ResNet(\n",
" (conv1_1): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(3, 32, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (conv1_2): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(32, 32, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (conv1_3): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(32, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (pool2d_max): MaxPool2D(kernel_size=3, stride=2, padding=1)\n",
" (bb_0_0): BasicBlock(\n",
" (conv0): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(64, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (conv1): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(64, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (short): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(64, 64, kernel_size=[1, 1], data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" )\n",
" (bb_0_1): BasicBlock(\n",
" (conv0): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(64, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (conv1): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(64, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" )\n",
" (bb_1_0): BasicBlock(\n",
" (conv0): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(64, 128, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (conv1): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(128, 128, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (short): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(64, 128, kernel_size=[1, 1], data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" )\n",
" (bb_1_1): BasicBlock(\n",
" (conv0): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(128, 128, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (conv1): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(128, 128, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" )\n",
" (bb_2_0): BasicBlock(\n",
" (conv0): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(128, 256, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (conv1): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(256, 256, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (short): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(128, 256, kernel_size=[1, 1], data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" )\n",
" (bb_2_1): BasicBlock(\n",
" (conv0): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(256, 256, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (conv1): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(256, 256, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" )\n",
" (bb_3_0): BasicBlock(\n",
" (conv0): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(256, 512, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (conv1): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(512, 512, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (short): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(256, 512, kernel_size=[1, 1], data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" )\n",
" (bb_3_1): BasicBlock(\n",
" (conv0): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(512, 512, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" (conv1): ConvBNLayer(\n",
" (_pool2d_avg): AvgPool2D(kernel_size=2, stride=2, padding=0)\n",
" (_conv): Conv2D(512, 512, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (_batch_norm): BatchNorm()\n",
" )\n",
" )\n",
" )\n",
" (neck): DBFPN(\n",
" (in2_conv): Conv2D(64, 256, kernel_size=[1, 1], data_format=NCHW)\n",
" (in3_conv): Conv2D(128, 256, kernel_size=[1, 1], data_format=NCHW)\n",
" (in4_conv): Conv2D(256, 256, kernel_size=[1, 1], data_format=NCHW)\n",
" (in5_conv): Conv2D(512, 256, kernel_size=[1, 1], data_format=NCHW)\n",
" (p5_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p4_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p3_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p2_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" )\n",
" (head): DBHead(\n",
" (binarize): Head(\n",
" (conv1): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (conv_bn1): BatchNorm()\n",
" (conv2): Conv2DTranspose(64, 64, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" (conv_bn2): BatchNorm()\n",
" (conv3): Conv2DTranspose(64, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" )\n",
" (thresh): Head(\n",
" (conv1): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (conv_bn1): BatchNorm()\n",
" (conv2): Conv2DTranspose(64, 64, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" (conv_bn2): BatchNorm()\n",
" (conv3): Conv2DTranspose(64, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" )\n",
" )\n",
" )\n",
" (Student): BaseModel(\n",
" (backbone): MobileNetV3(\n",
" (conv): ConvBNLayer(\n",
" (conv): Conv2D(3, 8, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (stage0): Sequential(\n",
" (0): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(8, 8, kernel_size=[3, 3], padding=1, groups=8, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (1): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(8, 32, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(32, 32, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=32, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(32, 16, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (2): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 40, kernel_size=[3, 3], padding=1, groups=40, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 16, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" )\n",
" (stage1): Sequential(\n",
" (0): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 40, kernel_size=[5, 5], stride=[2, 2], padding=2, groups=40, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 24, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (1): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (2): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" )\n",
" (stage2): Sequential(\n",
" (0): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(24, 120, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(120, 120, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=120, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(120, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (1): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 104, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(104, 104, kernel_size=[3, 3], padding=1, groups=104, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(104, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (2): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (3): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (4): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 240, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(240, 240, kernel_size=[3, 3], padding=1, groups=240, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(240, 56, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (5): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(336, 336, kernel_size=[3, 3], padding=1, groups=336, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(336, 56, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" )\n",
" (stage3): Sequential(\n",
" (0): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(336, 336, kernel_size=[5, 5], stride=[2, 2], padding=2, groups=336, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(336, 80, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (1): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (2): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (3): ConvBNLayer(\n",
" (conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" )\n",
" (neck): DBFPN(\n",
" (in2_conv): Conv2D(16, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (in3_conv): Conv2D(24, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (in4_conv): Conv2D(56, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (in5_conv): Conv2D(480, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (p5_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p4_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p3_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p2_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" )\n",
" (head): DBHead(\n",
" (binarize): Head(\n",
" (conv1): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (conv_bn1): BatchNorm()\n",
" (conv2): Conv2DTranspose(24, 24, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" (conv_bn2): BatchNorm()\n",
" (conv3): Conv2DTranspose(24, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" )\n",
" (thresh): Head(\n",
" (conv1): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (conv_bn1): BatchNorm()\n",
" (conv2): Conv2DTranspose(24, 24, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" (conv_bn2): BatchNorm()\n",
" (conv3): Conv2DTranspose(24, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" )\n",
" )\n",
" )\n",
" (Student2): BaseModel(\n",
" (backbone): MobileNetV3(\n",
" (conv): ConvBNLayer(\n",
" (conv): Conv2D(3, 8, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (stage0): Sequential(\n",
" (0): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(8, 8, kernel_size=[3, 3], padding=1, groups=8, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (1): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(8, 32, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(32, 32, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=32, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(32, 16, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (2): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 40, kernel_size=[3, 3], padding=1, groups=40, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 16, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" )\n",
" (stage1): Sequential(\n",
" (0): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 40, kernel_size=[5, 5], stride=[2, 2], padding=2, groups=40, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 24, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (1): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (2): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" )\n",
" (stage2): Sequential(\n",
" (0): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(24, 120, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(120, 120, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=120, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(120, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (1): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 104, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(104, 104, kernel_size=[3, 3], padding=1, groups=104, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(104, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (2): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (3): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (4): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 240, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(240, 240, kernel_size=[3, 3], padding=1, groups=240, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(240, 56, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (5): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(336, 336, kernel_size=[3, 3], padding=1, groups=336, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(336, 56, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" )\n",
" (stage3): Sequential(\n",
" (0): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(336, 336, kernel_size=[5, 5], stride=[2, 2], padding=2, groups=336, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(336, 80, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (1): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (2): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (3): ConvBNLayer(\n",
" (conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" )\n",
" (neck): DBFPN(\n",
" (in2_conv): Conv2D(16, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (in3_conv): Conv2D(24, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (in4_conv): Conv2D(56, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (in5_conv): Conv2D(480, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (p5_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p4_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p3_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p2_conv): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" )\n",
" (head): DBHead(\n",
" (binarize): Head(\n",
" (conv1): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (conv_bn1): BatchNorm()\n",
" (conv2): Conv2DTranspose(24, 24, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" (conv_bn2): BatchNorm()\n",
" (conv3): Conv2DTranspose(24, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" )\n",
" (thresh): Head(\n",
" (conv1): Conv2D(96, 24, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (conv_bn1): BatchNorm()\n",
" (conv2): Conv2DTranspose(24, 24, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" (conv_bn2): BatchNorm()\n",
" (conv3): Conv2DTranspose(24, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" )\n",
" )\n",
" )\n",
")\n"
]
}
],
"source": [
"# 参考代码\n",
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/modeling/architectures/__init__.py\n",
"from tools.program import load_config\n",
"from ppocr.modeling.architectures import build_model\n",
"config_path = \"./configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml\"\n",
"config = load_config(config_path)\n",
"model = build_model(config['Architecture'])\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"可以通过下面的方式快速体验CML蒸馏的训练过程。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mkdir: cannot create directory train_data: File exists\n",
"--2021-12-24 22:09:31-- https://paddleocr.bj.bcebos.com/dataset/det_data_lesson_demo.tar\n",
"Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.195, 182.61.200.229, 2409:8c04:1001:1002:0:ff:b001:368a\n",
"Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.195|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 449465021 (429M) [application/x-tar]\n",
"Saving to: det_data_lesson_demo.tar\n",
"\n",
"det_data_lesson_dem 100%[===================>] 428.64M 53.9MB/s in 8.5s \n",
"\n",
"2021-12-24 22:09:39 (50.7 MB/s) - det_data_lesson_demo.tar saved [449465021/449465021]\n",
"\n",
"mkdir: cannot create directory pretrain_models: File exists\n",
"[2021/12/24 22:09:46] root INFO: Architecture : \n",
"[2021/12/24 22:09:46] root INFO: Models : \n",
"[2021/12/24 22:09:46] root INFO: Student : \n",
"[2021/12/24 22:09:46] root INFO: Backbone : \n",
"[2021/12/24 22:09:46] root INFO: disable_se : True\n",
"[2021/12/24 22:09:46] root INFO: model_name : large\n",
"[2021/12/24 22:09:46] root INFO: name : MobileNetV3\n",
"[2021/12/24 22:09:46] root INFO: scale : 0.5\n",
"[2021/12/24 22:09:46] root INFO: Head : \n",
"[2021/12/24 22:09:46] root INFO: k : 50\n",
"[2021/12/24 22:09:46] root INFO: name : DBHead\n",
"[2021/12/24 22:09:46] root INFO: Neck : \n",
"[2021/12/24 22:09:46] root INFO: name : DBFPN\n",
"[2021/12/24 22:09:46] root INFO: out_channels : 96\n",
"[2021/12/24 22:09:46] root INFO: algorithm : DB\n",
"[2021/12/24 22:09:46] root INFO: freeze_params : False\n",
"[2021/12/24 22:09:46] root INFO: model_type : det\n",
"[2021/12/24 22:09:46] root INFO: return_all_feats : False\n",
"[2021/12/24 22:09:46] root INFO: Student2 : \n",
"[2021/12/24 22:09:46] root INFO: Backbone : \n",
"[2021/12/24 22:09:46] root INFO: disable_se : True\n",
"[2021/12/24 22:09:46] root INFO: model_name : large\n",
"[2021/12/24 22:09:46] root INFO: name : MobileNetV3\n",
"[2021/12/24 22:09:46] root INFO: scale : 0.5\n",
"[2021/12/24 22:09:46] root INFO: Head : \n",
"[2021/12/24 22:09:46] root INFO: k : 50\n",
"[2021/12/24 22:09:46] root INFO: name : DBHead\n",
"[2021/12/24 22:09:46] root INFO: Neck : \n",
"[2021/12/24 22:09:46] root INFO: name : DBFPN\n",
"[2021/12/24 22:09:46] root INFO: out_channels : 96\n",
"[2021/12/24 22:09:46] root INFO: Transform : None\n",
"[2021/12/24 22:09:46] root INFO: algorithm : DB\n",
"[2021/12/24 22:09:46] root INFO: freeze_params : False\n",
"[2021/12/24 22:09:46] root INFO: model_type : det\n",
"[2021/12/24 22:09:46] root INFO: return_all_feats : False\n",
"[2021/12/24 22:09:46] root INFO: Teacher : \n",
"[2021/12/24 22:09:46] root INFO: Backbone : \n",
"[2021/12/24 22:09:46] root INFO: layers : 18\n",
"[2021/12/24 22:09:46] root INFO: name : ResNet\n",
"[2021/12/24 22:09:46] root INFO: Head : \n",
"[2021/12/24 22:09:46] root INFO: k : 50\n",
"[2021/12/24 22:09:46] root INFO: name : DBHead\n",
"[2021/12/24 22:09:46] root INFO: Neck : \n",
"[2021/12/24 22:09:46] root INFO: name : DBFPN\n",
"[2021/12/24 22:09:46] root INFO: out_channels : 256\n",
"[2021/12/24 22:09:46] root INFO: Transform : None\n",
"[2021/12/24 22:09:46] root INFO: algorithm : DB\n",
"[2021/12/24 22:09:46] root INFO: freeze_params : True\n",
"[2021/12/24 22:09:46] root INFO: model_type : det\n",
"[2021/12/24 22:09:46] root INFO: return_all_feats : False\n",
"[2021/12/24 22:09:46] root INFO: algorithm : Distillation\n",
"[2021/12/24 22:09:46] root INFO: model_type : det\n",
"[2021/12/24 22:09:46] root INFO: name : DistillationModel\n",
"[2021/12/24 22:09:46] root INFO: Eval : \n",
"[2021/12/24 22:09:46] root INFO: dataset : \n",
"[2021/12/24 22:09:46] root INFO: data_dir : ./det_data_lesson_demo/\n",
"[2021/12/24 22:09:46] root INFO: label_file_list : ['./det_data_lesson_demo/eval.txt']\n",
"[2021/12/24 22:09:46] root INFO: name : SimpleDataSet\n",
"[2021/12/24 22:09:46] root INFO: transforms : \n",
"[2021/12/24 22:09:46] root INFO: DecodeImage : \n",
"[2021/12/24 22:09:46] root INFO: channel_first : False\n",
"[2021/12/24 22:09:46] root INFO: img_mode : BGR\n",
"[2021/12/24 22:09:46] root INFO: DetLabelEncode : None\n",
"[2021/12/24 22:09:46] root INFO: DetResizeForTest : None\n",
"[2021/12/24 22:09:46] root INFO: NormalizeImage : \n",
"[2021/12/24 22:09:46] root INFO: mean : [0.485, 0.456, 0.406]\n",
"[2021/12/24 22:09:46] root INFO: order : hwc\n",
"[2021/12/24 22:09:46] root INFO: scale : 1./255.\n",
"[2021/12/24 22:09:46] root INFO: std : [0.229, 0.224, 0.225]\n",
"[2021/12/24 22:09:46] root INFO: ToCHWImage : None\n",
"[2021/12/24 22:09:46] root INFO: KeepKeys : \n",
"[2021/12/24 22:09:46] root INFO: keep_keys : ['image', 'shape', 'polys', 'ignore_tags']\n",
"[2021/12/24 22:09:46] root INFO: loader : \n",
"[2021/12/24 22:09:46] root INFO: batch_size_per_card : 1\n",
"[2021/12/24 22:09:46] root INFO: drop_last : False\n",
"[2021/12/24 22:09:46] root INFO: num_workers : 0\n",
"[2021/12/24 22:09:46] root INFO: shuffle : False\n",
"[2021/12/24 22:09:46] root INFO: Global : \n",
"[2021/12/24 22:09:46] root INFO: cal_metric_during_train : False\n",
"[2021/12/24 22:09:46] root INFO: checkpoints : None\n",
"[2021/12/24 22:09:46] root INFO: debug : False\n",
"[2021/12/24 22:09:46] root INFO: distributed : False\n",
"[2021/12/24 22:09:46] root INFO: epoch_num : 1\n",
"[2021/12/24 22:09:46] root INFO: eval_batch_step : [3000, 2000]\n",
"[2021/12/24 22:09:46] root INFO: infer_img : doc/imgs_en/img_10.jpg\n",
"[2021/12/24 22:09:46] root INFO: log_smooth_window : 20\n",
"[2021/12/24 22:09:46] root INFO: pretrained_model : None\n",
"[2021/12/24 22:09:46] root INFO: print_batch_step : 2\n",
"[2021/12/24 22:09:46] root INFO: save_epoch_step : 1200\n",
"[2021/12/24 22:09:46] root INFO: save_inference_dir : None\n",
"[2021/12/24 22:09:46] root INFO: save_model_dir : ./output/ch_db_mv3/\n",
"[2021/12/24 22:09:46] root INFO: save_res_path : ./output/det_db/predicts_db.txt\n",
"[2021/12/24 22:09:46] root INFO: use_gpu : True\n",
"[2021/12/24 22:09:46] root INFO: use_visualdl : False\n",
"[2021/12/24 22:09:46] root INFO: Loss : \n",
"[2021/12/24 22:09:46] root INFO: loss_config_list : \n",
"[2021/12/24 22:09:46] root INFO: DistillationDilaDBLoss : \n",
"[2021/12/24 22:09:46] root INFO: alpha : 5\n",
"[2021/12/24 22:09:46] root INFO: balance_loss : True\n",
"[2021/12/24 22:09:46] root INFO: beta : 10\n",
"[2021/12/24 22:09:46] root INFO: key : maps\n",
"[2021/12/24 22:09:46] root INFO: main_loss_type : DiceLoss\n",
"[2021/12/24 22:09:46] root INFO: model_name_pairs : [['Student', 'Teacher'], ['Student2', 'Teacher']]\n",
"[2021/12/24 22:09:46] root INFO: ohem_ratio : 3\n",
"[2021/12/24 22:09:46] root INFO: weight : 1.0\n",
"[2021/12/24 22:09:46] root INFO: DistillationDMLLoss : \n",
"[2021/12/24 22:09:46] root INFO: key : maps\n",
"[2021/12/24 22:09:46] root INFO: maps_name : thrink_maps\n",
"[2021/12/24 22:09:46] root INFO: model_name_pairs : ['Student', 'Student2']\n",
"[2021/12/24 22:09:46] root INFO: weight : 1.0\n",
"[2021/12/24 22:09:46] root INFO: DistillationDBLoss : \n",
"[2021/12/24 22:09:46] root INFO: alpha : 5\n",
"[2021/12/24 22:09:46] root INFO: balance_loss : True\n",
"[2021/12/24 22:09:46] root INFO: beta : 10\n",
"[2021/12/24 22:09:46] root INFO: main_loss_type : DiceLoss\n",
"[2021/12/24 22:09:46] root INFO: model_name_list : ['Student', 'Student2']\n",
"[2021/12/24 22:09:46] root INFO: ohem_ratio : 3\n",
"[2021/12/24 22:09:46] root INFO: weight : 1.0\n",
"[2021/12/24 22:09:46] root INFO: name : CombinedLoss\n",
"[2021/12/24 22:09:46] root INFO: Metric : \n",
"[2021/12/24 22:09:46] root INFO: base_metric_name : DetMetric\n",
"[2021/12/24 22:09:46] root INFO: key : Student\n",
"[2021/12/24 22:09:46] root INFO: main_indicator : hmean\n",
"[2021/12/24 22:09:46] root INFO: name : DistillationMetric\n",
"[2021/12/24 22:09:46] root INFO: Optimizer : \n",
"[2021/12/24 22:09:46] root INFO: beta1 : 0.9\n",
"[2021/12/24 22:09:46] root INFO: beta2 : 0.999\n",
"[2021/12/24 22:09:46] root INFO: lr : \n",
"[2021/12/24 22:09:46] root INFO: learning_rate : 0.00025\n",
"[2021/12/24 22:09:46] root INFO: name : Cosine\n",
"[2021/12/24 22:09:46] root INFO: warmup_epoch : 2\n",
"[2021/12/24 22:09:46] root INFO: name : Adam\n",
"[2021/12/24 22:09:46] root INFO: regularizer : \n",
"[2021/12/24 22:09:46] root INFO: factor : 0\n",
"[2021/12/24 22:09:46] root INFO: name : L2\n",
"[2021/12/24 22:09:46] root INFO: PostProcess : \n",
"[2021/12/24 22:09:46] root INFO: box_thresh : 0.6\n",
"[2021/12/24 22:09:46] root INFO: max_candidates : 1000\n",
"[2021/12/24 22:09:46] root INFO: model_name : ['Student', 'Student2', 'Teacher']\n",
"[2021/12/24 22:09:46] root INFO: name : DistillationDBPostProcess\n",
"[2021/12/24 22:09:46] root INFO: thresh : 0.3\n",
"[2021/12/24 22:09:46] root INFO: unclip_ratio : 1.5\n",
"[2021/12/24 22:09:46] root INFO: Train : \n",
"[2021/12/24 22:09:46] root INFO: dataset : \n",
"[2021/12/24 22:09:46] root INFO: data_dir : ./det_data_lesson_demo/\n",
"[2021/12/24 22:09:46] root INFO: label_file_list : ['./det_data_lesson_demo/train.txt']\n",
"[2021/12/24 22:09:46] root INFO: name : SimpleDataSet\n",
"[2021/12/24 22:09:46] root INFO: ratio_list : [1.0]\n",
"[2021/12/24 22:09:46] root INFO: transforms : \n",
"[2021/12/24 22:09:46] root INFO: DecodeImage : \n",
"[2021/12/24 22:09:46] root INFO: channel_first : False\n",
"[2021/12/24 22:09:46] root INFO: img_mode : BGR\n",
"[2021/12/24 22:09:46] root INFO: DetLabelEncode : None\n",
"[2021/12/24 22:09:46] root INFO: CopyPaste : None\n",
"[2021/12/24 22:09:46] root INFO: IaaAugment : \n",
"[2021/12/24 22:09:46] root INFO: augmenter_args : \n",
"[2021/12/24 22:09:46] root INFO: args : \n",
"[2021/12/24 22:09:46] root INFO: p : 0.5\n",
"[2021/12/24 22:09:46] root INFO: type : Fliplr\n",
"[2021/12/24 22:09:46] root INFO: args : \n",
"[2021/12/24 22:09:46] root INFO: rotate : [-10, 10]\n",
"[2021/12/24 22:09:46] root INFO: type : Affine\n",
"[2021/12/24 22:09:46] root INFO: args : \n",
"[2021/12/24 22:09:46] root INFO: size : [0.5, 3]\n",
"[2021/12/24 22:09:46] root INFO: type : Resize\n",
"[2021/12/24 22:09:46] root INFO: EastRandomCropData : \n",
"[2021/12/24 22:09:46] root INFO: keep_ratio : True\n",
"[2021/12/24 22:09:46] root INFO: max_tries : 50\n",
"[2021/12/24 22:09:46] root INFO: size : [960, 960]\n",
"[2021/12/24 22:09:46] root INFO: MakeBorderMap : \n",
"[2021/12/24 22:09:46] root INFO: shrink_ratio : 0.4\n",
"[2021/12/24 22:09:46] root INFO: thresh_max : 0.7\n",
"[2021/12/24 22:09:46] root INFO: thresh_min : 0.3\n",
"[2021/12/24 22:09:46] root INFO: MakeShrinkMap : \n",
"[2021/12/24 22:09:46] root INFO: min_text_size : 8\n",
"[2021/12/24 22:09:46] root INFO: shrink_ratio : 0.4\n",
"[2021/12/24 22:09:46] root INFO: NormalizeImage : \n",
"[2021/12/24 22:09:46] root INFO: mean : [0.485, 0.456, 0.406]\n",
"[2021/12/24 22:09:46] root INFO: order : hwc\n",
"[2021/12/24 22:09:46] root INFO: scale : 1./255.\n",
"[2021/12/24 22:09:46] root INFO: std : [0.229, 0.224, 0.225]\n",
"[2021/12/24 22:09:46] root INFO: ToCHWImage : None\n",
"[2021/12/24 22:09:46] root INFO: KeepKeys : \n",
"[2021/12/24 22:09:46] root INFO: keep_keys : ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask']\n",
"[2021/12/24 22:09:46] root INFO: loader : \n",
"[2021/12/24 22:09:46] root INFO: batch_size_per_card : 8\n",
"[2021/12/24 22:09:46] root INFO: drop_last : False\n",
"[2021/12/24 22:09:46] root INFO: num_workers : 0\n",
"[2021/12/24 22:09:46] root INFO: shuffle : True\n",
"[2021/12/24 22:09:46] root INFO: profiler_options : None\n",
"[2021/12/24 22:09:46] root INFO: train with paddle 2.2.1 and device CUDAPlace(0)\n",
"[2021/12/24 22:09:46] root INFO: Initialize indexs of datasets:['./det_data_lesson_demo/train.txt']\n",
"[2021/12/24 22:09:46] root INFO: Initialize indexs of datasets:['./det_data_lesson_demo/eval.txt']\n",
"W1224 22:09:46.106822 8398 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1\n",
"W1224 22:09:46.111670 8398 device_context.cc:465] device: 0, cuDNN Version: 7.6.\n",
"[2021/12/24 22:09:50] root INFO: train from scratch\n",
"[2021/12/24 22:09:50] root INFO: train dataloader has 94 iters\n",
"[2021/12/24 22:09:50] root INFO: valid dataloader has 250 iters\n",
"[2021/12/24 22:09:50] root INFO: During the training process, after the 3000th iteration, an evaluation is run every 2000 iterations\n",
"[2021/12/24 22:09:50] root INFO: Initialize indexs of datasets:['./det_data_lesson_demo/train.txt']\n",
"[2021/12/24 22:09:59] root INFO: epoch: [1/1], iter: 2, lr: 0.000001, dila_dbloss_Student_Teacher: 1.971631, dila_dbloss_Student2_Teacher: 1.548899, loss: 22.388054, dml_thrink_maps_0: 0.166657, db_Student_loss_shrink_maps: 4.802422, db_Student_loss_threshold_maps: 3.800185, db_Student_loss_binary_maps: 0.966456, db_Student2_loss_shrink_maps: 4.827962, db_Student2_loss_threshold_maps: 3.306140, db_Student2_loss_binary_maps: 0.972999, reader_cost: 3.80953 s, batch_cost: 4.74377 s, samples: 24, ips: 2.52964\n",
"[2021/12/24 22:10:04] root INFO: epoch: [1/1], iter: 4, lr: 0.000003, dila_dbloss_Student_Teacher: 1.971631, dila_dbloss_Student2_Teacher: 1.579283, loss: 22.072165, dml_thrink_maps_0: 0.168828, db_Student_loss_shrink_maps: 4.764446, db_Student_loss_threshold_maps: 3.598955, db_Student_loss_binary_maps: 0.959983, db_Student2_loss_shrink_maps: 4.797078, db_Student2_loss_threshold_maps: 3.226031, db_Student2_loss_binary_maps: 0.967116, reader_cost: 1.46505 s, batch_cost: 2.07757 s, samples: 16, ips: 3.85066\n",
"[2021/12/24 22:10:10] root INFO: epoch: [1/1], iter: 6, lr: 0.000004, dila_dbloss_Student_Teacher: 1.971631, dila_dbloss_Student2_Teacher: 1.579283, loss: 22.026184, dml_thrink_maps_0: 0.180329, db_Student_loss_shrink_maps: 4.760996, db_Student_loss_threshold_maps: 3.598955, db_Student_loss_binary_maps: 0.954792, db_Student2_loss_shrink_maps: 4.784370, db_Student2_loss_threshold_maps: 3.226031, db_Student2_loss_binary_maps: 0.962342, reader_cost: 2.33646 s, batch_cost: 2.98103 s, samples: 16, ips: 2.68364\n",
"[2021/12/24 22:10:16] root INFO: epoch: [1/1], iter: 8, lr: 0.000005, dila_dbloss_Student_Teacher: 1.971220, dila_dbloss_Student2_Teacher: 1.580030, loss: 22.026184, dml_thrink_maps_0: 0.180329, db_Student_loss_shrink_maps: 4.760996, db_Student_loss_threshold_maps: 3.598955, db_Student_loss_binary_maps: 0.954792, db_Student2_loss_shrink_maps: 4.784370, db_Student2_loss_threshold_maps: 3.226031, db_Student2_loss_binary_maps: 0.962342, reader_cost: 2.51863 s, batch_cost: 3.17085 s, samples: 16, ips: 2.52298\n",
"[2021/12/24 22:10:22] root INFO: epoch: [1/1], iter: 10, lr: 0.000007, dila_dbloss_Student_Teacher: 1.967909, dila_dbloss_Student2_Teacher: 1.579283, loss: 21.956417, dml_thrink_maps_0: 0.182062, db_Student_loss_shrink_maps: 4.698996, db_Student_loss_threshold_maps: 3.476604, db_Student_loss_binary_maps: 0.944015, db_Student2_loss_shrink_maps: 4.730411, db_Student2_loss_threshold_maps: 3.181734, db_Student2_loss_binary_maps: 0.954959, reader_cost: 2.17333 s, batch_cost: 2.87094 s, samples: 16, ips: 2.78654\n",
"[2021/12/24 22:10:26] root INFO: epoch: [1/1], iter: 12, lr: 0.000008, dila_dbloss_Student_Teacher: 1.967516, dila_dbloss_Student2_Teacher: 1.579283, loss: 21.956417, dml_thrink_maps_0: 0.182062, db_Student_loss_shrink_maps: 4.698996, db_Student_loss_threshold_maps: 3.476604, db_Student_loss_binary_maps: 0.944015, db_Student2_loss_shrink_maps: 4.730411, db_Student2_loss_threshold_maps: 3.181734, db_Student2_loss_binary_maps: 0.954959, reader_cost: 1.36976 s, batch_cost: 2.01118 s, samples: 16, ips: 3.97776\n",
"^C\n"
]
}
],
"source": [
"# 参考代码\n",
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/tools/train.py\n",
"os.chdir(\"/home/aistudio/PaddleOCR/\")\n",
"!mkdir train_data\n",
"!wget https://paddleocr.bj.bcebos.com/dataset/det_data_lesson_demo.tar -O det_data_lesson_demo.tar && tar -xf det_data_lesson_demo.tar && rm det_data_lesson_demo.tar\n",
"!mkdir pretrain_models && wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar && tar -xf ch_ppocr_server_v2.0_det_train.tar\n",
"!mv ch_ppocr_server_v2.0_det_train pretrain_models/ && rm ch_ppocr_server_v2.0_det_train.tar\n",
"# 训练脚本\n",
"# 注意这里只训练了一个epoch仅用于快速演示指标会很差\n",
"!python tools/train.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml \\\n",
" -o Global.pretrained_model=\"\" \\\n",
" Train.dataset.data_dir=\"./det_data_lesson_demo/\" \\\n",
" Train.dataset.label_file_list=[\"./det_data_lesson_demo/train.txt\"] \\\n",
" Train.loader.num_workers=0 \\\n",
" Eval.dataset.data_dir=\"./det_data_lesson_demo/\" \\\n",
" Eval.dataset.label_file_list=[\"./det_data_lesson_demo/eval.txt\"] \\\n",
" Eval.loader.num_workers=0 \\\n",
" Optimizer.lr.learning_rate=0.00025 \\\n",
" Global.epoch_num=1"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"\n",
"### 3.1.2 数据增广\n",
"\n",
"数据增广是提升模型泛化能力重要的手段之一CopyPaste 是一种新颖的数据增强技巧,已经在目标检测和实例分割任务中验证了有效性。利用 CopyPaste可以合成文本实例来平衡训练图像中的正负样本之间的比例。相比而言传统图像旋转、随机翻转和随机裁剪是无法做到的。\n",
"\n",
"CopyPaste 主要步骤包括:\n",
"\n",
"1. 随机选择两幅训练图像;\n",
"2. 随机尺度抖动缩放;\n",
"3. 随机水平翻转;\n",
"4. 随机选择一幅图像中的目标子集;\n",
"5. 粘贴在另一幅图像中随机的位置。\n",
"\n",
"\n",
"这样就比较好地提升了样本丰富度,同时也增加了模型对环境的鲁棒性。如下图所示,通过在左下角的图中裁剪出来的文本,随机旋转缩放之后粘贴到左上角的图像中,进一步丰富了该文本在不同背景下的多样性。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/b44c9cd5241d444c8c76010f105a3e1abf2209414922457385d8e0eb2800be2a\" width = \"1200\" />\n",
"</div>\n",
"\n",
"如果希望在模型训练中使用`CopyPaste`,只需在`Train.transforms`配置字段中添加`CopyPaste`即可,如下所示。\n",
"\n",
"```yaml\n",
"Train:\n",
" dataset:\n",
" name: SimpleDataSet\n",
" data_dir: ./train_data/icdar2015/text_localization/\n",
" label_file_list:\n",
" - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt\n",
" ratio_list: [1.0]\n",
" transforms:\n",
" - DecodeImage: # load image\n",
" img_mode: BGR\n",
" channel_first: False\n",
" - DetLabelEncode: # Class handling label\n",
" - CopyPaste: # 添加CopyPaste\n",
" - IaaAugment:\n",
" augmenter_args:\n",
" - { 'type': Fliplr, 'args': { 'p': 0.5 } }\n",
" - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }\n",
" - { 'type': Resize, 'args': { 'size': [0.5, 3] } }\n",
" - EastRandomCropData:\n",
" size: [960, 960]\n",
" max_tries: 50\n",
" keep_ratio: true\n",
" - MakeBorderMap:\n",
" shrink_ratio: 0.4\n",
" thresh_min: 0.3\n",
" thresh_max: 0.7\n",
" - MakeShrinkMap:\n",
" shrink_ratio: 0.4\n",
" min_text_size: 8\n",
" - NormalizeImage:\n",
" scale: 1./255.\n",
" mean: [0.485, 0.456, 0.406]\n",
" std: [0.229, 0.224, 0.225]\n",
" order: 'hwc'\n",
" - ToCHWImage:\n",
" - KeepKeys:\n",
" keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list\n",
" loader:\n",
" shuffle: True\n",
" drop_last: False\n",
" batch_size_per_card: 8\n",
" num_workers: 4\n",
"```\n",
"\n",
"`CopyPaste`的具体实现可以参考[copy_paste.py](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/data/imaug/copy_paste.py)。\n",
"\n",
"下面基于icdar2015检测数据集演示CopyPaste的实际运行过程。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"\n",
"os.chdir(\"/home/aistudio/PaddleOCR/\")\n",
"!unzip -oq /home/aistudio/data/data46088/icdar2015.zip"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dict_keys(['img_path', 'label', 'image', 'ext_data', 'polys', 'texts', 'ignore_tags'])\n",
"./icdar2015/text_localization/icdar_c4_train_imgs/img_603.jpg\n",
"./icdar2015/text_localization/icdar_c4_train_imgs/img_233.jpg\n"
]
}
],
"source": [
"# 参考代码:\n",
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/data/simple_dataset.py\n",
"import logging\n",
"import random\n",
"import numpy as np\n",
"\n",
"from ppocr.data.imaug import create_operators, transform\n",
"\n",
"logger = logging.basicConfig()\n",
"\n",
"# CopyPaste示例的类\n",
"class CopyPasteDemo(object):\n",
" def __init__(self, ):\n",
" self.data_dir = \"./icdar2015/text_localization/\"\n",
" self.label_file_list = \"./icdar2015/text_localization/train_icdar2015_label.txt\"\n",
" self.data_lines = self.get_image_info_list(self.label_file_list)\n",
" self.data_idx_order_list = list(range(len(self.data_lines)))\n",
" transforms = [\n",
" {\"DecodeImage\": {\"img_mode\": \"BGR\", \"channel_first\": False}},\n",
" {\"DetLabelEncode\": {}},\n",
" {\"CopyPaste\": {\"objects_paste_ratio\": 1.0}},\n",
" ]\n",
" self.ops = create_operators(transforms)\n",
" \n",
" # 选择一张图像,将其中的内容拷贝到当前图像中\n",
" def get_ext_data(self, idx):\n",
" ext_data_num = 1\n",
" ext_data = []\n",
"\n",
" load_data_ops = self.ops[:2]\n",
"\n",
" next_idx = idx\n",
"\n",
" while len(ext_data) < ext_data_num:\n",
" next_idx = (next_idx + 1) % len(self)\n",
" file_idx = self.data_idx_order_list[next_idx]\n",
" data_line = self.data_lines[file_idx]\n",
" data_line = data_line.decode('utf-8')\n",
" substr = data_line.strip(\"\\n\").split(\"\\t\")\n",
" file_name = substr[0]\n",
" label = substr[1]\n",
" img_path = os.path.join(self.data_dir, file_name)\n",
" data = {'img_path': img_path, 'label': label}\n",
" if not os.path.exists(img_path):\n",
" continue\n",
" with open(data['img_path'], 'rb') as f:\n",
" img = f.read()\n",
" data['image'] = img\n",
" data = transform(data, load_data_ops)\n",
" if data is None:\n",
" continue\n",
" ext_data.append(data)\n",
" return ext_data\n",
" \n",
" # 获取图像信息\n",
" def get_image_info_list(self, file_list):\n",
" if isinstance(file_list, str):\n",
" file_list = [file_list]\n",
" data_lines = []\n",
" for idx, file in enumerate(file_list):\n",
" with open(file, \"rb\") as f:\n",
" lines = f.readlines()\n",
" data_lines.extend(lines)\n",
" return data_lines\n",
"\n",
" # 获取DataSet中的一条数据\n",
" def __getitem__(self, idx):\n",
" file_idx = self.data_idx_order_list[idx]\n",
" data_line = self.data_lines[file_idx]\n",
" try:\n",
" data_line = data_line.decode('utf-8')\n",
" substr = data_line.strip(\"\\n\").split(\"\\t\")\n",
" file_name = substr[0]\n",
" label = substr[1]\n",
" img_path = os.path.join(self.data_dir, file_name)\n",
" data = {'img_path': img_path, 'label': label}\n",
" if not os.path.exists(img_path):\n",
" raise Exception(\"{} does not exist!\".format(img_path))\n",
" with open(data['img_path'], 'rb') as f:\n",
" img = f.read()\n",
" data['image'] = img\n",
" data['ext_data'] = self.get_ext_data(idx)\n",
" outs = transform(data, self.ops)\n",
" except Exception as e:\n",
" print(\n",
" \"When parsing line {}, error happened with msg: {}\".format(\n",
" data_line, e))\n",
" outs = None\n",
" if outs is None:\n",
" return\n",
" return outs\n",
" \n",
" def __len__(self):\n",
" return len(self.data_idx_order_list)\n",
"\n",
"copy_paste_demo = CopyPasteDemo()\n",
"\n",
"idx = 1\n",
"data1 = copy_paste_demo[idx]\n",
"print(data1.keys())\n",
"print(data1[\"img_path\"])\n",
"print(data1[\"ext_data\"][0][\"img_path\"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"* 下面2张图是在CopyPaste之前的图像。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAFdCAYAAAAwgXjMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvTuvJEmW5/c7ZuYeEffmq7Ie04/pGSx2lyTAB0AQoEBhAQoEqFHlQ1+JH4BfYiVqK1AjQG2xygJUqa7CJZbL4fRM93ZXd9crszJv5r0R4e5mdigcM3fzuHGzsnq6enI5cQpZN17ubm5udux//udhoqpc5CIXuchFLnKRi1zkDyvub7sBF7nIRS5ykYtc5CL/f5QLyLrIRS5ykYtc5CIX+QHkArIucpGLXOQiF7nIRX4AuYCsi1zkIhe5yEUucpEfQC4g6yIXuchFLnKRi1zkB5ALyLrIRS5ykYtc5CIX+QHkBwFZIvJfi8j/KyJ/JSL/0w9xjYtc5CIXuchFLnKRD1nkD10nS0Q88JfAfwX8BviXwH+nqv/mD3qhi1zkIhe5yEUucpEPWH4IJus/B/5KVX+hqiPwvwH/zQ9wnYtc5CIXuchFLnKRD1Z+CJD1U+Dz5v1vymcXuchFLnKRi1zkIn9nJPxtXVhE/jHwjwG22+1/9qd/+mfLl7MHU1CR5hgQ7B9SXkn5vX34wNXsB/VUgqCqHA4HQvDsdjtyzvZLVVLO5JTIOaMoquVf1vlci5tVsZfN+3rNs55YndsAWu6ptk1wzuFEEOfmz8T666QD25dy77PlgzONmJtW215/7q09ziFibVFVUsqogve+uQNdna+9c06/r29P2/cHc1WfnkfOvF7/RqS9/PLdOff5+TN8x3VV5/GmCqeP79wxdVzXNsQYQTOh66zvz/Vhc/QyL07Ofubip+Pp3HudX59e47QZcnKO9Q9kNYcfvIHfS8ZxpOu6P/h5L/JHFM02XmIkj0ckDmiaIEdEEnU8CWV+ZrXnLedm5sOzdfm+HSsKokY3lLkn4iAndIhoUhyOJG61xqw03Vm9W1+1uqXVO3UteaBL6pxanbMq7nsaeH0/qyNaPX1v4j4sja45VYmrVUXLfYiUY2ofCoInqzLFhOZyx/PjOWmIPvhmtbS3T1cRkkIuDcztlyLIA90r93pu3dstxnDa3LEsnaHAL16/faGqn56/yiI/BMj6LfCz5v2fls9Woqr/FPinAP/gH/x7+k/+yf+MqgBu9QAm6RGneATvQEQJUkBA2BooKaAATImLCNn5BRyR5s89giMwjiM///lf8vz5M/79/+AfGqBSJSeYpon94Q3jOPDtqxdM08QwDLjsmaaJaZrmcxsQy+Qcy30lUp5QzfNimUg4XQZIffiOjBdwztH7QE6Jbb/h6uqKzWZDtwkE59mURUTUfosTJJSFUHW+/2WhWd6fAw05Z3LOjONICIGcMzFlktsQQk+/27LbXvHoyVNEhN/89rc8f/4ZTjqGKSHeoUxkSXb93KGk0ie5/E12ryzgtW2PqiKa5n6sba99erow1985PJRzzpNRyjXKX9QBzpQa3s6lESSv+mN5vbSxXn/pLOvf2m8rkfX7ts1ecwOy9OQ3ubw20Oqcn9upUoB+Suzv3jLs91xdXfGzP/0JKQupgP72enZ/9swrEK4guZ6zHR/n5kz73ntrSy5jsx7bvvbo/FrEl++DAXMRxLvmezl7jrZPWlB3Th5aGz7//HN+8pOfzO2+yAcqKxx0Mo/SHqaB4Te/5PjFL9Gv/pL09iuuu5FNdwTnbAwr6BRJxxHnHH7T2/FFd6+A14MrrEPpqukKIcJWoUskl/FZYVKGX/6O/GrCyQaVLUe3AedRJzOQqIt6dG59e7I4iHzRg1W3rG47T4hYu091XlIPuLLIl9up65lkIN/TndYXzTUkF52VAcW5jlTIg6rTljafn2FGLtyfV/O1NZElgwYUR1Il+B5wfPvqwDAk3u4nEh0qjiSmp10KiLrV+aq4Mj7c/FFe1k5ngPcQJ46T8jYpCY+KRyUTNTfrvEey6fMQAkguwyLjRMv5FV9wtojgyrG+YIxNLnhCMhkFMjFPpJT4H/75//Grs512Ij8EyPqXwD8Ukb+Hgav/Fvjv33WAqprV3japAC0NCZeFTEJzxilkBw6ZF8Cu63Bix4kI4hQ0otqAjKLEIxCckFLk6dPH+CDc3d0honRdR+h6+s2Orhfu7m4Zxycchz3edUzHkZTSakA658iZmfUxxqyAxQoG2nHfUCiahSw2CaNGBEiaGYaBjJI00IcOgK4sfKJSFIRjtlW0Yu+6aL7nkyrt8d6jCDFFpklxDgbnuNZHbLc7joc7nj9/zv7uwBgzXdeRs5Akoqp4F1D1M4gxoBVQTbgToDcDX1Uku5WikFl55XuKeJnUAjPQqjZVvba7dwySQP1sBc9grSgZbRin1XG1f07erxXlu5m4lOzZttcyBWff52wg03tmtjJpLmPK+vmoyt3dHeM44sNmBsinRkVlM+t3LRAzAGPWpr23+SNIURzn71/n/rn/DJ13q/O/r5wuNCuwxXf16Hnx3s99fZEPWKQ+X4dpYi1GWGJ68QWH11/Spbe8ffMl135E0oj0BdSwjGewcf5OSC0PR8JIZSVEyR3glCyJ4BTGkfFXXxJf3rLlESl4BhE0mL6txkAFWCoGBrKcnP8BaUFRnYv19el3M3MMaF6AxjtvW3yZYxltG8Vy/tagelc7K/HgXbj3XZWMA3UoHsWBcwwTvH37lru7xJSErAGRvuhjEHW2fst5XRpm43lpey56e0rCPkWmKEwKSQNZBFXz/PjC9gngndrSALhVfwq+nl7AK/hGj3lxM5GTMG9WzkqSZEAOOIM7H5Q/OMhS1Sgi/yPwv2Or4f+iqv/3dxxFypO9yhFjYrwtNuNQoIqxPlkXdiSPgynYKczWdx1AmTqI7VzV0hfxjPEIKM8/fkbOkeNxT4zGTm02xiSF0PHkyVOePHnKOI7c3Nzw9vUNd3d3eO9t0fOeYRisLepZXIEACXGFaaKziZfNPSdqSsIBWcQeoCZjlYaJmJQwJYbeXJkqjnGMbDYbnAidGPuEQM6Kc4VOdxWoLAClBTCwsDGnrEcInqQTYxo5Hkb2h1t670Cf8Pr1K/recXeXefT4iqw2qFNVcylgllWqT7NMxIXabyetTbZMLuAHFpCqqngy5Egri8t2YSghN+yU4JtzgOL94vIs02e+Vm0LZWxVazKltGJEHGv2rW1L1ncoU9bsXJX182nZRuufXP821rkWoPX02ZYY40oZV1Yq5zTfQ5V5TlQGVExZSVBIds/JZcSBNAq5HifZ40IgabQ+sUkFwDQDRL/q55QSOGcmgHMLuHNutgPENS5wXf6euiFXgOyBRaHruvvW/EU+LNHlbxaImgiiCEcEZfjVX3D4+rfozRfcvvgVT7sjGz/OY3ZenMv4hTKPUkJCWLNYVc/pQwaA6R71glxtQY44r7iU4fUrXvz1r9kNHdf+MeQt0XfEzoM4nJYxjoJrzquCywXUSS5GdF7dem3zfXGzTnBuYezFFR3UGoXVLVZ1a+2Hdp6oUkOtpVgti7G07pPWQ4C6+TyV0RGnK503gy7vF8PGd4wx42RDyo672wM3b47kDFN0ZAIqnXkXnBConiU14Niw2DMIjIoPguKIGKO+HwZihrucmMRIARVv/STm/QoScZJnkOtI1hPeQKqrbLoCuRjA4tiwZtezgBZ2csrMHpss9g/v8b4/8yzPyw8Sk6Wq/wL4F9/j98Q4IeJmRirnSM6g5b0rSpjZ9s6kGNHk0BTJDcjCCc5h7gvvjTaU4lIBEhS2S+YF6nA4cLd/y9OnT8ka2W6uZoDmfeD580/onD9hQKSwOpkYbcCmDCLGsqjWiVHYL2E1kFE1S4DCe+WyjiVFNZKT4MeIyMSm623gZkVTLGASqovVTut4Xy7g1KJxIgTnSMlYNc3C65sXbLYdL775ipQnco4415UHYdcyS6e656orzK3ch1Wc882kzShx1Z7mTWGs6nGLAsiEYjmlsqh
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAFdCAYAAAAwgXjMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvVmsbFlynvdFrLV2Zp485w41dA09spvdZHVLTaibFtWiSEoyLFHwQNh6MEUZgm0Igh8MEIQBQwIM2AbsF3mA7Ac/0IKtB4Pgg/xgaAAJWSAtQZYNusWpOTV7qB6qurrGO52TmXuvFeGHWHtn3h7IfmCBBTujcHHrnpO5h7XXivXHH3/EFnfnbGc729nOdrazne1sf7Cmf9gXcLazne1sZzvb2c72/0U7g6yzne1sZzvb2c52trfBziDrbGc729nOdrazne1tsDPIOtvZzna2s53tbGd7G+wMss52trOd7WxnO9vZ3gY7g6yzne1sZzvb2c52trfB3haQJSI/KiK/IyKfE5G//nac42xnO9vZzna2s53tnWzyB90nS0QS8FngXwG+CvwS8Jfc/Tf/QE90trOd7WxnO9vZzvYOtreDyfrjwOfc/QvuPgI/C/zY23Ces53tbGc729nOdrZ3rL0dIOvdwFdO/v3V/rOzne1sZzvb2c52tv/fWP7DOrGI/DXgr/X//2TJBWi4NxABpP/dmDOajjCnN0UFMe8/X46J45iBSv8TRwJRXATR+FtVSKpo/BKJT4G0+IYrQhzP3XAszuNg7uAs1yUC5g0z78cCUKyBisb33XEsDs3x9talsNlc8NZbby03IqJ9CIRcMq1VECepgjfW6zVCRlVQVUQEkTiru8cY9dvxPkZmhpvhzTFv/Z4cM8f7vZgLbh4/d+/3P4+9L3d2TDBLnMfnk8UzELd55LD+HZk/4r5cr/ny2+O5nP5NiScuy630cx2vwE8fwDJfjh9dzndyAcdPzE/Kudhest1egeQ+dt5nm9DqDdD6hVk/jMXvnX6d0Po4J9XjuY8z83jpgHxTij7O6Tgifca6f9PVPjbycrx/EQdP/Rd2ck6lT9UYXfGTMetX5ZyMdlyLimBmQEITMW8doN+bOLj2cZifh5z8OblmcUQLuWwwT/E9OV7DeNhhdYe1PSr62D3Oz9V9HkvFPXF9c+jjNI9ye+ycKoporF8TRXEEQz3u1AF3WY4P81qer6uP2Dx4y/jP1vo0dFR18QMqEv5kniMyz++4FuQ0pp2f3TwvT5/BcZ3F81IQ6ec5XYVOzM3jPcjydenP6OR8cvx/X54dsX5PF5nL4987+dbpWDzmB5aVoP2ara/+uG/5Bod5KlPxk/kjJz+Vfm3zmHhfFkefILi3Pv/n8ZhH81ubPDZH4zvHa5Hjd32+2z7zJC3/7ovmsbHz5WdH79Wn/3I3shzN++e7L0RPxm0+7Mn69+N5RWI/mseRxQcJ0veaU1OJ+zFr/Vkcfa0vm9F87/O8sNhLHTDBJea7uC7nlBMfeZwXx3Ud15P7HfRjfsNTiX1gudKT56LLc1rGvPtlQY4f6/tDPK7Hn3/4hzimIH33Po5N0rjf8HNHXyJ6skEjuAgxl1v8bYJI+LFHj3avu/vT/D72doCsl4D3nvz7Pf1nj5m7/zTw0wBltfK7zzyJ+4463UdLpjYBzWRusEZsInnF4XDAxEhJGGYA4TBZgB93MIUyKCuBAVhl0GGFl8KwvU3SNUNRVkW4tbkgi1J0QMSpeoNboegVKgXziWnaMbZHNIdaK+M4Umul1fkhCYemjOMI5oiB+opWM6syUA/3uNnvuJl2iCpNDV1lBhX296/ZHQ489+wF+90IKFoyqoqmysXlwA/+8I/w3/2t/4r/65/9Iv/kH/8cn//si1xsnmNzsWK1WnFxsenOvlGtcagTYuE86jghIux3Ow43Ow6PbtjtH3KoI80r+9Fppoyjcn2T2N1M1GZMU8N13QGYATskKeYSx++ANaXERMVdMBdUnNQqiRFxYmMFUkq0EwcgIrQ+WZWEusbGZ+G3XMJDZYEyOG4tAKSNuDcco3nFXEALQoF0CQipH1/Foc2b3QRS49wO2h2JOPyRP/b9fOoH/xypPAmSkTTidqBo4eEbn6ZOb9DqSM4jSSpiB5wJb5DymrE6WjaxpB2kjaRsAWzdEUkd8DjugoqRuiOIzxjmI8MqQAEuy+IXOW6WztSPlZZ/OxVNRm4XgOLscRoqGZUtrWoAZm1ockQarfniKHIujOPI2CYuLi44HA5k6XOZxPZKML8BE9RWJFE0NfBCs4r7BGKoKvgQ9ynE80BxdfLlMzz59MeZpqewJlhx0Nhgvv7V3+LNV36Z/c0XWEtB3HAN0CSp4CSm0XEmql1QucMv/9oXcd0ChqjhTLTawQhCTgObzZZSVuzVWVHJ7cDKR9yFgyeqK2gBV1oLJ202xRirYR7PRki4h2OV7nCHYUQ1fp9SPAtVZVUKg0mMj+yQVBESKitUCrMzF+nhhwRIywoJ74GH4C54ny9TM9CES6FagpTx6ohNZHWaP1iOq5KW/8czbgPzBrdM/PCYNBoiCdW+BbjQWsMM3KSD9fieaIxPM4s1u5zjdGOMOYZtMK80ewR6QCgkMiqGWV3GMILO2LwnV/BMkvADSEOoiEOyjKkyNjDNuE2kJBQtWBNqfQRyQOQQa9rByTTPMZ5mi7+Z15IvIDLmjlnre1KMm+GLP2vNMJRUMu4Ns4aKow5mgljGBRqtz6GKUxFxRJ3kBtZ9o2aMhvmBySbMRrIMqKxoXhDyAlAN68F2Amt4q4hO3Y+F/3URWk1Iyky1rzmx5RiqimoE2LWN2CRoLgiJ5kprTjNZ7hWpuFzj1lAVaInka1p+2DHOiiTr8Ocyhk+hBADThvU502hxbr9DEqXkCn5AfERs9ocDq1WJ84hR8grv86AhJB3i3oHWGs1uEG3ktMZVlnmU1ZbnXO2akvJy/JwuSDqgmtmJcZhGUhJWQyFJ7AHxzGItmk+UkhBXRDK4MvkF5iPN9rg39vuRpGtUBv6PX/j0l/gO7O0AWb8EfFhEvosAVz8O/MTv9YVmxs20I6WJ9XbDg0c30J2aHeIzKqBJOvJPmBlTjzbdBVOwBiaBq3cVXAVL3r9vqEJJkKUEw9SZJ0RBA6Vr2mBITEIPVqyRQNZxAg/WwFqNB+QVvCDtguQZfMLcmKaJVhut3pDSdSwMrYjm2Kgs0cy4dXvLbnfg0f7QJ1WwSWjD/YAj/Pqv/TL/y9/5n3n2qSf4q3/l3+Xv/72f41d+/ctLJCAiqIK7Il5JKWFWFwdTa1wrCcY0MiXvDiSuoTaYzHGUJoKRMSnACkgdnFiPuCqJFE4Aw9wwnWgIrnDwSipO8oaSSJP2zQNE0sK0OQGyVBINRTzhnYUAQwnmCBHMvO8PDRewHtG2+aqEDqQEPLMwKICo9TnWEE/ExlExtXDKCR4+vI+7x7h5sIPN5nHt12+N1hqaYw64C6JC66F1XJeQfI56J5DWAar1cYQ5YjSJ+zXNYI1mmewZZOpIbWbOFOeECZ0DXNfOynQQJoZYw70iGKIJbxPiiuNMrZElQNV+3IfzkcJbDx6RUiGXC3b7Qs5bXn7564gId65ijQiKdSYr/hxZMTshs1SPaxUPgGKiiCWsZfDSn3/toFPYbq94i4xQ+h+AmG/BrgaQGryE89RMKYWDxyYjksEKLi2CMQRYcRgzTsH1holGGxuG4SiVYGpNxsWX7HY7NIFIgPKkBYCUlZxzd+rBMK9W2teYLJu3qpJESK0TVjKgMiwOPzbAsX++czfifeMX1KUz8eHTVBQTJyu4KM3aQlwFBGu41c4lzCBQO4jwZaNd7ARggS7rxKou9+aWcGtLMHL8nizHEjsyQcE26HJc8Tj/TOqoaAAH60GCz1ydn7B8ApYRUVSHCIqwAIliNDUajmvMu+RGtghUXAPwReYi0frYxmWnTt7NewTMzGUACuv3kU+YtYZ5xdwDfLoiFHIaMB87+9/6ZxPfiR3ZUMd9Cg5rIZsLFQFrWBNSmhmZYJ5UHdVgazRZD3h
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import cv2\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"img1 = cv2.imread(data1[\"img_path\"])\n",
"img2 = cv2.imread(data1[\"ext_data\"][0][\"img_path\"])\n",
"plt.figure(figsize=(10,6))\n",
"plt.imshow(img1[:,:,::-1])\n",
"plt.show()\n",
"plt.figure(figsize=(10,6))\n",
"plt.imshow(img2[:,:,::-1])\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"* 将更新后的标注检测框画出来如下所示其中红色框是原始标注信息蓝色框是经过CopyPaste补充的标注框。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA3AAAAH5CAYAAADAytzuAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvXm0Ldld3/f57V11zr33jf261YMmuiUQFgZDMEOQkbBAwgECrIVJYoxBMTazgUAcMBhbTLG9FmNwjLPAgRgSO4YQBAoshMUosAmTJSQxCKEBDf1e95vfHc45VXv/8sdv76pddercd1/3a6kVzm+te+ucOlV7Hn7f37RFVdnSlra0pS1taUtb2tKWtrSlLT39yb2/C7ClLW1pS1va0pa2tKUtbWlLWzoZbQHclra0pS1taUtb2tKWtrSlLX2A0BbAbWlLW9rSlra0pS1taUtb2tIHCG0B3Ja2tKUtbWlLW9rSlra0pS19gNAWwG1pS1va0pa2tKUtbWlLW9rSBwhtAdyWtrSlLW1pS1va0pa2tKUtfYDQUwLgROS/EJE/FpG3isg/eCry2NKWtrSlLW1pS1va0pa2tKU/byR3+xw4EfHAW4CXA+8Gfhv4PFX9g7ua0Za2tKUtbWlLW9rSlra0pS39OaOnQgP3ccBbVfVtqroC/k/gs5+CfLa0pS1taUtb2tKWtrSlLW3pzxVVT0GazwLeVXx/N/Dx44dE5EuALwHY2dn5y89+9nOnU+sUhAIiw6/9x5Rm+W30kJY/lc8cQ1I+Kd3nECOr5QpQvHfs7u4RYwTUyqegGgnB7oUYUFVUFQH7nAqkMaI6LOag8prvjTWl2t85kRY1PyNFu2mqo3RXkeIPEOfseRm2mZNj2lDGX2Xy/vRLx9RF19tHU4MLgorryk4qs4jgnOueU4UYI6tVQ1VVVFWFiKQ+GeefB43A5O/lcxtuTdX5dv1Vtu1tns11vb0mvfi9a+px4aTIXlCNG4u1Xrz1/DeVSQaNosf0+NQv40LoSWfzdElk2MT9EmFjwrpe0RgJbUtV1Tjv8M5kX5rG5HHTYZAfxz98u9+7+brp+4bn14ZjsVaOl8VyzRsnNiyarD0iG8q+6f7dJlUd5HV0dERVVdR1/T7Jf0tbunNKs1MVNEBUWC0gBjQ0oIEYGoSISDGT8/TTtIpG7fYyESnWkePW0ZNZYGniv2TtcaW76SxZHa9ICrSBuGpBwYkDhSBizxZr0WRpZFMpj98L8rqem2F9r7qDuq/lOGq/Iu2TpprTyavz5JuanpHhWycl4zmHfPCmNNZHhBaNJuXd/p1ujNl+GKPlGTXzWpqrQMdPrfEdTBeqHGy5L/uMu5uJQxv0au6zqJA588R9F1lJn02ZwAaSjQ+spThGId0gdBvzGP7wp9dvXVbVZxxfoqcGwJ2IVPUHgR8E+OAPfoF+z/f8z+kXh5YdXHwOUhFxiFOc0i1mdVUBkdr5xMB7xNeDjVxVcc519/I1Ot/9DqCEwe8uMWoiQk1F2wauXbvKO97xDg4ObvH85z+fv/jhL6RpGitjCDipaNuWxWJB1JbDw32apmG5XLBqFqxWK9p2RQiKi54QeoAXQiDG2H3v/miLckaUQAgh164vvypR4mCgZNDoimHlEZyErk26tgO89+zM5njvmc/neC/U8xle+raY17UtDkoBmhTxoE66jYSUXleWov1zP5egaxPTP26Ptm2LdFLbuRmIx3uPcxX1zpyqqji1d5qgfV9evnKFP/qjt/ARH/5RnD17lhhh1UYiCtIC1oYQkViButQHoWjnmK6hK6OTvnzHkaR3yj4bt48JBG5PXXtqVqbn98oBMExL8/funbzzlvPDo3G59m6XxlrZ49r9yXZITEbZb5soz8WS1oFAxEXFueFcz+kO53ssPrvisyD44fPeEUJARGjbluXikGuXL3P69Gl2d3e55/xZ6rpG8cQB3p7KN11T+5ZrSlmvfB838W7x3ODZ4l75N05Tnaw9Oz0vFYciuCJfl9YJP3hfRcBtzru85ryG879vq27+c3KagoRt21JVtq2pKn/wB3/Avffey4MPPkgIAe89McZB+21pS3edSoFxxwMXa2lee6UlxBYfGnv+0ntpbl6lvfxufHvI4vI7aBa32JUFNQc4b3u2iqACDkFCNL5h2eCdI6riKo+raxIXbVmWa1Mn5DzZjGvF5r9TZ+xpTPXxCrMAlaBVAKc0QfEKvp7BsmH59nez2m/QZcTFipnfAxxH9Q4xCV3VSS80g+GaKpb/mGTtnuULcXIPGOxZEomxneQRyz1JVQja7xeZr/LiBgJ5o0gc8AjD/AclDXG4rjst+DmDHPZutbZP3qkwbLzfhvZ4IVufX0Q0WF+k8arijL+NwRQPaY9pQ03bwP7+gsVyxcEiEqNCVaNi/akYnyKtX8t7ig8w/r6Hz05z2/QKhJ5Pt8+rqIQIi6g0CosArSqt2pxpJQzwhacXStj4puO/vfcgsZgi0XjoAX9t5XOZD85tkvc920lxzuEUak37pLM2jmB5RKWJTdf2f+Onf/Wdm3u0p6cCwL0HeE7x/dnp3rEUY0yDLKJaaFEGIDxprrTXvACJ0VJCByYiIiD4boCVg7OUUjl6lrekbiFJ5ZKorMIKUmfM53PadoWIdKCrY0rEwEJd10SVBII8IbQ0berMtFA4XAfcyvzKcth3j2pIeeS6xcEzY+lzScZADuunyNpik8Fk8LZwGDPkcSGATDDfWjBjTpLsYTgZy4WsLOs0C3Yyyu2U+7ZtW5t45IUYXOgZOZAE7IwxDyGwd2qXqqpoGquz8w7VNGlF0Q7Q5MXCdfXOn/PVfu/belznQVuomwRG434f3yufK8enquLXpK1lH4z6o5T45jtJUtf3T0y/T4Os9U1qHbhNgtA43CSPBXEnYi7WQfCmz1NtYmPFr82hcdmdc8xmM2KMnfDAe0+ISap3G9BelmU4B9b7eFyW/NwY+I3H1vj92+W1aV6KG64Vd0N7tqlO7wvagrUtvc/pdlNGYmKKFScGKGhblke3aBYH0B4Qjw7S5xWubm1Py7yMSK81OIEA9MmST2qwTrfg0r7oFCoH3gSDEai9IKqwXMKtAxbXb+HUMfM7iK+J4nFSpaJLV58nW/Ik+ly7P26TAYBluI5OPSsUPMAozeFVT8zSTK35xz2X+cPj1uJNv5V7bOZZTkYOBkJUB+po2gbxHu8qsjD01s0jQlAOj1Y0QU3Y5x2avbSUjuc56X7iOtVYAm+mylqrG0BUh4onamSlLatWaaKyUogqhNTnTh0h8zt5P2K4J44FHcVISdrM0ootgbdUR0m/iUgP6jr+sd9XVZWgiqZpHBPIjapEPZnwHp4aAPfbwIeIyCMYcPsbwN887oUSwNh306L1N9Li5AUVxUVJAi5rjbZdGXJ2GQULro2oeMQpdV3jXd0xZZLgs+BBTdOSUflY0pGmpS0LSSxUVRXnzp1hPq/xlXBwcECMEe9tQFS7FZV3zOa7qCrzeZ0AhzF9i+UhIcwIrdIsVobEEyM+JcXOAKQHDVUB5jKIKIDfxFpgg2fI2CpKTBqRhn4gqipNaHEhTZ7KE7TCe48X1wFQl9ON/UCNApI1Oh2TnEEf3dXKdNyoOBmV2gSNShNbVk3AVzOcMzONdm8PVKh259R1zeLogKpyXLhwgdVqhYhn2QTTqCRtW5DWgJFWkGB+ltwBvclstzAGvHOMF0tX3OvBX8AA+bqkLvepS5K3k2jAhouQJ4lH6QeCH6Sjur6AW15pDqoHwkC6mOfOnTDDk1K10f1jAdyJtnSbG+X6MdDoFOXPRc99l3/3HihAj4gQNHbMUtbYVFXFcrmkbVvOnN7DOYev5kAv+DhOC5W1fJ1QqAA0a5/RwdXSSRtD3kh0uoXW1jDVjlkat/06qMob/Ho9nizdDrhaG52s1++ESk3jlrb0PqFSWNoNc0cEHGblgQQkHiHOwWKf5vo1jq5fpGoOiQdXWN66gm9vUDuQsMJX5EWgWxuAdQB3J/P1hMy88Q9C2uxNuOaUWJMYByW4iKha9k3L6p0XWd0
"text/plain": [
"<Figure size 1080x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import json\n",
"infos = copy_paste_demo.data_lines[idx]\n",
"infos = json.loads(infos.decode('utf-8').split(\"\\t\")[1])\n",
"\n",
"img3 = data1[\"image\"].copy()\n",
"plt.figure(figsize=(15,10))\n",
"plt.imshow(img3[:,:,::-1])\n",
"# 原始标注信息\n",
"for info in infos:\n",
" xs, ys = zip(*info[\"points\"])\n",
" xs = list(xs)\n",
" ys = list(ys)\n",
" xs.append(xs[0])\n",
" ys.append(ys[0])\n",
" plt.plot(xs, ys, \"r\")\n",
"# 新增的标注信息\n",
"for poly_idx in range(len(infos), len(data1[\"polys\"])):\n",
" poly = data1[\"polys\"][poly_idx]\n",
" xs, ys = zip(*poly)\n",
" xs = list(xs)\n",
" ys = list(ys)\n",
" xs.append(xs[0])\n",
" ys.append(ys[0])\n",
" plt.plot(xs, ys, \"b\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"### 3.1.3 文字检测优化小结\n",
"\n",
"\n",
"PP-OCRv2中对文字检测模型采用使用知识蒸馏方案以及数据增广策略增加模型的泛化性能。最终文字检测模型在大小不变的情况下Hmean从 **0.759** 提升至 **0.795**,具体消融实验如下所示。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/71c31fc78946459d9b2b0a5aeae75e9bf784399a73554bc79f8c25716ed9dcbe\" width = \"1200\" />\n",
"</div>\n",
"<center>PP-OCRv2检测模型消融实验</center>\n",
"\n",
"PP-OCRv2中检测效果如下所示。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/5f1a7e4d193e439bb7b10b89460d76c5ee1e2787ea304641834677524c210795\" width = \"1000\" />\n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## 3.2 文本识别模型优化详解\n",
"\n",
"PP-OCRv2文字识别模型优化过程中采用骨干网络优化、UDML知识蒸馏策略、CTC loss改进等技巧最终将识别精度从 **66.7%** 提升至 **74.8%**,具体消融实验如下所示。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/b45a54b41d554858a8714e308c31863354fc544582704583b84205c95cba37c3\" width = \"1000\" />\n",
"</div>\n",
"<center>PP-OCRv2识别模型消融实验</center>\n",
"\n",
"### 3.2.1 PP-LCNet轻量级骨干网络\n",
"\n",
"百度提出了一种基于 MKLDNN 加速策略的轻量级 CPU 网络,即 PP-LCNet大幅提高了轻量级模型在图像分类任务上的性能对于计算机视觉的下游任务如文本识别、目标检测、语义分割等有很好的表现。这里需要注意的是PP-LCNet是针对**CPU+MKLDNN**这个场景进行定制优化,在分类任务上的速度和精度都远远优于其他模型,因此大家如果有这个使用场景的模型需求的话,也推荐大家去使用。\n",
"\n",
"PP-LCNet 论文地址:[PP-LCNet: A Lightweight CPU Convolutional Neural Network](https://arxiv.org/abs/2109.15099)\n",
"\n",
"PP-LCNet基于MobileNetV1改进得到其结构图如下所示。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/4d3794e3dd96464b80a3b50a80c7460761bdb1fbea2b4049a744f12d9bef9204\" width = \"1000\" />\n",
"</div>\n",
"\n",
"相比于MobileNetV1PP-LCNet中融合了MobileNetV3结构中激活函数、头部结构、SE模块等策略优化技巧同时分析了最后阶段卷积层的卷积核大小最终该模型在保证速度优势的基础上精度大幅超越MobileNet、GhostNet等轻量级模型。\n",
"\n",
"具体地PP-LCNet中共涉及到下面4个优化点。\n",
"\n",
"* 除了 SE 模块,网络中所有的 relu 激活函数替换为 h-swish精度提升1%-2%\n",
"* PP-LCNet 第五阶段DW 的 kernel size 变为5x5精度提升0.5%-1%\n",
"* PP-LCNet 第五阶段的最后两个 DepthSepConv block 添加 SE 模块, 精度提升0.5%-1%\n",
"* GAP 后添加 1280 维的 FC 层增加特征表达能力精度提升2%-3%\n",
"\n",
"\n",
"在ImageNet1k数据集上PP-LCNet相比于其他目前比较常用的轻量级分类模型Top1-Acc 与预测耗时如下图所示。可以看出,预测耗时和精度都是要更优的。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/91ac1cc5cf3e439d9bdff598a3fbac5dd059d324c5124dc0a1c1ad4c15b8cd9b\" width = \"800\" />\n",
"</div>\n",
"\n",
"通过下面这种方式便可以快速完成PP-LCNet识别模型的定义。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1, 512, 1, 80]\n"
]
}
],
"source": [
"# 参考代码\n",
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/modeling/backbones/rec_mv1_enhance.py\n",
"from ppocr.modeling.backbones.rec_mv1_enhance import MobileNetV1Enhance\n",
"\n",
"x = paddle.rand([1, 3, 23, 320])\n",
"\n",
"model = MobileNetV1Enhance(scale=0.5)\n",
"\n",
"y = model(x)\n",
"print(y.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"### 3.2.2 U-DML 知识蒸馏策略\n",
"\n",
"对于标准的 DML 策略,蒸馏的损失函数仅包括最后输出层监督,然而对于 2 个结构完全相同的模型来说,对于完全相同的输入,它们的中间特征输出期望也完全相同,因此在最后输出层监督的监督上,可以进一步添加中间输出的特征图的监督信号,作为损失函数,即 PP-OCRv2 中的 U-DML (Unified-Deep Mutual Learning) 知识蒸馏方法。\n",
"\n",
"U-DML 知识蒸馏的算法流程图如下所示。 Teacher 模型与 Student 模型的网络结构完全相同,初始化参数不同,此外,在新增在标准的 DML 知识蒸馏的基础上,新增引入了对于 Feature Map 的监督机制,新增 Feature Loss。\n",
"\n",
"<div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/15c3365ce87a49e4a26f91c93ab828470bceea1941aa47cf80fa7173cbcfcbd8\" width = \"1000\" />\n",
"</div>\n",
"\n",
"在训练的过程中,总共包含 3 种 loss GT lossDML lossFeature loss。\n",
"\n",
"* GT loss\n",
"\n",
"文本识别任务使用的模型结构是 CRNN因此使用 CTC loss 作为 GT loss GT loss 计算方法如下所示。\n",
"\n",
"$$ Loss_{ctc} = CTC(S_{hout}, gt) + CTC(T_{hout}, gt) $$\n",
"\n",
"* DML loss \n",
"\n",
"DML loss 计算方法如下,这里 Teacher 模型与 Student 模型互相计算 KL 散度,最终 DML loss具有对称性。\n",
"\n",
"$$ Loss_{dml} = \\frac{KL(S_{pout} || T_{pout}) + KL(T_{pout} || S_{pout})}{2} $$\n",
"\n",
"* Feature loss\n",
"\n",
"Feature loss 使用的是 L2 loss具体计算方法如下所示。\n",
"\n",
"$$ Loss_{feat} = L2(S_{bout}, T_{bout}) $$\n",
"\n",
"最终,训练过程中的 loss function 计算方法如下所示。\n",
"\n",
"$$ Loss_{total} = Loss_{ctc} + Loss_{dml} + Loss_{feat} $$\n",
"\n",
"\n",
"此外,在训练过程中通过增加迭代次数,在 Head 部分添加 FC 层等 trick平衡模型的特征编码与解码的能力进一步提升了模型效果。\n",
"\n",
"\n",
"配置文件在[ch_PP-OCRv2_rec_distillation.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)。\n",
"\n",
"```yaml\n",
"Architecture:\n",
" model_type: &model_type \"rec\" # 模型类别rec、det等每个子网络的模型类别都与\n",
2021-12-30 15:50:45 +08:00
" name: DistillationModel # 结构名称蒸馏任务中为DistillationModel用于构建对应的结构\n",
" algorithm: Distillation # 算法名称\n",
" Models: # 模型,包含子网络的配置信息\n",
" Teacher: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数\n",
" pretrained: # 该子网络是否需要加载预训练模型\n",
" freeze_params: false # 是否需要固定参数\n",
" return_all_feats: true # 子网络的参数表示是否需要返回所有的features如果为False则只返回最后的输出\n",
" model_type: *model_type # 模型类别\n",
" algorithm: CRNN # 子网络的算法名称,该子网络剩余参与均为构造参数,与普通的模型训练配置一致\n",
" Transform:\n",
" Backbone:\n",
" name: MobileNetV1Enhance\n",
" scale: 0.5\n",
" Neck:\n",
" name: SequenceEncoder\n",
" encoder_type: rnn\n",
" hidden_size: 64\n",
" Head:\n",
" name: CTCHead\n",
" mid_channels: 96 # Head解码过程中穿插一层\n",
" fc_decay: 0.00002\n",
" Student: # 另外一个子网络这里给的是DML的蒸馏示例两个子网络结构相同均需要学习参数\n",
" pretrained: # 下面的组网参数同上\n",
" freeze_params: false\n",
" return_all_feats: true\n",
" model_type: *model_type\n",
" algorithm: CRNN\n",
" Transform:\n",
" Backbone:\n",
" name: MobileNetV1Enhance\n",
" scale: 0.5\n",
" Neck:\n",
" name: SequenceEncoder\n",
" encoder_type: rnn\n",
" hidden_size: 64\n",
" Head:\n",
" name: CTCHead\n",
" mid_channels: 96\n",
" fc_decay: 0.00002\n",
"```\n",
"\n",
"当然,这里如果希望添加更多的子网络进行训练,也可以按照`Student`与`Teacher`的添加方式在配置文件中添加相应的字段。比如说如果希望有3个模型互相监督共同训练那么`Architecture`可以写为如下格式。\n",
"\n",
"```yaml\n",
"Architecture:\n",
" model_type: &model_type \"rec\"\n",
" name: DistillationModel\n",
" algorithm: Distillation\n",
" Models:\n",
" Teacher:\n",
" pretrained:\n",
" freeze_params: false\n",
" return_all_feats: true\n",
" model_type: *model_type\n",
" algorithm: CRNN\n",
" Transform:\n",
" Backbone:\n",
" name: MobileNetV1Enhance\n",
" scale: 0.5\n",
" Neck:\n",
" name: SequenceEncoder\n",
" encoder_type: rnn\n",
" hidden_size: 64\n",
" Head:\n",
" name: CTCHead\n",
" mid_channels: 96\n",
" fc_decay: 0.00002\n",
" Student:\n",
" pretrained:\n",
" freeze_params: false\n",
" return_all_feats: true\n",
" model_type: *model_type\n",
" algorithm: CRNN\n",
" Transform:\n",
" Backbone:\n",
" name: MobileNetV1Enhance\n",
" scale: 0.5\n",
" Neck:\n",
" name: SequenceEncoder\n",
" encoder_type: rnn\n",
" hidden_size: 64\n",
" Head:\n",
" name: CTCHead\n",
" mid_channels: 96\n",
" fc_decay: 0.00002\n",
" Student2: # 知识蒸馏任务中引入的新的子网络,其他部分与上述配置相同\n",
" pretrained:\n",
" freeze_params: false\n",
" return_all_feats: true\n",
" model_type: *model_type\n",
" algorithm: CRNN\n",
" Transform:\n",
" Backbone:\n",
" name: MobileNetV1Enhance\n",
" scale: 0.5\n",
" Neck:\n",
" name: SequenceEncoder\n",
" encoder_type: rnn\n",
" hidden_size: 64\n",
" Head:\n",
" name: CTCHead\n",
" mid_channels: 96\n",
" fc_decay: 0.00002\n",
"```\n",
"\n",
"最终该模型训练时包含3个子网络`Teacher`, `Student`, `Student2`。\n",
"\n",
"蒸馏模型`DistillationModel`类的具体实现代码可以参考[distillation_model.py](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/ppocr/modeling/architectures/distillation_model.py)。\n",
"\n",
"最终模型`forward`输出为一个字典key为所有的子网络名称例如这里为`Student`与`Teacher`value为对应子网络的输出可以为`Tensor`(只返回该网络的最后一层)和`dict`(也返回了中间的特征信息)。\n",
"\n",
"在识别任务中,为了添加更多损失函数,保证蒸馏方法的可扩展性,将每个子网络的输出保存为`dict`,其中包含子模块输出。以该识别模型为例,每个子网络的输出结果均为`dict`key包含`backbone_out`,`neck_out`, `head_out``value`为对应模块的tensor最终对于上述配置文件`DistillationModel`的输出格式如下。\n",
"\n",
"```json\n",
"{\n",
" \"Teacher\": {\n",
" \"backbone_out\": tensor,\n",
" \"neck_out\": tensor,\n",
" \"head_out\": tensor,\n",
" },\n",
" \"Student\": {\n",
" \"backbone_out\": tensor,\n",
" \"neck_out\": tensor,\n",
" \"head_out\": tensor,\n",
" }\n",
"}\n",
"```\n",
"\n",
"知识蒸馏任务中,损失函数配置如下所示。\n",
"\n",
"```yaml\n",
"Loss:\n",
" name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类\n",
" loss_config_list: # 损失函数配置文件列表为CombinedLoss的必备函数\n",
" - DistillationCTCLoss: # 基于蒸馏的CTC损失函数继承自标准的CTC loss\n",
" weight: 1.0 # 损失函数的权重loss_config_list中每个损失函数的配置都必须包含该字段\n",
" model_name_list: [\"Student\", \"Teacher\"] # 对于蒸馏模型的预测结果提取这两个子网络的输出与gt计算CTC loss\n",
" key: head_out # 取子网络输出dict中该key对应的tensor\n",
" - DistillationDMLLoss: # 蒸馏的DML损失函数继承自标准的DMLLoss\n",
" weight: 1.0 # 权重\n",
" act: \"softmax\" # 激活函数对输入使用激活函数处理可以为softmax, sigmoid或者为None默认为None\n",
" model_name_pairs: # 用于计算DML loss的子网络名称对如果希望计算其他子网络的DML loss可以在列表下面继续填充\n",
" - [\"Student\", \"Teacher\"]\n",
" key: head_out # 取子网络输出dict中该key对应的tensor\n",
" - DistillationDistanceLoss: # 蒸馏的距离损失函数\n",
" weight: 1.0 # 权重\n",
" mode: \"l2\" # 距离计算方法目前支持l1, l2, smooth_l1\n",
" model_name_pairs: # 用于计算distance loss的子网络名称对\n",
" - [\"Student\", \"Teacher\"]\n",
" key: backbone_out # 取子网络输出dict中该key对应的tensor\n",
"```\n",
"\n",
"上述损失函数中,所有的蒸馏损失函数均继承自标准的损失函数类,主要功能为: 对蒸馏模型的输出进行解析,找到用于计算损失的中间节点(tensor),再使用标准的损失函数类去计算。\n",
"\n",
"以上述配置为例最终蒸馏训练的损失函数包含下面3个部分。\n",
"\n",
"- `Student`和`Teacher`的最终输出(`head_out`)与gt的CTC loss权重为1。在这里因为2个子网络都需要更新参数因此2者都需要计算与gt的loss。\n",
"- `Student`和`Teacher`的最终输出(`head_out`)之间的DML loss权重为1。\n",
"- `Student`和`Teacher`的骨干网络输出(`backbone_out`)之间的l2 loss权重为1。\n",
"\n",
"`CombinedLoss`类实现如下。\n",
"\n",
"```python\n",
"class CombinedLoss(nn.Layer):\n",
" \"\"\"\n",
" CombinedLoss:\n",
" a combionation of loss function\n",
" \"\"\"\n",
"\n",
" def __init__(self, loss_config_list=None):\n",
" super().__init__()\n",
" self.loss_func = []\n",
" self.loss_weight = []\n",
" assert isinstance(loss_config_list, list), (\n",
" 'operator config should be a list')\n",
" for config in loss_config_list:\n",
" assert isinstance(config,\n",
" dict) and len(config) == 1, \"yaml format error\"\n",
" name = list(config)[0]\n",
" param = config[name]\n",
" assert \"weight\" in param, \"weight must be in param, but param just contains {}\".format(\n",
" param.keys())\n",
" self.loss_weight.append(param.pop(\"weight\"))\n",
" self.loss_func.append(eval(name)(**param))\n",
"\n",
" def forward(self, input, batch, **kargs):\n",
" loss_dict = {}\n",
" loss_all = 0.\n",
" for idx, loss_func in enumerate(self.loss_func):\n",
" loss = loss_func(input, batch, **kargs)\n",
" if isinstance(loss, paddle.Tensor):\n",
" loss = {\"loss_{}_{}\".format(str(loss), idx): loss}\n",
"\n",
" weight = self.loss_weight[idx]\n",
"\n",
" loss = {key: loss[key] * weight for key in loss}\n",
"\n",
" if \"loss\" in loss:\n",
" loss_all += loss[\"loss\"]\n",
" else:\n",
" loss_all += paddle.add_n(list(loss.values()))\n",
" loss_dict.update(loss)\n",
" loss_dict[\"loss\"] = loss_all\n",
" return loss_dict\n",
"```\n",
"\n",
"关于`CombinedLoss`更加具体的实现可以参考: [combined_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/losses/combined_loss.py)。关于`DistillationCTCLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/ppocr/losses/distillation_loss.py)。\n",
"\n",
"\n",
"\n",
"对于上面3个模型的蒸馏Loss字段也需要相应修改同时考虑3个子网络之间的损失如下所示。\n",
"\n",
"```yaml\n",
"Loss:\n",
" name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类\n",
" loss_config_list: # 损失函数配置文件列表为CombinedLoss的必备函数\n",
" - DistillationCTCLoss: # 基于蒸馏的CTC损失函数继承自标准的CTC loss\n",
" weight: 1.0 # 损失函数的权重loss_config_list中每个损失函数的配置都必须包含该字段\n",
" model_name_list: [\"Student\", \"Student2\", \"Teacher\"] # 对于蒸馏模型的预测结果提取这三个子网络的输出与gt计算CTC loss\n",
" key: head_out # 取子网络输出dict中该key对应的tensor\n",
" - DistillationDMLLoss: # 蒸馏的DML损失函数继承自标准的DMLLoss\n",
" weight: 1.0 # 权重\n",
" act: \"softmax\" # 激活函数对输入使用激活函数处理可以为softmax, sigmoid或者为None默认为None\n",
" model_name_pairs: # 用于计算DML loss的子网络名称对如果希望计算其他子网络的DML loss可以在列表下面继续填充\n",
" - [\"Student\", \"Teacher\"]\n",
" - [\"Student2\", \"Teacher\"]\n",
" - [\"Student\", \"Student2\"]\n",
" key: head_out # 取子网络输出dict中该key对应的tensor\n",
" - DistillationDistanceLoss: # 蒸馏的距离损失函数\n",
" weight: 1.0 # 权重\n",
" mode: \"l2\" # 距离计算方法目前支持l1, l2, smooth_l1\n",
" model_name_pairs: # 用于计算distance loss的子网络名称对\n",
" - [\"Student\", \"Teacher\"]\n",
" - [\"Student2\", \"Teacher\"]\n",
" - [\"Student\", \"Student2\"]\n",
" key: backbone_out # 取子网络输出dict中该key对应的tensor\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2021-12-24 22:30:18-- https://paddleocr.bj.bcebos.com/dataset/rec_data_lesson_demo.tar\n",
"Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.195, 182.61.200.229, 2409:8c04:1001:1002:0:ff:b001:368a\n",
"Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.195|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 699098618 (667M) [application/x-tar]\n",
"Saving to: rec_data_lesson_demo.tar\n",
"\n",
"rec_data_lesson_dem 100%[===================>] 666.71M 42.5MB/s in 16s \n",
"\n",
"2021-12-24 22:30:34 (40.7 MB/s) - rec_data_lesson_demo.tar saved [699098618/699098618]\n",
"\n",
"--2021-12-24 22:30:41-- https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar\n",
"Resolving paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)... 182.61.200.229, 182.61.200.195, 2409:8c04:1001:1002:0:ff:b001:368a\n",
"Connecting to paddleocr.bj.bcebos.com (paddleocr.bj.bcebos.com)|182.61.200.229|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 77350400 (74M) [application/x-tar]\n",
"Saving to: ch_PP-OCRv2_rec_train.tar\n",
"\n",
"ch_PP-OCRv2_rec_tra 100%[===================>] 73.77M 50.5MB/s in 1.5s \n",
"\n",
"2021-12-24 22:30:43 (50.5 MB/s) - ch_PP-OCRv2_rec_train.tar saved [77350400/77350400]\n",
"\n",
"[2021/12/24 22:30:45] root INFO: Architecture : \n",
"[2021/12/24 22:30:45] root INFO: Models : \n",
"[2021/12/24 22:30:45] root INFO: Student : \n",
"[2021/12/24 22:30:45] root INFO: Backbone : \n",
"[2021/12/24 22:30:45] root INFO: name : MobileNetV1Enhance\n",
"[2021/12/24 22:30:45] root INFO: scale : 0.5\n",
"[2021/12/24 22:30:45] root INFO: Head : \n",
"[2021/12/24 22:30:45] root INFO: fc_decay : 2e-05\n",
"[2021/12/24 22:30:45] root INFO: mid_channels : 96\n",
"[2021/12/24 22:30:45] root INFO: name : CTCHead\n",
"[2021/12/24 22:30:45] root INFO: Neck : \n",
"[2021/12/24 22:30:45] root INFO: encoder_type : rnn\n",
"[2021/12/24 22:30:45] root INFO: hidden_size : 64\n",
"[2021/12/24 22:30:45] root INFO: name : SequenceEncoder\n",
"[2021/12/24 22:30:45] root INFO: Transform : None\n",
"[2021/12/24 22:30:45] root INFO: algorithm : CRNN\n",
"[2021/12/24 22:30:45] root INFO: freeze_params : False\n",
"[2021/12/24 22:30:45] root INFO: model_type : rec\n",
"[2021/12/24 22:30:45] root INFO: pretrained : None\n",
"[2021/12/24 22:30:45] root INFO: return_all_feats : True\n",
"[2021/12/24 22:30:45] root INFO: Teacher : \n",
"[2021/12/24 22:30:45] root INFO: Backbone : \n",
"[2021/12/24 22:30:45] root INFO: name : MobileNetV1Enhance\n",
"[2021/12/24 22:30:45] root INFO: scale : 0.5\n",
"[2021/12/24 22:30:45] root INFO: Head : \n",
"[2021/12/24 22:30:45] root INFO: fc_decay : 2e-05\n",
"[2021/12/24 22:30:45] root INFO: mid_channels : 96\n",
"[2021/12/24 22:30:45] root INFO: name : CTCHead\n",
"[2021/12/24 22:30:45] root INFO: Neck : \n",
"[2021/12/24 22:30:45] root INFO: encoder_type : rnn\n",
"[2021/12/24 22:30:45] root INFO: hidden_size : 64\n",
"[2021/12/24 22:30:45] root INFO: name : SequenceEncoder\n",
"[2021/12/24 22:30:45] root INFO: Transform : None\n",
"[2021/12/24 22:30:45] root INFO: algorithm : CRNN\n",
"[2021/12/24 22:30:45] root INFO: freeze_params : False\n",
"[2021/12/24 22:30:45] root INFO: model_type : rec\n",
"[2021/12/24 22:30:45] root INFO: pretrained : None\n",
"[2021/12/24 22:30:45] root INFO: return_all_feats : True\n",
"[2021/12/24 22:30:45] root INFO: algorithm : Distillation\n",
"[2021/12/24 22:30:45] root INFO: model_type : rec\n",
"[2021/12/24 22:30:45] root INFO: name : DistillationModel\n",
"[2021/12/24 22:30:45] root INFO: Eval : \n",
"[2021/12/24 22:30:45] root INFO: dataset : \n",
"[2021/12/24 22:30:45] root INFO: data_dir : ./rec_data_lesson_demo/\n",
"[2021/12/24 22:30:45] root INFO: label_file_list : ['./rec_data_lesson_demo/val.txt']\n",
"[2021/12/24 22:30:45] root INFO: name : SimpleDataSet\n",
"[2021/12/24 22:30:45] root INFO: transforms : \n",
"[2021/12/24 22:30:45] root INFO: DecodeImage : \n",
"[2021/12/24 22:30:45] root INFO: channel_first : False\n",
"[2021/12/24 22:30:45] root INFO: img_mode : BGR\n",
"[2021/12/24 22:30:45] root INFO: CTCLabelEncode : None\n",
"[2021/12/24 22:30:45] root INFO: RecResizeImg : \n",
"[2021/12/24 22:30:45] root INFO: image_shape : [3, 32, 320]\n",
"[2021/12/24 22:30:45] root INFO: KeepKeys : \n",
"[2021/12/24 22:30:45] root INFO: keep_keys : ['image', 'label', 'length']\n",
"[2021/12/24 22:30:45] root INFO: loader : \n",
"[2021/12/24 22:30:45] root INFO: batch_size_per_card : 128\n",
"[2021/12/24 22:30:45] root INFO: drop_last : False\n",
"[2021/12/24 22:30:45] root INFO: num_workers : 0\n",
"[2021/12/24 22:30:45] root INFO: shuffle : False\n",
"[2021/12/24 22:30:45] root INFO: Global : \n",
"[2021/12/24 22:30:45] root INFO: cal_metric_during_train : True\n",
"[2021/12/24 22:30:45] root INFO: character_dict_path : ppocr/utils/ppocr_keys_v1.txt\n",
"[2021/12/24 22:30:45] root INFO: checkpoints : None\n",
"[2021/12/24 22:30:45] root INFO: debug : False\n",
"[2021/12/24 22:30:45] root INFO: distributed : False\n",
"[2021/12/24 22:30:45] root INFO: epoch_num : 1\n",
"[2021/12/24 22:30:45] root INFO: eval_batch_step : [0, 2000]\n",
"[2021/12/24 22:30:45] root INFO: infer_img : doc/imgs_words/ch/word_1.jpg\n",
"[2021/12/24 22:30:45] root INFO: infer_mode : False\n",
"[2021/12/24 22:30:45] root INFO: log_smooth_window : 20\n",
"[2021/12/24 22:30:45] root INFO: max_text_length : 25\n",
"[2021/12/24 22:30:45] root INFO: pretrained_model : ./ch_PP-OCRv2_rec_train/best_accuracy\n",
"[2021/12/24 22:30:45] root INFO: print_batch_step : 10\n",
"[2021/12/24 22:30:45] root INFO: save_epoch_step : 3\n",
"[2021/12/24 22:30:45] root INFO: save_inference_dir : None\n",
"[2021/12/24 22:30:45] root INFO: save_model_dir : ./output/rec_pp-OCRv2_distillation\n",
"[2021/12/24 22:30:45] root INFO: save_res_path : ./output/rec/predicts_pp-OCRv2_distillation.txt\n",
"[2021/12/24 22:30:45] root INFO: use_gpu : True\n",
"[2021/12/24 22:30:45] root INFO: use_space_char : True\n",
"[2021/12/24 22:30:45] root INFO: use_visualdl : False\n",
"[2021/12/24 22:30:45] root INFO: Loss : \n",
"[2021/12/24 22:30:45] root INFO: loss_config_list : \n",
"[2021/12/24 22:30:45] root INFO: DistillationCTCLoss : \n",
"[2021/12/24 22:30:45] root INFO: key : head_out\n",
"[2021/12/24 22:30:45] root INFO: model_name_list : ['Student', 'Teacher']\n",
"[2021/12/24 22:30:45] root INFO: weight : 1.0\n",
"[2021/12/24 22:30:45] root INFO: DistillationDMLLoss : \n",
"[2021/12/24 22:30:45] root INFO: act : softmax\n",
"[2021/12/24 22:30:45] root INFO: key : head_out\n",
"[2021/12/24 22:30:45] root INFO: model_name_pairs : [['Student', 'Teacher']]\n",
"[2021/12/24 22:30:45] root INFO: use_log : True\n",
"[2021/12/24 22:30:45] root INFO: weight : 1.0\n",
"[2021/12/24 22:30:45] root INFO: DistillationDistanceLoss : \n",
"[2021/12/24 22:30:45] root INFO: key : backbone_out\n",
"[2021/12/24 22:30:45] root INFO: mode : l2\n",
"[2021/12/24 22:30:45] root INFO: model_name_pairs : [['Student', 'Teacher']]\n",
"[2021/12/24 22:30:45] root INFO: weight : 1.0\n",
"[2021/12/24 22:30:45] root INFO: name : CombinedLoss\n",
"[2021/12/24 22:30:45] root INFO: Metric : \n",
"[2021/12/24 22:30:45] root INFO: base_metric_name : RecMetric\n",
"[2021/12/24 22:30:45] root INFO: key : Student\n",
"[2021/12/24 22:30:45] root INFO: main_indicator : acc\n",
"[2021/12/24 22:30:45] root INFO: name : DistillationMetric\n",
"[2021/12/24 22:30:45] root INFO: Optimizer : \n",
"[2021/12/24 22:30:45] root INFO: beta1 : 0.9\n",
"[2021/12/24 22:30:45] root INFO: beta2 : 0.999\n",
"[2021/12/24 22:30:45] root INFO: lr : \n",
"[2021/12/24 22:30:45] root INFO: decay_epochs : [700, 800]\n",
"[2021/12/24 22:30:45] root INFO: name : Piecewise\n",
"[2021/12/24 22:30:45] root INFO: values : [0.0001, 1e-05]\n",
"[2021/12/24 22:30:45] root INFO: warmup_epoch : 5\n",
"[2021/12/24 22:30:45] root INFO: name : Adam\n",
"[2021/12/24 22:30:45] root INFO: regularizer : \n",
"[2021/12/24 22:30:45] root INFO: factor : 2e-05\n",
"[2021/12/24 22:30:45] root INFO: name : L2\n",
"[2021/12/24 22:30:45] root INFO: PostProcess : \n",
"[2021/12/24 22:30:45] root INFO: key : head_out\n",
"[2021/12/24 22:30:45] root INFO: model_name : ['Student', 'Teacher']\n",
"[2021/12/24 22:30:45] root INFO: name : DistillationCTCLabelDecode\n",
"[2021/12/24 22:30:45] root INFO: Train : \n",
"[2021/12/24 22:30:45] root INFO: dataset : \n",
"[2021/12/24 22:30:45] root INFO: data_dir : ./rec_data_lesson_demo/\n",
"[2021/12/24 22:30:45] root INFO: label_file_list : ['./rec_data_lesson_demo/train.txt']\n",
"[2021/12/24 22:30:45] root INFO: name : SimpleDataSet\n",
"[2021/12/24 22:30:45] root INFO: transforms : \n",
"[2021/12/24 22:30:45] root INFO: DecodeImage : \n",
"[2021/12/24 22:30:45] root INFO: channel_first : False\n",
"[2021/12/24 22:30:45] root INFO: img_mode : BGR\n",
"[2021/12/24 22:30:45] root INFO: RecAug : None\n",
"[2021/12/24 22:30:45] root INFO: CTCLabelEncode : None\n",
"[2021/12/24 22:30:45] root INFO: RecResizeImg : \n",
"[2021/12/24 22:30:45] root INFO: image_shape : [3, 32, 320]\n",
"[2021/12/24 22:30:45] root INFO: KeepKeys : \n",
"[2021/12/24 22:30:45] root INFO: keep_keys : ['image', 'label', 'length']\n",
"[2021/12/24 22:30:45] root INFO: loader : \n",
"[2021/12/24 22:30:45] root INFO: batch_size_per_card : 64\n",
"[2021/12/24 22:30:45] root INFO: drop_last : True\n",
"[2021/12/24 22:30:45] root INFO: num_sections : 1\n",
"[2021/12/24 22:30:45] root INFO: num_workers : 0\n",
"[2021/12/24 22:30:45] root INFO: shuffle : True\n",
"[2021/12/24 22:30:45] root INFO: profiler_options : None\n",
"[2021/12/24 22:30:45] root INFO: train with paddle 2.2.1 and device CUDAPlace(0)\n",
"[2021/12/24 22:30:45] root INFO: Initialize indexs of datasets:['./rec_data_lesson_demo/train.txt']\n",
"[2021/12/24 22:30:45] root INFO: Initialize indexs of datasets:['./rec_data_lesson_demo/val.txt']\n",
"W1224 22:30:45.741250 9254 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1\n",
"W1224 22:30:45.746162 9254 device_context.cc:465] device: 0, cuDNN Version: 7.6.\n",
"[2021/12/24 22:30:50] root INFO: load pretrain successful from ./ch_PP-OCRv2_rec_train/best_accuracy\n",
"[2021/12/24 22:30:50] root INFO: train dataloader has 1562 iters\n",
"[2021/12/24 22:30:50] root INFO: valid dataloader has 24 iters\n",
"[2021/12/24 22:30:50] root INFO: During the training process, after the 0th iteration, an evaluation is run every 2000 iterations\n",
"[2021/12/24 22:30:50] root INFO: Initialize indexs of datasets:['./rec_data_lesson_demo/train.txt']\n",
"[2021/12/24 22:31:01] root INFO: epoch: [1/1], iter: 10, lr: 0.000000, loss_ctc_Student_0: 6.976444, loss_ctc_Teacher_1: 8.681884, dml_0: 7.565064, loss: 23.507660, loss_distance_l2_Student_Teacher_0: 0.025505, acc: 0.562491, norm_edit_dis: 0.740752, Teacher_acc: 0.609365, Teacher_norm_edit_dis: 0.739364, reader_cost: 0.36132 s, batch_cost: 0.66199 s, samples: 704, ips: 106.34653\n",
"[2021/12/24 22:31:11] root INFO: epoch: [1/1], iter: 20, lr: 0.000000, loss_ctc_Student_0: 7.744696, loss_ctc_Teacher_1: 8.654169, dml_0: 8.570195, loss: 26.458534, loss_distance_l2_Student_Teacher_0: 0.026079, acc: 0.531242, norm_edit_dis: 0.735827, Teacher_acc: 0.593741, Teacher_norm_edit_dis: 0.760099, reader_cost: 0.36109 s, batch_cost: 0.59941 s, samples: 640, ips: 106.77130\n",
"[2021/12/24 22:31:21] root INFO: epoch: [1/1], iter: 30, lr: 0.000000, loss_ctc_Student_0: 8.108994, loss_ctc_Teacher_1: 8.537874, dml_0: 9.982393, loss: 26.843945, loss_distance_l2_Student_Teacher_0: 0.026251, acc: 0.507805, norm_edit_dis: 0.716328, Teacher_acc: 0.593741, Teacher_norm_edit_dis: 0.770375, reader_cost: 0.39665 s, batch_cost: 0.64006 s, samples: 640, ips: 99.99085\n",
"[2021/12/24 22:31:32] root INFO: epoch: [1/1], iter: 40, lr: 0.000000, loss_ctc_Student_0: 7.732370, loss_ctc_Teacher_1: 8.648810, dml_0: 8.644684, loss: 25.863911, loss_distance_l2_Student_Teacher_0: 0.025766, acc: 0.507805, norm_edit_dis: 0.709930, Teacher_acc: 0.585928, Teacher_norm_edit_dis: 0.767394, reader_cost: 0.36238 s, batch_cost: 0.60918 s, samples: 640, ips: 105.05873\n",
"[2021/12/24 22:31:42] root INFO: epoch: [1/1], iter: 50, lr: 0.000001, loss_ctc_Student_0: 8.219507, loss_ctc_Teacher_1: 10.171026, dml_0: 8.194988, loss: 26.805073, loss_distance_l2_Student_Teacher_0: 0.025741, acc: 0.539054, norm_edit_dis: 0.709930, Teacher_acc: 0.562491, Teacher_norm_edit_dis: 0.763058, reader_cost: 0.43782 s, batch_cost: 0.69340 s, samples: 640, ips: 92.29917\n",
"[2021/12/24 22:31:53] root INFO: epoch: [1/1], iter: 60, lr: 0.000001, loss_ctc_Student_0: 7.573787, loss_ctc_Teacher_1: 9.168297, dml_0: 8.064046, loss: 25.023621, loss_distance_l2_Student_Teacher_0: 0.025920, acc: 0.562491, norm_edit_dis: 0.734843, Teacher_acc: 0.585928, Teacher_norm_edit_dis: 0.756837, reader_cost: 0.39662 s, batch_cost: 0.64859 s, samples: 640, ips: 98.67595\n",
"[2021/12/24 22:32:04] root INFO: epoch: [1/1], iter: 70, lr: 0.000001, loss_ctc_Student_0: 7.743058, loss_ctc_Teacher_1: 8.413120, dml_0: 9.065943, loss: 24.850718, loss_distance_l2_Student_Teacher_0: 0.026217, acc: 0.546866, norm_edit_dis: 0.715387, Teacher_acc: 0.593741, Teacher_norm_edit_dis: 0.751324, reader_cost: 0.42741 s, batch_cost: 0.66567 s, samples: 640, ips: 96.14405\n",
"[2021/12/24 22:32:14] root INFO: epoch: [1/1], iter: 80, lr: 0.000001, loss_ctc_Student_0: 8.279991, loss_ctc_Teacher_1: 8.347084, dml_0: 8.233022, loss: 24.850718, loss_distance_l2_Student_Teacher_0: 0.026422, acc: 0.531242, norm_edit_dis: 0.708250, Teacher_acc: 0.601553, Teacher_norm_edit_dis: 0.769606, reader_cost: 0.38362 s, batch_cost: 0.62691 s, samples: 640, ips: 102.08804\n",
"^C\n",
"main proc 9254 exit, kill process group 9254\n"
]
}
],
"source": [
"# 下载数据\n",
"!wget -nc https://paddleocr.bj.bcebos.com/dataset/rec_data_lesson_demo.tar && tar -xf rec_data_lesson_demo.tar && rm rec_data_lesson_demo.tar\n",
"# # 下载预训练模型\n",
"!wget -nc https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar && tar -xf ch_PP-OCRv2_rec_train.tar && rm ch_PP-OCRv2_rec_train.tar\n",
"\n",
"!python tools/train.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml \\\n",
" -o Train.dataset.data_dir=\"./rec_data_lesson_demo/\" \\\n",
" Train.dataset.label_file_list=[\"./rec_data_lesson_demo/train.txt\"] \\\n",
" Train.loader.num_workers=0 \\\n",
" Train.loader.batch_size_per_card=64 \\\n",
" Eval.dataset.data_dir=\"./rec_data_lesson_demo/\" \\\n",
" Eval.dataset.label_file_list=[\"./rec_data_lesson_demo/val.txt\"] \\\n",
" Eval.loader.num_workers=0 \\\n",
" Optimizer.lr.values=[0.0001,0.00001] \\\n",
" Global.epoch_num=1 \\\n",
" Global.pretrained_model=\"./ch_PP-OCRv2_rec_train/best_accuracy\""
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"### 3.2.3 Enhanced CTC loss 改进\n",
"\n",
"中文 OCR 任务经常遇到的识别难点是相似字符数太多,容易误识。借鉴 Metric Learning 中的想法,引入 Center loss进一步增大类间距离核心公式如下所示。\n",
"\n",
"$$ L = L_{ctc} + \\lambda * L_{center} $$\n",
"$$ L_{center} =\\sum_{t=1}^T||x_{t} - c_{y_{t}}||_{2}^{2} $$\n",
"\n",
"这里 $x_t$ 表示时间步长 $t$ 处的标签,$c_{y_{t}}$ 表示标签 $y_t$ 对应的 center。\n",
"\n",
"Enhance CTC 中center 的初始化对结果也有较大影响,在 PP-OCRv2 中center 初始化的具体步骤如下所示。\n",
"\n",
"1. 基于标准的 CTC loss训练一个网络\n",
"2. 提取出训练集合中识别正确的图像集合,记为 G \n",
"3. 将 G 中的图片依次输入网络, 提取head输出时序特征的 $x_t$ 和 $y_t$ 的对应关系,其中 $y_t$ 计算方式如下:\n",
"\n",
"$$ y_{t} = argmax(W * x_{t}) $$\n",
"\n",
"4. 将相同 $y_t$ 对应的 $x_t$ 聚合在一起,取其平均值,作为初始 center。\n",
"\n",
"\n",
"首先需要基于[configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)训练一个基础网络\n",
"\n",
"更多关于Center loss的训练步骤可以参考[Enhanced CTC Loss使用文档](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/doc/doc_ch/enhanced_ctc_loss.md)\n",
"\n",
"最后,使用[configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.4/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml)进行训练,命令如下所示。\n",
"\n",
"```shell\n",
"python tools/train.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml\n",
"```\n",
"\n",
"主要改进点为`Loss`字段,相比于标准的`CTCLoss`,添加了`CenterLoss`。配置类别数、特征维度、center路径即可。\n",
"\n",
"```yaml\n",
"Loss:\n",
" name: CombinedLoss\n",
" loss_config_list:\n",
" - CTCLoss:\n",
" use_focal_loss: false\n",
" weight: 1.0\n",
" - CenterLoss:\n",
" weight: 0.05\n",
" num_classes: 6625\n",
" feat_dim: 96\n",
" center_file_path: \"./train_center.pkl\"\n",
"```\n",
"\n",
"### 3.2.4 文本识别优化小结\n",
"\n",
"PP-OCRv2文字识别模型优化过程中对模型从骨干网络、损失函数等角度进行改进并引入知识蒸馏的训练方法最终将识别精度从 **66.7%** 提升至 **74.8%**,具体消融实验如下所示。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/b45a54b41d554858a8714e308c31863354fc544582704583b84205c95cba37c3\" width = \"800\" />\n",
"</div>\n",
"<center>PP-OCRv2识别模型消融实验</center>\n",
"\n",
"在PP-OCRv2文字检测的基础上识别模型的实验效果如下所示。\n",
"\n",
"</div><div align=\"center\">\n",
"<img src=\"https://ai-studio-static-online.cdn.bcebos.com/12bec5f8e2e94fedbf4f2851f182e6bc7a1cff80a31b4390bedfcfbb7d8d6a6d\" width = \"800\" />\n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# 4. 总结\n",
"\n",
"本章主要介绍PP-OCR以及PP-OCRv2的优化策略。\n",
"\n",
"PP-OCR从骨干网络、学习率策略、数据增广、模型裁剪量化等方面共使用了19个策略对模型进行优化瘦身最终打造了面向服务器端的PP-OCR server系统以及面向移动端的PP-OCR mobile系统。\n",
"\n",
"相比于PP-OCR PP-OCRv2 在骨干网络、数据增广、损失函数这三个方面进行进一步优化解决端侧预测效率较差、背景复杂以及相似字符的误识等问题同时引入了知识蒸馏训练策略进一步提升模型精度最终打造了精度、速度远超PP-OCR的文字检测与识别系统。"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# 5. 作业\n",
"\n",
"具体内容见课程结业必修中的`优化策略客观题`以及`优化策略实战题`部分。"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "py35-paddle1.2.0"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 1
}