diff --git a/timm/models/_typing.py b/timm/models/_typing.py new file mode 100644 index 00000000..9d4b46d7 --- /dev/null +++ b/timm/models/_typing.py @@ -0,0 +1,10 @@ +import functools +import types +from typing import Any, Dict, List, Tuple, Union + +import torch.nn + + +BlockArgs = List[List[Dict[str, Any]]] +LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module] +PadType = Union[str, int, Tuple[int, int]] diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 8de94f7e..d56c0f54 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -7,11 +7,12 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 Hacked together by / Copyright 2019, Ross Wightman """ from functools import partial -from typing import List +from typing import Callable, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD @@ -23,6 +24,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie from ._features import FeatureInfo, FeatureHooks from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model, register_model_deprecations +from ._typing import BlockArgs, LayerType, PadType __all__ = ['MobileNetV3', 'MobileNetV3Features'] @@ -44,23 +46,42 @@ class MobileNetV3(nn.Module): def __init__( self, - block_args, - num_classes=1000, - in_chans=3, - stem_size=16, - fix_stem=False, - num_features=1280, - head_bias=True, - pad_type='', - act_layer=None, - norm_layer=None, - se_layer=None, - se_from_exp=True, - round_chs_fn=round_channels, - drop_rate=0., - drop_path_rate=0., - global_pool='avg', + block_args: BlockArgs, + num_classes: int = 1000, + in_chans: int = 3, + stem_size: int = 16, + fix_stem: bool = False, + num_features: int = 1280, + head_bias: bool = True, + pad_type: PadType = '', + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, + se_from_exp: bool = True, + round_chs_fn: Callable = round_channels, + drop_rate: float = 0., + drop_path_rate: float = 0., + global_pool: str = 'avg', ): + """ + Args: + block_args: Arguments for blocks of the network. + num_classes: Number of classes for classification head. + in_chans: Number of input image channels. + stem_size: Number of output channels of the initial stem convolution. + fix_stem: If True, don't scale stem by round_chs_fn. + num_features: Number of output channels of the conv head layer. + head_bias: If True, add a learnable bias to the conv head layer. + pad_type: Type of padding to use for convolution layers. + act_layer: Type of activation layer. + norm_layer: Type of normalization layer. + se_layer: Type of Squeeze-and-Excite layer. + se_from_exp: If True, calculate SE channel reduction from expanded mid channels. + round_chs_fn: Callable to round number of filters based on depth multiplier. + drop_rate: Dropout rate. + drop_path_rate: Stochastic depth rate. + global_pool: Type of pooling to use for global pooling features of the FC head. + """ super(MobileNetV3, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d @@ -110,28 +131,28 @@ class MobileNetV3(nn.Module): return nn.Sequential(*layers) @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False): return dict( stem=r'^conv_stem|bn1', blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)' ) @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True): self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self): return self.classifier - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - def forward_features(self, x): + def forward_features(self, x: Tensor) -> Tensor: x = self.conv_stem(x) x = self.bn1(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -140,7 +161,7 @@ class MobileNetV3(nn.Module): x = self.blocks(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor: x = self.global_pool(x) x = self.conv_head(x) x = self.act2(x) @@ -151,7 +172,7 @@ class MobileNetV3(nn.Module): x = F.dropout(x, p=self.drop_rate, training=self.training) return self.classifier(x) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: x = self.forward_features(x) x = self.forward_head(x) return x @@ -166,22 +187,40 @@ class MobileNetV3Features(nn.Module): def __init__( self, - block_args, - out_indices=(0, 1, 2, 3, 4), - feature_location='bottleneck', - in_chans=3, - stem_size=16, - fix_stem=False, - output_stride=32, - pad_type='', - round_chs_fn=round_channels, - se_from_exp=True, - act_layer=None, - norm_layer=None, - se_layer=None, - drop_rate=0., - drop_path_rate=0., + block_args: BlockArgs, + out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + feature_location: str = 'bottleneck', + in_chans: int = 3, + stem_size: int = 16, + fix_stem: bool = False, + output_stride: int = 32, + pad_type: PadType = '', + round_chs_fn: Callable = round_channels, + se_from_exp: bool = True, + act_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + se_layer: Optional[LayerType] = None, + drop_rate: float = 0., + drop_path_rate: float = 0., ): + """ + Args: + block_args: Arguments for blocks of the network. + out_indices: Output from stages at indices. + feature_location: Location of feature before/after each block, must be in ['bottleneck', 'expansion'] + in_chans: Number of input image channels. + stem_size: Number of output channels of the initial stem convolution. + fix_stem: If True, don't scale stem by round_chs_fn. + output_stride: Output stride of the network. + pad_type: Type of padding to use for convolution layers. + round_chs_fn: Callable to round number of filters based on depth multiplier. + se_from_exp: If True, calculate SE channel reduction from expanded mid channels. + act_layer: Type of activation layer. + norm_layer: Type of normalization layer. + se_layer: Type of Squeeze-and-Excite layer. + drop_rate: Dropout rate. + drop_path_rate: Stochastic depth rate. + """ super(MobileNetV3Features, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d @@ -221,10 +260,10 @@ class MobileNetV3Features(nn.Module): self.feature_hooks = FeatureHooks(hooks, self.named_modules()) @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True): self.grad_checkpointing = enable - def forward(self, x) -> List[torch.Tensor]: + def forward(self, x: Tensor) -> List[Tensor]: x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) @@ -246,7 +285,7 @@ class MobileNetV3Features(nn.Module): return list(out.values()) -def _create_mnv3(variant, pretrained=False, **kwargs): +def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV3: features_mode = '' model_cls = MobileNetV3 kwargs_filter = None @@ -272,7 +311,7 @@ def _create_mnv3(variant, pretrained=False, **kwargs): return model -def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): +def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3: """Creates a MobileNet-V3 model. Ref impl: ? @@ -310,7 +349,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw return model -def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): +def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3: """Creates a MobileNet-V3 model. Ref impl: ? @@ -407,7 +446,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg return model -def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): +def _gen_fbnetv3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs): """ FBNetV3 Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining` - https://arxiv.org/abs/2006.02049 @@ -468,7 +507,7 @@ def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): return model -def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): +def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs): """ LCNet Essentially a MobileNet-V3 crossed with a MobileNet-V1 @@ -506,7 +545,7 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): return model -def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): +def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs): """ LCNet Essentially a MobileNet-V3 crossed with a MobileNet-V1 @@ -544,7 +583,7 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): return model -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', @@ -649,42 +688,42 @@ default_cfgs = generate_default_cfgs({ @register_model -def mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3: +def mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3: +def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv3_small_050(pretrained=False, **kwargs) -> MobileNetV3: +def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ model = _gen_mobilenet_v3('mobilenetv3_small_050', 0.50, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3: +def mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3: +def mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv3_rw(pretrained=False, **kwargs) -> MobileNetV3: +def mobilenetv3_rw(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ if pretrained: # pretrained model trained with non-default BN epsilon @@ -694,7 +733,7 @@ def mobilenetv3_rw(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def tf_mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3: +def tf_mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' @@ -703,7 +742,7 @@ def tf_mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def tf_mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3: +def tf_mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' @@ -712,7 +751,7 @@ def tf_mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs) -> MobileNetV3: +def tf_mobilenetv3_large_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' @@ -721,7 +760,7 @@ def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def tf_mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3: +def tf_mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' @@ -730,7 +769,7 @@ def tf_mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def tf_mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3: +def tf_mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' @@ -739,7 +778,7 @@ def tf_mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs) -> MobileNetV3: +def tf_mobilenetv3_small_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ MobileNet V3 """ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['pad_type'] = 'same' @@ -748,56 +787,56 @@ def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs) -> MobileNetV3: @register_model -def fbnetv3_b(pretrained=False, **kwargs) -> MobileNetV3: +def fbnetv3_b(pretrained: bool = False, **kwargs) -> MobileNetV3: """ FBNetV3-B """ model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs) return model @register_model -def fbnetv3_d(pretrained=False, **kwargs) -> MobileNetV3: +def fbnetv3_d(pretrained: bool = False, **kwargs) -> MobileNetV3: """ FBNetV3-D """ model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs) return model @register_model -def fbnetv3_g(pretrained=False, **kwargs) -> MobileNetV3: +def fbnetv3_g(pretrained: bool = False, **kwargs) -> MobileNetV3: """ FBNetV3-G """ model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs) return model @register_model -def lcnet_035(pretrained=False, **kwargs) -> MobileNetV3: +def lcnet_035(pretrained: bool = False, **kwargs) -> MobileNetV3: """ PP-LCNet 0.35""" model = _gen_lcnet('lcnet_035', 0.35, pretrained=pretrained, **kwargs) return model @register_model -def lcnet_050(pretrained=False, **kwargs) -> MobileNetV3: +def lcnet_050(pretrained: bool = False, **kwargs) -> MobileNetV3: """ PP-LCNet 0.5""" model = _gen_lcnet('lcnet_050', 0.5, pretrained=pretrained, **kwargs) return model @register_model -def lcnet_075(pretrained=False, **kwargs) -> MobileNetV3: +def lcnet_075(pretrained: bool = False, **kwargs) -> MobileNetV3: """ PP-LCNet 1.0""" model = _gen_lcnet('lcnet_075', 0.75, pretrained=pretrained, **kwargs) return model @register_model -def lcnet_100(pretrained=False, **kwargs) -> MobileNetV3: +def lcnet_100(pretrained: bool = False, **kwargs) -> MobileNetV3: """ PP-LCNet 1.0""" model = _gen_lcnet('lcnet_100', 1.0, pretrained=pretrained, **kwargs) return model @register_model -def lcnet_150(pretrained=False, **kwargs) -> MobileNetV3: +def lcnet_150(pretrained: bool = False, **kwargs) -> MobileNetV3: """ PP-LCNet 1.5""" model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs) return model