From 17a47c0e351e1ca432c9a967fcd4f30b02e02118 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 11 Dec 2023 09:06:55 -0800 Subject: [PATCH] Add SGDW optimizer --- timm/optim/optim_factory.py | 8 ++ timm/optim/sgdw.py | 255 ++++++++++++++++++++++++++++++++++++ 2 files changed, 263 insertions(+) create mode 100644 timm/optim/sgdw.py diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index af316dee..8187b55a 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -27,6 +27,7 @@ from .nvnovograd import NvNovoGrad from .radam import RAdam from .rmsprop_tf import RMSpropTF from .sgdp import SGDP +from .sgdw import SGDW _logger = logging.getLogger(__name__) @@ -288,6 +289,13 @@ def create_optimizer_v2( optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) elif opt_lower == 'sgdp': optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == 'sgdw' or opt_lower == 'nesterovw': + # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons + opt_args.pop('eps', None) + optimizer = SGDW(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == 'momentumw': + opt_args.pop('eps', None) + optimizer = SGDW(parameters, momentum=momentum, nesterov=False, **opt_args) # adaptive elif opt_lower == 'adam': diff --git a/timm/optim/sgdw.py b/timm/optim/sgdw.py new file mode 100644 index 00000000..1d95bd4e --- /dev/null +++ b/timm/optim/sgdw.py @@ -0,0 +1,255 @@ +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach +from typing import List, Optional + +__all__ = ['SGDW', 'sgdw'] + + +class SGDW(Optimizer): + def __init__( + self, + params, + lr=1e-3, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + *, + maximize: bool = False, + foreach: Optional[bool] = None, + differentiable: bool = False, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + defaults = dict( + lr=lr, momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov, + maximize=maximize, foreach=foreach, + differentiable=differentiable) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + group.setdefault('maximize', False) + group.setdefault('foreach', None) + group.setdefault('differentiable', False) + + def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list): + has_sparse_grad = False + + for p in group['params']: + if p.grad is not None: + params_with_grad.append(p) + d_p_list.append(p.grad) + if p.grad.is_sparse: + has_sparse_grad = True + + state = self.state[p] + if 'momentum_buffer' not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state['momentum_buffer']) + + return has_sparse_grad + + @_use_grad_for_differentiable + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + d_p_list = [] + momentum_buffer_list = [] + + has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list) + + sgdw( + params_with_grad, + d_p_list, + momentum_buffer_list, + weight_decay=group['weight_decay'], + momentum=group['momentum'], + lr=group['lr'], + dampening=group['dampening'], + nesterov=group['nesterov'], + maximize=group['maximize'], + has_sparse_grad=has_sparse_grad, + foreach=group['foreach'], + ) + + # update momentum_buffers in state + for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list): + state = self.state[p] + state['momentum_buffer'] = momentum_buffer + + return loss + + +def sgdw( + params: List[Tensor], + d_p_list: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 + # setting this as kwarg for now as functional API is compiled by torch/distributed/optim + has_sparse_grad: bool = None, + foreach: Optional[bool] = None, + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool +): + r"""Functional API that performs SGD algorithm computation. + + See :class:`~torch.optim.SGD` for details. + """ + + if foreach is None: + # why must we be explicit about an if statement for torch.jit.is_scripting here? + # because JIT can't handle Optionals nor fancy conditionals when scripting + if not torch.jit.is_scripting(): + _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False) + else: + foreach = False + + if foreach and torch.jit.is_scripting(): + raise RuntimeError('torch.jit.script not supported with foreach optimizers') + + if foreach and not torch.jit.is_scripting(): + func = _multi_tensor_sgdw + else: + func = _single_tensor_sgdw + + func( + params, + d_p_list, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=nesterov, + has_sparse_grad=has_sparse_grad, + maximize=maximize, + ) + + +def _single_tensor_sgdw( + params: List[Tensor], + d_p_list: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool +): + for i, param in enumerate(params): + d_p = d_p_list[i] if not maximize else -d_p_list[i] + + param.mul_(1. - lr * weight_decay) + + if momentum != 0: + buf = momentum_buffer_list[i] + + if buf is None: + buf = torch.clone(d_p).detach() + momentum_buffer_list[i] = buf + else: + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + + if nesterov: + d_p = d_p.add(buf, alpha=momentum) + else: + d_p = buf + + param.add_(d_p, alpha=-lr) + + +def _multi_tensor_sgdw( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool +): + if len(params) == 0: + return + + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, momentum_buffer_list], with_indices=True) + for ((device_params, device_grads, device_momentum_buffer_list), indices) in grouped_tensors.values(): + device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads) + + if maximize: + device_grads = torch._foreach_neg(device_grads) + + torch._foreach_mul_(params, 1. - lr * weight_decay) + + if momentum != 0: + bufs = [] + + all_states_with_momentum_buffer = True + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + all_states_with_momentum_buffer = False + break + else: + bufs.append(device_momentum_buffer_list[i]) + + if all_states_with_momentum_buffer: + torch._foreach_mul_(bufs, momentum) + torch._foreach_add_(bufs, device_grads, alpha=1 - dampening) + else: + bufs = [] + for i in range(len(device_momentum_buffer_list)): + if device_momentum_buffer_list[i] is None: + buf = device_momentum_buffer_list[i] = momentum_buffer_list[indices[i]] = \ + torch.clone(device_grads[i]).detach() + else: + buf = device_momentum_buffer_list[i] + buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening) + + bufs.append(buf) + + if nesterov: + torch._foreach_add_(device_grads, bufs, alpha=momentum) + else: + device_grads = bufs + + if not device_has_sparse_grad: + torch._foreach_add_(device_params, device_grads, alpha=-lr) + else: + # foreach APIs don't support sparse + for i in range(len(device_params)): + device_params[i].add_(device_grads[i], alpha=-lr)