mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature]: Add sampler seed hook (#64)
* [Feature]: Add sampler seed hook * [Fix]: Add call with to UT
This commit is contained in:
parent
1244e486ae
commit
2d3e91248c
@ -97,7 +97,7 @@ import numpy as np
|
|||||||
|
|
||||||
@EVALUATORS.register_module()
|
@EVALUATORS.register_module()
|
||||||
class Accuracy(BaseEvaluator):
|
class Accuracy(BaseEvaluator):
|
||||||
|
|
||||||
def process(self, data_samples: Dict, predictions: Dict):
|
def process(self, data_samples: Dict, predictions: Dict):
|
||||||
"""Process one batch of data and predictions. The processed
|
"""Process one batch of data and predictions. The processed
|
||||||
Results should be stored in `self.results`, which will be used
|
Results should be stored in `self.results`, which will be used
|
||||||
|
@ -276,7 +276,7 @@ class ModuleA:
|
|||||||
class ModuleB:
|
class ModuleB:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.instance = GlobalAccessible.get_instance(current=True)
|
self.instance = GlobalAccessible.get_instance(current=True)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
print(f'moduleB: {self.instance.instance_name} is called')
|
print(f'moduleB: {self.instance.instance_name} is called')
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
from .iter_timer_hook import IterTimerHook
|
from .iter_timer_hook import IterTimerHook
|
||||||
|
from .sampler_seed_hook import DistSamplerSeedHook
|
||||||
|
|
||||||
__all__ = ['Hook', 'IterTimerHook']
|
__all__ = ['Hook', 'IterTimerHook', 'DistSamplerSeedHook']
|
||||||
|
29
mmengine/hooks/sampler_seed_hook.py
Normal file
29
mmengine/hooks/sampler_seed_hook.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from mmengine.registry import HOOKS
|
||||||
|
from .hook import Hook
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class DistSamplerSeedHook(Hook):
|
||||||
|
"""Data-loading sampler for distributed training.
|
||||||
|
|
||||||
|
When distributed training, it is only useful in conjunction with
|
||||||
|
:obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same
|
||||||
|
purpose with :obj:`IterLoader`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def before_epoch(self, runner: object) -> None:
|
||||||
|
"""Set the seed for sampler and batch_sampler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (object): The runner of the training process.
|
||||||
|
"""
|
||||||
|
if hasattr(runner.data_loader.sampler, 'set_epoch'): # type: ignore
|
||||||
|
# in case the data loader uses `SequentialSampler` in Pytorch
|
||||||
|
runner.data_loader.sampler.set_epoch(runner.epoch) # type: ignore
|
||||||
|
elif hasattr(
|
||||||
|
runner.data_loader.batch_sampler.sampler, # type: ignore
|
||||||
|
'set_epoch'):
|
||||||
|
# batch sampler in pytorch warps the sampler as its attributes.
|
||||||
|
runner.data_loader.batch_sampler.sampler.set_epoch( # type: ignore
|
||||||
|
runner.epoch) # type: ignore
|
28
tests/test_hook/test_sampler_seed_hook.py
Normal file
28
tests/test_hook/test_sampler_seed_hook.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from mmengine.hooks import DistSamplerSeedHook
|
||||||
|
|
||||||
|
|
||||||
|
class TestDistSamplerSeedHook:
|
||||||
|
|
||||||
|
def test_before_epoch(self):
|
||||||
|
|
||||||
|
hook = DistSamplerSeedHook()
|
||||||
|
# Test dataset sampler
|
||||||
|
runner = Mock()
|
||||||
|
runner.epoch = 1
|
||||||
|
runner.data_loader = Mock()
|
||||||
|
runner.data_loader.sampler = Mock()
|
||||||
|
runner.data_loader.sampler.set_epoch = Mock()
|
||||||
|
hook.before_epoch(runner)
|
||||||
|
runner.data_loader.sampler.set_epoch.assert_called()
|
||||||
|
# Test batch sampler
|
||||||
|
runner = Mock()
|
||||||
|
runner.data_loader = Mock()
|
||||||
|
runner.data_loader.sampler = Mock(spec_set=True)
|
||||||
|
runner.data_loader.batch_sampler = Mock()
|
||||||
|
runner.data_loader.batch_sampler.sampler = Mock()
|
||||||
|
runner.data_loader.batch_sampler.sampler.set_epoch = Mock()
|
||||||
|
hook.before_epoch(runner)
|
||||||
|
runner.data_loader.batch_sampler.sampler.set_epoch.assert_called()
|
Loading…
x
Reference in New Issue
Block a user