[Refactor] Re-write get_sinusoid_encoding from third-party implementation. (#965)
parent
6d8c91892c
commit
0b4a67dd31
|
@ -218,27 +218,24 @@ def get_sinusoid_encoding(n_position, embed_dims):
|
|||
|
||||
Sinusoid encoding is a kind of relative position encoding method came from
|
||||
`Attention Is All You Need<https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
Args:
|
||||
n_position (int): The length of the input token.
|
||||
embed_dims (int): The position embedding dimension.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.FloatTensor`: The sinusoid encoding table.
|
||||
"""
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [
|
||||
position / np.power(10000, 2 * (i // 2) / embed_dims)
|
||||
for i in range(embed_dims)
|
||||
]
|
||||
vec = torch.arange(embed_dims, dtype=torch.float64)
|
||||
vec = (vec - vec % 2) / embed_dims
|
||||
vec = torch.pow(10000, -vec).view(1, -1)
|
||||
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos) for pos in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
sinusoid_table = torch.arange(n_position).view(-1, 1) * vec
|
||||
sinusoid_table[:, 0::2].sin_() # dim 2i
|
||||
sinusoid_table[:, 1::2].cos_() # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
sinusoid_table = sinusoid_table.to(torch.float32)
|
||||
|
||||
return sinusoid_table.unsqueeze(0)
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
|
|
|
@ -5,10 +5,12 @@ import tempfile
|
|||
from copy import deepcopy
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
|
||||
from mmcls.models.backbones import T2T_ViT
|
||||
from mmcls.models.backbones.t2t_vit import get_sinusoid_encoding
|
||||
from .utils import timm_resize_pos_embed
|
||||
|
||||
|
||||
|
@ -155,3 +157,32 @@ class TestT2TViT(TestCase):
|
|||
math.ceil(imgs.shape[3] / 16))
|
||||
self.assertEqual(patch_token.shape, (1, 384, *expect_feat_shape))
|
||||
self.assertEqual(cls_token.shape, (1, 384))
|
||||
|
||||
|
||||
def test_get_sinusoid_encoding():
|
||||
# original numpy based third-party implementation copied from mmcls
|
||||
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
|
||||
def get_sinusoid_encoding_numpy(n_position, d_hid):
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
]
|
||||
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
n_positions = [128, 256, 512, 1024]
|
||||
embed_dims = [128, 256, 512, 1024]
|
||||
for n_position in n_positions:
|
||||
for embed_dim in embed_dims:
|
||||
out_mmcls = get_sinusoid_encoding(n_position, embed_dim)
|
||||
out_numpy = get_sinusoid_encoding_numpy(n_position, embed_dim)
|
||||
error = (out_mmcls - out_numpy).abs().max()
|
||||
assert error < 1e-9, 'Test case n_position=%d, embed_dim=%d failed'
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue