mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
67 lines
1.7 KiB
Python
67 lines
1.7 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from PIL import Image
|
||
|
|
||
|
from mmselfsup.datasets.pipelines import (
|
||
|
BEiTMaskGenerator, RandomResizedCropAndInterpolationWithTwoPic,
|
||
|
SimMIMMaskGenerator)
|
||
|
|
||
|
|
||
|
def test_simmim_mask_gen():
|
||
|
transform = dict(
|
||
|
input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6)
|
||
|
|
||
|
img = torch.rand((3, 192, 192))
|
||
|
results = {'img': img}
|
||
|
module = SimMIMMaskGenerator(**transform)
|
||
|
|
||
|
results = module(results)
|
||
|
|
||
|
# test transform
|
||
|
assert list(results['img'].shape) == [3, 192, 192]
|
||
|
assert list(results['mask'].shape) == [48, 48]
|
||
|
|
||
|
# test repr
|
||
|
assert isinstance(str(module), str)
|
||
|
|
||
|
|
||
|
def test_beit_mask_gen():
|
||
|
transform = dict(
|
||
|
input_size=(14, 14),
|
||
|
num_masking_patches=75,
|
||
|
max_num_patches=None,
|
||
|
min_num_patches=16)
|
||
|
module = BEiTMaskGenerator(**transform)
|
||
|
results = {}
|
||
|
|
||
|
results = module(results)
|
||
|
|
||
|
# test transform
|
||
|
assert list(results['mask'].shape) == [14, 14]
|
||
|
|
||
|
# test repr
|
||
|
assert isinstance(str(module), str)
|
||
|
|
||
|
|
||
|
def test_random_resize_crop_with_two_pic():
|
||
|
transform = dict(
|
||
|
size=224,
|
||
|
second_size=112,
|
||
|
interpolation='bicubic',
|
||
|
second_interpolation='lanczos',
|
||
|
scale=(0.08, 1.0))
|
||
|
module = RandomResizedCropAndInterpolationWithTwoPic(**transform)
|
||
|
fake_input = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
|
||
|
fake_input = Image.fromarray(fake_input)
|
||
|
|
||
|
results = {'img': fake_input}
|
||
|
results = module(results)
|
||
|
|
||
|
# test transform
|
||
|
assert list(results['img'].size) == [224, 224]
|
||
|
assert list(results['target_img'].size) == [112, 112]
|
||
|
|
||
|
# test repr
|
||
|
assert isinstance(str(module), str)
|