[Feature] Support iTPN and HiViT (#1584)

* hivit added

* Update hivit.py

* Update hivit.py

* Add files via upload

* Update __init__.py

* Add files via upload

* Update __init__.py

* Add files via upload

* Update hivit.py

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Update itpn.py

* Add files via upload

* Update __init__.py

* Update mae_hivit-base-p16.py

* Delete mim_itpn-base-p16.py

* Add files via upload

* Update itpn_hivit-base-p16.py

* Update itpn.py

* Update hivit.py

* Update __init__.py

* Update mae.py

* Delete hivit.py

* Update __init__.py

* Delete configs/itpn directory

* Add files via upload

* Add files via upload

* Delete configs/hivit directory

* Add files via upload

* refactor and add metafile and readme

* update clip

* add ut

* update ut

* update

* update docstring

* update model.rst

---------

Co-authored-by: 田运杰 <48153283+sunsmarterjie@users.noreply.github.com>
pull/1637/head
Yixiao Fang 2023-05-26 12:08:34 +08:00 committed by GitHub
parent 1f07c92ed1
commit e4c4a81b56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 3184 additions and 2 deletions

View File

@ -0,0 +1,49 @@
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
type='TwoNormDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# clip mean & std
second_mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
second_std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ColorJitter',
brightness=0.4,
contrast=0.4,
saturation=0.4,
hue=0.),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandomResizedCropAndInterpolationWithTwoPic',
size=224,
second_size=224,
interpolation='bicubic',
second_interpolation='bicubic',
scale=(0.2, 1.0)),
dict(
type='BEiTMaskGenerator',
input_size=(14, 14),
num_masking_patches=75,
max_num_patches=75,
min_num_patches=16),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=256,
num_workers=8,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline))

View File

@ -0,0 +1,83 @@
# dataset settings
dataset_type = 'ImageNet'
data_root = 'data/imagenet/'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)
bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackInputs'),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs'),
]
train_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)
val_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))
# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator

View File

@ -0,0 +1,28 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='HiViT',
arch='base',
img_size=224,
ape=True,
rpe=True,
drop_path_rate=0.5),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=512,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -0,0 +1,28 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='HiViT',
arch='small',
img_size=224,
ape=True,
rpe=True,
drop_path_rate=0.3),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=384,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -0,0 +1,28 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='HiViT',
arch='tiny',
img_size=224,
ape=True,
rpe=True,
drop_path_rate=0.05),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=384,
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
cal_acc=False),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
],
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)

View File

@ -0,0 +1,33 @@
# model settings
model = dict(
type='iTPN',
backbone=dict(
type='iTPNHiViT',
arch='base',
reconstruction_type='pixel',
mask_ratio=0.75),
neck=dict(
type='iTPNPretrainDecoder',
num_patches=196,
patch_size=16,
in_chans=3,
embed_dim=512,
decoder_embed_dim=512,
decoder_depth=6,
decoder_num_heads=16,
mlp_ratio=4.,
reconstruction_type='pixel',
# transformer pyramid
fpn_dim=256,
fpn_depth=2,
num_outs=3,
),
head=dict(
type='MAEPretrainHead',
norm_pix=True,
patch_size=16,
loss=dict(type='PixelReconstructionLoss', criterion='L2')),
init_cfg=[
dict(type='Xavier', layer='Linear', distribution='uniform'),
dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
])

View File

@ -0,0 +1,24 @@
# model settings
model = dict(
type='MAE',
backbone=dict(
type='MIMHiViT', patch_size=16, arch='base', mask_ratio=0.75),
neck=dict(
type='MAEPretrainDecoder',
patch_size=16,
in_chans=3,
embed_dim=512,
decoder_embed_dim=512,
decoder_depth=6,
decoder_num_heads=16,
mlp_ratio=4.,
),
head=dict(
type='MAEPretrainHead',
norm_pix=True,
patch_size=16,
loss=dict(type='PixelReconstructionLoss', criterion='L2')),
init_cfg=[
dict(type='Xavier', layer='Linear', distribution='uniform'),
dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
])

View File

@ -0,0 +1,41 @@
# for batch in each gpu is 128, 8 gpu
# lr = 5e-4 * 128 * 8 / 512 = 0.001
optim_wrapper = dict(
optimizer=dict(
type='AdamW',
lr=5e-4 * 1024 / 512,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999)),
paramwise_cfg=dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
flat_decay_mult=0.0,
custom_keys={
'.pos_embed': dict(decay_mult=0.0),
'.relative_position_bias_table': dict(decay_mult=0.0)
}),
)
# learning policy
param_scheduler = [
# warm up learning rate scheduler
dict(
type='LinearLR',
start_factor=1e-3,
by_epoch=True,
end=20,
# update by iter
convert_to_iter_based=True),
# main learning rate scheduler
dict(type='CosineAnnealingLR', eta_min=1e-5, by_epoch=True, begin=20)
]
# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
val_cfg = dict()
test_cfg = dict()
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=1024)

View File

@ -0,0 +1,81 @@
# HiViT
> [HiViT: A Simple and More Efficient Design of Hierarchical Vision Transformer](https://arxiv.org/abs/2205.14949)
<!-- [ALGORITHM] -->
## Abstract
Recently, masked image modeling (MIM) has offered a new methodology of self-supervised pre-training of vision transformers. A key idea of efficient implementation is to discard the masked image patches (or tokens) throughout the target network (encoder), which requires the encoder to be a plain vision transformer (e.g., ViT), albeit hierarchical vision transformers (e.g., Swin Transformer) have potentially better properties in formulating vision inputs. In this paper, we offer a new design of hierarchical vision transformers named HiViT (short for Hierarchical ViT) that enjoys both high efficiency and good performance in MIM. The key is to remove the unnecessary "local inter-unit operations", deriving structurally simple hierarchical vision transformers in which mask-units can be serialized like plain vision transformers. For this purpose, we start with Swin Transformer and (i) set the masking unit size to be the token size in the main stage of Swin Transformer, (ii) switch off inter-unit self-attentions before the main stage, and (iii) eliminate all operations after the main stage. Empirical studies demonstrate the advantageous performance of HiViT in terms of fully-supervised, self-supervised, and transfer learning. In particular, in running MAE on ImageNet-1K, HiViT-B reports a +0.6% accuracy gain over ViT-B and a 1.9$\times$ speed-up over Swin-B, and the performance gain generalizes to downstream tasks of detection and segmentation. Code will be made publicly available.
<div align=center>
<img src="https://github.com/open-mmlab/mmpretrain/assets/36138628/4a99cf9d-15df-4866-8750-bd2c3db5d894" width="80%"/>
</div>
## How to use it?
<!-- [TABS-BEGIN] -->
<!-- **Predict image**
```python
from mmpretrain import inference_model
predict = inference_model('hivit-tiny-p16_16xb64_in1k', 'demo/bird.JPEG')
print(predict['pred_class'])
print(predict['pred_score'])
```
<!-- **Use the model** -->
<!-- ```python
import torch
from mmpretrain import get_model
model = get_model('hivit-tiny-p16_16xb64_in1k', pretrained=True)
inputs = torch.rand(1, 3, 224, 224)
out = model(inputs)
print(type(out))
# To extract features.
feats = model.extract_feat(inputs)
print(type(feats))
``` -->
**Train/Test Command**
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
Train:
```shell
python tools/train.py configs/hivit/hivit-tiny-p16_16xb64_in1k.py
```
<!-- Test:
```shell
python tools/test.py configs/hivit/hivit-tiny-p16_16xb64_in1k.py None
``` -->
<!-- [TABS-END] -->
## Models and results
### Image Classification on ImageNet-1k
| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Config | Download |
| :---------------------------- | :----------: | :--------: | :-------: | :-------: | :--------------------------------------: | :------: |
| `hivit-tiny-p16_16xb64_in1k` | From scratch | 19.18 | 4.60 | 82.10 | [config](hivit-tiny-p16_16xb64_in1k.py) | N/A |
| `hivit-small-p16_16xb64_in1k` | From scratch | 37.53 | 9.07 | N/A | [config](hivit-small-p16_16xb64_in1k.py) | N/A |
| `hivit-base-p16_16xb64_in1k` | From scratch | 79.05 | 18.47 | N/A | [config](hivit-base-p16_16xb64_in1k.py) | N/A |
## Citation
```bibtex
@inproceedings{zhanghivit,
title={HiViT: A Simpler and More Efficient Design of Hierarchical Vision Transformer},
author={Zhang, Xiaosong and Tian, Yunjie and Xie, Lingxi and Huang, Wei and Dai, Qi and Ye, Qixiang and Tian, Qi},
booktitle={International Conference on Learning Representations},
year={2023},
}
```

View File

@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/hivit/base_224.py',
'../_base_/datasets/imagenet_bs64_hivit_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_hivit.py',
'../_base_/default_runtime.py'
]
# schedule settings
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))

View File

@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/hivit/small_224.py',
'../_base_/datasets/imagenet_bs64_hivit_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_hivit.py',
'../_base_/default_runtime.py'
]
# schedule settings
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))

View File

@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/hivit/tiny_224.py',
'../_base_/datasets/imagenet_bs64_hivit_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_hivit.py',
'../_base_/default_runtime.py'
]
# schedule settings
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))

View File

@ -0,0 +1,63 @@
Collections:
- Name: HiViT
Metadata:
Architecture:
- Dense Connections
- Dropout
- GELU
- Layer Normalization
- Multi-Head Attention
- Scaled Dot-Product Attention
Paper:
Title: 'HiViT: A Simple and More Efficient Design of Hierarchical Vision Transformer'
URL: https://arxiv.org/abs/2205.14949
README: configs/hivit/README.md
Code:
URL: null
Version: null
Models:
- Name: hivit-tiny-p16_16xb64_in1k
Metadata:
FLOPs: 4603000000
Parameters: 19181000
Training Data:
- ImageNet-1k
In Collection: HiViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.1
Task: Image Classification
Weights:
Config: configs/hivit/hivit-tiny-p16_16xb64_in1k.py
- Name: hivit-small-p16_16xb64_in1k
Metadata:
FLOPs: 9072000000
Parameters: 37526000
Training Data:
- ImageNet-1k
In Collection: HiViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy:
Task: Image Classification
Weights:
Config: configs/hivit/hivit-small-p16_16xb64_in1k.py
- Name: hivit-base-p16_16xb64_in1k
Metadata:
FLOPs: 18474000000
Parameters: 79051000
Training Data:
- ImageNet-1k
In Collection: HiViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy:
Task: Image Classification
Weights:
Config: configs/hivit/hivit-base-p16_16xb64_in1k.py

View File

@ -0,0 +1,65 @@
# iTPN
> [Integrally Pre-Trained Transformer Pyramid Networks](https://arxiv.org/abs/2211.12735)
<!-- [ALGORITHM] -->
## Abstract
In this paper, we present an integral pre-training framework based on masked image modeling (MIM). We advocate for pre-training the backbone and neck jointly so that the transfer gap between MIM and downstream recognition tasks is minimal. We make two technical contributions. First, we unify the reconstruction and recognition necks by inserting a feature pyramid into the pre-training stage. Second, we complement mask image modeling (MIM) with masked feature modeling (MFM) that offers multi-stage supervision to the feature pyramid. The pre-trained models, termed integrally pre-trained transformer pyramid networks (iTPNs), serve as powerful foundation models for visual recognition. In particular, the base/large-level iTPN achieves an 86.2%/87.8% top-1 accuracy on ImageNet-1K, a 53.2%/55.6% box AP on COCO object detection with 1x training schedule using Mask-RCNN, and a 54.7%/57.7% mIoU on ADE20K semantic segmentation using UPerHead -- all these results set new records. Our work inspires the community to work on unifying upstream pre-training and downstream fine-tuning tasks. Code and the pre-trained models will be released at https://github.com/sunsmarterjie/iTPN.
<div align=center>
<img src="https://github.com/open-mmlab/mmpretrain/assets/36138628/2e53d5b5-300e-4640-8507-c1173965ca62" width="80%"/>
</div>
## How to use it?
<!-- [TABS-BEGIN] -->
<!-- **Use the model**
```python
import torch
from mmpretrain import get_model
model = get_model('itpn-clip-b_hivit-base-p16_8xb256-amp-coslr-800e_in1k', pretrained=True)
inputs = torch.rand(1, 3, 224, 224)
out = model(inputs)
print(type(out))
# To extract features.
feats = model.extract_feat(inputs)
print(type(feats))
``` -->
**Train/Test Command**
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
Train:
```shell
python tools/train.py configs/itpn/itpn-pixel_hivit-base-p16_8xb512-amp-coslr-800e_in1k.py
```
<!-- [TABS-END] -->
## Models and results
### Pretrained models
| Model | Params (M) | Flops (G) | Config | Download |
| :------------------------------------------------------ | :--------: | :-------: | :----------------------------------------------------------------: | :------: |
| `itpn-clip-b_hivit-base-p16_8xb256-amp-coslr-800e_in1k` | 233.00 | 18.47 | [config](itpn-clip-b_hivit-base-p16_8xb256-amp-coslr-800e_in1k.py) | N/A |
| `itpn-pixel_hivit-base-p16_8xb512-amp-coslr-800e_in1k` | 103.00 | 18.47 | [config](itpn-pixel_hivit-base-p16_8xb512-amp-coslr-800e_in1k.py) | N/A |
| `itpn-pixel_hivit-large-p16_8xb512-amp-coslr-800e_in1k` | 314.00 | 63.98 | [config](itpn-pixel_hivit-large-p16_8xb512-amp-coslr-800e_in1k.py) | N/A |
## Citation
```bibtex
@article{tian2022integrally,
title={Integrally Pre-Trained Transformer Pyramid Networks},
author={Tian, Yunjie and Xie, Lingxi and Wang, Zhaozhi and Wei, Longhui and Zhang, Xiaopeng and Jiao, Jianbin and Wang, Yaowei and Tian, Qi and Ye, Qixiang},
journal={arXiv preprint arXiv:2211.12735},
year={2022}
}
```

View File

@ -0,0 +1,84 @@
_base_ = [
'../_base_/datasets/imagenet_bs256_itpn.py',
'../_base_/default_runtime.py',
]
model = dict(
type='iTPN',
backbone=dict(
type='iTPNHiViT',
arch='base',
drop_path_rate=0.0,
rpe=True,
layer_scale_init_value=0.1,
reconstruction_type='clip'),
neck=dict(
type='iTPNPretrainDecoder',
patch_size=16,
in_chans=3,
embed_dim=512,
mlp_ratio=4.,
reconstruction_type='clip',
# transformer pyramid
fpn_dim=256,
fpn_depth=2,
num_outs=3,
),
head=dict(
type='iTPNClipHead',
embed_dims=512,
num_embed=512,
loss=dict(type='CosineSimilarityLoss')),
target_generator=dict(
type='CLIPGenerator',
tokenizer_path= # noqa
'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/clip_vit_base_16.pth.tar' # noqa
),
)
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
# betas: (0.9, 0.98) for 300 epochs and (0.9, 0.999) for 1600 epochs.
optimizer=dict(
type='AdamW', lr=1.5e-3, betas=(0.9, 0.98), weight_decay=0.05),
clip_grad=dict(max_norm=3.0),
paramwise_cfg=dict(
custom_keys={
'.norm': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0),
'.gamma': dict(decay_mult=0.0),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
eta_min=1e-5,
by_epoch=True,
begin=10,
end=300,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=2048)

View File

@ -0,0 +1,84 @@
_base_ = [
'../_base_/datasets/imagenet_bs256_itpn.py',
'../_base_/default_runtime.py',
]
model = dict(
type='iTPN',
backbone=dict(
type='iTPNHiViT',
arch='base',
drop_path_rate=0.1,
rpe=True,
layer_scale_init_value=0.1,
reconstruction_type='clip'),
neck=dict(
type='iTPNPretrainDecoder',
patch_size=16,
in_chans=3,
embed_dim=512,
mlp_ratio=4.,
reconstruction_type='clip',
# transformer pyramid
fpn_dim=256,
fpn_depth=2,
num_outs=3,
),
head=dict(
type='iTPNClipHead',
embed_dims=512,
num_embed=512,
loss=dict(type='CrossEntropyLoss')),
target_generator=dict(
type='CLIPGenerator',
tokenizer_path= # noqa
'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/clip_vit_base_16.pth.tar' # noqa
),
)
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
# betas: (0.9, 0.98) for 300 epochs and (0.9, 0.999) for 800/1600 epochs.
optimizer=dict(
type='AdamW', lr=1.5e-3, betas=(0.9, 0.999), weight_decay=0.05),
clip_grad=dict(max_norm=3.0),
paramwise_cfg=dict(
custom_keys={
'.norm': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0),
'.gamma': dict(decay_mult=0.0),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
eta_min=1e-5,
by_epoch=True,
begin=10,
end=800,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=2048)

View File

@ -0,0 +1,56 @@
_base_ = [
'../_base_/models/itpn_hivit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'norm': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=1560,
by_epoch=True,
begin=40,
end=1600,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=1600)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,56 @@
_base_ = [
'../_base_/models/itpn_hivit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'norm': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=360,
by_epoch=True,
begin=40,
end=400,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=400)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,56 @@
_base_ = [
'../_base_/models/itpn_hivit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'norm': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=760,
by_epoch=True,
begin=40,
end=800,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,61 @@
_base_ = [
'../_base_/models/itpn_hivit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# model settings
model = dict(
backbone=dict(type='iTPNHiViT', arch='large'),
neck=dict(type='iTPNPretrainDecoder', embed_dim=768))
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'ln': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=1560,
by_epoch=True,
begin=40,
end=1600,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=1600)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,61 @@
_base_ = [
'../_base_/models/itpn_hivit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# model settings
model = dict(
backbone=dict(type='iTPNHiViT', arch='large'),
neck=dict(type='iTPNPretrainDecoder', embed_dim=768))
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'ln': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=360,
by_epoch=True,
begin=40,
end=400,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=400)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,61 @@
_base_ = [
'../_base_/models/itpn_hivit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# model settings
model = dict(
backbone=dict(type='iTPNHiViT', arch='large'),
neck=dict(type='iTPNPretrainDecoder', embed_dim=768))
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'ln': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=760,
by_epoch=True,
begin=40,
end=800,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,50 @@
Collections:
- Name: iTPN
Metadata:
Architecture:
- Dense Connections
- GELU
- Layer Normalization
- Multi-Head Attention
- Scaled Dot-Product Attention
Paper:
Title: 'Integrally Pre-Trained Transformer Pyramid Networks'
URL: https://arxiv.org/abs/2211.12735
README: configs/itpn/README.md
Code:
URL: null
Version: null
Models:
- Name: itpn-clip-b_hivit-base-p16_8xb256-amp-coslr-800e_in1k
Metadata:
FLOPs: 18474000000
Parameters: 233000000
Training Data:
- ImageNet-1k
In Collection: iTPN
Results: null
Weights:
Config: configs/itpn/itpn-clip-b_hivit-base-p16_8xb256-amp-coslr-800e_in1k.py
- Name: itpn-pixel_hivit-base-p16_8xb512-amp-coslr-800e_in1k
Metadata:
FLOPs: 18474000000
Parameters: 103000000
Training Data:
- ImageNet-1k
In Collection: iTPN
Results: null
Weights:
Config: configs/itpn/itpn-pixel_hivit-base-p16_8xb512-amp-coslr-800e_in1k.py
- Name: itpn-pixel_hivit-large-p16_8xb512-amp-coslr-800e_in1k
Metadata:
FLOPs: 63977000000
Parameters: 314000000
Training Data:
- ImageNet-1k
In Collection: iTPN
Results: null
Weights:
Config: configs/itpn/itpn-pixel_hivit-large-p16_8xb512-amp-coslr-800e_in1k.py

View File

@ -0,0 +1,56 @@
_base_ = [
'../_base_/models/mae_hivit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'norm': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=1560,
by_epoch=True,
begin=40,
end=1600,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=1600)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,56 @@
_base_ = [
'../_base_/models/mae_hivit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'norm': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=360,
by_epoch=True,
begin=40,
end=400,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=400)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,56 @@
_base_ = [
'../_base_/models/mae_hivit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'norm': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=760,
by_epoch=True,
begin=40,
end=800,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,61 @@
_base_ = [
'../_base_/models/mae_hivit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# model settings
model = dict(
backbone=dict(type='MAEHiViT', arch='large'),
neck=dict(type='MAEPretrainDecoder', embed_dim=768))
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'norm': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=1560,
by_epoch=True,
begin=40,
end=1600,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=1600)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,61 @@
_base_ = [
'../_base_/models/mae_hivit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# model settings
model = dict(
backbone=dict(type='MIMHiViT', arch='large'),
neck=dict(type='MAEPretrainDecoder', embed_dim=768))
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'norm': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=360,
by_epoch=True,
begin=40,
end=400,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=400)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -0,0 +1,61 @@
_base_ = [
'../_base_/models/mae_hivit-base-p16.py',
'../_base_/datasets/imagenet_bs512_mae.py',
'../_base_/default_runtime.py',
]
# model settings
model = dict(
backbone=dict(type='MAEHiViT', arch='large'),
neck=dict(type='MAEPretrainDecoder', embed_dim=768))
# optimizer wrapper
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='AdamW',
lr=1.5e-4 * 4096 / 256,
betas=(0.9, 0.95),
weight_decay=0.05),
paramwise_cfg=dict(
custom_keys={
'norm': dict(decay_mult=0.0),
'bias': dict(decay_mult=0.0),
'pos_embed': dict(decay_mult=0.),
'mask_token': dict(decay_mult=0.),
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=0.0001,
by_epoch=True,
begin=0,
end=40,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=760,
by_epoch=True,
begin=40,
end=800,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800)
default_hooks = dict(
# only keeps the latest 3 checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
# auto resume
resume = True
find_unused_parameters = True
# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=4096)

View File

@ -66,6 +66,7 @@ Self-supervised Algorithms
CAE
DenseCL
EVA
iTPN
MAE
MILAN
MaskFeat
@ -88,6 +89,8 @@ like ``mask``, and here is the a list of these **modified backbone** modules.
BEiTPretrainViT
CAEPretrainViT
iTPNHiViT
MAEHiViT
MAEViT
MILANViT
MaskFeatViT
@ -167,6 +170,7 @@ Backbones
EfficientFormer
EfficientNet
EfficientNetV2
HiViT
HRNet
HorNet
InceptionV3
@ -235,6 +239,7 @@ Necks
NonLinearNeck
SimMIMLinearDecoder
SwAVNeck
iTPNPretrainDecoder
.. module:: mmpretrain.models.heads
@ -271,6 +276,7 @@ Heads
SwAVHead
VigClsHead
VisionTransformerClsHead
iTPNClipHead
.. module:: mmpretrain.models.losses

View File

@ -13,6 +13,7 @@ from .edgenext import EdgeNeXt
from .efficientformer import EfficientFormer
from .efficientnet import EfficientNet
from .efficientnet_v2 import EfficientNetV2
from .hivit import HiViT
from .hornet import HorNet
from .hrnet import HRNet
from .inception_v3 import InceptionV3
@ -120,4 +121,5 @@ __all__ = [
'XCiT',
'ViTSAM',
'ViTEVA02',
'HiViT',
]

View File

@ -0,0 +1,656 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
from mmcv.cnn.bricks import DropPath
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.registry import MODELS
from ..utils import build_norm_layer, to_2tuple
from .base_backbone import BaseBackbone
class Mlp(nn.Module):
"""MLP block.
Args:
in_features (int): Number of input dims.
hidden_features (int): Number of hidden dims.
out_feature (int): Number of out dims.
act_layer: MLP activation layer.
drop (float): MLP dropout rate.
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
"""Attention.
Args:
input size (int): Input size.
dim (int): Number of input dims.
num_heads (int): Number of attention heads.
qkv_bias (bool): Enable bias for qkv projections if True.
qk_scale (float): The number of divider after q@k. Default to None.
attn_drop (float): The drop out rate for attention output weights.
Defaults to 0.
proj_drop (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
rpe (bool): If True, add relative position embedding to
the patch embedding.
"""
def __init__(self,
input_size,
dim,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
rpe=True):
super().__init__()
self.input_size = input_size
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * input_size - 1) *
(2 * input_size - 1), num_heads)) if rpe else None
if rpe:
coords_h = torch.arange(input_size)
coords_w = torch.arange(input_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :,
None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += input_size - 1
relative_coords[:, :, 1] += input_size - 1
relative_coords[:, :, 0] *= 2 * input_size - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer('relative_position_index',
relative_position_index)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, rpe_index=None, mask=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if rpe_index is not None:
rpe_index = self.relative_position_index.view(-1)
S = int(math.sqrt(rpe_index.size(-1)))
relative_position_bias = self.relative_position_bias_table[
rpe_index].view(-1, S, S, self.num_heads)
relative_position_bias = relative_position_bias.permute(
0, 3, 1, 2).contiguous()
attn = attn + relative_position_bias
if mask is not None:
mask = mask.bool()
attn = attn.masked_fill(~mask[:, None, None, :], float('-inf'))
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class BlockWithRPE(nn.Module):
"""HiViT block.
Args:
input_size (int): Input size.
dim (int): Number of input dims.
num_heads (int): Number of attention heads.
mlp_ratio (int): Ratio of MLP hidden dim to embedding dim.
qkv_bias (bool): Enable bias for qkv projections if True.
qk_scale (float): The number of divider after q@k. Default to None.
drop (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path (float): Stochastic depth rate. Defaults to 0.
rpe (bool): If True, add relative position embedding to
the patch embedding.
layer_scale_init_value (float): Layer-scale init values. Defaults to 0.
act_layer: MLP activation layer.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
"""
def __init__(self,
input_size,
dim,
num_heads=0.,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
rpe=True,
layer_scale_init_value=0.0,
act_layer=nn.GELU,
norm_cfg=dict(type='LN')):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
with_attn = num_heads > 0.
self.norm1 = build_norm_layer(norm_cfg, dim) if with_attn else None
self.attn = Attention(
input_size,
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
rpe=rpe,
) if with_attn else None
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = build_norm_layer(norm_cfg, dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
if layer_scale_init_value > 0:
self.gamma_1 = nn.Parameter(
layer_scale_init_value * torch.ones(
(dim)), requires_grad=True) if with_attn else None
self.gamma_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rpe_index=None, mask=None):
if self.attn is not None:
if self.gamma_1 is not None:
x = x + self.drop_path(
self.gamma_1 * self.attn(self.norm1(x), rpe_index, mask))
else:
x = x + self.drop_path(
self.attn(self.norm1(x), rpe_index, mask))
if self.gamma_2 is not None:
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
"""PatchEmbed for HiViT.
Args:
img_size (int): Input image size.
patch_size (int): Patch size. Defaults to 16.
inner_patches (int): Inner patch. Defaults to 4.
in_chans (int): Number of image input channels.
embed_dim (int): Transformer embedding dimension.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
kernel_size (int): Kernel size.
pad_size (int): Pad size.
"""
def __init__(self,
img_size=224,
patch_size=16,
inner_patches=4,
in_chans=3,
embed_dim=128,
norm_cfg=None,
kernel_size=None,
pad_size=None):
super().__init__()
img_size = to_2tuple(img_size) if not isinstance(img_size,
tuple) else img_size
patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // patch_size[0], img_size[1] // patch_size[1]
]
self.img_size = img_size
self.patch_size = patch_size
self.inner_patches = inner_patches
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
conv_size = [size // inner_patches for size in patch_size]
kernel_size = kernel_size or conv_size
pad_size = pad_size or 0
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=kernel_size,
stride=conv_size,
padding=pad_size)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
patches_resolution = (H // self.patch_size[0], W // self.patch_size[1])
num_patches = patches_resolution[0] * patches_resolution[1]
x = self.proj(x).view(
B,
-1,
patches_resolution[0],
self.inner_patches,
patches_resolution[1],
self.inner_patches,
).permute(0, 2, 4, 3, 5, 1).reshape(B, num_patches, self.inner_patches,
self.inner_patches, -1)
if self.norm is not None:
x = self.norm(x)
return x
class PatchMerge(nn.Module):
"""PatchMerge for HiViT.
Args:
dim (int): Number of input channels.
norm_cfg (dict): Config dict for normalization layer.
"""
def __init__(self, dim, norm_cfg):
super().__init__()
self.norm = build_norm_layer(norm_cfg, dim * 4)
self.reduction = nn.Linear(dim * 4, dim * 2, bias=False)
def forward(self, x, *args, **kwargs):
is_main_stage = len(x.shape) == 3
if is_main_stage:
B, N, C = x.shape
S = int(math.sqrt(N))
x = x.reshape(B, S // 2, 2, S // 2, 2, C) \
.permute(0, 1, 3, 2, 4, 5) \
.reshape(B, -1, 2, 2, C)
x0 = x[..., 0::2, 0::2, :]
x1 = x[..., 1::2, 0::2, :]
x2 = x[..., 0::2, 1::2, :]
x3 = x[..., 1::2, 1::2, :]
x = torch.cat([x0, x1, x2, x3], dim=-1)
x = self.norm(x)
x = self.reduction(x)
if is_main_stage:
x = x[:, :, 0, 0, :]
return x
@MODELS.register_module()
class HiViT(BaseBackbone):
"""HiViT.
A PyTorch implement of: `HiViT: A Simple and More Efficient Design
of Hierarchical Vision Transformer <https://arxiv.org/abs/2205.14949>`_.
Args:
arch (str | dict): Swin Transformer architecture. If use string, choose
from 'tiny', 'small', and'base'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (int): The number of heads in attention
modules of each stage.
Defaults to 'tiny'.
img_size (int): Input image size.
patch_size (int): Patch size. Defaults to 16.
inner_patches (int): Inner patch. Defaults to 4.
in_chans (int): Number of image input channels.
embed_dim (int): Transformer embedding dimension.
depths (list[int]): Number of successive HiViT blocks.
num_heads (int): Number of attention heads.
stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim
in the first two stages.
mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in
the last stage.
qkv_bias (bool): Enable bias for qkv projections if True.
qk_scale (float): The number of divider after q@k. Default to None.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
ape (bool): If True, add absolute position embedding to
the patch embedding.
rpe (bool): If True, add relative position embedding to
the patch embedding.
patch_norm (bool): If True, use norm_cfg for normalization layer.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
kernel_size (int): Kernel size.
pad_size (int): Pad size.
layer_scale_init_value (float): Layer-scale init values. Defaults to 0.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'embed_dims': 384,
'depths': [1, 1, 10],
'num_heads': 6}),
**dict.fromkeys(['s', 'small'],
{'embed_dims': 384,
'depths': [2, 2, 20],
'num_heads': 6}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': 512,
'depths': [2, 2, 24],
'num_heads': 8}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': 768,
'depths': [2, 2, 40],
'num_heads': 12}),
} # yapf: disable
num_extra_tokens = 0
def __init__(self,
arch='base',
img_size=224,
patch_size=16,
inner_patches=4,
in_chans=3,
stem_mlp_ratio=3.,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.0,
norm_cfg=dict(type='LN'),
out_indices=[23],
ape=True,
rpe=False,
patch_norm=True,
frozen_stages=-1,
kernel_size=None,
pad_size=None,
layer_scale_init_value=0.0,
init_cfg=None):
super(HiViT, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {'embed_dims', 'depths', 'num_heads'}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.num_stages = len(self.depths)
self.ape = ape
self.rpe = rpe
self.patch_size = patch_size
self.num_features = self.embed_dims
self.mlp_ratio = mlp_ratio
self.num_main_blocks = self.depths[-1]
self.out_indices = out_indices
self.out_indices[-1] = self.depths[-1] - 1
img_size = to_2tuple(img_size) if not isinstance(img_size,
tuple) else img_size
embed_dim = self.embed_dims // 2**(self.num_stages - 1)
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
inner_patches=inner_patches,
in_chans=in_chans,
embed_dim=embed_dim,
norm_cfg=norm_cfg if patch_norm else None,
kernel_size=kernel_size,
pad_size=pad_size)
num_patches = self.patch_embed.num_patches
Hp, Wp = self.patch_embed.patches_resolution
if rpe:
assert Hp == Wp, 'If you use relative position, make sure H == W '
'of input size'
# absolute position embedding
if ape:
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, self.num_features))
trunc_normal_(self.pos_embed, std=.02)
if rpe:
# get pair-wise relative position index for each token inside the
# window
coords_h = torch.arange(Hp)
coords_w = torch.arange(Wp)
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :,
None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += Hp - 1
relative_coords[:, :, 1] += Wp - 1
relative_coords[:, :, 0] *= 2 * Wp - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer('relative_position_index',
relative_position_index)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = iter(
x.item()
for x in torch.linspace(0, drop_path_rate,
sum(self.depths) + sum(self.depths[:-1])))
# build blocks
self.blocks = nn.ModuleList()
for stage_i, stage_depth in enumerate(self.depths):
is_main_stage = embed_dim == self.num_features
nhead = self.num_heads if is_main_stage else 0
ratio = mlp_ratio if is_main_stage else stem_mlp_ratio
# every block not in main stage includes two mlp blocks
stage_depth = stage_depth if is_main_stage else stage_depth * 2
for _ in range(stage_depth):
self.blocks.append(
BlockWithRPE(
Hp,
embed_dim,
nhead,
ratio,
qkv_bias,
qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=next(dpr),
rpe=rpe,
norm_cfg=norm_cfg,
layer_scale_init_value=layer_scale_init_value,
))
if stage_i + 1 < self.num_stages:
self.blocks.append(PatchMerge(embed_dim, norm_cfg))
embed_dim *= 2
self.frozen_stages = frozen_stages
if self.frozen_stages > 0:
self._freeze_stages()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def interpolate_pos_encoding(self, x, h, w):
npatch = x.shape[1]
N = self.pos_embed.shape[1]
if npatch == N and w == h:
return self.pos_embed
patch_pos_embed = self.pos_embed
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)),
dim).permute(0, 3, 1, 2),
scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)),
mode='bicubic',
)
assert int(h0) == patch_pos_embed.shape[-2] and int(
w0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def forward(self, x):
B, C, H, W = x.shape
Hp, Wp = H // self.patch_size, W // self.patch_size
x = self.patch_embed(x)
outs = []
for i, blk in enumerate(self.blocks[:-self.num_main_blocks]):
x = blk(x)
if i in self.out_indices:
x = x.reshape(B, Hp, Wp, *x.shape[-3:]).permute(
0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * x.shape[-3],
Wp * x.shape[-2]).contiguous()
outs.append(x)
x = x[..., 0, 0, :]
if self.ape:
x = x + self.interpolate_pos_encoding(x, H, W)
x = self.pos_drop(x)
rpe_index = True if self.rpe else None
for i, blk in enumerate(self.blocks[-self.num_main_blocks:]):
x = blk(x, rpe_index)
if i in self.out_indices:
x = x.transpose(1, 2).view(B, -1, Hp, Wp).contiguous()
outs.append(x)
return tuple(outs)
def _freeze_stages(self):
# freeze position embedding
if self.pos_embed is not None:
self.pos_embed.requires_grad = False
# set dropout to eval model
self.pos_drop.eval()
# freeze patch embedding
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages + 1):
m = self.blocks[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze the last layer norm
for param in self.fc_norm.parameters():
param.requires_grad = False
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
Note:
The first depth is the stem module (``layer_depth=0``), and the
last depth is the subsequent module (``layer_depth=num_layers-1``)
"""
self.num_layers = len(self.blocks)
num_layers = self.num_layers + 2
if not param_name.startswith(prefix):
# For subsequent module like head
return num_layers - 1, num_layers
param_name = param_name[len(prefix):]
if param_name in 'pos_embed':
layer_depth = 0
elif param_name.startswith('patch_embed'):
layer_depth = 0
elif param_name.startswith('layers'):
layer_id = int(param_name.split('.')[1])
layer_depth = layer_id + 1
else:
layer_depth = num_layers - 1
return layer_depth, num_layers

View File

@ -10,6 +10,7 @@ from .efficientformer_head import EfficientFormerClsHead
from .grounding_head import GroundingHead
from .itc_head import ITCHead
from .itm_head import ITMHead
from .itpn_clip_head import iTPNClipHead
from .latent_heads import LatentCrossCorrelationHead, LatentPredictHead
from .levit_head import LeViTClsHead
from .linear_head import LinearClsHead
@ -62,4 +63,5 @@ __all__ = [
'ITCHead',
'ITMHead',
'GroundingHead',
'iTPNClipHead',
]

View File

@ -0,0 +1,56 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class iTPNClipHead(BaseModule):
"""Head for iTPN Pre-training using Clip.
Compute the logits and the cross entropy loss.
Args:
embed_dims (int): The dimension of embedding.
num_embed (int): The number of classification types.
loss (dict): The config of loss.
init_cfg (dict or List[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
embed_dims: int,
num_embed: int,
loss: dict,
init_cfg: Optional[Union[dict, List[dict]]] = dict(
type='TruncNormal', layer='Linear', std=0.02, bias=0)
) -> None:
super().__init__(init_cfg=init_cfg)
self.cls_head = nn.Linear(embed_dims, num_embed)
self.loss_module = MODELS.build(loss)
def loss(self, feats: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""Generate loss.
Args:
feats (torch.Tensor): Features from backbone.
target (torch.Tensor): Target generated by target_generator.
mask (torch.Tensor): Generated mask for pretraing.
"""
mask = mask.to(torch.device('cuda'), non_blocking=True)
mask = mask.flatten(1).to(torch.bool)
target = target[mask]
# remove cls_token
# feats = feats[:, 1:]
logits = self.cls_head(feats[mask])
loss = self.loss_module(logits, target)
return loss

View File

@ -5,6 +5,7 @@ from .densecl_neck import DenseCLNeck
from .gap import GlobalAveragePooling
from .gem import GeneralizedMeanPooling
from .hr_fuse import HRFuseScales
from .itpn_neck import iTPNPretrainDecoder
from .linear_neck import LinearNeck
from .mae_neck import ClsBatchNormNeck, MAEPretrainDecoder
from .milan_neck import MILANPretrainDecoder
@ -30,4 +31,5 @@ __all__ = [
'NonLinearNeck',
'SimMIMLinearDecoder',
'SwAVNeck',
'iTPNPretrainDecoder',
]

View File

@ -0,0 +1,388 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule
from mmpretrain.models.backbones.hivit import BlockWithRPE
from mmpretrain.registry import MODELS
from ..backbones.vision_transformer import TransformerEncoderLayer
from ..utils import build_2d_sincos_position_embedding
class PatchSplit(nn.Module):
"""The up-sample module used in neck (transformer pyramid network)
Args:
dim (int): the input dimension (channel number).
fpn_dim (int): the fpn dimension (channel number).
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
"""
def __init__(self, dim, fpn_dim, norm_cfg):
super().__init__()
_, self.norm = build_norm_layer(norm_cfg, dim)
self.reduction = nn.Linear(dim, fpn_dim * 4, bias=False)
self.fpn_dim = fpn_dim
def forward(self, x):
B, N, H, W, C = x.shape
x = self.norm(x)
x = self.reduction(x)
x = x.reshape(B, N, H, W, 2, 2,
self.fpn_dim).permute(0, 1, 2, 4, 3, 5,
6).reshape(B, N, 2 * H, 2 * W,
self.fpn_dim)
return x
@MODELS.register_module()
class iTPNPretrainDecoder(BaseModule):
"""The neck module of iTPN (transformer pyramid network).
Args:
num_patches (int): The number of total patches. Defaults to 196.
patch_size (int): Image patch size. Defaults to 16.
in_chans (int): The channel of input image. Defaults to 3.
embed_dim (int): Encoder's embedding dimension. Defaults to 512.
fpn_dim (int): The fpn dimension (channel number).
fpn_depth (int): The layer number of feature pyramid.
decoder_embed_dim (int): Decoder's embedding dimension.
Defaults to 512.
decoder_depth (int): The depth of decoder. Defaults to 8.
decoder_num_heads (int): Number of attention heads of decoder.
Defaults to 16.
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
Defaults to 4.
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
reconstruction_type (str): The itpn supports 2 kinds of supervisions.
Defaults to 'pixel'.
num_outs (int): The output number of neck (transformer pyramid
network). Defaults to 3.
predict_feature_dim (int): The output dimension to supervision.
Defaults to None.
init_cfg (Union[List[dict], dict], optional): Initialization config
dict. Defaults to None.
"""
def __init__(self,
num_patches: int = 196,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 512,
fpn_dim: int = 256,
fpn_depth: int = 2,
decoder_embed_dim: int = 512,
decoder_depth: int = 6,
decoder_num_heads: int = 16,
mlp_ratio: int = 4,
norm_cfg: dict = dict(type='LN', eps=1e-6),
reconstruction_type: str = 'pixel',
num_outs: int = 3,
qkv_bias: bool = True,
qk_scale: Optional[bool] = None,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
predict_feature_dim: Optional[float] = None,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.num_patches = num_patches
assert reconstruction_type in ['pixel', 'clip'], \
'iTPN method only support `pixel` and `clip`, ' \
f'but got `{reconstruction_type}`.'
self.reconstruction_type = reconstruction_type
self.num_outs = num_outs
self.build_transformer_pyramid(
num_outs=num_outs,
embed_dim=embed_dim,
fpn_dim=fpn_dim,
fpn_depth=fpn_depth,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
rpe=False,
norm_cfg=norm_cfg,
)
# merge the output
self.decoder_embed = nn.ModuleList()
self.decoder_embed.append(
nn.Sequential(
nn.LayerNorm(fpn_dim),
nn.Linear(fpn_dim, decoder_embed_dim, bias=True),
))
if self.num_outs >= 2:
self.decoder_embed.append(
nn.Sequential(
nn.LayerNorm(fpn_dim),
nn.Linear(fpn_dim, decoder_embed_dim // 4, bias=True),
))
if self.num_outs >= 3:
self.decoder_embed.append(
nn.Sequential(
nn.LayerNorm(fpn_dim),
nn.Linear(fpn_dim, decoder_embed_dim // 16, bias=True),
))
if reconstruction_type == 'pixel':
self.mask_token = nn.Parameter(
torch.zeros(1, 1, decoder_embed_dim))
# create new position embedding, different from that in encoder
# and is not learnable
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches, decoder_embed_dim),
requires_grad=False)
self.decoder_blocks = nn.ModuleList([
TransformerEncoderLayer(
decoder_embed_dim,
decoder_num_heads,
int(mlp_ratio * decoder_embed_dim),
qkv_bias=True,
norm_cfg=norm_cfg) for _ in range(decoder_depth)
])
self.decoder_norm_name, decoder_norm = build_norm_layer(
norm_cfg, decoder_embed_dim, postfix=1)
self.add_module(self.decoder_norm_name, decoder_norm)
# Used to map features to pixels
if predict_feature_dim is None:
predict_feature_dim = patch_size**2 * in_chans
self.decoder_pred = nn.Linear(
decoder_embed_dim, predict_feature_dim, bias=True)
else:
_, norm = build_norm_layer(norm_cfg, embed_dim)
self.add_module('norm', norm)
def build_transformer_pyramid(self,
num_outs=3,
embed_dim=512,
fpn_dim=256,
fpn_depth=2,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
rpe=False,
norm_cfg=None):
Hp = None
mlvl_dims = {'4': embed_dim // 4, '8': embed_dim // 2, '16': embed_dim}
if num_outs > 1:
if embed_dim != fpn_dim:
self.align_dim_16tofpn = nn.Linear(embed_dim, fpn_dim)
else:
self.align_dim_16tofpn = None
self.fpn_modules = nn.ModuleList()
self.fpn_modules.append(
BlockWithRPE(
Hp,
fpn_dim,
0,
mlp_ratio,
qkv_bias,
qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=0.,
rpe=rpe,
norm_cfg=norm_cfg))
self.fpn_modules.append(
BlockWithRPE(
Hp,
fpn_dim,
0,
mlp_ratio,
qkv_bias,
qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=0.,
rpe=False,
norm_cfg=norm_cfg,
))
self.align_dim_16to8 = nn.Linear(
mlvl_dims['8'], fpn_dim, bias=False)
self.split_16to8 = PatchSplit(mlvl_dims['16'], fpn_dim, norm_cfg)
self.block_16to8 = nn.Sequential(*[
BlockWithRPE(
Hp,
fpn_dim,
0,
mlp_ratio,
qkv_bias,
qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=0.,
rpe=rpe,
norm_cfg=norm_cfg,
) for _ in range(fpn_depth)
])
if num_outs > 2:
self.align_dim_8to4 = nn.Linear(
mlvl_dims['4'], fpn_dim, bias=False)
self.split_8to4 = PatchSplit(fpn_dim, fpn_dim, norm_cfg)
self.block_8to4 = nn.Sequential(*[
BlockWithRPE(
Hp,
fpn_dim,
0,
mlp_ratio,
qkv_bias,
qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=0.,
rpe=rpe,
norm_cfg=norm_cfg,
) for _ in range(fpn_depth)
])
self.fpn_modules.append(
BlockWithRPE(
Hp,
fpn_dim,
0,
mlp_ratio,
qkv_bias,
qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=0.,
rpe=rpe,
norm_cfg=norm_cfg))
def init_weights(self) -> None:
"""Initialize position embedding and mask token of MAE decoder."""
super().init_weights()
if self.reconstruction_type == 'pixel':
decoder_pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.decoder_pos_embed.shape[-1],
cls_token=False)
self.decoder_pos_embed.data.copy_(decoder_pos_embed.float())
torch.nn.init.normal_(self.mask_token, std=.02)
else:
self.rescale_init_weight()
def rescale_init_weight(self) -> None:
"""Rescale the initialized weights."""
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.fpn_modules):
if isinstance(layer, BlockWithRPE):
if layer.attn is not None:
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
@property
def decoder_norm(self):
"""The normalization layer of decoder."""
return getattr(self, self.decoder_norm_name)
def forward(self,
x: torch.Tensor,
ids_restore: torch.Tensor = None) -> torch.Tensor:
"""The forward function.
The process computes the visible patches' features vectors and the mask
tokens to output feature vectors, which will be used for
reconstruction.
Args:
x (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
ids_restore (torch.Tensor): ids to restore original image.
Returns:
torch.Tensor: The reconstructed feature vectors, which is of
shape B x (num_patches) x C.
"""
features = x[:2]
x = x[-1]
B, L, _ = x.shape
x = x[..., None, None, :]
Hp = Wp = math.sqrt(L)
outs = [x] if self.align_dim_16tofpn is None else [
self.align_dim_16tofpn(x)
]
if self.num_outs >= 2:
x = self.block_16to8(
self.split_16to8(x) + self.align_dim_16to8(features[1]))
outs.append(x)
if self.num_outs >= 3:
x = self.block_8to4(
self.split_8to4(x) + self.align_dim_8to4(features[0]))
outs.append(x)
if self.num_outs > 3:
outs = [
out.reshape(B, Hp, Wp, *out.shape[-3:]).permute(
0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * out.shape[-3],
Wp * out.shape[-2]).contiguous()
for out in outs
]
if self.num_outs >= 4:
outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2))
if self.num_outs >= 5:
outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2))
for i, out in enumerate(outs):
out = self.fpn_modules[i](out)
outs[i] = out
if self.reconstruction_type == 'pixel':
feats = []
for feat, layer in zip(outs, self.decoder_embed):
x = layer(feat).reshape(B, L, -1)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x = torch.cat([x, mask_tokens], dim=1)
x = torch.gather(
x,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
feats.append(x)
x = feats.pop(0)
# add pos embed
x = x + self.decoder_pos_embed
for i, feat in enumerate(feats):
x = x + feats[i]
# apply Transformer blocks
for i, blk in enumerate(self.decoder_blocks):
x = blk(x)
x = self.decoder_norm(x)
x = self.decoder_pred(x)
return x
else:
feats = []
for feat, layer in zip(outs, self.decoder_embed):
x = layer(feat).reshape(B, L, -1)
feats.append(x)
x = feats.pop(0)
for i, feat in enumerate(feats):
x = x + feats[i]
x = self.norm(x)
return x

View File

@ -6,7 +6,8 @@ from .byol import BYOL
from .cae import CAE, CAEPretrainViT, DALLEEncoder
from .densecl import DenseCL
from .eva import EVA
from .mae import MAE, MAEViT
from .itpn import iTPN, iTPNHiViT
from .mae import MAE, MAEHiViT, MAEViT
from .maskfeat import HOGGenerator, MaskFeat, MaskFeatViT
from .milan import MILAN, CLIPGenerator, MILANViT
from .mixmim import MixMIM, MixMIMPretrainTransformer
@ -24,6 +25,9 @@ __all__ = [
'CAEPretrainViT',
'DALLEEncoder',
'MAEViT',
'MAEHiViT',
'iTPNHiViT',
'iTPN',
'HOGGenerator',
'MaskFeatViT',
'CLIPGenerator',

View File

@ -0,0 +1,356 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.models.backbones.hivit import BlockWithRPE, HiViT, PatchMerge
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils import build_2d_sincos_position_embedding
from .base import BaseSelfSupervisor
@MODELS.register_module()
class iTPNHiViT(HiViT):
"""HiViT for iTPN pre-training.
Args:
img_size (int | tuple): Input image size. Defaults to 224.
patch_size (int | tuple): The patch size. Defaults to 16.
inner_patches (int): Inner patch. Defaults to 4.
stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim
in the first two stages. Defaults to 3.
mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in
the last stage. Defaults to 4.
qkv_bias (bool): Enable bias for qkv projections if True.
qk_scale (float): The number of divider after q@k. Default to None.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
ape (bool): If True, add absolute position embedding to
the patch embedding.
rpe (bool): If True, add relative position embedding to
the patch embedding.
layer_scale_init_value (float): Layer-scale init values. Defaults to 0.
mask_ratio (bool): The ratio of total number of patches to be masked.
Defaults to 0.75.
reconstruction_type (str): The reconstruction of self-supervised
learning. Defaults to 'pixel'.
"""
def __init__(
self,
arch='base',
img_size: int = 224,
patch_size: int = 16,
inner_patches: int = 4,
stem_mlp_ratio: int = 3.,
mlp_ratio: int = 4.,
qkv_bias: bool = True,
qk_scale: Optional[bool] = None,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
ape: bool = True,
rpe: bool = False,
layer_scale_init_value: float = 0.0,
mask_ratio: float = 0.75,
reconstruction_type: str = 'pixel',
):
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
inner_patches=inner_patches,
stem_mlp_ratio=stem_mlp_ratio,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
ape=ape,
rpe=rpe,
layer_scale_init_value=layer_scale_init_value)
self.pos_embed.requires_grad = False
self.mask_ratio = mask_ratio
assert reconstruction_type in ['pixel', 'clip'], \
'iTPN method only support `pixel` and `clip`, ' \
f'but got `{reconstruction_type}`.'
self.reconstruction_type = reconstruction_type
self.num_patches = self.patch_embed.num_patches
if reconstruction_type == 'clip':
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
def init_weights(self) -> None:
"""Initialize position embedding, patch embedding and cls token."""
super().apply(self._init_weights)
if self.reconstruction_type == 'clip':
trunc_normal_(self.mask_token, std=0.02)
self.rescale_init_weight()
else:
pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.pos_embed.shape[-1],
cls_token=False)
self.pos_embed.data.copy_(pos_embed.float())
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
def rescale_init_weight(self) -> None:
"""Rescale the initialized weights."""
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
if isinstance(layer, BlockWithRPE):
if layer.attn is not None:
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def masking_id(self, batch_size, mask_ratio):
N, L = batch_size, self.pos_embed.size(1)
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(
N, L, device=self.pos_embed.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(
noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=self.pos_embed.device)
mask[:, :ids_keep.size(1)] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return ids_keep, ids_restore, mask
def forward_pixel(
self,
x: torch.Tensor,
mask: Optional[bool] = True
) -> Tuple[Tuple, torch.Tensor, torch.Tensor]:
"""Generate features for masked images.
The function supports two kind of forward behaviors. If the ``mask`` is
``True``, the function will generate mask to masking some patches
randomly and get the hidden features for visible patches, which means
the function will be executed as masked imagemodeling pre-training;
if the ``mask`` is ``None`` or ``False``, the forward function will
call ``super().forward()``, which extract features from images without
mask.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
mask (bool, optional): To indicate whether the forward function
generating ``mask`` or not.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
mask and the ids to restore original image.
- ``x`` (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
- ``mask`` (torch.Tensor): mask used to mask image.
- ``ids_restore`` (torch.Tensor): ids to restore original image.
"""
if mask is None or False:
return super().forward(x)
else:
B, C, H, W = x.shape
ids_keep, ids_restore, mask = self.masking_id(B, self.mask_ratio)
x = self.patch_embed(x)
x = torch.gather(
x,
dim=1,
index=ids_keep[:, :, None, None,
None].expand(-1, -1, *x.shape[2:]))
outs = []
for blk in self.blocks[:-self.num_main_blocks]:
if isinstance(blk, PatchMerge):
outs.append(x)
x = blk(x)
x = x[..., 0, 0, :]
if self.ape:
pos_embed = self.interpolate_pos_encoding(x, H, W)
pos_embed = torch.gather(
pos_embed.expand(B, -1, -1),
dim=1,
index=ids_keep[:, :, None].expand(-1, -1,
pos_embed.shape[2]),
)
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks[-self.num_main_blocks:]:
x = blk(x)
outs.append(x)
return (tuple(outs), mask, ids_restore)
def forward_clip(self,
x: torch.Tensor,
mask: Optional[bool] = True) -> Tuple:
"""Generate features for masked images.
The function supports two kind of forward behaviors. If the ``mask`` is
``True``, the function will generate mask to masking some patches
randomly and get the hidden features for visible patches, which means
the function will be executed as masked imagemodeling pre-training;
if the ``mask`` is ``None`` or ``False``, the forward function will
call ``super().forward()``, which extract features from images without
mask.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
mask (bool, optional): To indicate whether the forward function
generating ``mask`` or not.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
mask and the ids to restore original image.
- ``x`` (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
- ``mask`` (torch.Tensor): mask used to mask image.
- ``ids_restore`` (torch.Tensor): ids to restore original image.
"""
if mask is None or False:
return super().forward(x)
else:
B, C, H, W = x.shape
x = self.patch_embed(x)
outs = []
for blk in self.blocks[:-self.num_main_blocks]:
if isinstance(blk, PatchMerge):
outs.append(x)
x = blk(x)
x = x[..., 0, 0, :]
B, L, _ = x.shape
mask_token = self.mask_token.expand(B, L, -1)
w = mask.flatten(1).unsqueeze(-1).type_as(mask_token)
x = x * (1. - w) + mask_token * w
if self.ape:
pos_embed = self.interpolate_pos_encoding(x, H, W)
x = x + pos_embed
x = self.pos_drop(x)
rpe_index = True if self.rpe else None
for blk in self.blocks[-self.num_main_blocks:]:
x = blk(x, rpe_index)
outs.append(x)
return tuple(outs)
def forward(self, x: torch.Tensor, mask: Optional[bool] = True) -> Tuple:
"""Generate features for masked images.
The function supports two kind of forward behaviors. If the ``mask`` is
``True``, the function will generate mask to masking some patches
randomly and get the hidden features for visible patches, which means
the function will be executed as masked imagemodeling pre-training;
if the ``mask`` is ``None`` or ``False``, the forward function will
call ``super().forward()``, which extract features from images without
mask.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
mask (bool, optional): To indicate whether the forward function
generating ``mask`` or not.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
mask and the ids to restore original image.
- ``x`` (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
- ``mask`` (torch.Tensor): mask used to mask image.
- ``ids_restore`` (torch.Tensor): ids to restore original image.
"""
if self.reconstruction_type == 'pixel':
return self.forward_pixel(x, mask)
return self.forward_clip(x, mask)
@MODELS.register_module()
class iTPN(BaseSelfSupervisor):
"""iTPN.
Implementation of `iTPN: Integrally Pre-Trained Transformer Pyramid
Networks <https://arxiv.org/abs/2211.12735>`_.
"""
def extract_feat(self, inputs: torch.Tensor):
return self.backbone(inputs, mask=None)
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
**kwargs) -> Dict[str, torch.Tensor]:
"""The forward function in training.
Args:
inputs (torch.Tensor): The input images.
data_samples (List[DataSample]): All elements required
during the forward function.
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
if self.backbone.reconstruction_type == 'pixel':
latent, mask, ids_restore = self.backbone(inputs)
pred = self.neck(latent, ids_restore)
loss = self.head.loss(pred, inputs, mask)
else:
mask = torch.stack(
[data_sample.mask for data_sample in data_samples])
img_latent = self.backbone(inputs[0], mask)
# inputs[1] is the target image
with torch.no_grad():
target = self.target_generator(inputs[1])[0]
target = target.detach()
# iTPN contains a neck module
feats = self.neck(img_latent)
loss = self.head.loss(feats, target[:, 1:, :], mask)
losses = dict(loss=loss)
return losses

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from mmpretrain.models import VisionTransformer
from mmpretrain.models import HiViT, VisionTransformer
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
from ..utils import build_2d_sincos_position_embedding
@ -234,3 +234,183 @@ class MAE(BaseSelfSupervisor):
loss = self.head.loss(pred, inputs, mask)
losses = dict(loss=loss)
return losses
@MODELS.register_module()
class MAEHiViT(HiViT):
"""HiViT for MAE pre-training.
A PyTorch implement of: `HiViT: A Simple and More Efficient Design
of Hierarchical Vision Transformer <https://arxiv.org/abs/2205.14949>`_.
This module implements the patch masking in MAE and initialize the
position embedding with sine-cosine position embedding.
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
Defaults to 4, to downsample 4x at the first stage
inner_patches (int): The inner patches within a token
Defaults to 4
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
ape (bool): the absolute position embedding
rpe (bool): the relative position embedding
Defaults to False
layer_scale_init_value (float): the layer scale init value
mask_ratio (bool): The ratio of total number of patches to be masked.
Defaults to 0.75.
init_cfg (Union[List[dict], dict], optional): Initialization config
dict. Defaults to None.
"""
def __init__(self,
arch: Union[str, dict] = 'b',
img_size: int = 224,
patch_size: int = 16,
inner_patches: int = 4,
out_indices: Union[list, int] = [23],
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_cfg: dict = dict(type='LN', eps=1e-6),
ape: bool = True,
rpe: bool = False,
layer_scale_init_value: float = 0.0,
mask_ratio: float = 0.75,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
inner_patches=inner_patches,
out_indices=out_indices,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
ape=ape,
rpe=rpe,
layer_scale_init_value=layer_scale_init_value,
init_cfg=init_cfg)
self.pos_embed.requires_grad = False
self.mask_ratio = mask_ratio
self.num_patches = self.patch_embed.num_patches
def init_weights(self) -> None:
"""Initialize position embedding, patch embedding."""
super().apply(self._init_weights)
pos_embed = build_2d_sincos_position_embedding(
int(self.num_patches**.5),
self.pos_embed.shape[-1],
cls_token=False)
self.pos_embed.data.copy_(pos_embed.float())
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
def masking_id(
self, batch_size,
mask_ratio) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate the mask for MAE Pre-training.
Args:
batch_size: The batch size of input data
mask_ratio: The mask ratio of total patches.
Defaults to 0.75.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: the ids
for the tokens retained, the ids to restore original image,
and the mask
"""
N, L = batch_size, self.pos_embed.size(1)
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(
N, L, device=self.pos_embed.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(
noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=self.pos_embed.device)
mask[:, :ids_keep.size(1)] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return ids_keep, ids_restore, mask
def forward(
self,
x: torch.Tensor,
mask: Optional[bool] = True
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate features for masked images.
The function supports two kind of forward behaviors. If the ``mask`` is
``True``, the function will generate mask to masking some patches
randomly and get the hidden features for visible patches, which means
the function will be executed as masked imagemodeling pre-training;
if the ``mask`` is ``None`` or ``False``, the forward function will
call ``super().forward()``, which extract features from images without
mask.
Args:
x (torch.Tensor): Input images, which is of shape B x C x H x W.
mask (bool, optional): To indicate whether the forward function
generating ``mask`` or not.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
mask and the ids to restore original image.
- ``x`` (torch.Tensor): hidden features, which is of shape
B x (L * mask_ratio) x C.
- ``mask`` (torch.Tensor): mask used to mask image.
- ``ids_restore`` (torch.Tensor): ids to restore original image.
"""
if mask is None or False:
return super().forward(x)
else:
B, C, H, W = x.shape
ids_keep, ids_restore, mask = self.masking_id(B, self.mask_ratio)
x = self.patch_embed(x)
x = torch.gather(
x,
dim=1,
index=ids_keep[:, :, None, None,
None].expand(-1, -1, *x.shape[2:]))
for blk in self.blocks[:-self.num_main_blocks]:
x = blk(x)
x = x[..., 0, 0, :]
if self.ape:
pos_embed = self.interpolate_pos_encoding(x, H, W)
pos_embed = torch.gather(
pos_embed.expand(B, -1, -1),
dim=1,
index=ids_keep[:, :, None].expand(-1, -1,
pos_embed.shape[2]),
)
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks[-self.num_main_blocks:]:
x = blk(x)
return (x, mask, ids_restore)

View File

@ -76,3 +76,5 @@ Import:
- configs/flamingo/metafile.yml
- configs/blip2/metafile.yml
- configs/chinese_clip/metafile.yml
- configs/itpn/metafile.yml
- configs/hivit/metafile.yml

View File

@ -34,6 +34,7 @@ test_list = [
backbone=mmpretrain.models.VisionTransformer,
forward=False,
backward=False),
Cfg(name='hivit-tiny-p16_16xb64_in1k', backbone=mmpretrain.models.HiViT),
]

View File

@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmpretrain.models import iTPN
from mmpretrain.structures import DataSample
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_itpn():
data_preprocessor = {
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'to_rgb': True
}
backbone = dict(
type='iTPNHiViT',
arch='base',
reconstruction_type='pixel',
mask_ratio=0.75)
neck = dict(
type='iTPNPretrainDecoder',
num_patches=196,
patch_size=16,
in_chans=3,
embed_dim=512,
decoder_embed_dim=512,
decoder_depth=6,
decoder_num_heads=16,
mlp_ratio=4.,
reconstruction_type='pixel',
# transformer pyramid
fpn_dim=256,
fpn_depth=2,
num_outs=3,
)
head = dict(
type='MAEPretrainHead',
norm_pix=True,
patch_size=16,
loss=dict(type='PixelReconstructionLoss', criterion='L2'))
alg = iTPN(
backbone=backbone,
neck=neck,
head=head,
data_preprocessor=data_preprocessor)
fake_data = {
'inputs': torch.randn((2, 3, 224, 224)),
'data_samples': [DataSample() for _ in range(2)]
}
fake_inputs = alg.data_preprocessor(fake_data)
fake_outputs = alg(**fake_inputs, mode='loss')
assert isinstance(fake_outputs['loss'].item(), float)