mmpretrain/projects/starnet_backbone/mmpretrain/models/backbones/starnet.py

316 lines
10 KiB
Python

from typing import Optional, Sequence
import torch
import torch.nn as nn
from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer,
build_norm_layer)
from mmengine.model import BaseModule, Sequential
from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmpretrain.registry import MODELS
class Block(BaseModule):
"""StarNet Block.
Args:
in_channels (int): The number of input channels.
mlp_ratio (float): The expansion ratio in both pointwise convolution.
Defaults to 3.
drop_path (float): Stochastic depth rate. Defaults to 0.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='ReLU6')``.
conv_cfg (dict): Config dict for convolution layer.
Defaults to ``dict(type='Conv2d')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
in_channels,
mlp_ratio: float = 3.,
drop_path: float = 0.,
conv_cfg: Optional[dict] = dict(type='Conv2d'),
norm_cfg: Optional[dict] = dict(type='BN'),
act_cfg: Optional[dict] = dict(type='ReLU6'),
init_cfg: Optional[dict] = None,
) -> None:
super().__init__(init_cfg=init_cfg)
self.dwconv = ConvModule(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=7,
stride=1,
padding=(7 - 1) // 2,
groups=in_channels,
bias=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
)
self.fc1 = ConvModule(
in_channels=in_channels,
out_channels=mlp_ratio * in_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=None,
)
self.fc2 = ConvModule(
in_channels=in_channels,
out_channels=mlp_ratio * in_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=None,
)
self.g = ConvModule(
in_channels=mlp_ratio * in_channels,
out_channels=in_channels,
kernel_size=1,
bias=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
)
self.dwconv2 = ConvModule(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=7,
stride=1,
padding=(7 - 1) // 2,
groups=in_channels,
bias=True,
conv_cfg=conv_cfg,
norm_cfg=None,
)
self.act = build_activation_layer(act_cfg)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
x = self.dwconv(x)
x1, x2 = self.fc1(x), self.fc2(x)
x = self.act(x1) * x2
x = self.dwconv2(self.g(x))
x = identity + self.drop_path(x)
return x
@MODELS.register_module()
class StarNet(BaseBackbone):
"""StarNet.
A PyTorch implementation of StarNet introduced by:
`Rewrite the Stars <https://arxiv.org/abs/2403.19967>`_
Modified from the `official repo
<https://github.com/ma-xu/Rewrite-the-Stars?tab=readme-ov-file>`.
Args:
arch (str | dict): The model's architecture.
it should include the following two keys:
- layers (list[int]): Number of blocks at each stage.
- embed_dims (list[int]): The number of channels at each stage.
Defaults to 's1'.
in_channels (int): Number of input image channels. Default: 3.
out_channels (int): Output channels of the stem layer. Default: 32.
mlp_ratio (float): The expansion ratio in pointwise convolution.
Defaults to 4.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='ReLU6')``.
conv_cfg (dict): Config dict for convolution layer.
Defaults to ``dict(type='Conv2d')``.
init_cfg (dict, optional): Initialization config dict
"""
arch_settings = {
's1': {
'layers': [2, 2, 8, 3],
'embed_dims': [24, 48, 96, 192],
},
's2': {
'layers': [1, 2, 6, 2],
'embed_dims': [32, 64, 128, 256],
},
's3': {
'layers': [2, 2, 8, 4],
'embed_dims': [32, 64, 128, 256],
},
's4': {
'layers': [3, 3, 12, 5],
'embed_dims': [32, 64, 128, 256],
},
's050': {
'layers': [1, 1, 3, 1],
'embed_dims': [16, 32, 64, 128],
},
's100': {
'layers': [1, 2, 4, 1],
'embed_dims': [20, 40, 80, 160],
},
's150': {
'layers': [1, 2, 4, 2],
'embed_dims': [24, 48, 96, 192],
}
}
def __init__(
self,
arch='s1',
in_channels: int = 3,
out_channels: int = 32,
out_indices=-1,
frozen_stages=0,
mlp_ratio: float = 4.,
drop_path_rate: float = 0.,
conv_cfg: Optional[dict] = dict(type='Conv2d'),
norm_cfg: Optional[dict] = dict(type='BN'),
act_cfg: Optional[dict] = dict(type='ReLU6'),
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(type='Constant', val=1, layer=['_BatchNorm'])
]
) -> None:
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
assert 'layers' in arch and 'embed_dims' in arch, \
f'The arch dict must have "layers" and "embed_dims", ' \
f'but got {list(arch.keys())}.'
self.layers = arch['layers']
self.embed_dims = arch['embed_dims']
depth = len(self.layers)
self.num_stages = len(self.layers)
self.mlp_ratio = mlp_ratio
self.drop_path_rate = drop_path_rate
self.in_channels = in_channels
self.out_channels = out_channels
self.stem = ConvModule(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(self.layers))
]
self.stages = []
cur = 0
for i in range(depth):
stage = self._make_stage(
planes=self.out_channels,
num_blocks=self.layers[i],
cur=cur,
dpr=dpr,
stages_num=i)
self.out_channels = self.embed_dims[i]
cur += self.layers[i]
stage_name = f'stage{i}'
self.add_module(stage_name, stage)
self.stages.append(stage_name)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
out_indices = list(out_indices)
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_stages + index
assert 0 <= out_indices[i] <= self.num_stages, \
f'Invalid out_indices {index}.'
self.out_indices = out_indices
if self.out_indices:
for i_layer in self.out_indices:
layer = build_norm_layer(norm_cfg, self.embed_dims[i_layer])[1]
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
self.frozen_stages = frozen_stages
self._freeze_stages()
def _make_stage(self, planes, num_blocks, cur, dpr, stages_num):
down_sampler = ConvModule(
in_channels=planes,
out_channels=self.embed_dims[stages_num],
kernel_size=3,
stride=2,
padding=1,
bias=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
)
blocks = [
Block(
in_channels=self.embed_dims[stages_num],
mlp_ratio=self.mlp_ratio,
drop_path=dpr[cur + i],
) for i in range(num_blocks)
]
return Sequential(down_sampler, *blocks)
def forward(self, x):
x = self.stem(x)
outs = []
for i, stage_name in enumerate(self.stages):
stage = getattr(self, stage_name)
x = stage(x)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x)
outs.append(x_out)
return tuple(outs)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
stage_layer = getattr(self, f'stage{i}')
stage_layer.eval()
for param in stage_layer.parameters():
param.requires_grad = False
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
norm_layer.eval()
for param in norm_layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super(StarNet, self).train(mode)
self._freeze_stages()