mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove separate ConvNormActAa class, merge with ConvNormAct
This commit is contained in:
parent
5efa15b2a2
commit
f0fb471b26
@ -26,7 +26,8 @@ class ConvNormAct(nn.Module):
|
|||||||
apply_norm: bool = True,
|
apply_norm: bool = True,
|
||||||
apply_act: bool = True,
|
apply_act: bool = True,
|
||||||
norm_layer: LayerType = nn.BatchNorm2d,
|
norm_layer: LayerType = nn.BatchNorm2d,
|
||||||
act_layer: LayerType = nn.ReLU,
|
act_layer: Optional[LayerType] = nn.ReLU,
|
||||||
|
aa_layer: Optional[LayerType] = None,
|
||||||
drop_layer: Optional[Type[nn.Module]] = None,
|
drop_layer: Optional[Type[nn.Module]] = None,
|
||||||
conv_kwargs: Optional[Dict[str, Any]] = None,
|
conv_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
norm_kwargs: Optional[Dict[str, Any]] = None,
|
norm_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
@ -36,83 +37,12 @@ class ConvNormAct(nn.Module):
|
|||||||
conv_kwargs = conv_kwargs or {}
|
conv_kwargs = conv_kwargs or {}
|
||||||
norm_kwargs = norm_kwargs or {}
|
norm_kwargs = norm_kwargs or {}
|
||||||
act_kwargs = act_kwargs or {}
|
act_kwargs = act_kwargs or {}
|
||||||
|
use_aa = aa_layer is not None and stride > 1
|
||||||
|
|
||||||
self.conv = create_conv2d(
|
self.conv = create_conv2d(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
dilation=dilation,
|
|
||||||
groups=groups,
|
|
||||||
bias=bias,
|
|
||||||
**conv_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if apply_norm:
|
|
||||||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
|
||||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
|
||||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
|
||||||
if drop_layer:
|
|
||||||
norm_kwargs['drop_layer'] = drop_layer
|
|
||||||
self.bn = norm_act_layer(
|
|
||||||
out_channels,
|
|
||||||
apply_act=apply_act,
|
|
||||||
act_kwargs=act_kwargs,
|
|
||||||
**norm_kwargs,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.bn = nn.Sequential()
|
|
||||||
if drop_layer:
|
|
||||||
norm_kwargs['drop_layer'] = drop_layer
|
|
||||||
self.bn.add_module('drop', drop_layer())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def in_channels(self):
|
|
||||||
return self.conv.in_channels
|
|
||||||
|
|
||||||
@property
|
|
||||||
def out_channels(self):
|
|
||||||
return self.conv.out_channels
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv(x)
|
|
||||||
x = self.bn(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
ConvBnAct = ConvNormAct
|
|
||||||
|
|
||||||
|
|
||||||
class ConvNormActAa(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_channels: int,
|
|
||||||
out_channels: int,
|
|
||||||
kernel_size: int = 1,
|
|
||||||
stride: int = 1,
|
|
||||||
padding: PadType = '',
|
|
||||||
dilation: int = 1,
|
|
||||||
groups: int = 1,
|
|
||||||
bias: bool = False,
|
|
||||||
apply_norm: bool = True,
|
|
||||||
apply_act: bool = True,
|
|
||||||
norm_layer: LayerType = nn.BatchNorm2d,
|
|
||||||
act_layer: LayerType = nn.ReLU,
|
|
||||||
aa_layer: Optional[LayerType] = None,
|
|
||||||
drop_layer: Optional[Type[nn.Module]] = None,
|
|
||||||
conv_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
norm_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
act_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
):
|
|
||||||
super(ConvNormActAa, self).__init__()
|
|
||||||
use_aa = aa_layer is not None and stride == 2
|
|
||||||
conv_kwargs = conv_kwargs or {}
|
|
||||||
norm_kwargs = norm_kwargs or {}
|
|
||||||
act_kwargs = act_kwargs or {}
|
|
||||||
|
|
||||||
self.conv = create_conv2d(
|
|
||||||
in_channels, out_channels, kernel_size,
|
|
||||||
stride=1 if use_aa else stride,
|
stride=1 if use_aa else stride,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
@ -139,7 +69,7 @@ class ConvNormActAa(nn.Module):
|
|||||||
norm_kwargs['drop_layer'] = drop_layer
|
norm_kwargs['drop_layer'] = drop_layer
|
||||||
self.bn.add_module('drop', drop_layer())
|
self.bn.add_module('drop', drop_layer())
|
||||||
|
|
||||||
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
|
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa, noop=None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def in_channels(self):
|
def in_channels(self):
|
||||||
@ -152,5 +82,10 @@ class ConvNormActAa(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
x = self.bn(x)
|
x = self.bn(x)
|
||||||
x = self.aa(x)
|
if self.aa is not None:
|
||||||
|
x = self.aa(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
ConvBnAct = ConvNormAct
|
||||||
|
ConvNormActAa = ConvNormAct # backwards compat, when they were separate
|
||||||
|
@ -7,7 +7,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
|
|
||||||
from .conv_bn_act import ConvNormActAa
|
from .conv_bn_act import ConvNormAct
|
||||||
from .helpers import make_divisible
|
from .helpers import make_divisible
|
||||||
from .trace_utils import _assert
|
from .trace_utils import _assert
|
||||||
|
|
||||||
@ -100,7 +100,7 @@ class SelectiveKernel(nn.Module):
|
|||||||
stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer,
|
stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer,
|
||||||
aa_layer=aa_layer, drop_layer=drop_layer)
|
aa_layer=aa_layer, drop_layer=drop_layer)
|
||||||
self.paths = nn.ModuleList([
|
self.paths = nn.ModuleList([
|
||||||
ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
|
ConvNormAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
|
||||||
for k, d in zip(kernel_size, dilation)])
|
for k, d in zip(kernel_size, dilation)])
|
||||||
|
|
||||||
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
|
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
|
||||||
|
@ -9,7 +9,7 @@ import torch.nn as nn
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\
|
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\
|
||||||
ConvNormAct, ConvNormActAa, get_norm_act_layer, MultiQueryAttention2d, Attention2d
|
ConvNormAct, get_norm_act_layer, MultiQueryAttention2d, Attention2d
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual',
|
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual',
|
||||||
@ -345,7 +345,7 @@ class UniversalInvertedResidual(nn.Module):
|
|||||||
if dw_kernel_size_start:
|
if dw_kernel_size_start:
|
||||||
dw_start_stride = stride if not dw_kernel_size_mid else 1
|
dw_start_stride = stride if not dw_kernel_size_mid else 1
|
||||||
dw_start_groups = num_groups(group_size, in_chs)
|
dw_start_groups = num_groups(group_size, in_chs)
|
||||||
self.dw_start = ConvNormActAa(
|
self.dw_start = ConvNormAct(
|
||||||
in_chs, in_chs, dw_kernel_size_start,
|
in_chs, in_chs, dw_kernel_size_start,
|
||||||
stride=dw_start_stride,
|
stride=dw_start_stride,
|
||||||
dilation=dilation, # FIXME
|
dilation=dilation, # FIXME
|
||||||
@ -373,7 +373,7 @@ class UniversalInvertedResidual(nn.Module):
|
|||||||
# Middle depth-wise convolution
|
# Middle depth-wise convolution
|
||||||
if dw_kernel_size_mid:
|
if dw_kernel_size_mid:
|
||||||
groups = num_groups(group_size, mid_chs)
|
groups = num_groups(group_size, mid_chs)
|
||||||
self.dw_mid = ConvNormActAa(
|
self.dw_mid = ConvNormAct(
|
||||||
mid_chs, mid_chs, dw_kernel_size_mid,
|
mid_chs, mid_chs, dw_kernel_size_mid,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
dilation=dilation, # FIXME
|
dilation=dilation, # FIXME
|
||||||
|
@ -20,7 +20,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
|
from timm.layers import ClassifierHead, ConvNormAct, DropPath, get_attn, create_act_layer, make_divisible
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import named_apply, MATCH_PREV_GROUP
|
from ._manipulate import named_apply, MATCH_PREV_GROUP
|
||||||
from ._registry import register_model, generate_default_cfgs
|
from ._registry import register_model, generate_default_cfgs
|
||||||
@ -296,10 +296,10 @@ class CrossStage(nn.Module):
|
|||||||
if avg_down:
|
if avg_down:
|
||||||
self.conv_down = nn.Sequential(
|
self.conv_down = nn.Sequential(
|
||||||
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
||||||
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.conv_down = ConvNormActAa(
|
self.conv_down = ConvNormAct(
|
||||||
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
||||||
aa_layer=aa_layer, **conv_kwargs)
|
aa_layer=aa_layer, **conv_kwargs)
|
||||||
prev_chs = down_chs
|
prev_chs = down_chs
|
||||||
@ -375,10 +375,10 @@ class CrossStage3(nn.Module):
|
|||||||
if avg_down:
|
if avg_down:
|
||||||
self.conv_down = nn.Sequential(
|
self.conv_down = nn.Sequential(
|
||||||
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
||||||
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.conv_down = ConvNormActAa(
|
self.conv_down = ConvNormAct(
|
||||||
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
||||||
aa_layer=aa_layer, **conv_kwargs)
|
aa_layer=aa_layer, **conv_kwargs)
|
||||||
prev_chs = down_chs
|
prev_chs = down_chs
|
||||||
@ -442,10 +442,10 @@ class DarkStage(nn.Module):
|
|||||||
if avg_down:
|
if avg_down:
|
||||||
self.conv_down = nn.Sequential(
|
self.conv_down = nn.Sequential(
|
||||||
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
||||||
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.conv_down = ConvNormActAa(
|
self.conv_down = ConvNormAct(
|
||||||
in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
||||||
aa_layer=aa_layer, **conv_kwargs)
|
aa_layer=aa_layer, **conv_kwargs)
|
||||||
|
|
||||||
|
@ -12,8 +12,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule,\
|
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath
|
||||||
ConvNormActAa, ConvNormAct, DropPath
|
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||||
@ -39,13 +38,8 @@ class BasicBlock(nn.Module):
|
|||||||
self.stride = stride
|
self.stride = stride
|
||||||
act_layer = partial(nn.LeakyReLU, negative_slope=1e-3)
|
act_layer = partial(nn.LeakyReLU, negative_slope=1e-3)
|
||||||
|
|
||||||
if stride == 1:
|
self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer)
|
||||||
self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=1, act_layer=act_layer)
|
self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False)
|
||||||
else:
|
|
||||||
self.conv1 = ConvNormActAa(
|
|
||||||
inplanes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
|
|
||||||
|
|
||||||
self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, act_layer=None)
|
|
||||||
self.act = nn.ReLU(inplace=True)
|
self.act = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
rd_chs = max(planes * self.expansion // 4, 64)
|
rd_chs = max(planes * self.expansion // 4, 64)
|
||||||
@ -87,18 +81,14 @@ class Bottleneck(nn.Module):
|
|||||||
|
|
||||||
self.conv1 = ConvNormAct(
|
self.conv1 = ConvNormAct(
|
||||||
inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer)
|
inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer)
|
||||||
if stride == 1:
|
self.conv2 = ConvNormAct(
|
||||||
self.conv2 = ConvNormAct(
|
planes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer)
|
||||||
planes, planes, kernel_size=3, stride=1, act_layer=act_layer)
|
|
||||||
else:
|
|
||||||
self.conv2 = ConvNormActAa(
|
|
||||||
planes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
|
|
||||||
|
|
||||||
reduction_chs = max(planes * self.expansion // 8, 64)
|
reduction_chs = max(planes * self.expansion // 8, 64)
|
||||||
self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
|
self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
|
||||||
|
|
||||||
self.conv3 = ConvNormAct(
|
self.conv3 = ConvNormAct(
|
||||||
planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)
|
planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False)
|
||||||
|
|
||||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||||
self.act = nn.ReLU(inplace=True)
|
self.act = nn.ReLU(inplace=True)
|
||||||
@ -204,7 +194,7 @@ class TResNet(nn.Module):
|
|||||||
# avg pooling before 1x1 conv
|
# avg pooling before 1x1 conv
|
||||||
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
|
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
|
||||||
layers += [ConvNormAct(
|
layers += [ConvNormAct(
|
||||||
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)]
|
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False)]
|
||||||
downsample = nn.Sequential(*layers)
|
downsample = nn.Sequential(*layers)
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
|
Loading…
x
Reference in New Issue
Block a user