From 4c531be4790bd3f354ed373568712e25eb404c57 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 24 Jul 2024 16:41:31 -0700 Subject: [PATCH] set_input_size(), always_partition, strict_img_size, dynamic mask option for all swin models. More flexibility in resolution, window resizing. --- timm/models/swin_transformer.py | 111 ++++++++---- timm/models/swin_transformer_v2.py | 241 +++++++++++++++++++++----- timm/models/swin_transformer_v2_cr.py | 212 ++++++++++++++-------- 3 files changed, 416 insertions(+), 148 deletions(-) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 71c4e639..7b396e4a 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -219,6 +219,7 @@ class SwinTransformerBlock(nn.Module): window_size: _int_or_tuple_2_t = 7, shift_size: int = 0, always_partition: bool = False, + dynamic_mask: bool = False, mlp_ratio: float = 4., qkv_bias: bool = True, proj_drop: float = 0., @@ -235,6 +236,7 @@ class SwinTransformerBlock(nn.Module): num_heads: Number of attention heads. head_dim: Enforce the number of channels per head shift_size: Shift size for SW-MSA. + always_partition: Always partition into full windows and shift mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: If True, add a learnable bias to query, key, value. proj_drop: Dropout rate. @@ -246,9 +248,10 @@ class SwinTransformerBlock(nn.Module): super().__init__() self.dim = dim self.input_resolution = input_resolution - self.target_shift_size = to_2tuple(shift_size) + self.target_shift_size = to_2tuple(shift_size) # store for later resize self.always_partition = always_partition - self.window_size, self.shift_size = self._calc_window_shift(window_size, target_shift_size=shift_size) + self.dynamic_mask = dynamic_mask + self.window_size, self.shift_size = self._calc_window_shift(window_size, shift_size) self.window_area = self.window_size[0] * self.window_size[1] self.mlp_ratio = mlp_ratio @@ -257,7 +260,7 @@ class SwinTransformerBlock(nn.Module): dim, num_heads=num_heads, head_dim=head_dim, - window_size=to_2tuple(self.window_size), + window_size=self.window_size, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, @@ -273,25 +276,38 @@ class SwinTransformerBlock(nn.Module): ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self._make_attention_mask() + self.register_buffer( + "attn_mask", + None if self.dynamic_mask else self.get_attn_mask(), + persistent=False, + ) - def _make_attention_mask(self): + def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: if any(self.shift_size): # calculate attention mask for SW-MSA - H, W = self.input_resolution + if x is not None: + H, W = x.shape[1], x.shape[2] + device = x.device + dtype = x.dtype + else: + H, W = self.input_resolution + device = None + dtype = None H = math.ceil(H / self.window_size[0]) * self.window_size[0] W = math.ceil(W / self.window_size[1]) * self.window_size[1] - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + img_mask = torch.zeros((1, H, W, 1), dtype=dtype, device=device) # 1 H W 1 cnt = 0 for h in ( - slice(0, -self.window_size[0]), - slice(-self.window_size[0], -self.shift_size[0]), - slice(-self.shift_size[0], None)): + (0, -self.window_size[0]), + (-self.window_size[0], -self.shift_size[0]), + (-self.shift_size[0], None), + ): for w in ( - slice(0, -self.window_size[1]), - slice(-self.window_size[1], -self.shift_size[1]), - slice(-self.shift_size[1], None)): - img_mask[:, h, w, :] = cnt + (0, -self.window_size[1]), + (-self.window_size[1], -self.shift_size[1]), + (-self.shift_size[1], None), + ): + img_mask[:, h[0]:h[1], w[0]:w[1], :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_area) @@ -299,7 +315,7 @@ class SwinTransformerBlock(nn.Module): attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None - self.register_buffer("attn_mask", attn_mask, persistent=False) + return attn_mask def _calc_window_shift( self, @@ -308,14 +324,16 @@ class SwinTransformerBlock(nn.Module): ) -> Tuple[Tuple[int, int], Tuple[int, int]]: target_window_size = to_2tuple(target_window_size) if target_shift_size is None: - # if passed value is None, recalculate from default window_size // 2 if it was active + # if passed value is None, recalculate from default window_size // 2 if it was previously non-zero target_shift_size = self.target_shift_size if any(target_shift_size): - target_shift_size = [target_window_size[0] // 2, target_window_size[1] // 2] + target_shift_size = (target_window_size[0] // 2, target_window_size[1] // 2) else: target_shift_size = to_2tuple(target_shift_size) + if self.always_partition: return target_window_size, target_shift_size + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] return tuple(window_size), tuple(shift_size) @@ -338,7 +356,11 @@ class SwinTransformerBlock(nn.Module): self.window_size, self.shift_size = self._calc_window_shift(window_size) self.window_area = self.window_size[0] * self.window_size[1] self.attn.set_window_size(self.window_size) - self._make_attention_mask() + self.register_buffer( + "attn_mask", + None if self.dynamic_mask else self.get_attn_mask(), + persistent=False, + ) def _attn(self, x): B, H, W, C = x.shape @@ -354,14 +376,18 @@ class SwinTransformerBlock(nn.Module): pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h)) - Hp, Wp = H + pad_h, W + pad_w + _, Hp, Wp, _ = shifted_x.shape # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + if getattr(self, 'dynamic_mask', False): + attn_mask = self.get_attn_mask(shifted_x) + else: + attn_mask = self.attn_mask + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) @@ -408,8 +434,11 @@ class PatchMerging(nn.Module): def forward(self, x): B, H, W, C = x.shape - _assert(H % 2 == 0, f"x height ({H}) is not even.") - _assert(W % 2 == 0, f"x width ({W}) is not even.") + + pad_values = (0, 0, 0, H % 2, 0, W % 2) + x = nn.functional.pad(x, pad_values) + _, H, W, _ = x.shape + x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3) x = self.norm(x) x = self.reduction(x) @@ -431,6 +460,7 @@ class SwinTransformerStage(nn.Module): head_dim: Optional[int] = None, window_size: _int_or_tuple_2_t = 7, always_partition: bool = False, + dynamic_mask: bool = False, mlp_ratio: float = 4., qkv_bias: bool = True, proj_drop: float = 0., @@ -485,6 +515,7 @@ class SwinTransformerStage(nn.Module): window_size=window_size, shift_size=0 if (i % 2 == 0) else shift_size, always_partition=always_partition, + dynamic_mask=dynamic_mask, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_drop=proj_drop, @@ -500,11 +531,12 @@ class SwinTransformerStage(nn.Module): window_size: int, always_partition: Optional[bool] = None, ): - """Method updates the resolution to utilize and the window size and so the pair-wise relative positions. + """ Updates the resolution, window size and so the pair-wise relative positions. Args: - feat_size (Tuple[int, int]): New input resolution - window_size (int): New window size + feat_size: New input (feature) resolution + window_size: New window size + always_partition: Always partition / shift the window """ self.input_resolution = feat_size if isinstance(self.downsample, nn.Identity): @@ -548,6 +580,7 @@ class SwinTransformer(nn.Module): head_dim: Optional[int] = None, window_size: _int_or_tuple_2_t = 7, always_partition: bool = False, + strict_img_size: bool = True, mlp_ratio: float = 4., qkv_bias: bool = True, drop_rate: float = 0., @@ -599,9 +632,10 @@ class SwinTransformer(nn.Module): in_chans=in_chans, embed_dim=embed_dim[0], norm_layer=norm_layer, + strict_img_size=strict_img_size, output_fmt='NHWC', ) - self.patch_grid = self.patch_embed.grid_size + patch_grid = self.patch_embed.grid_size # build layers head_dim = to_ntuple(self.num_layers)(head_dim) @@ -621,8 +655,8 @@ class SwinTransformer(nn.Module): dim=in_dim, out_dim=out_dim, input_resolution=( - self.patch_grid[0] // scale, - self.patch_grid[1] // scale + patch_grid[0] // scale, + patch_grid[1] // scale ), depth=depths[i], downsample=i > 0, @@ -630,6 +664,7 @@ class SwinTransformer(nn.Module): head_dim=head_dim[i], window_size=window_size[i], always_partition=always_partition, + dynamic_mask=not strict_img_size, mlp_ratio=mlp_ratio[i], qkv_bias=qkv_bias, proj_drop=proj_drop_rate, @@ -673,27 +708,29 @@ class SwinTransformer(nn.Module): img_size: Optional[Tuple[int, int]] = None, patch_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None, - window_ratio: int = 32, + window_ratio: int = 8, always_partition: Optional[bool] = None, ) -> None: """ Updates the image resolution and window size. Args: - img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used - window_size (Optional[int]): New window size, if None based on new_img_size // window_div - window_ratio (int): divisor for calculating window size from image size + img_size: New input resolution, if None current resolution is used + patch_size (Optional[Tuple[int, int]): New patch size, if None use current patch size + window_size: New window size, if None based on new_img_size // window_div + window_ratio: divisor for calculating window size from grid size + always_partition: always partition into windows and shift (even if window size < feat size) """ if img_size is not None or patch_size is not None: self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size) - self.patch_grid = self.patch_embed.grid_size + patch_grid = self.patch_embed.grid_size + if window_size is None: - img_size = self.patch_embed.img_size - window_size = tuple([s // window_ratio for s in img_size]) + window_size = tuple([pg // window_ratio for pg in patch_grid]) + for index, stage in enumerate(self.layers): stage_scale = 2 ** max(index - 1, 0) - print(self.patch_grid, stage_scale) stage.set_input_size( - feat_size=(self.patch_grid[0] // stage_scale, self.patch_grid[1] // stage_scale), + feat_size=(patch_grid[0] // stage_scale, patch_grid[1] // stage_scale), window_size=window_size, always_partition=always_partition, ) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 9651189c..8ee247a5 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -94,7 +94,7 @@ class WindowAttention(nn.Module): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww - self.pretrained_window_size = pretrained_window_size + self.pretrained_window_size = to_2tuple(pretrained_window_size) self.num_heads = num_heads self.qkv_bias_separate = qkv_bias_separate @@ -107,21 +107,37 @@ class WindowAttention(nn.Module): nn.Linear(512, num_heads, bias=False) ) + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.register_buffer('k_bias', torch.zeros(dim), persistent=False) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.k_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + self._make_pair_wise_relative_positions() + + def _make_pair_wise_relative_positions(self): # get relative_coords_table relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0]).to(torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1]).to(torch.float32) relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w)) relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 - if pretrained_window_size[0] > 0: - relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + if self.pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (self.pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.pretrained_window_size[1] - 1) else: relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) relative_coords_table *= 8 # normalize to -8, 8 relative_coords_table = torch.sign(relative_coords_table) * torch.log2( torch.abs(relative_coords_table) + 1.0) / math.log2(8) - self.register_buffer("relative_coords_table", relative_coords_table, persistent=False) # get pair-wise relative position index for each token inside the window @@ -137,19 +153,15 @@ class WindowAttention(nn.Module): relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index, persistent=False) - self.qkv = nn.Linear(dim, dim * 3, bias=False) - if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(dim)) - self.register_buffer('k_bias', torch.zeros(dim), persistent=False) - self.v_bias = nn.Parameter(torch.zeros(dim)) - else: - self.q_bias = None - self.k_bias = None - self.v_bias = None - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - self.softmax = nn.Softmax(dim=-1) + def set_window_size(self, window_size: Tuple[int, int]) -> None: + """Update window size & interpolate position embeddings + Args: + window_size (int): New window size + """ + window_size = to_2tuple(window_size) + if window_size != self.window_size: + self.window_size = window_size + self._make_pair_wise_relative_positions() def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ @@ -210,6 +222,8 @@ class SwinTransformerV2Block(nn.Module): num_heads: int, window_size: _int_or_tuple_2_t = 7, shift_size: _int_or_tuple_2_t = 0, + always_partition: bool = False, + dynamic_mask: bool = False, mlp_ratio: float = 4., qkv_bias: bool = True, proj_drop: float = 0., @@ -218,7 +232,7 @@ class SwinTransformerV2Block(nn.Module): act_layer: LayerType = "gelu", norm_layer: nn.Module = nn.LayerNorm, pretrained_window_size: _int_or_tuple_2_t = 0, - ) -> None: + ): """ Args: dim: Number of input channels. @@ -226,6 +240,7 @@ class SwinTransformerV2Block(nn.Module): num_heads: Number of attention heads. window_size: Window size. shift_size: Shift size for SW-MSA. + always_partition: Always partition into full windows and shift mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: If True, add a learnable bias to query, key, value. proj_drop: Dropout rate. @@ -239,9 +254,10 @@ class SwinTransformerV2Block(nn.Module): self.dim = dim self.input_resolution = to_2tuple(input_resolution) self.num_heads = num_heads - ws, ss = self._calc_window_shift(window_size, shift_size) - self.window_size: Tuple[int, int] = ws - self.shift_size: Tuple[int, int] = ss + self.target_shift_size = to_2tuple(shift_size) # store for later resize + self.always_partition = always_partition + self.dynamic_mask = dynamic_mask + self.window_size, self.shift_size = self._calc_window_shift(window_size, shift_size) self.window_area = self.window_size[0] * self.window_size[1] self.mlp_ratio = mlp_ratio act_layer = get_act_layer(act_layer) @@ -267,20 +283,31 @@ class SwinTransformerV2Block(nn.Module): self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.register_buffer( + "attn_mask", + None if self.dynamic_mask else self.get_attn_mask(), + persistent=False, + ) + + def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: if any(self.shift_size): # calculate attention mask for SW-MSA - H, W = self.input_resolution - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + if x is None: + img_mask = torch.zeros((1, *self.input_resolution, 1)) # 1 H W 1 + else: + img_mask = torch.zeros((1, x.shape[1], x.shape[2], 1), dtype=x.dtype, device=x.device) # 1 H W 1 cnt = 0 for h in ( - slice(0, -self.window_size[0]), - slice(-self.window_size[0], -self.shift_size[0]), - slice(-self.shift_size[0], None)): + (0, -self.window_size[0]), + (-self.window_size[0], -self.shift_size[0]), + (-self.shift_size[0], None), + ): for w in ( - slice(0, -self.window_size[1]), - slice(-self.window_size[1], -self.shift_size[1]), - slice(-self.shift_size[1], None)): - img_mask[:, h, w, :] = cnt + (0, -self.window_size[1]), + (-self.window_size[1], -self.shift_size[1]), + (-self.shift_size[1], None), + ): + img_mask[:, h[0]:h[1], w[0]:w[1], :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_area) @@ -288,18 +315,60 @@ class SwinTransformerV2Block(nn.Module): attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None + return attn_mask - self.register_buffer("attn_mask", attn_mask, persistent=False) + def _calc_window_shift( + self, + target_window_size: _int_or_tuple_2_t, + target_shift_size: Optional[_int_or_tuple_2_t] = None, + ) -> Tuple[Tuple[int, int], Tuple[int, int]]: + target_window_size = to_2tuple(target_window_size) + if target_shift_size is None: + # if passed value is None, recalculate from default window_size // 2 if it was active + target_shift_size = self.target_shift_size + if any(target_shift_size): + # if there was previously a non-zero shift, recalculate based on current window_size + target_shift_size = (target_window_size[0] // 2, target_window_size[1] // 2) + else: + target_shift_size = to_2tuple(target_shift_size) + + if self.always_partition: + print('ap', target_window_size, target_shift_size) + return target_window_size, target_shift_size - def _calc_window_shift(self, - target_window_size: _int_or_tuple_2_t, - target_shift_size: _int_or_tuple_2_t) -> Tuple[Tuple[int, int], Tuple[int, int]]: target_window_size = to_2tuple(target_window_size) target_shift_size = to_2tuple(target_shift_size) window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + print('nap', window_size, shift_size) return tuple(window_size), tuple(shift_size) + def set_input_size( + self, + feat_size: Tuple[int, int], + window_size: Tuple[int, int], + always_partition: Optional[bool] = None, + ): + """ Updates the input resolution, window size. + + Args: + feat_size (Tuple[int, int]): New input resolution + window_size (int): New window size + always_partition: Change always_partition attribute if not None + """ + # Update input resolution + self.input_resolution = feat_size + if always_partition is not None: + self.always_partition = always_partition + self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) + self.window_area = self.window_size[0] * self.window_size[1] + self.attn.set_window_size(self.window_size) + self.register_buffer( + "attn_mask", + None if self.dynamic_mask else self.get_attn_mask(), + persistent=False, + ) + def _attn(self, x: torch.Tensor) -> torch.Tensor: B, H, W, C = x.shape @@ -310,16 +379,26 @@ class SwinTransformerV2Block(nn.Module): else: shifted_x = x + pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h)) + _, Hp, Wp, _ = shifted_x.shape + # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + if getattr(self, 'dynamic_mask', False): + attn_mask = self.get_attn_mask(shifted_x) + else: + attn_mask = self.attn_mask + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) - shifted_x = window_reverse(attn_windows, self.window_size, self.input_resolution) # B H' W' C + shifted_x = window_reverse(attn_windows, self.window_size, (Hp, Wp)) # B H' W' C + shifted_x = shifted_x[:, :H, :W, :].contiguous() # reverse cyclic shift if has_shift: @@ -341,7 +420,12 @@ class PatchMerging(nn.Module): """ Patch Merging Layer. """ - def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: nn.Module = nn.LayerNorm) -> None: + def __init__( + self, + dim: int, + out_dim: Optional[int] = None, + norm_layer: nn.Module = nn.LayerNorm + ): """ Args: dim (int): Number of input channels. @@ -356,8 +440,11 @@ class PatchMerging(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, C = x.shape - _assert(H % 2 == 0, f"x height ({H}) is not even.") - _assert(W % 2 == 0, f"x width ({W}) is not even.") + + pad_values = (0, 0, 0, H % 2, 0, W % 2) + x = nn.functional.pad(x, pad_values) + _, H, W, _ = x.shape + x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3) x = self.reduction(x) x = self.norm(x) @@ -376,6 +463,8 @@ class SwinTransformerV2Stage(nn.Module): depth: int, num_heads: int, window_size: _int_or_tuple_2_t, + always_partition: bool = False, + dynamic_mask: bool = False, downsample: bool = False, mlp_ratio: float = 4., qkv_bias: bool = True, @@ -395,6 +484,8 @@ class SwinTransformerV2Stage(nn.Module): depth: Number of blocks. num_heads: Number of attention heads. window_size: Local window size. + always_partition: Always partition into full windows and shift + dynamic_mask: Create attention mask in forward based on current input size downsample: Use downsample layer at start of the block. mlp_ratio: Ratio of mlp hidden dim to embedding dim. qkv_bias: If True, add a learnable bias to query, key, value. @@ -431,6 +522,8 @@ class SwinTransformerV2Stage(nn.Module): num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else shift_size, + always_partition=always_partition, + dynamic_mask=dynamic_mask, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_drop=proj_drop, @@ -442,6 +535,32 @@ class SwinTransformerV2Stage(nn.Module): ) for i in range(depth)]) + def set_input_size( + self, + feat_size: Tuple[int, int], + window_size: int, + always_partition: Optional[bool] = None, + ): + """ Updates the resolution, window size and so the pair-wise relative positions. + + Args: + feat_size: New input (feature) resolution + window_size: New window size + always_partition: Always partition / shift the window + """ + self.input_resolution = feat_size + if isinstance(self.downsample, nn.Identity): + self.output_resolution = feat_size + else: + assert isinstance(self.downsample, PatchMerging) + self.output_resolution = tuple(i // 2 for i in feat_size) + for block in self.blocks: + block.set_input_size( + feat_size=self.output_resolution, + window_size=window_size, + always_partition=always_partition, + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.downsample(x) @@ -478,6 +597,8 @@ class SwinTransformerV2(nn.Module): depths: Tuple[int, ...] = (2, 2, 6, 2), num_heads: Tuple[int, ...] = (3, 6, 12, 24), window_size: _int_or_tuple_2_t = 7, + always_partition: bool = False, + strict_img_size: bool = True, mlp_ratio: float = 4., qkv_bias: bool = True, drop_rate: float = 0., @@ -532,8 +653,10 @@ class SwinTransformerV2(nn.Module): in_chans=in_chans, embed_dim=embed_dim[0], norm_layer=norm_layer, + strict_img_size=strict_img_size, output_fmt='NHWC', ) + grid_size = self.patch_embed.grid_size dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] layers = [] @@ -544,13 +667,13 @@ class SwinTransformerV2(nn.Module): layers += [SwinTransformerV2Stage( dim=in_dim, out_dim=out_dim, - input_resolution=( - self.patch_embed.grid_size[0] // scale, - self.patch_embed.grid_size[1] // scale), + input_resolution=(grid_size[0] // scale, grid_size[1] // scale), depth=depths[i], downsample=i > 0, num_heads=num_heads[i], window_size=window_size, + always_partition=always_partition, + dynamic_mask=not strict_img_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_drop=proj_drop_rate, @@ -585,6 +708,38 @@ class SwinTransformerV2(nn.Module): if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) + def set_input_size( + self, + img_size: Optional[Tuple[int, int]] = None, + patch_size: Optional[Tuple[int, int]] = None, + window_size: Optional[Tuple[int, int]] = None, + window_ratio: Optional[int] = 8, + always_partition: Optional[bool] = None, + ): + """Updates the image resolution, window size, and so the pair-wise relative positions. + + Args: + img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used + patch_size (Optional[Tuple[int, int]): New patch size, if None use current patch size + window_size (Optional[int]): New window size, if None based on new_img_size // window_div + window_ratio (int): divisor for calculating window size from patch grid size + always_partition: always partition / shift windows even if feat size is < window + """ + if img_size is not None or patch_size is not None: + self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size) + grid_size = self.patch_embed.grid_size + + if window_size is None and window_ratio is not None: + window_size = tuple([s // window_ratio for s in grid_size]) + + for index, stage in enumerate(self.layers): + stage_scale = 2 ** max(index - 1, 0) + stage.set_input_size( + feat_size=(grid_size[0] // stage_scale, grid_size[1] // stage_scale), + window_size=window_size, + always_partition=always_partition, + ) + @torch.jit.ignore def no_weight_decay(self): nod = set() diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index d7c5f672..47a8c123 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -231,26 +231,30 @@ class SwinTransformerV2CrBlock(nn.Module): """ def __init__( - self, - dim: int, - num_heads: int, - feat_size: Tuple[int, int], - window_size: Tuple[int, int], - shift_size: Tuple[int, int] = (0, 0), - mlp_ratio: float = 4.0, - init_values: Optional[float] = 0, - proj_drop: float = 0.0, - drop_attn: float = 0.0, - drop_path: float = 0.0, - extra_norm: bool = False, - sequential_attn: bool = False, - norm_layer: Type[nn.Module] = nn.LayerNorm, - ) -> None: + self, + dim: int, + num_heads: int, + feat_size: Tuple[int, int], + window_size: Tuple[int, int], + shift_size: Tuple[int, int] = (0, 0), + always_partition: bool = False, + dynamic_mask: bool = False, + mlp_ratio: float = 4.0, + init_values: Optional[float] = 0, + proj_drop: float = 0.0, + drop_attn: float = 0.0, + drop_path: float = 0.0, + extra_norm: bool = False, + sequential_attn: bool = False, + norm_layer: Type[nn.Module] = nn.LayerNorm, + ): super(SwinTransformerV2CrBlock, self).__init__() self.dim: int = dim self.feat_size: Tuple[int, int] = feat_size self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size) - self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) + self.always_partition = always_partition + self.dynamic_mask = dynamic_mask + self.window_size, self.shift_size = self._calc_window_shift(window_size) self.window_area = self.window_size[0] * self.window_size[1] self.init_values: Optional[float] = init_values @@ -280,31 +284,53 @@ class SwinTransformerV2CrBlock(nn.Module): # Also being used as final network norm and optional stage ending norm while still in a C-last format. self.norm3 = norm_layer(dim) if extra_norm else nn.Identity() - self._make_attention_mask() + self.register_buffer( + "attn_mask", + None if self.dynamic_mask else self.get_attn_mask(), + persistent=False, + ) + print(self.dynamic_mask) + self.init_weights() - def _calc_window_shift(self, target_window_size): + def _calc_window_shift( + self, + target_window_size: Tuple[int, int], + ) -> Tuple[Tuple[int, int], Tuple[int, int]]: + target_window_size = to_2tuple(target_window_size) + target_shift_size = self.target_shift_size + if any(target_shift_size): + # if non-zero, recalculate shift from current window size in case window size has changed + target_shift_size = (target_window_size[0] // 2, target_window_size[1] // 2) + + if self.always_partition: + return target_window_size, target_shift_size + window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)] - shift_size = [0 if f <= w else s for f, w, s in zip(self.feat_size, window_size, self.target_shift_size)] + shift_size = [0 if f <= w else s for f, w, s in zip(self.feat_size, window_size, target_shift_size)] return tuple(window_size), tuple(shift_size) - def _make_attention_mask(self) -> None: + def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: """Method generates the attention mask used in shift case.""" # Make masks for shift case if any(self.shift_size): # calculate attention mask for SW-MSA - H, W = self.feat_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + if x is None: + img_mask = torch.zeros((1, *self.feat_size, 1)) # 1 H W 1 + else: + img_mask = torch.zeros((1, x.shape[1], x.shape[2], 1), dtype=x.dtype, device=x.device) # 1 H W 1 cnt = 0 for h in ( - slice(0, -self.window_size[0]), - slice(-self.window_size[0], -self.shift_size[0]), - slice(-self.shift_size[0], None)): + (0, -self.window_size[0]), + (-self.window_size[0], -self.shift_size[0]), + (-self.shift_size[0], None), + ): for w in ( - slice(0, -self.window_size[1]), - slice(-self.window_size[1], -self.shift_size[1]), - slice(-self.shift_size[1], None)): - img_mask[:, h, w, :] = cnt + (0, -self.window_size[1]), + (-self.window_size[1], -self.shift_size[1]), + (-self.shift_size[1], None), + ): + img_mask[:, h[0]:h[1], w[0]:w[1], :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # num_windows, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_area) @@ -312,7 +338,7 @@ class SwinTransformerV2CrBlock(nn.Module): attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None - self.register_buffer("attn_mask", attn_mask, persistent=False) + return attn_mask def init_weights(self): # extra, module specific weight init @@ -332,7 +358,11 @@ class SwinTransformerV2CrBlock(nn.Module): self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) self.window_area = self.window_size[0] * self.window_size[1] self.attn.set_window_size(self.window_size) - self._make_attention_mask() + self.register_buffer( + "attn_mask", + None if self.dynamic_mask else self.get_attn_mask(), + persistent=False, + ) def _shifted_window_attn(self, x): B, H, W, C = x.shape @@ -346,16 +376,26 @@ class SwinTransformerV2CrBlock(nn.Module): # x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2) x = torch.roll(x, shifts=(-sh, -sw), dims=(1, 2)) + pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + x = torch.nn.functional.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + _, Hp, Wp, _ = x.shape + # partition windows x_windows = window_partition(x, self.window_size) # num_windows * B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C) # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_windows * B, window_size * window_size, C + if getattr(self, 'dynamic_mask', False): + attn_mask = self.get_attn_mask(x) + else: + attn_mask = self.attn_mask + attn_windows = self.attn(x_windows, mask=attn_mask) # num_windows * B, window_size * window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) - x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C + x = window_reverse(attn_windows, self.window_size, (Hp, Wp)) # B H' W' C + x = x[:, :H, :W, :].contiguous() # reverse cyclic shift if do_shift: @@ -406,6 +446,11 @@ class PatchMerging(nn.Module): output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] """ B, H, W, C = x.shape + + pad_values = (0, 0, 0, H % 2, 0, W % 2) + x = nn.functional.pad(x, pad_values) + _, H, W, _ = x.shape + x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3) x = self.norm(x) x = self.reduction(x) @@ -414,7 +459,15 @@ class PatchMerging(nn.Module): class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + strict_img_size=True, + ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -422,14 +475,23 @@ class PatchEmbed(nn.Module): self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] + self.strict_img_size = strict_img_size self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def set_input_size(self, img_size: Tuple[int, int]): + img_size = to_2tuple(img_size) + if img_size != self.img_size: + self.img_size = img_size + self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + def forward(self, x): B, C, H, W = x.shape - _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") - _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + if self.strict_img_size: + _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") x = self.proj(x) x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x @@ -463,6 +525,8 @@ class SwinTransformerV2CrStage(nn.Module): num_heads: int, feat_size: Tuple[int, int], window_size: Tuple[int, int], + always_partition: bool = False, + dynamic_mask: bool = False, mlp_ratio: float = 4.0, init_values: Optional[float] = 0.0, proj_drop: float = 0.0, @@ -472,7 +536,7 @@ class SwinTransformerV2CrStage(nn.Module): extra_norm_period: int = 0, extra_norm_stage: bool = False, sequential_attn: bool = False, - ) -> None: + ): super(SwinTransformerV2CrStage, self).__init__() self.downscale: bool = downscale self.grad_checkpointing: bool = False @@ -496,6 +560,8 @@ class SwinTransformerV2CrStage(nn.Module): num_heads=num_heads, feat_size=self.feat_size, window_size=window_size, + always_partition=always_partition, + dynamic_mask=dynamic_mask, shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]), mlp_ratio=mlp_ratio, init_values=init_values, @@ -509,18 +575,24 @@ class SwinTransformerV2CrStage(nn.Module): for index in range(depth)] ) - def set_input_size(self, feat_size: Tuple[int, int], window_size: int) -> None: - """Method updates the resolution to utilize and the window size and so the pair-wise relative positions. + def set_input_size( + self, + feat_size: Tuple[int, int], + window_size: int, + always_partition: Optional[bool] = None, + ): + """ Updates the resolution to utilize and the window size and so the pair-wise relative positions. Args: window_size (int): New window size feat_size (Tuple[int, int]): New input resolution """ - self.feat_size: Tuple[int, int] = ( - (feat_size[0] // 2, feat_size[1] // 2) if self.downscale else feat_size - ) + self.feat_size = (feat_size[0] // 2, feat_size[1] // 2) if self.downscale else feat_size for block in self.blocks: - block.set_input_size(feat_size=self.feat_size, window_size=window_size) + block.set_input_size( + feat_size=self.feat_size, + window_size=window_size, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. @@ -548,8 +620,8 @@ class SwinTransformerV2Cr(nn.Module): Args: img_size: Input resolution. - window_size: Window size. If None, img_size // window_div - img_window_ratio: Window size to image size ratio. + window_size: Window size. If None, grid_size // window_div + window_ratio: Window size to patch grid ratio. patch_size: Patch size. in_chans: Number of input channels. depths: Depth of the stage (number of layers). @@ -572,7 +644,9 @@ class SwinTransformerV2Cr(nn.Module): img_size: Tuple[int, int] = (224, 224), patch_size: int = 4, window_size: Optional[int] = None, - img_window_ratio: int = 32, + window_ratio: int = 8, + always_partition: bool = False, + strict_img_size: bool = True, in_chans: int = 3, num_classes: int = 1000, embed_dim: int = 96, @@ -594,13 +668,9 @@ class SwinTransformerV2Cr(nn.Module): ) -> None: super(SwinTransformerV2Cr, self).__init__() img_size = to_2tuple(img_size) - window_size = tuple([ - s // img_window_ratio for s in img_size]) if window_size is None else to_2tuple(window_size) - self.num_classes: int = num_classes self.patch_size: int = patch_size self.img_size: Tuple[int, int] = img_size - self.window_size: int = window_size self.num_features = self.head_hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.feature_info = [] @@ -610,8 +680,13 @@ class SwinTransformerV2Cr(nn.Module): in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer, + strict_img_size=strict_img_size, ) - patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size + grid_size = self.patch_embed.grid_size + if window_size is None: + self.window_size = tuple([s // window_ratio for s in grid_size]) + else: + self.window_size = to_2tuple(window_size) dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] stages = [] @@ -622,12 +697,11 @@ class SwinTransformerV2Cr(nn.Module): embed_dim=in_dim, depth=depth, downscale=stage_idx != 0, - feat_size=( - patch_grid_size[0] // in_scale, - patch_grid_size[1] // in_scale - ), + feat_size=(grid_size[0] // in_scale, grid_size[1] // in_scale), num_heads=num_heads, - window_size=window_size, + window_size=self.window_size, + always_partition=always_partition, + dynamic_mask=not strict_img_size, mlp_ratio=mlp_ratio, init_values=init_values, proj_drop=proj_drop_rate, @@ -660,28 +734,30 @@ class SwinTransformerV2Cr(nn.Module): self, img_size: Optional[Tuple[int, int]] = None, window_size: Optional[Tuple[int, int]] = None, - window_ratio: int = 32, + window_ratio: int = 8, + always_partition: Optional[bool] = None, ) -> None: - """Method updates the image resolution to be processed and window size and so the pair-wise relative positions. + """Updates the image resolution, window size and so the pair-wise relative positions. Args: img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used window_size (Optional[int]): New window size, if None based on new_img_size // window_div - window_ratio (int): divisor for calculating window size from image size + window_ratio (int): divisor for calculating window size from patch grid size + always_partition: always partition / shift windows even if feat size is < window """ - if img_size is None: - img_size = self.img_size - else: - img_size = to_2tuple(img_size) - if window_size is None: - window_size = tuple([s // window_ratio for s in img_size]) - # Compute new patch resolution & update resolution of each stage - patch_grid_size = (img_size[0] // self.patch_size, img_size[1] // self.patch_size) + if img_size is not None: + self.patch_embed.set_input_size(img_size=img_size) + grid_size = self.patch_embed.grid_size + + if window_size is None and window_ratio is not None: + window_size = tuple([s // window_ratio for s in grid_size]) + for index, stage in enumerate(self.stages): stage_scale = 2 ** max(index - 1, 0) stage.set_input_size( - feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale), + feat_size=(grid_size[0] // stage_scale, grid_size[1] // stage_scale), window_size=window_size, + always_partition=always_partition, ) @torch.jit.ignore