mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove sdpa context mgrs
This commit is contained in:
parent
2734bb76ce
commit
379780bb6c
@ -124,7 +124,7 @@ def pack_images(
|
|||||||
):
|
):
|
||||||
max_seq_len = max_grid_size[0] * max_grid_size[1]
|
max_seq_len = max_grid_size[0] * max_grid_size[1]
|
||||||
|
|
||||||
# patchify if needed, generate position indices, apply patch drop, record seq lengths
|
# patchify, generate position indices, apply patch drop, record seq lengths
|
||||||
img_tokens = []
|
img_tokens = []
|
||||||
img_pos_indices = []
|
img_pos_indices = []
|
||||||
img_seq_lens = []
|
img_seq_lens = []
|
||||||
@ -144,6 +144,7 @@ def pack_images(
|
|||||||
indexing='ij'),
|
indexing='ij'),
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
# FIXME patch drop here
|
||||||
img_tokens.append(patches.flatten(0, 1))
|
img_tokens.append(patches.flatten(0, 1))
|
||||||
img_pos_indices.append(pos_indices.flatten(0, 1))
|
img_pos_indices.append(pos_indices.flatten(0, 1))
|
||||||
img_seq_lens.append(seq_len)
|
img_seq_lens.append(seq_len)
|
||||||
@ -221,12 +222,11 @@ class Attention(nn.Module):
|
|||||||
attn_mask = attn_mask.expand((-1, self.num_heads, -1, -1))
|
attn_mask = attn_mask.expand((-1, self.num_heads, -1, -1))
|
||||||
|
|
||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):
|
x = F.scaled_dot_product_attention(
|
||||||
x = F.scaled_dot_product_attention(
|
q, k, v,
|
||||||
q, k, v,
|
attn_mask=attn_mask,
|
||||||
attn_mask=attn_mask,
|
dropout_p=self.attn_drop.p,
|
||||||
dropout_p=self.attn_drop.p,
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1)
|
attn = q @ k.transpose(-2, -1)
|
||||||
@ -374,12 +374,11 @@ class ParallelScalingBlock(nn.Module):
|
|||||||
k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
|
k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
|
||||||
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):
|
x_attn = F.scaled_dot_product_attention(
|
||||||
x_attn = F.scaled_dot_product_attention(
|
q, k, v,
|
||||||
q, k, v,
|
attn_mask=attn_mask,
|
||||||
attn_mask=attn_mask,
|
dropout_p=self.attn_drop.p,
|
||||||
dropout_p=self.attn_drop.p,
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1)
|
attn = q @ k.transpose(-2, -1)
|
||||||
@ -507,11 +506,10 @@ class AttentionPoolLatent(nn.Module):
|
|||||||
q = self.q_norm(q)
|
q = self.q_norm(q)
|
||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
if False:
|
if False:
|
||||||
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):
|
x = F.scaled_dot_product_attention(
|
||||||
x = F.scaled_dot_product_attention(
|
q, k, v,
|
||||||
q, k, v,
|
attn_mask=attn_mask,
|
||||||
attn_mask=attn_mask,
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1)
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user