mmselfsup/tests/test_dataset/test_transforms.py

67 lines
1.7 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from PIL import Image
from mmselfsup.datasets.pipelines import (
BEiTMaskGenerator, RandomResizedCropAndInterpolationWithTwoPic,
SimMIMMaskGenerator)
def test_simmim_mask_gen():
transform = dict(
input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6)
img = torch.rand((3, 192, 192))
results = {'img': img}
module = SimMIMMaskGenerator(**transform)
results = module(results)
# test transform
assert list(results['img'].shape) == [3, 192, 192]
assert list(results['mask'].shape) == [48, 48]
# test repr
assert isinstance(str(module), str)
def test_beit_mask_gen():
transform = dict(
input_size=(14, 14),
num_masking_patches=75,
max_num_patches=None,
min_num_patches=16)
module = BEiTMaskGenerator(**transform)
results = {}
results = module(results)
# test transform
assert list(results['mask'].shape) == [14, 14]
# test repr
assert isinstance(str(module), str)
def test_random_resize_crop_with_two_pic():
transform = dict(
size=224,
second_size=112,
interpolation='bicubic',
second_interpolation='lanczos',
scale=(0.08, 1.0))
module = RandomResizedCropAndInterpolationWithTwoPic(**transform)
fake_input = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
fake_input = Image.fromarray(fake_input)
results = {'img': fake_input}
results = module(results)
# test transform
assert list(results['img'].size) == [224, 224]
assert list(results['target_img'].size) == [112, 112]
# test repr
assert isinstance(str(module), str)