346 lines
13 KiB
Python
346 lines
13 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from functools import partial
|
|
from itertools import chain
|
|
from typing import Sequence
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint as cp
|
|
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
|
|
from mmengine.model import BaseModule, ModuleList, Sequential
|
|
from mmengine.registry import MODELS
|
|
|
|
from .base_backbone import BaseBackbone
|
|
|
|
|
|
@MODELS.register_module('LN2d')
|
|
class LayerNorm2d(nn.LayerNorm):
|
|
"""LayerNorm on channels for 2d images.
|
|
|
|
Args:
|
|
num_channels (int): The number of channels of the input tensor.
|
|
eps (float): a value added to the denominator for numerical stability.
|
|
Defaults to 1e-5.
|
|
elementwise_affine (bool): a boolean value that when set to ``True``,
|
|
this module has learnable per-element affine parameters initialized
|
|
to ones (for weights) and zeros (for biases). Defaults to True.
|
|
"""
|
|
|
|
def __init__(self, num_channels: int, **kwargs) -> None:
|
|
super().__init__(num_channels, **kwargs)
|
|
self.num_channels = self.normalized_shape[0]
|
|
|
|
def forward(self, x):
|
|
assert x.dim() == 4, 'LayerNorm2d only supports inputs with shape ' \
|
|
f'(N, C, H, W), but got tensor with shape {x.shape}'
|
|
return F.layer_norm(
|
|
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight,
|
|
self.bias, self.eps).permute(0, 3, 1, 2)
|
|
|
|
|
|
class ConvNeXtBlock(BaseModule):
|
|
"""ConvNeXt Block.
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels.
|
|
dw_conv_cfg (dict): Config of depthwise convolution.
|
|
Defaults to ``dict(kernel_size=7, padding=3)``.
|
|
norm_cfg (dict): The config dict for norm layers.
|
|
Defaults to ``dict(type='LN2d', eps=1e-6)``.
|
|
act_cfg (dict): The config dict for activation between pointwise
|
|
convolution. Defaults to ``dict(type='GELU')``.
|
|
mlp_ratio (float): The expansion ratio in both pointwise convolution.
|
|
Defaults to 4.
|
|
linear_pw_conv (bool): Whether to use linear layer to do pointwise
|
|
convolution. More details can be found in the note.
|
|
Defaults to True.
|
|
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
|
layer_scale_init_value (float): Init value for Layer Scale.
|
|
Defaults to 1e-6.
|
|
|
|
Note:
|
|
There are two equivalent implementations:
|
|
|
|
1. DwConv -> LayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv;
|
|
all outputs are in (N, C, H, W).
|
|
2. DwConv -> LayerNorm -> Permute to (N, H, W, C) -> Linear -> GELU
|
|
-> Linear; Permute back
|
|
|
|
As default, we use the second to align with the official repository.
|
|
And it may be slightly faster.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
dw_conv_cfg=dict(kernel_size=7, padding=3),
|
|
norm_cfg=dict(type='LN2d', eps=1e-6),
|
|
act_cfg=dict(type='GELU'),
|
|
mlp_ratio=4.,
|
|
linear_pw_conv=True,
|
|
drop_path_rate=0.,
|
|
layer_scale_init_value=1e-6,
|
|
with_cp=False):
|
|
super().__init__()
|
|
self.with_cp = with_cp
|
|
|
|
self.depthwise_conv = nn.Conv2d(
|
|
in_channels, in_channels, groups=in_channels, **dw_conv_cfg)
|
|
|
|
self.linear_pw_conv = linear_pw_conv
|
|
self.norm = build_norm_layer(norm_cfg, in_channels)[1]
|
|
|
|
mid_channels = int(mlp_ratio * in_channels)
|
|
if self.linear_pw_conv:
|
|
# Use linear layer to do pointwise conv.
|
|
pw_conv = nn.Linear
|
|
else:
|
|
pw_conv = partial(nn.Conv2d, kernel_size=1)
|
|
|
|
self.pointwise_conv1 = pw_conv(in_channels, mid_channels)
|
|
self.act = build_activation_layer(act_cfg)
|
|
self.pointwise_conv2 = pw_conv(mid_channels, in_channels)
|
|
|
|
self.gamma = nn.Parameter(
|
|
layer_scale_init_value * torch.ones((in_channels)),
|
|
requires_grad=True) if layer_scale_init_value > 0 else None
|
|
|
|
self.drop_path = DropPath(
|
|
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
|
|
|
def forward(self, x):
|
|
|
|
def _inner_forward(x):
|
|
shortcut = x
|
|
x = self.depthwise_conv(x)
|
|
x = self.norm(x)
|
|
|
|
if self.linear_pw_conv:
|
|
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
|
|
|
x = self.pointwise_conv1(x)
|
|
x = self.act(x)
|
|
x = self.pointwise_conv2(x)
|
|
|
|
if self.linear_pw_conv:
|
|
x = x.permute(0, 3, 1, 2) # permute back
|
|
|
|
if self.gamma is not None:
|
|
x = x.mul(self.gamma.view(1, -1, 1, 1))
|
|
|
|
x = shortcut + self.drop_path(x)
|
|
return x
|
|
|
|
if self.with_cp and x.requires_grad:
|
|
x = cp.checkpoint(_inner_forward, x)
|
|
else:
|
|
x = _inner_forward(x)
|
|
return x
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ConvNeXt(BaseBackbone):
|
|
"""ConvNeXt.
|
|
|
|
A PyTorch implementation of : `A ConvNet for the 2020s
|
|
<https://arxiv.org/pdf/2201.03545.pdf>`_
|
|
|
|
Modified from the `official repo
|
|
<https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py>`_
|
|
and `timm
|
|
<https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convnext.py>`_.
|
|
|
|
Args:
|
|
arch (str | dict): The model's architecture. If string, it should be
|
|
one of architecture in ``ConvNeXt.arch_settings``. And if dict, it
|
|
should include the following two keys:
|
|
|
|
- depths (list[int]): Number of blocks at each stage.
|
|
- channels (list[int]): The number of channels at each stage.
|
|
|
|
Defaults to 'tiny'.
|
|
in_channels (int): Number of input image channels. Defaults to 3.
|
|
stem_patch_size (int): The size of one patch in the stem layer.
|
|
Defaults to 4.
|
|
norm_cfg (dict): The config dict for norm layers.
|
|
Defaults to ``dict(type='LN2d', eps=1e-6)``.
|
|
act_cfg (dict): The config dict for activation between pointwise
|
|
convolution. Defaults to ``dict(type='GELU')``.
|
|
linear_pw_conv (bool): Whether to use linear layer to do pointwise
|
|
convolution. Defaults to True.
|
|
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
|
layer_scale_init_value (float): Init value for Layer Scale.
|
|
Defaults to 1e-6.
|
|
out_indices (Sequence | int): Output from which stages.
|
|
Defaults to -1, means the last stage.
|
|
frozen_stages (int): Stages to be frozen (all param fixed).
|
|
Defaults to 0, which means not freezing any parameters.
|
|
gap_before_final_norm (bool): Whether to globally average the feature
|
|
map before the final norm layer. In the official repo, it's only
|
|
used in classification task. Defaults to True.
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed. Defaults to False.
|
|
init_cfg (dict, optional): Initialization config dict
|
|
""" # noqa: E501
|
|
arch_settings = {
|
|
'tiny': {
|
|
'depths': [3, 3, 9, 3],
|
|
'channels': [96, 192, 384, 768]
|
|
},
|
|
'small': {
|
|
'depths': [3, 3, 27, 3],
|
|
'channels': [96, 192, 384, 768]
|
|
},
|
|
'base': {
|
|
'depths': [3, 3, 27, 3],
|
|
'channels': [128, 256, 512, 1024]
|
|
},
|
|
'large': {
|
|
'depths': [3, 3, 27, 3],
|
|
'channels': [192, 384, 768, 1536]
|
|
},
|
|
'xlarge': {
|
|
'depths': [3, 3, 27, 3],
|
|
'channels': [256, 512, 1024, 2048]
|
|
},
|
|
}
|
|
|
|
def __init__(self,
|
|
arch='tiny',
|
|
in_channels=3,
|
|
stem_patch_size=4,
|
|
norm_cfg=dict(type='LN2d', eps=1e-6),
|
|
act_cfg=dict(type='GELU'),
|
|
linear_pw_conv=True,
|
|
drop_path_rate=0.,
|
|
layer_scale_init_value=1e-6,
|
|
out_indices=-1,
|
|
frozen_stages=0,
|
|
gap_before_final_norm=True,
|
|
with_cp=False,
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
if isinstance(arch, str):
|
|
assert arch in self.arch_settings, \
|
|
f'Unavailable arch, please choose from ' \
|
|
f'({set(self.arch_settings)}) or pass a dict.'
|
|
arch = self.arch_settings[arch]
|
|
elif isinstance(arch, dict):
|
|
assert 'depths' in arch and 'channels' in arch, \
|
|
f'The arch dict must have "depths" and "channels", ' \
|
|
f'but got {list(arch.keys())}.'
|
|
|
|
self.depths = arch['depths']
|
|
self.channels = arch['channels']
|
|
assert (isinstance(self.depths, Sequence)
|
|
and isinstance(self.channels, Sequence)
|
|
and len(self.depths) == len(self.channels)), \
|
|
f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \
|
|
'should be both sequence with the same length.'
|
|
|
|
self.num_stages = len(self.depths)
|
|
|
|
if isinstance(out_indices, int):
|
|
out_indices = [out_indices]
|
|
assert isinstance(out_indices, Sequence), \
|
|
f'"out_indices" must by a sequence or int, ' \
|
|
f'get {type(out_indices)} instead.'
|
|
for i, index in enumerate(out_indices):
|
|
if index < 0:
|
|
out_indices[i] = 4 + index
|
|
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
|
self.out_indices = out_indices
|
|
|
|
self.frozen_stages = frozen_stages
|
|
self.gap_before_final_norm = gap_before_final_norm
|
|
|
|
# stochastic depth decay rule
|
|
dpr = [
|
|
x.item()
|
|
for x in torch.linspace(0, drop_path_rate, sum(self.depths))
|
|
]
|
|
block_idx = 0
|
|
|
|
# 4 downsample layers between stages, including the stem layer.
|
|
self.downsample_layers = ModuleList()
|
|
stem = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels,
|
|
self.channels[0],
|
|
kernel_size=stem_patch_size,
|
|
stride=stem_patch_size),
|
|
build_norm_layer(norm_cfg, self.channels[0])[1],
|
|
)
|
|
self.downsample_layers.append(stem)
|
|
|
|
# 4 feature resolution stages, each consisting of multiple residual
|
|
# blocks
|
|
self.stages = nn.ModuleList()
|
|
|
|
for i in range(self.num_stages):
|
|
depth = self.depths[i]
|
|
channels = self.channels[i]
|
|
|
|
if i >= 1:
|
|
downsample_layer = nn.Sequential(
|
|
LayerNorm2d(self.channels[i - 1]),
|
|
nn.Conv2d(
|
|
self.channels[i - 1],
|
|
channels,
|
|
kernel_size=2,
|
|
stride=2),
|
|
)
|
|
self.downsample_layers.append(downsample_layer)
|
|
|
|
stage = Sequential(*[
|
|
ConvNeXtBlock(
|
|
in_channels=channels,
|
|
drop_path_rate=dpr[block_idx + j],
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
linear_pw_conv=linear_pw_conv,
|
|
layer_scale_init_value=layer_scale_init_value,
|
|
with_cp=with_cp) for j in range(depth)
|
|
])
|
|
block_idx += depth
|
|
|
|
self.stages.append(stage)
|
|
|
|
if i in self.out_indices:
|
|
norm_layer = build_norm_layer(norm_cfg, channels)[1]
|
|
self.add_module(f'norm{i}', norm_layer)
|
|
|
|
self._freeze_stages()
|
|
|
|
def forward(self, x):
|
|
outs = []
|
|
for i, stage in enumerate(self.stages):
|
|
x = self.downsample_layers[i](x)
|
|
x = stage(x)
|
|
if i in self.out_indices:
|
|
norm_layer = getattr(self, f'norm{i}')
|
|
if self.gap_before_final_norm:
|
|
gap = x.mean([-2, -1], keepdim=True)
|
|
outs.append(norm_layer(gap).flatten(1))
|
|
else:
|
|
# The output of LayerNorm2d may be discontiguous, which
|
|
# may cause some problem in the downstream tasks
|
|
outs.append(norm_layer(x).contiguous())
|
|
|
|
return tuple(outs)
|
|
|
|
def _freeze_stages(self):
|
|
for i in range(self.frozen_stages):
|
|
downsample_layer = self.downsample_layers[i]
|
|
stage = self.stages[i]
|
|
downsample_layer.eval()
|
|
stage.eval()
|
|
for param in chain(downsample_layer.parameters(),
|
|
stage.parameters()):
|
|
param.requires_grad = False
|
|
|
|
def train(self, mode=True):
|
|
super(ConvNeXt, self).train(mode)
|
|
self._freeze_stages()
|