2022-03-01 15:38:01 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
from unittest.mock import Mock
|
|
|
|
|
|
|
|
from mmengine.hooks import DistSamplerSeedHook
|
|
|
|
|
|
|
|
|
|
|
|
class TestDistSamplerSeedHook:
|
|
|
|
|
|
|
|
def test_before_epoch(self):
|
|
|
|
|
|
|
|
hook = DistSamplerSeedHook()
|
|
|
|
# Test dataset sampler
|
|
|
|
runner = Mock()
|
|
|
|
runner.epoch = 1
|
2022-04-08 15:57:10 +08:00
|
|
|
runner.train_loop.dataloader = Mock()
|
|
|
|
runner.train_loop.dataloader.sampler = Mock()
|
|
|
|
runner.train_loop.dataloader.sampler.set_epoch = Mock()
|
2022-03-13 16:48:09 +08:00
|
|
|
hook.before_train_epoch(runner)
|
2022-04-08 15:57:10 +08:00
|
|
|
runner.train_loop.dataloader.sampler.set_epoch.assert_called()
|
2022-03-01 15:38:01 +08:00
|
|
|
# Test batch sampler
|
|
|
|
runner = Mock()
|
2022-04-08 15:57:10 +08:00
|
|
|
runner.train_loop.dataloader = Mock()
|
|
|
|
runner.train_loop.dataloader.sampler = Mock(spec_set=True)
|
|
|
|
runner.train_loop.dataloader.batch_sampler = Mock()
|
|
|
|
runner.train_loop.dataloader.batch_sampler.sampler = Mock()
|
|
|
|
runner.train_loop.dataloader.batch_sampler.sampler.set_epoch = Mock()
|
2022-03-13 16:48:09 +08:00
|
|
|
hook.before_train_epoch(runner)
|
2022-04-08 15:57:10 +08:00
|
|
|
runner.train_loop.dataloader.\
|
|
|
|
batch_sampler.sampler.set_epoch.assert_called()
|