diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 970d1a11..04204b3d 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -67,7 +67,7 @@ class Attention(torch.nn.Module): rel_pos = (pos[..., :, None] - pos[..., None, :]).abs() rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1] self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1])) - self.register_buffer('attention_bias_idxs', torch.LongTensor(rel_pos)) + self.register_buffer('attention_bias_idxs', rel_pos) self.attention_bias_cache = {} # per-device attention_biases cache (data-parallel compat) @torch.no_grad()