merge CAE
parent
a56366a9d1
commit
7b50ce6585
|
@ -69,6 +69,7 @@ from .model_zoo.repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG
|
|||
from .model_zoo.van import VAN_tiny
|
||||
from .model_zoo.peleenet import PeleeNet
|
||||
from .model_zoo.convnext import ConvNeXt_tiny
|
||||
from .model_zoo.cae import cae_base_patch16_224, cae_base_patch16_384, cae_large_patch16_224, cae_large_patch16_384, cae_large_patch16_512, cae_small_patch16_224
|
||||
|
||||
from .variant_models.resnet_variant import ResNet50_last_stage_stride1
|
||||
from .variant_models.vgg_variant import VGG19Sigmoid
|
||||
|
|
|
@ -0,0 +1,978 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Code was heavily based on https://github.com/PaddlePaddle/VIMER/blob/main/CAE/models/modeling_finetune.py
|
||||
# reference: https://arxiv.org/abs/2202.03026
|
||||
|
||||
import collections
|
||||
from itertools import repeat
|
||||
import math
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0., std=1.):
|
||||
nn.initializer.TruncatedNormal(mean=mean, std=std)(tensor)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float=0., training: bool=False):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0], ) + (1, ) * (
|
||||
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x / keep_prob * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Layer):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'p={}'.format(self.drop_prob)
|
||||
|
||||
|
||||
class Mlp(nn.Layer):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias_attr=True)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias_attr=True)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
# x = self.drop(x)
|
||||
# commit this for the orignal BERT implement
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Layer):
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
window_size=None,
|
||||
attn_head_dim=None):
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
if attn_head_dim is not None:
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.zeros_ = nn.initializer.Constant(value=0.)
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias_attr=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = self.create_parameter(
|
||||
[all_head_dim], default_initializer=self.zeros_)
|
||||
self.v_bias = self.create_parameter(
|
||||
[all_head_dim], default_initializer=self.zeros_)
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
if window_size:
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1) + 3
|
||||
self.relative_position_bias_table = self.create_parameter(
|
||||
[self.num_relative_distance, num_heads],
|
||||
default_initializer=self.zeros_) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = paddle.arange(window_size[0])
|
||||
coords_w = paddle.arange(window_size[1])
|
||||
coords = paddle.stack(paddle.meshgrid(
|
||||
[coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = paddle.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :,
|
||||
None] - coords_flatten[:,
|
||||
None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.transpose(
|
||||
[1, 2, 0]) # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[
|
||||
0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = \
|
||||
paddle.zeros((window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(
|
||||
-1) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer("relative_position_index",
|
||||
relative_position_index)
|
||||
else:
|
||||
self.window_size = None
|
||||
self.relative_position_bias_table = None
|
||||
self.relative_position_index = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(all_head_dim, dim, bias_attr=True)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
B, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
k_bias = paddle.zeros_like(self.v_bias)
|
||||
k_bias.stop_gradient = True
|
||||
qkv_bias = paddle.concat((self.q_bias, k_bias, self.v_bias))
|
||||
# qkv = self.qkv(x).reshape([B, N, 3, self.num_heads, C // self.num_heads]).transpose([2, 0, 3, 1, 4])
|
||||
qkv = F.linear(x=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape([B, N, 3, self.num_heads, -1]).transpose(
|
||||
[2, 0, 3, 1, 4])
|
||||
q, k, v = qkv[0], qkv[1], qkv[
|
||||
2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @k.transpose([0, 1, 3, 2]))
|
||||
|
||||
if self.relative_position_bias_table is not None:
|
||||
relative_position_bias = \
|
||||
self.relative_position_bias_table[self.relative_position_index.reshape([-1])].reshape([
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1, -1]) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.transpose(
|
||||
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if rel_pos_bias is not None:
|
||||
attn = attn + rel_pos_bias
|
||||
|
||||
attn = F.softmax(attn, axis=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, -1])
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Layer):
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
init_values=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
window_size=None,
|
||||
attn_head_dim=None):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
window_size=window_size,
|
||||
attn_head_dim=attn_head_dim)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
|
||||
if init_values > 0:
|
||||
self.gamma_1 = self.create_parameter(
|
||||
[dim],
|
||||
default_initializer=nn.initializer.Constant(value=init_values))
|
||||
self.gamma_2 = self.create_parameter(
|
||||
[dim],
|
||||
default_initializer=nn.initializer.Constant(value=init_values))
|
||||
else:
|
||||
self.gamma_1, self.gamma_2 = None, None
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path(
|
||||
self.attn(
|
||||
self.norm1(x), rel_pos_bias=rel_pos_bias))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(
|
||||
self.norm1(x), rel_pos_bias=rel_pos_bias))
|
||||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Layer):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
to_2tuple = _ntuple(2)
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] //
|
||||
patch_size[0])
|
||||
self.patch_shape = (img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1])
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
self.in_chans = in_chans
|
||||
self.out_chans = embed_dim
|
||||
self.proj = nn.Conv2D(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias_attr=True)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose([0, 2, 1])
|
||||
return x
|
||||
|
||||
def _init_weights(self):
|
||||
fan_out = self.out_chans
|
||||
fan_in = self.patch_size[0] * self.patch_size[1] * self.in_chans
|
||||
weight_attr = paddle.ParamAttr(
|
||||
initializer=nn.initializer.XavierUniform(fan_in, fan_out)) # MAE
|
||||
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
|
||||
return weight_attr, bias_attr
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Layer):
|
||||
def __init__(self, window_size, num_heads):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1) + 3
|
||||
self.zeros_ = nn.initializer.Constant(value=0.)
|
||||
self.relative_position_bias_table = self.create_parameter(
|
||||
[self.num_relative_distance, num_heads],
|
||||
default_initializer=self.zeros_) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = paddle.arange(window_size[0])
|
||||
coords_w = paddle.arange(window_size[1])
|
||||
coords = paddle.stack(paddle.meshgrid(
|
||||
[coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = paddle.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :,
|
||||
None] - coords_flatten[:,
|
||||
None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.transpose(
|
||||
[1, 2, 0]) # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = \
|
||||
paddle.zeros((window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(
|
||||
-1) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer("relative_position_index",
|
||||
relative_position_index)
|
||||
|
||||
def forward(self):
|
||||
relative_position_bias = \
|
||||
self.relative_position_bias_table[self.relative_position_index.reshape([-1])].reshape([
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1, -1]) # Wh*Ww,Wh*Ww,nH
|
||||
return relative_position_bias.transpose([2, 0, 1]) # nH, Wh*Ww, Wh*Ww
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid, token=False):
|
||||
''' Sinusoid position encoding table '''
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
]
|
||||
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
if token:
|
||||
sinusoid_table = np.concatenate(
|
||||
[sinusoid_table, np.zeros([1, d_hid])], dim=0)
|
||||
|
||||
return paddle.to_tensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Layer):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
class_num=1000,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
init_values=None,
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
use_mean_pooling=True,
|
||||
init_scale=0.001,
|
||||
lin_probe=False,
|
||||
sin_pos_emb=True,
|
||||
args=None):
|
||||
super().__init__()
|
||||
self.class_num = class_num
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.use_mean_pooling = use_mean_pooling
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.zeros_ = nn.initializer.Constant(value=0.)
|
||||
self.ones_ = nn.initializer.Constant(value=1.)
|
||||
|
||||
self.cls_token = self.create_parameter(
|
||||
[1, 1, embed_dim], default_initializer=self.zeros_)
|
||||
|
||||
self.use_abs_pos_emb = use_abs_pos_emb
|
||||
if use_abs_pos_emb:
|
||||
self.pos_embed = self.create_parameter(
|
||||
[1, num_patches + 1, embed_dim],
|
||||
default_initializer=self.zeros_)
|
||||
elif sin_pos_emb:
|
||||
# sine-cosine positional embeddings is on the way
|
||||
self.pos_embed = self.create_parameter(
|
||||
[1, num_patches + 1, embed_dim],
|
||||
default_initializer=self.zeros_)
|
||||
self.pos_embed.set_value(
|
||||
self.build_2d_sincos_position_embedding(embed_dim))
|
||||
self.pos_embed.stop_gradient = True # fixed sin-cos embedding
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
if use_shared_rel_pos_bias:
|
||||
self.rel_pos_bias = RelativePositionBias(
|
||||
window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
||||
else:
|
||||
self.rel_pos_bias = None
|
||||
|
||||
dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)
|
||||
] # stochastic depth decay rule
|
||||
self.use_rel_pos_bias = use_rel_pos_bias
|
||||
self.blocks = nn.LayerList([
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
init_values=init_values,
|
||||
window_size=self.patch_embed.patch_shape
|
||||
if use_rel_pos_bias else None) for i in range(depth)
|
||||
])
|
||||
self.norm = nn.Identity() if use_mean_pooling else norm_layer(
|
||||
embed_dim)
|
||||
|
||||
self.lin_probe = lin_probe
|
||||
# NOTE: batch norm
|
||||
if lin_probe:
|
||||
# TODO
|
||||
from models.lincls_bn import LP_BatchNorm
|
||||
self.fc_norm = LP_BatchNorm(embed_dim, affine=False)
|
||||
else:
|
||||
if use_mean_pooling:
|
||||
self.fc_norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.fc_norm = None
|
||||
self.head = nn.Linear(embed_dim,
|
||||
class_num) if class_num > 0 else nn.Identity()
|
||||
|
||||
if self.pos_embed is not None and use_abs_pos_emb:
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
# trunc_normal_(self.mask_token, std=.02)
|
||||
trunc_normal_(self.head.weight, std=.02)
|
||||
self.apply(self._init_weights)
|
||||
self.fix_init_weight()
|
||||
|
||||
self.head.weight.set_value(self.head.weight * init_scale)
|
||||
self.head.bias.set_value(self.head.bias * init_scale)
|
||||
|
||||
def build_2d_sincos_position_embedding(self,
|
||||
embed_dim=768,
|
||||
temperature=10000.):
|
||||
h, w = self.patch_embed.patch_shape
|
||||
grid_w = paddle.arange(w, dtype=paddle.float32)
|
||||
grid_h = paddle.arange(h, dtype=paddle.float32)
|
||||
grid_w, grid_h = paddle.meshgrid(grid_w, grid_h)
|
||||
assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
|
||||
pos_dim = embed_dim // 4
|
||||
omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim
|
||||
omega = 1. / (temperature**omega)
|
||||
out_w = paddle.einsum('m,d->md', grid_w.flatten(), omega)
|
||||
out_h = paddle.einsum('m,d->md', grid_h.flatten(), omega)
|
||||
pos_emb = paddle.concat(
|
||||
[
|
||||
paddle.sin(out_w), paddle.cos(out_w), paddle.sin(out_h),
|
||||
paddle.cos(out_h)
|
||||
],
|
||||
axis=1)[None, :, :]
|
||||
|
||||
# if not self.use_mean_pooling:
|
||||
pe_token = paddle.zeros([1, 1, embed_dim], dtype=paddle.float32)
|
||||
pos_emb = paddle.concat([pe_token, pos_emb], axis=1)
|
||||
return pos_emb
|
||||
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, layer_id):
|
||||
param.set_value(param / math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight, layer_id + 1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
self.zeros_(m.bias)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
self.zeros_(m.bias)
|
||||
self.ones_(m.weight)
|
||||
|
||||
def get_num_layers(self):
|
||||
return len(self.blocks)
|
||||
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, class_num, global_pool=''):
|
||||
self.class_num = class_num
|
||||
self.head = nn.Linear(self.embed_dim,
|
||||
class_num) if class_num > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x, is_train=True):
|
||||
x = self.patch_embed(x)
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
cls_tokens = self.cls_token.expand(
|
||||
[batch_size, -1,
|
||||
-1]) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = paddle.concat((cls_tokens, x), axis=1)
|
||||
if self.pos_embed is not None:
|
||||
if self.use_abs_pos_emb:
|
||||
x = x + self.pos_embed.expand(
|
||||
[batch_size, -1, -1]).astype(x.dtype).clone().detach()
|
||||
else:
|
||||
x = x + self.pos_embed.expand(
|
||||
[batch_size, -1, -1]).astype(x.dtype).clone().detach()
|
||||
|
||||
x = self.pos_drop(x)
|
||||
|
||||
rel_pos_bias = self.rel_pos_bias(
|
||||
) if self.rel_pos_bias is not None else None
|
||||
for blk in self.blocks:
|
||||
x = blk(x, rel_pos_bias=rel_pos_bias)
|
||||
|
||||
x = self.norm(x)
|
||||
if self.fc_norm is not None:
|
||||
t = x[:, 1:, :]
|
||||
if self.lin_probe:
|
||||
if self.use_mean_pooling:
|
||||
return self.fc_norm(t.mean(1), is_train=is_train)
|
||||
else:
|
||||
return self.fc_norm(x[:, 0], is_train=is_train)
|
||||
else:
|
||||
return self.fc_norm(t.mean(1))
|
||||
|
||||
else:
|
||||
return x[:, 0]
|
||||
|
||||
def forward(self, x, is_train=True):
|
||||
x = self.forward_features(x, is_train)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _enable_linear_eval(model):
|
||||
zeros_ = nn.initializer.Constant(value=0.)
|
||||
normal_ = nn.initializer.Normal(mean=0.0, std=0.01)
|
||||
linear_keyword = 'head'
|
||||
head_norm = 'fc_norm'
|
||||
requires_grad = []
|
||||
for name, param in model.named_parameters():
|
||||
if name not in [
|
||||
'%s.weight' % linear_keyword, '%s.bias' % linear_keyword
|
||||
] and head_norm not in name:
|
||||
param.stop_gradient = True
|
||||
else:
|
||||
requires_grad.append(name)
|
||||
# init the fc layer
|
||||
normal_(getattr(model, linear_keyword).weight)
|
||||
zeros_(getattr(model, linear_keyword).bias)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _load_pretrained(pretrained,
|
||||
model,
|
||||
model_keys,
|
||||
model_ema_configs,
|
||||
abs_pos_emb,
|
||||
rel_pos_bias,
|
||||
use_ssld=False):
|
||||
checkpoint = paddle.load(pretrained)
|
||||
checkpoint_model = None
|
||||
for model_key in model_keys.split('|'):
|
||||
if model_key in checkpoint:
|
||||
checkpoint_model = checkpoint[model_key]
|
||||
break
|
||||
|
||||
if checkpoint_model is None:
|
||||
checkpoint_model = checkpoint
|
||||
state_dict = model.state_dict()
|
||||
all_keys = list(checkpoint_model.keys())
|
||||
# NOTE: remove all decoder keys
|
||||
all_keys = [key for key in all_keys if key.startswith('encoder.')]
|
||||
for key in all_keys:
|
||||
new_key = key.replace('encoder.', '')
|
||||
checkpoint_model[new_key] = checkpoint_model[key]
|
||||
checkpoint_model.pop(key)
|
||||
|
||||
for key in list(checkpoint_model.keys()):
|
||||
if key.startswith('regressor_and_decoder.'):
|
||||
checkpoint_model.pop(key)
|
||||
if key.startswith('teacher_network.'):
|
||||
checkpoint_model.pop(key)
|
||||
|
||||
# NOTE: replace norm with fc_norm
|
||||
for key in list(checkpoint_model.keys()):
|
||||
if key.startswith('norm.'):
|
||||
new_key = key.replace('norm.', 'fc_norm.')
|
||||
checkpoint_model[new_key] = checkpoint_model[key]
|
||||
checkpoint_model.pop(key)
|
||||
|
||||
for k in ['head.weight', 'head.bias']:
|
||||
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[
|
||||
k].shape:
|
||||
del checkpoint_model[k]
|
||||
|
||||
if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model:
|
||||
num_layers = model.get_num_layers()
|
||||
rel_pos_bias = checkpoint_model[
|
||||
"rel_pos_bias.relative_position_bias_table"]
|
||||
for i in range(num_layers):
|
||||
checkpoint_model["blocks.%d.attn.relative_position_bias_table" %
|
||||
i] = rel_pos_bias.clone()
|
||||
|
||||
checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")
|
||||
|
||||
all_keys = list(checkpoint_model.keys())
|
||||
|
||||
for key in all_keys:
|
||||
if "relative_position_index" in key:
|
||||
checkpoint_model.pop(key)
|
||||
|
||||
if "relative_position_bias_table" in key and rel_pos_bias:
|
||||
rel_pos_bias = checkpoint_model[key]
|
||||
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
||||
dst_num_pos, _ = model.state_dict()[key].size()
|
||||
dst_patch_shape = model.patch_embed.patch_shape
|
||||
if dst_patch_shape[0] != dst_patch_shape[1]:
|
||||
raise NotImplementedError()
|
||||
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (
|
||||
dst_patch_shape[1] * 2 - 1)
|
||||
src_size = int((src_num_pos - num_extra_tokens)**0.5)
|
||||
dst_size = int((dst_num_pos - num_extra_tokens)**0.5)
|
||||
if src_size != dst_size:
|
||||
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
||||
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
||||
|
||||
def geometric_progression(a, r, n):
|
||||
return a * (1.0 - r**n) / (1.0 - r)
|
||||
|
||||
left, right = 1.01, 1.5
|
||||
while right - left > 1e-6:
|
||||
q = (left + right) / 2.0
|
||||
gp = geometric_progression(1, q, src_size // 2)
|
||||
if gp > dst_size // 2:
|
||||
right = q
|
||||
else:
|
||||
left = q
|
||||
|
||||
dis = []
|
||||
cur = 1
|
||||
for i in range(src_size // 2):
|
||||
dis.append(cur)
|
||||
cur += q**(i + 1)
|
||||
|
||||
r_ids = [-_ for _ in reversed(dis)]
|
||||
|
||||
x = r_ids + [0] + dis
|
||||
y = r_ids + [0] + dis
|
||||
|
||||
t = dst_size // 2.0
|
||||
dx = np.arange(-t, t + 0.1, 1.0)
|
||||
dy = np.arange(-t, t + 0.1, 1.0)
|
||||
|
||||
all_rel_pos_bias = []
|
||||
|
||||
for i in range(num_attn_heads):
|
||||
z = rel_pos_bias[:, i].view(src_size,
|
||||
src_size).float().numpy()
|
||||
f = interpolate.interp2d(x, y, z, kind='cubic')
|
||||
all_rel_pos_bias.append(
|
||||
paddle.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(
|
||||
rel_pos_bias.device))
|
||||
|
||||
rel_pos_bias = paddle.concat(all_rel_pos_bias, axis=-1)
|
||||
|
||||
new_rel_pos_bias = paddle.concat(
|
||||
(rel_pos_bias, extra_tokens), axis=0)
|
||||
checkpoint_model[key] = new_rel_pos_bias
|
||||
|
||||
# interpolate position embedding
|
||||
if 'pos_embed' in checkpoint_model and abs_pos_emb:
|
||||
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_patches = model.patch_embed.num_patches
|
||||
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens)**
|
||||
0.5)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(num_patches**0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
if orig_size != new_size:
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
|
||||
embedding_size).permute(0, 3, 1, 2)
|
||||
pos_tokens = paddle.nn.functional.interpolate(
|
||||
pos_tokens,
|
||||
size=(new_size, new_size),
|
||||
mode='bicubic',
|
||||
align_corners=False)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = paddle.concat((extra_tokens, pos_tokens), axis=1)
|
||||
checkpoint_model['pos_embed'] = new_pos_embed
|
||||
msg = model.set_state_dict(checkpoint_model)
|
||||
|
||||
model_without_ddp = model
|
||||
n_parameters = sum(p.numel() for p in model.parameters()
|
||||
if not p.stop_gradient).item()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def cae_small_patch16_224(**kwargs):
|
||||
config = kwargs.copy()
|
||||
enable_linear_eval = config.pop('enable_linear_eval')
|
||||
model_keys = config.pop('model_key')
|
||||
model_ema_configs = config.pop('model_ema')
|
||||
abs_pos_emb = config.pop('abs_pos_emb')
|
||||
rel_pos_bias = config.pop('rel_pos_bias')
|
||||
pretrained = config.pop('pretrained')
|
||||
|
||||
model = VisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(
|
||||
nn.LayerNorm, epsilon=1e-6),
|
||||
**kwargs)
|
||||
|
||||
if enable_linear_eval:
|
||||
_enable_linear_eval(model)
|
||||
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
model_keys,
|
||||
model_ema_configs,
|
||||
abs_pos_emb,
|
||||
rel_pos_bias,
|
||||
use_ssld=False)
|
||||
return model
|
||||
|
||||
|
||||
def cae_base_patch16_224(**kwargs):
|
||||
config = kwargs.copy()
|
||||
enable_linear_eval = config.pop('enable_linear_eval')
|
||||
model_keys = config.pop('model_key')
|
||||
model_ema_configs = config.pop('model_ema')
|
||||
abs_pos_emb = config.pop('abs_pos_emb')
|
||||
rel_pos_bias = config.pop('rel_pos_bias')
|
||||
pretrained = config.pop('pretrained')
|
||||
|
||||
model = VisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(
|
||||
nn.LayerNorm, epsilon=1e-6),
|
||||
**config)
|
||||
|
||||
if enable_linear_eval:
|
||||
_enable_linear_eval(model)
|
||||
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
model_keys,
|
||||
model_ema_configs,
|
||||
abs_pos_emb,
|
||||
rel_pos_bias,
|
||||
use_ssld=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def cae_base_patch16_384(**kwargs):
|
||||
config = kwargs.copy()
|
||||
enable_linear_eval = config.pop('enable_linear_eval')
|
||||
model_keys = config.pop('model_key')
|
||||
model_ema_configs = config.pop('model_ema')
|
||||
abs_pos_emb = config.pop('abs_pos_emb')
|
||||
rel_pos_bias = config.pop('rel_pos_bias')
|
||||
pretrained = config.pop('pretrained')
|
||||
|
||||
model = VisionTransformer(
|
||||
img_size=384,
|
||||
patch_size=16,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(
|
||||
nn.LayerNorm, epsilon=1e-6),
|
||||
**kwargs)
|
||||
|
||||
if enable_linear_eval:
|
||||
_enable_linear_eval(model)
|
||||
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
model_keys,
|
||||
model_ema_configs,
|
||||
abs_pos_emb,
|
||||
rel_pos_bias,
|
||||
use_ssld=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def cae_large_patch16_224(**kwargs):
|
||||
config = kwargs.copy()
|
||||
enable_linear_eval = config.pop('enable_linear_eval')
|
||||
model_keys = config.pop('model_key')
|
||||
model_ema_configs = config.pop('model_ema')
|
||||
abs_pos_emb = config.pop('abs_pos_emb')
|
||||
rel_pos_bias = config.pop('rel_pos_bias')
|
||||
pretrained = config.pop('pretrained')
|
||||
|
||||
model = VisionTransformer(
|
||||
patch_size=16,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(
|
||||
nn.LayerNorm, epsilon=1e-6),
|
||||
**kwargs)
|
||||
|
||||
if enable_linear_eval:
|
||||
_enable_linear_eval(model)
|
||||
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
model_keys,
|
||||
model_ema_configs,
|
||||
abs_pos_emb,
|
||||
rel_pos_bias,
|
||||
use_ssld=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def cae_large_patch16_384(**kwargs):
|
||||
config = kwargs.copy()
|
||||
enable_linear_eval = config.pop('enable_linear_eval')
|
||||
model_keys = config.pop('model_key')
|
||||
model_ema_configs = config.pop('model_ema')
|
||||
abs_pos_emb = config.pop('abs_pos_emb')
|
||||
rel_pos_bias = config.pop('rel_pos_bias')
|
||||
pretrained = config.pop('pretrained')
|
||||
|
||||
model = VisionTransformer(
|
||||
img_size=384,
|
||||
patch_size=16,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(
|
||||
nn.LayerNorm, epsilon=1e-6),
|
||||
**kwargs)
|
||||
|
||||
if enable_linear_eval:
|
||||
_enable_linear_eval(model)
|
||||
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
model_keys,
|
||||
model_ema_configs,
|
||||
abs_pos_emb,
|
||||
rel_pos_bias,
|
||||
use_ssld=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def cae_large_patch16_512(**kwargs):
|
||||
config = kwargs.copy()
|
||||
enable_linear_eval = config.pop('enable_linear_eval')
|
||||
model_keys = config.pop('model_key')
|
||||
model_ema_configs = config.pop('model_ema')
|
||||
abs_pos_emb = config.pop('abs_pos_emb')
|
||||
rel_pos_bias = config.pop('rel_pos_bias')
|
||||
pretrained = config.pop('pretrained')
|
||||
|
||||
model = VisionTransformer(
|
||||
img_size=512,
|
||||
patch_size=16,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(
|
||||
nn.LayerNorm, epsilon=1e-6),
|
||||
**kwargs)
|
||||
|
||||
if enable_linear_eval:
|
||||
_enable_linear_eval(model)
|
||||
|
||||
_load_pretrained(
|
||||
pretrained,
|
||||
model,
|
||||
model_keys,
|
||||
model_ema_configs,
|
||||
abs_pos_emb,
|
||||
rel_pos_bias,
|
||||
use_ssld=False)
|
||||
|
||||
return model
|
|
@ -0,0 +1,170 @@
|
|||
# global configs
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
output_dir: ./output/
|
||||
device: gpu
|
||||
save_interval: 1
|
||||
eval_during_train: True
|
||||
eval_interval: 1
|
||||
epochs: 200
|
||||
print_batch_step: 10
|
||||
use_visualdl: False
|
||||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: cae_base_patch16_224
|
||||
class_num: 4
|
||||
drop_rate: 0.0
|
||||
drop_path_rate: 0.1
|
||||
attn_drop_rate: 0.0
|
||||
|
||||
use_mean_pooling: True
|
||||
init_scale: 0.001
|
||||
use_rel_pos_bias: True
|
||||
use_abs_pos_emb: False
|
||||
init_values: 0.1
|
||||
lin_probe: False
|
||||
|
||||
sin_pos_emb: True
|
||||
|
||||
abs_pos_emb: False
|
||||
enable_linear_eval: False
|
||||
model_key: model|module|state_dict
|
||||
rel_pos_bias: True
|
||||
model_ema:
|
||||
enable_model_ema: False
|
||||
model_ema_decay: 0.9999
|
||||
model_ema_force_cpu: False
|
||||
pretrained: ./pretrained/vit_base_cae_pretrained.pdparams
|
||||
|
||||
# loss function config for traing/eval process
|
||||
Loss:
|
||||
Train:
|
||||
- SoftTargetCrossEntropy:
|
||||
weight: 1.0
|
||||
Eval:
|
||||
- CELoss:
|
||||
weight: 1.0
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: AdamWDL
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
epsilon: 1e-8
|
||||
weight_decay: 0.05
|
||||
layerwise_decay: 0.65
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.004
|
||||
eta_min: 1e-6
|
||||
warmup_epoch: 10
|
||||
warmup_start_lr: 1e-6
|
||||
|
||||
|
||||
# data loader for train and eval
|
||||
DataLoader:
|
||||
Train:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/paddle-job-153869-0/train_eval_data/images
|
||||
cls_label_path: ./dataset/paddle-job-153869-0/train_eval_data/train_data_list.txt
|
||||
batch_transform_ops:
|
||||
- MixupCutmixHybrid:
|
||||
mixup_alpha: 0.8
|
||||
cutmix_alpha: 1.0
|
||||
switch_prob: 0.5
|
||||
num_classes: 4
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- RandCropImage:
|
||||
size: 224
|
||||
interpolation: bilinear
|
||||
- RandFlipImage:
|
||||
flip_code: 1
|
||||
- RandAugment:
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- RandomErasing:
|
||||
EPSILON: 0.5
|
||||
sl: 0.02
|
||||
sh: 0.3
|
||||
r1: 0.3
|
||||
delimiter: ' '
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 128
|
||||
drop_last: True
|
||||
shuffle: True
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: ImageNetDataset
|
||||
image_root: ./dataset/paddle-job-153869-0/train_eval_data/images
|
||||
cls_label_path: ./dataset/paddle-job-153869-0/train_eval_data/eval_data_list.txt
|
||||
transform_ops:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
delimiter: ' '
|
||||
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 128
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
num_workers: 4
|
||||
use_shared_memory: True
|
||||
|
||||
Infer:
|
||||
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
|
||||
batch_size: 10
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
to_rgb: True
|
||||
channel_first: False
|
||||
- ResizeImage:
|
||||
resize_short: 256
|
||||
- CropImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 1.0/255.0
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
std: [0.5, 0.5, 0.5]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
PostProcess:
|
||||
name: Topk
|
||||
topk: 5
|
||||
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
|
||||
|
||||
Metric:
|
||||
Train:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
|
@ -42,6 +42,7 @@ from ppcls.data.preprocess.ops.operators import RandomRotation
|
|||
from ppcls.data.preprocess.ops.operators import Padv2
|
||||
|
||||
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
|
||||
from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
|
|
@ -23,6 +23,9 @@ import numpy as np
|
|||
from ppcls.utils import logger
|
||||
from ppcls.data.preprocess.ops.fmix import sample_mask
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class BatchOperator(object):
|
||||
""" BatchOperator """
|
||||
|
@ -229,3 +232,270 @@ class OpSampler(object):
|
|||
list(self.ops.keys()), weights=list(self.ops.values()), k=1)[0]
|
||||
# return batch directly when None Op
|
||||
return op(batch) if op else batch
|
||||
|
||||
|
||||
class MixupCutmixHybrid(object):
|
||||
""" Mixup/Cutmix that applies different params to each element or whole batch
|
||||
|
||||
Args:
|
||||
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
|
||||
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
|
||||
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
|
||||
prob (float): probability of applying mixup or cutmix per batch or element
|
||||
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
|
||||
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
|
||||
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
|
||||
label_smoothing (float): apply label smoothing to the mixed target tensor
|
||||
num_classes (int): number of classes for target
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mixup_alpha=1.,
|
||||
cutmix_alpha=0.,
|
||||
cutmix_minmax=None,
|
||||
prob=1.0,
|
||||
switch_prob=0.5,
|
||||
mode='batch',
|
||||
correct_lam=True,
|
||||
label_smoothing=0.1,
|
||||
num_classes=4):
|
||||
self.mixup_alpha = mixup_alpha
|
||||
self.cutmix_alpha = cutmix_alpha
|
||||
self.cutmix_minmax = cutmix_minmax
|
||||
if self.cutmix_minmax is not None:
|
||||
assert len(self.cutmix_minmax) == 2
|
||||
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
|
||||
self.cutmix_alpha = 1.0
|
||||
self.mix_prob = prob
|
||||
self.switch_prob = switch_prob
|
||||
self.label_smoothing = label_smoothing
|
||||
self.num_classes = num_classes
|
||||
self.mode = mode
|
||||
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
|
||||
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
|
||||
|
||||
def _one_hot(self, x, num_classes, on_value=1., off_value=0.):
|
||||
x = paddle.cast(x, dtype='int64')
|
||||
on_value = paddle.full([x.shape[0], num_classes], on_value)
|
||||
off_value = paddle.full([x.shape[0], num_classes], off_value)
|
||||
return paddle.where(
|
||||
F.one_hot(x, num_classes) == 1, on_value, off_value)
|
||||
|
||||
def _mixup_target(self, target, num_classes, lam=1., smoothing=0.0):
|
||||
off_value = smoothing / num_classes
|
||||
on_value = 1. - smoothing + off_value
|
||||
y1 = self._one_hot(
|
||||
target,
|
||||
num_classes,
|
||||
on_value=on_value,
|
||||
off_value=off_value, )
|
||||
y2 = self._one_hot(
|
||||
target.flip(0),
|
||||
num_classes,
|
||||
on_value=on_value,
|
||||
off_value=off_value)
|
||||
return y1 * lam + y2 * (1. - lam)
|
||||
|
||||
def _rand_bbox(self, img_shape, lam, margin=0., count=None):
|
||||
""" Standard CutMix bounding-box
|
||||
Generates a random square bbox based on lambda value. This impl includes
|
||||
support for enforcing a border margin as percent of bbox dimensions.
|
||||
|
||||
Args:
|
||||
img_shape (tuple): Image shape as tuple
|
||||
lam (float): Cutmix lambda value
|
||||
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
|
||||
count (int): Number of bbox to generate
|
||||
"""
|
||||
ratio = np.sqrt(1 - lam)
|
||||
img_h, img_w = img_shape[-2:]
|
||||
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
|
||||
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
|
||||
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
|
||||
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
|
||||
yl = np.clip(cy - cut_h // 2, 0, img_h)
|
||||
yh = np.clip(cy + cut_h // 2, 0, img_h)
|
||||
xl = np.clip(cx - cut_w // 2, 0, img_w)
|
||||
xh = np.clip(cx + cut_w // 2, 0, img_w)
|
||||
return yl, yh, xl, xh
|
||||
|
||||
def _rand_bbox_minmax(self, img_shape, minmax, count=None):
|
||||
""" Min-Max CutMix bounding-box
|
||||
Inspired by Darknet cutmix impl, generates a random rectangular bbox
|
||||
based on min/max percent values applied to each dimension of the input image.
|
||||
|
||||
Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
|
||||
|
||||
Args:
|
||||
img_shape (tuple): Image shape as tuple
|
||||
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
|
||||
count (int): Number of bbox to generate
|
||||
"""
|
||||
assert len(minmax) == 2
|
||||
img_h, img_w = img_shape[-2:]
|
||||
cut_h = np.random.randint(
|
||||
int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
|
||||
cut_w = np.random.randint(
|
||||
int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
|
||||
yl = np.random.randint(0, img_h - cut_h, size=count)
|
||||
xl = np.random.randint(0, img_w - cut_w, size=count)
|
||||
yu = yl + cut_h
|
||||
xu = xl + cut_w
|
||||
return yl, yu, xl, xu
|
||||
|
||||
def _cutmix_bbox_and_lam(self,
|
||||
img_shape,
|
||||
lam,
|
||||
ratio_minmax=None,
|
||||
correct_lam=True,
|
||||
count=None):
|
||||
""" Generate bbox and apply lambda correction.
|
||||
"""
|
||||
if ratio_minmax is not None:
|
||||
yl, yu, xl, xu = self._rand_bbox_minmax(
|
||||
img_shape, ratio_minmax, count=count)
|
||||
else:
|
||||
yl, yu, xl, xu = self._rand_bbox(img_shape, lam, count=count)
|
||||
if correct_lam or ratio_minmax is not None:
|
||||
bbox_area = (yu - yl) * (xu - xl)
|
||||
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
|
||||
return (yl, yu, xl, xu), lam
|
||||
|
||||
def _params_per_elem(self, batch_size):
|
||||
lam = np.ones(batch_size, dtype=np.float32)
|
||||
use_cutmix = np.zeros(batch_size, dtype=np.bool)
|
||||
if self.mixup_enabled:
|
||||
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
|
||||
use_cutmix = np.random.rand(batch_size) < self.switch_prob
|
||||
lam_mix = np.where(
|
||||
use_cutmix,
|
||||
np.random.beta(
|
||||
self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
|
||||
np.random.beta(
|
||||
self.mixup_alpha, self.mixup_alpha, size=batch_size))
|
||||
elif self.mixup_alpha > 0.:
|
||||
lam_mix = np.random.beta(
|
||||
self.mixup_alpha, self.mixup_alpha, size=batch_size)
|
||||
elif self.cutmix_alpha > 0.:
|
||||
use_cutmix = np.ones(batch_size, dtype=np.bool)
|
||||
lam_mix = np.random.beta(
|
||||
self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
|
||||
else:
|
||||
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
|
||||
lam = np.where(
|
||||
np.random.rand(batch_size) < self.mix_prob,
|
||||
lam_mix.astype(np.float32), lam)
|
||||
return lam, use_cutmix
|
||||
|
||||
def _params_per_batch(self):
|
||||
lam = 1.
|
||||
use_cutmix = False
|
||||
if self.mixup_enabled and np.random.rand() < self.mix_prob:
|
||||
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
|
||||
use_cutmix = np.random.rand() < self.switch_prob
|
||||
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
|
||||
np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
||||
elif self.mixup_alpha > 0.:
|
||||
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
|
||||
elif self.cutmix_alpha > 0.:
|
||||
use_cutmix = True
|
||||
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
|
||||
else:
|
||||
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
|
||||
lam = float(lam_mix)
|
||||
return lam, use_cutmix
|
||||
|
||||
def _mix_elem(self, x):
|
||||
batch_size = len(x)
|
||||
lam_batch, use_cutmix = self._params_per_elem(batch_size)
|
||||
x_orig = x.clone(
|
||||
) # need to keep an unmodified original for mixing source
|
||||
for i in range(batch_size):
|
||||
j = batch_size - i - 1
|
||||
lam = lam_batch[i]
|
||||
if lam != 1.:
|
||||
if use_cutmix[i]:
|
||||
(yl, yh, xl, xh), lam = self._cutmix_bbox_and_lam(
|
||||
x[i].shape,
|
||||
lam,
|
||||
ratio_minmax=self.cutmix_minmax,
|
||||
correct_lam=self.correct_lam)
|
||||
if yl < yh and xl < xh:
|
||||
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
|
||||
lam_batch[i] = lam
|
||||
else:
|
||||
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
|
||||
return paddle.to_tensor(lam_batch, dtype=x.dtype).unsqueeze(1)
|
||||
|
||||
def _mix_pair(self, x):
|
||||
batch_size = len(x)
|
||||
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
|
||||
x_orig = x.clone(
|
||||
) # need to keep an unmodified original for mixing source
|
||||
for i in range(batch_size // 2):
|
||||
j = batch_size - i - 1
|
||||
lam = lam_batch[i]
|
||||
if lam != 1.:
|
||||
if use_cutmix[i]:
|
||||
(yl, yh, xl, xh), lam = self._cutmix_bbox_and_lam(
|
||||
x[i].shape,
|
||||
lam,
|
||||
ratio_minmax=self.cutmix_minmax,
|
||||
correct_lam=self.correct_lam)
|
||||
if yl < yh and xl < xh:
|
||||
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
|
||||
x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
|
||||
lam_batch[i] = lam
|
||||
else:
|
||||
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
|
||||
x[j] = x[j] * lam + x_orig[i] * (1 - lam)
|
||||
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
|
||||
return paddle.to_tensor(lam_batch, dtype=x.dtype).unsqueeze(1)
|
||||
|
||||
def _mix_batch(self, x):
|
||||
lam, use_cutmix = self._params_per_batch()
|
||||
if lam == 1.:
|
||||
return 1.
|
||||
if use_cutmix:
|
||||
(yl, yh, xl, xh), lam = self._cutmix_bbox_and_lam(
|
||||
x.shape,
|
||||
lam,
|
||||
ratio_minmax=self.cutmix_minmax,
|
||||
correct_lam=self.correct_lam)
|
||||
if yl < yh and xl < xh:
|
||||
x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
|
||||
|
||||
else:
|
||||
x_flipped = x.flip(0) * (1. - lam)
|
||||
x[:] = x * lam + x_flipped
|
||||
return lam
|
||||
|
||||
def _unpack(self, batch):
|
||||
""" _unpack """
|
||||
assert isinstance(batch, list), \
|
||||
'batch should be a list filled with tuples (img, label)'
|
||||
bs = len(batch)
|
||||
assert bs > 0, 'size of the batch data should > 0'
|
||||
#imgs, labels = list(zip(*batch))
|
||||
imgs = []
|
||||
labels = []
|
||||
for item in batch:
|
||||
imgs.append(item[0])
|
||||
labels.append(item[1])
|
||||
return np.array(imgs), np.array(labels), bs
|
||||
|
||||
def __call__(self, batch):
|
||||
x, target, bs = self._unpack(batch)
|
||||
x = paddle.to_tensor(x)
|
||||
target = paddle.to_tensor(target)
|
||||
assert len(x) % 2 == 0, 'Batch size should be even when using this'
|
||||
if self.mode == 'elem':
|
||||
lam = self._mix_elem(x)
|
||||
elif self.mode == 'pair':
|
||||
lam = self._mix_pair(x)
|
||||
else:
|
||||
lam = self._mix_batch(x)
|
||||
target = self._mixup_target(target, self.num_classes, lam,
|
||||
self.label_smoothing)
|
||||
|
||||
return list(zip(x.numpy(), target.numpy()))
|
||||
|
|
|
@ -17,6 +17,7 @@ from .supconloss import SupConLoss
|
|||
from .pairwisecosface import PairwiseCosface
|
||||
from .dmlloss import DMLLoss
|
||||
from .distanceloss import DistanceLoss
|
||||
from .softtargetceloss import SoftTargetCrossEntropy
|
||||
|
||||
from .distillationloss import DistillationCELoss
|
||||
from .distillationloss import DistillationGTCELoss
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class SoftTargetCrossEntropy(nn.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, target):
|
||||
loss = paddle.sum(-target * F.log_softmax(x, axis=-1), axis=-1)
|
||||
loss = loss.mean()
|
||||
return {"SoftTargetCELoss": loss}
|
||||
|
||||
def __str__(self, ):
|
||||
return type(self).__name__
|
|
@ -272,3 +272,145 @@ class AdamW(object):
|
|||
|
||||
def _apply_decay_param_fun(self, name):
|
||||
return name not in self.no_weight_decay_param_name_list
|
||||
|
||||
|
||||
class AdamWDL(object):
|
||||
"""
|
||||
The AdamWDL optimizer is implemented based on the AdamW Optimization with dynamic lr setting.
|
||||
Generally it's used for transformer model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate=0.001,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
epsilon=1e-8,
|
||||
weight_decay=None,
|
||||
multi_precision=False,
|
||||
grad_clip=None,
|
||||
layerwise_decay=None,
|
||||
filter_bias_and_bn=True,
|
||||
**args):
|
||||
self.learning_rate = learning_rate
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
self.epsilon = epsilon
|
||||
self.grad_clip = grad_clip
|
||||
self.weight_decay = weight_decay
|
||||
self.multi_precision = multi_precision
|
||||
self.layerwise_decay = layerwise_decay
|
||||
self.filter_bias_and_bn = filter_bias_and_bn
|
||||
|
||||
class AdamWDLImpl(optim.AdamW):
|
||||
def __init__(self,
|
||||
learning_rate=0.001,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
epsilon=1e-8,
|
||||
parameters=None,
|
||||
weight_decay=0.01,
|
||||
apply_decay_param_fun=None,
|
||||
grad_clip=None,
|
||||
lazy_mode=False,
|
||||
multi_precision=False,
|
||||
layerwise_decay=1.0,
|
||||
n_layers=12,
|
||||
name_dict=None,
|
||||
name=None):
|
||||
if not isinstance(layerwise_decay, float) and \
|
||||
not isinstance(layerwise_decay, fluid.framework.Variable):
|
||||
raise TypeError("coeff should be float or Tensor.")
|
||||
self.layerwise_decay = layerwise_decay
|
||||
self.name_dict = name_dict
|
||||
self.n_layers = n_layers
|
||||
self.set_param_lr_fun = self._layerwise_lr_decay
|
||||
super().__init__(
|
||||
learning_rate=learning_rate,
|
||||
parameters=parameters,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
epsilon=epsilon,
|
||||
grad_clip=grad_clip,
|
||||
name=name,
|
||||
apply_decay_param_fun=apply_decay_param_fun,
|
||||
weight_decay=weight_decay,
|
||||
lazy_mode=lazy_mode,
|
||||
multi_precision=multi_precision)
|
||||
|
||||
def _append_optimize_op(self, block, param_and_grad):
|
||||
if self.set_param_lr_fun is None:
|
||||
return super(AdamLW, self)._append_optimize_op(block,
|
||||
param_and_grad)
|
||||
|
||||
self._append_decoupled_weight_decay(block, param_and_grad)
|
||||
prev_lr = param_and_grad[0].optimize_attr["learning_rate"]
|
||||
self.set_param_lr_fun(self.layerwise_decay, self.name_dict,
|
||||
self.n_layers, param_and_grad[0])
|
||||
# excute Adam op
|
||||
res = super(optim.AdamW, self)._append_optimize_op(block,
|
||||
param_and_grad)
|
||||
param_and_grad[0].optimize_attr["learning_rate"] = prev_lr
|
||||
return res
|
||||
|
||||
# Layerwise decay
|
||||
def _layerwise_lr_decay(self, decay_rate, name_dict, n_layers, param):
|
||||
"""
|
||||
Args:
|
||||
decay_rate (float):
|
||||
The layer-wise decay ratio.
|
||||
name_dict (dict):
|
||||
The keys of name_dict is dynamic name of model while the value
|
||||
of name_dict is static name.
|
||||
Use model.named_parameters() to get name_dict.
|
||||
n_layers (int):
|
||||
Total number of layers in the transformer encoder.
|
||||
"""
|
||||
ratio = 1.0
|
||||
static_name = name_dict[param.name]
|
||||
if "blocks" in static_name:
|
||||
idx = static_name.find("blocks.")
|
||||
layer = int(static_name[idx:].split(".")[1])
|
||||
ratio = decay_rate**(n_layers - layer)
|
||||
elif "embed" in static_name:
|
||||
ratio = decay_rate**(n_layers + 1)
|
||||
param.optimize_attr["learning_rate"] *= ratio
|
||||
|
||||
def __call__(self, model_list):
|
||||
model = model_list[0]
|
||||
if self.weight_decay and self.filter_bias_and_bn:
|
||||
skip = {}
|
||||
if hasattr(model, 'no_weight_decay'):
|
||||
skip = model.no_weight_decay()
|
||||
decay_dict = {
|
||||
param.name: not (len(param.shape) == 1 or
|
||||
name.endswith(".bias") or name in skip)
|
||||
for name, param in model.named_parameters()
|
||||
if not 'teacher' in name
|
||||
}
|
||||
parameters = [
|
||||
param for param in model.parameters()
|
||||
if 'teacher' not in param.name
|
||||
]
|
||||
weight_decay = 0.
|
||||
else:
|
||||
parameters = model.parameters()
|
||||
|
||||
opt_args = dict(
|
||||
learning_rate=self.learning_rate, weight_decay=self.weight_decay)
|
||||
opt_args['parameters'] = parameters
|
||||
if decay_dict is not None:
|
||||
opt_args['apply_decay_param_fun'] = lambda n: decay_dict[n]
|
||||
opt_args['epsilon'] = self.epsilon
|
||||
opt_args['beta1'] = self.beta1
|
||||
opt_args['beta2'] = self.beta2
|
||||
if self.layerwise_decay and self.layerwise_decay < 1.0:
|
||||
opt_args['layerwise_decay'] = self.layerwise_decay
|
||||
name_dict = dict()
|
||||
for n, p in model.named_parameters():
|
||||
name_dict[p.name] = n
|
||||
opt_args['name_dict'] = name_dict
|
||||
opt_args['n_layers'] = model.get_num_layers()
|
||||
|
||||
optimizer = self.AdamWDLImpl(**opt_args)
|
||||
|
||||
return optimizer
|
||||
|
|
Loading…
Reference in New Issue