[Refactor] Add ln to vit avg_featmap output (#1447)
parent
3a25b13eb3
commit
75dceaa78f
|
@ -338,6 +338,8 @@ class VisionTransformer(BaseBackbone):
|
|||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
if self.out_type == 'avg_featmap':
|
||||
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
|
||||
|
||||
# freeze stages only when self.frozen_stages > 0
|
||||
if self.frozen_stages > 0:
|
||||
|
@ -454,7 +456,7 @@ class VisionTransformer(BaseBackbone):
|
|||
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
|
||||
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
|
||||
if self.out_type == 'avg_featmap':
|
||||
return patch_token.mean(dim=1)
|
||||
return self.ln2(patch_token.mean(dim=1))
|
||||
|
||||
def get_layer_depth(self, param_name: str, prefix: str = ''):
|
||||
"""Get the layer-wise depth of a parameter.
|
||||
|
|
Loading…
Reference in New Issue