mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
add madgradw optimizer
This commit is contained in:
parent
55fb5eedf6
commit
a6af48be64
@ -490,7 +490,7 @@ def test_lamb(optimizer):
|
|||||||
_test_model(optimizer, dict(lr=1e-3))
|
_test_model(optimizer, dict(lr=1e-3))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', ['madgrad'])
|
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
|
||||||
def test_madgrad(optimizer):
|
def test_madgrad(optimizer):
|
||||||
_test_basic_cases(
|
_test_basic_cases(
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||||
|
@ -53,7 +53,13 @@ class MADGRAD(torch.optim.Optimizer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, params: _params_t, lr: float = 1e-2, momentum: float = 0.9, weight_decay: float = 0, eps: float = 1e-6,
|
self,
|
||||||
|
params: _params_t,
|
||||||
|
lr: float = 1e-2,
|
||||||
|
momentum: float = 0.9,
|
||||||
|
weight_decay: float = 0,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
decoupled_decay: bool = False,
|
||||||
):
|
):
|
||||||
if momentum < 0 or momentum >= 1:
|
if momentum < 0 or momentum >= 1:
|
||||||
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
|
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
|
||||||
@ -64,7 +70,8 @@ class MADGRAD(torch.optim.Optimizer):
|
|||||||
if eps < 0:
|
if eps < 0:
|
||||||
raise ValueError(f"Eps must be non-negative")
|
raise ValueError(f"Eps must be non-negative")
|
||||||
|
|
||||||
defaults = dict(lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay)
|
defaults = dict(
|
||||||
|
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay)
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -95,7 +102,7 @@ class MADGRAD(torch.optim.Optimizer):
|
|||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
eps = group["eps"]
|
eps = group["eps"]
|
||||||
lr = group["lr"] + eps
|
lr = group["lr"] + eps
|
||||||
decay = group["weight_decay"]
|
weight_decay = group["weight_decay"]
|
||||||
momentum = group["momentum"]
|
momentum = group["momentum"]
|
||||||
|
|
||||||
ck = 1 - momentum
|
ck = 1 - momentum
|
||||||
@ -120,11 +127,13 @@ class MADGRAD(torch.optim.Optimizer):
|
|||||||
s = state["s"]
|
s = state["s"]
|
||||||
|
|
||||||
# Apply weight decay
|
# Apply weight decay
|
||||||
if decay != 0:
|
if weight_decay != 0:
|
||||||
|
if group['decoupled_decay']:
|
||||||
|
p.data.mul_(1.0 - group['lr'] * weight_decay)
|
||||||
|
else:
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
||||||
|
grad.add_(p.data, alpha=weight_decay)
|
||||||
grad.add_(p.data, alpha=decay)
|
|
||||||
|
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
grad = grad.coalesce()
|
grad = grad.coalesce()
|
||||||
|
@ -165,6 +165,8 @@ def create_optimizer_v2(
|
|||||||
optimizer = Lamb(parameters, **opt_args)
|
optimizer = Lamb(parameters, **opt_args)
|
||||||
elif opt_lower == 'madgrad':
|
elif opt_lower == 'madgrad':
|
||||||
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
|
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
|
||||||
|
elif opt_lower == 'madgradw':
|
||||||
|
optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
|
||||||
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
|
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
|
||||||
optimizer = NvNovoGrad(parameters, **opt_args)
|
optimizer = NvNovoGrad(parameters, **opt_args)
|
||||||
elif opt_lower == 'rmsprop':
|
elif opt_lower == 'rmsprop':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user