mmengine/tests/test_hook/test_sampler_seed_hook.py
RangiLyu 59cc08e3ac
[Refactor] Refactor data_batch type and remove cur_dataloader in runner. (#171)
* [Refactor] Refactor data_batch type.

* fix sampler

* [Refactor] Remove cur_dataloader in runner.

* fix set_epoch
2022-04-08 15:57:10 +08:00

30 lines
1.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
from mmengine.hooks import DistSamplerSeedHook
class TestDistSamplerSeedHook:
def test_before_epoch(self):
hook = DistSamplerSeedHook()
# Test dataset sampler
runner = Mock()
runner.epoch = 1
runner.train_loop.dataloader = Mock()
runner.train_loop.dataloader.sampler = Mock()
runner.train_loop.dataloader.sampler.set_epoch = Mock()
hook.before_train_epoch(runner)
runner.train_loop.dataloader.sampler.set_epoch.assert_called()
# Test batch sampler
runner = Mock()
runner.train_loop.dataloader = Mock()
runner.train_loop.dataloader.sampler = Mock(spec_set=True)
runner.train_loop.dataloader.batch_sampler = Mock()
runner.train_loop.dataloader.batch_sampler.sampler = Mock()
runner.train_loop.dataloader.batch_sampler.sampler.set_epoch = Mock()
hook.before_train_epoch(runner)
runner.train_loop.dataloader.\
batch_sampler.sampler.set_epoch.assert_called()