add type annotations in the code of swin_transformer_v2

This commit is contained in:
Li zhuoqun 2023-12-13 18:53:42 +08:00 committed by Ross Wightman
parent bbe798317f
commit 7da34a999a

View File

@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
# Written by Ze Liu # Written by Ze Liu
# -------------------------------------------------------- # --------------------------------------------------------
import math import math
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union, Set, Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -32,7 +32,7 @@ __all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to
_int_or_tuple_2_t = Union[int, Tuple[int, int]] _int_or_tuple_2_t = Union[int, Tuple[int, int]]
def window_partition(x, window_size: Tuple[int, int]): def window_partition(x: torch.Tensor, window_size: Tuple[int, int]) -> torch.Tensor:
""" """
Args: Args:
x: (B, H, W, C) x: (B, H, W, C)
@ -48,7 +48,7 @@ def window_partition(x, window_size: Tuple[int, int]):
@register_notrace_function # reason: int argument is a Proxy @register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], img_size: Tuple[int, int]) -> torch.Tensor:
""" """
Args: Args:
windows: (num_windows * B, window_size[0], window_size[1], C) windows: (num_windows * B, window_size[0], window_size[1], C)
@ -81,14 +81,14 @@ class WindowAttention(nn.Module):
def __init__( def __init__(
self, self,
dim, dim: int,
window_size, window_size: Tuple[int, int],
num_heads, num_heads: int,
qkv_bias=True, qkv_bias: bool = True,
attn_drop=0., attn_drop: float = 0.,
proj_drop=0., proj_drop: float = 0.,
pretrained_window_size=[0, 0], pretrained_window_size: Tuple[int, int] = (0, 0),
): ) -> None:
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.window_size = window_size # Wh, Ww self.window_size = window_size # Wh, Ww
@ -149,7 +149,7 @@ class WindowAttention(nn.Module):
self.proj_drop = nn.Dropout(proj_drop) self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1) self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask: Optional[torch.Tensor] = None): def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
""" """
Args: Args:
x: input features with shape of (num_windows*B, N, C) x: input features with shape of (num_windows*B, N, C)
@ -197,20 +197,20 @@ class SwinTransformerV2Block(nn.Module):
def __init__( def __init__(
self, self,
dim, dim: int,
input_resolution, input_resolution: _int_or_tuple_2_t,
num_heads, num_heads: int,
window_size=7, window_size: _int_or_tuple_2_t = 7,
shift_size=0, shift_size: _int_or_tuple_2_t = 0,
mlp_ratio=4., mlp_ratio: float = 4.,
qkv_bias=True, qkv_bias: bool = True,
proj_drop=0., proj_drop: float = 0.,
attn_drop=0., attn_drop: float = 0.,
drop_path=0., drop_path: float = 0.,
act_layer=nn.GELU, act_layer: nn.Module = nn.GELU,
norm_layer=nn.LayerNorm, norm_layer: nn.Module = nn.LayerNorm,
pretrained_window_size=0, pretrained_window_size: _int_or_tuple_2_t = 0,
): ) -> None:
""" """
Args: Args:
dim: Number of input channels. dim: Number of input channels.
@ -282,14 +282,16 @@ class SwinTransformerV2Block(nn.Module):
self.register_buffer("attn_mask", attn_mask, persistent=False) self.register_buffer("attn_mask", attn_mask, persistent=False)
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: def _calc_window_shift(self,
target_window_size: _int_or_tuple_2_t,
target_shift_size: _int_or_tuple_2_t) -> Tuple[Tuple[int, int], Tuple[int, int]]:
target_window_size = to_2tuple(target_window_size) target_window_size = to_2tuple(target_window_size)
target_shift_size = to_2tuple(target_shift_size) target_shift_size = to_2tuple(target_shift_size)
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
return tuple(window_size), tuple(shift_size) return tuple(window_size), tuple(shift_size)
def _attn(self, x): def _attn(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, C = x.shape B, H, W, C = x.shape
# cyclic shift # cyclic shift
@ -317,7 +319,7 @@ class SwinTransformerV2Block(nn.Module):
x = shifted_x x = shifted_x
return x return x
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, C = x.shape B, H, W, C = x.shape
x = x + self.drop_path1(self.norm1(self._attn(x))) x = x + self.drop_path1(self.norm1(self._attn(x)))
x = x.reshape(B, -1, C) x = x.reshape(B, -1, C)
@ -330,7 +332,7 @@ class PatchMerging(nn.Module):
""" Patch Merging Layer. """ Patch Merging Layer.
""" """
def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm): def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: nn.Module = nn.LayerNorm) -> None:
""" """
Args: Args:
dim (int): Number of input channels. dim (int): Number of input channels.
@ -343,7 +345,7 @@ class PatchMerging(nn.Module):
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False) self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
self.norm = norm_layer(self.out_dim) self.norm = norm_layer(self.out_dim)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, C = x.shape B, H, W, C = x.shape
_assert(H % 2 == 0, f"x height ({H}) is not even.") _assert(H % 2 == 0, f"x height ({H}) is not even.")
_assert(W % 2 == 0, f"x width ({W}) is not even.") _assert(W % 2 == 0, f"x width ({W}) is not even.")
@ -359,22 +361,22 @@ class SwinTransformerV2Stage(nn.Module):
def __init__( def __init__(
self, self,
dim, dim: int,
out_dim, out_dim: int,
input_resolution, input_resolution: _int_or_tuple_2_t,
depth, depth: int,
num_heads, num_heads: int,
window_size, window_size: _int_or_tuple_2_t,
downsample=False, downsample: bool = False,
mlp_ratio=4., mlp_ratio: float = 4.,
qkv_bias=True, qkv_bias: bool = True,
proj_drop=0., proj_drop: float = 0.,
attn_drop=0., attn_drop: float = 0.,
drop_path=0., drop_path: float = 0.,
norm_layer=nn.LayerNorm, norm_layer: nn.Module = nn.LayerNorm,
pretrained_window_size=0, pretrained_window_size: _int_or_tuple_2_t = 0,
output_nchw=False, output_nchw: bool = False,
): ) -> None:
""" """
Args: Args:
dim: Number of input channels. dim: Number of input channels.
@ -428,7 +430,7 @@ class SwinTransformerV2Stage(nn.Module):
) )
for i in range(depth)]) for i in range(depth)])
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.downsample(x) x = self.downsample(x)
for blk in self.blocks: for blk in self.blocks:
@ -438,7 +440,7 @@ class SwinTransformerV2Stage(nn.Module):
x = blk(x) x = blk(x)
return x return x
def _init_respostnorm(self): def _init_respostnorm(self) -> None:
for blk in self.blocks: for blk in self.blocks:
nn.init.constant_(blk.norm1.bias, 0) nn.init.constant_(blk.norm1.bias, 0)
nn.init.constant_(blk.norm1.weight, 0) nn.init.constant_(blk.norm1.weight, 0)