diff --git a/train.py b/train.py index 9f450ab8..816f4ae8 100755 --- a/train.py +++ b/train.py @@ -514,8 +514,14 @@ def main(): if utils.is_primary(args): _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif use_amp == 'native': - amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) - if device.type == 'cuda': + try: + amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) + except (AttributeError, TypeError): + # fallback to CUDA only AMP for PyTorch < 1.10 + assert device.type == 'cuda' + amp_autocast = torch.cuda.amp.autocast + if device.type == 'cuda' and amp_dtype == torch.float16: + # loss scaler only used for float16 (half) dtype, bfloat16 does not need it loss_scaler = NativeScaler() if utils.is_primary(args): _logger.info('Using native Torch AMP. Training in mixed precision.')