From d62f198d2b107ba1686e5e4d126033eca8f4bc28 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Tue, 15 Jun 2021 21:09:58 +0800 Subject: [PATCH] [Feature]Support custom hooks (#305) * add mytrain.py for test * test before layers * test attr in layers * test classifier * delete mytrain.py * register custom_hooks in runner * set custom_hooks_config to cfg.get(custom_hooks, None) --- mmcls/apis/train.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mmcls/apis/train.py b/mmcls/apis/train.py index 42b9a5b4..fb4ef629 100644 --- a/mmcls/apis/train.py +++ b/mmcls/apis/train.py @@ -127,9 +127,13 @@ def train_model(model, optimizer_config = cfg.optimizer_config # register hooks - runner.register_training_hooks(cfg.lr_config, optimizer_config, - cfg.checkpoint_config, cfg.log_config, - cfg.get('momentum_config', None)) + runner.register_training_hooks( + cfg.lr_config, + optimizer_config, + cfg.checkpoint_config, + cfg.log_config, + cfg.get('momentum_config', None), + custom_hooks_config=cfg.get('custom_hooks', None)) if distributed: runner.register_hook(DistSamplerSeedHook())