[Feature] Support auto import modules from registry. (#2481)

## Motivation

The registry now supports auto-import modules from the given location.

register_all_modules before running is no longer needed. The modules
will be lazy-imported during building.

- [x] This PR can be merged after
https://github.com/open-mmlab/mmengine/pull/643. The MMEngine version
should be updated.

Ref: https://github.com/open-mmlab/mmdetection/pull/9143
pull/2413/head
谢昕辰 2023-02-23 20:33:17 +08:00 committed by GitHub
parent 2d38bc8554
commit 039ba5d4ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 188 additions and 113 deletions

View File

@ -1,7 +1,10 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s12_3rdparty_32xb128_in1k_20220414-f8d83051.pth' # noqa
custom_imports = dict(imports='mmcls.models', allow_failed_imports=False)
# TODO: delete custom_imports after mmcls supports auto import
# please install mmcls>=1.0
# import mmcls.models to trigger register_module in mmcls
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],

View File

@ -1,7 +1,5 @@
_base_ = ['../_base_/default_runtime.py', '../_base_/datasets/cityscapes.py']
custom_imports = dict(imports='mmdet.models', allow_failed_imports=False)
crop_size = (512, 1024)
data_preprocessor = dict(
type='SegDataPreProcessor',

View File

@ -3,7 +3,6 @@ _base_ = [
]
pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pth' # noqa
custom_imports = dict(imports='mmdet.models', allow_failed_imports=False)
crop_size = (640, 640)
data_preprocessor = dict(

View File

@ -460,12 +460,8 @@
"outputs": [],
"source": [
"from mmengine.runner import Runner\n",
"from mmseg.utils import register_all_modules\n",
"\n",
"# register all modules in mmseg into the registries\n",
"# do not init the default scope here because it will be init in the runner\n",
"register_all_modules(init_default_scope=False)\n",
"runner = Runner.from_cfg(cfg)\n"
"runner = Runner.from_cfg(cfg)"
]
},
{
@ -523,7 +519,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3.8.5 ('tensorflow')",
"display_name": "Python 3.10.6 ('pt1.12')",
"language": "python",
"name": "python3"
},
@ -537,7 +533,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.10.6"
},
"pycharm": {
"stem_cell": {
@ -550,7 +546,7 @@
},
"vscode": {
"interpreter": {
"hash": "20d4b83e0c8b3730b580c42434163d64f4b735d580303a8fade7c849d4d29eba"
"hash": "0442e67aee3d9cbb788fa6e86d60c4ffa94ad7f1943c65abfecb99a6f4696c58"
}
}
},

View File

@ -4,7 +4,6 @@ from argparse import ArgumentParser
from mmengine.model import revert_sync_batchnorm
from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import register_all_modules
def main():
@ -24,8 +23,6 @@ def main():
'--title', default='result', help='The image identifier.')
args = parser.parse_args()
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
if args.device == 'cpu':

File diff suppressed because one or more lines are too long

View File

@ -6,7 +6,6 @@ from mmengine.model.utils import revert_sync_batchnorm
from mmseg.apis import inference_model, init_model
from mmseg.apis.inference import show_result_pyplot
from mmseg.utils import register_all_modules
def main():
@ -53,8 +52,6 @@ def main():
assert args.show or args.output_file, \
'At least one output should be enabled.'
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)
if args.device == 'cpu':

View File

@ -15,8 +15,8 @@ Instantiate Cityscapes training dataset:
```python
from mmseg.datasets import CityscapesDataset
from mmseg.utils import register_all_modules
register_all_modules()
from mmengine.registry import init_default_scope
init_default_scope('mmseg')
data_root = 'data/cityscapes/'
data_prefix=dict(img_path='leftImg8bit/train', seg_map_path='gtFine/train')

View File

@ -91,10 +91,8 @@ Option (b). If you install mmsegmentation with pip, open you python interpreter
```python
from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import register_all_modules
import mmcv
register_all_modules()
config_file = 'pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_file = 'pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'

View File

@ -31,14 +31,10 @@ Example:
```python
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
# register all modules in mmseg into the registries
register_all_modules()
# initialize model without checkpoint
model = init_model(config_path)
@ -76,14 +72,11 @@ Example:
```python
from mmseg.apis import init_model, inference_model
from mmseg.utils import register_all_modules
config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
img_path = 'demo/demo.png'
# register all modules in mmseg into the registries
register_all_modules()
model = init_model(config_path, checkpoint_path)
result = inference_model(model, img_path)
@ -115,14 +108,11 @@ Example:
```python
from mmseg.apis import init_model, inference_model, show_result_pyplot
from mmseg.utils import register_all_modules
config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
img_path = 'demo/demo.png'
# register all modules in mmseg into the registries
register_all_modules()
# build the model from a config file and a checkpoint file
model = init_model(config_path, checkpoint_path, device='cuda:0')

View File

@ -9,9 +9,10 @@
实例化 Cityscapes 训练数据集:
```python
from mmengine.registry import init_default_scope
from mmseg.datasets import CityscapesDataset
from mmseg.utils import register_all_modules
register_all_modules()
init_default_scope('mmseg')
data_root = 'data/cityscapes/'
data_prefix=dict(img_path='leftImg8bit/train', seg_map_path='gtFine/train')

View File

@ -92,10 +92,8 @@ python demo/image_demo.py demo/demo.png configs/pspnet/pspnet_r50-d8_4xb2-40k_ci
```python
from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import register_all_modules
import mmcv
register_all_modules()
config_file = 'pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_file = 'pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'

View File

@ -9,7 +9,7 @@ from .version import __version__, version_info
MMCV_MIN = '2.0.0rc4'
MMCV_MAX = '2.1.0'
MMENGINE_MIN = '0.2.0'
MMENGINE_MIN = '0.4.0'
MMENGINE_MAX = '1.0.0'

View File

@ -9,6 +9,7 @@ import numpy as np
import torch
from mmengine import Config
from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint
from mmengine.utils import mkdir_or_exist
@ -48,6 +49,8 @@ def init_model(config: Union[str, Path, Config],
config.model.backbone.init_cfg = None
config.model.pretrained = None
config.model.train_cfg = None
init_default_scope(config.get('default_scope', 'mmseg'))
model = MODELS.build(config.model)
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')

View File

@ -1,13 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS,
RUNNERS, TASK_UTILS, TRANSFORMS, VISBACKENDS,
VISUALIZERS, WEIGHT_INITIALIZERS)
from .registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, INFERENCERS,
LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS,
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,
TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,
WEIGHT_INITIALIZERS)
__all__ = [
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS',
'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS',
'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS',
'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS'
'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS',
'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS',
'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS',
'VISBACKENDS', 'VISUALIZERS', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS',
'EVALUATOR', 'LOG_PROCESSORS', 'OPTIM_WRAPPERS', 'INFERENCERS'
]

View File

@ -10,6 +10,7 @@ from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
from mmengine.registry import DATASETS as MMENGINE_DATASETS
from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
from mmengine.registry import HOOKS as MMENGINE_HOOKS
from mmengine.registry import INFERENCERS as MMENGINE_INFERENCERS
from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS
from mmengine.registry import LOOPS as MMENGINE_LOOPS
from mmengine.registry import METRICS as MMENGINE_METRICS
@ -39,45 +40,82 @@ RUNNER_CONSTRUCTORS = Registry(
# manage all kinds of loops like `EpochBasedTrainLoop`
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
# manage all kinds of hooks like `CheckpointHook`
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
HOOKS = Registry(
'hook', parent=MMENGINE_HOOKS, locations=['mmseg.engine.hooks'])
# manage data-related modules
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
DATASETS = Registry(
'dataset', parent=MMENGINE_DATASETS, locations=['mmseg.datasets'])
DATA_SAMPLERS = Registry(
'data sampler',
parent=MMENGINE_DATA_SAMPLERS,
locations=['mmseg.datasets.samplers'])
TRANSFORMS = Registry(
'transform',
parent=MMENGINE_TRANSFORMS,
locations=['mmseg.datasets.transforms'])
# mangage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS)
MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmseg.models'])
# mangage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
MODEL_WRAPPERS = Registry(
'model_wrapper',
parent=MMENGINE_MODEL_WRAPPERS,
locations=['mmseg.models'])
# mangage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry(
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
'weight initializer',
parent=MMENGINE_WEIGHT_INITIALIZERS,
locations=['mmseg.models'])
# mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
OPTIMIZERS = Registry(
'optimizer',
parent=MMENGINE_OPTIMIZERS,
locations=['mmseg.engine.optimizers'])
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim_wrapper', parent=MMENGINE_OPTIM_WRAPPERS)
OPTIM_WRAPPERS = Registry(
'optim_wrapper',
parent=MMENGINE_OPTIM_WRAPPERS,
locations=['mmseg.engine.optimizers'])
# manage constructors that customize the optimization hyperparameters.
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
'optimizer wrapper constructor',
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS,
locations=['mmseg.engine.optimizers'])
# mangage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
'parameter scheduler',
parent=MMENGINE_PARAM_SCHEDULERS,
locations=['mmseg.engine.schedulers'])
# manage all kinds of metrics
METRICS = Registry('metric', parent=MMENGINE_METRICS)
METRICS = Registry(
'metric', parent=MMENGINE_METRICS, locations=['mmseg.evaluation'])
# manage evaluator
EVALUATOR = Registry('evaluator', parent=MMENGINE_EVALUATOR)
EVALUATOR = Registry(
'evaluator', parent=MMENGINE_EVALUATOR, locations=['mmseg.evaluation'])
# manage task-specific modules like ohem pixel sampler
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
TASK_UTILS = Registry(
'task util', parent=MMENGINE_TASK_UTILS, locations=['mmseg.models'])
# manage visualizer
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
VISUALIZERS = Registry(
'visualizer',
parent=MMENGINE_VISUALIZERS,
locations=['mmseg.visualization'])
# manage visualizer backend
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
VISBACKENDS = Registry(
'vis_backend',
parent=MMENGINE_VISBACKENDS,
locations=['mmseg.visualization'])
# manage logprocessor
LOG_PROCESSORS = Registry('log_processor', parent=MMENGINE_LOG_PROCESSORS)
LOG_PROCESSORS = Registry(
'log_processor',
parent=MMENGINE_LOG_PROCESSORS,
locations=['mmseg.visualization'])
# manage inferencer
INFERENCERS = Registry('inferencer', parent=MMENGINE_INFERENCERS)

View File

@ -1,4 +1,4 @@
mmcls>=1.0.0rc0
mmcv>=2.0.0rc4
-e git+https://github.com/open-mmlab/mmdetection.git@dev-3.x#egg=mmdet
mmengine>=0.2.0,<1.0.0
mmengine>=0.4.0,<1.0.0

View File

@ -1,5 +1,5 @@
mmcv>=2.0.0rc1,<2.1.0
mmengine>=0.1.0,<1.0.0
mmengine>=0.4.0,<1.0.0
prettytable
scipy
torch

View File

@ -6,10 +6,10 @@ from os.path import dirname, exists, isdir, join, relpath
import numpy as np
from mmengine import Config
from mmengine.dataset import Compose
from mmengine.registry import init_default_scope
from torch import nn
from mmseg.models import build_segmentor
from mmseg.utils import register_all_modules
def _get_config_directory():
@ -70,7 +70,7 @@ def test_config_data_pipeline():
xdoctest -m tests/test_config.py test_config_build_data_pipeline
"""
register_all_modules()
init_default_scope('mmseg')
config_dpath = _get_config_directory()
print(f'Found config_dpath = {config_dpath!r}')

View File

@ -2,12 +2,12 @@
import os.path as osp
from mmengine.dataset import ConcatDataset, RepeatDataset
from mmengine.registry import init_default_scope
from mmseg.datasets import MultiImageMixDataset
from mmseg.registry import DATASETS
from mmseg.utils import register_all_modules
register_all_modules()
init_default_scope('mmseg')
@DATASETS.register_module()

View File

@ -5,6 +5,7 @@ import os.path as osp
import mmcv
import numpy as np
import pytest
from mmengine.registry import init_default_scope
from PIL import Image
from mmseg.datasets.transforms import * # noqa
@ -12,9 +13,8 @@ from mmseg.datasets.transforms import (LoadBiomedicalData,
LoadBiomedicalImageFromFile,
PhotoMetricDistortion, RandomCrop)
from mmseg.registry import TRANSFORMS
from mmseg.utils import register_all_modules
register_all_modules()
init_default_scope('mmseg')
def test_resize():

View File

@ -5,12 +5,12 @@ import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.optim.optimizer import build_optim_wrapper
from mmengine.registry import init_default_scope
from mmseg.engine.optimizers.layer_decay_optimizer_constructor import \
LearningRateDecayOptimizerConstructor
from mmseg.utils import register_all_modules
register_all_modules()
init_default_scope('mmseg')
base_lr = 1
decay_rate = 2

View File

@ -2,14 +2,14 @@
import pytest
import torch
from mmcv.cnn import ConvModule
from mmengine.registry import init_default_scope
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
InterpConv, UNet, UpConvBlock)
from mmseg.models.utils import Upsample
from mmseg.utils import register_all_modules
from .utils import check_norm_state
register_all_modules()
init_default_scope('mmseg')
def test_unet_basic_conv_block():

View File

@ -9,14 +9,14 @@ import pytest
import torch
import torch.nn as nn
from mmengine.model.utils import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.structures import PixelData
from mmengine.utils import is_list_of, is_tuple_of
from torch import Tensor
from mmseg.structures import SegDataSample
from mmseg.utils import register_all_modules
register_all_modules()
init_default_scope('mmseg')
def _demo_mm_inputs(batch_size=2, image_shapes=(3, 32, 32), num_classes=5):

View File

@ -3,15 +3,15 @@ from os.path import dirname, join
import torch
from mmengine import Config
from mmengine.registry import init_default_scope
from mmengine.structures import PixelData
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
from mmseg.utils import register_all_modules
def test_maskformer_head():
register_all_modules()
init_default_scope('mmseg')
repo_dpath = dirname(dirname(__file__))
cfg = Config.fromfile(
join(

View File

@ -2,14 +2,14 @@
import torch
from mmengine import ConfigDict
from mmengine.model import BaseTTAModel
from mmengine.registry import init_default_scope
from mmengine.structures import PixelData
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
from mmseg.utils import register_all_modules
from .utils import * # noqa: F401,F403
register_all_modules()
init_default_scope('mmseg')
def test_encoder_decoder_tta():

View File

@ -8,11 +8,11 @@ import torch
from mmengine import Config
from mmengine.fileio import dump
from mmengine.model.utils import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.runner import Runner, load_checkpoint
from mmengine.utils import mkdir_or_exist
from mmseg.registry import MODELS
from mmseg.utils import register_all_modules
def parse_args():
@ -32,8 +32,10 @@ def parse_args():
def main():
args = parse_args()
register_all_modules()
cfg = Config.fromfile(args.config)
init_default_scope(cfg.get('default_scope', 'mmseg'))
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
if args.work_dir is not None:
mkdir_or_exist(osp.abspath(args.work_dir))

View File

@ -3,10 +3,10 @@ import argparse
import os.path as osp
from mmengine import Config, DictAction
from mmengine.registry import init_default_scope
from mmengine.utils import ProgressBar
from mmseg.registry import DATASETS, VISUALIZERS
from mmseg.utils import register_all_modules
def parse_args():
@ -44,7 +44,7 @@ def main():
cfg.merge_from_dict(args.cfg_options)
# register all modules in mmseg into the registries
register_all_modules()
init_default_scope('mmseg')
dataset = DATASETS.build(cfg.train_dataloader.dataset)
cfg.visualizer['save_dir'] = args.output_dir

View File

@ -6,8 +6,6 @@ import os.path as osp
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
from mmseg.utils import register_all_modules
# TODO: support fuse_conv_bn, visualization, and format_only
def parse_args():
@ -77,10 +75,6 @@ def trigger_visualization_hook(cfg, args):
def main():
args = parse_args()
# register all modules in mmseg into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher

View File

@ -6,10 +6,9 @@ import os.path as osp
from mmengine.config import Config, DictAction
from mmengine.logging import print_log
from mmengine.registry import RUNNERS
from mmengine.runner import Runner
from mmseg.utils import register_all_modules
from mmseg.registry import RUNNERS
def parse_args():
@ -52,10 +51,6 @@ def parse_args():
def main():
args = parse_args()
# register all modules in mmseg into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher