[Feature] Add Maskfeat Support (#485)

* [Feature]: Add MaskfeatMaskGenerator Pipeline

* [Feature]: Add HogLayerC for MaskFeat

* [Feature]: Add Backbone of MaskFeat

* [Feature]: Add Head of MaskFeat

* [Feature]: Add Algorithms of MaskFeat

* [Feature]: Add Config of MaskFeat

* [Doc] Update Readme of MaskFeat

* [Fix] fix ut and hog_layer.

* [fix] Add and correct docstring

* [Fix] Refine the docstring of MaskFeat

* [fix] fix value of trunc_normal_

* [fix] rename the finetune config of maskfeat

* [fix] rename the fine-tuning config of maskfeat

* [fix] rename the fine-tuning config of maskfeat

* [fix] add new paramwise_options in fine-tuning config

* [fix] update the top-1 accuary of maskfeat

* [fix] update the top-1 accuary of maskfeat in model_zoo

* [fix] rename MaskfeatMaskGenerator
pull/520/head
lkylkylky 2022-09-24 17:48:36 +08:00 committed by Yixiao Fang
parent 6732025f48
commit 9e015762d1
23 changed files with 927 additions and 19 deletions

View File

@ -0,0 +1,76 @@
_base_ = [
'../_base_/models/vit-base-p16_ft.py',
'../_base_/datasets/imagenet.py',
'../_base_/schedules/adamw_coslr-100e_in1k.py',
'../_base_/default_runtime.py',
]
# maskfeat fine-tuning setting
# dataset
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(
type='RandomAug',
input_size=224,
color_jitter=0.4,
auto_augment='rand-m9-mstd0.5-inc1',
interpolation='bicubic',
re_prob=0.25,
re_mode='pixel',
re_count=1,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
]
test_pipeline = [
dict(type='Resize', size=256, interpolation=3),
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg)
]
data = dict(
samples_per_gpu=256,
drop_last=False,
workers_per_gpu=32,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline))
# model
model = dict(
backbone=dict(init_cfg=dict()),
head=dict(
type='MaskFeatFinetuneHead',
num_classes=1000,
embed_dim=768,
label_smooth_val=0.1))
# optimizer
optimizer = dict(
lr=0.002 * 8 / 2,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_options={
'ln': dict(weight_decay=0.),
'bias': dict(weight_decay=0.),
'pos_embed': dict(weight_decay=0.),
'cls_token': dict(weight_decay=0.),
},
constructor='TransformerFinetuneConstructor',
model_type='vit',
layer_decay=0.65)
# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr=1e-6,
warmup='linear',
warmup_iters=20,
warmup_ratio=1e-08,
warmup_by_epoch=True)
# runtime
checkpoint_config = dict(interval=1, max_keep_ckpts=3, out_dir='')
persistent_workers = True
log_config = dict(
interval=100, hooks=[
dict(type='TextLoggerHook'),
])

View File

@ -0,0 +1,35 @@
# dataset settings
data_source = 'ImageNet'
dataset_type = 'SingleViewDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(
type='RandomResizedCropAndInterpolationWithTwoPic',
size=224,
scale=(0.5, 1.0),
ratio=(0.75, 1.3333),
interpolation='bicubic'),
dict(type='RandomHorizontalFlip')
]
# prefetch
prefetch = False
if not prefetch:
train_pipeline.extend(
[dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg)])
train_pipeline.append(dict(type='MaskFeatMaskGenerator', mask_ratio=0.4))
# dataset summary
data = dict(
samples_per_gpu=256,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_source=dict(
type=data_source,
data_prefix='data/imagenet/train',
ann_file='data/imagenet/meta/train.txt'),
pipeline=train_pipeline,
prefetch=prefetch))

View File

@ -0,0 +1,15 @@
# model settings
model = dict(
type='MaskFeat',
backbone=dict(
type='MaskFeatViT',
arch='b',
patch_size=16,
drop_path_rate=0,
),
head=dict(type='MaskFeatPretrainHead', hog_dim=108),
hog_para=dict(
nbins=9, # Number of bin. Defaults to 9.
pool=8, # Number of cell. Defaults to 8.
gaussian_window=16 # Size of gaussian kernel. Defaults to 16.
))

View File

@ -0,0 +1,34 @@
# MaskFeat
> [Masked Feature Prediction for Self-Supervised Visual Pre-Training](https://arxiv.org/abs/2112.09133v1)
<!-- [ALGORITHM] -->
## Abstract
We present Masked Feature Prediction (MaskFeat) for self-supervised pre-training of video models. Our approach first randomly masks out a portion of the input sequence and then predicts the feature of the masked regions. We study five different types of features and find Histograms of Oriented Gradients (HOG), a hand-crafted feature descriptor, works particularly well in terms of both performance and efficiency. We observe that the local contrast normalization in HOG is essential for good results, which is in line with earlier work using HOG for visual recognition. Our approach can learn abundant visual knowledge and drive large-scale Transformer-based models. Without using extra model weights or supervision, MaskFeat pre-trained on unlabeled videos achieves unprecedented results of 86.7% with MViT-L on Kinetics-400, 88.3% on Kinetics-600, 80.4% on Kinetics-700, 38.8 mAP on AVA, and 75.0% on SSv2. MaskFeat further generalizes to image input, which can be interpreted as a video with a single frame and obtains competitive results on ImageNet.
<div align="center">
<img src="https://user-images.githubusercontent.com/48178838/190090285-428f07c0-0887-4ce8-b94f-f719cfd25622.png" width="60%"/>
</div>
## Models and Benchmarks
Here, we report the results of the model, which is pre-trained on ImageNet-1k
for 400 epochs, the details are below:
| Backbone | Pre-train epoch | Fine-tuning Top-1 | Pre-train Config | Fine-tuning Config | Download |
| :------: | :-------------: | :---------------: | :------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| ViT-B/16 | 300 | 83.5 | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py) | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/maskfeat_vit-base-p16_ft-8xb512-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) \| [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |
## Citation
```bibtex
@article{He2021MaskedAA,
title={Masked Autoencoders Are Scalable Vision Learners},
author={Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and
Piotr Doll'ar and Ross B. Girshick},
journal={ArXiv},
year={2021}
}
```

View File

@ -0,0 +1,40 @@
_base_ = [
'../_base_/models/maskfeat_vit-base-p16.py',
'../_base_/datasets/imagenet_maskfeat.py',
'../_base_/schedules/adamw_coslr-300e_in1k.py',
'../_base_/default_runtime.py',
]
# dataset
data = dict(samples_per_gpu=256, workers_per_gpu=32)
# optimizer
optimizer = dict(
lr=2e-4 * 8,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_options={
'ln': dict(weight_decay=0.),
'bias': dict(weight_decay=0.),
})
optimizer_config = dict(grad_clip=dict(max_norm=0.02))
# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr=1e-6,
warmup='linear',
warmup_iters=30,
warmup_ratio=1e-06,
warmup_by_epoch=True)
# schedule
runner = dict(max_epochs=300)
# runtime
checkpoint_config = dict(interval=1, max_keep_ckpts=3, out_dir='')
persistent_workers = True
log_config = dict(
interval=100, hooks=[
dict(type='TextLoggerHook'),
])

View File

@ -0,0 +1,27 @@
Collections:
- Name: MaskFeat
Metadata:
Training Data: ImageNet-1k
Training Techniques:
- AdamW
Training Resources: 8x A100-80G GPUs
Architecture:
- ViT
Paper:
URL: https://arxiv.org/abs/2112.09133v1
Title: "Masked Feature Prediction for Self-Supervised Visual Pre-Training"
README: configs/selfsup/maskfeat/README.md
Models:
- Name: maskfeat_vit-base-p16_8xb256-coslr-300e_in1k
In Collection: MaskFeat
Metadata:
Epochs: 300
Batch Size: 2048
Results:
- Task: Self-Supervised Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.5
Config: configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k_20220913-591d4c4b.pth

View File

@ -0,0 +1,34 @@
# MaskFeat
> [Masked Feature Prediction for Self-Supervised Visual Pre-Training](https://arxiv.org/abs/2112.09133v1)
<!-- [ALGORITHM] -->
## Abstract
We present Masked Feature Prediction (MaskFeat) for self-supervised pre-training of video models. Our approach first randomly masks out a portion of the input sequence and then predicts the feature of the masked regions. We study five different types of features and find Histograms of Oriented Gradients (HOG), a hand-crafted feature descriptor, works particularly well in terms of both performance and efficiency. We observe that the local contrast normalization in HOG is essential for good results, which is in line with earlier work using HOG for visual recognition. Our approach can learn abundant visual knowledge and drive large-scale Transformer-based models. Without using extra model weights or supervision, MaskFeat pre-trained on unlabeled videos achieves unprecedented results of 86.7% with MViT-L on Kinetics-400, 88.3% on Kinetics-600, 80.4% on Kinetics-700, 38.8 mAP on AVA, and 75.0% on SSv2. MaskFeat further generalizes to image input, which can be interpreted as a video with a single frame and obtains competitive results on ImageNet.
<div align="center">
<img src="https://user-images.githubusercontent.com/48178838/190090285-428f07c0-0887-4ce8-b94f-f719cfd25622.png" width="60%"/>
</div>
## Models and Benchmarks
Here, we report the results of the model, which is pre-trained on ImageNet-1k
for 400 epochs, the details are below:
| Backbone | Pre-train epoch | Fine-tuning Top-1 | Pre-train Config | Fine-tuning Config | Download |
| :------: | :-------------: | :---------------: | :------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| ViT-B/16 | 300 | 83.5 | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py) | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/maskfeat_vit-base-p16_ft-8xb512-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) \| [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |
## Citation
```bibtex
@article{He2021MaskedAA,
title={Masked Autoencoders Are Scalable Vision Learners},
author={Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and
Piotr Doll'ar and Ross B. Girshick},
journal={ArXiv},
year={2021}
}
```

View File

@ -25,7 +25,8 @@ All models and part of benchmark results are recorded below.
| [MoCo v3](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/README.md) | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | [model](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220225-e31238dd.pth) \| [log](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220222_160222.log.json) |
| [MAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/README.md) | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) \| [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |
| [SimMIM](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/README.md) | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | [model](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.pth) \| [log](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.log.json) |
| [CAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/RAEDME.md) | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth) \| [log](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.log.json) |
| [CAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/README.md) | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth) \| [log](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.log.json) |
| [MaskFeat](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/README.md) | [maskfeat_vit-base-p16_8xb256-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k_20220913-591d4c4b.pth) \| [log](https://download.openmmlab.com/mmselfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k_20220829_225552.log.json) |
Remarks:
@ -63,11 +64,12 @@ If not specified, we use linear evaluation setting from [MoCo](http://openaccess
### ImageNet Fine-tuning
| Algorithm | Config | Remarks | Top-1 (%) |
| --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | --------- |
| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 |
| SimMIM | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | | 82.9 |
| CAE | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | | 83.2 |
| Algorithm | Config | Remarks | Top-1 (%) |
| --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------- | --------- |
| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 |
| SimMIM | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | | 82.9 |
| CAE | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | | 83.2 |
| MaskFeat | [maskfeat_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/maskfeat_vit-base-p16_ft-8xb512-coslr-100e_in1k.py) | | 83.5 |
### COCO17 Object Detection and Instance Segmentation

View File

@ -0,0 +1,34 @@
# MaskFeat
> [Masked Feature Prediction for Self-Supervised Visual Pre-Training](https://arxiv.org/abs/2112.09133v1)
<!-- [ALGORITHM] -->
## Abstract
We present Masked Feature Prediction (MaskFeat) for self-supervised pre-training of video models. Our approach first randomly masks out a portion of the input sequence and then predicts the feature of the masked regions. We study five different types of features and find Histograms of Oriented Gradients (HOG), a hand-crafted feature descriptor, works particularly well in terms of both performance and efficiency. We observe that the local contrast normalization in HOG is essential for good results, which is in line with earlier work using HOG for visual recognition. Our approach can learn abundant visual knowledge and drive large-scale Transformer-based models. Without using extra model weights or supervision, MaskFeat pre-trained on unlabeled videos achieves unprecedented results of 86.7% with MViT-L on Kinetics-400, 88.3% on Kinetics-600, 80.4% on Kinetics-700, 38.8 mAP on AVA, and 75.0% on SSv2. MaskFeat further generalizes to image input, which can be interpreted as a video with a single frame and obtains competitive results on ImageNet.
<div align="center">
<img src="https://user-images.githubusercontent.com/48178838/190090285-428f07c0-0887-4ce8-b94f-f719cfd25622.png" width="60%"/>
</div>
## Models and Benchmarks
Here, we report the results of the model, which is pre-trained on ImageNet-1k
for 400 epochs, the details are below:
| Backbone | Pre-train epoch | Fine-tuning Top-1 | Pre-train Config | Fine-tuning Config | Download |
| :------: | :-------------: | :---------------: | :------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| ViT-B/16 | 300 | 83.5 | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py) | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/maskfeat_vit-base-p16_ft-8xb512-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) \| [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |
## Citation
```bibtex
@article{He2021MaskedAA,
title={Masked Autoencoders Are Scalable Vision Learners},
author={Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and
Piotr Doll'ar and Ross B. Girshick},
journal={ArXiv},
year={2021}
}
```

View File

@ -25,7 +25,8 @@
| [MoCo v3](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/README.md) | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | [model](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220225-e31238dd.pth) \| [log](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220222_160222.log.json) |
| [MAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/README.md) | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) \| [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |
| [SimMIM](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/README.md) | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | [model](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.pth) \| [log](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.log.json) |
| [CAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/RAEDME.md) | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth) \| [log](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.log.json) |
| [CAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/README.md) | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth) \| [log](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.log.json) |
| [MaskFeat](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/README.md) | [maskfeat_vit-base-p16_8xb256-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k_20220913-591d4c4b.pth) \| [log](https://download.openmmlab.com/mmselfsup/maskfeat/maskfeat_vit-base-p16_8xb256-coslr-300e_in1k_20220829_225552.log.json) |
备注:
@ -63,11 +64,12 @@
### ImageNet 微调
| 算法 | 配置文件 | 备注 | Top-1 (%) |
| ------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---- | --------- |
| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 |
| SimMIM | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | | 82.9 |
| CAE | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | | 83.2 |
| 算法 | 配置文件 | 备注 | Top-1 (%) |
| -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---- | --------- |
| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 |
| SimMIM | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | | 82.9 |
| CAE | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | | 83.2 |
| MaskFeat | [maskfeat_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/maskfeat_vit-base-p16_ft-8xb512-coslr-100e_in1k.py) | | 83.5 |
### COCO17 目标检测和实例分割

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .transforms import (BEiTMaskGenerator, GaussianBlur, Lighting,
RandomAppliedTrans, RandomAug, SimMIMMaskGenerator,
Solarization, ToTensor)
MaskFeatMaskGenerator, RandomAppliedTrans, RandomAug,
SimMIMMaskGenerator, Solarization, ToTensor)
__all__ = [
'GaussianBlur', 'Lighting', 'RandomAppliedTrans', 'Solarization',
'RandomAug', 'SimMIMMaskGenerator', 'ToTensor', 'BEiTMaskGenerator'
'RandomAug', 'SimMIMMaskGenerator', 'ToTensor', 'BEiTMaskGenerator',
'MaskFeatMaskGenerator'
]

View File

@ -482,3 +482,113 @@ class Solarization(object):
repr_str += f'threshold = {self.threshold}, '
repr_str += f'prob = {self.prob}'
return repr_str
@PIPELINES.register_module()
class MaskFeatMaskGenerator(object):
"""Generate random block mask for each image.
This module is borrowed from
https://github.com/facebookresearch/SlowFast/blob/main/slowfast/datasets/transform.py
Args:
mask_window_size (int): Size of input image. Defaults to 14.
mask_ratio (float): The mask ratio of image. Defaults to 0.4.
min_num_patches (int): Minimum number of patches that require masking.
Defaults to 15.
max_num_patches (int, optional): Maximum number of patches that
require masking. Defaults to None.
min_aspect (int): Minimum aspect of patches. Defaults to 0.3.
max_aspect (float, optional): Maximum aspect of patches.
Defaults to None.
"""
def __init__(
self,
mask_window_size: int = 14,
mask_ratio: float = 0.4,
min_num_patches: int = 15,
max_num_patches: Optional[int] = None,
min_aspect: float = 0.3,
max_aspect: Optional[float] = None,
) -> None:
self.height, self.width = mask_window_size, mask_window_size
self.num_patches = self.height * self.width
self.num_masking_patches = int(mask_window_size**2 * mask_ratio)
self.min_num_patches = min_num_patches
self.max_num_patches = (
self.num_masking_patches
if max_num_patches is None else max_num_patches)
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(height={self.height}, '
repr_str += f'width={self.width}, '
repr_str += f'num_patches={self.num_patches}, '
repr_str += f'num_masking_patches={self.num_masking_patches}, '
repr_str += f'min_num_patches={self.min_num_patches}, '
repr_str += f'max_num_patches={self.max_num_patches}, '
repr_str += f'log_aspect_ratio={self.log_aspect_ratio})'
return repr_str
def get_shape(self) -> Tuple[int, int]:
return self.height, self.width
def _random_masking(self, mask: np.array, max_mask_patches: int) -> int:
"""Generate random block masks for each image up to 10 times.
Args:
mask (np.array): Initial mask of shape (mask_window_size,
mask_window_size).
max_mask_patches (int): Maximum number of masked patches required.
Returns:
int: Number of masking patches.
"""
delta = 0
for _ in range(10):
target_area = random.uniform(self.min_num_patches,
max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < self.width and h < self.height:
top = random.randint(0, self.height - h)
left = random.randint(0, self.width - w)
num_masked = mask[top:top + h, left:left + w].sum()
# Overlap
if 0 < h * w - num_masked <= max_mask_patches:
for i in range(top, top + h):
for j in range(left, left + w):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
def __call__(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate random block mask for each image.
Args:
img (torch.Tensor): Input image of shape (C, H, W).
Returns:
Tuple[torch.Tensor, torch.Tensor]: Input image and mask.
"""
mask = np.zeros(shape=self.get_shape(), dtype=np.int)
mask_count = 0
while mask_count < self.num_masking_patches:
max_mask_patches = self.num_masking_patches - mask_count
max_mask_patches = min(max_mask_patches, self.max_num_patches)
delta = self._random_masking(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
return img, torch.Tensor(mask).bool()

View File

@ -7,6 +7,7 @@ from .classification import Classification
from .deepcluster import DeepCluster
from .densecl import DenseCL
from .mae import MAE
from .maskfeat import MaskFeat
from .mmcls_classifier_wrapper import MMClsImageClassifierWrapper
from .moco import MoCo
from .mocov3 import MoCoV3
@ -23,5 +24,5 @@ __all__ = [
'BaseModel', 'BarlowTwins', 'BYOL', 'Classification', 'DeepCluster',
'DenseCL', 'MoCo', 'NPID', 'ODC', 'RelativeLoc', 'RotationPred', 'SimCLR',
'SimSiam', 'SwAV', 'MAE', 'MoCoV3', 'SimMIM',
'MMClsImageClassifierWrapper', 'CAE'
'MMClsImageClassifierWrapper', 'CAE', 'MaskFeat'
]

View File

@ -0,0 +1,70 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch
from ..builder import ALGORITHMS, build_backbone, build_head
from ..utils.hog_layer import HOGLayerC
from .base import BaseModel
@ALGORITHMS.register_module()
class MaskFeat(BaseModel):
"""MaskFeat.
Implementation of `Masked Feature Prediction for
Self-Supervised Visual Pre-Training <https://arxiv.org/abs/2112.09133>`_.
Args:
backbone (dict): Config dict for encoder.
head (dict): Config dict for loss functions.
hog_para (dict): Config dict for hog layer.
dict['nbins', int]: Number of bin. Defaults to 9.
dict['pool', float]: Number of cell. Defaults to 8.
dict['gaussian_window', int]: Size of gaussian kernel.
Defaults to 16.
init_cfg (dict): Config dict for weight initialization.
Defaults to None.
"""
def __init__(self,
backbone: dict,
head: dict,
hog_para: dict,
init_cfg: Optional[dict] = None) -> None:
super().__init__(init_cfg)
assert backbone is not None
self.backbone = build_backbone(backbone)
assert head is not None
self.head = build_head(head)
assert hog_para is not None
self.hog_layer = HOGLayerC(**hog_para)
def extract_feat(self, input: List[torch.Tensor]) -> torch.Tensor:
"""Function to extract features from backbone.
Args:
input (List[torch.Tensor, torch.Tensor]): Input images and masks.
Returns:
tuple[Tensor]: backbone outputs.
"""
img = input[0]
mask = input[1]
return self.backbone(img, mask)
def forward_train(self, input: List[torch.Tensor], **kwargs) -> dict:
"""Forward computation during training.
Args:
input (List[torch.Tensor, torch.Tensor]): Input images and masks.
kwargs: Any keyword arguments to be used to forward.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
img = input[0]
mask = input[1]
hog = self.hog_layer(img)
latent = self.backbone(img, mask)
losses = self.head(latent, hog, mask)
return losses

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cae_vit import CAEViT
from .mae_vit import MAEViT
from .maskfeat_vit import MaskFeatViT
from .mim_cls_vit import MIMVisionTransformer
from .resnet import ResNet, ResNetV1d
from .resnext import ResNeXt
@ -9,5 +10,5 @@ from .vision_transformer import VisionTransformer
__all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'MAEViT', 'MIMVisionTransformer',
'VisionTransformer', 'SimMIMSwinTransformer', 'CAEViT'
'VisionTransformer', 'SimMIMSwinTransformer', 'CAEViT', 'MaskFeatViT'
]

View File

@ -0,0 +1,125 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import torch
from mmcls.models import VisionTransformer
from mmcv.cnn.utils.weight_init import trunc_normal_
from torch import nn
from ..builder import BACKBONES
@BACKBONES.register_module()
class MaskFeatViT(VisionTransformer):
"""Vision Transformer for MaskFeat pre-training.
A PyTorch implement of: `Masked Feature Prediction for Self-Supervised
Visual Pre-Training <https://arxiv.org/abs/2112.09133>`_.
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
arch: Union[str, dict] = 'b',
img_size: Union[Tuple[int, int], int] = 224,
patch_size: int = 16,
out_indices: int = -1,
drop_rate: float = 0.,
drop_path_rate: float = 0.,
norm_cfg: dict = dict(type='LN', eps=1e-6),
final_norm: bool = True,
output_cls_token: bool = True,
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
init_cfg: Optional[dict] = None) -> None:
super().__init__(
arch=arch,
img_size=img_size,
patch_size=patch_size,
out_indices=out_indices,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
final_norm=final_norm,
output_cls_token=output_cls_token,
interpolate_mode=interpolate_mode,
patch_cfg=patch_cfg,
layer_cfgs=layer_cfgs,
init_cfg=init_cfg)
self.mask_token = nn.parameter.Parameter(
torch.zeros(1, 1, self.embed_dims))
self.num_patches = self.patch_resolution[0] * self.patch_resolution[1]
def init_weights(self) -> None:
super().init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
trunc_normal_(self.cls_token, std=.02)
trunc_normal_(self.mask_token, std=.02)
trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m: torch.nn.Module) -> None:
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Generate features for masked images.
Args:
x (torch.Tensor): Input images.
mask (torch.Tensor): Input masks.
Returns:
torch.Tensor: Features with cls_tokens.
"""
B = x.shape[0]
x = self.patch_embed(x)[0]
# masking: length -> length * mask_ratio
B, L, _ = x.shape
mask_tokens = self.mask_token.expand(B, L, -1)
mask = mask.flatten(1).unsqueeze(-1)
x = x * (1 - mask.int()) + mask_tokens * mask
# append cls token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.drop_after_pos(x)
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.norm1(x)
return x

View File

@ -5,6 +5,7 @@ from .contrastive_head import ContrastiveHead
from .latent_pred_head import (LatentClsHead, LatentCrossCorrelationHead,
LatentPredictHead)
from .mae_head import MAEFinetuneHead, MAELinprobeHead, MAEPretrainHead
from .maskfeat_head import MaskFeatFinetuneHead, MaskFeatPretrainHead
from .mocov3_head import MoCoV3Head
from .multi_cls_head import MultiClsHead
from .simmim_head import SimMIMHead
@ -14,5 +15,6 @@ __all__ = [
'ContrastiveHead', 'ClsHead', 'LatentPredictHead', 'LatentClsHead',
'LatentCrossCorrelationHead', 'MultiClsHead', 'SwAVHead',
'MAEFinetuneHead', 'MAEPretrainHead', 'MoCoV3Head', 'SimMIMHead',
'CAEHead', 'MAELinprobeHead'
'CAEHead', 'MAELinprobeHead', 'MaskFeatFinetuneHead',
'MaskFeatPretrainHead'
]

View File

@ -0,0 +1,103 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcls.models import LabelSmoothLoss
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner import BaseModule
from torch import nn
from ..builder import HEADS
@HEADS.register_module()
class MaskFeatPretrainHead(BaseModule):
"""Pre-training head for MaskFeat.
Args:
embed_dim (int): The dim of the feature before the classifier head.
Defaults to 768.
hog_dim (int): The dim of the hog feature. Defaults to 108.
"""
def __init__(self, embed_dim: int = 768, hog_dim: int = 108) -> None:
super().__init__()
self.head = nn.Linear(embed_dim, hog_dim)
def init_weights(self) -> None:
nn.init.constant_(self.head.bias, 0)
trunc_normal_(self.head.weight, std=0.02)
def loss(self, pred: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> dict:
"""Compute the loss.
Args:
pred (torch.Tensor): Input prediction of shape (N, L, C).
target (torch.Tensor): Input target of shape (N, L, C).
mask (torch.Tensor): Input mask of shape (N, L, 1).
Returns:
dict[str, torch.Tensor]: A dictionary of loss components.
"""
losses = dict()
pred = pred[mask]
target = target[mask]
loss = ((pred - target)**2).mean(-1).mean()
losses['loss'] = loss
return losses
def forward(self, latent: torch.Tensor, hog: torch.Tensor,
mask: torch.Tensor) -> dict:
"""Pre-training head for MaskFeat.
Args:
latent (torch.Tensor): Input latent of shape (N, 1+L, C).
hog (torch.Tensor): Input hog feature of shape (N, L, C).
mask (torch.Tensor): Input mask of shape (N, H, W).
Returns:
Dict[str, torch.Tensor]: A dictionary of loss components.
"""
latent = self.head(latent)
mask = mask.flatten(1).bool()
losses = self.loss(latent[:, 1:], hog, mask)
return losses
@HEADS.register_module()
class MaskFeatFinetuneHead(BaseModule):
"""Fine-tuning head for MaskFeat.
Args:
embed_dim (int): The dim of the feature before the classifier head.
num_classes (int): The total classes. Defaults to 1000.
label_smooth_val (float): The degree of label smoothing.
Defaults to 0.1.
"""
def __init__(self,
embed_dim: int,
num_classes: int = 1000,
label_smooth_val: float = 0.1) -> None:
super().__init__()
self.head = nn.Linear(embed_dim, num_classes, bias=True)
self.act = nn.Softmax(dim=1)
self.criterion = LabelSmoothLoss(label_smooth_val, num_classes)
def init_weights(self) -> None:
nn.init.constant_(self.head.bias, 0)
trunc_normal_(self.head.weight, std=.02)
def forward(self, x: torch.Tensor) -> list:
""""Get the logits."""
outputs = self.head(x)
if not self.training:
outputs = self.act(outputs)
return [outputs]
def loss(self, outputs: torch.Tensor, labels: torch.Tensor) -> dict:
"""Compute the loss."""
losses = dict()
losses['loss'] = self.criterion(outputs[0], labels)
return losses

View File

@ -0,0 +1,106 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class HOGLayerC(nn.Module):
"""Generate hog feature for each batch images. This module is used in
Maskfeat to generate hog feature. This code is borrowed from.
<https://github.com/facebookresearch/SlowFast/blob/main/slowfast/models/operators.py>
Args:
nbins (int): Number of bin. Defaults to 9.
pool (float): Number of cell. Defaults to 8.
gaussian_window (int): Size of gaussian kernel. Defaults to 16.
"""
def __init__(self,
nbins: int = 9,
pool: int = 8,
gaussian_window: int = 16) -> None:
super().__init__()
self.nbins = nbins
self.pool = pool
self.pi = math.pi
weight_x = torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
weight_x = weight_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1)
weight_y = weight_x.transpose(2, 3)
self.register_buffer('weight_x', weight_x)
self.register_buffer('weight_y', weight_y)
self.gaussian_window = gaussian_window
if gaussian_window:
gkern = self.get_gkern(gaussian_window, gaussian_window // 2)
self.register_buffer('gkern', gkern)
def get_gkern(self, kernlen: int, std: int) -> torch.Tensor:
"""Returns a 2D Gaussian kernel array."""
def _gaussian_fn(kernlen: int, std: int) -> torch.Tensor:
n = torch.arange(0, kernlen).float()
n -= n.mean()
n /= std
w = torch.exp(-0.5 * n**2)
return w
gkern1d = _gaussian_fn(kernlen, std)
gkern2d = gkern1d[:, None] * gkern1d[None, :]
return gkern2d / gkern2d.sum()
def _reshape(self, hog_feat: torch.Tensor) -> torch.Tensor:
hog_feat = hog_feat.flatten(1, 2)
unfold_size = hog_feat.shape[-1] // 14
hog_feat = (
hog_feat.permute(0, 2, 3,
1).unfold(1, unfold_size, unfold_size).unfold(
2, unfold_size,
unfold_size).flatten(1, 2).flatten(2))
return hog_feat
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Generate hog feature for each batch images.
Args:
x (torch.Tensor): Input images of shape (N, 3, H, W).
Returns:
torch.Tensor: Hog features.
"""
# input is RGB image with shape [B 3 H W]
x = F.pad(x, pad=(1, 1, 1, 1), mode='reflect')
gx_rgb = F.conv2d(
x, self.weight_x, bias=None, stride=1, padding=0, groups=3)
gy_rgb = F.conv2d(
x, self.weight_y, bias=None, stride=1, padding=0, groups=3)
norm_rgb = torch.stack([gx_rgb, gy_rgb], dim=-1).norm(dim=-1)
phase = torch.atan2(gx_rgb, gy_rgb)
phase = phase / self.pi * self.nbins # [-9, 9]
b, c, h, w = norm_rgb.shape
out = torch.zeros((b, c, self.nbins, h, w),
dtype=torch.float,
device=x.device)
phase = phase.view(b, c, 1, h, w)
norm_rgb = norm_rgb.view(b, c, 1, h, w)
if self.gaussian_window:
if h != self.gaussian_window:
assert h % self.gaussian_window == 0, 'h {} gw {}'.format(
h, self.gaussian_window)
repeat_rate = h // self.gaussian_window
temp_gkern = self.gkern.repeat([repeat_rate, repeat_rate])
else:
temp_gkern = self.gkern
norm_rgb *= temp_gkern
out.scatter_add_(2, phase.floor().long() % self.nbins, norm_rgb)
out = out.unfold(3, self.pool, self.pool)
out = out.unfold(4, self.pool, self.pool)
out = out.sum(dim=[-1, -2])
out = F.normalize(out, p=2, dim=2)
return self._reshape(out)

View File

@ -187,3 +187,13 @@ def test_random_resize_crop_with_two_pic():
fake_output = module(fake_input)
assert list(fake_output[0].size) == [224, 224]
assert list(fake_output[1].size) == [112, 112]
def test_maskfeat_mask_gen():
transform = dict(
type='MaskFeatMaskGenerator', mask_window_size=14, mask_ratio=0.6)
img = torch.rand((3, 224, 224))
module = build_from_cfg(transform, PIPELINES)
res = module(img)
assert list(res[1].shape) == [14, 14]

View File

@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmselfsup.models.algorithms import MaskFeat
backbone = dict(
type='MaskFeatViT',
arch='b',
patch_size=16,
drop_path_rate=0,
)
head = dict(type='MaskFeatPretrainHead', hog_dim=108)
hog_para = dict(nbins=9, pool=8, gaussian_window=16)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_maskfeat():
with pytest.raises(AssertionError):
alg = MaskFeat(backbone=backbone, head=None, hog_para=hog_para)
with pytest.raises(AssertionError):
alg = MaskFeat(backbone=None, head=head, hog_para=hog_para)
alg = MaskFeat(backbone=backbone, head=head, hog_para=hog_para)
fake_img = torch.randn((2, 3, 224, 224))
fake_mask = torch.randn((2, 14, 14)).bool()
fake_input = (fake_img, fake_mask)
fake_loss = alg.forward_train(fake_input)
fake_feature = alg.extract_feat(fake_input)
assert isinstance(fake_loss['loss'].item(), float)
assert list(fake_feature.shape) == [2, 197, 768]

View File

@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import platform
import pytest
import torch
from mmselfsup.models.backbones import MaskFeatViT
backbone = dict(arch='b', patch_size=16)
@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_maskfeat_pretrain_vit():
maskfeat_pretrain_backbone = MaskFeatViT(**backbone)
maskfeat_pretrain_backbone.init_weights()
fake_inputs = torch.randn((2, 3, 224, 224))
fake_mask = torch.randn((2, 14, 14))
fake_outputs = maskfeat_pretrain_backbone(fake_inputs, fake_mask)
assert list(fake_outputs.shape) == [2, 197, 768]

View File

@ -5,7 +5,9 @@ import torch.nn.functional as F
from mmselfsup.models.heads import (ClsHead, ContrastiveHead, LatentClsHead,
LatentCrossCorrelationHead,
LatentPredictHead, MAEFinetuneHead,
MAEPretrainHead, MultiClsHead, SwAVHead)
MAEPretrainHead, MaskFeatFinetuneHead,
MaskFeatPretrainHead, MultiClsHead,
SwAVHead)
def test_cls_head():
@ -120,3 +122,28 @@ def test_mae_finetune_head():
loss = head.loss(fake_features, fake_labels)
assert loss['loss'].item() > 0
def test_maskfeat_pretrain_head():
head = MaskFeatPretrainHead(hog_dim=108)
fake_mask = torch.ones((2, 14, 14)).bool()
fake_pred = torch.rand((2, 197, 768))
fake_hog = torch.rand((2, 196, 108))
loss = head.forward(fake_pred, fake_hog, fake_mask)
assert loss['loss'].item() > 0
def test_maskfeat_finetune_head():
head = MaskFeatFinetuneHead(num_classes=1000, embed_dim=768)
fake_input = torch.rand((2, 768))
fake_labels = F.normalize(torch.rand((2, 1000)), dim=-1)
fake_features = head.forward(fake_input)
assert list(fake_features[0].shape) == [2, 1000]
loss = head.loss(fake_features, fake_labels)
assert loss['loss'].item() > 0