support fused attn (#3131)

pull/3152/head
Tingquan Gao 2024-05-16 13:33:46 +08:00 committed by GitHub
parent a0ae182547
commit e3aaa3cefb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 42 additions and 24 deletions

View File

@ -349,7 +349,8 @@ class Attention(nn.Layer):
attn_drop=0.,
proj_drop=0.,
model_name=None,
window_size=None):
window_size=None,
use_fused_attn=False):
super().__init__()
self._model_name = model_name
@ -368,9 +369,16 @@ class Attention(nn.Layer):
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.attn_drop_value = attn_drop
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.use_fused_attn = use_fused_attn
if use_fused_attn:
if hasattr(self, 'relative_position_bias_table') or (_model_size in _model_diff['add_shared_rel_pos_bias'] and rel_pos_bias is not None):
logger.warning("The fused attn don't support `relative_position` yet, so fused attn will not be used.")
self.use_fused_attn = False
def _register_relative_position_index(
self,
window_size,
@ -407,28 +415,33 @@ class Attention(nn.Layer):
def forward(self, x, rel_pos_bias=None):
# B= x.shape[0]
N, C = x.shape[1], x.shape[2]
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //
self.num_heads)).transpose((2, 0, 3, 1, 4))
q, k, v = qkv[0], qkv[1], qkv[2]
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads))
attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
if hasattr(self, 'relative_position_bias_table'):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.reshape([-1])].reshape([
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1]) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.transpose(
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if not self.use_fused_attn:
qkv = qkv.transpose((2, 0, 3, 1, 4))
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q.matmul(k.transpose((0, 1, 3, 2)))) * self.scale
if hasattr(self, 'relative_position_bias_table'):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.reshape([-1])].reshape([
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1]) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.transpose(
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if _model_size in _model_diff[
'add_shared_rel_pos_bias'] and rel_pos_bias is not None:
attn = attn + rel_pos_bias
if _model_size in _model_diff[
'add_shared_rel_pos_bias'] and rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = nn.functional.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((-1, N, C))
attn = nn.functional.softmax(attn, axis=-1)
attn = self.attn_drop(attn).matmul(v)
attn = attn.transpose((0, 2, 1, 3))
else:
qkv = qkv.transpose((2, 0, 1, 3, 4))
q, k, v = qkv[0], qkv[1], qkv[2]
attn, _ = paddle.nn.functional.flash_attention.flash_attention(q, k, v, dropout=self.attn_drop_value)
x = attn.reshape((-1, N, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
@ -677,6 +690,7 @@ class VisionTransformer(nn.Layer):
self.class_num = class_num
self.return_embed = kwargs.get('return_embed', False)
self.use_fused_attn = kwargs.get('use_fused_attn', False)
self.num_features = self.embed_dim = embed_dim
_img_size = to_2tuple(img_size)
_patch_size = to_2tuple(patch_size)
@ -985,7 +999,7 @@ def Unicom_vit_base_patch32_224(pretrained=False, use_ssld=False, **kwargs):
num_heads=12,
mlp_ratio=4,
qkv_bias=False,
conv_bias=True,
conv_bias=True,
feature_frame=True,
hugging_face_framework=False,
image_project=False,
@ -1225,4 +1239,4 @@ def CAE_vit_base_patch16_224(pretrained=False, use_ssld=False, **kwargs):
**kwargs, )
_load_pretrained(
pretrained, model, MODEL_URLS[model_name], use_ssld=use_ssld)
return model
return model

View File

@ -19,13 +19,15 @@ AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
level: O2
use_fp16_test: True
# model architecture
Arch:
name: CLIP_vit_base_patch16_224
class_num: 1000
return_embed: False
use_fused_attn: True # fused attn can be used in AMP O2 mode only
pretrained: True
# loss function config for traing/eval process

View File

@ -19,13 +19,15 @@ AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
level: O2
use_fp16_test: True
# model architecture
Arch:
name: CLIP_vit_large_patch14_224
class_num: 1000
return_embed: False
use_fused_attn: True # fused attn can be used in AMP O2 mode only, and dont support relative_position yet
pretrained: True
# loss function config for traing/eval process