Merge b82f9e869c
into c8c4f256b8
commit
61b7922f6b
|
@ -45,6 +45,7 @@ from .nasnet import *
|
|||
from .nest import *
|
||||
from .nextvit import *
|
||||
from .nfnet import *
|
||||
from .pe import *
|
||||
from .pit import *
|
||||
from .pnasnet import *
|
||||
from .pvt_v2 import *
|
||||
|
|
|
@ -0,0 +1,897 @@
|
|||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn, Tensor, broadcast_tensors, einsum
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import Module, ModuleList
|
||||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.amp import autocast
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
### Import timm layers
|
||||
from timm.layers import (
|
||||
PatchEmbed,
|
||||
Mlp,
|
||||
DropPath,
|
||||
AttentionPoolLatent,
|
||||
RmsNorm,
|
||||
PatchDropout,
|
||||
SwiGLUPacked,
|
||||
SwiGLU,
|
||||
trunc_normal_,
|
||||
lecun_normal_,
|
||||
resample_patch_embed,
|
||||
resample_abs_pos_embed,
|
||||
use_fused_attn,
|
||||
get_act_layer,
|
||||
get_norm_layer,
|
||||
LayerType,
|
||||
LayerScale,
|
||||
)
|
||||
|
||||
# from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
||||
|
||||
__all__ = ['PE']
|
||||
|
||||
|
||||
######## PE's Rope ########
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return rearrange(x, "... d r -> ... (d r)")
|
||||
|
||||
|
||||
@autocast("cuda", enabled=False)
|
||||
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
|
||||
dtype = t.dtype
|
||||
|
||||
if t.ndim == 3:
|
||||
seq_len = t.shape[seq_dim]
|
||||
freqs = freqs[-seq_len:]
|
||||
|
||||
rot_dim = freqs.shape[-1]
|
||||
end_index = start_index + rot_dim
|
||||
|
||||
assert (
|
||||
rot_dim <= t.shape[-1]
|
||||
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
||||
|
||||
t_left, t, t_right = (
|
||||
t[..., :start_index],
|
||||
t[..., start_index:end_index],
|
||||
t[..., end_index:],
|
||||
)
|
||||
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
||||
out = torch.cat((t_left, t, t_right), dim=-1)
|
||||
|
||||
return out.type(dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
custom_freqs: Optional[Tensor] = None,
|
||||
freqs_for: Union[Literal["lang"], Literal["pixel"], Literal["constant"]] = "lang",
|
||||
theta=10000,
|
||||
max_freq=10,
|
||||
num_freqs=1,
|
||||
learned_freq=False,
|
||||
use_xpos=False,
|
||||
xpos_scale_base=512,
|
||||
interpolate_factor=1.0,
|
||||
theta_rescale_factor=1.0,
|
||||
seq_before_head_dim=False,
|
||||
cache_if_possible=True,
|
||||
):
|
||||
super().__init__()
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
|
||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
self.freqs_for = freqs_for
|
||||
|
||||
if exists(custom_freqs):
|
||||
freqs = custom_freqs
|
||||
elif freqs_for == "lang":
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
elif freqs_for == "pixel":
|
||||
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
||||
elif freqs_for == "constant":
|
||||
freqs = torch.ones(num_freqs).float()
|
||||
|
||||
self.cache_if_possible = cache_if_possible
|
||||
|
||||
self.tmp_store("cached_freqs", None)
|
||||
self.tmp_store("cached_scales", None)
|
||||
|
||||
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
|
||||
|
||||
self.learned_freq = learned_freq
|
||||
|
||||
# dummy for device
|
||||
|
||||
self.tmp_store("dummy", torch.tensor(0))
|
||||
|
||||
# default sequence dimension
|
||||
|
||||
self.seq_before_head_dim = seq_before_head_dim
|
||||
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
||||
|
||||
# interpolation factors
|
||||
|
||||
assert interpolate_factor >= 1.0
|
||||
self.interpolate_factor = interpolate_factor
|
||||
|
||||
# xpos
|
||||
|
||||
self.use_xpos = use_xpos
|
||||
if not use_xpos:
|
||||
self.tmp_store("scale", None)
|
||||
return
|
||||
|
||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
||||
|
||||
self.scale_base = xpos_scale_base
|
||||
self.tmp_store("scale", scale)
|
||||
|
||||
# add apply_rotary_emb as static method
|
||||
|
||||
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.dummy.device
|
||||
|
||||
def tmp_store(self, key, value):
|
||||
self.register_buffer(key, value, persistent=False)
|
||||
|
||||
def get_seq_pos(self, seq_len, device, dtype, offset=0):
|
||||
return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
|
||||
|
||||
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
|
||||
seq_dim = default(seq_dim, self.default_seq_dim)
|
||||
|
||||
assert (
|
||||
not self.use_xpos
|
||||
), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
|
||||
|
||||
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
||||
|
||||
freqs = self.forward(
|
||||
self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset),
|
||||
seq_len=seq_len,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
if seq_dim == -3:
|
||||
freqs = rearrange(freqs, "n d -> n 1 d")
|
||||
|
||||
return apply_rotary_emb(freqs, t, seq_dim=seq_dim)
|
||||
|
||||
def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
|
||||
seq_dim = default(seq_dim, self.default_seq_dim)
|
||||
|
||||
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
||||
assert q_len <= k_len
|
||||
|
||||
rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, offset=k_len - q_len + offset)
|
||||
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset)
|
||||
|
||||
rotated_q = rotated_q.type(q.dtype)
|
||||
rotated_k = rotated_k.type(k.dtype)
|
||||
|
||||
return rotated_q, rotated_k
|
||||
|
||||
def rotate_queries_and_keys(self, q, k, seq_dim=None):
|
||||
seq_dim = default(seq_dim, self.default_seq_dim)
|
||||
|
||||
assert self.use_xpos
|
||||
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
||||
|
||||
seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
|
||||
|
||||
freqs = self.forward(seq, seq_len=seq_len)
|
||||
scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
|
||||
|
||||
if seq_dim == -3:
|
||||
freqs = rearrange(freqs, "n d -> n 1 d")
|
||||
scale = rearrange(scale, "n d -> n 1 d")
|
||||
|
||||
rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
|
||||
rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
|
||||
|
||||
rotated_q = rotated_q.type(q.dtype)
|
||||
rotated_k = rotated_k.type(k.dtype)
|
||||
|
||||
return rotated_q, rotated_k
|
||||
|
||||
def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
|
||||
assert self.use_xpos
|
||||
|
||||
should_cache = self.cache_if_possible and exists(seq_len)
|
||||
|
||||
if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales.shape[0]:
|
||||
return self.cached_scales[offset : (offset + seq_len)]
|
||||
|
||||
scale = 1.0
|
||||
if self.use_xpos:
|
||||
power = (t - len(t) // 2) / self.scale_base
|
||||
scale = self.scale ** rearrange(power, "n -> n 1")
|
||||
scale = torch.cat((scale, scale), dim=-1)
|
||||
|
||||
if should_cache:
|
||||
self.tmp_store("cached_scales", scale)
|
||||
|
||||
return scale
|
||||
|
||||
def get_axial_freqs(self, *dims):
|
||||
Colon = slice(None)
|
||||
all_freqs = []
|
||||
|
||||
for ind, dim in enumerate(dims):
|
||||
if self.freqs_for == "pixel":
|
||||
pos = torch.linspace(-1, 1, steps=dim, device=self.device)
|
||||
else:
|
||||
pos = torch.arange(dim, device=self.device)
|
||||
|
||||
freqs = self.forward(pos, seq_len=dim)
|
||||
|
||||
all_axis = [None] * len(dims)
|
||||
all_axis[ind] = Colon
|
||||
|
||||
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
||||
all_freqs.append(freqs[new_axis_slice])
|
||||
|
||||
all_freqs = broadcast_tensors(*all_freqs)
|
||||
return torch.cat(all_freqs, dim=-1)
|
||||
|
||||
@autocast("cuda", enabled=False)
|
||||
def forward(self, t: Tensor, seq_len=None, offset=0):
|
||||
should_cache = (
|
||||
self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel"
|
||||
)
|
||||
|
||||
if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs.shape[0]:
|
||||
return self.cached_freqs[offset : (offset + seq_len)].detach()
|
||||
|
||||
freqs = self.freqs
|
||||
|
||||
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
||||
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
||||
|
||||
if should_cache:
|
||||
self.tmp_store("cached_freqs", freqs.detach())
|
||||
|
||||
return freqs
|
||||
|
||||
|
||||
class Rope2D:
|
||||
"""Helper class to apply RoPE2D as well as interpolate on the fly."""
|
||||
|
||||
def __init__(self, dim, use_cls_token=False):
|
||||
self.dim = dim
|
||||
self.use_cls_token = use_cls_token
|
||||
self.grid_size = None
|
||||
self.freq = None
|
||||
|
||||
def init_tensors(self):
|
||||
self.rope = RotaryEmbedding(self.dim // 2)
|
||||
|
||||
def update_grid(self, device, grid_h, grid_w):
|
||||
if self.grid_size != (grid_h, grid_w):
|
||||
self.grid_size = (grid_h, grid_w)
|
||||
|
||||
self.rope = self.rope.to(device)
|
||||
if self.use_cls_token:
|
||||
# +1 to leave space for the cls token to be (0, 0)
|
||||
grid_y_range = torch.arange(grid_h, device=device) + 1
|
||||
grid_x_range = torch.arange(grid_w, device=device) + 1
|
||||
else:
|
||||
grid_y_range = torch.arange(grid_h, device=device)
|
||||
grid_x_range = torch.arange(grid_w, device=device)
|
||||
freqs_y = self.rope(grid_y_range)[:, None].expand(grid_h, grid_w, -1)
|
||||
freqs_x = self.rope(grid_x_range)[None, :].expand(grid_h, grid_w, -1)
|
||||
freq = torch.cat([freqs_x, freqs_y], dim=-1).reshape(grid_h * grid_w, -1)
|
||||
|
||||
if self.use_cls_token:
|
||||
freq = torch.cat([freq, torch.zeros(1, freq.shape[-1], device=device)], dim=0)
|
||||
|
||||
self.freq = freq[None, ...]
|
||||
|
||||
self.freq = self.freq.to(device)
|
||||
|
||||
def __call__(self, q, k):
|
||||
# batch, heads, seq, dim = q.shape
|
||||
q = apply_rotary_emb(self.freq[:, None, :, :], q)
|
||||
k = apply_rotary_emb(self.freq[:, None, :, :], k)
|
||||
|
||||
return q, k
|
||||
|
||||
|
||||
######## PE Modules ########
|
||||
class AttentionPooling(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
num_probe: int = 1,
|
||||
mlp_ratio: int = 4,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.probe = nn.Parameter(torch.randn(1, num_probe, self.embed_dim))
|
||||
self.attn = nn.MultiheadAttention(self.embed_dim, self.num_heads, batch_first=True)
|
||||
|
||||
self.layernorm = norm_layer(embed_dim)
|
||||
self.mlp_width = int(embed_dim * mlp_ratio)
|
||||
self.mlp = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("c_fc", nn.Linear(self.embed_dim, self.mlp_width)),
|
||||
("gelu", act_layer()),
|
||||
("c_proj", nn.Linear(self.mlp_width, self.embed_dim)),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
batch, _, _ = x.shape
|
||||
|
||||
q = self.probe.repeat((batch, 1, 1)).to(x.dtype)
|
||||
x = self.attn(q, x, x, need_weights=False)[0]
|
||||
x = x + self.mlp(self.layernorm(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
r"""
|
||||
Implements sequence packed attention and RoPe
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
rope: Optional[nn.Module] = None,
|
||||
):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
|
||||
# To make this compatibile with nn.MultiHeadAttention
|
||||
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
||||
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
|
||||
self.rope = rope
|
||||
self.scale = self.head_dim ** (-0.5)
|
||||
|
||||
def init_tensors(self):
|
||||
xavier_uniform_(self.in_proj_weight)
|
||||
constant_(self.in_proj_bias, 0.0)
|
||||
constant_(self.out_proj.bias, 0.0)
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
batch, seq, embed_dim = x.shape
|
||||
proj = F.linear(x, self.in_proj_weight, self.in_proj_bias)
|
||||
|
||||
# reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
|
||||
proj = proj.unflatten(-1, (3, embed_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
||||
q, k, v = proj[0], proj[1], proj[2]
|
||||
|
||||
# Use "q_" so that we don't accidentally quit in pdb :)
|
||||
q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
|
||||
k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
|
||||
v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
|
||||
|
||||
if self.rope:
|
||||
q, k = self.rope(q, k)
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale)
|
||||
attn = rearrange(attn, "b h s d -> b s (h d)")
|
||||
|
||||
return F.linear(attn, self.out_proj.weight, self.out_proj.bias)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
ls_init_value: float = None,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
drop_path: float = 0.0,
|
||||
rope: Optional[nn.Module] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if rope:
|
||||
self.attn = SelfAttention(d_model, n_head, rope=rope)
|
||||
else:
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
|
||||
|
||||
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
||||
|
||||
self.ln_1 = norm_layer(d_model)
|
||||
self.ln_2 = norm_layer(d_model)
|
||||
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
mlp_width = int(d_model * mlp_ratio)
|
||||
self.mlp = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("c_fc", nn.Linear(d_model, mlp_width)),
|
||||
("gelu", act_layer()),
|
||||
("c_proj", nn.Linear(mlp_width, d_model)),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def _call_attn(
|
||||
self,
|
||||
q_x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
|
||||
if attn_mask is not None:
|
||||
# Leave boolean masks as is
|
||||
if not attn_mask.dtype == torch.bool:
|
||||
attn_mask = attn_mask.to(q_x.dtype)
|
||||
|
||||
if isinstance(self.attn, SelfAttention):
|
||||
return self.attn(q_x, attn_mask=attn_mask)
|
||||
else:
|
||||
return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
x = x + self.drop_path1(self.ls_1(self._call_attn(self.ln_1(x), attn_mask=attn_mask)))
|
||||
x = x + self.drop_path2(self.ls_2(self.mlp(self.ln_2(x))))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
ls_init_value: float = None,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = nn.LayerNorm,
|
||||
drop_path: float = 0.0,
|
||||
rope: Optional[nn.Module] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(
|
||||
width,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
ls_init_value=ls_init_value,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop_path=drop_path,
|
||||
rope=rope,
|
||||
)
|
||||
for _ in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def truncate(self, layer_idx: int):
|
||||
"""Delete layers so the last layer is the given layer index."""
|
||||
self.layers = ((self.layers + layer_idx) % self.layers) + 1
|
||||
self.resblocks = nn.ModuleList(self.resblocks[: self.layers])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
layer_idx: int = -1,
|
||||
):
|
||||
stop_idx = (self.layers + layer_idx) % self.layers
|
||||
|
||||
for i, r in enumerate(self.resblocks):
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
# TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
|
||||
x = checkpoint(r, x, None, None, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
|
||||
if i == stop_idx:
|
||||
break
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
mlp_ratio: float,
|
||||
act_layer: Callable = nn.GELU,
|
||||
norm_layer: Callable = partial(nn.LayerNorm, eps=1e-5),
|
||||
use_ln_pre: bool = True,
|
||||
use_ln_post: bool = True,
|
||||
ls_init_value: float = None,
|
||||
drop_path: float = 0.0,
|
||||
image_size: int = 448, # Pretrain image size only; you can pass in any image size
|
||||
use_abs_posemb: bool = True,
|
||||
use_rope2d: bool = True,
|
||||
use_cls_token: bool = False,
|
||||
output_dim: Optional[int] = 1280,
|
||||
attn_pooler_heads: int = 8,
|
||||
pool_type: Literal["attn", "tok", "avg", "none"] = "attn",
|
||||
num_classes: int = 1000, # no use for now
|
||||
in_chans: int = 3,
|
||||
):
|
||||
super().__init__()
|
||||
assert pool_type in ("attn", "tok", "avg", "none")
|
||||
self.pool_type = pool_type
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.output_dim = output_dim or width
|
||||
self.proj_dim = output_dim
|
||||
self.heads = heads
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
|
||||
self.use_abs_posemb = use_abs_posemb
|
||||
self.use_cls_token = use_cls_token
|
||||
self.use_rope2d = use_rope2d
|
||||
self.image_size = image_size
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=3,
|
||||
out_channels=width,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
)
|
||||
self.rope = (
|
||||
Rope2D(
|
||||
dim=width // heads,
|
||||
use_cls_token=self.use_cls_token,
|
||||
)
|
||||
if self.use_rope2d
|
||||
else None
|
||||
)
|
||||
|
||||
self.ln_pre = norm_layer(width) if use_ln_pre else nn.Identity()
|
||||
self.ln_post = norm_layer(self.width) if use_ln_post else nn.Identity()
|
||||
|
||||
self.transformer = Transformer(
|
||||
width,
|
||||
layers,
|
||||
heads,
|
||||
mlp_ratio,
|
||||
ls_init_value=ls_init_value,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
drop_path=drop_path,
|
||||
rope=self.rope,
|
||||
)
|
||||
|
||||
if pool_type == "attn":
|
||||
self.attn_pool = AttentionPooling(
|
||||
embed_dim=width,
|
||||
num_heads=attn_pooler_heads,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
else:
|
||||
self.attn_pool = None
|
||||
|
||||
self.init_tensors()
|
||||
|
||||
def init_tensors(self):
|
||||
def init_submodule_tensors(module):
|
||||
for name, child in module.named_children():
|
||||
if hasattr(child, "init_tensors"):
|
||||
# logger.debug(f"Initializing tensors for submodule: {name}")
|
||||
child.init_tensors()
|
||||
init_submodule_tensors(child)
|
||||
|
||||
init_submodule_tensors(self)
|
||||
self.rope.init_tensors()
|
||||
|
||||
# class embeddings and positional embeddings
|
||||
init_scale = self.width**-0.5
|
||||
|
||||
if self.use_cls_token:
|
||||
self.class_embedding = nn.Parameter(init_scale * torch.randn(self.width))
|
||||
|
||||
if self.use_abs_posemb:
|
||||
self.posemb_grid_size = self.image_size // self.patch_size
|
||||
self.positional_embedding = nn.Parameter(
|
||||
init_scale * torch.randn(int(self.use_cls_token) + self.posemb_grid_size**2, self.width)
|
||||
)
|
||||
|
||||
if self.proj_dim is not None:
|
||||
self.proj = nn.Parameter(init_scale * torch.randn(self.width, self.proj_dim))
|
||||
|
||||
def truncate(self, layer_idx: int):
|
||||
"""Delete layers so the last layer is the given layer index."""
|
||||
self.transformer.truncate(layer_idx)
|
||||
self.layers = self.transformer.layers
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
self.transformer.set_grad_checkpointing(enable=enable)
|
||||
|
||||
def _sample_abs_posemb(self, grid_h: int, grid_w: int):
|
||||
"""Interpolates the absolute position embedding if necessary."""
|
||||
if self.posemb_grid_size == grid_h and self.posemb_grid_size == grid_w:
|
||||
return self.positional_embedding[None, ...]
|
||||
|
||||
pos_embed = self.positional_embedding
|
||||
if self.use_cls_token:
|
||||
cls_token_embed, pos_embed = pos_embed[:1], pos_embed[1:]
|
||||
|
||||
pos_embed = (
|
||||
pos_embed.reshape(1, self.posemb_grid_size, self.posemb_grid_size, -1).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
pos_embed = F.interpolate(pos_embed, size=(grid_h, grid_w), mode="bilinear", align_corners=False)
|
||||
pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(-1, self.width).contiguous()
|
||||
|
||||
if self.use_cls_token:
|
||||
pos_embed = torch.cat([cls_token_embed, pos_embed], dim=0)
|
||||
|
||||
return pos_embed[None, ...]
|
||||
|
||||
def _pool(self, x: torch.Tensor):
|
||||
if self.pool_type == "tok":
|
||||
return x[:, 0]
|
||||
elif self.pool_type == "avg":
|
||||
return x.mean(dim=1)
|
||||
elif self.pool_type == "attn":
|
||||
return self.attn_pool(x).squeeze(1)
|
||||
elif self.pool_type == "none":
|
||||
return x
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_features(self, x: torch.Tensor, norm: bool = False, layer_idx: int = -1, strip_cls_token: bool = False):
|
||||
batch, _, h, w = x.shape
|
||||
grid_h, grid_w = h // self.patch_size, w // self.patch_size
|
||||
|
||||
x = self.conv1(x)
|
||||
x = x.permute(0, 2, 3, 1).reshape(batch, -1, self.width)
|
||||
|
||||
if self.use_cls_token:
|
||||
x = torch.cat(
|
||||
[self.class_embedding.view(1, 1, -1).expand(batch, -1, -1), x],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
if self.use_abs_posemb:
|
||||
x = x + self._sample_abs_posemb(grid_h, grid_w)
|
||||
|
||||
if self.use_rope2d:
|
||||
self.rope.update_grid(x.device, grid_h, grid_w)
|
||||
|
||||
x = self.ln_pre(x)
|
||||
x = self.transformer(x, layer_idx=layer_idx)
|
||||
|
||||
if norm:
|
||||
x = self.ln_post(x)
|
||||
|
||||
if strip_cls_token and self.use_cls_token:
|
||||
x = x[:, 1:, :]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, **kwargs):
|
||||
x = self.forward_features(x, norm=True, **kwargs)
|
||||
x = self._pool(x)
|
||||
|
||||
if self.proj_dim is not None:
|
||||
x = x @ self.proj
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
||||
if any(k.startswith("visual.") for k in state_dict):
|
||||
state_dict = {k.replace("visual.", ""): v for k, v in state_dict.items() if "visual" in k}
|
||||
return state_dict
|
||||
|
||||
|
||||
######## PE Config ########
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'license': 'apache-2.0',
|
||||
'num_classes': 0,
|
||||
'interpolation': 'bilinear',
|
||||
'fixed_input_size': True,
|
||||
'mean': IMAGENET_INCEPTION_MEAN,
|
||||
'std': IMAGENET_INCEPTION_STD,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = generate_default_cfgs(
|
||||
{
|
||||
'vit_pe_core_base_patch16_224': _cfg(hf_hub_id='timm/', input_size=(3, 224, 224)),
|
||||
'vit_pe_core_large_patch14_336': _cfg(hf_hub_id='timm/', input_size=(3, 336, 336)),
|
||||
'vit_pe_core_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
|
||||
'vit_pe_lang_large_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
|
||||
'vit_pe_lang_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
|
||||
'vit_pe_spatial_gigantic_patch14_448': _cfg(hf_hub_id='timm/', input_size=(3, 448, 448)),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _create_pe(variant: str, pretrained: bool = False, **kwargs) -> PE:
|
||||
out_indices = kwargs.pop('out_indices', 3)
|
||||
return build_model_with_cfg(
|
||||
PE,
|
||||
variant,
|
||||
pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
pretrained_strict=True,
|
||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_pe_core_base_patch16_224(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
width=768,
|
||||
layers=12,
|
||||
heads=12,
|
||||
mlp_ratio=4.0,
|
||||
output_dim=1024,
|
||||
use_cls_token=True,
|
||||
pool_type='attn',
|
||||
)
|
||||
return _create_pe('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_pe_core_large_patch14_336(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
image_size=336,
|
||||
patch_size=14,
|
||||
width=1024,
|
||||
layers=24,
|
||||
heads=16,
|
||||
mlp_ratio=4.0,
|
||||
output_dim=1024,
|
||||
use_cls_token=True,
|
||||
pool_type='attn',
|
||||
)
|
||||
return _create_pe('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_pe_core_gigantic_patch14_448(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
image_size=448,
|
||||
patch_size=14,
|
||||
width=1536,
|
||||
layers=50,
|
||||
heads=16,
|
||||
mlp_ratio=8960 / 1536,
|
||||
output_dim=1280,
|
||||
use_cls_token=False,
|
||||
pool_type='attn',
|
||||
)
|
||||
return _create_pe('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_pe_lang_large_patch14_448(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
image_size=448,
|
||||
patch_size=14,
|
||||
width=1024,
|
||||
layers=23,
|
||||
heads=16,
|
||||
mlp_ratio=4.0,
|
||||
output_dim=None,
|
||||
use_cls_token=True,
|
||||
use_ln_post=False,
|
||||
pool_type='none',
|
||||
ls_init_value=0.1,
|
||||
)
|
||||
return _create_pe('vit_pe_lang_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_pe_lang_gigantic_patch14_448(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
image_size=448,
|
||||
patch_size=14,
|
||||
width=1536,
|
||||
layers=47,
|
||||
heads=16,
|
||||
mlp_ratio=8960 / 1536,
|
||||
output_dim=None,
|
||||
use_cls_token=False,
|
||||
use_ln_post=False,
|
||||
pool_type='none',
|
||||
ls_init_value=0.1,
|
||||
)
|
||||
return _create_pe('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
image_size=448,
|
||||
patch_size=14,
|
||||
width=1536,
|
||||
layers=50,
|
||||
heads=16,
|
||||
mlp_ratio=8960 / 1536,
|
||||
output_dim=None,
|
||||
use_cls_token=False,
|
||||
use_ln_post=False,
|
||||
pool_type='none',
|
||||
ls_init_value=0.1,
|
||||
)
|
||||
return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
Loading…
Reference in New Issue