More fixes for new factory & tests, add back adahessian

small_384_weights
Ross Wightman 2024-11-12 17:19:49 -08:00 committed by Ross Wightman
parent 45490ac52f
commit dde990785e
6 changed files with 68 additions and 30 deletions

View File

@ -18,10 +18,11 @@ This page contains the API reference documentation for learning rate optimizers
[[autodoc]] timm.optim.adahessian.Adahessian
[[autodoc]] timm.optim.adamp.AdamP
[[autodoc]] timm.optim.adamw.AdamW
[[autodoc]] timm.optim.adan.Adan
[[autodoc]] timm.optim.adopt.Adopt
[[autodoc]] timm.optim.lamb.Lamb
[[autodoc]] timm.optim.lars.Lars
[[autodoc]] timm.optim.lion,Lion
[[autodoc]] timm.optim.lion.Lion
[[autodoc]] timm.optim.lookahead.Lookahead
[[autodoc]] timm.optim.madgrad.MADGRAD
[[autodoc]] timm.optim.nadam.Nadam

View File

@ -12,7 +12,7 @@ import torch
from torch.testing._internal.common_utils import TestCase
from torch.nn import Parameter
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo
from timm.optim import param_groups_layer_decay, param_groups_weight_decay
from timm.scheduler import PlateauLRScheduler
@ -294,28 +294,32 @@ def _build_params_dict_single(weight, bias, **kwargs):
@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*')))
def test_optim_factory(optimizer):
get_optimizer_class(optimizer)
assert issubclass(get_optimizer_class(optimizer), torch.optim.Optimizer)
# test basic cases that don't need specific tuning via factory test
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=1e-2),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
)
opt_info = get_optimizer_info(optimizer)
assert isinstance(opt_info, OptimInfo)
if not opt_info.second_order: # basic tests don't support second order right now
# test basic cases that don't need specific tuning via factory test
_test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict(weight, bias, lr=1e-2),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2),
optimizer,
lr=1e-3)
)
_test_basic_cases(
lambda weight, bias: create_optimizer_v2(
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
)
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])

View File

@ -17,6 +17,6 @@ from .radam import RAdam
from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP
from ._optim_factory import list_optimizers, get_optimizer_class, create_optimizer_v2, \
create_optimizer, optimizer_kwargs, OptimInfo, OptimizerRegistry
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
create_optimizer_v2, create_optimizer, optimizer_kwargs
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers

View File

@ -13,10 +13,11 @@ import torch
import torch.nn as nn
import torch.optim as optim
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
from .adabelief import AdaBelief
from .adafactor import Adafactor
from .adafactor_bv import AdafactorBigVision
from .adahessian import Adahessian
from .adamp import AdamP
from .adan import Adan
from .adopt import Adopt
@ -78,6 +79,7 @@ class OptimInfo:
has_momentum: bool = False
has_betas: bool = False
num_betas: int = 2
second_order: bool = False
defaults: Optional[Dict[str, Any]] = None
@ -540,6 +542,13 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
has_betas=True,
num_betas=3
),
OptimInfo(
name='adahessian',
opt_class=Adahessian,
description='An Adaptive Second Order Optimizer',
has_betas=True,
second_order=True,
),
OptimInfo(
name='lion',
opt_class=Lion,
@ -770,6 +779,21 @@ def list_optimizers(
return default_registry.list_optimizers(filter, exclude_filters, with_description)
def get_optimizer_info(name: str) -> OptimInfo:
"""Get the OptimInfo for an optimizer.
Args:
name: Name of the optimizer
Returns:
OptimInfo configuration
Raises:
ValueError: If optimizer is not found
"""
return default_registry.get_optimizer_info(name)
def get_optimizer_class(
name: str,
bind_defaults: bool = False,

View File

@ -1,6 +1,6 @@
import logging
from itertools import islice
from typing import Collection, Optional, Tuple
from typing import Collection, Optional
from torch import nn as nn
@ -37,7 +37,7 @@ def _group(it, size):
return iter(lambda: tuple(islice(it, size)), ())
def _layer_map(model, layers_per_group=12, num_groups=None):
def auto_group_layers(model, layers_per_group=12, num_groups=None):
def _in_head(n, hp):
if not hp:
return True
@ -63,6 +63,8 @@ def _layer_map(model, layers_per_group=12, num_groups=None):
layer_map.update({n: num_trunk_groups for n in names_head})
return layer_map
_layer_map = auto_group_layers # backward compat
def param_groups_layer_decay(
model: nn.Module,
@ -86,7 +88,7 @@ def param_groups_layer_decay(
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
else:
# fallback
layer_map = _layer_map(model)
layer_map = auto_group_layers(model)
num_layers = max(layer_map.values()) + 1
layer_max = num_layers - 1
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))

View File

@ -0,0 +1,7 @@
# lots of uses of these functions directly, ala 'import timm.optim.optim_factory as optim_factory', fun :/
from ._optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters, _layer_map, _group
import warnings
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.optim", FutureWarning)