Add support to ConvNextBlock for downsample and channel expansion to improve stand alone use. Fix #1699

convnext_shortcut
Ross Wightman 2023-03-13 12:55:12 -07:00
parent 43e6143bef
commit ad94d737b7
1 changed files with 75 additions and 23 deletions

View File

@ -45,7 +45,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
from timm.layers import NormMlpClassifierHead, ClassifierHead from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
@ -56,6 +56,28 @@ from ._registry import register_model
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
class Downsample(nn.Module):
def __init__(self, in_chs, out_chs, stride=1, dilation=1):
super().__init__()
avg_stride = stride if dilation == 1 else 1
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
else:
self.pool = nn.Identity()
if in_chs != out_chs:
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
else:
self.conv = nn.Identity()
def forward(self, x):
x = self.pool(x)
x = self.conv(x)
return x
class ConvNeXtBlock(nn.Module): class ConvNeXtBlock(nn.Module):
""" ConvNeXt Block """ ConvNeXt Block
There are two equivalent implementations: There are two equivalent implementations:
@ -65,41 +87,65 @@ class ConvNeXtBlock(nn.Module):
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
Args:
in_chs (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
""" """
def __init__( def __init__(
self, self,
in_chs, in_chs: int,
out_chs=None, out_chs: Optional[int] = None,
kernel_size=7, kernel_size: int = 7,
stride=1, stride: int = 1,
dilation=1, dilation: Union[int, Tuple[int, int]] = (1, 1),
mlp_ratio=4, mlp_ratio: float = 4,
conv_mlp=False, conv_mlp: bool = False,
conv_bias=True, conv_bias: bool = True,
use_grn=False, use_grn: bool = False,
ls_init_value=1e-6, ls_init_value: Optional[float] = 1e-6,
act_layer='gelu', act_layer: Union[str, Callable] = 'gelu',
norm_layer=None, norm_layer: Optional[Callable] = None,
drop_path=0., drop_path: float = 0.,
): ):
"""
Args:
in_chs: Block input channels.
out_chs: Block output channels (same as in_chs if None).
kernel_size: Depthwise convolution kernel size.
stride: Stride of depthwise convolution.
dilation: Tuple specifying input and output dilation of block.
mlp_ratio: MLP expansion ratio.
conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
conv_bias: Apply bias for all convolution (linear) layers.
use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
ls_init_value: Layer-scale init values, layer-scale applied if not None.
act_layer: Activation layer.
norm_layer: Normalization layer (defaults to LN if not specified).
drop_path: Stochastic depth probability.
"""
super().__init__() super().__init__()
out_chs = out_chs or in_chs out_chs = out_chs or in_chs
dilation = to_ntuple(2)(dilation)
act_layer = get_act_layer(act_layer) act_layer = get_act_layer(act_layer)
if not norm_layer: if not norm_layer:
norm_layer = LayerNorm2d if conv_mlp else LayerNorm norm_layer = LayerNorm2d if conv_mlp else LayerNorm
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp) mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
self.use_conv_mlp = conv_mlp self.use_conv_mlp = conv_mlp
self.conv_dw = create_conv2d( self.conv_dw = create_conv2d(
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias) in_chs,
out_chs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation[0],
depthwise=True,
bias=conv_bias,
)
self.norm = norm_layer(out_chs) self.norm = norm_layer(out_chs)
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer) self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
else:
self.shortcut = nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x): def forward(self, x):
@ -116,7 +162,7 @@ class ConvNeXtBlock(nn.Module):
if self.gamma is not None: if self.gamma is not None:
x = x.mul(self.gamma.reshape(1, -1, 1, 1)) x = x.mul(self.gamma.reshape(1, -1, 1, 1))
x = self.drop_path(x) + shortcut x = self.drop_path(x) + self.shortcut(shortcut)
return x return x
@ -148,8 +194,14 @@ class ConvNeXtStage(nn.Module):
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
norm_layer(in_chs), norm_layer(in_chs),
create_conv2d( create_conv2d(
in_chs, out_chs, kernel_size=ds_ks, stride=stride, in_chs,
dilation=dilation[0], padding=pad, bias=conv_bias), out_chs,
kernel_size=ds_ks,
stride=stride,
dilation=dilation[0],
padding=pad,
bias=conv_bias,
),
) )
in_chs = out_chs in_chs = out_chs
else: else: