[Refactor]: refactor RandomRotation and RandomPatch

This commit is contained in:
renqin 2022-05-12 17:29:45 +00:00 committed by fangyixiao18
parent d9c7bd6a7b
commit 3426b49b39
3 changed files with 153 additions and 6 deletions

View File

@ -1,14 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .formatting import PackSelfSupInputs from .formatting import PackSelfSupInputs
from .transforms import (BEiTMaskGenerator, Lighting, RandomAug, from .transforms import (BEiTMaskGenerator, Lighting, RandomAug,
RandomGaussianBlur, RandomGaussianBlur, RandomPatchWithLabels,
RandomResizedCropAndInterpolationWithTwoPic, RandomResizedCropAndInterpolationWithTwoPic,
RandomSolarize, SimMIMMaskGenerator) RandomRotationWithLabels, RandomSolarize,
SimMIMMaskGenerator)
from .wrappers import MultiView from .wrappers import MultiView
__all__ = [ __all__ = [
'RandomGaussianBlur', 'Lighting', 'RandomSolarize', 'RandomAug', 'RandomGaussianBlur', 'Lighting', 'RandomSolarize', 'RandomAug',
'SimMIMMaskGenerator', 'BEiTMaskGenerator', 'SimMIMMaskGenerator', 'BEiTMaskGenerator',
'RandomResizedCropAndInterpolationWithTwoPic', 'PackSelfSupInputs', 'RandomResizedCropAndInterpolationWithTwoPic', 'PackSelfSupInputs',
'MultiView' 'MultiView', 'RandomRotationWithLabels', 'RandomPatchWithLabels'
] ]

View File

@ -11,6 +11,7 @@ from mmcv.image import adjust_lighting, solarize
from mmcv.transforms import BaseTransform from mmcv.transforms import BaseTransform
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from timm.data import create_transform from timm.data import create_transform
from torchvision.transforms import RandomCrop
from mmselfsup.registry import TRANSFORMS from mmselfsup.registry import TRANSFORMS
@ -518,3 +519,124 @@ class RandomSolarize(BaseTransform):
repr_str += f'(threshold = {self.threshold}, ' repr_str += f'(threshold = {self.threshold}, '
repr_str += f'prob = {self.prob})' repr_str += f'prob = {self.prob})'
return repr_str return repr_str
@TRANSFORMS.register_module()
class RandomRotationWithLabels(BaseTransform):
"""Rotation prediction.
Required Keys:
- img
Modified Keys:
- img
Added Keys:
- rot_label
Rotate each image with 0, 90, 180, and 270 degrees and give labels `0, 1,
2, 3` correspodingly.
"""
def __init__(self) -> None:
pass
def _rotate(self, img: torch.Tensor):
"""Rotate input image with 0, 90, 180, and 270 degrees.
Args:
img (Tensor): input image of shape (C, H, W).
Returns:
list[Tensor]: A list of four rotated images.
"""
return [
img,
torch.flip(img.transpose(1, 2), [1]),
torch.flip(img, [1, 2]),
torch.flip(img, [1]).transpose(1, 2)
]
def transform(self, results: Dict) -> Dict:
img = np.transpose(results['img'], (2, 0, 1))
img = torch.from_numpy(img)
img = torch.stack(self._rotate(img), dim=0)
rotation_labels = np.array([0, 1, 2, 3])
results = dict(img=img.numpy(), rot_label=rotation_labels)
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
return repr_str
@TRANSFORMS.register_module()
class RandomPatchWithLabels(BaseTransform):
"""Relative patch location.
Required Keys:
- img
Modified Keys:
- img
Added Keys:
- patch_label
Crops image into several patches and concatenates every surrounding patch
with center one. Finally gives labels `0, 1, 2, 3, 4, 5, 6, 7`.
"""
def __init__(self) -> None:
pass
def _image_to_patches(self, img: Image):
"""Crop split_per_side x split_per_side patches from input image.
Args:
img (PIL Image): input image.
Returns:
list[PIL Image]: A list of cropped patches.
"""
split_per_side = 3 # split of patches per image side
patch_jitter = 21 # jitter of each patch from each grid
h, w = img.size
h_grid = h // split_per_side
w_grid = w // split_per_side
h_patch = h_grid - patch_jitter
w_patch = w_grid - patch_jitter
assert h_patch > 0 and w_patch > 0
patches = []
for i in range(split_per_side):
for j in range(split_per_side):
p = F.crop(img, i * h_grid, j * w_grid, h_grid, w_grid)
p = RandomCrop((h_patch, w_patch))(p)
patches.append(np.transpose(np.asarray(p), (2, 0, 1)))
return patches
def transform(self, results: Dict) -> Dict:
img = Image.fromarray(results['img'])
patches = self._image_to_patches(img)
perms = []
# create a list of patch pairs
[
perms.append(np.concatenate((patches[i], patches[4]), axis=0))
for i in range(9) if i != 4
]
# create corresponding labels for patch pairs
patch_labels = np.array([0, 1, 2, 3, 4, 5, 6, 7])
results = dict(
img=np.stack(perms, axis=0),
patch_label=patch_labels) # 8(2C)HW, 8
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
return repr_str

View File

@ -5,9 +5,9 @@ import torch
from PIL import Image from PIL import Image
from mmselfsup.datasets.pipelines import ( from mmselfsup.datasets.pipelines import (
BEiTMaskGenerator, Lighting, RandomGaussianBlur, BEiTMaskGenerator, Lighting, RandomGaussianBlur, RandomPatchWithLabels,
RandomResizedCropAndInterpolationWithTwoPic, RandomSolarize, RandomResizedCropAndInterpolationWithTwoPic, RandomRotationWithLabels,
SimMIMMaskGenerator) RandomSolarize, SimMIMMaskGenerator)
def test_simmim_mask_gen(): def test_simmim_mask_gen():
@ -118,3 +118,27 @@ def test_random_solarize():
results = transform(results) results = transform(results)
assert results['img'].shape == original_img.shape assert results['img'].shape == original_img.shape
def test_random_rotation():
transform = dict()
module = RandomRotationWithLabels(**transform)
image = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
results = {'img': image}
results = module(results)
# test transform
assert list(results['img'].shape) == [4, 3, 224, 224]
assert list(results['rot_label'].shape) == [4]
def test_random_patch():
transform = dict()
module = RandomPatchWithLabels(**transform)
image = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
results = {'img': image}
results = module(results)
# test transform
assert list(results['img'].shape) == [8, 6, 53, 53]
assert list(results['patch_label'].shape) == [8]