Update gvt.py
parent
e6eeb09a82
commit
c4d131ebd3
|
@ -78,9 +78,9 @@ class GroupAttention(nn.Layer):
|
|||
total_groups = h_group * w_group
|
||||
x = x.reshape([B, h_group, self.ws, w_group, self.ws, C]).transpose(
|
||||
[0, 1, 3, 2, 4, 5])
|
||||
qkv = self.qkv(x).reshape(
|
||||
[B, total_groups, -1, 3, self.num_heads,
|
||||
C // self.num_heads]).transpose([3, 0, 1, 4, 2, 5])
|
||||
qkv = self.qkv(x).reshape([
|
||||
B, total_groups, self.ws**2, 3, self.num_heads, C // self.num_heads
|
||||
]).transpose([3, 0, 1, 4, 2, 5])
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
attn = (q @k.transpose([0, 1, 2, 4, 3])) * self.scale
|
||||
|
||||
|
@ -135,14 +135,15 @@ class Attention(nn.Layer):
|
|||
|
||||
if self.sr_ratio > 1:
|
||||
x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
|
||||
x_ = self.sr(x_).reshape([B, C, -1]).transpose([0, 2, 1])
|
||||
tmp_n = H * W // self.sr_ratio**2
|
||||
x_ = self.sr(x_).reshape([B, C, tmp_n]).transpose([0, 2, 1])
|
||||
x_ = self.norm(x_)
|
||||
kv = self.kv(x_).reshape(
|
||||
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
|
||||
[B, tmp_n, 2, self.num_heads, C // self.num_heads]).transpose(
|
||||
[2, 0, 3, 1, 4])
|
||||
else:
|
||||
kv = self.kv(x).reshape(
|
||||
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
|
||||
[B, N, 2, self.num_heads, C // self.num_heads]).transpose(
|
||||
[2, 0, 3, 1, 4])
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
|
@ -317,7 +318,6 @@ class PyramidVisionTransformer(nn.Layer):
|
|||
self.create_parameter(
|
||||
shape=[1, patch_num, embed_dims[i]],
|
||||
default_initializer=zeros_))
|
||||
self.add_parameter(f"pos_embeds_{i}", self.pos_embeds[i])
|
||||
self.pos_drops.append(nn.Dropout(p=drop_rate))
|
||||
|
||||
dpr = [
|
||||
|
@ -350,7 +350,6 @@ class PyramidVisionTransformer(nn.Layer):
|
|||
shape=[1, 1, embed_dims[-1]],
|
||||
default_initializer=zeros_,
|
||||
attr=paddle.ParamAttr(regularizer=L2Decay(0.0)))
|
||||
self.add_parameter("cls_token", self.cls_token)
|
||||
|
||||
# classification head
|
||||
self.head = nn.Linear(embed_dims[-1],
|
||||
|
@ -433,7 +432,7 @@ class CPVTV2(PyramidVisionTransformer):
|
|||
img_size=224,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
class_num=1000,
|
||||
num_classes=1000,
|
||||
embed_dims=[64, 128, 256, 512],
|
||||
num_heads=[1, 2, 4, 8],
|
||||
mlp_ratios=[4, 4, 4, 4],
|
||||
|
@ -446,7 +445,7 @@ class CPVTV2(PyramidVisionTransformer):
|
|||
depths=[3, 4, 6, 3],
|
||||
sr_ratios=[8, 4, 2, 1],
|
||||
block_cls=Block):
|
||||
super().__init__(img_size, patch_size, in_chans, class_num,
|
||||
super().__init__(img_size, patch_size, in_chans, num_classes,
|
||||
embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale,
|
||||
drop_rate, attn_drop_rate, drop_path_rate, norm_layer,
|
||||
depths, sr_ratios, block_cls)
|
||||
|
@ -488,7 +487,7 @@ class CPVTV2(PyramidVisionTransformer):
|
|||
x = self.pos_block[i](x, H, W) # PEG here
|
||||
|
||||
if i < len(self.depths) - 1:
|
||||
x = x.reshape([B, H, W, -1]).transpose([0, 3, 1, 2])
|
||||
x = x.reshape([B, H, W, x.shape[-1]]).transpose([0, 3, 1, 2])
|
||||
|
||||
x = self.norm(x)
|
||||
return x.mean(axis=1) # GAP here
|
||||
|
@ -499,7 +498,7 @@ class PCPVT(CPVTV2):
|
|||
img_size=224,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
class_num=1000,
|
||||
num_classes=1000,
|
||||
embed_dims=[64, 128, 256],
|
||||
num_heads=[1, 2, 4],
|
||||
mlp_ratios=[4, 4, 4],
|
||||
|
@ -512,7 +511,7 @@ class PCPVT(CPVTV2):
|
|||
depths=[4, 4, 4],
|
||||
sr_ratios=[4, 2, 1],
|
||||
block_cls=SBlock):
|
||||
super().__init__(img_size, patch_size, in_chans, class_num,
|
||||
super().__init__(img_size, patch_size, in_chans, num_classes,
|
||||
embed_dims, num_heads, mlp_ratios, qkv_bias, qk_scale,
|
||||
drop_rate, attn_drop_rate, drop_path_rate, norm_layer,
|
||||
depths, sr_ratios, block_cls)
|
||||
|
|
Loading…
Reference in New Issue