Cleanup original adafactor impl, add row/col dim heuristic that works with both conv and linear layers

This commit is contained in:
Ross Wightman 2024-11-08 08:35:25 -08:00 committed by Ross Wightman
parent 1409ce2dbe
commit e7b0480381

View File

@ -38,16 +38,38 @@ class Adafactor(torch.optim.Optimizer):
whether warm-up initialization is being used (default: False)
"""
def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
def __init__(
self,
params,
lr=None,
eps=1e-30,
eps_scale=1e-3,
clip_threshold=1.0,
decay_rate=-0.8,
betas=None,
weight_decay=0.0,
scale_parameter=True,
warmup_init=False,
min_dim_size_to_factor=32,
):
relative_step = not lr
if warmup_init and not relative_step:
raise ValueError('warmup_init requires relative_step=True')
beta1 = None if betas is None else betas[0] # make it compat with standard betas arg
defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate,
beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
relative_step=relative_step, warmup_init=warmup_init)
defaults = dict(
lr=lr,
eps=eps,
eps_scale=eps_scale,
clip_threshold=clip_threshold,
decay_rate=decay_rate,
beta1=beta1,
weight_decay=weight_decay,
scale_parameter=scale_parameter,
relative_step=relative_step,
warmup_init=warmup_init,
min_dim_size_to_factor=min_dim_size_to_factor,
)
super(Adafactor, self).__init__(params, defaults)
@staticmethod
@ -62,20 +84,34 @@ class Adafactor(torch.optim.Optimizer):
return param_group['lr']
@staticmethod
def _get_options(param_group, param_shape):
factored = len(param_shape) >= 2
def _get_options(param_group, param_shape, min_size_to_factor=32):
use_first_moment = param_group['beta1'] is not None
factored = None
ndim = len(param_shape)
# Use a simple heuristic to pick factorization row & col, note other PyTorch impl tend to
# always use -2, -1 BUT this will not pick correct dims for convolutions. This is a simple
# approach that should work in most cases, compare to the slightly more involved approach
# in AdafactorBigVision that sorts dims by size, please report if wrong dims chosen.
if ndim > 2 and param_shape[0] > min_size_to_factor and param_shape[1] > min_size_to_factor:
# nD convs in torch are ND + 2 dim weights with leading in/out chs
factored = 0, 1
elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
# if the criteria above didn't match, check trailing dims
factored = ndim - 2, ndim - 1
return factored, use_first_moment
@staticmethod
def _rms(tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, dim_col, dim_row):
# from our dim heuristic, always dim_col < dim_row, so col reduction dim for factored row = dim_col
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=dim_col, keepdim=True)).rsqrt_().unsqueeze(dim_row)
c_factor = exp_avg_sq_col.unsqueeze(dim_col).rsqrt()
return torch.mul(r_factor, c_factor)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@ -99,7 +135,11 @@ class Adafactor(torch.optim.Optimizer):
state = self.state[p]
factored, use_first_moment = self._get_options(group, grad.shape)
factored_dims, use_first_moment = self._get_options(
group,
grad.shape,
min_size_to_factor=group['min_dim_size_to_factor'],
)
# State Initialization
if len(state) == 0:
state['step'] = 0
@ -107,9 +147,12 @@ class Adafactor(torch.optim.Optimizer):
if use_first_moment:
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(grad)
if factored:
state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad)
state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad)
if factored_dims is not None:
dim_col, dim_row = factored_dims
def _remove_dim(shape, dim):
return shape[:dim] + shape[dim + 1:]
state['exp_avg_sq_row'] = torch.zeros(_remove_dim(grad.shape, dim_row)).to(grad)
state['exp_avg_sq_col'] = torch.zeros(_remove_dim(grad.shape, dim_col)).to(grad)
else:
state['exp_avg_sq'] = torch.zeros_like(grad)
@ -117,7 +160,7 @@ class Adafactor(torch.optim.Optimizer):
else:
if use_first_moment:
state['exp_avg'] = state['exp_avg'].to(grad)
if factored:
if factored_dims is not None:
state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
else:
@ -133,15 +176,16 @@ class Adafactor(torch.optim.Optimizer):
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
update = grad ** 2 + group['eps']
if factored:
if factored_dims is not None:
dim_col, dim_row = factored_dims
exp_avg_sq_row = state['exp_avg_sq_row']
exp_avg_sq_col = state['exp_avg_sq_col']
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t)
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=dim_row), alpha=1.0 - beta2t)
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=dim_col), alpha=1.0 - beta2t)
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col, dim_col, dim_row)
update.mul_(grad)
else:
exp_avg_sq = state['exp_avg_sq']