diff --git a/dinov2/layers/_utils.py b/dinov2/layers/_utils.py new file mode 100644 index 0000000..00a8639 --- /dev/null +++ b/dinov2/layers/_utils.py @@ -0,0 +1,22 @@ +import os +import logging + +logger = logging.getLogger("dinov2") + + +def _xformers_is_available(layer): + + XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None + xformers = None + try: + if XFORMERS_ENABLED: + import xformers + + logger.info(f"xFormers is available ({layer})") + else: + logger.warning(f"xFormers is disabled ({layer})") + raise ImportError + except ImportError: + logger.warning(f"xFormers is not available ({layer})") + + return xformers is not None diff --git a/dinov2/layers/attention.py b/dinov2/layers/attention.py index 0fb76ef..2023d63 100644 --- a/dinov2/layers/attention.py +++ b/dinov2/layers/attention.py @@ -8,29 +8,19 @@ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py import logging -import os -import warnings from torch import Tensor from torch import nn +from ._utils import _xformers_is_available + logger = logging.getLogger("dinov2") +XFORMERS_AVAILABLE = _xformers_is_available("Attention") -XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None -try: - if XFORMERS_ENABLED: - from xformers.ops import memory_efficient_attention, unbind - - XFORMERS_AVAILABLE = True - warnings.warn("xFormers is available (Attention)") - else: - warnings.warn("xFormers is disabled (Attention)") - raise ImportError -except ImportError: - XFORMERS_AVAILABLE = False - warnings.warn("xFormers is not available (Attention)") +if XFORMERS_AVAILABLE: + from xformers.ops import memory_efficient_attention, unbind class Attention(nn.Module): diff --git a/dinov2/layers/block.py b/dinov2/layers/block.py index 930787b..ca6ca98 100644 --- a/dinov2/layers/block.py +++ b/dinov2/layers/block.py @@ -8,9 +8,7 @@ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py import logging -import os from typing import Callable, List, Any, Tuple, Dict -import warnings import torch from torch import nn, Tensor @@ -19,25 +17,15 @@ from .attention import Attention, MemEffAttention from .drop_path import DropPath from .layer_scale import LayerScale from .mlp import Mlp +from ._utils import _xformers_is_available logger = logging.getLogger("dinov2") +XFORMERS_AVAILABLE = _xformers_is_available("Block") -XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None -try: - if XFORMERS_ENABLED: - from xformers.ops import fmha, scaled_index_add, index_select_cat - - XFORMERS_AVAILABLE = True - warnings.warn("xFormers is available (Block)") - else: - warnings.warn("xFormers is disabled (Block)") - raise ImportError -except ImportError: - XFORMERS_AVAILABLE = False - - warnings.warn("xFormers is not available (Block)") +if XFORMERS_AVAILABLE: + from xformers.ops import fmha, scaled_index_add, index_select_cat class Block(nn.Module): diff --git a/dinov2/layers/swiglu_ffn.py b/dinov2/layers/swiglu_ffn.py index 5e9dafa..c24365d 100644 --- a/dinov2/layers/swiglu_ffn.py +++ b/dinov2/layers/swiglu_ffn.py @@ -3,13 +3,16 @@ # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. -import os from typing import Callable, Optional -import warnings from torch import Tensor, nn import torch.nn.functional as F +from ._utils import _xformers_is_available + + +XFORMERS_AVAILABLE = _xformers_is_available("SwiGLU") + class SwiGLUFFN(nn.Module): def __init__( @@ -34,21 +37,10 @@ class SwiGLUFFN(nn.Module): return self.w3(hidden) -XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None -try: - if XFORMERS_ENABLED: - from xformers.ops import SwiGLU - - XFORMERS_AVAILABLE = True - warnings.warn("xFormers is available (SwiGLU)") - else: - warnings.warn("xFormers is disabled (SwiGLU)") - raise ImportError -except ImportError: +if XFORMERS_AVAILABLE: + from xformers.ops import SwiGLU +else: SwiGLU = SwiGLUFFN - XFORMERS_AVAILABLE = False - - warnings.warn("xFormers is not available (SwiGLU)") class SwiGLUFFNFused(SwiGLU):