Add final attr for fast_attn on beit / eva

This commit is contained in:
Ross Wightman 2023-03-28 08:40:40 -07:00
parent a84abe6656
commit ac67098147
2 changed files with 6 additions and 2 deletions

View File

@ -40,7 +40,7 @@ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
import math import math
from functools import partial from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Final, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -80,6 +80,8 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
class Attention(nn.Module): class Attention(nn.Module):
fast_attn: Final[bool]
def __init__( def __init__(
self, self,
dim: int, dim: int,

View File

@ -26,7 +26,7 @@ Modifications by / Copyright 2023 Ross Wightman, original copyrights below
# EVA02 models Copyright (c) 2023 BAAI-Vision # EVA02 models Copyright (c) 2023 BAAI-Vision
import math import math
from typing import Callable, Optional, Tuple, Union from typing import Callable, Final, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -44,6 +44,8 @@ __all__ = ['Eva']
class EvaAttention(nn.Module): class EvaAttention(nn.Module):
fast_attn: Final[bool]
def __init__( def __init__(
self, self,
dim: int, dim: int,