[Fix] make `autocast` compatible with GTX1660 and make it more robust. (#344)
* fix amp * fix amp * make logic consistent with torch.autocast * support multiple device * fix as comment * fix as comment * avoid circle importpull/354/head
parent
a3d2916790
commit
96378fa748
|
@ -5,8 +5,9 @@ from typing import Optional
|
|||
|
||||
import torch
|
||||
|
||||
from mmengine import print_log
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
from ..device import get_device
|
||||
from ..logging import print_log
|
||||
from ..utils import TORCH_VERSION, digit_version
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -74,7 +75,7 @@ def autocast(device_type: Optional[str] = None,
|
|||
# If pytorch version is between 1.5.0 and 1.10.0, the default value of
|
||||
# dtype for `torch.cuda.amp.autocast` is torch.float16.
|
||||
assert device_type == 'cuda' or device_type is None, (
|
||||
'Pytorch version under 1.5.0 only supports running automatic '
|
||||
'Pytorch version under 1.10.0 only supports running automatic '
|
||||
'mixed training with cuda')
|
||||
if dtype is not None or cache_enabled is not None:
|
||||
print_log(
|
||||
|
@ -96,19 +97,33 @@ def autocast(device_type: Optional[str] = None,
|
|||
'`autocast` is only available in gpu mode')
|
||||
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
device_type = 'cuda' if device_type is None else device_type
|
||||
else:
|
||||
device_type = 'cpu' if device_type is None else device_type
|
||||
# Modified from https://github.com/pytorch/pytorch/blob/master/torch/amp/autocast_mode.py # noqa: E501
|
||||
# This code should update with the `torch.autocast`.
|
||||
if cache_enabled is None:
|
||||
cache_enabled = torch.is_autocast_cache_enabled()
|
||||
device = get_device()
|
||||
device_type = device if device_type is None else device_type
|
||||
|
||||
if device_type == 'cuda':
|
||||
if dtype is None:
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
|
||||
if dtype == torch.bfloat16 and not \
|
||||
torch.cuda.is_bf16_supported():
|
||||
raise RuntimeError(
|
||||
'Current CUDA Device does not support bfloat16. Please '
|
||||
'switch dtype to float16.')
|
||||
|
||||
elif device_type == 'cpu':
|
||||
if dtype is None:
|
||||
dtype = torch.bfloat16
|
||||
assert dtype == torch.bfloat16, (
|
||||
'In CPU autocast, only support `torch.bfloat16` dtype')
|
||||
|
||||
else:
|
||||
raise ValueError('User specified autocast device_type must be '
|
||||
F'cuda or cpu, but got {device_type}')
|
||||
|
||||
if digit_version(TORCH_VERSION) < digit_version('1.11.0'):
|
||||
if dtype is not None and dtype != torch.bfloat16:
|
||||
print_log(
|
||||
f'{dtype} must be `torch.bfloat16` with Pytorch '
|
||||
f'version: {TORCH_VERSION}',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
dtype = torch.bfloat16
|
||||
with torch.autocast(
|
||||
device_type=device_type,
|
||||
enabled=enabled,
|
||||
|
|
Loading…
Reference in New Issue