mirror of https://github.com/open-mmlab/mmcv.git
[Fix] fix generalized attention fp16 (#1036)
* fix generalized attention fp16 * fix building without gpu error * add comment * Cast tensor at initializationpull/1029/head
parent
1a66977f33
commit
5be9593499
|
@ -170,18 +170,23 @@ class GeneralizedAttention(nn.Module):
|
|||
q_stride,
|
||||
kv_stride,
|
||||
device,
|
||||
dtype,
|
||||
feat_dim,
|
||||
wave_length=1000):
|
||||
h_idxs = torch.linspace(0, h - 1, h).to(device)
|
||||
# the default type of Tensor is float32, leading to type mismatch
|
||||
# in fp16 mode. Cast it to support fp16 mode.
|
||||
h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
|
||||
h_idxs = h_idxs.view((h, 1)) * q_stride
|
||||
|
||||
w_idxs = torch.linspace(0, w - 1, w).to(device)
|
||||
w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
|
||||
w_idxs = w_idxs.view((w, 1)) * q_stride
|
||||
|
||||
h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(device)
|
||||
h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
|
||||
device=device, dtype=dtype)
|
||||
h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
|
||||
|
||||
w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(device)
|
||||
w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
|
||||
device=device, dtype=dtype)
|
||||
w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
|
||||
|
||||
# (h, h_kv, 1)
|
||||
|
@ -192,9 +197,10 @@ class GeneralizedAttention(nn.Module):
|
|||
w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
|
||||
w_diff *= self.position_magnitude
|
||||
|
||||
feat_range = torch.arange(0, feat_dim / 4).to(device)
|
||||
feat_range = torch.arange(0, feat_dim / 4).to(
|
||||
device=device, dtype=dtype)
|
||||
|
||||
dim_mat = torch.Tensor([wave_length]).to(device)
|
||||
dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
|
||||
dim_mat = dim_mat**((4. / feat_dim) * feat_range)
|
||||
dim_mat = dim_mat.view((1, 1, -1))
|
||||
|
||||
|
@ -234,7 +240,7 @@ class GeneralizedAttention(nn.Module):
|
|||
if self.attention_type[1] or self.attention_type[3]:
|
||||
position_embed_x, position_embed_y = self.get_position_embedding(
|
||||
h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
|
||||
x_input.device, self.position_embedding_dim)
|
||||
x_input.device, x_input.dtype, self.position_embedding_dim)
|
||||
# (n, num_heads, w, w_kv, dim)
|
||||
position_feat_x = self.appr_geom_fc_x(position_embed_x).\
|
||||
view(1, w, w_kv, num_heads, self.qk_embed_dim).\
|
||||
|
|
|
@ -60,3 +60,16 @@ def test_context_block():
|
|||
assert gen_attention_block.kv_downsample is not None
|
||||
out = gen_attention_block(imgs)
|
||||
assert out.shape == imgs.shape
|
||||
|
||||
# test fp16 with attention_type='1111'
|
||||
if torch.cuda.is_available():
|
||||
imgs = torch.randn(2, 16, 20, 20).cuda().to(torch.half)
|
||||
gen_attention_block = GeneralizedAttention(
|
||||
16,
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='1111',
|
||||
kv_stride=2)
|
||||
gen_attention_block.cuda().type(torch.half)
|
||||
out = gen_attention_block(imgs)
|
||||
assert out.shape == imgs.shape
|
||||
|
|
Loading…
Reference in New Issue