mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add fused_attn flag to HieraDet attn block
This commit is contained in:
parent
691bb54443
commit
1bd92bca0e
@ -9,7 +9,7 @@ from torch.jit import Final
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, PatchDropout, \
|
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, PatchDropout, \
|
||||||
get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple
|
get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn
|
||||||
|
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features import feature_take_indices
|
from ._features import feature_take_indices
|
||||||
@ -58,6 +58,8 @@ def _calc_pad(H: int, W: int, window_size: Tuple[int, int]) -> Tuple[int, int, i
|
|||||||
|
|
||||||
|
|
||||||
class MultiScaleAttention(nn.Module):
|
class MultiScaleAttention(nn.Module):
|
||||||
|
fused_attn: torch.jit.Final[bool]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
@ -66,13 +68,12 @@ class MultiScaleAttention(nn.Module):
|
|||||||
q_pool: nn.Module = None,
|
q_pool: nn.Module = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.dim_out = dim_out
|
self.dim_out = dim_out
|
||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
head_dim = dim_out // num_heads
|
head_dim = dim_out // num_heads
|
||||||
self.scale = head_dim ** -0.5
|
self.scale = head_dim ** -0.5
|
||||||
|
self.fused_attn = use_fused_attn()
|
||||||
|
|
||||||
self.q_pool = q_pool
|
self.q_pool = q_pool
|
||||||
self.qkv = nn.Linear(dim, dim_out * 3)
|
self.qkv = nn.Linear(dim, dim_out * 3)
|
||||||
@ -95,11 +96,17 @@ class MultiScaleAttention(nn.Module):
|
|||||||
q = q.reshape(B, H * W, self.num_heads, -1)
|
q = q.reshape(B, H * W, self.num_heads, -1)
|
||||||
|
|
||||||
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
||||||
x = F.scaled_dot_product_attention(
|
q = q.transpose(1, 2)
|
||||||
q.transpose(1, 2),
|
k = k.transpose(1, 2)
|
||||||
k.transpose(1, 2),
|
v = v.transpose(1, 2)
|
||||||
v.transpose(1, 2),
|
if self.fused_attn:
|
||||||
)
|
x = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-1, -2)
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
x = attn @ v
|
||||||
|
|
||||||
# Transpose back
|
# Transpose back
|
||||||
x = x.transpose(1, 2).reshape(B, H, W, -1)
|
x = x.transpose(1, 2).reshape(B, H, W, -1)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user