mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Missed extra nadam algo step for capturable path
This commit is contained in:
parent
4790c0fa16
commit
2d597b126d
@ -315,6 +315,11 @@ def _multi_tensor_nadamw(
|
||||
|
||||
bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2)
|
||||
|
||||
# Only difference between NAdamW and AdamW in this implementation.
|
||||
# The official PyTorch implementation of NAdam uses a different algorithm.
|
||||
exp_avgs = torch._foreach_mul(exp_avgs, beta1)
|
||||
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
|
||||
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
|
||||
torch._foreach_div_(
|
||||
exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)
|
||||
|
Loading…
x
Reference in New Issue
Block a user