fix: fix -1 in dims and ignore swapdim in static

pull/1632/head
gaotingquan 2021-12-15 10:29:36 +00:00 committed by Tingquan Gao
parent 00fb3f7519
commit ee819393ac
1 changed files with 19 additions and 11 deletions

View File

@ -130,24 +130,32 @@ class Attention(nn.Layer):
if not self.linear:
if self.sr_ratio > 1:
x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
x_ = self.sr(x_).reshape([B, C, -1]).transpose([0, 2, 1])
x_ = self.sr(x_)
h_, w_ = x_.shape[-2:]
x_ = x_.reshape([B, C, h_ * w_]).transpose([0, 2, 1])
x_ = self.norm(x_)
kv = self.kv(x_).reshape(
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
[2, 0, 3, 1, 4])
kv = self.kv(x_)
kv = kv.reshape([
B, kv.shape[2] * kv.shape[1] // 2 // C, 2, self.num_heads,
C // self.num_heads
]).transpose([2, 0, 3, 1, 4])
else:
kv = self.kv(x).reshape(
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
[2, 0, 3, 1, 4])
kv = self.kv(x)
kv = kv.reshape([
B, kv.shape[2] * kv.shape[1] // 2 // C, 2, self.num_heads,
C // self.num_heads
]).transpose([2, 0, 3, 1, 4])
else:
x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W])
x_ = self.sr(self.pool(x_)).reshape([B, C, -1]).transpose(
[0, 2, 1])
x_ = self.norm(x_)
x_ = self.act(x_)
kv = self.kv(x_).reshape(
[B, -1, 2, self.num_heads, C // self.num_heads]).transpose(
[2, 0, 3, 1, 4])
kv = self.kv(x_)
kv = kv.reshape([
B, kv.shape[2] * kv.shape[1] // 2 // C, 2, self.num_heads,
C // self.num_heads
]).transpose([2, 0, 3, 1, 4])
k, v = kv[0], kv[1]
attn = (q @swapdim(k, -2, -1)) * self.scale
@ -324,7 +332,7 @@ class PyramidVisionTransformerV2(nn.Layer):
x = blk(x, H, W)
x = norm(x)
if i != self.num_stages - 1:
x = x.reshape([B, H, W, -1]).transpose([0, 3, 1, 2])
x = x.reshape([B, H, W, x.shape[2]]).transpose([0, 3, 1, 2])
return x.mean(axis=1)