mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix pytorch version compatibility of autocast (#339)
* fix unit test of autocast * fix compatiblity of unit test of optimizerwrapper * clean code * fix as comment * fix docstring
This commit is contained in:
parent
5ac3c23338
commit
59b0ccfe6f
@ -1,16 +1,22 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import logging
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from mmengine import print_log
|
||||||
from mmengine.utils import TORCH_VERSION, digit_version
|
from mmengine.utils import TORCH_VERSION, digit_version
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def autocast(enabled: bool = True, **kwargs):
|
def autocast(device_type: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
enabled: bool = True,
|
||||||
|
cache_enabled: Optional[bool] = None):
|
||||||
"""A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``.
|
"""A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``.
|
||||||
|
|
||||||
Pytorch 1.6.0 provide ``torch.cuda.amp.autocast`` for running in
|
Pytorch 1.5.0 provide ``torch.cuda.amp.autocast`` for running in
|
||||||
mixed precision , and update it to ``torch.autocast`` in 1.10.0.
|
mixed precision , and update it to ``torch.autocast`` in 1.10.0.
|
||||||
Both interfaces have different arguments, and ``torch.autocast``
|
Both interfaces have different arguments, and ``torch.autocast``
|
||||||
support running with cpu additionally.
|
support running with cpu additionally.
|
||||||
@ -49,9 +55,13 @@ def autocast(enabled: bool = True, **kwargs):
|
|||||||
>>> pass
|
>>> pass
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
device_type (str, required): Whether to use 'cuda' or 'cpu' device.
|
||||||
enabled(bool): Whether autocasting should be enabled in the region.
|
enabled(bool): Whether autocasting should be enabled in the region.
|
||||||
Defaults to True.
|
Defaults to True
|
||||||
kwargs (dict): Arguments of torch.autocast except for ``enabled``.
|
dtype (torch_dtype, optional): Whether to use ``torch.float16`` or
|
||||||
|
``torch.bfloat16``.
|
||||||
|
cache_enabled(bool, optional): Whether the weight cache inside
|
||||||
|
autocast should be enabled.
|
||||||
"""
|
"""
|
||||||
# If `enabled` is True, enable an empty context and all calculations
|
# If `enabled` is True, enable an empty context and all calculations
|
||||||
# are performed under fp32.
|
# are performed under fp32.
|
||||||
@ -63,9 +73,17 @@ def autocast(enabled: bool = True, **kwargs):
|
|||||||
digit_version('1.10.0')):
|
digit_version('1.10.0')):
|
||||||
# If pytorch version is between 1.5.0 and 1.10.0, the default value of
|
# If pytorch version is between 1.5.0 and 1.10.0, the default value of
|
||||||
# dtype for `torch.cuda.amp.autocast` is torch.float16.
|
# dtype for `torch.cuda.amp.autocast` is torch.float16.
|
||||||
assert not kwargs, (
|
assert device_type == 'cuda' or device_type is None, (
|
||||||
f'autocast under pytorch {TORCH_VERSION} only accept `enabled` '
|
'Pytorch version under 1.5.0 only supports running automatic '
|
||||||
'arguments.')
|
'mixed training with cuda')
|
||||||
|
if dtype is not None or cache_enabled is not None:
|
||||||
|
print_log(
|
||||||
|
f'{dtype} and {device_type} will not work for '
|
||||||
|
'`autocast` since your Pytorch version: '
|
||||||
|
f'{TORCH_VERSION} <= 1.10.0',
|
||||||
|
logger='current',
|
||||||
|
level=logging.WARNING)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
with torch.cuda.amp.autocast(enabled=enabled):
|
with torch.cuda.amp.autocast(enabled=enabled):
|
||||||
yield
|
yield
|
||||||
@ -77,24 +95,23 @@ def autocast(enabled: bool = True, **kwargs):
|
|||||||
'If pytorch versions is between 1.5.0 and 1.10, '
|
'If pytorch versions is between 1.5.0 and 1.10, '
|
||||||
'`autocast` is only available in gpu mode')
|
'`autocast` is only available in gpu mode')
|
||||||
|
|
||||||
elif (digit_version('1.11.0') > digit_version(TORCH_VERSION) >=
|
|
||||||
digit_version('1.10.0')):
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
kwargs.setdefault('device_type', 'cuda')
|
|
||||||
else:
|
else:
|
||||||
kwargs.setdefault('device_type', 'cpu')
|
|
||||||
# torch.autocast only support `dtype=torch.bfloat16` in
|
|
||||||
# pytorch 1.10
|
|
||||||
kwargs.setdefault('dtype', torch.bfloat16)
|
|
||||||
|
|
||||||
with torch.autocast(enabled=enabled, **kwargs):
|
|
||||||
yield
|
|
||||||
|
|
||||||
elif digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
kwargs.setdefault('device_type', 'cuda')
|
device_type = 'cuda' if device_type is None else device_type
|
||||||
else:
|
else:
|
||||||
kwargs.setdefault('device_type', 'cpu')
|
device_type = 'cpu' if device_type is None else device_type
|
||||||
|
|
||||||
with torch.autocast(enabled=enabled, **kwargs):
|
if digit_version(TORCH_VERSION) < digit_version('1.11.0'):
|
||||||
|
if dtype is not None and dtype != torch.bfloat16:
|
||||||
|
print_log(
|
||||||
|
f'{dtype} must be `torch.bfloat16` with Pytorch '
|
||||||
|
f'version: {TORCH_VERSION}',
|
||||||
|
logger='current',
|
||||||
|
level=logging.WARNING)
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
with torch.autocast(
|
||||||
|
device_type=device_type,
|
||||||
|
enabled=enabled,
|
||||||
|
dtype=dtype,
|
||||||
|
cache_enabled=cache_enabled):
|
||||||
yield
|
yield
|
||||||
|
@ -42,16 +42,16 @@ class TestAmp(unittest.TestCase):
|
|||||||
else:
|
else:
|
||||||
devices = ['cpu', 'cuda']
|
devices = ['cpu', 'cuda']
|
||||||
for device in devices:
|
for device in devices:
|
||||||
with autocast():
|
with autocast(device_type=device):
|
||||||
# torch.autocast support cpu and cuda mode.
|
# torch.autocast support cpu and cuda mode.
|
||||||
layer = nn.Conv2d(1, 1, 1).to(device)
|
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||||
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
|
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False, device_type=device):
|
||||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||||
self.assertEqual(res.dtype, torch.float32)
|
self.assertEqual(res.dtype, torch.float32)
|
||||||
# Test with fp32_enabled
|
# Test with fp32_enabled
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False, device_type=device):
|
||||||
layer = nn.Conv2d(1, 1, 1).to(device)
|
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||||
self.assertEqual(res.dtype, torch.float32)
|
self.assertEqual(res.dtype, torch.float32)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user