mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
torchscript typing fixes
This commit is contained in:
parent
2a1a6b1236
commit
70176a2dae
@ -207,7 +207,7 @@ class MultiQueryAttention2d(nn.Module):
|
|||||||
self.proj_drop = nn.Dropout(proj_drop)
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
self.einsum = False
|
self.einsum = False
|
||||||
|
|
||||||
def _reshape_input(self, t):
|
def _reshape_input(self, t: torch.Tensor):
|
||||||
"""Reshapes a tensor to three dimensions, keeping the batch and channels."""
|
"""Reshapes a tensor to three dimensions, keeping the batch and channels."""
|
||||||
s = t.shape
|
s = t.shape
|
||||||
t = t.reshape(s[0], s[1], -1).transpose(1, 2)
|
t = t.reshape(s[0], s[1], -1).transpose(1, 2)
|
||||||
@ -216,7 +216,7 @@ class MultiQueryAttention2d(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return t.unsqueeze(1).contiguous()
|
return t.unsqueeze(1).contiguous()
|
||||||
|
|
||||||
def _reshape_projected_query(self, t, num_heads, key_dim):
|
def _reshape_projected_query(self, t: torch.Tensor, num_heads: int, key_dim: int):
|
||||||
"""Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k]."""
|
"""Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k]."""
|
||||||
s = t.shape
|
s = t.shape
|
||||||
t = t.reshape(s[0], num_heads, key_dim, -1)
|
t = t.reshape(s[0], num_heads, key_dim, -1)
|
||||||
@ -225,7 +225,7 @@ class MultiQueryAttention2d(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return t.transpose(-1, -2).contiguous()
|
return t.transpose(-1, -2).contiguous()
|
||||||
|
|
||||||
def _reshape_output(self, t, num_heads, h_px, w_px):
|
def _reshape_output(self, t: torch.Tensor, num_heads: int, h_px: int, w_px: int):
|
||||||
"""Reshape output:[b, n x n x h, k] -> [b, n, n, hk]."""
|
"""Reshape output:[b, n x n x h, k] -> [b, n, n, hk]."""
|
||||||
s = t.shape
|
s = t.shape
|
||||||
feat_dim = s[-1] * num_heads
|
feat_dim = s[-1] * num_heads
|
||||||
@ -233,8 +233,6 @@ class MultiQueryAttention2d(nn.Module):
|
|||||||
t = t.transpose(1, 2)
|
t = t.transpose(1, 2)
|
||||||
return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous()
|
return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
||||||
"""Run layer computation."""
|
"""Run layer computation."""
|
||||||
B, C, H, W = s = x.shape
|
B, C, H, W = s = x.shape
|
||||||
@ -273,7 +271,7 @@ class MultiQueryAttention2d(nn.Module):
|
|||||||
o = F.scaled_dot_product_attention(
|
o = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
dropout_p=self.attn_drop.p if self.training else 0
|
dropout_p=self.attn_drop.p if self.training else 0.
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
Loading…
x
Reference in New Issue
Block a user