Fix #1712 broken support for AMP w/ PyTorch < 1.10. Disable loss scaler for bfloat16

This commit is contained in:
Ross Wightman 2023-03-11 15:26:09 -08:00
parent 3a636eee71
commit 43e6143bef

View File

@ -514,8 +514,14 @@ def main():
if utils.is_primary(args):
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
if device.type == 'cuda':
try:
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
except (AttributeError, TypeError):
# fallback to CUDA only AMP for PyTorch < 1.10
assert device.type == 'cuda'
amp_autocast = torch.cuda.amp.autocast
if device.type == 'cuda' and amp_dtype == torch.float16:
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
loss_scaler = NativeScaler()
if utils.is_primary(args):
_logger.info('Using native Torch AMP. Training in mixed precision.')