parent
67acdbe245
commit
425ca99e90
|
@ -29,7 +29,6 @@ class Hook:
|
|||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_run(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -39,7 +38,6 @@ class Hook:
|
|||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_train(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -48,7 +46,6 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_train(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -57,7 +54,6 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_val(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -66,7 +62,6 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_val(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -75,7 +70,6 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_test(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -84,7 +78,6 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_test(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -93,7 +86,6 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -104,7 +96,6 @@ class Hook:
|
|||
process.
|
||||
checkpoint (dict): Model's checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -115,7 +106,6 @@ class Hook:
|
|||
process.
|
||||
checkpoint (dict): Model's checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_train_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -300,7 +290,6 @@ class Hook:
|
|||
process.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _after_epoch(self, runner, mode: str = 'train') -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -311,7 +300,6 @@ class Hook:
|
|||
process.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _before_iter(self,
|
||||
runner,
|
||||
|
@ -328,7 +316,6 @@ class Hook:
|
|||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _after_iter(self,
|
||||
runner,
|
||||
|
@ -347,7 +334,6 @@ class Hook:
|
|||
outputs (dict or Sequence, optional): Outputs from model.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
pass
|
||||
|
||||
def every_n_epochs(self, runner, n: int) -> bool:
|
||||
"""Test whether current epoch can be evenly divided by n.
|
||||
|
@ -427,6 +413,11 @@ class Hook:
|
|||
return runner.iter + 1 == runner.max_iters
|
||||
|
||||
def get_triggered_stages(self) -> list:
|
||||
"""Get all triggered stages with method name of the hook.
|
||||
|
||||
Returns:
|
||||
list: List of triggered stages.
|
||||
"""
|
||||
trigger_stages = set()
|
||||
for stage in Hook.stages:
|
||||
if is_method_overridden(stage, Hook, self):
|
||||
|
|
|
@ -2,9 +2,10 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.testing import RunnerTestCase
|
||||
|
||||
|
||||
class TestHook:
|
||||
class TestHook(RunnerTestCase):
|
||||
|
||||
def test_before_run(self):
|
||||
hook = Hook()
|
||||
|
@ -192,3 +193,35 @@ class TestHook:
|
|||
runner.max_iters = 2
|
||||
return_val = hook.is_last_train_iter(runner)
|
||||
assert return_val
|
||||
|
||||
def test_get_triggered_stages(self):
|
||||
|
||||
class CustomHook(Hook):
|
||||
|
||||
def after_train(self, runner):
|
||||
return super().after_train(runner)
|
||||
|
||||
hook = CustomHook()
|
||||
triggered_stages = hook.get_triggered_stages()
|
||||
self.assertListEqual(triggered_stages, ['after_train'])
|
||||
|
||||
class CustomHook(Hook):
|
||||
|
||||
def _before_iter(self, runner):
|
||||
...
|
||||
|
||||
hook = CustomHook()
|
||||
triggered_stages = hook.get_triggered_stages()
|
||||
self.assertEqual(len(triggered_stages), 3)
|
||||
self.assertSetEqual(
|
||||
set(triggered_stages),
|
||||
{'before_train_iter', 'before_val_iter', 'before_test_iter'})
|
||||
|
||||
class CustomHook(Hook):
|
||||
|
||||
def is_last_train_epoch(self, runner):
|
||||
...
|
||||
|
||||
hook = CustomHook()
|
||||
triggered_stages = hook.get_triggered_stages()
|
||||
self.assertEqual(len(triggered_stages), 0)
|
||||
|
|
Loading…
Reference in New Issue