[Feature] Support torch.compile since PyTorch2.0 (#976)
* enable compile configurations to support torch.compile in Runner * enable compilation in train, val and test * fix as comments * add docstring to illustrate usage * minor refine error message * add unittests * fix ut skip * add logging message to inform users * compile `train_step`, `val_step`, `test_step` instead * fix as comments * revert to compile `train_step` only due to pt2 issue * add documentation about torch.compilepull/968/head^2
parent
6ea23a2f71
commit
0d25625ba2
|
@ -84,3 +84,32 @@ runner.train()
|
|||
```{warning}
|
||||
Up till PyTorch 1.13, `torch.bfloat16` performance on `Convolution` is bad unless manually set environment variable `TORCH_CUDNN_V8_API_ENABLED=1`. More context at [PyTorch issue](https://github.com/pytorch/pytorch/issues/57707#issuecomment-1166656767)
|
||||
```
|
||||
|
||||
## Model Compilation
|
||||
|
||||
PyTorch introduced [torch.compile](https://pytorch.org/docs/2.0/dynamo/get-started.html) in its 2.0 release. It compiles your model to speedup trainning & validation. This feature can be enabled since MMEngine v0.7.0, by passing to `Runner` an extra `cfg` dict with `compile` keyword:
|
||||
|
||||
```python
|
||||
runner = Runner(
|
||||
model=ResNet18(),
|
||||
... # other arguments you want
|
||||
cfg=dict(compile=True)
|
||||
)
|
||||
```
|
||||
|
||||
For advanced usage, you can also change compile options as illustrated in [torch.compile API Documentation](https://pytorch.org/docs/2.0/generated/torch.compile.html#torch-compile). For example:
|
||||
|
||||
```python
|
||||
compile_options = dict(backend='inductor', mode='max-autotune')
|
||||
runner = Runner(
|
||||
model=ResNet18(),
|
||||
... # other arguments you want
|
||||
cfg=dict(compile=compile_options)
|
||||
)
|
||||
```
|
||||
|
||||
This feature is only available for PyTorch >= 2.0.0.
|
||||
|
||||
```{warning}
|
||||
`torch.compile` is still under development by PyTorch team. Some models may fail compilation. If you encounter errors during compilation, you can refer to [PyTorch Dynamo FAQ](https://pytorch.org/docs/2.0/dynamo/faq.html) for quick fix, or [TorchDynamo Troubleshooting](https://pytorch.org/docs/2.0/dynamo/troubleshooting.html) to post an issue in PyTorch.
|
||||
```
|
||||
|
|
|
@ -85,3 +85,32 @@ runner.train()
|
|||
```{warning}
|
||||
截止到 PyTorch 1.13 版本,在 `Convolution` 中直接使用 `torch.bfloat16` 性能低下,必须手动设置环境变量 `TORCH_CUDNN_V8_API_ENABLED=1` 以启用 CuDNN 版本的 BF16 Convolution。相关讨论见 [PyTorch Issue](https://github.com/pytorch/pytorch/issues/57707#issuecomment-1166656767)
|
||||
```
|
||||
|
||||
## 模型编译
|
||||
|
||||
PyTorch 2.0 版本引入了 [torch.compile](https://pytorch.org/docs/2.0/dynamo/get-started.html) 新特性,通过对模型进行编译来加速训练、验证。MMEngine 从 v0.7.0 版本开始支持这一特性,你可以通过向 `Runner` 的 `cfg` 参数传入一个带有 `compile` 关键词的字典来开启模型编译:
|
||||
|
||||
```python
|
||||
runner = Runner(
|
||||
model=ResNet18(),
|
||||
... # 你的其他 Runner 配置参数
|
||||
cfg=dict(compile=True)
|
||||
)
|
||||
```
|
||||
|
||||
此外,你也可以传入更多的编译配置选项,所有编译配置选项可以参考 [torch.compile API 文档](https://pytorch.org/docs/2.0/generated/torch.compile.html#torch-compile)
|
||||
|
||||
```python
|
||||
compile_options = dict(backend='inductor', mode='max-autotune')
|
||||
runner = Runner(
|
||||
model=ResNet18(),
|
||||
... # 你的其他 Runner 配置参数
|
||||
cfg=dict(compile=compile_options)
|
||||
)
|
||||
```
|
||||
|
||||
这一特性只有在你安装 PyTorch >= 2.0.0 版本时才可用。
|
||||
|
||||
```{warning}
|
||||
`torch.compile` 目前仍然由 PyTorch 团队持续开发中,一些模型可能会编译失败。如果遇到了类似问题,你可以查阅 [PyTorch Dynamo FAQ](https://pytorch.org/docs/2.0/dynamo/faq.html) 解决常见问题,或参考 [TorchDynamo Troubleshooting](https://pytorch.org/docs/2.0/dynamo/troubleshooting.html) 向 PyTorch 提 issue.
|
||||
```
|
||||
|
|
|
@ -180,6 +180,14 @@ class Runner:
|
|||
cfg (dict or Configdict or :obj:`Config`, optional): Full config.
|
||||
Defaults to None.
|
||||
|
||||
Note:
|
||||
Since PyTorch 2.0.0, you can enable ``torch.compile`` by passing in
|
||||
`cfg.compile = True`. If you want to control compile options, you
|
||||
can pass a dict, e.g. ``cfg.compile = dict(backend='eager')``.
|
||||
Refer to `PyTorch API Documentation <https://pytorch.org/docs/
|
||||
master/generated/torch.compile.html#torch.compile>`_ for more valid
|
||||
options.
|
||||
|
||||
Examples:
|
||||
>>> from mmengine.runner import Runner
|
||||
>>> cfg = dict(
|
||||
|
@ -1686,6 +1694,10 @@ class Runner:
|
|||
self._train_loop.iter, # type: ignore
|
||||
self._train_loop.max_iters) # type: ignore
|
||||
|
||||
# Maybe compile the model according to options in self.cfg.compile
|
||||
# This must be called **AFTER** model has been wrapped.
|
||||
self._maybe_compile('train_step')
|
||||
|
||||
model = self.train_loop.run() # type: ignore
|
||||
self.call_hook('after_run')
|
||||
return model
|
||||
|
@ -2288,3 +2300,28 @@ class Runner:
|
|||
'\nRuntime environment:' + runtime_env_info + '\n' +
|
||||
dash_line + '\n')
|
||||
self.logger.info(f'Config:\n{self.cfg.pretty_text}')
|
||||
|
||||
def _maybe_compile(self, target: str) -> None:
|
||||
"""Use `torch.compile` to optimize model/wrapped_model."""
|
||||
compile_cfg = self.cfg.get('compile', None)
|
||||
if compile_cfg is None:
|
||||
# no compile options given, won't compile
|
||||
return
|
||||
|
||||
if isinstance(compile_cfg, bool):
|
||||
if not compile_cfg:
|
||||
# compile=False, compilation is disabled
|
||||
return
|
||||
# compile=True, use default configurations
|
||||
compile_cfg = dict()
|
||||
|
||||
assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), (
|
||||
'PyTorch >= 2.0.0 is required to enable torch.compile')
|
||||
assert isinstance(compile_cfg, dict), (
|
||||
f'`compile` should be a dict or bool, got {type(compile_cfg)}')
|
||||
|
||||
func = getattr(self.model, target)
|
||||
compiled_func = torch.compile(func, **compile_cfg)
|
||||
setattr(self.model, target, compiled_func)
|
||||
self.logger.info('Model has been "compiled". The first few iterations'
|
||||
' will be slow, please be patient.')
|
||||
|
|
|
@ -5,7 +5,7 @@ import os
|
|||
import os.path as osp
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest import TestCase, skipIf
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -1705,6 +1705,24 @@ class TestRunner(TestCase):
|
|||
with self.assertRaisesRegex(AssertionError, 'If you want to validate'):
|
||||
runner.train()
|
||||
|
||||
@skipIf(
|
||||
not hasattr(torch, 'compile'),
|
||||
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
|
||||
def test_train_with_compile(self):
|
||||
# 1. test with simple configuration
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_train_compile_simple'
|
||||
cfg.compile = True
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
# 2. test with advanced configuration
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_train_compile_advanced'
|
||||
cfg.compile = dict(backend='inductor', mode='default')
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
def test_val(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_val1'
|
||||
|
@ -1757,6 +1775,24 @@ class TestRunner(TestCase):
|
|||
self.assertIn(predictions[0].dtype,
|
||||
(torch.float16, torch.bfloat16))
|
||||
|
||||
@skipIf(
|
||||
not hasattr(torch, 'compile'),
|
||||
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
|
||||
def test_val_with_compile(self):
|
||||
# 1. test with simple configuration
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_val_compile_simple'
|
||||
cfg.compile = True
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.val()
|
||||
|
||||
# 2. test with advanced configuration
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_val_compile_advanced'
|
||||
cfg.compile = dict(backend='inductor', mode='default')
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.val()
|
||||
|
||||
def test_test(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_test1'
|
||||
|
@ -1811,6 +1847,24 @@ class TestRunner(TestCase):
|
|||
self.assertIn(predictions[0].dtype,
|
||||
(torch.float16, torch.bfloat16))
|
||||
|
||||
@skipIf(
|
||||
not hasattr(torch, 'compile'),
|
||||
reason='torch.compile is not valid, please install PyTorch>=2.0.0')
|
||||
def test_test_with_compile(self):
|
||||
# 1. test with simple configuration
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_test_compile_simple'
|
||||
cfg.compile = True
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.test()
|
||||
|
||||
# 2. test with advanced configuration
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_test_compile_advanced'
|
||||
cfg.compile = dict(backend='inductor', mode='default')
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.test()
|
||||
|
||||
def test_register_hook(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_register_hook'
|
||||
|
|
Loading…
Reference in New Issue