From 965d0a2d363668b7f8d1794e45c52d525bdb6278 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 10 Apr 2023 12:04:33 -0700 Subject: [PATCH] fast_attn -> fused_attn, implement global config for enable/disable fused_attn, add to more models. vit clip openai 336 weights. --- timm/layers/__init__.py | 4 +- timm/layers/config.py | 38 ++++++++++- timm/models/beit.py | 11 ++-- timm/models/cait.py | 27 +++++--- timm/models/davit.py | 16 +++-- timm/models/eva.py | 43 ++++++++++--- timm/models/maxxvit.py | 22 +++---- timm/models/swin_transformer.py | 38 +++++++---- timm/models/twins.py | 82 ++++++++++++++++++------ timm/models/visformer.py | 8 +-- timm/models/vision_transformer.py | 18 ++++-- timm/models/vision_transformer_relpos.py | 12 ++-- 12 files changed, 225 insertions(+), 94 deletions(-) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 576af8d1..2a14ba0c 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -5,8 +5,8 @@ from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbeddi from .blur_pool import BlurPool2d from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead from .cond_conv2d import CondConv2d, get_condconv_initializer -from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ - set_layer_config +from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \ + set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn from .conv2d_same import Conv2dSame, conv2d_same from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct from .create_act import create_act_layer, get_act_layer, get_act_fn diff --git a/timm/layers/config.py b/timm/layers/config.py index f07b9d78..47d5d0a3 100644 --- a/timm/layers/config.py +++ b/timm/layers/config.py @@ -1,10 +1,14 @@ """ Model / Layer Config singleton state """ +import os +import warnings from typing import Any, Optional +import torch + __all__ = [ - 'is_exportable', 'is_scriptable', 'is_no_jit', - 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' + 'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn' ] # Set to True if prefer to have layers with no jit optimization (includes activations) @@ -22,6 +26,14 @@ _EXPORTABLE = False _SCRIPTABLE = False +# use torch.scaled_dot_product_attention where possible +_HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention') +if 'TIMM_FUSED_ATTN' in os.environ: + _USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN']) +else: + _USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use) + + def is_no_jit(): return _NO_JIT @@ -113,3 +125,25 @@ class set_layer_config: global _NO_ACTIVATION_JIT _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev return False + + +def use_fused_attn(experimental: bool = False) -> bool: + # NOTE: ONNX export cannot handle F.scaled_dot_product_attention as of pytorch 2.0 + if not _HAS_FUSED_ATTN or _EXPORTABLE: + return False + if experimental: + return _USE_FUSED_ATTN > 1 + return _USE_FUSED_ATTN > 0 + + +def set_fused_attn(enable: bool = True, experimental: bool = False): + global _USE_FUSED_ATTN + if not _HAS_FUSED_ATTN: + warnings.warn('This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.') + return + if experimental and enable: + _USE_FUSED_ATTN = 2 + elif enable: + _USE_FUSED_ATTN = 1 + else: + _USE_FUSED_ATTN = 0 diff --git a/timm/models/beit.py b/timm/models/beit.py index 457d86eb..d3d77ba8 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -48,7 +48,7 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_ +from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model @@ -80,7 +80,7 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor: class Attention(nn.Module): - fast_attn: Final[bool] + fused_attn: Final[bool] def __init__( self, @@ -99,7 +99,7 @@ class Attention(nn.Module): head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = head_dim ** -0.5 - self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: @@ -142,15 +142,14 @@ class Attention(nn.Module): qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim - if self.fast_attn: + if self.fused_attn: + rel_pos_bias = None if self.relative_position_bias_table is not None: rel_pos_bias = self._get_rel_pos_bias() if shared_rel_pos_bias is not None: rel_pos_bias = rel_pos_bias + shared_rel_pos_bias elif shared_rel_pos_bias is not None: rel_pos_bias = shared_rel_pos_bias - else: - rel_pos_bias = None x = F.scaled_dot_product_attention( q, k, v, diff --git a/timm/models/cait.py b/timm/models/cait.py index 98c58397..4cc4fdd4 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, use_fused_attn from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model @@ -73,12 +73,15 @@ default_cfgs = dict( class ClassAttn(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py - # with slight modifications to do CA + # with slight modifications to do CA + fused_attn: torch.jit.Final[bool] + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() self.q = nn.Linear(dim, dim, bias=qkv_bias) self.k = nn.Linear(dim, dim, bias=qkv_bias) @@ -91,15 +94,21 @@ class ClassAttn(nn.Module): B, N, C = x.shape q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - - q = q * self.scale v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) - attn = (q @ k.transpose(-2, -1)) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + if self.fused_attn: + x_cls = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x_cls = attn @ v - x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) + x_cls = x_cls.transpose(1, 2).reshape(B, 1, C) x_cls = self.proj(x_cls) x_cls = self.proj_drop(x_cls) @@ -179,7 +188,7 @@ class TalkingHeadAttn(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] - attn = (q @ k.transpose(-2, -1)) + attn = q @ k.transpose(-2, -1) attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) diff --git a/timm/models/davit.py b/timm/models/davit.py index bbc5e421..181d84b9 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from torch import Tensor from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer +from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn from timm.layers import NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function @@ -229,6 +229,7 @@ class WindowAttention(nn.Module): num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True """ + fused_attn: torch.jit.Final[bool] def __init__(self, dim, window_size, num_heads, qkv_bias=True): super().__init__() @@ -237,6 +238,7 @@ class WindowAttention(nn.Module): self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) @@ -249,11 +251,15 @@ class WindowAttention(nn.Module): qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - attn = self.softmax(attn) + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v) + else: + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = self.softmax(attn) + x = attn @ v - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = x.transpose(1, 2).reshape(B_, N, C) x = self.proj(x) return x diff --git a/timm/models/eva.py b/timm/models/eva.py index b5cd1fec..be7f3996 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -35,7 +35,8 @@ from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \ - apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, to_2tuple + apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, \ + to_2tuple, use_fused_attn from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model @@ -44,7 +45,7 @@ __all__ = ['Eva'] class EvaAttention(nn.Module): - fast_attn: Final[bool] + fused_attn: Final[bool] def __init__( self, @@ -76,7 +77,7 @@ class EvaAttention(nn.Module): head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = head_dim ** -0.5 - self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + self.fused_attn = use_fused_attn() if qkv_fused: self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) @@ -121,7 +122,7 @@ class EvaAttention(nn.Module): q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], 2).type_as(v) k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], 2).type_as(v) - if self.fast_attn: + if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, @@ -488,11 +489,8 @@ class Eva(nn.Module): def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) + if m.bias is not None: + nn.init.zeros_(m.bias) @torch.jit.ignore def no_weight_decay(self): @@ -787,6 +785,11 @@ default_cfgs = generate_default_cfgs({ #hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt', num_classes=768, ), + 'eva02_large_patch14_clip_336.clip': _cfg( + # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt', + input_size=(3, 336, 336), crop_pct=1.0, + num_classes=768, + ), 'eva02_enormous_patch14_clip_224.clip': _cfg( #hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt', num_classes=1024, @@ -1011,6 +1014,28 @@ def eva02_large_patch14_clip_224(pretrained=False, **kwargs): return model +@register_model +def eva02_large_patch14_clip_336(pretrained=False, **kwargs): + # A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large + model_args = dict( + img_size=336, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4 * 2 / 3, + qkv_fused=False, + swiglu_mlp=True, + scale_mlp=True, + scale_attn_inner=True, + use_rot_pos_emb=True, + ref_feat_shape=(16, 16), # 224/14 + global_pool='token', + ) + model = _create_eva('eva02_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def eva02_enormous_patch14_clip_224(pretrained=False, **kwargs): # A EVA-CLIP specific variant that uses residual post-norm in blocks diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 0773ffab..17b64a9d 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -48,7 +48,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert -from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf +from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import named_apply, checkpoint_seq @@ -140,7 +140,7 @@ class MaxxVitCfg: class Attention2d(nn.Module): - fast_attn: Final[bool] + fused_attn: Final[bool] """ multi-head attention for 2D NCHW tensors""" def __init__( @@ -162,7 +162,7 @@ class Attention2d(nn.Module): self.dim_head = dim_head self.head_first = head_first self.scale = dim_head ** -0.5 - self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME + self.fused_attn = use_fused_attn() self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None @@ -178,13 +178,13 @@ class Attention2d(nn.Module): else: q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1) - if self.fast_attn: + if self.fused_attn: + attn_bias = None if self.rel_pos is not None: attn_bias = self.rel_pos.get_bias() elif shared_rel_pos is not None: attn_bias = shared_rel_pos - else: - attn_bias = None + x = torch.nn.functional.scaled_dot_product_attention( q.transpose(-1, -2), k.transpose(-1, -2), @@ -210,7 +210,7 @@ class Attention2d(nn.Module): class AttentionCl(nn.Module): """ Channels-last multi-head attention (B, ..., C) """ - fast_attn: Final[bool] + fused_attn: Final[bool] def __init__( self, @@ -232,7 +232,7 @@ class AttentionCl(nn.Module): self.dim_head = dim_head self.head_first = head_first self.scale = dim_head ** -0.5 - self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME + self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias) self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None @@ -249,13 +249,13 @@ class AttentionCl(nn.Module): else: q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2) - if self.fast_attn: + if self.fused_attn: + attn_bias = None if self.rel_pos is not None: attn_bias = self.rel_pos.get_bias() elif shared_rel_pos is not None: attn_bias = shared_rel_pos - else: - attn_bias = None + x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_bias, diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index f4df31c6..5cc99f23 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -23,7 +23,8 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, _assert +from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \ + _assert, use_fused_attn from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq, named_apply @@ -86,6 +87,7 @@ class WindowAttention(nn.Module): """ Window based multi-head self attention (W-MSA) module with relative position bias. It supports shifted and non-shifted windows. """ + fused_attn: torch.jit.Final[bool] def __init__( self, @@ -116,6 +118,7 @@ class WindowAttention(nn.Module): head_dim = head_dim or dim // num_heads attn_dim = head_dim * num_heads self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn(experimental=True) # NOTE not tested for prime-time yet # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)) @@ -147,21 +150,30 @@ class WindowAttention(nn.Module): qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - attn = attn + self._get_rel_pos_bias() - - if mask is not None: - num_win = mask.shape[0] - attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) + if self.fused_attn: + attn_mask = self._get_rel_pos_bias() + if mask is not None: + num_win = mask.shape[0] + mask = mask.view(1, num_win, 1, N, N).expand(B_ // num_win, -1, self.num_heads, -1, -1) + attn_mask = attn_mask + mask.reshape(-1, self.num_heads, N, N) + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p, + ) else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn + self._get_rel_pos_bias() + if mask is not None: + num_win = mask.shape[0] + attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) + attn = self.attn_drop(attn) + x = attn @ v - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, -1) + x = x.transpose(1, 2).reshape(B_, N, -1) x = self.proj(x) x = self.proj_drop(x) return x diff --git a/timm/models/twins.py b/timm/models/twins.py index 25fc95c7..ddf7897d 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -20,7 +20,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from timm.layers import Mlp, DropPath, to_2tuple, trunc_normal_, use_fused_attn from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module from ._registry import register_model @@ -68,6 +68,8 @@ Size_ = Tuple[int, int] class LocallyGroupedAttn(nn.Module): """ LSA: self attention within a group """ + fused_attn: torch.jit.Final[bool] + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): assert ws != 1 super(LocallyGroupedAttn, self).__init__() @@ -77,6 +79,7 @@ class LocallyGroupedAttn(nn.Module): self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=True) self.attn_drop = nn.Dropout(attn_drop) @@ -100,12 +103,22 @@ class LocallyGroupedAttn(nn.Module): x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) qkv = self.qkv(x).reshape( B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) - q, k, v = qkv[0], qkv[1], qkv[2] - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) - x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + q, k, v = qkv.unbind(0) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = x.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() x = x.reshape(B, N, C) @@ -152,6 +165,8 @@ class LocallyGroupedAttn(nn.Module): class GlobalSubSampleAttn(nn.Module): """ GSA: using a key to summarize the information for a group to be efficient. """ + fused_attn: torch.jit.Final[bool] + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): super().__init__() assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." @@ -160,6 +175,7 @@ class GlobalSubSampleAttn(nn.Module): self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn() self.q = nn.Linear(dim, dim, bias=True) self.kv = nn.Linear(dim, dim * 2, bias=True) @@ -184,13 +200,21 @@ class GlobalSubSampleAttn(nn.Module): x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) x = self.norm(x) kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - k, v = kv[0], kv[1] + k, v = kv.unbind(0) - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + if self.fused_attn: + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v - x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) @@ -200,8 +224,18 @@ class GlobalSubSampleAttn(nn.Module): class Block(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., proj_drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None): + self, + dim, + num_heads, + mlp_ratio=4., + proj_drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ws=None, + ): super().__init__() self.norm1 = norm_layer(dim) if ws is None: @@ -210,14 +244,20 @@ class Block(nn.Module): self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, proj_drop, sr_ratio) else: self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, proj_drop, ws) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x, size: Size_): - x = x + self.drop_path(self.attn(self.norm1(x), size)) - x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x + self.drop_path1(self.attn(self.norm1(x), size)) + x = x + self.drop_path2(self.mlp(self.norm2(x))) return x @@ -225,7 +265,9 @@ class PosConv(nn.Module): # PEG from https://arxiv.org/abs/2102.10882 def __init__(self, in_chans, embed_dim=768, stride=1): super(PosConv, self).__init__() - self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), ) + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), + ) self.stride = stride def forward(self, x, size: Size_): diff --git a/timm/models/visformer.py b/timm/models/visformer.py index d44d56fd..ae83c963 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier +from timm.layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier, use_fused_attn from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model @@ -90,7 +90,7 @@ class SpatialMlp(nn.Module): class Attention(nn.Module): - fast_attn: torch.jit.Final[bool] + fused_attn: torch.jit.Final[bool] def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.): super().__init__() @@ -99,7 +99,7 @@ class Attention(nn.Module): head_dim = round(dim // num_heads * head_dim_ratio) self.head_dim = head_dim self.scale = head_dim ** -0.5 - self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME + self.fused_attn = use_fused_attn() self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False) self.attn_drop = nn.Dropout(attn_drop) @@ -111,7 +111,7 @@ class Attention(nn.Module): x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) q, k, v = x.unbind(0) - if self.fast_attn: + if self.fused_attn: x = torch.nn.functional.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop.p, diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 9034bc5d..2c47c1f0 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -38,7 +38,7 @@ from torch.jit import Final from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ - resample_abs_pos_embed, RmsNorm, PatchDropout + resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn from ._builder import build_model_with_cfg from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._registry import generate_default_cfgs, register_model, register_model_deprecations @@ -50,7 +50,7 @@ _logger = logging.getLogger(__name__) class Attention(nn.Module): - fast_attn: Final[bool] + fused_attn: Final[bool] def __init__( self, @@ -67,7 +67,7 @@ class Attention(nn.Module): self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 - self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME + self.fused_attn = use_fused_attn() self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() @@ -82,7 +82,7 @@ class Attention(nn.Module): q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) - if self.fast_attn: + if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop.p, @@ -215,7 +215,7 @@ class ParallelScalingBlock(nn.Module): Based on: 'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442 """ - fast_attn: Final[bool] + fused_attn: Final[bool] def __init__( self, @@ -236,7 +236,7 @@ class ParallelScalingBlock(nn.Module): self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 - self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME + self.fused_attn = use_fused_attn() mlp_hidden_dim = int(mlp_ratio * dim) in_proj_out_dim = mlp_hidden_dim + 3 * dim @@ -279,7 +279,7 @@ class ParallelScalingBlock(nn.Module): q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) - if self.fast_attn: + if self.fused_attn: x_attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop.p, @@ -1219,6 +1219,10 @@ default_cfgs = generate_default_cfgs({ 'vit_large_patch14_clip_224.openai': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + 'vit_large_patch14_clip_336.openai': _cfg( + hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), num_classes=768), # experimental (may be removed) 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index a511f66b..239dfada 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -15,7 +15,7 @@ from torch.jit import Final from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias +from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model @@ -25,7 +25,7 @@ _logger = logging.getLogger(__name__) class RelPosAttention(nn.Module): - fast_attn: Final[bool] + fused_attn: Final[bool] def __init__( self, @@ -43,7 +43,7 @@ class RelPosAttention(nn.Module): self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 - self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME + self.fused_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() @@ -60,13 +60,13 @@ class RelPosAttention(nn.Module): q = self.q_norm(q) k = self.k_norm(k) - if self.fast_attn: + if self.fused_attn: + attn_bias = None if self.rel_pos is not None: attn_bias = self.rel_pos.get_bias() elif shared_rel_pos is not None: attn_bias = shared_rel_pos - else: - attn_bias = None + x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_bias,