torchscript typing fixes

This commit is contained in:
Ross Wightman 2024-05-23 11:43:05 -07:00
parent 2a1a6b1236
commit 70176a2dae

View File

@ -207,7 +207,7 @@ class MultiQueryAttention2d(nn.Module):
self.proj_drop = nn.Dropout(proj_drop)
self.einsum = False
def _reshape_input(self, t):
def _reshape_input(self, t: torch.Tensor):
"""Reshapes a tensor to three dimensions, keeping the batch and channels."""
s = t.shape
t = t.reshape(s[0], s[1], -1).transpose(1, 2)
@ -216,7 +216,7 @@ class MultiQueryAttention2d(nn.Module):
else:
return t.unsqueeze(1).contiguous()
def _reshape_projected_query(self, t, num_heads, key_dim):
def _reshape_projected_query(self, t: torch.Tensor, num_heads: int, key_dim: int):
"""Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k]."""
s = t.shape
t = t.reshape(s[0], num_heads, key_dim, -1)
@ -225,7 +225,7 @@ class MultiQueryAttention2d(nn.Module):
else:
return t.transpose(-1, -2).contiguous()
def _reshape_output(self, t, num_heads, h_px, w_px):
def _reshape_output(self, t: torch.Tensor, num_heads: int, h_px: int, w_px: int):
"""Reshape output:[b, n x n x h, k] -> [b, n, n, hk]."""
s = t.shape
feat_dim = s[-1] * num_heads
@ -233,8 +233,6 @@ class MultiQueryAttention2d(nn.Module):
t = t.transpose(1, 2)
return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous()
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
"""Run layer computation."""
B, C, H, W = s = x.shape
@ -273,7 +271,7 @@ class MultiQueryAttention2d(nn.Module):
o = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0
dropout_p=self.attn_drop.p if self.training else 0.
)
else:
q = q * self.scale