mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update sgdw for older pytorch
This commit is contained in:
parent
60b170b200
commit
711c5dee6d
@ -1,6 +1,13 @@
|
|||||||
|
from functools import update_wrapper, wraps
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach
|
from torch.optim.optimizer import Optimizer
|
||||||
|
try:
|
||||||
|
from torch.optim.optimizer import _use_grad_for_differentiable, _default_to_fused_or_foreach
|
||||||
|
has_recent_pt = True
|
||||||
|
except ImportError:
|
||||||
|
has_recent_pt = False
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
__all__ = ['SGDW', 'sgdw']
|
__all__ = ['SGDW', 'sgdw']
|
||||||
@ -62,7 +69,9 @@ class SGDW(Optimizer):
|
|||||||
|
|
||||||
return has_sparse_grad
|
return has_sparse_grad
|
||||||
|
|
||||||
@_use_grad_for_differentiable
|
# FIXME figure out how to make _use_grad_for_differentiable interchangeable with no_grad decorator
|
||||||
|
# without args, for backwards compatibility with old pytorch
|
||||||
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
@ -124,7 +133,7 @@ def sgdw(
|
|||||||
|
|
||||||
See :class:`~torch.optim.SGD` for details.
|
See :class:`~torch.optim.SGD` for details.
|
||||||
"""
|
"""
|
||||||
|
if has_recent_pt and hasattr(Optimizer, '_group_tensors_by_device_and_dtype'):
|
||||||
if foreach is None:
|
if foreach is None:
|
||||||
# why must we be explicit about an if statement for torch.jit.is_scripting here?
|
# why must we be explicit about an if statement for torch.jit.is_scripting here?
|
||||||
# because JIT can't handle Optionals nor fancy conditionals when scripting
|
# because JIT can't handle Optionals nor fancy conditionals when scripting
|
||||||
@ -135,6 +144,8 @@ def sgdw(
|
|||||||
|
|
||||||
if foreach and torch.jit.is_scripting():
|
if foreach and torch.jit.is_scripting():
|
||||||
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
|
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
|
||||||
|
else:
|
||||||
|
foreach = False # disabling altogether for older pytorch, as using _group_tensors_by_device_and_dtype
|
||||||
|
|
||||||
if foreach and not torch.jit.is_scripting():
|
if foreach and not torch.jit.is_scripting():
|
||||||
func = _multi_tensor_sgdw
|
func = _multi_tensor_sgdw
|
||||||
|
Loading…
x
Reference in New Issue
Block a user