mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix torchscript issue with legacy tnt
This commit is contained in:
parent
74ad32a67e
commit
16d0b26e19
@ -103,6 +103,7 @@ class Block(nn.Module):
|
|||||||
if self.legacy:
|
if self.legacy:
|
||||||
self.norm1_proj = norm_layer(dim)
|
self.norm1_proj = norm_layer(dim)
|
||||||
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True)
|
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True)
|
||||||
|
self.norm2_proj = None
|
||||||
else:
|
else:
|
||||||
self.norm1_proj = norm_layer(dim * num_pixel)
|
self.norm1_proj = norm_layer(dim * num_pixel)
|
||||||
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=False)
|
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=False)
|
||||||
@ -135,7 +136,7 @@ class Block(nn.Module):
|
|||||||
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
|
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
|
||||||
# outer
|
# outer
|
||||||
B, N, C = patch_embed.size()
|
B, N, C = patch_embed.size()
|
||||||
if self.legacy:
|
if self.norm2_proj is None:
|
||||||
patch_embed = torch.cat([
|
patch_embed = torch.cat([
|
||||||
patch_embed[:, 0:1],
|
patch_embed[:, 0:1],
|
||||||
patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
|
patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user