diff --git a/tools/train.py b/tools/train.py index 1cd2e53f8..2023b680e 100644 --- a/tools/train.py +++ b/tools/train.py @@ -6,6 +6,7 @@ import os.path as osp from mmengine.config import Config, DictAction from mmengine.logging import print_log +from mmengine.registry import RUNNERS from mmengine.runner import Runner from mmseg.utils import register_all_modules @@ -97,7 +98,13 @@ def main(): cfg.load_from = args.resume # build the runner from config - runner = Runner.from_cfg(cfg) + if 'runner_type' not in cfg: + # build the default runner + runner = Runner.from_cfg(cfg) + else: + # build customized runner from the registry + # if 'runner_type' is set in the cfg + runner = RUNNERS.build(cfg) # start training runner.train()