mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Add albu transform (#2710)
This commit is contained in:
parent
d6079bc3f3
commit
49f2a71953
@ -77,6 +77,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch<< parameters.torch >>/index.html
|
python -m pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cpu/torch<< parameters.torch >>/index.html
|
||||||
python -m pip install -r requirements.txt
|
python -m pip install -r requirements.txt
|
||||||
|
python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||||
- run:
|
- run:
|
||||||
name: Build and install
|
name: Build and install
|
||||||
command: |
|
command: |
|
||||||
@ -119,6 +120,8 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
docker exec mmseg pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/${MMCV_CUDA}/torch${MMCV_TORCH}/index.html
|
docker exec mmseg pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/${MMCV_CUDA}/torch${MMCV_TORCH}/index.html
|
||||||
docker exec mmseg pip install -r requirements.txt
|
docker exec mmseg pip install -r requirements.txt
|
||||||
|
docker exec mmseg pip install typing-extensions -U
|
||||||
|
docker exec mmseg pip install albumentations --use-pep517 qudida albumentations
|
||||||
- run:
|
- run:
|
||||||
name: Build and install
|
name: Build and install
|
||||||
command: |
|
command: |
|
||||||
|
8
.github/workflows/build.yml
vendored
8
.github/workflows/build.yml
vendored
@ -64,6 +64,7 @@ jobs:
|
|||||||
- name: Install unittest dependencies
|
- name: Install unittest dependencies
|
||||||
run: |
|
run: |
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||||
- name: Build and install
|
- name: Build and install
|
||||||
run: rm -rf .eggs && pip install -e .
|
run: rm -rf .eggs && pip install -e .
|
||||||
- name: Run unittests and generate coverage report
|
- name: Run unittests and generate coverage report
|
||||||
@ -133,6 +134,7 @@ jobs:
|
|||||||
python -m pip install -U openmim
|
python -m pip install -U openmim
|
||||||
mim install mmcv-full
|
mim install mmcv-full
|
||||||
python -m pip install -r requirements.txt
|
python -m pip install -r requirements.txt
|
||||||
|
python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||||
python -c 'import mmcv; print(mmcv.__version__)'
|
python -c 'import mmcv; print(mmcv.__version__)'
|
||||||
- name: Build and install
|
- name: Build and install
|
||||||
run: |
|
run: |
|
||||||
@ -200,6 +202,7 @@ jobs:
|
|||||||
python -m pip install openmim
|
python -m pip install openmim
|
||||||
mim install mmcv-full
|
mim install mmcv-full
|
||||||
python -m pip install -r requirements.txt
|
python -m pip install -r requirements.txt
|
||||||
|
python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||||
python -c 'import mmcv; print(mmcv.__version__)'
|
python -c 'import mmcv; print(mmcv.__version__)'
|
||||||
- name: Build and install
|
- name: Build and install
|
||||||
run: |
|
run: |
|
||||||
@ -263,6 +266,7 @@ jobs:
|
|||||||
python -m pip install openmim
|
python -m pip install openmim
|
||||||
mim install mmcv-full
|
mim install mmcv-full
|
||||||
python -m pip install -r requirements.txt
|
python -m pip install -r requirements.txt
|
||||||
|
python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||||
python -c 'import mmcv; print(mmcv.__version__)'
|
python -c 'import mmcv; print(mmcv.__version__)'
|
||||||
- name: Build and install
|
- name: Build and install
|
||||||
run: |
|
run: |
|
||||||
@ -301,7 +305,9 @@ jobs:
|
|||||||
pip install -U openmim
|
pip install -U openmim
|
||||||
mim install mmcv-full
|
mim install mmcv-full
|
||||||
- name: Install unittest dependencies
|
- name: Install unittest dependencies
|
||||||
run: pip install -r requirements/tests.txt -r requirements/optional.txt
|
run: |
|
||||||
|
pip install -r requirements/tests.txt -r requirements/optional.txt
|
||||||
|
pip install albumentations>=0.3.2 --no-binary qudida,albumentations
|
||||||
- name: Build and install
|
- name: Build and install
|
||||||
run: pip install -e .
|
run: pip install -e .
|
||||||
- name: Run unittests
|
- name: Run unittests
|
||||||
|
@ -63,6 +63,9 @@ Case b: If you use mmsegmentation as a dependency or third-party package, instal
|
|||||||
pip install mmsegmentation
|
pip install mmsegmentation
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Note:**
|
||||||
|
If you would like to use albumentations, we suggest using pip install -U albumentations --no-binary qudida,albumentations. If you simply use pip install albumentations>=0.3.2, it will install opencv-python-headless simultaneously (even though you have already installed opencv-python). We recommended checking the environment after installing albumentations to ensure that opencv-python and opencv-python-headless are not installed at the same time, because it might cause unexpected issues if they both installed. Please refer to [official documentation](https://albumentations.ai/docs/getting_started/installation/#note-on-opencv-dependencies) for more details.
|
||||||
|
|
||||||
## Verify the installation
|
## Verify the installation
|
||||||
|
|
||||||
To verify whether MMSegmentation is installed correctly, we provide some sample codes to run an inference demo.
|
To verify whether MMSegmentation is installed correctly, we provide some sample codes to run an inference demo.
|
||||||
|
@ -67,6 +67,9 @@ pip install -v -e .
|
|||||||
pip install mmsegmentation
|
pip install mmsegmentation
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**注意:**
|
||||||
|
如果你想使用 albumentations,我们建议使用 pip install-U albumentations --no-binary qudida,albumentations 进行安装。如果您仅使用 pip install albumentations>=0.3.2 进行安装,它将同时安装 opencv-python-headless(即使您已经安装了 opencv-python)。我们建议在安装了 albumentations 后检查环境,以确保没有同时安装 opencv-python 和 opencv-python-headless,因为如果两者都安装了,可能会导致意外问题。请参阅[官方文档](https://albumentations.ai/docs/getting_started/installation/#note-on-opencv-dependencies)了解更多详细信息。
|
||||||
|
|
||||||
## 验证安装
|
## 验证安装
|
||||||
|
|
||||||
为了验证 MMSegmentation 是否安装正确,我们提供了一些示例代码来执行模型推理。
|
为了验证 MMSegmentation 是否安装正确,我们提供了一些示例代码来执行模型推理。
|
||||||
|
@ -4,7 +4,7 @@ from .formatting import (Collect, ImageToTensor, ToDataContainer, ToTensor,
|
|||||||
Transpose, to_tensor)
|
Transpose, to_tensor)
|
||||||
from .loading import LoadAnnotations, LoadImageFromFile
|
from .loading import LoadAnnotations, LoadImageFromFile
|
||||||
from .test_time_aug import MultiScaleFlipAug
|
from .test_time_aug import MultiScaleFlipAug
|
||||||
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
|
from .transforms import (CLAHE, AdjustGamma, Albu, Normalize, Pad,
|
||||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||||
RandomFlip, RandomMosaic, RandomRotate, Rerange,
|
RandomFlip, RandomMosaic, RandomRotate, Rerange,
|
||||||
Resize, RGB2Gray, SegRescale)
|
Resize, RGB2Gray, SegRescale)
|
||||||
@ -15,5 +15,5 @@ __all__ = [
|
|||||||
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
|
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
|
||||||
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
|
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
|
||||||
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut',
|
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut',
|
||||||
'RandomMosaic'
|
'RandomMosaic', 'Albu'
|
||||||
]
|
]
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import copy
|
import copy
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import cv2
|
||||||
import mmcv
|
import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mmcv.utils import deprecated_api_warning, is_tuple_of
|
from mmcv.utils import deprecated_api_warning, is_tuple_of
|
||||||
@ -8,6 +10,13 @@ from numpy import random
|
|||||||
|
|
||||||
from ..builder import PIPELINES
|
from ..builder import PIPELINES
|
||||||
|
|
||||||
|
try:
|
||||||
|
import albumentations
|
||||||
|
from albumentations import Compose
|
||||||
|
except ImportError:
|
||||||
|
albumentations = None
|
||||||
|
Compose = None
|
||||||
|
|
||||||
|
|
||||||
@PIPELINES.register_module()
|
@PIPELINES.register_module()
|
||||||
class ResizeToMultiple(object):
|
class ResizeToMultiple(object):
|
||||||
@ -1333,3 +1342,146 @@ class RandomMosaic(object):
|
|||||||
repr_str += f'pad_val={self.pad_val}, '
|
repr_str += f'pad_val={self.pad_val}, '
|
||||||
repr_str += f'seg_pad_val={self.pad_val})'
|
repr_str += f'seg_pad_val={self.pad_val})'
|
||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module()
|
||||||
|
class Albu:
|
||||||
|
"""Albumentation augmentation. Adds custom transformations from
|
||||||
|
Albumentations library. Please, visit
|
||||||
|
`https://albumentations.readthedocs.io` to get more information. An example
|
||||||
|
of ``transforms`` is as followed:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
[
|
||||||
|
dict(
|
||||||
|
type='ShiftScaleRotate',
|
||||||
|
shift_limit=0.0625,
|
||||||
|
scale_limit=0.0,
|
||||||
|
rotate_limit=0,
|
||||||
|
interpolation=1,
|
||||||
|
p=0.5),
|
||||||
|
dict(
|
||||||
|
type='RandomBrightnessContrast',
|
||||||
|
brightness_limit=[0.1, 0.3],
|
||||||
|
contrast_limit=[0.1, 0.3],
|
||||||
|
p=0.2),
|
||||||
|
dict(type='ChannelShuffle', p=0.1),
|
||||||
|
dict(
|
||||||
|
type='OneOf',
|
||||||
|
transforms=[
|
||||||
|
dict(type='Blur', blur_limit=3, p=1.0),
|
||||||
|
dict(type='MedianBlur', blur_limit=3, p=1.0)
|
||||||
|
],
|
||||||
|
p=0.1),
|
||||||
|
]
|
||||||
|
Args:
|
||||||
|
transforms (list[dict]): A list of albu transformations
|
||||||
|
keymap (dict): Contains {'input key':'albumentation-style key'}
|
||||||
|
update_pad_shape (bool): Whether to update padding shape according to \
|
||||||
|
the output shape of the last transform
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, transforms, keymap=None, update_pad_shape=False):
|
||||||
|
if Compose is None:
|
||||||
|
raise ImportError(
|
||||||
|
'albumentations is not installed, '
|
||||||
|
'we suggest install albumentation by '
|
||||||
|
'"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa
|
||||||
|
)
|
||||||
|
|
||||||
|
# Args will be modified later, copying it will be safer
|
||||||
|
transforms = copy.deepcopy(transforms)
|
||||||
|
|
||||||
|
self.transforms = transforms
|
||||||
|
self.filter_lost_elements = False
|
||||||
|
self.update_pad_shape = update_pad_shape
|
||||||
|
|
||||||
|
self.aug = Compose([self.albu_builder(t) for t in self.transforms])
|
||||||
|
|
||||||
|
if not keymap:
|
||||||
|
self.keymap_to_albu = {
|
||||||
|
'img': 'image',
|
||||||
|
'gt_masks': 'masks',
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
self.keymap_to_albu = copy.deepcopy(keymap)
|
||||||
|
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
|
||||||
|
|
||||||
|
def albu_builder(self, cfg):
|
||||||
|
"""Import a module from albumentations.
|
||||||
|
|
||||||
|
It inherits some of :func:`build_from_cfg` logic.
|
||||||
|
Args:
|
||||||
|
cfg (dict): Config dict. It should at least contain the key "type".
|
||||||
|
Returns:
|
||||||
|
obj: The constructed object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert isinstance(cfg, dict) and 'type' in cfg
|
||||||
|
args = cfg.copy()
|
||||||
|
|
||||||
|
obj_type = args.pop('type')
|
||||||
|
if mmcv.is_str(obj_type):
|
||||||
|
if albumentations is None:
|
||||||
|
raise ImportError(
|
||||||
|
'albumentations is not installed, '
|
||||||
|
'we suggest install albumentation by '
|
||||||
|
'"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa
|
||||||
|
)
|
||||||
|
obj_cls = getattr(albumentations, obj_type)
|
||||||
|
elif inspect.isclass(obj_type):
|
||||||
|
obj_cls = obj_type
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f'type must be a str or valid type, but got {type(obj_type)}')
|
||||||
|
|
||||||
|
if 'transforms' in args:
|
||||||
|
args['transforms'] = [
|
||||||
|
self.albu_builder(transform)
|
||||||
|
for transform in args['transforms']
|
||||||
|
]
|
||||||
|
|
||||||
|
return obj_cls(**args)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mapper(d, keymap):
|
||||||
|
"""Dictionary mapper.
|
||||||
|
|
||||||
|
Renames keys according to keymap provided.
|
||||||
|
Args:
|
||||||
|
d (dict): old dict
|
||||||
|
keymap (dict): {'old_key':'new_key'}
|
||||||
|
Returns:
|
||||||
|
dict: new dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
updated_dict = {}
|
||||||
|
for k, _ in zip(d.keys(), d.values()):
|
||||||
|
new_k = keymap.get(k, k)
|
||||||
|
updated_dict[new_k] = d[k]
|
||||||
|
return updated_dict
|
||||||
|
|
||||||
|
def __call__(self, results):
|
||||||
|
# dict to albumentations format
|
||||||
|
results = self.mapper(results, self.keymap_to_albu)
|
||||||
|
|
||||||
|
# Convert to RGB since Albumentations works with RGB images
|
||||||
|
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
results = self.aug(**results)
|
||||||
|
|
||||||
|
# Convert back to BGR
|
||||||
|
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# back to the original format
|
||||||
|
results = self.mapper(results, self.keymap_back)
|
||||||
|
|
||||||
|
# update final shape
|
||||||
|
if self.update_pad_shape:
|
||||||
|
results['pad_shape'] = results['img'].shape
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
|
||||||
|
return repr_str
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
mmcv
|
mmcv
|
||||||
prettytable
|
prettytable
|
||||||
|
scipy
|
||||||
torch
|
torch
|
||||||
torchvision
|
torchvision
|
||||||
|
@ -3,3 +3,4 @@ mmcls>=0.20.1
|
|||||||
numpy
|
numpy
|
||||||
packaging
|
packaging
|
||||||
prettytable
|
prettytable
|
||||||
|
scipy
|
||||||
|
@ -688,3 +688,63 @@ def test_mosaic():
|
|||||||
mosaic_module = build_from_cfg(transform, PIPELINES)
|
mosaic_module = build_from_cfg(transform, PIPELINES)
|
||||||
results = mosaic_module(results)
|
results = mosaic_module(results)
|
||||||
assert results['img'].shape[:2] == (20, 24)
|
assert results['img'].shape[:2] == (20, 24)
|
||||||
|
|
||||||
|
|
||||||
|
def test_albu_transform():
|
||||||
|
results = dict(
|
||||||
|
img_prefix=osp.join(osp.dirname(__file__), '../data'),
|
||||||
|
img_info=dict(filename='color.jpg'))
|
||||||
|
|
||||||
|
# Define simple pipeline
|
||||||
|
load = dict(type='LoadImageFromFile')
|
||||||
|
load = build_from_cfg(load, PIPELINES)
|
||||||
|
|
||||||
|
albu_transform = dict(
|
||||||
|
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
|
||||||
|
albu_transform = build_from_cfg(albu_transform, PIPELINES)
|
||||||
|
|
||||||
|
normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
|
||||||
|
normalize = build_from_cfg(normalize, PIPELINES)
|
||||||
|
|
||||||
|
# Execute transforms
|
||||||
|
results = load(results)
|
||||||
|
results = albu_transform(results)
|
||||||
|
results = normalize(results)
|
||||||
|
|
||||||
|
assert results['img'].dtype == np.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_albu_channel_order():
|
||||||
|
results = dict(
|
||||||
|
img_prefix=osp.join(osp.dirname(__file__), '../data'),
|
||||||
|
img_info=dict(filename='color.jpg'))
|
||||||
|
|
||||||
|
# Define simple pipeline
|
||||||
|
load = dict(type='LoadImageFromFile')
|
||||||
|
load = build_from_cfg(load, PIPELINES)
|
||||||
|
|
||||||
|
# Transform is modifying B channel
|
||||||
|
albu_transform = dict(
|
||||||
|
type='Albu',
|
||||||
|
transforms=[
|
||||||
|
dict(
|
||||||
|
type='RGBShift',
|
||||||
|
r_shift_limit=0,
|
||||||
|
g_shift_limit=0,
|
||||||
|
b_shift_limit=200,
|
||||||
|
p=1)
|
||||||
|
])
|
||||||
|
albu_transform = build_from_cfg(albu_transform, PIPELINES)
|
||||||
|
|
||||||
|
# Execute transforms
|
||||||
|
results_load = load(results)
|
||||||
|
results_albu = albu_transform(results_load)
|
||||||
|
|
||||||
|
# assert only Green and Red channel are not modified
|
||||||
|
np.testing.assert_array_equal(results_albu['img'][..., 1:],
|
||||||
|
results_load['img'][..., 1:])
|
||||||
|
|
||||||
|
# assert Blue channel is modified
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
np.testing.assert_array_equal(results_albu['img'][..., 0],
|
||||||
|
results_load['img'][..., 0])
|
||||||
|
@ -140,8 +140,7 @@ def test_beit_init():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
model = BEiT(img_size=(512, 512))
|
model = BEiT(img_size=(512, 512))
|
||||||
with pytest.raises(AttributeError):
|
ckpt = model.resize_rel_pos_embed(ckpt)
|
||||||
model.resize_rel_pos_embed(ckpt)
|
|
||||||
|
|
||||||
# pretrained=None
|
# pretrained=None
|
||||||
# init_cfg=123, whose type is unsupported
|
# init_cfg=123, whose type is unsupported
|
||||||
|
@ -138,10 +138,19 @@ def test_mae_init():
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
model = MAE(img_size=(512, 512))
|
model = MAE(img_size=(512, 512))
|
||||||
with pytest.raises(AttributeError):
|
ckpt = model.resize_rel_pos_embed(ckpt)
|
||||||
model.resize_rel_pos_embed(ckpt)
|
|
||||||
|
|
||||||
# test resize abs pos embed
|
# test resize abs pos embed
|
||||||
|
value = torch.randn(732, 16)
|
||||||
|
abs_pos_embed_value = torch.rand(1, 17, 768)
|
||||||
|
ckpt = {
|
||||||
|
'state_dict': {
|
||||||
|
'layers.0.attn.relative_position_index': 0,
|
||||||
|
'layers.0.attn.relative_position_bias_table': value,
|
||||||
|
'pos_embed': abs_pos_embed_value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
model = MAE(img_size=(512, 512))
|
||||||
ckpt = model.resize_abs_pos_embed(ckpt['state_dict'])
|
ckpt = model.resize_abs_pos_embed(ckpt['state_dict'])
|
||||||
|
|
||||||
# pretrained=None
|
# pretrained=None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user