Update sgdw for older pytorch

pull/2052/head
Ross Wightman 2023-12-11 10:46:04 -08:00 committed by Ross Wightman
parent 60b170b200
commit 711c5dee6d
1 changed files with 23 additions and 12 deletions

View File

@ -1,6 +1,13 @@
from functools import update_wrapper, wraps
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach
from torch.optim.optimizer import Optimizer
try:
from torch.optim.optimizer import _use_grad_for_differentiable, _default_to_fused_or_foreach
has_recent_pt = True
except ImportError:
has_recent_pt = False
from typing import List, Optional
__all__ = ['SGDW', 'sgdw']
@ -62,7 +69,9 @@ class SGDW(Optimizer):
return has_sparse_grad
@_use_grad_for_differentiable
# FIXME figure out how to make _use_grad_for_differentiable interchangeable with no_grad decorator
# without args, for backwards compatibility with old pytorch
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@ -124,7 +133,7 @@ def sgdw(
See :class:`~torch.optim.SGD` for details.
"""
if has_recent_pt and hasattr(Optimizer, '_group_tensors_by_device_and_dtype'):
if foreach is None:
# why must we be explicit about an if statement for torch.jit.is_scripting here?
# because JIT can't handle Optionals nor fancy conditionals when scripting
@ -135,6 +144,8 @@ def sgdw(
if foreach and torch.jit.is_scripting():
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
else:
foreach = False # disabling altogether for older pytorch, as using _group_tensors_by_device_and_dtype
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_sgdw