mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Minor changes, has_eps=False missing for bnb lion
This commit is contained in:
parent
61305cc26a
commit
8b9b6824ae
@ -270,9 +270,9 @@ class OptimizerRegistry:
|
|||||||
|
|
||||||
if param_group_fn:
|
if param_group_fn:
|
||||||
# run custom fn to generate param groups from nn.Module
|
# run custom fn to generate param groups from nn.Module
|
||||||
parameters = param_group_fn(model_or_params)
|
params = param_group_fn(model_or_params)
|
||||||
elif layer_decay is not None:
|
elif layer_decay is not None:
|
||||||
parameters = param_groups_layer_decay(
|
params = param_groups_layer_decay(
|
||||||
model_or_params,
|
model_or_params,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
layer_decay=layer_decay,
|
layer_decay=layer_decay,
|
||||||
@ -281,17 +281,17 @@ class OptimizerRegistry:
|
|||||||
)
|
)
|
||||||
weight_decay = 0.
|
weight_decay = 0.
|
||||||
elif weight_decay and weight_decay_exclude_1d:
|
elif weight_decay and weight_decay_exclude_1d:
|
||||||
parameters = param_groups_weight_decay(
|
params = param_groups_weight_decay(
|
||||||
model_or_params,
|
model_or_params,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
no_weight_decay_list=no_weight_decay,
|
no_weight_decay_list=no_weight_decay,
|
||||||
)
|
)
|
||||||
weight_decay = 0.
|
weight_decay = 0.
|
||||||
else:
|
else:
|
||||||
parameters = model_or_params.parameters()
|
params = model_or_params.parameters()
|
||||||
else:
|
else:
|
||||||
# pass parameters / parameter groups through to optimizer
|
# pass parameters / parameter groups through to optimizer
|
||||||
parameters = model_or_params
|
params = model_or_params
|
||||||
|
|
||||||
# Parse optimizer name
|
# Parse optimizer name
|
||||||
opt_split = opt.lower().split('_')
|
opt_split = opt.lower().split('_')
|
||||||
@ -330,7 +330,7 @@ class OptimizerRegistry:
|
|||||||
|
|
||||||
# Create optimizer
|
# Create optimizer
|
||||||
opt_class = self.get_optimizer_class(opt_info, bind_defaults=False)
|
opt_class = self.get_optimizer_class(opt_info, bind_defaults=False)
|
||||||
optimizer = opt_class(parameters, **opt_args)
|
optimizer = opt_class(params, **opt_args)
|
||||||
|
|
||||||
# Apply Lookahead if requested
|
# Apply Lookahead if requested
|
||||||
if use_lookahead:
|
if use_lookahead:
|
||||||
@ -685,12 +685,14 @@ def _register_bnb_optimizers(registry: OptimizerRegistry) -> None:
|
|||||||
'bnblion',
|
'bnblion',
|
||||||
'bitsandbytes.optim.Lion',
|
'bitsandbytes.optim.Lion',
|
||||||
description='bitsandbytes Lion',
|
description='bitsandbytes Lion',
|
||||||
|
has_eps=False,
|
||||||
has_betas=True
|
has_betas=True
|
||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
'bnblion8bit',
|
'bnblion8bit',
|
||||||
'bitsandbytes.optim.Lion8bit',
|
'bitsandbytes.optim.Lion8bit',
|
||||||
description='bitsandbytes 8-bit Lion with dynamic quantization',
|
description='bitsandbytes 8-bit Lion with dynamic quantization',
|
||||||
|
has_eps=False,
|
||||||
has_betas=True
|
has_betas=True
|
||||||
),
|
),
|
||||||
OptimInfo(
|
OptimInfo(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user