Allow disabling xFormers via environment variable (#180)

Allow disabling the use of xFormers (for inference) by simply setting the XFORMERS_DISABLED environment variable
pull/183/head
Patrick Labatut 2023-08-30 17:20:47 +02:00 committed by GitHub
parent be7e57252f
commit ebc1cba109
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 15 deletions

View File

@ -9,6 +9,8 @@
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
import os
import warnings
from torch import Tensor
from torch import nn
@ -17,13 +19,19 @@ from torch import nn
logger = logging.getLogger("dinov2")
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
from xformers.ops import memory_efficient_attention, unbind, fmha
if XFORMERS_ENABLED:
from xformers.ops import memory_efficient_attention, unbind
XFORMERS_AVAILABLE = True
XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (Attention)")
else:
warnings.warn("xFormers is disabled (Attention)")
raise ImportError
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
warnings.warn("xFormers is not available (Attention)")
class Attention(nn.Module):
@ -65,7 +73,8 @@ class Attention(nn.Module):
class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
assert attn_bias is None, "xFormers is required for nested tensors usage"
if attn_bias is not None:
raise AssertionError("xFormers is required for using nested tensors")
return super().forward(x)
B, N, C = x.shape

View File

@ -9,7 +9,9 @@
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
import os
from typing import Callable, List, Any, Tuple, Dict
import warnings
import torch
from torch import nn, Tensor
@ -23,15 +25,21 @@ from .mlp import Mlp
logger = logging.getLogger("dinov2")
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
from xformers.ops import fmha
from xformers.ops import scaled_index_add, index_select_cat
if XFORMERS_ENABLED:
from xformers.ops import fmha, scaled_index_add, index_select_cat
XFORMERS_AVAILABLE = True
XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (Block)")
else:
warnings.warn("xFormers is disabled (Block)")
raise ImportError
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
warnings.warn("xFormers is not available (Block)")
class Block(nn.Module):
def __init__(
@ -246,7 +254,8 @@ class NestedTensorBlock(Block):
if isinstance(x_or_x_list, Tensor):
return super().forward(x_or_x_list)
elif isinstance(x_or_x_list, list):
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
if not XFORMERS_AVAILABLE:
raise AssertionError("xFormers is required for using nested tensors")
return self.forward_nested(x_or_x_list)
else:
raise AssertionError

View File

@ -4,7 +4,9 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Callable, Optional
import warnings
from torch import Tensor, nn
import torch.nn.functional as F
@ -33,14 +35,22 @@ class SwiGLUFFN(nn.Module):
return self.w3(hidden)
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
from xformers.ops import SwiGLU
if XFORMERS_ENABLED:
from xformers.ops import SwiGLU
XFORMERS_AVAILABLE = True
XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (SwiGLU)")
else:
warnings.warn("xFormers is disabled (SwiGLU)")
raise ImportError
except ImportError:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False
warnings.warn("xFormers is not available (SwiGLU)")
class SwiGLUFFNFused(SwiGLU):
def __init__(

View File

@ -19,13 +19,11 @@ from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, r
from dinov2.models.vision_transformer import BlockChunk
try:
from xformers.ops import fmha
XFORMERS_AVAILABLE = True
except ImportError:
XFORMERS_AVAILABLE = False
assert XFORMERS_AVAILABLE, "xFormers is required for DINOv2 training"
raise AssertionError("xFormers is required for training")
logger = logging.getLogger("dinov2")