A bit of lars/lamb cleanup, torch.where supports scalars properly now, make lamb grad clipping optional, clean it up a bit
parent
7cfaeced67
commit
9d8ccd2ba7
|
@ -85,14 +85,49 @@ class Lamb(Optimizer):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
|
||||
weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False):
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
bias_correction=True,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-6,
|
||||
weight_decay=0.01,
|
||||
grad_averaging=True,
|
||||
max_grad_norm=1.0,
|
||||
trust_clip=False,
|
||||
always_adapt=False,
|
||||
):
|
||||
defaults = dict(
|
||||
lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
|
||||
trust_clip=trust_clip, always_adapt=always_adapt)
|
||||
lr=lr,
|
||||
bias_correction=bias_correction,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
grad_averaging=grad_averaging,
|
||||
max_grad_norm=max_grad_norm,
|
||||
trust_clip=trust_clip,
|
||||
always_adapt=always_adapt,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def _get_clip_grad_norm(self):
|
||||
max_grad_norm = self.defaults['max_grad_norm']
|
||||
if max_grad_norm is None:
|
||||
return None
|
||||
|
||||
norms = []
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instead.')
|
||||
norms.append(torch.linalg.vector_norm(grad))
|
||||
global_norm = torch.linalg.vector_norm(torch.stack(norms))
|
||||
clip_global_norm = (global_norm / max_grad_norm).clamp_(min=1.0)
|
||||
return clip_global_norm
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
@ -105,26 +140,7 @@ class Lamb(Optimizer):
|
|||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
device = self.param_groups[0]['params'][0].device
|
||||
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
||||
global_grad_norm = torch.zeros(1, device=device)
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
|
||||
global_grad_norm.add_(grad.pow(2).sum())
|
||||
|
||||
global_grad_norm = torch.sqrt(global_grad_norm)
|
||||
# FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
|
||||
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
|
||||
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
|
||||
clip_global_grad_norm = torch.where(
|
||||
global_grad_norm > max_grad_norm,
|
||||
global_grad_norm / max_grad_norm,
|
||||
one_tensor)
|
||||
clip_grad_norm = self._get_clip_grad_norm() # None if disabled
|
||||
|
||||
for group in self.param_groups:
|
||||
bias_correction = 1 if group['bias_correction'] else 0
|
||||
|
@ -148,7 +164,11 @@ class Lamb(Optimizer):
|
|||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.div_(clip_global_grad_norm)
|
||||
grad = p.grad
|
||||
|
||||
if clip_grad_norm is not None:
|
||||
grad.div_(clip_grad_norm)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
|
@ -176,15 +196,17 @@ class Lamb(Optimizer):
|
|||
# excluded from weight decay, unless always_adapt == True, then always enabled.
|
||||
w_norm = p.norm(2.0)
|
||||
g_norm = update.norm(2.0)
|
||||
trust_ratio = w_norm / g_norm
|
||||
# FIXME nested where required since logical and/or not working in PT XLA
|
||||
# Set the ratio to 1.0 (no change) if either weight norm or grad norm is zero
|
||||
trust_ratio = torch.where(
|
||||
w_norm > 0,
|
||||
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
|
||||
one_tensor,
|
||||
torch.where(g_norm > 0, trust_ratio, 1.0),
|
||||
1.0,
|
||||
)
|
||||
if group['trust_clip']:
|
||||
# LAMBC trust clipping, upper bound fixed at one
|
||||
trust_ratio = torch.minimum(trust_ratio, one_tensor)
|
||||
trust_ratio = torch.clamp(trust_ratio, max=1.0)
|
||||
update.mul_(trust_ratio)
|
||||
|
||||
p.add_(update, alpha=-group['lr'])
|
||||
|
|
|
@ -84,9 +84,6 @@ class Lars(Optimizer):
|
|||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
device = self.param_groups[0]['params'][0].device
|
||||
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
||||
|
||||
for group in self.param_groups:
|
||||
weight_decay = group['weight_decay']
|
||||
momentum = group['momentum']
|
||||
|
@ -107,13 +104,14 @@ class Lars(Optimizer):
|
|||
g_norm = grad.norm(2.0)
|
||||
trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
|
||||
# FIXME nested where required since logical and/or not working in PT XLA
|
||||
# Set the ratio to 1.0 (no change) if either weight norm or grad norm is zero
|
||||
trust_ratio = torch.where(
|
||||
w_norm > 0,
|
||||
torch.where(g_norm > 0, trust_ratio, one_tensor),
|
||||
one_tensor,
|
||||
torch.where(g_norm > 0, trust_ratio, 1.0),
|
||||
1.0,
|
||||
)
|
||||
if group['trust_clip']:
|
||||
trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
|
||||
trust_ratio = torch.clamp(trust_ratio / group['lr'], max=1.0)
|
||||
grad.add_(p, alpha=weight_decay)
|
||||
grad.mul_(trust_ratio)
|
||||
|
||||
|
|
Loading…
Reference in New Issue