include typing suggestions by @rwightman

This commit is contained in:
a-r-r-o-w 2023-10-20 02:47:27 +05:30 committed by Ross Wightman
parent 05b0aaca51
commit 5f14bdd564
5 changed files with 25 additions and 24 deletions

View File

@ -52,4 +52,5 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .trace_utils import _assert, _float_to_int
from .typing import LayerType, PadType
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_

View File

@ -1,10 +1,9 @@
import functools
import types
from typing import Any, Dict, List, Tuple, Union
from typing import Tuple, Union
import torch.nn
BlockArgs = List[List[Dict[str, Any]]]
LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module]
PadType = Union[str, int, Tuple[int, int]]

View File

@ -11,6 +11,7 @@ import math
import re
from copy import deepcopy
from functools import partial
from typing import Any, Dict, List
import torch.nn as nn
@ -34,6 +35,8 @@ BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
BlockArgs = List[List[Dict[str, Any]]]
def get_bn_args_tf():
return _BN_ARGS_TF.copy()

View File

@ -16,15 +16,14 @@ from torch import Tensor
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer
from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer
from ._builder import build_model_with_cfg, pretrained_cfg_for_features
from ._efficientnet_blocks import SqueezeExcite
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from ._typing import BlockArgs, LayerType, PadType
__all__ = ['MobileNetV3', 'MobileNetV3Features']

View File

@ -9,7 +9,7 @@ Copyright 2019, Ross Wightman
"""
import math
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
@ -17,22 +17,21 @@ import torch.nn.functional as F
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \
get_act_layer, get_norm_layer, create_classifier
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
get_attn, get_act_layer, get_norm_layer, create_classifier
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
from ._typing import LayerType
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
def get_padding(kernel_size: int, stride: int, dilation: int = 1):
def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
def create_aa(aa_layer, channels, stride=2, enable=True):
def create_aa(aa_layer: Type[nn.Module], channels: int, stride: int = 2, enable: bool = True) -> nn.Module:
if not aa_layer or not enable:
return nn.Identity()
if issubclass(aa_layer, nn.AvgPool2d):
@ -55,11 +54,11 @@ class BasicBlock(nn.Module):
reduce_first: int = 1,
dilation: int = 1,
first_dilation: Optional[int] = None,
act_layer: nn.Module = nn.ReLU,
norm_layer: nn.Module = nn.BatchNorm2d,
attn_layer: Optional[nn.Module] = None,
aa_layer: Optional[nn.Module] = None,
drop_block: Type[nn.Module] = None,
act_layer: Type[nn.Module] = nn.ReLU,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
attn_layer: Optional[Type[nn.Module]] = None,
aa_layer: Optional[Type[nn.Module]] = None,
drop_block: Optional[Type[nn.Module]] = None,
drop_path: Optional[nn.Module] = None,
):
"""
@ -153,11 +152,11 @@ class Bottleneck(nn.Module):
reduce_first: int = 1,
dilation: int = 1,
first_dilation: Optional[int] = None,
act_layer: nn.Module = nn.ReLU,
norm_layer: nn.Module = nn.BatchNorm2d,
attn_layer: Optional[nn.Module] = None,
aa_layer: Optional[nn.Module] = None,
drop_block: Type[nn.Module] = None,
act_layer: Type[nn.Module] = nn.ReLU,
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
attn_layer: Optional[Type[nn.Module]] = None,
aa_layer: Optional[Type[nn.Module]] = None,
drop_block: Optional[Type[nn.Module]] = None,
drop_path: Optional[nn.Module] = None,
):
"""
@ -296,7 +295,7 @@ def drop_blocks(drop_prob: float = 0.):
def make_blocks(
block_fn: nn.Module,
block_fn: Union[BasicBlock, Bottleneck],
channels: List[int],
block_repeats: List[int],
inplanes: int,
@ -395,7 +394,7 @@ class ResNet(nn.Module):
def __init__(
self,
block: nn.Module,
block: Union[BasicBlock, Bottleneck],
layers: List[int],
num_classes: int = 1000,
in_chans: int = 3,
@ -411,7 +410,7 @@ class ResNet(nn.Module):
avg_down: bool = False,
act_layer: LayerType = nn.ReLU,
norm_layer: LayerType = nn.BatchNorm2d,
aa_layer: Optional[nn.Module] = None,
aa_layer: Optional[Type[nn.Module]] = None,
drop_rate: float = 0.0,
drop_path_rate: float = 0.,
drop_block_rate: float = 0.,