patch embedding for ada
parent
8cb907ab93
commit
badc5ddfc8
|
@ -68,7 +68,10 @@ class PatchEmbed(nn.Module):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
_, _, H, W = x.shape
|
||||
patch_H, patch_W = self.patch_size
|
||||
|
||||
# print("H, w,m patch", H, W, patch_H, patch_W)
|
||||
if(H%patch_H !=0 ):
|
||||
H -= 1
|
||||
W -= 1
|
||||
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
||||
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
||||
|
||||
|
|
|
@ -25,21 +25,21 @@ class Merge_block(BaseModule):
|
|||
self.ada_c = ada_c
|
||||
# 784 - embedded dim + adapter_c
|
||||
self.embeded_dim = 768
|
||||
self.fc_1 = nn.Linear(self.embeded_dim + ada_c, mid_c)
|
||||
self.fc_1 = nn.Linear(self.embeded_dim*2, mid_c)
|
||||
self.fc_2 = nn.Linear(mid_c, self.embeded_dim)
|
||||
self.return_ada = return_ada
|
||||
|
||||
if self.return_ada:
|
||||
self.conv_3 = nn.Conv1d(mid_c, ada_c * 2, kernel_size=1) # 1D Conv instead of 3x3
|
||||
self.conv_3 = nn.Conv1d(mid_c, self.embeded_dim, kernel_size=1) # 1D Conv instead of 3x3
|
||||
else:
|
||||
self.conv_3 = None
|
||||
|
||||
def forward(self, fea, adapter, ratio=1.0):
|
||||
res = fea
|
||||
# print("Before concatenation: ", fea.shape, adapter.shape, self.fea_c, self.ada_c)
|
||||
|
||||
# print("before concatenation: ", fea.shape, adapter.shape)
|
||||
fea = torch.cat([fea, adapter], dim=-1) # (B, seq_len, fea_c + ada_c)
|
||||
|
||||
# print("after concatenation: ", fea.shape, adapter.shape)
|
||||
B, seq_len, C = fea.shape
|
||||
fea = fea.view(B * seq_len, C)
|
||||
fea = self.fc_1(fea)
|
||||
|
@ -143,8 +143,9 @@ class Model_level_Adapeter(BaseModule):
|
|||
|
||||
else:
|
||||
adapter = torch.cat([self.conv_1(IMGS[0]), self.conv_2(IMGS[1]), self.conv_3(IMGS[2])], dim=1)
|
||||
|
||||
# print("Adapter:", adapter.shape)
|
||||
adapter = self.uni_conv(adapter)
|
||||
# print("Adapter:", adapter.shape)
|
||||
# adapter = self.res_1(adapter)
|
||||
# adapter = self.res_2(adapter)
|
||||
return adapter
|
||||
|
|
|
@ -16,7 +16,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
import torch.nn.functional as F
|
||||
from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
||||
from dinov2.models.help import Merge_block, Model_level_Adapeter
|
||||
from dinov2.models.help import VitInputLevelAdapter as Input_level_Adapeter
|
||||
|
@ -123,6 +123,7 @@ class DinoVisionTransformer(nn.Module):
|
|||
self.merge_ratio = merge_ratio
|
||||
|
||||
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
self.patch_embed_for_model_adapter = embed_layer(img_size=56, patch_size=4, in_chans=16, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
|
@ -206,8 +207,9 @@ class DinoVisionTransformer(nn.Module):
|
|||
self.merge_1 = Merge_block(fea_c=fea_c_s[0], ada_c=ada_c_s[0], mid_c=mid_c_s[0], return_ada=True)
|
||||
self.merge_2 = Merge_block(fea_c=fea_c_s[1], ada_c=ada_c_s[1], mid_c=mid_c_s[1], return_ada=True)
|
||||
self.merge_3 = Merge_block(fea_c=fea_c_s[2], ada_c=ada_c_s[2], mid_c=mid_c_s[2], return_ada=False)
|
||||
|
||||
self.merge_blocks = [self.merge_1, self.merge_2, self.merge_3]
|
||||
|
||||
print(self.merge_blocks)
|
||||
self.first_linear_proj = nn.Linear(625, 50)
|
||||
self.second_linear_proj = nn.Linear(3136, 257)
|
||||
|
||||
|
@ -279,11 +281,30 @@ class DinoVisionTransformer(nn.Module):
|
|||
x = x_raw[-1]
|
||||
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# print("ada.shape ", ada.shape, x.shape)
|
||||
ada = self.patch_embed_for_model_adapter(ada)
|
||||
tensor2_reshaped = ada.transpose(1, 2) # [32, 768, 196]
|
||||
|
||||
ada = F.interpolate(
|
||||
tensor2_reshaped,
|
||||
size=x.shape[1],
|
||||
mode='linear',
|
||||
align_corners=False
|
||||
)
|
||||
ada = ada.transpose(1, 2)
|
||||
|
||||
if masks is not None:
|
||||
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
||||
ada = torch.where(masks.unsqueeze(-1), self.mask_token.to(ada.dtype).unsqueeze(0), ada)
|
||||
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
ada = torch.cat((self.cls_token.expand(ada.shape[0], -1, -1), ada), dim=1)
|
||||
ada = ada + self.interpolate_pos_encoding(ada, w, h)
|
||||
|
||||
x = x + self.interpolate_pos_encoding(x, w, h)
|
||||
# ada = ada + self.interpolate_pos_encoding(ada, w, h)
|
||||
|
||||
|
||||
if self.register_tokens is not None:
|
||||
x = torch.cat(
|
||||
|
@ -294,19 +315,9 @@ class DinoVisionTransformer(nn.Module):
|
|||
),
|
||||
dim=1,
|
||||
)
|
||||
# print("x.shape",x.shape)
|
||||
|
||||
ada = ada.reshape(ada.shape[0], ada.shape[1], -1)
|
||||
batch_size, channels, features = ada.shape
|
||||
target_seq_len = x.shape[1]
|
||||
# print("Features: ", features, target_seq_len)
|
||||
if x.shape[1] == 50:
|
||||
linear_proj = self.first_linear_proj
|
||||
elif x.shape[1] == 257:
|
||||
linear_proj = self.second_linear_proj
|
||||
# linear_proj = nn.Linear(features, target_seq_len).to(ada.device).to(ada.dtype)
|
||||
ada = linear_proj(ada).permute(0, 2, 1)
|
||||
|
||||
|
||||
|
||||
return x, ada
|
||||
|
||||
def forward_features_list(self, x_list, masks_list):
|
||||
|
@ -314,8 +325,10 @@ class DinoVisionTransformer(nn.Module):
|
|||
|
||||
x_s = []
|
||||
ada_list = []
|
||||
# print("x_list", [i.shape for i in x_list])
|
||||
for x, masks in zip(x_list, masks_list):
|
||||
x_, ada = self.prepare_tokens_with_masks(x, masks)
|
||||
# print(x.shape, ada.shape, self.model_adapter)
|
||||
x_s.append(x_)
|
||||
ada_list.append(ada)
|
||||
|
||||
|
@ -323,9 +336,13 @@ class DinoVisionTransformer(nn.Module):
|
|||
x = x_s
|
||||
|
||||
for i, blk in enumerate(self.blocks):
|
||||
# print([j.shape for j in x])
|
||||
# print(ada.shape)
|
||||
x = blk(x)
|
||||
|
||||
|
||||
if self.w_lut and ada is not None and i < len(self.merge_blocks):
|
||||
# print("HERE 22")
|
||||
x_ada_pairs = [self.merge_blocks[i](x_i, ada_i, ratio=self.merge_ratio) for x_i, ada_i in zip(x, ada_list)]
|
||||
x, ada_list = map(list, zip(*x_ada_pairs))
|
||||
|
||||
|
@ -353,6 +370,7 @@ class DinoVisionTransformer(nn.Module):
|
|||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if self.w_lut and ada is not None and i < len(self.merge_blocks):
|
||||
# print("HERE 11", x.shape, ada.shape)
|
||||
x, ada = self.merge_blocks[i](x, ada, ratio=self.merge_ratio)
|
||||
|
||||
x_norm = self.norm(x)
|
||||
|
|
Loading…
Reference in New Issue