[Feature]: Add MAE transform

This commit is contained in:
liuyuan 2023-03-17 10:00:48 +08:00 committed by Yuan Liu
parent 7695bb8134
commit dd1f496821
3 changed files with 22 additions and 16 deletions

View File

@ -3,7 +3,7 @@ _base_ = 'vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py'
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='mmcls.ToPIL', to_rgb=True),
dict(type='MAERandomResizedCrop', size=224, interpolation=3),
dict(type='mmselfsup.MAERandomResizedCrop', size=224, interpolation=3),
dict(type='mmcls.torchvision/RandomHorizontalFlip', p=0.5),
dict(type='mmcls.ToNumpy', to_rgb=True),
dict(type='PackClsInputs'),

View File

@ -7,20 +7,12 @@ from .processing import (BEiTMaskGenerator, ColorJitter, RandomCrop,
RandomRotation, RandomSolarize, RotationWithLabels,
SimMIMMaskGenerator)
from .wrappers import MultiView
from .pytorch_transform import MAERandomResizedCrop
__all__ = [
'PackSelfSupInputs',
'RandomGaussianBlur',
'RandomSolarize',
'SimMIMMaskGenerator',
'BEiTMaskGenerator',
'ColorJitter',
'RandomResizedCropAndInterpolationWithTwoPic',
'PackSelfSupInputs',
'MultiView',
'RotationWithLabels',
'RandomPatchWithLabels',
'RandomRotation',
'RandomResizedCrop',
'RandomCrop',
'PackSelfSupInputs', 'RandomGaussianBlur', 'RandomSolarize',
'SimMIMMaskGenerator', 'BEiTMaskGenerator', 'ColorJitter',
'RandomResizedCropAndInterpolationWithTwoPic', 'PackSelfSupInputs',
'MultiView', 'RotationWithLabels', 'RandomPatchWithLabels',
'RandomRotation', 'RandomResizedCrop', 'RandomCrop', 'MAERandomResizedCrop'
]

View File

@ -40,4 +40,18 @@ class MAERandomResizedCrop(transforms.RandomResizedCrop):
i = torch.randint(0, height - h + 1, size=(1, )).item()
j = torch.randint(0, width - w + 1, size=(1, )).item()
return i, j, h, w
return i, j, h, w
def forward(self, results):
"""
Args:
img (PIL Image or Tensor): Image to be cropped and resized.
Returns:
PIL Image or Tensor: Randomly cropped and resized image.
"""
img = results['img']
i, j, h, w = self.get_params(img, self.scale, self.ratio)
img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
results['img'] = img
return results