From 1ad7bdcb5aeea00e27c5b0707433303bfbd54a65 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Fri, 27 May 2022 20:57:38 +0800 Subject: [PATCH] [Feature] Add default scope and register modules --- mmseg/utils/__init__.py | 4 ++-- mmseg/utils/set_env.py | 36 +++++++++++++++++++++++++++++++ tests/test_utils/test_set_env.py | 37 +++++++++++++++++++++++++++++++- 3 files changed, 74 insertions(+), 3 deletions(-) diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py index ed002c7de..122e20b40 100644 --- a/mmseg/utils/__init__.py +++ b/mmseg/utils/__init__.py @@ -2,9 +2,9 @@ from .collect_env import collect_env from .logger import get_root_logger from .misc import find_latest_checkpoint -from .set_env import setup_multi_processes +from .set_env import register_all_modules, setup_multi_processes __all__ = [ 'get_root_logger', 'collect_env', 'find_latest_checkpoint', - 'setup_multi_processes' + 'setup_multi_processes', 'register_all_modules' ] diff --git a/mmseg/utils/set_env.py b/mmseg/utils/set_env.py index b2d3aaf14..e68cc1532 100644 --- a/mmseg/utils/set_env.py +++ b/mmseg/utils/set_env.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +import datetime import os import platform +import warnings import cv2 import torch.multiprocessing as mp +from mmengine import DefaultScope from ..utils import get_root_logger @@ -53,3 +56,36 @@ def setup_multi_processes(cfg): os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) else: logger.info(f'MKL num threads is {os.environ["MKL_NUM_THREADS"]}') + + +def register_all_modules(init_default_scope: bool = True) -> None: + """Register all modules in mmseg into the registries. + + Args: + init_default_scope (bool): Whether initialize the mmseg default scope. + When `init_default_scope=True`, the global default scope will be + set to `mmseg`, and all registries will build modules from mmseg's + registry node. To understand more about the registry, please refer + to https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/registry.md + Defaults to True. + """ # noqa + import mmseg.core # noqa: F401,F403 + import mmseg.datasets # noqa: F401,F403 + import mmseg.models # noqa: F401,F403 + + if init_default_scope: + never_created = DefaultScope.get_current_instance() is None \ + or not DefaultScope.check_instance_created('mmseg') + if never_created: + DefaultScope.get_instance('mmseg', scope_name='mmseg') + return + current_scope = DefaultScope.get_current_instance() + if current_scope.scope_name != 'mmseg': + warnings.warn('The current default scope ' + f'"{current_scope.scope_name}" is not "mmseg", ' + '`register_all_modules` will force the current' + 'default scope to be "mmseg". If this is not ' + 'expected, please set `init_default_scope=False`.') + # avoid name conflict + new_instance_name = f'mmseg-{datetime.datetime.now()}' + DefaultScope.get_instance(new_instance_name, scope_name='mmseg') diff --git a/tests/test_utils/test_set_env.py b/tests/test_utils/test_set_env.py index 0af4424b1..7d48f616c 100644 --- a/tests/test_utils/test_set_env.py +++ b/tests/test_utils/test_set_env.py @@ -1,13 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. +import datetime import multiprocessing as mp import os import platform +import sys +from unittest import TestCase import cv2 import pytest from mmcv import Config +from mmengine import DefaultScope -from mmseg.utils import setup_multi_processes +from mmseg.utils import register_all_modules, setup_multi_processes @pytest.mark.parametrize('workers_per_gpu', (0, 2)) @@ -83,3 +87,34 @@ def test_setup_multi_processes(workers_per_gpu, valid, env_cfg): assert cv2.getNumThreads() == sys_cv_threads assert 'OMP_NUM_THREADS' not in os.environ assert 'MKL_NUM_THREADS' not in os.environ + + +class TestSetupEnv(TestCase): + + def test_register_all_modules(self): + from mmseg.registry import DATASETS + + # not init default scope + sys.modules.pop('mmseg.datasets', None) + sys.modules.pop('mmseg.datasets.ade', None) + DATASETS._module_dict.pop('ADE20KDataset', None) + self.assertFalse('ADE20KDataset' in DATASETS.module_dict) + register_all_modules(init_default_scope=False) + self.assertTrue('ADE20KDataset' in DATASETS.module_dict) + + # init default scope + sys.modules.pop('mmseg.datasets') + sys.modules.pop('mmseg.datasets.ade') + DATASETS._module_dict.pop('ADE20KDataset', None) + self.assertFalse('ADE20KDataset' in DATASETS.module_dict) + register_all_modules(init_default_scope=True) + self.assertTrue('ADE20KDataset' in DATASETS.module_dict) + self.assertEqual(DefaultScope.get_current_instance().scope_name, + 'mmseg') + + # init default scope when another scope is init + name = f'test-{datetime.datetime.now()}' + DefaultScope.get_instance(name, scope_name='test') + with self.assertWarnsRegex( + Warning, 'The current default scope "test" is not "mmseg"'): + register_all_modules(init_default_scope=True)