mmengine/tests/test_hooks/test_sampler_seed_hook.py

29 lines
937 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock
from mmengine.hooks import DistSamplerSeedHook
from mmengine.testing import RunnerTestCase
class TestDistSamplerSeedHook(RunnerTestCase):
def test_before_train_epoch(self):
hook = DistSamplerSeedHook()
# Test dataset sampler
runner = MagicMock()
runner.epoch = 1
hook.before_train_epoch(runner)
runner.train_loop.dataloader.sampler.set_epoch.assert_called()
# Test batch sampler
runner.train_loop.dataloader = MagicMock(spec_set=['batch_sampler'])
hook.before_train_epoch(runner)
runner.train_loop.dataloader.\
batch_sampler.sampler.set_epoch.assert_called()
def test_with_runner(self):
cfg = self.epoch_based_cfg
cfg.custom_hooks = [dict(type='DistSamplerSeedHook')]
runner = self.build_runner(cfg)
runner.train()