mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Refactor]: refactor RandomRotation and RandomPatch
This commit is contained in:
parent
d9c7bd6a7b
commit
3426b49b39
@ -1,14 +1,15 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .formatting import PackSelfSupInputs
|
||||
from .transforms import (BEiTMaskGenerator, Lighting, RandomAug,
|
||||
RandomGaussianBlur,
|
||||
RandomGaussianBlur, RandomPatchWithLabels,
|
||||
RandomResizedCropAndInterpolationWithTwoPic,
|
||||
RandomSolarize, SimMIMMaskGenerator)
|
||||
RandomRotationWithLabels, RandomSolarize,
|
||||
SimMIMMaskGenerator)
|
||||
from .wrappers import MultiView
|
||||
|
||||
__all__ = [
|
||||
'RandomGaussianBlur', 'Lighting', 'RandomSolarize', 'RandomAug',
|
||||
'SimMIMMaskGenerator', 'BEiTMaskGenerator',
|
||||
'RandomResizedCropAndInterpolationWithTwoPic', 'PackSelfSupInputs',
|
||||
'MultiView'
|
||||
'MultiView', 'RandomRotationWithLabels', 'RandomPatchWithLabels'
|
||||
]
|
||||
|
@ -11,6 +11,7 @@ from mmcv.image import adjust_lighting, solarize
|
||||
from mmcv.transforms import BaseTransform
|
||||
from PIL import Image, ImageFilter
|
||||
from timm.data import create_transform
|
||||
from torchvision.transforms import RandomCrop
|
||||
|
||||
from mmselfsup.registry import TRANSFORMS
|
||||
|
||||
@ -518,3 +519,124 @@ class RandomSolarize(BaseTransform):
|
||||
repr_str += f'(threshold = {self.threshold}, '
|
||||
repr_str += f'prob = {self.prob})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomRotationWithLabels(BaseTransform):
|
||||
"""Rotation prediction.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
|
||||
Added Keys:
|
||||
|
||||
- rot_label
|
||||
|
||||
Rotate each image with 0, 90, 180, and 270 degrees and give labels `0, 1,
|
||||
2, 3` correspodingly.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _rotate(self, img: torch.Tensor):
|
||||
"""Rotate input image with 0, 90, 180, and 270 degrees.
|
||||
|
||||
Args:
|
||||
img (Tensor): input image of shape (C, H, W).
|
||||
|
||||
Returns:
|
||||
list[Tensor]: A list of four rotated images.
|
||||
"""
|
||||
return [
|
||||
img,
|
||||
torch.flip(img.transpose(1, 2), [1]),
|
||||
torch.flip(img, [1, 2]),
|
||||
torch.flip(img, [1]).transpose(1, 2)
|
||||
]
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
img = np.transpose(results['img'], (2, 0, 1))
|
||||
img = torch.from_numpy(img)
|
||||
img = torch.stack(self._rotate(img), dim=0)
|
||||
rotation_labels = np.array([0, 1, 2, 3])
|
||||
results = dict(img=img.numpy(), rot_label=rotation_labels)
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class RandomPatchWithLabels(BaseTransform):
|
||||
"""Relative patch location.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
|
||||
Added Keys:
|
||||
|
||||
- patch_label
|
||||
|
||||
Crops image into several patches and concatenates every surrounding patch
|
||||
with center one. Finally gives labels `0, 1, 2, 3, 4, 5, 6, 7`.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _image_to_patches(self, img: Image):
|
||||
"""Crop split_per_side x split_per_side patches from input image.
|
||||
|
||||
Args:
|
||||
img (PIL Image): input image.
|
||||
|
||||
Returns:
|
||||
list[PIL Image]: A list of cropped patches.
|
||||
"""
|
||||
split_per_side = 3 # split of patches per image side
|
||||
patch_jitter = 21 # jitter of each patch from each grid
|
||||
h, w = img.size
|
||||
h_grid = h // split_per_side
|
||||
w_grid = w // split_per_side
|
||||
h_patch = h_grid - patch_jitter
|
||||
w_patch = w_grid - patch_jitter
|
||||
assert h_patch > 0 and w_patch > 0
|
||||
patches = []
|
||||
for i in range(split_per_side):
|
||||
for j in range(split_per_side):
|
||||
p = F.crop(img, i * h_grid, j * w_grid, h_grid, w_grid)
|
||||
p = RandomCrop((h_patch, w_patch))(p)
|
||||
patches.append(np.transpose(np.asarray(p), (2, 0, 1)))
|
||||
return patches
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
img = Image.fromarray(results['img'])
|
||||
patches = self._image_to_patches(img)
|
||||
perms = []
|
||||
# create a list of patch pairs
|
||||
[
|
||||
perms.append(np.concatenate((patches[i], patches[4]), axis=0))
|
||||
for i in range(9) if i != 4
|
||||
]
|
||||
# create corresponding labels for patch pairs
|
||||
patch_labels = np.array([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
results = dict(
|
||||
img=np.stack(perms, axis=0),
|
||||
patch_label=patch_labels) # 8(2C)HW, 8
|
||||
return results
|
||||
|
||||
def __repr__(self) -> str:
|
||||
repr_str = self.__class__.__name__
|
||||
return repr_str
|
||||
|
@ -5,9 +5,9 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from mmselfsup.datasets.pipelines import (
|
||||
BEiTMaskGenerator, Lighting, RandomGaussianBlur,
|
||||
RandomResizedCropAndInterpolationWithTwoPic, RandomSolarize,
|
||||
SimMIMMaskGenerator)
|
||||
BEiTMaskGenerator, Lighting, RandomGaussianBlur, RandomPatchWithLabels,
|
||||
RandomResizedCropAndInterpolationWithTwoPic, RandomRotationWithLabels,
|
||||
RandomSolarize, SimMIMMaskGenerator)
|
||||
|
||||
|
||||
def test_simmim_mask_gen():
|
||||
@ -118,3 +118,27 @@ def test_random_solarize():
|
||||
|
||||
results = transform(results)
|
||||
assert results['img'].shape == original_img.shape
|
||||
|
||||
|
||||
def test_random_rotation():
|
||||
transform = dict()
|
||||
module = RandomRotationWithLabels(**transform)
|
||||
image = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
|
||||
results = {'img': image}
|
||||
results = module(results)
|
||||
|
||||
# test transform
|
||||
assert list(results['img'].shape) == [4, 3, 224, 224]
|
||||
assert list(results['rot_label'].shape) == [4]
|
||||
|
||||
|
||||
def test_random_patch():
|
||||
transform = dict()
|
||||
module = RandomPatchWithLabels(**transform)
|
||||
image = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
|
||||
results = {'img': image}
|
||||
results = module(results)
|
||||
|
||||
# test transform
|
||||
assert list(results['img'].shape) == [8, 6, 53, 53]
|
||||
assert list(results['patch_label'].shape) == [8]
|
||||
|
Loading…
x
Reference in New Issue
Block a user