mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor unit test of Base hook (#806)
* refactor base hooks * Fix CI
This commit is contained in:
parent
67acdbe245
commit
425ca99e90
@ -29,7 +29,6 @@ class Hook:
|
|||||||
runner (Runner): The runner of the training, validation or testing
|
runner (Runner): The runner of the training, validation or testing
|
||||||
process.
|
process.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def after_run(self, runner) -> None:
|
def after_run(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -39,7 +38,6 @@ class Hook:
|
|||||||
runner (Runner): The runner of the training, validation or testing
|
runner (Runner): The runner of the training, validation or testing
|
||||||
process.
|
process.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def before_train(self, runner) -> None:
|
def before_train(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -48,7 +46,6 @@ class Hook:
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def after_train(self, runner) -> None:
|
def after_train(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -57,7 +54,6 @@ class Hook:
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def before_val(self, runner) -> None:
|
def before_val(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -66,7 +62,6 @@ class Hook:
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the validation process.
|
runner (Runner): The runner of the validation process.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def after_val(self, runner) -> None:
|
def after_val(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -75,7 +70,6 @@ class Hook:
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the validation process.
|
runner (Runner): The runner of the validation process.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def before_test(self, runner) -> None:
|
def before_test(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -84,7 +78,6 @@ class Hook:
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the testing process.
|
runner (Runner): The runner of the testing process.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def after_test(self, runner) -> None:
|
def after_test(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -93,7 +86,6 @@ class Hook:
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the testing process.
|
runner (Runner): The runner of the testing process.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
|
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -104,7 +96,6 @@ class Hook:
|
|||||||
process.
|
process.
|
||||||
checkpoint (dict): Model's checkpoint.
|
checkpoint (dict): Model's checkpoint.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
|
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -115,7 +106,6 @@ class Hook:
|
|||||||
process.
|
process.
|
||||||
checkpoint (dict): Model's checkpoint.
|
checkpoint (dict): Model's checkpoint.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def before_train_epoch(self, runner) -> None:
|
def before_train_epoch(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -300,7 +290,6 @@ class Hook:
|
|||||||
process.
|
process.
|
||||||
mode (str): Current mode of runner. Defaults to 'train'.
|
mode (str): Current mode of runner. Defaults to 'train'.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def _after_epoch(self, runner, mode: str = 'train') -> None:
|
def _after_epoch(self, runner, mode: str = 'train') -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
@ -311,7 +300,6 @@ class Hook:
|
|||||||
process.
|
process.
|
||||||
mode (str): Current mode of runner. Defaults to 'train'.
|
mode (str): Current mode of runner. Defaults to 'train'.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def _before_iter(self,
|
def _before_iter(self,
|
||||||
runner,
|
runner,
|
||||||
@ -328,7 +316,6 @@ class Hook:
|
|||||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||||
mode (str): Current mode of runner. Defaults to 'train'.
|
mode (str): Current mode of runner. Defaults to 'train'.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def _after_iter(self,
|
def _after_iter(self,
|
||||||
runner,
|
runner,
|
||||||
@ -347,7 +334,6 @@ class Hook:
|
|||||||
outputs (dict or Sequence, optional): Outputs from model.
|
outputs (dict or Sequence, optional): Outputs from model.
|
||||||
mode (str): Current mode of runner. Defaults to 'train'.
|
mode (str): Current mode of runner. Defaults to 'train'.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def every_n_epochs(self, runner, n: int) -> bool:
|
def every_n_epochs(self, runner, n: int) -> bool:
|
||||||
"""Test whether current epoch can be evenly divided by n.
|
"""Test whether current epoch can be evenly divided by n.
|
||||||
@ -427,6 +413,11 @@ class Hook:
|
|||||||
return runner.iter + 1 == runner.max_iters
|
return runner.iter + 1 == runner.max_iters
|
||||||
|
|
||||||
def get_triggered_stages(self) -> list:
|
def get_triggered_stages(self) -> list:
|
||||||
|
"""Get all triggered stages with method name of the hook.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of triggered stages.
|
||||||
|
"""
|
||||||
trigger_stages = set()
|
trigger_stages = set()
|
||||||
for stage in Hook.stages:
|
for stage in Hook.stages:
|
||||||
if is_method_overridden(stage, Hook, self):
|
if is_method_overridden(stage, Hook, self):
|
||||||
|
@ -2,9 +2,10 @@
|
|||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from mmengine.hooks import Hook
|
from mmengine.hooks import Hook
|
||||||
|
from mmengine.testing import RunnerTestCase
|
||||||
|
|
||||||
|
|
||||||
class TestHook:
|
class TestHook(RunnerTestCase):
|
||||||
|
|
||||||
def test_before_run(self):
|
def test_before_run(self):
|
||||||
hook = Hook()
|
hook = Hook()
|
||||||
@ -192,3 +193,35 @@ class TestHook:
|
|||||||
runner.max_iters = 2
|
runner.max_iters = 2
|
||||||
return_val = hook.is_last_train_iter(runner)
|
return_val = hook.is_last_train_iter(runner)
|
||||||
assert return_val
|
assert return_val
|
||||||
|
|
||||||
|
def test_get_triggered_stages(self):
|
||||||
|
|
||||||
|
class CustomHook(Hook):
|
||||||
|
|
||||||
|
def after_train(self, runner):
|
||||||
|
return super().after_train(runner)
|
||||||
|
|
||||||
|
hook = CustomHook()
|
||||||
|
triggered_stages = hook.get_triggered_stages()
|
||||||
|
self.assertListEqual(triggered_stages, ['after_train'])
|
||||||
|
|
||||||
|
class CustomHook(Hook):
|
||||||
|
|
||||||
|
def _before_iter(self, runner):
|
||||||
|
...
|
||||||
|
|
||||||
|
hook = CustomHook()
|
||||||
|
triggered_stages = hook.get_triggered_stages()
|
||||||
|
self.assertEqual(len(triggered_stages), 3)
|
||||||
|
self.assertSetEqual(
|
||||||
|
set(triggered_stages),
|
||||||
|
{'before_train_iter', 'before_val_iter', 'before_test_iter'})
|
||||||
|
|
||||||
|
class CustomHook(Hook):
|
||||||
|
|
||||||
|
def is_last_train_epoch(self, runner):
|
||||||
|
...
|
||||||
|
|
||||||
|
hook = CustomHook()
|
||||||
|
triggered_stages = hook.get_triggered_stages()
|
||||||
|
self.assertEqual(len(triggered_stages), 0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user