pytorch-image-models/timm/models/vision_transformer_packed.py
2023-09-25 23:30:56 -07:00

874 lines
33 KiB
Python

""" Packed Sequence Vision Transformer (ViT) in PyTorch
Base on ideas in NaViT paper
`Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution` - https://arxiv.org/abs/2307.06304
This is a WIP, TODO:
* significant additions to dataset pipeline (data loading / collation) to support sequences required
* token (patch) dropout needs to be implemented
* wider variety of position embedding options
"""
import logging
import math
from collections import OrderedDict
from dataclasses import dataclass, field
from functools import partial
from typing import Callable, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, trunc_normal_tf_, \
resample_patch_embed, resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked, to_2tuple
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq
from ._registry import generate_default_cfgs, register_model
from .vision_transformer import get_init_weights_vit
__all__ = ['VisionTransformerPacked'] # model_registry will add each entrypoint fn to this
_logger = logging.getLogger(__name__)
def extract_patches(
x,
patch_size=(16, 16),
channels_last=False,
flatten_grid=True,
pad=False,
):
B, C, H, W = x.shape
ph, pw = patch_size
if pad:
pad_h = (patch_size[0] - H % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - W % patch_size[1]) % patch_size[1]
x = F.pad(x, (0, pad_w, 0, pad_h))
H += pad_h
W += pad_w
gh, gw = H // ph, W // pw
if channels_last:
#x = x.unfold(2, ph, pw).unfold(3, ph, pw).permute(0, 2, 3, 4, 5, 1).reshape(B, -1, ph * pw * C)
x = x.reshape(B, C, gh, ph, gw, pw).permute(0, 2, 4, 3, 5, 1) # B, gH, gW, pH, pW, C
else:
#x = x.permute(0, 2, 3, 1).unfold(1, ph, pw).unfold(2, ph, pw).reshape(B, -1, C * ph * pw)
x = x.reshape(B, C, gh, ph, gw, pw).permute(0, 2, 4, 1, 3, 5)
if flatten_grid:
x = x.reshape(B, -1, C * ph * pw)
else:
x = x.reshape(B, gh, gw, -1)
return x
@dataclass
class PackedSequence:
tokens: List[torch.Tensor] = field(default_factory=list)
pos_indices: List[torch.Tensor] = field(default_factory=list)
seq_ids: List[torch.Tensor] = field(default_factory=list)
seq_lens: List[int] = field(default_factory=list)
total_len: int = 0
num_images: int = 0
def add_image(self, tokens, pos_indices):
seq_id = self.num_images + 1
seq_len = len(tokens)
device = tokens.device
self.tokens.append(tokens)
self.pos_indices.append(pos_indices)
self.seq_ids.append(torch.tensor([seq_id] * seq_len, dtype=torch.int64, device=device))
self.seq_lens.append(seq_len)
self.total_len += seq_len
self.num_images += 1
def to_tensors(self, max_seq_len, max_num_seq):
"""
Args:
max_seq_len: maximum sequence length (pad to this)
max_num_seq: maximum # of sequences (images) packed into one sequence (across the batch)
Returns:
Tuple of tensors for packed batch of images
"""
assert self.total_len > 0
assert max_seq_len >= self.total_len
device = self.tokens[-1].device
dim = self.tokens[-1].shape[-1]
pad_len = max_seq_len - self.total_len
seq_pad = max(0, max_num_seq - len(self.seq_lens))
seq_lens = self.seq_lens + [0] * seq_pad if seq_pad else self.seq_lens
seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=device)
if pad_len:
tokens = self.tokens + [torch.zeros(pad_len, dim, device=device)]
pos_indices = self.pos_indices + [torch.zeros((pad_len, 2), dtype=torch.int64, device=device)]
seq_ids = self.seq_ids + [torch.zeros(pad_len, dtype=torch.int64, device=device)]
else:
tokens = self.tokens
pos_indices = self.pos_indices
seq_ids = self.seq_ids
tokens = torch.concat(tokens)
pos_indices = torch.concat(pos_indices)
seq_ids = torch.concat(seq_ids)
return tokens, pos_indices, seq_ids, seq_lens
def pack_images(
images: List[torch.Tensor],
patch_size: Tuple[int, int],
max_grid_size: Tuple[int, int],
pad_patches: bool = False,
max_images_per_sequence: int = 4,
):
max_seq_len = max_grid_size[0] * max_grid_size[1]
# patchify, generate position indices, apply patch drop, record seq lengths
img_tokens = []
img_pos_indices = []
img_seq_lens = []
for img in images:
assert img.ndim == 3
device = img.device
patches = extract_patches(img.unsqueeze(0), patch_size, flatten_grid=False, pad=pad_patches).squeeze(0)
grid_h, grid_w, dim = patches.shape
seq_len = grid_h * grid_w
if seq_len > max_seq_len:
_logger.error('Sequence length of image is too large, skipping.')
continue
pos_indices = torch.stack(
torch.meshgrid((
torch.arange(grid_h, device=device),
torch.arange(grid_w, device=device)),
indexing='ij'),
dim=-1,
)
# FIXME patch drop here
img_tokens.append(patches.flatten(0, 1))
img_pos_indices.append(pos_indices.flatten(0, 1))
img_seq_lens.append(seq_len)
del images
# sort by seq length largest -> smallest
img_seq_lens = torch.tensor(img_seq_lens, dtype=torch.long, device=device)
seq_sort_indices = torch.argsort(img_seq_lens, descending=True)
packed_sequences: List[PackedSequence] = [] # image sequences packed together
next_pos = 0
max_packed = 0
for _ in range(len(seq_sort_indices)):
idx_to_pack = seq_sort_indices[next_pos]
len_to_pack = img_seq_lens[idx_to_pack]
sequence = None
for p in packed_sequences:
# try over existing
if p.num_images >= max_images_per_sequence or p.total_len + len_to_pack > max_seq_len:
# will not fit in this sequence
continue
sequence = p
break
if sequence is None:
sequence = PackedSequence() # start fresh sequence
packed_sequences.append(sequence)
img_to_pack = img_tokens[idx_to_pack]
pos_to_pack = img_pos_indices[idx_to_pack]
sequence.add_image(img_to_pack, pos_to_pack)
max_packed = max(sequence.num_images, max_packed)
next_pos += 1
tensors = [p.to_tensors(max_seq_len=max_seq_len, max_num_seq=max_packed) for p in packed_sequences]
o = [torch.stack(t) for t in zip(*tensors)]
return tuple(o)
class Attention(nn.Module):
fused_attn: Final[bool]
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if attn_mask is not None:
assert attn_mask.ndim == 4
if attn_mask.shape[1] != self.num_heads:
attn_mask = attn_mask.expand((-1, self.num_heads, -1, -1))
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
proj_drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
mlp_layer=Mlp,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask)))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class ParallelScalingBlock(nn.Module):
""" Parallel ViT block (MLP & Attention in parallel)
Based on:
'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
"""
fused_attn: Final[bool]
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
proj_drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
mlp_layer=None, # NOTE: not used
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()
mlp_hidden_dim = int(mlp_ratio * dim)
in_proj_out_dim = mlp_hidden_dim + 3 * dim
self.in_norm = norm_layer(dim)
self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias)
self.in_split = [mlp_hidden_dim] + [dim] * 3
if qkv_bias:
self.register_buffer('qkv_bias', None)
self.register_parameter('mlp_bias', None)
else:
self.register_buffer('qkv_bias', torch.zeros(3 * dim), persistent=False)
self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim))
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.attn_out_proj = nn.Linear(dim, dim)
self.mlp_drop = nn.Dropout(proj_drop)
self.mlp_act = act_layer()
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim)
self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def init_weights(self):
trunc_normal_tf_(self.in_proj.weight, std=(self.head_dim * self.num_heads) ** -0.5)
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
B, N, C = x.shape
# Combined MLP fc1 & qkv projections
y = self.in_norm(x)
if self.mlp_bias is not None:
# Concat constant zero-bias for qkv w/ trainable mlp_bias.
# Appears faster than adding to x_mlp separately
y = F.linear(y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias)))
else:
y = self.in_proj(y)
x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)
# Dot product attention w/ qk norm
q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
if self.fused_attn:
x_attn = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x_attn = attn @ v
x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
x_attn = self.attn_out_proj(x_attn)
# MLP activation, dropout, fc2
x_mlp = self.mlp_act(x_mlp)
x_mlp = self.mlp_drop(x_mlp)
x_mlp = self.mlp_out_proj(x_mlp)
# Add residual w/ drop path & layer scale applied
y = self.drop_path(self.ls(x_attn + x_mlp))
x = x + y
return x
class AttentionPoolLatent(nn.Module):
""" Attention pooling w/ latent query
"""
def __init__(
self,
in_features: int,
out_features: int = None,
embed_dim: int = None,
num_heads: int = 8,
qkv_bias: bool = True,
qk_norm: bool = False,
flatten_input: bool = True,
latent_size: int = 1,
latent_proj: bool = False,
latent_dim: int = None,
pos_embed: str = '',
proj_type: str = '',
pool_type: str = '',
norm_layer: Optional[nn.Module] = None,
drop: float = 0.0,
):
super().__init__()
embed_dim = embed_dim or in_features
out_features = out_features or in_features
assert embed_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.flatten_input = flatten_input
self.pool = pool_type
if pos_embed == 'abs':
spatial_len = self.feat_size
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
else:
self.pos_embed = None
self.latent_dim = latent_dim or embed_dim
latent_size = latent_size or self.feat_size
self.latent_len = latent_size
self.latent = nn.Parameter(torch.zeros(self.latent_len, embed_dim))
if latent_proj:
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
else:
assert not latent_dim or latent_dim == embed_dim
self.q = None
self.kv = nn.Linear(in_features, embed_dim * 2, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
if proj_type == 'linear':
self.proj = nn.Linear(embed_dim, out_features)
self.proj_drop = nn.Dropout(drop)
elif proj_type == 'mlp':
self.proj = Mlp(
embed_dim,
hidden_features=embed_dim * 4,
out_features=out_features,
drop=drop)
self.proj_drop = nn.Identity()
else:
assert out_features == embed_dim
self.proj = None
self.proj_drop = nn.Dropout(drop)
def init_weights(self):
if self.pos_embed is not None:
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
trunc_normal_tf_(self.latent, std=self.latent.shape[1] ** -0.5)
if self.q is not None:
trunc_normal_tf_(self.q.weight, std=self.q.weight.shape[1] ** -0.5)
if self.q.bias is not None:
nn.init.zeros_(self.q.bias)
trunc_normal_tf_(self.kv.weight, std=self.kv.weight.shape[1] ** -0.5)
if self.kv.bias is not None:
nn.init.zeros_(self.kv.bias)
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
B, N, _ = x.shape
if self.pos_embed is not None:
# FIXME interpolate
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
q = self.latent if self.q is None else self.q(self.latent)
q = q.reshape(1, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
if attn_mask.shape[2] != q.shape[2]:
# expand latent q to match attention mask, TODO make this less implicit?
if q.shape[2] == 1:
q = q.expand(B, -1, attn_mask.shape[2], -1)
else:
assert attn_mask.shape[2] % q.shape[2] == 0
q = q.repeat(1, 1, attn_mask.shape[2] // q.shape[2], 1)
q = q.expand(B, -1, -1, -1)
else:
q = q.expand(B, -1, -1, -1)
latent_len = q.shape[2]
x = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = x.unbind(0)
q = self.q_norm(q)
k = self.k_norm(k)
if False:
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn += attn_mask
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, latent_len, -1)
x = self.norm(x)
if self.proj is not None:
shortcut = x
x = self.proj(x)
x = self.proj_drop(x)
x = x + shortcut
else:
x = self.proj_drop(x)
if self.pool == 'token':
x = x[:, 0]
return x
class VisionTransformerPacked(nn.Module):
""" Vision Transformer
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_norm: bool = False,
init_values: Optional[float] = None,
pre_norm: bool = False,
fc_norm: Optional[bool] = None,
drop_rate: float = 0.,
pos_drop_rate: float = 0.,
patch_drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init: str = '',
norm_layer: Optional[Callable] = None,
act_layer: Optional[Callable] = None,
block_fn: Callable = Block,
mlp_layer: Callable = Mlp,
):
"""
Args:
img_size: Input image size.
patch_size: Patch size.
in_chans: Number of image input channels.
num_classes: Number of classes for classification head.
global_pool: Type of global pooling for final sequence (default: 'token').
embed_dim: Transformer embedding dimension.
depth: Depth of transformer.
num_heads: Number of attention heads.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: Enable bias for qkv projections if True.
init_values: Layer-scale init values (layer-scale enabled if not None).
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
drop_rate: Head dropout rate.
pos_drop_rate: Position embedding dropout rate.
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme.
norm_layer: Normalization layer.
act_layer: MLP activation layer.
block_fn: Transformer block layer.
"""
super().__init__()
assert global_pool in ('', 'avg', 'attn')
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.grad_checkpointing = False
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_size = patch_h, patch_w = to_2tuple(patch_size)
self.img_size = img_h, img_w = to_2tuple(img_size) # NOTE this === 'maximum size'
self.grid_size = grid_h, grid_w = img_h // patch_h, img_w // patch_w
self.max_seq = grid_h * grid_w
patch_dim_in = in_chans * patch_h * patch_w
self.patch_embed = nn.Linear(patch_dim_in, embed_dim)
self.pos_embed_h = nn.Parameter(torch.randn(grid_h, embed_dim) * .02)
self.pos_embed_w = nn.Parameter(torch.randn(grid_w, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
init_values=init_values,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
mlp_layer=mlp_layer,
)
for i in range(depth)])
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
if global_pool == 'avg':
self.attn_pool = None
else:
# FIXME attention pooling appears less stable in initial trials
self.attn_pool = AttentionPoolLatent(
self.embed_dim,
self.embed_dim,
num_heads=num_heads,
pos_embed='',
latent_proj=True,
proj_type='',
norm_layer=norm_layer,
)
# Classifier Head
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if weight_init != 'skip':
self.init_weights(weight_init)
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed_h, std=.02)
trunc_normal_(self.pos_embed_w, std=.02)
named_apply(get_init_weights_vit(mode, head_bias), self)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed_h', 'pos_embed_w'}
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^embeds', # stem and embed # FIXME correct when design finalized
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes: int, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'attn')
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(
self,
tokens: Union[List[torch.Tensor], torch.Tensor],
pos_indices: Optional[torch.Tensor] = None,
seq_ids: Optional[torch.Tensor] = None,
seq_lens: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
if tokens.ndim == 4:
# B, C, H, W batch tensor will be converted to list and packed
# for compatibility with common image model usage (and initial testing)
tokens = tokens.unbind(0)
if isinstance(tokens, (list, tuple)):
tokens, pos_indices, seq_ids, seq_lens = pack_images(
tokens,
self.patch_size,
max_grid_size=self.grid_size,
pad_patches=True,
max_images_per_sequence=4,
)
assert tokens.ndim == 3
assert pos_indices is not None
assert seq_ids is not None
assert seq_lens is not None
tokens = self.patch_embed(tokens)
pos_index_h, pos_index_w = pos_indices.unbind(-1)
pos = self.pos_embed_h[pos_index_h] + self.pos_embed_w[pos_index_w]
tokens += pos
tokens = self.pos_drop(tokens)
tokens = self.norm_pre(tokens)
if attn_mask is None:
attn_mask = seq_ids.unsqueeze(2) == seq_ids.unsqueeze(1)
# NOTE: not applying key padding mask as padding tokens are already isolated to
# themselves via the above mask (padding has seq_id == 0). Doing an additional
# key padding mask results in fully masked rows which causes numerical issues.
# key_padding_mask = (seq_ids != 0).unsqueeze(1)
# attn_mask = attn_mask & key_padding_mask
attn_mask = attn_mask.unsqueeze(1)
if attn_mask.dtype == torch.bool:
dtype = tokens.dtype
min_val = torch.finfo(dtype).min
attn_mask = torch.zeros_like(attn_mask, dtype=dtype).masked_fill_(~attn_mask, min_val)
for b in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
tokens = torch.utils.checkpoint.checkpoint(
b, tokens, use_reentrant=False, attn_mask=attn_mask)
else:
tokens = b(tokens, attn_mask=attn_mask)
tokens = self.norm(tokens)
device = tokens.device
max_packing = seq_lens.shape[1]
seq_id_range = torch.arange(1, 1 + max_packing, device=device)
unpack_mask = seq_ids.unsqueeze(1) == seq_id_range[:, None]
seq_lens = seq_lens.reshape(-1)
valid_rows = seq_lens > 0
if self.attn_pool is not None:
# unpack_mask = unpack_mask & key_padding_mask
unpack_mask = unpack_mask.unsqueeze(1)
unpack_mask = torch.zeros_like(unpack_mask, dtype=tokens.dtype).masked_fill_(
~unpack_mask, torch.finfo(tokens.dtype).min)
tokens = self.attn_pool(tokens, attn_mask=unpack_mask)
tokens = tokens.reshape(-1, self.embed_dim)
tokens = tokens[valid_rows]
else:
tokens = tokens.unsqueeze(1).expand(-1, max_packing, -1, -1)[unpack_mask]
tokens = tokens.tensor_split(seq_lens.reshape(-1).cumsum(0)[:sum(valid_rows) - 1].cpu())
# tokens = tokens.unsqueeze(1) * unpack_mask.unsqueeze(-1).expand(-1, -1, -1, self.embed_dim)
# tokens = tokens.reshape(-1, tokens.shape[-2], tokens.shape[-1])
# seq_lens = seq_lens[valid_rows]
# tokens = tokens[valid_rows]
# FIXME sort out this mess, the boundary of features vs head is a bit messy with
# variable length sequence averaging vs attention pooling...
return tokens #, seq_lens
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
if isinstance(x, (list, tuple)):
x = torch.stack([t.mean(dim=0) for t in x], 0)
else:
# x = x.sum(dim=1) / seq_lens.reshape(-1, 1)
x = x.mean(dim=1)
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(
self,
tokens: Union[List[torch.Tensor], torch.Tensor],
pos_indices: Optional[torch.Tensor] = None,
seq_ids: Optional[torch.Tensor] = None,
seq_lens: Optional[torch.Tensor] = None,
):
x = self.forward_features(
tokens,
pos_indices=pos_indices,
seq_ids=seq_ids,
seq_lens=seq_lens,
)
x = self.forward_head(x)
return x
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = generate_default_cfgs({
'navit_medium_patch16_384': _cfg(),
'navit_base_patch32_224': _cfg(),
'navit_base_patch32_384': _cfg(),
'navit_base_patch16_224': _cfg(),
'navit_base_patch16_384': _cfg(),
})
def _create_vision_transformer_packed(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
return build_model_with_cfg(
VisionTransformerPacked,
variant,
pretrained,
#pretrained_filter_fn=checkpoint_filter_fn,
**kwargs,
)
@register_model
def navit_medium_patch16_384(pretrained=False, **kwargs) -> VisionTransformerPacked:
model_args = dict(
img_size=384, patch_size=16, embed_dim=512, depth=12, num_heads=8,
fc_norm=False, init_values=1e-5, qkv_bias=False)
model = _create_vision_transformer_packed(
'navit_medium_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def navit_base_patch32_224(pretrained=False, **kwargs) -> VisionTransformerPacked:
model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer_packed('navit_base_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def navit_base_patch32_384(pretrained=False, **kwargs) -> VisionTransformerPacked:
model_args = dict(img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer_packed('navit_base_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def navit_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformerPacked:
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer_packed('navit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def navit_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformerPacked:
model_args = dict(img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer_packed('navit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def navit_base_patch16_xp_384(pretrained=False, **kwargs) -> VisionTransformerPacked:
model_args = dict(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
qk_norm=True, pre_norm=True, block_fn=ParallelScalingBlock)
model = _create_vision_transformer_packed('navit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model