From 2d597b126db04b44d96514c81756af87e8187ca8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 13 Jun 2023 20:51:31 -0700 Subject: [PATCH] Missed extra nadam algo step for capturable path --- timm/optim/nadamw.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/timm/optim/nadamw.py b/timm/optim/nadamw.py index 1360f2ca..c823f3d5 100644 --- a/timm/optim/nadamw.py +++ b/timm/optim/nadamw.py @@ -315,6 +315,11 @@ def _multi_tensor_nadamw( bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2) + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + exp_avgs = torch._foreach_mul(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) + exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs) torch._foreach_div_( exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)