mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
set_input_size(), always_partition, strict_img_size, dynamic mask option for all swin models. More flexibility in resolution, window resizing.
This commit is contained in:
parent
2b3f1a4633
commit
4c531be479
@ -219,6 +219,7 @@ class SwinTransformerBlock(nn.Module):
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
shift_size: int = 0,
|
||||
always_partition: bool = False,
|
||||
dynamic_mask: bool = False,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
proj_drop: float = 0.,
|
||||
@ -235,6 +236,7 @@ class SwinTransformerBlock(nn.Module):
|
||||
num_heads: Number of attention heads.
|
||||
head_dim: Enforce the number of channels per head
|
||||
shift_size: Shift size for SW-MSA.
|
||||
always_partition: Always partition into full windows and shift
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
proj_drop: Dropout rate.
|
||||
@ -246,9 +248,10 @@ class SwinTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.target_shift_size = to_2tuple(shift_size)
|
||||
self.target_shift_size = to_2tuple(shift_size) # store for later resize
|
||||
self.always_partition = always_partition
|
||||
self.window_size, self.shift_size = self._calc_window_shift(window_size, target_shift_size=shift_size)
|
||||
self.dynamic_mask = dynamic_mask
|
||||
self.window_size, self.shift_size = self._calc_window_shift(window_size, shift_size)
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
@ -257,7 +260,7 @@ class SwinTransformerBlock(nn.Module):
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
window_size=self.window_size,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
@ -273,25 +276,38 @@ class SwinTransformerBlock(nn.Module):
|
||||
)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self._make_attention_mask()
|
||||
self.register_buffer(
|
||||
"attn_mask",
|
||||
None if self.dynamic_mask else self.get_attn_mask(),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def _make_attention_mask(self):
|
||||
def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
|
||||
if any(self.shift_size):
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = self.input_resolution
|
||||
if x is not None:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
else:
|
||||
H, W = self.input_resolution
|
||||
device = None
|
||||
dtype = None
|
||||
H = math.ceil(H / self.window_size[0]) * self.window_size[0]
|
||||
W = math.ceil(W / self.window_size[1]) * self.window_size[1]
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
img_mask = torch.zeros((1, H, W, 1), dtype=dtype, device=device) # 1 H W 1
|
||||
cnt = 0
|
||||
for h in (
|
||||
slice(0, -self.window_size[0]),
|
||||
slice(-self.window_size[0], -self.shift_size[0]),
|
||||
slice(-self.shift_size[0], None)):
|
||||
(0, -self.window_size[0]),
|
||||
(-self.window_size[0], -self.shift_size[0]),
|
||||
(-self.shift_size[0], None),
|
||||
):
|
||||
for w in (
|
||||
slice(0, -self.window_size[1]),
|
||||
slice(-self.window_size[1], -self.shift_size[1]),
|
||||
slice(-self.shift_size[1], None)):
|
||||
img_mask[:, h, w, :] = cnt
|
||||
(0, -self.window_size[1]),
|
||||
(-self.window_size[1], -self.shift_size[1]),
|
||||
(-self.shift_size[1], None),
|
||||
):
|
||||
img_mask[:, h[0]:h[1], w[0]:w[1], :] = cnt
|
||||
cnt += 1
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_area)
|
||||
@ -299,7 +315,7 @@ class SwinTransformerBlock(nn.Module):
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
||||
return attn_mask
|
||||
|
||||
def _calc_window_shift(
|
||||
self,
|
||||
@ -308,14 +324,16 @@ class SwinTransformerBlock(nn.Module):
|
||||
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
target_window_size = to_2tuple(target_window_size)
|
||||
if target_shift_size is None:
|
||||
# if passed value is None, recalculate from default window_size // 2 if it was active
|
||||
# if passed value is None, recalculate from default window_size // 2 if it was previously non-zero
|
||||
target_shift_size = self.target_shift_size
|
||||
if any(target_shift_size):
|
||||
target_shift_size = [target_window_size[0] // 2, target_window_size[1] // 2]
|
||||
target_shift_size = (target_window_size[0] // 2, target_window_size[1] // 2)
|
||||
else:
|
||||
target_shift_size = to_2tuple(target_shift_size)
|
||||
|
||||
if self.always_partition:
|
||||
return target_window_size, target_shift_size
|
||||
|
||||
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
|
||||
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
|
||||
return tuple(window_size), tuple(shift_size)
|
||||
@ -338,7 +356,11 @@ class SwinTransformerBlock(nn.Module):
|
||||
self.window_size, self.shift_size = self._calc_window_shift(window_size)
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.attn.set_window_size(self.window_size)
|
||||
self._make_attention_mask()
|
||||
self.register_buffer(
|
||||
"attn_mask",
|
||||
None if self.dynamic_mask else self.get_attn_mask(),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def _attn(self, x):
|
||||
B, H, W, C = x.shape
|
||||
@ -354,14 +376,18 @@ class SwinTransformerBlock(nn.Module):
|
||||
pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
|
||||
pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
|
||||
shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
_, Hp, Wp, _ = shifted_x.shape
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||
if getattr(self, 'dynamic_mask', False):
|
||||
attn_mask = self.get_attn_mask(shifted_x)
|
||||
else:
|
||||
attn_mask = self.attn_mask
|
||||
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
|
||||
@ -408,8 +434,11 @@ class PatchMerging(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
B, H, W, C = x.shape
|
||||
_assert(H % 2 == 0, f"x height ({H}) is not even.")
|
||||
_assert(W % 2 == 0, f"x width ({W}) is not even.")
|
||||
|
||||
pad_values = (0, 0, 0, H % 2, 0, W % 2)
|
||||
x = nn.functional.pad(x, pad_values)
|
||||
_, H, W, _ = x.shape
|
||||
|
||||
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
@ -431,6 +460,7 @@ class SwinTransformerStage(nn.Module):
|
||||
head_dim: Optional[int] = None,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
always_partition: bool = False,
|
||||
dynamic_mask: bool = False,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
proj_drop: float = 0.,
|
||||
@ -485,6 +515,7 @@ class SwinTransformerStage(nn.Module):
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else shift_size,
|
||||
always_partition=always_partition,
|
||||
dynamic_mask=dynamic_mask,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop,
|
||||
@ -500,11 +531,12 @@ class SwinTransformerStage(nn.Module):
|
||||
window_size: int,
|
||||
always_partition: Optional[bool] = None,
|
||||
):
|
||||
"""Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
|
||||
""" Updates the resolution, window size and so the pair-wise relative positions.
|
||||
|
||||
Args:
|
||||
feat_size (Tuple[int, int]): New input resolution
|
||||
window_size (int): New window size
|
||||
feat_size: New input (feature) resolution
|
||||
window_size: New window size
|
||||
always_partition: Always partition / shift the window
|
||||
"""
|
||||
self.input_resolution = feat_size
|
||||
if isinstance(self.downsample, nn.Identity):
|
||||
@ -548,6 +580,7 @@ class SwinTransformer(nn.Module):
|
||||
head_dim: Optional[int] = None,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
always_partition: bool = False,
|
||||
strict_img_size: bool = True,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop_rate: float = 0.,
|
||||
@ -599,9 +632,10 @@ class SwinTransformer(nn.Module):
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim[0],
|
||||
norm_layer=norm_layer,
|
||||
strict_img_size=strict_img_size,
|
||||
output_fmt='NHWC',
|
||||
)
|
||||
self.patch_grid = self.patch_embed.grid_size
|
||||
patch_grid = self.patch_embed.grid_size
|
||||
|
||||
# build layers
|
||||
head_dim = to_ntuple(self.num_layers)(head_dim)
|
||||
@ -621,8 +655,8 @@ class SwinTransformer(nn.Module):
|
||||
dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
input_resolution=(
|
||||
self.patch_grid[0] // scale,
|
||||
self.patch_grid[1] // scale
|
||||
patch_grid[0] // scale,
|
||||
patch_grid[1] // scale
|
||||
),
|
||||
depth=depths[i],
|
||||
downsample=i > 0,
|
||||
@ -630,6 +664,7 @@ class SwinTransformer(nn.Module):
|
||||
head_dim=head_dim[i],
|
||||
window_size=window_size[i],
|
||||
always_partition=always_partition,
|
||||
dynamic_mask=not strict_img_size,
|
||||
mlp_ratio=mlp_ratio[i],
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop_rate,
|
||||
@ -673,27 +708,29 @@ class SwinTransformer(nn.Module):
|
||||
img_size: Optional[Tuple[int, int]] = None,
|
||||
patch_size: Optional[Tuple[int, int]] = None,
|
||||
window_size: Optional[Tuple[int, int]] = None,
|
||||
window_ratio: int = 32,
|
||||
window_ratio: int = 8,
|
||||
always_partition: Optional[bool] = None,
|
||||
) -> None:
|
||||
""" Updates the image resolution and window size.
|
||||
|
||||
Args:
|
||||
img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
|
||||
window_size (Optional[int]): New window size, if None based on new_img_size // window_div
|
||||
window_ratio (int): divisor for calculating window size from image size
|
||||
img_size: New input resolution, if None current resolution is used
|
||||
patch_size (Optional[Tuple[int, int]): New patch size, if None use current patch size
|
||||
window_size: New window size, if None based on new_img_size // window_div
|
||||
window_ratio: divisor for calculating window size from grid size
|
||||
always_partition: always partition into windows and shift (even if window size < feat size)
|
||||
"""
|
||||
if img_size is not None or patch_size is not None:
|
||||
self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
|
||||
self.patch_grid = self.patch_embed.grid_size
|
||||
patch_grid = self.patch_embed.grid_size
|
||||
|
||||
if window_size is None:
|
||||
img_size = self.patch_embed.img_size
|
||||
window_size = tuple([s // window_ratio for s in img_size])
|
||||
window_size = tuple([pg // window_ratio for pg in patch_grid])
|
||||
|
||||
for index, stage in enumerate(self.layers):
|
||||
stage_scale = 2 ** max(index - 1, 0)
|
||||
print(self.patch_grid, stage_scale)
|
||||
stage.set_input_size(
|
||||
feat_size=(self.patch_grid[0] // stage_scale, self.patch_grid[1] // stage_scale),
|
||||
feat_size=(patch_grid[0] // stage_scale, patch_grid[1] // stage_scale),
|
||||
window_size=window_size,
|
||||
always_partition=always_partition,
|
||||
)
|
||||
|
@ -94,7 +94,7 @@ class WindowAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.pretrained_window_size = pretrained_window_size
|
||||
self.pretrained_window_size = to_2tuple(pretrained_window_size)
|
||||
self.num_heads = num_heads
|
||||
self.qkv_bias_separate = qkv_bias_separate
|
||||
|
||||
@ -107,21 +107,37 @@ class WindowAttention(nn.Module):
|
||||
nn.Linear(512, num_heads, bias=False)
|
||||
)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(dim))
|
||||
self.register_buffer('k_bias', torch.zeros(dim), persistent=False)
|
||||
self.v_bias = nn.Parameter(torch.zeros(dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.k_bias = None
|
||||
self.v_bias = None
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
self._make_pair_wise_relative_positions()
|
||||
|
||||
def _make_pair_wise_relative_positions(self):
|
||||
# get relative_coords_table
|
||||
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0]).to(torch.float32)
|
||||
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1]).to(torch.float32)
|
||||
relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w))
|
||||
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
|
||||
if pretrained_window_size[0] > 0:
|
||||
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
|
||||
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
|
||||
if self.pretrained_window_size[0] > 0:
|
||||
relative_coords_table[:, :, :, 0] /= (self.pretrained_window_size[0] - 1)
|
||||
relative_coords_table[:, :, :, 1] /= (self.pretrained_window_size[1] - 1)
|
||||
else:
|
||||
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
|
||||
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
|
||||
relative_coords_table *= 8 # normalize to -8, 8
|
||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
||||
torch.abs(relative_coords_table) + 1.0) / math.log2(8)
|
||||
|
||||
self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
@ -137,19 +153,15 @@ class WindowAttention(nn.Module):
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(dim))
|
||||
self.register_buffer('k_bias', torch.zeros(dim), persistent=False)
|
||||
self.v_bias = nn.Parameter(torch.zeros(dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.k_bias = None
|
||||
self.v_bias = None
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
def set_window_size(self, window_size: Tuple[int, int]) -> None:
|
||||
"""Update window size & interpolate position embeddings
|
||||
Args:
|
||||
window_size (int): New window size
|
||||
"""
|
||||
window_size = to_2tuple(window_size)
|
||||
if window_size != self.window_size:
|
||||
self.window_size = window_size
|
||||
self._make_pair_wise_relative_positions()
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
@ -210,6 +222,8 @@ class SwinTransformerV2Block(nn.Module):
|
||||
num_heads: int,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
shift_size: _int_or_tuple_2_t = 0,
|
||||
always_partition: bool = False,
|
||||
dynamic_mask: bool = False,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
proj_drop: float = 0.,
|
||||
@ -218,7 +232,7 @@ class SwinTransformerV2Block(nn.Module):
|
||||
act_layer: LayerType = "gelu",
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
pretrained_window_size: _int_or_tuple_2_t = 0,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
@ -226,6 +240,7 @@ class SwinTransformerV2Block(nn.Module):
|
||||
num_heads: Number of attention heads.
|
||||
window_size: Window size.
|
||||
shift_size: Shift size for SW-MSA.
|
||||
always_partition: Always partition into full windows and shift
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
proj_drop: Dropout rate.
|
||||
@ -239,9 +254,10 @@ class SwinTransformerV2Block(nn.Module):
|
||||
self.dim = dim
|
||||
self.input_resolution = to_2tuple(input_resolution)
|
||||
self.num_heads = num_heads
|
||||
ws, ss = self._calc_window_shift(window_size, shift_size)
|
||||
self.window_size: Tuple[int, int] = ws
|
||||
self.shift_size: Tuple[int, int] = ss
|
||||
self.target_shift_size = to_2tuple(shift_size) # store for later resize
|
||||
self.always_partition = always_partition
|
||||
self.dynamic_mask = dynamic_mask
|
||||
self.window_size, self.shift_size = self._calc_window_shift(window_size, shift_size)
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.mlp_ratio = mlp_ratio
|
||||
act_layer = get_act_layer(act_layer)
|
||||
@ -267,20 +283,31 @@ class SwinTransformerV2Block(nn.Module):
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.register_buffer(
|
||||
"attn_mask",
|
||||
None if self.dynamic_mask else self.get_attn_mask(),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
|
||||
if any(self.shift_size):
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = self.input_resolution
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
if x is None:
|
||||
img_mask = torch.zeros((1, *self.input_resolution, 1)) # 1 H W 1
|
||||
else:
|
||||
img_mask = torch.zeros((1, x.shape[1], x.shape[2], 1), dtype=x.dtype, device=x.device) # 1 H W 1
|
||||
cnt = 0
|
||||
for h in (
|
||||
slice(0, -self.window_size[0]),
|
||||
slice(-self.window_size[0], -self.shift_size[0]),
|
||||
slice(-self.shift_size[0], None)):
|
||||
(0, -self.window_size[0]),
|
||||
(-self.window_size[0], -self.shift_size[0]),
|
||||
(-self.shift_size[0], None),
|
||||
):
|
||||
for w in (
|
||||
slice(0, -self.window_size[1]),
|
||||
slice(-self.window_size[1], -self.shift_size[1]),
|
||||
slice(-self.shift_size[1], None)):
|
||||
img_mask[:, h, w, :] = cnt
|
||||
(0, -self.window_size[1]),
|
||||
(-self.window_size[1], -self.shift_size[1]),
|
||||
(-self.shift_size[1], None),
|
||||
):
|
||||
img_mask[:, h[0]:h[1], w[0]:w[1], :] = cnt
|
||||
cnt += 1
|
||||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_area)
|
||||
@ -288,18 +315,60 @@ class SwinTransformerV2Block(nn.Module):
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
return attn_mask
|
||||
|
||||
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
||||
def _calc_window_shift(
|
||||
self,
|
||||
target_window_size: _int_or_tuple_2_t,
|
||||
target_shift_size: Optional[_int_or_tuple_2_t] = None,
|
||||
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
target_window_size = to_2tuple(target_window_size)
|
||||
if target_shift_size is None:
|
||||
# if passed value is None, recalculate from default window_size // 2 if it was active
|
||||
target_shift_size = self.target_shift_size
|
||||
if any(target_shift_size):
|
||||
# if there was previously a non-zero shift, recalculate based on current window_size
|
||||
target_shift_size = (target_window_size[0] // 2, target_window_size[1] // 2)
|
||||
else:
|
||||
target_shift_size = to_2tuple(target_shift_size)
|
||||
|
||||
if self.always_partition:
|
||||
print('ap', target_window_size, target_shift_size)
|
||||
return target_window_size, target_shift_size
|
||||
|
||||
def _calc_window_shift(self,
|
||||
target_window_size: _int_or_tuple_2_t,
|
||||
target_shift_size: _int_or_tuple_2_t) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
target_window_size = to_2tuple(target_window_size)
|
||||
target_shift_size = to_2tuple(target_shift_size)
|
||||
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
|
||||
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
|
||||
print('nap', window_size, shift_size)
|
||||
return tuple(window_size), tuple(shift_size)
|
||||
|
||||
def set_input_size(
|
||||
self,
|
||||
feat_size: Tuple[int, int],
|
||||
window_size: Tuple[int, int],
|
||||
always_partition: Optional[bool] = None,
|
||||
):
|
||||
""" Updates the input resolution, window size.
|
||||
|
||||
Args:
|
||||
feat_size (Tuple[int, int]): New input resolution
|
||||
window_size (int): New window size
|
||||
always_partition: Change always_partition attribute if not None
|
||||
"""
|
||||
# Update input resolution
|
||||
self.input_resolution = feat_size
|
||||
if always_partition is not None:
|
||||
self.always_partition = always_partition
|
||||
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.attn.set_window_size(self.window_size)
|
||||
self.register_buffer(
|
||||
"attn_mask",
|
||||
None if self.dynamic_mask else self.get_attn_mask(),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def _attn(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, H, W, C = x.shape
|
||||
|
||||
@ -310,16 +379,26 @@ class SwinTransformerV2Block(nn.Module):
|
||||
else:
|
||||
shifted_x = x
|
||||
|
||||
pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
|
||||
pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
|
||||
shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
_, Hp, Wp, _ = shifted_x.shape
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||
if getattr(self, 'dynamic_mask', False):
|
||||
attn_mask = self.get_attn_mask(shifted_x)
|
||||
else:
|
||||
attn_mask = self.attn_mask
|
||||
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, self.input_resolution) # B H' W' C
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, (Hp, Wp)) # B H' W' C
|
||||
shifted_x = shifted_x[:, :H, :W, :].contiguous()
|
||||
|
||||
# reverse cyclic shift
|
||||
if has_shift:
|
||||
@ -341,7 +420,12 @@ class PatchMerging(nn.Module):
|
||||
""" Patch Merging Layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: nn.Module = nn.LayerNorm) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
out_dim: Optional[int] = None,
|
||||
norm_layer: nn.Module = nn.LayerNorm
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
@ -356,8 +440,11 @@ class PatchMerging(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, H, W, C = x.shape
|
||||
_assert(H % 2 == 0, f"x height ({H}) is not even.")
|
||||
_assert(W % 2 == 0, f"x width ({W}) is not even.")
|
||||
|
||||
pad_values = (0, 0, 0, H % 2, 0, W % 2)
|
||||
x = nn.functional.pad(x, pad_values)
|
||||
_, H, W, _ = x.shape
|
||||
|
||||
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
|
||||
x = self.reduction(x)
|
||||
x = self.norm(x)
|
||||
@ -376,6 +463,8 @@ class SwinTransformerV2Stage(nn.Module):
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
window_size: _int_or_tuple_2_t,
|
||||
always_partition: bool = False,
|
||||
dynamic_mask: bool = False,
|
||||
downsample: bool = False,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
@ -395,6 +484,8 @@ class SwinTransformerV2Stage(nn.Module):
|
||||
depth: Number of blocks.
|
||||
num_heads: Number of attention heads.
|
||||
window_size: Local window size.
|
||||
always_partition: Always partition into full windows and shift
|
||||
dynamic_mask: Create attention mask in forward based on current input size
|
||||
downsample: Use downsample layer at start of the block.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: If True, add a learnable bias to query, key, value.
|
||||
@ -431,6 +522,8 @@ class SwinTransformerV2Stage(nn.Module):
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else shift_size,
|
||||
always_partition=always_partition,
|
||||
dynamic_mask=dynamic_mask,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop,
|
||||
@ -442,6 +535,32 @@ class SwinTransformerV2Stage(nn.Module):
|
||||
)
|
||||
for i in range(depth)])
|
||||
|
||||
def set_input_size(
|
||||
self,
|
||||
feat_size: Tuple[int, int],
|
||||
window_size: int,
|
||||
always_partition: Optional[bool] = None,
|
||||
):
|
||||
""" Updates the resolution, window size and so the pair-wise relative positions.
|
||||
|
||||
Args:
|
||||
feat_size: New input (feature) resolution
|
||||
window_size: New window size
|
||||
always_partition: Always partition / shift the window
|
||||
"""
|
||||
self.input_resolution = feat_size
|
||||
if isinstance(self.downsample, nn.Identity):
|
||||
self.output_resolution = feat_size
|
||||
else:
|
||||
assert isinstance(self.downsample, PatchMerging)
|
||||
self.output_resolution = tuple(i // 2 for i in feat_size)
|
||||
for block in self.blocks:
|
||||
block.set_input_size(
|
||||
feat_size=self.output_resolution,
|
||||
window_size=window_size,
|
||||
always_partition=always_partition,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.downsample(x)
|
||||
|
||||
@ -478,6 +597,8 @@ class SwinTransformerV2(nn.Module):
|
||||
depths: Tuple[int, ...] = (2, 2, 6, 2),
|
||||
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
always_partition: bool = False,
|
||||
strict_img_size: bool = True,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
drop_rate: float = 0.,
|
||||
@ -532,8 +653,10 @@ class SwinTransformerV2(nn.Module):
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim[0],
|
||||
norm_layer=norm_layer,
|
||||
strict_img_size=strict_img_size,
|
||||
output_fmt='NHWC',
|
||||
)
|
||||
grid_size = self.patch_embed.grid_size
|
||||
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
layers = []
|
||||
@ -544,13 +667,13 @@ class SwinTransformerV2(nn.Module):
|
||||
layers += [SwinTransformerV2Stage(
|
||||
dim=in_dim,
|
||||
out_dim=out_dim,
|
||||
input_resolution=(
|
||||
self.patch_embed.grid_size[0] // scale,
|
||||
self.patch_embed.grid_size[1] // scale),
|
||||
input_resolution=(grid_size[0] // scale, grid_size[1] // scale),
|
||||
depth=depths[i],
|
||||
downsample=i > 0,
|
||||
num_heads=num_heads[i],
|
||||
window_size=window_size,
|
||||
always_partition=always_partition,
|
||||
dynamic_mask=not strict_img_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_drop=proj_drop_rate,
|
||||
@ -585,6 +708,38 @@ class SwinTransformerV2(nn.Module):
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def set_input_size(
|
||||
self,
|
||||
img_size: Optional[Tuple[int, int]] = None,
|
||||
patch_size: Optional[Tuple[int, int]] = None,
|
||||
window_size: Optional[Tuple[int, int]] = None,
|
||||
window_ratio: Optional[int] = 8,
|
||||
always_partition: Optional[bool] = None,
|
||||
):
|
||||
"""Updates the image resolution, window size, and so the pair-wise relative positions.
|
||||
|
||||
Args:
|
||||
img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
|
||||
patch_size (Optional[Tuple[int, int]): New patch size, if None use current patch size
|
||||
window_size (Optional[int]): New window size, if None based on new_img_size // window_div
|
||||
window_ratio (int): divisor for calculating window size from patch grid size
|
||||
always_partition: always partition / shift windows even if feat size is < window
|
||||
"""
|
||||
if img_size is not None or patch_size is not None:
|
||||
self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
|
||||
grid_size = self.patch_embed.grid_size
|
||||
|
||||
if window_size is None and window_ratio is not None:
|
||||
window_size = tuple([s // window_ratio for s in grid_size])
|
||||
|
||||
for index, stage in enumerate(self.layers):
|
||||
stage_scale = 2 ** max(index - 1, 0)
|
||||
stage.set_input_size(
|
||||
feat_size=(grid_size[0] // stage_scale, grid_size[1] // stage_scale),
|
||||
window_size=window_size,
|
||||
always_partition=always_partition,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
nod = set()
|
||||
|
@ -231,26 +231,30 @@ class SwinTransformerV2CrBlock(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
feat_size: Tuple[int, int],
|
||||
window_size: Tuple[int, int],
|
||||
shift_size: Tuple[int, int] = (0, 0),
|
||||
mlp_ratio: float = 4.0,
|
||||
init_values: Optional[float] = 0,
|
||||
proj_drop: float = 0.0,
|
||||
drop_attn: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
extra_norm: bool = False,
|
||||
sequential_attn: bool = False,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
) -> None:
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
feat_size: Tuple[int, int],
|
||||
window_size: Tuple[int, int],
|
||||
shift_size: Tuple[int, int] = (0, 0),
|
||||
always_partition: bool = False,
|
||||
dynamic_mask: bool = False,
|
||||
mlp_ratio: float = 4.0,
|
||||
init_values: Optional[float] = 0,
|
||||
proj_drop: float = 0.0,
|
||||
drop_attn: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
extra_norm: bool = False,
|
||||
sequential_attn: bool = False,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
):
|
||||
super(SwinTransformerV2CrBlock, self).__init__()
|
||||
self.dim: int = dim
|
||||
self.feat_size: Tuple[int, int] = feat_size
|
||||
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
|
||||
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
|
||||
self.always_partition = always_partition
|
||||
self.dynamic_mask = dynamic_mask
|
||||
self.window_size, self.shift_size = self._calc_window_shift(window_size)
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.init_values: Optional[float] = init_values
|
||||
|
||||
@ -280,31 +284,53 @@ class SwinTransformerV2CrBlock(nn.Module):
|
||||
# Also being used as final network norm and optional stage ending norm while still in a C-last format.
|
||||
self.norm3 = norm_layer(dim) if extra_norm else nn.Identity()
|
||||
|
||||
self._make_attention_mask()
|
||||
self.register_buffer(
|
||||
"attn_mask",
|
||||
None if self.dynamic_mask else self.get_attn_mask(),
|
||||
persistent=False,
|
||||
)
|
||||
print(self.dynamic_mask)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _calc_window_shift(self, target_window_size):
|
||||
def _calc_window_shift(
|
||||
self,
|
||||
target_window_size: Tuple[int, int],
|
||||
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
target_window_size = to_2tuple(target_window_size)
|
||||
target_shift_size = self.target_shift_size
|
||||
if any(target_shift_size):
|
||||
# if non-zero, recalculate shift from current window size in case window size has changed
|
||||
target_shift_size = (target_window_size[0] // 2, target_window_size[1] // 2)
|
||||
|
||||
if self.always_partition:
|
||||
return target_window_size, target_shift_size
|
||||
|
||||
window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)]
|
||||
shift_size = [0 if f <= w else s for f, w, s in zip(self.feat_size, window_size, self.target_shift_size)]
|
||||
shift_size = [0 if f <= w else s for f, w, s in zip(self.feat_size, window_size, target_shift_size)]
|
||||
return tuple(window_size), tuple(shift_size)
|
||||
|
||||
def _make_attention_mask(self) -> None:
|
||||
def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
|
||||
"""Method generates the attention mask used in shift case."""
|
||||
# Make masks for shift case
|
||||
if any(self.shift_size):
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = self.feat_size
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
if x is None:
|
||||
img_mask = torch.zeros((1, *self.feat_size, 1)) # 1 H W 1
|
||||
else:
|
||||
img_mask = torch.zeros((1, x.shape[1], x.shape[2], 1), dtype=x.dtype, device=x.device) # 1 H W 1
|
||||
cnt = 0
|
||||
for h in (
|
||||
slice(0, -self.window_size[0]),
|
||||
slice(-self.window_size[0], -self.shift_size[0]),
|
||||
slice(-self.shift_size[0], None)):
|
||||
(0, -self.window_size[0]),
|
||||
(-self.window_size[0], -self.shift_size[0]),
|
||||
(-self.shift_size[0], None),
|
||||
):
|
||||
for w in (
|
||||
slice(0, -self.window_size[1]),
|
||||
slice(-self.window_size[1], -self.shift_size[1]),
|
||||
slice(-self.shift_size[1], None)):
|
||||
img_mask[:, h, w, :] = cnt
|
||||
(0, -self.window_size[1]),
|
||||
(-self.window_size[1], -self.shift_size[1]),
|
||||
(-self.shift_size[1], None),
|
||||
):
|
||||
img_mask[:, h[0]:h[1], w[0]:w[1], :] = cnt
|
||||
cnt += 1
|
||||
mask_windows = window_partition(img_mask, self.window_size) # num_windows, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1, self.window_area)
|
||||
@ -312,7 +338,7 @@ class SwinTransformerV2CrBlock(nn.Module):
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
else:
|
||||
attn_mask = None
|
||||
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
||||
return attn_mask
|
||||
|
||||
def init_weights(self):
|
||||
# extra, module specific weight init
|
||||
@ -332,7 +358,11 @@ class SwinTransformerV2CrBlock(nn.Module):
|
||||
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.attn.set_window_size(self.window_size)
|
||||
self._make_attention_mask()
|
||||
self.register_buffer(
|
||||
"attn_mask",
|
||||
None if self.dynamic_mask else self.get_attn_mask(),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def _shifted_window_attn(self, x):
|
||||
B, H, W, C = x.shape
|
||||
@ -346,16 +376,26 @@ class SwinTransformerV2CrBlock(nn.Module):
|
||||
# x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2)
|
||||
x = torch.roll(x, shifts=(-sh, -sw), dims=(1, 2))
|
||||
|
||||
pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
|
||||
pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
|
||||
x = torch.nn.functional.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
_, Hp, Wp, _ = x.shape
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(x, self.window_size) # num_windows * B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_windows * B, window_size * window_size, C
|
||||
if getattr(self, 'dynamic_mask', False):
|
||||
attn_mask = self.get_attn_mask(x)
|
||||
else:
|
||||
attn_mask = self.attn_mask
|
||||
attn_windows = self.attn(x_windows, mask=attn_mask) # num_windows * B, window_size * window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
|
||||
x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C
|
||||
x = window_reverse(attn_windows, self.window_size, (Hp, Wp)) # B H' W' C
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
# reverse cyclic shift
|
||||
if do_shift:
|
||||
@ -406,6 +446,11 @@ class PatchMerging(nn.Module):
|
||||
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_values = (0, 0, 0, H % 2, 0, W % 2)
|
||||
x = nn.functional.pad(x, pad_values)
|
||||
_, H, W, _ = x.shape
|
||||
|
||||
x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
@ -414,7 +459,15 @@ class PatchMerging(nn.Module):
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding """
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
norm_layer=None,
|
||||
strict_img_size=True,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
@ -422,14 +475,23 @@ class PatchEmbed(nn.Module):
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.strict_img_size = strict_img_size
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def set_input_size(self, img_size: Tuple[int, int]):
|
||||
img_size = to_2tuple(img_size)
|
||||
if img_size != self.img_size:
|
||||
self.img_size = img_size
|
||||
self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
if self.strict_img_size:
|
||||
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
x = self.proj(x)
|
||||
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
return x
|
||||
@ -463,6 +525,8 @@ class SwinTransformerV2CrStage(nn.Module):
|
||||
num_heads: int,
|
||||
feat_size: Tuple[int, int],
|
||||
window_size: Tuple[int, int],
|
||||
always_partition: bool = False,
|
||||
dynamic_mask: bool = False,
|
||||
mlp_ratio: float = 4.0,
|
||||
init_values: Optional[float] = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
@ -472,7 +536,7 @@ class SwinTransformerV2CrStage(nn.Module):
|
||||
extra_norm_period: int = 0,
|
||||
extra_norm_stage: bool = False,
|
||||
sequential_attn: bool = False,
|
||||
) -> None:
|
||||
):
|
||||
super(SwinTransformerV2CrStage, self).__init__()
|
||||
self.downscale: bool = downscale
|
||||
self.grad_checkpointing: bool = False
|
||||
@ -496,6 +560,8 @@ class SwinTransformerV2CrStage(nn.Module):
|
||||
num_heads=num_heads,
|
||||
feat_size=self.feat_size,
|
||||
window_size=window_size,
|
||||
always_partition=always_partition,
|
||||
dynamic_mask=dynamic_mask,
|
||||
shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]),
|
||||
mlp_ratio=mlp_ratio,
|
||||
init_values=init_values,
|
||||
@ -509,18 +575,24 @@ class SwinTransformerV2CrStage(nn.Module):
|
||||
for index in range(depth)]
|
||||
)
|
||||
|
||||
def set_input_size(self, feat_size: Tuple[int, int], window_size: int) -> None:
|
||||
"""Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
|
||||
def set_input_size(
|
||||
self,
|
||||
feat_size: Tuple[int, int],
|
||||
window_size: int,
|
||||
always_partition: Optional[bool] = None,
|
||||
):
|
||||
""" Updates the resolution to utilize and the window size and so the pair-wise relative positions.
|
||||
|
||||
Args:
|
||||
window_size (int): New window size
|
||||
feat_size (Tuple[int, int]): New input resolution
|
||||
"""
|
||||
self.feat_size: Tuple[int, int] = (
|
||||
(feat_size[0] // 2, feat_size[1] // 2) if self.downscale else feat_size
|
||||
)
|
||||
self.feat_size = (feat_size[0] // 2, feat_size[1] // 2) if self.downscale else feat_size
|
||||
for block in self.blocks:
|
||||
block.set_input_size(feat_size=self.feat_size, window_size=window_size)
|
||||
block.set_input_size(
|
||||
feat_size=self.feat_size,
|
||||
window_size=window_size,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass.
|
||||
@ -548,8 +620,8 @@ class SwinTransformerV2Cr(nn.Module):
|
||||
|
||||
Args:
|
||||
img_size: Input resolution.
|
||||
window_size: Window size. If None, img_size // window_div
|
||||
img_window_ratio: Window size to image size ratio.
|
||||
window_size: Window size. If None, grid_size // window_div
|
||||
window_ratio: Window size to patch grid ratio.
|
||||
patch_size: Patch size.
|
||||
in_chans: Number of input channels.
|
||||
depths: Depth of the stage (number of layers).
|
||||
@ -572,7 +644,9 @@ class SwinTransformerV2Cr(nn.Module):
|
||||
img_size: Tuple[int, int] = (224, 224),
|
||||
patch_size: int = 4,
|
||||
window_size: Optional[int] = None,
|
||||
img_window_ratio: int = 32,
|
||||
window_ratio: int = 8,
|
||||
always_partition: bool = False,
|
||||
strict_img_size: bool = True,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
embed_dim: int = 96,
|
||||
@ -594,13 +668,9 @@ class SwinTransformerV2Cr(nn.Module):
|
||||
) -> None:
|
||||
super(SwinTransformerV2Cr, self).__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
window_size = tuple([
|
||||
s // img_window_ratio for s in img_size]) if window_size is None else to_2tuple(window_size)
|
||||
|
||||
self.num_classes: int = num_classes
|
||||
self.patch_size: int = patch_size
|
||||
self.img_size: Tuple[int, int] = img_size
|
||||
self.window_size: int = window_size
|
||||
self.num_features = self.head_hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
self.feature_info = []
|
||||
|
||||
@ -610,8 +680,13 @@ class SwinTransformerV2Cr(nn.Module):
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=norm_layer,
|
||||
strict_img_size=strict_img_size,
|
||||
)
|
||||
patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size
|
||||
grid_size = self.patch_embed.grid_size
|
||||
if window_size is None:
|
||||
self.window_size = tuple([s // window_ratio for s in grid_size])
|
||||
else:
|
||||
self.window_size = to_2tuple(window_size)
|
||||
|
||||
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
stages = []
|
||||
@ -622,12 +697,11 @@ class SwinTransformerV2Cr(nn.Module):
|
||||
embed_dim=in_dim,
|
||||
depth=depth,
|
||||
downscale=stage_idx != 0,
|
||||
feat_size=(
|
||||
patch_grid_size[0] // in_scale,
|
||||
patch_grid_size[1] // in_scale
|
||||
),
|
||||
feat_size=(grid_size[0] // in_scale, grid_size[1] // in_scale),
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
window_size=self.window_size,
|
||||
always_partition=always_partition,
|
||||
dynamic_mask=not strict_img_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
init_values=init_values,
|
||||
proj_drop=proj_drop_rate,
|
||||
@ -660,28 +734,30 @@ class SwinTransformerV2Cr(nn.Module):
|
||||
self,
|
||||
img_size: Optional[Tuple[int, int]] = None,
|
||||
window_size: Optional[Tuple[int, int]] = None,
|
||||
window_ratio: int = 32,
|
||||
window_ratio: int = 8,
|
||||
always_partition: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
|
||||
"""Updates the image resolution, window size and so the pair-wise relative positions.
|
||||
|
||||
Args:
|
||||
img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
|
||||
window_size (Optional[int]): New window size, if None based on new_img_size // window_div
|
||||
window_ratio (int): divisor for calculating window size from image size
|
||||
window_ratio (int): divisor for calculating window size from patch grid size
|
||||
always_partition: always partition / shift windows even if feat size is < window
|
||||
"""
|
||||
if img_size is None:
|
||||
img_size = self.img_size
|
||||
else:
|
||||
img_size = to_2tuple(img_size)
|
||||
if window_size is None:
|
||||
window_size = tuple([s // window_ratio for s in img_size])
|
||||
# Compute new patch resolution & update resolution of each stage
|
||||
patch_grid_size = (img_size[0] // self.patch_size, img_size[1] // self.patch_size)
|
||||
if img_size is not None:
|
||||
self.patch_embed.set_input_size(img_size=img_size)
|
||||
grid_size = self.patch_embed.grid_size
|
||||
|
||||
if window_size is None and window_ratio is not None:
|
||||
window_size = tuple([s // window_ratio for s in grid_size])
|
||||
|
||||
for index, stage in enumerate(self.stages):
|
||||
stage_scale = 2 ** max(index - 1, 0)
|
||||
stage.set_input_size(
|
||||
feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale),
|
||||
feat_size=(grid_size[0] // stage_scale, grid_size[1] // stage_scale),
|
||||
window_size=window_size,
|
||||
always_partition=always_partition,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
|
Loading…
x
Reference in New Issue
Block a user