diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 7e86103d..fa6e1fc9 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -103,6 +103,7 @@ class Block(nn.Module): if self.legacy: self.norm1_proj = norm_layer(dim) self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True) + self.norm2_proj = None else: self.norm1_proj = norm_layer(dim * num_pixel) self.proj = nn.Linear(dim * num_pixel, dim_out, bias=False) @@ -135,7 +136,7 @@ class Block(nn.Module): pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) # outer B, N, C = patch_embed.size() - if self.legacy: + if self.norm2_proj is None: patch_embed = torch.cat([ patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),