diff --git a/configs/spark/README.md b/configs/spark/README.md new file mode 100644 index 00000000..60f510e9 --- /dev/null +++ b/configs/spark/README.md @@ -0,0 +1,87 @@ +# SparK + +> [Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling](https://arxiv.org/abs/2301.03580) + + + +## Abstract + +We identify and overcome two key obstacles in extending the success of BERT-style pre-training, or the masked image modeling, to convolutional networks (convnets): (i) convolution operation cannot handle irregular, random-masked input images; (ii) the single-scale nature of BERT pre-training is inconsistent with convnet's hierarchical structure. For (i), we treat unmasked pixels as sparse voxels of 3D point clouds and use sparse convolution to encode. This is the first use of sparse convolution for 2D masked modeling. For (ii), we develop a hierarchical decoder to reconstruct images from multi-scale encoded features. Our method called Sparse masKed modeling (SparK) is general: it can be used directly on any convolutional model without backbone modifications. We validate it on both classical (ResNet) and modern (ConvNeXt) models: on three downstream tasks, it surpasses both state-of-the-art contrastive learning and transformer-based masked modeling by similarly large margins (around +1.0%). Improvements on object detection and instance segmentation are more substantial (up to +3.5%), verifying the strong transferability of features learned. We also find its favorable scaling behavior by observing more gains on larger models. All this evidence reveals a promising future of generative pre-training on convnets. Codes and models are released at https://github.com/keyu-tian/SparK. + +
+ +
+ +## How to use it? + + + +**Predict image** + +```python +from mmpretrain import inference_model + +predict = inference_model('resnet50_spark-pre_300e_in1k', 'demo/bird.JPEG') +print(predict['pred_class']) +print(predict['pred_score']) +``` + +**Use the model** + +```python +import torch +from mmpretrain import get_model + +model = get_model('spark_sparse-resnet50_800e_in1k', pretrained=True) +inputs = torch.rand(1, 3, 224, 224) +out = model(inputs) +print(type(out)) +# To extract features. +feats = model.extract_feat(inputs) +print(type(feats)) +``` + +**Train/Test Command** + +Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset). + +Train: + +```shell +python tools/train.py configs/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py +``` + +Test: + +```shell +python tools/test.py configs/spark/benchmarks/resnet50_8xb256-coslr-300e_in1k.py https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/resnet50_8xb256-coslr-300e_in1k/resnet50_8xb256-coslr-300e_in1k_20230612-f86aab51.pth +``` + + + +## Models and results + +### Pretrained models + +| Model | Params (M) | Flops (G) | Config | Download | +| :--------------------------------------- | :--------: | :-------: | :-------------------------------------------------------------------: | :----------------------------------------------------------------------: | +| `spark_sparse-resnet50_800e_in1k` | 37.97 | 4.10 | [config](spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k_20230612-e403c28f.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k_20230612-e403c28f.json) | +| `spark_sparse-convnextv2-tiny_800e_in1k` | 39.73 | 4.47 | [config](spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k_20230612-b0ea712e.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k_20230612-b0ea712e.json) | + +### Image Classification on ImageNet-1k + +| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Top-5 (%) | Config | Download | +| :------------------------------------ | :----------------------------------------: | :--------: | :-------: | :-------: | :-------: | :---------------------------------------: | :-----------------------------------------: | +| `resnet50_spark-pre_300e_in1k` | [SPARK](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k_20230612-e403c28f.pth) | 23.52 | 1.31 | 80.10 | 94.90 | [config](benchmarks/resnet50_8xb256-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/resnet50_8xb256-coslr-300e_in1k/resnet50_8xb256-coslr-300e_in1k_20230612-f86aab51.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/resnet50_8xb256-coslr-300e_in1k/resnet50_8xb256-coslr-300e_in1k_20230612-f86aab51.json) | +| `convnextv2-tiny_spark-pre_300e_in1k` | [SPARK](https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k_20230612-b0ea712e.pth) | 28.64 | 4.47 | 82.80 | 96.30 | [config](benchmarks/convnextv2-tiny_8xb256-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmpretrain/v1.0/spark//spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k_20230612-ffc78743.pth) \| [log](https://download.openmmlab.com/mmpretrain/v1.0/spark//spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k_20230612-ffc78743.json) | + +## Citation + +```bibtex +@Article{tian2023designing, + author = {Keyu Tian and Yi Jiang and Qishuai Diao and Chen Lin and Liwei Wang and Zehuan Yuan}, + title = {Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling}, + journal = {arXiv:2301.03580}, + year = {2023}, +} +``` diff --git a/configs/spark/benchmarks/convnextv2-tiny_8xb256-coslr-300e_in1k.py b/configs/spark/benchmarks/convnextv2-tiny_8xb256-coslr-300e_in1k.py new file mode 100644 index 00000000..53eff37a --- /dev/null +++ b/configs/spark/benchmarks/convnextv2-tiny_8xb256-coslr-300e_in1k.py @@ -0,0 +1,122 @@ +_base_ = [ + '../../_base_/datasets/imagenet_bs64_swin_224.py', + '../../_base_/default_runtime.py', +] + +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='NumpyToPIL', to_rgb=True), + dict( + type='torchvision/TrivialAugmentWide', + num_magnitude_bins=31, + interpolation='bicubic', + fill=None), + dict(type='PILToNumpy', to_bgr=True), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type='PackInputs'), +] + +train_dataloader = dict( + dataset=dict(pipeline=train_pipeline), + sampler=dict(type='RepeatAugSampler', shuffle=True), +) + +# Model settings +model = dict( + type='ImageClassifier', + backbone=dict( + type='ConvNeXt', + arch='tiny', + drop_path_rate=0.1, + layer_scale_init_value=0., + use_grn=True, + ), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=768, + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + init_cfg=dict(type='TruncNormal', layer='Linear', std=.02, bias=0.), + ), + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.8), + dict(type='CutMix', alpha=1.0), + ]), +) + +custom_hooks = [ + dict( + type='EMAHook', + momentum=1e-4, + evaluate_on_origin=True, + priority='ABOVE_NORMAL') +] + +# schedule settings +# optimizer +optim_wrapper = dict( + optimizer=dict( + type='AdamW', lr=3.2e-3, betas=(0.9, 0.999), weight_decay=0.05), + constructor='LearningRateDecayOptimWrapperConstructor', + paramwise_cfg=dict( + layer_decay_rate=0.7, + norm_decay_mult=0.0, + bias_decay_mult=0.0, + flat_decay_mult=0.0)) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type='LinearLR', + start_factor=0.0001, + by_epoch=True, + begin=0, + end=20, + convert_to_iter_based=True), + # main learning rate scheduler + dict( + type='CosineAnnealingLR', + T_max=280, + eta_min=1.0e-5, + by_epoch=True, + begin=20, + end=300) +] +train_cfg = dict(by_epoch=True, max_epochs=300) +val_cfg = dict() +test_cfg = dict() + +default_hooks = dict( + # only keeps the latest 2 checkpoints + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2)) + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=2048) diff --git a/configs/spark/benchmarks/resnet50_8xb256-coslr-300e_in1k.py b/configs/spark/benchmarks/resnet50_8xb256-coslr-300e_in1k.py new file mode 100644 index 00000000..c82bcf6b --- /dev/null +++ b/configs/spark/benchmarks/resnet50_8xb256-coslr-300e_in1k.py @@ -0,0 +1,107 @@ +_base_ = [ + '../../_base_/models/resnet50.py', + '../../_base_/datasets/imagenet_bs256_rsb_a12.py', + '../../_base_/default_runtime.py' +] +# modification is based on ResNets RSB settings +data_preprocessor = dict( + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True, +) + +bgr_mean = data_preprocessor['mean'][::-1] +bgr_std = data_preprocessor['std'][::-1] + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + scale=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', prob=0.5, direction='horizontal'), + dict(type='NumpyToPIL', to_rgb=True), + dict( + type='torchvision/TrivialAugmentWide', + num_magnitude_bins=31, + interpolation='bicubic', + fill=None), + dict(type='PILToNumpy', to_bgr=True), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=bgr_mean, + fill_std=bgr_std), + dict(type='PackInputs'), +] +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +# model settings +model = dict( + backbone=dict( + norm_cfg=dict(type='SyncBN', requires_grad=True), + drop_path_rate=0.05, + ), + head=dict( + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, use_sigmoid=True)), + train_cfg=dict(augments=[ + dict(type='Mixup', alpha=0.1), + dict(type='CutMix', alpha=1.0) + ])) + +# schedule settings +# optimizer +optim_wrapper = dict( + optimizer=dict( + type='Lamb', + lr=0.016, + weight_decay=0.02, + ), + constructor='LearningRateDecayOptimWrapperConstructor', + paramwise_cfg=dict( + layer_decay_rate=0.7, + norm_decay_mult=0.0, + bias_decay_mult=0.0, + flat_decay_mult=0.0)) + +# learning policy +param_scheduler = [ + # warm up learning rate scheduler + dict( + type='LinearLR', + start_factor=0.0001, + by_epoch=True, + begin=0, + end=5, + # update by iter + convert_to_iter_based=True), + # main learning rate scheduler + dict( + type='CosineAnnealingLR', + T_max=295, + eta_min=1.0e-6, + by_epoch=True, + begin=5, + end=300) +] +train_cfg = dict(by_epoch=True, max_epochs=300) +val_cfg = dict() +test_cfg = dict() + +default_hooks = dict( + # only keeps the latest 2 checkpoints + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2)) +# randomness +randomness = dict(seed=0, diff_rank_seed=True) + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=2048) diff --git a/configs/spark/metafile.yml b/configs/spark/metafile.yml new file mode 100644 index 00000000..81ca3a70 --- /dev/null +++ b/configs/spark/metafile.yml @@ -0,0 +1,73 @@ +Collections: + - Name: SparK + Metadata: + Architecture: + - Dense Connections + - GELU + - Layer Normalization + - Multi-Head Attention + - Scaled Dot-Product Attention + Paper: + Title: 'Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling' + URL: https://arxiv.org/abs/2301.03580 + README: configs/spark/README.md + Code: + URL: null + Version: null + +Models: + - Name: spark_sparse-resnet50_800e_in1k + Metadata: + FLOPs: 4100000000 + Parameters: 37971000 + Training Data: + - ImageNet-1k + In Collection: SparK + Results: null + Weights: https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k_20230612-e403c28f.pth + Config: configs/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py + Downstream: + - resnet50_spark-pre_300e_in1k + - Name: resnet50_spark-pre_300e_in1k + Metadata: + FLOPs: 1310000000 + Parameters: 23520000 + Training Data: + - ImageNet-1k + In Collection: SparK + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 80.1 + Top 5 Accuracy: 94.9 + Task: Image Classification + Weights: https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k/resnet50_8xb256-coslr-300e_in1k/resnet50_8xb256-coslr-300e_in1k_20230612-f86aab51.pth + Config: configs/spark/benchmarks/resnet50_8xb256-coslr-300e_in1k.py + + - Name: spark_sparse-convnextv2-tiny_800e_in1k + Metadata: + FLOPs: 4470000000 + Parameters: 39732000 + Training Data: + - ImageNet-1k + In Collection: SparK + Results: null + Weights: https://download.openmmlab.com/mmpretrain/v1.0/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k_20230612-b0ea712e.pth + Config: configs/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k.py + Downstream: + - convnextv2-tiny_spark-pre_300e_in1k + - Name: convnextv2-tiny_spark-pre_300e_in1k + Metadata: + FLOPs: 4469631744 + Parameters: 28635496 + Training Data: + - ImageNet-1k + In Collection: SparK + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 82.8 + Top 5 Accuracy: 96.3 + Task: Image Classification + Weights: https://download.openmmlab.com/mmpretrain/v1.0/spark//spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k/convnextv2-tiny_8xb256-coslr-300e_in1k_20230612-ffc78743.pth + Config: configs/spark/benchmarks/convnextv2-tiny_8xb256-coslr-300e_in1k.py diff --git a/configs/spark/spark_sparse-convnext-small_16xb256-amp-coslr-800e_in1k.py b/configs/spark/spark_sparse-convnext-small_16xb256-amp-coslr-800e_in1k.py new file mode 100644 index 00000000..5cefb5b9 --- /dev/null +++ b/configs/spark/spark_sparse-convnext-small_16xb256-amp-coslr-800e_in1k.py @@ -0,0 +1,81 @@ +_base_ = [ + '../_base_/datasets/imagenet_bs512_mae.py', + '../_base_/default_runtime.py', +] + +# dataset 8 x 512 +train_dataloader = dict(batch_size=256, num_workers=8) + +# model settings +model = dict( + type='SparK', + input_size=224, + downsample_raito=32, + mask_ratio=0.6, + enc_dec_norm_cfg=dict(type='SparseLN2d', eps=1e-6), + enc_dec_norm_dim=768, + backbone=dict( + type='SparseConvNeXt', + arch='small', + drop_path_rate=0.2, + out_indices=(0, 1, 2, 3), + gap_before_output=False), + neck=dict( + type='SparKLightDecoder', + feature_dim=512, + upsample_ratio=32, # equal to downsample_raito + mid_channels=0, + last_act=False), + head=dict( + type='SparKPretrainHead', + loss=dict(type='PixelReconstructionLoss', criterion='L2'))) + +# optimizer wrapper +optimizer = dict( + type='Lamb', lr=2e-4 * 4096 / 512, betas=(0.9, 0.95), weight_decay=0.04) +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=optimizer, + clip_grad=dict(max_norm=5.0), + paramwise_cfg=dict( + bias_decay_mult=0.0, + flat_decay_mult=0.0, + custom_keys={ + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=760, + by_epoch=True, + begin=40, + end=800, + convert_to_iter_based=True), + dict( + type='CosineAnnealingWeightDecay', + eta_min=0.2, + T_max=800, + by_epoch=True, + begin=0, + end=800, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800) +default_hooks = dict( + logger=dict(type='LoggerHook', interval=100), + # only keeps the latest 3 checkpoints + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2)) + +# randomness +randomness = dict(seed=0, diff_rank_seed=True) diff --git a/configs/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k.py b/configs/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k.py new file mode 100644 index 00000000..3a1afc80 --- /dev/null +++ b/configs/spark/spark_sparse-convnextv2-tiny_16xb256-amp-coslr-800e_in1k.py @@ -0,0 +1,84 @@ +_base_ = [ + '../_base_/datasets/imagenet_bs512_mae.py', + '../_base_/default_runtime.py', +] + +# dataset 16 x 256 +train_dataloader = dict(batch_size=256, num_workers=8) + +# model settings, use ConvNeXt V2 +model = dict( + type='SparK', + input_size=224, + downsample_raito=32, + mask_ratio=0.6, + enc_dec_norm_cfg=dict(type='SparseLN2d', eps=1e-6), + enc_dec_norm_dim=768, + backbone=dict( + type='SparseConvNeXt', + arch='tiny', + drop_path_rate=0.2, + out_indices=(0, 1, 2, 3), + gap_before_output=False, + layer_scale_init_value=0., + use_grn=True, + ), + neck=dict( + type='SparKLightDecoder', + feature_dim=512, + upsample_ratio=32, # equal to downsample_raito + mid_channels=0, + last_act=False), + head=dict( + type='SparKPretrainHead', + loss=dict(type='PixelReconstructionLoss', criterion='L2'))) + +# optimizer wrapper +optimizer = dict( + type='Lamb', lr=2e-4 * 4096 / 512, betas=(0.9, 0.95), weight_decay=0.04) +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=optimizer, + clip_grad=dict(max_norm=5.0), + paramwise_cfg=dict( + bias_decay_mult=0.0, + flat_decay_mult=0.0, + custom_keys={ + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=20, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=780, + by_epoch=True, + begin=20, + end=800, + convert_to_iter_based=True), + dict( + type='CosineAnnealingWeightDecay', + eta_min=0.2, + T_max=800, + by_epoch=True, + begin=0, + end=800, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800) +default_hooks = dict( + logger=dict(type='LoggerHook', interval=100), + # only keeps the latest 3 checkpoints + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2)) + +# randomness +randomness = dict(seed=0, diff_rank_seed=True) diff --git a/configs/spark/spark_sparse-resnet50_8xb512-amp-coslr-1600e_in1k.py b/configs/spark/spark_sparse-resnet50_8xb512-amp-coslr-1600e_in1k.py new file mode 100644 index 00000000..10fc6757 --- /dev/null +++ b/configs/spark/spark_sparse-resnet50_8xb512-amp-coslr-1600e_in1k.py @@ -0,0 +1,30 @@ +_base_ = 'spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py' + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=1560, + by_epoch=True, + begin=40, + end=1600, + convert_to_iter_based=True), + dict( + type='CosineAnnealingWeightDecay', + eta_min=0.2, + T_max=1600, + by_epoch=True, + begin=0, + end=1600, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(max_epochs=1600) diff --git a/configs/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py b/configs/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py new file mode 100644 index 00000000..864f6162 --- /dev/null +++ b/configs/spark/spark_sparse-resnet50_8xb512-amp-coslr-800e_in1k.py @@ -0,0 +1,80 @@ +_base_ = [ + '../_base_/datasets/imagenet_bs512_mae.py', + '../_base_/default_runtime.py', +] + +# dataset 8 x 512 +train_dataloader = dict(batch_size=512, num_workers=8) + +# model settings +model = dict( + type='SparK', + input_size=224, + downsample_raito=32, + mask_ratio=0.6, + enc_dec_norm_cfg=dict(type='SparseSyncBatchNorm2d'), + enc_dec_norm_dim=2048, + backbone=dict( + type='SparseResNet', + depth=50, + out_indices=(0, 1, 2, 3), + drop_path_rate=0.05), + neck=dict( + type='SparKLightDecoder', + feature_dim=512, + upsample_ratio=32, # equal to downsample_raito + mid_channels=0, + last_act=False), + head=dict( + type='SparKPretrainHead', + loss=dict(type='PixelReconstructionLoss', criterion='L2'))) + +# optimizer wrapper +optimizer = dict( + type='Lamb', lr=2e-4 * 4096 / 512, betas=(0.9, 0.95), weight_decay=0.04) +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=optimizer, + clip_grad=dict(max_norm=5.0), + paramwise_cfg=dict( + bias_decay_mult=0.0, + flat_decay_mult=0.0, + custom_keys={ + 'mask_token': dict(decay_mult=0.), + })) + +# learning rate scheduler +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1e-4, + by_epoch=True, + begin=0, + end=40, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=760, + by_epoch=True, + begin=40, + end=800, + convert_to_iter_based=True), + dict( + type='CosineAnnealingWeightDecay', + eta_min=0.2, + T_max=800, + by_epoch=True, + begin=0, + end=800, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800) +default_hooks = dict( + logger=dict(type='LoggerHook', interval=100), + # only keeps the latest 3 checkpoints + checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=2)) + +# randomness +randomness = dict(seed=0, diff_rank_seed=True) diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst index e7c4fce6..93e3e841 100644 --- a/docs/en/api/models.rst +++ b/docs/en/api/models.rst @@ -76,6 +76,7 @@ Self-supervised Algorithms SimCLR SimMIM SimSiam + SparK SwAV .. _selfsup_backbones: @@ -205,6 +206,8 @@ Backbones SVT ShuffleNetV1 ShuffleNetV2 + SparseResNet + SparseConvNeXt SwinTransformer SwinTransformerV2 T2T_ViT @@ -243,6 +246,7 @@ Necks SimMIMLinearDecoder SwAVNeck iTPNPretrainDecoder + SparKLightDecoder .. module:: mmpretrain.models.heads @@ -280,6 +284,7 @@ Heads VigClsHead VisionTransformerClsHead iTPNClipHead + SparKPretrainHead .. module:: mmpretrain.models.losses diff --git a/mmpretrain/engine/__init__.py b/mmpretrain/engine/__init__.py index 7785da7b..332fea09 100644 --- a/mmpretrain/engine/__init__.py +++ b/mmpretrain/engine/__init__.py @@ -2,3 +2,4 @@ from .hooks import * # noqa: F401, F403 from .optimizers import * # noqa: F401, F403 from .runners import * # noqa: F401, F403 +from .schedulers import * # noqa: F401, F403 diff --git a/mmpretrain/engine/schedulers/__init__.py b/mmpretrain/engine/schedulers/__init__.py new file mode 100644 index 00000000..68b6a547 --- /dev/null +++ b/mmpretrain/engine/schedulers/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .weight_decay_scheduler import CosineAnnealingWeightDecay + +__all__ = ['CosineAnnealingWeightDecay'] diff --git a/mmpretrain/engine/schedulers/weight_decay_scheduler.py b/mmpretrain/engine/schedulers/weight_decay_scheduler.py new file mode 100644 index 00000000..7e725a4c --- /dev/null +++ b/mmpretrain/engine/schedulers/weight_decay_scheduler.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +from mmengine.optim.scheduler import CosineAnnealingParamScheduler + +from mmpretrain.registry import PARAM_SCHEDULERS + + +class WeightDecaySchedulerMixin: + """A mixin class for learning rate schedulers.""" + + def __init__(self, optimizer, *args, **kwargs): + super().__init__(optimizer, 'weight_decay', *args, **kwargs) + + +@PARAM_SCHEDULERS.register_module() +class CosineAnnealingWeightDecay(WeightDecaySchedulerMixin, + CosineAnnealingParamScheduler): + """Set the weight decay value of each parameter group using a cosine + annealing schedule. + + If the weight decay was set to be 0 initially, the weight decay value will + be 0 constantly during the training. + """ + + def _get_value(self) -> list: + """Compute value using chainable form of the scheduler.""" + + def _get_eta_min(base_value): + if self.eta_min_ratio is None: + return self.eta_min + return base_value * self.eta_min_ratio + + if self.last_step == 0: + return [ + group[self.param_name] for group in self.optimizer.param_groups + ] + elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0: + weight_decay_value_list = [] + for base_value, group in zip(self.base_values, + self.optimizer.param_groups): + if base_value == 0: + group_value = 0 + else: + group_value = group[self.param_name] + ( + base_value - _get_eta_min(base_value)) * ( + 1 - math.cos(math.pi / self.T_max)) / 2 + weight_decay_value_list.append(group_value) + return weight_decay_value_list + + weight_decay_value_list = [] + for base_value, group in zip(self.base_values, + self.optimizer.param_groups): + if base_value == 0: + group_value = 0 + else: + group_value = ( + 1 + math.cos(math.pi * self.last_step / self.T_max)) / ( + 1 + math.cos(math.pi * + (self.last_step - 1) / self.T_max) + ) * (group[self.param_name] - + _get_eta_min(base_value)) + _get_eta_min(base_value) + weight_decay_value_list.append(group_value) + return weight_decay_value_list diff --git a/mmpretrain/models/backbones/__init__.py b/mmpretrain/models/backbones/__init__.py index 3da72c1d..60e37fb7 100644 --- a/mmpretrain/models/backbones/__init__.py +++ b/mmpretrain/models/backbones/__init__.py @@ -42,6 +42,8 @@ from .seresnet import SEResNet from .seresnext import SEResNeXt from .shufflenet_v1 import ShuffleNetV1 from .shufflenet_v2 import ShuffleNetV2 +from .sparse_convnext import SparseConvNeXt +from .sparse_resnet import SparseResNet from .swin_transformer import SwinTransformer from .swin_transformer_v2 import SwinTransformerV2 from .t2t_vit import T2T_ViT @@ -122,4 +124,6 @@ __all__ = [ 'ViTSAM', 'ViTEVA02', 'HiViT', + 'SparseResNet', + 'SparseConvNeXt', ] diff --git a/mmpretrain/models/backbones/convnext.py b/mmpretrain/models/backbones/convnext.py index f9c29cf2..6a954f5b 100644 --- a/mmpretrain/models/backbones/convnext.py +++ b/mmpretrain/models/backbones/convnext.py @@ -366,3 +366,47 @@ class ConvNeXt(BaseBackbone): def train(self, mode=True): super(ConvNeXt, self).train(mode) self._freeze_stages() + + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer-wise depth of a parameter. + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + """ + + max_layer_id = 12 if self.depths[-2] > 9 else 6 + + if not param_name.startswith(prefix): + # For subsequent module like head + return max_layer_id + 1, max_layer_id + 2 + + param_name = param_name[len(prefix):] + if param_name.startswith('downsample_layers'): + stage_id = int(param_name.split('.')[1]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1 or stage_id == 2: + layer_id = stage_id + 1 + else: # stage_id == 3: + layer_id = max_layer_id + + elif param_name.startswith('stages'): + stage_id = int(param_name.split('.')[1]) + block_id = int(param_name.split('.')[2]) + if stage_id == 0 or stage_id == 1: + layer_id = stage_id + 1 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + else: # stage_id == 3: + layer_id = max_layer_id + + # final norm layer + else: + layer_id = max_layer_id + 1 + + return layer_id, max_layer_id + 2 diff --git a/mmpretrain/models/backbones/resnet.py b/mmpretrain/models/backbones/resnet.py index e4df601d..4a254f7c 100644 --- a/mmpretrain/models/backbones/resnet.py +++ b/mmpretrain/models/backbones/resnet.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math + import torch import torch.nn as nn import torch.utils.checkpoint as cp @@ -674,6 +676,64 @@ class ResNet(BaseBackbone): if isinstance(m, _BatchNorm): m.eval() + def get_layer_depth(self, param_name: str, prefix: str = ''): + """Get the layer id to set the different learning rates for ResNet. + + ResNet stages: + 50 : [3, 4, 6, 3] + 101 : [3, 4, 23, 3] + 152 : [3, 8, 36, 3] + 200 : [3, 24, 36, 3] + eca269d: [3, 30, 48, 8] + + Args: + param_name (str): The name of the parameter. + prefix (str): The prefix for the parameter. + Defaults to an empty string. + + Returns: + Tuple[int, int]: The layer-wise depth and the num of layers. + """ + depths = self.stage_blocks + if depths[1] == 4 and depths[2] == 6: + blk2, blk3 = 2, 3 + elif depths[1] == 4 and depths[2] == 23: + blk2, blk3 = 2, 3 + elif depths[1] == 8 and depths[2] == 36: + blk2, blk3 = 4, 4 + elif depths[1] == 24 and depths[2] == 36: + blk2, blk3 = 4, 4 + elif depths[1] == 30 and depths[2] == 48: + blk2, blk3 = 5, 6 + else: + raise NotImplementedError + + N2, N3 = math.ceil(depths[1] / blk2 - + 1e-5), math.ceil(depths[2] / blk3 - 1e-5) + N = 2 + N2 + N3 # r50: 2 + 2 + 2 = 6 + max_layer_id = N + 1 # r50: 2 + 2 + 2 + 1(like head) = 7 + + if not param_name.startswith(prefix): + # For subsequent module like head + return max_layer_id, max_layer_id + 1 + + if param_name.startswith('backbone.layer'): + stage_id = int(param_name.split('.')[1][5:]) + block_id = int(param_name.split('.')[2]) + + if stage_id == 1: + layer_id = 1 + elif stage_id == 2: + layer_id = 2 + block_id // blk2 # r50: 2, 3 + elif stage_id == 3: + layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5 + else: # stage_id == 4 + layer_id = N # r50: 6 + return layer_id, max_layer_id + 1 + + else: + return 0, max_layer_id + 1 + @MODELS.register_module() class ResNetV1c(ResNet): diff --git a/mmpretrain/models/backbones/sparse_convnext.py b/mmpretrain/models/backbones/sparse_convnext.py new file mode 100644 index 00000000..8f361360 --- /dev/null +++ b/mmpretrain/models/backbones/sparse_convnext.py @@ -0,0 +1,298 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Sequence, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmengine.model import ModuleList, Sequential + +from mmpretrain.registry import MODELS +from ..utils import (SparseAvgPooling, SparseConv2d, SparseHelper, + SparseMaxPooling, build_norm_layer) +from .convnext import ConvNeXt, ConvNeXtBlock + + +class SparseConvNeXtBlock(ConvNeXtBlock): + """Sparse ConvNeXt Block. + + Note: + There are two equivalent implementations: + 1. DwConv -> SparseLayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv; + all outputs are in (N, C, H, W). + 2. DwConv -> SparseLayerNorm -> Permute to (N, H, W, C) -> Linear -> + GELU -> Linear; Permute back + As default, we use the second to align with the official repository. + And it may be slightly faster. + """ + + def forward(self, x): + + def _inner_forward(x): + shortcut = x + x = self.depthwise_conv(x) + + if self.linear_pw_conv: + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x, data_format='channel_last') + x = self.pointwise_conv1(x) + x = self.act(x) + if self.grn is not None: + x = self.grn(x, data_format='channel_last') + x = self.pointwise_conv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + else: + x = self.norm(x, data_format='channel_first') + x = self.pointwise_conv1(x) + x = self.act(x) + + if self.grn is not None: + x = self.grn(x, data_format='channel_first') + x = self.pointwise_conv2(x) + + if self.gamma is not None: + x = x.mul(self.gamma.view(1, -1, 1, 1)) + + x *= SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=True) + + x = shortcut + self.drop_path(x) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +@MODELS.register_module() +class SparseConvNeXt(ConvNeXt): + """ConvNeXt with sparse module conversion function. + + Modified from + https://github.com/keyu-tian/SparK/blob/main/models/convnext.py + and + https://github.com/keyu-tian/SparK/blob/main/encoder.py + To use ConvNeXt v2, please set ``use_grn=True`` and ``layer_scale_init_value=0.``. + + Args: + arch (str | dict): The model's architecture. If string, it should be + one of architecture in ``ConvNeXt.arch_settings``. And if dict, it + should include the following two keys: + - depths (list[int]): Number of blocks at each stage. + - channels (list[int]): The number of channels at each stage. + Defaults to 'tiny'. + in_channels (int): Number of input image channels. Defaults to 3. + stem_patch_size (int): The size of one patch in the stem layer. + Defaults to 4. + norm_cfg (dict): The config dict for norm layers. + Defaults to ``dict(type='SparseLN2d', eps=1e-6)``. + act_cfg (dict): The config dict for activation between pointwise + convolution. Defaults to ``dict(type='GELU')``. + linear_pw_conv (bool): Whether to use linear layer to do pointwise + convolution. Defaults to True. + use_grn (bool): Whether to add Global Response Normalization in the + blocks. Defaults to False. + drop_path_rate (float): Stochastic depth rate. Defaults to 0. + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-6. + out_indices (Sequence | int): Output from which stages. + Defaults to -1, means the last stage. + frozen_stages (int): Stages to be frozen (all param fixed). + Defaults to 0, which means not freezing any parameters. + gap_before_output (bool): Whether to globally average the feature + map before the final norm layer. In the official repo, it's only + used in classification task. Defaults to True. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + init_cfg (dict, optional): Initialization config dict. + """ # noqa: E501 + + def __init__(self, + arch: str = 'small', + in_channels: int = 3, + stem_patch_size: int = 4, + norm_cfg: dict = dict(type='SparseLN2d', eps=1e-6), + act_cfg: dict = dict(type='GELU'), + linear_pw_conv: bool = True, + use_grn: bool = False, + drop_path_rate: float = 0, + layer_scale_init_value: float = 1e-6, + out_indices: int = -1, + frozen_stages: int = 0, + gap_before_output: bool = True, + with_cp: bool = False, + init_cfg: Optional[Union[dict, List[dict]]] = [ + dict( + type='TruncNormal', + layer=['Conv2d', 'Linear'], + std=.02, + bias=0.), + dict( + type='Constant', layer=['LayerNorm'], val=1., + bias=0.), + ]): + super(ConvNeXt, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + assert arch in self.arch_settings, \ + f'Unavailable arch, please choose from ' \ + f'({set(self.arch_settings)}) or pass a dict.' + arch = self.arch_settings[arch] + elif isinstance(arch, dict): + assert 'depths' in arch and 'channels' in arch, \ + f'The arch dict must have "depths" and "channels", ' \ + f'but got {list(arch.keys())}.' + + self.depths = arch['depths'] + self.channels = arch['channels'] + assert (isinstance(self.depths, Sequence) + and isinstance(self.channels, Sequence) + and len(self.depths) == len(self.channels)), \ + f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \ + 'should be both sequence with the same length.' + + self.num_stages = len(self.depths) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = 4 + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.out_indices = out_indices + + self.frozen_stages = frozen_stages + self.gap_before_output = gap_before_output + + # 4 downsample layers between stages, including the stem layer. + self.downsample_layers = ModuleList() + stem = nn.Sequential( + nn.Conv2d( + in_channels, + self.channels[0], + kernel_size=stem_patch_size, + stride=stem_patch_size), + build_norm_layer(norm_cfg, self.channels[0]), + ) + self.downsample_layers.append(stem) + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(self.depths)) + ] + block_idx = 0 + + # 4 feature resolution stages, each consisting of multiple residual + # blocks + self.stages = nn.ModuleList() + for i in range(self.num_stages): + depth = self.depths[i] + channels = self.channels[i] + + if i >= 1: + downsample_layer = nn.Sequential( + build_norm_layer(norm_cfg, self.channels[i - 1]), + nn.Conv2d( + self.channels[i - 1], + channels, + kernel_size=2, + stride=2), + ) + self.downsample_layers.append(downsample_layer) + + stage = Sequential(*[ + SparseConvNeXtBlock( + in_channels=channels, + drop_path_rate=dpr[block_idx + j], + norm_cfg=norm_cfg, + act_cfg=act_cfg, + linear_pw_conv=linear_pw_conv, + layer_scale_init_value=layer_scale_init_value, + use_grn=use_grn, + with_cp=with_cp) for j in range(depth) + ]) + block_idx += depth + + self.stages.append(stage) + + self.dense_model_to_sparse(m=self) + + def forward(self, x): + outs = [] + for i, stage in enumerate(self.stages): + x = self.downsample_layers[i](x) + x = stage(x) + if i in self.out_indices: + if self.gap_before_output: + gap = x.mean([-2, -1], keepdim=True) + outs.append(gap.flatten(1)) + else: + outs.append(x) + + return tuple(outs) + + def dense_model_to_sparse(self, m: nn.Module) -> nn.Module: + """Convert regular dense modules to sparse modules.""" + output = m + if isinstance(m, nn.Conv2d): + m: nn.Conv2d + bias = m.bias is not None + output = SparseConv2d( + m.in_channels, + m.out_channels, + kernel_size=m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + groups=m.groups, + bias=bias, + padding_mode=m.padding_mode, + ) + output.weight.data.copy_(m.weight.data) + if bias: + output.bias.data.copy_(m.bias.data) + + elif isinstance(m, nn.MaxPool2d): + m: nn.MaxPool2d + output = SparseMaxPooling( + m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + return_indices=m.return_indices, + ceil_mode=m.ceil_mode) + + elif isinstance(m, nn.AvgPool2d): + m: nn.AvgPool2d + output = SparseAvgPooling( + m.kernel_size, + m.stride, + m.padding, + ceil_mode=m.ceil_mode, + count_include_pad=m.count_include_pad, + divisor_override=m.divisor_override) + + # elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): + # m: nn.BatchNorm2d + # output = (SparseSyncBatchNorm2d + # if enable_sync_bn else SparseBatchNorm2d)( + # m.weight.shape[0], + # eps=m.eps, + # momentum=m.momentum, + # affine=m.affine, + # track_running_stats=m.track_running_stats) + # output.weight.data.copy_(m.weight.data) + # output.bias.data.copy_(m.bias.data) + # output.running_mean.data.copy_(m.running_mean.data) + # output.running_var.data.copy_(m.running_var.data) + # output.num_batches_tracked.data.copy_(m.num_batches_tracked.data) + + for name, child in m.named_children(): + output.add_module(name, self.dense_model_to_sparse(child)) + del m + return output diff --git a/mmpretrain/models/backbones/sparse_resnet.py b/mmpretrain/models/backbones/sparse_resnet.py new file mode 100644 index 00000000..67597f1f --- /dev/null +++ b/mmpretrain/models/backbones/sparse_resnet.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import re +from typing import Optional, Tuple + +import torch.nn as nn + +from mmpretrain.models.utils.sparse_modules import (SparseAvgPooling, + SparseBatchNorm2d, + SparseConv2d, + SparseMaxPooling, + SparseSyncBatchNorm2d) +from mmpretrain.registry import MODELS +from .resnet import ResNet + + +@MODELS.register_module() +class SparseResNet(ResNet): + """ResNet with sparse module conversion function. + + Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py + + Args: + depth (int): Network depth, from {18, 34, 50, 101, 152}. + in_channels (int): Number of input image channels. Defaults to 3. + stem_channels (int): Output channels of the stem layer. Defaults to 64. + base_channels (int): Middle channels of the first stage. + Defaults to 64. + num_stages (int): Stages of the network. Defaults to 4. + strides (Sequence[int]): Strides of the first block of each stage. + Defaults to ``(1, 2, 2, 2)``. + dilations (Sequence[int]): Dilation of each stage. + Defaults to ``(1, 1, 1, 1)``. + out_indices (Sequence[int]): Output from which stages. + Defaults to ``(3, )``. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. + Defaults to False. + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. Defaults to False. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + conv_cfg (dict | None): The config dict for conv layers. + Defaults to None. + norm_cfg (dict): The config dict for norm layers. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Defaults to False. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. Defaults to True. + drop_path_rate (float): stochastic depth rate. Defaults to 0. + """ + + def __init__(self, + depth: int, + in_channels: int = 3, + stem_channels: int = 64, + base_channels: int = 64, + expansion: Optional[int] = None, + num_stages: int = 4, + strides: Tuple[int] = (1, 2, 2, 2), + dilations: Tuple[int] = (1, 1, 1, 1), + out_indices: Tuple[int] = (3, ), + style: str = 'pytorch', + deep_stem: bool = False, + avg_down: bool = False, + frozen_stages: int = -1, + conv_cfg: Optional[dict] = None, + norm_cfg: dict = dict(type='SparseSyncBatchNorm2d'), + norm_eval: bool = False, + with_cp: bool = False, + zero_init_residual: bool = False, + init_cfg: Optional[dict] = [ + dict(type='Kaiming', layer=['Conv2d']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ], + drop_path_rate: float = 0, + **kwargs): + super().__init__( + depth=depth, + in_channels=in_channels, + stem_channels=stem_channels, + base_channels=base_channels, + expansion=expansion, + num_stages=num_stages, + strides=strides, + dilations=dilations, + out_indices=out_indices, + style=style, + deep_stem=deep_stem, + avg_down=avg_down, + frozen_stages=frozen_stages, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + norm_eval=norm_eval, + with_cp=with_cp, + zero_init_residual=zero_init_residual, + init_cfg=init_cfg, + drop_path_rate=drop_path_rate, + **kwargs) + norm_type = norm_cfg['type'] + enable_sync_bn = False + if re.search('Sync', norm_type) is not None: + enable_sync_bn = True + self.dense_model_to_sparse(m=self, enable_sync_bn=enable_sync_bn) + + def dense_model_to_sparse(self, m: nn.Module, + enable_sync_bn: bool) -> nn.Module: + """Convert regular dense modules to sparse modules.""" + output = m + if isinstance(m, nn.Conv2d): + m: nn.Conv2d + bias = m.bias is not None + output = SparseConv2d( + m.in_channels, + m.out_channels, + kernel_size=m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + groups=m.groups, + bias=bias, + padding_mode=m.padding_mode, + ) + output.weight.data.copy_(m.weight.data) + if bias: + output.bias.data.copy_(m.bias.data) + + elif isinstance(m, nn.MaxPool2d): + m: nn.MaxPool2d + output = SparseMaxPooling( + m.kernel_size, + stride=m.stride, + padding=m.padding, + dilation=m.dilation, + return_indices=m.return_indices, + ceil_mode=m.ceil_mode) + + elif isinstance(m, nn.AvgPool2d): + m: nn.AvgPool2d + output = SparseAvgPooling( + m.kernel_size, + m.stride, + m.padding, + ceil_mode=m.ceil_mode, + count_include_pad=m.count_include_pad, + divisor_override=m.divisor_override) + + elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): + m: nn.BatchNorm2d + output = (SparseSyncBatchNorm2d + if enable_sync_bn else SparseBatchNorm2d)( + m.weight.shape[0], + eps=m.eps, + momentum=m.momentum, + affine=m.affine, + track_running_stats=m.track_running_stats) + output.weight.data.copy_(m.weight.data) + output.bias.data.copy_(m.bias.data) + output.running_mean.data.copy_(m.running_mean.data) + output.running_var.data.copy_(m.running_var.data) + output.num_batches_tracked.data.copy_(m.num_batches_tracked.data) + + elif isinstance(m, (nn.Conv1d, )): + raise NotImplementedError + + for name, child in m.named_children(): + output.add_module( + name, + self.dense_model_to_sparse( + child, enable_sync_bn=enable_sync_bn)) + del m + return output diff --git a/mmpretrain/models/heads/__init__.py b/mmpretrain/models/heads/__init__.py index 899dbaa4..4364fb56 100644 --- a/mmpretrain/models/heads/__init__.py +++ b/mmpretrain/models/heads/__init__.py @@ -25,6 +25,7 @@ from .multi_label_linear_head import MultiLabelLinearClsHead from .multi_task_head import MultiTaskHead from .seq_gen_head import SeqGenerationHead from .simmim_head import SimMIMHead +from .spark_head import SparKPretrainHead from .stacked_head import StackedLinearClsHead from .swav_head import SwAVHead from .vig_head import VigClsHead @@ -64,4 +65,5 @@ __all__ = [ 'ITMHead', 'GroundingHead', 'iTPNClipHead', + 'SparKPretrainHead', ] diff --git a/mmpretrain/models/heads/spark_head.py b/mmpretrain/models/heads/spark_head.py new file mode 100644 index 00000000..a2748762 --- /dev/null +++ b/mmpretrain/models/heads/spark_head.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS + + +@MODELS.register_module() +class SparKPretrainHead(BaseModule): + """Pre-training head for SparK. + + Args: + loss (dict): Config of loss. + norm_pix (bool): Whether or not normalize target. Defaults to True. + patch_size (int): Patch size, equal to downsample ratio of backbone. + Defaults to 32. + """ + + def __init__(self, + loss: dict, + norm_pix: bool = True, + patch_size: int = 32) -> None: + super().__init__() + self.norm_pix = norm_pix + self.patch_size = patch_size + self.loss = MODELS.build(loss) + + def patchify(self, imgs): + """Split images into non-overlapped patches. + + Args: + imgs (torch.Tensor): A batch of images, of shape B x C x H x W. + Returns: + torch.Tensor: Patchified images. The shape is B x L x D. + """ + p = self.patch_size + assert len(imgs.shape + ) == 4 and imgs.shape[2] % p == 0 and imgs.shape[3] % p == 0 + + B, C, ori_h, ori_w = imgs.shape + h = ori_h // p + w = ori_w // p + x = imgs.reshape(shape=(B, C, h, p, w, p)) + x = torch.einsum('bchpwq->bhwpqc', x) + + # (B, f*f, downsample_raito*downsample_raito*3) + x = x.reshape(shape=(B, h * w, p**2 * C)) + return x + + def construct_target(self, target: torch.Tensor) -> torch.Tensor: + """Construct the reconstruction target. + + In addition to splitting images into tokens, this module will also + normalize the image according to ``norm_pix``. + Args: + target (torch.Tensor): Image with the shape of B x 3 x H x W + Returns: + torch.Tensor: Tokenized images with the shape of B x L x C + """ + target = self.patchify(target) + if self.norm_pix: + # normalize the target image + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + return target + + def forward(self, pred: torch.Tensor, target: torch.Tensor, + active_mask: torch.Tensor) -> torch.Tensor: + """Forward function of MAE head. + + Args: + pred (torch.Tensor): The reconstructed image. + target (torch.Tensor): The target image. + active_mask (torch.Tensor): The mask of the target image. + Returns: + torch.Tensor: The reconstruction loss. + """ + # (B, C, H, W) -> (B, L, C) and perform normalization + target = self.construct_target(target) + + # (B, C, H, W) -> (B, L, C) + pred = self.patchify(pred) + + # (B, 1, f, f) -> (B, L) + non_active_mask = active_mask.logical_not().int().view( + active_mask.shape[0], -1) + + # MSE loss on masked patches + loss = self.loss(pred, target, non_active_mask) + return loss diff --git a/mmpretrain/models/necks/__init__.py b/mmpretrain/models/necks/__init__.py index 91300e7e..2952a691 100644 --- a/mmpretrain/models/necks/__init__.py +++ b/mmpretrain/models/necks/__init__.py @@ -13,6 +13,7 @@ from .mixmim_neck import MixMIMPretrainDecoder from .mocov2_neck import MoCoV2Neck from .nonlinear_neck import NonLinearNeck from .simmim_neck import SimMIMLinearDecoder +from .spark_neck import SparKLightDecoder from .swav_neck import SwAVNeck __all__ = [ @@ -32,4 +33,5 @@ __all__ = [ 'SimMIMLinearDecoder', 'SwAVNeck', 'iTPNPretrainDecoder', + 'SparKLightDecoder', ] diff --git a/mmpretrain/models/necks/spark_neck.py b/mmpretrain/models/necks/spark_neck.py new file mode 100644 index 00000000..ac129da3 --- /dev/null +++ b/mmpretrain/models/necks/spark_neck.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmpretrain.registry import MODELS +from ..utils import build_norm_layer + + +def is_pow2n(x): + return x > 0 and (x & (x - 1) == 0) + + +class ConvBlock2x(BaseModule): + """The definition of convolution block.""" + + def __init__(self, + in_channels: int, + out_channels: int, + mid_channels: int, + norm_cfg: dict, + act_cfg: dict, + last_act: bool, + init_cfg: Optional[dict] = None) -> None: + super().__init__(init_cfg=init_cfg) + + self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1, bias=False) + self.norm1 = build_norm_layer(norm_cfg, mid_channels) + self.activate1 = MODELS.build(act_cfg) + + self.conv2 = nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False) + self.norm2 = build_norm_layer(norm_cfg, out_channels) + self.activate2 = MODELS.build(act_cfg) if last_act else nn.Identity() + + def forward(self, x: torch.Tensor): + out = self.conv1(x) + out = self.norm1(out) + out = self.activate1(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.activate2(out) + return out + + +class DecoderConvModule(BaseModule): + """The convolution module of decoder with upsampling.""" + + def __init__(self, + in_channels: int, + out_channels: int, + mid_channels: int, + kernel_size: int = 4, + scale_factor: int = 2, + num_conv_blocks: int = 1, + norm_cfg: dict = dict(type='SyncBN'), + act_cfg: dict = dict(type='ReLU6'), + last_act: bool = True, + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + + assert (kernel_size - scale_factor >= 0) and\ + (kernel_size - scale_factor) % 2 == 0,\ + f'kernel_size should be greater than or equal to scale_factor '\ + f'and (kernel_size - scale_factor) should be even numbers, '\ + f'while the kernel size is {kernel_size} and scale_factor is '\ + f'{scale_factor}.' + + padding = (kernel_size - scale_factor) // 2 + self.upsample = nn.ConvTranspose2d( + in_channels, + in_channels, + kernel_size=kernel_size, + stride=scale_factor, + padding=padding, + bias=True) + + conv_blocks_list = [ + ConvBlock2x( + in_channels=in_channels, + out_channels=out_channels, + mid_channels=mid_channels, + norm_cfg=norm_cfg, + last_act=last_act, + act_cfg=act_cfg) for _ in range(num_conv_blocks) + ] + self.conv_blocks = nn.Sequential(*conv_blocks_list) + + def forward(self, x): + x = self.upsample(x) + return self.conv_blocks(x) + + +@MODELS.register_module() +class SparKLightDecoder(BaseModule): + """The decoder for SparK, which upsamples the feature maps. + + Args: + feature_dim (int): The dimension of feature map. + upsample_ratio (int): The ratio of upsample, equal to downsample_raito + of the algorithm. + mid_channels (int): The middle channel of `DecoderConvModule`. Defaults + to 0. + kernel_size (int): The kernel size of `ConvTranspose2d` in + `DecoderConvModule`. Defaults to 4. + scale_factor (int): The scale_factor of `ConvTranspose2d` in + `DecoderConvModule`. Defaults to 2. + num_conv_blocks (int): The number of convolution blocks in + `DecoderConvModule`. Defaults to 1. + norm_cfg (dict): Normalization config. Defaults to dict(type='SyncBN'). + act_cfg (dict): Activation config. Defaults to dict(type='ReLU6'). + last_act (bool): Whether apply the last activation in + `DecoderConvModule`. Defaults to False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + feature_dim: int, + upsample_ratio: int, + mid_channels: int = 0, + kernel_size: int = 4, + scale_factor: int = 2, + num_conv_blocks: int = 1, + norm_cfg: dict = dict(type='SyncBN'), + act_cfg: dict = dict(type='ReLU6'), + last_act: bool = False, + init_cfg: Optional[dict] = [ + dict(type='Kaiming', layer=['Conv2d', 'ConvTranspose2d']), + dict(type='TruncNormal', std=0.02, layer=['Linear']), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'LayerNorm', 'SyncBatchNorm']) + ], + ): + super().__init__(init_cfg=init_cfg) + self.feature_dim = feature_dim + + assert is_pow2n(upsample_ratio) + n = round(math.log2(upsample_ratio)) + channels = [feature_dim // 2**i for i in range(n + 1)] + + self.decoder = nn.ModuleList([ + DecoderConvModule( + in_channels=c_in, + out_channels=c_out, + mid_channels=c_in if mid_channels == 0 else mid_channels, + kernel_size=kernel_size, + scale_factor=scale_factor, + num_conv_blocks=num_conv_blocks, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + last_act=last_act) + for (c_in, c_out) in zip(channels[:-1], channels[1:]) + ]) + self.proj = nn.Conv2d( + channels[-1], 3, kernel_size=1, stride=1, bias=True) + + def forward(self, to_dec): + x = 0 + for i, d in enumerate(self.decoder): + if i < len(to_dec) and to_dec[i] is not None: + x = x + to_dec[i] + x = self.decoder[i](x) + return self.proj(x) diff --git a/mmpretrain/models/selfsup/__init__.py b/mmpretrain/models/selfsup/__init__.py index 31607220..1052dedc 100644 --- a/mmpretrain/models/selfsup/__init__.py +++ b/mmpretrain/models/selfsup/__init__.py @@ -16,6 +16,7 @@ from .mocov3 import MoCoV3, MoCoV3ViT from .simclr import SimCLR from .simmim import SimMIM, SimMIMSwinTransformer from .simsiam import SimSiam +from .spark import SparK from .swav import SwAV __all__ = [ @@ -51,4 +52,5 @@ __all__ = [ 'DenseCL', 'BarlowTwins', 'SwAV', + 'SparK', ] diff --git a/mmpretrain/models/selfsup/spark.py b/mmpretrain/models/selfsup/spark.py new file mode 100644 index 00000000..d5570a5a --- /dev/null +++ b/mmpretrain/models/selfsup/spark.py @@ -0,0 +1,163 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from mmengine.model.weight_init import trunc_normal_ + +from mmpretrain.registry import MODELS +from mmpretrain.structures import DataSample +from ..utils.norm import build_norm_layer +from ..utils.sparse_modules import SparseHelper +from .base import BaseSelfSupervisor + + +@MODELS.register_module() +class SparK(BaseSelfSupervisor): + """Implementation of SparK. + + Implementation of `Designing BERT for Convolutional Networks: Sparse and + Hierarchical Masked Modeling `_. + + Modified from + https://github.com/keyu-tian/SparK/blob/main/pretrain/spark.py + """ + + def __init__( + self, + backbone: dict, + neck: dict, + head: dict, + pretrained: Optional[str] = None, + data_preprocessor: Optional[dict] = None, + input_size: int = 224, + downsample_raito: int = 32, + mask_ratio: float = 0.6, + enc_dec_norm_cfg=dict(type='SparseSyncBatchNorm2d'), + enc_dec_norm_dim: int = 2048, + init_cfg: Optional[dict] = None, + ) -> None: + super().__init__( + backbone=backbone, + neck=neck, + head=head, + pretrained=pretrained, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + self.input_size = input_size + self.downsample_raito = downsample_raito + feature_map_size = input_size // downsample_raito + self.feature_map_size = feature_map_size + + self.mask_ratio = mask_ratio + self.len_keep = round(feature_map_size * feature_map_size * + (1 - mask_ratio)) + + self.enc_dec_norm_cfg = enc_dec_norm_cfg + self.enc_dec_norms = nn.ModuleList() + self.enc_dec_projectors = nn.ModuleList() + self.mask_tokens = nn.ParameterList() + + proj_out_dim = self.neck.feature_dim + for i in range(len(self.backbone.out_indices)): + enc_dec_norm = build_norm_layer(self.enc_dec_norm_cfg, + enc_dec_norm_dim) + self.enc_dec_norms.append(enc_dec_norm) + + kernel_size = 1 if i <= 0 else 3 + proj_layer = nn.Conv2d( + enc_dec_norm_dim, + proj_out_dim, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + bias=True) + if i == 0 and enc_dec_norm_dim == proj_out_dim: + proj_layer = nn.Identity() + self.enc_dec_projectors.append(proj_layer) + + mask_token = nn.Parameter(torch.zeros(1, enc_dec_norm_dim, 1, 1)) + trunc_normal_(mask_token, mean=0, std=.02, a=-.02, b=.02) + self.mask_tokens.append(mask_token) + + enc_dec_norm_dim //= 2 + proj_out_dim //= 2 + feature_map_size *= 2 + + def mask(self, + shape: torch.Size, + device: Union[torch.device, str], + generator: Optional[torch.Generator] = None): + """Mask generation. + + Args: + shape (torch.Size): The shape of the input images. + device (Union[torch.device, str]): The device of the tensor. + generator (torch.Generator, optional): Generator for random + functions. Defaults to None + Returns: + torch.Tensor: The generated mask. + """ + B, C, H, W = shape + f = self.feature_map_size + idx = torch.rand(B, f * f, generator=generator).argsort(dim=1) + idx = idx[:, :self.len_keep].to(device) # (B, len_keep) + return torch.zeros( + B, f * f, dtype=torch.bool, device=device).scatter_( + dim=1, index=idx, value=True).view(B, 1, f, f) + + def loss(self, inputs: torch.Tensor, data_samples: List[DataSample], + **kwargs) -> Dict[str, torch.Tensor]: + """The forward function in training. + + Args: + inputs (List[torch.Tensor]): The input images. + data_samples (List[DataSample]): All elements required + during the forward function. + Returns: + Dict[str, torch.Tensor]: A dictionary of loss components. + """ + + # active mask of feature map, (B, 1, f, f) + active_mask_feature_map = self.mask(inputs.shape, inputs.device) + SparseHelper._cur_active = active_mask_feature_map + + # active mask of original input, (B, 1, H, W) + active_mask_origin = active_mask_feature_map.repeat_interleave( + self.downsample_raito, + 2).repeat_interleave(self.downsample_raito, 3) + masked_img = inputs * active_mask_origin + + # get hierarchical encoded sparse features in a list + # containing four feature maps + feature_maps = self.backbone(masked_img) + + # from the smallest feature map to the largest + feature_maps = list(feature_maps) + feature_maps.reverse() + + cur_active = active_mask_feature_map + feature_maps_to_dec = [] + for i, feature_map in enumerate(feature_maps): + if feature_map is not None: + # fill in empty positions with [mask] embeddings + feature_map = self.enc_dec_norms[i](feature_map) + mask_token = self.mask_tokens[i].expand_as(feature_map) + feature_map = torch.where( + cur_active.expand_as(feature_map), feature_map, + mask_token.to(feature_map.dtype)) + feature_map = self.enc_dec_projectors[i](feature_map) + feature_maps_to_dec.append(feature_map) + + # dilate the mask map + cur_active = cur_active.repeat_interleave( + 2, dim=2).repeat_interleave( + 2, dim=3) + + # decode and reconstruct + rec_img = self.neck(feature_maps_to_dec) + + # compute loss + loss = self.head(rec_img, inputs, active_mask_feature_map) + losses = dict(loss=loss) + return losses diff --git a/mmpretrain/models/utils/__init__.py b/mmpretrain/models/utils/__init__.py index 4c13dca0..e59d71d5 100644 --- a/mmpretrain/models/utils/__init__.py +++ b/mmpretrain/models/utils/__init__.py @@ -25,6 +25,9 @@ from .position_encoding import (ConditionalPositionEncoding, build_2d_sincos_position_embedding) from .res_layer_extra_norm import ResLayerExtraNorm from .se_layer import SELayer +from .sparse_modules import (SparseAvgPooling, SparseBatchNorm2d, SparseConv2d, + SparseHelper, SparseLayerNorm2D, SparseMaxPooling, + SparseSyncBatchNorm2d) from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused from .vector_quantizer import NormEMAVectorQuantizer @@ -78,6 +81,13 @@ __all__ = [ 'SwiGLUFFN', 'SwiGLUFFNFused', 'RotaryEmbeddingFast', + 'SparseAvgPooling', + 'SparseConv2d', + 'SparseHelper', + 'SparseMaxPooling', + 'SparseBatchNorm2d', + 'SparseLayerNorm2D', + 'SparseSyncBatchNorm2d', ] if WITH_MULTIMODAL: diff --git a/mmpretrain/models/utils/sparse_modules.py b/mmpretrain/models/utils/sparse_modules.py new file mode 100644 index 00000000..dd6bf345 --- /dev/null +++ b/mmpretrain/models/utils/sparse_modules.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) ByteDance, Inc. and its affiliates. All rights reserved. +# Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py +import torch +import torch.nn as nn + +from mmpretrain.registry import MODELS + + +class SparseHelper: + """The helper to compute sparse operation with pytorch, such as sparse + convlolution, sparse batch norm, etc.""" + + _cur_active: torch.Tensor = None + + @staticmethod + def _get_active_map_or_index(H: int, + returning_active_map: bool = True + ) -> torch.Tensor: + """Get current active map with (B, 1, f, f) shape or index format.""" + # _cur_active with shape (B, 1, f, f) + downsample_raito = H // SparseHelper._cur_active.shape[-1] + active_ex = SparseHelper._cur_active.repeat_interleave( + downsample_raito, 2).repeat_interleave(downsample_raito, 3) + return active_ex if returning_active_map else active_ex.squeeze( + 1).nonzero(as_tuple=True) + + @staticmethod + def sp_conv_forward(self, x: torch.Tensor) -> torch.Tensor: + """Sparse convolution forward function.""" + x = super(type(self), self).forward(x) + + # (b, c, h, w) *= (b, 1, h, w), mask the output of conv + x *= SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=True) + return x + + @staticmethod + def sp_bn_forward(self, x: torch.Tensor) -> torch.Tensor: + """Sparse batch norm forward function.""" + active_index = SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=False) + + # (b, c, h, w) -> (b, h, w, c) + x_permuted = x.permute(0, 2, 3, 1) + + # select the features on non-masked positions to form flatten features + # with shape (n, c) + x_flattened = x_permuted[active_index] + + # use BN1d to normalize this flatten feature (n, c) + x_flattened = super(type(self), self).forward(x_flattened) + + # generate output + output = torch.zeros_like(x_permuted, dtype=x_flattened.dtype) + output[active_index] = x_flattened + + # (b, h, w, c) -> (b, c, h, w) + output = output.permute(0, 3, 1, 2) + return output + + +class SparseConv2d(nn.Conv2d): + """hack: override the forward function. + See `sp_conv_forward` above for more details + """ + forward = SparseHelper.sp_conv_forward + + +class SparseMaxPooling(nn.MaxPool2d): + """hack: override the forward function. + See `sp_conv_forward` above for more details + """ + forward = SparseHelper.sp_conv_forward + + +class SparseAvgPooling(nn.AvgPool2d): + """hack: override the forward function. + See `sp_conv_forward` above for more details + """ + forward = SparseHelper.sp_conv_forward + + +@MODELS.register_module() +class SparseBatchNorm2d(nn.BatchNorm1d): + """hack: override the forward function. + See `sp_bn_forward` above for more details + """ + forward = SparseHelper.sp_bn_forward + + +@MODELS.register_module() +class SparseSyncBatchNorm2d(nn.SyncBatchNorm): + """hack: override the forward function. + See `sp_bn_forward` above for more details + """ + forward = SparseHelper.sp_bn_forward + + +@MODELS.register_module('SparseLN2d') +class SparseLayerNorm2D(nn.LayerNorm): + """Implementation of sparse LayerNorm on channels for 2d images.""" + + def forward(self, + x: torch.Tensor, + data_format='channel_first') -> torch.Tensor: + """Sparse layer norm forward function with 2D data. + + Args: + x (torch.Tensor): The input tensor. + data_format (str): The format of the input tensor. If + ``"channel_first"``, the shape of the input tensor should be + (B, C, H, W). If ``"channel_last"``, the shape of the input + tensor should be (B, H, W, C). Defaults to "channel_first". + """ + assert x.dim() == 4, ( + f'LayerNorm2d only supports inputs with shape ' + f'(N, C, H, W), but got tensor with shape {x.shape}') + if data_format == 'channel_last': + index = SparseHelper._get_active_map_or_index( + H=x.shape[1], returning_active_map=False) + + # select the features on non-masked positions to form flatten + # features with shape (n, c) + x_flattened = x[index] + # use LayerNorm to normalize this flatten feature (n, c) + x_flattened = super().forward(x_flattened) + + # generate output + x = torch.zeros_like(x, dtype=x_flattened.dtype) + x[index] = x_flattened + elif data_format == 'channel_first': + index = SparseHelper._get_active_map_or_index( + H=x.shape[2], returning_active_map=False) + x_permuted = x.permute(0, 2, 3, 1) + + # select the features on non-masked positions to form flatten + # features with shape (n, c) + x_flattened = x_permuted[index] + # use LayerNorm to normalize this flatten feature (n, c) + x_flattened = super().forward(x_flattened) + + # generate output + x = torch.zeros_like(x_permuted, dtype=x_flattened.dtype) + x[index] = x_flattened + x = x.permute(0, 3, 1, 2).contiguous() + else: + raise NotImplementedError + return x diff --git a/model-index.yml b/model-index.yml index 7d6fd0be..3fb3d045 100644 --- a/model-index.yml +++ b/model-index.yml @@ -78,6 +78,7 @@ Import: - configs/chinese_clip/metafile.yml - configs/itpn/metafile.yml - configs/hivit/metafile.yml + - configs/spark/metafile.yml - configs/minigpt4/metafile.yml - configs/llava/metafile.yml - configs/otter/metafile.yml diff --git a/tests/test_models/test_selfsup/test_spark.py b/tests/test_models/test_selfsup/test_spark.py new file mode 100644 index 00000000..cb4fe3d5 --- /dev/null +++ b/tests/test_models/test_selfsup/test_spark.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest +import torch + +from mmpretrain.models import SparK +from mmpretrain.structures import DataSample + + +@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit') +def test_spark(): + data_preprocessor = { + 'mean': (123.675, 116.28, 103.53), + 'std': (58.395, 57.12, 57.375), + 'to_rgb': True + } + + backbone = dict( + type='SparseResNet', + depth=50, + out_indices=(0, 1, 2, 3), + drop_path_rate=0.05, + norm_cfg=dict(type='BN')) + neck = dict( + type='SparKLightDecoder', + feature_dim=512, + upsample_ratio=32, # equal to downsample_raito + mid_channels=0, + norm_cfg=dict(type='BN'), + last_act=False) + head = dict( + type='SparKPretrainHead', + loss=dict(type='PixelReconstructionLoss', criterion='L2')) + + alg = SparK( + backbone=backbone, + neck=neck, + head=head, + data_preprocessor=data_preprocessor, + enc_dec_norm_cfg=dict(type='BN'), + ) + + fake_data = { + 'inputs': torch.randn((2, 3, 224, 224)), + 'data_sample': [DataSample() for _ in range(2)] + } + + fake_inputs = alg.data_preprocessor(fake_data) + fake_loss = alg(**fake_inputs, mode='loss') + assert isinstance(fake_loss['loss'].item(), float)