mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Feature]: Add MAE transform
This commit is contained in:
parent
7695bb8134
commit
dd1f496821
@ -3,7 +3,7 @@ _base_ = 'vit-base-p16_linear-8xb2048-coslr-torchvision-transform-90e_in1k.py'
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='mmcls.ToPIL', to_rgb=True),
|
||||
dict(type='MAERandomResizedCrop', size=224, interpolation=3),
|
||||
dict(type='mmselfsup.MAERandomResizedCrop', size=224, interpolation=3),
|
||||
dict(type='mmcls.torchvision/RandomHorizontalFlip', p=0.5),
|
||||
dict(type='mmcls.ToNumpy', to_rgb=True),
|
||||
dict(type='PackClsInputs'),
|
||||
|
@ -7,20 +7,12 @@ from .processing import (BEiTMaskGenerator, ColorJitter, RandomCrop,
|
||||
RandomRotation, RandomSolarize, RotationWithLabels,
|
||||
SimMIMMaskGenerator)
|
||||
from .wrappers import MultiView
|
||||
from .pytorch_transform import MAERandomResizedCrop
|
||||
|
||||
__all__ = [
|
||||
'PackSelfSupInputs',
|
||||
'RandomGaussianBlur',
|
||||
'RandomSolarize',
|
||||
'SimMIMMaskGenerator',
|
||||
'BEiTMaskGenerator',
|
||||
'ColorJitter',
|
||||
'RandomResizedCropAndInterpolationWithTwoPic',
|
||||
'PackSelfSupInputs',
|
||||
'MultiView',
|
||||
'RotationWithLabels',
|
||||
'RandomPatchWithLabels',
|
||||
'RandomRotation',
|
||||
'RandomResizedCrop',
|
||||
'RandomCrop',
|
||||
'PackSelfSupInputs', 'RandomGaussianBlur', 'RandomSolarize',
|
||||
'SimMIMMaskGenerator', 'BEiTMaskGenerator', 'ColorJitter',
|
||||
'RandomResizedCropAndInterpolationWithTwoPic', 'PackSelfSupInputs',
|
||||
'MultiView', 'RotationWithLabels', 'RandomPatchWithLabels',
|
||||
'RandomRotation', 'RandomResizedCrop', 'RandomCrop', 'MAERandomResizedCrop'
|
||||
]
|
||||
|
@ -40,4 +40,18 @@ class MAERandomResizedCrop(transforms.RandomResizedCrop):
|
||||
i = torch.randint(0, height - h + 1, size=(1, )).item()
|
||||
j = torch.randint(0, width - w + 1, size=(1, )).item()
|
||||
|
||||
return i, j, h, w
|
||||
return i, j, h, w
|
||||
|
||||
def forward(self, results):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image or Tensor): Image to be cropped and resized.
|
||||
|
||||
Returns:
|
||||
PIL Image or Tensor: Randomly cropped and resized image.
|
||||
"""
|
||||
img = results['img']
|
||||
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
||||
img = F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
|
||||
results['img'] = img
|
||||
return results
|
Loading…
x
Reference in New Issue
Block a user