mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix pytorch version compatibility of autocast (#339)
* fix unit test of autocast * fix compatiblity of unit test of optimizerwrapper * clean code * fix as comment * fix docstring
This commit is contained in:
parent
5ac3c23338
commit
59b0ccfe6f
@ -1,16 +1,22 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine import print_log
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
|
||||
|
||||
@contextmanager
|
||||
def autocast(enabled: bool = True, **kwargs):
|
||||
def autocast(device_type: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
enabled: bool = True,
|
||||
cache_enabled: Optional[bool] = None):
|
||||
"""A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``.
|
||||
|
||||
Pytorch 1.6.0 provide ``torch.cuda.amp.autocast`` for running in
|
||||
Pytorch 1.5.0 provide ``torch.cuda.amp.autocast`` for running in
|
||||
mixed precision , and update it to ``torch.autocast`` in 1.10.0.
|
||||
Both interfaces have different arguments, and ``torch.autocast``
|
||||
support running with cpu additionally.
|
||||
@ -49,9 +55,13 @@ def autocast(enabled: bool = True, **kwargs):
|
||||
>>> pass
|
||||
|
||||
Args:
|
||||
enabled (bool): Whether autocasting should be enabled in the region.
|
||||
Defaults to True.
|
||||
kwargs (dict): Arguments of torch.autocast except for ``enabled``.
|
||||
device_type (str, required): Whether to use 'cuda' or 'cpu' device.
|
||||
enabled(bool): Whether autocasting should be enabled in the region.
|
||||
Defaults to True
|
||||
dtype (torch_dtype, optional): Whether to use ``torch.float16`` or
|
||||
``torch.bfloat16``.
|
||||
cache_enabled(bool, optional): Whether the weight cache inside
|
||||
autocast should be enabled.
|
||||
"""
|
||||
# If `enabled` is True, enable an empty context and all calculations
|
||||
# are performed under fp32.
|
||||
@ -63,9 +73,17 @@ def autocast(enabled: bool = True, **kwargs):
|
||||
digit_version('1.10.0')):
|
||||
# If pytorch version is between 1.5.0 and 1.10.0, the default value of
|
||||
# dtype for `torch.cuda.amp.autocast` is torch.float16.
|
||||
assert not kwargs, (
|
||||
f'autocast under pytorch {TORCH_VERSION} only accept `enabled` '
|
||||
'arguments.')
|
||||
assert device_type == 'cuda' or device_type is None, (
|
||||
'Pytorch version under 1.5.0 only supports running automatic '
|
||||
'mixed training with cuda')
|
||||
if dtype is not None or cache_enabled is not None:
|
||||
print_log(
|
||||
f'{dtype} and {device_type} will not work for '
|
||||
'`autocast` since your Pytorch version: '
|
||||
f'{TORCH_VERSION} <= 1.10.0',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.amp.autocast(enabled=enabled):
|
||||
yield
|
||||
@ -77,24 +95,23 @@ def autocast(enabled: bool = True, **kwargs):
|
||||
'If pytorch versions is between 1.5.0 and 1.10, '
|
||||
'`autocast` is only available in gpu mode')
|
||||
|
||||
elif (digit_version('1.11.0') > digit_version(TORCH_VERSION) >=
|
||||
digit_version('1.10.0')):
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
kwargs.setdefault('device_type', 'cuda')
|
||||
device_type = 'cuda' if device_type is None else device_type
|
||||
else:
|
||||
kwargs.setdefault('device_type', 'cpu')
|
||||
# torch.autocast only support `dtype=torch.bfloat16` in
|
||||
# pytorch 1.10
|
||||
kwargs.setdefault('dtype', torch.bfloat16)
|
||||
device_type = 'cpu' if device_type is None else device_type
|
||||
|
||||
with torch.autocast(enabled=enabled, **kwargs):
|
||||
yield
|
||||
|
||||
elif digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
|
||||
if torch.cuda.is_available():
|
||||
kwargs.setdefault('device_type', 'cuda')
|
||||
else:
|
||||
kwargs.setdefault('device_type', 'cpu')
|
||||
|
||||
with torch.autocast(enabled=enabled, **kwargs):
|
||||
if digit_version(TORCH_VERSION) < digit_version('1.11.0'):
|
||||
if dtype is not None and dtype != torch.bfloat16:
|
||||
print_log(
|
||||
f'{dtype} must be `torch.bfloat16` with Pytorch '
|
||||
f'version: {TORCH_VERSION}',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
dtype = torch.bfloat16
|
||||
with torch.autocast(
|
||||
device_type=device_type,
|
||||
enabled=enabled,
|
||||
dtype=dtype,
|
||||
cache_enabled=cache_enabled):
|
||||
yield
|
||||
|
@ -42,16 +42,16 @@ class TestAmp(unittest.TestCase):
|
||||
else:
|
||||
devices = ['cpu', 'cuda']
|
||||
for device in devices:
|
||||
with autocast():
|
||||
with autocast(device_type=device):
|
||||
# torch.autocast support cpu and cuda mode.
|
||||
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
|
||||
with autocast(enabled=False):
|
||||
with autocast(enabled=False, device_type=device):
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertEqual(res.dtype, torch.float32)
|
||||
# Test with fp32_enabled
|
||||
with autocast(enabled=False):
|
||||
with autocast(enabled=False, device_type=device):
|
||||
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertEqual(res.dtype, torch.float32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user