289 lines
10 KiB
Python
289 lines
10 KiB
Python
|
from typing import List, Optional, Tuple, Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from torch import Tensor
|
||
|
from torch.optim import Optimizer
|
||
|
|
||
|
|
||
|
def _get_scalar_dtype():
|
||
|
"""Get the scalar dtype that the optimizer uses for state"""
|
||
|
return torch.float64
|
||
|
|
||
|
def _factored_dims(
|
||
|
shape: Tuple[int, ...],
|
||
|
factored: bool,
|
||
|
min_dim_size_to_factor: int
|
||
|
) -> Optional[tuple[int, int]]:
|
||
|
"""Whether to use a factored second moment estimator.
|
||
|
|
||
|
This function returns a tuple with the two largest axes to reduce over.
|
||
|
If no two dimensions have size >= min_dim_size_to_factor, return None.
|
||
|
|
||
|
Args:
|
||
|
shape: an input shape
|
||
|
factored: whether to use factored second-moment estimator for > 2d vars.
|
||
|
min_dim_size_to_factor: only factor accumulator if two array dimensions have at least this size.
|
||
|
|
||
|
Returns:
|
||
|
None or a tuple of ints
|
||
|
"""
|
||
|
if not factored or len(shape) < 2:
|
||
|
return None
|
||
|
sorted_dims = np.argsort(shape)
|
||
|
if shape[sorted_dims[-2]] < min_dim_size_to_factor:
|
||
|
return None
|
||
|
return int(sorted_dims[-2]), int(sorted_dims[-1])
|
||
|
|
||
|
|
||
|
class AdafactorBigVision(Optimizer):
|
||
|
"""
|
||
|
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
|
||
|
|
||
|
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
params,
|
||
|
lr: float = 1.0,
|
||
|
min_dim_size_to_factor: int = 32,
|
||
|
decay_rate: float = 0.8,
|
||
|
decay_offset: int = 0,
|
||
|
beta2_cap: float = 0.999,
|
||
|
momentum: Optional[float] = 0.9,
|
||
|
momentum_dtype: Union[str, torch.dtype] = torch.bfloat16,
|
||
|
eps: float = 1e-30,
|
||
|
weight_decay: float = 0.0,
|
||
|
clipping_threshold: Optional[float] = None,
|
||
|
unscaled_wd: bool = False,
|
||
|
*,
|
||
|
foreach: Optional[bool] = False,
|
||
|
):
|
||
|
if isinstance(momentum_dtype, str):
|
||
|
if momentum_dtype == 'float16':
|
||
|
momentum_dtype = torch.float16
|
||
|
elif momentum_dtype == 'bfloat16':
|
||
|
momentum_dtype = torch.bfloat16
|
||
|
else:
|
||
|
assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported'
|
||
|
momentum_dtype = torch.float32
|
||
|
|
||
|
defaults = dict(
|
||
|
lr=lr,
|
||
|
min_dim_size_to_factor=min_dim_size_to_factor,
|
||
|
decay_rate=decay_rate,
|
||
|
decay_offset=decay_offset,
|
||
|
beta2_cap=beta2_cap,
|
||
|
momentum=momentum,
|
||
|
momentum_dtype=momentum_dtype,
|
||
|
eps=eps,
|
||
|
weight_decay=weight_decay,
|
||
|
clipping_threshold=clipping_threshold,
|
||
|
unscaled_wd=unscaled_wd,
|
||
|
foreach=foreach,
|
||
|
)
|
||
|
super().__init__(params, defaults)
|
||
|
|
||
|
def __setstate__(self, state):
|
||
|
super().__setstate__(state)
|
||
|
for group in self.param_groups:
|
||
|
group.setdefault('foreach', None)
|
||
|
for p in group['params']:
|
||
|
p_state = self.state.get(p, {})
|
||
|
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||
|
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
|
||
|
|
||
|
def _get_beta2(self, step: Tensor, decay_rate: float, beta2_cap: float) -> float:
|
||
|
"""Computes beta2 according to the step schedule"""
|
||
|
t = float(step + 1)
|
||
|
return min(beta2_cap, 1.0 - t ** (-decay_rate))
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def step(self, closure=None):
|
||
|
loss = None
|
||
|
if closure is not None:
|
||
|
with torch.enable_grad():
|
||
|
loss = closure()
|
||
|
|
||
|
for group in self.param_groups:
|
||
|
params_with_grad = []
|
||
|
grads = []
|
||
|
exp_avg_sq_rs = []
|
||
|
exp_avg_sq_cs = []
|
||
|
exp_avg_sqs = []
|
||
|
state_steps = []
|
||
|
exp_avgs = [] # For momentum
|
||
|
|
||
|
for p in group['params']:
|
||
|
if p.grad is None:
|
||
|
continue
|
||
|
|
||
|
if p.grad.is_sparse:
|
||
|
raise RuntimeError("Sparse gradients not supported")
|
||
|
|
||
|
params_with_grad.append(p)
|
||
|
grads.append(p.grad)
|
||
|
|
||
|
state = self.state[p]
|
||
|
|
||
|
if len(state) == 0:
|
||
|
# NOTE step on CPU, probably need some more though to make capturable
|
||
|
state['step'] = torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||
|
|
||
|
shape = p.grad.shape
|
||
|
factored_dims = _factored_dims(
|
||
|
shape,
|
||
|
factored=True,
|
||
|
min_dim_size_to_factor=self.defaults['min_dim_size_to_factor']
|
||
|
)
|
||
|
|
||
|
if factored_dims is not None:
|
||
|
d1, d0 = factored_dims
|
||
|
row_shape = list(p.grad.shape)
|
||
|
row_shape[d0] = 1
|
||
|
col_shape = list(p.grad.shape)
|
||
|
col_shape[d1] = 1
|
||
|
state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape)
|
||
|
state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape)
|
||
|
else:
|
||
|
state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
|
||
|
|
||
|
if self.defaults['momentum'] is not None:
|
||
|
state['exp_avg'] = torch.zeros_like(p.grad, dtype=torch.bfloat16)
|
||
|
|
||
|
state_steps.append(state['step'])
|
||
|
exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))
|
||
|
exp_avg_sq_cs.append(state.get('exp_avg_sq_c', None))
|
||
|
exp_avg_sqs.append(state.get('exp_avg_sq', None))
|
||
|
exp_avgs.append(state.get('exp_avg', None))
|
||
|
|
||
|
if group['foreach']:
|
||
|
func = _multi_tensor_adafactor
|
||
|
else:
|
||
|
func = _single_tensor_adafactor
|
||
|
|
||
|
func(
|
||
|
params=params_with_grad,
|
||
|
grads=grads,
|
||
|
exp_avg_sq_rs=exp_avg_sq_rs,
|
||
|
exp_avg_sq_cs=exp_avg_sq_cs,
|
||
|
exp_avg_sqs=exp_avg_sqs,
|
||
|
exp_avgs=exp_avgs,
|
||
|
state_steps=state_steps,
|
||
|
beta2_decay=group['decay_rate'],
|
||
|
beta2_cap=group['beta2_cap'],
|
||
|
min_dim_size_to_factor=group['min_dim_size_to_factor'],
|
||
|
eps=group['eps'],
|
||
|
lr=group['lr'],
|
||
|
weight_decay=group['weight_decay'],
|
||
|
momentum=group['momentum'],
|
||
|
momentum_dtype=group['momentum_dtype'],
|
||
|
clipping_threshold=group['clipping_threshold'],
|
||
|
unscaled_wd=group['unscaled_wd'],
|
||
|
)
|
||
|
|
||
|
return loss
|
||
|
|
||
|
def _single_tensor_adafactor(
|
||
|
params: List[Tensor],
|
||
|
grads: List[Tensor],
|
||
|
exp_avg_sq_rs: List[Optional[Tensor]],
|
||
|
exp_avg_sq_cs: List[Optional[Tensor]],
|
||
|
exp_avg_sqs: List[Optional[Tensor]],
|
||
|
exp_avgs: List[Optional[Tensor]],
|
||
|
state_steps: List[Tensor],
|
||
|
*,
|
||
|
beta2_decay: float,
|
||
|
beta2_cap: float,
|
||
|
min_dim_size_to_factor: int,
|
||
|
eps: float,
|
||
|
lr: float,
|
||
|
weight_decay: float,
|
||
|
momentum: Optional[float],
|
||
|
momentum_dtype: Union[str, torch.dtype],
|
||
|
clipping_threshold: Optional[float],
|
||
|
unscaled_wd: bool,
|
||
|
):
|
||
|
for i, param in enumerate(params):
|
||
|
grad = grads[i]
|
||
|
exp_avg_sq_r = exp_avg_sq_rs[i]
|
||
|
exp_avg_sq_c = exp_avg_sq_cs[i]
|
||
|
exp_avg_sq = exp_avg_sqs[i]
|
||
|
exp_avg = exp_avgs[i]
|
||
|
step_t = state_steps[i]
|
||
|
|
||
|
# Update step
|
||
|
step_t += 1
|
||
|
beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay))
|
||
|
one_minus_beta2_t = 1 - beta2_t
|
||
|
|
||
|
if exp_avg_sq is None:
|
||
|
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
|
||
|
grad_sqr = torch.square(grad) + eps
|
||
|
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t)
|
||
|
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t)
|
||
|
|
||
|
reduced_d1 = d1 - 1 if d1 > d0 else d1
|
||
|
row_col_mean = exp_avg_sq_r.mean(dim=reduced_d1, keepdim=True)
|
||
|
row_factor = (exp_avg_sq_r / row_col_mean).rsqrt()
|
||
|
col_factor = exp_avg_sq_c.rsqrt()
|
||
|
|
||
|
update = grad * row_factor * col_factor
|
||
|
else:
|
||
|
# Handle non-factored
|
||
|
exp_avg_sq.mul_(beta2_t).addcmul_(grad, grad, value=one_minus_beta2_t)
|
||
|
update = grad * exp_avg_sq.add(eps).rsqrt_()
|
||
|
|
||
|
# Clip by RMS value
|
||
|
if clipping_threshold is not None:
|
||
|
denom = (update.norm(2) / ((update.numel() ** 0.5) / clipping_threshold)).clamp_(max=1.0)
|
||
|
update.div_(denom)
|
||
|
|
||
|
# Apply momentum (in different dtype)
|
||
|
if momentum is not None and exp_avg is not None:
|
||
|
if momentum_dtype != grad.dtype:
|
||
|
exp_avg.lerp_(update.to(momentum_dtype), 1 - momentum) # ema
|
||
|
update = exp_avg.to(grad.dtype)
|
||
|
else:
|
||
|
exp_avg.lerp_(update, 1 - momentum) # ema
|
||
|
update = exp_avg.clone()
|
||
|
|
||
|
# Scale by learning rate
|
||
|
update.mul_(lr)
|
||
|
|
||
|
# Perform weight decay
|
||
|
if weight_decay != 0:
|
||
|
if unscaled_wd:
|
||
|
# match big vision impl, 'fully decoupled' decay w/o LR scaling
|
||
|
param.mul_(1. - weight_decay)
|
||
|
else:
|
||
|
# match typical pytorch behaviour for decoupled decay, eg adamw where wd is scaled by LR
|
||
|
param.mul_(1. - lr * weight_decay)
|
||
|
|
||
|
# Update parameters
|
||
|
param.add_(update, alpha=-1.0)
|
||
|
|
||
|
def _multi_tensor_adafactor(
|
||
|
params: List[Tensor],
|
||
|
grads: List[Tensor],
|
||
|
exp_avg_sq_rs: List[Optional[Tensor]],
|
||
|
exp_avg_sq_cs: List[Optional[Tensor]],
|
||
|
exp_avg_sqs: List[Optional[Tensor]],
|
||
|
exp_avgs: List[Optional[Tensor]],
|
||
|
state_steps: List[Tensor],
|
||
|
*,
|
||
|
beta2_decay: float,
|
||
|
beta2_cap: float,
|
||
|
min_dim_size_to_factor: int,
|
||
|
eps: float,
|
||
|
lr: float,
|
||
|
weight_decay: float,
|
||
|
momentum: Optional[float],
|
||
|
momentum_dtype: Union[str, torch.dtype],
|
||
|
clipping_threshold: Optional[float],
|
||
|
unscaled_wd: bool,
|
||
|
):
|
||
|
assert False, 'multi-tensor fn (foreach=True) not implemented yet'
|