From dde990785e3afef2cdbbc078baf8b894761f82e6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 12 Nov 2024 17:19:49 -0800 Subject: [PATCH] More fixes for new factory & tests, add back adahessian --- hfdocs/source/reference/optimizers.mdx | 3 +- tests/test_optim.py | 48 ++++++++++++++------------ timm/optim/__init__.py | 6 ++-- timm/optim/_optim_factory.py | 26 +++++++++++++- timm/optim/_param_groups.py | 8 +++-- timm/optim/optim_factory.py | 7 ++++ 6 files changed, 68 insertions(+), 30 deletions(-) create mode 100644 timm/optim/optim_factory.py diff --git a/hfdocs/source/reference/optimizers.mdx b/hfdocs/source/reference/optimizers.mdx index 212152fb..66f32ca2 100644 --- a/hfdocs/source/reference/optimizers.mdx +++ b/hfdocs/source/reference/optimizers.mdx @@ -18,10 +18,11 @@ This page contains the API reference documentation for learning rate optimizers [[autodoc]] timm.optim.adahessian.Adahessian [[autodoc]] timm.optim.adamp.AdamP [[autodoc]] timm.optim.adamw.AdamW +[[autodoc]] timm.optim.adan.Adan [[autodoc]] timm.optim.adopt.Adopt [[autodoc]] timm.optim.lamb.Lamb [[autodoc]] timm.optim.lars.Lars -[[autodoc]] timm.optim.lion,Lion +[[autodoc]] timm.optim.lion.Lion [[autodoc]] timm.optim.lookahead.Lookahead [[autodoc]] timm.optim.madgrad.MADGRAD [[autodoc]] timm.optim.nadam.Nadam diff --git a/tests/test_optim.py b/tests/test_optim.py index 26c09a50..d70ec98d 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -12,7 +12,7 @@ import torch from torch.testing._internal.common_utils import TestCase from torch.nn import Parameter -from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class +from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo from timm.optim import param_groups_layer_decay, param_groups_weight_decay from timm.scheduler import PlateauLRScheduler @@ -294,28 +294,32 @@ def _build_params_dict_single(weight, bias, **kwargs): @pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*'))) def test_optim_factory(optimizer): - get_optimizer_class(optimizer) + assert issubclass(get_optimizer_class(optimizer), torch.optim.Optimizer) - # test basic cases that don't need specific tuning via factory test - _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict(weight, bias, lr=1e-2), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-2), - optimizer, - lr=1e-3) - ) - _test_basic_cases( - lambda weight, bias: create_optimizer_v2( - _build_params_dict_single(weight, bias, lr=1e-2), optimizer) - ) + opt_info = get_optimizer_info(optimizer) + assert isinstance(opt_info, OptimInfo) + + if not opt_info.second_order: # basic tests don't support second order right now + # test basic cases that don't need specific tuning via factory test + _test_basic_cases( + lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict(weight, bias, lr=1e-2), + optimizer, + lr=1e-3) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict_single(weight, bias, lr=1e-2), + optimizer, + lr=1e-3) + ) + _test_basic_cases( + lambda weight, bias: create_optimizer_v2( + _build_params_dict_single(weight, bias, lr=1e-2), optimizer) + ) #@pytest.mark.parametrize('optimizer', ['sgd', 'momentum']) diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index 53921b76..552585c9 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -17,6 +17,6 @@ from .radam import RAdam from .rmsprop_tf import RMSpropTF from .sgdp import SGDP -from ._optim_factory import list_optimizers, get_optimizer_class, create_optimizer_v2, \ - create_optimizer, optimizer_kwargs, OptimInfo, OptimizerRegistry -from ._param_groups import param_groups_layer_decay, param_groups_weight_decay \ No newline at end of file +from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \ + create_optimizer_v2, create_optimizer, optimizer_kwargs +from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index 0ea20b6e..b3759a37 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -13,10 +13,11 @@ import torch import torch.nn as nn import torch.optim as optim -from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters +from ._param_groups import param_groups_layer_decay, param_groups_weight_decay from .adabelief import AdaBelief from .adafactor import Adafactor from .adafactor_bv import AdafactorBigVision +from .adahessian import Adahessian from .adamp import AdamP from .adan import Adan from .adopt import Adopt @@ -78,6 +79,7 @@ class OptimInfo: has_momentum: bool = False has_betas: bool = False num_betas: int = 2 + second_order: bool = False defaults: Optional[Dict[str, Any]] = None @@ -540,6 +542,13 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None: has_betas=True, num_betas=3 ), + OptimInfo( + name='adahessian', + opt_class=Adahessian, + description='An Adaptive Second Order Optimizer', + has_betas=True, + second_order=True, + ), OptimInfo( name='lion', opt_class=Lion, @@ -770,6 +779,21 @@ def list_optimizers( return default_registry.list_optimizers(filter, exclude_filters, with_description) +def get_optimizer_info(name: str) -> OptimInfo: + """Get the OptimInfo for an optimizer. + + Args: + name: Name of the optimizer + + Returns: + OptimInfo configuration + + Raises: + ValueError: If optimizer is not found + """ + return default_registry.get_optimizer_info(name) + + def get_optimizer_class( name: str, bind_defaults: bool = False, diff --git a/timm/optim/_param_groups.py b/timm/optim/_param_groups.py index ef9faacf..a756c5e0 100644 --- a/timm/optim/_param_groups.py +++ b/timm/optim/_param_groups.py @@ -1,6 +1,6 @@ import logging from itertools import islice -from typing import Collection, Optional, Tuple +from typing import Collection, Optional from torch import nn as nn @@ -37,7 +37,7 @@ def _group(it, size): return iter(lambda: tuple(islice(it, size)), ()) -def _layer_map(model, layers_per_group=12, num_groups=None): +def auto_group_layers(model, layers_per_group=12, num_groups=None): def _in_head(n, hp): if not hp: return True @@ -63,6 +63,8 @@ def _layer_map(model, layers_per_group=12, num_groups=None): layer_map.update({n: num_trunk_groups for n in names_head}) return layer_map +_layer_map = auto_group_layers # backward compat + def param_groups_layer_decay( model: nn.Module, @@ -86,7 +88,7 @@ def param_groups_layer_decay( layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True) else: # fallback - layer_map = _layer_map(model) + layer_map = auto_group_layers(model) num_layers = max(layer_map.values()) + 1 layer_max = num_layers - 1 layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers)) diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py new file mode 100644 index 00000000..a4227a98 --- /dev/null +++ b/timm/optim/optim_factory.py @@ -0,0 +1,7 @@ +# lots of uses of these functions directly, ala 'import timm.optim.optim_factory as optim_factory', fun :/ + +from ._optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs +from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters, _layer_map, _group + +import warnings +warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.optim", FutureWarning)