Fix ADOPT on older PyTorch (tested back to 1.13)

small_384_weights
Ross Wightman 2024-11-08 09:10:36 -08:00 committed by Ross Wightman
parent 79abc25f55
commit ff136b8d3a
1 changed files with 9 additions and 7 deletions

View File

@ -38,9 +38,13 @@ def _get_scalar_dtype(is_fused=None):
)
def _is_compiling():
return torch.compiler.is_compiling() if hasattr(torch, 'compiler') else False
def _get_value(x):
# item is significantly faster than a cpu tensor in eager mode
if not torch.jit.is_scripting() and torch.compiler.is_compiling():
if not torch.jit.is_scripting() and _is_compiling():
return x
else:
return x.item() if isinstance(x, torch.Tensor) else x
@ -271,7 +275,7 @@ def _single_tensor_adopt(
step_t = state_steps[i]
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
if capturable and not _is_compiling():
from torch.optim.optimizer import _get_capturable_supported_devices
capturable_supported_devices = _get_capturable_supported_devices()
assert (
@ -340,7 +344,7 @@ def _multi_tensor_adopt(
)
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
if capturable and not _is_compiling():
from torch.optim.optimizer import _get_capturable_supported_devices
capturable_supported_devices = _get_capturable_supported_devices(
supports_xla=False
@ -384,7 +388,7 @@ def _multi_tensor_adopt(
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
# wrapped it once now. The alpha is required to assure we go to the right overload.
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
if not _is_compiling() and device_state_steps[0].is_cpu:
torch._foreach_add_(
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
)
@ -457,9 +461,7 @@ def adopt(
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
if not _is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)