fix: fix -1 in dims and ignore swapdim in static
parent
00fb3f7519
commit
ee819393ac
|
@ -130,24 +130,32 @@ class Attention(nn.Layer):
|
|||
if not self.linear:
|
||||
if self.sr_ratio > 1:
|
||||
x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
|
||||
x_ = self.sr(x_).reshape([B, C, -1]).transpose([0, 2, 1])
|
||||
x_ = self.sr(x_)
|
||||
h_, w_ = x_.shape[-2:]
|
||||
x_ = x_.reshape([B, C, h_ * w_]).transpose([0, 2, 1])
|
||||
x_ = self.norm(x_)
|
||||
kv = self.kv(x_).reshape(
|
||||
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
|
||||
[2, 0, 3, 1, 4])
|
||||
kv = self.kv(x_)
|
||||
kv = kv.reshape([
|
||||
B, kv.shape[2] * kv.shape[1] // 2 // C, 2, self.num_heads,
|
||||
C // self.num_heads
|
||||
]).transpose([2, 0, 3, 1, 4])
|
||||
else:
|
||||
kv = self.kv(x).reshape(
|
||||
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
|
||||
[2, 0, 3, 1, 4])
|
||||
kv = self.kv(x)
|
||||
kv = kv.reshape([
|
||||
B, kv.shape[2] * kv.shape[1] // 2 // C, 2, self.num_heads,
|
||||
C // self.num_heads
|
||||
]).transpose([2, 0, 3, 1, 4])
|
||||
else:
|
||||
x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
|
||||
x_ = self.sr(self.pool(x_)).reshape([B, C, -1]).transpose(
|
||||
[0, 2, 1])
|
||||
x_ = self.norm(x_)
|
||||
x_ = self.act(x_)
|
||||
kv = self.kv(x_).reshape(
|
||||
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
|
||||
[2, 0, 3, 1, 4])
|
||||
kv = self.kv(x_)
|
||||
kv = kv.reshape([
|
||||
B, kv.shape[2] * kv.shape[1] // 2 // C, 2, self.num_heads,
|
||||
C // self.num_heads
|
||||
]).transpose([2, 0, 3, 1, 4])
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
attn = (q @swapdim(k, -2, -1)) * self.scale
|
||||
|
@ -324,7 +332,7 @@ class PyramidVisionTransformerV2(nn.Layer):
|
|||
x = blk(x, H, W)
|
||||
x = norm(x)
|
||||
if i != self.num_stages - 1:
|
||||
x = x.reshape([B, H, W, -1]).transpose([0, 3, 1, 2])
|
||||
x = x.reshape([B, H, W, x.shape[2]]).transpose([0, 3, 1, 2])
|
||||
|
||||
return x.mean(axis=1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue