mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
One more scalar -> tensor fix for lamb optimizer
This commit is contained in:
parent
8f68193c91
commit
9541f4963b
@ -98,7 +98,7 @@ class Lamb(Optimizer):
|
|||||||
and returns the loss.
|
and returns the loss.
|
||||||
"""
|
"""
|
||||||
device = self.param_groups[0]["params"][0].device
|
device = self.param_groups[0]["params"][0].device
|
||||||
one_tensor = torch.tensor(1.0, device=device)
|
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
@ -115,7 +115,9 @@ class Lamb(Optimizer):
|
|||||||
global_grad_norm.add_(grad.pow(2).sum())
|
global_grad_norm.add_(grad.pow(2).sum())
|
||||||
|
|
||||||
global_grad_norm = torch.sqrt(global_grad_norm)
|
global_grad_norm = torch.sqrt(global_grad_norm)
|
||||||
max_grad_norm = self.defaults['max_grad_norm']
|
# FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
|
||||||
|
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
|
||||||
|
max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
|
||||||
clip_global_grad_norm = torch.where(
|
clip_global_grad_norm = torch.where(
|
||||||
global_grad_norm > max_grad_norm,
|
global_grad_norm > max_grad_norm,
|
||||||
global_grad_norm / max_grad_norm,
|
global_grad_norm / max_grad_norm,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user