diff --git a/timm/optim/adafactor_bv.py b/timm/optim/adafactor_bv.py index bad982ee..465f8f0e 100644 --- a/timm/optim/adafactor_bv.py +++ b/timm/optim/adafactor_bv.py @@ -139,11 +139,11 @@ class AdafactorBigVision(Optimizer): ) if factored_dims is not None: - d1, d0 = factored_dims + dc, dr = factored_dims row_shape = list(p.grad.shape) - row_shape[d0] = 1 + row_shape[dr] = 1 col_shape = list(p.grad.shape) - col_shape[d1] = 1 + col_shape[dc] = 1 state['exp_avg_sq_r'] = p.grad.new_zeros(row_shape) state['exp_avg_sq_c'] = p.grad.new_zeros(col_shape) else: @@ -226,12 +226,12 @@ def _single_tensor_adafactor( # NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach if exp_avg_sq is None: # factorized second moment - d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor) - exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t) - exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t) + dc, dr = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor) + exp_avg_sq_r.lerp_(grad_sqr.mean(dim=dr, keepdim=True), one_minus_beta2_t) + exp_avg_sq_c.lerp_(grad_sqr.mean(dim=dc, keepdim=True), one_minus_beta2_t) - reduced_d1 = d1 - 1 if d1 > d0 else d1 - row_col_mean = exp_avg_sq_r.mean(dim=reduced_d1, keepdim=True) + reduce_dc = dc - 1 if dc > dr else dc + row_col_mean = exp_avg_sq_r.mean(dim=reduce_dc, keepdim=True) row_factor = (exp_avg_sq_r / row_col_mean).rsqrt() col_factor = exp_avg_sq_c.rsqrt()