Implement patch dropout for eva / vision_transformer, refactor / improve consistency of dropout args across all vit based models

patch_drop_refactor
Ross Wightman 2023-04-07 14:43:15 -07:00
parent 1bb3989b61
commit 4d135421a3
32 changed files with 1899 additions and 853 deletions

View File

@ -33,12 +33,14 @@ from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same
from .patch_dropout import PatchDropout
from .patch_embed import PatchEmbed, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct

View File

@ -0,0 +1,53 @@
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
return_indices: torch.jit.Final[bool]
def __init__(
self,
prob: float = 0.5,
num_prefix_tokens: int = 1,
ordered: bool = False,
return_indices: bool = False,
):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
self.ordered = ordered
self.return_indices = return_indices
def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
if not self.training or self.prob == 0.:
if self.return_indices:
return x, None
return x
if self.num_prefix_tokens:
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
else:
prefix_tokens = None
B = x.shape[0]
L = x.shape[1]
num_keep = max(1, int(L * (1. - self.prob)))
keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep]
if self.ordered:
# NOTE does not need to maintain patch order in typical transformer use,
# but possibly useful for debug / visualization
keep_indices = keep_indices.sort(dim=-1)[0]
x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
if prefix_tokens is not None:
x = torch.cat((prefix_tokens, x), dim=1)
if self.return_indices:
return x, keep_indices
return x

View File

@ -194,6 +194,8 @@ def rot(x):
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
if sin_emb.ndim == 3:
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
return x * cos_emb + rot(x) * sin_emb
@ -205,9 +207,17 @@ def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
def apply_rot_embed_cat(x: torch.Tensor, emb):
sin_emb, cos_emb = emb.tensor_split(2, -1)
if sin_emb.ndim == 3:
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
return x * cos_emb + rot(x) * sin_emb
def apply_keep_indices_nlc(x, pos_embed, keep_indices):
pos_embed = pos_embed.unsqueeze(0).expand(x.shape[0], -1, -1)
pos_embed = pos_embed.gather(1, keep_indices.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1]))
return pos_embed
def build_rotary_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,

View File

@ -279,6 +279,8 @@ class Beit(nn.Module):
swiglu_mlp: bool = False,
scale_mlp: bool = False,
drop_rate: float = 0.,
pos_drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
norm_layer: Callable = LayerNorm,
@ -306,7 +308,7 @@ class Beit(nn.Module):
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
self.pos_drop = nn.Dropout(p=drop_rate)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(
@ -325,7 +327,7 @@ class Beit(nn.Module):
mlp_ratio=mlp_ratio,
scale_mlp=scale_mlp,
swiglu_mlp=swiglu_mlp,
proj_drop=drop_rate,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
@ -337,6 +339,7 @@ class Beit(nn.Module):
use_fc_norm = self.global_pool == 'avg'
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
@ -417,6 +420,7 @@ class Beit(nn.Module):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):

View File

@ -110,17 +110,38 @@ class LayerScaleBlockClassAttn(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add CA and LayerScale
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn,
mlp_block=Mlp, init_values=1e-4):
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
attn_block=ClassAttn,
mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_block(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp = mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=proj_drop,
)
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
@ -177,17 +198,38 @@ class LayerScaleBlock(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add layerScale
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn,
mlp_block=Mlp, init_values=1e-4):
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
attn_block=TalkingHeadAttn,
mlp_block=Mlp,
init_values=1e-4,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_block(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp = mlp_block(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=proj_drop,
)
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
@ -201,9 +243,22 @@ class Cait(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to adapt to our cait models
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
drop_rate=0.,
pos_drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
block_layers=LayerScaleBlock,
block_layers_token=LayerScaleBlockClassAttn,
patch_layer=PatchEmbed,
@ -226,33 +281,50 @@ class Cait(nn.Module):
self.grad_checkpointing = False
self.patch_embed = patch_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.Sequential(*[
block_layers(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_values)
for i in range(depth)])
self.blocks = nn.Sequential(*[block_layers(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
attn_block=attn_block,
mlp_block=mlp_block,
init_values=init_values,
) for i in range(depth)])
self.blocks_token_only = nn.ModuleList([
block_layers_token(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_token_only, qkv_bias=qkv_bias,
drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer,
act_layer=act_layer, attn_block=attn_block_token_only,
mlp_block=mlp_block_token_only, init_values=init_values)
for i in range(depth_token_only)])
self.blocks_token_only = nn.ModuleList([block_layers_token(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio_token_only,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
attn_block=attn_block_token_only,
mlp_block=mlp_block_token_only,
init_values=init_values,
) for _ in range(depth_token_only)])
self.norm = norm_layer(embed_dim)
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
@ -322,6 +394,7 @@ class Cait(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):
@ -344,77 +417,80 @@ def _create_cait(variant, pretrained=False, **kwargs):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
Cait, variant, pretrained,
Cait,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
**kwargs,
)
return model
@register_model
def cait_xxs24_224(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5, **kwargs)
model = _create_cait('cait_xxs24_224', pretrained=pretrained, **model_args)
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5)
model = _create_cait('cait_xxs24_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def cait_xxs24_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5, **kwargs)
model = _create_cait('cait_xxs24_384', pretrained=pretrained, **model_args)
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5)
model = _create_cait('cait_xxs24_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def cait_xxs36_224(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5, **kwargs)
model = _create_cait('cait_xxs36_224', pretrained=pretrained, **model_args)
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5)
model = _create_cait('cait_xxs36_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def cait_xxs36_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5, **kwargs)
model = _create_cait('cait_xxs36_384', pretrained=pretrained, **model_args)
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5)
model = _create_cait('cait_xxs36_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def cait_xs24_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_values=1e-5, **kwargs)
model = _create_cait('cait_xs24_384', pretrained=pretrained, **model_args)
model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_values=1e-5)
model = _create_cait('cait_xs24_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def cait_s24_224(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5, **kwargs)
model = _create_cait('cait_s24_224', pretrained=pretrained, **model_args)
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5)
model = _create_cait('cait_s24_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def cait_s24_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5, **kwargs)
model = _create_cait('cait_s24_384', pretrained=pretrained, **model_args)
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5)
model = _create_cait('cait_s24_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def cait_s36_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_values=1e-6, **kwargs)
model = _create_cait('cait_s36_384', pretrained=pretrained, **model_args)
model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_values=1e-6)
model = _create_cait('cait_s36_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def cait_m36_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_values=1e-6, **kwargs)
model = _create_cait('cait_m36_384', pretrained=pretrained, **model_args)
model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_values=1e-6)
model = _create_cait('cait_m36_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def cait_m48_448(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_values=1e-6, **kwargs)
model = _create_cait('cait_m48_448', pretrained=pretrained, **model_args)
model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_values=1e-6)
model = _create_cait('cait_m48_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model

View File

@ -54,7 +54,7 @@ default_cfgs = {
class ConvRelPosEnc(nn.Module):
""" Convolutional relative position encoding. """
def __init__(self, Ch, h, window):
def __init__(self, head_chs, num_heads, window):
"""
Initialization.
Ch: Channels per head.
@ -70,7 +70,7 @@ class ConvRelPosEnc(nn.Module):
if isinstance(window, int):
# Set the same window size for all attention heads.
window = {window: h}
window = {window: num_heads}
self.window = window
elif isinstance(window, dict):
self.window = window
@ -84,18 +84,20 @@ class ConvRelPosEnc(nn.Module):
# Determine padding size.
# Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2
cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch,
cur_conv = nn.Conv2d(
cur_head_split * head_chs,
cur_head_split * head_chs,
kernel_size=(cur_window, cur_window),
padding=(padding_size, padding_size),
dilation=(dilation, dilation),
groups=cur_head_split*Ch,
groups=cur_head_split * head_chs,
)
self.conv_list.append(cur_conv)
self.head_splits.append(cur_head_split)
self.channel_splits = [x*Ch for x in self.head_splits]
self.channel_splits = [x * head_chs for x in self.head_splits]
def forward(self, q, v, size: Tuple[int, int]):
B, h, N, Ch = q.shape
B, num_heads, N, C = q.shape
H, W = size
_assert(N == 1 + H * W, '')
@ -103,13 +105,13 @@ class ConvRelPosEnc(nn.Module):
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
v_img = v[:, :, 1:, :] # [B, h, H*W, Ch]
v_img = v_img.transpose(-1, -2).reshape(B, h * Ch, H, W)
v_img = v_img.transpose(-1, -2).reshape(B, num_heads * C, H, W)
v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels
conv_v_img_list = []
for i, conv in enumerate(self.conv_list):
conv_v_img_list.append(conv(v_img_list[i]))
conv_v_img = torch.cat(conv_v_img_list, dim=1)
conv_v_img = conv_v_img.reshape(B, h, Ch, H * W).transpose(-1, -2)
conv_v_img = conv_v_img.reshape(B, num_heads, C, H * W).transpose(-1, -2)
EV_hat = q_img * conv_v_img
EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch].
@ -118,7 +120,15 @@ class ConvRelPosEnc(nn.Module):
class FactorAttnConvRelPosEnc(nn.Module):
""" Factorized attention with convolutional relative position encoding class. """
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.,
shared_crpe=None,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
@ -188,8 +198,20 @@ class ConvPosEnc(nn.Module):
class SerialBlock(nn.Module):
""" Serial block class.
Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
shared_cpe=None,
shared_crpe=None,
):
super().__init__()
# Conv-Attention.
@ -197,13 +219,24 @@ class SerialBlock(nn.Module):
self.norm1 = norm_layer(dim)
self.factoratt_crpe = FactorAttnConvRelPosEnc(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe)
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
shared_crpe=shared_crpe,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# MLP.
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=proj_drop,
)
def forward(self, x, size: Tuple[int, int]):
# Conv-Attention.
@ -222,8 +255,19 @@ class SerialBlock(nn.Module):
class ParallelBlock(nn.Module):
""" Parallel block class. """
def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_crpes=None):
def __init__(
self,
dims,
num_heads,
mlp_ratios=[],
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
shared_crpes=None,
):
super().__init__()
# Conv-Attention.
@ -231,16 +275,28 @@ class ParallelBlock(nn.Module):
self.norm13 = norm_layer(dims[2])
self.norm14 = norm_layer(dims[3])
self.factoratt_crpe2 = FactorAttnConvRelPosEnc(
dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[1]
dims[1],
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
shared_crpe=shared_crpes[1],
)
self.factoratt_crpe3 = FactorAttnConvRelPosEnc(
dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[2]
dims[2],
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
shared_crpe=shared_crpes[2],
)
self.factoratt_crpe4 = FactorAttnConvRelPosEnc(
dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[3]
dims[3],
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
shared_crpe=shared_crpes[3],
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -253,7 +309,11 @@ class ParallelBlock(nn.Module):
assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
in_features=dims[1],
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=proj_drop,
)
def upsample(self, x, factor: float, size: Tuple[int, int]):
""" Feature map up-sampling. """
@ -319,10 +379,27 @@ class ParallelBlock(nn.Module):
class CoaT(nn.Module):
""" CoaT class. """
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0),
serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
return_interm_layers=False, out_features=None, crpe_window=None, global_pool='token'):
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dims=(0, 0, 0, 0),
serial_depths=(0, 0, 0, 0),
parallel_depth=0,
num_heads=0,
mlp_ratios=(0, 0, 0, 0),
qkv_bias=True,
drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
return_interm_layers=False,
out_features=None,
crpe_window=None,
global_pool='token',
):
super().__init__()
assert global_pool in ('token', 'avg')
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
@ -361,21 +438,31 @@ class CoaT(nn.Module):
self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3)
# Convolutional relative position encodings.
self.crpe1 = ConvRelPosEnc(Ch=embed_dims[0] // num_heads, h=num_heads, window=crpe_window)
self.crpe2 = ConvRelPosEnc(Ch=embed_dims[1] // num_heads, h=num_heads, window=crpe_window)
self.crpe3 = ConvRelPosEnc(Ch=embed_dims[2] // num_heads, h=num_heads, window=crpe_window)
self.crpe4 = ConvRelPosEnc(Ch=embed_dims[3] // num_heads, h=num_heads, window=crpe_window)
self.crpe1 = ConvRelPosEnc(head_chs=embed_dims[0] // num_heads, num_heads=num_heads, window=crpe_window)
self.crpe2 = ConvRelPosEnc(head_chs=embed_dims[1] // num_heads, num_heads=num_heads, window=crpe_window)
self.crpe3 = ConvRelPosEnc(head_chs=embed_dims[2] // num_heads, num_heads=num_heads, window=crpe_window)
self.crpe4 = ConvRelPosEnc(head_chs=embed_dims[3] // num_heads, num_heads=num_heads, window=crpe_window)
# Disable stochastic depth.
dpr = drop_path_rate
assert dpr == 0.0
skwargs = dict(
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr,
norm_layer=norm_layer,
)
# Serial blocks 1.
self.serial_blocks1 = nn.ModuleList([
SerialBlock(
dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe1, shared_crpe=self.crpe1
dim=embed_dims[0],
mlp_ratio=mlp_ratios[0],
shared_cpe=self.cpe1,
shared_crpe=self.crpe1,
**skwargs,
)
for _ in range(serial_depths[0])]
)
@ -383,9 +470,11 @@ class CoaT(nn.Module):
# Serial blocks 2.
self.serial_blocks2 = nn.ModuleList([
SerialBlock(
dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe2, shared_crpe=self.crpe2
dim=embed_dims[1],
mlp_ratio=mlp_ratios[1],
shared_cpe=self.cpe2,
shared_crpe=self.crpe2,
**skwargs,
)
for _ in range(serial_depths[1])]
)
@ -393,9 +482,11 @@ class CoaT(nn.Module):
# Serial blocks 3.
self.serial_blocks3 = nn.ModuleList([
SerialBlock(
dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe3, shared_crpe=self.crpe3
dim=embed_dims[2],
mlp_ratio=mlp_ratios[2],
shared_cpe=self.cpe3,
shared_crpe=self.crpe3,
**skwargs,
)
for _ in range(serial_depths[2])]
)
@ -403,9 +494,11 @@ class CoaT(nn.Module):
# Serial blocks 4.
self.serial_blocks4 = nn.ModuleList([
SerialBlock(
dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe4, shared_crpe=self.crpe4
dim=embed_dims[3],
mlp_ratio=mlp_ratios[3],
shared_cpe=self.cpe4,
shared_crpe=self.crpe4,
**skwargs,
)
for _ in range(serial_depths[3])]
)
@ -415,9 +508,10 @@ class CoaT(nn.Module):
if self.parallel_depth > 0:
self.parallel_blocks = nn.ModuleList([
ParallelBlock(
dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4)
dims=embed_dims,
mlp_ratios=mlp_ratios,
shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4),
**skwargs,
)
for _ in range(parallel_depth)]
)
@ -437,10 +531,12 @@ class CoaT(nn.Module):
# CoaT series: Aggregate features of last three scales for classification.
assert embed_dims[1] == embed_dims[2] == embed_dims[3]
self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
else:
# CoaT-Lite series: Use feature of last scale for classification.
self.aggregate = None
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# Initialize weights.
@ -587,6 +683,7 @@ class CoaT(nn.Module):
x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C]
else:
x = x_feat[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x_feat[:, 0]
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x) -> torch.Tensor:

View File

@ -62,7 +62,14 @@ default_cfgs = {
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
class GPSA(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.):
self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.,
locality_strength=1.,
):
super().__init__()
self.num_heads = num_heads
self.dim = dim
@ -145,7 +152,14 @@ class GPSA(nn.Module):
class MHSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
@ -195,20 +209,48 @@ class MHSA(nn.Module):
class Block(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
use_gpsa=True,
locality_strength=1.,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.use_gpsa = use_gpsa
if self.use_gpsa:
self.attn = GPSA(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, **kwargs)
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
locality_strength=locality_strength,
)
else:
self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.attn = MHSA(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=proj_drop,
)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
@ -221,10 +263,28 @@ class ConViT(nn.Module):
"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm,
local_up_to_layer=3, locality_strength=1., use_pos_embed=True):
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
drop_rate=0.,
pos_drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
hybrid_backbone=None,
norm_layer=nn.LayerNorm,
local_up_to_layer=3,
locality_strength=1.,
use_pos_embed=True,
):
super().__init__()
assert global_pool in ('', 'avg', 'token')
embed_dim *= num_heads
@ -245,7 +305,7 @@ class ConViT(nn.Module):
self.num_patches = num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if self.use_pos_embed:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
@ -254,20 +314,22 @@ class ConViT(nn.Module):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
use_gpsa=True,
locality_strength=locality_strength)
if i < local_up_to_layer else
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
use_gpsa=False)
for i in range(depth)])
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
use_gpsa=i < local_up_to_layer,
locality_strength=locality_strength,
) for i in range(depth)])
self.norm = norm_layer(embed_dim)
# Classifier head
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.cls_token, std=.02)
@ -327,6 +389,7 @@ class ConViT(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):

View File

@ -42,8 +42,18 @@ class Residual(nn.Module):
class ConvMixer(nn.Module):
def __init__(
self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, global_pool='avg',
act_layer=nn.GELU, **kwargs):
self,
dim,
depth,
kernel_size=9,
patch_size=7,
in_chans=3,
num_classes=1000,
global_pool='avg',
drop_rate=0.,
act_layer=nn.GELU,
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.num_features = dim
@ -67,6 +77,7 @@ class ConvMixer(nn.Module):
) for i in range(depth)]
)
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity()
@torch.jit.ignore
@ -98,6 +109,7 @@ class ConvMixer(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
x = self.pooling(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):

View File

@ -130,12 +130,19 @@ class PatchEmbed(nn.Module):
class CrossAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.scale = head_dim ** -0.5
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
@ -166,12 +173,26 @@ class CrossAttention(nn.Module):
class CrossAttentionBlock(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = CrossAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -182,8 +203,20 @@ class CrossAttentionBlock(nn.Module):
class MultiScaleBlock(nn.Module):
def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
def __init__(
self,
dim,
patches,
depth,
num_heads,
mlp_ratio,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
num_branches = len(dim)
@ -194,8 +227,15 @@ class MultiScaleBlock(nn.Module):
tmp = []
for i in range(depth[d]):
tmp.append(Block(
dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer))
dim=dim[d],
num_heads=num_heads[d],
mlp_ratio=mlp_ratio[d],
qkv_bias=qkv_bias,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path[i],
norm_layer=norm_layer,
))
if len(tmp) != 0:
self.blocks.append(nn.Sequential(*tmp))
@ -217,14 +257,28 @@ class MultiScaleBlock(nn.Module):
if depth[-1] == 0: # backward capability:
self.fusion.append(
CrossAttentionBlock(
dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
dim=dim[d_],
num_heads=nh,
mlp_ratio=mlp_ratio[d],
qkv_bias=qkv_bias,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path[-1],
norm_layer=norm_layer,
))
else:
tmp = []
for _ in range(depth[-1]):
tmp.append(CrossAttentionBlock(
dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
dim=dim[d_],
num_heads=nh,
mlp_ratio=mlp_ratio[d],
qkv_bias=qkv_bias,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path[-1],
norm_layer=norm_layer,
))
self.fusion.append(nn.Sequential(*tmp))
self.revert_projs = nn.ModuleList()
@ -288,10 +342,26 @@ class CrossViT(nn.Module):
"""
def __init__(
self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000,
embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.),
multi_conv=False, crop_scale=False, qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6), global_pool='token',
self,
img_size=224,
img_scale=(1.0, 1.0),
patch_size=(8, 16),
in_chans=3,
num_classes=1000,
embed_dim=(192, 384),
depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)),
num_heads=(6, 12),
mlp_ratio=(2., 2., 4.),
multi_conv=False,
crop_scale=False,
qkv_bias=True,
drop_rate=0.,
pos_drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
global_pool='token',
):
super().__init__()
assert global_pool in ('token', 'avg')
@ -315,9 +385,15 @@ class CrossViT(nn.Module):
for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim):
self.patch_embed.append(
PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv))
PatchEmbed(
img_size=im_s,
patch_size=p,
in_chans=in_chans,
embed_dim=d,
multi_conv=multi_conv,
))
self.pos_drop = nn.Dropout(p=drop_rate)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
total_depth = sum([sum(x[-2:]) for x in depth])
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule
@ -327,12 +403,22 @@ class CrossViT(nn.Module):
curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
blk = MultiScaleBlock(
embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_, norm_layer=norm_layer)
embed_dim,
num_patches,
block_cfg,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr_,
norm_layer=norm_layer,
)
dpr_ptr += curr_depth
self.blocks.append(blk)
self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.ModuleList([
nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
for i in range(self.num_branches)])
@ -411,6 +497,7 @@ class CrossViT(nn.Module):
def forward_head(self, xs: List[torch.Tensor], pre_logits: bool = False) -> torch.Tensor:
xs = [x[:, 1:].mean(dim=1) for x in xs] if self.global_pool == 'avg' else [x[:, 0] for x in xs]
xs = [self.head_drop(x) for x in xs]
if pre_logits or isinstance(self.head[0], nn.Identity):
return torch.cat([x for x in xs], dim=1)
return torch.mean(torch.stack([head(xs[i]) for i, head in enumerate(self.head)], dim=0), dim=0)
@ -436,17 +523,20 @@ def _create_crossvit(variant, pretrained=False, **kwargs):
return new_state_dict
return build_model_with_cfg(
CrossViT, variant, pretrained,
CrossViT,
variant,
pretrained,
pretrained_filter_fn=pretrained_filter_fn,
**kwargs)
**kwargs,
)
@register_model
def crossvit_tiny_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
num_heads=[3, 3], mlp_ratio=[4, 4, 1], **kwargs)
model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **model_args)
num_heads=[3, 3], mlp_ratio=[4, 4, 1])
model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -454,8 +544,8 @@ def crossvit_tiny_240(pretrained=False, **kwargs):
def crossvit_small_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
num_heads=[6, 6], mlp_ratio=[4, 4, 1], **kwargs)
model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **model_args)
num_heads=[6, 6], mlp_ratio=[4, 4, 1])
model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -463,8 +553,8 @@ def crossvit_small_240(pretrained=False, **kwargs):
def crossvit_base_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
num_heads=[12, 12], mlp_ratio=[4, 4, 1], **kwargs)
model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **model_args)
num_heads=[12, 12], mlp_ratio=[4, 4, 1])
model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -472,8 +562,8 @@ def crossvit_base_240(pretrained=False, **kwargs):
def crossvit_9_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
num_heads=[4, 4], mlp_ratio=[3, 3, 1], **kwargs)
model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **model_args)
num_heads=[4, 4], mlp_ratio=[3, 3, 1])
model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -481,8 +571,8 @@ def crossvit_9_240(pretrained=False, **kwargs):
def crossvit_15_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
num_heads=[6, 6], mlp_ratio=[3, 3, 1], **kwargs)
model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **model_args)
num_heads=[6, 6], mlp_ratio=[3, 3, 1])
model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -491,7 +581,7 @@ def crossvit_18_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs)
model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **model_args)
model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -499,8 +589,8 @@ def crossvit_18_240(pretrained=False, **kwargs):
def crossvit_9_dagger_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **model_args)
num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True)
model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -508,8 +598,8 @@ def crossvit_9_dagger_240(pretrained=False, **kwargs):
def crossvit_15_dagger_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **model_args)
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -517,8 +607,8 @@ def crossvit_15_dagger_240(pretrained=False, **kwargs):
def crossvit_15_dagger_408(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **model_args)
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -526,8 +616,8 @@ def crossvit_15_dagger_408(pretrained=False, **kwargs):
def crossvit_18_dagger_240(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **model_args)
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -535,6 +625,6 @@ def crossvit_18_dagger_240(pretrained=False, **kwargs):
def crossvit_18_dagger_408(pretrained=False, **kwargs):
model_args = dict(
img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **model_args)
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
return model

View File

@ -468,7 +468,6 @@ class DaViT(nn.Module):
ffn=True,
cpe_act=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_classes=1000,
global_pool='avg',

View File

@ -21,49 +21,11 @@ from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdap
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
from ._manipulate import named_apply, checkpoint_seq
from ._registry import register_model
from ._registry import register_model, generate_default_cfgs
__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
'crop_pct': 0.9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = dict(
edgenext_xx_small=_cfg(
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth",
test_input_size=(3, 288, 288), test_crop_pct=1.0),
edgenext_x_small=_cfg(
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth",
test_input_size=(3, 288, 288), test_crop_pct=1.0),
# edgenext_small=_cfg(
# url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"),
edgenext_small=_cfg( # USI weights
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
),
# edgenext_base=_cfg(
# url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth"),
edgenext_base=_cfg( # USI weights
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth",
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
),
edgenext_small_rw=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
test_input_size=(3, 320, 320), test_crop_pct=1.0,
),
)
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
class PositionalEncodingFourier(nn.Module):
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
@ -519,6 +481,43 @@ def _create_edgenext(variant, pretrained=False, **kwargs):
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
'crop_pct': 0.9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = generate_default_cfgs({
'edgenext_xx_small.in1k': _cfg(
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth",
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'edgenext_x_small.in1k': _cfg(
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth",
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'edgenext_small.usi_in1k': _cfg( # USI weights
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
),
'edgenext_base.usi_in1k': _cfg( # USI weights
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth",
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
),
'edgenext_base.in21k_ft_in1k': _cfg( # USI weights
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.21/edgenext_base_IN21K.pth",
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
),
'edgenext_small_rw.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
test_input_size=(3, 320, 320), test_crop_pct=1.0,
),
})
@register_model
def edgenext_xx_small(pretrained=False, **kwargs):
# 1.33M & 260.58M @ 256 resolution

View File

@ -211,7 +211,7 @@ class MetaBlock1d(nn.Module):
mlp_ratio=4.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
drop=0.,
proj_drop=0.,
drop_path=0.,
layer_scale_init_value=1e-5
):
@ -219,7 +219,12 @@ class MetaBlock1d(nn.Module):
self.norm1 = norm_layer(dim)
self.token_mixer = Attention(dim)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.ls1 = LayerScale(dim, layer_scale_init_value)
@ -251,7 +256,7 @@ class MetaBlock2d(nn.Module):
mlp_ratio=4.,
act_layer=nn.GELU,
norm_layer=nn.BatchNorm2d,
drop=0.,
proj_drop=0.,
drop_path=0.,
layer_scale_init_value=1e-5
):
@ -261,7 +266,12 @@ class MetaBlock2d(nn.Module):
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = ConvMlpWithNorm(
dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, norm_layer=norm_layer, drop=drop)
dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
norm_layer=norm_layer,
drop=proj_drop,
)
self.ls2 = LayerScale2d(dim, layer_scale_init_value)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -285,7 +295,7 @@ class EfficientFormerStage(nn.Module):
act_layer=nn.GELU,
norm_layer=nn.BatchNorm2d,
norm_layer_cl=nn.LayerNorm,
drop=.0,
proj_drop=.0,
drop_path=0.,
layer_scale_init_value=1e-5,
):
@ -312,7 +322,7 @@ class EfficientFormerStage(nn.Module):
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer_cl,
drop=drop,
proj_drop=proj_drop,
drop_path=drop_path[block_idx],
layer_scale_init_value=layer_scale_init_value,
))
@ -324,7 +334,7 @@ class EfficientFormerStage(nn.Module):
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
drop=drop,
proj_drop=proj_drop,
drop_path=drop_path[block_idx],
layer_scale_init_value=layer_scale_init_value,
))
@ -360,6 +370,7 @@ class EfficientFormer(nn.Module):
norm_layer=nn.BatchNorm2d,
norm_layer_cl=nn.LayerNorm,
drop_rate=0.,
proj_drop_rate=0.,
drop_path_rate=0.,
**kwargs
):
@ -386,7 +397,7 @@ class EfficientFormer(nn.Module):
act_layer=act_layer,
norm_layer_cl=norm_layer_cl,
norm_layer=norm_layer,
drop=drop_rate,
proj_drop=proj_drop_rate,
drop_path=dpr[i],
layer_scale_init_value=layer_scale_init_value,
)
@ -398,6 +409,7 @@ class EfficientFormer(nn.Module):
# Classifier head
self.num_features = embed_dims[-1]
self.norm = norm_layer_cl(self.num_features)
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# assuming model is always distilled (valid for current checkpoints, will split def if that changes)
self.head_dist = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
@ -453,6 +465,7 @@ class EfficientFormer(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=1)
x = self.head_drop(x)
if pre_logits:
return x
x, x_dist = self.head(x), self.head_dist(x)

View File

@ -342,7 +342,8 @@ class ConvMlpWithNorm(nn.Module):
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = ConvNormAct(
in_features, hidden_features, 1, bias=True, norm_layer=norm_layer, act_layer=act_layer)
in_features, hidden_features, 1,
bias=True, norm_layer=norm_layer, act_layer=act_layer)
if mid_conv:
self.mid = ConvNormAct(
hidden_features, hidden_features, 3,
@ -380,7 +381,7 @@ class EfficientFormerV2Block(nn.Module):
mlp_ratio=4.,
act_layer=nn.GELU,
norm_layer=nn.BatchNorm2d,
drop=0.,
proj_drop=0.,
drop_path=0.,
layer_scale_init_value=1e-5,
resolution=7,
@ -409,7 +410,7 @@ class EfficientFormerV2Block(nn.Module):
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
norm_layer=norm_layer,
drop=drop,
drop=proj_drop,
mid_conv=True,
)
self.ls2 = LayerScale2d(
@ -451,7 +452,7 @@ class EfficientFormerV2Stage(nn.Module):
block_use_attn=False,
num_vit=1,
mlp_ratio=4.,
drop=.0,
proj_drop=.0,
drop_path=0.,
layer_scale_init_value=1e-5,
act_layer=nn.GELU,
@ -487,7 +488,7 @@ class EfficientFormerV2Stage(nn.Module):
stride=block_stride,
mlp_ratio=mlp_ratio[block_idx],
use_attn=block_use_attn and block_idx > remain_idx,
drop=drop,
proj_drop=proj_drop,
drop_path=drop_path[block_idx],
layer_scale_init_value=layer_scale_init_value,
act_layer=act_layer,
@ -520,6 +521,7 @@ class EfficientFormerV2(nn.Module):
act_layer='gelu',
num_classes=1000,
drop_rate=0.,
proj_drop_rate=0.,
drop_path_rate=0.,
layer_scale_init_value=1e-5,
num_vit=0,
@ -556,7 +558,7 @@ class EfficientFormerV2(nn.Module):
block_use_attn=i >= 2,
num_vit=num_vit,
mlp_ratio=mlp_ratios[i],
drop=drop_rate,
proj_drop=proj_drop_rate,
drop_path=dpr[i],
layer_scale_init_value=layer_scale_init_value,
act_layer=act_layer,
@ -572,6 +574,7 @@ class EfficientFormerV2(nn.Module):
# Classifier head
self.num_features = embed_dims[-1]
self.norm = norm_layer(embed_dims[-1])
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.dist = distillation
if self.dist:
@ -630,6 +633,7 @@ class EfficientFormerV2(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=(2, 3))
x = self.head_drop(x)
if pre_logits:
return x
x, x_dist = self.head(x), self.head_dist(x)

View File

@ -34,8 +34,8 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, RotaryEmbeddingCat, \
apply_rot_embed_cat, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, to_2tuple
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, to_2tuple
from ._builder import build_model_with_cfg
from ._registry import generate_default_cfgs, register_model
@ -355,10 +355,14 @@ class Eva(nn.Module):
scale_mlp: bool = False,
scale_attn_inner: bool = False,
drop_rate: float = 0.,
pos_drop_rate: float = 0.,
patch_drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
norm_layer: Callable = LayerNorm,
init_values: Optional[float] = None,
class_token: bool = True,
use_abs_pos_emb: bool = True,
use_rot_pos_emb: bool = False,
use_post_norm: bool = False,
@ -383,10 +387,13 @@ class Eva(nn.Module):
scale_mlp:
scale_attn_inner:
drop_rate:
pos_drop_rate:
proj_drop_rate:
attn_drop_rate:
drop_path_rate:
norm_layer:
init_values:
class_token:
use_abs_pos_emb:
use_rot_pos_emb:
use_post_norm:
@ -397,7 +404,7 @@ class Eva(nn.Module):
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1
self.num_prefix_tokens = 1 if class_token else 0
self.grad_checkpointing = False
self.patch_embed = PatchEmbed(
@ -408,9 +415,19 @@ class Eva(nn.Module):
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
self.pos_drop = nn.Dropout(p=drop_rate)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_prefix_tokens, embed_dim)) if use_abs_pos_emb else None
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
self.patch_drop = PatchDropout(
patch_drop_rate,
num_prefix_tokens=self.num_prefix_tokens,
return_indices=True,
)
else:
self.patch_drop = None
if use_rot_pos_emb:
ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
@ -435,7 +452,7 @@ class Eva(nn.Module):
swiglu_mlp=swiglu_mlp,
scale_mlp=scale_mlp,
scale_attn_inner=scale_attn_inner,
proj_drop=drop_rate,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
@ -446,6 +463,7 @@ class Eva(nn.Module):
use_fc_norm = self.global_pool == 'avg'
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
@ -505,12 +523,21 @@ class Eva(nn.Module):
def forward_features(self, x):
x = self.patch_embed(x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
# apply abs position embedding
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
# obtain shared rotary position embedding and apply patch dropout
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
if self.patch_drop is not None:
x, keep_indices = self.patch_drop(x)
if rot_pos_embed is not None and keep_indices is not None:
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
@ -525,6 +552,7 @@ class Eva(nn.Module):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):

View File

@ -75,12 +75,6 @@ class FocalModulation(nn.Module):
self.norm = norm_layer(dim) if self.use_post_norm else nn.Identity()
def forward(self, x):
"""
Args:
x: input features with shape of (B, H, W, C)
"""
C = x.shape[1]
# pre linear projection
x = self.f(x)
q, ctx, gates = torch.split(x, self.input_split, 1)

View File

@ -192,6 +192,7 @@ class MlpMixer(nn.Module):
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
drop_rate=0.,
proj_drop_rate=0.,
drop_path_rate=0.,
nlhb=False,
stem_norm=False,
@ -219,11 +220,12 @@ class MlpMixer(nn.Module):
mlp_layer=mlp_layer,
norm_layer=norm_layer,
act_layer=act_layer,
drop=drop_rate,
drop=proj_drop_rate,
drop_path=drop_path_rate,
)
for _ in range(num_blocks)])
self.norm = norm_layer(embed_dim)
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(nlhb=nlhb)
@ -267,6 +269,7 @@ class MlpMixer(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=1)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):

View File

@ -271,7 +271,7 @@ class MobileVitBlock(nn.Module):
num_heads=num_heads,
qkv_bias=True,
attn_drop=attn_drop,
drop=drop,
proj_drop=drop,
drop_path=drop_path_rate,
act_layer=layers.act,
norm_layer=transformer_norm_layer,

View File

@ -27,43 +27,11 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._registry import register_model
from ._registry import register_model, register_model_deprecations, generate_default_cfgs
__all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
'fixed_input_size': True,
**kwargs
}
default_cfgs = dict(
mvitv2_tiny=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth'),
mvitv2_small=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth'),
mvitv2_base=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth'),
mvitv2_large=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth'),
mvitv2_base_in21k=_cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth',
num_classes=19168),
mvitv2_large_in21k=_cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth',
num_classes=19168),
mvitv2_huge_in21k=_cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth',
num_classes=19168),
mvitv2_small_cls=_cfg(url=''),
)
@dataclass
class MultiScaleVitCfg:
depths: Tuple[int, ...] = (2, 3, 16, 3)
@ -113,40 +81,6 @@ class MultiScaleVitCfg:
self.stride_kv = tuple(pool_kv_stride)
model_cfgs = dict(
mvitv2_tiny=MultiScaleVitCfg(
depths=(1, 2, 5, 2),
),
mvitv2_small=MultiScaleVitCfg(
depths=(1, 2, 11, 2),
),
mvitv2_base=MultiScaleVitCfg(
depths=(2, 3, 16, 3),
),
mvitv2_large=MultiScaleVitCfg(
depths=(2, 6, 36, 4),
embed_dim=144,
num_heads=2,
expand_attn=False,
),
mvitv2_base_in21k=MultiScaleVitCfg(
depths=(2, 3, 16, 3),
),
mvitv2_large_in21k=MultiScaleVitCfg(
depths=(2, 6, 36, 4),
embed_dim=144,
num_heads=2,
expand_attn=False,
),
mvitv2_small_cls=MultiScaleVitCfg(
depths=(1, 2, 11, 2),
use_cls_token=True,
),
)
def prod(iterable):
return reduce(operator.mul, iterable, 1)
@ -229,26 +163,32 @@ def cal_rel_pos_type(
# Scale up rel pos if shapes for q and k are different.
q_h_ratio = max(k_h / q_h, 1.0)
k_h_ratio = max(q_h / k_h, 1.0)
dist_h = torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio
dist_h = (
torch.arange(q_h, device=q.device).unsqueeze(-1) * q_h_ratio -
torch.arange(k_h, device=q.device).unsqueeze(0) * k_h_ratio
)
dist_h += (k_h - 1) * k_h_ratio
q_w_ratio = max(k_w / q_w, 1.0)
k_w_ratio = max(q_w / k_w, 1.0)
dist_w = torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio
dist_w = (
torch.arange(q_w, device=q.device).unsqueeze(-1) * q_w_ratio -
torch.arange(k_w, device=q.device).unsqueeze(0) * k_w_ratio
)
dist_w += (k_w - 1) * k_w_ratio
Rh = rel_pos_h[dist_h.long()]
Rw = rel_pos_w[dist_w.long()]
rel_h = rel_pos_h[dist_h.long()]
rel_w = rel_pos_w[dist_w.long()]
B, n_head, q_N, dim = q.shape
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, Rh)
rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, Rw)
rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, rel_h)
rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, rel_w)
attn[:, :, sp_idx:, sp_idx:] = (
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
+ rel_h[:, :, :, :, :, None]
+ rel_w[:, :, :, :, None, :]
+ rel_h.unsqueeze(-1)
+ rel_w.unsqueeze(-2)
).view(B, -1, q_h * q_w, k_h * k_w)
return attn
@ -390,18 +330,18 @@ class MultiScaleAttentionPoolFirst(nn.Module):
v = self.norm_v(v)
q_N = q_size[0] * q_size[1] + int(self.has_cls_token)
q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1)
q = self.q(q).reshape(B, q_N, self.num_heads, -1).permute(0, 2, 1, 3)
q = q.transpose(1, 2).reshape(B, q_N, -1)
q = self.q(q).reshape(B, q_N, self.num_heads, -1).transpose(1, 2)
k_N = k_size[0] * k_size[1] + int(self.has_cls_token)
k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1)
k = self.k(k).reshape(B, k_N, self.num_heads, -1).permute(0, 2, 1, 3)
k = k.transpose(1, 2).reshape(B, k_N, -1)
k = self.k(k).reshape(B, k_N, self.num_heads, -1)
v_N = v_size[0] * v_size[1] + int(self.has_cls_token)
v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1)
v = self.v(v).reshape(B, v_N, self.num_heads, -1).permute(0, 2, 1, 3)
v = v.transpose(1, 2).reshape(B, v_N, -1)
v = self.v(v).reshape(B, v_N, self.num_heads, -1).transpose(1, 2)
attn = (q * self.scale) @ k.transpose(-2, -1)
attn = (q * self.scale) @ k
if self.rel_pos_type == 'spatial':
attn = cal_rel_pos_type(
attn,
@ -764,7 +704,7 @@ class MultiScaleVit(nn.Module):
cfg: MultiScaleVitCfg,
img_size: Tuple[int, int] = (224, 224),
in_chans: int = 3,
global_pool: str = 'avg',
global_pool: Optional[str] = None,
num_classes: int = 1000,
drop_path_rate: float = 0.,
drop_rate: float = 0.,
@ -774,6 +714,8 @@ class MultiScaleVit(nn.Module):
norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
self.num_classes = num_classes
self.drop_rate = drop_rate
if global_pool is None:
global_pool = 'token' if cfg.use_cls_token else 'avg'
self.global_pool = global_pool
self.depths = tuple(cfg.depths)
self.expand_attn = cfg.expand_attn
@ -963,13 +905,89 @@ def checkpoint_filter_fn(state_dict, model):
return out_dict
model_cfgs = dict(
mvitv2_tiny=MultiScaleVitCfg(
depths=(1, 2, 5, 2),
),
mvitv2_small=MultiScaleVitCfg(
depths=(1, 2, 11, 2),
),
mvitv2_base=MultiScaleVitCfg(
depths=(2, 3, 16, 3),
),
mvitv2_large=MultiScaleVitCfg(
depths=(2, 6, 36, 4),
embed_dim=144,
num_heads=2,
expand_attn=False,
),
mvitv2_small_cls=MultiScaleVitCfg(
depths=(1, 2, 11, 2),
use_cls_token=True,
),
mvitv2_base_cls=MultiScaleVitCfg(
depths=(2, 3, 16, 3),
use_cls_token=True,
),
mvitv2_large_cls=MultiScaleVitCfg(
depths=(2, 6, 36, 4),
embed_dim=144,
num_heads=2,
use_cls_token=True,
expand_attn=True,
),
mvitv2_huge_cls=MultiScaleVitCfg(
depths=(4, 8, 60, 8),
embed_dim=192,
num_heads=3,
use_cls_token=True,
expand_attn=True,
),
)
def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
return build_model_with_cfg(
MultiScaleVit, variant, pretrained,
MultiScaleVit,
variant,
pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True),
**kwargs)
**kwargs,
)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
'fixed_input_size': True,
**kwargs
}
default_cfgs = generate_default_cfgs({
'mvitv2_tiny.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth'),
'mvitv2_small.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth'),
'mvitv2_base.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth'),
'mvitv2_large.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth'),
'mvitv2_small_cls': _cfg(url=''),
'mvitv2_base_cls.fb_inw21k': _cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth',
num_classes=19168),
'mvitv2_large_cls.fb_inw21k': _cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth',
num_classes=19168),
'mvitv2_huge_cls.fb_inw21k': _cfg(
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth',
num_classes=19168),
})
@register_model
@ -992,21 +1010,21 @@ def mvitv2_large(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_large', pretrained=pretrained, **kwargs)
# @register_model
# def mvitv2_base_in21k(pretrained=False, **kwargs):
# return _create_mvitv2('mvitv2_base_in21k', pretrained=pretrained, **kwargs)
#
#
# @register_model
# def mvitv2_large_in21k(pretrained=False, **kwargs):
# return _create_mvitv2('mvitv2_large_in21k', pretrained=pretrained, **kwargs)
#
#
# @register_model
# def mvitv2_huge_in21k(pretrained=False, **kwargs):
# return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs)
@register_model
def mvitv2_small_cls(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)
@register_model
def mvitv2_base_cls(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_base_cls', pretrained=pretrained, **kwargs)
@register_model
def mvitv2_large_cls(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_large_cls', pretrained=pretrained, **kwargs)
@register_model
def mvitv2_huge_cls(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_huge_cls', pretrained=pretrained, **kwargs)

View File

@ -104,15 +104,25 @@ class TransformerLayer(nn.Module):
- Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
- Uses modified Attention layer that handles the "block" dimension
"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop)
def forward(self, x):
y = self.norm1(x)
@ -176,9 +186,23 @@ class NestLevel(nn.Module):
""" Single hierarchical level of a Nested Transformer
"""
def __init__(
self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, prev_embed_dim=None,
mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rates=[],
norm_layer=None, act_layer=None, pad_type=''):
self,
num_blocks,
block_size,
seq_length,
num_heads,
depth,
embed_dim,
prev_embed_dim=None,
mlp_ratio=4.,
qkv_bias=True,
proj_drop=0.,
attn_drop=0.,
drop_path=[],
norm_layer=None,
act_layer=None,
pad_type='',
):
super().__init__()
self.block_size = block_size
self.grad_checkpointing = False
@ -191,13 +215,20 @@ class NestLevel(nn.Module):
self.pool = nn.Identity()
# Transformer encoder
if len(drop_path_rates):
assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers'
if len(drop_path):
assert len(drop_path) == depth, 'Must provide as many drop path rates as there are transformer layers'
self.transformer_encoder = nn.Sequential(*[
TransformerLayer(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rates[i],
norm_layer=norm_layer, act_layer=act_layer)
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path[i],
norm_layer=norm_layer,
act_layer=act_layer,
)
for i in range(depth)])
def forward(self, x):
@ -225,10 +256,26 @@ class Nest(nn.Module):
"""
def __init__(
self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512),
num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None,
pad_type='', weight_init='', global_pool='avg'
self,
img_size=224,
in_chans=3,
patch_size=4,
num_levels=3,
embed_dims=(128, 256, 512),
num_heads=(4, 8, 16),
depths=(2, 2, 20),
num_classes=1000,
mlp_ratio=4.,
qkv_bias=True,
drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.5,
norm_layer=None,
act_layer=None,
pad_type='',
weight_init='',
global_pool='avg',
):
"""
Args:
@ -292,7 +339,12 @@ class Nest(nn.Module):
# Patch embedding
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], flatten=False)
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dims[0],
flatten=False,
)
self.num_patches = self.patch_embed.num_patches
self.seq_length = self.num_patches // self.num_blocks[0]
@ -304,8 +356,22 @@ class Nest(nn.Module):
for i in range(len(self.num_blocks)):
dim = embed_dims[i]
levels.append(NestLevel(
self.num_blocks[i], self.block_size, self.seq_length, num_heads[i], depths[i], dim, prev_dim,
mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dp_rates[i], norm_layer, act_layer, pad_type=pad_type))
self.num_blocks[i],
self.block_size,
self.seq_length,
num_heads[i],
depths[i],
dim,
prev_dim,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dp_rates[i],
norm_layer=norm_layer,
act_layer=act_layer,
pad_type=pad_type,
))
self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f'levels.{i}')]
prev_dim = dim
curr_stride *= 2
@ -315,7 +381,10 @@ class Nest(nn.Module):
self.norm = norm_layer(embed_dims[-1])
# Classifier
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
global_pool, head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
self.global_pool = global_pool
self.head_drop = nn.Dropout(drop_rate)
self.head = head
self.init_weights(weight_init)
@ -366,8 +435,7 @@ class Nest(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):

View File

@ -78,7 +78,16 @@ class SequentialTuple(nn.Sequential):
class Transformer(nn.Module):
def __init__(
self, base_dim, depth, heads, mlp_ratio, pool=None, drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None):
self,
base_dim,
depth,
heads,
mlp_ratio,
pool=None,
proj_drop=.0,
attn_drop=.0,
drop_path_prob=None,
):
super(Transformer, self).__init__()
self.layers = nn.ModuleList([])
embed_dim = base_dim * heads
@ -89,8 +98,8 @@ class Transformer(nn.Module):
num_heads=heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
drop=drop_rate,
attn_drop=attn_drop_rate,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path_prob[i],
norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
@ -122,8 +131,14 @@ class ConvHeadPooling(nn.Module):
super(ConvHeadPooling, self).__init__()
self.conv = nn.Conv2d(
in_feature, out_feature, kernel_size=stride + 1, padding=stride // 2, stride=stride,
padding_mode=padding_mode, groups=in_feature)
in_feature,
out_feature,
kernel_size=stride + 1,
padding=stride // 2,
stride=stride,
padding_mode=padding_mode,
groups=in_feature,
)
self.fc = nn.Linear(in_feature, out_feature)
def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]:
@ -150,9 +165,24 @@ class PoolingVisionTransformer(nn.Module):
- https://arxiv.org/abs/2103.16302
"""
def __init__(
self, img_size, patch_size, stride, base_dims, depth, heads,
mlp_ratio, num_classes=1000, in_chans=3, global_pool='token',
distilled=False, attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
self,
img_size,
patch_size,
stride,
base_dims,
depth,
heads,
mlp_ratio,
num_classes=1000,
in_chans=3,
global_pool='token',
distilled=False,
drop_rate=0.,
pos_drop_drate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
):
super(PoolingVisionTransformer, self).__init__()
assert global_pool in ('token',)
@ -173,7 +203,7 @@ class PoolingVisionTransformer(nn.Module):
self.patch_embed = ConvEmbedding(in_chans, base_dims[0] * heads[0], patch_size, stride, padding)
self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, base_dims[0] * heads[0]))
self.pos_drop = nn.Dropout(p=drop_rate)
self.pos_drop = nn.Dropout(p=pos_drop_drate)
transformers = []
# stochastic depth decay rule
@ -182,16 +212,27 @@ class PoolingVisionTransformer(nn.Module):
pool = None
if stage < len(heads) - 1:
pool = ConvHeadPooling(
base_dims[stage] * heads[stage], base_dims[stage + 1] * heads[stage + 1], stride=2)
base_dims[stage] * heads[stage],
base_dims[stage + 1] * heads[stage + 1],
stride=2,
)
transformers += [Transformer(
base_dims[stage], depth[stage], heads[stage], mlp_ratio, pool=pool,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_prob=dpr[stage])
base_dims[stage],
depth[stage],
heads[stage],
mlp_ratio,
pool=pool,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path_prob=dpr[stage],
)
]
self.transformers = SequentialTuple(*transformers)
self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
self.num_features = self.embed_dim = base_dims[-1] * heads[-1]
# Classifier head
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
@ -243,6 +284,8 @@ class PoolingVisionTransformer(nn.Module):
if self.head_dist is not None:
assert self.global_pool == 'token'
x, x_dist = x[:, 0], x[:, 1]
x = self.head_drop(x)
x_dist = self.head_drop(x)
if not pre_logits:
x = self.head(x)
x_dist = self.head_dist(x_dist)
@ -255,6 +298,7 @@ class PoolingVisionTransformer(nn.Module):
else:
if self.global_pool == 'token':
x = x[:, 0]
x = self.head_drop(x)
if not pre_logits:
x = self.head(x)
return x
@ -284,71 +328,70 @@ def _create_pit(variant, pretrained=False, **kwargs):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
PoolingVisionTransformer, variant, pretrained,
PoolingVisionTransformer,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
**kwargs,
)
return model
@register_model
def pit_b_224(pretrained, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=14,
stride=7,
base_dims=[64, 64, 64],
depth=[3, 6, 4],
heads=[4, 8, 16],
mlp_ratio=4,
**kwargs
)
return _create_pit('pit_b_224', pretrained, **model_kwargs)
return _create_pit('pit_b_224', pretrained, **dict(model_args, **kwargs))
@register_model
def pit_s_224(pretrained, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=16,
stride=8,
base_dims=[48, 48, 48],
depth=[2, 6, 4],
heads=[3, 6, 12],
mlp_ratio=4,
**kwargs
)
return _create_pit('pit_s_224', pretrained, **model_kwargs)
return _create_pit('pit_s_224', pretrained, **dict(model_args, **kwargs))
@register_model
def pit_xs_224(pretrained, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=16,
stride=8,
base_dims=[48, 48, 48],
depth=[2, 6, 4],
heads=[2, 4, 8],
mlp_ratio=4,
**kwargs
)
return _create_pit('pit_xs_224', pretrained, **model_kwargs)
return _create_pit('pit_xs_224', pretrained, **dict(model_args, **kwargs))
@register_model
def pit_ti_224(pretrained, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=16,
stride=8,
base_dims=[32, 32, 32],
depth=[2, 6, 4],
heads=[2, 4, 8],
mlp_ratio=4,
**kwargs
)
return _create_pit('pit_ti_224', pretrained, **model_kwargs)
return _create_pit('pit_ti_224', pretrained, **dict(model_args, **kwargs))
@register_model
def pit_b_distilled_224(pretrained, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=14,
stride=7,
base_dims=[64, 64, 64],
@ -356,14 +399,13 @@ def pit_b_distilled_224(pretrained, **kwargs):
heads=[4, 8, 16],
mlp_ratio=4,
distilled=True,
**kwargs
)
return _create_pit('pit_b_distilled_224', pretrained, **model_kwargs)
return _create_pit('pit_b_distilled_224', pretrained, **dict(model_args, **kwargs))
@register_model
def pit_s_distilled_224(pretrained, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=16,
stride=8,
base_dims=[48, 48, 48],
@ -371,14 +413,13 @@ def pit_s_distilled_224(pretrained, **kwargs):
heads=[3, 6, 12],
mlp_ratio=4,
distilled=True,
**kwargs
)
return _create_pit('pit_s_distilled_224', pretrained, **model_kwargs)
return _create_pit('pit_s_distilled_224', pretrained, **dict(model_args, **kwargs))
@register_model
def pit_xs_distilled_224(pretrained, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=16,
stride=8,
base_dims=[48, 48, 48],
@ -386,14 +427,13 @@ def pit_xs_distilled_224(pretrained, **kwargs):
heads=[2, 4, 8],
mlp_ratio=4,
distilled=True,
**kwargs
)
return _create_pit('pit_xs_distilled_224', pretrained, **model_kwargs)
return _create_pit('pit_xs_distilled_224', pretrained, **dict(model_args, **kwargs))
@register_model
def pit_ti_distilled_224(pretrained, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=16,
stride=8,
base_dims=[32, 32, 32],
@ -401,6 +441,5 @@ def pit_ti_distilled_224(pretrained, **kwargs):
heads=[2, 4, 8],
mlp_ratio=4,
distilled=True,
**kwargs
)
return _create_pit('pit_ti_distilled_224', pretrained, **model_kwargs)
return _create_pit('pit_ti_distilled_224', pretrained, **dict(model_args, **kwargs))

View File

@ -103,9 +103,16 @@ class PoolFormerBlock(nn.Module):
"""
def __init__(
self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm1,
drop=0., drop_path=0., layer_scale_init_value=1e-5):
self,
dim,
pool_size=3,
mlp_ratio=4.,
act_layer=nn.GELU,
norm_layer=GroupNorm1,
drop=0.,
drop_path=0.,
layer_scale_init_value=1e-5,
):
super().__init__()
@ -116,7 +123,7 @@ class PoolFormerBlock(nn.Module):
self.mlp = ConvMlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if layer_scale_init_value:
if layer_scale_init_value is not None:
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones(dim))
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones(dim))
else:
@ -134,10 +141,15 @@ class PoolFormerBlock(nn.Module):
def basic_blocks(
dim, index, layers,
pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm1,
drop_rate=.0, drop_path_rate=0.,
dim,
index,
layers,
pool_size=3,
mlp_ratio=4.,
act_layer=nn.GELU,
norm_layer=GroupNorm1,
drop_rate=.0,
drop_path_rate=0.,
layer_scale_init_value=1e-5,
):
""" generate PoolFormer blocks for a stage """
@ -145,9 +157,13 @@ def basic_blocks(
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(PoolFormerBlock(
dim, pool_size=pool_size, mlp_ratio=mlp_ratio,
act_layer=act_layer, norm_layer=norm_layer,
drop=drop_rate, drop_path=block_dpr,
dim,
pool_size=pool_size,
mlp_ratio=mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
drop=drop_rate,
drop_path=block_dpr,
layer_scale_init_value=layer_scale_init_value,
))
blocks = nn.Sequential(*blocks)
@ -176,9 +192,12 @@ class PoolFormer(nn.Module):
down_patch_size=3,
down_stride=2,
down_pad=1,
drop_rate=0., drop_path_rate=0.,
drop_rate=0.,
proj_drop_rate=0.,
drop_path_rate=0.,
layer_scale_init_value=1e-5,
**kwargs):
**kwargs,
):
super().__init__()
self.num_classes = num_classes
@ -187,28 +206,42 @@ class PoolFormer(nn.Module):
self.grad_checkpointing = False
self.patch_embed = PatchEmbed(
patch_size=in_patch_size, stride=in_stride, padding=in_pad,
in_chs=in_chans, embed_dim=embed_dims[0])
patch_size=in_patch_size,
stride=in_stride,
padding=in_pad,
in_chs=in_chans,
embed_dim=embed_dims[0],
)
# set the main block in network
network = []
for i in range(len(layers)):
network.append(basic_blocks(
embed_dims[i], i, layers,
pool_size=pool_size, mlp_ratio=mlp_ratios[i],
act_layer=act_layer, norm_layer=norm_layer,
drop_rate=drop_rate, drop_path_rate=drop_path_rate,
layer_scale_init_value=layer_scale_init_value)
)
embed_dims[i],
i,
layers,
pool_size=pool_size,
mlp_ratio=mlp_ratios[i],
act_layer=act_layer,
norm_layer=norm_layer,
drop_rate=proj_drop_rate,
drop_path_rate=drop_path_rate,
layer_scale_init_value=layer_scale_init_value
))
if i < len(layers) - 1 and (downsamples[i] or embed_dims[i] != embed_dims[i + 1]):
# downsampling between stages
network.append(PatchEmbed(
in_chs=embed_dims[i], embed_dim=embed_dims[i + 1],
patch_size=down_patch_size, stride=down_stride, padding=down_pad)
)
in_chs=embed_dims[i],
embed_dim=embed_dims[i + 1],
patch_size=down_patch_size,
stride=down_stride,
padding=down_pad,
))
self.network = nn.Sequential(*network)
self.norm = norm_layer(self.num_features)
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
@ -254,6 +287,7 @@ class PoolFormer(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean([-2, -1])
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):

View File

@ -54,8 +54,14 @@ default_cfgs = {
class MlpWithDepthwiseConv(nn.Module):
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
drop=0., extra_relu=False):
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.,
extra_relu=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@ -154,8 +160,19 @@ class Attention(nn.Module):
class Block(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., sr_ratio=1, linear_attn=False, qkv_bias=False,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
self,
dim,
num_heads,
mlp_ratio=4.,
sr_ratio=1,
linear_attn=False,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
@ -165,7 +182,7 @@ class Block(nn.Module):
linear_attn=linear_attn,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
proj_drop=proj_drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
@ -173,8 +190,8 @@ class Block(nn.Module):
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
extra_relu=linear_attn
drop=proj_drop,
extra_relu=linear_attn,
)
def forward(self, x, feat_size: List[int]):
@ -193,8 +210,8 @@ class OverlapPatchEmbed(nn.Module):
assert max(patch_size) > stride, "Set larger patch_size than stride"
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
in_chans, embed_dim, patch_size,
stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
@ -217,7 +234,7 @@ class PyramidVisionTransformerStage(nn.Module):
linear_attn: bool = False,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop: float = 0.,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: Union[List[float], float] = 0.0,
norm_layer: Callable = nn.LayerNorm,
@ -242,7 +259,7 @@ class PyramidVisionTransformerStage(nn.Module):
linear_attn=linear_attn,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
@ -278,6 +295,7 @@ class PyramidVisionTransformerV2(nn.Module):
qkv_bias=True,
linear=False,
drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
@ -314,16 +332,17 @@ class PyramidVisionTransformerV2(nn.Module):
mlp_ratio=mlp_ratios[i],
linear_attn=linear,
qkv_bias=qkv_bias,
drop=drop_rate,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer
norm_layer=norm_layer,
))
prev_dim = embed_dims[i]
cur += depths[i]
# classification head
self.num_features = embed_dims[-1]
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
@ -379,6 +398,7 @@ class PyramidVisionTransformerV2(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x.mean(dim=(-1, -2))
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):

View File

@ -181,7 +181,7 @@ class SwinTransformerBlock(nn.Module):
shift_size: int = 0,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop: float = 0.,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: Callable = nn.GELU,
@ -197,7 +197,7 @@ class SwinTransformerBlock(nn.Module):
shift_size: Shift size for SW-MSA.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop: Dropout rate.
proj_drop: Dropout rate.
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
act_layer: Activation layer.
@ -223,7 +223,7 @@ class SwinTransformerBlock(nn.Module):
window_size=to_2tuple(self.window_size),
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
proj_drop=proj_drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -232,7 +232,7 @@ class SwinTransformerBlock(nn.Module):
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
drop=proj_drop,
)
if self.shift_size > 0:
@ -346,7 +346,7 @@ class SwinTransformerStage(nn.Module):
window_size: _int_or_tuple_2_t = 7,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop: float = 0.,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: Union[List[float], float] = 0.,
norm_layer: Callable = nn.LayerNorm,
@ -362,7 +362,7 @@ class SwinTransformerStage(nn.Module):
window_size: Local window size.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop: Dropout rate.
proj_drop: Projection dropout rate.
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
norm_layer: Normalization layer.
@ -396,7 +396,7 @@ class SwinTransformerStage(nn.Module):
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
@ -435,6 +435,7 @@ class SwinTransformer(nn.Module):
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.1,
norm_layer: Union[str, Callable] = nn.LayerNorm,
@ -508,7 +509,7 @@ class SwinTransformer(nn.Module):
window_size=window_size[i],
mlp_ratio=mlp_ratio[i],
qkv_bias=qkv_bias,
drop=drop_rate,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,

View File

@ -203,7 +203,7 @@ class SwinTransformerV2Block(nn.Module):
shift_size=0,
mlp_ratio=4.,
qkv_bias=True,
drop=0.,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
@ -219,7 +219,7 @@ class SwinTransformerV2Block(nn.Module):
shift_size: Shift size for SW-MSA.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop: Dropout rate.
proj_drop: Dropout rate.
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
act_layer: Activation layer.
@ -242,7 +242,7 @@ class SwinTransformerV2Block(nn.Module):
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
proj_drop=proj_drop,
pretrained_window_size=to_2tuple(pretrained_window_size),
)
self.norm1 = norm_layer(dim)
@ -252,7 +252,7 @@ class SwinTransformerV2Block(nn.Module):
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
drop=proj_drop,
)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -367,7 +367,7 @@ class SwinTransformerV2Stage(nn.Module):
downsample=False,
mlp_ratio=4.,
qkv_bias=True,
drop=0.,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
@ -384,7 +384,7 @@ class SwinTransformerV2Stage(nn.Module):
downsample: Use downsample layer at start of the block.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop: Dropout rate
proj_drop: Projection dropout rate
attn_drop: Attention dropout rate.
drop_path: Stochastic depth rate.
norm_layer: Normalization layer.
@ -416,7 +416,7 @@ class SwinTransformerV2Stage(nn.Module):
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
@ -463,6 +463,7 @@ class SwinTransformerV2(nn.Module):
mlp_ratio: float = 4.,
qkv_bias: bool = True,
drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.1,
norm_layer: Callable = nn.LayerNorm,
@ -481,7 +482,8 @@ class SwinTransformerV2(nn.Module):
window_size: Window size.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
drop_rate: Dropout rate.
drop_rate: Head dropout rate.
proj_drop_rate: Projection dropout rate.
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
norm_layer: Normalization layer.
@ -531,7 +533,8 @@ class SwinTransformerV2(nn.Module):
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
pretrained_window_size=pretrained_window_sizes[i],

View File

@ -221,7 +221,7 @@ class SwinTransformerV2CrBlock(nn.Module):
window_size (Tuple[int, int]): Window size to be utilized
shift_size (int): Shifting size to be used
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
drop (float): Dropout in input mapping
proj_drop (float): Dropout in input mapping
drop_attn (float): Dropout rate of attention map
drop_path (float): Dropout in main path
extra_norm (bool): Insert extra norm on 'main' branch if True
@ -238,7 +238,7 @@ class SwinTransformerV2CrBlock(nn.Module):
shift_size: Tuple[int, int] = (0, 0),
mlp_ratio: float = 4.0,
init_values: Optional[float] = 0,
drop: float = 0.0,
proj_drop: float = 0.0,
drop_attn: float = 0.0,
drop_path: float = 0.0,
extra_norm: bool = False,
@ -259,7 +259,7 @@ class SwinTransformerV2CrBlock(nn.Module):
num_heads=num_heads,
window_size=self.window_size,
drop_attn=drop_attn,
drop_proj=drop,
drop_proj=proj_drop,
sequential_attn=sequential_attn,
)
self.norm1 = norm_layer(dim)
@ -269,7 +269,7 @@ class SwinTransformerV2CrBlock(nn.Module):
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
drop=drop,
drop=proj_drop,
out_features=dim,
)
self.norm2 = norm_layer(dim)
@ -445,7 +445,7 @@ class SwinTransformerV2CrStage(nn.Module):
num_heads (int): Number of attention heads to be utilized
window_size (int): Window size to be utilized
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
drop (float): Dropout in input mapping
proj_drop (float): Dropout in input mapping
drop_attn (float): Dropout rate of attention map
drop_path (float): Dropout in main path
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
@ -464,7 +464,7 @@ class SwinTransformerV2CrStage(nn.Module):
window_size: Tuple[int, int],
mlp_ratio: float = 4.0,
init_values: Optional[float] = 0.0,
drop: float = 0.0,
proj_drop: float = 0.0,
drop_attn: float = 0.0,
drop_path: Union[List[float], float] = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
@ -498,7 +498,7 @@ class SwinTransformerV2CrStage(nn.Module):
shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]),
mlp_ratio=mlp_ratio,
init_values=init_values,
drop=drop,
proj_drop=proj_drop,
drop_attn=drop_attn,
drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,
extra_norm=_extra_norm(index),
@ -546,23 +546,24 @@ class SwinTransformerV2Cr(nn.Module):
https://arxiv.org/pdf/2111.09883
Args:
img_size (Tuple[int, int]): Input resolution.
window_size (Optional[int]): Window size. If None, img_size // window_div. Default: None
img_window_ratio (int): Window size to image size ratio. Default: 32
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input channels.
depths (int): Depth of the stage (number of layers).
num_heads (int): Number of attention heads to be utilized.
embed_dim (int): Patch embedding dimension. Default: 96
num_classes (int): Number of output classes. Default: 1000
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels. Default: 4
drop_rate (float): Dropout rate. Default: 0.0
attn_drop_rate (float): Dropout rate of attention map. Default: 0.0
drop_path_rate (float): Stochastic depth rate. Default: 0.0
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks in stage
extra_norm_stage (bool): End each stage with an extra norm layer in main branch
sequential_attn (bool): If true sequential self-attention is performed. Default: False
img_size: Input resolution.
window_size: Window size. If None, img_size // window_div
img_window_ratio: Window size to image size ratio.
patch_size: Patch size.
in_chans: Number of input channels.
depths: Depth of the stage (number of layers).
num_heads: Number of attention heads to be utilized.
embed_dim: Patch embedding dimension.
num_classes: Number of output classes.
mlp_ratio: Ratio of the hidden dimension in the FFN to the input channels.
drop_rate: Dropout rate.
proj_drop_rate: Projection dropout rate.
attn_drop_rate: Dropout rate of attention map.
drop_path_rate: Stochastic depth rate.
norm_layer: Type of normalization layer to be utilized.
extra_norm_period: Insert extra norm layer on main branch every N (period) blocks in stage
extra_norm_stage: End each stage with an extra norm layer in main branch
sequential_attn: If true sequential self-attention is performed.
"""
def __init__(
@ -579,6 +580,7 @@ class SwinTransformerV2Cr(nn.Module):
mlp_ratio: float = 4.0,
init_values: Optional[float] = 0.,
drop_rate: float = 0.0,
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
@ -627,7 +629,7 @@ class SwinTransformerV2Cr(nn.Module):
window_size=window_size,
mlp_ratio=mlp_ratio,
init_values=init_values,
drop=drop_rate,
proj_drop=proj_drop_rate,
drop_attn=attn_drop_rate,
drop_path=dpr[stage_idx],
extra_norm_period=extra_norm_period,

View File

@ -80,31 +80,64 @@ class Block(nn.Module):
""" TNT Block
"""
def __init__(
self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4.,
qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
self,
dim,
dim_out,
num_pixel,
num_heads_in=4,
num_heads_out=12,
mlp_ratio=4.,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
# Inner transformer
self.norm_in = norm_layer(in_dim)
self.norm_in = norm_layer(dim)
self.attn_in = Attention(
in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop)
dim,
dim,
num_heads=num_heads_in,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
)
self.norm_mlp_in = norm_layer(in_dim)
self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4),
out_features=in_dim, act_layer=act_layer, drop=drop)
self.norm_mlp_in = norm_layer(dim)
self.mlp_in = Mlp(
in_features=dim,
hidden_features=int(dim * 4),
out_features=dim,
act_layer=act_layer,
drop=proj_drop,
)
self.norm1_proj = norm_layer(in_dim)
self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True)
self.norm1_proj = norm_layer(dim)
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True)
# Outer transformer
self.norm_out = norm_layer(dim)
self.norm_out = norm_layer(dim_out)
self.attn_out = Attention(
dim, dim, num_heads=num_heads, qkv_bias=qkv_bias,
attn_drop=attn_drop, proj_drop=drop)
dim_out,
dim_out,
num_heads=num_heads_out,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm_mlp = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio),
out_features=dim, act_layer=act_layer, drop=drop)
self.norm_mlp = norm_layer(dim_out)
self.mlp = Mlp(
in_features=dim_out,
hidden_features=int(dim_out * mlp_ratio),
out_features=dim_out,
act_layer=act_layer,
drop=proj_drop,
)
def forward(self, pixel_embed, patch_embed):
# inner
@ -157,9 +190,27 @@ class TNT(nn.Module):
""" Transformer in Transformer - https://arxiv.org/abs/2103.00112
"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, in_dim=48, depth=12, num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4):
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
inner_dim=48,
depth=12,
num_heads_inner=4,
num_heads_outer=12,
mlp_ratio=4.,
qkv_bias=False,
drop_rate=0.,
pos_drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
first_stride=4,
):
super().__init__()
assert global_pool in ('', 'token', 'avg')
self.num_classes = num_classes
@ -168,31 +219,46 @@ class TNT(nn.Module):
self.grad_checkpointing = False
self.pixel_embed = PixelEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride)
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
in_dim=inner_dim,
stride=first_stride,
)
num_patches = self.pixel_embed.num_patches
self.num_patches = num_patches
new_patch_size = self.pixel_embed.new_patch_size
num_pixel = new_patch_size[0] * new_patch_size[1]
self.norm1_proj = norm_layer(num_pixel * in_dim)
self.proj = nn.Linear(num_pixel * in_dim, embed_dim)
self.norm1_proj = norm_layer(num_pixel * inner_dim)
self.proj = nn.Linear(num_pixel * inner_dim, embed_dim)
self.norm2_proj = norm_layer(embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size[0], new_patch_size[1]))
self.pos_drop = nn.Dropout(p=drop_rate)
self.pixel_pos = nn.Parameter(torch.zeros(1, inner_dim, new_patch_size[0], new_patch_size[1]))
self.pos_drop = nn.Dropout(p=pos_drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
blocks = []
for i in range(depth):
blocks.append(Block(
dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head,
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[i], norm_layer=norm_layer))
dim=inner_dim,
dim_out=embed_dim,
num_pixel=num_pixel,
num_heads_in=num_heads_inner,
num_heads_out=num_heads_outer,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
))
self.blocks = nn.ModuleList(blocks)
self.norm = norm_layer(embed_dim)
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.cls_token, std=.02)
@ -260,6 +326,7 @@ class TNT(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):
@ -290,16 +357,16 @@ def _create_tnt(variant, pretrained=False, **kwargs):
@register_model
def tnt_s_patch16_224(pretrained=False, **kwargs):
model_cfg = dict(
patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
qkv_bias=False, **kwargs)
model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg)
patch_size=16, embed_dim=384, inner_dim=24, depth=12, num_heads_outer=6,
qkv_bias=False)
model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
return model
@register_model
def tnt_b_patch16_224(pretrained=False, **kwargs):
model_cfg = dict(
patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4,
qkv_bias=False, **kwargs)
model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg)
patch_size=16, embed_dim=640, inner_dim=40, depth=12, num_heads_outer=10,
qkv_bias=False)
model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
return model

View File

@ -200,20 +200,20 @@ class GlobalSubSampleAttn(nn.Module):
class Block(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.,
self, dim, num_heads, mlp_ratio=4., proj_drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None):
super().__init__()
self.norm1 = norm_layer(dim)
if ws is None:
self.attn = Attention(dim, num_heads, False, None, attn_drop, drop)
self.attn = Attention(dim, num_heads, False, None, attn_drop, proj_drop)
elif ws == 1:
self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio)
self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, proj_drop, sr_ratio)
else:
self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws)
self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, proj_drop, ws)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop)
def forward(self, x, size: Size_):
x = x + self.drop_path(self.attn(self.norm1(x), size))
@ -275,10 +275,26 @@ class Twins(nn.Module):
Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
embed_dims=(64, 128, 256, 512), num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), depths=(3, 4, 6, 3),
sr_ratios=(8, 4, 2, 1), wss=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_cls=Block):
self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
global_pool='avg',
embed_dims=(64, 128, 256, 512),
num_heads=(1, 2, 4, 8),
mlp_ratios=(4, 4, 4, 4),
depths=(3, 4, 6, 3),
sr_ratios=(8, 4, 2, 1),
wss=None,
drop_rate=0.,
pos_drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_cls=Block,
):
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
@ -293,7 +309,7 @@ class Twins(nn.Module):
self.pos_drops = nn.ModuleList()
for i in range(len(depths)):
self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i]))
self.pos_drops.append(nn.Dropout(p=drop_rate))
self.pos_drops.append(nn.Dropout(p=pos_drop_rate))
prev_chs = embed_dims[i]
img_size = tuple(t // patch_size for t in img_size)
patch_size = 2
@ -303,9 +319,16 @@ class Twins(nn.Module):
cur = 0
for k in range(len(depths)):
_block = nn.ModuleList([block_cls(
dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k],
ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])])
dim=embed_dims[k],
num_heads=num_heads[k],
mlp_ratio=mlp_ratios[k],
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[cur + i],
norm_layer=norm_layer,
sr_ratio=sr_ratios[k],
ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])],
)
self.blocks.append(_block)
cur += depths[k]
@ -314,6 +337,7 @@ class Twins(nn.Module):
self.norm = norm_layer(self.num_features)
# classification head
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# init weights
@ -386,6 +410,7 @@ class Twins(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=1)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):
@ -404,47 +429,47 @@ def _create_twins(variant, pretrained=False, **kwargs):
@register_model
def twins_pcpvt_small(pretrained=False, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_pcpvt_small', pretrained=pretrained, **model_kwargs)
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1])
return _create_twins('twins_pcpvt_small', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def twins_pcpvt_base(pretrained=False, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_pcpvt_base', pretrained=pretrained, **model_kwargs)
depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1])
return _create_twins('twins_pcpvt_base', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def twins_pcpvt_large(pretrained=False, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_pcpvt_large', pretrained=pretrained, **model_kwargs)
depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1])
return _create_twins('twins_pcpvt_large', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def twins_svt_small(pretrained=False, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4],
depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_svt_small', pretrained=pretrained, **model_kwargs)
depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1])
return _create_twins('twins_svt_small', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def twins_svt_base(pretrained=False, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4],
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_svt_base', pretrained=pretrained, **model_kwargs)
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1])
return _create_twins('twins_svt_base', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def twins_svt_large(pretrained=False, **kwargs):
model_kwargs = dict(
model_args = dict(
patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4],
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs)
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1])
return _create_twins('twins_svt_large', pretrained=pretrained, **dict(model_args, **kwargs))

View File

@ -40,8 +40,15 @@ default_cfgs = dict(
class SpatialMlp(nn.Module):
def __init__(
self, in_features, hidden_features=None, out_features=None,
act_layer=nn.GELU, drop=0., group=8, spatial_conv=False):
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.,
group=8,
spatial_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@ -83,6 +90,8 @@ class SpatialMlp(nn.Module):
class Attention(nn.Module):
fast_attn: torch.jit.Final[bool]
def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
@ -90,6 +99,8 @@ class Attention(nn.Module):
head_dim = round(dim // num_heads * head_dim_ratio)
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False)
@ -100,10 +111,16 @@ class Attention(nn.Module):
x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
q, k, v = x.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
if self.fast_attn:
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
x = self.proj(x)
@ -113,9 +130,20 @@ class Attention(nn.Module):
class Block(nn.Module):
def __init__(
self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4.,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d,
group=8, attn_disabled=False, spatial_conv=False):
self,
dim,
num_heads,
head_dim_ratio=1.,
mlp_ratio=4.,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=LayerNorm2d,
group=8,
attn_disabled=False,
spatial_conv=False,
):
super().__init__()
self.spatial_conv = spatial_conv
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -125,12 +153,22 @@ class Block(nn.Module):
else:
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop)
dim,
num_heads=num_heads,
head_dim_ratio=head_dim_ratio,
attn_drop=attn_drop,
proj_drop=proj_drop,
)
self.norm2 = norm_layer(dim)
self.mlp = SpatialMlp(
in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop,
group=group, spatial_conv=spatial_conv) # new setting
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
group=group,
spatial_conv=spatial_conv,
)
def forward(self, x):
if self.attn is not None:
@ -141,10 +179,31 @@ class Block(nn.Module):
class Visformer(nn.Module):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111',
vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None):
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
init_channels=32,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4.,
drop_rate=0.,
pos_drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=LayerNorm2d,
attn_stage='111',
use_pos_embed=True,
spatial_conv='111',
vit_stem=False,
group=8,
global_pool='avg',
conv_init=False,
embed_norm=None,
):
super().__init__()
img_size = to_2tuple(img_size)
self.num_classes = num_classes
@ -159,7 +218,7 @@ class Visformer(nn.Module):
else:
self.stage_num1 = self.stage_num3 = depth // 3
self.stage_num2 = depth - self.stage_num1 - self.stage_num3
self.pos_embed = pos_embed
self.use_pos_embed = use_pos_embed
self.grad_checkpointing = False
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
@ -167,15 +226,25 @@ class Visformer(nn.Module):
if self.vit_stem:
self.stem = None
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=embed_norm,
flatten=False,
)
img_size = [x // patch_size for x in img_size]
else:
if self.init_channels is None:
self.stem = None
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size=img_size,
patch_size=patch_size // 2,
in_chans=in_chans,
embed_dim=embed_dim // 2,
norm_layer=embed_norm,
flatten=False,
)
img_size = [x // (patch_size // 2) for x in img_size]
else:
self.stem = nn.Sequential(
@ -185,21 +254,37 @@ class Visformer(nn.Module):
)
img_size = [x // 2 for x in img_size]
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size=img_size,
patch_size=patch_size // 4,
in_chans=self.init_channels,
embed_dim=embed_dim // 2,
norm_layer=embed_norm,
flatten=False,
)
img_size = [x // (patch_size // 4) for x in img_size]
if self.pos_embed:
if self.use_pos_embed:
if self.vit_stem:
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
else:
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size))
self.pos_drop = nn.Dropout(p=drop_rate)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
else:
self.pos_embed1 = None
self.stage1 = nn.Sequential(*[
Block(
dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1')
dim=embed_dim//2,
num_heads=num_heads,
head_dim_ratio=0.5,
mlp_ratio=mlp_ratio,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
group=group,
attn_disabled=(attn_stage[0] == '0'),
spatial_conv=(spatial_conv[0] == '1'),
)
for i in range(self.stage_num1)
])
@ -207,16 +292,33 @@ class Visformer(nn.Module):
# stage2
if not self.vit_stem:
self.patch_embed2 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size=img_size,
patch_size=patch_size // 8,
in_chans=embed_dim // 2,
embed_dim=embed_dim,
norm_layer=embed_norm,
flatten=False,
)
img_size = [x // (patch_size // 8) for x in img_size]
if self.pos_embed:
if self.use_pos_embed:
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
else:
self.pos_embed2 = None
else:
self.patch_embed2 = None
self.stage2 = nn.Sequential(*[
Block(
dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1')
dim=embed_dim,
num_heads=num_heads,
head_dim_ratio=1.0,
mlp_ratio=mlp_ratio,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
group=group,
attn_disabled=(attn_stage[1] == '0'),
spatial_conv=(spatial_conv[1] == '1'),
)
for i in range(self.stage_num1, self.stage_num1+self.stage_num2)
])
@ -224,27 +326,48 @@ class Visformer(nn.Module):
# stage 3
if not self.vit_stem:
self.patch_embed3 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim,
embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False)
img_size=img_size,
patch_size=patch_size // 8,
in_chans=embed_dim,
embed_dim=embed_dim * 2,
norm_layer=embed_norm,
flatten=False,
)
img_size = [x // (patch_size // 8) for x in img_size]
if self.pos_embed:
if self.use_pos_embed:
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size))
else:
self.pos_embed3 = None
else:
self.patch_embed3 = None
self.stage3 = nn.Sequential(*[
Block(
dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1')
dim=embed_dim * 2,
num_heads=num_heads,
head_dim_ratio=1.0,
mlp_ratio=mlp_ratio,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
group=group,
attn_disabled=(attn_stage[2] == '0'),
spatial_conv=(spatial_conv[2] == '1'),
)
for i in range(self.stage_num1+self.stage_num2, depth)
])
# head
self.num_features = embed_dim if self.vit_stem else embed_dim * 2
self.norm = norm_layer(self.num_features)
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
# head
global_pool, head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
self.global_pool = global_pool
self.head_drop = nn.Dropout(drop_rate)
self.head = head
# weights init
if self.pos_embed:
if self.use_pos_embed:
trunc_normal_(self.pos_embed1, std=0.02)
if not self.vit_stem:
trunc_normal_(self.pos_embed2, std=0.02)
@ -293,7 +416,7 @@ class Visformer(nn.Module):
# stage 1
x = self.patch_embed1(x)
if self.pos_embed:
if self.pos_embed1 is not None:
x = self.pos_drop(x + self.pos_embed1)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stage1, x)
@ -301,9 +424,9 @@ class Visformer(nn.Module):
x = self.stage1(x)
# stage 2
if not self.vit_stem:
if self.patch_embed2 is not None:
x = self.patch_embed2(x)
if self.pos_embed:
if self.pos_embed2 is not None:
x = self.pos_drop(x + self.pos_embed2)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stage2, x)
@ -311,9 +434,9 @@ class Visformer(nn.Module):
x = self.stage2(x)
# stage3
if not self.vit_stem:
if self.patch_embed3 is not None:
x = self.patch_embed3(x)
if self.pos_embed:
if self.pos_embed3 is not None:
x = self.pos_drop(x + self.pos_embed3)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stage3, x)
@ -325,6 +448,7 @@ class Visformer(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
x = self.global_pool(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):
@ -345,8 +469,8 @@ def visformer_tiny(pretrained=False, **kwargs):
model_cfg = dict(
init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs)
model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg)
embed_norm=nn.BatchNorm2d)
model = _create_visformer('visformer_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
return model
@ -355,8 +479,8 @@ def visformer_small(pretrained=False, **kwargs):
model_cfg = dict(
init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs)
model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg)
embed_norm=nn.BatchNorm2d)
model = _create_visformer('visformer_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
return model

View File

@ -27,7 +27,7 @@ import logging
import math
from collections import OrderedDict
from functools import partial
from typing import Optional, List
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -38,7 +38,7 @@ from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
resample_abs_pos_embed, RmsNorm
resample_abs_pos_embed, RmsNorm, PatchDropout
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@ -119,7 +119,7 @@ class Block(nn.Module):
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
drop=0.,
proj_drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
@ -134,7 +134,7 @@ class Block(nn.Module):
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
@ -145,7 +145,7 @@ class Block(nn.Module):
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
drop=proj_drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -165,7 +165,7 @@ class ResPostBlock(nn.Module):
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
drop=0.,
proj_drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
@ -181,7 +181,7 @@ class ResPostBlock(nn.Module):
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.norm1 = norm_layer(dim)
@ -191,7 +191,7 @@ class ResPostBlock(nn.Module):
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
drop=proj_drop,
)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -224,7 +224,7 @@ class ParallelScalingBlock(nn.Module):
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
drop=0.,
proj_drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
@ -255,7 +255,7 @@ class ParallelScalingBlock(nn.Module):
self.attn_drop = nn.Dropout(attn_drop)
self.attn_out_proj = nn.Linear(dim, dim)
self.mlp_drop = nn.Dropout(drop)
self.mlp_drop = nn.Dropout(proj_drop)
self.mlp_act = act_layer()
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim)
@ -318,7 +318,7 @@ class ParallelThingsBlock(nn.Module):
qkv_bias=False,
qk_norm=False,
init_values=None,
drop=0.,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
@ -337,7 +337,7 @@ class ParallelThingsBlock(nn.Module):
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)),
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
@ -349,7 +349,7 @@ class ParallelThingsBlock(nn.Module):
dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
drop=proj_drop,
)),
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
@ -382,53 +382,58 @@ class VisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
qk_norm=False,
init_values=None,
class_token=True,
no_embed_class=False,
pre_norm=False,
fc_norm=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=Block,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_norm: bool = False,
init_values: Optional[float] = None,
class_token: bool = True,
no_embed_class: bool = False,
pre_norm: bool = False,
fc_norm: Optional[bool] = None,
drop_rate: float = 0.,
pos_drop_rate: float = 0.,
patch_drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init: str = '',
embed_layer: Callable = PatchEmbed,
norm_layer: Optional[Callable] = None,
act_layer: Optional[Callable] = None,
block_fn: Callable = Block,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
img_size: Input image size.
patch_size: Patch size.
in_chans: Number of image input channels.
num_classes: Mumber of classes for classification head.
global_pool: Type of global pooling for final sequence (default: 'token').
embed_dim: Transformer embedding dimension.
depth: Depth of transformer.
num_heads: Number of attention heads.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: Enable bias for qkv projections if True.
init_values: Layer-scale init values (layer-scale enabled if not None).
class_token: Use class token.
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
drop_rate: Head dropout rate.
pos_drop_rate: Position embedding dropout rate.
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme.
embed_layer: Patch embedding layey.
norm_layer: Normalization layer.
act_layer: MLP activation layer.
block_fn: Transformer block layer.
"""
super().__init__()
assert global_pool in ('', 'avg', 'token')
@ -456,7 +461,14 @@ class VisionTransformer(nn.Module):
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
self.patch_drop = PatchDropout(
patch_drop_rate,
num_prefix_tokens=self.num_prefix_tokens,
)
else:
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
@ -468,7 +480,7 @@ class VisionTransformer(nn.Module):
qkv_bias=qkv_bias,
qk_norm=qk_norm,
init_values=init_values,
drop=drop_rate,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
@ -479,6 +491,7 @@ class VisionTransformer(nn.Module):
# Classifier Head
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if weight_init != 'skip':
@ -544,6 +557,7 @@ class VisionTransformer(nn.Module):
def forward_features(self, x):
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
@ -556,6 +570,7 @@ class VisionTransformer(nn.Module):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):

View File

@ -110,7 +110,7 @@ class RelPosBlock(nn.Module):
qk_norm=False,
rel_pos_cls=None,
init_values=None,
drop=0.,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
@ -125,7 +125,7 @@ class RelPosBlock(nn.Module):
qk_norm=qk_norm,
rel_pos_cls=rel_pos_cls,
attn_drop=attn_drop,
proj_drop=drop,
proj_drop=proj_drop,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
@ -136,7 +136,7 @@ class RelPosBlock(nn.Module):
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
drop=proj_drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -158,7 +158,7 @@ class ResPostRelPosBlock(nn.Module):
qk_norm=False,
rel_pos_cls=None,
init_values=None,
drop=0.,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
@ -174,7 +174,7 @@ class ResPostRelPosBlock(nn.Module):
qk_norm=qk_norm,
rel_pos_cls=rel_pos_cls,
attn_drop=attn_drop,
proj_drop=drop,
proj_drop=proj_drop,
)
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -183,7 +183,7 @@ class ResPostRelPosBlock(nn.Module):
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
drop=proj_drop,
)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -232,6 +232,7 @@ class VisionTransformerRelPos(nn.Module):
rel_pos_dim=None,
shared_rel_pos=False,
drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='skip',
@ -259,6 +260,7 @@ class VisionTransformerRelPos(nn.Module):
rel_pos_ty pe (str): type of relative position
shared_rel_pos (bool): share relative pos across all blocks
drop_rate (float): dropout rate
proj_drop_rate (float): projection dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
@ -313,7 +315,7 @@ class VisionTransformerRelPos(nn.Module):
qk_norm=qk_norm,
rel_pos_cls=rel_pos_cls,
init_values=init_values,
drop=drop_rate,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
@ -324,6 +326,7 @@ class VisionTransformerRelPos(nn.Module):
# Classifier Head
self.fc_norm = norm_layer(embed_dim) if fc_norm else nn.Identity()
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if weight_init != 'skip':
@ -380,6 +383,7 @@ class VisionTransformerRelPos(nn.Module):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):

View File

@ -85,7 +85,17 @@ default_cfgs = {
class OutlookAttention(nn.Module):
def __init__(self, dim, num_heads, kernel_size=3, padding=1, stride=1, qkv_bias=False, attn_drop=0., proj_drop=0.):
def __init__(
self,
dim,
num_heads,
kernel_size=3,
padding=1,
stride=1,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.,
):
super().__init__()
head_dim = dim // num_heads
self.num_heads = num_heads
@ -133,21 +143,40 @@ class OutlookAttention(nn.Module):
class Outlooker(nn.Module):
def __init__(
self, dim, kernel_size, padding, stride=1, num_heads=1, mlp_ratio=3., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, qkv_bias=False
self,
dim,
kernel_size,
padding,
stride=1,
num_heads=1,
mlp_ratio=3.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
qkv_bias=False,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = OutlookAttention(
dim, num_heads, kernel_size=kernel_size,
padding=padding, stride=stride,
qkv_bias=qkv_bias, attn_drop=attn_drop)
dim,
num_heads,
kernel_size=kernel_size,
padding=padding,
stride=stride,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
@ -158,7 +187,13 @@ class Outlooker(nn.Module):
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
@ -189,8 +224,16 @@ class Attention(nn.Module):
class Transformer(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop)
@ -211,7 +254,14 @@ class Transformer(nn.Module):
class ClassAttention(nn.Module):
def __init__(
self, dim, num_heads=8, head_dim=None, qkv_bias=False, attn_drop=0., proj_drop=0.):
self,
dim,
num_heads=8,
head_dim=None,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.,
):
super().__init__()
self.num_heads = num_heads
if head_dim is not None:
@ -246,17 +296,38 @@ class ClassAttention(nn.Module):
class ClassBlock(nn.Module):
def __init__(
self, dim, num_heads, head_dim=None, mlp_ratio=4., qkv_bias=False,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
self,
dim,
num_heads,
head_dim=None,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = ClassAttention(
dim, num_heads=num_heads, head_dim=head_dim, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
dim,
num_heads=num_heads,
head_dim=head_dim,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
def forward(self, x):
cls_embed = x[:, :1]
@ -354,18 +425,33 @@ def outlooker_blocks(
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(
block_fn(
dim, kernel_size=kernel_size, padding=padding,
stride=stride, num_heads=num_heads, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, attn_drop=attn_drop, drop_path=block_dpr))
blocks.append(block_fn(
dim,
kernel_size=kernel_size,
padding=padding,
stride=stride,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
drop_path=block_dpr,
))
blocks = nn.Sequential(*blocks)
return blocks
def transformer_blocks(
block_fn, index, dim, layers, num_heads, mlp_ratio=3.,
qkv_bias=False, attn_drop=0, drop_path_rate=0., **kwargs):
block_fn,
index,
dim,
layers,
num_heads,
mlp_ratio=3.,
qkv_bias=False,
attn_drop=0,
drop_path_rate=0.,
**kwargs,
):
"""
generate transformer layers in stage2
return: transformer layers
@ -373,13 +459,14 @@ def transformer_blocks(
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(
block_fn(
dim, num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
drop_path=block_dpr))
blocks.append(block_fn(
dim,
num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
drop_path=block_dpr,
))
blocks = nn.Sequential(*blocks)
return blocks
@ -405,6 +492,7 @@ class VOLO(nn.Module):
mlp_ratio=3.0,
qkv_bias=False,
drop_rate=0.,
pos_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
@ -429,14 +517,18 @@ class VOLO(nn.Module):
self.grad_checkpointing = False
self.patch_embed = PatchEmbed(
stem_conv=True, stem_stride=2, patch_size=patch_size,
in_chans=in_chans, hidden_dim=stem_hidden_dim,
embed_dim=embed_dims[0])
stem_conv=True,
stem_stride=2,
patch_size=patch_size,
in_chans=in_chans,
hidden_dim=stem_hidden_dim,
embed_dim=embed_dims[0],
)
# inital positional encoding, we add positional encoding after outlooker blocks
patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale)
self.pos_embed = nn.Parameter(torch.zeros(1, patch_grid[0], patch_grid[1], embed_dims[-1]))
self.pos_drop = nn.Dropout(p=drop_rate)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
# set the main block in network
network = []
@ -444,14 +536,31 @@ class VOLO(nn.Module):
if outlook_attention[i]:
# stage 1
stage = outlooker_blocks(
Outlooker, i, embed_dims[i], layers, num_heads[i], mlp_ratio=mlp_ratio[i],
qkv_bias=qkv_bias, attn_drop=attn_drop_rate, norm_layer=norm_layer)
Outlooker,
i,
embed_dims[i],
layers,
num_heads[i],
mlp_ratio=mlp_ratio[i],
qkv_bias=qkv_bias,
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
)
network.append(stage)
else:
# stage 2
stage = transformer_blocks(
Transformer, i, embed_dims[i], layers, num_heads[i], mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias,
drop_path_rate=drop_path_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
Transformer,
i,
embed_dims[i],
layers,
num_heads[i],
mlp_ratio=mlp_ratio[i],
qkv_bias=qkv_bias,
drop_path_rate=drop_path_rate,
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
)
network.append(stage)
if downsamples[i]:
@ -463,19 +572,18 @@ class VOLO(nn.Module):
# set post block, for example, class attention layers
self.post_network = None
if post_layers is not None:
self.post_network = nn.ModuleList(
[
get_block(
post_layers[i],
dim=embed_dims[-1],
num_heads=num_heads[-1],
mlp_ratio=mlp_ratio[-1],
qkv_bias=qkv_bias,
attn_drop=attn_drop_rate,
drop_path=0.,
norm_layer=norm_layer)
for i in range(len(post_layers))
])
self.post_network = nn.ModuleList([
get_block(
post_layers[i],
dim=embed_dims[-1],
num_heads=num_heads[-1],
mlp_ratio=mlp_ratio[-1],
qkv_bias=qkv_bias,
attn_drop=attn_drop_rate,
drop_path=0.,
norm_layer=norm_layer)
for i in range(len(post_layers))
])
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1]))
trunc_normal_(self.cls_token, std=.02)
@ -487,6 +595,7 @@ class VOLO(nn.Module):
self.norm = norm_layer(self.num_features)
# Classifier head
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02)
@ -630,6 +739,7 @@ class VOLO(nn.Module):
out = x[:, 0]
else:
out = x
x = self.head_drop(x)
if pre_logits:
return out
out = self.head(out)

View File

@ -219,17 +219,28 @@ class ClassAttentionBlock(nn.Module):
"""Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239"""
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1., tokens_norm=False):
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
eta=1.,
tokens_norm=False,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = ClassAttn(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop)
if eta is not None: # LayerScale Initialization (no layerscale when None)
self.gamma1 = nn.Parameter(eta * torch.ones(dim))
@ -297,18 +308,28 @@ class XCA(nn.Module):
class XCABlock(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1.):
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
eta=1.,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm3 = norm_layer(dim)
self.local_mp = LPI(in_features=dim, act_layer=act_layer)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop)
self.gamma1 = nn.Parameter(eta * torch.ones(dim))
self.gamma3 = nn.Parameter(eta * torch.ones(dim))
@ -331,9 +352,29 @@ class XCiT(nn.Module):
"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False):
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
drop_rate=0.,
pos_drop_rate=0.,
proj_drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_layer=None,
norm_layer=None,
cls_attn_layers=2,
use_pos_embed=True,
eta=1.,
tokens_norm=False,
):
"""
Args:
img_size (int, tuple): input image size
@ -346,6 +387,8 @@ class XCiT(nn.Module):
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
drop_rate (float): dropout rate after positional embedding, and in XCA/CA projection + MLP
pos_drop_rate: position embedding dropout rate
proj_drop_rate (float): projection dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate (constant across all layers)
norm_layer: (nn.Module): normalization layer
@ -372,28 +415,53 @@ class XCiT(nn.Module):
self.grad_checkpointing = False
self.patch_embed = ConvPatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, act_layer=act_layer)
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
act_layer=act_layer,
)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.use_pos_embed = use_pos_embed
if use_pos_embed:
self.pos_embed = PositionalEncodingFourier(dim=embed_dim)
self.pos_drop = nn.Dropout(p=drop_rate)
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=pos_drop_rate)
self.blocks = nn.ModuleList([
XCABlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=drop_path_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta)
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=drop_path_rate,
act_layer=act_layer,
norm_layer=norm_layer,
eta=eta,
)
for _ in range(depth)])
self.cls_attn_blocks = nn.ModuleList([
ClassAttentionBlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm)
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_drop=drop_rate,
attn_drop=attn_drop_rate,
act_layer=act_layer,
norm_layer=norm_layer,
eta=eta,
tokens_norm=tokens_norm,
)
for _ in range(cls_attn_layers)])
# Classifier head
self.norm = norm_layer(embed_dim)
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# Init weights
@ -438,7 +506,7 @@ class XCiT(nn.Module):
# x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches)
x, (Hp, Wp) = self.patch_embed(x)
if self.use_pos_embed:
if self.pos_embed is not None:
# `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
x = x + pos_encoding
@ -464,6 +532,7 @@ class XCiT(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):
@ -509,336 +578,336 @@ def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs):
@register_model
def xcit_nano_12_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_nano_12_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
model = _create_xcit('xcit_nano_12_p16_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
model = _create_xcit('xcit_nano_12_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_nano_12_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384, **kwargs)
model = _create_xcit('xcit_nano_12_p16_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384)
model = _create_xcit('xcit_nano_12_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_tiny_12_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_tiny_12_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_12_p16_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
model = _create_xcit('xcit_tiny_12_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_tiny_12_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_12_p16_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
model = _create_xcit('xcit_tiny_12_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_small_12_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_small_12_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_12_p16_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
model = _create_xcit('xcit_small_12_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_small_12_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_12_p16_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
model = _create_xcit('xcit_small_12_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_tiny_24_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_tiny_24_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_tiny_24_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_tiny_24_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_tiny_24_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_small_24_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_small_24_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_small_24_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_small_24_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_small_24_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_medium_24_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_medium_24_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_medium_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_medium_24_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_medium_24_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_medium_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_medium_24_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_large_24_p16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_large_24_p16_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_large_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_large_24_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_large_24_p16_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_large_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_large_24_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
# Patch size 8x8 models
@register_model
def xcit_nano_12_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_nano_12_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
model = _create_xcit('xcit_nano_12_p8_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
model = _create_xcit('xcit_nano_12_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_nano_12_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
model = _create_xcit('xcit_nano_12_p8_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
model = _create_xcit('xcit_nano_12_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_tiny_12_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_tiny_12_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_12_p8_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
model = _create_xcit('xcit_tiny_12_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_tiny_12_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_12_p8_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
model = _create_xcit('xcit_tiny_12_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_small_12_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_small_12_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_12_p8_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
model = _create_xcit('xcit_small_12_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_small_12_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_12_p8_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
model = _create_xcit('xcit_small_12_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_tiny_24_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_tiny_24_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_tiny_24_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_tiny_24_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_tiny_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_tiny_24_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_small_24_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_small_24_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_small_24_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_small_24_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_small_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_small_24_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_medium_24_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_medium_24_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_medium_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_medium_24_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_medium_24_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_medium_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_medium_24_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_large_24_p8_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_large_24_p8_224_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_large_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_large_24_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def xcit_large_24_p8_384_dist(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
model = _create_xcit('xcit_large_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
model_args = dict(
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
model = _create_xcit('xcit_large_24_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
return model