233 lines
8.4 KiB
Python
233 lines
8.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Copyright (c) 2022 Microsoft
|
|
# Modified from
|
|
# https://github.com/microsoft/unilm/blob/master/beit2/norm_ema_quantizer.py
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange, repeat
|
|
from mmengine.dist import all_reduce
|
|
|
|
|
|
def ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor,
|
|
decay: torch.Tensor) -> None:
|
|
"""Update moving average."""
|
|
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
|
|
|
|
|
def norm_ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor,
|
|
decay: torch.Tensor) -> None:
|
|
"""Update moving average with norm data."""
|
|
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
|
moving_avg.data.copy_(F.normalize(moving_avg.data, p=2, dim=-1))
|
|
|
|
|
|
def sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor:
|
|
"""Sample vectors according to the given number."""
|
|
num_samples, device = samples.shape[0], samples.device
|
|
|
|
if num_samples >= num:
|
|
indices = torch.randperm(num_samples, device=device)[:num]
|
|
else:
|
|
indices = torch.randint(0, num_samples, (num, ), device=device)
|
|
|
|
return samples[indices]
|
|
|
|
|
|
def kmeans(samples: torch.Tensor,
|
|
num_clusters: int,
|
|
num_iters: int = 10,
|
|
use_cosine_sim: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Run k-means algorithm."""
|
|
dim, dtype, _ = samples.shape[-1], samples.dtype, samples.device
|
|
|
|
means = sample_vectors(samples, num_clusters)
|
|
|
|
for _ in range(num_iters):
|
|
if use_cosine_sim:
|
|
dists = samples @ means.t()
|
|
else:
|
|
diffs = rearrange(samples, 'n d -> n () d') \
|
|
- rearrange(means, 'c d -> () c d')
|
|
dists = -(diffs**2).sum(dim=-1)
|
|
|
|
buckets = dists.max(dim=-1).indices
|
|
bins = torch.bincount(buckets, minlength=num_clusters)
|
|
zero_mask = bins == 0
|
|
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
|
|
|
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
|
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
|
|
new_means = new_means / bins_min_clamped[..., None]
|
|
|
|
if use_cosine_sim:
|
|
new_means = F.normalize(new_means, p=2, dim=-1)
|
|
|
|
means = torch.where(zero_mask[..., None], means, new_means)
|
|
|
|
return means, bins
|
|
|
|
|
|
class EmbeddingEMA(nn.Module):
|
|
"""The codebook of embedding vectors.
|
|
|
|
Args:
|
|
num_tokens (int): Number of embedding vectors in the codebook.
|
|
codebook_dim (int) : The dimension of embedding vectors in the
|
|
codebook.
|
|
kmeans_init (bool): Whether to use k-means to initialize the
|
|
VectorQuantizer. Defaults to True.
|
|
codebook_init_path (str): The initialization checkpoint for codebook.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_tokens: int,
|
|
codebook_dim: int,
|
|
kmeans_init: bool = True,
|
|
codebook_init_path: Optional[str] = None):
|
|
super().__init__()
|
|
self.num_tokens = num_tokens
|
|
self.codebook_dim = codebook_dim
|
|
if codebook_init_path is None:
|
|
if not kmeans_init:
|
|
weight = torch.randn(num_tokens, codebook_dim)
|
|
weight = F.normalize(weight, p=2, dim=-1)
|
|
else:
|
|
weight = torch.zeros(num_tokens, codebook_dim)
|
|
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
|
else:
|
|
print(f'load init codebook weight from {codebook_init_path}')
|
|
codebook_ckpt_weight = torch.load(
|
|
codebook_init_path, map_location='cpu')
|
|
weight = codebook_ckpt_weight.clone()
|
|
self.register_buffer('initted', torch.Tensor([True]))
|
|
|
|
self.weight = nn.Parameter(weight, requires_grad=False)
|
|
self.update = True
|
|
|
|
@torch.jit.ignore
|
|
def init_embed_(self, data: torch.Tensor) -> None:
|
|
"""Initialize embedding vectors of codebook."""
|
|
if self.initted:
|
|
return
|
|
print('Performing K-means init for codebook')
|
|
embed, _ = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
|
|
self.weight.data.copy_(embed)
|
|
self.initted.data.copy_(torch.Tensor([True]))
|
|
|
|
def forward(self, embed_id: torch.Tensor) -> torch.Tensor:
|
|
"""Get embedding vectors."""
|
|
return F.embedding(embed_id, self.weight)
|
|
|
|
|
|
class NormEMAVectorQuantizer(nn.Module):
|
|
"""Normed EMA vector quantizer module.
|
|
|
|
Args:
|
|
num_embed (int): Number of embedding vectors in the codebook. Defaults
|
|
to 8192.
|
|
embed_dims (int) : The dimension of embedding vectors in the codebook.
|
|
Defaults to 32.
|
|
beta (float): The mutiplier for VectorQuantizer embedding loss.
|
|
Defaults to 1.
|
|
decay (float): The decay parameter of EMA. Defaults to 0.99.
|
|
statistic_code_usage (bool): Whether to use cluster_size to record
|
|
statistic. Defaults to True.
|
|
kmeans_init (bool): Whether to use k-means to initialize the
|
|
VectorQuantizer. Defaults to True.
|
|
codebook_init_path (str): The initialization checkpoint for codebook.
|
|
Defaults to None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_embed: int,
|
|
embed_dims: int,
|
|
beta: float,
|
|
decay: float = 0.99,
|
|
statistic_code_usage: bool = True,
|
|
kmeans_init: bool = True,
|
|
codebook_init_path: Optional[str] = None) -> None:
|
|
super().__init__()
|
|
self.codebook_dim = embed_dims
|
|
self.num_tokens = num_embed
|
|
self.beta = beta
|
|
self.decay = decay
|
|
|
|
# learnable = True if orthogonal_reg_weight > 0 else False
|
|
self.embedding = EmbeddingEMA(
|
|
num_tokens=self.num_tokens,
|
|
codebook_dim=self.codebook_dim,
|
|
kmeans_init=kmeans_init,
|
|
codebook_init_path=codebook_init_path)
|
|
|
|
self.statistic_code_usage = statistic_code_usage
|
|
if statistic_code_usage:
|
|
self.register_buffer('cluster_size', torch.zeros(num_embed))
|
|
|
|
def reset_cluster_size(self, device):
|
|
|
|
if self.statistic_code_usage:
|
|
self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
|
|
self.cluster_size = self.cluster_size.to(device)
|
|
|
|
def forward(self, z):
|
|
"""Forward function."""
|
|
# reshape z -> (batch, height, width, channel)
|
|
z = rearrange(z, 'b c h w -> b h w c')
|
|
z = F.normalize(z, p=2, dim=-1)
|
|
z_flattened = z.reshape(-1, self.codebook_dim)
|
|
|
|
self.embedding.init_embed_(z_flattened)
|
|
|
|
# 'n d -> d n'
|
|
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
|
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
|
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight)
|
|
|
|
encoding_indices = torch.argmin(d, dim=1)
|
|
|
|
z_q = self.embedding(encoding_indices).view(z.shape)
|
|
|
|
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
|
|
|
if not self.training:
|
|
with torch.no_grad():
|
|
cluster_size = encodings.sum(0)
|
|
all_reduce(cluster_size)
|
|
ema_inplace(self.cluster_size, cluster_size, self.decay)
|
|
|
|
if self.training and self.embedding.update:
|
|
# update cluster size with EMA
|
|
bins = encodings.sum(0)
|
|
all_reduce(bins)
|
|
ema_inplace(self.cluster_size, bins, self.decay)
|
|
|
|
zero_mask = (bins == 0)
|
|
bins = bins.masked_fill(zero_mask, 1.)
|
|
|
|
embed_sum = z_flattened.t() @ encodings
|
|
all_reduce(embed_sum)
|
|
|
|
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
|
|
embed_normalized = F.normalize(embed_normalized, p=2, dim=-1)
|
|
embed_normalized = torch.where(zero_mask[..., None],
|
|
self.embedding.weight,
|
|
embed_normalized)
|
|
|
|
# Update embedding vectors with EMA
|
|
norm_ema_inplace(self.embedding.weight, embed_normalized,
|
|
self.decay)
|
|
|
|
# compute loss for embedding
|
|
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
|
|
|
# preserve gradients
|
|
z_q = z + (z_q - z).detach()
|
|
|
|
# reshape back to match original input shape
|
|
z_q = rearrange(z_q, 'b h w c -> b c h w')
|
|
return z_q, loss, encoding_indices
|