diff --git a/.circleci/test.yml b/.circleci/test.yml index 6ac32e565..1292b4e38 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -65,6 +65,7 @@ jobs: pip install mmcls==1.0.0rc6 pip install git+https://github.com/open-mmlab/mmdetection.git@main pip install -r requirements/tests.txt -r requirements/optional.txt + python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations - run: name: Build and install command: | @@ -111,6 +112,7 @@ jobs: docker exec mmseg pip install mmcls==1.0.0rc6 docker exec mmseg pip install -e /mmdetection docker exec mmseg pip install -r requirements/tests.txt -r requirements/optional.txt + docker exec mmseg python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations - run: name: Build and install command: | diff --git a/mmseg/datasets/__init__.py b/mmseg/datasets/__init__.py index a90d53c88..81c5a7363 100644 --- a/mmseg/datasets/__init__.py +++ b/mmseg/datasets/__init__.py @@ -22,7 +22,7 @@ from .refuge import REFUGEDataset from .stare import STAREDataset from .synapse import SynapseDataset # yapf: disable -from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad, +from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad, BioMedical3DRandomCrop, BioMedical3DRandomFlip, BioMedicalGaussianBlur, BioMedicalGaussianNoise, BioMedicalRandomGamma, GenerateEdge, LoadAnnotations, @@ -51,5 +51,5 @@ __all__ = [ 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', 'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip', 'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1', - 'MapillaryDataset_v2' + 'MapillaryDataset_v2', 'Albu' ] diff --git a/mmseg/datasets/transforms/__init__.py b/mmseg/datasets/transforms/__init__.py index 25f4ee4a9..c732fef23 100644 --- a/mmseg/datasets/transforms/__init__.py +++ b/mmseg/datasets/transforms/__init__.py @@ -4,7 +4,7 @@ from .loading import (LoadAnnotations, LoadBiomedicalAnnotation, LoadBiomedicalData, LoadBiomedicalImageFromFile, LoadImageFromNDArray) # yapf: disable -from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad, +from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad, BioMedical3DRandomCrop, BioMedical3DRandomFlip, BioMedicalGaussianBlur, BioMedicalGaussianNoise, BioMedicalRandomGamma, GenerateEdge, @@ -22,5 +22,5 @@ __all__ = [ 'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge', 'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur', 'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad', - 'RandomRotFlip' + 'RandomRotFlip', 'Albu' ] diff --git a/mmseg/datasets/transforms/transforms.py b/mmseg/datasets/transforms/transforms.py index fb7e2a0e6..bf538acb2 100644 --- a/mmseg/datasets/transforms/transforms.py +++ b/mmseg/datasets/transforms/transforms.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import inspect import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union import cv2 import mmcv +import mmengine import numpy as np from mmcv.transforms.base import BaseTransform from mmcv.transforms.utils import cache_randomness @@ -15,6 +17,15 @@ from scipy.ndimage import gaussian_filter from mmseg.datasets.dataset_wrappers import MultiImageMixDataset from mmseg.registry import TRANSFORMS +try: + import albumentations + from albumentations import Compose + ALBU_INSTALLED = True +except ImportError: + albumentations = None + Compose = None + ALBU_INSTALLED = False + @TRANSFORMS.register_module() class ResizeToMultiple(BaseTransform): @@ -2135,3 +2146,148 @@ class BioMedical3DRandomFlip(BaseTransform): repr_str += f'(prob={self.prob}, axes={self.axes}, ' \ f'swap_label_pairs={self.swap_label_pairs})' return repr_str + + +@TRANSFORMS.register_module() +class Albu(BaseTransform): + """Albumentation augmentation. Adds custom transformations from + Albumentations library. Please, visit + `https://albumentations.readthedocs.io` to get more information. An example + of ``transforms`` is as followed: + + .. code-block:: + [ + dict( + type='ShiftScaleRotate', + shift_limit=0.0625, + scale_limit=0.0, + rotate_limit=0, + interpolation=1, + p=0.5), + dict( + type='RandomBrightnessContrast', + brightness_limit=[0.1, 0.3], + contrast_limit=[0.1, 0.3], + p=0.2), + dict(type='ChannelShuffle', p=0.1), + dict( + type='OneOf', + transforms=[ + dict(type='Blur', blur_limit=3, p=1.0), + dict(type='MedianBlur', blur_limit=3, p=1.0) + ], + p=0.1), + ] + Args: + transforms (list[dict]): A list of albu transformations + keymap (dict): Contains {'input key':'albumentation-style key'} + update_pad_shape (bool): Whether to update padding shape according to \ + the output shape of the last transform + """ + + def __init__(self, + transforms: List[dict], + keymap: Optional[dict] = None, + update_pad_shape: bool = False): + if not ALBU_INSTALLED: + raise ImportError( + 'albumentations is not installed, ' + 'we suggest install albumentation by ' + '"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa + ) + + # Args will be modified later, copying it will be safer + transforms = copy.deepcopy(transforms) + + self.transforms = transforms + self.keymap = keymap + self.update_pad_shape = update_pad_shape + + self.aug = Compose([self.albu_builder(t) for t in self.transforms]) + + if not keymap: + self.keymap_to_albu = { + 'img': 'image', + 'gt_masks': 'masks', + } + else: + self.keymap_to_albu = copy.deepcopy(keymap) + self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()} + + def albu_builder(self, cfg: dict) -> object: + """Build a callable object from a dict containing albu arguments. + + Args: + cfg (dict): Config dict. It should at least contain the key "type". + + Returns: + Callable: A callable object. + """ + + assert isinstance(cfg, dict) and 'type' in cfg + args = cfg.copy() + + obj_type = args.pop('type') + if mmengine.is_str(obj_type): + if not ALBU_INSTALLED: + raise ImportError( + 'albumentations is not installed, ' + 'we suggest install albumentation by ' + '"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa + ) + obj_cls = getattr(albumentations, obj_type) + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + f'type must be a valid type or str, but got {type(obj_type)}') + + if 'transforms' in args: + args['transforms'] = [ + self.albu_builder(t) for t in args['transforms'] + ] + + return obj_cls(**args) + + @staticmethod + def mapper(d: dict, keymap: dict): + """Dictionary mapper. + + Renames keys according to keymap provided. + Args: + d (dict): old dict + keymap (dict): {'old_key':'new_key'} + Returns: + dict: new dict. + """ + + updated_dict = {} + for k, _ in zip(d.keys(), d.values()): + new_k = keymap.get(k, k) + updated_dict[new_k] = d[k] + return updated_dict + + def transform(self, results): + # dict to albumentations format + results = self.mapper(results, self.keymap_to_albu) + + # Convert to RGB since Albumentations works with RGB images + results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB) + + results = self.aug(**results) + + # Convert back to BGR + results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_RGB2BGR) + + # back to the original format + results = self.mapper(results, self.keymap_back) + + # update final shape + if self.update_pad_shape: + results['pad_shape'] = results['img'].shape + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(transforms={self.transforms})' + return repr_str diff --git a/tests/test_datasets/test_transform.py b/tests/test_datasets/test_transform.py index 92d6c6106..239b3842b 100644 --- a/tests/test_datasets/test_transform.py +++ b/tests/test_datasets/test_transform.py @@ -1160,3 +1160,61 @@ def test_biomedical_3d_flip(): results = transform(results) assert np.equal(original_img, results['img']).all() assert np.equal(original_seg, results['gt_seg_map']).all() + + +def test_albu_transform(): + results = dict( + img_path=osp.join(osp.dirname(__file__), '../data/color.jpg')) + + # Define simple pipeline + load = dict(type='LoadImageFromFile') + load = TRANSFORMS.build(load) + + albu_transform = dict( + type='Albu', transforms=[dict(type='ChannelShuffle', p=1)]) + albu_transform = TRANSFORMS.build(albu_transform) + + normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True) + normalize = TRANSFORMS.build(normalize) + + # Execute transforms + results = load(results) + results = albu_transform(results) + results = normalize(results) + + assert results['img'].dtype == np.float32 + + +def test_albu_channel_order(): + results = dict( + img_path=osp.join(osp.dirname(__file__), '../data/color.jpg')) + + # Define simple pipeline + load = dict(type='LoadImageFromFile') + load = TRANSFORMS.build(load) + + # Transform is modifying B channel + albu_transform = dict( + type='Albu', + transforms=[ + dict( + type='RGBShift', + r_shift_limit=0, + g_shift_limit=0, + b_shift_limit=200, + p=1) + ]) + albu_transform = TRANSFORMS.build(albu_transform) + + # Execute transforms + results_load = load(results) + results_albu = albu_transform(results_load) + + # assert only Green and Red channel are not modified + np.testing.assert_array_equal(results_albu['img'][..., 1:], + results_load['img'][..., 1:]) + + # assert Blue channel is modified + with pytest.raises(AssertionError): + np.testing.assert_array_equal(results_albu['img'][..., 0], + results_load['img'][..., 0])