444 lines
14 KiB
Python
444 lines
14 KiB
Python
""" Sin-cos, fourier, rotary position embedding modules and functions
|
|
|
|
Hacked together by / Copyright 2022 Ross Wightman
|
|
"""
|
|
import math
|
|
from typing import List, Tuple, Optional, Union
|
|
|
|
import torch
|
|
from torch import nn as nn
|
|
|
|
from .grid import ndgrid
|
|
from .trace_utils import _assert
|
|
|
|
|
|
def pixel_freq_bands(
|
|
num_bands: int,
|
|
max_freq: float = 224.,
|
|
linear_bands: bool = True,
|
|
device: Optional[torch.device] = None,
|
|
):
|
|
if linear_bands:
|
|
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
|
|
else:
|
|
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
|
|
return bands * torch.pi
|
|
|
|
|
|
def freq_bands(
|
|
num_bands: int,
|
|
temperature: float = 10000.,
|
|
step: int = 2,
|
|
device: Optional[torch.device] = None,
|
|
) -> torch.Tensor:
|
|
exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
|
|
bands = 1. / (temperature ** exp)
|
|
return bands
|
|
|
|
|
|
def build_sincos2d_pos_embed(
|
|
feat_shape: List[int],
|
|
dim: int = 64,
|
|
temperature: float = 10000.,
|
|
reverse_coord: bool = False,
|
|
interleave_sin_cos: bool = False,
|
|
dtype: torch.dtype = torch.float32,
|
|
device: Optional[torch.device] = None
|
|
) -> torch.Tensor:
|
|
"""
|
|
|
|
Args:
|
|
feat_shape:
|
|
dim:
|
|
temperature:
|
|
reverse_coord: stack grid order W, H instead of H, W
|
|
interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
|
|
dtype:
|
|
device:
|
|
|
|
Returns:
|
|
|
|
"""
|
|
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
|
|
pos_dim = dim // 4
|
|
bands = freq_bands(pos_dim, temperature=temperature, step=1, device=device)
|
|
|
|
if reverse_coord:
|
|
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
|
|
grid = torch.stack(ndgrid([
|
|
torch.arange(s, device=device, dtype=torch.int64).to(torch.float32)
|
|
for s in feat_shape
|
|
])).flatten(1).transpose(0, 1)
|
|
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
|
|
# FIXME add support for unflattened spatial dim?
|
|
|
|
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
|
|
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
|
|
return pos_emb.to(dtype=dtype)
|
|
|
|
|
|
def build_fourier_pos_embed(
|
|
feat_shape: List[int],
|
|
bands: Optional[torch.Tensor] = None,
|
|
num_bands: int = 64,
|
|
max_res: int = 224,
|
|
temperature: float = 10000.,
|
|
linear_bands: bool = False,
|
|
include_grid: bool = False,
|
|
in_pixels: bool = True,
|
|
ref_feat_shape: Optional[List[int]] = None,
|
|
dtype: torch.dtype = torch.float32,
|
|
device: Optional[torch.device] = None,
|
|
) -> List[torch.Tensor]:
|
|
"""
|
|
|
|
Args:
|
|
feat_shape: Feature shape for embedding.
|
|
bands: Pre-calculated frequency bands.
|
|
num_bands: Number of frequency bands (determines output dim).
|
|
max_res: Maximum resolution for pixel based freq.
|
|
temperature: Temperature for non-pixel freq.
|
|
linear_bands: Linear band spacing for pixel based freq.
|
|
include_grid: Include the spatial grid in output.
|
|
in_pixels: Output in pixel freq.
|
|
ref_feat_shape: Reference feature shape for resize / fine-tune.
|
|
dtype: Output dtype.
|
|
device: Output device.
|
|
|
|
Returns:
|
|
|
|
"""
|
|
if bands is None:
|
|
if in_pixels:
|
|
bands = pixel_freq_bands(
|
|
num_bands,
|
|
float(max_res),
|
|
linear_bands=linear_bands,
|
|
device=device,
|
|
)
|
|
else:
|
|
bands = freq_bands(
|
|
num_bands,
|
|
temperature=temperature,
|
|
step=1,
|
|
device=device,
|
|
)
|
|
else:
|
|
if device is None:
|
|
device = bands.device
|
|
if dtype is None:
|
|
dtype = bands.dtype
|
|
|
|
if in_pixels:
|
|
t = [torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape]
|
|
else:
|
|
t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
|
|
|
|
if ref_feat_shape is not None:
|
|
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
|
|
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
|
|
|
|
grid = torch.stack(ndgrid(t), dim=-1)
|
|
grid = grid.unsqueeze(-1)
|
|
pos = grid * bands
|
|
|
|
pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype)
|
|
out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
|
|
return out
|
|
|
|
|
|
class FourierEmbed(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
max_res: int = 224,
|
|
num_bands: int = 64,
|
|
concat_grid=True,
|
|
keep_spatial=False,
|
|
):
|
|
super().__init__()
|
|
self.max_res = max_res
|
|
self.num_bands = num_bands
|
|
self.concat_grid = concat_grid
|
|
self.keep_spatial = keep_spatial
|
|
self.register_buffer(
|
|
'bands',
|
|
pixel_freq_bands(max_res, num_bands),
|
|
persistent=False,
|
|
)
|
|
|
|
def forward(self, x):
|
|
B, C = x.shape[:2]
|
|
feat_shape = x.shape[2:]
|
|
emb = build_fourier_pos_embed(
|
|
feat_shape,
|
|
self.bands,
|
|
include_grid=self.concat_grid,
|
|
dtype=x.dtype,
|
|
device=x.device,
|
|
)
|
|
emb = torch.cat(emb, dim=-1)
|
|
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
|
|
batch_expand = (B,) + (-1,) * (x.ndim - 1)
|
|
|
|
# FIXME support nD
|
|
if self.keep_spatial:
|
|
x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
|
|
else:
|
|
x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
|
|
x = x.reshape(B, feat_shape.numel(), -1)
|
|
|
|
return x
|
|
|
|
|
|
def rot(x):
|
|
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
|
|
|
|
|
|
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
|
|
if sin_emb.ndim == 3:
|
|
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
|
|
return x * cos_emb + rot(x) * sin_emb
|
|
|
|
|
|
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
|
|
if isinstance(x, torch.Tensor):
|
|
x = [x]
|
|
return [t * cos_emb + rot(t) * sin_emb for t in x]
|
|
|
|
|
|
def apply_rot_embed_cat(x: torch.Tensor, emb):
|
|
sin_emb, cos_emb = emb.tensor_split(2, -1)
|
|
if sin_emb.ndim == 3:
|
|
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
|
|
return x * cos_emb + rot(x) * sin_emb
|
|
|
|
|
|
def apply_keep_indices_nlc(x, pos_embed, keep_indices):
|
|
pos_embed = pos_embed.unsqueeze(0).expand(x.shape[0], -1, -1)
|
|
pos_embed = pos_embed.gather(1, keep_indices.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1]))
|
|
return pos_embed
|
|
|
|
|
|
def build_rotary_pos_embed(
|
|
feat_shape: List[int],
|
|
bands: Optional[torch.Tensor] = None,
|
|
dim: int = 64,
|
|
max_res: int = 224,
|
|
temperature: float = 10000.,
|
|
linear_bands: bool = False,
|
|
in_pixels: bool = True,
|
|
ref_feat_shape: Optional[List[int]] = None,
|
|
dtype: torch.dtype = torch.float32,
|
|
device: Optional[torch.device] = None,
|
|
):
|
|
"""
|
|
|
|
Args:
|
|
feat_shape: Spatial shape of the target tensor for embedding.
|
|
bands: Optional pre-generated frequency bands
|
|
dim: Output dimension of embedding tensor.
|
|
max_res: Maximum resolution for pixel mode.
|
|
temperature: Temperature (inv freq) for non-pixel mode
|
|
linear_bands: Linearly (instead of log) spaced bands for pixel mode
|
|
in_pixels: Pixel vs language (inv freq) mode.
|
|
dtype: Output dtype.
|
|
device: Output device.
|
|
|
|
Returns:
|
|
|
|
"""
|
|
sin_emb, cos_emb = build_fourier_pos_embed(
|
|
feat_shape,
|
|
bands=bands,
|
|
num_bands=dim // 4,
|
|
max_res=max_res,
|
|
temperature=temperature,
|
|
linear_bands=linear_bands,
|
|
in_pixels=in_pixels,
|
|
ref_feat_shape=ref_feat_shape,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
num_spatial_dim = 1
|
|
# this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
|
|
for x in feat_shape:
|
|
num_spatial_dim *= x
|
|
sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
|
|
cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
|
|
return sin_emb, cos_emb
|
|
|
|
|
|
class RotaryEmbedding(nn.Module):
|
|
""" Rotary position embedding
|
|
|
|
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
|
|
been well tested, and will likely change. It will be moved to its own file.
|
|
|
|
The following impl/resources were referenced for this impl:
|
|
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
|
|
* https://blog.eleuther.ai/rotary-embeddings/
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
max_res=224,
|
|
temperature=10000,
|
|
in_pixels=True,
|
|
linear_bands: bool = False,
|
|
feat_shape: Optional[List[int]] = None,
|
|
ref_feat_shape: Optional[List[int]] = None,
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.max_res = max_res
|
|
self.temperature = temperature
|
|
self.in_pixels = in_pixels
|
|
self.feat_shape = feat_shape
|
|
self.ref_feat_shape = ref_feat_shape
|
|
|
|
if feat_shape is None:
|
|
# only cache bands
|
|
if in_pixels:
|
|
bands = pixel_freq_bands(
|
|
dim // 4,
|
|
float(max_res),
|
|
linear_bands=linear_bands,
|
|
)
|
|
else:
|
|
bands = freq_bands(
|
|
dim // 4,
|
|
temperature=temperature,
|
|
step=1,
|
|
)
|
|
self.register_buffer(
|
|
'bands',
|
|
bands,
|
|
persistent=False,
|
|
)
|
|
self.pos_embed_sin = None
|
|
self.pos_embed_cos = None
|
|
else:
|
|
# cache full sin/cos embeddings if shape provided up front
|
|
emb_sin, emb_cos = build_rotary_pos_embed(
|
|
feat_shape=feat_shape,
|
|
dim=dim,
|
|
max_res=max_res,
|
|
linear_bands=linear_bands,
|
|
in_pixels=in_pixels,
|
|
ref_feat_shape=self.ref_feat_shape,
|
|
)
|
|
self.bands = None
|
|
self.register_buffer(
|
|
'pos_embed_sin',
|
|
emb_sin,
|
|
persistent=False,
|
|
)
|
|
self.register_buffer(
|
|
'pos_embed_cos',
|
|
emb_cos,
|
|
persistent=False,
|
|
)
|
|
|
|
def get_embed(self, shape: Optional[List[int]] = None):
|
|
if self.bands is not None:
|
|
# rebuild embeddings every call, use if target shape changes
|
|
assert shape is not None
|
|
return build_rotary_pos_embed(
|
|
shape,
|
|
self.bands,
|
|
in_pixels=self.in_pixels,
|
|
)
|
|
else:
|
|
return self.pos_embed_sin, self.pos_embed_cos
|
|
|
|
def forward(self, x):
|
|
# assuming channel-first tensor where spatial dim are >= 2
|
|
sin_emb, cos_emb = self.get_embed(x.shape[2:])
|
|
return apply_rot_embed(x, sin_emb, cos_emb)
|
|
|
|
|
|
class RotaryEmbeddingCat(nn.Module):
|
|
""" Rotary position embedding w/ concatenatd sin & cos
|
|
|
|
The following impl/resources were referenced for this impl:
|
|
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
|
|
* https://blog.eleuther.ai/rotary-embeddings/
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
max_res=224,
|
|
temperature=10000,
|
|
in_pixels=True,
|
|
linear_bands: bool = False,
|
|
feat_shape: Optional[List[int]] = None,
|
|
ref_feat_shape: Optional[List[int]] = None,
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.max_res = max_res
|
|
self.temperature = temperature
|
|
self.in_pixels = in_pixels
|
|
self.feat_shape = feat_shape
|
|
self.ref_feat_shape = ref_feat_shape
|
|
|
|
if feat_shape is None:
|
|
# only cache bands
|
|
if in_pixels:
|
|
bands = pixel_freq_bands(
|
|
dim // 4,
|
|
float(max_res),
|
|
linear_bands=linear_bands,
|
|
)
|
|
else:
|
|
bands = freq_bands(
|
|
dim // 4,
|
|
temperature=temperature,
|
|
step=1,
|
|
)
|
|
self.register_buffer(
|
|
'bands',
|
|
bands,
|
|
persistent=False,
|
|
)
|
|
self.pos_embed = None
|
|
else:
|
|
# cache full sin/cos embeddings if shape provided up front
|
|
embeds = build_rotary_pos_embed(
|
|
feat_shape=feat_shape,
|
|
dim=dim,
|
|
max_res=max_res,
|
|
linear_bands=linear_bands,
|
|
in_pixels=in_pixels,
|
|
ref_feat_shape=self.ref_feat_shape,
|
|
)
|
|
self.bands = None
|
|
self.register_buffer(
|
|
'pos_embed',
|
|
torch.cat(embeds, -1),
|
|
persistent=False,
|
|
)
|
|
|
|
def get_embed(self, shape: Optional[List[int]] = None):
|
|
if self.bands is not None and shape is not None:
|
|
# rebuild embeddings every call, use if target shape changes
|
|
embeds = build_rotary_pos_embed(
|
|
shape,
|
|
self.bands,
|
|
in_pixels=self.in_pixels,
|
|
ref_feat_shape=self.ref_feat_shape,
|
|
)
|
|
return torch.cat(embeds, -1)
|
|
elif self.pos_embed is not None:
|
|
return self.pos_embed
|
|
else:
|
|
assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands"
|
|
|
|
def forward(self, x):
|
|
# assuming channel-first tensor where spatial dim are >= 2
|
|
pos_embed = self.get_embed(x.shape[2:])
|
|
return apply_rot_embed_cat(x, pos_embed)
|