mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
improvement: add typehints and docs to timm/models/mobilenetv3.py
This commit is contained in:
parent
d023154bb5
commit
c2fe0a2268
10
timm/models/_typing.py
Normal file
10
timm/models/_typing.py
Normal file
@ -0,0 +1,10 @@
|
||||
import functools
|
||||
import types
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import torch.nn
|
||||
|
||||
|
||||
BlockArgs = List[List[Dict[str, Any]]]
|
||||
LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module]
|
||||
PadType = Union[str, int, Tuple[int, int]]
|
@ -7,11 +7,12 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
@ -23,6 +24,7 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
|
||||
from ._features import FeatureInfo, FeatureHooks
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
from ._typing import BlockArgs, LayerType, PadType
|
||||
|
||||
__all__ = ['MobileNetV3', 'MobileNetV3Features']
|
||||
|
||||
@ -44,23 +46,42 @@ class MobileNetV3(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_args,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
stem_size=16,
|
||||
fix_stem=False,
|
||||
num_features=1280,
|
||||
head_bias=True,
|
||||
pad_type='',
|
||||
act_layer=None,
|
||||
norm_layer=None,
|
||||
se_layer=None,
|
||||
se_from_exp=True,
|
||||
round_chs_fn=round_channels,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
global_pool='avg',
|
||||
block_args: BlockArgs,
|
||||
num_classes: int = 1000,
|
||||
in_chans: int = 3,
|
||||
stem_size: int = 16,
|
||||
fix_stem: bool = False,
|
||||
num_features: int = 1280,
|
||||
head_bias: bool = True,
|
||||
pad_type: PadType = '',
|
||||
act_layer: Optional[LayerType] = None,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[LayerType] = None,
|
||||
se_from_exp: bool = True,
|
||||
round_chs_fn: Callable = round_channels,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
global_pool: str = 'avg',
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
block_args: Arguments for blocks of the network.
|
||||
num_classes: Number of classes for classification head.
|
||||
in_chans: Number of input image channels.
|
||||
stem_size: Number of output channels of the initial stem convolution.
|
||||
fix_stem: If True, don't scale stem by round_chs_fn.
|
||||
num_features: Number of output channels of the conv head layer.
|
||||
head_bias: If True, add a learnable bias to the conv head layer.
|
||||
pad_type: Type of padding to use for convolution layers.
|
||||
act_layer: Type of activation layer.
|
||||
norm_layer: Type of normalization layer.
|
||||
se_layer: Type of Squeeze-and-Excite layer.
|
||||
se_from_exp: If True, calculate SE channel reduction from expanded mid channels.
|
||||
round_chs_fn: Callable to round number of filters based on depth multiplier.
|
||||
drop_rate: Dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
global_pool: Type of pooling to use for global pooling features of the FC head.
|
||||
"""
|
||||
super(MobileNetV3, self).__init__()
|
||||
act_layer = act_layer or nn.ReLU
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
@ -110,28 +131,28 @@ class MobileNetV3(nn.Module):
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
def group_matcher(self, coarse: bool = False):
|
||||
return dict(
|
||||
stem=r'^conv_stem|bn1',
|
||||
blocks=r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)'
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.classifier
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
|
||||
self.num_classes = num_classes
|
||||
# cannot meaningfully change pooling of efficient head after creation
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
||||
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
def forward_features(self, x: Tensor) -> Tensor:
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
@ -140,7 +161,7 @@ class MobileNetV3(nn.Module):
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
|
||||
x = self.global_pool(x)
|
||||
x = self.conv_head(x)
|
||||
x = self.act2(x)
|
||||
@ -151,7 +172,7 @@ class MobileNetV3(nn.Module):
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
return self.classifier(x)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
@ -166,22 +187,40 @@ class MobileNetV3Features(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_args,
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
feature_location='bottleneck',
|
||||
in_chans=3,
|
||||
stem_size=16,
|
||||
fix_stem=False,
|
||||
output_stride=32,
|
||||
pad_type='',
|
||||
round_chs_fn=round_channels,
|
||||
se_from_exp=True,
|
||||
act_layer=None,
|
||||
norm_layer=None,
|
||||
se_layer=None,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
block_args: BlockArgs,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
feature_location: str = 'bottleneck',
|
||||
in_chans: int = 3,
|
||||
stem_size: int = 16,
|
||||
fix_stem: bool = False,
|
||||
output_stride: int = 32,
|
||||
pad_type: PadType = '',
|
||||
round_chs_fn: Callable = round_channels,
|
||||
se_from_exp: bool = True,
|
||||
act_layer: Optional[LayerType] = None,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[LayerType] = None,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
block_args: Arguments for blocks of the network.
|
||||
out_indices: Output from stages at indices.
|
||||
feature_location: Location of feature before/after each block, must be in ['bottleneck', 'expansion']
|
||||
in_chans: Number of input image channels.
|
||||
stem_size: Number of output channels of the initial stem convolution.
|
||||
fix_stem: If True, don't scale stem by round_chs_fn.
|
||||
output_stride: Output stride of the network.
|
||||
pad_type: Type of padding to use for convolution layers.
|
||||
round_chs_fn: Callable to round number of filters based on depth multiplier.
|
||||
se_from_exp: If True, calculate SE channel reduction from expanded mid channels.
|
||||
act_layer: Type of activation layer.
|
||||
norm_layer: Type of normalization layer.
|
||||
se_layer: Type of Squeeze-and-Excite layer.
|
||||
drop_rate: Dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
"""
|
||||
super(MobileNetV3Features, self).__init__()
|
||||
act_layer = act_layer or nn.ReLU
|
||||
norm_layer = norm_layer or nn.BatchNorm2d
|
||||
@ -221,10 +260,10 @@ class MobileNetV3Features(nn.Module):
|
||||
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x) -> List[torch.Tensor]:
|
||||
def forward(self, x: Tensor) -> List[Tensor]:
|
||||
x = self.conv_stem(x)
|
||||
x = self.bn1(x)
|
||||
x = self.act1(x)
|
||||
@ -246,7 +285,7 @@ class MobileNetV3Features(nn.Module):
|
||||
return list(out.values())
|
||||
|
||||
|
||||
def _create_mnv3(variant, pretrained=False, **kwargs):
|
||||
def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
features_mode = ''
|
||||
model_cls = MobileNetV3
|
||||
kwargs_filter = None
|
||||
@ -272,7 +311,7 @@ def _create_mnv3(variant, pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
"""Creates a MobileNet-V3 model.
|
||||
|
||||
Ref impl: ?
|
||||
@ -310,7 +349,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
|
||||
return model
|
||||
|
||||
|
||||
def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
"""Creates a MobileNet-V3 model.
|
||||
|
||||
Ref impl: ?
|
||||
@ -407,7 +446,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
|
||||
return model
|
||||
|
||||
|
||||
def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
def _gen_fbnetv3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs):
|
||||
""" FBNetV3
|
||||
Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining`
|
||||
- https://arxiv.org/abs/2006.02049
|
||||
@ -468,7 +507,7 @@ def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs):
|
||||
""" LCNet
|
||||
Essentially a MobileNet-V3 crossed with a MobileNet-V1
|
||||
|
||||
@ -506,7 +545,7 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs):
|
||||
""" LCNet
|
||||
Essentially a MobileNet-V3 crossed with a MobileNet-V1
|
||||
|
||||
@ -544,7 +583,7 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
def _cfg(url: str = '', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
@ -649,42 +688,42 @@ default_cfgs = generate_default_cfgs({
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_small_050(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_small_050', 0.50, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_rw(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def mobilenetv3_rw(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
if pretrained:
|
||||
# pretrained model trained with non-default BN epsilon
|
||||
@ -694,7 +733,7 @@ def mobilenetv3_rw(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def tf_mobilenetv3_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
@ -703,7 +742,7 @@ def tf_mobilenetv3_large_075(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def tf_mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
@ -712,7 +751,7 @@ def tf_mobilenetv3_large_100(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def tf_mobilenetv3_large_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
@ -721,7 +760,7 @@ def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def tf_mobilenetv3_small_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
@ -730,7 +769,7 @@ def tf_mobilenetv3_small_075(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def tf_mobilenetv3_small_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
@ -739,7 +778,7 @@ def tf_mobilenetv3_small_100(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def tf_mobilenetv3_small_minimal_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||
kwargs['pad_type'] = 'same'
|
||||
@ -748,56 +787,56 @@ def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
|
||||
|
||||
@register_model
|
||||
def fbnetv3_b(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def fbnetv3_b(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" FBNetV3-B """
|
||||
model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def fbnetv3_d(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def fbnetv3_d(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" FBNetV3-D """
|
||||
model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def fbnetv3_g(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def fbnetv3_g(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" FBNetV3-G """
|
||||
model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def lcnet_035(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def lcnet_035(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" PP-LCNet 0.35"""
|
||||
model = _gen_lcnet('lcnet_035', 0.35, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def lcnet_050(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def lcnet_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" PP-LCNet 0.5"""
|
||||
model = _gen_lcnet('lcnet_050', 0.5, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def lcnet_075(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def lcnet_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" PP-LCNet 1.0"""
|
||||
model = _gen_lcnet('lcnet_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def lcnet_100(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def lcnet_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" PP-LCNet 1.0"""
|
||||
model = _gen_lcnet('lcnet_100', 1.0, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def lcnet_150(pretrained=False, **kwargs) -> MobileNetV3:
|
||||
def lcnet_150(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" PP-LCNet 1.5"""
|
||||
model = _gen_lcnet('lcnet_150', 1.5, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
Loading…
x
Reference in New Issue
Block a user