mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Support bitsandbytes optimizers in factory
This commit is contained in:
parent
21e57c0b9e
commit
e3363a7159
@ -27,11 +27,6 @@ from .radam import RAdam
|
|||||||
from .rmsprop_tf import RMSpropTF
|
from .rmsprop_tf import RMSpropTF
|
||||||
from .sgdp import SGDP
|
from .sgdp import SGDP
|
||||||
|
|
||||||
try:
|
|
||||||
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
|
||||||
has_apex = True
|
|
||||||
except ImportError:
|
|
||||||
has_apex = False
|
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -254,9 +249,23 @@ def create_optimizer_v2(
|
|||||||
opt_lower = opt.lower()
|
opt_lower = opt.lower()
|
||||||
opt_split = opt_lower.split('_')
|
opt_split = opt_lower.split('_')
|
||||||
opt_lower = opt_split[-1]
|
opt_lower = opt_split[-1]
|
||||||
if 'fused' in opt_lower:
|
|
||||||
|
if opt_lower.startswith('fused'):
|
||||||
|
try:
|
||||||
|
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
||||||
|
has_apex = True
|
||||||
|
except ImportError:
|
||||||
|
has_apex = False
|
||||||
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
||||||
|
|
||||||
|
if opt_lower.startswith('bnb'):
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
has_bnb = True
|
||||||
|
except ImportError:
|
||||||
|
has_bnb = False
|
||||||
|
assert has_bnb and torch.cuda.is_available(), 'bitsandbytes and CUDA required for bnb optimizers'
|
||||||
|
|
||||||
opt_args = dict(weight_decay=weight_decay, **kwargs)
|
opt_args = dict(weight_decay=weight_decay, **kwargs)
|
||||||
|
|
||||||
if lr is not None:
|
if lr is not None:
|
||||||
@ -357,6 +366,40 @@ def create_optimizer_v2(
|
|||||||
opt_args.setdefault('betas', (0.95, 0.98))
|
opt_args.setdefault('betas', (0.95, 0.98))
|
||||||
optimizer = FusedNovoGrad(parameters, **opt_args)
|
optimizer = FusedNovoGrad(parameters, **opt_args)
|
||||||
|
|
||||||
|
# bitsandbytes optimizers, require bitsandbytes to be installed
|
||||||
|
elif opt_lower == 'bnbsgd':
|
||||||
|
opt_args.pop('eps', None)
|
||||||
|
optimizer = bnb.optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||||
|
elif opt_lower == 'bnbsgd8bit':
|
||||||
|
opt_args.pop('eps', None)
|
||||||
|
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||||
|
elif opt_lower == 'bnbmomentum':
|
||||||
|
opt_args.pop('eps', None)
|
||||||
|
optimizer = bnb.optim.SGD(parameters, momentum=momentum, **opt_args)
|
||||||
|
elif opt_lower == 'bnbmomentum8bit':
|
||||||
|
opt_args.pop('eps', None)
|
||||||
|
optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, **opt_args)
|
||||||
|
elif opt_lower == 'bnbadam':
|
||||||
|
optimizer = bnb.optim.Adam(parameters, **opt_args)
|
||||||
|
elif opt_lower == 'bnbadam8bit':
|
||||||
|
optimizer = bnb.optim.Adam8bit(parameters, **opt_args)
|
||||||
|
elif opt_lower == 'bnbadamw':
|
||||||
|
optimizer = bnb.optim.AdamW(parameters, **opt_args)
|
||||||
|
elif opt_lower == 'bnbadamw8bit':
|
||||||
|
optimizer = bnb.optim.AdamW8bit(parameters, **opt_args)
|
||||||
|
elif opt_lower == 'bnblamb':
|
||||||
|
optimizer = bnb.optim.LAMB(parameters, **opt_args)
|
||||||
|
elif opt_lower == 'bnblamb8bit':
|
||||||
|
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
|
||||||
|
elif opt_lower == 'bnblars':
|
||||||
|
optimizer = bnb.optim.LARS(parameters, **opt_args)
|
||||||
|
elif opt_lower == 'bnblarsb8bit':
|
||||||
|
optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
|
||||||
|
elif opt_lower == 'bnblion':
|
||||||
|
optimizer = bnb.optim.Lion(parameters, **opt_args)
|
||||||
|
elif opt_lower == 'bnblion8bit':
|
||||||
|
optimizer = bnb.optim.Lion8bit(parameters, **opt_args)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert False and "Invalid optimizer"
|
assert False and "Invalid optimizer"
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
Loading…
x
Reference in New Issue
Block a user