Missed extra nadam algo step for capturable path

This commit is contained in:
Ross Wightman 2023-06-13 20:51:31 -07:00
parent 4790c0fa16
commit 2d597b126d

View File

@ -315,6 +315,11 @@ def _multi_tensor_nadamw(
bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2)
# Only difference between NAdamW and AdamW in this implementation.
# The official PyTorch implementation of NAdam uses a different algorithm.
exp_avgs = torch._foreach_mul(exp_avgs, beta1)
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
torch._foreach_div_(
exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)