mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove unused beta2 fn, make eps grad^2 handling same across factorized and non-factorized cases
This commit is contained in:
parent
7c16adca83
commit
484a88f4b4
@ -95,11 +95,6 @@ class AdafactorBigVision(Optimizer):
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
|
||||
|
||||
def _get_beta2(self, step: Tensor, decay_rate: float, beta2_cap: float) -> float:
|
||||
"""Computes beta2 according to the step schedule"""
|
||||
t = float(step + 1)
|
||||
return min(beta2_cap, 1.0 - t ** (-decay_rate))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
@ -219,9 +214,10 @@ def _single_tensor_adafactor(
|
||||
beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay))
|
||||
one_minus_beta2_t = 1 - beta2_t
|
||||
|
||||
grad_sqr = torch.square(grad) + eps
|
||||
if exp_avg_sq is None:
|
||||
# factorized second moment
|
||||
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
|
||||
grad_sqr = torch.square(grad) + eps
|
||||
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t)
|
||||
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t)
|
||||
|
||||
@ -232,9 +228,10 @@ def _single_tensor_adafactor(
|
||||
|
||||
update = grad * row_factor * col_factor
|
||||
else:
|
||||
# Handle non-factored
|
||||
exp_avg_sq.mul_(beta2_t).addcmul_(grad, grad, value=one_minus_beta2_t)
|
||||
update = grad * exp_avg_sq.add(eps).rsqrt_()
|
||||
# non-factorized second moment
|
||||
assert exp_avg_sq_r is None and exp_avg_sq_c is None
|
||||
exp_avg_sq.lerp_(grad_sqr, one_minus_beta2_t)
|
||||
update = grad * exp_avg_sq.rsqrt()
|
||||
|
||||
# Clip by RMS value
|
||||
if clipping_threshold is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user