Update sgdw for older pytorch
parent
60b170b200
commit
711c5dee6d
|
@ -1,6 +1,13 @@
|
|||
from functools import update_wrapper, wraps
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach
|
||||
from torch.optim.optimizer import Optimizer
|
||||
try:
|
||||
from torch.optim.optimizer import _use_grad_for_differentiable, _default_to_fused_or_foreach
|
||||
has_recent_pt = True
|
||||
except ImportError:
|
||||
has_recent_pt = False
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
__all__ = ['SGDW', 'sgdw']
|
||||
|
@ -62,7 +69,9 @@ class SGDW(Optimizer):
|
|||
|
||||
return has_sparse_grad
|
||||
|
||||
@_use_grad_for_differentiable
|
||||
# FIXME figure out how to make _use_grad_for_differentiable interchangeable with no_grad decorator
|
||||
# without args, for backwards compatibility with old pytorch
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
|
@ -124,7 +133,7 @@ def sgdw(
|
|||
|
||||
See :class:`~torch.optim.SGD` for details.
|
||||
"""
|
||||
|
||||
if has_recent_pt and hasattr(Optimizer, '_group_tensors_by_device_and_dtype'):
|
||||
if foreach is None:
|
||||
# why must we be explicit about an if statement for torch.jit.is_scripting here?
|
||||
# because JIT can't handle Optionals nor fancy conditionals when scripting
|
||||
|
@ -135,6 +144,8 @@ def sgdw(
|
|||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
|
||||
else:
|
||||
foreach = False # disabling altogether for older pytorch, as using _group_tensors_by_device_and_dtype
|
||||
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_sgdw
|
||||
|
|
Loading…
Reference in New Issue