diff --git a/configs/_base_/datasets/imagenet_bs256_beitv2.py b/configs/_base_/datasets/imagenet_bs256_beitv2.py index 7e153d8c..34daf9d5 100644 --- a/configs/_base_/datasets/imagenet_bs256_beitv2.py +++ b/configs/_base_/datasets/imagenet_bs256_beitv2.py @@ -7,7 +7,7 @@ data_preprocessor = dict( std=[58.395, 57.12, 57.375], second_mean=[127.5, 127.5, 127.5], second_std=[127.5, 127.5, 127.5], - bgr_to_rgb=True) + to_rgb=True) train_pipeline = [ dict(type='LoadImageFromFile'), diff --git a/configs/_base_/datasets/imagenet_bs256_simmim_192.py b/configs/_base_/datasets/imagenet_bs256_simmim_192.py index 71e5e679..2d91665c 100644 --- a/configs/_base_/datasets/imagenet_bs256_simmim_192.py +++ b/configs/_base_/datasets/imagenet_bs256_simmim_192.py @@ -5,7 +5,7 @@ data_preprocessor = dict( type='SelfSupDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], - bgr_to_rgb=True) + to_rgb=True) train_pipeline = [ dict(type='LoadImageFromFile'), diff --git a/configs/_base_/datasets/imagenet_bs32_byol.py b/configs/_base_/datasets/imagenet_bs32_byol.py index 935a1935..6bb7b75f 100644 --- a/configs/_base_/datasets/imagenet_bs32_byol.py +++ b/configs/_base_/datasets/imagenet_bs32_byol.py @@ -5,7 +5,7 @@ data_preprocessor = dict( type='SelfSupDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], - bgr_to_rgb=True) + to_rgb=True) view_pipeline1 = [ dict( diff --git a/configs/_base_/datasets/imagenet_bs32_mocov2.py b/configs/_base_/datasets/imagenet_bs32_mocov2.py index c43c01a2..fa710ad4 100644 --- a/configs/_base_/datasets/imagenet_bs32_mocov2.py +++ b/configs/_base_/datasets/imagenet_bs32_mocov2.py @@ -5,7 +5,7 @@ data_preprocessor = dict( type='SelfSupDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], - bgr_to_rgb=True) + to_rgb=True) # The difference between mocov2 and mocov1 is the transforms in the pipeline view_pipeline = [ diff --git a/configs/_base_/datasets/imagenet_bs32_simclr.py b/configs/_base_/datasets/imagenet_bs32_simclr.py index 2a6c488b..c04f19df 100644 --- a/configs/_base_/datasets/imagenet_bs32_simclr.py +++ b/configs/_base_/datasets/imagenet_bs32_simclr.py @@ -5,7 +5,7 @@ data_preprocessor = dict( type='SelfSupDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], - bgr_to_rgb=True) + to_rgb=True) view_pipeline = [ dict(type='RandomResizedCrop', size=224, backend='pillow'), diff --git a/configs/_base_/datasets/imagenet_bs512_mae.py b/configs/_base_/datasets/imagenet_bs512_mae.py index abb48dea..b37776a6 100644 --- a/configs/_base_/datasets/imagenet_bs512_mae.py +++ b/configs/_base_/datasets/imagenet_bs512_mae.py @@ -5,7 +5,7 @@ data_preprocessor = dict( type='SelfSupDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], - bgr_to_rgb=True) + to_rgb=True) train_pipeline = [ dict(type='LoadImageFromFile'), diff --git a/configs/_base_/datasets/imagenet_bs512_mocov3.py b/configs/_base_/datasets/imagenet_bs512_mocov3.py index a69db4e9..c7a746cc 100644 --- a/configs/_base_/datasets/imagenet_bs512_mocov3.py +++ b/configs/_base_/datasets/imagenet_bs512_mocov3.py @@ -5,7 +5,7 @@ data_preprocessor = dict( type='SelfSupDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], - bgr_to_rgb=True) + to_rgb=True) view_pipeline1 = [ dict( diff --git a/configs/beit/beit_beit-base-p16_8xb256-amp-coslr-300e_in1k.py b/configs/beit/beit_beit-base-p16_8xb256-amp-coslr-300e_in1k.py index bf580cb2..fc773822 100644 --- a/configs/beit/beit_beit-base-p16_8xb256-amp-coslr-300e_in1k.py +++ b/configs/beit/beit_beit-base-p16_8xb256-amp-coslr-300e_in1k.py @@ -7,9 +7,9 @@ data_preprocessor = dict( type='TwoNormDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], - second_mean=[-20.4, -20.4, -20.4], - second_std=[204., 204., 204.], - bgr_to_rgb=True) + second_mean=[-31.875, -31.875, -31.875], + second_std=[318.75, 318.75, 318.75], + to_rgb=True) train_pipeline = [ dict(type='LoadImageFromFile'), diff --git a/configs/cae/cae_vit-base-p16_8xb256-amp-coslr-300e_in1k.py b/configs/cae/cae_vit-base-p16_8xb256-amp-coslr-300e_in1k.py index c06ecd79..7d35c20c 100644 --- a/configs/cae/cae_vit-base-p16_8xb256-amp-coslr-300e_in1k.py +++ b/configs/cae/cae_vit-base-p16_8xb256-amp-coslr-300e_in1k.py @@ -5,10 +5,12 @@ dataset_type = 'ImageNet' data_root = 'data/imagenet/' file_client_args = dict(backend='disk') data_preprocessor = dict( - type='CAEDataPreprocessor', - mean=[124, 117, 104], - std=[59, 58, 58], - bgr_to_rgb=True) + type='TwoNormDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + second_mean=[-31.875, -31.875, -31.875], + second_std=[318.75, 318.75, 318.75], + to_rgb=True) train_pipeline = [ dict(type='LoadImageFromFile', file_client_args=file_client_args), @@ -75,7 +77,7 @@ model = dict( type='mmselfsup.CAEDataPreprocessor', mean=[124, 117, 104], std=[59, 58, 58], - bgr_to_rgb=True), + to_rgb=True), base_momentum=0.0) # optimizer wrapper diff --git a/configs/densecl/densecl_resnet50_8xb32-coslr-200e_in1k.py b/configs/densecl/densecl_resnet50_8xb32-coslr-200e_in1k.py index 9ebc6557..32ba3bd9 100644 --- a/configs/densecl/densecl_resnet50_8xb32-coslr-200e_in1k.py +++ b/configs/densecl/densecl_resnet50_8xb32-coslr-200e_in1k.py @@ -11,10 +11,6 @@ model = dict( feat_dim=128, momentum=0.999, loss_lambda=0.5, - data_preprocessor=dict( - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - bgr_to_rgb=True), backbone=dict( type='ResNet', depth=50, diff --git a/configs/maskfeat/maskfeat_vit-base-p16_8xb256-amp-coslr-300e_in1k.py b/configs/maskfeat/maskfeat_vit-base-p16_8xb256-amp-coslr-300e_in1k.py index 04bbafac..89a63d49 100644 --- a/configs/maskfeat/maskfeat_vit-base-p16_8xb256-amp-coslr-300e_in1k.py +++ b/configs/maskfeat/maskfeat_vit-base-p16_8xb256-amp-coslr-300e_in1k.py @@ -7,7 +7,7 @@ data_preprocessor = dict( type='SelfSupDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], - bgr_to_rgb=True) + to_rgb=True) train_pipeline = [ dict(type='LoadImageFromFile'), @@ -50,7 +50,7 @@ model = dict( data_preprocessor=dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], - bgr_to_rgb=True), + to_rgb=True), backbone=dict(type='MaskFeatViT', arch='b', patch_size=16), neck=dict( type='LinearNeck', diff --git a/configs/mixmim/mixmim_mixmim-base_16xb128-coslr-300e_in1k.py b/configs/mixmim/mixmim_mixmim-base_16xb128-coslr-300e_in1k.py index ade6d746..45cae815 100644 --- a/configs/mixmim/mixmim_mixmim-base_16xb128-coslr-300e_in1k.py +++ b/configs/mixmim/mixmim_mixmim-base_16xb128-coslr-300e_in1k.py @@ -36,7 +36,7 @@ model = dict( data_preprocessor=dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], - bgr_to_rgb=True), + to_rgb=True), backbone=dict( type='MixMIMTransformerPretrain', arch='B', diff --git a/configs/mocov2/mocov2_resnet50_8xb32-coslr-200e_in1k.py b/configs/mocov2/mocov2_resnet50_8xb32-coslr-200e_in1k.py index 04c4e6a6..a090bcbb 100644 --- a/configs/mocov2/mocov2_resnet50_8xb32-coslr-200e_in1k.py +++ b/configs/mocov2/mocov2_resnet50_8xb32-coslr-200e_in1k.py @@ -13,7 +13,7 @@ model = dict( data_preprocessor=dict( mean=(123.675, 116.28, 103.53), std=(58.395, 57.12, 57.375), - bgr_to_rgb=True), + to_rgb=True), backbone=dict( type='ResNet', depth=50, diff --git a/configs/swav/swav_resnet50_8xb32-mcrop-coslr-200e_in1k-224px-96px.py b/configs/swav/swav_resnet50_8xb32-mcrop-coslr-200e_in1k-224px-96px.py index b3c86616..317dc2ff 100644 --- a/configs/swav/swav_resnet50_8xb32-mcrop-coslr-200e_in1k-224px-96px.py +++ b/configs/swav/swav_resnet50_8xb32-mcrop-coslr-200e_in1k-224px-96px.py @@ -10,7 +10,7 @@ data_preprocessor = dict( type='SelfSupDataPreprocessor', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], - bgr_to_rgb=True) + to_rgb=True) num_crops = [2, 6] color_distort_strength = 1.0 @@ -92,7 +92,7 @@ model = dict( data_preprocessor=dict( mean=(123.675, 116.28, 103.53), std=(58.395, 57.12, 57.375), - bgr_to_rgb=True), + to_rgb=True), backbone=dict( type='ResNet', depth=50, diff --git a/mmpretrain/models/selfsup/simclr.py b/mmpretrain/models/selfsup/simclr.py new file mode 100644 index 00000000..6123382b --- /dev/null +++ b/mmpretrain/models/selfsup/simclr.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, List, Tuple + +import torch +from mmengine.dist import all_gather, get_rank + + +class GatherLayer(torch.autograd.Function): + """Gather tensors from all process, supporting backward propagation.""" + + @staticmethod + def forward(ctx: Any, input: torch.Tensor) -> Tuple[List]: + ctx.save_for_backward(input) + output = all_gather(input) + return tuple(output) + + @staticmethod + def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor: + input, = ctx.saved_tensors + grad_out = torch.zeros_like(input) + grad_out[:] = grads[get_rank()] + return grad_out diff --git a/mmpretrain/models/utils/__init__.py b/mmpretrain/models/utils/__init__.py index b5391a72..4d3300ff 100644 --- a/mmpretrain/models/utils/__init__.py +++ b/mmpretrain/models/utils/__init__.py @@ -4,9 +4,12 @@ from .attention import (BEiTAttention, ChannelMultiheadAttention, MultiheadAttention, PromptMultiheadAttention, ShiftWindowMSA, WindowMSA, WindowMSAV2) from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix +from .batch_shuffle import batch_shuffle_ddp, batch_unshuffle_ddp from .channel_shuffle import channel_shuffle from .clip_generator_helper import build_clip_model -from .data_preprocessor import ClsDataPreprocessor +from .data_preprocessor import (ClsDataPreprocessor, SelfSupDataPreprocessor, + TwoNormDataPreprocessor, VideoDataPreprocessor) +from .ema import CosineEMA from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed, resize_relative_position_bias_table) from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple @@ -31,5 +34,7 @@ __all__ = [ 'PositionEncodingFourier', 'LeAttention', 'GRN', 'LayerNorm2d', 'build_norm_layer', 'CrossMultiheadAttention', 'build_2d_sincos_position_embedding', 'PromptMultiheadAttention', - 'NormEMAVectorQuantizer', 'build_clip_model' + 'NormEMAVectorQuantizer', 'build_clip_model', 'batch_shuffle_ddp', + 'batch_unshuffle_ddp', 'SelfSupDataPreprocessor', + 'TwoNormDataPreprocessor', 'VideoDataPreprocessor', 'CosineEMA' ] diff --git a/mmpretrain/models/utils/batch_shuffle.py b/mmpretrain/models/utils/batch_shuffle.py new file mode 100644 index 00000000..a0b03c5f --- /dev/null +++ b/mmpretrain/models/utils/batch_shuffle.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from mmengine.dist import all_gather, broadcast, get_rank + + +@torch.no_grad() +def batch_shuffle_ddp(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Batch shuffle, for making use of BatchNorm. + + Args: + x (torch.Tensor): Data in each GPU. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Output of shuffle operation. + - x_gather[idx_this]: Shuffled data. + - idx_unshuffle: Index for restoring. + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = torch.cat(all_gather(x), dim=0) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = torch.randperm(batch_size_all) + + # broadcast to all gpus + broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = torch.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = get_rank() + idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this], idx_unshuffle + + +@torch.no_grad() +def batch_unshuffle_ddp(x: torch.Tensor, + idx_unshuffle: torch.Tensor) -> torch.Tensor: + """Undo batch shuffle. + + Args: + x (torch.Tensor): Data in each GPU. + idx_unshuffle (torch.Tensor): Index for restoring. + + Returns: + torch.Tensor: Output of unshuffle operation. + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = torch.cat(all_gather(x), dim=0) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # restored index for this gpu + gpu_idx = get_rank() + idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this] diff --git a/mmpretrain/models/utils/data_preprocessor.py b/mmpretrain/models/utils/data_preprocessor.py index 6b65c972..d0317e09 100644 --- a/mmpretrain/models/utils/data_preprocessor.py +++ b/mmpretrain/models/utils/data_preprocessor.py @@ -1,11 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import math from numbers import Number -from typing import Optional, Sequence +from typing import List, Optional, Sequence, Tuple, Union import torch import torch.nn.functional as F -from mmengine.model import BaseDataPreprocessor, stack_batch +from mmengine.model import (BaseDataPreprocessor, ImgDataPreprocessor, + stack_batch) from mmpretrain.registry import MODELS from mmpretrain.structures import (DataSample, MultiTaskDataSample, @@ -194,3 +195,294 @@ class ClsDataPreprocessor(BaseDataPreprocessor): data_samples = self.cast_data(data_samples) return {'inputs': inputs, 'data_samples': data_samples} + + +@MODELS.register_module() +class SelfSupDataPreprocessor(ImgDataPreprocessor): + """Image pre-processor for operations, like normalization and bgr to rgb. + + Compared with the :class:`mmengine.ImgDataPreprocessor`, this module treats + each item in `inputs` of input data as a list, instead of torch.Tensor. + Besides, Add key ``to_rgb`` to align with :class:`ClsDataPreprocessor`. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + to_rgb: bool = False, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + non_blocking: Optional[bool] = False): + super().__init__( + mean=mean, + std=std, + pad_size_divisor=pad_size_divisor, + pad_value=pad_value, + bgr_to_rgb=bgr_to_rgb, + rgb_to_bgr=rgb_to_bgr, + non_blocking=non_blocking) + + self._channel_conversion = to_rgb or bgr_to_rgb or rgb_to_bgr + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[torch.Tensor, Optional[list]]: Data in the same format as the + model input. + """ + assert isinstance(data, + dict), 'Please use default_collate in dataloader, \ + instead of pseudo_collate.' + + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + # channel transform + if self._channel_conversion: + batch_inputs = [ + _input[:, [2, 1, 0], ...] for _input in batch_inputs + ] + + # Convert to float after channel conversion to ensure + # efficiency + batch_inputs = [input_.float() for input_ in batch_inputs] + + # Normalization. Here is what is different from + # :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views + # for an image for some algorithms, e.g. SimCLR, each item in inputs + # is a list, containing multi-views for an image. + if self._enable_normalize: + batch_inputs = [(_input - self.mean) / self.std + for _input in batch_inputs] + + return {'inputs': batch_inputs, 'data_samples': batch_data_samples} + + +@MODELS.register_module() +class TwoNormDataPreprocessor(SelfSupDataPreprocessor): + """Image pre-processor for CAE, BEiT v1/v2, etc. + + Compared with the :class:`mmselfsup.SelfSupDataPreprocessor`, this module + will normalize the prediction image and target image with different + normalization parameters. + + Args: + mean (Sequence[float or int], optional): The pixel mean of image + channels. If ``to_rgb=True`` it means the mean value of R, G, B + channels. If the length of `mean` is 1, it means all channels have + the same mean value, or the input is a gray image. If it is not + specified, images will not be normalized. Defaults to None. + std (Sequence[float or int], optional): The pixel standard deviation of + image channels. If ``to_rgb=True`` it means the standard deviation + of R, G, B channels. If the length of `std` is 1, it means all + channels have the same standard deviation, or the input is a gray + image. If it is not specified, images will not be normalized. + Defaults to None. + second_mean (Sequence[float or int], optional): The description is + like ``mean``, it can be customized for targe image. Defaults to + None. + second_std (Sequence[float or int], optional): The description is + like ``std``, it can be customized for targe image. Defaults to + None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (float or int): The padded pixel value. Defaults to 0. + to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + non_blocking (bool): Whether block current process when transferring + data to device. Defaults to False. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + second_mean: Sequence[Union[float, int]] = None, + second_std: Sequence[Union[float, int]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + to_rgb: bool = False, + non_blocking: Optional[bool] = False): + super().__init__( + mean=mean, + std=std, + pad_size_divisor=pad_size_divisor, + pad_value=pad_value, + to_rgb=to_rgb, + non_blocking=non_blocking) + assert (second_mean is not None) and (second_std is not None), ( + 'mean and std should not be None while using ' + '`TwoNormDataPreprocessor`') + assert len(second_mean) == 3 or len(second_mean) == 1, ( + '`mean` should have 1 or 3 values, to be compatible with ' + f'RGB or gray image, but got {len(second_mean)} values') + assert len(second_std) == 3 or len(second_std) == 1, ( + '`std` should have 1 or 3 values, to be compatible with RGB ' # type: ignore # noqa: E501 + f'or gray image, but got {len(std)} values') # type: ignore + + self.register_buffer('second_mean', + torch.tensor(second_mean).view(-1, 1, 1), False) + self.register_buffer('second_std', + torch.tensor(second_std).view(-1, 1, 1), False) + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[torch.Tensor, Optional[list]]: Data in the same format as the + model input. + """ + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + # channel transform + if self._channel_conversion: + batch_inputs = [ + _input[:, [2, 1, 0], ...] for _input in batch_inputs + ] + + # Convert to float after channel conversion to ensure + # efficiency + batch_inputs = [input_.float() for input_ in batch_inputs] + + # Normalization. Here is what is different from + # :class:`mmselfsup.SelfSupDataPreprocessor`. Normalize the target + # image and prediction image with different normalization params + if self._enable_normalize: + batch_inputs = [ + (batch_inputs[0] - self.mean) / self.std, + (batch_inputs[1] - self.second_mean) / self.second_std + ] + + return {'inputs': batch_inputs, 'data_samples': batch_data_samples} + + +@MODELS.register_module() +class VideoDataPreprocessor(BaseDataPreprocessor): + """Video pre-processor for operations, like normalization and bgr to rgb + conversion . + + Compared with the :class:`mmaction.ActionDataPreprocessor`, this module + treats each item in `inputs` of input data as a list, instead of + torch.Tensor. + + Args: + mean (Sequence[float or int, optional): The pixel mean of channels + of images or stacked optical flow. Defaults to None. + std (Sequence[float or int], optional): The pixel standard deviation + of channels of images or stacked optical flow. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (float or int): The padded pixel value. Defaults to 0. + to_rgb (bool): Whether to convert image from BGR to RGB. + Defaults to False. + format_shape (str): Format shape of input data. + Defaults to ``'NCHW'``. + """ + + def __init__(self, + mean: Optional[Sequence[Union[float, int]]] = None, + std: Optional[Sequence[Union[float, int]]] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + to_rgb: bool = False, + format_shape: str = 'NCHW') -> None: + super().__init__() + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + self.to_rgb = to_rgb + self.format_shape = format_shape + + if mean is not None: + assert std is not None, 'To enable the normalization in ' \ + 'preprocessing, please specify both ' \ + '`mean` and `std`.' + # Enable the normalization in preprocessing. + self._enable_normalize = True + if self.format_shape == 'NCHW': + normalizer_shape = (-1, 1, 1) + elif self.format_shape == 'NCTHW': + normalizer_shape = (-1, 1, 1, 1) + else: + raise ValueError(f'Invalid format shape: {format_shape}') + + self.register_buffer( + 'mean', + torch.tensor(mean, dtype=torch.float32).view(normalizer_shape), + False) + self.register_buffer( + 'std', + torch.tensor(std, dtype=torch.float32).view(normalizer_shape), + False) + else: + self._enable_normalize = False + + def forward( + self, + data: dict, + training: bool = False + ) -> Tuple[List[torch.Tensor], Optional[list]]: + """Performs normalization、padding and bgr2rgb conversion based on + ``BaseDataPreprocessor``. + + Args: + data (dict): data sampled from dataloader. + training (bool): Whether to enable training time augmentation. If + subclasses override this method, they can perform different + preprocessing strategies for training and testing based on the + value of ``training``. + Returns: + Tuple[List[torch.Tensor], Optional[list]]: Data in the same format + as the model input. + """ + + data = [val for _, val in data.items()] + batch_inputs, batch_data_samples = self.cast_data(data) + + # ------ To RGB ------ + if self.to_rgb: + if self.format_shape == 'NCHW': + batch_inputs = [ + batch_input[..., [2, 1, 0], :, :] + for batch_input in batch_inputs + ] + elif self.format_shape == 'NCTHW': + batch_inputs = [ + batch_input[..., [2, 1, 0], :, :, :] + for batch_input in batch_inputs + ] + else: + raise ValueError(f'Invalid format shape: {self.format_shape}') + + # -- Normalization --- + if self._enable_normalize: + batch_inputs = [(batch_input - self.mean) / self.std + for batch_input in batch_inputs] + else: + batch_inputs = [ + batch_input.to(torch.float32) for batch_input in batch_inputs + ] + + return {'inputs': batch_inputs, 'data_samples': batch_data_samples} diff --git a/mmpretrain/models/utils/ema.py b/mmpretrain/models/utils/ema.py new file mode 100644 index 00000000..3f3869b4 --- /dev/null +++ b/mmpretrain/models/utils/ema.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from math import cos, pi +from typing import Optional + +import torch +import torch.nn as nn +from mmengine.logging import MessageHub +from mmengine.model import ExponentialMovingAverage + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class CosineEMA(ExponentialMovingAverage): + """CosineEMA is implemented for updating momentum parameter, used in BYOL, + MoCoV3, etc. + + The momentum parameter is updated with cosine annealing, including momentum + adjustment following: + + .. math:: + m = m_1 - (m_1 - m_0) * (cos(pi * k / K) + 1) / 2 + + where :math:`k` is the current step, :math:`K` is the total steps. + + Args: + model (nn.Module): The model to be averaged. + momentum (float): The momentum used for updating ema parameter. + Ema's parameter are updated with the formula: + `averaged_param = momentum * averaged_param + (1-momentum) * + source_param`. Defaults to 0.996. + end_momentum (float): The end momentum value for cosine annealing. + Defaults to 1. + interval (int): Interval between two updates. Defaults to 1. + device (torch.device, optional): If provided, the averaged model will + be stored on the :attr:`device`. Defaults to None. + update_buffers (bool): if True, it will compute running averages for + both the parameters and the buffers of the model. Defaults to + False. + """ + + def __init__(self, + model: nn.Module, + momentum: float = 0.996, + end_momentum: float = 1., + interval: int = 1, + device: Optional[torch.device] = None, + update_buffers: bool = False) -> None: + super().__init__( + model=model, + momentum=momentum, + interval=interval, + device=device, + update_buffers=update_buffers) + self.end_momentum = end_momentum + + def avg_func(self, averaged_param: torch.Tensor, + source_param: torch.Tensor, steps: int) -> None: + """Compute the moving average of the parameters using the cosine + momentum strategy. + + Args: + averaged_param (Tensor): The averaged parameters. + source_param (Tensor): The source parameters. + steps (int): The number of times the parameters have been + updated. + Returns: + Tensor: The averaged parameters. + """ + message_hub = MessageHub.get_current_instance() + max_iters = message_hub.get_info('max_iters') + momentum = self.end_momentum - (self.end_momentum - self.momentum) * ( + cos(pi * steps / float(max_iters)) + 1) / 2 + averaged_param.mul_(momentum).add_(source_param, alpha=(1 - momentum)) diff --git a/mmpretrain/utils/__init__.py b/mmpretrain/utils/__init__.py index f110a784..328c01d7 100644 --- a/mmpretrain/utils/__init__.py +++ b/mmpretrain/utils/__init__.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .analyze import load_json_log from .collect_env import collect_env +from .misc import get_ori_model from .progress import track_on_main_process from .setup_env import register_all_modules __all__ = [ 'collect_env', 'register_all_modules', 'track_on_main_process', - 'load_json_log' + 'load_json_log', 'get_ori_model' ] diff --git a/mmpretrain/utils/misc.py b/mmpretrain/utils/misc.py new file mode 100644 index 00000000..cc532679 --- /dev/null +++ b/mmpretrain/utils/misc.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +from mmengine.model import is_model_wrapper + + +def get_ori_model(model: nn.Module) -> nn.Module: + """Get original model if the input model is a model wrapper. + + Args: + model (nn.Module): A model may be a model wrapper. + + Returns: + nn.Module: The model without model wrapper. + """ + if is_model_wrapper(model): + return model.module + else: + return model diff --git a/tests/test_models/test_utils/test_data_preprocessor.py b/tests/test_models/test_utils/test_data_preprocessor.py index 1edfcfac..a0a2923d 100644 --- a/tests/test_models/test_utils/test_data_preprocessor.py +++ b/tests/test_models/test_utils/test_data_preprocessor.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase +import pytest import torch -from mmpretrain.models import ClsDataPreprocessor, RandomBatchAugment +from mmpretrain.models import (ClsDataPreprocessor, RandomBatchAugment, + SelfSupDataPreprocessor, + TwoNormDataPreprocessor, VideoDataPreprocessor) from mmpretrain.registry import MODELS from mmpretrain.structures import DataSample @@ -99,3 +102,111 @@ class TestClsDataPreprocessor(TestCase): processed_data = processor(data, training=True) self.assertIn('inputs', processed_data) self.assertIsNone(processed_data['data_samples']) + + +class TestSelfSupDataPreprocessor(TestCase): + + def test_to_rgb(self): + cfg = dict(type='SelfSupDataPreprocessor', to_rgb=True) + processor: SelfSupDataPreprocessor = MODELS.build(cfg) + self.assertTrue(processor._channel_conversion) + + fake_data = { + 'inputs': + [torch.randn((2, 3, 224, 224)), + torch.randn((2, 3, 224, 224))], + 'data_samples': [DataSample(), DataSample()] + } + inputs = processor(fake_data)['inputs'] + torch.testing.assert_allclose(fake_data['inputs'][0].flip(1).float(), + inputs[0]) + torch.testing.assert_allclose(fake_data['inputs'][1].flip(1).float(), + inputs[1]) + + def test_forward(self): + data_preprocessor = SelfSupDataPreprocessor( + to_rgb=True, mean=[124, 117, 104], std=[59, 58, 58]) + fake_data = { + 'inputs': [torch.randn((2, 3, 224, 224))], + 'data_samples': [DataSample(), DataSample()] + } + fake_output = data_preprocessor(fake_data) + self.assertEqual(len(fake_output['inputs']), 1) + self.assertEqual(len(fake_output['data_samples']), 2) + + +class TestTwoNormDataPreprocessor(TestCase): + + def test_assertion(self): + with pytest.raises(AssertionError): + _ = TwoNormDataPreprocessor( + to_rgb=True, + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + ) + with pytest.raises(AssertionError): + _ = TwoNormDataPreprocessor( + to_rgb=True, + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + second_mean=(127.5, 127.5), + second_std=(127.5, 127.5, 127.5), + ) + with pytest.raises(AssertionError): + _ = TwoNormDataPreprocessor( + to_rgb=True, + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + second_mean=(127.5, 127.5, 127.5), + second_std=(127.5, 127.5), + ) + + def test_forward(self): + data_preprocessor = dict( + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + second_mean=(127.5, 127.5, 127.5), + second_std=(127.5, 127.5, 127.5), + to_rgb=True) + + data_preprocessor = TwoNormDataPreprocessor(**data_preprocessor) + fake_data = { + 'inputs': + [torch.randn((2, 3, 224, 224)), + torch.randn((2, 3, 224, 224))], + 'data_sample': [DataSample(), DataSample()] + } + fake_output = data_preprocessor(fake_data) + self.assertEqual(len(fake_output['inputs']), 2) + self.assertEqual(len(fake_output['data_samples']), 2) + + +class TestVideoDataPreprocessor(TestCase): + + def test_NCTHW_format(self): + data_preprocessor = VideoDataPreprocessor( + mean=[114.75, 114.75, 114.75], + std=[57.375, 57.375, 57.375], + to_rgb=True, + format_shape='NCTHW') + fake_data = { + 'inputs': [torch.randn((2, 3, 4, 224, 224))], + 'data_sample': [DataSample(), DataSample()] + } + fake_output = data_preprocessor(fake_data) + self.assertEqual(len(fake_output['inputs']), 1) + self.assertEqual(len(fake_output['data_samples']), 2) + + def test_NCHW_format(self): + data_preprocessor = VideoDataPreprocessor( + mean=[114.75, 114.75, 114.75], + std=[57.375, 57.375, 57.375], + to_rgb=True, + format_shape='NCHW') + fake_data = { + 'inputs': [torch.randn((2, 3, 224, 224))], + 'data_sample': [DataSample(), DataSample()] + } + fake_output = data_preprocessor(fake_data) + self.assertEqual(len(fake_output['inputs']), 1) + self.assertEqual(len(fake_output['data_samples']), 2) diff --git a/tests/test_models/test_utils/test_ema.py b/tests/test_models/test_utils/test_ema.py new file mode 100644 index 00000000..3166ae15 --- /dev/null +++ b/tests/test_models/test_utils/test_ema.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from unittest import TestCase + +import torch +import torch.nn as nn +from mmengine.logging import MessageHub +from mmengine.testing import assert_allclose + +from mmpretrain.models.utils import CosineEMA + + +class TestEMA(TestCase): + + def test_cosine_ema(self): + model = nn.Sequential(nn.Conv2d(1, 5, kernel_size=3), nn.Linear(5, 10)) + + # init message hub + max_iters = 5 + test = dict(name='ema_test') + message_hub = MessageHub.get_instance(**test) + message_hub.update_info('max_iters', max_iters) + + # test EMA + momentum = 0.996 + end_momentum = 1. + + ema_model = CosineEMA(model, momentum=momentum) + averaged_params = [ + torch.zeros_like(param) for param in model.parameters() + ] + + for i in range(max_iters): + updated_averaged_params = [] + for p, p_avg in zip(model.parameters(), averaged_params): + p.detach().add_(torch.randn_like(p)) + if i == 0: + updated_averaged_params.append(p.clone()) + else: + m = end_momentum - (end_momentum - momentum) * ( + math.cos(math.pi * i / float(max_iters)) + 1) / 2 + updated_averaged_params.append( + (p_avg * m + p * (1 - m)).clone()) + ema_model.update_parameters(model) + averaged_params = updated_averaged_params + + for p_target, p_ema in zip(averaged_params, ema_model.parameters()): + assert_allclose(p_target, p_ema)