Remove sdpa context mgrs

This commit is contained in:
Ross Wightman 2023-09-25 23:30:56 -07:00
parent 2734bb76ce
commit 379780bb6c

View File

@ -124,7 +124,7 @@ def pack_images(
): ):
max_seq_len = max_grid_size[0] * max_grid_size[1] max_seq_len = max_grid_size[0] * max_grid_size[1]
# patchify if needed, generate position indices, apply patch drop, record seq lengths # patchify, generate position indices, apply patch drop, record seq lengths
img_tokens = [] img_tokens = []
img_pos_indices = [] img_pos_indices = []
img_seq_lens = [] img_seq_lens = []
@ -144,6 +144,7 @@ def pack_images(
indexing='ij'), indexing='ij'),
dim=-1, dim=-1,
) )
# FIXME patch drop here
img_tokens.append(patches.flatten(0, 1)) img_tokens.append(patches.flatten(0, 1))
img_pos_indices.append(pos_indices.flatten(0, 1)) img_pos_indices.append(pos_indices.flatten(0, 1))
img_seq_lens.append(seq_len) img_seq_lens.append(seq_len)
@ -221,12 +222,11 @@ class Attention(nn.Module):
attn_mask = attn_mask.expand((-1, self.num_heads, -1, -1)) attn_mask = attn_mask.expand((-1, self.num_heads, -1, -1))
if self.fused_attn: if self.fused_attn:
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False): x = F.scaled_dot_product_attention(
x = F.scaled_dot_product_attention( q, k, v,
q, k, v, attn_mask=attn_mask,
attn_mask=attn_mask, dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p, )
)
else: else:
q = q * self.scale q = q * self.scale
attn = q @ k.transpose(-2, -1) attn = q @ k.transpose(-2, -1)
@ -374,12 +374,11 @@ class ParallelScalingBlock(nn.Module):
k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
if self.fused_attn: if self.fused_attn:
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False): x_attn = F.scaled_dot_product_attention(
x_attn = F.scaled_dot_product_attention( q, k, v,
q, k, v, attn_mask=attn_mask,
attn_mask=attn_mask, dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p, )
)
else: else:
q = q * self.scale q = q * self.scale
attn = q @ k.transpose(-2, -1) attn = q @ k.transpose(-2, -1)
@ -507,11 +506,10 @@ class AttentionPoolLatent(nn.Module):
q = self.q_norm(q) q = self.q_norm(q)
k = self.k_norm(k) k = self.k_norm(k)
if False: if False:
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False): x = F.scaled_dot_product_attention(
x = F.scaled_dot_product_attention( q, k, v,
q, k, v, attn_mask=attn_mask,
attn_mask=attn_mask, )
)
else: else:
q = q * self.scale q = q * self.scale
attn = q @ k.transpose(-2, -1) attn = q @ k.transpose(-2, -1)