Update train.py (#428)

* Update train.py

Add user-defined hooks.

* Update train.py

* Update train.py
This commit is contained in:
gszh 2021-10-30 03:13:40 +08:00 committed by GitHub
parent 4b69af7b13
commit d7f82e5dc8

View File

@ -5,7 +5,8 @@ import warnings
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import build_optimizer, build_runner
from mmcv.runner import HOOKS, build_optimizer, build_runner
from mmcv.utils import build_from_cfg
from mmseg.core import DistEvalHook, EvalHook
from mmseg.datasets import build_dataloader, build_dataset
@ -113,6 +114,20 @@ def train_segmentor(model,
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
assert isinstance(custom_hooks, list), \
f'custom_hooks expect list type, but got {type(custom_hooks)}'
for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from: