cls_config_fix (#3253)

* cls_config_fix

* fix flash attn
This commit is contained in:
Sunflower7788 2024-09-23 15:00:04 +08:00 committed by GitHub
parent 37c1ddb0dd
commit f064291d6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 44 additions and 30 deletions

Binary file not shown.

View File

@ -452,7 +452,8 @@ class MobileAttention(nn.Layer):
dropout_prob=0.0,
layer_scale_init_value=0.0,
if_act=True,
act=None):
act=None,
use_fused_attn=False):
super(MobileAttention, self).__init__()
self.if_shortcut = stride == 1 and in_c == out_c
@ -461,6 +462,7 @@ class MobileAttention(nn.Layer):
self.num_head = num_head
self.query_dim = query_dim
self.attn_drop_rate = attn_drop_rate
self.use_fused_attn = use_fused_attn
self.norm = BatchNorm(
num_channels=in_c,
act=None,
@ -515,6 +517,10 @@ class MobileAttention(nn.Layer):
padding=0,
groups=1,
bias_attr=False)
if not self.use_fused_attn:
self.scale = query_dim**-0.5
self.softmax = nn.Softmax(-1)
self.attn_drop = Dropout(self.attn_drop_rate)
self.drop = Dropout(dropout_prob)
self.layer_scale_init_value = layer_scale_init_value
if layer_scale_init_value > 0.0:
@ -541,12 +547,22 @@ class MobileAttention(nn.Layer):
v = v.reshape(
[B, self.kv_dim, 1, H // self.kv_stride * W // self.kv_stride])
v = v.transpose([0, 3, 2, 1])
attn = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=attn_mask,
dropout_p=self.attn_drop_rate if self.training else 0.0)
if self.use_fused_attn:
attn = F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=attn_mask,
dropout_p=self.attn_drop_rate if self.training else 0.0)
else:
q = q.transpose([0, 2, 1, 3]) * self.scale
v = v.transpose([0, 2, 1, 3])
attn = paddle.mm(q, k.transpose([0, 2, 3, 1]))
attn = self.softmax(attn)
attn = self.attn_drop(attn)
attn = paddle.mm(attn, v)
attn = attn.transpose([0, 2, 1, 3])
attn = attn.reshape([B, H, W, self.query_dim])
x = self.proj(attn.transpose([0, 3, 1, 2]))
x = self.drop(x)
@ -586,6 +602,7 @@ class MobileNetV4(TheseusLayer):
layer_scale_init_value=0.0,
return_patterns=None,
return_stages=None,
use_fused_attn=False,
**kwargs):
super(MobileNetV4, self).__init__()
self.cfg = config
@ -662,7 +679,8 @@ class MobileNetV4(TheseusLayer):
drop_path_rate=self.drop_path_rate * i / block_count,
layer_scale_init_value=layer_scale_init_value,
if_act=True,
act=act)
act=act,
use_fused_attn=use_fused_attn)
blocks.append(block)
self.blocks = nn.Sequential(*blocks)
self.global_pool = AdaptiveAvgPool2D(1)

View File

@ -17,7 +17,7 @@ Global:
# mixed precision
AMP:
use_amp: False
use_amp: True
use_fp16_test: False
scale_loss: 128.0
use_dynamic_loss_scaling: True
@ -70,7 +70,6 @@ DataLoader:
transform_ops:
- DecodeImage:
backend: pil
to_np: False
channel_first: False
- RandCropImage:
size: 224

View File

@ -17,7 +17,7 @@ Global:
# mixed precision
AMP:
use_amp: False
use_amp: True
use_fp16_test: False
scale_loss: 128.0
use_dynamic_loss_scaling: True
@ -70,7 +70,6 @@ DataLoader:
transform_ops:
- DecodeImage:
backend: pil
to_np: False
channel_first: False
- RandCropImage:
size: 224

View File

@ -17,7 +17,7 @@ Global:
# mixed precision
AMP:
use_amp: False
use_amp: True
use_fp16_test: False
scale_loss: 128.0
use_dynamic_loss_scaling: True
@ -70,7 +70,6 @@ DataLoader:
transform_ops:
- DecodeImage:
backend: pil
to_np: False
channel_first: False
- RandCropImage:
size: 224

View File

@ -17,7 +17,7 @@ Global:
# mixed precision
AMP:
use_amp: False
use_amp: True
use_fp16_test: False
scale_loss: 128.0
use_dynamic_loss_scaling: True
@ -70,7 +70,6 @@ DataLoader:
transform_ops:
- DecodeImage:
backend: pil
to_np: False
channel_first: False
- RandCropImage:
size: 224

View File

@ -17,7 +17,7 @@ Global:
# mixed precision
AMP:
use_amp: False
use_amp: True
use_fp16_test: False
scale_loss: 128.0
use_dynamic_loss_scaling: True
@ -70,7 +70,6 @@ DataLoader:
transform_ops:
- DecodeImage:
backend: pil
to_np: False
channel_first: False
- RandCropImage:
size: 224

View File

@ -17,7 +17,7 @@ Global:
# mixed precision
AMP:
use_amp: False
use_amp: True
use_fp16_test: False
scale_loss: 128.0
use_dynamic_loss_scaling: True
@ -70,7 +70,6 @@ DataLoader:
transform_ops:
- DecodeImage:
backend: pil
to_np: False
channel_first: False
- RandCropImage:
size: 224

View File

@ -31,6 +31,7 @@ Arch:
name: MobileNetV4_hybrid_large
drop_rate: 0.2
drop_path_rate: 0.1
use_fused_attn: False
class_num: 1000
@ -57,7 +58,7 @@ Optimizer:
lr:
# for 8 cards
name: Cosine
learning_rate: 0.002
learning_rate: 0.001
eta_min: 1.0e-06
warmup_epoch: 20
warmup_start_lr: 0
@ -111,7 +112,7 @@ DataLoader:
prob: 0.5
sampler:
name: DistributedBatchSampler
batch_size: 192
batch_size: 64
drop_last: False
shuffle: True
loader:
@ -157,9 +158,9 @@ Infer:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
resize_short: 448
- CropImage:
size: 224
size: 448
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]

View File

@ -30,6 +30,7 @@ Arch:
name: MobileNetV4_hybrid_medium
drop_rate: 0.2
drop_path_rate: 0.1
use_fused_attn: False
class_num: 1000
@ -108,7 +109,7 @@ DataLoader:
prob: 0.5
sampler:
name: DistributedBatchSampler
batch_size: 512
batch_size: 64
drop_last: False
shuffle: True
loader:
@ -154,7 +155,7 @@ Infer:
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
size: 256
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]

View File

@ -17,7 +17,7 @@ Global:
# mixed precision
AMP:
use_amp: False
use_amp: True
use_fp16_test: False
scale_loss: 128.0
use_dynamic_loss_scaling: True

View File

@ -17,7 +17,7 @@ Global:
# mixed precision
AMP:
use_amp: False
use_amp: True
use_fp16_test: False
scale_loss: 128.0
use_dynamic_loss_scaling: True

View File

@ -17,7 +17,7 @@ Global:
# mixed precision
AMP:
use_amp: False
use_amp: True
use_fp16_test: False
scale_loss: 128.0
use_dynamic_loss_scaling: True

View File

@ -17,7 +17,7 @@ Global:
# mixed precision
AMP:
use_amp: False
use_amp: True
use_fp16_test: False
scale_loss: 128.0
use_dynamic_loss_scaling: True