Fix torchscript issue with legacy tnt

This commit is contained in:
Ross Wightman 2025-05-14 09:33:41 -07:00
parent 74ad32a67e
commit 16d0b26e19

View File

@ -103,6 +103,7 @@ class Block(nn.Module):
if self.legacy: if self.legacy:
self.norm1_proj = norm_layer(dim) self.norm1_proj = norm_layer(dim)
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True) self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True)
self.norm2_proj = None
else: else:
self.norm1_proj = norm_layer(dim * num_pixel) self.norm1_proj = norm_layer(dim * num_pixel)
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=False) self.proj = nn.Linear(dim * num_pixel, dim_out, bias=False)
@ -135,7 +136,7 @@ class Block(nn.Module):
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
# outer # outer
B, N, C = patch_embed.size() B, N, C = patch_embed.size()
if self.legacy: if self.norm2_proj is None:
patch_embed = torch.cat([ patch_embed = torch.cat([
patch_embed[:, 0:1], patch_embed[:, 0:1],
patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)), patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),