Add support to ConvNextBlock for downsample and channel expansion to improve stand alone use. Fix #1699
parent
43e6143bef
commit
ad94d737b7
|
@ -45,7 +45,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
|
||||
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
|
||||
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
|
||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
|
@ -56,6 +56,28 @@ from ._registry import register_model
|
|||
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_chs, out_chs, stride=1, dilation=1):
|
||||
super().__init__()
|
||||
avg_stride = stride if dilation == 1 else 1
|
||||
if stride > 1 or dilation > 1:
|
||||
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
||||
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
||||
else:
|
||||
self.pool = nn.Identity()
|
||||
|
||||
if in_chs != out_chs:
|
||||
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
|
||||
else:
|
||||
self.conv = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool(x)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
""" ConvNeXt Block
|
||||
There are two equivalent implementations:
|
||||
|
@ -65,41 +87,65 @@ class ConvNeXtBlock(nn.Module):
|
|||
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
|
||||
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
|
||||
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
|
||||
|
||||
Args:
|
||||
in_chs (int): Number of input channels.
|
||||
drop_path (float): Stochastic depth rate. Default: 0.0
|
||||
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs=None,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
mlp_ratio=4,
|
||||
conv_mlp=False,
|
||||
conv_bias=True,
|
||||
use_grn=False,
|
||||
ls_init_value=1e-6,
|
||||
act_layer='gelu',
|
||||
norm_layer=None,
|
||||
drop_path=0.,
|
||||
in_chs: int,
|
||||
out_chs: Optional[int] = None,
|
||||
kernel_size: int = 7,
|
||||
stride: int = 1,
|
||||
dilation: Union[int, Tuple[int, int]] = (1, 1),
|
||||
mlp_ratio: float = 4,
|
||||
conv_mlp: bool = False,
|
||||
conv_bias: bool = True,
|
||||
use_grn: bool = False,
|
||||
ls_init_value: Optional[float] = 1e-6,
|
||||
act_layer: Union[str, Callable] = 'gelu',
|
||||
norm_layer: Optional[Callable] = None,
|
||||
drop_path: float = 0.,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
in_chs: Block input channels.
|
||||
out_chs: Block output channels (same as in_chs if None).
|
||||
kernel_size: Depthwise convolution kernel size.
|
||||
stride: Stride of depthwise convolution.
|
||||
dilation: Tuple specifying input and output dilation of block.
|
||||
mlp_ratio: MLP expansion ratio.
|
||||
conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
|
||||
conv_bias: Apply bias for all convolution (linear) layers.
|
||||
use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
|
||||
ls_init_value: Layer-scale init values, layer-scale applied if not None.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer (defaults to LN if not specified).
|
||||
drop_path: Stochastic depth probability.
|
||||
"""
|
||||
super().__init__()
|
||||
out_chs = out_chs or in_chs
|
||||
dilation = to_ntuple(2)(dilation)
|
||||
act_layer = get_act_layer(act_layer)
|
||||
if not norm_layer:
|
||||
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
|
||||
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
|
||||
self.use_conv_mlp = conv_mlp
|
||||
self.conv_dw = create_conv2d(
|
||||
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation[0],
|
||||
depthwise=True,
|
||||
bias=conv_bias,
|
||||
)
|
||||
self.norm = norm_layer(out_chs)
|
||||
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
|
||||
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
|
||||
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
||||
self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
|
||||
else:
|
||||
self.shortcut = nn.Identity()
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -116,7 +162,7 @@ class ConvNeXtBlock(nn.Module):
|
|||
if self.gamma is not None:
|
||||
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
||||
|
||||
x = self.drop_path(x) + shortcut
|
||||
x = self.drop_path(x) + self.shortcut(shortcut)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -148,8 +194,14 @@ class ConvNeXtStage(nn.Module):
|
|||
self.downsample = nn.Sequential(
|
||||
norm_layer(in_chs),
|
||||
create_conv2d(
|
||||
in_chs, out_chs, kernel_size=ds_ks, stride=stride,
|
||||
dilation=dilation[0], padding=pad, bias=conv_bias),
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=ds_ks,
|
||||
stride=stride,
|
||||
dilation=dilation[0],
|
||||
padding=pad,
|
||||
bias=conv_bias,
|
||||
),
|
||||
)
|
||||
in_chs = out_chs
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue