dbg: support fused attn
parent
e3aaa3cefb
commit
40042f89fa
|
@ -462,7 +462,8 @@ class Block(nn.Layer):
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
norm_layer='nn.LayerNorm',
|
norm_layer='nn.LayerNorm',
|
||||||
epsilon=1e-5,
|
epsilon=1e-5,
|
||||||
window_size=None):
|
window_size=None,
|
||||||
|
use_fused_attn=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
global _model_size
|
global _model_size
|
||||||
global _model_diff
|
global _model_diff
|
||||||
|
@ -482,7 +483,8 @@ class Block(nn.Layer):
|
||||||
attn_drop=attn_drop,
|
attn_drop=attn_drop,
|
||||||
proj_drop=drop,
|
proj_drop=drop,
|
||||||
model_name=self._model_name,
|
model_name=self._model_name,
|
||||||
window_size=window_size)
|
window_size=window_size,
|
||||||
|
use_fused_attn=use_fused_attn)
|
||||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
|
||||||
|
|
||||||
|
@ -690,8 +692,8 @@ class VisionTransformer(nn.Layer):
|
||||||
|
|
||||||
self.class_num = class_num
|
self.class_num = class_num
|
||||||
self.return_embed = kwargs.get('return_embed', False)
|
self.return_embed = kwargs.get('return_embed', False)
|
||||||
self.use_fused_attn = kwargs.get('use_fused_attn', False)
|
|
||||||
self.num_features = self.embed_dim = embed_dim
|
self.num_features = self.embed_dim = embed_dim
|
||||||
|
use_fused_attn = kwargs.get('use_fused_attn', False)
|
||||||
_img_size = to_2tuple(img_size)
|
_img_size = to_2tuple(img_size)
|
||||||
_patch_size = to_2tuple(patch_size)
|
_patch_size = to_2tuple(patch_size)
|
||||||
self.window_size = (_img_size[0] // _patch_size[0],
|
self.window_size = (_img_size[0] // _patch_size[0],
|
||||||
|
@ -768,7 +770,8 @@ class VisionTransformer(nn.Layer):
|
||||||
drop_path=dpr[i],
|
drop_path=dpr[i],
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
epsilon=epsilon,
|
epsilon=epsilon,
|
||||||
window_size=self.window_size) for i in range(depth)
|
window_size=self.window_size,
|
||||||
|
use_fused_attn=use_fused_attn) for i in range(depth)
|
||||||
])
|
])
|
||||||
|
|
||||||
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
|
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
|
||||||
|
|
|
@ -27,7 +27,7 @@ Arch:
|
||||||
name: CLIP_vit_large_patch14_224
|
name: CLIP_vit_large_patch14_224
|
||||||
class_num: 1000
|
class_num: 1000
|
||||||
return_embed: False
|
return_embed: False
|
||||||
use_fused_attn: True # fused attn can be used in AMP O2 mode only, and dont support relative_position yet
|
use_fused_attn: True # fused attn can be used in AMP O2 mode only
|
||||||
pretrained: True
|
pretrained: True
|
||||||
|
|
||||||
# loss function config for traing/eval process
|
# loss function config for traing/eval process
|
||||||
|
|
Loading…
Reference in New Issue