Use global pool arg to select attention pooling in head

vit_siglip_and_reg
Ross Wightman 2023-09-30 16:16:21 -07:00
parent 82cc53237e
commit 99cfd6702f
1 changed files with 4 additions and 5 deletions

View File

@ -470,7 +470,6 @@ class VisionTransformer(nn.Module):
reg_tokens: int = 0,
pre_norm: bool = False,
fc_norm: Optional[bool] = None,
use_attn_pool: bool = False,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
drop_rate: float = 0.,
@ -514,8 +513,8 @@ class VisionTransformer(nn.Module):
block_fn: Transformer block layer.
"""
super().__init__()
assert global_pool in ('', 'avg', 'token')
assert class_token or use_attn_pool or global_pool != 'token'
assert global_pool in ('', 'avg', 'token', 'map')
assert class_token or global_pool != 'token'
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
@ -580,7 +579,7 @@ class VisionTransformer(nn.Module):
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
# Classifier Head
if use_attn_pool == 'map':
if global_pool == 'pool':
self.attn_pool = AttentionPoolLatent(
self.embed_dim,
num_heads=num_heads,
@ -2243,7 +2242,7 @@ def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransfor
@register_model
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, use_attn_pool=True,
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))