# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function try: from collections import Callable except: from collections.abc import Callable import numpy as np import paddle from paddle import nn from paddle.nn import functional as F from ppocr.modeling.heads.rec_nrtr_head import Embeddings from ppocr.modeling.backbones.rec_svtrnet import DropPath, Identity, trunc_normal_, zeros_, ones_, Mlp class Attention(nn.Layer): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.q = nn.Linear(dim, dim, bias_attr=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias_attr=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, q, kv): N, C = kv.shape[1:] QN = q.shape[1] q = self.q(q).reshape( [-1, QN, self.num_heads, C // self.num_heads]).transpose( [0, 2, 1, 3]) k, v = self.kv(kv).reshape( [-1, N, 2, self.num_heads, C // self.num_heads]).transpose( (2, 0, 3, 1, 4)) attn = q.matmul(k.transpose((0, 1, 3, 2))) * self.scale attn = F.softmax(attn, axis=-1) attn = self.attn_drop(attn) x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, QN, C)) x = self.proj(x) x = self.proj_drop(x) return x class EdgeDecoderLayer(nn.Layer): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=[0., 0.], act_layer=nn.GELU, norm_layer='nn.LayerNorm', epsilon=1e-6): super().__init__() self.head_dim = dim // num_heads self.scale = qk_scale or self.head_dim**-0.5 # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path1 = DropPath(drop_path[0]) if drop_path[ 0] > 0. else Identity() self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) self.p = nn.Linear(dim, dim) self.cv = nn.Linear(dim, dim) self.pv = nn.Linear(dim, dim) self.dim = dim self.num_heads = num_heads self.p_proj = nn.Linear(dim, dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp_ratio = mlp_ratio self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, p, cv, pv): pN = p.shape[1] vN = cv.shape[1] p_shortcut = p p1 = self.p(p).reshape( [-1, pN, self.num_heads, self.dim // self.num_heads]).transpose( [0, 2, 1, 3]) cv1 = self.cv(cv).reshape( [-1, vN, self.num_heads, self.dim // self.num_heads]).transpose( [0, 2, 1, 3]) pv1 = self.pv(pv).reshape( [-1, vN, self.num_heads, self.dim // self.num_heads]).transpose( [0, 2, 1, 3]) edge = F.softmax(p1.matmul(pv1.transpose((0, 1, 3, 2))), -1) # B h N N p_c = (edge @cv1).transpose((0, 2, 1, 3)).reshape((-1, pN, self.dim)) x1 = self.norm1(p_shortcut + self.drop_path1(self.p_proj(p_c))) x = self.norm2(x1 + self.drop_path1(self.mlp(x1))) return x class DecoderLayer(nn.Layer): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer='nn.LayerNorm', epsilon=1e-6): super().__init__() if isinstance(norm_layer, str): self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) self.normkv = eval(norm_layer)(dim, epsilon=epsilon) elif isinstance(norm_layer, Callable): self.norm1 = norm_layer(dim) self.normkv = norm_layer(dim) else: raise TypeError( "The norm_layer must be str or paddle.nn.LayerNorm class") self.mixer = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() if isinstance(norm_layer, str): self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) elif isinstance(norm_layer, Callable): self.norm2 = norm_layer(dim) else: raise TypeError( "The norm_layer must be str or paddle.nn.layer.Layer class") mlp_hidden_dim = int(dim * mlp_ratio) self.mlp_ratio = mlp_ratio self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, q, kv): x1 = self.norm1(q + self.drop_path(self.mixer(q, kv))) x = self.norm2(x1 + self.drop_path(self.mlp(x1))) return x class CPPDHead(nn.Layer): def __init__(self, in_channels, dim, out_channels, num_layer=2, drop_path_rate=0.1, max_len=25, vis_seq=50, ch=False, **kwargs): super(CPPDHead, self).__init__() self.out_channels = out_channels # none + 26 + 10 self.dim = dim self.ch = ch self.max_len = max_len + 1 # max_len + eos self.char_node_embed = Embeddings( d_model=dim, vocab=self.out_channels, scale_embedding=True) self.pos_node_embed = Embeddings( d_model=dim, vocab=self.max_len, scale_embedding=True) dpr = np.linspace(0, drop_path_rate, num_layer + 1) self.char_node_decoder = nn.LayerList([ DecoderLayer( dim=dim, num_heads=dim // 32, mlp_ratio=4.0, qkv_bias=True, drop_path=dpr[i]) for i in range(num_layer) ]) self.pos_node_decoder = nn.LayerList([ DecoderLayer( dim=dim, num_heads=dim // 32, mlp_ratio=4.0, qkv_bias=True, drop_path=dpr[i]) for i in range(num_layer) ]) self.edge_decoder = EdgeDecoderLayer( dim=dim, num_heads=dim // 32, mlp_ratio=4.0, qkv_bias=True, drop_path=dpr[num_layer:num_layer + 1]) self.char_pos_embed = self.create_parameter( shape=[1, self.max_len, dim], default_initializer=zeros_) self.add_parameter("char_pos_embed", self.char_pos_embed) self.vis_pos_embed = self.create_parameter( shape=[1, vis_seq, dim], default_initializer=zeros_) self.add_parameter("vis_pos_embed", self.vis_pos_embed) self.char_node_fc1 = nn.Linear(dim, max_len) self.pos_node_fc1 = nn.Linear(dim, self.max_len) self.edge_fc = nn.Linear(dim, self.out_channels) trunc_normal_(self.char_pos_embed) trunc_normal_(self.vis_pos_embed) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: zeros_(m.bias) elif isinstance(m, nn.LayerNorm): zeros_(m.bias) ones_(m.weight) def forward(self, x, targets=None, epoch=0): if self.training: return self.forward_train(x, targets, epoch) else: return self.forward_test(x) def forward_test(self, x): visual_feats = x + self.vis_pos_embed bs = visual_feats.shape[0] pos_node_embed = self.pos_node_embed(paddle.arange( self.max_len)).unsqueeze(0) + self.char_pos_embed pos_node_embed = paddle.tile(pos_node_embed, [bs, 1, 1]) char_vis_node_query = visual_feats pos_vis_node_query = paddle.concat([pos_node_embed, visual_feats], 1) for char_decoder_layer, pos_decoder_layer in zip(self.char_node_decoder, self.pos_node_decoder): char_vis_node_query = char_decoder_layer(char_vis_node_query, char_vis_node_query) pos_vis_node_query = pos_decoder_layer( pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :]) pos_node_query = pos_vis_node_query[:, :self.max_len, :] char_vis_feats = char_vis_node_query pos_node_feats = self.edge_decoder(pos_node_query, char_vis_feats, char_vis_feats) # B, 26, dim edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37 edge_logits = F.softmax(edge_feats, -1) return edge_logits def forward_train(self, x, targets=None, epoch=0): visual_feats = x + self.vis_pos_embed bs = visual_feats.shape[0] if self.ch: char_node_embed = self.char_node_embed(targets[-2]) else: char_node_embed = self.char_node_embed( paddle.arange(self.out_channels)).unsqueeze(0) char_node_embed = paddle.tile(char_node_embed, [bs, 1, 1]) counting_char_num = paddle.shape(char_node_embed)[1] pos_node_embed = self.pos_node_embed(paddle.arange( self.max_len)).unsqueeze(0) + self.char_pos_embed pos_node_embed = paddle.tile(pos_node_embed, [bs, 1, 1]) node_feats = [] char_vis_node_query = paddle.concat([char_node_embed, visual_feats], 1) pos_vis_node_query = paddle.concat([pos_node_embed, visual_feats], 1) for char_decoder_layer, pos_decoder_layer in zip(self.char_node_decoder, self.pos_node_decoder): char_vis_node_query = char_decoder_layer( char_vis_node_query, char_vis_node_query[:, counting_char_num:, :]) pos_vis_node_query = pos_decoder_layer( pos_vis_node_query, pos_vis_node_query[:, self.max_len:, :]) char_node_query = char_vis_node_query[:, :counting_char_num, :] pos_node_query = pos_vis_node_query[:, :self.max_len, :] char_vis_feats = char_vis_node_query[:, counting_char_num:, :] char_node_feats1 = self.char_node_fc1(char_node_query) pos_node_feats1 = self.pos_node_fc1(pos_node_query) diag_mask = paddle.eye(pos_node_feats1.shape[1]).unsqueeze(0).tile( [pos_node_feats1.shape[0], 1, 1]) pos_node_feats1 = (pos_node_feats1 * diag_mask).sum(-1) node_feats.append(char_node_feats1) node_feats.append(pos_node_feats1) pos_node_feats = self.edge_decoder(pos_node_query, char_vis_feats, char_vis_feats) # B, 26, dim edge_feats = self.edge_fc(pos_node_feats) # B, 26, 37 return node_feats, edge_feats