Implement patch dropout for eva / vision_transformer, refactor / improve consistency of dropout args across all vit based models
parent
1bb3989b61
commit
4d135421a3
|
@ -33,12 +33,14 @@ from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
|
|||
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
|
||||
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
|
||||
from .padding import get_padding, get_same_padding, pad_same
|
||||
from .patch_dropout import PatchDropout
|
||||
from .patch_embed import PatchEmbed, resample_patch_embed
|
||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||
from .pos_embed import resample_abs_pos_embed
|
||||
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
|
||||
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
|
||||
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
|
||||
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
|
||||
FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
|
||||
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||
from .selective_kernel import SelectiveKernel
|
||||
from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class PatchDropout(nn.Module):
|
||||
"""
|
||||
https://arxiv.org/abs/2212.00794
|
||||
"""
|
||||
return_indices: torch.jit.Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prob: float = 0.5,
|
||||
num_prefix_tokens: int = 1,
|
||||
ordered: bool = False,
|
||||
return_indices: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert 0 <= prob < 1.
|
||||
self.prob = prob
|
||||
self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
|
||||
self.ordered = ordered
|
||||
self.return_indices = return_indices
|
||||
|
||||
def forward(self, x) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
if not self.training or self.prob == 0.:
|
||||
if self.return_indices:
|
||||
return x, None
|
||||
return x
|
||||
|
||||
if self.num_prefix_tokens:
|
||||
prefix_tokens, x = x[:, :self.num_prefix_tokens], x[:, self.num_prefix_tokens:]
|
||||
else:
|
||||
prefix_tokens = None
|
||||
|
||||
B = x.shape[0]
|
||||
L = x.shape[1]
|
||||
num_keep = max(1, int(L * (1. - self.prob)))
|
||||
keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep]
|
||||
if self.ordered:
|
||||
# NOTE does not need to maintain patch order in typical transformer use,
|
||||
# but possibly useful for debug / visualization
|
||||
keep_indices = keep_indices.sort(dim=-1)[0]
|
||||
x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
|
||||
|
||||
if prefix_tokens is not None:
|
||||
x = torch.cat((prefix_tokens, x), dim=1)
|
||||
|
||||
if self.return_indices:
|
||||
return x, keep_indices
|
||||
return x
|
|
@ -194,6 +194,8 @@ def rot(x):
|
|||
|
||||
|
||||
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
|
||||
if sin_emb.ndim == 3:
|
||||
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
|
||||
return x * cos_emb + rot(x) * sin_emb
|
||||
|
||||
|
||||
|
@ -205,9 +207,17 @@ def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
|
|||
|
||||
def apply_rot_embed_cat(x: torch.Tensor, emb):
|
||||
sin_emb, cos_emb = emb.tensor_split(2, -1)
|
||||
if sin_emb.ndim == 3:
|
||||
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
|
||||
return x * cos_emb + rot(x) * sin_emb
|
||||
|
||||
|
||||
def apply_keep_indices_nlc(x, pos_embed, keep_indices):
|
||||
pos_embed = pos_embed.unsqueeze(0).expand(x.shape[0], -1, -1)
|
||||
pos_embed = pos_embed.gather(1, keep_indices.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1]))
|
||||
return pos_embed
|
||||
|
||||
|
||||
def build_rotary_pos_embed(
|
||||
feat_shape: List[int],
|
||||
bands: Optional[torch.Tensor] = None,
|
||||
|
|
|
@ -279,6 +279,8 @@ class Beit(nn.Module):
|
|||
swiglu_mlp: bool = False,
|
||||
scale_mlp: bool = False,
|
||||
drop_rate: float = 0.,
|
||||
pos_drop_rate: float = 0.,
|
||||
proj_drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
norm_layer: Callable = LayerNorm,
|
||||
|
@ -306,7 +308,7 @@ class Beit(nn.Module):
|
|||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
|
||||
if use_shared_rel_pos_bias:
|
||||
self.rel_pos_bias = RelativePositionBias(
|
||||
|
@ -325,7 +327,7 @@ class Beit(nn.Module):
|
|||
mlp_ratio=mlp_ratio,
|
||||
scale_mlp=scale_mlp,
|
||||
swiglu_mlp=swiglu_mlp,
|
||||
proj_drop=drop_rate,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
|
@ -337,6 +339,7 @@ class Beit(nn.Module):
|
|||
use_fc_norm = self.global_pool == 'avg'
|
||||
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
@ -417,6 +420,7 @@ class Beit(nn.Module):
|
|||
if self.global_pool:
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.fc_norm(x)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -110,17 +110,38 @@ class LayerScaleBlockClassAttn(nn.Module):
|
|||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to add CA and LayerScale
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn,
|
||||
mlp_block=Mlp, init_values=1e-4):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
attn_block=ClassAttn,
|
||||
mlp_block=Mlp,
|
||||
init_values=1e-4,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = attn_block(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
self.mlp = mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
|
||||
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
|
@ -177,17 +198,38 @@ class LayerScaleBlock(nn.Module):
|
|||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to add layerScale
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn,
|
||||
mlp_block=Mlp, init_values=1e-4):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
attn_block=TalkingHeadAttn,
|
||||
mlp_block=Mlp,
|
||||
init_values=1e-4,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = attn_block(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
self.mlp = mlp_block(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
|
||||
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
|
@ -201,9 +243,22 @@ class Cait(nn.Module):
|
|||
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# with slight modifications to adapt to our cait models
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
|
||||
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='token',
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
pos_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
block_layers=LayerScaleBlock,
|
||||
block_layers_token=LayerScaleBlockClassAttn,
|
||||
patch_layer=PatchEmbed,
|
||||
|
@ -226,33 +281,50 @@ class Cait(nn.Module):
|
|||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = patch_layer(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
|
||||
dpr = [drop_path_rate for i in range(depth)]
|
||||
self.blocks = nn.Sequential(*[
|
||||
block_layers(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_values)
|
||||
for i in range(depth)])
|
||||
self.blocks = nn.Sequential(*[block_layers(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
attn_block=attn_block,
|
||||
mlp_block=mlp_block,
|
||||
init_values=init_values,
|
||||
) for i in range(depth)])
|
||||
|
||||
self.blocks_token_only = nn.ModuleList([
|
||||
block_layers_token(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_token_only, qkv_bias=qkv_bias,
|
||||
drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer,
|
||||
act_layer=act_layer, attn_block=attn_block_token_only,
|
||||
mlp_block=mlp_block_token_only, init_values=init_values)
|
||||
for i in range(depth_token_only)])
|
||||
self.blocks_token_only = nn.ModuleList([block_layers_token(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio_token_only,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
attn_block=attn_block_token_only,
|
||||
mlp_block=mlp_block_token_only,
|
||||
init_values=init_values,
|
||||
) for _ in range(depth_token_only)])
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
|
@ -322,6 +394,7 @@ class Cait(nn.Module):
|
|||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool:
|
||||
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -344,77 +417,80 @@ def _create_cait(variant, pretrained=False, **kwargs):
|
|||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
model = build_model_with_cfg(
|
||||
Cait, variant, pretrained,
|
||||
Cait,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def cait_xxs24_224(pretrained=False, **kwargs):
|
||||
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5, **kwargs)
|
||||
model = _create_cait('cait_xxs24_224', pretrained=pretrained, **model_args)
|
||||
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5)
|
||||
model = _create_cait('cait_xxs24_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def cait_xxs24_384(pretrained=False, **kwargs):
|
||||
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5, **kwargs)
|
||||
model = _create_cait('cait_xxs24_384', pretrained=pretrained, **model_args)
|
||||
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5)
|
||||
model = _create_cait('cait_xxs24_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def cait_xxs36_224(pretrained=False, **kwargs):
|
||||
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5, **kwargs)
|
||||
model = _create_cait('cait_xxs36_224', pretrained=pretrained, **model_args)
|
||||
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5)
|
||||
model = _create_cait('cait_xxs36_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def cait_xxs36_384(pretrained=False, **kwargs):
|
||||
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5, **kwargs)
|
||||
model = _create_cait('cait_xxs36_384', pretrained=pretrained, **model_args)
|
||||
model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5)
|
||||
model = _create_cait('cait_xxs36_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def cait_xs24_384(pretrained=False, **kwargs):
|
||||
model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_values=1e-5, **kwargs)
|
||||
model = _create_cait('cait_xs24_384', pretrained=pretrained, **model_args)
|
||||
model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_values=1e-5)
|
||||
model = _create_cait('cait_xs24_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def cait_s24_224(pretrained=False, **kwargs):
|
||||
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5, **kwargs)
|
||||
model = _create_cait('cait_s24_224', pretrained=pretrained, **model_args)
|
||||
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5)
|
||||
model = _create_cait('cait_s24_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def cait_s24_384(pretrained=False, **kwargs):
|
||||
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5, **kwargs)
|
||||
model = _create_cait('cait_s24_384', pretrained=pretrained, **model_args)
|
||||
model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5)
|
||||
model = _create_cait('cait_s24_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def cait_s36_384(pretrained=False, **kwargs):
|
||||
model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_values=1e-6, **kwargs)
|
||||
model = _create_cait('cait_s36_384', pretrained=pretrained, **model_args)
|
||||
model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_values=1e-6)
|
||||
model = _create_cait('cait_s36_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def cait_m36_384(pretrained=False, **kwargs):
|
||||
model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_values=1e-6, **kwargs)
|
||||
model = _create_cait('cait_m36_384', pretrained=pretrained, **model_args)
|
||||
model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_values=1e-6)
|
||||
model = _create_cait('cait_m36_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def cait_m48_448(pretrained=False, **kwargs):
|
||||
model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_values=1e-6, **kwargs)
|
||||
model = _create_cait('cait_m48_448', pretrained=pretrained, **model_args)
|
||||
model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_values=1e-6)
|
||||
model = _create_cait('cait_m48_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
|
|
@ -54,7 +54,7 @@ default_cfgs = {
|
|||
|
||||
class ConvRelPosEnc(nn.Module):
|
||||
""" Convolutional relative position encoding. """
|
||||
def __init__(self, Ch, h, window):
|
||||
def __init__(self, head_chs, num_heads, window):
|
||||
"""
|
||||
Initialization.
|
||||
Ch: Channels per head.
|
||||
|
@ -70,7 +70,7 @@ class ConvRelPosEnc(nn.Module):
|
|||
|
||||
if isinstance(window, int):
|
||||
# Set the same window size for all attention heads.
|
||||
window = {window: h}
|
||||
window = {window: num_heads}
|
||||
self.window = window
|
||||
elif isinstance(window, dict):
|
||||
self.window = window
|
||||
|
@ -84,18 +84,20 @@ class ConvRelPosEnc(nn.Module):
|
|||
# Determine padding size.
|
||||
# Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
|
||||
padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2
|
||||
cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch,
|
||||
cur_conv = nn.Conv2d(
|
||||
cur_head_split * head_chs,
|
||||
cur_head_split * head_chs,
|
||||
kernel_size=(cur_window, cur_window),
|
||||
padding=(padding_size, padding_size),
|
||||
dilation=(dilation, dilation),
|
||||
groups=cur_head_split*Ch,
|
||||
groups=cur_head_split * head_chs,
|
||||
)
|
||||
self.conv_list.append(cur_conv)
|
||||
self.head_splits.append(cur_head_split)
|
||||
self.channel_splits = [x*Ch for x in self.head_splits]
|
||||
self.channel_splits = [x * head_chs for x in self.head_splits]
|
||||
|
||||
def forward(self, q, v, size: Tuple[int, int]):
|
||||
B, h, N, Ch = q.shape
|
||||
B, num_heads, N, C = q.shape
|
||||
H, W = size
|
||||
_assert(N == 1 + H * W, '')
|
||||
|
||||
|
@ -103,13 +105,13 @@ class ConvRelPosEnc(nn.Module):
|
|||
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
|
||||
v_img = v[:, :, 1:, :] # [B, h, H*W, Ch]
|
||||
|
||||
v_img = v_img.transpose(-1, -2).reshape(B, h * Ch, H, W)
|
||||
v_img = v_img.transpose(-1, -2).reshape(B, num_heads * C, H, W)
|
||||
v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels
|
||||
conv_v_img_list = []
|
||||
for i, conv in enumerate(self.conv_list):
|
||||
conv_v_img_list.append(conv(v_img_list[i]))
|
||||
conv_v_img = torch.cat(conv_v_img_list, dim=1)
|
||||
conv_v_img = conv_v_img.reshape(B, h, Ch, H * W).transpose(-1, -2)
|
||||
conv_v_img = conv_v_img.reshape(B, num_heads, C, H * W).transpose(-1, -2)
|
||||
|
||||
EV_hat = q_img * conv_v_img
|
||||
EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch].
|
||||
|
@ -118,7 +120,15 @@ class ConvRelPosEnc(nn.Module):
|
|||
|
||||
class FactorAttnConvRelPosEnc(nn.Module):
|
||||
""" Factorized attention with convolutional relative position encoding class. """
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
shared_crpe=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
@ -188,8 +198,20 @@ class ConvPosEnc(nn.Module):
|
|||
class SerialBlock(nn.Module):
|
||||
""" Serial block class.
|
||||
Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
shared_cpe=None,
|
||||
shared_crpe=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Conv-Attention.
|
||||
|
@ -197,13 +219,24 @@ class SerialBlock(nn.Module):
|
|||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.factoratt_crpe = FactorAttnConvRelPosEnc(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe)
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
shared_crpe=shared_crpe,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
# MLP.
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
|
||||
def forward(self, x, size: Tuple[int, int]):
|
||||
# Conv-Attention.
|
||||
|
@ -222,8 +255,19 @@ class SerialBlock(nn.Module):
|
|||
|
||||
class ParallelBlock(nn.Module):
|
||||
""" Parallel block class. """
|
||||
def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_crpes=None):
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
num_heads,
|
||||
mlp_ratios=[],
|
||||
qkv_bias=False,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
shared_crpes=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Conv-Attention.
|
||||
|
@ -231,16 +275,28 @@ class ParallelBlock(nn.Module):
|
|||
self.norm13 = norm_layer(dims[2])
|
||||
self.norm14 = norm_layer(dims[3])
|
||||
self.factoratt_crpe2 = FactorAttnConvRelPosEnc(
|
||||
dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
|
||||
shared_crpe=shared_crpes[1]
|
||||
dims[1],
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
shared_crpe=shared_crpes[1],
|
||||
)
|
||||
self.factoratt_crpe3 = FactorAttnConvRelPosEnc(
|
||||
dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
|
||||
shared_crpe=shared_crpes[2]
|
||||
dims[2],
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
shared_crpe=shared_crpes[2],
|
||||
)
|
||||
self.factoratt_crpe4 = FactorAttnConvRelPosEnc(
|
||||
dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
|
||||
shared_crpe=shared_crpes[3]
|
||||
dims[3],
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
shared_crpe=shared_crpes[3],
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
|
@ -253,7 +309,11 @@ class ParallelBlock(nn.Module):
|
|||
assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
|
||||
mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
|
||||
self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
|
||||
in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
in_features=dims[1],
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
|
||||
def upsample(self, x, factor: float, size: Tuple[int, int]):
|
||||
""" Feature map up-sampling. """
|
||||
|
@ -319,10 +379,27 @@ class ParallelBlock(nn.Module):
|
|||
class CoaT(nn.Module):
|
||||
""" CoaT class. """
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0),
|
||||
serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
return_interm_layers=False, out_features=None, crpe_window=None, global_pool='token'):
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dims=(0, 0, 0, 0),
|
||||
serial_depths=(0, 0, 0, 0),
|
||||
parallel_depth=0,
|
||||
num_heads=0,
|
||||
mlp_ratios=(0, 0, 0, 0),
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
return_interm_layers=False,
|
||||
out_features=None,
|
||||
crpe_window=None,
|
||||
global_pool='token',
|
||||
):
|
||||
super().__init__()
|
||||
assert global_pool in ('token', 'avg')
|
||||
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
|
||||
|
@ -361,21 +438,31 @@ class CoaT(nn.Module):
|
|||
self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3)
|
||||
|
||||
# Convolutional relative position encodings.
|
||||
self.crpe1 = ConvRelPosEnc(Ch=embed_dims[0] // num_heads, h=num_heads, window=crpe_window)
|
||||
self.crpe2 = ConvRelPosEnc(Ch=embed_dims[1] // num_heads, h=num_heads, window=crpe_window)
|
||||
self.crpe3 = ConvRelPosEnc(Ch=embed_dims[2] // num_heads, h=num_heads, window=crpe_window)
|
||||
self.crpe4 = ConvRelPosEnc(Ch=embed_dims[3] // num_heads, h=num_heads, window=crpe_window)
|
||||
self.crpe1 = ConvRelPosEnc(head_chs=embed_dims[0] // num_heads, num_heads=num_heads, window=crpe_window)
|
||||
self.crpe2 = ConvRelPosEnc(head_chs=embed_dims[1] // num_heads, num_heads=num_heads, window=crpe_window)
|
||||
self.crpe3 = ConvRelPosEnc(head_chs=embed_dims[2] // num_heads, num_heads=num_heads, window=crpe_window)
|
||||
self.crpe4 = ConvRelPosEnc(head_chs=embed_dims[3] // num_heads, num_heads=num_heads, window=crpe_window)
|
||||
|
||||
# Disable stochastic depth.
|
||||
dpr = drop_path_rate
|
||||
assert dpr == 0.0
|
||||
|
||||
skwargs = dict(
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
|
||||
# Serial blocks 1.
|
||||
self.serial_blocks1 = nn.ModuleList([
|
||||
SerialBlock(
|
||||
dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
|
||||
shared_cpe=self.cpe1, shared_crpe=self.crpe1
|
||||
dim=embed_dims[0],
|
||||
mlp_ratio=mlp_ratios[0],
|
||||
shared_cpe=self.cpe1,
|
||||
shared_crpe=self.crpe1,
|
||||
**skwargs,
|
||||
)
|
||||
for _ in range(serial_depths[0])]
|
||||
)
|
||||
|
@ -383,9 +470,11 @@ class CoaT(nn.Module):
|
|||
# Serial blocks 2.
|
||||
self.serial_blocks2 = nn.ModuleList([
|
||||
SerialBlock(
|
||||
dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
|
||||
shared_cpe=self.cpe2, shared_crpe=self.crpe2
|
||||
dim=embed_dims[1],
|
||||
mlp_ratio=mlp_ratios[1],
|
||||
shared_cpe=self.cpe2,
|
||||
shared_crpe=self.crpe2,
|
||||
**skwargs,
|
||||
)
|
||||
for _ in range(serial_depths[1])]
|
||||
)
|
||||
|
@ -393,9 +482,11 @@ class CoaT(nn.Module):
|
|||
# Serial blocks 3.
|
||||
self.serial_blocks3 = nn.ModuleList([
|
||||
SerialBlock(
|
||||
dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
|
||||
shared_cpe=self.cpe3, shared_crpe=self.crpe3
|
||||
dim=embed_dims[2],
|
||||
mlp_ratio=mlp_ratios[2],
|
||||
shared_cpe=self.cpe3,
|
||||
shared_crpe=self.crpe3,
|
||||
**skwargs,
|
||||
)
|
||||
for _ in range(serial_depths[2])]
|
||||
)
|
||||
|
@ -403,9 +494,11 @@ class CoaT(nn.Module):
|
|||
# Serial blocks 4.
|
||||
self.serial_blocks4 = nn.ModuleList([
|
||||
SerialBlock(
|
||||
dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
|
||||
shared_cpe=self.cpe4, shared_crpe=self.crpe4
|
||||
dim=embed_dims[3],
|
||||
mlp_ratio=mlp_ratios[3],
|
||||
shared_cpe=self.cpe4,
|
||||
shared_crpe=self.crpe4,
|
||||
**skwargs,
|
||||
)
|
||||
for _ in range(serial_depths[3])]
|
||||
)
|
||||
|
@ -415,9 +508,10 @@ class CoaT(nn.Module):
|
|||
if self.parallel_depth > 0:
|
||||
self.parallel_blocks = nn.ModuleList([
|
||||
ParallelBlock(
|
||||
dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
|
||||
shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4)
|
||||
dims=embed_dims,
|
||||
mlp_ratios=mlp_ratios,
|
||||
shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4),
|
||||
**skwargs,
|
||||
)
|
||||
for _ in range(parallel_depth)]
|
||||
)
|
||||
|
@ -437,10 +531,12 @@ class CoaT(nn.Module):
|
|||
# CoaT series: Aggregate features of last three scales for classification.
|
||||
assert embed_dims[1] == embed_dims[2] == embed_dims[3]
|
||||
self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
else:
|
||||
# CoaT-Lite series: Use feature of last scale for classification.
|
||||
self.aggregate = None
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
# Initialize weights.
|
||||
|
@ -587,6 +683,7 @@ class CoaT(nn.Module):
|
|||
x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C]
|
||||
else:
|
||||
x = x_feat[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x_feat[:, 0]
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
|
|
|
@ -62,7 +62,14 @@ default_cfgs = {
|
|||
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
|
||||
class GPSA(nn.Module):
|
||||
def __init__(
|
||||
self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.):
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
locality_strength=1.,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.dim = dim
|
||||
|
@ -145,7 +152,14 @@ class GPSA(nn.Module):
|
|||
|
||||
|
||||
class MHSA(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
@ -195,20 +209,48 @@ class MHSA(nn.Module):
|
|||
class Block(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
use_gpsa=True,
|
||||
locality_strength=1.,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.use_gpsa = use_gpsa
|
||||
if self.use_gpsa:
|
||||
self.attn = GPSA(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, **kwargs)
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
locality_strength=locality_strength,
|
||||
)
|
||||
else:
|
||||
self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
self.attn = MHSA(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
|
@ -221,10 +263,28 @@ class ConViT(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
|
||||
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
|
||||
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm,
|
||||
local_up_to_layer=3, locality_strength=1., use_pos_embed=True):
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='token',
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
drop_rate=0.,
|
||||
pos_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
hybrid_backbone=None,
|
||||
norm_layer=nn.LayerNorm,
|
||||
local_up_to_layer=3,
|
||||
locality_strength=1.,
|
||||
use_pos_embed=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'avg', 'token')
|
||||
embed_dim *= num_heads
|
||||
|
@ -245,7 +305,7 @@ class ConViT(nn.Module):
|
|||
self.num_patches = num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
|
||||
if self.use_pos_embed:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
||||
|
@ -254,20 +314,22 @@ class ConViT(nn.Module):
|
|||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
use_gpsa=True,
|
||||
locality_strength=locality_strength)
|
||||
if i < local_up_to_layer else
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
use_gpsa=False)
|
||||
for i in range(depth)])
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
use_gpsa=i < local_up_to_layer,
|
||||
locality_strength=locality_strength,
|
||||
) for i in range(depth)])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
# Classifier head
|
||||
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
@ -327,6 +389,7 @@ class ConViT(nn.Module):
|
|||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool:
|
||||
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -42,8 +42,18 @@ class Residual(nn.Module):
|
|||
|
||||
class ConvMixer(nn.Module):
|
||||
def __init__(
|
||||
self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, global_pool='avg',
|
||||
act_layer=nn.GELU, **kwargs):
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
kernel_size=9,
|
||||
patch_size=7,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
drop_rate=0.,
|
||||
act_layer=nn.GELU,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = dim
|
||||
|
@ -67,6 +77,7 @@ class ConvMixer(nn.Module):
|
|||
) for i in range(depth)]
|
||||
)
|
||||
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
@torch.jit.ignore
|
||||
|
@ -98,6 +109,7 @@ class ConvMixer(nn.Module):
|
|||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
x = self.pooling(x)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -130,12 +130,19 @@ class PatchEmbed(nn.Module):
|
|||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
|
@ -166,12 +173,26 @@ class CrossAttention(nn.Module):
|
|||
class CrossAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = CrossAttention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
|
@ -182,8 +203,20 @@ class CrossAttentionBlock(nn.Module):
|
|||
|
||||
class MultiScaleBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
patches,
|
||||
depth,
|
||||
num_heads,
|
||||
mlp_ratio,
|
||||
qkv_bias=False,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_branches = len(dim)
|
||||
|
@ -194,8 +227,15 @@ class MultiScaleBlock(nn.Module):
|
|||
tmp = []
|
||||
for i in range(depth[d]):
|
||||
tmp.append(Block(
|
||||
dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
|
||||
drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer))
|
||||
dim=dim[d],
|
||||
num_heads=num_heads[d],
|
||||
mlp_ratio=mlp_ratio[d],
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i],
|
||||
norm_layer=norm_layer,
|
||||
))
|
||||
if len(tmp) != 0:
|
||||
self.blocks.append(nn.Sequential(*tmp))
|
||||
|
||||
|
@ -217,14 +257,28 @@ class MultiScaleBlock(nn.Module):
|
|||
if depth[-1] == 0: # backward capability:
|
||||
self.fusion.append(
|
||||
CrossAttentionBlock(
|
||||
dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
|
||||
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
|
||||
dim=dim[d_],
|
||||
num_heads=nh,
|
||||
mlp_ratio=mlp_ratio[d],
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[-1],
|
||||
norm_layer=norm_layer,
|
||||
))
|
||||
else:
|
||||
tmp = []
|
||||
for _ in range(depth[-1]):
|
||||
tmp.append(CrossAttentionBlock(
|
||||
dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
|
||||
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
|
||||
dim=dim[d_],
|
||||
num_heads=nh,
|
||||
mlp_ratio=mlp_ratio[d],
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[-1],
|
||||
norm_layer=norm_layer,
|
||||
))
|
||||
self.fusion.append(nn.Sequential(*tmp))
|
||||
|
||||
self.revert_projs = nn.ModuleList()
|
||||
|
@ -288,10 +342,26 @@ class CrossViT(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000,
|
||||
embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.),
|
||||
multi_conv=False, crop_scale=False, qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), global_pool='token',
|
||||
self,
|
||||
img_size=224,
|
||||
img_scale=(1.0, 1.0),
|
||||
patch_size=(8, 16),
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dim=(192, 384),
|
||||
depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)),
|
||||
num_heads=(6, 12),
|
||||
mlp_ratio=(2., 2., 4.),
|
||||
multi_conv=False,
|
||||
crop_scale=False,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
pos_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
global_pool='token',
|
||||
):
|
||||
super().__init__()
|
||||
assert global_pool in ('token', 'avg')
|
||||
|
@ -315,9 +385,15 @@ class CrossViT(nn.Module):
|
|||
|
||||
for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim):
|
||||
self.patch_embed.append(
|
||||
PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv))
|
||||
PatchEmbed(
|
||||
img_size=im_s,
|
||||
patch_size=p,
|
||||
in_chans=in_chans,
|
||||
embed_dim=d,
|
||||
multi_conv=multi_conv,
|
||||
))
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
|
||||
total_depth = sum([sum(x[-2:]) for x in depth])
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule
|
||||
|
@ -327,12 +403,22 @@ class CrossViT(nn.Module):
|
|||
curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
|
||||
dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
|
||||
blk = MultiScaleBlock(
|
||||
embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_, norm_layer=norm_layer)
|
||||
embed_dim,
|
||||
num_patches,
|
||||
block_cfg,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr_,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
dpr_ptr += curr_depth
|
||||
self.blocks.append(blk)
|
||||
|
||||
self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.ModuleList([
|
||||
nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
|
||||
for i in range(self.num_branches)])
|
||||
|
@ -411,6 +497,7 @@ class CrossViT(nn.Module):
|
|||
|
||||
def forward_head(self, xs: List[torch.Tensor], pre_logits: bool = False) -> torch.Tensor:
|
||||
xs = [x[:, 1:].mean(dim=1) for x in xs] if self.global_pool == 'avg' else [x[:, 0] for x in xs]
|
||||
xs = [self.head_drop(x) for x in xs]
|
||||
if pre_logits or isinstance(self.head[0], nn.Identity):
|
||||
return torch.cat([x for x in xs], dim=1)
|
||||
return torch.mean(torch.stack([head(xs[i]) for i, head in enumerate(self.head)], dim=0), dim=0)
|
||||
|
@ -436,17 +523,20 @@ def _create_crossvit(variant, pretrained=False, **kwargs):
|
|||
return new_state_dict
|
||||
|
||||
return build_model_with_cfg(
|
||||
CrossViT, variant, pretrained,
|
||||
CrossViT,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=pretrained_filter_fn,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@register_model
|
||||
def crossvit_tiny_240(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
|
||||
num_heads=[3, 3], mlp_ratio=[4, 4, 1], **kwargs)
|
||||
model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **model_args)
|
||||
num_heads=[3, 3], mlp_ratio=[4, 4, 1])
|
||||
model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -454,8 +544,8 @@ def crossvit_tiny_240(pretrained=False, **kwargs):
|
|||
def crossvit_small_240(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
|
||||
num_heads=[6, 6], mlp_ratio=[4, 4, 1], **kwargs)
|
||||
model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **model_args)
|
||||
num_heads=[6, 6], mlp_ratio=[4, 4, 1])
|
||||
model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -463,8 +553,8 @@ def crossvit_small_240(pretrained=False, **kwargs):
|
|||
def crossvit_base_240(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
|
||||
num_heads=[12, 12], mlp_ratio=[4, 4, 1], **kwargs)
|
||||
model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **model_args)
|
||||
num_heads=[12, 12], mlp_ratio=[4, 4, 1])
|
||||
model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -472,8 +562,8 @@ def crossvit_base_240(pretrained=False, **kwargs):
|
|||
def crossvit_9_240(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
|
||||
num_heads=[4, 4], mlp_ratio=[3, 3, 1], **kwargs)
|
||||
model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **model_args)
|
||||
num_heads=[4, 4], mlp_ratio=[3, 3, 1])
|
||||
model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -481,8 +571,8 @@ def crossvit_9_240(pretrained=False, **kwargs):
|
|||
def crossvit_15_240(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
|
||||
num_heads=[6, 6], mlp_ratio=[3, 3, 1], **kwargs)
|
||||
model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **model_args)
|
||||
num_heads=[6, 6], mlp_ratio=[3, 3, 1])
|
||||
model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -491,7 +581,7 @@ def crossvit_18_240(pretrained=False, **kwargs):
|
|||
model_args = dict(
|
||||
img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
|
||||
num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs)
|
||||
model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **model_args)
|
||||
model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -499,8 +589,8 @@ def crossvit_18_240(pretrained=False, **kwargs):
|
|||
def crossvit_9_dagger_240(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
|
||||
num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
|
||||
model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **model_args)
|
||||
num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True)
|
||||
model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -508,8 +598,8 @@ def crossvit_9_dagger_240(pretrained=False, **kwargs):
|
|||
def crossvit_15_dagger_240(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
|
||||
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
|
||||
model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **model_args)
|
||||
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
|
||||
model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -517,8 +607,8 @@ def crossvit_15_dagger_240(pretrained=False, **kwargs):
|
|||
def crossvit_15_dagger_408(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
|
||||
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
|
||||
model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **model_args)
|
||||
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
|
||||
model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -526,8 +616,8 @@ def crossvit_15_dagger_408(pretrained=False, **kwargs):
|
|||
def crossvit_18_dagger_240(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
|
||||
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
|
||||
model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **model_args)
|
||||
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
|
||||
model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -535,6 +625,6 @@ def crossvit_18_dagger_240(pretrained=False, **kwargs):
|
|||
def crossvit_18_dagger_408(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
|
||||
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
|
||||
model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **model_args)
|
||||
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
|
||||
model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
|
|
@ -468,7 +468,6 @@ class DaViT(nn.Module):
|
|||
ffn=True,
|
||||
cpe_act=False,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
|
|
|
@ -21,49 +21,11 @@ from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdap
|
|||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
|
||||
__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
edgenext_xx_small=_cfg(
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth",
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
edgenext_x_small=_cfg(
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth",
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
# edgenext_small=_cfg(
|
||||
# url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"),
|
||||
edgenext_small=_cfg( # USI weights
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
|
||||
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
# edgenext_base=_cfg(
|
||||
# url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth"),
|
||||
edgenext_base=_cfg( # USI weights
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth",
|
||||
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
|
||||
edgenext_small_rw=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
|
||||
test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
|
||||
class PositionalEncodingFourier(nn.Module):
|
||||
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
|
||||
|
@ -519,6 +481,43 @@ def _create_edgenext(variant, pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'edgenext_xx_small.in1k': _cfg(
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth",
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'edgenext_x_small.in1k': _cfg(
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth",
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'edgenext_small.usi_in1k': _cfg( # USI weights
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
|
||||
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
'edgenext_base.usi_in1k': _cfg( # USI weights
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth",
|
||||
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
'edgenext_base.in21k_ft_in1k': _cfg( # USI weights
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.21/edgenext_base_IN21K.pth",
|
||||
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
'edgenext_small_rw.sw_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
|
||||
test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def edgenext_xx_small(pretrained=False, **kwargs):
|
||||
# 1.33M & 260.58M @ 256 resolution
|
||||
|
|
|
@ -211,7 +211,7 @@ class MetaBlock1d(nn.Module):
|
|||
mlp_ratio=4.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
drop=0.,
|
||||
proj_drop=0.,
|
||||
drop_path=0.,
|
||||
layer_scale_init_value=1e-5
|
||||
):
|
||||
|
@ -219,7 +219,12 @@ class MetaBlock1d(nn.Module):
|
|||
self.norm1 = norm_layer(dim)
|
||||
self.token_mixer = Attention(dim)
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.ls1 = LayerScale(dim, layer_scale_init_value)
|
||||
|
@ -251,7 +256,7 @@ class MetaBlock2d(nn.Module):
|
|||
mlp_ratio=4.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
drop=0.,
|
||||
proj_drop=0.,
|
||||
drop_path=0.,
|
||||
layer_scale_init_value=1e-5
|
||||
):
|
||||
|
@ -261,7 +266,12 @@ class MetaBlock2d(nn.Module):
|
|||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.mlp = ConvMlpWithNorm(
|
||||
dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, norm_layer=norm_layer, drop=drop)
|
||||
dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.ls2 = LayerScale2d(dim, layer_scale_init_value)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
|
@ -285,7 +295,7 @@ class EfficientFormerStage(nn.Module):
|
|||
act_layer=nn.GELU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
norm_layer_cl=nn.LayerNorm,
|
||||
drop=.0,
|
||||
proj_drop=.0,
|
||||
drop_path=0.,
|
||||
layer_scale_init_value=1e-5,
|
||||
):
|
||||
|
@ -312,7 +322,7 @@ class EfficientFormerStage(nn.Module):
|
|||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer_cl,
|
||||
drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
drop_path=drop_path[block_idx],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
))
|
||||
|
@ -324,7 +334,7 @@ class EfficientFormerStage(nn.Module):
|
|||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
drop_path=drop_path[block_idx],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
))
|
||||
|
@ -360,6 +370,7 @@ class EfficientFormer(nn.Module):
|
|||
norm_layer=nn.BatchNorm2d,
|
||||
norm_layer_cl=nn.LayerNorm,
|
||||
drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
**kwargs
|
||||
):
|
||||
|
@ -386,7 +397,7 @@ class EfficientFormer(nn.Module):
|
|||
act_layer=act_layer,
|
||||
norm_layer_cl=norm_layer_cl,
|
||||
norm_layer=norm_layer,
|
||||
drop=drop_rate,
|
||||
proj_drop=proj_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
)
|
||||
|
@ -398,6 +409,7 @@ class EfficientFormer(nn.Module):
|
|||
# Classifier head
|
||||
self.num_features = embed_dims[-1]
|
||||
self.norm = norm_layer_cl(self.num_features)
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
# assuming model is always distilled (valid for current checkpoints, will split def if that changes)
|
||||
self.head_dist = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
@ -453,6 +465,7 @@ class EfficientFormer(nn.Module):
|
|||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean(dim=1)
|
||||
x = self.head_drop(x)
|
||||
if pre_logits:
|
||||
return x
|
||||
x, x_dist = self.head(x), self.head_dist(x)
|
||||
|
|
|
@ -342,7 +342,8 @@ class ConvMlpWithNorm(nn.Module):
|
|||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = ConvNormAct(
|
||||
in_features, hidden_features, 1, bias=True, norm_layer=norm_layer, act_layer=act_layer)
|
||||
in_features, hidden_features, 1,
|
||||
bias=True, norm_layer=norm_layer, act_layer=act_layer)
|
||||
if mid_conv:
|
||||
self.mid = ConvNormAct(
|
||||
hidden_features, hidden_features, 3,
|
||||
|
@ -380,7 +381,7 @@ class EfficientFormerV2Block(nn.Module):
|
|||
mlp_ratio=4.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
drop=0.,
|
||||
proj_drop=0.,
|
||||
drop_path=0.,
|
||||
layer_scale_init_value=1e-5,
|
||||
resolution=7,
|
||||
|
@ -409,7 +410,7 @@ class EfficientFormerV2Block(nn.Module):
|
|||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop=drop,
|
||||
drop=proj_drop,
|
||||
mid_conv=True,
|
||||
)
|
||||
self.ls2 = LayerScale2d(
|
||||
|
@ -451,7 +452,7 @@ class EfficientFormerV2Stage(nn.Module):
|
|||
block_use_attn=False,
|
||||
num_vit=1,
|
||||
mlp_ratio=4.,
|
||||
drop=.0,
|
||||
proj_drop=.0,
|
||||
drop_path=0.,
|
||||
layer_scale_init_value=1e-5,
|
||||
act_layer=nn.GELU,
|
||||
|
@ -487,7 +488,7 @@ class EfficientFormerV2Stage(nn.Module):
|
|||
stride=block_stride,
|
||||
mlp_ratio=mlp_ratio[block_idx],
|
||||
use_attn=block_use_attn and block_idx > remain_idx,
|
||||
drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
drop_path=drop_path[block_idx],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
act_layer=act_layer,
|
||||
|
@ -520,6 +521,7 @@ class EfficientFormerV2(nn.Module):
|
|||
act_layer='gelu',
|
||||
num_classes=1000,
|
||||
drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
layer_scale_init_value=1e-5,
|
||||
num_vit=0,
|
||||
|
@ -556,7 +558,7 @@ class EfficientFormerV2(nn.Module):
|
|||
block_use_attn=i >= 2,
|
||||
num_vit=num_vit,
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
drop=drop_rate,
|
||||
proj_drop=proj_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
act_layer=act_layer,
|
||||
|
@ -572,6 +574,7 @@ class EfficientFormerV2(nn.Module):
|
|||
# Classifier head
|
||||
self.num_features = embed_dims[-1]
|
||||
self.norm = norm_layer(embed_dims[-1])
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.dist = distillation
|
||||
if self.dist:
|
||||
|
@ -630,6 +633,7 @@ class EfficientFormerV2(nn.Module):
|
|||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean(dim=(2, 3))
|
||||
x = self.head_drop(x)
|
||||
if pre_logits:
|
||||
return x
|
||||
x, x_dist = self.head(x), self.head_dist(x)
|
||||
|
|
|
@ -34,8 +34,8 @@ import torch.nn.functional as F
|
|||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, RotaryEmbeddingCat, \
|
||||
apply_rot_embed_cat, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, to_2tuple
|
||||
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
|
||||
apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, to_2tuple
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
@ -355,10 +355,14 @@ class Eva(nn.Module):
|
|||
scale_mlp: bool = False,
|
||||
scale_attn_inner: bool = False,
|
||||
drop_rate: float = 0.,
|
||||
pos_drop_rate: float = 0.,
|
||||
patch_drop_rate: float = 0.,
|
||||
proj_drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
norm_layer: Callable = LayerNorm,
|
||||
init_values: Optional[float] = None,
|
||||
class_token: bool = True,
|
||||
use_abs_pos_emb: bool = True,
|
||||
use_rot_pos_emb: bool = False,
|
||||
use_post_norm: bool = False,
|
||||
|
@ -383,10 +387,13 @@ class Eva(nn.Module):
|
|||
scale_mlp:
|
||||
scale_attn_inner:
|
||||
drop_rate:
|
||||
pos_drop_rate:
|
||||
proj_drop_rate:
|
||||
attn_drop_rate:
|
||||
drop_path_rate:
|
||||
norm_layer:
|
||||
init_values:
|
||||
class_token:
|
||||
use_abs_pos_emb:
|
||||
use_rot_pos_emb:
|
||||
use_post_norm:
|
||||
|
@ -397,7 +404,7 @@ class Eva(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_prefix_tokens = 1
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
|
@ -408,9 +415,19 @@ class Eva(nn.Module):
|
|||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + self.num_prefix_tokens, embed_dim)) if use_abs_pos_emb else None
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
if patch_drop_rate > 0:
|
||||
self.patch_drop = PatchDropout(
|
||||
patch_drop_rate,
|
||||
num_prefix_tokens=self.num_prefix_tokens,
|
||||
return_indices=True,
|
||||
)
|
||||
else:
|
||||
self.patch_drop = None
|
||||
|
||||
if use_rot_pos_emb:
|
||||
ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
|
||||
|
@ -435,7 +452,7 @@ class Eva(nn.Module):
|
|||
swiglu_mlp=swiglu_mlp,
|
||||
scale_mlp=scale_mlp,
|
||||
scale_attn_inner=scale_attn_inner,
|
||||
proj_drop=drop_rate,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
|
@ -446,6 +463,7 @@ class Eva(nn.Module):
|
|||
use_fc_norm = self.global_pool == 'avg'
|
||||
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
@ -505,12 +523,21 @@ class Eva(nn.Module):
|
|||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
|
||||
# apply abs position embedding
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
# obtain shared rotary position embedding and apply patch dropout
|
||||
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
|
||||
if self.patch_drop is not None:
|
||||
x, keep_indices = self.patch_drop(x)
|
||||
if rot_pos_embed is not None and keep_indices is not None:
|
||||
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
|
@ -525,6 +552,7 @@ class Eva(nn.Module):
|
|||
if self.global_pool:
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.fc_norm(x)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -75,12 +75,6 @@ class FocalModulation(nn.Module):
|
|||
self.norm = norm_layer(dim) if self.use_post_norm else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: input features with shape of (B, H, W, C)
|
||||
"""
|
||||
C = x.shape[1]
|
||||
|
||||
# pre linear projection
|
||||
x = self.f(x)
|
||||
q, ctx, gates = torch.split(x, self.input_split, 1)
|
||||
|
|
|
@ -192,6 +192,7 @@ class MlpMixer(nn.Module):
|
|||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
act_layer=nn.GELU,
|
||||
drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
nlhb=False,
|
||||
stem_norm=False,
|
||||
|
@ -219,11 +220,12 @@ class MlpMixer(nn.Module):
|
|||
mlp_layer=mlp_layer,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
drop=drop_rate,
|
||||
drop=proj_drop_rate,
|
||||
drop_path=drop_path_rate,
|
||||
)
|
||||
for _ in range(num_blocks)])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.init_weights(nlhb=nlhb)
|
||||
|
@ -267,6 +269,7 @@ class MlpMixer(nn.Module):
|
|||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean(dim=1)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -271,7 +271,7 @@ class MobileVitBlock(nn.Module):
|
|||
num_heads=num_heads,
|
||||
qkv_bias=True,
|
||||
attn_drop=attn_drop,
|
||||
drop=drop,
|
||||
proj_drop=drop,
|
||||
drop_path=drop_path_rate,
|
||||
act_layer=layers.act,
|
||||
norm_layer=transformer_norm_layer,
|
||||
|
|
|
@ -27,43 +27,11 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||
from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._registry import register_model
|
||||
from ._registry import register_model, register_model_deprecations, generate_default_cfgs
|
||||
|
||||
__all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
||||
'fixed_input_size': True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
mvitv2_tiny=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth'),
|
||||
mvitv2_small=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth'),
|
||||
mvitv2_base=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth'),
|
||||
mvitv2_large=_cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth'),
|
||||
|
||||
mvitv2_base_in21k=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth',
|
||||
num_classes=19168),
|
||||
mvitv2_large_in21k=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth',
|
||||
num_classes=19168),
|
||||
mvitv2_huge_in21k=_cfg(
|
||||
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth',
|
||||
num_classes=19168),
|
||||
|
||||
mvitv2_small_cls=_cfg(url=''),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiScaleVitCfg:
|
||||
depths: Tuple[int, ...] = (2, 3, 16, 3)
|
||||
|
@ -113,40 +81,6 @@ class MultiScaleVitCfg:
|
|||
self.stride_kv = tuple(pool_kv_stride)
|
||||
|
||||
|
||||
model_cfgs = dict(
|
||||
mvitv2_tiny=MultiScaleVitCfg(
|
||||
depths=(1, 2, 5, 2),
|
||||
),
|
||||
mvitv2_small=MultiScaleVitCfg(
|
||||
depths=(1, 2, 11, 2),
|
||||
),
|
||||
mvitv2_base=MultiScaleVitCfg(
|
||||
depths=(2, 3, 16, 3),
|
||||
),
|
||||
mvitv2_large=MultiScaleVitCfg(
|
||||
depths=(2, 6, 36, 4),
|
||||
embed_dim=144,
|
||||
num_heads=2,
|
||||
expand_attn=False,
|
||||
),
|
||||
|
||||
mvitv2_base_in21k=MultiScaleVitCfg(
|
||||
depths=(2, 3, 16, 3),
|
||||
),
|
||||
mvitv2_large_in21k=MultiScaleVitCfg(
|
||||
depths=(2, 6, 36, 4),
|
||||
embed_dim=144,
|
||||
num_heads=2,
|
||||
expand_attn=False,
|
||||
),
|
||||
|
||||
mvitv2_small_cls=MultiScaleVitCfg(
|
||||
depths=(1, 2, 11, 2),
|
||||
use_cls_token=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def prod(iterable):
|
||||
return reduce(operator.mul, iterable, 1)
|
||||
|
||||
|
@ -229,26 +163,32 @@ def cal_rel_pos_type(
|
|||
# Scale up rel pos if shapes for q and k are different.
|
||||
q_h_ratio = max(k_h / q_h, 1.0)
|
||||
k_h_ratio = max(q_h / k_h, 1.0)
|
||||
dist_h = torch.arange(q_h)[:, None] * q_h_ratio - torch.arange(k_h)[None, :] * k_h_ratio
|
||||
dist_h = (
|
||||
torch.arange(q_h, device=q.device).unsqueeze(-1) * q_h_ratio -
|
||||
torch.arange(k_h, device=q.device).unsqueeze(0) * k_h_ratio
|
||||
)
|
||||
dist_h += (k_h - 1) * k_h_ratio
|
||||
q_w_ratio = max(k_w / q_w, 1.0)
|
||||
k_w_ratio = max(q_w / k_w, 1.0)
|
||||
dist_w = torch.arange(q_w)[:, None] * q_w_ratio - torch.arange(k_w)[None, :] * k_w_ratio
|
||||
dist_w = (
|
||||
torch.arange(q_w, device=q.device).unsqueeze(-1) * q_w_ratio -
|
||||
torch.arange(k_w, device=q.device).unsqueeze(0) * k_w_ratio
|
||||
)
|
||||
dist_w += (k_w - 1) * k_w_ratio
|
||||
|
||||
Rh = rel_pos_h[dist_h.long()]
|
||||
Rw = rel_pos_w[dist_w.long()]
|
||||
rel_h = rel_pos_h[dist_h.long()]
|
||||
rel_w = rel_pos_w[dist_w.long()]
|
||||
|
||||
B, n_head, q_N, dim = q.shape
|
||||
|
||||
r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
|
||||
rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, Rh)
|
||||
rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, Rw)
|
||||
rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, rel_h)
|
||||
rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, rel_w)
|
||||
|
||||
attn[:, :, sp_idx:, sp_idx:] = (
|
||||
attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
|
||||
+ rel_h[:, :, :, :, :, None]
|
||||
+ rel_w[:, :, :, :, None, :]
|
||||
+ rel_h.unsqueeze(-1)
|
||||
+ rel_w.unsqueeze(-2)
|
||||
).view(B, -1, q_h * q_w, k_h * k_w)
|
||||
|
||||
return attn
|
||||
|
@ -390,18 +330,18 @@ class MultiScaleAttentionPoolFirst(nn.Module):
|
|||
v = self.norm_v(v)
|
||||
|
||||
q_N = q_size[0] * q_size[1] + int(self.has_cls_token)
|
||||
q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1)
|
||||
q = self.q(q).reshape(B, q_N, self.num_heads, -1).permute(0, 2, 1, 3)
|
||||
q = q.transpose(1, 2).reshape(B, q_N, -1)
|
||||
q = self.q(q).reshape(B, q_N, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
k_N = k_size[0] * k_size[1] + int(self.has_cls_token)
|
||||
k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1)
|
||||
k = self.k(k).reshape(B, k_N, self.num_heads, -1).permute(0, 2, 1, 3)
|
||||
k = k.transpose(1, 2).reshape(B, k_N, -1)
|
||||
k = self.k(k).reshape(B, k_N, self.num_heads, -1)
|
||||
|
||||
v_N = v_size[0] * v_size[1] + int(self.has_cls_token)
|
||||
v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1)
|
||||
v = self.v(v).reshape(B, v_N, self.num_heads, -1).permute(0, 2, 1, 3)
|
||||
v = v.transpose(1, 2).reshape(B, v_N, -1)
|
||||
v = self.v(v).reshape(B, v_N, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
attn = (q * self.scale) @ k.transpose(-2, -1)
|
||||
attn = (q * self.scale) @ k
|
||||
if self.rel_pos_type == 'spatial':
|
||||
attn = cal_rel_pos_type(
|
||||
attn,
|
||||
|
@ -764,7 +704,7 @@ class MultiScaleVit(nn.Module):
|
|||
cfg: MultiScaleVitCfg,
|
||||
img_size: Tuple[int, int] = (224, 224),
|
||||
in_chans: int = 3,
|
||||
global_pool: str = 'avg',
|
||||
global_pool: Optional[str] = None,
|
||||
num_classes: int = 1000,
|
||||
drop_path_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
|
@ -774,6 +714,8 @@ class MultiScaleVit(nn.Module):
|
|||
norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
if global_pool is None:
|
||||
global_pool = 'token' if cfg.use_cls_token else 'avg'
|
||||
self.global_pool = global_pool
|
||||
self.depths = tuple(cfg.depths)
|
||||
self.expand_attn = cfg.expand_attn
|
||||
|
@ -963,13 +905,89 @@ def checkpoint_filter_fn(state_dict, model):
|
|||
return out_dict
|
||||
|
||||
|
||||
model_cfgs = dict(
|
||||
mvitv2_tiny=MultiScaleVitCfg(
|
||||
depths=(1, 2, 5, 2),
|
||||
),
|
||||
mvitv2_small=MultiScaleVitCfg(
|
||||
depths=(1, 2, 11, 2),
|
||||
),
|
||||
mvitv2_base=MultiScaleVitCfg(
|
||||
depths=(2, 3, 16, 3),
|
||||
),
|
||||
mvitv2_large=MultiScaleVitCfg(
|
||||
depths=(2, 6, 36, 4),
|
||||
embed_dim=144,
|
||||
num_heads=2,
|
||||
expand_attn=False,
|
||||
),
|
||||
|
||||
mvitv2_small_cls=MultiScaleVitCfg(
|
||||
depths=(1, 2, 11, 2),
|
||||
use_cls_token=True,
|
||||
),
|
||||
mvitv2_base_cls=MultiScaleVitCfg(
|
||||
depths=(2, 3, 16, 3),
|
||||
use_cls_token=True,
|
||||
),
|
||||
mvitv2_large_cls=MultiScaleVitCfg(
|
||||
depths=(2, 6, 36, 4),
|
||||
embed_dim=144,
|
||||
num_heads=2,
|
||||
use_cls_token=True,
|
||||
expand_attn=True,
|
||||
),
|
||||
mvitv2_huge_cls=MultiScaleVitCfg(
|
||||
depths=(4, 8, 60, 8),
|
||||
embed_dim=192,
|
||||
num_heads=3,
|
||||
use_cls_token=True,
|
||||
expand_attn=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
MultiScaleVit, variant, pretrained,
|
||||
MultiScaleVit,
|
||||
variant,
|
||||
pretrained,
|
||||
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
||||
'fixed_input_size': True,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'mvitv2_tiny.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth'),
|
||||
'mvitv2_small.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth'),
|
||||
'mvitv2_base.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth'),
|
||||
'mvitv2_large.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth'),
|
||||
|
||||
'mvitv2_small_cls': _cfg(url=''),
|
||||
'mvitv2_base_cls.fb_inw21k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth',
|
||||
num_classes=19168),
|
||||
'mvitv2_large_cls.fb_inw21k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth',
|
||||
num_classes=19168),
|
||||
'mvitv2_huge_cls.fb_inw21k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth',
|
||||
num_classes=19168),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -992,21 +1010,21 @@ def mvitv2_large(pretrained=False, **kwargs):
|
|||
return _create_mvitv2('mvitv2_large', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
# @register_model
|
||||
# def mvitv2_base_in21k(pretrained=False, **kwargs):
|
||||
# return _create_mvitv2('mvitv2_base_in21k', pretrained=pretrained, **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def mvitv2_large_in21k(pretrained=False, **kwargs):
|
||||
# return _create_mvitv2('mvitv2_large_in21k', pretrained=pretrained, **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def mvitv2_huge_in21k(pretrained=False, **kwargs):
|
||||
# return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mvitv2_small_cls(pretrained=False, **kwargs):
|
||||
return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mvitv2_base_cls(pretrained=False, **kwargs):
|
||||
return _create_mvitv2('mvitv2_base_cls', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mvitv2_large_cls(pretrained=False, **kwargs):
|
||||
return _create_mvitv2('mvitv2_large_cls', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mvitv2_huge_cls(pretrained=False, **kwargs):
|
||||
return _create_mvitv2('mvitv2_huge_cls', pretrained=pretrained, **kwargs)
|
||||
|
|
|
@ -104,15 +104,25 @@ class TransformerLayer(nn.Module):
|
|||
- Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
|
||||
- Uses modified Attention layer that handles the "block" dimension
|
||||
"""
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.norm1(x)
|
||||
|
@ -176,9 +186,23 @@ class NestLevel(nn.Module):
|
|||
""" Single hierarchical level of a Nested Transformer
|
||||
"""
|
||||
def __init__(
|
||||
self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, prev_embed_dim=None,
|
||||
mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rates=[],
|
||||
norm_layer=None, act_layer=None, pad_type=''):
|
||||
self,
|
||||
num_blocks,
|
||||
block_size,
|
||||
seq_length,
|
||||
num_heads,
|
||||
depth,
|
||||
embed_dim,
|
||||
prev_embed_dim=None,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=[],
|
||||
norm_layer=None,
|
||||
act_layer=None,
|
||||
pad_type='',
|
||||
):
|
||||
super().__init__()
|
||||
self.block_size = block_size
|
||||
self.grad_checkpointing = False
|
||||
|
@ -191,13 +215,20 @@ class NestLevel(nn.Module):
|
|||
self.pool = nn.Identity()
|
||||
|
||||
# Transformer encoder
|
||||
if len(drop_path_rates):
|
||||
assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers'
|
||||
if len(drop_path):
|
||||
assert len(drop_path) == depth, 'Must provide as many drop path rates as there are transformer layers'
|
||||
self.transformer_encoder = nn.Sequential(*[
|
||||
TransformerLayer(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rates[i],
|
||||
norm_layer=norm_layer, act_layer=act_layer)
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
for i in range(depth)])
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -225,10 +256,26 @@ class Nest(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512),
|
||||
num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None,
|
||||
pad_type='', weight_init='', global_pool='avg'
|
||||
self,
|
||||
img_size=224,
|
||||
in_chans=3,
|
||||
patch_size=4,
|
||||
num_levels=3,
|
||||
embed_dims=(128, 256, 512),
|
||||
num_heads=(4, 8, 16),
|
||||
depths=(2, 2, 20),
|
||||
num_classes=1000,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.5,
|
||||
norm_layer=None,
|
||||
act_layer=None,
|
||||
pad_type='',
|
||||
weight_init='',
|
||||
global_pool='avg',
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -292,7 +339,12 @@ class Nest(nn.Module):
|
|||
|
||||
# Patch embedding
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], flatten=False)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dims[0],
|
||||
flatten=False,
|
||||
)
|
||||
self.num_patches = self.patch_embed.num_patches
|
||||
self.seq_length = self.num_patches // self.num_blocks[0]
|
||||
|
||||
|
@ -304,8 +356,22 @@ class Nest(nn.Module):
|
|||
for i in range(len(self.num_blocks)):
|
||||
dim = embed_dims[i]
|
||||
levels.append(NestLevel(
|
||||
self.num_blocks[i], self.block_size, self.seq_length, num_heads[i], depths[i], dim, prev_dim,
|
||||
mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dp_rates[i], norm_layer, act_layer, pad_type=pad_type))
|
||||
self.num_blocks[i],
|
||||
self.block_size,
|
||||
self.seq_length,
|
||||
num_heads[i],
|
||||
depths[i],
|
||||
dim,
|
||||
prev_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dp_rates[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
pad_type=pad_type,
|
||||
))
|
||||
self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f'levels.{i}')]
|
||||
prev_dim = dim
|
||||
curr_stride *= 2
|
||||
|
@ -315,7 +381,10 @@ class Nest(nn.Module):
|
|||
self.norm = norm_layer(embed_dims[-1])
|
||||
|
||||
# Classifier
|
||||
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
global_pool, head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
self.global_pool = global_pool
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = head
|
||||
|
||||
self.init_weights(weight_init)
|
||||
|
||||
|
@ -366,8 +435,7 @@ class Nest(nn.Module):
|
|||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate > 0.:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -78,7 +78,16 @@ class SequentialTuple(nn.Sequential):
|
|||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self, base_dim, depth, heads, mlp_ratio, pool=None, drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None):
|
||||
self,
|
||||
base_dim,
|
||||
depth,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
pool=None,
|
||||
proj_drop=.0,
|
||||
attn_drop=.0,
|
||||
drop_path_prob=None,
|
||||
):
|
||||
super(Transformer, self).__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
embed_dim = base_dim * heads
|
||||
|
@ -89,8 +98,8 @@ class Transformer(nn.Module):
|
|||
num_heads=heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=proj_drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path_prob[i],
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6)
|
||||
)
|
||||
|
@ -122,8 +131,14 @@ class ConvHeadPooling(nn.Module):
|
|||
super(ConvHeadPooling, self).__init__()
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_feature, out_feature, kernel_size=stride + 1, padding=stride // 2, stride=stride,
|
||||
padding_mode=padding_mode, groups=in_feature)
|
||||
in_feature,
|
||||
out_feature,
|
||||
kernel_size=stride + 1,
|
||||
padding=stride // 2,
|
||||
stride=stride,
|
||||
padding_mode=padding_mode,
|
||||
groups=in_feature,
|
||||
)
|
||||
self.fc = nn.Linear(in_feature, out_feature)
|
||||
|
||||
def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
@ -150,9 +165,24 @@ class PoolingVisionTransformer(nn.Module):
|
|||
- https://arxiv.org/abs/2103.16302
|
||||
"""
|
||||
def __init__(
|
||||
self, img_size, patch_size, stride, base_dims, depth, heads,
|
||||
mlp_ratio, num_classes=1000, in_chans=3, global_pool='token',
|
||||
distilled=False, attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
|
||||
self,
|
||||
img_size,
|
||||
patch_size,
|
||||
stride,
|
||||
base_dims,
|
||||
depth,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
global_pool='token',
|
||||
distilled=False,
|
||||
drop_rate=0.,
|
||||
pos_drop_drate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
):
|
||||
super(PoolingVisionTransformer, self).__init__()
|
||||
assert global_pool in ('token',)
|
||||
|
||||
|
@ -173,7 +203,7 @@ class PoolingVisionTransformer(nn.Module):
|
|||
self.patch_embed = ConvEmbedding(in_chans, base_dims[0] * heads[0], patch_size, stride, padding)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, base_dims[0] * heads[0]))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_drate)
|
||||
|
||||
transformers = []
|
||||
# stochastic depth decay rule
|
||||
|
@ -182,16 +212,27 @@ class PoolingVisionTransformer(nn.Module):
|
|||
pool = None
|
||||
if stage < len(heads) - 1:
|
||||
pool = ConvHeadPooling(
|
||||
base_dims[stage] * heads[stage], base_dims[stage + 1] * heads[stage + 1], stride=2)
|
||||
base_dims[stage] * heads[stage],
|
||||
base_dims[stage + 1] * heads[stage + 1],
|
||||
stride=2,
|
||||
)
|
||||
transformers += [Transformer(
|
||||
base_dims[stage], depth[stage], heads[stage], mlp_ratio, pool=pool,
|
||||
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_prob=dpr[stage])
|
||||
base_dims[stage],
|
||||
depth[stage],
|
||||
heads[stage],
|
||||
mlp_ratio,
|
||||
pool=pool,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path_prob=dpr[stage],
|
||||
)
|
||||
]
|
||||
self.transformers = SequentialTuple(*transformers)
|
||||
self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
|
||||
self.num_features = self.embed_dim = base_dims[-1] * heads[-1]
|
||||
|
||||
# Classifier head
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = None
|
||||
if distilled:
|
||||
|
@ -243,6 +284,8 @@ class PoolingVisionTransformer(nn.Module):
|
|||
if self.head_dist is not None:
|
||||
assert self.global_pool == 'token'
|
||||
x, x_dist = x[:, 0], x[:, 1]
|
||||
x = self.head_drop(x)
|
||||
x_dist = self.head_drop(x)
|
||||
if not pre_logits:
|
||||
x = self.head(x)
|
||||
x_dist = self.head_dist(x_dist)
|
||||
|
@ -255,6 +298,7 @@ class PoolingVisionTransformer(nn.Module):
|
|||
else:
|
||||
if self.global_pool == 'token':
|
||||
x = x[:, 0]
|
||||
x = self.head_drop(x)
|
||||
if not pre_logits:
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
@ -284,71 +328,70 @@ def _create_pit(variant, pretrained=False, **kwargs):
|
|||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
|
||||
model = build_model_with_cfg(
|
||||
PoolingVisionTransformer, variant, pretrained,
|
||||
PoolingVisionTransformer,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
**kwargs)
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_b_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=14,
|
||||
stride=7,
|
||||
base_dims=[64, 64, 64],
|
||||
depth=[3, 6, 4],
|
||||
heads=[4, 8, 16],
|
||||
mlp_ratio=4,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_b_224', pretrained, **model_kwargs)
|
||||
return _create_pit('pit_b_224', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_s_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=16,
|
||||
stride=8,
|
||||
base_dims=[48, 48, 48],
|
||||
depth=[2, 6, 4],
|
||||
heads=[3, 6, 12],
|
||||
mlp_ratio=4,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_s_224', pretrained, **model_kwargs)
|
||||
return _create_pit('pit_s_224', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_xs_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=16,
|
||||
stride=8,
|
||||
base_dims=[48, 48, 48],
|
||||
depth=[2, 6, 4],
|
||||
heads=[2, 4, 8],
|
||||
mlp_ratio=4,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_xs_224', pretrained, **model_kwargs)
|
||||
return _create_pit('pit_xs_224', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_ti_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=16,
|
||||
stride=8,
|
||||
base_dims=[32, 32, 32],
|
||||
depth=[2, 6, 4],
|
||||
heads=[2, 4, 8],
|
||||
mlp_ratio=4,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_ti_224', pretrained, **model_kwargs)
|
||||
return _create_pit('pit_ti_224', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_b_distilled_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=14,
|
||||
stride=7,
|
||||
base_dims=[64, 64, 64],
|
||||
|
@ -356,14 +399,13 @@ def pit_b_distilled_224(pretrained, **kwargs):
|
|||
heads=[4, 8, 16],
|
||||
mlp_ratio=4,
|
||||
distilled=True,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_b_distilled_224', pretrained, **model_kwargs)
|
||||
return _create_pit('pit_b_distilled_224', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_s_distilled_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=16,
|
||||
stride=8,
|
||||
base_dims=[48, 48, 48],
|
||||
|
@ -371,14 +413,13 @@ def pit_s_distilled_224(pretrained, **kwargs):
|
|||
heads=[3, 6, 12],
|
||||
mlp_ratio=4,
|
||||
distilled=True,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_s_distilled_224', pretrained, **model_kwargs)
|
||||
return _create_pit('pit_s_distilled_224', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_xs_distilled_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=16,
|
||||
stride=8,
|
||||
base_dims=[48, 48, 48],
|
||||
|
@ -386,14 +427,13 @@ def pit_xs_distilled_224(pretrained, **kwargs):
|
|||
heads=[2, 4, 8],
|
||||
mlp_ratio=4,
|
||||
distilled=True,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_xs_distilled_224', pretrained, **model_kwargs)
|
||||
return _create_pit('pit_xs_distilled_224', pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def pit_ti_distilled_224(pretrained, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=16,
|
||||
stride=8,
|
||||
base_dims=[32, 32, 32],
|
||||
|
@ -401,6 +441,5 @@ def pit_ti_distilled_224(pretrained, **kwargs):
|
|||
heads=[2, 4, 8],
|
||||
mlp_ratio=4,
|
||||
distilled=True,
|
||||
**kwargs
|
||||
)
|
||||
return _create_pit('pit_ti_distilled_224', pretrained, **model_kwargs)
|
||||
return _create_pit('pit_ti_distilled_224', pretrained, **dict(model_args, **kwargs))
|
||||
|
|
|
@ -103,9 +103,16 @@ class PoolFormerBlock(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim, pool_size=3, mlp_ratio=4.,
|
||||
act_layer=nn.GELU, norm_layer=GroupNorm1,
|
||||
drop=0., drop_path=0., layer_scale_init_value=1e-5):
|
||||
self,
|
||||
dim,
|
||||
pool_size=3,
|
||||
mlp_ratio=4.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=GroupNorm1,
|
||||
drop=0.,
|
||||
drop_path=0.,
|
||||
layer_scale_init_value=1e-5,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
@ -116,7 +123,7 @@ class PoolFormerBlock(nn.Module):
|
|||
self.mlp = ConvMlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
if layer_scale_init_value:
|
||||
if layer_scale_init_value is not None:
|
||||
self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones(dim))
|
||||
self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones(dim))
|
||||
else:
|
||||
|
@ -134,10 +141,15 @@ class PoolFormerBlock(nn.Module):
|
|||
|
||||
|
||||
def basic_blocks(
|
||||
dim, index, layers,
|
||||
pool_size=3, mlp_ratio=4.,
|
||||
act_layer=nn.GELU, norm_layer=GroupNorm1,
|
||||
drop_rate=.0, drop_path_rate=0.,
|
||||
dim,
|
||||
index,
|
||||
layers,
|
||||
pool_size=3,
|
||||
mlp_ratio=4.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=GroupNorm1,
|
||||
drop_rate=.0,
|
||||
drop_path_rate=0.,
|
||||
layer_scale_init_value=1e-5,
|
||||
):
|
||||
""" generate PoolFormer blocks for a stage """
|
||||
|
@ -145,9 +157,13 @@ def basic_blocks(
|
|||
for block_idx in range(layers[index]):
|
||||
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
|
||||
blocks.append(PoolFormerBlock(
|
||||
dim, pool_size=pool_size, mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer, norm_layer=norm_layer,
|
||||
drop=drop_rate, drop_path=block_dpr,
|
||||
dim,
|
||||
pool_size=pool_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop=drop_rate,
|
||||
drop_path=block_dpr,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
))
|
||||
blocks = nn.Sequential(*blocks)
|
||||
|
@ -176,9 +192,12 @@ class PoolFormer(nn.Module):
|
|||
down_patch_size=3,
|
||||
down_stride=2,
|
||||
down_pad=1,
|
||||
drop_rate=0., drop_path_rate=0.,
|
||||
drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
layer_scale_init_value=1e-5,
|
||||
**kwargs):
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
|
@ -187,28 +206,42 @@ class PoolFormer(nn.Module):
|
|||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
patch_size=in_patch_size, stride=in_stride, padding=in_pad,
|
||||
in_chs=in_chans, embed_dim=embed_dims[0])
|
||||
patch_size=in_patch_size,
|
||||
stride=in_stride,
|
||||
padding=in_pad,
|
||||
in_chs=in_chans,
|
||||
embed_dim=embed_dims[0],
|
||||
)
|
||||
|
||||
# set the main block in network
|
||||
network = []
|
||||
for i in range(len(layers)):
|
||||
network.append(basic_blocks(
|
||||
embed_dims[i], i, layers,
|
||||
pool_size=pool_size, mlp_ratio=mlp_ratios[i],
|
||||
act_layer=act_layer, norm_layer=norm_layer,
|
||||
drop_rate=drop_rate, drop_path_rate=drop_path_rate,
|
||||
layer_scale_init_value=layer_scale_init_value)
|
||||
)
|
||||
embed_dims[i],
|
||||
i,
|
||||
layers,
|
||||
pool_size=pool_size,
|
||||
mlp_ratio=mlp_ratios[i],
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop_rate=proj_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
layer_scale_init_value=layer_scale_init_value
|
||||
))
|
||||
if i < len(layers) - 1 and (downsamples[i] or embed_dims[i] != embed_dims[i + 1]):
|
||||
# downsampling between stages
|
||||
network.append(PatchEmbed(
|
||||
in_chs=embed_dims[i], embed_dim=embed_dims[i + 1],
|
||||
patch_size=down_patch_size, stride=down_stride, padding=down_pad)
|
||||
)
|
||||
in_chs=embed_dims[i],
|
||||
embed_dim=embed_dims[i + 1],
|
||||
patch_size=down_patch_size,
|
||||
stride=down_stride,
|
||||
padding=down_pad,
|
||||
))
|
||||
|
||||
self.network = nn.Sequential(*network)
|
||||
self.norm = norm_layer(self.num_features)
|
||||
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
@ -254,6 +287,7 @@ class PoolFormer(nn.Module):
|
|||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean([-2, -1])
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -54,8 +54,14 @@ default_cfgs = {
|
|||
|
||||
class MlpWithDepthwiseConv(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
|
||||
drop=0., extra_relu=False):
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.,
|
||||
extra_relu=False,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
@ -154,8 +160,19 @@ class Attention(nn.Module):
|
|||
class Block(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., sr_ratio=1, linear_attn=False, qkv_bias=False,
|
||||
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
sr_ratio=1,
|
||||
linear_attn=False,
|
||||
qkv_bias=False,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
|
@ -165,7 +182,7 @@ class Block(nn.Module):
|
|||
linear_attn=linear_attn,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
|
@ -173,8 +190,8 @@ class Block(nn.Module):
|
|||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
extra_relu=linear_attn
|
||||
drop=proj_drop,
|
||||
extra_relu=linear_attn,
|
||||
)
|
||||
|
||||
def forward(self, x, feat_size: List[int]):
|
||||
|
@ -193,8 +210,8 @@ class OverlapPatchEmbed(nn.Module):
|
|||
assert max(patch_size) > stride, "Set larger patch_size than stride"
|
||||
self.patch_size = patch_size
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
||||
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
||||
in_chans, embed_dim, patch_size,
|
||||
stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2))
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -217,7 +234,7 @@ class PyramidVisionTransformerStage(nn.Module):
|
|||
linear_attn: bool = False,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: Union[List[float], float] = 0.0,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
|
@ -242,7 +259,7 @@ class PyramidVisionTransformerStage(nn.Module):
|
|||
linear_attn=linear_attn,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer,
|
||||
|
@ -278,6 +295,7 @@ class PyramidVisionTransformerV2(nn.Module):
|
|||
qkv_bias=True,
|
||||
linear=False,
|
||||
drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
|
@ -314,16 +332,17 @@ class PyramidVisionTransformerV2(nn.Module):
|
|||
mlp_ratio=mlp_ratios[i],
|
||||
linear_attn=linear,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer
|
||||
norm_layer=norm_layer,
|
||||
))
|
||||
prev_dim = embed_dims[i]
|
||||
cur += depths[i]
|
||||
|
||||
# classification head
|
||||
self.num_features = embed_dims[-1]
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
@ -379,6 +398,7 @@ class PyramidVisionTransformerV2(nn.Module):
|
|||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool:
|
||||
x = x.mean(dim=(-1, -2))
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -181,7 +181,7 @@ class SwinTransformerBlock(nn.Module):
|
|||
shift_size: int = 0,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
act_layer: Callable = nn.GELU,
|
||||
|
@ -197,7 +197,7 @@ class SwinTransformerBlock(nn.Module):
|
|||
shift_size: Shift size for SW-MSA.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop: Dropout rate.
|
||||
proj_drop: Dropout rate.
|
||||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
act_layer: Activation layer.
|
||||
|
@ -223,7 +223,7 @@ class SwinTransformerBlock(nn.Module):
|
|||
window_size=to_2tuple(self.window_size),
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
@ -232,7 +232,7 @@ class SwinTransformerBlock(nn.Module):
|
|||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
drop=proj_drop,
|
||||
)
|
||||
|
||||
if self.shift_size > 0:
|
||||
|
@ -346,7 +346,7 @@ class SwinTransformerStage(nn.Module):
|
|||
window_size: _int_or_tuple_2_t = 7,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: Union[List[float], float] = 0.,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
|
@ -362,7 +362,7 @@ class SwinTransformerStage(nn.Module):
|
|||
window_size: Local window size.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop: Dropout rate.
|
||||
proj_drop: Projection dropout rate.
|
||||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
|
@ -396,7 +396,7 @@ class SwinTransformerStage(nn.Module):
|
|||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer,
|
||||
|
@ -435,6 +435,7 @@ class SwinTransformer(nn.Module):
|
|||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop_rate: float = 0.,
|
||||
proj_drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.1,
|
||||
norm_layer: Union[str, Callable] = nn.LayerNorm,
|
||||
|
@ -508,7 +509,7 @@ class SwinTransformer(nn.Module):
|
|||
window_size=window_size[i],
|
||||
mlp_ratio=mlp_ratio[i],
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
|
|
|
@ -203,7 +203,7 @@ class SwinTransformerV2Block(nn.Module):
|
|||
shift_size=0,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop=0.,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
|
@ -219,7 +219,7 @@ class SwinTransformerV2Block(nn.Module):
|
|||
shift_size: Shift size for SW-MSA.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop: Dropout rate.
|
||||
proj_drop: Dropout rate.
|
||||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
act_layer: Activation layer.
|
||||
|
@ -242,7 +242,7 @@ class SwinTransformerV2Block(nn.Module):
|
|||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
pretrained_window_size=to_2tuple(pretrained_window_size),
|
||||
)
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
@ -252,7 +252,7 @@ class SwinTransformerV2Block(nn.Module):
|
|||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
@ -367,7 +367,7 @@ class SwinTransformerV2Stage(nn.Module):
|
|||
downsample=False,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop=0.,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
|
@ -384,7 +384,7 @@ class SwinTransformerV2Stage(nn.Module):
|
|||
downsample: Use downsample layer at start of the block.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop: Dropout rate
|
||||
proj_drop: Projection dropout rate
|
||||
attn_drop: Attention dropout rate.
|
||||
drop_path: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
|
@ -416,7 +416,7 @@ class SwinTransformerV2Stage(nn.Module):
|
|||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer,
|
||||
|
@ -463,6 +463,7 @@ class SwinTransformerV2(nn.Module):
|
|||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop_rate: float = 0.,
|
||||
proj_drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.1,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
|
@ -481,7 +482,8 @@ class SwinTransformerV2(nn.Module):
|
|||
window_size: Window size.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
drop_rate: Dropout rate.
|
||||
drop_rate: Head dropout rate.
|
||||
proj_drop_rate: Projection dropout rate.
|
||||
attn_drop_rate: Attention dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
norm_layer: Normalization layer.
|
||||
|
@ -531,7 +533,8 @@ class SwinTransformerV2(nn.Module):
|
|||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
pretrained_window_size=pretrained_window_sizes[i],
|
||||
|
|
|
@ -221,7 +221,7 @@ class SwinTransformerV2CrBlock(nn.Module):
|
|||
window_size (Tuple[int, int]): Window size to be utilized
|
||||
shift_size (int): Shifting size to be used
|
||||
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
|
||||
drop (float): Dropout in input mapping
|
||||
proj_drop (float): Dropout in input mapping
|
||||
drop_attn (float): Dropout rate of attention map
|
||||
drop_path (float): Dropout in main path
|
||||
extra_norm (bool): Insert extra norm on 'main' branch if True
|
||||
|
@ -238,7 +238,7 @@ class SwinTransformerV2CrBlock(nn.Module):
|
|||
shift_size: Tuple[int, int] = (0, 0),
|
||||
mlp_ratio: float = 4.0,
|
||||
init_values: Optional[float] = 0,
|
||||
drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
drop_attn: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
extra_norm: bool = False,
|
||||
|
@ -259,7 +259,7 @@ class SwinTransformerV2CrBlock(nn.Module):
|
|||
num_heads=num_heads,
|
||||
window_size=self.window_size,
|
||||
drop_attn=drop_attn,
|
||||
drop_proj=drop,
|
||||
drop_proj=proj_drop,
|
||||
sequential_attn=sequential_attn,
|
||||
)
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
@ -269,7 +269,7 @@ class SwinTransformerV2CrBlock(nn.Module):
|
|||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
drop=drop,
|
||||
drop=proj_drop,
|
||||
out_features=dim,
|
||||
)
|
||||
self.norm2 = norm_layer(dim)
|
||||
|
@ -445,7 +445,7 @@ class SwinTransformerV2CrStage(nn.Module):
|
|||
num_heads (int): Number of attention heads to be utilized
|
||||
window_size (int): Window size to be utilized
|
||||
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
|
||||
drop (float): Dropout in input mapping
|
||||
proj_drop (float): Dropout in input mapping
|
||||
drop_attn (float): Dropout rate of attention map
|
||||
drop_path (float): Dropout in main path
|
||||
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
|
||||
|
@ -464,7 +464,7 @@ class SwinTransformerV2CrStage(nn.Module):
|
|||
window_size: Tuple[int, int],
|
||||
mlp_ratio: float = 4.0,
|
||||
init_values: Optional[float] = 0.0,
|
||||
drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
drop_attn: float = 0.0,
|
||||
drop_path: Union[List[float], float] = 0.0,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
|
@ -498,7 +498,7 @@ class SwinTransformerV2CrStage(nn.Module):
|
|||
shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]),
|
||||
mlp_ratio=mlp_ratio,
|
||||
init_values=init_values,
|
||||
drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
drop_attn=drop_attn,
|
||||
drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,
|
||||
extra_norm=_extra_norm(index),
|
||||
|
@ -546,23 +546,24 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
https://arxiv.org/pdf/2111.09883
|
||||
|
||||
Args:
|
||||
img_size (Tuple[int, int]): Input resolution.
|
||||
window_size (Optional[int]): Window size. If None, img_size // window_div. Default: None
|
||||
img_window_ratio (int): Window size to image size ratio. Default: 32
|
||||
patch_size (int | tuple(int)): Patch size. Default: 4
|
||||
in_chans (int): Number of input channels.
|
||||
depths (int): Depth of the stage (number of layers).
|
||||
num_heads (int): Number of attention heads to be utilized.
|
||||
embed_dim (int): Patch embedding dimension. Default: 96
|
||||
num_classes (int): Number of output classes. Default: 1000
|
||||
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels. Default: 4
|
||||
drop_rate (float): Dropout rate. Default: 0.0
|
||||
attn_drop_rate (float): Dropout rate of attention map. Default: 0.0
|
||||
drop_path_rate (float): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
|
||||
extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks in stage
|
||||
extra_norm_stage (bool): End each stage with an extra norm layer in main branch
|
||||
sequential_attn (bool): If true sequential self-attention is performed. Default: False
|
||||
img_size: Input resolution.
|
||||
window_size: Window size. If None, img_size // window_div
|
||||
img_window_ratio: Window size to image size ratio.
|
||||
patch_size: Patch size.
|
||||
in_chans: Number of input channels.
|
||||
depths: Depth of the stage (number of layers).
|
||||
num_heads: Number of attention heads to be utilized.
|
||||
embed_dim: Patch embedding dimension.
|
||||
num_classes: Number of output classes.
|
||||
mlp_ratio: Ratio of the hidden dimension in the FFN to the input channels.
|
||||
drop_rate: Dropout rate.
|
||||
proj_drop_rate: Projection dropout rate.
|
||||
attn_drop_rate: Dropout rate of attention map.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
norm_layer: Type of normalization layer to be utilized.
|
||||
extra_norm_period: Insert extra norm layer on main branch every N (period) blocks in stage
|
||||
extra_norm_stage: End each stage with an extra norm layer in main branch
|
||||
sequential_attn: If true sequential self-attention is performed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -579,6 +580,7 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
mlp_ratio: float = 4.0,
|
||||
init_values: Optional[float] = 0.,
|
||||
drop_rate: float = 0.0,
|
||||
proj_drop_rate: float = 0.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
|
@ -627,7 +629,7 @@ class SwinTransformerV2Cr(nn.Module):
|
|||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
init_values=init_values,
|
||||
drop=drop_rate,
|
||||
proj_drop=proj_drop_rate,
|
||||
drop_attn=attn_drop_rate,
|
||||
drop_path=dpr[stage_idx],
|
||||
extra_norm_period=extra_norm_period,
|
||||
|
|
|
@ -80,31 +80,64 @@ class Block(nn.Module):
|
|||
""" TNT Block
|
||||
"""
|
||||
def __init__(
|
||||
self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4.,
|
||||
qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
self,
|
||||
dim,
|
||||
dim_out,
|
||||
num_pixel,
|
||||
num_heads_in=4,
|
||||
num_heads_out=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
# Inner transformer
|
||||
self.norm_in = norm_layer(in_dim)
|
||||
self.norm_in = norm_layer(dim)
|
||||
self.attn_in = Attention(
|
||||
in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop, proj_drop=drop)
|
||||
dim,
|
||||
dim,
|
||||
num_heads=num_heads_in,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
|
||||
self.norm_mlp_in = norm_layer(in_dim)
|
||||
self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4),
|
||||
out_features=in_dim, act_layer=act_layer, drop=drop)
|
||||
self.norm_mlp_in = norm_layer(dim)
|
||||
self.mlp_in = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * 4),
|
||||
out_features=dim,
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
|
||||
self.norm1_proj = norm_layer(in_dim)
|
||||
self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True)
|
||||
self.norm1_proj = norm_layer(dim)
|
||||
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True)
|
||||
|
||||
# Outer transformer
|
||||
self.norm_out = norm_layer(dim)
|
||||
self.norm_out = norm_layer(dim_out)
|
||||
self.attn_out = Attention(
|
||||
dim, dim, num_heads=num_heads, qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop, proj_drop=drop)
|
||||
dim_out,
|
||||
dim_out,
|
||||
num_heads=num_heads_out,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm_mlp = norm_layer(dim)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio),
|
||||
out_features=dim, act_layer=act_layer, drop=drop)
|
||||
self.norm_mlp = norm_layer(dim_out)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim_out,
|
||||
hidden_features=int(dim_out * mlp_ratio),
|
||||
out_features=dim_out,
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
|
||||
def forward(self, pixel_embed, patch_embed):
|
||||
# inner
|
||||
|
@ -157,9 +190,27 @@ class TNT(nn.Module):
|
|||
""" Transformer in Transformer - https://arxiv.org/abs/2103.00112
|
||||
"""
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
|
||||
embed_dim=768, in_dim=48, depth=12, num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4):
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='token',
|
||||
embed_dim=768,
|
||||
inner_dim=48,
|
||||
depth=12,
|
||||
num_heads_inner=4,
|
||||
num_heads_outer=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
drop_rate=0.,
|
||||
pos_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
first_stride=4,
|
||||
):
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'token', 'avg')
|
||||
self.num_classes = num_classes
|
||||
|
@ -168,31 +219,46 @@ class TNT(nn.Module):
|
|||
self.grad_checkpointing = False
|
||||
|
||||
self.pixel_embed = PixelEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
in_dim=inner_dim,
|
||||
stride=first_stride,
|
||||
)
|
||||
num_patches = self.pixel_embed.num_patches
|
||||
self.num_patches = num_patches
|
||||
new_patch_size = self.pixel_embed.new_patch_size
|
||||
num_pixel = new_patch_size[0] * new_patch_size[1]
|
||||
|
||||
self.norm1_proj = norm_layer(num_pixel * in_dim)
|
||||
self.proj = nn.Linear(num_pixel * in_dim, embed_dim)
|
||||
self.norm1_proj = norm_layer(num_pixel * inner_dim)
|
||||
self.proj = nn.Linear(num_pixel * inner_dim, embed_dim)
|
||||
self.norm2_proj = norm_layer(embed_dim)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
||||
self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size[0], new_patch_size[1]))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.pixel_pos = nn.Parameter(torch.zeros(1, inner_dim, new_patch_size[0], new_patch_size[1]))
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
blocks = []
|
||||
for i in range(depth):
|
||||
blocks.append(Block(
|
||||
dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head,
|
||||
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i], norm_layer=norm_layer))
|
||||
dim=inner_dim,
|
||||
dim_out=embed_dim,
|
||||
num_pixel=num_pixel,
|
||||
num_heads_in=num_heads_inner,
|
||||
num_heads_out=num_heads_outer,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
))
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
@ -260,6 +326,7 @@ class TNT(nn.Module):
|
|||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool:
|
||||
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -290,16 +357,16 @@ def _create_tnt(variant, pretrained=False, **kwargs):
|
|||
@register_model
|
||||
def tnt_s_patch16_224(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
|
||||
qkv_bias=False, **kwargs)
|
||||
model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg)
|
||||
patch_size=16, embed_dim=384, inner_dim=24, depth=12, num_heads_outer=6,
|
||||
qkv_bias=False)
|
||||
model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tnt_b_patch16_224(pretrained=False, **kwargs):
|
||||
model_cfg = dict(
|
||||
patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4,
|
||||
qkv_bias=False, **kwargs)
|
||||
model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg)
|
||||
patch_size=16, embed_dim=640, inner_dim=40, depth=12, num_heads_outer=10,
|
||||
qkv_bias=False)
|
||||
model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||
return model
|
||||
|
|
|
@ -200,20 +200,20 @@ class GlobalSubSampleAttn(nn.Module):
|
|||
class Block(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.,
|
||||
self, dim, num_heads, mlp_ratio=4., proj_drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
if ws is None:
|
||||
self.attn = Attention(dim, num_heads, False, None, attn_drop, drop)
|
||||
self.attn = Attention(dim, num_heads, False, None, attn_drop, proj_drop)
|
||||
elif ws == 1:
|
||||
self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio)
|
||||
self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, proj_drop, sr_ratio)
|
||||
else:
|
||||
self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws)
|
||||
self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, proj_drop, ws)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop)
|
||||
|
||||
def forward(self, x, size: Size_):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), size))
|
||||
|
@ -275,10 +275,26 @@ class Twins(nn.Module):
|
|||
Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git
|
||||
"""
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
|
||||
embed_dims=(64, 128, 256, 512), num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), depths=(3, 4, 6, 3),
|
||||
sr_ratios=(8, 4, 2, 1), wss=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), block_cls=Block):
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
embed_dims=(64, 128, 256, 512),
|
||||
num_heads=(1, 2, 4, 8),
|
||||
mlp_ratios=(4, 4, 4, 4),
|
||||
depths=(3, 4, 6, 3),
|
||||
sr_ratios=(8, 4, 2, 1),
|
||||
wss=None,
|
||||
drop_rate=0.,
|
||||
pos_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
block_cls=Block,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
|
@ -293,7 +309,7 @@ class Twins(nn.Module):
|
|||
self.pos_drops = nn.ModuleList()
|
||||
for i in range(len(depths)):
|
||||
self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i]))
|
||||
self.pos_drops.append(nn.Dropout(p=drop_rate))
|
||||
self.pos_drops.append(nn.Dropout(p=pos_drop_rate))
|
||||
prev_chs = embed_dims[i]
|
||||
img_size = tuple(t // patch_size for t in img_size)
|
||||
patch_size = 2
|
||||
|
@ -303,9 +319,16 @@ class Twins(nn.Module):
|
|||
cur = 0
|
||||
for k in range(len(depths)):
|
||||
_block = nn.ModuleList([block_cls(
|
||||
dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], drop=drop_rate,
|
||||
attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k],
|
||||
ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])])
|
||||
dim=embed_dims[k],
|
||||
num_heads=num_heads[k],
|
||||
mlp_ratio=mlp_ratios[k],
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[cur + i],
|
||||
norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratios[k],
|
||||
ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])],
|
||||
)
|
||||
self.blocks.append(_block)
|
||||
cur += depths[k]
|
||||
|
||||
|
@ -314,6 +337,7 @@ class Twins(nn.Module):
|
|||
self.norm = norm_layer(self.num_features)
|
||||
|
||||
# classification head
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
# init weights
|
||||
|
@ -386,6 +410,7 @@ class Twins(nn.Module):
|
|||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool == 'avg':
|
||||
x = x.mean(dim=1)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -404,47 +429,47 @@ def _create_twins(variant, pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def twins_pcpvt_small(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
|
||||
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
|
||||
return _create_twins('twins_pcpvt_small', pretrained=pretrained, **model_kwargs)
|
||||
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1])
|
||||
return _create_twins('twins_pcpvt_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def twins_pcpvt_base(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
|
||||
depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
|
||||
return _create_twins('twins_pcpvt_base', pretrained=pretrained, **model_kwargs)
|
||||
depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1])
|
||||
return _create_twins('twins_pcpvt_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def twins_pcpvt_large(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
|
||||
depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
|
||||
return _create_twins('twins_pcpvt_large', pretrained=pretrained, **model_kwargs)
|
||||
depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1])
|
||||
return _create_twins('twins_pcpvt_large', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def twins_svt_small(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4],
|
||||
depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
|
||||
return _create_twins('twins_svt_small', pretrained=pretrained, **model_kwargs)
|
||||
depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1])
|
||||
return _create_twins('twins_svt_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def twins_svt_base(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4],
|
||||
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
|
||||
return _create_twins('twins_svt_base', pretrained=pretrained, **model_kwargs)
|
||||
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1])
|
||||
return _create_twins('twins_svt_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def twins_svt_large(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
model_args = dict(
|
||||
patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4],
|
||||
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
|
||||
return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs)
|
||||
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1])
|
||||
return _create_twins('twins_svt_large', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
|
|
@ -40,8 +40,15 @@ default_cfgs = dict(
|
|||
|
||||
class SpatialMlp(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, hidden_features=None, out_features=None,
|
||||
act_layer=nn.GELU, drop=0., group=8, spatial_conv=False):
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.,
|
||||
group=8,
|
||||
spatial_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
@ -83,6 +90,8 @@ class SpatialMlp(nn.Module):
|
|||
|
||||
|
||||
class Attention(nn.Module):
|
||||
fast_attn: torch.jit.Final[bool]
|
||||
|
||||
def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
@ -90,6 +99,8 @@ class Attention(nn.Module):
|
|||
head_dim = round(dim // num_heads * head_dim_ratio)
|
||||
self.head_dim = head_dim
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
|
||||
|
||||
self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False)
|
||||
|
@ -100,10 +111,16 @@ class Attention(nn.Module):
|
|||
x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
|
||||
q, k, v = x.unbind(0)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
if self.fast_attn:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
|
||||
x = self.proj(x)
|
||||
|
@ -113,9 +130,20 @@ class Attention(nn.Module):
|
|||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4.,
|
||||
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d,
|
||||
group=8, attn_disabled=False, spatial_conv=False):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
head_dim_ratio=1.,
|
||||
mlp_ratio=4.,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=LayerNorm2d,
|
||||
group=8,
|
||||
attn_disabled=False,
|
||||
spatial_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.spatial_conv = spatial_conv
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
@ -125,12 +153,22 @@ class Block(nn.Module):
|
|||
else:
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
head_dim_ratio=head_dim_ratio,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = SpatialMlp(
|
||||
in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop,
|
||||
group=group, spatial_conv=spatial_conv) # new setting
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
group=group,
|
||||
spatial_conv=spatial_conv,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.attn is not None:
|
||||
|
@ -141,10 +179,31 @@ class Block(nn.Module):
|
|||
|
||||
class Visformer(nn.Module):
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
|
||||
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||
norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111',
|
||||
vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None):
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
init_channels=32,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4.,
|
||||
drop_rate=0.,
|
||||
pos_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=LayerNorm2d,
|
||||
attn_stage='111',
|
||||
use_pos_embed=True,
|
||||
spatial_conv='111',
|
||||
vit_stem=False,
|
||||
group=8,
|
||||
global_pool='avg',
|
||||
conv_init=False,
|
||||
embed_norm=None,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
self.num_classes = num_classes
|
||||
|
@ -159,7 +218,7 @@ class Visformer(nn.Module):
|
|||
else:
|
||||
self.stage_num1 = self.stage_num3 = depth // 3
|
||||
self.stage_num2 = depth - self.stage_num1 - self.stage_num3
|
||||
self.pos_embed = pos_embed
|
||||
self.use_pos_embed = use_pos_embed
|
||||
self.grad_checkpointing = False
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
||||
|
@ -167,15 +226,25 @@ class Visformer(nn.Module):
|
|||
if self.vit_stem:
|
||||
self.stem = None
|
||||
self.patch_embed1 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
||||
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=embed_norm,
|
||||
flatten=False,
|
||||
)
|
||||
img_size = [x // patch_size for x in img_size]
|
||||
else:
|
||||
if self.init_channels is None:
|
||||
self.stem = None
|
||||
self.patch_embed1 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans,
|
||||
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size // 2,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim // 2,
|
||||
norm_layer=embed_norm,
|
||||
flatten=False,
|
||||
)
|
||||
img_size = [x // (patch_size // 2) for x in img_size]
|
||||
else:
|
||||
self.stem = nn.Sequential(
|
||||
|
@ -185,21 +254,37 @@ class Visformer(nn.Module):
|
|||
)
|
||||
img_size = [x // 2 for x in img_size]
|
||||
self.patch_embed1 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels,
|
||||
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size // 4,
|
||||
in_chans=self.init_channels,
|
||||
embed_dim=embed_dim // 2,
|
||||
norm_layer=embed_norm,
|
||||
flatten=False,
|
||||
)
|
||||
img_size = [x // (patch_size // 4) for x in img_size]
|
||||
|
||||
if self.pos_embed:
|
||||
if self.use_pos_embed:
|
||||
if self.vit_stem:
|
||||
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
|
||||
else:
|
||||
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
else:
|
||||
self.pos_embed1 = None
|
||||
|
||||
self.stage1 = nn.Sequential(*[
|
||||
Block(
|
||||
dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1')
|
||||
dim=embed_dim//2,
|
||||
num_heads=num_heads,
|
||||
head_dim_ratio=0.5,
|
||||
mlp_ratio=mlp_ratio,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
group=group,
|
||||
attn_disabled=(attn_stage[0] == '0'),
|
||||
spatial_conv=(spatial_conv[0] == '1'),
|
||||
)
|
||||
for i in range(self.stage_num1)
|
||||
])
|
||||
|
@ -207,16 +292,33 @@ class Visformer(nn.Module):
|
|||
# stage2
|
||||
if not self.vit_stem:
|
||||
self.patch_embed2 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2,
|
||||
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size // 8,
|
||||
in_chans=embed_dim // 2,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=embed_norm,
|
||||
flatten=False,
|
||||
)
|
||||
img_size = [x // (patch_size // 8) for x in img_size]
|
||||
if self.pos_embed:
|
||||
if self.use_pos_embed:
|
||||
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
|
||||
else:
|
||||
self.pos_embed2 = None
|
||||
else:
|
||||
self.patch_embed2 = None
|
||||
self.stage2 = nn.Sequential(*[
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1')
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
head_dim_ratio=1.0,
|
||||
mlp_ratio=mlp_ratio,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
group=group,
|
||||
attn_disabled=(attn_stage[1] == '0'),
|
||||
spatial_conv=(spatial_conv[1] == '1'),
|
||||
)
|
||||
for i in range(self.stage_num1, self.stage_num1+self.stage_num2)
|
||||
])
|
||||
|
@ -224,27 +326,48 @@ class Visformer(nn.Module):
|
|||
# stage 3
|
||||
if not self.vit_stem:
|
||||
self.patch_embed3 = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim,
|
||||
embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size // 8,
|
||||
in_chans=embed_dim,
|
||||
embed_dim=embed_dim * 2,
|
||||
norm_layer=embed_norm,
|
||||
flatten=False,
|
||||
)
|
||||
img_size = [x // (patch_size // 8) for x in img_size]
|
||||
if self.pos_embed:
|
||||
if self.use_pos_embed:
|
||||
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size))
|
||||
else:
|
||||
self.pos_embed3 = None
|
||||
else:
|
||||
self.patch_embed3 = None
|
||||
self.stage3 = nn.Sequential(*[
|
||||
Block(
|
||||
dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
||||
group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1')
|
||||
dim=embed_dim * 2,
|
||||
num_heads=num_heads,
|
||||
head_dim_ratio=1.0,
|
||||
mlp_ratio=mlp_ratio,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
group=group,
|
||||
attn_disabled=(attn_stage[2] == '0'),
|
||||
spatial_conv=(spatial_conv[2] == '1'),
|
||||
)
|
||||
for i in range(self.stage_num1+self.stage_num2, depth)
|
||||
])
|
||||
|
||||
# head
|
||||
self.num_features = embed_dim if self.vit_stem else embed_dim * 2
|
||||
self.norm = norm_layer(self.num_features)
|
||||
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
|
||||
# head
|
||||
global_pool, head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
|
||||
self.global_pool = global_pool
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = head
|
||||
|
||||
# weights init
|
||||
if self.pos_embed:
|
||||
if self.use_pos_embed:
|
||||
trunc_normal_(self.pos_embed1, std=0.02)
|
||||
if not self.vit_stem:
|
||||
trunc_normal_(self.pos_embed2, std=0.02)
|
||||
|
@ -293,7 +416,7 @@ class Visformer(nn.Module):
|
|||
|
||||
# stage 1
|
||||
x = self.patch_embed1(x)
|
||||
if self.pos_embed:
|
||||
if self.pos_embed1 is not None:
|
||||
x = self.pos_drop(x + self.pos_embed1)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.stage1, x)
|
||||
|
@ -301,9 +424,9 @@ class Visformer(nn.Module):
|
|||
x = self.stage1(x)
|
||||
|
||||
# stage 2
|
||||
if not self.vit_stem:
|
||||
if self.patch_embed2 is not None:
|
||||
x = self.patch_embed2(x)
|
||||
if self.pos_embed:
|
||||
if self.pos_embed2 is not None:
|
||||
x = self.pos_drop(x + self.pos_embed2)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.stage2, x)
|
||||
|
@ -311,9 +434,9 @@ class Visformer(nn.Module):
|
|||
x = self.stage2(x)
|
||||
|
||||
# stage3
|
||||
if not self.vit_stem:
|
||||
if self.patch_embed3 is not None:
|
||||
x = self.patch_embed3(x)
|
||||
if self.pos_embed:
|
||||
if self.pos_embed3 is not None:
|
||||
x = self.pos_drop(x + self.pos_embed3)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.stage3, x)
|
||||
|
@ -325,6 +448,7 @@ class Visformer(nn.Module):
|
|||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -345,8 +469,8 @@ def visformer_tiny(pretrained=False, **kwargs):
|
|||
model_cfg = dict(
|
||||
init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
|
||||
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
|
||||
embed_norm=nn.BatchNorm2d, **kwargs)
|
||||
model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg)
|
||||
embed_norm=nn.BatchNorm2d)
|
||||
model = _create_visformer('visformer_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
@ -355,8 +479,8 @@ def visformer_small(pretrained=False, **kwargs):
|
|||
model_cfg = dict(
|
||||
init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
|
||||
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
|
||||
embed_norm=nn.BatchNorm2d, **kwargs)
|
||||
model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg)
|
||||
embed_norm=nn.BatchNorm2d)
|
||||
model = _create_visformer('visformer_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ import logging
|
|||
import math
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Optional, List
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -38,7 +38,7 @@ from torch.jit import Final
|
|||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
|
||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
|
||||
resample_abs_pos_embed, RmsNorm
|
||||
resample_abs_pos_embed, RmsNorm, PatchDropout
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
@ -119,7 +119,7 @@ class Block(nn.Module):
|
|||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_norm=False,
|
||||
drop=0.,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
init_values=None,
|
||||
drop_path=0.,
|
||||
|
@ -134,7 +134,7 @@ class Block(nn.Module):
|
|||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
|
@ -145,7 +145,7 @@ class Block(nn.Module):
|
|||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
@ -165,7 +165,7 @@ class ResPostBlock(nn.Module):
|
|||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_norm=False,
|
||||
drop=0.,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
init_values=None,
|
||||
drop_path=0.,
|
||||
|
@ -181,7 +181,7 @@ class ResPostBlock(nn.Module):
|
|||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
@ -191,7 +191,7 @@ class ResPostBlock(nn.Module):
|
|||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
@ -224,7 +224,7 @@ class ParallelScalingBlock(nn.Module):
|
|||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_norm=False,
|
||||
drop=0.,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
init_values=None,
|
||||
drop_path=0.,
|
||||
|
@ -255,7 +255,7 @@ class ParallelScalingBlock(nn.Module):
|
|||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.attn_out_proj = nn.Linear(dim, dim)
|
||||
|
||||
self.mlp_drop = nn.Dropout(drop)
|
||||
self.mlp_drop = nn.Dropout(proj_drop)
|
||||
self.mlp_act = act_layer()
|
||||
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim)
|
||||
|
||||
|
@ -318,7 +318,7 @@ class ParallelThingsBlock(nn.Module):
|
|||
qkv_bias=False,
|
||||
qk_norm=False,
|
||||
init_values=None,
|
||||
drop=0.,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
|
@ -337,7 +337,7 @@ class ParallelThingsBlock(nn.Module):
|
|||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
norm_layer=norm_layer,
|
||||
)),
|
||||
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
|
||||
|
@ -349,7 +349,7 @@ class ParallelThingsBlock(nn.Module):
|
|||
dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
drop=proj_drop,
|
||||
)),
|
||||
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
|
||||
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
|
||||
|
@ -382,53 +382,58 @@ class VisionTransformer(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='token',
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
init_values=None,
|
||||
class_token=True,
|
||||
no_embed_class=False,
|
||||
pre_norm=False,
|
||||
fc_norm=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
weight_init='',
|
||||
embed_layer=PatchEmbed,
|
||||
norm_layer=None,
|
||||
act_layer=None,
|
||||
block_fn=Block,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: str = 'token',
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
init_values: Optional[float] = None,
|
||||
class_token: bool = True,
|
||||
no_embed_class: bool = False,
|
||||
pre_norm: bool = False,
|
||||
fc_norm: Optional[bool] = None,
|
||||
drop_rate: float = 0.,
|
||||
pos_drop_rate: float = 0.,
|
||||
patch_drop_rate: float = 0.,
|
||||
proj_drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
weight_init: str = '',
|
||||
embed_layer: Callable = PatchEmbed,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
act_layer: Optional[Callable] = None,
|
||||
block_fn: Callable = Block,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
num_classes (int): number of classes for classification head
|
||||
global_pool (str): type of global pooling for final sequence (default: 'token')
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
init_values: (float): layer-scale init values
|
||||
class_token (bool): use class token
|
||||
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
|
||||
drop_rate (float): dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
weight_init (str): weight init scheme
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
act_layer: (nn.Module): MLP activation layer
|
||||
img_size: Input image size.
|
||||
patch_size: Patch size.
|
||||
in_chans: Number of image input channels.
|
||||
num_classes: Mumber of classes for classification head.
|
||||
global_pool: Type of global pooling for final sequence (default: 'token').
|
||||
embed_dim: Transformer embedding dimension.
|
||||
depth: Depth of transformer.
|
||||
num_heads: Number of attention heads.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: Enable bias for qkv projections if True.
|
||||
init_values: Layer-scale init values (layer-scale enabled if not None).
|
||||
class_token: Use class token.
|
||||
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
||||
drop_rate: Head dropout rate.
|
||||
pos_drop_rate: Position embedding dropout rate.
|
||||
attn_drop_rate: Attention dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
weight_init: Weight initialization scheme.
|
||||
embed_layer: Patch embedding layey.
|
||||
norm_layer: Normalization layer.
|
||||
act_layer: MLP activation layer.
|
||||
block_fn: Transformer block layer.
|
||||
"""
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'avg', 'token')
|
||||
|
@ -456,7 +461,14 @@ class VisionTransformer(nn.Module):
|
|||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
if patch_drop_rate > 0:
|
||||
self.patch_drop = PatchDropout(
|
||||
patch_drop_rate,
|
||||
num_prefix_tokens=self.num_prefix_tokens,
|
||||
)
|
||||
else:
|
||||
self.patch_drop = nn.Identity()
|
||||
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
@ -468,7 +480,7 @@ class VisionTransformer(nn.Module):
|
|||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
init_values=init_values,
|
||||
drop=drop_rate,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
|
@ -479,6 +491,7 @@ class VisionTransformer(nn.Module):
|
|||
|
||||
# Classifier Head
|
||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
if weight_init != 'skip':
|
||||
|
@ -544,6 +557,7 @@ class VisionTransformer(nn.Module):
|
|||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
|
@ -556,6 +570,7 @@ class VisionTransformer(nn.Module):
|
|||
if self.global_pool:
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.fc_norm(x)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -110,7 +110,7 @@ class RelPosBlock(nn.Module):
|
|||
qk_norm=False,
|
||||
rel_pos_cls=None,
|
||||
init_values=None,
|
||||
drop=0.,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
|
@ -125,7 +125,7 @@ class RelPosBlock(nn.Module):
|
|||
qk_norm=qk_norm,
|
||||
rel_pos_cls=rel_pos_cls,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
|
@ -136,7 +136,7 @@ class RelPosBlock(nn.Module):
|
|||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
@ -158,7 +158,7 @@ class ResPostRelPosBlock(nn.Module):
|
|||
qk_norm=False,
|
||||
rel_pos_cls=None,
|
||||
init_values=None,
|
||||
drop=0.,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
|
@ -174,7 +174,7 @@ class ResPostRelPosBlock(nn.Module):
|
|||
qk_norm=qk_norm,
|
||||
rel_pos_cls=rel_pos_cls,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
@ -183,7 +183,7 @@ class ResPostRelPosBlock(nn.Module):
|
|||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
@ -232,6 +232,7 @@ class VisionTransformerRelPos(nn.Module):
|
|||
rel_pos_dim=None,
|
||||
shared_rel_pos=False,
|
||||
drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
weight_init='skip',
|
||||
|
@ -259,6 +260,7 @@ class VisionTransformerRelPos(nn.Module):
|
|||
rel_pos_ty pe (str): type of relative position
|
||||
shared_rel_pos (bool): share relative pos across all blocks
|
||||
drop_rate (float): dropout rate
|
||||
proj_drop_rate (float): projection dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
weight_init (str): weight init scheme
|
||||
|
@ -313,7 +315,7 @@ class VisionTransformerRelPos(nn.Module):
|
|||
qk_norm=qk_norm,
|
||||
rel_pos_cls=rel_pos_cls,
|
||||
init_values=init_values,
|
||||
drop=drop_rate,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
|
@ -324,6 +326,7 @@ class VisionTransformerRelPos(nn.Module):
|
|||
|
||||
# Classifier Head
|
||||
self.fc_norm = norm_layer(embed_dim) if fc_norm else nn.Identity()
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
if weight_init != 'skip':
|
||||
|
@ -380,6 +383,7 @@ class VisionTransformerRelPos(nn.Module):
|
|||
if self.global_pool:
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.fc_norm(x)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -85,7 +85,17 @@ default_cfgs = {
|
|||
|
||||
class OutlookAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, kernel_size=3, padding=1, stride=1, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1,
|
||||
qkv_bias=False,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
):
|
||||
super().__init__()
|
||||
head_dim = dim // num_heads
|
||||
self.num_heads = num_heads
|
||||
|
@ -133,21 +143,40 @@ class OutlookAttention(nn.Module):
|
|||
|
||||
class Outlooker(nn.Module):
|
||||
def __init__(
|
||||
self, dim, kernel_size, padding, stride=1, num_heads=1, mlp_ratio=3., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, qkv_bias=False
|
||||
self,
|
||||
dim,
|
||||
kernel_size,
|
||||
padding,
|
||||
stride=1,
|
||||
num_heads=1,
|
||||
mlp_ratio=3.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
qkv_bias=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = OutlookAttention(
|
||||
dim, num_heads, kernel_size=kernel_size,
|
||||
padding=padding, stride=stride,
|
||||
qkv_bias=qkv_bias, attn_drop=attn_drop)
|
||||
dim,
|
||||
num_heads,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
stride=stride,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
|
@ -158,7 +187,13 @@ class Outlooker(nn.Module):
|
|||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
@ -189,8 +224,16 @@ class Attention(nn.Module):
|
|||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
|
||||
attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop)
|
||||
|
@ -211,7 +254,14 @@ class Transformer(nn.Module):
|
|||
class ClassAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads=8, head_dim=None, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
head_dim=None,
|
||||
qkv_bias=False,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
if head_dim is not None:
|
||||
|
@ -246,17 +296,38 @@ class ClassAttention(nn.Module):
|
|||
class ClassBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, head_dim=None, mlp_ratio=4., qkv_bias=False,
|
||||
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
head_dim=None,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = ClassAttention(
|
||||
dim, num_heads=num_heads, head_dim=head_dim, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
cls_embed = x[:, :1]
|
||||
|
@ -354,18 +425,33 @@ def outlooker_blocks(
|
|||
blocks = []
|
||||
for block_idx in range(layers[index]):
|
||||
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
|
||||
blocks.append(
|
||||
block_fn(
|
||||
dim, kernel_size=kernel_size, padding=padding,
|
||||
stride=stride, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias, attn_drop=attn_drop, drop_path=block_dpr))
|
||||
blocks.append(block_fn(
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
stride=stride,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=block_dpr,
|
||||
))
|
||||
blocks = nn.Sequential(*blocks)
|
||||
return blocks
|
||||
|
||||
|
||||
def transformer_blocks(
|
||||
block_fn, index, dim, layers, num_heads, mlp_ratio=3.,
|
||||
qkv_bias=False, attn_drop=0, drop_path_rate=0., **kwargs):
|
||||
block_fn,
|
||||
index,
|
||||
dim,
|
||||
layers,
|
||||
num_heads,
|
||||
mlp_ratio=3.,
|
||||
qkv_bias=False,
|
||||
attn_drop=0,
|
||||
drop_path_rate=0.,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
generate transformer layers in stage2
|
||||
return: transformer layers
|
||||
|
@ -373,13 +459,14 @@ def transformer_blocks(
|
|||
blocks = []
|
||||
for block_idx in range(layers[index]):
|
||||
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
|
||||
blocks.append(
|
||||
block_fn(
|
||||
dim, num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=block_dpr))
|
||||
blocks.append(block_fn(
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=block_dpr,
|
||||
))
|
||||
blocks = nn.Sequential(*blocks)
|
||||
return blocks
|
||||
|
||||
|
@ -405,6 +492,7 @@ class VOLO(nn.Module):
|
|||
mlp_ratio=3.0,
|
||||
qkv_bias=False,
|
||||
drop_rate=0.,
|
||||
pos_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
|
@ -429,14 +517,18 @@ class VOLO(nn.Module):
|
|||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
stem_conv=True, stem_stride=2, patch_size=patch_size,
|
||||
in_chans=in_chans, hidden_dim=stem_hidden_dim,
|
||||
embed_dim=embed_dims[0])
|
||||
stem_conv=True,
|
||||
stem_stride=2,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
hidden_dim=stem_hidden_dim,
|
||||
embed_dim=embed_dims[0],
|
||||
)
|
||||
|
||||
# inital positional encoding, we add positional encoding after outlooker blocks
|
||||
patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale)
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, patch_grid[0], patch_grid[1], embed_dims[-1]))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
|
||||
# set the main block in network
|
||||
network = []
|
||||
|
@ -444,14 +536,31 @@ class VOLO(nn.Module):
|
|||
if outlook_attention[i]:
|
||||
# stage 1
|
||||
stage = outlooker_blocks(
|
||||
Outlooker, i, embed_dims[i], layers, num_heads[i], mlp_ratio=mlp_ratio[i],
|
||||
qkv_bias=qkv_bias, attn_drop=attn_drop_rate, norm_layer=norm_layer)
|
||||
Outlooker,
|
||||
i,
|
||||
embed_dims[i],
|
||||
layers,
|
||||
num_heads[i],
|
||||
mlp_ratio=mlp_ratio[i],
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop_rate,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
network.append(stage)
|
||||
else:
|
||||
# stage 2
|
||||
stage = transformer_blocks(
|
||||
Transformer, i, embed_dims[i], layers, num_heads[i], mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias,
|
||||
drop_path_rate=drop_path_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer)
|
||||
Transformer,
|
||||
i,
|
||||
embed_dims[i],
|
||||
layers,
|
||||
num_heads[i],
|
||||
mlp_ratio=mlp_ratio[i],
|
||||
qkv_bias=qkv_bias,
|
||||
drop_path_rate=drop_path_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
network.append(stage)
|
||||
|
||||
if downsamples[i]:
|
||||
|
@ -463,19 +572,18 @@ class VOLO(nn.Module):
|
|||
# set post block, for example, class attention layers
|
||||
self.post_network = None
|
||||
if post_layers is not None:
|
||||
self.post_network = nn.ModuleList(
|
||||
[
|
||||
get_block(
|
||||
post_layers[i],
|
||||
dim=embed_dims[-1],
|
||||
num_heads=num_heads[-1],
|
||||
mlp_ratio=mlp_ratio[-1],
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=0.,
|
||||
norm_layer=norm_layer)
|
||||
for i in range(len(post_layers))
|
||||
])
|
||||
self.post_network = nn.ModuleList([
|
||||
get_block(
|
||||
post_layers[i],
|
||||
dim=embed_dims[-1],
|
||||
num_heads=num_heads[-1],
|
||||
mlp_ratio=mlp_ratio[-1],
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=0.,
|
||||
norm_layer=norm_layer)
|
||||
for i in range(len(post_layers))
|
||||
])
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1]))
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
||||
|
@ -487,6 +595,7 @@ class VOLO(nn.Module):
|
|||
self.norm = norm_layer(self.num_features)
|
||||
|
||||
# Classifier head
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
|
@ -630,6 +739,7 @@ class VOLO(nn.Module):
|
|||
out = x[:, 0]
|
||||
else:
|
||||
out = x
|
||||
x = self.head_drop(x)
|
||||
if pre_logits:
|
||||
return out
|
||||
out = self.head(out)
|
||||
|
|
|
@ -219,17 +219,28 @@ class ClassAttentionBlock(nn.Module):
|
|||
"""Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239"""
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
|
||||
act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1., tokens_norm=False):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
eta=1.,
|
||||
tokens_norm=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
||||
self.attn = ClassAttn(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop)
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop)
|
||||
|
||||
if eta is not None: # LayerScale Initialization (no layerscale when None)
|
||||
self.gamma1 = nn.Parameter(eta * torch.ones(dim))
|
||||
|
@ -297,18 +308,28 @@ class XCA(nn.Module):
|
|||
|
||||
class XCABlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1.):
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
eta=1.,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm3 = norm_layer(dim)
|
||||
self.local_mp = LPI(in_features=dim, act_layer=act_layer)
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop)
|
||||
|
||||
self.gamma1 = nn.Parameter(eta * torch.ones(dim))
|
||||
self.gamma3 = nn.Parameter(eta * torch.ones(dim))
|
||||
|
@ -331,9 +352,29 @@ class XCiT(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768,
|
||||
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||
act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False):
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='token',
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
pos_drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
act_layer=None,
|
||||
norm_layer=None,
|
||||
cls_attn_layers=2,
|
||||
use_pos_embed=True,
|
||||
eta=1.,
|
||||
tokens_norm=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
|
@ -346,6 +387,8 @@ class XCiT(nn.Module):
|
|||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
drop_rate (float): dropout rate after positional embedding, and in XCA/CA projection + MLP
|
||||
pos_drop_rate: position embedding dropout rate
|
||||
proj_drop_rate (float): projection dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate (constant across all layers)
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
|
@ -372,28 +415,53 @@ class XCiT(nn.Module):
|
|||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = ConvPatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, act_layer=act_layer)
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.use_pos_embed = use_pos_embed
|
||||
if use_pos_embed:
|
||||
self.pos_embed = PositionalEncodingFourier(dim=embed_dim)
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
XCABlock(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
||||
attn_drop=attn_drop_rate, drop_path=drop_path_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta)
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=drop_path_rate,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
eta=eta,
|
||||
)
|
||||
for _ in range(depth)])
|
||||
|
||||
self.cls_attn_blocks = nn.ModuleList([
|
||||
ClassAttentionBlock(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
||||
attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm)
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
eta=eta,
|
||||
tokens_norm=tokens_norm,
|
||||
)
|
||||
for _ in range(cls_attn_layers)])
|
||||
|
||||
# Classifier head
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
# Init weights
|
||||
|
@ -438,7 +506,7 @@ class XCiT(nn.Module):
|
|||
# x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches)
|
||||
x, (Hp, Wp) = self.patch_embed(x)
|
||||
|
||||
if self.use_pos_embed:
|
||||
if self.pos_embed is not None:
|
||||
# `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
|
||||
pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
|
||||
x = x + pos_encoding
|
||||
|
@ -464,6 +532,7 @@ class XCiT(nn.Module):
|
|||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool:
|
||||
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -509,336 +578,336 @@ def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs):
|
|||
|
||||
@register_model
|
||||
def xcit_nano_12_p16_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
|
||||
model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
|
||||
model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_nano_12_p16_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
|
||||
model = _create_xcit('xcit_nano_12_p16_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
|
||||
model = _create_xcit('xcit_nano_12_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_nano_12_p16_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384, **kwargs)
|
||||
model = _create_xcit('xcit_nano_12_p16_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384)
|
||||
model = _create_xcit('xcit_nano_12_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_tiny_12_p16_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
|
||||
model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_tiny_12_p16_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_tiny_12_p16_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
|
||||
model = _create_xcit('xcit_tiny_12_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_tiny_12_p16_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_tiny_12_p16_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
|
||||
model = _create_xcit('xcit_tiny_12_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_small_12_p16_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
|
||||
model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_small_12_p16_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_small_12_p16_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
|
||||
model = _create_xcit('xcit_small_12_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_small_12_p16_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_small_12_p16_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
|
||||
model = _create_xcit('xcit_small_12_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_tiny_24_p16_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_tiny_24_p16_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_tiny_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_tiny_24_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_tiny_24_p16_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_tiny_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_tiny_24_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_small_24_p16_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_small_24_p16_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_small_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_small_24_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_small_24_p16_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_small_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_small_24_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_medium_24_p16_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_medium_24_p16_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_medium_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_medium_24_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_medium_24_p16_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_medium_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_medium_24_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_large_24_p16_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_large_24_p16_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_large_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_large_24_p16_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_large_24_p16_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_large_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_large_24_p16_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
# Patch size 8x8 models
|
||||
@register_model
|
||||
def xcit_nano_12_p8_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
|
||||
model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
|
||||
model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_nano_12_p8_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
|
||||
model = _create_xcit('xcit_nano_12_p8_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
|
||||
model = _create_xcit('xcit_nano_12_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_nano_12_p8_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
|
||||
model = _create_xcit('xcit_nano_12_p8_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False)
|
||||
model = _create_xcit('xcit_nano_12_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_tiny_12_p8_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
|
||||
model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_tiny_12_p8_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_tiny_12_p8_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
|
||||
model = _create_xcit('xcit_tiny_12_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_tiny_12_p8_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_tiny_12_p8_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True)
|
||||
model = _create_xcit('xcit_tiny_12_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_small_12_p8_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
|
||||
model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_small_12_p8_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_small_12_p8_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
|
||||
model = _create_xcit('xcit_small_12_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_small_12_p8_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_small_12_p8_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True)
|
||||
model = _create_xcit('xcit_small_12_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_tiny_24_p8_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_tiny_24_p8_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_tiny_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_tiny_24_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_tiny_24_p8_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_tiny_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_tiny_24_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_small_24_p8_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_small_24_p8_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_small_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_small_24_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_small_24_p8_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_small_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_small_24_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_medium_24_p8_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_medium_24_p8_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_medium_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_medium_24_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_medium_24_p8_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_medium_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_medium_24_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_large_24_p8_224(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_large_24_p8_224_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_large_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_large_24_p8_224_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def xcit_large_24_p8_384_dist(pretrained=False, **kwargs):
|
||||
model_kwargs = dict(
|
||||
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
|
||||
model = _create_xcit('xcit_large_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
|
||||
model_args = dict(
|
||||
patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True)
|
||||
model = _create_xcit('xcit_large_24_p8_384_dist', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue