Fix nn.Module type hints
parent
47811bc05a
commit
19aaea3c8f
|
@ -106,7 +106,7 @@ class MultiQueryAttention2d(nn.Module):
|
|||
padding: Union[str, int, List[int]] = '',
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||
use_bias: bool = False,
|
||||
):
|
||||
"""Initializer.
|
||||
|
|
|
@ -54,7 +54,7 @@ class MobileOneBlock(nn.Module):
|
|||
use_act: bool = True,
|
||||
use_scale_branch: bool = True,
|
||||
num_conv_branches: int = 1,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
"""Construct a MobileOneBlock module.
|
||||
|
||||
|
@ -426,7 +426,7 @@ class ReparamLargeKernelConv(nn.Module):
|
|||
def convolutional_stem(
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
inference_mode: bool = False
|
||||
) -> nn.Sequential:
|
||||
"""Build convolutional stem with MobileOne blocks.
|
||||
|
@ -545,7 +545,7 @@ class PatchEmbed(nn.Module):
|
|||
stride: int,
|
||||
in_chs: int,
|
||||
embed_dim: int,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
lkc_use_act: bool = False,
|
||||
use_se: bool = False,
|
||||
inference_mode: bool = False,
|
||||
|
@ -718,7 +718,7 @@ class ConvMlp(nn.Module):
|
|||
in_chs: int,
|
||||
hidden_channels: Optional[int] = None,
|
||||
out_chs: Optional[int] = None,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
"""Build convolutional FFN module.
|
||||
|
@ -890,7 +890,7 @@ class RepMixerBlock(nn.Module):
|
|||
dim: int,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
proj_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
|
@ -947,8 +947,8 @@ class AttentionBlock(nn.Module):
|
|||
self,
|
||||
dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||
proj_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
layer_scale_init_value: float = 1e-5,
|
||||
|
@ -1007,8 +1007,8 @@ class FastVitStage(nn.Module):
|
|||
pos_emb_layer: Optional[nn.Module] = None,
|
||||
kernel_size: int = 3,
|
||||
mlp_ratio: float = 4.0,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||
proj_drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
layer_scale_init_value: Optional[float] = 1e-5,
|
||||
|
@ -1121,8 +1121,8 @@ class FastVit(nn.Module):
|
|||
fork_feat: bool = False,
|
||||
cls_ratio: float = 2.0,
|
||||
global_pool: str = 'avg',
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
inference_mode: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
|
|
@ -316,8 +316,8 @@ class HieraBlock(nn.Module):
|
|||
mlp_ratio: float = 4.0,
|
||||
drop_path: float = 0.0,
|
||||
init_values: Optional[float] = None,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
q_stride: int = 1,
|
||||
window_size: int = 0,
|
||||
use_expand_proj: bool = True,
|
||||
|
|
|
@ -230,7 +230,7 @@ class SwinTransformerV2Block(nn.Module):
|
|||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
act_layer: LayerType = "gelu",
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
pretrained_window_size: _int_or_tuple_2_t = 0,
|
||||
):
|
||||
"""
|
||||
|
@ -422,7 +422,7 @@ class PatchMerging(nn.Module):
|
|||
self,
|
||||
dim: int,
|
||||
out_dim: Optional[int] = None,
|
||||
norm_layer: nn.Module = nn.LayerNorm
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -470,7 +470,7 @@ class SwinTransformerV2Stage(nn.Module):
|
|||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
act_layer: Union[str, Callable] = 'gelu',
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
pretrained_window_size: _int_or_tuple_2_t = 0,
|
||||
output_nchw: bool = False,
|
||||
) -> None:
|
||||
|
|
|
@ -38,8 +38,8 @@ class ConvMlp(nn.Module):
|
|||
kernel_size=7,
|
||||
mlp_ratio=1.0,
|
||||
drop_rate: float = 0.2,
|
||||
act_layer: nn.Module = None,
|
||||
conv_layer: nn.Module = None,
|
||||
act_layer: Optional[Type[nn.Module]] = None,
|
||||
conv_layer: Optional[Type[nn.Module]] = None,
|
||||
):
|
||||
super(ConvMlp, self).__init__()
|
||||
self.input_kernel_size = kernel_size
|
||||
|
@ -72,9 +72,9 @@ class VGG(nn.Module):
|
|||
in_chans: int = 3,
|
||||
output_stride: int = 32,
|
||||
mlp_ratio: float = 1.0,
|
||||
act_layer: nn.Module = nn.ReLU,
|
||||
conv_layer: nn.Module = nn.Conv2d,
|
||||
norm_layer: nn.Module = None,
|
||||
act_layer: Type[nn.Module] = nn.ReLU,
|
||||
conv_layer: Type[nn.Module] = nn.Conv2d,
|
||||
norm_layer: Optional[Type[nn.Module]] = None,
|
||||
global_pool: str = 'avg',
|
||||
drop_rate: float = 0.,
|
||||
) -> None:
|
||||
|
@ -295,4 +295,4 @@ def vgg19_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
|
|||
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
|
||||
"""
|
||||
model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
|
||||
return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args)
|
||||
return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args)
|
||||
|
|
|
@ -67,7 +67,7 @@ class Attention(nn.Module):
|
|||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
@ -135,9 +135,9 @@ class Block(nn.Module):
|
|||
attn_drop: float = 0.,
|
||||
init_values: Optional[float] = None,
|
||||
drop_path: float = 0.,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
mlp_layer: nn.Module = Mlp,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
mlp_layer: Type[nn.Module] = Mlp,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
@ -184,9 +184,9 @@ class ResPostBlock(nn.Module):
|
|||
attn_drop: float = 0.,
|
||||
init_values: Optional[float] = None,
|
||||
drop_path: float = 0.,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
mlp_layer: nn.Module = Mlp,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
mlp_layer: Type[nn.Module] = Mlp,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.init_values = init_values
|
||||
|
@ -247,9 +247,9 @@ class ParallelScalingBlock(nn.Module):
|
|||
attn_drop: float = 0.,
|
||||
init_values: Optional[float] = None,
|
||||
drop_path: float = 0.,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
mlp_layer: Optional[nn.Module] = None,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
mlp_layer: Optional[Type[nn.Module]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
@ -342,9 +342,9 @@ class ParallelThingsBlock(nn.Module):
|
|||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
drop_path: float = 0.,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
mlp_layer: nn.Module = Mlp,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
mlp_layer: Type[nn.Module] = Mlp,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_parallel = num_parallel
|
||||
|
|
Loading…
Reference in New Issue