mmsegmentation/tests/test_datasets/test_transform.py
Junhwa Song cb2d8fe085
[Enhance] Support multi-band image for Mosaic (#2748)
## Modification

I changed the hardcoded 3 channel length to dynamic channel length in
`np.full` function arguments.
This modification enables `RandomMosaic` transform to support
multispectral image (e.g. RGB image with NIR band) or bi-temporal image
pairs for change detection task.

## Checklist

1. Pre-commit or other linting tools are used to fix the potential lint
issues.
2. The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
3. If the modification has potential influence on downstream projects,
this PR should be tested with downstream projects, like MMDet or
MMDet3D.
4. The documentation has been modified accordingly, like docstring or
example tutorials.
2023-03-15 19:36:47 +08:00

1163 lines
41 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
import mmcv
import numpy as np
import pytest
from mmengine.registry import init_default_scope
from PIL import Image
from mmseg.datasets.transforms import * # noqa
from mmseg.datasets.transforms import (LoadBiomedicalData,
LoadBiomedicalImageFromFile,
PhotoMetricDistortion, RandomCrop)
from mmseg.registry import TRANSFORMS
init_default_scope('mmseg')
def test_resize():
# Test `Resize`, `RandomResize` and `RandomChoiceResize` from
# MMCV transform. Noted: `RandomResize` has args `scales` but
# `Resize` and `RandomResize` has args `scale`.
transform = dict(type='Resize', scale=(1333, 800), keep_ratio=True)
resize_module = TRANSFORMS.build(transform)
results = dict()
# (288, 512, 3)
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
resized_results = resize_module(results.copy())
# img_shape = results['img'].shape[:2] in ``MMCV resize`` function
# so right now it is (750, 1333) rather than (750, 1333, 3)
assert resized_results['img_shape'] == (750, 1333)
# test keep_ratio=False
transform = dict(
type='RandomResize',
scale=(1280, 800),
ratio_range=(1.0, 1.0),
resize_type='Resize',
keep_ratio=False)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'] == (800, 1280)
# test `RandomChoiceResize`, which in older mmsegmentation
# `Resize` is multiscale_mode='range'
transform = dict(type='RandomResize', scale=[(1333, 400), (1333, 1200)])
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert max(resized_results['img_shape'][:2]) <= 1333
assert min(resized_results['img_shape'][:2]) >= 400
assert min(resized_results['img_shape'][:2]) <= 1200
# test RandomChoiceResize, which in older mmsegmentation
# `Resize` is multiscale_mode='value'
transform = dict(
type='RandomChoiceResize',
scales=[(1333, 800), (1333, 400)],
resize_type='Resize',
keep_ratio=False)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'] in [(800, 1333), (400, 1333)]
transform = dict(type='Resize', scale_factor=(0.9, 1.1), keep_ratio=True)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert max(resized_results['img_shape'][:2]) <= 1333 * 1.1
# test RandomChoiceResize, which `resize_type` is `ResizeShortestEdge`
transform = dict(
type='RandomChoiceResize',
scales=[128, 256, 512],
resize_type='ResizeShortestEdge',
max_size=1333)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'][0] in [128, 256, 512]
transform = dict(
type='RandomChoiceResize',
scales=[512],
resize_type='ResizeShortestEdge',
max_size=512)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'][1] == 512
transform = dict(
type='RandomChoiceResize',
scales=[(128, 256), (256, 512), (512, 1024)],
resize_type='ResizeShortestEdge',
max_size=1333)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'][0] in [128, 256, 512]
# test scale=None and scale_factor is tuple.
# img shape: (288, 512, 3)
transform = dict(
type='Resize', scale=None, scale_factor=(0.5, 2.0), keep_ratio=True)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert int(288 * 0.5) <= resized_results['img_shape'][0] <= 288 * 2.0
assert int(512 * 0.5) <= resized_results['img_shape'][1] <= 512 * 2.0
# test minimum resized image shape is 640
transform = dict(type='Resize', scale=(2560, 640), keep_ratio=True)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'] == (640, 1138)
# test minimum resized image shape is 640 when img_scale=(512, 640)
# where should define `scale_factor` in MMCV new ``Resize`` function.
min_size_ratio = max(640 / img.shape[0], 640 / img.shape[1])
transform = dict(
type='Resize', scale_factor=min_size_ratio, keep_ratio=True)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'] == (640, 1138)
# test h > w
img = np.random.randn(512, 288, 3)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
min_size_ratio = max(640 / img.shape[0], 640 / img.shape[1])
transform = dict(
type='Resize',
scale=(2560, 640),
scale_factor=min_size_ratio,
keep_ratio=True)
resize_module = TRANSFORMS.build(transform)
resized_results = resize_module(results.copy())
assert resized_results['img_shape'] == (1138, 640)
def test_flip():
# test assertion for invalid prob
with pytest.raises(AssertionError):
transform = dict(type='RandomFlip', prob=1.5)
TRANSFORMS.build(transform)
# test assertion for invalid direction
with pytest.raises(AssertionError):
transform = dict(type='RandomFlip', prob=1.0, direction='horizonta')
TRANSFORMS.build(transform)
transform = dict(type='RandomFlip', prob=1.0)
flip_module = TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
original_seg = copy.deepcopy(seg)
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = flip_module(results)
flip_module = TRANSFORMS.build(transform)
results = flip_module(results)
assert np.equal(original_img, results['img']).all()
assert np.equal(original_seg, results['gt_semantic_seg']).all()
def test_random_rotate_flip():
with pytest.raises(AssertionError):
transform = dict(type='RandomRotFlip', flip_prob=1.5)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(type='RandomRotFlip', rotate_prob=1.5)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(type='RandomRotFlip', degree=[20, 20, 20])
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(type='RandomRotFlip', degree=-20)
TRANSFORMS.build(transform)
transform = dict(
type='RandomRotFlip', flip_prob=1.0, rotate_prob=0, degree=20)
rot_flip_module = TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(
osp.dirname(__file__),
'../data/pseudo_synapse_dataset/img_dir/case0005_slice000.jpg'),
'color')
original_img = copy.deepcopy(img)
seg = np.array(
Image.open(
osp.join(
osp.dirname(__file__),
'../data/pseudo_synapse_dataset/ann_dir/case0005_slice000.png')
))
original_seg = copy.deepcopy(seg)
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
result_flip = rot_flip_module(results)
assert original_img.shape == result_flip['img'].shape
assert original_seg.shape == result_flip['gt_semantic_seg'].shape
transform = dict(
type='RandomRotFlip', flip_prob=0, rotate_prob=1.0, degree=20)
rot_flip_module = TRANSFORMS.build(transform)
result_rotate = rot_flip_module(results)
assert original_img.shape == result_rotate['img'].shape
assert original_seg.shape == result_rotate['gt_semantic_seg'].shape
assert str(transform) == "{'type': 'RandomRotFlip'," \
" 'flip_prob': 0," \
" 'rotate_prob': 1.0," \
" 'degree': 20}"
def test_pad():
# test assertion if both size_divisor and size is None
with pytest.raises(AssertionError):
transform = dict(type='Pad')
TRANSFORMS.build(transform)
transform = dict(type='Pad', size_divisor=32)
transform = TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
# original img already divisible by 32
assert np.equal(results['img'], original_img).all()
img_shape = results['img'].shape
assert img_shape[0] % 32 == 0
assert img_shape[1] % 32 == 0
def test_normalize():
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
transform = dict(type='Normalize', **img_norm_cfg)
transform = TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
mean = np.array(img_norm_cfg['mean'])
std = np.array(img_norm_cfg['std'])
converted_img = (original_img[..., ::-1] - mean) / std
assert np.allclose(results['img'], converted_img)
def test_random_crop():
# test assertion for invalid random crop
with pytest.raises(AssertionError):
RandomCrop(crop_size=(-1, 0))
results = dict()
img = mmcv.imread(osp.join('tests/data/color.jpg'), 'color')
seg = np.array(Image.open(osp.join('tests/data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
h, w, _ = img.shape
pipeline = RandomCrop(crop_size=(h - 20, w - 20))
results = pipeline(results)
assert results['img'].shape[:2] == (h - 20, w - 20)
assert results['img_shape'] == (h - 20, w - 20)
assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20)
def test_rgb2gray():
# test assertion out_channels should be greater than 0
with pytest.raises(AssertionError):
transform = dict(type='RGB2Gray', out_channels=-1)
TRANSFORMS.build(transform)
# test assertion weights should be tuple[float]
with pytest.raises(AssertionError):
transform = dict(type='RGB2Gray', out_channels=1, weights=1.1)
TRANSFORMS.build(transform)
# test out_channels is None
transform = dict(type='RGB2Gray')
transform = TRANSFORMS.build(transform)
assert str(transform) == f'RGB2Gray(' \
f'out_channels={None}, ' \
f'weights={(0.299, 0.587, 0.114)})'
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, c = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
assert results['img'].shape == (h, w, c)
assert results['img_shape'] == (h, w, c)
assert results['ori_shape'] == (h, w, c)
# test out_channels = 2
transform = dict(type='RGB2Gray', out_channels=2)
transform = TRANSFORMS.build(transform)
assert str(transform) == f'RGB2Gray(' \
f'out_channels={2}, ' \
f'weights={(0.299, 0.587, 0.114)})'
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, c = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
assert results['img'].shape == (h, w, 2)
assert results['img_shape'] == (h, w, 2)
def test_photo_metric_distortion():
results = dict()
img = mmcv.imread(osp.join('tests/data/color.jpg'), 'color')
seg = np.array(Image.open(osp.join('tests/data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
pipeline = PhotoMetricDistortion(saturation_range=(1., 1.))
results = pipeline(results)
assert (results['gt_semantic_seg'] == seg).all()
assert results['img_shape'] == img.shape
def test_rerange():
# test assertion if min_value or max_value is illegal
with pytest.raises(AssertionError):
transform = dict(type='Rerange', min_value=[0], max_value=[255])
TRANSFORMS.build(transform)
# test assertion if min_value >= max_value
with pytest.raises(AssertionError):
transform = dict(type='Rerange', min_value=1, max_value=1)
TRANSFORMS.build(transform)
# test assertion if img_min_value == img_max_value
with pytest.raises(AssertionError):
transform = dict(type='Rerange', min_value=0, max_value=1)
transform = TRANSFORMS.build(transform)
results = dict()
results['img'] = np.array([[1, 1], [1, 1]])
transform(results)
img_rerange_cfg = dict()
transform = dict(type='Rerange', **img_rerange_cfg)
transform = TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
min_value = np.min(original_img)
max_value = np.max(original_img)
converted_img = (original_img - min_value) / (max_value - min_value) * 255
assert np.allclose(results['img'], converted_img)
assert str(transform) == f'Rerange(min_value={0}, max_value={255})'
def test_CLAHE():
# test assertion if clip_limit is None
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', clip_limit=None)
TRANSFORMS.build(transform)
# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0))
TRANSFORMS.build(transform)
# test assertion if tile_grid_size is illegal
with pytest.raises(AssertionError):
transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9))
TRANSFORMS.build(transform)
transform = dict(type='CLAHE', clip_limit=2)
transform = TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
converted_img = np.empty(original_img.shape)
for i in range(original_img.shape[2]):
converted_img[:, :, i] = mmcv.clahe(
np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8))
assert np.allclose(results['img'], converted_img)
assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})'
def test_adjust_gamma():
# test assertion if gamma <= 0
with pytest.raises(AssertionError):
transform = dict(type='AdjustGamma', gamma=0)
TRANSFORMS.build(transform)
# test assertion if gamma is list
with pytest.raises(AssertionError):
transform = dict(type='AdjustGamma', gamma=[1.2])
TRANSFORMS.build(transform)
# test with gamma = 1.2
transform = dict(type='AdjustGamma', gamma=1.2)
transform = TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
original_img = copy.deepcopy(img)
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
inv_gamma = 1.0 / 1.2
table = np.array([((i / 255.0)**inv_gamma) * 255
for i in np.arange(0, 256)]).astype('uint8')
converted_img = mmcv.lut_transform(
np.array(original_img, dtype=np.uint8), table)
assert np.allclose(results['img'], converted_img)
assert str(transform) == f'AdjustGamma(gamma={1.2})'
def test_rotate():
# test assertion degree should be tuple[float] or float
with pytest.raises(AssertionError):
transform = dict(type='RandomRotate', prob=0.5, degree=-10)
TRANSFORMS.build(transform)
# test assertion degree should be tuple[float] or float
with pytest.raises(AssertionError):
transform = dict(type='RandomRotate', prob=0.5, degree=(10., 20., 30.))
TRANSFORMS.build(transform)
transform = dict(type='RandomRotate', degree=10., prob=1.)
transform = TRANSFORMS.build(transform)
assert str(transform) == f'RandomRotate(' \
f'prob={1.}, ' \
f'degree=({-10.}, {10.}), ' \
f'pad_val={0}, ' \
f'seg_pad_val={255}, ' \
f'center={None}, ' \
f'auto_bound={False})'
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, _ = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
results = transform(results)
assert results['img'].shape[:2] == (h, w)
def test_seg_rescale():
results = dict()
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
h, w = seg.shape
transform = dict(type='SegRescale', scale_factor=1. / 2)
rescale_module = TRANSFORMS.build(transform)
rescale_results = rescale_module(results.copy())
assert rescale_results['gt_semantic_seg'].shape == (h // 2, w // 2)
transform = dict(type='SegRescale', scale_factor=1)
rescale_module = TRANSFORMS.build(transform)
rescale_results = rescale_module(results.copy())
assert rescale_results['gt_semantic_seg'].shape == (h, w)
def test_mosaic():
# test prob
with pytest.raises(AssertionError):
transform = dict(type='RandomMosaic', prob=1.5)
TRANSFORMS.build(transform)
# test assertion for invalid img_scale
with pytest.raises(AssertionError):
transform = dict(type='RandomMosaic', prob=1, img_scale=640)
TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = TRANSFORMS.build(transform)
assert 'Mosaic' in repr(mosaic_module)
# test assertion for invalid mix_results
with pytest.raises(AssertionError):
mosaic_module(results)
results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)
results = dict()
results['img'] = img[:, :, 0]
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
transform = dict(type='RandomMosaic', prob=0, img_scale=(10, 12))
mosaic_module = TRANSFORMS.build(transform)
results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[:2] == img.shape[:2]
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = TRANSFORMS.build(transform)
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)
results = dict()
results['img'] = np.concatenate((img, img), axis=2)
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = TRANSFORMS.build(transform)
results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[2] == 6
def test_cutout():
# test prob
with pytest.raises(AssertionError):
transform = dict(type='RandomCutOut', prob=1.5, n_holes=1)
TRANSFORMS.build(transform)
# test n_holes
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut', prob=0.5, n_holes=(5, 3), cutout_shape=(8, 8))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut',
prob=0.5,
n_holes=(3, 4, 5),
cutout_shape=(8, 8))
TRANSFORMS.build(transform)
# test cutout_shape and cutout_ratio
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut', prob=0.5, n_holes=1, cutout_shape=8)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut', prob=0.5, n_holes=1, cutout_ratio=0.2)
TRANSFORMS.build(transform)
# either of cutout_shape and cutout_ratio should be given
with pytest.raises(AssertionError):
transform = dict(type='RandomCutOut', prob=0.5, n_holes=1)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut',
prob=0.5,
n_holes=1,
cutout_shape=(2, 2),
cutout_ratio=(0.4, 0.4))
TRANSFORMS.build(transform)
# test seg_fill_in
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut',
prob=0.5,
n_holes=1,
cutout_shape=(8, 8),
seg_fill_in='a')
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='RandomCutOut',
prob=0.5,
n_holes=1,
cutout_shape=(8, 8),
seg_fill_in=256)
TRANSFORMS.build(transform)
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['pad_shape'] = img.shape
results['img_fields'] = ['img']
transform = dict(
type='RandomCutOut', prob=1, n_holes=1, cutout_shape=(10, 10))
cutout_module = TRANSFORMS.build(transform)
assert 'cutout_shape' in repr(cutout_module)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() < img.sum()
transform = dict(
type='RandomCutOut', prob=1, n_holes=1, cutout_ratio=(0.8, 0.8))
cutout_module = TRANSFORMS.build(transform)
assert 'cutout_ratio' in repr(cutout_module)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() < img.sum()
transform = dict(
type='RandomCutOut', prob=0, n_holes=1, cutout_ratio=(0.8, 0.8))
cutout_module = TRANSFORMS.build(transform)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() == img.sum()
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
transform = dict(
type='RandomCutOut',
prob=1,
n_holes=(2, 4),
cutout_shape=[(10, 10), (15, 15)],
fill_in=(255, 255, 255),
seg_fill_in=None)
cutout_module = TRANSFORMS.build(transform)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() > img.sum()
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
transform = dict(
type='RandomCutOut',
prob=1,
n_holes=1,
cutout_ratio=(0.8, 0.8),
fill_in=(255, 255, 255),
seg_fill_in=255)
cutout_module = TRANSFORMS.build(transform)
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() > img.sum()
assert cutout_result['gt_semantic_seg'].sum() > seg.sum()
def test_resize_to_multiple():
transform = dict(type='ResizeToMultiple', size_divisor=32)
transform = TRANSFORMS.build(transform)
img = np.random.randn(213, 232, 3)
seg = np.random.randint(0, 19, (213, 232))
results = dict()
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['pad_shape'] = img.shape
results = transform(results)
assert results['img'].shape == (224, 256, 3)
assert results['gt_semantic_seg'].shape == (224, 256)
assert results['img_shape'] == (224, 256)
def test_generate_edge():
transform = dict(type='GenerateEdge', edge_width=1)
transform = TRANSFORMS.build(transform)
seg_map = np.array([
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 2],
[1, 1, 1, 2, 2],
[1, 1, 2, 2, 2],
[1, 2, 2, 2, 2],
[2, 2, 2, 2, 2],
])
results = dict()
results['gt_seg_map'] = seg_map
results['img_shape'] = seg_map.shape
results = transform(results)
assert np.all(results['gt_edge_map'] == np.array([
[0, 0, 0, 1, 0],
[0, 0, 1, 1, 1],
[0, 1, 1, 1, 0],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0],
[1, 0, 0, 0, 0],
]))
def test_biomedical3d_random_crop():
# test assertion for invalid random crop
with pytest.raises(AssertionError):
transform = dict(type='BioMedical3DRandomCrop', crop_shape=(-2, -1, 0))
transform = TRANSFORMS.build(transform)
from mmseg.datasets.transforms import (LoadBiomedicalAnnotation,
LoadBiomedicalImageFromFile)
results = dict()
results['img_path'] = osp.join(
osp.dirname(__file__), '../data', 'biomedical.nii.gz')
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
results['seg_map_path'] = osp.join(
osp.dirname(__file__), '../data', 'biomedical_ann.nii.gz')
transform = LoadBiomedicalAnnotation()
results = transform(copy.deepcopy(results))
d, h, w = results['img_shape']
transform = dict(
type='BioMedical3DRandomCrop',
crop_shape=(d - 20, h - 20, w - 20),
keep_foreground=True)
transform = TRANSFORMS.build(transform)
crop_results = transform(results)
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)
transform = dict(
type='BioMedical3DRandomCrop',
crop_shape=(d - 20, h - 20, w - 20),
keep_foreground=False)
transform = TRANSFORMS.build(transform)
crop_results = transform(results)
assert crop_results['img'].shape[1:] == (d - 20, h - 20, w - 20)
assert crop_results['img_shape'] == (d - 20, h - 20, w - 20)
assert crop_results['gt_seg_map'].shape == (d - 20, h - 20, w - 20)
def test_biomedical_gaussian_noise():
# test assertion for invalid prob
with pytest.raises(AssertionError):
transform = dict(type='BioMedicalGaussianNoise', prob=1.5)
TRANSFORMS.build(transform)
# test assertion for invalid std
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianNoise', prob=0.2, mean=0.5, std=-0.5)
TRANSFORMS.build(transform)
transform = dict(type='BioMedicalGaussianNoise', prob=1.0)
noise_module = TRANSFORMS.build(transform)
assert str(noise_module) == 'BioMedicalGaussianNoise'\
'(prob=1.0, ' \
'mean=0.0, ' \
'std=0.1)'
transform = dict(type='BioMedicalGaussianNoise', prob=1.0)
noise_module = TRANSFORMS.build(transform)
results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
original_img = copy.deepcopy(results['img'])
results = noise_module(results)
assert original_img.shape == results['img'].shape
def test_biomedical_gaussian_blur():
# test assertion for invalid prob
with pytest.raises(AssertionError):
transform = dict(type='BioMedicalGaussianBlur', prob=-1.5)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=0.6)
smooth_module = TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.6))
smooth_module = TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(15, 8, 9))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range='0.16')
TRANSFORMS.build(transform)
transform = dict(
type='BioMedicalGaussianBlur', prob=1.0, sigma_range=(0.7, 0.8))
smooth_module = TRANSFORMS.build(transform)
assert str(
smooth_module
) == 'BioMedicalGaussianBlur(prob=1.0, ' \
'prob_per_channel=0.5, '\
'sigma_range=(0.7, 0.8), ' \
'different_sigma_per_channel=True, '\
'different_sigma_per_axis=True)'
transform = dict(type='BioMedicalGaussianBlur', prob=1.0)
smooth_module = TRANSFORMS.build(transform)
assert str(
smooth_module
) == 'BioMedicalGaussianBlur(prob=1.0, ' \
'prob_per_channel=0.5, '\
'sigma_range=(0.5, 1.0), ' \
'different_sigma_per_channel=True, '\
'different_sigma_per_axis=True)'
results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
original_img = copy.deepcopy(results['img'])
results = smooth_module(results)
assert original_img.shape == results['img'].shape
# the max value in the smoothed image should be less than the original one
assert original_img.max() >= results['img'].max()
assert original_img.min() <= results['img'].min()
transform = dict(
type='BioMedicalGaussianBlur',
prob=1.0,
different_sigma_per_axis=False)
smooth_module = TRANSFORMS.build(transform)
results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/biomedical.nii.gz'))
from mmseg.datasets.transforms import LoadBiomedicalImageFromFile
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
original_img = copy.deepcopy(results['img'])
results = smooth_module(results)
assert original_img.shape == results['img'].shape
# the max value in the smoothed image should be less than the original one
assert original_img.max() >= results['img'].max()
assert original_img.min() <= results['img'].min()
def test_BioMedicalRandomGamma():
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma', prob=-1, gamma_range=(0.7, 1.2))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma', prob=1.2, gamma_range=(0.7, 1.2))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma', prob=1.0, gamma_range=(0.7))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 0.2, 0.3))
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 2),
invert_image=1)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 2),
per_channel=1)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 2),
retain_stats=1)
TRANSFORMS.build(transform)
test_img = 'tests/data/biomedical.nii.gz'
results = dict(img_path=test_img)
transform = LoadBiomedicalImageFromFile()
results = transform(copy.deepcopy(results))
origin_img = results['img']
transform2 = dict(
type='BioMedicalRandomGamma',
prob=1.0,
gamma_range=(0.7, 2),
)
transform2 = TRANSFORMS.build(transform2)
results = transform2(results)
transformed_img = results['img']
assert origin_img.shape == transformed_img.shape
def test_BioMedical3DPad():
# test assertion.
with pytest.raises(AssertionError):
transform = dict(type='BioMedical3DPad', pad_shape=None)
TRANSFORMS.build(transform)
with pytest.raises(AssertionError):
transform = dict(type='BioMedical3DPad', pad_shape=[256, 256])
TRANSFORMS.build(transform)
data_info1 = dict(img=np.random.random((8, 6, 4, 4)))
transform = dict(type='BioMedical3DPad', pad_shape=(6, 6, 6))
transform = TRANSFORMS.build(transform)
results = transform(copy.deepcopy(data_info1))
assert results['img'].shape[1:] == (6, 6, 6)
assert results['pad_shape'] == (6, 6, 6)
transform = dict(type='BioMedical3DPad', pad_shape=(4, 6, 6))
transform = TRANSFORMS.build(transform)
results = transform(copy.deepcopy(data_info1))
assert results['img'].shape[1:] == (6, 6, 6)
assert results['pad_shape'] == (6, 6, 6)
data_info2 = dict(
img=np.random.random((8, 6, 4, 4)),
gt_seg_map=np.random.randint(0, 2, (6, 4, 4)))
transform = dict(type='BioMedical3DPad', pad_shape=(6, 6, 6))
transform = TRANSFORMS.build(transform)
results = transform(copy.deepcopy(data_info2))
assert results['img'].shape[1:] == (6, 6, 6)
assert results['gt_seg_map'].shape[1:] == (6, 6, 6)
assert results['pad_shape'] == (6, 6, 6)
transform = dict(type='BioMedical3DPad', pad_shape=(4, 6, 6))
transform = TRANSFORMS.build(transform)
results = transform(copy.deepcopy(data_info2))
assert results['img'].shape[1:] == (6, 6, 6)
assert results['gt_seg_map'].shape[1:] == (6, 6, 6)
assert results['pad_shape'] == (6, 6, 6)
def test_biomedical_3d_flip():
# test assertion for invalid prob
with pytest.raises(AssertionError):
transform = dict(type='BioMedical3DRandomFlip', prob=1.5, axes=(0, 1))
transform = TRANSFORMS.build(transform)
# test assertion for invalid direction
with pytest.raises(AssertionError):
transform = dict(type='BioMedical3DRandomFlip', prob=1, axes=(0, 1, 3))
transform = TRANSFORMS.build(transform)
# test flip axes are (0, 1, 2)
transform = dict(type='BioMedical3DRandomFlip', prob=1, axes=(0, 1, 2))
transform = TRANSFORMS.build(transform)
# test with random 3d data
results = dict()
results['img_path'] = 'Null'
results['img_shape'] = (1, 16, 16, 16)
results['img'] = np.random.randn(1, 16, 16, 16)
results['gt_seg_map'] = np.random.randint(0, 4, (16, 16, 16))
original_img = results['img'].copy()
original_seg = results['gt_seg_map'].copy()
# flip first time
results = transform(results)
with pytest.raises(AssertionError):
assert np.equal(original_img, results['img']).all()
with pytest.raises(AssertionError):
assert np.equal(original_seg, results['gt_seg_map']).all()
# flip second time
results = transform(results)
assert np.equal(original_img, results['img']).all()
assert np.equal(original_seg, results['gt_seg_map']).all()
# test with actual data and flip axes are (0, 1)
# load biomedical 3d img and seg
data_prefix = osp.join(osp.dirname(__file__), '../data')
input_results = dict(img_path=osp.join(data_prefix, 'biomedical.npy'))
biomedical_loader = LoadBiomedicalData(with_seg=True)
data = biomedical_loader(copy.deepcopy(input_results))
results = data.copy()
original_img = data['img'].copy()
original_seg = data['gt_seg_map'].copy()
# test flip axes are (0, 1)
transform = dict(type='BioMedical3DRandomFlip', prob=1, axes=(0, 1))
transform = TRANSFORMS.build(transform)
# flip first time
results = transform(results)
with pytest.raises(AssertionError):
assert np.equal(original_img, results['img']).all()
with pytest.raises(AssertionError):
assert np.equal(original_seg, results['gt_seg_map']).all()
# flip second time
results = transform(results)
assert np.equal(original_img, results['img']).all()
assert np.equal(original_seg, results['gt_seg_map']).all()
# test transform with flip axes = (1)
transform = dict(type='BioMedical3DRandomFlip', prob=1, axes=(1, ))
transform = TRANSFORMS.build(transform)
results = data.copy()
results = transform(results)
results = transform(results)
assert np.equal(original_img, results['img']).all()
assert np.equal(original_seg, results['gt_seg_map']).all()
# test transform with swap_label_pairs
transform = dict(
type='BioMedical3DRandomFlip',
prob=1,
axes=(1, 2),
swap_label_pairs=[(0, 1)])
transform = TRANSFORMS.build(transform)
results = data.copy()
results = transform(results)
with pytest.raises(AssertionError):
assert np.equal(original_seg, results['gt_seg_map']).all()
# swap twice
results = transform(results)
assert np.equal(original_img, results['img']).all()
assert np.equal(original_seg, results['gt_seg_map']).all()