An impl of adafactor as per big vision (scaling vit) changes

small_384_weights
Ross Wightman 2024-11-03 17:08:58 -08:00 committed by Ross Wightman
parent e31e5d2d64
commit 7c16adca83
3 changed files with 292 additions and 0 deletions

View File

@ -4,6 +4,7 @@ from .adahessian import Adahessian
from .adamp import AdamP
from .adamw import AdamW
from .adan import Adan
from .adafactor_bv import AdafactorBigVision
from .lamb import Lamb
from .lars import Lars
from .lookahead import Lookahead

View File

@ -0,0 +1,288 @@
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from torch import Tensor
from torch.optim import Optimizer
def _get_scalar_dtype():
"""Get the scalar dtype that the optimizer uses for state"""
return torch.float64
def _factored_dims(
shape: Tuple[int, ...],
factored: bool,
min_dim_size_to_factor: int
) -> Optional[tuple[int, int]]:
"""Whether to use a factored second moment estimator.
This function returns a tuple with the two largest axes to reduce over.
If no two dimensions have size >= min_dim_size_to_factor, return None.
Args:
shape: an input shape
factored: whether to use factored second-moment estimator for > 2d vars.
min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
Returns:
None or a tuple of ints
"""
if not factored or len(shape) < 2:
return None
sorted_dims = np.argsort(shape)
if shape[sorted_dims[-2]] < min_dim_size_to_factor:
return None
return int(sorted_dims[-2]), int(sorted_dims[-1])
class AdafactorBigVision(Optimizer):
"""
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
"""
def __init__(
self,
params,
lr: float = 1.0,
min_dim_size_to_factor: int = 32,
decay_rate: float = 0.8,
decay_offset: int = 0,
beta2_cap: float = 0.999,
momentum: Optional[float] = 0.9,
momentum_dtype: Union[str, torch.dtype] = torch.bfloat16,
eps: float = 1e-30,
weight_decay: float = 0.0,
clipping_threshold: Optional[float] = None,
unscaled_wd: bool = False,
*,
foreach: Optional[bool] = False,
):
if isinstance(momentum_dtype, str):
if momentum_dtype == 'float16':
momentum_dtype = torch.float16
elif momentum_dtype == 'bfloat16':
momentum_dtype = torch.bfloat16
else:
assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported'
momentum_dtype = torch.float32
defaults = dict(
lr=lr,
min_dim_size_to_factor=min_dim_size_to_factor,
decay_rate=decay_rate,
decay_offset=decay_offset,
beta2_cap=beta2_cap,
momentum=momentum,
momentum_dtype=momentum_dtype,
eps=eps,
weight_decay=weight_decay,
clipping_threshold=clipping_threshold,
unscaled_wd=unscaled_wd,
foreach=foreach,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('foreach', None)
for p in group['params']:
p_state = self.state.get(p, {})
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
def _get_beta2(self, step: Tensor, decay_rate: float, beta2_cap: float) -> float:
"""Computes beta2 according to the step schedule"""
t = float(step + 1)
return min(beta2_cap, 1.0 - t ** (-decay_rate))
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avg_sq_rs = []
exp_avg_sq_cs = []
exp_avg_sqs = []
state_steps = []
exp_avgs = [] # For momentum
for p in group['params']:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError("Sparse gradients not supported")
params_with_grad.append(p)
grads.append(p.grad)
state = self.state[p]
if len(state) == 0:
# NOTE step on CPU, probably need some more though to make capturable
state['step'] = torch.tensor(0.0, dtype=_get_scalar_dtype())
shape = p.grad.shape
factored_dims = _factored_dims(
shape,
factored=True,
min_dim_size_to_factor=self.defaults['min_dim_size_to_factor']
)
if factored_dims is not None:
d1, d0 = factored_dims
row_shape = list(p.grad.shape)
row_shape[d0] = 1
col_shape = list(p.grad.shape)
col_shape[d1] = 1
state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape)
state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape)
else:
state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
if self.defaults['momentum'] is not None:
state['exp_avg'] = torch.zeros_like(p.grad, dtype=torch.bfloat16)
state_steps.append(state['step'])
exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))
exp_avg_sq_cs.append(state.get('exp_avg_sq_c', None))
exp_avg_sqs.append(state.get('exp_avg_sq', None))
exp_avgs.append(state.get('exp_avg', None))
if group['foreach']:
func = _multi_tensor_adafactor
else:
func = _single_tensor_adafactor
func(
params=params_with_grad,
grads=grads,
exp_avg_sq_rs=exp_avg_sq_rs,
exp_avg_sq_cs=exp_avg_sq_cs,
exp_avg_sqs=exp_avg_sqs,
exp_avgs=exp_avgs,
state_steps=state_steps,
beta2_decay=group['decay_rate'],
beta2_cap=group['beta2_cap'],
min_dim_size_to_factor=group['min_dim_size_to_factor'],
eps=group['eps'],
lr=group['lr'],
weight_decay=group['weight_decay'],
momentum=group['momentum'],
momentum_dtype=group['momentum_dtype'],
clipping_threshold=group['clipping_threshold'],
unscaled_wd=group['unscaled_wd'],
)
return loss
def _single_tensor_adafactor(
params: List[Tensor],
grads: List[Tensor],
exp_avg_sq_rs: List[Optional[Tensor]],
exp_avg_sq_cs: List[Optional[Tensor]],
exp_avg_sqs: List[Optional[Tensor]],
exp_avgs: List[Optional[Tensor]],
state_steps: List[Tensor],
*,
beta2_decay: float,
beta2_cap: float,
min_dim_size_to_factor: int,
eps: float,
lr: float,
weight_decay: float,
momentum: Optional[float],
momentum_dtype: Union[str, torch.dtype],
clipping_threshold: Optional[float],
unscaled_wd: bool,
):
for i, param in enumerate(params):
grad = grads[i]
exp_avg_sq_r = exp_avg_sq_rs[i]
exp_avg_sq_c = exp_avg_sq_cs[i]
exp_avg_sq = exp_avg_sqs[i]
exp_avg = exp_avgs[i]
step_t = state_steps[i]
# Update step
step_t += 1
beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay))
one_minus_beta2_t = 1 - beta2_t
if exp_avg_sq is None:
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
grad_sqr = torch.square(grad) + eps
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t)
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t)
reduced_d1 = d1 - 1 if d1 > d0 else d1
row_col_mean = exp_avg_sq_r.mean(dim=reduced_d1, keepdim=True)
row_factor = (exp_avg_sq_r / row_col_mean).rsqrt()
col_factor = exp_avg_sq_c.rsqrt()
update = grad * row_factor * col_factor
else:
# Handle non-factored
exp_avg_sq.mul_(beta2_t).addcmul_(grad, grad, value=one_minus_beta2_t)
update = grad * exp_avg_sq.add(eps).rsqrt_()
# Clip by RMS value
if clipping_threshold is not None:
denom = (update.norm(2) / ((update.numel() ** 0.5) / clipping_threshold)).clamp_(max=1.0)
update.div_(denom)
# Apply momentum (in different dtype)
if momentum is not None and exp_avg is not None:
if momentum_dtype != grad.dtype:
exp_avg.lerp_(update.to(momentum_dtype), 1 - momentum) # ema
update = exp_avg.to(grad.dtype)
else:
exp_avg.lerp_(update, 1 - momentum) # ema
update = exp_avg.clone()
# Scale by learning rate
update.mul_(lr)
# Perform weight decay
if weight_decay != 0:
if unscaled_wd:
# match big vision impl, 'fully decoupled' decay w/o LR scaling
param.mul_(1. - weight_decay)
else:
# match typical pytorch behaviour for decoupled decay, eg adamw where wd is scaled by LR
param.mul_(1. - lr * weight_decay)
# Update parameters
param.add_(update, alpha=-1.0)
def _multi_tensor_adafactor(
params: List[Tensor],
grads: List[Tensor],
exp_avg_sq_rs: List[Optional[Tensor]],
exp_avg_sq_cs: List[Optional[Tensor]],
exp_avg_sqs: List[Optional[Tensor]],
exp_avgs: List[Optional[Tensor]],
state_steps: List[Tensor],
*,
beta2_decay: float,
beta2_cap: float,
min_dim_size_to_factor: int,
eps: float,
lr: float,
weight_decay: float,
momentum: Optional[float],
momentum_dtype: Union[str, torch.dtype],
clipping_threshold: Optional[float],
unscaled_wd: bool,
):
assert False, 'multi-tensor fn (foreach=True) not implemented yet'

View File

@ -10,6 +10,7 @@ import torch.nn as nn
import torch.optim as optim
from timm.models import group_parameters
from . import AdafactorBigVision
from .adabelief import AdaBelief
from .adafactor import Adafactor
@ -356,6 +357,8 @@ def create_optimizer_v2(
elif opt_lower == 'lion':
opt_args.pop('eps', None)
optimizer = Lion(parameters, **opt_args)
elif opt_lower == 'adafactorbv':
optimizer = AdafactorBigVision(parameters, **opt_args)
# second order
elif opt_lower == 'adahessian':