mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
include typing suggestions by @rwightman
This commit is contained in:
parent
05b0aaca51
commit
5f14bdd564
@ -52,4 +52,5 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
|||||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||||
from .trace_utils import _assert, _float_to_int
|
from .trace_utils import _assert, _float_to_int
|
||||||
|
from .typing import LayerType, PadType
|
||||||
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
|
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import functools
|
import functools
|
||||||
import types
|
import types
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import torch.nn
|
import torch.nn
|
||||||
|
|
||||||
|
|
||||||
BlockArgs = List[List[Dict[str, Any]]]
|
|
||||||
LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module]
|
LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module]
|
||||||
PadType = Union[str, int, Tuple[int, int]]
|
PadType = Union[str, int, Tuple[int, int]]
|
@ -11,6 +11,7 @@ import math
|
|||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -34,6 +35,8 @@ BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
|||||||
BN_EPS_TF_DEFAULT = 1e-3
|
BN_EPS_TF_DEFAULT = 1e-3
|
||||||
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
|
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
|
||||||
|
|
||||||
|
BlockArgs = List[List[Dict[str, Any]]]
|
||||||
|
|
||||||
|
|
||||||
def get_bn_args_tf():
|
def get_bn_args_tf():
|
||||||
return _BN_ARGS_TF.copy()
|
return _BN_ARGS_TF.copy()
|
||||||
|
@ -16,15 +16,14 @@ from torch import Tensor
|
|||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer
|
from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer
|
||||||
from ._builder import build_model_with_cfg, pretrained_cfg_for_features
|
from ._builder import build_model_with_cfg, pretrained_cfg_for_features
|
||||||
from ._efficientnet_blocks import SqueezeExcite
|
from ._efficientnet_blocks import SqueezeExcite
|
||||||
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
||||||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||||
from ._features import FeatureInfo, FeatureHooks
|
from ._features import FeatureInfo, FeatureHooks
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||||
from ._typing import BlockArgs, LayerType, PadType
|
|
||||||
|
|
||||||
__all__ = ['MobileNetV3', 'MobileNetV3Features']
|
__all__ = ['MobileNetV3', 'MobileNetV3Features']
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ Copyright 2019, Ross Wightman
|
|||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -17,22 +17,21 @@ import torch.nn.functional as F
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \
|
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
|
||||||
get_act_layer, get_norm_layer, create_classifier
|
get_attn, get_act_layer, get_norm_layer, create_classifier
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||||
from ._typing import LayerType
|
|
||||||
|
|
||||||
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
|
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
|
||||||
|
|
||||||
|
|
||||||
def get_padding(kernel_size: int, stride: int, dilation: int = 1):
|
def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
|
||||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||||
return padding
|
return padding
|
||||||
|
|
||||||
|
|
||||||
def create_aa(aa_layer, channels, stride=2, enable=True):
|
def create_aa(aa_layer: Type[nn.Module], channels: int, stride: int = 2, enable: bool = True) -> nn.Module:
|
||||||
if not aa_layer or not enable:
|
if not aa_layer or not enable:
|
||||||
return nn.Identity()
|
return nn.Identity()
|
||||||
if issubclass(aa_layer, nn.AvgPool2d):
|
if issubclass(aa_layer, nn.AvgPool2d):
|
||||||
@ -55,11 +54,11 @@ class BasicBlock(nn.Module):
|
|||||||
reduce_first: int = 1,
|
reduce_first: int = 1,
|
||||||
dilation: int = 1,
|
dilation: int = 1,
|
||||||
first_dilation: Optional[int] = None,
|
first_dilation: Optional[int] = None,
|
||||||
act_layer: nn.Module = nn.ReLU,
|
act_layer: Type[nn.Module] = nn.ReLU,
|
||||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||||
attn_layer: Optional[nn.Module] = None,
|
attn_layer: Optional[Type[nn.Module]] = None,
|
||||||
aa_layer: Optional[nn.Module] = None,
|
aa_layer: Optional[Type[nn.Module]] = None,
|
||||||
drop_block: Type[nn.Module] = None,
|
drop_block: Optional[Type[nn.Module]] = None,
|
||||||
drop_path: Optional[nn.Module] = None,
|
drop_path: Optional[nn.Module] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -153,11 +152,11 @@ class Bottleneck(nn.Module):
|
|||||||
reduce_first: int = 1,
|
reduce_first: int = 1,
|
||||||
dilation: int = 1,
|
dilation: int = 1,
|
||||||
first_dilation: Optional[int] = None,
|
first_dilation: Optional[int] = None,
|
||||||
act_layer: nn.Module = nn.ReLU,
|
act_layer: Type[nn.Module] = nn.ReLU,
|
||||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||||
attn_layer: Optional[nn.Module] = None,
|
attn_layer: Optional[Type[nn.Module]] = None,
|
||||||
aa_layer: Optional[nn.Module] = None,
|
aa_layer: Optional[Type[nn.Module]] = None,
|
||||||
drop_block: Type[nn.Module] = None,
|
drop_block: Optional[Type[nn.Module]] = None,
|
||||||
drop_path: Optional[nn.Module] = None,
|
drop_path: Optional[nn.Module] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -296,7 +295,7 @@ def drop_blocks(drop_prob: float = 0.):
|
|||||||
|
|
||||||
|
|
||||||
def make_blocks(
|
def make_blocks(
|
||||||
block_fn: nn.Module,
|
block_fn: Union[BasicBlock, Bottleneck],
|
||||||
channels: List[int],
|
channels: List[int],
|
||||||
block_repeats: List[int],
|
block_repeats: List[int],
|
||||||
inplanes: int,
|
inplanes: int,
|
||||||
@ -395,7 +394,7 @@ class ResNet(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
block: nn.Module,
|
block: Union[BasicBlock, Bottleneck],
|
||||||
layers: List[int],
|
layers: List[int],
|
||||||
num_classes: int = 1000,
|
num_classes: int = 1000,
|
||||||
in_chans: int = 3,
|
in_chans: int = 3,
|
||||||
@ -411,7 +410,7 @@ class ResNet(nn.Module):
|
|||||||
avg_down: bool = False,
|
avg_down: bool = False,
|
||||||
act_layer: LayerType = nn.ReLU,
|
act_layer: LayerType = nn.ReLU,
|
||||||
norm_layer: LayerType = nn.BatchNorm2d,
|
norm_layer: LayerType = nn.BatchNorm2d,
|
||||||
aa_layer: Optional[nn.Module] = None,
|
aa_layer: Optional[Type[nn.Module]] = None,
|
||||||
drop_rate: float = 0.0,
|
drop_rate: float = 0.0,
|
||||||
drop_path_rate: float = 0.,
|
drop_path_rate: float = 0.,
|
||||||
drop_block_rate: float = 0.,
|
drop_block_rate: float = 0.,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user