improvement: add typehints and docs to timm/models/mobilenetv3.py

This commit is contained in:
a-r-r-o-w 2023-10-19 18:50:38 +05:30 committed by Ross Wightman
parent d023154bb5
commit c2fe0a2268
2 changed files with 116 additions and 67 deletions

10
timm/models/_typing.py Normal file
View File

@ -0,0 +1,10 @@
import functools
import types
from typing import Any, Dict, List, Tuple, Union
import torch.nn
BlockArgs = List[List[Dict[str, Any]]]
LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module]
PadType = Union[str, int, Tuple[int, int]]

View File

@ -7,11 +7,12 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by / Copyright 2019, Ross Wightman
"""
from functools import partial
from typing import List
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -23,6 +24,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
from ._features import FeatureInfo, FeatureHooks
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from ._typing import BlockArgs, LayerType, PadType
__all__ = ['MobileNetV3', 'MobileNetV3Features']
@ -44,23 +46,42 @@ class MobileNetV3(nn.Module):
def __init__(
self,
block_args,
num_classes=1000,
in_chans=3,
stem_size=16,
fix_stem=False,
num_features=1280,
head_bias=True,
pad_type='',
act_layer=None,
norm_layer=None,
se_layer=None,
se_from_exp=True,
round_chs_fn=round_channels,
drop_rate=0.,
drop_path_rate=0.,
global_pool='avg',
block_args: BlockArgs,
num_classes: int = 1000,
in_chans: int = 3,
stem_size: int = 16,
fix_stem: bool = False,
num_features: int = 1280,
head_bias: bool = True,
pad_type: PadType = '',
act_layer: Optional[LayerType] = None,
norm_layer: Optional[LayerType] = None,
se_layer: Optional[LayerType] = None,
se_from_exp: bool = True,
round_chs_fn: Callable = round_channels,
drop_rate: float = 0.,
drop_path_rate: float = 0.,
global_pool: str = 'avg',
):
"""
Args:
block_args: Arguments for blocks of the network.
num_classes: Number of classes for classification head.
in_chans: Number of input image channels.
stem_size: Number of output channels of the initial stem convolution.
fix_stem: If True, don't scale stem by round_chs_fn.
num_features: Number of output channels of the conv head layer.
head_bias: If True, add a learnable bias to the conv head layer.
pad_type: Type of padding to use for convolution layers.
act_layer: Type of activation layer.
norm_layer: Type of normalization layer.
se_layer: Type of Squeeze-and-Excite layer.
se_from_exp: If True, calculate SE channel reduction from expanded mid channels.
round_chs_fn: Callable to round number of filters based on depth multiplier.
drop_rate: Dropout rate.
drop_path_rate: Stochastic depth rate.
global_pool: Type of pooling to use for global pooling features of the FC head.
"""
super(MobileNetV3, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
@ -110,28 +131,28 @@ class MobileNetV3(nn.Module):
return nn.Sequential(*layers)
@torch.jit.ignore
def group_matcher(self, coarse=False):
def group_matcher(self, coarse: bool = False):
return dict(
stem=r'^conv_stem|bn1',
blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)'
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
def forward_features(self, x: Tensor) -> Tensor:
x = self.conv_stem(x)
x = self.bn1(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
@ -140,7 +161,7 @@ class MobileNetV3(nn.Module):
x = self.blocks(x)
return x
def forward_head(self, x, pre_logits: bool = False):
def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
x = self.global_pool(x)
x = self.conv_head(x)
x = self.act2(x)
@ -151,7 +172,7 @@ class MobileNetV3(nn.Module):
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x
@ -166,22 +187,40 @@ class MobileNetV3Features(nn.Module):
def __init__(
self,
block_args,
out_indices=(0, 1, 2, 3, 4),
feature_location='bottleneck',
in_chans=3,
stem_size=16,
fix_stem=False,
output_stride=32,
pad_type='',
round_chs_fn=round_channels,
se_from_exp=True,
act_layer=None,
norm_layer=None,
se_layer=None,
drop_rate=0.,
drop_path_rate=0.,
block_args: BlockArgs,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
feature_location: str = 'bottleneck',
in_chans: int = 3,
stem_size: int = 16,
fix_stem: bool = False,
output_stride: int = 32,
pad_type: PadType = '',
round_chs_fn: Callable = round_channels,
se_from_exp: bool = True,
act_layer: Optional[LayerType] = None,
norm_layer: Optional[LayerType] = None,
se_layer: Optional[LayerType] = None,
drop_rate: float = 0.,
drop_path_rate: float = 0.,
):
"""
Args:
block_args: Arguments for blocks of the network.
out_indices: Output from stages at indices.
feature_location: Location of feature before/after each block, must be in ['bottleneck', 'expansion']
in_chans: Number of input image channels.
stem_size: Number of output channels of the initial stem convolution.
fix_stem: If True, don't scale stem by round_chs_fn.
output_stride: Output stride of the network.
pad_type: Type of padding to use for convolution layers.
round_chs_fn: Callable to round number of filters based on depth multiplier.
se_from_exp: If True, calculate SE channel reduction from expanded mid channels.
act_layer: Type of activation layer.
norm_layer: Type of normalization layer.
se_layer: Type of Squeeze-and-Excite layer.
drop_rate: Dropout rate.
drop_path_rate: Stochastic depth rate.
"""
super(MobileNetV3Features, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
@ -221,10 +260,10 @@ class MobileNetV3Features(nn.Module):
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable
def forward(self, x) -> List[torch.Tensor]:
def forward(self, x: Tensor) -> List[Tensor]:
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
@ -246,7 +285,7 @@ class MobileNetV3Features(nn.Module):
return list(out.values())
def _create_mnv3(variant, pretrained=False, **kwargs):
def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV3:
features_mode = ''
model_cls = MobileNetV3
kwargs_filter = None
@ -272,7 +311,7 @@ def _create_mnv3(variant, pretrained=False, **kwargs):
return model
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
"""Creates a MobileNet-V3 model.
Ref impl: ?
@ -310,7 +349,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
return model
def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
"""Creates a MobileNet-V3 model.
Ref impl: ?
@ -407,7 +446,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
return model
def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
def _gen_fbnetv3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs):
""" FBNetV3
Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining`
- https://arxiv.org/abs/2006.02049
@ -468,7 +507,7 @@ def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
return model
def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs):
""" LCNet
Essentially a MobileNet-V3 crossed with a MobileNet-V1
@ -506,7 +545,7 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
return model
def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs):
""" LCNet
Essentially a MobileNet-V3 crossed with a MobileNet-V1
@ -544,7 +583,7 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
return model
def _cfg(url='', **kwargs):
def _cfg(url: str = '', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
@ -649,42 +688,42 @@ default_cfgs = generate_default_cfgs({
@register_model
def mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3:
def mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3:
def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv3_small_050(pretrained=False, **kwargs) -> MobileNetV3:
def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_small_050', 0.50, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3:
def mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3:
def mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv3_rw(pretrained=False, **kwargs) -> MobileNetV3:
def mobilenetv3_rw(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """
if pretrained:
# pretrained model trained with non-default BN epsilon
@ -694,7 +733,7 @@ def mobilenetv3_rw(pretrained=False, **kwargs) -> MobileNetV3:
@register_model
def tf_mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3:
def tf_mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
@ -703,7 +742,7 @@ def tf_mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3:
@register_model
def tf_mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3:
def tf_mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
@ -712,7 +751,7 @@ def tf_mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3:
@register_model
def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs) -> MobileNetV3:
def tf_mobilenetv3_large_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
@ -721,7 +760,7 @@ def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs) -> MobileNetV3:
@register_model
def tf_mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3:
def tf_mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
@ -730,7 +769,7 @@ def tf_mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3:
@register_model
def tf_mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3:
def tf_mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
@ -739,7 +778,7 @@ def tf_mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3:
@register_model
def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs) -> MobileNetV3:
def tf_mobilenetv3_small_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
@ -748,56 +787,56 @@ def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs) -> MobileNetV3:
@register_model
def fbnetv3_b(pretrained=False, **kwargs) -> MobileNetV3:
def fbnetv3_b(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" FBNetV3-B """
model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs)
return model
@register_model
def fbnetv3_d(pretrained=False, **kwargs) -> MobileNetV3:
def fbnetv3_d(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" FBNetV3-D """
model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs)
return model
@register_model
def fbnetv3_g(pretrained=False, **kwargs) -> MobileNetV3:
def fbnetv3_g(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" FBNetV3-G """
model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs)
return model
@register_model
def lcnet_035(pretrained=False, **kwargs) -> MobileNetV3:
def lcnet_035(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" PP-LCNet 0.35"""
model = _gen_lcnet('lcnet_035', 0.35, pretrained=pretrained, **kwargs)
return model
@register_model
def lcnet_050(pretrained=False, **kwargs) -> MobileNetV3:
def lcnet_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" PP-LCNet 0.5"""
model = _gen_lcnet('lcnet_050', 0.5, pretrained=pretrained, **kwargs)
return model
@register_model
def lcnet_075(pretrained=False, **kwargs) -> MobileNetV3:
def lcnet_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" PP-LCNet 1.0"""
model = _gen_lcnet('lcnet_075', 0.75, pretrained=pretrained, **kwargs)
return model
@register_model
def lcnet_100(pretrained=False, **kwargs) -> MobileNetV3:
def lcnet_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" PP-LCNet 1.0"""
model = _gen_lcnet('lcnet_100', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def lcnet_150(pretrained=False, **kwargs) -> MobileNetV3:
def lcnet_150(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" PP-LCNet 1.5"""
model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs)
return model