[Feature] Support iTPN and HiViT (#1584)
* hivit added * Update hivit.py * Update hivit.py * Add files via upload * Update __init__.py * Add files via upload * Update __init__.py * Add files via upload * Update hivit.py * Add files via upload * Add files via upload * Add files via upload * Add files via upload * Update itpn.py * Add files via upload * Update __init__.py * Update mae_hivit-base-p16.py * Delete mim_itpn-base-p16.py * Add files via upload * Update itpn_hivit-base-p16.py * Update itpn.py * Update hivit.py * Update __init__.py * Update mae.py * Delete hivit.py * Update __init__.py * Delete configs/itpn directory * Add files via upload * Add files via upload * Delete configs/hivit directory * Add files via upload * refactor and add metafile and readme * update clip * add ut * update ut * update * update docstring * update model.rst --------- Co-authored-by: 田运杰 <48153283+sunsmarterjie@users.noreply.github.com>pull/1637/head
parent
1f07c92ed1
commit
e4c4a81b56
|
@ -0,0 +1,49 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
data_preprocessor = dict(
|
||||
type='TwoNormDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# clip mean & std
|
||||
second_mean=[0.48145466 * 255, 0.4578275 * 255, 0.40821073 * 255],
|
||||
second_std=[0.26862954 * 255, 0.26130258 * 255, 0.27577711 * 255],
|
||||
to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ColorJitter',
|
||||
brightness=0.4,
|
||||
contrast=0.4,
|
||||
saturation=0.4,
|
||||
hue=0.),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandomResizedCropAndInterpolationWithTwoPic',
|
||||
size=224,
|
||||
second_size=224,
|
||||
interpolation='bicubic',
|
||||
second_interpolation='bicubic',
|
||||
scale=(0.2, 1.0)),
|
||||
dict(
|
||||
type='BEiTMaskGenerator',
|
||||
input_size=(14, 14),
|
||||
num_masking_patches=75,
|
||||
max_num_patches=75,
|
||||
min_num_patches=16),
|
||||
dict(type='PackInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=256,
|
||||
num_workers=8,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
collate_fn=dict(type='default_collate'),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix=dict(img_path='train/'),
|
||||
pipeline=train_pipeline))
|
|
@ -0,0 +1,83 @@
|
|||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
data_root = 'data/imagenet/'
|
||||
data_preprocessor = dict(
|
||||
num_classes=1000,
|
||||
# RGB format normalization parameters
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
# convert image from BGR to RGB
|
||||
to_rgb=True,
|
||||
)
|
||||
|
||||
bgr_mean = data_preprocessor['mean'][::-1]
|
||||
bgr_std = data_preprocessor['std'][::-1]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
scale=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies='timm_increasing',
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=bgr_mean,
|
||||
fill_std=bgr_std),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='ResizeEdge',
|
||||
scale=256,
|
||||
edge='short',
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='PackInputs'),
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
ann_file='meta/train.txt',
|
||||
data_prefix='train',
|
||||
pipeline=train_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
)
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=64,
|
||||
num_workers=5,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
ann_file='meta/val.txt',
|
||||
data_prefix='val',
|
||||
pipeline=test_pipeline),
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
)
|
||||
val_evaluator = dict(type='Accuracy', topk=(1, 5))
|
||||
|
||||
# If you want standard test, please manually configure the test dataset
|
||||
test_dataloader = val_dataloader
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,28 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='HiViT',
|
||||
arch='base',
|
||||
img_size=224,
|
||||
ape=True,
|
||||
rpe=True,
|
||||
drop_path_rate=0.5),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=512,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
|
@ -0,0 +1,28 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='HiViT',
|
||||
arch='small',
|
||||
img_size=224,
|
||||
ape=True,
|
||||
rpe=True,
|
||||
drop_path_rate=0.3),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=384,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
|
@ -0,0 +1,28 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='HiViT',
|
||||
arch='tiny',
|
||||
img_size=224,
|
||||
ape=True,
|
||||
rpe=True,
|
||||
drop_path_rate=0.05),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=384,
|
||||
init_cfg=None, # suppress the default init_cfg of LinearClsHead.
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
cal_acc=False),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='Mixup', alpha=0.8),
|
||||
dict(type='CutMix', alpha=1.0)
|
||||
]),
|
||||
)
|
|
@ -0,0 +1,33 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='iTPN',
|
||||
backbone=dict(
|
||||
type='iTPNHiViT',
|
||||
arch='base',
|
||||
reconstruction_type='pixel',
|
||||
mask_ratio=0.75),
|
||||
neck=dict(
|
||||
type='iTPNPretrainDecoder',
|
||||
num_patches=196,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=512,
|
||||
decoder_embed_dim=512,
|
||||
decoder_depth=6,
|
||||
decoder_num_heads=16,
|
||||
mlp_ratio=4.,
|
||||
reconstruction_type='pixel',
|
||||
# transformer pyramid
|
||||
fpn_dim=256,
|
||||
fpn_depth=2,
|
||||
num_outs=3,
|
||||
),
|
||||
head=dict(
|
||||
type='MAEPretrainHead',
|
||||
norm_pix=True,
|
||||
patch_size=16,
|
||||
loss=dict(type='PixelReconstructionLoss', criterion='L2')),
|
||||
init_cfg=[
|
||||
dict(type='Xavier', layer='Linear', distribution='uniform'),
|
||||
dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
|
||||
])
|
|
@ -0,0 +1,24 @@
|
|||
# model settings
|
||||
model = dict(
|
||||
type='MAE',
|
||||
backbone=dict(
|
||||
type='MIMHiViT', patch_size=16, arch='base', mask_ratio=0.75),
|
||||
neck=dict(
|
||||
type='MAEPretrainDecoder',
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=512,
|
||||
decoder_embed_dim=512,
|
||||
decoder_depth=6,
|
||||
decoder_num_heads=16,
|
||||
mlp_ratio=4.,
|
||||
),
|
||||
head=dict(
|
||||
type='MAEPretrainHead',
|
||||
norm_pix=True,
|
||||
patch_size=16,
|
||||
loss=dict(type='PixelReconstructionLoss', criterion='L2')),
|
||||
init_cfg=[
|
||||
dict(type='Xavier', layer='Linear', distribution='uniform'),
|
||||
dict(type='Constant', layer='LayerNorm', val=1.0, bias=0.0)
|
||||
])
|
|
@ -0,0 +1,41 @@
|
|||
# for batch in each gpu is 128, 8 gpu
|
||||
# lr = 5e-4 * 128 * 8 / 512 = 0.001
|
||||
optim_wrapper = dict(
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=5e-4 * 1024 / 512,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999)),
|
||||
paramwise_cfg=dict(
|
||||
norm_decay_mult=0.0,
|
||||
bias_decay_mult=0.0,
|
||||
flat_decay_mult=0.0,
|
||||
custom_keys={
|
||||
'.pos_embed': dict(decay_mult=0.0),
|
||||
'.relative_position_bias_table': dict(decay_mult=0.0)
|
||||
}),
|
||||
)
|
||||
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
# warm up learning rate scheduler
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-3,
|
||||
by_epoch=True,
|
||||
end=20,
|
||||
# update by iter
|
||||
convert_to_iter_based=True),
|
||||
# main learning rate scheduler
|
||||
dict(type='CosineAnnealingLR', eta_min=1e-5, by_epoch=True, begin=20)
|
||||
]
|
||||
|
||||
# train, val, test setting
|
||||
train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=1)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR,
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=1024)
|
|
@ -0,0 +1,81 @@
|
|||
# HiViT
|
||||
|
||||
> [HiViT: A Simple and More Efficient Design of Hierarchical Vision Transformer](https://arxiv.org/abs/2205.14949)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
Recently, masked image modeling (MIM) has offered a new methodology of self-supervised pre-training of vision transformers. A key idea of efficient implementation is to discard the masked image patches (or tokens) throughout the target network (encoder), which requires the encoder to be a plain vision transformer (e.g., ViT), albeit hierarchical vision transformers (e.g., Swin Transformer) have potentially better properties in formulating vision inputs. In this paper, we offer a new design of hierarchical vision transformers named HiViT (short for Hierarchical ViT) that enjoys both high efficiency and good performance in MIM. The key is to remove the unnecessary "local inter-unit operations", deriving structurally simple hierarchical vision transformers in which mask-units can be serialized like plain vision transformers. For this purpose, we start with Swin Transformer and (i) set the masking unit size to be the token size in the main stage of Swin Transformer, (ii) switch off inter-unit self-attentions before the main stage, and (iii) eliminate all operations after the main stage. Empirical studies demonstrate the advantageous performance of HiViT in terms of fully-supervised, self-supervised, and transfer learning. In particular, in running MAE on ImageNet-1K, HiViT-B reports a +0.6% accuracy gain over ViT-B and a 1.9$\times$ speed-up over Swin-B, and the performance gain generalizes to downstream tasks of detection and segmentation. Code will be made publicly available.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://github.com/open-mmlab/mmpretrain/assets/36138628/4a99cf9d-15df-4866-8750-bd2c3db5d894" width="80%"/>
|
||||
</div>
|
||||
|
||||
## How to use it?
|
||||
|
||||
<!-- [TABS-BEGIN] -->
|
||||
|
||||
<!-- **Predict image**
|
||||
|
||||
```python
|
||||
from mmpretrain import inference_model
|
||||
|
||||
predict = inference_model('hivit-tiny-p16_16xb64_in1k', 'demo/bird.JPEG')
|
||||
print(predict['pred_class'])
|
||||
print(predict['pred_score'])
|
||||
```
|
||||
|
||||
<!-- **Use the model** -->
|
||||
|
||||
<!-- ```python
|
||||
import torch
|
||||
from mmpretrain import get_model
|
||||
|
||||
model = get_model('hivit-tiny-p16_16xb64_in1k', pretrained=True)
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
out = model(inputs)
|
||||
print(type(out))
|
||||
# To extract features.
|
||||
feats = model.extract_feat(inputs)
|
||||
print(type(feats))
|
||||
``` -->
|
||||
|
||||
**Train/Test Command**
|
||||
|
||||
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
|
||||
|
||||
Train:
|
||||
|
||||
```shell
|
||||
python tools/train.py configs/hivit/hivit-tiny-p16_16xb64_in1k.py
|
||||
```
|
||||
|
||||
<!-- Test:
|
||||
|
||||
```shell
|
||||
python tools/test.py configs/hivit/hivit-tiny-p16_16xb64_in1k.py None
|
||||
``` -->
|
||||
|
||||
<!-- [TABS-END] -->
|
||||
|
||||
## Models and results
|
||||
|
||||
### Image Classification on ImageNet-1k
|
||||
|
||||
| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Config | Download |
|
||||
| :---------------------------- | :----------: | :--------: | :-------: | :-------: | :--------------------------------------: | :------: |
|
||||
| `hivit-tiny-p16_16xb64_in1k` | From scratch | 19.18 | 4.60 | 82.10 | [config](hivit-tiny-p16_16xb64_in1k.py) | N/A |
|
||||
| `hivit-small-p16_16xb64_in1k` | From scratch | 37.53 | 9.07 | N/A | [config](hivit-small-p16_16xb64_in1k.py) | N/A |
|
||||
| `hivit-base-p16_16xb64_in1k` | From scratch | 79.05 | 18.47 | N/A | [config](hivit-base-p16_16xb64_in1k.py) | N/A |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{zhanghivit,
|
||||
title={HiViT: A Simpler and More Efficient Design of Hierarchical Vision Transformer},
|
||||
author={Zhang, Xiaosong and Tian, Yunjie and Xie, Lingxi and Huang, Wei and Dai, Qi and Ye, Qixiang and Tian, Qi},
|
||||
booktitle={International Conference on Learning Representations},
|
||||
year={2023},
|
||||
}
|
||||
```
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = [
|
||||
'../_base_/models/hivit/base_224.py',
|
||||
'../_base_/datasets/imagenet_bs64_hivit_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_hivit.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = [
|
||||
'../_base_/models/hivit/small_224.py',
|
||||
'../_base_/datasets/imagenet_bs64_hivit_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_hivit.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = [
|
||||
'../_base_/models/hivit/tiny_224.py',
|
||||
'../_base_/datasets/imagenet_bs64_hivit_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_hivit.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# schedule settings
|
||||
optim_wrapper = dict(clip_grad=dict(max_norm=5.0))
|
|
@ -0,0 +1,63 @@
|
|||
Collections:
|
||||
- Name: HiViT
|
||||
Metadata:
|
||||
Architecture:
|
||||
- Dense Connections
|
||||
- Dropout
|
||||
- GELU
|
||||
- Layer Normalization
|
||||
- Multi-Head Attention
|
||||
- Scaled Dot-Product Attention
|
||||
Paper:
|
||||
Title: 'HiViT: A Simple and More Efficient Design of Hierarchical Vision Transformer'
|
||||
URL: https://arxiv.org/abs/2205.14949
|
||||
README: configs/hivit/README.md
|
||||
Code:
|
||||
URL: null
|
||||
Version: null
|
||||
|
||||
Models:
|
||||
- Name: hivit-tiny-p16_16xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 4603000000
|
||||
Parameters: 19181000
|
||||
Training Data:
|
||||
- ImageNet-1k
|
||||
In Collection: HiViT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.1
|
||||
Task: Image Classification
|
||||
Weights:
|
||||
Config: configs/hivit/hivit-tiny-p16_16xb64_in1k.py
|
||||
|
||||
- Name: hivit-small-p16_16xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 9072000000
|
||||
Parameters: 37526000
|
||||
Training Data:
|
||||
- ImageNet-1k
|
||||
In Collection: HiViT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy:
|
||||
Task: Image Classification
|
||||
Weights:
|
||||
Config: configs/hivit/hivit-small-p16_16xb64_in1k.py
|
||||
|
||||
- Name: hivit-base-p16_16xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 18474000000
|
||||
Parameters: 79051000
|
||||
Training Data:
|
||||
- ImageNet-1k
|
||||
In Collection: HiViT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy:
|
||||
Task: Image Classification
|
||||
Weights:
|
||||
Config: configs/hivit/hivit-base-p16_16xb64_in1k.py
|
|
@ -0,0 +1,65 @@
|
|||
# iTPN
|
||||
|
||||
> [Integrally Pre-Trained Transformer Pyramid Networks](https://arxiv.org/abs/2211.12735)
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
In this paper, we present an integral pre-training framework based on masked image modeling (MIM). We advocate for pre-training the backbone and neck jointly so that the transfer gap between MIM and downstream recognition tasks is minimal. We make two technical contributions. First, we unify the reconstruction and recognition necks by inserting a feature pyramid into the pre-training stage. Second, we complement mask image modeling (MIM) with masked feature modeling (MFM) that offers multi-stage supervision to the feature pyramid. The pre-trained models, termed integrally pre-trained transformer pyramid networks (iTPNs), serve as powerful foundation models for visual recognition. In particular, the base/large-level iTPN achieves an 86.2%/87.8% top-1 accuracy on ImageNet-1K, a 53.2%/55.6% box AP on COCO object detection with 1x training schedule using Mask-RCNN, and a 54.7%/57.7% mIoU on ADE20K semantic segmentation using UPerHead -- all these results set new records. Our work inspires the community to work on unifying upstream pre-training and downstream fine-tuning tasks. Code and the pre-trained models will be released at https://github.com/sunsmarterjie/iTPN.
|
||||
|
||||
<div align=center>
|
||||
<img src="https://github.com/open-mmlab/mmpretrain/assets/36138628/2e53d5b5-300e-4640-8507-c1173965ca62" width="80%"/>
|
||||
</div>
|
||||
|
||||
## How to use it?
|
||||
|
||||
<!-- [TABS-BEGIN] -->
|
||||
|
||||
<!-- **Use the model**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from mmpretrain import get_model
|
||||
|
||||
model = get_model('itpn-clip-b_hivit-base-p16_8xb256-amp-coslr-800e_in1k', pretrained=True)
|
||||
inputs = torch.rand(1, 3, 224, 224)
|
||||
out = model(inputs)
|
||||
print(type(out))
|
||||
# To extract features.
|
||||
feats = model.extract_feat(inputs)
|
||||
print(type(feats))
|
||||
``` -->
|
||||
|
||||
**Train/Test Command**
|
||||
|
||||
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
|
||||
|
||||
Train:
|
||||
|
||||
```shell
|
||||
python tools/train.py configs/itpn/itpn-pixel_hivit-base-p16_8xb512-amp-coslr-800e_in1k.py
|
||||
```
|
||||
|
||||
<!-- [TABS-END] -->
|
||||
|
||||
## Models and results
|
||||
|
||||
### Pretrained models
|
||||
|
||||
| Model | Params (M) | Flops (G) | Config | Download |
|
||||
| :------------------------------------------------------ | :--------: | :-------: | :----------------------------------------------------------------: | :------: |
|
||||
| `itpn-clip-b_hivit-base-p16_8xb256-amp-coslr-800e_in1k` | 233.00 | 18.47 | [config](itpn-clip-b_hivit-base-p16_8xb256-amp-coslr-800e_in1k.py) | N/A |
|
||||
| `itpn-pixel_hivit-base-p16_8xb512-amp-coslr-800e_in1k` | 103.00 | 18.47 | [config](itpn-pixel_hivit-base-p16_8xb512-amp-coslr-800e_in1k.py) | N/A |
|
||||
| `itpn-pixel_hivit-large-p16_8xb512-amp-coslr-800e_in1k` | 314.00 | 63.98 | [config](itpn-pixel_hivit-large-p16_8xb512-amp-coslr-800e_in1k.py) | N/A |
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{tian2022integrally,
|
||||
title={Integrally Pre-Trained Transformer Pyramid Networks},
|
||||
author={Tian, Yunjie and Xie, Lingxi and Wang, Zhaozhi and Wei, Longhui and Zhang, Xiaopeng and Jiao, Jianbin and Wang, Yaowei and Tian, Qi and Ye, Qixiang},
|
||||
journal={arXiv preprint arXiv:2211.12735},
|
||||
year={2022}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,84 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs256_itpn.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='iTPN',
|
||||
backbone=dict(
|
||||
type='iTPNHiViT',
|
||||
arch='base',
|
||||
drop_path_rate=0.0,
|
||||
rpe=True,
|
||||
layer_scale_init_value=0.1,
|
||||
reconstruction_type='clip'),
|
||||
neck=dict(
|
||||
type='iTPNPretrainDecoder',
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=512,
|
||||
mlp_ratio=4.,
|
||||
reconstruction_type='clip',
|
||||
# transformer pyramid
|
||||
fpn_dim=256,
|
||||
fpn_depth=2,
|
||||
num_outs=3,
|
||||
),
|
||||
head=dict(
|
||||
type='iTPNClipHead',
|
||||
embed_dims=512,
|
||||
num_embed=512,
|
||||
loss=dict(type='CosineSimilarityLoss')),
|
||||
target_generator=dict(
|
||||
type='CLIPGenerator',
|
||||
tokenizer_path= # noqa
|
||||
'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/clip_vit_base_16.pth.tar' # noqa
|
||||
),
|
||||
)
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
# betas: (0.9, 0.98) for 300 epochs and (0.9, 0.999) for 1600 epochs.
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=1.5e-3, betas=(0.9, 0.98), weight_decay=0.05),
|
||||
clip_grad=dict(max_norm=3.0),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'.norm': dict(decay_mult=0.0),
|
||||
'.pos_embed': dict(decay_mult=0.0),
|
||||
'.gamma': dict(decay_mult=0.0),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=10,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
eta_min=1e-5,
|
||||
by_epoch=True,
|
||||
begin=10,
|
||||
end=300,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=2048)
|
|
@ -0,0 +1,84 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs256_itpn.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
model = dict(
|
||||
type='iTPN',
|
||||
backbone=dict(
|
||||
type='iTPNHiViT',
|
||||
arch='base',
|
||||
drop_path_rate=0.1,
|
||||
rpe=True,
|
||||
layer_scale_init_value=0.1,
|
||||
reconstruction_type='clip'),
|
||||
neck=dict(
|
||||
type='iTPNPretrainDecoder',
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=512,
|
||||
mlp_ratio=4.,
|
||||
reconstruction_type='clip',
|
||||
# transformer pyramid
|
||||
fpn_dim=256,
|
||||
fpn_depth=2,
|
||||
num_outs=3,
|
||||
),
|
||||
head=dict(
|
||||
type='iTPNClipHead',
|
||||
embed_dims=512,
|
||||
num_embed=512,
|
||||
loss=dict(type='CrossEntropyLoss')),
|
||||
target_generator=dict(
|
||||
type='CLIPGenerator',
|
||||
tokenizer_path= # noqa
|
||||
'https://download.openmmlab.com/mmselfsup/1.x/target_generator_ckpt/clip_vit_base_16.pth.tar' # noqa
|
||||
),
|
||||
)
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
# betas: (0.9, 0.98) for 300 epochs and (0.9, 0.999) for 800/1600 epochs.
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=1.5e-3, betas=(0.9, 0.999), weight_decay=0.05),
|
||||
clip_grad=dict(max_norm=3.0),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'.norm': dict(decay_mult=0.0),
|
||||
'.pos_embed': dict(decay_mult=0.0),
|
||||
'.gamma': dict(decay_mult=0.0),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=10,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
eta_min=1e-5,
|
||||
by_epoch=True,
|
||||
begin=10,
|
||||
end=800,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=2048)
|
|
@ -0,0 +1,56 @@
|
|||
_base_ = [
|
||||
'../_base_/models/itpn_hivit-base-p16.py',
|
||||
'../_base_/datasets/imagenet_bs512_mae.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=1.5e-4 * 4096 / 256,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'norm': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0),
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'mask_token': dict(decay_mult=0.),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=40,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=1560,
|
||||
by_epoch=True,
|
||||
begin=40,
|
||||
end=1600,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=1600)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
# auto resume
|
||||
resume = True
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -0,0 +1,56 @@
|
|||
_base_ = [
|
||||
'../_base_/models/itpn_hivit-base-p16.py',
|
||||
'../_base_/datasets/imagenet_bs512_mae.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=1.5e-4 * 4096 / 256,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'norm': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0),
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'mask_token': dict(decay_mult=0.),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=40,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=360,
|
||||
by_epoch=True,
|
||||
begin=40,
|
||||
end=400,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=400)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
# auto resume
|
||||
resume = True
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -0,0 +1,56 @@
|
|||
_base_ = [
|
||||
'../_base_/models/itpn_hivit-base-p16.py',
|
||||
'../_base_/datasets/imagenet_bs512_mae.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=1.5e-4 * 4096 / 256,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'norm': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0),
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'mask_token': dict(decay_mult=0.),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=40,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=760,
|
||||
by_epoch=True,
|
||||
begin=40,
|
||||
end=800,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
# auto resume
|
||||
resume = True
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -0,0 +1,61 @@
|
|||
_base_ = [
|
||||
'../_base_/models/itpn_hivit-base-p16.py',
|
||||
'../_base_/datasets/imagenet_bs512_mae.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='iTPNHiViT', arch='large'),
|
||||
neck=dict(type='iTPNPretrainDecoder', embed_dim=768))
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=1.5e-4 * 4096 / 256,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'ln': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0),
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'mask_token': dict(decay_mult=0.),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=40,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=1560,
|
||||
by_epoch=True,
|
||||
begin=40,
|
||||
end=1600,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=1600)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
# auto resume
|
||||
resume = True
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -0,0 +1,61 @@
|
|||
_base_ = [
|
||||
'../_base_/models/itpn_hivit-base-p16.py',
|
||||
'../_base_/datasets/imagenet_bs512_mae.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='iTPNHiViT', arch='large'),
|
||||
neck=dict(type='iTPNPretrainDecoder', embed_dim=768))
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=1.5e-4 * 4096 / 256,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'ln': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0),
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'mask_token': dict(decay_mult=0.),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=40,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=360,
|
||||
by_epoch=True,
|
||||
begin=40,
|
||||
end=400,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=400)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
# auto resume
|
||||
resume = True
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -0,0 +1,61 @@
|
|||
_base_ = [
|
||||
'../_base_/models/itpn_hivit-base-p16.py',
|
||||
'../_base_/datasets/imagenet_bs512_mae.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='iTPNHiViT', arch='large'),
|
||||
neck=dict(type='iTPNPretrainDecoder', embed_dim=768))
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=1.5e-4 * 4096 / 256,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'ln': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0),
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'mask_token': dict(decay_mult=0.),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=40,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=760,
|
||||
by_epoch=True,
|
||||
begin=40,
|
||||
end=800,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
# auto resume
|
||||
resume = True
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -0,0 +1,50 @@
|
|||
Collections:
|
||||
- Name: iTPN
|
||||
Metadata:
|
||||
Architecture:
|
||||
- Dense Connections
|
||||
- GELU
|
||||
- Layer Normalization
|
||||
- Multi-Head Attention
|
||||
- Scaled Dot-Product Attention
|
||||
Paper:
|
||||
Title: 'Integrally Pre-Trained Transformer Pyramid Networks'
|
||||
URL: https://arxiv.org/abs/2211.12735
|
||||
README: configs/itpn/README.md
|
||||
Code:
|
||||
URL: null
|
||||
Version: null
|
||||
|
||||
Models:
|
||||
- Name: itpn-clip-b_hivit-base-p16_8xb256-amp-coslr-800e_in1k
|
||||
Metadata:
|
||||
FLOPs: 18474000000
|
||||
Parameters: 233000000
|
||||
Training Data:
|
||||
- ImageNet-1k
|
||||
In Collection: iTPN
|
||||
Results: null
|
||||
Weights:
|
||||
Config: configs/itpn/itpn-clip-b_hivit-base-p16_8xb256-amp-coslr-800e_in1k.py
|
||||
|
||||
- Name: itpn-pixel_hivit-base-p16_8xb512-amp-coslr-800e_in1k
|
||||
Metadata:
|
||||
FLOPs: 18474000000
|
||||
Parameters: 103000000
|
||||
Training Data:
|
||||
- ImageNet-1k
|
||||
In Collection: iTPN
|
||||
Results: null
|
||||
Weights:
|
||||
Config: configs/itpn/itpn-pixel_hivit-base-p16_8xb512-amp-coslr-800e_in1k.py
|
||||
|
||||
- Name: itpn-pixel_hivit-large-p16_8xb512-amp-coslr-800e_in1k
|
||||
Metadata:
|
||||
FLOPs: 63977000000
|
||||
Parameters: 314000000
|
||||
Training Data:
|
||||
- ImageNet-1k
|
||||
In Collection: iTPN
|
||||
Results: null
|
||||
Weights:
|
||||
Config: configs/itpn/itpn-pixel_hivit-large-p16_8xb512-amp-coslr-800e_in1k.py
|
|
@ -0,0 +1,56 @@
|
|||
_base_ = [
|
||||
'../_base_/models/mae_hivit-base-p16.py',
|
||||
'../_base_/datasets/imagenet_bs512_mae.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=1.5e-4 * 4096 / 256,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'norm': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0),
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'mask_token': dict(decay_mult=0.),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=40,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=1560,
|
||||
by_epoch=True,
|
||||
begin=40,
|
||||
end=1600,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=1600)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
# auto resume
|
||||
resume = True
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -0,0 +1,56 @@
|
|||
_base_ = [
|
||||
'../_base_/models/mae_hivit-base-p16.py',
|
||||
'../_base_/datasets/imagenet_bs512_mae.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=1.5e-4 * 4096 / 256,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'norm': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0),
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'mask_token': dict(decay_mult=0.),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=40,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=360,
|
||||
by_epoch=True,
|
||||
begin=40,
|
||||
end=400,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=400)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
# auto resume
|
||||
resume = True
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -0,0 +1,56 @@
|
|||
_base_ = [
|
||||
'../_base_/models/mae_hivit-base-p16.py',
|
||||
'../_base_/datasets/imagenet_bs512_mae.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=1.5e-4 * 4096 / 256,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'norm': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0),
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'mask_token': dict(decay_mult=0.),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1e-4,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=40,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=760,
|
||||
by_epoch=True,
|
||||
begin=40,
|
||||
end=800,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
# auto resume
|
||||
resume = True
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -0,0 +1,61 @@
|
|||
_base_ = [
|
||||
'../_base_/models/mae_hivit-base-p16.py',
|
||||
'../_base_/datasets/imagenet_bs512_mae.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='MAEHiViT', arch='large'),
|
||||
neck=dict(type='MAEPretrainDecoder', embed_dim=768))
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=1.5e-4 * 4096 / 256,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'norm': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0),
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'mask_token': dict(decay_mult=0.),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=0.0001,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=40,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=1560,
|
||||
by_epoch=True,
|
||||
begin=40,
|
||||
end=1600,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=1600)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
# auto resume
|
||||
resume = True
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -0,0 +1,61 @@
|
|||
_base_ = [
|
||||
'../_base_/models/mae_hivit-base-p16.py',
|
||||
'../_base_/datasets/imagenet_bs512_mae.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='MIMHiViT', arch='large'),
|
||||
neck=dict(type='MAEPretrainDecoder', embed_dim=768))
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=1.5e-4 * 4096 / 256,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'norm': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0),
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'mask_token': dict(decay_mult=0.),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=0.0001,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=40,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=360,
|
||||
by_epoch=True,
|
||||
begin=40,
|
||||
end=400,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=400)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
# auto resume
|
||||
resume = True
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -0,0 +1,61 @@
|
|||
_base_ = [
|
||||
'../_base_/models/mae_hivit-base-p16.py',
|
||||
'../_base_/datasets/imagenet_bs512_mae.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='MAEHiViT', arch='large'),
|
||||
neck=dict(type='MAEPretrainDecoder', embed_dim=768))
|
||||
|
||||
# optimizer wrapper
|
||||
optim_wrapper = dict(
|
||||
type='AmpOptimWrapper',
|
||||
loss_scale='dynamic',
|
||||
optimizer=dict(
|
||||
type='AdamW',
|
||||
lr=1.5e-4 * 4096 / 256,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'norm': dict(decay_mult=0.0),
|
||||
'bias': dict(decay_mult=0.0),
|
||||
'pos_embed': dict(decay_mult=0.),
|
||||
'mask_token': dict(decay_mult=0.),
|
||||
}))
|
||||
|
||||
# learning rate scheduler
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=0.0001,
|
||||
by_epoch=True,
|
||||
begin=0,
|
||||
end=40,
|
||||
convert_to_iter_based=True),
|
||||
dict(
|
||||
type='CosineAnnealingLR',
|
||||
T_max=760,
|
||||
by_epoch=True,
|
||||
begin=40,
|
||||
end=800,
|
||||
convert_to_iter_based=True)
|
||||
]
|
||||
|
||||
# runtime settings
|
||||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=800)
|
||||
default_hooks = dict(
|
||||
# only keeps the latest 3 checkpoints
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
|
||||
|
||||
randomness = dict(seed=0, diff_rank_seed=True)
|
||||
|
||||
# auto resume
|
||||
resume = True
|
||||
find_unused_parameters = True
|
||||
|
||||
# NOTE: `auto_scale_lr` is for automatically scaling LR
|
||||
# based on the actual training batch size.
|
||||
auto_scale_lr = dict(base_batch_size=4096)
|
|
@ -66,6 +66,7 @@ Self-supervised Algorithms
|
|||
CAE
|
||||
DenseCL
|
||||
EVA
|
||||
iTPN
|
||||
MAE
|
||||
MILAN
|
||||
MaskFeat
|
||||
|
@ -88,6 +89,8 @@ like ``mask``, and here is the a list of these **modified backbone** modules.
|
|||
|
||||
BEiTPretrainViT
|
||||
CAEPretrainViT
|
||||
iTPNHiViT
|
||||
MAEHiViT
|
||||
MAEViT
|
||||
MILANViT
|
||||
MaskFeatViT
|
||||
|
@ -167,6 +170,7 @@ Backbones
|
|||
EfficientFormer
|
||||
EfficientNet
|
||||
EfficientNetV2
|
||||
HiViT
|
||||
HRNet
|
||||
HorNet
|
||||
InceptionV3
|
||||
|
@ -235,6 +239,7 @@ Necks
|
|||
NonLinearNeck
|
||||
SimMIMLinearDecoder
|
||||
SwAVNeck
|
||||
iTPNPretrainDecoder
|
||||
|
||||
.. module:: mmpretrain.models.heads
|
||||
|
||||
|
@ -271,6 +276,7 @@ Heads
|
|||
SwAVHead
|
||||
VigClsHead
|
||||
VisionTransformerClsHead
|
||||
iTPNClipHead
|
||||
|
||||
.. module:: mmpretrain.models.losses
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from .edgenext import EdgeNeXt
|
|||
from .efficientformer import EfficientFormer
|
||||
from .efficientnet import EfficientNet
|
||||
from .efficientnet_v2 import EfficientNetV2
|
||||
from .hivit import HiViT
|
||||
from .hornet import HorNet
|
||||
from .hrnet import HRNet
|
||||
from .inception_v3 import InceptionV3
|
||||
|
@ -120,4 +121,5 @@ __all__ = [
|
|||
'XCiT',
|
||||
'ViTSAM',
|
||||
'ViTEVA02',
|
||||
'HiViT',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,656 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn.bricks import DropPath
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..utils import build_norm_layer, to_2tuple
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""MLP block.
|
||||
|
||||
Args:
|
||||
in_features (int): Number of input dims.
|
||||
hidden_features (int): Number of hidden dims.
|
||||
out_feature (int): Number of out dims.
|
||||
act_layer: MLP activation layer.
|
||||
drop (float): MLP dropout rate.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Attention.
|
||||
|
||||
Args:
|
||||
input size (int): Input size.
|
||||
dim (int): Number of input dims.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool): Enable bias for qkv projections if True.
|
||||
qk_scale (float): The number of divider after q@k. Default to None.
|
||||
attn_drop (float): The drop out rate for attention output weights.
|
||||
Defaults to 0.
|
||||
proj_drop (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.
|
||||
rpe (bool): If True, add relative position embedding to
|
||||
the patch embedding.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
dim,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
rpe=True):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * input_size - 1) *
|
||||
(2 * input_size - 1), num_heads)) if rpe else None
|
||||
if rpe:
|
||||
coords_h = torch.arange(input_size)
|
||||
coords_w = torch.arange(input_size)
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = coords_flatten[:, :,
|
||||
None] - coords_flatten[:, None, :]
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
relative_coords[:, :, 0] += input_size - 1
|
||||
relative_coords[:, :, 1] += input_size - 1
|
||||
relative_coords[:, :, 0] *= 2 * input_size - 1
|
||||
relative_position_index = relative_coords.sum(-1)
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, rpe_index=None, mask=None):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[
|
||||
2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
if rpe_index is not None:
|
||||
rpe_index = self.relative_position_index.view(-1)
|
||||
S = int(math.sqrt(rpe_index.size(-1)))
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
rpe_index].view(-1, S, S, self.num_heads)
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
0, 3, 1, 2).contiguous()
|
||||
attn = attn + relative_position_bias
|
||||
if mask is not None:
|
||||
mask = mask.bool()
|
||||
attn = attn.masked_fill(~mask[:, None, None, :], float('-inf'))
|
||||
attn = self.softmax(attn)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class BlockWithRPE(nn.Module):
|
||||
"""HiViT block.
|
||||
|
||||
Args:
|
||||
input_size (int): Input size.
|
||||
dim (int): Number of input dims.
|
||||
num_heads (int): Number of attention heads.
|
||||
mlp_ratio (int): Ratio of MLP hidden dim to embedding dim.
|
||||
qkv_bias (bool): Enable bias for qkv projections if True.
|
||||
qk_scale (float): The number of divider after q@k. Default to None.
|
||||
drop (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.
|
||||
attn_drop (float): The drop out rate for attention output weights.
|
||||
Defaults to 0.
|
||||
drop_path (float): Stochastic depth rate. Defaults to 0.
|
||||
rpe (bool): If True, add relative position embedding to
|
||||
the patch embedding.
|
||||
layer_scale_init_value (float): Layer-scale init values. Defaults to 0.
|
||||
act_layer: MLP activation layer.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
dim,
|
||||
num_heads=0.,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
rpe=True,
|
||||
layer_scale_init_value=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_cfg=dict(type='LN')):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
with_attn = num_heads > 0.
|
||||
|
||||
self.norm1 = build_norm_layer(norm_cfg, dim) if with_attn else None
|
||||
self.attn = Attention(
|
||||
input_size,
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
rpe=rpe,
|
||||
) if with_attn else None
|
||||
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = build_norm_layer(norm_cfg, dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
|
||||
if layer_scale_init_value > 0:
|
||||
self.gamma_1 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(
|
||||
(dim)), requires_grad=True) if with_attn else None
|
||||
self.gamma_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
||||
else:
|
||||
self.gamma_1, self.gamma_2 = None, None
|
||||
|
||||
def forward(self, x, rpe_index=None, mask=None):
|
||||
if self.attn is not None:
|
||||
if self.gamma_1 is not None:
|
||||
x = x + self.drop_path(
|
||||
self.gamma_1 * self.attn(self.norm1(x), rpe_index, mask))
|
||||
else:
|
||||
x = x + self.drop_path(
|
||||
self.attn(self.norm1(x), rpe_index, mask))
|
||||
if self.gamma_2 is not None:
|
||||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""PatchEmbed for HiViT.
|
||||
|
||||
Args:
|
||||
img_size (int): Input image size.
|
||||
patch_size (int): Patch size. Defaults to 16.
|
||||
inner_patches (int): Inner patch. Defaults to 4.
|
||||
in_chans (int): Number of image input channels.
|
||||
embed_dim (int): Transformer embedding dimension.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
kernel_size (int): Kernel size.
|
||||
pad_size (int): Pad size.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
inner_patches=4,
|
||||
in_chans=3,
|
||||
embed_dim=128,
|
||||
norm_cfg=None,
|
||||
kernel_size=None,
|
||||
pad_size=None):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size) if not isinstance(img_size,
|
||||
tuple) else img_size
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [
|
||||
img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
||||
]
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.inner_patches = inner_patches
|
||||
self.patches_resolution = patches_resolution
|
||||
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
conv_size = [size // inner_patches for size in patch_size]
|
||||
kernel_size = kernel_size or conv_size
|
||||
pad_size = pad_size or 0
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=conv_size,
|
||||
padding=pad_size)
|
||||
if norm_cfg is not None:
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
patches_resolution = (H // self.patch_size[0], W // self.patch_size[1])
|
||||
num_patches = patches_resolution[0] * patches_resolution[1]
|
||||
x = self.proj(x).view(
|
||||
B,
|
||||
-1,
|
||||
patches_resolution[0],
|
||||
self.inner_patches,
|
||||
patches_resolution[1],
|
||||
self.inner_patches,
|
||||
).permute(0, 2, 4, 3, 5, 1).reshape(B, num_patches, self.inner_patches,
|
||||
self.inner_patches, -1)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerge(nn.Module):
|
||||
"""PatchMerge for HiViT.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, norm_cfg):
|
||||
super().__init__()
|
||||
self.norm = build_norm_layer(norm_cfg, dim * 4)
|
||||
self.reduction = nn.Linear(dim * 4, dim * 2, bias=False)
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
is_main_stage = len(x.shape) == 3
|
||||
if is_main_stage:
|
||||
B, N, C = x.shape
|
||||
S = int(math.sqrt(N))
|
||||
x = x.reshape(B, S // 2, 2, S // 2, 2, C) \
|
||||
.permute(0, 1, 3, 2, 4, 5) \
|
||||
.reshape(B, -1, 2, 2, C)
|
||||
x0 = x[..., 0::2, 0::2, :]
|
||||
x1 = x[..., 1::2, 0::2, :]
|
||||
x2 = x[..., 0::2, 1::2, :]
|
||||
x3 = x[..., 1::2, 1::2, :]
|
||||
|
||||
x = torch.cat([x0, x1, x2, x3], dim=-1)
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
if is_main_stage:
|
||||
x = x[:, :, 0, 0, :]
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class HiViT(BaseBackbone):
|
||||
"""HiViT.
|
||||
|
||||
A PyTorch implement of: `HiViT: A Simple and More Efficient Design
|
||||
of Hierarchical Vision Transformer <https://arxiv.org/abs/2205.14949>`_.
|
||||
|
||||
Args:
|
||||
arch (str | dict): Swin Transformer architecture. If use string, choose
|
||||
from 'tiny', 'small', and'base'. If use dict, it should
|
||||
have below keys:
|
||||
|
||||
- **embed_dims** (int): The dimensions of embedding.
|
||||
- **depths** (List[int]): The number of blocks in each stage.
|
||||
- **num_heads** (int): The number of heads in attention
|
||||
modules of each stage.
|
||||
|
||||
Defaults to 'tiny'.
|
||||
img_size (int): Input image size.
|
||||
patch_size (int): Patch size. Defaults to 16.
|
||||
inner_patches (int): Inner patch. Defaults to 4.
|
||||
in_chans (int): Number of image input channels.
|
||||
embed_dim (int): Transformer embedding dimension.
|
||||
depths (list[int]): Number of successive HiViT blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim
|
||||
in the first two stages.
|
||||
mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in
|
||||
the last stage.
|
||||
qkv_bias (bool): Enable bias for qkv projections if True.
|
||||
qk_scale (float): The number of divider after q@k. Default to None.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.
|
||||
attn_drop_rate (float): The drop out rate for attention output weights.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
ape (bool): If True, add absolute position embedding to
|
||||
the patch embedding.
|
||||
rpe (bool): If True, add relative position embedding to
|
||||
the patch embedding.
|
||||
patch_norm (bool): If True, use norm_cfg for normalization layer.
|
||||
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
||||
-1 means not freezing any parameters. Defaults to -1.
|
||||
kernel_size (int): Kernel size.
|
||||
pad_size (int): Pad size.
|
||||
layer_scale_init_value (float): Layer-scale init values. Defaults to 0.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
arch_zoo = {
|
||||
**dict.fromkeys(['t', 'tiny'],
|
||||
{'embed_dims': 384,
|
||||
'depths': [1, 1, 10],
|
||||
'num_heads': 6}),
|
||||
**dict.fromkeys(['s', 'small'],
|
||||
{'embed_dims': 384,
|
||||
'depths': [2, 2, 20],
|
||||
'num_heads': 6}),
|
||||
**dict.fromkeys(['b', 'base'],
|
||||
{'embed_dims': 512,
|
||||
'depths': [2, 2, 24],
|
||||
'num_heads': 8}),
|
||||
**dict.fromkeys(['l', 'large'],
|
||||
{'embed_dims': 768,
|
||||
'depths': [2, 2, 40],
|
||||
'num_heads': 12}),
|
||||
} # yapf: disable
|
||||
|
||||
num_extra_tokens = 0
|
||||
|
||||
def __init__(self,
|
||||
arch='base',
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
inner_patches=4,
|
||||
in_chans=3,
|
||||
stem_mlp_ratio=3.,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.0,
|
||||
norm_cfg=dict(type='LN'),
|
||||
out_indices=[23],
|
||||
ape=True,
|
||||
rpe=False,
|
||||
patch_norm=True,
|
||||
frozen_stages=-1,
|
||||
kernel_size=None,
|
||||
pad_size=None,
|
||||
layer_scale_init_value=0.0,
|
||||
init_cfg=None):
|
||||
super(HiViT, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
arch = arch.lower()
|
||||
assert arch in set(self.arch_zoo), \
|
||||
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
|
||||
self.arch_settings = self.arch_zoo[arch]
|
||||
else:
|
||||
essential_keys = {'embed_dims', 'depths', 'num_heads'}
|
||||
assert isinstance(arch, dict) and set(arch) == essential_keys, \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
self.embed_dims = self.arch_settings['embed_dims']
|
||||
self.depths = self.arch_settings['depths']
|
||||
self.num_heads = self.arch_settings['num_heads']
|
||||
|
||||
self.num_stages = len(self.depths)
|
||||
self.ape = ape
|
||||
self.rpe = rpe
|
||||
self.patch_size = patch_size
|
||||
self.num_features = self.embed_dims
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.num_main_blocks = self.depths[-1]
|
||||
self.out_indices = out_indices
|
||||
self.out_indices[-1] = self.depths[-1] - 1
|
||||
|
||||
img_size = to_2tuple(img_size) if not isinstance(img_size,
|
||||
tuple) else img_size
|
||||
|
||||
embed_dim = self.embed_dims // 2**(self.num_stages - 1)
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
inner_patches=inner_patches,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
norm_cfg=norm_cfg if patch_norm else None,
|
||||
kernel_size=kernel_size,
|
||||
pad_size=pad_size)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
Hp, Wp = self.patch_embed.patches_resolution
|
||||
|
||||
if rpe:
|
||||
assert Hp == Wp, 'If you use relative position, make sure H == W '
|
||||
'of input size'
|
||||
|
||||
# absolute position embedding
|
||||
if ape:
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches, self.num_features))
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
if rpe:
|
||||
# get pair-wise relative position index for each token inside the
|
||||
# window
|
||||
coords_h = torch.arange(Hp)
|
||||
coords_w = torch.arange(Wp)
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
||||
coords_flatten = torch.flatten(coords, 1)
|
||||
relative_coords = coords_flatten[:, :,
|
||||
None] - coords_flatten[:, None, :]
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
||||
relative_coords[:, :, 0] += Hp - 1
|
||||
relative_coords[:, :, 1] += Wp - 1
|
||||
relative_coords[:, :, 0] *= 2 * Wp - 1
|
||||
relative_position_index = relative_coords.sum(-1)
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
# stochastic depth
|
||||
dpr = iter(
|
||||
x.item()
|
||||
for x in torch.linspace(0, drop_path_rate,
|
||||
sum(self.depths) + sum(self.depths[:-1])))
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList()
|
||||
for stage_i, stage_depth in enumerate(self.depths):
|
||||
is_main_stage = embed_dim == self.num_features
|
||||
nhead = self.num_heads if is_main_stage else 0
|
||||
ratio = mlp_ratio if is_main_stage else stem_mlp_ratio
|
||||
# every block not in main stage includes two mlp blocks
|
||||
stage_depth = stage_depth if is_main_stage else stage_depth * 2
|
||||
for _ in range(stage_depth):
|
||||
self.blocks.append(
|
||||
BlockWithRPE(
|
||||
Hp,
|
||||
embed_dim,
|
||||
nhead,
|
||||
ratio,
|
||||
qkv_bias,
|
||||
qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=next(dpr),
|
||||
rpe=rpe,
|
||||
norm_cfg=norm_cfg,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
))
|
||||
if stage_i + 1 < self.num_stages:
|
||||
self.blocks.append(PatchMerge(embed_dim, norm_cfg))
|
||||
embed_dim *= 2
|
||||
|
||||
self.frozen_stages = frozen_stages
|
||||
if self.frozen_stages > 0:
|
||||
self._freeze_stages()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def interpolate_pos_encoding(self, x, h, w):
|
||||
npatch = x.shape[1]
|
||||
N = self.pos_embed.shape[1]
|
||||
if npatch == N and w == h:
|
||||
return self.pos_embed
|
||||
patch_pos_embed = self.pos_embed
|
||||
dim = x.shape[-1]
|
||||
w0 = w // self.patch_size
|
||||
h0 = h // self.patch_size
|
||||
# we add a small number to avoid floating point error in interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
w0, h0 = w0 + 0.1, h0 + 0.1
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)),
|
||||
dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)),
|
||||
mode='bicubic',
|
||||
)
|
||||
assert int(h0) == patch_pos_embed.shape[-2] and int(
|
||||
w0) == patch_pos_embed.shape[-1]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return patch_pos_embed
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
Hp, Wp = H // self.patch_size, W // self.patch_size
|
||||
|
||||
x = self.patch_embed(x)
|
||||
|
||||
outs = []
|
||||
for i, blk in enumerate(self.blocks[:-self.num_main_blocks]):
|
||||
x = blk(x)
|
||||
if i in self.out_indices:
|
||||
x = x.reshape(B, Hp, Wp, *x.shape[-3:]).permute(
|
||||
0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * x.shape[-3],
|
||||
Wp * x.shape[-2]).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
x = x[..., 0, 0, :]
|
||||
if self.ape:
|
||||
x = x + self.interpolate_pos_encoding(x, H, W)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
rpe_index = True if self.rpe else None
|
||||
|
||||
for i, blk in enumerate(self.blocks[-self.num_main_blocks:]):
|
||||
x = blk(x, rpe_index)
|
||||
if i in self.out_indices:
|
||||
x = x.transpose(1, 2).view(B, -1, Hp, Wp).contiguous()
|
||||
outs.append(x)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
# freeze position embedding
|
||||
if self.pos_embed is not None:
|
||||
self.pos_embed.requires_grad = False
|
||||
# set dropout to eval model
|
||||
self.pos_drop.eval()
|
||||
# freeze patch embedding
|
||||
self.patch_embed.eval()
|
||||
for param in self.patch_embed.parameters():
|
||||
param.requires_grad = False
|
||||
# freeze layers
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
m = self.blocks[i - 1]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
# freeze the last layer norm
|
||||
for param in self.fc_norm.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def get_layer_depth(self, param_name: str, prefix: str = ''):
|
||||
"""Get the layer-wise depth of a parameter.
|
||||
|
||||
Args:
|
||||
param_name (str): The name of the parameter.
|
||||
prefix (str): The prefix for the parameter.
|
||||
Defaults to an empty string.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The layer-wise depth and the num of layers.
|
||||
|
||||
Note:
|
||||
The first depth is the stem module (``layer_depth=0``), and the
|
||||
last depth is the subsequent module (``layer_depth=num_layers-1``)
|
||||
"""
|
||||
self.num_layers = len(self.blocks)
|
||||
num_layers = self.num_layers + 2
|
||||
|
||||
if not param_name.startswith(prefix):
|
||||
# For subsequent module like head
|
||||
return num_layers - 1, num_layers
|
||||
|
||||
param_name = param_name[len(prefix):]
|
||||
|
||||
if param_name in 'pos_embed':
|
||||
layer_depth = 0
|
||||
elif param_name.startswith('patch_embed'):
|
||||
layer_depth = 0
|
||||
elif param_name.startswith('layers'):
|
||||
layer_id = int(param_name.split('.')[1])
|
||||
layer_depth = layer_id + 1
|
||||
else:
|
||||
layer_depth = num_layers - 1
|
||||
|
||||
return layer_depth, num_layers
|
|
@ -10,6 +10,7 @@ from .efficientformer_head import EfficientFormerClsHead
|
|||
from .grounding_head import GroundingHead
|
||||
from .itc_head import ITCHead
|
||||
from .itm_head import ITMHead
|
||||
from .itpn_clip_head import iTPNClipHead
|
||||
from .latent_heads import LatentCrossCorrelationHead, LatentPredictHead
|
||||
from .levit_head import LeViTClsHead
|
||||
from .linear_head import LinearClsHead
|
||||
|
@ -62,4 +63,5 @@ __all__ = [
|
|||
'ITCHead',
|
||||
'ITMHead',
|
||||
'GroundingHead',
|
||||
'iTPNClipHead',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.registry import MODELS
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class iTPNClipHead(BaseModule):
|
||||
"""Head for iTPN Pre-training using Clip.
|
||||
|
||||
Compute the logits and the cross entropy loss.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The dimension of embedding.
|
||||
num_embed (int): The number of classification types.
|
||||
loss (dict): The config of loss.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims: int,
|
||||
num_embed: int,
|
||||
loss: dict,
|
||||
init_cfg: Optional[Union[dict, List[dict]]] = dict(
|
||||
type='TruncNormal', layer='Linear', std=0.02, bias=0)
|
||||
) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.cls_head = nn.Linear(embed_dims, num_embed)
|
||||
self.loss_module = MODELS.build(loss)
|
||||
|
||||
def loss(self, feats: torch.Tensor, target: torch.Tensor,
|
||||
mask: torch.Tensor) -> torch.Tensor:
|
||||
"""Generate loss.
|
||||
|
||||
Args:
|
||||
feats (torch.Tensor): Features from backbone.
|
||||
target (torch.Tensor): Target generated by target_generator.
|
||||
mask (torch.Tensor): Generated mask for pretraing.
|
||||
"""
|
||||
|
||||
mask = mask.to(torch.device('cuda'), non_blocking=True)
|
||||
mask = mask.flatten(1).to(torch.bool)
|
||||
target = target[mask]
|
||||
|
||||
# remove cls_token
|
||||
# feats = feats[:, 1:]
|
||||
logits = self.cls_head(feats[mask])
|
||||
|
||||
loss = self.loss_module(logits, target)
|
||||
return loss
|
|
@ -5,6 +5,7 @@ from .densecl_neck import DenseCLNeck
|
|||
from .gap import GlobalAveragePooling
|
||||
from .gem import GeneralizedMeanPooling
|
||||
from .hr_fuse import HRFuseScales
|
||||
from .itpn_neck import iTPNPretrainDecoder
|
||||
from .linear_neck import LinearNeck
|
||||
from .mae_neck import ClsBatchNormNeck, MAEPretrainDecoder
|
||||
from .milan_neck import MILANPretrainDecoder
|
||||
|
@ -30,4 +31,5 @@ __all__ = [
|
|||
'NonLinearNeck',
|
||||
'SimMIMLinearDecoder',
|
||||
'SwAVNeck',
|
||||
'iTPNPretrainDecoder',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,388 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmpretrain.models.backbones.hivit import BlockWithRPE
|
||||
from mmpretrain.registry import MODELS
|
||||
from ..backbones.vision_transformer import TransformerEncoderLayer
|
||||
from ..utils import build_2d_sincos_position_embedding
|
||||
|
||||
|
||||
class PatchSplit(nn.Module):
|
||||
"""The up-sample module used in neck (transformer pyramid network)
|
||||
|
||||
Args:
|
||||
dim (int): the input dimension (channel number).
|
||||
fpn_dim (int): the fpn dimension (channel number).
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, fpn_dim, norm_cfg):
|
||||
super().__init__()
|
||||
_, self.norm = build_norm_layer(norm_cfg, dim)
|
||||
self.reduction = nn.Linear(dim, fpn_dim * 4, bias=False)
|
||||
self.fpn_dim = fpn_dim
|
||||
|
||||
def forward(self, x):
|
||||
B, N, H, W, C = x.shape
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
x = x.reshape(B, N, H, W, 2, 2,
|
||||
self.fpn_dim).permute(0, 1, 2, 4, 3, 5,
|
||||
6).reshape(B, N, 2 * H, 2 * W,
|
||||
self.fpn_dim)
|
||||
return x
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class iTPNPretrainDecoder(BaseModule):
|
||||
"""The neck module of iTPN (transformer pyramid network).
|
||||
|
||||
Args:
|
||||
num_patches (int): The number of total patches. Defaults to 196.
|
||||
patch_size (int): Image patch size. Defaults to 16.
|
||||
in_chans (int): The channel of input image. Defaults to 3.
|
||||
embed_dim (int): Encoder's embedding dimension. Defaults to 512.
|
||||
fpn_dim (int): The fpn dimension (channel number).
|
||||
fpn_depth (int): The layer number of feature pyramid.
|
||||
decoder_embed_dim (int): Decoder's embedding dimension.
|
||||
Defaults to 512.
|
||||
decoder_depth (int): The depth of decoder. Defaults to 8.
|
||||
decoder_num_heads (int): Number of attention heads of decoder.
|
||||
Defaults to 16.
|
||||
mlp_ratio (int): Ratio of mlp hidden dim to decoder's embedding dim.
|
||||
Defaults to 4.
|
||||
norm_cfg (dict): Normalization layer. Defaults to LayerNorm.
|
||||
reconstruction_type (str): The itpn supports 2 kinds of supervisions.
|
||||
Defaults to 'pixel'.
|
||||
num_outs (int): The output number of neck (transformer pyramid
|
||||
network). Defaults to 3.
|
||||
predict_feature_dim (int): The output dimension to supervision.
|
||||
Defaults to None.
|
||||
init_cfg (Union[List[dict], dict], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_patches: int = 196,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 512,
|
||||
fpn_dim: int = 256,
|
||||
fpn_depth: int = 2,
|
||||
decoder_embed_dim: int = 512,
|
||||
decoder_depth: int = 6,
|
||||
decoder_num_heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
reconstruction_type: str = 'pixel',
|
||||
num_outs: int = 3,
|
||||
qkv_bias: bool = True,
|
||||
qk_scale: Optional[bool] = None,
|
||||
drop_rate: float = 0.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
predict_feature_dim: Optional[float] = None,
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.num_patches = num_patches
|
||||
assert reconstruction_type in ['pixel', 'clip'], \
|
||||
'iTPN method only support `pixel` and `clip`, ' \
|
||||
f'but got `{reconstruction_type}`.'
|
||||
self.reconstruction_type = reconstruction_type
|
||||
self.num_outs = num_outs
|
||||
|
||||
self.build_transformer_pyramid(
|
||||
num_outs=num_outs,
|
||||
embed_dim=embed_dim,
|
||||
fpn_dim=fpn_dim,
|
||||
fpn_depth=fpn_depth,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
rpe=False,
|
||||
norm_cfg=norm_cfg,
|
||||
)
|
||||
|
||||
# merge the output
|
||||
self.decoder_embed = nn.ModuleList()
|
||||
self.decoder_embed.append(
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(fpn_dim),
|
||||
nn.Linear(fpn_dim, decoder_embed_dim, bias=True),
|
||||
))
|
||||
|
||||
if self.num_outs >= 2:
|
||||
self.decoder_embed.append(
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(fpn_dim),
|
||||
nn.Linear(fpn_dim, decoder_embed_dim // 4, bias=True),
|
||||
))
|
||||
if self.num_outs >= 3:
|
||||
self.decoder_embed.append(
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(fpn_dim),
|
||||
nn.Linear(fpn_dim, decoder_embed_dim // 16, bias=True),
|
||||
))
|
||||
|
||||
if reconstruction_type == 'pixel':
|
||||
self.mask_token = nn.Parameter(
|
||||
torch.zeros(1, 1, decoder_embed_dim))
|
||||
|
||||
# create new position embedding, different from that in encoder
|
||||
# and is not learnable
|
||||
self.decoder_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, self.num_patches, decoder_embed_dim),
|
||||
requires_grad=False)
|
||||
|
||||
self.decoder_blocks = nn.ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
decoder_embed_dim,
|
||||
decoder_num_heads,
|
||||
int(mlp_ratio * decoder_embed_dim),
|
||||
qkv_bias=True,
|
||||
norm_cfg=norm_cfg) for _ in range(decoder_depth)
|
||||
])
|
||||
|
||||
self.decoder_norm_name, decoder_norm = build_norm_layer(
|
||||
norm_cfg, decoder_embed_dim, postfix=1)
|
||||
self.add_module(self.decoder_norm_name, decoder_norm)
|
||||
|
||||
# Used to map features to pixels
|
||||
if predict_feature_dim is None:
|
||||
predict_feature_dim = patch_size**2 * in_chans
|
||||
self.decoder_pred = nn.Linear(
|
||||
decoder_embed_dim, predict_feature_dim, bias=True)
|
||||
else:
|
||||
_, norm = build_norm_layer(norm_cfg, embed_dim)
|
||||
self.add_module('norm', norm)
|
||||
|
||||
def build_transformer_pyramid(self,
|
||||
num_outs=3,
|
||||
embed_dim=512,
|
||||
fpn_dim=256,
|
||||
fpn_depth=2,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
rpe=False,
|
||||
norm_cfg=None):
|
||||
Hp = None
|
||||
mlvl_dims = {'4': embed_dim // 4, '8': embed_dim // 2, '16': embed_dim}
|
||||
if num_outs > 1:
|
||||
if embed_dim != fpn_dim:
|
||||
self.align_dim_16tofpn = nn.Linear(embed_dim, fpn_dim)
|
||||
else:
|
||||
self.align_dim_16tofpn = None
|
||||
self.fpn_modules = nn.ModuleList()
|
||||
self.fpn_modules.append(
|
||||
BlockWithRPE(
|
||||
Hp,
|
||||
fpn_dim,
|
||||
0,
|
||||
mlp_ratio,
|
||||
qkv_bias,
|
||||
qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=0.,
|
||||
rpe=rpe,
|
||||
norm_cfg=norm_cfg))
|
||||
self.fpn_modules.append(
|
||||
BlockWithRPE(
|
||||
Hp,
|
||||
fpn_dim,
|
||||
0,
|
||||
mlp_ratio,
|
||||
qkv_bias,
|
||||
qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=0.,
|
||||
rpe=False,
|
||||
norm_cfg=norm_cfg,
|
||||
))
|
||||
|
||||
self.align_dim_16to8 = nn.Linear(
|
||||
mlvl_dims['8'], fpn_dim, bias=False)
|
||||
self.split_16to8 = PatchSplit(mlvl_dims['16'], fpn_dim, norm_cfg)
|
||||
self.block_16to8 = nn.Sequential(*[
|
||||
BlockWithRPE(
|
||||
Hp,
|
||||
fpn_dim,
|
||||
0,
|
||||
mlp_ratio,
|
||||
qkv_bias,
|
||||
qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=0.,
|
||||
rpe=rpe,
|
||||
norm_cfg=norm_cfg,
|
||||
) for _ in range(fpn_depth)
|
||||
])
|
||||
|
||||
if num_outs > 2:
|
||||
self.align_dim_8to4 = nn.Linear(
|
||||
mlvl_dims['4'], fpn_dim, bias=False)
|
||||
self.split_8to4 = PatchSplit(fpn_dim, fpn_dim, norm_cfg)
|
||||
self.block_8to4 = nn.Sequential(*[
|
||||
BlockWithRPE(
|
||||
Hp,
|
||||
fpn_dim,
|
||||
0,
|
||||
mlp_ratio,
|
||||
qkv_bias,
|
||||
qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=0.,
|
||||
rpe=rpe,
|
||||
norm_cfg=norm_cfg,
|
||||
) for _ in range(fpn_depth)
|
||||
])
|
||||
self.fpn_modules.append(
|
||||
BlockWithRPE(
|
||||
Hp,
|
||||
fpn_dim,
|
||||
0,
|
||||
mlp_ratio,
|
||||
qkv_bias,
|
||||
qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=0.,
|
||||
rpe=rpe,
|
||||
norm_cfg=norm_cfg))
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize position embedding and mask token of MAE decoder."""
|
||||
super().init_weights()
|
||||
|
||||
if self.reconstruction_type == 'pixel':
|
||||
decoder_pos_embed = build_2d_sincos_position_embedding(
|
||||
int(self.num_patches**.5),
|
||||
self.decoder_pos_embed.shape[-1],
|
||||
cls_token=False)
|
||||
self.decoder_pos_embed.data.copy_(decoder_pos_embed.float())
|
||||
|
||||
torch.nn.init.normal_(self.mask_token, std=.02)
|
||||
else:
|
||||
self.rescale_init_weight()
|
||||
|
||||
def rescale_init_weight(self) -> None:
|
||||
"""Rescale the initialized weights."""
|
||||
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.fpn_modules):
|
||||
if isinstance(layer, BlockWithRPE):
|
||||
if layer.attn is not None:
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
@property
|
||||
def decoder_norm(self):
|
||||
"""The normalization layer of decoder."""
|
||||
return getattr(self, self.decoder_norm_name)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
ids_restore: torch.Tensor = None) -> torch.Tensor:
|
||||
"""The forward function.
|
||||
|
||||
The process computes the visible patches' features vectors and the mask
|
||||
tokens to output feature vectors, which will be used for
|
||||
reconstruction.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): hidden features, which is of shape
|
||||
B x (L * mask_ratio) x C.
|
||||
ids_restore (torch.Tensor): ids to restore original image.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The reconstructed feature vectors, which is of
|
||||
shape B x (num_patches) x C.
|
||||
"""
|
||||
|
||||
features = x[:2]
|
||||
x = x[-1]
|
||||
B, L, _ = x.shape
|
||||
x = x[..., None, None, :]
|
||||
Hp = Wp = math.sqrt(L)
|
||||
|
||||
outs = [x] if self.align_dim_16tofpn is None else [
|
||||
self.align_dim_16tofpn(x)
|
||||
]
|
||||
if self.num_outs >= 2:
|
||||
x = self.block_16to8(
|
||||
self.split_16to8(x) + self.align_dim_16to8(features[1]))
|
||||
outs.append(x)
|
||||
if self.num_outs >= 3:
|
||||
x = self.block_8to4(
|
||||
self.split_8to4(x) + self.align_dim_8to4(features[0]))
|
||||
outs.append(x)
|
||||
if self.num_outs > 3:
|
||||
outs = [
|
||||
out.reshape(B, Hp, Wp, *out.shape[-3:]).permute(
|
||||
0, 5, 1, 3, 2, 4).reshape(B, -1, Hp * out.shape[-3],
|
||||
Wp * out.shape[-2]).contiguous()
|
||||
for out in outs
|
||||
]
|
||||
if self.num_outs >= 4:
|
||||
outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2))
|
||||
if self.num_outs >= 5:
|
||||
outs.insert(0, F.avg_pool2d(outs[0], kernel_size=2, stride=2))
|
||||
|
||||
for i, out in enumerate(outs):
|
||||
out = self.fpn_modules[i](out)
|
||||
outs[i] = out
|
||||
|
||||
if self.reconstruction_type == 'pixel':
|
||||
feats = []
|
||||
for feat, layer in zip(outs, self.decoder_embed):
|
||||
x = layer(feat).reshape(B, L, -1)
|
||||
# append mask tokens to sequence
|
||||
mask_tokens = self.mask_token.repeat(
|
||||
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
||||
x = torch.cat([x, mask_tokens], dim=1)
|
||||
x = torch.gather(
|
||||
x,
|
||||
dim=1,
|
||||
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
|
||||
feats.append(x)
|
||||
x = feats.pop(0)
|
||||
# add pos embed
|
||||
x = x + self.decoder_pos_embed
|
||||
|
||||
for i, feat in enumerate(feats):
|
||||
x = x + feats[i]
|
||||
# apply Transformer blocks
|
||||
for i, blk in enumerate(self.decoder_blocks):
|
||||
x = blk(x)
|
||||
x = self.decoder_norm(x)
|
||||
x = self.decoder_pred(x)
|
||||
return x
|
||||
else:
|
||||
feats = []
|
||||
for feat, layer in zip(outs, self.decoder_embed):
|
||||
x = layer(feat).reshape(B, L, -1)
|
||||
feats.append(x)
|
||||
x = feats.pop(0)
|
||||
for i, feat in enumerate(feats):
|
||||
x = x + feats[i]
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x
|
|
@ -6,7 +6,8 @@ from .byol import BYOL
|
|||
from .cae import CAE, CAEPretrainViT, DALLEEncoder
|
||||
from .densecl import DenseCL
|
||||
from .eva import EVA
|
||||
from .mae import MAE, MAEViT
|
||||
from .itpn import iTPN, iTPNHiViT
|
||||
from .mae import MAE, MAEHiViT, MAEViT
|
||||
from .maskfeat import HOGGenerator, MaskFeat, MaskFeatViT
|
||||
from .milan import MILAN, CLIPGenerator, MILANViT
|
||||
from .mixmim import MixMIM, MixMIMPretrainTransformer
|
||||
|
@ -24,6 +25,9 @@ __all__ = [
|
|||
'CAEPretrainViT',
|
||||
'DALLEEncoder',
|
||||
'MAEViT',
|
||||
'MAEHiViT',
|
||||
'iTPNHiViT',
|
||||
'iTPN',
|
||||
'HOGGenerator',
|
||||
'MaskFeatViT',
|
||||
'CLIPGenerator',
|
||||
|
|
|
@ -0,0 +1,356 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmengine.model.weight_init import trunc_normal_
|
||||
|
||||
from mmpretrain.models.backbones.hivit import BlockWithRPE, HiViT, PatchMerge
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import DataSample
|
||||
from ..utils import build_2d_sincos_position_embedding
|
||||
from .base import BaseSelfSupervisor
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class iTPNHiViT(HiViT):
|
||||
"""HiViT for iTPN pre-training.
|
||||
|
||||
Args:
|
||||
img_size (int | tuple): Input image size. Defaults to 224.
|
||||
patch_size (int | tuple): The patch size. Defaults to 16.
|
||||
inner_patches (int): Inner patch. Defaults to 4.
|
||||
stem_mlp_ratio (int): Ratio of MLP hidden dim to embedding dim
|
||||
in the first two stages. Defaults to 3.
|
||||
mlp_ratio (int): Ratio of MLP hidden dim to embedding dim in
|
||||
the last stage. Defaults to 4.
|
||||
qkv_bias (bool): Enable bias for qkv projections if True.
|
||||
qk_scale (float): The number of divider after q@k. Default to None.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
attn_drop_rate (float): The drop out rate for attention output weights.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
ape (bool): If True, add absolute position embedding to
|
||||
the patch embedding.
|
||||
rpe (bool): If True, add relative position embedding to
|
||||
the patch embedding.
|
||||
layer_scale_init_value (float): Layer-scale init values. Defaults to 0.
|
||||
mask_ratio (bool): The ratio of total number of patches to be masked.
|
||||
Defaults to 0.75.
|
||||
reconstruction_type (str): The reconstruction of self-supervised
|
||||
learning. Defaults to 'pixel'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arch='base',
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
inner_patches: int = 4,
|
||||
stem_mlp_ratio: int = 3.,
|
||||
mlp_ratio: int = 4.,
|
||||
qkv_bias: bool = True,
|
||||
qk_scale: Optional[bool] = None,
|
||||
drop_rate: float = 0.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
ape: bool = True,
|
||||
rpe: bool = False,
|
||||
layer_scale_init_value: float = 0.0,
|
||||
mask_ratio: float = 0.75,
|
||||
reconstruction_type: str = 'pixel',
|
||||
):
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
inner_patches=inner_patches,
|
||||
stem_mlp_ratio=stem_mlp_ratio,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
ape=ape,
|
||||
rpe=rpe,
|
||||
layer_scale_init_value=layer_scale_init_value)
|
||||
|
||||
self.pos_embed.requires_grad = False
|
||||
self.mask_ratio = mask_ratio
|
||||
|
||||
assert reconstruction_type in ['pixel', 'clip'], \
|
||||
'iTPN method only support `pixel` and `clip`, ' \
|
||||
f'but got `{reconstruction_type}`.'
|
||||
self.reconstruction_type = reconstruction_type
|
||||
self.num_patches = self.patch_embed.num_patches
|
||||
|
||||
if reconstruction_type == 'clip':
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize position embedding, patch embedding and cls token."""
|
||||
super().apply(self._init_weights)
|
||||
|
||||
if self.reconstruction_type == 'clip':
|
||||
trunc_normal_(self.mask_token, std=0.02)
|
||||
self.rescale_init_weight()
|
||||
else:
|
||||
pos_embed = build_2d_sincos_position_embedding(
|
||||
int(self.num_patches**.5),
|
||||
self.pos_embed.shape[-1],
|
||||
cls_token=False)
|
||||
self.pos_embed.data.copy_(pos_embed.float())
|
||||
|
||||
w = self.patch_embed.proj.weight.data
|
||||
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
|
||||
def rescale_init_weight(self) -> None:
|
||||
"""Rescale the initialized weights."""
|
||||
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
if isinstance(layer, BlockWithRPE):
|
||||
if layer.attn is not None:
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
def masking_id(self, batch_size, mask_ratio):
|
||||
N, L = batch_size, self.pos_embed.size(1)
|
||||
len_keep = int(L * (1 - mask_ratio))
|
||||
|
||||
noise = torch.rand(
|
||||
N, L, device=self.pos_embed.device) # noise in [0, 1]
|
||||
|
||||
# sort noise for each sample
|
||||
ids_shuffle = torch.argsort(
|
||||
noise, dim=1) # ascend: small is keep, large is remove
|
||||
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
||||
|
||||
# keep the first subset
|
||||
ids_keep = ids_shuffle[:, :len_keep]
|
||||
# generate the binary mask: 0 is keep, 1 is remove
|
||||
mask = torch.ones([N, L], device=self.pos_embed.device)
|
||||
mask[:, :ids_keep.size(1)] = 0
|
||||
# unshuffle to get the binary mask
|
||||
mask = torch.gather(mask, dim=1, index=ids_restore)
|
||||
|
||||
return ids_keep, ids_restore, mask
|
||||
|
||||
def forward_pixel(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[bool] = True
|
||||
) -> Tuple[Tuple, torch.Tensor, torch.Tensor]:
|
||||
"""Generate features for masked images.
|
||||
|
||||
The function supports two kind of forward behaviors. If the ``mask`` is
|
||||
``True``, the function will generate mask to masking some patches
|
||||
randomly and get the hidden features for visible patches, which means
|
||||
the function will be executed as masked imagemodeling pre-training;
|
||||
if the ``mask`` is ``None`` or ``False``, the forward function will
|
||||
call ``super().forward()``, which extract features from images without
|
||||
mask.
|
||||
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
||||
mask (bool, optional): To indicate whether the forward function
|
||||
generating ``mask`` or not.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
|
||||
mask and the ids to restore original image.
|
||||
|
||||
- ``x`` (torch.Tensor): hidden features, which is of shape
|
||||
B x (L * mask_ratio) x C.
|
||||
- ``mask`` (torch.Tensor): mask used to mask image.
|
||||
- ``ids_restore`` (torch.Tensor): ids to restore original image.
|
||||
"""
|
||||
if mask is None or False:
|
||||
return super().forward(x)
|
||||
|
||||
else:
|
||||
B, C, H, W = x.shape
|
||||
ids_keep, ids_restore, mask = self.masking_id(B, self.mask_ratio)
|
||||
|
||||
x = self.patch_embed(x)
|
||||
|
||||
x = torch.gather(
|
||||
x,
|
||||
dim=1,
|
||||
index=ids_keep[:, :, None, None,
|
||||
None].expand(-1, -1, *x.shape[2:]))
|
||||
|
||||
outs = []
|
||||
for blk in self.blocks[:-self.num_main_blocks]:
|
||||
if isinstance(blk, PatchMerge):
|
||||
outs.append(x)
|
||||
x = blk(x)
|
||||
|
||||
x = x[..., 0, 0, :]
|
||||
if self.ape:
|
||||
pos_embed = self.interpolate_pos_encoding(x, H, W)
|
||||
pos_embed = torch.gather(
|
||||
pos_embed.expand(B, -1, -1),
|
||||
dim=1,
|
||||
index=ids_keep[:, :, None].expand(-1, -1,
|
||||
pos_embed.shape[2]),
|
||||
)
|
||||
x = x + pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks[-self.num_main_blocks:]:
|
||||
x = blk(x)
|
||||
|
||||
outs.append(x)
|
||||
|
||||
return (tuple(outs), mask, ids_restore)
|
||||
|
||||
def forward_clip(self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[bool] = True) -> Tuple:
|
||||
"""Generate features for masked images.
|
||||
|
||||
The function supports two kind of forward behaviors. If the ``mask`` is
|
||||
``True``, the function will generate mask to masking some patches
|
||||
randomly and get the hidden features for visible patches, which means
|
||||
the function will be executed as masked imagemodeling pre-training;
|
||||
if the ``mask`` is ``None`` or ``False``, the forward function will
|
||||
call ``super().forward()``, which extract features from images without
|
||||
mask.
|
||||
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
||||
mask (bool, optional): To indicate whether the forward function
|
||||
generating ``mask`` or not.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
|
||||
mask and the ids to restore original image.
|
||||
|
||||
- ``x`` (torch.Tensor): hidden features, which is of shape
|
||||
B x (L * mask_ratio) x C.
|
||||
- ``mask`` (torch.Tensor): mask used to mask image.
|
||||
- ``ids_restore`` (torch.Tensor): ids to restore original image.
|
||||
"""
|
||||
if mask is None or False:
|
||||
return super().forward(x)
|
||||
|
||||
else:
|
||||
B, C, H, W = x.shape
|
||||
x = self.patch_embed(x)
|
||||
|
||||
outs = []
|
||||
for blk in self.blocks[:-self.num_main_blocks]:
|
||||
if isinstance(blk, PatchMerge):
|
||||
outs.append(x)
|
||||
x = blk(x)
|
||||
|
||||
x = x[..., 0, 0, :]
|
||||
B, L, _ = x.shape
|
||||
mask_token = self.mask_token.expand(B, L, -1)
|
||||
w = mask.flatten(1).unsqueeze(-1).type_as(mask_token)
|
||||
x = x * (1. - w) + mask_token * w
|
||||
|
||||
if self.ape:
|
||||
pos_embed = self.interpolate_pos_encoding(x, H, W)
|
||||
x = x + pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
rpe_index = True if self.rpe else None
|
||||
|
||||
for blk in self.blocks[-self.num_main_blocks:]:
|
||||
x = blk(x, rpe_index)
|
||||
|
||||
outs.append(x)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: Optional[bool] = True) -> Tuple:
|
||||
"""Generate features for masked images.
|
||||
|
||||
The function supports two kind of forward behaviors. If the ``mask`` is
|
||||
``True``, the function will generate mask to masking some patches
|
||||
randomly and get the hidden features for visible patches, which means
|
||||
the function will be executed as masked imagemodeling pre-training;
|
||||
if the ``mask`` is ``None`` or ``False``, the forward function will
|
||||
call ``super().forward()``, which extract features from images without
|
||||
mask.
|
||||
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
||||
mask (bool, optional): To indicate whether the forward function
|
||||
generating ``mask`` or not.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
|
||||
mask and the ids to restore original image.
|
||||
|
||||
- ``x`` (torch.Tensor): hidden features, which is of shape
|
||||
B x (L * mask_ratio) x C.
|
||||
- ``mask`` (torch.Tensor): mask used to mask image.
|
||||
- ``ids_restore`` (torch.Tensor): ids to restore original image.
|
||||
"""
|
||||
|
||||
if self.reconstruction_type == 'pixel':
|
||||
return self.forward_pixel(x, mask)
|
||||
return self.forward_clip(x, mask)
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class iTPN(BaseSelfSupervisor):
|
||||
"""iTPN.
|
||||
|
||||
Implementation of `iTPN: Integrally Pre-Trained Transformer Pyramid
|
||||
Networks <https://arxiv.org/abs/2211.12735>`_.
|
||||
"""
|
||||
|
||||
def extract_feat(self, inputs: torch.Tensor):
|
||||
return self.backbone(inputs, mask=None)
|
||||
|
||||
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample],
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
"""The forward function in training.
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): The input images.
|
||||
data_samples (List[DataSample]): All elements required
|
||||
during the forward function.
|
||||
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
|
||||
if self.backbone.reconstruction_type == 'pixel':
|
||||
latent, mask, ids_restore = self.backbone(inputs)
|
||||
pred = self.neck(latent, ids_restore)
|
||||
|
||||
loss = self.head.loss(pred, inputs, mask)
|
||||
else:
|
||||
mask = torch.stack(
|
||||
[data_sample.mask for data_sample in data_samples])
|
||||
|
||||
img_latent = self.backbone(inputs[0], mask)
|
||||
|
||||
# inputs[1] is the target image
|
||||
with torch.no_grad():
|
||||
target = self.target_generator(inputs[1])[0]
|
||||
target = target.detach()
|
||||
|
||||
# iTPN contains a neck module
|
||||
feats = self.neck(img_latent)
|
||||
loss = self.head.loss(feats, target[:, 1:, :], mask)
|
||||
|
||||
losses = dict(loss=loss)
|
||||
return losses
|
|
@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|||
|
||||
import torch
|
||||
|
||||
from mmpretrain.models import VisionTransformer
|
||||
from mmpretrain.models import HiViT, VisionTransformer
|
||||
from mmpretrain.registry import MODELS
|
||||
from mmpretrain.structures import DataSample
|
||||
from ..utils import build_2d_sincos_position_embedding
|
||||
|
@ -234,3 +234,183 @@ class MAE(BaseSelfSupervisor):
|
|||
loss = self.head.loss(pred, inputs, mask)
|
||||
losses = dict(loss=loss)
|
||||
return losses
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class MAEHiViT(HiViT):
|
||||
"""HiViT for MAE pre-training.
|
||||
|
||||
A PyTorch implement of: `HiViT: A Simple and More Efficient Design
|
||||
of Hierarchical Vision Transformer <https://arxiv.org/abs/2205.14949>`_.
|
||||
This module implements the patch masking in MAE and initialize the
|
||||
position embedding with sine-cosine position embedding.
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture
|
||||
Default: 'b'
|
||||
img_size (int | tuple): Input image size
|
||||
patch_size (int | tuple): The patch size
|
||||
Defaults to 4, to downsample 4x at the first stage
|
||||
inner_patches (int): The inner patches within a token
|
||||
Defaults to 4
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
ape (bool): the absolute position embedding
|
||||
rpe (bool): the relative position embedding
|
||||
Defaults to False
|
||||
layer_scale_init_value (float): the layer scale init value
|
||||
mask_ratio (bool): The ratio of total number of patches to be masked.
|
||||
Defaults to 0.75.
|
||||
init_cfg (Union[List[dict], dict], optional): Initialization config
|
||||
dict. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch: Union[str, dict] = 'b',
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
inner_patches: int = 4,
|
||||
out_indices: Union[list, int] = [23],
|
||||
drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.0,
|
||||
norm_cfg: dict = dict(type='LN', eps=1e-6),
|
||||
ape: bool = True,
|
||||
rpe: bool = False,
|
||||
layer_scale_init_value: float = 0.0,
|
||||
mask_ratio: float = 0.75,
|
||||
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
|
||||
super().__init__(
|
||||
arch=arch,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
inner_patches=inner_patches,
|
||||
out_indices=out_indices,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_cfg=norm_cfg,
|
||||
ape=ape,
|
||||
rpe=rpe,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
init_cfg=init_cfg)
|
||||
|
||||
self.pos_embed.requires_grad = False
|
||||
self.mask_ratio = mask_ratio
|
||||
self.num_patches = self.patch_embed.num_patches
|
||||
|
||||
def init_weights(self) -> None:
|
||||
"""Initialize position embedding, patch embedding."""
|
||||
super().apply(self._init_weights)
|
||||
pos_embed = build_2d_sincos_position_embedding(
|
||||
int(self.num_patches**.5),
|
||||
self.pos_embed.shape[-1],
|
||||
cls_token=False)
|
||||
self.pos_embed.data.copy_(pos_embed.float())
|
||||
|
||||
w = self.patch_embed.proj.weight.data
|
||||
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
|
||||
def masking_id(
|
||||
self, batch_size,
|
||||
mask_ratio) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Generate the mask for MAE Pre-training.
|
||||
|
||||
Args:
|
||||
batch_size: The batch size of input data
|
||||
mask_ratio: The mask ratio of total patches.
|
||||
Defaults to 0.75.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: the ids
|
||||
for the tokens retained, the ids to restore original image,
|
||||
and the mask
|
||||
"""
|
||||
N, L = batch_size, self.pos_embed.size(1)
|
||||
len_keep = int(L * (1 - mask_ratio))
|
||||
|
||||
noise = torch.rand(
|
||||
N, L, device=self.pos_embed.device) # noise in [0, 1]
|
||||
|
||||
# sort noise for each sample
|
||||
ids_shuffle = torch.argsort(
|
||||
noise, dim=1) # ascend: small is keep, large is remove
|
||||
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
||||
|
||||
# keep the first subset
|
||||
ids_keep = ids_shuffle[:, :len_keep]
|
||||
# generate the binary mask: 0 is keep, 1 is remove
|
||||
mask = torch.ones([N, L], device=self.pos_embed.device)
|
||||
mask[:, :ids_keep.size(1)] = 0
|
||||
# unshuffle to get the binary mask
|
||||
mask = torch.gather(mask, dim=1, index=ids_restore)
|
||||
|
||||
return ids_keep, ids_restore, mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[bool] = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Generate features for masked images.
|
||||
|
||||
The function supports two kind of forward behaviors. If the ``mask`` is
|
||||
``True``, the function will generate mask to masking some patches
|
||||
randomly and get the hidden features for visible patches, which means
|
||||
the function will be executed as masked imagemodeling pre-training;
|
||||
if the ``mask`` is ``None`` or ``False``, the forward function will
|
||||
call ``super().forward()``, which extract features from images without
|
||||
mask.
|
||||
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input images, which is of shape B x C x H x W.
|
||||
mask (bool, optional): To indicate whether the forward function
|
||||
generating ``mask`` or not.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Hidden features,
|
||||
mask and the ids to restore original image.
|
||||
|
||||
- ``x`` (torch.Tensor): hidden features, which is of shape
|
||||
B x (L * mask_ratio) x C.
|
||||
- ``mask`` (torch.Tensor): mask used to mask image.
|
||||
- ``ids_restore`` (torch.Tensor): ids to restore original image.
|
||||
"""
|
||||
if mask is None or False:
|
||||
return super().forward(x)
|
||||
|
||||
else:
|
||||
B, C, H, W = x.shape
|
||||
ids_keep, ids_restore, mask = self.masking_id(B, self.mask_ratio)
|
||||
|
||||
x = self.patch_embed(x)
|
||||
|
||||
x = torch.gather(
|
||||
x,
|
||||
dim=1,
|
||||
index=ids_keep[:, :, None, None,
|
||||
None].expand(-1, -1, *x.shape[2:]))
|
||||
|
||||
for blk in self.blocks[:-self.num_main_blocks]:
|
||||
x = blk(x)
|
||||
|
||||
x = x[..., 0, 0, :]
|
||||
if self.ape:
|
||||
pos_embed = self.interpolate_pos_encoding(x, H, W)
|
||||
pos_embed = torch.gather(
|
||||
pos_embed.expand(B, -1, -1),
|
||||
dim=1,
|
||||
index=ids_keep[:, :, None].expand(-1, -1,
|
||||
pos_embed.shape[2]),
|
||||
)
|
||||
x = x + pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks[-self.num_main_blocks:]:
|
||||
x = blk(x)
|
||||
|
||||
return (x, mask, ids_restore)
|
||||
|
|
|
@ -76,3 +76,5 @@ Import:
|
|||
- configs/flamingo/metafile.yml
|
||||
- configs/blip2/metafile.yml
|
||||
- configs/chinese_clip/metafile.yml
|
||||
- configs/itpn/metafile.yml
|
||||
- configs/hivit/metafile.yml
|
||||
|
|
|
@ -34,6 +34,7 @@ test_list = [
|
|||
backbone=mmpretrain.models.VisionTransformer,
|
||||
forward=False,
|
||||
backward=False),
|
||||
Cfg(name='hivit-tiny-p16_16xb64_in1k', backbone=mmpretrain.models.HiViT),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import platform
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmpretrain.models import iTPN
|
||||
from mmpretrain.structures import DataSample
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
|
||||
def test_itpn():
|
||||
data_preprocessor = {
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'to_rgb': True
|
||||
}
|
||||
backbone = dict(
|
||||
type='iTPNHiViT',
|
||||
arch='base',
|
||||
reconstruction_type='pixel',
|
||||
mask_ratio=0.75)
|
||||
neck = dict(
|
||||
type='iTPNPretrainDecoder',
|
||||
num_patches=196,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=512,
|
||||
decoder_embed_dim=512,
|
||||
decoder_depth=6,
|
||||
decoder_num_heads=16,
|
||||
mlp_ratio=4.,
|
||||
reconstruction_type='pixel',
|
||||
# transformer pyramid
|
||||
fpn_dim=256,
|
||||
fpn_depth=2,
|
||||
num_outs=3,
|
||||
)
|
||||
head = dict(
|
||||
type='MAEPretrainHead',
|
||||
norm_pix=True,
|
||||
patch_size=16,
|
||||
loss=dict(type='PixelReconstructionLoss', criterion='L2'))
|
||||
|
||||
alg = iTPN(
|
||||
backbone=backbone,
|
||||
neck=neck,
|
||||
head=head,
|
||||
data_preprocessor=data_preprocessor)
|
||||
|
||||
fake_data = {
|
||||
'inputs': torch.randn((2, 3, 224, 224)),
|
||||
'data_samples': [DataSample() for _ in range(2)]
|
||||
}
|
||||
fake_inputs = alg.data_preprocessor(fake_data)
|
||||
fake_outputs = alg(**fake_inputs, mode='loss')
|
||||
assert isinstance(fake_outputs['loss'].item(), float)
|
Loading…
Reference in New Issue