Support bitsandbytes optimizers in factory

This commit is contained in:
Ross Wightman 2023-05-09 11:33:51 -07:00
parent 21e57c0b9e
commit e3363a7159

View File

@ -27,11 +27,6 @@ from .radam import RAdam
from .rmsprop_tf import RMSpropTF from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP from .sgdp import SGDP
try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True
except ImportError:
has_apex = False
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -254,9 +249,23 @@ def create_optimizer_v2(
opt_lower = opt.lower() opt_lower = opt.lower()
opt_split = opt_lower.split('_') opt_split = opt_lower.split('_')
opt_lower = opt_split[-1] opt_lower = opt_split[-1]
if 'fused' in opt_lower:
if opt_lower.startswith('fused'):
try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True
except ImportError:
has_apex = False
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
if opt_lower.startswith('bnb'):
try:
import bitsandbytes as bnb
has_bnb = True
except ImportError:
has_bnb = False
assert has_bnb and torch.cuda.is_available(), 'bitsandbytes and CUDA required for bnb optimizers'
opt_args = dict(weight_decay=weight_decay, **kwargs) opt_args = dict(weight_decay=weight_decay, **kwargs)
if lr is not None: if lr is not None:
@ -357,6 +366,40 @@ def create_optimizer_v2(
opt_args.setdefault('betas', (0.95, 0.98)) opt_args.setdefault('betas', (0.95, 0.98))
optimizer = FusedNovoGrad(parameters, **opt_args) optimizer = FusedNovoGrad(parameters, **opt_args)
# bitsandbytes optimizers, require bitsandbytes to be installed
elif opt_lower == 'bnbsgd':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'bnbsgd8bit':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, nesterov=True, **opt_args)
elif opt_lower == 'bnbmomentum':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'bnbmomentum8bit':
opt_args.pop('eps', None)
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, **opt_args)
elif opt_lower == 'bnbadam':
optimizer = bnb.optim.Adam(parameters, **opt_args)
elif opt_lower == 'bnbadam8bit':
optimizer = bnb.optim.Adam8bit(parameters, **opt_args)
elif opt_lower == 'bnbadamw':
optimizer = bnb.optim.AdamW(parameters, **opt_args)
elif opt_lower == 'bnbadamw8bit':
optimizer = bnb.optim.AdamW8bit(parameters, **opt_args)
elif opt_lower == 'bnblamb':
optimizer = bnb.optim.LAMB(parameters, **opt_args)
elif opt_lower == 'bnblamb8bit':
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
elif opt_lower == 'bnblars':
optimizer = bnb.optim.LARS(parameters, **opt_args)
elif opt_lower == 'bnblarsb8bit':
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
elif opt_lower == 'bnblion':
optimizer = bnb.optim.Lion(parameters, **opt_args)
elif opt_lower == 'bnblion8bit':
optimizer = bnb.optim.Lion8bit(parameters, **opt_args)
else: else:
assert False and "Invalid optimizer" assert False and "Invalid optimizer"
raise ValueError raise ValueError