Feature: add option global_pool='max' to VisionTransformer

Most of the CNNs have a max global pooling option. I would like to extend ViT to have this option.
This commit is contained in:
Fernando Cossio 2024-06-14 15:24:54 +02:00 committed by GitHub
parent 22de845add
commit 9567cf6d84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -400,7 +400,7 @@ class VisionTransformer(nn.Module):
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: Literal['', 'avg', 'token', 'map'] = 'token',
global_pool: Literal['', 'avg', 'max', 'token', 'map'] = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
@ -459,10 +459,10 @@ class VisionTransformer(nn.Module):
block_fn: Transformer block layer.
"""
super().__init__()
assert global_pool in ('', 'avg', 'token', 'map')
assert global_pool in ('', 'avg', 'max', 'token', 'map')
assert class_token or global_pool != 'token'
assert pos_embed in ('', 'none', 'learn')
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
use_fc_norm = global_pool in ['avg', 'max'] if fc_norm is None else fc_norm
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
act_layer = get_act_layer(act_layer) or nn.GELU
@ -761,6 +761,8 @@ class VisionTransformer(nn.Module):
x = self.attn_pool(x)
elif self.global_pool == 'avg':
x = x[:, self.num_prefix_tokens:].mean(dim=1)
elif self.global_pool == 'max':
x, _ = torch.max(x[:, self.num_prefix_tokens:], dim=1)
elif self.global_pool:
x = x[:, 0] # class token
x = self.fc_norm(x)