[Feature] Support auto import modules from registry. (#2481)

## Motivation

The registry now supports auto-import modules from the given location.

register_all_modules before running is no longer needed. The modules
will be lazy-imported during building.

- [x] This PR can be merged after
https://github.com/open-mmlab/mmengine/pull/643. The MMEngine version
should be updated.

Ref: https://github.com/open-mmlab/mmdetection/pull/9143
This commit is contained in:
谢昕辰 2023-02-23 20:33:17 +08:00 committed by GitHub
parent 2d38bc8554
commit 039ba5d4ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 188 additions and 113 deletions

View File

@ -1,7 +1,10 @@
# model settings # model settings
norm_cfg = dict(type='SyncBN', requires_grad=True) norm_cfg = dict(type='SyncBN', requires_grad=True)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s12_3rdparty_32xb128_in1k_20220414-f8d83051.pth' # noqa checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s12_3rdparty_32xb128_in1k_20220414-f8d83051.pth' # noqa
custom_imports = dict(imports='mmcls.models', allow_failed_imports=False) # TODO: delete custom_imports after mmcls supports auto import
# please install mmcls>=1.0
# import mmcls.models to trigger register_module in mmcls
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False)
data_preprocessor = dict( data_preprocessor = dict(
type='SegDataPreProcessor', type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53], mean=[123.675, 116.28, 103.53],

View File

@ -1,7 +1,5 @@
_base_ = ['../_base_/default_runtime.py', '../_base_/datasets/cityscapes.py'] _base_ = ['../_base_/default_runtime.py', '../_base_/datasets/cityscapes.py']
custom_imports = dict(imports='mmdet.models', allow_failed_imports=False)
crop_size = (512, 1024) crop_size = (512, 1024)
data_preprocessor = dict( data_preprocessor = dict(
type='SegDataPreProcessor', type='SegDataPreProcessor',

View File

@ -3,7 +3,6 @@ _base_ = [
] ]
pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pth' # noqa pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pth' # noqa
custom_imports = dict(imports='mmdet.models', allow_failed_imports=False)
crop_size = (640, 640) crop_size = (640, 640)
data_preprocessor = dict( data_preprocessor = dict(

View File

@ -460,12 +460,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from mmengine.runner import Runner\n", "from mmengine.runner import Runner\n",
"from mmseg.utils import register_all_modules\n",
"\n", "\n",
"# register all modules in mmseg into the registries\n", "runner = Runner.from_cfg(cfg)"
"# do not init the default scope here because it will be init in the runner\n",
"register_all_modules(init_default_scope=False)\n",
"runner = Runner.from_cfg(cfg)\n"
] ]
}, },
{ {
@ -523,7 +519,7 @@
"provenance": [] "provenance": []
}, },
"kernelspec": { "kernelspec": {
"display_name": "Python 3.8.5 ('tensorflow')", "display_name": "Python 3.10.6 ('pt1.12')",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -537,7 +533,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.5" "version": "3.10.6"
}, },
"pycharm": { "pycharm": {
"stem_cell": { "stem_cell": {
@ -550,7 +546,7 @@
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {
"hash": "20d4b83e0c8b3730b580c42434163d64f4b735d580303a8fade7c849d4d29eba" "hash": "0442e67aee3d9cbb788fa6e86d60c4ffa94ad7f1943c65abfecb99a6f4696c58"
} }
} }
}, },

View File

@ -4,7 +4,6 @@ from argparse import ArgumentParser
from mmengine.model import revert_sync_batchnorm from mmengine.model import revert_sync_batchnorm
from mmseg.apis import inference_model, init_model, show_result_pyplot from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import register_all_modules
def main(): def main():
@ -24,8 +23,6 @@ def main():
'--title', default='result', help='The image identifier.') '--title', default='result', help='The image identifier.')
args = parser.parse_args() args = parser.parse_args()
register_all_modules()
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device) model = init_model(args.config, args.checkpoint, device=args.device)
if args.device == 'cpu': if args.device == 'cpu':

File diff suppressed because one or more lines are too long

View File

@ -6,7 +6,6 @@ from mmengine.model.utils import revert_sync_batchnorm
from mmseg.apis import inference_model, init_model from mmseg.apis import inference_model, init_model
from mmseg.apis.inference import show_result_pyplot from mmseg.apis.inference import show_result_pyplot
from mmseg.utils import register_all_modules
def main(): def main():
@ -53,8 +52,6 @@ def main():
assert args.show or args.output_file, \ assert args.show or args.output_file, \
'At least one output should be enabled.' 'At least one output should be enabled.'
register_all_modules()
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device) model = init_model(args.config, args.checkpoint, device=args.device)
if args.device == 'cpu': if args.device == 'cpu':

View File

@ -15,8 +15,8 @@ Instantiate Cityscapes training dataset:
```python ```python
from mmseg.datasets import CityscapesDataset from mmseg.datasets import CityscapesDataset
from mmseg.utils import register_all_modules from mmengine.registry import init_default_scope
register_all_modules() init_default_scope('mmseg')
data_root = 'data/cityscapes/' data_root = 'data/cityscapes/'
data_prefix=dict(img_path='leftImg8bit/train', seg_map_path='gtFine/train') data_prefix=dict(img_path='leftImg8bit/train', seg_map_path='gtFine/train')

View File

@ -91,10 +91,8 @@ Option (b). If you install mmsegmentation with pip, open you python interpreter
```python ```python
from mmseg.apis import inference_model, init_model, show_result_pyplot from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import register_all_modules
import mmcv import mmcv
register_all_modules()
config_file = 'pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py' config_file = 'pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_file = 'pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth' checkpoint_file = 'pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'

View File

@ -31,14 +31,10 @@ Example:
```python ```python
from mmseg.apis import init_model from mmseg.apis import init_model
from mmseg.utils import register_all_modules
config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py' config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth' checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
# register all modules in mmseg into the registries
register_all_modules()
# initialize model without checkpoint # initialize model without checkpoint
model = init_model(config_path) model = init_model(config_path)
@ -76,14 +72,11 @@ Example:
```python ```python
from mmseg.apis import init_model, inference_model from mmseg.apis import init_model, inference_model
from mmseg.utils import register_all_modules
config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py' config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth' checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
img_path = 'demo/demo.png' img_path = 'demo/demo.png'
# register all modules in mmseg into the registries
register_all_modules()
model = init_model(config_path, checkpoint_path) model = init_model(config_path, checkpoint_path)
result = inference_model(model, img_path) result = inference_model(model, img_path)
@ -115,14 +108,11 @@ Example:
```python ```python
from mmseg.apis import init_model, inference_model, show_result_pyplot from mmseg.apis import init_model, inference_model, show_result_pyplot
from mmseg.utils import register_all_modules
config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py' config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth' checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
img_path = 'demo/demo.png' img_path = 'demo/demo.png'
# register all modules in mmseg into the registries
register_all_modules()
# build the model from a config file and a checkpoint file # build the model from a config file and a checkpoint file
model = init_model(config_path, checkpoint_path, device='cuda:0') model = init_model(config_path, checkpoint_path, device='cuda:0')

View File

@ -9,9 +9,10 @@
实例化 Cityscapes 训练数据集: 实例化 Cityscapes 训练数据集:
```python ```python
from mmengine.registry import init_default_scope
from mmseg.datasets import CityscapesDataset from mmseg.datasets import CityscapesDataset
from mmseg.utils import register_all_modules
register_all_modules() init_default_scope('mmseg')
data_root = 'data/cityscapes/' data_root = 'data/cityscapes/'
data_prefix=dict(img_path='leftImg8bit/train', seg_map_path='gtFine/train') data_prefix=dict(img_path='leftImg8bit/train', seg_map_path='gtFine/train')

View File

@ -92,10 +92,8 @@ python demo/image_demo.py demo/demo.png configs/pspnet/pspnet_r50-d8_4xb2-40k_ci
```python ```python
from mmseg.apis import inference_model, init_model, show_result_pyplot from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import register_all_modules
import mmcv import mmcv
register_all_modules()
config_file = 'pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py' config_file = 'pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_file = 'pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth' checkpoint_file = 'pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'

View File

@ -9,7 +9,7 @@ from .version import __version__, version_info
MMCV_MIN = '2.0.0rc4' MMCV_MIN = '2.0.0rc4'
MMCV_MAX = '2.1.0' MMCV_MAX = '2.1.0'
MMENGINE_MIN = '0.2.0' MMENGINE_MIN = '0.4.0'
MMENGINE_MAX = '1.0.0' MMENGINE_MAX = '1.0.0'

View File

@ -9,6 +9,7 @@ import numpy as np
import torch import torch
from mmengine import Config from mmengine import Config
from mmengine.dataset import Compose from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint from mmengine.runner import load_checkpoint
from mmengine.utils import mkdir_or_exist from mmengine.utils import mkdir_or_exist
@ -48,6 +49,8 @@ def init_model(config: Union[str, Path, Config],
config.model.backbone.init_cfg = None config.model.backbone.init_cfg = None
config.model.pretrained = None config.model.pretrained = None
config.model.train_cfg = None config.model.train_cfg = None
init_default_scope(config.get('default_scope', 'mmseg'))
model = MODELS.build(config.model) model = MODELS.build(config.model)
if checkpoint is not None: if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')

View File

@ -1,13 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS, from .registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, INFERENCERS,
MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS, LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS,
RUNNERS, TASK_UTILS, TRANSFORMS, VISBACKENDS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,
VISUALIZERS, WEIGHT_INITIALIZERS) TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,
WEIGHT_INITIALIZERS)
__all__ = [ __all__ = [
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS',
'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS',
'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS',
'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS' 'VISBACKENDS', 'VISUALIZERS', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS',
'EVALUATOR', 'LOG_PROCESSORS', 'OPTIM_WRAPPERS', 'INFERENCERS'
] ]

View File

@ -10,6 +10,7 @@ from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
from mmengine.registry import DATASETS as MMENGINE_DATASETS from mmengine.registry import DATASETS as MMENGINE_DATASETS
from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
from mmengine.registry import HOOKS as MMENGINE_HOOKS from mmengine.registry import HOOKS as MMENGINE_HOOKS
from mmengine.registry import INFERENCERS as MMENGINE_INFERENCERS
from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS
from mmengine.registry import LOOPS as MMENGINE_LOOPS from mmengine.registry import LOOPS as MMENGINE_LOOPS
from mmengine.registry import METRICS as MMENGINE_METRICS from mmengine.registry import METRICS as MMENGINE_METRICS
@ -39,45 +40,82 @@ RUNNER_CONSTRUCTORS = Registry(
# manage all kinds of loops like `EpochBasedTrainLoop` # manage all kinds of loops like `EpochBasedTrainLoop`
LOOPS = Registry('loop', parent=MMENGINE_LOOPS) LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
# manage all kinds of hooks like `CheckpointHook` # manage all kinds of hooks like `CheckpointHook`
HOOKS = Registry('hook', parent=MMENGINE_HOOKS) HOOKS = Registry(
'hook', parent=MMENGINE_HOOKS, locations=['mmseg.engine.hooks'])
# manage data-related modules # manage data-related modules
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS) DATASETS = Registry(
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS) 'dataset', parent=MMENGINE_DATASETS, locations=['mmseg.datasets'])
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS) DATA_SAMPLERS = Registry(
'data sampler',
parent=MMENGINE_DATA_SAMPLERS,
locations=['mmseg.datasets.samplers'])
TRANSFORMS = Registry(
'transform',
parent=MMENGINE_TRANSFORMS,
locations=['mmseg.datasets.transforms'])
# mangage all kinds of modules inheriting `nn.Module` # mangage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS) MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmseg.models'])
# mangage all kinds of model wrappers like 'MMDistributedDataParallel' # mangage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS) MODEL_WRAPPERS = Registry(
'model_wrapper',
parent=MMENGINE_MODEL_WRAPPERS,
locations=['mmseg.models'])
# mangage all kinds of weight initialization modules like `Uniform` # mangage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry( WEIGHT_INITIALIZERS = Registry(
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS) 'weight initializer',
parent=MMENGINE_WEIGHT_INITIALIZERS,
locations=['mmseg.models'])
# mangage all kinds of optimizers like `SGD` and `Adam` # mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS) OPTIMIZERS = Registry(
'optimizer',
parent=MMENGINE_OPTIMIZERS,
locations=['mmseg.engine.optimizers'])
# manage optimizer wrapper # manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim_wrapper', parent=MMENGINE_OPTIM_WRAPPERS) OPTIM_WRAPPERS = Registry(
'optim_wrapper',
parent=MMENGINE_OPTIM_WRAPPERS,
locations=['mmseg.engine.optimizers'])
# manage constructors that customize the optimization hyperparameters. # manage constructors that customize the optimization hyperparameters.
OPTIM_WRAPPER_CONSTRUCTORS = Registry( OPTIM_WRAPPER_CONSTRUCTORS = Registry(
'optimizer wrapper constructor', 'optimizer wrapper constructor',
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS) parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS,
locations=['mmseg.engine.optimizers'])
# mangage all kinds of parameter schedulers like `MultiStepLR` # mangage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry( PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS) 'parameter scheduler',
parent=MMENGINE_PARAM_SCHEDULERS,
locations=['mmseg.engine.schedulers'])
# manage all kinds of metrics # manage all kinds of metrics
METRICS = Registry('metric', parent=MMENGINE_METRICS) METRICS = Registry(
'metric', parent=MMENGINE_METRICS, locations=['mmseg.evaluation'])
# manage evaluator # manage evaluator
EVALUATOR = Registry('evaluator', parent=MMENGINE_EVALUATOR) EVALUATOR = Registry(
'evaluator', parent=MMENGINE_EVALUATOR, locations=['mmseg.evaluation'])
# manage task-specific modules like ohem pixel sampler # manage task-specific modules like ohem pixel sampler
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS) TASK_UTILS = Registry(
'task util', parent=MMENGINE_TASK_UTILS, locations=['mmseg.models'])
# manage visualizer # manage visualizer
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS) VISUALIZERS = Registry(
'visualizer',
parent=MMENGINE_VISUALIZERS,
locations=['mmseg.visualization'])
# manage visualizer backend # manage visualizer backend
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS) VISBACKENDS = Registry(
'vis_backend',
parent=MMENGINE_VISBACKENDS,
locations=['mmseg.visualization'])
# manage logprocessor # manage logprocessor
LOG_PROCESSORS = Registry('log_processor', parent=MMENGINE_LOG_PROCESSORS) LOG_PROCESSORS = Registry(
'log_processor',
parent=MMENGINE_LOG_PROCESSORS,
locations=['mmseg.visualization'])
# manage inferencer
INFERENCERS = Registry('inferencer', parent=MMENGINE_INFERENCERS)

View File

@ -1,4 +1,4 @@
mmcls>=1.0.0rc0 mmcls>=1.0.0rc0
mmcv>=2.0.0rc4 mmcv>=2.0.0rc4
-e git+https://github.com/open-mmlab/mmdetection.git@dev-3.x#egg=mmdet -e git+https://github.com/open-mmlab/mmdetection.git@dev-3.x#egg=mmdet
mmengine>=0.2.0,<1.0.0 mmengine>=0.4.0,<1.0.0

View File

@ -1,5 +1,5 @@
mmcv>=2.0.0rc1,<2.1.0 mmcv>=2.0.0rc1,<2.1.0
mmengine>=0.1.0,<1.0.0 mmengine>=0.4.0,<1.0.0
prettytable prettytable
scipy scipy
torch torch

View File

@ -6,10 +6,10 @@ from os.path import dirname, exists, isdir, join, relpath
import numpy as np import numpy as np
from mmengine import Config from mmengine import Config
from mmengine.dataset import Compose from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from torch import nn from torch import nn
from mmseg.models import build_segmentor from mmseg.models import build_segmentor
from mmseg.utils import register_all_modules
def _get_config_directory(): def _get_config_directory():
@ -70,7 +70,7 @@ def test_config_data_pipeline():
xdoctest -m tests/test_config.py test_config_build_data_pipeline xdoctest -m tests/test_config.py test_config_build_data_pipeline
""" """
register_all_modules() init_default_scope('mmseg')
config_dpath = _get_config_directory() config_dpath = _get_config_directory()
print(f'Found config_dpath = {config_dpath!r}') print(f'Found config_dpath = {config_dpath!r}')

View File

@ -2,12 +2,12 @@
import os.path as osp import os.path as osp
from mmengine.dataset import ConcatDataset, RepeatDataset from mmengine.dataset import ConcatDataset, RepeatDataset
from mmengine.registry import init_default_scope
from mmseg.datasets import MultiImageMixDataset from mmseg.datasets import MultiImageMixDataset
from mmseg.registry import DATASETS from mmseg.registry import DATASETS
from mmseg.utils import register_all_modules
register_all_modules() init_default_scope('mmseg')
@DATASETS.register_module() @DATASETS.register_module()

View File

@ -5,6 +5,7 @@ import os.path as osp
import mmcv import mmcv
import numpy as np import numpy as np
import pytest import pytest
from mmengine.registry import init_default_scope
from PIL import Image from PIL import Image
from mmseg.datasets.transforms import * # noqa from mmseg.datasets.transforms import * # noqa
@ -12,9 +13,8 @@ from mmseg.datasets.transforms import (LoadBiomedicalData,
LoadBiomedicalImageFromFile, LoadBiomedicalImageFromFile,
PhotoMetricDistortion, RandomCrop) PhotoMetricDistortion, RandomCrop)
from mmseg.registry import TRANSFORMS from mmseg.registry import TRANSFORMS
from mmseg.utils import register_all_modules
register_all_modules() init_default_scope('mmseg')
def test_resize(): def test_resize():

View File

@ -5,12 +5,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmengine.optim.optimizer import build_optim_wrapper from mmengine.optim.optimizer import build_optim_wrapper
from mmengine.registry import init_default_scope
from mmseg.engine.optimizers.layer_decay_optimizer_constructor import \ from mmseg.engine.optimizers.layer_decay_optimizer_constructor import \
LearningRateDecayOptimizerConstructor LearningRateDecayOptimizerConstructor
from mmseg.utils import register_all_modules
register_all_modules() init_default_scope('mmseg')
base_lr = 1 base_lr = 1
decay_rate = 2 decay_rate = 2

View File

@ -2,14 +2,14 @@
import pytest import pytest
import torch import torch
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmengine.registry import init_default_scope
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule, from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
InterpConv, UNet, UpConvBlock) InterpConv, UNet, UpConvBlock)
from mmseg.models.utils import Upsample from mmseg.models.utils import Upsample
from mmseg.utils import register_all_modules
from .utils import check_norm_state from .utils import check_norm_state
register_all_modules() init_default_scope('mmseg')
def test_unet_basic_conv_block(): def test_unet_basic_conv_block():

View File

@ -9,14 +9,14 @@ import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.model.utils import revert_sync_batchnorm from mmengine.model.utils import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.structures import PixelData from mmengine.structures import PixelData
from mmengine.utils import is_list_of, is_tuple_of from mmengine.utils import is_list_of, is_tuple_of
from torch import Tensor from torch import Tensor
from mmseg.structures import SegDataSample from mmseg.structures import SegDataSample
from mmseg.utils import register_all_modules
register_all_modules() init_default_scope('mmseg')
def _demo_mm_inputs(batch_size=2, image_shapes=(3, 32, 32), num_classes=5): def _demo_mm_inputs(batch_size=2, image_shapes=(3, 32, 32), num_classes=5):

View File

@ -3,15 +3,15 @@ from os.path import dirname, join
import torch import torch
from mmengine import Config from mmengine import Config
from mmengine.registry import init_default_scope
from mmengine.structures import PixelData from mmengine.structures import PixelData
from mmseg.registry import MODELS from mmseg.registry import MODELS
from mmseg.structures import SegDataSample from mmseg.structures import SegDataSample
from mmseg.utils import register_all_modules
def test_maskformer_head(): def test_maskformer_head():
register_all_modules() init_default_scope('mmseg')
repo_dpath = dirname(dirname(__file__)) repo_dpath = dirname(dirname(__file__))
cfg = Config.fromfile( cfg = Config.fromfile(
join( join(

View File

@ -2,14 +2,14 @@
import torch import torch
from mmengine import ConfigDict from mmengine import ConfigDict
from mmengine.model import BaseTTAModel from mmengine.model import BaseTTAModel
from mmengine.registry import init_default_scope
from mmengine.structures import PixelData from mmengine.structures import PixelData
from mmseg.registry import MODELS from mmseg.registry import MODELS
from mmseg.structures import SegDataSample from mmseg.structures import SegDataSample
from mmseg.utils import register_all_modules
from .utils import * # noqa: F401,F403 from .utils import * # noqa: F401,F403
register_all_modules() init_default_scope('mmseg')
def test_encoder_decoder_tta(): def test_encoder_decoder_tta():

View File

@ -8,11 +8,11 @@ import torch
from mmengine import Config from mmengine import Config
from mmengine.fileio import dump from mmengine.fileio import dump
from mmengine.model.utils import revert_sync_batchnorm from mmengine.model.utils import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner import Runner, load_checkpoint from mmengine.runner import Runner, load_checkpoint
from mmengine.utils import mkdir_or_exist from mmengine.utils import mkdir_or_exist
from mmseg.registry import MODELS from mmseg.registry import MODELS
from mmseg.utils import register_all_modules
def parse_args(): def parse_args():
@ -32,8 +32,10 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
register_all_modules()
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
init_default_scope(cfg.get('default_scope', 'mmseg'))
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
if args.work_dir is not None: if args.work_dir is not None:
mkdir_or_exist(osp.abspath(args.work_dir)) mkdir_or_exist(osp.abspath(args.work_dir))

View File

@ -3,10 +3,10 @@ import argparse
import os.path as osp import os.path as osp
from mmengine import Config, DictAction from mmengine import Config, DictAction
from mmengine.registry import init_default_scope
from mmengine.utils import ProgressBar from mmengine.utils import ProgressBar
from mmseg.registry import DATASETS, VISUALIZERS from mmseg.registry import DATASETS, VISUALIZERS
from mmseg.utils import register_all_modules
def parse_args(): def parse_args():
@ -44,7 +44,7 @@ def main():
cfg.merge_from_dict(args.cfg_options) cfg.merge_from_dict(args.cfg_options)
# register all modules in mmseg into the registries # register all modules in mmseg into the registries
register_all_modules() init_default_scope('mmseg')
dataset = DATASETS.build(cfg.train_dataloader.dataset) dataset = DATASETS.build(cfg.train_dataloader.dataset)
cfg.visualizer['save_dir'] = args.output_dir cfg.visualizer['save_dir'] = args.output_dir

View File

@ -6,8 +6,6 @@ import os.path as osp
from mmengine.config import Config, DictAction from mmengine.config import Config, DictAction
from mmengine.runner import Runner from mmengine.runner import Runner
from mmseg.utils import register_all_modules
# TODO: support fuse_conv_bn, visualization, and format_only # TODO: support fuse_conv_bn, visualization, and format_only
def parse_args(): def parse_args():
@ -77,10 +75,6 @@ def trigger_visualization_hook(cfg, args):
def main(): def main():
args = parse_args() args = parse_args()
# register all modules in mmseg into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config # load config
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher cfg.launcher = args.launcher

View File

@ -6,10 +6,9 @@ import os.path as osp
from mmengine.config import Config, DictAction from mmengine.config import Config, DictAction
from mmengine.logging import print_log from mmengine.logging import print_log
from mmengine.registry import RUNNERS
from mmengine.runner import Runner from mmengine.runner import Runner
from mmseg.utils import register_all_modules from mmseg.registry import RUNNERS
def parse_args(): def parse_args():
@ -52,10 +51,6 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
# register all modules in mmseg into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config # load config
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher cfg.launcher = args.launcher