mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
change for easy deployment of segmenter (#1642)
This commit is contained in:
parent
46326f63ce
commit
60655de194
@ -386,13 +386,13 @@ class VisionTransformer(BaseModule):
|
|||||||
"""
|
"""
|
||||||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
|
||||||
pos_h, pos_w = pos_shape
|
pos_h, pos_w = pos_shape
|
||||||
cls_token_weight = pos_embed[:, 0]
|
# keep dim for easy deployment
|
||||||
|
cls_token_weight = pos_embed[:, 0:1]
|
||||||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
|
||||||
pos_embed_weight = pos_embed_weight.reshape(
|
pos_embed_weight = pos_embed_weight.reshape(
|
||||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||||
pos_embed_weight = resize(
|
pos_embed_weight = resize(
|
||||||
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
|
pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
|
||||||
cls_token_weight = cls_token_weight.unsqueeze(1)
|
|
||||||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||||
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
||||||
return pos_embed
|
return pos_embed
|
||||||
|
Loading…
x
Reference in New Issue
Block a user