[Enhance] Improve Exception in call_hook (#247)

* improve exception in call_hook

* refine unit test

* add test_call_hook

* refine

* update docstring and ut
pull/271/head
Jiazhen Wang 2022-05-31 11:34:30 +08:00 committed by GitHub
parent 38b22d9e68
commit f2190de787
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 1 deletions

View File

@ -1287,11 +1287,17 @@ class Runner:
fn_name (str): The function name in each hook to be called, such as
"before_train_epoch".
**kwargs: Keyword arguments passed to hook.
Raises:
TypeError: if Hook got unexpected arguments.
"""
for hook in self._hooks:
# support adding additional custom hook methods
if hasattr(hook, fn_name):
getattr(hook, fn_name)(self, **kwargs)
try:
getattr(hook, fn_name)(self, **kwargs)
except TypeError as e:
raise TypeError(f'{e} in {hook}') from None
def register_hook(
self,

View File

@ -130,6 +130,13 @@ class ToyHook2(Hook):
pass
@HOOKS.register_module()
class ToyHook3(Hook):
def before_train_iter(self, runner, data_batch):
pass
@LOOPS.register_module()
class CustomTrainLoop(BaseLoop):
@ -1091,6 +1098,29 @@ class TestRunner(TestCase):
self.assertEqual(len(runner._hooks), 8)
self.assertTrue(isinstance(runner._hooks[7], ToyHook))
def test_call_hook(self):
# test unexpected argument in `call_hook`
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_call_hook1'
runner = Runner.from_cfg(cfg)
runner._hooks = []
custom_hooks = [dict(type='ToyHook3')]
runner.register_custom_hooks(custom_hooks)
with self.assertRaisesRegex(
TypeError,
r"got an unexpected keyword argument 'batch_idx' in "
r'<test_runner.ToyHook3 object at \w+>'):
runner.call_hook('before_train_iter', batch_idx=0, data_batch=None)
# test call hook with expected arguments
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_call_hook2'
runner = Runner.from_cfg(cfg)
runner._hooks = []
custom_hooks = [dict(type='ToyHook3')]
runner.register_custom_hooks(custom_hooks)
runner.call_hook('before_train_iter', data_batch=None)
def test_register_hooks(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_register_hooks'