From badc5ddfc8a5c263ea12f20bd0018e62c6e6466e Mon Sep 17 00:00:00 2001 From: Veronikkkka Date: Sun, 9 Mar 2025 12:25:56 +0000 Subject: [PATCH] patch embedding for ada --- dinov2/layers/patch_embed.py | 5 +++- dinov2/models/help.py | 11 +++---- dinov2/models/vision_transformer.py | 46 ++++++++++++++++++++--------- 3 files changed, 42 insertions(+), 20 deletions(-) diff --git a/dinov2/layers/patch_embed.py b/dinov2/layers/patch_embed.py index 8b7c080..8d12ca9 100644 --- a/dinov2/layers/patch_embed.py +++ b/dinov2/layers/patch_embed.py @@ -68,7 +68,10 @@ class PatchEmbed(nn.Module): def forward(self, x: Tensor) -> Tensor: _, _, H, W = x.shape patch_H, patch_W = self.patch_size - + # print("H, w,m patch", H, W, patch_H, patch_W) + if(H%patch_H !=0 ): + H -= 1 + W -= 1 assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" diff --git a/dinov2/models/help.py b/dinov2/models/help.py index 987f7c3..d22ef2d 100644 --- a/dinov2/models/help.py +++ b/dinov2/models/help.py @@ -25,21 +25,21 @@ class Merge_block(BaseModule): self.ada_c = ada_c # 784 - embedded dim + adapter_c self.embeded_dim = 768 - self.fc_1 = nn.Linear(self.embeded_dim + ada_c, mid_c) + self.fc_1 = nn.Linear(self.embeded_dim*2, mid_c) self.fc_2 = nn.Linear(mid_c, self.embeded_dim) self.return_ada = return_ada if self.return_ada: - self.conv_3 = nn.Conv1d(mid_c, ada_c * 2, kernel_size=1) # 1D Conv instead of 3x3 + self.conv_3 = nn.Conv1d(mid_c, self.embeded_dim, kernel_size=1) # 1D Conv instead of 3x3 else: self.conv_3 = None def forward(self, fea, adapter, ratio=1.0): res = fea # print("Before concatenation: ", fea.shape, adapter.shape, self.fea_c, self.ada_c) - + # print("before concatenation: ", fea.shape, adapter.shape) fea = torch.cat([fea, adapter], dim=-1) # (B, seq_len, fea_c + ada_c) - + # print("after concatenation: ", fea.shape, adapter.shape) B, seq_len, C = fea.shape fea = fea.view(B * seq_len, C) fea = self.fc_1(fea) @@ -143,8 +143,9 @@ class Model_level_Adapeter(BaseModule): else: adapter = torch.cat([self.conv_1(IMGS[0]), self.conv_2(IMGS[1]), self.conv_3(IMGS[2])], dim=1) - + # print("Adapter:", adapter.shape) adapter = self.uni_conv(adapter) + # print("Adapter:", adapter.shape) # adapter = self.res_1(adapter) # adapter = self.res_2(adapter) return adapter diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py index 2995900..e347380 100644 --- a/dinov2/models/vision_transformer.py +++ b/dinov2/models/vision_transformer.py @@ -16,7 +16,7 @@ import torch import torch.nn as nn import torch.utils.checkpoint from torch.nn.init import trunc_normal_ - +import torch.nn.functional as F from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block from dinov2.models.help import Merge_block, Model_level_Adapeter from dinov2.models.help import VitInputLevelAdapter as Input_level_Adapeter @@ -123,6 +123,7 @@ class DinoVisionTransformer(nn.Module): self.merge_ratio = merge_ratio self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + self.patch_embed_for_model_adapter = embed_layer(img_size=56, patch_size=4, in_chans=16, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) @@ -206,8 +207,9 @@ class DinoVisionTransformer(nn.Module): self.merge_1 = Merge_block(fea_c=fea_c_s[0], ada_c=ada_c_s[0], mid_c=mid_c_s[0], return_ada=True) self.merge_2 = Merge_block(fea_c=fea_c_s[1], ada_c=ada_c_s[1], mid_c=mid_c_s[1], return_ada=True) self.merge_3 = Merge_block(fea_c=fea_c_s[2], ada_c=ada_c_s[2], mid_c=mid_c_s[2], return_ada=False) + self.merge_blocks = [self.merge_1, self.merge_2, self.merge_3] - + print(self.merge_blocks) self.first_linear_proj = nn.Linear(625, 50) self.second_linear_proj = nn.Linear(3136, 257) @@ -279,11 +281,30 @@ class DinoVisionTransformer(nn.Module): x = x_raw[-1] x = self.patch_embed(x) + + # print("ada.shape ", ada.shape, x.shape) + ada = self.patch_embed_for_model_adapter(ada) + tensor2_reshaped = ada.transpose(1, 2) # [32, 768, 196] + + ada = F.interpolate( + tensor2_reshaped, + size=x.shape[1], + mode='linear', + align_corners=False + ) + ada = ada.transpose(1, 2) + if masks is not None: x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + ada = torch.where(masks.unsqueeze(-1), self.mask_token.to(ada.dtype).unsqueeze(0), ada) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + ada = torch.cat((self.cls_token.expand(ada.shape[0], -1, -1), ada), dim=1) + ada = ada + self.interpolate_pos_encoding(ada, w, h) + x = x + self.interpolate_pos_encoding(x, w, h) + # ada = ada + self.interpolate_pos_encoding(ada, w, h) + if self.register_tokens is not None: x = torch.cat( @@ -294,19 +315,9 @@ class DinoVisionTransformer(nn.Module): ), dim=1, ) + # print("x.shape",x.shape) - ada = ada.reshape(ada.shape[0], ada.shape[1], -1) - batch_size, channels, features = ada.shape - target_seq_len = x.shape[1] - # print("Features: ", features, target_seq_len) - if x.shape[1] == 50: - linear_proj = self.first_linear_proj - elif x.shape[1] == 257: - linear_proj = self.second_linear_proj - # linear_proj = nn.Linear(features, target_seq_len).to(ada.device).to(ada.dtype) - ada = linear_proj(ada).permute(0, 2, 1) - - + return x, ada def forward_features_list(self, x_list, masks_list): @@ -314,8 +325,10 @@ class DinoVisionTransformer(nn.Module): x_s = [] ada_list = [] + # print("x_list", [i.shape for i in x_list]) for x, masks in zip(x_list, masks_list): x_, ada = self.prepare_tokens_with_masks(x, masks) + # print(x.shape, ada.shape, self.model_adapter) x_s.append(x_) ada_list.append(ada) @@ -323,9 +336,13 @@ class DinoVisionTransformer(nn.Module): x = x_s for i, blk in enumerate(self.blocks): + # print([j.shape for j in x]) + # print(ada.shape) x = blk(x) + if self.w_lut and ada is not None and i < len(self.merge_blocks): + # print("HERE 22") x_ada_pairs = [self.merge_blocks[i](x_i, ada_i, ratio=self.merge_ratio) for x_i, ada_i in zip(x, ada_list)] x, ada_list = map(list, zip(*x_ada_pairs)) @@ -353,6 +370,7 @@ class DinoVisionTransformer(nn.Module): for i, blk in enumerate(self.blocks): x = blk(x) if self.w_lut and ada is not None and i < len(self.merge_blocks): + # print("HERE 11", x.shape, ada.shape) x, ada = self.merge_blocks[i](x, ada, ratio=self.merge_ratio) x_norm = self.norm(x)