improvement: add typehints and docs to timm/models/mobilenetv3.py

This commit is contained in:
a-r-r-o-w 2023-10-19 18:50:38 +05:30 committed by Ross Wightman
parent d023154bb5
commit c2fe0a2268
2 changed files with 116 additions and 67 deletions

10
timm/models/_typing.py Normal file
View File

@ -0,0 +1,10 @@
import functools
import types
from typing import Any, Dict, List, Tuple, Union
import torch.nn
BlockArgs = List[List[Dict[str, Any]]]
LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module]
PadType = Union[str, int, Tuple[int, int]]

View File

@ -7,11 +7,12 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by / Copyright 2019, Ross Wightman Hacked together by / Copyright 2019, Ross Wightman
""" """
from functools import partial from functools import partial
from typing import List from typing import Callable, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -23,6 +24,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
from ._features import FeatureInfo, FeatureHooks from ._features import FeatureInfo, FeatureHooks
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model, register_model_deprecations from ._registry import generate_default_cfgs, register_model, register_model_deprecations
from ._typing import BlockArgs, LayerType, PadType
__all__ = ['MobileNetV3', 'MobileNetV3Features'] __all__ = ['MobileNetV3', 'MobileNetV3Features']
@ -44,23 +46,42 @@ class MobileNetV3(nn.Module):
def __init__( def __init__(
self, self,
block_args, block_args: BlockArgs,
num_classes=1000, num_classes: int = 1000,
in_chans=3, in_chans: int = 3,
stem_size=16, stem_size: int = 16,
fix_stem=False, fix_stem: bool = False,
num_features=1280, num_features: int = 1280,
head_bias=True, head_bias: bool = True,
pad_type='', pad_type: PadType = '',
act_layer=None, act_layer: Optional[LayerType] = None,
norm_layer=None, norm_layer: Optional[LayerType] = None,
se_layer=None, se_layer: Optional[LayerType] = None,
se_from_exp=True, se_from_exp: bool = True,
round_chs_fn=round_channels, round_chs_fn: Callable = round_channels,
drop_rate=0., drop_rate: float = 0.,
drop_path_rate=0., drop_path_rate: float = 0.,
global_pool='avg', global_pool: str = 'avg',
): ):
"""
Args:
block_args: Arguments for blocks of the network.
num_classes: Number of classes for classification head.
in_chans: Number of input image channels.
stem_size: Number of output channels of the initial stem convolution.
fix_stem: If True, don't scale stem by round_chs_fn.
num_features: Number of output channels of the conv head layer.
head_bias: If True, add a learnable bias to the conv head layer.
pad_type: Type of padding to use for convolution layers.
act_layer: Type of activation layer.
norm_layer: Type of normalization layer.
se_layer: Type of Squeeze-and-Excite layer.
se_from_exp: If True, calculate SE channel reduction from expanded mid channels.
round_chs_fn: Callable to round number of filters based on depth multiplier.
drop_rate: Dropout rate.
drop_path_rate: Stochastic depth rate.
global_pool: Type of pooling to use for global pooling features of the FC head.
"""
super(MobileNetV3, self).__init__() super(MobileNetV3, self).__init__()
act_layer = act_layer or nn.ReLU act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
@ -110,28 +131,28 @@ class MobileNetV3(nn.Module):
return nn.Sequential(*layers) return nn.Sequential(*layers)
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse: bool = False):
return dict( return dict(
stem=r'^conv_stem|bn1', stem=r'^conv_stem|bn1',
blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)' blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)'
) )
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable=True): def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable self.grad_checkpointing = enable
@torch.jit.ignore @torch.jit.ignore
def get_classifier(self): def get_classifier(self):
return self.classifier return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.num_classes = num_classes self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation # cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x: Tensor) -> Tensor:
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
@ -140,7 +161,7 @@ class MobileNetV3(nn.Module):
x = self.blocks(x) x = self.blocks(x)
return x return x
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
x = self.global_pool(x) x = self.global_pool(x)
x = self.conv_head(x) x = self.conv_head(x)
x = self.act2(x) x = self.act2(x)
@ -151,7 +172,7 @@ class MobileNetV3(nn.Module):
x = F.dropout(x, p=self.drop_rate, training=self.training) x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x) return self.classifier(x)
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
x = self.forward_features(x) x = self.forward_features(x)
x = self.forward_head(x) x = self.forward_head(x)
return x return x
@ -166,22 +187,40 @@ class MobileNetV3Features(nn.Module):
def __init__( def __init__(
self, self,
block_args, block_args: BlockArgs,
out_indices=(0, 1, 2, 3, 4), out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
feature_location='bottleneck', feature_location: str = 'bottleneck',
in_chans=3, in_chans: int = 3,
stem_size=16, stem_size: int = 16,
fix_stem=False, fix_stem: bool = False,
output_stride=32, output_stride: int = 32,
pad_type='', pad_type: PadType = '',
round_chs_fn=round_channels, round_chs_fn: Callable = round_channels,
se_from_exp=True, se_from_exp: bool = True,
act_layer=None, act_layer: Optional[LayerType] = None,
norm_layer=None, norm_layer: Optional[LayerType] = None,
se_layer=None, se_layer: Optional[LayerType] = None,
drop_rate=0., drop_rate: float = 0.,
drop_path_rate=0., drop_path_rate: float = 0.,
): ):
"""
Args:
block_args: Arguments for blocks of the network.
out_indices: Output from stages at indices.
feature_location: Location of feature before/after each block, must be in ['bottleneck', 'expansion']
in_chans: Number of input image channels.
stem_size: Number of output channels of the initial stem convolution.
fix_stem: If True, don't scale stem by round_chs_fn.
output_stride: Output stride of the network.
pad_type: Type of padding to use for convolution layers.
round_chs_fn: Callable to round number of filters based on depth multiplier.
se_from_exp: If True, calculate SE channel reduction from expanded mid channels.
act_layer: Type of activation layer.
norm_layer: Type of normalization layer.
se_layer: Type of Squeeze-and-Excite layer.
drop_rate: Dropout rate.
drop_path_rate: Stochastic depth rate.
"""
super(MobileNetV3Features, self).__init__() super(MobileNetV3Features, self).__init__()
act_layer = act_layer or nn.ReLU act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
@ -221,10 +260,10 @@ class MobileNetV3Features(nn.Module):
self.feature_hooks = FeatureHooks(hooks, self.named_modules()) self.feature_hooks = FeatureHooks(hooks, self.named_modules())
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable=True): def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable self.grad_checkpointing = enable
def forward(self, x) -> List[torch.Tensor]: def forward(self, x: Tensor) -> List[Tensor]:
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x) x = self.act1(x)
@ -246,7 +285,7 @@ class MobileNetV3Features(nn.Module):
return list(out.values()) return list(out.values())
def _create_mnv3(variant, pretrained=False, **kwargs): def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV3:
features_mode = '' features_mode = ''
model_cls = MobileNetV3 model_cls = MobileNetV3
kwargs_filter = None kwargs_filter = None
@ -272,7 +311,7 @@ def _create_mnv3(variant, pretrained=False, **kwargs):
return model return model
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
"""Creates a MobileNet-V3 model. """Creates a MobileNet-V3 model.
Ref impl: ? Ref impl: ?
@ -310,7 +349,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
return model return model
def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
"""Creates a MobileNet-V3 model. """Creates a MobileNet-V3 model.
Ref impl: ? Ref impl: ?
@ -407,7 +446,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
return model return model
def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): def _gen_fbnetv3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs):
""" FBNetV3 """ FBNetV3
Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining` Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining`
- https://arxiv.org/abs/2006.02049 - https://arxiv.org/abs/2006.02049
@ -468,7 +507,7 @@ def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
return model return model
def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs):
""" LCNet """ LCNet
Essentially a MobileNet-V3 crossed with a MobileNet-V1 Essentially a MobileNet-V3 crossed with a MobileNet-V1
@ -506,7 +545,7 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
return model return model
def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs):
""" LCNet """ LCNet
Essentially a MobileNet-V3 crossed with a MobileNet-V1 Essentially a MobileNet-V3 crossed with a MobileNet-V1
@ -544,7 +583,7 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
return model return model
def _cfg(url='', **kwargs): def _cfg(url: str = '', **kwargs):
return { return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear', 'crop_pct': 0.875, 'interpolation': 'bilinear',
@ -649,42 +688,42 @@ default_cfgs = generate_default_cfgs({
@register_model @register_model
def mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3: def mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """ """ MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3: def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """ """ MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def mobilenetv3_small_050(pretrained=False, **kwargs) -> MobileNetV3: def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """ """ MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_small_050', 0.50, pretrained=pretrained, **kwargs) model = _gen_mobilenet_v3('mobilenetv3_small_050', 0.50, pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3: def mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """ """ MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3: def mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """ """ MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def mobilenetv3_rw(pretrained=False, **kwargs) -> MobileNetV3: def mobilenetv3_rw(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """ """ MobileNet V3 """
if pretrained: if pretrained:
# pretrained model trained with non-default BN epsilon # pretrained model trained with non-default BN epsilon
@ -694,7 +733,7 @@ def mobilenetv3_rw(pretrained=False, **kwargs) -> MobileNetV3:
@register_model @register_model
def tf_mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3: def tf_mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """ """ MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same' kwargs['pad_type'] = 'same'
@ -703,7 +742,7 @@ def tf_mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3:
@register_model @register_model
def tf_mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3: def tf_mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """ """ MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same' kwargs['pad_type'] = 'same'
@ -712,7 +751,7 @@ def tf_mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3:
@register_model @register_model
def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs) -> MobileNetV3: def tf_mobilenetv3_large_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """ """ MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same' kwargs['pad_type'] = 'same'
@ -721,7 +760,7 @@ def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs) -> MobileNetV3:
@register_model @register_model
def tf_mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3: def tf_mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """ """ MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same' kwargs['pad_type'] = 'same'
@ -730,7 +769,7 @@ def tf_mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3:
@register_model @register_model
def tf_mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3: def tf_mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """ """ MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same' kwargs['pad_type'] = 'same'
@ -739,7 +778,7 @@ def tf_mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3:
@register_model @register_model
def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs) -> MobileNetV3: def tf_mobilenetv3_small_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" MobileNet V3 """ """ MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same' kwargs['pad_type'] = 'same'
@ -748,56 +787,56 @@ def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs) -> MobileNetV3:
@register_model @register_model
def fbnetv3_b(pretrained=False, **kwargs) -> MobileNetV3: def fbnetv3_b(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" FBNetV3-B """ """ FBNetV3-B """
model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs) model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def fbnetv3_d(pretrained=False, **kwargs) -> MobileNetV3: def fbnetv3_d(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" FBNetV3-D """ """ FBNetV3-D """
model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs) model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def fbnetv3_g(pretrained=False, **kwargs) -> MobileNetV3: def fbnetv3_g(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" FBNetV3-G """ """ FBNetV3-G """
model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs) model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def lcnet_035(pretrained=False, **kwargs) -> MobileNetV3: def lcnet_035(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" PP-LCNet 0.35""" """ PP-LCNet 0.35"""
model = _gen_lcnet('lcnet_035', 0.35, pretrained=pretrained, **kwargs) model = _gen_lcnet('lcnet_035', 0.35, pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def lcnet_050(pretrained=False, **kwargs) -> MobileNetV3: def lcnet_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" PP-LCNet 0.5""" """ PP-LCNet 0.5"""
model = _gen_lcnet('lcnet_050', 0.5, pretrained=pretrained, **kwargs) model = _gen_lcnet('lcnet_050', 0.5, pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def lcnet_075(pretrained=False, **kwargs) -> MobileNetV3: def lcnet_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" PP-LCNet 1.0""" """ PP-LCNet 1.0"""
model = _gen_lcnet('lcnet_075', 0.75, pretrained=pretrained, **kwargs) model = _gen_lcnet('lcnet_075', 0.75, pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def lcnet_100(pretrained=False, **kwargs) -> MobileNetV3: def lcnet_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" PP-LCNet 1.0""" """ PP-LCNet 1.0"""
model = _gen_lcnet('lcnet_100', 1.0, pretrained=pretrained, **kwargs) model = _gen_lcnet('lcnet_100', 1.0, pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def lcnet_150(pretrained=False, **kwargs) -> MobileNetV3: def lcnet_150(pretrained: bool = False, **kwargs) -> MobileNetV3:
""" PP-LCNet 1.5""" """ PP-LCNet 1.5"""
model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs) model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs)
return model return model