[Feature] Support auto import modules from registry. (#2481)
## Motivation The registry now supports auto-import modules from the given location. register_all_modules before running is no longer needed. The modules will be lazy-imported during building. - [x] This PR can be merged after https://github.com/open-mmlab/mmengine/pull/643. The MMEngine version should be updated. Ref: https://github.com/open-mmlab/mmdetection/pull/9143pull/2413/head
parent
2d38bc8554
commit
039ba5d4ca
|
@ -1,7 +1,10 @@
|
|||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/poolformer/poolformer-s12_3rdparty_32xb128_in1k_20220414-f8d83051.pth' # noqa
|
||||
custom_imports = dict(imports='mmcls.models', allow_failed_imports=False)
|
||||
# TODO: delete custom_imports after mmcls supports auto import
|
||||
# please install mmcls>=1.0
|
||||
# import mmcls.models to trigger register_module in mmcls
|
||||
custom_imports = dict(imports=['mmcls.models'], allow_failed_imports=False)
|
||||
data_preprocessor = dict(
|
||||
type='SegDataPreProcessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
_base_ = ['../_base_/default_runtime.py', '../_base_/datasets/cityscapes.py']
|
||||
|
||||
custom_imports = dict(imports='mmdet.models', allow_failed_imports=False)
|
||||
|
||||
crop_size = (512, 1024)
|
||||
data_preprocessor = dict(
|
||||
type='SegDataPreProcessor',
|
||||
|
|
|
@ -3,7 +3,6 @@ _base_ = [
|
|||
]
|
||||
|
||||
pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window12_384_20220317-55b0104a.pth' # noqa
|
||||
custom_imports = dict(imports='mmdet.models', allow_failed_imports=False)
|
||||
|
||||
crop_size = (640, 640)
|
||||
data_preprocessor = dict(
|
||||
|
|
|
@ -460,12 +460,8 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"from mmengine.runner import Runner\n",
|
||||
"from mmseg.utils import register_all_modules\n",
|
||||
"\n",
|
||||
"# register all modules in mmseg into the registries\n",
|
||||
"# do not init the default scope here because it will be init in the runner\n",
|
||||
"register_all_modules(init_default_scope=False)\n",
|
||||
"runner = Runner.from_cfg(cfg)\n"
|
||||
"runner = Runner.from_cfg(cfg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -523,7 +519,7 @@
|
|||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.8.5 ('tensorflow')",
|
||||
"display_name": "Python 3.10.6 ('pt1.12')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -537,7 +533,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
"version": "3.10.6"
|
||||
},
|
||||
"pycharm": {
|
||||
"stem_cell": {
|
||||
|
@ -550,7 +546,7 @@
|
|||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "20d4b83e0c8b3730b580c42434163d64f4b735d580303a8fade7c849d4d29eba"
|
||||
"hash": "0442e67aee3d9cbb788fa6e86d60c4ffa94ad7f1943c65abfecb99a6f4696c58"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
@ -4,7 +4,6 @@ from argparse import ArgumentParser
|
|||
from mmengine.model import revert_sync_batchnorm
|
||||
|
||||
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -24,8 +23,6 @@ def main():
|
|||
'--title', default='result', help='The image identifier.')
|
||||
args = parser.parse_args()
|
||||
|
||||
register_all_modules()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||
if args.device == 'cpu':
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -6,7 +6,6 @@ from mmengine.model.utils import revert_sync_batchnorm
|
|||
|
||||
from mmseg.apis import inference_model, init_model
|
||||
from mmseg.apis.inference import show_result_pyplot
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -53,8 +52,6 @@ def main():
|
|||
assert args.show or args.output_file, \
|
||||
'At least one output should be enabled.'
|
||||
|
||||
register_all_modules()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_model(args.config, args.checkpoint, device=args.device)
|
||||
if args.device == 'cpu':
|
||||
|
|
|
@ -15,8 +15,8 @@ Instantiate Cityscapes training dataset:
|
|||
|
||||
```python
|
||||
from mmseg.datasets import CityscapesDataset
|
||||
from mmseg.utils import register_all_modules
|
||||
register_all_modules()
|
||||
from mmengine.registry import init_default_scope
|
||||
init_default_scope('mmseg')
|
||||
|
||||
data_root = 'data/cityscapes/'
|
||||
data_prefix=dict(img_path='leftImg8bit/train', seg_map_path='gtFine/train')
|
||||
|
|
|
@ -91,10 +91,8 @@ Option (b). If you install mmsegmentation with pip, open you python interpreter
|
|||
|
||||
```python
|
||||
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
||||
from mmseg.utils import register_all_modules
|
||||
import mmcv
|
||||
|
||||
register_all_modules()
|
||||
config_file = 'pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
|
||||
checkpoint_file = 'pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
|
||||
|
||||
|
|
|
@ -31,14 +31,10 @@ Example:
|
|||
|
||||
```python
|
||||
from mmseg.apis import init_model
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
|
||||
checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
|
||||
|
||||
# register all modules in mmseg into the registries
|
||||
register_all_modules()
|
||||
|
||||
# initialize model without checkpoint
|
||||
model = init_model(config_path)
|
||||
|
||||
|
@ -76,14 +72,11 @@ Example:
|
|||
|
||||
```python
|
||||
from mmseg.apis import init_model, inference_model
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
|
||||
checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
|
||||
img_path = 'demo/demo.png'
|
||||
|
||||
# register all modules in mmseg into the registries
|
||||
register_all_modules()
|
||||
|
||||
model = init_model(config_path, checkpoint_path)
|
||||
result = inference_model(model, img_path)
|
||||
|
@ -115,14 +108,11 @@ Example:
|
|||
|
||||
```python
|
||||
from mmseg.apis import init_model, inference_model, show_result_pyplot
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
config_path = 'configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
|
||||
checkpoint_path = 'checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
|
||||
img_path = 'demo/demo.png'
|
||||
|
||||
# register all modules in mmseg into the registries
|
||||
register_all_modules()
|
||||
|
||||
# build the model from a config file and a checkpoint file
|
||||
model = init_model(config_path, checkpoint_path, device='cuda:0')
|
||||
|
|
|
@ -9,9 +9,10 @@
|
|||
实例化 Cityscapes 训练数据集:
|
||||
|
||||
```python
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmseg.datasets import CityscapesDataset
|
||||
from mmseg.utils import register_all_modules
|
||||
register_all_modules()
|
||||
|
||||
init_default_scope('mmseg')
|
||||
|
||||
data_root = 'data/cityscapes/'
|
||||
data_prefix=dict(img_path='leftImg8bit/train', seg_map_path='gtFine/train')
|
||||
|
|
|
@ -92,10 +92,8 @@ python demo/image_demo.py demo/demo.png configs/pspnet/pspnet_r50-d8_4xb2-40k_ci
|
|||
|
||||
```python
|
||||
from mmseg.apis import inference_model, init_model, show_result_pyplot
|
||||
from mmseg.utils import register_all_modules
|
||||
import mmcv
|
||||
|
||||
register_all_modules()
|
||||
config_file = 'pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
|
||||
checkpoint_file = 'pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from .version import __version__, version_info
|
|||
|
||||
MMCV_MIN = '2.0.0rc4'
|
||||
MMCV_MAX = '2.1.0'
|
||||
MMENGINE_MIN = '0.2.0'
|
||||
MMENGINE_MIN = '0.4.0'
|
||||
MMENGINE_MAX = '1.0.0'
|
||||
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import numpy as np
|
|||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import load_checkpoint
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
|
@ -48,6 +49,8 @@ def init_model(config: Union[str, Path, Config],
|
|||
config.model.backbone.init_cfg = None
|
||||
config.model.pretrained = None
|
||||
config.model.train_cfg = None
|
||||
init_default_scope(config.get('default_scope', 'mmseg'))
|
||||
|
||||
model = MODELS.build(config.model)
|
||||
if checkpoint is not None:
|
||||
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
|
||||
MODEL_WRAPPERS, MODELS, OPTIM_WRAPPER_CONSTRUCTORS,
|
||||
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS,
|
||||
RUNNERS, TASK_UTILS, TRANSFORMS, VISBACKENDS,
|
||||
VISUALIZERS, WEIGHT_INITIALIZERS)
|
||||
from .registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS, INFERENCERS,
|
||||
LOG_PROCESSORS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS,
|
||||
OPTIM_WRAPPER_CONSTRUCTORS, OPTIM_WRAPPERS, OPTIMIZERS,
|
||||
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS,
|
||||
TASK_UTILS, TRANSFORMS, VISBACKENDS, VISUALIZERS,
|
||||
WEIGHT_INITIALIZERS)
|
||||
|
||||
__all__ = [
|
||||
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS',
|
||||
'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS',
|
||||
'OPTIM_WRAPPER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS',
|
||||
'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS'
|
||||
'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS',
|
||||
'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIM_WRAPPER_CONSTRUCTORS',
|
||||
'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', 'MODEL_WRAPPERS',
|
||||
'VISBACKENDS', 'VISUALIZERS', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'LOOPS',
|
||||
'EVALUATOR', 'LOG_PROCESSORS', 'OPTIM_WRAPPERS', 'INFERENCERS'
|
||||
]
|
||||
|
|
|
@ -10,6 +10,7 @@ from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
|
|||
from mmengine.registry import DATASETS as MMENGINE_DATASETS
|
||||
from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
|
||||
from mmengine.registry import HOOKS as MMENGINE_HOOKS
|
||||
from mmengine.registry import INFERENCERS as MMENGINE_INFERENCERS
|
||||
from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS
|
||||
from mmengine.registry import LOOPS as MMENGINE_LOOPS
|
||||
from mmengine.registry import METRICS as MMENGINE_METRICS
|
||||
|
@ -39,45 +40,82 @@ RUNNER_CONSTRUCTORS = Registry(
|
|||
# manage all kinds of loops like `EpochBasedTrainLoop`
|
||||
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
|
||||
# manage all kinds of hooks like `CheckpointHook`
|
||||
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
|
||||
HOOKS = Registry(
|
||||
'hook', parent=MMENGINE_HOOKS, locations=['mmseg.engine.hooks'])
|
||||
|
||||
# manage data-related modules
|
||||
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
|
||||
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
|
||||
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
|
||||
DATASETS = Registry(
|
||||
'dataset', parent=MMENGINE_DATASETS, locations=['mmseg.datasets'])
|
||||
DATA_SAMPLERS = Registry(
|
||||
'data sampler',
|
||||
parent=MMENGINE_DATA_SAMPLERS,
|
||||
locations=['mmseg.datasets.samplers'])
|
||||
TRANSFORMS = Registry(
|
||||
'transform',
|
||||
parent=MMENGINE_TRANSFORMS,
|
||||
locations=['mmseg.datasets.transforms'])
|
||||
|
||||
# mangage all kinds of modules inheriting `nn.Module`
|
||||
MODELS = Registry('model', parent=MMENGINE_MODELS)
|
||||
MODELS = Registry('model', parent=MMENGINE_MODELS, locations=['mmseg.models'])
|
||||
# mangage all kinds of model wrappers like 'MMDistributedDataParallel'
|
||||
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
|
||||
MODEL_WRAPPERS = Registry(
|
||||
'model_wrapper',
|
||||
parent=MMENGINE_MODEL_WRAPPERS,
|
||||
locations=['mmseg.models'])
|
||||
# mangage all kinds of weight initialization modules like `Uniform`
|
||||
WEIGHT_INITIALIZERS = Registry(
|
||||
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
|
||||
'weight initializer',
|
||||
parent=MMENGINE_WEIGHT_INITIALIZERS,
|
||||
locations=['mmseg.models'])
|
||||
|
||||
# mangage all kinds of optimizers like `SGD` and `Adam`
|
||||
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
|
||||
OPTIMIZERS = Registry(
|
||||
'optimizer',
|
||||
parent=MMENGINE_OPTIMIZERS,
|
||||
locations=['mmseg.engine.optimizers'])
|
||||
# manage optimizer wrapper
|
||||
OPTIM_WRAPPERS = Registry('optim_wrapper', parent=MMENGINE_OPTIM_WRAPPERS)
|
||||
OPTIM_WRAPPERS = Registry(
|
||||
'optim_wrapper',
|
||||
parent=MMENGINE_OPTIM_WRAPPERS,
|
||||
locations=['mmseg.engine.optimizers'])
|
||||
# manage constructors that customize the optimization hyperparameters.
|
||||
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
|
||||
'optimizer wrapper constructor',
|
||||
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
|
||||
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS,
|
||||
locations=['mmseg.engine.optimizers'])
|
||||
# mangage all kinds of parameter schedulers like `MultiStepLR`
|
||||
PARAM_SCHEDULERS = Registry(
|
||||
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
|
||||
'parameter scheduler',
|
||||
parent=MMENGINE_PARAM_SCHEDULERS,
|
||||
locations=['mmseg.engine.schedulers'])
|
||||
|
||||
# manage all kinds of metrics
|
||||
METRICS = Registry('metric', parent=MMENGINE_METRICS)
|
||||
METRICS = Registry(
|
||||
'metric', parent=MMENGINE_METRICS, locations=['mmseg.evaluation'])
|
||||
# manage evaluator
|
||||
EVALUATOR = Registry('evaluator', parent=MMENGINE_EVALUATOR)
|
||||
EVALUATOR = Registry(
|
||||
'evaluator', parent=MMENGINE_EVALUATOR, locations=['mmseg.evaluation'])
|
||||
|
||||
# manage task-specific modules like ohem pixel sampler
|
||||
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
|
||||
TASK_UTILS = Registry(
|
||||
'task util', parent=MMENGINE_TASK_UTILS, locations=['mmseg.models'])
|
||||
|
||||
# manage visualizer
|
||||
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
|
||||
VISUALIZERS = Registry(
|
||||
'visualizer',
|
||||
parent=MMENGINE_VISUALIZERS,
|
||||
locations=['mmseg.visualization'])
|
||||
# manage visualizer backend
|
||||
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
|
||||
VISBACKENDS = Registry(
|
||||
'vis_backend',
|
||||
parent=MMENGINE_VISBACKENDS,
|
||||
locations=['mmseg.visualization'])
|
||||
|
||||
# manage logprocessor
|
||||
LOG_PROCESSORS = Registry('log_processor', parent=MMENGINE_LOG_PROCESSORS)
|
||||
LOG_PROCESSORS = Registry(
|
||||
'log_processor',
|
||||
parent=MMENGINE_LOG_PROCESSORS,
|
||||
locations=['mmseg.visualization'])
|
||||
|
||||
# manage inferencer
|
||||
INFERENCERS = Registry('inferencer', parent=MMENGINE_INFERENCERS)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
mmcls>=1.0.0rc0
|
||||
mmcv>=2.0.0rc4
|
||||
-e git+https://github.com/open-mmlab/mmdetection.git@dev-3.x#egg=mmdet
|
||||
mmengine>=0.2.0,<1.0.0
|
||||
mmengine>=0.4.0,<1.0.0
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
mmcv>=2.0.0rc1,<2.1.0
|
||||
mmengine>=0.1.0,<1.0.0
|
||||
mmengine>=0.4.0,<1.0.0
|
||||
prettytable
|
||||
scipy
|
||||
torch
|
||||
|
|
|
@ -6,10 +6,10 @@ from os.path import dirname, exists, isdir, join, relpath
|
|||
import numpy as np
|
||||
from mmengine import Config
|
||||
from mmengine.dataset import Compose
|
||||
from mmengine.registry import init_default_scope
|
||||
from torch import nn
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def _get_config_directory():
|
||||
|
@ -70,7 +70,7 @@ def test_config_data_pipeline():
|
|||
xdoctest -m tests/test_config.py test_config_build_data_pipeline
|
||||
"""
|
||||
|
||||
register_all_modules()
|
||||
init_default_scope('mmseg')
|
||||
config_dpath = _get_config_directory()
|
||||
print(f'Found config_dpath = {config_dpath!r}')
|
||||
|
||||
|
|
|
@ -2,12 +2,12 @@
|
|||
import os.path as osp
|
||||
|
||||
from mmengine.dataset import ConcatDataset, RepeatDataset
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmseg.datasets import MultiImageMixDataset
|
||||
from mmseg.registry import DATASETS
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
init_default_scope('mmseg')
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
|
|
@ -5,6 +5,7 @@ import os.path as osp
|
|||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mmengine.registry import init_default_scope
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.datasets.transforms import * # noqa
|
||||
|
@ -12,9 +13,8 @@ from mmseg.datasets.transforms import (LoadBiomedicalData,
|
|||
LoadBiomedicalImageFromFile,
|
||||
PhotoMetricDistortion, RandomCrop)
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
init_default_scope('mmseg')
|
||||
|
||||
|
||||
def test_resize():
|
||||
|
|
|
@ -5,12 +5,12 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.optim.optimizer import build_optim_wrapper
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmseg.engine.optimizers.layer_decay_optimizer_constructor import \
|
||||
LearningRateDecayOptimizerConstructor
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
init_default_scope('mmseg')
|
||||
|
||||
base_lr = 1
|
||||
decay_rate = 2
|
||||
|
|
|
@ -2,14 +2,14 @@
|
|||
import pytest
|
||||
import torch
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.registry import init_default_scope
|
||||
|
||||
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
|
||||
InterpConv, UNet, UpConvBlock)
|
||||
from mmseg.models.utils import Upsample
|
||||
from mmseg.utils import register_all_modules
|
||||
from .utils import check_norm_state
|
||||
|
||||
register_all_modules()
|
||||
init_default_scope('mmseg')
|
||||
|
||||
|
||||
def test_unet_basic_conv_block():
|
||||
|
|
|
@ -9,14 +9,14 @@ import pytest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model.utils import revert_sync_batchnorm
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.structures import PixelData
|
||||
from mmengine.utils import is_list_of, is_tuple_of
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
init_default_scope('mmseg')
|
||||
|
||||
|
||||
def _demo_mm_inputs(batch_size=2, image_shapes=(3, 32, 32), num_classes=5):
|
||||
|
|
|
@ -3,15 +3,15 @@ from os.path import dirname, join
|
|||
|
||||
import torch
|
||||
from mmengine import Config
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def test_maskformer_head():
|
||||
register_all_modules()
|
||||
init_default_scope('mmseg')
|
||||
repo_dpath = dirname(dirname(__file__))
|
||||
cfg = Config.fromfile(
|
||||
join(
|
||||
|
|
|
@ -2,14 +2,14 @@
|
|||
import torch
|
||||
from mmengine import ConfigDict
|
||||
from mmengine.model import BaseTTAModel
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.structures import PixelData
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import register_all_modules
|
||||
from .utils import * # noqa: F401,F403
|
||||
|
||||
register_all_modules()
|
||||
init_default_scope('mmseg')
|
||||
|
||||
|
||||
def test_encoder_decoder_tta():
|
||||
|
|
|
@ -8,11 +8,11 @@ import torch
|
|||
from mmengine import Config
|
||||
from mmengine.fileio import dump
|
||||
from mmengine.model.utils import revert_sync_batchnorm
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.runner import Runner, load_checkpoint
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -32,8 +32,10 @@ def parse_args():
|
|||
|
||||
def main():
|
||||
args = parse_args()
|
||||
register_all_modules()
|
||||
cfg = Config.fromfile(args.config)
|
||||
|
||||
init_default_scope(cfg.get('default_scope', 'mmseg'))
|
||||
|
||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
if args.work_dir is not None:
|
||||
mkdir_or_exist(osp.abspath(args.work_dir))
|
||||
|
|
|
@ -3,10 +3,10 @@ import argparse
|
|||
import os.path as osp
|
||||
|
||||
from mmengine import Config, DictAction
|
||||
from mmengine.registry import init_default_scope
|
||||
from mmengine.utils import ProgressBar
|
||||
|
||||
from mmseg.registry import DATASETS, VISUALIZERS
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -44,7 +44,7 @@ def main():
|
|||
cfg.merge_from_dict(args.cfg_options)
|
||||
|
||||
# register all modules in mmseg into the registries
|
||||
register_all_modules()
|
||||
init_default_scope('mmseg')
|
||||
|
||||
dataset = DATASETS.build(cfg.train_dataloader.dataset)
|
||||
cfg.visualizer['save_dir'] = args.output_dir
|
||||
|
|
|
@ -6,8 +6,6 @@ import os.path as osp
|
|||
from mmengine.config import Config, DictAction
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
# TODO: support fuse_conv_bn, visualization, and format_only
|
||||
def parse_args():
|
||||
|
@ -77,10 +75,6 @@ def trigger_visualization_hook(cfg, args):
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmseg into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.launcher = args.launcher
|
||||
|
|
|
@ -6,10 +6,9 @@ import os.path as osp
|
|||
|
||||
from mmengine.config import Config, DictAction
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import RUNNERS
|
||||
from mmengine.runner import Runner
|
||||
|
||||
from mmseg.utils import register_all_modules
|
||||
from mmseg.registry import RUNNERS
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -52,10 +51,6 @@ def parse_args():
|
|||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# register all modules in mmseg into the registries
|
||||
# do not init the default scope here because it will be init in the runner
|
||||
register_all_modules(init_default_scope=False)
|
||||
|
||||
# load config
|
||||
cfg = Config.fromfile(args.config)
|
||||
cfg.launcher = args.launcher
|
||||
|
|
Loading…
Reference in New Issue