mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Support bitsandbytes optimizers in factory
This commit is contained in:
parent
21e57c0b9e
commit
e3363a7159
@ -27,11 +27,6 @@ from .radam import RAdam
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
from .sgdp import SGDP
|
||||
|
||||
try:
|
||||
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
||||
has_apex = True
|
||||
except ImportError:
|
||||
has_apex = False
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@ -254,9 +249,23 @@ def create_optimizer_v2(
|
||||
opt_lower = opt.lower()
|
||||
opt_split = opt_lower.split('_')
|
||||
opt_lower = opt_split[-1]
|
||||
if 'fused' in opt_lower:
|
||||
|
||||
if opt_lower.startswith('fused'):
|
||||
try:
|
||||
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
||||
has_apex = True
|
||||
except ImportError:
|
||||
has_apex = False
|
||||
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
||||
|
||||
if opt_lower.startswith('bnb'):
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
has_bnb = True
|
||||
except ImportError:
|
||||
has_bnb = False
|
||||
assert has_bnb and torch.cuda.is_available(), 'bitsandbytes and CUDA required for bnb optimizers'
|
||||
|
||||
opt_args = dict(weight_decay=weight_decay, **kwargs)
|
||||
|
||||
if lr is not None:
|
||||
@ -357,6 +366,40 @@ def create_optimizer_v2(
|
||||
opt_args.setdefault('betas', (0.95, 0.98))
|
||||
optimizer = FusedNovoGrad(parameters, **opt_args)
|
||||
|
||||
# bitsandbytes optimizers, require bitsandbytes to be installed
|
||||
elif opt_lower == 'bnbsgd':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = bnb.optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'bnbsgd8bit':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||
elif opt_lower == 'bnbmomentum':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = bnb.optim.SGD(parameters, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'bnbmomentum8bit':
|
||||
opt_args.pop('eps', None)
|
||||
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, **opt_args)
|
||||
elif opt_lower == 'bnbadam':
|
||||
optimizer = bnb.optim.Adam(parameters, **opt_args)
|
||||
elif opt_lower == 'bnbadam8bit':
|
||||
optimizer = bnb.optim.Adam8bit(parameters, **opt_args)
|
||||
elif opt_lower == 'bnbadamw':
|
||||
optimizer = bnb.optim.AdamW(parameters, **opt_args)
|
||||
elif opt_lower == 'bnbadamw8bit':
|
||||
optimizer = bnb.optim.AdamW8bit(parameters, **opt_args)
|
||||
elif opt_lower == 'bnblamb':
|
||||
optimizer = bnb.optim.LAMB(parameters, **opt_args)
|
||||
elif opt_lower == 'bnblamb8bit':
|
||||
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
|
||||
elif opt_lower == 'bnblars':
|
||||
optimizer = bnb.optim.LARS(parameters, **opt_args)
|
||||
elif opt_lower == 'bnblarsb8bit':
|
||||
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
|
||||
elif opt_lower == 'bnblion':
|
||||
optimizer = bnb.optim.Lion(parameters, **opt_args)
|
||||
elif opt_lower == 'bnblion8bit':
|
||||
optimizer = bnb.optim.Lion8bit(parameters, **opt_args)
|
||||
|
||||
else:
|
||||
assert False and "Invalid optimizer"
|
||||
raise ValueError
|
||||
|
Loading…
x
Reference in New Issue
Block a user