44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
|
|
|
|
class LARS(torch.optim.Optimizer):
|
|
"""
|
|
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
|
|
"""
|
|
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
|
|
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
|
|
super().__init__(params, defaults)
|
|
|
|
@torch.no_grad()
|
|
def step(self):
|
|
for g in self.param_groups:
|
|
for p in g['params']:
|
|
dp = p.grad
|
|
|
|
if dp is None:
|
|
continue
|
|
|
|
if p.ndim > 1: # if not normalization gamma/beta or bias
|
|
dp = dp.add(p, alpha=g['weight_decay'])
|
|
param_norm = torch.norm(p)
|
|
update_norm = torch.norm(dp)
|
|
one = torch.ones_like(param_norm)
|
|
q = torch.where(param_norm > 0.,
|
|
torch.where(update_norm > 0,
|
|
(g['trust_coefficient'] * param_norm / update_norm), one),
|
|
one)
|
|
dp = dp.mul(q)
|
|
|
|
param_state = self.state[p]
|
|
if 'mu' not in param_state:
|
|
param_state['mu'] = torch.zeros_like(p)
|
|
mu = param_state['mu']
|
|
mu.mul_(g['momentum']).add_(dp)
|
|
p.add_(mu, alpha=-g['lr'])
|