mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove experiment from lamb impl
This commit is contained in:
parent
8a9eca5157
commit
55fb5eedf6
@ -463,7 +463,7 @@ def test_adafactor(optimizer):
|
|||||||
_test_model(optimizer, dict(lr=5e-2))
|
_test_model(optimizer, dict(lr=5e-2))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['lamb', 'lambw'])
|
@pytest.mark.parametrize('optimizer', ['lamb'])
|
||||||
def test_lamb(optimizer):
|
def test_lamb(optimizer):
|
||||||
_test_basic_cases(
|
_test_basic_cases(
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||||
|
@ -18,7 +18,7 @@ class AdaBelief(Optimizer):
|
|||||||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
|
||||||
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
algorithm from the paper `On the Convergence of Adam and Beyond`_
|
||||||
(default: False)
|
(default: False)
|
||||||
decoupled_decay (boolean, optional): ( default: True) If set as True, then
|
decoupled_decay (boolean, optional): (default: True) If set as True, then
|
||||||
the optimizer uses decoupled weight decay as in AdamW
|
the optimizer uses decoupled weight decay as in AdamW
|
||||||
fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
|
fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
|
||||||
is set as True.
|
is set as True.
|
||||||
@ -194,7 +194,7 @@ class AdaBelief(Optimizer):
|
|||||||
denom = exp_avg_var.sqrt().add_(group['eps'])
|
denom = exp_avg_var.sqrt().add_(group['eps'])
|
||||||
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
|
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
|
||||||
elif step_size > 0:
|
elif step_size > 0:
|
||||||
p.data.add_( exp_avg, alpha=-step_size * group['lr'])
|
p.data.add_(exp_avg, alpha=-step_size * group['lr'])
|
||||||
|
|
||||||
if half_precision:
|
if half_precision:
|
||||||
p.data = p.data.half()
|
p.data = p.data.half()
|
||||||
|
@ -84,13 +84,11 @@ class Lamb(Optimizer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
|
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
|
||||||
grad_averaging=True, max_grad_norm=1.0, decoupled_decay=False, use_nvlamb=False):
|
weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False):
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr, bias_correction=bias_correction,
|
lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||||
betas=betas, eps=eps, weight_decay=weight_decay,
|
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb)
|
||||||
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
|
|
||||||
decoupled_decay=decoupled_decay, use_nvlamb=use_nvlamb)
|
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
@ -136,8 +134,6 @@ class Lamb(Optimizer):
|
|||||||
else:
|
else:
|
||||||
group['step'] = 1
|
group['step'] = 1
|
||||||
|
|
||||||
step_size = group['lr']
|
|
||||||
|
|
||||||
if bias_correction:
|
if bias_correction:
|
||||||
bias_correction1 = 1 - beta1 ** group['step']
|
bias_correction1 = 1 - beta1 ** group['step']
|
||||||
bias_correction2 = 1 - beta2 ** group['step']
|
bias_correction2 = 1 - beta2 ** group['step']
|
||||||
@ -157,11 +153,6 @@ class Lamb(Optimizer):
|
|||||||
# Exponential moving average of squared gradient values
|
# Exponential moving average of squared gradient values
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||||
|
|
||||||
decoupled_decay = group['decoupled_decay']
|
|
||||||
weight_decay = group['weight_decay']
|
|
||||||
if decoupled_decay and weight_decay != 0:
|
|
||||||
p.data.mul_(1. - group['lr'] * weight_decay)
|
|
||||||
|
|
||||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||||
|
|
||||||
# Decay the first and second moment running average coefficient
|
# Decay the first and second moment running average coefficient
|
||||||
@ -171,7 +162,8 @@ class Lamb(Optimizer):
|
|||||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||||
update = (exp_avg / bias_correction1).div_(denom)
|
update = (exp_avg / bias_correction1).div_(denom)
|
||||||
|
|
||||||
if not decoupled_decay and weight_decay != 0:
|
weight_decay = group['weight_decay']
|
||||||
|
if weight_decay != 0:
|
||||||
update.add_(p.data, alpha=weight_decay)
|
update.add_(p.data, alpha=weight_decay)
|
||||||
|
|
||||||
trust_ratio = one_tensor
|
trust_ratio = one_tensor
|
||||||
@ -186,6 +178,6 @@ class Lamb(Optimizer):
|
|||||||
one_tensor,
|
one_tensor,
|
||||||
)
|
)
|
||||||
update.mul_(trust_ratio)
|
update.mul_(trust_ratio)
|
||||||
p.data.add_(update, alpha=-step_size)
|
p.data.add_(update, alpha=-group['lr'])
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -163,8 +163,6 @@ def create_optimizer_v2(
|
|||||||
optimizer = Adafactor(parameters, **opt_args)
|
optimizer = Adafactor(parameters, **opt_args)
|
||||||
elif opt_lower == 'lamb':
|
elif opt_lower == 'lamb':
|
||||||
optimizer = Lamb(parameters, **opt_args)
|
optimizer = Lamb(parameters, **opt_args)
|
||||||
elif opt_lower == 'lambw':
|
|
||||||
optimizer = Lamb(parameters, decoupled_decay=True, **opt_args) # FIXME experimental
|
|
||||||
elif opt_lower == 'madgrad':
|
elif opt_lower == 'madgrad':
|
||||||
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
|
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
|
||||||
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
|
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user