mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Update train.py (#428)
* Update train.py Add user-defined hooks. * Update train.py * Update train.py
This commit is contained in:
parent
4b69af7b13
commit
d7f82e5dc8
@ -5,7 +5,8 @@ import warnings
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||
from mmcv.runner import build_optimizer, build_runner
|
||||
from mmcv.runner import HOOKS, build_optimizer, build_runner
|
||||
from mmcv.utils import build_from_cfg
|
||||
|
||||
from mmseg.core import DistEvalHook, EvalHook
|
||||
from mmseg.datasets import build_dataloader, build_dataset
|
||||
@ -113,6 +114,20 @@ def train_segmentor(model,
|
||||
runner.register_hook(
|
||||
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
|
||||
|
||||
# user-defined hooks
|
||||
if cfg.get('custom_hooks', None):
|
||||
custom_hooks = cfg.custom_hooks
|
||||
assert isinstance(custom_hooks, list), \
|
||||
f'custom_hooks expect list type, but got {type(custom_hooks)}'
|
||||
for hook_cfg in cfg.custom_hooks:
|
||||
assert isinstance(hook_cfg, dict), \
|
||||
'Each item in custom_hooks expects dict type, but got ' \
|
||||
f'{type(hook_cfg)}'
|
||||
hook_cfg = hook_cfg.copy()
|
||||
priority = hook_cfg.pop('priority', 'NORMAL')
|
||||
hook = build_from_cfg(hook_cfg, HOOKS)
|
||||
runner.register_hook(hook, priority=priority)
|
||||
|
||||
if cfg.resume_from:
|
||||
runner.resume(cfg.resume_from)
|
||||
elif cfg.load_from:
|
||||
|
Loading…
x
Reference in New Issue
Block a user