mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix #1712 broken support for AMP w/ PyTorch < 1.10. Disable loss scaler for bfloat16
This commit is contained in:
parent
3a636eee71
commit
43e6143bef
10
train.py
10
train.py
@ -514,8 +514,14 @@ def main():
|
||||
if utils.is_primary(args):
|
||||
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
|
||||
elif use_amp == 'native':
|
||||
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
||||
if device.type == 'cuda':
|
||||
try:
|
||||
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
||||
except (AttributeError, TypeError):
|
||||
# fallback to CUDA only AMP for PyTorch < 1.10
|
||||
assert device.type == 'cuda'
|
||||
amp_autocast = torch.cuda.amp.autocast
|
||||
if device.type == 'cuda' and amp_dtype == torch.float16:
|
||||
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
|
||||
loss_scaler = NativeScaler()
|
||||
if utils.is_primary(args):
|
||||
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
||||
|
Loading…
x
Reference in New Issue
Block a user