Change adafactor_bv epsilon default
parent
0b5ae49251
commit
7cfaeced67
|
@ -51,7 +51,7 @@ class AdafactorBigVision(Optimizer):
|
|||
beta2_cap: float = 0.999,
|
||||
momentum: Optional[float] = 0.9,
|
||||
momentum_dtype: Union[str, torch.dtype] = torch.bfloat16,
|
||||
eps: float = 1e-30,
|
||||
eps: Optional[float] = None,
|
||||
weight_decay: float = 0.0,
|
||||
clipping_threshold: Optional[float] = None,
|
||||
unscaled_wd: bool = False,
|
||||
|
@ -66,6 +66,7 @@ class AdafactorBigVision(Optimizer):
|
|||
else:
|
||||
assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported'
|
||||
momentum_dtype = torch.float32
|
||||
# FIXME try to check if momentum dtype is appropriate for device? Torch API not great for this.
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
|
@ -212,6 +213,9 @@ def _single_tensor_adafactor(
|
|||
exp_avg_sq = exp_avg_sqs[i]
|
||||
exp_avg = exp_avgs[i]
|
||||
step_t = state_steps[i]
|
||||
if eps is None:
|
||||
# use square of machine eps for grad dtype if not set
|
||||
eps = torch.finfo(grad.dtype).eps ** 2
|
||||
|
||||
# Update step
|
||||
step_t += 1
|
||||
|
@ -219,6 +223,7 @@ def _single_tensor_adafactor(
|
|||
one_minus_beta2_t = 1 - beta2_t
|
||||
|
||||
grad_sqr = torch.square(grad) + eps
|
||||
# NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach
|
||||
if exp_avg_sq is None:
|
||||
# factorized second moment
|
||||
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
|
||||
|
|
Loading…
Reference in New Issue