Merge pull request #1210 from cuicheng01/release/2.2

[Cherry-pick]fix tnt inference bug when bs > 1
pull/1297/head
cuicheng01 2021-09-07 14:15:47 +08:00 committed by GitHub
commit c363fa1df8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 4 deletions

View File

@ -193,9 +193,11 @@ class Block(nn.Layer):
self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))))
# outer
B, N, C = patch_embed.shape
patch_embed[:, 1:] = paddle.add(
patch_embed[:, 1:],
self.proj(self.norm1_proj(pixel_embed).reshape((B, N - 1, -1))))
norm1_proj = self.norm1_proj(pixel_embed)
norm1_proj = norm1_proj.reshape(
(B, N - 1, norm1_proj.shape[1] * norm1_proj.shape[2]))
patch_embed[:, 1:] = paddle.add(patch_embed[:, 1:],
self.proj(norm1_proj))
patch_embed = paddle.add(
patch_embed,
self.drop_path(self.attn_out(self.norm_out(patch_embed))))
@ -328,7 +330,7 @@ class TNT(nn.Layer):
ones_(m.weight)
def forward_features(self, x):
B = x.shape[0]
B = paddle.shape(x)[0]
pixel_embed = self.pixel_embed(x, self.pixel_pos)
patch_embed = self.norm2_proj(