mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Change adafactor_bv epsilon default
This commit is contained in:
parent
0b5ae49251
commit
7cfaeced67
@ -51,7 +51,7 @@ class AdafactorBigVision(Optimizer):
|
|||||||
beta2_cap: float = 0.999,
|
beta2_cap: float = 0.999,
|
||||||
momentum: Optional[float] = 0.9,
|
momentum: Optional[float] = 0.9,
|
||||||
momentum_dtype: Union[str, torch.dtype] = torch.bfloat16,
|
momentum_dtype: Union[str, torch.dtype] = torch.bfloat16,
|
||||||
eps: float = 1e-30,
|
eps: Optional[float] = None,
|
||||||
weight_decay: float = 0.0,
|
weight_decay: float = 0.0,
|
||||||
clipping_threshold: Optional[float] = None,
|
clipping_threshold: Optional[float] = None,
|
||||||
unscaled_wd: bool = False,
|
unscaled_wd: bool = False,
|
||||||
@ -66,6 +66,7 @@ class AdafactorBigVision(Optimizer):
|
|||||||
else:
|
else:
|
||||||
assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported'
|
assert momentum_dtype == 'float32', f'{momentum_dtype} dtype not supported'
|
||||||
momentum_dtype = torch.float32
|
momentum_dtype = torch.float32
|
||||||
|
# FIXME try to check if momentum dtype is appropriate for device? Torch API not great for this.
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
@ -212,6 +213,9 @@ def _single_tensor_adafactor(
|
|||||||
exp_avg_sq = exp_avg_sqs[i]
|
exp_avg_sq = exp_avg_sqs[i]
|
||||||
exp_avg = exp_avgs[i]
|
exp_avg = exp_avgs[i]
|
||||||
step_t = state_steps[i]
|
step_t = state_steps[i]
|
||||||
|
if eps is None:
|
||||||
|
# use square of machine eps for grad dtype if not set
|
||||||
|
eps = torch.finfo(grad.dtype).eps ** 2
|
||||||
|
|
||||||
# Update step
|
# Update step
|
||||||
step_t += 1
|
step_t += 1
|
||||||
@ -219,6 +223,7 @@ def _single_tensor_adafactor(
|
|||||||
one_minus_beta2_t = 1 - beta2_t
|
one_minus_beta2_t = 1 - beta2_t
|
||||||
|
|
||||||
grad_sqr = torch.square(grad) + eps
|
grad_sqr = torch.square(grad) + eps
|
||||||
|
# NOTE application of eps (epsilon1) mirrors the optax/big vision/t5x approach
|
||||||
if exp_avg_sq is None:
|
if exp_avg_sq is None:
|
||||||
# factorized second moment
|
# factorized second moment
|
||||||
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
|
d1, d0 = _factored_dims(grad.shape, True, min_dim_size_to_factor=min_dim_size_to_factor)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user