init layer linear

pull/511/head
Veronikkkka 2025-03-05 12:55:26 +00:00
parent 9c1bde505b
commit 8cb907ab93
1 changed files with 8 additions and 1 deletions

View File

@ -208,6 +208,8 @@ class DinoVisionTransformer(nn.Module):
self.merge_3 = Merge_block(fea_c=fea_c_s[2], ada_c=ada_c_s[2], mid_c=mid_c_s[2], return_ada=False)
self.merge_blocks = [self.merge_1, self.merge_2, self.merge_3]
self.first_linear_proj = nn.Linear(625, 50)
self.second_linear_proj = nn.Linear(3136, 257)
self.init_weights()
@ -296,7 +298,12 @@ class DinoVisionTransformer(nn.Module):
ada = ada.reshape(ada.shape[0], ada.shape[1], -1)
batch_size, channels, features = ada.shape
target_seq_len = x.shape[1]
linear_proj = nn.Linear(features, target_seq_len).to(ada.device).to(ada.dtype)
# print("Features: ", features, target_seq_len)
if x.shape[1] == 50:
linear_proj = self.first_linear_proj
elif x.shape[1] == 257:
linear_proj = self.second_linear_proj
# linear_proj = nn.Linear(features, target_seq_len).to(ada.device).to(ada.dtype)
ada = linear_proj(ada).permute(0, 2, 1)