fix concat error when fp16
parent
2c3ebe7b65
commit
4e988692dd
|
@ -376,7 +376,7 @@ class PyramidVisionTransformer(nn.Layer):
|
|||
for i in range(len(self.depths)):
|
||||
x, (H, W) = self.patch_embeds[i](x)
|
||||
if i == len(self.depths) - 1:
|
||||
cls_tokens = self.cls_token.expand([B, -1, -1])
|
||||
cls_tokens = self.cls_token.expand([B, -1, -1]).astype(x.dtype)
|
||||
x = paddle.concat([cls_tokens, x], dim=1)
|
||||
x = x + self.pos_embeds[i]
|
||||
x = self.pos_drops[i](x)
|
||||
|
|
|
@ -350,7 +350,9 @@ class TNT(nn.Layer):
|
|||
pixel_embed.reshape((-1, self.num_patches, pixel_embed.
|
||||
shape[-1] * pixel_embed.shape[-2])))))
|
||||
patch_embed = paddle.concat(
|
||||
(self.cls_token.expand((B, -1, -1)), patch_embed), axis=1)
|
||||
(self.cls_token.expand((B, -1, -1)).astype(patch_embed.dtype),
|
||||
patch_embed),
|
||||
axis=1)
|
||||
patch_embed = patch_embed + self.patch_pos
|
||||
patch_embed = self.pos_drop(patch_embed)
|
||||
for blk in self.blocks:
|
||||
|
|
|
@ -302,7 +302,7 @@ class VisionTransformer(nn.Layer):
|
|||
# B = x.shape[0]
|
||||
B = paddle.shape(x)[0]
|
||||
x = self.patch_embed(x)
|
||||
cls_tokens = self.cls_token.expand((B, -1, -1))
|
||||
cls_tokens = self.cls_token.expand((B, -1, -1)).astype(x.dtype)
|
||||
x = paddle.concat((cls_tokens, x), axis=1)
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
|
Loading…
Reference in New Issue