mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove sdpa context mgrs
This commit is contained in:
parent
2734bb76ce
commit
379780bb6c
@ -124,7 +124,7 @@ def pack_images(
|
||||
):
|
||||
max_seq_len = max_grid_size[0] * max_grid_size[1]
|
||||
|
||||
# patchify if needed, generate position indices, apply patch drop, record seq lengths
|
||||
# patchify, generate position indices, apply patch drop, record seq lengths
|
||||
img_tokens = []
|
||||
img_pos_indices = []
|
||||
img_seq_lens = []
|
||||
@ -144,6 +144,7 @@ def pack_images(
|
||||
indexing='ij'),
|
||||
dim=-1,
|
||||
)
|
||||
# FIXME patch drop here
|
||||
img_tokens.append(patches.flatten(0, 1))
|
||||
img_pos_indices.append(pos_indices.flatten(0, 1))
|
||||
img_seq_lens.append(seq_len)
|
||||
@ -221,12 +222,11 @@ class Attention(nn.Module):
|
||||
attn_mask = attn_mask.expand((-1, self.num_heads, -1, -1))
|
||||
|
||||
if self.fused_attn:
|
||||
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
@ -374,12 +374,11 @@ class ParallelScalingBlock(nn.Module):
|
||||
k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
|
||||
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
if self.fused_attn:
|
||||
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):
|
||||
x_attn = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
x_attn = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
@ -507,11 +506,10 @@ class AttentionPoolLatent(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if False:
|
||||
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user