diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index 3c91f210..31fd176b 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -5,8 +5,9 @@ from typing import Optional import torch -from mmengine import print_log -from mmengine.utils import TORCH_VERSION, digit_version +from ..device import get_device +from ..logging import print_log +from ..utils import TORCH_VERSION, digit_version @contextmanager @@ -74,7 +75,7 @@ def autocast(device_type: Optional[str] = None, # If pytorch version is between 1.5.0 and 1.10.0, the default value of # dtype for `torch.cuda.amp.autocast` is torch.float16. assert device_type == 'cuda' or device_type is None, ( - 'Pytorch version under 1.5.0 only supports running automatic ' + 'Pytorch version under 1.10.0 only supports running automatic ' 'mixed training with cuda') if dtype is not None or cache_enabled is not None: print_log( @@ -96,19 +97,33 @@ def autocast(device_type: Optional[str] = None, '`autocast` is only available in gpu mode') else: - if torch.cuda.is_available(): - device_type = 'cuda' if device_type is None else device_type - else: - device_type = 'cpu' if device_type is None else device_type + # Modified from https://github.com/pytorch/pytorch/blob/master/torch/amp/autocast_mode.py # noqa: E501 + # This code should update with the `torch.autocast`. + if cache_enabled is None: + cache_enabled = torch.is_autocast_cache_enabled() + device = get_device() + device_type = device if device_type is None else device_type + + if device_type == 'cuda': + if dtype is None: + dtype = torch.get_autocast_gpu_dtype() + + if dtype == torch.bfloat16 and not \ + torch.cuda.is_bf16_supported(): + raise RuntimeError( + 'Current CUDA Device does not support bfloat16. Please ' + 'switch dtype to float16.') + + elif device_type == 'cpu': + if dtype is None: + dtype = torch.bfloat16 + assert dtype == torch.bfloat16, ( + 'In CPU autocast, only support `torch.bfloat16` dtype') + + else: + raise ValueError('User specified autocast device_type must be ' + F'cuda or cpu, but got {device_type}') - if digit_version(TORCH_VERSION) < digit_version('1.11.0'): - if dtype is not None and dtype != torch.bfloat16: - print_log( - f'{dtype} must be `torch.bfloat16` with Pytorch ' - f'version: {TORCH_VERSION}', - logger='current', - level=logging.WARNING) - dtype = torch.bfloat16 with torch.autocast( device_type=device_type, enabled=enabled,