mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Support albu transform (#2943)
## Motivation
3a14c35974/mmseg/datasets/pipelines/transforms.py (L1348)
## Modification
Add albu to dev-1.x
This commit is contained in:
parent
60a542cc66
commit
7ff58d7074
@ -65,6 +65,7 @@ jobs:
|
||||
pip install mmcls==1.0.0rc6
|
||||
pip install git+https://github.com/open-mmlab/mmdetection.git@main
|
||||
pip install -r requirements/tests.txt -r requirements/optional.txt
|
||||
python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||
- run:
|
||||
name: Build and install
|
||||
command: |
|
||||
@ -111,6 +112,7 @@ jobs:
|
||||
docker exec mmseg pip install mmcls==1.0.0rc6
|
||||
docker exec mmseg pip install -e /mmdetection
|
||||
docker exec mmseg pip install -r requirements/tests.txt -r requirements/optional.txt
|
||||
docker exec mmseg python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||
- run:
|
||||
name: Build and install
|
||||
command: |
|
||||
|
@ -22,7 +22,7 @@ from .refuge import REFUGEDataset
|
||||
from .stare import STAREDataset
|
||||
from .synapse import SynapseDataset
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
|
||||
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
|
||||
@ -51,5 +51,5 @@ __all__ = [
|
||||
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
|
||||
'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
|
||||
'MapillaryDataset_v2'
|
||||
'MapillaryDataset_v2', 'Albu'
|
||||
]
|
||||
|
@ -4,7 +4,7 @@ from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadImageFromNDArray)
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
|
||||
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
BioMedicalRandomGamma, GenerateEdge,
|
||||
@ -22,5 +22,5 @@ __all__ = [
|
||||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
|
||||
'RandomRotFlip'
|
||||
'RandomRotFlip', 'Albu'
|
||||
]
|
||||
|
@ -1,10 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import mmengine
|
||||
import numpy as np
|
||||
from mmcv.transforms.base import BaseTransform
|
||||
from mmcv.transforms.utils import cache_randomness
|
||||
@ -15,6 +17,15 @@ from scipy.ndimage import gaussian_filter
|
||||
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
|
||||
from mmseg.registry import TRANSFORMS
|
||||
|
||||
try:
|
||||
import albumentations
|
||||
from albumentations import Compose
|
||||
ALBU_INSTALLED = True
|
||||
except ImportError:
|
||||
albumentations = None
|
||||
Compose = None
|
||||
ALBU_INSTALLED = False
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ResizeToMultiple(BaseTransform):
|
||||
@ -2135,3 +2146,148 @@ class BioMedical3DRandomFlip(BaseTransform):
|
||||
repr_str += f'(prob={self.prob}, axes={self.axes}, ' \
|
||||
f'swap_label_pairs={self.swap_label_pairs})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class Albu(BaseTransform):
|
||||
"""Albumentation augmentation. Adds custom transformations from
|
||||
Albumentations library. Please, visit
|
||||
`https://albumentations.readthedocs.io` to get more information. An example
|
||||
of ``transforms`` is as followed:
|
||||
|
||||
.. code-block::
|
||||
[
|
||||
dict(
|
||||
type='ShiftScaleRotate',
|
||||
shift_limit=0.0625,
|
||||
scale_limit=0.0,
|
||||
rotate_limit=0,
|
||||
interpolation=1,
|
||||
p=0.5),
|
||||
dict(
|
||||
type='RandomBrightnessContrast',
|
||||
brightness_limit=[0.1, 0.3],
|
||||
contrast_limit=[0.1, 0.3],
|
||||
p=0.2),
|
||||
dict(type='ChannelShuffle', p=0.1),
|
||||
dict(
|
||||
type='OneOf',
|
||||
transforms=[
|
||||
dict(type='Blur', blur_limit=3, p=1.0),
|
||||
dict(type='MedianBlur', blur_limit=3, p=1.0)
|
||||
],
|
||||
p=0.1),
|
||||
]
|
||||
Args:
|
||||
transforms (list[dict]): A list of albu transformations
|
||||
keymap (dict): Contains {'input key':'albumentation-style key'}
|
||||
update_pad_shape (bool): Whether to update padding shape according to \
|
||||
the output shape of the last transform
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
transforms: List[dict],
|
||||
keymap: Optional[dict] = None,
|
||||
update_pad_shape: bool = False):
|
||||
if not ALBU_INSTALLED:
|
||||
raise ImportError(
|
||||
'albumentations is not installed, '
|
||||
'we suggest install albumentation by '
|
||||
'"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa
|
||||
)
|
||||
|
||||
# Args will be modified later, copying it will be safer
|
||||
transforms = copy.deepcopy(transforms)
|
||||
|
||||
self.transforms = transforms
|
||||
self.keymap = keymap
|
||||
self.update_pad_shape = update_pad_shape
|
||||
|
||||
self.aug = Compose([self.albu_builder(t) for t in self.transforms])
|
||||
|
||||
if not keymap:
|
||||
self.keymap_to_albu = {
|
||||
'img': 'image',
|
||||
'gt_masks': 'masks',
|
||||
}
|
||||
else:
|
||||
self.keymap_to_albu = copy.deepcopy(keymap)
|
||||
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
|
||||
|
||||
def albu_builder(self, cfg: dict) -> object:
|
||||
"""Build a callable object from a dict containing albu arguments.
|
||||
|
||||
Args:
|
||||
cfg (dict): Config dict. It should at least contain the key "type".
|
||||
|
||||
Returns:
|
||||
Callable: A callable object.
|
||||
"""
|
||||
|
||||
assert isinstance(cfg, dict) and 'type' in cfg
|
||||
args = cfg.copy()
|
||||
|
||||
obj_type = args.pop('type')
|
||||
if mmengine.is_str(obj_type):
|
||||
if not ALBU_INSTALLED:
|
||||
raise ImportError(
|
||||
'albumentations is not installed, '
|
||||
'we suggest install albumentation by '
|
||||
'"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa
|
||||
)
|
||||
obj_cls = getattr(albumentations, obj_type)
|
||||
elif inspect.isclass(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(
|
||||
f'type must be a valid type or str, but got {type(obj_type)}')
|
||||
|
||||
if 'transforms' in args:
|
||||
args['transforms'] = [
|
||||
self.albu_builder(t) for t in args['transforms']
|
||||
]
|
||||
|
||||
return obj_cls(**args)
|
||||
|
||||
@staticmethod
|
||||
def mapper(d: dict, keymap: dict):
|
||||
"""Dictionary mapper.
|
||||
|
||||
Renames keys according to keymap provided.
|
||||
Args:
|
||||
d (dict): old dict
|
||||
keymap (dict): {'old_key':'new_key'}
|
||||
Returns:
|
||||
dict: new dict.
|
||||
"""
|
||||
|
||||
updated_dict = {}
|
||||
for k, _ in zip(d.keys(), d.values()):
|
||||
new_k = keymap.get(k, k)
|
||||
updated_dict[new_k] = d[k]
|
||||
return updated_dict
|
||||
|
||||
def transform(self, results):
|
||||
# dict to albumentations format
|
||||
results = self.mapper(results, self.keymap_to_albu)
|
||||
|
||||
# Convert to RGB since Albumentations works with RGB images
|
||||
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB)
|
||||
|
||||
results = self.aug(**results)
|
||||
|
||||
# Convert back to BGR
|
||||
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_RGB2BGR)
|
||||
|
||||
# back to the original format
|
||||
results = self.mapper(results, self.keymap_back)
|
||||
|
||||
# update final shape
|
||||
if self.update_pad_shape:
|
||||
results['pad_shape'] = results['img'].shape
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
|
||||
return repr_str
|
||||
|
@ -1160,3 +1160,61 @@ def test_biomedical_3d_flip():
|
||||
results = transform(results)
|
||||
assert np.equal(original_img, results['img']).all()
|
||||
assert np.equal(original_seg, results['gt_seg_map']).all()
|
||||
|
||||
|
||||
def test_albu_transform():
|
||||
results = dict(
|
||||
img_path=osp.join(osp.dirname(__file__), '../data/color.jpg'))
|
||||
|
||||
# Define simple pipeline
|
||||
load = dict(type='LoadImageFromFile')
|
||||
load = TRANSFORMS.build(load)
|
||||
|
||||
albu_transform = dict(
|
||||
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
|
||||
albu_transform = TRANSFORMS.build(albu_transform)
|
||||
|
||||
normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
|
||||
normalize = TRANSFORMS.build(normalize)
|
||||
|
||||
# Execute transforms
|
||||
results = load(results)
|
||||
results = albu_transform(results)
|
||||
results = normalize(results)
|
||||
|
||||
assert results['img'].dtype == np.float32
|
||||
|
||||
|
||||
def test_albu_channel_order():
|
||||
results = dict(
|
||||
img_path=osp.join(osp.dirname(__file__), '../data/color.jpg'))
|
||||
|
||||
# Define simple pipeline
|
||||
load = dict(type='LoadImageFromFile')
|
||||
load = TRANSFORMS.build(load)
|
||||
|
||||
# Transform is modifying B channel
|
||||
albu_transform = dict(
|
||||
type='Albu',
|
||||
transforms=[
|
||||
dict(
|
||||
type='RGBShift',
|
||||
r_shift_limit=0,
|
||||
g_shift_limit=0,
|
||||
b_shift_limit=200,
|
||||
p=1)
|
||||
])
|
||||
albu_transform = TRANSFORMS.build(albu_transform)
|
||||
|
||||
# Execute transforms
|
||||
results_load = load(results)
|
||||
results_albu = albu_transform(results_load)
|
||||
|
||||
# assert only Green and Red channel are not modified
|
||||
np.testing.assert_array_equal(results_albu['img'][..., 1:],
|
||||
results_load['img'][..., 1:])
|
||||
|
||||
# assert Blue channel is modified
|
||||
with pytest.raises(AssertionError):
|
||||
np.testing.assert_array_equal(results_albu['img'][..., 0],
|
||||
results_load['img'][..., 0])
|
||||
|
Loading…
x
Reference in New Issue
Block a user