Merge branch 'linfangjian/refactor_registies' into 'refactor_dev'

[Refactor] Refactor all registries

See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!3
pull/1801/head
zhengmiao 2022-05-10 12:15:20 +00:00
commit 168c8831dd
94 changed files with 312 additions and 229 deletions

View File

@ -1,11 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import (OPTIMIZER_BUILDERS, build_optimizer,
build_optimizer_constructor)
from .builder import build_optimizer, build_optimizer_constructor
from .evaluation import * # noqa: F401, F403
from .optimizers import * # noqa: F401, F403
from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
__all__ = [
'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor'
]
__all__ = ['build_optimizer', 'build_optimizer_constructor']

View File

@ -1,19 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
from mmcv.utils import Registry, build_from_cfg
OPTIMIZER_BUILDERS = Registry(
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
def build_optimizer_constructor(cfg):
constructor_type = cfg.get('type')
if constructor_type in OPTIMIZER_BUILDERS:
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
if constructor_type in OPTIMIZER_CONSTRUCTORS:
return OPTIMIZER_CONSTRUCTORS.build(cfg)
else:
raise KeyError(f'{constructor_type} is not registered '
'in the optimizer builder registry.')

View File

@ -2,10 +2,11 @@
import json
import warnings
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
from mmengine.dist import get_dist_info
from mmengine.optim import DefaultOptimizerConstructor
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
from mmseg.utils import get_root_logger
from ..builder import OPTIMIZER_BUILDERS
def get_layer_id_for_convnext(var_name, max_layer_id):
@ -99,7 +100,7 @@ def get_layer_id_for_vit(var_name, max_layer_id):
return max_layer_id - 1
@OPTIMIZER_BUILDERS.register_module()
@OPTIMIZER_CONSTRUCTORS.register_module()
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
"""Different learning rates are set for different layers of backbone.
@ -185,7 +186,7 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
params.extend(parameter_groups.values())
@OPTIMIZER_BUILDERS.register_module()
@OPTIMIZER_CONSTRUCTORS.register_module()
class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
"""Different learning rates are set for different layers of backbone.

View File

@ -1,9 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, build_from_cfg
import warnings
PIXEL_SAMPLERS = Registry('pixel sampler')
from mmseg.registry import TASK_UTILS
PIXEL_SAMPLERS = TASK_UTILS
def build_pixel_sampler(cfg, **default_args):
"""Build pixel sampler for segmentation map."""
return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
warnings.warn(
'``build_pixel_sampler`` would be deprecated soon, please use '
'``mmseg.registry.TASK_UTILS.build()`` ')
return TASK_UTILS.build(cfg, default_args=default_args)

View File

@ -5,7 +5,7 @@ import mmcv
import numpy as np
from PIL import Image
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset

View File

@ -8,9 +8,10 @@ import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg, digit_version
from mmcv.utils import digit_version
from torch.utils.data import DataLoader
from mmseg.registry import DATASETS, TRANSFORMS
from .samplers import DistributedSampler
if platform.system() != 'Windows':
@ -22,8 +23,7 @@ if platform.system() != 'Windows':
soft_limit = min(max(4096, base_soft_limit), hard_limit)
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
DATASETS = Registry('dataset')
PIPELINES = Registry('pipeline')
PIPELINES = TRANSFORMS
def _concat_dataset(cfg, default_args=None):
@ -82,7 +82,7 @@ def build_dataset(cfg, default_args=None):
cfg.get('split', None), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
dataset = DATASETS.build(cfg, default_args=default_args)
return dataset

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset

View File

@ -6,7 +6,7 @@ import numpy as np
from mmcv.utils import print_log
from PIL import Image
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset

View File

@ -10,8 +10,8 @@ from prettytable import PrettyTable
from torch.utils.data import Dataset
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
from mmseg.registry import DATASETS
from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose, LoadAnnotations

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .cityscapes import CityscapesDataset

View File

@ -6,10 +6,10 @@ from itertools import chain
import mmcv
import numpy as np
from mmcv.utils import build_from_cfg, print_log
from mmcv.utils import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from .builder import DATASETS, PIPELINES
from mmseg.registry import DATASETS, TRANSFORMS
from .cityscapes import CityscapesDataset
@ -225,7 +225,7 @@ class MultiImageMixDataset:
for transform in pipeline:
if isinstance(transform, dict):
self.pipeline_types.append(transform['type'])
transform = build_from_cfg(transform, PIPELINES)
transform = TRANSFORMS.build(transform)
self.pipeline.append(transform)
else:
raise TypeError('pipeline must be a dict')

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset

View File

@ -3,8 +3,8 @@
import mmcv
from mmcv.utils import print_log
from mmseg.registry import DATASETS
from ..utils import get_root_logger
from .builder import DATASETS
from .custom import CustomDataset

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset

View File

@ -5,7 +5,7 @@ import mmcv
import numpy as np
from PIL import Image
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .cityscapes import CityscapesDataset

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset

View File

@ -1,12 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import collections
from mmcv.utils import build_from_cfg
from ..builder import PIPELINES
from mmseg.registry import TRANSFORMS
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class Compose(object):
"""Compose multiple transforms sequentially.
@ -20,7 +18,7 @@ class Compose(object):
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES)
transform = TRANSFORMS.build(transform)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)

View File

@ -6,7 +6,7 @@ import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from ..builder import PIPELINES
from mmseg.registry import TRANSFORMS
def to_tensor(data):
@ -34,7 +34,7 @@ def to_tensor(data):
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class ToTensor(object):
"""Convert some results to :obj:`torch.Tensor` by given keys.
@ -64,7 +64,7 @@ class ToTensor(object):
return self.__class__.__name__ + f'(keys={self.keys})'
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class ImageToTensor(object):
"""Convert image to :obj:`torch.Tensor` by given keys.
@ -102,7 +102,7 @@ class ImageToTensor(object):
return self.__class__.__name__ + f'(keys={self.keys})'
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class Transpose(object):
"""Transpose some results by given keys.
@ -136,7 +136,7 @@ class Transpose(object):
f'(keys={self.keys}, order={self.order})'
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class ToDataContainer(object):
"""Convert results to :obj:`mmcv.DataContainer` by given fields.
@ -175,7 +175,7 @@ class ToDataContainer(object):
return self.__class__.__name__ + f'(fields={self.fields})'
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class DefaultFormatBundle(object):
"""Default formatting bundle.
@ -216,7 +216,7 @@ class DefaultFormatBundle(object):
return self.__class__.__name__
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class Collect(object):
"""Collect data from the loader relevant to the specific task.

View File

@ -4,10 +4,10 @@ import os.path as osp
import mmcv
import numpy as np
from ..builder import PIPELINES
from mmseg.registry import TRANSFORMS
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class LoadImageFromFile(object):
"""Load an image from file.
@ -87,7 +87,7 @@ class LoadImageFromFile(object):
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class LoadAnnotations(object):
"""Load annotations for semantic segmentation.

View File

@ -3,11 +3,11 @@ import warnings
import mmcv
from ..builder import PIPELINES
from mmseg.registry import TRANSFORMS
from .compose import Compose
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class MultiScaleFlipAug(object):
"""Test-time augmentation with multiple scales and flipping.

View File

@ -6,10 +6,10 @@ import numpy as np
from mmcv.utils import deprecated_api_warning, is_tuple_of
from numpy import random
from ..builder import PIPELINES
from mmseg.registry import TRANSFORMS
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class ResizeToMultiple(object):
"""Resize images & seg to multiple of divisor.
@ -66,7 +66,7 @@ class ResizeToMultiple(object):
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class Resize(object):
"""Resize images & seg.
@ -321,7 +321,7 @@ class Resize(object):
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomFlip(object):
"""Flip the image & seg.
@ -376,7 +376,7 @@ class RandomFlip(object):
return self.__class__.__name__ + f'(prob={self.prob})'
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class Pad(object):
"""Pad the image & mask.
@ -447,7 +447,7 @@ class Pad(object):
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class Normalize(object):
"""Normalize the image.
@ -489,7 +489,7 @@ class Normalize(object):
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class Rerange(object):
"""Rerange the image pixel value.
@ -535,7 +535,7 @@ class Rerange(object):
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class CLAHE(object):
"""Use CLAHE method to process the image.
@ -580,7 +580,7 @@ class CLAHE(object):
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomCrop(object):
"""Random crop the image & seg.
@ -653,7 +653,7 @@ class RandomCrop(object):
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomRotate(object):
"""Rotate the image & seg.
@ -736,7 +736,7 @@ class RandomRotate(object):
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RGB2Gray(object):
"""Convert RGB image to grayscale image.
@ -791,7 +791,7 @@ class RGB2Gray(object):
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class AdjustGamma(object):
"""Using gamma correction to process the image.
@ -827,7 +827,7 @@ class AdjustGamma(object):
return self.__class__.__name__ + f'(gamma={self.gamma})'
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class SegRescale(object):
"""Rescale semantic segmentation maps.
@ -857,7 +857,7 @@ class SegRescale(object):
return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class PhotoMetricDistortion(object):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
@ -976,7 +976,7 @@ class PhotoMetricDistortion(object):
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomCutOut(object):
"""CutOut operation.
@ -1068,7 +1068,7 @@ class RandomCutOut(object):
return repr_str
@PIPELINES.register_module()
@TRANSFORMS.register_module()
class RandomMosaic(object):
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
one output image. The output image is composed of the parts from each sub-

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset

View File

@ -7,8 +7,10 @@ from torch.utils.data import Dataset
from torch.utils.data import DistributedSampler as _DistributedSampler
from mmseg.core.utils import sync_random_seed
from mmseg.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class DistributedSampler(_DistributedSampler):
"""DistributedSampler inheriting from
`torch.utils.data.DistributedSampler`.

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from .builder import DATASETS
from mmseg.registry import DATASETS
from .custom import CustomDataset

View File

@ -13,8 +13,8 @@ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.registry import MODELS
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
@ -227,7 +227,7 @@ class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer):
return x
@BACKBONES.register_module()
@MODELS.register_module()
class BEiT(BaseModule):
"""BERT Pre-Training of Image Transformers.

View File

@ -5,7 +5,7 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmseg.ops import resize
from ..builder import BACKBONES, build_backbone
from mmseg.registry import MODELS
class SpatialPath(BaseModule):
@ -156,7 +156,7 @@ class ContextPath(BaseModule):
assert len(context_channels) == 3, 'Length of input channels \
of Context Path must be 3!'
self.backbone = build_backbone(backbone_cfg)
self.backbone = MODELS.build(backbone_cfg)
self.align_corners = align_corners
self.arm16 = AttentionRefinementModule(context_channels[1],
@ -262,7 +262,7 @@ class FeatureFusionModule(BaseModule):
return x_out
@BACKBONES.register_module()
@MODELS.register_module()
class BiSeNetV1(BaseModule):
"""BiSeNetV1 backbone.

View File

@ -6,7 +6,7 @@ from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
from mmcv.runner import BaseModule
from mmseg.ops import resize
from ..builder import BACKBONES
from mmseg.registry import MODELS
class DetailBranch(BaseModule):
@ -541,7 +541,7 @@ class BGALayer(BaseModule):
return output
@BACKBONES.register_module()
@MODELS.register_module()
class BiSeNetV2(BaseModule):
"""BiSeNetV2: Bilateral Network with Guided Aggregation for
Real-time Semantic Segmentation.

View File

@ -8,7 +8,7 @@ from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
from mmseg.registry import MODELS
class GlobalContextExtractor(nn.Module):
@ -183,7 +183,7 @@ class InputInjection(nn.Module):
return x
@BACKBONES.register_module()
@MODELS.register_module()
class CGNet(BaseModule):
"""CGNet backbone.

View File

@ -5,7 +5,7 @@ from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule
from mmseg.ops import resize
from ..builder import BACKBONES
from mmseg.registry import MODELS
class DownsamplerBlock(BaseModule):
@ -191,7 +191,7 @@ class UpsamplerBlock(BaseModule):
return output
@BACKBONES.register_module()
@MODELS.register_module()
class ERFNet(BaseModule):
"""ERFNet backbone.

View File

@ -6,7 +6,7 @@ from mmcv.runner import BaseModule
from mmseg.models.decode_heads.psp_head import PPM
from mmseg.ops import resize
from ..builder import BACKBONES
from mmseg.registry import MODELS
from ..utils import InvertedResidual
@ -268,7 +268,7 @@ class FeatureFusionModule(nn.Module):
return self.relu(out)
@BACKBONES.register_module()
@MODELS.register_module()
class FastSCNN(BaseModule):
"""Fast-SCNN Backbone.

View File

@ -7,7 +7,7 @@ from mmcv.runner import BaseModule, ModuleList, Sequential
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.ops import Upsample, resize
from ..builder import BACKBONES
from mmseg.registry import MODELS
from .resnet import BasicBlock, Bottleneck
@ -214,7 +214,7 @@ class HRModule(BaseModule):
return x_fuse
@BACKBONES.register_module()
@MODELS.register_module()
class HRNet(BaseModule):
"""HRNet backbone.

View File

@ -5,11 +5,11 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmseg.ops import resize
from ..builder import BACKBONES, build_backbone
from mmseg.registry import MODELS
from ..decode_heads.psp_head import PPM
@BACKBONES.register_module()
@MODELS.register_module()
class ICNet(BaseModule):
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
@ -66,7 +66,7 @@ class ICNet(BaseModule):
]
super(ICNet, self).__init__(init_cfg=init_cfg)
self.align_corners = align_corners
self.backbone = build_backbone(backbone_cfg)
self.backbone = MODELS.build(backbone_cfg)
# Note: Default `ceil_mode` is false in nn.MaxPool2d, set
# `ceil_mode=True` to keep information in the corner of feature map.

View File

@ -8,8 +8,8 @@ from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
from mmcv.runner import ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.registry import MODELS
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
@ -42,7 +42,7 @@ class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer):
self.attn = MAEAttention(**attn_cfg)
@BACKBONES.register_module()
@MODELS.register_module()
class MAE(BEiT):
"""VisionTransformer with support for patch.

View File

@ -12,7 +12,7 @@ from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
trunc_normal_init)
from mmcv.runner import BaseModule, ModuleList, Sequential
from ..builder import BACKBONES
from mmseg.registry import MODELS
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
@ -295,7 +295,7 @@ class TransformerEncoderLayer(BaseModule):
return x
@BACKBONES.register_module()
@MODELS.register_module()
class MixVisionTransformer(BaseModule):
"""The backbone of Segformer.

View File

@ -6,11 +6,11 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
from mmseg.registry import MODELS
from ..utils import InvertedResidual, make_divisible
@BACKBONES.register_module()
@MODELS.register_module()
class MobileNetV2(BaseModule):
"""MobileNetV2 backbone.

View File

@ -7,11 +7,11 @@ from mmcv.cnn.bricks import Conv2dAdaptivePadding
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES
from mmseg.registry import MODELS
from ..utils import InvertedResidualV3 as InvertedResidual
@BACKBONES.register_module()
@MODELS.register_module()
class MobileNetV3(BaseModule):
"""MobileNetV3 backbone.

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from mmseg.registry import MODELS
from ..utils import ResLayer
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNetV1d
@ -267,7 +267,7 @@ class Bottleneck(_Bottleneck):
return out
@BACKBONES.register_module()
@MODELS.register_module()
class ResNeSt(ResNetV1d):
"""ResNeSt backbone.

View File

@ -7,7 +7,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
from mmseg.registry import MODELS
from ..utils import ResLayer
@ -307,7 +307,7 @@ class Bottleneck(BaseModule):
return out
@BACKBONES.register_module()
@MODELS.register_module()
class ResNet(BaseModule):
"""ResNet backbone.
@ -685,7 +685,7 @@ class ResNet(BaseModule):
m.eval()
@BACKBONES.register_module()
@MODELS.register_module()
class ResNetV1c(ResNet):
"""ResNetV1c variant described in [1]_.
@ -700,7 +700,7 @@ class ResNetV1c(ResNet):
deep_stem=True, avg_down=False, **kwargs)
@BACKBONES.register_module()
@MODELS.register_module()
class ResNetV1d(ResNet):
"""ResNetV1d variant described in [1]_.

View File

@ -3,7 +3,7 @@ import math
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
from mmseg.registry import MODELS
from ..utils import ResLayer
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
@ -84,7 +84,7 @@ class Bottleneck(_Bottleneck):
self.add_module(self.norm3_name, norm3)
@BACKBONES.register_module()
@MODELS.register_module()
class ResNeXt(ResNet):
"""ResNeXt backbone.

View File

@ -7,7 +7,7 @@ from mmcv.cnn import ConvModule
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmseg.ops import resize
from ..builder import BACKBONES, build_backbone
from mmseg.registry import MODELS
from .bisenetv1 import AttentionRefinementModule
@ -184,7 +184,7 @@ class FeatureFusionModule(BaseModule):
return x_attn + x
@BACKBONES.register_module()
@MODELS.register_module()
class STDCNet(BaseModule):
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
@ -325,7 +325,7 @@ class STDCNet(BaseModule):
return tuple(outs)
@BACKBONES.register_module()
@MODELS.register_module()
class STDCContextPathNet(BaseModule):
"""STDCNet with Context Path. The `outs` below is a list of three feature
maps from deep to shallow, whose height and width is from small to big,
@ -371,7 +371,7 @@ class STDCContextPathNet(BaseModule):
norm_cfg=dict(type='BN'),
init_cfg=None):
super(STDCContextPathNet, self).__init__(init_cfg=init_cfg)
self.backbone = build_backbone(backbone_cfg)
self.backbone = MODELS.build(backbone_cfg)
self.arms = ModuleList()
self.convs = ModuleList()
for channels in last_in_channels:

View File

@ -15,8 +15,8 @@ from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
load_state_dict)
from mmcv.utils import to_2tuple
from mmseg.registry import MODELS
from ...utils import get_root_logger
from ..builder import BACKBONES
from ..utils.embed import PatchEmbed, PatchMerging
@ -462,7 +462,7 @@ class SwinBlockSequence(BaseModule):
return x, hw_shape, x, hw_shape
@BACKBONES.register_module()
@MODELS.register_module()
class SwinTransformer(BaseModule):
"""Swin Transformer backbone.

View File

@ -7,10 +7,10 @@ except ImportError:
from mmcv.cnn.bricks.registry import NORM_LAYERS
from mmcv.runner import BaseModule
from ..builder import BACKBONES
from mmseg.registry import MODELS
@BACKBONES.register_module()
@MODELS.register_module()
class TIMMBackbone(BaseModule):
"""Wrapper to use backbones from timm library. More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_ .

View File

@ -14,7 +14,7 @@ from mmcv.runner import BaseModule, ModuleList
from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.models.backbones.mit import EfficientMultiheadAttention
from mmseg.models.builder import BACKBONES
from mmseg.registry import MODELS
from ..utils.embed import PatchEmbed
@ -349,7 +349,7 @@ class ConditionalPositionEncoding(BaseModule):
return x
@BACKBONES.register_module()
@MODELS.register_module()
class PCPVT(BaseModule):
"""The backbone of Twins-PCPVT.
@ -508,7 +508,7 @@ class PCPVT(BaseModule):
return tuple(outputs)
@BACKBONES.register_module()
@MODELS.register_module()
class SVT(PCPVT):
"""The backbone of Twins-SVT.

View File

@ -9,7 +9,7 @@ from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.ops import Upsample
from ..builder import BACKBONES
from mmseg.registry import MODELS
from ..utils import UpConvBlock
@ -221,7 +221,7 @@ class InterpConv(nn.Module):
return out
@BACKBONES.register_module()
@MODELS.register_module()
class UNet(BaseModule):
"""UNet backbone.

View File

@ -15,8 +15,8 @@ from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize
from mmseg.registry import MODELS
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed
@ -122,7 +122,7 @@ class TransformerEncoderLayer(BaseModule):
return x
@BACKBONES.register_module()
@MODELS.register_module()
class VisionTransformer(BaseModule):
"""Vision Transformer.

View File

@ -1,12 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
from mmcv.utils import Registry
MODELS = Registry('models', parent=MMCV_MODELS)
ATTENTION = Registry('attention', parent=MMCV_ATTENTION)
from mmseg.registry import MODELS
BACKBONES = MODELS
NECKS = MODELS
@ -17,21 +12,29 @@ SEGMENTORS = MODELS
def build_backbone(cfg):
"""Build backbone."""
warnings.warn('``build_backbone`` would be deprecated soon, please use '
'``mmseg.registry.MODELS.build()`` ')
return BACKBONES.build(cfg)
def build_neck(cfg):
"""Build neck."""
warnings.warn('``build_neck`` would be deprecated soon, please use '
'``mmseg.registry.MODELS.build()`` ')
return NECKS.build(cfg)
def build_head(cfg):
"""Build head."""
warnings.warn('``build_head`` would be deprecated soon, please use '
'``mmseg.registry.MODELS.build()`` ')
return HEADS.build(cfg)
def build_loss(cfg):
"""Build loss."""
warnings.warn('``build_loss`` would be deprecated soon, please use '
'``mmseg.registry.MODELS.build()`` ')
return LOSSES.build(cfg)

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from ..builder import HEADS
from mmseg.registry import MODELS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead
@ -181,7 +181,7 @@ class APNB(nn.Module):
return output
@HEADS.register_module()
@MODELS.register_module()
class ANNHead(BaseDecodeHead):
"""Asymmetric Non-local Neural Networks for Semantic Segmentation.

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@ -107,7 +107,7 @@ class ACM(nn.Module):
return z_out
@HEADS.register_module()
@MODELS.register_module()
class APCHead(BaseDecodeHead):
"""Adaptive Pyramid Context Network for Semantic Segmentation.

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@ -50,7 +50,7 @@ class ASPPModule(nn.ModuleList):
return aspp_outs
@HEADS.register_module()
@MODELS.register_module()
class ASPPHead(BaseDecodeHead):
"""Rethinking Atrous Convolution for Semantic Image Segmentation.

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..builder import HEADS
from mmseg.registry import MODELS
from .fcn_head import FCNHead
try:
@ -10,7 +10,7 @@ except ModuleNotFoundError:
CrissCrossAttention = None
@HEADS.register_module()
@MODELS.register_module()
class CCHead(FCNHead):
"""CCNet: Criss-Cross Attention for Semantic Segmentation.

View File

@ -5,7 +5,7 @@ from mmcv.cnn import ConvModule, Scale
from torch import nn
from mmseg.core import add_prefix
from ..builder import HEADS
from mmseg.registry import MODELS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead
@ -72,7 +72,7 @@ class CAM(nn.Module):
return out
@HEADS.register_module()
@MODELS.register_module()
class DAHead(BaseDecodeHead):
"""Dual Attention Network for Scene Segmentation.

View File

@ -4,7 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@ -89,7 +89,7 @@ class DCM(nn.Module):
return output
@HEADS.register_module()
@MODELS.register_module()
class DMHead(BaseDecodeHead):
"""Dynamic Multi-scale Filters for Semantic Segmentation.

View File

@ -3,7 +3,7 @@ import torch
from mmcv.cnn import NonLocal2d
from torch import nn
from ..builder import HEADS
from mmseg.registry import MODELS
from .fcn_head import FCNHead
@ -89,7 +89,7 @@ class DisentangledNonLocal2d(NonLocal2d):
return output
@HEADS.register_module()
@MODELS.register_module()
class DNLHead(FCNHead):
"""Disentangled Non-Local Neural Networks.

View File

@ -7,7 +7,7 @@ from mmcv.cnn import ConvModule, Linear, build_activation_layer
from mmcv.runner import BaseModule
from mmseg.ops import resize
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@ -212,7 +212,7 @@ class FeatureFusionBlock(BaseModule):
return x
@HEADS.register_module()
@MODELS.register_module()
class DPTHead(BaseDecodeHead):
"""Vision Transformers for Dense Prediction.

View File

@ -7,7 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@ -76,7 +76,7 @@ class EMAModule(nn.Module):
return feats_recon
@HEADS.register_module()
@MODELS.register_module()
class EMAHead(BaseDecodeHead):
"""Expectation Maximization Attention Networks for Semantic Segmentation.

View File

@ -5,7 +5,8 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_norm_layer
from mmseg.ops import Encoding, resize
from ..builder import HEADS, build_loss
from mmseg.registry import MODELS
from ..builder import build_loss
from .decode_head import BaseDecodeHead
@ -59,7 +60,7 @@ class EncModule(nn.Module):
return encoding_feat, output
@HEADS.register_module()
@MODELS.register_module()
class EncHead(BaseDecodeHead):
"""Context Encoding for Semantic Segmentation.

View File

@ -3,11 +3,11 @@ import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
@MODELS.register_module()
class FCNHead(BaseDecodeHead):
"""Fully Convolution Networks for Semantic Segmentation.

View File

@ -4,11 +4,11 @@ import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import Upsample, resize
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
@MODELS.register_module()
class FPNHead(BaseDecodeHead):
"""Panoptic Feature Pyramid Networks.

View File

@ -2,11 +2,11 @@
import torch
from mmcv.cnn import ContextBlock
from ..builder import HEADS
from mmseg.registry import MODELS
from .fcn_head import FCNHead
@HEADS.register_module()
@MODELS.register_module()
class GCHead(FCNHead):
"""GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.

View File

@ -5,7 +5,7 @@ import torch
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from ..builder import HEADS
from mmseg.registry import MODELS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead
@ -55,7 +55,7 @@ class SelfAttentionBlock(_SelfAttentionBlock):
return self.output_project(context)
@HEADS.register_module()
@MODELS.register_module()
class ISAHead(BaseDecodeHead):
"""Interlaced Sparse Self-Attention for Semantic Segmentation.

View File

@ -7,8 +7,8 @@ from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER,
MultiheadAttention,
build_transformer_layer)
from mmseg.models.builder import HEADS, build_head
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.registry import MODELS
from mmseg.utils import get_root_logger
@ -139,7 +139,7 @@ class KernelUpdator(nn.Module):
return features
@HEADS.register_module()
@MODELS.register_module()
class KernelUpdateHead(nn.Module):
"""Kernel Update Head in K-Net.
@ -391,7 +391,7 @@ class KernelUpdateHead(nn.Module):
self.conv_kernel_size)
@HEADS.register_module()
@MODELS.register_module()
class IterativeDecodeHead(BaseDecodeHead):
"""K-Net: Towards Unified Image Segmentation.
@ -414,7 +414,7 @@ class IterativeDecodeHead(BaseDecodeHead):
super(BaseDecodeHead, self).__init__(**kwargs)
assert num_stages == len(kernel_update_head)
self.num_stages = num_stages
self.kernel_generate_head = build_head(kernel_generate_head)
self.kernel_generate_head = (kernel_generate_head)
self.kernel_update_head = nn.ModuleList()
self.align_corners = self.kernel_generate_head.align_corners
self.num_classes = self.kernel_generate_head.num_classes
@ -422,7 +422,7 @@ class IterativeDecodeHead(BaseDecodeHead):
self.ignore_index = self.kernel_generate_head.ignore_index
for head_cfg in kernel_update_head:
self.kernel_update_head.append(build_head(head_cfg))
self.kernel_update_head.append(MODELS.build(head_cfg))
def forward(self, inputs):
"""Forward function."""

View File

@ -5,11 +5,11 @@ from mmcv import is_tuple_of
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
@MODELS.register_module()
class LRASPPHead(BaseDecodeHead):
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.

View File

@ -2,11 +2,11 @@
import torch
from mmcv.cnn import NonLocal2d
from ..builder import HEADS
from mmseg.registry import MODELS
from .fcn_head import FCNHead
@HEADS.register_module()
@MODELS.register_module()
class NLHead(FCNHead):
"""Non-local Neural Networks.

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from mmseg.registry import MODELS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .cascade_decode_head import BaseCascadeDecodeHead
@ -82,7 +82,7 @@ class ObjectAttentionBlock(_SelfAttentionBlock):
return output
@HEADS.register_module()
@MODELS.register_module()
class OCRHead(BaseCascadeDecodeHead):
"""Object-Contextual Representations for Semantic Segmentation.

View File

@ -10,8 +10,8 @@ try:
except ModuleNotFoundError:
point_sample = None
from mmseg.models.builder import HEADS
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..losses import accuracy
from .cascade_decode_head import BaseCascadeDecodeHead
@ -36,7 +36,7 @@ def calculate_uncertainty(seg_logits):
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
@HEADS.register_module()
@MODELS.register_module()
class PointHead(BaseCascadeDecodeHead):
"""A mask point head use in PointRend.

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
try:
@ -14,7 +14,7 @@ except ModuleNotFoundError:
PSAMask = None
@HEADS.register_module()
@MODELS.register_module()
class PSAHead(BaseDecodeHead):
"""Point-wise Spatial Attention Network for Scene Parsing.

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@ -59,7 +59,7 @@ class PPM(nn.ModuleList):
return ppm_outs
@HEADS.register_module()
@MODELS.register_module()
class PSPHead(BaseDecodeHead):
"""Pyramid Scene Parsing Network.

View File

@ -3,12 +3,12 @@ import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.models.builder import HEADS
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.ops import resize
from mmseg.registry import MODELS
@HEADS.register_module()
@MODELS.register_module()
class SegformerHead(BaseDecodeHead):
"""The all mlp Head of segformer.

View File

@ -8,11 +8,11 @@ from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
from mmcv.runner import ModuleList
from mmseg.models.backbones.vit import TransformerEncoderLayer
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
@MODELS.register_module()
class SegmenterMaskTransformerHead(BaseDecodeHead):
"""Segmenter: Transformer for Semantic Segmentation.

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmseg.ops import resize
from ..builder import HEADS
from mmseg.registry import MODELS
from .aspp_head import ASPPHead, ASPPModule
@ -26,7 +26,7 @@ class DepthwiseSeparableASPPModule(ASPPModule):
act_cfg=self.act_cfg)
@HEADS.register_module()
@MODELS.register_module()
class DepthwiseSeparableASPPHead(ASPPHead):
"""Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import DepthwiseSeparableConvModule
from ..builder import HEADS
from mmseg.registry import MODELS
from .fcn_head import FCNHead
@HEADS.register_module()
@MODELS.register_module()
class DepthwiseSeparableFCNHead(FCNHead):
"""Depthwise-Separable Fully Convolutional Network for Semantic
Segmentation.

View File

@ -4,11 +4,11 @@ import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import Upsample
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
@MODELS.register_module()
class SETRMLAHead(BaseDecodeHead):
"""Multi level feature aggretation head of SETR.

View File

@ -3,11 +3,11 @@ import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from mmseg.ops import Upsample
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
@HEADS.register_module()
@MODELS.register_module()
class SETRUPHead(BaseDecodeHead):
"""Naive upsampling head and Progressive upsampling head of SETR.

View File

@ -2,11 +2,11 @@
import torch
import torch.nn.functional as F
from ..builder import HEADS
from mmseg.registry import MODELS
from .fcn_head import FCNHead
@HEADS.register_module()
@MODELS.register_module()
class STDCHead(FCNHead):
"""This head is the implementation of `Rethinking BiSeNet For Real-time
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.

View File

@ -4,12 +4,12 @@ import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from ..builder import HEADS
from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead
from .psp_head import PPM
@HEADS.register_module()
@MODELS.register_module()
class UPerHead(BaseDecodeHead):
"""Unified Perceptual Parsing for Scene Understanding.

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from mmseg.registry import MODELS
from .utils import get_class_weight, weight_reduce_loss
@ -193,7 +193,7 @@ def mask_cross_entropy(pred,
pred_slice, target, weight=class_weight, reduction='mean')[None]
@LOSSES.register_module()
@MODELS.register_module()
class CrossEntropyLoss(nn.Module):
"""CrossEntropyLoss.

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from mmseg.registry import MODELS
from .utils import get_class_weight, weighted_loss
@ -47,7 +47,7 @@ def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
return 1 - num / den
@LOSSES.register_module()
@MODELS.register_module()
class DiceLoss(nn.Module):
"""DiceLoss.

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
from ..builder import LOSSES
from mmseg.registry import MODELS
from .utils import weight_reduce_loss
@ -133,7 +133,7 @@ def sigmoid_focal_loss(pred,
return loss
@LOSSES.register_module()
@MODELS.register_module()
class FocalLoss(nn.Module):
def __init__(self,

View File

@ -8,7 +8,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
from mmseg.registry import MODELS
from .utils import get_class_weight, weight_reduce_loss
@ -222,7 +222,7 @@ def lovasz_softmax(probs,
return loss
@LOSSES.register_module()
@MODELS.register_module()
class LovaszLoss(nn.Module):
"""LovaszLoss.

View File

@ -2,10 +2,10 @@
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from ..builder import NECKS
from mmseg.registry import MODELS
@NECKS.register_module()
@MODELS.register_module()
class Feature2Pyramid(nn.Module):
"""Feature2Pyramid.

View File

@ -5,10 +5,10 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16
from mmseg.ops import resize
from ..builder import NECKS
from mmseg.registry import MODELS
@NECKS.register_module()
@MODELS.register_module()
class FPN(BaseModule):
"""Feature Pyramid Network.

View File

@ -4,7 +4,7 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from mmseg.ops import resize
from ..builder import NECKS
from mmseg.registry import MODELS
class CascadeFeatureFusion(BaseModule):
@ -77,7 +77,7 @@ class CascadeFeatureFusion(BaseModule):
return x, x_low
@NECKS.register_module()
@MODELS.register_module()
class ICNeck(BaseModule):
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.

View File

@ -5,10 +5,10 @@ from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmcv.runner import BaseModule
from mmseg.ops import resize
from ..builder import NECKS
from mmseg.registry import MODELS
@NECKS.register_module()
@MODELS.register_module()
class JPU(BaseModule):
"""FastFCN: Rethinking Dilated Convolution in the Backbone
for Semantic Segmentation.

View File

@ -2,7 +2,7 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from ..builder import NECKS
from mmseg.registry import MODELS
class MLAModule(nn.Module):
@ -59,7 +59,7 @@ class MLAModule(nn.Module):
return tuple(out_list)
@NECKS.register_module()
@MODELS.register_module()
class MLANeck(nn.Module):
"""Multi-level Feature Aggregation.

View File

@ -3,10 +3,10 @@ import torch.nn as nn
from mmcv.cnn import ConvModule, xavier_init
from mmseg.ops import resize
from ..builder import NECKS
from mmseg.registry import MODELS
@NECKS.register_module()
@MODELS.register_module()
class MultiLevelNeck(nn.Module):
"""MultiLevelNeck.

View File

@ -3,12 +3,11 @@ from torch import nn
from mmseg.core import add_prefix
from mmseg.ops import resize
from .. import builder
from ..builder import SEGMENTORS
from mmseg.registry import MODELS
from .encoder_decoder import EncoderDecoder
@SEGMENTORS.register_module()
@MODELS.register_module()
class CascadeEncoderDecoder(EncoderDecoder):
"""Cascade Encoder Decoder segmentors.
@ -44,7 +43,7 @@ class CascadeEncoderDecoder(EncoderDecoder):
assert len(decode_head) == self.num_stages
self.decode_head = nn.ModuleList()
for i in range(self.num_stages):
self.decode_head.append(builder.build_head(decode_head[i]))
self.decode_head.append(MODELS.build(decode_head[i]))
self.align_corners = self.decode_head[-1].align_corners
self.num_classes = self.decode_head[-1].num_classes

View File

@ -5,12 +5,11 @@ import torch.nn.functional as F
from mmseg.core import add_prefix
from mmseg.ops import resize
from .. import builder
from ..builder import SEGMENTORS
from mmseg.registry import MODELS
from .base import BaseSegmentor
@SEGMENTORS.register_module()
@MODELS.register_module()
class EncoderDecoder(BaseSegmentor):
"""Encoder Decoder segmentors.
@ -33,9 +32,9 @@ class EncoderDecoder(BaseSegmentor):
assert backbone.get('pretrained') is None, \
'both backbone and segmentor set pretrained weight'
backbone.pretrained = pretrained
self.backbone = builder.build_backbone(backbone)
self.backbone = MODELS.build(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
self.neck = MODELS.build(neck)
self._init_decode_head(decode_head)
self._init_auxiliary_head(auxiliary_head)
@ -46,7 +45,7 @@ class EncoderDecoder(BaseSegmentor):
def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``"""
self.decode_head = builder.build_head(decode_head)
self.decode_head = MODELS.build(decode_head)
self.align_corners = self.decode_head.align_corners
self.num_classes = self.decode_head.num_classes
@ -56,9 +55,9 @@ class EncoderDecoder(BaseSegmentor):
if isinstance(auxiliary_head, list):
self.auxiliary_head = nn.ModuleList()
for head_cfg in auxiliary_head:
self.auxiliary_head.append(builder.build_head(head_cfg))
self.auxiliary_head.append(MODELS.build(head_cfg))
else:
self.auxiliary_head = builder.build_head(auxiliary_head)
self.auxiliary_head = MODELS.build(auxiliary_head)
def extract_feat(self, img):
"""Extract features from images."""

View File

@ -0,0 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS,
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS,
RUNNERS, TASK_UTILS, TRANSFORMS, VISBACKENDS,
VISUALIZERS, WEIGHT_INITIALIZERS)
__all__ = [
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS',
'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS',
'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS',
'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS'
]

View File

@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""MMSegmentation provides 17 registry nodes to support using modules across
projects. Each node is a child of the root registry in MMEngine.
More details can be found at
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
"""
from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
from mmengine.registry import DATASETS as MMENGINE_DATASETS
from mmengine.registry import HOOKS as MMENGINE_HOOKS
from mmengine.registry import LOOPS as MMENGINE_LOOPS
from mmengine.registry import METRICS as MMENGINE_METRICS
from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
from mmengine.registry import MODELS as MMENGINE_MODELS
from mmengine.registry import \
OPTIMIZER_CONSTRUCTORS as MMENGINE_OPTIMIZER_CONSTRUCTORS
from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
from mmengine.registry import \
RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS
from mmengine.registry import RUNNERS as MMENGINE_RUNNERS
from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS
from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS
from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS
from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS
from mmengine.registry import \
WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS
from mmengine.registry import Registry
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
# manage runner constructors that define how to initialize runners
RUNNER_CONSTRUCTORS = Registry(
'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
# manage all kinds of loops like `EpochBasedTrainLoop`
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
# manage all kinds of hooks like `CheckpointHook`
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
# manage data-related modules
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
# mangage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS)
# mangage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
# mangage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry(
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
# mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
# manage constructors that customize the optimization hyperparameters.
OPTIMIZER_CONSTRUCTORS = Registry(
'optimizer constructor', parent=MMENGINE_OPTIMIZER_CONSTRUCTORS)
# mangage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
# manage all kinds of metrics
METRICS = Registry('metric', parent=MMENGINE_METRICS)
# manage task-specific modules like ohem pixel sampler
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
# manage visualizer
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
# manage visualizer backend
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)

View File

@ -2,10 +2,10 @@
import pytest
import torch
import torch.nn as nn
from mmcv.runner import DefaultOptimizerConstructor
from mmengine.optim import DefaultOptimizerConstructor
from mmseg.core.builder import (OPTIMIZER_BUILDERS, build_optimizer,
build_optimizer_constructor)
from mmseg.core.builder import build_optimizer, build_optimizer_constructor
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
class ExampleModel(nn.Module):
@ -35,7 +35,7 @@ def test_build_optimizer_constructor():
# Test whether optimizer constructor can be built from parent.
assert type(optim_constructor) is DefaultOptimizerConstructor
@OPTIMIZER_BUILDERS.register_module()
@OPTIMIZER_CONSTRUCTORS.register_module()
class MyOptimizerConstructor(DefaultOptimizerConstructor):
pass