Add cautious mars, improve test reliability by skipping grad diff for first step
parent
82e8677690
commit
303f7691a1
|
@ -300,6 +300,8 @@ def test_optim_factory(optimizer):
|
|||
lr = (1e-2,) * 4
|
||||
if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'):
|
||||
lr = (1e-3,) * 4
|
||||
elif optimizer in ('cmars',):
|
||||
lr = (1e-4,) * 4
|
||||
|
||||
try:
|
||||
if not opt_info.second_order: # basic tests don't support second order right now
|
||||
|
|
|
@ -572,6 +572,13 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
|
|||
has_betas=True,
|
||||
defaults = {'caution': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='cmars',
|
||||
opt_class=Mars,
|
||||
description='Cautious MARS',
|
||||
has_betas=True,
|
||||
defaults={'caution': True}
|
||||
),
|
||||
OptimInfo(
|
||||
name='cnadamw',
|
||||
opt_class=NAdamW,
|
||||
|
|
|
@ -14,38 +14,50 @@ Paper: MARS: Unleashing the Power of Variance Reduction for Training Large Model
|
|||
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from ._types import ParamsT
|
||||
|
||||
def mars_single_tensor(
|
||||
p,
|
||||
grad,
|
||||
exp_avg,
|
||||
exp_avg_sq,
|
||||
lr,
|
||||
weight_decay,
|
||||
beta1,
|
||||
beta2,
|
||||
last_grad,
|
||||
eps,
|
||||
step,
|
||||
gamma,
|
||||
mars_type,
|
||||
is_grad_2d,
|
||||
optimize_1d,
|
||||
lr_1d_factor,
|
||||
betas_1d,
|
||||
|
||||
def _mars_single_tensor_step(
|
||||
p: torch.Tensor,
|
||||
grad: torch.Tensor,
|
||||
exp_avg: torch.Tensor,
|
||||
exp_avg_sq: torch.Tensor,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
last_grad: torch.Tensor,
|
||||
eps: float,
|
||||
step: int,
|
||||
gamma: float,
|
||||
mars_type: str,
|
||||
is_grad_2d: bool,
|
||||
optimize_1d: bool,
|
||||
lr_1d_factor: bool,
|
||||
betas_1d: Tuple[float, float],
|
||||
caution: bool,
|
||||
):
|
||||
# optimize_1d: use MARS for 1d para, not: use AdamW for 1d para
|
||||
# optimize_1d ==> use MARS for 1d param, else use AdamW
|
||||
if optimize_1d or is_grad_2d:
|
||||
one_minus_beta1 = 1. - beta1
|
||||
c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad)
|
||||
c_t_norm = torch.norm(c_t)
|
||||
if c_t_norm > 1.:
|
||||
c_t = c_t / c_t_norm
|
||||
if step == 1:
|
||||
# this is a timm addition, making first step more consistent when no grad history, otherwise tests fail
|
||||
c_t = grad
|
||||
else:
|
||||
c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad)
|
||||
c_t_norm = torch.norm(c_t)
|
||||
if c_t_norm > 1.:
|
||||
c_t = c_t / c_t_norm
|
||||
exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1)
|
||||
if caution:
|
||||
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||
exp_avg = exp_avg * mask
|
||||
if mars_type == "adamw":
|
||||
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
|
||||
bias_correction1 = 1.0 - beta1 ** step
|
||||
|
@ -64,6 +76,10 @@ def mars_single_tensor(
|
|||
bias_correction1 = 1.0 - beta1_1d ** step
|
||||
bias_correction2 = 1.0 - beta2_1d ** step
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
||||
if caution:
|
||||
mask = (exp_avg * grad > 0).to(grad.dtype)
|
||||
mask.div_(mask.mean().clamp_(min=1e-3))
|
||||
exp_avg = exp_avg * mask
|
||||
update = p * weight_decay + (exp_avg / bias_correction1).div_(denom)
|
||||
p.add_(update, alpha=-(lr * lr_1d_factor))
|
||||
return exp_avg, exp_avg_sq
|
||||
|
@ -78,16 +94,17 @@ class Mars(Optimizer):
|
|||
"""
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=3e-3,
|
||||
betas=(0.9, 0.99),
|
||||
eps=1e-8,
|
||||
weight_decay=0.,
|
||||
gamma=0.025,
|
||||
mars_type="adamw",
|
||||
optimize_1d=False,
|
||||
lr_1d_factor=1.0,
|
||||
betas_1d=None,
|
||||
params: ParamsT,
|
||||
lr: float = 3e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.99),
|
||||
eps: float = 1e-8,
|
||||
weight_decay: float = 0.,
|
||||
gamma: float = 0.025,
|
||||
mars_type: str = "adamw",
|
||||
optimize_1d: bool = False,
|
||||
lr_1d_factor: float = 1.0,
|
||||
betas_1d: Optional[Tuple[float, float]] = None,
|
||||
caution: bool = False
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
|
@ -109,9 +126,15 @@ class Mars(Optimizer):
|
|||
optimize_1d=optimize_1d,
|
||||
lr_1d_factor=lr_1d_factor,
|
||||
betas_1d=betas_1d or betas,
|
||||
caution=caution,
|
||||
)
|
||||
super(Mars, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(Mars, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('caution', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
@ -134,7 +157,6 @@ class Mars(Optimizer):
|
|||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
|
||||
state = self.state[p]
|
||||
# ('----- starting a parameter state', state.keys(), 'Length of state', len(state))
|
||||
# State initialization
|
||||
if len(state) <= 1:
|
||||
state['step'] = 0
|
||||
|
@ -155,7 +177,8 @@ class Mars(Optimizer):
|
|||
beta1, beta2 = group['betas']
|
||||
is_grad_2d = grad.ndim >= 2
|
||||
|
||||
mars_single_tensor(
|
||||
# FIXME add multi-tensor (if usage warrants), make more standard
|
||||
_mars_single_tensor_step(
|
||||
p,
|
||||
grad,
|
||||
exp_avg,
|
||||
|
@ -173,6 +196,7 @@ class Mars(Optimizer):
|
|||
optimize_1d=group['optimize_1d'],
|
||||
lr_1d_factor=group['lr_1d_factor'],
|
||||
betas_1d=group['betas_1d'],
|
||||
caution=group['caution'],
|
||||
)
|
||||
|
||||
state['last_grad'] = grad
|
||||
|
|
Loading…
Reference in New Issue