Fix ADOPT on older PyTorch (tested back to 1.13)
parent
79abc25f55
commit
ff136b8d3a
|
@ -38,9 +38,13 @@ def _get_scalar_dtype(is_fused=None):
|
|||
)
|
||||
|
||||
|
||||
def _is_compiling():
|
||||
return torch.compiler.is_compiling() if hasattr(torch, 'compiler') else False
|
||||
|
||||
|
||||
def _get_value(x):
|
||||
# item is significantly faster than a cpu tensor in eager mode
|
||||
if not torch.jit.is_scripting() and torch.compiler.is_compiling():
|
||||
if not torch.jit.is_scripting() and _is_compiling():
|
||||
return x
|
||||
else:
|
||||
return x.item() if isinstance(x, torch.Tensor) else x
|
||||
|
@ -271,7 +275,7 @@ def _single_tensor_adopt(
|
|||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if capturable and not _is_compiling():
|
||||
from torch.optim.optimizer import _get_capturable_supported_devices
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
|
@ -340,7 +344,7 @@ def _multi_tensor_adopt(
|
|||
)
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
if capturable and not _is_compiling():
|
||||
from torch.optim.optimizer import _get_capturable_supported_devices
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
|
@ -384,7 +388,7 @@ def _multi_tensor_adopt(
|
|||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
if not _is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
|
@ -457,9 +461,7 @@ def adopt(
|
|||
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
if not _is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue