diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 2cdcfd98..e4d1499f 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -52,4 +52,5 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .trace_utils import _assert, _float_to_int +from .typing import LayerType, PadType from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ diff --git a/timm/models/_typing.py b/timm/layers/typing.py similarity index 66% rename from timm/models/_typing.py rename to timm/layers/typing.py index 9d4b46d7..35aa9f88 100644 --- a/timm/models/_typing.py +++ b/timm/layers/typing.py @@ -1,10 +1,9 @@ import functools import types -from typing import Any, Dict, List, Tuple, Union +from typing import Tuple, Union import torch.nn -BlockArgs = List[List[Dict[str, Any]]] LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module] PadType = Union[str, int, Tuple[int, int]] diff --git a/timm/models/_efficientnet_builder.py b/timm/models/_efficientnet_builder.py index b5dbeaae..1e3161d6 100644 --- a/timm/models/_efficientnet_builder.py +++ b/timm/models/_efficientnet_builder.py @@ -11,6 +11,7 @@ import math import re from copy import deepcopy from functools import partial +from typing import Any, Dict, List import torch.nn as nn @@ -34,6 +35,8 @@ BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 BN_EPS_TF_DEFAULT = 1e-3 _BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) +BlockArgs = List[List[Dict[str, Any]]] + def get_bn_args_tf(): return _BN_ARGS_TF.copy() diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index d56c0f54..f6cd8e08 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -16,15 +16,14 @@ from torch import Tensor from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer +from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer from ._builder import build_model_with_cfg, pretrained_cfg_for_features from ._efficientnet_blocks import SqueezeExcite -from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ +from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from ._features import FeatureInfo, FeatureHooks from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model, register_model_deprecations -from ._typing import BlockArgs, LayerType, PadType __all__ = ['MobileNetV3', 'MobileNetV3Features'] diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 67869261..429a54c4 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -9,7 +9,7 @@ Copyright 2019, Ross Wightman """ import math from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -17,22 +17,21 @@ import torch.nn.functional as F from torch import Tensor from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \ - get_act_layer, get_norm_layer, create_classifier +from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \ + get_attn, get_act_layer, get_norm_layer, create_classifier from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs, register_model_deprecations -from ._typing import LayerType __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this -def get_padding(kernel_size: int, stride: int, dilation: int = 1): +def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int: padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding -def create_aa(aa_layer, channels, stride=2, enable=True): +def create_aa(aa_layer: Type[nn.Module], channels: int, stride: int = 2, enable: bool = True) -> nn.Module: if not aa_layer or not enable: return nn.Identity() if issubclass(aa_layer, nn.AvgPool2d): @@ -55,11 +54,11 @@ class BasicBlock(nn.Module): reduce_first: int = 1, dilation: int = 1, first_dilation: Optional[int] = None, - act_layer: nn.Module = nn.ReLU, - norm_layer: nn.Module = nn.BatchNorm2d, - attn_layer: Optional[nn.Module] = None, - aa_layer: Optional[nn.Module] = None, - drop_block: Type[nn.Module] = None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + attn_layer: Optional[Type[nn.Module]] = None, + aa_layer: Optional[Type[nn.Module]] = None, + drop_block: Optional[Type[nn.Module]] = None, drop_path: Optional[nn.Module] = None, ): """ @@ -153,11 +152,11 @@ class Bottleneck(nn.Module): reduce_first: int = 1, dilation: int = 1, first_dilation: Optional[int] = None, - act_layer: nn.Module = nn.ReLU, - norm_layer: nn.Module = nn.BatchNorm2d, - attn_layer: Optional[nn.Module] = None, - aa_layer: Optional[nn.Module] = None, - drop_block: Type[nn.Module] = None, + act_layer: Type[nn.Module] = nn.ReLU, + norm_layer: Type[nn.Module] = nn.BatchNorm2d, + attn_layer: Optional[Type[nn.Module]] = None, + aa_layer: Optional[Type[nn.Module]] = None, + drop_block: Optional[Type[nn.Module]] = None, drop_path: Optional[nn.Module] = None, ): """ @@ -296,7 +295,7 @@ def drop_blocks(drop_prob: float = 0.): def make_blocks( - block_fn: nn.Module, + block_fn: Union[BasicBlock, Bottleneck], channels: List[int], block_repeats: List[int], inplanes: int, @@ -395,7 +394,7 @@ class ResNet(nn.Module): def __init__( self, - block: nn.Module, + block: Union[BasicBlock, Bottleneck], layers: List[int], num_classes: int = 1000, in_chans: int = 3, @@ -411,7 +410,7 @@ class ResNet(nn.Module): avg_down: bool = False, act_layer: LayerType = nn.ReLU, norm_layer: LayerType = nn.BatchNorm2d, - aa_layer: Optional[nn.Module] = None, + aa_layer: Optional[Type[nn.Module]] = None, drop_rate: float = 0.0, drop_path_rate: float = 0., drop_block_rate: float = 0.,