Merge pull request #2400 from adamjstewart/types/nn-module

Fix nn.Module type hints
This commit is contained in:
Ross Wightman 2025-01-14 08:23:57 -08:00 committed by GitHub
commit ff77dfa825
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 40 additions and 40 deletions

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import List, Optional, Type, Union
import torch
from torch import nn as nn
@ -106,7 +106,7 @@ class MultiQueryAttention2d(nn.Module):
padding: Union[str, int, List[int]] = '',
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.BatchNorm2d,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
use_bias: bool = False,
):
"""Initializer.

View File

@ -7,7 +7,7 @@
#
import os
from functools import partial
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
@ -54,7 +54,7 @@ class MobileOneBlock(nn.Module):
use_act: bool = True,
use_scale_branch: bool = True,
num_conv_branches: int = 1,
act_layer: nn.Module = nn.GELU,
act_layer: Type[nn.Module] = nn.GELU,
) -> None:
"""Construct a MobileOneBlock module.
@ -426,7 +426,7 @@ class ReparamLargeKernelConv(nn.Module):
def convolutional_stem(
in_chs: int,
out_chs: int,
act_layer: nn.Module = nn.GELU,
act_layer: Type[nn.Module] = nn.GELU,
inference_mode: bool = False
) -> nn.Sequential:
"""Build convolutional stem with MobileOne blocks.
@ -545,7 +545,7 @@ class PatchEmbed(nn.Module):
stride: int,
in_chs: int,
embed_dim: int,
act_layer: nn.Module = nn.GELU,
act_layer: Type[nn.Module] = nn.GELU,
lkc_use_act: bool = False,
use_se: bool = False,
inference_mode: bool = False,
@ -718,7 +718,7 @@ class ConvMlp(nn.Module):
in_chs: int,
hidden_channels: Optional[int] = None,
out_chs: Optional[int] = None,
act_layer: nn.Module = nn.GELU,
act_layer: Type[nn.Module] = nn.GELU,
drop: float = 0.0,
) -> None:
"""Build convolutional FFN module.
@ -890,7 +890,7 @@ class RepMixerBlock(nn.Module):
dim: int,
kernel_size: int = 3,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
act_layer: Type[nn.Module] = nn.GELU,
proj_drop: float = 0.0,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-5,
@ -947,8 +947,8 @@ class AttentionBlock(nn.Module):
self,
dim: int,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
proj_drop: float = 0.0,
drop_path: float = 0.0,
layer_scale_init_value: float = 1e-5,
@ -1007,8 +1007,8 @@ class FastVitStage(nn.Module):
pos_emb_layer: Optional[nn.Module] = None,
kernel_size: int = 3,
mlp_ratio: float = 4.0,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
proj_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
layer_scale_init_value: Optional[float] = 1e-5,
@ -1121,8 +1121,8 @@ class FastVit(nn.Module):
fork_feat: bool = False,
cls_ratio: float = 2.0,
global_pool: str = 'avg',
norm_layer: nn.Module = nn.BatchNorm2d,
act_layer: nn.Module = nn.GELU,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
act_layer: Type[nn.Module] = nn.GELU,
inference_mode: bool = False,
) -> None:
super().__init__()

View File

@ -316,8 +316,8 @@ class HieraBlock(nn.Module):
mlp_ratio: float = 4.0,
drop_path: float = 0.0,
init_values: Optional[float] = None,
norm_layer: nn.Module = nn.LayerNorm,
act_layer: nn.Module = nn.GELU,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
q_stride: int = 1,
window_size: int = 0,
use_expand_proj: bool = True,

View File

@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
# Written by Ze Liu
# --------------------------------------------------------
import math
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
@ -230,7 +230,7 @@ class SwinTransformerV2Block(nn.Module):
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: LayerType = "gelu",
norm_layer: nn.Module = nn.LayerNorm,
norm_layer: Type[nn.Module] = nn.LayerNorm,
pretrained_window_size: _int_or_tuple_2_t = 0,
):
"""
@ -422,7 +422,7 @@ class PatchMerging(nn.Module):
self,
dim: int,
out_dim: Optional[int] = None,
norm_layer: nn.Module = nn.LayerNorm
norm_layer: Type[nn.Module] = nn.LayerNorm
):
"""
Args:
@ -470,7 +470,7 @@ class SwinTransformerV2Stage(nn.Module):
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: Union[str, Callable] = 'gelu',
norm_layer: nn.Module = nn.LayerNorm,
norm_layer: Type[nn.Module] = nn.LayerNorm,
pretrained_window_size: _int_or_tuple_2_t = 0,
output_nchw: bool = False,
) -> None:

View File

@ -5,7 +5,7 @@ timm functionality.
Copyright 2021 Ross Wightman
"""
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Type, Union, cast
import torch
import torch.nn as nn
@ -38,8 +38,8 @@ class ConvMlp(nn.Module):
kernel_size=7,
mlp_ratio=1.0,
drop_rate: float = 0.2,
act_layer: nn.Module = None,
conv_layer: nn.Module = None,
act_layer: Optional[Type[nn.Module]] = None,
conv_layer: Optional[Type[nn.Module]] = None,
):
super(ConvMlp, self).__init__()
self.input_kernel_size = kernel_size
@ -72,9 +72,9 @@ class VGG(nn.Module):
in_chans: int = 3,
output_stride: int = 32,
mlp_ratio: float = 1.0,
act_layer: nn.Module = nn.ReLU,
conv_layer: nn.Module = nn.Conv2d,
norm_layer: nn.Module = None,
act_layer: Type[nn.Module] = nn.ReLU,
conv_layer: Type[nn.Module] = nn.Conv2d,
norm_layer: Optional[Type[nn.Module]] = None,
global_pool: str = 'avg',
drop_rate: float = 0.,
) -> None:
@ -295,4 +295,4 @@ def vgg19_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
"""
model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args)
return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args)

View File

@ -67,7 +67,7 @@ class Attention(nn.Module):
proj_bias: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.LayerNorm,
norm_layer: Type[nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
@ -135,9 +135,9 @@ class Block(nn.Module):
attn_drop: float = 0.,
init_values: Optional[float] = None,
drop_path: float = 0.,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
mlp_layer: nn.Module = Mlp,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.LayerNorm,
mlp_layer: Type[nn.Module] = Mlp,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
@ -184,9 +184,9 @@ class ResPostBlock(nn.Module):
attn_drop: float = 0.,
init_values: Optional[float] = None,
drop_path: float = 0.,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
mlp_layer: nn.Module = Mlp,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.LayerNorm,
mlp_layer: Type[nn.Module] = Mlp,
) -> None:
super().__init__()
self.init_values = init_values
@ -247,9 +247,9 @@ class ParallelScalingBlock(nn.Module):
attn_drop: float = 0.,
init_values: Optional[float] = None,
drop_path: float = 0.,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
mlp_layer: Optional[nn.Module] = None,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.LayerNorm,
mlp_layer: Optional[Type[nn.Module]] = None,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
@ -342,9 +342,9 @@ class ParallelThingsBlock(nn.Module):
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: float = 0.,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
mlp_layer: nn.Module = Mlp,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.LayerNorm,
mlp_layer: Type[nn.Module] = Mlp,
) -> None:
super().__init__()
self.num_parallel = num_parallel