mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More optimizer cleanup. Change all to no longer use .data. Improve (b)float16 use with adabelief. Add XLA compatible Lars.
This commit is contained in:
parent
9541f4963b
commit
a426511c95
@ -490,6 +490,33 @@ def test_lamb(optimizer):
|
||||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
|
||||
def test_lars(optimizer):
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=1e-3),
|
||||
optimizer,
|
||||
lr=1e-1)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-3),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-3), optimizer)
|
||||
)
|
||||
_test_rosenbrock(
|
||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
|
||||
)
|
||||
_test_model(optimizer, dict(lr=1e-3))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optimizer', ['madgrad', 'madgradw'])
|
||||
def test_madgrad(optimizer):
|
||||
_test_basic_cases(
|
||||
|
@ -68,6 +68,7 @@ class AdaBelief(Optimizer):
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def reset(self):
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
@ -77,14 +78,15 @@ class AdaBelief(Optimizer):
|
||||
# State initialization
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_var'] = torch.zeros_like(p.data)
|
||||
state['exp_avg_var'] = torch.zeros_like(p)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_var'] = torch.zeros_like(p.data)
|
||||
state['max_exp_avg_var'] = torch.zeros_like(p)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Arguments:
|
||||
@ -93,50 +95,47 @@ class AdaBelief(Optimizer):
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
# cast data type
|
||||
half_precision = False
|
||||
if p.data.dtype == torch.float16:
|
||||
half_precision = True
|
||||
p.data = p.data.float()
|
||||
p.grad = p.grad.float()
|
||||
|
||||
grad = p.grad.data
|
||||
grad = p.grad
|
||||
if grad.dtype in {torch.float16, torch.bfloat16}:
|
||||
grad = grad.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
|
||||
|
||||
p_fp32 = p
|
||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
p_fp32 = p_fp32.float()
|
||||
|
||||
amsgrad = group['amsgrad']
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state = self.state[p]
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
state['exp_avg'] = torch.zeros_like(p_fp32)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_var'] = torch.zeros_like(p.data)
|
||||
state['exp_avg_var'] = torch.zeros_like(p_fp32)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_var'] = torch.zeros_like(p.data)
|
||||
state['max_exp_avg_var'] = torch.zeros_like(p_fp32)
|
||||
|
||||
# perform weight decay, check if decoupled weight decay
|
||||
if group['decoupled_decay']:
|
||||
if not group['fixed_decay']:
|
||||
p.data.mul_(1.0 - group['lr'] * group['weight_decay'])
|
||||
p_fp32.mul_(1.0 - group['lr'] * group['weight_decay'])
|
||||
else:
|
||||
p.data.mul_(1.0 - group['weight_decay'])
|
||||
p_fp32.mul_(1.0 - group['weight_decay'])
|
||||
else:
|
||||
if group['weight_decay'] != 0:
|
||||
grad.add_(p.data, alpha=group['weight_decay'])
|
||||
grad.add_(p_fp32, alpha=group['weight_decay'])
|
||||
|
||||
# get current state variable
|
||||
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
|
||||
@ -164,7 +163,7 @@ class AdaBelief(Optimizer):
|
||||
if not group['rectify']:
|
||||
# Default update
|
||||
step_size = group['lr'] / bias_correction1
|
||||
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
else:
|
||||
# Rectified update, forked from RAdam
|
||||
buffered = group['buffer'][int(state['step'] % 10)]
|
||||
@ -192,12 +191,11 @@ class AdaBelief(Optimizer):
|
||||
|
||||
if num_sma >= 5:
|
||||
denom = exp_avg_var.sqrt().add_(group['eps'])
|
||||
p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
|
||||
p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
|
||||
elif step_size > 0:
|
||||
p.data.add_(exp_avg, alpha=-step_size * group['lr'])
|
||||
p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
|
||||
|
||||
if half_precision:
|
||||
p.data = p.data.half()
|
||||
p.grad = p.grad.half()
|
||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
p.copy_(p_fp32)
|
||||
|
||||
return loss
|
||||
|
@ -76,6 +76,7 @@ class Adafactor(torch.optim.Optimizer):
|
||||
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
||||
return torch.mul(r_factor, c_factor)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Arguments:
|
||||
@ -83,22 +84,22 @@ class Adafactor(torch.optim.Optimizer):
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
grad = p.grad
|
||||
if grad.dtype in {torch.float16, torch.bfloat16}:
|
||||
grad = grad.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adafactor does not support sparse gradients.')
|
||||
|
||||
state = self.state[p]
|
||||
grad_shape = grad.shape
|
||||
|
||||
factored, use_first_moment = self._get_options(group, grad_shape)
|
||||
factored, use_first_moment = self._get_options(group, grad.shape)
|
||||
# State Initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
@ -107,8 +108,8 @@ class Adafactor(torch.optim.Optimizer):
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(grad)
|
||||
if factored:
|
||||
state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
|
||||
state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
||||
state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad)
|
||||
state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad)
|
||||
else:
|
||||
state['exp_avg_sq'] = torch.zeros_like(grad)
|
||||
|
||||
@ -122,12 +123,12 @@ class Adafactor(torch.optim.Optimizer):
|
||||
else:
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
|
||||
|
||||
p_data_fp32 = p.data
|
||||
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
||||
p_data_fp32 = p_data_fp32.float()
|
||||
p_fp32 = p
|
||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
p_fp32 = p_fp32.float()
|
||||
|
||||
state['step'] += 1
|
||||
state['RMS'] = self._rms(p_data_fp32)
|
||||
state['RMS'] = self._rms(p_fp32)
|
||||
lr_t = self._get_lr(group, state)
|
||||
|
||||
beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
|
||||
@ -157,11 +158,10 @@ class Adafactor(torch.optim.Optimizer):
|
||||
update = exp_avg
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t)
|
||||
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t)
|
||||
|
||||
p_data_fp32.add_(-update)
|
||||
|
||||
if p.data.dtype in {torch.float16, torch.bfloat16}:
|
||||
p.data.copy_(p_data_fp32)
|
||||
p_fp32.add_(-update)
|
||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
p.copy_(p_fp32)
|
||||
|
||||
return loss
|
||||
|
@ -26,12 +26,13 @@ def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
|
||||
wd = 1.
|
||||
expand_size = (-1,) + (1,) * (len(p.shape) - 1)
|
||||
for view_func in [_channel_view, _layer_view]:
|
||||
param_view = view_func(p.data)
|
||||
param_view = view_func(p)
|
||||
grad_view = view_func(grad)
|
||||
cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
|
||||
|
||||
# FIXME this is a problem for PyTorch XLA
|
||||
if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
|
||||
p_n = p.data / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
|
||||
p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
|
||||
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
|
||||
wd = wd_ratio
|
||||
return perturb, wd
|
||||
@ -47,9 +48,11 @@ class AdamP(Optimizer):
|
||||
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
|
||||
super(AdamP, self).__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
@ -57,7 +60,7 @@ class AdamP(Optimizer):
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
grad = p.grad.data
|
||||
grad = p.grad
|
||||
beta1, beta2 = group['betas']
|
||||
nesterov = group['nesterov']
|
||||
|
||||
@ -66,8 +69,8 @@ class AdamP(Optimizer):
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
|
||||
# Adam
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
@ -94,9 +97,9 @@ class AdamP(Optimizer):
|
||||
|
||||
# Weight decay
|
||||
if group['weight_decay'] > 0:
|
||||
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
|
||||
p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
|
||||
|
||||
# Step
|
||||
p.data.add_(perturb, alpha=-step_size)
|
||||
p.add_(perturb, alpha=-step_size)
|
||||
|
||||
return loss
|
||||
|
@ -55,6 +55,7 @@ class AdamW(Optimizer):
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
@ -64,6 +65,7 @@ class AdamW(Optimizer):
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
@ -75,7 +77,7 @@ class AdamW(Optimizer):
|
||||
p.data.mul_(1 - group['lr'] * group['weight_decay'])
|
||||
|
||||
# Perform optimization step
|
||||
grad = p.grad.data
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
amsgrad = group['amsgrad']
|
||||
@ -86,12 +88,12 @@ class AdamW(Optimizer):
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
if amsgrad:
|
||||
@ -115,6 +117,6 @@ class AdamW(Optimizer):
|
||||
|
||||
step_size = group['lr'] / bias_correction1
|
||||
|
||||
p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
return loss
|
||||
|
@ -5,10 +5,14 @@ This optimizer code was adapted from the following (starting with latest)
|
||||
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
|
||||
* https://github.com/cybertronai/pytorch-lamb
|
||||
|
||||
Use FusedLamb if you can. The reason for including this variant of Lamb is to have a version that is
|
||||
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install APEX for whatever reason.
|
||||
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
|
||||
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
|
||||
|
||||
In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
|
||||
|
||||
Original copyrights for above sources are below.
|
||||
|
||||
Modifications Copyright 2021 Ross Wightman
|
||||
"""
|
||||
# Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
|
||||
|
||||
@ -60,8 +64,7 @@ class Lamb(Optimizer):
|
||||
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups.
|
||||
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
|
||||
lr (float, optional): learning rate. (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its norm. (default: (0.9, 0.999))
|
||||
@ -72,8 +75,7 @@ class Lamb(Optimizer):
|
||||
calculating running averages of gradient. (default: True)
|
||||
set_grad_none (bool, optional): whether set grad to None when zero_grad()
|
||||
method is called. (default: True)
|
||||
max_grad_norm (float, optional): value used to clip global grad norm
|
||||
(default: 1.0)
|
||||
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
|
||||
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
|
||||
weight decay parameter (default: False)
|
||||
|
||||
@ -91,25 +93,26 @@ class Lamb(Optimizer):
|
||||
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, use_nvlamb=use_nvlamb)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
device = self.param_groups[0]["params"][0].device
|
||||
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
device = self.param_groups[0]["params"][0].device
|
||||
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
||||
global_grad_norm = torch.zeros(1, device=device)
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
|
||||
global_grad_norm.add_(grad.pow(2).sum())
|
||||
@ -145,15 +148,15 @@ class Lamb(Optimizer):
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.div_(clip_global_grad_norm)
|
||||
grad = p.grad.div_(clip_global_grad_norm)
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of gradient valuesa
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
|
||||
@ -166,20 +169,21 @@ class Lamb(Optimizer):
|
||||
|
||||
weight_decay = group['weight_decay']
|
||||
if weight_decay != 0:
|
||||
update.add_(p.data, alpha=weight_decay)
|
||||
update.add_(p, alpha=weight_decay)
|
||||
|
||||
trust_ratio = one_tensor
|
||||
if weight_decay != 0 or group['use_nvlamb']:
|
||||
# Layer adaptation. By default, skip layer adaptation on parameters that are
|
||||
# excluded from weight decay, unless use_nvlamb == True, then always enabled.
|
||||
w_norm = p.data.norm(2.0)
|
||||
w_norm = p.norm(2.0)
|
||||
g_norm = update.norm(2.0)
|
||||
# FIXME nested where required since logical and/or not working in PT XLA
|
||||
trust_ratio = torch.where(
|
||||
w_norm > 0,
|
||||
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
|
||||
one_tensor,
|
||||
)
|
||||
update.mul_(trust_ratio)
|
||||
p.data.add_(update, alpha=-group['lr'])
|
||||
|
||||
p.add_(update, alpha=-group['lr'])
|
||||
|
||||
return loss
|
||||
|
136
timm/optim/lars.py
Normal file
136
timm/optim/lars.py
Normal file
@ -0,0 +1,136 @@
|
||||
""" PyTorch LARS / LARC Optimizer
|
||||
|
||||
An implementation of LARS (SGD) + LARC in PyTorch
|
||||
|
||||
Based on:
|
||||
* PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
|
||||
* NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
|
||||
|
||||
Additional cleanup and modifications to properly support PyTorch XLA.
|
||||
|
||||
Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer, required
|
||||
|
||||
|
||||
class Lars(Optimizer):
|
||||
""" LARS for PyTorch
|
||||
|
||||
Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf
|
||||
|
||||
Args:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
|
||||
lr (float, optional): learning rate. (default: 1e-3)
|
||||
momentum (float, optional): momentum factor (default: 0)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
dampening (float, optional): dampening for momentum (default: 0)
|
||||
nesterov (bool, optional): enables Nesterov momentum (default: False)
|
||||
trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001)
|
||||
eps (float): eps for division denominator (default: 1e-8)
|
||||
larc (bool): enable LARC clipping (default: False)
|
||||
always_scale (bool): always apply LARS scaling, otherwise only when group weight_decay != 0 (default: False)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=required,
|
||||
momentum=0,
|
||||
dampening=0,
|
||||
weight_decay=0,
|
||||
nesterov=False,
|
||||
trust_coeff=0.001,
|
||||
eps=1e-8,
|
||||
larc=False,
|
||||
always_scale=False,
|
||||
):
|
||||
if lr is not required and lr < 0.0:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if momentum < 0.0:
|
||||
raise ValueError(f"Invalid momentum value: {momentum}")
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
if nesterov and (momentum <= 0 or dampening != 0):
|
||||
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
dampening=dampening,
|
||||
weight_decay=weight_decay,
|
||||
nesterov=nesterov,
|
||||
trust_coeff=trust_coeff,
|
||||
eps=eps,
|
||||
larc=larc,
|
||||
always_scale=always_scale,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault("nesterov", False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Args:
|
||||
closure (callable, optional): A closure that reevaluates the model and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
device = self.param_groups[0]["params"][0].device
|
||||
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
||||
|
||||
# exclude scaling for params with 0 weight decay
|
||||
for group in self.param_groups:
|
||||
weight_decay = group['weight_decay']
|
||||
momentum = group['momentum']
|
||||
dampening = group['dampening']
|
||||
nesterov = group['nesterov']
|
||||
trust_coeff = group['trust_coeff']
|
||||
eps = group['eps']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
|
||||
# apply LARS scaling, LARC clipping, weight decay
|
||||
# ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
|
||||
if weight_decay != 0 or group['always_scale']:
|
||||
w_norm = p.norm(2.0)
|
||||
g_norm = grad.norm(2.0)
|
||||
trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
|
||||
# FIXME nested where required since logical and/or not working in PT XLA
|
||||
trust_ratio = torch.where(
|
||||
w_norm > 0,
|
||||
torch.where(g_norm > 0, trust_ratio, one_tensor),
|
||||
one_tensor,
|
||||
)
|
||||
if group['larc']:
|
||||
trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
|
||||
grad.add(p, alpha=weight_decay)
|
||||
grad.mul_(trust_ratio)
|
||||
|
||||
# apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
|
||||
if momentum != 0:
|
||||
param_state = self.state[p]
|
||||
if 'momentum_buffer' not in param_state:
|
||||
buf = param_state['momentum_buffer'] = torch.clone(grad).detach()
|
||||
else:
|
||||
buf = param_state['momentum_buffer']
|
||||
buf.mul_(momentum).add_(grad, alpha=1. - dampening)
|
||||
if nesterov:
|
||||
grad = grad.add(buf, alpha=momentum)
|
||||
else:
|
||||
grad = buf
|
||||
|
||||
p.add_(grad, alpha=-group['lr'])
|
||||
|
||||
return loss
|
@ -27,22 +27,24 @@ class Lookahead(Optimizer):
|
||||
for group in self._base_optimizer.param_groups:
|
||||
group.setdefault(name, default)
|
||||
|
||||
@torch.no_grad()
|
||||
def update_slow(self, group):
|
||||
for fast_p in group["params"]:
|
||||
if fast_p.grad is None:
|
||||
continue
|
||||
param_state = self._base_optimizer.state[fast_p]
|
||||
if 'lookahead_slow_buff' not in param_state:
|
||||
param_state['lookahead_slow_buff'] = torch.empty_like(fast_p.data)
|
||||
param_state['lookahead_slow_buff'].copy_(fast_p.data)
|
||||
param_state['lookahead_slow_buff'] = torch.empty_like(fast_p)
|
||||
param_state['lookahead_slow_buff'].copy_(fast_p)
|
||||
slow = param_state['lookahead_slow_buff']
|
||||
slow.add_(fast_p.data - slow, alpha=group['lookahead_alpha'])
|
||||
fast_p.data.copy_(slow)
|
||||
slow.add_(fast_p - slow, alpha=group['lookahead_alpha'])
|
||||
fast_p.copy_(slow)
|
||||
|
||||
def sync_lookahead(self):
|
||||
for group in self._base_optimizer.param_groups:
|
||||
self.update_slow(group)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = self._base_optimizer.step(closure)
|
||||
for group in self._base_optimizer.param_groups:
|
||||
|
@ -82,6 +82,7 @@ class MADGRAD(torch.optim.Optimizer):
|
||||
def supports_flat_params(self) -> bool:
|
||||
return True
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||
"""Performs a single optimization step.
|
||||
|
||||
@ -91,13 +92,10 @@ class MADGRAD(torch.optim.Optimizer):
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
# step counter must be stored in state to ensure correct behavior under
|
||||
# optimizer sharding
|
||||
if 'k' not in self.state:
|
||||
self.state['k'] = torch.tensor([0], dtype=torch.long)
|
||||
k = self.state['k'].item()
|
||||
step = self.state.setdefault('step', 0) # k
|
||||
|
||||
for group in self.param_groups:
|
||||
eps = group["eps"]
|
||||
@ -106,19 +104,19 @@ class MADGRAD(torch.optim.Optimizer):
|
||||
momentum = group["momentum"]
|
||||
|
||||
ck = 1 - momentum
|
||||
lamb = lr * math.pow(k + 1, 0.5)
|
||||
lamb = lr * math.sqrt(step + 1)
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
grad = p.grad
|
||||
state = self.state[p]
|
||||
|
||||
if "grad_sum_sq" not in state:
|
||||
state["grad_sum_sq"] = torch.zeros_like(p.data).detach()
|
||||
state["s"] = torch.zeros_like(p.data).detach()
|
||||
state["grad_sum_sq"] = torch.zeros_like(p)
|
||||
state["s"] = torch.zeros_like(p)
|
||||
if momentum != 0:
|
||||
state["x0"] = torch.clone(p.data).detach()
|
||||
state["x0"] = torch.clone(p).detach()
|
||||
|
||||
if momentum != 0.0 and grad.is_sparse:
|
||||
raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
|
||||
@ -129,11 +127,11 @@ class MADGRAD(torch.optim.Optimizer):
|
||||
# Apply weight decay
|
||||
if weight_decay != 0:
|
||||
if group['decoupled_decay']:
|
||||
p.data.mul_(1.0 - group['lr'] * weight_decay)
|
||||
p.mul_(1.0 - group['lr'] * weight_decay)
|
||||
else:
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("weight_decay option is not compatible with sparse gradients")
|
||||
grad.add_(p.data, alpha=weight_decay)
|
||||
grad.add_(p, alpha=weight_decay)
|
||||
|
||||
if grad.is_sparse:
|
||||
grad = grad.coalesce()
|
||||
@ -161,12 +159,12 @@ class MADGRAD(torch.optim.Optimizer):
|
||||
p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
|
||||
# Copy updated masked p to dense p using an add operation
|
||||
p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
|
||||
p.data.add_(p_masked, alpha=-1)
|
||||
p.add_(p_masked, alpha=-1)
|
||||
else:
|
||||
if momentum == 0:
|
||||
# Compute x_0 from other known quantities
|
||||
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
||||
x0 = p.data.addcdiv(s, rms, value=1)
|
||||
x0 = p.addcdiv(s, rms, value=1)
|
||||
else:
|
||||
x0 = state["x0"]
|
||||
|
||||
@ -175,16 +173,16 @@ class MADGRAD(torch.optim.Optimizer):
|
||||
rms = grad_sum_sq.pow(1 / 3).add_(eps)
|
||||
|
||||
# Update s
|
||||
s.data.add_(grad, alpha=lamb)
|
||||
s.add_(grad, alpha=lamb)
|
||||
|
||||
# Step
|
||||
if momentum == 0:
|
||||
p.data.copy_(x0.addcdiv(s, rms, value=-1))
|
||||
p.copy_(x0.addcdiv(s, rms, value=-1))
|
||||
else:
|
||||
z = x0.addcdiv(s, rms, value=-1)
|
||||
|
||||
# p is a moving average of z
|
||||
p.data.mul_(1 - ck).add_(z, alpha=ck)
|
||||
p.mul_(1 - ck).add_(z, alpha=ck)
|
||||
|
||||
self.state['k'] += 1
|
||||
self.state['step'] += 1
|
||||
return loss
|
||||
|
@ -1,3 +1,5 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
@ -33,6 +35,7 @@ class Nadam(Optimizer):
|
||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
|
||||
super(Nadam, self).__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
@ -42,21 +45,22 @@ class Nadam(Optimizer):
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
grad = p.grad
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['m_schedule'] = 1.
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
|
||||
# Warming momentum schedule
|
||||
m_schedule = state['m_schedule']
|
||||
@ -66,9 +70,10 @@ class Nadam(Optimizer):
|
||||
eps = group['eps']
|
||||
state['step'] += 1
|
||||
t = state['step']
|
||||
bias_correction2 = 1 - beta2 ** t
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
grad = grad.add(p.data, alpha=group['weight_decay'])
|
||||
grad = grad.add(p, alpha=group['weight_decay'])
|
||||
|
||||
momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
|
||||
momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
|
||||
@ -79,10 +84,9 @@ class Nadam(Optimizer):
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)
|
||||
exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t)
|
||||
denom = exp_avg_sq_prime.sqrt_().add_(eps)
|
||||
|
||||
p.data.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
|
||||
p.data.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
||||
p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
|
||||
p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
|
||||
|
||||
return loss
|
||||
|
@ -51,6 +51,7 @@ class NvNovoGrad(Optimizer):
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
@ -60,13 +61,14 @@ class NvNovoGrad(Optimizer):
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Sparse gradients are not supported.')
|
||||
amsgrad = group['amsgrad']
|
||||
@ -77,7 +79,7 @@ class NvNovoGrad(Optimizer):
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
state['exp_avg'] = torch.zeros_like(p)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
|
||||
if amsgrad:
|
||||
@ -108,11 +110,11 @@ class NvNovoGrad(Optimizer):
|
||||
|
||||
grad.div_(denom)
|
||||
if group['weight_decay'] != 0:
|
||||
grad.add_(p.data, alpha=group['weight_decay'])
|
||||
grad.add_(p, alpha=group['weight_decay'])
|
||||
if group['grad_averaging']:
|
||||
grad.mul_(1 - beta1)
|
||||
exp_avg.mul_(beta1).add_(grad)
|
||||
|
||||
p.data.add_(exp_avg, alpha=-group['lr'])
|
||||
p.add_(exp_avg, alpha=-group['lr'])
|
||||
|
||||
return loss
|
||||
|
@ -12,6 +12,7 @@ from .adafactor import Adafactor
|
||||
from .adahessian import Adahessian
|
||||
from .adamp import AdamP
|
||||
from .lamb import Lamb
|
||||
from .lars import Lars
|
||||
from .lookahead import Lookahead
|
||||
from .madgrad import MADGRAD
|
||||
from .nadam import Nadam
|
||||
@ -163,6 +164,14 @@ def create_optimizer_v2(
|
||||
optimizer = Adafactor(parameters, **opt_args)
|
||||
elif opt_lower == 'lamb':
|
||||
optimizer = Lamb(parameters, **opt_args)
|
||||
elif opt_lower == 'larc':
|
||||
optimizer = Lars(parameters, momentum=momentum, larc=True, **opt_args)
|
||||
elif opt_lower == 'lars':
|
||||
optimizer = Lars(parameters, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'nlarc':
|
||||
optimizer = Lars(parameters, momentum=momentum, larc=True, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'nlars':
|
||||
optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'madgrad':
|
||||
optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'madgradw':
|
||||
|
@ -18,9 +18,11 @@ class RAdam(Optimizer):
|
||||
def __setstate__(self, state):
|
||||
super(RAdam, self).__setstate__(state)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
@ -28,21 +30,21 @@ class RAdam(Optimizer):
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
grad = p.grad.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('RAdam does not support sparse gradients')
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
p_fp32 = p.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg'] = torch.zeros_like(p_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
@ -73,15 +75,15 @@ class RAdam(Optimizer):
|
||||
buffered[2] = step_size
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
|
||||
p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if num_sma >= 5:
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
else:
|
||||
p_data_fp32.add_(exp_avg, alpha=-step_size)
|
||||
p_fp32.add_(exp_avg, alpha=-step_size)
|
||||
|
||||
p.data.copy_(p_data_fp32)
|
||||
p.copy_(p_fp32)
|
||||
|
||||
return loss
|
||||
|
@ -4,7 +4,7 @@ Originally cut & paste from PyTorch RMSProp
|
||||
https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
|
||||
Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
|
||||
|
||||
Modifications Copyright 2020 Ross Wightman
|
||||
Modifications Copyright 2021 Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
@ -69,6 +69,7 @@ class RMSpropTF(Optimizer):
|
||||
group.setdefault('momentum', 0)
|
||||
group.setdefault('centered', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
@ -78,13 +79,14 @@ class RMSpropTF(Optimizer):
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('RMSprop does not support sparse gradients')
|
||||
state = self.state[p]
|
||||
@ -92,11 +94,11 @@ class RMSpropTF(Optimizer):
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero
|
||||
state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero
|
||||
if group['momentum'] > 0:
|
||||
state['momentum_buffer'] = torch.zeros_like(p.data)
|
||||
state['momentum_buffer'] = torch.zeros_like(p)
|
||||
if group['centered']:
|
||||
state['grad_avg'] = torch.zeros_like(p.data)
|
||||
state['grad_avg'] = torch.zeros_like(p)
|
||||
|
||||
square_avg = state['square_avg']
|
||||
one_minus_alpha = 1. - group['alpha']
|
||||
@ -105,9 +107,9 @@ class RMSpropTF(Optimizer):
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
if group['decoupled_decay']:
|
||||
p.data.mul_(1. - group['lr'] * group['weight_decay'])
|
||||
p.mul_(1. - group['lr'] * group['weight_decay'])
|
||||
else:
|
||||
grad = grad.add(p.data, alpha=group['weight_decay'])
|
||||
grad = grad.add(p, alpha=group['weight_decay'])
|
||||
|
||||
# Tensorflow order of ops for updating squared avg
|
||||
square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
|
||||
@ -126,12 +128,12 @@ class RMSpropTF(Optimizer):
|
||||
# Tensorflow accumulates the LR scaling in the momentum buffer
|
||||
if group['lr_in_momentum']:
|
||||
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
|
||||
p.data.add_(-buf)
|
||||
p.add_(-buf)
|
||||
else:
|
||||
# PyTorch scales the param update by LR
|
||||
buf.mul_(group['momentum']).addcdiv_(grad, avg)
|
||||
p.data.add_(buf, alpha=-group['lr'])
|
||||
p.add_(buf, alpha=-group['lr'])
|
||||
else:
|
||||
p.data.addcdiv_(grad, avg, value=-group['lr'])
|
||||
p.addcdiv_(grad, avg, value=-group['lr'])
|
||||
|
||||
return loss
|
||||
|
@ -24,9 +24,11 @@ class SGDP(Optimizer):
|
||||
nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
|
||||
super(SGDP, self).__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
@ -38,12 +40,12 @@ class SGDP(Optimizer):
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
grad = p.grad
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['momentum'] = torch.zeros_like(p.data)
|
||||
state['momentum'] = torch.zeros_like(p)
|
||||
|
||||
# SGD
|
||||
buf = state['momentum']
|
||||
@ -60,9 +62,9 @@ class SGDP(Optimizer):
|
||||
|
||||
# Weight decay
|
||||
if weight_decay != 0:
|
||||
p.data.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
|
||||
p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
|
||||
|
||||
# Step
|
||||
p.data.add_(d_p, alpha=-group['lr'])
|
||||
p.add_(d_p, alpha=-group['lr'])
|
||||
|
||||
return loss
|
||||
|
Loading…
x
Reference in New Issue
Block a user