# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import patch

from mmengine.testing import RunnerTestCase


class TestEmptyCacheHook(RunnerTestCase):

    def test_with_runner(self):
        with patch('torch.cuda.empty_cache') as mock_empty_cache:
            cfg = self.epoch_based_cfg
            cfg.custom_hooks = [dict(type='EmptyCacheHook')]
            cfg.train_cfg.val_interval = 1e6  # disable validation during training  # noqa: E501
            runner = self.build_runner(cfg)

            runner.train()
            runner.test()
            runner.val()

            # Call `torch.cuda.empty_cache` after each epoch:
            #   runner.train: `max_epochs` times.
            #   runner.val: `1` time.
            #   runner.test: `1` time.
            target_called_times = runner.max_epochs + 2
            self.assertEqual(mock_empty_cache.call_count, target_called_times)

        with patch('torch.cuda.empty_cache') as mock_empty_cache:
            cfg.custom_hooks = [dict(type='EmptyCacheHook', before_epoch=True)]
            runner = self.build_runner(cfg)

            runner.train()
            runner.val()
            runner.test()

            # Call `torch.cuda.empty_cache` after/before each epoch:
            #   runner.train: `max_epochs*2` times.
            #   runner.val: `1*2` times.
            #   runner.test: `1*2` times.

            target_called_times = runner.max_epochs * 2 + 4
            self.assertEqual(mock_empty_cache.call_count, target_called_times)

        with patch('torch.cuda.empty_cache') as mock_empty_cache:
            cfg.custom_hooks = [
                dict(
                    type='EmptyCacheHook', after_iter=True, before_epoch=True)
            ]
            runner = self.build_runner(cfg)

            runner.train()
            runner.val()
            runner.test()

            # Call `torch.cuda.empty_cache` after/before each epoch,
            # after each iteration:
            #   runner.train: `max_epochs*2 + len(dataloader)*max_epochs` times.  # noqa: E501
            #   runner.val: `1*2 + len(val_dataloader)` times.
            #   runner.test: `1*2 + len(val_dataloader)` times.

            target_called_times = \
                runner.max_epochs * 2 + 4 + \
                len(runner.train_dataloader) * runner.max_epochs + \
                len(runner.val_dataloader) + \
                len(runner.test_dataloader)
            self.assertEqual(mock_empty_cache.call_count, target_called_times)