diff --git a/dinov2/hub/backbones.py b/dinov2/hub/backbones.py index 53fe837..a8280ca 100644 --- a/dinov2/hub/backbones.py +++ b/dinov2/hub/backbones.py @@ -9,7 +9,7 @@ from typing import Union import torch from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name - +from dinov2.layers.attention import Attention class Weights(Enum): LVD142M = "LVD142M" @@ -61,38 +61,45 @@ def _make_dinov2_model( return model -def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): +def dinov2_vits14(*, pretrained: bool = True, for_onnx: bool=False, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): """ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. """ - return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + if not for_onnx: + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, attn_class=Attention, **kwargs) -def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): +def dinov2_vitb14(*, pretrained: bool = True, for_onnx: bool=False, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): """ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. """ - return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + if not for_onnx: + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, attn_class=Attention, **kwargs) - -def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): +def dinov2_vitl14(*, pretrained: bool = True, for_onnx: bool=False, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): """ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. """ - return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + if not for_onnx: + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, attn_class=Attention, **kwargs) - -def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): +def dinov2_vitg14(*, pretrained: bool = True, for_onnx: bool=False, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): """ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. """ - return _make_dinov2_model( - arch_name="vit_giant2", - ffn_layer="swiglufused", - weights=weights, - pretrained=pretrained, - **kwargs, - ) + if not for_onnx: + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + return _make_dinov2_model(arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, attn_class=Attention, **kwargs) def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):