Merge branch 'linfangjian/refactor_registies' into 'refactor_dev'
[Refactor] Refactor all registries See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!3pull/1801/head
commit
168c8831dd
|
@ -1,11 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import (OPTIMIZER_BUILDERS, build_optimizer,
|
||||
build_optimizer_constructor)
|
||||
from .builder import build_optimizer, build_optimizer_constructor
|
||||
from .evaluation import * # noqa: F401, F403
|
||||
from .optimizers import * # noqa: F401, F403
|
||||
from .seg import * # noqa: F401, F403
|
||||
from .utils import * # noqa: F401, F403
|
||||
|
||||
__all__ = [
|
||||
'OPTIMIZER_BUILDERS', 'build_optimizer', 'build_optimizer_constructor'
|
||||
]
|
||||
__all__ = ['build_optimizer', 'build_optimizer_constructor']
|
||||
|
|
|
@ -1,19 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS as MMCV_OPTIMIZER_BUILDERS
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
OPTIMIZER_BUILDERS = Registry(
|
||||
'optimizer builder', parent=MMCV_OPTIMIZER_BUILDERS)
|
||||
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
|
||||
|
||||
|
||||
def build_optimizer_constructor(cfg):
|
||||
constructor_type = cfg.get('type')
|
||||
if constructor_type in OPTIMIZER_BUILDERS:
|
||||
return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
|
||||
elif constructor_type in MMCV_OPTIMIZER_BUILDERS:
|
||||
return build_from_cfg(cfg, MMCV_OPTIMIZER_BUILDERS)
|
||||
if constructor_type in OPTIMIZER_CONSTRUCTORS:
|
||||
return OPTIMIZER_CONSTRUCTORS.build(cfg)
|
||||
else:
|
||||
raise KeyError(f'{constructor_type} is not registered '
|
||||
'in the optimizer builder registry.')
|
||||
|
|
|
@ -2,10 +2,11 @@
|
|||
import json
|
||||
import warnings
|
||||
|
||||
from mmcv.runner import DefaultOptimizerConstructor, get_dist_info
|
||||
from mmengine.dist import get_dist_info
|
||||
from mmengine.optim import DefaultOptimizerConstructor
|
||||
|
||||
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..builder import OPTIMIZER_BUILDERS
|
||||
|
||||
|
||||
def get_layer_id_for_convnext(var_name, max_layer_id):
|
||||
|
@ -99,7 +100,7 @@ def get_layer_id_for_vit(var_name, max_layer_id):
|
|||
return max_layer_id - 1
|
||||
|
||||
|
||||
@OPTIMIZER_BUILDERS.register_module()
|
||||
@OPTIMIZER_CONSTRUCTORS.register_module()
|
||||
class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
|
@ -185,7 +186,7 @@ class LearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
|
|||
params.extend(parameter_groups.values())
|
||||
|
||||
|
||||
@OPTIMIZER_BUILDERS.register_module()
|
||||
@OPTIMIZER_CONSTRUCTORS.register_module()
|
||||
class LayerDecayOptimizerConstructor(LearningRateDecayOptimizerConstructor):
|
||||
"""Different learning rates are set for different layers of backbone.
|
||||
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
import warnings
|
||||
|
||||
PIXEL_SAMPLERS = Registry('pixel sampler')
|
||||
from mmseg.registry import TASK_UTILS
|
||||
|
||||
PIXEL_SAMPLERS = TASK_UTILS
|
||||
|
||||
|
||||
def build_pixel_sampler(cfg, **default_args):
|
||||
"""Build pixel sampler for segmentation map."""
|
||||
return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
|
||||
warnings.warn(
|
||||
'``build_pixel_sampler`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.TASK_UTILS.build()`` ')
|
||||
return TASK_UTILS.build(cfg, default_args=default_args)
|
||||
|
|
|
@ -5,7 +5,7 @@ import mmcv
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
|
|
@ -8,9 +8,10 @@ import numpy as np
|
|||
import torch
|
||||
from mmcv.parallel import collate
|
||||
from mmcv.runner import get_dist_info
|
||||
from mmcv.utils import Registry, build_from_cfg, digit_version
|
||||
from mmcv.utils import digit_version
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmseg.registry import DATASETS, TRANSFORMS
|
||||
from .samplers import DistributedSampler
|
||||
|
||||
if platform.system() != 'Windows':
|
||||
|
@ -22,8 +23,7 @@ if platform.system() != 'Windows':
|
|||
soft_limit = min(max(4096, base_soft_limit), hard_limit)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
|
||||
|
||||
DATASETS = Registry('dataset')
|
||||
PIPELINES = Registry('pipeline')
|
||||
PIPELINES = TRANSFORMS
|
||||
|
||||
|
||||
def _concat_dataset(cfg, default_args=None):
|
||||
|
@ -82,7 +82,7 @@ def build_dataset(cfg, default_args=None):
|
|||
cfg.get('split', None), (list, tuple)):
|
||||
dataset = _concat_dataset(cfg, default_args)
|
||||
else:
|
||||
dataset = build_from_cfg(cfg, DATASETS, default_args)
|
||||
dataset = DATASETS.build(cfg, default_args=default_args)
|
||||
|
||||
return dataset
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
from mmcv.utils import print_log
|
||||
from PIL import Image
|
||||
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
|
|
@ -10,8 +10,8 @@ from prettytable import PrettyTable
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics
|
||||
from mmseg.registry import DATASETS
|
||||
from mmseg.utils import get_root_logger
|
||||
from .builder import DATASETS
|
||||
from .pipelines import Compose, LoadAnnotations
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .cityscapes import CityscapesDataset
|
||||
|
||||
|
||||
|
|
|
@ -6,10 +6,10 @@ from itertools import chain
|
|||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.utils import build_from_cfg, print_log
|
||||
from mmcv.utils import print_log
|
||||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
|
||||
|
||||
from .builder import DATASETS, PIPELINES
|
||||
from mmseg.registry import DATASETS, TRANSFORMS
|
||||
from .cityscapes import CityscapesDataset
|
||||
|
||||
|
||||
|
@ -225,7 +225,7 @@ class MultiImageMixDataset:
|
|||
for transform in pipeline:
|
||||
if isinstance(transform, dict):
|
||||
self.pipeline_types.append(transform['type'])
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
self.pipeline.append(transform)
|
||||
else:
|
||||
raise TypeError('pipeline must be a dict')
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
import mmcv
|
||||
from mmcv.utils import print_log
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from ..utils import get_root_logger
|
||||
from .builder import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import mmcv
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .cityscapes import CityscapesDataset
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import collections
|
||||
|
||||
from mmcv.utils import build_from_cfg
|
||||
|
||||
from ..builder import PIPELINES
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class Compose(object):
|
||||
"""Compose multiple transforms sequentially.
|
||||
|
||||
|
@ -20,7 +18,7 @@ class Compose(object):
|
|||
self.transforms = []
|
||||
for transform in transforms:
|
||||
if isinstance(transform, dict):
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
transform = TRANSFORMS.build(transform)
|
||||
self.transforms.append(transform)
|
||||
elif callable(transform):
|
||||
self.transforms.append(transform)
|
||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
import torch
|
||||
from mmcv.parallel import DataContainer as DC
|
||||
|
||||
from ..builder import PIPELINES
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
def to_tensor(data):
|
||||
|
@ -34,7 +34,7 @@ def to_tensor(data):
|
|||
raise TypeError(f'type {type(data)} cannot be converted to tensor.')
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class ToTensor(object):
|
||||
"""Convert some results to :obj:`torch.Tensor` by given keys.
|
||||
|
||||
|
@ -64,7 +64,7 @@ class ToTensor(object):
|
|||
return self.__class__.__name__ + f'(keys={self.keys})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class ImageToTensor(object):
|
||||
"""Convert image to :obj:`torch.Tensor` by given keys.
|
||||
|
||||
|
@ -102,7 +102,7 @@ class ImageToTensor(object):
|
|||
return self.__class__.__name__ + f'(keys={self.keys})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class Transpose(object):
|
||||
"""Transpose some results by given keys.
|
||||
|
||||
|
@ -136,7 +136,7 @@ class Transpose(object):
|
|||
f'(keys={self.keys}, order={self.order})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class ToDataContainer(object):
|
||||
"""Convert results to :obj:`mmcv.DataContainer` by given fields.
|
||||
|
||||
|
@ -175,7 +175,7 @@ class ToDataContainer(object):
|
|||
return self.__class__.__name__ + f'(fields={self.fields})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class DefaultFormatBundle(object):
|
||||
"""Default formatting bundle.
|
||||
|
||||
|
@ -216,7 +216,7 @@ class DefaultFormatBundle(object):
|
|||
return self.__class__.__name__
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class Collect(object):
|
||||
"""Collect data from the loader relevant to the specific task.
|
||||
|
||||
|
|
|
@ -4,10 +4,10 @@ import os.path as osp
|
|||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from ..builder import PIPELINES
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadImageFromFile(object):
|
||||
"""Load an image from file.
|
||||
|
||||
|
@ -87,7 +87,7 @@ class LoadImageFromFile(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadAnnotations(object):
|
||||
"""Load annotations for semantic segmentation.
|
||||
|
||||
|
|
|
@ -3,11 +3,11 @@ import warnings
|
|||
|
||||
import mmcv
|
||||
|
||||
from ..builder import PIPELINES
|
||||
from mmseg.registry import TRANSFORMS
|
||||
from .compose import Compose
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class MultiScaleFlipAug(object):
|
||||
"""Test-time augmentation with multiple scales and flipping.
|
||||
|
||||
|
|
|
@ -6,10 +6,10 @@ import numpy as np
|
|||
from mmcv.utils import deprecated_api_warning, is_tuple_of
|
||||
from numpy import random
|
||||
|
||||
from ..builder import PIPELINES
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class ResizeToMultiple(object):
|
||||
"""Resize images & seg to multiple of divisor.
|
||||
|
||||
|
@ -66,7 +66,7 @@ class ResizeToMultiple(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class Resize(object):
|
||||
"""Resize images & seg.
|
||||
|
||||
|
@ -321,7 +321,7 @@ class Resize(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomFlip(object):
|
||||
"""Flip the image & seg.
|
||||
|
||||
|
@ -376,7 +376,7 @@ class RandomFlip(object):
|
|||
return self.__class__.__name__ + f'(prob={self.prob})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class Pad(object):
|
||||
"""Pad the image & mask.
|
||||
|
||||
|
@ -447,7 +447,7 @@ class Pad(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class Normalize(object):
|
||||
"""Normalize the image.
|
||||
|
||||
|
@ -489,7 +489,7 @@ class Normalize(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class Rerange(object):
|
||||
"""Rerange the image pixel value.
|
||||
|
||||
|
@ -535,7 +535,7 @@ class Rerange(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class CLAHE(object):
|
||||
"""Use CLAHE method to process the image.
|
||||
|
||||
|
@ -580,7 +580,7 @@ class CLAHE(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomCrop(object):
|
||||
"""Random crop the image & seg.
|
||||
|
||||
|
@ -653,7 +653,7 @@ class RandomCrop(object):
|
|||
return self.__class__.__name__ + f'(crop_size={self.crop_size})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomRotate(object):
|
||||
"""Rotate the image & seg.
|
||||
|
||||
|
@ -736,7 +736,7 @@ class RandomRotate(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RGB2Gray(object):
|
||||
"""Convert RGB image to grayscale image.
|
||||
|
||||
|
@ -791,7 +791,7 @@ class RGB2Gray(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class AdjustGamma(object):
|
||||
"""Using gamma correction to process the image.
|
||||
|
||||
|
@ -827,7 +827,7 @@ class AdjustGamma(object):
|
|||
return self.__class__.__name__ + f'(gamma={self.gamma})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class SegRescale(object):
|
||||
"""Rescale semantic segmentation maps.
|
||||
|
||||
|
@ -857,7 +857,7 @@ class SegRescale(object):
|
|||
return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class PhotoMetricDistortion(object):
|
||||
"""Apply photometric distortion to image sequentially, every transformation
|
||||
is applied with a probability of 0.5. The position of random contrast is in
|
||||
|
@ -976,7 +976,7 @@ class PhotoMetricDistortion(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomCutOut(object):
|
||||
"""CutOut operation.
|
||||
|
||||
|
@ -1068,7 +1068,7 @@ class RandomCutOut(object):
|
|||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomMosaic(object):
|
||||
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
|
||||
one output image. The output image is composed of the parts from each sub-
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
|
|
@ -7,8 +7,10 @@ from torch.utils.data import Dataset
|
|||
from torch.utils.data import DistributedSampler as _DistributedSampler
|
||||
|
||||
from mmseg.core.utils import sync_random_seed
|
||||
from mmseg.registry import DATA_SAMPLERS
|
||||
|
||||
|
||||
@DATA_SAMPLERS.register_module()
|
||||
class DistributedSampler(_DistributedSampler):
|
||||
"""DistributedSampler inheriting from
|
||||
`torch.utils.data.DistributedSampler`.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
from .builder import DATASETS
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import CustomDataset
|
||||
|
||||
|
||||
|
|
|
@ -13,8 +13,8 @@ from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
|
|||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import PatchEmbed
|
||||
from .vit import TransformerEncoderLayer as VisionTransformerEncoderLayer
|
||||
|
||||
|
@ -227,7 +227,7 @@ class BEiTTransformerEncoderLayer(VisionTransformerEncoderLayer):
|
|||
return x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class BEiT(BaseModule):
|
||||
"""BERT Pre-Training of Image Transformers.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from mmcv.cnn import ConvModule
|
|||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import BACKBONES, build_backbone
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class SpatialPath(BaseModule):
|
||||
|
@ -156,7 +156,7 @@ class ContextPath(BaseModule):
|
|||
assert len(context_channels) == 3, 'Length of input channels \
|
||||
of Context Path must be 3!'
|
||||
|
||||
self.backbone = build_backbone(backbone_cfg)
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
|
||||
self.align_corners = align_corners
|
||||
self.arm16 = AttentionRefinementModule(context_channels[1],
|
||||
|
@ -262,7 +262,7 @@ class FeatureFusionModule(BaseModule):
|
|||
return x_out
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class BiSeNetV1(BaseModule):
|
||||
"""BiSeNetV1 backbone.
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
|
|||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class DetailBranch(BaseModule):
|
||||
|
@ -541,7 +541,7 @@ class BGALayer(BaseModule):
|
|||
return output
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class BiSeNetV2(BaseModule):
|
||||
"""BiSeNetV2: Bilateral Network with Guided Aggregation for
|
||||
Real-time Semantic Segmentation.
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
|
|||
from mmcv.runner import BaseModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class GlobalContextExtractor(nn.Module):
|
||||
|
@ -183,7 +183,7 @@ class InputInjection(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class CGNet(BaseModule):
|
||||
"""CGNet backbone.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
|||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class DownsamplerBlock(BaseModule):
|
||||
|
@ -191,7 +191,7 @@ class UpsamplerBlock(BaseModule):
|
|||
return output
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class ERFNet(BaseModule):
|
||||
"""ERFNet backbone.
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from mmcv.runner import BaseModule
|
|||
|
||||
from mmseg.models.decode_heads.psp_head import PPM
|
||||
from mmseg.ops import resize
|
||||
from ..builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidual
|
||||
|
||||
|
||||
|
@ -268,7 +268,7 @@ class FeatureFusionModule(nn.Module):
|
|||
return self.relu(out)
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class FastSCNN(BaseModule):
|
||||
"""Fast-SCNN Backbone.
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmcv.runner import BaseModule, ModuleList, Sequential
|
|||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.ops import Upsample, resize
|
||||
from ..builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
from .resnet import BasicBlock, Bottleneck
|
||||
|
||||
|
||||
|
@ -214,7 +214,7 @@ class HRModule(BaseModule):
|
|||
return x_fuse
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class HRNet(BaseModule):
|
||||
"""HRNet backbone.
|
||||
|
||||
|
|
|
@ -5,11 +5,11 @@ from mmcv.cnn import ConvModule
|
|||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import BACKBONES, build_backbone
|
||||
from mmseg.registry import MODELS
|
||||
from ..decode_heads.psp_head import PPM
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class ICNet(BaseModule):
|
||||
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
|
||||
|
||||
|
@ -66,7 +66,7 @@ class ICNet(BaseModule):
|
|||
]
|
||||
super(ICNet, self).__init__(init_cfg=init_cfg)
|
||||
self.align_corners = align_corners
|
||||
self.backbone = build_backbone(backbone_cfg)
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
|
||||
# Note: Default `ceil_mode` is false in nn.MaxPool2d, set
|
||||
# `ceil_mode=True` to keep information in the corner of feature map.
|
||||
|
|
|
@ -8,8 +8,8 @@ from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
|
|||
from mmcv.runner import ModuleList, _load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
from .beit import BEiT, BEiTAttention, BEiTTransformerEncoderLayer
|
||||
|
||||
|
||||
|
@ -42,7 +42,7 @@ class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer):
|
|||
self.attn = MAEAttention(**attn_cfg)
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class MAE(BEiT):
|
||||
"""VisionTransformer with support for patch.
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
|
|||
trunc_normal_init)
|
||||
from mmcv.runner import BaseModule, ModuleList, Sequential
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw
|
||||
|
||||
|
||||
|
@ -295,7 +295,7 @@ class TransformerEncoderLayer(BaseModule):
|
|||
return x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class MixVisionTransformer(BaseModule):
|
||||
"""The backbone of Segformer.
|
||||
|
||||
|
|
|
@ -6,11 +6,11 @@ from mmcv.cnn import ConvModule
|
|||
from mmcv.runner import BaseModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidual, make_divisible
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class MobileNetV2(BaseModule):
|
||||
"""MobileNetV2 backbone.
|
||||
|
||||
|
|
|
@ -7,11 +7,11 @@ from mmcv.cnn.bricks import Conv2dAdaptivePadding
|
|||
from mmcv.runner import BaseModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidualV3 as InvertedResidual
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class MobileNetV3(BaseModule):
|
||||
"""MobileNetV3 backbone.
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
from .resnet import Bottleneck as _Bottleneck
|
||||
from .resnet import ResNetV1d
|
||||
|
@ -267,7 +267,7 @@ class Bottleneck(_Bottleneck):
|
|||
return out
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class ResNeSt(ResNetV1d):
|
||||
"""ResNeSt backbone.
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
|
|||
from mmcv.runner import BaseModule
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
|
||||
|
||||
|
@ -307,7 +307,7 @@ class Bottleneck(BaseModule):
|
|||
return out
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class ResNet(BaseModule):
|
||||
"""ResNet backbone.
|
||||
|
||||
|
@ -685,7 +685,7 @@ class ResNet(BaseModule):
|
|||
m.eval()
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class ResNetV1c(ResNet):
|
||||
"""ResNetV1c variant described in [1]_.
|
||||
|
||||
|
@ -700,7 +700,7 @@ class ResNetV1c(ResNet):
|
|||
deep_stem=True, avg_down=False, **kwargs)
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class ResNetV1d(ResNet):
|
||||
"""ResNetV1d variant described in [1]_.
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import math
|
|||
|
||||
from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import ResLayer
|
||||
from .resnet import Bottleneck as _Bottleneck
|
||||
from .resnet import ResNet
|
||||
|
@ -84,7 +84,7 @@ class Bottleneck(_Bottleneck):
|
|||
self.add_module(self.norm3_name, norm3)
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class ResNeXt(ResNet):
|
||||
"""ResNeXt backbone.
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmcv.cnn import ConvModule
|
|||
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import BACKBONES, build_backbone
|
||||
from mmseg.registry import MODELS
|
||||
from .bisenetv1 import AttentionRefinementModule
|
||||
|
||||
|
||||
|
@ -184,7 +184,7 @@ class FeatureFusionModule(BaseModule):
|
|||
return x_attn + x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class STDCNet(BaseModule):
|
||||
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time
|
||||
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
|
||||
|
@ -325,7 +325,7 @@ class STDCNet(BaseModule):
|
|||
return tuple(outs)
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class STDCContextPathNet(BaseModule):
|
||||
"""STDCNet with Context Path. The `outs` below is a list of three feature
|
||||
maps from deep to shallow, whose height and width is from small to big,
|
||||
|
@ -371,7 +371,7 @@ class STDCContextPathNet(BaseModule):
|
|||
norm_cfg=dict(type='BN'),
|
||||
init_cfg=None):
|
||||
super(STDCContextPathNet, self).__init__(init_cfg=init_cfg)
|
||||
self.backbone = build_backbone(backbone_cfg)
|
||||
self.backbone = MODELS.build(backbone_cfg)
|
||||
self.arms = ModuleList()
|
||||
self.convs = ModuleList()
|
||||
for channels in last_in_channels:
|
||||
|
|
|
@ -15,8 +15,8 @@ from mmcv.runner import (BaseModule, CheckpointLoader, ModuleList,
|
|||
load_state_dict)
|
||||
from mmcv.utils import to_2tuple
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from ...utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
from ..utils.embed import PatchEmbed, PatchMerging
|
||||
|
||||
|
||||
|
@ -462,7 +462,7 @@ class SwinBlockSequence(BaseModule):
|
|||
return x, hw_shape, x, hw_shape
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class SwinTransformer(BaseModule):
|
||||
"""Swin Transformer backbone.
|
||||
|
||||
|
|
|
@ -7,10 +7,10 @@ except ImportError:
|
|||
from mmcv.cnn.bricks.registry import NORM_LAYERS
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class TIMMBackbone(BaseModule):
|
||||
"""Wrapper to use backbones from timm library. More details can be found in
|
||||
`timm <https://github.com/rwightman/pytorch-image-models>`_ .
|
||||
|
|
|
@ -14,7 +14,7 @@ from mmcv.runner import BaseModule, ModuleList
|
|||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmseg.models.backbones.mit import EfficientMultiheadAttention
|
||||
from mmseg.models.builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils.embed import PatchEmbed
|
||||
|
||||
|
||||
|
@ -349,7 +349,7 @@ class ConditionalPositionEncoding(BaseModule):
|
|||
return x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class PCPVT(BaseModule):
|
||||
"""The backbone of Twins-PCPVT.
|
||||
|
||||
|
@ -508,7 +508,7 @@ class PCPVT(BaseModule):
|
|||
return tuple(outputs)
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class SVT(PCPVT):
|
||||
"""The backbone of Twins-SVT.
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from mmcv.runner import BaseModule
|
|||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from mmseg.ops import Upsample
|
||||
from ..builder import BACKBONES
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import UpConvBlock
|
||||
|
||||
|
||||
|
@ -221,7 +221,7 @@ class InterpConv(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class UNet(BaseModule):
|
||||
"""UNet backbone.
|
||||
|
||||
|
|
|
@ -15,8 +15,8 @@ from torch.nn.modules.batchnorm import _BatchNorm
|
|||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import get_root_logger
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import PatchEmbed
|
||||
|
||||
|
||||
|
@ -122,7 +122,7 @@ class TransformerEncoderLayer(BaseModule):
|
|||
return x
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
@MODELS.register_module()
|
||||
class VisionTransformer(BaseModule):
|
||||
"""Vision Transformer.
|
||||
|
||||
|
|
|
@ -1,12 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
from mmcv.cnn import MODELS as MMCV_MODELS
|
||||
from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
|
||||
from mmcv.utils import Registry
|
||||
|
||||
MODELS = Registry('models', parent=MMCV_MODELS)
|
||||
ATTENTION = Registry('attention', parent=MMCV_ATTENTION)
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
BACKBONES = MODELS
|
||||
NECKS = MODELS
|
||||
|
@ -17,21 +12,29 @@ SEGMENTORS = MODELS
|
|||
|
||||
def build_backbone(cfg):
|
||||
"""Build backbone."""
|
||||
warnings.warn('``build_backbone`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return BACKBONES.build(cfg)
|
||||
|
||||
|
||||
def build_neck(cfg):
|
||||
"""Build neck."""
|
||||
warnings.warn('``build_neck`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return NECKS.build(cfg)
|
||||
|
||||
|
||||
def build_head(cfg):
|
||||
"""Build head."""
|
||||
warnings.warn('``build_head`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return HEADS.build(cfg)
|
||||
|
||||
|
||||
def build_loss(cfg):
|
||||
"""Build loss."""
|
||||
warnings.warn('``build_loss`` would be deprecated soon, please use '
|
||||
'``mmseg.registry.MODELS.build()`` ')
|
||||
return LOSSES.build(cfg)
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
@ -181,7 +181,7 @@ class APNB(nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ANNHead(BaseDecodeHead):
|
||||
"""Asymmetric Non-local Neural Networks for Semantic Segmentation.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -107,7 +107,7 @@ class ACM(nn.Module):
|
|||
return z_out
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class APCHead(BaseDecodeHead):
|
||||
"""Adaptive Pyramid Context Network for Semantic Segmentation.
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -50,7 +50,7 @@ class ASPPModule(nn.ModuleList):
|
|||
return aspp_outs
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ASPPHead(BaseDecodeHead):
|
||||
"""Rethinking Atrous Convolution for Semantic Image Segmentation.
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
try:
|
||||
|
@ -10,7 +10,7 @@ except ModuleNotFoundError:
|
|||
CrissCrossAttention = None
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class CCHead(FCNHead):
|
||||
"""CCNet: Criss-Cross Attention for Semantic Segmentation.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from mmcv.cnn import ConvModule, Scale
|
|||
from torch import nn
|
||||
|
||||
from mmseg.core import add_prefix
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
@ -72,7 +72,7 @@ class CAM(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class DAHead(BaseDecodeHead):
|
||||
"""Dual Attention Network for Scene Segmentation.
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||||
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -89,7 +89,7 @@ class DCM(nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class DMHead(BaseDecodeHead):
|
||||
"""Dynamic Multi-scale Filters for Semantic Segmentation.
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
|||
from mmcv.cnn import NonLocal2d
|
||||
from torch import nn
|
||||
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
|
@ -89,7 +89,7 @@ class DisentangledNonLocal2d(NonLocal2d):
|
|||
return output
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class DNLHead(FCNHead):
|
||||
"""Disentangled Non-Local Neural Networks.
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from mmcv.cnn import ConvModule, Linear, build_activation_layer
|
|||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -212,7 +212,7 @@ class FeatureFusionBlock(BaseModule):
|
|||
return x
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class DPTHead(BaseDecodeHead):
|
||||
"""Vision Transformers for Dense Prediction.
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -76,7 +76,7 @@ class EMAModule(nn.Module):
|
|||
return feats_recon
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class EMAHead(BaseDecodeHead):
|
||||
"""Expectation Maximization Attention Networks for Semantic Segmentation.
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@ import torch.nn.functional as F
|
|||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
|
||||
from mmseg.ops import Encoding, resize
|
||||
from ..builder import HEADS, build_loss
|
||||
from mmseg.registry import MODELS
|
||||
from ..builder import build_loss
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -59,7 +60,7 @@ class EncModule(nn.Module):
|
|||
return encoding_feat, output
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class EncHead(BaseDecodeHead):
|
||||
"""Context Encoding for Semantic Segmentation.
|
||||
|
||||
|
|
|
@ -3,11 +3,11 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class FCNHead(BaseDecodeHead):
|
||||
"""Fully Convolution Networks for Semantic Segmentation.
|
||||
|
||||
|
|
|
@ -4,11 +4,11 @@ import torch.nn as nn
|
|||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import Upsample, resize
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class FPNHead(BaseDecodeHead):
|
||||
"""Panoptic Feature Pyramid Networks.
|
||||
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
import torch
|
||||
from mmcv.cnn import ContextBlock
|
||||
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class GCHead(FCNHead):
|
||||
"""GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
@ -55,7 +55,7 @@ class SelfAttentionBlock(_SelfAttentionBlock):
|
|||
return self.output_project(context)
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ISAHead(BaseDecodeHead):
|
||||
"""Interlaced Sparse Self-Attention for Semantic Segmentation.
|
||||
|
||||
|
|
|
@ -7,8 +7,8 @@ from mmcv.cnn.bricks.transformer import (FFN, TRANSFORMER_LAYER,
|
|||
MultiheadAttention,
|
||||
build_transformer_layer)
|
||||
|
||||
from mmseg.models.builder import HEADS, build_head
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import get_root_logger
|
||||
|
||||
|
||||
|
@ -139,7 +139,7 @@ class KernelUpdator(nn.Module):
|
|||
return features
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class KernelUpdateHead(nn.Module):
|
||||
"""Kernel Update Head in K-Net.
|
||||
|
||||
|
@ -391,7 +391,7 @@ class KernelUpdateHead(nn.Module):
|
|||
self.conv_kernel_size)
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class IterativeDecodeHead(BaseDecodeHead):
|
||||
"""K-Net: Towards Unified Image Segmentation.
|
||||
|
||||
|
@ -414,7 +414,7 @@ class IterativeDecodeHead(BaseDecodeHead):
|
|||
super(BaseDecodeHead, self).__init__(**kwargs)
|
||||
assert num_stages == len(kernel_update_head)
|
||||
self.num_stages = num_stages
|
||||
self.kernel_generate_head = build_head(kernel_generate_head)
|
||||
self.kernel_generate_head = (kernel_generate_head)
|
||||
self.kernel_update_head = nn.ModuleList()
|
||||
self.align_corners = self.kernel_generate_head.align_corners
|
||||
self.num_classes = self.kernel_generate_head.num_classes
|
||||
|
@ -422,7 +422,7 @@ class IterativeDecodeHead(BaseDecodeHead):
|
|||
self.ignore_index = self.kernel_generate_head.ignore_index
|
||||
|
||||
for head_cfg in kernel_update_head:
|
||||
self.kernel_update_head.append(build_head(head_cfg))
|
||||
self.kernel_update_head.append(MODELS.build(head_cfg))
|
||||
|
||||
def forward(self, inputs):
|
||||
"""Forward function."""
|
||||
|
|
|
@ -5,11 +5,11 @@ from mmcv import is_tuple_of
|
|||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class LRASPPHead(BaseDecodeHead):
|
||||
"""Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
|
||||
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
import torch
|
||||
from mmcv.cnn import NonLocal2d
|
||||
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class NLHead(FCNHead):
|
||||
"""Non-local Neural Networks.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||
|
||||
|
@ -82,7 +82,7 @@ class ObjectAttentionBlock(_SelfAttentionBlock):
|
|||
return output
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class OCRHead(BaseCascadeDecodeHead):
|
||||
"""Object-Contextual Representations for Semantic Segmentation.
|
||||
|
||||
|
|
|
@ -10,8 +10,8 @@ try:
|
|||
except ModuleNotFoundError:
|
||||
point_sample = None
|
||||
|
||||
from mmseg.models.builder import HEADS
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..losses import accuracy
|
||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||
|
||||
|
@ -36,7 +36,7 @@ def calculate_uncertainty(seg_logits):
|
|||
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class PointHead(BaseCascadeDecodeHead):
|
||||
"""A mask point head use in PointRend.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
try:
|
||||
|
@ -14,7 +14,7 @@ except ModuleNotFoundError:
|
|||
PSAMask = None
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class PSAHead(BaseDecodeHead):
|
||||
"""Point-wise Spatial Attention Network for Scene Parsing.
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -59,7 +59,7 @@ class PPM(nn.ModuleList):
|
|||
return ppm_outs
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class PSPHead(BaseDecodeHead):
|
||||
"""Pyramid Scene Parsing Network.
|
||||
|
||||
|
|
|
@ -3,12 +3,12 @@ import torch
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.models.builder import HEADS
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SegformerHead(BaseDecodeHead):
|
||||
"""The all mlp Head of segformer.
|
||||
|
||||
|
|
|
@ -8,11 +8,11 @@ from mmcv.cnn.utils.weight_init import (constant_init, trunc_normal_,
|
|||
from mmcv.runner import ModuleList
|
||||
|
||||
from mmseg.models.backbones.vit import TransformerEncoderLayer
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SegmenterMaskTransformerHead(BaseDecodeHead):
|
||||
"""Segmenter: Transformer for Semantic Segmentation.
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .aspp_head import ASPPHead, ASPPModule
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ class DepthwiseSeparableASPPModule(ASPPModule):
|
|||
act_cfg=self.act_cfg)
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class DepthwiseSeparableASPPHead(ASPPHead):
|
||||
"""Encoder-Decoder with Atrous Separable Convolution for Semantic Image
|
||||
Segmentation.
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv.cnn import DepthwiseSeparableConvModule
|
||||
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class DepthwiseSeparableFCNHead(FCNHead):
|
||||
"""Depthwise-Separable Fully Convolutional Network for Semantic
|
||||
Segmentation.
|
||||
|
|
|
@ -4,11 +4,11 @@ import torch.nn as nn
|
|||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import Upsample
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SETRMLAHead(BaseDecodeHead):
|
||||
"""Multi level feature aggretation head of SETR.
|
||||
|
||||
|
|
|
@ -3,11 +3,11 @@ import torch.nn as nn
|
|||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
|
||||
from mmseg.ops import Upsample
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class SETRUPHead(BaseDecodeHead):
|
||||
"""Naive upsampling head and Progressive upsampling head of SETR.
|
||||
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .fcn_head import FCNHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class STDCHead(FCNHead):
|
||||
"""This head is the implementation of `Rethinking BiSeNet For Real-time
|
||||
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
|
||||
|
|
|
@ -4,12 +4,12 @@ import torch.nn as nn
|
|||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import HEADS
|
||||
from mmseg.registry import MODELS
|
||||
from .decode_head import BaseDecodeHead
|
||||
from .psp_head import PPM
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
@MODELS.register_module()
|
||||
class UPerHead(BaseDecodeHead):
|
||||
"""Unified Perceptual Parsing for Scene Understanding.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import get_class_weight, weight_reduce_loss
|
||||
|
||||
|
||||
|
@ -193,7 +193,7 @@ def mask_cross_entropy(pred,
|
|||
pred_slice, target, weight=class_weight, reduction='mean')[None]
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
"""CrossEntropyLoss.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import get_class_weight, weighted_loss
|
||||
|
||||
|
||||
|
@ -47,7 +47,7 @@ def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
|
|||
return 1 - num / den
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class DiceLoss(nn.Module):
|
||||
"""DiceLoss.
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
|
||||
|
||||
from ..builder import LOSSES
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import weight_reduce_loss
|
||||
|
||||
|
||||
|
@ -133,7 +133,7 @@ def sigmoid_focal_loss(pred,
|
|||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class FocalLoss(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import get_class_weight, weight_reduce_loss
|
||||
|
||||
|
||||
|
@ -222,7 +222,7 @@ def lovasz_softmax(probs,
|
|||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
@MODELS.register_module()
|
||||
class LovaszLoss(nn.Module):
|
||||
"""LovaszLoss.
|
||||
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
|
||||
from ..builder import NECKS
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
@MODELS.register_module()
|
||||
class Feature2Pyramid(nn.Module):
|
||||
"""Feature2Pyramid.
|
||||
|
||||
|
|
|
@ -5,10 +5,10 @@ from mmcv.cnn import ConvModule
|
|||
from mmcv.runner import BaseModule, auto_fp16
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import NECKS
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
@MODELS.register_module()
|
||||
class FPN(BaseModule):
|
||||
"""Feature Pyramid Network.
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from mmcv.cnn import ConvModule
|
|||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import NECKS
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class CascadeFeatureFusion(BaseModule):
|
||||
|
@ -77,7 +77,7 @@ class CascadeFeatureFusion(BaseModule):
|
|||
return x, x_low
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
@MODELS.register_module()
|
||||
class ICNeck(BaseModule):
|
||||
"""ICNet for Real-Time Semantic Segmentation on High-Resolution Images.
|
||||
|
||||
|
|
|
@ -5,10 +5,10 @@ from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
|||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import NECKS
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
@MODELS.register_module()
|
||||
class JPU(BaseModule):
|
||||
"""FastFCN: Rethinking Dilated Convolution in the Backbone
|
||||
for Semantic Segmentation.
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
|
||||
from ..builder import NECKS
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
class MLAModule(nn.Module):
|
||||
|
@ -59,7 +59,7 @@ class MLAModule(nn.Module):
|
|||
return tuple(out_list)
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
@MODELS.register_module()
|
||||
class MLANeck(nn.Module):
|
||||
"""Multi-level Feature Aggregation.
|
||||
|
||||
|
|
|
@ -3,10 +3,10 @@ import torch.nn as nn
|
|||
from mmcv.cnn import ConvModule, xavier_init
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import NECKS
|
||||
from mmseg.registry import MODELS
|
||||
|
||||
|
||||
@NECKS.register_module()
|
||||
@MODELS.register_module()
|
||||
class MultiLevelNeck(nn.Module):
|
||||
"""MultiLevelNeck.
|
||||
|
||||
|
|
|
@ -3,12 +3,11 @@ from torch import nn
|
|||
|
||||
from mmseg.core import add_prefix
|
||||
from mmseg.ops import resize
|
||||
from .. import builder
|
||||
from ..builder import SEGMENTORS
|
||||
from mmseg.registry import MODELS
|
||||
from .encoder_decoder import EncoderDecoder
|
||||
|
||||
|
||||
@SEGMENTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class CascadeEncoderDecoder(EncoderDecoder):
|
||||
"""Cascade Encoder Decoder segmentors.
|
||||
|
||||
|
@ -44,7 +43,7 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
|||
assert len(decode_head) == self.num_stages
|
||||
self.decode_head = nn.ModuleList()
|
||||
for i in range(self.num_stages):
|
||||
self.decode_head.append(builder.build_head(decode_head[i]))
|
||||
self.decode_head.append(MODELS.build(decode_head[i]))
|
||||
self.align_corners = self.decode_head[-1].align_corners
|
||||
self.num_classes = self.decode_head[-1].num_classes
|
||||
|
||||
|
|
|
@ -5,12 +5,11 @@ import torch.nn.functional as F
|
|||
|
||||
from mmseg.core import add_prefix
|
||||
from mmseg.ops import resize
|
||||
from .. import builder
|
||||
from ..builder import SEGMENTORS
|
||||
from mmseg.registry import MODELS
|
||||
from .base import BaseSegmentor
|
||||
|
||||
|
||||
@SEGMENTORS.register_module()
|
||||
@MODELS.register_module()
|
||||
class EncoderDecoder(BaseSegmentor):
|
||||
"""Encoder Decoder segmentors.
|
||||
|
||||
|
@ -33,9 +32,9 @@ class EncoderDecoder(BaseSegmentor):
|
|||
assert backbone.get('pretrained') is None, \
|
||||
'both backbone and segmentor set pretrained weight'
|
||||
backbone.pretrained = pretrained
|
||||
self.backbone = builder.build_backbone(backbone)
|
||||
self.backbone = MODELS.build(backbone)
|
||||
if neck is not None:
|
||||
self.neck = builder.build_neck(neck)
|
||||
self.neck = MODELS.build(neck)
|
||||
self._init_decode_head(decode_head)
|
||||
self._init_auxiliary_head(auxiliary_head)
|
||||
|
||||
|
@ -46,7 +45,7 @@ class EncoderDecoder(BaseSegmentor):
|
|||
|
||||
def _init_decode_head(self, decode_head):
|
||||
"""Initialize ``decode_head``"""
|
||||
self.decode_head = builder.build_head(decode_head)
|
||||
self.decode_head = MODELS.build(decode_head)
|
||||
self.align_corners = self.decode_head.align_corners
|
||||
self.num_classes = self.decode_head.num_classes
|
||||
|
||||
|
@ -56,9 +55,9 @@ class EncoderDecoder(BaseSegmentor):
|
|||
if isinstance(auxiliary_head, list):
|
||||
self.auxiliary_head = nn.ModuleList()
|
||||
for head_cfg in auxiliary_head:
|
||||
self.auxiliary_head.append(builder.build_head(head_cfg))
|
||||
self.auxiliary_head.append(MODELS.build(head_cfg))
|
||||
else:
|
||||
self.auxiliary_head = builder.build_head(auxiliary_head)
|
||||
self.auxiliary_head = MODELS.build(auxiliary_head)
|
||||
|
||||
def extract_feat(self, img):
|
||||
"""Extract features from images."""
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
|
||||
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS,
|
||||
OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS,
|
||||
RUNNERS, TASK_UTILS, TRANSFORMS, VISBACKENDS,
|
||||
VISUALIZERS, WEIGHT_INITIALIZERS)
|
||||
|
||||
__all__ = [
|
||||
'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS',
|
||||
'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS',
|
||||
'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'METRICS',
|
||||
'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS'
|
||||
]
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""MMSegmentation provides 17 registry nodes to support using modules across
|
||||
projects. Each node is a child of the root registry in MMEngine.
|
||||
|
||||
More details can be found at
|
||||
https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
||||
"""
|
||||
|
||||
from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
|
||||
from mmengine.registry import DATASETS as MMENGINE_DATASETS
|
||||
from mmengine.registry import HOOKS as MMENGINE_HOOKS
|
||||
from mmengine.registry import LOOPS as MMENGINE_LOOPS
|
||||
from mmengine.registry import METRICS as MMENGINE_METRICS
|
||||
from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
|
||||
from mmengine.registry import MODELS as MMENGINE_MODELS
|
||||
from mmengine.registry import \
|
||||
OPTIMIZER_CONSTRUCTORS as MMENGINE_OPTIMIZER_CONSTRUCTORS
|
||||
from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
|
||||
from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
|
||||
from mmengine.registry import \
|
||||
RUNNER_CONSTRUCTORS as MMENGINE_RUNNER_CONSTRUCTORS
|
||||
from mmengine.registry import RUNNERS as MMENGINE_RUNNERS
|
||||
from mmengine.registry import TASK_UTILS as MMENGINE_TASK_UTILS
|
||||
from mmengine.registry import TRANSFORMS as MMENGINE_TRANSFORMS
|
||||
from mmengine.registry import VISBACKENDS as MMENGINE_VISBACKENDS
|
||||
from mmengine.registry import VISUALIZERS as MMENGINE_VISUALIZERS
|
||||
from mmengine.registry import \
|
||||
WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS
|
||||
from mmengine.registry import Registry
|
||||
|
||||
# manage all kinds of runners like `EpochBasedRunner` and `IterBasedRunner`
|
||||
RUNNERS = Registry('runner', parent=MMENGINE_RUNNERS)
|
||||
# manage runner constructors that define how to initialize runners
|
||||
RUNNER_CONSTRUCTORS = Registry(
|
||||
'runner constructor', parent=MMENGINE_RUNNER_CONSTRUCTORS)
|
||||
# manage all kinds of loops like `EpochBasedTrainLoop`
|
||||
LOOPS = Registry('loop', parent=MMENGINE_LOOPS)
|
||||
# manage all kinds of hooks like `CheckpointHook`
|
||||
HOOKS = Registry('hook', parent=MMENGINE_HOOKS)
|
||||
|
||||
# manage data-related modules
|
||||
DATASETS = Registry('dataset', parent=MMENGINE_DATASETS)
|
||||
DATA_SAMPLERS = Registry('data sampler', parent=MMENGINE_DATA_SAMPLERS)
|
||||
TRANSFORMS = Registry('transform', parent=MMENGINE_TRANSFORMS)
|
||||
|
||||
# mangage all kinds of modules inheriting `nn.Module`
|
||||
MODELS = Registry('model', parent=MMENGINE_MODELS)
|
||||
# mangage all kinds of model wrappers like 'MMDistributedDataParallel'
|
||||
MODEL_WRAPPERS = Registry('model_wrapper', parent=MMENGINE_MODEL_WRAPPERS)
|
||||
# mangage all kinds of weight initialization modules like `Uniform`
|
||||
WEIGHT_INITIALIZERS = Registry(
|
||||
'weight initializer', parent=MMENGINE_WEIGHT_INITIALIZERS)
|
||||
|
||||
# mangage all kinds of optimizers like `SGD` and `Adam`
|
||||
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
|
||||
# manage constructors that customize the optimization hyperparameters.
|
||||
OPTIMIZER_CONSTRUCTORS = Registry(
|
||||
'optimizer constructor', parent=MMENGINE_OPTIMIZER_CONSTRUCTORS)
|
||||
# mangage all kinds of parameter schedulers like `MultiStepLR`
|
||||
PARAM_SCHEDULERS = Registry(
|
||||
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
|
||||
# manage all kinds of metrics
|
||||
METRICS = Registry('metric', parent=MMENGINE_METRICS)
|
||||
|
||||
# manage task-specific modules like ohem pixel sampler
|
||||
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
|
||||
|
||||
# manage visualizer
|
||||
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
|
||||
# manage visualizer backend
|
||||
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
|
|
@ -2,10 +2,10 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import DefaultOptimizerConstructor
|
||||
from mmengine.optim import DefaultOptimizerConstructor
|
||||
|
||||
from mmseg.core.builder import (OPTIMIZER_BUILDERS, build_optimizer,
|
||||
build_optimizer_constructor)
|
||||
from mmseg.core.builder import build_optimizer, build_optimizer_constructor
|
||||
from mmseg.registry import OPTIMIZER_CONSTRUCTORS
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
@ -35,7 +35,7 @@ def test_build_optimizer_constructor():
|
|||
# Test whether optimizer constructor can be built from parent.
|
||||
assert type(optim_constructor) is DefaultOptimizerConstructor
|
||||
|
||||
@OPTIMIZER_BUILDERS.register_module()
|
||||
@OPTIMIZER_CONSTRUCTORS.register_module()
|
||||
class MyOptimizerConstructor(DefaultOptimizerConstructor):
|
||||
pass
|
||||
|
||||
|
|
Loading…
Reference in New Issue