diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 28fd7fe1..eb1feeb5 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W # Written by Ze Liu # -------------------------------------------------------- import math -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union, Set, Dict import torch import torch.nn as nn @@ -32,7 +32,7 @@ __all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to _int_or_tuple_2_t = Union[int, Tuple[int, int]] -def window_partition(x, window_size: Tuple[int, int]): +def window_partition(x: torch.Tensor, window_size: Tuple[int, int]) -> torch.Tensor: """ Args: x: (B, H, W, C) @@ -48,7 +48,7 @@ def window_partition(x, window_size: Tuple[int, int]): @register_notrace_function # reason: int argument is a Proxy -def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): +def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], img_size: Tuple[int, int]) -> torch.Tensor: """ Args: windows: (num_windows * B, window_size[0], window_size[1], C) @@ -81,14 +81,14 @@ class WindowAttention(nn.Module): def __init__( self, - dim, - window_size, - num_heads, - qkv_bias=True, - attn_drop=0., - proj_drop=0., - pretrained_window_size=[0, 0], - ): + dim: int, + window_size: Tuple[int, int], + num_heads: int, + qkv_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + pretrained_window_size: Tuple[int, int] = (0, 0), + ) -> None: super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww @@ -149,7 +149,7 @@ class WindowAttention(nn.Module): self.proj_drop = nn.Dropout(proj_drop) self.softmax = nn.Softmax(dim=-1) - def forward(self, x, mask: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Args: x: input features with shape of (num_windows*B, N, C) @@ -197,20 +197,20 @@ class SwinTransformerV2Block(nn.Module): def __init__( self, - dim, - input_resolution, - num_heads, - window_size=7, - shift_size=0, - mlp_ratio=4., - qkv_bias=True, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - pretrained_window_size=0, - ): + dim: int, + input_resolution: _int_or_tuple_2_t, + num_heads: int, + window_size: _int_or_tuple_2_t = 7, + shift_size: _int_or_tuple_2_t = 0, + mlp_ratio: float = 4., + qkv_bias: bool = True, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + pretrained_window_size: _int_or_tuple_2_t = 0, + ) -> None: """ Args: dim: Number of input channels. @@ -282,14 +282,16 @@ class SwinTransformerV2Block(nn.Module): self.register_buffer("attn_mask", attn_mask, persistent=False) - def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + def _calc_window_shift(self, + target_window_size: _int_or_tuple_2_t, + target_shift_size: _int_or_tuple_2_t) -> Tuple[Tuple[int, int], Tuple[int, int]]: target_window_size = to_2tuple(target_window_size) target_shift_size = to_2tuple(target_shift_size) window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] return tuple(window_size), tuple(shift_size) - def _attn(self, x): + def _attn(self, x: torch.Tensor) -> torch.Tensor: B, H, W, C = x.shape # cyclic shift @@ -317,7 +319,7 @@ class SwinTransformerV2Block(nn.Module): x = shifted_x return x - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, C = x.shape x = x + self.drop_path1(self.norm1(self._attn(x))) x = x.reshape(B, -1, C) @@ -330,7 +332,7 @@ class PatchMerging(nn.Module): """ Patch Merging Layer. """ - def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm): + def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: nn.Module = nn.LayerNorm) -> None: """ Args: dim (int): Number of input channels. @@ -343,7 +345,7 @@ class PatchMerging(nn.Module): self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False) self.norm = norm_layer(self.out_dim) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: B, H, W, C = x.shape _assert(H % 2 == 0, f"x height ({H}) is not even.") _assert(W % 2 == 0, f"x width ({W}) is not even.") @@ -359,22 +361,22 @@ class SwinTransformerV2Stage(nn.Module): def __init__( self, - dim, - out_dim, - input_resolution, - depth, - num_heads, - window_size, - downsample=False, - mlp_ratio=4., - qkv_bias=True, - proj_drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm, - pretrained_window_size=0, - output_nchw=False, - ): + dim: int, + out_dim: int, + input_resolution: _int_or_tuple_2_t, + depth: int, + num_heads: int, + window_size: _int_or_tuple_2_t, + downsample: bool = False, + mlp_ratio: float = 4., + qkv_bias: bool = True, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + pretrained_window_size: _int_or_tuple_2_t = 0, + output_nchw: bool = False, + ) -> None: """ Args: dim: Number of input channels. @@ -428,7 +430,7 @@ class SwinTransformerV2Stage(nn.Module): ) for i in range(depth)]) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.downsample(x) for blk in self.blocks: @@ -438,7 +440,7 @@ class SwinTransformerV2Stage(nn.Module): x = blk(x) return x - def _init_respostnorm(self): + def _init_respostnorm(self) -> None: for blk in self.blocks: nn.init.constant_(blk.norm1.bias, 0) nn.init.constant_(blk.norm1.weight, 0)