An impl of adafactor as per big vision (scaling vit) changes
parent
e31e5d2d64
commit
7c16adca83
|
@ -4,6 +4,7 @@ from .adahessian import Adahessian
|
|||
from .adamp import AdamP
|
||||
from .adamw import AdamW
|
||||
from .adan import Adan
|
||||
from .adafactor_bv import AdafactorBigVision
|
||||
from .lamb import Lamb
|
||||
from .lars import Lars
|
||||
from .lookahead import Lookahead
|
||||
|
|
|
@ -0,0 +1,288 @@
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
def _get_scalar_dtype():
|
||||
"""Get the scalar dtype that the optimizer uses for state"""
|
||||
return torch.float64
|
||||
|
||||
def _factored_dims(
|
||||
shape: Tuple[int, ...],
|
||||
factored: bool,
|
||||
min_dim_size_to_factor: int
|
||||
) -> Optional[tuple[int, int]]:
|
||||
"""Whether to use a factored second moment estimator.
|
||||
|
||||
This function returns a tuple with the two largest axes to reduce over.
|
||||
If no two dimensions have size >= min_dim_size_to_factor, return None.
|
||||
|
||||
Args:
|
||||
shape: an input shape
|
||||
factored: whether to use factored second-moment estimator for > 2d vars.
|
||||
min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
|
||||
|
||||
Returns:
|
||||
None or a tuple of ints
|
||||
"""
|
||||
if not factored or len(shape) < 2:
|
||||
return None
|
||||
sorted_dims = np.argsort(shape)
|
||||
if shape[sorted_dims[-2]] < min_dim_size_to_factor:
|
||||
return None
|
||||
return int(sorted_dims[-2]), int(sorted_dims[-1])
|
||||
|
||||
|
||||
class AdafactorBigVision(Optimizer):
|
||||
"""
|
||||
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr: float = 1.0,
|
||||
min_dim_size_to_factor: int = 32,
|
||||
decay_rate: float = 0.8,
|
||||
decay_offset: int = 0,
|
||||
beta2_cap: float = 0.999,
|
||||
momentum: Optional[float] = 0.9,
|
||||
momentum_dtype: Union[str, torch.dtype] = torch.bfloat16,
|
||||
eps: float = 1e-30,
|
||||
weight_decay: float = 0.0,
|
||||
clipping_threshold: Optional[float] = None,
|
||||
unscaled_wd: bool = False,
|
||||
*,
|
||||
foreach: Optional[bool] = False,
|
||||
):
|
||||
if isinstance(momentum_dtype, str):
|
||||
if momentum_dtype == 'float16':
|
||||
momentum_dtype = torch.float16
|
||||
elif momentum_dtype == 'bfloat16':
|
||||
momentum_dtype = torch.bfloat16
|
||||
else:
|
||||
assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported'
|
||||
momentum_dtype = torch.float32
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
min_dim_size_to_factor=min_dim_size_to_factor,
|
||||
decay_rate=decay_rate,
|
||||
decay_offset=decay_offset,
|
||||
beta2_cap=beta2_cap,
|
||||
momentum=momentum,
|
||||
momentum_dtype=momentum_dtype,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
clipping_threshold=clipping_threshold,
|
||||
unscaled_wd=unscaled_wd,
|
||||
foreach=foreach,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('foreach', None)
|
||||
for p in group['params']:
|
||||
p_state = self.state.get(p, {})
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
|
||||
|
||||
def _get_beta2(self, step: Tensor, decay_rate: float, beta2_cap: float) -> float:
|
||||
"""Computes beta2 according to the step schedule"""
|
||||
t = float(step + 1)
|
||||
return min(beta2_cap, 1.0 - t ** (-decay_rate))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avg_sq_rs = []
|
||||
exp_avg_sq_cs = []
|
||||
exp_avg_sqs = []
|
||||
state_steps = []
|
||||
exp_avgs = [] # For momentum
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError("Sparse gradients not supported")
|
||||
|
||||
params_with_grad.append(p)
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
# NOTE step on CPU, probably need some more though to make capturable
|
||||
state['step'] = torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||
|
||||
shape = p.grad.shape
|
||||
factored_dims = _factored_dims(
|
||||
shape,
|
||||
factored=True,
|
||||
min_dim_size_to_factor=self.defaults['min_dim_size_to_factor']
|
||||
)
|
||||
|
||||
if factored_dims is not None:
|
||||
d1, d0 = factored_dims
|
||||
row_shape = list(p.grad.shape)
|
||||
row_shape[d0] = 1
|
||||
col_shape = list(p.grad.shape)
|
||||
col_shape[d1] = 1
|
||||
state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape)
|
||||
state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape)
|
||||
else:
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
|
||||
|
||||
if self.defaults['momentum'] is not None:
|
||||
state['exp_avg'] = torch.zeros_like(p.grad, dtype=torch.bfloat16)
|
||||
|
||||
state_steps.append(state['step'])
|
||||
exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))
|
||||
exp_avg_sq_cs.append(state.get('exp_avg_sq_c', None))
|
||||
exp_avg_sqs.append(state.get('exp_avg_sq', None))
|
||||
exp_avgs.append(state.get('exp_avg', None))
|
||||
|
||||
if group['foreach']:
|
||||
func = _multi_tensor_adafactor
|
||||
else:
|
||||
func = _single_tensor_adafactor
|
||||
|
||||
func(
|
||||
params=params_with_grad,
|
||||
grads=grads,
|
||||
exp_avg_sq_rs=exp_avg_sq_rs,
|
||||
exp_avg_sq_cs=exp_avg_sq_cs,
|
||||
exp_avg_sqs=exp_avg_sqs,
|
||||
exp_avgs=exp_avgs,
|
||||
state_steps=state_steps,
|
||||
beta2_decay=group['decay_rate'],
|
||||
beta2_cap=group['beta2_cap'],
|
||||
min_dim_size_to_factor=group['min_dim_size_to_factor'],
|
||||
eps=group['eps'],
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
momentum=group['momentum'],
|
||||
momentum_dtype=group['momentum_dtype'],
|
||||
clipping_threshold=group['clipping_threshold'],
|
||||
unscaled_wd=group['unscaled_wd'],
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def _single_tensor_adafactor(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avg_sq_rs: List[Optional[Tensor]],
|
||||
exp_avg_sq_cs: List[Optional[Tensor]],
|
||||
exp_avg_sqs: List[Optional[Tensor]],
|
||||
exp_avgs: List[Optional[Tensor]],
|
||||
state_steps: List[Tensor],
|
||||
*,
|
||||
beta2_decay: float,
|
||||
beta2_cap: float,
|
||||
min_dim_size_to_factor: int,
|
||||
eps: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
momentum: Optional[float],
|
||||
momentum_dtype: Union[str, torch.dtype],
|
||||
clipping_threshold: Optional[float],
|
||||
unscaled_wd: bool,
|
||||
):
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i]
|
||||
exp_avg_sq_r = exp_avg_sq_rs[i]
|
||||
exp_avg_sq_c = exp_avg_sq_cs[i]
|
||||
exp_avg_sq = exp_avg_sqs[i]
|
||||
exp_avg = exp_avgs[i]
|
||||
step_t = state_steps[i]
|
||||
|
||||
# Update step
|
||||
step_t += 1
|
||||
beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay))
|
||||
one_minus_beta2_t = 1 - beta2_t
|
||||
|
||||
if exp_avg_sq is None:
|
||||
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
|
||||
grad_sqr = torch.square(grad) + eps
|
||||
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t)
|
||||
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t)
|
||||
|
||||
reduced_d1 = d1 - 1 if d1 > d0 else d1
|
||||
row_col_mean = exp_avg_sq_r.mean(dim=reduced_d1, keepdim=True)
|
||||
row_factor = (exp_avg_sq_r / row_col_mean).rsqrt()
|
||||
col_factor = exp_avg_sq_c.rsqrt()
|
||||
|
||||
update = grad * row_factor * col_factor
|
||||
else:
|
||||
# Handle non-factored
|
||||
exp_avg_sq.mul_(beta2_t).addcmul_(grad, grad, value=one_minus_beta2_t)
|
||||
update = grad * exp_avg_sq.add(eps).rsqrt_()
|
||||
|
||||
# Clip by RMS value
|
||||
if clipping_threshold is not None:
|
||||
denom = (update.norm(2) / ((update.numel() ** 0.5) / clipping_threshold)).clamp_(max=1.0)
|
||||
update.div_(denom)
|
||||
|
||||
# Apply momentum (in different dtype)
|
||||
if momentum is not None and exp_avg is not None:
|
||||
if momentum_dtype != grad.dtype:
|
||||
exp_avg.lerp_(update.to(momentum_dtype), 1 - momentum) # ema
|
||||
update = exp_avg.to(grad.dtype)
|
||||
else:
|
||||
exp_avg.lerp_(update, 1 - momentum) # ema
|
||||
update = exp_avg.clone()
|
||||
|
||||
# Scale by learning rate
|
||||
update.mul_(lr)
|
||||
|
||||
# Perform weight decay
|
||||
if weight_decay != 0:
|
||||
if unscaled_wd:
|
||||
# match big vision impl, 'fully decoupled' decay w/o LR scaling
|
||||
param.mul_(1. - weight_decay)
|
||||
else:
|
||||
# match typical pytorch behaviour for decoupled decay, eg adamw where wd is scaled by LR
|
||||
param.mul_(1. - lr * weight_decay)
|
||||
|
||||
# Update parameters
|
||||
param.add_(update, alpha=-1.0)
|
||||
|
||||
def _multi_tensor_adafactor(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avg_sq_rs: List[Optional[Tensor]],
|
||||
exp_avg_sq_cs: List[Optional[Tensor]],
|
||||
exp_avg_sqs: List[Optional[Tensor]],
|
||||
exp_avgs: List[Optional[Tensor]],
|
||||
state_steps: List[Tensor],
|
||||
*,
|
||||
beta2_decay: float,
|
||||
beta2_cap: float,
|
||||
min_dim_size_to_factor: int,
|
||||
eps: float,
|
||||
lr: float,
|
||||
weight_decay: float,
|
||||
momentum: Optional[float],
|
||||
momentum_dtype: Union[str, torch.dtype],
|
||||
clipping_threshold: Optional[float],
|
||||
unscaled_wd: bool,
|
||||
):
|
||||
assert False, 'multi-tensor fn (foreach=True) not implemented yet'
|
|
@ -10,6 +10,7 @@ import torch.nn as nn
|
|||
import torch.optim as optim
|
||||
|
||||
from timm.models import group_parameters
|
||||
from . import AdafactorBigVision
|
||||
|
||||
from .adabelief import AdaBelief
|
||||
from .adafactor import Adafactor
|
||||
|
@ -356,6 +357,8 @@ def create_optimizer_v2(
|
|||
elif opt_lower == 'lion':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = Lion(parameters, **opt_args)
|
||||
elif opt_lower == 'adafactorbv':
|
||||
optimizer = AdafactorBigVision(parameters, **opt_args)
|
||||
|
||||
# second order
|
||||
elif opt_lower == 'adahessian':
|
||||
|
|
Loading…
Reference in New Issue