Need to init momentum with correct dtype

This commit is contained in:
Ross Wightman 2024-11-04 09:36:00 -08:00 committed by Ross Wightman
parent 484a88f4b4
commit 19090ea966

View File

@ -146,7 +146,7 @@ class AdafactorBigVision(Optimizer):
state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
if self.defaults['momentum'] is not None:
state['exp_avg'] = torch.zeros_like(p.grad, dtype=torch.bfloat16)
state['exp_avg'] = torch.zeros_like(p.grad, dtype=self.defaults['momentum_dtype'])
state_steps.append(state['step'])
exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))