mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2400 from adamjstewart/types/nn-module
Fix nn.Module type hints
This commit is contained in:
commit
ff77dfa825
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
@ -106,7 +106,7 @@ class MultiQueryAttention2d(nn.Module):
|
|||||||
padding: Union[str, int, List[int]] = '',
|
padding: Union[str, int, List[int]] = '',
|
||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.,
|
||||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||||
use_bias: bool = False,
|
use_bias: bool = False,
|
||||||
):
|
):
|
||||||
"""Initializer.
|
"""Initializer.
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -54,7 +54,7 @@ class MobileOneBlock(nn.Module):
|
|||||||
use_act: bool = True,
|
use_act: bool = True,
|
||||||
use_scale_branch: bool = True,
|
use_scale_branch: bool = True,
|
||||||
num_conv_branches: int = 1,
|
num_conv_branches: int = 1,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct a MobileOneBlock module.
|
"""Construct a MobileOneBlock module.
|
||||||
|
|
||||||
@ -426,7 +426,7 @@ class ReparamLargeKernelConv(nn.Module):
|
|||||||
def convolutional_stem(
|
def convolutional_stem(
|
||||||
in_chs: int,
|
in_chs: int,
|
||||||
out_chs: int,
|
out_chs: int,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
inference_mode: bool = False
|
inference_mode: bool = False
|
||||||
) -> nn.Sequential:
|
) -> nn.Sequential:
|
||||||
"""Build convolutional stem with MobileOne blocks.
|
"""Build convolutional stem with MobileOne blocks.
|
||||||
@ -545,7 +545,7 @@ class PatchEmbed(nn.Module):
|
|||||||
stride: int,
|
stride: int,
|
||||||
in_chs: int,
|
in_chs: int,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
lkc_use_act: bool = False,
|
lkc_use_act: bool = False,
|
||||||
use_se: bool = False,
|
use_se: bool = False,
|
||||||
inference_mode: bool = False,
|
inference_mode: bool = False,
|
||||||
@ -718,7 +718,7 @@ class ConvMlp(nn.Module):
|
|||||||
in_chs: int,
|
in_chs: int,
|
||||||
hidden_channels: Optional[int] = None,
|
hidden_channels: Optional[int] = None,
|
||||||
out_chs: Optional[int] = None,
|
out_chs: Optional[int] = None,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
drop: float = 0.0,
|
drop: float = 0.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Build convolutional FFN module.
|
"""Build convolutional FFN module.
|
||||||
@ -890,7 +890,7 @@ class RepMixerBlock(nn.Module):
|
|||||||
dim: int,
|
dim: int,
|
||||||
kernel_size: int = 3,
|
kernel_size: int = 3,
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
proj_drop: float = 0.0,
|
proj_drop: float = 0.0,
|
||||||
drop_path: float = 0.0,
|
drop_path: float = 0.0,
|
||||||
layer_scale_init_value: float = 1e-5,
|
layer_scale_init_value: float = 1e-5,
|
||||||
@ -947,8 +947,8 @@ class AttentionBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||||
proj_drop: float = 0.0,
|
proj_drop: float = 0.0,
|
||||||
drop_path: float = 0.0,
|
drop_path: float = 0.0,
|
||||||
layer_scale_init_value: float = 1e-5,
|
layer_scale_init_value: float = 1e-5,
|
||||||
@ -1007,8 +1007,8 @@ class FastVitStage(nn.Module):
|
|||||||
pos_emb_layer: Optional[nn.Module] = None,
|
pos_emb_layer: Optional[nn.Module] = None,
|
||||||
kernel_size: int = 3,
|
kernel_size: int = 3,
|
||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||||
proj_drop_rate: float = 0.0,
|
proj_drop_rate: float = 0.0,
|
||||||
drop_path_rate: float = 0.0,
|
drop_path_rate: float = 0.0,
|
||||||
layer_scale_init_value: Optional[float] = 1e-5,
|
layer_scale_init_value: Optional[float] = 1e-5,
|
||||||
@ -1121,8 +1121,8 @@ class FastVit(nn.Module):
|
|||||||
fork_feat: bool = False,
|
fork_feat: bool = False,
|
||||||
cls_ratio: float = 2.0,
|
cls_ratio: float = 2.0,
|
||||||
global_pool: str = 'avg',
|
global_pool: str = 'avg',
|
||||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
inference_mode: bool = False,
|
inference_mode: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -316,8 +316,8 @@ class HieraBlock(nn.Module):
|
|||||||
mlp_ratio: float = 4.0,
|
mlp_ratio: float = 4.0,
|
||||||
drop_path: float = 0.0,
|
drop_path: float = 0.0,
|
||||||
init_values: Optional[float] = None,
|
init_values: Optional[float] = None,
|
||||||
norm_layer: nn.Module = nn.LayerNorm,
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
q_stride: int = 1,
|
q_stride: int = 1,
|
||||||
window_size: int = 0,
|
window_size: int = 0,
|
||||||
use_expand_proj: bool = True,
|
use_expand_proj: bool = True,
|
||||||
|
@ -13,7 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
|
|||||||
# Written by Ze Liu
|
# Written by Ze Liu
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
import math
|
import math
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -230,7 +230,7 @@ class SwinTransformerV2Block(nn.Module):
|
|||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
drop_path: float = 0.,
|
drop_path: float = 0.,
|
||||||
act_layer: LayerType = "gelu",
|
act_layer: LayerType = "gelu",
|
||||||
norm_layer: nn.Module = nn.LayerNorm,
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
pretrained_window_size: _int_or_tuple_2_t = 0,
|
pretrained_window_size: _int_or_tuple_2_t = 0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -422,7 +422,7 @@ class PatchMerging(nn.Module):
|
|||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
out_dim: Optional[int] = None,
|
out_dim: Optional[int] = None,
|
||||||
norm_layer: nn.Module = nn.LayerNorm
|
norm_layer: Type[nn.Module] = nn.LayerNorm
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -470,7 +470,7 @@ class SwinTransformerV2Stage(nn.Module):
|
|||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
drop_path: float = 0.,
|
drop_path: float = 0.,
|
||||||
act_layer: Union[str, Callable] = 'gelu',
|
act_layer: Union[str, Callable] = 'gelu',
|
||||||
norm_layer: nn.Module = nn.LayerNorm,
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
pretrained_window_size: _int_or_tuple_2_t = 0,
|
pretrained_window_size: _int_or_tuple_2_t = 0,
|
||||||
output_nchw: bool = False,
|
output_nchw: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -5,7 +5,7 @@ timm functionality.
|
|||||||
|
|
||||||
Copyright 2021 Ross Wightman
|
Copyright 2021 Ross Wightman
|
||||||
"""
|
"""
|
||||||
from typing import Any, Dict, List, Optional, Union, cast
|
from typing import Any, Dict, List, Optional, Type, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -38,8 +38,8 @@ class ConvMlp(nn.Module):
|
|||||||
kernel_size=7,
|
kernel_size=7,
|
||||||
mlp_ratio=1.0,
|
mlp_ratio=1.0,
|
||||||
drop_rate: float = 0.2,
|
drop_rate: float = 0.2,
|
||||||
act_layer: nn.Module = None,
|
act_layer: Optional[Type[nn.Module]] = None,
|
||||||
conv_layer: nn.Module = None,
|
conv_layer: Optional[Type[nn.Module]] = None,
|
||||||
):
|
):
|
||||||
super(ConvMlp, self).__init__()
|
super(ConvMlp, self).__init__()
|
||||||
self.input_kernel_size = kernel_size
|
self.input_kernel_size = kernel_size
|
||||||
@ -72,9 +72,9 @@ class VGG(nn.Module):
|
|||||||
in_chans: int = 3,
|
in_chans: int = 3,
|
||||||
output_stride: int = 32,
|
output_stride: int = 32,
|
||||||
mlp_ratio: float = 1.0,
|
mlp_ratio: float = 1.0,
|
||||||
act_layer: nn.Module = nn.ReLU,
|
act_layer: Type[nn.Module] = nn.ReLU,
|
||||||
conv_layer: nn.Module = nn.Conv2d,
|
conv_layer: Type[nn.Module] = nn.Conv2d,
|
||||||
norm_layer: nn.Module = None,
|
norm_layer: Optional[Type[nn.Module]] = None,
|
||||||
global_pool: str = 'avg',
|
global_pool: str = 'avg',
|
||||||
drop_rate: float = 0.,
|
drop_rate: float = 0.,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -295,4 +295,4 @@ def vgg19_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
|
|||||||
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
|
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
|
||||||
"""
|
"""
|
||||||
model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
|
model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
|
||||||
return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args)
|
return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args)
|
||||||
|
@ -67,7 +67,7 @@ class Attention(nn.Module):
|
|||||||
proj_bias: bool = True,
|
proj_bias: bool = True,
|
||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.,
|
||||||
norm_layer: nn.Module = nn.LayerNorm,
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||||
@ -135,9 +135,9 @@ class Block(nn.Module):
|
|||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
init_values: Optional[float] = None,
|
init_values: Optional[float] = None,
|
||||||
drop_path: float = 0.,
|
drop_path: float = 0.,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
norm_layer: nn.Module = nn.LayerNorm,
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
mlp_layer: nn.Module = Mlp,
|
mlp_layer: Type[nn.Module] = Mlp,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
@ -184,9 +184,9 @@ class ResPostBlock(nn.Module):
|
|||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
init_values: Optional[float] = None,
|
init_values: Optional[float] = None,
|
||||||
drop_path: float = 0.,
|
drop_path: float = 0.,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
norm_layer: nn.Module = nn.LayerNorm,
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
mlp_layer: nn.Module = Mlp,
|
mlp_layer: Type[nn.Module] = Mlp,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.init_values = init_values
|
self.init_values = init_values
|
||||||
@ -247,9 +247,9 @@ class ParallelScalingBlock(nn.Module):
|
|||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
init_values: Optional[float] = None,
|
init_values: Optional[float] = None,
|
||||||
drop_path: float = 0.,
|
drop_path: float = 0.,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
norm_layer: nn.Module = nn.LayerNorm,
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
mlp_layer: Optional[nn.Module] = None,
|
mlp_layer: Optional[Type[nn.Module]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||||
@ -342,9 +342,9 @@ class ParallelThingsBlock(nn.Module):
|
|||||||
proj_drop: float = 0.,
|
proj_drop: float = 0.,
|
||||||
attn_drop: float = 0.,
|
attn_drop: float = 0.,
|
||||||
drop_path: float = 0.,
|
drop_path: float = 0.,
|
||||||
act_layer: nn.Module = nn.GELU,
|
act_layer: Type[nn.Module] = nn.GELU,
|
||||||
norm_layer: nn.Module = nn.LayerNorm,
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
mlp_layer: nn.Module = Mlp,
|
mlp_layer: Type[nn.Module] = Mlp,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_parallel = num_parallel
|
self.num_parallel = num_parallel
|
||||||
|
Loading…
x
Reference in New Issue
Block a user