[Fix] Fix repo based on refactoring standard (#1869)

* [Fix] Fix repo based on refactory standard

* fix ut
This commit is contained in:
Miao Zheng 2022-08-05 20:37:35 +08:00 committed by MeowZheng
parent bfe0fbe04d
commit e0499d5a77
83 changed files with 264 additions and 280 deletions

View File

@ -55,8 +55,8 @@ mmcv_max_version = digit_version(MMCV_MAX)
mmcv_version = digit_version(mmcv.__version__)
assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \
assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.'
f'Please install mmcv>={mmcv_min_version}, <{mmcv_max_version}.'
__all__ = ['__version__', 'version_info', 'digit_version']

View File

@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ade import ADE20KDataset
from .basesegdataset import BaseSegDataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
from .coco_stuff import COCOStuffDataset
from .custom import BaseSegDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import MultiImageMixDataset
from .drive import DRIVEDataset

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.registry import DATASETS
from .custom import BaseSegDataset
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.registry import DATASETS
from .custom import BaseSegDataset
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.registry import DATASETS
from .custom import BaseSegDataset
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.registry import DATASETS
from .custom import BaseSegDataset
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.registry import DATASETS
from .custom import BaseSegDataset
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.registry import DATASETS
from .custom import BaseSegDataset
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.registry import DATASETS
from .custom import BaseSegDataset
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.registry import DATASETS
from .custom import BaseSegDataset
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.registry import DATASETS
from .custom import BaseSegDataset
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()

View File

@ -2,7 +2,7 @@
import os.path as osp
from mmseg.registry import DATASETS
from .custom import BaseSegDataset
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.registry import DATASETS
from .custom import BaseSegDataset
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.registry import DATASETS
from .custom import BaseSegDataset
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()

View File

@ -1,50 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import collections
from mmseg.registry import TRANSFORMS
@TRANSFORMS.register_module()
class Compose(object):
"""Compose multiple transforms sequentially.
Args:
transforms (Sequence[dict | callable]): Sequence of transform object or
config dict to be composed.
"""
def __init__(self, transforms):
assert isinstance(transforms, collections.abc.Sequence)
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
transform = TRANSFORMS.build(transform)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict')
def __call__(self, data):
"""Call function to apply transforms sequentially.
Args:
data (dict): A result dict contains the data to transform.
Returns:
dict: Transformed data.
"""
for t in self.transforms:
data = t(data)
if data is None:
return None
return data
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += f' {t}'
format_string += '\n)'
return format_string

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmcv.parallel import DataContainer as DC
from mmcv.transforms import to_tensor
from mmcv.transforms.base import BaseTransform
from mmengine.data import PixelData
@ -87,114 +86,3 @@ class PackSegInputs(BaseTransform):
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str
@TRANSFORMS.register_module()
class ImageToTensor(object):
"""Convert image to :obj:`torch.Tensor` by given keys.
The dimension order of input image is (H, W, C). The pipeline will convert
it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
(1, H, W).
Args:
keys (Sequence[str]): Key of images to be converted to Tensor.
"""
def __init__(self, keys):
self.keys = keys
def __call__(self, results):
"""Call function to convert image in results to :obj:`torch.Tensor` and
transpose the channel order.
Args:
results (dict): Result dict contains the image data to convert.
Returns:
dict: The result dict contains the image converted
to :obj:`torch.Tensor` and transposed to (C, H, W) order.
"""
for key in self.keys:
img = results[key]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
results[key] = to_tensor(img.transpose(2, 0, 1))
return results
def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
@TRANSFORMS.register_module()
class Transpose(object):
"""Transpose some results by given keys.
Args:
keys (Sequence[str]): Keys of results to be transposed.
order (Sequence[int]): Order of transpose.
"""
def __init__(self, keys, order):
self.keys = keys
self.order = order
def __call__(self, results):
"""Call function to convert image in results to :obj:`torch.Tensor` and
transpose the channel order.
Args:
results (dict): Result dict contains the image data to convert.
Returns:
dict: The result dict contains the image converted
to :obj:`torch.Tensor` and transposed to (C, H, W) order.
"""
for key in self.keys:
results[key] = results[key].transpose(self.order)
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, order={self.order})'
@TRANSFORMS.register_module()
class ToDataContainer(object):
"""Convert results to :obj:`mmcv.DataContainer` by given fields.
Args:
fields (Sequence[dict]): Each field is a dict like
``dict(key='xxx', **kwargs)``. The ``key`` in result will
be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
Default: ``(dict(key='img', stack=True),
dict(key='gt_semantic_seg'))``.
"""
def __init__(self,
fields=(dict(key='img',
stack=True), dict(key='gt_semantic_seg'))):
self.fields = fields
def __call__(self, results):
"""Call function to convert data in results to
:obj:`mmcv.DataContainer`.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data converted to
:obj:`mmcv.DataContainer`.
"""
for field in self.fields:
field = field.copy()
key = field.pop('key')
results[key] = DC(results[key], **field)
return results
def __repr__(self):
return self.__class__.__name__ + f'(fields={self.fields})'

View File

@ -2,7 +2,7 @@
import os.path as osp
from mmseg.registry import DATASETS
from .custom import BaseSegDataset
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()

View File

@ -4,8 +4,8 @@ import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
class SpatialPath(BaseModule):

View File

@ -5,8 +5,8 @@ from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
build_activation_layer, build_norm_layer)
from mmengine.model import BaseModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
class DetailBranch(BaseModule):

View File

@ -4,8 +4,8 @@ import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmengine.model import BaseModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
class DownsamplerBlock(BaseModule):

View File

@ -5,9 +5,8 @@ from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmengine.model import BaseModule
from mmseg.models.decode_heads.psp_head import PPM
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import InvertedResidual
from ..utils import InvertedResidual, resize
class LearningToDownsample(nn.Module):

View File

@ -6,8 +6,8 @@ from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.model import BaseModule, ModuleList, Sequential
from mmseg.ops import Upsample, resize
from mmseg.registry import MODELS
from ..utils import Upsample, resize
from .resnet import BasicBlock, Bottleneck

View File

@ -4,9 +4,9 @@ import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..decode_heads.psp_head import PPM
from ..utils import resize
@MODELS.register_module()

View File

@ -6,8 +6,8 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule, ModuleList, Sequential
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
from .bisenetv1 import AttentionRefinementModule

View File

@ -8,9 +8,8 @@ from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
from mmcv.utils.parrots_wrapper import _BatchNorm
from mmengine.model import BaseModule
from mmseg.ops import Upsample
from mmseg.registry import MODELS
from ..utils import UpConvBlock
from ..utils import UpConvBlock, Upsample
class BasicConvBlock(nn.Module):

View File

@ -15,9 +15,8 @@ from mmengine.model import BaseModule, ModuleList
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.utils import _pair as to_2tuple
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import PatchEmbed
from ..utils import PatchEmbed, resize
class TransformerEncoderLayer(BaseModule):

View File

@ -4,8 +4,8 @@ import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead

View File

@ -3,8 +3,8 @@ import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead

View File

@ -7,11 +7,11 @@ import torch.nn as nn
from mmengine.model import BaseModule
from torch import Tensor
from mmseg.ops import resize
from mmseg.structures import build_pixel_sampler
from mmseg.utils import ConfigType, SampleList
from ..builder import build_loss
from ..losses import accuracy
from ..utils import resize
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):

View File

@ -6,8 +6,8 @@ import torch.nn as nn
from mmcv.cnn import ConvModule, Linear, build_activation_layer
from mmengine.model import BaseModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead

View File

@ -7,10 +7,10 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule, build_norm_layer
from torch import Tensor
from mmseg.ops import Encoding, resize
from mmseg.registry import MODELS
from mmseg.utils import ConfigType, SampleList
from ..builder import build_loss
from ..utils import Encoding, resize
from .decode_head import BaseDecodeHead

View File

@ -3,8 +3,8 @@ import numpy as np
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import Upsample, resize
from mmseg.registry import MODELS
from ..utils import Upsample, resize
from .decode_head import BaseDecodeHead

View File

@ -4,8 +4,8 @@ import torch.nn as nn
from mmcv import is_tuple_of
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead

View File

@ -4,9 +4,9 @@ import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
from ..utils import resize
from .cascade_decode_head import BaseCascadeDecodeHead

View File

@ -12,10 +12,10 @@ except ModuleNotFoundError:
from typing import List
from mmseg.ops import resize
from mmseg.registry import MODELS
from mmseg.utils import SampleList
from ..losses import accuracy
from ..utils import resize
from .cascade_decode_head import BaseCascadeDecodeHead

View File

@ -4,8 +4,8 @@ import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead
try:

View File

@ -3,8 +3,8 @@ import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead

View File

@ -4,8 +4,8 @@ import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
@MODELS.register_module()

View File

@ -3,8 +3,8 @@ import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
from .aspp_head import ASPPHead, ASPPModule

View File

@ -3,8 +3,8 @@ import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import Upsample
from mmseg.registry import MODELS
from ..utils import Upsample
from .decode_head import BaseDecodeHead

View File

@ -2,8 +2,8 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from mmseg.ops import Upsample
from mmseg.registry import MODELS
from ..utils import Upsample
from .decode_head import BaseDecodeHead

View File

@ -3,8 +3,8 @@ import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
from .decode_head import BaseDecodeHead
from .psp_head import PPM

View File

@ -4,8 +4,8 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
@MODELS.register_module()

View File

@ -3,8 +3,8 @@ import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
class CascadeFeatureFusion(BaseModule):

View File

@ -4,8 +4,8 @@ import torch.nn as nn
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmengine.model import BaseModule
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
@MODELS.register_module()

View File

@ -2,8 +2,8 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, xavier_init
from mmseg.ops import resize
from mmseg.registry import MODELS
from ..utils import resize
@MODELS.register_module()

View File

@ -6,10 +6,10 @@ from mmengine.data import PixelData
from mmengine.model import BaseModel
from torch import Tensor
from mmseg.ops import resize
from mmseg.structures import SegDataSample
from mmseg.utils import (ForwardResults, OptConfigType, OptMultiConfig,
OptSampleList, SampleList)
from ..utils import resize
class BaseSegmentor(BaseModel, metaclass=ABCMeta):

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .embed import PatchEmbed
from .encoding import Encoding
from .inverted_residual import InvertedResidual, InvertedResidualV3
from .make_divisible import make_divisible
from .res_layer import ResLayer
@ -8,9 +9,11 @@ from .self_attention_block import SelfAttentionBlock
from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc,
nlc_to_nchw)
from .up_conv_block import UpConvBlock
from .wrappers import Upsample, resize
__all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed',
'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc'
'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding',
'Upsample', 'resize'
]

View File

@ -0,0 +1,75 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn
from torch.nn import functional as F
class Encoding(nn.Module):
"""Encoding Layer: a learnable residual encoder.
Input is of shape (batch_size, channels, height, width).
Output is of shape (batch_size, num_codes, channels).
Args:
channels: dimension of the features or feature channels
num_codes: number of code words
"""
def __init__(self, channels, num_codes):
super(Encoding, self).__init__()
# init codewords and smoothing factor
self.channels, self.num_codes = channels, num_codes
std = 1. / ((num_codes * channels)**0.5)
# [num_codes, channels]
self.codewords = nn.Parameter(
torch.empty(num_codes, channels,
dtype=torch.float).uniform_(-std, std),
requires_grad=True)
# [num_codes]
self.scale = nn.Parameter(
torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
requires_grad=True)
@staticmethod
def scaled_l2(x, codewords, scale):
num_codes, channels = codewords.size()
batch_size = x.size(0)
reshaped_scale = scale.view((1, 1, num_codes))
expanded_x = x.unsqueeze(2).expand(
(batch_size, x.size(1), num_codes, channels))
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
scaled_l2_norm = reshaped_scale * (
expanded_x - reshaped_codewords).pow(2).sum(dim=3)
return scaled_l2_norm
@staticmethod
def aggregate(assignment_weights, x, codewords):
num_codes, channels = codewords.size()
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
batch_size = x.size(0)
expanded_x = x.unsqueeze(2).expand(
(batch_size, x.size(1), num_codes, channels))
encoded_feat = (assignment_weights.unsqueeze(3) *
(expanded_x - reshaped_codewords)).sum(dim=1)
return encoded_feat
def forward(self, x):
assert x.dim() == 4 and x.size(1) == self.channels
# [batch_size, channels, height, width]
batch_size = x.size(0)
# [batch_size, height x width, channels]
x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
# assignment_weights: [batch_size, channels, num_codes]
assignment_weights = F.softmax(
self.scaled_l2(x, self.codewords, self.scale), dim=2)
# aggregate
encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
return encoded_feat
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
f'x{self.channels})'
return repr_str

View File

@ -0,0 +1,51 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch.nn as nn
import torch.nn.functional as F
def resize(input,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None,
warning=True):
if warning:
if size is not None and align_corners:
input_h, input_w = tuple(int(x) for x in input.shape[2:])
output_h, output_w = tuple(int(x) for x in size)
if output_h > input_h or output_w > output_h:
if ((output_h > 1 and output_w > 1 and input_h > 1
and input_w > 1) and (output_h - 1) % (input_h - 1)
and (output_w - 1) % (input_w - 1)):
warnings.warn(
f'When align_corners={align_corners}, '
'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`')
return F.interpolate(input, size, scale_factor, mode, align_corners)
class Upsample(nn.Module):
def __init__(self,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None):
super(Upsample, self).__init__()
self.size = size
if isinstance(scale_factor, tuple):
self.scale_factor = tuple(float(factor) for factor in scale_factor)
else:
self.scale_factor = float(scale_factor) if scale_factor else None
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
if not self.size:
size = [int(t * self.scale_factor) for t in x.shape[-2:]]
else:
size = self.size
return resize(x, size, None, self.mode, self.align_corners)

View File

@ -1,5 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .encoding import Encoding
from .wrappers import Upsample, resize
__all__ = ['Upsample', 'resize', 'Encoding']

View File

@ -8,13 +8,16 @@ https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
from mmengine.registry import DATASETS as MMENGINE_DATASETS
from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
from mmengine.registry import HOOKS as MMENGINE_HOOKS
from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS
from mmengine.registry import LOOPS as MMENGINE_LOOPS
from mmengine.registry import METRICS as MMENGINE_METRICS
from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
from mmengine.registry import MODELS as MMENGINE_MODELS
from mmengine.registry import \
OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS
from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS
from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
from mmengine.registry import \
@ -53,14 +56,20 @@ WEIGHT_INITIALIZERS = Registry(
# mangage all kinds of optimizers like `SGD` and `Adam`
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
# manage optimizer wrapper
OPTIM_WRAPPERS = Registry('optim_wrapper', parent=MMENGINE_OPTIM_WRAPPERS)
# manage constructors that customize the optimization hyperparameters.
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
'optimizer constructor', parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
'optimizer wrapper constructor',
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
# mangage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry(
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
# manage all kinds of metrics
METRICS = Registry('metric', parent=MMENGINE_METRICS)
# manage evaluator
EVALUATOR = Registry('evaluator', parent=MMENGINE_EVALUATOR)
# manage task-specific modules like ohem pixel sampler
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
@ -69,3 +78,6 @@ TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
# manage visualizer backend
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
# manage logprocessor
LOG_PROCESSORS = Registry('log_processor', parent=MMENGINE_LOG_PROCESSORS)

View File

@ -1,6 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
__version__ = '0.24.1'
__version__ = '1.0.0rc0'
def parse_version_info(version_str):

View File

@ -1,2 +1,2 @@
mmcls>=0.20.1
mmcv-full>=1.4.4,<=1.6.0
mmcv-full>=2.0.0,<=2.0.0rc0

View File

@ -3,10 +3,13 @@ import glob
import os
from os.path import dirname, exists, isdir, join, relpath
from mmcv import Config
import numpy as np
from mmengine import Config
from mmengine.dataset import Compose
from torch import nn
from mmseg.models import build_segmentor
from mmseg.utils import register_all_modules
def _get_config_directory():
@ -60,72 +63,69 @@ def test_config_build_segmentor():
_check_decode_head(head_config, segmentor.decode_head)
# def test_config_data_pipeline():
# """Test whether the data pipeline is valid and can process corner cases.
def test_config_data_pipeline():
"""Test whether the data pipeline is valid and can process corner cases.
# CommandLine:
# xdoctest -m tests/test_config.py test_config_build_data_pipeline
# """
# import numpy as np
# from mmcv import Config
CommandLine:
xdoctest -m tests/test_config.py test_config_build_data_pipeline
"""
# from mmseg.datasets.transforms import Compose
register_all_modules()
config_dpath = _get_config_directory()
print('Found config_dpath = {!r}'.format(config_dpath))
# config_dpath = _get_config_directory()
# print('Found config_dpath = {!r}'.format(config_dpath))
import glob
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
config_names = [relpath(p, config_dpath) for p in config_fpaths]
# import glob
# config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
# config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
# config_names = [relpath(p, config_dpath) for p in config_fpaths]
print('Using {} config files'.format(len(config_names)))
# print('Using {} config files'.format(len(config_names)))
for config_fname in config_names:
config_fpath = join(config_dpath, config_fname)
print(
'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
config_mod = Config.fromfile(config_fpath)
# for config_fname in config_names:
# config_fpath = join(config_dpath, config_fname)
# print(
# 'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
# config_mod = Config.fromfile(config_fpath)
# remove loading pipeline
load_img_pipeline = config_mod.train_pipeline.pop(0)
to_float32 = load_img_pipeline.get('to_float32', False)
config_mod.train_pipeline.pop(0)
config_mod.test_pipeline.pop(0)
# remove loading annotation in test pipeline
config_mod.test_pipeline.pop(1)
# # remove loading pipeline
# load_img_pipeline = config_mod.train_pipeline.pop(0)
# to_float32 = load_img_pipeline.get('to_float32', False)
# config_mod.train_pipeline.pop(0)
# config_mod.test_pipeline.pop(0)
# # remove loading annotation in test pipeline
# config_mod.test_pipeline.pop(1)
train_pipeline = Compose(config_mod.train_pipeline)
test_pipeline = Compose(config_mod.test_pipeline)
# train_pipeline = Compose(config_mod.train_pipeline)
# test_pipeline = Compose(config_mod.test_pipeline)
img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
if to_float32:
img = img.astype(np.float32)
seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
# img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
# if to_float32:
# img = img.astype(np.float32)
# seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_seg_map=seg)
results['seg_fields'] = ['gt_seg_map']
# results = dict(
# filename='test_img.png',
# ori_filename='test_img.png',
# img=img,
# img_shape=img.shape,
# ori_shape=img.shape,
# gt_seg_map=seg)
# results['seg_fields'] = ['gt_seg_map']
print('Test training data pipeline: \n{!r}'.format(train_pipeline))
output_results = train_pipeline(results)
assert output_results is not None
# print('Test training data pipeline: \n{!r}'.format(train_pipeline))
# output_results = train_pipeline(results)
# assert output_results is not None
# results = dict(
# filename='test_img.png',
# ori_filename='test_img.png',
# img=img,
# img_shape=img.shape,
# ori_shape=img.shape,
# )
# print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
# output_results = test_pipeline(results)
# assert output_results is not None
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
)
print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
output_results = test_pipeline(results)
assert output_results is not None
def _check_decode_head(decode_head_cfg, decode_head):

View File

@ -1,4 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import all_zeros, check_norm_state, is_block, is_norm
__all__ = ['is_norm', 'is_block', 'all_zeros', 'check_norm_state']

View File

@ -5,7 +5,7 @@ from mmcv.cnn import ConvModule
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
InterpConv, UNet, UpConvBlock)
from mmseg.ops import Upsample
from mmseg.models.utils import Upsample
from .utils import check_norm_state

View File

@ -15,6 +15,14 @@ def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume',
nargs='?',
type=str,
const='auto',
help='If specify checkpint path, resume from it, while if not '
'specify, try to auto resume from the latest checkpoint '
'in the work directory.')
parser.add_argument(
'--amp',
action='store_true',
@ -80,6 +88,14 @@ def main():
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.loss_scale = 'dynamic'
# resume training
if args.resume == 'auto':
cfg.resume = True
cfg.load_from = None
elif args.resume is not None:
cfg.resume = True
cfg.load_from = args.resume
# build the runner from config
runner = Runner.from_cfg(cfg)