100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Modified from https://github.com/zejiangh/MILAN
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from mmengine.model import BaseModule
|
|
from torch import nn
|
|
|
|
from mmpretrain.models.utils.clip_generator_helper import \
|
|
ResidualAttentionBlock
|
|
from mmpretrain.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module()
|
|
class CLIPTransformer(nn.Module):
|
|
"""Transformer.
|
|
|
|
Both visual and text branches use this transformer.
|
|
|
|
Args:
|
|
width (int): The feature dimension.
|
|
layers (int): The number of layers.
|
|
heads (int): The number of attention heads.
|
|
attn_mask (torch.Tensor, optional): The attention mask.
|
|
"""
|
|
|
|
def __init__(self,
|
|
width: int,
|
|
layers: int,
|
|
heads: int,
|
|
attn_mask: Optional[torch.Tensor] = None) -> None:
|
|
super().__init__()
|
|
self.width = width
|
|
self.layers = layers
|
|
self.resblocks = nn.ModuleList()
|
|
for _ in range(layers - 1):
|
|
self.resblocks.append(
|
|
ResidualAttentionBlock(width, heads, attn_mask))
|
|
self.resblocks.append(
|
|
ResidualAttentionBlock(
|
|
width, heads, attn_mask, return_attention=True))
|
|
|
|
def forward(
|
|
self, x: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Forward function."""
|
|
z = []
|
|
for idx, blk in enumerate(self.resblocks):
|
|
if idx < self.layers - 1:
|
|
x = blk(x)
|
|
z.append(x.permute(1, 0, 2))
|
|
else:
|
|
x, attention = blk(x)
|
|
z.append(x.permute(1, 0, 2))
|
|
return x, attention, z
|
|
|
|
|
|
@MODELS.register_module()
|
|
class CLIPProjection(BaseModule):
|
|
"""Neck with CLIP Projection.
|
|
|
|
Args:
|
|
in_channels (int): Number of channels in the input.
|
|
out_channels (int): Number of channels in the output.
|
|
init_cfg (dict | list[dict], optional): Initialization config dict.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
init_cfg: Optional[dict] = None):
|
|
super(CLIPProjection, self).__init__(init_cfg=init_cfg)
|
|
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
scale = in_channels**-0.5
|
|
self.proj = nn.Parameter(scale *
|
|
torch.randn(in_channels, out_channels))
|
|
|
|
def forward(self, inputs: Tuple) -> Tuple[torch.Tensor]:
|
|
"""forward function.
|
|
|
|
Args:
|
|
inputs (Tuple): The features extracted from
|
|
the backbone. Multiple stage inputs are acceptable but only
|
|
the last stage will be used.
|
|
Returns:
|
|
Tuple(torch.Tensor)): A tuple of reducted features.
|
|
"""
|
|
if isinstance(inputs, tuple):
|
|
inputs = inputs[-1]
|
|
out = inputs @ self.proj
|
|
elif isinstance(inputs, torch.Tensor):
|
|
out = inputs @ self.proj
|
|
else:
|
|
raise TypeError(
|
|
'`CLIPProjection` neck inputs should be tuple or torch.tensor')
|
|
return (out, )
|