add svtr FAQ ang data_aug (#8865)
* Update rec_nrtr_head.py * add svtr FAQ and data_augpull/8918/head
parent
4aff082c5b
commit
29cdda4eda
|
@ -26,10 +26,10 @@ Optimizer:
|
||||||
name: AdamW
|
name: AdamW
|
||||||
beta1: 0.9
|
beta1: 0.9
|
||||||
beta2: 0.99
|
beta2: 0.99
|
||||||
epsilon: 8.e-8
|
epsilon: 1.e-8
|
||||||
weight_decay: 0.05
|
weight_decay: 0.05
|
||||||
no_weight_decay_name: norm pos_embed
|
no_weight_decay_name: norm pos_embed
|
||||||
one_dim_param_no_weight_decay: true
|
one_dim_param_no_weight_decay: True
|
||||||
lr:
|
lr:
|
||||||
name: Cosine
|
name: Cosine
|
||||||
learning_rate: 0.0005
|
learning_rate: 0.0005
|
||||||
|
@ -48,7 +48,7 @@ Architecture:
|
||||||
Backbone:
|
Backbone:
|
||||||
name: SVTRNet
|
name: SVTRNet
|
||||||
img_size: [32, 100]
|
img_size: [32, 100]
|
||||||
out_char_num: 25
|
out_char_num: 25 # W//4 or W//8 or W/12
|
||||||
out_channels: 192
|
out_channels: 192
|
||||||
patch_merging: 'Conv'
|
patch_merging: 'Conv'
|
||||||
embed_dim: [64, 128, 256]
|
embed_dim: [64, 128, 256]
|
||||||
|
@ -57,7 +57,7 @@ Architecture:
|
||||||
mixer: ['Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global']
|
mixer: ['Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global']
|
||||||
local_mixer: [[7, 11], [7, 11], [7, 11]]
|
local_mixer: [[7, 11], [7, 11], [7, 11]]
|
||||||
last_stage: True
|
last_stage: True
|
||||||
prenorm: false
|
prenorm: False
|
||||||
Neck:
|
Neck:
|
||||||
name: SequenceEncoder
|
name: SequenceEncoder
|
||||||
encoder_type: reshape
|
encoder_type: reshape
|
||||||
|
@ -82,6 +82,8 @@ Train:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
img_mode: BGR
|
img_mode: BGR
|
||||||
channel_first: False
|
channel_first: False
|
||||||
|
- SVTRRecAug:
|
||||||
|
aug_type: 0 # or 1
|
||||||
- CTCLabelEncode: # Class handling label
|
- CTCLabelEncode: # Class handling label
|
||||||
- SVTRRecResizeImg:
|
- SVTRRecResizeImg:
|
||||||
image_shape: [3, 64, 256]
|
image_shape: [3, 64, 256]
|
||||||
|
@ -92,7 +94,7 @@ Train:
|
||||||
shuffle: True
|
shuffle: True
|
||||||
batch_size_per_card: 512
|
batch_size_per_card: 512
|
||||||
drop_last: True
|
drop_last: True
|
||||||
num_workers: 4
|
num_workers: 8
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
|
|
|
@ -23,7 +23,7 @@ Optimizer:
|
||||||
name: AdamW
|
name: AdamW
|
||||||
beta1: 0.9
|
beta1: 0.9
|
||||||
beta2: 0.99
|
beta2: 0.99
|
||||||
epsilon: 8.0e-08
|
epsilon: 1.0e-08
|
||||||
weight_decay: 0.05
|
weight_decay: 0.05
|
||||||
no_weight_decay_name: norm pos_embed
|
no_weight_decay_name: norm pos_embed
|
||||||
one_dim_param_no_weight_decay: true
|
one_dim_param_no_weight_decay: true
|
||||||
|
@ -40,7 +40,7 @@ Architecture:
|
||||||
img_size:
|
img_size:
|
||||||
- 32
|
- 32
|
||||||
- 320
|
- 320
|
||||||
out_char_num: 40
|
out_char_num: 40 # W//4 or W//8 or W/12
|
||||||
out_channels: 96
|
out_channels: 96
|
||||||
patch_merging: Conv
|
patch_merging: Conv
|
||||||
embed_dim:
|
embed_dim:
|
||||||
|
|
|
@ -159,7 +159,23 @@ Predicts of ./doc/imgs_words_en/word_10.png:('pain', 0.9999998807907104)
|
||||||
<a name="5"></a>
|
<a name="5"></a>
|
||||||
## 5. FAQ
|
## 5. FAQ
|
||||||
|
|
||||||
1. 由于`SVTR`使用的算子大多为矩阵相乘,在GPU环境下,速度具有优势,但在CPU开启mkldnn加速环境下,`SVTR`相比于被优化的卷积网络没有优势。
|
- 1. GPU和CPU速度对比
|
||||||
|
- 由于`SVTR`使用的算子大多为矩阵相乘,在GPU环境下,速度具有优势,但在CPU开启mkldnn加速环境下,`SVTR`相比于被优化的卷积网络没有优势。
|
||||||
|
- 2. SVTR模型转ONNX失败
|
||||||
|
- 保证`paddle2onnx`和`onnxruntime`版本最新,转onnx命令参考[SVTR模型转onnx步骤实例](https://github.com/PaddlePaddle/PaddleOCR/issues/7821#issuecomment-1271214273)。
|
||||||
|
- 3. SVTR转ONNX成功但是推理结果不正确
|
||||||
|
- 可能的原因模型参数`out_char_num`设置不正确,应设置为W//4、W//8或者W//12,可以参考[高精度中文场景文本识别模型SVTR的3.3.3章节](https://aistudio.baidu.com/aistudio/projectdetail/5073182?contributionType=1)。
|
||||||
|
- 4. 长文本识别优化
|
||||||
|
- 参考[高精度中文场景文本识别模型SVTR的3.3章节](https://aistudio.baidu.com/aistudio/projectdetail/5073182?contributionType=1)。
|
||||||
|
- 5. 论文结果复现注意事项
|
||||||
|
- 数据集使用[ABINet](https://github.com/FangShancheng/ABINet)提供的数据集;
|
||||||
|
- 默认使用4卡GPU训练,单卡Batchsize默认为512,总Batchsize为2048,对应的学习率为0.0005,当修改Batchsize或者改变GPU卡数,学习率应等比例修改。
|
||||||
|
- 6. 进一步优化的探索点
|
||||||
|
- 学习率调整:可以调整为默认的两倍保持Batchsize不变;或者将Batchsize减小为默认的1/2,保持学习率不变;
|
||||||
|
- 数据增强策略:可选`RecConAug`和`RecAug`;
|
||||||
|
- 如果不使用STN时,可以将`mixer`的`Local`替换为`Conv`、`local_mixer`全部修改为`[5, 5]`;
|
||||||
|
- 网格搜索最优的`embed_dim`、`depth`、`num_heads`配置;
|
||||||
|
- 使用`后Normalization策略`,即是将模型配置`prenorm`修改为`True`。
|
||||||
|
|
||||||
|
|
||||||
## 引用
|
## 引用
|
||||||
|
|
|
@ -130,7 +130,23 @@ Not supported
|
||||||
<a name="5"></a>
|
<a name="5"></a>
|
||||||
## 5. FAQ
|
## 5. FAQ
|
||||||
|
|
||||||
1. Since most of the operators used by `SVTR` are matrix multiplication, in the GPU environment, the speed has an advantage, but in the environment where mkldnn is enabled on the CPU, `SVTR` has no advantage over the optimized convolutional network.
|
- 1. Speed situation on CPU and GPU
|
||||||
|
- Since most of the operators used by `SVTR` are matrix multiplication, in the GPU environment, the speed has an advantage, but in the environment where mkldnn is enabled on the CPU, `SVTR` has no advantage over the optimized convolutional network.
|
||||||
|
- 2. SVTR model convert to ONNX failed
|
||||||
|
- Ensure `paddle2onnx` and `onnxruntime` versions are up to date, refer to [SVTR model to onnx step-by-step example](https://github.com/PaddlePaddle/PaddleOCR/issues/7821#issuecomment-) for the convert onnx command. 1271214273).
|
||||||
|
- 3. SVTR model convert to ONNX is successful but the inference result is incorrect
|
||||||
|
- The possible reason is that the model parameter `out_char_num` is not set correctly, it should be set to W//4, W//8 or W//12, please refer to [Section 3.3.3 of SVTR, a high-precision Chinese scene text recognition model](https://aistudio.baidu.com/aistudio/) projectdetail/5073182?contributionType=1).
|
||||||
|
- 4. Optimization of long text recognition
|
||||||
|
- Refer to [Section 3.3 of SVTR, a high-precision Chinese scene text recognition model](https://aistudio.baidu.com/aistudio/projectdetail/5073182?contributionType=1).
|
||||||
|
- 5. Notes on the reproduction of the paper results
|
||||||
|
- Dataset using provided by [ABINet](https://github.com/FangShancheng/ABINet).
|
||||||
|
- By default, 4 cards of GPUs are used for training, the default Batchsize of a single card is 512, and the total Batchsize is 2048, corresponding to a learning rate of 0.0005. When modifying the Batchsize or changing the number of GPU cards, the learning rate should be modified in equal proportion.
|
||||||
|
- 6. Exploration Directions for further optimization
|
||||||
|
- Learning rate adjustment: adjusting to twice the default to keep Batchsize unchanged; or reducing Batchsize to 1/2 the default to keep the learning rate unchanged.
|
||||||
|
- Data augmentation strategies: optionally `RecConAug` and `RecAug`.
|
||||||
|
- If STN is not used, `Local` of `mixer` can be replaced by `Conv` and `local_mixer` can all be modified to `[5, 5]`.
|
||||||
|
- Grid search for optimal `embed_dim`, `depth`, `num_heads` configurations.
|
||||||
|
- Use the `Post-Normalization strategy`, which is to modify the model configuration `prenorm` to `True`.
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ from .make_pse_gt import MakePseGt
|
||||||
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||||
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
|
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
|
||||||
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
|
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
|
||||||
RFLRecResizeImg
|
RFLRecResizeImg, SVTRRecAug
|
||||||
from .ssl_img_aug import SSLRotateResize
|
from .ssl_img_aug import SSLRotateResize
|
||||||
from .randaugment import RandAugment
|
from .randaugment import RandAugment
|
||||||
from .copy_paste import CopyPaste
|
from .copy_paste import CopyPaste
|
||||||
|
|
|
@ -405,3 +405,54 @@ class CVColorJitter(object):
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
if random.random() < self.p: return self.transforms(img)
|
if random.random() < self.p: return self.transforms(img)
|
||||||
else: return img
|
else: return img
|
||||||
|
|
||||||
|
|
||||||
|
class SVTRDeterioration(object):
|
||||||
|
def __init__(self, var, degrees, factor, p=0.5):
|
||||||
|
self.p = p
|
||||||
|
transforms = []
|
||||||
|
if var is not None:
|
||||||
|
transforms.append(CVGaussianNoise(var=var))
|
||||||
|
if degrees is not None:
|
||||||
|
transforms.append(CVMotionBlur(degrees=degrees))
|
||||||
|
if factor is not None:
|
||||||
|
transforms.append(CVRescale(factor=factor))
|
||||||
|
self.transforms = transforms
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
if random.random() < self.p:
|
||||||
|
random.shuffle(self.transforms)
|
||||||
|
transforms = Compose(self.transforms)
|
||||||
|
return transforms(img)
|
||||||
|
else:
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class SVTRGeometry(object):
|
||||||
|
def __init__(self,
|
||||||
|
aug_type=0,
|
||||||
|
degrees=15,
|
||||||
|
translate=(0.3, 0.3),
|
||||||
|
scale=(0.5, 2.),
|
||||||
|
shear=(45, 15),
|
||||||
|
distortion=0.5,
|
||||||
|
p=0.5):
|
||||||
|
self.aug_type = aug_type
|
||||||
|
self.p = p
|
||||||
|
self.transforms = []
|
||||||
|
self.transforms.append(CVRandomRotation(degrees=degrees))
|
||||||
|
self.transforms.append(CVRandomAffine(
|
||||||
|
degrees=degrees, translate=translate, scale=scale, shear=shear))
|
||||||
|
self.transforms.append(CVRandomPerspective(distortion=distortion))
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
if random.random() < self.p:
|
||||||
|
if self.aug_type:
|
||||||
|
random.shuffle(self.transforms)
|
||||||
|
transforms = Compose(self.transforms[:random.randint(1, 3)])
|
||||||
|
img = transforms(img)
|
||||||
|
else:
|
||||||
|
img = self.transforms[random.randint(0, 2)](img)
|
||||||
|
return img
|
||||||
|
else:
|
||||||
|
return img
|
|
@ -19,7 +19,7 @@ import random
|
||||||
import copy
|
import copy
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
||||||
from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter
|
from .abinet_aug import CVGeometry, CVDeterioration, CVColorJitter, SVTRGeometry, SVTRDeterioration
|
||||||
from paddle.vision.transforms import Compose
|
from paddle.vision.transforms import Compose
|
||||||
|
|
||||||
|
|
||||||
|
@ -109,8 +109,9 @@ class ABINetRecAug(object):
|
||||||
scale=(0.5, 2.),
|
scale=(0.5, 2.),
|
||||||
shear=(45, 15),
|
shear=(45, 15),
|
||||||
distortion=0.5,
|
distortion=0.5,
|
||||||
p=geometry_p), CVDeterioration(
|
p=geometry_p),
|
||||||
var=20, degrees=6, factor=4, p=deterioration_p),
|
CVDeterioration(
|
||||||
|
var=20, degrees=6, factor=4, p=deterioration_p),
|
||||||
CVColorJitter(
|
CVColorJitter(
|
||||||
brightness=0.5,
|
brightness=0.5,
|
||||||
contrast=0.5,
|
contrast=0.5,
|
||||||
|
@ -169,6 +170,39 @@ class RecConAug(object):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class SVTRRecAug(object):
|
||||||
|
def __init__(self,
|
||||||
|
aug_type=0,
|
||||||
|
geometry_p=0.5,
|
||||||
|
deterioration_p=0.25,
|
||||||
|
colorjitter_p=0.25,
|
||||||
|
**kwargs):
|
||||||
|
self.transforms = Compose([
|
||||||
|
SVTRGeometry(
|
||||||
|
aug_type=aug_type,
|
||||||
|
degrees=45,
|
||||||
|
translate=(0.0, 0.0),
|
||||||
|
scale=(0.5, 2.),
|
||||||
|
shear=(45, 15),
|
||||||
|
distortion=0.5,
|
||||||
|
p=geometry_p),
|
||||||
|
SVTRDeterioration(
|
||||||
|
var=20, degrees=6, factor=4, p=deterioration_p),
|
||||||
|
CVColorJitter(
|
||||||
|
brightness=0.5,
|
||||||
|
contrast=0.5,
|
||||||
|
saturation=0.5,
|
||||||
|
hue=0.1,
|
||||||
|
p=colorjitter_p)
|
||||||
|
])
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data['image']
|
||||||
|
img = self.transforms(img)
|
||||||
|
data['image'] = img
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ClsResizeImg(object):
|
class ClsResizeImg(object):
|
||||||
def __init__(self, image_shape, **kwargs):
|
def __init__(self, image_shape, **kwargs):
|
||||||
self.image_shape = image_shape
|
self.image_shape = image_shape
|
||||||
|
|
Loading…
Reference in New Issue