add madgradw optimizer
parent
55fb5eedf6
commit
a6af48be64
|
@ -490,7 +490,7 @@ def test_lamb(optimizer):
|
|||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['madgrad'])
|
||||
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
|
||||
def test_madgrad(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
|
|
|
@ -53,7 +53,13 @@ class MADGRAD(torch.optim.Optimizer):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, params: _params_t, lr: float = 1e-2, momentum: float = 0.9, weight_decay: float = 0, eps: float = 1e-6,
|
||||
self,
|
||||
params: _params_t,
|
||||
lr: float = 1e-2,
|
||||
momentum: float = 0.9,
|
||||
weight_decay: float = 0,
|
||||
eps: float = 1e-6,
|
||||
decoupled_decay: bool = False,
|
||||
):
|
||||
if momentum < 0 or momentum >= 1:
|
||||
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
|
||||
|
@ -64,7 +70,8 @@ class MADGRAD(torch.optim.Optimizer):
|
|||
if eps < 0:
|
||||
raise ValueError(f"Eps must be non-negative")
|
||||
|
||||
defaults = dict(lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay)
|
||||
defaults = dict(
|
||||
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@property
|
||||
|
@ -95,7 +102,7 @@ class MADGRAD(torch.optim.Optimizer):
|
|||
for group in self.param_groups:
|
||||
eps = group["eps"]
|
||||
lr = group["lr"] + eps
|
||||
decay = group["weight_decay"]
|
||||
weight_decay = group["weight_decay"]
|
||||
momentum = group["momentum"]
|
||||
|
||||
ck = 1 - momentum
|
||||
|
@ -120,11 +127,13 @@ class MADGRAD(torch.optim.Optimizer):
|
|||
s = state["s"]
|
||||
|
||||
# Apply weight decay
|
||||
if decay != 0:
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
||||
|
||||
grad.add_(p.data, alpha=decay)
|
||||
if weight_decay != 0:
|
||||
if group['decoupled_decay']:
|
||||
p.data.mul_(1.0 - group['lr'] * weight_decay)
|
||||
else:
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
||||
grad.add_(p.data, alpha=weight_decay)
|
||||
|
||||
if grad.is_sparse:
|
||||
grad = grad.coalesce()
|
||||
|
|
|
@ -165,6 +165,8 @@ def create_optimizer_v2(
|
|||
optimizer = Lamb(parameters, **opt_args)
|
||||
elif opt_lower == 'madgrad':
|
||||
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'madgradw':
|
||||
optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
|
||||
elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
|
||||
optimizer = NvNovoGrad(parameters, **opt_args)
|
||||
elif opt_lower == 'rmsprop':
|
||||
|
|
Loading…
Reference in New Issue