# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock

from mmengine.hooks import DistSamplerSeedHook
from mmengine.testing import RunnerTestCase


class TestDistSamplerSeedHook(RunnerTestCase):

    def test_before_train_epoch(self):

        hook = DistSamplerSeedHook()
        # Test dataset sampler
        runner = MagicMock()
        runner.epoch = 1
        hook.before_train_epoch(runner)
        runner.train_loop.dataloader.sampler.set_epoch.assert_called()
        # Test batch sampler
        runner.train_loop.dataloader = MagicMock(spec_set=['batch_sampler'])
        hook.before_train_epoch(runner)
        runner.train_loop.dataloader.\
            batch_sampler.sampler.set_epoch.assert_called()

    def test_with_runner(self):
        cfg = self.epoch_based_cfg
        cfg.custom_hooks = [dict(type='DistSamplerSeedHook')]
        runner = self.build_runner(cfg)
        runner.train()