[Fix] fix generalized attention fp16 (#1036)

* fix generalized attention fp16

* fix building without gpu error

* add comment

* Cast tensor at initialization
pull/1029/head
Guangchen Lin 2021-05-23 15:16:27 +08:00 committed by GitHub
parent 1a66977f33
commit 5be9593499
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 7 deletions

View File

@ -170,18 +170,23 @@ class GeneralizedAttention(nn.Module):
q_stride,
kv_stride,
device,
dtype,
feat_dim,
wave_length=1000):
h_idxs = torch.linspace(0, h - 1, h).to(device)
# the default type of Tensor is float32, leading to type mismatch
# in fp16 mode. Cast it to support fp16 mode.
h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
h_idxs = h_idxs.view((h, 1)) * q_stride
w_idxs = torch.linspace(0, w - 1, w).to(device)
w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
w_idxs = w_idxs.view((w, 1)) * q_stride
h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(device)
h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
device=device, dtype=dtype)
h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(device)
w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
device=device, dtype=dtype)
w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
# (h, h_kv, 1)
@ -192,9 +197,10 @@ class GeneralizedAttention(nn.Module):
w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
w_diff *= self.position_magnitude
feat_range = torch.arange(0, feat_dim / 4).to(device)
feat_range = torch.arange(0, feat_dim / 4).to(
device=device, dtype=dtype)
dim_mat = torch.Tensor([wave_length]).to(device)
dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
dim_mat = dim_mat**((4. / feat_dim) * feat_range)
dim_mat = dim_mat.view((1, 1, -1))
@ -234,7 +240,7 @@ class GeneralizedAttention(nn.Module):
if self.attention_type[1] or self.attention_type[3]:
position_embed_x, position_embed_y = self.get_position_embedding(
h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
x_input.device, self.position_embedding_dim)
x_input.device, x_input.dtype, self.position_embedding_dim)
# (n, num_heads, w, w_kv, dim)
position_feat_x = self.appr_geom_fc_x(position_embed_x).\
view(1, w, w_kv, num_heads, self.qk_embed_dim).\

View File

@ -60,3 +60,16 @@ def test_context_block():
assert gen_attention_block.kv_downsample is not None
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
# test fp16 with attention_type='1111'
if torch.cuda.is_available():
imgs = torch.randn(2, 16, 20, 20).cuda().to(torch.half)
gen_attention_block = GeneralizedAttention(
16,
spatial_range=-1,
num_heads=8,
attention_type='1111',
kv_stride=2)
gen_attention_block.cuda().type(torch.half)
out = gen_attention_block(imgs)
assert out.shape == imgs.shape