274 lines
8.2 KiB
Python
274 lines
8.2 KiB
Python
# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from paddle import ParamAttr
|
|
from paddle.nn.initializer import KaimingNormal
|
|
import numpy as np
|
|
import paddle
|
|
import paddle.nn as nn
|
|
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
|
|
|
|
trunc_normal_ = TruncatedNormal(std=0.02)
|
|
normal_ = Normal
|
|
zeros_ = Constant(value=0.0)
|
|
ones_ = Constant(value=1.0)
|
|
|
|
|
|
def drop_path(x, drop_prob=0.0, training=False):
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
|
|
"""
|
|
if drop_prob == 0.0 or not training:
|
|
return x
|
|
keep_prob = paddle.to_tensor(1 - drop_prob)
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
|
|
random_tensor = paddle.floor(random_tensor) # binarize
|
|
output = x.divide(keep_prob) * random_tensor
|
|
return output
|
|
|
|
|
|
class DropPath(nn.Layer):
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
|
|
|
def __init__(self, drop_prob=None):
|
|
super(DropPath, self).__init__()
|
|
self.drop_prob = drop_prob
|
|
|
|
def forward(self, x):
|
|
return drop_path(x, self.drop_prob, self.training)
|
|
|
|
|
|
class Identity(nn.Layer):
|
|
def __init__(self):
|
|
super(Identity, self).__init__()
|
|
|
|
def forward(self, input):
|
|
return input
|
|
|
|
|
|
class Mlp(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
in_features,
|
|
hidden_features=None,
|
|
out_features=None,
|
|
act_layer=nn.GELU,
|
|
drop=0.0,
|
|
):
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
self.act = act_layer()
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop(x)
|
|
x = self.fc2(x)
|
|
x = self.drop(x)
|
|
return x
|
|
|
|
|
|
class Attention(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
num_heads=8,
|
|
qkv_bias=False,
|
|
qk_scale=None,
|
|
attn_drop=0.0,
|
|
proj_drop=0.0,
|
|
):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.dim = dim
|
|
head_dim = dim // num_heads
|
|
self.scale = qk_scale or head_dim**-0.5
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
self.proj = nn.Linear(dim, dim)
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
def forward(self, x):
|
|
qkv = paddle.reshape(
|
|
self.qkv(x), (0, -1, 3, self.num_heads, self.dim // self.num_heads)
|
|
).transpose((2, 0, 3, 1, 4))
|
|
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
|
|
|
attn = q.matmul(k.transpose((0, 1, 3, 2)))
|
|
attn = nn.functional.softmax(attn, axis=-1)
|
|
attn = self.attn_drop(attn)
|
|
|
|
x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, -1, self.dim))
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
return x
|
|
|
|
|
|
class Block(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
num_heads,
|
|
mlp_ratio=4.0,
|
|
qkv_bias=False,
|
|
qk_scale=None,
|
|
drop=0.0,
|
|
attn_drop=0.0,
|
|
drop_path=0.0,
|
|
act_layer=nn.GELU,
|
|
norm_layer="nn.LayerNorm",
|
|
epsilon=1e-6,
|
|
prenorm=True,
|
|
):
|
|
super().__init__()
|
|
if isinstance(norm_layer, str):
|
|
self.norm1 = eval(norm_layer)(dim, epsilon=epsilon)
|
|
else:
|
|
self.norm1 = norm_layer(dim)
|
|
self.mixer = Attention(
|
|
dim,
|
|
num_heads=num_heads,
|
|
qkv_bias=qkv_bias,
|
|
qk_scale=qk_scale,
|
|
attn_drop=attn_drop,
|
|
proj_drop=drop,
|
|
)
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
|
|
if isinstance(norm_layer, str):
|
|
self.norm2 = eval(norm_layer)(dim, epsilon=epsilon)
|
|
else:
|
|
self.norm2 = norm_layer(dim)
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
self.mlp_ratio = mlp_ratio
|
|
self.mlp = Mlp(
|
|
in_features=dim,
|
|
hidden_features=mlp_hidden_dim,
|
|
act_layer=act_layer,
|
|
drop=drop,
|
|
)
|
|
self.prenorm = prenorm
|
|
|
|
def forward(self, x):
|
|
if self.prenorm:
|
|
x = self.norm1(x + self.drop_path(self.mixer(x)))
|
|
x = self.norm2(x + self.drop_path(self.mlp(x)))
|
|
else:
|
|
x = x + self.drop_path(self.mixer(self.norm1(x)))
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
return x
|
|
|
|
|
|
class ViT(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
img_size=[32, 128],
|
|
patch_size=[4, 4],
|
|
in_channels=3,
|
|
embed_dim=384,
|
|
depth=12,
|
|
num_heads=6,
|
|
mlp_ratio=4,
|
|
qkv_bias=False,
|
|
qk_scale=None,
|
|
drop_rate=0.0,
|
|
attn_drop_rate=0.0,
|
|
drop_path_rate=0.1,
|
|
norm_layer="nn.LayerNorm",
|
|
epsilon=1e-6,
|
|
act="nn.GELU",
|
|
prenorm=False,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.out_channels = embed_dim
|
|
self.prenorm = prenorm
|
|
self.patch_embed = nn.Conv2D(
|
|
in_channels, embed_dim, patch_size, patch_size, padding=(0, 0)
|
|
)
|
|
self.pos_embed = self.create_parameter(
|
|
shape=[1, 257, embed_dim], default_initializer=zeros_
|
|
)
|
|
self.add_parameter("pos_embed", self.pos_embed)
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
dpr = np.linspace(0, drop_path_rate, depth)
|
|
self.blocks1 = nn.LayerList(
|
|
[
|
|
Block(
|
|
dim=embed_dim,
|
|
num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias,
|
|
qk_scale=qk_scale,
|
|
drop=drop_rate,
|
|
act_layer=eval(act),
|
|
attn_drop=attn_drop_rate,
|
|
drop_path=dpr[i],
|
|
norm_layer=norm_layer,
|
|
epsilon=epsilon,
|
|
prenorm=prenorm,
|
|
)
|
|
for i in range(depth)
|
|
]
|
|
)
|
|
if not prenorm:
|
|
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
|
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2D([1, 25])
|
|
self.last_conv = nn.Conv2D(
|
|
in_channels=embed_dim,
|
|
out_channels=self.out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias_attr=False,
|
|
)
|
|
self.hardswish = nn.Hardswish()
|
|
self.dropout = nn.Dropout(p=0.1, mode="downscale_in_infer")
|
|
|
|
trunc_normal_(self.pos_embed)
|
|
self.apply(self._init_weights)
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, nn.Linear):
|
|
trunc_normal_(m.weight)
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
zeros_(m.bias)
|
|
elif isinstance(m, nn.LayerNorm):
|
|
zeros_(m.bias)
|
|
ones_(m.weight)
|
|
|
|
def forward(self, x):
|
|
x = self.patch_embed(x).flatten(2).transpose((0, 2, 1))
|
|
x = x + self.pos_embed[:, 1:, :] # [:, :x.shape[1], :]
|
|
x = self.pos_drop(x)
|
|
for blk in self.blocks1:
|
|
x = blk(x)
|
|
if not self.prenorm:
|
|
x = self.norm(x)
|
|
|
|
x = self.avg_pool(x.transpose([0, 2, 1]).reshape([0, self.embed_dim, -1, 25]))
|
|
x = self.last_conv(x)
|
|
x = self.hardswish(x)
|
|
x = self.dropout(x)
|
|
return x
|