mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove separate ConvNormActAa class, merge with ConvNormAct
This commit is contained in:
parent
5efa15b2a2
commit
f0fb471b26
@ -26,7 +26,8 @@ class ConvNormAct(nn.Module):
|
||||
apply_norm: bool = True,
|
||||
apply_act: bool = True,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
act_layer: Optional[LayerType] = nn.ReLU,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
drop_layer: Optional[Type[nn.Module]] = None,
|
||||
conv_kwargs: Optional[Dict[str, Any]] = None,
|
||||
norm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@ -36,83 +37,12 @@ class ConvNormAct(nn.Module):
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
act_kwargs = act_kwargs or {}
|
||||
use_aa = aa_layer is not None and stride > 1
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
**conv_kwargs,
|
||||
)
|
||||
|
||||
if apply_norm:
|
||||
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn = norm_act_layer(
|
||||
out_channels,
|
||||
apply_act=apply_act,
|
||||
act_kwargs=act_kwargs,
|
||||
**norm_kwargs,
|
||||
)
|
||||
else:
|
||||
self.bn = nn.Sequential()
|
||||
if drop_layer:
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn.add_module('drop', drop_layer())
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
return self.conv.in_channels
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
return self.conv.out_channels
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
ConvBnAct = ConvNormAct
|
||||
|
||||
|
||||
class ConvNormActAa(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 1,
|
||||
stride: int = 1,
|
||||
padding: PadType = '',
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = False,
|
||||
apply_norm: bool = True,
|
||||
apply_act: bool = True,
|
||||
norm_layer: LayerType = nn.BatchNorm2d,
|
||||
act_layer: LayerType = nn.ReLU,
|
||||
aa_layer: Optional[LayerType] = None,
|
||||
drop_layer: Optional[Type[nn.Module]] = None,
|
||||
conv_kwargs: Optional[Dict[str, Any]] = None,
|
||||
norm_kwargs: Optional[Dict[str, Any]] = None,
|
||||
act_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super(ConvNormActAa, self).__init__()
|
||||
use_aa = aa_layer is not None and stride == 2
|
||||
conv_kwargs = conv_kwargs or {}
|
||||
norm_kwargs = norm_kwargs or {}
|
||||
act_kwargs = act_kwargs or {}
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_channels, out_channels, kernel_size,
|
||||
stride=1 if use_aa else stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
@ -139,7 +69,7 @@ class ConvNormActAa(nn.Module):
|
||||
norm_kwargs['drop_layer'] = drop_layer
|
||||
self.bn.add_module('drop', drop_layer())
|
||||
|
||||
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
|
||||
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa, noop=None)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
@ -152,5 +82,10 @@ class ConvNormActAa(nn.Module):
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.aa(x)
|
||||
if self.aa is not None:
|
||||
x = self.aa(x)
|
||||
return x
|
||||
|
||||
|
||||
ConvBnAct = ConvNormAct
|
||||
ConvNormActAa = ConvNormAct # backwards compat, when they were separate
|
||||
|
@ -7,7 +7,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
from .conv_bn_act import ConvNormActAa
|
||||
from .conv_bn_act import ConvNormAct
|
||||
from .helpers import make_divisible
|
||||
from .trace_utils import _assert
|
||||
|
||||
@ -100,7 +100,7 @@ class SelectiveKernel(nn.Module):
|
||||
stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer,
|
||||
aa_layer=aa_layer, drop_layer=drop_layer)
|
||||
self.paths = nn.ModuleList([
|
||||
ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
|
||||
ConvNormAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
|
||||
for k, d in zip(kernel_size, dilation)])
|
||||
|
||||
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
|
||||
|
@ -9,7 +9,7 @@ import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\
|
||||
ConvNormAct, ConvNormActAa, get_norm_act_layer, MultiQueryAttention2d, Attention2d
|
||||
ConvNormAct, get_norm_act_layer, MultiQueryAttention2d, Attention2d
|
||||
|
||||
__all__ = [
|
||||
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual',
|
||||
@ -345,7 +345,7 @@ class UniversalInvertedResidual(nn.Module):
|
||||
if dw_kernel_size_start:
|
||||
dw_start_stride = stride if not dw_kernel_size_mid else 1
|
||||
dw_start_groups = num_groups(group_size, in_chs)
|
||||
self.dw_start = ConvNormActAa(
|
||||
self.dw_start = ConvNormAct(
|
||||
in_chs, in_chs, dw_kernel_size_start,
|
||||
stride=dw_start_stride,
|
||||
dilation=dilation, # FIXME
|
||||
@ -373,7 +373,7 @@ class UniversalInvertedResidual(nn.Module):
|
||||
# Middle depth-wise convolution
|
||||
if dw_kernel_size_mid:
|
||||
groups = num_groups(group_size, mid_chs)
|
||||
self.dw_mid = ConvNormActAa(
|
||||
self.dw_mid = ConvNormAct(
|
||||
mid_chs, mid_chs, dw_kernel_size_mid,
|
||||
stride=stride,
|
||||
dilation=dilation, # FIXME
|
||||
|
@ -20,7 +20,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
|
||||
from timm.layers import ClassifierHead, ConvNormAct, DropPath, get_attn, create_act_layer, make_divisible
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, MATCH_PREV_GROUP
|
||||
from ._registry import register_model, generate_default_cfgs
|
||||
@ -296,10 +296,10 @@ class CrossStage(nn.Module):
|
||||
if avg_down:
|
||||
self.conv_down = nn.Sequential(
|
||||
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
||||
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
||||
ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
||||
)
|
||||
else:
|
||||
self.conv_down = ConvNormActAa(
|
||||
self.conv_down = ConvNormAct(
|
||||
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
||||
aa_layer=aa_layer, **conv_kwargs)
|
||||
prev_chs = down_chs
|
||||
@ -375,10 +375,10 @@ class CrossStage3(nn.Module):
|
||||
if avg_down:
|
||||
self.conv_down = nn.Sequential(
|
||||
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
||||
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
||||
ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
||||
)
|
||||
else:
|
||||
self.conv_down = ConvNormActAa(
|
||||
self.conv_down = ConvNormAct(
|
||||
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
||||
aa_layer=aa_layer, **conv_kwargs)
|
||||
prev_chs = down_chs
|
||||
@ -442,10 +442,10 @@ class DarkStage(nn.Module):
|
||||
if avg_down:
|
||||
self.conv_down = nn.Sequential(
|
||||
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
||||
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
||||
ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
||||
)
|
||||
else:
|
||||
self.conv_down = ConvNormActAa(
|
||||
self.conv_down = ConvNormAct(
|
||||
in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
||||
aa_layer=aa_layer, **conv_kwargs)
|
||||
|
||||
|
@ -12,8 +12,7 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule,\
|
||||
ConvNormActAa, ConvNormAct, DropPath
|
||||
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
|
||||
@ -39,13 +38,8 @@ class BasicBlock(nn.Module):
|
||||
self.stride = stride
|
||||
act_layer = partial(nn.LeakyReLU, negative_slope=1e-3)
|
||||
|
||||
if stride == 1:
|
||||
self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=1, act_layer=act_layer)
|
||||
else:
|
||||
self.conv1 = ConvNormActAa(
|
||||
inplanes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
|
||||
|
||||
self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, act_layer=None)
|
||||
self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer)
|
||||
self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False)
|
||||
self.act = nn.ReLU(inplace=True)
|
||||
|
||||
rd_chs = max(planes * self.expansion // 4, 64)
|
||||
@ -87,18 +81,14 @@ class Bottleneck(nn.Module):
|
||||
|
||||
self.conv1 = ConvNormAct(
|
||||
inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer)
|
||||
if stride == 1:
|
||||
self.conv2 = ConvNormAct(
|
||||
planes, planes, kernel_size=3, stride=1, act_layer=act_layer)
|
||||
else:
|
||||
self.conv2 = ConvNormActAa(
|
||||
planes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
|
||||
self.conv2 = ConvNormAct(
|
||||
planes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer)
|
||||
|
||||
reduction_chs = max(planes * self.expansion // 8, 64)
|
||||
self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
|
||||
|
||||
self.conv3 = ConvNormAct(
|
||||
planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)
|
||||
planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False)
|
||||
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
self.act = nn.ReLU(inplace=True)
|
||||
@ -204,7 +194,7 @@ class TResNet(nn.Module):
|
||||
# avg pooling before 1x1 conv
|
||||
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
|
||||
layers += [ConvNormAct(
|
||||
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)]
|
||||
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False)]
|
||||
downsample = nn.Sequential(*layers)
|
||||
|
||||
layers = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user