mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add autocast wrapper (#307)
* add autocast wrapper * fix docstring * fix docstring * fix compare version * fix unit test * fix incompatible arguments * fix as comment * fix unit test * rename auto_cast to autocast
This commit is contained in:
parent
216521a936
commit
312f264ecd
@ -120,5 +120,6 @@ class AmpOptimWrapper(OptimWrapper):
|
|||||||
Args:
|
Args:
|
||||||
model (nn.Module): The training model.
|
model (nn.Module): The training model.
|
||||||
"""
|
"""
|
||||||
with super().optim_context(model), torch.cuda.amp.autocast():
|
from mmengine.runner.amp import autocast
|
||||||
|
with super().optim_context(model), autocast():
|
||||||
yield
|
yield
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .amp import autocast
|
||||||
from .base_loop import BaseLoop
|
from .base_loop import BaseLoop
|
||||||
from .checkpoint import (CheckpointLoader, find_latest_checkpoint,
|
from .checkpoint import (CheckpointLoader, find_latest_checkpoint,
|
||||||
get_deprecated_model_names, get_external_models,
|
get_deprecated_model_names, get_external_models,
|
||||||
@ -13,5 +14,5 @@ __all__ = [
|
|||||||
'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names',
|
'get_external_models', 'get_mmcls_models', 'get_deprecated_model_names',
|
||||||
'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
|
'CheckpointLoader', 'load_checkpoint', 'weights_to_cpu', 'get_state_dict',
|
||||||
'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
|
'save_checkpoint', 'EpochBasedTrainLoop', 'IterBasedTrainLoop', 'ValLoop',
|
||||||
'TestLoop', 'Runner', 'find_latest_checkpoint'
|
'TestLoop', 'Runner', 'find_latest_checkpoint', 'autocast'
|
||||||
]
|
]
|
||||||
|
87
mmengine/runner/amp.py
Normal file
87
mmengine/runner/amp.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmengine.utils import TORCH_VERSION, digit_version
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def autocast(enabled: bool = True, **kwargs):
|
||||||
|
"""A wrapper of ``torch.autocast`` and ``toch.cuda.amp.autocast``.
|
||||||
|
|
||||||
|
Pytorch 1.6.0 provide ``torch.cuda.amp.autocast`` for running in
|
||||||
|
mixed precision , and update it to ``torch.autocast`` in 1.10.0.
|
||||||
|
Both interfaces have different arguments, and ``torch.autocast``
|
||||||
|
support running with cpu additionally.
|
||||||
|
|
||||||
|
This function provides a unified interface by wrapping
|
||||||
|
``torch.autocast`` and ``torch.cuda.amp.autocast``, which resolves the
|
||||||
|
compatibility issues that ``torch.cuda.amp.autocast`` does not support
|
||||||
|
running mixed precision with cpu, and both contexts have different
|
||||||
|
arguments. We suggest users using this function in the code
|
||||||
|
to achieve maximized compatibility of different PyTorch versions.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
``autocast`` requires pytorch version >= 1.5.0. If pytorch version
|
||||||
|
<= 1.10.0 and cuda is not available, it will raise an error with
|
||||||
|
``enabled=True``, since ``torch.cuda.amp.autocast`` only support cuda
|
||||||
|
mode.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # case1: 1.10 > Pytorch version >= 1.5.0
|
||||||
|
>>> with autocast():
|
||||||
|
>>> # run in mixed precision context
|
||||||
|
>>> pass
|
||||||
|
>>> with autocast(device_type='cpu')::
|
||||||
|
>>> # raise error, torch.cuda.amp.autocast only support cuda mode.
|
||||||
|
>>> pass
|
||||||
|
>>> # case2: Pytorch version >= 1.10.0
|
||||||
|
>>> with autocast():
|
||||||
|
>>> # default cuda mixed precision context
|
||||||
|
>>> pass
|
||||||
|
>>> with autocast(device_type='cpu'):
|
||||||
|
>>> # cpu mixed precision context
|
||||||
|
>>> pass
|
||||||
|
>>> with autocast(
|
||||||
|
>>> device_type='cuda', enabled=True, cache_enabled=True):
|
||||||
|
>>> # enable precision context with more specific arguments.
|
||||||
|
>>> pass
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enabled (bool): Whether autocasting should be enabled in the region.
|
||||||
|
Defaults to True.
|
||||||
|
kwargs (dict): Arguments of torch.autocast except for ``enabled``.
|
||||||
|
"""
|
||||||
|
# If `enabled` is True, enable an empty context and all calculations
|
||||||
|
# are performed under fp32.
|
||||||
|
assert digit_version(TORCH_VERSION) >= digit_version('1.5.0'), (
|
||||||
|
'The minimum pytorch version requirements of mmengine is 1.5.0, but '
|
||||||
|
f'got {TORCH_VERSION}')
|
||||||
|
|
||||||
|
if (digit_version('1.5.0') <= digit_version(TORCH_VERSION) <
|
||||||
|
digit_version('1.10.0')):
|
||||||
|
# If pytorch version is between 1.5.0 and 1.10.0, the default value of
|
||||||
|
# dtype for `torch.cuda.amp.autocast` is torch.float16.
|
||||||
|
assert not kwargs, (
|
||||||
|
f'autocast under pytorch {TORCH_VERSION} only accept `enabled` '
|
||||||
|
'arguments.')
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
with torch.cuda.amp.autocast(enabled=enabled):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
if not enabled:
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
'If pytorch versions is between 1.5.0 and 1.10, '
|
||||||
|
'`autocast` is only available in gpu mode')
|
||||||
|
|
||||||
|
elif digit_version(TORCH_VERSION) >= digit_version('1.10.0'):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
kwargs.setdefault('device_type', 'cuda')
|
||||||
|
else:
|
||||||
|
kwargs.setdefault('device_type', 'cpu')
|
||||||
|
|
||||||
|
with torch.autocast(enabled=enabled, **kwargs):
|
||||||
|
yield
|
@ -9,6 +9,7 @@ from torch.utils.data import DataLoader
|
|||||||
from mmengine.evaluator import Evaluator
|
from mmengine.evaluator import Evaluator
|
||||||
from mmengine.registry import LOOPS
|
from mmengine.registry import LOOPS
|
||||||
from mmengine.utils import is_list_of
|
from mmengine.utils import is_list_of
|
||||||
|
from .amp import autocast
|
||||||
from .base_loop import BaseLoop
|
from .base_loop import BaseLoop
|
||||||
|
|
||||||
|
|
||||||
@ -269,10 +270,15 @@ class ValLoop(BaseLoop):
|
|||||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||||
build a dataloader.
|
build a dataloader.
|
||||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||||
|
fp16 (bool): Whether to enable fp16 validation. Defaults to
|
||||||
|
False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
def __init__(self,
|
||||||
evaluator: Union[Evaluator, Dict, List]) -> None:
|
runner,
|
||||||
|
dataloader: Union[DataLoader, Dict],
|
||||||
|
evaluator: Union[Evaluator, Dict, List],
|
||||||
|
fp16: bool = False) -> None:
|
||||||
super().__init__(runner, dataloader)
|
super().__init__(runner, dataloader)
|
||||||
|
|
||||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||||
@ -288,6 +294,7 @@ class ValLoop(BaseLoop):
|
|||||||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||||
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
||||||
'visualizer will be None.')
|
'visualizer will be None.')
|
||||||
|
self.fp16 = fp16
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""Launch validation."""
|
"""Launch validation."""
|
||||||
@ -313,6 +320,7 @@ class ValLoop(BaseLoop):
|
|||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
'before_val_iter', batch_idx=idx, data_batch=data_batch)
|
'before_val_iter', batch_idx=idx, data_batch=data_batch)
|
||||||
# outputs should be sequence of BaseDataElement
|
# outputs should be sequence of BaseDataElement
|
||||||
|
with autocast(enabled=self.fp16):
|
||||||
outputs = self.runner.model.val_step(data_batch)
|
outputs = self.runner.model.val_step(data_batch)
|
||||||
self.evaluator.process(data_batch, outputs)
|
self.evaluator.process(data_batch, outputs)
|
||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
@ -331,10 +339,15 @@ class TestLoop(BaseLoop):
|
|||||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||||
build a dataloader.
|
build a dataloader.
|
||||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||||
|
fp16 (bool): Whether to enable fp16 testing. Defaults to
|
||||||
|
False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
def __init__(self,
|
||||||
evaluator: Union[Evaluator, Dict, List]):
|
runner,
|
||||||
|
dataloader: Union[DataLoader, Dict],
|
||||||
|
evaluator: Union[Evaluator, Dict, List],
|
||||||
|
fp16: bool = False):
|
||||||
super().__init__(runner, dataloader)
|
super().__init__(runner, dataloader)
|
||||||
|
|
||||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||||
@ -350,6 +363,7 @@ class TestLoop(BaseLoop):
|
|||||||
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
|
||||||
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
'metainfo. ``dataset_meta`` in evaluator, metric and '
|
||||||
'visualizer will be None.')
|
'visualizer will be None.')
|
||||||
|
self.fp16 = fp16
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""Launch test."""
|
"""Launch test."""
|
||||||
@ -374,6 +388,7 @@ class TestLoop(BaseLoop):
|
|||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
'before_test_iter', batch_idx=idx, data_batch=data_batch)
|
'before_test_iter', batch_idx=idx, data_batch=data_batch)
|
||||||
# predictions should be sequence of BaseDataElement
|
# predictions should be sequence of BaseDataElement
|
||||||
|
with autocast(enabled=self.fp16):
|
||||||
predictions = self.runner.model.test_step(data_batch)
|
predictions = self.runner.model.test_step(data_batch)
|
||||||
self.evaluator.process(data_batch, predictions)
|
self.evaluator.process(data_batch, predictions)
|
||||||
self.runner.call_hook(
|
self.runner.call_hook(
|
||||||
|
@ -97,13 +97,15 @@ class Runner:
|
|||||||
val_cfg (dict, optional): A dict to build a validation loop. If it does
|
val_cfg (dict, optional): A dict to build a validation loop. If it does
|
||||||
not provide "type" key, :class:`ValLoop` will be used by default.
|
not provide "type" key, :class:`ValLoop` will be used by default.
|
||||||
If ``val_cfg`` specified, :attr:`val_dataloader` should also be
|
If ``val_cfg`` specified, :attr:`val_dataloader` should also be
|
||||||
specified. Defaults to None.
|
specified. If ``ValLoop`` is built with `fp16=True``,
|
||||||
See :meth:`build_val_loop` for more details.
|
``runner.val()`` will be performed under fp16 precision.
|
||||||
|
Defaults to None. See :meth:`build_val_loop` for more details.
|
||||||
test_cfg (dict, optional): A dict to build a test loop. If it does
|
test_cfg (dict, optional): A dict to build a test loop. If it does
|
||||||
not provide "type" key, :class:`TestLoop` will be used by default.
|
not provide "type" key, :class:`TestLoop` will be used by default.
|
||||||
If ``test_cfg`` specified, :attr:`test_dataloader` should also be
|
If ``test_cfg`` specified, :attr:`test_dataloader` should also be
|
||||||
specified. Defaults to None.
|
specified. If ``ValLoop`` is built with `fp16=True``,
|
||||||
See :meth:`build_test_loop` for more details.
|
``runner.val()`` will be performed under fp16 precision.
|
||||||
|
Defaults to None. See :meth:`build_test_loop` for more details.
|
||||||
auto_scale_lr (dict, Optional): Config to scale the learning rate
|
auto_scale_lr (dict, Optional): Config to scale the learning rate
|
||||||
automatically. It includes ``base_batch_size`` and ``enable``.
|
automatically. It includes ``base_batch_size`` and ``enable``.
|
||||||
``base_batch_size`` is the batch size that the optimizer lr is
|
``base_batch_size`` is the batch size that the optimizer lr is
|
||||||
@ -1424,6 +1426,7 @@ class Runner:
|
|||||||
evaluator=self._val_evaluator))
|
evaluator=self._val_evaluator))
|
||||||
else:
|
else:
|
||||||
loop = ValLoop(
|
loop = ValLoop(
|
||||||
|
**loop_cfg,
|
||||||
runner=self,
|
runner=self,
|
||||||
dataloader=self._val_dataloader,
|
dataloader=self._val_dataloader,
|
||||||
evaluator=self._val_evaluator) # type: ignore
|
evaluator=self._val_evaluator) # type: ignore
|
||||||
@ -1465,6 +1468,7 @@ class Runner:
|
|||||||
evaluator=self._test_evaluator))
|
evaluator=self._test_evaluator))
|
||||||
else:
|
else:
|
||||||
loop = TestLoop(
|
loop = TestLoop(
|
||||||
|
**loop_cfg,
|
||||||
runner=self,
|
runner=self,
|
||||||
dataloader=self._test_dataloader,
|
dataloader=self._test_dataloader,
|
||||||
evaluator=self._test_evaluator) # type: ignore
|
evaluator=self._test_evaluator) # type: ignore
|
||||||
|
@ -24,7 +24,7 @@ class TestAveragedModel(TestCase):
|
|||||||
averaged_params = [
|
averaged_params = [
|
||||||
torch.zeros_like(param) for param in model.parameters()
|
torch.zeros_like(param) for param in model.parameters()
|
||||||
]
|
]
|
||||||
n_updates = 10
|
n_updates = 2
|
||||||
for i in range(n_updates):
|
for i in range(n_updates):
|
||||||
for p, p_avg in zip(model.parameters(), averaged_params):
|
for p, p_avg in zip(model.parameters(), averaged_params):
|
||||||
p.detach().add_(torch.randn_like(p))
|
p.detach().add_(torch.randn_like(p))
|
||||||
|
57
tests/test_runner/test_amp.py
Normal file
57
tests/test_runner/test_amp.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from mmengine.runner import autocast
|
||||||
|
from mmengine.utils import TORCH_VERSION, digit_version
|
||||||
|
|
||||||
|
|
||||||
|
class TestAmp(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_autocast(self):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
if digit_version(TORCH_VERSION) < digit_version('1.10.0'):
|
||||||
|
# `torch.cuda.amp.autocast` is only support in gpu mode, if
|
||||||
|
# cuda is not available, it will return an empty context and
|
||||||
|
# should not accept any arguments.
|
||||||
|
with self.assertRaisesRegex(RuntimeError,
|
||||||
|
'If pytorch versions is '):
|
||||||
|
with autocast():
|
||||||
|
pass
|
||||||
|
|
||||||
|
with autocast(enabled=False):
|
||||||
|
layer = nn.Conv2d(1, 1, 1)
|
||||||
|
res = layer(torch.randn(1, 1, 1, 1))
|
||||||
|
self.assertEqual(res.dtype, torch.float32)
|
||||||
|
|
||||||
|
else:
|
||||||
|
with autocast(device_type='cpu'):
|
||||||
|
# torch.autocast support cpu mode.
|
||||||
|
layer = nn.Conv2d(1, 1, 1)
|
||||||
|
res = layer(torch.randn(1, 1, 1, 1))
|
||||||
|
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
|
||||||
|
with autocast(enabled=False):
|
||||||
|
res = layer(torch.randn(1, 1, 1, 1))
|
||||||
|
self.assertEqual(res.dtype, torch.float32)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if digit_version(TORCH_VERSION) < digit_version('1.10.0'):
|
||||||
|
devices = ['cuda']
|
||||||
|
else:
|
||||||
|
devices = ['cpu', 'cuda']
|
||||||
|
for device in devices:
|
||||||
|
with autocast():
|
||||||
|
# torch.autocast support cpu and cuda mode.
|
||||||
|
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||||
|
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||||
|
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
|
||||||
|
with autocast(enabled=False):
|
||||||
|
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||||
|
self.assertEqual(res.dtype, torch.float32)
|
||||||
|
# Test with fp32_enabled
|
||||||
|
with autocast(enabled=False):
|
||||||
|
layer = nn.Conv2d(1, 1, 1).to(device)
|
||||||
|
res = layer(torch.randn(1, 1, 1, 1).to(device))
|
||||||
|
self.assertEqual(res.dtype, torch.float32)
|
@ -31,7 +31,7 @@ from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
|||||||
Runner, TestLoop, ValLoop)
|
Runner, TestLoop, ValLoop)
|
||||||
from mmengine.runner.loops import _InfiniteDataloaderIterator
|
from mmengine.runner.loops import _InfiniteDataloaderIterator
|
||||||
from mmengine.runner.priority import Priority, get_priority
|
from mmengine.runner.priority import Priority, get_priority
|
||||||
from mmengine.utils import is_list_of
|
from mmengine.utils import TORCH_VERSION, digit_version, is_list_of
|
||||||
from mmengine.visualization import Visualizer
|
from mmengine.visualization import Visualizer
|
||||||
|
|
||||||
|
|
||||||
@ -55,7 +55,6 @@ class ToyModel(BaseModel):
|
|||||||
outputs = dict(loss=loss)
|
outputs = dict(loss=loss)
|
||||||
return outputs
|
return outputs
|
||||||
elif mode == 'predict':
|
elif mode == 'predict':
|
||||||
outputs = dict(log_vars=dict(a=1, b=0.5))
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@ -1273,7 +1272,31 @@ class TestRunner(TestCase):
|
|||||||
cfg.pop('test_cfg')
|
cfg.pop('test_cfg')
|
||||||
cfg.pop('test_evaluator')
|
cfg.pop('test_evaluator')
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
|
# Test default fp32 `autocast` context.
|
||||||
|
predictions = []
|
||||||
|
|
||||||
|
def get_outputs_callback(module, inputs, outputs):
|
||||||
|
predictions.append(outputs)
|
||||||
|
|
||||||
|
runner.model.register_forward_hook(get_outputs_callback)
|
||||||
runner.val()
|
runner.val()
|
||||||
|
self.assertEqual(predictions[0].dtype, torch.float32)
|
||||||
|
predictions.clear()
|
||||||
|
|
||||||
|
# Test fp16 `autocast` context.
|
||||||
|
cfg.experiment_name = 'test_val3'
|
||||||
|
cfg.val_cfg = dict(fp16=True)
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
runner.model.register_forward_hook(get_outputs_callback)
|
||||||
|
if (digit_version(TORCH_VERSION) < digit_version('1.10.0')
|
||||||
|
and not torch.cuda.is_available()):
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'If pytorch versions'):
|
||||||
|
runner.val()
|
||||||
|
else:
|
||||||
|
runner.val()
|
||||||
|
self.assertIn(predictions[0].dtype,
|
||||||
|
(torch.float16, torch.bfloat16))
|
||||||
|
|
||||||
def test_test(self):
|
def test_test(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
@ -1303,7 +1326,31 @@ class TestRunner(TestCase):
|
|||||||
cfg.pop('val_cfg')
|
cfg.pop('val_cfg')
|
||||||
cfg.pop('val_evaluator')
|
cfg.pop('val_evaluator')
|
||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
|
|
||||||
|
# Test default fp32 `autocast` context.
|
||||||
|
predictions = []
|
||||||
|
|
||||||
|
def get_outputs_callback(module, inputs, outputs):
|
||||||
|
predictions.append(outputs)
|
||||||
|
|
||||||
|
runner.model.register_forward_hook(get_outputs_callback)
|
||||||
runner.test()
|
runner.test()
|
||||||
|
self.assertEqual(predictions[0].dtype, torch.float32)
|
||||||
|
predictions.clear()
|
||||||
|
|
||||||
|
# Test fp16 `autocast` context.
|
||||||
|
cfg.experiment_name = 'test_val3'
|
||||||
|
cfg.test_cfg = dict(fp16=True)
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
runner.model.register_forward_hook(get_outputs_callback)
|
||||||
|
if (digit_version(TORCH_VERSION) < digit_version('1.10.0')
|
||||||
|
and not torch.cuda.is_available()):
|
||||||
|
with self.assertRaisesRegex(RuntimeError, 'If pytorch versions'):
|
||||||
|
runner.test()
|
||||||
|
else:
|
||||||
|
runner.test()
|
||||||
|
self.assertIn(predictions[0].dtype,
|
||||||
|
(torch.float16, torch.bfloat16))
|
||||||
|
|
||||||
def test_register_hook(self):
|
def test_register_hook(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user