mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Need to init momentum with correct dtype
This commit is contained in:
parent
484a88f4b4
commit
19090ea966
@ -146,7 +146,7 @@ class AdafactorBigVision(Optimizer):
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
|
||||
|
||||
if self.defaults['momentum'] is not None:
|
||||
state['exp_avg'] = torch.zeros_like(p.grad, dtype=torch.bfloat16)
|
||||
state['exp_avg'] = torch.zeros_like(p.grad, dtype=self.defaults['momentum_dtype'])
|
||||
|
||||
state_steps.append(state['step'])
|
||||
exp_avg_sq_rs.append(state.get('exp_avg_sq_r', None))
|
||||
|
Loading…
x
Reference in New Issue
Block a user