mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Support albu transform (#2943)
## Motivation
3a14c35974/mmseg/datasets/pipelines/transforms.py (L1348)
## Modification
Add albu to dev-1.x
This commit is contained in:
parent
60a542cc66
commit
7ff58d7074
@ -65,6 +65,7 @@ jobs:
|
|||||||
pip install mmcls==1.0.0rc6
|
pip install mmcls==1.0.0rc6
|
||||||
pip install git+https://github.com/open-mmlab/mmdetection.git@main
|
pip install git+https://github.com/open-mmlab/mmdetection.git@main
|
||||||
pip install -r requirements/tests.txt -r requirements/optional.txt
|
pip install -r requirements/tests.txt -r requirements/optional.txt
|
||||||
|
python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||||
- run:
|
- run:
|
||||||
name: Build and install
|
name: Build and install
|
||||||
command: |
|
command: |
|
||||||
@ -111,6 +112,7 @@ jobs:
|
|||||||
docker exec mmseg pip install mmcls==1.0.0rc6
|
docker exec mmseg pip install mmcls==1.0.0rc6
|
||||||
docker exec mmseg pip install -e /mmdetection
|
docker exec mmseg pip install -e /mmdetection
|
||||||
docker exec mmseg pip install -r requirements/tests.txt -r requirements/optional.txt
|
docker exec mmseg pip install -r requirements/tests.txt -r requirements/optional.txt
|
||||||
|
docker exec mmseg python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||||
- run:
|
- run:
|
||||||
name: Build and install
|
name: Build and install
|
||||||
command: |
|
command: |
|
||||||
|
@ -22,7 +22,7 @@ from .refuge import REFUGEDataset
|
|||||||
from .stare import STAREDataset
|
from .stare import STAREDataset
|
||||||
from .synapse import SynapseDataset
|
from .synapse import SynapseDataset
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
|
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||||
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
|
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
|
||||||
@ -51,5 +51,5 @@ __all__ = [
|
|||||||
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||||
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
|
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
|
||||||
'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
|
'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
|
||||||
'MapillaryDataset_v2'
|
'MapillaryDataset_v2', 'Albu'
|
||||||
]
|
]
|
||||||
|
@ -4,7 +4,7 @@ from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
|||||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||||
LoadImageFromNDArray)
|
LoadImageFromNDArray)
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
|
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||||
BioMedicalRandomGamma, GenerateEdge,
|
BioMedicalRandomGamma, GenerateEdge,
|
||||||
@ -22,5 +22,5 @@ __all__ = [
|
|||||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||||
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
|
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
|
||||||
'RandomRotFlip'
|
'RandomRotFlip', 'Albu'
|
||||||
]
|
]
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import copy
|
import copy
|
||||||
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import mmcv
|
import mmcv
|
||||||
|
import mmengine
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.transforms.base import BaseTransform
|
from mmcv.transforms.base import BaseTransform
|
||||||
from mmcv.transforms.utils import cache_randomness
|
from mmcv.transforms.utils import cache_randomness
|
||||||
@ -15,6 +17,15 @@ from scipy.ndimage import gaussian_filter
|
|||||||
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
|
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
|
||||||
from mmseg.registry import TRANSFORMS
|
from mmseg.registry import TRANSFORMS
|
||||||
|
|
||||||
|
try:
|
||||||
|
import albumentations
|
||||||
|
from albumentations import Compose
|
||||||
|
ALBU_INSTALLED = True
|
||||||
|
except ImportError:
|
||||||
|
albumentations = None
|
||||||
|
Compose = None
|
||||||
|
ALBU_INSTALLED = False
|
||||||
|
|
||||||
|
|
||||||
@TRANSFORMS.register_module()
|
@TRANSFORMS.register_module()
|
||||||
class ResizeToMultiple(BaseTransform):
|
class ResizeToMultiple(BaseTransform):
|
||||||
@ -2135,3 +2146,148 @@ class BioMedical3DRandomFlip(BaseTransform):
|
|||||||
repr_str += f'(prob={self.prob}, axes={self.axes}, ' \
|
repr_str += f'(prob={self.prob}, axes={self.axes}, ' \
|
||||||
f'swap_label_pairs={self.swap_label_pairs})'
|
f'swap_label_pairs={self.swap_label_pairs})'
|
||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class Albu(BaseTransform):
|
||||||
|
"""Albumentation augmentation. Adds custom transformations from
|
||||||
|
Albumentations library. Please, visit
|
||||||
|
`https://albumentations.readthedocs.io` to get more information. An example
|
||||||
|
of ``transforms`` is as followed:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
[
|
||||||
|
dict(
|
||||||
|
type='ShiftScaleRotate',
|
||||||
|
shift_limit=0.0625,
|
||||||
|
scale_limit=0.0,
|
||||||
|
rotate_limit=0,
|
||||||
|
interpolation=1,
|
||||||
|
p=0.5),
|
||||||
|
dict(
|
||||||
|
type='RandomBrightnessContrast',
|
||||||
|
brightness_limit=[0.1, 0.3],
|
||||||
|
contrast_limit=[0.1, 0.3],
|
||||||
|
p=0.2),
|
||||||
|
dict(type='ChannelShuffle', p=0.1),
|
||||||
|
dict(
|
||||||
|
type='OneOf',
|
||||||
|
transforms=[
|
||||||
|
dict(type='Blur', blur_limit=3, p=1.0),
|
||||||
|
dict(type='MedianBlur', blur_limit=3, p=1.0)
|
||||||
|
],
|
||||||
|
p=0.1),
|
||||||
|
]
|
||||||
|
Args:
|
||||||
|
transforms (list[dict]): A list of albu transformations
|
||||||
|
keymap (dict): Contains {'input key':'albumentation-style key'}
|
||||||
|
update_pad_shape (bool): Whether to update padding shape according to \
|
||||||
|
the output shape of the last transform
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
transforms: List[dict],
|
||||||
|
keymap: Optional[dict] = None,
|
||||||
|
update_pad_shape: bool = False):
|
||||||
|
if not ALBU_INSTALLED:
|
||||||
|
raise ImportError(
|
||||||
|
'albumentations is not installed, '
|
||||||
|
'we suggest install albumentation by '
|
||||||
|
'"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa
|
||||||
|
)
|
||||||
|
|
||||||
|
# Args will be modified later, copying it will be safer
|
||||||
|
transforms = copy.deepcopy(transforms)
|
||||||
|
|
||||||
|
self.transforms = transforms
|
||||||
|
self.keymap = keymap
|
||||||
|
self.update_pad_shape = update_pad_shape
|
||||||
|
|
||||||
|
self.aug = Compose([self.albu_builder(t) for t in self.transforms])
|
||||||
|
|
||||||
|
if not keymap:
|
||||||
|
self.keymap_to_albu = {
|
||||||
|
'img': 'image',
|
||||||
|
'gt_masks': 'masks',
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.keymap_to_albu = copy.deepcopy(keymap)
|
||||||
|
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
|
||||||
|
|
||||||
|
def albu_builder(self, cfg: dict) -> object:
|
||||||
|
"""Build a callable object from a dict containing albu arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (dict): Config dict. It should at least contain the key "type".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable: A callable object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert isinstance(cfg, dict) and 'type' in cfg
|
||||||
|
args = cfg.copy()
|
||||||
|
|
||||||
|
obj_type = args.pop('type')
|
||||||
|
if mmengine.is_str(obj_type):
|
||||||
|
if not ALBU_INSTALLED:
|
||||||
|
raise ImportError(
|
||||||
|
'albumentations is not installed, '
|
||||||
|
'we suggest install albumentation by '
|
||||||
|
'"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa
|
||||||
|
)
|
||||||
|
obj_cls = getattr(albumentations, obj_type)
|
||||||
|
elif inspect.isclass(obj_type):
|
||||||
|
obj_cls = obj_type
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f'type must be a valid type or str, but got {type(obj_type)}')
|
||||||
|
|
||||||
|
if 'transforms' in args:
|
||||||
|
args['transforms'] = [
|
||||||
|
self.albu_builder(t) for t in args['transforms']
|
||||||
|
]
|
||||||
|
|
||||||
|
return obj_cls(**args)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mapper(d: dict, keymap: dict):
|
||||||
|
"""Dictionary mapper.
|
||||||
|
|
||||||
|
Renames keys according to keymap provided.
|
||||||
|
Args:
|
||||||
|
d (dict): old dict
|
||||||
|
keymap (dict): {'old_key':'new_key'}
|
||||||
|
Returns:
|
||||||
|
dict: new dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
updated_dict = {}
|
||||||
|
for k, _ in zip(d.keys(), d.values()):
|
||||||
|
new_k = keymap.get(k, k)
|
||||||
|
updated_dict[new_k] = d[k]
|
||||||
|
return updated_dict
|
||||||
|
|
||||||
|
def transform(self, results):
|
||||||
|
# dict to albumentations format
|
||||||
|
results = self.mapper(results, self.keymap_to_albu)
|
||||||
|
|
||||||
|
# Convert to RGB since Albumentations works with RGB images
|
||||||
|
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
results = self.aug(**results)
|
||||||
|
|
||||||
|
# Convert back to BGR
|
||||||
|
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# back to the original format
|
||||||
|
results = self.mapper(results, self.keymap_back)
|
||||||
|
|
||||||
|
# update final shape
|
||||||
|
if self.update_pad_shape:
|
||||||
|
results['pad_shape'] = results['img'].shape
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
|
||||||
|
return repr_str
|
||||||
|
@ -1160,3 +1160,61 @@ def test_biomedical_3d_flip():
|
|||||||
results = transform(results)
|
results = transform(results)
|
||||||
assert np.equal(original_img, results['img']).all()
|
assert np.equal(original_img, results['img']).all()
|
||||||
assert np.equal(original_seg, results['gt_seg_map']).all()
|
assert np.equal(original_seg, results['gt_seg_map']).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_albu_transform():
|
||||||
|
results = dict(
|
||||||
|
img_path=osp.join(osp.dirname(__file__), '../data/color.jpg'))
|
||||||
|
|
||||||
|
# Define simple pipeline
|
||||||
|
load = dict(type='LoadImageFromFile')
|
||||||
|
load = TRANSFORMS.build(load)
|
||||||
|
|
||||||
|
albu_transform = dict(
|
||||||
|
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
|
||||||
|
albu_transform = TRANSFORMS.build(albu_transform)
|
||||||
|
|
||||||
|
normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
|
||||||
|
normalize = TRANSFORMS.build(normalize)
|
||||||
|
|
||||||
|
# Execute transforms
|
||||||
|
results = load(results)
|
||||||
|
results = albu_transform(results)
|
||||||
|
results = normalize(results)
|
||||||
|
|
||||||
|
assert results['img'].dtype == np.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_albu_channel_order():
|
||||||
|
results = dict(
|
||||||
|
img_path=osp.join(osp.dirname(__file__), '../data/color.jpg'))
|
||||||
|
|
||||||
|
# Define simple pipeline
|
||||||
|
load = dict(type='LoadImageFromFile')
|
||||||
|
load = TRANSFORMS.build(load)
|
||||||
|
|
||||||
|
# Transform is modifying B channel
|
||||||
|
albu_transform = dict(
|
||||||
|
type='Albu',
|
||||||
|
transforms=[
|
||||||
|
dict(
|
||||||
|
type='RGBShift',
|
||||||
|
r_shift_limit=0,
|
||||||
|
g_shift_limit=0,
|
||||||
|
b_shift_limit=200,
|
||||||
|
p=1)
|
||||||
|
])
|
||||||
|
albu_transform = TRANSFORMS.build(albu_transform)
|
||||||
|
|
||||||
|
# Execute transforms
|
||||||
|
results_load = load(results)
|
||||||
|
results_albu = albu_transform(results_load)
|
||||||
|
|
||||||
|
# assert only Green and Red channel are not modified
|
||||||
|
np.testing.assert_array_equal(results_albu['img'][..., 1:],
|
||||||
|
results_load['img'][..., 1:])
|
||||||
|
|
||||||
|
# assert Blue channel is modified
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
np.testing.assert_array_equal(results_albu['img'][..., 0],
|
||||||
|
results_load['img'][..., 0])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user