[Feature] Support torch.compile since PyTorch2.0 (#976)

* enable compile configurations to support torch.compile in Runner

* enable compilation in train, val and test

* fix as comments

* add docstring to illustrate usage

* minor refine error message

* add unittests

* fix ut skip

* add logging message to inform users

* compile `train_step`, `val_step`, `test_step` instead

* fix as comments

* revert to compile `train_step` only due to pt2 issue

* add documentation about torch.compile
pull/968/head^2
Qian Zhao 2023-03-12 18:26:43 +08:00 committed by GitHub
parent 6ea23a2f71
commit 0d25625ba2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 150 additions and 1 deletions

View File

@ -84,3 +84,32 @@ runner.train()
```{warning}
Up till PyTorch 1.13, `torch.bfloat16` performance on `Convolution` is bad unless manually set environment variable `TORCH_CUDNN_V8_API_ENABLED=1`. More context at [PyTorch issue](https://github.com/pytorch/pytorch/issues/57707#issuecomment-1166656767)
```
## Model Compilation
PyTorch introduced [torch.compile](https://pytorch.org/docs/2.0/dynamo/get-started.html) in its 2.0 release. It compiles your model to speedup trainning & validation. This feature can be enabled since MMEngine v0.7.0, by passing to `Runner` an extra `cfg` dict with `compile` keyword:
```python
runner = Runner(
model=ResNet18(),
... # other arguments you want
cfg=dict(compile=True)
)
```
For advanced usage, you can also change compile options as illustrated in [torch.compile API Documentation](https://pytorch.org/docs/2.0/generated/torch.compile.html#torch-compile). For example:
```python
compile_options = dict(backend='inductor', mode='max-autotune')
runner = Runner(
model=ResNet18(),
... # other arguments you want
cfg=dict(compile=compile_options)
)
```
This feature is only available for PyTorch >= 2.0.0.
```{warning}
`torch.compile` is still under development by PyTorch team. Some models may fail compilation. If you encounter errors during compilation, you can refer to [PyTorch Dynamo FAQ](https://pytorch.org/docs/2.0/dynamo/faq.html) for quick fix, or [TorchDynamo Troubleshooting](https://pytorch.org/docs/2.0/dynamo/troubleshooting.html) to post an issue in PyTorch.
```

View File

@ -85,3 +85,32 @@ runner.train()
```{warning}
截止到 PyTorch 1.13 版本,在 `Convolution` 中直接使用 `torch.bfloat16` 性能低下,必须手动设置环境变量 `TORCH_CUDNN_V8_API_ENABLED=1` 以启用 CuDNN 版本的 BF16 Convolution。相关讨论见 [PyTorch Issue](https://github.com/pytorch/pytorch/issues/57707#issuecomment-1166656767)
```
## 模型编译
PyTorch 2.0 版本引入了 [torch.compile](https://pytorch.org/docs/2.0/dynamo/get-started.html) 新特性通过对模型进行编译来加速训练、验证。MMEngine 从 v0.7.0 版本开始支持这一特性,你可以通过向 `Runner``cfg` 参数传入一个带有 `compile` 关键词的字典来开启模型编译:
```python
runner = Runner(
model=ResNet18(),
... # 你的其他 Runner 配置参数
cfg=dict(compile=True)
)
```
此外,你也可以传入更多的编译配置选项,所有编译配置选项可以参考 [torch.compile API 文档](https://pytorch.org/docs/2.0/generated/torch.compile.html#torch-compile)
```python
compile_options = dict(backend='inductor', mode='max-autotune')
runner = Runner(
model=ResNet18(),
... # 你的其他 Runner 配置参数
cfg=dict(compile=compile_options)
)
```
这一特性只有在你安装 PyTorch >= 2.0.0 版本时才可用。
```{warning}
`torch.compile` 目前仍然由 PyTorch 团队持续开发中,一些模型可能会编译失败。如果遇到了类似问题,你可以查阅 [PyTorch Dynamo FAQ](https://pytorch.org/docs/2.0/dynamo/faq.html) 解决常见问题,或参考 [TorchDynamo Troubleshooting](https://pytorch.org/docs/2.0/dynamo/troubleshooting.html) 向 PyTorch 提 issue.
```

View File

@ -180,6 +180,14 @@ class Runner:
cfg (dict or Configdict or :obj:`Config`, optional): Full config.
Defaults to None.
Note:
Since PyTorch 2.0.0, you can enable ``torch.compile`` by passing in
`cfg.compile = True`. If you want to control compile options, you
can pass a dict, e.g. ``cfg.compile = dict(backend='eager')``.
Refer to `PyTorch API Documentation <https://pytorch.org/docs/
master/generated/torch.compile.html#torch.compile>`_ for more valid
options.
Examples:
>>> from mmengine.runner import Runner
>>> cfg = dict(
@ -1686,6 +1694,10 @@ class Runner:
self._train_loop.iter, # type: ignore
self._train_loop.max_iters) # type: ignore
# Maybe compile the model according to options in self.cfg.compile
# This must be called **AFTER** model has been wrapped.
self._maybe_compile('train_step')
model = self.train_loop.run() # type: ignore
self.call_hook('after_run')
return model
@ -2288,3 +2300,28 @@ class Runner:
'\nRuntime environment:' + runtime_env_info + '\n' +
dash_line + '\n')
self.logger.info(f'Config:\n{self.cfg.pretty_text}')
def _maybe_compile(self, target: str) -> None:
"""Use `torch.compile` to optimize model/wrapped_model."""
compile_cfg = self.cfg.get('compile', None)
if compile_cfg is None:
# no compile options given, won't compile
return
if isinstance(compile_cfg, bool):
if not compile_cfg:
# compile=False, compilation is disabled
return
# compile=True, use default configurations
compile_cfg = dict()
assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), (
'PyTorch >= 2.0.0 is required to enable torch.compile')
assert isinstance(compile_cfg, dict), (
f'`compile` should be a dict or bool, got {type(compile_cfg)}')
func = getattr(self.model, target)
compiled_func = torch.compile(func, **compile_cfg)
setattr(self.model, target, compiled_func)
self.logger.info('Model has been "compiled". The first few iterations'
' will be slow, please be patient.')

View File

@ -5,7 +5,7 @@ import os
import os.path as osp
import shutil
import tempfile
from unittest import TestCase
from unittest import TestCase, skipIf
import numpy as np
import torch
@ -1705,6 +1705,24 @@ class TestRunner(TestCase):
with self.assertRaisesRegex(AssertionError, 'If you want to validate'):
runner.train()
@skipIf(
not hasattr(torch, 'compile'),
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
def test_train_with_compile(self):
# 1. test with simple configuration
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train_compile_simple'
cfg.compile = True
runner = Runner.from_cfg(cfg)
runner.train()
# 2. test with advanced configuration
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_train_compile_advanced'
cfg.compile = dict(backend='inductor', mode='default')
runner = Runner.from_cfg(cfg)
runner.train()
def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val1'
@ -1757,6 +1775,24 @@ class TestRunner(TestCase):
self.assertIn(predictions[0].dtype,
(torch.float16, torch.bfloat16))
@skipIf(
not hasattr(torch, 'compile'),
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
def test_val_with_compile(self):
# 1. test with simple configuration
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val_compile_simple'
cfg.compile = True
runner = Runner.from_cfg(cfg)
runner.val()
# 2. test with advanced configuration
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val_compile_advanced'
cfg.compile = dict(backend='inductor', mode='default')
runner = Runner.from_cfg(cfg)
runner.val()
def test_test(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_test1'
@ -1811,6 +1847,24 @@ class TestRunner(TestCase):
self.assertIn(predictions[0].dtype,
(torch.float16, torch.bfloat16))
@skipIf(
not hasattr(torch, 'compile'),
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
def test_test_with_compile(self):
# 1. test with simple configuration
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_test_compile_simple'
cfg.compile = True
runner = Runner.from_cfg(cfg)
runner.test()
# 2. test with advanced configuration
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_test_compile_advanced'
cfg.compile = dict(backend='inductor', mode='default')
runner = Runner.from_cfg(cfg)
runner.test()
def test_register_hook(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_register_hook'