diff --git a/timm/optim/lion.py b/timm/optim/lion.py index 980a0713..18607232 100644 --- a/timm/optim/lion.py +++ b/timm/optim/lion.py @@ -143,7 +143,7 @@ def lion( if foreach is None: try: # cannot do foreach if this overload doesn't exist when caution enabled - foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum.overloads() + foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum_.overloads() except: foreach = False diff --git a/timm/optim/nadamw.py b/timm/optim/nadamw.py index 17eb6fd0..d9933026 100644 --- a/timm/optim/nadamw.py +++ b/timm/optim/nadamw.py @@ -171,7 +171,7 @@ def nadamw( if foreach is None: try: # cannot do foreach if this overload doesn't exist when caution enabled - foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum.overloads() + foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum_.overloads() except: foreach = False