From 711c5dee6db9fa98ed0e69abb498a43e5348a042 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 11 Dec 2023 10:46:04 -0800 Subject: [PATCH] Update sgdw for older pytorch --- timm/optim/sgdw.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/timm/optim/sgdw.py b/timm/optim/sgdw.py index 1d95bd4e..b3d2c12f 100644 --- a/timm/optim/sgdw.py +++ b/timm/optim/sgdw.py @@ -1,6 +1,13 @@ +from functools import update_wrapper, wraps import torch from torch import Tensor -from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach +from torch.optim.optimizer import Optimizer +try: + from torch.optim.optimizer import _use_grad_for_differentiable, _default_to_fused_or_foreach + has_recent_pt = True +except ImportError: + has_recent_pt = False + from typing import List, Optional __all__ = ['SGDW', 'sgdw'] @@ -62,7 +69,9 @@ class SGDW(Optimizer): return has_sparse_grad - @_use_grad_for_differentiable + # FIXME figure out how to make _use_grad_for_differentiable interchangeable with no_grad decorator + # without args, for backwards compatibility with old pytorch + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -124,17 +133,19 @@ def sgdw( See :class:`~torch.optim.SGD` for details. """ + if has_recent_pt and hasattr(Optimizer, '_group_tensors_by_device_and_dtype'): + if foreach is None: + # why must we be explicit about an if statement for torch.jit.is_scripting here? + # because JIT can't handle Optionals nor fancy conditionals when scripting + if not torch.jit.is_scripting(): + _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False) + else: + foreach = False - if foreach is None: - # why must we be explicit about an if statement for torch.jit.is_scripting here? - # because JIT can't handle Optionals nor fancy conditionals when scripting - if not torch.jit.is_scripting(): - _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False) - else: - foreach = False - - if foreach and torch.jit.is_scripting(): - raise RuntimeError('torch.jit.script not supported with foreach optimizers') + if foreach and torch.jit.is_scripting(): + raise RuntimeError('torch.jit.script not supported with foreach optimizers') + else: + foreach = False # disabling altogether for older pytorch, as using _group_tensors_by_device_and_dtype if foreach and not torch.jit.is_scripting(): func = _multi_tensor_sgdw