[Refactor] Add necks, heads and losses for the self-supervised task. (#1376)

* add necks

* refactor linear neck

* rename simmim neck

* add heads

* add losses

* fix

* add unittest

* update

* update cae

* remove mim head

* update config
pull/1400/head
Yixiao Fang 2023-02-28 10:05:00 +08:00 committed by GitHub
parent 75c79311f4
commit 63d9f27fde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
46 changed files with 2795 additions and 45 deletions

View File

@ -16,7 +16,7 @@ model = dict(
type='MAEPretrainHead',
norm_pix=True,
patch_size=16,
loss=dict(type='MAEReconstructionLoss')),
loss=dict(type='PixelReconstructionLoss', criterion='L2')),
init_cfg=[
dict(type='Xavier', layer='Linear', distribution='uniform'),
dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)

View File

@ -71,7 +71,7 @@ model = dict(
type='BEiTV1Head',
embed_dims=768,
num_embed=8192,
loss=dict(type='BEiTLoss')),
loss=dict(type='CrossEntropyLoss')),
target_generator=dict(
type='DALL-E',
init_cfg=dict(

View File

@ -56,7 +56,7 @@ model = dict(
type='BEiTV2Head',
embed_dims=768,
num_embed=8192,
loss=dict(type='BEiTLoss')),
loss=dict(type='CrossEntropyLoss')),
target_generator=dict(
type='VQKD',
encoder_config=vqkd_encoder,

View File

@ -56,7 +56,7 @@ model = dict(
type='BEiTV2Head',
embed_dims=768,
num_embed=8192,
loss=dict(type='BEiTLoss')),
loss=dict(type='CrossEntropyLoss')),
target_generator=dict(
type='VQKD',
encoder_config=vqkd_encoder,

View File

@ -1,4 +0,0 @@
_base_ = 'cae_vit-base-p16_8xb256-amp-coslr-300e_in1k.py'
# dataset 128 x 16
train_dataloader = dict(batch_size=128)

View File

@ -56,7 +56,6 @@ model = dict(
qkv_bias=False),
neck=dict(
type='CAENeck',
patch_size=16,
embed_dims=768,
num_heads=12,
regressor_depth=4,

View File

@ -4,7 +4,7 @@ Collections:
Training Data: ImageNet-1k
Training Techniques:
- AdamW
Training Resources: 16x A100-80G GPUs
Training Resources: 8x A100-80G GPUs
Architecture:
- ViT
Paper:
@ -19,8 +19,8 @@ Models:
Epochs: 300
Batch Size: 2048
Results: null
Config: configs/cae/cae_vit-base-p16_16xb128-amp-coslr-300e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/cae/cae_vit-base-p16_16xb128-fp16-coslr-300e_in1k/cae_vit-base-p16_16xb128-fp16-coslr-300e_in1k_20220825-404a1929.pth
Config: configs/cae/cae_vit-base-p16_8xb256-amp-coslr-300e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/1.x/cae/cae_vit-base-p16_8xb256-amp-coslr-300e_in1k/cae_vit-base-p16_8xb256-amp-coslr-300e_in1k_20221230-808170f3.pth
Downstream:
- beit-base-p16_cae-pre_8xb128-coslr-100e_in1k
- Name: beit-base-p16_cae-pre_8xb128-coslr-100e_in1k

View File

@ -56,7 +56,6 @@ model = dict(
type='LinearNeck',
in_channels=768,
out_channels=108,
with_avg_pool=False,
init_cfg=dict(type='TruncNormal', layer='Linear', std=0.02, bias=0)),
head=dict(
type='MaskFeatPretrainHead',

View File

@ -11,11 +11,12 @@ model = dict(
arch='base',
img_size=192,
stage_cfgs=dict(block_cfgs=dict(window_size=6))),
neck=dict(type='SimMIMNeck', in_channels=128 * 2**3, encoder_stride=32),
neck=dict(
type='SimMIMLinearDecoder', in_channels=128 * 2**3, encoder_stride=32),
head=dict(
type='SimMIMHead',
patch_size=4,
loss=dict(type='SimMIMReconstructionLoss', encoder_in_channels=3)))
loss=dict(type='PixelReconstructionLoss', criterion='L1', channels=3)))
# optimizer wrapper
optim_wrapper = dict(

View File

@ -11,11 +11,12 @@ model = dict(
arch='base',
img_size=192,
stage_cfgs=dict(block_cfgs=dict(window_size=6))),
neck=dict(type='SimMIMNeck', in_channels=128 * 2**3, encoder_stride=32),
neck=dict(
type='SimMIMLinearDecoder', in_channels=128 * 2**3, encoder_stride=32),
head=dict(
type='SimMIMHead',
patch_size=4,
loss=dict(type='SimMIMReconstructionLoss', encoder_in_channels=3)))
loss=dict(type='PixelReconstructionLoss', criterion='L1', channels=3)))
# optimizer wrapper
optim_wrapper = dict(

View File

@ -12,11 +12,12 @@ model = dict(
img_size=192,
stage_cfgs=dict(block_cfgs=dict(window_size=12)),
pad_small_map=True),
neck=dict(type='SimMIMNeck', in_channels=192 * 2**3, encoder_stride=32),
neck=dict(
type='SimMIMLinearDecoder', in_channels=192 * 2**3, encoder_stride=32),
head=dict(
type='SimMIMHead',
patch_size=4,
loss=dict(type='SimMIMReconstructionLoss', encoder_in_channels=3)))
loss=dict(type='PixelReconstructionLoss', criterion='L1', channels=3)))
# optimizer wrapper
optim_wrapper = dict(

View File

@ -1,16 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .beitv1_head import BEiTV1Head
from .beitv2_head import BEiTV2Head
from .cae_head import CAEHead
from .cls_head import ClsHead
from .conformer_head import ConformerHead
from .contrastive_head import ContrastiveHead
from .deit_head import DeiTClsHead
from .efficientformer_head import EfficientFormerClsHead
from .latent_heads import LatentCrossCorrelationHead, LatentPredictHead
from .levit_head import LeViTClsHead
from .linear_head import LinearClsHead
from .mae_head import MAEPretrainHead
from .margin_head import ArcFaceClsHead
from .mixmim_head import MixMIMPretrainHead
from .multi_label_cls_head import MultiLabelClsHead
from .multi_label_csra_head import CSRAClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead
from .multi_task_head import MultiTaskHead
from .stacked_head import StackedLinearClsHead
from .swav_head import SwAVHead
from .vig_head import VigClsHead
from .vision_transformer_head import VisionTransformerClsHead
@ -29,4 +37,13 @@ __all__ = [
'MultiTaskHead',
'LeViTClsHead',
'VigClsHead',
'BEiTV1Head',
'BEiTV2Head',
'CAEHead',
'ContrastiveHead',
'LatentCrossCorrelationHead',
'LatentPredictHead',
'MAEPretrainHead',
'MixMIMPretrainHead',
'SwAVHead',
]

View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
import torch.nn as nn
from mmpretrain.registry import MODELS
from .base_head import BaseHead
@MODELS.register_module()
class BEiTV1Head(BaseHead):
"""Pretrain Head for BEiT v1.
Compute the logits and the cross entropy loss.
Args:
embed_dims (int): The dimension of embedding.
num_embed (int): The number of classification types.
loss (dict): The config of loss.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
embed_dims: int,
num_embed: int,
loss: dict,
init_cfg: Optional[Union[dict, List[dict]]] = dict(
type='TruncNormal', layer='Linear', std=0.02, bias=0)
) -> None:
super().__init__(init_cfg=init_cfg)
self.cls_head = nn.Linear(embed_dims, num_embed)
self.loss_module = MODELS.build(loss)
def loss(self, feats: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Generate loss.
Args:
feats (torch.Tensor): Features from backbone.
target (torch.Tensor): Target generated by target_generator.
mask (torch.Tensor): Generated mask for pretraing.
"""
mask = mask.flatten(1).to(torch.bool)
target = torch.argmax(target, dim=1).flatten(1)
target = target[mask]
# remove cls_token
feats = feats[:, 1:]
logits = self.cls_head(feats[mask])
loss = self.loss_module(logits, target)
return loss

View File

@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
import torch.nn as nn
from mmpretrain.registry import MODELS
from .base_head import BaseHead
@MODELS.register_module()
class BEiTV2Head(BaseHead):
"""Pretrain Head for BEiT.
Compute the logits and the cross entropy loss.
Args:
embed_dims (int): The dimension of embedding.
num_embed (int): The number of classification types.
loss (dict): The config of loss.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
embed_dims: int,
num_embed: int,
loss: dict,
init_cfg: Optional[Union[dict, List[dict]]] = dict(
type='TruncNormal', layer='Linear', std=0.02, bias=0)
) -> None:
super().__init__(init_cfg=init_cfg)
self.cls_head = nn.Linear(embed_dims, num_embed)
self.loss_module = MODELS.build(loss)
def loss(self, feats: torch.Tensor, feats_cls_pt: torch.Tensor,
target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Generate loss.
Args:
feats (torch.Tensor): Features from backbone.
feats_cls_pt (torch.Tensor) : Features from class late layers for
pretraining.
target (torch.Tensor): Target generated by target_generator.
mask (torch.Tensor): Generated mask for pretraing.
"""
mask = mask.flatten(1).to(torch.bool)
target = target[mask]
# shared cls head
logits = self.cls_head(feats[mask])
logits_cls_pt = self.cls_head(feats_cls_pt[mask])
loss_1 = self.loss_module(logits, target)
loss_2 = self.loss_module(logits_cls_pt, target)
return loss_1, loss_2

View File

@ -0,0 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
from mmpretrain.registry import MODELS
from .base_head import BaseHead
@MODELS.register_module()
class CAEHead(BaseHead):
"""Pretrain Head for CAE.
Compute the align loss and the main loss. In addition, this head also
generates the prediction target generated by dalle.
Args:
loss (dict): The config of loss.
tokenizer_path (str): The path of the tokenizer.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
loss: dict,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.loss_module = MODELS.build(loss)
@torch.no_grad()
def _generate_target(self, logits_target: torch.Tensor) -> torch.Tensor:
"""Generate the reconstruction target.
Args:
logits_target (torch.Tensor): The logits generated by DALL-E.s
Returns:
torch.Tensor: The logits target.
"""
target = torch.argmax(logits_target, dim=1)
return target.flatten(1)
def loss(self, logits: torch.Tensor, logits_target: torch.Tensor,
latent_pred: torch.Tensor, latent_target: torch.Tensor,
mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate loss.
Args:
logits (torch.Tensor): Logits generated by decoder.
logits_target (img_target): Target generated by dalle for decoder
prediction.
latent_pred (torch.Tensor): Latent prediction by regressor.
latent_target (torch.Tensor): Target for latent prediction,
generated by teacher.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The tuple of loss.
- loss_main (torch.Tensor): Cross entropy loss.
- loss_align (torch.Tensor): MSE loss.
"""
target = self._generate_target(logits_target) # target features
target = target[mask].detach()
# loss main for decoder, loss align for regressor
loss_main, loss_align = self.loss_module(logits, target, latent_pred,
latent_target)
return (loss_main, loss_align)

View File

@ -0,0 +1,50 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
from mmpretrain.registry import MODELS
from .base_head import BaseHead
@MODELS.register_module()
class ContrastiveHead(BaseHead):
"""Head for contrastive learning.
The contrastive loss is implemented in this head and is used in SimCLR,
MoCo, DenseCL, etc.
Args:
loss (dict): Config dict for module of loss functions.
temperature (float): The temperature hyper-parameter that
controls the concentration level of the distribution.
Defaults to 0.1.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
loss: dict,
temperature: float = 0.1,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.loss_module = MODELS.build(loss)
self.temperature = temperature
def loss(self, pos: torch.Tensor, neg: torch.Tensor) -> torch.Tensor:
"""Forward function to compute contrastive loss.
Args:
pos (torch.Tensor): Nx1 positive similarity.
neg (torch.Tensor): Nxk negative similarity.
Returns:
torch.Tensor: The contrastive loss.
"""
N = pos.size(0)
logits = torch.cat((pos, neg), dim=1)
logits /= self.temperature
labels = torch.zeros((N, ), dtype=torch.long).to(pos.device)
loss = self.loss_module(logits, labels)
return loss

View File

@ -0,0 +1,95 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.dist import all_reduce, get_world_size
from mmpretrain.registry import MODELS
from .base_head import BaseHead
@MODELS.register_module()
class LatentPredictHead(BaseHead):
"""Head for latent feature prediction.
This head builds a predictor, which can be any registered neck component.
For example, BYOL and SimSiam call this head and build NonLinearNeck.
It also implements similarity loss between two forward features.
Args:
loss (dict): Config dict for the loss.
predictor (dict): Config dict for the predictor.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
loss: dict,
predictor: dict,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.loss = MODELS.build(loss)
self.predictor = MODELS.build(predictor)
def forward(self, input: torch.Tensor,
target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward head.
Args:
input (torch.Tensor): NxC input features.
target (torch.Tensor): NxC target features.
Returns:
torch.Tensor: The latent predict loss.
"""
pred = self.predictor([input])[0]
target = target.detach()
loss = self.loss(pred, target)
return loss
@MODELS.register_module()
class LatentCrossCorrelationHead(BaseHead):
"""Head for latent feature cross correlation.
Part of the code is borrowed from `script
<https://github.com/facebookresearch/barlowtwins/blob/main/main.py>`_.
Args:
in_channels (int): Number of input channels.
loss (dict): Config dict for module of loss functions.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: int,
loss: dict,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.world_size = get_world_size()
self.bn = nn.BatchNorm1d(in_channels, affine=False)
self.loss = MODELS.build(loss)
def forward(self, input: torch.Tensor,
target: torch.Tensor) -> torch.Tensor:
"""Forward head.
Args:
input (torch.Tensor): NxC input features.
target (torch.Tensor): NxC target features.
Returns:
torch.Tensor: The cross correlation loss.
"""
# cross-correlation matrix
cross_correlation_matrix = self.bn(input).T @ self.bn(target)
cross_correlation_matrix.div_(input.size(0) * self.world_size)
all_reduce(cross_correlation_matrix)
loss = self.loss(cross_correlation_matrix)
return loss

View File

@ -0,0 +1,99 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmpretrain.registry import MODELS
from .base_head import BaseHead
@MODELS.register_module()
class MAEPretrainHead(BaseHead):
"""Pre-training head for MAE.
Args:
loss (dict): Config of loss.
norm_pix_loss (bool): Whether or not normalize target.
Defaults to False.
patch_size (int): Patch size. Defaults to 16.
"""
def __init__(self,
loss: dict,
norm_pix: bool = False,
patch_size: int = 16) -> None:
super().__init__()
self.norm_pix = norm_pix
self.patch_size = patch_size
self.loss_module = MODELS.build(loss)
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
"""Split images into non-overlapped patches.
Args:
imgs (torch.Tensor): A batch of images, of shape B x H x W x C.
Returns:
torch.Tensor: Patchified images. The shape is B x L x D.
"""
p = self.patch_size
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
"""Combine non-overlapped patches into images.
Args:
x (torch.Tensor): The shape is (N, L, patch_size**2 *3)
Returns:
imgs (torch.Tensor): The shape is (N, 3, H, W)
"""
p = self.patch_size
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def construct_target(self, target: torch.Tensor) -> torch.Tensor:
"""Construct the reconstruction target.
In addition to splitting images into tokens, this module will also
normalize the image according to ``norm_pix``.
Args:
target (torch.Tensor): Image with the shape of B x 3 x H x W
Returns:
torch.Tensor: Tokenized images with the shape of B x L x C
"""
target = self.patchify(target)
if self.norm_pix:
# normalize the target image
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
return target
def loss(self, pred: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Forward function of MAE head.
Args:
pred (torch.Tensor): The reconstructed image.
target (torch.Tensor): The target image.
mask (torch.Tensor): The mask of the target image.
Returns:
torch.Tensor: The reconstruction loss.
"""
target = self.construct_target(target)
loss = self.loss_module(pred, target, mask)
return loss

View File

@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
# TODO: delete and use NaiveMIMHead
@MODELS.register_module()
class MaskFeatPretrainHead(BaseModule):
"""Pre-training head for MaskFeat.
It computes reconstruction loss between prediction and target in masked
region.
Args:
loss (dict): Config dict for module of loss functions.
"""
def __init__(self, loss: dict) -> None:
super().__init__()
self.loss = MODELS.build(loss)
def forward(self, pred: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Forward head.
Args:
latent (torch.Tensor): Predictions,
which is of shape B x (1 + L) x C.
target (torch.Tensor): Hog features, which is of shape B x L x C.
mask (torch.Tensor): The mask of the hog features,
which is of shape B x H x W.
Returns:
torch.Tensor: The loss tensor.
"""
mask = mask.flatten(1).bool()
loss = self.loss(pred[:, 1:], target, mask)
return loss

View File

@ -0,0 +1,49 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmpretrain.registry import MODELS
from .mae_head import MAEPretrainHead
@MODELS.register_module()
class MixMIMPretrainHead(MAEPretrainHead):
"""MixMIM pretrain head.
Args:
loss (dict): Config of loss.
norm_pix_loss (bool): Whether or not normalize target.
Defaults to False.
patch_size (int): Patch size. Defaults to 16.
"""
def __init__(self,
loss: dict,
norm_pix: bool = False,
patch_size: int = 16) -> None:
super().__init__(loss=loss, norm_pix=norm_pix, patch_size=patch_size)
def loss(self, x_rec: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Forward function of MixMIM head.
Args:
pred (torch.Tensor): The reconstructed image.
target (torch.Tensor): The target image.
mask (torch.Tensor): The mask of the target image.
Returns:
torch.Tensor: The reconstruction loss.
"""
target = self.construct_target(target)
B, L, C = x_rec.shape
# unmix tokens
x1_rec = x_rec[:B // 2]
x2_rec = x_rec[B // 2:]
unmix_x_rec = x1_rec * mask + x2_rec.flip(0) * (1 - mask)
loss_rec = self.loss_module(unmix_x_rec, target)
return loss_rec

View File

@ -0,0 +1,67 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.dist import get_rank
from mmpretrain.registry import MODELS
from mmpretrain.utils import concat_all_gather
from .base_head import BaseHead
@MODELS.register_module()
class MoCoV3Head(BaseHead):
"""Head for MoCo v3 algorithms.
This head builds a predictor, which can be any registered neck component.
It also implements latent contrastive loss between two forward features.
Part of the code is modified from:
`<https://github.com/facebookresearch/moco-v3/blob/main/moco/builder.py>`_.
Args:
predictor (dict): Config dict for module of predictor.
loss (dict): Config dict for module of loss functions.
temperature (float): The temperature hyper-parameter that
controls the concentration level of the distribution.
Defaults to 1.0.
"""
def __init__(self,
predictor: dict,
loss: dict,
temperature: float = 1.0) -> None:
super().__init__()
self.predictor = MODELS.build(predictor)
self.loss_module = MODELS.build(loss)
self.temperature = temperature
def loss(self, base_out: torch.Tensor,
momentum_out: torch.Tensor) -> torch.Tensor:
"""Forward head.
Args:
base_out (torch.Tensor): NxC features from base_encoder.
momentum_out (torch.Tensor): NxC features from momentum_encoder.
Returns:
torch.Tensor: The loss tensor.
"""
# predictor computation
pred = self.predictor([base_out])[0]
# normalize
pred = nn.functional.normalize(pred, dim=1)
target = nn.functional.normalize(momentum_out, dim=1)
# get negative samples
target = concat_all_gather(target)
# Einstein sum is more intuitive
logits = torch.einsum('nc,mc->nm', [pred, target]) / self.temperature
# generate labels
batch_size = logits.shape[0]
labels = (torch.arange(batch_size, dtype=torch.long) +
batch_size * get_rank()).to(logits.device)
loss = self.loss_module(logits, labels)
return loss

View File

@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
# TODO: delete and use NaiveMIMHead
@MODELS.register_module()
class SimMIMHead(BaseModule):
"""Pretrain Head for SimMIM.
Args:
patch_size (int): Patch size of each token.
loss (dict): The config for loss.
"""
def __init__(self, patch_size: int, loss: dict) -> None:
super().__init__()
self.patch_size = patch_size
self.loss = MODELS.build(loss)
def forward(self, pred: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Forward function of MAE Loss.
This method will expand mask to the size of the original image.
Args:
pred (torch.Tensor): The reconstructed image (B, C, H, W).
target (torch.Tensor): The target image (B, C, H, W).
mask (torch.Tensor): The mask of the target image.
Returns:
torch.Tensor: The reconstruction loss.
"""
mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(
self.patch_size, 2).unsqueeze(1).contiguous()
loss = self.loss(pred, target, mask)
return loss

View File

@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmpretrain.registry import MODELS
from .base_head import BaseHead
@MODELS.register_module()
class SwAVHead(BaseHead):
"""Head for SwAV.
Args:
loss (dict): Config dict for module of loss functions.
"""
def __init__(self, loss: dict) -> None:
super().__init__()
self.loss_module = MODELS.build(loss)
def loss(self, pred: torch.Tensor) -> torch.Tensor:
"""Forward function of SwAV head.
Args:
pred (torch.Tensor): NxC input features.
Returns:
torch.Tensor: The SwAV loss.
"""
loss = self.loss_module(pred)
return loss

View File

@ -1,16 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .asymmetric_loss import AsymmetricLoss, asymmetric_loss
from .cae_loss import CAELoss
from .cosine_similarity_loss import CosineSimilarityLoss
from .cross_correlation_loss import CrossCorrelationLoss
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy)
from .focal_loss import FocalLoss, sigmoid_focal_loss
from .label_smooth_loss import LabelSmoothLoss
from .reconstruction_loss import PixelReconstructionLoss
from .seesaw_loss import SeesawLoss
from .swav_loss import SwAVLoss
from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss,
weighted_loss)
__all__ = [
'asymmetric_loss', 'AsymmetricLoss', 'cross_entropy',
'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'LabelSmoothLoss', 'weighted_loss', 'FocalLoss',
'sigmoid_focal_loss', 'convert_to_one_hot', 'SeesawLoss'
'asymmetric_loss',
'AsymmetricLoss',
'cross_entropy',
'binary_cross_entropy',
'CrossEntropyLoss',
'reduce_loss',
'weight_reduce_loss',
'LabelSmoothLoss',
'weighted_loss',
'FocalLoss',
'sigmoid_focal_loss',
'convert_to_one_hot',
'SeesawLoss',
'CAELoss',
'CosineSimilarityLoss',
'CrossCorrelationLoss',
'PixelReconstructionLoss',
'SwAVLoss',
]

View File

@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
from mmengine.model import BaseModule
from torch import nn
from mmpretrain.registry import MODELS
@MODELS.register_module()
class CAELoss(BaseModule):
"""Loss function for CAE.
Compute the align loss and the main loss.
Args:
lambd (float): The weight for the align loss.
"""
def __init__(self, lambd: float) -> None:
super().__init__()
self.lambd = lambd
self.loss_cross_entropy = nn.CrossEntropyLoss()
self.loss_mse = nn.MSELoss()
def forward(
self, logits: torch.Tensor, target: torch.Tensor,
latent_pred: torch.Tensor,
latent_target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward function of CAE Loss.
Args:
logits (torch.Tensor): The outputs from the decoder.
target (torch.Tensor): The targets generated by dalle.
latent_pred (torch.Tensor): The latent prediction from the
regressor.
latent_target (torch.Tensor): The latent target from the teacher
network.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The main loss and align loss.
"""
loss_main = self.loss_cross_entropy(logits, target)
loss_align = self.loss_mse(latent_pred,
latent_target.detach()) * self.lambd
return loss_main, loss_align

View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from mmengine.model import BaseModule
from torch import nn
from mmpretrain.registry import MODELS
@MODELS.register_module()
class CosineSimilarityLoss(BaseModule):
"""Cosine similarity loss function.
Compute the similarity between two features and optimize that similarity as
loss.
Args:
shift_factor (float): The shift factor of cosine similarity.
Default: 0.0.
scale_factor (float): The scale factor of cosine similarity.
Default: 1.0.
"""
def __init__(self,
shift_factor: float = 0.0,
scale_factor: float = 1.0) -> None:
super().__init__()
self.shift_factor = shift_factor
self.scale_factor = scale_factor
def forward(self,
pred: torch.Tensor,
target: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward function of cosine similarity loss.
Args:
pred (torch.Tensor): The predicted features.
target (torch.Tensor): The target features.
Returns:
torch.Tensor: The cosine similarity loss.
"""
pred_norm = nn.functional.normalize(pred, dim=-1)
target_norm = nn.functional.normalize(target, dim=-1)
loss = self.shift_factor - self.scale_factor * (
pred_norm * target_norm).sum(dim=-1)
if mask is None:
loss = loss.mean()
else:
loss = (loss * mask).sum() / mask.sum()
return loss

View File

@ -0,0 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class CrossCorrelationLoss(BaseModule):
"""Cross correlation loss function.
Compute the on-diagnal and off-diagnal loss.
Args:
lambd (float): The weight for the off-diag loss.
"""
def __init__(self, lambd: float = 0.0051) -> None:
super().__init__()
self.lambd = lambd
def forward(self, cross_correlation_matrix: torch.Tensor) -> torch.Tensor:
"""Forward function of cross correlation loss.
Args:
cross_correlation_matrix (torch.Tensor): The cross correlation
matrix.
Returns:
torch.Tensor: cross correlation loss.
"""
# loss
on_diag = torch.diagonal(cross_correlation_matrix).add_(-1).pow_(
2).sum()
off_diag = self.off_diagonal(cross_correlation_matrix).pow_(2).sum()
loss = on_diag + self.lambd * off_diag
return loss
def off_diagonal(self, x: torch.Tensor) -> torch.Tensor:
"""Rreturn a flattened view of the off-diagonal elements of a square
matrix."""
n, m = x.shape
assert n == m
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

View File

@ -0,0 +1,67 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class PixelReconstructionLoss(BaseModule):
"""Loss for the reconstruction of pixel in Masked Image Modeling.
This module measures the distance between the target image and the
reconstructed image and compute the loss to optimize the model. Currently,
This module only provides L1 and L2 loss to penalize the reconstructed
error. In addition, a mask can be passed in the ``forward`` function to
only apply loss on visible region, like that in MAE.
Args:
criterion (str): The loss the penalize the reconstructed error.
Currently, only supports L1 and L2 loss
channel (int, optional): The number of channels to average the
reconstruction loss. If not None, the reconstruction loss
will be divided by the channel. Defaults to None.
"""
def __init__(self, criterion: str, channel: Optional[int] = None) -> None:
super().__init__()
if criterion == 'L1':
self.penalty = torch.nn.L1Loss(reduction='none')
elif criterion == 'L2':
self.penalty = torch.nn.MSELoss(reduction='none')
else:
raise NotImplementedError(f'Currently, PixelReconstructionLoss \
only supports L1 and L2 loss, but get {criterion}')
self.channel = channel if channel is not None else 1
def forward(self,
pred: torch.Tensor,
target: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward function to compute the reconstrction loss.
Args:
pred (torch.Tensor): The reconstructed image.
target (torch.Tensor): The target image.
mask (torch.Tensor): The mask of the target image.
Returns:
torch.Tensor: The reconstruction loss.
"""
loss = self.penalty(pred, target)
# if the dim of the loss is 3, take the average of the loss
# along the last dim
if len(loss.shape) == 3:
loss = loss.mean(dim=-1)
if mask is None:
loss = loss.mean()
else:
loss = (loss * mask).sum() / mask.sum() / self.channel
return loss

View File

@ -0,0 +1,190 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from mmengine.dist import all_reduce
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@torch.no_grad()
def distributed_sinkhorn(out: torch.Tensor, sinkhorn_iterations: int,
world_size: int, epsilon: float) -> torch.Tensor:
"""Apply the distributed sinknorn optimization on the scores matrix to find
the assignments.
This function is modified from
https://github.com/facebookresearch/swav/blob/main/main_swav.py
Args:
out (torch.Tensor): The scores matrix
sinkhorn_iterations (int): Number of iterations in Sinkhorn-Knopp
algorithm.
world_size (int): The world size of the process group.
epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm.
Returns:
torch.Tensor: Output of sinkhorn algorithm.
"""
eps_num_stab = 1e-12
Q = torch.exp(out / epsilon).t(
) # Q is K-by-B for consistency with notations from our paper
B = Q.shape[1] * world_size # number of samples to assign
K = Q.shape[0] # how many prototypes
# make the matrix sums to 1
sum_Q = torch.sum(Q)
all_reduce(sum_Q)
Q /= sum_Q
for it in range(sinkhorn_iterations):
# normalize each row: total weight per prototype must be 1/K
u = torch.sum(Q, dim=1, keepdim=True)
if len(torch.nonzero(u == 0)) > 0:
Q += eps_num_stab
u = torch.sum(Q, dim=1, keepdim=True, dtype=Q.dtype)
all_reduce(u)
Q /= u
Q /= K
# normalize each column: total weight per sample must be 1/B
Q /= torch.sum(Q, dim=0, keepdim=True)
Q /= B
Q *= B # the columns must sum to 1 so that Q is an assignment
return Q.t()
class MultiPrototypes(BaseModule):
"""Multi-prototypes for SwAV head.
Args:
output_dim (int): The output dim from SwAV neck.
num_prototypes (List[int]): The number of prototypes needed.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
output_dim: int,
num_prototypes: List[int],
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(init_cfg=init_cfg)
assert isinstance(num_prototypes, list)
self.num_heads = len(num_prototypes)
for i, k in enumerate(num_prototypes):
self.add_module('prototypes' + str(i),
nn.Linear(output_dim, k, bias=False))
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Run forward for every prototype."""
out = []
for i in range(self.num_heads):
out.append(getattr(self, 'prototypes' + str(i))(x))
return out
@MODELS.register_module()
class SwAVLoss(BaseModule):
"""The Loss for SwAV.
This Loss contains clustering and sinkhorn algorithms to compute Q codes.
Part of the code is borrowed from `script
<https://github.com/facebookresearch/swav>`_.
The queue is built in `engine/hooks/swav_hook.py`.
Args:
feat_dim (int): feature dimension of the prototypes.
sinkhorn_iterations (int): number of iterations in Sinkhorn-Knopp
algorithm. Defaults to 3.
epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm.
Defaults to 0.05.
temperature (float): temperature parameter in training loss.
Defaults to 0.1.
crops_for_assign (List[int]): list of crops id used for computing
assignments. Defaults to [0, 1].
num_crops (List[int]): list of number of crops. Defaults to [2].
num_prototypes (int): number of prototypes. Defaults to 3000.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
feat_dim: int,
sinkhorn_iterations: int = 3,
epsilon: float = 0.05,
temperature: float = 0.1,
crops_for_assign: List[int] = [0, 1],
num_crops: List[int] = [2],
num_prototypes: int = 3000,
init_cfg: Optional[Union[List[dict], dict]] = None):
super().__init__(init_cfg=init_cfg)
self.sinkhorn_iterations = sinkhorn_iterations
self.epsilon = epsilon
self.temperature = temperature
self.crops_for_assign = crops_for_assign
self.num_crops = num_crops
self.use_queue = False
self.queue = None
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
# prototype layer
self.prototypes = None
if isinstance(num_prototypes, list):
self.prototypes = MultiPrototypes(feat_dim, num_prototypes)
elif num_prototypes > 0:
self.prototypes = nn.Linear(feat_dim, num_prototypes, bias=False)
assert self.prototypes is not None
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function of SwAV loss.
Args:
x (torch.Tensor): NxC input features.
Returns:
torch.Tensor: The returned loss.
"""
# normalize the prototypes
with torch.no_grad():
w = self.prototypes.weight.data.clone()
w = nn.functional.normalize(w, dim=1, p=2)
self.prototypes.weight.copy_(w)
embedding, output = x, self.prototypes(x)
embedding = embedding.detach()
bs = int(embedding.size(0) / sum(self.num_crops))
loss = 0
for i, crop_id in enumerate(self.crops_for_assign):
with torch.no_grad():
out = output[bs * crop_id:bs * (crop_id + 1)].detach()
# time to use the queue
if self.queue is not None:
if self.use_queue or not torch.all(self.queue[i,
-1, :] == 0):
self.use_queue = True
out = torch.cat(
(torch.mm(self.queue[i],
self.prototypes.weight.t()), out))
# fill the queue
self.queue[i, bs:] = self.queue[i, :-bs].clone()
self.queue[i, :bs] = embedding[crop_id * bs:(crop_id + 1) *
bs]
# get assignments (batch_size * num_prototypes)
q = distributed_sinkhorn(out, self.sinkhorn_iterations,
self.world_size, self.epsilon)[-bs:]
# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(self.num_crops)), crop_id):
x = output[bs * v:bs * (v + 1)] / self.temperature
subloss -= torch.mean(
torch.sum(q * nn.functional.log_softmax(x, dim=1), dim=1))
loss += subloss / (np.sum(self.num_crops) - 1)
loss /= len(self.crops_for_assign)
return loss

View File

@ -1,10 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .beitv2_neck import BEiTV2Neck
from .cae_neck import CAENeck
from .densecl_neck import DenseCLNeck
from .gap import GlobalAveragePooling
from .gem import GeneralizedMeanPooling
from .hr_fuse import HRFuseScales
from .reduction import LinearReduction
from .linear_neck import LinearNeck
from .mae_neck import MAEPretrainDecoder
from .milan_neck import MILANPretrainDecoder
from .mixmim_neck import MixMIMPretrainDecoder
from .mocov2_neck import MoCoV2Neck
from .nonlinear_neck import NonLinearNeck
from .simmim_neck import SimMIMLinearDecoder
__all__ = [
'GlobalAveragePooling', 'GeneralizedMeanPooling', 'HRFuseScales',
'LinearReduction'
'GlobalAveragePooling',
'GeneralizedMeanPooling',
'HRFuseScales',
'LinearNeck',
'BEiTV2Neck',
'CAENeck',
'DenseCLNeck',
'MAEPretrainDecoder',
'MILANPretrainDecoder',
'MixMIMPretrainDecoder',
'MoCoV2Neck',
'NonLinearNeck',
'SimMIMLinearDecoder',
]

View File

@ -0,0 +1,153 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule
from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer
from mmpretrain.registry import MODELS
@MODELS.register_module()
class BEiTV2Neck(BaseModule):
"""Neck for BEiTV2 Pre-training.
This module construct the decoder for the final prediction.
Args:
num_layers (int): Number of encoder layers of neck. Defaults to 2.
early_layers (int): The layer index of the early output from the
backbone. Defaults to 9.
backbone_arch (str): Vision Transformer architecture. Defaults to base.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): The initialization value for the
learnable scaling of attention and FFN. Defaults to 0.1.
use_rel_pos_bias (bool): Whether to use unique relative position bias,
if False, use shared relative position bias defined in backbone.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'depth': 12,
'num_heads': 12,
'feedforward_channels': 3072,
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'depth': 24,
'num_heads': 16,
'feedforward_channels': 4096,
}),
}
def __init__(
self,
num_layers: int = 2,
early_layers: int = 9,
backbone_arch: str = 'base',
drop_rate: float = 0.,
drop_path_rate: float = 0.,
layer_scale_init_value: float = 0.1,
use_rel_pos_bias: bool = False,
norm_cfg: dict = dict(type='LN', eps=1e-6),
init_cfg: Optional[Union[dict, List[dict]]] = dict(
type='TruncNormal', layer='Linear', std=0.02, bias=0)
) -> None:
super().__init__(init_cfg=init_cfg)
if isinstance(backbone_arch, str):
backbone_arch = backbone_arch.lower()
assert backbone_arch in set(self.arch_zoo), \
(f'Arch {backbone_arch} is not in default archs '
f'{set(self.arch_zoo)}')
self.arch_settings = self.arch_zoo[backbone_arch]
else:
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
assert isinstance(backbone_arch, dict) and essential_keys <= set(
backbone_arch
), f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = backbone_arch
# stochastic depth decay rule
self.early_layers = early_layers
depth = self.arch_settings['depth']
dpr = np.linspace(0, drop_path_rate,
max(depth, early_layers + num_layers))
self.patch_aggregation = nn.ModuleList()
for i in range(early_layers, early_layers + num_layers):
_layer_cfg = dict(
embed_dims=self.arch_settings['embed_dims'],
num_heads=self.arch_settings['num_heads'],
feedforward_channels=self.
arch_settings['feedforward_channels'],
drop_rate=drop_rate,
drop_path_rate=dpr[i],
norm_cfg=norm_cfg,
layer_scale_init_value=layer_scale_init_value,
window_size=None,
use_rel_pos_bias=use_rel_pos_bias)
self.patch_aggregation.append(
BEiTTransformerEncoderLayer(**_layer_cfg))
self.rescale_patch_aggregation_init_weight()
embed_dims = self.arch_settings['embed_dims']
_, norm = build_norm_layer(norm_cfg, embed_dims)
self.add_module('norm', norm)
def rescale_patch_aggregation_init_weight(self):
"""Rescale the initialized weights."""
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.patch_aggregation):
rescale(layer.attn.proj.weight.data,
self.early_layers + layer_id + 1)
rescale(layer.ffn.layers[1].weight.data,
self.early_layers + layer_id + 1)
def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the latent prediction and final prediction.
Args:
x (Tuple[torch.Tensor]): Features of tokens.
rel_pos_bias (torch.Tensor): Shared relative position bias table.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- ``x``: The final layer features from backbone, which are normed
in ``BEiTV2Neck``.
- ``x_cls_pt``: The early state features from backbone, which are
consist of final layer cls_token and early state patch_tokens
from backbone and sent to PatchAggregation layers in the neck.
"""
early_states, x = inputs[0], inputs[1]
x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1)
for layer in self.patch_aggregation:
x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias)
# shared norm
x, x_cls_pt = self.norm(x), self.norm(x_cls_pt)
# remove cls_token
x = x[:, 1:]
x_cls_pt = x_cls_pt[:, 1:]
return x, x_cls_pt

View File

@ -0,0 +1,273 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmcv.cnn.bricks.transformer import FFN
from mmengine.model import BaseModule
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer
from mmpretrain.registry import MODELS
from ..utils import CrossMultiheadAttention
class CAETransformerRegressorLayer(BaseModule):
"""Transformer layer for the regressor of CAE.
This module is different from conventional transformer encoder layer, for
its queries are the masked tokens, but its keys and values are the
concatenation of the masked and unmasked tokens.
Args:
embed_dims (int): The feature dimension.
num_heads (int): The number of heads in multi-head attention.
feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 1024.
num_fcs (int, optional): The number of fully-connected layers in
FFNs. Default: 2.
qkv_bias (bool): If True, add a learnable bias to q, k, v.
Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
drop_rate (float): The dropout rate. Defaults to 0.0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): The init value of gamma.
Defaults to 0.0.
act_cfg (dict): The activation config for FFNs.
Defaluts to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
"""
def __init__(
self,
embed_dims: int,
num_heads: int,
feedforward_channels: int,
num_fcs: int = 2,
qkv_bias: bool = False,
qk_scale: float = None,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
layer_scale_init_value: float = 0.0,
act_cfg: dict = dict(type='GELU'),
norm_cfg: dict = dict(type='LN', eps=1e-6)
) -> None:
super().__init__()
# NOTE: cross attention
_, self.norm1_q_cross = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
_, self.norm1_k_cross = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
_, self.norm1_v_cross = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
_, self.norm2_cross = build_norm_layer(norm_cfg, embed_dims, postfix=2)
self.cross_attn = CrossMultiheadAttention(
embed_dims,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop_rate,
proj_drop=drop_rate)
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=None,
act_cfg=act_cfg,
add_identity=False)
self.drop_path = DropPath(drop_prob=drop_path_rate)
if layer_scale_init_value > 0:
self.gamma_1_cross = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True)
self.gamma_2_cross = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True)
else:
self.gamma_1_cross = nn.Parameter(
torch.ones((embed_dims)), requires_grad=False)
self.gamma_2_cross = nn.Parameter(
torch.ones((embed_dims)), requires_grad=False)
def forward(self, x_q: torch.Tensor, x_kv: torch.Tensor,
pos_q: torch.Tensor, pos_k: torch.Tensor) -> torch.Tensor:
"""Forward function."""
x = x_q + self.drop_path(self.gamma_1_cross * self.cross_attn(
self.norm1_q_cross(x_q + pos_q),
k=self.norm1_k_cross(x_kv + pos_k),
v=self.norm1_v_cross(x_kv)))
x = self.norm2_cross(x)
x = x + self.drop_path(self.gamma_2_cross * self.ffn(x))
return x
@MODELS.register_module()
class CAENeck(BaseModule):
"""Neck for CAE Pre-training.
This module construct the latent prediction regressor and the decoder
for the latent prediction and final prediction.
Args:
num_classes (int): The number of classes for final prediction. Defaults
to 8192.
embed_dims (int): The embed dims of latent feature in regressor and
decoder. Defaults to 768.
regressor_depth (int): The number of regressor blocks. Defaults to 6.
decoder_depth (int): The number of decoder blocks. Defaults to 8.
num_heads (int): The number of head in multi-head attention. Defaults
to 12.
mlp_ratio (int): The expand ratio of latent features in MLP. defaults
to 4.
qkv_bias (bool): Whether or not to use qkv bias. Defaults to True.
qk_scale (float, optional): The scale applied to the results of qk.
Defaults to None.
drop_rate (float): The dropout rate. Defaults to 0.
attn_drop_rate (float): The dropout rate in attention block. Defaults
to 0.
norm_cfg (dict): The config of normalization layer. Defaults to
dict(type='LN', eps=1e-6).
layer_scale_init_value (float, optional): The init value of gamma.
Defaults to None.
mask_tokens_num (int): The number of mask tokens. Defaults to 75.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
num_classes: int = 8192,
embed_dims: int = 768,
regressor_depth: int = 6,
decoder_depth: int = 8,
num_heads: int = 12,
mlp_ratio: int = 4,
qkv_bias: bool = True,
qk_scale: float = None,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
norm_cfg: dict = dict(type='LN', eps=1e-6),
layer_scale_init_value: float = None,
mask_tokens_num: int = 75,
init_cfg: dict = None) -> None:
super().__init__(init_cfg=init_cfg)
self.num_features = self.embed_dim = embed_dims
self.mask_token_num = mask_tokens_num
# regressor
regressor_drop_path_rates = [
x.item()
for x in torch.linspace(0, drop_path_rate, regressor_depth)
]
self.regressors = nn.ModuleList([
CAETransformerRegressorLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio * embed_dims,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=regressor_drop_path_rates[i],
norm_cfg=norm_cfg,
layer_scale_init_value=layer_scale_init_value)
for i in range(regressor_depth)
])
# decoder
decoder_drop_path_rates = [
x.item() for x in torch.linspace(0, drop_path_rate, decoder_depth)
]
self.decoders = nn.ModuleList([
BEiTTransformerEncoderLayer(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=mlp_ratio * embed_dims,
layer_scale_init_value=layer_scale_init_value,
window_size=None,
# setting `use_rel_pos_bias` to False ignores the `window_size`
use_rel_pos_bias=False,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=decoder_drop_path_rates[i],
norm_cfg=norm_cfg) for i in range(decoder_depth)
])
_, self.norm_regressor = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
_, self.norm_decoder = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
self.head = nn.Linear(
embed_dims, num_classes) if num_classes > 0 else nn.Identity()
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
def init_weights(self) -> None:
"""Initialization."""
super().init_weights()
self.apply(self._init_weights)
trunc_normal_(self.mask_token, std=0.02)
trunc_normal_(self.head.weight, std=0.02)
def _init_weights(self, m: nn.Module) -> None:
"""Initialization."""
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(
self, x_unmasked: torch.Tensor, pos_embed_masked: torch.Tensor,
pos_embed_unmasked: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the latent prediction and final prediction.
Args:
x_unmasked (torch.Tensor): Features of unmasked tokens.
pos_embed_masked (torch.Tensor): Position embedding of masked
tokens.
pos_embed_unmasked (torch.Tensor): Position embedding of unmasked
tokens.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- ``logits``: Final prediction.
- ``latent_pred``: Latent prediction.
"""
x_masked = self.mask_token.expand(x_unmasked.shape[0],
self.mask_token_num, -1)
# regressor
for regressor in self.regressors:
x_masked = regressor(
x_masked, torch.cat([x_unmasked, x_masked], dim=1),
pos_embed_masked,
torch.cat([pos_embed_unmasked, pos_embed_masked], dim=1))
x_masked = self.norm_regressor(x_masked)
latent_pred = x_masked
# decoder
x_masked = x_masked + pos_embed_masked
for decoder in self.decoders:
x_masked = decoder(x_masked)
x_masked = self.norm_decoder(x_masked)
logits = self.head(x_masked)
return logits, latent_pred

View File

@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class DenseCLNeck(BaseModule):
"""The non-linear neck of DenseCL.
Single and dense neck in parallel: fc-relu-fc, conv-relu-conv.
Borrowed from the authors' `code <https://github.com/WXinlong/DenseCL>`_.
Args:
in_channels (int): Number of input channels.
hid_channels (int): Number of hidden channels.
out_channels (int): Number of output channels.
num_grid (int): The grid size of dense features. Defaults to None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: int,
hid_channels: int,
out_channels: int,
num_grid: Optional[int] = None,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(init_cfg)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.mlp = nn.Sequential(
nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True),
nn.Linear(hid_channels, out_channels))
self.with_pool = True if num_grid is not None else False
if self.with_pool:
self.pool = nn.AdaptiveAvgPool2d((num_grid, num_grid))
self.mlp2 = nn.Sequential(
nn.Conv2d(in_channels, hid_channels, 1), nn.ReLU(inplace=True),
nn.Conv2d(hid_channels, out_channels, 1))
self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
"""Forward function of neck.
Args:
x (Tuple[torch.Tensor]): feature map of backbone.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- ``avgpooled_x``: Global feature vectors.
- ``x``: Dense feature vectors.
- ``avgpooled_x2``: Dense feature vectors for queue.
"""
assert len(x) == 1
x = x[0]
avgpooled_x = self.avgpool(x)
avgpooled_x = self.mlp(avgpooled_x.view(avgpooled_x.size(0), -1))
if self.with_pool:
x = self.pool(x) # sxs
x = self.mlp2(x) # sxs: bxdxsxs
avgpooled_x2 = self.avgpool2(x) # 1x1: bxdx1x1
x = x.view(x.size(0), x.size(1), -1) # bxdxs^2
avgpooled_x2 = avgpooled_x2.view(avgpooled_x2.size(0), -1) # bxd
return avgpooled_x, x, avgpooled_x2

View File

@ -11,12 +11,14 @@ from mmpretrain.registry import MODELS
@MODELS.register_module()
class LinearReduction(BaseModule):
"""Neck with Dimension reduction.
class LinearNeck(BaseModule):
"""Linear neck with Dimension projection.
Args:
in_channels (int): Number of channels in the input.
out_channels (int): Number of channels in the output.
gap_dim (int): Dimensions of each sample channel, can be one of
{0, 1, 2, 3}. Defaults to 0.
norm_cfg (dict, optional): dictionary to construct and
config norm layer. Defaults to dict(type='BN1d').
act_cfg (dict, optional): dictionary to construct and
@ -28,22 +30,35 @@ class LinearReduction(BaseModule):
def __init__(self,
in_channels: int,
out_channels: int,
gap_dim: int = 0,
norm_cfg: Optional[dict] = dict(type='BN1d'),
act_cfg: Optional[dict] = None,
init_cfg: Optional[dict] = None):
super(LinearReduction, self).__init__(init_cfg=init_cfg)
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.norm_cfg = copy.deepcopy(norm_cfg)
self.act_cfg = copy.deepcopy(act_cfg)
self.reduction = nn.Linear(
in_features=in_channels, out_features=out_channels)
assert gap_dim in [0, 1, 2, 3], 'GlobalAveragePooling dim only ' \
f'support {0, 1, 2, 3}, get {gap_dim} instead.'
if gap_dim == 0:
self.gap = nn.Identity()
elif gap_dim == 1:
self.gap = nn.AdaptiveAvgPool1d(1)
elif gap_dim == 2:
self.gap = nn.AdaptiveAvgPool2d((1, 1))
elif gap_dim == 3:
self.gap = nn.AdaptiveAvgPool3d((1, 1, 1))
self.fc = nn.Linear(in_features=in_channels, out_features=out_channels)
if norm_cfg:
self.norm = build_norm_layer(norm_cfg, out_channels)[1]
else:
self.norm = nn.Identity()
if act_cfg:
self.act = build_activation_layer(act_cfg)
else:
@ -59,13 +74,15 @@ class LinearReduction(BaseModule):
the last stage will be used.
Returns:
Tuple(torch.Tensor)): A tuple of reducted features.
Tuple[torch.Tensor]: A tuple of output features.
"""
assert isinstance(inputs, (tuple, torch.Tensor)), (
'The inputs of `LinearReduction` neck must be tuple or '
f'`torch.Tensor`, but get {type(inputs)}.')
'The inputs of `LinearNeck` must be tuple or `torch.Tensor`, '
f'but get {type(inputs)}.')
if isinstance(inputs, tuple):
inputs = inputs[-1]
out = self.act(self.norm(self.reduction(inputs)))
x = self.gap(inputs)
x = x.view(x.size(0), -1)
out = self.act(self.norm(self.fc(x)))
return (out, )

View File

@ -0,0 +1,189 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
from ..backbones.vision_transformer import TransformerEncoderLayer
from ..utils import build_2d_sincos_position_embedding
@MODELS.register_module()
class MAEPretrainDecoder(BaseModule):
"""Decoder for MAE Pre-training.
Some of the code is borrowed from `https://github.com/facebookresearch/mae`. # noqa
Args:
num_patches (int): The number of total patches. Defaults to 196.
patch_size (int): Image patch size. Defaults to 16.
in_chans (int): The channel of input image. Defaults to 3.
embed_dim (int): Encoder's embedding dimension. Defaults to 1024.
decoder_embed_dim (int): Decoder's embedding dimension.
Defaults to 512.
decoder_depth (int): The depth of decoder. Defaults to 8.
decoder_num_heads (int): Number of attention heads of decoder.
Defaults to 16.
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
Defaults to 4.
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
init_cfg (Union[List[dict], dict], optional): Initialization config
dict. Defaults to None.
Example:
>>> from mmpretrain.models import MAEPretrainDecoder
>>> import torch
>>> self = MAEPretrainDecoder()
>>> self.eval()
>>> inputs = torch.rand(1, 50, 1024)
>>> ids_restore = torch.arange(0, 196).unsqueeze(0)
>>> level_outputs = self.forward(inputs, ids_restore)
>>> print(tuple(level_outputs.shape))
(1, 196, 768)
"""
def __init__(self,
num_patches: int = 196,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 1024,
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
mlp_ratio: int = 4,
norm_cfg: dict = dict(type='LN', eps=1e-6),
predict_feature_dim: Optional[float] = None,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.num_patches = num_patches
# used to convert the dim of features from encoder to the dim
# compatible with that of decoder
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
# create new position embedding, different from that in encoder
# and is not learnable
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1, decoder_embed_dim),
requires_grad=False)
self.decoder_blocks = nn.ModuleList([
TransformerEncoderLayer(
decoder_embed_dim,
decoder_num_heads,
int(mlp_ratio * decoder_embed_dim),
qkv_bias=True,
norm_cfg=norm_cfg) for _ in range(decoder_depth)
])
self.decoder_norm_name, decoder_norm = build_norm_layer(
norm_cfg, decoder_embed_dim, postfix=1)
self.add_module(self.decoder_norm_name, decoder_norm)
# Used to map features to pixels
if predict_feature_dim is None:
predict_feature_dim = patch_size**2 * in_chans
self.decoder_pred = nn.Linear(
decoder_embed_dim, predict_feature_dim, bias=True)
def init_weights(self) -> None:
"""Initialize position embedding and mask token of MAE decoder."""
super().init_weights()
decoder_pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.decoder_pos_embed.shape[-1],
cls_token=True)
self.decoder_pos_embed.data.copy_(decoder_pos_embed.float())
torch.nn.init.normal_(self.mask_token, std=.02)
@property
def decoder_norm(self):
"""The normalization layer of decoder."""
return getattr(self, self.decoder_norm_name)
def forward(self, x: torch.Tensor,
ids_restore: torch.Tensor) -> torch.Tensor:
"""The forward function.
The process computes the visible patches' features vectors and the mask
tokens to output feature vectors, which will be used for
reconstruction.
Args:
x (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
ids_restore (torch.Tensor): ids to restore original image.
Returns:
x (torch.Tensor): The reconstructed feature vectors, which is of
shape B x (num_patches) x C.
"""
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(
x_,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
return x
@MODELS.register_module()
class ClsBatchNormNeck(BaseModule):
"""Normalize cls token across batch before head.
This module is proposed by MAE, when running linear probing.
Args:
input_features (int): The dimension of features.
affine (bool): a boolean value that when set to ``True``, this module
has learnable affine parameters. Defaults to False.
eps (float): a value added to the denominator for numerical stability.
Defaults to 1e-6.
init_cfg (Dict or List[Dict], optional): Config dict for weight
initialization. Defaults to None.
"""
def __init__(self,
input_features: int,
affine: bool = False,
eps: float = 1e-6,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(init_cfg)
self.bn = nn.BatchNorm1d(input_features, affine=affine, eps=eps)
def forward(
self,
inputs: Tuple[List[torch.Tensor]]) -> Tuple[List[torch.Tensor]]:
"""The forward function."""
# Only apply batch norm to cls token, which is the second tensor in
# each item of the tuple
inputs = [[input_[0], self.bn(input_[1])] for input_ in inputs]
return tuple(inputs)

View File

@ -0,0 +1,222 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
from torch import nn
from mmpretrain.registry import MODELS
from ..backbones.vision_transformer import TransformerEncoderLayer
from ..utils import PromptMultiheadAttention
from .mae_neck import MAEPretrainDecoder
class PromptTransformerEncoderLayer(TransformerEncoderLayer):
"""Prompt Transformer Encoder Layer for MILAN.
This module is specific for the prompt encoder in MILAN. It will not update
the visible tokens from the encoder.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Defaults to 0.0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): Enable bias for qkv if True. Defaults to True.
act_cfg (dict): The activation config for FFNs.
Defaluts to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Defaults to False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims: int,
num_heads: int,
feedforward_channels=int,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
num_fcs: int = 2,
qkv_bias: bool = True,
act_cfg: dict = dict(type='GELU'),
norm_cfg: dict = dict(type='LN'),
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
embed_dims=embed_dims,
num_heads=num_heads,
feedforward_channels=feedforward_channels,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
num_fcs=num_fcs,
qkv_bias=qkv_bias,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
init_cfg=init_cfg)
self.attn = PromptMultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias)
def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor,
ids_restore: torch.Tensor) -> torch.Tensor:
"""Forward function for `PromptMultiheadAttention`.
Args:
x (torch.Tensor): Mask token features with shape N x L_m x C.
visible_tokens (torch.Tensor): The visible tokens features from
encoder with shape N x L_v x C.
ids_restore (torch.Tensor): The ids of all tokens in the original
image with shape N x L.
Returns:
torch Tensor: Output features with shape N x L x C.
"""
x = x + self.attn(self.norm1(x), visible_tokens, ids_restore)
x = self.ffn(self.norm2(x), identity=x)
return x
@MODELS.register_module()
class MILANPretrainDecoder(MAEPretrainDecoder):
"""Prompt decoder for MILAN.
This decoder is used in MILAN pretraining, which will not update these
visible tokens from the encoder.
Args:
num_patches (int): The number of total patches. Defaults to 196.
patch_size (int): Image patch size. Defaults to 16.
in_chans (int): The channel of input image. Defaults to 3.
embed_dim (int): Encoder's embedding dimension. Defaults to 1024.
decoder_embed_dim (int): Decoder's embedding dimension.
Defaults to 512.
decoder_depth (int): The depth of decoder. Defaults to 8.
decoder_num_heads (int): Number of attention heads of decoder.
Defaults to 16.
predict_feature_dim (int): The dimension of the feature to be
predicted. Defaults to 512.
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
Defaults to 4.
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
init_cfg (Union[List[dict], dict], optional): Initialization config
dict. Defaults to None.
"""
def __init__(self,
num_patches: int = 196,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 1024,
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
predict_feature_dim: int = 512,
mlp_ratio: int = 4,
norm_cfg: dict = dict(type='LN', eps=1e-6),
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
num_patches=num_patches,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
decoder_embed_dim=decoder_embed_dim,
decoder_depth=decoder_depth,
decoder_num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
norm_cfg=norm_cfg,
init_cfg=init_cfg)
# map the dim of features from decoder to the dim compatible with
# that of CLIP
self.decoder_pred = nn.Linear(
decoder_embed_dim, predict_feature_dim, bias=True)
# use prompt transformer encoder layer, instead of the conventional
# transformer encoder layer
self.decoder_blocks = nn.ModuleList([
PromptTransformerEncoderLayer(
decoder_embed_dim,
decoder_num_heads,
int(mlp_ratio * decoder_embed_dim),
qkv_bias=True,
norm_cfg=norm_cfg) for _ in range(decoder_depth)
])
def forward(self, x: torch.Tensor, ids_restore: torch.Tensor,
ids_keep: torch.Tensor,
ids_dump: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
x (torch.Tensor): The input features, which is of shape (N, L, C).
ids_restore (torch.Tensor): The indices to restore these tokens
to the original image.
ids_keep (torch.Tensor): The indices of tokens to be kept.
ids_dump (torch.Tensor): The indices of tokens to be masked.
Returns:
torch.Tensor: The reconstructed features, which is of shape
(N, L, C).
"""
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(
x_,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)
# add pos embed
x = x + self.decoder_pos_embed
# split mask tokens and visible tokens
visible_tokens = torch.cat([
x[:, :1, :],
torch.gather(
x[:, 1:, :],
dim=1,
index=ids_keep.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
],
dim=1)
x = torch.gather(
x[:, 1:, :],
dim=1,
index=ids_dump.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
for blk in self.decoder_blocks:
x = blk(x, visible_tokens, ids_restore)
# full sequence recovery
x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1)
x_ = torch.gather(
x_,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1,
x.shape[-1])) # unshuffle
x = torch.cat([visible_tokens[:, :1, :], x_], dim=1)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
return x

View File

@ -0,0 +1,111 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
import torch.nn as nn
from mmpretrain.registry import MODELS
from ..utils import build_2d_sincos_position_embedding
from .mae_neck import MAEPretrainDecoder
@MODELS.register_module()
class MixMIMPretrainDecoder(MAEPretrainDecoder):
"""Decoder for MixMIM Pretraining.
Some of the code is borrowed from `https://github.com/Sense-X/MixMIM`. # noqa
Args:
num_patches (int): The number of total patches. Defaults to 196.
patch_size (int): Image patch size. Defaults to 16.
in_chans (int): The channel of input image. Defaults to 3.
embed_dim (int): Encoder's embedding dimension. Defaults to 1024.
encoder_stride (int): The output stride of MixMIM backbone. Defaults
to 32.
decoder_embed_dim (int): Decoder's embedding dimension.
Defaults to 512.
decoder_depth (int): The depth of decoder. Defaults to 8.
decoder_num_heads (int): Number of attention heads of decoder.
Defaults to 16.
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
Defaults to 4.
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
init_cfg (Union[List[dict], dict], optional): Initialization config
dict. Defaults to None.
"""
def __init__(self,
num_patches: int = 196,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 1024,
encoder_stride: int = 32,
decoder_embed_dim: int = 512,
decoder_depth: int = 8,
decoder_num_heads: int = 16,
mlp_ratio: int = 4,
norm_cfg: dict = dict(type='LN', eps=1e-6),
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
num_patches=num_patches,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
decoder_embed_dim=decoder_embed_dim,
decoder_depth=decoder_depth,
decoder_num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
norm_cfg=norm_cfg,
init_cfg=init_cfg)
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, decoder_embed_dim),
requires_grad=False)
self.decoder_pred = nn.Linear(decoder_embed_dim, encoder_stride**2 * 3)
def init_weights(self) -> None:
"""Initialize position embedding and mask token of MixMIM decoder."""
super(MAEPretrainDecoder, self).init_weights()
decoder_pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.decoder_pos_embed.shape[-1],
cls_token=False)
self.decoder_pos_embed.data.copy_(decoder_pos_embed.float())
torch.nn.init.normal_(self.mask_token, std=.02)
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
x (torch.Tensor): The input features, which is of shape (N, L, C).
mask (torch.Tensor): The tensor to indicate which tokens a
re masked.
Returns:
torch.Tensor: The reconstructed features, which is of shape
(N, L, C).
"""
x = self.decoder_embed(x)
B, L, C = x.shape
mask_tokens = self.mask_token.expand(B, L, -1)
x1 = x * (1 - mask) + mask_tokens * mask
x2 = x * mask + mask_tokens * (1 - mask)
x = torch.cat([x1, x2], dim=0)
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for idx, blk in enumerate(self.decoder_blocks):
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
return x

View File

@ -0,0 +1,52 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class MoCoV2Neck(BaseModule):
"""The non-linear neck of MoCo v2: fc-relu-fc.
Args:
in_channels (int): Number of input channels.
hid_channels (int): Number of hidden channels.
out_channels (int): Number of output channels.
with_avg_pool (bool): Whether to apply the global
average pooling after backbone. Defaults to True.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: int,
hid_channels: int,
out_channels: int,
with_avg_pool: bool = True,
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
super().__init__(init_cfg)
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.mlp = nn.Sequential(
nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True),
nn.Linear(hid_channels, out_channels))
def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
"""Forward function.
Args:
x (Tuple[torch.Tensor]): The feature map of backbone.
Returns:
Tuple[torch.Tensor]: The output features.
"""
assert len(x) == 1
x = x[0]
if self.with_avg_pool:
x = self.avgpool(x)
return (self.mlp(x.view(x.size(0), -1)), )

View File

@ -0,0 +1,121 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class NonLinearNeck(BaseModule):
"""The non-linear neck.
Structure: fc-bn-[relu-fc-bn] where the substructure in [] can be repeated.
For the default setting, the repeated time is 1.
The neck can be used in many algorithms, e.g., SimCLR, BYOL, SimSiam.
Args:
in_channels (int): Number of input channels.
hid_channels (int): Number of hidden channels.
out_channels (int): Number of output channels.
num_layers (int): Number of fc layers. Defaults to 2.
with_bias (bool): Whether to use bias in fc layers (except for the
last). Defaults to False.
with_last_bn (bool): Whether to add the last BN layer.
Defaults to True.
with_last_bn_affine (bool): Whether to have learnable affine parameters
in the last BN layer (set False for SimSiam). Defaults to True.
with_last_bias (bool): Whether to use bias in the last fc layer.
Defaults to False.
with_avg_pool (bool): Whether to apply the global average pooling
after backbone. Defaults to True.
vit_backbone (bool): The key to indicate whether the upstream backbone
is ViT. Defaults to False.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to dict(type='SyncBN').
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(
self,
in_channels: int,
hid_channels: int,
out_channels: int,
num_layers: int = 2,
with_bias: bool = False,
with_last_bn: bool = True,
with_last_bn_affine: bool = True,
with_last_bias: bool = False,
with_avg_pool: bool = True,
vit_backbone: bool = False,
norm_cfg: dict = dict(type='SyncBN'),
init_cfg: Optional[Union[dict, List[dict]]] = [
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]
) -> None:
super(NonLinearNeck, self).__init__(init_cfg)
self.with_avg_pool = with_avg_pool
self.vit_backbone = vit_backbone
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.relu = nn.ReLU(inplace=True)
self.fc0 = nn.Linear(in_channels, hid_channels, bias=with_bias)
self.bn0 = build_norm_layer(norm_cfg, hid_channels)[1]
self.fc_names = []
self.bn_names = []
for i in range(1, num_layers):
this_channels = out_channels if i == num_layers - 1 \
else hid_channels
if i != num_layers - 1:
self.add_module(
f'fc{i}',
nn.Linear(hid_channels, this_channels, bias=with_bias))
self.add_module(f'bn{i}',
build_norm_layer(norm_cfg, this_channels)[1])
self.bn_names.append(f'bn{i}')
else:
self.add_module(
f'fc{i}',
nn.Linear(
hid_channels, this_channels, bias=with_last_bias))
if with_last_bn:
self.add_module(
f'bn{i}',
build_norm_layer(
dict(**norm_cfg, affine=with_last_bn_affine),
this_channels)[1])
self.bn_names.append(f'bn{i}')
else:
self.bn_names.append(None)
self.fc_names.append(f'fc{i}')
def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
"""Forward function.
Args:
x (Tuple[torch.Tensor]): The feature map of backbone.
Returns:
Tuple[torch.Tensor]: The output features.
"""
assert len(x) == 1
x = x[0]
if self.vit_backbone:
x = x[-1]
if self.with_avg_pool:
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc0(x)
x = self.bn0(x)
for fc_name, bn_name in zip(self.fc_names, self.bn_names):
fc = getattr(self, fc_name)
x = self.relu(x)
x = fc(x)
if bn_name is not None:
bn = getattr(self, bn_name)
x = bn(x)
return (x, )

View File

@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class SimMIMLinearDecoder(BaseModule):
"""Linear Decoder For SimMIM pretraining.
This neck reconstructs the original image from the shrunk feature map.
Args:
in_channels (int): Channel dimension of the feature map.
encoder_stride (int): The total stride of the encoder.
"""
def __init__(self, in_channels: int, encoder_stride: int) -> None:
super().__init__()
self.decoder = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=encoder_stride**2 * 3,
kernel_size=1),
nn.PixelShuffle(encoder_stride),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function."""
x = self.decoder(x)
return x

View File

@ -0,0 +1,93 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class SwAVNeck(BaseModule):
"""The non-linear neck of SwAV: fc-bn-relu-fc-normalization.
Args:
in_channels (int): Number of input channels.
hid_channels (int): Number of hidden channels.
out_channels (int): Number of output channels.
with_avg_pool (bool): Whether to apply the global average pooling after
backbone. Defaults to True.
with_l2norm (bool): whether to normalize the output after projection.
Defaults to True.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to dict(type='SyncBN').
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(
self,
in_channels: int,
hid_channels: int,
out_channels: int,
with_avg_pool: bool = True,
with_l2norm: bool = True,
norm_cfg: dict = dict(type='SyncBN'),
init_cfg: Optional[Union[dict, List[dict]]] = [
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]
) -> None:
super().__init__(init_cfg)
self.with_avg_pool = with_avg_pool
self.with_l2norm = with_l2norm
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
if out_channels == 0:
self.projection_neck = nn.Identity()
elif hid_channels == 0:
self.projection_neck = nn.Linear(in_channels, out_channels)
else:
self.norm = build_norm_layer(norm_cfg, hid_channels)[1]
self.projection_neck = nn.Sequential(
nn.Linear(in_channels, hid_channels),
self.norm,
nn.ReLU(inplace=True),
nn.Linear(hid_channels, out_channels),
)
def forward_projection(self, x: torch.Tensor) -> torch.Tensor:
"""Compute projection.
Args:
x (torch.Tensor): The feature vectors after pooling.
Returns:
torch.Tensor: The output features with projection or L2-norm.
"""
x = self.projection_neck(x)
if self.with_l2norm:
x = nn.functional.normalize(x, dim=1, p=2)
return x
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
"""Forward function.
Args:
x (List[torch.Tensor]): list of feature maps, len(x) according to
len(num_crops).
Returns:
torch.Tensor: The projection vectors.
"""
avg_out = []
for _x in x:
_x = _x[0]
if self.with_avg_pool:
_out = self.avgpool(_x)
avg_out.append(_out)
feat_vec = torch.cat(avg_out) # [sum(num_crops) * N, C]
feat_vec = feat_vec.view(feat_vec.size(0), -1)
output = self.forward_projection(feat_vec)
return output

View File

@ -1,7 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .attention import (BEiTAttention, ChannelMultiheadAttention, LeAttention,
MultiheadAttention, ShiftWindowMSA, WindowMSA,
WindowMSAV2)
from .attention import (BEiTAttention, ChannelMultiheadAttention,
CrossMultiheadAttention, LeAttention,
MultiheadAttention, PromptMultiheadAttention,
ShiftWindowMSA, WindowMSA, WindowMSAV2)
from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix
from .channel_shuffle import channel_shuffle
from .data_preprocessor import ClsDataPreprocessor
@ -13,7 +14,8 @@ from .layer_scale import LayerScale
from .make_divisible import make_divisible
from .norm import GRN, LayerNorm2d, build_norm_layer
from .position_encoding import (ConditionalPositionEncoding,
PositionEncodingFourier)
PositionEncodingFourier,
build_2d_sincos_position_embedding)
from .se_layer import SELayer
__all__ = [
@ -25,5 +27,6 @@ __all__ = [
'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix', 'BEiTAttention',
'LayerScale', 'WindowMSA', 'WindowMSAV2', 'ChannelMultiheadAttention',
'PositionEncodingFourier', 'LeAttention', 'GRN', 'LayerNorm2d',
'build_norm_layer'
'build_norm_layer', 'CrossMultiheadAttention',
'build_2d_sincos_position_embedding', 'PromptMultiheadAttention'
]

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import itertools
from typing import List, Optional, Union
import numpy as np
import torch
@ -683,6 +684,7 @@ class BEiTAttention(BaseModule):
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
Wh = self.window_size[0]
Ww = self.window_size[1]
@ -879,3 +881,192 @@ class LeAttention(BaseModule):
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
class CrossMultiheadAttention(BaseModule):
"""Cross attention between queries and the union of keys and values.
This module is different from ``MultiheadAttention``, for the attention
is computed between queries and the union of keys and values.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
qkv_bias (bool): If True, add a learnable bias to q, k, v.
Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
"""
def __init__(self,
embed_dims: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_scale: float = None,
attn_drop: float = 0.,
proj_drop: float = 0.) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = embed_dims // num_heads
self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(embed_dims, embed_dims, bias=False)
self.k = nn.Linear(embed_dims, embed_dims, bias=False)
self.v = nn.Linear(embed_dims, embed_dims, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(embed_dims))
self.v_bias = nn.Parameter(torch.zeros(embed_dims))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self,
x: torch.Tensor,
k: torch.Tensor = None,
v: torch.Tensor = None) -> None:
"""Forward function."""
B, N, _ = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
v_bias = self.v_bias
q = F.linear(
input=x, weight=self.q.weight, bias=q_bias) # (B, N_q, dim)
k = F.linear(
input=k, weight=self.k.weight, bias=k_bias) # (B, N_k, dim)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
q = q.reshape(B, N, 1, self.num_heads,
-1).permute(2, 0, 3, 1,
4).squeeze(0) # (B, num_heads, N_q, dim)
k = k.reshape(B, N_k, 1, self.num_heads,
-1).permute(2, 0, 3, 1,
4).squeeze(0) # (B, num_heads, N_k, dim)
v = v.reshape(B, N_v, 1, self.num_heads,
-1).permute(2, 0, 3, 1,
4).squeeze(0) # (B, num_heads, N_v, dim)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class PromptMultiheadAttention(MultiheadAttention):
"""Prompt Multihead Attention for MILAN.
This module is specific for the prompt encoder in MILAN. It will not update
the visible tokens from the encoder.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
input_dims (int, optional): The input dimension, and if None,
use ``embed_dims``. Defaults to None.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
dropout_layer (dict): The dropout config before adding the shortcut.
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
qkv_bias (bool): If True, add a learnable bias to q, k, v.
Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool) If True, add a learnable bias to output projection.
Defaults to True.
v_shortcut (bool): Add a shortcut from value to output. It's usually
used if ``input_dims`` is different from ``embed_dims``.
Defaults to False.
return_attention (bool): If True, return the attention map, computed by
the cross attention between the class token and all other tokens.
Defaults to False.
init_cfg (Union[List[dict], dict], optional): The Config for
initialization. Defaults to None.
"""
def __init__(self,
embed_dims: int,
num_heads: int,
input_dims: Optional[int] = None,
attn_drop: float = 0,
proj_drop: float = 0,
dropout_layer: dict = dict(type='Dropout', drop_prob=0.),
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
proj_bias: bool = True,
v_shortcut: bool = False,
use_layer_scale: bool = False,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(embed_dims, num_heads, input_dims, attn_drop,
proj_drop, dropout_layer, qkv_bias, qk_scale,
proj_bias, v_shortcut, use_layer_scale, init_cfg)
# no longer need qkv
del self.qkv
# to project the mask tokens
self.q = nn.Linear(embed_dims, embed_dims, bias=qkv_bias)
# to project al the tokens
self.kv = nn.Linear(embed_dims, embed_dims * 2, bias=qkv_bias)
def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor,
ids_restore: torch.Tensor) -> torch.Tensor:
"""Forward function for `PromptMultiheadAttention`.
Args:
x (torch.Tensor): Mask token features with shape N x L_m x C.
visible_tokens (torch.Tensor): The visible tokens features from
encoder with shape N x L_v x C.
ids_restore (torch.Tensor): The ids of all tokens in the original
image with shape N x L.
Returns:
torch Tensor: Output features with shape N x L x C.
"""
x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1)
assert x_.shape[1] == ids_restore.shape[1]
x_ = torch.gather(
x_,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
x_ = torch.cat([visible_tokens[:, :1, :], x_], dim=1)
# full sequence shape
B, _, _ = x_.shape
q = self.q(x).reshape(B, x.shape[1], self.num_heads,
self.head_dims).permute(0, 2, 1, 3)
kv = self.kv(x_).reshape(B, x_.shape[1], 2, self.num_heads,
self.head_dims).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, x.shape[1], self.embed_dims)
x = self.proj(x)
x = self.out_drop(self.gamma1(self.proj_drop(x)))
return x

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from functools import partial
from typing import Optional, Sequence, Union
import torch
import torch.nn as nn
@ -107,3 +108,57 @@ class PositionEncodingFourier(BaseModule):
pos = self.proj(pos)
return pos
def build_2d_sincos_position_embedding(
patches_resolution: Union[int, Sequence[int]],
embed_dims: int,
temperature: Optional[int] = 10000.,
cls_token: Optional[bool] = False) -> torch.Tensor:
"""The function is to build position embedding for model to obtain the
position information of the image patches.
Args:
patches_resolution (Union[int, Sequence[int]]): The resolution of each
patch.
embed_dims (int): The dimension of the embedding vector.
temperature (int, optional): The temperature parameter. Defaults to
10000.
cls_token (bool, optional): Whether to concatenate class token.
Defaults to False.
Returns:
torch.Tensor: The position embedding vector.
"""
if isinstance(patches_resolution, int):
patches_resolution = (patches_resolution, patches_resolution)
h, w = patches_resolution
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
assert embed_dims % 4 == 0, \
'Embed dimension must be divisible by 4.'
pos_dim = embed_dims // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature**omega)
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat(
[
torch.sin(out_w),
torch.cos(out_w),
torch.sin(out_h),
torch.cos(out_h)
],
dim=1,
)[None, :, :]
if cls_token:
cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32)
pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1)
return pos_emb

View File

@ -371,3 +371,33 @@ def test_seesaw_loss():
fake_label = torch.Tensor([0]).long()
loss = loss_cls(fake_pred, fake_label)
assert torch.allclose(loss, torch.tensor(200.) + torch.tensor(100.).log())
def test_reconstruction_loss():
# test L2 loss
loss_config = dict(type='PixelReconstructionLoss', criterion='L2')
loss = build_loss(loss_config)
fake_pred = torch.rand((2, 196, 768))
fake_target = torch.rand((2, 196, 768))
fake_mask = torch.ones((2, 196))
loss_value = loss(fake_pred, fake_target, fake_mask)
assert isinstance(loss_value.item(), float)
# test L1 loss
loss_config = dict(
type='PixelReconstructionLoss', criterion='L1', channel=3)
loss = build_loss(loss_config)
fake_pred = torch.rand((2, 3, 192, 192))
fake_target = torch.rand((2, 3, 192, 192))
fake_mask = torch.ones((2, 1, 192, 192))
loss_value = loss(fake_pred, fake_target, fake_mask)
assert isinstance(loss_value.item(), float)
with pytest.raises(NotImplementedError):
loss_config = dict(type='PixelReconstructionLoss', criterion='L3')
loss = build_loss(loss_config)

View File

@ -4,7 +4,7 @@ import torch
from mmpretrain.models.necks import (GeneralizedMeanPooling,
GlobalAveragePooling, HRFuseScales,
LinearReduction)
LinearNeck)
def test_gap_neck():
@ -107,8 +107,9 @@ def test_hr_fuse_scales():
def test_linear_reduction():
# test linear_reduction without `act_cfg` and `norm_cfg`
neck = LinearReduction(10, 5, None, None)
neck = LinearNeck(10, 5, 0, None, None)
neck.eval()
assert isinstance(neck.gap, torch.nn.Identity)
assert isinstance(neck.act, torch.nn.Identity)
assert isinstance(neck.norm, torch.nn.Identity)
@ -125,12 +126,36 @@ def test_linear_reduction():
# batch_size, out_features
assert output[-1].shape == (1, 5)
# batch_size, in_channels, out_channels, gap_dim
neck = LinearNeck(10, 5, 1, None, None)
fake_input = torch.rand(1, 10, 10)
output = neck(fake_input)
# batch_size, out_features
assert output[-1].shape == (1, 5)
# batch_size, in_channels, out_channels, gap_dim
neck = LinearNeck(10, 5, 2, None, None)
fake_input = torch.rand(1, 10, 10, 10)
output = neck(fake_input)
# batch_size, out_features
assert output[-1].shape == (1, 5)
# batch_size, in_channels, out_channels, gap_dim
neck = LinearNeck(10, 5, 3, None, None)
fake_input = torch.rand(1, 10, 10, 10, 10)
output = neck(fake_input)
# batch_size, out_features
assert output[-1].shape == (1, 5)
# batch_size, in_channels, out_channels, gap_dim
with pytest.raises(AssertionError):
neck = LinearNeck(10, 5, None, None, None)
# test linear_reduction with `init_cfg`
neck = LinearReduction(
10, 5, init_cfg=dict(type='Xavier', layer=['Linear']))
neck = LinearNeck(10, 5, init_cfg=dict(type='Xavier', layer=['Linear']))
# test linear_reduction with `act_cfg` and `norm_cfg`
neck = LinearReduction(
neck = LinearNeck(
10, 5, act_cfg=dict(type='ReLU'), norm_cfg=dict(type='BN1d'))
neck.eval()