799 lines
27 KiB
Python
799 lines
27 KiB
Python
|
""" Optimizer Factory w/ custom Weight Decay & Layer Decay support
|
||
|
|
||
|
Hacked together by / Copyright 2021 Ross Wightman
|
||
|
"""
|
||
|
import logging
|
||
|
from dataclasses import dataclass
|
||
|
from functools import partial
|
||
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, Protocol, Iterator
|
||
|
from fnmatch import fnmatch
|
||
|
import importlib
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.optim as optim
|
||
|
|
||
|
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, group_parameters
|
||
|
from .adabelief import AdaBelief
|
||
|
from .adafactor import Adafactor
|
||
|
from .adafactor_bv import AdafactorBigVision
|
||
|
from .adamp import AdamP
|
||
|
from .adan import Adan
|
||
|
from .adopt import Adopt
|
||
|
from .lamb import Lamb
|
||
|
from .lars import Lars
|
||
|
from .lion import Lion
|
||
|
from .lookahead import Lookahead
|
||
|
from .madgrad import MADGRAD
|
||
|
from .nadam import Nadam
|
||
|
from .nadamw import NAdamW
|
||
|
from .nvnovograd import NvNovoGrad
|
||
|
from .radam import RAdam
|
||
|
from .rmsprop_tf import RMSpropTF
|
||
|
from .sgdp import SGDP
|
||
|
from .sgdw import SGDW
|
||
|
|
||
|
_logger = logging.getLogger(__name__)
|
||
|
|
||
|
# Type variables
|
||
|
T = TypeVar('T')
|
||
|
Params = Union[Iterator[nn.Parameter], Iterator[Dict[str, Any]]]
|
||
|
OptimType = TypeVar('OptimType', bound='optim.Optimizer')
|
||
|
|
||
|
|
||
|
def _import_class(class_string: str) -> Type:
|
||
|
"""Dynamically import a class from a string."""
|
||
|
try:
|
||
|
module_name, class_name = class_string.rsplit(".", 1)
|
||
|
module = importlib.import_module(module_name)
|
||
|
return getattr(module, class_name)
|
||
|
except (ImportError, AttributeError) as e:
|
||
|
raise ImportError(f"Could not import {class_string}: {e}")
|
||
|
|
||
|
|
||
|
class OptimizerCallable(Protocol):
|
||
|
"""Protocol for optimizer constructor signatures."""
|
||
|
|
||
|
def __call__(self, params: Params, **kwargs) -> optim.Optimizer: ...
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class OptimInfo:
|
||
|
"""Immutable configuration for an optimizer.
|
||
|
|
||
|
Attributes:
|
||
|
name: Unique identifier for the optimizer
|
||
|
opt_class: The optimizer class
|
||
|
description: Brief description of the optimizer's characteristics and behavior
|
||
|
has_eps: Whether the optimizer accepts epsilon parameter
|
||
|
has_momentum: Whether the optimizer accepts momentum parameter
|
||
|
has_betas: Whether the optimizer accepts a tuple of beta parameters
|
||
|
num_betas: number of betas in tuple (valid IFF has_betas = True)
|
||
|
defaults: Optional default parameters for the optimizer
|
||
|
"""
|
||
|
name: str
|
||
|
opt_class: Union[str, Type[optim.Optimizer]]
|
||
|
description: str = ''
|
||
|
has_eps: bool = True
|
||
|
has_momentum: bool = False
|
||
|
has_betas: bool = False
|
||
|
num_betas: int = 2
|
||
|
defaults: Optional[Dict[str, Any]] = None
|
||
|
|
||
|
|
||
|
class OptimizerRegistry:
|
||
|
"""Registry managing optimizer configurations and instantiation.
|
||
|
|
||
|
This class provides a central registry for optimizer configurations and handles
|
||
|
their instantiation with appropriate parameter groups and settings.
|
||
|
"""
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
self._optimizers: Dict[str, OptimInfo] = {}
|
||
|
self._foreach_defaults: Set[str] = {'lion'}
|
||
|
|
||
|
def register(self, info: OptimInfo) -> None:
|
||
|
"""Register an optimizer configuration.
|
||
|
|
||
|
Args:
|
||
|
info: The OptimInfo configuration containing name, type and description
|
||
|
"""
|
||
|
name = info.name.lower()
|
||
|
if name in self._optimizers:
|
||
|
_logger.warning(f'Optimizer {name} already registered, overwriting')
|
||
|
self._optimizers[name] = info
|
||
|
|
||
|
def register_alias(self, alias: str, target: str) -> None:
|
||
|
"""Register an alias for an existing optimizer.
|
||
|
|
||
|
Args:
|
||
|
alias: The alias name
|
||
|
target: The target optimizer name
|
||
|
|
||
|
Raises:
|
||
|
KeyError: If target optimizer doesn't exist
|
||
|
"""
|
||
|
target = target.lower()
|
||
|
if target not in self._optimizers:
|
||
|
raise KeyError(f'Cannot create alias for non-existent optimizer {target}')
|
||
|
self._optimizers[alias.lower()] = self._optimizers[target]
|
||
|
|
||
|
def register_foreach_default(self, name: str) -> None:
|
||
|
"""Register an optimizer as defaulting to foreach=True."""
|
||
|
self._foreach_defaults.add(name.lower())
|
||
|
|
||
|
def list_optimizers(
|
||
|
self,
|
||
|
filter: str = '',
|
||
|
exclude_filters: Optional[List[str]] = None,
|
||
|
with_description: bool = False
|
||
|
) -> List[Union[str, Tuple[str, str]]]:
|
||
|
"""List available optimizer names, optionally filtered.
|
||
|
|
||
|
Args:
|
||
|
filter: Wildcard style filter string (e.g., 'adam*')
|
||
|
exclude_filters: Optional list of wildcard patterns to exclude
|
||
|
with_description: If True, return tuples of (name, description)
|
||
|
|
||
|
Returns:
|
||
|
List of either optimizer names or (name, description) tuples
|
||
|
"""
|
||
|
names = sorted(self._optimizers.keys())
|
||
|
|
||
|
if filter:
|
||
|
names = [n for n in names if fnmatch(n, filter)]
|
||
|
|
||
|
if exclude_filters:
|
||
|
for exclude_filter in exclude_filters:
|
||
|
names = [n for n in names if not fnmatch(n, exclude_filter)]
|
||
|
|
||
|
if with_description:
|
||
|
return [(name, self._optimizers[name].description) for name in names]
|
||
|
return names
|
||
|
|
||
|
def get_optimizer_info(self, name: str) -> OptimInfo:
|
||
|
"""Get the OptimInfo for an optimizer.
|
||
|
|
||
|
Args:
|
||
|
name: Name of the optimizer
|
||
|
|
||
|
Returns:
|
||
|
OptimInfo configuration
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If optimizer is not found
|
||
|
"""
|
||
|
name = name.lower()
|
||
|
if name not in self._optimizers:
|
||
|
raise ValueError(f'Optimizer {name} not found in registry')
|
||
|
return self._optimizers[name]
|
||
|
|
||
|
def get_optimizer_class(
|
||
|
self,
|
||
|
name_or_info: Union[str, OptimInfo],
|
||
|
bind_defaults: bool = True,
|
||
|
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
||
|
"""Get the optimizer class with any default arguments applied.
|
||
|
|
||
|
This allows direct instantiation of optimizers with their default configs
|
||
|
without going through the full factory.
|
||
|
|
||
|
Args:
|
||
|
name_or_info: Name of the optimizer
|
||
|
bind_defaults: Bind default arguments to optimizer class via `partial` before returning
|
||
|
|
||
|
Returns:
|
||
|
Optimizer class or partial with defaults applied
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If optimizer not found
|
||
|
"""
|
||
|
if isinstance(name_or_info, str):
|
||
|
opt_info = self.get_optimizer_info(name_or_info)
|
||
|
else:
|
||
|
assert isinstance(name_or_info, OptimInfo)
|
||
|
opt_info = name_or_info
|
||
|
|
||
|
if isinstance(opt_info.opt_class, str):
|
||
|
# Special handling for APEX and BNB optimizers
|
||
|
if opt_info.opt_class.startswith('apex.'):
|
||
|
assert torch.cuda.is_available(), 'CUDA required for APEX optimizers'
|
||
|
try:
|
||
|
opt_class = _import_class(opt_info.opt_class)
|
||
|
except ImportError as e:
|
||
|
raise ImportError('APEX optimizers require apex to be installed') from e
|
||
|
elif opt_info.opt_class.startswith('bitsandbytes.'):
|
||
|
assert torch.cuda.is_available(), 'CUDA required for bitsandbytes optimizers'
|
||
|
try:
|
||
|
opt_class = _import_class(opt_info.opt_class)
|
||
|
except ImportError as e:
|
||
|
raise ImportError('bitsandbytes optimizers require bitsandbytes to be installed') from e
|
||
|
else:
|
||
|
opt_class = _import_class(opt_info.opt_class)
|
||
|
else:
|
||
|
opt_class = opt_info.opt_class
|
||
|
|
||
|
# Return class or partial with defaults
|
||
|
if bind_defaults and opt_info.defaults:
|
||
|
opt_class = partial(opt_class, **opt_info.defaults)
|
||
|
|
||
|
return opt_class
|
||
|
|
||
|
def create_optimizer(
|
||
|
self,
|
||
|
model_or_params: Union[nn.Module, Params],
|
||
|
opt: str,
|
||
|
lr: Optional[float] = None,
|
||
|
weight_decay: float = 0.,
|
||
|
momentum: float = 0.9,
|
||
|
foreach: Optional[bool] = None,
|
||
|
weight_decay_exclude_1d: bool = True,
|
||
|
layer_decay: Optional[float] = None,
|
||
|
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> optim.Optimizer:
|
||
|
"""Create an optimizer instance.
|
||
|
|
||
|
Args:
|
||
|
model_or_params: Model or parameters to optimize
|
||
|
opt: Name of optimizer to create
|
||
|
lr: Learning rate
|
||
|
weight_decay: Weight decay factor
|
||
|
momentum: Momentum factor for applicable optimizers
|
||
|
foreach: Enable/disable foreach operation
|
||
|
weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine)
|
||
|
layer_decay: Layer-wise learning rate decay
|
||
|
param_group_fn: Optional custom parameter grouping function
|
||
|
**kwargs: Additional optimizer-specific arguments
|
||
|
|
||
|
Returns:
|
||
|
Configured optimizer instance
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If optimizer not found or configuration invalid
|
||
|
"""
|
||
|
|
||
|
# Get parameters to optimize
|
||
|
if isinstance(model_or_params, nn.Module):
|
||
|
# Extract parameters from a nn.Module, build param groups w/ weight-decay and/or layer-decay applied
|
||
|
no_weight_decay = getattr(model_or_params, 'no_weight_decay', lambda: set())()
|
||
|
|
||
|
if param_group_fn:
|
||
|
# run custom fn to generate param groups from nn.Module
|
||
|
parameters = param_group_fn(model_or_params)
|
||
|
elif layer_decay is not None:
|
||
|
parameters = param_groups_layer_decay(
|
||
|
model_or_params,
|
||
|
weight_decay=weight_decay,
|
||
|
layer_decay=layer_decay,
|
||
|
no_weight_decay_list=no_weight_decay,
|
||
|
weight_decay_exclude_1d=weight_decay_exclude_1d,
|
||
|
)
|
||
|
weight_decay = 0.
|
||
|
elif weight_decay and weight_decay_exclude_1d:
|
||
|
parameters = param_groups_weight_decay(
|
||
|
model_or_params,
|
||
|
weight_decay=weight_decay,
|
||
|
no_weight_decay_list=no_weight_decay,
|
||
|
)
|
||
|
weight_decay = 0.
|
||
|
else:
|
||
|
parameters = model_or_params.parameters()
|
||
|
else:
|
||
|
# pass parameters / parameter groups through to optimizer
|
||
|
parameters = model_or_params
|
||
|
|
||
|
# Parse optimizer name
|
||
|
opt_split = opt.lower().split('_')
|
||
|
opt_name = opt_split[-1]
|
||
|
use_lookahead = opt_split[0] == 'lookahead' if len(opt_split) > 1 else False
|
||
|
|
||
|
opt_info = self.get_optimizer_info(opt_name)
|
||
|
|
||
|
# Build optimizer arguments
|
||
|
opt_args: Dict[str, Any] = {'weight_decay': weight_decay, **kwargs}
|
||
|
|
||
|
# Add LR to args, if None optimizer default is used, some optimizers manage LR internally if None.
|
||
|
if lr is not None:
|
||
|
opt_args['lr'] = lr
|
||
|
|
||
|
# Apply optimizer-specific settings
|
||
|
if opt_info.defaults:
|
||
|
for k, v in opt_info.defaults.items():
|
||
|
opt_args.setdefault(k, v)
|
||
|
|
||
|
# timm has always defaulted momentum to 0.9 if optimizer supports momentum, keep for backward compat.
|
||
|
if opt_info.has_momentum:
|
||
|
opt_args.setdefault('momentum', momentum)
|
||
|
|
||
|
# Remove commonly used kwargs that aren't always supported
|
||
|
if not opt_info.has_eps:
|
||
|
opt_args.pop('eps', None)
|
||
|
if not opt_info.has_betas:
|
||
|
opt_args.pop('betas', None)
|
||
|
|
||
|
if foreach is not None:
|
||
|
# Explicitly activate or deactivate multi-tensor foreach impl.
|
||
|
# Not all optimizers support this, and those that do usually default to using
|
||
|
# multi-tensor impl if foreach is left as default 'None' and can be enabled.
|
||
|
opt_args.setdefault('foreach', foreach)
|
||
|
|
||
|
# Create optimizer
|
||
|
opt_class = self.get_optimizer_class(opt_info, bind_defaults=False)
|
||
|
optimizer = opt_class(parameters, **opt_args)
|
||
|
|
||
|
# Apply Lookahead if requested
|
||
|
if use_lookahead:
|
||
|
optimizer = Lookahead(optimizer)
|
||
|
|
||
|
return optimizer
|
||
|
|
||
|
|
||
|
def _register_sgd_variants(registry: OptimizerRegistry) -> None:
|
||
|
"""Register SGD-based optimizers"""
|
||
|
sgd_optimizers = [
|
||
|
OptimInfo(
|
||
|
name='sgd',
|
||
|
opt_class=optim.SGD,
|
||
|
description='Stochastic Gradient Descent with Nesterov momentum (default)',
|
||
|
has_eps=False,
|
||
|
has_momentum=True,
|
||
|
defaults={'nesterov': True}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='momentum',
|
||
|
opt_class=optim.SGD,
|
||
|
description='Stochastic Gradient Descent with classical momentum',
|
||
|
has_eps=False,
|
||
|
has_momentum=True,
|
||
|
defaults={'nesterov': False}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='sgdp',
|
||
|
opt_class=SGDP,
|
||
|
description='SGD with built-in projection to unit norm sphere',
|
||
|
has_momentum=True,
|
||
|
defaults={'nesterov': True}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='sgdw',
|
||
|
opt_class=SGDW,
|
||
|
description='SGD with decoupled weight decay and Nesterov momentum',
|
||
|
has_eps=False,
|
||
|
has_momentum=True,
|
||
|
defaults={'nesterov': True}
|
||
|
),
|
||
|
]
|
||
|
for opt in sgd_optimizers:
|
||
|
registry.register(opt)
|
||
|
|
||
|
|
||
|
def _register_adam_variants(registry: OptimizerRegistry) -> None:
|
||
|
"""Register Adam-based optimizers"""
|
||
|
adam_optimizers = [
|
||
|
OptimInfo(
|
||
|
name='adam',
|
||
|
opt_class=optim.Adam,
|
||
|
description='torch.optim Adam (Adaptive Moment Estimation)',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='adamw',
|
||
|
opt_class=optim.AdamW,
|
||
|
description='torch.optim Adam with decoupled weight decay regularization',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='adamp',
|
||
|
opt_class=AdamP,
|
||
|
description='Adam with built-in projection to unit norm sphere',
|
||
|
has_betas=True,
|
||
|
defaults={'wd_ratio': 0.01, 'nesterov': True}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='nadam',
|
||
|
opt_class=Nadam,
|
||
|
description='Adam with Nesterov momentum',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='nadamw',
|
||
|
opt_class=NAdamW,
|
||
|
description='Adam with Nesterov momentum and decoupled weight decay',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='radam',
|
||
|
opt_class=RAdam,
|
||
|
description='Rectified Adam with variance adaptation',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='adamax',
|
||
|
opt_class=optim.Adamax,
|
||
|
description='torch.optim Adamax, Adam with infinity norm for more stable updates',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='adafactor',
|
||
|
opt_class=Adafactor,
|
||
|
description='Memory-efficient implementation of Adam with factored gradients',
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='adafactorbv',
|
||
|
opt_class=AdafactorBigVision,
|
||
|
description='Big Vision variant of Adafactor with factored gradients, half precision momentum.',
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='adopt',
|
||
|
opt_class=Adopt,
|
||
|
description='Memory-efficient implementation of Adam with factored gradients',
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='adoptw',
|
||
|
opt_class=Adopt,
|
||
|
description='Memory-efficient implementation of Adam with factored gradients',
|
||
|
defaults={'decoupled': True}
|
||
|
),
|
||
|
]
|
||
|
for opt in adam_optimizers:
|
||
|
registry.register(opt)
|
||
|
|
||
|
|
||
|
def _register_lamb_lars(registry: OptimizerRegistry) -> None:
|
||
|
"""Register LAMB and LARS variants"""
|
||
|
lamb_lars_optimizers = [
|
||
|
OptimInfo(
|
||
|
name='lamb',
|
||
|
opt_class=Lamb,
|
||
|
description='Layer-wise Adaptive Moments for batch optimization',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='lambc',
|
||
|
opt_class=Lamb,
|
||
|
description='LAMB with trust ratio clipping for stability',
|
||
|
has_betas=True,
|
||
|
defaults={'trust_clip': True}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='lars',
|
||
|
opt_class=Lars,
|
||
|
description='Layer-wise Adaptive Rate Scaling',
|
||
|
has_momentum=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='larc',
|
||
|
opt_class=Lars,
|
||
|
description='LARS with trust ratio clipping for stability',
|
||
|
has_momentum=True,
|
||
|
defaults={'trust_clip': True}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='nlars',
|
||
|
opt_class=Lars,
|
||
|
description='LARS with Nesterov momentum',
|
||
|
has_momentum=True,
|
||
|
defaults={'nesterov': True}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='nlarc',
|
||
|
opt_class=Lars,
|
||
|
description='LARS with Nesterov momentum & trust ratio clipping',
|
||
|
has_momentum=True,
|
||
|
defaults={'nesterov': True, 'trust_clip': True}
|
||
|
),
|
||
|
]
|
||
|
for opt in lamb_lars_optimizers:
|
||
|
registry.register(opt)
|
||
|
|
||
|
|
||
|
def _register_other_optimizers(registry: OptimizerRegistry) -> None:
|
||
|
"""Register miscellaneous optimizers"""
|
||
|
other_optimizers = [
|
||
|
OptimInfo(
|
||
|
name='adabelief',
|
||
|
opt_class=AdaBelief,
|
||
|
description='Adapts learning rate based on gradient prediction error',
|
||
|
has_betas=True,
|
||
|
defaults={'rectify': False}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='radabelief',
|
||
|
opt_class=AdaBelief,
|
||
|
description='Rectified AdaBelief with variance adaptation',
|
||
|
has_betas=True,
|
||
|
defaults={'rectify': True}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='adadelta',
|
||
|
opt_class=optim.Adadelta,
|
||
|
description='torch.optim Adadelta, Adapts learning rates based on running windows of gradients'
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='adagrad',
|
||
|
opt_class=optim.Adagrad,
|
||
|
description='torch.optim Adagrad, Adapts learning rates using cumulative squared gradients',
|
||
|
defaults={'eps': 1e-8}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='adan',
|
||
|
opt_class=Adan,
|
||
|
description='Adaptive Nesterov Momentum Algorithm',
|
||
|
defaults={'no_prox': False},
|
||
|
has_betas=True,
|
||
|
num_betas=3
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='adanw',
|
||
|
opt_class=Adan,
|
||
|
description='Adaptive Nesterov Momentum with decoupled weight decay',
|
||
|
defaults={'no_prox': True},
|
||
|
has_betas=True,
|
||
|
num_betas=3
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='lion',
|
||
|
opt_class=Lion,
|
||
|
description='Evolved Sign Momentum optimizer for improved convergence',
|
||
|
has_eps=False,
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='madgrad',
|
||
|
opt_class=MADGRAD,
|
||
|
description='Momentum-based Adaptive gradient method',
|
||
|
has_momentum=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='madgradw',
|
||
|
opt_class=MADGRAD,
|
||
|
description='MADGRAD with decoupled weight decay',
|
||
|
has_momentum=True,
|
||
|
defaults={'decoupled_decay': True}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='novograd',
|
||
|
opt_class=NvNovoGrad,
|
||
|
description='Normalized Adam with L2 norm gradient normalization',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='rmsprop',
|
||
|
opt_class=optim.RMSprop,
|
||
|
description='torch.optim RMSprop, Root Mean Square Propagation',
|
||
|
has_momentum=True,
|
||
|
defaults={'alpha': 0.9}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='rmsproptf',
|
||
|
opt_class=RMSpropTF,
|
||
|
description='TensorFlow-style RMSprop implementation, Root Mean Square Propagation',
|
||
|
has_momentum=True,
|
||
|
defaults={'alpha': 0.9}
|
||
|
),
|
||
|
]
|
||
|
for opt in other_optimizers:
|
||
|
registry.register(opt)
|
||
|
registry.register_foreach_default('lion')
|
||
|
|
||
|
|
||
|
def _register_apex_optimizers(registry: OptimizerRegistry) -> None:
|
||
|
"""Register APEX optimizers (lazy import)"""
|
||
|
apex_optimizers = [
|
||
|
OptimInfo(
|
||
|
name='fusedsgd',
|
||
|
opt_class='apex.optimizers.FusedSGD',
|
||
|
description='NVIDIA APEX fused SGD implementation for faster training',
|
||
|
has_eps=False,
|
||
|
has_momentum=True,
|
||
|
defaults={'nesterov': True}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='fusedadam',
|
||
|
opt_class='apex.optimizers.FusedAdam',
|
||
|
description='NVIDIA APEX fused Adam implementation',
|
||
|
has_betas=True,
|
||
|
defaults={'adam_w_mode': False}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='fusedadamw',
|
||
|
opt_class='apex.optimizers.FusedAdam',
|
||
|
description='NVIDIA APEX fused AdamW implementation',
|
||
|
has_betas=True,
|
||
|
defaults={'adam_w_mode': True}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='fusedlamb',
|
||
|
opt_class='apex.optimizers.FusedLAMB',
|
||
|
description='NVIDIA APEX fused LAMB implementation',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='fusednovograd',
|
||
|
opt_class='apex.optimizers.FusedNovoGrad',
|
||
|
description='NVIDIA APEX fused NovoGrad implementation',
|
||
|
has_betas=True,
|
||
|
defaults={'betas': (0.95, 0.98)}
|
||
|
),
|
||
|
]
|
||
|
for opt in apex_optimizers:
|
||
|
registry.register(opt)
|
||
|
|
||
|
|
||
|
def _register_bnb_optimizers(registry: OptimizerRegistry) -> None:
|
||
|
"""Register bitsandbytes optimizers (lazy import)"""
|
||
|
bnb_optimizers = [
|
||
|
OptimInfo(
|
||
|
name='bnbsgd',
|
||
|
opt_class='bitsandbytes.optim.SGD',
|
||
|
description='bitsandbytes SGD',
|
||
|
has_eps=False,
|
||
|
has_momentum=True,
|
||
|
defaults={'nesterov': True}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='bnbsgd8bit',
|
||
|
opt_class='bitsandbytes.optim.SGD8bit',
|
||
|
description='bitsandbytes 8-bit SGD with dynamic quantization',
|
||
|
has_eps=False,
|
||
|
has_momentum=True,
|
||
|
defaults={'nesterov': True}
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='bnbadam',
|
||
|
opt_class='bitsandbytes.optim.Adam',
|
||
|
description='bitsandbytes Adam',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='bnbadam8bit',
|
||
|
opt_class='bitsandbytes.optim.Adam',
|
||
|
description='bitsandbytes 8-bit Adam with dynamic quantization',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='bnbadamw',
|
||
|
opt_class='bitsandbytes.optim.AdamW',
|
||
|
description='bitsandbytes AdamW',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
name='bnbadamw8bit',
|
||
|
opt_class='bitsandbytes.optim.AdamW',
|
||
|
description='bitsandbytes 8-bit AdamW with dynamic quantization',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
'bnblion',
|
||
|
'bitsandbytes.optim.Lion',
|
||
|
description='bitsandbytes Lion',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
'bnblion8bit',
|
||
|
'bitsandbytes.optim.Lion8bit',
|
||
|
description='bitsandbytes 8-bit Lion with dynamic quantization',
|
||
|
has_betas=True
|
||
|
),
|
||
|
OptimInfo(
|
||
|
'bnbademamix',
|
||
|
'bitsandbytes.optim.AdEMAMix',
|
||
|
description='bitsandbytes AdEMAMix',
|
||
|
has_betas=True,
|
||
|
num_betas=3,
|
||
|
),
|
||
|
OptimInfo(
|
||
|
'bnbademamix8bit',
|
||
|
'bitsandbytes.optim.AdEMAMix8bit',
|
||
|
description='bitsandbytes 8-bit AdEMAMix with dynamic quantization',
|
||
|
has_betas=True,
|
||
|
num_betas=3,
|
||
|
),
|
||
|
]
|
||
|
for opt in bnb_optimizers:
|
||
|
registry.register(opt)
|
||
|
|
||
|
|
||
|
default_registry = OptimizerRegistry()
|
||
|
|
||
|
def _register_default_optimizers() -> None:
|
||
|
"""Register all default optimizers to the global registry."""
|
||
|
# Register all optimizer groups
|
||
|
_register_sgd_variants(default_registry)
|
||
|
_register_adam_variants(default_registry)
|
||
|
_register_lamb_lars(default_registry)
|
||
|
_register_other_optimizers(default_registry)
|
||
|
_register_apex_optimizers(default_registry)
|
||
|
_register_bnb_optimizers(default_registry)
|
||
|
|
||
|
# Register aliases
|
||
|
default_registry.register_alias('nesterov', 'sgd')
|
||
|
default_registry.register_alias('nesterovw', 'sgdw')
|
||
|
|
||
|
|
||
|
# Initialize default registry
|
||
|
_register_default_optimizers()
|
||
|
|
||
|
# Public API
|
||
|
|
||
|
def list_optimizers(
|
||
|
filter: str = '',
|
||
|
exclude_filters: Optional[List[str]] = None,
|
||
|
with_description: bool = False,
|
||
|
) -> List[Union[str, Tuple[str, str]]]:
|
||
|
"""List available optimizer names, optionally filtered.
|
||
|
"""
|
||
|
return default_registry.list_optimizers(filter, exclude_filters, with_description)
|
||
|
|
||
|
|
||
|
def get_optimizer_class(
|
||
|
name: str,
|
||
|
bind_defaults: bool = False,
|
||
|
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
|
||
|
"""Get optimizer class by name with any defaults applied.
|
||
|
"""
|
||
|
return default_registry.get_optimizer_class(name, bind_defaults=bind_defaults)
|
||
|
|
||
|
|
||
|
def create_optimizer_v2(
|
||
|
model_or_params: Union[nn.Module, Params],
|
||
|
opt: str = 'sgd',
|
||
|
lr: Optional[float] = None,
|
||
|
weight_decay: float = 0.,
|
||
|
momentum: float = 0.9,
|
||
|
foreach: Optional[bool] = None,
|
||
|
filter_bias_and_bn: bool = True,
|
||
|
layer_decay: Optional[float] = None,
|
||
|
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> optim.Optimizer:
|
||
|
"""Create an optimizer instance using the default registry."""
|
||
|
return default_registry.create_optimizer(
|
||
|
model_or_params,
|
||
|
opt=opt,
|
||
|
lr=lr,
|
||
|
weight_decay=weight_decay,
|
||
|
momentum=momentum,
|
||
|
foreach=foreach,
|
||
|
weight_decay_exclude_1d=filter_bias_and_bn,
|
||
|
layer_decay=layer_decay,
|
||
|
param_group_fn=param_group_fn,
|
||
|
**kwargs
|
||
|
)
|
||
|
|
||
|
|
||
|
def optimizer_kwargs(cfg):
|
||
|
""" cfg/argparse to kwargs helper
|
||
|
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
|
||
|
"""
|
||
|
kwargs = dict(
|
||
|
opt=cfg.opt,
|
||
|
lr=cfg.lr,
|
||
|
weight_decay=cfg.weight_decay,
|
||
|
momentum=cfg.momentum,
|
||
|
)
|
||
|
if getattr(cfg, 'opt_eps', None) is not None:
|
||
|
kwargs['eps'] = cfg.opt_eps
|
||
|
if getattr(cfg, 'opt_betas', None) is not None:
|
||
|
kwargs['betas'] = cfg.opt_betas
|
||
|
if getattr(cfg, 'layer_decay', None) is not None:
|
||
|
kwargs['layer_decay'] = cfg.layer_decay
|
||
|
if getattr(cfg, 'opt_args', None) is not None:
|
||
|
kwargs.update(cfg.opt_args)
|
||
|
if getattr(cfg, 'opt_foreach', None) is not None:
|
||
|
kwargs['foreach'] = cfg.opt_foreach
|
||
|
return kwargs
|
||
|
|
||
|
|
||
|
def create_optimizer(args, model, filter_bias_and_bn=True):
|
||
|
""" Legacy optimizer factory for backwards compatibility.
|
||
|
NOTE: Use create_optimizer_v2 for new code.
|
||
|
"""
|
||
|
return create_optimizer_v2(
|
||
|
model,
|
||
|
**optimizer_kwargs(cfg=args),
|
||
|
filter_bias_and_bn=filter_bias_and_bn,
|
||
|
)
|
||
|
|