diff --git a/timm/optim/adafactor_bv.py b/timm/optim/adafactor_bv.py index ea8a4afa..d603b528 100644 --- a/timm/optim/adafactor_bv.py +++ b/timm/optim/adafactor_bv.py @@ -51,7 +51,7 @@ class AdafactorBigVision(Optimizer): beta2_cap: float = 0.999, momentum: Optional[float] = 0.9, momentum_dtype: Union[str, torch.dtype] = torch.bfloat16, - eps: float = 1e-30, + eps: Optional[float] = None, weight_decay: float = 0.0, clipping_threshold: Optional[float] = None, unscaled_wd: bool = False, @@ -66,6 +66,7 @@ class AdafactorBigVision(Optimizer): else: assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported' momentum_dtype = torch.float32 + # FIXME try to check if momentum dtype is appropriate for device? Torch API not great for this. defaults = dict( lr=lr, @@ -212,6 +213,9 @@ def _single_tensor_adafactor( exp_avg_sq = exp_avg_sqs[i] exp_avg = exp_avgs[i] step_t = state_steps[i] + if eps is None: + # use square of machine eps for grad dtype if not set + eps = torch.finfo(grad.dtype).eps ** 2 # Update step step_t += 1 @@ -219,6 +223,7 @@ def _single_tensor_adafactor( one_minus_beta2_t = 1 - beta2_t grad_sqr = torch.square(grad) + eps + # NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach if exp_avg_sq is None: # factorized second moment d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)