mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fast_attn -> fused_attn, implement global config for enable/disable fused_attn, add to more models. vit clip openai 336 weights.
This commit is contained in:
parent
4d135421a3
commit
965d0a2d36
@ -5,8 +5,8 @@ from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbeddi
|
||||
from .blur_pool import BlurPool2d
|
||||
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
|
||||
set_layer_config
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
|
||||
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn
|
||||
from .conv2d_same import Conv2dSame, conv2d_same
|
||||
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
|
||||
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||
|
@ -1,10 +1,14 @@
|
||||
""" Model / Layer Config singleton state
|
||||
"""
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
'is_exportable', 'is_scriptable', 'is_no_jit',
|
||||
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
|
||||
'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn',
|
||||
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn'
|
||||
]
|
||||
|
||||
# Set to True if prefer to have layers with no jit optimization (includes activations)
|
||||
@ -22,6 +26,14 @@ _EXPORTABLE = False
|
||||
_SCRIPTABLE = False
|
||||
|
||||
|
||||
# use torch.scaled_dot_product_attention where possible
|
||||
_HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||||
if 'TIMM_FUSED_ATTN' in os.environ:
|
||||
_USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN'])
|
||||
else:
|
||||
_USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
|
||||
|
||||
|
||||
def is_no_jit():
|
||||
return _NO_JIT
|
||||
|
||||
@ -113,3 +125,25 @@ class set_layer_config:
|
||||
global _NO_ACTIVATION_JIT
|
||||
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
|
||||
return False
|
||||
|
||||
|
||||
def use_fused_attn(experimental: bool = False) -> bool:
|
||||
# NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0
|
||||
if not _HAS_FUSED_ATTN or _EXPORTABLE:
|
||||
return False
|
||||
if experimental:
|
||||
return _USE_FUSED_ATTN > 1
|
||||
return _USE_FUSED_ATTN > 0
|
||||
|
||||
|
||||
def set_fused_attn(enable: bool = True, experimental: bool = False):
|
||||
global _USE_FUSED_ATTN
|
||||
if not _HAS_FUSED_ATTN:
|
||||
warnings.warn('This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.')
|
||||
return
|
||||
if experimental and enable:
|
||||
_USE_FUSED_ATTN = 2
|
||||
elif enable:
|
||||
_USE_FUSED_ATTN = 1
|
||||
else:
|
||||
_USE_FUSED_ATTN = 0
|
||||
|
@ -48,7 +48,7 @@ import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_
|
||||
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
@ -80,7 +80,7 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
fast_attn: Final[bool]
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -99,7 +99,7 @@ class Attention(nn.Module):
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
@ -142,15 +142,14 @@ class Attention(nn.Module):
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
|
||||
|
||||
if self.fast_attn:
|
||||
if self.fused_attn:
|
||||
rel_pos_bias = None
|
||||
if self.relative_position_bias_table is not None:
|
||||
rel_pos_bias = self._get_rel_pos_bias()
|
||||
if shared_rel_pos_bias is not None:
|
||||
rel_pos_bias = rel_pos_bias + shared_rel_pos_bias
|
||||
elif shared_rel_pos_bias is not None:
|
||||
rel_pos_bias = shared_rel_pos_bias
|
||||
else:
|
||||
rel_pos_bias = None
|
||||
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
|
@ -14,7 +14,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model
|
||||
@ -73,12 +73,15 @@ default_cfgs = dict(
|
||||
|
||||
class ClassAttn(nn.Module):
|
||||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to do CA
|
||||
# with slight modifications to do CA
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.k = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
@ -91,15 +94,21 @@ class ClassAttn(nn.Module):
|
||||
B, N, C = x.shape
|
||||
q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
|
||||
q = q * self.scale
|
||||
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
if self.fused_attn:
|
||||
x_cls = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x_cls = attn @ v
|
||||
|
||||
x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
|
||||
x_cls = x_cls.transpose(1, 2).reshape(B, 1, C)
|
||||
x_cls = self.proj(x_cls)
|
||||
x_cls = self.proj_drop(x_cls)
|
||||
|
||||
@ -179,7 +188,7 @@ class TalkingHeadAttn(nn.Module):
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
|
@ -20,7 +20,7 @@ import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer
|
||||
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
|
||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
@ -229,6 +229,7 @@ class WindowAttention(nn.Module):
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
"""
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True):
|
||||
super().__init__()
|
||||
@ -237,6 +238,7 @@ class WindowAttention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
@ -249,11 +251,15 @@ class WindowAttention(nn.Module):
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
attn = self.softmax(attn)
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
attn = self.softmax(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = x.transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
@ -35,7 +35,8 @@ from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
|
||||
apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, to_2tuple
|
||||
apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, \
|
||||
to_2tuple, use_fused_attn
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
@ -44,7 +45,7 @@ __all__ = ['Eva']
|
||||
|
||||
|
||||
class EvaAttention(nn.Module):
|
||||
fast_attn: Final[bool]
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -76,7 +77,7 @@ class EvaAttention(nn.Module):
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
if qkv_fused:
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
@ -121,7 +122,7 @@ class EvaAttention(nn.Module):
|
||||
q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], 2).type_as(v)
|
||||
k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], 2).type_as(v)
|
||||
|
||||
if self.fast_attn:
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
@ -488,11 +489,8 @@ class Eva(nn.Module):
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
@ -787,6 +785,11 @@ default_cfgs = generate_default_cfgs({
|
||||
#hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
|
||||
num_classes=768,
|
||||
),
|
||||
'eva02_large_patch14_clip_336.clip': _cfg(
|
||||
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
|
||||
input_size=(3, 336, 336), crop_pct=1.0,
|
||||
num_classes=768,
|
||||
),
|
||||
'eva02_enormous_patch14_clip_224.clip': _cfg(
|
||||
#hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
|
||||
num_classes=1024,
|
||||
@ -1011,6 +1014,28 @@ def eva02_large_patch14_clip_224(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva02_large_patch14_clip_336(pretrained=False, **kwargs):
|
||||
# A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large
|
||||
model_args = dict(
|
||||
img_size=336,
|
||||
patch_size=14,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4 * 2 / 3,
|
||||
qkv_fused=False,
|
||||
swiglu_mlp=True,
|
||||
scale_mlp=True,
|
||||
scale_attn_inner=True,
|
||||
use_rot_pos_emb=True,
|
||||
ref_feat_shape=(16, 16), # 224/14
|
||||
global_pool='token',
|
||||
)
|
||||
model = _create_eva('eva02_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def eva02_enormous_patch14_clip_224(pretrained=False, **kwargs):
|
||||
# A EVA-CLIP specific variant that uses residual post-norm in blocks
|
||||
|
@ -48,7 +48,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
|
||||
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
|
||||
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
|
||||
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf
|
||||
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
@ -140,7 +140,7 @@ class MaxxVitCfg:
|
||||
|
||||
|
||||
class Attention2d(nn.Module):
|
||||
fast_attn: Final[bool]
|
||||
fused_attn: Final[bool]
|
||||
|
||||
""" multi-head attention for 2D NCHW tensors"""
|
||||
def __init__(
|
||||
@ -162,7 +162,7 @@ class Attention2d(nn.Module):
|
||||
self.dim_head = dim_head
|
||||
self.head_first = head_first
|
||||
self.scale = dim_head ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
|
||||
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
|
||||
@ -178,13 +178,13 @@ class Attention2d(nn.Module):
|
||||
else:
|
||||
q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
|
||||
|
||||
if self.fast_attn:
|
||||
if self.fused_attn:
|
||||
attn_bias = None
|
||||
if self.rel_pos is not None:
|
||||
attn_bias = self.rel_pos.get_bias()
|
||||
elif shared_rel_pos is not None:
|
||||
attn_bias = shared_rel_pos
|
||||
else:
|
||||
attn_bias = None
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q.transpose(-1, -2),
|
||||
k.transpose(-1, -2),
|
||||
@ -210,7 +210,7 @@ class Attention2d(nn.Module):
|
||||
|
||||
class AttentionCl(nn.Module):
|
||||
""" Channels-last multi-head attention (B, ..., C) """
|
||||
fast_attn: Final[bool]
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -232,7 +232,7 @@ class AttentionCl(nn.Module):
|
||||
self.dim_head = dim_head
|
||||
self.head_first = head_first
|
||||
self.scale = dim_head ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias)
|
||||
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
|
||||
@ -249,13 +249,13 @@ class AttentionCl(nn.Module):
|
||||
else:
|
||||
q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2)
|
||||
|
||||
if self.fast_attn:
|
||||
if self.fused_attn:
|
||||
attn_bias = None
|
||||
if self.rel_pos is not None:
|
||||
attn_bias = self.rel_pos.get_bias()
|
||||
elif shared_rel_pos is not None:
|
||||
attn_bias = shared_rel_pos
|
||||
else:
|
||||
attn_bias = None
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_bias,
|
||||
|
@ -23,7 +23,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, _assert
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \
|
||||
_assert, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq, named_apply
|
||||
@ -86,6 +87,7 @@ class WindowAttention(nn.Module):
|
||||
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports shifted and non-shifted windows.
|
||||
"""
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -116,6 +118,7 @@ class WindowAttention(nn.Module):
|
||||
head_dim = head_dim or dim // num_heads
|
||||
attn_dim = head_dim * num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn(experimental=True) # NOTE not tested for prime-time yet
|
||||
|
||||
# define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
|
||||
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))
|
||||
@ -147,21 +150,30 @@ class WindowAttention(nn.Module):
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
attn = attn + self._get_rel_pos_bias()
|
||||
|
||||
if mask is not None:
|
||||
num_win = mask.shape[0]
|
||||
attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
if self.fused_attn:
|
||||
attn_mask = self._get_rel_pos_bias()
|
||||
if mask is not None:
|
||||
num_win = mask.shape[0]
|
||||
mask = mask.view(1, num_win, 1, N, N).expand(B_ // num_win, -1, self.num_heads, -1, -1)
|
||||
attn_mask = attn_mask + mask.reshape(-1, self.num_heads, N, N)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn + self._get_rel_pos_bias()
|
||||
if mask is not None:
|
||||
num_win = mask.shape[0]
|
||||
attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, -1)
|
||||
x = x.transpose(1, 2).reshape(B_, N, -1)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
@ -20,7 +20,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_
|
||||
from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._registry import register_model
|
||||
@ -68,6 +68,8 @@ Size_ = Tuple[int, int]
|
||||
class LocallyGroupedAttn(nn.Module):
|
||||
""" LSA: self attention within a group
|
||||
"""
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1):
|
||||
assert ws != 1
|
||||
super(LocallyGroupedAttn, self).__init__()
|
||||
@ -77,6 +79,7 @@ class LocallyGroupedAttn(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
@ -100,12 +103,22 @@ class LocallyGroupedAttn(nn.Module):
|
||||
x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3)
|
||||
qkv = self.qkv(x).reshape(
|
||||
B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C)
|
||||
x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C)
|
||||
x = x.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C)
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
x = x.reshape(B, N, C)
|
||||
@ -152,6 +165,8 @@ class LocallyGroupedAttn(nn.Module):
|
||||
class GlobalSubSampleAttn(nn.Module):
|
||||
""" GSA: using a key to summarize the information for a group to be efficient.
|
||||
"""
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
||||
@ -160,6 +175,7 @@ class GlobalSubSampleAttn(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.q = nn.Linear(dim, dim, bias=True)
|
||||
self.kv = nn.Linear(dim, dim * 2, bias=True)
|
||||
@ -184,13 +200,21 @@ class GlobalSubSampleAttn(nn.Module):
|
||||
x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
|
||||
x = self.norm(x)
|
||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
k, v = kv.unbind(0)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
if self.fused_attn:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
@ -200,8 +224,18 @@ class GlobalSubSampleAttn(nn.Module):
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., proj_drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
sr_ratio=1,
|
||||
ws=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
if ws is None:
|
||||
@ -210,14 +244,20 @@ class Block(nn.Module):
|
||||
self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, proj_drop, sr_ratio)
|
||||
else:
|
||||
self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, proj_drop, ws)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x, size: Size_):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), size))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
x = x + self.drop_path1(self.attn(self.norm1(x), size))
|
||||
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@ -225,7 +265,9 @@ class PosConv(nn.Module):
|
||||
# PEG from https://arxiv.org/abs/2102.10882
|
||||
def __init__(self, in_chans, embed_dim=768, stride=1):
|
||||
super(PosConv, self).__init__()
|
||||
self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), )
|
||||
self.proj = nn.Sequential(
|
||||
nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim),
|
||||
)
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x, size: Size_):
|
||||
|
@ -11,7 +11,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier
|
||||
from timm.layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model
|
||||
@ -90,7 +90,7 @@ class SpatialMlp(nn.Module):
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
fast_attn: torch.jit.Final[bool]
|
||||
fused_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
@ -99,7 +99,7 @@ class Attention(nn.Module):
|
||||
head_dim = round(dim // num_heads * head_dim_ratio)
|
||||
self.head_dim = head_dim
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
@ -111,7 +111,7 @@ class Attention(nn.Module):
|
||||
x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
|
||||
q, k, v = x.unbind(0)
|
||||
|
||||
if self.fast_attn:
|
||||
if self.fused_attn:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
|
@ -38,7 +38,7 @@ from torch.jit import Final
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
||||
resample_abs_pos_embed, RmsNorm, PatchDropout
|
||||
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
@ -50,7 +50,7 @@ _logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
fast_attn: Final[bool]
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -67,7 +67,7 @@ class Attention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
@ -82,7 +82,7 @@ class Attention(nn.Module):
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fast_attn:
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
@ -215,7 +215,7 @@ class ParallelScalingBlock(nn.Module):
|
||||
Based on:
|
||||
'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
|
||||
"""
|
||||
fast_attn: Final[bool]
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -236,7 +236,7 @@ class ParallelScalingBlock(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||
self.fused_attn = use_fused_attn()
|
||||
mlp_hidden_dim = int(mlp_ratio * dim)
|
||||
in_proj_out_dim = mlp_hidden_dim + 3 * dim
|
||||
|
||||
@ -279,7 +279,7 @@ class ParallelScalingBlock(nn.Module):
|
||||
q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
|
||||
k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
|
||||
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
if self.fast_attn:
|
||||
if self.fused_attn:
|
||||
x_attn = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
@ -1219,6 +1219,10 @@ default_cfgs = generate_default_cfgs({
|
||||
'vit_large_patch14_clip_224.openai': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
|
||||
'vit_large_patch14_clip_336.openai': _cfg(
|
||||
hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
crop_pct=1.0, input_size=(3, 336, 336), num_classes=768),
|
||||
|
||||
# experimental (may be removed)
|
||||
'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
|
||||
|
@ -15,7 +15,7 @@ from torch.jit import Final
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
||||
@ -25,7 +25,7 @@ _logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RelPosAttention(nn.Module):
|
||||
fast_attn: Final[bool]
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -43,7 +43,7 @@ class RelPosAttention(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||
self.fused_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
@ -60,13 +60,13 @@ class RelPosAttention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
if self.fast_attn:
|
||||
if self.fused_attn:
|
||||
attn_bias = None
|
||||
if self.rel_pos is not None:
|
||||
attn_bias = self.rel_pos.get_bias()
|
||||
elif shared_rel_pos is not None:
|
||||
attn_bias = shared_rel_pos
|
||||
else:
|
||||
attn_bias = None
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_bias,
|
||||
|
Loading…
x
Reference in New Issue
Block a user