Remove separate ConvNormActAa class, merge with ConvNormAct

This commit is contained in:
Ross Wightman 2024-06-10 12:05:35 -07:00
parent 5efa15b2a2
commit f0fb471b26
5 changed files with 29 additions and 104 deletions

View File

@ -26,7 +26,8 @@ class ConvNormAct(nn.Module):
apply_norm: bool = True, apply_norm: bool = True,
apply_act: bool = True, apply_act: bool = True,
norm_layer: LayerType = nn.BatchNorm2d, norm_layer: LayerType = nn.BatchNorm2d,
act_layer: LayerType = nn.ReLU, act_layer: Optional[LayerType] = nn.ReLU,
aa_layer: Optional[LayerType] = None,
drop_layer: Optional[Type[nn.Module]] = None, drop_layer: Optional[Type[nn.Module]] = None,
conv_kwargs: Optional[Dict[str, Any]] = None, conv_kwargs: Optional[Dict[str, Any]] = None,
norm_kwargs: Optional[Dict[str, Any]] = None, norm_kwargs: Optional[Dict[str, Any]] = None,
@ -36,83 +37,12 @@ class ConvNormAct(nn.Module):
conv_kwargs = conv_kwargs or {} conv_kwargs = conv_kwargs or {}
norm_kwargs = norm_kwargs or {} norm_kwargs = norm_kwargs or {}
act_kwargs = act_kwargs or {} act_kwargs = act_kwargs or {}
use_aa = aa_layer is not None and stride > 1
self.conv = create_conv2d( self.conv = create_conv2d(
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
**conv_kwargs,
)
if apply_norm:
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn = norm_act_layer(
out_channels,
apply_act=apply_act,
act_kwargs=act_kwargs,
**norm_kwargs,
)
else:
self.bn = nn.Sequential()
if drop_layer:
norm_kwargs['drop_layer'] = drop_layer
self.bn.add_module('drop', drop_layer())
@property
def in_channels(self):
return self.conv.in_channels
@property
def out_channels(self):
return self.conv.out_channels
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
ConvBnAct = ConvNormAct
class ConvNormActAa(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 1,
stride: int = 1,
padding: PadType = '',
dilation: int = 1,
groups: int = 1,
bias: bool = False,
apply_norm: bool = True,
apply_act: bool = True,
norm_layer: LayerType = nn.BatchNorm2d,
act_layer: LayerType = nn.ReLU,
aa_layer: Optional[LayerType] = None,
drop_layer: Optional[Type[nn.Module]] = None,
conv_kwargs: Optional[Dict[str, Any]] = None,
norm_kwargs: Optional[Dict[str, Any]] = None,
act_kwargs: Optional[Dict[str, Any]] = None,
):
super(ConvNormActAa, self).__init__()
use_aa = aa_layer is not None and stride == 2
conv_kwargs = conv_kwargs or {}
norm_kwargs = norm_kwargs or {}
act_kwargs = act_kwargs or {}
self.conv = create_conv2d(
in_channels, out_channels, kernel_size,
stride=1 if use_aa else stride, stride=1 if use_aa else stride,
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
@ -139,7 +69,7 @@ class ConvNormActAa(nn.Module):
norm_kwargs['drop_layer'] = drop_layer norm_kwargs['drop_layer'] = drop_layer
self.bn.add_module('drop', drop_layer()) self.bn.add_module('drop', drop_layer())
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa, noop=None)
@property @property
def in_channels(self): def in_channels(self):
@ -152,5 +82,10 @@ class ConvNormActAa(nn.Module):
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
x = self.bn(x) x = self.bn(x)
x = self.aa(x) if self.aa is not None:
x = self.aa(x)
return x return x
ConvBnAct = ConvNormAct
ConvNormActAa = ConvNormAct # backwards compat, when they were separate

View File

@ -7,7 +7,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch import torch
from torch import nn as nn from torch import nn as nn
from .conv_bn_act import ConvNormActAa from .conv_bn_act import ConvNormAct
from .helpers import make_divisible from .helpers import make_divisible
from .trace_utils import _assert from .trace_utils import _assert
@ -100,7 +100,7 @@ class SelectiveKernel(nn.Module):
stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer, stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer,
aa_layer=aa_layer, drop_layer=drop_layer) aa_layer=aa_layer, drop_layer=drop_layer)
self.paths = nn.ModuleList([ self.paths = nn.ModuleList([
ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) ConvNormAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
for k, d in zip(kernel_size, dilation)]) for k, d in zip(kernel_size, dilation)])
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)

View File

@ -9,7 +9,7 @@ import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\ from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\
ConvNormAct, ConvNormActAa, get_norm_act_layer, MultiQueryAttention2d, Attention2d ConvNormAct, get_norm_act_layer, MultiQueryAttention2d, Attention2d
__all__ = [ __all__ = [
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual',
@ -345,7 +345,7 @@ class UniversalInvertedResidual(nn.Module):
if dw_kernel_size_start: if dw_kernel_size_start:
dw_start_stride = stride if not dw_kernel_size_mid else 1 dw_start_stride = stride if not dw_kernel_size_mid else 1
dw_start_groups = num_groups(group_size, in_chs) dw_start_groups = num_groups(group_size, in_chs)
self.dw_start = ConvNormActAa( self.dw_start = ConvNormAct(
in_chs, in_chs, dw_kernel_size_start, in_chs, in_chs, dw_kernel_size_start,
stride=dw_start_stride, stride=dw_start_stride,
dilation=dilation, # FIXME dilation=dilation, # FIXME
@ -373,7 +373,7 @@ class UniversalInvertedResidual(nn.Module):
# Middle depth-wise convolution # Middle depth-wise convolution
if dw_kernel_size_mid: if dw_kernel_size_mid:
groups = num_groups(group_size, mid_chs) groups = num_groups(group_size, mid_chs)
self.dw_mid = ConvNormActAa( self.dw_mid = ConvNormAct(
mid_chs, mid_chs, dw_kernel_size_mid, mid_chs, mid_chs, dw_kernel_size_mid,
stride=stride, stride=stride,
dilation=dilation, # FIXME dilation=dilation, # FIXME

View File

@ -20,7 +20,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible from timm.layers import ClassifierHead, ConvNormAct, DropPath, get_attn, create_act_layer, make_divisible
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply, MATCH_PREV_GROUP from ._manipulate import named_apply, MATCH_PREV_GROUP
from ._registry import register_model, generate_default_cfgs from ._registry import register_model, generate_default_cfgs
@ -296,10 +296,10 @@ class CrossStage(nn.Module):
if avg_down: if avg_down:
self.conv_down = nn.Sequential( self.conv_down = nn.Sequential(
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
) )
else: else:
self.conv_down = ConvNormActAa( self.conv_down = ConvNormAct(
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
aa_layer=aa_layer, **conv_kwargs) aa_layer=aa_layer, **conv_kwargs)
prev_chs = down_chs prev_chs = down_chs
@ -375,10 +375,10 @@ class CrossStage3(nn.Module):
if avg_down: if avg_down:
self.conv_down = nn.Sequential( self.conv_down = nn.Sequential(
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
) )
else: else:
self.conv_down = ConvNormActAa( self.conv_down = ConvNormAct(
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
aa_layer=aa_layer, **conv_kwargs) aa_layer=aa_layer, **conv_kwargs)
prev_chs = down_chs prev_chs = down_chs
@ -442,10 +442,10 @@ class DarkStage(nn.Module):
if avg_down: if avg_down:
self.conv_down = nn.Sequential( self.conv_down = nn.Sequential(
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
) )
else: else:
self.conv_down = ConvNormActAa( self.conv_down = ConvNormAct(
in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
aa_layer=aa_layer, **conv_kwargs) aa_layer=aa_layer, **conv_kwargs)

View File

@ -12,8 +12,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule,\ from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath
ConvNormActAa, ConvNormAct, DropPath
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs, register_model_deprecations from ._registry import register_model, generate_default_cfgs, register_model_deprecations
@ -39,13 +38,8 @@ class BasicBlock(nn.Module):
self.stride = stride self.stride = stride
act_layer = partial(nn.LeakyReLU, negative_slope=1e-3) act_layer = partial(nn.LeakyReLU, negative_slope=1e-3)
if stride == 1: self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer)
self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=1, act_layer=act_layer) self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False)
else:
self.conv1 = ConvNormActAa(
inplanes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, act_layer=None)
self.act = nn.ReLU(inplace=True) self.act = nn.ReLU(inplace=True)
rd_chs = max(planes * self.expansion // 4, 64) rd_chs = max(planes * self.expansion // 4, 64)
@ -87,18 +81,14 @@ class Bottleneck(nn.Module):
self.conv1 = ConvNormAct( self.conv1 = ConvNormAct(
inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer) inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer)
if stride == 1: self.conv2 = ConvNormAct(
self.conv2 = ConvNormAct( planes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer)
planes, planes, kernel_size=3, stride=1, act_layer=act_layer)
else:
self.conv2 = ConvNormActAa(
planes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
reduction_chs = max(planes * self.expansion // 8, 64) reduction_chs = max(planes * self.expansion // 8, 64)
self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
self.conv3 = ConvNormAct( self.conv3 = ConvNormAct(
planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None) planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.act = nn.ReLU(inplace=True) self.act = nn.ReLU(inplace=True)
@ -204,7 +194,7 @@ class TResNet(nn.Module):
# avg pooling before 1x1 conv # avg pooling before 1x1 conv
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False)) layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
layers += [ConvNormAct( layers += [ConvNormAct(
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)] self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False)]
downsample = nn.Sequential(*layers) downsample = nn.Sequential(*layers)
layers = [] layers = []