[Feature] Add autocast wrapper (#307)
* add autocast wrapper * fix docstring * fix docstring * fix compare version * fix unit test * fix incompatible arguments * fix as comment * fix unit test * rename auto_cast to autocastpull/330/head
parent
216521a936
commit
312f264ecd
|
@ -120,5 +120,6 @@ class AmpOptimWrapper(OptimWrapper):
|
|||
Args:
|
||||
model (nn.Module): The training model.
|
||||
"""
|
||||
with super().optim_context(model), torch.cuda.amp.autocast():
|
||||
from mmengine.runner.amp import autocast
|
||||
with super().optim_context(model), autocast():
|
||||
yield
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .amp import autocast
|
||||
from .base_loop import BaseLoop
|
||||
from .checkpoint import (CheckpointLoader, find_latest_checkpoint,
|
||||
get_deprecated_model_names, get_external_models,
|
||||
|
@ -13,5 +14,5 @@ __all__ = [
|
|||
'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names',
|
||||
'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
|
||||
'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
|
||||
'TestLoop', 'Runner', 'find_latest_checkpoint'
|
||||
'TestLoop', 'Runner', 'find_latest_checkpoint', 'autocast'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
|
||||
|
||||
@contextmanager
|
||||
def autocast(enabled: bool = True, **kwargs):
|
||||
"""A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``.
|
||||
|
||||
Pytorch 1.6.0 provide ``torch.cuda.amp.autocast`` for running in
|
||||
mixed precision , and update it to ``torch.autocast`` in 1.10.0.
|
||||
Both interfaces have different arguments, and ``torch.autocast``
|
||||
support running with cpu additionally.
|
||||
|
||||
This function provides a unified interface by wrapping
|
||||
``torch.autocast`` and ``torch.cuda.amp.autocast``, which resolves the
|
||||
compatibility issues that ``torch.cuda.amp.autocast`` does not support
|
||||
running mixed precision with cpu, and both contexts have different
|
||||
arguments. We suggest users using this function in the code
|
||||
to achieve maximized compatibility of different PyTorch versions.
|
||||
|
||||
Note:
|
||||
``autocast`` requires pytorch version >= 1.5.0. If pytorch version
|
||||
<= 1.10.0 and cuda is not available, it will raise an error with
|
||||
``enabled=True``, since ``torch.cuda.amp.autocast`` only support cuda
|
||||
mode.
|
||||
|
||||
Examples:
|
||||
>>> # case1: 1.10 > Pytorch version >= 1.5.0
|
||||
>>> with autocast():
|
||||
>>> # run in mixed precision context
|
||||
>>> pass
|
||||
>>> with autocast(device_type='cpu')::
|
||||
>>> # raise error, torch.cuda.amp.autocast only support cuda mode.
|
||||
>>> pass
|
||||
>>> # case2: Pytorch version >= 1.10.0
|
||||
>>> with autocast():
|
||||
>>> # default cuda mixed precision context
|
||||
>>> pass
|
||||
>>> with autocast(device_type='cpu'):
|
||||
>>> # cpu mixed precision context
|
||||
>>> pass
|
||||
>>> with autocast(
|
||||
>>> device_type='cuda', enabled=True, cache_enabled=True):
|
||||
>>> # enable precision context with more specific arguments.
|
||||
>>> pass
|
||||
|
||||
Args:
|
||||
enabled (bool): Whether autocasting should be enabled in the region.
|
||||
Defaults to True.
|
||||
kwargs (dict): Arguments of torch.autocast except for ``enabled``.
|
||||
"""
|
||||
# If `enabled` is True, enable an empty context and all calculations
|
||||
# are performed under fp32.
|
||||
assert digit_version(TORCH_VERSION) >= digit_version('1.5.0'), (
|
||||
'The minimum pytorch version requirements of mmengine is 1.5.0, but '
|
||||
f'got {TORCH_VERSION}')
|
||||
|
||||
if (digit_version('1.5.0') <= digit_version(TORCH_VERSION) <
|
||||
digit_version('1.10.0')):
|
||||
# If pytorch version is between 1.5.0 and 1.10.0, the default value of
|
||||
# dtype for `torch.cuda.amp.autocast` is torch.float16.
|
||||
assert not kwargs, (
|
||||
f'autocast under pytorch {TORCH_VERSION} only accept `enabled` '
|
||||
'arguments.')
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.amp.autocast(enabled=enabled):
|
||||
yield
|
||||
else:
|
||||
if not enabled:
|
||||
yield
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'If pytorch versions is between 1.5.0 and 1.10, '
|
||||
'`autocast` is only available in gpu mode')
|
||||
|
||||
elif digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
|
||||
if torch.cuda.is_available():
|
||||
kwargs.setdefault('device_type', 'cuda')
|
||||
else:
|
||||
kwargs.setdefault('device_type', 'cpu')
|
||||
|
||||
with torch.autocast(enabled=enabled, **kwargs):
|
||||
yield
|
|
@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
|
|||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.registry import LOOPS
|
||||
from mmengine.utils import is_list_of
|
||||
from .amp import autocast
|
||||
from .base_loop import BaseLoop
|
||||
|
||||
|
||||
|
@ -269,10 +270,15 @@ class ValLoop(BaseLoop):
|
|||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||
fp16 (bool): Whether to enable fp16 validation. Defaults to
|
||||
False.
|
||||
"""
|
||||
|
||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List]) -> None:
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List],
|
||||
fp16: bool = False) -> None:
|
||||
super().__init__(runner, dataloader)
|
||||
|
||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||
|
@ -288,6 +294,7 @@ class ValLoop(BaseLoop):
|
|||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
||||
'visualizer will be None.')
|
||||
self.fp16 = fp16
|
||||
|
||||
def run(self):
|
||||
"""Launch validation."""
|
||||
|
@ -313,7 +320,8 @@ class ValLoop(BaseLoop):
|
|||
self.runner.call_hook(
|
||||
'before_val_iter', batch_idx=idx, data_batch=data_batch)
|
||||
# outputs should be sequence of BaseDataElement
|
||||
outputs = self.runner.model.val_step(data_batch)
|
||||
with autocast(enabled=self.fp16):
|
||||
outputs = self.runner.model.val_step(data_batch)
|
||||
self.evaluator.process(data_batch, outputs)
|
||||
self.runner.call_hook(
|
||||
'after_val_iter',
|
||||
|
@ -331,10 +339,15 @@ class TestLoop(BaseLoop):
|
|||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||
fp16 (bool): Whether to enable fp16 testing. Defaults to
|
||||
False.
|
||||
"""
|
||||
|
||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List]):
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[Evaluator, Dict, List],
|
||||
fp16: bool = False):
|
||||
super().__init__(runner, dataloader)
|
||||
|
||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||
|
@ -350,6 +363,7 @@ class TestLoop(BaseLoop):
|
|||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
||||
'visualizer will be None.')
|
||||
self.fp16 = fp16
|
||||
|
||||
def run(self) -> None:
|
||||
"""Launch test."""
|
||||
|
@ -374,7 +388,8 @@ class TestLoop(BaseLoop):
|
|||
self.runner.call_hook(
|
||||
'before_test_iter', batch_idx=idx, data_batch=data_batch)
|
||||
# predictions should be sequence of BaseDataElement
|
||||
predictions = self.runner.model.test_step(data_batch)
|
||||
with autocast(enabled=self.fp16):
|
||||
predictions = self.runner.model.test_step(data_batch)
|
||||
self.evaluator.process(data_batch, predictions)
|
||||
self.runner.call_hook(
|
||||
'after_test_iter',
|
||||
|
|
|
@ -97,13 +97,15 @@ class Runner:
|
|||
val_cfg (dict, optional): A dict to build a validation loop. If it does
|
||||
not provide "type" key, :class:`ValLoop` will be used by default.
|
||||
If ``val_cfg`` specified, :attr:`val_dataloader` should also be
|
||||
specified. Defaults to None.
|
||||
See :meth:`build_val_loop` for more details.
|
||||
specified. If ``ValLoop`` is built with `fp16=True``,
|
||||
``runner.val()`` will be performed under fp16 precision.
|
||||
Defaults to None. See :meth:`build_val_loop` for more details.
|
||||
test_cfg (dict, optional): A dict to build a test loop. If it does
|
||||
not provide "type" key, :class:`TestLoop` will be used by default.
|
||||
If ``test_cfg`` specified, :attr:`test_dataloader` should also be
|
||||
specified. Defaults to None.
|
||||
See :meth:`build_test_loop` for more details.
|
||||
specified. If ``ValLoop`` is built with `fp16=True``,
|
||||
``runner.val()`` will be performed under fp16 precision.
|
||||
Defaults to None. See :meth:`build_test_loop` for more details.
|
||||
auto_scale_lr (dict, Optional): Config to scale the learning rate
|
||||
automatically. It includes ``base_batch_size`` and ``enable``.
|
||||
``base_batch_size`` is the batch size that the optimizer lr is
|
||||
|
@ -1424,6 +1426,7 @@ class Runner:
|
|||
evaluator=self._val_evaluator))
|
||||
else:
|
||||
loop = ValLoop(
|
||||
**loop_cfg,
|
||||
runner=self,
|
||||
dataloader=self._val_dataloader,
|
||||
evaluator=self._val_evaluator) # type: ignore
|
||||
|
@ -1465,6 +1468,7 @@ class Runner:
|
|||
evaluator=self._test_evaluator))
|
||||
else:
|
||||
loop = TestLoop(
|
||||
**loop_cfg,
|
||||
runner=self,
|
||||
dataloader=self._test_dataloader,
|
||||
evaluator=self._test_evaluator) # type: ignore
|
||||
|
|
|
@ -24,7 +24,7 @@ class TestAveragedModel(TestCase):
|
|||
averaged_params = [
|
||||
torch.zeros_like(param) for param in model.parameters()
|
||||
]
|
||||
n_updates = 10
|
||||
n_updates = 2
|
||||
for i in range(n_updates):
|
||||
for p, p_avg in zip(model.parameters(), averaged_params):
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.runner import autocast
|
||||
from mmengine.utils import TORCH_VERSION, digit_version
|
||||
|
||||
|
||||
class TestAmp(unittest.TestCase):
|
||||
|
||||
def test_autocast(self):
|
||||
if not torch.cuda.is_available():
|
||||
if digit_version(TORCH_VERSION) < digit_version('1.10.0'):
|
||||
# `torch.cuda.amp.autocast` is only support in gpu mode, if
|
||||
# cuda is not available, it will return an empty context and
|
||||
# should not accept any arguments.
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
'If pytorch versions is '):
|
||||
with autocast():
|
||||
pass
|
||||
|
||||
with autocast(enabled=False):
|
||||
layer = nn.Conv2d(1, 1, 1)
|
||||
res = layer(torch.randn(1, 1, 1, 1))
|
||||
self.assertEqual(res.dtype, torch.float32)
|
||||
|
||||
else:
|
||||
with autocast(device_type='cpu'):
|
||||
# torch.autocast support cpu mode.
|
||||
layer = nn.Conv2d(1, 1, 1)
|
||||
res = layer(torch.randn(1, 1, 1, 1))
|
||||
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
|
||||
with autocast(enabled=False):
|
||||
res = layer(torch.randn(1, 1, 1, 1))
|
||||
self.assertEqual(res.dtype, torch.float32)
|
||||
|
||||
else:
|
||||
if digit_version(TORCH_VERSION) < digit_version('1.10.0'):
|
||||
devices = ['cuda']
|
||||
else:
|
||||
devices = ['cpu', 'cuda']
|
||||
for device in devices:
|
||||
with autocast():
|
||||
# torch.autocast support cpu and cuda mode.
|
||||
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
|
||||
with autocast(enabled=False):
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertEqual(res.dtype, torch.float32)
|
||||
# Test with fp32_enabled
|
||||
with autocast(enabled=False):
|
||||
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||
self.assertEqual(res.dtype, torch.float32)
|
|
@ -31,7 +31,7 @@ from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
|||
Runner, TestLoop, ValLoop)
|
||||
from mmengine.runner.loops import _InfiniteDataloaderIterator
|
||||
from mmengine.runner.priority import Priority, get_priority
|
||||
from mmengine.utils import is_list_of
|
||||
from mmengine.utils import TORCH_VERSION, digit_version, is_list_of
|
||||
from mmengine.visualization import Visualizer
|
||||
|
||||
|
||||
|
@ -55,7 +55,6 @@ class ToyModel(BaseModel):
|
|||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
elif mode == 'predict':
|
||||
outputs = dict(log_vars=dict(a=1, b=0.5))
|
||||
return outputs
|
||||
|
||||
|
||||
|
@ -1273,7 +1272,31 @@ class TestRunner(TestCase):
|
|||
cfg.pop('test_cfg')
|
||||
cfg.pop('test_evaluator')
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# Test default fp32 `autocast` context.
|
||||
predictions = []
|
||||
|
||||
def get_outputs_callback(module, inputs, outputs):
|
||||
predictions.append(outputs)
|
||||
|
||||
runner.model.register_forward_hook(get_outputs_callback)
|
||||
runner.val()
|
||||
self.assertEqual(predictions[0].dtype, torch.float32)
|
||||
predictions.clear()
|
||||
|
||||
# Test fp16 `autocast` context.
|
||||
cfg.experiment_name = 'test_val3'
|
||||
cfg.val_cfg = dict(fp16=True)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.model.register_forward_hook(get_outputs_callback)
|
||||
if (digit_version(TORCH_VERSION) < digit_version('1.10.0')
|
||||
and not torch.cuda.is_available()):
|
||||
with self.assertRaisesRegex(RuntimeError, 'If pytorch versions'):
|
||||
runner.val()
|
||||
else:
|
||||
runner.val()
|
||||
self.assertIn(predictions[0].dtype,
|
||||
(torch.float16, torch.bfloat16))
|
||||
|
||||
def test_test(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
|
@ -1303,7 +1326,31 @@ class TestRunner(TestCase):
|
|||
cfg.pop('val_cfg')
|
||||
cfg.pop('val_evaluator')
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# Test default fp32 `autocast` context.
|
||||
predictions = []
|
||||
|
||||
def get_outputs_callback(module, inputs, outputs):
|
||||
predictions.append(outputs)
|
||||
|
||||
runner.model.register_forward_hook(get_outputs_callback)
|
||||
runner.test()
|
||||
self.assertEqual(predictions[0].dtype, torch.float32)
|
||||
predictions.clear()
|
||||
|
||||
# Test fp16 `autocast` context.
|
||||
cfg.experiment_name = 'test_val3'
|
||||
cfg.test_cfg = dict(fp16=True)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.model.register_forward_hook(get_outputs_callback)
|
||||
if (digit_version(TORCH_VERSION) < digit_version('1.10.0')
|
||||
and not torch.cuda.is_available()):
|
||||
with self.assertRaisesRegex(RuntimeError, 'If pytorch versions'):
|
||||
runner.test()
|
||||
else:
|
||||
runner.test()
|
||||
self.assertIn(predictions[0].dtype,
|
||||
(torch.float16, torch.bfloat16))
|
||||
|
||||
def test_register_hook(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
|
|
Loading…
Reference in New Issue