More fixes for new factory & tests, add back adahessian
parent
45490ac52f
commit
dde990785e
hfdocs/source/reference
tests
|
@ -18,10 +18,11 @@ This page contains the API reference documentation for learning rate optimizers
|
|||
[[autodoc]] timm.optim.adahessian.Adahessian
|
||||
[[autodoc]] timm.optim.adamp.AdamP
|
||||
[[autodoc]] timm.optim.adamw.AdamW
|
||||
[[autodoc]] timm.optim.adan.Adan
|
||||
[[autodoc]] timm.optim.adopt.Adopt
|
||||
[[autodoc]] timm.optim.lamb.Lamb
|
||||
[[autodoc]] timm.optim.lars.Lars
|
||||
[[autodoc]] timm.optim.lion,Lion
|
||||
[[autodoc]] timm.optim.lion.Lion
|
||||
[[autodoc]] timm.optim.lookahead.Lookahead
|
||||
[[autodoc]] timm.optim.madgrad.MADGRAD
|
||||
[[autodoc]] timm.optim.nadam.Nadam
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch
|
|||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.nn import Parameter
|
||||
|
||||
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class
|
||||
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo
|
||||
from timm.optim import param_groups_layer_decay, param_groups_weight_decay
|
||||
from timm.scheduler import PlateauLRScheduler
|
||||
|
||||
|
@ -294,28 +294,32 @@ def _build_params_dict_single(weight, bias, **kwargs):
|
|||
|
||||
@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*')))
|
||||
def test_optim_factory(optimizer):
|
||||
get_optimizer_class(optimizer)
|
||||
assert issubclass(get_optimizer_class(optimizer), torch.optim.Optimizer)
|
||||
|
||||
# test basic cases that don't need specific tuning via factory test
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=1e-2),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-2),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
|
||||
)
|
||||
opt_info = get_optimizer_info(optimizer)
|
||||
assert isinstance(opt_info, OptimInfo)
|
||||
|
||||
if not opt_info.second_order: # basic tests don't support second order right now
|
||||
# test basic cases that don't need specific tuning via factory test
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict(weight, bias, lr=1e-2),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-2),
|
||||
optimizer,
|
||||
lr=1e-3)
|
||||
)
|
||||
_test_basic_cases(
|
||||
lambda weight, bias: create_optimizer_v2(
|
||||
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
|
||||
)
|
||||
|
||||
|
||||
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
|
||||
|
|
|
@ -17,6 +17,6 @@ from .radam import RAdam
|
|||
from .rmsprop_tf import RMSpropTF
|
||||
from .sgdp import SGDP
|
||||
|
||||
from ._optim_factory import list_optimizers, get_optimizer_class, create_optimizer_v2, \
|
||||
create_optimizer, optimizer_kwargs, OptimInfo, OptimizerRegistry
|
||||
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
|
||||
from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
|
||||
create_optimizer_v2, create_optimizer, optimizer_kwargs
|
||||
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers
|
||||
|
|
|
@ -13,10 +13,11 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters
|
||||
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
|
||||
from .adabelief import AdaBelief
|
||||
from .adafactor import Adafactor
|
||||
from .adafactor_bv import AdafactorBigVision
|
||||
from .adahessian import Adahessian
|
||||
from .adamp import AdamP
|
||||
from .adan import Adan
|
||||
from .adopt import Adopt
|
||||
|
@ -78,6 +79,7 @@ class OptimInfo:
|
|||
has_momentum: bool = False
|
||||
has_betas: bool = False
|
||||
num_betas: int = 2
|
||||
second_order: bool = False
|
||||
defaults: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
|
@ -540,6 +542,13 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||
has_betas=True,
|
||||
num_betas=3
|
||||
),
|
||||
OptimInfo(
|
||||
name='adahessian',
|
||||
opt_class=Adahessian,
|
||||
description='An Adaptive Second Order Optimizer',
|
||||
has_betas=True,
|
||||
second_order=True,
|
||||
),
|
||||
OptimInfo(
|
||||
name='lion',
|
||||
opt_class=Lion,
|
||||
|
@ -770,6 +779,21 @@ def list_optimizers(
|
|||
return default_registry.list_optimizers(filter, exclude_filters, with_description)
|
||||
|
||||
|
||||
def get_optimizer_info(name: str) -> OptimInfo:
|
||||
"""Get the OptimInfo for an optimizer.
|
||||
|
||||
Args:
|
||||
name: Name of the optimizer
|
||||
|
||||
Returns:
|
||||
OptimInfo configuration
|
||||
|
||||
Raises:
|
||||
ValueError: If optimizer is not found
|
||||
"""
|
||||
return default_registry.get_optimizer_info(name)
|
||||
|
||||
|
||||
def get_optimizer_class(
|
||||
name: str,
|
||||
bind_defaults: bool = False,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from itertools import islice
|
||||
from typing import Collection, Optional, Tuple
|
||||
from typing import Collection, Optional
|
||||
|
||||
from torch import nn as nn
|
||||
|
||||
|
@ -37,7 +37,7 @@ def _group(it, size):
|
|||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def _layer_map(model, layers_per_group=12, num_groups=None):
|
||||
def auto_group_layers(model, layers_per_group=12, num_groups=None):
|
||||
def _in_head(n, hp):
|
||||
if not hp:
|
||||
return True
|
||||
|
@ -63,6 +63,8 @@ def _layer_map(model, layers_per_group=12, num_groups=None):
|
|||
layer_map.update({n: num_trunk_groups for n in names_head})
|
||||
return layer_map
|
||||
|
||||
_layer_map = auto_group_layers # backward compat
|
||||
|
||||
|
||||
def param_groups_layer_decay(
|
||||
model: nn.Module,
|
||||
|
@ -86,7 +88,7 @@ def param_groups_layer_decay(
|
|||
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
|
||||
else:
|
||||
# fallback
|
||||
layer_map = _layer_map(model)
|
||||
layer_map = auto_group_layers(model)
|
||||
num_layers = max(layer_map.values()) + 1
|
||||
layer_max = num_layers - 1
|
||||
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
# lots of uses of these functions directly, ala 'import timm.optim.optim_factory as optim_factory', fun :/
|
||||
|
||||
from ._optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
|
||||
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters, _layer_map, _group
|
||||
|
||||
import warnings
|
||||
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.optim", FutureWarning)
|
Loading…
Reference in New Issue