From f2190de7874730f7a05abad269de088ae8079983 Mon Sep 17 00:00:00 2001 From: Jiazhen Wang <47851024+teamwong111@users.noreply.github.com> Date: Tue, 31 May 2022 11:34:30 +0800 Subject: [PATCH] [Enhance] Improve Exception in call_hook (#247) * improve exception in call_hook * refine unit test * add test_call_hook * refine * update docstring and ut --- mmengine/runner/runner.py | 8 +++++++- tests/test_runner/test_runner.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index bfa50603..217742b6 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1287,11 +1287,17 @@ class Runner: fn_name (str): The function name in each hook to be called, such as "before_train_epoch". **kwargs: Keyword arguments passed to hook. + + Raises: + TypeError: if Hook got unexpected arguments. """ for hook in self._hooks: # support adding additional custom hook methods if hasattr(hook, fn_name): - getattr(hook, fn_name)(self, **kwargs) + try: + getattr(hook, fn_name)(self, **kwargs) + except TypeError as e: + raise TypeError(f'{e} in {hook}') from None def register_hook( self, diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 06b12b5a..504a8731 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -130,6 +130,13 @@ class ToyHook2(Hook): pass +@HOOKS.register_module() +class ToyHook3(Hook): + + def before_train_iter(self, runner, data_batch): + pass + + @LOOPS.register_module() class CustomTrainLoop(BaseLoop): @@ -1091,6 +1098,29 @@ class TestRunner(TestCase): self.assertEqual(len(runner._hooks), 8) self.assertTrue(isinstance(runner._hooks[7], ToyHook)) + def test_call_hook(self): + # test unexpected argument in `call_hook` + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.experiment_name = 'test_call_hook1' + runner = Runner.from_cfg(cfg) + runner._hooks = [] + custom_hooks = [dict(type='ToyHook3')] + runner.register_custom_hooks(custom_hooks) + with self.assertRaisesRegex( + TypeError, + r"got an unexpected keyword argument 'batch_idx' in " + r''): + runner.call_hook('before_train_iter', batch_idx=0, data_batch=None) + + # test call hook with expected arguments + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.experiment_name = 'test_call_hook2' + runner = Runner.from_cfg(cfg) + runner._hooks = [] + custom_hooks = [dict(type='ToyHook3')] + runner.register_custom_hooks(custom_hooks) + runner.call_hook('before_train_iter', data_batch=None) + def test_register_hooks(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_register_hooks'