mmclassification/projects/dino/dataset/transform/processing.py

92 lines
3.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import random
from mmcv.transforms import RandomApply # noqa: E501
from mmcv.transforms import BaseTransform, Compose, RandomFlip, RandomGrayscale
from mmpretrain.datasets.transforms import (ColorJitter, GaussianBlur,
RandomResizedCrop, Solarize)
from mmpretrain.registry import TRANSFORMS
@TRANSFORMS.register_module()
class DINOMultiCrop(BaseTransform):
"""Multi-crop transform for DINO.
This module applies the multi-crop transform for DINO.
Args:
global_crops_scale (int): Scale of global crops.
local_crops_scale (int): Scale of local crops.
local_crops_number (int): Number of local crops.
"""
def __init__(self, global_crops_scale: int, local_crops_scale: int,
local_crops_number: int) -> None:
super().__init__()
self.global_crops_scale = global_crops_scale
self.local_crops_scale = local_crops_scale
flip_and_color_jitter = Compose([
RandomFlip(prob=0.5, direction='horizontal'),
RandomApply([
ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)
],
prob=0.8),
RandomGrayscale(
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989),
)
])
self.global_transform_1 = Compose([
RandomResizedCrop(
224,
crop_ratio_range=global_crops_scale,
interpolation='bicubic'),
flip_and_color_jitter,
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
])
self.global_transform_2 = Compose([
RandomResizedCrop(
224,
crop_ratio_range=global_crops_scale,
interpolation='bicubic'),
flip_and_color_jitter,
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
Solarize(thr=128, prob=0.2),
])
self.local_crops_number = local_crops_number
self.local_transform = Compose([
RandomResizedCrop(
96,
crop_ratio_range=local_crops_scale,
interpolation='bicubic'),
flip_and_color_jitter,
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
])
def transform(self, results: dict) -> dict:
ori_img = results['img']
crops = []
results['img'] = ori_img
crops.append(self.global_transform_1(results)['img'])
results['img'] = ori_img
crops.append(self.global_transform_2(results)['img'])
for _ in range(self.local_crops_number):
results['img'] = ori_img
crops.append(self.local_transform(results)['img'])
results['img'] = crops
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(global_crops_scale = {self.global_crops_scale}, '
repr_str += f'local_crops_scale = {self.local_crops_scale}, '
repr_str += f'local_crop_number = {self.local_crops_number})'
return repr_str