Add ideas from 'Scaling ViT to 22-B Params', testing PyTorch 2.0 fused F.scaled_dot_product_attention impl in vit, vit_relpos, maxxvit / coatnet.
parent
a3d528524a
commit
621e1b2182
|
@ -28,7 +28,7 @@ from .linear import Linear
|
||||||
from .mixed_conv2d import MixedConv2d
|
from .mixed_conv2d import MixedConv2d
|
||||||
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
|
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
|
||||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
|
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
|
||||||
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
|
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
|
||||||
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
|
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
|
||||||
from .padding import get_padding, get_same_padding, pad_same
|
from .padding import get_padding, get_same_padding, pad_same
|
||||||
|
|
|
@ -17,6 +17,12 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_apex = False
|
has_apex = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from apex.normalization.fused_layer_norm import fused_rms_norm_affine, fused_rms_norm
|
||||||
|
has_apex_rmsnorm = True
|
||||||
|
except ImportError:
|
||||||
|
has_apex_rmsnorm = False
|
||||||
|
|
||||||
|
|
||||||
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
|
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
|
||||||
_USE_FAST_NORM = False # defaulting to False for now
|
_USE_FAST_NORM = False # defaulting to False for now
|
||||||
|
@ -76,3 +82,32 @@ def fast_layer_norm(
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
return F.layer_norm(x, normalized_shape, weight, bias, eps)
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
normalized_shape: List[int],
|
||||||
|
weight: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
):
|
||||||
|
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
|
||||||
|
v = torch.var(x, dim=dims, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(v + eps)
|
||||||
|
if weight is not None:
|
||||||
|
x = x * weight
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def fast_rms_norm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
normalized_shape: List[int],
|
||||||
|
weight: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if torch.jit.is_scripting() or not has_apex_rmsnorm:
|
||||||
|
return rms_norm(x, normalized_shape, weight, eps)
|
||||||
|
|
||||||
|
if weight is None:
|
||||||
|
return fused_rms_norm(x, normalized_shape, eps)
|
||||||
|
else:
|
||||||
|
return fused_rms_norm_affine(x, weight, normalized_shape, eps)
|
||||||
|
|
|
@ -4,12 +4,14 @@ Norm layer definitions that support fast norm and consistent channel arg order (
|
||||||
|
|
||||||
Hacked together by / Copyright 2022 Ross Wightman
|
Hacked together by / Copyright 2022 Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
import numbers
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
|
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm(nn.GroupNorm):
|
class GroupNorm(nn.GroupNorm):
|
||||||
|
@ -115,3 +117,38 @@ class LayerNormExp2d(nn.LayerNorm):
|
||||||
else:
|
else:
|
||||||
x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
|
x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RmsNorm(nn.Module):
|
||||||
|
""" RmsNorm w/ fast (apex) norm if available
|
||||||
|
"""
|
||||||
|
normalized_shape: Tuple[int, ...]
|
||||||
|
eps: float
|
||||||
|
elementwise_affine: bool
|
||||||
|
|
||||||
|
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
|
||||||
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
super().__init__()
|
||||||
|
normalized_shape = channels
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
# mypy error: incompatible types in assignment
|
||||||
|
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||||
|
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||||
|
self.eps = eps
|
||||||
|
self.elementwise_affine = affine
|
||||||
|
if self.elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.register_parameter('weight', None)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
if self.elementwise_affine:
|
||||||
|
nn.init.ones_(self.weight)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
|
||||||
|
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
|
||||||
|
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||||
|
return x
|
||||||
|
|
|
@ -83,8 +83,8 @@ def gen_relative_log_coords(
|
||||||
pretrained_win_size: Tuple[int, int] = (0, 0),
|
pretrained_win_size: Tuple[int, int] = (0, 0),
|
||||||
mode='swin',
|
mode='swin',
|
||||||
):
|
):
|
||||||
assert mode in ('swin', 'cr', 'rw')
|
assert mode in ('swin', 'cr')
|
||||||
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
|
# as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
|
||||||
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
|
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
|
||||||
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
|
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
|
||||||
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
|
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
|
||||||
|
@ -100,18 +100,9 @@ def gen_relative_log_coords(
|
||||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
||||||
1.0 + relative_coords_table.abs()) / math.log2(8)
|
1.0 + relative_coords_table.abs()) / math.log2(8)
|
||||||
else:
|
else:
|
||||||
if mode == 'rw':
|
# mode == 'cr'
|
||||||
# cr w/ window size normalization -> [-1,1] log coords
|
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
|
||||||
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
|
1.0 + relative_coords_table.abs())
|
||||||
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
|
|
||||||
relative_coords_table *= 8 # scale to -8, 8
|
|
||||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
|
||||||
1.0 + relative_coords_table.abs())
|
|
||||||
relative_coords_table /= math.log2(9) # -> [-1, 1]
|
|
||||||
else:
|
|
||||||
# mode == 'cr'
|
|
||||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
|
|
||||||
1.0 + relative_coords_table.abs())
|
|
||||||
|
|
||||||
return relative_coords_table
|
return relative_coords_table
|
||||||
|
|
||||||
|
@ -141,10 +132,6 @@ class RelPosMlp(nn.Module):
|
||||||
self.bias_act = nn.Sigmoid()
|
self.bias_act = nn.Sigmoid()
|
||||||
self.bias_gain = 16
|
self.bias_gain = 16
|
||||||
mlp_bias = (True, False)
|
mlp_bias = (True, False)
|
||||||
elif mode == 'rw':
|
|
||||||
self.bias_act = nn.Tanh()
|
|
||||||
self.bias_gain = 4
|
|
||||||
mlp_bias = True
|
|
||||||
else:
|
else:
|
||||||
self.bias_act = nn.Identity()
|
self.bias_act = nn.Identity()
|
||||||
self.bias_gain = None
|
self.bias_gain = None
|
||||||
|
|
|
@ -160,6 +160,7 @@ class Attention2d(nn.Module):
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
self.head_first = head_first
|
self.head_first = head_first
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
|
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||||
|
|
||||||
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
|
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
|
||||||
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
|
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
|
||||||
|
@ -175,15 +176,31 @@ class Attention2d(nn.Module):
|
||||||
else:
|
else:
|
||||||
q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
|
q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
|
||||||
|
|
||||||
attn = (q.transpose(-2, -1) @ k) * self.scale
|
if self.fast_attn:
|
||||||
if self.rel_pos is not None:
|
if self.rel_pos is not None:
|
||||||
attn = self.rel_pos(attn)
|
attn_bias = self.rel_pos.get_bias()
|
||||||
elif shared_rel_pos is not None:
|
elif shared_rel_pos is not None:
|
||||||
attn = attn + shared_rel_pos
|
attn_bias = shared_rel_pos
|
||||||
attn = attn.softmax(dim=-1)
|
else:
|
||||||
attn = self.attn_drop(attn)
|
attn_bias = None
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q.transpose(-1, -2),
|
||||||
|
k.transpose(-1, -2),
|
||||||
|
v.transpose(-1, -2),
|
||||||
|
attn_mask=attn_bias,
|
||||||
|
dropout_p=self.attn_drop.p,
|
||||||
|
).transpose(-1, -2).reshape(B, -1, H, W)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q.transpose(-2, -1) @ k
|
||||||
|
if self.rel_pos is not None:
|
||||||
|
attn = self.rel_pos(attn)
|
||||||
|
elif shared_rel_pos is not None:
|
||||||
|
attn = attn + shared_rel_pos
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
|
||||||
|
|
||||||
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
|
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
|
@ -211,6 +228,7 @@ class AttentionCl(nn.Module):
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
self.head_first = head_first
|
self.head_first = head_first
|
||||||
self.scale = dim_head ** -0.5
|
self.scale = dim_head ** -0.5
|
||||||
|
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias)
|
self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias)
|
||||||
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
|
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
|
||||||
|
@ -227,15 +245,30 @@ class AttentionCl(nn.Module):
|
||||||
else:
|
else:
|
||||||
q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2)
|
q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2)
|
||||||
|
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
if self.fast_attn:
|
||||||
if self.rel_pos is not None:
|
if self.rel_pos is not None:
|
||||||
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
|
attn_bias = self.rel_pos.get_bias()
|
||||||
elif shared_rel_pos is not None:
|
elif shared_rel_pos is not None:
|
||||||
attn = attn + shared_rel_pos
|
attn_bias = shared_rel_pos
|
||||||
attn = attn.softmax(dim=-1)
|
else:
|
||||||
attn = self.attn_drop(attn)
|
attn_bias = None
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
attn_mask=attn_bias,
|
||||||
|
dropout_p=self.attn_drop.p,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
if self.rel_pos is not None:
|
||||||
|
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
|
||||||
|
elif shared_rel_pos is not None:
|
||||||
|
attn = attn + shared_rel_pos
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v
|
||||||
|
|
||||||
x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,))
|
x = x.transpose(1, 2).reshape(restore_shape + (-1,))
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -37,7 +37,7 @@ import torch.utils.checkpoint
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
||||||
resample_abs_pos_embed
|
resample_abs_pos_embed, RmsNorm
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||||
from ._pretrained import generate_default_cfgs
|
from ._pretrained import generate_default_cfgs
|
||||||
|
@ -51,28 +51,49 @@ _logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads=8,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_norm=False,
|
||||||
|
attn_drop=0.,
|
||||||
|
proj_drop=0.,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
head_dim = dim // num_heads
|
self.head_dim = dim // num_heads
|
||||||
self.scale = head_dim ** -0.5
|
self.scale = self.head_dim ** -0.5
|
||||||
|
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
self.proj = nn.Linear(dim, dim)
|
self.proj = nn.Linear(dim, dim)
|
||||||
self.proj_drop = nn.Dropout(proj_drop)
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv.unbind(0)
|
||||||
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
if self.fast_attn:
|
||||||
attn = attn.softmax(dim=-1)
|
x = F.scaled_dot_product_attention(
|
||||||
attn = self.attn_drop(attn)
|
q, k, v,
|
||||||
|
dropout_p=self.attn_drop.p,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v
|
||||||
|
|
||||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
x = x.transpose(1, 2).reshape(B, N, C)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
|
@ -96,6 +117,7 @@ class Block(nn.Module):
|
||||||
num_heads,
|
num_heads,
|
||||||
mlp_ratio=4.,
|
mlp_ratio=4.,
|
||||||
qkv_bias=False,
|
qkv_bias=False,
|
||||||
|
qk_norm=False,
|
||||||
drop=0.,
|
drop=0.,
|
||||||
attn_drop=0.,
|
attn_drop=0.,
|
||||||
init_values=None,
|
init_values=None,
|
||||||
|
@ -105,13 +127,25 @@ class Block(nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
self.attn = Attention(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
|
||||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
self.norm2 = norm_layer(dim)
|
self.norm2 = norm_layer(dim)
|
||||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
self.mlp = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=int(dim * mlp_ratio),
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
|
@ -129,6 +163,7 @@ class ResPostBlock(nn.Module):
|
||||||
num_heads,
|
num_heads,
|
||||||
mlp_ratio=4.,
|
mlp_ratio=4.,
|
||||||
qkv_bias=False,
|
qkv_bias=False,
|
||||||
|
qk_norm=False,
|
||||||
drop=0.,
|
drop=0.,
|
||||||
attn_drop=0.,
|
attn_drop=0.,
|
||||||
init_values=None,
|
init_values=None,
|
||||||
|
@ -139,11 +174,24 @@ class ResPostBlock(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.init_values = init_values
|
self.init_values = init_values
|
||||||
|
|
||||||
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
self.attn = Attention(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
self.mlp = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=int(dim * mlp_ratio),
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
self.norm2 = norm_layer(dim)
|
self.norm2 = norm_layer(dim)
|
||||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
|
@ -161,8 +209,61 @@ class ResPostBlock(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ParallelBlock(nn.Module):
|
class ParallelScalingBlock(nn.Module):
|
||||||
|
""" Parallel ViT block (MLP & Attention in parallel)
|
||||||
|
Based on:
|
||||||
|
'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio=4.,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_norm=False,
|
||||||
|
drop=0.,
|
||||||
|
attn_drop=0.,
|
||||||
|
init_values=None,
|
||||||
|
drop_path=0.,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
norm_layer=nn.LayerNorm
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||||
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
self.mlp = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=int(dim * mlp_ratio),
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||||
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y1 = self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
||||||
|
y2 = self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
||||||
|
x = x + y1 + y2
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelThingsBlock(nn.Module):
|
||||||
|
""" Parallel ViT block (N parallel attention followed by N parallel MLP)
|
||||||
|
Based on:
|
||||||
|
`Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
|
@ -170,6 +271,7 @@ class ParallelBlock(nn.Module):
|
||||||
num_parallel=2,
|
num_parallel=2,
|
||||||
mlp_ratio=4.,
|
mlp_ratio=4.,
|
||||||
qkv_bias=False,
|
qkv_bias=False,
|
||||||
|
qk_norm=False,
|
||||||
init_values=None,
|
init_values=None,
|
||||||
drop=0.,
|
drop=0.,
|
||||||
attn_drop=0.,
|
attn_drop=0.,
|
||||||
|
@ -184,13 +286,26 @@ class ParallelBlock(nn.Module):
|
||||||
for _ in range(num_parallel):
|
for _ in range(num_parallel):
|
||||||
self.attns.append(nn.Sequential(OrderedDict([
|
self.attns.append(nn.Sequential(OrderedDict([
|
||||||
('norm', norm_layer(dim)),
|
('norm', norm_layer(dim)),
|
||||||
('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)),
|
('attn', Attention(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)),
|
||||||
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
|
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
|
||||||
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
|
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
|
||||||
])))
|
])))
|
||||||
self.ffns.append(nn.Sequential(OrderedDict([
|
self.ffns.append(nn.Sequential(OrderedDict([
|
||||||
('norm', norm_layer(dim)),
|
('norm', norm_layer(dim)),
|
||||||
('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)),
|
('mlp', Mlp(
|
||||||
|
dim,
|
||||||
|
hidden_features=int(dim * mlp_ratio),
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)),
|
||||||
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
|
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
|
||||||
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
|
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
|
||||||
])))
|
])))
|
||||||
|
@ -232,6 +347,7 @@ class VisionTransformer(nn.Module):
|
||||||
num_heads=12,
|
num_heads=12,
|
||||||
mlp_ratio=4.,
|
mlp_ratio=4.,
|
||||||
qkv_bias=True,
|
qkv_bias=True,
|
||||||
|
qk_norm=False,
|
||||||
init_values=None,
|
init_values=None,
|
||||||
class_token=True,
|
class_token=True,
|
||||||
no_embed_class=False,
|
no_embed_class=False,
|
||||||
|
@ -305,6 +421,7 @@ class VisionTransformer(nn.Module):
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm,
|
||||||
init_values=init_values,
|
init_values=init_values,
|
||||||
drop=drop_rate,
|
drop=drop_rate,
|
||||||
attn_drop=attn_drop_rate,
|
attn_drop=attn_drop_rate,
|
||||||
|
@ -641,9 +758,8 @@ def checkpoint_filter_fn(
|
||||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||||
import re
|
import re
|
||||||
out_dict = {}
|
out_dict = {}
|
||||||
if 'model' in state_dict:
|
state_dict = state_dict.get('model', state_dict)
|
||||||
# For deit models
|
state_dict = state_dict.get('state_dict', state_dict)
|
||||||
state_dict = state_dict['model']
|
|
||||||
|
|
||||||
if 'visual.class_embedding' in state_dict:
|
if 'visual.class_embedding' in state_dict:
|
||||||
return _convert_openai_clip(state_dict, model)
|
return _convert_openai_clip(state_dict, model)
|
||||||
|
@ -1129,6 +1245,9 @@ default_cfgs = generate_default_cfgs({
|
||||||
url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True,
|
url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True,
|
||||||
hf_hub_id='timm/',
|
hf_hub_id='timm/',
|
||||||
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
|
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
|
||||||
|
|
||||||
|
'vit_large_patch14_xp_224.untrained': _cfg(url=''),
|
||||||
|
'vit_huge_patch14_xp_224.untrained': _cfg(url=''),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@ -1566,7 +1685,7 @@ def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
|
||||||
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
|
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock)
|
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelThingsBlock)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
@ -1577,7 +1696,8 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
|
||||||
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
|
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
|
||||||
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
|
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock)
|
model_kwargs = dict(
|
||||||
|
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelThingsBlock)
|
||||||
model = _create_vision_transformer(
|
model = _create_vision_transformer(
|
||||||
'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
@ -1625,3 +1745,29 @@ def flexivit_large(pretrained=False, **kwargs):
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
|
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
|
||||||
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_large_patch14_xp_224(pretrained=False, **kwargs):
|
||||||
|
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, no_embed_class=True,
|
||||||
|
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_large_patch14_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_huge_patch14_xp_224(pretrained=False, **kwargs):
|
||||||
|
""" ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, no_embed_class=True,
|
||||||
|
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
|
||||||
|
)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_huge_patch14_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
|
return model
|
||||||
|
|
|
@ -25,14 +25,27 @@ _logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RelPosAttention(nn.Module):
|
class RelPosAttention(nn.Module):
|
||||||
def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
num_heads=8,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_norm=False,
|
||||||
|
rel_pos_cls=None,
|
||||||
|
attn_drop=0.,
|
||||||
|
proj_drop=0.,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
head_dim = dim // num_heads
|
self.head_dim = dim // num_heads
|
||||||
self.scale = head_dim ** -0.5
|
self.scale = self.head_dim ** -0.5
|
||||||
|
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||||
self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None
|
self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
self.proj = nn.Linear(dim, dim)
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
@ -40,18 +53,35 @@ class RelPosAttention(nn.Module):
|
||||||
|
|
||||||
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
|
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv.unbind(0)
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
if self.fast_attn:
|
||||||
if self.rel_pos is not None:
|
if self.rel_pos is not None:
|
||||||
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
|
attn_bias = self.rel_pos.get_bias()
|
||||||
elif shared_rel_pos is not None:
|
elif shared_rel_pos is not None:
|
||||||
attn = attn + shared_rel_pos
|
attn_bias = shared_rel_pos
|
||||||
attn = attn.softmax(dim=-1)
|
else:
|
||||||
attn = self.attn_drop(attn)
|
attn_bias = None
|
||||||
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
attn_mask=attn_bias,
|
||||||
|
dropout_p=self.attn_drop.p,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q = q * self.scale
|
||||||
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
if self.rel_pos is not None:
|
||||||
|
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
|
||||||
|
elif shared_rel_pos is not None:
|
||||||
|
attn = attn + shared_rel_pos
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
x = attn @ v
|
||||||
|
|
||||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
x = x.transpose(1, 2).reshape(B, N, C)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
return x
|
return x
|
||||||
|
@ -70,18 +100,42 @@ class LayerScale(nn.Module):
|
||||||
class RelPosBlock(nn.Module):
|
class RelPosBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None,
|
self,
|
||||||
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
dim,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio=4.,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_norm=False,
|
||||||
|
rel_pos_cls=None,
|
||||||
|
init_values=None,
|
||||||
|
drop=0.,
|
||||||
|
attn_drop=0.,
|
||||||
|
drop_path=0.,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
self.attn = RelPosAttention(
|
self.attn = RelPosAttention(
|
||||||
dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop)
|
dim,
|
||||||
|
num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
rel_pos_cls=rel_pos_cls,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
)
|
||||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
self.norm2 = norm_layer(dim)
|
self.norm2 = norm_layer(dim)
|
||||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
self.mlp = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=int(dim * mlp_ratio),
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
|
@ -94,17 +148,41 @@ class RelPosBlock(nn.Module):
|
||||||
class ResPostRelPosBlock(nn.Module):
|
class ResPostRelPosBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None,
|
self,
|
||||||
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
dim,
|
||||||
|
num_heads,
|
||||||
|
mlp_ratio=4.,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_norm=False,
|
||||||
|
rel_pos_cls=None,
|
||||||
|
init_values=None,
|
||||||
|
drop=0.,
|
||||||
|
attn_drop=0.,
|
||||||
|
drop_path=0.,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.init_values = init_values
|
self.init_values = init_values
|
||||||
|
|
||||||
self.attn = RelPosAttention(
|
self.attn = RelPosAttention(
|
||||||
dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop)
|
dim,
|
||||||
|
num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
rel_pos_cls=rel_pos_cls,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
)
|
||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
self.mlp = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=int(dim * mlp_ratio),
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
)
|
||||||
self.norm2 = norm_layer(dim)
|
self.norm2 = norm_layer(dim)
|
||||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
|
@ -144,6 +222,7 @@ class VisionTransformerRelPos(nn.Module):
|
||||||
num_heads=12,
|
num_heads=12,
|
||||||
mlp_ratio=4.,
|
mlp_ratio=4.,
|
||||||
qkv_bias=True,
|
qkv_bias=True,
|
||||||
|
qk_norm=False,
|
||||||
init_values=1e-6,
|
init_values=1e-6,
|
||||||
class_token=False,
|
class_token=False,
|
||||||
fc_norm=False,
|
fc_norm=False,
|
||||||
|
@ -171,6 +250,7 @@ class VisionTransformerRelPos(nn.Module):
|
||||||
num_heads (int): number of attention heads
|
num_heads (int): number of attention heads
|
||||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||||
qkv_bias (bool): enable bias for qkv if True
|
qkv_bias (bool): enable bias for qkv if True
|
||||||
|
qk_norm (bool): Enable normalization of query and key in attention
|
||||||
init_values: (float): layer-scale init values
|
init_values: (float): layer-scale init values
|
||||||
class_token (bool): use class token (default: False)
|
class_token (bool): use class token (default: False)
|
||||||
fc_norm (bool): use pre classifier norm instead of pre-pool
|
fc_norm (bool): use pre classifier norm instead of pre-pool
|
||||||
|
@ -197,18 +277,19 @@ class VisionTransformerRelPos(nn.Module):
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
self.patch_embed = embed_layer(
|
self.patch_embed = embed_layer(
|
||||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
img_size=img_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_chans,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
)
|
||||||
feat_size = self.patch_embed.grid_size
|
feat_size = self.patch_embed.grid_size
|
||||||
|
|
||||||
rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens)
|
rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens)
|
||||||
if rel_pos_type.startswith('mlp'):
|
if rel_pos_type.startswith('mlp'):
|
||||||
if rel_pos_dim:
|
if rel_pos_dim:
|
||||||
rel_pos_args['hidden_dim'] = rel_pos_dim
|
rel_pos_args['hidden_dim'] = rel_pos_dim
|
||||||
# FIXME experimenting with different relpos log coord configs
|
|
||||||
if 'swin' in rel_pos_type:
|
if 'swin' in rel_pos_type:
|
||||||
rel_pos_args['mode'] = 'swin'
|
rel_pos_args['mode'] = 'swin'
|
||||||
elif 'rw' in rel_pos_type:
|
|
||||||
rel_pos_args['mode'] = 'rw'
|
|
||||||
rel_pos_cls = partial(RelPosMlp, **rel_pos_args)
|
rel_pos_cls = partial(RelPosMlp, **rel_pos_args)
|
||||||
else:
|
else:
|
||||||
rel_pos_cls = partial(RelPosBias, **rel_pos_args)
|
rel_pos_cls = partial(RelPosBias, **rel_pos_args)
|
||||||
|
@ -223,9 +304,19 @@ class VisionTransformerRelPos(nn.Module):
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
block_fn(
|
block_fn(
|
||||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls,
|
dim=embed_dim,
|
||||||
init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
|
num_heads=num_heads,
|
||||||
norm_layer=norm_layer, act_layer=act_layer)
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
rel_pos_cls=rel_pos_cls,
|
||||||
|
init_values=init_values,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=dpr[i],
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act_layer=act_layer,
|
||||||
|
)
|
||||||
for i in range(depth)])
|
for i in range(depth)])
|
||||||
self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity()
|
self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue