To be technically correct, need to check the in-place _ ver of op

This commit is contained in:
Ross Wightman 2024-11-28 13:46:17 -08:00
parent b0a121bed0
commit 9b27f84876
2 changed files with 2 additions and 2 deletions

View File

@ -143,7 +143,7 @@ def lion(
if foreach is None: if foreach is None:
try: try:
# cannot do foreach if this overload doesn't exist when caution enabled # cannot do foreach if this overload doesn't exist when caution enabled
foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum.overloads() foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum_.overloads()
except: except:
foreach = False foreach = False

View File

@ -171,7 +171,7 @@ def nadamw(
if foreach is None: if foreach is None:
try: try:
# cannot do foreach if this overload doesn't exist when caution enabled # cannot do foreach if this overload doesn't exist when caution enabled
foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum.overloads() foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum_.overloads()
except: except:
foreach = False foreach = False