Use global pool arg to select attention pooling in head
parent
82cc53237e
commit
99cfd6702f
|
@ -470,7 +470,6 @@ class VisionTransformer(nn.Module):
|
|||
reg_tokens: int = 0,
|
||||
pre_norm: bool = False,
|
||||
fc_norm: Optional[bool] = None,
|
||||
use_attn_pool: bool = False,
|
||||
dynamic_img_size: bool = False,
|
||||
dynamic_img_pad: bool = False,
|
||||
drop_rate: float = 0.,
|
||||
|
@ -514,8 +513,8 @@ class VisionTransformer(nn.Module):
|
|||
block_fn: Transformer block layer.
|
||||
"""
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'avg', 'token')
|
||||
assert class_token or use_attn_pool or global_pool != 'token'
|
||||
assert global_pool in ('', 'avg', 'token', 'map')
|
||||
assert class_token or global_pool != 'token'
|
||||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
act_layer = act_layer or nn.GELU
|
||||
|
@ -580,7 +579,7 @@ class VisionTransformer(nn.Module):
|
|||
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
||||
|
||||
# Classifier Head
|
||||
if use_attn_pool == 'map':
|
||||
if global_pool == 'pool':
|
||||
self.attn_pool = AttentionPoolLatent(
|
||||
self.embed_dim,
|
||||
num_heads=num_heads,
|
||||
|
@ -2243,7 +2242,7 @@ def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransfor
|
|||
@register_model
|
||||
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, use_attn_pool=True,
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
|
Loading…
Reference in New Issue