diff --git a/.circleci/test.yml b/.circleci/test.yml index d77170611..6a1271b4f 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -77,6 +77,7 @@ jobs: command: | python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch<< parameters.torch >>/index.html python -m pip install -r requirements.txt + python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations - run: name: Build and install command: | @@ -119,6 +120,8 @@ jobs: command: | docker exec mmseg pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/${MMCV_CUDA}/torch${MMCV_TORCH}/index.html docker exec mmseg pip install -r requirements.txt + docker exec mmseg pip install typing-extensions -U + docker exec mmseg pip install albumentations --use-pep517 qudida albumentations - run: name: Build and install command: | diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 48c6c3569..b68b2cc28 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -64,6 +64,7 @@ jobs: - name: Install unittest dependencies run: | pip install -r requirements.txt + pip install albumentations>=0.3.2 --no-binary qudida,albumentations - name: Build and install run: rm -rf .eggs && pip install -e . - name: Run unittests and generate coverage report @@ -133,6 +134,7 @@ jobs: python -m pip install -U openmim mim install mmcv-full python -m pip install -r requirements.txt + python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations python -c 'import mmcv; print(mmcv.__version__)' - name: Build and install run: | @@ -200,6 +202,7 @@ jobs: python -m pip install openmim mim install mmcv-full python -m pip install -r requirements.txt + python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations python -c 'import mmcv; print(mmcv.__version__)' - name: Build and install run: | @@ -263,6 +266,7 @@ jobs: python -m pip install openmim mim install mmcv-full python -m pip install -r requirements.txt + python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations python -c 'import mmcv; print(mmcv.__version__)' - name: Build and install run: | @@ -301,7 +305,9 @@ jobs: pip install -U openmim mim install mmcv-full - name: Install unittest dependencies - run: pip install -r requirements/tests.txt -r requirements/optional.txt + run: | + pip install -r requirements/tests.txt -r requirements/optional.txt + pip install albumentations>=0.3.2 --no-binary qudida,albumentations - name: Build and install run: pip install -e . - name: Run unittests diff --git a/docs/en/get_started.md b/docs/en/get_started.md index bbe3d5795..8167bf4fa 100644 --- a/docs/en/get_started.md +++ b/docs/en/get_started.md @@ -63,6 +63,9 @@ Case b: If you use mmsegmentation as a dependency or third-party package, instal pip install mmsegmentation ``` +**Note:** +If you would like to use albumentations, we suggest using pip install -U albumentations --no-binary qudida,albumentations. If you simply use pip install albumentations>=0.3.2, it will install opencv-python-headless simultaneously (even though you have already installed opencv-python). We recommended checking the environment after installing albumentations to ensure that opencv-python and opencv-python-headless are not installed at the same time, because it might cause unexpected issues if they both installed. Please refer to [official documentation](https://albumentations.ai/docs/getting_started/installation/#note-on-opencv-dependencies) for more details. + ## Verify the installation To verify whether MMSegmentation is installed correctly, we provide some sample codes to run an inference demo. diff --git a/docs/zh_cn/get_started.md b/docs/zh_cn/get_started.md index e57fa0360..09d8d4116 100644 --- a/docs/zh_cn/get_started.md +++ b/docs/zh_cn/get_started.md @@ -67,6 +67,9 @@ pip install -v -e . pip install mmsegmentation ``` +**注意:** +如果你想使用 albumentations,我们建议使用 pip install-U albumentations --no-binary qudida,albumentations 进行安装。如果您仅使用 pip install albumentations>=0.3.2 进行安装,它将同时安装 opencv-python-headless(即使您已经安装了 opencv-python)。我们建议在安装了 albumentations 后检查环境,以确保没有同时安装 opencv-python 和 opencv-python-headless,因为如果两者都安装了,可能会导致意外问题。请参阅[官方文档](https://albumentations.ai/docs/getting_started/installation/#note-on-opencv-dependencies)了解更多详细信息。 + ## 验证安装 为了验证 MMSegmentation 是否安装正确,我们提供了一些示例代码来执行模型推理。 diff --git a/mmseg/datasets/pipelines/__init__.py b/mmseg/datasets/pipelines/__init__.py index 8256a6fe2..90d6a2444 100644 --- a/mmseg/datasets/pipelines/__init__.py +++ b/mmseg/datasets/pipelines/__init__.py @@ -4,7 +4,7 @@ from .formatting import (Collect, ImageToTensor, ToDataContainer, ToTensor, Transpose, to_tensor) from .loading import LoadAnnotations, LoadImageFromFile from .test_time_aug import MultiScaleFlipAug -from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, +from .transforms import (CLAHE, AdjustGamma, Albu, Normalize, Pad, PhotoMetricDistortion, RandomCrop, RandomCutOut, RandomFlip, RandomMosaic, RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) @@ -15,5 +15,5 @@ __all__ = [ 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut', - 'RandomMosaic' + 'RandomMosaic', 'Albu' ] diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 5673b646f..19899674f 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import inspect +import cv2 import mmcv import numpy as np from mmcv.utils import deprecated_api_warning, is_tuple_of @@ -8,6 +10,13 @@ from numpy import random from ..builder import PIPELINES +try: + import albumentations + from albumentations import Compose +except ImportError: + albumentations = None + Compose = None + @PIPELINES.register_module() class ResizeToMultiple(object): @@ -1333,3 +1342,146 @@ class RandomMosaic(object): repr_str += f'pad_val={self.pad_val}, ' repr_str += f'seg_pad_val={self.pad_val})' return repr_str + + +@PIPELINES.register_module() +class Albu: + """Albumentation augmentation. Adds custom transformations from + Albumentations library. Please, visit + `https://albumentations.readthedocs.io` to get more information. An example + of ``transforms`` is as followed: + + .. code-block:: + [ + dict( + type='ShiftScaleRotate', + shift_limit=0.0625, + scale_limit=0.0, + rotate_limit=0, + interpolation=1, + p=0.5), + dict( + type='RandomBrightnessContrast', + brightness_limit=[0.1, 0.3], + contrast_limit=[0.1, 0.3], + p=0.2), + dict(type='ChannelShuffle', p=0.1), + dict( + type='OneOf', + transforms=[ + dict(type='Blur', blur_limit=3, p=1.0), + dict(type='MedianBlur', blur_limit=3, p=1.0) + ], + p=0.1), + ] + Args: + transforms (list[dict]): A list of albu transformations + keymap (dict): Contains {'input key':'albumentation-style key'} + update_pad_shape (bool): Whether to update padding shape according to \ + the output shape of the last transform + """ + + def __init__(self, transforms, keymap=None, update_pad_shape=False): + if Compose is None: + raise ImportError( + 'albumentations is not installed, ' + 'we suggest install albumentation by ' + '"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa + ) + + # Args will be modified later, copying it will be safer + transforms = copy.deepcopy(transforms) + + self.transforms = transforms + self.filter_lost_elements = False + self.update_pad_shape = update_pad_shape + + self.aug = Compose([self.albu_builder(t) for t in self.transforms]) + + if not keymap: + self.keymap_to_albu = { + 'img': 'image', + 'gt_masks': 'masks', + } + else: + self.keymap_to_albu = copy.deepcopy(keymap) + self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()} + + def albu_builder(self, cfg): + """Import a module from albumentations. + + It inherits some of :func:`build_from_cfg` logic. + Args: + cfg (dict): Config dict. It should at least contain the key "type". + Returns: + obj: The constructed object. + """ + + assert isinstance(cfg, dict) and 'type' in cfg + args = cfg.copy() + + obj_type = args.pop('type') + if mmcv.is_str(obj_type): + if albumentations is None: + raise ImportError( + 'albumentations is not installed, ' + 'we suggest install albumentation by ' + '"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa + ) + obj_cls = getattr(albumentations, obj_type) + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a str or valid type, but got {type(obj_type)}') + + if 'transforms' in args: + args['transforms'] = [ + self.albu_builder(transform) + for transform in args['transforms'] + ] + + return obj_cls(**args) + + @staticmethod + def mapper(d, keymap): + """Dictionary mapper. + + Renames keys according to keymap provided. + Args: + d (dict): old dict + keymap (dict): {'old_key':'new_key'} + Returns: + dict: new dict. + """ + + updated_dict = {} + for k, _ in zip(d.keys(), d.values()): + new_k = keymap.get(k, k) + updated_dict[new_k] = d[k] + return updated_dict + + def __call__(self, results): + # dict to albumentations format + results = self.mapper(results, self.keymap_to_albu) + + # Convert to RGB since Albumentations works with RGB images + results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB) + + results = self.aug(**results) + + # Convert back to BGR + results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_RGB2BGR) + + # back to the original format + results = self.mapper(results, self.keymap_back) + + # update final shape + if self.update_pad_shape: + results['pad_shape'] = results['img'].shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(transforms={self.transforms})' + return repr_str diff --git a/requirements/readthedocs.txt b/requirements/readthedocs.txt index 22a894bd7..55f245d4c 100644 --- a/requirements/readthedocs.txt +++ b/requirements/readthedocs.txt @@ -1,4 +1,5 @@ mmcv prettytable +scipy torch torchvision diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 520408fe8..9a8df8149 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -3,3 +3,4 @@ mmcls>=0.20.1 numpy packaging prettytable +scipy diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index fcc46e7d0..f5fde99d8 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -688,3 +688,63 @@ def test_mosaic(): mosaic_module = build_from_cfg(transform, PIPELINES) results = mosaic_module(results) assert results['img'].shape[:2] == (20, 24) + + +def test_albu_transform(): + results = dict( + img_prefix=osp.join(osp.dirname(__file__), '../data'), + img_info=dict(filename='color.jpg')) + + # Define simple pipeline + load = dict(type='LoadImageFromFile') + load = build_from_cfg(load, PIPELINES) + + albu_transform = dict( + type='Albu', transforms=[dict(type='ChannelShuffle', p=1)]) + albu_transform = build_from_cfg(albu_transform, PIPELINES) + + normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True) + normalize = build_from_cfg(normalize, PIPELINES) + + # Execute transforms + results = load(results) + results = albu_transform(results) + results = normalize(results) + + assert results['img'].dtype == np.float32 + + +def test_albu_channel_order(): + results = dict( + img_prefix=osp.join(osp.dirname(__file__), '../data'), + img_info=dict(filename='color.jpg')) + + # Define simple pipeline + load = dict(type='LoadImageFromFile') + load = build_from_cfg(load, PIPELINES) + + # Transform is modifying B channel + albu_transform = dict( + type='Albu', + transforms=[ + dict( + type='RGBShift', + r_shift_limit=0, + g_shift_limit=0, + b_shift_limit=200, + p=1) + ]) + albu_transform = build_from_cfg(albu_transform, PIPELINES) + + # Execute transforms + results_load = load(results) + results_albu = albu_transform(results_load) + + # assert only Green and Red channel are not modified + np.testing.assert_array_equal(results_albu['img'][..., 1:], + results_load['img'][..., 1:]) + + # assert Blue channel is modified + with pytest.raises(AssertionError): + np.testing.assert_array_equal(results_albu['img'][..., 0], + results_load['img'][..., 0]) diff --git a/tests/test_models/test_backbones/test_beit.py b/tests/test_models/test_backbones/test_beit.py index cf3960894..9495b60a1 100644 --- a/tests/test_models/test_backbones/test_beit.py +++ b/tests/test_models/test_backbones/test_beit.py @@ -140,8 +140,7 @@ def test_beit_init(): } } model = BEiT(img_size=(512, 512)) - with pytest.raises(AttributeError): - model.resize_rel_pos_embed(ckpt) + ckpt = model.resize_rel_pos_embed(ckpt) # pretrained=None # init_cfg=123, whose type is unsupported diff --git a/tests/test_models/test_backbones/test_mae.py b/tests/test_models/test_backbones/test_mae.py index 562d067a7..aa7a292db 100644 --- a/tests/test_models/test_backbones/test_mae.py +++ b/tests/test_models/test_backbones/test_mae.py @@ -138,10 +138,19 @@ def test_mae_init(): } } model = MAE(img_size=(512, 512)) - with pytest.raises(AttributeError): - model.resize_rel_pos_embed(ckpt) + ckpt = model.resize_rel_pos_embed(ckpt) # test resize abs pos embed + value = torch.randn(732, 16) + abs_pos_embed_value = torch.rand(1, 17, 768) + ckpt = { + 'state_dict': { + 'layers.0.attn.relative_position_index': 0, + 'layers.0.attn.relative_position_bias_table': value, + 'pos_embed': abs_pos_embed_value + } + } + model = MAE(img_size=(512, 512)) ckpt = model.resize_abs_pos_embed(ckpt['state_dict']) # pretrained=None