diff --git a/timm/optim/adafactor_bv.py b/timm/optim/adafactor_bv.py index 62736a33..58b18032 100644 --- a/timm/optim/adafactor_bv.py +++ b/timm/optim/adafactor_bv.py @@ -146,7 +146,7 @@ class AdafactorBigVision(Optimizer): state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format) if self.defaults['momentum'] is not None: - state['exp_avg'] = torch.zeros_like(p.grad, dtype=torch.bfloat16) + state['exp_avg'] = torch.zeros_like(p.grad, dtype=self.defaults['momentum_dtype']) state_steps.append(state['step']) exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))