mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
torchscript typing fixes
This commit is contained in:
parent
2a1a6b1236
commit
70176a2dae
@ -207,7 +207,7 @@ class MultiQueryAttention2d(nn.Module):
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.einsum = False
|
||||
|
||||
def _reshape_input(self, t):
|
||||
def _reshape_input(self, t: torch.Tensor):
|
||||
"""Reshapes a tensor to three dimensions, keeping the batch and channels."""
|
||||
s = t.shape
|
||||
t = t.reshape(s[0], s[1], -1).transpose(1, 2)
|
||||
@ -216,7 +216,7 @@ class MultiQueryAttention2d(nn.Module):
|
||||
else:
|
||||
return t.unsqueeze(1).contiguous()
|
||||
|
||||
def _reshape_projected_query(self, t, num_heads, key_dim):
|
||||
def _reshape_projected_query(self, t: torch.Tensor, num_heads: int, key_dim: int):
|
||||
"""Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k]."""
|
||||
s = t.shape
|
||||
t = t.reshape(s[0], num_heads, key_dim, -1)
|
||||
@ -225,7 +225,7 @@ class MultiQueryAttention2d(nn.Module):
|
||||
else:
|
||||
return t.transpose(-1, -2).contiguous()
|
||||
|
||||
def _reshape_output(self, t, num_heads, h_px, w_px):
|
||||
def _reshape_output(self, t: torch.Tensor, num_heads: int, h_px: int, w_px: int):
|
||||
"""Reshape output:[b, n x n x h, k] -> [b, n, n, hk]."""
|
||||
s = t.shape
|
||||
feat_dim = s[-1] * num_heads
|
||||
@ -233,8 +233,6 @@ class MultiQueryAttention2d(nn.Module):
|
||||
t = t.transpose(1, 2)
|
||||
return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
|
||||
|
||||
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
||||
"""Run layer computation."""
|
||||
B, C, H, W = s = x.shape
|
||||
@ -273,7 +271,7 @@ class MultiQueryAttention2d(nn.Module):
|
||||
o = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p if self.training else 0
|
||||
dropout_p=self.attn_drop.p if self.training else 0.
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
|
Loading…
x
Reference in New Issue
Block a user