From 9567cf6d84012f1fcfe95c55f5233d32ad7150f9 Mon Sep 17 00:00:00 2001 From: Fernando Cossio <39391180+fcossio@users.noreply.github.com> Date: Fri, 14 Jun 2024 15:24:54 +0200 Subject: [PATCH] Feature: add option global_pool='max' to VisionTransformer Most of the CNNs have a max global pooling option. I would like to extend ViT to have this option. --- timm/models/vision_transformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index a3ca0990..5beb77a2 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -400,7 +400,7 @@ class VisionTransformer(nn.Module): patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, - global_pool: Literal['', 'avg', 'token', 'map'] = 'token', + global_pool: Literal['', 'avg', 'max', 'token', 'map'] = 'token', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, @@ -459,10 +459,10 @@ class VisionTransformer(nn.Module): block_fn: Transformer block layer. """ super().__init__() - assert global_pool in ('', 'avg', 'token', 'map') + assert global_pool in ('', 'avg', 'max', 'token', 'map') assert class_token or global_pool != 'token' assert pos_embed in ('', 'none', 'learn') - use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm + use_fc_norm = global_pool in ['avg', 'max'] if fc_norm is None else fc_norm norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) act_layer = get_act_layer(act_layer) or nn.GELU @@ -761,6 +761,8 @@ class VisionTransformer(nn.Module): x = self.attn_pool(x) elif self.global_pool == 'avg': x = x[:, self.num_prefix_tokens:].mean(dim=1) + elif self.global_pool == 'max': + x, _ = torch.max(x[:, self.num_prefix_tokens:], dim=1) elif self.global_pool: x = x[:, 0] # class token x = self.fc_norm(x)