change for easy deployment of segmenter (#1642)
parent
46326f63ce
commit
60655de194
|
@ -386,13 +386,13 @@ class VisionTransformer(BaseModule):
|
|||
"""
|
||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
||||
pos_h, pos_w = pos_shape
|
||||
cls_token_weight = pos_embed[:, 0]
|
||||
# keep dim for easy deployment
|
||||
cls_token_weight = pos_embed[:, 0:1]
|
||||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
||||
pos_embed_weight = pos_embed_weight.reshape(
|
||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||
pos_embed_weight = resize(
|
||||
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
|
||||
cls_token_weight = cls_token_weight.unsqueeze(1)
|
||||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
||||
return pos_embed
|
||||
|
|
Loading…
Reference in New Issue