[Refactor]: refactor RandomRotation and RandomPatch

This commit is contained in:
renqin 2022-05-12 17:29:45 +00:00 committed by fangyixiao18
parent d9c7bd6a7b
commit 3426b49b39
3 changed files with 153 additions and 6 deletions

View File

@ -1,14 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .formatting import PackSelfSupInputs
from .transforms import (BEiTMaskGenerator, Lighting, RandomAug,
RandomGaussianBlur,
RandomGaussianBlur, RandomPatchWithLabels,
RandomResizedCropAndInterpolationWithTwoPic,
RandomSolarize, SimMIMMaskGenerator)
RandomRotationWithLabels, RandomSolarize,
SimMIMMaskGenerator)
from .wrappers import MultiView
__all__ = [
'RandomGaussianBlur', 'Lighting', 'RandomSolarize', 'RandomAug',
'SimMIMMaskGenerator', 'BEiTMaskGenerator',
'RandomResizedCropAndInterpolationWithTwoPic', 'PackSelfSupInputs',
'MultiView'
'MultiView', 'RandomRotationWithLabels', 'RandomPatchWithLabels'
]

View File

@ -11,6 +11,7 @@ from mmcv.image import adjust_lighting, solarize
from mmcv.transforms import BaseTransform
from PIL import Image, ImageFilter
from timm.data import create_transform
from torchvision.transforms import RandomCrop
from mmselfsup.registry import TRANSFORMS
@ -518,3 +519,124 @@ class RandomSolarize(BaseTransform):
repr_str += f'(threshold = {self.threshold}, '
repr_str += f'prob = {self.prob})'
return repr_str
@TRANSFORMS.register_module()
class RandomRotationWithLabels(BaseTransform):
"""Rotation prediction.
Required Keys:
- img
Modified Keys:
- img
Added Keys:
- rot_label
Rotate each image with 0, 90, 180, and 270 degrees and give labels `0, 1,
2, 3` correspodingly.
"""
def __init__(self) -> None:
pass
def _rotate(self, img: torch.Tensor):
"""Rotate input image with 0, 90, 180, and 270 degrees.
Args:
img (Tensor): input image of shape (C, H, W).
Returns:
list[Tensor]: A list of four rotated images.
"""
return [
img,
torch.flip(img.transpose(1, 2), [1]),
torch.flip(img, [1, 2]),
torch.flip(img, [1]).transpose(1, 2)
]
def transform(self, results: Dict) -> Dict:
img = np.transpose(results['img'], (2, 0, 1))
img = torch.from_numpy(img)
img = torch.stack(self._rotate(img), dim=0)
rotation_labels = np.array([0, 1, 2, 3])
results = dict(img=img.numpy(), rot_label=rotation_labels)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
return repr_str
@TRANSFORMS.register_module()
class RandomPatchWithLabels(BaseTransform):
"""Relative patch location.
Required Keys:
- img
Modified Keys:
- img
Added Keys:
- patch_label
Crops image into several patches and concatenates every surrounding patch
with center one. Finally gives labels `0, 1, 2, 3, 4, 5, 6, 7`.
"""
def __init__(self) -> None:
pass
def _image_to_patches(self, img: Image):
"""Crop split_per_side x split_per_side patches from input image.
Args:
img (PIL Image): input image.
Returns:
list[PIL Image]: A list of cropped patches.
"""
split_per_side = 3 # split of patches per image side
patch_jitter = 21 # jitter of each patch from each grid
h, w = img.size
h_grid = h // split_per_side
w_grid = w // split_per_side
h_patch = h_grid - patch_jitter
w_patch = w_grid - patch_jitter
assert h_patch > 0 and w_patch > 0
patches = []
for i in range(split_per_side):
for j in range(split_per_side):
p = F.crop(img, i * h_grid, j * w_grid, h_grid, w_grid)
p = RandomCrop((h_patch, w_patch))(p)
patches.append(np.transpose(np.asarray(p), (2, 0, 1)))
return patches
def transform(self, results: Dict) -> Dict:
img = Image.fromarray(results['img'])
patches = self._image_to_patches(img)
perms = []
# create a list of patch pairs
[
perms.append(np.concatenate((patches[i], patches[4]), axis=0))
for i in range(9) if i != 4
]
# create corresponding labels for patch pairs
patch_labels = np.array([0, 1, 2, 3, 4, 5, 6, 7])
results = dict(
img=np.stack(perms, axis=0),
patch_label=patch_labels) # 8(2C)HW, 8
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
return repr_str

View File

@ -5,9 +5,9 @@ import torch
from PIL import Image
from mmselfsup.datasets.pipelines import (
BEiTMaskGenerator, Lighting, RandomGaussianBlur,
RandomResizedCropAndInterpolationWithTwoPic, RandomSolarize,
SimMIMMaskGenerator)
BEiTMaskGenerator, Lighting, RandomGaussianBlur, RandomPatchWithLabels,
RandomResizedCropAndInterpolationWithTwoPic, RandomRotationWithLabels,
RandomSolarize, SimMIMMaskGenerator)
def test_simmim_mask_gen():
@ -118,3 +118,27 @@ def test_random_solarize():
results = transform(results)
assert results['img'].shape == original_img.shape
def test_random_rotation():
transform = dict()
module = RandomRotationWithLabels(**transform)
image = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
results = {'img': image}
results = module(results)
# test transform
assert list(results['img'].shape) == [4, 3, 224, 224]
assert list(results['rot_label'].shape) == [4]
def test_random_patch():
transform = dict()
module = RandomPatchWithLabels(**transform)
image = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
results = {'img': image}
results = module(results)
# test transform
assert list(results['img'].shape) == [8, 6, 53, 53]
assert list(results['patch_label'].shape) == [8]