From 91f0ea33386c11eabaf8c1f22186e47ce8743ebe Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 4 Nov 2024 09:36:00 -0800 Subject: [PATCH] Need to init momentum with correct dtype --- timm/optim/adafactor_bv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/optim/adafactor_bv.py b/timm/optim/adafactor_bv.py index 62736a33..58b18032 100644 --- a/timm/optim/adafactor_bv.py +++ b/timm/optim/adafactor_bv.py @@ -146,7 +146,7 @@ class AdafactorBigVision(Optimizer): state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format) if self.defaults['momentum'] is not None: - state['exp_avg'] = torch.zeros_like(p.grad, dtype=torch.bfloat16) + state['exp_avg'] = torch.zeros_like(p.grad, dtype=self.defaults['momentum_dtype']) state_steps.append(state['step']) exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))