2022-05-09 20:18:15 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
import numpy as np
|
2022-05-10 11:30:17 +08:00
|
|
|
import pytest
|
2022-05-09 20:18:15 +08:00
|
|
|
import torch
|
|
|
|
|
|
|
|
from mmselfsup.datasets.pipelines import (
|
2022-05-13 01:29:45 +08:00
|
|
|
BEiTMaskGenerator, Lighting, RandomGaussianBlur, RandomPatchWithLabels,
|
|
|
|
RandomResizedCropAndInterpolationWithTwoPic, RandomRotationWithLabels,
|
|
|
|
RandomSolarize, SimMIMMaskGenerator)
|
2022-05-09 20:18:15 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_simmim_mask_gen():
|
|
|
|
transform = dict(
|
|
|
|
input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6)
|
|
|
|
|
|
|
|
img = torch.rand((3, 192, 192))
|
|
|
|
results = {'img': img}
|
|
|
|
module = SimMIMMaskGenerator(**transform)
|
|
|
|
|
|
|
|
results = module(results)
|
|
|
|
|
|
|
|
# test transform
|
|
|
|
assert list(results['img'].shape) == [3, 192, 192]
|
|
|
|
assert list(results['mask'].shape) == [48, 48]
|
|
|
|
|
|
|
|
# test repr
|
|
|
|
assert isinstance(str(module), str)
|
|
|
|
|
|
|
|
|
|
|
|
def test_beit_mask_gen():
|
|
|
|
transform = dict(
|
|
|
|
input_size=(14, 14),
|
|
|
|
num_masking_patches=75,
|
|
|
|
max_num_patches=None,
|
|
|
|
min_num_patches=16)
|
|
|
|
module = BEiTMaskGenerator(**transform)
|
|
|
|
results = {}
|
|
|
|
|
|
|
|
results = module(results)
|
|
|
|
|
|
|
|
# test transform
|
|
|
|
assert list(results['mask'].shape) == [14, 14]
|
|
|
|
|
|
|
|
# test repr
|
|
|
|
assert isinstance(str(module), str)
|
|
|
|
|
|
|
|
|
|
|
|
def test_random_resize_crop_with_two_pic():
|
|
|
|
transform = dict(
|
|
|
|
size=224,
|
|
|
|
second_size=112,
|
|
|
|
interpolation='bicubic',
|
|
|
|
second_interpolation='lanczos',
|
|
|
|
scale=(0.08, 1.0))
|
|
|
|
module = RandomResizedCropAndInterpolationWithTwoPic(**transform)
|
|
|
|
fake_input = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
|
|
|
|
|
|
|
|
results = {'img': fake_input}
|
|
|
|
results = module(results)
|
|
|
|
# test transform
|
2022-05-13 18:57:27 +08:00
|
|
|
assert list(results['img'][0].shape) == [224, 224, 3]
|
|
|
|
assert list(results['img'][1].shape) == [112, 112, 3]
|
2022-05-09 20:18:15 +08:00
|
|
|
|
|
|
|
# test repr
|
|
|
|
assert isinstance(str(module), str)
|
2022-05-10 11:30:17 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_lighting():
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
transform = Lighting(eigval=1)
|
|
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
transform = Lighting(eigvec=1)
|
|
|
|
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
transform = Lighting(eigvec=[1])
|
|
|
|
|
|
|
|
original_img = np.ones((8, 8, 3), dtype=np.uint8)
|
|
|
|
results = dict(img=original_img)
|
|
|
|
|
|
|
|
transform = Lighting()
|
|
|
|
assert isinstance(str(transform), str)
|
|
|
|
|
|
|
|
results = transform(results)
|
|
|
|
assert results['img'].shape == original_img.shape
|
|
|
|
|
|
|
|
transform = Lighting(alphastd=0., to_rgb=False)
|
|
|
|
results = transform(dict(img=original_img))
|
|
|
|
assert np.equal(results['img'], original_img).all()
|
|
|
|
|
|
|
|
|
|
|
|
def test_random_gaussiablur():
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
transform = RandomGaussianBlur(sigma_min=0.1, sigma_max=1.0, prob=-1)
|
|
|
|
|
|
|
|
original_img = np.ones((8, 8, 3), dtype=np.uint8)
|
|
|
|
results = dict(img=original_img)
|
|
|
|
|
|
|
|
transform = RandomGaussianBlur(sigma_min=0.1, sigma_max=1.0)
|
|
|
|
assert isinstance(str(transform), str)
|
|
|
|
|
|
|
|
results = transform(results)
|
|
|
|
assert results['img'].shape == original_img.shape
|
|
|
|
|
|
|
|
|
|
|
|
def test_random_solarize():
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
transform = RandomSolarize(prob=-1)
|
|
|
|
|
|
|
|
original_img = np.ones((8, 8, 3), dtype=np.uint8)
|
|
|
|
results = dict(img=original_img)
|
|
|
|
|
|
|
|
transform = RandomSolarize()
|
|
|
|
assert isinstance(str(transform), str)
|
|
|
|
|
|
|
|
results = transform(results)
|
|
|
|
assert results['img'].shape == original_img.shape
|
2022-05-13 01:29:45 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_random_rotation():
|
|
|
|
transform = dict()
|
|
|
|
module = RandomRotationWithLabels(**transform)
|
|
|
|
image = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
|
|
|
|
results = {'img': image}
|
|
|
|
results = module(results)
|
|
|
|
|
|
|
|
# test transform
|
|
|
|
assert list(results['img'].shape) == [4, 3, 224, 224]
|
|
|
|
assert list(results['rot_label'].shape) == [4]
|
|
|
|
|
|
|
|
|
|
|
|
def test_random_patch():
|
|
|
|
transform = dict()
|
|
|
|
module = RandomPatchWithLabels(**transform)
|
|
|
|
image = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
|
|
|
|
results = {'img': image}
|
|
|
|
results = module(results)
|
|
|
|
|
|
|
|
# test transform
|
|
|
|
assert list(results['img'].shape) == [8, 6, 53, 53]
|
|
|
|
assert list(results['patch_label'].shape) == [8]
|