[Fix] make `autocast` compatible with GTX1660 and make it more robust. (#344)

* fix amp

* fix amp

* make logic consistent with torch.autocast

* support multiple device

* fix as comment

* fix as comment

* avoid circle import
pull/354/head
Mashiro 2022-07-05 20:37:56 +08:00 committed by GitHub
parent a3d2916790
commit 96378fa748
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 30 additions and 15 deletions

View File

@ -5,8 +5,9 @@ from typing import Optional
import torch
from mmengine import print_log
from mmengine.utils import TORCH_VERSION, digit_version
from ..device import get_device
from ..logging import print_log
from ..utils import TORCH_VERSION, digit_version
@contextmanager
@ -74,7 +75,7 @@ def autocast(device_type: Optional[str] = None,
# If pytorch version is between 1.5.0 and 1.10.0, the default value of
# dtype for `torch.cuda.amp.autocast` is torch.float16.
assert device_type == 'cuda' or device_type is None, (
'Pytorch version under 1.5.0 only supports running automatic '
'Pytorch version under 1.10.0 only supports running automatic '
'mixed training with cuda')
if dtype is not None or cache_enabled is not None:
print_log(
@ -96,19 +97,33 @@ def autocast(device_type: Optional[str] = None,
'`autocast` is only available in gpu mode')
else:
if torch.cuda.is_available():
device_type = 'cuda' if device_type is None else device_type
else:
device_type = 'cpu' if device_type is None else device_type
# Modified from https://github.com/pytorch/pytorch/blob/master/torch/amp/autocast_mode.py # noqa: E501
# This code should update with the `torch.autocast`.
if cache_enabled is None:
cache_enabled = torch.is_autocast_cache_enabled()
device = get_device()
device_type = device if device_type is None else device_type
if device_type == 'cuda':
if dtype is None:
dtype = torch.get_autocast_gpu_dtype()
if dtype == torch.bfloat16 and not \
torch.cuda.is_bf16_supported():
raise RuntimeError(
'Current CUDA Device does not support bfloat16. Please '
'switch dtype to float16.')
elif device_type == 'cpu':
if dtype is None:
dtype = torch.bfloat16
assert dtype == torch.bfloat16, (
'In CPU autocast, only support `torch.bfloat16` dtype')
else:
raise ValueError('User specified autocast device_type must be '
F'cuda or cpu, but got {device_type}')
if digit_version(TORCH_VERSION) < digit_version('1.11.0'):
if dtype is not None and dtype != torch.bfloat16:
print_log(
f'{dtype} must be `torch.bfloat16` with Pytorch '
f'version: {TORCH_VERSION}',
logger='current',
level=logging.WARNING)
dtype = torch.bfloat16
with torch.autocast(
device_type=device_type,
enabled=enabled,