mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Add albu transform (#2710)
This commit is contained in:
parent
d6079bc3f3
commit
49f2a71953
@ -77,6 +77,7 @@ jobs:
|
||||
command: |
|
||||
python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch<< parameters.torch >>/index.html
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||
- run:
|
||||
name: Build and install
|
||||
command: |
|
||||
@ -119,6 +120,8 @@ jobs:
|
||||
command: |
|
||||
docker exec mmseg pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/${MMCV_CUDA}/torch${MMCV_TORCH}/index.html
|
||||
docker exec mmseg pip install -r requirements.txt
|
||||
docker exec mmseg pip install typing-extensions -U
|
||||
docker exec mmseg pip install albumentations --use-pep517 qudida albumentations
|
||||
- run:
|
||||
name: Build and install
|
||||
command: |
|
||||
|
8
.github/workflows/build.yml
vendored
8
.github/workflows/build.yml
vendored
@ -64,6 +64,7 @@ jobs:
|
||||
- name: Install unittest dependencies
|
||||
run: |
|
||||
pip install -r requirements.txt
|
||||
pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||
- name: Build and install
|
||||
run: rm -rf .eggs && pip install -e .
|
||||
- name: Run unittests and generate coverage report
|
||||
@ -133,6 +134,7 @@ jobs:
|
||||
python -m pip install -U openmim
|
||||
mim install mmcv-full
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||
python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- name: Build and install
|
||||
run: |
|
||||
@ -200,6 +202,7 @@ jobs:
|
||||
python -m pip install openmim
|
||||
mim install mmcv-full
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||
python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- name: Build and install
|
||||
run: |
|
||||
@ -263,6 +266,7 @@ jobs:
|
||||
python -m pip install openmim
|
||||
mim install mmcv-full
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||
python -c 'import mmcv; print(mmcv.__version__)'
|
||||
- name: Build and install
|
||||
run: |
|
||||
@ -301,7 +305,9 @@ jobs:
|
||||
pip install -U openmim
|
||||
mim install mmcv-full
|
||||
- name: Install unittest dependencies
|
||||
run: pip install -r requirements/tests.txt -r requirements/optional.txt
|
||||
run: |
|
||||
pip install -r requirements/tests.txt -r requirements/optional.txt
|
||||
pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||
- name: Build and install
|
||||
run: pip install -e .
|
||||
- name: Run unittests
|
||||
|
@ -63,6 +63,9 @@ Case b: If you use mmsegmentation as a dependency or third-party package, instal
|
||||
pip install mmsegmentation
|
||||
```
|
||||
|
||||
**Note:**
|
||||
If you would like to use albumentations, we suggest using pip install -U albumentations --no-binary qudida,albumentations. If you simply use pip install albumentations>=0.3.2, it will install opencv-python-headless simultaneously (even though you have already installed opencv-python). We recommended checking the environment after installing albumentations to ensure that opencv-python and opencv-python-headless are not installed at the same time, because it might cause unexpected issues if they both installed. Please refer to [official documentation](https://albumentations.ai/docs/getting_started/installation/#note-on-opencv-dependencies) for more details.
|
||||
|
||||
## Verify the installation
|
||||
|
||||
To verify whether MMSegmentation is installed correctly, we provide some sample codes to run an inference demo.
|
||||
|
@ -67,6 +67,9 @@ pip install -v -e .
|
||||
pip install mmsegmentation
|
||||
```
|
||||
|
||||
**注意:**
|
||||
如果你想使用 albumentations,我们建议使用 pip install-U albumentations --no-binary qudida,albumentations 进行安装。如果您仅使用 pip install albumentations>=0.3.2 进行安装,它将同时安装 opencv-python-headless(即使您已经安装了 opencv-python)。我们建议在安装了 albumentations 后检查环境,以确保没有同时安装 opencv-python 和 opencv-python-headless,因为如果两者都安装了,可能会导致意外问题。请参阅[官方文档](https://albumentations.ai/docs/getting_started/installation/#note-on-opencv-dependencies)了解更多详细信息。
|
||||
|
||||
## 验证安装
|
||||
|
||||
为了验证 MMSegmentation 是否安装正确,我们提供了一些示例代码来执行模型推理。
|
||||
|
@ -4,7 +4,7 @@ from .formatting import (Collect, ImageToTensor, ToDataContainer, ToTensor,
|
||||
Transpose, to_tensor)
|
||||
from .loading import LoadAnnotations, LoadImageFromFile
|
||||
from .test_time_aug import MultiScaleFlipAug
|
||||
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
|
||||
from .transforms import (CLAHE, AdjustGamma, Albu, Normalize, Pad,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomFlip, RandomMosaic, RandomRotate, Rerange,
|
||||
Resize, RGB2Gray, SegRescale)
|
||||
@ -15,5 +15,5 @@ __all__ = [
|
||||
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
|
||||
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
|
||||
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut',
|
||||
'RandomMosaic'
|
||||
'RandomMosaic', 'Albu'
|
||||
]
|
||||
|
@ -1,6 +1,8 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.utils import deprecated_api_warning, is_tuple_of
|
||||
@ -8,6 +10,13 @@ from numpy import random
|
||||
|
||||
from ..builder import PIPELINES
|
||||
|
||||
try:
|
||||
import albumentations
|
||||
from albumentations import Compose
|
||||
except ImportError:
|
||||
albumentations = None
|
||||
Compose = None
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ResizeToMultiple(object):
|
||||
@ -1333,3 +1342,146 @@ class RandomMosaic(object):
|
||||
repr_str += f'pad_val={self.pad_val}, '
|
||||
repr_str += f'seg_pad_val={self.pad_val})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Albu:
|
||||
"""Albumentation augmentation. Adds custom transformations from
|
||||
Albumentations library. Please, visit
|
||||
`https://albumentations.readthedocs.io` to get more information. An example
|
||||
of ``transforms`` is as followed:
|
||||
|
||||
.. code-block::
|
||||
[
|
||||
dict(
|
||||
type='ShiftScaleRotate',
|
||||
shift_limit=0.0625,
|
||||
scale_limit=0.0,
|
||||
rotate_limit=0,
|
||||
interpolation=1,
|
||||
p=0.5),
|
||||
dict(
|
||||
type='RandomBrightnessContrast',
|
||||
brightness_limit=[0.1, 0.3],
|
||||
contrast_limit=[0.1, 0.3],
|
||||
p=0.2),
|
||||
dict(type='ChannelShuffle', p=0.1),
|
||||
dict(
|
||||
type='OneOf',
|
||||
transforms=[
|
||||
dict(type='Blur', blur_limit=3, p=1.0),
|
||||
dict(type='MedianBlur', blur_limit=3, p=1.0)
|
||||
],
|
||||
p=0.1),
|
||||
]
|
||||
Args:
|
||||
transforms (list[dict]): A list of albu transformations
|
||||
keymap (dict): Contains {'input key':'albumentation-style key'}
|
||||
update_pad_shape (bool): Whether to update padding shape according to \
|
||||
the output shape of the last transform
|
||||
"""
|
||||
|
||||
def __init__(self, transforms, keymap=None, update_pad_shape=False):
|
||||
if Compose is None:
|
||||
raise ImportError(
|
||||
'albumentations is not installed, '
|
||||
'we suggest install albumentation by '
|
||||
'"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa
|
||||
)
|
||||
|
||||
# Args will be modified later, copying it will be safer
|
||||
transforms = copy.deepcopy(transforms)
|
||||
|
||||
self.transforms = transforms
|
||||
self.filter_lost_elements = False
|
||||
self.update_pad_shape = update_pad_shape
|
||||
|
||||
self.aug = Compose([self.albu_builder(t) for t in self.transforms])
|
||||
|
||||
if not keymap:
|
||||
self.keymap_to_albu = {
|
||||
'img': 'image',
|
||||
'gt_masks': 'masks',
|
||||
}
|
||||
else:
|
||||
self.keymap_to_albu = copy.deepcopy(keymap)
|
||||
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
|
||||
|
||||
def albu_builder(self, cfg):
|
||||
"""Import a module from albumentations.
|
||||
|
||||
It inherits some of :func:`build_from_cfg` logic.
|
||||
Args:
|
||||
cfg (dict): Config dict. It should at least contain the key "type".
|
||||
Returns:
|
||||
obj: The constructed object.
|
||||
"""
|
||||
|
||||
assert isinstance(cfg, dict) and 'type' in cfg
|
||||
args = cfg.copy()
|
||||
|
||||
obj_type = args.pop('type')
|
||||
if mmcv.is_str(obj_type):
|
||||
if albumentations is None:
|
||||
raise ImportError(
|
||||
'albumentations is not installed, '
|
||||
'we suggest install albumentation by '
|
||||
'"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa
|
||||
)
|
||||
obj_cls = getattr(albumentations, obj_type)
|
||||
elif inspect.isclass(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(
|
||||
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||
|
||||
if 'transforms' in args:
|
||||
args['transforms'] = [
|
||||
self.albu_builder(transform)
|
||||
for transform in args['transforms']
|
||||
]
|
||||
|
||||
return obj_cls(**args)
|
||||
|
||||
@staticmethod
|
||||
def mapper(d, keymap):
|
||||
"""Dictionary mapper.
|
||||
|
||||
Renames keys according to keymap provided.
|
||||
Args:
|
||||
d (dict): old dict
|
||||
keymap (dict): {'old_key':'new_key'}
|
||||
Returns:
|
||||
dict: new dict.
|
||||
"""
|
||||
|
||||
updated_dict = {}
|
||||
for k, _ in zip(d.keys(), d.values()):
|
||||
new_k = keymap.get(k, k)
|
||||
updated_dict[new_k] = d[k]
|
||||
return updated_dict
|
||||
|
||||
def __call__(self, results):
|
||||
# dict to albumentations format
|
||||
results = self.mapper(results, self.keymap_to_albu)
|
||||
|
||||
# Convert to RGB since Albumentations works with RGB images
|
||||
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB)
|
||||
|
||||
results = self.aug(**results)
|
||||
|
||||
# Convert back to BGR
|
||||
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_RGB2BGR)
|
||||
|
||||
# back to the original format
|
||||
results = self.mapper(results, self.keymap_back)
|
||||
|
||||
# update final shape
|
||||
if self.update_pad_shape:
|
||||
results['pad_shape'] = results['img'].shape
|
||||
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
|
||||
return repr_str
|
||||
|
@ -1,4 +1,5 @@
|
||||
mmcv
|
||||
prettytable
|
||||
scipy
|
||||
torch
|
||||
torchvision
|
||||
|
@ -3,3 +3,4 @@ mmcls>=0.20.1
|
||||
numpy
|
||||
packaging
|
||||
prettytable
|
||||
scipy
|
||||
|
@ -688,3 +688,63 @@ def test_mosaic():
|
||||
mosaic_module = build_from_cfg(transform, PIPELINES)
|
||||
results = mosaic_module(results)
|
||||
assert results['img'].shape[:2] == (20, 24)
|
||||
|
||||
|
||||
def test_albu_transform():
|
||||
results = dict(
|
||||
img_prefix=osp.join(osp.dirname(__file__), '../data'),
|
||||
img_info=dict(filename='color.jpg'))
|
||||
|
||||
# Define simple pipeline
|
||||
load = dict(type='LoadImageFromFile')
|
||||
load = build_from_cfg(load, PIPELINES)
|
||||
|
||||
albu_transform = dict(
|
||||
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
|
||||
albu_transform = build_from_cfg(albu_transform, PIPELINES)
|
||||
|
||||
normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
|
||||
normalize = build_from_cfg(normalize, PIPELINES)
|
||||
|
||||
# Execute transforms
|
||||
results = load(results)
|
||||
results = albu_transform(results)
|
||||
results = normalize(results)
|
||||
|
||||
assert results['img'].dtype == np.float32
|
||||
|
||||
|
||||
def test_albu_channel_order():
|
||||
results = dict(
|
||||
img_prefix=osp.join(osp.dirname(__file__), '../data'),
|
||||
img_info=dict(filename='color.jpg'))
|
||||
|
||||
# Define simple pipeline
|
||||
load = dict(type='LoadImageFromFile')
|
||||
load = build_from_cfg(load, PIPELINES)
|
||||
|
||||
# Transform is modifying B channel
|
||||
albu_transform = dict(
|
||||
type='Albu',
|
||||
transforms=[
|
||||
dict(
|
||||
type='RGBShift',
|
||||
r_shift_limit=0,
|
||||
g_shift_limit=0,
|
||||
b_shift_limit=200,
|
||||
p=1)
|
||||
])
|
||||
albu_transform = build_from_cfg(albu_transform, PIPELINES)
|
||||
|
||||
# Execute transforms
|
||||
results_load = load(results)
|
||||
results_albu = albu_transform(results_load)
|
||||
|
||||
# assert only Green and Red channel are not modified
|
||||
np.testing.assert_array_equal(results_albu['img'][..., 1:],
|
||||
results_load['img'][..., 1:])
|
||||
|
||||
# assert Blue channel is modified
|
||||
with pytest.raises(AssertionError):
|
||||
np.testing.assert_array_equal(results_albu['img'][..., 0],
|
||||
results_load['img'][..., 0])
|
||||
|
@ -140,8 +140,7 @@ def test_beit_init():
|
||||
}
|
||||
}
|
||||
model = BEiT(img_size=(512, 512))
|
||||
with pytest.raises(AttributeError):
|
||||
model.resize_rel_pos_embed(ckpt)
|
||||
ckpt = model.resize_rel_pos_embed(ckpt)
|
||||
|
||||
# pretrained=None
|
||||
# init_cfg=123, whose type is unsupported
|
||||
|
@ -138,10 +138,19 @@ def test_mae_init():
|
||||
}
|
||||
}
|
||||
model = MAE(img_size=(512, 512))
|
||||
with pytest.raises(AttributeError):
|
||||
model.resize_rel_pos_embed(ckpt)
|
||||
ckpt = model.resize_rel_pos_embed(ckpt)
|
||||
|
||||
# test resize abs pos embed
|
||||
value = torch.randn(732, 16)
|
||||
abs_pos_embed_value = torch.rand(1, 17, 768)
|
||||
ckpt = {
|
||||
'state_dict': {
|
||||
'layers.0.attn.relative_position_index': 0,
|
||||
'layers.0.attn.relative_position_bias_table': value,
|
||||
'pos_embed': abs_pos_embed_value
|
||||
}
|
||||
}
|
||||
model = MAE(img_size=(512, 512))
|
||||
ckpt = model.resize_abs_pos_embed(ckpt['state_dict'])
|
||||
|
||||
# pretrained=None
|
||||
|
Loading…
x
Reference in New Issue
Block a user