530 lines
16 KiB
Python
530 lines
16 KiB
Python
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
This code is refer from:
|
|
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from itertools import repeat
|
|
import collections
|
|
import math
|
|
from functools import partial
|
|
|
|
import paddle
|
|
import paddle.nn as nn
|
|
import paddle.nn.functional as F
|
|
from ppocr.modeling.backbones.rec_resnetv2 import (
|
|
ResNetV2,
|
|
StdConv2dSame,
|
|
DropPath,
|
|
get_padding,
|
|
)
|
|
from paddle.nn.initializer import (
|
|
TruncatedNormal,
|
|
Constant,
|
|
Normal,
|
|
KaimingUniform,
|
|
XavierUniform,
|
|
)
|
|
|
|
normal_ = Normal(mean=0.0, std=1e-6)
|
|
zeros_ = Constant(value=0.0)
|
|
ones_ = Constant(value=1.0)
|
|
kaiming_normal_ = KaimingUniform(nonlinearity="relu")
|
|
trunc_normal_ = TruncatedNormal(std=0.02)
|
|
xavier_uniform_ = XavierUniform()
|
|
|
|
|
|
def _ntuple(n):
|
|
def parse(x):
|
|
if isinstance(x, collections.abc.Iterable):
|
|
return x
|
|
return tuple(repeat(x, n))
|
|
|
|
return parse
|
|
|
|
|
|
to_1tuple = _ntuple(1)
|
|
to_2tuple = _ntuple(2)
|
|
to_3tuple = _ntuple(3)
|
|
to_4tuple = _ntuple(4)
|
|
to_ntuple = _ntuple
|
|
|
|
|
|
class Conv2dAlign(nn.Conv2D):
|
|
"""Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
|
|
|
|
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
|
|
https://arxiv.org/abs/1903.10520v2
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channel,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
padding=0,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=True,
|
|
eps=1e-6,
|
|
):
|
|
|
|
super().__init__(
|
|
in_channel,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias_attr=bias,
|
|
weight_attr=True,
|
|
)
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
x = F.conv2d(
|
|
x,
|
|
self.weight,
|
|
self.bias,
|
|
self._stride,
|
|
self._padding,
|
|
self._dilation,
|
|
self._groups,
|
|
)
|
|
return x
|
|
|
|
|
|
class HybridEmbed(nn.Layer):
|
|
"""CNN Feature Map Embedding
|
|
Extract feature map from CNN, flatten, project to embedding dim.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
backbone,
|
|
img_size=224,
|
|
patch_size=1,
|
|
feature_size=None,
|
|
in_chans=3,
|
|
embed_dim=768,
|
|
):
|
|
super().__init__()
|
|
assert isinstance(backbone, nn.Layer)
|
|
img_size = to_2tuple(img_size)
|
|
patch_size = to_2tuple(patch_size)
|
|
self.img_size = img_size
|
|
self.patch_size = patch_size
|
|
self.backbone = backbone
|
|
feature_dim = 1024
|
|
feature_size = (42, 12)
|
|
patch_size = (1, 1)
|
|
assert (
|
|
feature_size[0] % patch_size[0] == 0
|
|
and feature_size[1] % patch_size[1] == 0
|
|
)
|
|
|
|
self.grid_size = (
|
|
feature_size[0] // patch_size[0],
|
|
feature_size[1] // patch_size[1],
|
|
)
|
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
self.proj = nn.Conv2D(
|
|
feature_dim,
|
|
embed_dim,
|
|
kernel_size=patch_size,
|
|
stride=patch_size,
|
|
weight_attr=True,
|
|
bias_attr=True,
|
|
)
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.backbone(x)
|
|
if isinstance(x, (list, tuple)):
|
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
|
x = self.proj(x).flatten(2).transpose([0, 2, 1])
|
|
|
|
return x
|
|
|
|
|
|
class myLinear(nn.Linear):
|
|
def __init__(self, in_channel, out_channels, weight_attr=True, bias_attr=True):
|
|
super().__init__(
|
|
in_channel, out_channels, weight_attr=weight_attr, bias_attr=bias_attr
|
|
)
|
|
|
|
def forward(self, x):
|
|
return paddle.matmul(x, self.weight, transpose_y=True) + self.bias
|
|
|
|
|
|
class Attention(nn.Layer):
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
self.scale = head_dim**-0.5
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
self.proj = myLinear(dim, dim, weight_attr=True, bias_attr=True)
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
def forward(self, x):
|
|
B, N, C = x.shape
|
|
qkv = (
|
|
self.qkv(x)
|
|
.reshape([B, N, 3, self.num_heads, C // self.num_heads])
|
|
.transpose([2, 0, 3, 1, 4])
|
|
)
|
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
|
|
|
|
attn = F.softmax(attn, axis=-1)
|
|
attn = self.attn_drop(attn)
|
|
|
|
x = (attn @ v).transpose([0, 2, 1, 3]).reshape([B, N, C])
|
|
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
return x
|
|
|
|
|
|
class Mlp(nn.Layer):
|
|
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
hidden_features=None,
|
|
out_features=None,
|
|
act_layer=nn.GELU,
|
|
drop=0.0,
|
|
):
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
drop_probs = to_2tuple(drop)
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
self.act = act_layer()
|
|
self.drop1 = nn.Dropout(drop_probs[0])
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
self.drop2 = nn.Dropout(drop_probs[1])
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop1(x)
|
|
x = self.fc2(x)
|
|
x = self.drop2(x)
|
|
return x
|
|
|
|
|
|
class Block(nn.Layer):
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
num_heads,
|
|
mlp_ratio=4.0,
|
|
qkv_bias=False,
|
|
drop=0.0,
|
|
attn_drop=0.0,
|
|
drop_path=0.0,
|
|
act_layer=nn.GELU,
|
|
norm_layer=nn.LayerNorm,
|
|
):
|
|
super().__init__()
|
|
self.norm1 = norm_layer(dim)
|
|
self.attn = Attention(
|
|
dim,
|
|
num_heads=num_heads,
|
|
qkv_bias=qkv_bias,
|
|
attn_drop=attn_drop,
|
|
proj_drop=drop,
|
|
)
|
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
self.norm2 = norm_layer(dim)
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
self.mlp = Mlp(
|
|
in_features=dim,
|
|
hidden_features=mlp_hidden_dim,
|
|
act_layer=act_layer,
|
|
drop=drop,
|
|
)
|
|
|
|
def forward(self, x):
|
|
|
|
x = x + self.drop_path(self.attn(self.norm1(x)))
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
return x
|
|
|
|
|
|
class HybridTransformer(nn.Layer):
|
|
"""Implementation of HybridTransformer.
|
|
|
|
Args:
|
|
x: input images with shape [N, 1, H, W]
|
|
label: LaTeX-OCR labels with shape [N, L] , L is the max sequence length
|
|
attention_mask: LaTeX-OCR attention mask with shape [N, L] , L is the max sequence length
|
|
|
|
Returns:
|
|
The encoded features with shape [N, 1, H//16, W//16]
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
backbone_layers=[2, 3, 7],
|
|
input_channel=1,
|
|
is_predict=False,
|
|
is_export=False,
|
|
img_size=(224, 224),
|
|
patch_size=16,
|
|
num_classes=1000,
|
|
embed_dim=768,
|
|
depth=12,
|
|
num_heads=12,
|
|
mlp_ratio=4.0,
|
|
qkv_bias=True,
|
|
representation_size=None,
|
|
distilled=False,
|
|
drop_rate=0.0,
|
|
attn_drop_rate=0.0,
|
|
drop_path_rate=0.0,
|
|
embed_layer=None,
|
|
norm_layer=None,
|
|
act_layer=None,
|
|
weight_init="",
|
|
**kwargs,
|
|
):
|
|
super(HybridTransformer, self).__init__()
|
|
self.num_classes = num_classes
|
|
self.num_features = self.embed_dim = (
|
|
embed_dim # num_features for consistency with other models
|
|
)
|
|
self.num_tokens = 2 if distilled else 1
|
|
norm_layer = norm_layer or partial(nn.LayerNorm, epsilon=1e-6)
|
|
act_layer = act_layer or nn.GELU
|
|
self.height, self.width = img_size
|
|
self.patch_size = patch_size
|
|
backbone = ResNetV2(
|
|
layers=backbone_layers,
|
|
num_classes=0,
|
|
global_pool="",
|
|
in_chans=input_channel,
|
|
preact=False,
|
|
stem_type="same",
|
|
conv_layer=StdConv2dSame,
|
|
is_export=is_export,
|
|
)
|
|
min_patch_size = 2 ** (len(backbone_layers) + 1)
|
|
self.patch_embed = HybridEmbed(
|
|
img_size=img_size,
|
|
patch_size=patch_size // min_patch_size,
|
|
in_chans=input_channel,
|
|
embed_dim=embed_dim,
|
|
backbone=backbone,
|
|
)
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
self.cls_token = paddle.create_parameter([1, 1, embed_dim], dtype="float32")
|
|
self.dist_token = (
|
|
paddle.create_parameter(
|
|
[1, 1, embed_dim],
|
|
dtype="float32",
|
|
)
|
|
if distilled
|
|
else None
|
|
)
|
|
self.pos_embed = paddle.create_parameter(
|
|
[1, num_patches + self.num_tokens, embed_dim], dtype="float32"
|
|
)
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
zeros_(self.cls_token)
|
|
if self.dist_token is not None:
|
|
zeros_(self.dist_token)
|
|
zeros_(self.pos_embed)
|
|
|
|
dpr = [
|
|
x.item() for x in paddle.linspace(0, drop_path_rate, depth)
|
|
] # stochastic depth decay rule
|
|
self.blocks = nn.Sequential(
|
|
*[
|
|
Block(
|
|
dim=embed_dim,
|
|
num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias,
|
|
drop=drop_rate,
|
|
attn_drop=attn_drop_rate,
|
|
drop_path=dpr[i],
|
|
norm_layer=norm_layer,
|
|
act_layer=act_layer,
|
|
)
|
|
for i in range(depth)
|
|
]
|
|
)
|
|
self.norm = norm_layer(embed_dim)
|
|
|
|
# Representation layer
|
|
if representation_size and not distilled:
|
|
self.num_features = representation_size
|
|
self.pre_logits = nn.Sequential(
|
|
("fc", nn.Linear(embed_dim, representation_size)), ("act", nn.Tanh())
|
|
)
|
|
else:
|
|
self.pre_logits = nn.Identity()
|
|
|
|
# Classifier head(s)
|
|
self.head = (
|
|
nn.Linear(self.num_features, num_classes)
|
|
if num_classes > 0
|
|
else nn.Identity()
|
|
)
|
|
self.head_dist = None
|
|
if distilled:
|
|
self.head_dist = (
|
|
nn.Linear(self.embed_dim, self.num_classes)
|
|
if num_classes > 0
|
|
else nn.Identity()
|
|
)
|
|
self.init_weights(weight_init)
|
|
self.out_channels = embed_dim
|
|
self.is_predict = is_predict
|
|
self.is_export = is_export
|
|
|
|
def init_weights(self, mode=""):
|
|
assert mode in ("jax", "jax_nlhb", "nlhb", "")
|
|
head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
|
|
trunc_normal_(self.pos_embed)
|
|
trunc_normal_(self.cls_token)
|
|
self.apply(_init_vit_weights)
|
|
|
|
def _init_weights(self, m):
|
|
# this fn left here for compat with downstream users
|
|
_init_vit_weights(m)
|
|
|
|
def load_pretrained(self, checkpoint_path, prefix=""):
|
|
raise NotImplementedError
|
|
|
|
def no_weight_decay(self):
|
|
return {"pos_embed", "cls_token", "dist_token"}
|
|
|
|
def get_classifier(self):
|
|
if self.dist_token is None:
|
|
return self.head
|
|
else:
|
|
return self.head, self.head_dist
|
|
|
|
def reset_classifier(self, num_classes, global_pool=""):
|
|
self.num_classes = num_classes
|
|
self.head = (
|
|
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
)
|
|
if self.num_tokens == 2:
|
|
self.head_dist = (
|
|
nn.Linear(self.embed_dim, self.num_classes)
|
|
if num_classes > 0
|
|
else nn.Identity()
|
|
)
|
|
|
|
def forward_features(self, x):
|
|
B, c, h, w = x.shape
|
|
x = self.patch_embed(x)
|
|
cls_tokens = self.cls_token.expand(
|
|
[B, -1, -1]
|
|
) # stole cls_tokens impl from Phil Wang, thanks
|
|
x = paddle.concat((cls_tokens, x), axis=1)
|
|
h, w = h // self.patch_size, w // self.patch_size
|
|
repeat_tensor = (
|
|
paddle.arange(h) * (self.width // self.patch_size - w)
|
|
).reshape([-1, 1])
|
|
repeat_tensor = paddle.repeat_interleave(
|
|
repeat_tensor, paddle.to_tensor(w), axis=1
|
|
).reshape([-1])
|
|
pos_emb_ind = repeat_tensor + paddle.arange(h * w)
|
|
pos_emb_ind = paddle.concat(
|
|
(paddle.zeros([1], dtype="int64"), pos_emb_ind + 1), axis=0
|
|
).cast(paddle.int64)
|
|
x += self.pos_embed[:, pos_emb_ind]
|
|
x = self.pos_drop(x)
|
|
|
|
for blk in self.blocks:
|
|
x = blk(x)
|
|
|
|
x = self.norm(x)
|
|
return x
|
|
|
|
def forward(self, input_data):
|
|
|
|
if self.training:
|
|
x, label, attention_mask = input_data
|
|
else:
|
|
if isinstance(input_data, list):
|
|
x = input_data[0]
|
|
else:
|
|
x = input_data
|
|
x = self.forward_features(x)
|
|
x = self.head(x)
|
|
if self.training:
|
|
return x, label, attention_mask
|
|
else:
|
|
return x
|
|
|
|
|
|
def _init_vit_weights(
|
|
module: nn.Layer, name: str = "", head_bias: float = 0.0, jax_impl: bool = False
|
|
):
|
|
"""ViT weight initialization
|
|
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
|
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
|
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
|
"""
|
|
if isinstance(module, nn.Linear):
|
|
if name.startswith("head"):
|
|
zeros_(module.weight)
|
|
constant_ = Constant(value=head_bias)
|
|
constant_(module.bias, head_bias)
|
|
elif name.startswith("pre_logits"):
|
|
zeros_(module.bias)
|
|
else:
|
|
if jax_impl:
|
|
xavier_uniform_(module.weight)
|
|
if module.bias is not None:
|
|
if "mlp" in name:
|
|
normal_(module.bias)
|
|
else:
|
|
zeros_(module.bias)
|
|
else:
|
|
trunc_normal_(module.weight)
|
|
if module.bias is not None:
|
|
zeros_(module.bias)
|
|
elif jax_impl and isinstance(module, nn.Conv2D):
|
|
# NOTE conv was left to pytorch default in my original init
|
|
if module.bias is not None:
|
|
zeros_(module.bias)
|
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2D)):
|
|
zeros_(module.bias)
|
|
ones_(module.weight)
|