Missing optimizers in __init__.py, add bind_defaults=False for unit tests
parent
d0161f303a
commit
0b5264a108
|
@ -294,7 +294,7 @@ def _build_params_dict_single(weight, bias, **kwargs):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*')))
|
||||
def test_optim_factory(optimizer):
|
||||
assert issubclass(get_optimizer_class(optimizer), torch.optim.Optimizer)
|
||||
assert issubclass(get_optimizer_class(optimizer, bind_defaults=False), torch.optim.Optimizer)
|
||||
|
||||
opt_info = get_optimizer_info(optimizer)
|
||||
assert isinstance(opt_info, OptimInfo)
|
||||
|
|
|
@ -12,10 +12,12 @@ from .lion import Lion
|
|||
from .lookahead import Lookahead
|
||||
from .madgrad import MADGRAD
|
||||
from .nadam import Nadam
|
||||
from .nadamw import NAdamW
|
||||
from .nvnovograd import NvNovoGrad
|
||||
from .radam import RAdam
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
from .sgdp import SGDP
|
||||
from .sgdw import SGDW
|
||||
|
||||
from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
|
||||
create_optimizer_v2, create_optimizer, optimizer_kwargs
|
||||
|
|
Loading…
Reference in New Issue