mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Use global pool arg to select attention pooling in head
This commit is contained in:
parent
82cc53237e
commit
99cfd6702f
@ -470,7 +470,6 @@ class VisionTransformer(nn.Module):
|
|||||||
reg_tokens: int = 0,
|
reg_tokens: int = 0,
|
||||||
pre_norm: bool = False,
|
pre_norm: bool = False,
|
||||||
fc_norm: Optional[bool] = None,
|
fc_norm: Optional[bool] = None,
|
||||||
use_attn_pool: bool = False,
|
|
||||||
dynamic_img_size: bool = False,
|
dynamic_img_size: bool = False,
|
||||||
dynamic_img_pad: bool = False,
|
dynamic_img_pad: bool = False,
|
||||||
drop_rate: float = 0.,
|
drop_rate: float = 0.,
|
||||||
@ -514,8 +513,8 @@ class VisionTransformer(nn.Module):
|
|||||||
block_fn: Transformer block layer.
|
block_fn: Transformer block layer.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert global_pool in ('', 'avg', 'token')
|
assert global_pool in ('', 'avg', 'token', 'map')
|
||||||
assert class_token or use_attn_pool or global_pool != 'token'
|
assert class_token or global_pool != 'token'
|
||||||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
||||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||||
act_layer = act_layer or nn.GELU
|
act_layer = act_layer or nn.GELU
|
||||||
@ -580,7 +579,7 @@ class VisionTransformer(nn.Module):
|
|||||||
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
||||||
|
|
||||||
# Classifier Head
|
# Classifier Head
|
||||||
if use_attn_pool == 'map':
|
if global_pool == 'pool':
|
||||||
self.attn_pool = AttentionPoolLatent(
|
self.attn_pool = AttentionPoolLatent(
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
@ -2243,7 +2242,7 @@ def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransfor
|
|||||||
@register_model
|
@register_model
|
||||||
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, use_attn_pool=True,
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
|
||||||
)
|
)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user