diff --git a/mmseg/__init__.py b/mmseg/__init__.py index 17d43de1b..8ffaf446b 100644 --- a/mmseg/__init__.py +++ b/mmseg/__init__.py @@ -55,8 +55,8 @@ mmcv_max_version = digit_version(MMCV_MAX) mmcv_version = digit_version(mmcv.__version__) -assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ +assert (mmcv_min_version <= mmcv_version < mmcv_max_version), \ f'MMCV=={mmcv.__version__} is used but incompatible. ' \ - f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.' + f'Please install mmcv>={mmcv_min_version}, <{mmcv_max_version}.' __all__ = ['__version__', 'version_info', 'digit_version'] diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py index 873b566ab..5cd5a772f 100644 --- a/mmseg/datasets/__init__.py +++ b/mmseg/datasets/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ade import ADE20KDataset +from .basesegdataset import BaseSegDataset from .chase_db1 import ChaseDB1Dataset from .cityscapes import CityscapesDataset from .coco_stuff import COCOStuffDataset -from .custom import BaseSegDataset from .dark_zurich import DarkZurichDataset from .dataset_wrappers import MultiImageMixDataset from .drive import DRIVEDataset diff --git a/mmseg/datasets/ade.py b/mmseg/datasets/ade.py index 2153ff2d7..1f97fcb2c 100644 --- a/mmseg/datasets/ade.py +++ b/mmseg/datasets/ade.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmseg.registry import DATASETS -from .custom import BaseSegDataset +from .basesegdataset import BaseSegDataset @DATASETS.register_module() diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/basesegdataset.py similarity index 100% rename from mmseg/datasets/custom.py rename to mmseg/datasets/basesegdataset.py diff --git a/mmseg/datasets/chase_db1.py b/mmseg/datasets/chase_db1.py index 49ef84d53..71139f2aa 100644 --- a/mmseg/datasets/chase_db1.py +++ b/mmseg/datasets/chase_db1.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmseg.registry import DATASETS -from .custom import BaseSegDataset +from .basesegdataset import BaseSegDataset @DATASETS.register_module() diff --git a/mmseg/datasets/cityscapes.py b/mmseg/datasets/cityscapes.py index a72dcbe1b..f494d6242 100644 --- a/mmseg/datasets/cityscapes.py +++ b/mmseg/datasets/cityscapes.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmseg.registry import DATASETS -from .custom import BaseSegDataset +from .basesegdataset import BaseSegDataset @DATASETS.register_module() diff --git a/mmseg/datasets/coco_stuff.py b/mmseg/datasets/coco_stuff.py index 72fd0e4b2..b0891a387 100644 --- a/mmseg/datasets/coco_stuff.py +++ b/mmseg/datasets/coco_stuff.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmseg.registry import DATASETS -from .custom import BaseSegDataset +from .basesegdataset import BaseSegDataset @DATASETS.register_module() diff --git a/mmseg/datasets/drive.py b/mmseg/datasets/drive.py index 2fa78d985..fbaf729be 100644 --- a/mmseg/datasets/drive.py +++ b/mmseg/datasets/drive.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmseg.registry import DATASETS -from .custom import BaseSegDataset +from .basesegdataset import BaseSegDataset @DATASETS.register_module() diff --git a/mmseg/datasets/hrf.py b/mmseg/datasets/hrf.py index 91ed9c1e3..f8d330918 100644 --- a/mmseg/datasets/hrf.py +++ b/mmseg/datasets/hrf.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmseg.registry import DATASETS -from .custom import BaseSegDataset +from .basesegdataset import BaseSegDataset @DATASETS.register_module() diff --git a/mmseg/datasets/isaid.py b/mmseg/datasets/isaid.py index c5a435779..a91a12b3b 100644 --- a/mmseg/datasets/isaid.py +++ b/mmseg/datasets/isaid.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmseg.registry import DATASETS -from .custom import BaseSegDataset +from .basesegdataset import BaseSegDataset @DATASETS.register_module() diff --git a/mmseg/datasets/isprs.py b/mmseg/datasets/isprs.py index 67bccc32d..78df4e385 100644 --- a/mmseg/datasets/isprs.py +++ b/mmseg/datasets/isprs.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmseg.registry import DATASETS -from .custom import BaseSegDataset +from .basesegdataset import BaseSegDataset @DATASETS.register_module() diff --git a/mmseg/datasets/loveda.py b/mmseg/datasets/loveda.py index 2dae213ad..b5d80f25a 100644 --- a/mmseg/datasets/loveda.py +++ b/mmseg/datasets/loveda.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmseg.registry import DATASETS -from .custom import BaseSegDataset +from .basesegdataset import BaseSegDataset @DATASETS.register_module() diff --git a/mmseg/datasets/pascal_context.py b/mmseg/datasets/pascal_context.py index 47ec544ea..a98372352 100644 --- a/mmseg/datasets/pascal_context.py +++ b/mmseg/datasets/pascal_context.py @@ -2,7 +2,7 @@ import os.path as osp from mmseg.registry import DATASETS -from .custom import BaseSegDataset +from .basesegdataset import BaseSegDataset @DATASETS.register_module() diff --git a/mmseg/datasets/potsdam.py b/mmseg/datasets/potsdam.py index e8801c4f3..808cf6ec7 100644 --- a/mmseg/datasets/potsdam.py +++ b/mmseg/datasets/potsdam.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmseg.registry import DATASETS -from .custom import BaseSegDataset +from .basesegdataset import BaseSegDataset @DATASETS.register_module() diff --git a/mmseg/datasets/stare.py b/mmseg/datasets/stare.py index 73086dd3b..485470277 100644 --- a/mmseg/datasets/stare.py +++ b/mmseg/datasets/stare.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmseg.registry import DATASETS -from .custom import BaseSegDataset +from .basesegdataset import BaseSegDataset @DATASETS.register_module() diff --git a/mmseg/datasets/transforms/compose.py b/mmseg/datasets/transforms/compose.py deleted file mode 100644 index 5bfaa7046..000000000 --- a/mmseg/datasets/transforms/compose.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import collections - -from mmseg.registry import TRANSFORMS - - -@TRANSFORMS.register_module() -class Compose(object): - """Compose multiple transforms sequentially. - - Args: - transforms (Sequence[dict | callable]): Sequence of transform object or - config dict to be composed. - """ - - def __init__(self, transforms): - assert isinstance(transforms, collections.abc.Sequence) - self.transforms = [] - for transform in transforms: - if isinstance(transform, dict): - transform = TRANSFORMS.build(transform) - self.transforms.append(transform) - elif callable(transform): - self.transforms.append(transform) - else: - raise TypeError('transform must be callable or a dict') - - def __call__(self, data): - """Call function to apply transforms sequentially. - - Args: - data (dict): A result dict contains the data to transform. - - Returns: - dict: Transformed data. - """ - - for t in self.transforms: - data = t(data) - if data is None: - return None - return data - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - for t in self.transforms: - format_string += '\n' - format_string += f' {t}' - format_string += '\n)' - return format_string diff --git a/mmseg/datasets/transforms/formatting.py b/mmseg/datasets/transforms/formatting.py index 4f343c3d4..71c99c97b 100644 --- a/mmseg/datasets/transforms/formatting.py +++ b/mmseg/datasets/transforms/formatting.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np -from mmcv.parallel import DataContainer as DC from mmcv.transforms import to_tensor from mmcv.transforms.base import BaseTransform from mmengine.data import PixelData @@ -87,114 +86,3 @@ class PackSegInputs(BaseTransform): repr_str = self.__class__.__name__ repr_str += f'(meta_keys={self.meta_keys})' return repr_str - - -@TRANSFORMS.register_module() -class ImageToTensor(object): - """Convert image to :obj:`torch.Tensor` by given keys. - - The dimension order of input image is (H, W, C). The pipeline will convert - it to (C, H, W). If only 2 dimension (H, W) is given, the output would be - (1, H, W). - - Args: - keys (Sequence[str]): Key of images to be converted to Tensor. - """ - - def __init__(self, keys): - self.keys = keys - - def __call__(self, results): - """Call function to convert image in results to :obj:`torch.Tensor` and - transpose the channel order. - - Args: - results (dict): Result dict contains the image data to convert. - - Returns: - dict: The result dict contains the image converted - to :obj:`torch.Tensor` and transposed to (C, H, W) order. - """ - - for key in self.keys: - img = results[key] - if len(img.shape) < 3: - img = np.expand_dims(img, -1) - results[key] = to_tensor(img.transpose(2, 0, 1)) - return results - - def __repr__(self): - return self.__class__.__name__ + f'(keys={self.keys})' - - -@TRANSFORMS.register_module() -class Transpose(object): - """Transpose some results by given keys. - - Args: - keys (Sequence[str]): Keys of results to be transposed. - order (Sequence[int]): Order of transpose. - """ - - def __init__(self, keys, order): - self.keys = keys - self.order = order - - def __call__(self, results): - """Call function to convert image in results to :obj:`torch.Tensor` and - transpose the channel order. - - Args: - results (dict): Result dict contains the image data to convert. - - Returns: - dict: The result dict contains the image converted - to :obj:`torch.Tensor` and transposed to (C, H, W) order. - """ - - for key in self.keys: - results[key] = results[key].transpose(self.order) - return results - - def __repr__(self): - return self.__class__.__name__ + \ - f'(keys={self.keys}, order={self.order})' - - -@TRANSFORMS.register_module() -class ToDataContainer(object): - """Convert results to :obj:`mmcv.DataContainer` by given fields. - - Args: - fields (Sequence[dict]): Each field is a dict like - ``dict(key='xxx', **kwargs)``. The ``key`` in result will - be converted to :obj:`mmcv.DataContainer` with ``**kwargs``. - Default: ``(dict(key='img', stack=True), - dict(key='gt_semantic_seg'))``. - """ - - def __init__(self, - fields=(dict(key='img', - stack=True), dict(key='gt_semantic_seg'))): - self.fields = fields - - def __call__(self, results): - """Call function to convert data in results to - :obj:`mmcv.DataContainer`. - - Args: - results (dict): Result dict contains the data to convert. - - Returns: - dict: The result dict contains the data converted to - :obj:`mmcv.DataContainer`. - """ - - for field in self.fields: - field = field.copy() - key = field.pop('key') - results[key] = DC(results[key], **field) - return results - - def __repr__(self): - return self.__class__.__name__ + f'(fields={self.fields})' diff --git a/mmseg/datasets/voc.py b/mmseg/datasets/voc.py index fe4070205..1defcbf6c 100644 --- a/mmseg/datasets/voc.py +++ b/mmseg/datasets/voc.py @@ -2,7 +2,7 @@ import os.path as osp from mmseg.registry import DATASETS -from .custom import BaseSegDataset +from .basesegdataset import BaseSegDataset @DATASETS.register_module() diff --git a/mmseg/models/backbones/bisenetv1.py b/mmseg/models/backbones/bisenetv1.py index 709b53dfb..fa9e51383 100644 --- a/mmseg/models/backbones/bisenetv1.py +++ b/mmseg/models/backbones/bisenetv1.py @@ -4,8 +4,8 @@ import torch.nn as nn from mmcv.cnn import ConvModule from mmengine.model import BaseModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize class SpatialPath(BaseModule): diff --git a/mmseg/models/backbones/bisenetv2.py b/mmseg/models/backbones/bisenetv2.py index 71bfe5949..256de7952 100644 --- a/mmseg/models/backbones/bisenetv2.py +++ b/mmseg/models/backbones/bisenetv2.py @@ -5,8 +5,8 @@ from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, build_activation_layer, build_norm_layer) from mmengine.model import BaseModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize class DetailBranch(BaseModule): diff --git a/mmseg/models/backbones/erfnet.py b/mmseg/models/backbones/erfnet.py index 638efedbb..88f41dfb1 100644 --- a/mmseg/models/backbones/erfnet.py +++ b/mmseg/models/backbones/erfnet.py @@ -4,8 +4,8 @@ import torch.nn as nn from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer from mmengine.model import BaseModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize class DownsamplerBlock(BaseModule): diff --git a/mmseg/models/backbones/fast_scnn.py b/mmseg/models/backbones/fast_scnn.py index d7e13f4ce..9884b2e18 100644 --- a/mmseg/models/backbones/fast_scnn.py +++ b/mmseg/models/backbones/fast_scnn.py @@ -5,9 +5,8 @@ from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule from mmengine.model import BaseModule from mmseg.models.decode_heads.psp_head import PPM -from mmseg.ops import resize from mmseg.registry import MODELS -from ..utils import InvertedResidual +from ..utils import InvertedResidual, resize class LearningToDownsample(nn.Module): diff --git a/mmseg/models/backbones/hrnet.py b/mmseg/models/backbones/hrnet.py index ff2998c2c..3e209b0af 100644 --- a/mmseg/models/backbones/hrnet.py +++ b/mmseg/models/backbones/hrnet.py @@ -6,8 +6,8 @@ from mmcv.cnn import build_conv_layer, build_norm_layer from mmcv.utils.parrots_wrapper import _BatchNorm from mmengine.model import BaseModule, ModuleList, Sequential -from mmseg.ops import Upsample, resize from mmseg.registry import MODELS +from ..utils import Upsample, resize from .resnet import BasicBlock, Bottleneck diff --git a/mmseg/models/backbones/icnet.py b/mmseg/models/backbones/icnet.py index 87f3cf144..faa6d15cc 100644 --- a/mmseg/models/backbones/icnet.py +++ b/mmseg/models/backbones/icnet.py @@ -4,9 +4,9 @@ import torch.nn as nn from mmcv.cnn import ConvModule from mmengine.model import BaseModule -from mmseg.ops import resize from mmseg.registry import MODELS from ..decode_heads.psp_head import PPM +from ..utils import resize @MODELS.register_module() diff --git a/mmseg/models/backbones/stdc.py b/mmseg/models/backbones/stdc.py index fefaecbdf..340d6ee4c 100644 --- a/mmseg/models/backbones/stdc.py +++ b/mmseg/models/backbones/stdc.py @@ -6,8 +6,8 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule from mmengine.model import BaseModule, ModuleList, Sequential -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize from .bisenetv1 import AttentionRefinementModule diff --git a/mmseg/models/backbones/unet.py b/mmseg/models/backbones/unet.py index 99f217f18..1a14b8c58 100644 --- a/mmseg/models/backbones/unet.py +++ b/mmseg/models/backbones/unet.py @@ -8,9 +8,8 @@ from mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer, from mmcv.utils.parrots_wrapper import _BatchNorm from mmengine.model import BaseModule -from mmseg.ops import Upsample from mmseg.registry import MODELS -from ..utils import UpConvBlock +from ..utils import UpConvBlock, Upsample class BasicConvBlock(nn.Module): diff --git a/mmseg/models/backbones/vit.py b/mmseg/models/backbones/vit.py index 3707ff529..fa5a012b7 100644 --- a/mmseg/models/backbones/vit.py +++ b/mmseg/models/backbones/vit.py @@ -15,9 +15,8 @@ from mmengine.model import BaseModule, ModuleList from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.utils import _pair as to_2tuple -from mmseg.ops import resize from mmseg.registry import MODELS -from ..utils import PatchEmbed +from ..utils import PatchEmbed, resize class TransformerEncoderLayer(BaseModule): diff --git a/mmseg/models/decode_heads/apc_head.py b/mmseg/models/decode_heads/apc_head.py index 45f4e2850..187fdb0e9 100644 --- a/mmseg/models/decode_heads/apc_head.py +++ b/mmseg/models/decode_heads/apc_head.py @@ -4,8 +4,8 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize from .decode_head import BaseDecodeHead diff --git a/mmseg/models/decode_heads/aspp_head.py b/mmseg/models/decode_heads/aspp_head.py index acf9eedfa..757e69359 100644 --- a/mmseg/models/decode_heads/aspp_head.py +++ b/mmseg/models/decode_heads/aspp_head.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn from mmcv.cnn import ConvModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize from .decode_head import BaseDecodeHead diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py index 5889bd673..6fdfbba1d 100644 --- a/mmseg/models/decode_heads/decode_head.py +++ b/mmseg/models/decode_heads/decode_head.py @@ -7,11 +7,11 @@ import torch.nn as nn from mmengine.model import BaseModule from torch import Tensor -from mmseg.ops import resize from mmseg.structures import build_pixel_sampler from mmseg.utils import ConfigType, SampleList from ..builder import build_loss from ..losses import accuracy +from ..utils import resize class BaseDecodeHead(BaseModule, metaclass=ABCMeta): diff --git a/mmseg/models/decode_heads/dpt_head.py b/mmseg/models/decode_heads/dpt_head.py index 243eec5b5..989b53aa4 100644 --- a/mmseg/models/decode_heads/dpt_head.py +++ b/mmseg/models/decode_heads/dpt_head.py @@ -6,8 +6,8 @@ import torch.nn as nn from mmcv.cnn import ConvModule, Linear, build_activation_layer from mmengine.model import BaseModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize from .decode_head import BaseDecodeHead diff --git a/mmseg/models/decode_heads/enc_head.py b/mmseg/models/decode_heads/enc_head.py index 1b8eecbff..cfceb9428 100644 --- a/mmseg/models/decode_heads/enc_head.py +++ b/mmseg/models/decode_heads/enc_head.py @@ -7,10 +7,10 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule, build_norm_layer from torch import Tensor -from mmseg.ops import Encoding, resize from mmseg.registry import MODELS from mmseg.utils import ConfigType, SampleList from ..builder import build_loss +from ..utils import Encoding, resize from .decode_head import BaseDecodeHead diff --git a/mmseg/models/decode_heads/fpn_head.py b/mmseg/models/decode_heads/fpn_head.py index be92ceed5..a9af1feec 100644 --- a/mmseg/models/decode_heads/fpn_head.py +++ b/mmseg/models/decode_heads/fpn_head.py @@ -3,8 +3,8 @@ import numpy as np import torch.nn as nn from mmcv.cnn import ConvModule -from mmseg.ops import Upsample, resize from mmseg.registry import MODELS +from ..utils import Upsample, resize from .decode_head import BaseDecodeHead diff --git a/mmseg/models/decode_heads/lraspp_head.py b/mmseg/models/decode_heads/lraspp_head.py index 36999f056..84dcc40eb 100644 --- a/mmseg/models/decode_heads/lraspp_head.py +++ b/mmseg/models/decode_heads/lraspp_head.py @@ -4,8 +4,8 @@ import torch.nn as nn from mmcv import is_tuple_of from mmcv.cnn import ConvModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize from .decode_head import BaseDecodeHead diff --git a/mmseg/models/decode_heads/ocr_head.py b/mmseg/models/decode_heads/ocr_head.py index ce3582413..57d91613d 100644 --- a/mmseg/models/decode_heads/ocr_head.py +++ b/mmseg/models/decode_heads/ocr_head.py @@ -4,9 +4,9 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule -from mmseg.ops import resize from mmseg.registry import MODELS from ..utils import SelfAttentionBlock as _SelfAttentionBlock +from ..utils import resize from .cascade_decode_head import BaseCascadeDecodeHead diff --git a/mmseg/models/decode_heads/point_head.py b/mmseg/models/decode_heads/point_head.py index 781ed1ee8..a57d531bd 100644 --- a/mmseg/models/decode_heads/point_head.py +++ b/mmseg/models/decode_heads/point_head.py @@ -12,10 +12,10 @@ except ModuleNotFoundError: from typing import List -from mmseg.ops import resize from mmseg.registry import MODELS from mmseg.utils import SampleList from ..losses import accuracy +from ..utils import resize from .cascade_decode_head import BaseCascadeDecodeHead diff --git a/mmseg/models/decode_heads/psa_head.py b/mmseg/models/decode_heads/psa_head.py index 4b292c600..cf3749f9c 100644 --- a/mmseg/models/decode_heads/psa_head.py +++ b/mmseg/models/decode_heads/psa_head.py @@ -4,8 +4,8 @@ import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import ConvModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize from .decode_head import BaseDecodeHead try: diff --git a/mmseg/models/decode_heads/psp_head.py b/mmseg/models/decode_heads/psp_head.py index 734de8f1a..c22337fde 100644 --- a/mmseg/models/decode_heads/psp_head.py +++ b/mmseg/models/decode_heads/psp_head.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn from mmcv.cnn import ConvModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize from .decode_head import BaseDecodeHead diff --git a/mmseg/models/decode_heads/segformer_head.py b/mmseg/models/decode_heads/segformer_head.py index 8c14602a3..f9eb0b320 100644 --- a/mmseg/models/decode_heads/segformer_head.py +++ b/mmseg/models/decode_heads/segformer_head.py @@ -4,8 +4,8 @@ import torch.nn as nn from mmcv.cnn import ConvModule from mmseg.models.decode_heads.decode_head import BaseDecodeHead -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize @MODELS.register_module() diff --git a/mmseg/models/decode_heads/sep_aspp_head.py b/mmseg/models/decode_heads/sep_aspp_head.py index c63217931..a1ffae708 100644 --- a/mmseg/models/decode_heads/sep_aspp_head.py +++ b/mmseg/models/decode_heads/sep_aspp_head.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize from .aspp_head import ASPPHead, ASPPModule diff --git a/mmseg/models/decode_heads/setr_mla_head.py b/mmseg/models/decode_heads/setr_mla_head.py index 228c31100..887e11bc2 100644 --- a/mmseg/models/decode_heads/setr_mla_head.py +++ b/mmseg/models/decode_heads/setr_mla_head.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn from mmcv.cnn import ConvModule -from mmseg.ops import Upsample from mmseg.registry import MODELS +from ..utils import Upsample from .decode_head import BaseDecodeHead diff --git a/mmseg/models/decode_heads/setr_up_head.py b/mmseg/models/decode_heads/setr_up_head.py index 11ab9bb7d..151e9ebce 100644 --- a/mmseg/models/decode_heads/setr_up_head.py +++ b/mmseg/models/decode_heads/setr_up_head.py @@ -2,8 +2,8 @@ import torch.nn as nn from mmcv.cnn import ConvModule, build_norm_layer -from mmseg.ops import Upsample from mmseg.registry import MODELS +from ..utils import Upsample from .decode_head import BaseDecodeHead diff --git a/mmseg/models/decode_heads/uper_head.py b/mmseg/models/decode_heads/uper_head.py index 347ef3292..651236b20 100644 --- a/mmseg/models/decode_heads/uper_head.py +++ b/mmseg/models/decode_heads/uper_head.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn from mmcv.cnn import ConvModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize from .decode_head import BaseDecodeHead from .psp_head import PPM diff --git a/mmseg/models/necks/fpn.py b/mmseg/models/necks/fpn.py index bb49b0b17..98fb650f5 100644 --- a/mmseg/models/necks/fpn.py +++ b/mmseg/models/necks/fpn.py @@ -4,8 +4,8 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule from mmengine.model import BaseModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize @MODELS.register_module() diff --git a/mmseg/models/necks/ic_neck.py b/mmseg/models/necks/ic_neck.py index f46f45d56..c554d227b 100644 --- a/mmseg/models/necks/ic_neck.py +++ b/mmseg/models/necks/ic_neck.py @@ -3,8 +3,8 @@ import torch.nn.functional as F from mmcv.cnn import ConvModule from mmengine.model import BaseModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize class CascadeFeatureFusion(BaseModule): diff --git a/mmseg/models/necks/jpu.py b/mmseg/models/necks/jpu.py index 333230be4..6aad895c3 100644 --- a/mmseg/models/necks/jpu.py +++ b/mmseg/models/necks/jpu.py @@ -4,8 +4,8 @@ import torch.nn as nn from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule from mmengine.model import BaseModule -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize @MODELS.register_module() diff --git a/mmseg/models/necks/multilevel_neck.py b/mmseg/models/necks/multilevel_neck.py index 14942bb63..50b7049d4 100644 --- a/mmseg/models/necks/multilevel_neck.py +++ b/mmseg/models/necks/multilevel_neck.py @@ -2,8 +2,8 @@ import torch.nn as nn from mmcv.cnn import ConvModule, xavier_init -from mmseg.ops import resize from mmseg.registry import MODELS +from ..utils import resize @MODELS.register_module() diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 586f718b6..e9f05a5b7 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -6,10 +6,10 @@ from mmengine.data import PixelData from mmengine.model import BaseModel from torch import Tensor -from mmseg.ops import resize from mmseg.structures import SegDataSample from mmseg.utils import (ForwardResults, OptConfigType, OptMultiConfig, OptSampleList, SampleList) +from ..utils import resize class BaseSegmentor(BaseModel, metaclass=ABCMeta): diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py index 6d8329021..7aaa600c2 100644 --- a/mmseg/models/utils/__init__.py +++ b/mmseg/models/utils/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .embed import PatchEmbed +from .encoding import Encoding from .inverted_residual import InvertedResidual, InvertedResidualV3 from .make_divisible import make_divisible from .res_layer import ResLayer @@ -8,9 +9,11 @@ from .self_attention_block import SelfAttentionBlock from .shape_convert import (nchw2nlc2nchw, nchw_to_nlc, nlc2nchw2nlc, nlc_to_nchw) from .up_conv_block import UpConvBlock +from .wrappers import Upsample, resize __all__ = [ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed', - 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc' + 'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding', + 'Upsample', 'resize' ] diff --git a/mmseg/models/utils/encoding.py b/mmseg/models/utils/encoding.py new file mode 100644 index 000000000..f397cc54e --- /dev/null +++ b/mmseg/models/utils/encoding.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn +from torch.nn import functional as F + + +class Encoding(nn.Module): + """Encoding Layer: a learnable residual encoder. + + Input is of shape (batch_size, channels, height, width). + Output is of shape (batch_size, num_codes, channels). + + Args: + channels: dimension of the features or feature channels + num_codes: number of code words + """ + + def __init__(self, channels, num_codes): + super(Encoding, self).__init__() + # init codewords and smoothing factor + self.channels, self.num_codes = channels, num_codes + std = 1. / ((num_codes * channels)**0.5) + # [num_codes, channels] + self.codewords = nn.Parameter( + torch.empty(num_codes, channels, + dtype=torch.float).uniform_(-std, std), + requires_grad=True) + # [num_codes] + self.scale = nn.Parameter( + torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), + requires_grad=True) + + @staticmethod + def scaled_l2(x, codewords, scale): + num_codes, channels = codewords.size() + batch_size = x.size(0) + reshaped_scale = scale.view((1, 1, num_codes)) + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + + scaled_l2_norm = reshaped_scale * ( + expanded_x - reshaped_codewords).pow(2).sum(dim=3) + return scaled_l2_norm + + @staticmethod + def aggregate(assignment_weights, x, codewords): + num_codes, channels = codewords.size() + reshaped_codewords = codewords.view((1, 1, num_codes, channels)) + batch_size = x.size(0) + + expanded_x = x.unsqueeze(2).expand( + (batch_size, x.size(1), num_codes, channels)) + encoded_feat = (assignment_weights.unsqueeze(3) * + (expanded_x - reshaped_codewords)).sum(dim=1) + return encoded_feat + + def forward(self, x): + assert x.dim() == 4 and x.size(1) == self.channels + # [batch_size, channels, height, width] + batch_size = x.size(0) + # [batch_size, height x width, channels] + x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() + # assignment_weights: [batch_size, channels, num_codes] + assignment_weights = F.softmax( + self.scaled_l2(x, self.codewords, self.scale), dim=2) + # aggregate + encoded_feat = self.aggregate(assignment_weights, x, self.codewords) + return encoded_feat + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ + f'x{self.channels})' + return repr_str diff --git a/mmseg/models/utils/wrappers.py b/mmseg/models/utils/wrappers.py new file mode 100644 index 000000000..ce67e4beb --- /dev/null +++ b/mmseg/models/utils/wrappers.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch.nn as nn +import torch.nn.functional as F + + +def resize(input, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None, + warning=True): + if warning: + if size is not None and align_corners: + input_h, input_w = tuple(int(x) for x in input.shape[2:]) + output_h, output_w = tuple(int(x) for x in size) + if output_h > input_h or output_w > output_h: + if ((output_h > 1 and output_w > 1 and input_h > 1 + and input_w > 1) and (output_h - 1) % (input_h - 1) + and (output_w - 1) % (input_w - 1)): + warnings.warn( + f'When align_corners={align_corners}, ' + 'the output would more aligned if ' + f'input size {(input_h, input_w)} is `x+1` and ' + f'out size {(output_h, output_w)} is `nx+1`') + return F.interpolate(input, size, scale_factor, mode, align_corners) + + +class Upsample(nn.Module): + + def __init__(self, + size=None, + scale_factor=None, + mode='nearest', + align_corners=None): + super(Upsample, self).__init__() + self.size = size + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + if not self.size: + size = [int(t * self.scale_factor) for t in x.shape[-2:]] + else: + size = self.size + return resize(x, size, None, self.mode, self.align_corners) diff --git a/mmseg/ops/__init__.py b/mmseg/ops/__init__.py deleted file mode 100644 index bc075cd4e..000000000 --- a/mmseg/ops/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .encoding import Encoding -from .wrappers import Upsample, resize - -__all__ = ['Upsample', 'resize', 'Encoding'] diff --git a/mmseg/registry/registry.py b/mmseg/registry/registry.py index a3fcf40c5..5c9977ab8 100644 --- a/mmseg/registry/registry.py +++ b/mmseg/registry/registry.py @@ -8,13 +8,16 @@ https://mmengine.readthedocs.io/en/latest/tutorials/registry.html. from mmengine.registry import DATA_SAMPLERS as MMENGINE_DATA_SAMPLERS from mmengine.registry import DATASETS as MMENGINE_DATASETS +from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR from mmengine.registry import HOOKS as MMENGINE_HOOKS +from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS from mmengine.registry import LOOPS as MMENGINE_LOOPS from mmengine.registry import METRICS as MMENGINE_METRICS from mmengine.registry import MODEL_WRAPPERS as MMENGINE_MODEL_WRAPPERS from mmengine.registry import MODELS as MMENGINE_MODELS from mmengine.registry import \ OPTIM_WRAPPER_CONSTRUCTORS as MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS +from mmengine.registry import OPTIM_WRAPPERS as MMENGINE_OPTIM_WRAPPERS from mmengine.registry import OPTIMIZERS as MMENGINE_OPTIMIZERS from mmengine.registry import PARAM_SCHEDULERS as MMENGINE_PARAM_SCHEDULERS from mmengine.registry import \ @@ -53,14 +56,20 @@ WEIGHT_INITIALIZERS = Registry( # mangage all kinds of optimizers like `SGD` and `Adam` OPTIMIZERS = Registry('optimizer', parent=MMENGINE_OPTIMIZERS) +# manage optimizer wrapper +OPTIM_WRAPPERS = Registry('optim_wrapper', parent=MMENGINE_OPTIM_WRAPPERS) # manage constructors that customize the optimization hyperparameters. OPTIM_WRAPPER_CONSTRUCTORS = Registry( - 'optimizer constructor', parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS) + 'optimizer wrapper constructor', + parent=MMENGINE_OPTIM_WRAPPER_CONSTRUCTORS) # mangage all kinds of parameter schedulers like `MultiStepLR` PARAM_SCHEDULERS = Registry( 'parameter scheduler', parent=MMENGINE_PARAM_SCHEDULERS) + # manage all kinds of metrics METRICS = Registry('metric', parent=MMENGINE_METRICS) +# manage evaluator +EVALUATOR = Registry('evaluator', parent=MMENGINE_EVALUATOR) # manage task-specific modules like ohem pixel sampler TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS) @@ -69,3 +78,6 @@ TASK_UTILS = Registry('task util', parent=MMENGINE_TASK_UTILS) VISUALIZERS = Registry('visualizer', parent=MMENGINE_VISUALIZERS) # manage visualizer backend VISBACKENDS = Registry('vis_backend', parent=MMENGINE_VISBACKENDS) + +# manage logprocessor +LOG_PROCESSORS = Registry('log_processor', parent=MMENGINE_LOG_PROCESSORS) diff --git a/mmseg/version.py b/mmseg/version.py index e05146f0a..6ee77798f 100644 --- a/mmseg/version.py +++ b/mmseg/version.py @@ -1,6 +1,6 @@ # Copyright (c) Open-MMLab. All rights reserved. -__version__ = '0.24.1' +__version__ = '1.0.0rc0' def parse_version_info(version_str): diff --git a/requirements/mminstall.txt b/requirements/mminstall.txt index bd43faf87..91571a9c2 100644 --- a/requirements/mminstall.txt +++ b/requirements/mminstall.txt @@ -1,2 +1,2 @@ mmcls>=0.20.1 -mmcv-full>=1.4.4,<=1.6.0 +mmcv-full>=2.0.0,<=2.0.0rc0 diff --git a/tests/test_config.py b/tests/test_config.py index cd99dad5d..d644a34ba 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,10 +3,13 @@ import glob import os from os.path import dirname, exists, isdir, join, relpath -from mmcv import Config +import numpy as np +from mmengine import Config +from mmengine.dataset import Compose from torch import nn from mmseg.models import build_segmentor +from mmseg.utils import register_all_modules def _get_config_directory(): @@ -60,72 +63,69 @@ def test_config_build_segmentor(): _check_decode_head(head_config, segmentor.decode_head) -# def test_config_data_pipeline(): -# """Test whether the data pipeline is valid and can process corner cases. +def test_config_data_pipeline(): + """Test whether the data pipeline is valid and can process corner cases. -# CommandLine: -# xdoctest -m tests/test_config.py test_config_build_data_pipeline -# """ -# import numpy as np -# from mmcv import Config + CommandLine: + xdoctest -m tests/test_config.py test_config_build_data_pipeline + """ -# from mmseg.datasets.transforms import Compose + register_all_modules() + config_dpath = _get_config_directory() + print('Found config_dpath = {!r}'.format(config_dpath)) -# config_dpath = _get_config_directory() -# print('Found config_dpath = {!r}'.format(config_dpath)) + import glob + config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py'))) + config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1] + config_names = [relpath(p, config_dpath) for p in config_fpaths] -# import glob -# config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py'))) -# config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1] -# config_names = [relpath(p, config_dpath) for p in config_fpaths] + print('Using {} config files'.format(len(config_names))) -# print('Using {} config files'.format(len(config_names))) + for config_fname in config_names: + config_fpath = join(config_dpath, config_fname) + print( + 'Building data pipeline, config_fpath = {!r}'.format(config_fpath)) + config_mod = Config.fromfile(config_fpath) -# for config_fname in config_names: -# config_fpath = join(config_dpath, config_fname) -# print( -# 'Building data pipeline, config_fpath = {!r}'.format(config_fpath)) -# config_mod = Config.fromfile(config_fpath) + # remove loading pipeline + load_img_pipeline = config_mod.train_pipeline.pop(0) + to_float32 = load_img_pipeline.get('to_float32', False) + config_mod.train_pipeline.pop(0) + config_mod.test_pipeline.pop(0) + # remove loading annotation in test pipeline + config_mod.test_pipeline.pop(1) -# # remove loading pipeline -# load_img_pipeline = config_mod.train_pipeline.pop(0) -# to_float32 = load_img_pipeline.get('to_float32', False) -# config_mod.train_pipeline.pop(0) -# config_mod.test_pipeline.pop(0) -# # remove loading annotation in test pipeline -# config_mod.test_pipeline.pop(1) + train_pipeline = Compose(config_mod.train_pipeline) + test_pipeline = Compose(config_mod.test_pipeline) -# train_pipeline = Compose(config_mod.train_pipeline) -# test_pipeline = Compose(config_mod.test_pipeline) + img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8) + if to_float32: + img = img.astype(np.float32) + seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8) -# img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8) -# if to_float32: -# img = img.astype(np.float32) -# seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8) + results = dict( + filename='test_img.png', + ori_filename='test_img.png', + img=img, + img_shape=img.shape, + ori_shape=img.shape, + gt_seg_map=seg) + results['seg_fields'] = ['gt_seg_map'] -# results = dict( -# filename='test_img.png', -# ori_filename='test_img.png', -# img=img, -# img_shape=img.shape, -# ori_shape=img.shape, -# gt_seg_map=seg) -# results['seg_fields'] = ['gt_seg_map'] + print('Test training data pipeline: \n{!r}'.format(train_pipeline)) + output_results = train_pipeline(results) + assert output_results is not None -# print('Test training data pipeline: \n{!r}'.format(train_pipeline)) -# output_results = train_pipeline(results) -# assert output_results is not None - -# results = dict( -# filename='test_img.png', -# ori_filename='test_img.png', -# img=img, -# img_shape=img.shape, -# ori_shape=img.shape, -# ) -# print('Test testing data pipeline: \n{!r}'.format(test_pipeline)) -# output_results = test_pipeline(results) -# assert output_results is not None + results = dict( + filename='test_img.png', + ori_filename='test_img.png', + img=img, + img_shape=img.shape, + ori_shape=img.shape, + ) + print('Test testing data pipeline: \n{!r}'.format(test_pipeline)) + output_results = test_pipeline(results) + assert output_results is not None def _check_decode_head(decode_head_cfg, decode_head): diff --git a/tests/test_evaluation/test_citys_metric.py b/tests/test_evaluation/test_metrics/test_citys_metric.py similarity index 100% rename from tests/test_evaluation/test_citys_metric.py rename to tests/test_evaluation/test_metrics/test_citys_metric.py diff --git a/tests/test_evaluation/test_iou_metric.py b/tests/test_evaluation/test_metrics/test_iou_metric.py similarity index 100% rename from tests/test_evaluation/test_iou_metric.py rename to tests/test_evaluation/test_metrics/test_iou_metric.py diff --git a/tests/test_models/test_backbones/__init__.py b/tests/test_models/test_backbones/__init__.py index 8b673fa5c..ef101fec6 100644 --- a/tests/test_models/test_backbones/__init__.py +++ b/tests/test_models/test_backbones/__init__.py @@ -1,4 +1 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .utils import all_zeros, check_norm_state, is_block, is_norm - -__all__ = ['is_norm', 'is_block', 'all_zeros', 'check_norm_state'] diff --git a/tests/test_models/test_backbones/test_unet.py b/tests/test_models/test_backbones/test_unet.py index 9beb7279a..63d3774da 100644 --- a/tests/test_models/test_backbones/test_unet.py +++ b/tests/test_models/test_backbones/test_unet.py @@ -5,7 +5,7 @@ from mmcv.cnn import ConvModule from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule, InterpConv, UNet, UpConvBlock) -from mmseg.ops import Upsample +from mmseg.models.utils import Upsample from .utils import check_norm_state diff --git a/tests/test_data/test_seg_data_sample.py b/tests/test_structures/test_seg_data_sample.py similarity index 100% rename from tests/test_data/test_seg_data_sample.py rename to tests/test_structures/test_seg_data_sample.py diff --git a/tools/analyze_logs.py b/tools/analysis_tools/analyze_logs.py similarity index 100% rename from tools/analyze_logs.py rename to tools/analysis_tools/analyze_logs.py diff --git a/tools/benchmark.py b/tools/analysis_tools/benchmark.py similarity index 100% rename from tools/benchmark.py rename to tools/analysis_tools/benchmark.py diff --git a/tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py similarity index 100% rename from tools/confusion_matrix.py rename to tools/analysis_tools/confusion_matrix.py diff --git a/tools/get_flops.py b/tools/analysis_tools/get_flops.py similarity index 100% rename from tools/get_flops.py rename to tools/analysis_tools/get_flops.py diff --git a/tools/convert_datasets/chase_db1.py b/tools/dataset_converters/chase_db1.py similarity index 100% rename from tools/convert_datasets/chase_db1.py rename to tools/dataset_converters/chase_db1.py diff --git a/tools/convert_datasets/cityscapes.py b/tools/dataset_converters/cityscapes.py similarity index 100% rename from tools/convert_datasets/cityscapes.py rename to tools/dataset_converters/cityscapes.py diff --git a/tools/convert_datasets/coco_stuff10k.py b/tools/dataset_converters/coco_stuff10k.py similarity index 100% rename from tools/convert_datasets/coco_stuff10k.py rename to tools/dataset_converters/coco_stuff10k.py diff --git a/tools/convert_datasets/coco_stuff164k.py b/tools/dataset_converters/coco_stuff164k.py similarity index 100% rename from tools/convert_datasets/coco_stuff164k.py rename to tools/dataset_converters/coco_stuff164k.py diff --git a/tools/convert_datasets/drive.py b/tools/dataset_converters/drive.py similarity index 100% rename from tools/convert_datasets/drive.py rename to tools/dataset_converters/drive.py diff --git a/tools/convert_datasets/hrf.py b/tools/dataset_converters/hrf.py similarity index 100% rename from tools/convert_datasets/hrf.py rename to tools/dataset_converters/hrf.py diff --git a/tools/convert_datasets/isaid.py b/tools/dataset_converters/isaid.py similarity index 100% rename from tools/convert_datasets/isaid.py rename to tools/dataset_converters/isaid.py diff --git a/tools/convert_datasets/loveda.py b/tools/dataset_converters/loveda.py similarity index 100% rename from tools/convert_datasets/loveda.py rename to tools/dataset_converters/loveda.py diff --git a/tools/convert_datasets/pascal_context.py b/tools/dataset_converters/pascal_context.py similarity index 100% rename from tools/convert_datasets/pascal_context.py rename to tools/dataset_converters/pascal_context.py diff --git a/tools/convert_datasets/potsdam.py b/tools/dataset_converters/potsdam.py similarity index 100% rename from tools/convert_datasets/potsdam.py rename to tools/dataset_converters/potsdam.py diff --git a/tools/convert_datasets/stare.py b/tools/dataset_converters/stare.py similarity index 100% rename from tools/convert_datasets/stare.py rename to tools/dataset_converters/stare.py diff --git a/tools/convert_datasets/vaihingen.py b/tools/dataset_converters/vaihingen.py similarity index 100% rename from tools/convert_datasets/vaihingen.py rename to tools/dataset_converters/vaihingen.py diff --git a/tools/convert_datasets/voc_aug.py b/tools/dataset_converters/voc_aug.py similarity index 100% rename from tools/convert_datasets/voc_aug.py rename to tools/dataset_converters/voc_aug.py diff --git a/tools/pytorch2torchscript.py b/tools/deployment/pytorch2torchscript.py similarity index 100% rename from tools/pytorch2torchscript.py rename to tools/deployment/pytorch2torchscript.py diff --git a/tools/browse_dataset.py b/tools/misc/browse_dataset.py similarity index 100% rename from tools/browse_dataset.py rename to tools/misc/browse_dataset.py diff --git a/tools/print_config.py b/tools/misc/print_config.py similarity index 100% rename from tools/print_config.py rename to tools/misc/print_config.py diff --git a/tools/publish_model.py b/tools/misc/publish_model.py similarity index 100% rename from tools/publish_model.py rename to tools/misc/publish_model.py diff --git a/tools/train.py b/tools/train.py index 878d78c31..1cd2e53f8 100644 --- a/tools/train.py +++ b/tools/train.py @@ -15,6 +15,14 @@ def parse_args(): parser = argparse.ArgumentParser(description='Train a segmentor') parser.add_argument('config', help='train config file path') parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--resume', + nargs='?', + type=str, + const='auto', + help='If specify checkpint path, resume from it, while if not ' + 'specify, try to auto resume from the latest checkpoint ' + 'in the work directory.') parser.add_argument( '--amp', action='store_true', @@ -80,6 +88,14 @@ def main(): cfg.optim_wrapper.type = 'AmpOptimWrapper' cfg.optim_wrapper.loss_scale = 'dynamic' + # resume training + if args.resume == 'auto': + cfg.resume = True + cfg.load_from = None + elif args.resume is not None: + cfg.resume = True + cfg.load_from = args.resume + # build the runner from config runner = Runner.from_cfg(cfg)