add svtr large model (#10937)

* add svtr large model

* [WIP]add svtr large model
pull/11012/head
zhangyubo0722 2023-09-26 14:38:29 +08:00 committed by GitHub
parent 2751cb3a11
commit e49e491417
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 457 additions and 18 deletions

View File

@ -0,0 +1,144 @@
Global:
debug: false
use_gpu: true
epoch_num: 200
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/svtr_large/
save_epoch_step: 10
# evaluation is run every 2000 iterations after the 0th iteration
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: &max_text_length 40
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_svtr_large.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.99
epsilon: 1.0e-08
weight_decay: 0.05
no_weight_decay_name: norm pos_embed char_node_embed pos_node_embed char_pos_embed vis_pos_embed
one_dim_param_no_weight_decay: true
lr:
name: Cosine
learning_rate: 0.00025 # 8gpus 64bs
warmup_epoch: 5
Architecture:
model_type: rec
algorithm: SVTR_LCNet
Transform: null
Backbone:
name: SVTRNet
img_size:
- 48
- 320
out_char_num: 40
out_channels: 512
patch_merging: Conv
embed_dim: [192, 256, 512]
depth: [6, 6, 9]
num_heads: [6, 8, 16]
mixer: ['Conv','Conv','Conv','Conv','Conv','Conv','Conv','Conv','Conv','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global']
local_mixer: [[5, 5], [5, 5], [5, 5]]
last_stage: False
prenorm: True
Head:
name: MultiHead
use_pool: true
use_pos: true
head_list:
- CTCHead:
Neck:
name: svtr
dims: 256
depth: 2
hidden_dims: 256
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 512
max_text_length: *max_text_length
Loss:
name: MultiLoss
loss_config_list:
- CTCLoss:
- NRTRLoss:
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
ignore_space: true
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: true
batch_size_per_card: 64
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- SVTRRecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4

View File

@ -47,7 +47,7 @@ class RecAug(object):
if h >= 20 and w >= 20:
img = tia_distort(img, random.randint(3, 6))
img = tia_stretch(img, random.randint(3, 6))
img = tia_perspective(img)
img = tia_perspective(img)
# bda
data['image'] = img

View File

@ -24,6 +24,7 @@ def build_backbone(config, model_type):
from .det_pp_lcnet import PPLCNet
from .rec_lcnetv3 import PPLCNetV3
from .rec_hgnet import PPHGNet_small
from .rec_vit import ViT
support_dict = [
"MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet",
"PPLCNetV3", "PPHGNet_small"
@ -55,7 +56,7 @@ def build_backbone(config, model_type):
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet',
'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL',
'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small', 'ViTParseQ'
'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small', 'ViTParseQ', 'ViT'
]
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet

View File

@ -0,0 +1,258 @@
# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import ParamAttr
from paddle.nn.initializer import KaimingNormal
import numpy as np
import paddle
import paddle.nn as nn
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
trunc_normal_ = TruncatedNormal(std=.02)
normal_ = Normal
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
def drop_path(x, drop_prob=0., training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if drop_prob == 0. or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
return output
class DropPath(nn.Layer):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Identity(nn.Layer):
def __init__(self):
super(Identity, self).__init__()
def forward(self, input):
return input
class Mlp(nn.Layer):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Layer):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.dim = dim
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
qkv = paddle.reshape(self.qkv(x), (0, -1, 3, self.num_heads, self.dim //
self.num_heads)).transpose((2, 0, 3, 1, 4))
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = (q.matmul(k.transpose((0, 1, 3, 2))))
attn = nn.functional.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, -1, self.dim))
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Layer):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer='nn.LayerNorm',
epsilon=1e-6,
prenorm=True):
super().__init__()
if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
else:
self.norm1 = norm_layer(dim)
self.mixer = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
else:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
self.prenorm = prenorm
def forward(self, x):
if self.prenorm:
x = self.norm1(x + self.drop_path(self.mixer(x)))
x = self.norm2(x + self.drop_path(self.mlp(x)))
else:
x = x + self.drop_path(self.mixer(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class ViT(nn.Layer):
def __init__(
self,
img_size=[32, 128],
patch_size=[4,4],
in_channels=3,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
norm_layer='nn.LayerNorm',
epsilon=1e-6,
act='nn.GELU',
prenorm=False,
**kwargs):
super().__init__()
self.embed_dim = embed_dim
self.out_channels = embed_dim
self.prenorm = prenorm
self.patch_embed = nn.Conv2D(in_channels, embed_dim, patch_size, patch_size, padding=(0, 0))
self.pos_embed = self.create_parameter(
shape=[1, 257, embed_dim], default_initializer=zeros_)
self.add_parameter("pos_embed", self.pos_embed)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = np.linspace(0, drop_path_rate, depth)
self.blocks1 = nn.LayerList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=eval(act),
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm) for i in range(depth)
])
if not prenorm:
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
self.avg_pool = nn.AdaptiveAvgPool2D([1, 25])
self.last_conv = nn.Conv2D(
in_channels=embed_dim,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.hardswish = nn.Hardswish()
self.dropout = nn.Dropout(p=0.1, mode="downscale_in_infer")
trunc_normal_(self.pos_embed)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
def forward(self, x):
x = self.patch_embed(x).flatten(2).transpose((0, 2, 1))
x = x + self.pos_embed[:, 1:, :] #[:, :paddle.shape(x)[1], :]
x = self.pos_drop(x)
for blk in self.blocks1:
x = blk(x)
if not self.prenorm:
x = self.norm(x)
x = self.avg_pool(x.transpose([0, 2, 1]).reshape(
[0, self.embed_dim, -1, 25]))
x = self.last_conv(x)
x = self.hardswish(x)
x = self.dropout(x)
return x

View File

@ -22,7 +22,7 @@ from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from ppocr.modeling.necks.rnn import Im2Seq, EncoderWithRNN, EncoderWithFC, SequenceEncoder, EncoderWithSVTR
from ppocr.modeling.necks.rnn import Im2Seq, EncoderWithRNN, EncoderWithFC, SequenceEncoder, EncoderWithSVTR, trunc_normal_, zeros_
from .rec_ctc_head import CTCHead
from .rec_sar_head import SARHead
from .rec_nrtr_head import Transformer
@ -41,12 +41,28 @@ class FCTranspose(nn.Layer):
else:
return self.fc(x.transpose([0, 2, 1]))
class AddPos(nn.Layer):
def __init__(self, dim, w):
super().__init__()
self.dec_pos_embed = self.create_parameter(
shape=[1, w, dim], default_initializer=zeros_)
self.add_parameter("dec_pos_embed", self.dec_pos_embed)
trunc_normal_(self.dec_pos_embed)
def forward(self,x):
x = x + self.dec_pos_embed[:, :paddle.shape(x)[1], :]
return x
class MultiHead(nn.Layer):
def __init__(self, in_channels, out_channels_list, **kwargs):
super().__init__()
self.head_list = kwargs.pop('head_list')
self.use_pool = kwargs.get('use_pool', False)
self.use_pos = kwargs.get('use_pos', False)
self.in_channels = in_channels
if self.use_pool:
self.pool = nn.AvgPool2D(kernel_size=[3, 2], stride=[3, 2], padding=0)
self.gtc_head = 'sar'
assert len(self.head_list) >= 2
for idx, head_name in enumerate(self.head_list):
@ -61,8 +77,13 @@ class MultiHead(nn.Layer):
max_text_length = gtc_args.get('max_text_length', 25)
nrtr_dim = gtc_args.get('nrtr_dim', 256)
num_decoder_layers = gtc_args.get('num_decoder_layers', 4)
self.before_gtc = nn.Sequential(
if self.use_pos:
self.before_gtc = nn.Sequential(
nn.Flatten(2), FCTranspose(in_channels, nrtr_dim), AddPos(nrtr_dim, 80))
else:
self.before_gtc = nn.Sequential(
nn.Flatten(2), FCTranspose(in_channels, nrtr_dim))
self.gtc_head = Transformer(
d_model=nrtr_dim,
nhead=nrtr_dim // 32,
@ -88,7 +109,8 @@ class MultiHead(nn.Layer):
'{} is not supported in MultiHead yet'.format(name))
def forward(self, x, targets=None):
if self.use_pool:
x = self.pool(x.reshape([0, 3, -1, self.in_channels]).transpose([0, 3, 1, 2]))
ctc_encoder = self.ctc_encoder(x)
ctc_out = self.ctc_head(ctc_encoder, targets)
head_out = dict()

32
ppocr/utils/utility.py 100755 → 100644
View File

@ -57,23 +57,35 @@ def _check_image_file(path):
return any([path.lower().endswith(e) for e in img_end])
def get_image_file_list(img_file):
def get_image_file_list(img_file, infer_list=None):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
if infer_list and not os.path.exists(infer_list):
raise Exception("not found infer list {}".format(infer_list))
if infer_list:
with open(infer_list, "r") as f:
lines = f.readlines()
for line in lines:
image_path = line.strip().split("\t")[0]
image_path = os.path.join(img_file, image_path)
imgs_lists.append(image_path)
else:
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
if os.path.isfile(img_file) and _check_image_file(img_file):
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
if os.path.isfile(file_path) and _check_image_file(file_path):
imgs_lists.append(file_path)
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'}
if os.path.isfile(img_file) and _check_image_file(img_file):
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
file_path = os.path.join(img_file, single_file)
if os.path.isfile(file_path) and _check_image_file(file_path):
imgs_lists.append(file_path)
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
imgs_lists = sorted(imgs_lists)
return imgs_lists
def binarize_img(img):
if len(img.shape) == 3 and img.shape[2] == 3:
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # conversion to grayscale image

View File

@ -118,9 +118,11 @@ def main():
os.makedirs(os.path.dirname(save_res_path))
model.eval()
infer_imgs = config['Global']['infer_img']
infer_list = config['Global'].get('infer_list', None)
with open(save_res_path, "w") as fout:
for file in get_image_file_list(config['Global']['infer_img']):
for file in get_image_file_list(infer_imgs, infer_list=infer_list):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()