More fixes for new factory & tests, add back adahessian

This commit is contained in:
Ross Wightman 2024-11-12 17:19:49 -08:00
parent 5dae91812d
commit 0e6da65c95
6 changed files with 68 additions and 30 deletions

View File

@ -18,10 +18,11 @@ This page contains the API reference documentation for learning rate optimizers
[[autodoc]] timm.optim.adahessian.Adahessian [[autodoc]] timm.optim.adahessian.Adahessian
[[autodoc]] timm.optim.adamp.AdamP [[autodoc]] timm.optim.adamp.AdamP
[[autodoc]] timm.optim.adamw.AdamW [[autodoc]] timm.optim.adamw.AdamW
[[autodoc]] timm.optim.adan.Adan
[[autodoc]] timm.optim.adopt.Adopt [[autodoc]] timm.optim.adopt.Adopt
[[autodoc]] timm.optim.lamb.Lamb [[autodoc]] timm.optim.lamb.Lamb
[[autodoc]] timm.optim.lars.Lars [[autodoc]] timm.optim.lars.Lars
[[autodoc]] timm.optim.lion,Lion [[autodoc]] timm.optim.lion.Lion
[[autodoc]] timm.optim.lookahead.Lookahead [[autodoc]] timm.optim.lookahead.Lookahead
[[autodoc]] timm.optim.madgrad.MADGRAD [[autodoc]] timm.optim.madgrad.MADGRAD
[[autodoc]] timm.optim.nadam.Nadam [[autodoc]] timm.optim.nadam.Nadam

View File

@ -12,7 +12,7 @@ import torch
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import TestCase
from torch.nn import Parameter from torch.nn import Parameter
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo
from timm.optim import param_groups_layer_decay, param_groups_weight_decay from timm.optim import param_groups_layer_decay, param_groups_weight_decay
from timm.scheduler import PlateauLRScheduler from timm.scheduler import PlateauLRScheduler
@ -294,8 +294,12 @@ def _build_params_dict_single(weight, bias, **kwargs):
@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*'))) @pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*')))
def test_optim_factory(optimizer): def test_optim_factory(optimizer):
get_optimizer_class(optimizer) assert issubclass(get_optimizer_class(optimizer), torch.optim.Optimizer)
opt_info = get_optimizer_info(optimizer)
assert isinstance(opt_info, OptimInfo)
if not opt_info.second_order: # basic tests don't support second order right now
# test basic cases that don't need specific tuning via factory test # test basic cases that don't need specific tuning via factory test
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3) lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)

View File

@ -17,6 +17,6 @@ from .radam import RAdam
from .rmsprop_tf import RMSpropTF from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP from .sgdp import SGDP
from ._optim_factory import list_optimizers, get_optimizer_class, create_optimizer_v2, \ from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
create_optimizer, optimizer_kwargs, OptimInfo, OptimizerRegistry create_optimizer_v2, create_optimizer, optimizer_kwargs
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers

View File

@ -13,10 +13,11 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
from .adabelief import AdaBelief from .adabelief import AdaBelief
from .adafactor import Adafactor from .adafactor import Adafactor
from .adafactor_bv import AdafactorBigVision from .adafactor_bv import AdafactorBigVision
from .adahessian import Adahessian
from .adamp import AdamP from .adamp import AdamP
from .adan import Adan from .adan import Adan
from .adopt import Adopt from .adopt import Adopt
@ -78,6 +79,7 @@ class OptimInfo:
has_momentum: bool = False has_momentum: bool = False
has_betas: bool = False has_betas: bool = False
num_betas: int = 2 num_betas: int = 2
second_order: bool = False
defaults: Optional[Dict[str, Any]] = None defaults: Optional[Dict[str, Any]] = None
@ -540,6 +542,13 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
has_betas=True, has_betas=True,
num_betas=3 num_betas=3
), ),
OptimInfo(
name='adahessian',
opt_class=Adahessian,
description='An Adaptive Second Order Optimizer',
has_betas=True,
second_order=True,
),
OptimInfo( OptimInfo(
name='lion', name='lion',
opt_class=Lion, opt_class=Lion,
@ -770,6 +779,21 @@ def list_optimizers(
return default_registry.list_optimizers(filter, exclude_filters, with_description) return default_registry.list_optimizers(filter, exclude_filters, with_description)
def get_optimizer_info(name: str) -> OptimInfo:
"""Get the OptimInfo for an optimizer.
Args:
name: Name of the optimizer
Returns:
OptimInfo configuration
Raises:
ValueError: If optimizer is not found
"""
return default_registry.get_optimizer_info(name)
def get_optimizer_class( def get_optimizer_class(
name: str, name: str,
bind_defaults: bool = False, bind_defaults: bool = False,

View File

@ -1,6 +1,6 @@
import logging import logging
from itertools import islice from itertools import islice
from typing import Collection, Optional, Tuple from typing import Collection, Optional
from torch import nn as nn from torch import nn as nn
@ -37,7 +37,7 @@ def _group(it, size):
return iter(lambda: tuple(islice(it, size)), ()) return iter(lambda: tuple(islice(it, size)), ())
def _layer_map(model, layers_per_group=12, num_groups=None): def auto_group_layers(model, layers_per_group=12, num_groups=None):
def _in_head(n, hp): def _in_head(n, hp):
if not hp: if not hp:
return True return True
@ -63,6 +63,8 @@ def _layer_map(model, layers_per_group=12, num_groups=None):
layer_map.update({n: num_trunk_groups for n in names_head}) layer_map.update({n: num_trunk_groups for n in names_head})
return layer_map return layer_map
_layer_map = auto_group_layers # backward compat
def param_groups_layer_decay( def param_groups_layer_decay(
model: nn.Module, model: nn.Module,
@ -86,7 +88,7 @@ def param_groups_layer_decay(
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True) layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
else: else:
# fallback # fallback
layer_map = _layer_map(model) layer_map = auto_group_layers(model)
num_layers = max(layer_map.values()) + 1 num_layers = max(layer_map.values()) + 1
layer_max = num_layers - 1 layer_max = num_layers - 1
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers)) layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))

View File

@ -0,0 +1,7 @@
# lots of uses of these functions directly, ala 'import timm.optim.optim_factory as optim_factory', fun :/
from ._optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters, _layer_map, _group
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.optim", FutureWarning)