mirror of
https://github.com/open-mmlab/mmselfsup.git
synced 2025-06-03 14:59:38 +08:00
[Refactor]: refactor RandomRotation and RandomPatch
This commit is contained in:
parent
d9c7bd6a7b
commit
3426b49b39
@ -1,14 +1,15 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .formatting import PackSelfSupInputs
|
from .formatting import PackSelfSupInputs
|
||||||
from .transforms import (BEiTMaskGenerator, Lighting, RandomAug,
|
from .transforms import (BEiTMaskGenerator, Lighting, RandomAug,
|
||||||
RandomGaussianBlur,
|
RandomGaussianBlur, RandomPatchWithLabels,
|
||||||
RandomResizedCropAndInterpolationWithTwoPic,
|
RandomResizedCropAndInterpolationWithTwoPic,
|
||||||
RandomSolarize, SimMIMMaskGenerator)
|
RandomRotationWithLabels, RandomSolarize,
|
||||||
|
SimMIMMaskGenerator)
|
||||||
from .wrappers import MultiView
|
from .wrappers import MultiView
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'RandomGaussianBlur', 'Lighting', 'RandomSolarize', 'RandomAug',
|
'RandomGaussianBlur', 'Lighting', 'RandomSolarize', 'RandomAug',
|
||||||
'SimMIMMaskGenerator', 'BEiTMaskGenerator',
|
'SimMIMMaskGenerator', 'BEiTMaskGenerator',
|
||||||
'RandomResizedCropAndInterpolationWithTwoPic', 'PackSelfSupInputs',
|
'RandomResizedCropAndInterpolationWithTwoPic', 'PackSelfSupInputs',
|
||||||
'MultiView'
|
'MultiView', 'RandomRotationWithLabels', 'RandomPatchWithLabels'
|
||||||
]
|
]
|
||||||
|
@ -11,6 +11,7 @@ from mmcv.image import adjust_lighting, solarize
|
|||||||
from mmcv.transforms import BaseTransform
|
from mmcv.transforms import BaseTransform
|
||||||
from PIL import Image, ImageFilter
|
from PIL import Image, ImageFilter
|
||||||
from timm.data import create_transform
|
from timm.data import create_transform
|
||||||
|
from torchvision.transforms import RandomCrop
|
||||||
|
|
||||||
from mmselfsup.registry import TRANSFORMS
|
from mmselfsup.registry import TRANSFORMS
|
||||||
|
|
||||||
@ -518,3 +519,124 @@ class RandomSolarize(BaseTransform):
|
|||||||
repr_str += f'(threshold = {self.threshold}, '
|
repr_str += f'(threshold = {self.threshold}, '
|
||||||
repr_str += f'prob = {self.prob})'
|
repr_str += f'prob = {self.prob})'
|
||||||
return repr_str
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class RandomRotationWithLabels(BaseTransform):
|
||||||
|
"""Rotation prediction.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
|
Added Keys:
|
||||||
|
|
||||||
|
- rot_label
|
||||||
|
|
||||||
|
Rotate each image with 0, 90, 180, and 270 degrees and give labels `0, 1,
|
||||||
|
2, 3` correspodingly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _rotate(self, img: torch.Tensor):
|
||||||
|
"""Rotate input image with 0, 90, 180, and 270 degrees.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (Tensor): input image of shape (C, H, W).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Tensor]: A list of four rotated images.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
img,
|
||||||
|
torch.flip(img.transpose(1, 2), [1]),
|
||||||
|
torch.flip(img, [1, 2]),
|
||||||
|
torch.flip(img, [1]).transpose(1, 2)
|
||||||
|
]
|
||||||
|
|
||||||
|
def transform(self, results: Dict) -> Dict:
|
||||||
|
img = np.transpose(results['img'], (2, 0, 1))
|
||||||
|
img = torch.from_numpy(img)
|
||||||
|
img = torch.stack(self._rotate(img), dim=0)
|
||||||
|
rotation_labels = np.array([0, 1, 2, 3])
|
||||||
|
results = dict(img=img.numpy(), rot_label=rotation_labels)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
return repr_str
|
||||||
|
|
||||||
|
|
||||||
|
@TRANSFORMS.register_module()
|
||||||
|
class RandomPatchWithLabels(BaseTransform):
|
||||||
|
"""Relative patch location.
|
||||||
|
|
||||||
|
Required Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
|
Modified Keys:
|
||||||
|
|
||||||
|
- img
|
||||||
|
|
||||||
|
Added Keys:
|
||||||
|
|
||||||
|
- patch_label
|
||||||
|
|
||||||
|
Crops image into several patches and concatenates every surrounding patch
|
||||||
|
with center one. Finally gives labels `0, 1, 2, 3, 4, 5, 6, 7`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _image_to_patches(self, img: Image):
|
||||||
|
"""Crop split_per_side x split_per_side patches from input image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (PIL Image): input image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[PIL Image]: A list of cropped patches.
|
||||||
|
"""
|
||||||
|
split_per_side = 3 # split of patches per image side
|
||||||
|
patch_jitter = 21 # jitter of each patch from each grid
|
||||||
|
h, w = img.size
|
||||||
|
h_grid = h // split_per_side
|
||||||
|
w_grid = w // split_per_side
|
||||||
|
h_patch = h_grid - patch_jitter
|
||||||
|
w_patch = w_grid - patch_jitter
|
||||||
|
assert h_patch > 0 and w_patch > 0
|
||||||
|
patches = []
|
||||||
|
for i in range(split_per_side):
|
||||||
|
for j in range(split_per_side):
|
||||||
|
p = F.crop(img, i * h_grid, j * w_grid, h_grid, w_grid)
|
||||||
|
p = RandomCrop((h_patch, w_patch))(p)
|
||||||
|
patches.append(np.transpose(np.asarray(p), (2, 0, 1)))
|
||||||
|
return patches
|
||||||
|
|
||||||
|
def transform(self, results: Dict) -> Dict:
|
||||||
|
img = Image.fromarray(results['img'])
|
||||||
|
patches = self._image_to_patches(img)
|
||||||
|
perms = []
|
||||||
|
# create a list of patch pairs
|
||||||
|
[
|
||||||
|
perms.append(np.concatenate((patches[i], patches[4]), axis=0))
|
||||||
|
for i in range(9) if i != 4
|
||||||
|
]
|
||||||
|
# create corresponding labels for patch pairs
|
||||||
|
patch_labels = np.array([0, 1, 2, 3, 4, 5, 6, 7])
|
||||||
|
results = dict(
|
||||||
|
img=np.stack(perms, axis=0),
|
||||||
|
patch_label=patch_labels) # 8(2C)HW, 8
|
||||||
|
return results
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
repr_str = self.__class__.__name__
|
||||||
|
return repr_str
|
||||||
|
@ -5,9 +5,9 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from mmselfsup.datasets.pipelines import (
|
from mmselfsup.datasets.pipelines import (
|
||||||
BEiTMaskGenerator, Lighting, RandomGaussianBlur,
|
BEiTMaskGenerator, Lighting, RandomGaussianBlur, RandomPatchWithLabels,
|
||||||
RandomResizedCropAndInterpolationWithTwoPic, RandomSolarize,
|
RandomResizedCropAndInterpolationWithTwoPic, RandomRotationWithLabels,
|
||||||
SimMIMMaskGenerator)
|
RandomSolarize, SimMIMMaskGenerator)
|
||||||
|
|
||||||
|
|
||||||
def test_simmim_mask_gen():
|
def test_simmim_mask_gen():
|
||||||
@ -118,3 +118,27 @@ def test_random_solarize():
|
|||||||
|
|
||||||
results = transform(results)
|
results = transform(results)
|
||||||
assert results['img'].shape == original_img.shape
|
assert results['img'].shape == original_img.shape
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_rotation():
|
||||||
|
transform = dict()
|
||||||
|
module = RandomRotationWithLabels(**transform)
|
||||||
|
image = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
|
||||||
|
results = {'img': image}
|
||||||
|
results = module(results)
|
||||||
|
|
||||||
|
# test transform
|
||||||
|
assert list(results['img'].shape) == [4, 3, 224, 224]
|
||||||
|
assert list(results['rot_label'].shape) == [4]
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_patch():
|
||||||
|
transform = dict()
|
||||||
|
module = RandomPatchWithLabels(**transform)
|
||||||
|
image = torch.rand((224, 224, 3)).numpy().astype(np.uint8)
|
||||||
|
results = {'img': image}
|
||||||
|
results = module(results)
|
||||||
|
|
||||||
|
# test transform
|
||||||
|
assert list(results['img'].shape) == [8, 6, 53, 53]
|
||||||
|
assert list(results['patch_label'].shape) == [8]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user