[Feature] Enable bf16 in AmpOptimWrapper (#960)

* support bf16 in AmpOptimWrapper

* add docstring

* modify docs

* add unittests for bf16 in AmpOptimWrapper

* fix type

* fix to pass ci

* fix ut skip logic to pass ci

* fix as comment

* add type hints

* fix docstring and add warning information

* remove check for pytorch>=1.6 in unittest

* modify unittest

* modify unittest

* remove torch.float32 && torch.float64 from valid dtypes

* fix as comments

* minor refine docstring

* fix unittest parameterized to pass CI

* fix unittest && add back torch.float32, torch.float64
pull/975/head
Qian Zhao 2023-03-01 21:35:18 +08:00 committed by GitHub
parent 8a407ca214
commit 2ed8e343a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 41 deletions

View File

@ -60,7 +60,7 @@ MMEngine supports training models with CPU, single GPU, multiple GPUs in single
## Mixed Precision Training ## Mixed Precision Training
Nvidia introduced the Tensor Core unit into the Volta and Turing architectures to support FP32 and FP16 mixed precision computing. With automatic mixed precision training enabled, some operators operate at FP16 and the rest operate at FP32, which reduces training time and storage requirements without changing the model or degrading its training precision, thus supporting training with larger batch sizes, larger models, and larger input sizes. Nvidia introduced the Tensor Core unit into the Volta and Turing architectures to support FP32 and FP16 mixed precision computing. They further support BF16 in Ampere architectures. With automatic mixed precision training enabled, some operators operate at FP16/BF16 and the rest operate at FP32, which reduces training time and storage requirements without changing the model or degrading its training precision, thus supporting training with larger batch sizes, larger models, and larger input sizes.
[PyTorch officially supports amp from 1.6](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/). If you are interested in the implementation of automatic mixing precision, you can refer to [Mixed Precision Training](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html). [PyTorch officially supports amp from 1.6](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/). If you are interested in the implementation of automatic mixing precision, you can refer to [Mixed Precision Training](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html).
@ -71,8 +71,16 @@ runner = Runner(
model=ResNet18(), model=ResNet18(),
work_dir='./work_dir', work_dir='./work_dir',
train_dataloader=train_dataloader_cfg, train_dataloader=train_dataloader_cfg,
optim_wrapper=dict(type='AmpOptimWrapper', optimizer=dict(type='SGD', lr=0.001, momentum=0.9)), optim_wrapper=dict(
type='AmpOptimWrapper',
# If you want to use bfloat16, uncomment the following line
# dtype='bfloat16', # valid values: ('float16', 'bfloat16', None)
optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=3), train_cfg=dict(by_epoch=True, max_epochs=3),
) )
runner.train() runner.train()
``` ```
```{warning}
Up till PyTorch 1.13, `torch.bfloat16` performance on `Convolution` is bad unless manually set environment variable `TORCH_CUDNN_V8_API_ENABLED=1`. More context at [PyTorch issue](https://github.com/pytorch/pytorch/issues/57707#issuecomment-1166656767)
```

View File

@ -61,7 +61,7 @@ MMEngine 支持 CPU、单卡、单机多卡以及多机多卡的训练。当环
## 混合精度训练 ## 混合精度训练
Nvidia 在 Volta 和 Turing 架构中引入 Tensor Core 单元,来支持 FP32 和 FP16 混合精度计算。开启自动混合精度训练后,部分算子的操作精度是 FP16其余算子的操作精度是 FP32。这样在不改变模型、不降低模型训练精度的前提下可以缩短训练时间降低存储需求因而能支持更大的 batch size、更大模型和尺寸更大的输入的训练。 Nvidia 在 Volta 和 Turing 架构中引入 Tensor Core 单元,来支持 FP32 和 FP16 混合精度计算。在 Ampere 架构中,他们进一步支持了 BF16 计算。开启自动混合精度训练后,部分算子的操作精度是 FP16/BF16其余算子的操作精度是 FP32。这样在不改变模型、不降低模型训练精度的前提下可以缩短训练时间降低存储需求因而能支持更大的 batch size、更大模型和尺寸更大的输入的训练。
[PyTorch 从 1.6 开始官方支持 amp](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/)。如果你对自动混合精度的实现感兴趣,可以阅读 [torch.cuda.amp: 自动混合精度详解](https://zhuanlan.zhihu.com/p/348554267)。 [PyTorch 从 1.6 开始官方支持 amp](https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/)。如果你对自动混合精度的实现感兴趣,可以阅读 [torch.cuda.amp: 自动混合精度详解](https://zhuanlan.zhihu.com/p/348554267)。
@ -72,8 +72,16 @@ runner = Runner(
model=ResNet18(), model=ResNet18(),
work_dir='./work_dir', work_dir='./work_dir',
train_dataloader=train_dataloader_cfg, train_dataloader=train_dataloader_cfg,
optim_wrapper=dict(type='AmpOptimWrapper', optimizer=dict(type='SGD', lr=0.001, momentum=0.9)), optim_wrapper=dict(
type='AmpOptimWrapper',
# 如果你想要使用 BF16请取消下面一行的代码注释
# dtype='bfloat16', # 可用值: ('float16', 'bfloat16', None)
optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=3), train_cfg=dict(by_epoch=True, max_epochs=3),
) )
runner.train() runner.train()
``` ```
```{warning}
截止到 PyTorch 1.13 版本,在 `Convolution` 中直接使用 `torch.bfloat16` 性能低下,必须手动设置环境变量 `TORCH_CUDNN_V8_API_ENABLED=1` 以启用 CuDNN 版本的 BF16 Convolution。相关讨论见 [PyTorch Issue](https://github.com/pytorch/pytorch/issues/57707#issuecomment-1166656767)
```

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager from contextlib import contextmanager
from typing import Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -38,15 +39,30 @@ class AmpOptimWrapper(OptimWrapper):
- float: Initialize GradScaler with ``init_scale``. - float: Initialize GradScaler with ``init_scale``.
- dict: Initialize GradScaler with more detail configuration. - dict: Initialize GradScaler with more detail configuration.
dtype (str or torch.dtype, optional): The data type to autocast in amp.
If a ``str`` is given, it will be converted to ``torch.dtype``.
Valid ``str`` format are `'float16'`, `'bfloat16'`, `'float32'` and
`'float64'`. If set to ``None``, the default data type will be used.
Defaults to None.
`New in version 0.6.1.`
**kwargs: Keyword arguments passed to OptimWrapper. **kwargs: Keyword arguments passed to OptimWrapper.
Warnings:
``dtype`` argument is only available with PyTorch version >= 1.10.0. If
you use PyTorch of an older version, it will be ignored.
Note: Note:
If you use ``IterBasedRunner`` and enable gradient accumulation, If you use ``IterBasedRunner`` and enable gradient accumulation,
the original `max_iters` should be multiplied by the original `max_iters` should be multiplied by
``accumulative_counts``. ``accumulative_counts``.
""" """
def __init__(self, loss_scale='dynamic', **kwargs): valid_dtypes = ('float16', 'bfloat16', 'float32', 'float64')
def __init__(self,
loss_scale: str = 'dynamic',
dtype: Union[str, torch.dtype] = None,
**kwargs):
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
'`torch.cuda.amp` is only available when pytorch version >= 1.6') '`torch.cuda.amp` is only available when pytorch version >= 1.6')
assert is_cuda_available() or is_npu_available(), ( assert is_cuda_available() or is_npu_available(), (
@ -68,6 +84,16 @@ class AmpOptimWrapper(OptimWrapper):
raise TypeError('loss_scale must be of type float, dict, or ' raise TypeError('loss_scale must be of type float, dict, or '
f'"dynamic", but got {loss_scale}') f'"dynamic", but got {loss_scale}')
# convert string value to torch.dtype
if isinstance(dtype, str):
assert dtype in self.valid_dtypes, (
f'dtype should be any of {self.valid_dtypes}, got {dtype}')
dtype = getattr(torch, dtype)
assert dtype is None or isinstance(dtype, torch.dtype), (
f'dtype should be None or instance of torch.dtype, got {dtype}')
self.cast_dtype = dtype
def backward(self, loss: torch.Tensor, **kwargs): def backward(self, loss: torch.Tensor, **kwargs):
"""Perform gradient back propagation with :attr:`loss_scaler`. """Perform gradient back propagation with :attr:`loss_scaler`.
@ -133,5 +159,5 @@ class AmpOptimWrapper(OptimWrapper):
model (nn.Module): The training model. model (nn.Module): The training model.
""" """
from mmengine.runner.amp import autocast from mmengine.runner.amp import autocast
with super().optim_context(model), autocast(): with super().optim_context(model), autocast(dtype=self.cast_dtype):
yield yield

View File

@ -4,10 +4,10 @@ import unittest
from unittest import TestCase from unittest import TestCase
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest
import torch import torch
import torch.distributed as torch_dist import torch.distributed as torch_dist
import torch.nn as nn import torch.nn as nn
from parameterized import parameterized
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel.distributed import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD, Adam, Optimizer from torch.optim import SGD, Adam, Optimizer
@ -17,8 +17,6 @@ from mmengine.logging import MessageHub, MMLogger
from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
is_apex_available = False is_apex_available = False
try: try:
@ -27,6 +25,17 @@ try:
except ImportError: except ImportError:
pass pass
amp_valid_dtypes = ['float64', 'float32', 'float16', 'bfloat16', None]
torch_dtypes = [
torch.float16 if dtype is None else getattr(torch, dtype)
for dtype in amp_valid_dtypes
]
def bf16_supported() -> bool:
return (hasattr(torch.cuda, 'is_bf16_supported')
and torch.cuda.is_bf16_supported())
class ToyModel(nn.Module): class ToyModel(nn.Module):
@ -196,7 +205,7 @@ class TestOptimWrapper(MultiProcessTestCase):
# TODO: This unit test could cause CI to fail with some probability, which # TODO: This unit test could cause CI to fail with some probability, which
# is caused by MultiProcessTestCase. This problem should be solved # is caused by MultiProcessTestCase. This problem should be solved
# in the future). # in the future).
@pytest.mark.skipif(True, reason='Solved in the future') @unittest.skipIf(True, reason='Solved in the future')
def test_clip_grads(self): def test_clip_grads(self):
# Test `clip_grad` with `clip_norm_` # Test `clip_grad` with `clip_norm_`
optim_wrapper = OptimWrapper( optim_wrapper = OptimWrapper(
@ -392,10 +401,8 @@ class TestAmpOptimWrapper(TestCase):
self.optimizer = SGD(self.model.parameters(), lr=0.1) self.optimizer = SGD(self.model.parameters(), lr=0.1)
@unittest.skipIf( @unittest.skipIf(
not torch.cuda.is_available() not torch.cuda.is_available(),
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
'>= 1.6')
def test_init(self): def test_init(self):
# Test with default arguments. # Test with default arguments.
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
@ -407,6 +414,16 @@ class TestAmpOptimWrapper(TestCase):
self.assertIsNone(amp_optim_wrapper._scale_update_param) self.assertIsNone(amp_optim_wrapper._scale_update_param)
self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler)
# Test with dtype float16
amp_optim_wrapper = AmpOptimWrapper(
dtype='float16', optimizer=self.optimizer)
self.assertIs(amp_optim_wrapper.cast_dtype, torch.float16)
# Test with dtype bfloat16
amp_optim_wrapper = AmpOptimWrapper(
dtype='bfloat16', optimizer=self.optimizer)
self.assertIs(amp_optim_wrapper.cast_dtype, torch.bfloat16)
# Test with dict loss_scale. # Test with dict loss_scale.
amp_optim_wrapper = AmpOptimWrapper( amp_optim_wrapper = AmpOptimWrapper(
dict(init_scale=1, growth_factor=2), optimizer=self.optimizer) dict(init_scale=1, growth_factor=2), optimizer=self.optimizer)
@ -416,14 +433,15 @@ class TestAmpOptimWrapper(TestCase):
'loss_scale must be of type float'): 'loss_scale must be of type float'):
AmpOptimWrapper(optimizer=self.optimizer, loss_scale='unknown') AmpOptimWrapper(optimizer=self.optimizer, loss_scale='unknown')
@parameterized.expand(list(zip(amp_valid_dtypes)))
@unittest.skipIf( @unittest.skipIf(
not torch.cuda.is_available() not torch.cuda.is_available(),
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
reason='`torch.cuda.amp` is only available when pytorch-gpu version ' def test_step(self, dtype):
'>= 1.6') if dtype == 'bfloat16' and not bf16_supported():
def test_step(self): raise unittest.SkipTest('bfloat16 not supported by device')
optimizer = MagicMock(spec=Optimizer) optimizer = MagicMock(spec=Optimizer)
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer) amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer, dtype=dtype)
amp_optim_wrapper.loss_scaler = MagicMock() amp_optim_wrapper.loss_scaler = MagicMock()
amp_optim_wrapper.step() amp_optim_wrapper.step()
amp_optim_wrapper.loss_scaler.step.assert_called_with( amp_optim_wrapper.loss_scaler.step.assert_called_with(
@ -431,13 +449,15 @@ class TestAmpOptimWrapper(TestCase):
amp_optim_wrapper.loss_scaler.update.assert_called_with( amp_optim_wrapper.loss_scaler.update.assert_called_with(
amp_optim_wrapper._scale_update_param) amp_optim_wrapper._scale_update_param)
@parameterized.expand(list(zip(amp_valid_dtypes)))
@unittest.skipIf( @unittest.skipIf(
not torch.cuda.is_available() not torch.cuda.is_available(),
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
reason='`torch.cuda.amp` is only available when pytorch-gpu version ' def test_backward(self, dtype):
'>= 1.6') if dtype == 'bfloat16' and not bf16_supported():
def test_backward(self): raise unittest.SkipTest('bfloat16 not supported by device')
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) amp_optim_wrapper = AmpOptimWrapper(
optimizer=self.optimizer, dtype=dtype)
loss_scaler = MagicMock() loss_scaler = MagicMock()
scale_return = MagicMock() scale_return = MagicMock()
scale_fn = MagicMock(return_value=scale_return) scale_fn = MagicMock(return_value=scale_return)
@ -449,10 +469,8 @@ class TestAmpOptimWrapper(TestCase):
scale_return.backward.assert_called_with() scale_return.backward.assert_called_with()
@unittest.skipIf( @unittest.skipIf(
not torch.cuda.is_available() not torch.cuda.is_available(),
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
'>= 1.6')
def test_state_dict(self): def test_state_dict(self):
self.model = self.model.cuda() self.model = self.model.cuda()
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
@ -468,10 +486,8 @@ class TestAmpOptimWrapper(TestCase):
amp_optim_wrapper.loss_scaler.state_dict()) amp_optim_wrapper.loss_scaler.state_dict())
@unittest.skipIf( @unittest.skipIf(
not torch.cuda.is_available() not torch.cuda.is_available(),
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
reason='`torch.cuda.amp` is only available when pytorch-gpu version '
'>= 1.6')
def test_load_state_dict(self): def test_load_state_dict(self):
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer)
self.model = self.model.cuda() self.model = self.model.cuda()
@ -491,14 +507,16 @@ class TestAmpOptimWrapper(TestCase):
self.assertDictEqual(amp_optim_wrapper.loss_scaler.state_dict(), self.assertDictEqual(amp_optim_wrapper.loss_scaler.state_dict(),
amp_optim_wrapper_.loss_scaler.state_dict()) amp_optim_wrapper_.loss_scaler.state_dict())
@parameterized.expand(list(zip(amp_valid_dtypes, torch_dtypes)))
@unittest.skipIf( @unittest.skipIf(
not torch.cuda.is_available() not torch.cuda.is_available(),
and (digit_version(TORCH_VERSION) >= digit_version('1.6.0')), reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
reason='`torch.cuda.amp` is only available when pytorch-gpu version ' def test_optim_context(self, dtype, target_dtype):
'>= 1.6') if dtype == 'bfloat16' and not bf16_supported():
def test_optim_context(self): raise unittest.SkipTest('bfloat16 not supported by device')
amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) amp_optim_wrapper = AmpOptimWrapper(
optimizer=self.optimizer, dtype=dtype)
with amp_optim_wrapper.optim_context(self.model): with amp_optim_wrapper.optim_context(self.model):
x = torch.randn(1, 1, 1, 1).cuda() x = torch.randn(1, 1, 1, 1).cuda()
y = nn.Conv2d(1, 1, 1).cuda()(x) y = nn.Conv2d(1, 1, 1).cuda()(x)
self.assertEqual(y.dtype, torch.float16) self.assertEqual(y.dtype, target_dtype)