removed unnnecessairy bug fix
parent
afc9d19fbe
commit
e9bbd22734
|
@ -27,24 +27,25 @@ logger = logging.getLogger("dinov2")
|
||||||
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
||||||
try:
|
try:
|
||||||
if XFORMERS_ENABLED:
|
if XFORMERS_ENABLED:
|
||||||
from xformers.ops import fmha, scaled_index_add as _scaled_index_add, index_select_cat as _index_select_cat
|
# from xformers.ops import fmha, scaled_index_add as _scaled_index_add, index_select_cat as _index_select_cat
|
||||||
|
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
||||||
|
|
||||||
def scaled_index_add(input, index, source, scaling, alpha):
|
# def scaled_index_add(input, index, source, scaling, alpha):
|
||||||
is_proper_embed_dim = input.shape[-1] % 256 == 0
|
# is_proper_embed_dim = input.shape[-1] % 256 == 0
|
||||||
is_float16 = input.dtype == torch.half
|
# is_float16 = input.dtype == torch.half
|
||||||
if is_proper_embed_dim and is_float16:
|
# if is_proper_embed_dim and is_float16:
|
||||||
return _scaled_index_add(input, index, source, scaling, alpha)
|
# return _scaled_index_add(input, index, source, scaling, alpha)
|
||||||
else:
|
# else:
|
||||||
return torch.index_add(input, dim=0, source=scaling * source, index=index, alpha=alpha)
|
# return torch.index_add(input, dim=0, source=scaling * source, index=index, alpha=alpha)
|
||||||
|
|
||||||
|
|
||||||
def index_select_cat(sources, indices):
|
# def index_select_cat(sources, indices):
|
||||||
is_proper_embed_dim = all(s.shape[-1] % 256 == 0 for s in sources)
|
# is_proper_embed_dim = all(s.shape[-1] % 256 == 0 for s in sources)
|
||||||
is_float16 = all(s.dtype == torch.half for s in sources)
|
# is_float16 = all(s.dtype == torch.half for s in sources)
|
||||||
if is_proper_embed_dim and is_float16:
|
# if is_proper_embed_dim and is_float16:
|
||||||
return _index_select_cat(sources, indices)
|
# return _index_select_cat(sources, indices)
|
||||||
else:
|
# else:
|
||||||
return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0)
|
# return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0)
|
||||||
|
|
||||||
XFORMERS_AVAILABLE = True
|
XFORMERS_AVAILABLE = True
|
||||||
warnings.warn("xFormers is available (Block)")
|
warnings.warn("xFormers is available (Block)")
|
||||||
|
|
Loading…
Reference in New Issue