92 lines
3.2 KiB
Python
92 lines
3.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import random
|
|
|
|
from mmcv.transforms import RandomApply # noqa: E501
|
|
from mmcv.transforms import BaseTransform, Compose, RandomFlip, RandomGrayscale
|
|
|
|
from mmpretrain.datasets.transforms import (ColorJitter, GaussianBlur,
|
|
RandomResizedCrop, Solarize)
|
|
from mmpretrain.registry import TRANSFORMS
|
|
|
|
|
|
@TRANSFORMS.register_module()
|
|
class DINOMultiCrop(BaseTransform):
|
|
"""Multi-crop transform for DINO.
|
|
|
|
This module applies the multi-crop transform for DINO.
|
|
|
|
Args:
|
|
global_crops_scale (int): Scale of global crops.
|
|
local_crops_scale (int): Scale of local crops.
|
|
local_crops_number (int): Number of local crops.
|
|
"""
|
|
|
|
def __init__(self, global_crops_scale: int, local_crops_scale: int,
|
|
local_crops_number: int) -> None:
|
|
super().__init__()
|
|
self.global_crops_scale = global_crops_scale
|
|
self.local_crops_scale = local_crops_scale
|
|
|
|
flip_and_color_jitter = Compose([
|
|
RandomFlip(prob=0.5, direction='horizontal'),
|
|
RandomApply([
|
|
ColorJitter(
|
|
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)
|
|
],
|
|
prob=0.8),
|
|
RandomGrayscale(
|
|
prob=0.2,
|
|
keep_channels=True,
|
|
channel_weights=(0.114, 0.587, 0.2989),
|
|
)
|
|
])
|
|
|
|
self.global_transform_1 = Compose([
|
|
RandomResizedCrop(
|
|
224,
|
|
crop_ratio_range=global_crops_scale,
|
|
interpolation='bicubic'),
|
|
flip_and_color_jitter,
|
|
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
|
|
])
|
|
|
|
self.global_transform_2 = Compose([
|
|
RandomResizedCrop(
|
|
224,
|
|
crop_ratio_range=global_crops_scale,
|
|
interpolation='bicubic'),
|
|
flip_and_color_jitter,
|
|
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
|
|
Solarize(thr=128, prob=0.2),
|
|
])
|
|
|
|
self.local_crops_number = local_crops_number
|
|
self.local_transform = Compose([
|
|
RandomResizedCrop(
|
|
96,
|
|
crop_ratio_range=local_crops_scale,
|
|
interpolation='bicubic'),
|
|
flip_and_color_jitter,
|
|
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
|
|
])
|
|
|
|
def transform(self, results: dict) -> dict:
|
|
ori_img = results['img']
|
|
crops = []
|
|
results['img'] = ori_img
|
|
crops.append(self.global_transform_1(results)['img'])
|
|
results['img'] = ori_img
|
|
crops.append(self.global_transform_2(results)['img'])
|
|
for _ in range(self.local_crops_number):
|
|
results['img'] = ori_img
|
|
crops.append(self.local_transform(results)['img'])
|
|
results['img'] = crops
|
|
return results
|
|
|
|
def __repr__(self) -> str:
|
|
repr_str = self.__class__.__name__
|
|
repr_str += f'(global_crops_scale = {self.global_crops_scale}, '
|
|
repr_str += f'local_crops_scale = {self.local_crops_scale}, '
|
|
repr_str += f'local_crop_number = {self.local_crops_number})'
|
|
return repr_str
|