29 lines
937 B
Python
29 lines
937 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest.mock import MagicMock
|
|
|
|
from mmengine.hooks import DistSamplerSeedHook
|
|
from mmengine.testing import RunnerTestCase
|
|
|
|
|
|
class TestDistSamplerSeedHook(RunnerTestCase):
|
|
|
|
def test_before_train_epoch(self):
|
|
|
|
hook = DistSamplerSeedHook()
|
|
# Test dataset sampler
|
|
runner = MagicMock()
|
|
runner.epoch = 1
|
|
hook.before_train_epoch(runner)
|
|
runner.train_loop.dataloader.sampler.set_epoch.assert_called()
|
|
# Test batch sampler
|
|
runner.train_loop.dataloader = MagicMock(spec_set=['batch_sampler'])
|
|
hook.before_train_epoch(runner)
|
|
runner.train_loop.dataloader.\
|
|
batch_sampler.sampler.set_epoch.assert_called()
|
|
|
|
def test_with_runner(self):
|
|
cfg = self.epoch_based_cfg
|
|
cfg.custom_hooks = [dict(type='DistSamplerSeedHook')]
|
|
runner = self.build_runner(cfg)
|
|
runner.train()
|