diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py index c805217..2995900 100644 --- a/dinov2/models/vision_transformer.py +++ b/dinov2/models/vision_transformer.py @@ -208,6 +208,8 @@ class DinoVisionTransformer(nn.Module): self.merge_3 = Merge_block(fea_c=fea_c_s[2], ada_c=ada_c_s[2], mid_c=mid_c_s[2], return_ada=False) self.merge_blocks = [self.merge_1, self.merge_2, self.merge_3] + self.first_linear_proj = nn.Linear(625, 50) + self.second_linear_proj = nn.Linear(3136, 257) self.init_weights() @@ -296,7 +298,12 @@ class DinoVisionTransformer(nn.Module): ada = ada.reshape(ada.shape[0], ada.shape[1], -1) batch_size, channels, features = ada.shape target_seq_len = x.shape[1] - linear_proj = nn.Linear(features, target_seq_len).to(ada.device).to(ada.dtype) + # print("Features: ", features, target_seq_len) + if x.shape[1] == 50: + linear_proj = self.first_linear_proj + elif x.shape[1] == 257: + linear_proj = self.second_linear_proj + # linear_proj = nn.Linear(features, target_seq_len).to(ada.device).to(ada.dtype) ada = linear_proj(ada).permute(0, 2, 1)