Make LaProp weight decay match typical PyTorch 'decoupled' behaviour where it's scaled by LR

This commit is contained in:
Ross Wightman 2024-11-29 16:44:43 -08:00
parent 886eb77938
commit 82e8677690

View File

@ -116,6 +116,6 @@ class LaProp(Optimizer):
p.add_(exp_avg, alpha=-step_size)
if group['weight_decay'] != 0:
p.add_(p, alpha=-group['weight_decay'])
p.add_(p, alpha=-(group['lr'] * group['weight_decay']))
return loss