Remove unused beta2 fn, make eps grad^2 handling same across factorized and non-factorized cases

This commit is contained in:
Ross Wightman 2024-11-04 09:23:04 -08:00 committed by Ross Wightman
parent 7c16adca83
commit 484a88f4b4

View File

@ -95,11 +95,6 @@ class AdafactorBigVision(Optimizer):
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
def _get_beta2(self, step: Tensor, decay_rate: float, beta2_cap: float) -> float:
"""Computes beta2 according to the step schedule"""
t = float(step + 1)
return min(beta2_cap, 1.0 - t ** (-decay_rate))
@torch.no_grad()
def step(self, closure=None):
loss = None
@ -219,9 +214,10 @@ def _single_tensor_adafactor(
beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay))
one_minus_beta2_t = 1 - beta2_t
grad_sqr = torch.square(grad) + eps
if exp_avg_sq is None:
# factorized second moment
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
grad_sqr = torch.square(grad) + eps
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t)
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t)
@ -232,9 +228,10 @@ def _single_tensor_adafactor(
update = grad * row_factor * col_factor
else:
# Handle non-factored
exp_avg_sq.mul_(beta2_t).addcmul_(grad, grad, value=one_minus_beta2_t)
update = grad * exp_avg_sq.add(eps).rsqrt_()
# non-factorized second moment
assert exp_avg_sq_r is None and exp_avg_sq_c is None
exp_avg_sq.lerp_(grad_sqr, one_minus_beta2_t)
update = grad * exp_avg_sq.rsqrt()
# Clip by RMS value
if clipping_threshold is not None: