[Refactor] Add necks, heads and losses for the self-supervised task. (#1376)
* add necks * refactor linear neck * rename simmim neck * add heads * add losses * fix * add unittest * update * update cae * remove mim head * update configpull/1400/head
parent
75c79311f4
commit
63d9f27fde
|
@ -16,7 +16,7 @@ model = dict(
|
|||
type='MAEPretrainHead',
|
||||
norm_pix=True,
|
||||
patch_size=16,
|
||||
loss=dict(type='MAEReconstructionLoss')),
|
||||
loss=dict(type='PixelReconstructionLoss', criterion='L2')),
|
||||
init_cfg=[
|
||||
dict(type='Xavier', layer='Linear', distribution='uniform'),
|
||||
dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
|
||||
|
|
|
@ -71,7 +71,7 @@ model = dict(
|
|||
type='BEiTV1Head',
|
||||
embed_dims=768,
|
||||
num_embed=8192,
|
||||
loss=dict(type='BEiTLoss')),
|
||||
loss=dict(type='CrossEntropyLoss')),
|
||||
target_generator=dict(
|
||||
type='DALL-E',
|
||||
init_cfg=dict(
|
||||
|
|
|
@ -56,7 +56,7 @@ model = dict(
|
|||
type='BEiTV2Head',
|
||||
embed_dims=768,
|
||||
num_embed=8192,
|
||||
loss=dict(type='BEiTLoss')),
|
||||
loss=dict(type='CrossEntropyLoss')),
|
||||
target_generator=dict(
|
||||
type='VQKD',
|
||||
encoder_config=vqkd_encoder,
|
||||
|
|
|
@ -56,7 +56,7 @@ model = dict(
|
|||
type='BEiTV2Head',
|
||||
embed_dims=768,
|
||||
num_embed=8192,
|
||||
loss=dict(type='BEiTLoss')),
|
||||
loss=dict(type='CrossEntropyLoss')),
|
||||
target_generator=dict(
|
||||
type='VQKD',
|
||||
encoder_config=vqkd_encoder,
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
_base_ = 'cae_vit-base-p16_8xb256-amp-coslr-300e_in1k.py'
|
||||
|
||||
# dataset 128 x 16
|
||||
train_dataloader = dict(batch_size=128)
|
|
@ -56,7 +56,6 @@ model = dict(
|
|||
qkv_bias=False),
|
||||
neck=dict(
|
||||
type='CAENeck',
|
||||
patch_size=16,
|
||||
embed_dims=768,
|
||||
num_heads=12,
|
||||
regressor_depth=4,
|
||||
|
|
|
@ -4,7 +4,7 @@ Collections:
|
|||
Training Data: ImageNet-1k
|
||||
Training Techniques:
|
||||
- AdamW
|
||||
Training Resources: 16x A100-80G GPUs
|
||||
Training Resources: 8x A100-80G GPUs
|
||||
Architecture:
|
||||
- ViT
|
||||
Paper:
|
||||
|
@ -19,8 +19,8 @@ Models:
|
|||
Epochs: 300
|
||||
Batch Size: 2048
|
||||
Results: null
|
||||
Config: configs/cae/cae_vit-base-p16_16xb128-amp-coslr-300e_in1k.py
|
||||
Weights: https://download.openmmlab.com/mmselfsup/1.x/cae/cae_vit-base-p16_16xb128-fp16-coslr-300e_in1k/cae_vit-base-p16_16xb128-fp16-coslr-300e_in1k_20220825-404a1929.pth
|
||||
Config: configs/cae/cae_vit-base-p16_8xb256-amp-coslr-300e_in1k.py
|
||||
Weights: https://download.openmmlab.com/mmselfsup/1.x/cae/cae_vit-base-p16_8xb256-amp-coslr-300e_in1k/cae_vit-base-p16_8xb256-amp-coslr-300e_in1k_20221230-808170f3.pth
|
||||
Downstream:
|
||||
- beit-base-p16_cae-pre_8xb128-coslr-100e_in1k
|
||||
- Name: beit-base-p16_cae-pre_8xb128-coslr-100e_in1k
|
||||
|
|
|
@ -56,7 +56,6 @@ model = dict(
|
|||
type='LinearNeck',
|
||||
in_channels=768,
|
||||
out_channels=108,
|
||||
with_avg_pool=False,
|
||||
init_cfg=dict(type='TruncNormal', layer='Linear', std=0.02, bias=0)),
|
||||
head=dict(
|
||||
type='MaskFeatPretrainHead',
|
||||
|
|
|
@ -11,11 +11,12 @@ model = dict(
|
|||
arch='base',
|
||||
img_size=192,
|
||||
stage_cfgs=dict(block_cfgs=dict(window_size=6))),
|
||||
neck=dict(type='SimMIMNeck', in_channels=128 * 2**3, encoder_stride=32),
|
||||
neck=dict(
|
||||
type='SimMIMLinearDecoder', in_channels=128 * 2**3, encoder_stride=32),
|
||||
head=dict(
|
||||
type='SimMIMHead',
|
||||
patch_size=4,
|
||||
loss=dict(type='SimMIMReconstructionLoss', encoder_in_channels=3)))
|
||||
loss=dict(type='PixelReconstructionLoss', criterion='L1', channels=3)))
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
|
|
|
@ -11,11 +11,12 @@ model = dict(
|
|||
arch='base',
|
||||
img_size=192,
|
||||
stage_cfgs=dict(block_cfgs=dict(window_size=6))),
|
||||
neck=dict(type='SimMIMNeck', in_channels=128 * 2**3, encoder_stride=32),
|
||||
neck=dict(
|
||||
type='SimMIMLinearDecoder', in_channels=128 * 2**3, encoder_stride=32),
|
||||
head=dict(
|
||||
type='SimMIMHead',
|
||||
patch_size=4,
|
||||
loss=dict(type='SimMIMReconstructionLoss', encoder_in_channels=3)))
|
||||
loss=dict(type='PixelReconstructionLoss', criterion='L1', channels=3)))
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
|
|
|
@ -12,11 +12,12 @@ model = dict(
|
|||
img_size=192,
|
||||
stage_cfgs=dict(block_cfgs=dict(window_size=12)),
|
||||
pad_small_map=True),
|
||||
neck=dict(type='SimMIMNeck', in_channels=192 * 2**3, encoder_stride=32),
|
||||
neck=dict(
|
||||
type='SimMIMLinearDecoder', in_channels=192 * 2**3, encoder_stride=32),
|
||||
head=dict(
|
||||
type='SimMIMHead',
|
||||
patch_size=4,
|
||||
loss=dict(type='SimMIMReconstructionLoss', encoder_in_channels=3)))
|
||||
loss=dict(type='PixelReconstructionLoss', criterion='L1', channels=3)))
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
|
|
|
@ -1,16 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .beitv1_head import BEiTV1Head
|
||||
from .beitv2_head import BEiTV2Head
|
||||
from .cae_head import CAEHead
|
||||
from .cls_head import ClsHead
|
||||
from .conformer_head import ConformerHead
|
||||
from .contrastive_head import ContrastiveHead
|
||||
from .deit_head import DeiTClsHead
|
||||
from .efficientformer_head import EfficientFormerClsHead
|
||||
from .latent_heads import LatentCrossCorrelationHead, LatentPredictHead
|
||||
from .levit_head import LeViTClsHead
|
||||
from .linear_head import LinearClsHead
|
||||
from .mae_head import MAEPretrainHead
|
||||
from .margin_head import ArcFaceClsHead
|
||||
from .mixmim_head import MixMIMPretrainHead
|
||||
from .multi_label_cls_head import MultiLabelClsHead
|
||||
from .multi_label_csra_head import CSRAClsHead
|
||||
from .multi_label_linear_head import MultiLabelLinearClsHead
|
||||
from .multi_task_head import MultiTaskHead
|
||||
from .stacked_head import StackedLinearClsHead
|
||||
from .swav_head import SwAVHead
|
||||
from .vig_head import VigClsHead
|
||||
from .vision_transformer_head import VisionTransformerClsHead
|
||||
|
||||
|
@ -29,4 +37,13 @@ __all__ = [
|
|||
'MultiTaskHead',
|
||||
'LeViTClsHead',
|
||||
'VigClsHead',
|
||||
'BEiTV1Head',
|
||||
'BEiTV2Head',
|
||||
'CAEHead',
|
||||
'ContrastiveHead',
|
||||
'LatentCrossCorrelationHead',
|
||||
'LatentPredictHead',
|
||||
'MAEPretrainHead',
|
||||
'MixMIMPretrainHead',
|
||||
'SwAVHead',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BEiTV1Head(BaseHead):
|
||||
"""Pretrain Head for BEiT v1.
|
||||
|
||||
Compute the logits and the cross entropy loss.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The dimension of embedding.
|
||||
num_embed (int): The number of classification types.
|
||||
loss (dict): The config of loss.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int,
|
||||
num_embed: int,
|
||||
loss: dict,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = dict(
|
||||
type='TruncNormal', layer='Linear', std=0.02, bias=0)
|
||||
) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.cls_head = nn.Linear(embed_dims, num_embed)
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
def loss(self, feats: torch.Tensor, target: torch.Tensor,
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate loss.
|
||||
|
||||
Args:
|
||||
feats (torch.Tensor): Features from backbone.
|
||||
target (torch.Tensor): Target generated by target_generator.
|
||||
mask (torch.Tensor): Generated mask for pretraing.
|
||||
"""
|
||||
mask = mask.flatten(1).to(torch.bool)
|
||||
target = torch.argmax(target, dim=1).flatten(1)
|
||||
target = target[mask]
|
||||
|
||||
# remove cls_token
|
||||
feats = feats[:, 1:]
|
||||
logits = self.cls_head(feats[mask])
|
||||
|
||||
loss = self.loss_module(logits, target)
|
||||
return loss
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BEiTV2Head(BaseHead):
|
||||
"""Pretrain Head for BEiT.
|
||||
|
||||
Compute the logits and the cross entropy loss.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The dimension of embedding.
|
||||
num_embed (int): The number of classification types.
|
||||
loss (dict): The config of loss.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int,
|
||||
num_embed: int,
|
||||
loss: dict,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = dict(
|
||||
type='TruncNormal', layer='Linear', std=0.02, bias=0)
|
||||
) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.cls_head = nn.Linear(embed_dims, num_embed)
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
def loss(self, feats: torch.Tensor, feats_cls_pt: torch.Tensor,
|
||||
target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate loss.
|
||||
|
||||
Args:
|
||||
feats (torch.Tensor): Features from backbone.
|
||||
feats_cls_pt (torch.Tensor) : Features from class late layers for
|
||||
pretraining.
|
||||
target (torch.Tensor): Target generated by target_generator.
|
||||
mask (torch.Tensor): Generated mask for pretraing.
|
||||
"""
|
||||
mask = mask.flatten(1).to(torch.bool)
|
||||
target = target[mask]
|
||||
|
||||
# shared cls head
|
||||
logits = self.cls_head(feats[mask])
|
||||
logits_cls_pt = self.cls_head(feats_cls_pt[mask])
|
||||
|
||||
loss_1 = self.loss_module(logits, target)
|
||||
loss_2 = self.loss_module(logits_cls_pt, target)
|
||||
return loss_1, loss_2
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CAEHead(BaseHead):
|
||||
"""Pretrain Head for CAE.
|
||||
|
||||
Compute the align loss and the main loss. In addition, this head also
|
||||
generates the prediction target generated by dalle.
|
||||
|
||||
Args:
|
||||
loss (dict): The config of loss.
|
||||
tokenizer_path (str): The path of the tokenizer.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: dict,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
@torch.no_grad()
|
||||
def _generate_target(self, logits_target: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate the reconstruction target.
|
||||
|
||||
Args:
|
||||
logits_target (torch.Tensor): The logits generated by DALL-E.s
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The logits target.
|
||||
"""
|
||||
target = torch.argmax(logits_target, dim=1)
|
||||
return target.flatten(1)
|
||||
|
||||
def loss(self, logits: torch.Tensor, logits_target: torch.Tensor,
|
||||
latent_pred: torch.Tensor, latent_target: torch.Tensor,
|
||||
mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Generate loss.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): Logits generated by decoder.
|
||||
logits_target (img_target): Target generated by dalle for decoder
|
||||
prediction.
|
||||
latent_pred (torch.Tensor): Latent prediction by regressor.
|
||||
latent_target (torch.Tensor): Target for latent prediction,
|
||||
generated by teacher.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The tuple of loss.
|
||||
- loss_main (torch.Tensor): Cross entropy loss.
|
||||
- loss_align (torch.Tensor): MSE loss.
|
||||
"""
|
||||
|
||||
target = self._generate_target(logits_target) # target features
|
||||
target = target[mask].detach()
|
||||
|
||||
# loss main for decoder, loss align for regressor
|
||||
loss_main, loss_align = self.loss_module(logits, target, latent_pred,
|
||||
latent_target)
|
||||
|
||||
return (loss_main, loss_align)
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ContrastiveHead(BaseHead):
|
||||
"""Head for contrastive learning.
|
||||
|
||||
The contrastive loss is implemented in this head and is used in SimCLR,
|
||||
MoCo, DenseCL, etc.
|
||||
|
||||
Args:
|
||||
loss (dict): Config dict for module of loss functions.
|
||||
temperature (float): The temperature hyper-parameter that
|
||||
controls the concentration level of the distribution.
|
||||
Defaults to 0.1.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: dict,
|
||||
temperature: float = 0.1,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.loss_module = MODELS.build(loss)
|
||||
self.temperature = temperature
|
||||
|
||||
def loss(self, pos: torch.Tensor, neg: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function to compute contrastive loss.
|
||||
|
||||
Args:
|
||||
pos (torch.Tensor): Nx1 positive similarity.
|
||||
neg (torch.Tensor): Nxk negative similarity.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The contrastive loss.
|
||||
"""
|
||||
N = pos.size(0)
|
||||
logits = torch.cat((pos, neg), dim=1)
|
||||
logits /= self.temperature
|
||||
labels = torch.zeros((N, ), dtype=torch.long).to(pos.device)
|
||||
|
||||
loss = self.loss_module(logits, labels)
|
||||
return loss
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.dist import all_reduce, get_world_size
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LatentPredictHead(BaseHead):
|
||||
"""Head for latent feature prediction.
|
||||
|
||||
This head builds a predictor, which can be any registered neck component.
|
||||
For example, BYOL and SimSiam call this head and build NonLinearNeck.
|
||||
It also implements similarity loss between two forward features.
|
||||
|
||||
Args:
|
||||
loss (dict): Config dict for the loss.
|
||||
predictor (dict): Config dict for the predictor.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: dict,
|
||||
predictor: dict,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.loss = MODELS.build(loss)
|
||||
self.predictor = MODELS.build(predictor)
|
||||
|
||||
def forward(self, input: torch.Tensor,
|
||||
target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward head.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): NxC input features.
|
||||
target (torch.Tensor): NxC target features.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The latent predict loss.
|
||||
"""
|
||||
pred = self.predictor([input])[0]
|
||||
target = target.detach()
|
||||
|
||||
loss = self.loss(pred, target)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LatentCrossCorrelationHead(BaseHead):
|
||||
"""Head for latent feature cross correlation.
|
||||
|
||||
Part of the code is borrowed from `script
|
||||
<https://github.com/facebookresearch/barlowtwins/blob/main/main.py>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
loss (dict): Config dict for module of loss functions.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
loss: dict,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.world_size = get_world_size()
|
||||
self.bn = nn.BatchNorm1d(in_channels, affine=False)
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
def forward(self, input: torch.Tensor,
|
||||
target: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward head.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): NxC input features.
|
||||
target (torch.Tensor): NxC target features.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The cross correlation loss.
|
||||
"""
|
||||
# cross-correlation matrix
|
||||
cross_correlation_matrix = self.bn(input).T @ self.bn(target)
|
||||
cross_correlation_matrix.div_(input.size(0) * self.world_size)
|
||||
|
||||
all_reduce(cross_correlation_matrix)
|
||||
|
||||
loss = self.loss(cross_correlation_matrix)
|
||||
return loss
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MAEPretrainHead(BaseHead):
|
||||
"""Pre-training head for MAE.
|
||||
|
||||
Args:
|
||||
loss (dict): Config of loss.
|
||||
norm_pix_loss (bool): Whether or not normalize target.
|
||||
Defaults to False.
|
||||
patch_size (int): Patch size. Defaults to 16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: dict,
|
||||
norm_pix: bool = False,
|
||||
patch_size: int = 16) -> None:
|
||||
super().__init__()
|
||||
self.norm_pix = norm_pix
|
||||
self.patch_size = patch_size
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
|
||||
"""Split images into non-overlapped patches.
|
||||
|
||||
Args:
|
||||
imgs (torch.Tensor): A batch of images, of shape B x H x W x C.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Patchified images. The shape is B x L x D.
|
||||
"""
|
||||
p = self.patch_size
|
||||
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
||||
|
||||
h = w = imgs.shape[2] // p
|
||||
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
||||
x = torch.einsum('nchpwq->nhwpqc', x)
|
||||
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
||||
return x
|
||||
|
||||
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Combine non-overlapped patches into images.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The shape is (N, L, patch_size**2 *3)
|
||||
Returns:
|
||||
imgs (torch.Tensor): The shape is (N, 3, H, W)
|
||||
"""
|
||||
p = self.patch_size
|
||||
h = w = int(x.shape[1]**.5)
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
||||
return imgs
|
||||
|
||||
def construct_target(self, target: torch.Tensor) -> torch.Tensor:
|
||||
"""Construct the reconstruction target.
|
||||
|
||||
In addition to splitting images into tokens, this module will also
|
||||
normalize the image according to ``norm_pix``.
|
||||
|
||||
Args:
|
||||
target (torch.Tensor): Image with the shape of B x 3 x H x W
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tokenized images with the shape of B x L x C
|
||||
"""
|
||||
target = self.patchify(target)
|
||||
if self.norm_pix:
|
||||
# normalize the target image
|
||||
mean = target.mean(dim=-1, keepdim=True)
|
||||
var = target.var(dim=-1, keepdim=True)
|
||||
target = (target - mean) / (var + 1.e-6)**.5
|
||||
|
||||
return target
|
||||
|
||||
def loss(self, pred: torch.Tensor, target: torch.Tensor,
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function of MAE head.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The reconstructed image.
|
||||
target (torch.Tensor): The target image.
|
||||
mask (torch.Tensor): The mask of the target image.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reconstruction loss.
|
||||
"""
|
||||
target = self.construct_target(target)
|
||||
loss = self.loss_module(pred, target, mask)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
# TODO: delete and use NaiveMIMHead
|
||||
@MODELS.register_module()
|
||||
class MaskFeatPretrainHead(BaseModule):
|
||||
"""Pre-training head for MaskFeat.
|
||||
|
||||
It computes reconstruction loss between prediction and target in masked
|
||||
region.
|
||||
|
||||
Args:
|
||||
loss (dict): Config dict for module of loss functions.
|
||||
"""
|
||||
|
||||
def __init__(self, loss: dict) -> None:
|
||||
super().__init__()
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor,
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward head.
|
||||
|
||||
Args:
|
||||
latent (torch.Tensor): Predictions,
|
||||
which is of shape B x (1 + L) x C.
|
||||
target (torch.Tensor): Hog features, which is of shape B x L x C.
|
||||
mask (torch.Tensor): The mask of the hog features,
|
||||
which is of shape B x H x W.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The loss tensor.
|
||||
"""
|
||||
mask = mask.flatten(1).bool()
|
||||
loss = self.loss(pred[:, 1:], target, mask)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from .mae_head import MAEPretrainHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MixMIMPretrainHead(MAEPretrainHead):
|
||||
"""MixMIM pretrain head.
|
||||
|
||||
Args:
|
||||
loss (dict): Config of loss.
|
||||
norm_pix_loss (bool): Whether or not normalize target.
|
||||
Defaults to False.
|
||||
patch_size (int): Patch size. Defaults to 16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss: dict,
|
||||
norm_pix: bool = False,
|
||||
patch_size: int = 16) -> None:
|
||||
super().__init__(loss=loss, norm_pix=norm_pix, patch_size=patch_size)
|
||||
|
||||
def loss(self, x_rec: torch.Tensor, target: torch.Tensor,
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function of MixMIM head.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The reconstructed image.
|
||||
target (torch.Tensor): The target image.
|
||||
mask (torch.Tensor): The mask of the target image.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reconstruction loss.
|
||||
"""
|
||||
target = self.construct_target(target)
|
||||
|
||||
B, L, C = x_rec.shape
|
||||
|
||||
# unmix tokens
|
||||
x1_rec = x_rec[:B // 2]
|
||||
x2_rec = x_rec[B // 2:]
|
||||
|
||||
unmix_x_rec = x1_rec * mask + x2_rec.flip(0) * (1 - mask)
|
||||
|
||||
loss_rec = self.loss_module(unmix_x_rec, target)
|
||||
|
||||
return loss_rec
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.dist import get_rank
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.utils import concat_all_gather
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MoCoV3Head(BaseHead):
|
||||
"""Head for MoCo v3 algorithms.
|
||||
|
||||
This head builds a predictor, which can be any registered neck component.
|
||||
It also implements latent contrastive loss between two forward features.
|
||||
Part of the code is modified from:
|
||||
`<https://github.com/facebookresearch/moco-v3/blob/main/moco/builder.py>`_.
|
||||
|
||||
Args:
|
||||
predictor (dict): Config dict for module of predictor.
|
||||
loss (dict): Config dict for module of loss functions.
|
||||
temperature (float): The temperature hyper-parameter that
|
||||
controls the concentration level of the distribution.
|
||||
Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
predictor: dict,
|
||||
loss: dict,
|
||||
temperature: float = 1.0) -> None:
|
||||
super().__init__()
|
||||
self.predictor = MODELS.build(predictor)
|
||||
self.loss_module = MODELS.build(loss)
|
||||
self.temperature = temperature
|
||||
|
||||
def loss(self, base_out: torch.Tensor,
|
||||
momentum_out: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward head.
|
||||
|
||||
Args:
|
||||
base_out (torch.Tensor): NxC features from base_encoder.
|
||||
momentum_out (torch.Tensor): NxC features from momentum_encoder.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The loss tensor.
|
||||
"""
|
||||
# predictor computation
|
||||
pred = self.predictor([base_out])[0]
|
||||
|
||||
# normalize
|
||||
pred = nn.functional.normalize(pred, dim=1)
|
||||
target = nn.functional.normalize(momentum_out, dim=1)
|
||||
|
||||
# get negative samples
|
||||
target = concat_all_gather(target)
|
||||
|
||||
# Einstein sum is more intuitive
|
||||
logits = torch.einsum('nc,mc->nm', [pred, target]) / self.temperature
|
||||
|
||||
# generate labels
|
||||
batch_size = logits.shape[0]
|
||||
labels = (torch.arange(batch_size, dtype=torch.long) +
|
||||
batch_size * get_rank()).to(logits.device)
|
||||
|
||||
loss = self.loss_module(logits, labels)
|
||||
return loss
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
# TODO: delete and use NaiveMIMHead
|
||||
@MODELS.register_module()
|
||||
class SimMIMHead(BaseModule):
|
||||
"""Pretrain Head for SimMIM.
|
||||
|
||||
Args:
|
||||
patch_size (int): Patch size of each token.
|
||||
loss (dict): The config for loss.
|
||||
"""
|
||||
|
||||
def __init__(self, patch_size: int, loss: dict) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.loss = MODELS.build(loss)
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor,
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function of MAE Loss.
|
||||
|
||||
This method will expand mask to the size of the original image.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The reconstructed image (B, C, H, W).
|
||||
target (torch.Tensor): The target image (B, C, H, W).
|
||||
mask (torch.Tensor): The mask of the target image.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reconstruction loss.
|
||||
"""
|
||||
mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(
|
||||
self.patch_size, 2).unsqueeze(1).contiguous()
|
||||
loss = self.loss(pred, target, mask)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from .base_head import BaseHead
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SwAVHead(BaseHead):
|
||||
"""Head for SwAV.
|
||||
|
||||
Args:
|
||||
loss (dict): Config dict for module of loss functions.
|
||||
"""
|
||||
|
||||
def __init__(self, loss: dict) -> None:
|
||||
super().__init__()
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
def loss(self, pred: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function of SwAV head.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): NxC input features.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The SwAV loss.
|
||||
"""
|
||||
loss = self.loss_module(pred)
|
||||
|
||||
return loss
|
|
@ -1,16 +1,35 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .asymmetric_loss import AsymmetricLoss, asymmetric_loss
|
||||
from .cae_loss import CAELoss
|
||||
from .cosine_similarity_loss import CosineSimilarityLoss
|
||||
from .cross_correlation_loss import CrossCorrelationLoss
|
||||
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
||||
cross_entropy)
|
||||
from .focal_loss import FocalLoss, sigmoid_focal_loss
|
||||
from .label_smooth_loss import LabelSmoothLoss
|
||||
from .reconstruction_loss import PixelReconstructionLoss
|
||||
from .seesaw_loss import SeesawLoss
|
||||
from .swav_loss import SwAVLoss
|
||||
from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss,
|
||||
weighted_loss)
|
||||
|
||||
__all__ = [
|
||||
'asymmetric_loss', 'AsymmetricLoss', 'cross_entropy',
|
||||
'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
||||
'weight_reduce_loss', 'LabelSmoothLoss', 'weighted_loss', 'FocalLoss',
|
||||
'sigmoid_focal_loss', 'convert_to_one_hot', 'SeesawLoss'
|
||||
'asymmetric_loss',
|
||||
'AsymmetricLoss',
|
||||
'cross_entropy',
|
||||
'binary_cross_entropy',
|
||||
'CrossEntropyLoss',
|
||||
'reduce_loss',
|
||||
'weight_reduce_loss',
|
||||
'LabelSmoothLoss',
|
||||
'weighted_loss',
|
||||
'FocalLoss',
|
||||
'sigmoid_focal_loss',
|
||||
'convert_to_one_hot',
|
||||
'SeesawLoss',
|
||||
'CAELoss',
|
||||
'CosineSimilarityLoss',
|
||||
'CrossCorrelationLoss',
|
||||
'PixelReconstructionLoss',
|
||||
'SwAVLoss',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CAELoss(BaseModule):
|
||||
"""Loss function for CAE.
|
||||
|
||||
Compute the align loss and the main loss.
|
||||
|
||||
Args:
|
||||
lambd (float): The weight for the align loss.
|
||||
"""
|
||||
|
||||
def __init__(self, lambd: float) -> None:
|
||||
super().__init__()
|
||||
self.lambd = lambd
|
||||
self.loss_cross_entropy = nn.CrossEntropyLoss()
|
||||
self.loss_mse = nn.MSELoss()
|
||||
|
||||
def forward(
|
||||
self, logits: torch.Tensor, target: torch.Tensor,
|
||||
latent_pred: torch.Tensor,
|
||||
latent_target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward function of CAE Loss.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): The outputs from the decoder.
|
||||
target (torch.Tensor): The targets generated by dalle.
|
||||
latent_pred (torch.Tensor): The latent prediction from the
|
||||
regressor.
|
||||
latent_target (torch.Tensor): The latent target from the teacher
|
||||
network.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The main loss and align loss.
|
||||
"""
|
||||
loss_main = self.loss_cross_entropy(logits, target)
|
||||
loss_align = self.loss_mse(latent_pred,
|
||||
latent_target.detach()) * self.lambd
|
||||
|
||||
return loss_main, loss_align
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CosineSimilarityLoss(BaseModule):
|
||||
"""Cosine similarity loss function.
|
||||
|
||||
Compute the similarity between two features and optimize that similarity as
|
||||
loss.
|
||||
|
||||
Args:
|
||||
shift_factor (float): The shift factor of cosine similarity.
|
||||
Default: 0.0.
|
||||
scale_factor (float): The scale factor of cosine similarity.
|
||||
Default: 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
shift_factor: float = 0.0,
|
||||
scale_factor: float = 1.0) -> None:
|
||||
super().__init__()
|
||||
self.shift_factor = shift_factor
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def forward(self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Forward function of cosine similarity loss.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted features.
|
||||
target (torch.Tensor): The target features.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The cosine similarity loss.
|
||||
"""
|
||||
pred_norm = nn.functional.normalize(pred, dim=-1)
|
||||
target_norm = nn.functional.normalize(target, dim=-1)
|
||||
loss = self.shift_factor - self.scale_factor * (
|
||||
pred_norm * target_norm).sum(dim=-1)
|
||||
|
||||
if mask is None:
|
||||
loss = loss.mean()
|
||||
else:
|
||||
loss = (loss * mask).sum() / mask.sum()
|
||||
return loss
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CrossCorrelationLoss(BaseModule):
|
||||
"""Cross correlation loss function.
|
||||
|
||||
Compute the on-diagnal and off-diagnal loss.
|
||||
|
||||
Args:
|
||||
lambd (float): The weight for the off-diag loss.
|
||||
"""
|
||||
|
||||
def __init__(self, lambd: float = 0.0051) -> None:
|
||||
super().__init__()
|
||||
self.lambd = lambd
|
||||
|
||||
def forward(self, cross_correlation_matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function of cross correlation loss.
|
||||
|
||||
Args:
|
||||
cross_correlation_matrix (torch.Tensor): The cross correlation
|
||||
matrix.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: cross correlation loss.
|
||||
"""
|
||||
# loss
|
||||
on_diag = torch.diagonal(cross_correlation_matrix).add_(-1).pow_(
|
||||
2).sum()
|
||||
off_diag = self.off_diagonal(cross_correlation_matrix).pow_(2).sum()
|
||||
loss = on_diag + self.lambd * off_diag
|
||||
return loss
|
||||
|
||||
def off_diagonal(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Rreturn a flattened view of the off-diagonal elements of a square
|
||||
matrix."""
|
||||
n, m = x.shape
|
||||
assert n == m
|
||||
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class PixelReconstructionLoss(BaseModule):
|
||||
"""Loss for the reconstruction of pixel in Masked Image Modeling.
|
||||
|
||||
This module measures the distance between the target image and the
|
||||
reconstructed image and compute the loss to optimize the model. Currently,
|
||||
This module only provides L1 and L2 loss to penalize the reconstructed
|
||||
error. In addition, a mask can be passed in the ``forward`` function to
|
||||
only apply loss on visible region, like that in MAE.
|
||||
|
||||
Args:
|
||||
criterion (str): The loss the penalize the reconstructed error.
|
||||
Currently, only supports L1 and L2 loss
|
||||
channel (int, optional): The number of channels to average the
|
||||
reconstruction loss. If not None, the reconstruction loss
|
||||
will be divided by the channel. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, criterion: str, channel: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
if criterion == 'L1':
|
||||
self.penalty = torch.nn.L1Loss(reduction='none')
|
||||
elif criterion == 'L2':
|
||||
self.penalty = torch.nn.MSELoss(reduction='none')
|
||||
else:
|
||||
raise NotImplementedError(f'Currently, PixelReconstructionLoss \
|
||||
only supports L1 and L2 loss, but get {criterion}')
|
||||
|
||||
self.channel = channel if channel is not None else 1
|
||||
|
||||
def forward(self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Forward function to compute the reconstrction loss.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The reconstructed image.
|
||||
target (torch.Tensor): The target image.
|
||||
mask (torch.Tensor): The mask of the target image.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reconstruction loss.
|
||||
"""
|
||||
loss = self.penalty(pred, target)
|
||||
|
||||
# if the dim of the loss is 3, take the average of the loss
|
||||
# along the last dim
|
||||
if len(loss.shape) == 3:
|
||||
loss = loss.mean(dim=-1)
|
||||
|
||||
if mask is None:
|
||||
loss = loss.mean()
|
||||
else:
|
||||
loss = (loss * mask).sum() / mask.sum() / self.channel
|
||||
|
||||
return loss
|
|
@ -0,0 +1,190 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from mmengine.dist import all_reduce
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def distributed_sinkhorn(out: torch.Tensor, sinkhorn_iterations: int,
|
||||
world_size: int, epsilon: float) -> torch.Tensor:
|
||||
"""Apply the distributed sinknorn optimization on the scores matrix to find
|
||||
the assignments.
|
||||
|
||||
This function is modified from
|
||||
https://github.com/facebookresearch/swav/blob/main/main_swav.py
|
||||
|
||||
Args:
|
||||
out (torch.Tensor): The scores matrix
|
||||
sinkhorn_iterations (int): Number of iterations in Sinkhorn-Knopp
|
||||
algorithm.
|
||||
world_size (int): The world size of the process group.
|
||||
epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output of sinkhorn algorithm.
|
||||
"""
|
||||
eps_num_stab = 1e-12
|
||||
Q = torch.exp(out / epsilon).t(
|
||||
) # Q is K-by-B for consistency with notations from our paper
|
||||
B = Q.shape[1] * world_size # number of samples to assign
|
||||
K = Q.shape[0] # how many prototypes
|
||||
|
||||
# make the matrix sums to 1
|
||||
sum_Q = torch.sum(Q)
|
||||
all_reduce(sum_Q)
|
||||
Q /= sum_Q
|
||||
|
||||
for it in range(sinkhorn_iterations):
|
||||
# normalize each row: total weight per prototype must be 1/K
|
||||
u = torch.sum(Q, dim=1, keepdim=True)
|
||||
if len(torch.nonzero(u == 0)) > 0:
|
||||
Q += eps_num_stab
|
||||
u = torch.sum(Q, dim=1, keepdim=True, dtype=Q.dtype)
|
||||
all_reduce(u)
|
||||
Q /= u
|
||||
Q /= K
|
||||
|
||||
# normalize each column: total weight per sample must be 1/B
|
||||
Q /= torch.sum(Q, dim=0, keepdim=True)
|
||||
Q /= B
|
||||
|
||||
Q *= B # the columns must sum to 1 so that Q is an assignment
|
||||
return Q.t()
|
||||
|
||||
|
||||
class MultiPrototypes(BaseModule):
|
||||
"""Multi-prototypes for SwAV head.
|
||||
|
||||
Args:
|
||||
output_dim (int): The output dim from SwAV neck.
|
||||
num_prototypes (List[int]): The number of prototypes needed.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
output_dim: int,
|
||||
num_prototypes: List[int],
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
assert isinstance(num_prototypes, list)
|
||||
self.num_heads = len(num_prototypes)
|
||||
for i, k in enumerate(num_prototypes):
|
||||
self.add_module('prototypes' + str(i),
|
||||
nn.Linear(output_dim, k, bias=False))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""Run forward for every prototype."""
|
||||
out = []
|
||||
for i in range(self.num_heads):
|
||||
out.append(getattr(self, 'prototypes' + str(i))(x))
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SwAVLoss(BaseModule):
|
||||
"""The Loss for SwAV.
|
||||
|
||||
This Loss contains clustering and sinkhorn algorithms to compute Q codes.
|
||||
Part of the code is borrowed from `script
|
||||
<https://github.com/facebookresearch/swav>`_.
|
||||
The queue is built in `engine/hooks/swav_hook.py`.
|
||||
|
||||
Args:
|
||||
feat_dim (int): feature dimension of the prototypes.
|
||||
sinkhorn_iterations (int): number of iterations in Sinkhorn-Knopp
|
||||
algorithm. Defaults to 3.
|
||||
epsilon (float): regularization parameter for Sinkhorn-Knopp algorithm.
|
||||
Defaults to 0.05.
|
||||
temperature (float): temperature parameter in training loss.
|
||||
Defaults to 0.1.
|
||||
crops_for_assign (List[int]): list of crops id used for computing
|
||||
assignments. Defaults to [0, 1].
|
||||
num_crops (List[int]): list of number of crops. Defaults to [2].
|
||||
num_prototypes (int): number of prototypes. Defaults to 3000.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
feat_dim: int,
|
||||
sinkhorn_iterations: int = 3,
|
||||
epsilon: float = 0.05,
|
||||
temperature: float = 0.1,
|
||||
crops_for_assign: List[int] = [0, 1],
|
||||
num_crops: List[int] = [2],
|
||||
num_prototypes: int = 3000,
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.sinkhorn_iterations = sinkhorn_iterations
|
||||
self.epsilon = epsilon
|
||||
self.temperature = temperature
|
||||
self.crops_for_assign = crops_for_assign
|
||||
self.num_crops = num_crops
|
||||
self.use_queue = False
|
||||
self.queue = None
|
||||
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
|
||||
# prototype layer
|
||||
self.prototypes = None
|
||||
if isinstance(num_prototypes, list):
|
||||
self.prototypes = MultiPrototypes(feat_dim, num_prototypes)
|
||||
elif num_prototypes > 0:
|
||||
self.prototypes = nn.Linear(feat_dim, num_prototypes, bias=False)
|
||||
assert self.prototypes is not None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function of SwAV loss.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): NxC input features.
|
||||
Returns:
|
||||
torch.Tensor: The returned loss.
|
||||
"""
|
||||
# normalize the prototypes
|
||||
with torch.no_grad():
|
||||
w = self.prototypes.weight.data.clone()
|
||||
w = nn.functional.normalize(w, dim=1, p=2)
|
||||
self.prototypes.weight.copy_(w)
|
||||
|
||||
embedding, output = x, self.prototypes(x)
|
||||
embedding = embedding.detach()
|
||||
|
||||
bs = int(embedding.size(0) / sum(self.num_crops))
|
||||
loss = 0
|
||||
for i, crop_id in enumerate(self.crops_for_assign):
|
||||
with torch.no_grad():
|
||||
out = output[bs * crop_id:bs * (crop_id + 1)].detach()
|
||||
# time to use the queue
|
||||
if self.queue is not None:
|
||||
if self.use_queue or not torch.all(self.queue[i,
|
||||
-1, :] == 0):
|
||||
self.use_queue = True
|
||||
out = torch.cat(
|
||||
(torch.mm(self.queue[i],
|
||||
self.prototypes.weight.t()), out))
|
||||
# fill the queue
|
||||
self.queue[i, bs:] = self.queue[i, :-bs].clone()
|
||||
self.queue[i, :bs] = embedding[crop_id * bs:(crop_id + 1) *
|
||||
bs]
|
||||
|
||||
# get assignments (batch_size * num_prototypes)
|
||||
q = distributed_sinkhorn(out, self.sinkhorn_iterations,
|
||||
self.world_size, self.epsilon)[-bs:]
|
||||
|
||||
# cluster assignment prediction
|
||||
subloss = 0
|
||||
for v in np.delete(np.arange(np.sum(self.num_crops)), crop_id):
|
||||
x = output[bs * v:bs * (v + 1)] / self.temperature
|
||||
subloss -= torch.mean(
|
||||
torch.sum(q * nn.functional.log_softmax(x, dim=1), dim=1))
|
||||
loss += subloss / (np.sum(self.num_crops) - 1)
|
||||
loss /= len(self.crops_for_assign)
|
||||
return loss
|
|
@ -1,10 +1,30 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .beitv2_neck import BEiTV2Neck
|
||||
from .cae_neck import CAENeck
|
||||
from .densecl_neck import DenseCLNeck
|
||||
from .gap import GlobalAveragePooling
|
||||
from .gem import GeneralizedMeanPooling
|
||||
from .hr_fuse import HRFuseScales
|
||||
from .reduction import LinearReduction
|
||||
from .linear_neck import LinearNeck
|
||||
from .mae_neck import MAEPretrainDecoder
|
||||
from .milan_neck import MILANPretrainDecoder
|
||||
from .mixmim_neck import MixMIMPretrainDecoder
|
||||
from .mocov2_neck import MoCoV2Neck
|
||||
from .nonlinear_neck import NonLinearNeck
|
||||
from .simmim_neck import SimMIMLinearDecoder
|
||||
|
||||
__all__ = [
|
||||
'GlobalAveragePooling', 'GeneralizedMeanPooling', 'HRFuseScales',
|
||||
'LinearReduction'
|
||||
'GlobalAveragePooling',
|
||||
'GeneralizedMeanPooling',
|
||||
'HRFuseScales',
|
||||
'LinearNeck',
|
||||
'BEiTV2Neck',
|
||||
'CAENeck',
|
||||
'DenseCLNeck',
|
||||
'MAEPretrainDecoder',
|
||||
'MILANPretrainDecoder',
|
||||
'MixMIMPretrainDecoder',
|
||||
'MoCoV2Neck',
|
||||
'NonLinearNeck',
|
||||
'SimMIMLinearDecoder',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,153 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BEiTV2Neck(BaseModule):
|
||||
"""Neck for BEiTV2 Pre-training.
|
||||
|
||||
This module construct the decoder for the final prediction.
|
||||
|
||||
Args:
|
||||
num_layers (int): Number of encoder layers of neck. Defaults to 2.
|
||||
early_layers (int): The layer index of the early output from the
|
||||
backbone. Defaults to 9.
|
||||
backbone_arch (str): Vision Transformer architecture. Defaults to base.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
layer_scale_init_value (float): The initialization value for the
|
||||
learnable scaling of attention and FFN. Defaults to 0.1.
|
||||
use_rel_pos_bias (bool): Whether to use unique relative position bias,
|
||||
if False, use shared relative position bias defined in backbone.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
arch_zoo = {
|
||||
**dict.fromkeys(
|
||||
['b', 'base'], {
|
||||
'embed_dims': 768,
|
||||
'depth': 12,
|
||||
'num_heads': 12,
|
||||
'feedforward_channels': 3072,
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['l', 'large'], {
|
||||
'embed_dims': 1024,
|
||||
'depth': 24,
|
||||
'num_heads': 16,
|
||||
'feedforward_channels': 4096,
|
||||
}),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 2,
|
||||
early_layers: int = 9,
|
||||
backbone_arch: str = 'base',
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
layer_scale_init_value: float = 0.1,
|
||||
use_rel_pos_bias: bool = False,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = dict(
|
||||
type='TruncNormal', layer='Linear', std=0.02, bias=0)
|
||||
) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(backbone_arch, str):
|
||||
backbone_arch = backbone_arch.lower()
|
||||
assert backbone_arch in set(self.arch_zoo), \
|
||||
(f'Arch {backbone_arch} is not in default archs '
|
||||
f'{set(self.arch_zoo)}')
|
||||
self.arch_settings = self.arch_zoo[backbone_arch]
|
||||
else:
|
||||
essential_keys = {
|
||||
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
|
||||
}
|
||||
assert isinstance(backbone_arch, dict) and essential_keys <= set(
|
||||
backbone_arch
|
||||
), f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = backbone_arch
|
||||
|
||||
# stochastic depth decay rule
|
||||
self.early_layers = early_layers
|
||||
depth = self.arch_settings['depth']
|
||||
dpr = np.linspace(0, drop_path_rate,
|
||||
max(depth, early_layers + num_layers))
|
||||
|
||||
self.patch_aggregation = nn.ModuleList()
|
||||
for i in range(early_layers, early_layers + num_layers):
|
||||
_layer_cfg = dict(
|
||||
embed_dims=self.arch_settings['embed_dims'],
|
||||
num_heads=self.arch_settings['num_heads'],
|
||||
feedforward_channels=self.
|
||||
arch_settings['feedforward_channels'],
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=dpr[i],
|
||||
norm_cfg=norm_cfg,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
window_size=None,
|
||||
use_rel_pos_bias=use_rel_pos_bias)
|
||||
self.patch_aggregation.append(
|
||||
BEiTTransformerEncoderLayer(**_layer_cfg))
|
||||
|
||||
self.rescale_patch_aggregation_init_weight()
|
||||
|
||||
embed_dims = self.arch_settings['embed_dims']
|
||||
_, norm = build_norm_layer(norm_cfg, embed_dims)
|
||||
self.add_module('norm', norm)
|
||||
|
||||
def rescale_patch_aggregation_init_weight(self):
|
||||
"""Rescale the initialized weights."""
|
||||
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.patch_aggregation):
|
||||
rescale(layer.attn.proj.weight.data,
|
||||
self.early_layers + layer_id + 1)
|
||||
rescale(layer.ffn.layers[1].weight.data,
|
||||
self.early_layers + layer_id + 1)
|
||||
|
||||
def forward(self, inputs: Tuple[torch.Tensor], rel_pos_bias: torch.Tensor,
|
||||
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get the latent prediction and final prediction.
|
||||
|
||||
Args:
|
||||
x (Tuple[torch.Tensor]): Features of tokens.
|
||||
rel_pos_bias (torch.Tensor): Shared relative position bias table.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- ``x``: The final layer features from backbone, which are normed
|
||||
in ``BEiTV2Neck``.
|
||||
- ``x_cls_pt``: The early state features from backbone, which are
|
||||
consist of final layer cls_token and early state patch_tokens
|
||||
from backbone and sent to PatchAggregation layers in the neck.
|
||||
"""
|
||||
|
||||
early_states, x = inputs[0], inputs[1]
|
||||
x_cls_pt = torch.cat([x[:, [0]], early_states[:, 1:]], dim=1)
|
||||
for layer in self.patch_aggregation:
|
||||
x_cls_pt = layer(x_cls_pt, rel_pos_bias=rel_pos_bias)
|
||||
|
||||
# shared norm
|
||||
x, x_cls_pt = self.norm(x), self.norm(x_cls_pt)
|
||||
|
||||
# remove cls_token
|
||||
x = x[:, 1:]
|
||||
x_cls_pt = x_cls_pt[:, 1:]
|
||||
return x, x_cls_pt
|
|
@ -0,0 +1,273 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmengine.model import BaseModule
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmpretrain.models.backbones.beit import BEiTTransformerEncoderLayer
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import CrossMultiheadAttention
|
||||
|
||||
|
||||
class CAETransformerRegressorLayer(BaseModule):
|
||||
"""Transformer layer for the regressor of CAE.
|
||||
|
||||
This module is different from conventional transformer encoder layer, for
|
||||
its queries are the masked tokens, but its keys and values are the
|
||||
concatenation of the masked and unmasked tokens.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): The number of heads in multi-head attention.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 1024.
|
||||
num_fcs (int, optional): The number of fully-connected layers in
|
||||
FFNs. Default: 2.
|
||||
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
||||
Defaults to True.
|
||||
qk_scale (float, optional): Override default qk scale of
|
||||
``head_dim ** -0.5`` if set. Defaults to None.
|
||||
drop_rate (float): The dropout rate. Defaults to 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention output weights.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||
layer_scale_init_value (float): The init value of gamma.
|
||||
Defaults to 0.0.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Defaluts to ``dict(type='GELU')``.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int,
|
||||
num_heads: int,
|
||||
feedforward_channels: int,
|
||||
num_fcs: int = 2,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: float = None,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
layer_scale_init_value: float = 0.0,
|
||||
act_cfg: dict = dict(type='GELU'),
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6)
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# NOTE: cross attention
|
||||
_, self.norm1_q_cross = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=2)
|
||||
_, self.norm1_k_cross = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=2)
|
||||
_, self.norm1_v_cross = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=2)
|
||||
_, self.norm2_cross = build_norm_layer(norm_cfg, embed_dims, postfix=2)
|
||||
self.cross_attn = CrossMultiheadAttention(
|
||||
embed_dims,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate)
|
||||
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=None,
|
||||
act_cfg=act_cfg,
|
||||
add_identity=False)
|
||||
|
||||
self.drop_path = DropPath(drop_prob=drop_path_rate)
|
||||
|
||||
if layer_scale_init_value > 0:
|
||||
self.gamma_1_cross = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((embed_dims)),
|
||||
requires_grad=True)
|
||||
self.gamma_2_cross = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((embed_dims)),
|
||||
requires_grad=True)
|
||||
else:
|
||||
self.gamma_1_cross = nn.Parameter(
|
||||
torch.ones((embed_dims)), requires_grad=False)
|
||||
self.gamma_2_cross = nn.Parameter(
|
||||
torch.ones((embed_dims)), requires_grad=False)
|
||||
|
||||
def forward(self, x_q: torch.Tensor, x_kv: torch.Tensor,
|
||||
pos_q: torch.Tensor, pos_k: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function."""
|
||||
x = x_q + self.drop_path(self.gamma_1_cross * self.cross_attn(
|
||||
self.norm1_q_cross(x_q + pos_q),
|
||||
k=self.norm1_k_cross(x_kv + pos_k),
|
||||
v=self.norm1_v_cross(x_kv)))
|
||||
x = self.norm2_cross(x)
|
||||
x = x + self.drop_path(self.gamma_2_cross * self.ffn(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CAENeck(BaseModule):
|
||||
"""Neck for CAE Pre-training.
|
||||
|
||||
This module construct the latent prediction regressor and the decoder
|
||||
for the latent prediction and final prediction.
|
||||
|
||||
Args:
|
||||
num_classes (int): The number of classes for final prediction. Defaults
|
||||
to 8192.
|
||||
embed_dims (int): The embed dims of latent feature in regressor and
|
||||
decoder. Defaults to 768.
|
||||
regressor_depth (int): The number of regressor blocks. Defaults to 6.
|
||||
decoder_depth (int): The number of decoder blocks. Defaults to 8.
|
||||
num_heads (int): The number of head in multi-head attention. Defaults
|
||||
to 12.
|
||||
mlp_ratio (int): The expand ratio of latent features in MLP. defaults
|
||||
to 4.
|
||||
qkv_bias (bool): Whether or not to use qkv bias. Defaults to True.
|
||||
qk_scale (float, optional): The scale applied to the results of qk.
|
||||
Defaults to None.
|
||||
drop_rate (float): The dropout rate. Defaults to 0.
|
||||
attn_drop_rate (float): The dropout rate in attention block. Defaults
|
||||
to 0.
|
||||
norm_cfg (dict): The config of normalization layer. Defaults to
|
||||
dict(type='LN', eps=1e-6).
|
||||
layer_scale_init_value (float, optional): The init value of gamma.
|
||||
Defaults to None.
|
||||
mask_tokens_num (int): The number of mask tokens. Defaults to 75.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_classes: int = 8192,
|
||||
embed_dims: int = 768,
|
||||
regressor_depth: int = 6,
|
||||
decoder_depth: int = 8,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: int = 4,
|
||||
qkv_bias: bool = True,
|
||||
qk_scale: float = None,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
layer_scale_init_value: float = None,
|
||||
mask_tokens_num: int = 75,
|
||||
init_cfg: dict = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.num_features = self.embed_dim = embed_dims
|
||||
self.mask_token_num = mask_tokens_num
|
||||
|
||||
# regressor
|
||||
regressor_drop_path_rates = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, drop_path_rate, regressor_depth)
|
||||
]
|
||||
self.regressors = nn.ModuleList([
|
||||
CAETransformerRegressorLayer(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio * embed_dims,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=regressor_drop_path_rates[i],
|
||||
norm_cfg=norm_cfg,
|
||||
layer_scale_init_value=layer_scale_init_value)
|
||||
for i in range(regressor_depth)
|
||||
])
|
||||
|
||||
# decoder
|
||||
decoder_drop_path_rates = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, decoder_depth)
|
||||
]
|
||||
self.decoders = nn.ModuleList([
|
||||
BEiTTransformerEncoderLayer(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=mlp_ratio * embed_dims,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
window_size=None,
|
||||
# setting `use_rel_pos_bias` to False ignores the `window_size`
|
||||
use_rel_pos_bias=False,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=decoder_drop_path_rates[i],
|
||||
norm_cfg=norm_cfg) for i in range(decoder_depth)
|
||||
])
|
||||
|
||||
_, self.norm_regressor = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=2)
|
||||
_, self.norm_decoder = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=2)
|
||||
|
||||
self.head = nn.Linear(
|
||||
embed_dims, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialization."""
|
||||
super().init_weights()
|
||||
self.apply(self._init_weights)
|
||||
trunc_normal_(self.mask_token, std=0.02)
|
||||
trunc_normal_(self.head.weight, std=0.02)
|
||||
|
||||
def _init_weights(self, m: nn.Module) -> None:
|
||||
"""Initialization."""
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward(
|
||||
self, x_unmasked: torch.Tensor, pos_embed_masked: torch.Tensor,
|
||||
pos_embed_unmasked: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get the latent prediction and final prediction.
|
||||
|
||||
Args:
|
||||
x_unmasked (torch.Tensor): Features of unmasked tokens.
|
||||
pos_embed_masked (torch.Tensor): Position embedding of masked
|
||||
tokens.
|
||||
pos_embed_unmasked (torch.Tensor): Position embedding of unmasked
|
||||
tokens.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- ``logits``: Final prediction.
|
||||
- ``latent_pred``: Latent prediction.
|
||||
"""
|
||||
x_masked = self.mask_token.expand(x_unmasked.shape[0],
|
||||
self.mask_token_num, -1)
|
||||
# regressor
|
||||
for regressor in self.regressors:
|
||||
x_masked = regressor(
|
||||
x_masked, torch.cat([x_unmasked, x_masked], dim=1),
|
||||
pos_embed_masked,
|
||||
torch.cat([pos_embed_unmasked, pos_embed_masked], dim=1))
|
||||
x_masked = self.norm_regressor(x_masked)
|
||||
latent_pred = x_masked
|
||||
|
||||
# decoder
|
||||
x_masked = x_masked + pos_embed_masked
|
||||
for decoder in self.decoders:
|
||||
x_masked = decoder(x_masked)
|
||||
x_masked = self.norm_decoder(x_masked)
|
||||
|
||||
logits = self.head(x_masked)
|
||||
|
||||
return logits, latent_pred
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class DenseCLNeck(BaseModule):
|
||||
"""The non-linear neck of DenseCL.
|
||||
|
||||
Single and dense neck in parallel: fc-relu-fc, conv-relu-conv.
|
||||
Borrowed from the authors' `code <https://github.com/WXinlong/DenseCL>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
hid_channels (int): Number of hidden channels.
|
||||
out_channels (int): Number of output channels.
|
||||
num_grid (int): The grid size of dense features. Defaults to None.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
hid_channels: int,
|
||||
out_channels: int,
|
||||
num_grid: Optional[int] = None,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
||||
super().__init__(init_cfg)
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True),
|
||||
nn.Linear(hid_channels, out_channels))
|
||||
|
||||
self.with_pool = True if num_grid is not None else False
|
||||
if self.with_pool:
|
||||
self.pool = nn.AdaptiveAvgPool2d((num_grid, num_grid))
|
||||
self.mlp2 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, hid_channels, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(hid_channels, out_channels, 1))
|
||||
self.avgpool2 = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
|
||||
"""Forward function of neck.
|
||||
|
||||
Args:
|
||||
x (Tuple[torch.Tensor]): feature map of backbone.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
- ``avgpooled_x``: Global feature vectors.
|
||||
- ``x``: Dense feature vectors.
|
||||
- ``avgpooled_x2``: Dense feature vectors for queue.
|
||||
"""
|
||||
assert len(x) == 1
|
||||
x = x[0]
|
||||
|
||||
avgpooled_x = self.avgpool(x)
|
||||
avgpooled_x = self.mlp(avgpooled_x.view(avgpooled_x.size(0), -1))
|
||||
|
||||
if self.with_pool:
|
||||
x = self.pool(x) # sxs
|
||||
x = self.mlp2(x) # sxs: bxdxsxs
|
||||
avgpooled_x2 = self.avgpool2(x) # 1x1: bxdx1x1
|
||||
x = x.view(x.size(0), x.size(1), -1) # bxdxs^2
|
||||
avgpooled_x2 = avgpooled_x2.view(avgpooled_x2.size(0), -1) # bxd
|
||||
return avgpooled_x, x, avgpooled_x2
|
|
@ -11,12 +11,14 @@ from mmpretrain.registry import MODELS
|
|||
|
||||
|
||||
@MODELS.register_module()
|
||||
class LinearReduction(BaseModule):
|
||||
"""Neck with Dimension reduction.
|
||||
class LinearNeck(BaseModule):
|
||||
"""Linear neck with Dimension projection.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input.
|
||||
out_channels (int): Number of channels in the output.
|
||||
gap_dim (int): Dimensions of each sample channel, can be one of
|
||||
{0, 1, 2, 3}. Defaults to 0.
|
||||
norm_cfg (dict, optional): dictionary to construct and
|
||||
config norm layer. Defaults to dict(type='BN1d').
|
||||
act_cfg (dict, optional): dictionary to construct and
|
||||
|
@ -28,22 +30,35 @@ class LinearReduction(BaseModule):
|
|||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
gap_dim: int = 0,
|
||||
norm_cfg: Optional[dict] = dict(type='BN1d'),
|
||||
act_cfg: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
super(LinearReduction, self).__init__(init_cfg=init_cfg)
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.norm_cfg = copy.deepcopy(norm_cfg)
|
||||
self.act_cfg = copy.deepcopy(act_cfg)
|
||||
|
||||
self.reduction = nn.Linear(
|
||||
in_features=in_channels, out_features=out_channels)
|
||||
assert gap_dim in [0, 1, 2, 3], 'GlobalAveragePooling dim only ' \
|
||||
f'support {0, 1, 2, 3}, get {gap_dim} instead.'
|
||||
if gap_dim == 0:
|
||||
self.gap = nn.Identity()
|
||||
elif gap_dim == 1:
|
||||
self.gap = nn.AdaptiveAvgPool1d(1)
|
||||
elif gap_dim == 2:
|
||||
self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
||||
elif gap_dim == 3:
|
||||
self.gap = nn.AdaptiveAvgPool3d((1, 1, 1))
|
||||
|
||||
self.fc = nn.Linear(in_features=in_channels, out_features=out_channels)
|
||||
|
||||
if norm_cfg:
|
||||
self.norm = build_norm_layer(norm_cfg, out_channels)[1]
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
if act_cfg:
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
else:
|
||||
|
@ -59,13 +74,15 @@ class LinearReduction(BaseModule):
|
|||
the last stage will be used.
|
||||
|
||||
Returns:
|
||||
Tuple(torch.Tensor)): A tuple of reducted features.
|
||||
Tuple[torch.Tensor]: A tuple of output features.
|
||||
"""
|
||||
assert isinstance(inputs, (tuple, torch.Tensor)), (
|
||||
'The inputs of `LinearReduction` neck must be tuple or '
|
||||
f'`torch.Tensor`, but get {type(inputs)}.')
|
||||
'The inputs of `LinearNeck` must be tuple or `torch.Tensor`, '
|
||||
f'but get {type(inputs)}.')
|
||||
if isinstance(inputs, tuple):
|
||||
inputs = inputs[-1]
|
||||
|
||||
out = self.act(self.norm(self.reduction(inputs)))
|
||||
x = self.gap(inputs)
|
||||
x = x.view(x.size(0), -1)
|
||||
out = self.act(self.norm(self.fc(x)))
|
||||
return (out, )
|
|
@ -0,0 +1,189 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..backbones.vision_transformer import TransformerEncoderLayer
|
||||
from ..utils import build_2d_sincos_position_embedding
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MAEPretrainDecoder(BaseModule):
|
||||
"""Decoder for MAE Pre-training.
|
||||
|
||||
Some of the code is borrowed from `https://github.com/facebookresearch/mae`. # noqa
|
||||
|
||||
Args:
|
||||
num_patches (int): The number of total patches. Defaults to 196.
|
||||
patch_size (int): Image patch size. Defaults to 16.
|
||||
in_chans (int): The channel of input image. Defaults to 3.
|
||||
embed_dim (int): Encoder's embedding dimension. Defaults to 1024.
|
||||
decoder_embed_dim (int): Decoder's embedding dimension.
|
||||
Defaults to 512.
|
||||
decoder_depth (int): The depth of decoder. Defaults to 8.
|
||||
decoder_num_heads (int): Number of attention heads of decoder.
|
||||
Defaults to 16.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
|
||||
Defaults to 4.
|
||||
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
|
||||
init_cfg (Union[List[dict], dict], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
|
||||
Example:
|
||||
>>> from mmpretrain.models import MAEPretrainDecoder
|
||||
>>> import torch
|
||||
>>> self = MAEPretrainDecoder()
|
||||
>>> self.eval()
|
||||
>>> inputs = torch.rand(1, 50, 1024)
|
||||
>>> ids_restore = torch.arange(0, 196).unsqueeze(0)
|
||||
>>> level_outputs = self.forward(inputs, ids_restore)
|
||||
>>> print(tuple(level_outputs.shape))
|
||||
(1, 196, 768)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_patches: int = 196,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 1024,
|
||||
decoder_embed_dim: int = 512,
|
||||
decoder_depth: int = 8,
|
||||
decoder_num_heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
predict_feature_dim: Optional[float] = None,
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.num_patches = num_patches
|
||||
|
||||
# used to convert the dim of features from encoder to the dim
|
||||
# compatible with that of decoder
|
||||
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
||||
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
||||
|
||||
# create new position embedding, different from that in encoder
|
||||
# and is not learnable
|
||||
self.decoder_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, self.num_patches + 1, decoder_embed_dim),
|
||||
requires_grad=False)
|
||||
|
||||
self.decoder_blocks = nn.ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
decoder_embed_dim,
|
||||
decoder_num_heads,
|
||||
int(mlp_ratio * decoder_embed_dim),
|
||||
qkv_bias=True,
|
||||
norm_cfg=norm_cfg) for _ in range(decoder_depth)
|
||||
])
|
||||
|
||||
self.decoder_norm_name, decoder_norm = build_norm_layer(
|
||||
norm_cfg, decoder_embed_dim, postfix=1)
|
||||
self.add_module(self.decoder_norm_name, decoder_norm)
|
||||
|
||||
# Used to map features to pixels
|
||||
if predict_feature_dim is None:
|
||||
predict_feature_dim = patch_size**2 * in_chans
|
||||
self.decoder_pred = nn.Linear(
|
||||
decoder_embed_dim, predict_feature_dim, bias=True)
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize position embedding and mask token of MAE decoder."""
|
||||
super().init_weights()
|
||||
|
||||
decoder_pos_embed = build_2d_sincos_position_embedding(
|
||||
int(self.num_patches**.5),
|
||||
self.decoder_pos_embed.shape[-1],
|
||||
cls_token=True)
|
||||
self.decoder_pos_embed.data.copy_(decoder_pos_embed.float())
|
||||
|
||||
torch.nn.init.normal_(self.mask_token, std=.02)
|
||||
|
||||
@property
|
||||
def decoder_norm(self):
|
||||
"""The normalization layer of decoder."""
|
||||
return getattr(self, self.decoder_norm_name)
|
||||
|
||||
def forward(self, x: torch.Tensor,
|
||||
ids_restore: torch.Tensor) -> torch.Tensor:
|
||||
"""The forward function.
|
||||
|
||||
The process computes the visible patches' features vectors and the mask
|
||||
tokens to output feature vectors, which will be used for
|
||||
reconstruction.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): hidden features, which is of shape
|
||||
B x (L * mask_ratio) x C.
|
||||
ids_restore (torch.Tensor): ids to restore original image.
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): The reconstructed feature vectors, which is of
|
||||
shape B x (num_patches) x C.
|
||||
"""
|
||||
# embed tokens
|
||||
x = self.decoder_embed(x)
|
||||
|
||||
# append mask tokens to sequence
|
||||
mask_tokens = self.mask_token.repeat(
|
||||
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
||||
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
|
||||
x_ = torch.gather(
|
||||
x_,
|
||||
dim=1,
|
||||
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
|
||||
x = torch.cat([x[:, :1, :], x_], dim=1)
|
||||
|
||||
# add pos embed
|
||||
x = x + self.decoder_pos_embed
|
||||
|
||||
# apply Transformer blocks
|
||||
for blk in self.decoder_blocks:
|
||||
x = blk(x)
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
# predictor projection
|
||||
x = self.decoder_pred(x)
|
||||
|
||||
# remove cls token
|
||||
x = x[:, 1:, :]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class ClsBatchNormNeck(BaseModule):
|
||||
"""Normalize cls token across batch before head.
|
||||
|
||||
This module is proposed by MAE, when running linear probing.
|
||||
|
||||
Args:
|
||||
input_features (int): The dimension of features.
|
||||
affine (bool): a boolean value that when set to ``True``, this module
|
||||
has learnable affine parameters. Defaults to False.
|
||||
eps (float): a value added to the denominator for numerical stability.
|
||||
Defaults to 1e-6.
|
||||
init_cfg (Dict or List[Dict], optional): Config dict for weight
|
||||
initialization. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_features: int,
|
||||
affine: bool = False,
|
||||
eps: float = 1e-6,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
||||
super().__init__(init_cfg)
|
||||
self.bn = nn.BatchNorm1d(input_features, affine=affine, eps=eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: Tuple[List[torch.Tensor]]) -> Tuple[List[torch.Tensor]]:
|
||||
"""The forward function."""
|
||||
# Only apply batch norm to cls token, which is the second tensor in
|
||||
# each item of the tuple
|
||||
inputs = [[input_[0], self.bn(input_[1])] for input_ in inputs]
|
||||
return tuple(inputs)
|
|
@ -0,0 +1,222 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..backbones.vision_transformer import TransformerEncoderLayer
|
||||
from ..utils import PromptMultiheadAttention
|
||||
from .mae_neck import MAEPretrainDecoder
|
||||
|
||||
|
||||
class PromptTransformerEncoderLayer(TransformerEncoderLayer):
|
||||
"""Prompt Transformer Encoder Layer for MILAN.
|
||||
|
||||
This module is specific for the prompt encoder in MILAN. It will not update
|
||||
the visible tokens from the encoder.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.0.
|
||||
attn_drop_rate (float): The drop out rate for attention layer.
|
||||
Defaults to 0.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Defaults to 2.
|
||||
qkv_bias (bool): Enable bias for qkv if True. Defaults to True.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Defaluts to ``dict(type='GELU')``.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
batch_first (bool): Key, Query and Value are shape of
|
||||
(batch, n, embed_dim)
|
||||
or (n, batch, embed_dim). Defaults to False.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims: int,
|
||||
num_heads: int,
|
||||
feedforward_channels=int,
|
||||
drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
num_fcs: int = 2,
|
||||
qkv_bias: bool = True,
|
||||
act_cfg: dict = dict(type='GELU'),
|
||||
norm_cfg: dict = dict(type='LN'),
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
feedforward_channels=feedforward_channels,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
num_fcs=num_fcs,
|
||||
qkv_bias=qkv_bias,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
init_cfg=init_cfg)
|
||||
self.attn = PromptMultiheadAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
qkv_bias=qkv_bias)
|
||||
|
||||
def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor,
|
||||
ids_restore: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function for `PromptMultiheadAttention`.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Mask token features with shape N x L_m x C.
|
||||
visible_tokens (torch.Tensor): The visible tokens features from
|
||||
encoder with shape N x L_v x C.
|
||||
ids_restore (torch.Tensor): The ids of all tokens in the original
|
||||
image with shape N x L.
|
||||
|
||||
Returns:
|
||||
torch Tensor: Output features with shape N x L x C.
|
||||
"""
|
||||
x = x + self.attn(self.norm1(x), visible_tokens, ids_restore)
|
||||
x = self.ffn(self.norm2(x), identity=x)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MILANPretrainDecoder(MAEPretrainDecoder):
|
||||
"""Prompt decoder for MILAN.
|
||||
|
||||
This decoder is used in MILAN pretraining, which will not update these
|
||||
visible tokens from the encoder.
|
||||
|
||||
Args:
|
||||
num_patches (int): The number of total patches. Defaults to 196.
|
||||
patch_size (int): Image patch size. Defaults to 16.
|
||||
in_chans (int): The channel of input image. Defaults to 3.
|
||||
embed_dim (int): Encoder's embedding dimension. Defaults to 1024.
|
||||
decoder_embed_dim (int): Decoder's embedding dimension.
|
||||
Defaults to 512.
|
||||
decoder_depth (int): The depth of decoder. Defaults to 8.
|
||||
decoder_num_heads (int): Number of attention heads of decoder.
|
||||
Defaults to 16.
|
||||
predict_feature_dim (int): The dimension of the feature to be
|
||||
predicted. Defaults to 512.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
|
||||
Defaults to 4.
|
||||
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
|
||||
init_cfg (Union[List[dict], dict], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_patches: int = 196,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 1024,
|
||||
decoder_embed_dim: int = 512,
|
||||
decoder_depth: int = 8,
|
||||
decoder_num_heads: int = 16,
|
||||
predict_feature_dim: int = 512,
|
||||
mlp_ratio: int = 4,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(
|
||||
num_patches=num_patches,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
decoder_embed_dim=decoder_embed_dim,
|
||||
decoder_depth=decoder_depth,
|
||||
decoder_num_heads=decoder_num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
norm_cfg=norm_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
# map the dim of features from decoder to the dim compatible with
|
||||
# that of CLIP
|
||||
self.decoder_pred = nn.Linear(
|
||||
decoder_embed_dim, predict_feature_dim, bias=True)
|
||||
|
||||
# use prompt transformer encoder layer, instead of the conventional
|
||||
# transformer encoder layer
|
||||
self.decoder_blocks = nn.ModuleList([
|
||||
PromptTransformerEncoderLayer(
|
||||
decoder_embed_dim,
|
||||
decoder_num_heads,
|
||||
int(mlp_ratio * decoder_embed_dim),
|
||||
qkv_bias=True,
|
||||
norm_cfg=norm_cfg) for _ in range(decoder_depth)
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor, ids_restore: torch.Tensor,
|
||||
ids_keep: torch.Tensor,
|
||||
ids_dump: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input features, which is of shape (N, L, C).
|
||||
ids_restore (torch.Tensor): The indices to restore these tokens
|
||||
to the original image.
|
||||
ids_keep (torch.Tensor): The indices of tokens to be kept.
|
||||
ids_dump (torch.Tensor): The indices of tokens to be masked.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reconstructed features, which is of shape
|
||||
(N, L, C).
|
||||
"""
|
||||
# embed tokens
|
||||
x = self.decoder_embed(x)
|
||||
|
||||
# append mask tokens to sequence
|
||||
mask_tokens = self.mask_token.repeat(
|
||||
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
||||
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
|
||||
x_ = torch.gather(
|
||||
x_,
|
||||
dim=1,
|
||||
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
|
||||
x = torch.cat([x[:, :1, :], x_], dim=1)
|
||||
|
||||
# add pos embed
|
||||
x = x + self.decoder_pos_embed
|
||||
|
||||
# split mask tokens and visible tokens
|
||||
visible_tokens = torch.cat([
|
||||
x[:, :1, :],
|
||||
torch.gather(
|
||||
x[:, 1:, :],
|
||||
dim=1,
|
||||
index=ids_keep.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
|
||||
],
|
||||
dim=1)
|
||||
x = torch.gather(
|
||||
x[:, 1:, :],
|
||||
dim=1,
|
||||
index=ids_dump.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
|
||||
|
||||
for blk in self.decoder_blocks:
|
||||
x = blk(x, visible_tokens, ids_restore)
|
||||
|
||||
# full sequence recovery
|
||||
x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1)
|
||||
x_ = torch.gather(
|
||||
x_,
|
||||
dim=1,
|
||||
index=ids_restore.unsqueeze(-1).repeat(1, 1,
|
||||
x.shape[-1])) # unshuffle
|
||||
x = torch.cat([visible_tokens[:, :1, :], x_], dim=1)
|
||||
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
# predictor projection
|
||||
x = self.decoder_pred(x)
|
||||
|
||||
return x
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import build_2d_sincos_position_embedding
|
||||
from .mae_neck import MAEPretrainDecoder
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MixMIMPretrainDecoder(MAEPretrainDecoder):
|
||||
"""Decoder for MixMIM Pretraining.
|
||||
|
||||
Some of the code is borrowed from `https://github.com/Sense-X/MixMIM`. # noqa
|
||||
|
||||
Args:
|
||||
num_patches (int): The number of total patches. Defaults to 196.
|
||||
patch_size (int): Image patch size. Defaults to 16.
|
||||
in_chans (int): The channel of input image. Defaults to 3.
|
||||
embed_dim (int): Encoder's embedding dimension. Defaults to 1024.
|
||||
encoder_stride (int): The output stride of MixMIM backbone. Defaults
|
||||
to 32.
|
||||
decoder_embed_dim (int): Decoder's embedding dimension.
|
||||
Defaults to 512.
|
||||
decoder_depth (int): The depth of decoder. Defaults to 8.
|
||||
decoder_num_heads (int): Number of attention heads of decoder.
|
||||
Defaults to 16.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
|
||||
Defaults to 4.
|
||||
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
|
||||
init_cfg (Union[List[dict], dict], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_patches: int = 196,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 1024,
|
||||
encoder_stride: int = 32,
|
||||
decoder_embed_dim: int = 512,
|
||||
decoder_depth: int = 8,
|
||||
decoder_num_heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
|
||||
super().__init__(
|
||||
num_patches=num_patches,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
decoder_embed_dim=decoder_embed_dim,
|
||||
decoder_depth=decoder_depth,
|
||||
decoder_num_heads=decoder_num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
norm_cfg=norm_cfg,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.decoder_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches, decoder_embed_dim),
|
||||
requires_grad=False)
|
||||
self.decoder_pred = nn.Linear(decoder_embed_dim, encoder_stride**2 * 3)
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize position embedding and mask token of MixMIM decoder."""
|
||||
super(MAEPretrainDecoder, self).init_weights()
|
||||
|
||||
decoder_pos_embed = build_2d_sincos_position_embedding(
|
||||
int(self.num_patches**.5),
|
||||
self.decoder_pos_embed.shape[-1],
|
||||
cls_token=False)
|
||||
self.decoder_pos_embed.data.copy_(decoder_pos_embed.float())
|
||||
|
||||
torch.nn.init.normal_(self.mask_token, std=.02)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input features, which is of shape (N, L, C).
|
||||
mask (torch.Tensor): The tensor to indicate which tokens a
|
||||
re masked.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reconstructed features, which is of shape
|
||||
(N, L, C).
|
||||
"""
|
||||
|
||||
x = self.decoder_embed(x)
|
||||
B, L, C = x.shape
|
||||
|
||||
mask_tokens = self.mask_token.expand(B, L, -1)
|
||||
x1 = x * (1 - mask) + mask_tokens * mask
|
||||
x2 = x * mask + mask_tokens * (1 - mask)
|
||||
x = torch.cat([x1, x2], dim=0)
|
||||
|
||||
# add pos embed
|
||||
x = x + self.decoder_pos_embed
|
||||
|
||||
# apply Transformer blocks
|
||||
for idx, blk in enumerate(self.decoder_blocks):
|
||||
x = blk(x)
|
||||
x = self.decoder_norm(x)
|
||||
|
||||
# predictor projection
|
||||
x = self.decoder_pred(x)
|
||||
|
||||
return x
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MoCoV2Neck(BaseModule):
|
||||
"""The non-linear neck of MoCo v2: fc-relu-fc.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
hid_channels (int): Number of hidden channels.
|
||||
out_channels (int): Number of output channels.
|
||||
with_avg_pool (bool): Whether to apply the global
|
||||
average pooling after backbone. Defaults to True.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
hid_channels: int,
|
||||
out_channels: int,
|
||||
with_avg_pool: bool = True,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = None) -> None:
|
||||
super().__init__(init_cfg)
|
||||
self.with_avg_pool = with_avg_pool
|
||||
if with_avg_pool:
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(in_channels, hid_channels), nn.ReLU(inplace=True),
|
||||
nn.Linear(hid_channels, out_channels))
|
||||
|
||||
def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tuple[torch.Tensor]): The feature map of backbone.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor]: The output features.
|
||||
"""
|
||||
assert len(x) == 1
|
||||
x = x[0]
|
||||
if self.with_avg_pool:
|
||||
x = self.avgpool(x)
|
||||
return (self.mlp(x.view(x.size(0), -1)), )
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class NonLinearNeck(BaseModule):
|
||||
"""The non-linear neck.
|
||||
|
||||
Structure: fc-bn-[relu-fc-bn] where the substructure in [] can be repeated.
|
||||
For the default setting, the repeated time is 1.
|
||||
The neck can be used in many algorithms, e.g., SimCLR, BYOL, SimSiam.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
hid_channels (int): Number of hidden channels.
|
||||
out_channels (int): Number of output channels.
|
||||
num_layers (int): Number of fc layers. Defaults to 2.
|
||||
with_bias (bool): Whether to use bias in fc layers (except for the
|
||||
last). Defaults to False.
|
||||
with_last_bn (bool): Whether to add the last BN layer.
|
||||
Defaults to True.
|
||||
with_last_bn_affine (bool): Whether to have learnable affine parameters
|
||||
in the last BN layer (set False for SimSiam). Defaults to True.
|
||||
with_last_bias (bool): Whether to use bias in the last fc layer.
|
||||
Defaults to False.
|
||||
with_avg_pool (bool): Whether to apply the global average pooling
|
||||
after backbone. Defaults to True.
|
||||
vit_backbone (bool): The key to indicate whether the upstream backbone
|
||||
is ViT. Defaults to False.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Defaults to dict(type='SyncBN').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hid_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 2,
|
||||
with_bias: bool = False,
|
||||
with_last_bn: bool = True,
|
||||
with_last_bn_affine: bool = True,
|
||||
with_last_bias: bool = False,
|
||||
with_avg_pool: bool = True,
|
||||
vit_backbone: bool = False,
|
||||
norm_cfg: dict = dict(type='SyncBN'),
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = [
|
||||
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
) -> None:
|
||||
super(NonLinearNeck, self).__init__(init_cfg)
|
||||
self.with_avg_pool = with_avg_pool
|
||||
self.vit_backbone = vit_backbone
|
||||
if with_avg_pool:
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc0 = nn.Linear(in_channels, hid_channels, bias=with_bias)
|
||||
self.bn0 = build_norm_layer(norm_cfg, hid_channels)[1]
|
||||
|
||||
self.fc_names = []
|
||||
self.bn_names = []
|
||||
for i in range(1, num_layers):
|
||||
this_channels = out_channels if i == num_layers - 1 \
|
||||
else hid_channels
|
||||
if i != num_layers - 1:
|
||||
self.add_module(
|
||||
f'fc{i}',
|
||||
nn.Linear(hid_channels, this_channels, bias=with_bias))
|
||||
self.add_module(f'bn{i}',
|
||||
build_norm_layer(norm_cfg, this_channels)[1])
|
||||
self.bn_names.append(f'bn{i}')
|
||||
else:
|
||||
self.add_module(
|
||||
f'fc{i}',
|
||||
nn.Linear(
|
||||
hid_channels, this_channels, bias=with_last_bias))
|
||||
if with_last_bn:
|
||||
self.add_module(
|
||||
f'bn{i}',
|
||||
build_norm_layer(
|
||||
dict(**norm_cfg, affine=with_last_bn_affine),
|
||||
this_channels)[1])
|
||||
self.bn_names.append(f'bn{i}')
|
||||
else:
|
||||
self.bn_names.append(None)
|
||||
self.fc_names.append(f'fc{i}')
|
||||
|
||||
def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (Tuple[torch.Tensor]): The feature map of backbone.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor]: The output features.
|
||||
"""
|
||||
assert len(x) == 1
|
||||
x = x[0]
|
||||
if self.vit_backbone:
|
||||
x = x[-1]
|
||||
if self.with_avg_pool:
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc0(x)
|
||||
x = self.bn0(x)
|
||||
for fc_name, bn_name in zip(self.fc_names, self.bn_names):
|
||||
fc = getattr(self, fc_name)
|
||||
x = self.relu(x)
|
||||
x = fc(x)
|
||||
if bn_name is not None:
|
||||
bn = getattr(self, bn_name)
|
||||
x = bn(x)
|
||||
return (x, )
|
|
@ -0,0 +1,33 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SimMIMLinearDecoder(BaseModule):
|
||||
"""Linear Decoder For SimMIM pretraining.
|
||||
|
||||
This neck reconstructs the original image from the shrunk feature map.
|
||||
|
||||
Args:
|
||||
in_channels (int): Channel dimension of the feature map.
|
||||
encoder_stride (int): The total stride of the encoder.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, encoder_stride: int) -> None:
|
||||
super().__init__()
|
||||
self.decoder = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=encoder_stride**2 * 3,
|
||||
kernel_size=1),
|
||||
nn.PixelShuffle(encoder_stride),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function."""
|
||||
x = self.decoder(x)
|
||||
return x
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class SwAVNeck(BaseModule):
|
||||
"""The non-linear neck of SwAV: fc-bn-relu-fc-normalization.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
hid_channels (int): Number of hidden channels.
|
||||
out_channels (int): Number of output channels.
|
||||
with_avg_pool (bool): Whether to apply the global average pooling after
|
||||
backbone. Defaults to True.
|
||||
with_l2norm (bool): whether to normalize the output after projection.
|
||||
Defaults to True.
|
||||
norm_cfg (dict): Dictionary to construct and config norm layer.
|
||||
Defaults to dict(type='SyncBN').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hid_channels: int,
|
||||
out_channels: int,
|
||||
with_avg_pool: bool = True,
|
||||
with_l2norm: bool = True,
|
||||
norm_cfg: dict = dict(type='SyncBN'),
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = [
|
||||
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
) -> None:
|
||||
super().__init__(init_cfg)
|
||||
self.with_avg_pool = with_avg_pool
|
||||
self.with_l2norm = with_l2norm
|
||||
if with_avg_pool:
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
if out_channels == 0:
|
||||
self.projection_neck = nn.Identity()
|
||||
elif hid_channels == 0:
|
||||
self.projection_neck = nn.Linear(in_channels, out_channels)
|
||||
else:
|
||||
self.norm = build_norm_layer(norm_cfg, hid_channels)[1]
|
||||
self.projection_neck = nn.Sequential(
|
||||
nn.Linear(in_channels, hid_channels),
|
||||
self.norm,
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(hid_channels, out_channels),
|
||||
)
|
||||
|
||||
def forward_projection(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute projection.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The feature vectors after pooling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output features with projection or L2-norm.
|
||||
"""
|
||||
x = self.projection_neck(x)
|
||||
if self.with_l2norm:
|
||||
x = nn.functional.normalize(x, dim=1, p=2)
|
||||
return x
|
||||
|
||||
def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x (List[torch.Tensor]): list of feature maps, len(x) according to
|
||||
len(num_crops).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The projection vectors.
|
||||
"""
|
||||
avg_out = []
|
||||
for _x in x:
|
||||
_x = _x[0]
|
||||
if self.with_avg_pool:
|
||||
_out = self.avgpool(_x)
|
||||
avg_out.append(_out)
|
||||
feat_vec = torch.cat(avg_out) # [sum(num_crops) * N, C]
|
||||
feat_vec = feat_vec.view(feat_vec.size(0), -1)
|
||||
output = self.forward_projection(feat_vec)
|
||||
return output
|
|
@ -1,7 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .attention import (BEiTAttention, ChannelMultiheadAttention, LeAttention,
|
||||
MultiheadAttention, ShiftWindowMSA, WindowMSA,
|
||||
WindowMSAV2)
|
||||
from .attention import (BEiTAttention, ChannelMultiheadAttention,
|
||||
CrossMultiheadAttention, LeAttention,
|
||||
MultiheadAttention, PromptMultiheadAttention,
|
||||
ShiftWindowMSA, WindowMSA, WindowMSAV2)
|
||||
from .batch_augments import CutMix, Mixup, RandomBatchAugment, ResizeMix
|
||||
from .channel_shuffle import channel_shuffle
|
||||
from .data_preprocessor import ClsDataPreprocessor
|
||||
|
@ -13,7 +14,8 @@ from .layer_scale import LayerScale
|
|||
from .make_divisible import make_divisible
|
||||
from .norm import GRN, LayerNorm2d, build_norm_layer
|
||||
from .position_encoding import (ConditionalPositionEncoding,
|
||||
PositionEncodingFourier)
|
||||
PositionEncodingFourier,
|
||||
build_2d_sincos_position_embedding)
|
||||
from .se_layer import SELayer
|
||||
|
||||
__all__ = [
|
||||
|
@ -25,5 +27,6 @@ __all__ = [
|
|||
'ClsDataPreprocessor', 'Mixup', 'CutMix', 'ResizeMix', 'BEiTAttention',
|
||||
'LayerScale', 'WindowMSA', 'WindowMSAV2', 'ChannelMultiheadAttention',
|
||||
'PositionEncodingFourier', 'LeAttention', 'GRN', 'LayerNorm2d',
|
||||
'build_norm_layer'
|
||||
'build_norm_layer', 'CrossMultiheadAttention',
|
||||
'build_2d_sincos_position_embedding', 'PromptMultiheadAttention'
|
||||
]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import itertools
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -683,6 +684,7 @@ class BEiTAttention(BaseModule):
|
|||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
if self.relative_position_bias_table is not None:
|
||||
Wh = self.window_size[0]
|
||||
Ww = self.window_size[1]
|
||||
|
@ -879,3 +881,192 @@ class LeAttention(BaseModule):
|
|||
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class CrossMultiheadAttention(BaseModule):
|
||||
"""Cross attention between queries and the union of keys and values.
|
||||
|
||||
This module is different from ``MultiheadAttention``, for the attention
|
||||
is computed between queries and the union of keys and values.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
||||
Defaults to True.
|
||||
qk_scale (float, optional): Override default qk scale of
|
||||
``head_dim ** -0.5`` if set. Defaults to None.
|
||||
attn_drop (float): Dropout rate of the dropout layer after the
|
||||
attention calculation of query and key. Defaults to 0.
|
||||
proj_drop (float): Dropout rate of the dropout layer after the
|
||||
output projection. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: float = None,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = embed_dims // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.q = nn.Linear(embed_dims, embed_dims, bias=False)
|
||||
self.k = nn.Linear(embed_dims, embed_dims, bias=False)
|
||||
self.v = nn.Linear(embed_dims, embed_dims, bias=False)
|
||||
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(embed_dims))
|
||||
self.v_bias = nn.Parameter(torch.zeros(embed_dims))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.k_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
k: torch.Tensor = None,
|
||||
v: torch.Tensor = None) -> None:
|
||||
"""Forward function."""
|
||||
B, N, _ = x.shape
|
||||
|
||||
N_k = k.shape[1]
|
||||
N_v = v.shape[1]
|
||||
|
||||
q_bias, k_bias, v_bias = None, None, None
|
||||
if self.q_bias is not None:
|
||||
q_bias = self.q_bias
|
||||
k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
|
||||
v_bias = self.v_bias
|
||||
|
||||
q = F.linear(
|
||||
input=x, weight=self.q.weight, bias=q_bias) # (B, N_q, dim)
|
||||
k = F.linear(
|
||||
input=k, weight=self.k.weight, bias=k_bias) # (B, N_k, dim)
|
||||
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
|
||||
|
||||
q = q.reshape(B, N, 1, self.num_heads,
|
||||
-1).permute(2, 0, 3, 1,
|
||||
4).squeeze(0) # (B, num_heads, N_q, dim)
|
||||
k = k.reshape(B, N_k, 1, self.num_heads,
|
||||
-1).permute(2, 0, 3, 1,
|
||||
4).squeeze(0) # (B, num_heads, N_k, dim)
|
||||
v = v.reshape(B, N_v, 1, self.num_heads,
|
||||
-1).permute(2, 0, 3, 1,
|
||||
4).squeeze(0) # (B, num_heads, N_v, dim)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PromptMultiheadAttention(MultiheadAttention):
|
||||
"""Prompt Multihead Attention for MILAN.
|
||||
|
||||
This module is specific for the prompt encoder in MILAN. It will not update
|
||||
the visible tokens from the encoder.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The embedding dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
input_dims (int, optional): The input dimension, and if None,
|
||||
use ``embed_dims``. Defaults to None.
|
||||
attn_drop (float): Dropout rate of the dropout layer after the
|
||||
attention calculation of query and key. Defaults to 0.
|
||||
proj_drop (float): Dropout rate of the dropout layer after the
|
||||
output projection. Defaults to 0.
|
||||
dropout_layer (dict): The dropout config before adding the shortcut.
|
||||
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
|
||||
qkv_bias (bool): If True, add a learnable bias to q, k, v.
|
||||
Defaults to True.
|
||||
qk_scale (float, optional): Override default qk scale of
|
||||
``head_dim ** -0.5`` if set. Defaults to None.
|
||||
proj_bias (bool) If True, add a learnable bias to output projection.
|
||||
Defaults to True.
|
||||
v_shortcut (bool): Add a shortcut from value to output. It's usually
|
||||
used if ``input_dims`` is different from ``embed_dims``.
|
||||
Defaults to False.
|
||||
return_attention (bool): If True, return the attention map, computed by
|
||||
the cross attention between the class token and all other tokens.
|
||||
Defaults to False.
|
||||
init_cfg (Union[List[dict], dict], optional): The Config for
|
||||
initialization. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims: int,
|
||||
num_heads: int,
|
||||
input_dims: Optional[int] = None,
|
||||
attn_drop: float = 0,
|
||||
proj_drop: float = 0,
|
||||
dropout_layer: dict = dict(type='Dropout', drop_prob=0.),
|
||||
qkv_bias: bool = True,
|
||||
qk_scale: Optional[float] = None,
|
||||
proj_bias: bool = True,
|
||||
v_shortcut: bool = False,
|
||||
use_layer_scale: bool = False,
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(embed_dims, num_heads, input_dims, attn_drop,
|
||||
proj_drop, dropout_layer, qkv_bias, qk_scale,
|
||||
proj_bias, v_shortcut, use_layer_scale, init_cfg)
|
||||
# no longer need qkv
|
||||
del self.qkv
|
||||
|
||||
# to project the mask tokens
|
||||
self.q = nn.Linear(embed_dims, embed_dims, bias=qkv_bias)
|
||||
# to project al the tokens
|
||||
self.kv = nn.Linear(embed_dims, embed_dims * 2, bias=qkv_bias)
|
||||
|
||||
def forward(self, x: torch.Tensor, visible_tokens: torch.Tensor,
|
||||
ids_restore: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function for `PromptMultiheadAttention`.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Mask token features with shape N x L_m x C.
|
||||
visible_tokens (torch.Tensor): The visible tokens features from
|
||||
encoder with shape N x L_v x C.
|
||||
ids_restore (torch.Tensor): The ids of all tokens in the original
|
||||
image with shape N x L.
|
||||
|
||||
Returns:
|
||||
torch Tensor: Output features with shape N x L x C.
|
||||
"""
|
||||
x_ = torch.cat([visible_tokens[:, 1:, :], x], dim=1)
|
||||
assert x_.shape[1] == ids_restore.shape[1]
|
||||
x_ = torch.gather(
|
||||
x_,
|
||||
dim=1,
|
||||
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[-1]))
|
||||
x_ = torch.cat([visible_tokens[:, :1, :], x_], dim=1)
|
||||
|
||||
# full sequence shape
|
||||
B, _, _ = x_.shape
|
||||
q = self.q(x).reshape(B, x.shape[1], self.num_heads,
|
||||
self.head_dims).permute(0, 2, 1, 3)
|
||||
kv = self.kv(x_).reshape(B, x_.shape[1], 2, self.num_heads,
|
||||
self.head_dims).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, x.shape[1], self.embed_dims)
|
||||
x = self.proj(x)
|
||||
x = self.out_drop(self.gamma1(self.proj_drop(x)))
|
||||
return x
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -107,3 +108,57 @@ class PositionEncodingFourier(BaseModule):
|
|||
pos = self.proj(pos)
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def build_2d_sincos_position_embedding(
|
||||
patches_resolution: Union[int, Sequence[int]],
|
||||
embed_dims: int,
|
||||
temperature: Optional[int] = 10000.,
|
||||
cls_token: Optional[bool] = False) -> torch.Tensor:
|
||||
"""The function is to build position embedding for model to obtain the
|
||||
position information of the image patches.
|
||||
|
||||
Args:
|
||||
patches_resolution (Union[int, Sequence[int]]): The resolution of each
|
||||
patch.
|
||||
embed_dims (int): The dimension of the embedding vector.
|
||||
temperature (int, optional): The temperature parameter. Defaults to
|
||||
10000.
|
||||
cls_token (bool, optional): Whether to concatenate class token.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The position embedding vector.
|
||||
"""
|
||||
|
||||
if isinstance(patches_resolution, int):
|
||||
patches_resolution = (patches_resolution, patches_resolution)
|
||||
|
||||
h, w = patches_resolution
|
||||
grid_w = torch.arange(w, dtype=torch.float32)
|
||||
grid_h = torch.arange(h, dtype=torch.float32)
|
||||
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
|
||||
assert embed_dims % 4 == 0, \
|
||||
'Embed dimension must be divisible by 4.'
|
||||
pos_dim = embed_dims // 4
|
||||
|
||||
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
||||
omega = 1. / (temperature**omega)
|
||||
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
|
||||
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
|
||||
|
||||
pos_emb = torch.cat(
|
||||
[
|
||||
torch.sin(out_w),
|
||||
torch.cos(out_w),
|
||||
torch.sin(out_h),
|
||||
torch.cos(out_h)
|
||||
],
|
||||
dim=1,
|
||||
)[None, :, :]
|
||||
|
||||
if cls_token:
|
||||
cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32)
|
||||
pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1)
|
||||
|
||||
return pos_emb
|
||||
|
|
|
@ -371,3 +371,33 @@ def test_seesaw_loss():
|
|||
fake_label = torch.Tensor([0]).long()
|
||||
loss = loss_cls(fake_pred, fake_label)
|
||||
assert torch.allclose(loss, torch.tensor(200.) + torch.tensor(100.).log())
|
||||
|
||||
|
||||
def test_reconstruction_loss():
|
||||
|
||||
# test L2 loss
|
||||
loss_config = dict(type='PixelReconstructionLoss', criterion='L2')
|
||||
loss = build_loss(loss_config)
|
||||
|
||||
fake_pred = torch.rand((2, 196, 768))
|
||||
fake_target = torch.rand((2, 196, 768))
|
||||
fake_mask = torch.ones((2, 196))
|
||||
loss_value = loss(fake_pred, fake_target, fake_mask)
|
||||
|
||||
assert isinstance(loss_value.item(), float)
|
||||
|
||||
# test L1 loss
|
||||
loss_config = dict(
|
||||
type='PixelReconstructionLoss', criterion='L1', channel=3)
|
||||
loss = build_loss(loss_config)
|
||||
|
||||
fake_pred = torch.rand((2, 3, 192, 192))
|
||||
fake_target = torch.rand((2, 3, 192, 192))
|
||||
fake_mask = torch.ones((2, 1, 192, 192))
|
||||
loss_value = loss(fake_pred, fake_target, fake_mask)
|
||||
|
||||
assert isinstance(loss_value.item(), float)
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
loss_config = dict(type='PixelReconstructionLoss', criterion='L3')
|
||||
loss = build_loss(loss_config)
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
|
||||
from mmpretrain.models.necks import (GeneralizedMeanPooling,
|
||||
GlobalAveragePooling, HRFuseScales,
|
||||
LinearReduction)
|
||||
LinearNeck)
|
||||
|
||||
|
||||
def test_gap_neck():
|
||||
|
@ -107,8 +107,9 @@ def test_hr_fuse_scales():
|
|||
|
||||
def test_linear_reduction():
|
||||
# test linear_reduction without `act_cfg` and `norm_cfg`
|
||||
neck = LinearReduction(10, 5, None, None)
|
||||
neck = LinearNeck(10, 5, 0, None, None)
|
||||
neck.eval()
|
||||
assert isinstance(neck.gap, torch.nn.Identity)
|
||||
assert isinstance(neck.act, torch.nn.Identity)
|
||||
assert isinstance(neck.norm, torch.nn.Identity)
|
||||
|
||||
|
@ -125,12 +126,36 @@ def test_linear_reduction():
|
|||
# batch_size, out_features
|
||||
assert output[-1].shape == (1, 5)
|
||||
|
||||
# batch_size, in_channels, out_channels, gap_dim
|
||||
neck = LinearNeck(10, 5, 1, None, None)
|
||||
fake_input = torch.rand(1, 10, 10)
|
||||
output = neck(fake_input)
|
||||
# batch_size, out_features
|
||||
assert output[-1].shape == (1, 5)
|
||||
|
||||
# batch_size, in_channels, out_channels, gap_dim
|
||||
neck = LinearNeck(10, 5, 2, None, None)
|
||||
fake_input = torch.rand(1, 10, 10, 10)
|
||||
output = neck(fake_input)
|
||||
# batch_size, out_features
|
||||
assert output[-1].shape == (1, 5)
|
||||
|
||||
# batch_size, in_channels, out_channels, gap_dim
|
||||
neck = LinearNeck(10, 5, 3, None, None)
|
||||
fake_input = torch.rand(1, 10, 10, 10, 10)
|
||||
output = neck(fake_input)
|
||||
# batch_size, out_features
|
||||
assert output[-1].shape == (1, 5)
|
||||
|
||||
# batch_size, in_channels, out_channels, gap_dim
|
||||
with pytest.raises(AssertionError):
|
||||
neck = LinearNeck(10, 5, None, None, None)
|
||||
|
||||
# test linear_reduction with `init_cfg`
|
||||
neck = LinearReduction(
|
||||
10, 5, init_cfg=dict(type='Xavier', layer=['Linear']))
|
||||
neck = LinearNeck(10, 5, init_cfg=dict(type='Xavier', layer=['Linear']))
|
||||
|
||||
# test linear_reduction with `act_cfg` and `norm_cfg`
|
||||
neck = LinearReduction(
|
||||
neck = LinearNeck(
|
||||
10, 5, act_cfg=dict(type='ReLU'), norm_cfg=dict(type='BN1d'))
|
||||
neck.eval()
|
||||
|
||||
|
|
Loading…
Reference in New Issue