mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cleanup original adafactor impl, add row/col dim heuristic that works with both conv and linear layers
This commit is contained in:
parent
1409ce2dbe
commit
e7b0480381
@ -38,16 +38,38 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
whether warm-up initialization is being used (default: False)
|
whether warm-up initialization is being used (default: False)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
|
def __init__(
|
||||||
decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
|
self,
|
||||||
|
params,
|
||||||
|
lr=None,
|
||||||
|
eps=1e-30,
|
||||||
|
eps_scale=1e-3,
|
||||||
|
clip_threshold=1.0,
|
||||||
|
decay_rate=-0.8,
|
||||||
|
betas=None,
|
||||||
|
weight_decay=0.0,
|
||||||
|
scale_parameter=True,
|
||||||
|
warmup_init=False,
|
||||||
|
min_dim_size_to_factor=32,
|
||||||
|
):
|
||||||
relative_step = not lr
|
relative_step = not lr
|
||||||
if warmup_init and not relative_step:
|
if warmup_init and not relative_step:
|
||||||
raise ValueError('warmup_init requires relative_step=True')
|
raise ValueError('warmup_init requires relative_step=True')
|
||||||
|
|
||||||
beta1 = None if betas is None else betas[0] # make it compat with standard betas arg
|
beta1 = None if betas is None else betas[0] # make it compat with standard betas arg
|
||||||
defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate,
|
defaults = dict(
|
||||||
beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
|
lr=lr,
|
||||||
relative_step=relative_step, warmup_init=warmup_init)
|
eps=eps,
|
||||||
|
eps_scale=eps_scale,
|
||||||
|
clip_threshold=clip_threshold,
|
||||||
|
decay_rate=decay_rate,
|
||||||
|
beta1=beta1,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
scale_parameter=scale_parameter,
|
||||||
|
relative_step=relative_step,
|
||||||
|
warmup_init=warmup_init,
|
||||||
|
min_dim_size_to_factor=min_dim_size_to_factor,
|
||||||
|
)
|
||||||
super(Adafactor, self).__init__(params, defaults)
|
super(Adafactor, self).__init__(params, defaults)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -62,20 +84,34 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
return param_group['lr']
|
return param_group['lr']
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_options(param_group, param_shape):
|
def _get_options(param_group, param_shape, min_size_to_factor=32):
|
||||||
factored = len(param_shape) >= 2
|
|
||||||
use_first_moment = param_group['beta1'] is not None
|
use_first_moment = param_group['beta1'] is not None
|
||||||
|
factored = None
|
||||||
|
ndim = len(param_shape)
|
||||||
|
# Use a simple heuristic to pick factorization row & col, note other PyTorch impl tend to
|
||||||
|
# always use -2, -1 BUT this will not pick correct dims for convolutions. This is a simple
|
||||||
|
# approach that should work in most cases, compare to the slightly more involved approach
|
||||||
|
# in AdafactorBigVision that sorts dims by size, please report if wrong dims chosen.
|
||||||
|
if ndim > 2 and param_shape[0] > min_size_to_factor and param_shape[1] > min_size_to_factor:
|
||||||
|
# nD convs in torch are ND + 2 dim weights with leading in/out chs
|
||||||
|
factored = 0, 1
|
||||||
|
elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
|
||||||
|
# if the criteria above didn't match, check trailing dims
|
||||||
|
factored = ndim - 2, ndim - 1
|
||||||
|
|
||||||
return factored, use_first_moment
|
return factored, use_first_moment
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _rms(tensor):
|
def _rms(tensor):
|
||||||
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
||||||
|
|
||||||
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
|
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, dim_col, dim_row):
|
||||||
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
# from our dim heuristic, always dim_col < dim_row, so col reduction dim for factored row = dim_col
|
||||||
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=dim_col, keepdim=True)).rsqrt_().unsqueeze(dim_row)
|
||||||
|
c_factor = exp_avg_sq_col.unsqueeze(dim_col).rsqrt()
|
||||||
return torch.mul(r_factor, c_factor)
|
return torch.mul(r_factor, c_factor)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
@ -99,7 +135,11 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
|
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
|
|
||||||
factored, use_first_moment = self._get_options(group, grad.shape)
|
factored_dims, use_first_moment = self._get_options(
|
||||||
|
group,
|
||||||
|
grad.shape,
|
||||||
|
min_size_to_factor=group['min_dim_size_to_factor'],
|
||||||
|
)
|
||||||
# State Initialization
|
# State Initialization
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
@ -107,9 +147,12 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
if use_first_moment:
|
if use_first_moment:
|
||||||
# Exponential moving average of gradient values
|
# Exponential moving average of gradient values
|
||||||
state['exp_avg'] = torch.zeros_like(grad)
|
state['exp_avg'] = torch.zeros_like(grad)
|
||||||
if factored:
|
if factored_dims is not None:
|
||||||
state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad)
|
dim_col, dim_row = factored_dims
|
||||||
state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad)
|
def _remove_dim(shape, dim):
|
||||||
|
return shape[:dim] + shape[dim + 1:]
|
||||||
|
state['exp_avg_sq_row'] = torch.zeros(_remove_dim(grad.shape, dim_row)).to(grad)
|
||||||
|
state['exp_avg_sq_col'] = torch.zeros(_remove_dim(grad.shape, dim_col)).to(grad)
|
||||||
else:
|
else:
|
||||||
state['exp_avg_sq'] = torch.zeros_like(grad)
|
state['exp_avg_sq'] = torch.zeros_like(grad)
|
||||||
|
|
||||||
@ -117,7 +160,7 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
else:
|
else:
|
||||||
if use_first_moment:
|
if use_first_moment:
|
||||||
state['exp_avg'] = state['exp_avg'].to(grad)
|
state['exp_avg'] = state['exp_avg'].to(grad)
|
||||||
if factored:
|
if factored_dims is not None:
|
||||||
state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
|
state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
|
||||||
state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
|
state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
|
||||||
else:
|
else:
|
||||||
@ -133,15 +176,16 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
|
|
||||||
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
|
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
|
||||||
update = grad ** 2 + group['eps']
|
update = grad ** 2 + group['eps']
|
||||||
if factored:
|
if factored_dims is not None:
|
||||||
|
dim_col, dim_row = factored_dims
|
||||||
exp_avg_sq_row = state['exp_avg_sq_row']
|
exp_avg_sq_row = state['exp_avg_sq_row']
|
||||||
exp_avg_sq_col = state['exp_avg_sq_col']
|
exp_avg_sq_col = state['exp_avg_sq_col']
|
||||||
|
|
||||||
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t)
|
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=dim_row), alpha=1.0 - beta2t)
|
||||||
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
|
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=dim_col), alpha=1.0 - beta2t)
|
||||||
|
|
||||||
# Approximation of exponential moving average of square of gradient
|
# Approximation of exponential moving average of square of gradient
|
||||||
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col, dim_col, dim_row)
|
||||||
update.mul_(grad)
|
update.mul_(grad)
|
||||||
else:
|
else:
|
||||||
exp_avg_sq = state['exp_avg_sq']
|
exp_avg_sq = state['exp_avg_sq']
|
||||||
|
Loading…
x
Reference in New Issue
Block a user