From 43e6143befbc3704bf99d7e2dc57e9e18d366dff Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 11 Mar 2023 15:26:09 -0800 Subject: [PATCH] Fix #1712 broken support for AMP w/ PyTorch < 1.10. Disable loss scaler for bfloat16 --- train.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 9f450ab8..816f4ae8 100755 --- a/train.py +++ b/train.py @@ -514,8 +514,14 @@ def main(): if utils.is_primary(args): _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': - amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) - if device.type == 'cuda': + try: + amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) + except (AttributeError, TypeError): + # fallback to CUDA only AMP for PyTorch < 1.10 + assert device.type == 'cuda' + amp_autocast = torch.cuda.amp.autocast + if device.type == 'cuda' and amp_dtype == torch.float16: + # loss scaler only used for float16 (half) dtype, bfloat16 does not need it loss_scaler = NativeScaler() if utils.is_primary(args): _logger.info('Using native Torch AMP. Training in mixed precision.')