mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add support for resizing swin transformer img_size and window_size on init and load from pretrained weights. Add support for non-square window_size to both swin v1/v2
This commit is contained in:
parent
81089b10a2
commit
7790ea709b
@ -38,23 +38,28 @@ _logger = logging.getLogger(__name__)
|
||||
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
|
||||
|
||||
|
||||
def window_partition(x, window_size: int):
|
||||
def window_partition(
|
||||
x: torch.Tensor,
|
||||
window_size: Tuple[int, int],
|
||||
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
"""
|
||||
Partition into non-overlapping windows with padding if needed.
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
x (tensor): input tokens with [B, H, W, C].
|
||||
window_size (int): window size.
|
||||
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
||||
(Hp, Wp): padded height and width before partition
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
|
||||
return windows
|
||||
|
||||
|
||||
@register_notrace_function # reason: int argument is a Proxy
|
||||
def window_reverse(windows, window_size: int, H: int, W: int):
|
||||
def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
@ -66,7 +71,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
C = windows.shape[-1]
|
||||
x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C)
|
||||
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
|
||||
return x
|
||||
|
||||
@ -124,7 +129,7 @@ class WindowAttention(nn.Module):
|
||||
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads))
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w))
|
||||
self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False)
|
||||
|
||||
self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
@ -218,14 +223,11 @@ class SwinTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
ws, ss = self._calc_window_shift(window_size, shift_size)
|
||||
self.window_size: Tuple[int, int] = ws
|
||||
self.shift_size: Tuple[int, int] = ss
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.mlp_ratio = mlp_ratio
|
||||
if min(self.input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.window_size = min(self.input_resolution)
|
||||
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention(
|
||||
@ -237,8 +239,8 @@ class SwinTransformerBlock(nn.Module):
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
)
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
@ -246,66 +248,82 @@ class SwinTransformerBlock(nn.Module):
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
if self.shift_size > 0:
|
||||
if any(self.shift_size):
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = self.input_resolution
|
||||
H = math.ceil(H / self.window_size[0]) * self.window_size[0]
|
||||
W = math.ceil(W / self.window_size[1]) * self.window_size[1]
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
cnt = 0
|
||||
for h in (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None)):
|
||||
slice(0, -self.window_size[0]),
|
||||
slice(-self.window_size[0], -self.shift_size[0]),
|
||||
slice(-self.shift_size[0], None)):
|
||||
for w in (
|
||||
slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None)):
|
||||
slice(0, -self.window_size[1]),
|
||||
slice(-self.window_size[1], -self.shift_size[1]),
|
||||
slice(-self.shift_size[1], None)):
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_area)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
self.register_buffer("attn_mask", attn_mask)
|
||||
|
||||
def forward(self, x):
|
||||
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
||||
|
||||
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
target_window_size = to_2tuple(target_window_size)
|
||||
target_shift_size = to_2tuple(target_shift_size)
|
||||
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
|
||||
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
|
||||
return tuple(window_size), tuple(shift_size)
|
||||
|
||||
def _attn(self, x):
|
||||
B, H, W, C = x.shape
|
||||
_assert(H == self.input_resolution[0], "input feature has wrong size")
|
||||
_assert(W == self.input_resolution[1], "input feature has wrong size")
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
has_shift = any(self.shift_size)
|
||||
if has_shift:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
|
||||
else:
|
||||
shifted_x = x
|
||||
|
||||
# pad for resolution not divisible by window size
|
||||
pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
|
||||
pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_win*B, window_size*window_size, C
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C
|
||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
||||
shifted_x = shifted_x[:, :H, :W, :].contiguous()
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
if has_shift:
|
||||
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
return x
|
||||
|
||||
# FFN
|
||||
x = shortcut + self.drop_path(x)
|
||||
|
||||
def forward(self, x):
|
||||
B, H, W, C = x.shape
|
||||
x = x + self.drop_path1(self._attn(self.norm1(x)))
|
||||
x = x.reshape(B, -1, C)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
||||
x = x.reshape(B, H, W, C)
|
||||
return x
|
||||
|
||||
@ -385,6 +403,8 @@ class SwinTransformerStage(nn.Module):
|
||||
self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
|
||||
self.depth = depth
|
||||
self.grad_checkpointing = False
|
||||
window_size = to_2tuple(window_size)
|
||||
shift_size = tuple([w // 2 for w in window_size])
|
||||
|
||||
# patch merging layer
|
||||
if downsample:
|
||||
@ -405,7 +425,7 @@ class SwinTransformerStage(nn.Module):
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
shift_size=0 if (i % 2 == 0) else shift_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop,
|
||||
@ -499,7 +519,11 @@ class SwinTransformer(nn.Module):
|
||||
|
||||
# build layers
|
||||
head_dim = to_ntuple(self.num_layers)(head_dim)
|
||||
window_size = to_ntuple(self.num_layers)(window_size)
|
||||
if not isinstance(window_size, (list, tuple)):
|
||||
window_size = to_ntuple(self.num_layers)(window_size)
|
||||
elif len(window_size) == 2:
|
||||
window_size = (window_size,) * self.num_layers
|
||||
assert len(window_size) == self.num_layers
|
||||
mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
layers = []
|
||||
@ -598,15 +622,34 @@ class SwinTransformer(nn.Module):
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
old_weights = True
|
||||
if 'head.fc.weight' in state_dict:
|
||||
return state_dict
|
||||
old_weights = False
|
||||
import re
|
||||
current_state_dict = model.state_dict()
|
||||
out_dict = {}
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
for k, v in state_dict.items():
|
||||
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
if any([n in k for n in ('relative_position_index', 'attn_mask')]):
|
||||
continue # skip buffers that should not be persistent
|
||||
|
||||
if k.endswith('relative_position_bias_table'):
|
||||
m = model.get_submodule(k[:-29])
|
||||
bias_size = tuple([2 * x -1 for x in m.window_size])
|
||||
old_len = int(len(v) ** 0.5) # we have to assume pretrained weight is square right now
|
||||
if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
|
||||
new_pos_bias = torch.nn.functional.interpolate(
|
||||
v.transpose(1, 0).reshape(1, -1, old_len, old_len),
|
||||
size=bias_size,
|
||||
mode="bicubic",
|
||||
)
|
||||
v = new_pos_bias.reshape(-1, bias_size[0] * bias_size[1]).transpose(0, 1)
|
||||
|
||||
if old_weights:
|
||||
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
@ -398,6 +398,8 @@ class SwinTransformerV2Stage(nn.Module):
|
||||
self.depth = depth
|
||||
self.output_nchw = output_nchw
|
||||
self.grad_checkpointing = False
|
||||
window_size = to_2tuple(window_size)
|
||||
shift_size = tuple([w // 2 for w in window_size])
|
||||
|
||||
# patch merging / downsample layer
|
||||
if downsample:
|
||||
@ -413,7 +415,7 @@ class SwinTransformerV2Stage(nn.Module):
|
||||
input_resolution=self.output_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
shift_size=0 if (i % 2 == 0) else shift_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop,
|
||||
|
Loading…
x
Reference in New Issue
Block a user