mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
add type annotations in the code of swin_transformer_v2
This commit is contained in:
parent
bbe798317f
commit
7da34a999a
@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
|
|||||||
# Written by Ze Liu
|
# Written by Ze Liu
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
import math
|
import math
|
||||||
from typing import Callable, Optional, Tuple, Union
|
from typing import Callable, Optional, Tuple, Union, Set, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -32,7 +32,7 @@ __all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to
|
|||||||
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
|
_int_or_tuple_2_t = Union[int, Tuple[int, int]]
|
||||||
|
|
||||||
|
|
||||||
def window_partition(x, window_size: Tuple[int, int]):
|
def window_partition(x: torch.Tensor, window_size: Tuple[int, int]) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: (B, H, W, C)
|
x: (B, H, W, C)
|
||||||
@ -48,7 +48,7 @@ def window_partition(x, window_size: Tuple[int, int]):
|
|||||||
|
|
||||||
|
|
||||||
@register_notrace_function # reason: int argument is a Proxy
|
@register_notrace_function # reason: int argument is a Proxy
|
||||||
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
|
def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], img_size: Tuple[int, int]) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
windows: (num_windows * B, window_size[0], window_size[1], C)
|
windows: (num_windows * B, window_size[0], window_size[1], C)
|
||||||
@ -81,14 +81,14 @@ class WindowAttention(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim: int,
|
||||||
window_size,
|
window_size: Tuple[int, int],
|
||||||
num_heads,
|
num_heads: int,
|
||||||
qkv_bias=True,
|
qkv_bias: bool = True,
|
||||||
attn_drop=0.,
|
attn_drop: float = 0.,
|
||||||
proj_drop=0.,
|
proj_drop: float = 0.,
|
||||||
pretrained_window_size=[0, 0],
|
pretrained_window_size: Tuple[int, int] = (0, 0),
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.window_size = window_size # Wh, Ww
|
self.window_size = window_size # Wh, Ww
|
||||||
@ -149,7 +149,7 @@ class WindowAttention(nn.Module):
|
|||||||
self.proj_drop = nn.Dropout(proj_drop)
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
self.softmax = nn.Softmax(dim=-1)
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
def forward(self, x, mask: Optional[torch.Tensor] = None):
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: input features with shape of (num_windows*B, N, C)
|
x: input features with shape of (num_windows*B, N, C)
|
||||||
@ -197,20 +197,20 @@ class SwinTransformerV2Block(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim: int,
|
||||||
input_resolution,
|
input_resolution: _int_or_tuple_2_t,
|
||||||
num_heads,
|
num_heads: int,
|
||||||
window_size=7,
|
window_size: _int_or_tuple_2_t = 7,
|
||||||
shift_size=0,
|
shift_size: _int_or_tuple_2_t = 0,
|
||||||
mlp_ratio=4.,
|
mlp_ratio: float = 4.,
|
||||||
qkv_bias=True,
|
qkv_bias: bool = True,
|
||||||
proj_drop=0.,
|
proj_drop: float = 0.,
|
||||||
attn_drop=0.,
|
attn_drop: float = 0.,
|
||||||
drop_path=0.,
|
drop_path: float = 0.,
|
||||||
act_layer=nn.GELU,
|
act_layer: nn.Module = nn.GELU,
|
||||||
norm_layer=nn.LayerNorm,
|
norm_layer: nn.Module = nn.LayerNorm,
|
||||||
pretrained_window_size=0,
|
pretrained_window_size: _int_or_tuple_2_t = 0,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
dim: Number of input channels.
|
dim: Number of input channels.
|
||||||
@ -282,14 +282,16 @@ class SwinTransformerV2Block(nn.Module):
|
|||||||
|
|
||||||
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
||||||
|
|
||||||
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
def _calc_window_shift(self,
|
||||||
|
target_window_size: _int_or_tuple_2_t,
|
||||||
|
target_shift_size: _int_or_tuple_2_t) -> Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||||
target_window_size = to_2tuple(target_window_size)
|
target_window_size = to_2tuple(target_window_size)
|
||||||
target_shift_size = to_2tuple(target_shift_size)
|
target_shift_size = to_2tuple(target_shift_size)
|
||||||
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
|
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
|
||||||
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
|
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
|
||||||
return tuple(window_size), tuple(shift_size)
|
return tuple(window_size), tuple(shift_size)
|
||||||
|
|
||||||
def _attn(self, x):
|
def _attn(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
|
|
||||||
# cyclic shift
|
# cyclic shift
|
||||||
@ -317,7 +319,7 @@ class SwinTransformerV2Block(nn.Module):
|
|||||||
x = shifted_x
|
x = shifted_x
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
x = x + self.drop_path1(self.norm1(self._attn(x)))
|
x = x + self.drop_path1(self.norm1(self._attn(x)))
|
||||||
x = x.reshape(B, -1, C)
|
x = x.reshape(B, -1, C)
|
||||||
@ -330,7 +332,7 @@ class PatchMerging(nn.Module):
|
|||||||
""" Patch Merging Layer.
|
""" Patch Merging Layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm):
|
def __init__(self, dim: int, out_dim: Optional[int] = None, norm_layer: nn.Module = nn.LayerNorm) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
dim (int): Number of input channels.
|
dim (int): Number of input channels.
|
||||||
@ -343,7 +345,7 @@ class PatchMerging(nn.Module):
|
|||||||
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
|
self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)
|
||||||
self.norm = norm_layer(self.out_dim)
|
self.norm = norm_layer(self.out_dim)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
_assert(H % 2 == 0, f"x height ({H}) is not even.")
|
_assert(H % 2 == 0, f"x height ({H}) is not even.")
|
||||||
_assert(W % 2 == 0, f"x width ({W}) is not even.")
|
_assert(W % 2 == 0, f"x width ({W}) is not even.")
|
||||||
@ -359,22 +361,22 @@ class SwinTransformerV2Stage(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim,
|
dim: int,
|
||||||
out_dim,
|
out_dim: int,
|
||||||
input_resolution,
|
input_resolution: _int_or_tuple_2_t,
|
||||||
depth,
|
depth: int,
|
||||||
num_heads,
|
num_heads: int,
|
||||||
window_size,
|
window_size: _int_or_tuple_2_t,
|
||||||
downsample=False,
|
downsample: bool = False,
|
||||||
mlp_ratio=4.,
|
mlp_ratio: float = 4.,
|
||||||
qkv_bias=True,
|
qkv_bias: bool = True,
|
||||||
proj_drop=0.,
|
proj_drop: float = 0.,
|
||||||
attn_drop=0.,
|
attn_drop: float = 0.,
|
||||||
drop_path=0.,
|
drop_path: float = 0.,
|
||||||
norm_layer=nn.LayerNorm,
|
norm_layer: nn.Module = nn.LayerNorm,
|
||||||
pretrained_window_size=0,
|
pretrained_window_size: _int_or_tuple_2_t = 0,
|
||||||
output_nchw=False,
|
output_nchw: bool = False,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
dim: Number of input channels.
|
dim: Number of input channels.
|
||||||
@ -428,7 +430,7 @@ class SwinTransformerV2Stage(nn.Module):
|
|||||||
)
|
)
|
||||||
for i in range(depth)])
|
for i in range(depth)])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = self.downsample(x)
|
x = self.downsample(x)
|
||||||
|
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
@ -438,7 +440,7 @@ class SwinTransformerV2Stage(nn.Module):
|
|||||||
x = blk(x)
|
x = blk(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _init_respostnorm(self):
|
def _init_respostnorm(self) -> None:
|
||||||
for blk in self.blocks:
|
for blk in self.blocks:
|
||||||
nn.init.constant_(blk.norm1.bias, 0)
|
nn.init.constant_(blk.norm1.bias, 0)
|
||||||
nn.init.constant_(blk.norm1.weight, 0)
|
nn.init.constant_(blk.norm1.weight, 0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user