Fix meshgrid deprecation warnings and backward compat with explicit 'ndgrid' and 'meshgrid' fn w/o indexing arg
parent
fa247fd9ba
commit
88889de923
|
@ -24,6 +24,7 @@ from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct
|
|||
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
|
||||
from .gather_excite import GatherExcite
|
||||
from .global_context import GlobalContext
|
||||
from .grid import ndgrid, meshgrid
|
||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
|
||||
from .inplace_abn import InplaceAbn
|
||||
from .linear import Linear
|
||||
|
|
|
@ -18,10 +18,18 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .grid import ndgrid
|
||||
|
||||
|
||||
def drop_block_2d(
|
||||
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
|
||||
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
||||
x,
|
||||
drop_prob: float = 0.1,
|
||||
block_size: int = 7,
|
||||
gamma_scale: float = 1.0,
|
||||
with_noise: bool = False,
|
||||
inplace: bool = False,
|
||||
batchwise: bool = False
|
||||
):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
|
||||
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
|
||||
|
@ -35,7 +43,7 @@ def drop_block_2d(
|
|||
(W - block_size + 1) * (H - block_size + 1))
|
||||
|
||||
# Forces the block to be inside the feature map.
|
||||
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
|
||||
w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device))
|
||||
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
|
||||
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
|
||||
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
|
||||
|
@ -68,8 +76,13 @@ def drop_block_2d(
|
|||
|
||||
|
||||
def drop_block_fast_2d(
|
||||
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
|
||||
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False):
|
||||
x: torch.Tensor,
|
||||
drop_prob: float = 0.1,
|
||||
block_size: int = 7,
|
||||
gamma_scale: float = 1.0,
|
||||
with_noise: bool = False,
|
||||
inplace: bool = False,
|
||||
):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
|
||||
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
|
||||
"""generate N-D grid in dimension order.
|
||||
|
||||
The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
|
||||
|
||||
That is, the statement
|
||||
[X1,X2,X3] = ndgrid(x1,x2,x3)
|
||||
|
||||
produces the same result as
|
||||
|
||||
[X2,X1,X3] = meshgrid(x2,x1,x3)
|
||||
|
||||
This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
|
||||
torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
|
||||
|
||||
"""
|
||||
try:
|
||||
return torch.meshgrid(*tensors, indexing='ij')
|
||||
except TypeError:
|
||||
# old PyTorch < 1.10 will follow this path as it does not have indexing arg,
|
||||
# the old behaviour of meshgrid was 'ij'
|
||||
return torch.meshgrid(*tensors)
|
||||
|
||||
|
||||
def meshgrid(*tensors) -> Tuple[torch.Tensor, ...]:
|
||||
"""generate N-D grid in spatial dim order.
|
||||
|
||||
The meshgrid function is similar to ndgrid except that the order of the
|
||||
first two input and output arguments is switched.
|
||||
|
||||
That is, the statement
|
||||
|
||||
[X,Y,Z] = meshgrid(x,y,z)
|
||||
produces the same result as
|
||||
|
||||
[Y,X,Z] = ndgrid(y,x,z)
|
||||
Because of this, meshgrid is better suited to problems in two- or three-dimensional Cartesian space,
|
||||
while ndgrid is better suited to multidimensional problems that aren't spatially based.
|
||||
"""
|
||||
|
||||
# NOTE: this will throw in PyTorch < 1.10 as meshgrid did not support indexing arg or have
|
||||
# capability of generating grid in xy order before then.
|
||||
return torch.meshgrid(*tensors, indexing='xy')
|
||||
|
|
@ -24,13 +24,14 @@ import torch
|
|||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .grid import ndgrid
|
||||
from .helpers import to_2tuple, make_divisible
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
|
||||
def rel_pos_indices(size):
|
||||
size = to_2tuple(size)
|
||||
pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
|
||||
pos = torch.stack(ndgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
|
||||
rel_pos = pos[:, None, :] - pos[:, :, None]
|
||||
rel_pos[0] += size[0] - 1
|
||||
rel_pos[1] += size[1] - 1
|
||||
|
|
|
@ -10,6 +10,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .grid import ndgrid
|
||||
from .interpolate import RegularGridInterpolator
|
||||
from .mlp import Mlp
|
||||
from .weight_init import trunc_normal_
|
||||
|
@ -26,12 +27,7 @@ def gen_relative_position_index(
|
|||
# get pair-wise relative position index for each token inside the window
|
||||
assert k_size is None, 'Different q & k sizes not currently supported' # FIXME
|
||||
|
||||
coords = torch.stack(
|
||||
torch.meshgrid([
|
||||
torch.arange(q_size[0]),
|
||||
torch.arange(q_size[1])
|
||||
])
|
||||
).flatten(1) # 2, Wh, Ww
|
||||
coords = torch.stack(ndgrid(torch.arange(q_size[0]), torch.arange(q_size[1]))).flatten(1) # 2, Wh, Ww
|
||||
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
|
||||
relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0
|
||||
|
@ -42,16 +38,16 @@ def gen_relative_position_index(
|
|||
# else:
|
||||
# # FIXME different q vs k sizes is a WIP, need to better offset the two grids?
|
||||
# q_coords = torch.stack(
|
||||
# torch.meshgrid([
|
||||
# ndgrid(
|
||||
# torch.arange(q_size[0]),
|
||||
# torch.arange(q_size[1])
|
||||
# ])
|
||||
# )
|
||||
# ).flatten(1) # 2, Wh, Ww
|
||||
# k_coords = torch.stack(
|
||||
# torch.meshgrid([
|
||||
# ndgrid(
|
||||
# torch.arange(k_size[0]),
|
||||
# torch.arange(k_size[1])
|
||||
# ])
|
||||
# )
|
||||
# ).flatten(1)
|
||||
# relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
# relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
|
||||
|
@ -232,7 +228,7 @@ def resize_rel_pos_bias_table(
|
|||
tx = dst_size[1] // 2.0
|
||||
dy = torch.arange(-ty, ty + 0.1, 1.0)
|
||||
dx = torch.arange(-tx, tx + 0.1, 1.0)
|
||||
dyx = torch.meshgrid([dy, dx])
|
||||
dyx = ndgrid(dy, dx)
|
||||
# print("Target positions = %s" % str(dx))
|
||||
|
||||
all_rel_pos_bias = []
|
||||
|
@ -313,7 +309,7 @@ def gen_relative_log_coords(
|
|||
# as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
|
||||
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0]).to(torch.float32)
|
||||
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1]).to(torch.float32)
|
||||
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
|
||||
relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w))
|
||||
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
|
||||
if mode == 'swin':
|
||||
if pretrained_win_size[0] > 0:
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import List, Tuple, Optional, Union
|
|||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from .grid import ndgrid
|
||||
from .trace_utils import _assert
|
||||
|
||||
|
||||
|
@ -64,10 +65,10 @@ def build_sincos2d_pos_embed(
|
|||
|
||||
if reverse_coord:
|
||||
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
|
||||
grid = torch.stack(torch.meshgrid(
|
||||
[torch.arange(s, device=device, dtype=torch.int64).to(torch.float32)
|
||||
for s in feat_shape])
|
||||
).flatten(1).transpose(0, 1)
|
||||
grid = torch.stack(ndgrid([
|
||||
torch.arange(s, device=device, dtype=torch.int64).to(torch.float32)
|
||||
for s in feat_shape
|
||||
])).flatten(1).transpose(0, 1)
|
||||
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
|
||||
# FIXME add support for unflattened spatial dim?
|
||||
|
||||
|
@ -137,7 +138,7 @@ def build_fourier_pos_embed(
|
|||
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
|
||||
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
|
||||
|
||||
grid = torch.stack(torch.meshgrid(t), dim=-1)
|
||||
grid = torch.stack(ndgrid(t), dim=-1)
|
||||
grid = grid.unsqueeze(-1)
|
||||
pos = grid * bands
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ from torch.utils.checkpoint import checkpoint
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
|
||||
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table
|
||||
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid
|
||||
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
|
@ -63,9 +63,7 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
|
|||
# cls to token & token 2 cls & cls to cls
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
window_area = window_size[0] * window_size[1]
|
||||
coords = torch.stack(torch.meshgrid(
|
||||
[torch.arange(window_size[0]),
|
||||
torch.arange(window_size[1])])) # 2, Wh, Ww
|
||||
coords = torch.stack(ndgrid(torch.arange(window_size[0]), torch.arange(window_size[1]))) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
|
|
|
@ -18,7 +18,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
|
||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp, ndgrid
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
@ -63,7 +63,7 @@ class Attention(torch.nn.Module):
|
|||
self.proj = nn.Linear(self.val_attn_dim, dim)
|
||||
|
||||
resolution = to_2tuple(resolution)
|
||||
pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
|
||||
pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
|
||||
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
|
||||
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
|
||||
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
|
||||
|
|
|
@ -23,7 +23,7 @@ import torch.nn as nn
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
|
||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
|
||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
@ -129,7 +129,7 @@ class Attention2d(torch.nn.Module):
|
|||
self.act = act_layer()
|
||||
self.proj = ConvNorm(self.dh, dim, 1)
|
||||
|
||||
pos = torch.stack(torch.meshgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1)
|
||||
pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1)
|
||||
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
|
||||
rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
|
||||
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, self.N))
|
||||
|
@ -231,12 +231,11 @@ class Attention2dDownsample(torch.nn.Module):
|
|||
self.proj = ConvNorm(self.dh, self.out_dim, 1)
|
||||
|
||||
self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.N))
|
||||
k_pos = torch.stack(torch.meshgrid(torch.arange(
|
||||
self.resolution[0]),
|
||||
torch.arange(self.resolution[1]))).flatten(1)
|
||||
q_pos = torch.stack(torch.meshgrid(
|
||||
k_pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1)
|
||||
q_pos = torch.stack(ndgrid(
|
||||
torch.arange(0, self.resolution[0], step=2),
|
||||
torch.arange(0, self.resolution[1], step=2))).flatten(1)
|
||||
torch.arange(0, self.resolution[1], step=2)
|
||||
)).flatten(1)
|
||||
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
|
||||
rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
|
||||
self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
|
||||
|
|
|
@ -31,7 +31,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
|
||||
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_
|
||||
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model
|
||||
|
@ -194,7 +194,7 @@ class Attention(nn.Module):
|
|||
]))
|
||||
|
||||
self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
|
||||
pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
|
||||
pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
|
||||
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
|
||||
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
|
||||
self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
|
||||
|
@ -290,10 +290,11 @@ class AttentionDownsample(nn.Module):
|
|||
]))
|
||||
|
||||
self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
|
||||
k_pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
|
||||
q_pos = torch.stack(torch.meshgrid(
|
||||
k_pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
|
||||
q_pos = torch.stack(ndgrid(
|
||||
torch.arange(0, resolution[0], step=stride),
|
||||
torch.arange(0, resolution[1], step=stride))).flatten(1)
|
||||
torch.arange(0, resolution[1], step=stride)
|
||||
)).flatten(1)
|
||||
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
|
||||
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
|
||||
self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
|
||||
|
|
|
@ -24,7 +24,7 @@ import torch.nn as nn
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \
|
||||
_assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed
|
||||
_assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import checkpoint_seq, named_apply
|
||||
|
@ -78,7 +78,7 @@ def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int):
|
|||
|
||||
def get_relative_position_index(win_h: int, win_w: int):
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) # 2, Wh, Ww
|
||||
coords = torch.stack(ndgrid(torch.arange(win_h), torch.arange(win_w))) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
|
|
|
@ -22,7 +22,7 @@ import torch.utils.checkpoint as checkpoint
|
|||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,\
|
||||
resample_patch_embed
|
||||
resample_patch_embed, ndgrid
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
|
@ -107,9 +107,8 @@ class WindowAttention(nn.Module):
|
|||
# get relative_coords_table
|
||||
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0]).to(torch.float32)
|
||||
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1]).to(torch.float32)
|
||||
relative_coords_table = torch.stack(torch.meshgrid([
|
||||
relative_coords_h,
|
||||
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
||||
relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w))
|
||||
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
||||
if pretrained_window_size[0] > 0:
|
||||
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
|
||||
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
|
||||
|
@ -125,7 +124,7 @@ class WindowAttention(nn.Module):
|
|||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords = torch.stack(ndgrid(coords_h, coords_w)) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
|
|
|
@ -37,7 +37,7 @@ import torch.nn.functional as F
|
|||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert
|
||||
from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._manipulate import named_apply
|
||||
|
@ -141,9 +141,10 @@ class WindowMultiHeadAttention(nn.Module):
|
|||
def _make_pair_wise_relative_positions(self) -> None:
|
||||
"""Method initializes the pair-wise relative positions to compute the positional biases."""
|
||||
device = self.logit_scale.device
|
||||
coordinates = torch.stack(torch.meshgrid([
|
||||
coordinates = torch.stack(ndgrid(
|
||||
torch.arange(self.window_size[0], device=device),
|
||||
torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1)
|
||||
torch.arange(self.window_size[1], device=device)
|
||||
), dim=0).flatten(1)
|
||||
relative_coordinates = coordinates[:, :, None] - coordinates[:, None, :]
|
||||
relative_coordinates = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float()
|
||||
relative_coordinates_log = torch.sign(relative_coordinates) * torch.log(
|
||||
|
|
Loading…
Reference in New Issue