mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add Final annotation to attn_fas to avoid symbol lookup of new scaled_dot_product_attn fn on old PyTorch in jit
This commit is contained in:
parent
621e1b2182
commit
122621daef
@ -42,6 +42,7 @@ from typing import Callable, Optional, Union, Tuple, List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.jit import Final
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
|
||||
@ -140,6 +141,8 @@ class MaxxVitCfg:
|
||||
|
||||
|
||||
class Attention2d(nn.Module):
|
||||
fast_attn: Final[bool]
|
||||
|
||||
""" multi-head attention for 2D NCHW tensors"""
|
||||
def __init__(
|
||||
self,
|
||||
@ -208,6 +211,8 @@ class Attention2d(nn.Module):
|
||||
|
||||
class AttentionCl(nn.Module):
|
||||
""" Channels-last multi-head attention (B, ..., C) """
|
||||
fast_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
|
@ -33,6 +33,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.jit import Final
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
@ -51,6 +52,8 @@ _logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
fast_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
|
@ -11,6 +11,7 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.jit import Final
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
@ -25,6 +26,8 @@ _logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RelPosAttention(nn.Module):
|
||||
fast_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
|
Loading…
x
Reference in New Issue
Block a user