Remove separate ConvNormActAa class, merge with ConvNormAct

This commit is contained in:
Ross Wightman 2024-06-10 12:05:35 -07:00
parent 5efa15b2a2
commit f0fb471b26
5 changed files with 29 additions and 104 deletions

View File

@ -26,7 +26,8 @@ class ConvNormAct(nn.Module):
apply_norm: bool = True,
apply_act: bool = True,
norm_layer: LayerType = nn.BatchNorm2d,
act_layer: LayerType = nn.ReLU,
act_layer: Optional[LayerType] = nn.ReLU,
aa_layer: Optional[LayerType] = None,
drop_layer: Optional[Type[nn.Module]] = None,
conv_kwargs: Optional[Dict[str, Any]] = None,
norm_kwargs: Optional[Dict[str, Any]] = None,
@ -36,83 +37,12 @@ class ConvNormAct(nn.Module):
conv_kwargs = conv_kwargs or {}
norm_kwargs = norm_kwargs or {}
act_kwargs = act_kwargs or {}
use_aa = aa_layer is not None and stride > 1
self.conv = create_conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
**conv_kwargs,
)
if apply_norm:
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn = norm_act_layer(
out_channels,
apply_act=apply_act,
act_kwargs=act_kwargs,
**norm_kwargs,
)
else:
self.bn = nn.Sequential()
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn.add_module('drop', drop_layer())
@property
def in_channels(self):
return self.conv.in_channels
@property
def out_channels(self):
return self.conv.out_channels
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
ConvBnAct = ConvNormAct
class ConvNormActAa(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 1,
stride: int = 1,
padding: PadType = '',
dilation: int = 1,
groups: int = 1,
bias: bool = False,
apply_norm: bool = True,
apply_act: bool = True,
norm_layer: LayerType = nn.BatchNorm2d,
act_layer: LayerType = nn.ReLU,
aa_layer: Optional[LayerType] = None,
drop_layer: Optional[Type[nn.Module]] = None,
conv_kwargs: Optional[Dict[str, Any]] = None,
norm_kwargs: Optional[Dict[str, Any]] = None,
act_kwargs: Optional[Dict[str, Any]] = None,
):
super(ConvNormActAa, self).__init__()
use_aa = aa_layer is not None and stride == 2
conv_kwargs = conv_kwargs or {}
norm_kwargs = norm_kwargs or {}
act_kwargs = act_kwargs or {}
self.conv = create_conv2d(
in_channels, out_channels, kernel_size,
stride=1 if use_aa else stride,
padding=padding,
dilation=dilation,
@ -139,7 +69,7 @@ class ConvNormActAa(nn.Module):
norm_kwargs['drop_layer'] = drop_layer
self.bn.add_module('drop', drop_layer())
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa, noop=None)
@property
def in_channels(self):
@ -152,5 +82,10 @@ class ConvNormActAa(nn.Module):
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.aa(x)
if self.aa is not None:
x = self.aa(x)
return x
ConvBnAct = ConvNormAct
ConvNormActAa = ConvNormAct # backwards compat, when they were separate

View File

@ -7,7 +7,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch
from torch import nn as nn
from .conv_bn_act import ConvNormActAa
from .conv_bn_act import ConvNormAct
from .helpers import make_divisible
from .trace_utils import _assert
@ -100,7 +100,7 @@ class SelectiveKernel(nn.Module):
stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer,
aa_layer=aa_layer, drop_layer=drop_layer)
self.paths = nn.ModuleList([
ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
ConvNormAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
for k, d in zip(kernel_size, dilation)])
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)

View File

@ -9,7 +9,7 @@ import torch.nn as nn
from torch.nn import functional as F
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\
ConvNormAct, ConvNormActAa, get_norm_act_layer, MultiQueryAttention2d, Attention2d
ConvNormAct, get_norm_act_layer, MultiQueryAttention2d, Attention2d
__all__ = [
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual',
@ -345,7 +345,7 @@ class UniversalInvertedResidual(nn.Module):
if dw_kernel_size_start:
dw_start_stride = stride if not dw_kernel_size_mid else 1
dw_start_groups = num_groups(group_size, in_chs)
self.dw_start = ConvNormActAa(
self.dw_start = ConvNormAct(
in_chs, in_chs, dw_kernel_size_start,
stride=dw_start_stride,
dilation=dilation, # FIXME
@ -373,7 +373,7 @@ class UniversalInvertedResidual(nn.Module):
# Middle depth-wise convolution
if dw_kernel_size_mid:
groups = num_groups(group_size, mid_chs)
self.dw_mid = ConvNormActAa(
self.dw_mid = ConvNormAct(
mid_chs, mid_chs, dw_kernel_size_mid,
stride=stride,
dilation=dilation, # FIXME

View File

@ -20,7 +20,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
from timm.layers import ClassifierHead, ConvNormAct, DropPath, get_attn, create_act_layer, make_divisible
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, MATCH_PREV_GROUP
from ._registry import register_model, generate_default_cfgs
@ -296,10 +296,10 @@ class CrossStage(nn.Module):
if avg_down:
self.conv_down = nn.Sequential(
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
)
else:
self.conv_down = ConvNormActAa(
self.conv_down = ConvNormAct(
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
aa_layer=aa_layer, **conv_kwargs)
prev_chs = down_chs
@ -375,10 +375,10 @@ class CrossStage3(nn.Module):
if avg_down:
self.conv_down = nn.Sequential(
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
)
else:
self.conv_down = ConvNormActAa(
self.conv_down = ConvNormAct(
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
aa_layer=aa_layer, **conv_kwargs)
prev_chs = down_chs
@ -442,10 +442,10 @@ class DarkStage(nn.Module):
if avg_down:
self.conv_down = nn.Sequential(
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
)
else:
self.conv_down = ConvNormActAa(
self.conv_down = ConvNormAct(
in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
aa_layer=aa_layer, **conv_kwargs)

View File

@ -12,8 +12,7 @@ from typing import Optional
import torch
import torch.nn as nn
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule,\
ConvNormActAa, ConvNormAct, DropPath
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
@ -39,13 +38,8 @@ class BasicBlock(nn.Module):
self.stride = stride
act_layer = partial(nn.LeakyReLU, negative_slope=1e-3)
if stride == 1:
self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=1, act_layer=act_layer)
else:
self.conv1 = ConvNormActAa(
inplanes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, act_layer=None)
self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer)
self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False)
self.act = nn.ReLU(inplace=True)
rd_chs = max(planes * self.expansion // 4, 64)
@ -87,18 +81,14 @@ class Bottleneck(nn.Module):
self.conv1 = ConvNormAct(
inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer)
if stride == 1:
self.conv2 = ConvNormAct(
planes, planes, kernel_size=3, stride=1, act_layer=act_layer)
else:
self.conv2 = ConvNormActAa(
planes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
self.conv2 = ConvNormAct(
planes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer)
reduction_chs = max(planes * self.expansion // 8, 64)
self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
self.conv3 = ConvNormAct(
planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)
planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.act = nn.ReLU(inplace=True)
@ -204,7 +194,7 @@ class TResNet(nn.Module):
# avg pooling before 1x1 conv
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
layers += [ConvNormAct(
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)]
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False)]
downsample = nn.Sequential(*layers)
layers = []