From 4e988692dd8a58b462985a543a77ec2cc97ce691 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Mon, 20 Mar 2023 11:38:44 +0000 Subject: [PATCH] fix concat error when fp16 --- ppcls/arch/backbone/model_zoo/gvt.py | 2 +- ppcls/arch/backbone/model_zoo/tnt.py | 4 +++- ppcls/arch/backbone/model_zoo/vision_transformer.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ppcls/arch/backbone/model_zoo/gvt.py b/ppcls/arch/backbone/model_zoo/gvt.py index d1afbecaa..3e0592389 100644 --- a/ppcls/arch/backbone/model_zoo/gvt.py +++ b/ppcls/arch/backbone/model_zoo/gvt.py @@ -376,7 +376,7 @@ class PyramidVisionTransformer(nn.Layer): for i in range(len(self.depths)): x, (H, W) = self.patch_embeds[i](x) if i == len(self.depths) - 1: - cls_tokens = self.cls_token.expand([B, -1, -1]) + cls_tokens = self.cls_token.expand([B, -1, -1]).astype(x.dtype) x = paddle.concat([cls_tokens, x], dim=1) x = x + self.pos_embeds[i] x = self.pos_drops[i](x) diff --git a/ppcls/arch/backbone/model_zoo/tnt.py b/ppcls/arch/backbone/model_zoo/tnt.py index c313a1402..6025a1d2d 100644 --- a/ppcls/arch/backbone/model_zoo/tnt.py +++ b/ppcls/arch/backbone/model_zoo/tnt.py @@ -350,7 +350,9 @@ class TNT(nn.Layer): pixel_embed.reshape((-1, self.num_patches, pixel_embed. shape[-1] * pixel_embed.shape[-2]))))) patch_embed = paddle.concat( - (self.cls_token.expand((B, -1, -1)), patch_embed), axis=1) + (self.cls_token.expand((B, -1, -1)).astype(patch_embed.dtype), + patch_embed), + axis=1) patch_embed = patch_embed + self.patch_pos patch_embed = self.pos_drop(patch_embed) for blk in self.blocks: diff --git a/ppcls/arch/backbone/model_zoo/vision_transformer.py b/ppcls/arch/backbone/model_zoo/vision_transformer.py index fbec1fcb4..e12b66e82 100644 --- a/ppcls/arch/backbone/model_zoo/vision_transformer.py +++ b/ppcls/arch/backbone/model_zoo/vision_transformer.py @@ -302,7 +302,7 @@ class VisionTransformer(nn.Layer): # B = x.shape[0] B = paddle.shape(x)[0] x = self.patch_embed(x) - cls_tokens = self.cls_token.expand((B, -1, -1)) + cls_tokens = self.cls_token.expand((B, -1, -1)).astype(x.dtype) x = paddle.concat((cls_tokens, x), axis=1) x = x + self.pos_embed x = self.pos_drop(x)