mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove unused beta2 fn, make eps grad^2 handling same across factorized and non-factorized cases
This commit is contained in:
parent
7c16adca83
commit
484a88f4b4
@ -95,11 +95,6 @@ class AdafactorBigVision(Optimizer):
|
|||||||
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
if len(p_state) != 0 and not torch.is_tensor(p_state['step']):
|
||||||
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
|
p_state['step'] = torch.tensor(float(p_state['step']), dtype=_get_scalar_dtype())
|
||||||
|
|
||||||
def _get_beta2(self, step: Tensor, decay_rate: float, beta2_cap: float) -> float:
|
|
||||||
"""Computes beta2 according to the step schedule"""
|
|
||||||
t = float(step + 1)
|
|
||||||
return min(beta2_cap, 1.0 - t ** (-decay_rate))
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
loss = None
|
loss = None
|
||||||
@ -219,9 +214,10 @@ def _single_tensor_adafactor(
|
|||||||
beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay))
|
beta2_t = min(beta2_cap, 1.0 - float(step_t) ** (-beta2_decay))
|
||||||
one_minus_beta2_t = 1 - beta2_t
|
one_minus_beta2_t = 1 - beta2_t
|
||||||
|
|
||||||
if exp_avg_sq is None:
|
|
||||||
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
|
|
||||||
grad_sqr = torch.square(grad) + eps
|
grad_sqr = torch.square(grad) + eps
|
||||||
|
if exp_avg_sq is None:
|
||||||
|
# factorized second moment
|
||||||
|
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
|
||||||
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t)
|
exp_avg_sq_r.lerp_(grad_sqr.mean(dim=d0, keepdim=True), one_minus_beta2_t)
|
||||||
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t)
|
exp_avg_sq_c.lerp_(grad_sqr.mean(dim=d1, keepdim=True), one_minus_beta2_t)
|
||||||
|
|
||||||
@ -232,9 +228,10 @@ def _single_tensor_adafactor(
|
|||||||
|
|
||||||
update = grad * row_factor * col_factor
|
update = grad * row_factor * col_factor
|
||||||
else:
|
else:
|
||||||
# Handle non-factored
|
# non-factorized second moment
|
||||||
exp_avg_sq.mul_(beta2_t).addcmul_(grad, grad, value=one_minus_beta2_t)
|
assert exp_avg_sq_r is None and exp_avg_sq_c is None
|
||||||
update = grad * exp_avg_sq.add(eps).rsqrt_()
|
exp_avg_sq.lerp_(grad_sqr, one_minus_beta2_t)
|
||||||
|
update = grad * exp_avg_sq.rsqrt()
|
||||||
|
|
||||||
# Clip by RMS value
|
# Clip by RMS value
|
||||||
if clipping_threshold is not None:
|
if clipping_threshold is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user