From e9bbd227342b4104f9f4936f6df429c9daa663b0 Mon Sep 17 00:00:00 2001 From: cm090999 Date: Thu, 14 Mar 2024 10:18:25 +0100 Subject: [PATCH] removed unnnecessairy bug fix --- dinov2/layers/block.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/dinov2/layers/block.py b/dinov2/layers/block.py index 0dab4ef..dd9011f 100644 --- a/dinov2/layers/block.py +++ b/dinov2/layers/block.py @@ -27,24 +27,25 @@ logger = logging.getLogger("dinov2") XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: if XFORMERS_ENABLED: - from xformers.ops import fmha, scaled_index_add as _scaled_index_add, index_select_cat as _index_select_cat + # from xformers.ops import fmha, scaled_index_add as _scaled_index_add, index_select_cat as _index_select_cat + from xformers.ops import fmha, scaled_index_add, index_select_cat - def scaled_index_add(input, index, source, scaling, alpha): - is_proper_embed_dim = input.shape[-1] % 256 == 0 - is_float16 = input.dtype == torch.half - if is_proper_embed_dim and is_float16: - return _scaled_index_add(input, index, source, scaling, alpha) - else: - return torch.index_add(input, dim=0, source=scaling * source, index=index, alpha=alpha) + # def scaled_index_add(input, index, source, scaling, alpha): + # is_proper_embed_dim = input.shape[-1] % 256 == 0 + # is_float16 = input.dtype == torch.half + # if is_proper_embed_dim and is_float16: + # return _scaled_index_add(input, index, source, scaling, alpha) + # else: + # return torch.index_add(input, dim=0, source=scaling * source, index=index, alpha=alpha) - def index_select_cat(sources, indices): - is_proper_embed_dim = all(s.shape[-1] % 256 == 0 for s in sources) - is_float16 = all(s.dtype == torch.half for s in sources) - if is_proper_embed_dim and is_float16: - return _index_select_cat(sources, indices) - else: - return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0) + # def index_select_cat(sources, indices): + # is_proper_embed_dim = all(s.shape[-1] % 256 == 0 for s in sources) + # is_float16 = all(s.dtype == torch.half for s in sources) + # if is_proper_embed_dim and is_float16: + # return _index_select_cat(sources, indices) + # else: + # return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0) XFORMERS_AVAILABLE = True warnings.warn("xFormers is available (Block)")