A bit of lars/lamb cleanup, torch.where supports scalars properly now, make lamb grad clipping optional, clean it up a bit

small_384_weights
Ross Wightman 2024-11-07 21:42:24 -08:00 committed by Ross Wightman
parent 7cfaeced67
commit 9d8ccd2ba7
2 changed files with 55 additions and 35 deletions

View File

@ -85,14 +85,49 @@ class Lamb(Optimizer):
"""
def __init__(
self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False):
self,
params,
lr=1e-3,
bias_correction=True,
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=0.01,
grad_averaging=True,
max_grad_norm=1.0,
trust_clip=False,
always_adapt=False,
):
defaults = dict(
lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
trust_clip=trust_clip, always_adapt=always_adapt)
lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm,
trust_clip=trust_clip,
always_adapt=always_adapt,
)
super().__init__(params, defaults)
def _get_clip_grad_norm(self):
max_grad_norm = self.defaults['max_grad_norm']
if max_grad_norm is None:
return None
norms = []
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instead.')
norms.append(torch.linalg.vector_norm(grad))
global_norm = torch.linalg.vector_norm(torch.stack(norms))
clip_global_norm = (global_norm / max_grad_norm).clamp_(min=1.0)
return clip_global_norm
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@ -105,26 +140,7 @@ class Lamb(Optimizer):
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]['params'][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
global_grad_norm = torch.zeros(1, device=device)
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
global_grad_norm.add_(grad.pow(2).sum())
global_grad_norm = torch.sqrt(global_grad_norm)
# FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
clip_global_grad_norm = torch.where(
global_grad_norm > max_grad_norm,
global_grad_norm / max_grad_norm,
one_tensor)
clip_grad_norm = self._get_clip_grad_norm() # None if disabled
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
@ -148,7 +164,11 @@ class Lamb(Optimizer):
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.div_(clip_global_grad_norm)
grad = p.grad
if clip_grad_norm is not None:
grad.div_(clip_grad_norm)
state = self.state[p]
# State initialization
@ -176,15 +196,17 @@ class Lamb(Optimizer):
# excluded from weight decay, unless always_adapt == True, then always enabled.
w_norm = p.norm(2.0)
g_norm = update.norm(2.0)
trust_ratio = w_norm / g_norm
# FIXME nested where required since logical and/or not working in PT XLA
# Set the ratio to 1.0 (no change) if either weight norm or grad norm is zero
trust_ratio = torch.where(
w_norm > 0,
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
one_tensor,
torch.where(g_norm > 0, trust_ratio, 1.0),
1.0,
)
if group['trust_clip']:
# LAMBC trust clipping, upper bound fixed at one
trust_ratio = torch.minimum(trust_ratio, one_tensor)
trust_ratio = torch.clamp(trust_ratio, max=1.0)
update.mul_(trust_ratio)
p.add_(update, alpha=-group['lr'])

View File

@ -84,9 +84,6 @@ class Lars(Optimizer):
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]['params'][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
@ -107,13 +104,14 @@ class Lars(Optimizer):
g_norm = grad.norm(2.0)
trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
# FIXME nested where required since logical and/or not working in PT XLA
# Set the ratio to 1.0 (no change) if either weight norm or grad norm is zero
trust_ratio = torch.where(
w_norm > 0,
torch.where(g_norm > 0, trust_ratio, one_tensor),
one_tensor,
torch.where(g_norm > 0, trust_ratio, 1.0),
1.0,
)
if group['trust_clip']:
trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
trust_ratio = torch.clamp(trust_ratio / group['lr'], max=1.0)
grad.add_(p, alpha=weight_decay)
grad.mul_(trust_ratio)