Add files via upload

pull/1908/head
bear-coder-9527 2024-06-13 14:49:19 +08:00 committed by GitHub
parent 17a886cb58
commit e7dcde9631
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 453 additions and 0 deletions

View File

@ -0,0 +1,118 @@
# Example Project
This is an example README for community `projects/`. You can write your README in your own project. Here are
some recommended parts of a README for others to understand and use your project, you can copy or modify them
according to your project.
## Usage
### Setup Environment
Please refer to [Get Started](https://mmpretrain.readthedocs.io/en/latest/get_started.html) to install
MMPreTrain.
At first, add the current folder to `PYTHONPATH`, so that Python can find your code. Run command in the current directory to add it.
> Please run it every time after you opened a new shell.
```shell
export PYTHONPATH=`pwd`:$PYTHONPATH
```
### Data Preparation
Prepare the ImageNet-2012 dataset according to the [instruction](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#imagenet).
### Training commands
**To train with single GPU:**
```bash
mim train mmpretrain configs/starnet/startnet_s1_32xb32_in1k.py
```
**To train with multiple GPUs:**
```bash
mim train mmpretrain configs/starnet/startnet_s1_32xb32_in1k.py --launcher pytorch --gpus 8
```
**To train with multiple GPUs by slurm:**
```bash
mim train mmpretrain configs/starnet/startnet_s1_32xb32_in1k.py --launcher slurm \
--gpus 16 --gpus-per-node 8 --partition $PARTITION
```
### Testing commands
**To test with single GPU:**
```bash
mim test mmpretrain configs/starnet/startnet_s1_32xb32_in1k.py --checkpoint $CHECKPOINT
```
**To test with multiple GPUs:**
```bash
mim test mmpretrain configs/starnet/startnet_s1_32xb32_in1k.py --checkpoint $CHECKPOINT --launcher pytorch --gpus 8
```
**To test with multiple GPUs by slurm:**
```bash
mim test mmpretrain configs/starnet/startnet_s1_32xb32_in1k.py --checkpoint $CHECKPOINT --launcher slurm \
--gpus 16 --gpus-per-node 8 --partition $PARTITION
```
## Citation
<!-- Replace to the citation of the paper your project refers to. -->
```BibTeX
@misc{2023mmpretrain,
title={OpenMMLab's Pre-training Toolbox and Benchmark},
author={MMPreTrain Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmpretrain}},
year={2023}
}
```
## Checklist
Here is a checklist of this project's progress. And you can ignore this part if you don't plan to contribute
to MMPreTrain projects.
- [ ] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
- [ ] Finish the code
<!-- The code's design shall follow existing interfaces and convention. For example, each model component should be registered into `mmpretrain.registry.MODELS` and configurable via a config file. -->
- [ ] Basic docstrings & proper citation
<!-- Each major class should contains a docstring, describing its functionality and arguments. If your code is copied or modified from other open-source projects, don't forget to cite the source project in docstring and make sure your behavior is not against its license. Typically, we do not accept any code snippet under GPL license. [A Short Guide to Open Source Licenses](https://medium.com/nationwide-technology/a-short-guide-to-open-source-licenses-cf5b1c329edd) -->
- [ ] Converted checkpoint and results (Only for reproduction)
<!-- If you are reproducing the result from a paper, make sure the model in the project can match that results. Also please provide checkpoint links or a checkpoint conversion script for others to get the pre-trained model. -->
- [ ] Milestone 2: Indicates a successful model implementation.
- [ ] Training results
<!-- If you are reproducing the result from a paper, train your model from scratch and verified that the final result can match the original result. Usually, ±0.1% is acceptable for the image classification task on ImageNet-1k. -->
- [ ] Milestone 3: Good to be a part of our core package!
- [ ] Unit tests
<!-- Unit tests for the major module are required. [Example](https://github.com/open-mmlab/mmpretrain/blob/main/tests/test_models/test_backbones/test_vision_transformer.py) -->
- [ ] Code style
<!-- Refactor your code according to reviewer's comment. -->
- [ ] `metafile.yml` and `README.md`
<!-- It will used for MMPreTrain to acquire your models. [Example](https://github.com/open-mmlab/mmpretrain/blob/main/configs/mvit/metafile.yml). In particular, you may have to refactor this README into a standard one. [Example](https://github.com/open-mmlab/mmpretrain/blob/main/configs/swin_transformer/README.md) -->

View File

@ -0,0 +1,11 @@
# Model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='StarNet', arch='s1'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=9,
in_channels=192,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
))

View File

@ -0,0 +1,7 @@
_base_ = [
'../_base_/models/starnet/starnet_s1.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py', '../_base_/default_runtime.py'
]
train_dataloader = dict(batch_size=24)
val_dataloader = dict(batch_size=24)

View File

@ -0,0 +1,317 @@
from typing import Optional, Sequence
import torch
import torch.nn as nn
from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer,
build_norm_layer)
from mmengine.model import BaseModule, Sequential
from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmpretrain.registry import MODELS
class Block(BaseModule):
"""StarNet Block.
Args:
in_channels (int): The number of input channels.
mlp_ratio (float): The expansion ratio in both pointwise convolution.
Defaults to 3.
drop_path (float): Stochastic depth rate. Defaults to 0.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='ReLU6')``.
conv_cfg (dict): Config dict for convolution layer.
Defaults to ``dict(type='Conv2d')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
in_channels,
mlp_ratio: float = 3.,
drop_path: float = 0.,
conv_cfg: Optional[dict] = dict(type='Conv2d'),
norm_cfg: Optional[dict] = dict(type='BN'),
act_cfg: Optional[dict] = dict(type='ReLU6'),
init_cfg: Optional[dict] = None,
) -> None:
super().__init__(init_cfg=init_cfg)
self.dwconv = ConvModule(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=7,
stride=1,
padding=(7 - 1) // 2,
groups=in_channels,
bias=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
)
self.fc1 = ConvModule(
in_channels=in_channels,
out_channels=mlp_ratio * in_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=None,
)
self.fc2 = ConvModule(
in_channels=in_channels,
out_channels=mlp_ratio * in_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=None,
)
self.g = ConvModule(
in_channels=mlp_ratio * in_channels,
out_channels=in_channels,
kernel_size=1,
bias=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
)
self.dwconv2 = ConvModule(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=7,
stride=1,
padding=(7 - 1) // 2,
groups=in_channels,
bias=True,
conv_cfg=conv_cfg,
norm_cfg=None,
)
self.act = build_activation_layer(act_cfg)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
x = self.dwconv(x)
x1, x2 = self.fc1(x), self.fc2(x)
x = self.act(x1) * x2
x = self.dwconv2(self.g(x))
x = identity + self.drop_path(x)
return x
@MODELS.register_module()
class StarNet(BaseBackbone):
"""StarNet.
A PyTorch implementation of StarNet introduced by:
`Rewrite the Stars <https://arxiv.org/abs/2403.19967>`_
Modified from the `official repo
<https://github.com/ma-xu/Rewrite-the-Stars?tab=readme-ov-file>`.
Args:
arch (str | dict): The model's architecture.
If string, it should be one of architecture
in ``PoolFormer.arch_settings``. And if dict,
it should include the following two keys:
- layers (list[int]): Number of blocks at each stage.
- embed_dims (list[int]): The number of channels at each stage.
Defaults to 's1'.
in_channels (int): Number of input image channels. Default: 3.
out_channels (int): Output channels of the stem layer. Default: 32.
mlp_ratio (float): The expansion ratio in pointwise convolution.
Defaults to 4.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN')``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='ReLU6')``.
conv_cfg (dict): Config dict for convolution layer.
Defaults to ``dict(type='Conv2d')``.
init_cfg (dict, optional): Initialization config dict
"""
arch_settings = {
's1': {
'layers': [2, 2, 8, 3],
'embed_dims': [24, 48, 96, 192],
},
's2': {
'layers': [1, 2, 6, 2],
'embed_dims': [32, 64, 128, 256],
},
's3': {
'layers': [2, 2, 8, 4],
'embed_dims': [32, 64, 128, 256],
},
's4': {
'layers': [3, 3, 12, 5],
'embed_dims': [32, 64, 128, 256],
},
's050': {
'layers': [1, 1, 3, 1],
'embed_dims': [16, 32, 64, 128],
},
's100': {
'layers': [1, 2, 4, 1],
'embed_dims': [20, 40, 80, 160],
},
's150': {
'layers': [1, 2, 4, 2],
'embed_dims': [24, 48, 96, 192],
}
}
def __init__(
self,
arch='s1',
in_channels: int = 3,
out_channels: int = 32,
out_indices=-1,
frozen_stages=0,
mlp_ratio: float = 4.,
drop_path_rate: float = 0.,
conv_cfg: Optional[dict] = dict(type='Conv2d'),
norm_cfg: Optional[dict] = dict(type='BN'),
act_cfg: Optional[dict] = dict(type='ReLU6'),
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(type='Constant', val=1, layer=['_BatchNorm'])
]
) -> None:
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
assert 'layers' in arch and 'embed_dims' in arch, \
f'The arch dict must have "layers" and "embed_dims", ' \
f'but got {list(arch.keys())}.'
self.layers = arch['layers']
self.embed_dims = arch['embed_dims']
depth = len(self.layers)
self.num_stages = len(self.layers)
self.mlp_ratio = mlp_ratio
self.drop_path_rate = drop_path_rate
self.in_channels = in_channels
self.out_channels = out_channels
self.stem = ConvModule(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(self.layers))
]
self.stages = []
cur = 0
for i in range(depth):
stage = self._make_stage(
planes=self.out_channels,
num_blocks=self.layers[i],
cur=cur,
dpr=dpr,
stages_num=i)
self.out_channels = self.embed_dims[i]
cur += self.layers[i]
stage_name = f'stage{i}'
self.add_module(stage_name, stage)
self.stages.append(stage_name)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
out_indices = list(out_indices)
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_stages + index
assert 0 <= out_indices[i] <= self.num_stages, \
f'Invalid out_indices {index}.'
self.out_indices = out_indices
if self.out_indices:
for i_layer in self.out_indices:
layer = build_norm_layer(norm_cfg, self.embed_dims[i_layer])[1]
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
self.frozen_stages = frozen_stages
self._freeze_stages()
def _make_stage(self, planes, num_blocks, cur, dpr, stages_num):
down_sampler = ConvModule(
in_channels=planes,
out_channels=self.embed_dims[stages_num],
kernel_size=3,
stride=2,
padding=1,
bias=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
)
blocks = [
Block(
in_channels=self.embed_dims[stages_num],
mlp_ratio=self.mlp_ratio,
drop_path=dpr[cur + i],
) for i in range(num_blocks)
]
return Sequential(down_sampler, *blocks)
def forward(self, x):
x = self.stem(x)
outs = []
for i, stage_name in enumerate(self.stages):
stage = getattr(self, stage_name)
x = stage(x)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x)
outs.append(x_out)
return tuple(outs)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
stage_layer = getattr(self, f'stage{i}')
stage_layer.eval()
for param in stage_layer.parameters():
param.requires_grad = False
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
norm_layer.eval()
for param in norm_layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super(StarNet, self).train(mode)
self._freeze_stages()