[Refactor] Refactor unit test of Base hook (#806)

* refactor base hooks

* Fix CI
pull/948/head
Mashiro 2022-12-09 15:54:00 +08:00 committed by Zaida Zhou
parent 67acdbe245
commit 425ca99e90
2 changed files with 39 additions and 15 deletions

View File

@ -29,7 +29,6 @@ class Hook:
runner (Runner): The runner of the training, validation or testing
process.
"""
pass
def after_run(self, runner) -> None:
"""All subclasses should override this method, if they need any
@ -39,7 +38,6 @@ class Hook:
runner (Runner): The runner of the training, validation or testing
process.
"""
pass
def before_train(self, runner) -> None:
"""All subclasses should override this method, if they need any
@ -48,7 +46,6 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
"""
pass
def after_train(self, runner) -> None:
"""All subclasses should override this method, if they need any
@ -57,7 +54,6 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
"""
pass
def before_val(self, runner) -> None:
"""All subclasses should override this method, if they need any
@ -66,7 +62,6 @@ class Hook:
Args:
runner (Runner): The runner of the validation process.
"""
pass
def after_val(self, runner) -> None:
"""All subclasses should override this method, if they need any
@ -75,7 +70,6 @@ class Hook:
Args:
runner (Runner): The runner of the validation process.
"""
pass
def before_test(self, runner) -> None:
"""All subclasses should override this method, if they need any
@ -84,7 +78,6 @@ class Hook:
Args:
runner (Runner): The runner of the testing process.
"""
pass
def after_test(self, runner) -> None:
"""All subclasses should override this method, if they need any
@ -93,7 +86,6 @@ class Hook:
Args:
runner (Runner): The runner of the testing process.
"""
pass
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
"""All subclasses should override this method, if they need any
@ -104,7 +96,6 @@ class Hook:
process.
checkpoint (dict): Model's checkpoint.
"""
pass
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
"""All subclasses should override this method, if they need any
@ -115,7 +106,6 @@ class Hook:
process.
checkpoint (dict): Model's checkpoint.
"""
pass
def before_train_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any
@ -300,7 +290,6 @@ class Hook:
process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass
def _after_epoch(self, runner, mode: str = 'train') -> None:
"""All subclasses should override this method, if they need any
@ -311,7 +300,6 @@ class Hook:
process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass
def _before_iter(self,
runner,
@ -328,7 +316,6 @@ class Hook:
data_batch (dict or tuple or list, optional): Data from dataloader.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass
def _after_iter(self,
runner,
@ -347,7 +334,6 @@ class Hook:
outputs (dict or Sequence, optional): Outputs from model.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass
def every_n_epochs(self, runner, n: int) -> bool:
"""Test whether current epoch can be evenly divided by n.
@ -427,6 +413,11 @@ class Hook:
return runner.iter + 1 == runner.max_iters
def get_triggered_stages(self) -> list:
"""Get all triggered stages with method name of the hook.
Returns:
list: List of triggered stages.
"""
trigger_stages = set()
for stage in Hook.stages:
if is_method_overridden(stage, Hook, self):

View File

@ -2,9 +2,10 @@
from unittest.mock import Mock
from mmengine.hooks import Hook
from mmengine.testing import RunnerTestCase
class TestHook:
class TestHook(RunnerTestCase):
def test_before_run(self):
hook = Hook()
@ -192,3 +193,35 @@ class TestHook:
runner.max_iters = 2
return_val = hook.is_last_train_iter(runner)
assert return_val
def test_get_triggered_stages(self):
class CustomHook(Hook):
def after_train(self, runner):
return super().after_train(runner)
hook = CustomHook()
triggered_stages = hook.get_triggered_stages()
self.assertListEqual(triggered_stages, ['after_train'])
class CustomHook(Hook):
def _before_iter(self, runner):
...
hook = CustomHook()
triggered_stages = hook.get_triggered_stages()
self.assertEqual(len(triggered_stages), 3)
self.assertSetEqual(
set(triggered_stages),
{'before_train_iter', 'before_val_iter', 'before_test_iter'})
class CustomHook(Hook):
def is_last_train_epoch(self, runner):
...
hook = CustomHook()
triggered_stages = hook.get_triggered_stages()
self.assertEqual(len(triggered_stages), 0)