[Fix] make `autocast` compatible with GTX1660 and make it more robust. (#344)
* fix amp * fix amp * make logic consistent with torch.autocast * support multiple device * fix as comment * fix as comment * avoid circle importpull/354/head
parent
a3d2916790
commit
96378fa748
|
@ -5,8 +5,9 @@ from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmengine import print_log
|
from ..device import get_device
|
||||||
from mmengine.utils import TORCH_VERSION, digit_version
|
from ..logging import print_log
|
||||||
|
from ..utils import TORCH_VERSION, digit_version
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -74,7 +75,7 @@ def autocast(device_type: Optional[str] = None,
|
||||||
# If pytorch version is between 1.5.0 and 1.10.0, the default value of
|
# If pytorch version is between 1.5.0 and 1.10.0, the default value of
|
||||||
# dtype for `torch.cuda.amp.autocast` is torch.float16.
|
# dtype for `torch.cuda.amp.autocast` is torch.float16.
|
||||||
assert device_type == 'cuda' or device_type is None, (
|
assert device_type == 'cuda' or device_type is None, (
|
||||||
'Pytorch version under 1.5.0 only supports running automatic '
|
'Pytorch version under 1.10.0 only supports running automatic '
|
||||||
'mixed training with cuda')
|
'mixed training with cuda')
|
||||||
if dtype is not None or cache_enabled is not None:
|
if dtype is not None or cache_enabled is not None:
|
||||||
print_log(
|
print_log(
|
||||||
|
@ -96,19 +97,33 @@ def autocast(device_type: Optional[str] = None,
|
||||||
'`autocast` is only available in gpu mode')
|
'`autocast` is only available in gpu mode')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if torch.cuda.is_available():
|
# Modified from https://github.com/pytorch/pytorch/blob/master/torch/amp/autocast_mode.py # noqa: E501
|
||||||
device_type = 'cuda' if device_type is None else device_type
|
# This code should update with the `torch.autocast`.
|
||||||
else:
|
if cache_enabled is None:
|
||||||
device_type = 'cpu' if device_type is None else device_type
|
cache_enabled = torch.is_autocast_cache_enabled()
|
||||||
|
device = get_device()
|
||||||
|
device_type = device if device_type is None else device_type
|
||||||
|
|
||||||
|
if device_type == 'cuda':
|
||||||
|
if dtype is None:
|
||||||
|
dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
|
||||||
|
if dtype == torch.bfloat16 and not \
|
||||||
|
torch.cuda.is_bf16_supported():
|
||||||
|
raise RuntimeError(
|
||||||
|
'Current CUDA Device does not support bfloat16. Please '
|
||||||
|
'switch dtype to float16.')
|
||||||
|
|
||||||
|
elif device_type == 'cpu':
|
||||||
|
if dtype is None:
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
assert dtype == torch.bfloat16, (
|
||||||
|
'In CPU autocast, only support `torch.bfloat16` dtype')
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError('User specified autocast device_type must be '
|
||||||
|
F'cuda or cpu, but got {device_type}')
|
||||||
|
|
||||||
if digit_version(TORCH_VERSION) < digit_version('1.11.0'):
|
|
||||||
if dtype is not None and dtype != torch.bfloat16:
|
|
||||||
print_log(
|
|
||||||
f'{dtype} must be `torch.bfloat16` with Pytorch '
|
|
||||||
f'version: {TORCH_VERSION}',
|
|
||||||
logger='current',
|
|
||||||
level=logging.WARNING)
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
with torch.autocast(
|
with torch.autocast(
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
enabled=enabled,
|
enabled=enabled,
|
||||||
|
|
Loading…
Reference in New Issue