Add fused_attn flag to HieraDet attn block

This commit is contained in:
Ross Wightman 2024-08-16 22:57:49 -07:00
parent 691bb54443
commit 1bd92bca0e

View File

@ -9,7 +9,7 @@ from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, PatchDropout, \ from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, PatchDropout, \
get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features import feature_take_indices from ._features import feature_take_indices
@ -58,6 +58,8 @@ def _calc_pad(H: int, W: int, window_size: Tuple[int, int]) -> Tuple[int, int, i
class MultiScaleAttention(nn.Module): class MultiScaleAttention(nn.Module):
fused_attn: torch.jit.Final[bool]
def __init__( def __init__(
self, self,
dim: int, dim: int,
@ -66,13 +68,12 @@ class MultiScaleAttention(nn.Module):
q_pool: nn.Module = None, q_pool: nn.Module = None,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.dim_out = dim_out self.dim_out = dim_out
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim_out // num_heads head_dim = dim_out // num_heads
self.scale = head_dim**-0.5 self.scale = head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.q_pool = q_pool self.q_pool = q_pool
self.qkv = nn.Linear(dim, dim_out * 3) self.qkv = nn.Linear(dim, dim_out * 3)
@ -95,11 +96,17 @@ class MultiScaleAttention(nn.Module):
q = q.reshape(B, H * W, self.num_heads, -1) q = q.reshape(B, H * W, self.num_heads, -1)
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
x = F.scaled_dot_product_attention( q = q.transpose(1, 2)
q.transpose(1, 2), k = k.transpose(1, 2)
k.transpose(1, 2), v = v.transpose(1, 2)
v.transpose(1, 2), if self.fused_attn:
) x = F.scaled_dot_product_attention(q, k, v)
else:
q = q * self.scale
attn = q @ k.transpose(-1, -2)
attn = attn.softmax(dim=-1)
x = attn @ v
# Transpose back # Transpose back
x = x.transpose(1, 2).reshape(B, H, W, -1) x = x.transpose(1, 2).reshape(B, H, W, -1)