parent
4224529ebe
commit
81089b10a2
|
@ -67,7 +67,7 @@ class Attention(torch.nn.Module):
|
|||
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
|
||||
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
|
||||
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
|
||||
self.register_buffer('attention_bias_idxs', torch.LongTensor(rel_pos))
|
||||
self.register_buffer('attention_bias_idxs', rel_pos)
|
||||
self.attention_bias_cache = {} # per-device attention_biases cache (data-parallel compat)
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
Loading…
Reference in New Issue