dbg: support fused attn

pull/3155/head
gaotingquan 2024-05-23 09:13:28 +00:00 committed by Tingquan Gao
parent e3aaa3cefb
commit 40042f89fa
2 changed files with 8 additions and 5 deletions

View File

@ -462,7 +462,8 @@ class Block(nn.Layer):
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer='nn.LayerNorm', norm_layer='nn.LayerNorm',
epsilon=1e-5, epsilon=1e-5,
window_size=None): window_size=None,
use_fused_attn=False):
super().__init__() super().__init__()
global _model_size global _model_size
global _model_diff global _model_diff
@ -482,7 +483,8 @@ class Block(nn.Layer):
attn_drop=attn_drop, attn_drop=attn_drop,
proj_drop=drop, proj_drop=drop,
model_name=self._model_name, model_name=self._model_name,
window_size=window_size) window_size=window_size,
use_fused_attn=use_fused_attn)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
@ -690,8 +692,8 @@ class VisionTransformer(nn.Layer):
self.class_num = class_num self.class_num = class_num
self.return_embed = kwargs.get('return_embed', False) self.return_embed = kwargs.get('return_embed', False)
self.use_fused_attn = kwargs.get('use_fused_attn', False)
self.num_features = self.embed_dim = embed_dim self.num_features = self.embed_dim = embed_dim
use_fused_attn = kwargs.get('use_fused_attn', False)
_img_size = to_2tuple(img_size) _img_size = to_2tuple(img_size)
_patch_size = to_2tuple(patch_size) _patch_size = to_2tuple(patch_size)
self.window_size = (_img_size[0] // _patch_size[0], self.window_size = (_img_size[0] // _patch_size[0],
@ -768,7 +770,8 @@ class VisionTransformer(nn.Layer):
drop_path=dpr[i], drop_path=dpr[i],
norm_layer=norm_layer, norm_layer=norm_layer,
epsilon=epsilon, epsilon=epsilon,
window_size=self.window_size) for i in range(depth) window_size=self.window_size,
use_fused_attn=use_fused_attn) for i in range(depth)
]) ])
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon) self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)

View File

@ -27,7 +27,7 @@ Arch:
name: CLIP_vit_large_patch14_224 name: CLIP_vit_large_patch14_224
class_num: 1000 class_num: 1000
return_embed: False return_embed: False
use_fused_attn: True # fused attn can be used in AMP O2 mode only, and dont support relative_position yet use_fused_attn: True # fused attn can be used in AMP O2 mode only
pretrained: True pretrained: True
# loss function config for traing/eval process # loss function config for traing/eval process