mmclassification/mmpretrain/models/multimodal/clip/clip_transformer.py

100 lines
3.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/zejiangh/MILAN
from typing import Optional, Tuple
import torch
from mmengine.model import BaseModule
from torch import nn
from mmpretrain.models.utils.clip_generator_helper import \
ResidualAttentionBlock
from mmpretrain.registry import MODELS
@MODELS.register_module()
class CLIPTransformer(nn.Module):
"""Transformer.
Both visual and text branches use this transformer.
Args:
width (int): The feature dimension.
layers (int): The number of layers.
heads (int): The number of attention heads.
attn_mask (torch.Tensor, optional): The attention mask.
"""
def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask: Optional[torch.Tensor] = None) -> None:
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList()
for _ in range(layers - 1):
self.resblocks.append(
ResidualAttentionBlock(width, heads, attn_mask))
self.resblocks.append(
ResidualAttentionBlock(
width, heads, attn_mask, return_attention=True))
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward function."""
z = []
for idx, blk in enumerate(self.resblocks):
if idx < self.layers - 1:
x = blk(x)
z.append(x.permute(1, 0, 2))
else:
x, attention = blk(x)
z.append(x.permute(1, 0, 2))
return x, attention, z
@MODELS.register_module()
class CLIPProjection(BaseModule):
"""Neck with CLIP Projection.
Args:
in_channels (int): Number of channels in the input.
out_channels (int): Number of channels in the output.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
init_cfg: Optional[dict] = None):
super(CLIPProjection, self).__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
scale = in_channels**-0.5
self.proj = nn.Parameter(scale *
torch.randn(in_channels, out_channels))
def forward(self, inputs: Tuple) -> Tuple[torch.Tensor]:
"""forward function.
Args:
inputs (Tuple): The features extracted from
the backbone. Multiple stage inputs are acceptable but only
the last stage will be used.
Returns:
Tuple(torch.Tensor)): A tuple of reducted features.
"""
if isinstance(inputs, tuple):
inputs = inputs[-1]
out = inputs @ self.proj
elif isinstance(inputs, torch.Tensor):
out = inputs @ self.proj
else:
raise TypeError(
'`CLIPProjection` neck inputs should be tuple or torch.tensor')
return (out, )