support fused attn (#3131)
parent
a0ae182547
commit
e3aaa3cefb
|
@ -349,7 +349,8 @@ class Attention(nn.Layer):
|
|||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
model_name=None,
|
||||
window_size=None):
|
||||
window_size=None,
|
||||
use_fused_attn=False):
|
||||
super().__init__()
|
||||
self._model_name = model_name
|
||||
|
||||
|
@ -368,9 +369,16 @@ class Attention(nn.Layer):
|
|||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.attn_drop_value = attn_drop
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.use_fused_attn = use_fused_attn
|
||||
if use_fused_attn:
|
||||
if hasattr(self, 'relative_position_bias_table') or (_model_size in _model_diff['add_shared_rel_pos_bias'] and rel_pos_bias is not None):
|
||||
logger.warning("The fused attn don't support `relative_position` yet, so fused attn will not be used.")
|
||||
self.use_fused_attn = False
|
||||
|
||||
def _register_relative_position_index(
|
||||
self,
|
||||
window_size,
|
||||
|
@ -407,28 +415,33 @@ class Attention(nn.Layer):
|
|||
def forward(self, x, rel_pos_bias=None):
|
||||
# B= x.shape[0]
|
||||
N, C = x.shape[1], x.shape[2]
|
||||
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //
|
||||
self.num_heads)).transpose((2, 0, 3, 1, 4))
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads))
|
||||
|
||||
attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
|
||||
if hasattr(self, 'relative_position_bias_table'):
|
||||
relative_position_bias = \
|
||||
self.relative_position_bias_table[self.relative_position_index.reshape([-1])].reshape([
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1, -1]) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.transpose(
|
||||
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
if not self.use_fused_attn:
|
||||
qkv = qkv.transpose((2, 0, 3, 1, 4))
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
|
||||
if hasattr(self, 'relative_position_bias_table'):
|
||||
relative_position_bias = \
|
||||
self.relative_position_bias_table[self.relative_position_index.reshape([-1])].reshape([
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1, -1]) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.transpose(
|
||||
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if _model_size in _model_diff[
|
||||
'add_shared_rel_pos_bias'] and rel_pos_bias is not None:
|
||||
attn = attn + rel_pos_bias
|
||||
if _model_size in _model_diff[
|
||||
'add_shared_rel_pos_bias'] and rel_pos_bias is not None:
|
||||
attn = attn + rel_pos_bias
|
||||
|
||||
attn = nn.functional.softmax(attn, axis=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C))
|
||||
attn = nn.functional.softmax(attn, axis=-1)
|
||||
attn = self.attn_drop(attn).matmul(v)
|
||||
attn = attn.transpose((0, 2, 1, 3))
|
||||
else:
|
||||
qkv = qkv.transpose((2, 0, 1, 3, 4))
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
attn, _ = paddle.nn.functional.flash_attention.flash_attention(q, k, v, dropout=self.attn_drop_value)
|
||||
x = attn.reshape((-1, N, C))
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
@ -677,6 +690,7 @@ class VisionTransformer(nn.Layer):
|
|||
|
||||
self.class_num = class_num
|
||||
self.return_embed = kwargs.get('return_embed', False)
|
||||
self.use_fused_attn = kwargs.get('use_fused_attn', False)
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
_img_size = to_2tuple(img_size)
|
||||
_patch_size = to_2tuple(patch_size)
|
||||
|
@ -985,7 +999,7 @@ def Unicom_vit_base_patch32_224(pretrained=False, use_ssld=False, **kwargs):
|
|||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=False,
|
||||
conv_bias=True,
|
||||
conv_bias=True,
|
||||
feature_frame=True,
|
||||
hugging_face_framework=False,
|
||||
image_project=False,
|
||||
|
@ -1225,4 +1239,4 @@ def CAE_vit_base_patch16_224(pretrained=False, use_ssld=False, **kwargs):
|
|||
**kwargs, )
|
||||
_load_pretrained(
|
||||
pretrained, model, MODEL_URLS[model_name], use_ssld=use_ssld)
|
||||
return model
|
||||
return model
|
||||
|
|
|
@ -19,13 +19,15 @@ AMP:
|
|||
scale_loss: 128.0
|
||||
use_dynamic_loss_scaling: True
|
||||
# O1: mixed fp16
|
||||
level: O1
|
||||
level: O2
|
||||
use_fp16_test: True
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: CLIP_vit_base_patch16_224
|
||||
class_num: 1000
|
||||
return_embed: False
|
||||
use_fused_attn: True # fused attn can be used in AMP O2 mode only
|
||||
pretrained: True
|
||||
|
||||
# loss function config for traing/eval process
|
||||
|
|
|
@ -19,13 +19,15 @@ AMP:
|
|||
scale_loss: 128.0
|
||||
use_dynamic_loss_scaling: True
|
||||
# O1: mixed fp16
|
||||
level: O1
|
||||
level: O2
|
||||
use_fp16_test: True
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
name: CLIP_vit_large_patch14_224
|
||||
class_num: 1000
|
||||
return_embed: False
|
||||
use_fused_attn: True # fused attn can be used in AMP O2 mode only, and dont support relative_position yet
|
||||
pretrained: True
|
||||
|
||||
# loss function config for traing/eval process
|
||||
|
|
Loading…
Reference in New Issue