mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor unit test of EmptyCacheHook (#805)
* test EmptyCacheHook with runner * Add coments
This commit is contained in:
parent
425ca99e90
commit
4b781c336b
@ -1,14 +1,65 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from unittest.mock import Mock
|
from unittest.mock import patch
|
||||||
|
|
||||||
from mmengine.hooks import EmptyCacheHook
|
from mmengine.testing import RunnerTestCase
|
||||||
|
|
||||||
|
|
||||||
class TestEmptyCacheHook:
|
class TestEmptyCacheHook(RunnerTestCase):
|
||||||
|
|
||||||
def test_emtpy_cache_hook(self):
|
def test_with_runner(self):
|
||||||
hook = EmptyCacheHook(True, True, True)
|
with patch('torch.cuda.empty_cache') as mock_empty_cache:
|
||||||
runner = Mock()
|
cfg = self.epoch_based_cfg
|
||||||
hook._after_iter(runner, 0)
|
cfg.custom_hooks = [dict(type='EmptyCacheHook')]
|
||||||
hook._before_epoch(runner)
|
cfg.train_cfg.val_interval = 1e6 # disable validation during training # noqa: E501
|
||||||
hook._after_epoch(runner)
|
runner = self.build_runner(cfg)
|
||||||
|
|
||||||
|
runner.train()
|
||||||
|
runner.test()
|
||||||
|
runner.val()
|
||||||
|
|
||||||
|
# Call `torch.cuda.empty_cache` after each epoch:
|
||||||
|
# runner.train: `max_epochs` times.
|
||||||
|
# runner.val: `1` time.
|
||||||
|
# runner.test: `1` time.
|
||||||
|
target_called_times = runner.max_epochs + 2
|
||||||
|
self.assertEqual(mock_empty_cache.call_count, target_called_times)
|
||||||
|
|
||||||
|
with patch('torch.cuda.empty_cache') as mock_empty_cache:
|
||||||
|
cfg.custom_hooks = [dict(type='EmptyCacheHook', before_epoch=True)]
|
||||||
|
runner = self.build_runner(cfg)
|
||||||
|
|
||||||
|
runner.train()
|
||||||
|
runner.val()
|
||||||
|
runner.test()
|
||||||
|
|
||||||
|
# Call `torch.cuda.empty_cache` after/before each epoch:
|
||||||
|
# runner.train: `max_epochs*2` times.
|
||||||
|
# runner.val: `1*2` times.
|
||||||
|
# runner.test: `1*2` times.
|
||||||
|
|
||||||
|
target_called_times = runner.max_epochs * 2 + 4
|
||||||
|
self.assertEqual(mock_empty_cache.call_count, target_called_times)
|
||||||
|
|
||||||
|
with patch('torch.cuda.empty_cache') as mock_empty_cache:
|
||||||
|
cfg.custom_hooks = [
|
||||||
|
dict(
|
||||||
|
type='EmptyCacheHook', after_iter=True, before_epoch=True)
|
||||||
|
]
|
||||||
|
runner = self.build_runner(cfg)
|
||||||
|
|
||||||
|
runner.train()
|
||||||
|
runner.val()
|
||||||
|
runner.test()
|
||||||
|
|
||||||
|
# Call `torch.cuda.empty_cache` after/before each epoch,
|
||||||
|
# after each iteration:
|
||||||
|
# runner.train: `max_epochs*2 + len(dataloader)*max_epochs` times. # noqa: E501
|
||||||
|
# runner.val: `1*2 + len(val_dataloader)` times.
|
||||||
|
# runner.test: `1*2 + len(val_dataloader)` times.
|
||||||
|
|
||||||
|
target_called_times = \
|
||||||
|
runner.max_epochs * 2 + 4 + \
|
||||||
|
len(runner.train_dataloader) * runner.max_epochs + \
|
||||||
|
len(runner.val_dataloader) + \
|
||||||
|
len(runner.test_dataloader)
|
||||||
|
self.assertEqual(mock_empty_cache.call_count, target_called_times)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user