diff --git a/timm/layers/conv_bn_act.py b/timm/layers/conv_bn_act.py index de738045..73ad6705 100644 --- a/timm/layers/conv_bn_act.py +++ b/timm/layers/conv_bn_act.py @@ -26,7 +26,8 @@ class ConvNormAct(nn.Module): apply_norm: bool = True, apply_act: bool = True, norm_layer: LayerType = nn.BatchNorm2d, - act_layer: LayerType = nn.ReLU, + act_layer: Optional[LayerType] = nn.ReLU, + aa_layer: Optional[LayerType] = None, drop_layer: Optional[Type[nn.Module]] = None, conv_kwargs: Optional[Dict[str, Any]] = None, norm_kwargs: Optional[Dict[str, Any]] = None, @@ -36,83 +37,12 @@ class ConvNormAct(nn.Module): conv_kwargs = conv_kwargs or {} norm_kwargs = norm_kwargs or {} act_kwargs = act_kwargs or {} + use_aa = aa_layer is not None and stride > 1 self.conv = create_conv2d( in_channels, out_channels, kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - **conv_kwargs, - ) - - if apply_norm: - # NOTE for backwards compatibility with models that use separate norm and act layer definitions - norm_act_layer = get_norm_act_layer(norm_layer, act_layer) - # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` - if drop_layer: - norm_kwargs['drop_layer'] = drop_layer - self.bn = norm_act_layer( - out_channels, - apply_act=apply_act, - act_kwargs=act_kwargs, - **norm_kwargs, - ) - else: - self.bn = nn.Sequential() - if drop_layer: - norm_kwargs['drop_layer'] = drop_layer - self.bn.add_module('drop', drop_layer()) - - @property - def in_channels(self): - return self.conv.in_channels - - @property - def out_channels(self): - return self.conv.out_channels - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - return x - - -ConvBnAct = ConvNormAct - - -class ConvNormActAa(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int = 1, - stride: int = 1, - padding: PadType = '', - dilation: int = 1, - groups: int = 1, - bias: bool = False, - apply_norm: bool = True, - apply_act: bool = True, - norm_layer: LayerType = nn.BatchNorm2d, - act_layer: LayerType = nn.ReLU, - aa_layer: Optional[LayerType] = None, - drop_layer: Optional[Type[nn.Module]] = None, - conv_kwargs: Optional[Dict[str, Any]] = None, - norm_kwargs: Optional[Dict[str, Any]] = None, - act_kwargs: Optional[Dict[str, Any]] = None, - ): - super(ConvNormActAa, self).__init__() - use_aa = aa_layer is not None and stride == 2 - conv_kwargs = conv_kwargs or {} - norm_kwargs = norm_kwargs or {} - act_kwargs = act_kwargs or {} - - self.conv = create_conv2d( - in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, padding=padding, dilation=dilation, @@ -139,7 +69,7 @@ class ConvNormActAa(nn.Module): norm_kwargs['drop_layer'] = drop_layer self.bn.add_module('drop', drop_layer()) - self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) + self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa, noop=None) @property def in_channels(self): @@ -152,5 +82,10 @@ class ConvNormActAa(nn.Module): def forward(self, x): x = self.conv(x) x = self.bn(x) - x = self.aa(x) + if self.aa is not None: + x = self.aa(x) return x + + +ConvBnAct = ConvNormAct +ConvNormActAa = ConvNormAct # backwards compat, when they were separate diff --git a/timm/layers/selective_kernel.py b/timm/layers/selective_kernel.py index 3d71e3aa..ec8ee6ce 100644 --- a/timm/layers/selective_kernel.py +++ b/timm/layers/selective_kernel.py @@ -7,7 +7,7 @@ Hacked together by / Copyright 2020 Ross Wightman import torch from torch import nn as nn -from .conv_bn_act import ConvNormActAa +from .conv_bn_act import ConvNormAct from .helpers import make_divisible from .trace_utils import _assert @@ -100,7 +100,7 @@ class SelectiveKernel(nn.Module): stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_layer=drop_layer) self.paths = nn.ModuleList([ - ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + ConvNormAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) for k, d in zip(kernel_size, dilation)]) attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) diff --git a/timm/models/_efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py index f33dacd5..5f98c90c 100644 --- a/timm/models/_efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -9,7 +9,7 @@ import torch.nn as nn from torch.nn import functional as F from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\ - ConvNormAct, ConvNormActAa, get_norm_act_layer, MultiQueryAttention2d, Attention2d + ConvNormAct, get_norm_act_layer, MultiQueryAttention2d, Attention2d __all__ = [ 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', @@ -345,7 +345,7 @@ class UniversalInvertedResidual(nn.Module): if dw_kernel_size_start: dw_start_stride = stride if not dw_kernel_size_mid else 1 dw_start_groups = num_groups(group_size, in_chs) - self.dw_start = ConvNormActAa( + self.dw_start = ConvNormAct( in_chs, in_chs, dw_kernel_size_start, stride=dw_start_stride, dilation=dilation, # FIXME @@ -373,7 +373,7 @@ class UniversalInvertedResidual(nn.Module): # Middle depth-wise convolution if dw_kernel_size_mid: groups = num_groups(group_size, mid_chs) - self.dw_mid = ConvNormActAa( + self.dw_mid = ConvNormAct( mid_chs, mid_chs, dw_kernel_size_mid, stride=stride, dilation=dilation, # FIXME diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 7d63096a..a7368821 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -20,7 +20,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible +from timm.layers import ClassifierHead, ConvNormAct, DropPath, get_attn, create_act_layer, make_divisible from ._builder import build_model_with_cfg from ._manipulate import named_apply, MATCH_PREV_GROUP from ._registry import register_model, generate_default_cfgs @@ -296,10 +296,10 @@ class CrossStage(nn.Module): if avg_down: self.conv_down = nn.Sequential( nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling - ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: - self.conv_down = ConvNormActAa( + self.conv_down = ConvNormAct( in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, aa_layer=aa_layer, **conv_kwargs) prev_chs = down_chs @@ -375,10 +375,10 @@ class CrossStage3(nn.Module): if avg_down: self.conv_down = nn.Sequential( nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling - ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: - self.conv_down = ConvNormActAa( + self.conv_down = ConvNormAct( in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, aa_layer=aa_layer, **conv_kwargs) prev_chs = down_chs @@ -442,10 +442,10 @@ class DarkStage(nn.Module): if avg_down: self.conv_down = nn.Sequential( nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling - ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: - self.conv_down = ConvNormActAa( + self.conv_down = ConvNormAct( in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, aa_layer=aa_layer, **conv_kwargs) diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 006b7e0b..dec24c1f 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -12,8 +12,7 @@ from typing import Optional import torch import torch.nn as nn -from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule,\ - ConvNormActAa, ConvNormAct, DropPath +from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs, register_model_deprecations @@ -39,13 +38,8 @@ class BasicBlock(nn.Module): self.stride = stride act_layer = partial(nn.LeakyReLU, negative_slope=1e-3) - if stride == 1: - self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=1, act_layer=act_layer) - else: - self.conv1 = ConvNormActAa( - inplanes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer) - - self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, act_layer=None) + self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer) + self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False) self.act = nn.ReLU(inplace=True) rd_chs = max(planes * self.expansion // 4, 64) @@ -87,18 +81,14 @@ class Bottleneck(nn.Module): self.conv1 = ConvNormAct( inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer) - if stride == 1: - self.conv2 = ConvNormAct( - planes, planes, kernel_size=3, stride=1, act_layer=act_layer) - else: - self.conv2 = ConvNormActAa( - planes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer) + self.conv2 = ConvNormAct( + planes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer) reduction_chs = max(planes * self.expansion // 8, 64) self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None self.conv3 = ConvNormAct( - planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None) + planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.act = nn.ReLU(inplace=True) @@ -204,7 +194,7 @@ class TResNet(nn.Module): # avg pooling before 1x1 conv layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False)) layers += [ConvNormAct( - self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)] + self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False)] downsample = nn.Sequential(*layers) layers = []