Add lamb/lars to optim init imports, remove stray comment
parent
c207e02782
commit
a16a753852
|
@ -1,7 +1,10 @@
|
|||
from .adamp import AdamP
|
||||
from .adamw import AdamW
|
||||
from .adabelief import AdaBelief
|
||||
from .adafactor import Adafactor
|
||||
from .adahessian import Adahessian
|
||||
from .adamp import AdamP
|
||||
from .adamw import AdamW
|
||||
from .lamb import Lamb
|
||||
from .lars import Lars
|
||||
from .lookahead import Lookahead
|
||||
from .madgrad import MADGRAD
|
||||
from .nadam import Nadam
|
||||
|
@ -9,5 +12,4 @@ from .nvnovograd import NvNovoGrad
|
|||
from .radam import RAdam
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
from .sgdp import SGDP
|
||||
from .adabelief import AdaBelief
|
||||
from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
|
||||
|
|
|
@ -87,7 +87,6 @@ class Lars(Optimizer):
|
|||
device = self.param_groups[0]['params'][0].device
|
||||
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
|
||||
|
||||
# exclude scaling for params with 0 weight decay
|
||||
for group in self.param_groups:
|
||||
weight_decay = group['weight_decay']
|
||||
momentum = group['momentum']
|
||||
|
|
Loading…
Reference in New Issue