mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1900 from huggingface/swin_maxvit_resize
Add support for resizing swin transformer, maxvit, coatnet at creation time
This commit is contained in:
commit
da75cdd212
@ -37,7 +37,8 @@ from .patch_dropout import PatchDropout
|
|||||||
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
|
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
|
||||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||||
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
|
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
|
||||||
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
|
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \
|
||||||
|
resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple
|
||||||
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
|
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
|
||||||
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
|
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
|
||||||
FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
|
FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
|
||||||
|
68
timm/layers/interpolate.py
Normal file
68
timm/layers/interpolate.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
""" Interpolation helpers for timm layers
|
||||||
|
|
||||||
|
RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations
|
||||||
|
Copyright Shane Barratt, Apache 2.0 license
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
|
||||||
|
class RegularGridInterpolator:
|
||||||
|
""" Interpolate data defined on a rectilinear grid with even or uneven spacing.
|
||||||
|
Produces similar results to scipy RegularGridInterpolator or interp2d
|
||||||
|
in 'linear' mode.
|
||||||
|
|
||||||
|
Taken from https://github.com/sbarratt/torch_interpolations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, points, values):
|
||||||
|
self.points = points
|
||||||
|
self.values = values
|
||||||
|
|
||||||
|
assert isinstance(self.points, tuple) or isinstance(self.points, list)
|
||||||
|
assert isinstance(self.values, torch.Tensor)
|
||||||
|
|
||||||
|
self.ms = list(self.values.shape)
|
||||||
|
self.n = len(self.points)
|
||||||
|
|
||||||
|
assert len(self.ms) == self.n
|
||||||
|
|
||||||
|
for i, p in enumerate(self.points):
|
||||||
|
assert isinstance(p, torch.Tensor)
|
||||||
|
assert p.shape[0] == self.values.shape[i]
|
||||||
|
|
||||||
|
def __call__(self, points_to_interp):
|
||||||
|
assert self.points is not None
|
||||||
|
assert self.values is not None
|
||||||
|
|
||||||
|
assert len(points_to_interp) == len(self.points)
|
||||||
|
K = points_to_interp[0].shape[0]
|
||||||
|
for x in points_to_interp:
|
||||||
|
assert x.shape[0] == K
|
||||||
|
|
||||||
|
idxs = []
|
||||||
|
dists = []
|
||||||
|
overalls = []
|
||||||
|
for p, x in zip(self.points, points_to_interp):
|
||||||
|
idx_right = torch.bucketize(x, p)
|
||||||
|
idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1
|
||||||
|
idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1)
|
||||||
|
dist_left = x - p[idx_left]
|
||||||
|
dist_right = p[idx_right] - x
|
||||||
|
dist_left[dist_left < 0] = 0.
|
||||||
|
dist_right[dist_right < 0] = 0.
|
||||||
|
both_zero = (dist_left == 0) & (dist_right == 0)
|
||||||
|
dist_left[both_zero] = dist_right[both_zero] = 1.
|
||||||
|
|
||||||
|
idxs.append((idx_left, idx_right))
|
||||||
|
dists.append((dist_left, dist_right))
|
||||||
|
overalls.append(dist_left + dist_right)
|
||||||
|
|
||||||
|
numerator = 0.
|
||||||
|
for indexer in product([0, 1], repeat=self.n):
|
||||||
|
as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)]
|
||||||
|
bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)]
|
||||||
|
numerator += self.values[as_s] * \
|
||||||
|
torch.prod(torch.stack(bs_s), dim=0)
|
||||||
|
denominator = torch.prod(torch.stack(overalls), dim=0)
|
||||||
|
return numerator / denominator
|
@ -3,15 +3,19 @@
|
|||||||
Hacked together by / Copyright 2022 Ross Wightman
|
Hacked together by / Copyright 2022 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .interpolate import RegularGridInterpolator
|
||||||
from .mlp import Mlp
|
from .mlp import Mlp
|
||||||
from .weight_init import trunc_normal_
|
from .weight_init import trunc_normal_
|
||||||
|
|
||||||
|
_USE_SCIPY = int(os.environ.get('TIMM_USE_SCIPY_INTERP', 0)) > 0
|
||||||
|
|
||||||
|
|
||||||
def gen_relative_position_index(
|
def gen_relative_position_index(
|
||||||
q_size: Tuple[int, int],
|
q_size: Tuple[int, int],
|
||||||
@ -20,7 +24,8 @@ def gen_relative_position_index(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Adapted with significant modifications from Swin / BeiT codebases
|
# Adapted with significant modifications from Swin / BeiT codebases
|
||||||
# get pair-wise relative position index for each token inside the window
|
# get pair-wise relative position index for each token inside the window
|
||||||
if k_size is None:
|
assert k_size is None, 'Different q & k sizes not currently supported' # FIXME
|
||||||
|
|
||||||
coords = torch.stack(
|
coords = torch.stack(
|
||||||
torch.meshgrid([
|
torch.meshgrid([
|
||||||
torch.arange(q_size[0]),
|
torch.arange(q_size[0]),
|
||||||
@ -29,42 +34,209 @@ def gen_relative_position_index(
|
|||||||
).flatten(1) # 2, Wh, Ww
|
).flatten(1) # 2, Wh, Ww
|
||||||
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||||
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
|
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
|
||||||
num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + 3
|
relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0
|
||||||
else:
|
relative_coords[:, :, 1] += q_size[1] - 1
|
||||||
# FIXME different q vs k sizes is a WIP, need to better offset the two grids?
|
relative_coords[:, :, 0] *= 2 * q_size[1] - 1
|
||||||
q_coords = torch.stack(
|
num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1)
|
||||||
torch.meshgrid([
|
|
||||||
torch.arange(q_size[0]),
|
# else:
|
||||||
torch.arange(q_size[1])
|
# # FIXME different q vs k sizes is a WIP, need to better offset the two grids?
|
||||||
])
|
# q_coords = torch.stack(
|
||||||
).flatten(1) # 2, Wh, Ww
|
# torch.meshgrid([
|
||||||
k_coords = torch.stack(
|
# torch.arange(q_size[0]),
|
||||||
torch.meshgrid([
|
# torch.arange(q_size[1])
|
||||||
torch.arange(k_size[0]),
|
# ])
|
||||||
torch.arange(k_size[1])
|
# ).flatten(1) # 2, Wh, Ww
|
||||||
])
|
# k_coords = torch.stack(
|
||||||
).flatten(1)
|
# torch.meshgrid([
|
||||||
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
# torch.arange(k_size[0]),
|
||||||
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
|
# torch.arange(k_size[1])
|
||||||
|
# ])
|
||||||
|
# ).flatten(1)
|
||||||
|
# relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||||
|
# relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
|
||||||
# relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
|
# relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
|
||||||
# relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
|
# relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
|
||||||
# relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
|
# relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
|
||||||
# relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
|
# relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
|
||||||
num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + q_size[1] - 1) + 3
|
# num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + k_size[1] - 1) + 3
|
||||||
|
|
||||||
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
|
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||||
|
|
||||||
if class_token:
|
if class_token:
|
||||||
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
|
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
|
||||||
# NOTE not intended or tested with MLP log-coords
|
# NOTE not intended or tested with MLP log-coords
|
||||||
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
|
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
|
||||||
relative_position_index[0, 0:] = num_relative_distance - 3
|
relative_position_index[0, 0:] = num_relative_distance
|
||||||
relative_position_index[0:, 0] = num_relative_distance - 2
|
relative_position_index[0:, 0] = num_relative_distance + 1
|
||||||
relative_position_index[0, 0] = num_relative_distance - 1
|
relative_position_index[0, 0] = num_relative_distance + 2
|
||||||
|
|
||||||
return relative_position_index.contiguous()
|
return relative_position_index.contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
def resize_rel_pos_bias_table_simple(
|
||||||
|
rel_pos_bias,
|
||||||
|
new_window_size: Tuple[int, int],
|
||||||
|
new_bias_shape: Tuple[int, ...],
|
||||||
|
):
|
||||||
|
dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
|
||||||
|
if rel_pos_bias.ndim == 3:
|
||||||
|
# TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
|
||||||
|
_, dst_h, dst_w = new_bias_shape
|
||||||
|
num_attn_heads, src_h, src_w = rel_pos_bias.shape
|
||||||
|
assert dst_h == dst_size[0] and dst_w == dst_size[1]
|
||||||
|
if src_h != dst_h or src_w != dst_w:
|
||||||
|
rel_pos_bias = torch.nn.functional.interpolate(
|
||||||
|
rel_pos_bias.unsqueeze(0),
|
||||||
|
size=dst_size,
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
).squeeze(0)
|
||||||
|
else:
|
||||||
|
assert rel_pos_bias.ndim == 2
|
||||||
|
# (num_pos, num_heads) (aka flat) bias shape
|
||||||
|
dst_num_pos, _ = new_bias_shape
|
||||||
|
src_num_pos, num_attn_heads = rel_pos_bias.shape
|
||||||
|
num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
|
||||||
|
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
|
||||||
|
src_size = (src_size, src_size) # FIXME could support non-equal src if argument passed
|
||||||
|
|
||||||
|
if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
|
||||||
|
if num_extra_tokens:
|
||||||
|
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
||||||
|
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
||||||
|
else:
|
||||||
|
extra_tokens = None
|
||||||
|
|
||||||
|
rel_pos_bias = torch.nn.functional.interpolate(
|
||||||
|
rel_pos_bias.transpose(1, 0).reshape((1, -1, src_size[0], src_size[1])),
|
||||||
|
size=dst_size,
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
).view(-1, dst_num_pos - num_extra_tokens).transpose(0, 1)
|
||||||
|
|
||||||
|
if extra_tokens is not None:
|
||||||
|
rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
|
||||||
|
|
||||||
|
return rel_pos_bias
|
||||||
|
|
||||||
|
|
||||||
|
def resize_rel_pos_bias_table(
|
||||||
|
rel_pos_bias,
|
||||||
|
new_window_size: Tuple[int, int],
|
||||||
|
new_bias_shape: Tuple[int, ...],
|
||||||
|
):
|
||||||
|
""" Resize relative position bias table using more advanced interpolation.
|
||||||
|
|
||||||
|
Modified from code in Microsoft Unilm (https://github.com/microsoft/unilm) repo (BeiT, BeiT-v2, etc).
|
||||||
|
|
||||||
|
https://github.com/microsoft/unilm/blob/5255d52de86dad642810f5849dd357769346c1d7/beit/run_class_finetuning.py#L351
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rel_pos_bias:
|
||||||
|
new_window_size:
|
||||||
|
new_bias_shape:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if _USE_SCIPY:
|
||||||
|
from scipy import interpolate
|
||||||
|
|
||||||
|
dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
|
||||||
|
if rel_pos_bias.ndim == 3:
|
||||||
|
# TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
|
||||||
|
num_extra_tokens = 0
|
||||||
|
_, dst_h, dst_w = new_bias_shape
|
||||||
|
assert dst_h == dst_size[0] and dst_w == dst_size[1]
|
||||||
|
num_attn_heads, src_h, src_w = rel_pos_bias.shape
|
||||||
|
src_size = (src_h, src_w)
|
||||||
|
has_flat_shape = False
|
||||||
|
else:
|
||||||
|
assert rel_pos_bias.ndim == 2
|
||||||
|
# (num_pos, num_heads) (aka flat) bias shape
|
||||||
|
dst_num_pos, _ = new_bias_shape
|
||||||
|
src_num_pos, num_attn_heads = rel_pos_bias.shape
|
||||||
|
num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
|
||||||
|
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
|
||||||
|
src_size = (src_size, src_size)
|
||||||
|
has_flat_shape = True
|
||||||
|
|
||||||
|
if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
|
||||||
|
# print("Interpolating position from %dx%d to %dx%d" % (src_size[0], src_size[1], dst_size[0], dst_size[1]))
|
||||||
|
if num_extra_tokens:
|
||||||
|
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
||||||
|
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
||||||
|
else:
|
||||||
|
extra_tokens = None
|
||||||
|
|
||||||
|
def geometric_progression(a, r, n):
|
||||||
|
return a * (1.0 - r ** n) / (1.0 - r)
|
||||||
|
|
||||||
|
def _calc(src, dst):
|
||||||
|
left, right = 1.01, 1.5
|
||||||
|
while right - left > 1e-6:
|
||||||
|
q = (left + right) / 2.0
|
||||||
|
gp = geometric_progression(1, q, src // 2)
|
||||||
|
if gp > dst // 2:
|
||||||
|
right = q
|
||||||
|
else:
|
||||||
|
left = q
|
||||||
|
|
||||||
|
dis = []
|
||||||
|
cur = 1
|
||||||
|
for i in range(src // 2):
|
||||||
|
dis.append(cur)
|
||||||
|
cur += q ** (i + 1)
|
||||||
|
r_ids = [-_ for _ in reversed(dis)]
|
||||||
|
return r_ids + [0] + dis
|
||||||
|
|
||||||
|
y = _calc(src_size[0], dst_size[0])
|
||||||
|
x = _calc(src_size[1], dst_size[1])
|
||||||
|
yx = [torch.tensor(y), torch.tensor(x)]
|
||||||
|
# print("Original positions = %s" % str(x))
|
||||||
|
|
||||||
|
ty = dst_size[0] // 2.0
|
||||||
|
tx = dst_size[1] // 2.0
|
||||||
|
dy = torch.arange(-ty, ty + 0.1, 1.0)
|
||||||
|
dx = torch.arange(-tx, tx + 0.1, 1.0)
|
||||||
|
dyx = torch.meshgrid([dy, dx])
|
||||||
|
# print("Target positions = %s" % str(dx))
|
||||||
|
|
||||||
|
all_rel_pos_bias = []
|
||||||
|
for i in range(num_attn_heads):
|
||||||
|
if has_flat_shape:
|
||||||
|
z = rel_pos_bias[:, i].view(src_size[0], src_size[1]).float()
|
||||||
|
else:
|
||||||
|
z = rel_pos_bias[i, :, :].float()
|
||||||
|
|
||||||
|
if _USE_SCIPY:
|
||||||
|
# Original beit code uses scipy w/ cubic interpolation
|
||||||
|
f = interpolate.interp2d(x, y, z.numpy(), kind='cubic')
|
||||||
|
r = torch.Tensor(f(dx, dy)).contiguous().to(rel_pos_bias.device)
|
||||||
|
else:
|
||||||
|
# Without scipy dependency, I've found a reasonably simple impl
|
||||||
|
# that supports uneven spaced interpolation pts with 'linear' interp.
|
||||||
|
# Results are comparable to scipy for model accuracy in most cases.
|
||||||
|
f = RegularGridInterpolator(yx, z)
|
||||||
|
r = f(dyx).contiguous().to(rel_pos_bias.device)
|
||||||
|
|
||||||
|
if has_flat_shape:
|
||||||
|
r = r.view(-1, 1)
|
||||||
|
all_rel_pos_bias.append(r)
|
||||||
|
|
||||||
|
if has_flat_shape:
|
||||||
|
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
||||||
|
else:
|
||||||
|
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=0)
|
||||||
|
|
||||||
|
if extra_tokens is not None:
|
||||||
|
assert has_flat_shape
|
||||||
|
rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
|
||||||
|
|
||||||
|
return rel_pos_bias
|
||||||
|
|
||||||
|
|
||||||
class RelPosBias(nn.Module):
|
class RelPosBias(nn.Module):
|
||||||
""" Relative Position Bias
|
""" Relative Position Bias
|
||||||
Adapted from Swin-V1 relative position bias impl, modularized.
|
Adapted from Swin-V1 relative position bias impl, modularized.
|
||||||
|
@ -48,6 +48,8 @@ from torch.utils.checkpoint import checkpoint
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
|
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
|
||||||
|
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table
|
||||||
|
|
||||||
|
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._registry import generate_default_cfgs, register_model
|
from ._registry import generate_default_cfgs, register_model
|
||||||
@ -115,7 +117,7 @@ class Attention(nn.Module):
|
|||||||
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
||||||
self.relative_position_bias_table = nn.Parameter(
|
self.relative_position_bias_table = nn.Parameter(
|
||||||
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||||
self.register_buffer("relative_position_index", gen_relative_position_index(window_size))
|
self.register_buffer("relative_position_index", gen_relative_position_index(window_size), persistent=False)
|
||||||
else:
|
else:
|
||||||
self.window_size = None
|
self.window_size = None
|
||||||
self.relative_position_bias_table = None
|
self.relative_position_bias_table = None
|
||||||
@ -504,11 +506,46 @@ default_cfgs = generate_default_cfgs({
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def _beit_checkpoint_filter_fn(state_dict, model):
|
def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
|
||||||
if 'module' in state_dict:
|
state_dict = state_dict.get('model', state_dict)
|
||||||
|
state_dict = state_dict.get('module', state_dict)
|
||||||
# beit v2 didn't strip module
|
# beit v2 didn't strip module
|
||||||
state_dict = state_dict['module']
|
|
||||||
return checkpoint_filter_fn(state_dict, model)
|
out_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if 'relative_position_index' in k:
|
||||||
|
continue
|
||||||
|
if 'patch_embed.proj.weight' in k:
|
||||||
|
O, I, H, W = model.patch_embed.proj.weight.shape
|
||||||
|
if v.shape[-1] != W or v.shape[-2] != H:
|
||||||
|
v = resample_patch_embed(
|
||||||
|
v,
|
||||||
|
(H, W),
|
||||||
|
interpolation=interpolation,
|
||||||
|
antialias=antialias,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
||||||
|
# To resize pos embedding when using model at different size from pretrained weights
|
||||||
|
num_prefix_tokens = 1
|
||||||
|
v = resample_abs_pos_embed(
|
||||||
|
v,
|
||||||
|
new_size=model.patch_embed.grid_size,
|
||||||
|
num_prefix_tokens=num_prefix_tokens,
|
||||||
|
interpolation=interpolation,
|
||||||
|
antialias=antialias,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
elif k.endswith('relative_position_bias_table'):
|
||||||
|
m = model.get_submodule(k[:-29])
|
||||||
|
if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
|
||||||
|
v = resize_rel_pos_bias_table(
|
||||||
|
v,
|
||||||
|
new_window_size=m.window_size,
|
||||||
|
new_bias_shape=m.relative_position_bias_table.shape,
|
||||||
|
)
|
||||||
|
out_dict[k] = v
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
def _create_beit(variant, pretrained=False, **kwargs):
|
def _create_beit(variant, pretrained=False, **kwargs):
|
||||||
|
@ -48,7 +48,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|||||||
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
|
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
|
||||||
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
|
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
|
||||||
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
|
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
|
||||||
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn
|
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn, resize_rel_pos_bias_table
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features_fx import register_notrace_function
|
from ._features_fx import register_notrace_function
|
||||||
from ._manipulate import named_apply, checkpoint_seq
|
from ._manipulate import named_apply, checkpoint_seq
|
||||||
@ -186,9 +186,9 @@ class Attention2d(nn.Module):
|
|||||||
attn_bias = shared_rel_pos
|
attn_bias = shared_rel_pos
|
||||||
|
|
||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q.transpose(-1, -2),
|
q.transpose(-1, -2).contiguous(),
|
||||||
k.transpose(-1, -2),
|
k.transpose(-1, -2).contiguous(),
|
||||||
v.transpose(-1, -2),
|
v.transpose(-1, -2).contiguous(),
|
||||||
attn_mask=attn_bias,
|
attn_mask=attn_bias,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p,
|
||||||
).transpose(-1, -2).reshape(B, -1, H, W)
|
).transpose(-1, -2).reshape(B, -1, H, W)
|
||||||
@ -1790,6 +1790,15 @@ def checkpoint_filter_fn(state_dict, model: nn.Module):
|
|||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
out_dict = {}
|
out_dict = {}
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
|
if k.endswith('relative_position_bias_table'):
|
||||||
|
m = model.get_submodule(k[:-29])
|
||||||
|
if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
|
||||||
|
v = resize_rel_pos_bias_table(
|
||||||
|
v,
|
||||||
|
new_window_size=m.window_size,
|
||||||
|
new_bias_shape=m.relative_position_bias_table.shape,
|
||||||
|
)
|
||||||
|
|
||||||
if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
|
if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
|
||||||
# adapt between conv2d / linear layers
|
# adapt between conv2d / linear layers
|
||||||
assert v.ndim in (2, 4)
|
assert v.ndim in (2, 4)
|
||||||
|
@ -24,7 +24,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \
|
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \
|
||||||
_assert, use_fused_attn
|
_assert, use_fused_attn, resize_rel_pos_bias_table
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features_fx import register_notrace_function
|
from ._features_fx import register_notrace_function
|
||||||
from ._manipulate import checkpoint_seq, named_apply
|
from ._manipulate import checkpoint_seq, named_apply
|
||||||
@ -38,23 +38,28 @@ _logger = logging.getLogger(__name__)
|
|||||||
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
|
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
|
||||||
|
|
||||||
|
|
||||||
def window_partition(x, window_size: int):
|
def window_partition(
|
||||||
|
x: torch.Tensor,
|
||||||
|
window_size: Tuple[int, int],
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
Partition into non-overlapping windows with padding if needed.
|
||||||
Args:
|
Args:
|
||||||
x: (B, H, W, C)
|
x (tensor): input tokens with [B, H, W, C].
|
||||||
window_size (int): window size
|
window_size (int): window size.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
windows: (num_windows*B, window_size, window_size, C)
|
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
||||||
|
(Hp, Wp): padded height and width before partition
|
||||||
"""
|
"""
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
|
||||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
|
||||||
return windows
|
return windows
|
||||||
|
|
||||||
|
|
||||||
@register_notrace_function # reason: int argument is a Proxy
|
@register_notrace_function # reason: int argument is a Proxy
|
||||||
def window_reverse(windows, window_size: int, H: int, W: int):
|
def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
windows: (num_windows*B, window_size, window_size, C)
|
windows: (num_windows*B, window_size, window_size, C)
|
||||||
@ -66,7 +71,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
|
|||||||
x: (B, H, W, C)
|
x: (B, H, W, C)
|
||||||
"""
|
"""
|
||||||
C = windows.shape[-1]
|
C = windows.shape[-1]
|
||||||
x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C)
|
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
|
||||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -124,7 +129,7 @@ class WindowAttention(nn.Module):
|
|||||||
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))
|
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))
|
||||||
|
|
||||||
# get pair-wise relative position index for each token inside the window
|
# get pair-wise relative position index for each token inside the window
|
||||||
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w))
|
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False)
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
|
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
@ -218,14 +223,11 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.input_resolution = input_resolution
|
self.input_resolution = input_resolution
|
||||||
self.window_size = window_size
|
ws, ss = self._calc_window_shift(window_size, shift_size)
|
||||||
self.shift_size = shift_size
|
self.window_size: Tuple[int, int] = ws
|
||||||
|
self.shift_size: Tuple[int, int] = ss
|
||||||
|
self.window_area = self.window_size[0] * self.window_size[1]
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
if min(self.input_resolution) <= self.window_size:
|
|
||||||
# if window size is larger than input resolution, we don't partition windows
|
|
||||||
self.shift_size = 0
|
|
||||||
self.window_size = min(self.input_resolution)
|
|
||||||
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
|
||||||
|
|
||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
self.attn = WindowAttention(
|
self.attn = WindowAttention(
|
||||||
@ -237,8 +239,8 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_drop=attn_drop,
|
attn_drop=attn_drop,
|
||||||
proj_drop=proj_drop,
|
proj_drop=proj_drop,
|
||||||
)
|
)
|
||||||
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
||||||
self.norm2 = norm_layer(dim)
|
self.norm2 = norm_layer(dim)
|
||||||
self.mlp = Mlp(
|
self.mlp = Mlp(
|
||||||
in_features=dim,
|
in_features=dim,
|
||||||
@ -246,66 +248,81 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
drop=proj_drop,
|
drop=proj_drop,
|
||||||
)
|
)
|
||||||
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
if self.shift_size > 0:
|
if any(self.shift_size):
|
||||||
# calculate attention mask for SW-MSA
|
# calculate attention mask for SW-MSA
|
||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
|
H = math.ceil(H / self.window_size[0]) * self.window_size[0]
|
||||||
|
W = math.ceil(W / self.window_size[1]) * self.window_size[1]
|
||||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||||
cnt = 0
|
cnt = 0
|
||||||
for h in (
|
for h in (
|
||||||
slice(0, -self.window_size),
|
slice(0, -self.window_size[0]),
|
||||||
slice(-self.window_size, -self.shift_size),
|
slice(-self.window_size[0], -self.shift_size[0]),
|
||||||
slice(-self.shift_size, None)):
|
slice(-self.shift_size[0], None)):
|
||||||
for w in (
|
for w in (
|
||||||
slice(0, -self.window_size),
|
slice(0, -self.window_size[1]),
|
||||||
slice(-self.window_size, -self.shift_size),
|
slice(-self.window_size[1], -self.shift_size[1]),
|
||||||
slice(-self.shift_size, None)):
|
slice(-self.shift_size[1], None)):
|
||||||
img_mask[:, h, w, :] = cnt
|
img_mask[:, h, w, :] = cnt
|
||||||
cnt += 1
|
cnt += 1
|
||||||
mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1
|
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
mask_windows = mask_windows.view(-1, self.window_area)
|
||||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||||
else:
|
else:
|
||||||
attn_mask = None
|
attn_mask = None
|
||||||
self.register_buffer("attn_mask", attn_mask)
|
|
||||||
|
|
||||||
def forward(self, x):
|
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
||||||
|
|
||||||
|
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||||
|
target_window_size = to_2tuple(target_window_size)
|
||||||
|
target_shift_size = to_2tuple(target_shift_size)
|
||||||
|
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
|
||||||
|
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
|
||||||
|
return tuple(window_size), tuple(shift_size)
|
||||||
|
|
||||||
|
def _attn(self, x):
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
_assert(H == self.input_resolution[0], "input feature has wrong size")
|
|
||||||
_assert(W == self.input_resolution[1], "input feature has wrong size")
|
|
||||||
|
|
||||||
shortcut = x
|
|
||||||
x = self.norm1(x)
|
|
||||||
|
|
||||||
# cyclic shift
|
# cyclic shift
|
||||||
if self.shift_size > 0:
|
has_shift = any(self.shift_size)
|
||||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
if has_shift:
|
||||||
|
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
|
||||||
else:
|
else:
|
||||||
shifted_x = x
|
shifted_x = x
|
||||||
|
|
||||||
|
# pad for resolution not divisible by window size
|
||||||
|
pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
|
||||||
|
pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
|
||||||
|
shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
|
||||||
|
Hp, Wp = H + pad_h, W + pad_w
|
||||||
|
|
||||||
# partition windows
|
# partition windows
|
||||||
x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C
|
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_win*B, window_size*window_size, C
|
x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
|
||||||
|
|
||||||
# W-MSA/SW-MSA
|
# W-MSA/SW-MSA
|
||||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C
|
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||||
|
|
||||||
# merge windows
|
# merge windows
|
||||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
|
||||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
||||||
|
shifted_x = shifted_x[:, :H, :W, :].contiguous()
|
||||||
|
|
||||||
# reverse cyclic shift
|
# reverse cyclic shift
|
||||||
if self.shift_size > 0:
|
if has_shift:
|
||||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
|
||||||
else:
|
else:
|
||||||
x = shifted_x
|
x = shifted_x
|
||||||
|
return x
|
||||||
|
|
||||||
# FFN
|
def forward(self, x):
|
||||||
x = shortcut + self.drop_path(x)
|
B, H, W, C = x.shape
|
||||||
|
x = x + self.drop_path1(self._attn(self.norm1(x)))
|
||||||
x = x.reshape(B, -1, C)
|
x = x.reshape(B, -1, C)
|
||||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
||||||
x = x.reshape(B, H, W, C)
|
x = x.reshape(B, H, W, C)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -385,6 +402,8 @@ class SwinTransformerStage(nn.Module):
|
|||||||
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
|
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
|
||||||
self.depth = depth
|
self.depth = depth
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
window_size = to_2tuple(window_size)
|
||||||
|
shift_size = tuple([w // 2 for w in window_size])
|
||||||
|
|
||||||
# patch merging layer
|
# patch merging layer
|
||||||
if downsample:
|
if downsample:
|
||||||
@ -405,7 +424,7 @@ class SwinTransformerStage(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
shift_size=0 if (i % 2 == 0) else shift_size,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
proj_drop=proj_drop,
|
proj_drop=proj_drop,
|
||||||
@ -499,7 +518,11 @@ class SwinTransformer(nn.Module):
|
|||||||
|
|
||||||
# build layers
|
# build layers
|
||||||
head_dim = to_ntuple(self.num_layers)(head_dim)
|
head_dim = to_ntuple(self.num_layers)(head_dim)
|
||||||
|
if not isinstance(window_size, (list, tuple)):
|
||||||
window_size = to_ntuple(self.num_layers)(window_size)
|
window_size = to_ntuple(self.num_layers)(window_size)
|
||||||
|
elif len(window_size) == 2:
|
||||||
|
window_size = (window_size,) * self.num_layers
|
||||||
|
assert len(window_size) == self.num_layers
|
||||||
mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)
|
mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)
|
||||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||||
layers = []
|
layers = []
|
||||||
@ -598,15 +621,30 @@ class SwinTransformer(nn.Module):
|
|||||||
|
|
||||||
def checkpoint_filter_fn(state_dict, model):
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||||
|
old_weights = True
|
||||||
if 'head.fc.weight' in state_dict:
|
if 'head.fc.weight' in state_dict:
|
||||||
return state_dict
|
old_weights = False
|
||||||
import re
|
import re
|
||||||
out_dict = {}
|
out_dict = {}
|
||||||
state_dict = state_dict.get('model', state_dict)
|
state_dict = state_dict.get('model', state_dict)
|
||||||
state_dict = state_dict.get('state_dict', state_dict)
|
state_dict = state_dict.get('state_dict', state_dict)
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
|
if any([n in k for n in ('relative_position_index', 'attn_mask')]):
|
||||||
|
continue # skip buffers that should not be persistent
|
||||||
|
|
||||||
|
if k.endswith('relative_position_bias_table'):
|
||||||
|
m = model.get_submodule(k[:-29])
|
||||||
|
if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
|
||||||
|
v = resize_rel_pos_bias_table(
|
||||||
|
v,
|
||||||
|
new_window_size=m.window_size,
|
||||||
|
new_bias_shape=m.relative_position_bias_table.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
if old_weights:
|
||||||
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
|
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
|
||||||
k = k.replace('head.', 'head.fc.')
|
k = k.replace('head.', 'head.fc.')
|
||||||
|
|
||||||
out_dict[k] = v
|
out_dict[k] = v
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
@ -398,6 +398,8 @@ class SwinTransformerV2Stage(nn.Module):
|
|||||||
self.depth = depth
|
self.depth = depth
|
||||||
self.output_nchw = output_nchw
|
self.output_nchw = output_nchw
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
window_size = to_2tuple(window_size)
|
||||||
|
shift_size = tuple([w // 2 for w in window_size])
|
||||||
|
|
||||||
# patch merging / downsample layer
|
# patch merging / downsample layer
|
||||||
if downsample:
|
if downsample:
|
||||||
@ -413,7 +415,7 @@ class SwinTransformerV2Stage(nn.Module):
|
|||||||
input_resolution=self.output_resolution,
|
input_resolution=self.output_resolution,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
shift_size=0 if (i % 2 == 0) else shift_size,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
proj_drop=proj_drop,
|
proj_drop=proj_drop,
|
||||||
@ -568,7 +570,7 @@ class SwinTransformerV2(nn.Module):
|
|||||||
def no_weight_decay(self):
|
def no_weight_decay(self):
|
||||||
nod = set()
|
nod = set()
|
||||||
for n, m in self.named_modules():
|
for n, m in self.named_modules():
|
||||||
if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]):
|
if any([kw in n for kw in ("cpb_mlp", "logit_scale")]):
|
||||||
nod.add(n)
|
nod.add(n)
|
||||||
return nod
|
return nod
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user