[Refactor] Re-write get_sinusoid_encoding from third-party implementation. (#965)

pull/1034/head
Kai Hu 2022-09-13 03:24:29 -04:00 committed by GitHub
parent 6d8c91892c
commit 0b4a67dd31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 12 deletions

View File

@ -218,27 +218,24 @@ def get_sinusoid_encoding(n_position, embed_dims):
Sinusoid encoding is a kind of relative position encoding method came from
`Attention Is All You Need<https://arxiv.org/abs/1706.03762>`_.
Args:
n_position (int): The length of the input token.
embed_dims (int): The position embedding dimension.
Returns:
:obj:`torch.FloatTensor`: The sinusoid encoding table.
"""
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (i // 2) / embed_dims)
for i in range(embed_dims)
]
vec = torch.arange(embed_dims, dtype=torch.float64)
vec = (vec - vec % 2) / embed_dims
vec = torch.pow(10000, -vec).view(1, -1)
sinusoid_table = np.array(
[get_position_angle_vec(pos) for pos in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
sinusoid_table = torch.arange(n_position).view(-1, 1) * vec
sinusoid_table[:, 0::2].sin_() # dim 2i
sinusoid_table[:, 1::2].cos_() # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
sinusoid_table = sinusoid_table.to(torch.float32)
return sinusoid_table.unsqueeze(0)
@BACKBONES.register_module()

View File

@ -5,10 +5,12 @@ import tempfile
from copy import deepcopy
from unittest import TestCase
import numpy as np
import torch
from mmcv.runner import load_checkpoint, save_checkpoint
from mmcls.models.backbones import T2T_ViT
from mmcls.models.backbones.t2t_vit import get_sinusoid_encoding
from .utils import timm_resize_pos_embed
@ -155,3 +157,32 @@ class TestT2TViT(TestCase):
math.ceil(imgs.shape[3] / 16))
self.assertEqual(patch_token.shape, (1, 384, *expect_feat_shape))
self.assertEqual(cls_token.shape, (1, 384))
def test_get_sinusoid_encoding():
# original numpy based third-party implementation copied from mmcls
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
def get_sinusoid_encoding_numpy(n_position, d_hid):
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
n_positions = [128, 256, 512, 1024]
embed_dims = [128, 256, 512, 1024]
for n_position in n_positions:
for embed_dim in embed_dims:
out_mmcls = get_sinusoid_encoding(n_position, embed_dim)
out_numpy = get_sinusoid_encoding_numpy(n_position, embed_dim)
error = (out_mmcls - out_numpy).abs().max()
assert error < 1e-9, 'Test case n_position=%d, embed_dim=%d failed'
return