diff --git a/tests/test_hooks/test_empty_cache_hook.py b/tests/test_hooks/test_empty_cache_hook.py index e909fc5d..4a9ea997 100644 --- a/tests/test_hooks/test_empty_cache_hook.py +++ b/tests/test_hooks/test_empty_cache_hook.py @@ -1,14 +1,65 @@ # Copyright (c) OpenMMLab. All rights reserved. -from unittest.mock import Mock +from unittest.mock import patch -from mmengine.hooks import EmptyCacheHook +from mmengine.testing import RunnerTestCase -class TestEmptyCacheHook: +class TestEmptyCacheHook(RunnerTestCase): - def test_emtpy_cache_hook(self): - hook = EmptyCacheHook(True, True, True) - runner = Mock() - hook._after_iter(runner, 0) - hook._before_epoch(runner) - hook._after_epoch(runner) + def test_with_runner(self): + with patch('torch.cuda.empty_cache') as mock_empty_cache: + cfg = self.epoch_based_cfg + cfg.custom_hooks = [dict(type='EmptyCacheHook')] + cfg.train_cfg.val_interval = 1e6 # disable validation during training # noqa: E501 + runner = self.build_runner(cfg) + + runner.train() + runner.test() + runner.val() + + # Call `torch.cuda.empty_cache` after each epoch: + # runner.train: `max_epochs` times. + # runner.val: `1` time. + # runner.test: `1` time. + target_called_times = runner.max_epochs + 2 + self.assertEqual(mock_empty_cache.call_count, target_called_times) + + with patch('torch.cuda.empty_cache') as mock_empty_cache: + cfg.custom_hooks = [dict(type='EmptyCacheHook', before_epoch=True)] + runner = self.build_runner(cfg) + + runner.train() + runner.val() + runner.test() + + # Call `torch.cuda.empty_cache` after/before each epoch: + # runner.train: `max_epochs*2` times. + # runner.val: `1*2` times. + # runner.test: `1*2` times. + + target_called_times = runner.max_epochs * 2 + 4 + self.assertEqual(mock_empty_cache.call_count, target_called_times) + + with patch('torch.cuda.empty_cache') as mock_empty_cache: + cfg.custom_hooks = [ + dict( + type='EmptyCacheHook', after_iter=True, before_epoch=True) + ] + runner = self.build_runner(cfg) + + runner.train() + runner.val() + runner.test() + + # Call `torch.cuda.empty_cache` after/before each epoch, + # after each iteration: + # runner.train: `max_epochs*2 + len(dataloader)*max_epochs` times. # noqa: E501 + # runner.val: `1*2 + len(val_dataloader)` times. + # runner.test: `1*2 + len(val_dataloader)` times. + + target_called_times = \ + runner.max_epochs * 2 + 4 + \ + len(runner.train_dataloader) * runner.max_epochs + \ + len(runner.val_dataloader) + \ + len(runner.test_dataloader) + self.assertEqual(mock_empty_cache.call_count, target_called_times)