diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py index aa5bf26..c13db5b 100644 --- a/dinov2/models/vision_transformer.py +++ b/dinov2/models/vision_transformer.py @@ -368,7 +368,7 @@ def vit_small(patch_size=16, num_register_tokens=0, **kwargs): patch_size=patch_size, embed_dim=384, depth=12, - num_heads=6, + num_heads=6, # embed_dim per head is 64 mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, @@ -382,7 +382,7 @@ def vit_base(patch_size=16, num_register_tokens=0, **kwargs): patch_size=patch_size, embed_dim=768, depth=12, - num_heads=12, + num_heads=12, # embed_dim per head is 64 mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, @@ -396,7 +396,7 @@ def vit_large(patch_size=16, num_register_tokens=0, **kwargs): patch_size=patch_size, embed_dim=1024, depth=24, - num_heads=16, + num_heads=16, # embed_dim per head is 64 mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, @@ -410,7 +410,21 @@ def vit_huge(patch_size=16, num_register_tokens=0, **kwargs): patch_size=patch_size, embed_dim=1280, depth=32, - num_heads=16, + num_heads=16, # embed_dim per head is 80 + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_huge2(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1280, + depth=32, + num_heads=20, # embed_dim per head is 64 mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens,