diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 99af178c..1225e10a 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -470,7 +470,6 @@ class VisionTransformer(nn.Module): reg_tokens: int = 0, pre_norm: bool = False, fc_norm: Optional[bool] = None, - use_attn_pool: bool = False, dynamic_img_size: bool = False, dynamic_img_pad: bool = False, drop_rate: float = 0., @@ -514,8 +513,8 @@ class VisionTransformer(nn.Module): block_fn: Transformer block layer. """ super().__init__() - assert global_pool in ('', 'avg', 'token') - assert class_token or use_attn_pool or global_pool != 'token' + assert global_pool in ('', 'avg', 'token', 'map') + assert class_token or global_pool != 'token' use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU @@ -580,7 +579,7 @@ class VisionTransformer(nn.Module): self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() # Classifier Head - if use_attn_pool == 'map': + if global_pool == 'pool': self.attn_pool = AttentionPoolLatent( self.embed_dim, num_heads=num_heads, @@ -2243,7 +2242,7 @@ def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransfor @register_model def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, use_attn_pool=True, + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', ) model = _create_vision_transformer( 'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))