add type annotations in the code of swin_transformer_v2

This commit is contained in:
Li zhuoqun 2023-12-13 18:53:42 +08:00 committed by Ross Wightman
parent bbe798317f
commit 7da34a999a

View File

@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
# Written by Ze Liu
# --------------------------------------------------------
import math
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Optional, Tuple, Union, Set, Dict
import torch
import torch.nn as nn
@ -32,7 +32,7 @@ __all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
def window_partition(x, window_size: Tuple[int, int]):
def window_partition(x: torch.Tensor, window_size: Tuple[int, int]) -> torch.Tensor:
"""
Args:
x: (B, H, W, C)
@ -48,7 +48,7 @@ def window_partition(x, window_size: Tuple[int, int]):
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], img_size: Tuple[int, int]) -> torch.Tensor:
"""
Args:
windows: (num_windows * B, window_size[0], window_size[1], C)
@ -81,14 +81,14 @@ class WindowAttention(nn.Module):
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
attn_drop=0.,
proj_drop=0.,
pretrained_window_size=[0, 0],
):
dim: int,
window_size: Tuple[int, int],
num_heads: int,
qkv_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
pretrained_window_size: Tuple[int, int] = (0, 0),
) -> None:
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
@ -149,7 +149,7 @@ class WindowAttention(nn.Module):
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask: Optional[torch.Tensor] = None):
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
x: input features with shape of (num_windows*B, N, C)
@ -197,20 +197,20 @@ class SwinTransformerV2Block(nn.Module):
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.,
qkv_bias=True,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
pretrained_window_size=0,
):
dim: int,
input_resolution: _int_or_tuple_2_t,
num_heads: int,
window_size: _int_or_tuple_2_t = 7,
shift_size: _int_or_tuple_2_t = 0,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
pretrained_window_size: _int_or_tuple_2_t = 0,
) -> None:
"""
Args:
dim: Number of input channels.
@ -282,14 +282,16 @@ class SwinTransformerV2Block(nn.Module):
self.register_buffer("attn_mask", attn_mask, persistent=False)
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
def _calc_window_shift(self,
target_window_size: _int_or_tuple_2_t,
target_shift_size: _int_or_tuple_2_t) -> Tuple[Tuple[int, int], Tuple[int, int]]:
target_window_size = to_2tuple(target_window_size)
target_shift_size = to_2tuple(target_shift_size)
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
return tuple(window_size), tuple(shift_size)
def _attn(self, x):
def _attn(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, C = x.shape
# cyclic shift
@ -317,7 +319,7 @@ class SwinTransformerV2Block(nn.Module):
x = shifted_x
return x
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, C = x.shape
x = x + self.drop_path1(self.norm1(self._attn(x)))
x = x.reshape(B, -1, C)
@ -330,7 +332,7 @@ class PatchMerging(nn.Module):
""" Patch Merging Layer.
"""
def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm):
def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: nn.Module = nn.LayerNorm) -> None:
"""
Args:
dim (int): Number of input channels.
@ -343,7 +345,7 @@ class PatchMerging(nn.Module):
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
self.norm = norm_layer(self.out_dim)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, C = x.shape
_assert(H % 2 == 0, f"x height ({H}) is not even.")
_assert(W % 2 == 0, f"x width ({W}) is not even.")
@ -359,22 +361,22 @@ class SwinTransformerV2Stage(nn.Module):
def __init__(
self,
dim,
out_dim,
input_resolution,
depth,
num_heads,
window_size,
downsample=False,
mlp_ratio=4.,
qkv_bias=True,
proj_drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
pretrained_window_size=0,
output_nchw=False,
):
dim: int,
out_dim: int,
input_resolution: _int_or_tuple_2_t,
depth: int,
num_heads: int,
window_size: _int_or_tuple_2_t,
downsample: bool = False,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
norm_layer: nn.Module = nn.LayerNorm,
pretrained_window_size: _int_or_tuple_2_t = 0,
output_nchw: bool = False,
) -> None:
"""
Args:
dim: Number of input channels.
@ -428,7 +430,7 @@ class SwinTransformerV2Stage(nn.Module):
)
for i in range(depth)])
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.downsample(x)
for blk in self.blocks:
@ -438,7 +440,7 @@ class SwinTransformerV2Stage(nn.Module):
x = blk(x)
return x
def _init_respostnorm(self):
def _init_respostnorm(self) -> None:
for blk in self.blocks:
nn.init.constant_(blk.norm1.bias, 0)
nn.init.constant_(blk.norm1.weight, 0)