Minor changes, has_eps=False missing for bnb lion

This commit is contained in:
Ross Wightman 2024-11-12 20:06:30 -08:00 committed by Ross Wightman
parent 61305cc26a
commit 8b9b6824ae

View File

@ -270,9 +270,9 @@ class OptimizerRegistry:
if param_group_fn: if param_group_fn:
# run custom fn to generate param groups from nn.Module # run custom fn to generate param groups from nn.Module
parameters = param_group_fn(model_or_params) params = param_group_fn(model_or_params)
elif layer_decay is not None: elif layer_decay is not None:
parameters = param_groups_layer_decay( params = param_groups_layer_decay(
model_or_params, model_or_params,
weight_decay=weight_decay, weight_decay=weight_decay,
layer_decay=layer_decay, layer_decay=layer_decay,
@ -281,17 +281,17 @@ class OptimizerRegistry:
) )
weight_decay = 0. weight_decay = 0.
elif weight_decay and weight_decay_exclude_1d: elif weight_decay and weight_decay_exclude_1d:
parameters = param_groups_weight_decay( params = param_groups_weight_decay(
model_or_params, model_or_params,
weight_decay=weight_decay, weight_decay=weight_decay,
no_weight_decay_list=no_weight_decay, no_weight_decay_list=no_weight_decay,
) )
weight_decay = 0. weight_decay = 0.
else: else:
parameters = model_or_params.parameters() params = model_or_params.parameters()
else: else:
# pass parameters / parameter groups through to optimizer # pass parameters / parameter groups through to optimizer
parameters = model_or_params params = model_or_params
# Parse optimizer name # Parse optimizer name
opt_split = opt.lower().split('_') opt_split = opt.lower().split('_')
@ -330,7 +330,7 @@ class OptimizerRegistry:
# Create optimizer # Create optimizer
opt_class = self.get_optimizer_class(opt_info, bind_defaults=False) opt_class = self.get_optimizer_class(opt_info, bind_defaults=False)
optimizer = opt_class(parameters, **opt_args) optimizer = opt_class(params, **opt_args)
# Apply Lookahead if requested # Apply Lookahead if requested
if use_lookahead: if use_lookahead:
@ -685,12 +685,14 @@ def _register_bnb_optimizers(registry: OptimizerRegistry) -> None:
'bnblion', 'bnblion',
'bitsandbytes.optim.Lion', 'bitsandbytes.optim.Lion',
description='bitsandbytes Lion', description='bitsandbytes Lion',
has_eps=False,
has_betas=True has_betas=True
), ),
OptimInfo( OptimInfo(
'bnblion8bit', 'bnblion8bit',
'bitsandbytes.optim.Lion8bit', 'bitsandbytes.optim.Lion8bit',
description='bitsandbytes 8-bit Lion with dynamic quantization', description='bitsandbytes 8-bit Lion with dynamic quantization',
has_eps=False,
has_betas=True has_betas=True
), ),
OptimInfo( OptimInfo(