Add lamb/lars to optim init imports, remove stray comment
parent
c207e02782
commit
a16a753852
|
@ -1,7 +1,10 @@
|
||||||
from .adamp import AdamP
|
from .adabelief import AdaBelief
|
||||||
from .adamw import AdamW
|
|
||||||
from .adafactor import Adafactor
|
from .adafactor import Adafactor
|
||||||
from .adahessian import Adahessian
|
from .adahessian import Adahessian
|
||||||
|
from .adamp import AdamP
|
||||||
|
from .adamw import AdamW
|
||||||
|
from .lamb import Lamb
|
||||||
|
from .lars import Lars
|
||||||
from .lookahead import Lookahead
|
from .lookahead import Lookahead
|
||||||
from .madgrad import MADGRAD
|
from .madgrad import MADGRAD
|
||||||
from .nadam import Nadam
|
from .nadam import Nadam
|
||||||
|
@ -9,5 +12,4 @@ from .nvnovograd import NvNovoGrad
|
||||||
from .radam import RAdam
|
from .radam import RAdam
|
||||||
from .rmsprop_tf import RMSpropTF
|
from .rmsprop_tf import RMSpropTF
|
||||||
from .sgdp import SGDP
|
from .sgdp import SGDP
|
||||||
from .adabelief import AdaBelief
|
|
||||||
from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
|
from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
|
||||||
|
|
|
@ -87,7 +87,6 @@ class Lars(Optimizer):
|
||||||
device = self.param_groups[0]['params'][0].device
|
device = self.param_groups[0]['params'][0].device
|
||||||
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
||||||
|
|
||||||
# exclude scaling for params with 0 weight decay
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
weight_decay = group['weight_decay']
|
weight_decay = group['weight_decay']
|
||||||
momentum = group['momentum']
|
momentum = group['momentum']
|
||||||
|
|
Loading…
Reference in New Issue