Swin-V2 test fixes, typo
parent
9a86b900fa
commit
c0211b0bf7
|
@ -25,7 +25,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
|||
NON_STD_FILTERS = [
|
||||
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
|
||||
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*']
|
||||
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*']
|
||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||
|
||||
# exclude models that cause specific test failures
|
||||
|
|
|
@ -39,7 +39,7 @@ def _cfg(url='', **kwargs):
|
|||
|
||||
|
||||
default_cfgs = {
|
||||
'swinv2_tiny_window8_256.': _cfg(
|
||||
'swinv2_tiny_window8_256': _cfg(
|
||||
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
|
||||
input_size=(3, 256, 256)
|
||||
),
|
||||
|
@ -106,6 +106,7 @@ def window_partition(x, window_size):
|
|||
return windows
|
||||
|
||||
|
||||
@register_notrace_function # reason: int argument is a Proxy
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
|
@ -190,9 +191,11 @@ class WindowAttention(nn.Module):
|
|||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(dim))
|
||||
self.register_buffer('k_bias', torch.zeros(dim), persistent=False)
|
||||
self.v_bias = nn.Parameter(torch.zeros(dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.k_bias = None
|
||||
self.v_bias = None
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
@ -208,7 +211,7 @@ class WindowAttention(nn.Module):
|
|||
B_, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
||||
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
|
|
@ -51,7 +51,7 @@ def _cfg(url='', **kwargs):
|
|||
'url': url,
|
||||
'num_classes': 1000,
|
||||
'input_size': (3, 224, 224),
|
||||
'pool_size': None,
|
||||
'pool_size': (7, 7),
|
||||
'crop_pct': 0.9,
|
||||
'interpolation': 'bicubic',
|
||||
'fixed_input_size': True,
|
||||
|
@ -65,14 +65,14 @@ def _cfg(url='', **kwargs):
|
|||
|
||||
default_cfgs = {
|
||||
'swinv2_cr_tiny_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_tiny_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_tiny_ns_224': _cfg(
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_small_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_small_224': _cfg(
|
||||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
|
@ -80,21 +80,21 @@ default_cfgs = {
|
|||
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
|
||||
input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_base_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_base_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_base_ns_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_large_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_large_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_huge_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_huge_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
'swinv2_cr_giant_384': _cfg(
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
|
||||
'swinv2_cr_giant_224': _cfg(
|
||||
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue