mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cleanup original adafactor impl, add row/col dim heuristic that works with both conv and linear layers
This commit is contained in:
parent
1409ce2dbe
commit
e7b0480381
@ -38,16 +38,38 @@ class Adafactor(torch.optim.Optimizer):
|
||||
whether warm-up initialization is being used (default: False)
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
|
||||
decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=None,
|
||||
eps=1e-30,
|
||||
eps_scale=1e-3,
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
betas=None,
|
||||
weight_decay=0.0,
|
||||
scale_parameter=True,
|
||||
warmup_init=False,
|
||||
min_dim_size_to_factor=32,
|
||||
):
|
||||
relative_step = not lr
|
||||
if warmup_init and not relative_step:
|
||||
raise ValueError('warmup_init requires relative_step=True')
|
||||
|
||||
beta1 = None if betas is None else betas[0] # make it compat with standard betas arg
|
||||
defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate,
|
||||
beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
|
||||
relative_step=relative_step, warmup_init=warmup_init)
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
eps=eps,
|
||||
eps_scale=eps_scale,
|
||||
clip_threshold=clip_threshold,
|
||||
decay_rate=decay_rate,
|
||||
beta1=beta1,
|
||||
weight_decay=weight_decay,
|
||||
scale_parameter=scale_parameter,
|
||||
relative_step=relative_step,
|
||||
warmup_init=warmup_init,
|
||||
min_dim_size_to_factor=min_dim_size_to_factor,
|
||||
)
|
||||
super(Adafactor, self).__init__(params, defaults)
|
||||
|
||||
@staticmethod
|
||||
@ -62,20 +84,34 @@ class Adafactor(torch.optim.Optimizer):
|
||||
return param_group['lr']
|
||||
|
||||
@staticmethod
|
||||
def _get_options(param_group, param_shape):
|
||||
factored = len(param_shape) >= 2
|
||||
def _get_options(param_group, param_shape, min_size_to_factor=32):
|
||||
use_first_moment = param_group['beta1'] is not None
|
||||
factored = None
|
||||
ndim = len(param_shape)
|
||||
# Use a simple heuristic to pick factorization row & col, note other PyTorch impl tend to
|
||||
# always use -2, -1 BUT this will not pick correct dims for convolutions. This is a simple
|
||||
# approach that should work in most cases, compare to the slightly more involved approach
|
||||
# in AdafactorBigVision that sorts dims by size, please report if wrong dims chosen.
|
||||
if ndim > 2 and param_shape[0] > min_size_to_factor and param_shape[1] > min_size_to_factor:
|
||||
# nD convs in torch are ND + 2 dim weights with leading in/out chs
|
||||
factored = 0, 1
|
||||
elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
|
||||
# if the criteria above didn't match, check trailing dims
|
||||
factored = ndim - 2, ndim - 1
|
||||
|
||||
return factored, use_first_moment
|
||||
|
||||
@staticmethod
|
||||
def _rms(tensor):
|
||||
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
||||
|
||||
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
|
||||
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
||||
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
||||
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, dim_col, dim_row):
|
||||
# from our dim heuristic, always dim_col < dim_row, so col reduction dim for factored row = dim_col
|
||||
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=dim_col, keepdim=True)).rsqrt_().unsqueeze(dim_row)
|
||||
c_factor = exp_avg_sq_col.unsqueeze(dim_col).rsqrt()
|
||||
return torch.mul(r_factor, c_factor)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
@ -99,7 +135,11 @@ class Adafactor(torch.optim.Optimizer):
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
factored, use_first_moment = self._get_options(group, grad.shape)
|
||||
factored_dims, use_first_moment = self._get_options(
|
||||
group,
|
||||
grad.shape,
|
||||
min_size_to_factor=group['min_dim_size_to_factor'],
|
||||
)
|
||||
# State Initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
@ -107,9 +147,12 @@ class Adafactor(torch.optim.Optimizer):
|
||||
if use_first_moment:
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(grad)
|
||||
if factored:
|
||||
state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad)
|
||||
state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad)
|
||||
if factored_dims is not None:
|
||||
dim_col, dim_row = factored_dims
|
||||
def _remove_dim(shape, dim):
|
||||
return shape[:dim] + shape[dim + 1:]
|
||||
state['exp_avg_sq_row'] = torch.zeros(_remove_dim(grad.shape, dim_row)).to(grad)
|
||||
state['exp_avg_sq_col'] = torch.zeros(_remove_dim(grad.shape, dim_col)).to(grad)
|
||||
else:
|
||||
state['exp_avg_sq'] = torch.zeros_like(grad)
|
||||
|
||||
@ -117,7 +160,7 @@ class Adafactor(torch.optim.Optimizer):
|
||||
else:
|
||||
if use_first_moment:
|
||||
state['exp_avg'] = state['exp_avg'].to(grad)
|
||||
if factored:
|
||||
if factored_dims is not None:
|
||||
state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
|
||||
state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
|
||||
else:
|
||||
@ -133,15 +176,16 @@ class Adafactor(torch.optim.Optimizer):
|
||||
|
||||
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
|
||||
update = grad ** 2 + group['eps']
|
||||
if factored:
|
||||
if factored_dims is not None:
|
||||
dim_col, dim_row = factored_dims
|
||||
exp_avg_sq_row = state['exp_avg_sq_row']
|
||||
exp_avg_sq_col = state['exp_avg_sq_col']
|
||||
|
||||
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t)
|
||||
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
|
||||
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=dim_row), alpha=1.0 - beta2t)
|
||||
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=dim_col), alpha=1.0 - beta2t)
|
||||
|
||||
# Approximation of exponential moving average of square of gradient
|
||||
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
||||
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col, dim_col, dim_row)
|
||||
update.mul_(grad)
|
||||
else:
|
||||
exp_avg_sq = state['exp_avg_sq']
|
||||
|
Loading…
x
Reference in New Issue
Block a user