mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
874 lines
33 KiB
Python
874 lines
33 KiB
Python
""" Packed Sequence Vision Transformer (ViT) in PyTorch
|
|
|
|
Base on ideas in NaViT paper
|
|
`Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution` - https://arxiv.org/abs/2307.06304
|
|
|
|
This is a WIP, TODO:
|
|
* significant additions to dataset pipeline (data loading / collation) to support sequences required
|
|
* token (patch) dropout needs to be implemented
|
|
* wider variety of position embedding options
|
|
|
|
"""
|
|
import logging
|
|
import math
|
|
from collections import OrderedDict
|
|
from dataclasses import dataclass, field
|
|
from functools import partial
|
|
from typing import Callable, List, Optional, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint
|
|
from torch.jit import Final
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
|
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, trunc_normal_tf_, \
|
|
resample_patch_embed, resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked, to_2tuple
|
|
from ._builder import build_model_with_cfg
|
|
from ._manipulate import named_apply, checkpoint_seq
|
|
from ._registry import generate_default_cfgs, register_model
|
|
from .vision_transformer import get_init_weights_vit
|
|
|
|
__all__ = ['VisionTransformerPacked'] # model_registry will add each entrypoint fn to this
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def extract_patches(
|
|
x,
|
|
patch_size=(16, 16),
|
|
channels_last=False,
|
|
flatten_grid=True,
|
|
pad=False,
|
|
):
|
|
B, C, H, W = x.shape
|
|
ph, pw = patch_size
|
|
if pad:
|
|
pad_h = (patch_size[0] - H % patch_size[0]) % patch_size[0]
|
|
pad_w = (patch_size[1] - W % patch_size[1]) % patch_size[1]
|
|
x = F.pad(x, (0, pad_w, 0, pad_h))
|
|
H += pad_h
|
|
W += pad_w
|
|
gh, gw = H // ph, W // pw
|
|
if channels_last:
|
|
#x = x.unfold(2, ph, pw).unfold(3, ph, pw).permute(0, 2, 3, 4, 5, 1).reshape(B, -1, ph * pw * C)
|
|
x = x.reshape(B, C, gh, ph, gw, pw).permute(0, 2, 4, 3, 5, 1) # B, gH, gW, pH, pW, C
|
|
else:
|
|
#x = x.permute(0, 2, 3, 1).unfold(1, ph, pw).unfold(2, ph, pw).reshape(B, -1, C * ph * pw)
|
|
x = x.reshape(B, C, gh, ph, gw, pw).permute(0, 2, 4, 1, 3, 5)
|
|
if flatten_grid:
|
|
x = x.reshape(B, -1, C * ph * pw)
|
|
else:
|
|
x = x.reshape(B, gh, gw, -1)
|
|
return x
|
|
|
|
|
|
@dataclass
|
|
class PackedSequence:
|
|
tokens: List[torch.Tensor] = field(default_factory=list)
|
|
pos_indices: List[torch.Tensor] = field(default_factory=list)
|
|
seq_ids: List[torch.Tensor] = field(default_factory=list)
|
|
seq_lens: List[int] = field(default_factory=list)
|
|
total_len: int = 0
|
|
num_images: int = 0
|
|
|
|
def add_image(self, tokens, pos_indices):
|
|
seq_id = self.num_images + 1
|
|
seq_len = len(tokens)
|
|
device = tokens.device
|
|
self.tokens.append(tokens)
|
|
self.pos_indices.append(pos_indices)
|
|
self.seq_ids.append(torch.tensor([seq_id] * seq_len, dtype=torch.int64, device=device))
|
|
self.seq_lens.append(seq_len)
|
|
self.total_len += seq_len
|
|
self.num_images += 1
|
|
|
|
def to_tensors(self, max_seq_len, max_num_seq):
|
|
"""
|
|
Args:
|
|
max_seq_len: maximum sequence length (pad to this)
|
|
max_num_seq: maximum # of sequences (images) packed into one sequence (across the batch)
|
|
|
|
Returns:
|
|
Tuple of tensors for packed batch of images
|
|
"""
|
|
assert self.total_len > 0
|
|
assert max_seq_len >= self.total_len
|
|
device = self.tokens[-1].device
|
|
dim = self.tokens[-1].shape[-1]
|
|
pad_len = max_seq_len - self.total_len
|
|
seq_pad = max(0, max_num_seq - len(self.seq_lens))
|
|
seq_lens = self.seq_lens + [0] * seq_pad if seq_pad else self.seq_lens
|
|
seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=device)
|
|
if pad_len:
|
|
tokens = self.tokens + [torch.zeros(pad_len, dim, device=device)]
|
|
pos_indices = self.pos_indices + [torch.zeros((pad_len, 2), dtype=torch.int64, device=device)]
|
|
seq_ids = self.seq_ids + [torch.zeros(pad_len, dtype=torch.int64, device=device)]
|
|
else:
|
|
tokens = self.tokens
|
|
pos_indices = self.pos_indices
|
|
seq_ids = self.seq_ids
|
|
tokens = torch.concat(tokens)
|
|
pos_indices = torch.concat(pos_indices)
|
|
seq_ids = torch.concat(seq_ids)
|
|
return tokens, pos_indices, seq_ids, seq_lens
|
|
|
|
|
|
def pack_images(
|
|
images: List[torch.Tensor],
|
|
patch_size: Tuple[int, int],
|
|
max_grid_size: Tuple[int, int],
|
|
pad_patches: bool = False,
|
|
max_images_per_sequence: int = 4,
|
|
):
|
|
max_seq_len = max_grid_size[0] * max_grid_size[1]
|
|
|
|
# patchify, generate position indices, apply patch drop, record seq lengths
|
|
img_tokens = []
|
|
img_pos_indices = []
|
|
img_seq_lens = []
|
|
for img in images:
|
|
assert img.ndim == 3
|
|
device = img.device
|
|
patches = extract_patches(img.unsqueeze(0), patch_size, flatten_grid=False, pad=pad_patches).squeeze(0)
|
|
grid_h, grid_w, dim = patches.shape
|
|
seq_len = grid_h * grid_w
|
|
if seq_len > max_seq_len:
|
|
_logger.error('Sequence length of image is too large, skipping.')
|
|
continue
|
|
pos_indices = torch.stack(
|
|
torch.meshgrid((
|
|
torch.arange(grid_h, device=device),
|
|
torch.arange(grid_w, device=device)),
|
|
indexing='ij'),
|
|
dim=-1,
|
|
)
|
|
# FIXME patch drop here
|
|
img_tokens.append(patches.flatten(0, 1))
|
|
img_pos_indices.append(pos_indices.flatten(0, 1))
|
|
img_seq_lens.append(seq_len)
|
|
del images
|
|
|
|
# sort by seq length largest -> smallest
|
|
img_seq_lens = torch.tensor(img_seq_lens, dtype=torch.long, device=device)
|
|
seq_sort_indices = torch.argsort(img_seq_lens, descending=True)
|
|
|
|
packed_sequences: List[PackedSequence] = [] # image sequences packed together
|
|
next_pos = 0
|
|
max_packed = 0
|
|
for _ in range(len(seq_sort_indices)):
|
|
idx_to_pack = seq_sort_indices[next_pos]
|
|
len_to_pack = img_seq_lens[idx_to_pack]
|
|
sequence = None
|
|
for p in packed_sequences:
|
|
# try over existing
|
|
if p.num_images >= max_images_per_sequence or p.total_len + len_to_pack > max_seq_len:
|
|
# will not fit in this sequence
|
|
continue
|
|
sequence = p
|
|
break
|
|
|
|
if sequence is None:
|
|
sequence = PackedSequence() # start fresh sequence
|
|
packed_sequences.append(sequence)
|
|
|
|
img_to_pack = img_tokens[idx_to_pack]
|
|
pos_to_pack = img_pos_indices[idx_to_pack]
|
|
sequence.add_image(img_to_pack, pos_to_pack)
|
|
max_packed = max(sequence.num_images, max_packed)
|
|
next_pos += 1
|
|
|
|
tensors = [p.to_tensors(max_seq_len=max_seq_len, max_num_seq=max_packed) for p in packed_sequences]
|
|
o = [torch.stack(t) for t in zip(*tensors)]
|
|
return tuple(o)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
fused_attn: Final[bool]
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
num_heads=8,
|
|
qkv_bias=False,
|
|
qk_norm=False,
|
|
attn_drop=0.,
|
|
proj_drop=0.,
|
|
norm_layer=nn.LayerNorm,
|
|
):
|
|
super().__init__()
|
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
self.scale = self.head_dim ** -0.5
|
|
self.fused_attn = use_fused_attn()
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
self.proj = nn.Linear(dim, dim)
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
|
B, N, C = x.shape
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
q, k, v = qkv.unbind(0)
|
|
q, k = self.q_norm(q), self.k_norm(k)
|
|
if attn_mask is not None:
|
|
assert attn_mask.ndim == 4
|
|
if attn_mask.shape[1] != self.num_heads:
|
|
attn_mask = attn_mask.expand((-1, self.num_heads, -1, -1))
|
|
|
|
if self.fused_attn:
|
|
x = F.scaled_dot_product_attention(
|
|
q, k, v,
|
|
attn_mask=attn_mask,
|
|
dropout_p=self.attn_drop.p,
|
|
)
|
|
else:
|
|
q = q * self.scale
|
|
attn = q @ k.transpose(-2, -1)
|
|
if attn_mask is not None:
|
|
attn += attn_mask
|
|
attn = attn.softmax(dim=-1)
|
|
attn = self.attn_drop(attn)
|
|
x = attn @ v
|
|
|
|
x = x.transpose(1, 2).reshape(B, N, C)
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
return x
|
|
|
|
|
|
class LayerScale(nn.Module):
|
|
def __init__(self, dim, init_values=1e-5, inplace=False):
|
|
super().__init__()
|
|
self.inplace = inplace
|
|
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
|
|
|
def forward(self, x):
|
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
num_heads,
|
|
mlp_ratio=4.,
|
|
qkv_bias=False,
|
|
qk_norm=False,
|
|
proj_drop=0.,
|
|
attn_drop=0.,
|
|
init_values=None,
|
|
drop_path=0.,
|
|
act_layer=nn.GELU,
|
|
norm_layer=nn.LayerNorm,
|
|
mlp_layer=Mlp,
|
|
):
|
|
super().__init__()
|
|
self.norm1 = norm_layer(dim)
|
|
self.attn = Attention(
|
|
dim,
|
|
num_heads=num_heads,
|
|
qkv_bias=qkv_bias,
|
|
qk_norm=qk_norm,
|
|
attn_drop=attn_drop,
|
|
proj_drop=proj_drop,
|
|
norm_layer=norm_layer,
|
|
)
|
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
self.mlp = mlp_layer(
|
|
in_features=dim,
|
|
hidden_features=int(dim * mlp_ratio),
|
|
act_layer=act_layer,
|
|
drop=proj_drop,
|
|
)
|
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
|
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask)))
|
|
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
|
return x
|
|
|
|
|
|
class ParallelScalingBlock(nn.Module):
|
|
""" Parallel ViT block (MLP & Attention in parallel)
|
|
Based on:
|
|
'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
|
|
"""
|
|
fused_attn: Final[bool]
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
num_heads,
|
|
mlp_ratio=4.,
|
|
qkv_bias=False,
|
|
qk_norm=False,
|
|
proj_drop=0.,
|
|
attn_drop=0.,
|
|
init_values=None,
|
|
drop_path=0.,
|
|
act_layer=nn.GELU,
|
|
norm_layer=nn.LayerNorm,
|
|
mlp_layer=None, # NOTE: not used
|
|
):
|
|
super().__init__()
|
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
self.scale = self.head_dim ** -0.5
|
|
self.fused_attn = use_fused_attn()
|
|
mlp_hidden_dim = int(mlp_ratio * dim)
|
|
in_proj_out_dim = mlp_hidden_dim + 3 * dim
|
|
|
|
self.in_norm = norm_layer(dim)
|
|
self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias)
|
|
self.in_split = [mlp_hidden_dim] + [dim] * 3
|
|
if qkv_bias:
|
|
self.register_buffer('qkv_bias', None)
|
|
self.register_parameter('mlp_bias', None)
|
|
else:
|
|
self.register_buffer('qkv_bias', torch.zeros(3 * dim), persistent=False)
|
|
self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim))
|
|
|
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
self.attn_out_proj = nn.Linear(dim, dim)
|
|
|
|
self.mlp_drop = nn.Dropout(proj_drop)
|
|
self.mlp_act = act_layer()
|
|
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim)
|
|
|
|
self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
def init_weights(self):
|
|
trunc_normal_tf_(self.in_proj.weight, std=(self.head_dim * self.num_heads) ** -0.5)
|
|
|
|
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
|
B, N, C = x.shape
|
|
|
|
# Combined MLP fc1 & qkv projections
|
|
y = self.in_norm(x)
|
|
if self.mlp_bias is not None:
|
|
# Concat constant zero-bias for qkv w/ trainable mlp_bias.
|
|
# Appears faster than adding to x_mlp separately
|
|
y = F.linear(y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias)))
|
|
else:
|
|
y = self.in_proj(y)
|
|
x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)
|
|
|
|
# Dot product attention w/ qk norm
|
|
q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
|
|
k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
|
|
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
|
|
if self.fused_attn:
|
|
x_attn = F.scaled_dot_product_attention(
|
|
q, k, v,
|
|
attn_mask=attn_mask,
|
|
dropout_p=self.attn_drop.p,
|
|
)
|
|
else:
|
|
q = q * self.scale
|
|
attn = q @ k.transpose(-2, -1)
|
|
attn += attn_mask
|
|
attn = attn.softmax(dim=-1)
|
|
attn = self.attn_drop(attn)
|
|
x_attn = attn @ v
|
|
x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
|
|
x_attn = self.attn_out_proj(x_attn)
|
|
|
|
# MLP activation, dropout, fc2
|
|
x_mlp = self.mlp_act(x_mlp)
|
|
x_mlp = self.mlp_drop(x_mlp)
|
|
x_mlp = self.mlp_out_proj(x_mlp)
|
|
|
|
# Add residual w/ drop path & layer scale applied
|
|
y = self.drop_path(self.ls(x_attn + x_mlp))
|
|
x = x + y
|
|
return x
|
|
|
|
|
|
class AttentionPoolLatent(nn.Module):
|
|
""" Attention pooling w/ latent query
|
|
"""
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int = None,
|
|
embed_dim: int = None,
|
|
num_heads: int = 8,
|
|
qkv_bias: bool = True,
|
|
qk_norm: bool = False,
|
|
flatten_input: bool = True,
|
|
latent_size: int = 1,
|
|
latent_proj: bool = False,
|
|
latent_dim: int = None,
|
|
pos_embed: str = '',
|
|
proj_type: str = '',
|
|
pool_type: str = '',
|
|
norm_layer: Optional[nn.Module] = None,
|
|
drop: float = 0.0,
|
|
):
|
|
super().__init__()
|
|
embed_dim = embed_dim or in_features
|
|
out_features = out_features or in_features
|
|
assert embed_dim % num_heads == 0
|
|
self.num_heads = num_heads
|
|
self.head_dim = embed_dim // num_heads
|
|
self.scale = self.head_dim ** -0.5
|
|
self.flatten_input = flatten_input
|
|
self.pool = pool_type
|
|
|
|
if pos_embed == 'abs':
|
|
spatial_len = self.feat_size
|
|
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
|
|
else:
|
|
self.pos_embed = None
|
|
|
|
self.latent_dim = latent_dim or embed_dim
|
|
latent_size = latent_size or self.feat_size
|
|
self.latent_len = latent_size
|
|
self.latent = nn.Parameter(torch.zeros(self.latent_len, embed_dim))
|
|
|
|
if latent_proj:
|
|
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
|
|
else:
|
|
assert not latent_dim or latent_dim == embed_dim
|
|
self.q = None
|
|
|
|
self.kv = nn.Linear(in_features, embed_dim * 2, bias=qkv_bias)
|
|
|
|
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
|
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
|
|
|
|
if proj_type == 'linear':
|
|
self.proj = nn.Linear(embed_dim, out_features)
|
|
self.proj_drop = nn.Dropout(drop)
|
|
elif proj_type == 'mlp':
|
|
self.proj = Mlp(
|
|
embed_dim,
|
|
hidden_features=embed_dim * 4,
|
|
out_features=out_features,
|
|
drop=drop)
|
|
self.proj_drop = nn.Identity()
|
|
else:
|
|
assert out_features == embed_dim
|
|
self.proj = None
|
|
self.proj_drop = nn.Dropout(drop)
|
|
|
|
def init_weights(self):
|
|
if self.pos_embed is not None:
|
|
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
|
trunc_normal_tf_(self.latent, std=self.latent.shape[1] ** -0.5)
|
|
if self.q is not None:
|
|
trunc_normal_tf_(self.q.weight, std=self.q.weight.shape[1] ** -0.5)
|
|
if self.q.bias is not None:
|
|
nn.init.zeros_(self.q.bias)
|
|
trunc_normal_tf_(self.kv.weight, std=self.kv.weight.shape[1] ** -0.5)
|
|
if self.kv.bias is not None:
|
|
nn.init.zeros_(self.kv.bias)
|
|
|
|
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
|
B, N, _ = x.shape
|
|
|
|
if self.pos_embed is not None:
|
|
# FIXME interpolate
|
|
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
|
|
|
|
q = self.latent if self.q is None else self.q(self.latent)
|
|
q = q.reshape(1, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
|
if attn_mask.shape[2] != q.shape[2]:
|
|
# expand latent q to match attention mask, TODO make this less implicit?
|
|
if q.shape[2] == 1:
|
|
q = q.expand(B, -1, attn_mask.shape[2], -1)
|
|
else:
|
|
assert attn_mask.shape[2] % q.shape[2] == 0
|
|
q = q.repeat(1, 1, attn_mask.shape[2] // q.shape[2], 1)
|
|
q = q.expand(B, -1, -1, -1)
|
|
else:
|
|
q = q.expand(B, -1, -1, -1)
|
|
latent_len = q.shape[2]
|
|
x = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
|
k, v = x.unbind(0)
|
|
q = self.q_norm(q)
|
|
k = self.k_norm(k)
|
|
if False:
|
|
x = F.scaled_dot_product_attention(
|
|
q, k, v,
|
|
attn_mask=attn_mask,
|
|
)
|
|
else:
|
|
q = q * self.scale
|
|
attn = q @ k.transpose(-2, -1)
|
|
attn += attn_mask
|
|
attn = attn.softmax(dim=-1)
|
|
x = attn @ v
|
|
x = x.transpose(1, 2).reshape(B, latent_len, -1)
|
|
|
|
x = self.norm(x)
|
|
if self.proj is not None:
|
|
shortcut = x
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
x = x + shortcut
|
|
else:
|
|
x = self.proj_drop(x)
|
|
if self.pool == 'token':
|
|
x = x[:, 0]
|
|
return x
|
|
|
|
|
|
class VisionTransformerPacked(nn.Module):
|
|
""" Vision Transformer
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
img_size: Union[int, Tuple[int, int]] = 224,
|
|
patch_size: Union[int, Tuple[int, int]] = 16,
|
|
in_chans: int = 3,
|
|
num_classes: int = 1000,
|
|
global_pool: str = 'avg',
|
|
embed_dim: int = 768,
|
|
depth: int = 12,
|
|
num_heads: int = 12,
|
|
mlp_ratio: float = 4.,
|
|
qkv_bias: bool = True,
|
|
qk_norm: bool = False,
|
|
init_values: Optional[float] = None,
|
|
pre_norm: bool = False,
|
|
fc_norm: Optional[bool] = None,
|
|
drop_rate: float = 0.,
|
|
pos_drop_rate: float = 0.,
|
|
patch_drop_rate: float = 0.,
|
|
proj_drop_rate: float = 0.,
|
|
attn_drop_rate: float = 0.,
|
|
drop_path_rate: float = 0.,
|
|
weight_init: str = '',
|
|
norm_layer: Optional[Callable] = None,
|
|
act_layer: Optional[Callable] = None,
|
|
block_fn: Callable = Block,
|
|
mlp_layer: Callable = Mlp,
|
|
):
|
|
"""
|
|
Args:
|
|
img_size: Input image size.
|
|
patch_size: Patch size.
|
|
in_chans: Number of image input channels.
|
|
num_classes: Number of classes for classification head.
|
|
global_pool: Type of global pooling for final sequence (default: 'token').
|
|
embed_dim: Transformer embedding dimension.
|
|
depth: Depth of transformer.
|
|
num_heads: Number of attention heads.
|
|
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
|
qkv_bias: Enable bias for qkv projections if True.
|
|
init_values: Layer-scale init values (layer-scale enabled if not None).
|
|
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
|
drop_rate: Head dropout rate.
|
|
pos_drop_rate: Position embedding dropout rate.
|
|
attn_drop_rate: Attention dropout rate.
|
|
drop_path_rate: Stochastic depth rate.
|
|
weight_init: Weight initialization scheme.
|
|
norm_layer: Normalization layer.
|
|
act_layer: MLP activation layer.
|
|
block_fn: Transformer block layer.
|
|
"""
|
|
super().__init__()
|
|
assert global_pool in ('', 'avg', 'attn')
|
|
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
|
act_layer = act_layer or nn.GELU
|
|
|
|
self.num_classes = num_classes
|
|
self.global_pool = global_pool
|
|
self.grad_checkpointing = False
|
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
|
|
|
self.patch_size = patch_h, patch_w = to_2tuple(patch_size)
|
|
self.img_size = img_h, img_w = to_2tuple(img_size) # NOTE this === 'maximum size'
|
|
self.grid_size = grid_h, grid_w = img_h // patch_h, img_w // patch_w
|
|
self.max_seq = grid_h * grid_w
|
|
patch_dim_in = in_chans * patch_h * patch_w
|
|
|
|
self.patch_embed = nn.Linear(patch_dim_in, embed_dim)
|
|
self.pos_embed_h = nn.Parameter(torch.randn(grid_h, embed_dim) * .02)
|
|
self.pos_embed_w = nn.Parameter(torch.randn(grid_w, embed_dim) * .02)
|
|
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
|
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
self.blocks = nn.Sequential(*[
|
|
block_fn(
|
|
dim=embed_dim,
|
|
num_heads=num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias,
|
|
qk_norm=qk_norm,
|
|
init_values=init_values,
|
|
proj_drop=proj_drop_rate,
|
|
attn_drop=attn_drop_rate,
|
|
drop_path=dpr[i],
|
|
norm_layer=norm_layer,
|
|
act_layer=act_layer,
|
|
mlp_layer=mlp_layer,
|
|
)
|
|
for i in range(depth)])
|
|
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
|
|
|
if global_pool == 'avg':
|
|
self.attn_pool = None
|
|
else:
|
|
# FIXME attention pooling appears less stable in initial trials
|
|
self.attn_pool = AttentionPoolLatent(
|
|
self.embed_dim,
|
|
self.embed_dim,
|
|
num_heads=num_heads,
|
|
pos_embed='',
|
|
latent_proj=True,
|
|
proj_type='',
|
|
norm_layer=norm_layer,
|
|
)
|
|
|
|
# Classifier Head
|
|
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
|
self.head_drop = nn.Dropout(drop_rate)
|
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
if weight_init != 'skip':
|
|
self.init_weights(weight_init)
|
|
|
|
def init_weights(self, mode=''):
|
|
assert mode in ('jax', 'jax_nlhb', 'moco', '')
|
|
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
|
trunc_normal_(self.pos_embed_h, std=.02)
|
|
trunc_normal_(self.pos_embed_w, std=.02)
|
|
named_apply(get_init_weights_vit(mode, head_bias), self)
|
|
|
|
@torch.jit.ignore
|
|
def no_weight_decay(self):
|
|
return {'pos_embed_h', 'pos_embed_w'}
|
|
|
|
@torch.jit.ignore
|
|
def group_matcher(self, coarse=False):
|
|
return dict(
|
|
stem=r'^embeds', # stem and embed # FIXME correct when design finalized
|
|
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
|
)
|
|
|
|
@torch.jit.ignore
|
|
def set_grad_checkpointing(self, enable=True):
|
|
self.grad_checkpointing = enable
|
|
|
|
@torch.jit.ignore
|
|
def get_classifier(self):
|
|
return self.head
|
|
|
|
def reset_classifier(self, num_classes: int, global_pool=None):
|
|
self.num_classes = num_classes
|
|
if global_pool is not None:
|
|
assert global_pool in ('', 'avg', 'attn')
|
|
self.global_pool = global_pool
|
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
def forward_features(
|
|
self,
|
|
tokens: Union[List[torch.Tensor], torch.Tensor],
|
|
pos_indices: Optional[torch.Tensor] = None,
|
|
seq_ids: Optional[torch.Tensor] = None,
|
|
seq_lens: Optional[torch.Tensor] = None,
|
|
attn_mask: Optional[torch.Tensor] = None,
|
|
):
|
|
if tokens.ndim == 4:
|
|
# B, C, H, W batch tensor will be converted to list and packed
|
|
# for compatibility with common image model usage (and initial testing)
|
|
tokens = tokens.unbind(0)
|
|
|
|
if isinstance(tokens, (list, tuple)):
|
|
tokens, pos_indices, seq_ids, seq_lens = pack_images(
|
|
tokens,
|
|
self.patch_size,
|
|
max_grid_size=self.grid_size,
|
|
pad_patches=True,
|
|
max_images_per_sequence=4,
|
|
)
|
|
|
|
assert tokens.ndim == 3
|
|
assert pos_indices is not None
|
|
assert seq_ids is not None
|
|
assert seq_lens is not None
|
|
|
|
tokens = self.patch_embed(tokens)
|
|
pos_index_h, pos_index_w = pos_indices.unbind(-1)
|
|
pos = self.pos_embed_h[pos_index_h] + self.pos_embed_w[pos_index_w]
|
|
tokens += pos
|
|
tokens = self.pos_drop(tokens)
|
|
tokens = self.norm_pre(tokens)
|
|
|
|
if attn_mask is None:
|
|
attn_mask = seq_ids.unsqueeze(2) == seq_ids.unsqueeze(1)
|
|
# NOTE: not applying key padding mask as padding tokens are already isolated to
|
|
# themselves via the above mask (padding has seq_id == 0). Doing an additional
|
|
# key padding mask results in fully masked rows which causes numerical issues.
|
|
# key_padding_mask = (seq_ids != 0).unsqueeze(1)
|
|
# attn_mask = attn_mask & key_padding_mask
|
|
attn_mask = attn_mask.unsqueeze(1)
|
|
|
|
if attn_mask.dtype == torch.bool:
|
|
dtype = tokens.dtype
|
|
min_val = torch.finfo(dtype).min
|
|
attn_mask = torch.zeros_like(attn_mask, dtype=dtype).masked_fill_(~attn_mask, min_val)
|
|
|
|
for b in self.blocks:
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
tokens = torch.utils.checkpoint.checkpoint(
|
|
b, tokens, use_reentrant=False, attn_mask=attn_mask)
|
|
else:
|
|
tokens = b(tokens, attn_mask=attn_mask)
|
|
tokens = self.norm(tokens)
|
|
|
|
device = tokens.device
|
|
max_packing = seq_lens.shape[1]
|
|
seq_id_range = torch.arange(1, 1 + max_packing, device=device)
|
|
unpack_mask = seq_ids.unsqueeze(1) == seq_id_range[:, None]
|
|
seq_lens = seq_lens.reshape(-1)
|
|
valid_rows = seq_lens > 0
|
|
if self.attn_pool is not None:
|
|
# unpack_mask = unpack_mask & key_padding_mask
|
|
unpack_mask = unpack_mask.unsqueeze(1)
|
|
unpack_mask = torch.zeros_like(unpack_mask, dtype=tokens.dtype).masked_fill_(
|
|
~unpack_mask, torch.finfo(tokens.dtype).min)
|
|
tokens = self.attn_pool(tokens, attn_mask=unpack_mask)
|
|
tokens = tokens.reshape(-1, self.embed_dim)
|
|
tokens = tokens[valid_rows]
|
|
else:
|
|
tokens = tokens.unsqueeze(1).expand(-1, max_packing, -1, -1)[unpack_mask]
|
|
tokens = tokens.tensor_split(seq_lens.reshape(-1).cumsum(0)[:sum(valid_rows) - 1].cpu())
|
|
# tokens = tokens.unsqueeze(1) * unpack_mask.unsqueeze(-1).expand(-1, -1, -1, self.embed_dim)
|
|
# tokens = tokens.reshape(-1, tokens.shape[-2], tokens.shape[-1])
|
|
# seq_lens = seq_lens[valid_rows]
|
|
# tokens = tokens[valid_rows]
|
|
|
|
# FIXME sort out this mess, the boundary of features vs head is a bit messy with
|
|
# variable length sequence averaging vs attention pooling...
|
|
return tokens #, seq_lens
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
if self.global_pool == 'avg':
|
|
if isinstance(x, (list, tuple)):
|
|
x = torch.stack([t.mean(dim=0) for t in x], 0)
|
|
else:
|
|
# x = x.sum(dim=1) / seq_lens.reshape(-1, 1)
|
|
x = x.mean(dim=1)
|
|
x = self.fc_norm(x)
|
|
x = self.head_drop(x)
|
|
return x if pre_logits else self.head(x)
|
|
|
|
def forward(
|
|
self,
|
|
tokens: Union[List[torch.Tensor], torch.Tensor],
|
|
pos_indices: Optional[torch.Tensor] = None,
|
|
seq_ids: Optional[torch.Tensor] = None,
|
|
seq_lens: Optional[torch.Tensor] = None,
|
|
):
|
|
x = self.forward_features(
|
|
tokens,
|
|
pos_indices=pos_indices,
|
|
seq_ids=seq_ids,
|
|
seq_lens=seq_lens,
|
|
)
|
|
x = self.forward_head(x)
|
|
return x
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
return {
|
|
'url': url,
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
|
**kwargs
|
|
}
|
|
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
'navit_medium_patch16_384': _cfg(),
|
|
'navit_base_patch32_224': _cfg(),
|
|
'navit_base_patch32_384': _cfg(),
|
|
'navit_base_patch16_224': _cfg(),
|
|
'navit_base_patch16_384': _cfg(),
|
|
})
|
|
|
|
|
|
def _create_vision_transformer_packed(variant, pretrained=False, **kwargs):
|
|
if kwargs.get('features_only', None):
|
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
|
|
|
return build_model_with_cfg(
|
|
VisionTransformerPacked,
|
|
variant,
|
|
pretrained,
|
|
#pretrained_filter_fn=checkpoint_filter_fn,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
@register_model
|
|
def navit_medium_patch16_384(pretrained=False, **kwargs) -> VisionTransformerPacked:
|
|
model_args = dict(
|
|
img_size=384, patch_size=16, embed_dim=512, depth=12, num_heads=8,
|
|
fc_norm=False, init_values=1e-5, qkv_bias=False)
|
|
model = _create_vision_transformer_packed(
|
|
'navit_medium_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
return model
|
|
|
|
|
|
@register_model
|
|
def navit_base_patch32_224(pretrained=False, **kwargs) -> VisionTransformerPacked:
|
|
model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
|
|
model = _create_vision_transformer_packed('navit_base_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
return model
|
|
|
|
|
|
@register_model
|
|
def navit_base_patch32_384(pretrained=False, **kwargs) -> VisionTransformerPacked:
|
|
model_args = dict(img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12)
|
|
model = _create_vision_transformer_packed('navit_base_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
return model
|
|
|
|
|
|
@register_model
|
|
def navit_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformerPacked:
|
|
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
|
|
model = _create_vision_transformer_packed('navit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
return model
|
|
|
|
|
|
@register_model
|
|
def navit_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformerPacked:
|
|
model_args = dict(img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12)
|
|
model = _create_vision_transformer_packed('navit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
return model
|
|
|
|
|
|
@register_model
|
|
def navit_base_patch16_xp_384(pretrained=False, **kwargs) -> VisionTransformerPacked:
|
|
model_args = dict(
|
|
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
|
qk_norm=True, pre_norm=True, block_fn=ParallelScalingBlock)
|
|
model = _create_vision_transformer_packed('navit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
return model
|