Update Adan with newer impl (from original source) that includes multi-tensor fn

This commit is contained in:
Ross Wightman 2024-11-26 10:55:20 -08:00
parent b59058bd88
commit c5690c044e

View File

@ -5,52 +5,94 @@ Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J].
Implementation adapted from https://github.com/sail-sg/Adan Implementation adapted from https://github.com/sail-sg/Adan
""" """
# Copyright 2022 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
from typing import List, Tuple
import torch import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
from torch.optim import Optimizer
class MultiTensorApply(object):
available = False
warned = False
def __init__(self, chunk_size):
try:
MultiTensorApply.available = True
self.chunk_size = chunk_size
except ImportError as err:
MultiTensorApply.available = False
MultiTensorApply.import_err = err
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
class Adan(Optimizer): class Adan(Optimizer):
""" """ Implements a pytorch variant of Adan.
Implements a pytorch variant of Adan
Adan was proposed in Adan was proposed in Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models
Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022.
https://arxiv.org/abs/2208.06677 https://arxiv.org/abs/2208.06677
Arguments: Arguments:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups. params: Iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3) lr: Learning rate.
betas (Tuple[float, float, flot], optional): coefficients used for computing betas: Coefficients used for first- and second-order moments.
running averages of gradient and its norm. (default: (0.98, 0.92, 0.99)) eps: Term added to the denominator to improve numerical stability.
eps (float, optional): term added to the denominator to improve weight_decay: Decoupled weight decay (L2 penalty)
numerical stability. (default: 1e-8) no_prox: How to perform the weight decay
weight_decay (float, optional): decoupled weight decay (L2 penalty) (default: 0) foreach: If True would use torch._foreach implementation. Faster but uses slightly more memory.
no_prox (bool): how to perform the decoupled weight decay (default: False)
""" """
def __init__( def __init__(self,
self,
params, params,
lr=1e-3, lr: float = 1e-3,
betas=(0.98, 0.92, 0.99), betas: Tuple[float, float, float] = (0.98, 0.92, 0.99),
eps=1e-8, eps: float = 1e-8,
weight_decay=0.0, weight_decay: float = 0.0,
no_prox=False, no_prox: bool = False,
foreach: bool = True,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError('Invalid epsilon value: {}'.format(eps))
if not 0.0 <= betas[0] < 1.0: if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
if not 0.0 <= betas[1] < 1.0: if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
if not 0.0 <= betas[2] < 1.0: if not 0.0 <= betas[2] < 1.0:
raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) raise ValueError('Invalid beta parameter at index 2: {}'.format(betas[2]))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, no_prox=no_prox)
super(Adan, self).__init__(params, defaults) defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
no_prox=no_prox,
foreach=foreach,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super(Adan, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('no_prox', False)
@torch.no_grad() @torch.no_grad()
def restart_opt(self): def restart_opt(self):
@ -70,17 +112,23 @@ class Adan(Optimizer):
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
""" Performs a single optimization step. """Performs a single optimization step."""
"""
loss = None loss = None
if closure is not None: if closure is not None:
with torch.enable_grad(): with torch.enable_grad():
loss = closure() loss = closure()
for group in self.param_groups: for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
exp_avg_diffs = []
neg_pre_grads = []
beta1, beta2, beta3 = group['betas'] beta1, beta2, beta3 = group['betas']
# assume same step across group now to simplify things # assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel # per parameter step can be easily supported by making it a tensor, or pass list into kernel
if 'step' in group: if 'step' in group:
group['step'] += 1 group['step'] += 1
else: else:
@ -93,32 +141,155 @@ class Adan(Optimizer):
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad params_with_grad.append(p)
grads.append(p.grad)
state = self.state[p] state = self.state[p]
if len(state) == 0: if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p) state['exp_avg'] = torch.zeros_like(p)
state['exp_avg_diff'] = torch.zeros_like(p)
state['exp_avg_sq'] = torch.zeros_like(p) state['exp_avg_sq'] = torch.zeros_like(p)
state['pre_grad'] = grad.clone() state['exp_avg_diff'] = torch.zeros_like(p)
exp_avg, exp_avg_sq, exp_avg_diff = state['exp_avg'], state['exp_avg_diff'], state['exp_avg_sq'] if 'neg_pre_grad' not in state or group['step'] == 1:
grad_diff = grad - state['pre_grad'] state['neg_pre_grad'] = -p.grad.clone()
exp_avg.lerp_(grad, 1. - beta1) # m_t exp_avgs.append(state['exp_avg'])
exp_avg_diff.lerp_(grad_diff, 1. - beta2) # diff_t (v) exp_avg_sqs.append(state['exp_avg_sq'])
update = grad + beta2 * grad_diff exp_avg_diffs.append(state['exp_avg_diff'])
exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1. - beta3) # n_t neg_pre_grads.append(state['neg_pre_grad'])
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction3)).add_(group['eps']) if not params_with_grad:
update = (exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2).div_(denom) continue
if group['no_prox']:
p.data.mul_(1 - group['lr'] * group['weight_decay'])
p.add_(update, alpha=-group['lr'])
else:
p.add_(update, alpha=-group['lr'])
p.data.div_(1 + group['lr'] * group['weight_decay'])
state['pre_grad'].copy_(grad) kwargs = dict(
params=params_with_grad,
grads=grads,
exp_avgs=exp_avgs,
exp_avg_sqs=exp_avg_sqs,
exp_avg_diffs=exp_avg_diffs,
neg_pre_grads=neg_pre_grads,
beta1=beta1,
beta2=beta2,
beta3=beta3,
bias_correction1=bias_correction1,
bias_correction2=bias_correction2,
bias_correction3_sqrt=math.sqrt(bias_correction3),
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
no_prox=group['no_prox'],
)
if group['foreach']:
_multi_tensor_adan(**kwargs)
else:
_single_tensor_adan(**kwargs)
return loss return loss
def _single_tensor_adan(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
exp_avg_diffs: List[Tensor],
neg_pre_grads: List[Tensor],
*,
beta1: float,
beta2: float,
beta3: float,
bias_correction1: float,
bias_correction2: float,
bias_correction3_sqrt: float,
lr: float,
weight_decay: float,
eps: float,
no_prox: bool,
):
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
exp_avg_diff = exp_avg_diffs[i]
neg_grad_or_diff = neg_pre_grads[i]
# for memory saving, we use `neg_grad_or_diff` to get some temp variable in an inplace way
neg_grad_or_diff.add_(grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t
exp_avg_diff.mul_(beta2).add_(neg_grad_or_diff, alpha=1 - beta2) # diff_t
neg_grad_or_diff.mul_(beta2).add_(grad)
exp_avg_sq.mul_(beta3).addcmul_(neg_grad_or_diff, neg_grad_or_diff, value=1 - beta3) # n_t
denom = (exp_avg_sq.sqrt() / bias_correction3_sqrt).add_(eps)
step_size_diff = lr * beta2 / bias_correction2
step_size = lr / bias_correction1
if no_prox:
param.mul_(1 - lr * weight_decay)
param.addcdiv_(exp_avg, denom, value=-step_size)
param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff)
else:
param.addcdiv_(exp_avg, denom, value=-step_size)
param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff)
param.div_(1 + lr * weight_decay)
neg_grad_or_diff.zero_().add_(grad, alpha=-1.0)
def _multi_tensor_adan(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
exp_avg_diffs: List[Tensor],
neg_pre_grads: List[Tensor],
*,
beta1: float,
beta2: float,
beta3: float,
bias_correction1: float,
bias_correction2: float,
bias_correction3_sqrt: float,
lr: float,
weight_decay: float,
eps: float,
no_prox: bool,
):
if len(params) == 0:
return
# for memory saving, we use `neg_pre_grads` to get some temp variable in a inplace way
torch._foreach_add_(neg_pre_grads, grads)
torch._foreach_mul_(exp_avgs, beta1)
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t
torch._foreach_mul_(exp_avg_diffs, beta2)
torch._foreach_add_(exp_avg_diffs, neg_pre_grads, alpha=1 - beta2) # diff_t
torch._foreach_mul_(neg_pre_grads, beta2)
torch._foreach_add_(neg_pre_grads, grads)
torch._foreach_mul_(exp_avg_sqs, beta3)
torch._foreach_addcmul_(exp_avg_sqs, neg_pre_grads, neg_pre_grads, value=1 - beta3) # n_t
denom = torch._foreach_sqrt(exp_avg_sqs)
torch._foreach_div_(denom, bias_correction3_sqrt)
torch._foreach_add_(denom, eps)
step_size_diff = lr * beta2 / bias_correction2
step_size = lr / bias_correction1
if no_prox:
torch._foreach_mul_(params, 1 - lr * weight_decay)
torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)
torch._foreach_addcdiv_(params, exp_avg_diffs, denom, value=-step_size_diff)
else:
torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size)
torch._foreach_addcdiv_(params, exp_avg_diffs, denom, value=-step_size_diff)
torch._foreach_div_(params, 1 + lr * weight_decay)
torch._foreach_zero_(neg_pre_grads)
torch._foreach_add_(neg_pre_grads, grads, alpha=-1.0)