From 7a32d610e41bec47a6955186331699fa853267c7 Mon Sep 17 00:00:00 2001 From: "linfangjian.vendor" Date: Tue, 10 May 2022 12:15:20 +0000 Subject: [PATCH] [Refactor] Refactor all registries --- mmseg/core/__init__.py | 7 +- mmseg/core/builder.py | 12 +--- .../layer_decay_optimizer_constructor.py | 9 +-- mmseg/core/seg/builder.py | 11 ++- mmseg/datasets/ade.py | 2 +- mmseg/datasets/builder.py | 8 +-- mmseg/datasets/chase_db1.py | 2 +- mmseg/datasets/cityscapes.py | 2 +- mmseg/datasets/coco_stuff.py | 2 +- mmseg/datasets/custom.py | 2 +- mmseg/datasets/dark_zurich.py | 2 +- mmseg/datasets/dataset_wrappers.py | 6 +- mmseg/datasets/drive.py | 2 +- mmseg/datasets/hrf.py | 2 +- mmseg/datasets/isaid.py | 2 +- mmseg/datasets/isprs.py | 2 +- mmseg/datasets/loveda.py | 2 +- mmseg/datasets/night_driving.py | 2 +- mmseg/datasets/pascal_context.py | 2 +- mmseg/datasets/pipelines/compose.py | 8 +-- mmseg/datasets/pipelines/formatting.py | 14 ++-- mmseg/datasets/pipelines/loading.py | 6 +- mmseg/datasets/pipelines/test_time_aug.py | 4 +- mmseg/datasets/pipelines/transforms.py | 32 ++++----- mmseg/datasets/potsdam.py | 2 +- .../datasets/samplers/distributed_sampler.py | 2 + mmseg/datasets/stare.py | 2 +- mmseg/datasets/voc.py | 2 +- mmseg/models/backbones/beit.py | 4 +- mmseg/models/backbones/bisenetv1.py | 6 +- mmseg/models/backbones/bisenetv2.py | 4 +- mmseg/models/backbones/cgnet.py | 4 +- mmseg/models/backbones/erfnet.py | 4 +- mmseg/models/backbones/fast_scnn.py | 4 +- mmseg/models/backbones/hrnet.py | 4 +- mmseg/models/backbones/icnet.py | 6 +- mmseg/models/backbones/mae.py | 4 +- mmseg/models/backbones/mit.py | 4 +- mmseg/models/backbones/mobilenet_v2.py | 4 +- mmseg/models/backbones/mobilenet_v3.py | 4 +- mmseg/models/backbones/resnest.py | 4 +- mmseg/models/backbones/resnet.py | 8 +-- mmseg/models/backbones/resnext.py | 4 +- mmseg/models/backbones/stdc.py | 8 +-- mmseg/models/backbones/swin.py | 4 +- mmseg/models/backbones/timm_backbone.py | 4 +- mmseg/models/backbones/twins.py | 6 +- mmseg/models/backbones/unet.py | 4 +- mmseg/models/backbones/vit.py | 4 +- mmseg/models/builder.py | 15 ++-- mmseg/models/decode_heads/ann_head.py | 4 +- mmseg/models/decode_heads/apc_head.py | 4 +- mmseg/models/decode_heads/aspp_head.py | 4 +- mmseg/models/decode_heads/cc_head.py | 4 +- mmseg/models/decode_heads/da_head.py | 4 +- mmseg/models/decode_heads/dm_head.py | 4 +- mmseg/models/decode_heads/dnl_head.py | 4 +- mmseg/models/decode_heads/dpt_head.py | 4 +- mmseg/models/decode_heads/ema_head.py | 4 +- mmseg/models/decode_heads/enc_head.py | 5 +- mmseg/models/decode_heads/fcn_head.py | 4 +- mmseg/models/decode_heads/fpn_head.py | 4 +- mmseg/models/decode_heads/gc_head.py | 4 +- mmseg/models/decode_heads/isa_head.py | 4 +- mmseg/models/decode_heads/knet_head.py | 10 +-- mmseg/models/decode_heads/lraspp_head.py | 4 +- mmseg/models/decode_heads/nl_head.py | 4 +- mmseg/models/decode_heads/ocr_head.py | 4 +- mmseg/models/decode_heads/point_head.py | 4 +- mmseg/models/decode_heads/psa_head.py | 4 +- mmseg/models/decode_heads/psp_head.py | 4 +- mmseg/models/decode_heads/segformer_head.py | 4 +- .../decode_heads/segmenter_mask_head.py | 4 +- mmseg/models/decode_heads/sep_aspp_head.py | 4 +- mmseg/models/decode_heads/sep_fcn_head.py | 4 +- mmseg/models/decode_heads/setr_mla_head.py | 4 +- mmseg/models/decode_heads/setr_up_head.py | 4 +- mmseg/models/decode_heads/stdc_head.py | 4 +- mmseg/models/decode_heads/uper_head.py | 4 +- mmseg/models/losses/cross_entropy_loss.py | 4 +- mmseg/models/losses/dice_loss.py | 4 +- mmseg/models/losses/focal_loss.py | 4 +- mmseg/models/losses/lovasz_loss.py | 4 +- mmseg/models/necks/featurepyramid.py | 4 +- mmseg/models/necks/fpn.py | 4 +- mmseg/models/necks/ic_neck.py | 4 +- mmseg/models/necks/jpu.py | 4 +- mmseg/models/necks/mla_neck.py | 4 +- mmseg/models/necks/multilevel_neck.py | 4 +- .../segmentors/cascade_encoder_decoder.py | 7 +- mmseg/models/segmentors/encoder_decoder.py | 15 ++-- mmseg/registry/__init__.py | 13 ++++ mmseg/registry/registry.py | 71 +++++++++++++++++++ tests/test_core/test_optimizer.py | 8 +-- 94 files changed, 312 insertions(+), 229 deletions(-) create mode 100644 mmseg/registry/__init__.py create mode 100644 mmseg/registry/registry.py diff --git a/mmseg/core/__init__.py b/mmseg/core/__init__.py index 1a077d2f1..0f2fcf13c 100644 --- a/mmseg/core/__init__.py +++ b/mmseg/core/__init__.py @@ -1,11 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import (OPTIMIZER_BUILDERS, build_optimizer, - build_optimizer_constructor) +from .builder import build_optimizer, build_optimizer_constructor from .evaluation import * # noqa: F401, F403 from .optimizers import * # noqa: F401, F403 from .seg import * # noqa: F401, F403 from .utils import * # noqa: F401, F403 -__all__ = [ - 'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor' -] +__all__ = ['build_optimizer', 'build_optimizer_constructor'] diff --git a/mmseg/core/builder.py b/mmseg/core/builder.py index 406dd9b4b..b155a4166 100644 --- a/mmseg/core/builder.py +++ b/mmseg/core/builder.py @@ -1,19 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS -from mmcv.utils import Registry, build_from_cfg - -OPTIMIZER_BUILDERS = Registry( - 'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS) +from mmseg.registry import OPTIMIZER_CONSTRUCTORS def build_optimizer_constructor(cfg): constructor_type = cfg.get('type') - if constructor_type in OPTIMIZER_BUILDERS: - return build_from_cfg(cfg, OPTIMIZER_BUILDERS) - elif constructor_type in MMCV_OPTIMIZER_BUILDERS: - return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS) + if constructor_type in OPTIMIZER_CONSTRUCTORS: + return OPTIMIZER_CONSTRUCTORS.build(cfg) else: raise KeyError(f'{constructor_type} is not registered ' 'in the optimizer builder registry.') diff --git a/mmseg/core/optimizers/layer_decay_optimizer_constructor.py b/mmseg/core/optimizers/layer_decay_optimizer_constructor.py index 2b6b8ff9c..6b86944f7 100644 --- a/mmseg/core/optimizers/layer_decay_optimizer_constructor.py +++ b/mmseg/core/optimizers/layer_decay_optimizer_constructor.py @@ -2,10 +2,11 @@ import json import warnings -from mmcv.runner import DefaultOptimizerConstructor, get_dist_info +from mmengine.dist import get_dist_info +from mmengine.optim import DefaultOptimizerConstructor +from mmseg.registry import OPTIMIZER_CONSTRUCTORS from mmseg.utils import get_root_logger -from ..builder import OPTIMIZER_BUILDERS def get_layer_id_for_convnext(var_name, max_layer_id): @@ -99,7 +100,7 @@ def get_layer_id_for_vit(var_name, max_layer_id): return max_layer_id - 1 -@OPTIMIZER_BUILDERS.register_module() +@OPTIMIZER_CONSTRUCTORS.register_module() class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor): """Different learning rates are set for different layers of backbone. @@ -185,7 +186,7 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor): params.extend(parameter_groups.values()) -@OPTIMIZER_BUILDERS.register_module() +@OPTIMIZER_CONSTRUCTORS.register_module() class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor): """Different learning rates are set for different layers of backbone. diff --git a/mmseg/core/seg/builder.py b/mmseg/core/seg/builder.py index 1cecd347b..48e147902 100644 --- a/mmseg/core/seg/builder.py +++ b/mmseg/core/seg/builder.py @@ -1,9 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmcv.utils import Registry, build_from_cfg +import warnings -PIXEL_SAMPLERS = Registry('pixel sampler') +from mmseg.registry import TASK_UTILS + +PIXEL_SAMPLERS = TASK_UTILS def build_pixel_sampler(cfg, **default_args): """Build pixel sampler for segmentation map.""" - return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) + warnings.warn( + '``build_pixel_sampler`` would be deprecated soon, please use ' + '``mmseg.registry.TASK_UTILS.build()`` ') + return TASK_UTILS.build(cfg, default_args=default_args) diff --git a/mmseg/datasets/ade.py b/mmseg/datasets/ade.py index db94cebd3..440c25256 100644 --- a/mmseg/datasets/ade.py +++ b/mmseg/datasets/ade.py @@ -5,7 +5,7 @@ import mmcv import numpy as np from PIL import Image -from .builder import DATASETS +from mmseg.registry import DATASETS from .custom import CustomDataset diff --git a/mmseg/datasets/builder.py b/mmseg/datasets/builder.py index 4d852d365..0eba0e1dd 100644 --- a/mmseg/datasets/builder.py +++ b/mmseg/datasets/builder.py @@ -8,9 +8,10 @@ import numpy as np import torch from mmcv.parallel import collate from mmcv.runner import get_dist_info -from mmcv.utils import Registry, build_from_cfg, digit_version +from mmcv.utils import digit_version from torch.utils.data import DataLoader +from mmseg.registry import DATASETS, TRANSFORMS from .samplers import DistributedSampler if platform.system() != 'Windows': @@ -22,8 +23,7 @@ if platform.system() != 'Windows': soft_limit = min(max(4096, base_soft_limit), hard_limit) resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) -DATASETS = Registry('dataset') -PIPELINES = Registry('pipeline') +PIPELINES = TRANSFORMS def _concat_dataset(cfg, default_args=None): @@ -82,7 +82,7 @@ def build_dataset(cfg, default_args=None): cfg.get('split', None), (list, tuple)): dataset = _concat_dataset(cfg, default_args) else: - dataset = build_from_cfg(cfg, DATASETS, default_args) + dataset = DATASETS.build(cfg, default_args=default_args) return dataset diff --git a/mmseg/datasets/chase_db1.py b/mmseg/datasets/chase_db1.py index 5cdc8d8d4..a5e35e86a 100644 --- a/mmseg/datasets/chase_db1.py +++ b/mmseg/datasets/chase_db1.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import DATASETS +from mmseg.registry import DATASETS from .custom import CustomDataset diff --git a/mmseg/datasets/cityscapes.py b/mmseg/datasets/cityscapes.py index ed633d00d..2e44db7a7 100644 --- a/mmseg/datasets/cityscapes.py +++ b/mmseg/datasets/cityscapes.py @@ -6,7 +6,7 @@ import numpy as np from mmcv.utils import print_log from PIL import Image -from .builder import DATASETS +from mmseg.registry import DATASETS from .custom import CustomDataset diff --git a/mmseg/datasets/coco_stuff.py b/mmseg/datasets/coco_stuff.py index 24d089556..6f00ff602 100644 --- a/mmseg/datasets/coco_stuff.py +++ b/mmseg/datasets/coco_stuff.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import DATASETS +from mmseg.registry import DATASETS from .custom import CustomDataset diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index 4615d4114..f2920cb2f 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -10,8 +10,8 @@ from prettytable import PrettyTable from torch.utils.data import Dataset from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics +from mmseg.registry import DATASETS from mmseg.utils import get_root_logger -from .builder import DATASETS from .pipelines import Compose, LoadAnnotations diff --git a/mmseg/datasets/dark_zurich.py b/mmseg/datasets/dark_zurich.py index 0b6fda6e9..83bb16bbb 100644 --- a/mmseg/datasets/dark_zurich.py +++ b/mmseg/datasets/dark_zurich.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import DATASETS +from mmseg.registry import DATASETS from .cityscapes import CityscapesDataset diff --git a/mmseg/datasets/dataset_wrappers.py b/mmseg/datasets/dataset_wrappers.py index 1fb089f9f..54b8fe856 100644 --- a/mmseg/datasets/dataset_wrappers.py +++ b/mmseg/datasets/dataset_wrappers.py @@ -6,10 +6,10 @@ from itertools import chain import mmcv import numpy as np -from mmcv.utils import build_from_cfg, print_log +from mmcv.utils import print_log from torch.utils.data.dataset import ConcatDataset as _ConcatDataset -from .builder import DATASETS, PIPELINES +from mmseg.registry import DATASETS, TRANSFORMS from .cityscapes import CityscapesDataset @@ -225,7 +225,7 @@ class MultiImageMixDataset: for transform in pipeline: if isinstance(transform, dict): self.pipeline_types.append(transform['type']) - transform = build_from_cfg(transform, PIPELINES) + transform = TRANSFORMS.build(transform) self.pipeline.append(transform) else: raise TypeError('pipeline must be a dict') diff --git a/mmseg/datasets/drive.py b/mmseg/datasets/drive.py index d44fb0da7..efd9de519 100644 --- a/mmseg/datasets/drive.py +++ b/mmseg/datasets/drive.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import DATASETS +from mmseg.registry import DATASETS from .custom import CustomDataset diff --git a/mmseg/datasets/hrf.py b/mmseg/datasets/hrf.py index cf3ea8d79..774eae972 100644 --- a/mmseg/datasets/hrf.py +++ b/mmseg/datasets/hrf.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import DATASETS +from mmseg.registry import DATASETS from .custom import CustomDataset diff --git a/mmseg/datasets/isaid.py b/mmseg/datasets/isaid.py index db24f9376..8c77e3256 100644 --- a/mmseg/datasets/isaid.py +++ b/mmseg/datasets/isaid.py @@ -3,8 +3,8 @@ import mmcv from mmcv.utils import print_log +from mmseg.registry import DATASETS from ..utils import get_root_logger -from .builder import DATASETS from .custom import CustomDataset diff --git a/mmseg/datasets/isprs.py b/mmseg/datasets/isprs.py index 5f23e1a9b..b3c6eca49 100644 --- a/mmseg/datasets/isprs.py +++ b/mmseg/datasets/isprs.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import DATASETS +from mmseg.registry import DATASETS from .custom import CustomDataset diff --git a/mmseg/datasets/loveda.py b/mmseg/datasets/loveda.py index 90d654f62..b0ae2d140 100644 --- a/mmseg/datasets/loveda.py +++ b/mmseg/datasets/loveda.py @@ -5,7 +5,7 @@ import mmcv import numpy as np from PIL import Image -from .builder import DATASETS +from mmseg.registry import DATASETS from .custom import CustomDataset diff --git a/mmseg/datasets/night_driving.py b/mmseg/datasets/night_driving.py index 6620586e3..11a3fade7 100644 --- a/mmseg/datasets/night_driving.py +++ b/mmseg/datasets/night_driving.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import DATASETS +from mmseg.registry import DATASETS from .cityscapes import CityscapesDataset diff --git a/mmseg/datasets/pascal_context.py b/mmseg/datasets/pascal_context.py index efacee0f3..b93440ca8 100644 --- a/mmseg/datasets/pascal_context.py +++ b/mmseg/datasets/pascal_context.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import DATASETS +from mmseg.registry import DATASETS from .custom import CustomDataset diff --git a/mmseg/datasets/pipelines/compose.py b/mmseg/datasets/pipelines/compose.py index 30280c133..5bfaa7046 100644 --- a/mmseg/datasets/pipelines/compose.py +++ b/mmseg/datasets/pipelines/compose.py @@ -1,12 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import collections -from mmcv.utils import build_from_cfg - -from ..builder import PIPELINES +from mmseg.registry import TRANSFORMS -@PIPELINES.register_module() +@TRANSFORMS.register_module() class Compose(object): """Compose multiple transforms sequentially. @@ -20,7 +18,7 @@ class Compose(object): self.transforms = [] for transform in transforms: if isinstance(transform, dict): - transform = build_from_cfg(transform, PIPELINES) + transform = TRANSFORMS.build(transform) self.transforms.append(transform) elif callable(transform): self.transforms.append(transform) diff --git a/mmseg/datasets/pipelines/formatting.py b/mmseg/datasets/pipelines/formatting.py index 4e057c1b8..c4be7aaf9 100644 --- a/mmseg/datasets/pipelines/formatting.py +++ b/mmseg/datasets/pipelines/formatting.py @@ -6,7 +6,7 @@ import numpy as np import torch from mmcv.parallel import DataContainer as DC -from ..builder import PIPELINES +from mmseg.registry import TRANSFORMS def to_tensor(data): @@ -34,7 +34,7 @@ def to_tensor(data): raise TypeError(f'type {type(data)} cannot be converted to tensor.') -@PIPELINES.register_module() +@TRANSFORMS.register_module() class ToTensor(object): """Convert some results to :obj:`torch.Tensor` by given keys. @@ -64,7 +64,7 @@ class ToTensor(object): return self.__class__.__name__ + f'(keys={self.keys})' -@PIPELINES.register_module() +@TRANSFORMS.register_module() class ImageToTensor(object): """Convert image to :obj:`torch.Tensor` by given keys. @@ -102,7 +102,7 @@ class ImageToTensor(object): return self.__class__.__name__ + f'(keys={self.keys})' -@PIPELINES.register_module() +@TRANSFORMS.register_module() class Transpose(object): """Transpose some results by given keys. @@ -136,7 +136,7 @@ class Transpose(object): f'(keys={self.keys}, order={self.order})' -@PIPELINES.register_module() +@TRANSFORMS.register_module() class ToDataContainer(object): """Convert results to :obj:`mmcv.DataContainer` by given fields. @@ -175,7 +175,7 @@ class ToDataContainer(object): return self.__class__.__name__ + f'(fields={self.fields})' -@PIPELINES.register_module() +@TRANSFORMS.register_module() class DefaultFormatBundle(object): """Default formatting bundle. @@ -216,7 +216,7 @@ class DefaultFormatBundle(object): return self.__class__.__name__ -@PIPELINES.register_module() +@TRANSFORMS.register_module() class Collect(object): """Collect data from the loader relevant to the specific task. diff --git a/mmseg/datasets/pipelines/loading.py b/mmseg/datasets/pipelines/loading.py index 572e43431..c0641a3c0 100644 --- a/mmseg/datasets/pipelines/loading.py +++ b/mmseg/datasets/pipelines/loading.py @@ -4,10 +4,10 @@ import os.path as osp import mmcv import numpy as np -from ..builder import PIPELINES +from mmseg.registry import TRANSFORMS -@PIPELINES.register_module() +@TRANSFORMS.register_module() class LoadImageFromFile(object): """Load an image from file. @@ -87,7 +87,7 @@ class LoadImageFromFile(object): return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class LoadAnnotations(object): """Load annotations for semantic segmentation. diff --git a/mmseg/datasets/pipelines/test_time_aug.py b/mmseg/datasets/pipelines/test_time_aug.py index 5c17cbbba..a582b5082 100644 --- a/mmseg/datasets/pipelines/test_time_aug.py +++ b/mmseg/datasets/pipelines/test_time_aug.py @@ -3,11 +3,11 @@ import warnings import mmcv -from ..builder import PIPELINES +from mmseg.registry import TRANSFORMS from .compose import Compose -@PIPELINES.register_module() +@TRANSFORMS.register_module() class MultiScaleFlipAug(object): """Test-time augmentation with multiple scales and flipping. diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 5673b646f..e65f9857e 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -6,10 +6,10 @@ import numpy as np from mmcv.utils import deprecated_api_warning, is_tuple_of from numpy import random -from ..builder import PIPELINES +from mmseg.registry import TRANSFORMS -@PIPELINES.register_module() +@TRANSFORMS.register_module() class ResizeToMultiple(object): """Resize images & seg to multiple of divisor. @@ -66,7 +66,7 @@ class ResizeToMultiple(object): return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class Resize(object): """Resize images & seg. @@ -321,7 +321,7 @@ class Resize(object): return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomFlip(object): """Flip the image & seg. @@ -376,7 +376,7 @@ class RandomFlip(object): return self.__class__.__name__ + f'(prob={self.prob})' -@PIPELINES.register_module() +@TRANSFORMS.register_module() class Pad(object): """Pad the image & mask. @@ -447,7 +447,7 @@ class Pad(object): return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class Normalize(object): """Normalize the image. @@ -489,7 +489,7 @@ class Normalize(object): return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class Rerange(object): """Rerange the image pixel value. @@ -535,7 +535,7 @@ class Rerange(object): return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class CLAHE(object): """Use CLAHE method to process the image. @@ -580,7 +580,7 @@ class CLAHE(object): return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomCrop(object): """Random crop the image & seg. @@ -653,7 +653,7 @@ class RandomCrop(object): return self.__class__.__name__ + f'(crop_size={self.crop_size})' -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomRotate(object): """Rotate the image & seg. @@ -736,7 +736,7 @@ class RandomRotate(object): return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RGB2Gray(object): """Convert RGB image to grayscale image. @@ -791,7 +791,7 @@ class RGB2Gray(object): return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class AdjustGamma(object): """Using gamma correction to process the image. @@ -827,7 +827,7 @@ class AdjustGamma(object): return self.__class__.__name__ + f'(gamma={self.gamma})' -@PIPELINES.register_module() +@TRANSFORMS.register_module() class SegRescale(object): """Rescale semantic segmentation maps. @@ -857,7 +857,7 @@ class SegRescale(object): return self.__class__.__name__ + f'(scale_factor={self.scale_factor})' -@PIPELINES.register_module() +@TRANSFORMS.register_module() class PhotoMetricDistortion(object): """Apply photometric distortion to image sequentially, every transformation is applied with a probability of 0.5. The position of random contrast is in @@ -976,7 +976,7 @@ class PhotoMetricDistortion(object): return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomCutOut(object): """CutOut operation. @@ -1068,7 +1068,7 @@ class RandomCutOut(object): return repr_str -@PIPELINES.register_module() +@TRANSFORMS.register_module() class RandomMosaic(object): """Mosaic augmentation. Given 4 images, mosaic transform combines them into one output image. The output image is composed of the parts from each sub- diff --git a/mmseg/datasets/potsdam.py b/mmseg/datasets/potsdam.py index 2986b8faa..78bbca8ab 100644 --- a/mmseg/datasets/potsdam.py +++ b/mmseg/datasets/potsdam.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .builder import DATASETS +from mmseg.registry import DATASETS from .custom import CustomDataset diff --git a/mmseg/datasets/samplers/distributed_sampler.py b/mmseg/datasets/samplers/distributed_sampler.py index d1a13c716..84b8762c3 100644 --- a/mmseg/datasets/samplers/distributed_sampler.py +++ b/mmseg/datasets/samplers/distributed_sampler.py @@ -7,8 +7,10 @@ from torch.utils.data import Dataset from torch.utils.data import DistributedSampler as _DistributedSampler from mmseg.core.utils import sync_random_seed +from mmseg.registry import DATA_SAMPLERS +@DATA_SAMPLERS.register_module() class DistributedSampler(_DistributedSampler): """DistributedSampler inheriting from `torch.utils.data.DistributedSampler`. diff --git a/mmseg/datasets/stare.py b/mmseg/datasets/stare.py index a24d1d957..fdb98f1c0 100644 --- a/mmseg/datasets/stare.py +++ b/mmseg/datasets/stare.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from .builder import DATASETS +from mmseg.registry import DATASETS from .custom import CustomDataset diff --git a/mmseg/datasets/voc.py b/mmseg/datasets/voc.py index 3cec9e350..e0fafb069 100644 --- a/mmseg/datasets/voc.py +++ b/mmseg/datasets/voc.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from .builder import DATASETS +from mmseg.registry import DATASETS from .custom import CustomDataset diff --git a/mmseg/models/backbones/beit.py b/mmseg/models/backbones/beit.py index fade60137..f3cb98bba 100644 --- a/mmseg/models/backbones/beit.py +++ b/mmseg/models/backbones/beit.py @@ -13,8 +13,8 @@ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.utils import _pair as to_2tuple +from mmseg.registry import MODELS from mmseg.utils import get_root_logger -from ..builder import BACKBONES from ..utils import PatchEmbed from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer @@ -227,7 +227,7 @@ class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer): return x -@BACKBONES.register_module() +@MODELS.register_module() class BEiT(BaseModule): """BERT Pre-Training of Image Transformers. diff --git a/mmseg/models/backbones/bisenetv1.py b/mmseg/models/backbones/bisenetv1.py index 4beb7b394..4eded9076 100644 --- a/mmseg/models/backbones/bisenetv1.py +++ b/mmseg/models/backbones/bisenetv1.py @@ -5,7 +5,7 @@ from mmcv.cnn import ConvModule from mmcv.runner import BaseModule from mmseg.ops import resize -from ..builder import BACKBONES, build_backbone +from mmseg.registry import MODELS class SpatialPath(BaseModule): @@ -156,7 +156,7 @@ class ContextPath(BaseModule): assert len(context_channels) == 3, 'Length of input channels \ of Context Path must be 3!' - self.backbone = build_backbone(backbone_cfg) + self.backbone = MODELS.build(backbone_cfg) self.align_corners = align_corners self.arm16 = AttentionRefinementModule(context_channels[1], @@ -262,7 +262,7 @@ class FeatureFusionModule(BaseModule): return x_out -@BACKBONES.register_module() +@MODELS.register_module() class BiSeNetV1(BaseModule): """BiSeNetV1 backbone. diff --git a/mmseg/models/backbones/bisenetv2.py b/mmseg/models/backbones/bisenetv2.py index d908b321c..50693a4f2 100644 --- a/mmseg/models/backbones/bisenetv2.py +++ b/mmseg/models/backbones/bisenetv2.py @@ -6,7 +6,7 @@ from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, from mmcv.runner import BaseModule from mmseg.ops import resize -from ..builder import BACKBONES +from mmseg.registry import MODELS class DetailBranch(BaseModule): @@ -541,7 +541,7 @@ class BGALayer(BaseModule): return output -@BACKBONES.register_module() +@MODELS.register_module() class BiSeNetV2(BaseModule): """BiSeNetV2: Bilateral Network with Guided Aggregation for Real-time Semantic Segmentation. diff --git a/mmseg/models/backbones/cgnet.py b/mmseg/models/backbones/cgnet.py index 168194c10..a3da0a2ae 100644 --- a/mmseg/models/backbones/cgnet.py +++ b/mmseg/models/backbones/cgnet.py @@ -8,7 +8,7 @@ from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer from mmcv.runner import BaseModule from mmcv.utils.parrots_wrapper import _BatchNorm -from ..builder import BACKBONES +from mmseg.registry import MODELS class GlobalContextExtractor(nn.Module): @@ -183,7 +183,7 @@ class InputInjection(nn.Module): return x -@BACKBONES.register_module() +@MODELS.register_module() class CGNet(BaseModule): """CGNet backbone. diff --git a/mmseg/models/backbones/erfnet.py b/mmseg/models/backbones/erfnet.py index 8921c18f3..7c0da2da2 100644 --- a/mmseg/models/backbones/erfnet.py +++ b/mmseg/models/backbones/erfnet.py @@ -5,7 +5,7 @@ from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer from mmcv.runner import BaseModule from mmseg.ops import resize -from ..builder import BACKBONES +from mmseg.registry import MODELS class DownsamplerBlock(BaseModule): @@ -191,7 +191,7 @@ class UpsamplerBlock(BaseModule): return output -@BACKBONES.register_module() +@MODELS.register_module() class ERFNet(BaseModule): """ERFNet backbone. diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py index cbfbcaf4f..3d40e46b8 100644 --- a/mmseg/models/backbones/fast_scnn.py +++ b/mmseg/models/backbones/fast_scnn.py @@ -6,7 +6,7 @@ from mmcv.runner import BaseModule from mmseg.models.decode_heads.psp_head import PPM from mmseg.ops import resize -from ..builder import BACKBONES +from mmseg.registry import MODELS from ..utils import InvertedResidual @@ -268,7 +268,7 @@ class FeatureFusionModule(nn.Module): return self.relu(out) -@BACKBONES.register_module() +@MODELS.register_module() class FastSCNN(BaseModule): """Fast-SCNN Backbone. diff --git a/mmseg/models/backbones/hrnet.py b/mmseg/models/backbones/hrnet.py index 90feadcf6..dbbd2c8a4 100644 --- a/mmseg/models/backbones/hrnet.py +++ b/mmseg/models/backbones/hrnet.py @@ -7,7 +7,7 @@ from mmcv.runner import BaseModule, ModuleList, Sequential from mmcv.utils.parrots_wrapper import _BatchNorm from mmseg.ops import Upsample, resize -from ..builder import BACKBONES +from mmseg.registry import MODELS from .resnet import BasicBlock, Bottleneck @@ -214,7 +214,7 @@ class HRModule(BaseModule): return x_fuse -@BACKBONES.register_module() +@MODELS.register_module() class HRNet(BaseModule): """HRNet backbone. diff --git a/mmseg/models/backbones/icnet.py b/mmseg/models/backbones/icnet.py index 6faaeab01..3cd7037b3 100644 --- a/mmseg/models/backbones/icnet.py +++ b/mmseg/models/backbones/icnet.py @@ -5,11 +5,11 @@ from mmcv.cnn import ConvModule from mmcv.runner import BaseModule from mmseg.ops import resize -from ..builder import BACKBONES, build_backbone +from mmseg.registry import MODELS from ..decode_heads.psp_head import PPM -@BACKBONES.register_module() +@MODELS.register_module() class ICNet(BaseModule): """ICNet for Real-Time Semantic Segmentation on High-Resolution Images. @@ -66,7 +66,7 @@ class ICNet(BaseModule): ] super(ICNet, self).__init__(init_cfg=init_cfg) self.align_corners = align_corners - self.backbone = build_backbone(backbone_cfg) + self.backbone = MODELS.build(backbone_cfg) # Note: Default `ceil_mode` is false in nn.MaxPool2d, set # `ceil_mode=True` to keep information in the corner of feature map. diff --git a/mmseg/models/backbones/mae.py b/mmseg/models/backbones/mae.py index d3e8754bd..688aff42b 100644 --- a/mmseg/models/backbones/mae.py +++ b/mmseg/models/backbones/mae.py @@ -8,8 +8,8 @@ from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init, from mmcv.runner import ModuleList, _load_checkpoint from torch.nn.modules.batchnorm import _BatchNorm +from mmseg.registry import MODELS from mmseg.utils import get_root_logger -from ..builder import BACKBONES from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer @@ -42,7 +42,7 @@ class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer): self.attn = MAEAttention(**attn_cfg) -@BACKBONES.register_module() +@MODELS.register_module() class MAE(BEiT): """VisionTransformer with support for patch. diff --git a/mmseg/models/backbones/mit.py b/mmseg/models/backbones/mit.py index 4417cf113..83c3bb3ca 100644 --- a/mmseg/models/backbones/mit.py +++ b/mmseg/models/backbones/mit.py @@ -12,7 +12,7 @@ from mmcv.cnn.utils.weight_init import (constant_init, normal_init, trunc_normal_init) from mmcv.runner import BaseModule, ModuleList, Sequential -from ..builder import BACKBONES +from mmseg.registry import MODELS from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw @@ -295,7 +295,7 @@ class TransformerEncoderLayer(BaseModule): return x -@BACKBONES.register_module() +@MODELS.register_module() class MixVisionTransformer(BaseModule): """The backbone of Segformer. diff --git a/mmseg/models/backbones/mobilenet_v2.py b/mmseg/models/backbones/mobilenet_v2.py index cbb9c6cd0..67269182a 100644 --- a/mmseg/models/backbones/mobilenet_v2.py +++ b/mmseg/models/backbones/mobilenet_v2.py @@ -6,11 +6,11 @@ from mmcv.cnn import ConvModule from mmcv.runner import BaseModule from torch.nn.modules.batchnorm import _BatchNorm -from ..builder import BACKBONES +from mmseg.registry import MODELS from ..utils import InvertedResidual, make_divisible -@BACKBONES.register_module() +@MODELS.register_module() class MobileNetV2(BaseModule): """MobileNetV2 backbone. diff --git a/mmseg/models/backbones/mobilenet_v3.py b/mmseg/models/backbones/mobilenet_v3.py index dd3d6eb17..ac73233b0 100644 --- a/mmseg/models/backbones/mobilenet_v3.py +++ b/mmseg/models/backbones/mobilenet_v3.py @@ -7,11 +7,11 @@ from mmcv.cnn.bricks import Conv2dAdaptivePadding from mmcv.runner import BaseModule from torch.nn.modules.batchnorm import _BatchNorm -from ..builder import BACKBONES +from mmseg.registry import MODELS from ..utils import InvertedResidualV3 as InvertedResidual -@BACKBONES.register_module() +@MODELS.register_module() class MobileNetV3(BaseModule): """MobileNetV3 backbone. diff --git a/mmseg/models/backbones/resnest.py b/mmseg/models/backbones/resnest.py index 91952c2ca..519bd9738 100644 --- a/mmseg/models/backbones/resnest.py +++ b/mmseg/models/backbones/resnest.py @@ -7,7 +7,7 @@ import torch.nn.functional as F import torch.utils.checkpoint as cp from mmcv.cnn import build_conv_layer, build_norm_layer -from ..builder import BACKBONES +from mmseg.registry import MODELS from ..utils import ResLayer from .resnet import Bottleneck as _Bottleneck from .resnet import ResNetV1d @@ -267,7 +267,7 @@ class Bottleneck(_Bottleneck): return out -@BACKBONES.register_module() +@MODELS.register_module() class ResNeSt(ResNetV1d): """ResNeSt backbone. diff --git a/mmseg/models/backbones/resnet.py b/mmseg/models/backbones/resnet.py index e8b961d5f..9eda906e6 100644 --- a/mmseg/models/backbones/resnet.py +++ b/mmseg/models/backbones/resnet.py @@ -7,7 +7,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer from mmcv.runner import BaseModule from mmcv.utils.parrots_wrapper import _BatchNorm -from ..builder import BACKBONES +from mmseg.registry import MODELS from ..utils import ResLayer @@ -307,7 +307,7 @@ class Bottleneck(BaseModule): return out -@BACKBONES.register_module() +@MODELS.register_module() class ResNet(BaseModule): """ResNet backbone. @@ -685,7 +685,7 @@ class ResNet(BaseModule): m.eval() -@BACKBONES.register_module() +@MODELS.register_module() class ResNetV1c(ResNet): """ResNetV1c variant described in [1]_. @@ -700,7 +700,7 @@ class ResNetV1c(ResNet): deep_stem=True, avg_down=False, **kwargs) -@BACKBONES.register_module() +@MODELS.register_module() class ResNetV1d(ResNet): """ResNetV1d variant described in [1]_. diff --git a/mmseg/models/backbones/resnext.py b/mmseg/models/backbones/resnext.py index 805c27bf3..2f7cacab7 100644 --- a/mmseg/models/backbones/resnext.py +++ b/mmseg/models/backbones/resnext.py @@ -3,7 +3,7 @@ import math from mmcv.cnn import build_conv_layer, build_norm_layer -from ..builder import BACKBONES +from mmseg.registry import MODELS from ..utils import ResLayer from .resnet import Bottleneck as _Bottleneck from .resnet import ResNet @@ -84,7 +84,7 @@ class Bottleneck(_Bottleneck): self.add_module(self.norm3_name, norm3) -@BACKBONES.register_module() +@MODELS.register_module() class ResNeXt(ResNet): """ResNeXt backbone. diff --git a/mmseg/models/backbones/stdc.py b/mmseg/models/backbones/stdc.py index 04f2f7a2a..ece7da172 100644 --- a/mmseg/models/backbones/stdc.py +++ b/mmseg/models/backbones/stdc.py @@ -7,7 +7,7 @@ from mmcv.cnn import ConvModule from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from mmseg.ops import resize -from ..builder import BACKBONES, build_backbone +from mmseg.registry import MODELS from .bisenetv1 import AttentionRefinementModule @@ -184,7 +184,7 @@ class FeatureFusionModule(BaseModule): return x_attn + x -@BACKBONES.register_module() +@MODELS.register_module() class STDCNet(BaseModule): """This backbone is the implementation of `Rethinking BiSeNet For Real-time Semantic Segmentation `_. @@ -325,7 +325,7 @@ class STDCNet(BaseModule): return tuple(outs) -@BACKBONES.register_module() +@MODELS.register_module() class STDCContextPathNet(BaseModule): """STDCNet with Context Path. The `outs` below is a list of three feature maps from deep to shallow, whose height and width is from small to big, @@ -371,7 +371,7 @@ class STDCContextPathNet(BaseModule): norm_cfg=dict(type='BN'), init_cfg=None): super(STDCContextPathNet, self).__init__(init_cfg=init_cfg) - self.backbone = build_backbone(backbone_cfg) + self.backbone = MODELS.build(backbone_cfg) self.arms = ModuleList() self.convs = ModuleList() for channels in last_in_channels: diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py index cbf13288a..dae660eeb 100644 --- a/mmseg/models/backbones/swin.py +++ b/mmseg/models/backbones/swin.py @@ -15,8 +15,8 @@ from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList, load_state_dict) from mmcv.utils import to_2tuple +from mmseg.registry import MODELS from ...utils import get_root_logger -from ..builder import BACKBONES from ..utils.embed import PatchEmbed, PatchMerging @@ -462,7 +462,7 @@ class SwinBlockSequence(BaseModule): return x, hw_shape, x, hw_shape -@BACKBONES.register_module() +@MODELS.register_module() class SwinTransformer(BaseModule): """Swin Transformer backbone. diff --git a/mmseg/models/backbones/timm_backbone.py b/mmseg/models/backbones/timm_backbone.py index 01b29fc5e..478e8bdea 100644 --- a/mmseg/models/backbones/timm_backbone.py +++ b/mmseg/models/backbones/timm_backbone.py @@ -7,10 +7,10 @@ except ImportError: from mmcv.cnn.bricks.registry import NORM_LAYERS from mmcv.runner import BaseModule -from ..builder import BACKBONES +from mmseg.registry import MODELS -@BACKBONES.register_module() +@MODELS.register_module() class TIMMBackbone(BaseModule): """Wrapper to use backbones from timm library. More details can be found in `timm `_ . diff --git a/mmseg/models/backbones/twins.py b/mmseg/models/backbones/twins.py index 6bd946911..ce1faaa21 100644 --- a/mmseg/models/backbones/twins.py +++ b/mmseg/models/backbones/twins.py @@ -14,7 +14,7 @@ from mmcv.runner import BaseModule, ModuleList from torch.nn.modules.batchnorm import _BatchNorm from mmseg.models.backbones.mit import EfficientMultiheadAttention -from mmseg.models.builder import BACKBONES +from mmseg.registry import MODELS from ..utils.embed import PatchEmbed @@ -349,7 +349,7 @@ class ConditionalPositionEncoding(BaseModule): return x -@BACKBONES.register_module() +@MODELS.register_module() class PCPVT(BaseModule): """The backbone of Twins-PCPVT. @@ -508,7 +508,7 @@ class PCPVT(BaseModule): return tuple(outputs) -@BACKBONES.register_module() +@MODELS.register_module() class SVT(PCPVT): """The backbone of Twins-SVT. diff --git a/mmseg/models/backbones/unet.py b/mmseg/models/backbones/unet.py index c2d33667f..b07edd5f2 100644 --- a/mmseg/models/backbones/unet.py +++ b/mmseg/models/backbones/unet.py @@ -9,7 +9,7 @@ from mmcv.runner import BaseModule from mmcv.utils.parrots_wrapper import _BatchNorm from mmseg.ops import Upsample -from ..builder import BACKBONES +from mmseg.registry import MODELS from ..utils import UpConvBlock @@ -221,7 +221,7 @@ class InterpConv(nn.Module): return out -@BACKBONES.register_module() +@MODELS.register_module() class UNet(BaseModule): """UNet backbone. diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 94ea5eacb..e179f2835 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -15,8 +15,8 @@ from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.utils import _pair as to_2tuple from mmseg.ops import resize +from mmseg.registry import MODELS from mmseg.utils import get_root_logger -from ..builder import BACKBONES from ..utils import PatchEmbed @@ -122,7 +122,7 @@ class TransformerEncoderLayer(BaseModule): return x -@BACKBONES.register_module() +@MODELS.register_module() class VisionTransformer(BaseModule): """Vision Transformer. diff --git a/mmseg/models/builder.py b/mmseg/models/builder.py index 5e18e4e64..081c646b4 100644 --- a/mmseg/models/builder.py +++ b/mmseg/models/builder.py @@ -1,12 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings -from mmcv.cnn import MODELS as MMCV_MODELS -from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION -from mmcv.utils import Registry - -MODELS = Registry('models', parent=MMCV_MODELS) -ATTENTION = Registry('attention', parent=MMCV_ATTENTION) +from mmseg.registry import MODELS BACKBONES = MODELS NECKS = MODELS @@ -17,21 +12,29 @@ SEGMENTORS = MODELS def build_backbone(cfg): """Build backbone.""" + warnings.warn('``build_backbone`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') return BACKBONES.build(cfg) def build_neck(cfg): """Build neck.""" + warnings.warn('``build_neck`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') return NECKS.build(cfg) def build_head(cfg): """Build head.""" + warnings.warn('``build_head`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') return HEADS.build(cfg) def build_loss(cfg): """Build loss.""" + warnings.warn('``build_loss`` would be deprecated soon, please use ' + '``mmseg.registry.MODELS.build()`` ') return LOSSES.build(cfg) diff --git a/mmseg/models/decode_heads/ann_head.py b/mmseg/models/decode_heads/ann_head.py index c8d882e31..9cc791b26 100644 --- a/mmseg/models/decode_heads/ann_head.py +++ b/mmseg/models/decode_heads/ann_head.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from mmcv.cnn import ConvModule -from ..builder import HEADS +from mmseg.registry import MODELS from ..utils import SelfAttentionBlock as _SelfAttentionBlock from .decode_head import BaseDecodeHead @@ -181,7 +181,7 @@ class APNB(nn.Module): return output -@HEADS.register_module() +@MODELS.register_module() class ANNHead(BaseDecodeHead): """Asymmetric Non-local Neural Networks for Semantic Segmentation. diff --git a/mmseg/models/decode_heads/apc_head.py b/mmseg/models/decode_heads/apc_head.py index 3198fd188..45f4e2850 100644 --- a/mmseg/models/decode_heads/apc_head.py +++ b/mmseg/models/decode_heads/apc_head.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule from mmseg.ops import resize -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead @@ -107,7 +107,7 @@ class ACM(nn.Module): return z_out -@HEADS.register_module() +@MODELS.register_module() class APCHead(BaseDecodeHead): """Adaptive Pyramid Context Network for Semantic Segmentation. diff --git a/mmseg/models/decode_heads/aspp_head.py b/mmseg/models/decode_heads/aspp_head.py index 7059aee96..acf9eedfa 100644 --- a/mmseg/models/decode_heads/aspp_head.py +++ b/mmseg/models/decode_heads/aspp_head.py @@ -4,7 +4,7 @@ import torch.nn as nn from mmcv.cnn import ConvModule from mmseg.ops import resize -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead @@ -50,7 +50,7 @@ class ASPPModule(nn.ModuleList): return aspp_outs -@HEADS.register_module() +@MODELS.register_module() class ASPPHead(BaseDecodeHead): """Rethinking Atrous Convolution for Semantic Image Segmentation. diff --git a/mmseg/models/decode_heads/cc_head.py b/mmseg/models/decode_heads/cc_head.py index ed19eb46d..03ad3db76 100644 --- a/mmseg/models/decode_heads/cc_head.py +++ b/mmseg/models/decode_heads/cc_head.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from ..builder import HEADS +from mmseg.registry import MODELS from .fcn_head import FCNHead try: @@ -10,7 +10,7 @@ except ModuleNotFoundError: CrissCrossAttention = None -@HEADS.register_module() +@MODELS.register_module() class CCHead(FCNHead): """CCNet: Criss-Cross Attention for Semantic Segmentation. diff --git a/mmseg/models/decode_heads/da_head.py b/mmseg/models/decode_heads/da_head.py index 77fd6639c..2a66fc204 100644 --- a/mmseg/models/decode_heads/da_head.py +++ b/mmseg/models/decode_heads/da_head.py @@ -5,7 +5,7 @@ from mmcv.cnn import ConvModule, Scale from torch import nn from mmseg.core import add_prefix -from ..builder import HEADS +from mmseg.registry import MODELS from ..utils import SelfAttentionBlock as _SelfAttentionBlock from .decode_head import BaseDecodeHead @@ -72,7 +72,7 @@ class CAM(nn.Module): return out -@HEADS.register_module() +@MODELS.register_module() class DAHead(BaseDecodeHead): """Dual Attention Network for Scene Segmentation. diff --git a/mmseg/models/decode_heads/dm_head.py b/mmseg/models/decode_heads/dm_head.py index ffaa870ab..30405e3eb 100644 --- a/mmseg/models/decode_heads/dm_head.py +++ b/mmseg/models/decode_heads/dm_head.py @@ -4,7 +4,7 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead @@ -89,7 +89,7 @@ class DCM(nn.Module): return output -@HEADS.register_module() +@MODELS.register_module() class DMHead(BaseDecodeHead): """Dynamic Multi-scale Filters for Semantic Segmentation. diff --git a/mmseg/models/decode_heads/dnl_head.py b/mmseg/models/decode_heads/dnl_head.py index dabf15421..400a17556 100644 --- a/mmseg/models/decode_heads/dnl_head.py +++ b/mmseg/models/decode_heads/dnl_head.py @@ -3,7 +3,7 @@ import torch from mmcv.cnn import NonLocal2d from torch import nn -from ..builder import HEADS +from mmseg.registry import MODELS from .fcn_head import FCNHead @@ -89,7 +89,7 @@ class DisentangledNonLocal2d(NonLocal2d): return output -@HEADS.register_module() +@MODELS.register_module() class DNLHead(FCNHead): """Disentangled Non-Local Neural Networks. diff --git a/mmseg/models/decode_heads/dpt_head.py b/mmseg/models/decode_heads/dpt_head.py index 6c895d02d..04ade62cf 100644 --- a/mmseg/models/decode_heads/dpt_head.py +++ b/mmseg/models/decode_heads/dpt_head.py @@ -7,7 +7,7 @@ from mmcv.cnn import ConvModule, Linear, build_activation_layer from mmcv.runner import BaseModule from mmseg.ops import resize -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead @@ -212,7 +212,7 @@ class FeatureFusionBlock(BaseModule): return x -@HEADS.register_module() +@MODELS.register_module() class DPTHead(BaseDecodeHead): """Vision Transformers for Dense Prediction. diff --git a/mmseg/models/decode_heads/ema_head.py b/mmseg/models/decode_heads/ema_head.py index f6de16711..d7923f424 100644 --- a/mmseg/models/decode_heads/ema_head.py +++ b/mmseg/models/decode_heads/ema_head.py @@ -7,7 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead @@ -76,7 +76,7 @@ class EMAModule(nn.Module): return feats_recon -@HEADS.register_module() +@MODELS.register_module() class EMAHead(BaseDecodeHead): """Expectation Maximization Attention Networks for Semantic Segmentation. diff --git a/mmseg/models/decode_heads/enc_head.py b/mmseg/models/decode_heads/enc_head.py index 648c8906b..c6f20b1f3 100644 --- a/mmseg/models/decode_heads/enc_head.py +++ b/mmseg/models/decode_heads/enc_head.py @@ -5,7 +5,8 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule, build_norm_layer from mmseg.ops import Encoding, resize -from ..builder import HEADS, build_loss +from mmseg.registry import MODELS +from ..builder import build_loss from .decode_head import BaseDecodeHead @@ -59,7 +60,7 @@ class EncModule(nn.Module): return encoding_feat, output -@HEADS.register_module() +@MODELS.register_module() class EncHead(BaseDecodeHead): """Context Encoding for Semantic Segmentation. diff --git a/mmseg/models/decode_heads/fcn_head.py b/mmseg/models/decode_heads/fcn_head.py index fb79a0d7c..4e3b974a8 100644 --- a/mmseg/models/decode_heads/fcn_head.py +++ b/mmseg/models/decode_heads/fcn_head.py @@ -3,11 +3,11 @@ import torch import torch.nn as nn from mmcv.cnn import ConvModule -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead -@HEADS.register_module() +@MODELS.register_module() class FCNHead(BaseDecodeHead): """Fully Convolution Networks for Semantic Segmentation. diff --git a/mmseg/models/decode_heads/fpn_head.py b/mmseg/models/decode_heads/fpn_head.py index e41f324cc..be92ceed5 100644 --- a/mmseg/models/decode_heads/fpn_head.py +++ b/mmseg/models/decode_heads/fpn_head.py @@ -4,11 +4,11 @@ import torch.nn as nn from mmcv.cnn import ConvModule from mmseg.ops import Upsample, resize -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead -@HEADS.register_module() +@MODELS.register_module() class FPNHead(BaseDecodeHead): """Panoptic Feature Pyramid Networks. diff --git a/mmseg/models/decode_heads/gc_head.py b/mmseg/models/decode_heads/gc_head.py index eed507425..e89b92d8b 100644 --- a/mmseg/models/decode_heads/gc_head.py +++ b/mmseg/models/decode_heads/gc_head.py @@ -2,11 +2,11 @@ import torch from mmcv.cnn import ContextBlock -from ..builder import HEADS +from mmseg.registry import MODELS from .fcn_head import FCNHead -@HEADS.register_module() +@MODELS.register_module() class GCHead(FCNHead): """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. diff --git a/mmseg/models/decode_heads/isa_head.py b/mmseg/models/decode_heads/isa_head.py index 0bf345557..3769bdff4 100644 --- a/mmseg/models/decode_heads/isa_head.py +++ b/mmseg/models/decode_heads/isa_head.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F from mmcv.cnn import ConvModule -from ..builder import HEADS +from mmseg.registry import MODELS from ..utils import SelfAttentionBlock as _SelfAttentionBlock from .decode_head import BaseDecodeHead @@ -55,7 +55,7 @@ class SelfAttentionBlock(_SelfAttentionBlock): return self.output_project(context) -@HEADS.register_module() +@MODELS.register_module() class ISAHead(BaseDecodeHead): """Interlaced Sparse Self-Attention for Semantic Segmentation. diff --git a/mmseg/models/decode_heads/knet_head.py b/mmseg/models/decode_heads/knet_head.py index f73daccb6..46de21434 100644 --- a/mmseg/models/decode_heads/knet_head.py +++ b/mmseg/models/decode_heads/knet_head.py @@ -7,8 +7,8 @@ from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER, MultiheadAttention, build_transformer_layer) -from mmseg.models.builder import HEADS, build_head from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.registry import MODELS from mmseg.utils import get_root_logger @@ -139,7 +139,7 @@ class KernelUpdator(nn.Module): return features -@HEADS.register_module() +@MODELS.register_module() class KernelUpdateHead(nn.Module): """Kernel Update Head in K-Net. @@ -391,7 +391,7 @@ class KernelUpdateHead(nn.Module): self.conv_kernel_size) -@HEADS.register_module() +@MODELS.register_module() class IterativeDecodeHead(BaseDecodeHead): """K-Net: Towards Unified Image Segmentation. @@ -414,7 +414,7 @@ class IterativeDecodeHead(BaseDecodeHead): super(BaseDecodeHead, self).__init__(**kwargs) assert num_stages == len(kernel_update_head) self.num_stages = num_stages - self.kernel_generate_head = build_head(kernel_generate_head) + self.kernel_generate_head = (kernel_generate_head) self.kernel_update_head = nn.ModuleList() self.align_corners = self.kernel_generate_head.align_corners self.num_classes = self.kernel_generate_head.num_classes @@ -422,7 +422,7 @@ class IterativeDecodeHead(BaseDecodeHead): self.ignore_index = self.kernel_generate_head.ignore_index for head_cfg in kernel_update_head: - self.kernel_update_head.append(build_head(head_cfg)) + self.kernel_update_head.append(MODELS.build(head_cfg)) def forward(self, inputs): """Forward function.""" diff --git a/mmseg/models/decode_heads/lraspp_head.py b/mmseg/models/decode_heads/lraspp_head.py index c10ff0d82..36999f056 100644 --- a/mmseg/models/decode_heads/lraspp_head.py +++ b/mmseg/models/decode_heads/lraspp_head.py @@ -5,11 +5,11 @@ from mmcv import is_tuple_of from mmcv.cnn import ConvModule from mmseg.ops import resize -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead -@HEADS.register_module() +@MODELS.register_module() class LRASPPHead(BaseDecodeHead): """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. diff --git a/mmseg/models/decode_heads/nl_head.py b/mmseg/models/decode_heads/nl_head.py index 637517e7a..7903f1ace 100644 --- a/mmseg/models/decode_heads/nl_head.py +++ b/mmseg/models/decode_heads/nl_head.py @@ -2,11 +2,11 @@ import torch from mmcv.cnn import NonLocal2d -from ..builder import HEADS +from mmseg.registry import MODELS from .fcn_head import FCNHead -@HEADS.register_module() +@MODELS.register_module() class NLHead(FCNHead): """Non-local Neural Networks. diff --git a/mmseg/models/decode_heads/ocr_head.py b/mmseg/models/decode_heads/ocr_head.py index 09eadfb1a..ce3582413 100644 --- a/mmseg/models/decode_heads/ocr_head.py +++ b/mmseg/models/decode_heads/ocr_head.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule from mmseg.ops import resize -from ..builder import HEADS +from mmseg.registry import MODELS from ..utils import SelfAttentionBlock as _SelfAttentionBlock from .cascade_decode_head import BaseCascadeDecodeHead @@ -82,7 +82,7 @@ class ObjectAttentionBlock(_SelfAttentionBlock): return output -@HEADS.register_module() +@MODELS.register_module() class OCRHead(BaseCascadeDecodeHead): """Object-Contextual Representations for Semantic Segmentation. diff --git a/mmseg/models/decode_heads/point_head.py b/mmseg/models/decode_heads/point_head.py index 5e605271c..a94d778b0 100644 --- a/mmseg/models/decode_heads/point_head.py +++ b/mmseg/models/decode_heads/point_head.py @@ -10,8 +10,8 @@ try: except ModuleNotFoundError: point_sample = None -from mmseg.models.builder import HEADS from mmseg.ops import resize +from mmseg.registry import MODELS from ..losses import accuracy from .cascade_decode_head import BaseCascadeDecodeHead @@ -36,7 +36,7 @@ def calculate_uncertainty(seg_logits): return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) -@HEADS.register_module() +@MODELS.register_module() class PointHead(BaseCascadeDecodeHead): """A mask point head use in PointRend. diff --git a/mmseg/models/decode_heads/psa_head.py b/mmseg/models/decode_heads/psa_head.py index df7593cbc..4b292c600 100644 --- a/mmseg/models/decode_heads/psa_head.py +++ b/mmseg/models/decode_heads/psa_head.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule from mmseg.ops import resize -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead try: @@ -14,7 +14,7 @@ except ModuleNotFoundError: PSAMask = None -@HEADS.register_module() +@MODELS.register_module() class PSAHead(BaseDecodeHead): """Point-wise Spatial Attention Network for Scene Parsing. diff --git a/mmseg/models/decode_heads/psp_head.py b/mmseg/models/decode_heads/psp_head.py index 6990676ff..734de8f1a 100644 --- a/mmseg/models/decode_heads/psp_head.py +++ b/mmseg/models/decode_heads/psp_head.py @@ -4,7 +4,7 @@ import torch.nn as nn from mmcv.cnn import ConvModule from mmseg.ops import resize -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead @@ -59,7 +59,7 @@ class PPM(nn.ModuleList): return ppm_outs -@HEADS.register_module() +@MODELS.register_module() class PSPHead(BaseDecodeHead): """Pyramid Scene Parsing Network. diff --git a/mmseg/models/decode_heads/segformer_head.py b/mmseg/models/decode_heads/segformer_head.py index 2e75d5069..8c14602a3 100644 --- a/mmseg/models/decode_heads/segformer_head.py +++ b/mmseg/models/decode_heads/segformer_head.py @@ -3,12 +3,12 @@ import torch import torch.nn as nn from mmcv.cnn import ConvModule -from mmseg.models.builder import HEADS from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmseg.ops import resize +from mmseg.registry import MODELS -@HEADS.register_module() +@MODELS.register_module() class SegformerHead(BaseDecodeHead): """The all mlp Head of segformer. diff --git a/mmseg/models/decode_heads/segmenter_mask_head.py b/mmseg/models/decode_heads/segmenter_mask_head.py index 6a9b3d47e..95a85a9e3 100644 --- a/mmseg/models/decode_heads/segmenter_mask_head.py +++ b/mmseg/models/decode_heads/segmenter_mask_head.py @@ -8,11 +8,11 @@ from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_, from mmcv.runner import ModuleList from mmseg.models.backbones.vit import TransformerEncoderLayer -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead -@HEADS.register_module() +@MODELS.register_module() class SegmenterMaskTransformerHead(BaseDecodeHead): """Segmenter: Transformer for Semantic Segmentation. diff --git a/mmseg/models/decode_heads/sep_aspp_head.py b/mmseg/models/decode_heads/sep_aspp_head.py index 4e894e28e..c63217931 100644 --- a/mmseg/models/decode_heads/sep_aspp_head.py +++ b/mmseg/models/decode_heads/sep_aspp_head.py @@ -4,7 +4,7 @@ import torch.nn as nn from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule from mmseg.ops import resize -from ..builder import HEADS +from mmseg.registry import MODELS from .aspp_head import ASPPHead, ASPPModule @@ -26,7 +26,7 @@ class DepthwiseSeparableASPPModule(ASPPModule): act_cfg=self.act_cfg) -@HEADS.register_module() +@MODELS.register_module() class DepthwiseSeparableASPPHead(ASPPHead): """Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation. diff --git a/mmseg/models/decode_heads/sep_fcn_head.py b/mmseg/models/decode_heads/sep_fcn_head.py index 7f9658e08..5c8b79bd0 100644 --- a/mmseg/models/decode_heads/sep_fcn_head.py +++ b/mmseg/models/decode_heads/sep_fcn_head.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmcv.cnn import DepthwiseSeparableConvModule -from ..builder import HEADS +from mmseg.registry import MODELS from .fcn_head import FCNHead -@HEADS.register_module() +@MODELS.register_module() class DepthwiseSeparableFCNHead(FCNHead): """Depthwise-Separable Fully Convolutional Network for Semantic Segmentation. diff --git a/mmseg/models/decode_heads/setr_mla_head.py b/mmseg/models/decode_heads/setr_mla_head.py index 6bb94ae33..228c31100 100644 --- a/mmseg/models/decode_heads/setr_mla_head.py +++ b/mmseg/models/decode_heads/setr_mla_head.py @@ -4,11 +4,11 @@ import torch.nn as nn from mmcv.cnn import ConvModule from mmseg.ops import Upsample -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead -@HEADS.register_module() +@MODELS.register_module() class SETRMLAHead(BaseDecodeHead): """Multi level feature aggretation head of SETR. diff --git a/mmseg/models/decode_heads/setr_up_head.py b/mmseg/models/decode_heads/setr_up_head.py index 87e7ea7fa..11ab9bb7d 100644 --- a/mmseg/models/decode_heads/setr_up_head.py +++ b/mmseg/models/decode_heads/setr_up_head.py @@ -3,11 +3,11 @@ import torch.nn as nn from mmcv.cnn import ConvModule, build_norm_layer from mmseg.ops import Upsample -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead -@HEADS.register_module() +@MODELS.register_module() class SETRUPHead(BaseDecodeHead): """Naive upsampling head and Progressive upsampling head of SETR. diff --git a/mmseg/models/decode_heads/stdc_head.py b/mmseg/models/decode_heads/stdc_head.py index 1cf3732ce..e18354ffb 100644 --- a/mmseg/models/decode_heads/stdc_head.py +++ b/mmseg/models/decode_heads/stdc_head.py @@ -2,11 +2,11 @@ import torch import torch.nn.functional as F -from ..builder import HEADS +from mmseg.registry import MODELS from .fcn_head import FCNHead -@HEADS.register_module() +@MODELS.register_module() class STDCHead(FCNHead): """This head is the implementation of `Rethinking BiSeNet For Real-time Semantic Segmentation `_. diff --git a/mmseg/models/decode_heads/uper_head.py b/mmseg/models/decode_heads/uper_head.py index 06b152a3f..347ef3292 100644 --- a/mmseg/models/decode_heads/uper_head.py +++ b/mmseg/models/decode_heads/uper_head.py @@ -4,12 +4,12 @@ import torch.nn as nn from mmcv.cnn import ConvModule from mmseg.ops import resize -from ..builder import HEADS +from mmseg.registry import MODELS from .decode_head import BaseDecodeHead from .psp_head import PPM -@HEADS.register_module() +@MODELS.register_module() class UPerHead(BaseDecodeHead): """Unified Perceptual Parsing for Scene Understanding. diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py index 623fd58db..e607248fb 100644 --- a/mmseg/models/losses/cross_entropy_loss.py +++ b/mmseg/models/losses/cross_entropy_loss.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from ..builder import LOSSES +from mmseg.registry import MODELS from .utils import get_class_weight, weight_reduce_loss @@ -193,7 +193,7 @@ def mask_cross_entropy(pred, pred_slice, target, weight=class_weight, reduction='mean')[None] -@LOSSES.register_module() +@MODELS.register_module() class CrossEntropyLoss(nn.Module): """CrossEntropyLoss. diff --git a/mmseg/models/losses/dice_loss.py b/mmseg/models/losses/dice_loss.py index 79a3abfc2..4a98aaee9 100644 --- a/mmseg/models/losses/dice_loss.py +++ b/mmseg/models/losses/dice_loss.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from ..builder import LOSSES +from mmseg.registry import MODELS from .utils import get_class_weight, weighted_loss @@ -47,7 +47,7 @@ def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards): return 1 - num / den -@LOSSES.register_module() +@MODELS.register_module() class DiceLoss(nn.Module): """DiceLoss. diff --git a/mmseg/models/losses/focal_loss.py b/mmseg/models/losses/focal_loss.py index af1c711df..cd7eff4f6 100644 --- a/mmseg/models/losses/focal_loss.py +++ b/mmseg/models/losses/focal_loss.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss -from ..builder import LOSSES +from mmseg.registry import MODELS from .utils import weight_reduce_loss @@ -133,7 +133,7 @@ def sigmoid_focal_loss(pred, return loss -@LOSSES.register_module() +@MODELS.register_module() class FocalLoss(nn.Module): def __init__(self, diff --git a/mmseg/models/losses/lovasz_loss.py b/mmseg/models/losses/lovasz_loss.py index 2bb0fad39..457a23316 100644 --- a/mmseg/models/losses/lovasz_loss.py +++ b/mmseg/models/losses/lovasz_loss.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from ..builder import LOSSES +from mmseg.registry import MODELS from .utils import get_class_weight, weight_reduce_loss @@ -222,7 +222,7 @@ def lovasz_softmax(probs, return loss -@LOSSES.register_module() +@MODELS.register_module() class LovaszLoss(nn.Module): """LovaszLoss. diff --git a/mmseg/models/necks/featurepyramid.py b/mmseg/models/necks/featurepyramid.py index 82a00ceb1..40453653d 100644 --- a/mmseg/models/necks/featurepyramid.py +++ b/mmseg/models/necks/featurepyramid.py @@ -2,10 +2,10 @@ import torch.nn as nn from mmcv.cnn import build_norm_layer -from ..builder import NECKS +from mmseg.registry import MODELS -@NECKS.register_module() +@MODELS.register_module() class Feature2Pyramid(nn.Module): """Feature2Pyramid. diff --git a/mmseg/models/necks/fpn.py b/mmseg/models/necks/fpn.py index 6997de9d4..ee0e23240 100644 --- a/mmseg/models/necks/fpn.py +++ b/mmseg/models/necks/fpn.py @@ -5,10 +5,10 @@ from mmcv.cnn import ConvModule from mmcv.runner import BaseModule, auto_fp16 from mmseg.ops import resize -from ..builder import NECKS +from mmseg.registry import MODELS -@NECKS.register_module() +@MODELS.register_module() class FPN(BaseModule): """Feature Pyramid Network. diff --git a/mmseg/models/necks/ic_neck.py b/mmseg/models/necks/ic_neck.py index a5d81cef8..973683c2b 100644 --- a/mmseg/models/necks/ic_neck.py +++ b/mmseg/models/necks/ic_neck.py @@ -4,7 +4,7 @@ from mmcv.cnn import ConvModule from mmcv.runner import BaseModule from mmseg.ops import resize -from ..builder import NECKS +from mmseg.registry import MODELS class CascadeFeatureFusion(BaseModule): @@ -77,7 +77,7 @@ class CascadeFeatureFusion(BaseModule): return x, x_low -@NECKS.register_module() +@MODELS.register_module() class ICNeck(BaseModule): """ICNet for Real-Time Semantic Segmentation on High-Resolution Images. diff --git a/mmseg/models/necks/jpu.py b/mmseg/models/necks/jpu.py index 3cc6b9f42..9de2435f9 100644 --- a/mmseg/models/necks/jpu.py +++ b/mmseg/models/necks/jpu.py @@ -5,10 +5,10 @@ from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule from mmcv.runner import BaseModule from mmseg.ops import resize -from ..builder import NECKS +from mmseg.registry import MODELS -@NECKS.register_module() +@MODELS.register_module() class JPU(BaseModule): """FastFCN: Rethinking Dilated Convolution in the Backbone for Semantic Segmentation. diff --git a/mmseg/models/necks/mla_neck.py b/mmseg/models/necks/mla_neck.py index 1513e296d..64a409239 100644 --- a/mmseg/models/necks/mla_neck.py +++ b/mmseg/models/necks/mla_neck.py @@ -2,7 +2,7 @@ import torch.nn as nn from mmcv.cnn import ConvModule, build_norm_layer -from ..builder import NECKS +from mmseg.registry import MODELS class MLAModule(nn.Module): @@ -59,7 +59,7 @@ class MLAModule(nn.Module): return tuple(out_list) -@NECKS.register_module() +@MODELS.register_module() class MLANeck(nn.Module): """Multi-level Feature Aggregation. diff --git a/mmseg/models/necks/multilevel_neck.py b/mmseg/models/necks/multilevel_neck.py index 5151f8762..14942bb63 100644 --- a/mmseg/models/necks/multilevel_neck.py +++ b/mmseg/models/necks/multilevel_neck.py @@ -3,10 +3,10 @@ import torch.nn as nn from mmcv.cnn import ConvModule, xavier_init from mmseg.ops import resize -from ..builder import NECKS +from mmseg.registry import MODELS -@NECKS.register_module() +@MODELS.register_module() class MultiLevelNeck(nn.Module): """MultiLevelNeck. diff --git a/mmseg/models/segmentors/cascade_encoder_decoder.py b/mmseg/models/segmentors/cascade_encoder_decoder.py index 1913a22e2..8e0676563 100644 --- a/mmseg/models/segmentors/cascade_encoder_decoder.py +++ b/mmseg/models/segmentors/cascade_encoder_decoder.py @@ -3,12 +3,11 @@ from torch import nn from mmseg.core import add_prefix from mmseg.ops import resize -from .. import builder -from ..builder import SEGMENTORS +from mmseg.registry import MODELS from .encoder_decoder import EncoderDecoder -@SEGMENTORS.register_module() +@MODELS.register_module() class CascadeEncoderDecoder(EncoderDecoder): """Cascade Encoder Decoder segmentors. @@ -44,7 +43,7 @@ class CascadeEncoderDecoder(EncoderDecoder): assert len(decode_head) == self.num_stages self.decode_head = nn.ModuleList() for i in range(self.num_stages): - self.decode_head.append(builder.build_head(decode_head[i])) + self.decode_head.append(MODELS.build(decode_head[i])) self.align_corners = self.decode_head[-1].align_corners self.num_classes = self.decode_head[-1].num_classes diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 72467b469..c22705f7e 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -5,12 +5,11 @@ import torch.nn.functional as F from mmseg.core import add_prefix from mmseg.ops import resize -from .. import builder -from ..builder import SEGMENTORS +from mmseg.registry import MODELS from .base import BaseSegmentor -@SEGMENTORS.register_module() +@MODELS.register_module() class EncoderDecoder(BaseSegmentor): """Encoder Decoder segmentors. @@ -33,9 +32,9 @@ class EncoderDecoder(BaseSegmentor): assert backbone.get('pretrained') is None, \ 'both backbone and segmentor set pretrained weight' backbone.pretrained = pretrained - self.backbone = builder.build_backbone(backbone) + self.backbone = MODELS.build(backbone) if neck is not None: - self.neck = builder.build_neck(neck) + self.neck = MODELS.build(neck) self._init_decode_head(decode_head) self._init_auxiliary_head(auxiliary_head) @@ -46,7 +45,7 @@ class EncoderDecoder(BaseSegmentor): def _init_decode_head(self, decode_head): """Initialize ``decode_head``""" - self.decode_head = builder.build_head(decode_head) + self.decode_head = MODELS.build(decode_head) self.align_corners = self.decode_head.align_corners self.num_classes = self.decode_head.num_classes @@ -56,9 +55,9 @@ class EncoderDecoder(BaseSegmentor): if isinstance(auxiliary_head, list): self.auxiliary_head = nn.ModuleList() for head_cfg in auxiliary_head: - self.auxiliary_head.append(builder.build_head(head_cfg)) + self.auxiliary_head.append(MODELS.build(head_cfg)) else: - self.auxiliary_head = builder.build_head(auxiliary_head) + self.auxiliary_head = MODELS.build(auxiliary_head) def extract_feat(self, img): """Extract features from images.""" diff --git a/mmseg/registry/__init__.py b/mmseg/registry/__init__.py new file mode 100644 index 000000000..3ea642d99 --- /dev/null +++ b/mmseg/registry/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS, + MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, + OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, + RUNNERS, TASK_UTILS, TRANSFORMS, VISBACKENDS, + VISUALIZERS, WEIGHT_INITIALIZERS) + +__all__ = [ + 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', + 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', + 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS', + 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS' +] diff --git a/mmseg/registry/registry.py b/mmseg/registry/registry.py new file mode 100644 index 000000000..6e104e1da --- /dev/null +++ b/mmseg/registry/registry.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""MMSegmentation provides 17 registry nodes to support using modules across +projects. Each node is a child of the root registry in MMEngine. + +More details can be found at +https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. +""" + +from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS +from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import LOOPS as MMENGINE_LOOPS +from mmengine.registry import METRICS as MMENGINE_METRICS +from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS +from mmengine.registry import MODELS as MMENGINE_MODELS +from mmengine.registry import \ + OPTIMIZER_CONSTRUCTORS as MMENGINE_OPTIMIZER_CONSTRUCTORS +from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS +from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS +from mmengine.registry import \ + RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS +from mmengine.registry import RUNNERS as MMENGINE_RUNNERS +from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS +from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS +from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS +from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS +from mmengine.registry import \ + WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS +from mmengine.registry import Registry + +# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner` +RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS) +# manage runner constructors that define how to initialize runners +RUNNER_CONSTRUCTORS = Registry( + 'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS) +# manage all kinds of loops like `EpochBasedTrainLoop` +LOOPS = Registry('loop', parent=MMENGINE_LOOPS) +# manage all kinds of hooks like `CheckpointHook` +HOOKS = Registry('hook', parent=MMENGINE_HOOKS) + +# manage data-related modules +DATASETS = Registry('dataset', parent=MMENGINE_DATASETS) +DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS) +TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS) + +# mangage all kinds of modules inheriting `nn.Module` +MODELS = Registry('model', parent=MMENGINE_MODELS) +# mangage all kinds of model wrappers like 'MMDistributedDataParallel' +MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS) +# mangage all kinds of weight initialization modules like `Uniform` +WEIGHT_INITIALIZERS = Registry( + 'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS) + +# mangage all kinds of optimizers like `SGD` and `Adam` +OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS) +# manage constructors that customize the optimization hyperparameters. +OPTIMIZER_CONSTRUCTORS = Registry( + 'optimizer constructor', parent=MMENGINE_OPTIMIZER_CONSTRUCTORS) +# mangage all kinds of parameter schedulers like `MultiStepLR` +PARAM_SCHEDULERS = Registry( + 'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS) +# manage all kinds of metrics +METRICS = Registry('metric', parent=MMENGINE_METRICS) + +# manage task-specific modules like ohem pixel sampler +TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS) + +# manage visualizer +VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS) +# manage visualizer backend +VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS) diff --git a/tests/test_core/test_optimizer.py b/tests/test_core/test_optimizer.py index 247f9feb1..41c8fc252 100644 --- a/tests/test_core/test_optimizer.py +++ b/tests/test_core/test_optimizer.py @@ -2,10 +2,10 @@ import pytest import torch import torch.nn as nn -from mmcv.runner import DefaultOptimizerConstructor +from mmengine.optim import DefaultOptimizerConstructor -from mmseg.core.builder import (OPTIMIZER_BUILDERS, build_optimizer, - build_optimizer_constructor) +from mmseg.core.builder import build_optimizer, build_optimizer_constructor +from mmseg.registry import OPTIMIZER_CONSTRUCTORS class ExampleModel(nn.Module): @@ -35,7 +35,7 @@ def test_build_optimizer_constructor(): # Test whether optimizer constructor can be built from parent. assert type(optim_constructor) is DefaultOptimizerConstructor - @OPTIMIZER_BUILDERS.register_module() + @OPTIMIZER_CONSTRUCTORS.register_module() class MyOptimizerConstructor(DefaultOptimizerConstructor): pass