[Feature] Add autocast wrapper (#307)

* add autocast wrapper

* fix docstring

* fix docstring

* fix compare version

* fix unit test

* fix incompatible arguments

* fix as comment

* fix unit test

* rename auto_cast to autocast
pull/330/head
Mashiro 2022-06-22 19:49:20 +08:00 committed by GitHub
parent 216521a936
commit 312f264ecd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 227 additions and 15 deletions

View File

@ -120,5 +120,6 @@ class AmpOptimWrapper(OptimWrapper):
Args:
model (nn.Module): The training model.
"""
with super().optim_context(model), torch.cuda.amp.autocast():
from mmengine.runner.amp import autocast
with super().optim_context(model), autocast():
yield

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .amp import autocast
from .base_loop import BaseLoop
from .checkpoint import (CheckpointLoader, find_latest_checkpoint,
get_deprecated_model_names, get_external_models,
@ -13,5 +14,5 @@ __all__ = [
'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names',
'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
'TestLoop', 'Runner', 'find_latest_checkpoint'
'TestLoop', 'Runner', 'find_latest_checkpoint', 'autocast'
]

View File

@ -0,0 +1,87 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
import torch
from mmengine.utils import TORCH_VERSION, digit_version
@contextmanager
def autocast(enabled: bool = True, **kwargs):
"""A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``.
Pytorch 1.6.0 provide ``torch.cuda.amp.autocast`` for running in
mixed precision , and update it to ``torch.autocast`` in 1.10.0.
Both interfaces have different arguments, and ``torch.autocast``
support running with cpu additionally.
This function provides a unified interface by wrapping
``torch.autocast`` and ``torch.cuda.amp.autocast``, which resolves the
compatibility issues that ``torch.cuda.amp.autocast`` does not support
running mixed precision with cpu, and both contexts have different
arguments. We suggest users using this function in the code
to achieve maximized compatibility of different PyTorch versions.
Note:
``autocast`` requires pytorch version >= 1.5.0. If pytorch version
<= 1.10.0 and cuda is not available, it will raise an error with
``enabled=True``, since ``torch.cuda.amp.autocast`` only support cuda
mode.
Examples:
>>> # case1: 1.10 > Pytorch version >= 1.5.0
>>> with autocast():
>>> # run in mixed precision context
>>> pass
>>> with autocast(device_type='cpu')::
>>> # raise error, torch.cuda.amp.autocast only support cuda mode.
>>> pass
>>> # case2: Pytorch version >= 1.10.0
>>> with autocast():
>>> # default cuda mixed precision context
>>> pass
>>> with autocast(device_type='cpu'):
>>> # cpu mixed precision context
>>> pass
>>> with autocast(
>>> device_type='cuda', enabled=True, cache_enabled=True):
>>> # enable precision context with more specific arguments.
>>> pass
Args:
enabled (bool): Whether autocasting should be enabled in the region.
Defaults to True.
kwargs (dict): Arguments of torch.autocast except for ``enabled``.
"""
# If `enabled` is True, enable an empty context and all calculations
# are performed under fp32.
assert digit_version(TORCH_VERSION) >= digit_version('1.5.0'), (
'The minimum pytorch version requirements of mmengine is 1.5.0, but '
f'got {TORCH_VERSION}')
if (digit_version('1.5.0') <= digit_version(TORCH_VERSION) <
digit_version('1.10.0')):
# If pytorch version is between 1.5.0 and 1.10.0, the default value of
# dtype for `torch.cuda.amp.autocast` is torch.float16.
assert not kwargs, (
f'autocast under pytorch {TORCH_VERSION} only accept `enabled` '
'arguments.')
if torch.cuda.is_available():
with torch.cuda.amp.autocast(enabled=enabled):
yield
else:
if not enabled:
yield
else:
raise RuntimeError(
'If pytorch versions is between 1.5.0 and 1.10, '
'`autocast` is only available in gpu mode')
elif digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
if torch.cuda.is_available():
kwargs.setdefault('device_type', 'cuda')
else:
kwargs.setdefault('device_type', 'cpu')
with torch.autocast(enabled=enabled, **kwargs):
yield

View File

@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
from mmengine.evaluator import Evaluator
from mmengine.registry import LOOPS
from mmengine.utils import is_list_of
from .amp import autocast
from .base_loop import BaseLoop
@ -269,10 +270,15 @@ class ValLoop(BaseLoop):
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 validation. Defaults to
False.
"""
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List]) -> None:
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False) -> None:
super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
@ -288,6 +294,7 @@ class ValLoop(BaseLoop):
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in evaluator, metric and '
'visualizer will be None.')
self.fp16 = fp16
def run(self):
"""Launch validation."""
@ -313,7 +320,8 @@ class ValLoop(BaseLoop):
self.runner.call_hook(
'before_val_iter', batch_idx=idx, data_batch=data_batch)
# outputs should be sequence of BaseDataElement
outputs = self.runner.model.val_step(data_batch)
with autocast(enabled=self.fp16):
outputs = self.runner.model.val_step(data_batch)
self.evaluator.process(data_batch, outputs)
self.runner.call_hook(
'after_val_iter',
@ -331,10 +339,15 @@ class TestLoop(BaseLoop):
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (Evaluator or dict or list): Used for computing metrics.
fp16 (bool): Whether to enable fp16 testing. Defaults to
False.
"""
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List]):
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[Evaluator, Dict, List],
fp16: bool = False):
super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
@ -350,6 +363,7 @@ class TestLoop(BaseLoop):
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
'metainfo. ``dataset_meta`` in evaluator, metric and '
'visualizer will be None.')
self.fp16 = fp16
def run(self) -> None:
"""Launch test."""
@ -374,7 +388,8 @@ class TestLoop(BaseLoop):
self.runner.call_hook(
'before_test_iter', batch_idx=idx, data_batch=data_batch)
# predictions should be sequence of BaseDataElement
predictions = self.runner.model.test_step(data_batch)
with autocast(enabled=self.fp16):
predictions = self.runner.model.test_step(data_batch)
self.evaluator.process(data_batch, predictions)
self.runner.call_hook(
'after_test_iter',

View File

@ -97,13 +97,15 @@ class Runner:
val_cfg (dict, optional): A dict to build a validation loop. If it does
not provide "type" key, :class:`ValLoop` will be used by default.
If ``val_cfg`` specified, :attr:`val_dataloader` should also be
specified. Defaults to None.
See :meth:`build_val_loop` for more details.
specified. If ``ValLoop`` is built with `fp16=True``,
``runner.val()`` will be performed under fp16 precision.
Defaults to None. See :meth:`build_val_loop` for more details.
test_cfg (dict, optional): A dict to build a test loop. If it does
not provide "type" key, :class:`TestLoop` will be used by default.
If ``test_cfg`` specified, :attr:`test_dataloader` should also be
specified. Defaults to None.
See :meth:`build_test_loop` for more details.
specified. If ``ValLoop`` is built with `fp16=True``,
``runner.val()`` will be performed under fp16 precision.
Defaults to None. See :meth:`build_test_loop` for more details.
auto_scale_lr (dict, Optional): Config to scale the learning rate
automatically. It includes ``base_batch_size`` and ``enable``.
``base_batch_size`` is the batch size that the optimizer lr is
@ -1424,6 +1426,7 @@ class Runner:
evaluator=self._val_evaluator))
else:
loop = ValLoop(
**loop_cfg,
runner=self,
dataloader=self._val_dataloader,
evaluator=self._val_evaluator) # type: ignore
@ -1465,6 +1468,7 @@ class Runner:
evaluator=self._test_evaluator))
else:
loop = TestLoop(
**loop_cfg,
runner=self,
dataloader=self._test_dataloader,
evaluator=self._test_evaluator) # type: ignore

View File

@ -24,7 +24,7 @@ class TestAveragedModel(TestCase):
averaged_params = [
torch.zeros_like(param) for param in model.parameters()
]
n_updates = 10
n_updates = 2
for i in range(n_updates):
for p, p_avg in zip(model.parameters(), averaged_params):
p.detach().add_(torch.randn_like(p))

View File

@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest
import torch
import torch.nn as nn
from mmengine.runner import autocast
from mmengine.utils import TORCH_VERSION, digit_version
class TestAmp(unittest.TestCase):
def test_autocast(self):
if not torch.cuda.is_available():
if digit_version(TORCH_VERSION) < digit_version('1.10.0'):
# `torch.cuda.amp.autocast` is only support in gpu mode, if
# cuda is not available, it will return an empty context and
# should not accept any arguments.
with self.assertRaisesRegex(RuntimeError,
'If pytorch versions is '):
with autocast():
pass
with autocast(enabled=False):
layer = nn.Conv2d(1, 1, 1)
res = layer(torch.randn(1, 1, 1, 1))
self.assertEqual(res.dtype, torch.float32)
else:
with autocast(device_type='cpu'):
# torch.autocast support cpu mode.
layer = nn.Conv2d(1, 1, 1)
res = layer(torch.randn(1, 1, 1, 1))
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
with autocast(enabled=False):
res = layer(torch.randn(1, 1, 1, 1))
self.assertEqual(res.dtype, torch.float32)
else:
if digit_version(TORCH_VERSION) < digit_version('1.10.0'):
devices = ['cuda']
else:
devices = ['cpu', 'cuda']
for device in devices:
with autocast():
# torch.autocast support cpu and cuda mode.
layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
with autocast(enabled=False):
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32)
# Test with fp32_enabled
with autocast(enabled=False):
layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32)

View File

@ -31,7 +31,7 @@ from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
Runner, TestLoop, ValLoop)
from mmengine.runner.loops import _InfiniteDataloaderIterator
from mmengine.runner.priority import Priority, get_priority
from mmengine.utils import is_list_of
from mmengine.utils import TORCH_VERSION, digit_version, is_list_of
from mmengine.visualization import Visualizer
@ -55,7 +55,6 @@ class ToyModel(BaseModel):
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':
outputs = dict(log_vars=dict(a=1, b=0.5))
return outputs
@ -1273,7 +1272,31 @@ class TestRunner(TestCase):
cfg.pop('test_cfg')
cfg.pop('test_evaluator')
runner = Runner.from_cfg(cfg)
# Test default fp32 `autocast` context.
predictions = []
def get_outputs_callback(module, inputs, outputs):
predictions.append(outputs)
runner.model.register_forward_hook(get_outputs_callback)
runner.val()
self.assertEqual(predictions[0].dtype, torch.float32)
predictions.clear()
# Test fp16 `autocast` context.
cfg.experiment_name = 'test_val3'
cfg.val_cfg = dict(fp16=True)
runner = Runner.from_cfg(cfg)
runner.model.register_forward_hook(get_outputs_callback)
if (digit_version(TORCH_VERSION) < digit_version('1.10.0')
and not torch.cuda.is_available()):
with self.assertRaisesRegex(RuntimeError, 'If pytorch versions'):
runner.val()
else:
runner.val()
self.assertIn(predictions[0].dtype,
(torch.float16, torch.bfloat16))
def test_test(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
@ -1303,7 +1326,31 @@ class TestRunner(TestCase):
cfg.pop('val_cfg')
cfg.pop('val_evaluator')
runner = Runner.from_cfg(cfg)
# Test default fp32 `autocast` context.
predictions = []
def get_outputs_callback(module, inputs, outputs):
predictions.append(outputs)
runner.model.register_forward_hook(get_outputs_callback)
runner.test()
self.assertEqual(predictions[0].dtype, torch.float32)
predictions.clear()
# Test fp16 `autocast` context.
cfg.experiment_name = 'test_val3'
cfg.test_cfg = dict(fp16=True)
runner = Runner.from_cfg(cfg)
runner.model.register_forward_hook(get_outputs_callback)
if (digit_version(TORCH_VERSION) < digit_version('1.10.0')
and not torch.cuda.is_available()):
with self.assertRaisesRegex(RuntimeError, 'If pytorch versions'):
runner.test()
else:
runner.test()
self.assertIn(predictions[0].dtype,
(torch.float16, torch.bfloat16))
def test_register_hook(self):
cfg = copy.deepcopy(self.epoch_based_cfg)