mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More optimizer cleanup. Change all to no longer use .data. Improve (b)float16 use with adabelief. Add XLA compatible Lars.
This commit is contained in:
parent
9541f4963b
commit
a426511c95
@ -490,6 +490,33 @@ def test_lamb(optimizer):
|
|||||||
_test_model(optimizer, dict(lr=1e-3))
|
_test_model(optimizer, dict(lr=1e-3))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
|
||||||
|
def test_lars(optimizer):
|
||||||
|
_test_basic_cases(
|
||||||
|
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||||
|
)
|
||||||
|
_test_basic_cases(
|
||||||
|
lambda weight, bias: create_optimizer_v2(
|
||||||
|
_build_params_dict(weight, bias, lr=1e-3),
|
||||||
|
optimizer,
|
||||||
|
lr=1e-1)
|
||||||
|
)
|
||||||
|
_test_basic_cases(
|
||||||
|
lambda weight, bias: create_optimizer_v2(
|
||||||
|
_build_params_dict_single(weight, bias, lr=1e-3),
|
||||||
|
optimizer,
|
||||||
|
lr=1e-3)
|
||||||
|
)
|
||||||
|
_test_basic_cases(
|
||||||
|
lambda weight, bias: create_optimizer_v2(
|
||||||
|
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
|
||||||
|
)
|
||||||
|
_test_rosenbrock(
|
||||||
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||||
|
)
|
||||||
|
_test_model(optimizer, dict(lr=1e-3))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
|
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
|
||||||
def test_madgrad(optimizer):
|
def test_madgrad(optimizer):
|
||||||
_test_basic_cases(
|
_test_basic_cases(
|
||||||
|
@ -68,6 +68,7 @@ class AdaBelief(Optimizer):
|
|||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
group.setdefault('amsgrad', False)
|
group.setdefault('amsgrad', False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def reset(self):
|
def reset(self):
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
@ -77,14 +78,15 @@ class AdaBelief(Optimizer):
|
|||||||
# State initialization
|
# State initialization
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
# Exponential moving average of gradient values
|
# Exponential moving average of gradient values
|
||||||
state['exp_avg'] = torch.zeros_like(p.data)
|
state['exp_avg'] = torch.zeros_like(p)
|
||||||
|
|
||||||
# Exponential moving average of squared gradient values
|
# Exponential moving average of squared gradient values
|
||||||
state['exp_avg_var'] = torch.zeros_like(p.data)
|
state['exp_avg_var'] = torch.zeros_like(p)
|
||||||
if amsgrad:
|
if amsgrad:
|
||||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||||
state['max_exp_avg_var'] = torch.zeros_like(p.data)
|
state['max_exp_avg_var'] = torch.zeros_like(p)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -93,50 +95,47 @@ class AdaBelief(Optimizer):
|
|||||||
"""
|
"""
|
||||||
loss = None
|
loss = None
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
|
grad = p.grad
|
||||||
# cast data type
|
if grad.dtype in {torch.float16, torch.bfloat16}:
|
||||||
half_precision = False
|
grad = grad.float()
|
||||||
if p.data.dtype == torch.float16:
|
|
||||||
half_precision = True
|
|
||||||
p.data = p.data.float()
|
|
||||||
p.grad = p.grad.float()
|
|
||||||
|
|
||||||
grad = p.grad.data
|
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
|
'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
|
||||||
|
|
||||||
|
p_fp32 = p
|
||||||
|
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||||
|
p_fp32 = p_fp32.float()
|
||||||
|
|
||||||
amsgrad = group['amsgrad']
|
amsgrad = group['amsgrad']
|
||||||
|
|
||||||
state = self.state[p]
|
|
||||||
|
|
||||||
beta1, beta2 = group['betas']
|
beta1, beta2 = group['betas']
|
||||||
|
state = self.state[p]
|
||||||
# State initialization
|
# State initialization
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
# Exponential moving average of gradient values
|
# Exponential moving average of gradient values
|
||||||
state['exp_avg'] = torch.zeros_like(p.data)
|
state['exp_avg'] = torch.zeros_like(p_fp32)
|
||||||
# Exponential moving average of squared gradient values
|
# Exponential moving average of squared gradient values
|
||||||
state['exp_avg_var'] = torch.zeros_like(p.data)
|
state['exp_avg_var'] = torch.zeros_like(p_fp32)
|
||||||
if amsgrad:
|
if amsgrad:
|
||||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||||
state['max_exp_avg_var'] = torch.zeros_like(p.data)
|
state['max_exp_avg_var'] = torch.zeros_like(p_fp32)
|
||||||
|
|
||||||
# perform weight decay, check if decoupled weight decay
|
# perform weight decay, check if decoupled weight decay
|
||||||
if group['decoupled_decay']:
|
if group['decoupled_decay']:
|
||||||
if not group['fixed_decay']:
|
if not group['fixed_decay']:
|
||||||
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
|
p_fp32.mul_(1.0 - group['lr'] * group['weight_decay'])
|
||||||
else:
|
else:
|
||||||
p.data.mul_(1.0 - group['weight_decay'])
|
p_fp32.mul_(1.0 - group['weight_decay'])
|
||||||
else:
|
else:
|
||||||
if group['weight_decay'] != 0:
|
if group['weight_decay'] != 0:
|
||||||
grad.add_(p.data, alpha=group['weight_decay'])
|
grad.add_(p_fp32, alpha=group['weight_decay'])
|
||||||
|
|
||||||
# get current state variable
|
# get current state variable
|
||||||
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
|
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
|
||||||
@ -164,7 +163,7 @@ class AdaBelief(Optimizer):
|
|||||||
if not group['rectify']:
|
if not group['rectify']:
|
||||||
# Default update
|
# Default update
|
||||||
step_size = group['lr'] / bias_correction1
|
step_size = group['lr'] / bias_correction1
|
||||||
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
else:
|
else:
|
||||||
# Rectified update, forked from RAdam
|
# Rectified update, forked from RAdam
|
||||||
buffered = group['buffer'][int(state['step'] % 10)]
|
buffered = group['buffer'][int(state['step'] % 10)]
|
||||||
@ -192,12 +191,11 @@ class AdaBelief(Optimizer):
|
|||||||
|
|
||||||
if num_sma >= 5:
|
if num_sma >= 5:
|
||||||
denom = exp_avg_var.sqrt().add_(group['eps'])
|
denom = exp_avg_var.sqrt().add_(group['eps'])
|
||||||
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
|
p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
|
||||||
elif step_size > 0:
|
elif step_size > 0:
|
||||||
p.data.add_(exp_avg, alpha=-step_size * group['lr'])
|
p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
|
||||||
|
|
||||||
if half_precision:
|
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||||
p.data = p.data.half()
|
p.copy_(p_fp32)
|
||||||
p.grad = p.grad.half()
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -76,6 +76,7 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
||||||
return torch.mul(r_factor, c_factor)
|
return torch.mul(r_factor, c_factor)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -83,22 +84,22 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
"""
|
"""
|
||||||
loss = None
|
loss = None
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
grad = p.grad.data
|
grad = p.grad
|
||||||
if grad.dtype in {torch.float16, torch.bfloat16}:
|
if grad.dtype in {torch.float16, torch.bfloat16}:
|
||||||
grad = grad.float()
|
grad = grad.float()
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError('Adafactor does not support sparse gradients.')
|
raise RuntimeError('Adafactor does not support sparse gradients.')
|
||||||
|
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
grad_shape = grad.shape
|
|
||||||
|
|
||||||
factored, use_first_moment = self._get_options(group, grad_shape)
|
factored, use_first_moment = self._get_options(group, grad.shape)
|
||||||
# State Initialization
|
# State Initialization
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
@ -107,8 +108,8 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
# Exponential moving average of gradient values
|
# Exponential moving average of gradient values
|
||||||
state['exp_avg'] = torch.zeros_like(grad)
|
state['exp_avg'] = torch.zeros_like(grad)
|
||||||
if factored:
|
if factored:
|
||||||
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
|
state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad)
|
||||||
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad)
|
||||||
else:
|
else:
|
||||||
state['exp_avg_sq'] = torch.zeros_like(grad)
|
state['exp_avg_sq'] = torch.zeros_like(grad)
|
||||||
|
|
||||||
@ -122,12 +123,12 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
else:
|
else:
|
||||||
state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
|
state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
|
||||||
|
|
||||||
p_data_fp32 = p.data
|
p_fp32 = p
|
||||||
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||||
p_data_fp32 = p_data_fp32.float()
|
p_fp32 = p_fp32.float()
|
||||||
|
|
||||||
state['step'] += 1
|
state['step'] += 1
|
||||||
state['RMS'] = self._rms(p_data_fp32)
|
state['RMS'] = self._rms(p_fp32)
|
||||||
lr_t = self._get_lr(group, state)
|
lr_t = self._get_lr(group, state)
|
||||||
|
|
||||||
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
|
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
|
||||||
@ -157,11 +158,10 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
update = exp_avg
|
update = exp_avg
|
||||||
|
|
||||||
if group['weight_decay'] != 0:
|
if group['weight_decay'] != 0:
|
||||||
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t)
|
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t)
|
||||||
|
|
||||||
p_data_fp32.add_(-update)
|
p_fp32.add_(-update)
|
||||||
|
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||||
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
p.copy_(p_fp32)
|
||||||
p.data.copy_(p_data_fp32)
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -26,12 +26,13 @@ def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
|
|||||||
wd = 1.
|
wd = 1.
|
||||||
expand_size = (-1,) + (1,) * (len(p.shape) - 1)
|
expand_size = (-1,) + (1,) * (len(p.shape) - 1)
|
||||||
for view_func in [_channel_view, _layer_view]:
|
for view_func in [_channel_view, _layer_view]:
|
||||||
param_view = view_func(p.data)
|
param_view = view_func(p)
|
||||||
grad_view = view_func(grad)
|
grad_view = view_func(grad)
|
||||||
cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
|
cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
|
||||||
|
|
||||||
|
# FIXME this is a problem for PyTorch XLA
|
||||||
if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
|
if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
|
||||||
p_n = p.data / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
|
p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
|
||||||
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
|
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
|
||||||
wd = wd_ratio
|
wd = wd_ratio
|
||||||
return perturb, wd
|
return perturb, wd
|
||||||
@ -47,17 +48,19 @@ class AdamP(Optimizer):
|
|||||||
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
|
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
|
||||||
super(AdamP, self).__init__(params, defaults)
|
super(AdamP, self).__init__(params, defaults)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
loss = None
|
loss = None
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
grad = p.grad.data
|
grad = p.grad
|
||||||
beta1, beta2 = group['betas']
|
beta1, beta2 = group['betas']
|
||||||
nesterov = group['nesterov']
|
nesterov = group['nesterov']
|
||||||
|
|
||||||
@ -66,8 +69,8 @@ class AdamP(Optimizer):
|
|||||||
# State initialization
|
# State initialization
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
state['exp_avg'] = torch.zeros_like(p.data)
|
state['exp_avg'] = torch.zeros_like(p)
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||||
|
|
||||||
# Adam
|
# Adam
|
||||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||||
@ -94,9 +97,9 @@ class AdamP(Optimizer):
|
|||||||
|
|
||||||
# Weight decay
|
# Weight decay
|
||||||
if group['weight_decay'] > 0:
|
if group['weight_decay'] > 0:
|
||||||
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
|
p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
|
||||||
|
|
||||||
# Step
|
# Step
|
||||||
p.data.add_(perturb, alpha=-step_size)
|
p.add_(perturb, alpha=-step_size)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -55,6 +55,7 @@ class AdamW(Optimizer):
|
|||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
group.setdefault('amsgrad', False)
|
group.setdefault('amsgrad', False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
@ -64,7 +65,8 @@ class AdamW(Optimizer):
|
|||||||
"""
|
"""
|
||||||
loss = None
|
loss = None
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
@ -75,7 +77,7 @@ class AdamW(Optimizer):
|
|||||||
p.data.mul_(1 - group['lr'] * group['weight_decay'])
|
p.data.mul_(1 - group['lr'] * group['weight_decay'])
|
||||||
|
|
||||||
# Perform optimization step
|
# Perform optimization step
|
||||||
grad = p.grad.data
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||||
amsgrad = group['amsgrad']
|
amsgrad = group['amsgrad']
|
||||||
@ -86,12 +88,12 @@ class AdamW(Optimizer):
|
|||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
# Exponential moving average of gradient values
|
# Exponential moving average of gradient values
|
||||||
state['exp_avg'] = torch.zeros_like(p.data)
|
state['exp_avg'] = torch.zeros_like(p)
|
||||||
# Exponential moving average of squared gradient values
|
# Exponential moving average of squared gradient values
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||||
if amsgrad:
|
if amsgrad:
|
||||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||||
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
|
state['max_exp_avg_sq'] = torch.zeros_like(p)
|
||||||
|
|
||||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||||
if amsgrad:
|
if amsgrad:
|
||||||
@ -115,6 +117,6 @@ class AdamW(Optimizer):
|
|||||||
|
|
||||||
step_size = group['lr'] / bias_correction1
|
step_size = group['lr'] / bias_correction1
|
||||||
|
|
||||||
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -5,10 +5,14 @@ This optimizer code was adapted from the following (starting with latest)
|
|||||||
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
|
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
|
||||||
* https://github.com/cybertronai/pytorch-lamb
|
* https://github.com/cybertronai/pytorch-lamb
|
||||||
|
|
||||||
Use FusedLamb if you can. The reason for including this variant of Lamb is to have a version that is
|
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
|
||||||
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install APEX for whatever reason.
|
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
|
||||||
|
|
||||||
|
In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
|
||||||
|
|
||||||
Original copyrights for above sources are below.
|
Original copyrights for above sources are below.
|
||||||
|
|
||||||
|
Modifications Copyright 2021 Ross Wightman
|
||||||
"""
|
"""
|
||||||
# Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
|
# Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
|
||||||
|
|
||||||
@ -60,8 +64,7 @@ class Lamb(Optimizer):
|
|||||||
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
|
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
params (iterable): iterable of parameters to optimize or dicts defining
|
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
|
||||||
parameter groups.
|
|
||||||
lr (float, optional): learning rate. (default: 1e-3)
|
lr (float, optional): learning rate. (default: 1e-3)
|
||||||
betas (Tuple[float, float], optional): coefficients used for computing
|
betas (Tuple[float, float], optional): coefficients used for computing
|
||||||
running averages of gradient and its norm. (default: (0.9, 0.999))
|
running averages of gradient and its norm. (default: (0.9, 0.999))
|
||||||
@ -72,8 +75,7 @@ class Lamb(Optimizer):
|
|||||||
calculating running averages of gradient. (default: True)
|
calculating running averages of gradient. (default: True)
|
||||||
set_grad_none (bool, optional): whether set grad to None when zero_grad()
|
set_grad_none (bool, optional): whether set grad to None when zero_grad()
|
||||||
method is called. (default: True)
|
method is called. (default: True)
|
||||||
max_grad_norm (float, optional): value used to clip global grad norm
|
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
|
||||||
(default: 1.0)
|
|
||||||
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
|
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
|
||||||
weight decay parameter (default: False)
|
weight decay parameter (default: False)
|
||||||
|
|
||||||
@ -91,25 +93,26 @@ class Lamb(Optimizer):
|
|||||||
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb)
|
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb)
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
Arguments:
|
Arguments:
|
||||||
closure (callable, optional): A closure that reevaluates the model
|
closure (callable, optional): A closure that reevaluates the model
|
||||||
and returns the loss.
|
and returns the loss.
|
||||||
"""
|
"""
|
||||||
device = self.param_groups[0]["params"][0].device
|
|
||||||
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
device = self.param_groups[0]["params"][0].device
|
||||||
|
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
||||||
global_grad_norm = torch.zeros(1, device=device)
|
global_grad_norm = torch.zeros(1, device=device)
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
grad = p.grad.data
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
|
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
|
||||||
global_grad_norm.add_(grad.pow(2).sum())
|
global_grad_norm.add_(grad.pow(2).sum())
|
||||||
@ -145,15 +148,15 @@ class Lamb(Optimizer):
|
|||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
grad = p.grad.data.div_(clip_global_grad_norm)
|
grad = p.grad.div_(clip_global_grad_norm)
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
|
|
||||||
# State initialization
|
# State initialization
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
# Exponential moving average of gradient values
|
# Exponential moving average of gradient valuesa
|
||||||
state['exp_avg'] = torch.zeros_like(p.data)
|
state['exp_avg'] = torch.zeros_like(p)
|
||||||
# Exponential moving average of squared gradient values
|
# Exponential moving average of squared gradient values
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||||
|
|
||||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||||
|
|
||||||
@ -166,20 +169,21 @@ class Lamb(Optimizer):
|
|||||||
|
|
||||||
weight_decay = group['weight_decay']
|
weight_decay = group['weight_decay']
|
||||||
if weight_decay != 0:
|
if weight_decay != 0:
|
||||||
update.add_(p.data, alpha=weight_decay)
|
update.add_(p, alpha=weight_decay)
|
||||||
|
|
||||||
trust_ratio = one_tensor
|
|
||||||
if weight_decay != 0 or group['use_nvlamb']:
|
if weight_decay != 0 or group['use_nvlamb']:
|
||||||
# Layer adaptation. By default, skip layer adaptation on parameters that are
|
# Layer adaptation. By default, skip layer adaptation on parameters that are
|
||||||
# excluded from weight decay, unless use_nvlamb == True, then always enabled.
|
# excluded from weight decay, unless use_nvlamb == True, then always enabled.
|
||||||
w_norm = p.data.norm(2.0)
|
w_norm = p.norm(2.0)
|
||||||
g_norm = update.norm(2.0)
|
g_norm = update.norm(2.0)
|
||||||
|
# FIXME nested where required since logical and/or not working in PT XLA
|
||||||
trust_ratio = torch.where(
|
trust_ratio = torch.where(
|
||||||
w_norm > 0,
|
w_norm > 0,
|
||||||
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
|
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
|
||||||
one_tensor,
|
one_tensor,
|
||||||
)
|
)
|
||||||
update.mul_(trust_ratio)
|
update.mul_(trust_ratio)
|
||||||
p.data.add_(update, alpha=-group['lr'])
|
|
||||||
|
p.add_(update, alpha=-group['lr'])
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
136
timm/optim/lars.py
Normal file
136
timm/optim/lars.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
""" PyTorch LARS / LARC Optimizer
|
||||||
|
|
||||||
|
An implementation of LARS (SGD) + LARC in PyTorch
|
||||||
|
|
||||||
|
Based on:
|
||||||
|
* PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
|
||||||
|
* NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
|
||||||
|
|
||||||
|
Additional cleanup and modifications to properly support PyTorch XLA.
|
||||||
|
|
||||||
|
Copyright 2021 Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from torch.optim.optimizer import Optimizer, required
|
||||||
|
|
||||||
|
|
||||||
|
class Lars(Optimizer):
|
||||||
|
""" LARS for PyTorch
|
||||||
|
|
||||||
|
Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
|
||||||
|
lr (float, optional): learning rate. (default: 1e-3)
|
||||||
|
momentum (float, optional): momentum factor (default: 0)
|
||||||
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||||
|
dampening (float, optional): dampening for momentum (default: 0)
|
||||||
|
nesterov (bool, optional): enables Nesterov momentum (default: False)
|
||||||
|
trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001)
|
||||||
|
eps (float): eps for division denominator (default: 1e-8)
|
||||||
|
larc (bool): enable LARC clipping (default: False)
|
||||||
|
always_scale (bool): always apply LARS scaling, otherwise only when group weight_decay != 0 (default: False)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params,
|
||||||
|
lr=required,
|
||||||
|
momentum=0,
|
||||||
|
dampening=0,
|
||||||
|
weight_decay=0,
|
||||||
|
nesterov=False,
|
||||||
|
trust_coeff=0.001,
|
||||||
|
eps=1e-8,
|
||||||
|
larc=False,
|
||||||
|
always_scale=False,
|
||||||
|
):
|
||||||
|
if lr is not required and lr < 0.0:
|
||||||
|
raise ValueError(f"Invalid learning rate: {lr}")
|
||||||
|
if momentum < 0.0:
|
||||||
|
raise ValueError(f"Invalid momentum value: {momentum}")
|
||||||
|
if weight_decay < 0.0:
|
||||||
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
if nesterov and (momentum <= 0 or dampening != 0):
|
||||||
|
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
||||||
|
|
||||||
|
defaults = dict(
|
||||||
|
lr=lr,
|
||||||
|
momentum=momentum,
|
||||||
|
dampening=dampening,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
nesterov=nesterov,
|
||||||
|
trust_coeff=trust_coeff,
|
||||||
|
eps=eps,
|
||||||
|
larc=larc,
|
||||||
|
always_scale=always_scale,
|
||||||
|
)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super().__setstate__(state)
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault("nesterov", False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
closure (callable, optional): A closure that reevaluates the model and returns the loss.
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
device = self.param_groups[0]["params"][0].device
|
||||||
|
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
||||||
|
|
||||||
|
# exclude scaling for params with 0 weight decay
|
||||||
|
for group in self.param_groups:
|
||||||
|
weight_decay = group['weight_decay']
|
||||||
|
momentum = group['momentum']
|
||||||
|
dampening = group['dampening']
|
||||||
|
nesterov = group['nesterov']
|
||||||
|
trust_coeff = group['trust_coeff']
|
||||||
|
eps = group['eps']
|
||||||
|
|
||||||
|
for p in group['params']:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
grad = p.grad
|
||||||
|
|
||||||
|
# apply LARS scaling, LARC clipping, weight decay
|
||||||
|
# ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
|
||||||
|
if weight_decay != 0 or group['always_scale']:
|
||||||
|
w_norm = p.norm(2.0)
|
||||||
|
g_norm = grad.norm(2.0)
|
||||||
|
trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
|
||||||
|
# FIXME nested where required since logical and/or not working in PT XLA
|
||||||
|
trust_ratio = torch.where(
|
||||||
|
w_norm > 0,
|
||||||
|
torch.where(g_norm > 0, trust_ratio, one_tensor),
|
||||||
|
one_tensor,
|
||||||
|
)
|
||||||
|
if group['larc']:
|
||||||
|
trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
|
||||||
|
grad.add(p, alpha=weight_decay)
|
||||||
|
grad.mul_(trust_ratio)
|
||||||
|
|
||||||
|
# apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
|
||||||
|
if momentum != 0:
|
||||||
|
param_state = self.state[p]
|
||||||
|
if 'momentum_buffer' not in param_state:
|
||||||
|
buf = param_state['momentum_buffer'] = torch.clone(grad).detach()
|
||||||
|
else:
|
||||||
|
buf = param_state['momentum_buffer']
|
||||||
|
buf.mul_(momentum).add_(grad, alpha=1. - dampening)
|
||||||
|
if nesterov:
|
||||||
|
grad = grad.add(buf, alpha=momentum)
|
||||||
|
else:
|
||||||
|
grad = buf
|
||||||
|
|
||||||
|
p.add_(grad, alpha=-group['lr'])
|
||||||
|
|
||||||
|
return loss
|
@ -27,22 +27,24 @@ class Lookahead(Optimizer):
|
|||||||
for group in self._base_optimizer.param_groups:
|
for group in self._base_optimizer.param_groups:
|
||||||
group.setdefault(name, default)
|
group.setdefault(name, default)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def update_slow(self, group):
|
def update_slow(self, group):
|
||||||
for fast_p in group["params"]:
|
for fast_p in group["params"]:
|
||||||
if fast_p.grad is None:
|
if fast_p.grad is None:
|
||||||
continue
|
continue
|
||||||
param_state = self._base_optimizer.state[fast_p]
|
param_state = self._base_optimizer.state[fast_p]
|
||||||
if 'lookahead_slow_buff' not in param_state:
|
if 'lookahead_slow_buff' not in param_state:
|
||||||
param_state['lookahead_slow_buff'] = torch.empty_like(fast_p.data)
|
param_state['lookahead_slow_buff'] = torch.empty_like(fast_p)
|
||||||
param_state['lookahead_slow_buff'].copy_(fast_p.data)
|
param_state['lookahead_slow_buff'].copy_(fast_p)
|
||||||
slow = param_state['lookahead_slow_buff']
|
slow = param_state['lookahead_slow_buff']
|
||||||
slow.add_(fast_p.data - slow, alpha=group['lookahead_alpha'])
|
slow.add_(fast_p - slow, alpha=group['lookahead_alpha'])
|
||||||
fast_p.data.copy_(slow)
|
fast_p.copy_(slow)
|
||||||
|
|
||||||
def sync_lookahead(self):
|
def sync_lookahead(self):
|
||||||
for group in self._base_optimizer.param_groups:
|
for group in self._base_optimizer.param_groups:
|
||||||
self.update_slow(group)
|
self.update_slow(group)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
loss = self._base_optimizer.step(closure)
|
loss = self._base_optimizer.step(closure)
|
||||||
for group in self._base_optimizer.param_groups:
|
for group in self._base_optimizer.param_groups:
|
||||||
|
@ -82,6 +82,7 @@ class MADGRAD(torch.optim.Optimizer):
|
|||||||
def supports_flat_params(self) -> bool:
|
def supports_flat_params(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
@ -91,13 +92,10 @@ class MADGRAD(torch.optim.Optimizer):
|
|||||||
"""
|
"""
|
||||||
loss = None
|
loss = None
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
# step counter must be stored in state to ensure correct behavior under
|
step = self.state.setdefault('step', 0) # k
|
||||||
# optimizer sharding
|
|
||||||
if 'k' not in self.state:
|
|
||||||
self.state['k'] = torch.tensor([0], dtype=torch.long)
|
|
||||||
k = self.state['k'].item()
|
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
@ -106,19 +104,19 @@ class MADGRAD(torch.optim.Optimizer):
|
|||||||
momentum = group["momentum"]
|
momentum = group["momentum"]
|
||||||
|
|
||||||
ck = 1 - momentum
|
ck = 1 - momentum
|
||||||
lamb = lr * math.pow(k + 1, 0.5)
|
lamb = lr * math.sqrt(step + 1)
|
||||||
|
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
grad = p.grad.data
|
grad = p.grad
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
|
|
||||||
if "grad_sum_sq" not in state:
|
if "grad_sum_sq" not in state:
|
||||||
state["grad_sum_sq"] = torch.zeros_like(p.data).detach()
|
state["grad_sum_sq"] = torch.zeros_like(p)
|
||||||
state["s"] = torch.zeros_like(p.data).detach()
|
state["s"] = torch.zeros_like(p)
|
||||||
if momentum != 0:
|
if momentum != 0:
|
||||||
state["x0"] = torch.clone(p.data).detach()
|
state["x0"] = torch.clone(p).detach()
|
||||||
|
|
||||||
if momentum != 0.0 and grad.is_sparse:
|
if momentum != 0.0 and grad.is_sparse:
|
||||||
raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
|
raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
|
||||||
@ -129,11 +127,11 @@ class MADGRAD(torch.optim.Optimizer):
|
|||||||
# Apply weight decay
|
# Apply weight decay
|
||||||
if weight_decay != 0:
|
if weight_decay != 0:
|
||||||
if group['decoupled_decay']:
|
if group['decoupled_decay']:
|
||||||
p.data.mul_(1.0 - group['lr'] * weight_decay)
|
p.mul_(1.0 - group['lr'] * weight_decay)
|
||||||
else:
|
else:
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
||||||
grad.add_(p.data, alpha=weight_decay)
|
grad.add_(p, alpha=weight_decay)
|
||||||
|
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
grad = grad.coalesce()
|
grad = grad.coalesce()
|
||||||
@ -161,12 +159,12 @@ class MADGRAD(torch.optim.Optimizer):
|
|||||||
p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
|
p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
|
||||||
# Copy updated masked p to dense p using an add operation
|
# Copy updated masked p to dense p using an add operation
|
||||||
p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
|
p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
|
||||||
p.data.add_(p_masked, alpha=-1)
|
p.add_(p_masked, alpha=-1)
|
||||||
else:
|
else:
|
||||||
if momentum == 0:
|
if momentum == 0:
|
||||||
# Compute x_0 from other known quantities
|
# Compute x_0 from other known quantities
|
||||||
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
||||||
x0 = p.data.addcdiv(s, rms, value=1)
|
x0 = p.addcdiv(s, rms, value=1)
|
||||||
else:
|
else:
|
||||||
x0 = state["x0"]
|
x0 = state["x0"]
|
||||||
|
|
||||||
@ -175,16 +173,16 @@ class MADGRAD(torch.optim.Optimizer):
|
|||||||
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
||||||
|
|
||||||
# Update s
|
# Update s
|
||||||
s.data.add_(grad, alpha=lamb)
|
s.add_(grad, alpha=lamb)
|
||||||
|
|
||||||
# Step
|
# Step
|
||||||
if momentum == 0:
|
if momentum == 0:
|
||||||
p.data.copy_(x0.addcdiv(s, rms, value=-1))
|
p.copy_(x0.addcdiv(s, rms, value=-1))
|
||||||
else:
|
else:
|
||||||
z = x0.addcdiv(s, rms, value=-1)
|
z = x0.addcdiv(s, rms, value=-1)
|
||||||
|
|
||||||
# p is a moving average of z
|
# p is a moving average of z
|
||||||
p.data.mul_(1 - ck).add_(z, alpha=ck)
|
p.mul_(1 - ck).add_(z, alpha=ck)
|
||||||
|
|
||||||
self.state['k'] += 1
|
self.state['step'] += 1
|
||||||
return loss
|
return loss
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
|
|
||||||
@ -33,6 +35,7 @@ class Nadam(Optimizer):
|
|||||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
|
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
|
||||||
super(Nadam, self).__init__(params, defaults)
|
super(Nadam, self).__init__(params, defaults)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
@ -42,21 +45,22 @@ class Nadam(Optimizer):
|
|||||||
"""
|
"""
|
||||||
loss = None
|
loss = None
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
grad = p.grad.data
|
grad = p.grad
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
|
|
||||||
# State initialization
|
# State initialization
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
state['m_schedule'] = 1.
|
state['m_schedule'] = 1.
|
||||||
state['exp_avg'] = torch.zeros_like(p.data)
|
state['exp_avg'] = torch.zeros_like(p)
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||||
|
|
||||||
# Warming momentum schedule
|
# Warming momentum schedule
|
||||||
m_schedule = state['m_schedule']
|
m_schedule = state['m_schedule']
|
||||||
@ -66,9 +70,10 @@ class Nadam(Optimizer):
|
|||||||
eps = group['eps']
|
eps = group['eps']
|
||||||
state['step'] += 1
|
state['step'] += 1
|
||||||
t = state['step']
|
t = state['step']
|
||||||
|
bias_correction2 = 1 - beta2 ** t
|
||||||
|
|
||||||
if group['weight_decay'] != 0:
|
if group['weight_decay'] != 0:
|
||||||
grad = grad.add(p.data, alpha=group['weight_decay'])
|
grad = grad.add(p, alpha=group['weight_decay'])
|
||||||
|
|
||||||
momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
|
momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
|
||||||
momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
|
momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
|
||||||
@ -79,10 +84,9 @@ class Nadam(Optimizer):
|
|||||||
# Decay the first and second moment running average coefficient
|
# Decay the first and second moment running average coefficient
|
||||||
exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
|
exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)
|
||||||
exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t)
|
|
||||||
denom = exp_avg_sq_prime.sqrt_().add_(eps)
|
|
||||||
|
|
||||||
p.data.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
||||||
p.data.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
|
p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
|
||||||
|
p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -51,6 +51,7 @@ class NvNovoGrad(Optimizer):
|
|||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
group.setdefault('amsgrad', False)
|
group.setdefault('amsgrad', False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
@ -60,13 +61,14 @@ class NvNovoGrad(Optimizer):
|
|||||||
"""
|
"""
|
||||||
loss = None
|
loss = None
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
grad = p.grad.data
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError('Sparse gradients are not supported.')
|
raise RuntimeError('Sparse gradients are not supported.')
|
||||||
amsgrad = group['amsgrad']
|
amsgrad = group['amsgrad']
|
||||||
@ -77,7 +79,7 @@ class NvNovoGrad(Optimizer):
|
|||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
# Exponential moving average of gradient values
|
# Exponential moving average of gradient values
|
||||||
state['exp_avg'] = torch.zeros_like(p.data)
|
state['exp_avg'] = torch.zeros_like(p)
|
||||||
# Exponential moving average of squared gradient values
|
# Exponential moving average of squared gradient values
|
||||||
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
|
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
|
||||||
if amsgrad:
|
if amsgrad:
|
||||||
@ -108,11 +110,11 @@ class NvNovoGrad(Optimizer):
|
|||||||
|
|
||||||
grad.div_(denom)
|
grad.div_(denom)
|
||||||
if group['weight_decay'] != 0:
|
if group['weight_decay'] != 0:
|
||||||
grad.add_(p.data, alpha=group['weight_decay'])
|
grad.add_(p, alpha=group['weight_decay'])
|
||||||
if group['grad_averaging']:
|
if group['grad_averaging']:
|
||||||
grad.mul_(1 - beta1)
|
grad.mul_(1 - beta1)
|
||||||
exp_avg.mul_(beta1).add_(grad)
|
exp_avg.mul_(beta1).add_(grad)
|
||||||
|
|
||||||
p.data.add_(exp_avg, alpha=-group['lr'])
|
p.add_(exp_avg, alpha=-group['lr'])
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -12,6 +12,7 @@ from .adafactor import Adafactor
|
|||||||
from .adahessian import Adahessian
|
from .adahessian import Adahessian
|
||||||
from .adamp import AdamP
|
from .adamp import AdamP
|
||||||
from .lamb import Lamb
|
from .lamb import Lamb
|
||||||
|
from .lars import Lars
|
||||||
from .lookahead import Lookahead
|
from .lookahead import Lookahead
|
||||||
from .madgrad import MADGRAD
|
from .madgrad import MADGRAD
|
||||||
from .nadam import Nadam
|
from .nadam import Nadam
|
||||||
@ -163,6 +164,14 @@ def create_optimizer_v2(
|
|||||||
optimizer = Adafactor(parameters, **opt_args)
|
optimizer = Adafactor(parameters, **opt_args)
|
||||||
elif opt_lower == 'lamb':
|
elif opt_lower == 'lamb':
|
||||||
optimizer = Lamb(parameters, **opt_args)
|
optimizer = Lamb(parameters, **opt_args)
|
||||||
|
elif opt_lower == 'larc':
|
||||||
|
optimizer = Lars(parameters, momentum=momentum, larc=True, **opt_args)
|
||||||
|
elif opt_lower == 'lars':
|
||||||
|
optimizer = Lars(parameters, momentum=momentum, **opt_args)
|
||||||
|
elif opt_lower == 'nlarc':
|
||||||
|
optimizer = Lars(parameters, momentum=momentum, larc=True, nesterov=True, **opt_args)
|
||||||
|
elif opt_lower == 'nlars':
|
||||||
|
optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||||
elif opt_lower == 'madgrad':
|
elif opt_lower == 'madgrad':
|
||||||
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
|
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
|
||||||
elif opt_lower == 'madgradw':
|
elif opt_lower == 'madgradw':
|
||||||
|
@ -18,31 +18,33 @@ class RAdam(Optimizer):
|
|||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
super(RAdam, self).__setstate__(state)
|
super(RAdam, self).__setstate__(state)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
loss = None
|
loss = None
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
|
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
grad = p.grad.data.float()
|
grad = p.grad.float()
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError('RAdam does not support sparse gradients')
|
raise RuntimeError('RAdam does not support sparse gradients')
|
||||||
|
|
||||||
p_data_fp32 = p.data.float()
|
p_fp32 = p.float()
|
||||||
|
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
|
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
state['exp_avg'] = torch.zeros_like(p_fp32)
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
state['exp_avg_sq'] = torch.zeros_like(p_fp32)
|
||||||
else:
|
else:
|
||||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
state['exp_avg'] = state['exp_avg'].type_as(p_fp32)
|
||||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32)
|
||||||
|
|
||||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||||
beta1, beta2 = group['betas']
|
beta1, beta2 = group['betas']
|
||||||
@ -73,15 +75,15 @@ class RAdam(Optimizer):
|
|||||||
buffered[2] = step_size
|
buffered[2] = step_size
|
||||||
|
|
||||||
if group['weight_decay'] != 0:
|
if group['weight_decay'] != 0:
|
||||||
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
|
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
|
||||||
|
|
||||||
# more conservative since it's an approximated value
|
# more conservative since it's an approximated value
|
||||||
if num_sma >= 5:
|
if num_sma >= 5:
|
||||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||||
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
|
p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
else:
|
else:
|
||||||
p_data_fp32.add_(exp_avg, alpha=-step_size)
|
p_fp32.add_(exp_avg, alpha=-step_size)
|
||||||
|
|
||||||
p.data.copy_(p_data_fp32)
|
p.copy_(p_fp32)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -4,7 +4,7 @@ Originally cut & paste from PyTorch RMSProp
|
|||||||
https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
|
https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
|
||||||
Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
|
Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
|
||||||
|
|
||||||
Modifications Copyright 2020 Ross Wightman
|
Modifications Copyright 2021 Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -69,6 +69,7 @@ class RMSpropTF(Optimizer):
|
|||||||
group.setdefault('momentum', 0)
|
group.setdefault('momentum', 0)
|
||||||
group.setdefault('centered', False)
|
group.setdefault('centered', False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
"""Performs a single optimization step.
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
@ -78,13 +79,14 @@ class RMSpropTF(Optimizer):
|
|||||||
"""
|
"""
|
||||||
loss = None
|
loss = None
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
grad = p.grad.data
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError('RMSprop does not support sparse gradients')
|
raise RuntimeError('RMSprop does not support sparse gradients')
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
@ -92,11 +94,11 @@ class RMSpropTF(Optimizer):
|
|||||||
# State initialization
|
# State initialization
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero
|
state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero
|
||||||
if group['momentum'] > 0:
|
if group['momentum'] > 0:
|
||||||
state['momentum_buffer'] = torch.zeros_like(p.data)
|
state['momentum_buffer'] = torch.zeros_like(p)
|
||||||
if group['centered']:
|
if group['centered']:
|
||||||
state['grad_avg'] = torch.zeros_like(p.data)
|
state['grad_avg'] = torch.zeros_like(p)
|
||||||
|
|
||||||
square_avg = state['square_avg']
|
square_avg = state['square_avg']
|
||||||
one_minus_alpha = 1. - group['alpha']
|
one_minus_alpha = 1. - group['alpha']
|
||||||
@ -105,9 +107,9 @@ class RMSpropTF(Optimizer):
|
|||||||
|
|
||||||
if group['weight_decay'] != 0:
|
if group['weight_decay'] != 0:
|
||||||
if group['decoupled_decay']:
|
if group['decoupled_decay']:
|
||||||
p.data.mul_(1. - group['lr'] * group['weight_decay'])
|
p.mul_(1. - group['lr'] * group['weight_decay'])
|
||||||
else:
|
else:
|
||||||
grad = grad.add(p.data, alpha=group['weight_decay'])
|
grad = grad.add(p, alpha=group['weight_decay'])
|
||||||
|
|
||||||
# Tensorflow order of ops for updating squared avg
|
# Tensorflow order of ops for updating squared avg
|
||||||
square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
|
square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
|
||||||
@ -126,12 +128,12 @@ class RMSpropTF(Optimizer):
|
|||||||
# Tensorflow accumulates the LR scaling in the momentum buffer
|
# Tensorflow accumulates the LR scaling in the momentum buffer
|
||||||
if group['lr_in_momentum']:
|
if group['lr_in_momentum']:
|
||||||
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
|
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
|
||||||
p.data.add_(-buf)
|
p.add_(-buf)
|
||||||
else:
|
else:
|
||||||
# PyTorch scales the param update by LR
|
# PyTorch scales the param update by LR
|
||||||
buf.mul_(group['momentum']).addcdiv_(grad, avg)
|
buf.mul_(group['momentum']).addcdiv_(grad, avg)
|
||||||
p.data.add_(buf, alpha=-group['lr'])
|
p.add_(buf, alpha=-group['lr'])
|
||||||
else:
|
else:
|
||||||
p.data.addcdiv_(grad, avg, value=-group['lr'])
|
p.addcdiv_(grad, avg, value=-group['lr'])
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -24,10 +24,12 @@ class SGDP(Optimizer):
|
|||||||
nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
|
nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
|
||||||
super(SGDP, self).__init__(params, defaults)
|
super(SGDP, self).__init__(params, defaults)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
loss = None
|
loss = None
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
loss = closure()
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
weight_decay = group['weight_decay']
|
weight_decay = group['weight_decay']
|
||||||
@ -38,12 +40,12 @@ class SGDP(Optimizer):
|
|||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
continue
|
continue
|
||||||
grad = p.grad.data
|
grad = p.grad
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
|
|
||||||
# State initialization
|
# State initialization
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['momentum'] = torch.zeros_like(p.data)
|
state['momentum'] = torch.zeros_like(p)
|
||||||
|
|
||||||
# SGD
|
# SGD
|
||||||
buf = state['momentum']
|
buf = state['momentum']
|
||||||
@ -60,9 +62,9 @@ class SGDP(Optimizer):
|
|||||||
|
|
||||||
# Weight decay
|
# Weight decay
|
||||||
if weight_decay != 0:
|
if weight_decay != 0:
|
||||||
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
|
p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
|
||||||
|
|
||||||
# Step
|
# Step
|
||||||
p.data.add_(d_p, alpha=-group['lr'])
|
p.add_(d_p, alpha=-group['lr'])
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
Loading…
x
Reference in New Issue
Block a user