From 70176a2dae0039d68d8eea3a0b9c8c4564af2e7e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 May 2024 11:43:05 -0700 Subject: [PATCH] torchscript typing fixes --- timm/layers/attention2d.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/timm/layers/attention2d.py b/timm/layers/attention2d.py index 3213a9f8..3d3f6d01 100644 --- a/timm/layers/attention2d.py +++ b/timm/layers/attention2d.py @@ -207,7 +207,7 @@ class MultiQueryAttention2d(nn.Module): self.proj_drop = nn.Dropout(proj_drop) self.einsum = False - def _reshape_input(self, t): + def _reshape_input(self, t: torch.Tensor): """Reshapes a tensor to three dimensions, keeping the batch and channels.""" s = t.shape t = t.reshape(s[0], s[1], -1).transpose(1, 2) @@ -216,7 +216,7 @@ class MultiQueryAttention2d(nn.Module): else: return t.unsqueeze(1).contiguous() - def _reshape_projected_query(self, t, num_heads, key_dim): + def _reshape_projected_query(self, t: torch.Tensor, num_heads: int, key_dim: int): """Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k].""" s = t.shape t = t.reshape(s[0], num_heads, key_dim, -1) @@ -225,7 +225,7 @@ class MultiQueryAttention2d(nn.Module): else: return t.transpose(-1, -2).contiguous() - def _reshape_output(self, t, num_heads, h_px, w_px): + def _reshape_output(self, t: torch.Tensor, num_heads: int, h_px: int, w_px: int): """Reshape output:[b, n x n x h, k] -> [b, n, n, hk].""" s = t.shape feat_dim = s[-1] * num_heads @@ -233,8 +233,6 @@ class MultiQueryAttention2d(nn.Module): t = t.transpose(1, 2) return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous() - - def forward(self, x, attn_mask: Optional[torch.Tensor] = None): """Run layer computation.""" B, C, H, W = s = x.shape @@ -273,7 +271,7 @@ class MultiQueryAttention2d(nn.Module): o = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, - dropout_p=self.attn_drop.p if self.training else 0 + dropout_p=self.attn_drop.p if self.training else 0. ) else: q = q * self.scale