[Feature]: Add param scheduler hook (#63)
* [Feature]: Add param scheduler hook * [Fix]: update docstring and add assert_call to UTpull/70/head
parent
2d3e91248c
commit
ab31e1936e
|
@ -2,5 +2,8 @@
|
|||
from .hook import Hook
|
||||
from .iter_timer_hook import IterTimerHook
|
||||
from .sampler_seed_hook import DistSamplerSeedHook
|
||||
from .param_scheduler_hook import ParamSchedulerHook
|
||||
|
||||
__all__ = ['Hook', 'IterTimerHook', 'DistSamplerSeedHook']
|
||||
__all__ = [
|
||||
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class ParamSchedulerHook(Hook):
|
||||
"""A hook to update some hyper-parameters in optimizer, e.g learning rate
|
||||
and momentum."""
|
||||
|
||||
def after_iter(self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Call step function for each scheduler after each iteration.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample]): Data from dataloader. In
|
||||
order to keep this interface consistent with other hooks, we
|
||||
keep ``data_batch`` here. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample]): Outputs from model. In
|
||||
order to keep this interface consistent with other hooks, we
|
||||
keep ``data_batch`` here. Defaults to None.
|
||||
"""
|
||||
for scheduler in runner.schedulers: # type: ignore
|
||||
if not scheduler.by_epoch:
|
||||
scheduler.step()
|
||||
|
||||
def after_epoch(self, runner: object) -> None:
|
||||
"""Call step function for each scheduler after each epoch.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
"""
|
||||
for scheduler in runner.schedulers: # type: ignore
|
||||
if scheduler.by_epoch:
|
||||
scheduler.step()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import Mock
|
||||
|
||||
from mmengine.hooks import ParamSchedulerHook
|
||||
|
||||
|
||||
class TestParamSchedulerHook:
|
||||
|
||||
def test_after_iter(self):
|
||||
Hook = ParamSchedulerHook()
|
||||
Runner = Mock()
|
||||
scheduler = Mock()
|
||||
scheduler.step = Mock()
|
||||
scheduler.by_epoch = False
|
||||
Runner.schedulers = [scheduler]
|
||||
Hook.after_iter(Runner)
|
||||
scheduler.step.assert_called()
|
||||
|
||||
def test_after_epoch(self):
|
||||
Hook = ParamSchedulerHook()
|
||||
Runner = Mock()
|
||||
scheduler = Mock()
|
||||
scheduler.step = Mock()
|
||||
scheduler.by_epoch = True
|
||||
Runner.schedulers = [scheduler]
|
||||
Hook.after_epoch(Runner)
|
||||
scheduler.step.assert_called()
|
Loading…
Reference in New Issue