Merge branch 'linfangjian/refactor_registies' into 'refactor_dev'

[Refactor] Refactor all registries

See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!3
pull/1801/head
zhengmiao 2022-05-10 12:15:20 +00:00
commit 168c8831dd
94 changed files with 312 additions and 229 deletions

View File

@ -1,11 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import (OPTIMIZER_BUILDERS, build_optimizer, from .builder import build_optimizer, build_optimizer_constructor
build_optimizer_constructor)
from .evaluation import * # noqa: F401, F403 from .evaluation import * # noqa: F401, F403
from .optimizers import * # noqa: F401, F403 from .optimizers import * # noqa: F401, F403
from .seg import * # noqa: F401, F403 from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403 from .utils import * # noqa: F401, F403
__all__ = [ __all__ = ['build_optimizer', 'build_optimizer_constructor']
'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor'
]

View File

@ -1,19 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS from mmseg.registry import OPTIMIZER_CONSTRUCTORS
from mmcv.utils import Registry, build_from_cfg
OPTIMIZER_BUILDERS = Registry(
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
def build_optimizer_constructor(cfg): def build_optimizer_constructor(cfg):
constructor_type = cfg.get('type') constructor_type = cfg.get('type')
if constructor_type in OPTIMIZER_BUILDERS: if constructor_type in OPTIMIZER_CONSTRUCTORS:
return build_from_cfg(cfg, OPTIMIZER_BUILDERS) return OPTIMIZER_CONSTRUCTORS.build(cfg)
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
else: else:
raise KeyError(f'{constructor_type} is not registered ' raise KeyError(f'{constructor_type} is not registered '
'in the optimizer builder registry.') 'in the optimizer builder registry.')

View File

@ -2,10 +2,11 @@
import json import json
import warnings import warnings
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info from mmengine.dist import get_dist_info
from mmengine.optim import DefaultOptimizerConstructor
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
from ..builder import OPTIMIZER_BUILDERS
def get_layer_id_for_convnext(var_name, max_layer_id): def get_layer_id_for_convnext(var_name, max_layer_id):
@ -99,7 +100,7 @@ def get_layer_id_for_vit(var_name, max_layer_id):
return max_layer_id - 1 return max_layer_id - 1
@OPTIMIZER_BUILDERS.register_module() @OPTIMIZER_CONSTRUCTORS.register_module()
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor): class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
"""Different learning rates are set for different layers of backbone. """Different learning rates are set for different layers of backbone.
@ -185,7 +186,7 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
params.extend(parameter_groups.values()) params.extend(parameter_groups.values())
@OPTIMIZER_BUILDERS.register_module() @OPTIMIZER_CONSTRUCTORS.register_module()
class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor): class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
"""Different learning rates are set for different layers of backbone. """Different learning rates are set for different layers of backbone.

View File

@ -1,9 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, build_from_cfg import warnings
PIXEL_SAMPLERS = Registry('pixel sampler') from mmseg.registry import TASK_UTILS
PIXEL_SAMPLERS = TASK_UTILS
def build_pixel_sampler(cfg, **default_args): def build_pixel_sampler(cfg, **default_args):
"""Build pixel sampler for segmentation map.""" """Build pixel sampler for segmentation map."""
return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) warnings.warn(
'``build_pixel_sampler`` would be deprecated soon, please use '
'``mmseg.registry.TASK_UTILS.build()`` ')
return TASK_UTILS.build(cfg, default_args=default_args)

View File

@ -5,7 +5,7 @@ import mmcv
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from .builder import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -8,9 +8,10 @@ import numpy as np
import torch import torch
from mmcv.parallel import collate from mmcv.parallel import collate
from mmcv.runner import get_dist_info from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg, digit_version from mmcv.utils import digit_version
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmseg.registry import DATASETS, TRANSFORMS
from .samplers import DistributedSampler from .samplers import DistributedSampler
if platform.system() != 'Windows': if platform.system() != 'Windows':
@ -22,8 +23,7 @@ if platform.system() != 'Windows':
soft_limit = min(max(4096, base_soft_limit), hard_limit) soft_limit = min(max(4096, base_soft_limit), hard_limit)
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
DATASETS = Registry('dataset') PIPELINES = TRANSFORMS
PIPELINES = Registry('pipeline')
def _concat_dataset(cfg, default_args=None): def _concat_dataset(cfg, default_args=None):
@ -82,7 +82,7 @@ def build_dataset(cfg, default_args=None):
cfg.get('split', None), (list, tuple)): cfg.get('split', None), (list, tuple)):
dataset = _concat_dataset(cfg, default_args) dataset = _concat_dataset(cfg, default_args)
else: else:
dataset = build_from_cfg(cfg, DATASETS, default_args) dataset = DATASETS.build(cfg, default_args=default_args)
return dataset return dataset

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -6,7 +6,7 @@ import numpy as np
from mmcv.utils import print_log from mmcv.utils import print_log
from PIL import Image from PIL import Image
from .builder import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -10,8 +10,8 @@ from prettytable import PrettyTable
from torch.utils.data import Dataset from torch.utils.data import Dataset
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
from mmseg.registry import DATASETS
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
from .builder import DATASETS
from .pipelines import Compose, LoadAnnotations from .pipelines import Compose, LoadAnnotations

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS from mmseg.registry import DATASETS
from .cityscapes import CityscapesDataset from .cityscapes import CityscapesDataset

View File

@ -6,10 +6,10 @@ from itertools import chain
import mmcv import mmcv
import numpy as np import numpy as np
from mmcv.utils import build_from_cfg, print_log from mmcv.utils import print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
from .builder import DATASETS, PIPELINES from mmseg.registry import DATASETS, TRANSFORMS
from .cityscapes import CityscapesDataset from .cityscapes import CityscapesDataset
@ -225,7 +225,7 @@ class MultiImageMixDataset:
for transform in pipeline: for transform in pipeline:
if isinstance(transform, dict): if isinstance(transform, dict):
self.pipeline_types.append(transform['type']) self.pipeline_types.append(transform['type'])
transform = build_from_cfg(transform, PIPELINES) transform = TRANSFORMS.build(transform)
self.pipeline.append(transform) self.pipeline.append(transform)
else: else:
raise TypeError('pipeline must be a dict') raise TypeError('pipeline must be a dict')

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -3,8 +3,8 @@
import mmcv import mmcv
from mmcv.utils import print_log from mmcv.utils import print_log
from mmseg.registry import DATASETS
from ..utils import get_root_logger from ..utils import get_root_logger
from .builder import DATASETS
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -5,7 +5,7 @@ import mmcv
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from .builder import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS from mmseg.registry import DATASETS
from .cityscapes import CityscapesDataset from .cityscapes import CityscapesDataset

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -1,12 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import collections import collections
from mmcv.utils import build_from_cfg from mmseg.registry import TRANSFORMS
from ..builder import PIPELINES
@PIPELINES.register_module() @TRANSFORMS.register_module()
class Compose(object): class Compose(object):
"""Compose multiple transforms sequentially. """Compose multiple transforms sequentially.
@ -20,7 +18,7 @@ class Compose(object):
self.transforms = [] self.transforms = []
for transform in transforms: for transform in transforms:
if isinstance(transform, dict): if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES) transform = TRANSFORMS.build(transform)
self.transforms.append(transform) self.transforms.append(transform)
elif callable(transform): elif callable(transform):
self.transforms.append(transform) self.transforms.append(transform)

View File

@ -6,7 +6,7 @@ import numpy as np
import torch import torch
from mmcv.parallel import DataContainer as DC from mmcv.parallel import DataContainer as DC
from ..builder import PIPELINES from mmseg.registry import TRANSFORMS
def to_tensor(data): def to_tensor(data):
@ -34,7 +34,7 @@ def to_tensor(data):
raise TypeError(f'type {type(data)} cannot be converted to tensor.') raise TypeError(f'type {type(data)} cannot be converted to tensor.')
@PIPELINES.register_module() @TRANSFORMS.register_module()
class ToTensor(object): class ToTensor(object):
"""Convert some results to :obj:`torch.Tensor` by given keys. """Convert some results to :obj:`torch.Tensor` by given keys.
@ -64,7 +64,7 @@ class ToTensor(object):
return self.__class__.__name__ + f'(keys={self.keys})' return self.__class__.__name__ + f'(keys={self.keys})'
@PIPELINES.register_module() @TRANSFORMS.register_module()
class ImageToTensor(object): class ImageToTensor(object):
"""Convert image to :obj:`torch.Tensor` by given keys. """Convert image to :obj:`torch.Tensor` by given keys.
@ -102,7 +102,7 @@ class ImageToTensor(object):
return self.__class__.__name__ + f'(keys={self.keys})' return self.__class__.__name__ + f'(keys={self.keys})'
@PIPELINES.register_module() @TRANSFORMS.register_module()
class Transpose(object): class Transpose(object):
"""Transpose some results by given keys. """Transpose some results by given keys.
@ -136,7 +136,7 @@ class Transpose(object):
f'(keys={self.keys}, order={self.order})' f'(keys={self.keys}, order={self.order})'
@PIPELINES.register_module() @TRANSFORMS.register_module()
class ToDataContainer(object): class ToDataContainer(object):
"""Convert results to :obj:`mmcv.DataContainer` by given fields. """Convert results to :obj:`mmcv.DataContainer` by given fields.
@ -175,7 +175,7 @@ class ToDataContainer(object):
return self.__class__.__name__ + f'(fields={self.fields})' return self.__class__.__name__ + f'(fields={self.fields})'
@PIPELINES.register_module() @TRANSFORMS.register_module()
class DefaultFormatBundle(object): class DefaultFormatBundle(object):
"""Default formatting bundle. """Default formatting bundle.
@ -216,7 +216,7 @@ class DefaultFormatBundle(object):
return self.__class__.__name__ return self.__class__.__name__
@PIPELINES.register_module() @TRANSFORMS.register_module()
class Collect(object): class Collect(object):
"""Collect data from the loader relevant to the specific task. """Collect data from the loader relevant to the specific task.

View File

@ -4,10 +4,10 @@ import os.path as osp
import mmcv import mmcv
import numpy as np import numpy as np
from ..builder import PIPELINES from mmseg.registry import TRANSFORMS
@PIPELINES.register_module() @TRANSFORMS.register_module()
class LoadImageFromFile(object): class LoadImageFromFile(object):
"""Load an image from file. """Load an image from file.
@ -87,7 +87,7 @@ class LoadImageFromFile(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class LoadAnnotations(object): class LoadAnnotations(object):
"""Load annotations for semantic segmentation. """Load annotations for semantic segmentation.

View File

@ -3,11 +3,11 @@ import warnings
import mmcv import mmcv
from ..builder import PIPELINES from mmseg.registry import TRANSFORMS
from .compose import Compose from .compose import Compose
@PIPELINES.register_module() @TRANSFORMS.register_module()
class MultiScaleFlipAug(object): class MultiScaleFlipAug(object):
"""Test-time augmentation with multiple scales and flipping. """Test-time augmentation with multiple scales and flipping.

View File

@ -6,10 +6,10 @@ import numpy as np
from mmcv.utils import deprecated_api_warning, is_tuple_of from mmcv.utils import deprecated_api_warning, is_tuple_of
from numpy import random from numpy import random
from ..builder import PIPELINES from mmseg.registry import TRANSFORMS
@PIPELINES.register_module() @TRANSFORMS.register_module()
class ResizeToMultiple(object): class ResizeToMultiple(object):
"""Resize images & seg to multiple of divisor. """Resize images & seg to multiple of divisor.
@ -66,7 +66,7 @@ class ResizeToMultiple(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class Resize(object): class Resize(object):
"""Resize images & seg. """Resize images & seg.
@ -321,7 +321,7 @@ class Resize(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class RandomFlip(object): class RandomFlip(object):
"""Flip the image & seg. """Flip the image & seg.
@ -376,7 +376,7 @@ class RandomFlip(object):
return self.__class__.__name__ + f'(prob={self.prob})' return self.__class__.__name__ + f'(prob={self.prob})'
@PIPELINES.register_module() @TRANSFORMS.register_module()
class Pad(object): class Pad(object):
"""Pad the image & mask. """Pad the image & mask.
@ -447,7 +447,7 @@ class Pad(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class Normalize(object): class Normalize(object):
"""Normalize the image. """Normalize the image.
@ -489,7 +489,7 @@ class Normalize(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class Rerange(object): class Rerange(object):
"""Rerange the image pixel value. """Rerange the image pixel value.
@ -535,7 +535,7 @@ class Rerange(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class CLAHE(object): class CLAHE(object):
"""Use CLAHE method to process the image. """Use CLAHE method to process the image.
@ -580,7 +580,7 @@ class CLAHE(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class RandomCrop(object): class RandomCrop(object):
"""Random crop the image & seg. """Random crop the image & seg.
@ -653,7 +653,7 @@ class RandomCrop(object):
return self.__class__.__name__ + f'(crop_size={self.crop_size})' return self.__class__.__name__ + f'(crop_size={self.crop_size})'
@PIPELINES.register_module() @TRANSFORMS.register_module()
class RandomRotate(object): class RandomRotate(object):
"""Rotate the image & seg. """Rotate the image & seg.
@ -736,7 +736,7 @@ class RandomRotate(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class RGB2Gray(object): class RGB2Gray(object):
"""Convert RGB image to grayscale image. """Convert RGB image to grayscale image.
@ -791,7 +791,7 @@ class RGB2Gray(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class AdjustGamma(object): class AdjustGamma(object):
"""Using gamma correction to process the image. """Using gamma correction to process the image.
@ -827,7 +827,7 @@ class AdjustGamma(object):
return self.__class__.__name__ + f'(gamma={self.gamma})' return self.__class__.__name__ + f'(gamma={self.gamma})'
@PIPELINES.register_module() @TRANSFORMS.register_module()
class SegRescale(object): class SegRescale(object):
"""Rescale semantic segmentation maps. """Rescale semantic segmentation maps.
@ -857,7 +857,7 @@ class SegRescale(object):
return self.__class__.__name__ + f'(scale_factor={self.scale_factor})' return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
@PIPELINES.register_module() @TRANSFORMS.register_module()
class PhotoMetricDistortion(object): class PhotoMetricDistortion(object):
"""Apply photometric distortion to image sequentially, every transformation """Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in is applied with a probability of 0.5. The position of random contrast is in
@ -976,7 +976,7 @@ class PhotoMetricDistortion(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class RandomCutOut(object): class RandomCutOut(object):
"""CutOut operation. """CutOut operation.
@ -1068,7 +1068,7 @@ class RandomCutOut(object):
return repr_str return repr_str
@PIPELINES.register_module() @TRANSFORMS.register_module()
class RandomMosaic(object): class RandomMosaic(object):
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into """Mosaic augmentation. Given 4 images, mosaic transform combines them into
one output image. The output image is composed of the parts from each sub- one output image. The output image is composed of the parts from each sub-

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -7,8 +7,10 @@ from torch.utils.data import Dataset
from torch.utils.data import DistributedSampler as _DistributedSampler from torch.utils.data import DistributedSampler as _DistributedSampler
from mmseg.core.utils import sync_random_seed from mmseg.core.utils import sync_random_seed
from mmseg.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class DistributedSampler(_DistributedSampler): class DistributedSampler(_DistributedSampler):
"""DistributedSampler inheriting from """DistributedSampler inheriting from
`torch.utils.data.DistributedSampler`. `torch.utils.data.DistributedSampler`.

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
from .builder import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
from .builder import DATASETS from mmseg.registry import DATASETS
from .custom import CustomDataset from .custom import CustomDataset

View File

@ -13,8 +13,8 @@ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.registry import MODELS
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed from ..utils import PatchEmbed
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
@ -227,7 +227,7 @@ class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer):
return x return x
@BACKBONES.register_module() @MODELS.register_module()
class BEiT(BaseModule): class BEiT(BaseModule):
"""BERT Pre-Training of Image Transformers. """BERT Pre-Training of Image Transformers.

View File

@ -5,7 +5,7 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import BACKBONES, build_backbone from mmseg.registry import MODELS
class SpatialPath(BaseModule): class SpatialPath(BaseModule):
@ -156,7 +156,7 @@ class ContextPath(BaseModule):
assert len(context_channels) == 3, 'Length of input channels \ assert len(context_channels) == 3, 'Length of input channels \
of Context Path must be 3!' of Context Path must be 3!'
self.backbone = build_backbone(backbone_cfg) self.backbone = MODELS.build(backbone_cfg)
self.align_corners = align_corners self.align_corners = align_corners
self.arm16 = AttentionRefinementModule(context_channels[1], self.arm16 = AttentionRefinementModule(context_channels[1],
@ -262,7 +262,7 @@ class FeatureFusionModule(BaseModule):
return x_out return x_out
@BACKBONES.register_module() @MODELS.register_module()
class BiSeNetV1(BaseModule): class BiSeNetV1(BaseModule):
"""BiSeNetV1 backbone. """BiSeNetV1 backbone.

View File

@ -6,7 +6,7 @@ from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import BACKBONES from mmseg.registry import MODELS
class DetailBranch(BaseModule): class DetailBranch(BaseModule):
@ -541,7 +541,7 @@ class BGALayer(BaseModule):
return output return output
@BACKBONES.register_module() @MODELS.register_module()
class BiSeNetV2(BaseModule): class BiSeNetV2(BaseModule):
"""BiSeNetV2: Bilateral Network with Guided Aggregation for """BiSeNetV2: Bilateral Network with Guided Aggregation for
Real-time Semantic Segmentation. Real-time Semantic Segmentation.

View File

@ -8,7 +8,7 @@ from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES from mmseg.registry import MODELS
class GlobalContextExtractor(nn.Module): class GlobalContextExtractor(nn.Module):
@ -183,7 +183,7 @@ class InputInjection(nn.Module):
return x return x
@BACKBONES.register_module() @MODELS.register_module()
class CGNet(BaseModule): class CGNet(BaseModule):
"""CGNet backbone. """CGNet backbone.

View File

@ -5,7 +5,7 @@ from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import BACKBONES from mmseg.registry import MODELS
class DownsamplerBlock(BaseModule): class DownsamplerBlock(BaseModule):
@ -191,7 +191,7 @@ class UpsamplerBlock(BaseModule):
return output return output
@BACKBONES.register_module() @MODELS.register_module()
class ERFNet(BaseModule): class ERFNet(BaseModule):
"""ERFNet backbone. """ERFNet backbone.

View File

@ -6,7 +6,7 @@ from mmcv.runner import BaseModule
from mmseg.models.decode_heads.psp_head import PPM from mmseg.models.decode_heads.psp_head import PPM
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import BACKBONES from mmseg.registry import MODELS
from ..utils import InvertedResidual from ..utils import InvertedResidual
@ -268,7 +268,7 @@ class FeatureFusionModule(nn.Module):
return self.relu(out) return self.relu(out)
@BACKBONES.register_module() @MODELS.register_module()
class FastSCNN(BaseModule): class FastSCNN(BaseModule):
"""Fast-SCNN Backbone. """Fast-SCNN Backbone.

View File

@ -7,7 +7,7 @@ from mmcv.runner import BaseModule, ModuleList, Sequential
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.ops import Upsample, resize from mmseg.ops import Upsample, resize
from ..builder import BACKBONES from mmseg.registry import MODELS
from .resnet import BasicBlock, Bottleneck from .resnet import BasicBlock, Bottleneck
@ -214,7 +214,7 @@ class HRModule(BaseModule):
return x_fuse return x_fuse
@BACKBONES.register_module() @MODELS.register_module()
class HRNet(BaseModule): class HRNet(BaseModule):
"""HRNet backbone. """HRNet backbone.

View File

@ -5,11 +5,11 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import BACKBONES, build_backbone from mmseg.registry import MODELS
from ..decode_heads.psp_head import PPM from ..decode_heads.psp_head import PPM
@BACKBONES.register_module() @MODELS.register_module()
class ICNet(BaseModule): class ICNet(BaseModule):
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images. """ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
@ -66,7 +66,7 @@ class ICNet(BaseModule):
] ]
super(ICNet, self).__init__(init_cfg=init_cfg) super(ICNet, self).__init__(init_cfg=init_cfg)
self.align_corners = align_corners self.align_corners = align_corners
self.backbone = build_backbone(backbone_cfg) self.backbone = MODELS.build(backbone_cfg)
# Note: Default `ceil_mode` is false in nn.MaxPool2d, set # Note: Default `ceil_mode` is false in nn.MaxPool2d, set
# `ceil_mode=True` to keep information in the corner of feature map. # `ceil_mode=True` to keep information in the corner of feature map.

View File

@ -8,8 +8,8 @@ from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
from mmcv.runner import ModuleList, _load_checkpoint from mmcv.runner import ModuleList, _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.registry import MODELS
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
@ -42,7 +42,7 @@ class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer):
self.attn = MAEAttention(**attn_cfg) self.attn = MAEAttention(**attn_cfg)
@BACKBONES.register_module() @MODELS.register_module()
class MAE(BEiT): class MAE(BEiT):
"""VisionTransformer with support for patch. """VisionTransformer with support for patch.

View File

@ -12,7 +12,7 @@ from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
trunc_normal_init) trunc_normal_init)
from mmcv.runner import BaseModule, ModuleList, Sequential from mmcv.runner import BaseModule, ModuleList, Sequential
from ..builder import BACKBONES from mmseg.registry import MODELS
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
@ -295,7 +295,7 @@ class TransformerEncoderLayer(BaseModule):
return x return x
@BACKBONES.register_module() @MODELS.register_module()
class MixVisionTransformer(BaseModule): class MixVisionTransformer(BaseModule):
"""The backbone of Segformer. """The backbone of Segformer.

View File

@ -6,11 +6,11 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES from mmseg.registry import MODELS
from ..utils import InvertedResidual, make_divisible from ..utils import InvertedResidual, make_divisible
@BACKBONES.register_module() @MODELS.register_module()
class MobileNetV2(BaseModule): class MobileNetV2(BaseModule):
"""MobileNetV2 backbone. """MobileNetV2 backbone.

View File

@ -7,11 +7,11 @@ from mmcv.cnn.bricks import Conv2dAdaptivePadding
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from ..builder import BACKBONES from mmseg.registry import MODELS
from ..utils import InvertedResidualV3 as InvertedResidual from ..utils import InvertedResidualV3 as InvertedResidual
@BACKBONES.register_module() @MODELS.register_module()
class MobileNetV3(BaseModule): class MobileNetV3(BaseModule):
"""MobileNetV3 backbone. """MobileNetV3 backbone.

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES from mmseg.registry import MODELS
from ..utils import ResLayer from ..utils import ResLayer
from .resnet import Bottleneck as _Bottleneck from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNetV1d from .resnet import ResNetV1d
@ -267,7 +267,7 @@ class Bottleneck(_Bottleneck):
return out return out
@BACKBONES.register_module() @MODELS.register_module()
class ResNeSt(ResNetV1d): class ResNeSt(ResNetV1d):
"""ResNeSt backbone. """ResNeSt backbone.

View File

@ -7,7 +7,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES from mmseg.registry import MODELS
from ..utils import ResLayer from ..utils import ResLayer
@ -307,7 +307,7 @@ class Bottleneck(BaseModule):
return out return out
@BACKBONES.register_module() @MODELS.register_module()
class ResNet(BaseModule): class ResNet(BaseModule):
"""ResNet backbone. """ResNet backbone.
@ -685,7 +685,7 @@ class ResNet(BaseModule):
m.eval() m.eval()
@BACKBONES.register_module() @MODELS.register_module()
class ResNetV1c(ResNet): class ResNetV1c(ResNet):
"""ResNetV1c variant described in [1]_. """ResNetV1c variant described in [1]_.
@ -700,7 +700,7 @@ class ResNetV1c(ResNet):
deep_stem=True, avg_down=False, **kwargs) deep_stem=True, avg_down=False, **kwargs)
@BACKBONES.register_module() @MODELS.register_module()
class ResNetV1d(ResNet): class ResNetV1d(ResNet):
"""ResNetV1d variant described in [1]_. """ResNetV1d variant described in [1]_.

View File

@ -3,7 +3,7 @@ import math
from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES from mmseg.registry import MODELS
from ..utils import ResLayer from ..utils import ResLayer
from .resnet import Bottleneck as _Bottleneck from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet from .resnet import ResNet
@ -84,7 +84,7 @@ class Bottleneck(_Bottleneck):
self.add_module(self.norm3_name, norm3) self.add_module(self.norm3_name, norm3)
@BACKBONES.register_module() @MODELS.register_module()
class ResNeXt(ResNet): class ResNeXt(ResNet):
"""ResNeXt backbone. """ResNeXt backbone.

View File

@ -7,7 +7,7 @@ from mmcv.cnn import ConvModule
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import BACKBONES, build_backbone from mmseg.registry import MODELS
from .bisenetv1 import AttentionRefinementModule from .bisenetv1 import AttentionRefinementModule
@ -184,7 +184,7 @@ class FeatureFusionModule(BaseModule):
return x_attn + x return x_attn + x
@BACKBONES.register_module() @MODELS.register_module()
class STDCNet(BaseModule): class STDCNet(BaseModule):
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time """This backbone is the implementation of `Rethinking BiSeNet For Real-time
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_. Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
@ -325,7 +325,7 @@ class STDCNet(BaseModule):
return tuple(outs) return tuple(outs)
@BACKBONES.register_module() @MODELS.register_module()
class STDCContextPathNet(BaseModule): class STDCContextPathNet(BaseModule):
"""STDCNet with Context Path. The `outs` below is a list of three feature """STDCNet with Context Path. The `outs` below is a list of three feature
maps from deep to shallow, whose height and width is from small to big, maps from deep to shallow, whose height and width is from small to big,
@ -371,7 +371,7 @@ class STDCContextPathNet(BaseModule):
norm_cfg=dict(type='BN'), norm_cfg=dict(type='BN'),
init_cfg=None): init_cfg=None):
super(STDCContextPathNet, self).__init__(init_cfg=init_cfg) super(STDCContextPathNet, self).__init__(init_cfg=init_cfg)
self.backbone = build_backbone(backbone_cfg) self.backbone = MODELS.build(backbone_cfg)
self.arms = ModuleList() self.arms = ModuleList()
self.convs = ModuleList() self.convs = ModuleList()
for channels in last_in_channels: for channels in last_in_channels:

View File

@ -15,8 +15,8 @@ from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
load_state_dict) load_state_dict)
from mmcv.utils import to_2tuple from mmcv.utils import to_2tuple
from mmseg.registry import MODELS
from ...utils import get_root_logger from ...utils import get_root_logger
from ..builder import BACKBONES
from ..utils.embed import PatchEmbed, PatchMerging from ..utils.embed import PatchEmbed, PatchMerging
@ -462,7 +462,7 @@ class SwinBlockSequence(BaseModule):
return x, hw_shape, x, hw_shape return x, hw_shape, x, hw_shape
@BACKBONES.register_module() @MODELS.register_module()
class SwinTransformer(BaseModule): class SwinTransformer(BaseModule):
"""Swin Transformer backbone. """Swin Transformer backbone.

View File

@ -7,10 +7,10 @@ except ImportError:
from mmcv.cnn.bricks.registry import NORM_LAYERS from mmcv.cnn.bricks.registry import NORM_LAYERS
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from ..builder import BACKBONES from mmseg.registry import MODELS
@BACKBONES.register_module() @MODELS.register_module()
class TIMMBackbone(BaseModule): class TIMMBackbone(BaseModule):
"""Wrapper to use backbones from timm library. More details can be found in """Wrapper to use backbones from timm library. More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_ . `timm <https://github.com/rwightman/pytorch-image-models>`_ .

View File

@ -14,7 +14,7 @@ from mmcv.runner import BaseModule, ModuleList
from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.batchnorm import _BatchNorm
from mmseg.models.backbones.mit import EfficientMultiheadAttention from mmseg.models.backbones.mit import EfficientMultiheadAttention
from mmseg.models.builder import BACKBONES from mmseg.registry import MODELS
from ..utils.embed import PatchEmbed from ..utils.embed import PatchEmbed
@ -349,7 +349,7 @@ class ConditionalPositionEncoding(BaseModule):
return x return x
@BACKBONES.register_module() @MODELS.register_module()
class PCPVT(BaseModule): class PCPVT(BaseModule):
"""The backbone of Twins-PCPVT. """The backbone of Twins-PCPVT.
@ -508,7 +508,7 @@ class PCPVT(BaseModule):
return tuple(outputs) return tuple(outputs)
@BACKBONES.register_module() @MODELS.register_module()
class SVT(PCPVT): class SVT(PCPVT):
"""The backbone of Twins-SVT. """The backbone of Twins-SVT.

View File

@ -9,7 +9,7 @@ from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
from mmseg.ops import Upsample from mmseg.ops import Upsample
from ..builder import BACKBONES from mmseg.registry import MODELS
from ..utils import UpConvBlock from ..utils import UpConvBlock
@ -221,7 +221,7 @@ class InterpConv(nn.Module):
return out return out
@BACKBONES.register_module() @MODELS.register_module()
class UNet(BaseModule): class UNet(BaseModule):
"""UNet backbone. """UNet backbone.

View File

@ -15,8 +15,8 @@ from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize from mmseg.ops import resize
from mmseg.registry import MODELS
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import PatchEmbed from ..utils import PatchEmbed
@ -122,7 +122,7 @@ class TransformerEncoderLayer(BaseModule):
return x return x
@BACKBONES.register_module() @MODELS.register_module()
class VisionTransformer(BaseModule): class VisionTransformer(BaseModule):
"""Vision Transformer. """Vision Transformer.

View File

@ -1,12 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from mmcv.cnn import MODELS as MMCV_MODELS from mmseg.registry import MODELS
from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
from mmcv.utils import Registry
MODELS = Registry('models', parent=MMCV_MODELS)
ATTENTION = Registry('attention', parent=MMCV_ATTENTION)
BACKBONES = MODELS BACKBONES = MODELS
NECKS = MODELS NECKS = MODELS
@ -17,21 +12,29 @@ SEGMENTORS = MODELS
def build_backbone(cfg): def build_backbone(cfg):
"""Build backbone.""" """Build backbone."""
warnings.warn('``build_backbone`` would be deprecated soon, please use '
'``mmseg.registry.MODELS.build()`` ')
return BACKBONES.build(cfg) return BACKBONES.build(cfg)
def build_neck(cfg): def build_neck(cfg):
"""Build neck.""" """Build neck."""
warnings.warn('``build_neck`` would be deprecated soon, please use '
'``mmseg.registry.MODELS.build()`` ')
return NECKS.build(cfg) return NECKS.build(cfg)
def build_head(cfg): def build_head(cfg):
"""Build head.""" """Build head."""
warnings.warn('``build_head`` would be deprecated soon, please use '
'``mmseg.registry.MODELS.build()`` ')
return HEADS.build(cfg) return HEADS.build(cfg)
def build_loss(cfg): def build_loss(cfg):
"""Build loss.""" """Build loss."""
warnings.warn('``build_loss`` would be deprecated soon, please use '
'``mmseg.registry.MODELS.build()`` ')
return LOSSES.build(cfg) return LOSSES.build(cfg)

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from ..builder import HEADS from mmseg.registry import MODELS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@ -181,7 +181,7 @@ class APNB(nn.Module):
return output return output
@HEADS.register_module() @MODELS.register_module()
class ANNHead(BaseDecodeHead): class ANNHead(BaseDecodeHead):
"""Asymmetric Non-local Neural Networks for Semantic Segmentation. """Asymmetric Non-local Neural Networks for Semantic Segmentation.

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@ -107,7 +107,7 @@ class ACM(nn.Module):
return z_out return z_out
@HEADS.register_module() @MODELS.register_module()
class APCHead(BaseDecodeHead): class APCHead(BaseDecodeHead):
"""Adaptive Pyramid Context Network for Semantic Segmentation. """Adaptive Pyramid Context Network for Semantic Segmentation.

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@ -50,7 +50,7 @@ class ASPPModule(nn.ModuleList):
return aspp_outs return aspp_outs
@HEADS.register_module() @MODELS.register_module()
class ASPPHead(BaseDecodeHead): class ASPPHead(BaseDecodeHead):
"""Rethinking Atrous Convolution for Semantic Image Segmentation. """Rethinking Atrous Convolution for Semantic Image Segmentation.

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch import torch
from ..builder import HEADS from mmseg.registry import MODELS
from .fcn_head import FCNHead from .fcn_head import FCNHead
try: try:
@ -10,7 +10,7 @@ except ModuleNotFoundError:
CrissCrossAttention = None CrissCrossAttention = None
@HEADS.register_module() @MODELS.register_module()
class CCHead(FCNHead): class CCHead(FCNHead):
"""CCNet: Criss-Cross Attention for Semantic Segmentation. """CCNet: Criss-Cross Attention for Semantic Segmentation.

View File

@ -5,7 +5,7 @@ from mmcv.cnn import ConvModule, Scale
from torch import nn from torch import nn
from mmseg.core import add_prefix from mmseg.core import add_prefix
from ..builder import HEADS from mmseg.registry import MODELS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@ -72,7 +72,7 @@ class CAM(nn.Module):
return out return out
@HEADS.register_module() @MODELS.register_module()
class DAHead(BaseDecodeHead): class DAHead(BaseDecodeHead):
"""Dual Attention Network for Scene Segmentation. """Dual Attention Network for Scene Segmentation.

View File

@ -4,7 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@ -89,7 +89,7 @@ class DCM(nn.Module):
return output return output
@HEADS.register_module() @MODELS.register_module()
class DMHead(BaseDecodeHead): class DMHead(BaseDecodeHead):
"""Dynamic Multi-scale Filters for Semantic Segmentation. """Dynamic Multi-scale Filters for Semantic Segmentation.

View File

@ -3,7 +3,7 @@ import torch
from mmcv.cnn import NonLocal2d from mmcv.cnn import NonLocal2d
from torch import nn from torch import nn
from ..builder import HEADS from mmseg.registry import MODELS
from .fcn_head import FCNHead from .fcn_head import FCNHead
@ -89,7 +89,7 @@ class DisentangledNonLocal2d(NonLocal2d):
return output return output
@HEADS.register_module() @MODELS.register_module()
class DNLHead(FCNHead): class DNLHead(FCNHead):
"""Disentangled Non-Local Neural Networks. """Disentangled Non-Local Neural Networks.

View File

@ -7,7 +7,7 @@ from mmcv.cnn import ConvModule, Linear, build_activation_layer
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@ -212,7 +212,7 @@ class FeatureFusionBlock(BaseModule):
return x return x
@HEADS.register_module() @MODELS.register_module()
class DPTHead(BaseDecodeHead): class DPTHead(BaseDecodeHead):
"""Vision Transformers for Dense Prediction. """Vision Transformers for Dense Prediction.

View File

@ -7,7 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@ -76,7 +76,7 @@ class EMAModule(nn.Module):
return feats_recon return feats_recon
@HEADS.register_module() @MODELS.register_module()
class EMAHead(BaseDecodeHead): class EMAHead(BaseDecodeHead):
"""Expectation Maximization Attention Networks for Semantic Segmentation. """Expectation Maximization Attention Networks for Semantic Segmentation.

View File

@ -5,7 +5,8 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_norm_layer from mmcv.cnn import ConvModule, build_norm_layer
from mmseg.ops import Encoding, resize from mmseg.ops import Encoding, resize
from ..builder import HEADS, build_loss from mmseg.registry import MODELS
from ..builder import build_loss
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@ -59,7 +60,7 @@ class EncModule(nn.Module):
return encoding_feat, output return encoding_feat, output
@HEADS.register_module() @MODELS.register_module()
class EncHead(BaseDecodeHead): class EncHead(BaseDecodeHead):
"""Context Encoding for Semantic Segmentation. """Context Encoding for Semantic Segmentation.

View File

@ -3,11 +3,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@HEADS.register_module() @MODELS.register_module()
class FCNHead(BaseDecodeHead): class FCNHead(BaseDecodeHead):
"""Fully Convolution Networks for Semantic Segmentation. """Fully Convolution Networks for Semantic Segmentation.

View File

@ -4,11 +4,11 @@ import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmseg.ops import Upsample, resize from mmseg.ops import Upsample, resize
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@HEADS.register_module() @MODELS.register_module()
class FPNHead(BaseDecodeHead): class FPNHead(BaseDecodeHead):
"""Panoptic Feature Pyramid Networks. """Panoptic Feature Pyramid Networks.

View File

@ -2,11 +2,11 @@
import torch import torch
from mmcv.cnn import ContextBlock from mmcv.cnn import ContextBlock
from ..builder import HEADS from mmseg.registry import MODELS
from .fcn_head import FCNHead from .fcn_head import FCNHead
@HEADS.register_module() @MODELS.register_module()
class GCHead(FCNHead): class GCHead(FCNHead):
"""GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond. """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.

View File

@ -5,7 +5,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from ..builder import HEADS from mmseg.registry import MODELS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@ -55,7 +55,7 @@ class SelfAttentionBlock(_SelfAttentionBlock):
return self.output_project(context) return self.output_project(context)
@HEADS.register_module() @MODELS.register_module()
class ISAHead(BaseDecodeHead): class ISAHead(BaseDecodeHead):
"""Interlaced Sparse Self-Attention for Semantic Segmentation. """Interlaced Sparse Self-Attention for Semantic Segmentation.

View File

@ -7,8 +7,8 @@ from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER,
MultiheadAttention, MultiheadAttention,
build_transformer_layer) build_transformer_layer)
from mmseg.models.builder import HEADS, build_head
from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.registry import MODELS
from mmseg.utils import get_root_logger from mmseg.utils import get_root_logger
@ -139,7 +139,7 @@ class KernelUpdator(nn.Module):
return features return features
@HEADS.register_module() @MODELS.register_module()
class KernelUpdateHead(nn.Module): class KernelUpdateHead(nn.Module):
"""Kernel Update Head in K-Net. """Kernel Update Head in K-Net.
@ -391,7 +391,7 @@ class KernelUpdateHead(nn.Module):
self.conv_kernel_size) self.conv_kernel_size)
@HEADS.register_module() @MODELS.register_module()
class IterativeDecodeHead(BaseDecodeHead): class IterativeDecodeHead(BaseDecodeHead):
"""K-Net: Towards Unified Image Segmentation. """K-Net: Towards Unified Image Segmentation.
@ -414,7 +414,7 @@ class IterativeDecodeHead(BaseDecodeHead):
super(BaseDecodeHead, self).__init__(**kwargs) super(BaseDecodeHead, self).__init__(**kwargs)
assert num_stages == len(kernel_update_head) assert num_stages == len(kernel_update_head)
self.num_stages = num_stages self.num_stages = num_stages
self.kernel_generate_head = build_head(kernel_generate_head) self.kernel_generate_head = (kernel_generate_head)
self.kernel_update_head = nn.ModuleList() self.kernel_update_head = nn.ModuleList()
self.align_corners = self.kernel_generate_head.align_corners self.align_corners = self.kernel_generate_head.align_corners
self.num_classes = self.kernel_generate_head.num_classes self.num_classes = self.kernel_generate_head.num_classes
@ -422,7 +422,7 @@ class IterativeDecodeHead(BaseDecodeHead):
self.ignore_index = self.kernel_generate_head.ignore_index self.ignore_index = self.kernel_generate_head.ignore_index
for head_cfg in kernel_update_head: for head_cfg in kernel_update_head:
self.kernel_update_head.append(build_head(head_cfg)) self.kernel_update_head.append(MODELS.build(head_cfg))
def forward(self, inputs): def forward(self, inputs):
"""Forward function.""" """Forward function."""

View File

@ -5,11 +5,11 @@ from mmcv import is_tuple_of
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@HEADS.register_module() @MODELS.register_module()
class LRASPPHead(BaseDecodeHead): class LRASPPHead(BaseDecodeHead):
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3. """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.

View File

@ -2,11 +2,11 @@
import torch import torch
from mmcv.cnn import NonLocal2d from mmcv.cnn import NonLocal2d
from ..builder import HEADS from mmseg.registry import MODELS
from .fcn_head import FCNHead from .fcn_head import FCNHead
@HEADS.register_module() @MODELS.register_module()
class NLHead(FCNHead): class NLHead(FCNHead):
"""Non-local Neural Networks. """Non-local Neural Networks.

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import HEADS from mmseg.registry import MODELS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from .cascade_decode_head import BaseCascadeDecodeHead from .cascade_decode_head import BaseCascadeDecodeHead
@ -82,7 +82,7 @@ class ObjectAttentionBlock(_SelfAttentionBlock):
return output return output
@HEADS.register_module() @MODELS.register_module()
class OCRHead(BaseCascadeDecodeHead): class OCRHead(BaseCascadeDecodeHead):
"""Object-Contextual Representations for Semantic Segmentation. """Object-Contextual Representations for Semantic Segmentation.

View File

@ -10,8 +10,8 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
point_sample = None point_sample = None
from mmseg.models.builder import HEADS
from mmseg.ops import resize from mmseg.ops import resize
from mmseg.registry import MODELS
from ..losses import accuracy from ..losses import accuracy
from .cascade_decode_head import BaseCascadeDecodeHead from .cascade_decode_head import BaseCascadeDecodeHead
@ -36,7 +36,7 @@ def calculate_uncertainty(seg_logits):
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
@HEADS.register_module() @MODELS.register_module()
class PointHead(BaseCascadeDecodeHead): class PointHead(BaseCascadeDecodeHead):
"""A mask point head use in PointRend. """A mask point head use in PointRend.

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
try: try:
@ -14,7 +14,7 @@ except ModuleNotFoundError:
PSAMask = None PSAMask = None
@HEADS.register_module() @MODELS.register_module()
class PSAHead(BaseDecodeHead): class PSAHead(BaseDecodeHead):
"""Point-wise Spatial Attention Network for Scene Parsing. """Point-wise Spatial Attention Network for Scene Parsing.

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@ -59,7 +59,7 @@ class PPM(nn.ModuleList):
return ppm_outs return ppm_outs
@HEADS.register_module() @MODELS.register_module()
class PSPHead(BaseDecodeHead): class PSPHead(BaseDecodeHead):
"""Pyramid Scene Parsing Network. """Pyramid Scene Parsing Network.

View File

@ -3,12 +3,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmseg.models.builder import HEADS
from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.ops import resize from mmseg.ops import resize
from mmseg.registry import MODELS
@HEADS.register_module() @MODELS.register_module()
class SegformerHead(BaseDecodeHead): class SegformerHead(BaseDecodeHead):
"""The all mlp Head of segformer. """The all mlp Head of segformer.

View File

@ -8,11 +8,11 @@ from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
from mmcv.runner import ModuleList from mmcv.runner import ModuleList
from mmseg.models.backbones.vit import TransformerEncoderLayer from mmseg.models.backbones.vit import TransformerEncoderLayer
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@HEADS.register_module() @MODELS.register_module()
class SegmenterMaskTransformerHead(BaseDecodeHead): class SegmenterMaskTransformerHead(BaseDecodeHead):
"""Segmenter: Transformer for Semantic Segmentation. """Segmenter: Transformer for Semantic Segmentation.

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import HEADS from mmseg.registry import MODELS
from .aspp_head import ASPPHead, ASPPModule from .aspp_head import ASPPHead, ASPPModule
@ -26,7 +26,7 @@ class DepthwiseSeparableASPPModule(ASPPModule):
act_cfg=self.act_cfg) act_cfg=self.act_cfg)
@HEADS.register_module() @MODELS.register_module()
class DepthwiseSeparableASPPHead(ASPPHead): class DepthwiseSeparableASPPHead(ASPPHead):
"""Encoder-Decoder with Atrous Separable Convolution for Semantic Image """Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation. Segmentation.

View File

@ -1,11 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import DepthwiseSeparableConvModule from mmcv.cnn import DepthwiseSeparableConvModule
from ..builder import HEADS from mmseg.registry import MODELS
from .fcn_head import FCNHead from .fcn_head import FCNHead
@HEADS.register_module() @MODELS.register_module()
class DepthwiseSeparableFCNHead(FCNHead): class DepthwiseSeparableFCNHead(FCNHead):
"""Depthwise-Separable Fully Convolutional Network for Semantic """Depthwise-Separable Fully Convolutional Network for Semantic
Segmentation. Segmentation.

View File

@ -4,11 +4,11 @@ import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmseg.ops import Upsample from mmseg.ops import Upsample
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@HEADS.register_module() @MODELS.register_module()
class SETRMLAHead(BaseDecodeHead): class SETRMLAHead(BaseDecodeHead):
"""Multi level feature aggretation head of SETR. """Multi level feature aggretation head of SETR.

View File

@ -3,11 +3,11 @@ import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer from mmcv.cnn import ConvModule, build_norm_layer
from mmseg.ops import Upsample from mmseg.ops import Upsample
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
@HEADS.register_module() @MODELS.register_module()
class SETRUPHead(BaseDecodeHead): class SETRUPHead(BaseDecodeHead):
"""Naive upsampling head and Progressive upsampling head of SETR. """Naive upsampling head and Progressive upsampling head of SETR.

View File

@ -2,11 +2,11 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from ..builder import HEADS from mmseg.registry import MODELS
from .fcn_head import FCNHead from .fcn_head import FCNHead
@HEADS.register_module() @MODELS.register_module()
class STDCHead(FCNHead): class STDCHead(FCNHead):
"""This head is the implementation of `Rethinking BiSeNet For Real-time """This head is the implementation of `Rethinking BiSeNet For Real-time
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_. Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.

View File

@ -4,12 +4,12 @@ import torch.nn as nn
from mmcv.cnn import ConvModule from mmcv.cnn import ConvModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import HEADS from mmseg.registry import MODELS
from .decode_head import BaseDecodeHead from .decode_head import BaseDecodeHead
from .psp_head import PPM from .psp_head import PPM
@HEADS.register_module() @MODELS.register_module()
class UPerHead(BaseDecodeHead): class UPerHead(BaseDecodeHead):
"""Unified Perceptual Parsing for Scene Understanding. """Unified Perceptual Parsing for Scene Understanding.

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..builder import LOSSES from mmseg.registry import MODELS
from .utils import get_class_weight, weight_reduce_loss from .utils import get_class_weight, weight_reduce_loss
@ -193,7 +193,7 @@ def mask_cross_entropy(pred,
pred_slice, target, weight=class_weight, reduction='mean')[None] pred_slice, target, weight=class_weight, reduction='mean')[None]
@LOSSES.register_module() @MODELS.register_module()
class CrossEntropyLoss(nn.Module): class CrossEntropyLoss(nn.Module):
"""CrossEntropyLoss. """CrossEntropyLoss.

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..builder import LOSSES from mmseg.registry import MODELS
from .utils import get_class_weight, weighted_loss from .utils import get_class_weight, weighted_loss
@ -47,7 +47,7 @@ def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
return 1 - num / den return 1 - num / den
@LOSSES.register_module() @MODELS.register_module()
class DiceLoss(nn.Module): class DiceLoss(nn.Module):
"""DiceLoss. """DiceLoss.

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
from ..builder import LOSSES from mmseg.registry import MODELS
from .utils import weight_reduce_loss from .utils import weight_reduce_loss
@ -133,7 +133,7 @@ def sigmoid_focal_loss(pred,
return loss return loss
@LOSSES.register_module() @MODELS.register_module()
class FocalLoss(nn.Module): class FocalLoss(nn.Module):
def __init__(self, def __init__(self,

View File

@ -8,7 +8,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..builder import LOSSES from mmseg.registry import MODELS
from .utils import get_class_weight, weight_reduce_loss from .utils import get_class_weight, weight_reduce_loss
@ -222,7 +222,7 @@ def lovasz_softmax(probs,
return loss return loss
@LOSSES.register_module() @MODELS.register_module()
class LovaszLoss(nn.Module): class LovaszLoss(nn.Module):
"""LovaszLoss. """LovaszLoss.

View File

@ -2,10 +2,10 @@
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import build_norm_layer from mmcv.cnn import build_norm_layer
from ..builder import NECKS from mmseg.registry import MODELS
@NECKS.register_module() @MODELS.register_module()
class Feature2Pyramid(nn.Module): class Feature2Pyramid(nn.Module):
"""Feature2Pyramid. """Feature2Pyramid.

View File

@ -5,10 +5,10 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16 from mmcv.runner import BaseModule, auto_fp16
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import NECKS from mmseg.registry import MODELS
@NECKS.register_module() @MODELS.register_module()
class FPN(BaseModule): class FPN(BaseModule):
"""Feature Pyramid Network. """Feature Pyramid Network.

View File

@ -4,7 +4,7 @@ from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import NECKS from mmseg.registry import MODELS
class CascadeFeatureFusion(BaseModule): class CascadeFeatureFusion(BaseModule):
@ -77,7 +77,7 @@ class CascadeFeatureFusion(BaseModule):
return x, x_low return x, x_low
@NECKS.register_module() @MODELS.register_module()
class ICNeck(BaseModule): class ICNeck(BaseModule):
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images. """ICNet for Real-Time Semantic Segmentation on High-Resolution Images.

View File

@ -5,10 +5,10 @@ from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmcv.runner import BaseModule from mmcv.runner import BaseModule
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import NECKS from mmseg.registry import MODELS
@NECKS.register_module() @MODELS.register_module()
class JPU(BaseModule): class JPU(BaseModule):
"""FastFCN: Rethinking Dilated Convolution in the Backbone """FastFCN: Rethinking Dilated Convolution in the Backbone
for Semantic Segmentation. for Semantic Segmentation.

View File

@ -2,7 +2,7 @@
import torch.nn as nn import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer from mmcv.cnn import ConvModule, build_norm_layer
from ..builder import NECKS from mmseg.registry import MODELS
class MLAModule(nn.Module): class MLAModule(nn.Module):
@ -59,7 +59,7 @@ class MLAModule(nn.Module):
return tuple(out_list) return tuple(out_list)
@NECKS.register_module() @MODELS.register_module()
class MLANeck(nn.Module): class MLANeck(nn.Module):
"""Multi-level Feature Aggregation. """Multi-level Feature Aggregation.

View File

@ -3,10 +3,10 @@ import torch.nn as nn
from mmcv.cnn import ConvModule, xavier_init from mmcv.cnn import ConvModule, xavier_init
from mmseg.ops import resize from mmseg.ops import resize
from ..builder import NECKS from mmseg.registry import MODELS
@NECKS.register_module() @MODELS.register_module()
class MultiLevelNeck(nn.Module): class MultiLevelNeck(nn.Module):
"""MultiLevelNeck. """MultiLevelNeck.

View File

@ -3,12 +3,11 @@ from torch import nn
from mmseg.core import add_prefix from mmseg.core import add_prefix
from mmseg.ops import resize from mmseg.ops import resize
from .. import builder from mmseg.registry import MODELS
from ..builder import SEGMENTORS
from .encoder_decoder import EncoderDecoder from .encoder_decoder import EncoderDecoder
@SEGMENTORS.register_module() @MODELS.register_module()
class CascadeEncoderDecoder(EncoderDecoder): class CascadeEncoderDecoder(EncoderDecoder):
"""Cascade Encoder Decoder segmentors. """Cascade Encoder Decoder segmentors.
@ -44,7 +43,7 @@ class CascadeEncoderDecoder(EncoderDecoder):
assert len(decode_head) == self.num_stages assert len(decode_head) == self.num_stages
self.decode_head = nn.ModuleList() self.decode_head = nn.ModuleList()
for i in range(self.num_stages): for i in range(self.num_stages):
self.decode_head.append(builder.build_head(decode_head[i])) self.decode_head.append(MODELS.build(decode_head[i]))
self.align_corners = self.decode_head[-1].align_corners self.align_corners = self.decode_head[-1].align_corners
self.num_classes = self.decode_head[-1].num_classes self.num_classes = self.decode_head[-1].num_classes

View File

@ -5,12 +5,11 @@ import torch.nn.functional as F
from mmseg.core import add_prefix from mmseg.core import add_prefix
from mmseg.ops import resize from mmseg.ops import resize
from .. import builder from mmseg.registry import MODELS
from ..builder import SEGMENTORS
from .base import BaseSegmentor from .base import BaseSegmentor
@SEGMENTORS.register_module() @MODELS.register_module()
class EncoderDecoder(BaseSegmentor): class EncoderDecoder(BaseSegmentor):
"""Encoder Decoder segmentors. """Encoder Decoder segmentors.
@ -33,9 +32,9 @@ class EncoderDecoder(BaseSegmentor):
assert backbone.get('pretrained') is None, \ assert backbone.get('pretrained') is None, \
'both backbone and segmentor set pretrained weight' 'both backbone and segmentor set pretrained weight'
backbone.pretrained = pretrained backbone.pretrained = pretrained
self.backbone = builder.build_backbone(backbone) self.backbone = MODELS.build(backbone)
if neck is not None: if neck is not None:
self.neck = builder.build_neck(neck) self.neck = MODELS.build(neck)
self._init_decode_head(decode_head) self._init_decode_head(decode_head)
self._init_auxiliary_head(auxiliary_head) self._init_auxiliary_head(auxiliary_head)
@ -46,7 +45,7 @@ class EncoderDecoder(BaseSegmentor):
def _init_decode_head(self, decode_head): def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``""" """Initialize ``decode_head``"""
self.decode_head = builder.build_head(decode_head) self.decode_head = MODELS.build(decode_head)
self.align_corners = self.decode_head.align_corners self.align_corners = self.decode_head.align_corners
self.num_classes = self.decode_head.num_classes self.num_classes = self.decode_head.num_classes
@ -56,9 +55,9 @@ class EncoderDecoder(BaseSegmentor):
if isinstance(auxiliary_head, list): if isinstance(auxiliary_head, list):
self.auxiliary_head = nn.ModuleList() self.auxiliary_head = nn.ModuleList()
for head_cfg in auxiliary_head: for head_cfg in auxiliary_head:
self.auxiliary_head.append(builder.build_head(head_cfg)) self.auxiliary_head.append(MODELS.build(head_cfg))
else: else:
self.auxiliary_head = builder.build_head(auxiliary_head) self.auxiliary_head = MODELS.build(auxiliary_head)
def extract_feat(self, img): def extract_feat(self, img):
"""Extract features from images.""" """Extract features from images."""

View File

@ -0,0 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS,
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS,
RUNNERS, TASK_UTILS, TRANSFORMS, VISBACKENDS,
VISUALIZERS, WEIGHT_INITIALIZERS)
__all__ = [
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS',
'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS',
'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS',
'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS'
]

View File

@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""MMSegmentation provides 17 registry nodes to support using modules across
projects. Each node is a child of the root registry in MMEngine.
More details can be found at
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
"""
from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
from mmengine.registry import DATASETS as MMENGINE_DATASETS
from mmengine.registry import HOOKS as MMENGINE_HOOKS
from mmengine.registry import LOOPS as MMENGINE_LOOPS
from mmengine.registry import METRICS as MMENGINE_METRICS
from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
from mmengine.registry import MODELS as MMENGINE_MODELS
from mmengine.registry import \
OPTIMIZER_CONSTRUCTORS as MMENGINE_OPTIMIZER_CONSTRUCTORS
from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
from mmengine.registry import \
RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS
from mmengine.registry import RUNNERS as MMENGINE_RUNNERS
from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS
from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS
from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS
from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS
from mmengine.registry import \
WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS
from mmengine.registry import Registry
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
# manage runner constructors that define how to initialize runners
RUNNER_CONSTRUCTORS = Registry(
'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
# manage all kinds of loops like `EpochBasedTrainLoop`
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
# manage all kinds of hooks like `CheckpointHook`
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
# manage data-related modules
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
# mangage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model', parent=MMENGINE_MODELS)
# mangage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
# mangage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry(
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
# mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
# manage constructors that customize the optimization hyperparameters.
OPTIMIZER_CONSTRUCTORS = Registry(
'optimizer constructor', parent=MMENGINE_OPTIMIZER_CONSTRUCTORS)
# mangage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
# manage all kinds of metrics
METRICS = Registry('metric', parent=MMENGINE_METRICS)
# manage task-specific modules like ohem pixel sampler
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
# manage visualizer
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
# manage visualizer backend
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)

View File

@ -2,10 +2,10 @@
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.runner import DefaultOptimizerConstructor from mmengine.optim import DefaultOptimizerConstructor
from mmseg.core.builder import (OPTIMIZER_BUILDERS, build_optimizer, from mmseg.core.builder import build_optimizer, build_optimizer_constructor
build_optimizer_constructor) from mmseg.registry import OPTIMIZER_CONSTRUCTORS
class ExampleModel(nn.Module): class ExampleModel(nn.Module):
@ -35,7 +35,7 @@ def test_build_optimizer_constructor():
# Test whether optimizer constructor can be built from parent. # Test whether optimizer constructor can be built from parent.
assert type(optim_constructor) is DefaultOptimizerConstructor assert type(optim_constructor) is DefaultOptimizerConstructor
@OPTIMIZER_BUILDERS.register_module() @OPTIMIZER_CONSTRUCTORS.register_module()
class MyOptimizerConstructor(DefaultOptimizerConstructor): class MyOptimizerConstructor(DefaultOptimizerConstructor):
pass pass