PaddleOCR/notebook/notebook_ch/2.text_detection/文本检测实践篇.ipynb

3380 lines
4.0 MiB
Plaintext
Raw Normal View History

2021-12-23 19:59:29 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# OCR 文本检测实战\n",
"\n",
"本节将介绍如何使用PaddleOCR完成文本检测DB算法的训练与运行包括\n",
"1. 快速调用paddleocr包体验文本检测\n",
"1. 理解文本检测DB算法原理\n",
"2. 掌握文本检测模型构建流程\n",
"3. 掌握文本检测模型训练流程\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# 1. 快速开始\n",
"\n",
"本节以[paddleocr](https://pypi.org/project/paddleocr/)为例,介绍如何三个步骤快速实现文本检测。\n",
"1. 安装[paddleocr](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/doc/doc_ch/whl.md)\n",
"2. 一行命令运行DB算法得到检测结果\n",
"3. 可视化文本检测结果\n",
"\n",
"\n",
"\n",
"\n",
"**安装paddleocr whl包**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"!pip install --upgrade pip\n",
"!pip install paddleocr"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"**一行命令实现文本检测**\n",
"\n",
"初次运行时paddleocr会自动下载并使用PaddleOCR的[PP-OCRv2轻量级模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/README.md#pp-ocr-series-model-listupdate-on-september-8th)。\n",
"\n",
"使用安装好的paddleocr 以./12.jpg为输入图像将得到以下预测结果\n",
"\n",
"<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/4e31d512f7e147d4847cb1a0ee27a8260ef05506c9254fc1b19137bab1831ac8\"\n",
"width=\"200\", height=\"400\" ></center>\n",
"\n",
"<br><center>图 12.jpg </center>\n",
"\n",
"```\n",
"[[79.0, 555.0], [398.0, 542.0], [399.0, 571.0], [80.0, 584.0]]\n",
"[[21.0, 507.0], [512.0, 491.0], [513.0, 532.0], [22.0, 548.0]]\n",
"[[174.0, 458.0], [397.0, 449.0], [398.0, 480.0], [175.0, 489.0]]\n",
"[[42.0, 414.0], [482.0, 392.0], [484.0, 428.0], [44.0, 450.0]]\n",
"```\n",
"预测结果一共包含四个文本框,每一行包含四个坐标点,代表一个文本框的坐标集合,从左上角起以顺时针顺序排列。\n",
"\n",
"\n",
"paddleocr命令行调用文本检测模型预测图像./12.jpg的方式如下\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import os\n",
"# 修改Aistudio代码运行的默认目录为 /home/aistudio/\n",
"os.chdir(\"/home/aistudio/\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" 0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)\n",
"/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)\n",
"[2021/12/22 21:07:19] root WARNING: version PP-OCRv2 not support cls models, auto switch to version PP-OCR\n",
"Namespace(benchmark=False, cls_batch_num=6, cls_image_shape='3, 48, 192', cls_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer', cls_thresh=0.9, cpu_threads=10, det=True, det_algorithm='DB', det_db_box_thresh=0.6, det_db_score_mode='fast', det_db_thresh=0.3, det_db_unclip_ratio=1.5, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_east_score_thresh=0.8, det_limit_side_len=960, det_limit_type='max', det_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/det/ch/ch_PP-OCRv2_det_infer', det_pse_box_thresh=0.85, det_pse_box_type='box', det_pse_min_area=16, det_pse_scale=1, det_pse_thresh=0, det_sast_nms_thresh=0.2, det_sast_polygon=False, det_sast_score_thresh=0.5, drop_score=0.5, e2e_algorithm='PGNet', e2e_char_dict_path='./ppocr/utils/ic15_dict.txt', e2e_limit_side_len=768, e2e_limit_type='max', e2e_model_dir=None, e2e_pgnet_mode='fast', e2e_pgnet_polygon=True, e2e_pgnet_score_thresh=0.5, e2e_pgnet_valid_set='totaltext', enable_mkldnn=False, gpu_mem=500, help='==SUPPRESS==', image_dir='./12.jpg', ir_optim=True, label_list=['0', '180'], lang='ch', layout_path_model='lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config', max_batch_size=10, max_text_length=25, min_subgraph_size=15, ocr_version='PP-OCRv2', output='./output/table', precision='fp32', process_id=0, rec=False, rec_algorithm='CRNN', rec_batch_num=6, rec_char_dict_path='/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddleocr/ppocr/utils/ppocr_keys_v1.txt', rec_image_shape='3, 32, 320', rec_model_dir='/home/aistudio/.paddleocr/2.3.0.2/ocr/rec/ch/ch_PP-OCRv2_rec_infer', save_log_path='./log_output/', show_log=True, structure_version='STRUCTURE', table_char_dict_path=None, table_char_type='en', table_max_len=488, table_model_dir=None, total_process_num=1, type='ocr', use_angle_cls=False, use_dilation=False, use_gpu=True, use_mp=False, use_onnx=False, use_pdserving=False, use_space_char=True, use_tensorrt=False, vis_font_path='./doc/fonts/simfang.ttf', warmup=True)\n",
"[2021/12/22 21:07:21] root INFO: **********./12.jpg**********\n",
"[2021/12/22 21:07:23] root INFO: [[79.0, 555.0], [398.0, 542.0], [399.0, 571.0], [80.0, 584.0]]\n",
"[2021/12/22 21:07:23] root INFO: [[21.0, 507.0], [512.0, 491.0], [513.0, 532.0], [22.0, 548.0]]\n",
"[2021/12/22 21:07:23] root INFO: [[174.0, 458.0], [397.0, 449.0], [398.0, 480.0], [175.0, 489.0]]\n",
"[2021/12/22 21:07:23] root INFO: [[42.0, 414.0], [482.0, 392.0], [484.0, 428.0], [44.0, 450.0]]\n"
]
}
],
"source": [
"# --image_dir 指向要预测的图像路径 --rec false表示不使用识别识别只执行文本检测\n",
"! paddleocr --image_dir ./12.jpg --rec false"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"另外除了命令行使用方式paddleocr也提供了代码调用方式如下"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2021/12/22 21:07:58] root WARNING: version 2.1 not support cls models, use version 2.0 instead\n",
"Namespace(benchmark=False, cls_batch_num=6, cls_image_shape='3, 48, 192', cls_model_dir='/home/aistudio/.paddleocr/2.2.1/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer', cls_thresh=0.9, cpu_threads=10, det=True, det_algorithm='DB', det_db_box_thresh=0.6, det_db_score_mode='fast', det_db_thresh=0.3, det_db_unclip_ratio=1.5, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, det_east_score_thresh=0.8, det_limit_side_len=960, det_limit_type='max', det_model_dir='/home/aistudio/.paddleocr/2.2.1/ocr/det/ch/ch_PP-OCRv2_det_infer', det_sast_nms_thresh=0.2, det_sast_polygon=False, det_sast_score_thresh=0.5, drop_score=0.5, e2e_algorithm='PGNet', e2e_char_dict_path='./ppocr/utils/ic15_dict.txt', e2e_limit_side_len=768, e2e_limit_type='max', e2e_model_dir=None, e2e_pgnet_mode='fast', e2e_pgnet_polygon=True, e2e_pgnet_score_thresh=0.5, e2e_pgnet_valid_set='totaltext', enable_mkldnn=False, gpu_mem=500, help='==SUPPRESS==', image_dir=None, ir_optim=True, label_list=['0', '180'], lang='ch', layout_path_model='lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config', max_batch_size=10, max_text_length=25, min_subgraph_size=15, output='./output/table', precision='fp32', process_id=0, rec=True, rec_algorithm='CRNN', rec_batch_num=6, rec_char_dict_path='/home/aistudio/PaddleOCR/ppocr/utils/ppocr_keys_v1.txt', rec_char_type='ch', rec_image_shape='3, 32, 320', rec_model_dir='/home/aistudio/.paddleocr/2.2.1/ocr/rec/ch/ch_PP-OCRv2_rec_infer', save_log_path='./log_output/', show_log=True, table_char_dict_path=None, table_char_type='en', table_max_len=488, table_model_dir=None, total_process_num=1, type='ocr', use_angle_cls=False, use_dilation=False, use_gpu=True, use_mp=False, use_pdserving=False, use_space_char=True, use_tensorrt=False, version='2.1', vis_font_path='./doc/fonts/simfang.ttf', warmup=True)\n",
"[2021/12/22 21:07:59] root WARNING: Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process\n",
"The predicted text box of ./12.jpg are follows.\n",
"[[[79.0, 555.0], [398.0, 542.0], [399.0, 571.0], [80.0, 584.0]], [[21.0, 507.0], [512.0, 491.0], [513.0, 532.0], [22.0, 548.0]], [[174.0, 458.0], [397.0, 449.0], [398.0, 480.0], [175.0, 489.0]], [[42.0, 414.0], [482.0, 392.0], [484.0, 428.0], [44.0, 450.0]]]\n"
]
}
],
"source": [
"# 1. 从paddleocr中import PaddleOCR类\n",
"from paddleocr import PaddleOCR\n",
"\n",
"# 2. 声明PaddleOCR类\n",
"ocr = PaddleOCR() \n",
"img_path = './12.jpg'\n",
"# 3. 执行预测\n",
"result = ocr.ocr(img_path, rec=False)\n",
"print(f\"The predicted text box of {img_path} are follows.\")\n",
"print(result)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"**可视化文本检测预测结果**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7ff48841bcd0>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAJCCAYAAADEEWDaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvXm8nVV97/9eaz3THs88ZyQDSYCQMCUIiKBIRLFUVBw6iRa1DvdW0dbrWFtrW/21v1rb/q61DkXR61BRZhkVMAySBIgmZA4Zzzzs6ZnWWvePZ59DUNRQ+SHXuz+v1yacfZ6z99r7eZ7v9Pl+P0tYa2mhhRZaaKGFWcjf9AJaaKGFFlp4fqHlGFpooYUWWngKWo6hhRZaaKGFp6DlGFpooYUWWngKWo6hhRZaaKGFp6DlGFpooYUWWngKnnPHIITYIIR4XAixSwjx58/1+7fQQgsttPDLIZ7LOQYhhAJ2ABcBB4GHgNdba3/6nC2ihRZaaKGFX4rnOmM4C9hlrd1jrY2BrwO/8xyvoYUWWmihhV8C5zl+vyHgwDE/HwTWHXuAEOIq4CoA11Wnd3UXEUIQxQnlcjuptnR3D2DSlOnxozgSBJYs7xHYp/zXYizEKXT29OG6HqNHD+EIg7CQagPW4noeQkCWPVk0gkRbOjp78H2fifFhrE5wBAghsYA1FqSiEcX4+SKxNnR0dlOt1GgrFRg+cpQkiojCECkFQkhyOR8cRW9vH46UTE2OgkkQ1iCQgEAby2SlSt/gfJSU7Nm9C9/z6O3toxFFWKDc1kYu8Nm29SdIIcjlchSKRTzPRTkOxUIRxP/v57KF/4tgrQUhMMYAEIYNarU6juMQBAFe4COFxFiLEAKtDUIIho8exXFcAHp7u7EWtNaEjZB8voAxmtHRUZT06OnpRKrs/Y4OD9PX24exBiUlk5NThFGC67i4nkO5VJxdGVIIjIVqtQZCEMcxXZ3tAIyOjpHL5YmjCN/3KRbzT/lMQggmxscpl8tYmmsLQ8rlMkI+GTf/Jm8nYwyyuRZrDNPjYyAEjSgCKag3QhKtMYDrOlhjQECapBQKBaIwRAiB4zhMjI6PWWt7ftV7PteO4VfCWvs54HMAA4Nt9i1vfRHl9k527NnL+nNezHRV8pYrryathXz3c39NR6DxHIFFYKQCKVHNs5gKQ2wsW/fUuPwP3snipSfwH//0VxSiSXypcDwfnRqODI+Sy/l093SBsMQyYNfhaS6+7A2sXbuWb3z1X4gnnqDDkzjKRTiSMDEYGbBt31GCvkUcnkx5x/vez49+sJkXrlvLjd/8Bh2eYnJ4hJtuuolKZYaO7jZEMc+X/uMroGPuvPnrVId30Jl3iCsxWiumU/jh5m3845e+gdAJa1aeyAkLFuF7eV71+jcwVqtw6eWvQhjLx99zNXt37eTS33kFTxw+zJvf9CaMMZx99tkIpRASTPN7bXUZtPDrIE5THMfh85//PFe+5S0AWCwCQSNscPDQIfxcjp6BQay13HzL93lk8yO8/73v54n9B/jyNdfwoY98IAvy6hF33XEHr3j5y3E9iEL43Ge/yuo1qzj/orXESYKJXe67dyN33H0rH/jzq9l4/4MsXb6ayYkpHGU56aRluA4IYxBSMj1T58Zbb6Oru5eR4SO85tWX4buSVIM1sGfnfk48cSFS2qZ9EFhrsdayd+9eSqUSPT09aK2RUmbOrfl7IcTc41iI5u+AOYf5dMc9Y8xW9495mbBaIYljSu1lttx1B73z5/Gxv/lbDkyMcXR8Ar9UxCDJF3zGxkfI5/NUp6ssW7KEwwcPUSyUCMOQ+79/9/7jWcJzbS8OAfOP+Xle87mnhUDiuTmslgjrkCQJQlogJci5WCFQKgsxhABHSLASpRyEo1BKkfM9coFLrV7FWoufy2HQGBJSHaFtwsBAH7lcjkMHDzMyPEq1WiXn+8SNEK01SqmneG1js5tESonWmkJQII4T4gYsXLiYmZkGrutSKBQYGhriTW96Ey+68AImxqeoTtd56x+/HWMsaWxwnAJhaLDKwSiJF+QplNoIPI8g8Dhh4XxmxkYxjTq33HADA909+Mqhu6uLpStWMG/hQpacuJI1p58GMLeu2Wuz5RBaeDYQxzEAR48eRQJxGKIQYAyFIMfyJSewYHCIQAiUtVx68QZ0FFLwBe3lMkkUEXgSTwp27/gpbaUCAoMUoE1MUMjRNzhAPUyo1kKsgEcf28boyBhtbUVeeP55PPboJo4eOUCprYDjAsIiBGhjKBTzvPIVl3DmutN45e+8HOUI0lQTxTFJajLbIZgz2vYYo75kyRKKxawykaYpUsrm8Zl9EcdkSsdCCMEsRyulnLMPvy6MaD5mn7CGm6+/gVKpDEJycHSMm++4k/3DR7GuQ6GjPVtnkjI+OoorXSpTFXSScO6557Jo4WIOHTrEzMzMca/hubYbDwHLhBCLhRAe8Drge7/o4DhO0KGi4LaTD8pgHExsidMkM34OWAlIjRAmi5BtihUiO1EKlFKUCnmSqAHCEAQB0nHwfIVQBiEN0gHHVfT19dHe1kGjWuPo0aNMTk4ihEAKhTVZ5DF7gUiZXQxRGBOFMTo2zExV8JRHWzFHHIY4jiKXz1Nua+P0M8/kLVe9lUK+zK5du7jvvvuoRympUThuCStcpBPQiDS9ff0YY/Cl4l1vvYqu9jK1mWnGjxzhpu9eR297G6UgR7lcZuEJi+nq6mLz5s04UjLQ15d9ecf0FLScQwu/LvL5PNZakiRBAvkgQAKOlFkR1FqUtSjAUxIlLH/18Y8iyO7BxYsWMDMxTRo1OOvM03jJi19Eo1EhjiMC36Wnu4P2cglXKTrbSygJg4Nd/NVffwxrNTnf5eKXvIiXb7gAR2jq1RoSMFaDsCgFu/fsAqOJozrSglQC3/dQCHK5gCiyzSwBBFlGgJQkWqOtJTUGx/PQ1iId5ymR/2wAeixmM47h4WHCMCRN018/W2jiSaeQlfB+97WX8/EPfoCvfu5z7Ny/j9vvvYehJUuYqFQJCnkk0BbkaMzUCDyPYj5PsVDiputv4uGHH8b3ffpmbcNx4Dm1GdbaFHgncCuwDfiGtfYnv+h4nRq2PbqTTQ89QjijUcano9xNXE+QKIywGBKMTZCOIdGZ8RfSIpVCSolFI5VgbGwMIQSltjaEcuf4BMeVGJsglMFYjRCCjs52ujramJ6awFpLW7GNXFAgSQHhorUgTQ3GWIIgTxAE9PX0IAz4fkCaZo4mTqK51/S9gK6eXt7wut9nbHSCl1z8UgbnL+Yr115HLVaEqYuWPtINMNLBdyVSCsJGjav/+38Dq9FRg72Pb+fu79+G70CxVKBcLtPb28s9d9/D/T/amF2YQrQ4hhZ+baRp+pSfhRCcfvrpP/c81mINmaGdPRYIw4hEa8odJf7kT95CsZAj53tIbZECSoU8nucihMVxFLV6hTCqE6cRXgDnv3g9fX1tSCGRAh7ZspnDhw+zaOF8SsUCjbCBVA5SZnziSStPJAproA0WgxSSNI74u0/9DY899hh+8GRJaHp6hiNHjnDw4EGiJv9wbKZgjEFrDUC9XidJkrl/a7UaP/rRj/jCF77ANddcw9TUFNVqlUqlQhRFwJMB5OxrPBMYYzOnl6ZYo7nvBz/g/VdfTb6tzLdu+C7fv/de8p2d3P/jhxCOImzEVGZqzFSrFPN5+rr76OvuY3J8guHhYaSUhGHIoUO/sDjzc3jOOQZr7U3ATcdzrOd6LBicT60RsXfPfh7dup3zLnwZa9cM01Fsw1oXYzXGWHCzC9fzHaSjkEJirUAAjoIwrGOMoaOrk9pBj8BRpDY7aVYKrJYkSQoWrImxJjPonnJwHY9NmzZzwbozAUmqdWZ8LRQKBYqFAk6titYpQZA5I60THFcibHbDWKFQykXImJ7BAbSAUmcPMw34yte+x9nrz2DlquUIHRHkLEmS4ipDvV7H7+7jghe/mJtuuRUvl+cv/+KjvHTDBoJ8niCfZ3BwkEA5BJ7H49u2M3/RIjCWObKlhRb+C3CczDykacqmxx5lcnKShzY
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"import cv2\n",
"import matplotlib.pyplot as plt\n",
"# 在notebook中使用matplotlib.pyplot绘图时需要添加该命令进行显示\n",
"%matplotlib inline\n",
"\n",
"# 4. 可视化检测结果\n",
"image = cv2.imread(img_path)\n",
"boxes = [line[0] for line in result]\n",
"for box in result:\n",
" box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)\n",
" image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)\n",
"\n",
"# 画出读取的图片\n",
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(image)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# 2. DB文本检测算法详细实现\n",
"\n",
"\n",
"## 2.1 DB文本检测算法原理\n",
"\n",
"\n",
"[DB](https://arxiv.org/pdf/1911.08947.pdf)是一个基于分割的文本检测算法其提出可微分阈值Differenttiable Binarization moduleDB module采用动态的阈值区分文本区域与背景。\n",
"\n",
"\n",
"<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/5eabdb59916a4267a049e5440f5093a63b6bfac9010844fb971aad0607d455a1\" width = \"600\"></center>\n",
"<center><br>图1 DB模型与其他方法的区别</br></center>\n",
"<br></br>\n",
"\n",
"基于分割的普通文本检测算法其流程如上图中的蓝色箭头所示,此类方法得到分割结果之后采用一个固定的阈值得到二值化的分割图,之后采用诸如像素聚类的启发式算法得到文本区域。\n",
"\n",
"DB算法的流程如图中红色箭头所示最大的不同在于DB有一个阈值图通过网络去预测图片每个位置处的阈值而不是采用一个固定的值更好的分离文本背景与前景。\n",
"\n",
"DB算法有以下几个优势\n",
"1. 算法结构简单,无需繁琐的后处理\n",
"2. 在开源数据上拥有良好的精度和性能\n",
"\n",
"\n",
"在传统的图像分割算法中获取概率图后会使用标准二值化Standard Binarize方法进行处理将低于阈值的像素点置0高于阈值的像素点置1公式如下\n",
"$$ B_{i,j}=\\left\\{\n",
"\\begin{aligned}\n",
"1 , if P_{i,j} >= t ,\\\\\n",
"0 , otherwise. \n",
"\\end{aligned}\n",
"\\right.\n",
"$$\n",
"但是标准的二值化方法是不可微的导致网络无法端对端训练。为了解决这个问题DB算法提出了可微二值化Differentiable BinarizationDB。可微二值化将标准二值化中的阶跃函数进行了近似使用如下公式进行代替\n",
"$$\n",
"\\hat{B} = \\frac{1}{1 + e^{-k(P_{i,j}-T_{i,j})}}\n",
"$$\n",
"其中P是上文中获取的概率图T是上文中获取的阈值图k是增益因子在实验中根据经验选取为50。标准二值化和可微二值化的对比图如 **下图3a** 所示。\n",
"\n",
"当使用交叉熵损失时正负样本的loss分别为 $l_+$ 和 $l_-$ \n",
"$$\n",
"l_+ = -log(\\frac{1}{1 + e^{-k(P_{i,j}-T_{i,j})}})\n",
"$$\n",
"$$\n",
"l_- = -log(1-\\frac{1}{1 + e^{-k(P_{i,j}-T_{i,j})}})\n",
"$$\n",
"对输入 $x$ 求偏导则会得到:\n",
"$$\n",
"\\frac{\\delta{l_+}}{\\delta{x}} = -kf(x)e^{-kx}\n",
"$$\n",
"$$\n",
"\\frac{\\delta{l_-}}{\\delta{x}} = -kf(x)\n",
"$$\n",
"可以发现,增强因子会放大错误预测的梯度,从而优化模型得到更好的结果。**图3b** 中,$x<0$ 的部分为正样本预测为负样本的情况可以看到增益因子k将梯度进行了放大而 **图3c** 中$x>0$ 的部分为负样本预测为正样本时,梯度同样也被放大了。\n",
"\n",
"<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/29255d870bd74403af37c8f88cb10ebca0c3117282614774a3d607efc8be8c84\" width = \"600\"></center>\n",
"<center><br>图3DB算法示意图</br></center>\n",
"<br></br>\n",
"\n",
"\n",
"\n",
"DB算法整体结构如下图所示\n",
"\n",
"<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/6e1f293e9a1f4c90b6c26919f16b95a4a85dcf7be73f4cc99c9dc5477bb956e6\" width = \"1000\"></center>\n",
"<center><br>图2 DB模型网络结构示意图</br></center>\n",
"<br></br>\n",
"\n",
"输入的图像经过网络Backbone和FPN提取特征提取后的特征级联在一起得到原图四分之一大小的特征然后利用卷积层分别得到文本区域预测概率图和阈值图进而通过DB的后处理得到文本包围曲线。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## 2.2 DB文本检测模型构建\n",
"\n",
"\n",
"DB文本检测模型可以分为三个部分\n",
"- Backbone网络负责提取图像的特征\n",
"- FPN网络特征金字塔结构增强特征\n",
"- Head网络计算文本区域概率图\n",
"\n",
"本节使用PaddlePaddle分别实现上述三个网络模块并完成完整的网络构建。\n",
"\n",
"\n",
"**backbone网络**\n",
"\n",
"DB文本检测网络的Backbone部分采用的是图像分类网络论文中使用了ResNet50本节实验中为了加快训练速度采用MobileNetV3 large结构作为backbone。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
"Requirement already satisfied: pip in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (21.3.1)\n",
"Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
"Requirement already satisfied: shapely in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 1)) (1.8.0)\n",
"Requirement already satisfied: scikit-image==0.17.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 2)) (0.17.2)\n",
"Requirement already satisfied: imgaug==0.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 3)) (0.4.0)\n",
"Requirement already satisfied: pyclipper in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 4)) (1.3.0.post2)\n",
"Requirement already satisfied: lmdb in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 5)) (1.2.1)\n",
"Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 6)) (4.27.0)\n",
"Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 7)) (1.20.3)\n",
"Requirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 8)) (2.2.0)\n",
"Requirement already satisfied: python-Levenshtein in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 9)) (0.12.2)\n",
"Requirement already satisfied: opencv-contrib-python==4.4.0.46 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 10)) (4.4.0.46)\n",
"Requirement already satisfied: lxml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 11)) (4.7.1)\n",
"Requirement already satisfied: premailer in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 12)) (3.10.0)\n",
"Requirement already satisfied: openpyxl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 13)) (3.0.5)\n",
"Requirement already satisfied: pillow!=7.1.0,!=7.1.1,>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (7.1.2)\n",
"Requirement already satisfied: matplotlib!=3.0.0,>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.2.3)\n",
"Requirement already satisfied: scipy>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (1.6.3)\n",
"Requirement already satisfied: PyWavelets>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (1.2.0)\n",
"Requirement already satisfied: tifffile>=2019.7.26 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2021.11.2)\n",
"Requirement already satisfied: networkx>=2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.4)\n",
"Requirement already satisfied: imageio>=2.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-image==0.17.2->-r requirements.txt (line 2)) (2.6.1)\n",
"Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (4.1.1.26)\n",
"Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from imgaug==0.4.0->-r requirements.txt (line 3)) (1.15.0)\n",
"Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (2.22.0)\n",
"Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.8.2)\n",
"Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.21.0)\n",
"Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.1)\n",
"Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.8.53)\n",
"Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.0.0)\n",
"Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (3.14.0)\n",
"Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (0.7.1.1)\n",
"Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl->-r requirements.txt (line 8)) (1.1.5)\n",
"Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from python-Levenshtein->-r requirements.txt (line 9)) (56.2.0)\n",
"Requirement already satisfied: cssutils in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (2.3.0)\n",
"Requirement already satisfied: cachetools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (4.0.0)\n",
"Requirement already satisfied: cssselect in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from premailer->-r requirements.txt (line 12)) (1.1.0)\n",
"Requirement already satisfied: jdcal in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 13)) (1.4.1)\n",
"Requirement already satisfied: et-xmlfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from openpyxl->-r requirements.txt (line 13)) (1.0.1)\n",
"Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.6.0)\n",
"Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (2.2.0)\n",
"Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.6.1)\n",
"Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (0.23)\n",
"Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.0)\n",
"Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (0.16.0)\n",
"Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (7.0)\n",
"Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (2.11.0)\n",
"Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2.8.0)\n",
"Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl->-r requirements.txt (line 8)) (2019.3)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (1.1.0)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (2.8.0)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (2.4.2)\n",
"Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib!=3.0.0,>=2.0.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (0.10.0)\n",
"Requirement already satisfied: decorator>=4.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from networkx>=2.0->scikit-image==0.17.2->-r requirements.txt (line 2)) (4.4.2)\n",
"Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (3.9.9)\n",
"Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl->-r requirements.txt (line 8)) (0.18.0)\n",
"Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (16.7.9)\n",
"Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.4.10)\n",
"Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (2.0.1)\n",
"Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (0.10.0)\n",
"Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (5.1.2)\n",
"Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.0)\n",
"Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl->-r requirements.txt (line 8)) (1.3.4)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2.8)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (1.25.6)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl->-r requirements.txt (line 8)) (2019.9.11)\n",
"Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl->-r requirements.txt (line 8)) (1.1.1)\n",
"Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl->-r requirements.txt (line 8)) (3.6.0)\n"
]
}
],
"source": [
"# 首次运行需要打开下一行的注释下载PaddleOCR代码\n",
"#!git clone https://gitee.com/paddlepaddle/PaddleOCR\n",
"import os\n",
"# 修改代码运行的默认目录为 /home/aistudio/PaddleOCR\n",
"os.chdir(\"/home/aistudio/PaddleOCR\")\n",
"# 安装PaddleOCR第三方依赖\n",
"!pip install --upgrade pip\n",
"!pip install -r requirements.txt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/modeling/backbones/det_mobilenet_v3.py\n",
"from ppocr.modeling.backbones.det_mobilenet_v3 import MobileNetV3"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"如果您希望使用ResNet作为Backbone训练可以在PaddleOCR代码中选择[ResNet](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/modeling/backbones/det_resnet_vd.py),或者从[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/release/2.0/ppcls/modeling/architectures)中选择backbone模型。\n",
"\n",
"\n",
"DB的Backbone用于提取图像的多尺度特征如下代码所示假设输入的形状为[640, 640]backbone网络的输出有四个特征其形状分别是 [1, 16, 160, 160][1, 24, 80, 80] [1, 56, 40, 40][1, 480, 20, 20]。\n",
"这些特征将输入给特征金字塔FPN网络进一步的增强特征。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MobileNetV3(\n",
" (conv): ConvBNLayer(\n",
" (conv): Conv2D(3, 8, kernel_size=[3, 3], stride=[2, 2], padding=1, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (stage0): Sequential(\n",
" (0): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(8, 8, kernel_size=[3, 3], padding=1, groups=8, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(8, 8, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (1): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(8, 32, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(32, 32, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=32, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(32, 16, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (2): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 40, kernel_size=[3, 3], padding=1, groups=40, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 16, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" )\n",
" (stage1): Sequential(\n",
" (0): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(16, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 40, kernel_size=[5, 5], stride=[2, 2], padding=2, groups=40, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (mid_se): SEModule(\n",
" (avg_pool): AdaptiveAvgPool2D(output_size=1)\n",
" (conv1): Conv2D(40, 10, kernel_size=[1, 1], data_format=NCHW)\n",
" (conv2): Conv2D(10, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 24, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (1): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (mid_se): SEModule(\n",
" (avg_pool): AdaptiveAvgPool2D(output_size=1)\n",
" (conv1): Conv2D(64, 16, kernel_size=[1, 1], data_format=NCHW)\n",
" (conv2): Conv2D(16, 64, kernel_size=[1, 1], data_format=NCHW)\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (2): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(24, 64, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(64, 64, kernel_size=[5, 5], padding=2, groups=64, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (mid_se): SEModule(\n",
" (avg_pool): AdaptiveAvgPool2D(output_size=1)\n",
" (conv1): Conv2D(64, 16, kernel_size=[1, 1], data_format=NCHW)\n",
" (conv2): Conv2D(16, 64, kernel_size=[1, 1], data_format=NCHW)\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(64, 24, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" )\n",
" (stage2): Sequential(\n",
" (0): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(24, 120, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(120, 120, kernel_size=[3, 3], stride=[2, 2], padding=1, groups=120, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(120, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (1): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 104, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(104, 104, kernel_size=[3, 3], padding=1, groups=104, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(104, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (2): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (3): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 96, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(96, 96, kernel_size=[3, 3], padding=1, groups=96, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(96, 40, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (4): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(40, 240, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(240, 240, kernel_size=[3, 3], padding=1, groups=240, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (mid_se): SEModule(\n",
" (avg_pool): AdaptiveAvgPool2D(output_size=1)\n",
" (conv1): Conv2D(240, 60, kernel_size=[1, 1], data_format=NCHW)\n",
" (conv2): Conv2D(60, 240, kernel_size=[1, 1], data_format=NCHW)\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(240, 56, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (5): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(336, 336, kernel_size=[3, 3], padding=1, groups=336, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (mid_se): SEModule(\n",
" (avg_pool): AdaptiveAvgPool2D(output_size=1)\n",
" (conv1): Conv2D(336, 84, kernel_size=[1, 1], data_format=NCHW)\n",
" (conv2): Conv2D(84, 336, kernel_size=[1, 1], data_format=NCHW)\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(336, 56, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" )\n",
" (stage3): Sequential(\n",
" (0): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(56, 336, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(336, 336, kernel_size=[5, 5], stride=[2, 2], padding=2, groups=336, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (mid_se): SEModule(\n",
" (avg_pool): AdaptiveAvgPool2D(output_size=1)\n",
" (conv1): Conv2D(336, 84, kernel_size=[1, 1], data_format=NCHW)\n",
" (conv2): Conv2D(84, 336, kernel_size=[1, 1], data_format=NCHW)\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(336, 80, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (1): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (mid_se): SEModule(\n",
" (avg_pool): AdaptiveAvgPool2D(output_size=1)\n",
" (conv1): Conv2D(480, 120, kernel_size=[1, 1], data_format=NCHW)\n",
" (conv2): Conv2D(120, 480, kernel_size=[1, 1], data_format=NCHW)\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (2): ResidualUnit(\n",
" (expand_conv): ConvBNLayer(\n",
" (conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (bottleneck_conv): ConvBNLayer(\n",
" (conv): Conv2D(480, 480, kernel_size=[5, 5], padding=2, groups=480, data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" (mid_se): SEModule(\n",
" (avg_pool): AdaptiveAvgPool2D(output_size=1)\n",
" (conv1): Conv2D(480, 120, kernel_size=[1, 1], data_format=NCHW)\n",
" (conv2): Conv2D(120, 480, kernel_size=[1, 1], data_format=NCHW)\n",
" )\n",
" (linear_conv): ConvBNLayer(\n",
" (conv): Conv2D(480, 80, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
" (3): ConvBNLayer(\n",
" (conv): Conv2D(80, 480, kernel_size=[1, 1], data_format=NCHW)\n",
" (bn): BatchNorm()\n",
" )\n",
" )\n",
")\n",
"The index is 0 and the shape of output is [1, 16, 160, 160]\n",
"The index is 1 and the shape of output is [1, 24, 80, 80]\n",
"The index is 2 and the shape of output is [1, 56, 40, 40]\n",
"The index is 3 and the shape of output is [1, 480, 20, 20]\n"
]
}
],
"source": [
"import paddle \n",
"\n",
"fake_inputs = paddle.randn([1, 3, 640, 640], dtype=\"float32\")\n",
"\n",
"# 1. 声明Backbone\n",
"model_backbone = MobileNetV3()\n",
"model_backbone.eval()\n",
"\n",
"# 2. 执行预测\n",
"outs = model_backbone(fake_inputs)\n",
"\n",
"# 3. 打印网络结构\n",
"print(model_backbone)\n",
"\n",
"# 4. 打印输出特征形状\n",
"for idx, out in enumerate(outs):\n",
" print(\"The index is \", idx, \"and the shape of output is \", out.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"**FPN网络**\n",
"\n",
"特征金字塔结构FPN是一种卷积网络来高效提取图片中各维度特征的常用方法。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/modeling/necks/db_fpn.py\n",
"\n",
"import paddle\n",
"from paddle import nn\n",
"import paddle.nn.functional as F\n",
"from paddle import ParamAttr\n",
"\n",
"class DBFPN(nn.Layer):\n",
" def __init__(self, in_channels, out_channels, **kwargs):\n",
" super(DBFPN, self).__init__()\n",
" self.out_channels = out_channels\n",
"\n",
" # DBFPN详细实现参考 https://github.com/PaddlePaddle/PaddleOCRblob/release%2F2.4/ppocr/modeling/necks/db_fpn.py\n",
"\n",
" def forward(self, x):\n",
" c2, c3, c4, c5 = x\n",
"\n",
" in5 = self.in5_conv(c5)\n",
" in4 = self.in4_conv(c4)\n",
" in3 = self.in3_conv(c3)\n",
" in2 = self.in2_conv(c2)\n",
"\n",
" # 特征上采样\n",
" out4 = in4 + F.upsample(\n",
" in5, scale_factor=2, mode=\"nearest\", align_mode=1) # 1/16\n",
" out3 = in3 + F.upsample(\n",
" out4, scale_factor=2, mode=\"nearest\", align_mode=1) # 1/8\n",
" out2 = in2 + F.upsample(\n",
" out3, scale_factor=2, mode=\"nearest\", align_mode=1) # 1/4\n",
"\n",
" p5 = self.p5_conv(in5)\n",
" p4 = self.p4_conv(out4)\n",
" p3 = self.p3_conv(out3)\n",
" p2 = self.p2_conv(out2)\n",
"\n",
" # 特征上采样\n",
" p5 = F.upsample(p5, scale_factor=8, mode=\"nearest\", align_mode=1)\n",
" p4 = F.upsample(p4, scale_factor=4, mode=\"nearest\", align_mode=1)\n",
" p3 = F.upsample(p3, scale_factor=2, mode=\"nearest\", align_mode=1)\n",
"\n",
" fuse = paddle.concat([p5, p4, p3, p2], axis=1)\n",
" return fuse"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"FPN网络的输入为Backbone部分的输出输出特征图的高度和宽度为原图的四分之一。假设输入图像的形状为[1, 3, 640, 640]FPN输出特征的高度和宽度为[160, 160]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DBFPN(\n",
" (in2_conv): Conv2D(16, 256, kernel_size=[1, 1], data_format=NCHW)\n",
" (in3_conv): Conv2D(24, 256, kernel_size=[1, 1], data_format=NCHW)\n",
" (in4_conv): Conv2D(56, 256, kernel_size=[1, 1], data_format=NCHW)\n",
" (in5_conv): Conv2D(480, 256, kernel_size=[1, 1], data_format=NCHW)\n",
" (p5_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p4_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p3_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (p2_conv): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
")\n",
"The shape of fpn outs [1, 256, 160, 160]\n"
]
}
],
"source": [
"\n",
"import paddle \n",
"\n",
"# 1. 从PaddleOCR中import DBFPN\n",
"from ppocr.modeling.necks.db_fpn import DBFPN\n",
"\n",
"# 2. 获得Backbone网络输出结果\n",
"fake_inputs = paddle.randn([1, 3, 640, 640], dtype=\"float32\")\n",
"model_backbone = MobileNetV3()\n",
"in_channles = model_backbone.out_channels\n",
"\n",
"# 3. 声明FPN网络\n",
"model_fpn = DBFPN(in_channels=in_channles, out_channels=256)\n",
"\n",
"# 4. 打印FPN网络\n",
"print(model_fpn)\n",
"\n",
"# 5. 计算得到FPN结果输出\n",
"outs = model_backbone(fake_inputs)\n",
"fpn_outs = model_fpn(outs)\n",
"\n",
"# 6. 打印FPN输出特征形状\n",
"print(f\"The shape of fpn outs {fpn_outs.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"**Head网络**\n",
"\n",
"计算文本区域概率图,文本区域阈值图以及文本区域二值图。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import math\n",
"import paddle\n",
"from paddle import nn\n",
"import paddle.nn.functional as F\n",
"from paddle import ParamAttr\n",
"\n",
"class DBHead(nn.Layer):\n",
" \"\"\"\n",
" Differentiable Binarization (DB) for text detection:\n",
" see https://arxiv.org/abs/1911.08947\n",
" args:\n",
" params(dict): super parameters for build DB network\n",
" \"\"\"\n",
"\n",
" def __init__(self, in_channels, k=50, **kwargs):\n",
" super(DBHead, self).__init__()\n",
" self.k = k\n",
"\n",
" # DBHead详细实现参考 https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/modeling/heads/det_db_head.py\n",
"\n",
" def step_function(self, x, y):\n",
" # 可微二值化实现,通过概率图和阈值图计算文本分割二值图\n",
" return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))\n",
"\n",
" def forward(self, x, targets=None):\n",
" shrink_maps = self.binarize(x)\n",
" if not self.training:\n",
" return {'maps': shrink_maps}\n",
"\n",
" threshold_maps = self.thresh(x)\n",
" binary_maps = self.step_function(shrink_maps, threshold_maps)\n",
" y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)\n",
" return {'maps': y}"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"DB Head网络会在FPN特征的基础上作上采样将FPN特征由原图的四分之一大小映射到原图大小。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DBHead(\n",
" (binarize): Head(\n",
" (conv1): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (conv_bn1): BatchNorm()\n",
" (conv2): Conv2DTranspose(64, 64, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" (conv_bn2): BatchNorm()\n",
" (conv3): Conv2DTranspose(64, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" )\n",
" (thresh): Head(\n",
" (conv1): Conv2D(256, 64, kernel_size=[3, 3], padding=1, data_format=NCHW)\n",
" (conv_bn1): BatchNorm()\n",
" (conv2): Conv2DTranspose(64, 64, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" (conv_bn2): BatchNorm()\n",
" (conv3): Conv2DTranspose(64, 1, kernel_size=[2, 2], stride=[2, 2], data_format=NCHW)\n",
" )\n",
")\n",
"The shape of fpn outs [1, 256, 160, 160]\n",
"The shape of DB head outs [1, 3, 640, 640]\n"
]
}
],
"source": [
"# 1. 从PaddleOCR中imort DBHead\n",
"from ppocr.modeling.heads.det_db_head import DBHead\n",
"import paddle \n",
"\n",
"# 2. 计算DBFPN网络输出结果\n",
"fake_inputs = paddle.randn([1, 3, 640, 640], dtype=\"float32\")\n",
"model_backbone = MobileNetV3()\n",
"in_channles = model_backbone.out_channels\n",
"model_fpn = DBFPN(in_channels=in_channles, out_channels=256)\n",
"outs = model_backbone(fake_inputs)\n",
"fpn_outs = model_fpn(outs)\n",
"\n",
"# 3. 声明Head网络\n",
"model_db_head = DBHead(in_channels=256)\n",
"\n",
"# 4. 打印DBhead网络\n",
"print(model_db_head)\n",
"\n",
"# 5. 计算Head网络的输出\n",
"db_head_outs = model_db_head(fpn_outs)\n",
"print(f\"The shape of fpn outs {fpn_outs.shape}\")\n",
"print(f\"The shape of DB head outs {db_head_outs['maps'].shape}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# 3 训练DB文字检测模型\n",
"PaddleOCR提供DB文本检测算法支持MobileNetV3、ResNet50_vd两种骨干网络可以根据需要选择相应的配置文件启动训练。\n",
"\n",
"本节以icdar15数据集、MobileNetV3作为骨干网络的DB检测模型即超轻量模型使用的配置为例介绍如何完成PaddleOCR中文字检测模型的训练、评估与测试。\n",
"\n",
"## 3.1 数据准备\n",
"\n",
"本次实验选取了场景文本检测和识别(Scene Text Detection and Recognition)任务最知名和常用的数据集ICDAR2015。icdar2015数据集的示意图如下图所示\n",
"\n",
"<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/e1b06e0c8e904a2aa412e9eea41f45cce3d58543232948fa88200298fd3cd2e4\" width = \"600\"></center>\n",
"<center><br>图 icdar2015数据集示意图 </br></center>\n",
"<br></br>\n",
"\n",
"该项目中已经下载了icdar2015数据集存放在 /home/aistudio/data/data96799 中,可以运行如下指令完成数据集解压,或者从链接中自行下载。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"!cd ~/data/data96799/ && tar xf icdar2015.tar "
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"运行上述指令后 ~/train_data/icdar2015/text_localization 有两个文件夹和两个文件,分别是:\n",
"```\n",
"~/train_data/icdar2015/text_localization \n",
" └─ icdar_c4_train_imgs/ icdar数据集的训练数据\n",
" └─ ch4_test_images/ icdar数据集的测试数据\n",
" └─ train_icdar2015_label.txt icdar数据集的训练标注\n",
" └─ test_icdar2015_label.txt icdar数据集的测试标注\n",
"```\n",
"提供的标注文件格式为:\n",
"```\n",
"\" 图像文件名 json.dumps编码的图像标注信息\"\n",
"ch4_test_images/img_61.jpg [{\"transcription\": \"MASA\", \"points\": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}]\n",
"```\n",
"\n",
"json.dumps编码前的图像标注信息是包含多个字典的list字典中的points表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 transcription中的字段表示当前文本框的文字在文本检测任务中并不需要这个信息。 如果您想在其他数据集上训练PaddleOCR可以按照上述形式构建标注文件。\n",
"\n",
"如果\"transcription\"字段的文字为'*'或者'###表示对应的标注可以被忽略掉因此如果没有文字标签可以将transcription字段设置为空字符串。\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## 3.2 数据预处理\n",
"\n",
"训练时对输入图片的格式、大小有一定的要求,同时,还需要根据标注信息获取阈值图以及概率图的真实标签。所以,在数据输入模型前,需要对数据进行预处理操作,使得图片和标签满足网络训练和预测的需要。另外,为了扩大训练数据集、抑制过拟合,提升模型的泛化能力,还需要使用了几种基础的数据增广方法。\n",
"\n",
"本实验的数据预处理共包括如下方法:\n",
"\n",
"- 图像解码将图像转为Numpy格式\n",
"- 标签编码解析txt文件中的标签信息并按统一格式进行保存\n",
"- 基础数据增广:包括:随机水平翻转、随机旋转,随机缩放,随机裁剪等;\n",
"- 获取阈值图标签:使用扩张的方式获取算法训练需要的阈值图标签;\n",
"- 获取概率图标签:使用收缩的方式获取算法训练需要的概率图标签;\n",
"- 归一化通过规范化手段把神经网络每层中任意神经元的输入值分布改变成均值为0方差为1的标准正太分布使得最优解的寻优过程明显会变得平缓训练过程更容易收敛\n",
"- 通道变换:图像的数据格式为[H, W, C](即高度、宽度和通道数),而神经网络使用的训练数据的格式为[C, H, W],因此需要对图像数据重新排列,例如[224, 224, 3]变为[3, 224, 224]\n",
"\n",
"\n",
"**图像解码**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import sys\n",
"import six\n",
"import cv2\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/data/imaug/operators.py\n",
"class DecodeImage(object):\n",
" \"\"\" decode image \"\"\"\n",
"\n",
" def __init__(self, img_mode='RGB', channel_first=False, **kwargs):\n",
" self.img_mode = img_mode\n",
" self.channel_first = channel_first\n",
"\n",
" def __call__(self, data):\n",
" img = data['image']\n",
" if six.PY2:\n",
" assert type(img) is str and len(\n",
" img) > 0, \"invalid input 'img' in DecodeImage\"\n",
" else:\n",
" assert type(img) is bytes and len(\n",
" img) > 0, \"invalid input 'img' in DecodeImage\"\n",
" # 1. 图像解码\n",
" img = np.frombuffer(img, dtype='uint8')\n",
" img = cv2.imdecode(img, 1)\n",
"\n",
" if img is None:\n",
" return None\n",
" if self.img_mode == 'GRAY':\n",
" img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)\n",
" elif self.img_mode == 'RGB':\n",
" assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)\n",
" img = img[:, :, ::-1]\n",
"\n",
" if self.channel_first:\n",
" img = img.transpose((2, 0, 1))\n",
" # 2. 解码后的图像放在字典中\n",
" data['image'] = img\n",
" return data"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"接下来从训练数据的标注中读取图像演示DecodeImage类的使用方式。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The first data in train_icdar2015_label.txt is as follows.\n",
" icdar_c4_train_imgs/img_61.jpg\t[{\"transcription\": \"###\", \"points\": [[427, 293], [469, 293], [468, 315], [425, 314]]}, {\"transcription\": \"###\", \"points\": [[480, 291], [651, 289], [650, 311], [479, 313]]}, {\"transcription\": \"Ave\", \"points\": [[655, 287], [698, 287], [696, 309], [652, 309]]}, {\"transcription\": \"West\", \"points\": [[701, 285], [759, 285], [759, 308], [701, 308]]}, {\"transcription\": \"YOU\", \"points\": [[1044, 531], [1074, 536], [1076, 585], [1046, 579]]}, {\"transcription\": \"CAN\", \"points\": [[1077, 535], [1114, 539], [1117, 595], [1079, 585]]}, {\"transcription\": \"PAY\", \"points\": [[1119, 539], [1160, 543], [1158, 601], [1120, 593]]}, {\"transcription\": \"LESS?\", \"points\": [[1164, 542], [1252, 545], [1253, 624], [1166, 602]]}, {\"transcription\": \"Singapore's\", \"points\": [[1032, 177], [1185, 73], [1191, 143], [1038, 223]]}, {\"transcription\": \"no.1\", \"points\": [[1190, 73], [1270, 19], [1278, 91], [1194, 133]]}]\n",
"\n"
]
}
],
"source": [
"import json\n",
"import cv2\n",
"import os\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"# 在notebook中使用matplotlib.pyplot绘图时需要添加该命令进行显示\n",
"%matplotlib inline\n",
"from PIL import Image\n",
"import numpy as np\n",
"\n",
"\n",
"label_path = \"/home/aistudio/data/data96799/icdar2015/text_localization/train_icdar2015_label.txt\"\n",
"img_dir = \"/home/aistudio/data/data96799/icdar2015/text_localization/\"\n",
"\n",
"# 1. 读取训练标签的第一条数据\n",
"f = open(label_path, \"r\")\n",
"lines = f.readlines()\n",
"\n",
"# 2. 取第一条数据\n",
"line = lines[0]\n",
"\n",
"print(\"The first data in train_icdar2015_label.txt is as follows.\\n\", line)\n",
"img_name, gt_label = line.strip().split(\"\\t\")\n",
"\n",
"# 3. 读取图像\n",
"image = open(os.path.join(img_dir, img_name), 'rb').read()\n",
"data = {'image': image, 'label': gt_label}\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"声明DecodeImage类解码图像并返回一个新的字典data。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The shape of decoded image is (720, 1280, 3)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAFdCAYAAAAwgXjMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvcuubcmSpvWZufsYc6619444ESdP3uoiQIgHKG5NQEJCokEXeIBq8QA8SzWgS5dOSfUM1QZUKIWUkJnnUieiduy91pxzDHczo2E+5tqRpKjTyFCegmmhpdjrMsfFh7v5b7/9ZkMigoc97GEPe9jDHvawh/3tmv5dX8DDHvawhz3sYQ972P8X7QGyHvawhz3sYQ972MN+AnuArIc97GEPe9jDHvawn8AeIOthD3vYwx72sIc97CewB8h62MMe9rCHPexhD/sJ7AGyHvawhz3sYQ972MN+AvtJQJaI/Bci8i9E5M9E5L//Kc7xsIc97GEPe9jDHvb7bPK33SdLRArwvwP/OfAXwD8H/puI+F//Vk/0sIc97GEPe9jDHvZ7bD8Fk/UfAn8WEf9HROzA/wT8Vz/BeR72sIc97GEPe9jDfm/tpwBZfwr8X198/xfzZw972MMe9rCHPexh/7+x+nd1YhH5x8A/BtBS/9Hz89eAUEphjIGIoqqYCAL55aA6/00QAhEgokTkH4kABIIdZ0JECI6/FQJHJX9OBBH5e5GJOSNQLUT4j675y8SqAKKCal4fbpgZ7pHXEuAeuDuB3z8VAaJ5byLcYa5bYGaAA4575PVI4OHz2gsiFRB03m9elc9rARVFVUBq3gMg+ByTyON5XoPkKB6jNP9zRIKIyE/EMYb5vcwjvY1H3I8i8wHcvxOZn8gxDXkbSTn+/otBFflrx5U5kMdAhSA/mrJBhM3zBYKCfHFeybExG1jvjDGOQcPc8QjCHQLWdeXd+3eEOzY6Nkb+zp37A83RYVhgTn6e43f+xf3KcZov/h2ICEtbWM8LOufAfX6+nYJ7Bl/e/hcRb4ec5pHHdc/7cHd67znngvu9MZ+pqs7nB4ISHljE2wkl76e2SmsNpMw18nYbcr8+v88NYo75/LzwdjPHXLjP/bdFOv8mfvTvYz7OBfnF/y3vKRx8/j+OeRVfzCO9z+RgToFj/svb+VQkl568ne64XvNg2J6jpErRer/xfA5vk/YYHwEQnfcWc8378S1fLnbVQq11jhnp2Mh7uD/7iPt43VeF3P9s/m6u7ekjEXk71RezRUWQ+YH7euTtGRzHF5kPORxVQZn+7Xhec/xE5D6X3n43n5VKfglQC5Qvfv//mMFvz+ztmr78GwEWcpsqf8Pn/Iuv8cW/53jNdYkb2FzHIWDyNp7H3305uT0gHA5fDrjnejmuy6f/4IvZDrlO86jHPpDnzLmreAQec0bORX+s7cOLxZy9Ms8V4Xc/k3uSzp/nsQ6/nHPjWEPyo7mUo5XnQvRH8+BYzxB5T188gnwib/ce8/dvPloQKZjPa5TIuRbHCuRHY+TTHWiUt+c952/eWwCWn5TDb8TdxwQjx0s0v9C51/o8voIUEEUlpsc+9mC/H9djzgdVzm0h5h4KgZRjDUOg/HjHe/NXP3z36bcR8Qd/fVb+dfspQNZfAn//i+//3vzZjywi/gnwTwA+fPg2/qP/+L+cm1SjlHRAFsLehKaNRQpLqZyqUOhIOK/2idP6DmFBdKGUAgXMN1ps8+FUdHmCUHZzaq308ZmlCpXAh1GXFaOALhQaAMuyMMYO4gk4IjCODSy/b61RSkFj5+cnoffB5XVn3wdmzuV15/X1lT6uDIdhkuCpNGpdcIz1K+W0PuFDefnhE8Fg9Cv7viO+ELpjtmNmtPpEK1/jCKKdIhAMzDe0GHVRnk5nPrz/A87vvoVyQqRQ2CjsIJ1yKry+dkQqy+kJcEIDcaGEUvxC0UGE4bKCVFyUKhAhEwhCqCCUHB9xLITejd47H95/xbIsoMJ+2+beI3Qb97HDjad2RlXnoslpfIwvLdBqRHQiCj4KboUW70kn4yAd8ysQaHFUznlNKGbG6VwZdqWPC+Ple374+JFt3xnDcYe+G5fLhZfLxrfffsu//x/8I9YW2Hbj83ff47cberuybzfwnR5AO3PphV99d+H1OvAYeGwgnVqVUgpFFry+ASipSq2VWiv/6X/yn/H+mydqVdyhaEvQK8K+d1wLIenEzYxSckPe952TQpjf94qQgjtcbzds3Li+vvLx40c+//DCx48f2a473gclOqUGz0+NWvIZFir75nx6uSSOFEcKdDf+7X/v3+Ef/MN/iLEipaLTmRaB8MEYA48btZywURkdRILSdpCB+iDoqCr7GHhkcGCulGgTXDIBcp+bjYMYFuM+H8K4b+gRnTBj7J0Yxn69wsjAJnTDhuMjAyY1mUGSUKwiDJDOUgaEUUVprXEuhkbOv+GBK+wIl/3Gd59/iZtyWt/z/t2392sefUPHuM/XpSoyg7dSClEWIjpIB90opeDu1LIgpaBS+cXP/5jT6YxqUKqAdET3nP89CIPRnTBDKUg4ojlGpRTMgu3mMOCpCWggVTAGoaBVpl9aWEqlhKMCTaAcgaTUCQAtgzOM2jQxkW48t8ZJK+elUVvJv2kgZWdZFtZaEIV1UdwdDRBd4d0JTso4N+q3H+BphaJ4KSiau6zoBJYTEknN30XN30WFSMDq8sdofAM85eaJwvTRHv8SuKHccL4HLij7EZVAdBg39u9+Ay83lqhglbgWzAaiRmkBjHk9ORfYd9h2ZB+MHoQXPn1S+p7rTmhctp1t74DyGtDDiWEzkEtgO8ag904MYfTC3hvXLmwebDi7Oz6MsIGEYaIMCrtDc6eiOJLHEMcYmF8RWzHynH2kX5YQYh90CSzyGrfuoMK2bZQCN+sZkKGU0hAaIMRQbNxwPD+vGbApQhEh+g4elPC5VpxRHNNARrC0D1x2YeiG6I4GxGjoPvLzpREomzm7KwGcikIoIhWViseg71dgUBfDcIzp5wbYHiy1URvcto57YoWmzwiN3ge7fcLqwigNE2EpzoqxqFHcMetYOLoUZFG0Fs7nMz/76hssOmVxQp3leQXN/arHoDZl33fO68qw2zFr+Z//x3/2578LIPopQNY/B/5dEfm3SHD1XwP/7f/bB0RA/AqhtOpUMcZw1lKQvVPLkpuUVooISG5uz8vAxkdqPSN6pkibqD6jGhGdUb5QSqNEIua1NpoGRZ3hg6KG4lArYemgSg2GzehIHOgoA9FgDIMIfAyUChrc7AKitCfFxMGc6p0Wg9LBzNl6OvPwQbcbjhFXEIyvnr8megMan364sUghQglTqhcQoUZQ/AoqWN1y0piBGOHBdh3E2LCbodJoy2Bd17nJ5AZDVFpZsAi8Dyi5cUbkpMVmJCck+BIm05JRhkrgPlhcENkxlI6iBEtRFjnhtxt960gRQtMzBaB+RKOa4C0G4T+iEjI6c6e44ntGIBFGeCUsGHyex8hN2X0gmqxGKQUbGf3UVrndLrRWKbrAunB6OuE44RsqpINdhdt18O79mdvtBXNh3G5QYDk1bLwgNnAbYIGZI9I4LTu2C+aOWeDiKDmPRIxiS96nvjE9v/iTP6KsyuVyyQ05gqIJ1EXyeyUyolLFNUGnhLG2QgnLeNOFOCgSFdalouczblfWtXKpzrpWwgauio9CKUKo4zha8jxaB20djOEMM4IEgx/ef835/Mxty+huXU8o+VxszGssz5SyEq6YTZCkCcDVJR1oCUoBp4A0QJGQSRgkozpGghCRgrjmZivJClKdwBBxhpdkT1rOy6WeMOuEDcbudOlsvuG9Z8zqk3jRZ4oYRY2IkQxWmfuw5Hh4gEUkeZFDTykNt7jPRwDRg+1pk7SZzEI4yMEUOBEDkUCi4CZAo7WF53PldHri6VQR8eQrRq5fKYGEUjzZiRKkr4gBGvhkuSMs55sLuGBe72xl6GRf1CEqIg3DkkcQGJNMk2CCNu6MWXCw11CkThB5MACSz0QrtSQLm4FTXjsypiOf2QNpiCpozC9BhWSU3NL
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 4. 声明DecodeImage类解码图像\n",
"decode_image = DecodeImage(img_mode='RGB', channel_first=False)\n",
"data = decode_image(data)\n",
"\n",
"# 5. 打印解码后图像的shape并可视化图像\n",
"print(\"The shape of decoded image is \", data['image'].shape)\n",
"\n",
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(data['image'])\n",
"src_img = data['image']"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"**标签编码**\n",
"\n",
"解析txt文件中的标签信息并按统一格式进行保存"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numpy as np\n",
"import string\n",
"import json\n",
"\n",
"# 详细实现参考: https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/data/imaug/label_ops.py#L38\n",
"class DetLabelEncode(object):\n",
" def __init__(self, **kwargs):\n",
" pass\n",
"\n",
" def __call__(self, data):\n",
" label = data['label']\n",
" # 1. 使用json读入标签\n",
" label = json.loads(label)\n",
" nBox = len(label)\n",
" boxes, txts, txt_tags = [], [], []\n",
" for bno in range(0, nBox):\n",
" box = label[bno]['points']\n",
" txt = label[bno]['transcription']\n",
" boxes.append(box)\n",
" txts.append(txt)\n",
" # 1.1 如果文本标注是*或者###,表示此标注无效\n",
" if txt in ['*', '###']:\n",
" txt_tags.append(True)\n",
" else:\n",
" txt_tags.append(False)\n",
" if len(boxes) == 0:\n",
" return None\n",
" boxes = self.expand_points_num(boxes)\n",
" boxes = np.array(boxes, dtype=np.float32)\n",
" txt_tags = np.array(txt_tags, dtype=np.bool)\n",
" \n",
" # 2. 得到文字、box等信息\n",
" data['polys'] = boxes\n",
" data['texts'] = txts\n",
" data['ignore_tags'] = txt_tags\n",
" return data"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"运行下述代码观察DetLabelEncode类解码标签前后的对比。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The label before decode are: [{\"transcription\": \"###\", \"points\": [[427, 293], [469, 293], [468, 315], [425, 314]]}, {\"transcription\": \"###\", \"points\": [[480, 291], [651, 289], [650, 311], [479, 313]]}, {\"transcription\": \"Ave\", \"points\": [[655, 287], [698, 287], [696, 309], [652, 309]]}, {\"transcription\": \"West\", \"points\": [[701, 285], [759, 285], [759, 308], [701, 308]]}, {\"transcription\": \"YOU\", \"points\": [[1044, 531], [1074, 536], [1076, 585], [1046, 579]]}, {\"transcription\": \"CAN\", \"points\": [[1077, 535], [1114, 539], [1117, 595], [1079, 585]]}, {\"transcription\": \"PAY\", \"points\": [[1119, 539], [1160, 543], [1158, 601], [1120, 593]]}, {\"transcription\": \"LESS?\", \"points\": [[1164, 542], [1252, 545], [1253, 624], [1166, 602]]}, {\"transcription\": \"Singapore's\", \"points\": [[1032, 177], [1185, 73], [1191, 143], [1038, 223]]}, {\"transcription\": \"no.1\", \"points\": [[1190, 73], [1270, 19], [1278, 91], [1194, 133]]}]\n",
"\n",
"\n",
"The polygon after decode are: [[[ 427. 293.]\n",
" [ 469. 293.]\n",
" [ 468. 315.]\n",
" [ 425. 314.]]\n",
"\n",
" [[ 480. 291.]\n",
" [ 651. 289.]\n",
" [ 650. 311.]\n",
" [ 479. 313.]]\n",
"\n",
" [[ 655. 287.]\n",
" [ 698. 287.]\n",
" [ 696. 309.]\n",
" [ 652. 309.]]\n",
"\n",
" [[ 701. 285.]\n",
" [ 759. 285.]\n",
" [ 759. 308.]\n",
" [ 701. 308.]]\n",
"\n",
" [[1044. 531.]\n",
" [1074. 536.]\n",
" [1076. 585.]\n",
" [1046. 579.]]\n",
"\n",
" [[1077. 535.]\n",
" [1114. 539.]\n",
" [1117. 595.]\n",
" [1079. 585.]]\n",
"\n",
" [[1119. 539.]\n",
" [1160. 543.]\n",
" [1158. 601.]\n",
" [1120. 593.]]\n",
"\n",
" [[1164. 542.]\n",
" [1252. 545.]\n",
" [1253. 624.]\n",
" [1166. 602.]]\n",
"\n",
" [[1032. 177.]\n",
" [1185. 73.]\n",
" [1191. 143.]\n",
" [1038. 223.]]\n",
"\n",
" [[1190. 73.]\n",
" [1270. 19.]\n",
" [1278. 91.]\n",
" [1194. 133.]]]\n",
"The text after decode are: ['###', '###', 'Ave', 'West', 'YOU', 'CAN', 'PAY', 'LESS?', \"Singapore's\", 'no.1']\n"
]
}
],
"source": [
"# 从PaddleOCR中import DetLabelEncode\n",
"from ppocr.data.imaug.label_ops import DetLabelEncode\n",
"\n",
"# 1. 声明标签解码的类\n",
"decode_label = DetLabelEncode()\n",
"\n",
"# 2. 打印解码前的标签\n",
"print(\"The label before decode are: \", data['label'])\n",
"\n",
"# 3. 标签解码\n",
"data = decode_label(data)\n",
"print(\"\\n\")\n",
"\n",
"# 4. 打印解码后的标签\n",
"print(\"The polygon after decode are: \", data['polys'])\n",
"print(\"The text after decode are: \", data['texts'])\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"**基础数据增广**\n",
"\n",
"数据增广是提高模型训练精度,增加模型泛化性的常用方法,文本检测常用的数据增广包括随机水平翻转、随机旋转、随机缩放以及随机裁剪等等。\n",
"\n",
"随机水平翻转、随机旋转、随机缩放的代码实现参考[代码](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/data/imaug/iaa_augment.py)。随机裁剪的数据增广代码实现参考[代码](https://github.com/PaddlePaddle/PaddleOCR/blob/81ee76ad7f9ff534a0ae5439d2a5259c4263993c/ppocr/data/imaug/random_crop_data.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L127)。\n",
"\n",
"\n",
"**获取阈值图标签**\n",
"\n",
"使用扩张的方式获取算法训练需要的阈值图标签;\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numpy as np\n",
"import cv2\n",
"\n",
"np.seterr(divide='ignore', invalid='ignore')\n",
"import pyclipper\n",
"from shapely.geometry import Polygon\n",
"import sys\n",
"import warnings\n",
"\n",
"warnings.simplefilter(\"ignore\")\n",
"\n",
"# 计算文本区域阈值图标签类\n",
"# 详细实现代码参考https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/data/imaug/make_border_map.py\n",
"class MakeBorderMap(object):\n",
" def __init__(self,\n",
" shrink_ratio=0.4,\n",
" thresh_min=0.3,\n",
" thresh_max=0.7,\n",
" **kwargs):\n",
" self.shrink_ratio = shrink_ratio\n",
" self.thresh_min = thresh_min\n",
" self.thresh_max = thresh_max\n",
"\n",
" def __call__(self, data):\n",
"\n",
" img = data['image']\n",
" text_polys = data['polys']\n",
" ignore_tags = data['ignore_tags']\n",
"\n",
" # 1. 生成空模版\n",
" canvas = np.zeros(img.shape[:2], dtype=np.float32)\n",
" mask = np.zeros(img.shape[:2], dtype=np.float32)\n",
"\n",
" for i in range(len(text_polys)):\n",
" if ignore_tags[i]:\n",
" continue\n",
"\n",
" # 2. draw_border_map函数根据解码后的box信息计算阈值图标签\n",
" self.draw_border_map(text_polys[i], canvas, mask=mask)\n",
" canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min\n",
"\n",
" data['threshold_map'] = canvas\n",
" data['threshold_mask'] = mask\n",
" return data\n",
"\n",
" def draw_border_map(self, polygon, canvas, mask):\n",
" polygon = np.array(polygon)\n",
" assert polygon.ndim == 2\n",
" assert polygon.shape[1] == 2\n",
"\n",
" polygon_shape = Polygon(polygon)\n",
" if polygon_shape.area <= 0:\n",
" return\n",
" # 多边形内缩\n",
" distance = polygon_shape.area * (\n",
" 1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length\n",
" subject = [tuple(l) for l in polygon]\n",
" padding = pyclipper.PyclipperOffset()\n",
" padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)\n",
" # 计算mask\n",
" padded_polygon = np.array(padding.Execute(distance)[0])\n",
" cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)\n",
"\n",
" xmin = padded_polygon[:, 0].min()\n",
" xmax = padded_polygon[:, 0].max()\n",
" ymin = padded_polygon[:, 1].min()\n",
" ymax = padded_polygon[:, 1].max()\n",
" width = xmax - xmin + 1\n",
" height = ymax - ymin + 1\n",
"\n",
" polygon[:, 0] = polygon[:, 0] - xmin\n",
" polygon[:, 1] = polygon[:, 1] - ymin\n",
"\n",
" xs = np.broadcast_to(\n",
" np.linspace(\n",
" 0, width - 1, num=width).reshape(1, width), (height, width))\n",
" ys = np.broadcast_to(\n",
" np.linspace(\n",
" 0, height - 1, num=height).reshape(height, 1), (height, width))\n",
"\n",
" distance_map = np.zeros(\n",
" (polygon.shape[0], height, width), dtype=np.float32)\n",
" for i in range(polygon.shape[0]):\n",
" j = (i + 1) % polygon.shape[0]\n",
" # 计算点到线的距离\n",
" absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])\n",
" distance_map[i] = np.clip(absolute_distance / distance, 0, 1)\n",
" distance_map = distance_map.min(axis=0)\n",
"\n",
" xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)\n",
" xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)\n",
" ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)\n",
" ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)\n",
" canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(\n",
" 1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,\n",
" xmin_valid - xmin:xmax_valid - xmax + width],\n",
" canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7ff488039d10>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAFdCAYAAAAwgXjMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvcuubcmSpvWZufsYc6619444ESdP3uoiQIgHKG5NQEJCokEXeIBq8QA8SzWgS5dOSfUM1QZUKIWUkJnnUieiduy91pxzDHczo2E+5tqRpKjTyFCegmmhpdjrMsfFh7v5b7/9ZkMigoc97GEPe9jDHvawh/3tmv5dX8DDHvawhz3sYQ972P8X7QGyHvawhz3sYQ972MN+AnuArIc97GEPe9jDHvawn8AeIOthD3vYwx72sIc97CewB8h62MMe9rCHPexhD/sJ7AGyHvawhz3sYQ972MN+AvtJQJaI/Bci8i9E5M9E5L//Kc7xsIc97GEPe9jDHvb7bPK33SdLRArwvwP/OfAXwD8H/puI+F//Vk/0sIc97GEPe9jDHvZ7bD8Fk/UfAn8WEf9HROzA/wT8Vz/BeR72sIc97GEPe9jDfm/tpwBZfwr8X198/xfzZw972MMe9rCHPexh/7+x+nd1YhH5x8A/BtBS/9Hz89eAUEphjIGIoqqYCAL55aA6/00QAhEgokTkH4kABIIdZ0JECI6/FQJHJX9OBBH5e5GJOSNQLUT4j675y8SqAKKCal4fbpgZ7pHXEuAeuDuB3z8VAaJ5byLcYa5bYGaAA4575PVI4OHz2gsiFRB03m9elc9rARVFVUBq3gMg+ByTyON5XoPkKB6jNP9zRIKIyE/EMYb5vcwjvY1H3I8i8wHcvxOZn8gxDXkbSTn+/otBFflrx5U5kMdAhSA/mrJBhM3zBYKCfHFeybExG1jvjDGOQcPc8QjCHQLWdeXd+3eEOzY6Nkb+zp37A83RYVhgTn6e43f+xf3KcZov/h2ICEtbWM8LOufAfX6+nYJ7Bl/e/hcRb4ec5pHHdc/7cHd67znngvu9MZ+pqs7nB4ISHljE2wkl76e2SmsNpMw18nYbcr8+v88NYo75/LzwdjPHXLjP/bdFOv8mfvTvYz7OBfnF/y3vKRx8/j+OeRVfzCO9z+RgToFj/svb+VQkl568ne64XvNg2J6jpErRer/xfA5vk/YYHwEQnfcWc8378S1fLnbVQq11jhnp2Mh7uD/7iPt43VeF3P9s/m6u7ekjEXk71RezRUWQ+YH7euTtGRzHF5kPORxVQZn+7Xhec/xE5D6X3n43n5VKfglQC5Qvfv//mMFvz+ztmr78GwEWcpsqf8Pn/Iuv8cW/53jNdYkb2FzHIWDyNp7H3305uT0gHA5fDrjnejmuy6f/4IvZDrlO86jHPpDnzLmreAQec0bORX+s7cOLxZy9Ms8V4Xc/k3uSzp/nsQ6/nHPjWEPyo7mUo5XnQvRH8+BYzxB5T188gnwib/ce8/dvPloQKZjPa5TIuRbHCuRHY+TTHWiUt+c952/eWwCWn5TDb8TdxwQjx0s0v9C51/o8voIUEEUlpsc+9mC/H9djzgdVzm0h5h4KgZRjDUOg/HjHe/NXP3z36bcR8Qd/fVb+dfspQNZfAn//i+//3vzZjywi/gnwTwA+fPg2/qP/+L+cm1SjlHRAFsLehKaNRQpLqZyqUOhIOK/2idP6DmFBdKGUAgXMN1ps8+FUdHmCUHZzaq308ZmlCpXAh1GXFaOALhQaAMuyMMYO4gk4IjCODSy/b61RSkFj5+cnoffB5XVn3wdmzuV15/X1lT6uDIdhkuCpNGpdcIz1K+W0PuFDefnhE8Fg9Cv7viO+ELpjtmNmtPpEK1/jCKKdIhAMzDe0GHVRnk5nPrz/A87vvoVyQqRQ2CjsIJ1yKry+dkQqy+kJcEIDcaGEUvxC0UGE4bKCVFyUKhAhEwhCqCCUHB9xLITejd47H95/xbIsoMJ+2+beI3Qb97HDjad2RlXnoslpfIwvLdBqRHQiCj4KboUW70kn4yAd8ysQaHFUznlNKGbG6VwZdqWPC+Ple374+JFt3xnDcYe+G5fLhZfLxrfffsu//x/8I9YW2Hbj83ff47cberuybzfwnR5AO3PphV99d+H1OvAYeGwgnVqVUgpFFry+ASipSq2VWiv/6X/yn/H+mydqVdyhaEvQK8K+d1wLIenEzYxSckPe952TQpjf94qQgjtcbzds3Li+vvLx40c+//DCx48f2a473gclOqUGz0+NWvIZFir75nx6uSSOFEcKdDf+7X/v3+Ef/MN/iLEipaLTmRaB8MEYA48btZywURkdRILSdpCB+iDoqCr7GHhkcGCulGgTXDIBcp+bjYMYFuM+H8K4b+gRnTBj7J0Yxn69wsjAJnTDhuMjAyY1mUGSUKwiDJDOUgaEUUVprXEuhkbOv+GBK+wIl/3Gd59/iZtyWt/z/t2392sefUPHuM/XpSoyg7dSClEWIjpIB90opeDu1LIgpaBS+cXP/5jT6YxqUKqAdET3nP89CIPRnTBDKUg4ojlGpRTMgu3mMOCpCWggVTAGoaBVpl9aWEqlhKMCTaAcgaTUCQAtgzOM2jQxkW48t8ZJK+elUVvJv2kgZWdZFtZaEIV1UdwdDRBd4d0JTso4N+q3H+BphaJ4KSiau6zoBJYTEknN30XN30WFSMDq8sdofAM85eaJwvTRHv8SuKHccL4HLij7EZVAdBg39u9+Ay83lqhglbgWzAaiRmkBjHk9ORfYd9h2ZB+MHoQXPn1S+p7rTmhctp1t74DyGtDDiWEzkEtgO8ag904MYfTC3hvXLmwebDi7Oz6MsIGEYaIMCrtDc6eiOJLHEMcYmF8RWzHynH2kX5YQYh90CSzyGrfuoMK2bZQCN+sZkKGU0hAaIMRQbNxwPD+vGbApQhEh+g4elPC5VpxRHNNARrC0D1x2YeiG6I4GxGjoPvLzpREomzm7KwGcikIoIhWViseg71dgUBfDcIzp5wbYHiy1URvcto57YoWmzwiN3ge7fcLqwigNE2EpzoqxqFHcMetYOLoUZFG0Fs7nMz/76hssOmVxQp3leQXN/arHoDZl33fO68qw2zFr+Z//x3/2578LIPopQNY/B/5dEfm3SHD1XwP/7f/bB0RA/AqhtOpUMcZw1lKQvVPLkpuUVooISG5uz8vAxkdqPSN6pkibqD6jGhGdUb5QSqNEIua1NpoGRZ3hg6KG4lArYemgSg2GzehIHOgoA9FgDIMIfAyUChrc7AKitCfFxMGc6p0Wg9LBzNl6OvPwQbcbjhFXEIyvnr8megMan364sUghQglTqhcQoUZQ/AoqWN1y0piBGOHBdh3E2LCbodJoy2Bd17nJ5AZDVFpZsAi8Dyi5cUbkpMVmJCck+BIm05JRhkrgPlhcENkxlI6iBEtRFjnhtxt960gRQtMzBaB+RKOa4C0G4T+iEjI6c6e44ntGIBFGeCUsGHyex8hN2X0gmqxGKQUbGf3UVrndLrRWKbrAunB6OuE44RsqpINdhdt18O79mdvtBXNh3G5QYDk1bLwgNnAbYIGZI9I4LTu2C+aOWeDiKDmPRIxiS96nvjE9v/iTP6KsyuVyyQ05gqIJ1EXyeyUyolLFNUGnhLG2QgnLeNOFOCgSFdalouczblfWtXKpzrpWwgauio9CKUKo4zha8jxaB20djOEMM4IEgx/ef835/Mxty+huXU8o+VxszGssz5SyEq6YTZCkCcDVJR1oCUoBp4A0QJGQSRgkozpGghCRgrjmZivJClKdwBBxhpdkT1rOy6WeMOuEDcbudOlsvuG9Z8zqk3jRZ4oYRY2IkQxWmfuw5Hh4gEUkeZFDTykNt7jPRwDRg+1pk7SZzEI4yMEUOBEDkUCi4CZAo7WF53PldHri6VQR8eQrRq5fKYGEUjzZiRKkr4gBGvhkuSMs55sLuGBe72xl6GRf1CEqIg3DkkcQGJNMk2CCNu6MWXCw11CkThB5MACSz0QrtSQLm4FTXjsypiOf2QNpiCpozC9BhWSU3NL
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAFdCAYAAAAwgXjMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzs3VmsZdl93/fv2uOZ53PuXFPX1Owustkk1aRMOTQpKhYlR3lIZDmBIhkK+CI9BPGDhbwYAfKgBEhiIUYUEJYhMjBCCbINyrIkSk1SJmlKFEWym9VzVXWNdzz3zNOeVx72qVvVzSb7VtW9NXT9P0Dh3rvvuXuvcx7Yf671W/+ltNYIIYQQQoiDZTzoAQghhBBCvBdJkSWEEEIIcQikyBJCCCGEOARSZAkhhBBCHAIpsoQQQgghDoEUWUIIIYQQh+BQiiyl1N9XSr2ulLqolPrNw3iGEEIIIcTDTB10nyyllAm8AXwauAF8B/hHWutXDvRBQgghhBAPscOYyfoJ4KLW+k2tdQB8EfiFQ3iOEEIIIcRD6zCKrBXg+m0/35hfE0IIIYR4bFgP6sFKqc8CnwUwMT+Uo/SghiKEEEIIsW8jerta6+a7ve4wiqx1YO22n1fn195Ca/054HMAJVXTz6lPHcJQhBBCCCEO1vP6D6/u53WHUWR9BzillDpOWlz9EvDfHMJzhBBCCCEOjZHPozIuAMlwjA6DO/r7Ay+ytNaRUuo3gC8DJvCvtNYvH/RzhBBCCCEOg7W2SnCsiV+yiN00vm4PI5y+j7ndh2v7vM9hDE5r/SfAnxzGvYUQQgghDoNZKcNSC3+hyPCoS1BSxOlEFs7QINO3yeXtB1tkCSGEEEI8SsxSieTkGt5ClmnDYryqCCqaJJsAEPQNorxBYrr7vqcUWUIIIYR4rFmrK8SLVUbH88zqBl5dMVuNsKse5bwHQDdfJLEcVLz/7ldSZAkhhBDisWQdP0q4WGG04DKtm0xWFX4tQdU9Ti+1OVPe5gP5tPXnF3Mf4ZJq4fkykyWEEEII8Y6U7WCuLjF7osGsaePVDbw6eCsh2dqMleqAn2pe5MO5yzzrdmnHBl92n8J0YhJ7/8cRSpElhBBCiMeG2aijF5t4C3lGRxy8hsKvacJqRGN5wBPVXZ4ubvDzxRd5xnX5QWBSVBEZM8QwExJz/8+SIksIIYQQ73lGPo9aWWT6RI1Zw8KrKyYrmrgWUKxPOF3p88nG6zybvcKH3SkvBg7/V+8o78vcYM0ZUbACHCdi6shMlhBCCCEEANbSIvFiHW8hx2gtLbD8mkatzFiuDXmivMszxRv8fPE8Ry2H7/s2fzr8ANdmVY46bXKGR9YIsM0YbUmRJYQQQgiBtbhA8MQiXstl2jCYrEJQizFqPqeWdniqvMkz+Ws8614nozR/PKnz/elRvrnzBK3ciGni4iqbrBniWDHalCJLCCGEEI8xI5dDLS8QLJYZHcswqxv4NfCXAwr1KWuVPn+veXN5cMwgSfj34yf5wXiV1/sLXN+uUjriMUlcfB2SMwIsIwEpsoQQQgjxWDJMzLNPECwUmDWd+eyVJqxFOFWPDy5tcq68wQdzV/m53ICLoc+Xxkd4cXKEb+0cp90rknRcdC5mEjqMkgyDJKBsTcnbATjJvociRZYQQggh3hNuHoszPV5m2ryVvYqWA6q1MccqXT7VeI0PZa7wIRe+7jl8b3aa86MVLvSbbK1Xsbo22Z5ieizBjyymscsk0WRUiGtFGJYUWUIIIYR4jFirK8RLNbxWhuERC6+m8OsJuhZyZLnDqXKbZ4rX+AeFVykbJv/Jy/Nng/fz8mCJ6/0K424Od8PG7SjsqWa6pogSg2niMNUmecOft3GQ5UIhhBBCPAbMZhO9VGd0osS0YeLVFdOVGFULaNZGnCh3+Jn6yzybucZTtsN/mLb4/vQYL4+WOL+5jNfJYvdMCl1F8XpCth0wWnVQocEssBlGGUaJQ8WcUrR8LDve99ikyBJCCCHEI0e5LubyIuFSBa/lMlyz0mB7PSa7Mma1MuBMeZtnC1f5+fxlQq35s1mFP++f49XBAhu9MsF6nkzXINOF7G5C8coUFWuMyMaIIIoNZrHDRDvkDB/XjLAsKbKEEEII8R6lXBd15jizVp5Z02bWMJisaKJaRLY249ziJu8vrvNs7grPul1uRBZ/4x3nB+M1vr19lG4vj+665DcMMruabCcms+NjXlxHry5gxKAiRRSZ+InJJHFZsfp7vbL2S4osIYQQQjwSjGIRtdQiXCoxPJrBqyu8miaoR5SWRxyt9niytMXPlV/gWcdjO474y9kKX+6e48KgyVa3RLKeJdM1cLua0rWQzPYUc6tHtL5BDJjNGmagMUJFHJmMQ5d+nOOs06ZkeWSdcN/jlSJLCCGEEA89o1hEnz7CrJVj1rAYr6TB9rgWUaxNeHbxBucK63w49yYfdgK+Hzh8b3aGF0Zr/O3WGsNeDrNrk9swyO5qsrsRuUs92NwhGg5vPSiKMUKNESnCSOHFNtPEJac0OdPHNmR3oRBCCCHeI24eizM+VmDWMPBqitlqjFn1Wa4NOV1p859XX+LZzA2OWg5/PGnw3ckxXhkucbVfZbhZxOmYuF1F8XpMth3ibA6JX7/4Q89SYYSKNCoGHRn4sYWnbfLKIGcEuGa0/3Ef5IcghBBCCHGQzFMnCBbLzFoOozUTv6YJajH525YHn8lf42fzG3Tj9FicP+uly4Mb3RJhL0P2hkWmq8nuxhQvT/aWB99RGGGECSoCIrXXKytn2Hu9svZLiiwhhBBCPFwME2ttmWihQv94Pj0Spw6zlQi3NuNodcjfbV3kQ7nLPOvukFcGXxof4XuTY7w2WODiZouk6+B2THJd0tmrHR97s0/05hV+XJmkPQ8zTDBCUKGBH1kM4iwWJhVzSsHy9/02pMgSQgghxANl1muoXA5sC5IEncvgLxSYthxGawZ+TRPWYqrLA05UOzxV2uTnSi/wQcfgjdDkq94af959mkuDOu1uCXUjQ66nyHQ02d2EwptD1FaHaHvnXceigxAVJhgRECuCyGQW28x0QN7wyZoSfBdCCCHEQ07ZDuZii2ilRpSzSRwDlWhix2B2s7HocoKuB1SrEz6yeI33F27wkeybnLJCvu4V+d7sNC8O13hha4VJN4vVtcltKLK7CdndiMz2FC5cJZ5O9zeoMMSYF1m398qa6nivV9Z+SZElhBBCiPvKyOcxFppErRLjRobJkkmYVyQOkIC2wK9qolrIwkqPk5VdzhXX+fniDzhru3w3gH/e/QgvD5e43K/R6xaw1l0K3XT2qnTVx90awWabuNe7o7Elnofhh+kOw1ARBhaTyKGfQEn5FC1v3/eSIksIIYQQ942Ry8HxNbxWHq9u4VUNpkuKsKBJMgkqVmhDY9QDFqojPrpwhWfy1/hI5iplI+Y/TAt8f3qMb7afYL1bxu+mx+Lk1+ezV+2QzKUd4vVNdLT/WafbqTBGRaAiiBMDL7aYJhY5IyJnBPu+jxRZQgghhLgvzIUWulljdrTIrGbhVxV+BbzFCKsckM8GxLGBUprVyoAnK1t8pvwiz7p9Qq3595PTvDA+wuuDFpc3GqiuQ7Zj4HahdC0i0/awtvpEV6/f20CjeL5ceKtXlqctGiokZ0rwXQghhBAPEev4UaJmCb+RYbw8L7CqmrASU1gcs1waspwfMIkcLJXwdHGDZ3NX+JlcyHd9m+96x3i+8ySX+3U63QLWhovbmYfbOzG5y0PU5g5Rp3vPY1VhhBHpdFZt3itroh3yypeZLCGEEEI8BJTCWl4iaVYYHS3gVU38qmK6qIkqIW7VY6085iONq7wvt8G5zHXeDFos2z0+7ARsxgH/vPc050ervN5vsb5Rw+xaZLsGhRs6XRrcmf74vld3w/MxQzBCIDTwIot+nKdkhBSN2b5vI0WWEEIIIQ6eUpgnjxM2i/h1l/FSWmAFFU286FOpTFkpDzhZbPOflV7jnLPDgunQjmdUDI+rUcKXJ0/z1d2zXO1VGfZyOBs2mc7NcLuHvTVCb2wTjUYHOnQ
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 从PaddleOCR中import MakeBorderMap\n",
"from ppocr.data.imaug.make_border_map import MakeBorderMap\n",
"\n",
"# 1. 声明MakeBorderMap函数\n",
"generate_text_border = MakeBorderMap()\n",
"\n",
"# 2. 根据解码后的输入数据计算bordermap信息\n",
"data = generate_text_border(data)\n",
"\n",
"# 3. 阈值图可视化\n",
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(src_img)\n",
"\n",
"text_border_map = data['threshold_map']\n",
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(text_border_map)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"**获取概率图标签**\n",
"\n",
"使用收缩的方式获取算法训练需要的概率图标签;\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"\n",
"import numpy as np\n",
"import cv2\n",
"from shapely.geometry import Polygon\n",
"import pyclipper\n",
"\n",
"# 计算概率图标签\n",
"# 详细代码实现参考: https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/data/imaug/make_shrink_map.py\n",
"class MakeShrinkMap(object):\n",
" r'''\n",
" Making binary mask from detection data with ICDAR format.\n",
" Typically following the process of class `MakeICDARData`.\n",
" '''\n",
"\n",
" def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):\n",
" self.min_text_size = min_text_size\n",
" self.shrink_ratio = shrink_ratio\n",
"\n",
" def __call__(self, data):\n",
" image = data['image']\n",
" text_polys = data['polys']\n",
" ignore_tags = data['ignore_tags']\n",
"\n",
" h, w = image.shape[:2]\n",
" # 1. 校验文本检测标签\n",
" text_polys, ignore_tags = self.validate_polygons(text_polys,\n",
" ignore_tags, h, w)\n",
" gt = np.zeros((h, w), dtype=np.float32)\n",
" mask = np.ones((h, w), dtype=np.float32)\n",
"\n",
" # 2. 根据文本检测框计算文本区域概率图\n",
" for i in range(len(text_polys)):\n",
" polygon = text_polys[i]\n",
" height = max(polygon[:, 1]) - min(polygon[:, 1])\n",
" width = max(polygon[:, 0]) - min(polygon[:, 0])\n",
" if ignore_tags[i] or min(height, width) < self.min_text_size:\n",
" cv2.fillPoly(mask,\n",
" polygon.astype(np.int32)[np.newaxis, :, :], 0)\n",
" ignore_tags[i] = True\n",
" else:\n",
" # 多边形内缩\n",
" polygon_shape = Polygon(polygon)\n",
" subject = [tuple(l) for l in polygon]\n",
" padding = pyclipper.PyclipperOffset()\n",
" padding.AddPath(subject, pyclipper.JT_ROUND,\n",
" pyclipper.ET_CLOSEDPOLYGON)\n",
" shrinked = []\n",
"\n",
" # Increase the shrink ratio every time we get multiple polygon returned back\n",
" possible_ratios = np.arange(self.shrink_ratio, 1,\n",
" self.shrink_ratio)\n",
" np.append(possible_ratios, 1)\n",
" # print(possible_ratios)\n",
" for ratio in possible_ratios:\n",
" # print(f\"Change shrink ratio to {ratio}\")\n",
" distance = polygon_shape.area * (\n",
" 1 - np.power(ratio, 2)) / polygon_shape.length\n",
" shrinked = padding.Execute(-distance)\n",
" if len(shrinked) == 1:\n",
" break\n",
"\n",
" if shrinked == []:\n",
" cv2.fillPoly(mask,\n",
" polygon.astype(np.int32)[np.newaxis, :, :], 0)\n",
" ignore_tags[i] = True\n",
" continue\n",
" # 填充\n",
" for each_shrink in shrinked:\n",
" shrink = np.array(each_shrink).reshape(-1, 2)\n",
" cv2.fillPoly(gt, [shrink.astype(np.int32)], 1)\n",
"\n",
" data['shrink_map'] = gt\n",
" data['shrink_mask'] = mask\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7ff43450c8d0>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAFdCAYAAAAwgXjMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvcuubcmSpvWZufsYc6619444ESdP3uoiQIgHKG5NQEJCokEXeIBq8QA8SzWgS5dOSfUM1QZUKIWUkJnnUieiduy91pxzDHczo2E+5tqRpKjTyFCegmmhpdjrMsfFh7v5b7/9ZkMigoc97GEPe9jDHvawh/3tmv5dX8DDHvawhz3sYQ972P8X7QGyHvawhz3sYQ972MN+AnuArIc97GEPe9jDHvawn8AeIOthD3vYwx72sIc97CewB8h62MMe9rCHPexhD/sJ7AGyHvawhz3sYQ972MN+AvtJQJaI/Bci8i9E5M9E5L//Kc7xsIc97GEPe9jDHvb7bPK33SdLRArwvwP/OfAXwD8H/puI+F//Vk/0sIc97GEPe9jDHvZ7bD8Fk/UfAn8WEf9HROzA/wT8Vz/BeR72sIc97GEPe9jDfm/tpwBZfwr8X198/xfzZw972MMe9rCHPexh/7+x+nd1YhH5x8A/BtBS/9Hz89eAUEphjIGIoqqYCAL55aA6/00QAhEgokTkH4kABIIdZ0JECI6/FQJHJX9OBBH5e5GJOSNQLUT4j675y8SqAKKCal4fbpgZ7pHXEuAeuDuB3z8VAaJ5byLcYa5bYGaAA4575PVI4OHz2gsiFRB03m9elc9rARVFVUBq3gMg+ByTyON5XoPkKB6jNP9zRIKIyE/EMYb5vcwjvY1H3I8i8wHcvxOZn8gxDXkbSTn+/otBFflrx5U5kMdAhSA/mrJBhM3zBYKCfHFeybExG1jvjDGOQcPc8QjCHQLWdeXd+3eEOzY6Nkb+zp37A83RYVhgTn6e43f+xf3KcZov/h2ICEtbWM8LOufAfX6+nYJ7Bl/e/hcRb4ec5pHHdc/7cHd67znngvu9MZ+pqs7nB4ISHljE2wkl76e2SmsNpMw18nYbcr8+v88NYo75/LzwdjPHXLjP/bdFOv8mfvTvYz7OBfnF/y3vKRx8/j+OeRVfzCO9z+RgToFj/svb+VQkl568ne64XvNg2J6jpErRer/xfA5vk/YYHwEQnfcWc8378S1fLnbVQq11jhnp2Mh7uD/7iPt43VeF3P9s/m6u7ekjEXk71RezRUWQ+YH7euTtGRzHF5kPORxVQZn+7Xhec/xE5D6X3n43n5VKfglQC5Qvfv//mMFvz+ztmr78GwEWcpsqf8Pn/Iuv8cW/53jNdYkb2FzHIWDyNp7H3305uT0gHA5fDrjnejmuy6f/4IvZDrlO86jHPpDnzLmreAQec0bORX+s7cOLxZy9Ms8V4Xc/k3uSzp/nsQ6/nHPjWEPyo7mUo5XnQvRH8+BYzxB5T188gnwib/ce8/dvPloQKZjPa5TIuRbHCuRHY+TTHWiUt+c952/eWwCWn5TDb8TdxwQjx0s0v9C51/o8voIUEEUlpsc+9mC/H9djzgdVzm0h5h4KgZRjDUOg/HjHe/NXP3z36bcR8Qd/fVb+dfspQNZfAn//i+//3vzZjywi/gnwTwA+fPg2/qP/+L+cm1SjlHRAFsLehKaNRQpLqZyqUOhIOK/2idP6DmFBdKGUAgXMN1ps8+FUdHmCUHZzaq308ZmlCpXAh1GXFaOALhQaAMuyMMYO4gk4IjCODSy/b61RSkFj5+cnoffB5XVn3wdmzuV15/X1lT6uDIdhkuCpNGpdcIz1K+W0PuFDefnhE8Fg9Cv7viO+ELpjtmNmtPpEK1/jCKKdIhAMzDe0GHVRnk5nPrz/A87vvoVyQqRQ2CjsIJ1yKry+dkQqy+kJcEIDcaGEUvxC0UGE4bKCVFyUKhAhEwhCqCCUHB9xLITejd47H95/xbIsoMJ+2+beI3Qb97HDjad2RlXnoslpfIwvLdBqRHQiCj4KboUW70kn4yAd8ysQaHFUznlNKGbG6VwZdqWPC+Ple374+JFt3xnDcYe+G5fLhZfLxrfffsu//x/8I9YW2Hbj83ff47cberuybzfwnR5AO3PphV99d+H1OvAYeGwgnVqVUgpFFry+ASipSq2VWiv/6X/yn/H+mydqVdyhaEvQK8K+d1wLIenEzYxSckPe952TQpjf94qQgjtcbzds3Li+vvLx40c+//DCx48f2a473gclOqUGz0+NWvIZFir75nx6uSSOFEcKdDf+7X/v3+Ef/MN/iLEipaLTmRaB8MEYA48btZywURkdRILSdpCB+iDoqCr7GHhkcGCulGgTXDIBcp+bjYMYFuM+H8K4b+gRnTBj7J0Yxn69wsjAJnTDhuMjAyY1mUGSUKwiDJDOUgaEUUVprXEuhkbOv+GBK+wIl/3Gd59/iZtyWt/z/t2392sefUPHuM/XpSoyg7dSClEWIjpIB90opeDu1LIgpaBS+cXP/5jT6YxqUKqAdET3nP89CIPRnTBDKUg4ojlGpRTMgu3mMOCpCWggVTAGoaBVpl9aWEqlhKMCTaAcgaTUCQAtgzOM2jQxkW48t8ZJK+elUVvJv2kgZWdZFtZaEIV1UdwdDRBd4d0JTso4N+q3H+BphaJ4KSiau6zoBJYTEknN30XN30WFSMDq8sdofAM85eaJwvTRHv8SuKHccL4HLij7EZVAdBg39u9+Ay83lqhglbgWzAaiRmkBjHk9ORfYd9h2ZB+MHoQXPn1S+p7rTmhctp1t74DyGtDDiWEzkEtgO8ag904MYfTC3hvXLmwebDi7Oz6MsIGEYaIMCrtDc6eiOJLHEMcYmF8RWzHynH2kX5YQYh90CSzyGrfuoMK2bZQCN+sZkKGU0hAaIMRQbNxwPD+vGbApQhEh+g4elPC5VpxRHNNARrC0D1x2YeiG6I4GxGjoPvLzpREomzm7KwGcikIoIhWViseg71dgUBfDcIzp5wbYHiy1URvcto57YoWmzwiN3ge7fcLqwigNE2EpzoqxqFHcMetYOLoUZFG0Fs7nMz/76hssOmVxQp3leQXN/arHoDZl33fO68qw2zFr+Z//x3/2578LIPopQNY/B/5dEfm3SHD1XwP/7f/bB0RA/AqhtOpUMcZw1lKQvVPLkpuUVooISG5uz8vAxkdqPSN6pkibqD6jGhGdUb5QSqNEIua1NpoGRZ3hg6KG4lArYemgSg2GzehIHOgoA9FgDIMIfAyUChrc7AKitCfFxMGc6p0Wg9LBzNl6OvPwQbcbjhFXEIyvnr8megMan364sUghQglTqhcQoUZQ/AoqWN1y0piBGOHBdh3E2LCbodJoy2Bd17nJ5AZDVFpZsAi8Dyi5cUbkpMVmJCck+BIm05JRhkrgPlhcENkxlI6iBEtRFjnhtxt960gRQtMzBaB+RKOa4C0G4T+iEjI6c6e44ntGIBFGeCUsGHyex8hN2X0gmqxGKQUbGf3UVrndLrRWKbrAunB6OuE44RsqpINdhdt18O79mdvtBXNh3G5QYDk1bLwgNnAbYIGZI9I4LTu2C+aOWeDiKDmPRIxiS96nvjE9v/iTP6KsyuVyyQ05gqIJ1EXyeyUyolLFNUGnhLG2QgnLeNOFOCgSFdalouczblfWtXKpzrpWwgauio9CKUKo4zha8jxaB20djOEMM4IEgx/ef835/Mxty+huXU8o+VxszGssz5SyEq6YTZCkCcDVJR1oCUoBp4A0QJGQSRgkozpGghCRgrjmZivJClKdwBBxhpdkT1rOy6WeMOuEDcbudOlsvuG9Z8zqk3jRZ4oYRY2IkQxWmfuw5Hh4gEUkeZFDTykNt7jPRwDRg+1pk7SZzEI4yMEUOBEDkUCi4CZAo7WF53PldHri6VQR8eQrRq5fKYGEUjzZiRKkr4gBGvhkuSMs55sLuGBe72xl6GRf1CEqIg3DkkcQGJNMk2CCNu6MWXCw11CkThB5MACSz0QrtSQLm4FTXjsypiOf2QNpiCpozC9BhWSU3NL
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAFdCAYAAAAwgXjMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAGBZJREFUeJzt3W/MpWddJ/DvbzultbilLbqTdqbZ1jDBsGYd6gRqNIalq6WVMLxArCEy1iazL1gXVxMp+oJssi8wuxEh2a2ZULU1LH+ssJ2QSi0FYky2XVpsKoK1A1I7038KpaBEKOtvX5z7gdOhZZ5n5lzznPM8n09y59z3dV/nnOtcuTvz7XVd9z3V3QEAYLH+xWY3AABgKxKyAAAGELIAAAYQsgAABhCyAAAGELIAAAYYErKq6lVV9UBVHamq60d8BwDAMqtFPyerqs5I8tdJfjLJ0SSfTPJz3f2ZhX4RAMASGzGS9bIkR7r78939jSTvS7J/wPcAACytESFrV5KH546PTmUAANvGjs364qo6mORgkpyRM37knJy7WU0BAFi3r+bJv+/u7z9RvREh61iSi+eOd09lz9Ddh5IcSpJz64J+eV0xoCkAAIv10b7lofXUGzFd+Mkke6rq0qp6XpJrkhwe8D0AAEtr4SNZ3f3NqvqPSW5PckaS3+3uv1z09wAALLMha7K6+7Ykt434bACAVeCJ7wAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQAAAwhZAAADCFkAAAMIWQDAtnf7I/fl9kfuW+hnClkAwLa26HC1ZseQTwUAWHKjwtUaI1kAwLYzOmAlRrIAgG3kdISrNUayAIBt4XQGrETIAgC2uI3cObjIICZkAQBb1ukevZpnTRYAsOVsZrhaYyQLANhSliFgJUIWALCFLEvASkwXAgBbwDKFqzVGsgCAlbaMASsxkgUArKhlDVdrjGQBACtlI8+9OtnPXwQhCwBgACELAGAAa7IAACZXXrR3YZ8lZAEA294iw9UaIQsA2LZGhKs1J1yTVVW/W1VPVNWn58ouqKo7qurB6fX8qbyq6l1VdaSq7q+qy4a1HADgJF150d6hAStZ38L330/yquPKrk9yZ3fvSXLndJwkVyXZM20Hk9ywmGYCAJy60xGu1pxwurC7/7SqLjmueH+SV0z7NyX5RJK3TOU3d3cnuauqzquqC7v70UU1GABgo05XsJp3smuyds4Fp8eS7Jz2dyV5eK7e0alMyAIATrvNCFdrTnnhe3d3VfVG31dVBzObUszZOedUmwEAkGRzg9W8k30Y6eNVdWGSTK9PTOXHklw8V2/3VPYduvtQd+/r7n1n5qyTbAYAsEoW8U/WfLcQtSwBKzn5kHU4yYFp/0CSW+fK3zjdZXh5kqesxwIAkrH/oPPpXNC+XiecLqyq92a2yP37qupokrcleXuSD1TVdUkeSvL6qfptSa5OciTJ15JcO6DNAMCKGRWwli1YzavZjYCb69y6oF9eV2x2MwCABXuucLXM4ehEPtq33Nvd+05Uzz8QDQAMMXJ6cBX4Z3UAgIXa7uFqjZAFACyEcPVMpgsBgFMmYH0nIQsAOCUC1rMTsgCAkyZgPTdrsgCADROuTsxIFgCwbrc/ct9CAtZ2CGlCFgCwLtshGC2S6UIA4LsaEa5W+Ynv62UkCwB4TgLWyTOSBQA8q0UHrO0SrtYIWQDAMywyXG23YDVPyAIAkhi5WjQhCwAwejWAkAUA29yiApZw9UxCFgBsU8LVWB7hAADbkIA1npEsANhGhKvTR8gCgG3iVAOWYLUxQhYAbHHC1eawJgsAtjABa/MYyQKALepUApZwdeqELADYYoSr5SBkAcAWcrIBS7haPCELALYA4Wr5CFkAsMKEq+UlZAHAijqZgCVcnT4e4QAAK0jAWn5GsgBghWw0XAlWm0fIAoAVYORq9QhZALDkjF6tJiELALYI4Wq5CFkAsOKEq+V0wrsLq+riqvp4VX2mqv6yqt48lV9QVXdU1YPT6/lTeVXVu6rqSFXdX1WXjf4RALAdXXnRXgFria3nEQ7fTPKr3f2SJJcneVNVvSTJ9Unu7O49Se6cjpPkqiR7pu1gkhsW3moA2OaEq+V3wunC7n40yaPT/ler6rNJdiXZn+QVU7WbknwiyVum8pu7u5PcVVXnVdWF0+cAABt05UV7c/sj9wlWK2ZDDyOtqkuSvDTJ3Ul2zgWnx5LsnPZ3JXl47m1HpzIA4CQJWKtn3SGrqr43yR8l+eXu/sr8uWnUqjfyxVV1sKruqap7ns7XN/JWAIClt66QVVVnZhaw3tPdH5yKH6+qC6fzFyZ5Yio/luTiubfvnsqeobsPdfe+7t53Zs462fYDACyl9dxdWEluTPLZ7v6tuVOHkxyY9g8kuXWu/I3TXYaXJ3nKeiwAYLtZz3OyfizJzyf5i6pae+Tsryd5e5IPVNV1SR5K8vrp3G1Jrk5yJMnXkly70BYDAKyA9dxd+GdJ6jlOX/Es9TvJm06xXQAAK21DdxcCALA+QhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwABCFgDAAEIWAMAAQhYAwAA7TlShqs5O8qdJzprq39Ldb6uqS5O8L8kLk9yb5Oe7+xtVdVaSm5P8SJIvJvnZ7v7CoPYDLJ3bH7lvYZ915UV7F/ZZwOl1wpCV5OtJXtnd/1BVZyb5s6r64yS/kuQd3f2+qvqdJNcluWF6fbK7X1RV1yT5zSQ/O6j9AKfkZAOR8AOcyAmnC3vmH6bDM6etk7wyyS1T+U1JXjvt75+OM52/oqpqYS0GAFgB61qTVVVnVNV9SZ5IckeSzyX5cnd/c6pyNMmuaX9XkoeTZDr/VGZTigBLxSgWMNK6QlZ3/7/u3ptkd5KXJfnBU/3iqjpYVfdU1T1P5+un+nEAAEtlPWuyvqW7v1xVH0/yo0nOq6od02jV7iTHpmrHklyc5GhV7UjygswWwB//WYeSHEqSc+uCPvmfALBcjHQByfruLvz+JE9PAet7kvxkZovZP57kdZndYXggya3TWw5Px/9nOv+x7haigKUjDAEjrWck68IkN1XVGZlNL36guz9cVZ9J8r6q+q9J/jzJjVP9G5P8QVUdSfKlJNcMaDcAwFI7Ycjq7vuTvPRZyj+f2fqs48v/KcnPLKR1AAAryhPfAQAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAYQsgAABhCyAAAGELIAAAZYd8iqqjOq6s+r6sPT8aVVdXdVHamq91fV86bys6bjI9P5S8Y0HQBgeW1kJOvNST47d/ybSd7R3S9K8mSS66by65I8OZW/Y6oHALCtrCtkVdXuJD+d5N3TcSV5ZZJbpio3JXnttL9/Os50/oqpPgDAtrHekazfTvJrSf55On5hki939zen46NJdk37u5I8nCTT+aem+gAA28YJQ1ZVvTrJE9197yK/uKoOVtU9VXXP0/n6Ij8aAGDT7VhHnR9L8pqqujrJ2UnOTfLOJOdV1Y5ptGp3kmNT/WNJLk5ytKp2JHlBki8e/6HdfSjJoSQ5ty7oU/0hAADL5IQjWd391u7e3d2XJLkmyce6+w1JPp7kdVO1A0lunfYPT8eZzn+su4UoAGBbOZXnZL0lya9U1ZHM1lzdOJXfmOSFU/mvJLn+1JoIALB61jNd+C3
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# 从 PaddleOCR 中 import MakeShrinkMap\n",
"from ppocr.data.imaug.make_shrink_map import MakeShrinkMap\n",
"\n",
"# 1. 声明文本概率图标签生成\n",
"generate_shrink_map = MakeShrinkMap()\n",
"\n",
"# 2. 根据解码后的标签计算文本区域概率图\n",
"data = generate_shrink_map(data)\n",
"\n",
"# 3. 文本区域概率图可视化\n",
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(src_img)\n",
"text_border_map = data['shrink_map']\n",
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(text_border_map)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"**归一化**\n",
"\n",
"通过规范化手段把神经网络每层中任意神经元的输入值分布改变成均值为0方差为1的标准正太分布使得最优解的寻优过程明显会变得平缓训练过程更容易收敛\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# 图像归一化类\n",
"class NormalizeImage(object):\n",
" \"\"\" normalize image such as substract mean, divide std\n",
" \"\"\"\n",
"\n",
" def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):\n",
" if isinstance(scale, str):\n",
" scale = eval(scale)\n",
" self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)\n",
" # 1. 获得归一化的均值和方差\n",
" mean = mean if mean is not None else [0.485, 0.456, 0.406]\n",
" std = std if std is not None else [0.229, 0.224, 0.225]\n",
"\n",
" shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)\n",
" self.mean = np.array(mean).reshape(shape).astype('float32')\n",
" self.std = np.array(std).reshape(shape).astype('float32')\n",
"\n",
" def __call__(self, data):\n",
" # 2. 从字典中获取图像数据\n",
" img = data['image']\n",
" from PIL import Image\n",
" if isinstance(img, Image.Image):\n",
" img = np.array(img)\n",
" assert isinstance(img, np.ndarray), \"invalid input 'img' in NormalizeImage\"\n",
"\n",
" # 3. 图像归一化\n",
" data['image'] = (img.astype('float32') * self.scale - self.mean) / self.std\n",
" return data\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"**通道变换**\n",
"\n",
"图像的数据格式为[H, W, C](即高度、宽度和通道数),而神经网络使用的训练数据的格式为[C, H, W],因此需要对图像数据重新排列,例如[224, 224, 3]变为[3, 224, 224]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The shape of image before transpose (720, 1280, 3)\n",
"The shape of image after transpose (3, 720, 1280)\n"
]
}
],
"source": [
"# 改变图像的通道顺序HWC to CHW\n",
"class ToCHWImage(object):\n",
" \"\"\" convert hwc image to chw image\n",
" \"\"\"\n",
" def __init__(self, **kwargs):\n",
" pass\n",
"\n",
" def __call__(self, data):\n",
" # 1. 从字典中获取图像数据\n",
" img = data['image']\n",
" from PIL import Image\n",
" if isinstance(img, Image.Image):\n",
" img = np.array(img)\n",
" \n",
" # 2. 通过转置改变图像的通道顺序\n",
" data['image'] = img.transpose((2, 0, 1))\n",
" return data\n",
" \n",
"# 1. 声明通道变换类\n",
"transpose = ToCHWImage()\n",
"\n",
"# 2. 打印变换前的图像\n",
"print(\"The shape of image before transpose\", data['image'].shape)\n",
"\n",
"# 3. 图像通道变换\n",
"data = transpose(data)\n",
"\n",
"# 4. 打印通向通道变换后的图像\n",
"print(\"The shape of image after transpose\", data['image'].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## 3.3 构建数据读取器\n",
"\n",
"\n",
"上面的代码仅展示了读取一张图片和预处理的方法,在实际模型训练时,多采用批量数据读取处理的方式。\n",
"\n",
"本节采用PaddlePaddle中的[Dataset](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/Dataset_cn.html)和[DatasetLoader](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader) API构建数据读取器。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# dataloader构建详细代码参考https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/data/simple_dataset.py\n",
"\n",
"import numpy as np\n",
"import os\n",
"import random\n",
"from paddle.io import Dataset\n",
"\n",
"def transform(data, ops=None):\n",
" \"\"\" transform \"\"\"\n",
" if ops is None:\n",
" ops = []\n",
" for op in ops:\n",
" data = op(data)\n",
" if data is None:\n",
" return None\n",
" return data\n",
"\n",
"\n",
"def create_operators(op_param_list, global_config=None):\n",
" \"\"\"\n",
" create operators based on the config\n",
" Args:\n",
" params(list): a dict list, used to create some operators\n",
" \"\"\"\n",
" assert isinstance(op_param_list, list), ('operator config should be a list')\n",
" ops = []\n",
" for operator in op_param_list:\n",
" assert isinstance(operator,\n",
" dict) and len(operator) == 1, \"yaml format error\"\n",
" op_name = list(operator)[0]\n",
" param = {} if operator[op_name] is None else operator[op_name]\n",
" if global_config is not None:\n",
" param.update(global_config)\n",
" op = eval(op_name)(**param)\n",
" ops.append(op)\n",
" return ops\n",
"\n",
" \n",
"class SimpleDataSet(Dataset):\n",
" def __init__(self, mode, label_file, data_dir, seed=None):\n",
" super(SimpleDataSet, self).__init__()\n",
" # 标注文件中,使用'\\t'作为分隔符区分图片名称与标签\n",
" self.delimiter = '\\t'\n",
" # 数据集路径\n",
" self.data_dir = data_dir\n",
" # 随机数种子\n",
" self.seed = seed\n",
" # 获取所有数据,以列表形式返回\n",
" self.data_lines = self.get_image_info_list(label_file)\n",
" # 新建列表存放数据索引\n",
" self.data_idx_order_list = list(range(len(self.data_lines)))\n",
" self.mode = mode\n",
" # 如果是训练过程,将数据集进行随机打乱\n",
" if self.mode.lower() == \"train\":\n",
" self.shuffle_data_random()\n",
"\n",
" def get_image_info_list(self, label_file):\n",
" # 获取标签文件中的所有数据\n",
" with open(label_file, \"rb\") as f:\n",
" lines = f.readlines()\n",
" return lines\n",
"\n",
" def shuffle_data_random(self):\n",
" #随机打乱数据\n",
" random.seed(self.seed)\n",
" random.shuffle(self.data_lines)\n",
" return\n",
"\n",
" def __getitem__(self, idx):\n",
" # 1. 获取索引为idx的数据\n",
" file_idx = self.data_idx_order_list[idx]\n",
" data_line = self.data_lines[file_idx]\n",
" try:\n",
" # 2. 获取图片名称以及标签\n",
" data_line = data_line.decode('utf-8')\n",
" substr = data_line.strip(\"\\n\").split(self.delimiter)\n",
" file_name = substr[0]\n",
" label = substr[1]\n",
" # 3. 获取图片路径\n",
" img_path = os.path.join(self.data_dir, file_name)\n",
" data = {'img_path': img_path, 'label': label}\n",
" if not os.path.exists(img_path):\n",
" raise Exception(\"{} does not exist!\".format(img_path))\n",
" # 4. 读取图片并进行预处理\n",
" with open(data['img_path'], 'rb') as f:\n",
" img = f.read()\n",
" data['image'] = img\n",
"\n",
" # 5. 完成数据增强操作\n",
" outs = transform(data, self.mode.lower())\n",
"\n",
" # 6. 如果当前数据读取失败,重新随机读取一个新数据\n",
" except Exception as e:\n",
" outs = None\n",
" if outs is None:\n",
" return self.__getitem__(np.random.randint(self.__len__()))\n",
" return outs\n",
"\n",
" def __len__(self):\n",
" # 返回数据集的大小\n",
" return len(self.data_idx_order_list)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"PaddlePaddle的[Dataloader API](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/io/DataLoader_cn.html#dataloader)中可以使用多进程数据读取,并可以自由设置线程数量。多线程数据读取可以加快数据处理速度和模型训练速度,多线程读取实现代码如下:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"\n",
"from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler\n",
"\n",
"def build_dataloader(mode, label_file, data_dir, batch_size, drop_last, shuffle, num_workers, seed=None):\n",
" # 创建数据读取类\n",
" dataset = SimpleDataSet(mode, label_file, data_dir, seed)\n",
" # 定义 batch_sampler\n",
" batch_sampler = BatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)\n",
" # 使用paddle.io.DataLoader创建数据读取器并设置batchsize进程数量num_workers等参数\n",
" data_loader = DataLoader(dataset=dataset, batch_sampler=batch_sampler, num_workers=num_workers, return_list=True, use_shared_memory=False)\n",
"\n",
" return data_loader\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"ic15_data_path = \"/home/aistudio/data/data96799/icdar2015/text_localization/\"\n",
"train_data_label = \"/home/aistudio/data/data96799/icdar2015/text_localization/train_icdar2015_label.txt\"\n",
"eval_data_label = \"/home/aistudio/data/data96799/icdar2015/text_localization/test_icdar2015_label.txt\"\n",
"\n",
"# 定义训练集数据读取器进程数设置为8\n",
"train_dataloader = build_dataloader('Train', train_data_label, ic15_data_path, batch_size=8, drop_last=False, shuffle=True, num_workers=0)\n",
"# 定义验证集数据读取器\n",
"eval_dataloader = build_dataloader('Eval', eval_data_label, ic15_data_path, batch_size=1, drop_last=False, shuffle=False, num_workers=0)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## 3.4 DB模型后处理\n",
"\n",
"DB head网络的输出形状和原图相同实际上DB head网络输出的三个通道特征分别为文本区域的概率图、阈值图和二值图。\n",
"\n",
"在训练阶段3个预测图与真实标签共同完成损失函数的计算以及模型训练\n",
"\n",
"在预测阶段只需要使用概率图即可DB后处理函数根据概率图中文本区域的响应计算出包围文本响应区域的文本框坐标。\n",
"\n",
"由于网络预测的概率图是经过收缩后的结果,所以在后处理步骤中,使用相同的偏移值将预测的多边形区域进行扩张,即可得到最终的文本框。代码实现如下所示。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/postprocess/db_postprocess.py\n",
"\n",
"import numpy as np\n",
"import cv2\n",
"import paddle\n",
"from shapely.geometry import Polygon\n",
"import pyclipper\n",
"\n",
"\n",
"class DBPostProcess(object):\n",
" \"\"\"\n",
" The post process for Differentiable Binarization (DB).\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" thresh=0.3,\n",
" box_thresh=0.7,\n",
" max_candidates=1000,\n",
" unclip_ratio=2.0,\n",
" use_dilation=False,\n",
" score_mode=\"fast\",\n",
" **kwargs):\n",
" # 1. 获取后处理超参数\n",
" self.thresh = thresh\n",
" self.box_thresh = box_thresh\n",
" self.max_candidates = max_candidates\n",
" self.unclip_ratio = unclip_ratio\n",
" self.min_size = 3\n",
" self.score_mode = score_mode\n",
" assert score_mode in [\n",
" \"slow\", \"fast\"\n",
" ], \"Score mode must be in [slow, fast] but got: {}\".format(score_mode)\n",
"\n",
" self.dilation_kernel = None if not use_dilation else np.array(\n",
" [[1, 1], [1, 1]])\n",
"\n",
" # DB后处理代码详细实现参考 https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/postprocess/db_postprocess.py\n",
"\n",
" def __call__(self, outs_dict, shape_list):\n",
"\n",
" # 1. 从字典中获取网络预测结果\n",
" pred = outs_dict['maps']\n",
" if isinstance(pred, paddle.Tensor):\n",
" pred = pred.numpy()\n",
" pred = pred[:, 0, :, :]\n",
"\n",
" # 2. 大于后处理参数阈值self.thresh的\n",
" segmentation = pred > self.thresh\n",
"\n",
" boxes_batch = []\n",
" for batch_index in range(pred.shape[0]):\n",
" # 3. 获取原图的形状和resize比例\n",
" src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]\n",
" if self.dilation_kernel is not None:\n",
" mask = cv2.dilate(\n",
" np.array(segmentation[batch_index]).astype(np.uint8),\n",
" self.dilation_kernel)\n",
" else:\n",
" mask = segmentation[batch_index]\n",
" \n",
" # 4. 使用boxes_from_bitmap函数 完成 从预测的文本概率图中计算得到文本框\n",
" boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,\n",
" src_w, src_h)\n",
"\n",
" boxes_batch.append({'points': boxes})\n",
" return boxes_batch\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"\n",
"可以发现每个单词都有一个蓝色的框包围着。这些蓝色的框即是在DB输出的分割结果上做一些后处理得到的。将如下代码添加到PaddleOCR/ppocr/postprocess/db_postprocess.py的177行可以可视化DB输出的分割图分割图的可视化结果保存为图像vis_segmentation.png。\n",
"\n",
"```\n",
"_maps = np.array(pred[0, :, :] * 255).astype(np.uint8)\n",
"import cv2\n",
"cv2.imwrite(\"vis_segmentation.png\", _maps)\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"File ./pretrain_models/det_mv3_db_v2.0_train.tar already there; not retrieving.\n",
"\n",
"/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" 0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)\n",
"/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)\n",
"[2021/12/22 21:58:41] root INFO: Architecture : \n",
"[2021/12/22 21:58:41] root INFO: Backbone : \n",
"[2021/12/22 21:58:41] root INFO: model_name : large\n",
"[2021/12/22 21:58:41] root INFO: name : MobileNetV3\n",
"[2021/12/22 21:58:41] root INFO: scale : 0.5\n",
"[2021/12/22 21:58:41] root INFO: Head : \n",
"[2021/12/22 21:58:41] root INFO: k : 50\n",
"[2021/12/22 21:58:41] root INFO: name : DBHead\n",
"[2021/12/22 21:58:41] root INFO: Neck : \n",
"[2021/12/22 21:58:41] root INFO: name : DBFPN\n",
"[2021/12/22 21:58:41] root INFO: out_channels : 256\n",
"[2021/12/22 21:58:41] root INFO: Transform : None\n",
"[2021/12/22 21:58:41] root INFO: algorithm : DB\n",
"[2021/12/22 21:58:41] root INFO: model_type : det\n",
"[2021/12/22 21:58:41] root INFO: Eval : \n",
"[2021/12/22 21:58:41] root INFO: dataset : \n",
"[2021/12/22 21:58:41] root INFO: data_dir : ./train_data/icdar2015/text_localization/\n",
"[2021/12/22 21:58:41] root INFO: label_file_list : ['./train_data/icdar2015/text_localization/test_icdar2015_label.txt']\n",
"[2021/12/22 21:58:41] root INFO: name : SimpleDataSet\n",
"[2021/12/22 21:58:41] root INFO: transforms : \n",
"[2021/12/22 21:58:41] root INFO: DecodeImage : \n",
"[2021/12/22 21:58:41] root INFO: channel_first : False\n",
"[2021/12/22 21:58:41] root INFO: img_mode : BGR\n",
"[2021/12/22 21:58:41] root INFO: DetLabelEncode : None\n",
"[2021/12/22 21:58:41] root INFO: DetResizeForTest : \n",
"[2021/12/22 21:58:41] root INFO: image_shape : [736, 1280]\n",
"[2021/12/22 21:58:41] root INFO: NormalizeImage : \n",
"[2021/12/22 21:58:41] root INFO: mean : [0.485, 0.456, 0.406]\n",
"[2021/12/22 21:58:41] root INFO: order : hwc\n",
"[2021/12/22 21:58:41] root INFO: scale : 1./255.\n",
"[2021/12/22 21:58:41] root INFO: std : [0.229, 0.224, 0.225]\n",
"[2021/12/22 21:58:41] root INFO: ToCHWImage : None\n",
"[2021/12/22 21:58:41] root INFO: KeepKeys : \n",
"[2021/12/22 21:58:41] root INFO: keep_keys : ['image', 'shape', 'polys', 'ignore_tags']\n",
"[2021/12/22 21:58:41] root INFO: loader : \n",
"[2021/12/22 21:58:41] root INFO: batch_size_per_card : 1\n",
"[2021/12/22 21:58:41] root INFO: drop_last : False\n",
"[2021/12/22 21:58:41] root INFO: num_workers : 8\n",
"[2021/12/22 21:58:41] root INFO: shuffle : False\n",
"[2021/12/22 21:58:41] root INFO: use_shared_memory : False\n",
"[2021/12/22 21:58:41] root INFO: Global : \n",
"[2021/12/22 21:58:41] root INFO: cal_metric_during_train : False\n",
"[2021/12/22 21:58:41] root INFO: checkpoints : ./pretrain_models/det_mv3_db_v2.0_train/best_accuracy\n",
"[2021/12/22 21:58:41] root INFO: debug : False\n",
"[2021/12/22 21:58:41] root INFO: distributed : False\n",
"[2021/12/22 21:58:41] root INFO: epoch_num : 1200\n",
"[2021/12/22 21:58:41] root INFO: eval_batch_step : [0, 2000]\n",
"[2021/12/22 21:58:41] root INFO: infer_img : ./doc/imgs_en/img_12.jpg\n",
"[2021/12/22 21:58:41] root INFO: log_smooth_window : 20\n",
"[2021/12/22 21:58:41] root INFO: pretrained_model : ./pretrain_models/MobileNetV3_large_x0_5_pretrained\n",
"[2021/12/22 21:58:41] root INFO: print_batch_step : 10\n",
"[2021/12/22 21:58:41] root INFO: save_epoch_step : 1200\n",
"[2021/12/22 21:58:41] root INFO: save_inference_dir : None\n",
"[2021/12/22 21:58:41] root INFO: save_model_dir : ./output/db_mv3/\n",
"[2021/12/22 21:58:41] root INFO: save_res_path : ./output/det_db/predicts_db.txt\n",
"[2021/12/22 21:58:41] root INFO: use_gpu : True\n",
"[2021/12/22 21:58:41] root INFO: use_visualdl : False\n",
"[2021/12/22 21:58:41] root INFO: Loss : \n",
"[2021/12/22 21:58:41] root INFO: alpha : 5\n",
"[2021/12/22 21:58:41] root INFO: balance_loss : True\n",
"[2021/12/22 21:58:41] root INFO: beta : 10\n",
"[2021/12/22 21:58:41] root INFO: main_loss_type : DiceLoss\n",
"[2021/12/22 21:58:41] root INFO: name : DBLoss\n",
"[2021/12/22 21:58:41] root INFO: ohem_ratio : 3\n",
"[2021/12/22 21:58:41] root INFO: Metric : \n",
"[2021/12/22 21:58:41] root INFO: main_indicator : hmean\n",
"[2021/12/22 21:58:41] root INFO: name : DetMetric\n",
"[2021/12/22 21:58:41] root INFO: Optimizer : \n",
"[2021/12/22 21:58:41] root INFO: beta1 : 0.9\n",
"[2021/12/22 21:58:41] root INFO: beta2 : 0.999\n",
"[2021/12/22 21:58:41] root INFO: lr : \n",
"[2021/12/22 21:58:41] root INFO: learning_rate : 0.001\n",
"[2021/12/22 21:58:41] root INFO: name : Adam\n",
"[2021/12/22 21:58:41] root INFO: regularizer : \n",
"[2021/12/22 21:58:41] root INFO: factor : 0\n",
"[2021/12/22 21:58:41] root INFO: name : L2\n",
"[2021/12/22 21:58:41] root INFO: PostProcess : \n",
"[2021/12/22 21:58:41] root INFO: box_thresh : 0.6\n",
"[2021/12/22 21:58:41] root INFO: max_candidates : 1000\n",
"[2021/12/22 21:58:41] root INFO: name : DBPostProcess\n",
"[2021/12/22 21:58:41] root INFO: thresh : 0.3\n",
"[2021/12/22 21:58:41] root INFO: unclip_ratio : 1.5\n",
"[2021/12/22 21:58:41] root INFO: Train : \n",
"[2021/12/22 21:58:41] root INFO: dataset : \n",
"[2021/12/22 21:58:41] root INFO: data_dir : ./train_data/icdar2015/text_localization/\n",
"[2021/12/22 21:58:41] root INFO: label_file_list : ['./train_data/icdar2015/text_localization/train_icdar2015_label.txt']\n",
"[2021/12/22 21:58:41] root INFO: name : SimpleDataSet\n",
"[2021/12/22 21:58:41] root INFO: ratio_list : [1.0]\n",
"[2021/12/22 21:58:41] root INFO: transforms : \n",
"[2021/12/22 21:58:41] root INFO: DecodeImage : \n",
"[2021/12/22 21:58:41] root INFO: channel_first : False\n",
"[2021/12/22 21:58:41] root INFO: img_mode : BGR\n",
"[2021/12/22 21:58:41] root INFO: DetLabelEncode : None\n",
"[2021/12/22 21:58:41] root INFO: IaaAugment : \n",
"[2021/12/22 21:58:41] root INFO: augmenter_args : \n",
"[2021/12/22 21:58:41] root INFO: args : \n",
"[2021/12/22 21:58:41] root INFO: p : 0.5\n",
"[2021/12/22 21:58:41] root INFO: type : Fliplr\n",
"[2021/12/22 21:58:41] root INFO: args : \n",
"[2021/12/22 21:58:41] root INFO: rotate : [-10, 10]\n",
"[2021/12/22 21:58:41] root INFO: type : Affine\n",
"[2021/12/22 21:58:41] root INFO: args : \n",
"[2021/12/22 21:58:41] root INFO: size : [0.5, 3]\n",
"[2021/12/22 21:58:41] root INFO: type : Resize\n",
"[2021/12/22 21:58:41] root INFO: EastRandomCropData : \n",
"[2021/12/22 21:58:41] root INFO: keep_ratio : True\n",
"[2021/12/22 21:58:41] root INFO: max_tries : 50\n",
"[2021/12/22 21:58:41] root INFO: size : [640, 640]\n",
"[2021/12/22 21:58:41] root INFO: MakeBorderMap : \n",
"[2021/12/22 21:58:41] root INFO: shrink_ratio : 0.4\n",
"[2021/12/22 21:58:41] root INFO: thresh_max : 0.7\n",
"[2021/12/22 21:58:41] root INFO: thresh_min : 0.3\n",
"[2021/12/22 21:58:41] root INFO: MakeShrinkMap : \n",
"[2021/12/22 21:58:41] root INFO: min_text_size : 8\n",
"[2021/12/22 21:58:41] root INFO: shrink_ratio : 0.4\n",
"[2021/12/22 21:58:41] root INFO: NormalizeImage : \n",
"[2021/12/22 21:58:41] root INFO: mean : [0.485, 0.456, 0.406]\n",
"[2021/12/22 21:58:41] root INFO: order : hwc\n",
"[2021/12/22 21:58:41] root INFO: scale : 1./255.\n",
"[2021/12/22 21:58:41] root INFO: std : [0.229, 0.224, 0.225]\n",
"[2021/12/22 21:58:41] root INFO: ToCHWImage : None\n",
"[2021/12/22 21:58:41] root INFO: KeepKeys : \n",
"[2021/12/22 21:58:41] root INFO: keep_keys : ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask']\n",
"[2021/12/22 21:58:41] root INFO: loader : \n",
"[2021/12/22 21:58:41] root INFO: batch_size_per_card : 16\n",
"[2021/12/22 21:58:41] root INFO: drop_last : False\n",
"[2021/12/22 21:58:41] root INFO: num_workers : 8\n",
"[2021/12/22 21:58:41] root INFO: shuffle : True\n",
"[2021/12/22 21:58:41] root INFO: use_shared_memory : False\n",
"[2021/12/22 21:58:41] root INFO: train with paddle 2.2.1 and device CUDAPlace(0)\n",
"W1222 21:58:41.294615 18214 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1\n",
"W1222 21:58:41.298939 18214 device_context.cc:465] device: 0, cuDNN Version: 7.6.\n",
"[2021/12/22 21:58:44] root INFO: resume from ./pretrain_models/det_mv3_db_v2.0_train/best_accuracy\n",
"[2021/12/22 21:58:44] root INFO: infer_img: ./doc/imgs_en/img_12.jpg\n",
"[2021/12/22 21:58:44] root INFO: The detected Image saved in ./output/det_db/det_results/img_12.jpg\n",
"[2021/12/22 21:58:44] root INFO: success!\n"
]
}
],
"source": [
"# 1. 下载训练好的模型\n",
"!wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar\n",
"!cd ./pretrain_models/ && tar xf det_mv3_db_v2.0_train.tar && cd ../\n",
"\n",
"# 2. 执行文本检测预测得到结果\n",
"!python tools/infer_det.py -c configs/det/det_mv3_db.yml \\\n",
" -o Global.checkpoints=./pretrain_models/det_mv3_db_v2.0_train/best_accuracy \\\n",
" Global.infer_img=./doc/imgs_en/img_12.jpg \n",
" #PostProcess.unclip_ratio=4.0\n",
"# 注有关PostProcess参数和Global参数介绍与使用参考 https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.3/doc/doc_ch/config.md"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"可视化预测模型预测的文本概率图,以及最终预测文本框结果。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7ff4344889d0>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmAAAAF/CAYAAADuA3UDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvc2ubcuSHvRFZOYYa1e5QcfulGmAgBdA4g2Q6NGFFzBI8AC0aNDhHfwEiCYNSzyD24BAFhLCYCTTtO/Zc4zMCBrxkzHm3mXu6ZzaKs+s2jp3rTXn+MmMjIz44osIUlV8xmd8xmd8xmd8xmd8xh83+G/6AT7jMz7jMz7jMz7jM/51Gx8D7DM+4zM+4zM+4zM+4w8eHwPsMz7jMz7jMz7jMz7jDx4fA+wzPuMzPuMzPuMzPuMPHh8D7DM+4zM+4zM+4zM+4w8eHwPsMz7jMz7jMz7jMz7jDx5/uAFGRP8REf2vRPRPiOi/+qPv/xmf8Rmf8Rmf8Rmf8Tc96I+sA0ZEDcD/BuA/BPBPAfxjAP+pqv7Pf9hDfMZnfMZnfMZnfMZn/A2PPxoB+w8A/BNV/d9V9QLw3wH4j//gZ/iMz/iMz/iMz/iMz/gbHX+0AfZXAP7P8vM/9d99xmd8xmd8xmd8xmf8azP63/QDvA8i+gcA/gEAHGP8+3/v7/49KHaYVAHQ/rT/BlAFiPZnUD5HoH0NorwI+feAH8Ow5J/TuKICxPE7tWuXG9rvKD9vf6IfrvvDe2h9rrxyXvqnLxS/jnfWH/8W71z/KypgZqjuZ/Xb+sXiJva+Ss/Lkv+k5YbxfftvmcecG81r0vsCxaMTfvKOBFXZ67WfLO9FROXe2NeH/vCzXU/r5fx99Hlvf9y6LM9Z0DJV/6oFqff3uXmsU33pny8xEUFlC/aWjh+/b3Pxw5s9PwOClnfc8l/kAHu9ysWfz16/X5/3TQ61rP1PN2eVQzyH1lfM+5Yfys3rmubn8pd7LTSvgecE+LPl7bTcv1wz5vh5rTLXfr0qe7H3nm9Ge7/UefnpmuzPx2fEP/DjHt+XYr9YTMF+jjdJKzd8imfZBI+/PXXu46Z/nWA87lP1xtvP5XIiri9yiyrY9/vPFF7s7f2nsvfyJ63LnX9TBfh9r5ZX1dBdujVfzuLbg5tuLVPy103D8xHfRRGhr35418cFyueK7o51yqfNz5f18R/zu1VmVUFM+S5P/V6l4GfP9qMertd8yIn/XaGQJbheL/Q+cH47H3vGtffbOwuIqJy58Un7+0P//XXnTTlQVBXMpm/je/W8e+w1f5f436ETYvzf/88/+39V9e/izxh/tAH2fwH4N8vPf99/l0NV/yGAfwgAf/+v/kr/i//sPwcRgZkx5wQ4NiyhtYa1ln3RN6wSwGzAni5B7x0iAsB+f6+Zk9WIwcwQEbAKhBhrLbTRbUFA+VkReQooKVpruO+F1lo+v8gEyO5H2L+/ronzPKGqEF9iBoEZmHOmYdRag6q971oLzPaMWAJVxVR7JyLCnLMYIvbdeFciAjXO+0xZNmfE+Z2Yu5gvInp8H2zvr2v/Lp41PjfGwDVvtNbQWsOcNr85b9Pmh5kxZYHRAAgEmmvTBPksynbf1pod6ABUCYDk9QHgdV/49u0bZC4QtXynwfaZCc057b1DxeZtzok5J0QEvTPGGPh+vTBax33f9p6NAd3gMBGht4YGwoRCls2lxLVdUUm5BzWTk0Ymt/F77g0QhYjk3MVcka8fA5Blcys6fa7j2jYP13XZOsyF89sXILoPDiYzIlw5xN45jiNlT1V9bQZEFUyuSBpjiqC5fIwxHteJ+6jrGzGhR+eW+++6LruXyEM5ERFk2jUFClJgXTf6eeQzzTkxxkh5aK1hqcufIvfEvSZ6s3u0Xo1P5LPG70JW9/7SlOG6v2P/kNoc1r1FaLmvqz5gZpBi72kyeRCSvL9q6Kob3dcidMYUwSCGuLxOn7Pmeyj0UA4mXL7HmDsUNxoPQAQMgog9G/c956r6eN54rtjPoQNivUUEYMa6b9OHrUGJQEKpK0LWl9wp3yEnvXfcl+2R4zhwXZfpLQYa77UF2X3i+WJuTPcwmj9f7x2v+8bgkc+Zcq2CTmz6bF62V3uHQKFCKYPKhAb7mRuwpsnCVFt3nespO1gQhA72OS33FhEcxwFmO0JfrxfGGDaXtiuwXH6qfo9/IYchY83PimuZbBAR5Db9zqNjrQVqQIftIVWFMIFdlygh12Ap+X/Fz6jbzjpZECVQt3t3uFyK2Jqo4rUmjma6ahx2vjJ1fL8vjEYYY+D1eqER23fJ9lHIV/z97DYX933nuwsJhj9PzGVrDb/99htev1345//sn+P49g3/zr/3b+M8T7CfR3MtrGk6JWS3tQYswVTsM5z5cTYBew/K7Xu9MUQm2uj2XzpSHvPcWjPfZ6qAe0PXBtWV97H/diwsdOpbpgH81//tf/N/4M8cf3QI8h8D+HeJ6N8iogPAfwLgf/hXfSEO9lC88TsAj4WPIS5MoSCqkIfgVwMh/i7EUF1ojcyI8sVba+Fe2/AD4AJFbhggr2P3PNLwsnswiBrO88zFI8VDqbY2MMaJ1kZeM953rYX7viG0BSQ2dSg4AGmUtdbAvbmA2VwoIY2gKiipbAhYKphin6XGWLrnccrKue29g5nRe8/fxbVjfuPd67zHAUNs104DUgiXLlBv6L2j947mm5pEMedG7VL5QfOQrgqNmfGaN5SfB81932lcLn+Hr69jG3RsB+DX1xeoccocM7uXtQ+rtRY05iIU5bohMrH8v70zGnEakLFmcT3uLee49573yvsRmVGBVb7X/B9hqQK+Bl9f2/iK70NsjqixyS0T2uhYKkU5bcVVje44/Os+q4ZMrOl2YsgP6L0HwtGo8grYgRRGTjWYcl7972EUxD6on49net/H73NcDb+Q15DZGFVvhLzEvnxcFy11T8xH3D/kMQ05NQeMgfzXiaFzpUERexYAul83dRvRntNwlFRS/4gbx3CjXGUbUaE7FGYYp24LvVN0Xqx77GdmTj1rDgogYHA/oNTM2GM7wGJe5pzQJblf4vdXMTCvNc348fkK/RN6O/Z1yDX3bvfz5wwdPvx6oRfrWudc+l6qcvrYV72BR7e9cx5QJnRiM8wI4NHRjpG69rGnYFPe+4ExTpznNxC1dNrGGHm/NRWy8HzfdKTKWvuctNYgJEDjNOLe1y0dA5h+psYpJ3Xt7vsGQ6DrRiMAohjN1pgaPwxoZsZ937ZGvg6hE1preL1eOZdf44As4HqZ40PcMZeC/ZnSoXhddtYRcl3j73H+hH6In2XZ89hcOtIpko5yU6RjkN+RianTJV3yvda6c+7SaQRAndCO9ng3oobr+g5moNlk+T8bJgcNpIxrTTNcecuKkD331Kf+/D3jDzXAVHUC+C8B/I8A/hcA/72q/k//6m/J42APjzy8+t471D3vQL9CaAMheveO0ysqxpnAFOmUlYdnIALxPSU8Dqv7vtMqz823Fno/0PiEYD8XNQa3AdWVG8qUhT3X9+/fc8PGgRGbJgwoNPv9dV25cV+vVyopAABvFEtVcfQBXVuYMwTpB1YYabGJYyNsgV52cBWj5n1TVYP4sencMw3lUJUKACgRrjUB7hAlTIV7gmzeq3t1V/GYApVLxQSFkqA1wlp3GtDEiiU3Wif73+s2A7sTlATXde13UAajQafi5MN2ni7IutEY7gFPKPuBOIuX6qgBEZnX5nIZB0xFVgLhifvGHIQ8kSpIN0KVsunXsN2Ah2LJwYRr3o89UL+fxpAWRcEEduUfBlTnBpVtOFXHh9kMOiWkE8FwgzCey2Jk+fPPDPGys1O2lJB7kNo2oiqiPWWZEYmtvKuhoqrQJWlEvct5/K7KazXiYh1s7vZ7v3//3bCsvwfwQM5jD8U9v76+HgexqgK00ZCY03rYJHrieiD2JMP2C6vtlyWSqKztC3oYYrGe785nNTaY+aHT6kGZ6xbGE5a9rbIZHGhQMgSfiHDL3q8hu2gblRaD6zBaQ0fDoA4WQi9GcMxV7IGc18a2TVX
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAFkCAYAAAAT9C6pAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xl8VPW9//HXZ2aSsIkEVJSEVVDcI/tirYoVDAh67VVaa7V6RevSetveVttfb5fb9dra1l2rrWhtudZWRYxSxVqrgEiAIopI2CQBd1ZZksx8f3/MmTBZ50wyh5kk7+fjkcecOfM953zPmZOcd77fM/M15xwiIiIiklmhbFdAREREpCNSyBIREREJgEKWiIiISAAUskREREQCoJAlIiIiEgCFLBEREZEABBKyzGyKma0xswozuymIbYiIiIjkMsv092SZWRh4G/gMUAm8BnzOOfdmRjckIiIiksOCaMkaA1Q459Y756qBOcCMALYjIiIikrOCCFlFwOak55XePBEREZFOI5KtDZvZLGAWQJjwyG70zFZVRERERHzbxbYPnXOHpyoXRMiqAvonPS/25tXjnLsPuA+gp/V2Y21SAFURERERyazn3WOb/JQLorvwNWCYmQ02s3xgJjA3gO2IiIiI5KyMt2Q552rN7HpgPhAGfueceyPT2xERERHJZYHck+WcKwPKgli3iIiISHugb3wXERERCYBCloiIiEgAFLJEREREAqCQJSIiIhIAhSwRERGRAChkiYiIiARAIUtEREQkAApZIiIiIgFQyBIREREJgEKWiIiISAAUskREREQCoJAlIiIiEgCFLBEREZEAKGSJiIiIBEAhS0RERCQAClkiIiIiAVDIEhEREQmAQpaIiIhIABSyRERERAKgkCUiIiISAIUsERERkQAoZImIiIgEQCFLREREJAAKWSIiIiIBUMgSERERCYBCloiIiEgAFLJEREREAqCQJSIiIhIAhSwRERGRAChkiYiIiARAIUtEREQkAApZIiIiIgFQyBIREREJgEKWiIiISAAUskREREQCkDJkmdnvzOx9M1uVNK+3mT1nZmu9x0JvvpnZbWZWYWYrzWxEkJUXERERyVV+WrIeBKY0mHcTsMA5NwxY4D0HOBcY5v3MAu7OTDVFRERE2peUIcs59xLwcYPZM4DZ3vRs4Pyk+Q+5uMVALzM7KlOVFREREWkvWntPVl/n3FZv+l2grzddBGxOKlfpzRMRERHpVNp847tzzgEu3eXMbJaZLTWzpTXsb2s1RERERHJKa0PWe4luQO/xfW9+FdA/qVyxN68R59x9zrlRzrlReRS0shoiIiIiuam1IWsucJk3fRnwZNL8L3qfMhwH7EjqVhQRERHpNCKpCpjZn4AzgMPMrBL4HvAz4FEzuxLYBFzkFS8DSoEKYA/wpQDqLCIiIpLzUoYs59znmnlpUhNlHXBdWyslIiIi0t7pG99FREREAqCQJSIiIhIAhSwRERGRAChkiYiIiARAIUtEREQkAApZIiIiIgFI+RUOIh3BPZteZkCkW93z0qIR7DtvDF2eWpLFWomISEemkCUdXuSoIxmc16PevPlbVgAr4F6IuhgAP/zwJF4d1R0AV1N9sKspIiIdjEKWdHg/XvgktDA+ZtjiveY/OPyN+PgFTYi6GGEL1T0u2Bvm9C7VhDBG/OJ6ln3jDkqLRnDzupX87+QZRCs2EBk0gNqN79Rbz/YvjqfXQ4uIFPWjtmpLyxU3A+dv7PXQycOZNucVttV257/6vA5AhHC9Oic/bovtpTDUlQsqSnl8aBlTp1/K03Mf5rwTJ/HUqgVNHp/k45D8HGBnbB8zB0zkU//ayz9P7tL0/nn7YwUF7JpxKoNvfIufFc/jqmPO5tRFeyg/NcS6X45jzcy76rab2N7gJ2dRMf0epk39AvOe/kOzxyFRt8Ryu2P76Gr5xHCEsGaPR8PtJT9etuksZg98gakDx/D0piV162rq+CRral2JsvtdDRHCjZbZ66q56DOXMvPxv7Nk19F8/8gXuKT/RD586hgOO+/t+PaOP4bom2/z9gOjOOaq5Tz+ziIKLMJ33y/hB0csZ1rRSCpvnkDxTxcS7nUoD75exuVnfoHvP/co3x08mg0/Gc/gby+q/9aMOhG3dFW9cy48bAjRtesb1TEyeCAbf9GD3494kGPzaimZ91UqzruH6aNKqd36LpaXj6upxkaeAP9aw8/WvsIPN0/jO/2fpiQ/fsmJ4ZhWNLLZ91GkozDn8494kHpabzfWGn2BvEjbmVFWWd7shVBEsmN3bB8XFo/LdjVEWuV591i5c25UqnK68kiHNq9yqQKWSA7qEeqS7SqIBE5XH+nQvv1eyn80REREAqF7sqRDe31MGN5JXa6jS9x35Md33y+h/NQQn1w4lu5/eZWyqmX1Xn+rZj+DI2EuGDCei96o4oHvn88/b72rxXUm3ysFsNvtp5vl17u3Cai7/2ld7V4GRvKJECaGY8SSS1k6ZjbTi0Y3XnniXi/vXqCESP9iajdX+t7vdIV79iS6axc4x4ezxtP3xff5yfw/UrbrZP5xctdG5W/f9Ao3DJwIwDfXvc7X7riaZd+4o/F6O0nLa42LZrsKIoHTPVnSaWz6wQQGfm8hP9mwhG8PHsPcqteavmgH4K5NL3NYOMzMoWcS27cvvYVDYYjVvyCFhw4mWrEhgzUUadm8qnLG/s/1vPbdOwEYWnY1FaX3AjS6sX9Dze56X5mSKAMHgnZp0YiDVXWRjPN7T5ZCloiIBCsUBhcjfNwwom++ne3aiLSZ35Cl7kIREQmW1xKrgCWdTefo/BcRERE5yBSyRERERAKgkCUiIiISAIUsERERkQAoZImIiIgEQCFLREREJAAKWSIiIiIBUMgSERERCYBCloiIiEgAFLJEREREAqCQJSIiIhIAhSwRERGRAChkiYiIiARAIUtEREQkAApZIiIiIgFQyBIREREJgEKWiIiISAAUskREREQCkDJkmVl/M/u7mb1pZm+Y2Ve9+b3N7DkzW+s9FnrzzcxuM7MKM1tpZiOC3gkRERGRXOOnJasW+Lpz7nhgHHCdmR0P3AQscM4NAxZ4zwHOBYZ5P7OAuzNeaxEREZEclzJkOee2OueWedO7gNVAETADmO0Vmw2c703PAB5ycYuBXmZ2VMZrLiIiIpLD0rony8wGAacCrwJ9nXNbvZfeBfp600XA5qTFKr15IiIiIp2G75BlZj2AvwA3Oud2Jr/mnHOAS2fDZjbLzJaa2dIa9qezqIiIiEjO8xWyzCyPeMB6xDn3V2/2e4luQO/xfW9+FdA/afFib149zrn7nHOjnHOj8ihobf1FREREcpKfTxca8ACw2jl3a9JLc4HLvOnLgCeT5n/R+5ThOGBHUreiiIiISKcQ8VFmInAp8LqZrfDmfRv4GfComV0JbAIu8l4rA0qBCmAP8KWM1lhERESkHUgZspxzLwPWzMuTmijvgOvaWC8RERGRdk3f+C4iIiISAIUsERERkQAoZImIiIgEQCFLREREJAAKWSIiIiIBUMgSERERCYBCloiIiEgAFLJEREREAqCQJSIiIhIAhSwRERGRAChkiYiIiARAIUtEREQkAApZIiIiIgFQyBIREREJgEKWiIiISAAUskREREQCoJAlIiIiEoBItisg0tncs+llrp0xi9iKN30vU1a1rN7z/a6WC4rHZLpqmWGGhcOEenQnun1H69cz7mTCb27k8dULiBAmbCHej37CxD9+gyHfWpTWqsqqlhG25v+nnNyvpMXlv17xBpO67m+0jqiLcew/ruC/R8zjkeHFadWpKet+OY7VM+8khFFaNKLN6xOR7DLnXLbrQE/r7cbapGxXQ3KY5eXjaqrjF/BRJxKqqOQny57lpuGns+5/RjD4iT1EVr/DoL/tZd3ofVmqpFFWWV73tKWLOsQv0H4upKFu3XimYmHKclEXq9tuqtDQlO+sX8GPh8SXSz7e+PwbMX/LCt/bSq7rh9FPKAx1ZWdsHz1DXermtySd/fNTr1TrSxXSGoq6GOedOInotm2+lzl
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = Image.open('./output/det_db/det_results/img_12.jpg')\n",
"img = np.array(img)\n",
"\n",
"# 画出读取的图片\n",
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(img)\n",
"\n",
"img = Image.open('./vis_segmentation.png')\n",
"img = np.array(img)\n",
"\n",
"# 画出读取的图片\n",
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(img)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"从可视化结果中可以发现DB的输出结果是文本区域的二值图属于文本区域的响应更高非文本的背景区域响应值低。DB的后处理即是求这些响应区域的最小包围框进而得到每个文本区域的坐标。\n",
"另外,通过修改后处理参数可以调整文本框的大小,或者过滤检测效果差的文本框。\n",
"\n",
"DB后处理有四个参数分别是\n",
"- thresh: DBPostProcess中分割图进行二值化的阈值默认值为0.3\n",
"- box_thresh: DBPostProcess中对输出框进行过滤的阈值低于此阈值的框不会输出\n",
"- unclip_ratio: DBPostProcess中对文本框进行放大的比例\n",
"- max_candidates: DBPostProcess中输出的最大文本框数量默认1000\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" 0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)\n",
"/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)\n",
"[2021/12/22 21:59:35] root INFO: Architecture : \n",
"[2021/12/22 21:59:35] root INFO: Backbone : \n",
"[2021/12/22 21:59:35] root INFO: model_name : large\n",
"[2021/12/22 21:59:35] root INFO: name : MobileNetV3\n",
"[2021/12/22 21:59:35] root INFO: scale : 0.5\n",
"[2021/12/22 21:59:35] root INFO: Head : \n",
"[2021/12/22 21:59:35] root INFO: k : 50\n",
"[2021/12/22 21:59:35] root INFO: name : DBHead\n",
"[2021/12/22 21:59:35] root INFO: Neck : \n",
"[2021/12/22 21:59:35] root INFO: name : DBFPN\n",
"[2021/12/22 21:59:35] root INFO: out_channels : 256\n",
"[2021/12/22 21:59:35] root INFO: Transform : None\n",
"[2021/12/22 21:59:35] root INFO: algorithm : DB\n",
"[2021/12/22 21:59:35] root INFO: model_type : det\n",
"[2021/12/22 21:59:35] root INFO: Eval : \n",
"[2021/12/22 21:59:35] root INFO: dataset : \n",
"[2021/12/22 21:59:35] root INFO: data_dir : ./train_data/icdar2015/text_localization/\n",
"[2021/12/22 21:59:35] root INFO: label_file_list : ['./train_data/icdar2015/text_localization/test_icdar2015_label.txt']\n",
"[2021/12/22 21:59:35] root INFO: name : SimpleDataSet\n",
"[2021/12/22 21:59:35] root INFO: transforms : \n",
"[2021/12/22 21:59:35] root INFO: DecodeImage : \n",
"[2021/12/22 21:59:35] root INFO: channel_first : False\n",
"[2021/12/22 21:59:35] root INFO: img_mode : BGR\n",
"[2021/12/22 21:59:35] root INFO: DetLabelEncode : None\n",
"[2021/12/22 21:59:35] root INFO: DetResizeForTest : \n",
"[2021/12/22 21:59:35] root INFO: image_shape : [736, 1280]\n",
"[2021/12/22 21:59:35] root INFO: NormalizeImage : \n",
"[2021/12/22 21:59:35] root INFO: mean : [0.485, 0.456, 0.406]\n",
"[2021/12/22 21:59:35] root INFO: order : hwc\n",
"[2021/12/22 21:59:35] root INFO: scale : 1./255.\n",
"[2021/12/22 21:59:35] root INFO: std : [0.229, 0.224, 0.225]\n",
"[2021/12/22 21:59:35] root INFO: ToCHWImage : None\n",
"[2021/12/22 21:59:35] root INFO: KeepKeys : \n",
"[2021/12/22 21:59:35] root INFO: keep_keys : ['image', 'shape', 'polys', 'ignore_tags']\n",
"[2021/12/22 21:59:35] root INFO: loader : \n",
"[2021/12/22 21:59:35] root INFO: batch_size_per_card : 1\n",
"[2021/12/22 21:59:35] root INFO: drop_last : False\n",
"[2021/12/22 21:59:35] root INFO: num_workers : 8\n",
"[2021/12/22 21:59:35] root INFO: shuffle : False\n",
"[2021/12/22 21:59:35] root INFO: use_shared_memory : False\n",
"[2021/12/22 21:59:35] root INFO: Global : \n",
"[2021/12/22 21:59:35] root INFO: cal_metric_during_train : False\n",
"[2021/12/22 21:59:35] root INFO: checkpoints : ./pretrain_models/det_mv3_db_v2.0_train/best_accuracy\n",
"[2021/12/22 21:59:35] root INFO: debug : False\n",
"[2021/12/22 21:59:35] root INFO: distributed : False\n",
"[2021/12/22 21:59:35] root INFO: epoch_num : 1200\n",
"[2021/12/22 21:59:35] root INFO: eval_batch_step : [0, 2000]\n",
"[2021/12/22 21:59:35] root INFO: infer_img : ./doc/imgs_en/img_12.jpg\n",
"[2021/12/22 21:59:35] root INFO: log_smooth_window : 20\n",
"[2021/12/22 21:59:35] root INFO: pretrained_model : ./pretrain_models/MobileNetV3_large_x0_5_pretrained\n",
"[2021/12/22 21:59:35] root INFO: print_batch_step : 10\n",
"[2021/12/22 21:59:35] root INFO: save_epoch_step : 1200\n",
"[2021/12/22 21:59:35] root INFO: save_inference_dir : None\n",
"[2021/12/22 21:59:35] root INFO: save_model_dir : ./output/db_mv3/\n",
"[2021/12/22 21:59:35] root INFO: save_res_path : ./output/det_db/predicts_db.txt\n",
"[2021/12/22 21:59:35] root INFO: use_gpu : True\n",
"[2021/12/22 21:59:35] root INFO: use_visualdl : False\n",
"[2021/12/22 21:59:35] root INFO: Loss : \n",
"[2021/12/22 21:59:35] root INFO: alpha : 5\n",
"[2021/12/22 21:59:35] root INFO: balance_loss : True\n",
"[2021/12/22 21:59:35] root INFO: beta : 10\n",
"[2021/12/22 21:59:35] root INFO: main_loss_type : DiceLoss\n",
"[2021/12/22 21:59:35] root INFO: name : DBLoss\n",
"[2021/12/22 21:59:35] root INFO: ohem_ratio : 3\n",
"[2021/12/22 21:59:35] root INFO: Metric : \n",
"[2021/12/22 21:59:35] root INFO: main_indicator : hmean\n",
"[2021/12/22 21:59:35] root INFO: name : DetMetric\n",
"[2021/12/22 21:59:35] root INFO: Optimizer : \n",
"[2021/12/22 21:59:35] root INFO: beta1 : 0.9\n",
"[2021/12/22 21:59:35] root INFO: beta2 : 0.999\n",
"[2021/12/22 21:59:35] root INFO: lr : \n",
"[2021/12/22 21:59:35] root INFO: learning_rate : 0.001\n",
"[2021/12/22 21:59:35] root INFO: name : Adam\n",
"[2021/12/22 21:59:35] root INFO: regularizer : \n",
"[2021/12/22 21:59:35] root INFO: factor : 0\n",
"[2021/12/22 21:59:35] root INFO: name : L2\n",
"[2021/12/22 21:59:35] root INFO: PostProcess : \n",
"[2021/12/22 21:59:35] root INFO: box_thresh : 0.6\n",
"[2021/12/22 21:59:35] root INFO: max_candidates : 1000\n",
"[2021/12/22 21:59:35] root INFO: name : DBPostProcess\n",
"[2021/12/22 21:59:35] root INFO: thresh : 0.3\n",
"[2021/12/22 21:59:35] root INFO: unclip_ratio : 4.0\n",
"[2021/12/22 21:59:35] root INFO: Train : \n",
"[2021/12/22 21:59:35] root INFO: dataset : \n",
"[2021/12/22 21:59:35] root INFO: data_dir : ./train_data/icdar2015/text_localization/\n",
"[2021/12/22 21:59:35] root INFO: label_file_list : ['./train_data/icdar2015/text_localization/train_icdar2015_label.txt']\n",
"[2021/12/22 21:59:35] root INFO: name : SimpleDataSet\n",
"[2021/12/22 21:59:35] root INFO: ratio_list : [1.0]\n",
"[2021/12/22 21:59:35] root INFO: transforms : \n",
"[2021/12/22 21:59:35] root INFO: DecodeImage : \n",
"[2021/12/22 21:59:35] root INFO: channel_first : False\n",
"[2021/12/22 21:59:35] root INFO: img_mode : BGR\n",
"[2021/12/22 21:59:35] root INFO: DetLabelEncode : None\n",
"[2021/12/22 21:59:35] root INFO: IaaAugment : \n",
"[2021/12/22 21:59:35] root INFO: augmenter_args : \n",
"[2021/12/22 21:59:35] root INFO: args : \n",
"[2021/12/22 21:59:35] root INFO: p : 0.5\n",
"[2021/12/22 21:59:35] root INFO: type : Fliplr\n",
"[2021/12/22 21:59:35] root INFO: args : \n",
"[2021/12/22 21:59:35] root INFO: rotate : [-10, 10]\n",
"[2021/12/22 21:59:35] root INFO: type : Affine\n",
"[2021/12/22 21:59:35] root INFO: args : \n",
"[2021/12/22 21:59:35] root INFO: size : [0.5, 3]\n",
"[2021/12/22 21:59:35] root INFO: type : Resize\n",
"[2021/12/22 21:59:35] root INFO: EastRandomCropData : \n",
"[2021/12/22 21:59:35] root INFO: keep_ratio : True\n",
"[2021/12/22 21:59:35] root INFO: max_tries : 50\n",
"[2021/12/22 21:59:35] root INFO: size : [640, 640]\n",
"[2021/12/22 21:59:35] root INFO: MakeBorderMap : \n",
"[2021/12/22 21:59:35] root INFO: shrink_ratio : 0.4\n",
"[2021/12/22 21:59:35] root INFO: thresh_max : 0.7\n",
"[2021/12/22 21:59:35] root INFO: thresh_min : 0.3\n",
"[2021/12/22 21:59:35] root INFO: MakeShrinkMap : \n",
"[2021/12/22 21:59:35] root INFO: min_text_size : 8\n",
"[2021/12/22 21:59:35] root INFO: shrink_ratio : 0.4\n",
"[2021/12/22 21:59:35] root INFO: NormalizeImage : \n",
"[2021/12/22 21:59:35] root INFO: mean : [0.485, 0.456, 0.406]\n",
"[2021/12/22 21:59:35] root INFO: order : hwc\n",
"[2021/12/22 21:59:35] root INFO: scale : 1./255.\n",
"[2021/12/22 21:59:35] root INFO: std : [0.229, 0.224, 0.225]\n",
"[2021/12/22 21:59:35] root INFO: ToCHWImage : None\n",
"[2021/12/22 21:59:35] root INFO: KeepKeys : \n",
"[2021/12/22 21:59:35] root INFO: keep_keys : ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask']\n",
"[2021/12/22 21:59:35] root INFO: loader : \n",
"[2021/12/22 21:59:35] root INFO: batch_size_per_card : 16\n",
"[2021/12/22 21:59:35] root INFO: drop_last : False\n",
"[2021/12/22 21:59:35] root INFO: num_workers : 8\n",
"[2021/12/22 21:59:35] root INFO: shuffle : True\n",
"[2021/12/22 21:59:35] root INFO: use_shared_memory : False\n",
"[2021/12/22 21:59:35] root INFO: train with paddle 2.2.1 and device CUDAPlace(0)\n",
"W1222 21:59:35.610255 18271 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1\n",
"W1222 21:59:35.614423 18271 device_context.cc:465] device: 0, cuDNN Version: 7.6.\n",
"[2021/12/22 21:59:38] root INFO: resume from ./pretrain_models/det_mv3_db_v2.0_train/best_accuracy\n",
"[2021/12/22 21:59:38] root INFO: infer_img: ./doc/imgs_en/img_12.jpg\n",
"[2021/12/22 21:59:39] root INFO: The detected Image saved in ./output/det_db/det_results/img_12.jpg\n",
"[2021/12/22 21:59:39] root INFO: success!\n"
]
}
],
"source": [
"\n",
"# 3. 增大DB后处理的参数unlip_ratio为4.0默认为1.5,改变输出的文本框大小,参数执行文本检测预测得到结果\n",
"!python tools/infer_det.py -c configs/det/det_mv3_db.yml \\\n",
" -o Global.checkpoints=./pretrain_models/det_mv3_db_v2.0_train/best_accuracy \\\n",
" Global.infer_img=./doc/imgs_en/img_12.jpg \\\n",
" PostProcess.unclip_ratio=4.0\n",
"# 注有关PostProcess参数和Global参数介绍与使用参考 https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/doc/doc_ch/config.md"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7ff4344b8bd0>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmAAAAF/CAYAAADuA3UDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzsvc2ubcuSHvRFZOYYa1e5QcfulGmAgBdA4g2Q6NGFFzBI8AC0aNDhHfwEiCYNSzyD24BAFhLCYCTTtO/Zc4zMCBrxkzHm3mXu6ZzaKs+s2jp3rTXn+MmMjIz44osIUlV8xmd8xmd8xmd8xmd8xh83+G/6AT7jMz7jMz7jMz7jM/51Gx8D7DM+4zM+4zM+4zM+4w8eHwPsMz7jMz7jMz7jMz7jDx4fA+wzPuMzPuMzPuMzPuMPHh8D7DM+4zM+4zM+4zM+4w8eHwPsMz7jMz7jMz7jMz7jDx5/uAFGRP8REf2vRPRPiOi/+qPv/xmf8Rmf8Rmf8Rmf8Tc96I+sA0ZEDcD/BuA/BPBPAfxjAP+pqv7Pf9hDfMZnfMZnfMZnfMZn/A2PPxoB+w8A/BNV/d9V9QLw3wH4j//gZ/iMz/iMz/iMz/iMz/gbHX+0AfZXAP7P8vM/9d99xmd8xmd8xmd8xmf8azP63/QDvA8i+gcA/gEAHGP8+3/v7/49KHaYVAHQ/rT/BlAFiPZnUD5HoH0NorwI+feAH8Ow5J/TuKICxPE7tWuXG9rvKD9vf6IfrvvDe2h9rrxyXvqnLxS/jnfWH/8W71z/KypgZqjuZ/Xb+sXiJva+Ss/Lkv+k5YbxfftvmcecG81r0vsCxaMTfvKOBFXZ67WfLO9FROXe2NeH/vCzXU/r5fx99Hlvf9y6LM9Z0DJV/6oFqff3uXmsU33pny8xEUFlC/aWjh+/b3Pxw5s9PwOClnfc8l/kAHu9ysWfz16/X5/3TQ61rP1PN2eVQzyH1lfM+5Yfys3rmubn8pd7LTSvgecE+LPl7bTcv1wz5vh5rTLXfr0qe7H3nm9Ge7/UefnpmuzPx2fEP/DjHt+XYr9YTMF+jjdJKzd8imfZBI+/PXXu46Z/nWA87lP1xtvP5XIiri9yiyrY9/vPFF7s7f2nsvfyJ63LnX9TBfh9r5ZX1dBdujVfzuLbg5tuLVPy103D8xHfRRGhr35418cFyueK7o51yqfNz5f18R/zu1VmVUFM+S5P/V6l4GfP9qMertd8yIn/XaGQJbheL/Q+cH47H3vGtffbOwuIqJy58Un7+0P//XXnTTlQVBXMpm/je/W8e+w1f5f436ETYvzf/88/+39V9e/izxh/tAH2fwH4N8vPf99/l0NV/yGAfwgAf/+v/kr/i//sPwcRgZkx5wQ4NiyhtYa1ln3RN6wSwGzAni5B7x0iAsB+f6+Zk9WIwcwQEbAKhBhrLbTRbUFA+VkReQooKVpruO+F1lo+v8gEyO5H2L+/ronzPKGqEF9iBoEZmHOmYdRag6q971oLzPaMWAJVxVR7JyLCnLMYIvbdeFciAjXO+0xZNmfE+Z2Yu5gvInp8H2zvr2v/Lp41PjfGwDVvtNbQWsOcNr85b9Pmh5kxZYHRAAgEmmvTBPksynbf1pod6ABUCYDk9QHgdV/49u0bZC4QtXynwfaZCc057b1DxeZtzok5J0QEvTPGGPh+vTBax33f9p6NAd3gMBGht4YGwoRCls2lxLVdUUm5BzWTk0Ymt/F77g0QhYjk3MVcka8fA5Blcys6fa7j2jYP13XZOsyF89sXILoPDiYzIlw5xN45jiNlT1V9bQZEFUyuSBpjiqC5fIwxHteJ+6jrGzGhR+eW+++6LruXyEM5ERFk2jUFClJgXTf6eeQzzTkxxkh5aK1hqcufIvfEvSZ6s3u0Xo1P5LPG70JW9/7SlOG6v2P/kNoc1r1FaLmvqz5gZpBi72kyeRCSvL9q6Kob3dcidMYUwSCGuLxOn7Pmeyj0UA4mXL7HmDsUNxoPQAQMgog9G/c956r6eN54rtjPoQNivUUEYMa6b9OHrUGJQEKpK0LWl9wp3yEnvXfcl+2R4zhwXZfpLQYa77UF2X3i+WJuTPcwmj9f7x2v+8bgkc+Zcq2CTmz6bF62V3uHQKFCKYPKhAb7mRuwpsnCVFt3nespO1gQhA72OS33FhEcxwFmO0JfrxfGGDaXtiuwXH6qfo9/IYchY83PimuZbBAR5Db9zqNjrQVqQIftIVWFMIFdlygh12Ap+X/Fz6jbzjpZECVQt3t3uFyK2Jqo4rUmjma6ahx2vjJ1fL8vjEYYY+D1eqER23fJ9lHIV/z97DYX933nuwsJhj9PzGVrDb/99htev1345//sn+P49g3/zr/3b+M8T7CfR3MtrGk6JWS3tQYswVTsM5z5cTYBew/K7Xu9MUQm2uj2XzpSHvPcWjPfZ6qAe0PXBtWV97H/diwsdOpbpgH81//tf/N/4M8cf3QI8h8D+HeJ6N8iogPAfwLgf/hXfSEO9lC88TsAj4WPIS5MoSCqkIfgVwMh/i7EUF1ojcyI8sVba+Fe2/AD4AJFbhggr2P3PNLwsnswiBrO88zFI8VDqbY2MMaJ1kZeM953rYX7viG0BSQ2dSg4AGmUtdbAvbmA2VwoIY2gKiipbAhYKphin6XGWLrnccrKue29g5nRe8/fxbVjfuPd67zHAUNs104DUgiXLlBv6L2j947mm5pEMedG7VL5QfOQrgqNmfGaN5SfB81932lcLn+Hr69jG3RsB+DX1xeoccocM7uXtQ+rtRY05iIU5bohMrH8v70zGnEakLFmcT3uLee49573yvsRmVGBVb7X/B9hqQK+Bl9f2/iK70NsjqixyS0T2uhYKkU5bcVVje44/Os+q4ZMrOl2YsgP6L0HwtGo8grYgRRGTjWYcl7972EUxD6on49net/H73NcDb+Q15DZGFVvhLzEvnxcFy11T8xH3D/kMQ05NQeMgfzXiaFzpUERexYAul83dRvRntNwlFRS/4gbx3CjXGUbUaE7FGYYp24LvVN0Xqx77GdmTj1rDgogYHA/oNTM2GM7wGJe5pzQJblf4vdXMTCvNc348fkK/RN6O/Z1yDX3bvfz5wwdPvx6oRfrWudc+l6qcvrYV72BR7e9cx5QJnRiM8wI4NHRjpG69rGnYFPe+4ExTpznNxC1dNrGGHm/NRWy8HzfdKTKWvuctNYgJEDjNOLe1y0dA5h+psYpJ3Xt7vsGQ6DrRiMAohjN1pgaPwxoZsZ937ZGvg6hE1preL1eOZdf44As4HqZ40PcMZeC/ZnSoXhddtYRcl3j73H+hH6In2XZ89hcOtIpko5yU6RjkN+RianTJV3yvda6c+7SaQRAndCO9ng3oobr+g5moNlk+T8bJgcNpIxrTTNcecuKkD331Kf+/D3jDzXAVHUC+C8B/I8A/hcA/72q/k//6m/J42APjzy8+t471D3vQL9CaAMheveO0ysqxpnAFOmUlYdnIALxPSU8Dqv7vtMqz823Fno/0PiEYD8XNQa3AdWVG8qUhT3X9+/fc8PGgRGbJgwoNPv9dV25cV+vVyopAABvFEtVcfQBXVuYMwTpB1YYabGJYyNsgV52cBWj5n1TVYP4sencMw3lUJUKACgRrjUB7hAlTIV7gmzeq3t1V/GYApVLxQSFkqA1wlp3GtDEiiU3Wif73+s2A7sTlATXde13UAajQafi5MN2ni7IutEY7gFPKPuBOIuX6qgBEZnX5nIZB0xFVgLhifvGHIQ8kSpIN0KVsunXsN2Ah2LJwYRr3o89UL+fxpAWRcEEduUfBlTnBpVtOFXHh9kMOiWkE8FwgzCey2Jk+fPPDPGys1O2lJB7kNo2oiqiPWWZEYmtvKuhoqrQJWlEvct5/K7KazXiYh1s7vZ7v3//3bCsvwfwQM5jD8U9v76+HgexqgK00ZCY03rYJHrieiD2JMP2C6vtlyWSqKztC3oYYrGe785nNTaY+aHT6kGZ6xbGE5a9rbIZHGhQMgSfiHDL3q8hu2gblRaD6zBaQ0fDoA4WQi9GcMxV7IGc18a2TVX
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAFkCAYAAAAT9C6pAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xl8VPW9//HXZ2aSsIkEVJSEVVDcI/tirYoVDAh67VVaa7V6RevSetveVttfb5fb9dra1l2rrWhtudZWRYxSxVqrgEiAIopI2CQBd1ZZksx8f3/MmTBZ50wyh5kk7+fjkcecOfM953zPmZOcd77fM/M15xwiIiIiklmhbFdAREREpCNSyBIREREJgEKWiIiISAAUskREREQCoJAlIiIiEgCFLBEREZEABBKyzGyKma0xswozuymIbYiIiIjkMsv092SZWRh4G/gMUAm8BnzOOfdmRjckIiIiksOCaMkaA1Q459Y756qBOcCMALYjIiIikrOCCFlFwOak55XePBEREZFOI5KtDZvZLGAWQJjwyG70zFZVRERERHzbxbYPnXOHpyoXRMiqAvonPS/25tXjnLsPuA+gp/V2Y21SAFURERERyazn3WOb/JQLorvwNWCYmQ02s3xgJjA3gO2IiIiI5KyMt2Q552rN7HpgPhAGfueceyPT2xERERHJZYHck+WcKwPKgli3iIiISHugb3wXERERCYBCloiIiEgAFLJEREREAqCQJSIiIhIAhSwRERGRAChkiYiIiARAIUtEREQkAApZIiIiIgFQyBIREREJgEKWiIiISAAUskREREQCoJAlIiIiEgCFLBEREZEAKGSJiIiIBEAhS0RERCQAClkiIiIiAVDIEhEREQmAQpaIiIhIABSyRERERAKgkCUiIiISAIUsERERkQAoZImIiIgEQCFLREREJAAKWSIiIiIBUMgSERERCYBCloiIiEgAFLJEREREAqCQJSIiIhIAhSwRERGRAChkiYiIiARAIUtEREQkAApZIiIiIgFQyBIREREJgEKWiIiISAAUskREREQCkDJkmdnvzOx9M1uVNK+3mT1nZmu9x0JvvpnZbWZWYWYrzWxEkJUXERERyVV+WrIeBKY0mHcTsMA5NwxY4D0HOBcY5v3MAu7OTDVFRERE2peUIcs59xLwcYPZM4DZ3vRs4Pyk+Q+5uMVALzM7KlOVFREREWkvWntPVl/n3FZv+l2grzddBGxOKlfpzRMRERHpVNp847tzzgEu3eXMbJaZLTWzpTXsb2s1RERERHJKa0PWe4luQO/xfW9+FdA/qVyxN68R59x9zrlRzrlReRS0shoiIiIiuam1IWsucJk3fRnwZNL8L3qfMhwH7EjqVhQRERHpNCKpCpjZn4AzgMPMrBL4HvAz4FEzuxLYBFzkFS8DSoEKYA/wpQDqLCIiIpLzUoYs59znmnlpUhNlHXBdWyslIiIi0t7pG99FREREAqCQJSIiIhIAhSwRERGRAChkiYiIiARAIUtEREQkAApZIiIiIgFI+RUOIh3BPZteZkCkW93z0qIR7DtvDF2eWpLFWomISEemkCUdXuSoIxmc16PevPlbVgAr4F6IuhgAP/zwJF4d1R0AV1N9sKspIiIdjEKWdHg/XvgktDA+ZtjiveY/OPyN+PgFTYi6GGEL1T0u2Bvm9C7VhDBG/OJ6ln3jDkqLRnDzupX87+QZRCs2EBk0gNqN79Rbz/YvjqfXQ4uIFPWjtmpLyxU3A+dv7PXQycOZNucVttV257/6vA5AhHC9Oic/bovtpTDUlQsqSnl8aBlTp1/K03Mf5rwTJ/HUqgVNHp/k45D8HGBnbB8zB0zkU//ayz9P7tL0/nn7YwUF7JpxKoNvfIufFc/jqmPO5tRFeyg/NcS6X45jzcy76rab2N7gJ2dRMf0epk39AvOe/kOzxyFRt8Ryu2P76Gr5xHCEsGaPR8PtJT9etuksZg98gakDx/D0piV162rq+CRral2JsvtdDRHCjZbZ66q56DOXMvPxv7Nk19F8/8gXuKT/RD586hgOO+/t+PaOP4bom2/z9gOjOOaq5Tz+ziIKLMJ33y/hB0csZ1rRSCpvnkDxTxcS7nUoD75exuVnfoHvP/co3x08mg0/Gc/gby+q/9aMOhG3dFW9cy48bAjRtesb1TEyeCAbf9GD3494kGPzaimZ91UqzruH6aNKqd36LpaXj6upxkaeAP9aw8/WvsIPN0/jO/2fpiQ/fsmJ4ZhWNLLZ91GkozDn8494kHpabzfWGn2BvEjbmVFWWd7shVBEsmN3bB8XFo/LdjVEWuV591i5c25UqnK68kiHNq9yqQKWSA7qEeqS7SqIBE5XH+nQvv1eyn80REREAqF7sqRDe31MGN5JXa6jS9x35Md33y+h/NQQn1w4lu5/eZWyqmX1Xn+rZj+DI2EuGDCei96o4oHvn88/b72rxXUm3ysFsNvtp5vl17u3Cai7/2ld7V4GRvKJECaGY8SSS1k6ZjbTi0Y3XnniXi/vXqCESP9iajdX+t7vdIV79iS6axc4x4ezxtP3xff5yfw/UrbrZP5xctdG5W/f9Ao3DJwIwDfXvc7X7riaZd+4o/F6O0nLa42LZrsKIoHTPVnSaWz6wQQGfm8hP9mwhG8PHsPcqteavmgH4K5NL3NYOMzMoWcS27cvvYVDYYjVvyCFhw4mWrEhgzUUadm8qnLG/s/1vPbdOwEYWnY1FaX3AjS6sX9Dze56X5mSKAMHgnZp0YiDVXWRjPN7T5ZCloiIBCsUBhcjfNwwom++ne3aiLSZ35Cl7kIREQmW1xKrgCWdTefo/BcRERE5yBSyRERERAKgkCUiIiISAIUsERERkQAoZImIiIgEQCFLREREJAAKWSIiIiIBUMgSERERCYBCloiIiEgAFLJEREREAqCQJSIiIhIAhSwRERGRAChkiYiIiARAIUtEREQkAApZIiIiIgFQyBIREREJgEKWiIiISAAUskREREQCkDJkmVl/M/u7mb1pZm+Y2Ve9+b3N7DkzW+s9FnrzzcxuM7MKM1tpZiOC3gkRERGRXOOnJasW+Lpz7nhgHHCdmR0P3AQscM4NAxZ4zwHOBYZ5P7OAuzNeaxEREZEclzJkOee2OueWedO7gNVAETADmO0Vmw2c703PAB5ycYuBXmZ2VMZrLiIiIpLD0rony8wGAacCrwJ9nXNbvZfeBfp600XA5qTFKr15IiIiIp2G75BlZj2AvwA3Oud2Jr/mnHOAS2fDZjbLzJaa2dIa9qezqIiIiEjO8xWyzCyPeMB6xDn3V2/2e4luQO/xfW9+FdA/afFib149zrn7nHOjnHOj8ihobf1FREREcpKfTxca8ACw2jl3a9JLc4HLvOnLgCeT5n/R+5ThOGBHUreiiIiISKcQ8VFmInAp8LqZrfDmfRv4GfComV0JbAIu8l4rA0qBCmAP8KWM1lhERESkHUgZspxzLwPWzMuTmijvgOvaWC8RERGRdk3f+C4iIiISAIUsERERkQAoZImIiIgEQCFLREREJAAKWSIiIiIBUMgSERERCYBCloiIiEgAFLJEREREAqCQJSIiIhIAhSwRERGRAChkiYiIiARAIUtEREQkAApZIiIiIgFQyBIREREJgEKWiIiISAAUskREREQCoJAlIiIiEoBItisg0tncs+llrp0xi9iKN30vU1a1rN7z/a6WC4rHZLpqmWGGhcOEenQnun1H69cz7mTCb27k8dULiBAmbCHej37CxD9+gyHfWpTWqsqqlhG25v+nnNyvpMXlv17xBpO67m+0jqiLcew/ruC/R8zjkeHFadWpKet+OY7VM+8khFFaNKLN6xOR7DLnXLbrQE/r7cbapGxXQ3KY5eXjaqrjF/BRJxKqqOQny57lpuGns+5/RjD4iT1EVr/DoL/tZd3ofVmqpFFWWV73tKWLOsQv0H4upKFu3XimYmHKclEXq9tuqtDQlO+sX8GPh8SXSz7e+PwbMX/LCt/bSq7rh9FPKAx1ZWdsHz1DXermtySd/fNTr1TrSxXSGoq6GOedOInotm2+lzl
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = Image.open('./output/det_db/det_results/img_12.jpg')\n",
"img = np.array(img)\n",
"\n",
"# 画出读取的图片\n",
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(img)\n",
"\n",
"img = Image.open('./vis_segmentation.png')\n",
"img = np.array(img)\n",
"\n",
"# 画出读取的图片\n",
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(img)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"从上述代码的运行结果可以发现增大DB后处理的unclip_ratio参数之后预测的文本框明显变大了。因此当训练结果不符合我们预期时可以通过调整后处理参数调整文本检测结果。另外可以尝试调整其他三个参数threshbox_threshmax_candidates对比检测结果。"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## 3.5 损失函数定义\n",
"\n",
"\n",
"由于训练阶段获取了3个预测图所以在损失函数中也需要结合这3个预测图与它们对应的真实标签分别构建3部分损失函数。总的损失函数的公式定义如下:\n",
"\n",
"$L = L_b + \\alpha \\times L_s + \\beta \\times L_t$\n",
"\n",
"其中,$L$为总的损失,$L_s$为概率图损失,在本实验中使用了带 OHEMonline hard example mining 的 Dice 损失,$L_t$为阈值图损失,在本实验中使用了预测值和标签间的$L_1$距离,$L_b$为文本二值图的损失函数。$\\alpha$和$\\beta$为权重系数本实验中分别将其设为5和10。\n",
"\n",
"三个loss $L_b$$L_s$$L_t$分别是Dice Loss、Dice Loss(OHEM)、MaskL1 Loss接下来分别定义这3个部分\n",
"\n",
"- Dice Loss是比较预测的文本二值图和标签之间的相似度常用于二值图像分割代码实现参考[链接](https://github.com/PaddlePaddle/PaddleOCR/blob/81ee76ad7f9ff534a0ae5439d2a5259c4263993c/ppocr/losses/det_basic_loss.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L109)。公式如下:\n",
"\n",
"$dice\\_loss = 1 - \\frac{2 \\times intersection\\_area}{total\\_area}$\n",
"\n",
"- Dice Loss(OHEM)是采用带OHEM的Dice Loss目的是为了改善正负样本不均衡的问题。OHEM为一种特殊的自动采样方式可以自动的选择难样本进行loss的计算从而提升模型的训练效果。这里将正负样本的采样比率设为1:3。代码实现参考[链接](https://github.com/PaddlePaddle/PaddleOCR/blob/81ee76ad7f9ff534a0ae5439d2a5259c4263993c/ppocr/losses/det_basic_loss.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L95)。\n",
"\n",
"- MaskL1 Loss是计算预测的文本阈值图和标签间的$L_1$距离。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"\n",
"from paddle import nn\n",
"import paddle\n",
"from paddle import nn\n",
"import paddle.nn.functional as F\n",
"\n",
"\n",
"# DB损失函数\n",
"# 详细代码实现参考https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/losses/det_db_loss.py\n",
"class DBLoss(nn.Layer):\n",
" \"\"\"\n",
" Differentiable Binarization (DB) Loss Function\n",
" args:\n",
" param (dict): the super paramter for DB Loss\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
" balance_loss=True,\n",
" main_loss_type='DiceLoss',\n",
" alpha=5,\n",
" beta=10,\n",
" ohem_ratio=3,\n",
" eps=1e-6,\n",
" **kwargs):\n",
" super(DBLoss, self).__init__()\n",
" self.alpha = alpha\n",
" self.beta = beta\n",
" # 声明不同的损失函数\n",
" self.dice_loss = DiceLoss(eps=eps)\n",
" self.l1_loss = MaskL1Loss(eps=eps)\n",
" self.bce_loss = BalanceLoss(\n",
" balance_loss=balance_loss,\n",
" main_loss_type=main_loss_type,\n",
" negative_ratio=ohem_ratio)\n",
"\n",
" def forward(self, predicts, labels):\n",
" predict_maps = predicts['maps']\n",
" label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[\n",
" 1:]\n",
" shrink_maps = predict_maps[:, 0, :, :]\n",
" threshold_maps = predict_maps[:, 1, :, :]\n",
" binary_maps = predict_maps[:, 2, :, :]\n",
" # 1. 针对文本预测概率图,使用二值交叉熵损失函数\n",
" loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,\n",
" label_shrink_mask)\n",
" # 2. 针对文本预测阈值图使用L1距离损失函数\n",
" loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,\n",
" label_threshold_mask)\n",
" # 3. 针对文本预测二值图使用dice loss损失函数\n",
" loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,\n",
" label_shrink_mask)\n",
"\n",
" # 4. 不同的损失函数乘上不同的权重\n",
" loss_shrink_maps = self.alpha * loss_shrink_maps\n",
" loss_threshold_maps = self.beta * loss_threshold_maps\n",
"\n",
" loss_all = loss_shrink_maps + loss_threshold_maps \\\n",
" + loss_binary_maps\n",
" losses = {'loss': loss_all, \\\n",
" \"loss_shrink_maps\": loss_shrink_maps, \\\n",
" \"loss_threshold_maps\": loss_threshold_maps, \\\n",
" \"loss_binary_maps\": loss_binary_maps}\n",
" return losses"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## 3.6 评估指标\n",
"\n",
"考虑到DB后处理检测框多种多样并不是水平的本次试验中采用简单计算IOU的方式来评测计算代码参考[icdar Challenges 4的文本检测评测方法](https://rrc.cvc.uab.es/?ch=4&com=mymethods&task=1)。\n",
"\n",
"\n",
"文本检测的计算指标有三个分别是PrecisionRecall和Hmean三个指标的计算逻辑为\n",
"1. 创建[n, m]大小的一个矩阵叫做iouMat其中n为GT(ground truth)box的个数m为检测到的框数量其中n,m为除去了文本标定为###的框数量;\n",
"2. 在iouMat中统计IOU大于阈值0.5的个数将这个值除以gt个数n得到Recall\n",
"3. 在iouMat中统计IOU大于阈值0.5的个数将这个值除以检测框m的个数得到Precision\n",
"4. Hmean的指标计算方式同F1-score的计算方式公式如下\n",
"\n",
"$$\n",
"Hmean = 2.0* \\frac{Precision * Recall}{Precision + Recall }\n",
"$$\n",
"\n",
"\n",
"文本检测metric指标计算的核心代码如下所示完整代码实现参考[链接](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/metrics/det_metric.py)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# 文本检测metric指标计算方式如下\n",
"# 完整代码参考 https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/ppocr/metrics/det_metric.py\n",
"if len(gtPols) > 0 and len(detPols) > 0:\n",
" outputShape = [len(gtPols), len(detPols)]\n",
"\n",
" # 1. 创建[n, m]大小的矩阵用于保存计算的IOU\n",
" iouMat = np.empty(outputShape)\n",
" gtRectMat = np.zeros(len(gtPols), np.int8)\n",
" detRectMat = np.zeros(len(detPols), np.int8)\n",
" for gtNum in range(len(gtPols)):\n",
" for detNum in range(len(detPols)):\n",
" pG = gtPols[gtNum]\n",
" pD = detPols[detNum]\n",
"\n",
" # 2. 计算预测框和GT框之间的IOU\n",
" iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)\n",
" for gtNum in range(len(gtPols)):\n",
" for detNum in range(len(detPols)):\n",
" if gtRectMat[gtNum] == 0 and detRectMat[\n",
" detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:\n",
"\n",
" # 2.1 统计IOU大于阈值0.5的个数\n",
" if iouMat[gtNum, detNum] > self.iou_constraint:\n",
" gtRectMat[gtNum] = 1\n",
" detRectMat[detNum] = 1\n",
" detMatched += 1\n",
" pairs.append({'gt': gtNum, 'det': detNum})\n",
" detMatchedNums.append(detNum)\n",
" \n",
" # 3. IOU大于阈值0.5的个数除以GT框的个数numGtcare得到recall\n",
" recall = float(detMatched) / numGtCare\n",
"\n",
" # 4. IOU大于阈值0.5的个数除以预测框的个数numDetcare得到precision\n",
" precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare\n",
"\n",
" # 5. 通过公式计算得到Hmean指标\n",
" hmean = 0 if (precision + recall) == 0 else 2.0 * \\\n",
" precision * recall / (precision + recall)\n",
" "
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"**思考:**\n",
"1. 对于下图中的情况当GT框与预测框的IOU大于0.5,但是却漏检测文本的情况,上述指标计算是否能准确反映模型的精度?\n",
"2. 实验场景中遇到此类问题,该如何优化模型?\n",
"\n",
"<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/e23f47c7c39f4b92bb494444d3724758401cd9810a8d469690093857c7f05d9e\" width = \"600\"></center>\n",
"<center><br>图 GT框与预测框的标注示例 </br></center>\n",
"<br></br>\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## 3.7 模型训练\n",
"\n",
"完成数据处理,网络定义和损失函数定义后即可开始训练模型了。\n",
"\n",
"训练基于PaddleOCR训练采用参数配置的形式参数文件参考[链接](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.4/configs/det/det_mv3_db.yml),网络结构参数如下:\n",
"```\n",
"Architecture:\n",
" model_type: det\n",
" algorithm: DB\n",
" Transform:\n",
" Backbone:\n",
" name: MobileNetV3\n",
" scale: 0.5\n",
" model_name: large\n",
" Neck:\n",
" name: DBFPN\n",
" out_channels: 256\n",
" Head:\n",
" name: DBHead\n",
" k: 50\n",
"```\n",
"\n",
"优化器参数如下:\n",
"```\n",
"Optimizer:\n",
" name: Adam\n",
" beta1: 0.9\n",
" beta2: 0.999\n",
" lr:\n",
" learning_rate: 0.001\n",
" regularizer:\n",
" name: 'L2'\n",
" factor: 0\n",
"```\n",
"\n",
"后处理参数如下:\n",
"```\n",
"PostProcess:\n",
" name: DBPostProcess\n",
" thresh: 0.3\n",
" box_thresh: 0.6\n",
" max_candidates: 1000\n",
" unclip_ratio: 1.5\n",
"```\n",
"\n",
"...\n",
"\n",
"完整参数配置文件见[det_mv3_db.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.3/configs/det/det_mv3_db.yml)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mkdir: cannot create directory train_data: File exists\n",
"--2021-12-22 22:04:50-- https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams\n",
"Resolving paddle-imagenet-models-name.bj.bcebos.com (paddle-imagenet-models-name.bj.bcebos.com)... 100.67.200.6\n",
"Connecting to paddle-imagenet-models-name.bj.bcebos.com (paddle-imagenet-models-name.bj.bcebos.com)|100.67.200.6|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 16255295 (16M) [application/octet-stream]\n",
"Saving to: ./pretrain_models/MobileNetV3_large_x0_5_pretrained.pdparams.3\n",
"\n",
"MobileNetV3_large_x 100%[===================>] 15.50M 70.9MB/s in 0.2s \n",
"\n",
"2021-12-22 22:04:50 (70.9 MB/s) - ./pretrain_models/MobileNetV3_large_x0_5_pretrained.pdparams.3 saved [16255295/16255295]\n",
"\n"
]
}
],
"source": [
"!mkdir train_data \n",
"!cd train_data && ln -s /home/aistudio/data/data96799/icdar2015 icdar2015\n",
"!wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" 0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)\n",
"/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)\n",
"[2021/12/22 22:05:35] root INFO: Architecture : \n",
"[2021/12/22 22:05:35] root INFO: Backbone : \n",
"[2021/12/22 22:05:35] root INFO: model_name : large\n",
"[2021/12/22 22:05:35] root INFO: name : MobileNetV3\n",
"[2021/12/22 22:05:35] root INFO: scale : 0.5\n",
"[2021/12/22 22:05:35] root INFO: Head : \n",
"[2021/12/22 22:05:35] root INFO: k : 50\n",
"[2021/12/22 22:05:35] root INFO: name : DBHead\n",
"[2021/12/22 22:05:35] root INFO: Neck : \n",
"[2021/12/22 22:05:35] root INFO: name : DBFPN\n",
"[2021/12/22 22:05:35] root INFO: out_channels : 256\n",
"[2021/12/22 22:05:35] root INFO: Transform : None\n",
"[2021/12/22 22:05:35] root INFO: algorithm : DB\n",
"[2021/12/22 22:05:35] root INFO: model_type : det\n",
"[2021/12/22 22:05:35] root INFO: Eval : \n",
"[2021/12/22 22:05:35] root INFO: dataset : \n",
"[2021/12/22 22:05:35] root INFO: data_dir : ./train_data/icdar2015/text_localization/\n",
"[2021/12/22 22:05:35] root INFO: label_file_list : ['./train_data/icdar2015/text_localization/test_icdar2015_label.txt']\n",
"[2021/12/22 22:05:35] root INFO: name : SimpleDataSet\n",
"[2021/12/22 22:05:35] root INFO: transforms : \n",
"[2021/12/22 22:05:35] root INFO: DecodeImage : \n",
"[2021/12/22 22:05:35] root INFO: channel_first : False\n",
"[2021/12/22 22:05:35] root INFO: img_mode : BGR\n",
"[2021/12/22 22:05:35] root INFO: DetLabelEncode : None\n",
"[2021/12/22 22:05:35] root INFO: DetResizeForTest : \n",
"[2021/12/22 22:05:35] root INFO: image_shape : [736, 1280]\n",
"[2021/12/22 22:05:35] root INFO: NormalizeImage : \n",
"[2021/12/22 22:05:35] root INFO: mean : [0.485, 0.456, 0.406]\n",
"[2021/12/22 22:05:35] root INFO: order : hwc\n",
"[2021/12/22 22:05:35] root INFO: scale : 1./255.\n",
"[2021/12/22 22:05:35] root INFO: std : [0.229, 0.224, 0.225]\n",
"[2021/12/22 22:05:35] root INFO: ToCHWImage : None\n",
"[2021/12/22 22:05:35] root INFO: KeepKeys : \n",
"[2021/12/22 22:05:35] root INFO: keep_keys : ['image', 'shape', 'polys', 'ignore_tags']\n",
"[2021/12/22 22:05:35] root INFO: loader : \n",
"[2021/12/22 22:05:35] root INFO: batch_size_per_card : 1\n",
"[2021/12/22 22:05:35] root INFO: drop_last : False\n",
"[2021/12/22 22:05:35] root INFO: num_workers : 8\n",
"[2021/12/22 22:05:35] root INFO: shuffle : False\n",
"[2021/12/22 22:05:35] root INFO: use_shared_memory : False\n",
"[2021/12/22 22:05:35] root INFO: Global : \n",
"[2021/12/22 22:05:35] root INFO: cal_metric_during_train : False\n",
"[2021/12/22 22:05:35] root INFO: checkpoints : None\n",
"[2021/12/22 22:05:35] root INFO: debug : False\n",
"[2021/12/22 22:05:35] root INFO: distributed : False\n",
"[2021/12/22 22:05:35] root INFO: epoch_num : 1200\n",
"[2021/12/22 22:05:35] root INFO: eval_batch_step : [0, 2000]\n",
"[2021/12/22 22:05:35] root INFO: infer_img : doc/imgs_en/img_10.jpg\n",
"[2021/12/22 22:05:35] root INFO: log_smooth_window : 20\n",
"[2021/12/22 22:05:35] root INFO: pretrained_model : ./pretrain_models/MobileNetV3_large_x0_5_pretrained\n",
"[2021/12/22 22:05:35] root INFO: print_batch_step : 10\n",
"[2021/12/22 22:05:35] root INFO: save_epoch_step : 1200\n",
"[2021/12/22 22:05:35] root INFO: save_inference_dir : None\n",
"[2021/12/22 22:05:35] root INFO: save_model_dir : ./output/db_mv3/\n",
"[2021/12/22 22:05:35] root INFO: save_res_path : ./output/det_db/predicts_db.txt\n",
"[2021/12/22 22:05:35] root INFO: use_gpu : True\n",
"[2021/12/22 22:05:35] root INFO: use_visualdl : False\n",
"[2021/12/22 22:05:35] root INFO: Loss : \n",
"[2021/12/22 22:05:35] root INFO: alpha : 5\n",
"[2021/12/22 22:05:35] root INFO: balance_loss : True\n",
"[2021/12/22 22:05:35] root INFO: beta : 10\n",
"[2021/12/22 22:05:35] root INFO: main_loss_type : DiceLoss\n",
"[2021/12/22 22:05:35] root INFO: name : DBLoss\n",
"[2021/12/22 22:05:35] root INFO: ohem_ratio : 3\n",
"[2021/12/22 22:05:35] root INFO: Metric : \n",
"[2021/12/22 22:05:35] root INFO: main_indicator : hmean\n",
"[2021/12/22 22:05:35] root INFO: name : DetMetric\n",
"[2021/12/22 22:05:35] root INFO: Optimizer : \n",
"[2021/12/22 22:05:35] root INFO: beta1 : 0.9\n",
"[2021/12/22 22:05:35] root INFO: beta2 : 0.999\n",
"[2021/12/22 22:05:35] root INFO: lr : \n",
"[2021/12/22 22:05:35] root INFO: learning_rate : 0.001\n",
"[2021/12/22 22:05:35] root INFO: name : Adam\n",
"[2021/12/22 22:05:35] root INFO: regularizer : \n",
"[2021/12/22 22:05:35] root INFO: factor : 0\n",
"[2021/12/22 22:05:35] root INFO: name : L2\n",
"[2021/12/22 22:05:35] root INFO: PostProcess : \n",
"[2021/12/22 22:05:35] root INFO: box_thresh : 0.6\n",
"[2021/12/22 22:05:35] root INFO: max_candidates : 1000\n",
"[2021/12/22 22:05:35] root INFO: name : DBPostProcess\n",
"[2021/12/22 22:05:35] root INFO: thresh : 0.3\n",
"[2021/12/22 22:05:35] root INFO: unclip_ratio : 1.5\n",
"[2021/12/22 22:05:35] root INFO: Train : \n",
"[2021/12/22 22:05:35] root INFO: dataset : \n",
"[2021/12/22 22:05:35] root INFO: data_dir : ./train_data/icdar2015/text_localization/\n",
"[2021/12/22 22:05:35] root INFO: label_file_list : ['./train_data/icdar2015/text_localization/train_icdar2015_label.txt']\n",
"[2021/12/22 22:05:35] root INFO: name : SimpleDataSet\n",
"[2021/12/22 22:05:35] root INFO: ratio_list : [1.0]\n",
"[2021/12/22 22:05:35] root INFO: transforms : \n",
"[2021/12/22 22:05:35] root INFO: DecodeImage : \n",
"[2021/12/22 22:05:35] root INFO: channel_first : False\n",
"[2021/12/22 22:05:35] root INFO: img_mode : BGR\n",
"[2021/12/22 22:05:35] root INFO: DetLabelEncode : None\n",
"[2021/12/22 22:05:35] root INFO: IaaAugment : \n",
"[2021/12/22 22:05:35] root INFO: augmenter_args : \n",
"[2021/12/22 22:05:35] root INFO: args : \n",
"[2021/12/22 22:05:35] root INFO: p : 0.5\n",
"[2021/12/22 22:05:35] root INFO: type : Fliplr\n",
"[2021/12/22 22:05:35] root INFO: args : \n",
"[2021/12/22 22:05:35] root INFO: rotate : [-10, 10]\n",
"[2021/12/22 22:05:35] root INFO: type : Affine\n",
"[2021/12/22 22:05:35] root INFO: args : \n",
"[2021/12/22 22:05:35] root INFO: size : [0.5, 3]\n",
"[2021/12/22 22:05:35] root INFO: type : Resize\n",
"[2021/12/22 22:05:35] root INFO: EastRandomCropData : \n",
"[2021/12/22 22:05:35] root INFO: keep_ratio : True\n",
"[2021/12/22 22:05:35] root INFO: max_tries : 50\n",
"[2021/12/22 22:05:35] root INFO: size : [640, 640]\n",
"[2021/12/22 22:05:35] root INFO: MakeBorderMap : \n",
"[2021/12/22 22:05:35] root INFO: shrink_ratio : 0.4\n",
"[2021/12/22 22:05:35] root INFO: thresh_max : 0.7\n",
"[2021/12/22 22:05:35] root INFO: thresh_min : 0.3\n",
"[2021/12/22 22:05:35] root INFO: MakeShrinkMap : \n",
"[2021/12/22 22:05:35] root INFO: min_text_size : 8\n",
"[2021/12/22 22:05:35] root INFO: shrink_ratio : 0.4\n",
"[2021/12/22 22:05:35] root INFO: NormalizeImage : \n",
"[2021/12/22 22:05:35] root INFO: mean : [0.485, 0.456, 0.406]\n",
"[2021/12/22 22:05:35] root INFO: order : hwc\n",
"[2021/12/22 22:05:35] root INFO: scale : 1./255.\n",
"[2021/12/22 22:05:35] root INFO: std : [0.229, 0.224, 0.225]\n",
"[2021/12/22 22:05:35] root INFO: ToCHWImage : None\n",
"[2021/12/22 22:05:35] root INFO: KeepKeys : \n",
"[2021/12/22 22:05:35] root INFO: keep_keys : ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask']\n",
"[2021/12/22 22:05:35] root INFO: loader : \n",
"[2021/12/22 22:05:35] root INFO: batch_size_per_card : 16\n",
"[2021/12/22 22:05:35] root INFO: drop_last : False\n",
"[2021/12/22 22:05:35] root INFO: num_workers : 8\n",
"[2021/12/22 22:05:35] root INFO: shuffle : True\n",
"[2021/12/22 22:05:35] root INFO: use_shared_memory : False\n",
"[2021/12/22 22:05:35] root INFO: train with paddle 2.2.1 and device CUDAPlace(0)\n",
"[2021/12/22 22:05:35] root INFO: Initialize indexs of datasets:['./train_data/icdar2015/text_localization/train_icdar2015_label.txt']\n",
"[2021/12/22 22:05:35] root INFO: Initialize indexs of datasets:['./train_data/icdar2015/text_localization/test_icdar2015_label.txt']\n",
"W1222 22:05:35.374327 18562 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1\n",
"W1222 22:05:35.378737 18562 device_context.cc:465] device: 0, cuDNN Version: 7.6.\n",
"[2021/12/22 22:05:38] root INFO: The shape of model params neck.in2_conv.weight [256, 16, 1, 1] not matched with loaded params last_conv.weight [1280, 480, 1, 1] !\n",
"[2021/12/22 22:05:38] root INFO: The shape of model params neck.in3_conv.weight [256, 24, 1, 1] not matched with loaded params out.weight [1280, 1000] !\n",
"[2021/12/22 22:05:38] root INFO: The shape of model params neck.in4_conv.weight [256, 56, 1, 1] not matched with loaded params out.bias [1000] !\n",
"[2021/12/22 22:05:38] root INFO: loaded pretrained_model successful from ./pretrain_models/MobileNetV3_large_x0_5_pretrained.pdparams\n",
"[2021/12/22 22:05:38] root INFO: train dataloader has 63 iters\n",
"[2021/12/22 22:05:38] root INFO: valid dataloader has 500 iters\n",
"[2021/12/22 22:05:38] root INFO: During the training process, after the 0th iteration, an evaluation is run every 2000 iterations\n",
"[2021/12/22 22:05:38] root INFO: Initialize indexs of datasets:['./train_data/icdar2015/text_localization/train_icdar2015_label.txt']\n",
"[2021/12/22 22:05:57] root INFO: epoch: [1/1200], iter: 10, lr: 0.001000, loss: 8.111980, loss_shrink_maps: 4.886614, loss_threshold_maps: 2.247729, loss_binary_maps: 0.977636, reader_cost: 1.03414 s, batch_cost: 1.90380 s, samples: 176, ips: 9.24466\n",
"[2021/12/22 22:06:04] root INFO: epoch: [1/1200], iter: 20, lr: 0.001000, loss: 7.016824, loss_shrink_maps: 4.852873, loss_threshold_maps: 1.217604, loss_binary_maps: 0.969366, reader_cost: 0.00025 s, batch_cost: 0.68818 s, samples: 160, ips: 23.24974\n",
"[2021/12/22 22:06:12] root INFO: epoch: [1/1200], iter: 30, lr: 0.001000, loss: 6.830338, loss_shrink_maps: 4.769650, loss_threshold_maps: 1.075279, loss_binary_maps: 0.953149, reader_cost: 0.13629 s, batch_cost: 0.77564 s, samples: 160, ips: 20.62814\n",
"^C\n",
"main proc 18582 exit, kill process group 18562\n",
"main proc 18583 exit, kill process group 18562\n",
"main proc 18584 exit, kill process group 18562\n",
"main proc 18580 exit, kill process group 18562\n",
"main proc 18581 exit, kill process group 18562\n",
"main proc 18577 exit, kill process group 18562\n"
]
}
],
"source": [
"!python tools/train.py -c configs/det/det_mv3_db.yml"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"网络训练后的模型默认保存在PaddleOCR/output/db_mv3/目录下如果想更换保存目录可以在训练时设置参数Global.save_model_dir比如\n",
"```\n",
"# 设置参数文件里的Global.save_model_dir可以更改模型保存目录\n",
"python tools/train.py -c configs/det/det_mv3_db.yml -o Global.save_model_dir=\"./output/save_db_train/\"\n",
"```\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## 3.8 模型评估\n",
"\n",
"训练过程中默认保存两种模型一种是latest命名的最新训练的模型一种是best_accuracy命名的精度最高的模型。接下来使用保存的模型参数评估在测试集上的precision、recall和hmean\n",
"\n",
"文本检测精度评估代码位于PaddleOCR/ppocr/metrics/det_metric.py中调用tools/eval.py即可进行对训练好的模型做精度评估。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"!python tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./output/db_mv3/best_accuracy"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"## 3.9 模型预测\n",
"\n",
"训练好模型后,也可以使用保存好的模型,对数据集中的某一张图片或者某个文件夹的图像进行模型推理,观察模型预测效果。\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:241: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" 0, 1, 1, 0, 0, 1, 0, 0, 0], dtype=np.bool)\n",
"/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/skimage/morphology/_skeletonize.py:256: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.\n",
"Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.bool)\n",
"[2021/12/22 22:07:35] root INFO: Architecture : \n",
"[2021/12/22 22:07:35] root INFO: Backbone : \n",
"[2021/12/22 22:07:35] root INFO: model_name : large\n",
"[2021/12/22 22:07:35] root INFO: name : MobileNetV3\n",
"[2021/12/22 22:07:35] root INFO: scale : 0.5\n",
"[2021/12/22 22:07:35] root INFO: Head : \n",
"[2021/12/22 22:07:35] root INFO: k : 50\n",
"[2021/12/22 22:07:35] root INFO: name : DBHead\n",
"[2021/12/22 22:07:35] root INFO: Neck : \n",
"[2021/12/22 22:07:35] root INFO: name : DBFPN\n",
"[2021/12/22 22:07:35] root INFO: out_channels : 256\n",
"[2021/12/22 22:07:35] root INFO: Transform : None\n",
"[2021/12/22 22:07:35] root INFO: algorithm : DB\n",
"[2021/12/22 22:07:35] root INFO: model_type : det\n",
"[2021/12/22 22:07:35] root INFO: Eval : \n",
"[2021/12/22 22:07:35] root INFO: dataset : \n",
"[2021/12/22 22:07:35] root INFO: data_dir : ./train_data/icdar2015/text_localization/\n",
"[2021/12/22 22:07:35] root INFO: label_file_list : ['./train_data/icdar2015/text_localization/test_icdar2015_label.txt']\n",
"[2021/12/22 22:07:35] root INFO: name : SimpleDataSet\n",
"[2021/12/22 22:07:35] root INFO: transforms : \n",
"[2021/12/22 22:07:35] root INFO: DecodeImage : \n",
"[2021/12/22 22:07:35] root INFO: channel_first : False\n",
"[2021/12/22 22:07:35] root INFO: img_mode : BGR\n",
"[2021/12/22 22:07:35] root INFO: DetLabelEncode : None\n",
"[2021/12/22 22:07:35] root INFO: DetResizeForTest : \n",
"[2021/12/22 22:07:35] root INFO: image_shape : [736, 1280]\n",
"[2021/12/22 22:07:35] root INFO: NormalizeImage : \n",
"[2021/12/22 22:07:35] root INFO: mean : [0.485, 0.456, 0.406]\n",
"[2021/12/22 22:07:35] root INFO: order : hwc\n",
"[2021/12/22 22:07:35] root INFO: scale : 1./255.\n",
"[2021/12/22 22:07:35] root INFO: std : [0.229, 0.224, 0.225]\n",
"[2021/12/22 22:07:35] root INFO: ToCHWImage : None\n",
"[2021/12/22 22:07:35] root INFO: KeepKeys : \n",
"[2021/12/22 22:07:35] root INFO: keep_keys : ['image', 'shape', 'polys', 'ignore_tags']\n",
"[2021/12/22 22:07:35] root INFO: loader : \n",
"[2021/12/22 22:07:35] root INFO: batch_size_per_card : 1\n",
"[2021/12/22 22:07:35] root INFO: drop_last : False\n",
"[2021/12/22 22:07:35] root INFO: num_workers : 8\n",
"[2021/12/22 22:07:35] root INFO: shuffle : False\n",
"[2021/12/22 22:07:35] root INFO: use_shared_memory : False\n",
"[2021/12/22 22:07:35] root INFO: Global : \n",
"[2021/12/22 22:07:35] root INFO: cal_metric_during_train : False\n",
"[2021/12/22 22:07:35] root INFO: checkpoints : ./pretrain_models/det_mv3_db_v2.0_train/best_accuracy\n",
"[2021/12/22 22:07:35] root INFO: debug : False\n",
"[2021/12/22 22:07:35] root INFO: distributed : False\n",
"[2021/12/22 22:07:35] root INFO: epoch_num : 1200\n",
"[2021/12/22 22:07:35] root INFO: eval_batch_step : [0, 2000]\n",
"[2021/12/22 22:07:35] root INFO: infer_img : ./doc/imgs_en/img_12.jpg\n",
"[2021/12/22 22:07:35] root INFO: log_smooth_window : 20\n",
"[2021/12/22 22:07:35] root INFO: pretrained_model : ./pretrain_models/MobileNetV3_large_x0_5_pretrained\n",
"[2021/12/22 22:07:35] root INFO: print_batch_step : 10\n",
"[2021/12/22 22:07:35] root INFO: save_epoch_step : 1200\n",
"[2021/12/22 22:07:35] root INFO: save_inference_dir : None\n",
"[2021/12/22 22:07:35] root INFO: save_model_dir : ./output/db_mv3/\n",
"[2021/12/22 22:07:35] root INFO: save_res_path : ./output/det_db/predicts_db.txt\n",
"[2021/12/22 22:07:35] root INFO: use_gpu : True\n",
"[2021/12/22 22:07:35] root INFO: use_visualdl : False\n",
"[2021/12/22 22:07:35] root INFO: Loss : \n",
"[2021/12/22 22:07:35] root INFO: alpha : 5\n",
"[2021/12/22 22:07:35] root INFO: balance_loss : True\n",
"[2021/12/22 22:07:35] root INFO: beta : 10\n",
"[2021/12/22 22:07:35] root INFO: main_loss_type : DiceLoss\n",
"[2021/12/22 22:07:35] root INFO: name : DBLoss\n",
"[2021/12/22 22:07:35] root INFO: ohem_ratio : 3\n",
"[2021/12/22 22:07:35] root INFO: Metric : \n",
"[2021/12/22 22:07:35] root INFO: main_indicator : hmean\n",
"[2021/12/22 22:07:35] root INFO: name : DetMetric\n",
"[2021/12/22 22:07:35] root INFO: Optimizer : \n",
"[2021/12/22 22:07:35] root INFO: beta1 : 0.9\n",
"[2021/12/22 22:07:35] root INFO: beta2 : 0.999\n",
"[2021/12/22 22:07:35] root INFO: lr : \n",
"[2021/12/22 22:07:35] root INFO: learning_rate : 0.001\n",
"[2021/12/22 22:07:35] root INFO: name : Adam\n",
"[2021/12/22 22:07:35] root INFO: regularizer : \n",
"[2021/12/22 22:07:35] root INFO: factor : 0\n",
"[2021/12/22 22:07:35] root INFO: name : L2\n",
"[2021/12/22 22:07:35] root INFO: PostProcess : \n",
"[2021/12/22 22:07:35] root INFO: box_thresh : 0.6\n",
"[2021/12/22 22:07:35] root INFO: max_candidates : 1000\n",
"[2021/12/22 22:07:35] root INFO: name : DBPostProcess\n",
"[2021/12/22 22:07:35] root INFO: thresh : 0.3\n",
"[2021/12/22 22:07:35] root INFO: unclip_ratio : 1.5\n",
"[2021/12/22 22:07:35] root INFO: Train : \n",
"[2021/12/22 22:07:35] root INFO: dataset : \n",
"[2021/12/22 22:07:35] root INFO: data_dir : ./train_data/icdar2015/text_localization/\n",
"[2021/12/22 22:07:35] root INFO: label_file_list : ['./train_data/icdar2015/text_localization/train_icdar2015_label.txt']\n",
"[2021/12/22 22:07:35] root INFO: name : SimpleDataSet\n",
"[2021/12/22 22:07:35] root INFO: ratio_list : [1.0]\n",
"[2021/12/22 22:07:35] root INFO: transforms : \n",
"[2021/12/22 22:07:35] root INFO: DecodeImage : \n",
"[2021/12/22 22:07:35] root INFO: channel_first : False\n",
"[2021/12/22 22:07:35] root INFO: img_mode : BGR\n",
"[2021/12/22 22:07:35] root INFO: DetLabelEncode : None\n",
"[2021/12/22 22:07:35] root INFO: IaaAugment : \n",
"[2021/12/22 22:07:35] root INFO: augmenter_args : \n",
"[2021/12/22 22:07:35] root INFO: args : \n",
"[2021/12/22 22:07:35] root INFO: p : 0.5\n",
"[2021/12/22 22:07:35] root INFO: type : Fliplr\n",
"[2021/12/22 22:07:35] root INFO: args : \n",
"[2021/12/22 22:07:35] root INFO: rotate : [-10, 10]\n",
"[2021/12/22 22:07:35] root INFO: type : Affine\n",
"[2021/12/22 22:07:35] root INFO: args : \n",
"[2021/12/22 22:07:35] root INFO: size : [0.5, 3]\n",
"[2021/12/22 22:07:35] root INFO: type : Resize\n",
"[2021/12/22 22:07:35] root INFO: EastRandomCropData : \n",
"[2021/12/22 22:07:35] root INFO: keep_ratio : True\n",
"[2021/12/22 22:07:35] root INFO: max_tries : 50\n",
"[2021/12/22 22:07:35] root INFO: size : [640, 640]\n",
"[2021/12/22 22:07:35] root INFO: MakeBorderMap : \n",
"[2021/12/22 22:07:35] root INFO: shrink_ratio : 0.4\n",
"[2021/12/22 22:07:35] root INFO: thresh_max : 0.7\n",
"[2021/12/22 22:07:35] root INFO: thresh_min : 0.3\n",
"[2021/12/22 22:07:35] root INFO: MakeShrinkMap : \n",
"[2021/12/22 22:07:35] root INFO: min_text_size : 8\n",
"[2021/12/22 22:07:35] root INFO: shrink_ratio : 0.4\n",
"[2021/12/22 22:07:35] root INFO: NormalizeImage : \n",
"[2021/12/22 22:07:35] root INFO: mean : [0.485, 0.456, 0.406]\n",
"[2021/12/22 22:07:35] root INFO: order : hwc\n",
"[2021/12/22 22:07:35] root INFO: scale : 1./255.\n",
"[2021/12/22 22:07:35] root INFO: std : [0.229, 0.224, 0.225]\n",
"[2021/12/22 22:07:35] root INFO: ToCHWImage : None\n",
"[2021/12/22 22:07:35] root INFO: KeepKeys : \n",
"[2021/12/22 22:07:35] root INFO: keep_keys : ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask']\n",
"[2021/12/22 22:07:35] root INFO: loader : \n",
"[2021/12/22 22:07:35] root INFO: batch_size_per_card : 16\n",
"[2021/12/22 22:07:35] root INFO: drop_last : False\n",
"[2021/12/22 22:07:35] root INFO: num_workers : 8\n",
"[2021/12/22 22:07:35] root INFO: shuffle : True\n",
"[2021/12/22 22:07:35] root INFO: use_shared_memory : False\n",
"[2021/12/22 22:07:35] root INFO: train with paddle 2.2.1 and device CUDAPlace(0)\n",
"W1222 22:07:35.524910 18707 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.0, Runtime API Version: 10.1\n",
"W1222 22:07:35.529196 18707 device_context.cc:465] device: 0, cuDNN Version: 7.6.\n",
"[2021/12/22 22:07:38] root INFO: resume from ./pretrain_models/det_mv3_db_v2.0_train/best_accuracy\n",
"[2021/12/22 22:07:38] root INFO: infer_img: ./doc/imgs_en/img_12.jpg\n",
"[2021/12/22 22:07:39] root INFO: The detected Image saved in ./output/det_db/det_results/img_12.jpg\n",
"[2021/12/22 22:07:39] root INFO: success!\n"
]
}
],
"source": [
"!python tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./pretrain_models/det_mv3_db_v2.0_train/best_accuracy Global.infer_img=./doc/imgs_en/img_12.jpg"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"预测后的图像默认保存在./output/det_db/det_results/目录下使用PIL库可视化结果如下"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7ff434483cd0>"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABI4AAALbCAYAAACc+MvnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzs/buubcu2LQSWGtFaH3OdcwWJdK/yJg8Dgy9A4g+Q8NLN5AMwyFTaWBjp5D9gpM0XIPEN2CCBEFIqeRlYiLPWHL21iIpRo9QoEb3Pvfdx9t1IPbTmGmP03h4RNer7Febu+IzP+IzP+IzP+IzP+IzP+IzP+IzP+IzP+IzP2Ef5Fz2Bz/iMz/iMz/iMz/iMz/iMz/iMz/iMz/iMz/jbHB/H0Wd8xmd8xmd8xmd8xmd8xmd8xmd8xmd8xme8HR/H0Wd8xmd8xmd8xmd8xmd8xmd8xmd8xmd8xme8HR/H0Wd8xmd8xmd8xmd8xmd8xmd8xmd8xmd8xme8HR/H0Wd8xmd8xmd8xmd8xmd8xmd8xmd8xmd8xme8HR/H0Wd8xmd8xmd8xmd8xmd8xmd8xmd8xmd8xme8HX91x5GZ/Xtm9l+b2X9rZv/RX/v9n/EZn/EZn/EZn/EZn/EZn/EZn/EZn/EZn/GXDXP3v97LzCqA/wbAvwvgvwfwXwD4v7v7f/VXm8RnfMZnfMZnfMZnfMZnfMZnfMZnfMZnfMZn/EXjr51x9O8A+G/d/b9z9yeA/xTA//WvPIfP+IzP+IzP+IzP+IzP+IzP+IzP+IzP+IzP+AvGX9tx9K8B+P/L3//9+OwzPuMzPuMzPuMzPuMzPuMzPuMzPuMzPuMz/sbG8S96Avsws/8AwH8AAI/z8W//03/6T2EW30VZnQFwGD+Uwaq7N1/lffnTeeFWqsdL8i4DTJ6ddzjnm3MzGHx/Xk5eXzC/M/00//dmXiaTkPUWm/OLuzznCTOwFNHM8vku8+bc8x558wLGDajznlizmY3H+3LtchfXNzd0XrVsy1yJ6RN0/W/3WN5j6+9v94Z7WriH8uH2HIPBiqH3jsTB8aXDN9yzsV2ee9zdx9X73u5Q/wvHgqfxDh9IkHsCS1xNPNUZuK/4m3s4ryTavZKK7iN/TFyb+Cgr3Z7l8hnnuOyByboIPZ3Ihqi+rAeKVkkPcN+AsD1GbsjLfMJigT/mGhS2eQkXu6OtwmF/pLxL36tTNnm25XPnPiawCRO+c7vO9gUn7wg80DLmCfeVsALP1vW8W5tskiCTyf6PZ75jfxs/dsy1v967ItmkUWwPnctROpkTMlmqwFTwU/Ec+ZXwV2UluV/7/u5gEi7xglAKG/lD79Fl6hJyHa94PHmFv8JiPupFLig7132wIXeIbyY3TfzFitebLHjH3n35Zd0Tzv31jgV5Xn/VRe0L1CcpKf8JfrXj7wt+63X6Ln8zkTeKzCbWNpm1MzxdwJs16SscgJGppdBb1qMwJt3PidnLq1dxsTEIme9Kx/rqFV8XMO7Inngw77UXaMml+rBFzgiPXCAV36gO8R7UO/Oan6eMFd5rBvTustW2wHbSig25ITNYZMA6i6n3/YofT/nCW5eZv1uCfGWmc1MW+R5GuSvCL9+/bCPSfU93eZrv4yTeI//KHRSB5kRe+O0yLOeQv6tu/bIOffu6mPX6lRm7Plc3ZnmYv9z9nnXpjPUaw3v62ODxdjXypF/gyMTxeH73nr8XtRO2saxn52u8YiHebbN+RXp8tr177QZrX9fOteja4vqJCbpn874xXcQO6PellOVdeh/nwnv4/pI4M/nvsuxNZPxJuSAMYUGt3IC4+EWnhGCGyEPb+ecCs/l79w6DofcGd+C+Lrg7jvNErQfqUUJfeGvbrzY/5zHt83fyXxZmUw/Zzcht54Q/y8s2XsR3gvPdDZt8wcr3x0eLP2GBq3DzV+7wyp42lXReaaF5+S954uv4H//n/+l/cfd/9mcvxF/fcfQ/APg35O9/fXyWw93/EwD/CQD86//qv+b/r//w/4FSClpr6L0LQnb03vH19YXruuBW0HtHKQXHEcuywaTv+8Z5nrjvG7VWuDt673ldInbv8bfF71/nI79vraGUknPgc/gTiHf1HvMyM7gBtda8J643tNbGcwPpaq1jSXEN18vfUeKa+77jsrvl3MmE+Ewbyh8/j/caSiljnjG/WivMKkop+P7+BgDUajlPzmsXjj6USs7vOA50eMKGe5CCAobe72Au3YE6k9wIl+N45PMA5F6hGKqVSXh3y99LKfB05CTu5Dx4zXVdKKXgPE9c17URcFyj+8p7UQze4ndlWK21/Ix73FrLPWyt4evrC+6O+75xHAfMDM/nE2aGsx7LWmMfAkc5n947rBbc940CSxwifvG+UuIad8fj8cB1Xfmdu+M8T7TWch4AcPeW+3Pfd/7+fD7x9fWV1xJuhH+tNdYJQ8Ocu8KfsCHucb60Q45SkwZzPUMZ7L3jKDX3t3lPPOae9YbAt9Zg5olrbljwgrDUYWY4Sk0+ojiadC/rtFqSZri/Ok+u9b5vlKPmvikulVJQa819OXzQ18BbwknhmDyurHhHeCU+dq73WOZv3dENOfcUuGboBhy28YvxfF7fe8d5nugDfwmPUgrq4EtKO3cPmup3W96589aYR00Hij6D94wLxx4+Y6/H2ogD9TxyXmYG85J8sbWWdJAKl8D4as+JU46UAWbBk13gprAmje1OVuWPKkPIX0spgE++mGuoFbUg6a/3nvut895phYPwopw4z3P5bsd7ft4RKcaJmxb3txby5HlfOI4j8X3y5yPpgPBLvARwCDzjAktcKqUA3VFh6HWubVfGFa92Hq2Da889GnzXB47uMDDUYVy2FyNggZPPZ9fzeNnXXC95jiiDimfkn3ofeUbvHeWoC04Rh/i8o1Rc7V7g/A5mhI3KLL1e5xbrG3g2rrPBD0qJ+VEm6ZwW3OmTP05eN/nSKbTAe5o8w+R3laGBLpNXhAHQofyUc+mbVeito/nQx4bOQD7D5xeP9zU4zlKX9xzHgTb4l87rKDXliq7HzHAJnqJPGcW1kyYTxqSh+kiZymvMLGVt7iem3lcGdy4O+OAPV7sXXsJ11nIu71Mc6L1jLH2hK84n92g8R2Gh97iHbEk8wORr8BX3ynmgtUvwaO7nWWryaNiUe6FP1Reer7+rfCee3veN3x5fCy/TfbnvG14G3ndPeXE+AijXFXzvvu/gkUlHsf4KS5k9defQRxMXAFgZusu94engh/d94yw1eSyAlHF8xu331BeOxyL3Y29b4hDXoTitusSUbaH3L/gwdBMzQxvvI8z4zrbxQNUR4v7Aue/v7+AB54nn85nP4D48728cAwEPyvsWeGq1LPNXeidPUlxPnexer915pEoNG7C6esg1hT/fnSJNbEzvBhv8mnrjYUEvByZsbgda77HPouv63VBLQe83rnbDxx4fxxGlPt3z+l5CJlQ7Bvzu3PPWGh5HfO596Kw26HzwGuUn1Qp+/vw59Vnv4+fUtVXunJShFnSZti3shQ/03nH1BqDjx48fgQtY5TZ1hcTpC+O+K/TT1vAP//AP+J//h/8R99Xxz/75P8f/+f/yz/Hbv/QbqgVuE84+aK4Pfm027Y/FHhoOHLWtiuAL+ZzaWokf5gudKU4FSkzba7cVBtXA3fGoR+4bYXX8+C1porWGMujkuq7YB8iczgM23ncMfwbXCSDX7QBOKygFy5qVV077HskXqX+r7NV9+o//P//v/x/+wvHXdhz9FwD+LTP7NxEOo/8bgH//l1ePvaVCQCYbjOtEKY77DnWYCo0SRHfHYXVBMAWUKvQAUjEx2MKYiSTKuPj5VCjuNLrKYJIVBu8ezhN3HAOx0IdiUIdCfsV7rJYRVUEomD4Ro3uDeTDvcqxESiZNJlPKZPYhSKZS37ujlEAqrv84hOGWiHaFrbkyjGnsliQad8f5CELgHi1j09N3IyjWEUz8PM/FQEF3dBODWZwI+iwVdpwT91wNWDqkdiV8d3hwDucZgtHbfE46HSoZQcHXOY3VagX9bqk8cm8
"text/plain": [
"<Figure size 1440x1440 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"# 在notebook中使用matplotlib.pyplot绘图时需要添加该命令进行显示\n",
"%matplotlib inline\n",
"from PIL import Image\n",
"import numpy as np\n",
"\n",
"img = Image.open('./output/det_db/det_results/img_12.jpg')\n",
"img = np.array(img)\n",
"\n",
"# 画出读取的图片\n",
"plt.figure(figsize=(20, 20))\n",
"plt.imshow(img)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# 4. 总结\n",
"\n",
"本节介绍了PaddleOCR文本检测模型的快速使用方法并且以DB算法为例介绍了从数据处理到完成文本检测算法训练的实现过程。下一节将介绍文本识别算法的相关内容。"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# **FAQ**\n",
"\n",
"1. 遇到如下图文字漏检测部分,该如何处理?\n",
"\n",
"<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/ccf08d89e0974a848e1a929eefbc1e7176ba211121cb4b76a6b6501cb27b1c9f\" width = \"500\"></center>\n",
"\n",
"<center><br>图 文字区域漏检测 </br></center>\n",
"<br></br>\n",
"\n",
"上述问题表现检测了一部分文字但是文本预测框和GT框的IOU大于阈值0.5检测指标无法正常反馈出来如果此类结果较多建议增大IOU阈值。另外漏检测的本质原因在于一部分文字的特征没有响应归根结底是网络没有学习到漏检测部分文字的特征。建议具体问题具体分析可视化预测结果分析漏检测的原因是否是因为光照形变文字较长等因素导致的然后针对性的使用数据增强、调整网络、或者调整后处理等方法优化检测结果。\n",
"\n",
"\n",
"更多文本检测FAQ内容参考下一节内容。"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# 作业\n",
"简答题:\n",
"- 1. 根据DB Backbone和FPN的输出特征图的大小判断DB的输入图像高度和宽度需要是_的倍数\n",
"A: 32 B: 64\n",
"\n",
"实验题:\n",
"- 1. 使用DB算法配置文件configs/det/det_mv3_db.yml在数据集[det_data_lesson_demo.tar](https://paddleocr.bj.bcebos.com/dataset/det_data_lesson_demo.tar)上训练文本检测模型,并调优实验精度。\n",
"\n",
"<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/1b68e58548f94603854cd342602686bc4f0ada68b6214e3986f332816164c518\"\n",
"width = \"700\"></center>\n",
"\n",
"<center><br>图 det_data_lesson_demo训练数据示例 </br></center>\n",
"<br></br>\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "py35-paddle1.2.0"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 1
}