init layer linear
parent
9c1bde505b
commit
8cb907ab93
|
@ -208,6 +208,8 @@ class DinoVisionTransformer(nn.Module):
|
|||
self.merge_3 = Merge_block(fea_c=fea_c_s[2], ada_c=ada_c_s[2], mid_c=mid_c_s[2], return_ada=False)
|
||||
self.merge_blocks = [self.merge_1, self.merge_2, self.merge_3]
|
||||
|
||||
self.first_linear_proj = nn.Linear(625, 50)
|
||||
self.second_linear_proj = nn.Linear(3136, 257)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
|
@ -296,7 +298,12 @@ class DinoVisionTransformer(nn.Module):
|
|||
ada = ada.reshape(ada.shape[0], ada.shape[1], -1)
|
||||
batch_size, channels, features = ada.shape
|
||||
target_seq_len = x.shape[1]
|
||||
linear_proj = nn.Linear(features, target_seq_len).to(ada.device).to(ada.dtype)
|
||||
# print("Features: ", features, target_seq_len)
|
||||
if x.shape[1] == 50:
|
||||
linear_proj = self.first_linear_proj
|
||||
elif x.shape[1] == 257:
|
||||
linear_proj = self.second_linear_proj
|
||||
# linear_proj = nn.Linear(features, target_seq_len).to(ada.device).to(ada.dtype)
|
||||
ada = linear_proj(ada).permute(0, 2, 1)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue