diff --git a/timm/optim/adopt.py b/timm/optim/adopt.py index a40b3010..9647aa9a 100644 --- a/timm/optim/adopt.py +++ b/timm/optim/adopt.py @@ -38,9 +38,13 @@ def _get_scalar_dtype(is_fused=None): ) +def _is_compiling(): + return torch.compiler.is_compiling() if hasattr(torch, 'compiler') else False + + def _get_value(x): # item is significantly faster than a cpu tensor in eager mode - if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + if not torch.jit.is_scripting() and _is_compiling(): return x else: return x.item() if isinstance(x, torch.Tensor) else x @@ -271,7 +275,7 @@ def _single_tensor_adopt( step_t = state_steps[i] # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if capturable and not _is_compiling(): from torch.optim.optimizer import _get_capturable_supported_devices capturable_supported_devices = _get_capturable_supported_devices() assert ( @@ -340,7 +344,7 @@ def _multi_tensor_adopt( ) # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] - if not torch._utils.is_compiling() and capturable: + if capturable and not _is_compiling(): from torch.optim.optimizer import _get_capturable_supported_devices capturable_supported_devices = _get_capturable_supported_devices( supports_xla=False @@ -384,7 +388,7 @@ def _multi_tensor_adopt( # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just # wrapped it once now. The alpha is required to assure we go to the right overload. - if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: + if not _is_compiling() and device_state_steps[0].is_cpu: torch._foreach_add_( device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 ) @@ -457,9 +461,7 @@ def adopt( # this check is slow during compilation, so we skip it # if it's strictly needed we can add this check back in dynamo - if not torch._utils.is_compiling() and not all( - isinstance(t, torch.Tensor) for t in state_steps - ): + if not _is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( "API has changed, `state_steps` argument must contain a list of singleton tensors" )