[Refactor] Add ln to vit avg_featmap output (#1447)

pull/1464/head
Yixiao Fang 2023-04-06 11:59:39 +08:00 committed by GitHub
parent 3a25b13eb3
commit 75dceaa78f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 1 deletions

View File

@ -338,6 +338,8 @@ class VisionTransformer(BaseBackbone):
self.final_norm = final_norm
if final_norm:
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
if self.out_type == 'avg_featmap':
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
# freeze stages only when self.frozen_stages > 0
if self.frozen_stages > 0:
@ -454,7 +456,7 @@ class VisionTransformer(BaseBackbone):
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
if self.out_type == 'avg_featmap':
return patch_token.mean(dim=1)
return self.ln2(patch_token.mean(dim=1))
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer-wise depth of a parameter.