diff --git a/timm/optim/lamb.py b/timm/optim/lamb.py index 12c7c49b..9d3a3421 100644 --- a/timm/optim/lamb.py +++ b/timm/optim/lamb.py @@ -85,14 +85,49 @@ class Lamb(Optimizer): """ def __init__( - self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, - weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False): + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=0.01, + grad_averaging=True, + max_grad_norm=1.0, + trust_clip=False, + always_adapt=False, + ): defaults = dict( - lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, - grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, - trust_clip=trust_clip, always_adapt=always_adapt) + lr=lr, + bias_correction=bias_correction, + betas=betas, + eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + max_grad_norm=max_grad_norm, + trust_clip=trust_clip, + always_adapt=always_adapt, + ) super().__init__(params, defaults) + def _get_clip_grad_norm(self): + max_grad_norm = self.defaults['max_grad_norm'] + if max_grad_norm is None: + return None + + norms = [] + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instead.') + norms.append(torch.linalg.vector_norm(grad)) + global_norm = torch.linalg.vector_norm(torch.stack(norms)) + clip_global_norm = (global_norm / max_grad_norm).clamp_(min=1.0) + return clip_global_norm + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -105,26 +140,7 @@ class Lamb(Optimizer): with torch.enable_grad(): loss = closure() - device = self.param_groups[0]['params'][0].device - one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly - global_grad_norm = torch.zeros(1, device=device) - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - grad = p.grad - if grad.is_sparse: - raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') - global_grad_norm.add_(grad.pow(2).sum()) - - global_grad_norm = torch.sqrt(global_grad_norm) - # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes - # scalar types properly https://github.com/pytorch/pytorch/issues/9190 - max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) - clip_global_grad_norm = torch.where( - global_grad_norm > max_grad_norm, - global_grad_norm / max_grad_norm, - one_tensor) + clip_grad_norm = self._get_clip_grad_norm() # None if disabled for group in self.param_groups: bias_correction = 1 if group['bias_correction'] else 0 @@ -148,7 +164,11 @@ class Lamb(Optimizer): for p in group['params']: if p.grad is None: continue - grad = p.grad.div_(clip_global_grad_norm) + grad = p.grad + + if clip_grad_norm is not None: + grad.div_(clip_grad_norm) + state = self.state[p] # State initialization @@ -176,15 +196,17 @@ class Lamb(Optimizer): # excluded from weight decay, unless always_adapt == True, then always enabled. w_norm = p.norm(2.0) g_norm = update.norm(2.0) + trust_ratio = w_norm / g_norm # FIXME nested where required since logical and/or not working in PT XLA + # Set the ratio to 1.0 (no change) if either weight norm or grad norm is zero trust_ratio = torch.where( w_norm > 0, - torch.where(g_norm > 0, w_norm / g_norm, one_tensor), - one_tensor, + torch.where(g_norm > 0, trust_ratio, 1.0), + 1.0, ) if group['trust_clip']: # LAMBC trust clipping, upper bound fixed at one - trust_ratio = torch.minimum(trust_ratio, one_tensor) + trust_ratio = torch.clamp(trust_ratio, max=1.0) update.mul_(trust_ratio) p.add_(update, alpha=-group['lr']) diff --git a/timm/optim/lars.py b/timm/optim/lars.py index 38ca9e0b..d49efc6d 100644 --- a/timm/optim/lars.py +++ b/timm/optim/lars.py @@ -84,9 +84,6 @@ class Lars(Optimizer): with torch.enable_grad(): loss = closure() - device = self.param_groups[0]['params'][0].device - one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly - for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum'] @@ -107,13 +104,14 @@ class Lars(Optimizer): g_norm = grad.norm(2.0) trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps) # FIXME nested where required since logical and/or not working in PT XLA + # Set the ratio to 1.0 (no change) if either weight norm or grad norm is zero trust_ratio = torch.where( w_norm > 0, - torch.where(g_norm > 0, trust_ratio, one_tensor), - one_tensor, + torch.where(g_norm > 0, trust_ratio, 1.0), + 1.0, ) if group['trust_clip']: - trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor) + trust_ratio = torch.clamp(trust_ratio / group['lr'], max=1.0) grad.add_(p, alpha=weight_decay) grad.mul_(trust_ratio)