mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix repo based on refactoring standard (#1869)
* [Fix] Fix repo based on refactory standard * fix ut
This commit is contained in:
parent
bfe0fbe04d
commit
e0499d5a77
@ -55,8 +55,8 @@ mmcv_max_version = digit_version(MMCV_MAX)
|
||||
mmcv_version = digit_version(mmcv.__version__)
|
||||
|
||||
|
||||
assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \
|
||||
assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \
|
||||
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
|
||||
f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.'
|
||||
f'Please install mmcv>={mmcv_min_version}, <{mmcv_max_version}.'
|
||||
|
||||
__all__ = ['__version__', 'version_info', 'digit_version']
|
||||
|
@ -1,9 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .ade import ADE20KDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
from .chase_db1 import ChaseDB1Dataset
|
||||
from .cityscapes import CityscapesDataset
|
||||
from .coco_stuff import COCOStuffDataset
|
||||
from .custom import BaseSegDataset
|
||||
from .dark_zurich import DarkZurichDataset
|
||||
from .dataset_wrappers import MultiImageMixDataset
|
||||
from .drive import DRIVEDataset
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import BaseSegDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import BaseSegDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import BaseSegDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import BaseSegDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import BaseSegDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import BaseSegDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import BaseSegDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import BaseSegDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import BaseSegDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -2,7 +2,7 @@
|
||||
import os.path as osp
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import BaseSegDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import BaseSegDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import BaseSegDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -1,50 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import collections
|
||||
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Compose(object):
|
||||
"""Compose multiple transforms sequentially.
|
||||
|
||||
Args:
|
||||
transforms (Sequence[dict | callable]): Sequence of transform object or
|
||||
config dict to be composed.
|
||||
"""
|
||||
|
||||
def __init__(self, transforms):
|
||||
assert isinstance(transforms, collections.abc.Sequence)
|
||||
self.transforms = []
|
||||
for transform in transforms:
|
||||
if isinstance(transform, dict):
|
||||
transform = TRANSFORMS.build(transform)
|
||||
self.transforms.append(transform)
|
||||
elif callable(transform):
|
||||
self.transforms.append(transform)
|
||||
else:
|
||||
raise TypeError('transform must be callable or a dict')
|
||||
|
||||
def __call__(self, data):
|
||||
"""Call function to apply transforms sequentially.
|
||||
|
||||
Args:
|
||||
data (dict): A result dict contains the data to transform.
|
||||
|
||||
Returns:
|
||||
dict: Transformed data.
|
||||
"""
|
||||
|
||||
for t in self.transforms:
|
||||
data = t(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
|
||||
def __repr__(self):
|
||||
format_string = self.__class__.__name__ + '('
|
||||
for t in self.transforms:
|
||||
format_string += '\n'
|
||||
format_string += f' {t}'
|
||||
format_string += '\n)'
|
||||
return format_string
|
@ -1,6 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
from mmcv.parallel import DataContainer as DC
|
||||
from mmcv.transforms import to_tensor
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmengine.data import PixelData
|
||||
@ -87,114 +86,3 @@ class PackSegInputs(BaseTransform):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(meta_keys={self.meta_keys})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ImageToTensor(object):
|
||||
"""Convert image to :obj:`torch.Tensor` by given keys.
|
||||
|
||||
The dimension order of input image is (H, W, C). The pipeline will convert
|
||||
it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
|
||||
(1, H, W).
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): Key of images to be converted to Tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, keys):
|
||||
self.keys = keys
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to convert image in results to :obj:`torch.Tensor` and
|
||||
transpose the channel order.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the image data to convert.
|
||||
|
||||
Returns:
|
||||
dict: The result dict contains the image converted
|
||||
to :obj:`torch.Tensor` and transposed to (C, H, W) order.
|
||||
"""
|
||||
|
||||
for key in self.keys:
|
||||
img = results[key]
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
results[key] = to_tensor(img.transpose(2, 0, 1))
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(keys={self.keys})'
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Transpose(object):
|
||||
"""Transpose some results by given keys.
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): Keys of results to be transposed.
|
||||
order (Sequence[int]): Order of transpose.
|
||||
"""
|
||||
|
||||
def __init__(self, keys, order):
|
||||
self.keys = keys
|
||||
self.order = order
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to convert image in results to :obj:`torch.Tensor` and
|
||||
transpose the channel order.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the image data to convert.
|
||||
|
||||
Returns:
|
||||
dict: The result dict contains the image converted
|
||||
to :obj:`torch.Tensor` and transposed to (C, H, W) order.
|
||||
"""
|
||||
|
||||
for key in self.keys:
|
||||
results[key] = results[key].transpose(self.order)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + \
|
||||
f'(keys={self.keys}, order={self.order})'
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ToDataContainer(object):
|
||||
"""Convert results to :obj:`mmcv.DataContainer` by given fields.
|
||||
|
||||
Args:
|
||||
fields (Sequence[dict]): Each field is a dict like
|
||||
``dict(key='xxx', **kwargs)``. The ``key`` in result will
|
||||
be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
|
||||
Default: ``(dict(key='img', stack=True),
|
||||
dict(key='gt_semantic_seg'))``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
fields=(dict(key='img',
|
||||
stack=True), dict(key='gt_semantic_seg'))):
|
||||
self.fields = fields
|
||||
|
||||
def __call__(self, results):
|
||||
"""Call function to convert data in results to
|
||||
:obj:`mmcv.DataContainer`.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict contains the data to convert.
|
||||
|
||||
Returns:
|
||||
dict: The result dict contains the data converted to
|
||||
:obj:`mmcv.DataContainer`.
|
||||
"""
|
||||
|
||||
for field in self.fields:
|
||||
field = field.copy()
|
||||
key = field.pop('key')
|
||||
results[key] = DC(results[key], **field)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + f'(fields={self.fields})'
|
||||
|
@ -2,7 +2,7 @@
|
||||
import os.path as osp
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .custom import BaseSegDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class SpatialPath(BaseModule):
|
||||
|
@ -5,8 +5,8 @@ from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
|
||||
build_activation_layer, build_norm_layer)
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class DetailBranch(BaseModule):
|
||||
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class DownsamplerBlock(BaseModule):
|
||||
|
@ -5,9 +5,8 @@ from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.models.decode_heads.psp_head import PPM
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import InvertedResidual
|
||||
from ..utils import InvertedResidual, resize
|
||||
|
||||
|
||||
class LearningToDownsample(nn.Module):
|
||||
|
@ -6,8 +6,8 @@ from mmcv.cnn import build_conv_layer, build_norm_layer
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
|
||||
from mmseg.ops import Upsample, resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample, resize
|
||||
from .resnet import BasicBlock, Bottleneck
|
||||
|
||||
|
||||
|
@ -4,9 +4,9 @@ import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..decode_heads.psp_head import PPM
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -6,8 +6,8 @@ import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule, ModuleList, Sequential
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .bisenetv1 import AttentionRefinementModule
|
||||
|
||||
|
||||
|
@ -8,9 +8,8 @@ from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.ops import Upsample
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import UpConvBlock
|
||||
from ..utils import UpConvBlock, Upsample
|
||||
|
||||
|
||||
class BasicConvBlock(nn.Module):
|
||||
|
@ -15,9 +15,8 @@ from mmengine.model import BaseModule, ModuleList
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.modules.utils import _pair as to_2tuple
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import PatchEmbed
|
||||
from ..utils import PatchEmbed, resize
|
||||
|
||||
|
||||
class TransformerEncoderLayer(BaseModule):
|
||||
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -3,8 +3,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -7,11 +7,11 @@ import torch.nn as nn
|
||||
from mmengine.model import BaseModule
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.structures import build_pixel_sampler
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
from ..builder import build_loss
|
||||
from ..losses import accuracy
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
|
@ -6,8 +6,8 @@ import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, Linear, build_activation_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -7,10 +7,10 @@ import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.ops import Encoding, resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import ConfigType, SampleList
|
||||
from ..builder import build_loss
|
||||
from ..utils import Encoding, resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -3,8 +3,8 @@ import numpy as np
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import Upsample, resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample, resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
||||
from mmcv import is_tuple_of
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -4,9 +4,9 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import SelfAttentionBlock as _SelfAttentionBlock
|
||||
from ..utils import resize
|
||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||
|
||||
|
||||
|
@ -12,10 +12,10 @@ except ModuleNotFoundError:
|
||||
|
||||
from typing import List
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import SampleList
|
||||
from ..losses import accuracy
|
||||
from ..utils import resize
|
||||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||||
|
||||
|
||||
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
try:
|
||||
|
@ -3,8 +3,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -3,8 +3,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .aspp_head import ASPPHead, ASPPModule
|
||||
|
||||
|
||||
|
@ -3,8 +3,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import Upsample
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -2,8 +2,8 @@
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, build_norm_layer
|
||||
|
||||
from mmseg.ops import Upsample
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import Upsample
|
||||
from .decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
|
@ -3,8 +3,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
from .decode_head import BaseDecodeHead
|
||||
from .psp_head import PPM
|
||||
|
||||
|
@ -4,8 +4,8 @@ import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -3,8 +3,8 @@ import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class CascadeFeatureFusion(BaseModule):
|
||||
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -2,8 +2,8 @@
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, xavier_init
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -6,10 +6,10 @@ from mmengine.data import PixelData
|
||||
from mmengine.model import BaseModel
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.ops import resize
|
||||
from mmseg.structures import SegDataSample
|
||||
from mmseg.utils import (ForwardResults, OptConfigType, OptMultiConfig,
|
||||
OptSampleList, SampleList)
|
||||
from ..utils import resize
|
||||
|
||||
|
||||
class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .embed import PatchEmbed
|
||||
from .encoding import Encoding
|
||||
from .inverted_residual import InvertedResidual, InvertedResidualV3
|
||||
from .make_divisible import make_divisible
|
||||
from .res_layer import ResLayer
|
||||
@ -8,9 +9,11 @@ from .self_attention_block import SelfAttentionBlock
|
||||
from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc,
|
||||
nlc_to_nchw)
|
||||
from .up_conv_block import UpConvBlock
|
||||
from .wrappers import Upsample, resize
|
||||
|
||||
__all__ = [
|
||||
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
|
||||
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed',
|
||||
'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc'
|
||||
'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding',
|
||||
'Upsample', 'resize'
|
||||
]
|
||||
|
75
mmseg/models/utils/encoding.py
Normal file
75
mmseg/models/utils/encoding.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Encoding(nn.Module):
|
||||
"""Encoding Layer: a learnable residual encoder.
|
||||
|
||||
Input is of shape (batch_size, channels, height, width).
|
||||
Output is of shape (batch_size, num_codes, channels).
|
||||
|
||||
Args:
|
||||
channels: dimension of the features or feature channels
|
||||
num_codes: number of code words
|
||||
"""
|
||||
|
||||
def __init__(self, channels, num_codes):
|
||||
super(Encoding, self).__init__()
|
||||
# init codewords and smoothing factor
|
||||
self.channels, self.num_codes = channels, num_codes
|
||||
std = 1. / ((num_codes * channels)**0.5)
|
||||
# [num_codes, channels]
|
||||
self.codewords = nn.Parameter(
|
||||
torch.empty(num_codes, channels,
|
||||
dtype=torch.float).uniform_(-std, std),
|
||||
requires_grad=True)
|
||||
# [num_codes]
|
||||
self.scale = nn.Parameter(
|
||||
torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
|
||||
requires_grad=True)
|
||||
|
||||
@staticmethod
|
||||
def scaled_l2(x, codewords, scale):
|
||||
num_codes, channels = codewords.size()
|
||||
batch_size = x.size(0)
|
||||
reshaped_scale = scale.view((1, 1, num_codes))
|
||||
expanded_x = x.unsqueeze(2).expand(
|
||||
(batch_size, x.size(1), num_codes, channels))
|
||||
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
|
||||
|
||||
scaled_l2_norm = reshaped_scale * (
|
||||
expanded_x - reshaped_codewords).pow(2).sum(dim=3)
|
||||
return scaled_l2_norm
|
||||
|
||||
@staticmethod
|
||||
def aggregate(assignment_weights, x, codewords):
|
||||
num_codes, channels = codewords.size()
|
||||
reshaped_codewords = codewords.view((1, 1, num_codes, channels))
|
||||
batch_size = x.size(0)
|
||||
|
||||
expanded_x = x.unsqueeze(2).expand(
|
||||
(batch_size, x.size(1), num_codes, channels))
|
||||
encoded_feat = (assignment_weights.unsqueeze(3) *
|
||||
(expanded_x - reshaped_codewords)).sum(dim=1)
|
||||
return encoded_feat
|
||||
|
||||
def forward(self, x):
|
||||
assert x.dim() == 4 and x.size(1) == self.channels
|
||||
# [batch_size, channels, height, width]
|
||||
batch_size = x.size(0)
|
||||
# [batch_size, height x width, channels]
|
||||
x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
|
||||
# assignment_weights: [batch_size, channels, num_codes]
|
||||
assignment_weights = F.softmax(
|
||||
self.scaled_l2(x, self.codewords, self.scale), dim=2)
|
||||
# aggregate
|
||||
encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
|
||||
return encoded_feat
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
|
||||
f'x{self.channels})'
|
||||
return repr_str
|
51
mmseg/models/utils/wrappers.py
Normal file
51
mmseg/models/utils/wrappers.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def resize(input,
|
||||
size=None,
|
||||
scale_factor=None,
|
||||
mode='nearest',
|
||||
align_corners=None,
|
||||
warning=True):
|
||||
if warning:
|
||||
if size is not None and align_corners:
|
||||
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
||||
output_h, output_w = tuple(int(x) for x in size)
|
||||
if output_h > input_h or output_w > output_h:
|
||||
if ((output_h > 1 and output_w > 1 and input_h > 1
|
||||
and input_w > 1) and (output_h - 1) % (input_h - 1)
|
||||
and (output_w - 1) % (input_w - 1)):
|
||||
warnings.warn(
|
||||
f'When align_corners={align_corners}, '
|
||||
'the output would more aligned if '
|
||||
f'input size {(input_h, input_w)} is `x+1` and '
|
||||
f'out size {(output_h, output_w)} is `nx+1`')
|
||||
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
size=None,
|
||||
scale_factor=None,
|
||||
mode='nearest',
|
||||
align_corners=None):
|
||||
super(Upsample, self).__init__()
|
||||
self.size = size
|
||||
if isinstance(scale_factor, tuple):
|
||||
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
||||
else:
|
||||
self.scale_factor = float(scale_factor) if scale_factor else None
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
if not self.size:
|
||||
size = [int(t * self.scale_factor) for t in x.shape[-2:]]
|
||||
else:
|
||||
size = self.size
|
||||
return resize(x, size, None, self.mode, self.align_corners)
|
@ -1,5 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .encoding import Encoding
|
||||
from .wrappers import Upsample, resize
|
||||
|
||||
__all__ = ['Upsample', 'resize', 'Encoding']
|
@ -8,13 +8,16 @@ https://mmengine.readthedocs.io/en/latest/tutorials/registry.html.
|
||||
|
||||
from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS
|
||||
from mmengine.registry import DATASETS as MMENGINE_DATASETS
|
||||
from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
|
||||
from mmengine.registry import HOOKS as MMENGINE_HOOKS
|
||||
from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS
|
||||
from mmengine.registry import LOOPS as MMENGINE_LOOPS
|
||||
from mmengine.registry import METRICS as MMENGINE_METRICS
|
||||
from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS
|
||||
from mmengine.registry import MODELS as MMENGINE_MODELS
|
||||
from mmengine.registry import \
|
||||
OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS
|
||||
from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS
|
||||
from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS
|
||||
from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS
|
||||
from mmengine.registry import \
|
||||
@ -53,14 +56,20 @@ WEIGHT_INITIALIZERS = Registry(
|
||||
|
||||
# mangage all kinds of optimizers like `SGD` and `Adam`
|
||||
OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS)
|
||||
# manage optimizer wrapper
|
||||
OPTIM_WRAPPERS = Registry('optim_wrapper', parent=MMENGINE_OPTIM_WRAPPERS)
|
||||
# manage constructors that customize the optimization hyperparameters.
|
||||
OPTIM_WRAPPER_CONSTRUCTORS = Registry(
|
||||
'optimizer constructor', parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
|
||||
'optimizer wrapper constructor',
|
||||
parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS)
|
||||
# mangage all kinds of parameter schedulers like `MultiStepLR`
|
||||
PARAM_SCHEDULERS = Registry(
|
||||
'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS)
|
||||
|
||||
# manage all kinds of metrics
|
||||
METRICS = Registry('metric', parent=MMENGINE_METRICS)
|
||||
# manage evaluator
|
||||
EVALUATOR = Registry('evaluator', parent=MMENGINE_EVALUATOR)
|
||||
|
||||
# manage task-specific modules like ohem pixel sampler
|
||||
TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
|
||||
@ -69,3 +78,6 @@ TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS)
|
||||
VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS)
|
||||
# manage visualizer backend
|
||||
VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS)
|
||||
|
||||
# manage logprocessor
|
||||
LOG_PROCESSORS = Registry('log_processor', parent=MMENGINE_LOG_PROCESSORS)
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
|
||||
__version__ = '0.24.1'
|
||||
__version__ = '1.0.0rc0'
|
||||
|
||||
|
||||
def parse_version_info(version_str):
|
||||
|
@ -1,2 +1,2 @@
|
||||
mmcls>=0.20.1
|
||||
mmcv-full>=1.4.4,<=1.6.0
|
||||
mmcv-full>=2.0.0,<=2.0.0rc0
|
||||
|
@ -3,10 +3,13 @@ import glob
|
||||
import os
|
||||
from os.path import dirname, exists, isdir, join, relpath
|
||||
|
||||
from mmcv import Config
|
||||
import numpy as np
|
||||
from mmengine import Config
|
||||
from mmengine.dataset import Compose
|
||||
from torch import nn
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def _get_config_directory():
|
||||
@ -60,72 +63,69 @@ def test_config_build_segmentor():
|
||||
_check_decode_head(head_config, segmentor.decode_head)
|
||||
|
||||
|
||||
# def test_config_data_pipeline():
|
||||
# """Test whether the data pipeline is valid and can process corner cases.
|
||||
def test_config_data_pipeline():
|
||||
"""Test whether the data pipeline is valid and can process corner cases.
|
||||
|
||||
# CommandLine:
|
||||
# xdoctest -m tests/test_config.py test_config_build_data_pipeline
|
||||
# """
|
||||
# import numpy as np
|
||||
# from mmcv import Config
|
||||
CommandLine:
|
||||
xdoctest -m tests/test_config.py test_config_build_data_pipeline
|
||||
"""
|
||||
|
||||
# from mmseg.datasets.transforms import Compose
|
||||
register_all_modules()
|
||||
config_dpath = _get_config_directory()
|
||||
print('Found config_dpath = {!r}'.format(config_dpath))
|
||||
|
||||
# config_dpath = _get_config_directory()
|
||||
# print('Found config_dpath = {!r}'.format(config_dpath))
|
||||
import glob
|
||||
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
|
||||
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
||||
config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
||||
|
||||
# import glob
|
||||
# config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
|
||||
# config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
||||
# config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
||||
print('Using {} config files'.format(len(config_names)))
|
||||
|
||||
# print('Using {} config files'.format(len(config_names)))
|
||||
for config_fname in config_names:
|
||||
config_fpath = join(config_dpath, config_fname)
|
||||
print(
|
||||
'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
|
||||
config_mod = Config.fromfile(config_fpath)
|
||||
|
||||
# for config_fname in config_names:
|
||||
# config_fpath = join(config_dpath, config_fname)
|
||||
# print(
|
||||
# 'Building data pipeline, config_fpath = {!r}'.format(config_fpath))
|
||||
# config_mod = Config.fromfile(config_fpath)
|
||||
# remove loading pipeline
|
||||
load_img_pipeline = config_mod.train_pipeline.pop(0)
|
||||
to_float32 = load_img_pipeline.get('to_float32', False)
|
||||
config_mod.train_pipeline.pop(0)
|
||||
config_mod.test_pipeline.pop(0)
|
||||
# remove loading annotation in test pipeline
|
||||
config_mod.test_pipeline.pop(1)
|
||||
|
||||
# # remove loading pipeline
|
||||
# load_img_pipeline = config_mod.train_pipeline.pop(0)
|
||||
# to_float32 = load_img_pipeline.get('to_float32', False)
|
||||
# config_mod.train_pipeline.pop(0)
|
||||
# config_mod.test_pipeline.pop(0)
|
||||
# # remove loading annotation in test pipeline
|
||||
# config_mod.test_pipeline.pop(1)
|
||||
train_pipeline = Compose(config_mod.train_pipeline)
|
||||
test_pipeline = Compose(config_mod.test_pipeline)
|
||||
|
||||
# train_pipeline = Compose(config_mod.train_pipeline)
|
||||
# test_pipeline = Compose(config_mod.test_pipeline)
|
||||
img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
|
||||
if to_float32:
|
||||
img = img.astype(np.float32)
|
||||
seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
|
||||
|
||||
# img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
|
||||
# if to_float32:
|
||||
# img = img.astype(np.float32)
|
||||
# seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
|
||||
results = dict(
|
||||
filename='test_img.png',
|
||||
ori_filename='test_img.png',
|
||||
img=img,
|
||||
img_shape=img.shape,
|
||||
ori_shape=img.shape,
|
||||
gt_seg_map=seg)
|
||||
results['seg_fields'] = ['gt_seg_map']
|
||||
|
||||
# results = dict(
|
||||
# filename='test_img.png',
|
||||
# ori_filename='test_img.png',
|
||||
# img=img,
|
||||
# img_shape=img.shape,
|
||||
# ori_shape=img.shape,
|
||||
# gt_seg_map=seg)
|
||||
# results['seg_fields'] = ['gt_seg_map']
|
||||
print('Test training data pipeline: \n{!r}'.format(train_pipeline))
|
||||
output_results = train_pipeline(results)
|
||||
assert output_results is not None
|
||||
|
||||
# print('Test training data pipeline: \n{!r}'.format(train_pipeline))
|
||||
# output_results = train_pipeline(results)
|
||||
# assert output_results is not None
|
||||
|
||||
# results = dict(
|
||||
# filename='test_img.png',
|
||||
# ori_filename='test_img.png',
|
||||
# img=img,
|
||||
# img_shape=img.shape,
|
||||
# ori_shape=img.shape,
|
||||
# )
|
||||
# print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
|
||||
# output_results = test_pipeline(results)
|
||||
# assert output_results is not None
|
||||
results = dict(
|
||||
filename='test_img.png',
|
||||
ori_filename='test_img.png',
|
||||
img=img,
|
||||
img_shape=img.shape,
|
||||
ori_shape=img.shape,
|
||||
)
|
||||
print('Test testing data pipeline: \n{!r}'.format(test_pipeline))
|
||||
output_results = test_pipeline(results)
|
||||
assert output_results is not None
|
||||
|
||||
|
||||
def _check_decode_head(decode_head_cfg, decode_head):
|
||||
|
@ -1,4 +1 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .utils import all_zeros, check_norm_state, is_block, is_norm
|
||||
|
||||
__all__ = ['is_norm', 'is_block', 'all_zeros', 'check_norm_state']
|
||||
|
@ -5,7 +5,7 @@ from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
|
||||
InterpConv, UNet, UpConvBlock)
|
||||
from mmseg.ops import Upsample
|
||||
from mmseg.models.utils import Upsample
|
||||
from .utils import check_norm_state
|
||||
|
||||
|
||||
|
@ -15,6 +15,14 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Train a segmentor')
|
||||
parser.add_argument('config', help='train config file path')
|
||||
parser.add_argument('--work-dir', help='the dir to save logs and models')
|
||||
parser.add_argument(
|
||||
'--resume',
|
||||
nargs='?',
|
||||
type=str,
|
||||
const='auto',
|
||||
help='If specify checkpint path, resume from it, while if not '
|
||||
'specify, try to auto resume from the latest checkpoint '
|
||||
'in the work directory.')
|
||||
parser.add_argument(
|
||||
'--amp',
|
||||
action='store_true',
|
||||
@ -80,6 +88,14 @@ def main():
|
||||
cfg.optim_wrapper.type = 'AmpOptimWrapper'
|
||||
cfg.optim_wrapper.loss_scale = 'dynamic'
|
||||
|
||||
# resume training
|
||||
if args.resume == 'auto':
|
||||
cfg.resume = True
|
||||
cfg.load_from = None
|
||||
elif args.resume is not None:
|
||||
cfg.resume = True
|
||||
cfg.load_from = args.resume
|
||||
|
||||
# build the runner from config
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user