Add support to ConvNextBlock for downsample and channel expansion to improve stand alone use. Fix #1699
parent
43e6143bef
commit
ad94d737b7
|
@ -45,7 +45,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
|
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
|
||||||
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
|
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
|
||||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
|
@ -56,6 +56,28 @@ from ._registry import register_model
|
||||||
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
|
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_chs, out_chs, stride=1, dilation=1):
|
||||||
|
super().__init__()
|
||||||
|
avg_stride = stride if dilation == 1 else 1
|
||||||
|
if stride > 1 or dilation > 1:
|
||||||
|
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
||||||
|
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
||||||
|
else:
|
||||||
|
self.pool = nn.Identity()
|
||||||
|
|
||||||
|
if in_chs != out_chs:
|
||||||
|
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
|
||||||
|
else:
|
||||||
|
self.conv = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ConvNeXtBlock(nn.Module):
|
class ConvNeXtBlock(nn.Module):
|
||||||
""" ConvNeXt Block
|
""" ConvNeXt Block
|
||||||
There are two equivalent implementations:
|
There are two equivalent implementations:
|
||||||
|
@ -65,41 +87,65 @@ class ConvNeXtBlock(nn.Module):
|
||||||
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
|
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
|
||||||
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
|
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
|
||||||
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
|
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
|
||||||
|
|
||||||
Args:
|
|
||||||
in_chs (int): Number of input channels.
|
|
||||||
drop_path (float): Stochastic depth rate. Default: 0.0
|
|
||||||
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chs,
|
in_chs: int,
|
||||||
out_chs=None,
|
out_chs: Optional[int] = None,
|
||||||
kernel_size=7,
|
kernel_size: int = 7,
|
||||||
stride=1,
|
stride: int = 1,
|
||||||
dilation=1,
|
dilation: Union[int, Tuple[int, int]] = (1, 1),
|
||||||
mlp_ratio=4,
|
mlp_ratio: float = 4,
|
||||||
conv_mlp=False,
|
conv_mlp: bool = False,
|
||||||
conv_bias=True,
|
conv_bias: bool = True,
|
||||||
use_grn=False,
|
use_grn: bool = False,
|
||||||
ls_init_value=1e-6,
|
ls_init_value: Optional[float] = 1e-6,
|
||||||
act_layer='gelu',
|
act_layer: Union[str, Callable] = 'gelu',
|
||||||
norm_layer=None,
|
norm_layer: Optional[Callable] = None,
|
||||||
drop_path=0.,
|
drop_path: float = 0.,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_chs: Block input channels.
|
||||||
|
out_chs: Block output channels (same as in_chs if None).
|
||||||
|
kernel_size: Depthwise convolution kernel size.
|
||||||
|
stride: Stride of depthwise convolution.
|
||||||
|
dilation: Tuple specifying input and output dilation of block.
|
||||||
|
mlp_ratio: MLP expansion ratio.
|
||||||
|
conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
|
||||||
|
conv_bias: Apply bias for all convolution (linear) layers.
|
||||||
|
use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
|
||||||
|
ls_init_value: Layer-scale init values, layer-scale applied if not None.
|
||||||
|
act_layer: Activation layer.
|
||||||
|
norm_layer: Normalization layer (defaults to LN if not specified).
|
||||||
|
drop_path: Stochastic depth probability.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
out_chs = out_chs or in_chs
|
out_chs = out_chs or in_chs
|
||||||
|
dilation = to_ntuple(2)(dilation)
|
||||||
act_layer = get_act_layer(act_layer)
|
act_layer = get_act_layer(act_layer)
|
||||||
if not norm_layer:
|
if not norm_layer:
|
||||||
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
|
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
|
||||||
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
|
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
|
||||||
self.use_conv_mlp = conv_mlp
|
self.use_conv_mlp = conv_mlp
|
||||||
self.conv_dw = create_conv2d(
|
self.conv_dw = create_conv2d(
|
||||||
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation[0],
|
||||||
|
depthwise=True,
|
||||||
|
bias=conv_bias,
|
||||||
|
)
|
||||||
self.norm = norm_layer(out_chs)
|
self.norm = norm_layer(out_chs)
|
||||||
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
|
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
|
||||||
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
|
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
|
||||||
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
||||||
|
self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
|
||||||
|
else:
|
||||||
|
self.shortcut = nn.Identity()
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -116,7 +162,7 @@ class ConvNeXtBlock(nn.Module):
|
||||||
if self.gamma is not None:
|
if self.gamma is not None:
|
||||||
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
||||||
|
|
||||||
x = self.drop_path(x) + shortcut
|
x = self.drop_path(x) + self.shortcut(shortcut)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -148,8 +194,14 @@ class ConvNeXtStage(nn.Module):
|
||||||
self.downsample = nn.Sequential(
|
self.downsample = nn.Sequential(
|
||||||
norm_layer(in_chs),
|
norm_layer(in_chs),
|
||||||
create_conv2d(
|
create_conv2d(
|
||||||
in_chs, out_chs, kernel_size=ds_ks, stride=stride,
|
in_chs,
|
||||||
dilation=dilation[0], padding=pad, bias=conv_bias),
|
out_chs,
|
||||||
|
kernel_size=ds_ks,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=pad,
|
||||||
|
bias=conv_bias,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
in_chs = out_chs
|
in_chs = out_chs
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue