Remove unecessary LongTensor in EfficientFormer. Possibly maybe fix #1878

pull/1903/head v0.9.5
Ross Wightman 2023-08-03 16:38:53 -07:00
parent 4224529ebe
commit 81089b10a2
1 changed files with 1 additions and 1 deletions

View File

@ -67,7 +67,7 @@ class Attention(torch.nn.Module):
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
self.register_buffer('attention_bias_idxs', torch.LongTensor(rel_pos))
self.register_buffer('attention_bias_idxs', rel_pos)
self.attention_bias_cache = {} # per-device attention_biases cache (data-parallel compat)
@torch.no_grad()