From ee819393ac018972a210b382caef84bf3f614a0e Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Wed, 15 Dec 2021 10:29:36 +0000 Subject: [PATCH] fix: fix -1 in dims and ignore swapdim in static --- ppcls/arch/backbone/model_zoo/pvt_v2.py | 30 ++++++++++++++++--------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/ppcls/arch/backbone/model_zoo/pvt_v2.py b/ppcls/arch/backbone/model_zoo/pvt_v2.py index 7c4749e71..94175754f 100644 --- a/ppcls/arch/backbone/model_zoo/pvt_v2.py +++ b/ppcls/arch/backbone/model_zoo/pvt_v2.py @@ -130,24 +130,32 @@ class Attention(nn.Layer): if not self.linear: if self.sr_ratio > 1: x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W]) - x_ = self.sr(x_).reshape([B, C, -1]).transpose([0, 2, 1]) + x_ = self.sr(x_) + h_, w_ = x_.shape[-2:] + x_ = x_.reshape([B, C, h_ * w_]).transpose([0, 2, 1]) x_ = self.norm(x_) - kv = self.kv(x_).reshape( - [B, -1, 2, self.num_heads, C // self.num_heads]).transpose( - [2, 0, 3, 1, 4]) + kv = self.kv(x_) + kv = kv.reshape([ + B, kv.shape[2] * kv.shape[1] // 2 // C, 2, self.num_heads, + C // self.num_heads + ]).transpose([2, 0, 3, 1, 4]) else: - kv = self.kv(x).reshape( - [B, -1, 2, self.num_heads, C // self.num_heads]).transpose( - [2, 0, 3, 1, 4]) + kv = self.kv(x) + kv = kv.reshape([ + B, kv.shape[2] * kv.shape[1] // 2 // C, 2, self.num_heads, + C // self.num_heads + ]).transpose([2, 0, 3, 1, 4]) else: x_ = x.transpose([0, 2, 1]).reshape([B, C, H, W]) x_ = self.sr(self.pool(x_)).reshape([B, C, -1]).transpose( [0, 2, 1]) x_ = self.norm(x_) x_ = self.act(x_) - kv = self.kv(x_).reshape( - [B, -1, 2, self.num_heads, C // self.num_heads]).transpose( - [2, 0, 3, 1, 4]) + kv = self.kv(x_) + kv = kv.reshape([ + B, kv.shape[2] * kv.shape[1] // 2 // C, 2, self.num_heads, + C // self.num_heads + ]).transpose([2, 0, 3, 1, 4]) k, v = kv[0], kv[1] attn = (q @swapdim(k, -2, -1)) * self.scale @@ -324,7 +332,7 @@ class PyramidVisionTransformerV2(nn.Layer): x = blk(x, H, W) x = norm(x) if i != self.num_stages - 1: - x = x.reshape([B, H, W, -1]).transpose([0, 3, 1, 2]) + x = x.reshape([B, H, W, x.shape[2]]).transpose([0, 3, 1, 2]) return x.mean(axis=1)