diff --git a/timm/optim/lamb.py b/timm/optim/lamb.py
index 12c7c49b..9d3a3421 100644
--- a/timm/optim/lamb.py
+++ b/timm/optim/lamb.py
@@ -85,14 +85,49 @@ class Lamb(Optimizer):
     """
 
     def __init__(
-            self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
-            weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False):
+            self,
+            params,
+            lr=1e-3,
+            bias_correction=True,
+            betas=(0.9, 0.999),
+            eps=1e-6,
+            weight_decay=0.01,
+            grad_averaging=True,
+            max_grad_norm=1.0,
+            trust_clip=False,
+            always_adapt=False,
+    ):
         defaults = dict(
-            lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
-            grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
-            trust_clip=trust_clip, always_adapt=always_adapt)
+            lr=lr,
+            bias_correction=bias_correction,
+            betas=betas,
+            eps=eps,
+            weight_decay=weight_decay,
+            grad_averaging=grad_averaging,
+            max_grad_norm=max_grad_norm,
+            trust_clip=trust_clip,
+            always_adapt=always_adapt,
+        )
         super().__init__(params, defaults)
 
+    def _get_clip_grad_norm(self):
+        max_grad_norm = self.defaults['max_grad_norm']
+        if max_grad_norm is None:
+            return None
+
+        norms = []
+        for group in self.param_groups:
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad
+                if grad.is_sparse:
+                    raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instead.')
+                norms.append(torch.linalg.vector_norm(grad))
+        global_norm = torch.linalg.vector_norm(torch.stack(norms))
+        clip_global_norm = (global_norm / max_grad_norm).clamp_(min=1.0)
+        return clip_global_norm
+
     @torch.no_grad()
     def step(self, closure=None):
         """Performs a single optimization step.
@@ -105,26 +140,7 @@ class Lamb(Optimizer):
             with torch.enable_grad():
                 loss = closure()
 
-        device = self.param_groups[0]['params'][0].device
-        one_tensor = torch.tensor(1.0, device=device)  # because torch.where doesn't handle scalars correctly
-        global_grad_norm = torch.zeros(1, device=device)
-        for group in self.param_groups:
-            for p in group['params']:
-                if p.grad is None:
-                    continue
-                grad = p.grad
-                if grad.is_sparse:
-                    raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
-                global_grad_norm.add_(grad.pow(2).sum())
-
-        global_grad_norm = torch.sqrt(global_grad_norm)
-        # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
-        # scalar types properly https://github.com/pytorch/pytorch/issues/9190
-        max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
-        clip_global_grad_norm = torch.where(
-            global_grad_norm > max_grad_norm,
-            global_grad_norm / max_grad_norm,
-            one_tensor)
+        clip_grad_norm = self._get_clip_grad_norm() # None if disabled
 
         for group in self.param_groups:
             bias_correction = 1 if group['bias_correction'] else 0
@@ -148,7 +164,11 @@ class Lamb(Optimizer):
             for p in group['params']:
                 if p.grad is None:
                     continue
-                grad = p.grad.div_(clip_global_grad_norm)
+                grad = p.grad
+
+                if clip_grad_norm is not None:
+                    grad.div_(clip_grad_norm)
+
                 state = self.state[p]
 
                 # State initialization
@@ -176,15 +196,17 @@ class Lamb(Optimizer):
                     # excluded from weight decay, unless always_adapt == True, then always enabled.
                     w_norm = p.norm(2.0)
                     g_norm = update.norm(2.0)
+                    trust_ratio = w_norm / g_norm
                     # FIXME nested where required since logical and/or not working in PT XLA
+                    # Set the ratio to 1.0 (no change) if either weight norm or grad norm is zero
                     trust_ratio = torch.where(
                         w_norm > 0,
-                        torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
-                        one_tensor,
+                        torch.where(g_norm > 0, trust_ratio, 1.0),
+                        1.0,
                     )
                     if group['trust_clip']:
                         # LAMBC trust clipping, upper bound fixed at one
-                        trust_ratio = torch.minimum(trust_ratio, one_tensor)
+                        trust_ratio = torch.clamp(trust_ratio, max=1.0)
                     update.mul_(trust_ratio)
 
                 p.add_(update, alpha=-group['lr'])
diff --git a/timm/optim/lars.py b/timm/optim/lars.py
index 38ca9e0b..d49efc6d 100644
--- a/timm/optim/lars.py
+++ b/timm/optim/lars.py
@@ -84,9 +84,6 @@ class Lars(Optimizer):
             with torch.enable_grad():
                 loss = closure()
 
-        device = self.param_groups[0]['params'][0].device
-        one_tensor = torch.tensor(1.0, device=device)  # because torch.where doesn't handle scalars correctly
-
         for group in self.param_groups:
             weight_decay = group['weight_decay']
             momentum = group['momentum']
@@ -107,13 +104,14 @@ class Lars(Optimizer):
                     g_norm = grad.norm(2.0)
                     trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
                     # FIXME nested where required since logical and/or not working in PT XLA
+                    # Set the ratio to 1.0 (no change) if either weight norm or grad norm is zero
                     trust_ratio = torch.where(
                         w_norm > 0,
-                        torch.where(g_norm > 0, trust_ratio, one_tensor),
-                        one_tensor,
+                        torch.where(g_norm > 0, trust_ratio, 1.0),
+                        1.0,
                     )
                     if group['trust_clip']:
-                        trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
+                        trust_ratio = torch.clamp(trust_ratio / group['lr'], max=1.0)
                     grad.add_(p, alpha=weight_decay)
                     grad.mul_(trust_ratio)