From 484a88f4b4cbd1fefcaad2e24140ac59492222e8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 4 Nov 2024 09:23:04 -0800 Subject: [PATCH] Remove unused beta2 fn, make eps grad^2 handling same across factorized and non-factorized cases --- timm/optim/adafactor_bv.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/timm/optim/adafactor_bv.py b/timm/optim/adafactor_bv.py index d320f63d..62736a33 100644 --- a/timm/optim/adafactor_bv.py +++ b/timm/optim/adafactor_bv.py @@ -95,11 +95,6 @@ class AdafactorBigVision(Optimizer): if len(p_state) != 0 and not torch.is_tensor(p_state['step']): p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype()) - def _get_beta2(self, step: Tensor, decay_rate: float, beta2_cap: float) -> float: - """Computes beta2 according to the step schedule""" - t = float(step + 1) - return min(beta2_cap, 1.0 - t ** (-decay_rate)) - @torch.no_grad() def step(self, closure=None): loss = None @@ -219,9 +214,10 @@ def _single_tensor_adafactor( beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay)) one_minus_beta2_t = 1 - beta2_t + grad_sqr = torch.square(grad) + eps if exp_avg_sq is None: + # factorized second moment d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor) - grad_sqr = torch.square(grad) + eps exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t) exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t) @@ -232,9 +228,10 @@ def _single_tensor_adafactor( update = grad * row_factor * col_factor else: - # Handle non-factored - exp_avg_sq.mul_(beta2_t).addcmul_(grad, grad, value=one_minus_beta2_t) - update = grad * exp_avg_sq.add(eps).rsqrt_() + # non-factorized second moment + assert exp_avg_sq_r is None and exp_avg_sq_c is None + exp_avg_sq.lerp_(grad_sqr, one_minus_beta2_t) + update = grad * exp_avg_sq.rsqrt() # Clip by RMS value if clipping_threshold is not None: