mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
include typing suggestions by @rwightman
This commit is contained in:
parent
05b0aaca51
commit
5f14bdd564
@ -52,4 +52,5 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .trace_utils import _assert, _float_to_int
|
||||
from .typing import LayerType, PadType
|
||||
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
|
||||
|
@ -1,10 +1,9 @@
|
||||
import functools
|
||||
import types
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch.nn
|
||||
|
||||
|
||||
BlockArgs = List[List[Dict[str, Any]]]
|
||||
LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module]
|
||||
PadType = Union[str, int, Tuple[int, int]]
|
@ -11,6 +11,7 @@ import math
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
@ -34,6 +35,8 @@ BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
||||
BN_EPS_TF_DEFAULT = 1e-3
|
||||
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
|
||||
|
||||
BlockArgs = List[List[Dict[str, Any]]]
|
||||
|
||||
|
||||
def get_bn_args_tf():
|
||||
return _BN_ARGS_TF.copy()
|
||||
|
@ -16,15 +16,14 @@ from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer
|
||||
from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer
|
||||
from ._builder import build_model_with_cfg, pretrained_cfg_for_features
|
||||
from ._efficientnet_blocks import SqueezeExcite
|
||||
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
||||
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
||||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from ._features import FeatureInfo, FeatureHooks
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
||||
from ._typing import BlockArgs, LayerType, PadType
|
||||
|
||||
__all__ = ['MobileNetV3', 'MobileNetV3Features']
|
||||
|
||||
|
@ -9,7 +9,7 @@ Copyright 2019, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -17,22 +17,21 @@ import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \
|
||||
get_act_layer, get_norm_layer, create_classifier
|
||||
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
|
||||
get_attn, get_act_layer, get_norm_layer, create_classifier
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||
from ._typing import LayerType
|
||||
|
||||
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def get_padding(kernel_size: int, stride: int, dilation: int = 1):
|
||||
def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
|
||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||
return padding
|
||||
|
||||
|
||||
def create_aa(aa_layer, channels, stride=2, enable=True):
|
||||
def create_aa(aa_layer: Type[nn.Module], channels: int, stride: int = 2, enable: bool = True) -> nn.Module:
|
||||
if not aa_layer or not enable:
|
||||
return nn.Identity()
|
||||
if issubclass(aa_layer, nn.AvgPool2d):
|
||||
@ -55,11 +54,11 @@ class BasicBlock(nn.Module):
|
||||
reduce_first: int = 1,
|
||||
dilation: int = 1,
|
||||
first_dilation: Optional[int] = None,
|
||||
act_layer: nn.Module = nn.ReLU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
attn_layer: Optional[nn.Module] = None,
|
||||
aa_layer: Optional[nn.Module] = None,
|
||||
drop_block: Type[nn.Module] = None,
|
||||
act_layer: Type[nn.Module] = nn.ReLU,
|
||||
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||
attn_layer: Optional[Type[nn.Module]] = None,
|
||||
aa_layer: Optional[Type[nn.Module]] = None,
|
||||
drop_block: Optional[Type[nn.Module]] = None,
|
||||
drop_path: Optional[nn.Module] = None,
|
||||
):
|
||||
"""
|
||||
@ -153,11 +152,11 @@ class Bottleneck(nn.Module):
|
||||
reduce_first: int = 1,
|
||||
dilation: int = 1,
|
||||
first_dilation: Optional[int] = None,
|
||||
act_layer: nn.Module = nn.ReLU,
|
||||
norm_layer: nn.Module = nn.BatchNorm2d,
|
||||
attn_layer: Optional[nn.Module] = None,
|
||||
aa_layer: Optional[nn.Module] = None,
|
||||
drop_block: Type[nn.Module] = None,
|
||||
act_layer: Type[nn.Module] = nn.ReLU,
|
||||
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
|
||||
attn_layer: Optional[Type[nn.Module]] = None,
|
||||
aa_layer: Optional[Type[nn.Module]] = None,
|
||||
drop_block: Optional[Type[nn.Module]] = None,
|
||||
drop_path: Optional[nn.Module] = None,
|
||||
):
|
||||
"""
|
||||
@ -296,7 +295,7 @@ def drop_blocks(drop_prob: float = 0.):
|
||||
|
||||
|
||||
def make_blocks(
|
||||
block_fn: nn.Module,
|
||||
block_fn: Union[BasicBlock, Bottleneck],
|
||||
channels: List[int],
|
||||
block_repeats: List[int],
|
||||
inplanes: int,
|
||||
@ -395,7 +394,7 @@ class ResNet(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: nn.Module,
|
||||
block: Union[BasicBlock, Bottleneck],
|
||||
layers: List[int],
|
||||
num_classes: int = 1000,
|
||||
in_chans: int = 3,
|
||||
@ -411,7 +410,7 @@ class ResNet(nn.Module):
|
||||
avg_down: bool = False,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[nn.Module] = None,
|
||||
aa_layer: Optional[Type[nn.Module]] = None,
|
||||
drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.,
|
||||
drop_block_rate: float = 0.,
|
||||
|
Loading…
x
Reference in New Issue
Block a user