Merge pull request #1900 from huggingface/swin_maxvit_resize

Add support for resizing swin transformer, maxvit, coatnet at creation time
This commit is contained in:
Ross Wightman 2023-08-11 15:05:28 -07:00 committed by GitHub
commit da75cdd212
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 427 additions and 100 deletions

View File

@ -37,7 +37,8 @@ from .patch_dropout import PatchDropout
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \
resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \ from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \ build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat

View File

@ -0,0 +1,68 @@
""" Interpolation helpers for timm layers
RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations
Copyright Shane Barratt, Apache 2.0 license
"""
import torch
from itertools import product
class RegularGridInterpolator:
""" Interpolate data defined on a rectilinear grid with even or uneven spacing.
Produces similar results to scipy RegularGridInterpolator or interp2d
in 'linear' mode.
Taken from https://github.com/sbarratt/torch_interpolations
"""
def __init__(self, points, values):
self.points = points
self.values = values
assert isinstance(self.points, tuple) or isinstance(self.points, list)
assert isinstance(self.values, torch.Tensor)
self.ms = list(self.values.shape)
self.n = len(self.points)
assert len(self.ms) == self.n
for i, p in enumerate(self.points):
assert isinstance(p, torch.Tensor)
assert p.shape[0] == self.values.shape[i]
def __call__(self, points_to_interp):
assert self.points is not None
assert self.values is not None
assert len(points_to_interp) == len(self.points)
K = points_to_interp[0].shape[0]
for x in points_to_interp:
assert x.shape[0] == K
idxs = []
dists = []
overalls = []
for p, x in zip(self.points, points_to_interp):
idx_right = torch.bucketize(x, p)
idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1
idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1)
dist_left = x - p[idx_left]
dist_right = p[idx_right] - x
dist_left[dist_left < 0] = 0.
dist_right[dist_right < 0] = 0.
both_zero = (dist_left == 0) & (dist_right == 0)
dist_left[both_zero] = dist_right[both_zero] = 1.
idxs.append((idx_left, idx_right))
dists.append((dist_left, dist_right))
overalls.append(dist_left + dist_right)
numerator = 0.
for indexer in product([0, 1], repeat=self.n):
as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)]
bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)]
numerator += self.values[as_s] * \
torch.prod(torch.stack(bs_s), dim=0)
denominator = torch.prod(torch.stack(overalls), dim=0)
return numerator / denominator

View File

@ -3,15 +3,19 @@
Hacked together by / Copyright 2022 Ross Wightman Hacked together by / Copyright 2022 Ross Wightman
""" """
import math import math
import os
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .interpolate import RegularGridInterpolator
from .mlp import Mlp from .mlp import Mlp
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
_USE_SCIPY = int(os.environ.get('TIMM_USE_SCIPY_INTERP', 0)) > 0
def gen_relative_position_index( def gen_relative_position_index(
q_size: Tuple[int, int], q_size: Tuple[int, int],
@ -20,51 +24,219 @@ def gen_relative_position_index(
) -> torch.Tensor: ) -> torch.Tensor:
# Adapted with significant modifications from Swin / BeiT codebases # Adapted with significant modifications from Swin / BeiT codebases
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
if k_size is None: assert k_size is None, 'Different q & k sizes not currently supported' # FIXME
coords = torch.stack(
torch.meshgrid([
torch.arange(q_size[0]),
torch.arange(q_size[1])
])
).flatten(1) # 2, Wh, Ww
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + 3
else:
# FIXME different q vs k sizes is a WIP, need to better offset the two grids?
q_coords = torch.stack(
torch.meshgrid([
torch.arange(q_size[0]),
torch.arange(q_size[1])
])
).flatten(1) # 2, Wh, Ww
k_coords = torch.stack(
torch.meshgrid([
torch.arange(k_size[0]),
torch.arange(k_size[1])
])
).flatten(1)
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
# relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
# relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
# relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
# relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + q_size[1] - 1) + 3
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0) coords = torch.stack(
torch.meshgrid([
torch.arange(q_size[0]),
torch.arange(q_size[1])
])
).flatten(1) # 2, Wh, Ww
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += q_size[1] - 1
relative_coords[:, :, 0] *= 2 * q_size[1] - 1
num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1)
# else:
# # FIXME different q vs k sizes is a WIP, need to better offset the two grids?
# q_coords = torch.stack(
# torch.meshgrid([
# torch.arange(q_size[0]),
# torch.arange(q_size[1])
# ])
# ).flatten(1) # 2, Wh, Ww
# k_coords = torch.stack(
# torch.meshgrid([
# torch.arange(k_size[0]),
# torch.arange(k_size[1])
# ])
# ).flatten(1)
# relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
# relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
# relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
# relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
# relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
# relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
# num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + k_size[1] - 1) + 3
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
if class_token: if class_token:
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
# NOTE not intended or tested with MLP log-coords # NOTE not intended or tested with MLP log-coords
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0]) relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0, 0:] = num_relative_distance
relative_position_index[0:, 0] = num_relative_distance - 2 relative_position_index[0:, 0] = num_relative_distance + 1
relative_position_index[0, 0] = num_relative_distance - 1 relative_position_index[0, 0] = num_relative_distance + 2
return relative_position_index.contiguous() return relative_position_index.contiguous()
def resize_rel_pos_bias_table_simple(
rel_pos_bias,
new_window_size: Tuple[int, int],
new_bias_shape: Tuple[int, ...],
):
dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
if rel_pos_bias.ndim == 3:
# TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
_, dst_h, dst_w = new_bias_shape
num_attn_heads, src_h, src_w = rel_pos_bias.shape
assert dst_h == dst_size[0] and dst_w == dst_size[1]
if src_h != dst_h or src_w != dst_w:
rel_pos_bias = torch.nn.functional.interpolate(
rel_pos_bias.unsqueeze(0),
size=dst_size,
mode="bicubic",
align_corners=False,
).squeeze(0)
else:
assert rel_pos_bias.ndim == 2
# (num_pos, num_heads) (aka flat) bias shape
dst_num_pos, _ = new_bias_shape
src_num_pos, num_attn_heads = rel_pos_bias.shape
num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
src_size = (src_size, src_size) # FIXME could support non-equal src if argument passed
if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
if num_extra_tokens:
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
else:
extra_tokens = None
rel_pos_bias = torch.nn.functional.interpolate(
rel_pos_bias.transpose(1, 0).reshape((1, -1, src_size[0], src_size[1])),
size=dst_size,
mode="bicubic",
align_corners=False,
).view(-1, dst_num_pos - num_extra_tokens).transpose(0, 1)
if extra_tokens is not None:
rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
return rel_pos_bias
def resize_rel_pos_bias_table(
rel_pos_bias,
new_window_size: Tuple[int, int],
new_bias_shape: Tuple[int, ...],
):
""" Resize relative position bias table using more advanced interpolation.
Modified from code in Microsoft Unilm (https://github.com/microsoft/unilm) repo (BeiT, BeiT-v2, etc).
https://github.com/microsoft/unilm/blob/5255d52de86dad642810f5849dd357769346c1d7/beit/run_class_finetuning.py#L351
Args:
rel_pos_bias:
new_window_size:
new_bias_shape:
Returns:
"""
if _USE_SCIPY:
from scipy import interpolate
dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
if rel_pos_bias.ndim == 3:
# TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
num_extra_tokens = 0
_, dst_h, dst_w = new_bias_shape
assert dst_h == dst_size[0] and dst_w == dst_size[1]
num_attn_heads, src_h, src_w = rel_pos_bias.shape
src_size = (src_h, src_w)
has_flat_shape = False
else:
assert rel_pos_bias.ndim == 2
# (num_pos, num_heads) (aka flat) bias shape
dst_num_pos, _ = new_bias_shape
src_num_pos, num_attn_heads = rel_pos_bias.shape
num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
src_size = (src_size, src_size)
has_flat_shape = True
if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
# print("Interpolating position from %dx%d to %dx%d" % (src_size[0], src_size[1], dst_size[0], dst_size[1]))
if num_extra_tokens:
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
else:
extra_tokens = None
def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)
def _calc(src, dst):
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src // 2)
if gp > dst // 2:
right = q
else:
left = q
dis = []
cur = 1
for i in range(src // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
return r_ids + [0] + dis
y = _calc(src_size[0], dst_size[0])
x = _calc(src_size[1], dst_size[1])
yx = [torch.tensor(y), torch.tensor(x)]
# print("Original positions = %s" % str(x))
ty = dst_size[0] // 2.0
tx = dst_size[1] // 2.0
dy = torch.arange(-ty, ty + 0.1, 1.0)
dx = torch.arange(-tx, tx + 0.1, 1.0)
dyx = torch.meshgrid([dy, dx])
# print("Target positions = %s" % str(dx))
all_rel_pos_bias = []
for i in range(num_attn_heads):
if has_flat_shape:
z = rel_pos_bias[:, i].view(src_size[0], src_size[1]).float()
else:
z = rel_pos_bias[i, :, :].float()
if _USE_SCIPY:
# Original beit code uses scipy w/ cubic interpolation
f = interpolate.interp2d(x, y, z.numpy(), kind='cubic')
r = torch.Tensor(f(dx, dy)).contiguous().to(rel_pos_bias.device)
else:
# Without scipy dependency, I've found a reasonably simple impl
# that supports uneven spaced interpolation pts with 'linear' interp.
# Results are comparable to scipy for model accuracy in most cases.
f = RegularGridInterpolator(yx, z)
r = f(dyx).contiguous().to(rel_pos_bias.device)
if has_flat_shape:
r = r.view(-1, 1)
all_rel_pos_bias.append(r)
if has_flat_shape:
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
else:
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=0)
if extra_tokens is not None:
assert has_flat_shape
rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
return rel_pos_bias
class RelPosBias(nn.Module): class RelPosBias(nn.Module):
""" Relative Position Bias """ Relative Position Bias
Adapted from Swin-V1 relative position bias impl, modularized. Adapted from Swin-V1 relative position bias impl, modularized.

View File

@ -48,6 +48,8 @@ from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._registry import generate_default_cfgs, register_model from ._registry import generate_default_cfgs, register_model
@ -115,7 +117,7 @@ class Attention(nn.Module):
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter( self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
self.register_buffer("relative_position_index", gen_relative_position_index(window_size)) self.register_buffer("relative_position_index", gen_relative_position_index(window_size), persistent=False)
else: else:
self.window_size = None self.window_size = None
self.relative_position_bias_table = None self.relative_position_bias_table = None
@ -504,11 +506,46 @@ default_cfgs = generate_default_cfgs({
}) })
def _beit_checkpoint_filter_fn(state_dict, model): def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
if 'module' in state_dict: state_dict = state_dict.get('model', state_dict)
# beit v2 didn't strip module state_dict = state_dict.get('module', state_dict)
state_dict = state_dict['module'] # beit v2 didn't strip module
return checkpoint_filter_fn(state_dict, model)
out_dict = {}
for k, v in state_dict.items():
if 'relative_position_index' in k:
continue
if 'patch_embed.proj.weight' in k:
O, I, H, W = model.patch_embed.proj.weight.shape
if v.shape[-1] != W or v.shape[-2] != H:
v = resample_patch_embed(
v,
(H, W),
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights
num_prefix_tokens = 1
v = resample_abs_pos_embed(
v,
new_size=model.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
interpolation=interpolation,
antialias=antialias,
verbose=True,
)
elif k.endswith('relative_position_bias_table'):
m = model.get_submodule(k[:-29])
if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
v = resize_rel_pos_bias_table(
v,
new_window_size=m.window_size,
new_bias_shape=m.relative_position_bias_table.shape,
)
out_dict[k] = v
return out_dict
def _create_beit(variant, pretrained=False, **kwargs): def _create_beit(variant, pretrained=False, **kwargs):

View File

@ -48,7 +48,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn, resize_rel_pos_bias_table
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import named_apply, checkpoint_seq from ._manipulate import named_apply, checkpoint_seq
@ -186,9 +186,9 @@ class Attention2d(nn.Module):
attn_bias = shared_rel_pos attn_bias = shared_rel_pos
x = torch.nn.functional.scaled_dot_product_attention( x = torch.nn.functional.scaled_dot_product_attention(
q.transpose(-1, -2), q.transpose(-1, -2).contiguous(),
k.transpose(-1, -2), k.transpose(-1, -2).contiguous(),
v.transpose(-1, -2), v.transpose(-1, -2).contiguous(),
attn_mask=attn_bias, attn_mask=attn_bias,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p,
).transpose(-1, -2).reshape(B, -1, H, W) ).transpose(-1, -2).reshape(B, -1, H, W)
@ -1790,6 +1790,15 @@ def checkpoint_filter_fn(state_dict, model: nn.Module):
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
out_dict = {} out_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
if k.endswith('relative_position_bias_table'):
m = model.get_submodule(k[:-29])
if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
v = resize_rel_pos_bias_table(
v,
new_window_size=m.window_size,
new_bias_shape=m.relative_position_bias_table.shape,
)
if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel(): if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
# adapt between conv2d / linear layers # adapt between conv2d / linear layers
assert v.ndim in (2, 4) assert v.ndim in (2, 4)

View File

@ -24,7 +24,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \ from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \
_assert, use_fused_attn _assert, use_fused_attn, resize_rel_pos_bias_table
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq, named_apply from ._manipulate import checkpoint_seq, named_apply
@ -38,23 +38,28 @@ _logger = logging.getLogger(__name__)
_int_or_tuple_2_t = Union[int, Tuple[int, int]] _int_or_tuple_2_t = Union[int, Tuple[int, int]]
def window_partition(x, window_size: int): def window_partition(
x: torch.Tensor,
window_size: Tuple[int, int],
) -> torch.Tensor:
""" """
Partition into non-overlapping windows with padding if needed.
Args: Args:
x: (B, H, W, C) x (tensor): input tokens with [B, H, W, C].
window_size (int): window size window_size (int): window size.
Returns: Returns:
windows: (num_windows*B, window_size, window_size, C) windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
""" """
B, H, W, C = x.shape B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows return windows
@register_notrace_function # reason: int argument is a Proxy @register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: int, H: int, W: int): def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int):
""" """
Args: Args:
windows: (num_windows*B, window_size, window_size, C) windows: (num_windows*B, window_size, window_size, C)
@ -66,7 +71,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
x: (B, H, W, C) x: (B, H, W, C)
""" """
C = windows.shape[-1] C = windows.shape[-1]
x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C) x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x
@ -124,7 +129,7 @@ class WindowAttention(nn.Module):
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)) self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w)) self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False)
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
@ -218,14 +223,11 @@ class SwinTransformerBlock(nn.Module):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.input_resolution = input_resolution self.input_resolution = input_resolution
self.window_size = window_size ws, ss = self._calc_window_shift(window_size, shift_size)
self.shift_size = shift_size self.window_size: Tuple[int, int] = ws
self.shift_size: Tuple[int, int] = ss
self.window_area = self.window_size[0] * self.window_size[1]
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = WindowAttention( self.attn = WindowAttention(
@ -237,8 +239,8 @@ class SwinTransformerBlock(nn.Module):
attn_drop=attn_drop, attn_drop=attn_drop,
proj_drop=proj_drop, proj_drop=proj_drop,
) )
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.mlp = Mlp( self.mlp = Mlp(
in_features=dim, in_features=dim,
@ -246,66 +248,81 @@ class SwinTransformerBlock(nn.Module):
act_layer=act_layer, act_layer=act_layer,
drop=proj_drop, drop=proj_drop,
) )
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if self.shift_size > 0: if any(self.shift_size):
# calculate attention mask for SW-MSA # calculate attention mask for SW-MSA
H, W = self.input_resolution H, W = self.input_resolution
H = math.ceil(H / self.window_size[0]) * self.window_size[0]
W = math.ceil(W / self.window_size[1]) * self.window_size[1]
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
cnt = 0 cnt = 0
for h in ( for h in (
slice(0, -self.window_size), slice(0, -self.window_size[0]),
slice(-self.window_size, -self.shift_size), slice(-self.window_size[0], -self.shift_size[0]),
slice(-self.shift_size, None)): slice(-self.shift_size[0], None)):
for w in ( for w in (
slice(0, -self.window_size), slice(0, -self.window_size[1]),
slice(-self.window_size, -self.shift_size), slice(-self.window_size[1], -self.shift_size[1]),
slice(-self.shift_size, None)): slice(-self.shift_size[1], None)):
img_mask[:, h, w, :] = cnt img_mask[:, h, w, :] = cnt
cnt += 1 cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) mask_windows = mask_windows.view(-1, self.window_area)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else: else:
attn_mask = None attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x): self.register_buffer("attn_mask", attn_mask, persistent=False)
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
target_window_size = to_2tuple(target_window_size)
target_shift_size = to_2tuple(target_shift_size)
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
return tuple(window_size), tuple(shift_size)
def _attn(self, x):
B, H, W, C = x.shape B, H, W, C = x.shape
_assert(H == self.input_resolution[0], "input feature has wrong size")
_assert(W == self.input_resolution[1], "input feature has wrong size")
shortcut = x
x = self.norm1(x)
# cyclic shift # cyclic shift
if self.shift_size > 0: has_shift = any(self.shift_size)
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) if has_shift:
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
else: else:
shifted_x = x shifted_x = x
# pad for resolution not divisible by window size
pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
# partition windows # partition windows
x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_win*B, window_size*window_size, C x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA # W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows # merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
shifted_x = shifted_x[:, :H, :W, :].contiguous()
# reverse cyclic shift # reverse cyclic shift
if self.shift_size > 0: if has_shift:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
else: else:
x = shifted_x x = shifted_x
return x
# FFN def forward(self, x):
x = shortcut + self.drop_path(x) B, H, W, C = x.shape
x = x + self.drop_path1(self._attn(self.norm1(x)))
x = x.reshape(B, -1, C) x = x.reshape(B, -1, C)
x = x + self.drop_path(self.mlp(self.norm2(x))) x = x + self.drop_path2(self.mlp(self.norm2(x)))
x = x.reshape(B, H, W, C) x = x.reshape(B, H, W, C)
return x return x
@ -385,6 +402,8 @@ class SwinTransformerStage(nn.Module):
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
self.depth = depth self.depth = depth
self.grad_checkpointing = False self.grad_checkpointing = False
window_size = to_2tuple(window_size)
shift_size = tuple([w // 2 for w in window_size])
# patch merging layer # patch merging layer
if downsample: if downsample:
@ -405,7 +424,7 @@ class SwinTransformerStage(nn.Module):
num_heads=num_heads, num_heads=num_heads,
head_dim=head_dim, head_dim=head_dim,
window_size=window_size, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2, shift_size=0 if (i % 2 == 0) else shift_size,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
proj_drop=proj_drop, proj_drop=proj_drop,
@ -499,7 +518,11 @@ class SwinTransformer(nn.Module):
# build layers # build layers
head_dim = to_ntuple(self.num_layers)(head_dim) head_dim = to_ntuple(self.num_layers)(head_dim)
window_size = to_ntuple(self.num_layers)(window_size) if not isinstance(window_size, (list, tuple)):
window_size = to_ntuple(self.num_layers)(window_size)
elif len(window_size) == 2:
window_size = (window_size,) * self.num_layers
assert len(window_size) == self.num_layers
mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio) mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
layers = [] layers = []
@ -598,15 +621,30 @@ class SwinTransformer(nn.Module):
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
old_weights = True
if 'head.fc.weight' in state_dict: if 'head.fc.weight' in state_dict:
return state_dict old_weights = False
import re import re
out_dict = {} out_dict = {}
state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict) state_dict = state_dict.get('state_dict', state_dict)
for k, v in state_dict.items(): for k, v in state_dict.items():
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) if any([n in k for n in ('relative_position_index', 'attn_mask')]):
k = k.replace('head.', 'head.fc.') continue # skip buffers that should not be persistent
if k.endswith('relative_position_bias_table'):
m = model.get_submodule(k[:-29])
if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
v = resize_rel_pos_bias_table(
v,
new_window_size=m.window_size,
new_bias_shape=m.relative_position_bias_table.shape,
)
if old_weights:
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
k = k.replace('head.', 'head.fc.')
out_dict[k] = v out_dict[k] = v
return out_dict return out_dict

View File

@ -398,6 +398,8 @@ class SwinTransformerV2Stage(nn.Module):
self.depth = depth self.depth = depth
self.output_nchw = output_nchw self.output_nchw = output_nchw
self.grad_checkpointing = False self.grad_checkpointing = False
window_size = to_2tuple(window_size)
shift_size = tuple([w // 2 for w in window_size])
# patch merging / downsample layer # patch merging / downsample layer
if downsample: if downsample:
@ -413,7 +415,7 @@ class SwinTransformerV2Stage(nn.Module):
input_resolution=self.output_resolution, input_resolution=self.output_resolution,
num_heads=num_heads, num_heads=num_heads,
window_size=window_size, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2, shift_size=0 if (i % 2 == 0) else shift_size,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
proj_drop=proj_drop, proj_drop=proj_drop,
@ -568,7 +570,7 @@ class SwinTransformerV2(nn.Module):
def no_weight_decay(self): def no_weight_decay(self):
nod = set() nod = set()
for n, m in self.named_modules(): for n, m in self.named_modules():
if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]): if any([kw in n for kw in ("cpb_mlp", "logit_scale")]):
nod.add(n) nod.add(n)
return nod return nod