diff --git a/configs/rec/rec_svtrnet.yml b/configs/rec/rec_svtrnet.yml
index e8ceefead..82b8273a1 100644
--- a/configs/rec/rec_svtrnet.yml
+++ b/configs/rec/rec_svtrnet.yml
@@ -26,10 +26,10 @@ Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.99
- epsilon: 8.e-8
+ epsilon: 1.e-8
weight_decay: 0.05
no_weight_decay_name: norm pos_embed
- one_dim_param_no_weight_decay: true
+ one_dim_param_no_weight_decay: True
lr:
name: Cosine
learning_rate: 0.0005
@@ -48,7 +48,7 @@ Architecture:
Backbone:
name: SVTRNet
img_size: [32, 100]
- out_char_num: 25
+ out_char_num: 25 # W//4 or W//8 or W/12
out_channels: 192
patch_merging: 'Conv'
embed_dim: [64, 128, 256]
@@ -57,7 +57,7 @@ Architecture:
mixer: ['Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global']
local_mixer: [[7, 11], [7, 11], [7, 11]]
last_stage: True
- prenorm: false
+ prenorm: False
Neck:
name: SequenceEncoder
encoder_type: reshape
@@ -82,6 +82,8 @@ Train:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
+ - SVTRRecAug:
+ aug_type: 0 # or 1
- CTCLabelEncode: # Class handling label
- SVTRRecResizeImg:
image_shape: [3, 64, 256]
@@ -92,7 +94,7 @@ Train:
shuffle: True
batch_size_per_card: 512
drop_last: True
- num_workers: 4
+ num_workers: 8
Eval:
dataset:
diff --git a/configs/rec/rec_svtrnet_ch.yml b/configs/rec/rec_svtrnet_ch.yml
index 0d3f63d12..597e57fb4 100644
--- a/configs/rec/rec_svtrnet_ch.yml
+++ b/configs/rec/rec_svtrnet_ch.yml
@@ -23,7 +23,7 @@ Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.99
- epsilon: 8.0e-08
+ epsilon: 1.0e-08
weight_decay: 0.05
no_weight_decay_name: norm pos_embed
one_dim_param_no_weight_decay: true
@@ -40,7 +40,7 @@ Architecture:
img_size:
- 32
- 320
- out_char_num: 40
+ out_char_num: 40 # W//4 or W//8 or W/12
out_channels: 96
patch_merging: Conv
embed_dim:
diff --git a/doc/doc_ch/algorithm_rec_svtr.md b/doc/doc_ch/algorithm_rec_svtr.md
index c0e26433e..42a1a9a41 100644
--- a/doc/doc_ch/algorithm_rec_svtr.md
+++ b/doc/doc_ch/algorithm_rec_svtr.md
@@ -159,7 +159,23 @@ Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999998807907104)
## 5. FAQ
-1. 由于`SVTR`使用的算子大多为矩阵相乘,在GPU环境下,速度具有优势,但在CPU开启mkldnn加速环境下,`SVTR`相比于被优化的卷积网络没有优势。
+- 1. GPU和CPU速度对比
+ - 由于`SVTR`使用的算子大多为矩阵相乘,在GPU环境下,速度具有优势,但在CPU开启mkldnn加速环境下,`SVTR`相比于被优化的卷积网络没有优势。
+- 2. SVTR模型转ONNX失败
+ - 保证`paddle2onnx`和`onnxruntime`版本最新,转onnx命令参考[SVTR模型转onnx步骤实例](https://github.com/PaddlePaddle/PaddleOCR/issues/7821#issuecomment-1271214273)。
+- 3. SVTR转ONNX成功但是推理结果不正确
+ - 可能的原因模型参数`out_char_num`设置不正确,应设置为W//4、W//8或者W//12,可以参考[高精度中文场景文本识别模型SVTR的3.3.3章节](https://aistudio.baidu.com/aistudio/projectdetail/5073182?contributionType=1)。
+- 4. 长文本识别优化
+ - 参考[高精度中文场景文本识别模型SVTR的3.3章节](https://aistudio.baidu.com/aistudio/projectdetail/5073182?contributionType=1)。
+- 5. 论文结果复现注意事项
+ - 数据集使用[ABINet](https://github.com/FangShancheng/ABINet)提供的数据集;
+ - 默认使用4卡GPU训练,单卡Batchsize默认为512,总Batchsize为2048,对应的学习率为0.0005,当修改Batchsize或者改变GPU卡数,学习率应等比例修改。
+- 6. 进一步优化的探索点
+ - 学习率调整:可以调整为默认的两倍保持Batchsize不变;或者将Batchsize减小为默认的1/2,保持学习率不变;
+ - 数据增强策略:可选`RecConAug`和`RecAug`;
+ - 如果不使用STN时,可以将`mixer`的`Local`替换为`Conv`、`local_mixer`全部修改为`[5, 5]`;
+ - 网格搜索最优的`embed_dim`、`depth`、`num_heads`配置;
+ - 使用`后Normalization策略`,即是将模型配置`prenorm`修改为`True`。
## 引用
diff --git a/doc/doc_en/algorithm_rec_svtr_en.md b/doc/doc_en/algorithm_rec_svtr_en.md
index 37cd35f35..d22fe73e6 100644
--- a/doc/doc_en/algorithm_rec_svtr_en.md
+++ b/doc/doc_en/algorithm_rec_svtr_en.md
@@ -130,7 +130,23 @@ Not supported
## 5. FAQ
-1. Since most of the operators used by `SVTR` are matrix multiplication, in the GPU environment, the speed has an advantage, but in the environment where mkldnn is enabled on the CPU, `SVTR` has no advantage over the optimized convolutional network.
+- 1. Speed situation on CPU and GPU
+ - Since most of the operators used by `SVTR` are matrix multiplication, in the GPU environment, the speed has an advantage, but in the environment where mkldnn is enabled on the CPU, `SVTR` has no advantage over the optimized convolutional network.
+- 2. SVTR model convert to ONNX failed
+ - Ensure `paddle2onnx` and `onnxruntime` versions are up to date, refer to [SVTR model to onnx step-by-step example](https://github.com/PaddlePaddle/PaddleOCR/issues/7821#issuecomment-) for the convert onnx command. 1271214273).
+- 3. SVTR model convert to ONNX is successful but the inference result is incorrect
+ - The possible reason is that the model parameter `out_char_num` is not set correctly, it should be set to W//4, W//8 or W//12, please refer to [Section 3.3.3 of SVTR, a high-precision Chinese scene text recognition model](https://aistudio.baidu.com/aistudio/) projectdetail/5073182?contributionType=1).
+- 4. Optimization of long text recognition
+ - Refer to [Section 3.3 of SVTR, a high-precision Chinese scene text recognition model](https://aistudio.baidu.com/aistudio/projectdetail/5073182?contributionType=1).
+- 5. Notes on the reproduction of the paper results
+ - Dataset using provided by [ABINet](https://github.com/FangShancheng/ABINet).
+ - By default, 4 cards of GPUs are used for training, the default Batchsize of a single card is 512, and the total Batchsize is 2048, corresponding to a learning rate of 0.0005. When modifying the Batchsize or changing the number of GPU cards, the learning rate should be modified in equal proportion.
+- 6. Exploration Directions for further optimization
+ - Learning rate adjustment: adjusting to twice the default to keep Batchsize unchanged; or reducing Batchsize to 1/2 the default to keep the learning rate unchanged.
+ - Data augmentation strategies: optionally `RecConAug` and `RecAug`.
+ - If STN is not used, `Local` of `mixer` can be replaced by `Conv` and `local_mixer` can all be modified to `[5, 5]`.
+ - Grid search for optimal `embed_dim`, `depth`, `num_heads` configurations.
+ - Use the `Post-Normalization strategy`, which is to modify the model configuration `prenorm` to `True`.
## Citation
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 93d97446d..121582b49 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -27,7 +27,7 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
- RFLRecResizeImg
+ RFLRecResizeImg, SVTRRecAug
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
diff --git a/ppocr/data/imaug/abinet_aug.py b/ppocr/data/imaug/abinet_aug.py
index eefdc75d5..1b93751bc 100644
--- a/ppocr/data/imaug/abinet_aug.py
+++ b/ppocr/data/imaug/abinet_aug.py
@@ -405,3 +405,54 @@ class CVColorJitter(object):
def __call__(self, img):
if random.random() < self.p: return self.transforms(img)
else: return img
+
+
+class SVTRDeterioration(object):
+ def __init__(self, var, degrees, factor, p=0.5):
+ self.p = p
+ transforms = []
+ if var is not None:
+ transforms.append(CVGaussianNoise(var=var))
+ if degrees is not None:
+ transforms.append(CVMotionBlur(degrees=degrees))
+ if factor is not None:
+ transforms.append(CVRescale(factor=factor))
+ self.transforms = transforms
+
+ def __call__(self, img):
+ if random.random() < self.p:
+ random.shuffle(self.transforms)
+ transforms = Compose(self.transforms)
+ return transforms(img)
+ else:
+ return img
+
+
+class SVTRGeometry(object):
+ def __init__(self,
+ aug_type=0,
+ degrees=15,
+ translate=(0.3, 0.3),
+ scale=(0.5, 2.),
+ shear=(45, 15),
+ distortion=0.5,
+ p=0.5):
+ self.aug_type = aug_type
+ self.p = p
+ self.transforms = []
+ self.transforms.append(CVRandomRotation(degrees=degrees))
+ self.transforms.append(CVRandomAffine(
+ degrees=degrees, translate=translate, scale=scale, shear=shear))
+ self.transforms.append(CVRandomPerspective(distortion=distortion))
+
+ def __call__(self, img):
+ if random.random() < self.p:
+ if self.aug_type:
+ random.shuffle(self.transforms)
+ transforms = Compose(self.transforms[:random.randint(1, 3)])
+ img = transforms(img)
+ else:
+ img = self.transforms[random.randint(0, 2)](img)
+ return img
+ else:
+ return img
\ No newline at end of file
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index e22153bde..48404ab8d 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -19,7 +19,7 @@ import random
import copy
from PIL import Image
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
-from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter
+from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter, SVTRGeometry, SVTRDeterioration
from paddle.vision.transforms import Compose
@@ -109,8 +109,9 @@ class ABINetRecAug(object):
scale=(0.5, 2.),
shear=(45, 15),
distortion=0.5,
- p=geometry_p), CVDeterioration(
- var=20, degrees=6, factor=4, p=deterioration_p),
+ p=geometry_p),
+ CVDeterioration(
+ var=20, degrees=6, factor=4, p=deterioration_p),
CVColorJitter(
brightness=0.5,
contrast=0.5,
@@ -169,6 +170,39 @@ class RecConAug(object):
return data
+class SVTRRecAug(object):
+ def __init__(self,
+ aug_type=0,
+ geometry_p=0.5,
+ deterioration_p=0.25,
+ colorjitter_p=0.25,
+ **kwargs):
+ self.transforms = Compose([
+ SVTRGeometry(
+ aug_type=aug_type,
+ degrees=45,
+ translate=(0.0, 0.0),
+ scale=(0.5, 2.),
+ shear=(45, 15),
+ distortion=0.5,
+ p=geometry_p),
+ SVTRDeterioration(
+ var=20, degrees=6, factor=4, p=deterioration_p),
+ CVColorJitter(
+ brightness=0.5,
+ contrast=0.5,
+ saturation=0.5,
+ hue=0.1,
+ p=colorjitter_p)
+ ])
+
+ def __call__(self, data):
+ img = data['image']
+ img = self.transforms(img)
+ data['image'] = img
+ return data
+
+
class ClsResizeImg(object):
def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape