From ad94d737b7892fa9d62f9a3392d70fd6dff0a90d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 13 Mar 2023 12:55:12 -0700 Subject: [PATCH] Add support to ConvNextBlock for downsample and channel expansion to improve stand alone use. Fix #1699 --- timm/models/convnext.py | 98 +++++++++++++++++++++++++++++++---------- 1 file changed, 75 insertions(+), 23 deletions(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index d00be5e4..b2e8fce7 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -45,7 +45,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \ +from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \ LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple from timm.layers import NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg @@ -56,6 +56,28 @@ from ._registry import register_model __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this +class Downsample(nn.Module): + + def __init__(self, in_chs, out_chs, stride=1, dilation=1): + super().__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + else: + self.pool = nn.Identity() + + if in_chs != out_chs: + self.conv = create_conv2d(in_chs, out_chs, 1, stride=1) + else: + self.conv = nn.Identity() + + def forward(self, x): + x = self.pool(x) + x = self.conv(x) + return x + + class ConvNeXtBlock(nn.Module): """ ConvNeXt Block There are two equivalent implementations: @@ -65,41 +87,65 @@ class ConvNeXtBlock(nn.Module): Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. - - Args: - in_chs (int): Number of input channels. - drop_path (float): Stochastic depth rate. Default: 0.0 - ls_init_value (float): Init value for Layer Scale. Default: 1e-6. """ def __init__( self, - in_chs, - out_chs=None, - kernel_size=7, - stride=1, - dilation=1, - mlp_ratio=4, - conv_mlp=False, - conv_bias=True, - use_grn=False, - ls_init_value=1e-6, - act_layer='gelu', - norm_layer=None, - drop_path=0., + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 7, + stride: int = 1, + dilation: Union[int, Tuple[int, int]] = (1, 1), + mlp_ratio: float = 4, + conv_mlp: bool = False, + conv_bias: bool = True, + use_grn: bool = False, + ls_init_value: Optional[float] = 1e-6, + act_layer: Union[str, Callable] = 'gelu', + norm_layer: Optional[Callable] = None, + drop_path: float = 0., ): + """ + + Args: + in_chs: Block input channels. + out_chs: Block output channels (same as in_chs if None). + kernel_size: Depthwise convolution kernel size. + stride: Stride of depthwise convolution. + dilation: Tuple specifying input and output dilation of block. + mlp_ratio: MLP expansion ratio. + conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True. + conv_bias: Apply bias for all convolution (linear) layers. + use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2) + ls_init_value: Layer-scale init values, layer-scale applied if not None. + act_layer: Activation layer. + norm_layer: Normalization layer (defaults to LN if not specified). + drop_path: Stochastic depth probability. + """ super().__init__() out_chs = out_chs or in_chs + dilation = to_ntuple(2)(dilation) act_layer = get_act_layer(act_layer) if not norm_layer: norm_layer = LayerNorm2d if conv_mlp else LayerNorm mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp) self.use_conv_mlp = conv_mlp self.conv_dw = create_conv2d( - in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias) + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation[0], + depthwise=True, + bias=conv_bias, + ) self.norm = norm_layer(out_chs) self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer) self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0]) + else: + self.shortcut = nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): @@ -116,7 +162,7 @@ class ConvNeXtBlock(nn.Module): if self.gamma is not None: x = x.mul(self.gamma.reshape(1, -1, 1, 1)) - x = self.drop_path(x) + shortcut + x = self.drop_path(x) + self.shortcut(shortcut) return x @@ -148,8 +194,14 @@ class ConvNeXtStage(nn.Module): self.downsample = nn.Sequential( norm_layer(in_chs), create_conv2d( - in_chs, out_chs, kernel_size=ds_ks, stride=stride, - dilation=dilation[0], padding=pad, bias=conv_bias), + in_chs, + out_chs, + kernel_size=ds_ks, + stride=stride, + dilation=dilation[0], + padding=pad, + bias=conv_bias, + ), ) in_chs = out_chs else: