To be technically correct, need to check the in-place _ ver of op

This commit is contained in:
Ross Wightman 2024-11-28 13:46:17 -08:00
parent b0a121bed0
commit 9b27f84876
2 changed files with 2 additions and 2 deletions

View File

@ -143,7 +143,7 @@ def lion(
if foreach is None:
try:
# cannot do foreach if this overload doesn't exist when caution enabled
foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum.overloads()
foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum_.overloads()
except:
foreach = False

View File

@ -171,7 +171,7 @@ def nadamw(
if foreach is None:
try:
# cannot do foreach if this overload doesn't exist when caution enabled
foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum.overloads()
foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum_.overloads()
except:
foreach = False