mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add anti-aliasing support to mobilenetv3 and efficientnet family models. Update MobileNetV4 model defs, resolutions. Fix #599
* create_aa helper function centralized for all timm uses (resnet, convbnact helper) * allow BlurPool w/ pre-defined channels (expand) * mobilenetv4 UIB block using ConvNormAct layers for improved clarity, esp with AA added * improve more mobilenetv3 and efficientnet related type annotations
This commit is contained in:
parent
4ff7c25766
commit
5fa6efa158
@ -4,7 +4,7 @@ from .adaptive_avgmax_pool import \
|
||||
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
|
||||
from .attention_pool import AttentionPoolLatent
|
||||
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||
from .blur_pool import BlurPool2d
|
||||
from .blur_pool import BlurPool2d, create_aa
|
||||
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
|
||||
|
@ -5,12 +5,16 @@ BlurPool layer inspired by
|
||||
|
||||
Hacked together by Chris Ha and Ross Wightman
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from .padding import get_padding
|
||||
from .typing import LayerType
|
||||
|
||||
|
||||
class BlurPool2d(nn.Module):
|
||||
@ -26,17 +30,62 @@ class BlurPool2d(nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: the transformed tensor.
|
||||
"""
|
||||
def __init__(self, channels, filt_size=3, stride=2) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
channels: Optional[int] = None,
|
||||
filt_size: int = 3,
|
||||
stride: int = 2,
|
||||
pad_mode: str = 'reflect',
|
||||
) -> None:
|
||||
super(BlurPool2d, self).__init__()
|
||||
assert filt_size > 1
|
||||
self.channels = channels
|
||||
self.filt_size = filt_size
|
||||
self.stride = stride
|
||||
self.pad_mode = pad_mode
|
||||
self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
|
||||
|
||||
coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
|
||||
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
|
||||
blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :]
|
||||
if channels is not None:
|
||||
blur_filter = blur_filter.repeat(self.channels, 1, 1, 1)
|
||||
self.register_buffer('filt', blur_filter, persistent=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.pad(x, self.padding, 'reflect')
|
||||
return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)
|
||||
x = F.pad(x, self.padding, mode=self.pad_mode)
|
||||
if self.channels is None:
|
||||
channels = x.shape[1]
|
||||
weight = self.filt.expand(channels, 1, self.filt_size, self.filt_size)
|
||||
else:
|
||||
channels = self.channels
|
||||
weight = self.filt
|
||||
return F.conv2d(x, weight, stride=self.stride, groups=channels)
|
||||
|
||||
|
||||
def create_aa(
|
||||
aa_layer: LayerType,
|
||||
channels: Optional[int] = None,
|
||||
stride: int = 2,
|
||||
enable: bool = True,
|
||||
noop: Optional[Type[nn.Module]] = nn.Identity
|
||||
) -> nn.Module:
|
||||
""" Anti-aliasing """
|
||||
if not aa_layer or not enable:
|
||||
return noop() if noop is not None else None
|
||||
|
||||
if isinstance(aa_layer, str):
|
||||
aa_layer = aa_layer.lower().replace('_', '').replace('-', '')
|
||||
if aa_layer == 'avg' or aa_layer == 'avgpool':
|
||||
aa_layer = nn.AvgPool2d
|
||||
elif aa_layer == 'blur' or aa_layer == 'blurpool':
|
||||
aa_layer = BlurPool2d
|
||||
elif aa_layer == 'blurpc':
|
||||
aa_layer = partial(BlurPool2d, pad_mode='constant')
|
||||
|
||||
else:
|
||||
assert False, f"Unknown anti-aliasing layer ({aa_layer})."
|
||||
|
||||
try:
|
||||
return aa_layer(channels=channels, stride=stride)
|
||||
except TypeError as e:
|
||||
return aa_layer(stride)
|
||||
|
@ -2,9 +2,12 @@
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import functools
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from torch import nn as nn
|
||||
|
||||
from .typing import LayerType, PadType
|
||||
from .blur_pool import create_aa
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_norm_act import get_norm_act_layer
|
||||
|
||||
@ -12,28 +15,38 @@ from .create_norm_act import get_norm_act_layer
|
||||
class ConvNormAct(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding='',
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
apply_act=True,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
norm_kwargs=None,
|
||||
act_layer=nn.ReLU,
|
||||
act_kwargs=None,
|
||||
drop_layer=None,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 1,
|
||||
stride: int = 1,
|
||||
padding: PadType = '',
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = False,
|
||||
apply_act: bool = True,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
drop_layer: Optional[Type[nn.Module]] = None,
|
||||
conv_kwargs: Optional[Dict[str, Any]] = None,
|
||||
norm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
act_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(ConvNormAct, self).__init__()
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
act_kwargs = act_kwargs or {}
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_channels, out_channels, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
**conv_kwargs,
|
||||
)
|
||||
|
||||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
@ -64,54 +77,53 @@ class ConvNormAct(nn.Module):
|
||||
ConvBnAct = ConvNormAct
|
||||
|
||||
|
||||
def create_aa(aa_layer, channels, stride=2, enable=True):
|
||||
if not aa_layer or not enable:
|
||||
return nn.Identity()
|
||||
if isinstance(aa_layer, functools.partial):
|
||||
if issubclass(aa_layer.func, nn.AvgPool2d):
|
||||
return aa_layer()
|
||||
else:
|
||||
return aa_layer(channels)
|
||||
elif issubclass(aa_layer, nn.AvgPool2d):
|
||||
return aa_layer(stride)
|
||||
else:
|
||||
return aa_layer(channels=channels, stride=stride)
|
||||
|
||||
|
||||
class ConvNormActAa(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding='',
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False,
|
||||
apply_act=True,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
norm_kwargs=None,
|
||||
act_layer=nn.ReLU,
|
||||
act_kwargs=None,
|
||||
aa_layer=None,
|
||||
drop_layer=None,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 1,
|
||||
stride: int = 1,
|
||||
padding: PadType = '',
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = False,
|
||||
apply_act: bool = True,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
drop_layer: Optional[Type[nn.Module]] = None,
|
||||
conv_kwargs: Optional[Dict[str, Any]] = None,
|
||||
norm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
act_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(ConvNormActAa, self).__init__()
|
||||
use_aa = aa_layer is not None and stride == 2
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
act_kwargs = act_kwargs or {}
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
|
||||
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
||||
in_channels, out_channels, kernel_size,
|
||||
stride=1 if use_aa else stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
**conv_kwargs,
|
||||
)
|
||||
|
||||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn = norm_act_layer(out_channels, apply_act=apply_act, act_kwargs=act_kwargs, **norm_kwargs)
|
||||
self.bn = norm_act_layer(
|
||||
out_channels,
|
||||
apply_act=apply_act,
|
||||
act_kwargs=act_kwargs,
|
||||
**norm_kwargs,
|
||||
)
|
||||
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
|
||||
|
||||
@property
|
||||
|
@ -2,22 +2,24 @@
|
||||
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
from typing import Optional
|
||||
from typing import Callable, Dict, Optional, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, to_2tuple,\
|
||||
get_norm_act_layer, MultiQueryAttention2d, MultiQueryAttentionV2, Attention2d
|
||||
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\
|
||||
ConvNormAct, ConvNormActAa, get_norm_act_layer, MultiQueryAttention2d, Attention2d
|
||||
|
||||
__all__ = [
|
||||
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual',
|
||||
'UniversalInvertedResidual', 'MobileAttention'
|
||||
]
|
||||
|
||||
ModuleType = Type[nn.Module]
|
||||
|
||||
def num_groups(group_size, channels):
|
||||
|
||||
def num_groups(group_size: Optional[int], channels: int):
|
||||
if not group_size: # 0 or None
|
||||
return 1 # normal conv with 1 group
|
||||
else:
|
||||
@ -40,13 +42,13 @@ class SqueezeExcite(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
rd_ratio=0.25,
|
||||
rd_channels=None,
|
||||
act_layer=nn.ReLU,
|
||||
gate_layer=nn.Sigmoid,
|
||||
force_act_layer=None,
|
||||
rd_round_fn=None,
|
||||
in_chs: int,
|
||||
rd_ratio: float = 0.25,
|
||||
rd_channels: Optional[int] = None,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
gate_layer: LayerType = nn.Sigmoid,
|
||||
force_act_layer: Optional[LayerType] = None,
|
||||
rd_round_fn: Optional[Callable] = None,
|
||||
):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
if rd_channels is None:
|
||||
@ -71,27 +73,31 @@ class ConvBnAct(nn.Module):
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group_size=0,
|
||||
pad_type='',
|
||||
skip=False,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
drop_path_rate=0.,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 0,
|
||||
pad_type: str = '',
|
||||
skip: bool = False,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
super(ConvBnAct, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
groups = num_groups(group_size, in_chs)
|
||||
self.has_skip = skip and stride == 1 and in_chs == out_chs
|
||||
use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_chs, out_chs, kernel_size,
|
||||
stride=stride, dilation=dilation, groups=groups, padding=pad_type)
|
||||
stride=1 if use_aa else stride,
|
||||
dilation=dilation, groups=groups, padding=pad_type)
|
||||
self.bn1 = norm_act_layer(out_chs, inplace=True)
|
||||
self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
|
||||
|
||||
def feature_info(self, location):
|
||||
@ -104,6 +110,7 @@ class ConvBnAct(nn.Module):
|
||||
shortcut = x
|
||||
x = self.conv(x)
|
||||
x = self.bn1(x)
|
||||
x = self.aa(x)
|
||||
if self.has_skip:
|
||||
x = self.drop_path(x) + shortcut
|
||||
return x
|
||||
@ -116,37 +123,38 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
dw_kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group_size=1,
|
||||
pad_type='',
|
||||
noskip=False,
|
||||
pw_kernel_size=1,
|
||||
pw_act=False,
|
||||
s2d=0,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
se_layer=None,
|
||||
drop_path_rate=0.,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
dw_kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 1,
|
||||
pad_type: str = '',
|
||||
noskip: bool = False,
|
||||
pw_kernel_size: int = 1,
|
||||
pw_act: bool = False,
|
||||
s2d: int = 0,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[ModuleType] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
|
||||
self.has_pw_act = pw_act # activation after point-wise conv
|
||||
use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
|
||||
|
||||
# Space to depth
|
||||
if s2d == 1:
|
||||
sd_chs = int(in_chs * 4)
|
||||
#sd_pad_type = 'sam'
|
||||
self.conv_s2d = create_conv2d(
|
||||
in_chs, sd_chs, kernel_size=2, stride=2, padding=0) #'same')
|
||||
self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same')
|
||||
self.bn_s2d = norm_act_layer(sd_chs, sd_chs)
|
||||
dw_kernel_size = (dw_kernel_size + 1) // 2
|
||||
dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
|
||||
in_chs = sd_chs
|
||||
use_aa = False # disable AA
|
||||
else:
|
||||
self.conv_s2d = None
|
||||
self.bn_s2d = None
|
||||
@ -156,8 +164,10 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
|
||||
self.conv_dw = create_conv2d(
|
||||
in_chs, in_chs, dw_kernel_size,
|
||||
stride=stride, dilation=dilation, padding=dw_pad_type, groups=groups)
|
||||
stride=1 if use_aa else stride,
|
||||
dilation=dilation, padding=dw_pad_type, groups=groups)
|
||||
self.bn1 = norm_act_layer(in_chs, inplace=True)
|
||||
self.aa = create_aa(aa_layer, channels=out_chs, stride=stride, enable=use_aa)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
||||
@ -174,13 +184,12 @@ class DepthwiseSeparableConv(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
#print('ii', x.shape) # FIXME debug s2d
|
||||
if self.conv_s2d is not None:
|
||||
x = self.conv_s2d(x)
|
||||
x = self.bn_s2d(x)
|
||||
#print('id', x.shape) # FIXME debug s2d
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn1(x)
|
||||
x = self.aa(x)
|
||||
x = self.se(x)
|
||||
x = self.conv_pw(x)
|
||||
x = self.bn2(x)
|
||||
@ -201,37 +210,40 @@ class InvertedResidual(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
dw_kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group_size=1,
|
||||
pad_type='',
|
||||
noskip=False,
|
||||
exp_ratio=1.0,
|
||||
exp_kernel_size=1,
|
||||
pw_kernel_size=1,
|
||||
s2d=0,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
se_layer=None,
|
||||
conv_kwargs=None,
|
||||
drop_path_rate=0.,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
dw_kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 1,
|
||||
pad_type: str = '',
|
||||
noskip: bool = False,
|
||||
exp_ratio: float = 1.0,
|
||||
exp_kernel_size: int = 1,
|
||||
pw_kernel_size: int = 1,
|
||||
s2d: int = 0,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[ModuleType] = None,
|
||||
conv_kwargs: Optional[Dict] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
super(InvertedResidual, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
|
||||
use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
|
||||
|
||||
# Space to depth
|
||||
if s2d == 1:
|
||||
sd_chs = int(in_chs * 4)
|
||||
self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding=pad_type)
|
||||
self.conv_s2d = create_conv2d(in_chs, sd_chs, kernel_size=2, stride=2, padding='same')
|
||||
self.bn_s2d = norm_act_layer(sd_chs, sd_chs)
|
||||
dw_kernel_size = (dw_kernel_size + 1) // 2
|
||||
dw_pad_type = 'same' if dw_kernel_size == 2 else pad_type
|
||||
in_chs = sd_chs
|
||||
use_aa = False # disable AA
|
||||
else:
|
||||
self.conv_s2d = None
|
||||
self.bn_s2d = None
|
||||
@ -247,8 +259,10 @@ class InvertedResidual(nn.Module):
|
||||
# Depth-wise convolution
|
||||
self.conv_dw = create_conv2d(
|
||||
mid_chs, mid_chs, dw_kernel_size,
|
||||
stride=stride, dilation=dilation, groups=groups, padding=dw_pad_type, **conv_kwargs)
|
||||
stride=1 if use_aa else stride,
|
||||
dilation=dilation, groups=groups, padding=dw_pad_type, **conv_kwargs)
|
||||
self.bn2 = norm_act_layer(mid_chs, inplace=True)
|
||||
self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
||||
@ -273,6 +287,7 @@ class InvertedResidual(nn.Module):
|
||||
x = self.bn1(x)
|
||||
x = self.conv_dw(x)
|
||||
x = self.bn2(x)
|
||||
x = self.aa(x)
|
||||
x = self.se(x)
|
||||
x = self.conv_pwl(x)
|
||||
x = self.bn3(x)
|
||||
@ -282,7 +297,7 @@ class InvertedResidual(nn.Module):
|
||||
|
||||
|
||||
class LayerScale2d(nn.Module):
|
||||
def __init__(self, dim, init_values=1e-5, inplace=False):
|
||||
def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False):
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
@ -293,7 +308,7 @@ class LayerScale2d(nn.Module):
|
||||
|
||||
|
||||
class UniversalInvertedResidual(nn.Module):
|
||||
""" Universal Inverted Residual Block
|
||||
""" Universal Inverted Residual Block (aka Universal Inverted Bottleneck, UIB)
|
||||
|
||||
For MobileNetV4 - https://arxiv.org/abs/, referenced from
|
||||
https://github.com/tensorflow/models/blob/d93c7e932de27522b2fa3b115f58d06d6f640537/official/vision/modeling/layers/nn_blocks.py#L778
|
||||
@ -301,89 +316,109 @@ class UniversalInvertedResidual(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
dw_kernel_size_start: int = 0,
|
||||
dw_kernel_size_mid: int = 3,
|
||||
dw_kernel_size_end: int = 0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group_size=1,
|
||||
pad_type='',
|
||||
noskip=False,
|
||||
exp_ratio=1.0,
|
||||
act_layer=nn.ReLU,
|
||||
dw_act_layer=None,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
se_layer=None,
|
||||
conv_kwargs=None,
|
||||
drop_path_rate=0.,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 1,
|
||||
pad_type: str = '',
|
||||
noskip: bool = False,
|
||||
exp_ratio: float = 1.0,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[ModuleType] = None,
|
||||
conv_kwargs: Optional[Dict] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
layer_scale_init_value: Optional[float] = 1e-5,
|
||||
):
|
||||
super(UniversalInvertedResidual, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
dw_act_layer = dw_act_layer or act_layer
|
||||
dw_norm_act_layer = get_norm_act_layer(norm_layer, dw_act_layer)
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
|
||||
if stride > 1:
|
||||
assert dw_kernel_size_start or dw_kernel_size_mid or dw_kernel_size_end
|
||||
|
||||
# FIXME dilation isn't right w/ extra ks > 1 convs
|
||||
if dw_kernel_size_start:
|
||||
self.conv_dw_start = create_conv2d(
|
||||
dw_start_stride = stride if not dw_kernel_size_mid else 1
|
||||
dw_start_groups = num_groups(group_size, in_chs)
|
||||
self.dw_start = ConvNormActAa(
|
||||
in_chs, in_chs, dw_kernel_size_start,
|
||||
stride=dw_start_stride,
|
||||
dilation=dilation, # FIXME
|
||||
depthwise=True,
|
||||
groups=dw_start_groups,
|
||||
padding=pad_type,
|
||||
apply_act=False,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
**conv_kwargs,
|
||||
)
|
||||
self.norm_dw_start = dw_norm_act_layer(in_chs, apply_act=False)
|
||||
else:
|
||||
# start is None when not used for cleaner repr
|
||||
self.conv_dw_start = None
|
||||
self.norm_dw_start = None
|
||||
self.dw_start = nn.Identity()
|
||||
|
||||
# Point-wise expansion
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
self.conv_pw = create_conv2d(in_chs, mid_chs, 1, padding=pad_type, **conv_kwargs)
|
||||
self.norm_pw = norm_act_layer(mid_chs, inplace=True)
|
||||
self.pw_exp = ConvNormAct(
|
||||
in_chs, mid_chs, 1,
|
||||
padding=pad_type,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
**conv_kwargs,
|
||||
)
|
||||
|
||||
# Depth-wise convolution
|
||||
# Middle depth-wise convolution
|
||||
if dw_kernel_size_mid:
|
||||
groups = num_groups(group_size, mid_chs)
|
||||
self.conv_dw_mid = create_conv2d(
|
||||
self.dw_mid = ConvNormActAa(
|
||||
mid_chs, mid_chs, dw_kernel_size_mid,
|
||||
stride=stride,
|
||||
dilation=dilation, # FIXME
|
||||
groups=groups,
|
||||
padding=pad_type,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
**conv_kwargs,
|
||||
)
|
||||
self.norm_dw_mid = dw_norm_act_layer(mid_chs, inplace=True)
|
||||
else:
|
||||
# keeping mid as identity so it can be hooked more easily for features
|
||||
self.conv_dw_mid = nn.Identity()
|
||||
self.norm_dw_mid = nn.Identity()
|
||||
self.dw_mid = nn.Identity()
|
||||
|
||||
# Squeeze-and-excitation
|
||||
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
||||
|
||||
# Point-wise linear projection
|
||||
self.conv_pwl = create_conv2d(mid_chs, out_chs, 1, padding=pad_type, **conv_kwargs)
|
||||
self.norm_pwl = norm_act_layer(out_chs, apply_act=False)
|
||||
self.pw_proj = ConvNormAct(
|
||||
mid_chs, out_chs, 1,
|
||||
padding=pad_type,
|
||||
apply_act=False,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
**conv_kwargs,
|
||||
)
|
||||
|
||||
if dw_kernel_size_end:
|
||||
self.conv_dw_end = create_conv2d(
|
||||
dw_end_stride = stride if not dw_kernel_size_start and not dw_kernel_size_mid else 1
|
||||
dw_end_groups = num_groups(group_size, out_chs)
|
||||
if dw_end_stride > 1:
|
||||
assert not aa_layer
|
||||
self.dw_end = ConvNormAct(
|
||||
out_chs, out_chs, dw_kernel_size_end,
|
||||
stride=dw_end_stride,
|
||||
dilation=dilation,
|
||||
depthwise=True,
|
||||
groups=dw_end_groups,
|
||||
padding=pad_type,
|
||||
apply_act=False,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
**conv_kwargs,
|
||||
)
|
||||
self.norm_dw_end = dw_norm_act_layer(out_chs, apply_act=False)
|
||||
else:
|
||||
# end is None when not in use for cleaner repr
|
||||
self.conv_dw_end = None
|
||||
self.norm_dw_end = None
|
||||
self.dw_end = nn.Identity()
|
||||
|
||||
if layer_scale_init_value is not None:
|
||||
self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
|
||||
@ -393,25 +428,18 @@ class UniversalInvertedResidual(nn.Module):
|
||||
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion': # after SE, input to PWL
|
||||
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
||||
return dict(module='pw_proj.conv', hook_type='forward_pre', num_chs=self.pw_proj.conv.in_channels)
|
||||
else: # location == 'bottleneck', block output
|
||||
return dict(module='', num_chs=self.conv_pwl.out_channels)
|
||||
return dict(module='', num_chs=self.pw_proj.conv.out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
if self.conv_dw_start is not None:
|
||||
x = self.conv_dw_start(x)
|
||||
x = self.norm_dw_start(x)
|
||||
x = self.conv_pw(x)
|
||||
x = self.norm_pw(x)
|
||||
x = self.conv_dw_mid(x)
|
||||
x = self.norm_dw_mid(x)
|
||||
x = self.dw_start(x)
|
||||
x = self.pw_exp(x)
|
||||
x = self.dw_mid(x)
|
||||
x = self.se(x)
|
||||
x = self.conv_pwl(x)
|
||||
x = self.norm_pwl(x)
|
||||
if self.conv_dw_end is not None:
|
||||
x = self.conv_dw_end(x)
|
||||
x = self.norm_dw_end(x)
|
||||
x = self.pw_proj(x)
|
||||
x = self.dw_end(x)
|
||||
x = self.layer_scale(x)
|
||||
if self.has_skip:
|
||||
x = self.drop_path(x) + shortcut
|
||||
@ -426,29 +454,30 @@ class MobileAttention(nn.Module):
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=1,
|
||||
dw_kernel_size=3,
|
||||
dilation=1,
|
||||
group_size=1,
|
||||
pad_type='',
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
stride: int = 1,
|
||||
dw_kernel_size: int = 3,
|
||||
dilation: int = 1,
|
||||
group_size: int = 1,
|
||||
pad_type: str = '',
|
||||
num_heads: int = 8,
|
||||
key_dim: int = 64,
|
||||
value_dim: int = 64,
|
||||
use_multi_query: bool = False,
|
||||
query_strides: int = (1, 1),
|
||||
kv_stride: int = 1,
|
||||
cpe_dw_kernel_size=3,
|
||||
noskip=False,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
drop_path_rate=0.,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
cpe_dw_kernel_size: int = 3,
|
||||
noskip: bool = False,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
layer_scale_init_value: Optional[float] = 1e-5,
|
||||
use_bias=False,
|
||||
use_cpe=False,
|
||||
use_bias: bool = False,
|
||||
use_cpe: bool = False,
|
||||
):
|
||||
super(MobileAttention, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
@ -512,7 +541,6 @@ class MobileAttention(nn.Module):
|
||||
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
|
||||
|
||||
|
||||
def feature_info(self, location):
|
||||
if location == 'expansion': # after SE, input to PW
|
||||
return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
|
||||
@ -539,22 +567,23 @@ class CondConvResidual(InvertedResidual):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
dw_kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group_size=1,
|
||||
pad_type='',
|
||||
noskip=False,
|
||||
exp_ratio=1.0,
|
||||
exp_kernel_size=1,
|
||||
pw_kernel_size=1,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
se_layer=None,
|
||||
num_experts=0,
|
||||
drop_path_rate=0.,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
dw_kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 1,
|
||||
pad_type: str = '',
|
||||
noskip: bool = False,
|
||||
exp_ratio: float = 1.0,
|
||||
exp_kernel_size: int = 1,
|
||||
pw_kernel_size: int = 1,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[ModuleType] = None,
|
||||
num_experts: int = 0,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
|
||||
self.num_experts = num_experts
|
||||
@ -567,13 +596,14 @@ class CondConvResidual(InvertedResidual):
|
||||
dilation=dilation,
|
||||
group_size=group_size,
|
||||
pad_type=pad_type,
|
||||
act_layer=act_layer,
|
||||
noskip=noskip,
|
||||
exp_ratio=exp_ratio,
|
||||
exp_kernel_size=exp_kernel_size,
|
||||
pw_kernel_size=pw_kernel_size,
|
||||
se_layer=se_layer,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
se_layer=se_layer,
|
||||
conv_kwargs=conv_kwargs,
|
||||
drop_path_rate=drop_path_rate,
|
||||
)
|
||||
@ -609,21 +639,22 @@ class EdgeResidual(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
exp_kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group_size=0,
|
||||
pad_type='',
|
||||
force_in_chs=0,
|
||||
noskip=False,
|
||||
exp_ratio=1.0,
|
||||
pw_kernel_size=1,
|
||||
act_layer=nn.ReLU,
|
||||
norm_layer=nn.BatchNorm2d,
|
||||
se_layer=None,
|
||||
drop_path_rate=0.,
|
||||
in_chs: int,
|
||||
out_chs: int,
|
||||
exp_kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
group_size: int = 0,
|
||||
pad_type: str = '',
|
||||
force_in_chs: int = 0,
|
||||
noskip: bool = False,
|
||||
exp_ratio: float = 1.0,
|
||||
pw_kernel_size: int = 1,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[ModuleType] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
super(EdgeResidual, self).__init__()
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
@ -633,13 +664,17 @@ class EdgeResidual(nn.Module):
|
||||
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||
groups = num_groups(group_size, in_chs)
|
||||
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
|
||||
use_aa = aa_layer is not None and stride > 1 # FIXME handle dilation
|
||||
|
||||
# Expansion convolution
|
||||
self.conv_exp = create_conv2d(
|
||||
in_chs, mid_chs, exp_kernel_size,
|
||||
stride=stride, dilation=dilation, groups=groups, padding=pad_type)
|
||||
stride=1 if use_aa else stride,
|
||||
dilation=dilation, groups=groups, padding=pad_type)
|
||||
self.bn1 = norm_act_layer(mid_chs, inplace=True)
|
||||
|
||||
self.aa = create_aa(aa_layer, channels=mid_chs, stride=stride, enable=use_aa)
|
||||
|
||||
# Squeeze-and-excitation
|
||||
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
|
||||
|
||||
@ -658,6 +693,7 @@ class EdgeResidual(nn.Module):
|
||||
shortcut = x
|
||||
x = self.conv_exp(x)
|
||||
x = self.bn1(x)
|
||||
x = self.aa(x)
|
||||
x = self.se(x)
|
||||
x = self.conv_pwl(x)
|
||||
x = self.bn2(x)
|
||||
|
@ -17,7 +17,7 @@ from typing import Any, Dict, List
|
||||
import torch.nn as nn
|
||||
|
||||
from ._efficientnet_blocks import *
|
||||
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
|
||||
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible, LayerType
|
||||
|
||||
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
|
||||
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
|
||||
@ -326,9 +326,10 @@ class EfficientNetBuilder:
|
||||
pad_type: str = '',
|
||||
round_chs_fn: Callable = round_channels,
|
||||
se_from_exp: bool = False,
|
||||
act_layer: Optional[Callable] = None,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
se_layer: Optional[Callable] = None,
|
||||
act_layer: Optional[LayerType] = None,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[LayerType] = None,
|
||||
drop_path_rate: float = 0.,
|
||||
layer_scale_init_value: Optional[float] = None,
|
||||
feature_location: str = '',
|
||||
@ -339,6 +340,7 @@ class EfficientNetBuilder:
|
||||
self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
|
||||
self.act_layer = act_layer
|
||||
self.norm_layer = norm_layer
|
||||
self.aa_layer = aa_layer
|
||||
self.se_layer = get_attn(se_layer)
|
||||
try:
|
||||
self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg
|
||||
@ -378,6 +380,9 @@ class EfficientNetBuilder:
|
||||
ba['norm_layer'] = self.norm_layer
|
||||
ba['drop_path_rate'] = drop_path_rate
|
||||
|
||||
if self.aa_layer is not None:
|
||||
ba['aa_layer'] = self.aa_layer
|
||||
|
||||
se_ratio = ba.pop('se_ratio', None)
|
||||
if se_ratio and self.se_layer is not None:
|
||||
if not self.se_from_exp:
|
||||
@ -461,6 +466,7 @@ class EfficientNetBuilder:
|
||||
space2depth = 1
|
||||
|
||||
if space2depth > 0:
|
||||
# FIXME s2d is a WIP
|
||||
if space2depth == 2 and block_args['stride'] == 2:
|
||||
block_args['stride'] = 1
|
||||
# to end s2d region, need to correct expansion and se ratio relative to input
|
||||
|
@ -36,7 +36,7 @@ the models and weights open source!
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -44,10 +44,10 @@ import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct
|
||||
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct, LayerType
|
||||
from ._builder import build_model_with_cfg, pretrained_cfg_for_features
|
||||
from ._efficientnet_blocks import SqueezeExcite
|
||||
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
||||
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
||||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
@ -74,21 +74,22 @@ class EfficientNet(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_args,
|
||||
num_classes=1000,
|
||||
num_features=1280,
|
||||
in_chans=3,
|
||||
stem_size=32,
|
||||
fix_stem=False,
|
||||
output_stride=32,
|
||||
pad_type='',
|
||||
round_chs_fn=round_channels,
|
||||
act_layer=None,
|
||||
norm_layer=None,
|
||||
se_layer=None,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
global_pool='avg'
|
||||
block_args: BlockArgs,
|
||||
num_classes: int = 1000,
|
||||
num_features: int = 1280,
|
||||
in_chans: int = 3,
|
||||
stem_size: int = 32,
|
||||
fix_stem: bool = False,
|
||||
output_stride: int = 32,
|
||||
pad_type: str = '',
|
||||
act_layer: Optional[LayerType] = None,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[LayerType] = None,
|
||||
round_chs_fn: Callable = round_channels,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
global_pool: str = 'avg'
|
||||
):
|
||||
super(EfficientNet, self).__init__()
|
||||
act_layer = act_layer or nn.ReLU
|
||||
@ -113,6 +114,7 @@ class EfficientNet(nn.Module):
|
||||
round_chs_fn=round_chs_fn,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
se_layer=se_layer,
|
||||
drop_path_rate=drop_path_rate,
|
||||
)
|
||||
@ -270,20 +272,21 @@ class EfficientNetFeatures(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_args,
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
feature_location='bottleneck',
|
||||
in_chans=3,
|
||||
stem_size=32,
|
||||
fix_stem=False,
|
||||
output_stride=32,
|
||||
pad_type='',
|
||||
round_chs_fn=round_channels,
|
||||
act_layer=None,
|
||||
norm_layer=None,
|
||||
se_layer=None,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.
|
||||
block_args: BlockArgs,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
feature_location: str = 'bottleneck',
|
||||
in_chans: int = 3,
|
||||
stem_size: int = 32,
|
||||
fix_stem: bool = False,
|
||||
output_stride: int = 32,
|
||||
pad_type: str = '',
|
||||
act_layer: Optional[LayerType] = None,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[LayerType] = None,
|
||||
round_chs_fn: Callable = round_channels,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
):
|
||||
super(EfficientNetFeatures, self).__init__()
|
||||
act_layer = act_layer or nn.ReLU
|
||||
@ -306,6 +309,7 @@ class EfficientNetFeatures(nn.Module):
|
||||
round_chs_fn=round_chs_fn,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
se_layer=se_layer,
|
||||
drop_path_rate=drop_path_rate,
|
||||
feature_location=feature_location,
|
||||
@ -1154,6 +1158,7 @@ default_cfgs = generate_default_cfgs({
|
||||
input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
|
||||
'efficientnet_b3_g8_gn.untrained': _cfg(
|
||||
input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
|
||||
'efficientnet_blur_b0.untrained': _cfg(),
|
||||
|
||||
'efficientnet_es.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth',
|
||||
@ -1850,6 +1855,17 @@ def efficientnet_b3_g8_gn(pretrained=False, **kwargs) -> EfficientNet:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_blur_b0(pretrained=False, **kwargs) -> EfficientNet:
|
||||
""" EfficientNet-B0 w/ BlurPool """
|
||||
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
|
||||
model = _gen_efficientnet(
|
||||
'efficientnet_blur_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained,
|
||||
aa_layer='blurpc', **kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def efficientnet_es(pretrained=False, **kwargs) -> EfficientNet:
|
||||
""" EfficientNet-Edge Small. """
|
||||
|
@ -40,6 +40,7 @@ class MobileNetV3(nn.Module):
|
||||
* HardCoRe-NAS - https://arxiv.org/abs/2102.11646 (defn in hardcorenas.py uses this class)
|
||||
* FBNet-V3 - https://arxiv.org/abs/2006.02049
|
||||
* LCNet - https://arxiv.org/abs/2109.15099
|
||||
* MobileNet-V4 - https://arxiv.org/abs/2404.10518
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -52,9 +53,10 @@ class MobileNetV3(nn.Module):
|
||||
num_features: int = 1280,
|
||||
head_bias: bool = True,
|
||||
head_norm: bool = False,
|
||||
pad_type: PadType = '',
|
||||
pad_type: str = '',
|
||||
act_layer: Optional[LayerType] = None,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[LayerType] = None,
|
||||
se_from_exp: bool = True,
|
||||
round_chs_fn: Callable = round_channels,
|
||||
@ -75,6 +77,7 @@ class MobileNetV3(nn.Module):
|
||||
pad_type: Type of padding to use for convolution layers.
|
||||
act_layer: Type of activation layer.
|
||||
norm_layer: Type of normalization layer.
|
||||
aa_layer: Type of anti-aliasing layer.
|
||||
se_layer: Type of Squeeze-and-Excite layer.
|
||||
se_from_exp: If True, calculate SE channel reduction from expanded mid channels.
|
||||
round_chs_fn: Callable to round number of filters based on depth multiplier.
|
||||
@ -107,6 +110,7 @@ class MobileNetV3(nn.Module):
|
||||
se_from_exp=se_from_exp,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
se_layer=se_layer,
|
||||
drop_path_rate=drop_path_rate,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
@ -291,6 +295,7 @@ class MobileNetV3Features(nn.Module):
|
||||
se_from_exp: bool = True,
|
||||
act_layer: Optional[LayerType] = None,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
se_layer: Optional[LayerType] = None,
|
||||
drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
@ -337,6 +342,7 @@ class MobileNetV3Features(nn.Module):
|
||||
se_from_exp=se_from_exp,
|
||||
act_layer=act_layer,
|
||||
norm_layer=norm_layer,
|
||||
aa_layer=aa_layer,
|
||||
se_layer=se_layer,
|
||||
drop_path_rate=drop_path_rate,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
@ -649,15 +655,17 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
|
||||
Args:
|
||||
channel_multiplier: multiplier to number of channels per layer.
|
||||
"""
|
||||
num_features = 1280
|
||||
if 'hybrid' in variant:
|
||||
layer_scale_init_value = 1e-5
|
||||
if 'medium' in variant:
|
||||
stem_size = 32
|
||||
num_features = 1280
|
||||
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['er_r1_k3_s2_e4_c48'], # FusedIB (EdgeResidual)
|
||||
[
|
||||
'er_r1_k3_s2_e4_c48' # FusedIB (EdgeResidual)
|
||||
],
|
||||
# stage 1, 56x56 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e4_c80', # ExtraDW
|
||||
@ -689,23 +697,26 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
|
||||
'uir_r1_a0_k0_s1_e4_c256', # FFN
|
||||
'mqa_r1_k3_h4_s1_d64_c256', # MQA
|
||||
'uir_r1_a3_k0_s1_e4_c256', # ConvNeXt
|
||||
'mqa_r1_k3_h4_s1_d64_c256', # MQA
|
||||
'mqa_r1_k3_h4_s1_d64_c256', # MQA
|
||||
'uir_r1_a5_k5_s1_e4_c256', # ExtraDW
|
||||
'mqa_r1_k3_h4_s1_d64_c256', # MQA
|
||||
'mqa_r1_k3_h4_s1_d64_c256', # MQA
|
||||
'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt
|
||||
'mqa_r1_k3_h4_s1_d64_c256', # MQA
|
||||
'uir_r1_a5_k0_s1_e4_c256', # ConvNeXt
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # Conv
|
||||
[
|
||||
'cn_r1_k1_s1_c960' # Conv
|
||||
],
|
||||
]
|
||||
elif 'large' in variant:
|
||||
stem_size = 24
|
||||
num_features = 1280
|
||||
act_layer = resolve_act_layer(kwargs, 'gelu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['er_r1_k3_s2_e4_c48'], # FusedIB (EdgeResidual)
|
||||
[
|
||||
'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual)
|
||||
],
|
||||
# stage 1, 56x56 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e4_c96', # ExtraDW
|
||||
@ -734,17 +745,19 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
|
||||
'uir_r2_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
'uir_r1_a5_k3_s1_e4_c512', # ExtraDW
|
||||
'uir_r1_a5_k5_s1_e4_c512', # ExtraDW
|
||||
'mqa_r1_k3_h8_s1_d64_c512', # MQA
|
||||
'mqa_r1_k3_h8_s1_d64_c512', # MQA
|
||||
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
'mqa_r1_k3_h8_s1_d64_c512', # MQA
|
||||
'mqa_r1_k3_h8_s1_d64_c512', # MQA
|
||||
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
'mqa_r1_k3_h8_s1_d64_c512', # MQA
|
||||
'mqa_r1_k3_h8_s1_d64_c512', # MQA
|
||||
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
'mqa_r1_k3_h8_s1_d64_c512', # MQA
|
||||
'mqa_r1_k3_h8_s1_d64_c512', # MQA
|
||||
'uir_r1_a5_k0_s1_e4_c512', # ConvNeXt
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
['cn_r1_k1_s1_c960'],
|
||||
[
|
||||
'cn_r1_k1_s1_c960', # Conv
|
||||
],
|
||||
]
|
||||
else:
|
||||
assert False, f'Unknown variant {variant}.'
|
||||
@ -752,7 +765,6 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
|
||||
layer_scale_init_value = None
|
||||
if 'small' in variant:
|
||||
stem_size = 32
|
||||
num_features = 1280
|
||||
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
@ -780,15 +792,18 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
|
||||
'uir_r2_a0_k3_s1_e4_c128', # IR
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # Conv
|
||||
[
|
||||
'cn_r1_k1_s1_c960', # Conv
|
||||
],
|
||||
]
|
||||
elif 'medium' in variant:
|
||||
stem_size = 32
|
||||
num_features = 1280
|
||||
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['er_r1_k3_s2_e4_c48'], # FusedIB (EdgeResidual)
|
||||
[
|
||||
'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual)
|
||||
],
|
||||
# stage 1, 56x56 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e4_c80', # ExtraDW
|
||||
@ -817,15 +832,18 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
|
||||
'uir_r1_a5_k0_s1_e2_c256', # ConvNeXt
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # Conv
|
||||
[
|
||||
'cn_r1_k1_s1_c960', # Conv
|
||||
],
|
||||
]
|
||||
elif 'large' in variant:
|
||||
stem_size = 24
|
||||
num_features = 1280
|
||||
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||
arch_def = [
|
||||
# stage 0, 112x112 in
|
||||
['er_r1_k3_s2_e4_c48'], # FusedIB (EdgeResidual)
|
||||
[
|
||||
'er_r1_k3_s2_e4_c48', # FusedIB (EdgeResidual)
|
||||
],
|
||||
# stage 1, 56x56 in
|
||||
[
|
||||
'uir_r1_a3_k5_s2_e4_c96', # ExtraDW
|
||||
@ -851,24 +869,23 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
|
||||
|
||||
],
|
||||
# stage 4, 7x7 in
|
||||
['cn_r1_k1_s1_c960'], # Conv
|
||||
[
|
||||
'cn_r1_k1_s1_c960', # Conv
|
||||
],
|
||||
]
|
||||
else:
|
||||
assert False, f'Unknown variant {variant}.'
|
||||
|
||||
# NOTE SE not used in initial MobileNet-v4 definitions
|
||||
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
|
||||
model_kwargs = dict(
|
||||
block_args=decode_arch_def(arch_def),
|
||||
head_bias=False,
|
||||
head_norm=True,
|
||||
num_features=num_features,
|
||||
stem_size=stem_size,
|
||||
fix_stem=channel_multiplier < 0.75,
|
||||
fix_stem=channel_multiplier < 1.0,
|
||||
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
|
||||
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
|
||||
act_layer=act_layer,
|
||||
se_layer=se_layer,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
**kwargs,
|
||||
)
|
||||
@ -904,9 +921,6 @@ default_cfgs = generate_default_cfgs({
|
||||
origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
|
||||
paper_ids='arXiv:2104.10972v4',
|
||||
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
|
||||
'mobilenetv3_large_150.untrained': _cfg(
|
||||
interpolation='bicubic'),
|
||||
|
||||
|
||||
'mobilenetv3_small_050.lamb_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
|
||||
@ -985,28 +999,48 @@ default_cfgs = generate_default_cfgs({
|
||||
'mobilenetv4_conv_small': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'mobilenetv4_conv_medium': _cfg(
|
||||
#hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'mobilenetv4_conv_large': _cfg(
|
||||
'mobilenetv4_conv_medium.r224': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_conv_medium.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_conv_large.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_conv_large.r384': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=0.95, interpolation='bicubic'),
|
||||
|
||||
'mobilenetv4_hybrid_small': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_medium': _cfg(
|
||||
'mobilenetv4_hybrid_medium.r224': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_large': _cfg(
|
||||
crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_medium.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_large.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_large.r384': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=0.95, interpolation='bicubic'),
|
||||
|
||||
# experimental
|
||||
'mobilenetv4_conv_aa_medium.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_conv_blur_medium.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_medium_075': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_medium_150': _cfg(
|
||||
crop_pct=0.95, interpolation='bicubic'),
|
||||
'mobilenetv4_hybrid_large_075.r256': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
interpolation='bicubic'),
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95, interpolation='bicubic'),
|
||||
})
|
||||
|
||||
|
||||
@ -1024,13 +1058,6 @@ def mobilenetv3_large_100(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_large_150(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
model = _gen_mobilenet_v3('mobilenetv3_large_150', 1.5, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv3_small_050(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V3 """
|
||||
@ -1191,13 +1218,6 @@ def mobilenetv4_conv_large(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_medium_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid """
|
||||
@ -1205,13 +1225,6 @@ def mobilenetv4_hybrid_medium(pretrained: bool = False, **kwargs) -> MobileNetV3
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_medium_150(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_150', 1.5, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_large(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid"""
|
||||
@ -1219,6 +1232,33 @@ def mobilenetv4_hybrid_large(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_conv_aa_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 w/ AvgPool AA """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_conv_aa_medium', 1.0, pretrained=pretrained, aa_layer='avg', **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_conv_blur_medium(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Conv w/ Blur AA """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_conv_blur_medium', 1.0, pretrained=pretrained, aa_layer='blurpc', **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_medium_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid """
|
||||
model = _gen_mobilenet_v4('mobilenetv4_hybrid_medium_075', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilenetv4_hybrid_large_075(pretrained: bool = False, **kwargs) -> MobileNetV3:
|
||||
""" MobileNet V4 Hybrid"""
|
||||
model = _gen_mobilenet_v4('mobilenetv4_hybrid_large', 0.75, pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
register_model_deprecations(__name__, {
|
||||
'mobilenetv3_large_100_miil': 'mobilenetv3_large_100.miil_in21k_ft_in1k',
|
||||
|
@ -17,7 +17,7 @@ import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
|
||||
get_attn, get_act_layer, get_norm_layer, create_classifier
|
||||
get_attn, get_act_layer, get_norm_layer, create_classifier, create_aa
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import checkpoint_seq
|
||||
@ -31,15 +31,6 @@ def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
|
||||
return padding
|
||||
|
||||
|
||||
def create_aa(aa_layer: Type[nn.Module], channels: int, stride: int = 2, enable: bool = True) -> nn.Module:
|
||||
if not aa_layer or not enable:
|
||||
return nn.Identity()
|
||||
if issubclass(aa_layer, nn.AvgPool2d):
|
||||
return aa_layer(stride)
|
||||
else:
|
||||
return aa_layer(channels=channels, stride=stride)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user