Change adafactor_bv epsilon default

small_384_weights
Ross Wightman 2024-11-05 13:03:13 -08:00 committed by Ross Wightman
parent 0b5ae49251
commit 7cfaeced67
1 changed files with 6 additions and 1 deletions

View File

@ -51,7 +51,7 @@ class AdafactorBigVision(Optimizer):
beta2_cap: float = 0.999,
momentum: Optional[float] = 0.9,
momentum_dtype: Union[str, torch.dtype] = torch.bfloat16,
eps: float = 1e-30,
eps: Optional[float] = None,
weight_decay: float = 0.0,
clipping_threshold: Optional[float] = None,
unscaled_wd: bool = False,
@ -66,6 +66,7 @@ class AdafactorBigVision(Optimizer):
else:
assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported'
momentum_dtype = torch.float32
# FIXME try to check if momentum dtype is appropriate for device? Torch API not great for this.
defaults = dict(
lr=lr,
@ -212,6 +213,9 @@ def _single_tensor_adafactor(
exp_avg_sq = exp_avg_sqs[i]
exp_avg = exp_avgs[i]
step_t = state_steps[i]
if eps is None:
# use square of machine eps for grad dtype if not set
eps = torch.finfo(grad.dtype).eps ** 2
# Update step
step_t += 1
@ -219,6 +223,7 @@ def _single_tensor_adafactor(
one_minus_beta2_t = 1 - beta2_t
grad_sqr = torch.square(grad) + eps
# NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach
if exp_avg_sq is None:
# factorized second moment
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)