Add cautious mars, improve test reliability by skipping grad diff for first step

convnormact_aa_none
Ross Wightman 2024-12-02 09:38:25 -08:00 committed by Ross Wightman
parent 82e8677690
commit 303f7691a1
3 changed files with 68 additions and 35 deletions

View File

@ -300,6 +300,8 @@ def test_optim_factory(optimizer):
lr = (1e-2,) * 4
if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'):
lr = (1e-3,) * 4
elif optimizer in ('cmars',):
lr = (1e-4,) * 4
try:
if not opt_info.second_order: # basic tests don't support second order right now

View File

@ -572,6 +572,13 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
has_betas=True,
defaults = {'caution': True}
),
OptimInfo(
name='cmars',
opt_class=Mars,
description='Cautious MARS',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='cnadamw',
opt_class=NAdamW,

View File

@ -14,38 +14,50 @@ Paper: MARS: Unleashing the Power of Variance Reduction for Training Large Model
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Optional, Tuple
import torch
from torch.optim.optimizer import Optimizer
from ._types import ParamsT
def mars_single_tensor(
p,
grad,
exp_avg,
exp_avg_sq,
lr,
weight_decay,
beta1,
beta2,
last_grad,
eps,
step,
gamma,
mars_type,
is_grad_2d,
optimize_1d,
lr_1d_factor,
betas_1d,
def _mars_single_tensor_step(
p: torch.Tensor,
grad: torch.Tensor,
exp_avg: torch.Tensor,
exp_avg_sq: torch.Tensor,
lr: float,
weight_decay: float,
beta1: float,
beta2: float,
last_grad: torch.Tensor,
eps: float,
step: int,
gamma: float,
mars_type: str,
is_grad_2d: bool,
optimize_1d: bool,
lr_1d_factor: bool,
betas_1d: Tuple[float, float],
caution: bool,
):
# optimize_1d: use MARS for 1d para, not: use AdamW for 1d para
# optimize_1d ==> use MARS for 1d param, else use AdamW
if optimize_1d or is_grad_2d:
one_minus_beta1 = 1. - beta1
c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad)
c_t_norm = torch.norm(c_t)
if c_t_norm > 1.:
c_t = c_t / c_t_norm
if step == 1:
# this is a timm addition, making first step more consistent when no grad history, otherwise tests fail
c_t = grad
else:
c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad)
c_t_norm = torch.norm(c_t)
if c_t_norm > 1.:
c_t = c_t / c_t_norm
exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1)
if caution:
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
if mars_type == "adamw":
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
bias_correction1 = 1.0 - beta1 ** step
@ -64,6 +76,10 @@ def mars_single_tensor(
bias_correction1 = 1.0 - beta1_1d ** step
bias_correction2 = 1.0 - beta2_1d ** step
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
if caution:
mask = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
update = p * weight_decay + (exp_avg / bias_correction1).div_(denom)
p.add_(update, alpha=-(lr * lr_1d_factor))
return exp_avg, exp_avg_sq
@ -78,16 +94,17 @@ class Mars(Optimizer):
"""
def __init__(
self,
params,
lr=3e-3,
betas=(0.9, 0.99),
eps=1e-8,
weight_decay=0.,
gamma=0.025,
mars_type="adamw",
optimize_1d=False,
lr_1d_factor=1.0,
betas_1d=None,
params: ParamsT,
lr: float = 3e-3,
betas: Tuple[float, float] = (0.9, 0.99),
eps: float = 1e-8,
weight_decay: float = 0.,
gamma: float = 0.025,
mars_type: str = "adamw",
optimize_1d: bool = False,
lr_1d_factor: float = 1.0,
betas_1d: Optional[Tuple[float, float]] = None,
caution: bool = False
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
@ -109,9 +126,15 @@ class Mars(Optimizer):
optimize_1d=optimize_1d,
lr_1d_factor=lr_1d_factor,
betas_1d=betas_1d or betas,
caution=caution,
)
super(Mars, self).__init__(params, defaults)
def __setstate__(self, state):
super(Mars, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('caution', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@ -134,7 +157,6 @@ class Mars(Optimizer):
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# ('----- starting a parameter state', state.keys(), 'Length of state', len(state))
# State initialization
if len(state) <= 1:
state['step'] = 0
@ -155,7 +177,8 @@ class Mars(Optimizer):
beta1, beta2 = group['betas']
is_grad_2d = grad.ndim >= 2
mars_single_tensor(
# FIXME add multi-tensor (if usage warrants), make more standard
_mars_single_tensor_step(
p,
grad,
exp_avg,
@ -173,6 +196,7 @@ class Mars(Optimizer):
optimize_1d=group['optimize_1d'],
lr_1d_factor=group['lr_1d_factor'],
betas_1d=group['betas_1d'],
caution=group['caution'],
)
state['last_grad'] = grad