pull/514/merge
Riadh Fezzani 2025-03-31 15:31:53 +02:00 committed by GitHub
commit 937a73bae7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 47 deletions

View File

@ -0,0 +1,22 @@
import os
import logging
logger = logging.getLogger("dinov2")
def _xformers_is_available(layer):
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
xformers = None
try:
if XFORMERS_ENABLED:
import xformers
logger.info(f"xFormers is available ({layer})")
else:
logger.warning(f"xFormers is disabled ({layer})")
raise ImportError
except ImportError:
logger.warning(f"xFormers is not available ({layer})")
return xformers is not None

View File

@ -8,29 +8,19 @@
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
import os
import warnings
from torch import Tensor
from torch import nn
from ._utils import _xformers_is_available
logger = logging.getLogger("dinov2")
XFORMERS_AVAILABLE = _xformers_is_available("Attention")
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
if XFORMERS_ENABLED:
from xformers.ops import memory_efficient_attention, unbind
XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (Attention)")
else:
warnings.warn("xFormers is disabled (Attention)")
raise ImportError
except ImportError:
XFORMERS_AVAILABLE = False
warnings.warn("xFormers is not available (Attention)")
if XFORMERS_AVAILABLE:
from xformers.ops import memory_efficient_attention, unbind
class Attention(nn.Module):

View File

@ -8,9 +8,7 @@
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
import os
from typing import Callable, List, Any, Tuple, Dict
import warnings
import torch
from torch import nn, Tensor
@ -19,25 +17,15 @@ from .attention import Attention, MemEffAttention
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
from ._utils import _xformers_is_available
logger = logging.getLogger("dinov2")
XFORMERS_AVAILABLE = _xformers_is_available("Block")
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
if XFORMERS_ENABLED:
from xformers.ops import fmha, scaled_index_add, index_select_cat
XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (Block)")
else:
warnings.warn("xFormers is disabled (Block)")
raise ImportError
except ImportError:
XFORMERS_AVAILABLE = False
warnings.warn("xFormers is not available (Block)")
if XFORMERS_AVAILABLE:
from xformers.ops import fmha, scaled_index_add, index_select_cat
class Block(nn.Module):

View File

@ -3,13 +3,16 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import os
from typing import Callable, Optional
import warnings
from torch import Tensor, nn
import torch.nn.functional as F
from ._utils import _xformers_is_available
XFORMERS_AVAILABLE = _xformers_is_available("SwiGLU")
class SwiGLUFFN(nn.Module):
def __init__(
@ -34,21 +37,10 @@ class SwiGLUFFN(nn.Module):
return self.w3(hidden)
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
if XFORMERS_ENABLED:
from xformers.ops import SwiGLU
XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (SwiGLU)")
else:
warnings.warn("xFormers is disabled (SwiGLU)")
raise ImportError
except ImportError:
if XFORMERS_AVAILABLE:
from xformers.ops import SwiGLU
else:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False
warnings.warn("xFormers is not available (SwiGLU)")
class SwiGLUFFNFused(SwiGLU):