mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
add type annotations in the code of swin_transformer_v2
This commit is contained in:
parent
bbe798317f
commit
7da34a999a
@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
|
||||
# Written by Ze Liu
|
||||
# --------------------------------------------------------
|
||||
import math
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
from typing import Callable, Optional, Tuple, Union, Set, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -32,7 +32,7 @@ __all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to
|
||||
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
|
||||
|
||||
|
||||
def window_partition(x, window_size: Tuple[int, int]):
|
||||
def window_partition(x: torch.Tensor, window_size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
@ -48,7 +48,7 @@ def window_partition(x, window_size: Tuple[int, int]):
|
||||
|
||||
|
||||
@register_notrace_function # reason: int argument is a Proxy
|
||||
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
|
||||
def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], img_size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows * B, window_size[0], window_size[1], C)
|
||||
@ -81,14 +81,14 @@ class WindowAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
pretrained_window_size=[0, 0],
|
||||
):
|
||||
dim: int,
|
||||
window_size: Tuple[int, int],
|
||||
num_heads: int,
|
||||
qkv_bias: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
pretrained_window_size: Tuple[int, int] = (0, 0),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
@ -149,7 +149,7 @@ class WindowAttention(nn.Module):
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask: Optional[torch.Tensor] = None):
|
||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
@ -197,20 +197,20 @@ class SwinTransformerV2Block(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
shift_size=0,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
pretrained_window_size=0,
|
||||
):
|
||||
dim: int,
|
||||
input_resolution: _int_or_tuple_2_t,
|
||||
num_heads: int,
|
||||
window_size: _int_or_tuple_2_t = 7,
|
||||
shift_size: _int_or_tuple_2_t = 0,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
pretrained_window_size: _int_or_tuple_2_t = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
@ -282,14 +282,16 @@ class SwinTransformerV2Block(nn.Module):
|
||||
|
||||
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
||||
|
||||
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
def _calc_window_shift(self,
|
||||
target_window_size: _int_or_tuple_2_t,
|
||||
target_shift_size: _int_or_tuple_2_t) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
target_window_size = to_2tuple(target_window_size)
|
||||
target_shift_size = to_2tuple(target_shift_size)
|
||||
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
|
||||
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
|
||||
return tuple(window_size), tuple(shift_size)
|
||||
|
||||
def _attn(self, x):
|
||||
def _attn(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, H, W, C = x.shape
|
||||
|
||||
# cyclic shift
|
||||
@ -317,7 +319,7 @@ class SwinTransformerV2Block(nn.Module):
|
||||
x = shifted_x
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, H, W, C = x.shape
|
||||
x = x + self.drop_path1(self.norm1(self._attn(x)))
|
||||
x = x.reshape(B, -1, C)
|
||||
@ -330,7 +332,7 @@ class PatchMerging(nn.Module):
|
||||
""" Patch Merging Layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm):
|
||||
def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: nn.Module = nn.LayerNorm) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
@ -343,7 +345,7 @@ class PatchMerging(nn.Module):
|
||||
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
|
||||
self.norm = norm_layer(self.out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, H, W, C = x.shape
|
||||
_assert(H % 2 == 0, f"x height ({H}) is not even.")
|
||||
_assert(W % 2 == 0, f"x width ({W}) is not even.")
|
||||
@ -359,22 +361,22 @@ class SwinTransformerV2Stage(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
out_dim,
|
||||
input_resolution,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size,
|
||||
downsample=False,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
proj_drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
pretrained_window_size=0,
|
||||
output_nchw=False,
|
||||
):
|
||||
dim: int,
|
||||
out_dim: int,
|
||||
input_resolution: _int_or_tuple_2_t,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
window_size: _int_or_tuple_2_t,
|
||||
downsample: bool = False,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
pretrained_window_size: _int_or_tuple_2_t = 0,
|
||||
output_nchw: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim: Number of input channels.
|
||||
@ -428,7 +430,7 @@ class SwinTransformerV2Stage(nn.Module):
|
||||
)
|
||||
for i in range(depth)])
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.downsample(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
@ -438,7 +440,7 @@ class SwinTransformerV2Stage(nn.Module):
|
||||
x = blk(x)
|
||||
return x
|
||||
|
||||
def _init_respostnorm(self):
|
||||
def _init_respostnorm(self) -> None:
|
||||
for blk in self.blocks:
|
||||
nn.init.constant_(blk.norm1.bias, 0)
|
||||
nn.init.constant_(blk.norm1.weight, 0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user