[Feature] Support SparK. (#1531)

* add spark configs

* fix configs

* remove repeat aug

* add module codes

* support lr layer decay of resnet

* update

* fix lint

* add metafile and readme

* fix lint

* add models and logs

* refactor codes

* fix lint

* update model rst

* update name

* add docstring

* add ut

* fix lint

---------

Co-authored-by: Ma Zerun <mzr1996@163.com>
pull/1647/head
Yixiao Fang 2023-06-19 11:27:50 +08:00 committed by GitHub
parent bfd49b0d52
commit a1cfe888e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1964 additions and 0 deletions

View File

@ -0,0 +1,87 @@
# SparK
> [Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling](https://arxiv.org/abs/2301.03580)
<!-- [ALGORITHM] -->
## Abstract
We identify and overcome two key obstacles in extending the success of BERT-style pre-training, or the masked image modeling, to convolutional networks (convnets): (i) convolution operation cannot handle irregular, random-masked input images; (ii) the single-scale nature of BERT pre-training is inconsistent with convnet's hierarchical structure. For (i), we treat unmasked pixels as sparse voxels of 3D point clouds and use sparse convolution to encode. This is the first use of sparse convolution for 2D masked modeling. For (ii), we develop a hierarchical decoder to reconstruct images from multi-scale encoded features. Our method called Sparse masKed modeling (SparK) is general: it can be used directly on any convolutional model without backbone modifications. We validate it on both classical (ResNet) and modern (ConvNeXt) models: on three downstream tasks, it surpasses both state-of-the-art contrastive learning and transformer-based masked modeling by similarly large margins (around +1.0%). Improvements on object detection and instance segmentation are more substantial (up to +3.5%), verifying the strong transferability of features learned. We also find its favorable scaling behavior by observing more gains on larger models. All this evidence reveals a promising future of generative pre-training on convnets. Codes and models are released at https://github.com/keyu-tian/SparK.
<div align=center>
<img src="https://github.com/open-mmlab/mmpretrain/assets/36138628/b93e8d6f-ec1e-4f27-b986-da470fabe7df" width="80%"/>
</div>
## How to use it?
<!-- [TABS-BEGIN] -->
**Predict image**
```python
from mmpretrain import inference_model
predict = inference_model('resnet50_spark-pre_300e_in1k', 'demo/bird.JPEG')
print(predict['pred_class'])
print(predict['pred_score'])
```
**Use the model**
```python
import torch
from mmpretrain import get_model
model = get_model('spark_sparse-resnet50_800e_in1k', pretrained=True)
inputs = torch.rand(1, 3, 224, 224)
out = model(inputs)
print(type(out))
# To extract features.
feats = model.extract_feat(inputs)
print(type(feats))
```
**Train/Test Command**
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
Train:
```shell
python tools/train.py configs/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py
```
Test:
```shell
python tools/test.py configs/spark/benchmarks/resnet50_8xb256-coslr-300e_in1k.py https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/resnet50_8xb256-coslr-300e_in1k/resnet50_8xb256-coslr-300e_in1k_20230612-f86aab51.pth
```
<!-- [TABS-END] -->
## Models and results
### Pretrained models
| Model | Params (M) | Flops (G) | Config | Download |
| :--------------------------------------- | :--------: | :-------: | :-------------------------------------------------------------------: | :----------------------------------------------------------------------: |
| `spark_sparse-resnet50_800e_in1k` | 37.97 | 4.10 | [config](spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k_20230612-e403c28f.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k_20230612-e403c28f.json) |
| `spark_sparse-convnextv2-tiny_800e_in1k` | 39.73 | 4.47 | [config](spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k_20230612-b0ea712e.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k_20230612-b0ea712e.json) |
### Image Classification on ImageNet-1k
| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :------------------------------------ | :----------------------------------------: | :--------: | :-------: | :-------: | :-------: | :---------------------------------------: | :-----------------------------------------: |
| `resnet50_spark-pre_300e_in1k` | [SPARK](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k_20230612-e403c28f.pth) | 23.52 | 1.31 | 80.10 | 94.90 | [config](benchmarks/resnet50_8xb256-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/resnet50_8xb256-coslr-300e_in1k/resnet50_8xb256-coslr-300e_in1k_20230612-f86aab51.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/resnet50_8xb256-coslr-300e_in1k/resnet50_8xb256-coslr-300e_in1k_20230612-f86aab51.json) |
| `convnextv2-tiny_spark-pre_300e_in1k` | [SPARK](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k_20230612-b0ea712e.pth) | 28.64 | 4.47 | 82.80 | 96.30 | [config](benchmarks/convnextv2-tiny_8xb256-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/spark//spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k_20230612-ffc78743.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/spark//spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k_20230612-ffc78743.json) |
## Citation
```bibtex
@Article{tian2023designing,
author = {Keyu Tian and Yi Jiang and Qishuai Diao and Chen Lin and Liwei Wang and Zehuan Yuan},
title = {Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling},
journal = {arXiv:2301.03580},
year = {2023},
}
```

View File

@ -0,0 +1,122 @@
_base_ = [
'../../_base_/datasets/imagenet_bs64_swin_224.py',
'../../_base_/default_runtime.py',
]
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='NumpyToPIL', to_rgb=True),
dict(
type='torchvision/TrivialAugmentWide',
num_magnitude_bins=31,
interpolation='bicubic',
fill=None),
dict(type='PILToNumpy', to_bgr=True),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackInputs'),
]
train_dataloader = dict(
dataset=dict(pipeline=train_pipeline),
sampler=dict(type='RepeatAugSampler', shuffle=True),
)
# Model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='ConvNeXt',
arch='tiny',
drop_path_rate=0.1,
layer_scale_init_value=0.,
use_grn=True,
),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02, bias=0.),
),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0),
]),
)
custom_hooks = [
dict(
type='EMAHook',
momentum=1e-4,
evaluate_on_origin=True,
priority='ABOVE_NORMAL')
]
# schedule settings
# optimizer
optim_wrapper = dict(
optimizer=dict(
type='AdamW', lr=3.2e-3, betas=(0.9, 0.999), weight_decay=0.05),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.7,
norm_decay_mult=0.0,
bias_decay_mult=0.0,
flat_decay_mult=0.0))
# learning policy
param_scheduler = [
# warm up learning rate scheduler
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
begin=0,
end=20,
convert_to_iter_based=True),
# main learning rate scheduler
dict(
type='CosineAnnealingLR',
T_max=280,
eta_min=1.0e-5,
by_epoch=True,
begin=20,
end=300)
]
train_cfg = dict(by_epoch=True, max_epochs=300)
val_cfg = dict()
test_cfg = dict()
default_hooks = dict(
# only keeps the latest 2 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2))
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=2048)

View File

@ -0,0 +1,107 @@
_base_ = [
'../../_base_/models/resnet50.py',
'../../_base_/datasets/imagenet_bs256_rsb_a12.py',
'../../_base_/default_runtime.py'
]
# modification is based on ResNets RSB settings
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='NumpyToPIL', to_rgb=True),
dict(
type='torchvision/TrivialAugmentWide',
num_magnitude_bins=31,
interpolation='bicubic',
fill=None),
dict(type='PILToNumpy', to_bgr=True),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackInputs'),
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
# model settings
model = dict(
backbone=dict(
norm_cfg=dict(type='SyncBN', requires_grad=True),
drop_path_rate=0.05,
),
head=dict(
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, use_sigmoid=True)),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.1),
dict(type='CutMix', alpha=1.0)
]))
# schedule settings
# optimizer
optim_wrapper = dict(
optimizer=dict(
type='Lamb',
lr=0.016,
weight_decay=0.02,
),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.7,
norm_decay_mult=0.0,
bias_decay_mult=0.0,
flat_decay_mult=0.0))
# learning policy
param_scheduler = [
# warm up learning rate scheduler
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
begin=0,
end=5,
# update by iter
convert_to_iter_based=True),
# main learning rate scheduler
dict(
type='CosineAnnealingLR',
T_max=295,
eta_min=1.0e-6,
by_epoch=True,
begin=5,
end=300)
]
train_cfg = dict(by_epoch=True, max_epochs=300)
val_cfg = dict()
test_cfg = dict()
default_hooks = dict(
# only keeps the latest 2 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2))
# randomness
randomness = dict(seed=0, diff_rank_seed=True)
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=2048)

View File

@ -0,0 +1,73 @@
Collections:
- Name: SparK
Metadata:
Architecture:
- Dense Connections
- GELU
- Layer Normalization
- Multi-Head Attention
- Scaled Dot-Product Attention
Paper:
Title: 'Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling'
URL: https://arxiv.org/abs/2301.03580
README: configs/spark/README.md
Code:
URL: null
Version: null
Models:
- Name: spark_sparse-resnet50_800e_in1k
Metadata:
FLOPs: 4100000000
Parameters: 37971000
Training Data:
- ImageNet-1k
In Collection: SparK
Results: null
Weights: https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k_20230612-e403c28f.pth
Config: configs/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py
Downstream:
- resnet50_spark-pre_300e_in1k
- Name: resnet50_spark-pre_300e_in1k
Metadata:
FLOPs: 1310000000
Parameters: 23520000
Training Data:
- ImageNet-1k
In Collection: SparK
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 80.1
Top 5 Accuracy: 94.9
Task: Image Classification
Weights: https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/resnet50_8xb256-coslr-300e_in1k/resnet50_8xb256-coslr-300e_in1k_20230612-f86aab51.pth
Config: configs/spark/benchmarks/resnet50_8xb256-coslr-300e_in1k.py
- Name: spark_sparse-convnextv2-tiny_800e_in1k
Metadata:
FLOPs: 4470000000
Parameters: 39732000
Training Data:
- ImageNet-1k
In Collection: SparK
Results: null
Weights: https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k_20230612-b0ea712e.pth
Config: configs/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k.py
Downstream:
- convnextv2-tiny_spark-pre_300e_in1k
- Name: convnextv2-tiny_spark-pre_300e_in1k
Metadata:
FLOPs: 4469631744
Parameters: 28635496
Training Data:
- ImageNet-1k
In Collection: SparK
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.8
Top 5 Accuracy: 96.3
Task: Image Classification
Weights: https://download.openmmlab.com/mmpretrain/v1.0/spark//spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k_20230612-ffc78743.pth
Config: configs/spark/benchmarks/convnextv2-tiny_8xb256-coslr-300e_in1k.py

View File

@ -0,0 +1,81 @@
_base_ = [
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# dataset 8 x 512
train_dataloader = dict(batch_size=256, num_workers=8)
# model settings
model = dict(
type='SparK',
input_size=224,
downsample_raito=32,
mask_ratio=0.6,
enc_dec_norm_cfg=dict(type='SparseLN2d', eps=1e-6),
enc_dec_norm_dim=768,
backbone=dict(
type='SparseConvNeXt',
arch='small',
drop_path_rate=0.2,
out_indices=(0, 1, 2, 3),
gap_before_output=False),
neck=dict(
type='SparKLightDecoder',
feature_dim=512,
upsample_ratio=32, # equal to downsample_raito
mid_channels=0,
last_act=False),
head=dict(
type='SparKPretrainHead',
loss=dict(type='PixelReconstructionLoss', criterion='L2')))
# optimizer wrapper
optimizer = dict(
type='Lamb', lr=2e-4 * 4096 / 512, betas=(0.9, 0.95), weight_decay=0.04)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=optimizer,
clip_grad=dict(max_norm=5.0),
paramwise_cfg=dict(
bias_decay_mult=0.0,
flat_decay_mult=0.0,
custom_keys={
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=760,
by_epoch=True,
begin=40,
end=800,
convert_to_iter_based=True),
dict(
type='CosineAnnealingWeightDecay',
eta_min=0.2,
T_max=800,
by_epoch=True,
begin=0,
end=800,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=100),
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2))
# randomness
randomness = dict(seed=0, diff_rank_seed=True)

View File

@ -0,0 +1,84 @@
_base_ = [
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# dataset 16 x 256
train_dataloader = dict(batch_size=256, num_workers=8)
# model settings, use ConvNeXt V2
model = dict(
type='SparK',
input_size=224,
downsample_raito=32,
mask_ratio=0.6,
enc_dec_norm_cfg=dict(type='SparseLN2d', eps=1e-6),
enc_dec_norm_dim=768,
backbone=dict(
type='SparseConvNeXt',
arch='tiny',
drop_path_rate=0.2,
out_indices=(0, 1, 2, 3),
gap_before_output=False,
layer_scale_init_value=0.,
use_grn=True,
),
neck=dict(
type='SparKLightDecoder',
feature_dim=512,
upsample_ratio=32, # equal to downsample_raito
mid_channels=0,
last_act=False),
head=dict(
type='SparKPretrainHead',
loss=dict(type='PixelReconstructionLoss', criterion='L2')))
# optimizer wrapper
optimizer = dict(
type='Lamb', lr=2e-4 * 4096 / 512, betas=(0.9, 0.95), weight_decay=0.04)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=optimizer,
clip_grad=dict(max_norm=5.0),
paramwise_cfg=dict(
bias_decay_mult=0.0,
flat_decay_mult=0.0,
custom_keys={
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=20,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=780,
by_epoch=True,
begin=20,
end=800,
convert_to_iter_based=True),
dict(
type='CosineAnnealingWeightDecay',
eta_min=0.2,
T_max=800,
by_epoch=True,
begin=0,
end=800,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=100),
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2))
# randomness
randomness = dict(seed=0, diff_rank_seed=True)

View File

@ -0,0 +1,30 @@
_base_ = 'spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py'
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=1560,
by_epoch=True,
begin=40,
end=1600,
convert_to_iter_based=True),
dict(
type='CosineAnnealingWeightDecay',
eta_min=0.2,
T_max=1600,
by_epoch=True,
begin=0,
end=1600,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(max_epochs=1600)

View File

@ -0,0 +1,80 @@
_base_ = [
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# dataset 8 x 512
train_dataloader = dict(batch_size=512, num_workers=8)
# model settings
model = dict(
type='SparK',
input_size=224,
downsample_raito=32,
mask_ratio=0.6,
enc_dec_norm_cfg=dict(type='SparseSyncBatchNorm2d'),
enc_dec_norm_dim=2048,
backbone=dict(
type='SparseResNet',
depth=50,
out_indices=(0, 1, 2, 3),
drop_path_rate=0.05),
neck=dict(
type='SparKLightDecoder',
feature_dim=512,
upsample_ratio=32, # equal to downsample_raito
mid_channels=0,
last_act=False),
head=dict(
type='SparKPretrainHead',
loss=dict(type='PixelReconstructionLoss', criterion='L2')))
# optimizer wrapper
optimizer = dict(
type='Lamb', lr=2e-4 * 4096 / 512, betas=(0.9, 0.95), weight_decay=0.04)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=optimizer,
clip_grad=dict(max_norm=5.0),
paramwise_cfg=dict(
bias_decay_mult=0.0,
flat_decay_mult=0.0,
custom_keys={
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=760,
by_epoch=True,
begin=40,
end=800,
convert_to_iter_based=True),
dict(
type='CosineAnnealingWeightDecay',
eta_min=0.2,
T_max=800,
by_epoch=True,
begin=0,
end=800,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=100),
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2))
# randomness
randomness = dict(seed=0, diff_rank_seed=True)

View File

@ -76,6 +76,7 @@ Self-supervised Algorithms
SimCLR
SimMIM
SimSiam
SparK
SwAV
.. _selfsup_backbones:
@ -205,6 +206,8 @@ Backbones
SVT
ShuffleNetV1
ShuffleNetV2
SparseResNet
SparseConvNeXt
SwinTransformer
SwinTransformerV2
T2T_ViT
@ -243,6 +246,7 @@ Necks
SimMIMLinearDecoder
SwAVNeck
iTPNPretrainDecoder
SparKLightDecoder
.. module:: mmpretrain.models.heads
@ -280,6 +284,7 @@ Heads
VigClsHead
VisionTransformerClsHead
iTPNClipHead
SparKPretrainHead
.. module:: mmpretrain.models.losses

View File

@ -2,3 +2,4 @@
from .hooks import * # noqa: F401, F403
from .optimizers import * # noqa: F401, F403
from .runners import * # noqa: F401, F403
from .schedulers import * # noqa: F401, F403

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .weight_decay_scheduler import CosineAnnealingWeightDecay
__all__ = ['CosineAnnealingWeightDecay']

View File

@ -0,0 +1,64 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from mmengine.optim.scheduler import CosineAnnealingParamScheduler
from mmpretrain.registry import PARAM_SCHEDULERS
class WeightDecaySchedulerMixin:
"""A mixin class for learning rate schedulers."""
def __init__(self, optimizer, *args, **kwargs):
super().__init__(optimizer, 'weight_decay', *args, **kwargs)
@PARAM_SCHEDULERS.register_module()
class CosineAnnealingWeightDecay(WeightDecaySchedulerMixin,
CosineAnnealingParamScheduler):
"""Set the weight decay value of each parameter group using a cosine
annealing schedule.
If the weight decay was set to be 0 initially, the weight decay value will
be 0 constantly during the training.
"""
def _get_value(self) -> list:
"""Compute value using chainable form of the scheduler."""
def _get_eta_min(base_value):
if self.eta_min_ratio is None:
return self.eta_min
return base_value * self.eta_min_ratio
if self.last_step == 0:
return [
group[self.param_name] for group in self.optimizer.param_groups
]
elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0:
weight_decay_value_list = []
for base_value, group in zip(self.base_values,
self.optimizer.param_groups):
if base_value == 0:
group_value = 0
else:
group_value = group[self.param_name] + (
base_value - _get_eta_min(base_value)) * (
1 - math.cos(math.pi / self.T_max)) / 2
weight_decay_value_list.append(group_value)
return weight_decay_value_list
weight_decay_value_list = []
for base_value, group in zip(self.base_values,
self.optimizer.param_groups):
if base_value == 0:
group_value = 0
else:
group_value = (
1 + math.cos(math.pi * self.last_step / self.T_max)) / (
1 + math.cos(math.pi *
(self.last_step - 1) / self.T_max)
) * (group[self.param_name] -
_get_eta_min(base_value)) + _get_eta_min(base_value)
weight_decay_value_list.append(group_value)
return weight_decay_value_list

View File

@ -42,6 +42,8 @@ from .seresnet import SEResNet
from .seresnext import SEResNeXt
from .shufflenet_v1 import ShuffleNetV1
from .shufflenet_v2 import ShuffleNetV2
from .sparse_convnext import SparseConvNeXt
from .sparse_resnet import SparseResNet
from .swin_transformer import SwinTransformer
from .swin_transformer_v2 import SwinTransformerV2
from .t2t_vit import T2T_ViT
@ -122,4 +124,6 @@ __all__ = [
'ViTSAM',
'ViTEVA02',
'HiViT',
'SparseResNet',
'SparseConvNeXt',
]

View File

@ -366,3 +366,47 @@ class ConvNeXt(BaseBackbone):
def train(self, mode=True):
super(ConvNeXt, self).train(mode)
self._freeze_stages()
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
"""
max_layer_id = 12 if self.depths[-2] > 9 else 6
if not param_name.startswith(prefix):
# For subsequent module like head
return max_layer_id + 1, max_layer_id + 2
param_name = param_name[len(prefix):]
if param_name.startswith('downsample_layers'):
stage_id = int(param_name.split('.')[1])
if stage_id == 0:
layer_id = 0
elif stage_id == 1 or stage_id == 2:
layer_id = stage_id + 1
else: # stage_id == 3:
layer_id = max_layer_id
elif param_name.startswith('stages'):
stage_id = int(param_name.split('.')[1])
block_id = int(param_name.split('.')[2])
if stage_id == 0 or stage_id == 1:
layer_id = stage_id + 1
elif stage_id == 2:
layer_id = 3 + block_id // 3
else: # stage_id == 3:
layer_id = max_layer_id
# final norm layer
else:
layer_id = max_layer_id + 1
return layer_id, max_layer_id + 2

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
@ -674,6 +676,64 @@ class ResNet(BaseBackbone):
if isinstance(m, _BatchNorm):
m.eval()
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer id to set the different learning rates for ResNet.
ResNet stages:
50 : [3, 4, 6, 3]
101 : [3, 4, 23, 3]
152 : [3, 8, 36, 3]
200 : [3, 24, 36, 3]
eca269d: [3, 30, 48, 8]
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
"""
depths = self.stage_blocks
if depths[1] == 4 and depths[2] == 6:
blk2, blk3 = 2, 3
elif depths[1] == 4 and depths[2] == 23:
blk2, blk3 = 2, 3
elif depths[1] == 8 and depths[2] == 36:
blk2, blk3 = 4, 4
elif depths[1] == 24 and depths[2] == 36:
blk2, blk3 = 4, 4
elif depths[1] == 30 and depths[2] == 48:
blk2, blk3 = 5, 6
else:
raise NotImplementedError
N2, N3 = math.ceil(depths[1] / blk2 -
1e-5), math.ceil(depths[2] / blk3 - 1e-5)
N = 2 + N2 + N3 # r50: 2 + 2 + 2 = 6
max_layer_id = N + 1 # r50: 2 + 2 + 2 + 1(like head) = 7
if not param_name.startswith(prefix):
# For subsequent module like head
return max_layer_id, max_layer_id + 1
if param_name.startswith('backbone.layer'):
stage_id = int(param_name.split('.')[1][5:])
block_id = int(param_name.split('.')[2])
if stage_id == 1:
layer_id = 1
elif stage_id == 2:
layer_id = 2 + block_id // blk2 # r50: 2, 3
elif stage_id == 3:
layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5
else: # stage_id == 4
layer_id = N # r50: 6
return layer_id, max_layer_id + 1
else:
return 0, max_layer_id + 1
@MODELS.register_module()
class ResNetV1c(ResNet):

View File

@ -0,0 +1,298 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmengine.model import ModuleList, Sequential
from mmpretrain.registry import MODELS
from ..utils import (SparseAvgPooling, SparseConv2d, SparseHelper,
SparseMaxPooling, build_norm_layer)
from .convnext import ConvNeXt, ConvNeXtBlock
class SparseConvNeXtBlock(ConvNeXtBlock):
"""Sparse ConvNeXt Block.
Note:
There are two equivalent implementations:
1. DwConv -> SparseLayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv;
all outputs are in (N, C, H, W).
2. DwConv -> SparseLayerNorm -> Permute to (N, H, W, C) -> Linear ->
GELU -> Linear; Permute back
As default, we use the second to align with the official repository.
And it may be slightly faster.
"""
def forward(self, x):
def _inner_forward(x):
shortcut = x
x = self.depthwise_conv(x)
if self.linear_pw_conv:
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x, data_format='channel_last')
x = self.pointwise_conv1(x)
x = self.act(x)
if self.grn is not None:
x = self.grn(x, data_format='channel_last')
x = self.pointwise_conv2(x)
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
else:
x = self.norm(x, data_format='channel_first')
x = self.pointwise_conv1(x)
x = self.act(x)
if self.grn is not None:
x = self.grn(x, data_format='channel_first')
x = self.pointwise_conv2(x)
if self.gamma is not None:
x = x.mul(self.gamma.view(1, -1, 1, 1))
x *= SparseHelper._get_active_map_or_index(
H=x.shape[2], returning_active_map=True)
x = shortcut + self.drop_path(x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
@MODELS.register_module()
class SparseConvNeXt(ConvNeXt):
"""ConvNeXt with sparse module conversion function.
Modified from
https://github.com/keyu-tian/SparK/blob/main/models/convnext.py
and
https://github.com/keyu-tian/SparK/blob/main/encoder.py
To use ConvNeXt v2, please set ``use_grn=True`` and ``layer_scale_init_value=0.``.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``ConvNeXt.arch_settings``. And if dict, it
should include the following two keys:
- depths (list[int]): Number of blocks at each stage.
- channels (list[int]): The number of channels at each stage.
Defaults to 'tiny'.
in_channels (int): Number of input image channels. Defaults to 3.
stem_patch_size (int): The size of one patch in the stem layer.
Defaults to 4.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='SparseLN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
linear_pw_conv (bool): Whether to use linear layer to do pointwise
convolution. Defaults to True.
use_grn (bool): Whether to add Global Response Normalization in the
blocks. Defaults to False.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-6.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
gap_before_output (bool): Whether to globally average the feature
map before the final norm layer. In the official repo, it's only
used in classification task. Defaults to True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): Initialization config dict.
""" # noqa: E501
def __init__(self,
arch: str = 'small',
in_channels: int = 3,
stem_patch_size: int = 4,
norm_cfg: dict = dict(type='SparseLN2d', eps=1e-6),
act_cfg: dict = dict(type='GELU'),
linear_pw_conv: bool = True,
use_grn: bool = False,
drop_path_rate: float = 0,
layer_scale_init_value: float = 1e-6,
out_indices: int = -1,
frozen_stages: int = 0,
gap_before_output: bool = True,
with_cp: bool = False,
init_cfg: Optional[Union[dict, List[dict]]] = [
dict(
type='TruncNormal',
layer=['Conv2d', 'Linear'],
std=.02,
bias=0.),
dict(
type='Constant', layer=['LayerNorm'], val=1.,
bias=0.),
]):
super(ConvNeXt, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
assert 'depths' in arch and 'channels' in arch, \
f'The arch dict must have "depths" and "channels", ' \
f'but got {list(arch.keys())}.'
self.depths = arch['depths']
self.channels = arch['channels']
assert (isinstance(self.depths, Sequence)
and isinstance(self.channels, Sequence)
and len(self.depths) == len(self.channels)), \
f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \
'should be both sequence with the same length.'
self.num_stages = len(self.depths)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = 4 + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.gap_before_output = gap_before_output
# 4 downsample layers between stages, including the stem layer.
self.downsample_layers = ModuleList()
stem = nn.Sequential(
nn.Conv2d(
in_channels,
self.channels[0],
kernel_size=stem_patch_size,
stride=stem_patch_size),
build_norm_layer(norm_cfg, self.channels[0]),
)
self.downsample_layers.append(stem)
# stochastic depth decay rule
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(self.depths))
]
block_idx = 0
# 4 feature resolution stages, each consisting of multiple residual
# blocks
self.stages = nn.ModuleList()
for i in range(self.num_stages):
depth = self.depths[i]
channels = self.channels[i]
if i >= 1:
downsample_layer = nn.Sequential(
build_norm_layer(norm_cfg, self.channels[i - 1]),
nn.Conv2d(
self.channels[i - 1],
channels,
kernel_size=2,
stride=2),
)
self.downsample_layers.append(downsample_layer)
stage = Sequential(*[
SparseConvNeXtBlock(
in_channels=channels,
drop_path_rate=dpr[block_idx + j],
norm_cfg=norm_cfg,
act_cfg=act_cfg,
linear_pw_conv=linear_pw_conv,
layer_scale_init_value=layer_scale_init_value,
use_grn=use_grn,
with_cp=with_cp) for j in range(depth)
])
block_idx += depth
self.stages.append(stage)
self.dense_model_to_sparse(m=self)
def forward(self, x):
outs = []
for i, stage in enumerate(self.stages):
x = self.downsample_layers[i](x)
x = stage(x)
if i in self.out_indices:
if self.gap_before_output:
gap = x.mean([-2, -1], keepdim=True)
outs.append(gap.flatten(1))
else:
outs.append(x)
return tuple(outs)
def dense_model_to_sparse(self, m: nn.Module) -> nn.Module:
"""Convert regular dense modules to sparse modules."""
output = m
if isinstance(m, nn.Conv2d):
m: nn.Conv2d
bias = m.bias is not None
output = SparseConv2d(
m.in_channels,
m.out_channels,
kernel_size=m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
groups=m.groups,
bias=bias,
padding_mode=m.padding_mode,
)
output.weight.data.copy_(m.weight.data)
if bias:
output.bias.data.copy_(m.bias.data)
elif isinstance(m, nn.MaxPool2d):
m: nn.MaxPool2d
output = SparseMaxPooling(
m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
return_indices=m.return_indices,
ceil_mode=m.ceil_mode)
elif isinstance(m, nn.AvgPool2d):
m: nn.AvgPool2d
output = SparseAvgPooling(
m.kernel_size,
m.stride,
m.padding,
ceil_mode=m.ceil_mode,
count_include_pad=m.count_include_pad,
divisor_override=m.divisor_override)
# elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
# m: nn.BatchNorm2d
# output = (SparseSyncBatchNorm2d
# if enable_sync_bn else SparseBatchNorm2d)(
# m.weight.shape[0],
# eps=m.eps,
# momentum=m.momentum,
# affine=m.affine,
# track_running_stats=m.track_running_stats)
# output.weight.data.copy_(m.weight.data)
# output.bias.data.copy_(m.bias.data)
# output.running_mean.data.copy_(m.running_mean.data)
# output.running_var.data.copy_(m.running_var.data)
# output.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
for name, child in m.named_children():
output.add_module(name, self.dense_model_to_sparse(child))
del m
return output

View File

@ -0,0 +1,179 @@
# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import Optional, Tuple
import torch.nn as nn
from mmpretrain.models.utils.sparse_modules import (SparseAvgPooling,
SparseBatchNorm2d,
SparseConv2d,
SparseMaxPooling,
SparseSyncBatchNorm2d)
from mmpretrain.registry import MODELS
from .resnet import ResNet
@MODELS.register_module()
class SparseResNet(ResNet):
"""ResNet with sparse module conversion function.
Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py
Args:
depth (int): Network depth, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Defaults to 3.
stem_channels (int): Output channels of the stem layer. Defaults to 64.
base_channels (int): Middle channels of the first stage.
Defaults to 64.
num_stages (int): Stages of the network. Defaults to 4.
strides (Sequence[int]): Strides of the first block of each stage.
Defaults to ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Defaults to ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages.
Defaults to ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Defaults to False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
conv_cfg (dict | None): The config dict for conv layers.
Defaults to None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Defaults to True.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
"""
def __init__(self,
depth: int,
in_channels: int = 3,
stem_channels: int = 64,
base_channels: int = 64,
expansion: Optional[int] = None,
num_stages: int = 4,
strides: Tuple[int] = (1, 2, 2, 2),
dilations: Tuple[int] = (1, 1, 1, 1),
out_indices: Tuple[int] = (3, ),
style: str = 'pytorch',
deep_stem: bool = False,
avg_down: bool = False,
frozen_stages: int = -1,
conv_cfg: Optional[dict] = None,
norm_cfg: dict = dict(type='SparseSyncBatchNorm2d'),
norm_eval: bool = False,
with_cp: bool = False,
zero_init_residual: bool = False,
init_cfg: Optional[dict] = [
dict(type='Kaiming', layer=['Conv2d']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
],
drop_path_rate: float = 0,
**kwargs):
super().__init__(
depth=depth,
in_channels=in_channels,
stem_channels=stem_channels,
base_channels=base_channels,
expansion=expansion,
num_stages=num_stages,
strides=strides,
dilations=dilations,
out_indices=out_indices,
style=style,
deep_stem=deep_stem,
avg_down=avg_down,
frozen_stages=frozen_stages,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
norm_eval=norm_eval,
with_cp=with_cp,
zero_init_residual=zero_init_residual,
init_cfg=init_cfg,
drop_path_rate=drop_path_rate,
**kwargs)
norm_type = norm_cfg['type']
enable_sync_bn = False
if re.search('Sync', norm_type) is not None:
enable_sync_bn = True
self.dense_model_to_sparse(m=self, enable_sync_bn=enable_sync_bn)
def dense_model_to_sparse(self, m: nn.Module,
enable_sync_bn: bool) -> nn.Module:
"""Convert regular dense modules to sparse modules."""
output = m
if isinstance(m, nn.Conv2d):
m: nn.Conv2d
bias = m.bias is not None
output = SparseConv2d(
m.in_channels,
m.out_channels,
kernel_size=m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
groups=m.groups,
bias=bias,
padding_mode=m.padding_mode,
)
output.weight.data.copy_(m.weight.data)
if bias:
output.bias.data.copy_(m.bias.data)
elif isinstance(m, nn.MaxPool2d):
m: nn.MaxPool2d
output = SparseMaxPooling(
m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
return_indices=m.return_indices,
ceil_mode=m.ceil_mode)
elif isinstance(m, nn.AvgPool2d):
m: nn.AvgPool2d
output = SparseAvgPooling(
m.kernel_size,
m.stride,
m.padding,
ceil_mode=m.ceil_mode,
count_include_pad=m.count_include_pad,
divisor_override=m.divisor_override)
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
m: nn.BatchNorm2d
output = (SparseSyncBatchNorm2d
if enable_sync_bn else SparseBatchNorm2d)(
m.weight.shape[0],
eps=m.eps,
momentum=m.momentum,
affine=m.affine,
track_running_stats=m.track_running_stats)
output.weight.data.copy_(m.weight.data)
output.bias.data.copy_(m.bias.data)
output.running_mean.data.copy_(m.running_mean.data)
output.running_var.data.copy_(m.running_var.data)
output.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
elif isinstance(m, (nn.Conv1d, )):
raise NotImplementedError
for name, child in m.named_children():
output.add_module(
name,
self.dense_model_to_sparse(
child, enable_sync_bn=enable_sync_bn))
del m
return output

View File

@ -25,6 +25,7 @@ from .multi_label_linear_head import MultiLabelLinearClsHead
from .multi_task_head import MultiTaskHead
from .seq_gen_head import SeqGenerationHead
from .simmim_head import SimMIMHead
from .spark_head import SparKPretrainHead
from .stacked_head import StackedLinearClsHead
from .swav_head import SwAVHead
from .vig_head import VigClsHead
@ -64,4 +65,5 @@ __all__ = [
'ITMHead',
'GroundingHead',
'iTPNClipHead',
'SparKPretrainHead',
]

View File

@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class SparKPretrainHead(BaseModule):
"""Pre-training head for SparK.
Args:
loss (dict): Config of loss.
norm_pix (bool): Whether or not normalize target. Defaults to True.
patch_size (int): Patch size, equal to downsample ratio of backbone.
Defaults to 32.
"""
def __init__(self,
loss: dict,
norm_pix: bool = True,
patch_size: int = 32) -> None:
super().__init__()
self.norm_pix = norm_pix
self.patch_size = patch_size
self.loss = MODELS.build(loss)
def patchify(self, imgs):
"""Split images into non-overlapped patches.
Args:
imgs (torch.Tensor): A batch of images, of shape B x C x H x W.
Returns:
torch.Tensor: Patchified images. The shape is B x L x D.
"""
p = self.patch_size
assert len(imgs.shape
) == 4 and imgs.shape[2] % p == 0 and imgs.shape[3] % p == 0
B, C, ori_h, ori_w = imgs.shape
h = ori_h // p
w = ori_w // p
x = imgs.reshape(shape=(B, C, h, p, w, p))
x = torch.einsum('bchpwq->bhwpqc', x)
# (B, f*f, downsample_raito*downsample_raito*3)
x = x.reshape(shape=(B, h * w, p**2 * C))
return x
def construct_target(self, target: torch.Tensor) -> torch.Tensor:
"""Construct the reconstruction target.
In addition to splitting images into tokens, this module will also
normalize the image according to ``norm_pix``.
Args:
target (torch.Tensor): Image with the shape of B x 3 x H x W
Returns:
torch.Tensor: Tokenized images with the shape of B x L x C
"""
target = self.patchify(target)
if self.norm_pix:
# normalize the target image
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
return target
def forward(self, pred: torch.Tensor, target: torch.Tensor,
active_mask: torch.Tensor) -> torch.Tensor:
"""Forward function of MAE head.
Args:
pred (torch.Tensor): The reconstructed image.
target (torch.Tensor): The target image.
active_mask (torch.Tensor): The mask of the target image.
Returns:
torch.Tensor: The reconstruction loss.
"""
# (B, C, H, W) -> (B, L, C) and perform normalization
target = self.construct_target(target)
# (B, C, H, W) -> (B, L, C)
pred = self.patchify(pred)
# (B, 1, f, f) -> (B, L)
non_active_mask = active_mask.logical_not().int().view(
active_mask.shape[0], -1)
# MSE loss on masked patches
loss = self.loss(pred, target, non_active_mask)
return loss

View File

@ -13,6 +13,7 @@ from .mixmim_neck import MixMIMPretrainDecoder
from .mocov2_neck import MoCoV2Neck
from .nonlinear_neck import NonLinearNeck
from .simmim_neck import SimMIMLinearDecoder
from .spark_neck import SparKLightDecoder
from .swav_neck import SwAVNeck
__all__ = [
@ -32,4 +33,5 @@ __all__ = [
'SimMIMLinearDecoder',
'SwAVNeck',
'iTPNPretrainDecoder',
'SparKLightDecoder',
]

View File

@ -0,0 +1,169 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
from ..utils import build_norm_layer
def is_pow2n(x):
return x > 0 and (x & (x - 1) == 0)
class ConvBlock2x(BaseModule):
"""The definition of convolution block."""
def __init__(self,
in_channels: int,
out_channels: int,
mid_channels: int,
norm_cfg: dict,
act_cfg: dict,
last_act: bool,
init_cfg: Optional[dict] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1, bias=False)
self.norm1 = build_norm_layer(norm_cfg, mid_channels)
self.activate1 = MODELS.build(act_cfg)
self.conv2 = nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False)
self.norm2 = build_norm_layer(norm_cfg, out_channels)
self.activate2 = MODELS.build(act_cfg) if last_act else nn.Identity()
def forward(self, x: torch.Tensor):
out = self.conv1(x)
out = self.norm1(out)
out = self.activate1(out)
out = self.conv2(out)
out = self.norm2(out)
out = self.activate2(out)
return out
class DecoderConvModule(BaseModule):
"""The convolution module of decoder with upsampling."""
def __init__(self,
in_channels: int,
out_channels: int,
mid_channels: int,
kernel_size: int = 4,
scale_factor: int = 2,
num_conv_blocks: int = 1,
norm_cfg: dict = dict(type='SyncBN'),
act_cfg: dict = dict(type='ReLU6'),
last_act: bool = True,
init_cfg: Optional[dict] = None):
super().__init__(init_cfg=init_cfg)
assert (kernel_size - scale_factor >= 0) and\
(kernel_size - scale_factor) % 2 == 0,\
f'kernel_size should be greater than or equal to scale_factor '\
f'and (kernel_size - scale_factor) should be even numbers, '\
f'while the kernel size is {kernel_size} and scale_factor is '\
f'{scale_factor}.'
padding = (kernel_size - scale_factor) // 2
self.upsample = nn.ConvTranspose2d(
in_channels,
in_channels,
kernel_size=kernel_size,
stride=scale_factor,
padding=padding,
bias=True)
conv_blocks_list = [
ConvBlock2x(
in_channels=in_channels,
out_channels=out_channels,
mid_channels=mid_channels,
norm_cfg=norm_cfg,
last_act=last_act,
act_cfg=act_cfg) for _ in range(num_conv_blocks)
]
self.conv_blocks = nn.Sequential(*conv_blocks_list)
def forward(self, x):
x = self.upsample(x)
return self.conv_blocks(x)
@MODELS.register_module()
class SparKLightDecoder(BaseModule):
"""The decoder for SparK, which upsamples the feature maps.
Args:
feature_dim (int): The dimension of feature map.
upsample_ratio (int): The ratio of upsample, equal to downsample_raito
of the algorithm.
mid_channels (int): The middle channel of `DecoderConvModule`. Defaults
to 0.
kernel_size (int): The kernel size of `ConvTranspose2d` in
`DecoderConvModule`. Defaults to 4.
scale_factor (int): The scale_factor of `ConvTranspose2d` in
`DecoderConvModule`. Defaults to 2.
num_conv_blocks (int): The number of convolution blocks in
`DecoderConvModule`. Defaults to 1.
norm_cfg (dict): Normalization config. Defaults to dict(type='SyncBN').
act_cfg (dict): Activation config. Defaults to dict(type='ReLU6').
last_act (bool): Whether apply the last activation in
`DecoderConvModule`. Defaults to False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(
self,
feature_dim: int,
upsample_ratio: int,
mid_channels: int = 0,
kernel_size: int = 4,
scale_factor: int = 2,
num_conv_blocks: int = 1,
norm_cfg: dict = dict(type='SyncBN'),
act_cfg: dict = dict(type='ReLU6'),
last_act: bool = False,
init_cfg: Optional[dict] = [
dict(type='Kaiming', layer=['Conv2d', 'ConvTranspose2d']),
dict(type='TruncNormal', std=0.02, layer=['Linear']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'LayerNorm', 'SyncBatchNorm'])
],
):
super().__init__(init_cfg=init_cfg)
self.feature_dim = feature_dim
assert is_pow2n(upsample_ratio)
n = round(math.log2(upsample_ratio))
channels = [feature_dim // 2**i for i in range(n + 1)]
self.decoder = nn.ModuleList([
DecoderConvModule(
in_channels=c_in,
out_channels=c_out,
mid_channels=c_in if mid_channels == 0 else mid_channels,
kernel_size=kernel_size,
scale_factor=scale_factor,
num_conv_blocks=num_conv_blocks,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
last_act=last_act)
for (c_in, c_out) in zip(channels[:-1], channels[1:])
])
self.proj = nn.Conv2d(
channels[-1], 3, kernel_size=1, stride=1, bias=True)
def forward(self, to_dec):
x = 0
for i, d in enumerate(self.decoder):
if i < len(to_dec) and to_dec[i] is not None:
x = x + to_dec[i]
x = self.decoder[i](x)
return self.proj(x)

View File

@ -16,6 +16,7 @@ from .mocov3 import MoCoV3, MoCoV3ViT
from .simclr import SimCLR
from .simmim import SimMIM, SimMIMSwinTransformer
from .simsiam import SimSiam
from .spark import SparK
from .swav import SwAV
__all__ = [
@ -51,4 +52,5 @@ __all__ = [
'DenseCL',
'BarlowTwins',
'SwAV',
'SparK',
]

View File

@ -0,0 +1,163 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils.norm import build_norm_layer
from ..utils.sparse_modules import SparseHelper
from .base import BaseSelfSupervisor
@MODELS.register_module()
class SparK(BaseSelfSupervisor):
"""Implementation of SparK.
Implementation of `Designing BERT for Convolutional Networks: Sparse and
Hierarchical Masked Modeling <https://arxiv.org/abs/2301.03580>`_.
Modified from
https://github.com/keyu-tian/SparK/blob/main/pretrain/spark.py
"""
def __init__(
self,
backbone: dict,
neck: dict,
head: dict,
pretrained: Optional[str] = None,
data_preprocessor: Optional[dict] = None,
input_size: int = 224,
downsample_raito: int = 32,
mask_ratio: float = 0.6,
enc_dec_norm_cfg=dict(type='SparseSyncBatchNorm2d'),
enc_dec_norm_dim: int = 2048,
init_cfg: Optional[dict] = None,
) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
pretrained=pretrained,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
self.input_size = input_size
self.downsample_raito = downsample_raito
feature_map_size = input_size // downsample_raito
self.feature_map_size = feature_map_size
self.mask_ratio = mask_ratio
self.len_keep = round(feature_map_size * feature_map_size *
(1 - mask_ratio))
self.enc_dec_norm_cfg = enc_dec_norm_cfg
self.enc_dec_norms = nn.ModuleList()
self.enc_dec_projectors = nn.ModuleList()
self.mask_tokens = nn.ParameterList()
proj_out_dim = self.neck.feature_dim
for i in range(len(self.backbone.out_indices)):
enc_dec_norm = build_norm_layer(self.enc_dec_norm_cfg,
enc_dec_norm_dim)
self.enc_dec_norms.append(enc_dec_norm)
kernel_size = 1 if i <= 0 else 3
proj_layer = nn.Conv2d(
enc_dec_norm_dim,
proj_out_dim,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
bias=True)
if i == 0 and enc_dec_norm_dim == proj_out_dim:
proj_layer = nn.Identity()
self.enc_dec_projectors.append(proj_layer)
mask_token = nn.Parameter(torch.zeros(1, enc_dec_norm_dim, 1, 1))
trunc_normal_(mask_token, mean=0, std=.02, a=-.02, b=.02)
self.mask_tokens.append(mask_token)
enc_dec_norm_dim //= 2
proj_out_dim //= 2
feature_map_size *= 2
def mask(self,
shape: torch.Size,
device: Union[torch.device, str],
generator: Optional[torch.Generator] = None):
"""Mask generation.
Args:
shape (torch.Size): The shape of the input images.
device (Union[torch.device, str]): The device of the tensor.
generator (torch.Generator, optional): Generator for random
functions. Defaults to None
Returns:
torch.Tensor: The generated mask.
"""
B, C, H, W = shape
f = self.feature_map_size
idx = torch.rand(B, f * f, generator=generator).argsort(dim=1)
idx = idx[:, :self.len_keep].to(device) # (B, len_keep)
return torch.zeros(
B, f * f, dtype=torch.bool, device=device).scatter_(
dim=1, index=idx, value=True).view(B, 1, f, f)
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (List[torch.Tensor]): The input images.
data_samples (List[DataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
# active mask of feature map, (B, 1, f, f)
active_mask_feature_map = self.mask(inputs.shape, inputs.device)
SparseHelper._cur_active = active_mask_feature_map
# active mask of original input, (B, 1, H, W)
active_mask_origin = active_mask_feature_map.repeat_interleave(
self.downsample_raito,
2).repeat_interleave(self.downsample_raito, 3)
masked_img = inputs * active_mask_origin
# get hierarchical encoded sparse features in a list
# containing four feature maps
feature_maps = self.backbone(masked_img)
# from the smallest feature map to the largest
feature_maps = list(feature_maps)
feature_maps.reverse()
cur_active = active_mask_feature_map
feature_maps_to_dec = []
for i, feature_map in enumerate(feature_maps):
if feature_map is not None:
# fill in empty positions with [mask] embeddings
feature_map = self.enc_dec_norms[i](feature_map)
mask_token = self.mask_tokens[i].expand_as(feature_map)
feature_map = torch.where(
cur_active.expand_as(feature_map), feature_map,
mask_token.to(feature_map.dtype))
feature_map = self.enc_dec_projectors[i](feature_map)
feature_maps_to_dec.append(feature_map)
# dilate the mask map
cur_active = cur_active.repeat_interleave(
2, dim=2).repeat_interleave(
2, dim=3)
# decode and reconstruct
rec_img = self.neck(feature_maps_to_dec)
# compute loss
loss = self.head(rec_img, inputs, active_mask_feature_map)
losses = dict(loss=loss)
return losses

View File

@ -25,6 +25,9 @@ from .position_encoding import (ConditionalPositionEncoding,
build_2d_sincos_position_embedding)
from .res_layer_extra_norm import ResLayerExtraNorm
from .se_layer import SELayer
from .sparse_modules import (SparseAvgPooling, SparseBatchNorm2d, SparseConv2d,
SparseHelper, SparseLayerNorm2D, SparseMaxPooling,
SparseSyncBatchNorm2d)
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
from .vector_quantizer import NormEMAVectorQuantizer
@ -78,6 +81,13 @@ __all__ = [
'SwiGLUFFN',
'SwiGLUFFNFused',
'RotaryEmbeddingFast',
'SparseAvgPooling',
'SparseConv2d',
'SparseHelper',
'SparseMaxPooling',
'SparseBatchNorm2d',
'SparseLayerNorm2D',
'SparseSyncBatchNorm2d',
]
if WITH_MULTIMODAL:

View File

@ -0,0 +1,149 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) ByteDance, Inc. and its affiliates. All rights reserved.
# Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py
import torch
import torch.nn as nn
from mmpretrain.registry import MODELS
class SparseHelper:
"""The helper to compute sparse operation with pytorch, such as sparse
convlolution, sparse batch norm, etc."""
_cur_active: torch.Tensor = None
@staticmethod
def _get_active_map_or_index(H: int,
returning_active_map: bool = True
) -> torch.Tensor:
"""Get current active map with (B, 1, f, f) shape or index format."""
# _cur_active with shape (B, 1, f, f)
downsample_raito = H // SparseHelper._cur_active.shape[-1]
active_ex = SparseHelper._cur_active.repeat_interleave(
downsample_raito, 2).repeat_interleave(downsample_raito, 3)
return active_ex if returning_active_map else active_ex.squeeze(
1).nonzero(as_tuple=True)
@staticmethod
def sp_conv_forward(self, x: torch.Tensor) -> torch.Tensor:
"""Sparse convolution forward function."""
x = super(type(self), self).forward(x)
# (b, c, h, w) *= (b, 1, h, w), mask the output of conv
x *= SparseHelper._get_active_map_or_index(
H=x.shape[2], returning_active_map=True)
return x
@staticmethod
def sp_bn_forward(self, x: torch.Tensor) -> torch.Tensor:
"""Sparse batch norm forward function."""
active_index = SparseHelper._get_active_map_or_index(
H=x.shape[2], returning_active_map=False)
# (b, c, h, w) -> (b, h, w, c)
x_permuted = x.permute(0, 2, 3, 1)
# select the features on non-masked positions to form flatten features
# with shape (n, c)
x_flattened = x_permuted[active_index]
# use BN1d to normalize this flatten feature (n, c)
x_flattened = super(type(self), self).forward(x_flattened)
# generate output
output = torch.zeros_like(x_permuted, dtype=x_flattened.dtype)
output[active_index] = x_flattened
# (b, h, w, c) -> (b, c, h, w)
output = output.permute(0, 3, 1, 2)
return output
class SparseConv2d(nn.Conv2d):
"""hack: override the forward function.
See `sp_conv_forward` above for more details
"""
forward = SparseHelper.sp_conv_forward
class SparseMaxPooling(nn.MaxPool2d):
"""hack: override the forward function.
See `sp_conv_forward` above for more details
"""
forward = SparseHelper.sp_conv_forward
class SparseAvgPooling(nn.AvgPool2d):
"""hack: override the forward function.
See `sp_conv_forward` above for more details
"""
forward = SparseHelper.sp_conv_forward
@MODELS.register_module()
class SparseBatchNorm2d(nn.BatchNorm1d):
"""hack: override the forward function.
See `sp_bn_forward` above for more details
"""
forward = SparseHelper.sp_bn_forward
@MODELS.register_module()
class SparseSyncBatchNorm2d(nn.SyncBatchNorm):
"""hack: override the forward function.
See `sp_bn_forward` above for more details
"""
forward = SparseHelper.sp_bn_forward
@MODELS.register_module('SparseLN2d')
class SparseLayerNorm2D(nn.LayerNorm):
"""Implementation of sparse LayerNorm on channels for 2d images."""
def forward(self,
x: torch.Tensor,
data_format='channel_first') -> torch.Tensor:
"""Sparse layer norm forward function with 2D data.
Args:
x (torch.Tensor): The input tensor.
data_format (str): The format of the input tensor. If
``"channel_first"``, the shape of the input tensor should be
(B, C, H, W). If ``"channel_last"``, the shape of the input
tensor should be (B, H, W, C). Defaults to "channel_first".
"""
assert x.dim() == 4, (
f'LayerNorm2d only supports inputs with shape '
f'(N, C, H, W), but got tensor with shape {x.shape}')
if data_format == 'channel_last':
index = SparseHelper._get_active_map_or_index(
H=x.shape[1], returning_active_map=False)
# select the features on non-masked positions to form flatten
# features with shape (n, c)
x_flattened = x[index]
# use LayerNorm to normalize this flatten feature (n, c)
x_flattened = super().forward(x_flattened)
# generate output
x = torch.zeros_like(x, dtype=x_flattened.dtype)
x[index] = x_flattened
elif data_format == 'channel_first':
index = SparseHelper._get_active_map_or_index(
H=x.shape[2], returning_active_map=False)
x_permuted = x.permute(0, 2, 3, 1)
# select the features on non-masked positions to form flatten
# features with shape (n, c)
x_flattened = x_permuted[index]
# use LayerNorm to normalize this flatten feature (n, c)
x_flattened = super().forward(x_flattened)
# generate output
x = torch.zeros_like(x_permuted, dtype=x_flattened.dtype)
x[index] = x_flattened
x = x.permute(0, 3, 1, 2).contiguous()
else:
raise NotImplementedError
return x

View File

@ -78,6 +78,7 @@ Import:
- configs/chinese_clip/metafile.yml
- configs/itpn/metafile.yml
- configs/hivit/metafile.yml
- configs/spark/metafile.yml
- configs/minigpt4/metafile.yml
- configs/llava/metafile.yml
- configs/otter/metafile.yml

View File

@ -0,0 +1,51 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmpretrain.models import SparK
from mmpretrain.structures import DataSample
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_spark():
data_preprocessor = {
'mean': (123.675, 116.28, 103.53),
'std': (58.395, 57.12, 57.375),
'to_rgb': True
}
backbone = dict(
type='SparseResNet',
depth=50,
out_indices=(0, 1, 2, 3),
drop_path_rate=0.05,
norm_cfg=dict(type='BN'))
neck = dict(
type='SparKLightDecoder',
feature_dim=512,
upsample_ratio=32, # equal to downsample_raito
mid_channels=0,
norm_cfg=dict(type='BN'),
last_act=False)
head = dict(
type='SparKPretrainHead',
loss=dict(type='PixelReconstructionLoss', criterion='L2'))
alg = SparK(
backbone=backbone,
neck=neck,
head=head,
data_preprocessor=data_preprocessor,
enc_dec_norm_cfg=dict(type='BN'),
)
fake_data = {
'inputs': torch.randn((2, 3, 224, 224)),
'data_sample': [DataSample() for _ in range(2)]
}
fake_inputs = alg.data_preprocessor(fake_data)
fake_loss = alg(**fake_inputs, mode='loss')
assert isinstance(fake_loss['loss'].item(), float)