patch embedding for ada

pull/511/head
Veronikkkka 2025-03-09 12:25:56 +00:00
parent 8cb907ab93
commit badc5ddfc8
3 changed files with 42 additions and 20 deletions

View File

@ -68,7 +68,10 @@ class PatchEmbed(nn.Module):
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
# print("H, w,m patch", H, W, patch_H, patch_W)
if(H%patch_H !=0 ):
H -= 1
W -= 1
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"

View File

@ -25,21 +25,21 @@ class Merge_block(BaseModule):
self.ada_c = ada_c
# 784 - embedded dim + adapter_c
self.embeded_dim = 768
self.fc_1 = nn.Linear(self.embeded_dim + ada_c, mid_c)
self.fc_1 = nn.Linear(self.embeded_dim*2, mid_c)
self.fc_2 = nn.Linear(mid_c, self.embeded_dim)
self.return_ada = return_ada
if self.return_ada:
self.conv_3 = nn.Conv1d(mid_c, ada_c * 2, kernel_size=1) # 1D Conv instead of 3x3
self.conv_3 = nn.Conv1d(mid_c, self.embeded_dim, kernel_size=1) # 1D Conv instead of 3x3
else:
self.conv_3 = None
def forward(self, fea, adapter, ratio=1.0):
res = fea
# print("Before concatenation: ", fea.shape, adapter.shape, self.fea_c, self.ada_c)
# print("before concatenation: ", fea.shape, adapter.shape)
fea = torch.cat([fea, adapter], dim=-1) # (B, seq_len, fea_c + ada_c)
# print("after concatenation: ", fea.shape, adapter.shape)
B, seq_len, C = fea.shape
fea = fea.view(B * seq_len, C)
fea = self.fc_1(fea)
@ -143,8 +143,9 @@ class Model_level_Adapeter(BaseModule):
else:
adapter = torch.cat([self.conv_1(IMGS[0]), self.conv_2(IMGS[1]), self.conv_3(IMGS[2])], dim=1)
# print("Adapter:", adapter.shape)
adapter = self.uni_conv(adapter)
# print("Adapter:", adapter.shape)
# adapter = self.res_1(adapter)
# adapter = self.res_2(adapter)
return adapter

View File

@ -16,7 +16,7 @@ import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn.init import trunc_normal_
import torch.nn.functional as F
from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
from dinov2.models.help import Merge_block, Model_level_Adapeter
from dinov2.models.help import VitInputLevelAdapter as Input_level_Adapeter
@ -123,6 +123,7 @@ class DinoVisionTransformer(nn.Module):
self.merge_ratio = merge_ratio
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
self.patch_embed_for_model_adapter = embed_layer(img_size=56, patch_size=4, in_chans=16, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
@ -206,8 +207,9 @@ class DinoVisionTransformer(nn.Module):
self.merge_1 = Merge_block(fea_c=fea_c_s[0], ada_c=ada_c_s[0], mid_c=mid_c_s[0], return_ada=True)
self.merge_2 = Merge_block(fea_c=fea_c_s[1], ada_c=ada_c_s[1], mid_c=mid_c_s[1], return_ada=True)
self.merge_3 = Merge_block(fea_c=fea_c_s[2], ada_c=ada_c_s[2], mid_c=mid_c_s[2], return_ada=False)
self.merge_blocks = [self.merge_1, self.merge_2, self.merge_3]
print(self.merge_blocks)
self.first_linear_proj = nn.Linear(625, 50)
self.second_linear_proj = nn.Linear(3136, 257)
@ -279,11 +281,30 @@ class DinoVisionTransformer(nn.Module):
x = x_raw[-1]
x = self.patch_embed(x)
# print("ada.shape ", ada.shape, x.shape)
ada = self.patch_embed_for_model_adapter(ada)
tensor2_reshaped = ada.transpose(1, 2) # [32, 768, 196]
ada = F.interpolate(
tensor2_reshaped,
size=x.shape[1],
mode='linear',
align_corners=False
)
ada = ada.transpose(1, 2)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
ada = torch.where(masks.unsqueeze(-1), self.mask_token.to(ada.dtype).unsqueeze(0), ada)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
ada = torch.cat((self.cls_token.expand(ada.shape[0], -1, -1), ada), dim=1)
ada = ada + self.interpolate_pos_encoding(ada, w, h)
x = x + self.interpolate_pos_encoding(x, w, h)
# ada = ada + self.interpolate_pos_encoding(ada, w, h)
if self.register_tokens is not None:
x = torch.cat(
@ -294,19 +315,9 @@ class DinoVisionTransformer(nn.Module):
),
dim=1,
)
# print("x.shape",x.shape)
ada = ada.reshape(ada.shape[0], ada.shape[1], -1)
batch_size, channels, features = ada.shape
target_seq_len = x.shape[1]
# print("Features: ", features, target_seq_len)
if x.shape[1] == 50:
linear_proj = self.first_linear_proj
elif x.shape[1] == 257:
linear_proj = self.second_linear_proj
# linear_proj = nn.Linear(features, target_seq_len).to(ada.device).to(ada.dtype)
ada = linear_proj(ada).permute(0, 2, 1)
return x, ada
def forward_features_list(self, x_list, masks_list):
@ -314,8 +325,10 @@ class DinoVisionTransformer(nn.Module):
x_s = []
ada_list = []
# print("x_list", [i.shape for i in x_list])
for x, masks in zip(x_list, masks_list):
x_, ada = self.prepare_tokens_with_masks(x, masks)
# print(x.shape, ada.shape, self.model_adapter)
x_s.append(x_)
ada_list.append(ada)
@ -323,9 +336,13 @@ class DinoVisionTransformer(nn.Module):
x = x_s
for i, blk in enumerate(self.blocks):
# print([j.shape for j in x])
# print(ada.shape)
x = blk(x)
if self.w_lut and ada is not None and i < len(self.merge_blocks):
# print("HERE 22")
x_ada_pairs = [self.merge_blocks[i](x_i, ada_i, ratio=self.merge_ratio) for x_i, ada_i in zip(x, ada_list)]
x, ada_list = map(list, zip(*x_ada_pairs))
@ -353,6 +370,7 @@ class DinoVisionTransformer(nn.Module):
for i, blk in enumerate(self.blocks):
x = blk(x)
if self.w_lut and ada is not None and i < len(self.merge_blocks):
# print("HERE 11", x.shape, ada.shape)
x, ada = self.merge_blocks[i](x, ada, ratio=self.merge_ratio)
x_norm = self.norm(x)