[Enhance] Improve Exception in call_hook (#247)
* improve exception in call_hook * refine unit test * add test_call_hook * refine * update docstring and utpull/271/head
parent
38b22d9e68
commit
f2190de787
|
@ -1287,11 +1287,17 @@ class Runner:
|
|||
fn_name (str): The function name in each hook to be called, such as
|
||||
"before_train_epoch".
|
||||
**kwargs: Keyword arguments passed to hook.
|
||||
|
||||
Raises:
|
||||
TypeError: if Hook got unexpected arguments.
|
||||
"""
|
||||
for hook in self._hooks:
|
||||
# support adding additional custom hook methods
|
||||
if hasattr(hook, fn_name):
|
||||
getattr(hook, fn_name)(self, **kwargs)
|
||||
try:
|
||||
getattr(hook, fn_name)(self, **kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError(f'{e} in {hook}') from None
|
||||
|
||||
def register_hook(
|
||||
self,
|
||||
|
|
|
@ -130,6 +130,13 @@ class ToyHook2(Hook):
|
|||
pass
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class ToyHook3(Hook):
|
||||
|
||||
def before_train_iter(self, runner, data_batch):
|
||||
pass
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
class CustomTrainLoop(BaseLoop):
|
||||
|
||||
|
@ -1091,6 +1098,29 @@ class TestRunner(TestCase):
|
|||
self.assertEqual(len(runner._hooks), 8)
|
||||
self.assertTrue(isinstance(runner._hooks[7], ToyHook))
|
||||
|
||||
def test_call_hook(self):
|
||||
# test unexpected argument in `call_hook`
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_call_hook1'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner._hooks = []
|
||||
custom_hooks = [dict(type='ToyHook3')]
|
||||
runner.register_custom_hooks(custom_hooks)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"got an unexpected keyword argument 'batch_idx' in "
|
||||
r'<test_runner.ToyHook3 object at \w+>'):
|
||||
runner.call_hook('before_train_iter', batch_idx=0, data_batch=None)
|
||||
|
||||
# test call hook with expected arguments
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_call_hook2'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner._hooks = []
|
||||
custom_hooks = [dict(type='ToyHook3')]
|
||||
runner.register_custom_hooks(custom_hooks)
|
||||
runner.call_hook('before_train_iter', data_batch=None)
|
||||
|
||||
def test_register_hooks(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_register_hooks'
|
||||
|
|
Loading…
Reference in New Issue