[Refactor] Move and refactor utils from mmselfsup. (#1385)
* add heads * add losses * fix * remove mim head * add modified backbones and target generators * fix lint * fix lint * add heads * add losses * fix * add data preprocessor from mmselfsup * add ut for data prepocessor * add GatherLayer * add ema * add batch shuffle * add misc * fix lint * update * update docstringpull/1400/head
parent
414ba80274
commit
c9670173aa
|
@ -7,7 +7,7 @@ data_preprocessor = dict(
|
|||
std=[58.395, 57.12, 57.375],
|
||||
second_mean=[127.5, 127.5, 127.5],
|
||||
second_std=[127.5, 127.5, 127.5],
|
||||
bgr_to_rgb=True)
|
||||
to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
|
|
|
@ -5,7 +5,7 @@ data_preprocessor = dict(
|
|||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True)
|
||||
to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
|
|
|
@ -5,7 +5,7 @@ data_preprocessor = dict(
|
|||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True)
|
||||
to_rgb=True)
|
||||
|
||||
view_pipeline1 = [
|
||||
dict(
|
||||
|
|
|
@ -5,7 +5,7 @@ data_preprocessor = dict(
|
|||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True)
|
||||
to_rgb=True)
|
||||
|
||||
# The difference between mocov2 and mocov1 is the transforms in the pipeline
|
||||
view_pipeline = [
|
||||
|
|
|
@ -5,7 +5,7 @@ data_preprocessor = dict(
|
|||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True)
|
||||
to_rgb=True)
|
||||
|
||||
view_pipeline = [
|
||||
dict(type='RandomResizedCrop', size=224, backend='pillow'),
|
||||
|
|
|
@ -5,7 +5,7 @@ data_preprocessor = dict(
|
|||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True)
|
||||
to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
|
|
|
@ -5,7 +5,7 @@ data_preprocessor = dict(
|
|||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True)
|
||||
to_rgb=True)
|
||||
|
||||
view_pipeline1 = [
|
||||
dict(
|
||||
|
|
|
@ -7,9 +7,9 @@ data_preprocessor = dict(
|
|||
type='TwoNormDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
second_mean=[-20.4, -20.4, -20.4],
|
||||
second_std=[204., 204., 204.],
|
||||
bgr_to_rgb=True)
|
||||
second_mean=[-31.875, -31.875, -31.875],
|
||||
second_std=[318.75, 318.75, 318.75],
|
||||
to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
|
|
|
@ -5,10 +5,12 @@ dataset_type = 'ImageNet'
|
|||
data_root = 'data/imagenet/'
|
||||
file_client_args = dict(backend='disk')
|
||||
data_preprocessor = dict(
|
||||
type='CAEDataPreprocessor',
|
||||
mean=[124, 117, 104],
|
||||
std=[59, 58, 58],
|
||||
bgr_to_rgb=True)
|
||||
type='TwoNormDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
second_mean=[-31.875, -31.875, -31.875],
|
||||
second_std=[318.75, 318.75, 318.75],
|
||||
to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=file_client_args),
|
||||
|
@ -75,7 +77,7 @@ model = dict(
|
|||
type='mmselfsup.CAEDataPreprocessor',
|
||||
mean=[124, 117, 104],
|
||||
std=[59, 58, 58],
|
||||
bgr_to_rgb=True),
|
||||
to_rgb=True),
|
||||
base_momentum=0.0)
|
||||
|
||||
# optimizer wrapper
|
||||
|
|
|
@ -11,10 +11,6 @@ model = dict(
|
|||
feat_dim=128,
|
||||
momentum=0.999,
|
||||
loss_lambda=0.5,
|
||||
data_preprocessor=dict(
|
||||
mean=(123.675, 116.28, 103.53),
|
||||
std=(58.395, 57.12, 57.375),
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
|
|
@ -7,7 +7,7 @@ data_preprocessor = dict(
|
|||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True)
|
||||
to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
|
@ -50,7 +50,7 @@ model = dict(
|
|||
data_preprocessor=dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True),
|
||||
to_rgb=True),
|
||||
backbone=dict(type='MaskFeatViT', arch='b', patch_size=16),
|
||||
neck=dict(
|
||||
type='LinearNeck',
|
||||
|
|
|
@ -36,7 +36,7 @@ model = dict(
|
|||
data_preprocessor=dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True),
|
||||
to_rgb=True),
|
||||
backbone=dict(
|
||||
type='MixMIMTransformerPretrain',
|
||||
arch='B',
|
||||
|
|
|
@ -13,7 +13,7 @@ model = dict(
|
|||
data_preprocessor=dict(
|
||||
mean=(123.675, 116.28, 103.53),
|
||||
std=(58.395, 57.12, 57.375),
|
||||
bgr_to_rgb=True),
|
||||
to_rgb=True),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
|
|
@ -10,7 +10,7 @@ data_preprocessor = dict(
|
|||
type='SelfSupDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True)
|
||||
to_rgb=True)
|
||||
|
||||
num_crops = [2, 6]
|
||||
color_distort_strength = 1.0
|
||||
|
@ -92,7 +92,7 @@ model = dict(
|
|||
data_preprocessor=dict(
|
||||
mean=(123.675, 116.28, 103.53),
|
||||
std=(58.395, 57.12, 57.375),
|
||||
bgr_to_rgb=True),
|
||||
to_rgb=True),
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=50,
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.dist import all_gather, get_rank
|
||||
|
||||
|
||||
class GatherLayer(torch.autograd.Function):
|
||||
"""Gather tensors from all process, supporting backward propagation."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, input: torch.Tensor) -> Tuple[List]:
|
||||
ctx.save_for_backward(input)
|
||||
output = all_gather(input)
|
||||
return tuple(output)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor:
|
||||
input, = ctx.saved_tensors
|
||||
grad_out = torch.zeros_like(input)
|
||||
grad_out[:] = grads[get_rank()]
|
||||
return grad_out
|
|
@ -4,9 +4,12 @@ from .attention import (BEiTAttention, ChannelMultiheadAttention,
|
|||
MultiheadAttention, PromptMultiheadAttention,
|
||||
ShiftWindowMSA, WindowMSA, WindowMSAV2)
|
||||
from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix
|
||||
from .batch_shuffle import batch_shuffle_ddp, batch_unshuffle_ddp
|
||||
from .channel_shuffle import channel_shuffle
|
||||
from .clip_generator_helper import build_clip_model
|
||||
from .data_preprocessor import ClsDataPreprocessor
|
||||
from .data_preprocessor import (ClsDataPreprocessor, SelfSupDataPreprocessor,
|
||||
TwoNormDataPreprocessor, VideoDataPreprocessor)
|
||||
from .ema import CosineEMA
|
||||
from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed,
|
||||
resize_relative_position_bias_table)
|
||||
from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple
|
||||
|
@ -31,5 +34,7 @@ __all__ = [
|
|||
'PositionEncodingFourier', 'LeAttention', 'GRN', 'LayerNorm2d',
|
||||
'build_norm_layer', 'CrossMultiheadAttention',
|
||||
'build_2d_sincos_position_embedding', 'PromptMultiheadAttention',
|
||||
'NormEMAVectorQuantizer', 'build_clip_model'
|
||||
'NormEMAVectorQuantizer', 'build_clip_model', 'batch_shuffle_ddp',
|
||||
'batch_unshuffle_ddp', 'SelfSupDataPreprocessor',
|
||||
'TwoNormDataPreprocessor', 'VideoDataPreprocessor', 'CosineEMA'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.dist import all_gather, broadcast, get_rank
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def batch_shuffle_ddp(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Batch shuffle, for making use of BatchNorm.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Data in each GPU.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Output of shuffle operation.
|
||||
- x_gather[idx_this]: Shuffled data.
|
||||
- idx_unshuffle: Index for restoring.
|
||||
"""
|
||||
# gather from all gpus
|
||||
batch_size_this = x.shape[0]
|
||||
x_gather = torch.cat(all_gather(x), dim=0)
|
||||
batch_size_all = x_gather.shape[0]
|
||||
|
||||
num_gpus = batch_size_all // batch_size_this
|
||||
|
||||
# random shuffle index
|
||||
idx_shuffle = torch.randperm(batch_size_all)
|
||||
|
||||
# broadcast to all gpus
|
||||
broadcast(idx_shuffle, src=0)
|
||||
|
||||
# index for restoring
|
||||
idx_unshuffle = torch.argsort(idx_shuffle)
|
||||
|
||||
# shuffled index for this gpu
|
||||
gpu_idx = get_rank()
|
||||
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
|
||||
|
||||
return x_gather[idx_this], idx_unshuffle
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def batch_unshuffle_ddp(x: torch.Tensor,
|
||||
idx_unshuffle: torch.Tensor) -> torch.Tensor:
|
||||
"""Undo batch shuffle.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Data in each GPU.
|
||||
idx_unshuffle (torch.Tensor): Index for restoring.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output of unshuffle operation.
|
||||
"""
|
||||
# gather from all gpus
|
||||
batch_size_this = x.shape[0]
|
||||
x_gather = torch.cat(all_gather(x), dim=0)
|
||||
batch_size_all = x_gather.shape[0]
|
||||
|
||||
num_gpus = batch_size_all // batch_size_this
|
||||
|
||||
# restored index for this gpu
|
||||
gpu_idx = get_rank()
|
||||
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
|
||||
|
||||
return x_gather[idx_this]
|
|
@ -1,11 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from numbers import Number
|
||||
from typing import Optional, Sequence
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from mmengine.model import BaseDataPreprocessor, stack_batch
|
||||
from mmengine.model import (BaseDataPreprocessor, ImgDataPreprocessor,
|
||||
stack_batch)
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import (DataSample, MultiTaskDataSample,
|
||||
|
@ -194,3 +195,294 @@ class ClsDataPreprocessor(BaseDataPreprocessor):
|
|||
data_samples = self.cast_data(data_samples)
|
||||
|
||||
return {'inputs': inputs, 'data_samples': data_samples}
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SelfSupDataPreprocessor(ImgDataPreprocessor):
|
||||
"""Image pre-processor for operations, like normalization and bgr to rgb.
|
||||
|
||||
Compared with the :class:`mmengine.ImgDataPreprocessor`, this module treats
|
||||
each item in `inputs` of input data as a list, instead of torch.Tensor.
|
||||
Besides, Add key ``to_rgb`` to align with :class:`ClsDataPreprocessor`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mean: Optional[Sequence[Union[float, int]]] = None,
|
||||
std: Optional[Sequence[Union[float, int]]] = None,
|
||||
pad_size_divisor: int = 1,
|
||||
pad_value: Union[float, int] = 0,
|
||||
to_rgb: bool = False,
|
||||
bgr_to_rgb: bool = False,
|
||||
rgb_to_bgr: bool = False,
|
||||
non_blocking: Optional[bool] = False):
|
||||
super().__init__(
|
||||
mean=mean,
|
||||
std=std,
|
||||
pad_size_divisor=pad_size_divisor,
|
||||
pad_value=pad_value,
|
||||
bgr_to_rgb=bgr_to_rgb,
|
||||
rgb_to_bgr=rgb_to_bgr,
|
||||
non_blocking=non_blocking)
|
||||
|
||||
self._channel_conversion = to_rgb or bgr_to_rgb or rgb_to_bgr
|
||||
|
||||
def forward(
|
||||
self,
|
||||
data: dict,
|
||||
training: bool = False
|
||||
) -> Tuple[List[torch.Tensor], Optional[list]]:
|
||||
"""Performs normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (dict): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation. If
|
||||
subclasses override this method, they can perform different
|
||||
preprocessing strategies for training and testing based on the
|
||||
value of ``training``.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||
model input.
|
||||
"""
|
||||
assert isinstance(data,
|
||||
dict), 'Please use default_collate in dataloader, \
|
||||
instead of pseudo_collate.'
|
||||
|
||||
data = [val for _, val in data.items()]
|
||||
batch_inputs, batch_data_samples = self.cast_data(data)
|
||||
# channel transform
|
||||
if self._channel_conversion:
|
||||
batch_inputs = [
|
||||
_input[:, [2, 1, 0], ...] for _input in batch_inputs
|
||||
]
|
||||
|
||||
# Convert to float after channel conversion to ensure
|
||||
# efficiency
|
||||
batch_inputs = [input_.float() for input_ in batch_inputs]
|
||||
|
||||
# Normalization. Here is what is different from
|
||||
# :class:`mmengine.ImgDataPreprocessor`. Since there are multiple views
|
||||
# for an image for some algorithms, e.g. SimCLR, each item in inputs
|
||||
# is a list, containing multi-views for an image.
|
||||
if self._enable_normalize:
|
||||
batch_inputs = [(_input - self.mean) / self.std
|
||||
for _input in batch_inputs]
|
||||
|
||||
return {'inputs': batch_inputs, 'data_samples': batch_data_samples}
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class TwoNormDataPreprocessor(SelfSupDataPreprocessor):
|
||||
"""Image pre-processor for CAE, BEiT v1/v2, etc.
|
||||
|
||||
Compared with the :class:`mmselfsup.SelfSupDataPreprocessor`, this module
|
||||
will normalize the prediction image and target image with different
|
||||
normalization parameters.
|
||||
|
||||
Args:
|
||||
mean (Sequence[float or int], optional): The pixel mean of image
|
||||
channels. If ``to_rgb=True`` it means the mean value of R, G, B
|
||||
channels. If the length of `mean` is 1, it means all channels have
|
||||
the same mean value, or the input is a gray image. If it is not
|
||||
specified, images will not be normalized. Defaults to None.
|
||||
std (Sequence[float or int], optional): The pixel standard deviation of
|
||||
image channels. If ``to_rgb=True`` it means the standard deviation
|
||||
of R, G, B channels. If the length of `std` is 1, it means all
|
||||
channels have the same standard deviation, or the input is a gray
|
||||
image. If it is not specified, images will not be normalized.
|
||||
Defaults to None.
|
||||
second_mean (Sequence[float or int], optional): The description is
|
||||
like ``mean``, it can be customized for targe image. Defaults to
|
||||
None.
|
||||
second_std (Sequence[float or int], optional): The description is
|
||||
like ``std``, it can be customized for targe image. Defaults to
|
||||
None.
|
||||
pad_size_divisor (int): The size of padded image should be
|
||||
divisible by ``pad_size_divisor``. Defaults to 1.
|
||||
pad_value (float or int): The padded pixel value. Defaults to 0.
|
||||
to_rgb (bool): whether to convert image from BGR to RGB.
|
||||
Defaults to False.
|
||||
non_blocking (bool): Whether block current process when transferring
|
||||
data to device. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mean: Optional[Sequence[Union[float, int]]] = None,
|
||||
std: Optional[Sequence[Union[float, int]]] = None,
|
||||
second_mean: Sequence[Union[float, int]] = None,
|
||||
second_std: Sequence[Union[float, int]] = None,
|
||||
pad_size_divisor: int = 1,
|
||||
pad_value: Union[float, int] = 0,
|
||||
to_rgb: bool = False,
|
||||
non_blocking: Optional[bool] = False):
|
||||
super().__init__(
|
||||
mean=mean,
|
||||
std=std,
|
||||
pad_size_divisor=pad_size_divisor,
|
||||
pad_value=pad_value,
|
||||
to_rgb=to_rgb,
|
||||
non_blocking=non_blocking)
|
||||
assert (second_mean is not None) and (second_std is not None), (
|
||||
'mean and std should not be None while using '
|
||||
'`TwoNormDataPreprocessor`')
|
||||
assert len(second_mean) == 3 or len(second_mean) == 1, (
|
||||
'`mean` should have 1 or 3 values, to be compatible with '
|
||||
f'RGB or gray image, but got {len(second_mean)} values')
|
||||
assert len(second_std) == 3 or len(second_std) == 1, (
|
||||
'`std` should have 1 or 3 values, to be compatible with RGB ' # type: ignore # noqa: E501
|
||||
f'or gray image, but got {len(std)} values') # type: ignore
|
||||
|
||||
self.register_buffer('second_mean',
|
||||
torch.tensor(second_mean).view(-1, 1, 1), False)
|
||||
self.register_buffer('second_std',
|
||||
torch.tensor(second_std).view(-1, 1, 1), False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
data: dict,
|
||||
training: bool = False
|
||||
) -> Tuple[List[torch.Tensor], Optional[list]]:
|
||||
"""Performs normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (dict): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation. If
|
||||
subclasses override this method, they can perform different
|
||||
preprocessing strategies for training and testing based on the
|
||||
value of ``training``.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||
model input.
|
||||
"""
|
||||
data = [val for _, val in data.items()]
|
||||
batch_inputs, batch_data_samples = self.cast_data(data)
|
||||
# channel transform
|
||||
if self._channel_conversion:
|
||||
batch_inputs = [
|
||||
_input[:, [2, 1, 0], ...] for _input in batch_inputs
|
||||
]
|
||||
|
||||
# Convert to float after channel conversion to ensure
|
||||
# efficiency
|
||||
batch_inputs = [input_.float() for input_ in batch_inputs]
|
||||
|
||||
# Normalization. Here is what is different from
|
||||
# :class:`mmselfsup.SelfSupDataPreprocessor`. Normalize the target
|
||||
# image and prediction image with different normalization params
|
||||
if self._enable_normalize:
|
||||
batch_inputs = [
|
||||
(batch_inputs[0] - self.mean) / self.std,
|
||||
(batch_inputs[1] - self.second_mean) / self.second_std
|
||||
]
|
||||
|
||||
return {'inputs': batch_inputs, 'data_samples': batch_data_samples}
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class VideoDataPreprocessor(BaseDataPreprocessor):
|
||||
"""Video pre-processor for operations, like normalization and bgr to rgb
|
||||
conversion .
|
||||
|
||||
Compared with the :class:`mmaction.ActionDataPreprocessor`, this module
|
||||
treats each item in `inputs` of input data as a list, instead of
|
||||
torch.Tensor.
|
||||
|
||||
Args:
|
||||
mean (Sequence[float or int, optional): The pixel mean of channels
|
||||
of images or stacked optical flow. Defaults to None.
|
||||
std (Sequence[float or int], optional): The pixel standard deviation
|
||||
of channels of images or stacked optical flow. Defaults to None.
|
||||
pad_size_divisor (int): The size of padded image should be
|
||||
divisible by ``pad_size_divisor``. Defaults to 1.
|
||||
pad_value (float or int): The padded pixel value. Defaults to 0.
|
||||
to_rgb (bool): Whether to convert image from BGR to RGB.
|
||||
Defaults to False.
|
||||
format_shape (str): Format shape of input data.
|
||||
Defaults to ``'NCHW'``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mean: Optional[Sequence[Union[float, int]]] = None,
|
||||
std: Optional[Sequence[Union[float, int]]] = None,
|
||||
pad_size_divisor: int = 1,
|
||||
pad_value: Union[float, int] = 0,
|
||||
to_rgb: bool = False,
|
||||
format_shape: str = 'NCHW') -> None:
|
||||
super().__init__()
|
||||
self.pad_size_divisor = pad_size_divisor
|
||||
self.pad_value = pad_value
|
||||
self.to_rgb = to_rgb
|
||||
self.format_shape = format_shape
|
||||
|
||||
if mean is not None:
|
||||
assert std is not None, 'To enable the normalization in ' \
|
||||
'preprocessing, please specify both ' \
|
||||
'`mean` and `std`.'
|
||||
# Enable the normalization in preprocessing.
|
||||
self._enable_normalize = True
|
||||
if self.format_shape == 'NCHW':
|
||||
normalizer_shape = (-1, 1, 1)
|
||||
elif self.format_shape == 'NCTHW':
|
||||
normalizer_shape = (-1, 1, 1, 1)
|
||||
else:
|
||||
raise ValueError(f'Invalid format shape: {format_shape}')
|
||||
|
||||
self.register_buffer(
|
||||
'mean',
|
||||
torch.tensor(mean, dtype=torch.float32).view(normalizer_shape),
|
||||
False)
|
||||
self.register_buffer(
|
||||
'std',
|
||||
torch.tensor(std, dtype=torch.float32).view(normalizer_shape),
|
||||
False)
|
||||
else:
|
||||
self._enable_normalize = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
data: dict,
|
||||
training: bool = False
|
||||
) -> Tuple[List[torch.Tensor], Optional[list]]:
|
||||
"""Performs normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (dict): data sampled from dataloader.
|
||||
training (bool): Whether to enable training time augmentation. If
|
||||
subclasses override this method, they can perform different
|
||||
preprocessing strategies for training and testing based on the
|
||||
value of ``training``.
|
||||
Returns:
|
||||
Tuple[List[torch.Tensor], Optional[list]]: Data in the same format
|
||||
as the model input.
|
||||
"""
|
||||
|
||||
data = [val for _, val in data.items()]
|
||||
batch_inputs, batch_data_samples = self.cast_data(data)
|
||||
|
||||
# ------ To RGB ------
|
||||
if self.to_rgb:
|
||||
if self.format_shape == 'NCHW':
|
||||
batch_inputs = [
|
||||
batch_input[..., [2, 1, 0], :, :]
|
||||
for batch_input in batch_inputs
|
||||
]
|
||||
elif self.format_shape == 'NCTHW':
|
||||
batch_inputs = [
|
||||
batch_input[..., [2, 1, 0], :, :, :]
|
||||
for batch_input in batch_inputs
|
||||
]
|
||||
else:
|
||||
raise ValueError(f'Invalid format shape: {self.format_shape}')
|
||||
|
||||
# -- Normalization ---
|
||||
if self._enable_normalize:
|
||||
batch_inputs = [(batch_input - self.mean) / self.std
|
||||
for batch_input in batch_inputs]
|
||||
else:
|
||||
batch_inputs = [
|
||||
batch_input.to(torch.float32) for batch_input in batch_inputs
|
||||
]
|
||||
|
||||
return {'inputs': batch_inputs, 'data_samples': batch_data_samples}
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from math import cos, pi
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.logging import MessageHub
|
||||
from mmengine.model import ExponentialMovingAverage
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CosineEMA(ExponentialMovingAverage):
|
||||
"""CosineEMA is implemented for updating momentum parameter, used in BYOL,
|
||||
MoCoV3, etc.
|
||||
|
||||
The momentum parameter is updated with cosine annealing, including momentum
|
||||
adjustment following:
|
||||
|
||||
.. math::
|
||||
m = m_1 - (m_1 - m_0) * (cos(pi * k / K) + 1) / 2
|
||||
|
||||
where :math:`k` is the current step, :math:`K` is the total steps.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be averaged.
|
||||
momentum (float): The momentum used for updating ema parameter.
|
||||
Ema's parameter are updated with the formula:
|
||||
`averaged_param = momentum * averaged_param + (1-momentum) *
|
||||
source_param`. Defaults to 0.996.
|
||||
end_momentum (float): The end momentum value for cosine annealing.
|
||||
Defaults to 1.
|
||||
interval (int): Interval between two updates. Defaults to 1.
|
||||
device (torch.device, optional): If provided, the averaged model will
|
||||
be stored on the :attr:`device`. Defaults to None.
|
||||
update_buffers (bool): if True, it will compute running averages for
|
||||
both the parameters and the buffers of the model. Defaults to
|
||||
False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
momentum: float = 0.996,
|
||||
end_momentum: float = 1.,
|
||||
interval: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
update_buffers: bool = False) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
momentum=momentum,
|
||||
interval=interval,
|
||||
device=device,
|
||||
update_buffers=update_buffers)
|
||||
self.end_momentum = end_momentum
|
||||
|
||||
def avg_func(self, averaged_param: torch.Tensor,
|
||||
source_param: torch.Tensor, steps: int) -> None:
|
||||
"""Compute the moving average of the parameters using the cosine
|
||||
momentum strategy.
|
||||
|
||||
Args:
|
||||
averaged_param (Tensor): The averaged parameters.
|
||||
source_param (Tensor): The source parameters.
|
||||
steps (int): The number of times the parameters have been
|
||||
updated.
|
||||
Returns:
|
||||
Tensor: The averaged parameters.
|
||||
"""
|
||||
message_hub = MessageHub.get_current_instance()
|
||||
max_iters = message_hub.get_info('max_iters')
|
||||
momentum = self.end_momentum - (self.end_momentum - self.momentum) * (
|
||||
cos(pi * steps / float(max_iters)) + 1) / 2
|
||||
averaged_param.mul_(momentum).add_(source_param, alpha=(1 - momentum))
|
|
@ -1,10 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .analyze import load_json_log
|
||||
from .collect_env import collect_env
|
||||
from .misc import get_ori_model
|
||||
from .progress import track_on_main_process
|
||||
from .setup_env import register_all_modules
|
||||
|
||||
__all__ = [
|
||||
'collect_env', 'register_all_modules', 'track_on_main_process',
|
||||
'load_json_log'
|
||||
'load_json_log', 'get_ori_model'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmengine.model import is_model_wrapper
|
||||
|
||||
|
||||
def get_ori_model(model: nn.Module) -> nn.Module:
|
||||
"""Get original model if the input model is a model wrapper.
|
||||
|
||||
Args:
|
||||
model (nn.Module): A model may be a model wrapper.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model without model wrapper.
|
||||
"""
|
||||
if is_model_wrapper(model):
|
||||
return model.module
|
||||
else:
|
||||
return model
|
|
@ -1,9 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmpretrain.models import ClsDataPreprocessor, RandomBatchAugment
|
||||
from mmpretrain.models import (ClsDataPreprocessor, RandomBatchAugment,
|
||||
SelfSupDataPreprocessor,
|
||||
TwoNormDataPreprocessor, VideoDataPreprocessor)
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
@ -99,3 +102,111 @@ class TestClsDataPreprocessor(TestCase):
|
|||
processed_data = processor(data, training=True)
|
||||
self.assertIn('inputs', processed_data)
|
||||
self.assertIsNone(processed_data['data_samples'])
|
||||
|
||||
|
||||
class TestSelfSupDataPreprocessor(TestCase):
|
||||
|
||||
def test_to_rgb(self):
|
||||
cfg = dict(type='SelfSupDataPreprocessor', to_rgb=True)
|
||||
processor: SelfSupDataPreprocessor = MODELS.build(cfg)
|
||||
self.assertTrue(processor._channel_conversion)
|
||||
|
||||
fake_data = {
|
||||
'inputs':
|
||||
[torch.randn((2, 3, 224, 224)),
|
||||
torch.randn((2, 3, 224, 224))],
|
||||
'data_samples': [DataSample(), DataSample()]
|
||||
}
|
||||
inputs = processor(fake_data)['inputs']
|
||||
torch.testing.assert_allclose(fake_data['inputs'][0].flip(1).float(),
|
||||
inputs[0])
|
||||
torch.testing.assert_allclose(fake_data['inputs'][1].flip(1).float(),
|
||||
inputs[1])
|
||||
|
||||
def test_forward(self):
|
||||
data_preprocessor = SelfSupDataPreprocessor(
|
||||
to_rgb=True, mean=[124, 117, 104], std=[59, 58, 58])
|
||||
fake_data = {
|
||||
'inputs': [torch.randn((2, 3, 224, 224))],
|
||||
'data_samples': [DataSample(), DataSample()]
|
||||
}
|
||||
fake_output = data_preprocessor(fake_data)
|
||||
self.assertEqual(len(fake_output['inputs']), 1)
|
||||
self.assertEqual(len(fake_output['data_samples']), 2)
|
||||
|
||||
|
||||
class TestTwoNormDataPreprocessor(TestCase):
|
||||
|
||||
def test_assertion(self):
|
||||
with pytest.raises(AssertionError):
|
||||
_ = TwoNormDataPreprocessor(
|
||||
to_rgb=True,
|
||||
mean=(123.675, 116.28, 103.53),
|
||||
std=(58.395, 57.12, 57.375),
|
||||
)
|
||||
with pytest.raises(AssertionError):
|
||||
_ = TwoNormDataPreprocessor(
|
||||
to_rgb=True,
|
||||
mean=(123.675, 116.28, 103.53),
|
||||
std=(58.395, 57.12, 57.375),
|
||||
second_mean=(127.5, 127.5),
|
||||
second_std=(127.5, 127.5, 127.5),
|
||||
)
|
||||
with pytest.raises(AssertionError):
|
||||
_ = TwoNormDataPreprocessor(
|
||||
to_rgb=True,
|
||||
mean=(123.675, 116.28, 103.53),
|
||||
std=(58.395, 57.12, 57.375),
|
||||
second_mean=(127.5, 127.5, 127.5),
|
||||
second_std=(127.5, 127.5),
|
||||
)
|
||||
|
||||
def test_forward(self):
|
||||
data_preprocessor = dict(
|
||||
mean=(123.675, 116.28, 103.53),
|
||||
std=(58.395, 57.12, 57.375),
|
||||
second_mean=(127.5, 127.5, 127.5),
|
||||
second_std=(127.5, 127.5, 127.5),
|
||||
to_rgb=True)
|
||||
|
||||
data_preprocessor = TwoNormDataPreprocessor(**data_preprocessor)
|
||||
fake_data = {
|
||||
'inputs':
|
||||
[torch.randn((2, 3, 224, 224)),
|
||||
torch.randn((2, 3, 224, 224))],
|
||||
'data_sample': [DataSample(), DataSample()]
|
||||
}
|
||||
fake_output = data_preprocessor(fake_data)
|
||||
self.assertEqual(len(fake_output['inputs']), 2)
|
||||
self.assertEqual(len(fake_output['data_samples']), 2)
|
||||
|
||||
|
||||
class TestVideoDataPreprocessor(TestCase):
|
||||
|
||||
def test_NCTHW_format(self):
|
||||
data_preprocessor = VideoDataPreprocessor(
|
||||
mean=[114.75, 114.75, 114.75],
|
||||
std=[57.375, 57.375, 57.375],
|
||||
to_rgb=True,
|
||||
format_shape='NCTHW')
|
||||
fake_data = {
|
||||
'inputs': [torch.randn((2, 3, 4, 224, 224))],
|
||||
'data_sample': [DataSample(), DataSample()]
|
||||
}
|
||||
fake_output = data_preprocessor(fake_data)
|
||||
self.assertEqual(len(fake_output['inputs']), 1)
|
||||
self.assertEqual(len(fake_output['data_samples']), 2)
|
||||
|
||||
def test_NCHW_format(self):
|
||||
data_preprocessor = VideoDataPreprocessor(
|
||||
mean=[114.75, 114.75, 114.75],
|
||||
std=[57.375, 57.375, 57.375],
|
||||
to_rgb=True,
|
||||
format_shape='NCHW')
|
||||
fake_data = {
|
||||
'inputs': [torch.randn((2, 3, 224, 224))],
|
||||
'data_sample': [DataSample(), DataSample()]
|
||||
}
|
||||
fake_output = data_preprocessor(fake_data)
|
||||
self.assertEqual(len(fake_output['inputs']), 1)
|
||||
self.assertEqual(len(fake_output['data_samples']), 2)
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.logging import MessageHub
|
||||
from mmengine.testing import assert_allclose
|
||||
|
||||
from mmpretrain.models.utils import CosineEMA
|
||||
|
||||
|
||||
class TestEMA(TestCase):
|
||||
|
||||
def test_cosine_ema(self):
|
||||
model = nn.Sequential(nn.Conv2d(1, 5, kernel_size=3), nn.Linear(5, 10))
|
||||
|
||||
# init message hub
|
||||
max_iters = 5
|
||||
test = dict(name='ema_test')
|
||||
message_hub = MessageHub.get_instance(**test)
|
||||
message_hub.update_info('max_iters', max_iters)
|
||||
|
||||
# test EMA
|
||||
momentum = 0.996
|
||||
end_momentum = 1.
|
||||
|
||||
ema_model = CosineEMA(model, momentum=momentum)
|
||||
averaged_params = [
|
||||
torch.zeros_like(param) for param in model.parameters()
|
||||
]
|
||||
|
||||
for i in range(max_iters):
|
||||
updated_averaged_params = []
|
||||
for p, p_avg in zip(model.parameters(), averaged_params):
|
||||
p.detach().add_(torch.randn_like(p))
|
||||
if i == 0:
|
||||
updated_averaged_params.append(p.clone())
|
||||
else:
|
||||
m = end_momentum - (end_momentum - momentum) * (
|
||||
math.cos(math.pi * i / float(max_iters)) + 1) / 2
|
||||
updated_averaged_params.append(
|
||||
(p_avg * m + p * (1 - m)).clone())
|
||||
ema_model.update_parameters(model)
|
||||
averaged_params = updated_averaged_params
|
||||
|
||||
for p_target, p_ema in zip(averaged_params, ema_model.parameters()):
|
||||
assert_allclose(p_target, p_ema)
|
Loading…
Reference in New Issue