removed unnnecessairy bug fix

pull/393/head
cm090999 2024-03-14 10:18:25 +01:00
parent afc9d19fbe
commit e9bbd22734
1 changed files with 16 additions and 15 deletions

View File

@ -27,24 +27,25 @@ logger = logging.getLogger("dinov2")
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try: try:
if XFORMERS_ENABLED: if XFORMERS_ENABLED:
from xformers.ops import fmha, scaled_index_add as _scaled_index_add, index_select_cat as _index_select_cat # from xformers.ops import fmha, scaled_index_add as _scaled_index_add, index_select_cat as _index_select_cat
from xformers.ops import fmha, scaled_index_add, index_select_cat
def scaled_index_add(input, index, source, scaling, alpha): # def scaled_index_add(input, index, source, scaling, alpha):
is_proper_embed_dim = input.shape[-1] % 256 == 0 # is_proper_embed_dim = input.shape[-1] % 256 == 0
is_float16 = input.dtype == torch.half # is_float16 = input.dtype == torch.half
if is_proper_embed_dim and is_float16: # if is_proper_embed_dim and is_float16:
return _scaled_index_add(input, index, source, scaling, alpha) # return _scaled_index_add(input, index, source, scaling, alpha)
else: # else:
return torch.index_add(input, dim=0, source=scaling * source, index=index, alpha=alpha) # return torch.index_add(input, dim=0, source=scaling * source, index=index, alpha=alpha)
def index_select_cat(sources, indices): # def index_select_cat(sources, indices):
is_proper_embed_dim = all(s.shape[-1] % 256 == 0 for s in sources) # is_proper_embed_dim = all(s.shape[-1] % 256 == 0 for s in sources)
is_float16 = all(s.dtype == torch.half for s in sources) # is_float16 = all(s.dtype == torch.half for s in sources)
if is_proper_embed_dim and is_float16: # if is_proper_embed_dim and is_float16:
return _index_select_cat(sources, indices) # return _index_select_cat(sources, indices)
else: # else:
return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0) # return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0)
XFORMERS_AVAILABLE = True XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (Block)") warnings.warn("xFormers is available (Block)")