ViT-H/14 variant (#6)

pull/509/head
Han Chong 2025-02-27 17:54:50 +08:00 committed by GitHub
parent 391342d78d
commit 0909dbc557
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 18 additions and 4 deletions

View File

@ -368,7 +368,7 @@ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
patch_size=patch_size, patch_size=patch_size,
embed_dim=384, embed_dim=384,
depth=12, depth=12,
num_heads=6, num_heads=6, # embed_dim per head is 64
mlp_ratio=4, mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention), block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens, num_register_tokens=num_register_tokens,
@ -382,7 +382,7 @@ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
patch_size=patch_size, patch_size=patch_size,
embed_dim=768, embed_dim=768,
depth=12, depth=12,
num_heads=12, num_heads=12, # embed_dim per head is 64
mlp_ratio=4, mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention), block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens, num_register_tokens=num_register_tokens,
@ -396,7 +396,7 @@ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
patch_size=patch_size, patch_size=patch_size,
embed_dim=1024, embed_dim=1024,
depth=24, depth=24,
num_heads=16, num_heads=16, # embed_dim per head is 64
mlp_ratio=4, mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention), block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens, num_register_tokens=num_register_tokens,
@ -410,7 +410,21 @@ def vit_huge(patch_size=16, num_register_tokens=0, **kwargs):
patch_size=patch_size, patch_size=patch_size,
embed_dim=1280, embed_dim=1280,
depth=32, depth=32,
num_heads=16, num_heads=16, # embed_dim per head is 80
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_huge2(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1280,
depth=32,
num_heads=20, # embed_dim per head is 64
mlp_ratio=4, mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention), block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens, num_register_tokens=num_register_tokens,