mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
More fixes for new factory & tests, add back adahessian
This commit is contained in:
parent
5dae91812d
commit
0e6da65c95
@ -18,10 +18,11 @@ This page contains the API reference documentation for learning rate optimizers
|
|||||||
[[autodoc]] timm.optim.adahessian.Adahessian
|
[[autodoc]] timm.optim.adahessian.Adahessian
|
||||||
[[autodoc]] timm.optim.adamp.AdamP
|
[[autodoc]] timm.optim.adamp.AdamP
|
||||||
[[autodoc]] timm.optim.adamw.AdamW
|
[[autodoc]] timm.optim.adamw.AdamW
|
||||||
|
[[autodoc]] timm.optim.adan.Adan
|
||||||
[[autodoc]] timm.optim.adopt.Adopt
|
[[autodoc]] timm.optim.adopt.Adopt
|
||||||
[[autodoc]] timm.optim.lamb.Lamb
|
[[autodoc]] timm.optim.lamb.Lamb
|
||||||
[[autodoc]] timm.optim.lars.Lars
|
[[autodoc]] timm.optim.lars.Lars
|
||||||
[[autodoc]] timm.optim.lion,Lion
|
[[autodoc]] timm.optim.lion.Lion
|
||||||
[[autodoc]] timm.optim.lookahead.Lookahead
|
[[autodoc]] timm.optim.lookahead.Lookahead
|
||||||
[[autodoc]] timm.optim.madgrad.MADGRAD
|
[[autodoc]] timm.optim.madgrad.MADGRAD
|
||||||
[[autodoc]] timm.optim.nadam.Nadam
|
[[autodoc]] timm.optim.nadam.Nadam
|
||||||
|
@ -12,7 +12,7 @@ import torch
|
|||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import TestCase
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class
|
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo
|
||||||
from timm.optim import param_groups_layer_decay, param_groups_weight_decay
|
from timm.optim import param_groups_layer_decay, param_groups_weight_decay
|
||||||
from timm.scheduler import PlateauLRScheduler
|
from timm.scheduler import PlateauLRScheduler
|
||||||
|
|
||||||
@ -294,28 +294,32 @@ def _build_params_dict_single(weight, bias, **kwargs):
|
|||||||
|
|
||||||
@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*')))
|
@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*')))
|
||||||
def test_optim_factory(optimizer):
|
def test_optim_factory(optimizer):
|
||||||
get_optimizer_class(optimizer)
|
assert issubclass(get_optimizer_class(optimizer), torch.optim.Optimizer)
|
||||||
|
|
||||||
# test basic cases that don't need specific tuning via factory test
|
opt_info = get_optimizer_info(optimizer)
|
||||||
_test_basic_cases(
|
assert isinstance(opt_info, OptimInfo)
|
||||||
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
|
||||||
)
|
if not opt_info.second_order: # basic tests don't support second order right now
|
||||||
_test_basic_cases(
|
# test basic cases that don't need specific tuning via factory test
|
||||||
lambda weight, bias: create_optimizer_v2(
|
_test_basic_cases(
|
||||||
_build_params_dict(weight, bias, lr=1e-2),
|
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3)
|
||||||
optimizer,
|
)
|
||||||
lr=1e-3)
|
_test_basic_cases(
|
||||||
)
|
lambda weight, bias: create_optimizer_v2(
|
||||||
_test_basic_cases(
|
_build_params_dict(weight, bias, lr=1e-2),
|
||||||
lambda weight, bias: create_optimizer_v2(
|
optimizer,
|
||||||
_build_params_dict_single(weight, bias, lr=1e-2),
|
lr=1e-3)
|
||||||
optimizer,
|
)
|
||||||
lr=1e-3)
|
_test_basic_cases(
|
||||||
)
|
lambda weight, bias: create_optimizer_v2(
|
||||||
_test_basic_cases(
|
_build_params_dict_single(weight, bias, lr=1e-2),
|
||||||
lambda weight, bias: create_optimizer_v2(
|
optimizer,
|
||||||
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
|
lr=1e-3)
|
||||||
)
|
)
|
||||||
|
_test_basic_cases(
|
||||||
|
lambda weight, bias: create_optimizer_v2(
|
||||||
|
_build_params_dict_single(weight, bias, lr=1e-2), optimizer)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
|
#@pytest.mark.parametrize('optimizer', ['sgd', 'momentum'])
|
||||||
|
@ -17,6 +17,6 @@ from .radam import RAdam
|
|||||||
from .rmsprop_tf import RMSpropTF
|
from .rmsprop_tf import RMSpropTF
|
||||||
from .sgdp import SGDP
|
from .sgdp import SGDP
|
||||||
|
|
||||||
from ._optim_factory import list_optimizers, get_optimizer_class, create_optimizer_v2, \
|
from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
|
||||||
create_optimizer, optimizer_kwargs, OptimInfo, OptimizerRegistry
|
create_optimizer_v2, create_optimizer, optimizer_kwargs
|
||||||
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
|
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers
|
||||||
|
@ -13,10 +13,11 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
|
||||||
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters
|
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
|
||||||
from .adabelief import AdaBelief
|
from .adabelief import AdaBelief
|
||||||
from .adafactor import Adafactor
|
from .adafactor import Adafactor
|
||||||
from .adafactor_bv import AdafactorBigVision
|
from .adafactor_bv import AdafactorBigVision
|
||||||
|
from .adahessian import Adahessian
|
||||||
from .adamp import AdamP
|
from .adamp import AdamP
|
||||||
from .adan import Adan
|
from .adan import Adan
|
||||||
from .adopt import Adopt
|
from .adopt import Adopt
|
||||||
@ -78,6 +79,7 @@ class OptimInfo:
|
|||||||
has_momentum: bool = False
|
has_momentum: bool = False
|
||||||
has_betas: bool = False
|
has_betas: bool = False
|
||||||
num_betas: int = 2
|
num_betas: int = 2
|
||||||
|
second_order: bool = False
|
||||||
defaults: Optional[Dict[str, Any]] = None
|
defaults: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
@ -540,6 +542,13 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
|||||||
has_betas=True,
|
has_betas=True,
|
||||||
num_betas=3
|
num_betas=3
|
||||||
),
|
),
|
||||||
|
OptimInfo(
|
||||||
|
name='adahessian',
|
||||||
|
opt_class=Adahessian,
|
||||||
|
description='An Adaptive Second Order Optimizer',
|
||||||
|
has_betas=True,
|
||||||
|
second_order=True,
|
||||||
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
name='lion',
|
name='lion',
|
||||||
opt_class=Lion,
|
opt_class=Lion,
|
||||||
@ -770,6 +779,21 @@ def list_optimizers(
|
|||||||
return default_registry.list_optimizers(filter, exclude_filters, with_description)
|
return default_registry.list_optimizers(filter, exclude_filters, with_description)
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer_info(name: str) -> OptimInfo:
|
||||||
|
"""Get the OptimInfo for an optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the optimizer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OptimInfo configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If optimizer is not found
|
||||||
|
"""
|
||||||
|
return default_registry.get_optimizer_info(name)
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer_class(
|
def get_optimizer_class(
|
||||||
name: str,
|
name: str,
|
||||||
bind_defaults: bool = False,
|
bind_defaults: bool = False,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from typing import Collection, Optional, Tuple
|
from typing import Collection, Optional
|
||||||
|
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ def _group(it, size):
|
|||||||
return iter(lambda: tuple(islice(it, size)), ())
|
return iter(lambda: tuple(islice(it, size)), ())
|
||||||
|
|
||||||
|
|
||||||
def _layer_map(model, layers_per_group=12, num_groups=None):
|
def auto_group_layers(model, layers_per_group=12, num_groups=None):
|
||||||
def _in_head(n, hp):
|
def _in_head(n, hp):
|
||||||
if not hp:
|
if not hp:
|
||||||
return True
|
return True
|
||||||
@ -63,6 +63,8 @@ def _layer_map(model, layers_per_group=12, num_groups=None):
|
|||||||
layer_map.update({n: num_trunk_groups for n in names_head})
|
layer_map.update({n: num_trunk_groups for n in names_head})
|
||||||
return layer_map
|
return layer_map
|
||||||
|
|
||||||
|
_layer_map = auto_group_layers # backward compat
|
||||||
|
|
||||||
|
|
||||||
def param_groups_layer_decay(
|
def param_groups_layer_decay(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -86,7 +88,7 @@ def param_groups_layer_decay(
|
|||||||
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
|
layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
|
||||||
else:
|
else:
|
||||||
# fallback
|
# fallback
|
||||||
layer_map = _layer_map(model)
|
layer_map = auto_group_layers(model)
|
||||||
num_layers = max(layer_map.values()) + 1
|
num_layers = max(layer_map.values()) + 1
|
||||||
layer_max = num_layers - 1
|
layer_max = num_layers - 1
|
||||||
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
|
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
|
||||||
|
7
timm/optim/optim_factory.py
Normal file
7
timm/optim/optim_factory.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# lots of uses of these functions directly, ala 'import timm.optim.optim_factory as optim_factory', fun :/
|
||||||
|
|
||||||
|
from ._optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
|
||||||
|
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters, _layer_map, _group
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.optim", FutureWarning)
|
Loading…
x
Reference in New Issue
Block a user